diff options
author | Josh Rahm <joshuarahm@gmail.com> | 2025-02-05 23:09:29 +0000 |
---|---|---|
committer | Josh Rahm <joshuarahm@gmail.com> | 2025-02-05 23:09:29 +0000 |
commit | d5f194ce780c95821a855aca3c19426576d28ae0 (patch) | |
tree | d45f461b19f9118ad2bb1f440a7a08973ad18832 /runtime/lua/vim/treesitter/query.lua | |
parent | c5d770d311841ea5230426cc4c868e8db27300a8 (diff) | |
parent | 44740e561fc93afe3ebecfd3618bda2d2abeafb0 (diff) | |
download | rneovim-rahm.tar.gz rneovim-rahm.tar.bz2 rneovim-rahm.zip |
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 418 |
1 files changed, 280 insertions, 138 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 1677e8d364..10fb82e533 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,17 +1,77 @@ +--- @brief This Lua |treesitter-query| interface allows you to create queries and use them to parse +--- text. See |vim.treesitter.query.parse()| for a working example. + local api = vim.api local language = require('vim.treesitter.language') local memoize = vim.func._memoize +local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$' +local EXTENDS_FORMAT = '^;+%s*extends%s*$' + local M = {} +local function is_directive(name) + return string.sub(name, -1) == '!' +end + +---@nodoc +---@class vim.treesitter.query.ProcessedPredicate +---@field [1] string predicate name +---@field [2] boolean should match +---@field [3] (integer|string)[] the original predicate + +---@alias vim.treesitter.query.ProcessedDirective (integer|string)[] + +---@nodoc +---@class vim.treesitter.query.ProcessedPattern { +---@field predicates vim.treesitter.query.ProcessedPredicate[] +---@field directives vim.treesitter.query.ProcessedDirective[] + +--- Splits the query patterns into predicates and directives. +---@param patterns table<integer, (integer|string)[][]> +---@return table<integer, vim.treesitter.query.ProcessedPattern> +local function process_patterns(patterns) + ---@type table<integer, vim.treesitter.query.ProcessedPattern> + local processed_patterns = {} + + for k, pattern_list in pairs(patterns) do + ---@type vim.treesitter.query.ProcessedPredicate[] + local predicates = {} + ---@type vim.treesitter.query.ProcessedDirective[] + local directives = {} + + for _, pattern in ipairs(pattern_list) do + -- Note: tree-sitter strips the leading # from predicates for us. + local pred_name = pattern[1] + ---@cast pred_name string + + if is_directive(pred_name) then + table.insert(directives, pattern) + else + local should_match = true + if pred_name:match('^not%-') then + pred_name = pred_name:sub(5) + should_match = false + end + table.insert(predicates, { pred_name, should_match, pattern }) + end + end + + processed_patterns[k] = { predicates = predicates, directives = directives } + end + + return processed_patterns +end + ---@nodoc ---Parsed query, see |vim.treesitter.query.parse()| --- ---@class vim.treesitter.Query ----@field lang string name of the language for this parser +---@field lang string parser language name ---@field captures string[] list of (unique) capture names defined in query ----@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives) +---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives) ---@field query TSQuery userdata query object +---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern> local Query = {} Query.__index = Query @@ -30,6 +90,7 @@ function Query.new(lang, ts_query) patterns = query_info.patterns, } self.captures = self.info.captures + self._processed_patterns = process_patterns(self.info.patterns) return self end @@ -109,9 +170,6 @@ function M.get_files(lang, query_name, is_included) -- ;+ inherits: ({language},)*{language} -- -- {language} ::= {lang} | ({lang}) - local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$' - local EXTENDS_FORMAT = '^;+%s*extends%s*$' - for _, filename in ipairs(lang_files) do local file, err = io.open(filename, 'r') if not file then @@ -184,8 +242,8 @@ local function read_query_files(filenames) return table.concat(contents, '') end --- The explicitly set queries from |vim.treesitter.query.set()| ----@type table<string,table<string,vim.treesitter.Query>> +-- The explicitly set query strings from |vim.treesitter.query.set()| +---@type table<string,table<string,string>> local explicit_queries = setmetatable({}, { __index = function(t, k) local lang_queries = {} @@ -197,14 +255,27 @@ local explicit_queries = setmetatable({}, { --- Sets the runtime query named {query_name} for {lang} --- ---- This allows users to override any runtime files and/or configuration +--- This allows users to override or extend any runtime files and/or configuration --- set by plugins. --- +--- For example, you could enable spellchecking of `C` identifiers with the +--- following code: +--- ```lua +--- vim.treesitter.query.set( +--- 'c', +--- 'highlights', +--- [[;inherits c +--- (identifier) @spell]]) +--- ]]) +--- ``` +--- ---@param lang string Language to use for the query ---@param query_name string Name of the query (e.g., "highlights") ---@param text string Query text (unparsed). function M.set(lang, query_name, text) - explicit_queries[lang][query_name] = M.parse(lang, text) + --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics + M.get:clear(lang, query_name) + explicit_queries[lang][query_name] = text end --- Returns the runtime query {query_name} for {lang}. @@ -214,34 +285,82 @@ end --- ---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found. M.get = memoize('concat-2', function(lang, query_name) + local query_string ---@type string + if explicit_queries[lang][query_name] then - return explicit_queries[lang][query_name] - end + local query_files = {} + local base_langs = {} ---@type string[] - local query_files = M.get_files(lang, query_name) - local query_string = read_query_files(query_files) + for line in explicit_queries[lang][query_name]:gmatch('([^\n]*)\n?') do + if not vim.startswith(line, ';') then + break + end + + local lang_list = line:match(MODELINE_FORMAT) + if lang_list then + for _, incl_lang in ipairs(vim.split(lang_list, ',')) do + local is_optional = incl_lang:match('%(.*%)') + + if is_optional then + add_included_lang(base_langs, lang, incl_lang:sub(2, #incl_lang - 1)) + else + add_included_lang(base_langs, lang, incl_lang) + end + end + elseif line:match(EXTENDS_FORMAT) then + table.insert(base_langs, lang) + end + end + + for _, base_lang in ipairs(base_langs) do + local base_files = M.get_files(base_lang, query_name, true) + vim.list_extend(query_files, base_files) + end + + query_string = read_query_files(query_files) .. explicit_queries[lang][query_name] + else + local query_files = M.get_files(lang, query_name) + query_string = read_query_files(query_files) + end if #query_string == 0 then return nil end return M.parse(lang, query_string) -end) +end, false) + +api.nvim_create_autocmd('OptionSet', { + pattern = { 'runtimepath' }, + group = api.nvim_create_augroup('nvim.treesitter.query_cache_reset', { clear = true }), + callback = function() + --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics + M.get:clear() + end, +}) ---- Parse {query} as a string. (If the query is in a file, the caller ---- should read the contents into a string before calling). ---- ---- Returns a `Query` (see |lua-treesitter-query|) object which can be used to ---- search nodes in the syntax tree for the patterns defined in {query} ---- using the `iter_captures` and `iter_matches` methods. +--- Parses a {query} string and returns a `Query` object (|lua-treesitter-query|), which can be used +--- to search the tree for the query patterns (via |Query:iter_captures()|, |Query:iter_matches()|), +--- or inspect the query via these fields: +--- - `captures`: a list of unique capture names defined in the query (alias: `info.captures`). +--- - `info.patterns`: information about predicates. --- ---- Exposes `info` and `captures` with additional context about {query}. ---- - `captures` contains the list of unique capture names defined in {query}. ---- - `info.captures` also points to `captures`. ---- - `info.patterns` contains information about predicates. +--- Example: +--- ```lua +--- local query = vim.treesitter.query.parse('vimdoc', [[ +--- ; query +--- ((h1) @str +--- (#trim! @str 1 1 1 1)) +--- ]]) +--- local tree = vim.treesitter.get_parser():parse()[1] +--- for id, node, metadata in query:iter_captures(tree:root(), 0) do +--- -- Print the node name and source text. +--- vim.print({node:type(), vim.treesitter.get_node_text(node, vim.api.nvim_get_current_buf())}) +--- end +--- ``` --- ---@param lang string Language to use for the query ----@param query string Query in s-expr syntax +---@param query string Query text, in s-expr syntax --- ---@return vim.treesitter.Query : Parsed query --- @@ -250,7 +369,7 @@ M.parse = memoize('concat-2', function(lang, query) assert(language.add(lang)) local ts_query = vim._ts_parse_query(lang, query) return Query.new(lang, ts_query) -end) +end, false) --- Implementations of predicates that can optionally be prefixed with "any-". --- @@ -572,13 +691,17 @@ local directive_handlers = { metadata[id].text = text:gsub(pattern, replacement) end, - -- Trim blank lines from end of the node - -- Example: (#trim! @fold) - -- TODO(clason): generalize to arbitrary whitespace removal + -- Trim whitespace from both sides of the node + -- Example: (#trim! @fold 1 1 1 1) ['trim!'] = function(match, _, bufnr, pred, metadata) local capture_id = pred[2] assert(type(capture_id) == 'number') + local trim_start_lines = pred[3] == '1' + local trim_start_cols = pred[4] == '1' + local trim_end_lines = pred[5] == '1' or not pred[3] -- default true for backwards compatibility + local trim_end_cols = pred[6] == '1' + local nodes = match[capture_id] if not nodes or #nodes == 0 then return @@ -588,20 +711,45 @@ local directive_handlers = { local start_row, start_col, end_row, end_col = node:range() - -- Don't trim if region ends in middle of a line - if end_col ~= 0 then - return + local node_text = vim.split(vim.treesitter.get_node_text(node, bufnr), '\n') + if end_col == 0 then + -- get_node_text() will ignore the last line if the node ends at column 0 + node_text[#node_text + 1] = '' end - while end_row >= start_row do - -- As we only care when end_col == 0, always inspect one line above end_row. - local end_line = api.nvim_buf_get_lines(bufnr, end_row - 1, end_row, true)[1] + local end_idx = #node_text + local start_idx = 1 - if end_line ~= '' then - break + if trim_end_lines then + while end_idx > 0 and node_text[end_idx]:find('^%s*$') do + end_idx = end_idx - 1 + end_row = end_row - 1 + -- set the end position to the last column of the next line, or 0 if we just trimmed the + -- last line + end_col = end_idx > 0 and #node_text[end_idx] or 0 end + end + if trim_end_cols then + if end_idx == 0 then + end_row = start_row + end_col = start_col + else + local whitespace_start = node_text[end_idx]:find('(%s*)$') + end_col = (whitespace_start - 1) + (end_idx == 1 and start_col or 0) + end + end - end_row = end_row - 1 + if trim_start_lines then + while start_idx <= end_idx and node_text[start_idx]:find('^%s*$') do + start_idx = start_idx + 1 + start_row = start_row + 1 + start_col = 0 + end + end + if trim_start_cols and node_text[start_idx] then + local _, whitespace_end = node_text[start_idx]:find('^(%s*)') + whitespace_end = whitespace_end or 0 + start_col = (start_idx == 1 and start_col or 0) + whitespace_end end -- If this produces an invalid range, we just skip it. @@ -711,84 +859,50 @@ function M.list_predicates() return vim.tbl_keys(predicate_handlers) end -local function xor(x, y) - return (x or y) and not (x and y) -end - -local function is_directive(name) - return string.sub(name, -1) == '!' -end - ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param predicates vim.treesitter.query.ProcessedPredicate[] +---@param captures table<integer, TSNode[]> ---@param source integer|string -function Query:match_preds(match, source) - local _, pattern = match:info() - local preds = self.info.patterns[pattern] - - if not preds then - return true - end - - local captures = match:captures() - - for _, pred in pairs(preds) do - -- Here we only want to return if a predicate DOES NOT match, and - -- continue on the other case. This way unknown predicates will not be considered, - -- which allows some testing and easier user extensibility (#12173). - -- Also, tree-sitter strips the leading # from predicates for us. - local is_not = false - - -- Skip over directives... they will get processed after all the predicates. - if not is_directive(pred[1]) then - local pred_name = pred[1] - if pred_name:match('^not%-') then - pred_name = pred_name:sub(5) - is_not = true - end - - local handler = predicate_handlers[pred_name] - - if not handler then - error(string.format('No handler for %s', pred[1])) - return false - end - - local pred_matches = handler(captures, pattern, source, pred) +---@return boolean whether the predicates match +function Query:_match_predicates(predicates, pattern_i, captures, source) + for _, predicate in ipairs(predicates) do + local processed_name = predicate[1] + local should_match = predicate[2] + local orig_predicate = predicate[3] + + local handler = predicate_handlers[processed_name] + if not handler then + error(string.format('No handler for %s', orig_predicate[1])) + return false + end - if not xor(is_not, pred_matches) then - return false - end + local does_match = handler(captures, pattern_i, source, orig_predicate) + if does_match ~= should_match then + return false end end return true end ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param directives vim.treesitter.query.ProcessedDirective[] +---@param source integer|string +---@param captures table<integer, TSNode[]> ---@return vim.treesitter.query.TSMetadata metadata -function Query:apply_directives(match, source) +function Query:_apply_directives(directives, pattern_i, captures, source) ---@type vim.treesitter.query.TSMetadata local metadata = {} - local _, pattern = match:info() - local preds = self.info.patterns[pattern] - - if not preds then - return metadata - end - local captures = match:captures() - - for _, pred in pairs(preds) do - if is_directive(pred[1]) then - local handler = directive_handlers[pred[1]] - - if not handler then - error(string.format('No handler for %s', pred[1])) - end + for _, directive in pairs(directives) do + local handler = directive_handlers[directive[1]] - handler(captures, pattern, source, pred, metadata) + if not handler then + error(string.format('No handler for %s', directive[1])) end + + handler(captures, pattern_i, source, directive, metadata) end return metadata @@ -812,26 +926,22 @@ local function value_or_node_range(start, stop, node) return start, stop end ---- @param match TSQueryMatch ---- @return integer -local function match_id_hash(_, match) - return (match:info()) -end - ---- Iterate over all captures from all matches inside {node} +--- Iterates over all captures from all matches in {node}. --- ---- {source} is needed if the query contains predicates; then the caller +--- {source} is required if the query contains predicates; then the caller --- must ensure to use a freshly parsed tree consistent with the current --- text of the buffer (if relevant). {start} and {stop} can be used to limit --- matches inside a row range (this is typically used with root node --- as the {node}, i.e., to get syntax highlight matches in the current --- viewport). When omitted, the {start} and {stop} row values are used from the given node. --- ---- The iterator returns four values: a numeric id identifying the capture, ---- the captured node, metadata from any directives processing the match, ---- and the match itself. ---- The following example shows how to get captures by name: +--- The iterator returns four values: +--- 1. the numeric id identifying the capture +--- 2. the captured node +--- 3. metadata from any directives processing the match +--- 4. the match itself --- +--- Example: how to get captures by name: --- ```lua --- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do --- local name = query.captures[id] -- name of the capture in the query @@ -847,8 +957,8 @@ end ---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch): ---- capture id, capture node, metadata, match +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch, TSTree): +--- capture id, capture node, metadata, match, tree --- ---@note Captures are only returned if the query pattern of a specific capture contained predicates. function Query:iter_captures(node, source, start, stop) @@ -858,10 +968,14 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) + -- Copy the tree to ensure it is valid during the entire lifetime of the iterator + local tree = node:tree():copy() local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) - local apply_directives = memoize(match_id_hash, self.apply_directives, true) - local match_preds = memoize(match_id_hash, self.match_preds, true) + -- For faster checks that a match is not in the cache. + local highest_cached_match_id = -1 + ---@type table<integer, vim.treesitter.query.TSMetadata> + local match_cache = {} local function iter(end_line) local capture, captured_node, match = cursor:next_capture() @@ -870,18 +984,39 @@ function Query:iter_captures(node, source, start, stop) return end - if not match_preds(self, match, source) then - local match_id = match:info() - cursor:remove_match(match_id) - if end_line and captured_node:range() > end_line then - return nil, captured_node, nil, nil - end - return iter(end_line) -- tail call: try next match + local match_id, pattern_i = match:info() + + --- @type vim.treesitter.query.TSMetadata + local metadata + if match_id <= highest_cached_match_id then + metadata = match_cache[match_id] end - local metadata = apply_directives(self, match, source) + if not metadata then + metadata = {} + + local processed_pattern = self._processed_patterns[pattern_i] + if processed_pattern then + local captures = match:captures() - return capture, captured_node, metadata, match + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then + cursor:remove_match(match_id) + if end_line and captured_node:range() > end_line then + return nil, captured_node, nil, nil + end + return iter(end_line) -- tail call: try next match + end + + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) + end + + highest_cached_match_id = math.max(highest_cached_match_id, match_id) + match_cache[match_id] = metadata + end + + return capture, captured_node, metadata, match, tree end return iter end @@ -903,7 +1038,7 @@ end --- -- `node` was captured by the `name` capture in the match --- --- local node_data = metadata[id] -- Node level metadata ---- ... use the info here ... +--- -- ... use the info here ... --- end --- end --- end @@ -922,7 +1057,7 @@ end --- (last) node instead of the full list of matching nodes. This option is only for backward --- compatibility and will be removed in a future release. --- ----@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata): pattern id, match, metadata +---@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata, TSTree): pattern id, match, metadata, tree function Query:iter_matches(node, source, start, stop, opts) opts = opts or {} opts.match_limit = opts.match_limit or 256 @@ -933,6 +1068,8 @@ function Query:iter_matches(node, source, start, stop, opts) start, stop = value_or_node_range(start, stop, node) + -- Copy the tree to ensure it is valid during the entire lifetime of the iterator + local tree = node:tree():copy() local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts) local function iter() @@ -942,17 +1079,22 @@ function Query:iter_matches(node, source, start, stop, opts) return end - local match_id, pattern = match:info() + local match_id, pattern_i = match:info() + local processed_pattern = self._processed_patterns[pattern_i] + local captures = match:captures() - if not self:match_preds(match, source) then - cursor:remove_match(match_id) - return iter() -- tail call: try next match + --- @type vim.treesitter.query.TSMetadata + local metadata = {} + if processed_pattern then + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then + cursor:remove_match(match_id) + return iter() -- tail call: try next match + end + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) end - local metadata = self:apply_directives(match, source) - - local captures = match:captures() - if opts.all == false then -- Convert the match table into the old buggy version for backward -- compatibility. This is slow, but we only do it when the caller explicitly opted into it by @@ -961,11 +1103,11 @@ function Query:iter_matches(node, source, start, stop, opts) for k, v in pairs(captures or {}) do old_match[k] = v[#v] end - return pattern, old_match, metadata + return pattern_i, old_match, metadata end -- TODO(lewis6991): create a new function that returns {match, metadata} - return pattern, captures, metadata + return pattern_i, captures, metadata, tree end return iter end |