diff options
Diffstat (limited to 'runtime/lua/vim/treesitter')
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 300 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_meta.lua | 28 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_query_linter.lua | 46 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/dev.lua | 228 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/health.lua | 2 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/highlighter.lua | 194 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/language.lua | 40 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 95 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 602 |
9 files changed, 861 insertions, 674 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index 5c1cc06908..d96cc966de 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -5,35 +5,20 @@ local Range = require('vim.treesitter._range') local api = vim.api ---@class TS.FoldInfo ----@field levels table<integer,string> ----@field levels0 table<integer,integer> ----@field private start_counts table<integer,integer> ----@field private stop_counts table<integer,integer> +---@field levels string[] the foldexpr result for each line +---@field levels0 integer[] the raw fold levels +---@field edits? {[1]: integer, [2]: integer} line range edited since the last invocation of the callback scheduled in on_bytes. 0-indexed, end-exclusive. local FoldInfo = {} FoldInfo.__index = FoldInfo ---@private function FoldInfo.new() return setmetatable({ - start_counts = {}, - stop_counts = {}, levels0 = {}, levels = {}, }, FoldInfo) end ----@package ----@param srow integer ----@param erow integer -function FoldInfo:invalidate_range(srow, erow) - for i = srow, erow do - self.start_counts[i + 1] = nil - self.stop_counts[i + 1] = nil - self.levels0[i + 1] = nil - self.levels[i + 1] = nil - end -end - --- Efficiently remove items from middle of a list a list. --- --- Calling table.remove() in a loop will re-index the tail of the table on @@ -55,12 +40,10 @@ end ---@package ---@param srow integer ----@param erow integer +---@param erow integer 0-indexed, exclusive function FoldInfo:remove_range(srow, erow) list_remove(self.levels, srow + 1, erow) list_remove(self.levels0, srow + 1, erow) - list_remove(self.start_counts, srow + 1, erow) - list_remove(self.stop_counts, srow + 1, erow) end --- Efficiently insert items into the middle of a list. @@ -91,46 +74,37 @@ end ---@package ---@param srow integer ----@param erow integer +---@param erow integer 0-indexed, exclusive function FoldInfo:add_range(srow, erow) - list_insert(self.levels, srow + 1, erow, '-1') + list_insert(self.levels, srow + 1, erow, '=') list_insert(self.levels0, srow + 1, erow, -1) - list_insert(self.start_counts, srow + 1, erow, nil) - list_insert(self.stop_counts, srow + 1, erow, nil) end ---@package ----@param lnum integer -function FoldInfo:add_start(lnum) - self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1 -end - ----@package ----@param lnum integer -function FoldInfo:add_stop(lnum) - self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1 -end - ----@package ----@param lnum integer ----@return integer -function FoldInfo:get_start(lnum) - return self.start_counts[lnum] or 0 +---@param srow integer +---@param erow_old integer +---@param erow_new integer 0-indexed, exclusive +function FoldInfo:edit_range(srow, erow_old, erow_new) + if self.edits then + self.edits[1] = math.min(srow, self.edits[1]) + if erow_old <= self.edits[2] then + self.edits[2] = self.edits[2] + (erow_new - erow_old) + end + self.edits[2] = math.max(self.edits[2], erow_new) + else + self.edits = { srow, erow_new } + end end ---@package ----@param lnum integer ----@return integer -function FoldInfo:get_stop(lnum) - return self.stop_counts[lnum] or 0 -end - -local function trim_level(level) - local max_fold_level = vim.wo.foldnestmax - if level > max_fold_level then - return max_fold_level +---@return integer? srow +---@return integer? erow 0-indexed, exclusive +function FoldInfo:flush_edit() + if self.edits then + local srow, erow = self.edits[1], self.edits[2] + self.edits = nil + return srow, erow end - return level end --- If a parser doesn't have any ranges explicitly set, treesitter will @@ -140,10 +114,10 @@ end --- TODO(lewis6991): Handle this generally --- --- @param bufnr integer ---- @param erow integer? +--- @param erow integer? 0-indexed, exclusive --- @return integer local function normalise_erow(bufnr, erow) - local max_erow = api.nvim_buf_line_count(bufnr) - 1 + local max_erow = api.nvim_buf_line_count(bufnr) return math.min(erow or max_erow, max_erow) end @@ -152,31 +126,30 @@ end ---@param bufnr integer ---@param info TS.FoldInfo ---@param srow integer? ----@param erow integer? +---@param erow integer? 0-indexed, exclusive ---@param parse_injections? boolean local function get_folds_levels(bufnr, info, srow, erow, parse_injections) srow = srow or 0 erow = normalise_erow(bufnr, erow) - info:invalidate_range(srow, erow) - - local prev_start = -1 - local prev_stop = -1 - local parser = ts.get_parser(bufnr) parser:parse(parse_injections and { srow, erow } or nil) + local enter_counts = {} ---@type table<integer, integer> + local leave_counts = {} ---@type table<integer, integer> + local prev_start = -1 + local prev_stop = -1 + parser:for_each_tree(function(tree, ltree) local query = ts.query.get(ltree:lang(), 'folds') if not query then return end - -- erow in query is end-exclusive - local q_erow = erow and erow + 1 or -1 - - for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do + -- Collect folds starting from srow - 1, because we should first subtract the folds that end at + -- srow - 1 from the level of srow - 1 to get accurate level of srow. + for id, node, metadata in query:iter_captures(tree:root(), bufnr, math.max(srow - 1, 0), erow) do if query.captures[id] == 'fold' then local range = ts.get_range(node, bufnr, metadata[id]) local start, _, stop, stop_col = Range.unpack4(range) @@ -193,8 +166,8 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) if fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) then - info:add_start(start + 1) - info:add_stop(stop + 1) + enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 + leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 prev_start = start prev_stop = stop end @@ -202,16 +175,15 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) end end) - local current_level = info.levels0[srow] or 0 + local nestmax = vim.wo.foldnestmax + local level0_prev = info.levels0[srow] or 0 + local leave_prev = leave_counts[srow] or 0 -- We now have the list of fold opening and closing, fill the gaps and mark where fold start - for lnum = srow + 1, erow + 1 do - local last_trimmed_level = trim_level(current_level) - current_level = current_level + info:get_start(lnum) - info.levels0[lnum] = current_level - - local trimmed_level = trim_level(current_level) - current_level = current_level - info:get_stop(lnum) + for lnum = srow + 1, erow do + local enter_line = enter_counts[lnum] or 0 + local leave_line = leave_counts[lnum] or 0 + local level0 = level0_prev - leave_prev + enter_line -- Determine if it's the start/end of a fold -- NB: vim's fold-expr interface does not have a mechanism to indicate that @@ -219,14 +191,36 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) -- ( \n ( \n )) \n (( \n ) \n ) -- versus -- ( \n ( \n ) \n ( \n ) \n ) - -- If it did have such a mechanism, (trimmed_level - last_trimmed_level) + -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and + -- vim interprets as the second case. + -- If it did have such a mechanism, (clamped - clamped_prev) -- would be the correct number of starts to pass on. + local adjusted = level0 ---@type integer local prefix = '' - if trimmed_level - last_trimmed_level > 0 then + if enter_line > 0 then prefix = '>' + if leave_line > 0 then + -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line + -- so that f2 gets the correct level on this line. This may reduce the size of f1 below + -- foldminlines, but we don't handle it for simplicity. + adjusted = level0 - leave_line + leave_line = 0 + end + end + + -- Clamp at foldnestmax. + local clamped = adjusted + if adjusted > nestmax then + prefix = '' + clamped = nestmax end - info.levels[lnum] = prefix .. tostring(trimmed_level) + -- Record the "real" level, so that it can be used as "base" of later get_folds_levels(). + info.levels0[lnum] = adjusted + info.levels[lnum] = prefix .. tostring(clamped) + + leave_prev = leave_line + level0_prev = adjusted end end @@ -296,8 +290,12 @@ end local function on_changedtree(bufnr, foldinfo, tree_changes) schedule_if_loaded(bufnr, function() for _, change in ipairs(tree_changes) do - local srow, _, erow = Range.unpack4(change) - get_folds_levels(bufnr, foldinfo, srow, erow) + local srow, _, erow, ecol = Range.unpack4(change) + if ecol > 0 then + erow = erow + 1 + end + -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. + get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow) end if #tree_changes > 0 then foldupdate(bufnr) @@ -309,19 +307,46 @@ end ---@param foldinfo TS.FoldInfo ---@param start_row integer ---@param old_row integer +---@param old_col integer ---@param new_row integer -local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) - local end_row_old = start_row + old_row - local end_row_new = start_row + new_row +---@param new_col integer +local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, new_row, new_col) + -- extend the end to fully include the range + local end_row_old = start_row + old_row + 1 + local end_row_new = start_row + new_row + 1 if new_row ~= old_row then + -- foldexpr can be evaluated before the scheduled callback is invoked. So it may observe the + -- outdated levels, which may spuriously open the folds that didn't change. So we should shift + -- folds as accurately as possible. For this to be perfectly accurate, we should track the + -- actual TSNodes that account for each fold, and compare the node's range with the edited + -- range. But for simplicity, we just check whether the start row is completely removed (e.g., + -- `dd`) or shifted (e.g., `o`). if new_row < old_row then - foldinfo:remove_range(end_row_new, end_row_old) + if start_col == 0 and new_row == 0 and new_col == 0 then + foldinfo:remove_range(start_row, start_row + (end_row_old - end_row_new)) + else + foldinfo:remove_range(end_row_new, end_row_old) + end else - foldinfo:add_range(start_row, end_row_new) + if start_col == 0 and old_row == 0 and old_col == 0 then + foldinfo:add_range(start_row, start_row + (end_row_new - end_row_old)) + else + foldinfo:add_range(end_row_old, end_row_new) + end end + foldinfo:edit_range(start_row, end_row_old, end_row_new) + + -- This callback must not use on_bytes arguments, because they can be outdated when the callback + -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing + -- the scheduled callback. So we should collect the edits. schedule_if_loaded(bufnr, function() - get_folds_levels(bufnr, foldinfo, start_row, end_row_new) + local srow, erow = foldinfo:flush_edit() + if not srow then + return + end + -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. + get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow) foldupdate(bufnr) end) end @@ -348,8 +373,8 @@ function M.foldexpr(lnum) on_changedtree(bufnr, foldinfos[bufnr], tree_changes) end, - on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _) - on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row) + on_bytes = function(_, _, start_row, start_col, _, old_row, old_col, _, new_row, new_col, _) + on_bytes(bufnr, foldinfos[bufnr], start_row, start_col, old_row, old_col, new_row, new_col) end, on_detach = function() @@ -361,96 +386,15 @@ function M.foldexpr(lnum) return foldinfos[bufnr].levels[lnum] or '0' end ----@package ----@return { [1]: string, [2]: string[] }[]|string -function M.foldtext() - local foldstart = vim.v.foldstart - local bufnr = api.nvim_get_current_buf() - - ---@type boolean, LanguageTree - local ok, parser = pcall(ts.get_parser, bufnr) - if not ok then - return vim.fn.foldtext() - end - - local query = ts.query.get(parser:lang(), 'highlights') - if not query then - return vim.fn.foldtext() - end - - local tree = parser:parse({ foldstart - 1, foldstart })[1] - - local line = api.nvim_buf_get_lines(bufnr, foldstart - 1, foldstart, false)[1] - if not line then - return vim.fn.foldtext() - end - - ---@type { [1]: string, [2]: string[], range: { [1]: integer, [2]: integer } }[] | { [1]: string, [2]: string[] }[] - local result = {} - - local line_pos = 0 - - for id, node, metadata in query:iter_captures(tree:root(), 0, foldstart - 1, foldstart) do - local name = query.captures[id] - local start_row, start_col, end_row, end_col = node:range() - - local priority = tonumber(metadata.priority or vim.highlight.priorities.treesitter) - - if start_row == foldstart - 1 and end_row == foldstart - 1 then - -- check for characters ignored by treesitter - if start_col > line_pos then - table.insert(result, { - line:sub(line_pos + 1, start_col), - {}, - range = { line_pos, start_col }, - }) - end - line_pos = end_col - - local text = line:sub(start_col + 1, end_col) - table.insert(result, { text, { { '@' .. name, priority } }, range = { start_col, end_col } }) - end - end - - local i = 1 - while i <= #result do - -- find first capture that is not in current range and apply highlights on the way - local j = i + 1 - while - j <= #result - and result[j].range[1] >= result[i].range[1] - and result[j].range[2] <= result[i].range[2] - do - for k, v in ipairs(result[i][2]) do - if not vim.tbl_contains(result[j][2], v) then - table.insert(result[j][2], k, v) - end - end - j = j + 1 - end - - -- remove the parent capture if it is split into children - if j > i + 1 then - table.remove(result, i) - else - -- highlights need to be sorted by priority, on equal prio, the deeper nested capture (earlier - -- in list) should be considered higher prio - if #result[i][2] > 1 then - table.sort(result[i][2], function(a, b) - return a[2] < b[2] - end) - end - - result[i][2] = vim.tbl_map(function(tbl) - return tbl[1] - end, result[i][2]) - result[i] = { result[i][1], result[i][2] } - - i = i + 1 +api.nvim_create_autocmd('OptionSet', { + pattern = { 'foldminlines', 'foldnestmax' }, + desc = 'Refresh treesitter folds', + callback = function() + for _, bufnr in ipairs(vim.tbl_keys(foldinfos)) do + foldinfos[bufnr] = FoldInfo.new() + get_folds_levels(bufnr, foldinfos[bufnr]) + foldupdate(bufnr) end - end - - return result -end - + end, +}) return M diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua index 80c998b555..19d97d2820 100644 --- a/runtime/lua/vim/treesitter/_meta.lua +++ b/runtime/lua/vim/treesitter/_meta.lua @@ -1,4 +1,5 @@ ---@meta +error('Cannot require a meta file') ---@class TSNode: userdata ---@field id fun(self: TSNode): string @@ -33,27 +34,26 @@ ---@field byte_length fun(self: TSNode): integer local TSNode = {} ----@param query userdata +---@param query TSQuery ---@param captures true ---@param start? integer ---@param end_? integer ---@param opts? table ----@return fun(): integer, TSNode, any +---@return fun(): integer, TSNode, vim.treesitter.query.TSMatch function TSNode:_rawquery(query, captures, start, end_, opts) end ----@param query userdata +---@param query TSQuery ---@param captures false ---@param start? integer ---@param end_? integer ---@param opts? table ----@return fun(): string, any +---@return fun(): integer, vim.treesitter.query.TSMatch function TSNode:_rawquery(query, captures, start, end_, opts) end ---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string) ----@class TSParser ----@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: true): TSTree, Range6[] ----@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: false|nil): TSTree, Range4[] +---@class TSParser: userdata +---@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: boolean): TSTree, (Range4|Range6)[] ---@field reset fun(self: TSParser) ---@field included_ranges fun(self: TSParser, include_bytes: boolean?): integer[] ---@field set_included_ranges fun(self: TSParser, ranges: (Range6|TSNode)[]) @@ -62,19 +62,31 @@ function TSNode:_rawquery(query, captures, start, end_, opts) end ---@field _set_logger fun(self: TSParser, lex: boolean, parse: boolean, cb: TSLoggerCallback) ---@field _logger fun(self: TSParser): TSLoggerCallback ----@class TSTree +---@class TSTree: userdata ---@field root fun(self: TSTree): TSNode ---@field edit fun(self: TSTree, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _:integer) ---@field copy fun(self: TSTree): TSTree ---@field included_ranges fun(self: TSTree, include_bytes: true): Range6[] ---@field included_ranges fun(self: TSTree, include_bytes: false): Range4[] +---@class TSQuery: userdata +---@field inspect fun(self: TSQuery): TSQueryInfo + +---@class (exact) TSQueryInfo +---@field captures string[] +---@field patterns table<integer, (integer|string)[][]> + ---@return integer vim._ts_get_language_version = function() end ---@return integer vim._ts_get_minimum_language_version = function() end +---@param lang string Language to use for the query +---@param query string Query string in s-expr syntax +---@return TSQuery +vim._ts_parse_query = function(lang, query) end + ---@param lang string ---@return TSParser vim._create_ts_parser = function(lang) end diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua index 87d74789a3..6216d4e891 100644 --- a/runtime/lua/vim/treesitter/_query_linter.lua +++ b/runtime/lua/vim/treesitter/_query_linter.lua @@ -17,7 +17,7 @@ local M = {} --- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs` --- Adds a diagnostic for node in the query buffer ---- @param diagnostics Diagnostic[] +--- @param diagnostics vim.Diagnostic[] --- @param range Range4 --- @param lint string --- @param lang string? @@ -45,7 +45,7 @@ local function guess_query_lang(buf) end --- @param buf integer ---- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil +--- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil --- @return QueryLinterNormalizedOpts local function normalize_opts(buf, opts) opts = opts or {} @@ -92,7 +92,7 @@ local function get_error_entry(err, node) end_col = end_col + #underlined elseif msg:match('^Invalid') then -- Use the length of the problematic type/capture/field - end_col = end_col + #msg:match('"([^"]+)"') + end_col = end_col + #(msg:match('"([^"]+)"') or '') end return { @@ -114,7 +114,7 @@ end --- @return vim.treesitter.ParseError? local parse = vim.func._memoize(hash_parse, function(node, buf, lang) local query_text = vim.treesitter.get_node_text(node, buf) - local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query + local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|vim.treesitter.Query if not ok and type(err) == 'string' then return get_error_entry(err, node) @@ -122,28 +122,30 @@ local parse = vim.func._memoize(hash_parse, function(node, buf, lang) end) --- @param buf integer ---- @param match table<integer,TSNode> ---- @param query Query +--- @param match vim.treesitter.query.TSMatch +--- @param query vim.treesitter.Query --- @param lang_context QueryLinterLanguageContext ---- @param diagnostics Diagnostic[] +--- @param diagnostics vim.Diagnostic[] local function lint_match(buf, match, query, lang_context, diagnostics) local lang = lang_context.lang local parser_info = lang_context.parser_info - for id, node in pairs(match) do - local cap_id = query.captures[id] + for id, nodes in pairs(match) do + for _, node in ipairs(nodes) do + local cap_id = query.captures[id] - -- perform language-independent checks only for first lang - if lang_context.is_first_lang and cap_id == 'error' then - local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') - add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text) - end + -- perform language-independent checks only for first lang + if lang_context.is_first_lang and cap_id == 'error' then + local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') + add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text) + end - -- other checks rely on Neovim parser introspection - if lang and parser_info and cap_id == 'toplevel' then - local err = parse(node, buf, lang) - if err then - add_lint_for_node(diagnostics, err.range, err.msg, lang) + -- other checks rely on Neovim parser introspection + if lang and parser_info and cap_id == 'toplevel' then + local err = parse(node, buf, lang) + if err then + add_lint_for_node(diagnostics, err.range, err.msg, lang) + end end end end @@ -151,7 +153,7 @@ end --- @private --- @param buf integer Buffer to lint ---- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil Options for linting +--- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil Options for linting function M.lint(buf, opts) if buf == 0 then buf = api.nvim_get_current_buf() @@ -173,7 +175,7 @@ function M.lint(buf, opts) parser:parse() parser:for_each_tree(function(tree, ltree) if ltree:lang() == 'query' then - for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1) do + for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1, { all = true }) do local lang_context = { lang = lang, parser_info = parser_info, @@ -195,7 +197,7 @@ function M.clear(buf) end --- @private ---- @param findstart integer +--- @param findstart 0|1 --- @param base string function M.omnifunc(findstart, base) if findstart == 1 then diff --git a/runtime/lua/vim/treesitter/dev.lua b/runtime/lua/vim/treesitter/dev.lua index 69ddc9b558..dc2a14d238 100644 --- a/runtime/lua/vim/treesitter/dev.lua +++ b/runtime/lua/vim/treesitter/dev.lua @@ -1,31 +1,29 @@ local api = vim.api ----@class TSDevModule local M = {} ----@class TSTreeView +---@class (private) vim.treesitter.dev.TSTreeView ---@field ns integer API namespace ----@field opts table Options table with the following keys: ---- - anon (boolean): If true, display anonymous nodes ---- - lang (boolean): If true, display the language alongside each node ---- - indent (number): Number of spaces to indent nested lines. Default is 2. ----@field nodes TSP.Node[] ----@field named TSP.Node[] +---@field opts vim.treesitter.dev.TSTreeViewOpts +---@field nodes vim.treesitter.dev.Node[] +---@field named vim.treesitter.dev.Node[] local TSTreeView = {} ----@class TSP.Node ----@field id integer Node id ----@field text string Node text ----@field named boolean True if this is a named (non-anonymous) node ----@field depth integer Depth of the node within the tree ----@field lnum integer Beginning line number of this node in the source buffer ----@field col integer Beginning column number of this node in the source buffer ----@field end_lnum integer Final line number of this node in the source buffer ----@field end_col integer Final column number of this node in the source buffer +---@private +---@class (private) vim.treesitter.dev.TSTreeViewOpts +---@field anon boolean If true, display anonymous nodes. +---@field lang boolean If true, display the language alongside each node. +---@field indent number Number of spaces to indent nested lines. + +---@class (private) vim.treesitter.dev.Node +---@field node TSNode Treesitter node +---@field field string? Node field +---@field depth integer Depth of this node in the tree +---@field text string? Text displayed in the inspector for this node. Not computed until the +--- inspector is drawn. ---@field lang string Source language of this node ----@field root TSNode ----@class TSP.Injection +---@class (private) vim.treesitter.dev.Injection ---@field lang string Source language of this injection ---@field root TSNode Root node of the injection @@ -43,48 +41,26 @@ local TSTreeView = {} --- ---@param node TSNode Starting node to begin traversal |tsnode| ---@param depth integer Current recursion depth +---@param field string|nil The field of the current node ---@param lang string Language of the tree currently being traversed ----@param injections table<string, TSP.Injection> Mapping of node ids to root nodes +---@param injections table<string, vim.treesitter.dev.Injection> Mapping of node ids to root nodes --- of injected language trees (see explanation above) ----@param tree TSP.Node[] Output table containing a list of tables each representing a node in the tree -local function traverse(node, depth, lang, injections, tree) +---@param tree vim.treesitter.dev.Node[] Output table containing a list of tables each representing a node in the tree +local function traverse(node, depth, field, lang, injections, tree) + table.insert(tree, { + node = node, + depth = depth, + lang = lang, + field = field, + }) + local injection = injections[node:id()] if injection then - traverse(injection.root, depth, injection.lang, injections, tree) + traverse(injection.root, depth + 1, nil, injection.lang, injections, tree) end - for child, field in node:iter_children() do - local type = child:type() - local lnum, col, end_lnum, end_col = child:range() - local named = child:named() - local text ---@type string - if named then - if field then - text = string.format('%s: (%s', field, type) - else - text = string.format('(%s', type) - end - else - text = string.format('"%s"', type:gsub('\n', '\\n'):gsub('"', '\\"')) - end - - table.insert(tree, { - id = child:id(), - text = text, - named = named, - depth = depth, - lnum = lnum, - col = col, - end_lnum = end_lnum, - end_col = end_col, - lang = lang, - }) - - traverse(child, depth + 1, lang, injections, tree) - - if named then - tree[#tree].text = string.format('%s)', tree[#tree].text) - end + for child, child_field in node:iter_children() do + traverse(child, depth + 1, child_field, lang, injections, tree) end return tree @@ -95,44 +71,45 @@ end ---@param bufnr integer Source buffer number ---@param lang string|nil Language of source buffer --- ----@return TSTreeView|nil +---@return vim.treesitter.dev.TSTreeView|nil ---@return string|nil Error message, if any --- ---@package function TSTreeView:new(bufnr, lang) local ok, parser = pcall(vim.treesitter.get_parser, bufnr or 0, lang) if not ok then - return nil, 'No parser available for the given buffer' + local err = parser --[[ @as string ]] + return nil, 'No parser available for the given buffer:\n' .. err end -- For each child tree (injected language), find the root of the tree and locate the node within -- the primary tree that contains that root. Add a mapping from the node in the primary tree to -- the root in the child tree to the {injections} table. local root = parser:parse(true)[1]:root() - local injections = {} ---@type table<string, TSP.Injection> + local injections = {} ---@type table<string, vim.treesitter.dev.Injection> parser:for_each_tree(function(parent_tree, parent_ltree) local parent = parent_tree:root() for _, child in pairs(parent_ltree:children()) do - child:for_each_tree(function(tree, ltree) + for _, tree in pairs(child:trees()) do local r = tree:root() local node = assert(parent:named_descendant_for_range(r:range())) local id = node:id() if not injections[id] or r:byte_length() > injections[id].root:byte_length() then injections[id] = { - lang = ltree:lang(), + lang = child:lang(), root = r, } end - end) + end end end) - local nodes = traverse(root, 0, parser:lang(), injections, {}) + local nodes = traverse(root, 0, nil, parser:lang(), injections, {}) - local named = {} ---@type TSP.Node[] + local named = {} ---@type vim.treesitter.dev.Node[] for _, v in ipairs(nodes) do - if v.named then + if v.node:named() then named[#named + 1] = v end end @@ -141,6 +118,7 @@ function TSTreeView:new(bufnr, lang) ns = api.nvim_create_namespace('treesitter/dev-inspect'), nodes = nodes, named = named, + ---@type vim.treesitter.dev.TSTreeViewOpts opts = { anon = false, lang = false, @@ -155,16 +133,12 @@ end local decor_ns = api.nvim_create_namespace('ts.dev') ----@param lnum integer ----@param col integer ----@param end_lnum integer ----@param end_col integer +---@param range Range4 ---@return string -local function get_range_str(lnum, col, end_lnum, end_col) - if lnum == end_lnum then - return string.format('[%d:%d - %d]', lnum + 1, col + 1, end_col) - end - return string.format('[%d:%d - %d:%d]', lnum + 1, col + 1, end_lnum + 1, end_col) +local function range_to_string(range) + ---@type integer, integer, integer, integer + local row, col, end_row, end_col = unpack(range) + return string.format('[%d, %d] - [%d, %d]', row, col, end_row, end_col) end ---@param w integer @@ -183,7 +157,10 @@ end local function set_dev_properties(w, b) vim.wo[w].scrolloff = 5 vim.wo[w].wrap = false - vim.wo[w].foldmethod = 'manual' -- disable folding + vim.wo[w].foldmethod = 'expr' + vim.wo[w].foldexpr = 'v:lua.vim.treesitter.foldexpr()' -- explicitly set foldexpr + vim.wo[w].foldenable = false -- Don't fold on first open InspectTree + vim.wo[w].foldlevel = 99 vim.bo[b].buflisted = false vim.bo[b].buftype = 'nofile' vim.bo[b].bufhidden = 'wipe' @@ -192,7 +169,7 @@ end --- Updates the cursor position in the inspector to match the node under the cursor. --- ---- @param treeview TSTreeView +--- @param treeview vim.treesitter.dev.TSTreeView --- @param lang string --- @param source_buf integer --- @param inspect_buf integer @@ -213,7 +190,7 @@ local function set_inspector_cursor(treeview, lang, source_buf, inspect_buf, ins local cursor_node_id = cursor_node:id() for i, v in treeview:iter() do - if v.id == cursor_node_id then + if v.node:id() == cursor_node_id then local start = v.depth * treeview.opts.indent ---@type integer local end_col = start + #v.text api.nvim_buf_set_extmark(inspect_buf, treeview.ns, i - 1, start, { @@ -228,6 +205,8 @@ end --- Write the contents of this View into {bufnr}. --- +--- Calling this function computes the text that is displayed for each node. +--- ---@param bufnr integer Buffer number to write into. ---@package function TSTreeView:draw(bufnr) @@ -235,13 +214,35 @@ function TSTreeView:draw(bufnr) local lines = {} ---@type string[] local lang_hl_marks = {} ---@type table[] - for _, item in self:iter() do - local range_str = get_range_str(item.lnum, item.col, item.end_lnum, item.end_col) + for i, item in self:iter() do + local range_str = range_to_string({ item.node:range() }) local lang_str = self.opts.lang and string.format(' %s', item.lang) or '' + + local text ---@type string + if item.node:named() then + if item.field then + text = string.format('%s: (%s', item.field, item.node:type()) + else + text = string.format('(%s', item.node:type()) + end + else + text = string.format('"%s"', item.node:type():gsub('\n', '\\n'):gsub('"', '\\"')) + end + + local next = self:get(i + 1) + if not next or next.depth <= item.depth then + local parens = item.depth - (next and next.depth or 0) + (item.node:named() and 1 or 0) + if parens > 0 then + text = string.format('%s%s', text, string.rep(')', parens)) + end + end + + item.text = text + local line = string.format( '%s%s ; %s%s', string.rep(' ', item.depth * self.opts.indent), - item.text, + text, range_str, lang_str ) @@ -253,7 +254,7 @@ function TSTreeView:draw(bufnr) } end - lines[#lines + 1] = line + lines[i] = line end api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) @@ -275,7 +276,7 @@ end --- The node number is dependent on whether or not anonymous nodes are displayed. --- ---@param i integer Node number to get ----@return TSP.Node +---@return vim.treesitter.dev.Node ---@package function TSTreeView:get(i) local t = self.opts.anon and self.nodes or self.named @@ -284,7 +285,7 @@ end --- Iterate over all of the nodes in this View. --- ----@return (fun(): integer, TSP.Node) Iterator over all nodes in this View +---@return (fun(): integer, vim.treesitter.dev.Node) Iterator over all nodes in this View ---@return table ---@return integer ---@package @@ -292,22 +293,31 @@ function TSTreeView:iter() return ipairs(self.opts.anon and self.nodes or self.named) end ---- @class InspectTreeOpts ---- @field lang string? The language of the source buffer. If omitted, the ---- filetype of the source buffer is used. ---- @field bufnr integer? Buffer to draw the tree into. If omitted, a new ---- buffer is created. ---- @field winid integer? Window id to display the tree buffer in. If omitted, ---- a new window is created with {command}. ---- @field command string? Vimscript command to create the window. Default ---- value is "60vnew". Only used when {winid} is nil. ---- @field title (string|fun(bufnr:integer):string|nil) Title of the window. If a ---- function, it accepts the buffer number of the source ---- buffer as its only argument and should return a string. +--- @class vim.treesitter.dev.inspect_tree.Opts +--- @inlinedoc +--- +--- The language of the source buffer. If omitted, the filetype of the source +--- buffer is used. +--- @field lang string? +--- +--- Buffer to draw the tree into. If omitted, a new buffer is created. +--- @field bufnr integer? +--- +--- Window id to display the tree buffer in. If omitted, a new window is +--- created with {command}. +--- @field winid integer? +--- +--- Vimscript command to create the window. Default value is "60vnew". +--- Only used when {winid} is nil. +--- @field command string? +--- +--- Title of the window. If a function, it accepts the buffer number of the +--- source buffer as its only argument and should return a string. +--- @field title (string|fun(bufnr:integer):string|nil) --- @private --- ---- @param opts InspectTreeOpts? +--- @param opts vim.treesitter.dev.inspect_tree.Opts? function M.inspect_tree(opts) vim.validate({ opts = { opts, 't', true }, @@ -364,9 +374,9 @@ function M.inspect_tree(opts) desc = 'Jump to the node under the cursor in the source buffer', callback = function() local row = api.nvim_win_get_cursor(w)[1] - local pos = treeview:get(row) + local lnum, col = treeview:get(row).node:start() api.nvim_set_current_win(win) - api.nvim_win_set_cursor(win, { pos.lnum + 1, pos.col }) + api.nvim_win_set_cursor(win, { lnum + 1, col }) end, }) api.nvim_buf_set_keymap(b, 'n', 'a', '', { @@ -374,7 +384,7 @@ function M.inspect_tree(opts) callback = function() local row, col = unpack(api.nvim_win_get_cursor(w)) ---@type integer, integer local curnode = treeview:get(row) - while curnode and not curnode.named do + while curnode and not curnode.node:named() do row = row - 1 curnode = treeview:get(row) end @@ -386,9 +396,9 @@ function M.inspect_tree(opts) return end - local id = curnode.id + local id = curnode.node:id() for i, node in treeview:iter() do - if node.id == id then + if node.node:id() == id then api.nvim_win_set_cursor(w, { i, col }) break end @@ -424,20 +434,20 @@ function M.inspect_tree(opts) api.nvim_buf_clear_namespace(buf, treeview.ns, 0, -1) local row = api.nvim_win_get_cursor(w)[1] - local pos = treeview:get(row) - api.nvim_buf_set_extmark(buf, treeview.ns, pos.lnum, pos.col, { - end_row = pos.end_lnum, - end_col = math.max(0, pos.end_col), + local lnum, col, end_lnum, end_col = treeview:get(row).node:range() + api.nvim_buf_set_extmark(buf, treeview.ns, lnum, col, { + end_row = end_lnum, + end_col = math.max(0, end_col), hl_group = 'Visual', }) local topline, botline = vim.fn.line('w0', win), vim.fn.line('w$', win) -- Move the cursor if highlighted range is completely out of view - if pos.lnum < topline and pos.end_lnum < topline then - api.nvim_win_set_cursor(win, { pos.end_lnum + 1, 0 }) - elseif pos.lnum > botline and pos.end_lnum > botline then - api.nvim_win_set_cursor(win, { pos.lnum + 1, 0 }) + if lnum < topline and end_lnum < topline then + api.nvim_win_set_cursor(win, { end_lnum + 1, 0 }) + elseif lnum > botline and end_lnum > botline then + api.nvim_win_set_cursor(win, { lnum + 1, 0 }) end end, }) @@ -462,7 +472,9 @@ function M.inspect_tree(opts) return true end + local treeview_opts = treeview.opts treeview = assert(TSTreeView:new(buf, opts.lang)) + treeview.opts = treeview_opts treeview:draw(b) end, }) diff --git a/runtime/lua/vim/treesitter/health.lua b/runtime/lua/vim/treesitter/health.lua index ed1161e97f..a9b066d158 100644 --- a/runtime/lua/vim/treesitter/health.lua +++ b/runtime/lua/vim/treesitter/health.lua @@ -1,6 +1,6 @@ local M = {} local ts = vim.treesitter -local health = require('vim.health') +local health = vim.health --- Performs a healthcheck for treesitter integration function M.check() diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index 496193c6ed..388680259a 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -2,50 +2,25 @@ local api = vim.api local query = vim.treesitter.query local Range = require('vim.treesitter._range') ----@alias TSHlIter fun(end_line: integer|nil): integer, TSNode, TSMetadata - ----@class TSHighlightState ----@field next_row integer ----@field iter TSHlIter|nil - ----@class TSHighlighter ----@field active table<integer,TSHighlighter> ----@field bufnr integer ----@field orig_spelloptions string ----@field _highlight_states table<TSTree,TSHighlightState> ----@field _queries table<string,TSHighlighterQuery> ----@field tree LanguageTree ----@field redraw_count integer -local TSHighlighter = rawget(vim.treesitter, 'TSHighlighter') or {} -TSHighlighter.__index = TSHighlighter +local ns = api.nvim_create_namespace('treesitter/highlighter') ---- @nodoc -TSHighlighter.active = TSHighlighter.active or {} +---@alias vim.treesitter.highlighter.Iter fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata ----@class TSHighlighterQuery ----@field _query Query|nil ----@field hl_cache table<integer,integer> +---@class (private) vim.treesitter.highlighter.Query +---@field private _query vim.treesitter.Query? +---@field private lang string +---@field private hl_cache table<integer,integer> local TSHighlighterQuery = {} TSHighlighterQuery.__index = TSHighlighterQuery -local ns = api.nvim_create_namespace('treesitter/highlighter') - ---@private +---@param lang string +---@param query_string string? +---@return vim.treesitter.highlighter.Query function TSHighlighterQuery.new(lang, query_string) - local self = setmetatable({}, { __index = TSHighlighterQuery }) - - self.hl_cache = setmetatable({}, { - __index = function(table, capture) - local name = self._query.captures[capture] - local id = 0 - if not vim.startswith(name, '_') then - id = api.nvim_get_hl_id_by_name('@' .. name .. '.' .. lang) - end - - rawset(table, capture, id) - return id - end, - }) + local self = setmetatable({}, TSHighlighterQuery) + self.lang = lang + self.hl_cache = {} if query_string then self._query = query.parse(lang, query_string) @@ -57,18 +32,57 @@ function TSHighlighterQuery.new(lang, query_string) end ---@package +---@param capture integer +---@return integer? +function TSHighlighterQuery:get_hl_from_capture(capture) + if not self.hl_cache[capture] then + local name = self._query.captures[capture] + local id = 0 + if not vim.startswith(name, '_') then + id = api.nvim_get_hl_id_by_name('@' .. name .. '.' .. self.lang) + end + self.hl_cache[capture] = id + end + + return self.hl_cache[capture] +end + +---@package function TSHighlighterQuery:query() return self._query end +---@class (private) vim.treesitter.highlighter.State +---@field tstree TSTree +---@field next_row integer +---@field iter vim.treesitter.highlighter.Iter? +---@field highlighter_query vim.treesitter.highlighter.Query + +---@nodoc +---@class vim.treesitter.highlighter +---@field active table<integer,vim.treesitter.highlighter> +---@field bufnr integer +---@field private orig_spelloptions string +--- A map of highlight states. +--- This state is kept during rendering across each line update. +---@field private _highlight_states vim.treesitter.highlighter.State[] +---@field private _queries table<string,vim.treesitter.highlighter.Query> +---@field tree vim.treesitter.LanguageTree +---@field private redraw_count integer +local TSHighlighter = { + active = {}, +} + +TSHighlighter.__index = TSHighlighter + ---@package --- --- Creates a highlighter for `tree`. --- ----@param tree LanguageTree parser object to use for highlighting +---@param tree vim.treesitter.LanguageTree parser object to use for highlighting ---@param opts (table|nil) Configuration of the highlighter: --- - queries table overwrite queries used by the highlighter ----@return TSHighlighter Created highlighter object +---@return vim.treesitter.highlighter Created highlighter object function TSHighlighter.new(tree, opts) local self = setmetatable({}, TSHighlighter) @@ -98,15 +112,12 @@ function TSHighlighter.new(tree, opts) end, }, true) - self.bufnr = tree:source() --[[@as integer]] - self.edit_count = 0 + local source = tree:source() + assert(type(source) == 'number') + + self.bufnr = source self.redraw_count = 0 - self.line_count = {} - -- A map of highlight states. - -- This state is kept during rendering across each line update. self._highlight_states = {} - - ---@type table<string,TSHighlighterQuery> self._queries = {} -- Queries for a specific language can be overridden by a custom @@ -144,11 +155,9 @@ end --- @nodoc --- Removes all internal references to the highlighter function TSHighlighter:destroy() - if TSHighlighter.active[self.bufnr] then - TSHighlighter.active[self.bufnr] = nil - end + TSHighlighter.active[self.bufnr] = nil - if vim.api.nvim_buf_is_loaded(self.bufnr) then + if api.nvim_buf_is_loaded(self.bufnr) then vim.bo[self.bufnr].spelloptions = self.orig_spelloptions vim.b[self.bufnr].ts_highlight = nil if vim.g.syntax_on == 1 then @@ -157,23 +166,49 @@ function TSHighlighter:destroy() end end ----@package ----@param tstree TSTree ----@return TSHighlightState -function TSHighlighter:get_highlight_state(tstree) - if not self._highlight_states[tstree] then - self._highlight_states[tstree] = { +---@param srow integer +---@param erow integer exclusive +---@private +function TSHighlighter:prepare_highlight_states(srow, erow) + self._highlight_states = {} + + self.tree:for_each_tree(function(tstree, tree) + if not tstree then + return + end + + local root_node = tstree:root() + local root_start_row, _, root_end_row, _ = root_node:range() + + -- Only consider trees within the visible range + if root_start_row > erow or root_end_row < srow then + return + end + + local highlighter_query = self:get_query(tree:lang()) + + -- Some injected languages may not have highlight queries. + if not highlighter_query:query() then + return + end + + -- _highlight_states should be a list so that the highlights are added in the same order as + -- for_each_tree traversal. This ensures that parents' highlight don't override children's. + table.insert(self._highlight_states, { + tstree = tstree, next_row = 0, iter = nil, - } - end - - return self._highlight_states[tstree] + highlighter_query = highlighter_query, + }) + end) end ----@private -function TSHighlighter:reset_highlight_state() - self._highlight_states = {} +---@param fn fun(state: vim.treesitter.highlighter.State) +---@package +function TSHighlighter:for_each_highlight_state(fn) + for _, state in ipairs(self._highlight_states) do + fn(state) + end end ---@package @@ -197,10 +232,9 @@ function TSHighlighter:on_changedtree(changes) end --- Gets the query used for @param lang --- ---@package ---@param lang string Language used by the highlighter. ----@return TSHighlighterQuery +---@return vim.treesitter.highlighter.Query function TSHighlighter:get_query(lang) if not self._queries[lang] then self._queries[lang] = TSHighlighterQuery.new(lang) @@ -209,35 +243,23 @@ function TSHighlighter:get_query(lang) return self._queries[lang] end ----@param self TSHighlighter +---@param self vim.treesitter.highlighter ---@param buf integer ---@param line integer ---@param is_spell_nav boolean local function on_line_impl(self, buf, line, is_spell_nav) - self.tree:for_each_tree(function(tstree, tree) - if not tstree then - return - end - - local root_node = tstree:root() + self:for_each_highlight_state(function(state) + local root_node = state.tstree:root() local root_start_row, _, root_end_row, _ = root_node:range() - -- Only worry about trees within the line range + -- Only consider trees that contain this line if root_start_row > line or root_end_row < line then return end - local state = self:get_highlight_state(tstree) - local highlighter_query = self:get_query(tree:lang()) - - -- Some injected languages may not have highlight queries. - if not highlighter_query:query() then - return - end - if state.iter == nil or state.next_row < line then state.iter = - highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) + state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) end while line >= state.next_row do @@ -250,9 +272,9 @@ local function on_line_impl(self, buf, line, is_spell_nav) local start_row, start_col, end_row, end_col = Range.unpack4(range) if capture then - local hl = highlighter_query.hl_cache[capture] + local hl = state.highlighter_query:get_hl_from_capture(capture) - local capture_name = highlighter_query:query().captures[capture] + local capture_name = state.highlighter_query:query().captures[capture] local spell = nil ---@type boolean? if capture_name == 'spell' then spell = true @@ -308,7 +330,7 @@ function TSHighlighter._on_spell_nav(_, _, buf, srow, _, erow, _) return end - self:reset_highlight_state() + self:prepare_highlight_states(srow, erow) for row = srow, erow do on_line_impl(self, buf, row, true) @@ -326,7 +348,7 @@ function TSHighlighter._on_win(_, _win, buf, topline, botline) return false end self.tree:parse({ topline, botline + 1 }) - self:reset_highlight_state() + self:prepare_highlight_states(topline, botline + 1) self.redraw_count = self.redraw_count + 1 return true end diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua index 15bf666a1e..47abf65332 100644 --- a/runtime/lua/vim/treesitter/language.lua +++ b/runtime/lua/vim/treesitter/language.lua @@ -1,6 +1,5 @@ local api = vim.api ----@class TSLanguageModule local M = {} ---@type table<string,string> @@ -37,6 +36,11 @@ end ---@deprecated function M.require_language(lang, path, silent, symbol_name) + vim.deprecate( + 'vim.treesitter.language.require_language()', + 'vim.treesitter.language.add()', + '0.12' + ) local opts = { silent = silent, path = path, @@ -52,10 +56,17 @@ function M.require_language(lang, path, silent, symbol_name) return true end ----@class treesitter.RequireLangOpts ----@field path? string ----@field silent? boolean +---@class vim.treesitter.language.add.Opts +---@inlinedoc +--- +---Default filetype the parser should be associated with. +---(Default: {lang}) ---@field filetype? string|string[] +--- +---Optional path the parser is located at +---@field path? string +--- +---Internal symbol name for the language to load ---@field symbol_name? string --- Load parser with name {lang} @@ -63,13 +74,8 @@ end --- Parsers are searched in the `parser` runtime directory, or the provided {path} --- ---@param lang string Name of the parser (alphanumerical and `_` only) ----@param opts (table|nil) Options: ---- - filetype (string|string[]) Default filetype the parser should be associated with. ---- Defaults to {lang}. ---- - path (string|nil) Optional path the parser is located at ---- - symbol_name (string|nil) Internal symbol name for the language to load +---@param opts? vim.treesitter.language.add.Opts Options: function M.add(lang, opts) - ---@cast opts treesitter.RequireLangOpts opts = opts or {} local path = opts.path local filetype = opts.filetype or lang @@ -114,6 +120,10 @@ local function ensure_list(x) end --- Register a parser named {lang} to be used for {filetype}(s). +--- +--- Note: this adds or overrides the mapping for {filetype}, any existing mappings from other +--- filetypes to {lang} will be preserved. +--- --- @param lang string Name of parser --- @param filetype string|string[] Filetype(s) to associate with lang function M.register(lang, filetype) @@ -140,14 +150,4 @@ function M.inspect(lang) return vim._ts_inspect_language(lang) end ----@deprecated -function M.inspect_language(...) - vim.deprecate( - 'vim.treesitter.language.inspect_language()', - 'vim.treesitter.language.inspect()', - '0.10' - ) - return M.inspect(...) -end - return M diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 0171b416cd..62714d3f1b 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -1,6 +1,4 @@ ---- @defgroup lua-treesitter-languagetree ---- ---- @brief A \*LanguageTree\* contains a tree of parsers: the root treesitter parser for {lang} and +--- @brief A [LanguageTree]() contains a tree of parsers: the root treesitter parser for {lang} and --- any "injected" language parsers, which themselves may inject other languages, recursively. --- For example a Lua buffer containing some Vimscript commands needs multiple parsers to fully --- understand its contents. @@ -69,11 +67,12 @@ local TSCallbackNames = { on_child_removed = 'child_removed', } ----@class LanguageTree +---@nodoc +---@class vim.treesitter.LanguageTree ---@field private _callbacks table<TSCallbackName,function[]> Callback handlers ---@field package _callbacks_rec table<TSCallbackName,function[]> Callback handlers (recursive) ----@field private _children table<string,LanguageTree> Injected languages ----@field private _injection_query Query Queries defining injected languages +---@field private _children table<string,vim.treesitter.LanguageTree> Injected languages +---@field private _injection_query vim.treesitter.Query Queries defining injected languages ---@field private _injections_processed boolean ---@field private _opts table Options ---@field private _parser TSParser Parser for language @@ -91,9 +90,11 @@ local TSCallbackNames = { ---@field private _logfile? file* local LanguageTree = {} ----@class LanguageTreeOpts ----@field queries table<string,string> -- Deprecated ----@field injections table<string,string> +---Optional arguments: +---@class vim.treesitter.LanguageTree.new.Opts +---@inlinedoc +---@field queries? table<string,string> -- Deprecated +---@field injections? table<string,string> LanguageTree.__index = LanguageTree @@ -104,14 +105,11 @@ LanguageTree.__index = LanguageTree --- ---@param source (integer|string) Buffer or text string to parse ---@param lang string Root language of this tree ----@param opts (table|nil) Optional arguments: ---- - injections table Map of language to injection query strings. Overrides the ---- built-in runtime file searching for language injections. +---@param opts vim.treesitter.LanguageTree.new.Opts? ---@param parent_lang? string Parent language name of this tree ----@return LanguageTree parser object +---@return vim.treesitter.LanguageTree parser object function LanguageTree.new(source, lang, opts, parent_lang) language.add(lang) - ---@type LanguageTreeOpts opts = opts or {} if source == 0 then @@ -120,7 +118,7 @@ function LanguageTree.new(source, lang, opts, parent_lang) local injections = opts.injections or {} - --- @type LanguageTree + --- @type vim.treesitter.LanguageTree local self = { _source = source, _lang = lang, @@ -196,7 +194,7 @@ local function tcall(f, ...) end ---@private ----@vararg any +---@param ... any function LanguageTree:_log(...) if not self._logger then return @@ -348,7 +346,13 @@ function LanguageTree:_parse_regions(range) -- If there are no ranges, set to an empty list -- so the included ranges in the parser are cleared. for i, ranges in pairs(self:included_regions()) do - if not self._valid[i] and intercepts_region(ranges, range) then + if + not self._valid[i] + and ( + intercepts_region(ranges, range) + or (self._trees[i] and intercepts_region(self._trees[i]:included_ranges(false), range)) + ) + then self._parser:set_included_ranges(ranges) local parse_time, tree, tree_changes = tcall(self._parser.parse, self._parser, self._trees[i], self._source, true) @@ -427,7 +431,7 @@ function LanguageTree:parse(range) local query_time = 0 local total_parse_time = 0 - --- At least 1 region is invalid + -- At least 1 region is invalid if not self:is_valid(true) then changes, no_regions_parsed, total_parse_time = self:_parse_regions(range) -- Need to run injections when we parsed something @@ -460,7 +464,7 @@ end --- add recursion yourself if needed. --- Invokes the callback for each |LanguageTree| and its children recursively --- ----@param fn fun(tree: LanguageTree, lang: string) +---@param fn fun(tree: vim.treesitter.LanguageTree, lang: string) ---@param include_self boolean|nil Whether to include the invoking tree in the results function LanguageTree:for_each_child(fn, include_self) vim.deprecate('LanguageTree:for_each_child()', 'LanguageTree:children()', '0.11') @@ -469,6 +473,7 @@ function LanguageTree:for_each_child(fn, include_self) end for _, child in pairs(self._children) do + --- @diagnostic disable-next-line:deprecated child:for_each_child(fn, true) end end @@ -477,7 +482,7 @@ end --- --- Note: This includes the invoking tree's child trees as well. --- ----@param fn fun(tree: TSTree, ltree: LanguageTree) +---@param fn fun(tree: TSTree, ltree: vim.treesitter.LanguageTree) function LanguageTree:for_each_tree(fn) for _, tree in pairs(self._trees) do fn(tree, self) @@ -494,7 +499,7 @@ end --- ---@private ---@param lang string Language to add. ----@return LanguageTree injected +---@return vim.treesitter.LanguageTree injected function LanguageTree:add_child(lang) if self._children[lang] then self:remove_child(lang) @@ -664,7 +669,7 @@ end ---@param node TSNode ---@param source string|integer ----@param metadata TSMetadata +---@param metadata vim.treesitter.query.TSMetadata ---@param include_children boolean ---@return Range6[] local function get_node_ranges(node, source, metadata, include_children) @@ -698,13 +703,14 @@ local function get_node_ranges(node, source, metadata, include_children) return ranges end ----@class TSInjectionElem +---@nodoc +---@class vim.treesitter.languagetree.InjectionElem ---@field combined boolean ---@field regions Range6[][] ----@alias TSInjection table<string,table<integer,TSInjectionElem>> +---@alias vim.treesitter.languagetree.Injection table<string,table<integer,vim.treesitter.languagetree.InjectionElem>> ----@param t table<integer,TSInjection> +---@param t table<integer,vim.treesitter.languagetree.Injection> ---@param tree_index integer ---@param pattern integer ---@param lang string @@ -751,6 +757,11 @@ end) ---@param alias string language or filetype name ---@return string? # resolved parser name local function resolve_lang(alias) + -- validate that `alias` is a legal language + if not (alias and alias:match('[%w_]+') == alias) then + return + end + if has_parser(alias) then return alias end @@ -773,8 +784,8 @@ end ---@private --- Extract injections according to: --- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection ----@param match table<integer,TSNode> ----@param metadata TSMetadata +---@param match table<integer,TSNode[]> +---@param metadata vim.treesitter.query.TSMetadata ---@return string?, boolean, Range6[] function LanguageTree:_get_injection(match, metadata) local ranges = {} ---@type Range6[] @@ -785,14 +796,16 @@ function LanguageTree:_get_injection(match, metadata) or (injection_lang and resolve_lang(injection_lang)) local include_children = metadata['injection.include-children'] ~= nil - for id, node in pairs(match) do - local name = self._injection_query.captures[id] - -- Lang should override any other language tag - if name == 'injection.language' then - local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] }) - lang = resolve_lang(text) - elseif name == 'injection.content' then - ranges = get_node_ranges(node, self._source, metadata[id], include_children) + for id, nodes in pairs(match) do + for _, node in ipairs(nodes) do + local name = self._injection_query.captures[id] + -- Lang should override any other language tag + if name == 'injection.language' then + local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] }) + lang = resolve_lang(text) + elseif name == 'injection.content' then + ranges = get_node_ranges(node, self._source, metadata[id], include_children) + end end end @@ -825,7 +838,7 @@ function LanguageTree:_get_injections() return {} end - ---@type table<integer,TSInjection> + ---@type table<integer,vim.treesitter.languagetree.Injection> local injections = {} for index, tree in pairs(self._trees) do @@ -833,7 +846,13 @@ function LanguageTree:_get_injections() local start_line, _, end_line, _ = root_node:range() for pattern, match, metadata in - self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1) + self._injection_query:iter_matches( + root_node, + self._source, + start_line, + end_line + 1, + { all = true } + ) do local lang, combined, ranges = self:_get_injection(match, metadata) if lang then @@ -1133,7 +1152,7 @@ end --- Gets the appropriate language that contains {range}. --- ---@param range Range4 `{ start_line, start_col, end_line, end_col }` ----@return LanguageTree Managing {range} +---@return vim.treesitter.LanguageTree Managing {range} function LanguageTree:language_for_range(range) for _, child in pairs(self._children) do if child:contains(range) then diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 8cbbffcd60..a086f5e876 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,19 +1,50 @@ local api = vim.api local language = require('vim.treesitter.language') ----@class Query ----@field captures string[] List of captures used in query ----@field info TSQueryInfo Contains used queries, predicates, directives ----@field query userdata Parsed query +local M = {} + +---@nodoc +---Parsed query, see |vim.treesitter.query.parse()| +--- +---@class vim.treesitter.Query +---@field lang string name of the language for this parser +---@field captures string[] list of (unique) capture names defined in query +---@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives) +---@field query TSQuery userdata query object local Query = {} Query.__index = Query ----@class TSQueryInfo ----@field captures table ----@field patterns table<string,any[][]> +---@package +---@see vim.treesitter.query.parse +---@param lang string +---@param ts_query TSQuery +---@return vim.treesitter.Query +function Query.new(lang, ts_query) + local self = setmetatable({}, Query) + local query_info = ts_query:inspect() ---@type TSQueryInfo + self.query = ts_query + self.lang = lang + self.info = { + captures = query_info.captures, + patterns = query_info.patterns, + } + self.captures = self.info.captures + return self +end ----@class TSQueryModule -local M = {} +---@nodoc +---Information for Query, see |vim.treesitter.query.parse()| +---@class vim.treesitter.QueryInfo +--- +---List of (unique) capture names defined in query. +---@field captures string[] +--- +---Contains information about predicates and directives. +---Key is pattern id, and value is list of predicates or directives defined in the pattern. +---A predicate or directive is a list of (integer|string); integer represents `capture_id`, and +---string represents (literal) arguments to predicate/directive. See |treesitter-predicates| +---and |treesitter-directives| for more details. +---@field patterns table<integer, (integer|string)[][]> ---@param files string[] ---@return string[] @@ -53,16 +84,6 @@ local function add_included_lang(base_langs, lang, ilang) return false end ----@deprecated -function M.get_query_files(...) - vim.deprecate( - 'vim.treesitter.query.get_query_files()', - 'vim.treesitter.query.get_files()', - '0.10' - ) - return M.get_files(...) -end - --- Gets the list of files used to make up a query --- ---@param lang string Language to get query for @@ -163,7 +184,7 @@ local function read_query_files(filenames) end -- The explicitly set queries from |vim.treesitter.query.set()| ----@type table<string,table<string,Query>> +---@type table<string,table<string,vim.treesitter.Query>> local explicit_queries = setmetatable({}, { __index = function(t, k) local lang_queries = {} @@ -173,12 +194,6 @@ local explicit_queries = setmetatable({}, { end, }) ----@deprecated -function M.set_query(...) - vim.deprecate('vim.treesitter.query.set_query()', 'vim.treesitter.query.set()', '0.10') - M.set(...) -end - --- Sets the runtime query named {query_name} for {lang} --- --- This allows users to override any runtime files and/or configuration @@ -191,18 +206,12 @@ function M.set(lang, query_name, text) explicit_queries[lang][query_name] = M.parse(lang, text) end ----@deprecated -function M.get_query(...) - vim.deprecate('vim.treesitter.query.get_query()', 'vim.treesitter.query.get()', '0.10') - return M.get(...) -end - --- Returns the runtime query {query_name} for {lang}. --- ---@param lang string Language to use for the query ---@param query_name string Name of the query (e.g. "highlights") --- ----@return Query|nil Parsed query +---@return vim.treesitter.Query|nil : Parsed query. `nil` if no query files are found. M.get = vim.func._memoize('concat-2', function(lang, query_name) if explicit_queries[lang][query_name] then return explicit_queries[lang][query_name] @@ -218,92 +227,96 @@ M.get = vim.func._memoize('concat-2', function(lang, query_name) return M.parse(lang, query_string) end) ----@deprecated -function M.parse_query(...) - vim.deprecate('vim.treesitter.query.parse_query()', 'vim.treesitter.query.parse()', '0.10') - return M.parse(...) -end - --- Parse {query} as a string. (If the query is in a file, the caller --- should read the contents into a string before calling). --- --- Returns a `Query` (see |lua-treesitter-query|) object which can be used to --- search nodes in the syntax tree for the patterns defined in {query} ---- using `iter_*` methods below. +--- using the `iter_captures` and `iter_matches` methods. --- --- Exposes `info` and `captures` with additional context about {query}. ---- - `captures` contains the list of unique capture names defined in ---- {query}. ---- -` info.captures` also points to `captures`. +--- - `captures` contains the list of unique capture names defined in {query}. +--- - `info.captures` also points to `captures`. --- - `info.patterns` contains information about predicates. --- ---@param lang string Language to use for the query ---@param query string Query in s-expr syntax --- ----@return Query Parsed query +---@return vim.treesitter.Query Parsed query +--- +---@see |vim.treesitter.query.get()| M.parse = vim.func._memoize('concat-2', function(lang, query) language.add(lang) - local self = setmetatable({}, Query) - self.query = vim._ts_parse_query(lang, query) - self.info = self.query:inspect() - self.captures = self.info.captures - return self + local ts_query = vim._ts_parse_query(lang, query) + return Query.new(lang, ts_query) end) ----@deprecated -function M.get_range(...) - vim.deprecate('vim.treesitter.query.get_range()', 'vim.treesitter.get_range()', '0.10') - return vim.treesitter.get_range(...) -end - ----@deprecated -function M.get_node_text(...) - vim.deprecate('vim.treesitter.query.get_node_text()', 'vim.treesitter.get_node_text()', '0.10') - return vim.treesitter.get_node_text(...) -end - ----@alias TSMatch table<integer,TSNode> - ----@alias TSPredicate fun(match: TSMatch, _, _, predicate: any[]): boolean - --- Predicate handler receive the following arguments --- (match, pattern, bufnr, predicate) ----@type table<string,TSPredicate> -local predicate_handlers = { - ['eq?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then +--- Implementations of predicates that can optionally be prefixed with "any-". +--- +--- These functions contain the implementations for each predicate, correctly +--- handling the "any" vs "all" semantics. They are called from the +--- predicate_handlers table with the appropriate arguments for each predicate. +local impl = { + --- @param match vim.treesitter.query.TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + ['eq'] = function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local node_text = vim.treesitter.get_node_text(node, source) - local str ---@type string - if type(predicate[3]) == 'string' then - -- (#eq? @aa "foo") - str = predicate[3] - else - -- (#eq? @aa @bb) - str = vim.treesitter.get_node_text(match[predicate[3]], source) - end + for _, node in ipairs(nodes) do + local node_text = vim.treesitter.get_node_text(node, source) + + local str ---@type string + if type(predicate[3]) == 'string' then + -- (#eq? @aa "foo") + str = predicate[3] + else + -- (#eq? @aa @bb) + local other = assert(match[predicate[3]]) + assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes') + str = vim.treesitter.get_node_text(other[1], source) + end - if node_text ~= str or str == nil then - return false + local res = str ~= nil and node_text == str + if any and res then + return true + elseif not any and not res then + return false + end end - return true + return not any end, - ['lua-match?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then + --- @param match vim.treesitter.query.TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + ['lua-match'] = function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local regex = predicate[3] - return string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil + + for _, node in ipairs(nodes) do + local regex = predicate[3] + local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil + if any and res then + return true + elseif not any and not res then + return false + end + end + + return not any end, - ['match?'] = (function() + ['match'] = (function() local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true } local function check_magic(str) if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then @@ -320,85 +333,161 @@ local predicate_handlers = { end, }) - return function(match, _, source, pred) - ---@cast match TSMatch - local node = match[pred[2]] - if not node then + --- @param match vim.treesitter.query.TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + return function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - ---@diagnostic disable-next-line no-unknown - local regex = compiled_vim_regexes[pred[3]] - return regex:match_str(vim.treesitter.get_node_text(node, source)) + + for _, node in ipairs(nodes) do + local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex + local res = regex:match_str(vim.treesitter.get_node_text(node, source)) + if any and res then + return true + elseif not any and not res then + return false + end + end + return not any end end)(), - ['contains?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then + --- @param match vim.treesitter.query.TSMatch + --- @param source integer|string + --- @param predicate any[] + --- @param any boolean + ['contains'] = function(match, source, predicate, any) + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local node_text = vim.treesitter.get_node_text(node, source) - for i = 3, #predicate do - if string.find(node_text, predicate[i], 1, true) then - return true + for _, node in ipairs(nodes) do + local node_text = vim.treesitter.get_node_text(node, source) + + for i = 3, #predicate do + local res = string.find(node_text, predicate[i], 1, true) + if any and res then + return true + elseif not any and not res then + return false + end end end - return false + return not any + end, +} + +---@nodoc +---@class vim.treesitter.query.TSMatch +---@field pattern? integer +---@field active? boolean +---@field [integer] TSNode[] + +---@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean + +-- Predicate handler receive the following arguments +-- (match, pattern, bufnr, predicate) +---@type table<string,TSPredicate> +local predicate_handlers = { + ['eq?'] = function(match, _, source, predicate) + return impl['eq'](match, source, predicate, false) + end, + + ['any-eq?'] = function(match, _, source, predicate) + return impl['eq'](match, source, predicate, true) + end, + + ['lua-match?'] = function(match, _, source, predicate) + return impl['lua-match'](match, source, predicate, false) + end, + + ['any-lua-match?'] = function(match, _, source, predicate) + return impl['lua-match'](match, source, predicate, true) + end, + + ['match?'] = function(match, _, source, predicate) + return impl['match'](match, source, predicate, false) + end, + + ['any-match?'] = function(match, _, source, predicate) + return impl['match'](match, source, predicate, true) + end, + + ['contains?'] = function(match, _, source, predicate) + return impl['contains'](match, source, predicate, false) + end, + + ['any-contains?'] = function(match, _, source, predicate) + return impl['contains'](match, source, predicate, true) end, ['any-of?'] = function(match, _, source, predicate) - local node = match[predicate[2]] - if not node then + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local node_text = vim.treesitter.get_node_text(node, source) - -- Since 'predicate' will not be used by callers of this function, use it - -- to store a string set built from the list of words to check against. - local string_set = predicate['string_set'] - if not string_set then - string_set = {} - for i = 3, #predicate do - ---@diagnostic disable-next-line:no-unknown - string_set[predicate[i]] = true + for _, node in ipairs(nodes) do + local node_text = vim.treesitter.get_node_text(node, source) + + -- Since 'predicate' will not be used by callers of this function, use it + -- to store a string set built from the list of words to check against. + local string_set = predicate['string_set'] --- @type table<string, boolean> + if not string_set then + string_set = {} + for i = 3, #predicate do + string_set[predicate[i]] = true + end + predicate['string_set'] = string_set + end + + if string_set[node_text] then + return true end - predicate['string_set'] = string_set end - return string_set[node_text] + return false end, ['has-ancestor?'] = function(match, _, _, predicate) - local node = match[predicate[2]] - if not node then + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - local ancestor_types = {} - for _, type in ipairs({ unpack(predicate, 3) }) do - ancestor_types[type] = true - end + for _, node in ipairs(nodes) do + local ancestor_types = {} --- @type table<string, boolean> + for _, type in ipairs({ unpack(predicate, 3) }) do + ancestor_types[type] = true + end - node = node:parent() - while node do - if ancestor_types[node:type()] then - return true + local cur = node:parent() + while cur do + if ancestor_types[cur:type()] then + return true + end + cur = cur:parent() end - node = node:parent() end return false end, ['has-parent?'] = function(match, _, _, predicate) - local node = match[predicate[2]] - if not node then + local nodes = match[predicate[2]] + if not nodes or #nodes == 0 then return true end - if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then - return true + for _, node in ipairs(nodes) do + if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then + return true + end end return false end, @@ -406,14 +495,16 @@ local predicate_handlers = { -- As we provide lua-match? also expose vim-match? predicate_handlers['vim-match?'] = predicate_handlers['match?'] +predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?'] ----@class TSMetadata +---@nodoc +---@class vim.treesitter.query.TSMetadata ---@field range? Range ---@field conceal? string ----@field [integer] TSMetadata +---@field [integer] vim.treesitter.query.TSMetadata ---@field [string] integer|string ----@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata) +---@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -441,13 +532,17 @@ local directive_handlers = { -- Shifts the range of a node. -- Example: (#offset! @_node 0 1 0 -1) ['offset!'] = function(match, _, _, pred, metadata) - ---@cast pred integer[] - local capture_id = pred[2] + local capture_id = pred[2] --[[@as integer]] + local nodes = match[capture_id] + assert(#nodes == 1, '#offset! does not support captures on multiple nodes') + + local node = nodes[1] + if not metadata[capture_id] then metadata[capture_id] = {} end - local range = metadata[capture_id].range or { match[capture_id]:range() } + local range = metadata[capture_id].range or { node:range() } local start_row_offset = pred[3] or 0 local start_col_offset = pred[4] or 0 local end_row_offset = pred[5] or 0 @@ -471,7 +566,9 @@ local directive_handlers = { local id = pred[2] assert(type(id) == 'number') - local node = match[id] + local nodes = match[id] + assert(#nodes == 1, '#gsub! does not support captures on multiple nodes') + local node = nodes[1] local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' if not metadata[id] then @@ -491,10 +588,9 @@ local directive_handlers = { local capture_id = pred[2] assert(type(capture_id) == 'number') - local node = match[capture_id] - if not node then - return - end + local nodes = match[capture_id] + assert(#nodes == 1, '#trim! does not support captures on multiple nodes') + local node = nodes[1] local start_row, start_col, end_row, end_col = node:range() @@ -525,38 +621,93 @@ local directive_handlers = { --- Adds a new predicate to be used in queries --- ---@param name string Name of the predicate, without leading # ----@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[]) +---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) --- - see |vim.treesitter.query.add_directive()| for argument meanings ----@param force boolean|nil -function M.add_predicate(name, handler, force) - if predicate_handlers[name] and not force then - error(string.format('Overriding %s', name)) +---@param opts table<string, any> Optional options: +--- - force (boolean): Override an existing +--- predicate of the same name +--- - all (boolean): Use the correct +--- implementation of the match table where +--- capture IDs map to a list of nodes instead +--- of a single node. Defaults to false (for +--- backward compatibility). This option will +--- eventually become the default and removed. +function M.add_predicate(name, handler, opts) + -- Backward compatibility: old signature had "force" as boolean argument + if type(opts) == 'boolean' then + opts = { force = opts } end - predicate_handlers[name] = handler + opts = opts or {} + + if predicate_handlers[name] and not opts.force then + error(string.format('Overriding existing predicate %s', name)) + end + + if opts.all then + predicate_handlers[name] = handler + else + --- @param match table<integer, TSNode[]> + local function wrapper(match, ...) + local m = {} ---@type table<integer, TSNode> + for k, v in pairs(match) do + if type(k) == 'number' then + m[k] = v[#v] + end + end + return handler(m, ...) + end + predicate_handlers[name] = wrapper + end end --- Adds a new directive to be used in queries --- --- Handlers can set match level data by setting directly on the ---- metadata object `metadata.key = value`, additionally, handlers +--- metadata object `metadata.key = value`. Additionally, handlers --- can set node level data by using the capture id on the --- metadata table `metadata[capture_id].key = value` --- ---@param name string Name of the directive, without leading # ----@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[], metadata:table) ---- - match: see |treesitter-query| ---- - node-level data are accessible via `match[capture_id]` ---- - pattern: see |treesitter-query| +---@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) +--- - match: A table mapping capture IDs to a list of captured nodes +--- - pattern: the index of the matching pattern in the query file --- - predicate: list of strings containing the full directive being called, e.g. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` ----@param force boolean|nil -function M.add_directive(name, handler, force) - if directive_handlers[name] and not force then - error(string.format('Overriding %s', name)) +---@param opts table<string, any> Optional options: +--- - force (boolean): Override an existing +--- predicate of the same name +--- - all (boolean): Use the correct +--- implementation of the match table where +--- capture IDs map to a list of nodes instead +--- of a single node. Defaults to false (for +--- backward compatibility). This option will +--- eventually become the default and removed. +function M.add_directive(name, handler, opts) + -- Backward compatibility: old signature had "force" as boolean argument + if type(opts) == 'boolean' then + opts = { force = opts } end - directive_handlers[name] = handler + opts = opts or {} + + if directive_handlers[name] and not opts.force then + error(string.format('Overriding existing directive %s', name)) + end + + if opts.all then + directive_handlers[name] = handler + else + --- @param match table<integer, TSNode[]> + local function wrapper(match, ...) + local m = {} ---@type table<integer, TSNode> + for k, v in pairs(match) do + m[k] = v[#v] + end + handler(m, ...) + end + directive_handlers[name] = wrapper + end end --- Lists the currently available directives to use in queries. @@ -580,8 +731,8 @@ local function is_directive(name) end ---@private ----@param match TSMatch ----@param pattern string +---@param match vim.treesitter.query.TSMatch +---@param pattern integer ---@param source integer|string function Query:match_preds(match, pattern, source) local preds = self.info.patterns[pattern] @@ -591,18 +742,14 @@ function Query:match_preds(match, pattern, source) -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). -- Also, tree-sitter strips the leading # from predicates for us. - local pred_name ---@type string - - local is_not ---@type boolean + local is_not = false -- Skip over directives... they will get processed after all the predicates. if not is_directive(pred[1]) then - if string.sub(pred[1], 1, 4) == 'not-' then - pred_name = string.sub(pred[1], 5) + local pred_name = pred[1] + if pred_name:match('^not%-') then + pred_name = pred_name:sub(5) is_not = true - else - pred_name = pred[1] - is_not = false end local handler = predicate_handlers[pred_name] @@ -623,8 +770,8 @@ function Query:match_preds(match, pattern, source) end ---@private ----@param match TSMatch ----@param metadata TSMetadata +---@param match vim.treesitter.query.TSMatch +---@param metadata vim.treesitter.query.TSMetadata function Query:apply_directives(match, pattern, source, metadata) local preds = self.info.patterns[pattern] @@ -645,14 +792,16 @@ end --- Returns the start and stop value if set else the node's range. -- When the node's range is used, the stop is incremented by 1 -- to make the search inclusive. ----@param start integer ----@param stop integer +---@param start integer|nil +---@param stop integer|nil ---@param node TSNode ---@return integer, integer local function value_or_node_range(start, stop, node) - if start == nil and stop == nil then - local node_start, _, node_stop, _ = node:range() - return node_start, node_stop + 1 -- Make stop inclusive + if start == nil then + start = node:start() + end + if stop == nil then + stop = node:end_() + 1 -- Make stop inclusive end return start, stop @@ -683,10 +832,10 @@ end --- ---@param node TSNode under which the search will occur ---@param source (integer|string) Source buffer or string to extract text from ----@param start integer Starting line for the search ----@param stop integer Stopping line for the search (end-exclusive) +---@param start? integer Starting line for the search. Defaults to `node:start()`. +---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, TSMetadata): +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata): --- capture id, capture node, metadata function Query:iter_captures(node, source, start, stop) if type(source) == 'number' and source == 0 then @@ -695,7 +844,7 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, true, start, stop) + local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch local function iter(end_line) local capture, captured_node, match = raw_iter() local metadata = {} @@ -719,46 +868,55 @@ end --- Iterates the matches of self on a given range. --- ---- Iterate over all matches within a {node}. The arguments are the same as ---- for |Query:iter_captures()| but the iterated values are different: ---- an (1-based) index of the pattern in the query, a table mapping ---- capture indices to nodes, and metadata from any directives processing the match. ---- If the query has more than one pattern, the capture table might be sparse ---- and e.g. `pairs()` method should be used over `ipairs`. ---- Here is an example iterating over all captures in every match: +--- Iterate over all matches within a {node}. The arguments are the same as for +--- |Query:iter_captures()| but the iterated values are different: an (1-based) +--- index of the pattern in the query, a table mapping capture indices to a list +--- of nodes, and metadata from any directives processing the match. +--- +--- WARNING: Set `all=true` to ensure all matching nodes in a match are +--- returned, otherwise only the last node in a match is returned, breaking captures +--- involving quantifiers such as `(comment)+ @comment`. The default option +--- `all=false` is only provided for backward compatibility and will be removed +--- after Nvim 0.10. +--- +--- Example: --- --- ```lua ---- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do ---- for id, node in pairs(match) do +--- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1, { all = true }) do +--- for id, nodes in pairs(match) do --- local name = query.captures[id] ---- -- `node` was captured by the `name` capture in the match +--- for _, node in ipairs(nodes) do +--- -- `node` was captured by the `name` capture in the match --- ---- local node_data = metadata[id] -- Node level metadata ---- ---- -- ... use the info here ... +--- local node_data = metadata[id] -- Node level metadata +--- ... use the info here ... +--- end --- end --- end --- ``` --- +--- ---@param node TSNode under which the search will occur ---@param source (integer|string) Source buffer or string to search ----@param start integer Starting line for the search ----@param stop integer Stopping line for the search (end-exclusive) ----@param opts table|nil Options: +---@param start? integer Starting line for the search. Defaults to `node:start()`. +---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. +---@param opts? table Optional keyword arguments: --- - max_start_depth (integer) if non-zero, sets the maximum start depth --- for each match. This is used to prevent traversing too deep into a tree. ---- Requires treesitter >= 0.20.9. +--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes. +--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is +--- incorrect behavior. This option will eventually become the default and removed. --- ----@return (fun(): integer, table<integer,TSNode>, table): pattern id, match, metadata +---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata function Query:iter_matches(node, source, start, stop, opts) + local all = opts and opts.all if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, false, start, stop, opts) - ---@cast raw_iter fun(): string, any + local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch local function iter() local pattern, match = raw_iter() local metadata = {} @@ -771,14 +929,33 @@ function Query:iter_matches(node, source, start, stop, opts) self:apply_directives(match, pattern, source, metadata) end + + if not all then + -- Convert the match table into the old buggy version for backward + -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all" + -- option! + local old_match = {} ---@type table<integer, TSNode> + for k, v in pairs(match or {}) do + old_match[k] = v[#v] + end + return pattern, old_match, metadata + end + return pattern, match, metadata end return iter end ----@class QueryLinterOpts ----@field langs (string|string[]|nil) ----@field clear (boolean) +--- Optional keyword arguments: +--- @class vim.treesitter.query.lint.Opts +--- @inlinedoc +--- +--- Language(s) to use for checking the query. +--- If multiple languages are specified, queries are validated for all of them +--- @field langs? string|string[] +--- +--- Just clear current lint errors +--- @field clear boolean --- Lint treesitter queries using installed parser, or clear lint errors. --- @@ -793,15 +970,12 @@ end --- of the query file, e.g., if the path ends in `/lua/highlights.scm`, the parser for the --- `lua` language will be used. ---@param buf (integer) Buffer handle ----@param opts (QueryLinterOpts|nil) Optional keyword arguments: ---- - langs (string|string[]|nil) Language(s) to use for checking the query. ---- If multiple languages are specified, queries are validated for all of them ---- - clear (boolean) if `true`, just clear current lint errors +---@param opts? vim.treesitter.query.lint.Opts function M.lint(buf, opts) if opts and opts.clear then - require('vim.treesitter._query_linter').clear(buf) + vim.treesitter._query_linter.clear(buf) else - require('vim.treesitter._query_linter').lint(buf, opts) + vim.treesitter._query_linter.lint(buf, opts) end end @@ -813,13 +987,15 @@ end --- vim.bo.omnifunc = 'v:lua.vim.treesitter.query.omnifunc' --- ``` --- +--- @param findstart 0|1 +--- @param base string function M.omnifunc(findstart, base) - return require('vim.treesitter._query_linter').omnifunc(findstart, base) + return vim.treesitter._query_linter.omnifunc(findstart, base) end --- Opens a live editor to query the buffer you started from. --- ---- Can also be shown with *:EditQuery*. +--- Can also be shown with [:EditQuery](). --- --- If you move the cursor to a capture name ("@foo"), text matching the capture is highlighted in --- the source buffer. The query editor is a scratch buffer, use `:write` to save it. You can find @@ -827,7 +1003,7 @@ end --- --- @param lang? string language to open the query editor for. If omitted, inferred from the current buffer's filetype. function M.edit(lang) - require('vim.treesitter.dev').edit_query(lang) + vim.treesitter.dev.edit_query(lang) end return M |