aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua
diff options
context:
space:
mode:
authorThomas Vigouroux <thomas.vigouroux@protonmail.com>2024-02-16 18:54:47 +0100
committerGitHub <noreply@github.com>2024-02-16 11:54:47 -0600
commitbd5008de07d29a6457ddc7fe13f9f85c9c4619d2 (patch)
tree1c73e5c0bdefb1fa635afdae86516219a7c34fff /runtime/lua
parent1ba3500abdb23027b7ba9bcc9b4f697dcd5ad886 (diff)
downloadrneovim-bd5008de07d29a6457ddc7fe13f9f85c9c4619d2.tar.gz
rneovim-bd5008de07d29a6457ddc7fe13f9f85c9c4619d2.tar.bz2
rneovim-bd5008de07d29a6457ddc7fe13f9f85c9c4619d2.zip
fix(treesitter): correctly handle query quantifiers (#24738)
Query patterns can contain quantifiers (e.g. (foo)+ @bar), so a single capture can map to multiple nodes. The iter_matches API can not handle this situation because the match table incorrectly maps capture indices to a single node instead of to an array of nodes. The match table should be updated to map capture indices to an array of nodes. However, this is a massively breaking change, so must be done with a proper deprecation period. `iter_matches`, `add_predicate` and `add_directive` must opt-in to the correct behavior for backward compatibility. This is done with a new "all" option. This option will become the default and removed after the 0.10 release. Co-authored-by: Christian Clason <c.clason@uni-graz.at> Co-authored-by: MDeiml <matthias@deiml.net> Co-authored-by: Gregory Anders <greg@gpanders.com>
Diffstat (limited to 'runtime/lua')
-rw-r--r--runtime/lua/vim/treesitter/_meta.lua4
-rw-r--r--runtime/lua/vim/treesitter/_query_linter.lua28
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua28
-rw-r--r--runtime/lua/vim/treesitter/query.lua422
4 files changed, 335 insertions, 147 deletions
diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua
index 6a714de052..0b285d2d7f 100644
--- a/runtime/lua/vim/treesitter/_meta.lua
+++ b/runtime/lua/vim/treesitter/_meta.lua
@@ -39,7 +39,7 @@ local TSNode = {}
---@param start? integer
---@param end_? integer
---@param opts? table
----@return fun(): integer, TSNode, any
+---@return fun(): integer, TSNode, TSMatch
function TSNode:_rawquery(query, captures, start, end_, opts) end
---@param query TSQuery
@@ -47,7 +47,7 @@ function TSNode:_rawquery(query, captures, start, end_, opts) end
---@param start? integer
---@param end_? integer
---@param opts? table
----@return fun(): integer, any
+---@return fun(): integer, TSMatch
function TSNode:_rawquery(query, captures, start, end_, opts) end
---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string)
diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua
index 8651e187c2..378e9c67aa 100644
--- a/runtime/lua/vim/treesitter/_query_linter.lua
+++ b/runtime/lua/vim/treesitter/_query_linter.lua
@@ -122,7 +122,7 @@ local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
end)
--- @param buf integer
---- @param match table<integer,TSNode>
+--- @param match table<integer,TSNode[]>
--- @param query Query
--- @param lang_context QueryLinterLanguageContext
--- @param diagnostics Diagnostic[]
@@ -130,20 +130,22 @@ local function lint_match(buf, match, query, lang_context, diagnostics)
local lang = lang_context.lang
local parser_info = lang_context.parser_info
- for id, node in pairs(match) do
- local cap_id = query.captures[id]
+ for id, nodes in pairs(match) do
+ for _, node in ipairs(nodes) do
+ local cap_id = query.captures[id]
- -- perform language-independent checks only for first lang
- if lang_context.is_first_lang and cap_id == 'error' then
- local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
- add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text)
- end
+ -- perform language-independent checks only for first lang
+ if lang_context.is_first_lang and cap_id == 'error' then
+ local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
+ add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text)
+ end
- -- other checks rely on Neovim parser introspection
- if lang and parser_info and cap_id == 'toplevel' then
- local err = parse(node, buf, lang)
- if err then
- add_lint_for_node(diagnostics, err.range, err.msg, lang)
+ -- other checks rely on Neovim parser introspection
+ if lang and parser_info and cap_id == 'toplevel' then
+ local err = parse(node, buf, lang)
+ if err then
+ add_lint_for_node(diagnostics, err.range, err.msg, lang)
+ end
end
end
end
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index 971c4449e8..79566f5eeb 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -784,7 +784,7 @@ end
---@private
--- Extract injections according to:
--- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection
----@param match table<integer,TSNode>
+---@param match table<integer,TSNode[]>
---@param metadata TSMetadata
---@return string?, boolean, Range6[]
function LanguageTree:_get_injection(match, metadata)
@@ -796,14 +796,16 @@ function LanguageTree:_get_injection(match, metadata)
or (injection_lang and resolve_lang(injection_lang))
local include_children = metadata['injection.include-children'] ~= nil
- for id, node in pairs(match) do
- local name = self._injection_query.captures[id]
- -- Lang should override any other language tag
- if name == 'injection.language' then
- local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] })
- lang = resolve_lang(text)
- elseif name == 'injection.content' then
- ranges = get_node_ranges(node, self._source, metadata[id], include_children)
+ for id, nodes in pairs(match) do
+ for _, node in ipairs(nodes) do
+ local name = self._injection_query.captures[id]
+ -- Lang should override any other language tag
+ if name == 'injection.language' then
+ local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] })
+ lang = resolve_lang(text)
+ elseif name == 'injection.content' then
+ ranges = get_node_ranges(node, self._source, metadata[id], include_children)
+ end
end
end
@@ -844,7 +846,13 @@ function LanguageTree:_get_injections()
local start_line, _, end_line, _ = root_node:range()
for pattern, match, metadata in
- self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1)
+ self._injection_query:iter_matches(
+ root_node,
+ self._source,
+ start_line,
+ end_line + 1,
+ { all = true }
+ )
do
local lang, combined, ranges = self:_get_injection(match, metadata)
if lang then
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index cd65c0d7f6..5bb9e07a82 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -290,47 +290,71 @@ function M.get_node_text(...)
return vim.treesitter.get_node_text(...)
end
----@alias TSMatch table<integer,TSNode>
-
----@alias TSPredicate fun(match: TSMatch, _, _, predicate: any[]): boolean
-
--- Predicate handler receive the following arguments
--- (match, pattern, bufnr, predicate)
----@type table<string,TSPredicate>
-local predicate_handlers = {
- ['eq?'] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- if not node then
+--- Implementations of predicates that can optionally be prefixed with "any-".
+---
+--- These functions contain the implementations for each predicate, correctly
+--- handling the "any" vs "all" semantics. They are called from the
+--- predicate_handlers table with the appropriate arguments for each predicate.
+local impl = {
+ --- @param match TSMatch
+ --- @param source integer|string
+ --- @param predicate any[]
+ --- @param any boolean
+ ['eq'] = function(match, source, predicate, any)
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local node_text = vim.treesitter.get_node_text(node, source)
- local str ---@type string
- if type(predicate[3]) == 'string' then
- -- (#eq? @aa "foo")
- str = predicate[3]
- else
- -- (#eq? @aa @bb)
- str = vim.treesitter.get_node_text(match[predicate[3]], source)
- end
+ for _, node in ipairs(nodes) do
+ local node_text = vim.treesitter.get_node_text(node, source)
+
+ local str ---@type string
+ if type(predicate[3]) == 'string' then
+ -- (#eq? @aa "foo")
+ str = predicate[3]
+ else
+ -- (#eq? @aa @bb)
+ local other = assert(match[predicate[3]])
+ assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes')
+ str = vim.treesitter.get_node_text(other[1], source)
+ end
- if node_text ~= str or str == nil then
- return false
+ local res = str ~= nil and node_text == str
+ if any and res then
+ return true
+ elseif not any and not res then
+ return false
+ end
end
- return true
+ return not any
end,
- ['lua-match?'] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- if not node then
+ --- @param match TSMatch
+ --- @param source integer|string
+ --- @param predicate any[]
+ --- @param any boolean
+ ['lua-match'] = function(match, source, predicate, any)
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local regex = predicate[3]
- return string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil
+
+ for _, node in ipairs(nodes) do
+ local regex = predicate[3]
+ local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil
+ if any and res then
+ return true
+ elseif not any and not res then
+ return false
+ end
+ end
+
+ return not any
end,
- ['match?'] = (function()
+ ['match'] = (function()
local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true }
local function check_magic(str)
if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then
@@ -347,85 +371,160 @@ local predicate_handlers = {
end,
})
- return function(match, _, source, pred)
- ---@cast match TSMatch
- local node = match[pred[2]]
- if not node then
+ --- @param match TSMatch
+ --- @param source integer|string
+ --- @param predicate any[]
+ --- @param any boolean
+ return function(match, source, predicate, any)
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- ---@diagnostic disable-next-line no-unknown
- local regex = compiled_vim_regexes[pred[3]]
- return regex:match_str(vim.treesitter.get_node_text(node, source))
+
+ for _, node in ipairs(nodes) do
+ local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex
+ local res = regex:match_str(vim.treesitter.get_node_text(node, source))
+ if any and res then
+ return true
+ elseif not any and not res then
+ return false
+ end
+ end
+ return not any
end
end)(),
- ['contains?'] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- if not node then
+ --- @param match TSMatch
+ --- @param source integer|string
+ --- @param predicate any[]
+ --- @param any boolean
+ ['contains'] = function(match, source, predicate, any)
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local node_text = vim.treesitter.get_node_text(node, source)
- for i = 3, #predicate do
- if string.find(node_text, predicate[i], 1, true) then
- return true
+ for _, node in ipairs(nodes) do
+ local node_text = vim.treesitter.get_node_text(node, source)
+
+ for i = 3, #predicate do
+ local res = string.find(node_text, predicate[i], 1, true)
+ if any and res then
+ return true
+ elseif not any and not res then
+ return false
+ end
end
end
- return false
+ return not any
+ end,
+}
+
+---@class TSMatch
+---@field pattern? integer
+---@field active? boolean
+---@field [integer] TSNode[]
+
+---@alias TSPredicate fun(match: TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean
+
+-- Predicate handler receive the following arguments
+-- (match, pattern, bufnr, predicate)
+---@type table<string,TSPredicate>
+local predicate_handlers = {
+ ['eq?'] = function(match, _, source, predicate)
+ return impl['eq'](match, source, predicate, false)
+ end,
+
+ ['any-eq?'] = function(match, _, source, predicate)
+ return impl['eq'](match, source, predicate, true)
+ end,
+
+ ['lua-match?'] = function(match, _, source, predicate)
+ return impl['lua-match'](match, source, predicate, false)
+ end,
+
+ ['any-lua-match?'] = function(match, _, source, predicate)
+ return impl['lua-match'](match, source, predicate, true)
+ end,
+
+ ['match?'] = function(match, _, source, predicate)
+ return impl['match'](match, source, predicate, false)
+ end,
+
+ ['any-match?'] = function(match, _, source, predicate)
+ return impl['match'](match, source, predicate, true)
+ end,
+
+ ['contains?'] = function(match, _, source, predicate)
+ return impl['contains'](match, source, predicate, false)
+ end,
+
+ ['any-contains?'] = function(match, _, source, predicate)
+ return impl['contains'](match, source, predicate, true)
end,
['any-of?'] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- if not node then
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local node_text = vim.treesitter.get_node_text(node, source)
- -- Since 'predicate' will not be used by callers of this function, use it
- -- to store a string set built from the list of words to check against.
- local string_set = predicate['string_set']
- if not string_set then
- string_set = {}
- for i = 3, #predicate do
- ---@diagnostic disable-next-line:no-unknown
- string_set[predicate[i]] = true
+ for _, node in ipairs(nodes) do
+ local node_text = vim.treesitter.get_node_text(node, source)
+
+ -- Since 'predicate' will not be used by callers of this function, use it
+ -- to store a string set built from the list of words to check against.
+ local string_set = predicate['string_set'] --- @type table<string, boolean>
+ if not string_set then
+ string_set = {}
+ for i = 3, #predicate do
+ string_set[predicate[i]] = true
+ end
+ predicate['string_set'] = string_set
+ end
+
+ if string_set[node_text] then
+ return true
end
- predicate['string_set'] = string_set
end
- return string_set[node_text]
+ return false
end,
['has-ancestor?'] = function(match, _, _, predicate)
- local node = match[predicate[2]]
- if not node then
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local ancestor_types = {}
- for _, type in ipairs({ unpack(predicate, 3) }) do
- ancestor_types[type] = true
- end
+ for _, node in ipairs(nodes) do
+ local ancestor_types = {} --- @type table<string, boolean>
+ for _, type in ipairs({ unpack(predicate, 3) }) do
+ ancestor_types[type] = true
+ end
- node = node:parent()
- while node do
- if ancestor_types[node:type()] then
- return true
+ local cur = node:parent()
+ while cur do
+ if ancestor_types[cur:type()] then
+ return true
+ end
+ cur = cur:parent()
end
- node = node:parent()
end
return false
end,
['has-parent?'] = function(match, _, _, predicate)
- local node = match[predicate[2]]
- if not node then
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then
- return true
+ for _, node in ipairs(nodes) do
+ if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then
+ return true
+ end
end
return false
end,
@@ -433,6 +532,7 @@ local predicate_handlers = {
-- As we provide lua-match? also expose vim-match?
predicate_handlers['vim-match?'] = predicate_handlers['match?']
+predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?']
---@class TSMetadata
---@field range? Range
@@ -468,13 +568,17 @@ local directive_handlers = {
-- Shifts the range of a node.
-- Example: (#offset! @_node 0 1 0 -1)
['offset!'] = function(match, _, _, pred, metadata)
- ---@cast pred integer[]
- local capture_id = pred[2]
+ local capture_id = pred[2] --[[@as integer]]
+ local nodes = match[capture_id]
+ assert(#nodes == 1, '#offset! does not support captures on multiple nodes')
+
+ local node = nodes[1]
+
if not metadata[capture_id] then
metadata[capture_id] = {}
end
- local range = metadata[capture_id].range or { match[capture_id]:range() }
+ local range = metadata[capture_id].range or { node:range() }
local start_row_offset = pred[3] or 0
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
@@ -498,7 +602,9 @@ local directive_handlers = {
local id = pred[2]
assert(type(id) == 'number')
- local node = match[id]
+ local nodes = match[id]
+ assert(#nodes == 1, '#gsub! does not support captures on multiple nodes')
+ local node = nodes[1]
local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
if not metadata[id] then
@@ -518,10 +624,9 @@ local directive_handlers = {
local capture_id = pred[2]
assert(type(capture_id) == 'number')
- local node = match[capture_id]
- if not node then
- return
- end
+ local nodes = match[capture_id]
+ assert(#nodes == 1, '#trim! does not support captures on multiple nodes')
+ local node = nodes[1]
local start_row, start_col, end_row, end_col = node:range()
@@ -552,38 +657,93 @@ local directive_handlers = {
--- Adds a new predicate to be used in queries
---
---@param name string Name of the predicate, without leading #
----@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[])
+---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table)
--- - see |vim.treesitter.query.add_directive()| for argument meanings
----@param force boolean|nil
-function M.add_predicate(name, handler, force)
- if predicate_handlers[name] and not force then
- error(string.format('Overriding %s', name))
+---@param opts table<string, any> Optional options:
+--- - force (boolean): Override an existing
+--- predicate of the same name
+--- - all (boolean): Use the correct
+--- implementation of the match table where
+--- capture IDs map to a list of nodes instead
+--- of a single node. Defaults to false (for
+--- backward compatibility). This option will
+--- eventually become the default and removed.
+function M.add_predicate(name, handler, opts)
+ -- Backward compatibility: old signature had "force" as boolean argument
+ if type(opts) == 'boolean' then
+ opts = { force = opts }
end
- predicate_handlers[name] = handler
+ opts = opts or {}
+
+ if predicate_handlers[name] and not opts.force then
+ error(string.format('Overriding existing predicate %s', name))
+ end
+
+ if opts.all then
+ predicate_handlers[name] = handler
+ else
+ --- @param match table<integer, TSNode[]>
+ local function wrapper(match, ...)
+ local m = {} ---@type table<integer, TSNode>
+ for k, v in pairs(match) do
+ if type(k) == 'number' then
+ m[k] = v[#v]
+ end
+ end
+ return handler(m, ...)
+ end
+ predicate_handlers[name] = wrapper
+ end
end
--- Adds a new directive to be used in queries
---
--- Handlers can set match level data by setting directly on the
---- metadata object `metadata.key = value`, additionally, handlers
+--- metadata object `metadata.key = value`. Additionally, handlers
--- can set node level data by using the capture id on the
--- metadata table `metadata[capture_id].key = value`
---
---@param name string Name of the directive, without leading #
----@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[], metadata:table)
---- - match: see |treesitter-query|
---- - node-level data are accessible via `match[capture_id]`
---- - pattern: see |treesitter-query|
+---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table)
+--- - match: A table mapping capture IDs to a list of captured nodes
+--- - pattern: the index of the matching pattern in the query file
--- - predicate: list of strings containing the full directive being called, e.g.
--- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }`
----@param force boolean|nil
-function M.add_directive(name, handler, force)
- if directive_handlers[name] and not force then
- error(string.format('Overriding %s', name))
+---@param opts table<string, any> Optional options:
+--- - force (boolean): Override an existing
+--- predicate of the same name
+--- - all (boolean): Use the correct
+--- implementation of the match table where
+--- capture IDs map to a list of nodes instead
+--- of a single node. Defaults to false (for
+--- backward compatibility). This option will
+--- eventually become the default and removed.
+function M.add_directive(name, handler, opts)
+ -- Backward compatibility: old signature had "force" as boolean argument
+ if type(opts) == 'boolean' then
+ opts = { force = opts }
+ end
+
+ opts = opts or {}
+
+ if directive_handlers[name] and not opts.force then
+ error(string.format('Overriding existing directive %s', name))
end
- directive_handlers[name] = handler
+ if opts.all then
+ directive_handlers[name] = handler
+ else
+ --- @param match table<integer, TSNode[]>
+ local function wrapper(match, ...)
+ local m = {} ---@type table<integer, TSNode>
+ for k, v in pairs(match) do
+ m[k] = v[#v]
+ end
+ handler(m, ...)
+ end
+ directive_handlers[name] = wrapper
+ end
end
--- Lists the currently available directives to use in queries.
@@ -608,7 +768,7 @@ end
---@private
---@param match TSMatch
----@param pattern string
+---@param pattern integer
---@param source integer|string
function Query:match_preds(match, pattern, source)
local preds = self.info.patterns[pattern]
@@ -618,18 +778,14 @@ function Query:match_preds(match, pattern, source)
-- continue on the other case. This way unknown predicates will not be considered,
-- which allows some testing and easier user extensibility (#12173).
-- Also, tree-sitter strips the leading # from predicates for us.
- local pred_name ---@type string
-
- local is_not ---@type boolean
+ local is_not = false
-- Skip over directives... they will get processed after all the predicates.
if not is_directive(pred[1]) then
- if string.sub(pred[1], 1, 4) == 'not-' then
- pred_name = string.sub(pred[1], 5)
+ local pred_name = pred[1]
+ if pred_name:match('^not%-') then
+ pred_name = pred_name:sub(5)
is_not = true
- else
- pred_name = pred[1]
- is_not = false
end
local handler = predicate_handlers[pred_name]
@@ -724,7 +880,7 @@ function Query:iter_captures(node, source, start, stop)
start, stop = value_or_node_range(start, stop, node)
- local raw_iter = node:_rawquery(self.query, true, start, stop)
+ local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, TSMatch
local function iter(end_line)
local capture, captured_node, match = raw_iter()
local metadata = {}
@@ -748,27 +904,34 @@ end
--- Iterates the matches of self on a given range.
---
---- Iterate over all matches within a {node}. The arguments are the same as
---- for |Query:iter_captures()| but the iterated values are different:
---- an (1-based) index of the pattern in the query, a table mapping
---- capture indices to nodes, and metadata from any directives processing the match.
---- If the query has more than one pattern, the capture table might be sparse
---- and e.g. `pairs()` method should be used over `ipairs`.
---- Here is an example iterating over all captures in every match:
+--- Iterate over all matches within a {node}. The arguments are the same as for
+--- |Query:iter_captures()| but the iterated values are different: an (1-based)
+--- index of the pattern in the query, a table mapping capture indices to a list
+--- of nodes, and metadata from any directives processing the match.
+---
+--- WARNING: Set `all=true` to ensure all matching nodes in a match are
+--- returned, otherwise only the last node in a match is returned, breaking captures
+--- involving quantifiers such as `(comment)+ @comment`. The default option
+--- `all=false` is only provided for backward compatibility and will be removed
+--- after Nvim 0.10.
+---
+--- Example:
---
--- ```lua
---- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do
---- for id, node in pairs(match) do
+--- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1, { all = true }) do
+--- for id, nodes in pairs(match) do
--- local name = query.captures[id]
---- -- `node` was captured by the `name` capture in the match
+--- for _, node in ipairs(nodes) do
+--- -- `node` was captured by the `name` capture in the match
---
---- local node_data = metadata[id] -- Node level metadata
----
---- -- ... use the info here ...
+--- local node_data = metadata[id] -- Node level metadata
+--- ... use the info here ...
+--- end
--- end
--- end
--- ```
---
+---
---@param node TSNode under which the search will occur
---@param source (integer|string) Source buffer or string to search
---@param start? integer Starting line for the search. Defaults to `node:start()`.
@@ -776,17 +939,20 @@ end
---@param opts? table Optional keyword arguments:
--- - max_start_depth (integer) if non-zero, sets the maximum start depth
--- for each match. This is used to prevent traversing too deep into a tree.
+--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes.
+--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is
+--- incorrect behavior. This option will eventually become the default and removed.
---
----@return (fun(): integer, table<integer,TSNode>, table): pattern id, match, metadata
+---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata
function Query:iter_matches(node, source, start, stop, opts)
+ local all = opts and opts.all
if type(source) == 'number' and source == 0 then
source = api.nvim_get_current_buf()
end
start, stop = value_or_node_range(start, stop, node)
- local raw_iter = node:_rawquery(self.query, false, start, stop, opts)
- ---@cast raw_iter fun(): string, any
+ local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, TSMatch
local function iter()
local pattern, match = raw_iter()
local metadata = {}
@@ -799,6 +965,18 @@ function Query:iter_matches(node, source, start, stop, opts)
self:apply_directives(match, pattern, source, metadata)
end
+
+ if not all then
+ -- Convert the match table into the old buggy version for backward
+ -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all"
+ -- option!
+ local old_match = {} ---@type table<integer, TSNode>
+ for k, v in pairs(match or {}) do
+ old_match[k] = v[#v]
+ end
+ return pattern, old_match, metadata
+ end
+
return pattern, match, metadata
end
return iter