aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter')
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua300
-rw-r--r--runtime/lua/vim/treesitter/_meta.lua28
-rw-r--r--runtime/lua/vim/treesitter/_query_linter.lua46
-rw-r--r--runtime/lua/vim/treesitter/dev.lua228
-rw-r--r--runtime/lua/vim/treesitter/health.lua2
-rw-r--r--runtime/lua/vim/treesitter/highlighter.lua194
-rw-r--r--runtime/lua/vim/treesitter/language.lua40
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua95
-rw-r--r--runtime/lua/vim/treesitter/query.lua602
9 files changed, 861 insertions, 674 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
index 5c1cc06908..d96cc966de 100644
--- a/runtime/lua/vim/treesitter/_fold.lua
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -5,35 +5,20 @@ local Range = require('vim.treesitter._range')
local api = vim.api
---@class TS.FoldInfo
----@field levels table<integer,string>
----@field levels0 table<integer,integer>
----@field private start_counts table<integer,integer>
----@field private stop_counts table<integer,integer>
+---@field levels string[] the foldexpr result for each line
+---@field levels0 integer[] the raw fold levels
+---@field edits? {[1]: integer, [2]: integer} line range edited since the last invocation of the callback scheduled in on_bytes. 0-indexed, end-exclusive.
local FoldInfo = {}
FoldInfo.__index = FoldInfo
---@private
function FoldInfo.new()
return setmetatable({
- start_counts = {},
- stop_counts = {},
levels0 = {},
levels = {},
}, FoldInfo)
end
----@package
----@param srow integer
----@param erow integer
-function FoldInfo:invalidate_range(srow, erow)
- for i = srow, erow do
- self.start_counts[i + 1] = nil
- self.stop_counts[i + 1] = nil
- self.levels0[i + 1] = nil
- self.levels[i + 1] = nil
- end
-end
-
--- Efficiently remove items from middle of a list a list.
---
--- Calling table.remove() in a loop will re-index the tail of the table on
@@ -55,12 +40,10 @@ end
---@package
---@param srow integer
----@param erow integer
+---@param erow integer 0-indexed, exclusive
function FoldInfo:remove_range(srow, erow)
list_remove(self.levels, srow + 1, erow)
list_remove(self.levels0, srow + 1, erow)
- list_remove(self.start_counts, srow + 1, erow)
- list_remove(self.stop_counts, srow + 1, erow)
end
--- Efficiently insert items into the middle of a list.
@@ -91,46 +74,37 @@ end
---@package
---@param srow integer
----@param erow integer
+---@param erow integer 0-indexed, exclusive
function FoldInfo:add_range(srow, erow)
- list_insert(self.levels, srow + 1, erow, '-1')
+ list_insert(self.levels, srow + 1, erow, '=')
list_insert(self.levels0, srow + 1, erow, -1)
- list_insert(self.start_counts, srow + 1, erow, nil)
- list_insert(self.stop_counts, srow + 1, erow, nil)
end
---@package
----@param lnum integer
-function FoldInfo:add_start(lnum)
- self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1
-end
-
----@package
----@param lnum integer
-function FoldInfo:add_stop(lnum)
- self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1
-end
-
----@package
----@param lnum integer
----@return integer
-function FoldInfo:get_start(lnum)
- return self.start_counts[lnum] or 0
+---@param srow integer
+---@param erow_old integer
+---@param erow_new integer 0-indexed, exclusive
+function FoldInfo:edit_range(srow, erow_old, erow_new)
+ if self.edits then
+ self.edits[1] = math.min(srow, self.edits[1])
+ if erow_old <= self.edits[2] then
+ self.edits[2] = self.edits[2] + (erow_new - erow_old)
+ end
+ self.edits[2] = math.max(self.edits[2], erow_new)
+ else
+ self.edits = { srow, erow_new }
+ end
end
---@package
----@param lnum integer
----@return integer
-function FoldInfo:get_stop(lnum)
- return self.stop_counts[lnum] or 0
-end
-
-local function trim_level(level)
- local max_fold_level = vim.wo.foldnestmax
- if level > max_fold_level then
- return max_fold_level
+---@return integer? srow
+---@return integer? erow 0-indexed, exclusive
+function FoldInfo:flush_edit()
+ if self.edits then
+ local srow, erow = self.edits[1], self.edits[2]
+ self.edits = nil
+ return srow, erow
end
- return level
end
--- If a parser doesn't have any ranges explicitly set, treesitter will
@@ -140,10 +114,10 @@ end
--- TODO(lewis6991): Handle this generally
---
--- @param bufnr integer
---- @param erow integer?
+--- @param erow integer? 0-indexed, exclusive
--- @return integer
local function normalise_erow(bufnr, erow)
- local max_erow = api.nvim_buf_line_count(bufnr) - 1
+ local max_erow = api.nvim_buf_line_count(bufnr)
return math.min(erow or max_erow, max_erow)
end
@@ -152,31 +126,30 @@ end
---@param bufnr integer
---@param info TS.FoldInfo
---@param srow integer?
----@param erow integer?
+---@param erow integer? 0-indexed, exclusive
---@param parse_injections? boolean
local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
srow = srow or 0
erow = normalise_erow(bufnr, erow)
- info:invalidate_range(srow, erow)
-
- local prev_start = -1
- local prev_stop = -1
-
local parser = ts.get_parser(bufnr)
parser:parse(parse_injections and { srow, erow } or nil)
+ local enter_counts = {} ---@type table<integer, integer>
+ local leave_counts = {} ---@type table<integer, integer>
+ local prev_start = -1
+ local prev_stop = -1
+
parser:for_each_tree(function(tree, ltree)
local query = ts.query.get(ltree:lang(), 'folds')
if not query then
return
end
- -- erow in query is end-exclusive
- local q_erow = erow and erow + 1 or -1
-
- for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do
+ -- Collect folds starting from srow - 1, because we should first subtract the folds that end at
+ -- srow - 1 from the level of srow - 1 to get accurate level of srow.
+ for id, node, metadata in query:iter_captures(tree:root(), bufnr, math.max(srow - 1, 0), erow) do
if query.captures[id] == 'fold' then
local range = ts.get_range(node, bufnr, metadata[id])
local start, _, stop, stop_col = Range.unpack4(range)
@@ -193,8 +166,8 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
if
fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
then
- info:add_start(start + 1)
- info:add_stop(stop + 1)
+ enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1
+ leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1
prev_start = start
prev_stop = stop
end
@@ -202,16 +175,15 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
end
end)
- local current_level = info.levels0[srow] or 0
+ local nestmax = vim.wo.foldnestmax
+ local level0_prev = info.levels0[srow] or 0
+ local leave_prev = leave_counts[srow] or 0
-- We now have the list of fold opening and closing, fill the gaps and mark where fold start
- for lnum = srow + 1, erow + 1 do
- local last_trimmed_level = trim_level(current_level)
- current_level = current_level + info:get_start(lnum)
- info.levels0[lnum] = current_level
-
- local trimmed_level = trim_level(current_level)
- current_level = current_level - info:get_stop(lnum)
+ for lnum = srow + 1, erow do
+ local enter_line = enter_counts[lnum] or 0
+ local leave_line = leave_counts[lnum] or 0
+ local level0 = level0_prev - leave_prev + enter_line
-- Determine if it's the start/end of a fold
-- NB: vim's fold-expr interface does not have a mechanism to indicate that
@@ -219,14 +191,36 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
-- ( \n ( \n )) \n (( \n ) \n )
-- versus
-- ( \n ( \n ) \n ( \n ) \n )
- -- If it did have such a mechanism, (trimmed_level - last_trimmed_level)
+ -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and
+ -- vim interprets as the second case.
+ -- If it did have such a mechanism, (clamped - clamped_prev)
-- would be the correct number of starts to pass on.
+ local adjusted = level0 ---@type integer
local prefix = ''
- if trimmed_level - last_trimmed_level > 0 then
+ if enter_line > 0 then
prefix = '>'
+ if leave_line > 0 then
+ -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line
+ -- so that f2 gets the correct level on this line. This may reduce the size of f1 below
+ -- foldminlines, but we don't handle it for simplicity.
+ adjusted = level0 - leave_line
+ leave_line = 0
+ end
+ end
+
+ -- Clamp at foldnestmax.
+ local clamped = adjusted
+ if adjusted > nestmax then
+ prefix = ''
+ clamped = nestmax
end
- info.levels[lnum] = prefix .. tostring(trimmed_level)
+ -- Record the "real" level, so that it can be used as "base" of later get_folds_levels().
+ info.levels0[lnum] = adjusted
+ info.levels[lnum] = prefix .. tostring(clamped)
+
+ leave_prev = leave_line
+ level0_prev = adjusted
end
end
@@ -296,8 +290,12 @@ end
local function on_changedtree(bufnr, foldinfo, tree_changes)
schedule_if_loaded(bufnr, function()
for _, change in ipairs(tree_changes) do
- local srow, _, erow = Range.unpack4(change)
- get_folds_levels(bufnr, foldinfo, srow, erow)
+ local srow, _, erow, ecol = Range.unpack4(change)
+ if ecol > 0 then
+ erow = erow + 1
+ end
+ -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
+ get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow)
end
if #tree_changes > 0 then
foldupdate(bufnr)
@@ -309,19 +307,46 @@ end
---@param foldinfo TS.FoldInfo
---@param start_row integer
---@param old_row integer
+---@param old_col integer
---@param new_row integer
-local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row)
- local end_row_old = start_row + old_row
- local end_row_new = start_row + new_row
+---@param new_col integer
+local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, new_row, new_col)
+ -- extend the end to fully include the range
+ local end_row_old = start_row + old_row + 1
+ local end_row_new = start_row + new_row + 1
if new_row ~= old_row then
+ -- foldexpr can be evaluated before the scheduled callback is invoked. So it may observe the
+ -- outdated levels, which may spuriously open the folds that didn't change. So we should shift
+ -- folds as accurately as possible. For this to be perfectly accurate, we should track the
+ -- actual TSNodes that account for each fold, and compare the node's range with the edited
+ -- range. But for simplicity, we just check whether the start row is completely removed (e.g.,
+ -- `dd`) or shifted (e.g., `o`).
if new_row < old_row then
- foldinfo:remove_range(end_row_new, end_row_old)
+ if start_col == 0 and new_row == 0 and new_col == 0 then
+ foldinfo:remove_range(start_row, start_row + (end_row_old - end_row_new))
+ else
+ foldinfo:remove_range(end_row_new, end_row_old)
+ end
else
- foldinfo:add_range(start_row, end_row_new)
+ if start_col == 0 and old_row == 0 and old_col == 0 then
+ foldinfo:add_range(start_row, start_row + (end_row_new - end_row_old))
+ else
+ foldinfo:add_range(end_row_old, end_row_new)
+ end
end
+ foldinfo:edit_range(start_row, end_row_old, end_row_new)
+
+ -- This callback must not use on_bytes arguments, because they can be outdated when the callback
+ -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing
+ -- the scheduled callback. So we should collect the edits.
schedule_if_loaded(bufnr, function()
- get_folds_levels(bufnr, foldinfo, start_row, end_row_new)
+ local srow, erow = foldinfo:flush_edit()
+ if not srow then
+ return
+ end
+ -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
+ get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow)
foldupdate(bufnr)
end)
end
@@ -348,8 +373,8 @@ function M.foldexpr(lnum)
on_changedtree(bufnr, foldinfos[bufnr], tree_changes)
end,
- on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _)
- on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row)
+ on_bytes = function(_, _, start_row, start_col, _, old_row, old_col, _, new_row, new_col, _)
+ on_bytes(bufnr, foldinfos[bufnr], start_row, start_col, old_row, old_col, new_row, new_col)
end,
on_detach = function()
@@ -361,96 +386,15 @@ function M.foldexpr(lnum)
return foldinfos[bufnr].levels[lnum] or '0'
end
----@package
----@return { [1]: string, [2]: string[] }[]|string
-function M.foldtext()
- local foldstart = vim.v.foldstart
- local bufnr = api.nvim_get_current_buf()
-
- ---@type boolean, LanguageTree
- local ok, parser = pcall(ts.get_parser, bufnr)
- if not ok then
- return vim.fn.foldtext()
- end
-
- local query = ts.query.get(parser:lang(), 'highlights')
- if not query then
- return vim.fn.foldtext()
- end
-
- local tree = parser:parse({ foldstart - 1, foldstart })[1]
-
- local line = api.nvim_buf_get_lines(bufnr, foldstart - 1, foldstart, false)[1]
- if not line then
- return vim.fn.foldtext()
- end
-
- ---@type { [1]: string, [2]: string[], range: { [1]: integer, [2]: integer } }[] | { [1]: string, [2]: string[] }[]
- local result = {}
-
- local line_pos = 0
-
- for id, node, metadata in query:iter_captures(tree:root(), 0, foldstart - 1, foldstart) do
- local name = query.captures[id]
- local start_row, start_col, end_row, end_col = node:range()
-
- local priority = tonumber(metadata.priority or vim.highlight.priorities.treesitter)
-
- if start_row == foldstart - 1 and end_row == foldstart - 1 then
- -- check for characters ignored by treesitter
- if start_col > line_pos then
- table.insert(result, {
- line:sub(line_pos + 1, start_col),
- {},
- range = { line_pos, start_col },
- })
- end
- line_pos = end_col
-
- local text = line:sub(start_col + 1, end_col)
- table.insert(result, { text, { { '@' .. name, priority } }, range = { start_col, end_col } })
- end
- end
-
- local i = 1
- while i <= #result do
- -- find first capture that is not in current range and apply highlights on the way
- local j = i + 1
- while
- j <= #result
- and result[j].range[1] >= result[i].range[1]
- and result[j].range[2] <= result[i].range[2]
- do
- for k, v in ipairs(result[i][2]) do
- if not vim.tbl_contains(result[j][2], v) then
- table.insert(result[j][2], k, v)
- end
- end
- j = j + 1
- end
-
- -- remove the parent capture if it is split into children
- if j > i + 1 then
- table.remove(result, i)
- else
- -- highlights need to be sorted by priority, on equal prio, the deeper nested capture (earlier
- -- in list) should be considered higher prio
- if #result[i][2] > 1 then
- table.sort(result[i][2], function(a, b)
- return a[2] < b[2]
- end)
- end
-
- result[i][2] = vim.tbl_map(function(tbl)
- return tbl[1]
- end, result[i][2])
- result[i] = { result[i][1], result[i][2] }
-
- i = i + 1
+api.nvim_create_autocmd('OptionSet', {
+ pattern = { 'foldminlines', 'foldnestmax' },
+ desc = 'Refresh treesitter folds',
+ callback = function()
+ for _, bufnr in ipairs(vim.tbl_keys(foldinfos)) do
+ foldinfos[bufnr] = FoldInfo.new()
+ get_folds_levels(bufnr, foldinfos[bufnr])
+ foldupdate(bufnr)
end
- end
-
- return result
-end
-
+ end,
+})
return M
diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua
index 80c998b555..19d97d2820 100644
--- a/runtime/lua/vim/treesitter/_meta.lua
+++ b/runtime/lua/vim/treesitter/_meta.lua
@@ -1,4 +1,5 @@
---@meta
+error('Cannot require a meta file')
---@class TSNode: userdata
---@field id fun(self: TSNode): string
@@ -33,27 +34,26 @@
---@field byte_length fun(self: TSNode): integer
local TSNode = {}
----@param query userdata
+---@param query TSQuery
---@param captures true
---@param start? integer
---@param end_? integer
---@param opts? table
----@return fun(): integer, TSNode, any
+---@return fun(): integer, TSNode, vim.treesitter.query.TSMatch
function TSNode:_rawquery(query, captures, start, end_, opts) end
----@param query userdata
+---@param query TSQuery
---@param captures false
---@param start? integer
---@param end_? integer
---@param opts? table
----@return fun(): string, any
+---@return fun(): integer, vim.treesitter.query.TSMatch
function TSNode:_rawquery(query, captures, start, end_, opts) end
---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string)
----@class TSParser
----@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: true): TSTree, Range6[]
----@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: false|nil): TSTree, Range4[]
+---@class TSParser: userdata
+---@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: boolean): TSTree, (Range4|Range6)[]
---@field reset fun(self: TSParser)
---@field included_ranges fun(self: TSParser, include_bytes: boolean?): integer[]
---@field set_included_ranges fun(self: TSParser, ranges: (Range6|TSNode)[])
@@ -62,19 +62,31 @@ function TSNode:_rawquery(query, captures, start, end_, opts) end
---@field _set_logger fun(self: TSParser, lex: boolean, parse: boolean, cb: TSLoggerCallback)
---@field _logger fun(self: TSParser): TSLoggerCallback
----@class TSTree
+---@class TSTree: userdata
---@field root fun(self: TSTree): TSNode
---@field edit fun(self: TSTree, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _:integer)
---@field copy fun(self: TSTree): TSTree
---@field included_ranges fun(self: TSTree, include_bytes: true): Range6[]
---@field included_ranges fun(self: TSTree, include_bytes: false): Range4[]
+---@class TSQuery: userdata
+---@field inspect fun(self: TSQuery): TSQueryInfo
+
+---@class (exact) TSQueryInfo
+---@field captures string[]
+---@field patterns table<integer, (integer|string)[][]>
+
---@return integer
vim._ts_get_language_version = function() end
---@return integer
vim._ts_get_minimum_language_version = function() end
+---@param lang string Language to use for the query
+---@param query string Query string in s-expr syntax
+---@return TSQuery
+vim._ts_parse_query = function(lang, query) end
+
---@param lang string
---@return TSParser
vim._create_ts_parser = function(lang) end
diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua
index 87d74789a3..6216d4e891 100644
--- a/runtime/lua/vim/treesitter/_query_linter.lua
+++ b/runtime/lua/vim/treesitter/_query_linter.lua
@@ -17,7 +17,7 @@ local M = {}
--- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs`
--- Adds a diagnostic for node in the query buffer
---- @param diagnostics Diagnostic[]
+--- @param diagnostics vim.Diagnostic[]
--- @param range Range4
--- @param lint string
--- @param lang string?
@@ -45,7 +45,7 @@ local function guess_query_lang(buf)
end
--- @param buf integer
---- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil
+--- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil
--- @return QueryLinterNormalizedOpts
local function normalize_opts(buf, opts)
opts = opts or {}
@@ -92,7 +92,7 @@ local function get_error_entry(err, node)
end_col = end_col + #underlined
elseif msg:match('^Invalid') then
-- Use the length of the problematic type/capture/field
- end_col = end_col + #msg:match('"([^"]+)"')
+ end_col = end_col + #(msg:match('"([^"]+)"') or '')
end
return {
@@ -114,7 +114,7 @@ end
--- @return vim.treesitter.ParseError?
local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
local query_text = vim.treesitter.get_node_text(node, buf)
- local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query
+ local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|vim.treesitter.Query
if not ok and type(err) == 'string' then
return get_error_entry(err, node)
@@ -122,28 +122,30 @@ local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
end)
--- @param buf integer
---- @param match table<integer,TSNode>
---- @param query Query
+--- @param match vim.treesitter.query.TSMatch
+--- @param query vim.treesitter.Query
--- @param lang_context QueryLinterLanguageContext
---- @param diagnostics Diagnostic[]
+--- @param diagnostics vim.Diagnostic[]
local function lint_match(buf, match, query, lang_context, diagnostics)
local lang = lang_context.lang
local parser_info = lang_context.parser_info
- for id, node in pairs(match) do
- local cap_id = query.captures[id]
+ for id, nodes in pairs(match) do
+ for _, node in ipairs(nodes) do
+ local cap_id = query.captures[id]
- -- perform language-independent checks only for first lang
- if lang_context.is_first_lang and cap_id == 'error' then
- local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
- add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text)
- end
+ -- perform language-independent checks only for first lang
+ if lang_context.is_first_lang and cap_id == 'error' then
+ local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
+ add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text)
+ end
- -- other checks rely on Neovim parser introspection
- if lang and parser_info and cap_id == 'toplevel' then
- local err = parse(node, buf, lang)
- if err then
- add_lint_for_node(diagnostics, err.range, err.msg, lang)
+ -- other checks rely on Neovim parser introspection
+ if lang and parser_info and cap_id == 'toplevel' then
+ local err = parse(node, buf, lang)
+ if err then
+ add_lint_for_node(diagnostics, err.range, err.msg, lang)
+ end
end
end
end
@@ -151,7 +153,7 @@ end
--- @private
--- @param buf integer Buffer to lint
---- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil Options for linting
+--- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil Options for linting
function M.lint(buf, opts)
if buf == 0 then
buf = api.nvim_get_current_buf()
@@ -173,7 +175,7 @@ function M.lint(buf, opts)
parser:parse()
parser:for_each_tree(function(tree, ltree)
if ltree:lang() == 'query' then
- for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1) do
+ for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1, { all = true }) do
local lang_context = {
lang = lang,
parser_info = parser_info,
@@ -195,7 +197,7 @@ function M.clear(buf)
end
--- @private
---- @param findstart integer
+--- @param findstart 0|1
--- @param base string
function M.omnifunc(findstart, base)
if findstart == 1 then
diff --git a/runtime/lua/vim/treesitter/dev.lua b/runtime/lua/vim/treesitter/dev.lua
index 69ddc9b558..dc2a14d238 100644
--- a/runtime/lua/vim/treesitter/dev.lua
+++ b/runtime/lua/vim/treesitter/dev.lua
@@ -1,31 +1,29 @@
local api = vim.api
----@class TSDevModule
local M = {}
----@class TSTreeView
+---@class (private) vim.treesitter.dev.TSTreeView
---@field ns integer API namespace
----@field opts table Options table with the following keys:
---- - anon (boolean): If true, display anonymous nodes
---- - lang (boolean): If true, display the language alongside each node
---- - indent (number): Number of spaces to indent nested lines. Default is 2.
----@field nodes TSP.Node[]
----@field named TSP.Node[]
+---@field opts vim.treesitter.dev.TSTreeViewOpts
+---@field nodes vim.treesitter.dev.Node[]
+---@field named vim.treesitter.dev.Node[]
local TSTreeView = {}
----@class TSP.Node
----@field id integer Node id
----@field text string Node text
----@field named boolean True if this is a named (non-anonymous) node
----@field depth integer Depth of the node within the tree
----@field lnum integer Beginning line number of this node in the source buffer
----@field col integer Beginning column number of this node in the source buffer
----@field end_lnum integer Final line number of this node in the source buffer
----@field end_col integer Final column number of this node in the source buffer
+---@private
+---@class (private) vim.treesitter.dev.TSTreeViewOpts
+---@field anon boolean If true, display anonymous nodes.
+---@field lang boolean If true, display the language alongside each node.
+---@field indent number Number of spaces to indent nested lines.
+
+---@class (private) vim.treesitter.dev.Node
+---@field node TSNode Treesitter node
+---@field field string? Node field
+---@field depth integer Depth of this node in the tree
+---@field text string? Text displayed in the inspector for this node. Not computed until the
+--- inspector is drawn.
---@field lang string Source language of this node
----@field root TSNode
----@class TSP.Injection
+---@class (private) vim.treesitter.dev.Injection
---@field lang string Source language of this injection
---@field root TSNode Root node of the injection
@@ -43,48 +41,26 @@ local TSTreeView = {}
---
---@param node TSNode Starting node to begin traversal |tsnode|
---@param depth integer Current recursion depth
+---@param field string|nil The field of the current node
---@param lang string Language of the tree currently being traversed
----@param injections table<string, TSP.Injection> Mapping of node ids to root nodes
+---@param injections table<string, vim.treesitter.dev.Injection> Mapping of node ids to root nodes
--- of injected language trees (see explanation above)
----@param tree TSP.Node[] Output table containing a list of tables each representing a node in the tree
-local function traverse(node, depth, lang, injections, tree)
+---@param tree vim.treesitter.dev.Node[] Output table containing a list of tables each representing a node in the tree
+local function traverse(node, depth, field, lang, injections, tree)
+ table.insert(tree, {
+ node = node,
+ depth = depth,
+ lang = lang,
+ field = field,
+ })
+
local injection = injections[node:id()]
if injection then
- traverse(injection.root, depth, injection.lang, injections, tree)
+ traverse(injection.root, depth + 1, nil, injection.lang, injections, tree)
end
- for child, field in node:iter_children() do
- local type = child:type()
- local lnum, col, end_lnum, end_col = child:range()
- local named = child:named()
- local text ---@type string
- if named then
- if field then
- text = string.format('%s: (%s', field, type)
- else
- text = string.format('(%s', type)
- end
- else
- text = string.format('"%s"', type:gsub('\n', '\\n'):gsub('"', '\\"'))
- end
-
- table.insert(tree, {
- id = child:id(),
- text = text,
- named = named,
- depth = depth,
- lnum = lnum,
- col = col,
- end_lnum = end_lnum,
- end_col = end_col,
- lang = lang,
- })
-
- traverse(child, depth + 1, lang, injections, tree)
-
- if named then
- tree[#tree].text = string.format('%s)', tree[#tree].text)
- end
+ for child, child_field in node:iter_children() do
+ traverse(child, depth + 1, child_field, lang, injections, tree)
end
return tree
@@ -95,44 +71,45 @@ end
---@param bufnr integer Source buffer number
---@param lang string|nil Language of source buffer
---
----@return TSTreeView|nil
+---@return vim.treesitter.dev.TSTreeView|nil
---@return string|nil Error message, if any
---
---@package
function TSTreeView:new(bufnr, lang)
local ok, parser = pcall(vim.treesitter.get_parser, bufnr or 0, lang)
if not ok then
- return nil, 'No parser available for the given buffer'
+ local err = parser --[[ @as string ]]
+ return nil, 'No parser available for the given buffer:\n' .. err
end
-- For each child tree (injected language), find the root of the tree and locate the node within
-- the primary tree that contains that root. Add a mapping from the node in the primary tree to
-- the root in the child tree to the {injections} table.
local root = parser:parse(true)[1]:root()
- local injections = {} ---@type table<string, TSP.Injection>
+ local injections = {} ---@type table<string, vim.treesitter.dev.Injection>
parser:for_each_tree(function(parent_tree, parent_ltree)
local parent = parent_tree:root()
for _, child in pairs(parent_ltree:children()) do
- child:for_each_tree(function(tree, ltree)
+ for _, tree in pairs(child:trees()) do
local r = tree:root()
local node = assert(parent:named_descendant_for_range(r:range()))
local id = node:id()
if not injections[id] or r:byte_length() > injections[id].root:byte_length() then
injections[id] = {
- lang = ltree:lang(),
+ lang = child:lang(),
root = r,
}
end
- end)
+ end
end
end)
- local nodes = traverse(root, 0, parser:lang(), injections, {})
+ local nodes = traverse(root, 0, nil, parser:lang(), injections, {})
- local named = {} ---@type TSP.Node[]
+ local named = {} ---@type vim.treesitter.dev.Node[]
for _, v in ipairs(nodes) do
- if v.named then
+ if v.node:named() then
named[#named + 1] = v
end
end
@@ -141,6 +118,7 @@ function TSTreeView:new(bufnr, lang)
ns = api.nvim_create_namespace('treesitter/dev-inspect'),
nodes = nodes,
named = named,
+ ---@type vim.treesitter.dev.TSTreeViewOpts
opts = {
anon = false,
lang = false,
@@ -155,16 +133,12 @@ end
local decor_ns = api.nvim_create_namespace('ts.dev')
----@param lnum integer
----@param col integer
----@param end_lnum integer
----@param end_col integer
+---@param range Range4
---@return string
-local function get_range_str(lnum, col, end_lnum, end_col)
- if lnum == end_lnum then
- return string.format('[%d:%d - %d]', lnum + 1, col + 1, end_col)
- end
- return string.format('[%d:%d - %d:%d]', lnum + 1, col + 1, end_lnum + 1, end_col)
+local function range_to_string(range)
+ ---@type integer, integer, integer, integer
+ local row, col, end_row, end_col = unpack(range)
+ return string.format('[%d, %d] - [%d, %d]', row, col, end_row, end_col)
end
---@param w integer
@@ -183,7 +157,10 @@ end
local function set_dev_properties(w, b)
vim.wo[w].scrolloff = 5
vim.wo[w].wrap = false
- vim.wo[w].foldmethod = 'manual' -- disable folding
+ vim.wo[w].foldmethod = 'expr'
+ vim.wo[w].foldexpr = 'v:lua.vim.treesitter.foldexpr()' -- explicitly set foldexpr
+ vim.wo[w].foldenable = false -- Don't fold on first open InspectTree
+ vim.wo[w].foldlevel = 99
vim.bo[b].buflisted = false
vim.bo[b].buftype = 'nofile'
vim.bo[b].bufhidden = 'wipe'
@@ -192,7 +169,7 @@ end
--- Updates the cursor position in the inspector to match the node under the cursor.
---
---- @param treeview TSTreeView
+--- @param treeview vim.treesitter.dev.TSTreeView
--- @param lang string
--- @param source_buf integer
--- @param inspect_buf integer
@@ -213,7 +190,7 @@ local function set_inspector_cursor(treeview, lang, source_buf, inspect_buf, ins
local cursor_node_id = cursor_node:id()
for i, v in treeview:iter() do
- if v.id == cursor_node_id then
+ if v.node:id() == cursor_node_id then
local start = v.depth * treeview.opts.indent ---@type integer
local end_col = start + #v.text
api.nvim_buf_set_extmark(inspect_buf, treeview.ns, i - 1, start, {
@@ -228,6 +205,8 @@ end
--- Write the contents of this View into {bufnr}.
---
+--- Calling this function computes the text that is displayed for each node.
+---
---@param bufnr integer Buffer number to write into.
---@package
function TSTreeView:draw(bufnr)
@@ -235,13 +214,35 @@ function TSTreeView:draw(bufnr)
local lines = {} ---@type string[]
local lang_hl_marks = {} ---@type table[]
- for _, item in self:iter() do
- local range_str = get_range_str(item.lnum, item.col, item.end_lnum, item.end_col)
+ for i, item in self:iter() do
+ local range_str = range_to_string({ item.node:range() })
local lang_str = self.opts.lang and string.format(' %s', item.lang) or ''
+
+ local text ---@type string
+ if item.node:named() then
+ if item.field then
+ text = string.format('%s: (%s', item.field, item.node:type())
+ else
+ text = string.format('(%s', item.node:type())
+ end
+ else
+ text = string.format('"%s"', item.node:type():gsub('\n', '\\n'):gsub('"', '\\"'))
+ end
+
+ local next = self:get(i + 1)
+ if not next or next.depth <= item.depth then
+ local parens = item.depth - (next and next.depth or 0) + (item.node:named() and 1 or 0)
+ if parens > 0 then
+ text = string.format('%s%s', text, string.rep(')', parens))
+ end
+ end
+
+ item.text = text
+
local line = string.format(
'%s%s ; %s%s',
string.rep(' ', item.depth * self.opts.indent),
- item.text,
+ text,
range_str,
lang_str
)
@@ -253,7 +254,7 @@ function TSTreeView:draw(bufnr)
}
end
- lines[#lines + 1] = line
+ lines[i] = line
end
api.nvim_buf_set_lines(bufnr, 0, -1, false, lines)
@@ -275,7 +276,7 @@ end
--- The node number is dependent on whether or not anonymous nodes are displayed.
---
---@param i integer Node number to get
----@return TSP.Node
+---@return vim.treesitter.dev.Node
---@package
function TSTreeView:get(i)
local t = self.opts.anon and self.nodes or self.named
@@ -284,7 +285,7 @@ end
--- Iterate over all of the nodes in this View.
---
----@return (fun(): integer, TSP.Node) Iterator over all nodes in this View
+---@return (fun(): integer, vim.treesitter.dev.Node) Iterator over all nodes in this View
---@return table
---@return integer
---@package
@@ -292,22 +293,31 @@ function TSTreeView:iter()
return ipairs(self.opts.anon and self.nodes or self.named)
end
---- @class InspectTreeOpts
---- @field lang string? The language of the source buffer. If omitted, the
---- filetype of the source buffer is used.
---- @field bufnr integer? Buffer to draw the tree into. If omitted, a new
---- buffer is created.
---- @field winid integer? Window id to display the tree buffer in. If omitted,
---- a new window is created with {command}.
---- @field command string? Vimscript command to create the window. Default
---- value is "60vnew". Only used when {winid} is nil.
---- @field title (string|fun(bufnr:integer):string|nil) Title of the window. If a
---- function, it accepts the buffer number of the source
---- buffer as its only argument and should return a string.
+--- @class vim.treesitter.dev.inspect_tree.Opts
+--- @inlinedoc
+---
+--- The language of the source buffer. If omitted, the filetype of the source
+--- buffer is used.
+--- @field lang string?
+---
+--- Buffer to draw the tree into. If omitted, a new buffer is created.
+--- @field bufnr integer?
+---
+--- Window id to display the tree buffer in. If omitted, a new window is
+--- created with {command}.
+--- @field winid integer?
+---
+--- Vimscript command to create the window. Default value is "60vnew".
+--- Only used when {winid} is nil.
+--- @field command string?
+---
+--- Title of the window. If a function, it accepts the buffer number of the
+--- source buffer as its only argument and should return a string.
+--- @field title (string|fun(bufnr:integer):string|nil)
--- @private
---
---- @param opts InspectTreeOpts?
+--- @param opts vim.treesitter.dev.inspect_tree.Opts?
function M.inspect_tree(opts)
vim.validate({
opts = { opts, 't', true },
@@ -364,9 +374,9 @@ function M.inspect_tree(opts)
desc = 'Jump to the node under the cursor in the source buffer',
callback = function()
local row = api.nvim_win_get_cursor(w)[1]
- local pos = treeview:get(row)
+ local lnum, col = treeview:get(row).node:start()
api.nvim_set_current_win(win)
- api.nvim_win_set_cursor(win, { pos.lnum + 1, pos.col })
+ api.nvim_win_set_cursor(win, { lnum + 1, col })
end,
})
api.nvim_buf_set_keymap(b, 'n', 'a', '', {
@@ -374,7 +384,7 @@ function M.inspect_tree(opts)
callback = function()
local row, col = unpack(api.nvim_win_get_cursor(w)) ---@type integer, integer
local curnode = treeview:get(row)
- while curnode and not curnode.named do
+ while curnode and not curnode.node:named() do
row = row - 1
curnode = treeview:get(row)
end
@@ -386,9 +396,9 @@ function M.inspect_tree(opts)
return
end
- local id = curnode.id
+ local id = curnode.node:id()
for i, node in treeview:iter() do
- if node.id == id then
+ if node.node:id() == id then
api.nvim_win_set_cursor(w, { i, col })
break
end
@@ -424,20 +434,20 @@ function M.inspect_tree(opts)
api.nvim_buf_clear_namespace(buf, treeview.ns, 0, -1)
local row = api.nvim_win_get_cursor(w)[1]
- local pos = treeview:get(row)
- api.nvim_buf_set_extmark(buf, treeview.ns, pos.lnum, pos.col, {
- end_row = pos.end_lnum,
- end_col = math.max(0, pos.end_col),
+ local lnum, col, end_lnum, end_col = treeview:get(row).node:range()
+ api.nvim_buf_set_extmark(buf, treeview.ns, lnum, col, {
+ end_row = end_lnum,
+ end_col = math.max(0, end_col),
hl_group = 'Visual',
})
local topline, botline = vim.fn.line('w0', win), vim.fn.line('w$', win)
-- Move the cursor if highlighted range is completely out of view
- if pos.lnum < topline and pos.end_lnum < topline then
- api.nvim_win_set_cursor(win, { pos.end_lnum + 1, 0 })
- elseif pos.lnum > botline and pos.end_lnum > botline then
- api.nvim_win_set_cursor(win, { pos.lnum + 1, 0 })
+ if lnum < topline and end_lnum < topline then
+ api.nvim_win_set_cursor(win, { end_lnum + 1, 0 })
+ elseif lnum > botline and end_lnum > botline then
+ api.nvim_win_set_cursor(win, { lnum + 1, 0 })
end
end,
})
@@ -462,7 +472,9 @@ function M.inspect_tree(opts)
return true
end
+ local treeview_opts = treeview.opts
treeview = assert(TSTreeView:new(buf, opts.lang))
+ treeview.opts = treeview_opts
treeview:draw(b)
end,
})
diff --git a/runtime/lua/vim/treesitter/health.lua b/runtime/lua/vim/treesitter/health.lua
index ed1161e97f..a9b066d158 100644
--- a/runtime/lua/vim/treesitter/health.lua
+++ b/runtime/lua/vim/treesitter/health.lua
@@ -1,6 +1,6 @@
local M = {}
local ts = vim.treesitter
-local health = require('vim.health')
+local health = vim.health
--- Performs a healthcheck for treesitter integration
function M.check()
diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua
index 496193c6ed..388680259a 100644
--- a/runtime/lua/vim/treesitter/highlighter.lua
+++ b/runtime/lua/vim/treesitter/highlighter.lua
@@ -2,50 +2,25 @@ local api = vim.api
local query = vim.treesitter.query
local Range = require('vim.treesitter._range')
----@alias TSHlIter fun(end_line: integer|nil): integer, TSNode, TSMetadata
-
----@class TSHighlightState
----@field next_row integer
----@field iter TSHlIter|nil
-
----@class TSHighlighter
----@field active table<integer,TSHighlighter>
----@field bufnr integer
----@field orig_spelloptions string
----@field _highlight_states table<TSTree,TSHighlightState>
----@field _queries table<string,TSHighlighterQuery>
----@field tree LanguageTree
----@field redraw_count integer
-local TSHighlighter = rawget(vim.treesitter, 'TSHighlighter') or {}
-TSHighlighter.__index = TSHighlighter
+local ns = api.nvim_create_namespace('treesitter/highlighter')
---- @nodoc
-TSHighlighter.active = TSHighlighter.active or {}
+---@alias vim.treesitter.highlighter.Iter fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata
----@class TSHighlighterQuery
----@field _query Query|nil
----@field hl_cache table<integer,integer>
+---@class (private) vim.treesitter.highlighter.Query
+---@field private _query vim.treesitter.Query?
+---@field private lang string
+---@field private hl_cache table<integer,integer>
local TSHighlighterQuery = {}
TSHighlighterQuery.__index = TSHighlighterQuery
-local ns = api.nvim_create_namespace('treesitter/highlighter')
-
---@private
+---@param lang string
+---@param query_string string?
+---@return vim.treesitter.highlighter.Query
function TSHighlighterQuery.new(lang, query_string)
- local self = setmetatable({}, { __index = TSHighlighterQuery })
-
- self.hl_cache = setmetatable({}, {
- __index = function(table, capture)
- local name = self._query.captures[capture]
- local id = 0
- if not vim.startswith(name, '_') then
- id = api.nvim_get_hl_id_by_name('@' .. name .. '.' .. lang)
- end
-
- rawset(table, capture, id)
- return id
- end,
- })
+ local self = setmetatable({}, TSHighlighterQuery)
+ self.lang = lang
+ self.hl_cache = {}
if query_string then
self._query = query.parse(lang, query_string)
@@ -57,18 +32,57 @@ function TSHighlighterQuery.new(lang, query_string)
end
---@package
+---@param capture integer
+---@return integer?
+function TSHighlighterQuery:get_hl_from_capture(capture)
+ if not self.hl_cache[capture] then
+ local name = self._query.captures[capture]
+ local id = 0
+ if not vim.startswith(name, '_') then
+ id = api.nvim_get_hl_id_by_name('@' .. name .. '.' .. self.lang)
+ end
+ self.hl_cache[capture] = id
+ end
+
+ return self.hl_cache[capture]
+end
+
+---@package
function TSHighlighterQuery:query()
return self._query
end
+---@class (private) vim.treesitter.highlighter.State
+---@field tstree TSTree
+---@field next_row integer
+---@field iter vim.treesitter.highlighter.Iter?
+---@field highlighter_query vim.treesitter.highlighter.Query
+
+---@nodoc
+---@class vim.treesitter.highlighter
+---@field active table<integer,vim.treesitter.highlighter>
+---@field bufnr integer
+---@field private orig_spelloptions string
+--- A map of highlight states.
+--- This state is kept during rendering across each line update.
+---@field private _highlight_states vim.treesitter.highlighter.State[]
+---@field private _queries table<string,vim.treesitter.highlighter.Query>
+---@field tree vim.treesitter.LanguageTree
+---@field private redraw_count integer
+local TSHighlighter = {
+ active = {},
+}
+
+TSHighlighter.__index = TSHighlighter
+
---@package
---
--- Creates a highlighter for `tree`.
---
----@param tree LanguageTree parser object to use for highlighting
+---@param tree vim.treesitter.LanguageTree parser object to use for highlighting
---@param opts (table|nil) Configuration of the highlighter:
--- - queries table overwrite queries used by the highlighter
----@return TSHighlighter Created highlighter object
+---@return vim.treesitter.highlighter Created highlighter object
function TSHighlighter.new(tree, opts)
local self = setmetatable({}, TSHighlighter)
@@ -98,15 +112,12 @@ function TSHighlighter.new(tree, opts)
end,
}, true)
- self.bufnr = tree:source() --[[@as integer]]
- self.edit_count = 0
+ local source = tree:source()
+ assert(type(source) == 'number')
+
+ self.bufnr = source
self.redraw_count = 0
- self.line_count = {}
- -- A map of highlight states.
- -- This state is kept during rendering across each line update.
self._highlight_states = {}
-
- ---@type table<string,TSHighlighterQuery>
self._queries = {}
-- Queries for a specific language can be overridden by a custom
@@ -144,11 +155,9 @@ end
--- @nodoc
--- Removes all internal references to the highlighter
function TSHighlighter:destroy()
- if TSHighlighter.active[self.bufnr] then
- TSHighlighter.active[self.bufnr] = nil
- end
+ TSHighlighter.active[self.bufnr] = nil
- if vim.api.nvim_buf_is_loaded(self.bufnr) then
+ if api.nvim_buf_is_loaded(self.bufnr) then
vim.bo[self.bufnr].spelloptions = self.orig_spelloptions
vim.b[self.bufnr].ts_highlight = nil
if vim.g.syntax_on == 1 then
@@ -157,23 +166,49 @@ function TSHighlighter:destroy()
end
end
----@package
----@param tstree TSTree
----@return TSHighlightState
-function TSHighlighter:get_highlight_state(tstree)
- if not self._highlight_states[tstree] then
- self._highlight_states[tstree] = {
+---@param srow integer
+---@param erow integer exclusive
+---@private
+function TSHighlighter:prepare_highlight_states(srow, erow)
+ self._highlight_states = {}
+
+ self.tree:for_each_tree(function(tstree, tree)
+ if not tstree then
+ return
+ end
+
+ local root_node = tstree:root()
+ local root_start_row, _, root_end_row, _ = root_node:range()
+
+ -- Only consider trees within the visible range
+ if root_start_row > erow or root_end_row < srow then
+ return
+ end
+
+ local highlighter_query = self:get_query(tree:lang())
+
+ -- Some injected languages may not have highlight queries.
+ if not highlighter_query:query() then
+ return
+ end
+
+ -- _highlight_states should be a list so that the highlights are added in the same order as
+ -- for_each_tree traversal. This ensures that parents' highlight don't override children's.
+ table.insert(self._highlight_states, {
+ tstree = tstree,
next_row = 0,
iter = nil,
- }
- end
-
- return self._highlight_states[tstree]
+ highlighter_query = highlighter_query,
+ })
+ end)
end
----@private
-function TSHighlighter:reset_highlight_state()
- self._highlight_states = {}
+---@param fn fun(state: vim.treesitter.highlighter.State)
+---@package
+function TSHighlighter:for_each_highlight_state(fn)
+ for _, state in ipairs(self._highlight_states) do
+ fn(state)
+ end
end
---@package
@@ -197,10 +232,9 @@ function TSHighlighter:on_changedtree(changes)
end
--- Gets the query used for @param lang
---
---@package
---@param lang string Language used by the highlighter.
----@return TSHighlighterQuery
+---@return vim.treesitter.highlighter.Query
function TSHighlighter:get_query(lang)
if not self._queries[lang] then
self._queries[lang] = TSHighlighterQuery.new(lang)
@@ -209,35 +243,23 @@ function TSHighlighter:get_query(lang)
return self._queries[lang]
end
----@param self TSHighlighter
+---@param self vim.treesitter.highlighter
---@param buf integer
---@param line integer
---@param is_spell_nav boolean
local function on_line_impl(self, buf, line, is_spell_nav)
- self.tree:for_each_tree(function(tstree, tree)
- if not tstree then
- return
- end
-
- local root_node = tstree:root()
+ self:for_each_highlight_state(function(state)
+ local root_node = state.tstree:root()
local root_start_row, _, root_end_row, _ = root_node:range()
- -- Only worry about trees within the line range
+ -- Only consider trees that contain this line
if root_start_row > line or root_end_row < line then
return
end
- local state = self:get_highlight_state(tstree)
- local highlighter_query = self:get_query(tree:lang())
-
- -- Some injected languages may not have highlight queries.
- if not highlighter_query:query() then
- return
- end
-
if state.iter == nil or state.next_row < line then
state.iter =
- highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1)
+ state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1)
end
while line >= state.next_row do
@@ -250,9 +272,9 @@ local function on_line_impl(self, buf, line, is_spell_nav)
local start_row, start_col, end_row, end_col = Range.unpack4(range)
if capture then
- local hl = highlighter_query.hl_cache[capture]
+ local hl = state.highlighter_query:get_hl_from_capture(capture)
- local capture_name = highlighter_query:query().captures[capture]
+ local capture_name = state.highlighter_query:query().captures[capture]
local spell = nil ---@type boolean?
if capture_name == 'spell' then
spell = true
@@ -308,7 +330,7 @@ function TSHighlighter._on_spell_nav(_, _, buf, srow, _, erow, _)
return
end
- self:reset_highlight_state()
+ self:prepare_highlight_states(srow, erow)
for row = srow, erow do
on_line_impl(self, buf, row, true)
@@ -326,7 +348,7 @@ function TSHighlighter._on_win(_, _win, buf, topline, botline)
return false
end
self.tree:parse({ topline, botline + 1 })
- self:reset_highlight_state()
+ self:prepare_highlight_states(topline, botline + 1)
self.redraw_count = self.redraw_count + 1
return true
end
diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua
index 15bf666a1e..47abf65332 100644
--- a/runtime/lua/vim/treesitter/language.lua
+++ b/runtime/lua/vim/treesitter/language.lua
@@ -1,6 +1,5 @@
local api = vim.api
----@class TSLanguageModule
local M = {}
---@type table<string,string>
@@ -37,6 +36,11 @@ end
---@deprecated
function M.require_language(lang, path, silent, symbol_name)
+ vim.deprecate(
+ 'vim.treesitter.language.require_language()',
+ 'vim.treesitter.language.add()',
+ '0.12'
+ )
local opts = {
silent = silent,
path = path,
@@ -52,10 +56,17 @@ function M.require_language(lang, path, silent, symbol_name)
return true
end
----@class treesitter.RequireLangOpts
----@field path? string
----@field silent? boolean
+---@class vim.treesitter.language.add.Opts
+---@inlinedoc
+---
+---Default filetype the parser should be associated with.
+---(Default: {lang})
---@field filetype? string|string[]
+---
+---Optional path the parser is located at
+---@field path? string
+---
+---Internal symbol name for the language to load
---@field symbol_name? string
--- Load parser with name {lang}
@@ -63,13 +74,8 @@ end
--- Parsers are searched in the `parser` runtime directory, or the provided {path}
---
---@param lang string Name of the parser (alphanumerical and `_` only)
----@param opts (table|nil) Options:
---- - filetype (string|string[]) Default filetype the parser should be associated with.
---- Defaults to {lang}.
---- - path (string|nil) Optional path the parser is located at
---- - symbol_name (string|nil) Internal symbol name for the language to load
+---@param opts? vim.treesitter.language.add.Opts Options:
function M.add(lang, opts)
- ---@cast opts treesitter.RequireLangOpts
opts = opts or {}
local path = opts.path
local filetype = opts.filetype or lang
@@ -114,6 +120,10 @@ local function ensure_list(x)
end
--- Register a parser named {lang} to be used for {filetype}(s).
+---
+--- Note: this adds or overrides the mapping for {filetype}, any existing mappings from other
+--- filetypes to {lang} will be preserved.
+---
--- @param lang string Name of parser
--- @param filetype string|string[] Filetype(s) to associate with lang
function M.register(lang, filetype)
@@ -140,14 +150,4 @@ function M.inspect(lang)
return vim._ts_inspect_language(lang)
end
----@deprecated
-function M.inspect_language(...)
- vim.deprecate(
- 'vim.treesitter.language.inspect_language()',
- 'vim.treesitter.language.inspect()',
- '0.10'
- )
- return M.inspect(...)
-end
-
return M
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index 0171b416cd..62714d3f1b 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -1,6 +1,4 @@
---- @defgroup lua-treesitter-languagetree
----
---- @brief A \*LanguageTree\* contains a tree of parsers: the root treesitter parser for {lang} and
+--- @brief A [LanguageTree]() contains a tree of parsers: the root treesitter parser for {lang} and
--- any "injected" language parsers, which themselves may inject other languages, recursively.
--- For example a Lua buffer containing some Vimscript commands needs multiple parsers to fully
--- understand its contents.
@@ -69,11 +67,12 @@ local TSCallbackNames = {
on_child_removed = 'child_removed',
}
----@class LanguageTree
+---@nodoc
+---@class vim.treesitter.LanguageTree
---@field private _callbacks table<TSCallbackName,function[]> Callback handlers
---@field package _callbacks_rec table<TSCallbackName,function[]> Callback handlers (recursive)
----@field private _children table<string,LanguageTree> Injected languages
----@field private _injection_query Query Queries defining injected languages
+---@field private _children table<string,vim.treesitter.LanguageTree> Injected languages
+---@field private _injection_query vim.treesitter.Query Queries defining injected languages
---@field private _injections_processed boolean
---@field private _opts table Options
---@field private _parser TSParser Parser for language
@@ -91,9 +90,11 @@ local TSCallbackNames = {
---@field private _logfile? file*
local LanguageTree = {}
----@class LanguageTreeOpts
----@field queries table<string,string> -- Deprecated
----@field injections table<string,string>
+---Optional arguments:
+---@class vim.treesitter.LanguageTree.new.Opts
+---@inlinedoc
+---@field queries? table<string,string> -- Deprecated
+---@field injections? table<string,string>
LanguageTree.__index = LanguageTree
@@ -104,14 +105,11 @@ LanguageTree.__index = LanguageTree
---
---@param source (integer|string) Buffer or text string to parse
---@param lang string Root language of this tree
----@param opts (table|nil) Optional arguments:
---- - injections table Map of language to injection query strings. Overrides the
---- built-in runtime file searching for language injections.
+---@param opts vim.treesitter.LanguageTree.new.Opts?
---@param parent_lang? string Parent language name of this tree
----@return LanguageTree parser object
+---@return vim.treesitter.LanguageTree parser object
function LanguageTree.new(source, lang, opts, parent_lang)
language.add(lang)
- ---@type LanguageTreeOpts
opts = opts or {}
if source == 0 then
@@ -120,7 +118,7 @@ function LanguageTree.new(source, lang, opts, parent_lang)
local injections = opts.injections or {}
- --- @type LanguageTree
+ --- @type vim.treesitter.LanguageTree
local self = {
_source = source,
_lang = lang,
@@ -196,7 +194,7 @@ local function tcall(f, ...)
end
---@private
----@vararg any
+---@param ... any
function LanguageTree:_log(...)
if not self._logger then
return
@@ -348,7 +346,13 @@ function LanguageTree:_parse_regions(range)
-- If there are no ranges, set to an empty list
-- so the included ranges in the parser are cleared.
for i, ranges in pairs(self:included_regions()) do
- if not self._valid[i] and intercepts_region(ranges, range) then
+ if
+ not self._valid[i]
+ and (
+ intercepts_region(ranges, range)
+ or (self._trees[i] and intercepts_region(self._trees[i]:included_ranges(false), range))
+ )
+ then
self._parser:set_included_ranges(ranges)
local parse_time, tree, tree_changes =
tcall(self._parser.parse, self._parser, self._trees[i], self._source, true)
@@ -427,7 +431,7 @@ function LanguageTree:parse(range)
local query_time = 0
local total_parse_time = 0
- --- At least 1 region is invalid
+ -- At least 1 region is invalid
if not self:is_valid(true) then
changes, no_regions_parsed, total_parse_time = self:_parse_regions(range)
-- Need to run injections when we parsed something
@@ -460,7 +464,7 @@ end
--- add recursion yourself if needed.
--- Invokes the callback for each |LanguageTree| and its children recursively
---
----@param fn fun(tree: LanguageTree, lang: string)
+---@param fn fun(tree: vim.treesitter.LanguageTree, lang: string)
---@param include_self boolean|nil Whether to include the invoking tree in the results
function LanguageTree:for_each_child(fn, include_self)
vim.deprecate('LanguageTree:for_each_child()', 'LanguageTree:children()', '0.11')
@@ -469,6 +473,7 @@ function LanguageTree:for_each_child(fn, include_self)
end
for _, child in pairs(self._children) do
+ --- @diagnostic disable-next-line:deprecated
child:for_each_child(fn, true)
end
end
@@ -477,7 +482,7 @@ end
---
--- Note: This includes the invoking tree's child trees as well.
---
----@param fn fun(tree: TSTree, ltree: LanguageTree)
+---@param fn fun(tree: TSTree, ltree: vim.treesitter.LanguageTree)
function LanguageTree:for_each_tree(fn)
for _, tree in pairs(self._trees) do
fn(tree, self)
@@ -494,7 +499,7 @@ end
---
---@private
---@param lang string Language to add.
----@return LanguageTree injected
+---@return vim.treesitter.LanguageTree injected
function LanguageTree:add_child(lang)
if self._children[lang] then
self:remove_child(lang)
@@ -664,7 +669,7 @@ end
---@param node TSNode
---@param source string|integer
----@param metadata TSMetadata
+---@param metadata vim.treesitter.query.TSMetadata
---@param include_children boolean
---@return Range6[]
local function get_node_ranges(node, source, metadata, include_children)
@@ -698,13 +703,14 @@ local function get_node_ranges(node, source, metadata, include_children)
return ranges
end
----@class TSInjectionElem
+---@nodoc
+---@class vim.treesitter.languagetree.InjectionElem
---@field combined boolean
---@field regions Range6[][]
----@alias TSInjection table<string,table<integer,TSInjectionElem>>
+---@alias vim.treesitter.languagetree.Injection table<string,table<integer,vim.treesitter.languagetree.InjectionElem>>
----@param t table<integer,TSInjection>
+---@param t table<integer,vim.treesitter.languagetree.Injection>
---@param tree_index integer
---@param pattern integer
---@param lang string
@@ -751,6 +757,11 @@ end)
---@param alias string language or filetype name
---@return string? # resolved parser name
local function resolve_lang(alias)
+ -- validate that `alias` is a legal language
+ if not (alias and alias:match('[%w_]+') == alias) then
+ return
+ end
+
if has_parser(alias) then
return alias
end
@@ -773,8 +784,8 @@ end
---@private
--- Extract injections according to:
--- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection
----@param match table<integer,TSNode>
----@param metadata TSMetadata
+---@param match table<integer,TSNode[]>
+---@param metadata vim.treesitter.query.TSMetadata
---@return string?, boolean, Range6[]
function LanguageTree:_get_injection(match, metadata)
local ranges = {} ---@type Range6[]
@@ -785,14 +796,16 @@ function LanguageTree:_get_injection(match, metadata)
or (injection_lang and resolve_lang(injection_lang))
local include_children = metadata['injection.include-children'] ~= nil
- for id, node in pairs(match) do
- local name = self._injection_query.captures[id]
- -- Lang should override any other language tag
- if name == 'injection.language' then
- local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] })
- lang = resolve_lang(text)
- elseif name == 'injection.content' then
- ranges = get_node_ranges(node, self._source, metadata[id], include_children)
+ for id, nodes in pairs(match) do
+ for _, node in ipairs(nodes) do
+ local name = self._injection_query.captures[id]
+ -- Lang should override any other language tag
+ if name == 'injection.language' then
+ local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] })
+ lang = resolve_lang(text)
+ elseif name == 'injection.content' then
+ ranges = get_node_ranges(node, self._source, metadata[id], include_children)
+ end
end
end
@@ -825,7 +838,7 @@ function LanguageTree:_get_injections()
return {}
end
- ---@type table<integer,TSInjection>
+ ---@type table<integer,vim.treesitter.languagetree.Injection>
local injections = {}
for index, tree in pairs(self._trees) do
@@ -833,7 +846,13 @@ function LanguageTree:_get_injections()
local start_line, _, end_line, _ = root_node:range()
for pattern, match, metadata in
- self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1)
+ self._injection_query:iter_matches(
+ root_node,
+ self._source,
+ start_line,
+ end_line + 1,
+ { all = true }
+ )
do
local lang, combined, ranges = self:_get_injection(match, metadata)
if lang then
@@ -1133,7 +1152,7 @@ end
--- Gets the appropriate language that contains {range}.
---
---@param range Range4 `{ start_line, start_col, end_line, end_col }`
----@return LanguageTree Managing {range}
+---@return vim.treesitter.LanguageTree Managing {range}
function LanguageTree:language_for_range(range)
for _, child in pairs(self._children) do
if child:contains(range) then
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 8cbbffcd60..a086f5e876 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -1,19 +1,50 @@
local api = vim.api
local language = require('vim.treesitter.language')
----@class Query
----@field captures string[] List of captures used in query
----@field info TSQueryInfo Contains used queries, predicates, directives
----@field query userdata Parsed query
+local M = {}
+
+---@nodoc
+---Parsed query, see |vim.treesitter.query.parse()|
+---
+---@class vim.treesitter.Query
+---@field lang string name of the language for this parser
+---@field captures string[] list of (unique) capture names defined in query
+---@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives)
+---@field query TSQuery userdata query object
local Query = {}
Query.__index = Query
----@class TSQueryInfo
----@field captures table
----@field patterns table<string,any[][]>
+---@package
+---@see vim.treesitter.query.parse
+---@param lang string
+---@param ts_query TSQuery
+---@return vim.treesitter.Query
+function Query.new(lang, ts_query)
+ local self = setmetatable({}, Query)
+ local query_info = ts_query:inspect() ---@type TSQueryInfo
+ self.query = ts_query
+ self.lang = lang
+ self.info = {
+ captures = query_info.captures,
+ patterns = query_info.patterns,
+ }
+ self.captures = self.info.captures
+ return self
+end
----@class TSQueryModule
-local M = {}
+---@nodoc
+---Information for Query, see |vim.treesitter.query.parse()|
+---@class vim.treesitter.QueryInfo
+---
+---List of (unique) capture names defined in query.
+---@field captures string[]
+---
+---Contains information about predicates and directives.
+---Key is pattern id, and value is list of predicates or directives defined in the pattern.
+---A predicate or directive is a list of (integer|string); integer represents `capture_id`, and
+---string represents (literal) arguments to predicate/directive. See |treesitter-predicates|
+---and |treesitter-directives| for more details.
+---@field patterns table<integer, (integer|string)[][]>
---@param files string[]
---@return string[]
@@ -53,16 +84,6 @@ local function add_included_lang(base_langs, lang, ilang)
return false
end
----@deprecated
-function M.get_query_files(...)
- vim.deprecate(
- 'vim.treesitter.query.get_query_files()',
- 'vim.treesitter.query.get_files()',
- '0.10'
- )
- return M.get_files(...)
-end
-
--- Gets the list of files used to make up a query
---
---@param lang string Language to get query for
@@ -163,7 +184,7 @@ local function read_query_files(filenames)
end
-- The explicitly set queries from |vim.treesitter.query.set()|
----@type table<string,table<string,Query>>
+---@type table<string,table<string,vim.treesitter.Query>>
local explicit_queries = setmetatable({}, {
__index = function(t, k)
local lang_queries = {}
@@ -173,12 +194,6 @@ local explicit_queries = setmetatable({}, {
end,
})
----@deprecated
-function M.set_query(...)
- vim.deprecate('vim.treesitter.query.set_query()', 'vim.treesitter.query.set()', '0.10')
- M.set(...)
-end
-
--- Sets the runtime query named {query_name} for {lang}
---
--- This allows users to override any runtime files and/or configuration
@@ -191,18 +206,12 @@ function M.set(lang, query_name, text)
explicit_queries[lang][query_name] = M.parse(lang, text)
end
----@deprecated
-function M.get_query(...)
- vim.deprecate('vim.treesitter.query.get_query()', 'vim.treesitter.query.get()', '0.10')
- return M.get(...)
-end
-
--- Returns the runtime query {query_name} for {lang}.
---
---@param lang string Language to use for the query
---@param query_name string Name of the query (e.g. "highlights")
---
----@return Query|nil Parsed query
+---@return vim.treesitter.Query|nil : Parsed query. `nil` if no query files are found.
M.get = vim.func._memoize('concat-2', function(lang, query_name)
if explicit_queries[lang][query_name] then
return explicit_queries[lang][query_name]
@@ -218,92 +227,96 @@ M.get = vim.func._memoize('concat-2', function(lang, query_name)
return M.parse(lang, query_string)
end)
----@deprecated
-function M.parse_query(...)
- vim.deprecate('vim.treesitter.query.parse_query()', 'vim.treesitter.query.parse()', '0.10')
- return M.parse(...)
-end
-
--- Parse {query} as a string. (If the query is in a file, the caller
--- should read the contents into a string before calling).
---
--- Returns a `Query` (see |lua-treesitter-query|) object which can be used to
--- search nodes in the syntax tree for the patterns defined in {query}
---- using `iter_*` methods below.
+--- using the `iter_captures` and `iter_matches` methods.
---
--- Exposes `info` and `captures` with additional context about {query}.
---- - `captures` contains the list of unique capture names defined in
---- {query}.
---- -` info.captures` also points to `captures`.
+--- - `captures` contains the list of unique capture names defined in {query}.
+--- - `info.captures` also points to `captures`.
--- - `info.patterns` contains information about predicates.
---
---@param lang string Language to use for the query
---@param query string Query in s-expr syntax
---
----@return Query Parsed query
+---@return vim.treesitter.Query Parsed query
+---
+---@see |vim.treesitter.query.get()|
M.parse = vim.func._memoize('concat-2', function(lang, query)
language.add(lang)
- local self = setmetatable({}, Query)
- self.query = vim._ts_parse_query(lang, query)
- self.info = self.query:inspect()
- self.captures = self.info.captures
- return self
+ local ts_query = vim._ts_parse_query(lang, query)
+ return Query.new(lang, ts_query)
end)
----@deprecated
-function M.get_range(...)
- vim.deprecate('vim.treesitter.query.get_range()', 'vim.treesitter.get_range()', '0.10')
- return vim.treesitter.get_range(...)
-end
-
----@deprecated
-function M.get_node_text(...)
- vim.deprecate('vim.treesitter.query.get_node_text()', 'vim.treesitter.get_node_text()', '0.10')
- return vim.treesitter.get_node_text(...)
-end
-
----@alias TSMatch table<integer,TSNode>
-
----@alias TSPredicate fun(match: TSMatch, _, _, predicate: any[]): boolean
-
--- Predicate handler receive the following arguments
--- (match, pattern, bufnr, predicate)
----@type table<string,TSPredicate>
-local predicate_handlers = {
- ['eq?'] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- if not node then
+--- Implementations of predicates that can optionally be prefixed with "any-".
+---
+--- These functions contain the implementations for each predicate, correctly
+--- handling the "any" vs "all" semantics. They are called from the
+--- predicate_handlers table with the appropriate arguments for each predicate.
+local impl = {
+ --- @param match vim.treesitter.query.TSMatch
+ --- @param source integer|string
+ --- @param predicate any[]
+ --- @param any boolean
+ ['eq'] = function(match, source, predicate, any)
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local node_text = vim.treesitter.get_node_text(node, source)
- local str ---@type string
- if type(predicate[3]) == 'string' then
- -- (#eq? @aa "foo")
- str = predicate[3]
- else
- -- (#eq? @aa @bb)
- str = vim.treesitter.get_node_text(match[predicate[3]], source)
- end
+ for _, node in ipairs(nodes) do
+ local node_text = vim.treesitter.get_node_text(node, source)
+
+ local str ---@type string
+ if type(predicate[3]) == 'string' then
+ -- (#eq? @aa "foo")
+ str = predicate[3]
+ else
+ -- (#eq? @aa @bb)
+ local other = assert(match[predicate[3]])
+ assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes')
+ str = vim.treesitter.get_node_text(other[1], source)
+ end
- if node_text ~= str or str == nil then
- return false
+ local res = str ~= nil and node_text == str
+ if any and res then
+ return true
+ elseif not any and not res then
+ return false
+ end
end
- return true
+ return not any
end,
- ['lua-match?'] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- if not node then
+ --- @param match vim.treesitter.query.TSMatch
+ --- @param source integer|string
+ --- @param predicate any[]
+ --- @param any boolean
+ ['lua-match'] = function(match, source, predicate, any)
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local regex = predicate[3]
- return string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil
+
+ for _, node in ipairs(nodes) do
+ local regex = predicate[3]
+ local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil
+ if any and res then
+ return true
+ elseif not any and not res then
+ return false
+ end
+ end
+
+ return not any
end,
- ['match?'] = (function()
+ ['match'] = (function()
local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true }
local function check_magic(str)
if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then
@@ -320,85 +333,161 @@ local predicate_handlers = {
end,
})
- return function(match, _, source, pred)
- ---@cast match TSMatch
- local node = match[pred[2]]
- if not node then
+ --- @param match vim.treesitter.query.TSMatch
+ --- @param source integer|string
+ --- @param predicate any[]
+ --- @param any boolean
+ return function(match, source, predicate, any)
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- ---@diagnostic disable-next-line no-unknown
- local regex = compiled_vim_regexes[pred[3]]
- return regex:match_str(vim.treesitter.get_node_text(node, source))
+
+ for _, node in ipairs(nodes) do
+ local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex
+ local res = regex:match_str(vim.treesitter.get_node_text(node, source))
+ if any and res then
+ return true
+ elseif not any and not res then
+ return false
+ end
+ end
+ return not any
end
end)(),
- ['contains?'] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- if not node then
+ --- @param match vim.treesitter.query.TSMatch
+ --- @param source integer|string
+ --- @param predicate any[]
+ --- @param any boolean
+ ['contains'] = function(match, source, predicate, any)
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local node_text = vim.treesitter.get_node_text(node, source)
- for i = 3, #predicate do
- if string.find(node_text, predicate[i], 1, true) then
- return true
+ for _, node in ipairs(nodes) do
+ local node_text = vim.treesitter.get_node_text(node, source)
+
+ for i = 3, #predicate do
+ local res = string.find(node_text, predicate[i], 1, true)
+ if any and res then
+ return true
+ elseif not any and not res then
+ return false
+ end
end
end
- return false
+ return not any
+ end,
+}
+
+---@nodoc
+---@class vim.treesitter.query.TSMatch
+---@field pattern? integer
+---@field active? boolean
+---@field [integer] TSNode[]
+
+---@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean
+
+-- Predicate handler receive the following arguments
+-- (match, pattern, bufnr, predicate)
+---@type table<string,TSPredicate>
+local predicate_handlers = {
+ ['eq?'] = function(match, _, source, predicate)
+ return impl['eq'](match, source, predicate, false)
+ end,
+
+ ['any-eq?'] = function(match, _, source, predicate)
+ return impl['eq'](match, source, predicate, true)
+ end,
+
+ ['lua-match?'] = function(match, _, source, predicate)
+ return impl['lua-match'](match, source, predicate, false)
+ end,
+
+ ['any-lua-match?'] = function(match, _, source, predicate)
+ return impl['lua-match'](match, source, predicate, true)
+ end,
+
+ ['match?'] = function(match, _, source, predicate)
+ return impl['match'](match, source, predicate, false)
+ end,
+
+ ['any-match?'] = function(match, _, source, predicate)
+ return impl['match'](match, source, predicate, true)
+ end,
+
+ ['contains?'] = function(match, _, source, predicate)
+ return impl['contains'](match, source, predicate, false)
+ end,
+
+ ['any-contains?'] = function(match, _, source, predicate)
+ return impl['contains'](match, source, predicate, true)
end,
['any-of?'] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- if not node then
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local node_text = vim.treesitter.get_node_text(node, source)
- -- Since 'predicate' will not be used by callers of this function, use it
- -- to store a string set built from the list of words to check against.
- local string_set = predicate['string_set']
- if not string_set then
- string_set = {}
- for i = 3, #predicate do
- ---@diagnostic disable-next-line:no-unknown
- string_set[predicate[i]] = true
+ for _, node in ipairs(nodes) do
+ local node_text = vim.treesitter.get_node_text(node, source)
+
+ -- Since 'predicate' will not be used by callers of this function, use it
+ -- to store a string set built from the list of words to check against.
+ local string_set = predicate['string_set'] --- @type table<string, boolean>
+ if not string_set then
+ string_set = {}
+ for i = 3, #predicate do
+ string_set[predicate[i]] = true
+ end
+ predicate['string_set'] = string_set
+ end
+
+ if string_set[node_text] then
+ return true
end
- predicate['string_set'] = string_set
end
- return string_set[node_text]
+ return false
end,
['has-ancestor?'] = function(match, _, _, predicate)
- local node = match[predicate[2]]
- if not node then
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- local ancestor_types = {}
- for _, type in ipairs({ unpack(predicate, 3) }) do
- ancestor_types[type] = true
- end
+ for _, node in ipairs(nodes) do
+ local ancestor_types = {} --- @type table<string, boolean>
+ for _, type in ipairs({ unpack(predicate, 3) }) do
+ ancestor_types[type] = true
+ end
- node = node:parent()
- while node do
- if ancestor_types[node:type()] then
- return true
+ local cur = node:parent()
+ while cur do
+ if ancestor_types[cur:type()] then
+ return true
+ end
+ cur = cur:parent()
end
- node = node:parent()
end
return false
end,
['has-parent?'] = function(match, _, _, predicate)
- local node = match[predicate[2]]
- if not node then
+ local nodes = match[predicate[2]]
+ if not nodes or #nodes == 0 then
return true
end
- if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then
- return true
+ for _, node in ipairs(nodes) do
+ if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then
+ return true
+ end
end
return false
end,
@@ -406,14 +495,16 @@ local predicate_handlers = {
-- As we provide lua-match? also expose vim-match?
predicate_handlers['vim-match?'] = predicate_handlers['match?']
+predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?']
----@class TSMetadata
+---@nodoc
+---@class vim.treesitter.query.TSMetadata
---@field range? Range
---@field conceal? string
----@field [integer] TSMetadata
+---@field [integer] vim.treesitter.query.TSMetadata
---@field [string] integer|string
----@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata)
+---@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
-- Predicate handler receive the following arguments
-- (match, pattern, bufnr, predicate)
@@ -441,13 +532,17 @@ local directive_handlers = {
-- Shifts the range of a node.
-- Example: (#offset! @_node 0 1 0 -1)
['offset!'] = function(match, _, _, pred, metadata)
- ---@cast pred integer[]
- local capture_id = pred[2]
+ local capture_id = pred[2] --[[@as integer]]
+ local nodes = match[capture_id]
+ assert(#nodes == 1, '#offset! does not support captures on multiple nodes')
+
+ local node = nodes[1]
+
if not metadata[capture_id] then
metadata[capture_id] = {}
end
- local range = metadata[capture_id].range or { match[capture_id]:range() }
+ local range = metadata[capture_id].range or { node:range() }
local start_row_offset = pred[3] or 0
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
@@ -471,7 +566,9 @@ local directive_handlers = {
local id = pred[2]
assert(type(id) == 'number')
- local node = match[id]
+ local nodes = match[id]
+ assert(#nodes == 1, '#gsub! does not support captures on multiple nodes')
+ local node = nodes[1]
local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
if not metadata[id] then
@@ -491,10 +588,9 @@ local directive_handlers = {
local capture_id = pred[2]
assert(type(capture_id) == 'number')
- local node = match[capture_id]
- if not node then
- return
- end
+ local nodes = match[capture_id]
+ assert(#nodes == 1, '#trim! does not support captures on multiple nodes')
+ local node = nodes[1]
local start_row, start_col, end_row, end_col = node:range()
@@ -525,38 +621,93 @@ local directive_handlers = {
--- Adds a new predicate to be used in queries
---
---@param name string Name of the predicate, without leading #
----@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[])
+---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table)
--- - see |vim.treesitter.query.add_directive()| for argument meanings
----@param force boolean|nil
-function M.add_predicate(name, handler, force)
- if predicate_handlers[name] and not force then
- error(string.format('Overriding %s', name))
+---@param opts table<string, any> Optional options:
+--- - force (boolean): Override an existing
+--- predicate of the same name
+--- - all (boolean): Use the correct
+--- implementation of the match table where
+--- capture IDs map to a list of nodes instead
+--- of a single node. Defaults to false (for
+--- backward compatibility). This option will
+--- eventually become the default and removed.
+function M.add_predicate(name, handler, opts)
+ -- Backward compatibility: old signature had "force" as boolean argument
+ if type(opts) == 'boolean' then
+ opts = { force = opts }
end
- predicate_handlers[name] = handler
+ opts = opts or {}
+
+ if predicate_handlers[name] and not opts.force then
+ error(string.format('Overriding existing predicate %s', name))
+ end
+
+ if opts.all then
+ predicate_handlers[name] = handler
+ else
+ --- @param match table<integer, TSNode[]>
+ local function wrapper(match, ...)
+ local m = {} ---@type table<integer, TSNode>
+ for k, v in pairs(match) do
+ if type(k) == 'number' then
+ m[k] = v[#v]
+ end
+ end
+ return handler(m, ...)
+ end
+ predicate_handlers[name] = wrapper
+ end
end
--- Adds a new directive to be used in queries
---
--- Handlers can set match level data by setting directly on the
---- metadata object `metadata.key = value`, additionally, handlers
+--- metadata object `metadata.key = value`. Additionally, handlers
--- can set node level data by using the capture id on the
--- metadata table `metadata[capture_id].key = value`
---
---@param name string Name of the directive, without leading #
----@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[], metadata:table)
---- - match: see |treesitter-query|
---- - node-level data are accessible via `match[capture_id]`
---- - pattern: see |treesitter-query|
+---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table)
+--- - match: A table mapping capture IDs to a list of captured nodes
+--- - pattern: the index of the matching pattern in the query file
--- - predicate: list of strings containing the full directive being called, e.g.
--- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }`
----@param force boolean|nil
-function M.add_directive(name, handler, force)
- if directive_handlers[name] and not force then
- error(string.format('Overriding %s', name))
+---@param opts table<string, any> Optional options:
+--- - force (boolean): Override an existing
+--- predicate of the same name
+--- - all (boolean): Use the correct
+--- implementation of the match table where
+--- capture IDs map to a list of nodes instead
+--- of a single node. Defaults to false (for
+--- backward compatibility). This option will
+--- eventually become the default and removed.
+function M.add_directive(name, handler, opts)
+ -- Backward compatibility: old signature had "force" as boolean argument
+ if type(opts) == 'boolean' then
+ opts = { force = opts }
end
- directive_handlers[name] = handler
+ opts = opts or {}
+
+ if directive_handlers[name] and not opts.force then
+ error(string.format('Overriding existing directive %s', name))
+ end
+
+ if opts.all then
+ directive_handlers[name] = handler
+ else
+ --- @param match table<integer, TSNode[]>
+ local function wrapper(match, ...)
+ local m = {} ---@type table<integer, TSNode>
+ for k, v in pairs(match) do
+ m[k] = v[#v]
+ end
+ handler(m, ...)
+ end
+ directive_handlers[name] = wrapper
+ end
end
--- Lists the currently available directives to use in queries.
@@ -580,8 +731,8 @@ local function is_directive(name)
end
---@private
----@param match TSMatch
----@param pattern string
+---@param match vim.treesitter.query.TSMatch
+---@param pattern integer
---@param source integer|string
function Query:match_preds(match, pattern, source)
local preds = self.info.patterns[pattern]
@@ -591,18 +742,14 @@ function Query:match_preds(match, pattern, source)
-- continue on the other case. This way unknown predicates will not be considered,
-- which allows some testing and easier user extensibility (#12173).
-- Also, tree-sitter strips the leading # from predicates for us.
- local pred_name ---@type string
-
- local is_not ---@type boolean
+ local is_not = false
-- Skip over directives... they will get processed after all the predicates.
if not is_directive(pred[1]) then
- if string.sub(pred[1], 1, 4) == 'not-' then
- pred_name = string.sub(pred[1], 5)
+ local pred_name = pred[1]
+ if pred_name:match('^not%-') then
+ pred_name = pred_name:sub(5)
is_not = true
- else
- pred_name = pred[1]
- is_not = false
end
local handler = predicate_handlers[pred_name]
@@ -623,8 +770,8 @@ function Query:match_preds(match, pattern, source)
end
---@private
----@param match TSMatch
----@param metadata TSMetadata
+---@param match vim.treesitter.query.TSMatch
+---@param metadata vim.treesitter.query.TSMetadata
function Query:apply_directives(match, pattern, source, metadata)
local preds = self.info.patterns[pattern]
@@ -645,14 +792,16 @@ end
--- Returns the start and stop value if set else the node's range.
-- When the node's range is used, the stop is incremented by 1
-- to make the search inclusive.
----@param start integer
----@param stop integer
+---@param start integer|nil
+---@param stop integer|nil
---@param node TSNode
---@return integer, integer
local function value_or_node_range(start, stop, node)
- if start == nil and stop == nil then
- local node_start, _, node_stop, _ = node:range()
- return node_start, node_stop + 1 -- Make stop inclusive
+ if start == nil then
+ start = node:start()
+ end
+ if stop == nil then
+ stop = node:end_() + 1 -- Make stop inclusive
end
return start, stop
@@ -683,10 +832,10 @@ end
---
---@param node TSNode under which the search will occur
---@param source (integer|string) Source buffer or string to extract text from
----@param start integer Starting line for the search
----@param stop integer Stopping line for the search (end-exclusive)
+---@param start? integer Starting line for the search. Defaults to `node:start()`.
+---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
---
----@return (fun(end_line: integer|nil): integer, TSNode, TSMetadata):
+---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata):
--- capture id, capture node, metadata
function Query:iter_captures(node, source, start, stop)
if type(source) == 'number' and source == 0 then
@@ -695,7 +844,7 @@ function Query:iter_captures(node, source, start, stop)
start, stop = value_or_node_range(start, stop, node)
- local raw_iter = node:_rawquery(self.query, true, start, stop)
+ local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch
local function iter(end_line)
local capture, captured_node, match = raw_iter()
local metadata = {}
@@ -719,46 +868,55 @@ end
--- Iterates the matches of self on a given range.
---
---- Iterate over all matches within a {node}. The arguments are the same as
---- for |Query:iter_captures()| but the iterated values are different:
---- an (1-based) index of the pattern in the query, a table mapping
---- capture indices to nodes, and metadata from any directives processing the match.
---- If the query has more than one pattern, the capture table might be sparse
---- and e.g. `pairs()` method should be used over `ipairs`.
---- Here is an example iterating over all captures in every match:
+--- Iterate over all matches within a {node}. The arguments are the same as for
+--- |Query:iter_captures()| but the iterated values are different: an (1-based)
+--- index of the pattern in the query, a table mapping capture indices to a list
+--- of nodes, and metadata from any directives processing the match.
+---
+--- WARNING: Set `all=true` to ensure all matching nodes in a match are
+--- returned, otherwise only the last node in a match is returned, breaking captures
+--- involving quantifiers such as `(comment)+ @comment`. The default option
+--- `all=false` is only provided for backward compatibility and will be removed
+--- after Nvim 0.10.
+---
+--- Example:
---
--- ```lua
---- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do
---- for id, node in pairs(match) do
+--- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1, { all = true }) do
+--- for id, nodes in pairs(match) do
--- local name = query.captures[id]
---- -- `node` was captured by the `name` capture in the match
+--- for _, node in ipairs(nodes) do
+--- -- `node` was captured by the `name` capture in the match
---
---- local node_data = metadata[id] -- Node level metadata
----
---- -- ... use the info here ...
+--- local node_data = metadata[id] -- Node level metadata
+--- ... use the info here ...
+--- end
--- end
--- end
--- ```
---
+---
---@param node TSNode under which the search will occur
---@param source (integer|string) Source buffer or string to search
----@param start integer Starting line for the search
----@param stop integer Stopping line for the search (end-exclusive)
----@param opts table|nil Options:
+---@param start? integer Starting line for the search. Defaults to `node:start()`.
+---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
+---@param opts? table Optional keyword arguments:
--- - max_start_depth (integer) if non-zero, sets the maximum start depth
--- for each match. This is used to prevent traversing too deep into a tree.
---- Requires treesitter >= 0.20.9.
+--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes.
+--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is
+--- incorrect behavior. This option will eventually become the default and removed.
---
----@return (fun(): integer, table<integer,TSNode>, table): pattern id, match, metadata
+---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata
function Query:iter_matches(node, source, start, stop, opts)
+ local all = opts and opts.all
if type(source) == 'number' and source == 0 then
source = api.nvim_get_current_buf()
end
start, stop = value_or_node_range(start, stop, node)
- local raw_iter = node:_rawquery(self.query, false, start, stop, opts)
- ---@cast raw_iter fun(): string, any
+ local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch
local function iter()
local pattern, match = raw_iter()
local metadata = {}
@@ -771,14 +929,33 @@ function Query:iter_matches(node, source, start, stop, opts)
self:apply_directives(match, pattern, source, metadata)
end
+
+ if not all then
+ -- Convert the match table into the old buggy version for backward
+ -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all"
+ -- option!
+ local old_match = {} ---@type table<integer, TSNode>
+ for k, v in pairs(match or {}) do
+ old_match[k] = v[#v]
+ end
+ return pattern, old_match, metadata
+ end
+
return pattern, match, metadata
end
return iter
end
----@class QueryLinterOpts
----@field langs (string|string[]|nil)
----@field clear (boolean)
+--- Optional keyword arguments:
+--- @class vim.treesitter.query.lint.Opts
+--- @inlinedoc
+---
+--- Language(s) to use for checking the query.
+--- If multiple languages are specified, queries are validated for all of them
+--- @field langs? string|string[]
+---
+--- Just clear current lint errors
+--- @field clear boolean
--- Lint treesitter queries using installed parser, or clear lint errors.
---
@@ -793,15 +970,12 @@ end
--- of the query file, e.g., if the path ends in `/lua/highlights.scm`, the parser for the
--- `lua` language will be used.
---@param buf (integer) Buffer handle
----@param opts (QueryLinterOpts|nil) Optional keyword arguments:
---- - langs (string|string[]|nil) Language(s) to use for checking the query.
---- If multiple languages are specified, queries are validated for all of them
---- - clear (boolean) if `true`, just clear current lint errors
+---@param opts? vim.treesitter.query.lint.Opts
function M.lint(buf, opts)
if opts and opts.clear then
- require('vim.treesitter._query_linter').clear(buf)
+ vim.treesitter._query_linter.clear(buf)
else
- require('vim.treesitter._query_linter').lint(buf, opts)
+ vim.treesitter._query_linter.lint(buf, opts)
end
end
@@ -813,13 +987,15 @@ end
--- vim.bo.omnifunc = 'v:lua.vim.treesitter.query.omnifunc'
--- ```
---
+--- @param findstart 0|1
+--- @param base string
function M.omnifunc(findstart, base)
- return require('vim.treesitter._query_linter').omnifunc(findstart, base)
+ return vim.treesitter._query_linter.omnifunc(findstart, base)
end
--- Opens a live editor to query the buffer you started from.
---
---- Can also be shown with *:EditQuery*.
+--- Can also be shown with [:EditQuery]().
---
--- If you move the cursor to a capture name ("@foo"), text matching the capture is highlighted in
--- the source buffer. The query editor is a scratch buffer, use `:write` to save it. You can find
@@ -827,7 +1003,7 @@ end
---
--- @param lang? string language to open the query editor for. If omitted, inferred from the current buffer's filetype.
function M.edit(lang)
- require('vim.treesitter.dev').edit_query(lang)
+ vim.treesitter.dev.edit_query(lang)
end
return M