diff options
Diffstat (limited to 'runtime/lua/vim/treesitter')
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 207 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_meta.lua | 53 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_query_linter.lua | 2 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/dev.lua | 2 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/health.lua | 2 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/highlighter.lua | 77 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/language.lua | 3 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 98 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 225 |
9 files changed, 399 insertions, 270 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index d96cc966de..eecf1ad6b1 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -4,10 +4,21 @@ local Range = require('vim.treesitter._range') local api = vim.api +---Treesitter folding is done in two steps: +---(1) compute the fold levels with the syntax tree and cache the result (`compute_folds_levels`) +---(2) evaluate foldexpr for each window, which reads from the cache (`foldupdate`) ---@class TS.FoldInfo ----@field levels string[] the foldexpr result for each line ----@field levels0 integer[] the raw fold levels ----@field edits? {[1]: integer, [2]: integer} line range edited since the last invocation of the callback scheduled in on_bytes. 0-indexed, end-exclusive. +--- +---@field levels string[] the cached foldexpr result for each line +---@field levels0 integer[] the cached raw fold levels +--- +---The range edited since the last invocation of the callback scheduled in on_bytes. +---Should compute fold levels in this range. +---@field on_bytes_range? Range2 +--- +---The range on which to evaluate foldexpr. +---When in insert mode, the evaluation is deferred to InsertLeave. +---@field foldupdate_range? Range2 local FoldInfo = {} FoldInfo.__index = FoldInfo @@ -80,45 +91,16 @@ function FoldInfo:add_range(srow, erow) list_insert(self.levels0, srow + 1, erow, -1) end ----@package +---@param range Range2 ---@param srow integer ---@param erow_old integer ---@param erow_new integer 0-indexed, exclusive -function FoldInfo:edit_range(srow, erow_old, erow_new) - if self.edits then - self.edits[1] = math.min(srow, self.edits[1]) - if erow_old <= self.edits[2] then - self.edits[2] = self.edits[2] + (erow_new - erow_old) - end - self.edits[2] = math.max(self.edits[2], erow_new) - else - self.edits = { srow, erow_new } +local function edit_range(range, srow, erow_old, erow_new) + range[1] = math.min(srow, range[1]) + if erow_old <= range[2] then + range[2] = range[2] + (erow_new - erow_old) end -end - ----@package ----@return integer? srow ----@return integer? erow 0-indexed, exclusive -function FoldInfo:flush_edit() - if self.edits then - local srow, erow = self.edits[1], self.edits[2] - self.edits = nil - return srow, erow - end -end - ---- If a parser doesn't have any ranges explicitly set, treesitter will ---- return a range with end_row and end_bytes with a value of UINT32_MAX, ---- so clip end_row to the max buffer line. ---- ---- TODO(lewis6991): Handle this generally ---- ---- @param bufnr integer ---- @param erow integer? 0-indexed, exclusive ---- @return integer -local function normalise_erow(bufnr, erow) - local max_erow = api.nvim_buf_line_count(bufnr) - return math.min(erow or max_erow, max_erow) + range[2] = math.max(range[2], erow_new) end -- TODO(lewis6991): Setup a decor provider so injections folds can be parsed @@ -128,9 +110,9 @@ end ---@param srow integer? ---@param erow integer? 0-indexed, exclusive ---@param parse_injections? boolean -local function get_folds_levels(bufnr, info, srow, erow, parse_injections) +local function compute_folds_levels(bufnr, info, srow, erow, parse_injections) srow = srow or 0 - erow = normalise_erow(bufnr, erow) + erow = erow or api.nvim_buf_line_count(bufnr) local parser = ts.get_parser(bufnr) @@ -149,27 +131,43 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) -- Collect folds starting from srow - 1, because we should first subtract the folds that end at -- srow - 1 from the level of srow - 1 to get accurate level of srow. - for id, node, metadata in query:iter_captures(tree:root(), bufnr, math.max(srow - 1, 0), erow) do - if query.captures[id] == 'fold' then - local range = ts.get_range(node, bufnr, metadata[id]) - local start, _, stop, stop_col = Range.unpack4(range) - - if stop_col == 0 then - stop = stop - 1 - end - - local fold_length = stop - start + 1 - - -- Fold only multiline nodes that are not exactly the same as previously met folds - -- Checking against just the previously found fold is sufficient if nodes - -- are returned in preorder or postorder when traversing tree - if - fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) - then - enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 - leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 - prev_start = start - prev_stop = stop + for _, match, metadata in + query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow, { all = true }) + do + for id, nodes in pairs(match) do + if query.captures[id] == 'fold' then + local range = ts.get_range(nodes[1], bufnr, metadata[id]) + local start, _, stop, stop_col = Range.unpack4(range) + + for i = 2, #nodes, 1 do + local node_range = ts.get_range(nodes[i], bufnr, metadata[id]) + local node_start, _, node_stop, node_stop_col = Range.unpack4(node_range) + if node_start < start then + start = node_start + end + if node_stop > stop then + stop = node_stop + stop_col = node_stop_col + end + end + + if stop_col == 0 then + stop = stop - 1 + end + + local fold_length = stop - start + 1 + + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if + fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) + then + enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 + leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 + prev_start = start + prev_stop = stop + end end end end @@ -215,7 +213,7 @@ local function get_folds_levels(bufnr, info, srow, erow, parse_injections) clamped = nestmax end - -- Record the "real" level, so that it can be used as "base" of later get_folds_levels(). + -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels(). info.levels0[lnum] = adjusted info.levels[lnum] = prefix .. tostring(clamped) @@ -236,18 +234,17 @@ local group = api.nvim_create_augroup('treesitter/fold', {}) --- --- Nvim usually automatically updates folds when text changes, but it doesn't work here because --- FoldInfo update is scheduled. So we do it manually. -local function foldupdate(bufnr) - local function do_update() - for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do - api.nvim_win_call(win, function() - if vim.wo.foldmethod == 'expr' then - vim._foldupdate() - end - end) - end +---@package +---@param srow integer +---@param erow integer 0-indexed, exclusive +function FoldInfo:foldupdate(bufnr, srow, erow) + if self.foldupdate_range then + edit_range(self.foldupdate_range, srow, erow, erow) + else + self.foldupdate_range = { srow, erow } end - if api.nvim_get_mode().mode == 'i' then + if api.nvim_get_mode().mode:match('^i') then -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave if #(api.nvim_get_autocmds({ group = group, @@ -259,12 +256,25 @@ local function foldupdate(bufnr) group = group, buffer = bufnr, once = true, - callback = do_update, + callback = function() + self:do_foldupdate(bufnr) + end, }) return end - do_update() + self:do_foldupdate(bufnr) +end + +---@package +function FoldInfo:do_foldupdate(bufnr) + local srow, erow = self.foldupdate_range[1], self.foldupdate_range[2] + self.foldupdate_range = nil + for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do + if vim.wo[win].foldmethod == 'expr' then + vim._foldupdate(win, srow, erow) + end + end end --- Schedule a function only if bufnr is loaded. @@ -272,7 +282,7 @@ end --- * queries seem to use the old buffer state in on_bytes for some unknown reason; --- * to avoid textlock; --- * to avoid infinite recursion: ---- get_folds_levels → parse → _do_callback → on_changedtree → get_folds_levels. +--- compute_folds_levels → parse → _do_callback → on_changedtree → compute_folds_levels. ---@param bufnr integer ---@param fn function local function schedule_if_loaded(bufnr, fn) @@ -289,16 +299,27 @@ end ---@param tree_changes Range4[] local function on_changedtree(bufnr, foldinfo, tree_changes) schedule_if_loaded(bufnr, function() + local srow_upd, erow_upd ---@type integer?, integer? + local max_erow = api.nvim_buf_line_count(bufnr) for _, change in ipairs(tree_changes) do local srow, _, erow, ecol = Range.unpack4(change) - if ecol > 0 then + -- If a parser doesn't have any ranges explicitly set, treesitter will + -- return a range with end_row and end_bytes with a value of UINT32_MAX, + -- so clip end_row to the max buffer line. + -- TODO(lewis6991): Handle this generally + if erow > max_erow then + erow = max_erow + elseif ecol > 0 then erow = erow + 1 end -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. - get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow) + srow = math.max(srow - vim.wo.foldminlines, 0) + compute_folds_levels(bufnr, foldinfo, srow, erow) + srow_upd = srow_upd and math.min(srow_upd, srow) or srow + erow_upd = erow_upd and math.max(erow_upd, erow) or erow end if #tree_changes > 0 then - foldupdate(bufnr) + foldinfo:foldupdate(bufnr, srow_upd, erow_upd) end end) end @@ -335,19 +356,29 @@ local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, foldinfo:add_range(end_row_old, end_row_new) end end - foldinfo:edit_range(start_row, end_row_old, end_row_new) + + if foldinfo.on_bytes_range then + edit_range(foldinfo.on_bytes_range, start_row, end_row_old, end_row_new) + else + foldinfo.on_bytes_range = { start_row, end_row_new } + end + if foldinfo.foldupdate_range then + edit_range(foldinfo.foldupdate_range, start_row, end_row_old, end_row_new) + end -- This callback must not use on_bytes arguments, because they can be outdated when the callback -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing - -- the scheduled callback. So we should collect the edits. + -- the scheduled callback. So we accumulate the edited ranges in `on_bytes_range`. schedule_if_loaded(bufnr, function() - local srow, erow = foldinfo:flush_edit() - if not srow then + if not foldinfo.on_bytes_range then return end + local srow, erow = foldinfo.on_bytes_range[1], foldinfo.on_bytes_range[2] + foldinfo.on_bytes_range = nil -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. - get_folds_levels(bufnr, foldinfo, math.max(srow - vim.wo.foldminlines, 0), erow) - foldupdate(bufnr) + srow = math.max(srow - vim.wo.foldminlines, 0) + compute_folds_levels(bufnr, foldinfo, srow, erow) + foldinfo:foldupdate(bufnr, srow, erow) end) end end @@ -366,7 +397,7 @@ function M.foldexpr(lnum) if not foldinfos[bufnr] then foldinfos[bufnr] = FoldInfo.new() - get_folds_levels(bufnr, foldinfos[bufnr]) + compute_folds_levels(bufnr, foldinfos[bufnr]) parser:register_cbs({ on_changedtree = function(tree_changes) @@ -390,10 +421,10 @@ api.nvim_create_autocmd('OptionSet', { pattern = { 'foldminlines', 'foldnestmax' }, desc = 'Refresh treesitter folds', callback = function() - for _, bufnr in ipairs(vim.tbl_keys(foldinfos)) do + for bufnr, _ in pairs(foldinfos) do foldinfos[bufnr] = FoldInfo.new() - get_folds_levels(bufnr, foldinfos[bufnr]) - foldupdate(bufnr) + compute_folds_levels(bufnr, foldinfos[bufnr]) + foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr)) end end, }) diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua index 19d97d2820..177699a207 100644 --- a/runtime/lua/vim/treesitter/_meta.lua +++ b/runtime/lua/vim/treesitter/_meta.lua @@ -20,6 +20,7 @@ error('Cannot require a meta file') ---@field descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode? ---@field named_descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode? ---@field parent fun(self: TSNode): TSNode? +---@field child_containing_descendant fun(self: TSNode, descendant: TSNode): TSNode? ---@field next_sibling fun(self: TSNode): TSNode? ---@field prev_sibling fun(self: TSNode): TSNode? ---@field next_named_sibling fun(self: TSNode): TSNode? @@ -34,22 +35,6 @@ error('Cannot require a meta file') ---@field byte_length fun(self: TSNode): integer local TSNode = {} ----@param query TSQuery ----@param captures true ----@param start? integer ----@param end_? integer ----@param opts? table ----@return fun(): integer, TSNode, vim.treesitter.query.TSMatch -function TSNode:_rawquery(query, captures, start, end_, opts) end - ----@param query TSQuery ----@param captures false ----@param start? integer ----@param end_? integer ----@param opts? table ----@return fun(): integer, vim.treesitter.query.TSMatch -function TSNode:_rawquery(query, captures, start, end_, opts) end - ---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string) ---@class TSParser: userdata @@ -76,9 +61,17 @@ function TSNode:_rawquery(query, captures, start, end_, opts) end ---@field captures string[] ---@field patterns table<integer, (integer|string)[][]> +--- @param lang string +vim._ts_inspect_language = function(lang) end + ---@return integer vim._ts_get_language_version = function() end +--- @param path string +--- @param lang string +--- @param symbol_name? string +vim._ts_add_language = function(path, lang, symbol_name) end + ---@return integer vim._ts_get_minimum_language_version = function() end @@ -90,3 +83,31 @@ vim._ts_parse_query = function(lang, query) end ---@param lang string ---@return TSParser vim._create_ts_parser = function(lang) end + +--- @class TSQueryMatch: userdata +--- @field captures fun(self: TSQueryMatch): table<integer,TSNode[]> +local TSQueryMatch = {} + +--- @return integer match_id +--- @return integer pattern_index +function TSQueryMatch:info() end + +--- @class TSQueryCursor: userdata +--- @field remove_match fun(self: TSQueryCursor, id: integer) +local TSQueryCursor = {} + +--- @return integer capture +--- @return TSNode captured_node +--- @return TSQueryMatch match +function TSQueryCursor:next_capture() end + +--- @return TSQueryMatch match +function TSQueryCursor:next_match() end + +--- @param node TSNode +--- @param query TSQuery +--- @param start integer? +--- @param stop integer? +--- @param opts? { max_start_depth?: integer, match_limit?: integer} +--- @return TSQueryCursor +function vim._create_ts_querycursor(node, query, start, stop, opts) end diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua index 6216d4e891..12b4cbc7b9 100644 --- a/runtime/lua/vim/treesitter/_query_linter.lua +++ b/runtime/lua/vim/treesitter/_query_linter.lua @@ -122,7 +122,7 @@ local parse = vim.func._memoize(hash_parse, function(node, buf, lang) end) --- @param buf integer ---- @param match vim.treesitter.query.TSMatch +--- @param match table<integer,TSNode[]> --- @param query vim.treesitter.Query --- @param lang_context QueryLinterLanguageContext --- @param diagnostics vim.Diagnostic[] diff --git a/runtime/lua/vim/treesitter/dev.lua b/runtime/lua/vim/treesitter/dev.lua index dc2a14d238..5c91f101c0 100644 --- a/runtime/lua/vim/treesitter/dev.lua +++ b/runtime/lua/vim/treesitter/dev.lua @@ -226,7 +226,7 @@ function TSTreeView:draw(bufnr) text = string.format('(%s', item.node:type()) end else - text = string.format('"%s"', item.node:type():gsub('\n', '\\n'):gsub('"', '\\"')) + text = string.format('%q', item.node:type()):gsub('\n', 'n') end local next = self:get(i + 1) diff --git a/runtime/lua/vim/treesitter/health.lua b/runtime/lua/vim/treesitter/health.lua index a9b066d158..ed3616ef46 100644 --- a/runtime/lua/vim/treesitter/health.lua +++ b/runtime/lua/vim/treesitter/health.lua @@ -24,7 +24,7 @@ function M.check() else local lang = ts.language.inspect(parsername) health.ok( - string.format('Parser: %-10s ABI: %d, path: %s', parsername, lang._abi_version, parser) + string.format('Parser: %-20s ABI: %d, path: %s', parsername, lang._abi_version, parser) ) end end diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index 388680259a..d2f986b874 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -4,7 +4,7 @@ local Range = require('vim.treesitter._range') local ns = api.nvim_create_namespace('treesitter/highlighter') ----@alias vim.treesitter.highlighter.Iter fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata +---@alias vim.treesitter.highlighter.Iter fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch ---@class (private) vim.treesitter.highlighter.Query ---@field private _query vim.treesitter.Query? @@ -215,7 +215,7 @@ end ---@param start_row integer ---@param new_end integer function TSHighlighter:on_bytes(_, _, start_row, _, _, _, _, _, new_end) - api.nvim__buf_redraw_range(self.bufnr, start_row, start_row + new_end + 1) + api.nvim__redraw({ buf = self.bufnr, range = { start_row, start_row + new_end + 1 } }) end ---@package @@ -227,7 +227,7 @@ end ---@param changes Range6[] function TSHighlighter:on_changedtree(changes) for _, ch in ipairs(changes) do - api.nvim__buf_redraw_range(self.bufnr, ch[1], ch[4] + 1) + api.nvim__redraw({ buf = self.bufnr, range = { ch[1], ch[4] + 1 } }) end end @@ -243,6 +243,46 @@ function TSHighlighter:get_query(lang) return self._queries[lang] end +--- @param match TSQueryMatch +--- @param bufnr integer +--- @param capture integer +--- @param metadata vim.treesitter.query.TSMetadata +--- @return string? +local function get_url(match, bufnr, capture, metadata) + ---@type string|number|nil + local url = metadata[capture] and metadata[capture].url + + if not url or type(url) == 'string' then + return url + end + + local captures = match:captures() + + if not captures[url] then + return + end + + -- Assume there is only one matching node. If there is more than one, take the URL + -- from the first. + local other_node = captures[url][1] + + return vim.treesitter.get_node_text(other_node, bufnr, { + metadata = metadata[url], + }) +end + +--- @param capture_name string +--- @return boolean?, integer +local function get_spell(capture_name) + if capture_name == 'spell' then + return true, 0 + elseif capture_name == 'nospell' then + -- Give nospell a higher priority so it always overrides spell captures. + return false, 1 + end + return nil, 0 +end + ---@param self vim.treesitter.highlighter ---@param buf integer ---@param line integer @@ -258,12 +298,16 @@ local function on_line_impl(self, buf, line, is_spell_nav) end if state.iter == nil or state.next_row < line then + -- Mainly used to skip over folds + + -- TODO(lewis6991): Creating a new iterator loses the cached predicate results for query + -- matches. Move this logic inside iter_captures() so we can maintain the cache. state.iter = state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) end while line >= state.next_row do - local capture, node, metadata = state.iter(line) + local capture, node, metadata, match = state.iter(line) local range = { root_end_row + 1, 0, root_end_row + 1, 0 } if node then @@ -275,27 +319,30 @@ local function on_line_impl(self, buf, line, is_spell_nav) local hl = state.highlighter_query:get_hl_from_capture(capture) local capture_name = state.highlighter_query:query().captures[capture] - local spell = nil ---@type boolean? - if capture_name == 'spell' then - spell = true - elseif capture_name == 'nospell' then - spell = false - end - -- Give nospell a higher priority so it always overrides spell captures. - local spell_pri_offset = capture_name == 'nospell' and 1 or 0 + local spell, spell_pri_offset = get_spell(capture_name) + + -- The "priority" attribute can be set at the pattern level or on a particular capture + local priority = ( + tonumber(metadata.priority or metadata[capture] and metadata[capture].priority) + or vim.highlight.priorities.treesitter + ) + spell_pri_offset + + -- The "conceal" attribute can be set at the pattern level or on a particular capture + local conceal = metadata.conceal or metadata[capture] and metadata[capture].conceal + + local url = get_url(match, buf, capture, metadata) if hl and end_row >= line and (not is_spell_nav or spell ~= nil) then - local priority = (tonumber(metadata.priority) or vim.highlight.priorities.treesitter) - + spell_pri_offset api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { end_line = end_row, end_col = end_col, hl_group = hl, ephemeral = true, priority = priority, - conceal = metadata.conceal, + conceal = conceal, spell = spell, + url = url, }) end end diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua index 47abf65332..d0a74daa6c 100644 --- a/runtime/lua/vim/treesitter/language.lua +++ b/runtime/lua/vim/treesitter/language.lua @@ -88,6 +88,9 @@ function M.add(lang, opts) filetype = { filetype, { 'string', 'table' }, true }, }) + -- parser names are assumed to be lowercase (consistent behavior on case-insensitive file systems) + lang = lang:lower() + if vim._ts_has_language(lang) then M.register(lang, filetype) return diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 62714d3f1b..b0812123b9 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -81,7 +81,7 @@ local TSCallbackNames = { ---List of regions this tree should manage and parse. If nil then regions are ---taken from _trees. This is mostly a short-lived cache for included_regions() ---@field private _lang string Language name ----@field private _parent_lang? string Parent language name +---@field private _parent? vim.treesitter.LanguageTree Parent LanguageTree ---@field private _source (integer|string) Buffer or string to parse ---@field private _trees table<integer, TSTree> Reference to parsed tree (one for each language). ---Each key is the index of region, which is synced with _regions and _valid. @@ -106,9 +106,8 @@ LanguageTree.__index = LanguageTree ---@param source (integer|string) Buffer or text string to parse ---@param lang string Root language of this tree ---@param opts vim.treesitter.LanguageTree.new.Opts? ----@param parent_lang? string Parent language name of this tree ---@return vim.treesitter.LanguageTree parser object -function LanguageTree.new(source, lang, opts, parent_lang) +function LanguageTree.new(source, lang, opts) language.add(lang) opts = opts or {} @@ -122,7 +121,6 @@ function LanguageTree.new(source, lang, opts, parent_lang) local self = { _source = source, _lang = lang, - _parent_lang = parent_lang, _children = {}, _trees = {}, _opts = opts, @@ -158,8 +156,10 @@ function LanguageTree:_set_logger() local lang = self:lang() - vim.fn.mkdir(vim.fn.stdpath('log'), 'p') - local logfilename = vim.fs.joinpath(vim.fn.stdpath('log'), 'treesitter.log') + local logdir = vim.fn.stdpath('log') --[[@as string]] + + vim.fn.mkdir(logdir, 'p') + local logfilename = vim.fs.joinpath(logdir, 'treesitter.log') local logfile, openerr = io.open(logfilename, 'a+') @@ -225,7 +225,10 @@ function LanguageTree:_log(...) self._logger('nvim', table.concat(msg, ' ')) end ---- Invalidates this parser and all its children +--- Invalidates this parser and its children. +--- +--- Should only be called when the tracked state of the LanguageTree is not valid against the parse +--- tree in treesitter. Doesn't clear filesystem cache. Called often, so needs to be fast. ---@param reload boolean|nil function LanguageTree:invalidate(reload) self._valid = false @@ -460,24 +463,6 @@ function LanguageTree:parse(range) return self._trees end ----@deprecated Misleading name. Use `LanguageTree:children()` (non-recursive) instead, ---- add recursion yourself if needed. ---- Invokes the callback for each |LanguageTree| and its children recursively ---- ----@param fn fun(tree: vim.treesitter.LanguageTree, lang: string) ----@param include_self boolean|nil Whether to include the invoking tree in the results -function LanguageTree:for_each_child(fn, include_self) - vim.deprecate('LanguageTree:for_each_child()', 'LanguageTree:children()', '0.11') - if include_self then - fn(self, self._lang) - end - - for _, child in pairs(self._children) do - --- @diagnostic disable-next-line:deprecated - child:for_each_child(fn, true) - end -end - --- Invokes the callback for each |LanguageTree| recursively. --- --- Note: This includes the invoking tree's child trees as well. @@ -505,19 +490,25 @@ function LanguageTree:add_child(lang) self:remove_child(lang) end - local child = LanguageTree.new(self._source, lang, self._opts, self:lang()) + local child = LanguageTree.new(self._source, lang, self._opts) -- Inherit recursive callbacks for nm, cb in pairs(self._callbacks_rec) do vim.list_extend(child._callbacks_rec[nm], cb) end + child._parent = self self._children[lang] = child self:_do_callback('child_added', self._children[lang]) return self._children[lang] end +--- @package +function LanguageTree:parent() + return self._parent +end + --- Removes a child language from this |LanguageTree|. --- ---@private @@ -752,7 +743,6 @@ local has_parser = vim.func._memoize(1, function(lang) end) --- Return parser name for language (if exists) or filetype (if registered and exists). ---- Also attempts with the input lower-cased. --- ---@param alias string language or filetype name ---@return string? # resolved parser name @@ -766,19 +756,10 @@ local function resolve_lang(alias) return alias end - if has_parser(alias:lower()) then - return alias:lower() - end - local lang = vim.treesitter.language.get_lang(alias) if lang and has_parser(lang) then return lang end - - lang = vim.treesitter.language.get_lang(alias:lower()) - if lang and has_parser(lang) then - return lang - end end ---@private @@ -792,7 +773,7 @@ function LanguageTree:_get_injection(match, metadata) local combined = metadata['injection.combined'] ~= nil local injection_lang = metadata['injection.language'] --[[@as string?]] local lang = metadata['injection.self'] ~= nil and self:lang() - or metadata['injection.parent'] ~= nil and self._parent_lang + or metadata['injection.parent'] ~= nil and self._parent:lang() or (injection_lang and resolve_lang(injection_lang)) local include_children = metadata['injection.include-children'] ~= nil @@ -802,7 +783,11 @@ function LanguageTree:_get_injection(match, metadata) -- Lang should override any other language tag if name == 'injection.language' then local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] }) - lang = resolve_lang(text) + lang = resolve_lang(text:lower()) -- language names are always lower case + elseif name == 'injection.filename' then + local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] }) + local ft = vim.filetype.match({ filename = text }) + lang = ft and resolve_lang(ft) elseif name == 'injection.content' then ranges = get_node_ranges(node, self._source, metadata[id], include_children) end @@ -1054,20 +1039,19 @@ function LanguageTree:_on_detach(...) end end ---- Registers callbacks for the |LanguageTree|. ----@param cbs table An |nvim_buf_attach()|-like table argument with the following handlers: ---- - `on_bytes` : see |nvim_buf_attach()|, but this will be called _after_ the parsers callback. +--- Registers callbacks for the [LanguageTree]. +---@param cbs table<TSCallbackNameOn,function> An [nvim_buf_attach()]-like table argument with the following handlers: +--- - `on_bytes` : see [nvim_buf_attach()], but this will be called _after_ the parsers callback. --- - `on_changedtree` : a callback that will be called every time the tree has syntactical changes. --- It will be passed two arguments: a table of the ranges (as node ranges) that --- changed and the changed tree. --- - `on_child_added` : emitted when a child is added to the tree. --- - `on_child_removed` : emitted when a child is removed from the tree. ---- - `on_detach` : emitted when the buffer is detached, see |nvim_buf_detach_event|. +--- - `on_detach` : emitted when the buffer is detached, see [nvim_buf_detach_event]. --- Takes one argument, the number of the buffer. --- @param recursive? boolean Apply callbacks recursively for all children. Any new children will --- also inherit the callbacks. function LanguageTree:register_cbs(cbs, recursive) - ---@cast cbs table<TSCallbackNameOn,function> if not cbs then return end @@ -1091,7 +1075,14 @@ end ---@param range Range ---@return boolean local function tree_contains(tree, range) - return Range.contains({ tree:root():range() }, range) + local tree_ranges = tree:included_ranges(false) + + return Range.contains({ + tree_ranges[1][1], + tree_ranges[1][2], + tree_ranges[#tree_ranges][3], + tree_ranges[#tree_ranges][4], + }, range) end --- Determines whether {range} is contained in the |LanguageTree|. @@ -1108,12 +1099,18 @@ function LanguageTree:contains(range) return false end +--- @class vim.treesitter.LanguageTree.tree_for_range.Opts +--- @inlinedoc +--- +--- Ignore injected languages +--- (default: `true`) +--- @field ignore_injections? boolean + --- Gets the tree that contains {range}. --- ---@param range Range4 `{ start_line, start_col, end_line, end_col }` ----@param opts table|nil Optional keyword arguments: ---- - ignore_injections boolean Ignore injected languages (default true) ----@return TSTree|nil +---@param opts? vim.treesitter.LanguageTree.tree_for_range.Opts +---@return TSTree? function LanguageTree:tree_for_range(range, opts) opts = opts or {} local ignore = vim.F.if_nil(opts.ignore_injections, true) @@ -1139,9 +1136,8 @@ end --- Gets the smallest named node that contains {range}. --- ---@param range Range4 `{ start_line, start_col, end_line, end_col }` ----@param opts table|nil Optional keyword arguments: ---- - ignore_injections boolean Ignore injected languages (default true) ----@return TSNode | nil Found node +---@param opts? vim.treesitter.LanguageTree.tree_for_range.Opts +---@return TSNode? function LanguageTree:named_node_for_range(range, opts) local tree = self:tree_for_range(range, opts) if tree then @@ -1152,7 +1148,7 @@ end --- Gets the appropriate language that contains {range}. --- ---@param range Range4 `{ start_line, start_col, end_line, end_col }` ----@return vim.treesitter.LanguageTree Managing {range} +---@return vim.treesitter.LanguageTree tree Managing {range} function LanguageTree:language_for_range(range) for _, child in pairs(self._children) do if child:contains(range) then diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index a086f5e876..ef5c2143a7 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,5 +1,6 @@ local api = vim.api local language = require('vim.treesitter.language') +local memoize = vim.func._memoize local M = {} @@ -88,7 +89,7 @@ end --- ---@param lang string Language to get query for ---@param query_name string Name of the query to load (e.g., "highlights") ----@param is_included (boolean|nil) Internal parameter, most of the time left as `nil` +---@param is_included? boolean Internal parameter, most of the time left as `nil` ---@return string[] query_files List of files to load for given query and language function M.get_files(lang, query_name, is_included) local query_path = string.format('queries/%s/%s.scm', lang, query_name) @@ -211,8 +212,8 @@ end ---@param lang string Language to use for the query ---@param query_name string Name of the query (e.g. "highlights") --- ----@return vim.treesitter.Query|nil : Parsed query. `nil` if no query files are found. -M.get = vim.func._memoize('concat-2', function(lang, query_name) +---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found. +M.get = memoize('concat-2', function(lang, query_name) if explicit_queries[lang][query_name] then return explicit_queries[lang][query_name] end @@ -242,10 +243,10 @@ end) ---@param lang string Language to use for the query ---@param query string Query in s-expr syntax --- ----@return vim.treesitter.Query Parsed query +---@return vim.treesitter.Query : Parsed query --- ----@see |vim.treesitter.query.get()| -M.parse = vim.func._memoize('concat-2', function(lang, query) +---@see [vim.treesitter.query.get()] +M.parse = memoize('concat-2', function(lang, query) language.add(lang) local ts_query = vim._ts_parse_query(lang, query) @@ -258,7 +259,7 @@ end) --- handling the "any" vs "all" semantics. They are called from the --- predicate_handlers table with the appropriate arguments for each predicate. local impl = { - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -293,7 +294,7 @@ local impl = { return not any end, - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -333,7 +334,7 @@ local impl = { end, }) - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -356,7 +357,7 @@ local impl = { end end)(), - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -383,13 +384,7 @@ local impl = { end, } ----@nodoc ----@class vim.treesitter.query.TSMatch ----@field pattern? integer ----@field active? boolean ----@field [integer] TSNode[] - ----@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean +---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -462,17 +457,8 @@ local predicate_handlers = { end for _, node in ipairs(nodes) do - local ancestor_types = {} --- @type table<string, boolean> - for _, type in ipairs({ unpack(predicate, 3) }) do - ancestor_types[type] = true - end - - local cur = node:parent() - while cur do - if ancestor_types[cur:type()] then - return true - end - cur = cur:parent() + if node:__has_ancestor(predicate) then + return true end end return false @@ -504,7 +490,7 @@ predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?'] ---@field [integer] vim.treesitter.query.TSMetadata ---@field [string] integer|string ----@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) +---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -534,6 +520,9 @@ local directive_handlers = { ['offset!'] = function(match, _, _, pred, metadata) local capture_id = pred[2] --[[@as integer]] local nodes = match[capture_id] + if not nodes or #nodes == 0 then + return + end assert(#nodes == 1, '#offset! does not support captures on multiple nodes') local node = nodes[1] @@ -567,6 +556,9 @@ local directive_handlers = { assert(type(id) == 'number') local nodes = match[id] + if not nodes or #nodes == 0 then + return + end assert(#nodes == 1, '#gsub! does not support captures on multiple nodes') local node = nodes[1] local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' @@ -589,6 +581,9 @@ local directive_handlers = { assert(type(capture_id) == 'number') local nodes = match[capture_id] + if not nodes or #nodes == 0 then + return + end assert(#nodes == 1, '#trim! does not support captures on multiple nodes') local node = nodes[1] @@ -618,20 +613,23 @@ local directive_handlers = { end, } +--- @class vim.treesitter.query.add_predicate.Opts +--- @inlinedoc +--- +--- Override an existing predicate of the same name +--- @field force? boolean +--- +--- Use the correct implementation of the match table where capture IDs map to +--- a list of nodes instead of a single node. Defaults to false (for backward +--- compatibility). This option will eventually become the default and removed. +--- @field all? boolean + --- Adds a new predicate to be used in queries --- ---@param name string Name of the predicate, without leading # ----@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) +---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) --- - see |vim.treesitter.query.add_directive()| for argument meanings ----@param opts table<string, any> Optional options: ---- - force (boolean): Override an existing ---- predicate of the same name ---- - all (boolean): Use the correct ---- implementation of the match table where ---- capture IDs map to a list of nodes instead ---- of a single node. Defaults to false (for ---- backward compatibility). This option will ---- eventually become the default and removed. +---@param opts vim.treesitter.query.add_predicate.Opts function M.add_predicate(name, handler, opts) -- Backward compatibility: old signature had "force" as boolean argument if type(opts) == 'boolean' then @@ -669,20 +667,12 @@ end --- metadata table `metadata[capture_id].key = value` --- ---@param name string Name of the directive, without leading # ----@param handler function(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) +---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: table) --- - match: A table mapping capture IDs to a list of captured nodes --- - pattern: the index of the matching pattern in the query file --- - predicate: list of strings containing the full directive being called, e.g. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` ----@param opts table<string, any> Optional options: ---- - force (boolean): Override an existing ---- predicate of the same name ---- - all (boolean): Use the correct ---- implementation of the match table where ---- capture IDs map to a list of nodes instead ---- of a single node. Defaults to false (for ---- backward compatibility). This option will ---- eventually become the default and removed. +---@param opts vim.treesitter.query.add_predicate.Opts function M.add_directive(name, handler, opts) -- Backward compatibility: old signature had "force" as boolean argument if type(opts) == 'boolean' then @@ -711,13 +701,13 @@ function M.add_directive(name, handler, opts) end --- Lists the currently available directives to use in queries. ----@return string[] List of supported directives. +---@return string[] : Supported directives. function M.list_directives() return vim.tbl_keys(directive_handlers) end --- Lists the currently available predicates to use in queries. ----@return string[] List of supported predicates. +---@return string[] : Supported predicates. function M.list_predicates() return vim.tbl_keys(predicate_handlers) end @@ -731,13 +721,19 @@ local function is_directive(name) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param pattern integer +---@param match TSQueryMatch ---@param source integer|string -function Query:match_preds(match, pattern, source) +function Query:match_preds(match, source) + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return true + end + + local captures = match:captures() + + for _, pred in pairs(preds) do -- Here we only want to return if a predicate DOES NOT match, and -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). @@ -759,7 +755,7 @@ function Query:match_preds(match, pattern, source) return false end - local pred_matches = handler(match, pattern, source, pred) + local pred_matches = handler(captures, pattern, source, pred) if not xor(is_not, pred_matches) then return false @@ -770,30 +766,40 @@ function Query:match_preds(match, pattern, source) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param metadata vim.treesitter.query.TSMetadata -function Query:apply_directives(match, pattern, source, metadata) +---@param match TSQueryMatch +---@return vim.treesitter.query.TSMetadata metadata +function Query:apply_directives(match, source) + ---@type vim.treesitter.query.TSMetadata + local metadata = {} + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return metadata + end + + local captures = match:captures() + + for _, pred in pairs(preds) do if is_directive(pred[1]) then local handler = directive_handlers[pred[1]] if not handler then error(string.format('No handler for %s', pred[1])) - return end - handler(match, pattern, source, pred, metadata) + handler(captures, pattern, source, pred, metadata) end end + + return metadata end --- Returns the start and stop value if set else the node's range. -- When the node's range is used, the stop is incremented by 1 -- to make the search inclusive. ----@param start integer|nil ----@param stop integer|nil +---@param start integer? +---@param stop integer? ---@param node TSNode ---@return integer, integer local function value_or_node_range(start, stop, node) @@ -807,6 +813,12 @@ local function value_or_node_range(start, stop, node) return start, stop end +--- @param match TSQueryMatch +--- @return integer +local function match_id_hash(_, match) + return (match:info()) +end + --- Iterate over all captures from all matches inside {node} --- --- {source} is needed if the query contains predicates; then the caller @@ -816,12 +828,13 @@ end --- as the {node}, i.e., to get syntax highlight matches in the current --- viewport). When omitted, the {start} and {stop} row values are used from the given node. --- ---- The iterator returns three values: a numeric id identifying the capture, ---- the captured node, and metadata from any directives processing the match. +--- The iterator returns four values: a numeric id identifying the capture, +--- the captured node, metadata from any directives processing the match, +--- and the match itself. --- The following example shows how to get captures by name: --- --- ```lua ---- for id, node, metadata in query:iter_captures(tree:root(), bufnr, first, last) do +--- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do --- local name = query.captures[id] -- name of the capture in the query --- -- typically useful info about the node: --- local type = node:type() -- type of the captured node @@ -835,8 +848,10 @@ end ---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata): ---- capture id, capture node, metadata +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch): +--- capture id, capture node, metadata, match +--- +---@note Captures are only returned if the query pattern of a specific capture contained predicates. function Query:iter_captures(node, source, start, stop) if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() @@ -844,24 +859,30 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) + + local apply_directives = memoize(match_id_hash, self.apply_directives, true) + local match_preds = memoize(match_id_hash, self.match_preds, true) + local function iter(end_line) - local capture, captured_node, match = raw_iter() - local metadata = {} - - if match ~= nil then - local active = self:match_preds(match, match.pattern, source) - match.active = active - if not active then - if end_line and captured_node:range() > end_line then - return nil, captured_node, nil - end - return iter(end_line) -- tail call: try next match - end + local capture, captured_node, match = cursor:next_capture() - self:apply_directives(match, match.pattern, source, metadata) + if not capture then + return end - return capture, captured_node, metadata + + if not match_preds(self, match, source) then + local match_id = match:info() + cursor:remove_match(match_id) + if end_line and captured_node:range() > end_line then + return nil, captured_node, nil, nil + end + return iter(end_line) -- tail call: try next match + end + + local metadata = apply_directives(self, match, source) + + return capture, captured_node, metadata, match end return iter end @@ -903,45 +924,55 @@ end ---@param opts? table Optional keyword arguments: --- - max_start_depth (integer) if non-zero, sets the maximum start depth --- for each match. This is used to prevent traversing too deep into a tree. +--- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256). --- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes. --- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is --- incorrect behavior. This option will eventually become the default and removed. --- ---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata function Query:iter_matches(node, source, start, stop, opts) - local all = opts and opts.all + opts = opts or {} + opts.match_limit = opts.match_limit or 256 + if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts) + local function iter() - local pattern, match = raw_iter() - local metadata = {} + local match = cursor:next_match() - if match ~= nil then - local active = self:match_preds(match, pattern, source) - if not active then - return iter() -- tail call: try next match - end + if not match then + return + end - self:apply_directives(match, pattern, source, metadata) + local match_id, pattern = match:info() + + if not self:match_preds(match, source) then + cursor:remove_match(match_id) + return iter() -- tail call: try next match end - if not all then + local metadata = self:apply_directives(match, source) + + local captures = match:captures() + + if not opts.all then -- Convert the match table into the old buggy version for backward -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all" -- option! local old_match = {} ---@type table<integer, TSNode> - for k, v in pairs(match or {}) do + for k, v in pairs(captures or {}) do old_match[k] = v[#v] end return pattern, old_match, metadata end - return pattern, match, metadata + -- TODO(lewis6991): create a new function that returns {match, metadata} + return pattern, captures, metadata end return iter end |