diff options
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 172 |
1 files changed, 103 insertions, 69 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 01fdb708eb..1fc001b39f 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -7,6 +7,59 @@ local memoize = vim.func._memoize local M = {} +local function is_directive(name) + return string.sub(name, -1) == '!' +end + +---@nodoc +---@class vim.treesitter.query.ProcessedPredicate +---@field [1] string predicate name +---@field [2] boolean should match +---@field [3] (integer|string)[] the original predicate + +---@alias vim.treesitter.query.ProcessedDirective (integer|string)[] + +---@nodoc +---@class vim.treesitter.query.ProcessedPattern { +---@field predicates vim.treesitter.query.ProcessedPredicate[] +---@field directives vim.treesitter.query.ProcessedDirective[] + +--- Splits the query patterns into predicates and directives. +---@param patterns table<integer, (integer|string)[][]> +---@return table<integer, vim.treesitter.query.ProcessedPattern> +local function process_patterns(patterns) + ---@type table<integer, vim.treesitter.query.ProcessedPattern> + local processed_patterns = {} + + for k, pattern_list in pairs(patterns) do + ---@type vim.treesitter.query.ProcessedPredicate[] + local predicates = {} + ---@type vim.treesitter.query.ProcessedDirective[] + local directives = {} + + for _, pattern in ipairs(pattern_list) do + -- Note: tree-sitter strips the leading # from predicates for us. + local pred_name = pattern[1] + ---@cast pred_name string + + if is_directive(pred_name) then + table.insert(directives, pattern) + else + local should_match = true + if pred_name:match('^not%-') then + pred_name = pred_name:sub(5) + should_match = false + end + table.insert(predicates, { pred_name, should_match, pattern }) + end + end + + processed_patterns[k] = { predicates = predicates, directives = directives } + end + + return processed_patterns +end + ---@nodoc ---Parsed query, see |vim.treesitter.query.parse()| --- @@ -15,6 +68,7 @@ local M = {} ---@field captures string[] list of (unique) capture names defined in query ---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives) ---@field query TSQuery userdata query object +---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern> local Query = {} Query.__index = Query @@ -33,6 +87,7 @@ function Query.new(lang, ts_query) patterns = query_info.patterns, } self.captures = self.info.captures + self._processed_patterns = process_patterns(self.info.patterns) return self end @@ -751,67 +806,50 @@ function M.list_predicates() return vim.tbl_keys(predicate_handlers) end -local function xor(x, y) - return (x or y) and not (x and y) -end - -local function is_directive(name) - return string.sub(name, -1) == '!' -end - ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param predicates vim.treesitter.query.ProcessedPredicate[] +---@param captures table<integer, TSNode[]> ---@param source integer|string -function Query:match_preds(preds, pattern, captures, source) - for _, pred in pairs(preds) do - -- Here we only want to return if a predicate DOES NOT match, and - -- continue on the other case. This way unknown predicates will not be considered, - -- which allows some testing and easier user extensibility (#12173). - -- Also, tree-sitter strips the leading # from predicates for us. - local is_not = false - - -- Skip over directives... they will get processed after all the predicates. - if not is_directive(pred[1]) then - local pred_name = pred[1] - if pred_name:match('^not%-') then - pred_name = pred_name:sub(5) - is_not = true - end - - local handler = predicate_handlers[pred_name] - - if not handler then - error(string.format('No handler for %s', pred[1])) - return false - end - - local pred_matches = handler(captures, pattern, source, pred) +---@return boolean whether the predicates match +function Query:_match_predicates(predicates, pattern_i, captures, source) + for _, predicate in ipairs(predicates) do + local processed_name = predicate[1] + local should_match = predicate[2] + local orig_predicate = predicate[3] + + local handler = predicate_handlers[processed_name] + if not handler then + error(string.format('No handler for %s', orig_predicate[1])) + return false + end - if not xor(is_not, pred_matches) then - return false - end + local does_match = handler(captures, pattern_i, source, orig_predicate) + if does_match ~= should_match then + return false end end return true end ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param directives vim.treesitter.query.ProcessedDirective[] +---@param source integer|string +---@param captures table<integer, TSNode[]> ---@return vim.treesitter.query.TSMetadata metadata -function Query:apply_directives(preds, pattern, captures, source) +function Query:_apply_directives(directives, pattern_i, captures, source) ---@type vim.treesitter.query.TSMetadata local metadata = {} - for _, pred in pairs(preds) do - if is_directive(pred[1]) then - local handler = directive_handlers[pred[1]] + for _, directive in pairs(directives) do + local handler = directive_handlers[directive[1]] - if not handler then - error(string.format('No handler for %s', pred[1])) - end - - handler(captures, pattern, source, pred, metadata) + if not handler then + error(string.format('No handler for %s', directive[1])) end + + handler(captures, pattern_i, source, directive, metadata) end return metadata @@ -835,12 +873,6 @@ local function value_or_node_range(start, stop, node) return start, stop end ---- @param match TSQueryMatch ---- @return integer -local function match_id_hash(_, match) - return (match:info()) -end - --- Iterates over all captures from all matches in {node}. --- --- {source} is required if the query contains predicates; then the caller @@ -897,7 +929,7 @@ function Query:iter_captures(node, source, start, stop) return end - local match_id, pattern = match:info() + local match_id, pattern_i = match:info() --- @type vim.treesitter.query.TSMetadata local metadata @@ -906,11 +938,14 @@ function Query:iter_captures(node, source, start, stop) end if not metadata then - local preds = self.info.patterns[pattern] - if preds then + metadata = {} + + local processed_pattern = self._processed_patterns[pattern_i] + if processed_pattern then local captures = match:captures() - if not self:match_preds(preds, pattern, captures, source) then + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then cursor:remove_match(match_id) if end_line and captured_node:range() > end_line then return nil, captured_node, nil, nil @@ -918,9 +953,8 @@ function Query:iter_captures(node, source, start, stop) return iter(end_line) -- tail call: try next match end - metadata = self:apply_directives(preds, pattern, captures, source) - else - metadata = {} + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) end highest_cached_match_id = math.max(highest_cached_match_id, match_id) @@ -988,20 +1022,20 @@ function Query:iter_matches(node, source, start, stop, opts) return end - local match_id, pattern = match:info() - local preds = self.info.patterns[pattern] + local match_id, pattern_i = match:info() + local processed_pattern = self._processed_patterns[pattern_i] local captures = match:captures() --- @type vim.treesitter.query.TSMetadata - local metadata - if preds then - if not self:match_preds(preds, pattern, captures, source) then + local metadata = {} + if processed_pattern then + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then cursor:remove_match(match_id) return iter() -- tail call: try next match end - metadata = self:apply_directives(preds, pattern, captures, source) - else - metadata = {} + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) end if opts.all == false then @@ -1012,11 +1046,11 @@ function Query:iter_matches(node, source, start, stop, opts) for k, v in pairs(captures or {}) do old_match[k] = v[#v] end - return pattern, old_match, metadata + return pattern_i, old_match, metadata end -- TODO(lewis6991): create a new function that returns {match, metadata} - return pattern, captures, metadata + return pattern_i, captures, metadata end return iter end |