diff options
author | vanaigr <vanaigranov@gmail.com> | 2024-12-18 01:06:41 -0600 |
---|---|---|
committer | vanaigr <vanaigranov@gmail.com> | 2025-01-06 00:35:19 -0600 |
commit | dd234135ad20119917831fd8ffcb19d8562022ca (patch) | |
tree | 9c26bb172ead60b9bdab2d847525edb3c4b34881 | |
parent | 8d2ee542a82a0d162198f27de316ddfc81e8761c (diff) | |
download | rneovim-dd234135ad20119917831fd8ffcb19d8562022ca.tar.gz rneovim-dd234135ad20119917831fd8ffcb19d8562022ca.tar.bz2 rneovim-dd234135ad20119917831fd8ffcb19d8562022ca.zip |
refactor: split predicates and directives
-rw-r--r-- | runtime/lua/vim/treesitter/highlighter.lua | 4 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 172 | ||||
-rw-r--r-- | test/benchmark/decor_spec.lua | 1 | ||||
-rw-r--r-- | test/functional/treesitter/query_spec.lua | 4 |
4 files changed, 108 insertions, 73 deletions
diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index 8ce8652f7d..96503c38ea 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -299,6 +299,8 @@ local function on_line_impl(self, buf, line, is_spell_nav) state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) end + local captures = state.highlighter_query:query().captures + while line >= state.next_row do local capture, node, metadata, match = state.iter(line) @@ -311,7 +313,7 @@ local function on_line_impl(self, buf, line, is_spell_nav) if capture then local hl = state.highlighter_query:get_hl_from_capture(capture) - local capture_name = state.highlighter_query:query().captures[capture] + local capture_name = captures[capture] local spell, spell_pri_offset = get_spell(capture_name) diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 01fdb708eb..1fc001b39f 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -7,6 +7,59 @@ local memoize = vim.func._memoize local M = {} +local function is_directive(name) + return string.sub(name, -1) == '!' +end + +---@nodoc +---@class vim.treesitter.query.ProcessedPredicate +---@field [1] string predicate name +---@field [2] boolean should match +---@field [3] (integer|string)[] the original predicate + +---@alias vim.treesitter.query.ProcessedDirective (integer|string)[] + +---@nodoc +---@class vim.treesitter.query.ProcessedPattern { +---@field predicates vim.treesitter.query.ProcessedPredicate[] +---@field directives vim.treesitter.query.ProcessedDirective[] + +--- Splits the query patterns into predicates and directives. +---@param patterns table<integer, (integer|string)[][]> +---@return table<integer, vim.treesitter.query.ProcessedPattern> +local function process_patterns(patterns) + ---@type table<integer, vim.treesitter.query.ProcessedPattern> + local processed_patterns = {} + + for k, pattern_list in pairs(patterns) do + ---@type vim.treesitter.query.ProcessedPredicate[] + local predicates = {} + ---@type vim.treesitter.query.ProcessedDirective[] + local directives = {} + + for _, pattern in ipairs(pattern_list) do + -- Note: tree-sitter strips the leading # from predicates for us. + local pred_name = pattern[1] + ---@cast pred_name string + + if is_directive(pred_name) then + table.insert(directives, pattern) + else + local should_match = true + if pred_name:match('^not%-') then + pred_name = pred_name:sub(5) + should_match = false + end + table.insert(predicates, { pred_name, should_match, pattern }) + end + end + + processed_patterns[k] = { predicates = predicates, directives = directives } + end + + return processed_patterns +end + ---@nodoc ---Parsed query, see |vim.treesitter.query.parse()| --- @@ -15,6 +68,7 @@ local M = {} ---@field captures string[] list of (unique) capture names defined in query ---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives) ---@field query TSQuery userdata query object +---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern> local Query = {} Query.__index = Query @@ -33,6 +87,7 @@ function Query.new(lang, ts_query) patterns = query_info.patterns, } self.captures = self.info.captures + self._processed_patterns = process_patterns(self.info.patterns) return self end @@ -751,67 +806,50 @@ function M.list_predicates() return vim.tbl_keys(predicate_handlers) end -local function xor(x, y) - return (x or y) and not (x and y) -end - -local function is_directive(name) - return string.sub(name, -1) == '!' -end - ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param predicates vim.treesitter.query.ProcessedPredicate[] +---@param captures table<integer, TSNode[]> ---@param source integer|string -function Query:match_preds(preds, pattern, captures, source) - for _, pred in pairs(preds) do - -- Here we only want to return if a predicate DOES NOT match, and - -- continue on the other case. This way unknown predicates will not be considered, - -- which allows some testing and easier user extensibility (#12173). - -- Also, tree-sitter strips the leading # from predicates for us. - local is_not = false - - -- Skip over directives... they will get processed after all the predicates. - if not is_directive(pred[1]) then - local pred_name = pred[1] - if pred_name:match('^not%-') then - pred_name = pred_name:sub(5) - is_not = true - end - - local handler = predicate_handlers[pred_name] - - if not handler then - error(string.format('No handler for %s', pred[1])) - return false - end - - local pred_matches = handler(captures, pattern, source, pred) +---@return boolean whether the predicates match +function Query:_match_predicates(predicates, pattern_i, captures, source) + for _, predicate in ipairs(predicates) do + local processed_name = predicate[1] + local should_match = predicate[2] + local orig_predicate = predicate[3] + + local handler = predicate_handlers[processed_name] + if not handler then + error(string.format('No handler for %s', orig_predicate[1])) + return false + end - if not xor(is_not, pred_matches) then - return false - end + local does_match = handler(captures, pattern_i, source, orig_predicate) + if does_match ~= should_match then + return false end end return true end ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param directives vim.treesitter.query.ProcessedDirective[] +---@param source integer|string +---@param captures table<integer, TSNode[]> ---@return vim.treesitter.query.TSMetadata metadata -function Query:apply_directives(preds, pattern, captures, source) +function Query:_apply_directives(directives, pattern_i, captures, source) ---@type vim.treesitter.query.TSMetadata local metadata = {} - for _, pred in pairs(preds) do - if is_directive(pred[1]) then - local handler = directive_handlers[pred[1]] + for _, directive in pairs(directives) do + local handler = directive_handlers[directive[1]] - if not handler then - error(string.format('No handler for %s', pred[1])) - end - - handler(captures, pattern, source, pred, metadata) + if not handler then + error(string.format('No handler for %s', directive[1])) end + + handler(captures, pattern_i, source, directive, metadata) end return metadata @@ -835,12 +873,6 @@ local function value_or_node_range(start, stop, node) return start, stop end ---- @param match TSQueryMatch ---- @return integer -local function match_id_hash(_, match) - return (match:info()) -end - --- Iterates over all captures from all matches in {node}. --- --- {source} is required if the query contains predicates; then the caller @@ -897,7 +929,7 @@ function Query:iter_captures(node, source, start, stop) return end - local match_id, pattern = match:info() + local match_id, pattern_i = match:info() --- @type vim.treesitter.query.TSMetadata local metadata @@ -906,11 +938,14 @@ function Query:iter_captures(node, source, start, stop) end if not metadata then - local preds = self.info.patterns[pattern] - if preds then + metadata = {} + + local processed_pattern = self._processed_patterns[pattern_i] + if processed_pattern then local captures = match:captures() - if not self:match_preds(preds, pattern, captures, source) then + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then cursor:remove_match(match_id) if end_line and captured_node:range() > end_line then return nil, captured_node, nil, nil @@ -918,9 +953,8 @@ function Query:iter_captures(node, source, start, stop) return iter(end_line) -- tail call: try next match end - metadata = self:apply_directives(preds, pattern, captures, source) - else - metadata = {} + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) end highest_cached_match_id = math.max(highest_cached_match_id, match_id) @@ -988,20 +1022,20 @@ function Query:iter_matches(node, source, start, stop, opts) return end - local match_id, pattern = match:info() - local preds = self.info.patterns[pattern] + local match_id, pattern_i = match:info() + local processed_pattern = self._processed_patterns[pattern_i] local captures = match:captures() --- @type vim.treesitter.query.TSMetadata - local metadata - if preds then - if not self:match_preds(preds, pattern, captures, source) then + local metadata = {} + if processed_pattern then + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then cursor:remove_match(match_id) return iter() -- tail call: try next match end - metadata = self:apply_directives(preds, pattern, captures, source) - else - metadata = {} + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) end if opts.all == false then @@ -1012,11 +1046,11 @@ function Query:iter_matches(node, source, start, stop, opts) for k, v in pairs(captures or {}) do old_match[k] = v[#v] end - return pattern, old_match, metadata + return pattern_i, old_match, metadata end -- TODO(lewis6991): create a new function that returns {match, metadata} - return pattern, captures, metadata + return pattern_i, captures, metadata end return iter end diff --git a/test/benchmark/decor_spec.lua b/test/benchmark/decor_spec.lua index 42b2d1e744..1b7e763a09 100644 --- a/test/benchmark/decor_spec.lua +++ b/test/benchmark/decor_spec.lua @@ -99,7 +99,6 @@ describe('decor perf', function() print('\nTotal ' .. fmt(total) .. '\nDecoration provider: ' .. fmt(provider)) end) - it('can handle full screen of highlighting', function() Screen.new(100, 51) diff --git a/test/functional/treesitter/query_spec.lua b/test/functional/treesitter/query_spec.lua index 634f8af83d..6e21ed1d99 100644 --- a/test/functional/treesitter/query_spec.lua +++ b/test/functional/treesitter/query_spec.lua @@ -835,9 +835,9 @@ void ui_refresh(void) local result = exec_lua(function() local query0 = vim.treesitter.query.parse('c', query) - local match_preds = query0.match_preds + local match_preds = query0._match_predicates local called = 0 - function query0:match_preds(...) + function query0:_match_predicates(...) called = called + 1 return match_preds(self, ...) end |