diff options
author | Josh Rahm <joshuarahm@gmail.com> | 2025-02-05 23:09:29 +0000 |
---|---|---|
committer | Josh Rahm <joshuarahm@gmail.com> | 2025-02-05 23:09:29 +0000 |
commit | d5f194ce780c95821a855aca3c19426576d28ae0 (patch) | |
tree | d45f461b19f9118ad2bb1f440a7a08973ad18832 /runtime/lua/vim/treesitter | |
parent | c5d770d311841ea5230426cc4c868e8db27300a8 (diff) | |
parent | 44740e561fc93afe3ebecfd3618bda2d2abeafb0 (diff) | |
download | rneovim-rahm.tar.gz rneovim-rahm.tar.bz2 rneovim-rahm.zip |
Diffstat (limited to 'runtime/lua/vim/treesitter')
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 325 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_meta/misc.lua | 8 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_meta/tsnode.lua | 16 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_query_linter.lua | 6 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/dev.lua | 25 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/highlighter.lua | 30 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/language.lua | 7 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 356 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 418 |
9 files changed, 749 insertions, 442 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index 7237d2e7d4..38318347a7 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -19,76 +19,36 @@ local api = vim.api ---The range on which to evaluate foldexpr. ---When in insert mode, the evaluation is deferred to InsertLeave. ---@field foldupdate_range? Range2 +--- +---The treesitter parser associated with this buffer. +---@field parser? vim.treesitter.LanguageTree local FoldInfo = {} FoldInfo.__index = FoldInfo ---@private -function FoldInfo.new() +---@param bufnr integer +function FoldInfo.new(bufnr) return setmetatable({ levels0 = {}, levels = {}, + parser = ts.get_parser(bufnr, nil, { error = false }), }, FoldInfo) end ---- Efficiently remove items from middle of a list a list. ---- ---- Calling table.remove() in a loop will re-index the tail of the table on ---- every iteration, instead this function will re-index the table exactly ---- once. ---- ---- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 ---- ----@param t any[] ----@param first integer ----@param last integer -local function list_remove(t, first, last) - local n = #t - for i = 0, n - first do - t[first + i] = t[last + 1 + i] - t[last + 1 + i] = nil - end -end - ---@package ---@param srow integer ---@param erow integer 0-indexed, exclusive function FoldInfo:remove_range(srow, erow) - list_remove(self.levels, srow + 1, erow) - list_remove(self.levels0, srow + 1, erow) -end - ---- Efficiently insert items into the middle of a list. ---- ---- Calling table.insert() in a loop will re-index the tail of the table on ---- every iteration, instead this function will re-index the table exactly ---- once. ---- ---- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 ---- ----@param t any[] ----@param first integer ----@param last integer ----@param v any -local function list_insert(t, first, last, v) - local n = #t - - -- Shift table forward - for i = n - first, 0, -1 do - t[last + 1 + i] = t[first + i] - end - - -- Fill in new values - for i = first, last do - t[i] = v - end + vim._list_remove(self.levels, srow + 1, erow) + vim._list_remove(self.levels0, srow + 1, erow) end ---@package ---@param srow integer ---@param erow integer 0-indexed, exclusive function FoldInfo:add_range(srow, erow) - list_insert(self.levels, srow + 1, erow, -1) - list_insert(self.levels0, srow + 1, erow, -1) + vim._list_insert(self.levels, srow + 1, erow, -1) + vim._list_insert(self.levels0, srow + 1, erow, -1) end ---@param range Range2 @@ -109,111 +69,122 @@ end ---@param info TS.FoldInfo ---@param srow integer? ---@param erow integer? 0-indexed, exclusive ----@param parse_injections? boolean -local function compute_folds_levels(bufnr, info, srow, erow, parse_injections) +---@param callback function? +local function compute_folds_levels(bufnr, info, srow, erow, callback) srow = srow or 0 erow = erow or api.nvim_buf_line_count(bufnr) - local parser = assert(ts.get_parser(bufnr, nil, { error = false })) - - parser:parse(parse_injections and { srow, erow } or nil) - - local enter_counts = {} ---@type table<integer, integer> - local leave_counts = {} ---@type table<integer, integer> - local prev_start = -1 - local prev_stop = -1 + local parser = info.parser + if not parser then + return + end - parser:for_each_tree(function(tree, ltree) - local query = ts.query.get(ltree:lang(), 'folds') - if not query then + parser:parse(nil, function(_, trees) + if not trees then return end - -- Collect folds starting from srow - 1, because we should first subtract the folds that end at - -- srow - 1 from the level of srow - 1 to get accurate level of srow. - for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do - for id, nodes in pairs(match) do - if query.captures[id] == 'fold' then - local range = ts.get_range(nodes[1], bufnr, metadata[id]) - local start, _, stop, stop_col = Range.unpack4(range) - - if #nodes > 1 then - -- assumes nodes are ordered by range - local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id]) - local _, _, end_stop, end_stop_col = Range.unpack4(end_range) - stop = end_stop - stop_col = end_stop_col - end + local enter_counts = {} ---@type table<integer, integer> + local leave_counts = {} ---@type table<integer, integer> + local prev_start = -1 + local prev_stop = -1 - if stop_col == 0 then - stop = stop - 1 - end + parser:for_each_tree(function(tree, ltree) + local query = ts.query.get(ltree:lang(), 'folds') + if not query then + return + end - local fold_length = stop - start + 1 - - -- Fold only multiline nodes that are not exactly the same as previously met folds - -- Checking against just the previously found fold is sufficient if nodes - -- are returned in preorder or postorder when traversing tree - if - fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) - then - enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 - leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 - prev_start = start - prev_stop = stop + -- Collect folds starting from srow - 1, because we should first subtract the folds that end at + -- srow - 1 from the level of srow - 1 to get accurate level of srow. + for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do + for id, nodes in pairs(match) do + if query.captures[id] == 'fold' then + local range = ts.get_range(nodes[1], bufnr, metadata[id]) + local start, _, stop, stop_col = Range.unpack4(range) + + if #nodes > 1 then + -- assumes nodes are ordered by range + local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id]) + local _, _, end_stop, end_stop_col = Range.unpack4(end_range) + stop = end_stop + stop_col = end_stop_col + end + + if stop_col == 0 then + stop = stop - 1 + end + + local fold_length = stop - start + 1 + + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if + fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) + then + enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1 + leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1 + prev_start = start + prev_stop = stop + end end end end - end - end) + end) - local nestmax = vim.wo.foldnestmax - local level0_prev = info.levels0[srow] or 0 - local leave_prev = leave_counts[srow] or 0 - - -- We now have the list of fold opening and closing, fill the gaps and mark where fold start - for lnum = srow + 1, erow do - local enter_line = enter_counts[lnum] or 0 - local leave_line = leave_counts[lnum] or 0 - local level0 = level0_prev - leave_prev + enter_line - - -- Determine if it's the start/end of a fold - -- NB: vim's fold-expr interface does not have a mechanism to indicate that - -- two (or more) folds start at this line, so it cannot distinguish between - -- ( \n ( \n )) \n (( \n ) \n ) - -- versus - -- ( \n ( \n ) \n ( \n ) \n ) - -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and - -- vim interprets as the second case. - -- If it did have such a mechanism, (clamped - clamped_prev) - -- would be the correct number of starts to pass on. - local adjusted = level0 ---@type integer - local prefix = '' - if enter_line > 0 then - prefix = '>' - if leave_line > 0 then - -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line - -- so that f2 gets the correct level on this line. This may reduce the size of f1 below - -- foldminlines, but we don't handle it for simplicity. - adjusted = level0 - leave_line - leave_line = 0 + local nestmax = vim.wo.foldnestmax + local level0_prev = info.levels0[srow] or 0 + local leave_prev = leave_counts[srow] or 0 + + -- We now have the list of fold opening and closing, fill the gaps and mark where fold start + for lnum = srow + 1, erow do + local enter_line = enter_counts[lnum] or 0 + local leave_line = leave_counts[lnum] or 0 + local level0 = level0_prev - leave_prev + enter_line + + -- Determine if it's the start/end of a fold + -- NB: vim's fold-expr interface does not have a mechanism to indicate that + -- two (or more) folds start at this line, so it cannot distinguish between + -- ( \n ( \n )) \n (( \n ) \n ) + -- versus + -- ( \n ( \n ) \n ( \n ) \n ) + -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and + -- vim interprets as the second case. + -- If it did have such a mechanism, (clamped - clamped_prev) + -- would be the correct number of starts to pass on. + local adjusted = level0 ---@type integer + local prefix = '' + if enter_line > 0 then + prefix = '>' + if leave_line > 0 then + -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line + -- so that f2 gets the correct level on this line. This may reduce the size of f1 below + -- foldminlines, but we don't handle it for simplicity. + adjusted = level0 - leave_line + leave_line = 0 + end end - end - -- Clamp at foldnestmax. - local clamped = adjusted - if adjusted > nestmax then - prefix = '' - clamped = nestmax - end + -- Clamp at foldnestmax. + local clamped = adjusted + if adjusted > nestmax then + prefix = '' + clamped = nestmax + end - -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels(). - info.levels0[lnum] = adjusted - info.levels[lnum] = prefix .. tostring(clamped) + -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels(). + info.levels0[lnum] = adjusted + info.levels[lnum] = prefix .. tostring(clamped) - leave_prev = leave_line - level0_prev = adjusted - end + leave_prev = leave_line + level0_prev = adjusted + end + + if callback then + callback() + end + end) end local M = {} @@ -221,7 +192,7 @@ local M = {} ---@type table<integer,TS.FoldInfo> local foldinfos = {} -local group = api.nvim_create_augroup('treesitter/fold', {}) +local group = api.nvim_create_augroup('nvim.treesitter.fold', {}) --- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that --- the user doesn't use different foldexpr for the same buffer). @@ -298,12 +269,19 @@ local function schedule_if_loaded(bufnr, fn) end ---@param bufnr integer ----@param foldinfo TS.FoldInfo ---@param tree_changes Range4[] -local function on_changedtree(bufnr, foldinfo, tree_changes) +local function on_changedtree(bufnr, tree_changes) schedule_if_loaded(bufnr, function() + -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked. + local foldinfo = foldinfos[bufnr] + if not foldinfo then + return + end + local srow_upd, erow_upd ---@type integer?, integer? local max_erow = api.nvim_buf_line_count(bufnr) + -- TODO(ribru17): Replace this with a proper .all() awaiter once #19624 is resolved + local iterations = 0 for _, change in ipairs(tree_changes) do local srow, _, erow, ecol = Range.unpack4(change) -- If a parser doesn't have any ranges explicitly set, treesitter will @@ -317,24 +295,31 @@ local function on_changedtree(bufnr, foldinfo, tree_changes) end -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. srow = math.max(srow - vim.wo.foldminlines, 0) - compute_folds_levels(bufnr, foldinfo, srow, erow) srow_upd = srow_upd and math.min(srow_upd, srow) or srow erow_upd = erow_upd and math.max(erow_upd, erow) or erow - end - if #tree_changes > 0 then - foldinfo:foldupdate(bufnr, srow_upd, erow_upd) + compute_folds_levels(bufnr, foldinfo, srow, erow, function() + iterations = iterations + 1 + if iterations == #tree_changes then + foldinfo:foldupdate(bufnr, srow_upd, erow_upd) + end + end) end end) end ---@param bufnr integer ----@param foldinfo TS.FoldInfo ---@param start_row integer ---@param old_row integer ---@param old_col integer ---@param new_row integer ---@param new_col integer -local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, new_row, new_col) +local function on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col) + -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked. + local foldinfo = foldinfos[bufnr] + if not foldinfo then + return + end + -- extend the end to fully include the range local end_row_old = start_row + old_row + 1 local end_row_new = start_row + new_row + 1 @@ -373,15 +358,16 @@ local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing -- the scheduled callback. So we accumulate the edited ranges in `on_bytes_range`. schedule_if_loaded(bufnr, function() - if not foldinfo.on_bytes_range then + if not (foldinfo.on_bytes_range and foldinfos[bufnr]) then return end local srow, erow = foldinfo.on_bytes_range[1], foldinfo.on_bytes_range[2] foldinfo.on_bytes_range = nil -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit. srow = math.max(srow - vim.wo.foldminlines, 0) - compute_folds_levels(bufnr, foldinfo, srow, erow) - foldinfo:foldupdate(bufnr, srow, erow) + compute_folds_levels(bufnr, foldinfo, srow, erow, function() + foldinfo:foldupdate(bufnr, srow, erow) + end) end) end end @@ -392,22 +378,30 @@ function M.foldexpr(lnum) lnum = lnum or vim.v.lnum local bufnr = api.nvim_get_current_buf() - local parser = ts.get_parser(bufnr, nil, { error = false }) - if not parser then - return '0' - end - if not foldinfos[bufnr] then - foldinfos[bufnr] = FoldInfo.new() + foldinfos[bufnr] = FoldInfo.new(bufnr) + api.nvim_create_autocmd({ 'BufUnload', 'VimEnter' }, { + buffer = bufnr, + once = true, + callback = function() + foldinfos[bufnr] = nil + end, + }) + + local parser = foldinfos[bufnr].parser + if not parser then + return '0' + end + compute_folds_levels(bufnr, foldinfos[bufnr]) parser:register_cbs({ on_changedtree = function(tree_changes) - on_changedtree(bufnr, foldinfos[bufnr], tree_changes) + on_changedtree(bufnr, tree_changes) end, on_bytes = function(_, _, start_row, start_col, _, old_row, old_col, _, new_row, new_col, _) - on_bytes(bufnr, foldinfos[bufnr], start_row, start_col, old_row, old_col, new_row, new_col) + on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col) end, on_detach = function() @@ -423,10 +417,17 @@ api.nvim_create_autocmd('OptionSet', { pattern = { 'foldminlines', 'foldnestmax' }, desc = 'Refresh treesitter folds', callback = function() - for bufnr, _ in pairs(foldinfos) do - foldinfos[bufnr] = FoldInfo.new() - compute_folds_levels(bufnr, foldinfos[bufnr]) - foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr)) + local buf = api.nvim_get_current_buf() + local bufs = vim.v.option_type == 'global' and vim.tbl_keys(foldinfos) + or foldinfos[buf] and { buf } + or {} + for _, bufnr in ipairs(bufs) do + foldinfos[bufnr] = FoldInfo.new(bufnr) + api.nvim_buf_call(bufnr, function() + compute_folds_levels(bufnr, foldinfos[bufnr], nil, nil, function() + foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr)) + end) + end) end end, }) diff --git a/runtime/lua/vim/treesitter/_meta/misc.lua b/runtime/lua/vim/treesitter/_meta/misc.lua index 33701ef254..c532257f49 100644 --- a/runtime/lua/vim/treesitter/_meta/misc.lua +++ b/runtime/lua/vim/treesitter/_meta/misc.lua @@ -20,9 +20,15 @@ error('Cannot require a meta file') ---@class (exact) TSQueryInfo ---@field captures string[] ---@field patterns table<integer, (integer|string)[][]> +--- +---@class TSLangInfo +---@field fields string[] +---@field symbols table<string,boolean> +---@field _wasm boolean +---@field _abi_version integer --- @param lang string ---- @return table +--- @return TSLangInfo vim._ts_inspect_language = function(lang) end ---@return integer diff --git a/runtime/lua/vim/treesitter/_meta/tsnode.lua b/runtime/lua/vim/treesitter/_meta/tsnode.lua index d982b6a505..552905c3f0 100644 --- a/runtime/lua/vim/treesitter/_meta/tsnode.lua +++ b/runtime/lua/vim/treesitter/_meta/tsnode.lua @@ -68,12 +68,6 @@ function TSNode:named_child_count() end --- @return TSNode? function TSNode:named_child(index) end ---- Get the node's child that contains {descendant}. ---- @param descendant TSNode ---- @return TSNode? ---- @deprecated -function TSNode:child_containing_descendant(descendant) end - --- Get the node's child that contains {descendant} (includes {descendant}). --- --- For example, with the following node hierarchy: @@ -109,17 +103,9 @@ function TSNode:end_() end --- - end row --- - end column --- - end byte (if {include_bytes} is `true`) ---- @param include_bytes boolean? -function TSNode:range(include_bytes) end - ---- @nodoc --- @param include_bytes false? --- @return integer, integer, integer, integer -function TSNode:range(include_bytes) end - ---- @nodoc ---- @param include_bytes true ---- @return integer, integer, integer, integer, integer, integer +--- @overload fun(self: TSNode, include_bytes: true): integer, integer, integer, integer, integer, integer function TSNode:range(include_bytes) end --- Get the node's type as a string. diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua index a825505378..3dfc6b0cfe 100644 --- a/runtime/lua/vim/treesitter/_query_linter.lua +++ b/runtime/lua/vim/treesitter/_query_linter.lua @@ -1,6 +1,6 @@ local api = vim.api -local namespace = api.nvim_create_namespace('vim.treesitter.query_linter') +local namespace = api.nvim_create_namespace('nvim.treesitter.query_linter') local M = {} @@ -138,7 +138,9 @@ local function lint_match(buf, match, query, lang_context, diagnostics) -- perform language-independent checks only for first lang if lang_context.is_first_lang and cap_id == 'error' then local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') - add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text) + ---@diagnostic disable-next-line: missing-fields LuaLS varargs bug + local range = { node:range() } --- @type Range4 + add_lint_for_node(diagnostics, range, 'Syntax error: ' .. node_text) end -- other checks rely on Neovim parser introspection diff --git a/runtime/lua/vim/treesitter/dev.lua b/runtime/lua/vim/treesitter/dev.lua index 26817cdba5..24dd8243db 100644 --- a/runtime/lua/vim/treesitter/dev.lua +++ b/runtime/lua/vim/treesitter/dev.lua @@ -119,7 +119,7 @@ function TSTreeView:new(bufnr, lang) end local t = { - ns = api.nvim_create_namespace('treesitter/dev-inspect'), + ns = api.nvim_create_namespace('nvim.treesitter.dev_inspect'), nodes = nodes, named = named, ---@type vim.treesitter.dev.TSTreeViewOpts @@ -135,15 +135,7 @@ function TSTreeView:new(bufnr, lang) return t end -local decor_ns = api.nvim_create_namespace('ts.dev') - ----@param range Range4 ----@return string -local function range_to_string(range) - ---@type integer, integer, integer, integer - local row, col, end_row, end_col = unpack(range) - return string.format('[%d, %d] - [%d, %d]', row, col, end_row, end_col) -end +local decor_ns = api.nvim_create_namespace('nvim.treesitter.dev') ---@param w integer ---@return boolean closed Whether the window was closed. @@ -227,14 +219,17 @@ function TSTreeView:draw(bufnr) local lang_hl_marks = {} ---@type table[] for i, item in self:iter() do - local range_str = range_to_string({ item.node:range() }) + local range_str = ('[%d, %d] - [%d, %d]'):format(item.node:range()) local lang_str = self.opts.lang and string.format(' %s', item.lang) or '' local text ---@type string if item.node:named() then - text = string.format('(%s', item.node:type()) + text = string.format('(%s%s', item.node:missing() and 'MISSING ' or '', item.node:type()) else text = string.format('%q', item.node:type()):gsub('\n', 'n') + if item.node:missing() then + text = string.format('(MISSING %s)', text) + end end if item.field then text = string.format('%s: %s', item.field, text) @@ -442,7 +437,7 @@ function M.inspect_tree(opts) end, }) - local group = api.nvim_create_augroup('treesitter/dev', {}) + local group = api.nvim_create_augroup('nvim.treesitter.dev', {}) api.nvim_create_autocmd('CursorMoved', { group = group, @@ -547,7 +542,7 @@ function M.inspect_tree(opts) }) end -local edit_ns = api.nvim_create_namespace('treesitter/dev-edit') +local edit_ns = api.nvim_create_namespace('nvim.treesitter.dev_edit') ---@param query_win integer ---@param base_win integer @@ -633,7 +628,7 @@ function M.edit_query(lang) -- can infer the language later. api.nvim_buf_set_name(query_buf, string.format('%s/query_editor.scm', lang)) - local group = api.nvim_create_augroup('treesitter/dev-edit', {}) + local group = api.nvim_create_augroup('nvim.treesitter.dev_edit', {}) api.nvim_create_autocmd({ 'TextChanged', 'InsertLeave' }, { group = group, buffer = query_buf, diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index 8ce8652f7d..6dd47811bd 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -2,7 +2,7 @@ local api = vim.api local query = vim.treesitter.query local Range = require('vim.treesitter._range') -local ns = api.nvim_create_namespace('treesitter/highlighter') +local ns = api.nvim_create_namespace('nvim.treesitter.highlighter') ---@alias vim.treesitter.highlighter.Iter fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch @@ -69,6 +69,7 @@ end ---@field private _queries table<string,vim.treesitter.highlighter.Query> ---@field tree vim.treesitter.LanguageTree ---@field private redraw_count integer +---@field parsing boolean true if we are parsing asynchronously local TSHighlighter = { active = {}, } @@ -147,8 +148,6 @@ function TSHighlighter.new(tree, opts) vim.opt_local.spelloptions:append('noplainbuffer') end) - self.tree:parse() - return self end @@ -161,7 +160,10 @@ function TSHighlighter:destroy() vim.bo[self.bufnr].spelloptions = self.orig_spelloptions vim.b[self.bufnr].ts_highlight = nil if vim.g.syntax_on == 1 then - api.nvim_exec_autocmds('FileType', { group = 'syntaxset', buffer = self.bufnr }) + api.nvim_exec_autocmds( + 'FileType', + { group = 'syntaxset', buffer = self.bufnr, modeline = false } + ) end end end @@ -299,6 +301,8 @@ local function on_line_impl(self, buf, line, is_spell_nav) state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) end + local captures = state.highlighter_query:query().captures + while line >= state.next_row do local capture, node, metadata, match = state.iter(line) @@ -311,7 +315,7 @@ local function on_line_impl(self, buf, line, is_spell_nav) if capture then local hl = state.highlighter_query:get_hl_from_capture(capture) - local capture_name = state.highlighter_query:query().captures[capture] + local capture_name = captures[capture] local spell, spell_pri_offset = get_spell(capture_name) @@ -382,19 +386,23 @@ function TSHighlighter._on_spell_nav(_, _, buf, srow, _, erow, _) end ---@private ----@param _win integer ---@param buf integer ---@param topline integer ---@param botline integer -function TSHighlighter._on_win(_, _win, buf, topline, botline) +function TSHighlighter._on_win(_, _, buf, topline, botline) local self = TSHighlighter.active[buf] - if not self then + if not self or self.parsing then return false end - self.tree:parse({ topline, botline + 1 }) - self:prepare_highlight_states(topline, botline + 1) + self.parsing = self.tree:parse({ topline, botline + 1 }, function(_, trees) + if trees and self.parsing then + self.parsing = false + api.nvim__redraw({ buf = buf, valid = false, flush = false }) + end + end) == nil self.redraw_count = self.redraw_count + 1 - return true + self:prepare_highlight_states(topline, botline) + return #self._highlight_states > 0 end api.nvim_set_decoration_provider(ns, { diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua index 446051dfd7..16d19bfc5a 100644 --- a/runtime/lua/vim/treesitter/language.lua +++ b/runtime/lua/vim/treesitter/language.lua @@ -133,8 +133,9 @@ function M.add(lang, opts) path = paths[1] end - return loadparser(path, lang, symbol_name) or nil, - string.format('Cannot load parser %s for language "%s"', path, lang) + local res = loadparser(path, lang, symbol_name) + return res, + res == nil and string.format('Cannot load parser %s for language "%s"', path, lang) or nil end --- @param x string|string[] @@ -174,7 +175,7 @@ end --- (`"`). --- ---@param lang string Language ----@return table +---@return TSLangInfo function M.inspect(lang) M.add(lang) return vim._ts_inspect_language(lang) diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 4b42164dc8..ea745c4deb 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -44,6 +44,8 @@ local query = require('vim.treesitter.query') local language = require('vim.treesitter.language') local Range = require('vim.treesitter._range') +local default_parse_timeout_ms = 3 + ---@alias TSCallbackName ---| 'changedtree' ---| 'bytes' @@ -58,6 +60,8 @@ local Range = require('vim.treesitter._range') ---| 'on_child_added' ---| 'on_child_removed' +---@alias ParserThreadState { timeout: integer? } + --- @type table<TSCallbackNameOn,TSCallbackName> local TSCallbackNames = { on_changedtree = 'changedtree', @@ -76,8 +80,13 @@ local TSCallbackNames = { ---@field private _injections_processed boolean ---@field private _opts table Options ---@field private _parser TSParser Parser for language ----@field private _has_regions boolean +---Table of regions for which the tree is currently running an async parse +---@field private _ranges_being_parsed table<string, boolean> +---Table of callback queues, keyed by each region for which the callbacks should be run +---@field private _cb_queues table<string, fun(err?: string, trees?: table<integer, TSTree>)[]> ---@field private _regions table<integer, Range6[]>? +---The total number of regions. Since _regions can have holes, we cannot simply read this value from #_regions. +---@field private _num_regions integer ---List of regions this tree should manage and parse. If nil then regions are ---taken from _trees. This is mostly a short-lived cache for included_regions() ---@field private _lang string Language name @@ -85,7 +94,8 @@ local TSCallbackNames = { ---@field private _source (integer|string) Buffer or string to parse ---@field private _trees table<integer, TSTree> Reference to parsed tree (one for each language). ---Each key is the index of region, which is synced with _regions and _valid. ----@field private _valid boolean|table<integer,boolean> If the parsed tree is valid +---@field private _valid_regions table<integer,true> Set of valid region IDs. +---@field private _is_entirely_valid boolean Whether the entire tree (excluding children) is valid. ---@field private _logger? fun(logtype: string, msg: string) ---@field private _logfile? file* local LanguageTree = {} @@ -117,7 +127,7 @@ function LanguageTree.new(source, lang, opts) local injections = opts.injections or {} - --- @type vim.treesitter.LanguageTree + --- @class vim.treesitter.LanguageTree local self = { _source = source, _lang = lang, @@ -126,10 +136,13 @@ function LanguageTree.new(source, lang, opts) _opts = opts, _injection_query = injections[lang] and query.parse(lang, injections[lang]) or query.get(lang, 'injections'), - _has_regions = false, _injections_processed = false, - _valid = false, + _valid_regions = {}, + _num_regions = 1, + _is_entirely_valid = false, _parser = vim._create_ts_parser(lang), + _ranges_being_parsed = {}, + _cb_queues = {}, _callbacks = {}, _callbacks_rec = {}, } @@ -182,7 +195,7 @@ end ---Measure execution time of a function ---@generic R1, R2, R3 ----@param f fun(): R1, R2, R2 +---@param f fun(): R1, R2, R3 ---@return number, R1, R2, R3 local function tcall(f, ...) local start = vim.uv.hrtime() @@ -190,6 +203,7 @@ local function tcall(f, ...) local r = { f(...) } --- @type number local duration = (vim.uv.hrtime() - start) / 1000000 + --- @diagnostic disable-next-line: redundant-return-value return duration, unpack(r) end @@ -231,7 +245,9 @@ end --- tree in treesitter. Doesn't clear filesystem cache. Called often, so needs to be fast. ---@param reload boolean|nil function LanguageTree:invalidate(reload) - self._valid = false + self._valid_regions = {} + self._is_entirely_valid = false + self._parser:reset() -- buffer was reloaded, reparse all trees if reload then @@ -258,20 +274,51 @@ function LanguageTree:trees() end --- Gets the language of this tree node. +--- @return string function LanguageTree:lang() return self._lang end +--- @param region Range6[] +--- @param range? boolean|Range +--- @return boolean +local function intercepts_region(region, range) + if #region == 0 then + return true + end + + if range == nil then + return false + end + + if type(range) == 'boolean' then + return range + end + + for _, r in ipairs(region) do + if Range.intercepts(r, range) then + return true + end + end + + return false +end + --- Returns whether this LanguageTree is valid, i.e., |LanguageTree:trees()| reflects the latest --- state of the source. If invalid, user should call |LanguageTree:parse()|. ----@param exclude_children boolean|nil whether to ignore the validity of children (default `false`) +---@param exclude_children boolean? whether to ignore the validity of children (default `false`) +---@param range Range? range to check for validity ---@return boolean -function LanguageTree:is_valid(exclude_children) - local valid = self._valid +function LanguageTree:is_valid(exclude_children, range) + local valid_regions = self._valid_regions - if type(valid) == 'table' then - for i, _ in pairs(self:included_regions()) do - if not valid[i] then + if not self._is_entirely_valid then + if not range then + return false + end + -- TODO: Efficiently search for possibly intersecting regions using a binary search + for i, region in pairs(self:included_regions()) do + if not valid_regions[i] and intercepts_region(region, range) then return false end end @@ -283,97 +330,81 @@ function LanguageTree:is_valid(exclude_children) end for _, child in pairs(self._children) do - if not child:is_valid(exclude_children) then + if not child:is_valid(exclude_children, range) then return false end end end - if type(valid) == 'boolean' then - return valid - end - - self._valid = true return true end --- Returns a map of language to child tree. +--- @return table<string,vim.treesitter.LanguageTree> function LanguageTree:children() return self._children end --- Returns the source content of the language tree (bufnr or string). +--- @return integer|string function LanguageTree:source() return self._source end ---- @param region Range6[] ---- @param range? boolean|Range ---- @return boolean -local function intercepts_region(region, range) - if #region == 0 then - return true - end - - if range == nil then - return false - end - - if type(range) == 'boolean' then - return range - end - - for _, r in ipairs(region) do - if Range.intercepts(r, range) then - return true - end - end - - return false -end - --- @private --- @param range boolean|Range? +--- @param thread_state ParserThreadState --- @return Range6[] changes --- @return integer no_regions_parsed --- @return number total_parse_time -function LanguageTree:_parse_regions(range) +--- @return boolean finished whether async parsing still needs time +function LanguageTree:_parse_regions(range, thread_state) local changes = {} local no_regions_parsed = 0 local total_parse_time = 0 - if type(self._valid) ~= 'table' then - self._valid = {} - end - -- If there are no ranges, set to an empty list -- so the included ranges in the parser are cleared. for i, ranges in pairs(self:included_regions()) do if - not self._valid[i] + not self._valid_regions[i] and ( intercepts_region(ranges, range) or (self._trees[i] and intercepts_region(self._trees[i]:included_ranges(false), range)) ) then self._parser:set_included_ranges(ranges) + self._parser:set_timeout(thread_state.timeout and thread_state.timeout * 1000 or 0) -- ms -> micros + local parse_time, tree, tree_changes = tcall(self._parser.parse, self._parser, self._trees[i], self._source, true) + while true do + if tree then + break + end + coroutine.yield(changes, no_regions_parsed, total_parse_time, false) - -- Pass ranges if this is an initial parse - local cb_changes = self._trees[i] and tree_changes or tree:included_ranges(true) + parse_time, tree, tree_changes = + tcall(self._parser.parse, self._parser, self._trees[i], self._source, true) + end - self:_do_callback('changedtree', cb_changes, tree) + self:_do_callback('changedtree', tree_changes, tree) self._trees[i] = tree vim.list_extend(changes, tree_changes) total_parse_time = total_parse_time + parse_time no_regions_parsed = no_regions_parsed + 1 - self._valid[i] = true + self._valid_regions[i] = true + + -- _valid_regions can have holes, but that is okay because this equality is only true when it + -- has no holes (meaning all regions are valid) + if #self._valid_regions == self._num_regions then + self._is_entirely_valid = true + end end end - return changes, no_regions_parsed, total_parse_time + return changes, no_regions_parsed, total_parse_time, true end --- @private @@ -409,6 +440,98 @@ function LanguageTree:_add_injections() return query_time end +--- @param range boolean|Range? +--- @return string +local function range_to_string(range) + return type(range) == 'table' and table.concat(range, ',') or tostring(range) +end + +--- @private +--- @param range boolean|Range? +--- @param callback fun(err?: string, trees?: table<integer, TSTree>) +function LanguageTree:_push_async_callback(range, callback) + local key = range_to_string(range) + self._cb_queues[key] = self._cb_queues[key] or {} + local queue = self._cb_queues[key] + queue[#queue + 1] = callback +end + +--- @private +--- @param range boolean|Range? +--- @param err? string +--- @param trees? table<integer, TSTree> +function LanguageTree:_run_async_callbacks(range, err, trees) + local key = range_to_string(range) + for _, cb in ipairs(self._cb_queues[key]) do + cb(err, trees) + end + self._ranges_being_parsed[key] = nil + self._cb_queues[key] = nil +end + +--- Run an asynchronous parse, calling {on_parse} when complete. +--- +--- @private +--- @param range boolean|Range? +--- @param on_parse fun(err?: string, trees?: table<integer, TSTree>) +--- @return table<integer, TSTree>? trees the list of parsed trees, if parsing completed synchronously +function LanguageTree:_async_parse(range, on_parse) + self:_push_async_callback(range, on_parse) + + -- If we are already running an async parse, just queue the callback. + local range_string = range_to_string(range) + if not self._ranges_being_parsed[range_string] then + self._ranges_being_parsed[range_string] = true + else + return + end + + local source = self._source + local is_buffer_parser = type(source) == 'number' + local buf = is_buffer_parser and vim.b[source] or nil + local ct = is_buffer_parser and buf.changedtick or nil + local total_parse_time = 0 + local redrawtime = vim.o.redrawtime + + local thread_state = {} ---@type ParserThreadState + + ---@type fun(): table<integer, TSTree>, boolean + local parse = coroutine.wrap(self._parse) + + local function step() + if is_buffer_parser then + if + not vim.api.nvim_buf_is_valid(source --[[@as number]]) + then + return nil + end + + -- If buffer was changed in the middle of parsing, reset parse state + if buf.changedtick ~= ct then + ct = buf.changedtick + total_parse_time = 0 + parse = coroutine.wrap(self._parse) + end + end + + thread_state.timeout = not vim.g._ts_force_sync_parsing and default_parse_timeout_ms or nil + local parse_time, trees, finished = tcall(parse, self, range, thread_state) + total_parse_time = total_parse_time + parse_time + + if finished then + self:_run_async_callbacks(range, nil, trees) + return trees + elseif total_parse_time > redrawtime then + self:_run_async_callbacks(range, 'TIMEOUT', nil) + return nil + else + vim.schedule(step) + end + end + + return step() +end + --- Recursively parse all regions in the language tree using |treesitter-parsers| --- for the corresponding languages and run injection queries on the parsed trees --- to determine whether child trees should be created and parsed. @@ -420,11 +543,33 @@ end --- Set to `true` to run a complete parse of the source (Note: Can be slow!) --- Set to `false|nil` to only parse regions with empty ranges (typically --- only the root tree without injections). ---- @return table<integer, TSTree> -function LanguageTree:parse(range) - if self:is_valid() then +--- @param on_parse fun(err?: string, trees?: table<integer, TSTree>)? Function invoked when parsing completes. +--- When provided and `vim.g._ts_force_sync_parsing` is not set, parsing will run +--- asynchronously. The first argument to the function is a string representing the error type, +--- in case of a failure (currently only possible for timeouts). The second argument is the list +--- of trees returned by the parse (upon success), or `nil` if the parse timed out (determined +--- by 'redrawtime'). +--- +--- If parsing was still able to finish synchronously (within 3ms), `parse()` returns the list +--- of trees. Otherwise, it returns `nil`. +--- @return table<integer, TSTree>? +function LanguageTree:parse(range, on_parse) + if on_parse then + return self:_async_parse(range, on_parse) + end + local trees, _ = self:_parse(range, {}) + return trees +end + +--- @private +--- @param range boolean|Range|nil +--- @param thread_state ParserThreadState +--- @return table<integer, TSTree> trees +--- @return boolean finished +function LanguageTree:_parse(range, thread_state) + if self:is_valid(nil, type(range) == 'table' and range or nil) then self:_log('valid') - return self._trees + return self._trees, true end local changes --- @type Range6[]? @@ -435,15 +580,27 @@ function LanguageTree:parse(range) local total_parse_time = 0 -- At least 1 region is invalid - if not self:is_valid(true) then - changes, no_regions_parsed, total_parse_time = self:_parse_regions(range) + if not self:is_valid(true, type(range) == 'table' and range or nil) then + ---@type fun(self: vim.treesitter.LanguageTree, range: boolean|Range?, thread_state: ParserThreadState): Range6[], integer, number, boolean + local parse_regions = coroutine.wrap(self._parse_regions) + while true do + local is_finished + changes, no_regions_parsed, total_parse_time, is_finished = + parse_regions(self, range, thread_state) + thread_state.timeout = thread_state.timeout + and math.max(thread_state.timeout - total_parse_time, 0) + if is_finished then + break + end + coroutine.yield(self._trees, false) + end -- Need to run injections when we parsed something if no_regions_parsed > 0 then self._injections_processed = false end end - if not self._injections_processed and range ~= false and range ~= nil then + if not self._injections_processed and range then query_time = self:_add_injections() self._injections_processed = true end @@ -457,10 +614,24 @@ function LanguageTree:parse(range) }) for _, child in pairs(self._children) do - child:parse(range) + if thread_state.timeout == 0 then + coroutine.yield(self._trees, false) + end + + ---@type fun(): table<integer, TSTree>, boolean + local parse = coroutine.wrap(child._parse) + + while true do + local ctime, _, child_finished = tcall(parse, child, range, thread_state) + if child_finished then + thread_state.timeout = thread_state.timeout and math.max(thread_state.timeout - ctime, 0) + break + end + coroutine.yield(self._trees, child_finished) + end end - return self._trees + return self._trees, true end --- Invokes the callback for each |LanguageTree| recursively. @@ -504,7 +675,8 @@ function LanguageTree:add_child(lang) return self._children[lang] end ---- @package +---Returns the parent tree. `nil` for the root tree. +---@return vim.treesitter.LanguageTree? function LanguageTree:parent() return self._parent end @@ -551,38 +723,34 @@ end ---region is valid or not. ---@param fn fun(index: integer, region: Range6[]): boolean function LanguageTree:_iter_regions(fn) - if not self._valid then + if vim.deep_equal(self._valid_regions, {}) then return end - local was_valid = type(self._valid) ~= 'table' - - if was_valid then - self:_log('was valid', self._valid) - self._valid = {} + if self._is_entirely_valid then + self:_log('was valid') end local all_valid = true for i, region in pairs(self:included_regions()) do - if was_valid or self._valid[i] then - self._valid[i] = fn(i, region) - if not self._valid[i] then + if self._valid_regions[i] then + -- Setting this to nil rather than false allows us to determine if all regions were parsed + -- just by checking the length of _valid_regions. + self._valid_regions[i] = fn(i, region) and true or nil + if not self._valid_regions[i] then self:_log(function() return 'invalidating region', i, region_tostr(region) end) end end - if not self._valid[i] then + if not self._valid_regions[i] then all_valid = false end end - -- Compress the valid value to 'true' if there are no invalid regions - if all_valid then - self._valid = all_valid - end + self._is_entirely_valid = all_valid end --- Sets the included regions that should be parsed by this |LanguageTree|. @@ -602,14 +770,13 @@ end ---@private ---@param new_regions (Range4|Range6|TSNode)[][] List of regions this tree should manage and parse. function LanguageTree:set_included_regions(new_regions) - self._has_regions = true - -- Transform the tables from 4 element long to 6 element long (with byte offset) for _, region in ipairs(new_regions) do for i, range in ipairs(region) do if type(range) == 'table' and #range == 4 then region[i] = Range.add_bytes(self._source, range --[[@as Range4]]) elseif type(range) == 'userdata' then + --- @diagnostic disable-next-line: missing-fields LuaLS varargs bug region[i] = { range:range(true) } end end @@ -633,6 +800,7 @@ function LanguageTree:set_included_regions(new_regions) end self._regions = new_regions + self._num_regions = #new_regions end ---Gets the set of included regions managed by this LanguageTree. This can be different from the @@ -646,18 +814,8 @@ function LanguageTree:included_regions() return self._regions end - if not self._has_regions then - -- treesitter.c will default empty ranges to { -1, -1, -1, -1, -1, -1} (the full range) - return { {} } - end - - local regions = {} ---@type Range6[][] - for i, _ in pairs(self._trees) do - regions[i] = self._trees[i]:included_ranges(true) - end - - self._regions = regions - return regions + -- treesitter.c will default empty ranges to { -1, -1, -1, -1, -1, -1} (the full range) + return { {} } end ---@param node TSNode @@ -821,7 +979,7 @@ end --- @private --- @return table<string, Range6[][]> function LanguageTree:_get_injections() - if not self._injection_query then + if not self._injection_query or #self._injection_query.captures == 0 then return {} end @@ -907,7 +1065,15 @@ function LanguageTree:_edit( ) end - self._regions = nil + self._parser:reset() + + if self._regions then + local regions = {} ---@type table<integer, Range6[]> + for i, tree in pairs(self._trees) do + regions[i] = tree:included_ranges(true) + end + self._regions = regions + end local changed_range = { start_row, diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 1677e8d364..10fb82e533 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,17 +1,77 @@ +--- @brief This Lua |treesitter-query| interface allows you to create queries and use them to parse +--- text. See |vim.treesitter.query.parse()| for a working example. + local api = vim.api local language = require('vim.treesitter.language') local memoize = vim.func._memoize +local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$' +local EXTENDS_FORMAT = '^;+%s*extends%s*$' + local M = {} +local function is_directive(name) + return string.sub(name, -1) == '!' +end + +---@nodoc +---@class vim.treesitter.query.ProcessedPredicate +---@field [1] string predicate name +---@field [2] boolean should match +---@field [3] (integer|string)[] the original predicate + +---@alias vim.treesitter.query.ProcessedDirective (integer|string)[] + +---@nodoc +---@class vim.treesitter.query.ProcessedPattern { +---@field predicates vim.treesitter.query.ProcessedPredicate[] +---@field directives vim.treesitter.query.ProcessedDirective[] + +--- Splits the query patterns into predicates and directives. +---@param patterns table<integer, (integer|string)[][]> +---@return table<integer, vim.treesitter.query.ProcessedPattern> +local function process_patterns(patterns) + ---@type table<integer, vim.treesitter.query.ProcessedPattern> + local processed_patterns = {} + + for k, pattern_list in pairs(patterns) do + ---@type vim.treesitter.query.ProcessedPredicate[] + local predicates = {} + ---@type vim.treesitter.query.ProcessedDirective[] + local directives = {} + + for _, pattern in ipairs(pattern_list) do + -- Note: tree-sitter strips the leading # from predicates for us. + local pred_name = pattern[1] + ---@cast pred_name string + + if is_directive(pred_name) then + table.insert(directives, pattern) + else + local should_match = true + if pred_name:match('^not%-') then + pred_name = pred_name:sub(5) + should_match = false + end + table.insert(predicates, { pred_name, should_match, pattern }) + end + end + + processed_patterns[k] = { predicates = predicates, directives = directives } + end + + return processed_patterns +end + ---@nodoc ---Parsed query, see |vim.treesitter.query.parse()| --- ---@class vim.treesitter.Query ----@field lang string name of the language for this parser +---@field lang string parser language name ---@field captures string[] list of (unique) capture names defined in query ----@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives) +---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives) ---@field query TSQuery userdata query object +---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern> local Query = {} Query.__index = Query @@ -30,6 +90,7 @@ function Query.new(lang, ts_query) patterns = query_info.patterns, } self.captures = self.info.captures + self._processed_patterns = process_patterns(self.info.patterns) return self end @@ -109,9 +170,6 @@ function M.get_files(lang, query_name, is_included) -- ;+ inherits: ({language},)*{language} -- -- {language} ::= {lang} | ({lang}) - local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$' - local EXTENDS_FORMAT = '^;+%s*extends%s*$' - for _, filename in ipairs(lang_files) do local file, err = io.open(filename, 'r') if not file then @@ -184,8 +242,8 @@ local function read_query_files(filenames) return table.concat(contents, '') end --- The explicitly set queries from |vim.treesitter.query.set()| ----@type table<string,table<string,vim.treesitter.Query>> +-- The explicitly set query strings from |vim.treesitter.query.set()| +---@type table<string,table<string,string>> local explicit_queries = setmetatable({}, { __index = function(t, k) local lang_queries = {} @@ -197,14 +255,27 @@ local explicit_queries = setmetatable({}, { --- Sets the runtime query named {query_name} for {lang} --- ---- This allows users to override any runtime files and/or configuration +--- This allows users to override or extend any runtime files and/or configuration --- set by plugins. --- +--- For example, you could enable spellchecking of `C` identifiers with the +--- following code: +--- ```lua +--- vim.treesitter.query.set( +--- 'c', +--- 'highlights', +--- [[;inherits c +--- (identifier) @spell]]) +--- ]]) +--- ``` +--- ---@param lang string Language to use for the query ---@param query_name string Name of the query (e.g., "highlights") ---@param text string Query text (unparsed). function M.set(lang, query_name, text) - explicit_queries[lang][query_name] = M.parse(lang, text) + --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics + M.get:clear(lang, query_name) + explicit_queries[lang][query_name] = text end --- Returns the runtime query {query_name} for {lang}. @@ -214,34 +285,82 @@ end --- ---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found. M.get = memoize('concat-2', function(lang, query_name) + local query_string ---@type string + if explicit_queries[lang][query_name] then - return explicit_queries[lang][query_name] - end + local query_files = {} + local base_langs = {} ---@type string[] - local query_files = M.get_files(lang, query_name) - local query_string = read_query_files(query_files) + for line in explicit_queries[lang][query_name]:gmatch('([^\n]*)\n?') do + if not vim.startswith(line, ';') then + break + end + + local lang_list = line:match(MODELINE_FORMAT) + if lang_list then + for _, incl_lang in ipairs(vim.split(lang_list, ',')) do + local is_optional = incl_lang:match('%(.*%)') + + if is_optional then + add_included_lang(base_langs, lang, incl_lang:sub(2, #incl_lang - 1)) + else + add_included_lang(base_langs, lang, incl_lang) + end + end + elseif line:match(EXTENDS_FORMAT) then + table.insert(base_langs, lang) + end + end + + for _, base_lang in ipairs(base_langs) do + local base_files = M.get_files(base_lang, query_name, true) + vim.list_extend(query_files, base_files) + end + + query_string = read_query_files(query_files) .. explicit_queries[lang][query_name] + else + local query_files = M.get_files(lang, query_name) + query_string = read_query_files(query_files) + end if #query_string == 0 then return nil end return M.parse(lang, query_string) -end) +end, false) + +api.nvim_create_autocmd('OptionSet', { + pattern = { 'runtimepath' }, + group = api.nvim_create_augroup('nvim.treesitter.query_cache_reset', { clear = true }), + callback = function() + --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics + M.get:clear() + end, +}) ---- Parse {query} as a string. (If the query is in a file, the caller ---- should read the contents into a string before calling). ---- ---- Returns a `Query` (see |lua-treesitter-query|) object which can be used to ---- search nodes in the syntax tree for the patterns defined in {query} ---- using the `iter_captures` and `iter_matches` methods. +--- Parses a {query} string and returns a `Query` object (|lua-treesitter-query|), which can be used +--- to search the tree for the query patterns (via |Query:iter_captures()|, |Query:iter_matches()|), +--- or inspect the query via these fields: +--- - `captures`: a list of unique capture names defined in the query (alias: `info.captures`). +--- - `info.patterns`: information about predicates. --- ---- Exposes `info` and `captures` with additional context about {query}. ---- - `captures` contains the list of unique capture names defined in {query}. ---- - `info.captures` also points to `captures`. ---- - `info.patterns` contains information about predicates. +--- Example: +--- ```lua +--- local query = vim.treesitter.query.parse('vimdoc', [[ +--- ; query +--- ((h1) @str +--- (#trim! @str 1 1 1 1)) +--- ]]) +--- local tree = vim.treesitter.get_parser():parse()[1] +--- for id, node, metadata in query:iter_captures(tree:root(), 0) do +--- -- Print the node name and source text. +--- vim.print({node:type(), vim.treesitter.get_node_text(node, vim.api.nvim_get_current_buf())}) +--- end +--- ``` --- ---@param lang string Language to use for the query ----@param query string Query in s-expr syntax +---@param query string Query text, in s-expr syntax --- ---@return vim.treesitter.Query : Parsed query --- @@ -250,7 +369,7 @@ M.parse = memoize('concat-2', function(lang, query) assert(language.add(lang)) local ts_query = vim._ts_parse_query(lang, query) return Query.new(lang, ts_query) -end) +end, false) --- Implementations of predicates that can optionally be prefixed with "any-". --- @@ -572,13 +691,17 @@ local directive_handlers = { metadata[id].text = text:gsub(pattern, replacement) end, - -- Trim blank lines from end of the node - -- Example: (#trim! @fold) - -- TODO(clason): generalize to arbitrary whitespace removal + -- Trim whitespace from both sides of the node + -- Example: (#trim! @fold 1 1 1 1) ['trim!'] = function(match, _, bufnr, pred, metadata) local capture_id = pred[2] assert(type(capture_id) == 'number') + local trim_start_lines = pred[3] == '1' + local trim_start_cols = pred[4] == '1' + local trim_end_lines = pred[5] == '1' or not pred[3] -- default true for backwards compatibility + local trim_end_cols = pred[6] == '1' + local nodes = match[capture_id] if not nodes or #nodes == 0 then return @@ -588,20 +711,45 @@ local directive_handlers = { local start_row, start_col, end_row, end_col = node:range() - -- Don't trim if region ends in middle of a line - if end_col ~= 0 then - return + local node_text = vim.split(vim.treesitter.get_node_text(node, bufnr), '\n') + if end_col == 0 then + -- get_node_text() will ignore the last line if the node ends at column 0 + node_text[#node_text + 1] = '' end - while end_row >= start_row do - -- As we only care when end_col == 0, always inspect one line above end_row. - local end_line = api.nvim_buf_get_lines(bufnr, end_row - 1, end_row, true)[1] + local end_idx = #node_text + local start_idx = 1 - if end_line ~= '' then - break + if trim_end_lines then + while end_idx > 0 and node_text[end_idx]:find('^%s*$') do + end_idx = end_idx - 1 + end_row = end_row - 1 + -- set the end position to the last column of the next line, or 0 if we just trimmed the + -- last line + end_col = end_idx > 0 and #node_text[end_idx] or 0 end + end + if trim_end_cols then + if end_idx == 0 then + end_row = start_row + end_col = start_col + else + local whitespace_start = node_text[end_idx]:find('(%s*)$') + end_col = (whitespace_start - 1) + (end_idx == 1 and start_col or 0) + end + end - end_row = end_row - 1 + if trim_start_lines then + while start_idx <= end_idx and node_text[start_idx]:find('^%s*$') do + start_idx = start_idx + 1 + start_row = start_row + 1 + start_col = 0 + end + end + if trim_start_cols and node_text[start_idx] then + local _, whitespace_end = node_text[start_idx]:find('^(%s*)') + whitespace_end = whitespace_end or 0 + start_col = (start_idx == 1 and start_col or 0) + whitespace_end end -- If this produces an invalid range, we just skip it. @@ -711,84 +859,50 @@ function M.list_predicates() return vim.tbl_keys(predicate_handlers) end -local function xor(x, y) - return (x or y) and not (x and y) -end - -local function is_directive(name) - return string.sub(name, -1) == '!' -end - ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param predicates vim.treesitter.query.ProcessedPredicate[] +---@param captures table<integer, TSNode[]> ---@param source integer|string -function Query:match_preds(match, source) - local _, pattern = match:info() - local preds = self.info.patterns[pattern] - - if not preds then - return true - end - - local captures = match:captures() - - for _, pred in pairs(preds) do - -- Here we only want to return if a predicate DOES NOT match, and - -- continue on the other case. This way unknown predicates will not be considered, - -- which allows some testing and easier user extensibility (#12173). - -- Also, tree-sitter strips the leading # from predicates for us. - local is_not = false - - -- Skip over directives... they will get processed after all the predicates. - if not is_directive(pred[1]) then - local pred_name = pred[1] - if pred_name:match('^not%-') then - pred_name = pred_name:sub(5) - is_not = true - end - - local handler = predicate_handlers[pred_name] - - if not handler then - error(string.format('No handler for %s', pred[1])) - return false - end - - local pred_matches = handler(captures, pattern, source, pred) +---@return boolean whether the predicates match +function Query:_match_predicates(predicates, pattern_i, captures, source) + for _, predicate in ipairs(predicates) do + local processed_name = predicate[1] + local should_match = predicate[2] + local orig_predicate = predicate[3] + + local handler = predicate_handlers[processed_name] + if not handler then + error(string.format('No handler for %s', orig_predicate[1])) + return false + end - if not xor(is_not, pred_matches) then - return false - end + local does_match = handler(captures, pattern_i, source, orig_predicate) + if does_match ~= should_match then + return false end end return true end ---@private ----@param match TSQueryMatch +---@param pattern_i integer +---@param directives vim.treesitter.query.ProcessedDirective[] +---@param source integer|string +---@param captures table<integer, TSNode[]> ---@return vim.treesitter.query.TSMetadata metadata -function Query:apply_directives(match, source) +function Query:_apply_directives(directives, pattern_i, captures, source) ---@type vim.treesitter.query.TSMetadata local metadata = {} - local _, pattern = match:info() - local preds = self.info.patterns[pattern] - - if not preds then - return metadata - end - local captures = match:captures() - - for _, pred in pairs(preds) do - if is_directive(pred[1]) then - local handler = directive_handlers[pred[1]] - - if not handler then - error(string.format('No handler for %s', pred[1])) - end + for _, directive in pairs(directives) do + local handler = directive_handlers[directive[1]] - handler(captures, pattern, source, pred, metadata) + if not handler then + error(string.format('No handler for %s', directive[1])) end + + handler(captures, pattern_i, source, directive, metadata) end return metadata @@ -812,26 +926,22 @@ local function value_or_node_range(start, stop, node) return start, stop end ---- @param match TSQueryMatch ---- @return integer -local function match_id_hash(_, match) - return (match:info()) -end - ---- Iterate over all captures from all matches inside {node} +--- Iterates over all captures from all matches in {node}. --- ---- {source} is needed if the query contains predicates; then the caller +--- {source} is required if the query contains predicates; then the caller --- must ensure to use a freshly parsed tree consistent with the current --- text of the buffer (if relevant). {start} and {stop} can be used to limit --- matches inside a row range (this is typically used with root node --- as the {node}, i.e., to get syntax highlight matches in the current --- viewport). When omitted, the {start} and {stop} row values are used from the given node. --- ---- The iterator returns four values: a numeric id identifying the capture, ---- the captured node, metadata from any directives processing the match, ---- and the match itself. ---- The following example shows how to get captures by name: +--- The iterator returns four values: +--- 1. the numeric id identifying the capture +--- 2. the captured node +--- 3. metadata from any directives processing the match +--- 4. the match itself --- +--- Example: how to get captures by name: --- ```lua --- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do --- local name = query.captures[id] -- name of the capture in the query @@ -847,8 +957,8 @@ end ---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch): ---- capture id, capture node, metadata, match +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch, TSTree): +--- capture id, capture node, metadata, match, tree --- ---@note Captures are only returned if the query pattern of a specific capture contained predicates. function Query:iter_captures(node, source, start, stop) @@ -858,10 +968,14 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) + -- Copy the tree to ensure it is valid during the entire lifetime of the iterator + local tree = node:tree():copy() local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) - local apply_directives = memoize(match_id_hash, self.apply_directives, true) - local match_preds = memoize(match_id_hash, self.match_preds, true) + -- For faster checks that a match is not in the cache. + local highest_cached_match_id = -1 + ---@type table<integer, vim.treesitter.query.TSMetadata> + local match_cache = {} local function iter(end_line) local capture, captured_node, match = cursor:next_capture() @@ -870,18 +984,39 @@ function Query:iter_captures(node, source, start, stop) return end - if not match_preds(self, match, source) then - local match_id = match:info() - cursor:remove_match(match_id) - if end_line and captured_node:range() > end_line then - return nil, captured_node, nil, nil - end - return iter(end_line) -- tail call: try next match + local match_id, pattern_i = match:info() + + --- @type vim.treesitter.query.TSMetadata + local metadata + if match_id <= highest_cached_match_id then + metadata = match_cache[match_id] end - local metadata = apply_directives(self, match, source) + if not metadata then + metadata = {} + + local processed_pattern = self._processed_patterns[pattern_i] + if processed_pattern then + local captures = match:captures() - return capture, captured_node, metadata, match + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then + cursor:remove_match(match_id) + if end_line and captured_node:range() > end_line then + return nil, captured_node, nil, nil + end + return iter(end_line) -- tail call: try next match + end + + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) + end + + highest_cached_match_id = math.max(highest_cached_match_id, match_id) + match_cache[match_id] = metadata + end + + return capture, captured_node, metadata, match, tree end return iter end @@ -903,7 +1038,7 @@ end --- -- `node` was captured by the `name` capture in the match --- --- local node_data = metadata[id] -- Node level metadata ---- ... use the info here ... +--- -- ... use the info here ... --- end --- end --- end @@ -922,7 +1057,7 @@ end --- (last) node instead of the full list of matching nodes. This option is only for backward --- compatibility and will be removed in a future release. --- ----@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata): pattern id, match, metadata +---@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata, TSTree): pattern id, match, metadata, tree function Query:iter_matches(node, source, start, stop, opts) opts = opts or {} opts.match_limit = opts.match_limit or 256 @@ -933,6 +1068,8 @@ function Query:iter_matches(node, source, start, stop, opts) start, stop = value_or_node_range(start, stop, node) + -- Copy the tree to ensure it is valid during the entire lifetime of the iterator + local tree = node:tree():copy() local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts) local function iter() @@ -942,17 +1079,22 @@ function Query:iter_matches(node, source, start, stop, opts) return end - local match_id, pattern = match:info() + local match_id, pattern_i = match:info() + local processed_pattern = self._processed_patterns[pattern_i] + local captures = match:captures() - if not self:match_preds(match, source) then - cursor:remove_match(match_id) - return iter() -- tail call: try next match + --- @type vim.treesitter.query.TSMetadata + local metadata = {} + if processed_pattern then + local predicates = processed_pattern.predicates + if not self:_match_predicates(predicates, pattern_i, captures, source) then + cursor:remove_match(match_id) + return iter() -- tail call: try next match + end + local directives = processed_pattern.directives + metadata = self:_apply_directives(directives, pattern_i, captures, source) end - local metadata = self:apply_directives(match, source) - - local captures = match:captures() - if opts.all == false then -- Convert the match table into the old buggy version for backward -- compatibility. This is slow, but we only do it when the caller explicitly opted into it by @@ -961,11 +1103,11 @@ function Query:iter_matches(node, source, start, stop, opts) for k, v in pairs(captures or {}) do old_match[k] = v[#v] end - return pattern, old_match, metadata + return pattern_i, old_match, metadata end -- TODO(lewis6991): create a new function that returns {match, metadata} - return pattern, captures, metadata + return pattern_i, captures, metadata, tree end return iter end |