diff options
author | Lewis Russell <lewis6991@gmail.com> | 2024-03-18 23:19:01 +0000 |
---|---|---|
committer | Lewis Russell <me@lewisr.dev> | 2024-03-19 14:24:59 +0000 |
commit | aca2048bcd57937ea1c7b7f0325f25d5b82588db (patch) | |
tree | 74f26cf464e6ce345cb296c07477cb38c088f613 /runtime/lua/vim/treesitter/query.lua | |
parent | 16a416cb3c17ed3a7f21d35da5d211fcad947768 (diff) | |
download | rneovim-aca2048bcd57937ea1c7b7f0325f25d5b82588db.tar.gz rneovim-aca2048bcd57937ea1c7b7f0325f25d5b82588db.tar.bz2 rneovim-aca2048bcd57937ea1c7b7f0325f25d5b82588db.zip |
refactor(treesitter): redesign query iterating
Problem:
`TSNode:_rawquery()` is complicated, has known issues and the Lua and
C code is awkwardly coupled (see logic with `active`).
Solution:
- Add `TSQueryCursor` and `TSQueryMatch` bindings.
- Replace `TSNode:_rawquery()` with `TSQueryCursor:next_capture()` and `TSQueryCursor:next_match()`
- Do more stuff in Lua
- API for `Query:iter_captures()` and `Query:iter_matches()` remains the same.
- `treesitter.c` no longer contains any logic related to predicates.
- Add `match_limit` option to `iter_matches()`. Default is still 256.
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 125 |
1 files changed, 80 insertions, 45 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 30cd00c617..075fd0e99b 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -258,7 +258,7 @@ end) --- handling the "any" vs "all" semantics. They are called from the --- predicate_handlers table with the appropriate arguments for each predicate. local impl = { - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -293,7 +293,7 @@ local impl = { return not any end, - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -333,7 +333,7 @@ local impl = { end, }) - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -356,7 +356,7 @@ local impl = { end end)(), - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -383,13 +383,7 @@ local impl = { end, } ----@nodoc ----@class vim.treesitter.query.TSMatch ----@field pattern? integer ----@field active? boolean ----@field [integer] TSNode[] - ----@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean +---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -504,7 +498,7 @@ predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?'] ---@field [integer] vim.treesitter.query.TSMetadata ---@field [string] integer|string ----@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) +---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -726,13 +720,19 @@ local function is_directive(name) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param pattern integer +---@param match TSQueryMatch ---@param source integer|string -function Query:match_preds(match, pattern, source) +function Query:match_preds(match, source) + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return true + end + + local captures = match:captures() + + for _, pred in pairs(preds) do -- Here we only want to return if a predicate DOES NOT match, and -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). @@ -754,7 +754,7 @@ function Query:match_preds(match, pattern, source) return false end - local pred_matches = handler(match, pattern, source, pred) + local pred_matches = handler(captures, pattern, source, pred) if not xor(is_not, pred_matches) then return false @@ -765,23 +765,33 @@ function Query:match_preds(match, pattern, source) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param metadata vim.treesitter.query.TSMetadata -function Query:apply_directives(match, pattern, source, metadata) +---@param match TSQueryMatch +---@return vim.treesitter.query.TSMetadata metadata +function Query:apply_directives(match, source) + ---@type vim.treesitter.query.TSMetadata + local metadata = {} + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return metadata + end + + local captures = match:captures() + + for _, pred in pairs(preds) do if is_directive(pred[1]) then local handler = directive_handlers[pred[1]] if not handler then error(string.format('No handler for %s', pred[1])) - return end - handler(match, pattern, source, pred, metadata) + handler(captures, pattern, source, pred, metadata) end end + + return metadata end --- Returns the start and stop value if set else the node's range. @@ -831,8 +841,10 @@ end ---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer, TSNode>): +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer,TSNode[]>?): --- capture id, capture node, metadata, match +--- +---@note Captures are only returned if the query pattern of a specific capture contained predicates. function Query:iter_captures(node, source, start, stop) if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() @@ -840,24 +852,38 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) + + local max_match_id = -1 + local function iter(end_line) - local capture, captured_node, match = raw_iter() + local capture, captured_node, match = cursor:next_capture() + + if not capture then + return + end + + local captures --- @type table<integer,TSNode[]>? + local match_id, pattern_index = match:info() + local metadata = {} - if match ~= nil then - local active = self:match_preds(match, match.pattern, source) - match.active = active - if not active then + local preds = self.info.patterns[pattern_index] or {} + + if #preds > 0 and match_id > max_match_id then + captures = match:captures() + max_match_id = match_id + if not self:match_preds(match, source) then + cursor:remove_match(match_id) if end_line and captured_node:range() > end_line then return nil, captured_node, nil end return iter(end_line) -- tail call: try next match end - self:apply_directives(match, match.pattern, source, metadata) + metadata = self:apply_directives(match, source) end - return capture, captured_node, metadata, match + return capture, captured_node, metadata, captures end return iter end @@ -899,45 +925,54 @@ end ---@param opts? table Optional keyword arguments: --- - max_start_depth (integer) if non-zero, sets the maximum start depth --- for each match. This is used to prevent traversing too deep into a tree. +--- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256). --- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes. --- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is --- incorrect behavior. This option will eventually become the default and removed. --- ---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata function Query:iter_matches(node, source, start, stop, opts) - local all = opts and opts.all + opts = opts or {} + opts.match_limit = opts.match_limit or 256 + if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts) + local function iter() - local pattern, match = raw_iter() - local metadata = {} + local match = cursor:next_match() - if match ~= nil then - local active = self:match_preds(match, pattern, source) - if not active then - return iter() -- tail call: try next match - end + if not match then + return + end - self:apply_directives(match, pattern, source, metadata) + local match_id, pattern = match:info() + + if not self:match_preds(match, source) then + cursor:remove_match(match_id) + return iter() -- tail call: try next match end - if not all then + local metadata = self:apply_directives(match, source) + + local captures = match:captures() + + if not opts.all then -- Convert the match table into the old buggy version for backward -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all" -- option! local old_match = {} ---@type table<integer, TSNode> - for k, v in pairs(match or {}) do + for k, v in pairs(captures or {}) do old_match[k] = v[#v] end return pattern, old_match, metadata end - return pattern, match, metadata + return pattern, captures, metadata end return iter end |