diff options
author | Justin M. Keyes <justinkz@gmail.com> | 2025-01-06 07:09:34 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-06 07:09:34 -0800 |
commit | 3d9fb975b957c12d412ddc508998cf59aac2a4d7 (patch) | |
tree | 673b598c36c90019cb2dfcb03c4b7dab0e1120da /runtime/lua | |
parent | 86770108e2c6e08c2b8b95f1611923ba99b854dd (diff) | |
parent | dd234135ad20119917831fd8ffcb19d8562022ca (diff) | |
download | rneovim-3d9fb975b957c12d412ddc508998cf59aac2a4d7.tar.gz rneovim-3d9fb975b957c12d412ddc508998cf59aac2a4d7.tar.bz2 rneovim-3d9fb975b957c12d412ddc508998cf59aac2a4d7.zip |
Merge #31625 perf(decor): improve iter_captures() cache
Diffstat (limited to 'runtime/lua')
-rw-r--r-- | runtime/lua/vim/treesitter/highlighter.lua | 4 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 221 |
2 files changed, 135 insertions, 90 deletions
diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index 8ce8652f7d..96503c38ea 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -299,6 +299,8 @@ local function on_line_impl(self, buf, line, is_spell_nav) state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) end + local captures = state.highlighter_query:query().captures + while line >= state.next_row do local capture, node, metadata, match = state.iter(line) @@ -311,7 +313,7 @@ local function on_line_impl(self, buf, line, is_spell_nav) if capture then local hl = state.highlighter_query:get_hl_from_capture(capture) - local capture_name = state.highlighter_query:query().captures[capture] + local capture_name = captures[capture] local spell, spell_pri_offset = get_spell(capture_name) diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 2b3b9096a6..1fc001b39f 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -7,6 +7,59 @@ local memoize = vim.func._memoize local M = {} +local function is_directive(name) + return string.sub(name, -1) == '!' +end + +---@nodoc +---@class vim.treesitter.query.ProcessedPredicate +---@field [1] string predicate name +---@field [2] boolean should match +---@field [3] (integer|string)[] the original predicate + +---@alias vim.treesitter.query.ProcessedDirective (integer|string)[] + +---@nodoc +---@class vim.treesitter.query.ProcessedPattern { +---@field predicates vim.treesitter.query.ProcessedPredicate[] +---@field directives vim.treesitter.query.ProcessedDirective[] + +--- Splits the query patterns into predicates and directives. +---@param patterns table<integer, (integer|string)[][]> +---@return table<integer, vim.treesitter.query.ProcessedPattern> +local function process_patterns(patterns) + ---@type table<integer, vim.treesitter.query.ProcessedPattern> + local processed_patterns = {} + + for k, pattern_list in pairs(patterns) do + ---@type vim.treesitter.query.ProcessedPredicate[] + local predicates = {} + ---@type vim.treesitter.query.ProcessedDirective[] + local directives = {} + + for _, pattern in ipairs(pattern_list) do + -- Note: tree-sitter strips the leading # from predicates for us. + local pred_name = pattern[1] + ---@cast pred_name string + + if is_directive(pred_name) then + table.insert(directives, pattern) + else + local should_match = true + if pred_name:match('^not%-') then + pred_name = pred_name:sub(5) + should_match = false + end + table.insert(predicates, { pred_name, should_match, pattern }) + end + end + + processed_patterns[k] = { predicates = predicates, directives = directives } + end + + return processed_patterns +end + ---@nodoc ---Parsed query, see |vim.treesitter.query.parse()| --- @@ -15,6 +68,7 @@ local M = {} ---@field captures string[] list of (unique) capture names defined in query ---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives) ---@field query TSQuery userdata query object +---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern> local Query = {} Query.__index = Query @@ -33,6 +87,7 @@ function Query.new(lang, ts_query) patterns = query_info.patterns, } self.captures = self.info.captures + self._processed_patterns = process_patterns(self.info.patterns) return self end @@ -751,84 +806,50 @@ function M.list_predicates() return vim.tbl_keys(predicate_handlers) end -local function xor(x, y) - return (x or y) and not (x and y) -end - -local function is_directive(name) - return string.sub(name, -1) == '!' -end - ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param predicates vim.treesitter.query.ProcessedPredicate[] +---@param captures table<integer, TSNode[]> ---@param source integer|string -function Query:match_preds(match, source) - local _, pattern = match:info() - local preds = self.info.patterns[pattern] - - if not preds then - return true - end - - local captures = match:captures() - - for _, pred in pairs(preds) do - -- Here we only want to return if a predicate DOES NOT match, and - -- continue on the other case. This way unknown predicates will not be considered, - -- which allows some testing and easier user extensibility (#12173). - -- Also, tree-sitter strips the leading # from predicates for us. - local is_not = false - - -- Skip over directives... they will get processed after all the predicates. - if not is_directive(pred[1]) then - local pred_name = pred[1] - if pred_name:match('^not%-') then - pred_name = pred_name:sub(5) - is_not = true - end - - local handler = predicate_handlers[pred_name] - - if not handler then - error(string.format('No handler for %s', pred[1])) - return false - end - - local pred_matches = handler(captures, pattern, source, pred) +---@return boolean whether the predicates match +function Query:_match_predicates(predicates, pattern_i, captures, source) + for _, predicate in ipairs(predicates) do + local processed_name = predicate[1] + local should_match = predicate[2] + local orig_predicate = predicate[3] + + local handler = predicate_handlers[processed_name] + if not handler then + error(string.format('No handler for %s', orig_predicate[1])) + return false + end - if not xor(is_not, pred_matches) then - return false - end + local does_match = handler(captures, pattern_i, source, orig_predicate) + if does_match ~= should_match then + return false end end return true end ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param directives vim.treesitter.query.ProcessedDirective[] +---@param source integer|string +---@param captures table<integer, TSNode[]> ---@return vim.treesitter.query.TSMetadata metadata -function Query:apply_directives(match, source) +function Query:_apply_directives(directives, pattern_i, captures, source) ---@type vim.treesitter.query.TSMetadata local metadata = {} - local _, pattern = match:info() - local preds = self.info.patterns[pattern] - if not preds then - return metadata - end + for _, directive in pairs(directives) do + local handler = directive_handlers[directive[1]] - local captures = match:captures() - - for _, pred in pairs(preds) do - if is_directive(pred[1]) then - local handler = directive_handlers[pred[1]] - - if not handler then - error(string.format('No handler for %s', pred[1])) - end - - handler(captures, pattern, source, pred, metadata) + if not handler then + error(string.format('No handler for %s', directive[1])) end + + handler(captures, pattern_i, source, directive, metadata) end return metadata @@ -852,12 +873,6 @@ local function value_or_node_range(start, stop, node) return start, stop end ---- @param match TSQueryMatch ---- @return integer -local function match_id_hash(_, match) - return (match:info()) -end - --- Iterates over all captures from all matches in {node}. --- --- {source} is required if the query contains predicates; then the caller @@ -902,8 +917,10 @@ function Query:iter_captures(node, source, start, stop) local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) - local apply_directives = memoize(match_id_hash, self.apply_directives, false) - local match_preds = memoize(match_id_hash, self.match_preds, false) + -- For faster checks that a match is not in the cache. + local highest_cached_match_id = -1 + ---@type table<integer, vim.treesitter.query.TSMetadata> + local match_cache = {} local function iter(end_line) local capture, captured_node, match = cursor:next_capture() @@ -912,16 +929,37 @@ function Query:iter_captures(node, source, start, stop) return end - if not match_preds(self, match, source) then - local match_id = match:info() - cursor:remove_match(match_id) - if end_line and captured_node:range() > end_line then - return nil, captured_node, nil, nil - end - return iter(end_line) -- tail call: try next match + local match_id, pattern_i = match:info() + + --- @type vim.treesitter.query.TSMetadata + local metadata + if match_id <= highest_cached_match_id then + metadata = match_cache[match_id] end - local metadata = apply_directives(self, match, source) + if not metadata then + metadata = {} + + local processed_pattern = self._processed_patterns[pattern_i] + if processed_pattern then + local captures = match:captures() + + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then + cursor:remove_match(match_id) + if end_line and captured_node:range() > end_line then + return nil, captured_node, nil, nil + end + return iter(end_line) -- tail call: try next match + end + + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) + end + + highest_cached_match_id = math.max(highest_cached_match_id, match_id) + match_cache[match_id] = metadata + end return capture, captured_node, metadata, match end @@ -984,17 +1022,22 @@ function Query:iter_matches(node, source, start, stop, opts) return end - local match_id, pattern = match:info() + local match_id, pattern_i = match:info() + local processed_pattern = self._processed_patterns[pattern_i] + local captures = match:captures() - if not self:match_preds(match, source) then - cursor:remove_match(match_id) - return iter() -- tail call: try next match + --- @type vim.treesitter.query.TSMetadata + local metadata = {} + if processed_pattern then + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then + cursor:remove_match(match_id) + return iter() -- tail call: try next match + end + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) end - local metadata = self:apply_directives(match, source) - - local captures = match:captures() - if opts.all == false then -- Convert the match table into the old buggy version for backward -- compatibility. This is slow, but we only do it when the caller explicitly opted into it by @@ -1003,11 +1046,11 @@ function Query:iter_matches(node, source, start, stop, opts) for k, v in pairs(captures or {}) do old_match[k] = v[#v] end - return pattern, old_match, metadata + return pattern_i, old_match, metadata end -- TODO(lewis6991): create a new function that returns {match, metadata} - return pattern, captures, metadata + return pattern_i, captures, metadata end return iter end |