diff options
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 225 |
1 files changed, 128 insertions, 97 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index a086f5e876..ef5c2143a7 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,5 +1,6 @@ local api = vim.api local language = require('vim.treesitter.language') +local memoize = vim.func._memoize local M = {} @@ -88,7 +89,7 @@ end --- ---@param lang string Language to get query for ---@param query_name string Name of the query to load (e.g., "highlights") ----@param is_included (boolean|nil) Internal parameter, most of the time left as `nil` +---@param is_included? boolean Internal parameter, most of the time left as `nil` ---@return string[] query_files List of files to load for given query and language function M.get_files(lang, query_name, is_included) local query_path = string.format('queries/%s/%s.scm', lang, query_name) @@ -211,8 +212,8 @@ end ---@param lang string Language to use for the query ---@param query_name string Name of the query (e.g. "highlights") --- ----@return vim.treesitter.Query|nil : Parsed query. `nil` if no query files are found. -M.get = vim.func._memoize('concat-2', function(lang, query_name) +---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found. +M.get = memoize('concat-2', function(lang, query_name) if explicit_queries[lang][query_name] then return explicit_queries[lang][query_name] end @@ -242,10 +243,10 @@ end) ---@param lang string Language to use for the query ---@param query string Query in s-expr syntax --- ----@return vim.treesitter.Query Parsed query +---@return vim.treesitter.Query : Parsed query --- ----@see |vim.treesitter.query.get()| -M.parse = vim.func._memoize('concat-2', function(lang, query) +---@see [vim.treesitter.query.get()] +M.parse = memoize('concat-2', function(lang, query) language.add(lang) local ts_query = vim._ts_parse_query(lang, query) @@ -258,7 +259,7 @@ end) --- handling the "any" vs "all" semantics. They are called from the --- predicate_handlers table with the appropriate arguments for each predicate. local impl = { - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -293,7 +294,7 @@ local impl = { return not any end, - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -333,7 +334,7 @@ local impl = { end, }) - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -356,7 +357,7 @@ local impl = { end end)(), - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -383,13 +384,7 @@ local impl = { end, } ----@nodoc ----@class vim.treesitter.query.TSMatch ----@field pattern? integer ----@field active? boolean ----@field [integer] TSNode[] - ----@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean +---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -462,17 +457,8 @@ local predicate_handlers = { end for _, node in ipairs(nodes) do - local ancestor_types = {} --- @type table<string, boolean> - for _, type in ipairs({ unpack(predicate, 3) }) do - ancestor_types[type] = true - end - - local cur = node:parent() - while cur do - if ancestor_types[cur:type()] then - return true - end - cur = cur:parent() + if node:__has_ancestor(predicate) then + return true end end return false @@ -504,7 +490,7 @@ predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?'] ---@field [integer] vim.treesitter.query.TSMetadata ---@field [string] integer|string ----@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) +---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -534,6 +520,9 @@ local directive_handlers = { ['offset!'] = function(match, _, _, pred, metadata) local capture_id = pred[2] --[[@as integer]] local nodes = match[capture_id] + if not nodes or #nodes == 0 then + return + end assert(#nodes == 1, '#offset! does not support captures on multiple nodes') local node = nodes[1] @@ -567,6 +556,9 @@ local directive_handlers = { assert(type(id) == 'number') local nodes = match[id] + if not nodes or #nodes == 0 then + return + end assert(#nodes == 1, '#gsub! does not support captures on multiple nodes') local node = nodes[1] local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' @@ -589,6 +581,9 @@ local directive_handlers = { assert(type(capture_id) == 'number') local nodes = match[capture_id] + if not nodes or #nodes == 0 then + return + end assert(#nodes == 1, '#trim! does not support captures on multiple nodes') local node = nodes[1] @@ -618,20 +613,23 @@ local directive_handlers = { end, } +--- @class vim.treesitter.query.add_predicate.Opts +--- @inlinedoc +--- +--- Override an existing predicate of the same name +--- @field force? boolean +--- +--- Use the correct implementation of the match table where capture IDs map to +--- a list of nodes instead of a single node. Defaults to false (for backward +--- compatibility). This option will eventually become the default and removed. +--- @field all? boolean + --- Adds a new predicate to be used in queries --- ---@param name string Name of the predicate, without leading # ----@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) +---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) --- - see |vim.treesitter.query.add_directive()| for argument meanings ----@param opts table<string, any> Optional options: ---- - force (boolean): Override an existing ---- predicate of the same name ---- - all (boolean): Use the correct ---- implementation of the match table where ---- capture IDs map to a list of nodes instead ---- of a single node. Defaults to false (for ---- backward compatibility). This option will ---- eventually become the default and removed. +---@param opts vim.treesitter.query.add_predicate.Opts function M.add_predicate(name, handler, opts) -- Backward compatibility: old signature had "force" as boolean argument if type(opts) == 'boolean' then @@ -669,20 +667,12 @@ end --- metadata table `metadata[capture_id].key = value` --- ---@param name string Name of the directive, without leading # ----@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) +---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) --- - match: A table mapping capture IDs to a list of captured nodes --- - pattern: the index of the matching pattern in the query file --- - predicate: list of strings containing the full directive being called, e.g. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` ----@param opts table<string, any> Optional options: ---- - force (boolean): Override an existing ---- predicate of the same name ---- - all (boolean): Use the correct ---- implementation of the match table where ---- capture IDs map to a list of nodes instead ---- of a single node. Defaults to false (for ---- backward compatibility). This option will ---- eventually become the default and removed. +---@param opts vim.treesitter.query.add_predicate.Opts function M.add_directive(name, handler, opts) -- Backward compatibility: old signature had "force" as boolean argument if type(opts) == 'boolean' then @@ -711,13 +701,13 @@ function M.add_directive(name, handler, opts) end --- Lists the currently available directives to use in queries. ----@return string[] List of supported directives. +---@return string[] : Supported directives. function M.list_directives() return vim.tbl_keys(directive_handlers) end --- Lists the currently available predicates to use in queries. ----@return string[] List of supported predicates. +---@return string[] : Supported predicates. function M.list_predicates() return vim.tbl_keys(predicate_handlers) end @@ -731,13 +721,19 @@ local function is_directive(name) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param pattern integer +---@param match TSQueryMatch ---@param source integer|string -function Query:match_preds(match, pattern, source) +function Query:match_preds(match, source) + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return true + end + + local captures = match:captures() + + for _, pred in pairs(preds) do -- Here we only want to return if a predicate DOES NOT match, and -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). @@ -759,7 +755,7 @@ function Query:match_preds(match, pattern, source) return false end - local pred_matches = handler(match, pattern, source, pred) + local pred_matches = handler(captures, pattern, source, pred) if not xor(is_not, pred_matches) then return false @@ -770,30 +766,40 @@ function Query:match_preds(match, pattern, source) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param metadata vim.treesitter.query.TSMetadata -function Query:apply_directives(match, pattern, source, metadata) +---@param match TSQueryMatch +---@return vim.treesitter.query.TSMetadata metadata +function Query:apply_directives(match, source) + ---@type vim.treesitter.query.TSMetadata + local metadata = {} + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return metadata + end + + local captures = match:captures() + + for _, pred in pairs(preds) do if is_directive(pred[1]) then local handler = directive_handlers[pred[1]] if not handler then error(string.format('No handler for %s', pred[1])) - return end - handler(match, pattern, source, pred, metadata) + handler(captures, pattern, source, pred, metadata) end end + + return metadata end --- Returns the start and stop value if set else the node's range. -- When the node's range is used, the stop is incremented by 1 -- to make the search inclusive. ----@param start integer|nil ----@param stop integer|nil +---@param start integer? +---@param stop integer? ---@param node TSNode ---@return integer, integer local function value_or_node_range(start, stop, node) @@ -807,6 +813,12 @@ local function value_or_node_range(start, stop, node) return start, stop end +--- @param match TSQueryMatch +--- @return integer +local function match_id_hash(_, match) + return (match:info()) +end + --- Iterate over all captures from all matches inside {node} --- --- {source} is needed if the query contains predicates; then the caller @@ -816,12 +828,13 @@ end --- as the {node}, i.e., to get syntax highlight matches in the current --- viewport). When omitted, the {start} and {stop} row values are used from the given node. --- ---- The iterator returns three values: a numeric id identifying the capture, ---- the captured node, and metadata from any directives processing the match. +--- The iterator returns four values: a numeric id identifying the capture, +--- the captured node, metadata from any directives processing the match, +--- and the match itself. --- The following example shows how to get captures by name: --- --- ```lua ---- for id, node, metadata in query:iter_captures(tree:root(), bufnr, first, last) do +--- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do --- local name = query.captures[id] -- name of the capture in the query --- -- typically useful info about the node: --- local type = node:type() -- type of the captured node @@ -835,8 +848,10 @@ end ---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata): ---- capture id, capture node, metadata +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch): +--- capture id, capture node, metadata, match +--- +---@note Captures are only returned if the query pattern of a specific capture contained predicates. function Query:iter_captures(node, source, start, stop) if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() @@ -844,24 +859,30 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) + + local apply_directives = memoize(match_id_hash, self.apply_directives, true) + local match_preds = memoize(match_id_hash, self.match_preds, true) + local function iter(end_line) - local capture, captured_node, match = raw_iter() - local metadata = {} - - if match ~= nil then - local active = self:match_preds(match, match.pattern, source) - match.active = active - if not active then - if end_line and captured_node:range() > end_line then - return nil, captured_node, nil - end - return iter(end_line) -- tail call: try next match - end + local capture, captured_node, match = cursor:next_capture() - self:apply_directives(match, match.pattern, source, metadata) + if not capture then + return end - return capture, captured_node, metadata + + if not match_preds(self, match, source) then + local match_id = match:info() + cursor:remove_match(match_id) + if end_line and captured_node:range() > end_line then + return nil, captured_node, nil, nil + end + return iter(end_line) -- tail call: try next match + end + + local metadata = apply_directives(self, match, source) + + return capture, captured_node, metadata, match end return iter end @@ -903,45 +924,55 @@ end ---@param opts? table Optional keyword arguments: --- - max_start_depth (integer) if non-zero, sets the maximum start depth --- for each match. This is used to prevent traversing too deep into a tree. +--- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256). --- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes. --- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is --- incorrect behavior. This option will eventually become the default and removed. --- ---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata function Query:iter_matches(node, source, start, stop, opts) - local all = opts and opts.all + opts = opts or {} + opts.match_limit = opts.match_limit or 256 + if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts) + local function iter() - local pattern, match = raw_iter() - local metadata = {} + local match = cursor:next_match() - if match ~= nil then - local active = self:match_preds(match, pattern, source) - if not active then - return iter() -- tail call: try next match - end + if not match then + return + end - self:apply_directives(match, pattern, source, metadata) + local match_id, pattern = match:info() + + if not self:match_preds(match, source) then + cursor:remove_match(match_id) + return iter() -- tail call: try next match end - if not all then + local metadata = self:apply_directives(match, source) + + local captures = match:captures() + + if not opts.all then -- Convert the match table into the old buggy version for backward -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all" -- option! local old_match = {} ---@type table<integer, TSNode> - for k, v in pairs(match or {}) do + for k, v in pairs(captures or {}) do old_match[k] = v[#v] end return pattern, old_match, metadata end - return pattern, match, metadata + -- TODO(lewis6991): create a new function that returns {match, metadata} + return pattern, captures, metadata end return iter end |