From 14e4b6bbd8640675d7393bdeb3e93d74ab875ff1 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Sat, 16 Mar 2024 17:11:42 +0000 Subject: refactor(lua): type annotations --- runtime/lua/vim/treesitter/query.lua | 51 ++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 28 deletions(-) (limited to 'runtime/lua/vim/treesitter/query.lua') diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index a086f5e876..67b8c596b8 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -88,7 +88,7 @@ end --- ---@param lang string Language to get query for ---@param query_name string Name of the query to load (e.g., "highlights") ----@param is_included (boolean|nil) Internal parameter, most of the time left as `nil` +---@param is_included? boolean Internal parameter, most of the time left as `nil` ---@return string[] query_files List of files to load for given query and language function M.get_files(lang, query_name, is_included) local query_path = string.format('queries/%s/%s.scm', lang, query_name) @@ -211,7 +211,7 @@ end ---@param lang string Language to use for the query ---@param query_name string Name of the query (e.g. "highlights") --- ----@return vim.treesitter.Query|nil : Parsed query. `nil` if no query files are found. +---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found. M.get = vim.func._memoize('concat-2', function(lang, query_name) if explicit_queries[lang][query_name] then return explicit_queries[lang][query_name] @@ -242,9 +242,9 @@ end) ---@param lang string Language to use for the query ---@param query string Query in s-expr syntax --- ----@return vim.treesitter.Query Parsed query +---@return vim.treesitter.Query : Parsed query --- ----@see |vim.treesitter.query.get()| +---@see [vim.treesitter.query.get()] M.parse = vim.func._memoize('concat-2', function(lang, query) language.add(lang) @@ -618,20 +618,23 @@ local directive_handlers = { end, } +--- @class vim.treesitter.query.add_predicate.Opts +--- @inlinedoc +--- +--- Override an existing predicate of the same name +--- @field force? boolean +--- +--- Use the correct implementation of the match table where capture IDs map to +--- a list of nodes instead of a single node. Defaults to false (for backward +--- compatibility). This option will eventually become the default and removed. +--- @field all? boolean + --- Adds a new predicate to be used in queries --- ---@param name string Name of the predicate, without leading # ----@param handler function(match: table, pattern: integer, source: integer|string, predicate: any[], metadata: table) +---@param handler fun(match: table, pattern: integer, source: integer|string, predicate: any[], metadata: table) --- - see |vim.treesitter.query.add_directive()| for argument meanings ----@param opts table Optional options: ---- - force (boolean): Override an existing ---- predicate of the same name ---- - all (boolean): Use the correct ---- implementation of the match table where ---- capture IDs map to a list of nodes instead ---- of a single node. Defaults to false (for ---- backward compatibility). This option will ---- eventually become the default and removed. +---@param opts vim.treesitter.query.add_predicate.Opts function M.add_predicate(name, handler, opts) -- Backward compatibility: old signature had "force" as boolean argument if type(opts) == 'boolean' then @@ -669,20 +672,12 @@ end --- metadata table `metadata[capture_id].key = value` --- ---@param name string Name of the directive, without leading # ----@param handler function(match: table, pattern: integer, source: integer|string, predicate: any[], metadata: table) +---@param handler fun(match: table, pattern: integer, source: integer|string, predicate: any[], metadata: table) --- - match: A table mapping capture IDs to a list of captured nodes --- - pattern: the index of the matching pattern in the query file --- - predicate: list of strings containing the full directive being called, e.g. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` ----@param opts table Optional options: ---- - force (boolean): Override an existing ---- predicate of the same name ---- - all (boolean): Use the correct ---- implementation of the match table where ---- capture IDs map to a list of nodes instead ---- of a single node. Defaults to false (for ---- backward compatibility). This option will ---- eventually become the default and removed. +---@param opts vim.treesitter.query.add_predicate.Opts function M.add_directive(name, handler, opts) -- Backward compatibility: old signature had "force" as boolean argument if type(opts) == 'boolean' then @@ -711,13 +706,13 @@ function M.add_directive(name, handler, opts) end --- Lists the currently available directives to use in queries. ----@return string[] List of supported directives. +---@return string[] : Supported directives. function M.list_directives() return vim.tbl_keys(directive_handlers) end --- Lists the currently available predicates to use in queries. ----@return string[] List of supported predicates. +---@return string[] : Supported predicates. function M.list_predicates() return vim.tbl_keys(predicate_handlers) end @@ -792,8 +787,8 @@ end --- Returns the start and stop value if set else the node's range. -- When the node's range is used, the stop is incremented by 1 -- to make the search inclusive. ----@param start integer|nil ----@param stop integer|nil +---@param start integer? +---@param stop integer? ---@param node TSNode ---@return integer, integer local function value_or_node_range(start, stop, node) -- cgit From 3b29b39e6deb212456eba691bc79b17edaa8717b Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Sun, 17 Mar 2024 18:02:40 +0000 Subject: fix(treesitter): revert to using iter_captures in highlighter Fixes #27895 --- runtime/lua/vim/treesitter/query.lua | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'runtime/lua/vim/treesitter/query.lua') diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 67b8c596b8..30cd00c617 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -811,12 +811,13 @@ end --- as the {node}, i.e., to get syntax highlight matches in the current --- viewport). When omitted, the {start} and {stop} row values are used from the given node. --- ---- The iterator returns three values: a numeric id identifying the capture, ---- the captured node, and metadata from any directives processing the match. +--- The iterator returns four values: a numeric id identifying the capture, +--- the captured node, metadata from any directives processing the match, +--- and the match itself. --- The following example shows how to get captures by name: --- --- ```lua ---- for id, node, metadata in query:iter_captures(tree:root(), bufnr, first, last) do +--- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do --- local name = query.captures[id] -- name of the capture in the query --- -- typically useful info about the node: --- local type = node:type() -- type of the captured node @@ -830,8 +831,8 @@ end ---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata): ---- capture id, capture node, metadata +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table): +--- capture id, capture node, metadata, match function Query:iter_captures(node, source, start, stop) if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() @@ -856,7 +857,7 @@ function Query:iter_captures(node, source, start, stop) self:apply_directives(match, match.pattern, source, metadata) end - return capture, captured_node, metadata + return capture, captured_node, metadata, match end return iter end -- cgit From aca2048bcd57937ea1c7b7f0325f25d5b82588db Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Mon, 18 Mar 2024 23:19:01 +0000 Subject: refactor(treesitter): redesign query iterating Problem: `TSNode:_rawquery()` is complicated, has known issues and the Lua and C code is awkwardly coupled (see logic with `active`). Solution: - Add `TSQueryCursor` and `TSQueryMatch` bindings. - Replace `TSNode:_rawquery()` with `TSQueryCursor:next_capture()` and `TSQueryCursor:next_match()` - Do more stuff in Lua - API for `Query:iter_captures()` and `Query:iter_matches()` remains the same. - `treesitter.c` no longer contains any logic related to predicates. - Add `match_limit` option to `iter_matches()`. Default is still 256. --- runtime/lua/vim/treesitter/query.lua | 125 ++++++++++++++++++++++------------- 1 file changed, 80 insertions(+), 45 deletions(-) (limited to 'runtime/lua/vim/treesitter/query.lua') diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 30cd00c617..075fd0e99b 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -258,7 +258,7 @@ end) --- handling the "any" vs "all" semantics. They are called from the --- predicate_handlers table with the appropriate arguments for each predicate. local impl = { - --- @param match vim.treesitter.query.TSMatch + --- @param match table --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -293,7 +293,7 @@ local impl = { return not any end, - --- @param match vim.treesitter.query.TSMatch + --- @param match table --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -333,7 +333,7 @@ local impl = { end, }) - --- @param match vim.treesitter.query.TSMatch + --- @param match table --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -356,7 +356,7 @@ local impl = { end end)(), - --- @param match vim.treesitter.query.TSMatch + --- @param match table --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -383,13 +383,7 @@ local impl = { end, } ----@nodoc ----@class vim.treesitter.query.TSMatch ----@field pattern? integer ----@field active? boolean ----@field [integer] TSNode[] - ----@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean +---@alias TSPredicate fun(match: table, pattern: integer, source: integer|string, predicate: any[]): boolean -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -504,7 +498,7 @@ predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?'] ---@field [integer] vim.treesitter.query.TSMetadata ---@field [string] integer|string ----@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) +---@alias TSDirective fun(match: table, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -726,13 +720,19 @@ local function is_directive(name) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param pattern integer +---@param match TSQueryMatch ---@param source integer|string -function Query:match_preds(match, pattern, source) +function Query:match_preds(match, source) + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return true + end + + local captures = match:captures() + + for _, pred in pairs(preds) do -- Here we only want to return if a predicate DOES NOT match, and -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). @@ -754,7 +754,7 @@ function Query:match_preds(match, pattern, source) return false end - local pred_matches = handler(match, pattern, source, pred) + local pred_matches = handler(captures, pattern, source, pred) if not xor(is_not, pred_matches) then return false @@ -765,23 +765,33 @@ function Query:match_preds(match, pattern, source) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param metadata vim.treesitter.query.TSMetadata -function Query:apply_directives(match, pattern, source, metadata) +---@param match TSQueryMatch +---@return vim.treesitter.query.TSMetadata metadata +function Query:apply_directives(match, source) + ---@type vim.treesitter.query.TSMetadata + local metadata = {} + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return metadata + end + + local captures = match:captures() + + for _, pred in pairs(preds) do if is_directive(pred[1]) then local handler = directive_handlers[pred[1]] if not handler then error(string.format('No handler for %s', pred[1])) - return end - handler(match, pattern, source, pred, metadata) + handler(captures, pattern, source, pred, metadata) end end + + return metadata end --- Returns the start and stop value if set else the node's range. @@ -831,8 +841,10 @@ end ---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table): +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table?): --- capture id, capture node, metadata, match +--- +---@note Captures are only returned if the query pattern of a specific capture contained predicates. function Query:iter_captures(node, source, start, stop) if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() @@ -840,24 +852,38 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) + + local max_match_id = -1 + local function iter(end_line) - local capture, captured_node, match = raw_iter() + local capture, captured_node, match = cursor:next_capture() + + if not capture then + return + end + + local captures --- @type table? + local match_id, pattern_index = match:info() + local metadata = {} - if match ~= nil then - local active = self:match_preds(match, match.pattern, source) - match.active = active - if not active then + local preds = self.info.patterns[pattern_index] or {} + + if #preds > 0 and match_id > max_match_id then + captures = match:captures() + max_match_id = match_id + if not self:match_preds(match, source) then + cursor:remove_match(match_id) if end_line and captured_node:range() > end_line then return nil, captured_node, nil end return iter(end_line) -- tail call: try next match end - self:apply_directives(match, match.pattern, source, metadata) + metadata = self:apply_directives(match, source) end - return capture, captured_node, metadata, match + return capture, captured_node, metadata, captures end return iter end @@ -899,45 +925,54 @@ end ---@param opts? table Optional keyword arguments: --- - max_start_depth (integer) if non-zero, sets the maximum start depth --- for each match. This is used to prevent traversing too deep into a tree. +--- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256). --- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes. --- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is --- incorrect behavior. This option will eventually become the default and removed. --- ---@return (fun(): integer, table, table): pattern id, match, metadata function Query:iter_matches(node, source, start, stop, opts) - local all = opts and opts.all + opts = opts or {} + opts.match_limit = opts.match_limit or 256 + if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts) + local function iter() - local pattern, match = raw_iter() - local metadata = {} + local match = cursor:next_match() - if match ~= nil then - local active = self:match_preds(match, pattern, source) - if not active then - return iter() -- tail call: try next match - end + if not match then + return + end - self:apply_directives(match, pattern, source, metadata) + local match_id, pattern = match:info() + + if not self:match_preds(match, source) then + cursor:remove_match(match_id) + return iter() -- tail call: try next match end - if not all then + local metadata = self:apply_directives(match, source) + + local captures = match:captures() + + if not opts.all then -- Convert the match table into the old buggy version for backward -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all" -- option! local old_match = {} ---@type table - for k, v in pairs(match or {}) do + for k, v in pairs(captures or {}) do old_match[k] = v[#v] end return pattern, old_match, metadata end - return pattern, match, metadata + return pattern, captures, metadata end return iter end -- cgit From 7d971500847089ec8ade926a7f84d6bb3a51c8b0 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Mon, 25 Mar 2024 22:06:31 +0000 Subject: fix(treesitter): return correct match table in iter_captures() --- runtime/lua/vim/treesitter/query.lua | 46 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) (limited to 'runtime/lua/vim/treesitter/query.lua') diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 075fd0e99b..e68acac929 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,5 +1,6 @@ local api = vim.api local language = require('vim.treesitter.language') +local memoize = vim.func._memoize local M = {} @@ -212,7 +213,7 @@ end ---@param query_name string Name of the query (e.g. "highlights") --- ---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found. -M.get = vim.func._memoize('concat-2', function(lang, query_name) +M.get = memoize('concat-2', function(lang, query_name) if explicit_queries[lang][query_name] then return explicit_queries[lang][query_name] end @@ -245,7 +246,7 @@ end) ---@return vim.treesitter.Query : Parsed query --- ---@see [vim.treesitter.query.get()] -M.parse = vim.func._memoize('concat-2', function(lang, query) +M.parse = memoize('concat-2', function(lang, query) language.add(lang) local ts_query = vim._ts_parse_query(lang, query) @@ -812,6 +813,12 @@ local function value_or_node_range(start, stop, node) return start, stop end +--- @param match TSQueryMatch +--- @return integer +local function match_id_hash(_, match) + return (match:info()) +end + --- Iterate over all captures from all matches inside {node} --- --- {source} is needed if the query contains predicates; then the caller @@ -841,7 +848,7 @@ end ---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table?): +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch): --- capture id, capture node, metadata, match --- ---@note Captures are only returned if the query pattern of a specific capture contained predicates. @@ -854,7 +861,8 @@ function Query:iter_captures(node, source, start, stop) local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) - local max_match_id = -1 + local apply_directives = memoize(match_id_hash, self.apply_directives, true) + local match_preds = memoize(match_id_hash, self.match_preds, true) local function iter(end_line) local capture, captured_node, match = cursor:next_capture() @@ -863,27 +871,18 @@ function Query:iter_captures(node, source, start, stop) return end - local captures --- @type table? - local match_id, pattern_index = match:info() - - local metadata = {} - - local preds = self.info.patterns[pattern_index] or {} - - if #preds > 0 and match_id > max_match_id then - captures = match:captures() - max_match_id = match_id - if not self:match_preds(match, source) then - cursor:remove_match(match_id) - if end_line and captured_node:range() > end_line then - return nil, captured_node, nil - end - return iter(end_line) -- tail call: try next match + if not match_preds(self, match, source) then + local match_id = match:info() + cursor:remove_match(match_id) + if end_line and captured_node:range() > end_line then + return nil, captured_node, nil, nil end - - metadata = self:apply_directives(match, source) + return iter(end_line) -- tail call: try next match end - return capture, captured_node, metadata, captures + + local metadata = apply_directives(self, match, source) + + return capture, captured_node, metadata, match end return iter end @@ -972,6 +971,7 @@ function Query:iter_matches(node, source, start, stop, opts) return pattern, old_match, metadata end + -- TODO(lewis6991): create a new function that returns {match, metadata} return pattern, captures, metadata end return iter -- cgit From 6a264e08974bcb1b91f891eb65ef374f350d2827 Mon Sep 17 00:00:00 2001 From: Riley Bruins Date: Tue, 14 May 2024 07:14:43 -0700 Subject: fix(treesitter): allow optional directive captures (#28664) --- runtime/lua/vim/treesitter/query.lua | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'runtime/lua/vim/treesitter/query.lua') diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index e68acac929..36c78b7f1d 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -529,6 +529,9 @@ local directive_handlers = { ['offset!'] = function(match, _, _, pred, metadata) local capture_id = pred[2] --[[@as integer]] local nodes = match[capture_id] + if not nodes or #nodes == 0 then + return + end assert(#nodes == 1, '#offset! does not support captures on multiple nodes') local node = nodes[1] @@ -562,6 +565,9 @@ local directive_handlers = { assert(type(id) == 'number') local nodes = match[id] + if not nodes or #nodes == 0 then + return + end assert(#nodes == 1, '#gsub! does not support captures on multiple nodes') local node = nodes[1] local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' @@ -584,6 +590,9 @@ local directive_handlers = { assert(type(capture_id) == 'number') local nodes = match[capture_id] + if not nodes or #nodes == 0 then + return + end assert(#nodes == 1, '#trim! does not support captures on multiple nodes') local node = nodes[1] -- cgit From 4b029163345333a2c6975cd0dace6613b036ae47 Mon Sep 17 00:00:00 2001 From: vanaigr Date: Thu, 16 May 2024 09:57:58 -0500 Subject: perf(treesitter): use child_containing_descendant() in has-ancestor? (#28512) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Problem: `has-ancestor?` is O(n²) for the depth of the tree since it iterates over each of the node's ancestors (bottom-up), and each ancestor takes O(n) time. This happens because tree-sitter's nodes don't store their parent nodes, and the tree is searched (top-down) each time a new parent is requested. Solution: Make use of new `ts_node_child_containing_descendant()` in tree-sitter v0.22.6 (which is now the minimum required version) to rewrite the `has-ancestor?` predicate in C to become O(n). For a sample file, decreases the time taken by `has-ancestor?` from 360ms to 6ms. --- runtime/lua/vim/treesitter/query.lua | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) (limited to 'runtime/lua/vim/treesitter/query.lua') diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 36c78b7f1d..ef5c2143a7 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -457,17 +457,8 @@ local predicate_handlers = { end for _, node in ipairs(nodes) do - local ancestor_types = {} --- @type table - for _, type in ipairs({ unpack(predicate, 3) }) do - ancestor_types[type] = true - end - - local cur = node:parent() - while cur do - if ancestor_types[cur:type()] then - return true - end - cur = cur:parent() + if node:__has_ancestor(predicate) then + return true end end return false -- cgit