diff options
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 422 |
1 files changed, 300 insertions, 122 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index cd65c0d7f6..5bb9e07a82 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -290,47 +290,71 @@ function M.get_node_text(...) return vim.treesitter.get_node_text(...) end ----@alias TSMatch table<integer,TSNode> - ----@alias TSPredicate fun(match: TSMatch, _, _, predicate: any[]): boolean - --- Predicate handler receive the following arguments --- (match, pattern, bufnr, predicate) ----@type table<string,TSPredicate> -local predicate_handlers = { - ['eq?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then +--- Implementations of predicates that can optionally be prefixed with "any-". +--- +--- These functions contain the implementations for each predicate, correctly +--- handling the "any" vs "all" semantics. They are called from the +--- predicate_handlers table with the appropriate arguments for each predicate. +local impl = { + --- @param match TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + ['eq'] = function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local node_text = vim.treesitter.get_node_text(node, source) - local str ---@type string - if type(predicate[3]) == 'string' then - -- (#eq? @aa "foo") - str = predicate[3] - else - -- (#eq? @aa @bb) - str = vim.treesitter.get_node_text(match[predicate[3]], source) - end + for _, node in ipairs(nodes) do + local node_text = vim.treesitter.get_node_text(node, source) + + local str ---@type string + if type(predicate[3]) == 'string' then + -- (#eq? @aa "foo") + str = predicate[3] + else + -- (#eq? @aa @bb) + local other = assert(match[predicate[3]]) + assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes') + str = vim.treesitter.get_node_text(other[1], source) + end - if node_text ~= str or str == nil then - return false + local res = str ~= nil and node_text == str + if any and res then + return true + elseif not any and not res then + return false + end end - return true + return not any end, - ['lua-match?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then + --- @param match TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + ['lua-match'] = function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local regex = predicate[3] - return string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil + + for _, node in ipairs(nodes) do + local regex = predicate[3] + local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil + if any and res then + return true + elseif not any and not res then + return false + end + end + + return not any end, - ['match?'] = (function() + ['match'] = (function() local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true } local function check_magic(str) if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then @@ -347,85 +371,160 @@ local predicate_handlers = { end, }) - return function(match, _, source, pred) - ---@cast match TSMatch - local node = match[pred[2]] - if not node then + --- @param match TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + return function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - ---@diagnostic disable-next-line no-unknown - local regex = compiled_vim_regexes[pred[3]] - return regex:match_str(vim.treesitter.get_node_text(node, source)) + + for _, node in ipairs(nodes) do + local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex + local res = regex:match_str(vim.treesitter.get_node_text(node, source)) + if any and res then + return true + elseif not any and not res then + return false + end + end + return not any end end)(), - ['contains?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then + --- @param match TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + ['contains'] = function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local node_text = vim.treesitter.get_node_text(node, source) - for i = 3, #predicate do - if string.find(node_text, predicate[i], 1, true) then - return true + for _, node in ipairs(nodes) do + local node_text = vim.treesitter.get_node_text(node, source) + + for i = 3, #predicate do + local res = string.find(node_text, predicate[i], 1, true) + if any and res then + return true + elseif not any and not res then + return false + end end end - return false + return not any + end, +} + +---@class TSMatch +---@field pattern? integer +---@field active? boolean +---@field [integer] TSNode[] + +---@alias TSPredicate fun(match: TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean + +-- Predicate handler receive the following arguments +-- (match, pattern, bufnr, predicate) +---@type table<string,TSPredicate> +local predicate_handlers = { + ['eq?'] = function(match, _, source, predicate) + return impl['eq'](match, source, predicate, false) + end, + + ['any-eq?'] = function(match, _, source, predicate) + return impl['eq'](match, source, predicate, true) + end, + + ['lua-match?'] = function(match, _, source, predicate) + return impl['lua-match'](match, source, predicate, false) + end, + + ['any-lua-match?'] = function(match, _, source, predicate) + return impl['lua-match'](match, source, predicate, true) + end, + + ['match?'] = function(match, _, source, predicate) + return impl['match'](match, source, predicate, false) + end, + + ['any-match?'] = function(match, _, source, predicate) + return impl['match'](match, source, predicate, true) + end, + + ['contains?'] = function(match, _, source, predicate) + return impl['contains'](match, source, predicate, false) + end, + + ['any-contains?'] = function(match, _, source, predicate) + return impl['contains'](match, source, predicate, true) end, ['any-of?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local node_text = vim.treesitter.get_node_text(node, source) - -- Since 'predicate' will not be used by callers of this function, use it - -- to store a string set built from the list of words to check against. - local string_set = predicate['string_set'] - if not string_set then - string_set = {} - for i = 3, #predicate do - ---@diagnostic disable-next-line:no-unknown - string_set[predicate[i]] = true + for _, node in ipairs(nodes) do + local node_text = vim.treesitter.get_node_text(node, source) + + -- Since 'predicate' will not be used by callers of this function, use it + -- to store a string set built from the list of words to check against. + local string_set = predicate['string_set'] --- @type table<string, boolean> + if not string_set then + string_set = {} + for i = 3, #predicate do + string_set[predicate[i]] = true + end + predicate['string_set'] = string_set + end + + if string_set[node_text] then + return true end - predicate['string_set'] = string_set end - return string_set[node_text] + return false end, ['has-ancestor?'] = function(match, _, _, predicate) - local node = match[predicate[2]] - if not node then + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local ancestor_types = {} - for _, type in ipairs({ unpack(predicate, 3) }) do - ancestor_types[type] = true - end + for _, node in ipairs(nodes) do + local ancestor_types = {} --- @type table<string, boolean> + for _, type in ipairs({ unpack(predicate, 3) }) do + ancestor_types[type] = true + end - node = node:parent() - while node do - if ancestor_types[node:type()] then - return true + local cur = node:parent() + while cur do + if ancestor_types[cur:type()] then + return true + end + cur = cur:parent() end - node = node:parent() end return false end, ['has-parent?'] = function(match, _, _, predicate) - local node = match[predicate[2]] - if not node then + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then - return true + for _, node in ipairs(nodes) do + if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then + return true + end end return false end, @@ -433,6 +532,7 @@ local predicate_handlers = { -- As we provide lua-match? also expose vim-match? predicate_handlers['vim-match?'] = predicate_handlers['match?'] +predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?'] ---@class TSMetadata ---@field range? Range @@ -468,13 +568,17 @@ local directive_handlers = { -- Shifts the range of a node. -- Example: (#offset! @_node 0 1 0 -1) ['offset!'] = function(match, _, _, pred, metadata) - ---@cast pred integer[] - local capture_id = pred[2] + local capture_id = pred[2] --[[@as integer]] + local nodes = match[capture_id] + assert(#nodes == 1, '#offset! does not support captures on multiple nodes') + + local node = nodes[1] + if not metadata[capture_id] then metadata[capture_id] = {} end - local range = metadata[capture_id].range or { match[capture_id]:range() } + local range = metadata[capture_id].range or { node:range() } local start_row_offset = pred[3] or 0 local start_col_offset = pred[4] or 0 local end_row_offset = pred[5] or 0 @@ -498,7 +602,9 @@ local directive_handlers = { local id = pred[2] assert(type(id) == 'number') - local node = match[id] + local nodes = match[id] + assert(#nodes == 1, '#gsub! does not support captures on multiple nodes') + local node = nodes[1] local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' if not metadata[id] then @@ -518,10 +624,9 @@ local directive_handlers = { local capture_id = pred[2] assert(type(capture_id) == 'number') - local node = match[capture_id] - if not node then - return - end + local nodes = match[capture_id] + assert(#nodes == 1, '#trim! does not support captures on multiple nodes') + local node = nodes[1] local start_row, start_col, end_row, end_col = node:range() @@ -552,38 +657,93 @@ local directive_handlers = { --- Adds a new predicate to be used in queries --- ---@param name string Name of the predicate, without leading # ----@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[]) +---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) --- - see |vim.treesitter.query.add_directive()| for argument meanings ----@param force boolean|nil -function M.add_predicate(name, handler, force) - if predicate_handlers[name] and not force then - error(string.format('Overriding %s', name)) +---@param opts table<string, any> Optional options: +--- - force (boolean): Override an existing +--- predicate of the same name +--- - all (boolean): Use the correct +--- implementation of the match table where +--- capture IDs map to a list of nodes instead +--- of a single node. Defaults to false (for +--- backward compatibility). This option will +--- eventually become the default and removed. +function M.add_predicate(name, handler, opts) + -- Backward compatibility: old signature had "force" as boolean argument + if type(opts) == 'boolean' then + opts = { force = opts } end - predicate_handlers[name] = handler + opts = opts or {} + + if predicate_handlers[name] and not opts.force then + error(string.format('Overriding existing predicate %s', name)) + end + + if opts.all then + predicate_handlers[name] = handler + else + --- @param match table<integer, TSNode[]> + local function wrapper(match, ...) + local m = {} ---@type table<integer, TSNode> + for k, v in pairs(match) do + if type(k) == 'number' then + m[k] = v[#v] + end + end + return handler(m, ...) + end + predicate_handlers[name] = wrapper + end end --- Adds a new directive to be used in queries --- --- Handlers can set match level data by setting directly on the ---- metadata object `metadata.key = value`, additionally, handlers +--- metadata object `metadata.key = value`. Additionally, handlers --- can set node level data by using the capture id on the --- metadata table `metadata[capture_id].key = value` --- ---@param name string Name of the directive, without leading # ----@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[], metadata:table) ---- - match: see |treesitter-query| ---- - node-level data are accessible via `match[capture_id]` ---- - pattern: see |treesitter-query| +---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) +--- - match: A table mapping capture IDs to a list of captured nodes +--- - pattern: the index of the matching pattern in the query file --- - predicate: list of strings containing the full directive being called, e.g. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` ----@param force boolean|nil -function M.add_directive(name, handler, force) - if directive_handlers[name] and not force then - error(string.format('Overriding %s', name)) +---@param opts table<string, any> Optional options: +--- - force (boolean): Override an existing +--- predicate of the same name +--- - all (boolean): Use the correct +--- implementation of the match table where +--- capture IDs map to a list of nodes instead +--- of a single node. Defaults to false (for +--- backward compatibility). This option will +--- eventually become the default and removed. +function M.add_directive(name, handler, opts) + -- Backward compatibility: old signature had "force" as boolean argument + if type(opts) == 'boolean' then + opts = { force = opts } + end + + opts = opts or {} + + if directive_handlers[name] and not opts.force then + error(string.format('Overriding existing directive %s', name)) end - directive_handlers[name] = handler + if opts.all then + directive_handlers[name] = handler + else + --- @param match table<integer, TSNode[]> + local function wrapper(match, ...) + local m = {} ---@type table<integer, TSNode> + for k, v in pairs(match) do + m[k] = v[#v] + end + handler(m, ...) + end + directive_handlers[name] = wrapper + end end --- Lists the currently available directives to use in queries. @@ -608,7 +768,7 @@ end ---@private ---@param match TSMatch ----@param pattern string +---@param pattern integer ---@param source integer|string function Query:match_preds(match, pattern, source) local preds = self.info.patterns[pattern] @@ -618,18 +778,14 @@ function Query:match_preds(match, pattern, source) -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). -- Also, tree-sitter strips the leading # from predicates for us. - local pred_name ---@type string - - local is_not ---@type boolean + local is_not = false -- Skip over directives... they will get processed after all the predicates. if not is_directive(pred[1]) then - if string.sub(pred[1], 1, 4) == 'not-' then - pred_name = string.sub(pred[1], 5) + local pred_name = pred[1] + if pred_name:match('^not%-') then + pred_name = pred_name:sub(5) is_not = true - else - pred_name = pred[1] - is_not = false end local handler = predicate_handlers[pred_name] @@ -724,7 +880,7 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, true, start, stop) + local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, TSMatch local function iter(end_line) local capture, captured_node, match = raw_iter() local metadata = {} @@ -748,27 +904,34 @@ end --- Iterates the matches of self on a given range. --- ---- Iterate over all matches within a {node}. The arguments are the same as ---- for |Query:iter_captures()| but the iterated values are different: ---- an (1-based) index of the pattern in the query, a table mapping ---- capture indices to nodes, and metadata from any directives processing the match. ---- If the query has more than one pattern, the capture table might be sparse ---- and e.g. `pairs()` method should be used over `ipairs`. ---- Here is an example iterating over all captures in every match: +--- Iterate over all matches within a {node}. The arguments are the same as for +--- |Query:iter_captures()| but the iterated values are different: an (1-based) +--- index of the pattern in the query, a table mapping capture indices to a list +--- of nodes, and metadata from any directives processing the match. +--- +--- WARNING: Set `all=true` to ensure all matching nodes in a match are +--- returned, otherwise only the last node in a match is returned, breaking captures +--- involving quantifiers such as `(comment)+ @comment`. The default option +--- `all=false` is only provided for backward compatibility and will be removed +--- after Nvim 0.10. +--- +--- Example: --- --- ```lua ---- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do ---- for id, node in pairs(match) do +--- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1, { all = true }) do +--- for id, nodes in pairs(match) do --- local name = query.captures[id] ---- -- `node` was captured by the `name` capture in the match +--- for _, node in ipairs(nodes) do +--- -- `node` was captured by the `name` capture in the match --- ---- local node_data = metadata[id] -- Node level metadata ---- ---- -- ... use the info here ... +--- local node_data = metadata[id] -- Node level metadata +--- ... use the info here ... +--- end --- end --- end --- ``` --- +--- ---@param node TSNode under which the search will occur ---@param source (integer|string) Source buffer or string to search ---@param start? integer Starting line for the search. Defaults to `node:start()`. @@ -776,17 +939,20 @@ end ---@param opts? table Optional keyword arguments: --- - max_start_depth (integer) if non-zero, sets the maximum start depth --- for each match. This is used to prevent traversing too deep into a tree. +--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes. +--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is +--- incorrect behavior. This option will eventually become the default and removed. --- ----@return (fun(): integer, table<integer,TSNode>, table): pattern id, match, metadata +---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata function Query:iter_matches(node, source, start, stop, opts) + local all = opts and opts.all if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, false, start, stop, opts) - ---@cast raw_iter fun(): string, any + local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, TSMatch local function iter() local pattern, match = raw_iter() local metadata = {} @@ -799,6 +965,18 @@ function Query:iter_matches(node, source, start, stop, opts) self:apply_directives(match, pattern, source, metadata) end + + if not all then + -- Convert the match table into the old buggy version for backward + -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all" + -- option! + local old_match = {} ---@type table<integer, TSNode> + for k, v in pairs(match or {}) do + old_match[k] = v[#v] + end + return pattern, old_match, metadata + end + return pattern, match, metadata end return iter |