aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/query.lua
diff options
context:
space:
mode:
authorJosh Rahm <joshuarahm@gmail.com>2025-02-05 23:09:29 +0000
committerJosh Rahm <joshuarahm@gmail.com>2025-02-05 23:09:29 +0000
commitd5f194ce780c95821a855aca3c19426576d28ae0 (patch)
treed45f461b19f9118ad2bb1f440a7a08973ad18832 /runtime/lua/vim/treesitter/query.lua
parentc5d770d311841ea5230426cc4c868e8db27300a8 (diff)
parent44740e561fc93afe3ebecfd3618bda2d2abeafb0 (diff)
downloadrneovim-rahm.tar.gz
rneovim-rahm.tar.bz2
rneovim-rahm.zip
Merge remote-tracking branch 'upstream/master' into mix_20240309HEADrahm
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r--runtime/lua/vim/treesitter/query.lua418
1 files changed, 280 insertions, 138 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 1677e8d364..10fb82e533 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -1,17 +1,77 @@
+--- @brief This Lua |treesitter-query| interface allows you to create queries and use them to parse
+--- text. See |vim.treesitter.query.parse()| for a working example.
+
local api = vim.api
local language = require('vim.treesitter.language')
local memoize = vim.func._memoize
+local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
+local EXTENDS_FORMAT = '^;+%s*extends%s*$'
+
local M = {}
+local function is_directive(name)
+ return string.sub(name, -1) == '!'
+end
+
+---@nodoc
+---@class vim.treesitter.query.ProcessedPredicate
+---@field [1] string predicate name
+---@field [2] boolean should match
+---@field [3] (integer|string)[] the original predicate
+
+---@alias vim.treesitter.query.ProcessedDirective (integer|string)[]
+
+---@nodoc
+---@class vim.treesitter.query.ProcessedPattern {
+---@field predicates vim.treesitter.query.ProcessedPredicate[]
+---@field directives vim.treesitter.query.ProcessedDirective[]
+
+--- Splits the query patterns into predicates and directives.
+---@param patterns table<integer, (integer|string)[][]>
+---@return table<integer, vim.treesitter.query.ProcessedPattern>
+local function process_patterns(patterns)
+ ---@type table<integer, vim.treesitter.query.ProcessedPattern>
+ local processed_patterns = {}
+
+ for k, pattern_list in pairs(patterns) do
+ ---@type vim.treesitter.query.ProcessedPredicate[]
+ local predicates = {}
+ ---@type vim.treesitter.query.ProcessedDirective[]
+ local directives = {}
+
+ for _, pattern in ipairs(pattern_list) do
+ -- Note: tree-sitter strips the leading # from predicates for us.
+ local pred_name = pattern[1]
+ ---@cast pred_name string
+
+ if is_directive(pred_name) then
+ table.insert(directives, pattern)
+ else
+ local should_match = true
+ if pred_name:match('^not%-') then
+ pred_name = pred_name:sub(5)
+ should_match = false
+ end
+ table.insert(predicates, { pred_name, should_match, pattern })
+ end
+ end
+
+ processed_patterns[k] = { predicates = predicates, directives = directives }
+ end
+
+ return processed_patterns
+end
+
---@nodoc
---Parsed query, see |vim.treesitter.query.parse()|
---
---@class vim.treesitter.Query
----@field lang string name of the language for this parser
+---@field lang string parser language name
---@field captures string[] list of (unique) capture names defined in query
----@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives)
+---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives)
---@field query TSQuery userdata query object
+---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern>
local Query = {}
Query.__index = Query
@@ -30,6 +90,7 @@ function Query.new(lang, ts_query)
patterns = query_info.patterns,
}
self.captures = self.info.captures
+ self._processed_patterns = process_patterns(self.info.patterns)
return self
end
@@ -109,9 +170,6 @@ function M.get_files(lang, query_name, is_included)
-- ;+ inherits: ({language},)*{language}
--
-- {language} ::= {lang} | ({lang})
- local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
- local EXTENDS_FORMAT = '^;+%s*extends%s*$'
-
for _, filename in ipairs(lang_files) do
local file, err = io.open(filename, 'r')
if not file then
@@ -184,8 +242,8 @@ local function read_query_files(filenames)
return table.concat(contents, '')
end
--- The explicitly set queries from |vim.treesitter.query.set()|
----@type table<string,table<string,vim.treesitter.Query>>
+-- The explicitly set query strings from |vim.treesitter.query.set()|
+---@type table<string,table<string,string>>
local explicit_queries = setmetatable({}, {
__index = function(t, k)
local lang_queries = {}
@@ -197,14 +255,27 @@ local explicit_queries = setmetatable({}, {
--- Sets the runtime query named {query_name} for {lang}
---
---- This allows users to override any runtime files and/or configuration
+--- This allows users to override or extend any runtime files and/or configuration
--- set by plugins.
---
+--- For example, you could enable spellchecking of `C` identifiers with the
+--- following code:
+--- ```lua
+--- vim.treesitter.query.set(
+--- 'c',
+--- 'highlights',
+--- [[;inherits c
+--- (identifier) @spell]])
+--- ]])
+--- ```
+---
---@param lang string Language to use for the query
---@param query_name string Name of the query (e.g., "highlights")
---@param text string Query text (unparsed).
function M.set(lang, query_name, text)
- explicit_queries[lang][query_name] = M.parse(lang, text)
+ --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics
+ M.get:clear(lang, query_name)
+ explicit_queries[lang][query_name] = text
end
--- Returns the runtime query {query_name} for {lang}.
@@ -214,34 +285,82 @@ end
---
---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found.
M.get = memoize('concat-2', function(lang, query_name)
+ local query_string ---@type string
+
if explicit_queries[lang][query_name] then
- return explicit_queries[lang][query_name]
- end
+ local query_files = {}
+ local base_langs = {} ---@type string[]
- local query_files = M.get_files(lang, query_name)
- local query_string = read_query_files(query_files)
+ for line in explicit_queries[lang][query_name]:gmatch('([^\n]*)\n?') do
+ if not vim.startswith(line, ';') then
+ break
+ end
+
+ local lang_list = line:match(MODELINE_FORMAT)
+ if lang_list then
+ for _, incl_lang in ipairs(vim.split(lang_list, ',')) do
+ local is_optional = incl_lang:match('%(.*%)')
+
+ if is_optional then
+ add_included_lang(base_langs, lang, incl_lang:sub(2, #incl_lang - 1))
+ else
+ add_included_lang(base_langs, lang, incl_lang)
+ end
+ end
+ elseif line:match(EXTENDS_FORMAT) then
+ table.insert(base_langs, lang)
+ end
+ end
+
+ for _, base_lang in ipairs(base_langs) do
+ local base_files = M.get_files(base_lang, query_name, true)
+ vim.list_extend(query_files, base_files)
+ end
+
+ query_string = read_query_files(query_files) .. explicit_queries[lang][query_name]
+ else
+ local query_files = M.get_files(lang, query_name)
+ query_string = read_query_files(query_files)
+ end
if #query_string == 0 then
return nil
end
return M.parse(lang, query_string)
-end)
+end, false)
+
+api.nvim_create_autocmd('OptionSet', {
+ pattern = { 'runtimepath' },
+ group = api.nvim_create_augroup('nvim.treesitter.query_cache_reset', { clear = true }),
+ callback = function()
+ --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics
+ M.get:clear()
+ end,
+})
---- Parse {query} as a string. (If the query is in a file, the caller
---- should read the contents into a string before calling).
----
---- Returns a `Query` (see |lua-treesitter-query|) object which can be used to
---- search nodes in the syntax tree for the patterns defined in {query}
---- using the `iter_captures` and `iter_matches` methods.
+--- Parses a {query} string and returns a `Query` object (|lua-treesitter-query|), which can be used
+--- to search the tree for the query patterns (via |Query:iter_captures()|, |Query:iter_matches()|),
+--- or inspect the query via these fields:
+--- - `captures`: a list of unique capture names defined in the query (alias: `info.captures`).
+--- - `info.patterns`: information about predicates.
---
---- Exposes `info` and `captures` with additional context about {query}.
---- - `captures` contains the list of unique capture names defined in {query}.
---- - `info.captures` also points to `captures`.
---- - `info.patterns` contains information about predicates.
+--- Example:
+--- ```lua
+--- local query = vim.treesitter.query.parse('vimdoc', [[
+--- ; query
+--- ((h1) @str
+--- (#trim! @str 1 1 1 1))
+--- ]])
+--- local tree = vim.treesitter.get_parser():parse()[1]
+--- for id, node, metadata in query:iter_captures(tree:root(), 0) do
+--- -- Print the node name and source text.
+--- vim.print({node:type(), vim.treesitter.get_node_text(node, vim.api.nvim_get_current_buf())})
+--- end
+--- ```
---
---@param lang string Language to use for the query
----@param query string Query in s-expr syntax
+---@param query string Query text, in s-expr syntax
---
---@return vim.treesitter.Query : Parsed query
---
@@ -250,7 +369,7 @@ M.parse = memoize('concat-2', function(lang, query)
assert(language.add(lang))
local ts_query = vim._ts_parse_query(lang, query)
return Query.new(lang, ts_query)
-end)
+end, false)
--- Implementations of predicates that can optionally be prefixed with "any-".
---
@@ -572,13 +691,17 @@ local directive_handlers = {
metadata[id].text = text:gsub(pattern, replacement)
end,
- -- Trim blank lines from end of the node
- -- Example: (#trim! @fold)
- -- TODO(clason): generalize to arbitrary whitespace removal
+ -- Trim whitespace from both sides of the node
+ -- Example: (#trim! @fold 1 1 1 1)
['trim!'] = function(match, _, bufnr, pred, metadata)
local capture_id = pred[2]
assert(type(capture_id) == 'number')
+ local trim_start_lines = pred[3] == '1'
+ local trim_start_cols = pred[4] == '1'
+ local trim_end_lines = pred[5] == '1' or not pred[3] -- default true for backwards compatibility
+ local trim_end_cols = pred[6] == '1'
+
local nodes = match[capture_id]
if not nodes or #nodes == 0 then
return
@@ -588,20 +711,45 @@ local directive_handlers = {
local start_row, start_col, end_row, end_col = node:range()
- -- Don't trim if region ends in middle of a line
- if end_col ~= 0 then
- return
+ local node_text = vim.split(vim.treesitter.get_node_text(node, bufnr), '\n')
+ if end_col == 0 then
+ -- get_node_text() will ignore the last line if the node ends at column 0
+ node_text[#node_text + 1] = ''
end
- while end_row >= start_row do
- -- As we only care when end_col == 0, always inspect one line above end_row.
- local end_line = api.nvim_buf_get_lines(bufnr, end_row - 1, end_row, true)[1]
+ local end_idx = #node_text
+ local start_idx = 1
- if end_line ~= '' then
- break
+ if trim_end_lines then
+ while end_idx > 0 and node_text[end_idx]:find('^%s*$') do
+ end_idx = end_idx - 1
+ end_row = end_row - 1
+ -- set the end position to the last column of the next line, or 0 if we just trimmed the
+ -- last line
+ end_col = end_idx > 0 and #node_text[end_idx] or 0
end
+ end
+ if trim_end_cols then
+ if end_idx == 0 then
+ end_row = start_row
+ end_col = start_col
+ else
+ local whitespace_start = node_text[end_idx]:find('(%s*)$')
+ end_col = (whitespace_start - 1) + (end_idx == 1 and start_col or 0)
+ end
+ end
- end_row = end_row - 1
+ if trim_start_lines then
+ while start_idx <= end_idx and node_text[start_idx]:find('^%s*$') do
+ start_idx = start_idx + 1
+ start_row = start_row + 1
+ start_col = 0
+ end
+ end
+ if trim_start_cols and node_text[start_idx] then
+ local _, whitespace_end = node_text[start_idx]:find('^(%s*)')
+ whitespace_end = whitespace_end or 0
+ start_col = (start_idx == 1 and start_col or 0) + whitespace_end
end
-- If this produces an invalid range, we just skip it.
@@ -711,84 +859,50 @@ function M.list_predicates()
return vim.tbl_keys(predicate_handlers)
end
-local function xor(x, y)
- return (x or y) and not (x and y)
-end
-
-local function is_directive(name)
- return string.sub(name, -1) == '!'
-end
-
---@private
----@param match TSQueryMatch
+---@param pattern_i integer
+---@param predicates vim.treesitter.query.ProcessedPredicate[]
+---@param captures table<integer, TSNode[]>
---@param source integer|string
-function Query:match_preds(match, source)
- local _, pattern = match:info()
- local preds = self.info.patterns[pattern]
-
- if not preds then
- return true
- end
-
- local captures = match:captures()
-
- for _, pred in pairs(preds) do
- -- Here we only want to return if a predicate DOES NOT match, and
- -- continue on the other case. This way unknown predicates will not be considered,
- -- which allows some testing and easier user extensibility (#12173).
- -- Also, tree-sitter strips the leading # from predicates for us.
- local is_not = false
-
- -- Skip over directives... they will get processed after all the predicates.
- if not is_directive(pred[1]) then
- local pred_name = pred[1]
- if pred_name:match('^not%-') then
- pred_name = pred_name:sub(5)
- is_not = true
- end
-
- local handler = predicate_handlers[pred_name]
-
- if not handler then
- error(string.format('No handler for %s', pred[1]))
- return false
- end
-
- local pred_matches = handler(captures, pattern, source, pred)
+---@return boolean whether the predicates match
+function Query:_match_predicates(predicates, pattern_i, captures, source)
+ for _, predicate in ipairs(predicates) do
+ local processed_name = predicate[1]
+ local should_match = predicate[2]
+ local orig_predicate = predicate[3]
+
+ local handler = predicate_handlers[processed_name]
+ if not handler then
+ error(string.format('No handler for %s', orig_predicate[1]))
+ return false
+ end
- if not xor(is_not, pred_matches) then
- return false
- end
+ local does_match = handler(captures, pattern_i, source, orig_predicate)
+ if does_match ~= should_match then
+ return false
end
end
return true
end
---@private
----@param match TSQueryMatch
+---@param pattern_i integer
+---@param directives vim.treesitter.query.ProcessedDirective[]
+---@param source integer|string
+---@param captures table<integer, TSNode[]>
---@return vim.treesitter.query.TSMetadata metadata
-function Query:apply_directives(match, source)
+function Query:_apply_directives(directives, pattern_i, captures, source)
---@type vim.treesitter.query.TSMetadata
local metadata = {}
- local _, pattern = match:info()
- local preds = self.info.patterns[pattern]
-
- if not preds then
- return metadata
- end
- local captures = match:captures()
-
- for _, pred in pairs(preds) do
- if is_directive(pred[1]) then
- local handler = directive_handlers[pred[1]]
-
- if not handler then
- error(string.format('No handler for %s', pred[1]))
- end
+ for _, directive in pairs(directives) do
+ local handler = directive_handlers[directive[1]]
- handler(captures, pattern, source, pred, metadata)
+ if not handler then
+ error(string.format('No handler for %s', directive[1]))
end
+
+ handler(captures, pattern_i, source, directive, metadata)
end
return metadata
@@ -812,26 +926,22 @@ local function value_or_node_range(start, stop, node)
return start, stop
end
---- @param match TSQueryMatch
---- @return integer
-local function match_id_hash(_, match)
- return (match:info())
-end
-
---- Iterate over all captures from all matches inside {node}
+--- Iterates over all captures from all matches in {node}.
---
---- {source} is needed if the query contains predicates; then the caller
+--- {source} is required if the query contains predicates; then the caller
--- must ensure to use a freshly parsed tree consistent with the current
--- text of the buffer (if relevant). {start} and {stop} can be used to limit
--- matches inside a row range (this is typically used with root node
--- as the {node}, i.e., to get syntax highlight matches in the current
--- viewport). When omitted, the {start} and {stop} row values are used from the given node.
---
---- The iterator returns four values: a numeric id identifying the capture,
---- the captured node, metadata from any directives processing the match,
---- and the match itself.
---- The following example shows how to get captures by name:
+--- The iterator returns four values:
+--- 1. the numeric id identifying the capture
+--- 2. the captured node
+--- 3. metadata from any directives processing the match
+--- 4. the match itself
---
+--- Example: how to get captures by name:
--- ```lua
--- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do
--- local name = query.captures[id] -- name of the capture in the query
@@ -847,8 +957,8 @@ end
---@param start? integer Starting line for the search. Defaults to `node:start()`.
---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
---
----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch):
---- capture id, capture node, metadata, match
+---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch, TSTree):
+--- capture id, capture node, metadata, match, tree
---
---@note Captures are only returned if the query pattern of a specific capture contained predicates.
function Query:iter_captures(node, source, start, stop)
@@ -858,10 +968,14 @@ function Query:iter_captures(node, source, start, stop)
start, stop = value_or_node_range(start, stop, node)
+ -- Copy the tree to ensure it is valid during the entire lifetime of the iterator
+ local tree = node:tree():copy()
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
- local apply_directives = memoize(match_id_hash, self.apply_directives, true)
- local match_preds = memoize(match_id_hash, self.match_preds, true)
+ -- For faster checks that a match is not in the cache.
+ local highest_cached_match_id = -1
+ ---@type table<integer, vim.treesitter.query.TSMetadata>
+ local match_cache = {}
local function iter(end_line)
local capture, captured_node, match = cursor:next_capture()
@@ -870,18 +984,39 @@ function Query:iter_captures(node, source, start, stop)
return
end
- if not match_preds(self, match, source) then
- local match_id = match:info()
- cursor:remove_match(match_id)
- if end_line and captured_node:range() > end_line then
- return nil, captured_node, nil, nil
- end
- return iter(end_line) -- tail call: try next match
+ local match_id, pattern_i = match:info()
+
+ --- @type vim.treesitter.query.TSMetadata
+ local metadata
+ if match_id <= highest_cached_match_id then
+ metadata = match_cache[match_id]
end
- local metadata = apply_directives(self, match, source)
+ if not metadata then
+ metadata = {}
+
+ local processed_pattern = self._processed_patterns[pattern_i]
+ if processed_pattern then
+ local captures = match:captures()
- return capture, captured_node, metadata, match
+ local predicates = processed_pattern.predicates
+ if not self:_match_predicates(predicates, pattern_i, captures, source) then
+ cursor:remove_match(match_id)
+ if end_line and captured_node:range() > end_line then
+ return nil, captured_node, nil, nil
+ end
+ return iter(end_line) -- tail call: try next match
+ end
+
+ local directives = processed_pattern.directives
+ metadata = self:_apply_directives(directives, pattern_i, captures, source)
+ end
+
+ highest_cached_match_id = math.max(highest_cached_match_id, match_id)
+ match_cache[match_id] = metadata
+ end
+
+ return capture, captured_node, metadata, match, tree
end
return iter
end
@@ -903,7 +1038,7 @@ end
--- -- `node` was captured by the `name` capture in the match
---
--- local node_data = metadata[id] -- Node level metadata
---- ... use the info here ...
+--- -- ... use the info here ...
--- end
--- end
--- end
@@ -922,7 +1057,7 @@ end
--- (last) node instead of the full list of matching nodes. This option is only for backward
--- compatibility and will be removed in a future release.
---
----@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata): pattern id, match, metadata
+---@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata, TSTree): pattern id, match, metadata, tree
function Query:iter_matches(node, source, start, stop, opts)
opts = opts or {}
opts.match_limit = opts.match_limit or 256
@@ -933,6 +1068,8 @@ function Query:iter_matches(node, source, start, stop, opts)
start, stop = value_or_node_range(start, stop, node)
+ -- Copy the tree to ensure it is valid during the entire lifetime of the iterator
+ local tree = node:tree():copy()
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts)
local function iter()
@@ -942,17 +1079,22 @@ function Query:iter_matches(node, source, start, stop, opts)
return
end
- local match_id, pattern = match:info()
+ local match_id, pattern_i = match:info()
+ local processed_pattern = self._processed_patterns[pattern_i]
+ local captures = match:captures()
- if not self:match_preds(match, source) then
- cursor:remove_match(match_id)
- return iter() -- tail call: try next match
+ --- @type vim.treesitter.query.TSMetadata
+ local metadata = {}
+ if processed_pattern then
+ local predicates = processed_pattern.predicates
+ if not self:_match_predicates(predicates, pattern_i, captures, source) then
+ cursor:remove_match(match_id)
+ return iter() -- tail call: try next match
+ end
+ local directives = processed_pattern.directives
+ metadata = self:_apply_directives(directives, pattern_i, captures, source)
end
- local metadata = self:apply_directives(match, source)
-
- local captures = match:captures()
-
if opts.all == false then
-- Convert the match table into the old buggy version for backward
-- compatibility. This is slow, but we only do it when the caller explicitly opted into it by
@@ -961,11 +1103,11 @@ function Query:iter_matches(node, source, start, stop, opts)
for k, v in pairs(captures or {}) do
old_match[k] = v[#v]
end
- return pattern, old_match, metadata
+ return pattern_i, old_match, metadata
end
-- TODO(lewis6991): create a new function that returns {match, metadata}
- return pattern, captures, metadata
+ return pattern_i, captures, metadata, tree
end
return iter
end