aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter
diff options
context:
space:
mode:
authorJosh Rahm <joshuarahm@gmail.com>2025-02-05 23:09:29 +0000
committerJosh Rahm <joshuarahm@gmail.com>2025-02-05 23:09:29 +0000
commitd5f194ce780c95821a855aca3c19426576d28ae0 (patch)
treed45f461b19f9118ad2bb1f440a7a08973ad18832 /runtime/lua/vim/treesitter
parentc5d770d311841ea5230426cc4c868e8db27300a8 (diff)
parent44740e561fc93afe3ebecfd3618bda2d2abeafb0 (diff)
downloadrneovim-rahm.tar.gz
rneovim-rahm.tar.bz2
rneovim-rahm.zip
Merge remote-tracking branch 'upstream/master' into mix_20240309HEADrahm
Diffstat (limited to 'runtime/lua/vim/treesitter')
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua325
-rw-r--r--runtime/lua/vim/treesitter/_meta/misc.lua8
-rw-r--r--runtime/lua/vim/treesitter/_meta/tsnode.lua16
-rw-r--r--runtime/lua/vim/treesitter/_query_linter.lua6
-rw-r--r--runtime/lua/vim/treesitter/dev.lua25
-rw-r--r--runtime/lua/vim/treesitter/highlighter.lua30
-rw-r--r--runtime/lua/vim/treesitter/language.lua7
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua356
-rw-r--r--runtime/lua/vim/treesitter/query.lua418
9 files changed, 749 insertions, 442 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
index 7237d2e7d4..38318347a7 100644
--- a/runtime/lua/vim/treesitter/_fold.lua
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -19,76 +19,36 @@ local api = vim.api
---The range on which to evaluate foldexpr.
---When in insert mode, the evaluation is deferred to InsertLeave.
---@field foldupdate_range? Range2
+---
+---The treesitter parser associated with this buffer.
+---@field parser? vim.treesitter.LanguageTree
local FoldInfo = {}
FoldInfo.__index = FoldInfo
---@private
-function FoldInfo.new()
+---@param bufnr integer
+function FoldInfo.new(bufnr)
return setmetatable({
levels0 = {},
levels = {},
+ parser = ts.get_parser(bufnr, nil, { error = false }),
}, FoldInfo)
end
---- Efficiently remove items from middle of a list a list.
----
---- Calling table.remove() in a loop will re-index the tail of the table on
---- every iteration, instead this function will re-index the table exactly
---- once.
----
---- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
----
----@param t any[]
----@param first integer
----@param last integer
-local function list_remove(t, first, last)
- local n = #t
- for i = 0, n - first do
- t[first + i] = t[last + 1 + i]
- t[last + 1 + i] = nil
- end
-end
-
---@package
---@param srow integer
---@param erow integer 0-indexed, exclusive
function FoldInfo:remove_range(srow, erow)
- list_remove(self.levels, srow + 1, erow)
- list_remove(self.levels0, srow + 1, erow)
-end
-
---- Efficiently insert items into the middle of a list.
----
---- Calling table.insert() in a loop will re-index the tail of the table on
---- every iteration, instead this function will re-index the table exactly
---- once.
----
---- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
----
----@param t any[]
----@param first integer
----@param last integer
----@param v any
-local function list_insert(t, first, last, v)
- local n = #t
-
- -- Shift table forward
- for i = n - first, 0, -1 do
- t[last + 1 + i] = t[first + i]
- end
-
- -- Fill in new values
- for i = first, last do
- t[i] = v
- end
+ vim._list_remove(self.levels, srow + 1, erow)
+ vim._list_remove(self.levels0, srow + 1, erow)
end
---@package
---@param srow integer
---@param erow integer 0-indexed, exclusive
function FoldInfo:add_range(srow, erow)
- list_insert(self.levels, srow + 1, erow, -1)
- list_insert(self.levels0, srow + 1, erow, -1)
+ vim._list_insert(self.levels, srow + 1, erow, -1)
+ vim._list_insert(self.levels0, srow + 1, erow, -1)
end
---@param range Range2
@@ -109,111 +69,122 @@ end
---@param info TS.FoldInfo
---@param srow integer?
---@param erow integer? 0-indexed, exclusive
----@param parse_injections? boolean
-local function compute_folds_levels(bufnr, info, srow, erow, parse_injections)
+---@param callback function?
+local function compute_folds_levels(bufnr, info, srow, erow, callback)
srow = srow or 0
erow = erow or api.nvim_buf_line_count(bufnr)
- local parser = assert(ts.get_parser(bufnr, nil, { error = false }))
-
- parser:parse(parse_injections and { srow, erow } or nil)
-
- local enter_counts = {} ---@type table<integer, integer>
- local leave_counts = {} ---@type table<integer, integer>
- local prev_start = -1
- local prev_stop = -1
+ local parser = info.parser
+ if not parser then
+ return
+ end
- parser:for_each_tree(function(tree, ltree)
- local query = ts.query.get(ltree:lang(), 'folds')
- if not query then
+ parser:parse(nil, function(_, trees)
+ if not trees then
return
end
- -- Collect folds starting from srow - 1, because we should first subtract the folds that end at
- -- srow - 1 from the level of srow - 1 to get accurate level of srow.
- for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do
- for id, nodes in pairs(match) do
- if query.captures[id] == 'fold' then
- local range = ts.get_range(nodes[1], bufnr, metadata[id])
- local start, _, stop, stop_col = Range.unpack4(range)
-
- if #nodes > 1 then
- -- assumes nodes are ordered by range
- local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id])
- local _, _, end_stop, end_stop_col = Range.unpack4(end_range)
- stop = end_stop
- stop_col = end_stop_col
- end
+ local enter_counts = {} ---@type table<integer, integer>
+ local leave_counts = {} ---@type table<integer, integer>
+ local prev_start = -1
+ local prev_stop = -1
- if stop_col == 0 then
- stop = stop - 1
- end
+ parser:for_each_tree(function(tree, ltree)
+ local query = ts.query.get(ltree:lang(), 'folds')
+ if not query then
+ return
+ end
- local fold_length = stop - start + 1
-
- -- Fold only multiline nodes that are not exactly the same as previously met folds
- -- Checking against just the previously found fold is sufficient if nodes
- -- are returned in preorder or postorder when traversing tree
- if
- fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
- then
- enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1
- leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1
- prev_start = start
- prev_stop = stop
+ -- Collect folds starting from srow - 1, because we should first subtract the folds that end at
+ -- srow - 1 from the level of srow - 1 to get accurate level of srow.
+ for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do
+ for id, nodes in pairs(match) do
+ if query.captures[id] == 'fold' then
+ local range = ts.get_range(nodes[1], bufnr, metadata[id])
+ local start, _, stop, stop_col = Range.unpack4(range)
+
+ if #nodes > 1 then
+ -- assumes nodes are ordered by range
+ local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id])
+ local _, _, end_stop, end_stop_col = Range.unpack4(end_range)
+ stop = end_stop
+ stop_col = end_stop_col
+ end
+
+ if stop_col == 0 then
+ stop = stop - 1
+ end
+
+ local fold_length = stop - start + 1
+
+ -- Fold only multiline nodes that are not exactly the same as previously met folds
+ -- Checking against just the previously found fold is sufficient if nodes
+ -- are returned in preorder or postorder when traversing tree
+ if
+ fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
+ then
+ enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1
+ leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1
+ prev_start = start
+ prev_stop = stop
+ end
end
end
end
- end
- end)
+ end)
- local nestmax = vim.wo.foldnestmax
- local level0_prev = info.levels0[srow] or 0
- local leave_prev = leave_counts[srow] or 0
-
- -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
- for lnum = srow + 1, erow do
- local enter_line = enter_counts[lnum] or 0
- local leave_line = leave_counts[lnum] or 0
- local level0 = level0_prev - leave_prev + enter_line
-
- -- Determine if it's the start/end of a fold
- -- NB: vim's fold-expr interface does not have a mechanism to indicate that
- -- two (or more) folds start at this line, so it cannot distinguish between
- -- ( \n ( \n )) \n (( \n ) \n )
- -- versus
- -- ( \n ( \n ) \n ( \n ) \n )
- -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and
- -- vim interprets as the second case.
- -- If it did have such a mechanism, (clamped - clamped_prev)
- -- would be the correct number of starts to pass on.
- local adjusted = level0 ---@type integer
- local prefix = ''
- if enter_line > 0 then
- prefix = '>'
- if leave_line > 0 then
- -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line
- -- so that f2 gets the correct level on this line. This may reduce the size of f1 below
- -- foldminlines, but we don't handle it for simplicity.
- adjusted = level0 - leave_line
- leave_line = 0
+ local nestmax = vim.wo.foldnestmax
+ local level0_prev = info.levels0[srow] or 0
+ local leave_prev = leave_counts[srow] or 0
+
+ -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
+ for lnum = srow + 1, erow do
+ local enter_line = enter_counts[lnum] or 0
+ local leave_line = leave_counts[lnum] or 0
+ local level0 = level0_prev - leave_prev + enter_line
+
+ -- Determine if it's the start/end of a fold
+ -- NB: vim's fold-expr interface does not have a mechanism to indicate that
+ -- two (or more) folds start at this line, so it cannot distinguish between
+ -- ( \n ( \n )) \n (( \n ) \n )
+ -- versus
+ -- ( \n ( \n ) \n ( \n ) \n )
+ -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and
+ -- vim interprets as the second case.
+ -- If it did have such a mechanism, (clamped - clamped_prev)
+ -- would be the correct number of starts to pass on.
+ local adjusted = level0 ---@type integer
+ local prefix = ''
+ if enter_line > 0 then
+ prefix = '>'
+ if leave_line > 0 then
+ -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line
+ -- so that f2 gets the correct level on this line. This may reduce the size of f1 below
+ -- foldminlines, but we don't handle it for simplicity.
+ adjusted = level0 - leave_line
+ leave_line = 0
+ end
end
- end
- -- Clamp at foldnestmax.
- local clamped = adjusted
- if adjusted > nestmax then
- prefix = ''
- clamped = nestmax
- end
+ -- Clamp at foldnestmax.
+ local clamped = adjusted
+ if adjusted > nestmax then
+ prefix = ''
+ clamped = nestmax
+ end
- -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels().
- info.levels0[lnum] = adjusted
- info.levels[lnum] = prefix .. tostring(clamped)
+ -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels().
+ info.levels0[lnum] = adjusted
+ info.levels[lnum] = prefix .. tostring(clamped)
- leave_prev = leave_line
- level0_prev = adjusted
- end
+ leave_prev = leave_line
+ level0_prev = adjusted
+ end
+
+ if callback then
+ callback()
+ end
+ end)
end
local M = {}
@@ -221,7 +192,7 @@ local M = {}
---@type table<integer,TS.FoldInfo>
local foldinfos = {}
-local group = api.nvim_create_augroup('treesitter/fold', {})
+local group = api.nvim_create_augroup('nvim.treesitter.fold', {})
--- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that
--- the user doesn't use different foldexpr for the same buffer).
@@ -298,12 +269,19 @@ local function schedule_if_loaded(bufnr, fn)
end
---@param bufnr integer
----@param foldinfo TS.FoldInfo
---@param tree_changes Range4[]
-local function on_changedtree(bufnr, foldinfo, tree_changes)
+local function on_changedtree(bufnr, tree_changes)
schedule_if_loaded(bufnr, function()
+ -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked.
+ local foldinfo = foldinfos[bufnr]
+ if not foldinfo then
+ return
+ end
+
local srow_upd, erow_upd ---@type integer?, integer?
local max_erow = api.nvim_buf_line_count(bufnr)
+ -- TODO(ribru17): Replace this with a proper .all() awaiter once #19624 is resolved
+ local iterations = 0
for _, change in ipairs(tree_changes) do
local srow, _, erow, ecol = Range.unpack4(change)
-- If a parser doesn't have any ranges explicitly set, treesitter will
@@ -317,24 +295,31 @@ local function on_changedtree(bufnr, foldinfo, tree_changes)
end
-- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
srow = math.max(srow - vim.wo.foldminlines, 0)
- compute_folds_levels(bufnr, foldinfo, srow, erow)
srow_upd = srow_upd and math.min(srow_upd, srow) or srow
erow_upd = erow_upd and math.max(erow_upd, erow) or erow
- end
- if #tree_changes > 0 then
- foldinfo:foldupdate(bufnr, srow_upd, erow_upd)
+ compute_folds_levels(bufnr, foldinfo, srow, erow, function()
+ iterations = iterations + 1
+ if iterations == #tree_changes then
+ foldinfo:foldupdate(bufnr, srow_upd, erow_upd)
+ end
+ end)
end
end)
end
---@param bufnr integer
----@param foldinfo TS.FoldInfo
---@param start_row integer
---@param old_row integer
---@param old_col integer
---@param new_row integer
---@param new_col integer
-local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, new_row, new_col)
+local function on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col)
+ -- Buffer reload clears `foldinfos[bufnr]`, which may still be nil when callback is invoked.
+ local foldinfo = foldinfos[bufnr]
+ if not foldinfo then
+ return
+ end
+
-- extend the end to fully include the range
local end_row_old = start_row + old_row + 1
local end_row_new = start_row + new_row + 1
@@ -373,15 +358,16 @@ local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col,
-- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing
-- the scheduled callback. So we accumulate the edited ranges in `on_bytes_range`.
schedule_if_loaded(bufnr, function()
- if not foldinfo.on_bytes_range then
+ if not (foldinfo.on_bytes_range and foldinfos[bufnr]) then
return
end
local srow, erow = foldinfo.on_bytes_range[1], foldinfo.on_bytes_range[2]
foldinfo.on_bytes_range = nil
-- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
srow = math.max(srow - vim.wo.foldminlines, 0)
- compute_folds_levels(bufnr, foldinfo, srow, erow)
- foldinfo:foldupdate(bufnr, srow, erow)
+ compute_folds_levels(bufnr, foldinfo, srow, erow, function()
+ foldinfo:foldupdate(bufnr, srow, erow)
+ end)
end)
end
end
@@ -392,22 +378,30 @@ function M.foldexpr(lnum)
lnum = lnum or vim.v.lnum
local bufnr = api.nvim_get_current_buf()
- local parser = ts.get_parser(bufnr, nil, { error = false })
- if not parser then
- return '0'
- end
-
if not foldinfos[bufnr] then
- foldinfos[bufnr] = FoldInfo.new()
+ foldinfos[bufnr] = FoldInfo.new(bufnr)
+ api.nvim_create_autocmd({ 'BufUnload', 'VimEnter' }, {
+ buffer = bufnr,
+ once = true,
+ callback = function()
+ foldinfos[bufnr] = nil
+ end,
+ })
+
+ local parser = foldinfos[bufnr].parser
+ if not parser then
+ return '0'
+ end
+
compute_folds_levels(bufnr, foldinfos[bufnr])
parser:register_cbs({
on_changedtree = function(tree_changes)
- on_changedtree(bufnr, foldinfos[bufnr], tree_changes)
+ on_changedtree(bufnr, tree_changes)
end,
on_bytes = function(_, _, start_row, start_col, _, old_row, old_col, _, new_row, new_col, _)
- on_bytes(bufnr, foldinfos[bufnr], start_row, start_col, old_row, old_col, new_row, new_col)
+ on_bytes(bufnr, start_row, start_col, old_row, old_col, new_row, new_col)
end,
on_detach = function()
@@ -423,10 +417,17 @@ api.nvim_create_autocmd('OptionSet', {
pattern = { 'foldminlines', 'foldnestmax' },
desc = 'Refresh treesitter folds',
callback = function()
- for bufnr, _ in pairs(foldinfos) do
- foldinfos[bufnr] = FoldInfo.new()
- compute_folds_levels(bufnr, foldinfos[bufnr])
- foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr))
+ local buf = api.nvim_get_current_buf()
+ local bufs = vim.v.option_type == 'global' and vim.tbl_keys(foldinfos)
+ or foldinfos[buf] and { buf }
+ or {}
+ for _, bufnr in ipairs(bufs) do
+ foldinfos[bufnr] = FoldInfo.new(bufnr)
+ api.nvim_buf_call(bufnr, function()
+ compute_folds_levels(bufnr, foldinfos[bufnr], nil, nil, function()
+ foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr))
+ end)
+ end)
end
end,
})
diff --git a/runtime/lua/vim/treesitter/_meta/misc.lua b/runtime/lua/vim/treesitter/_meta/misc.lua
index 33701ef254..c532257f49 100644
--- a/runtime/lua/vim/treesitter/_meta/misc.lua
+++ b/runtime/lua/vim/treesitter/_meta/misc.lua
@@ -20,9 +20,15 @@ error('Cannot require a meta file')
---@class (exact) TSQueryInfo
---@field captures string[]
---@field patterns table<integer, (integer|string)[][]>
+---
+---@class TSLangInfo
+---@field fields string[]
+---@field symbols table<string,boolean>
+---@field _wasm boolean
+---@field _abi_version integer
--- @param lang string
---- @return table
+--- @return TSLangInfo
vim._ts_inspect_language = function(lang) end
---@return integer
diff --git a/runtime/lua/vim/treesitter/_meta/tsnode.lua b/runtime/lua/vim/treesitter/_meta/tsnode.lua
index d982b6a505..552905c3f0 100644
--- a/runtime/lua/vim/treesitter/_meta/tsnode.lua
+++ b/runtime/lua/vim/treesitter/_meta/tsnode.lua
@@ -68,12 +68,6 @@ function TSNode:named_child_count() end
--- @return TSNode?
function TSNode:named_child(index) end
---- Get the node's child that contains {descendant}.
---- @param descendant TSNode
---- @return TSNode?
---- @deprecated
-function TSNode:child_containing_descendant(descendant) end
-
--- Get the node's child that contains {descendant} (includes {descendant}).
---
--- For example, with the following node hierarchy:
@@ -109,17 +103,9 @@ function TSNode:end_() end
--- - end row
--- - end column
--- - end byte (if {include_bytes} is `true`)
---- @param include_bytes boolean?
-function TSNode:range(include_bytes) end
-
---- @nodoc
--- @param include_bytes false?
--- @return integer, integer, integer, integer
-function TSNode:range(include_bytes) end
-
---- @nodoc
---- @param include_bytes true
---- @return integer, integer, integer, integer, integer, integer
+--- @overload fun(self: TSNode, include_bytes: true): integer, integer, integer, integer, integer, integer
function TSNode:range(include_bytes) end
--- Get the node's type as a string.
diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua
index a825505378..3dfc6b0cfe 100644
--- a/runtime/lua/vim/treesitter/_query_linter.lua
+++ b/runtime/lua/vim/treesitter/_query_linter.lua
@@ -1,6 +1,6 @@
local api = vim.api
-local namespace = api.nvim_create_namespace('vim.treesitter.query_linter')
+local namespace = api.nvim_create_namespace('nvim.treesitter.query_linter')
local M = {}
@@ -138,7 +138,9 @@ local function lint_match(buf, match, query, lang_context, diagnostics)
-- perform language-independent checks only for first lang
if lang_context.is_first_lang and cap_id == 'error' then
local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
- add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text)
+ ---@diagnostic disable-next-line: missing-fields LuaLS varargs bug
+ local range = { node:range() } --- @type Range4
+ add_lint_for_node(diagnostics, range, 'Syntax error: ' .. node_text)
end
-- other checks rely on Neovim parser introspection
diff --git a/runtime/lua/vim/treesitter/dev.lua b/runtime/lua/vim/treesitter/dev.lua
index 26817cdba5..24dd8243db 100644
--- a/runtime/lua/vim/treesitter/dev.lua
+++ b/runtime/lua/vim/treesitter/dev.lua
@@ -119,7 +119,7 @@ function TSTreeView:new(bufnr, lang)
end
local t = {
- ns = api.nvim_create_namespace('treesitter/dev-inspect'),
+ ns = api.nvim_create_namespace('nvim.treesitter.dev_inspect'),
nodes = nodes,
named = named,
---@type vim.treesitter.dev.TSTreeViewOpts
@@ -135,15 +135,7 @@ function TSTreeView:new(bufnr, lang)
return t
end
-local decor_ns = api.nvim_create_namespace('ts.dev')
-
----@param range Range4
----@return string
-local function range_to_string(range)
- ---@type integer, integer, integer, integer
- local row, col, end_row, end_col = unpack(range)
- return string.format('[%d, %d] - [%d, %d]', row, col, end_row, end_col)
-end
+local decor_ns = api.nvim_create_namespace('nvim.treesitter.dev')
---@param w integer
---@return boolean closed Whether the window was closed.
@@ -227,14 +219,17 @@ function TSTreeView:draw(bufnr)
local lang_hl_marks = {} ---@type table[]
for i, item in self:iter() do
- local range_str = range_to_string({ item.node:range() })
+ local range_str = ('[%d, %d] - [%d, %d]'):format(item.node:range())
local lang_str = self.opts.lang and string.format(' %s', item.lang) or ''
local text ---@type string
if item.node:named() then
- text = string.format('(%s', item.node:type())
+ text = string.format('(%s%s', item.node:missing() and 'MISSING ' or '', item.node:type())
else
text = string.format('%q', item.node:type()):gsub('\n', 'n')
+ if item.node:missing() then
+ text = string.format('(MISSING %s)', text)
+ end
end
if item.field then
text = string.format('%s: %s', item.field, text)
@@ -442,7 +437,7 @@ function M.inspect_tree(opts)
end,
})
- local group = api.nvim_create_augroup('treesitter/dev', {})
+ local group = api.nvim_create_augroup('nvim.treesitter.dev', {})
api.nvim_create_autocmd('CursorMoved', {
group = group,
@@ -547,7 +542,7 @@ function M.inspect_tree(opts)
})
end
-local edit_ns = api.nvim_create_namespace('treesitter/dev-edit')
+local edit_ns = api.nvim_create_namespace('nvim.treesitter.dev_edit')
---@param query_win integer
---@param base_win integer
@@ -633,7 +628,7 @@ function M.edit_query(lang)
-- can infer the language later.
api.nvim_buf_set_name(query_buf, string.format('%s/query_editor.scm', lang))
- local group = api.nvim_create_augroup('treesitter/dev-edit', {})
+ local group = api.nvim_create_augroup('nvim.treesitter.dev_edit', {})
api.nvim_create_autocmd({ 'TextChanged', 'InsertLeave' }, {
group = group,
buffer = query_buf,
diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua
index 8ce8652f7d..6dd47811bd 100644
--- a/runtime/lua/vim/treesitter/highlighter.lua
+++ b/runtime/lua/vim/treesitter/highlighter.lua
@@ -2,7 +2,7 @@ local api = vim.api
local query = vim.treesitter.query
local Range = require('vim.treesitter._range')
-local ns = api.nvim_create_namespace('treesitter/highlighter')
+local ns = api.nvim_create_namespace('nvim.treesitter.highlighter')
---@alias vim.treesitter.highlighter.Iter fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch
@@ -69,6 +69,7 @@ end
---@field private _queries table<string,vim.treesitter.highlighter.Query>
---@field tree vim.treesitter.LanguageTree
---@field private redraw_count integer
+---@field parsing boolean true if we are parsing asynchronously
local TSHighlighter = {
active = {},
}
@@ -147,8 +148,6 @@ function TSHighlighter.new(tree, opts)
vim.opt_local.spelloptions:append('noplainbuffer')
end)
- self.tree:parse()
-
return self
end
@@ -161,7 +160,10 @@ function TSHighlighter:destroy()
vim.bo[self.bufnr].spelloptions = self.orig_spelloptions
vim.b[self.bufnr].ts_highlight = nil
if vim.g.syntax_on == 1 then
- api.nvim_exec_autocmds('FileType', { group = 'syntaxset', buffer = self.bufnr })
+ api.nvim_exec_autocmds(
+ 'FileType',
+ { group = 'syntaxset', buffer = self.bufnr, modeline = false }
+ )
end
end
end
@@ -299,6 +301,8 @@ local function on_line_impl(self, buf, line, is_spell_nav)
state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1)
end
+ local captures = state.highlighter_query:query().captures
+
while line >= state.next_row do
local capture, node, metadata, match = state.iter(line)
@@ -311,7 +315,7 @@ local function on_line_impl(self, buf, line, is_spell_nav)
if capture then
local hl = state.highlighter_query:get_hl_from_capture(capture)
- local capture_name = state.highlighter_query:query().captures[capture]
+ local capture_name = captures[capture]
local spell, spell_pri_offset = get_spell(capture_name)
@@ -382,19 +386,23 @@ function TSHighlighter._on_spell_nav(_, _, buf, srow, _, erow, _)
end
---@private
----@param _win integer
---@param buf integer
---@param topline integer
---@param botline integer
-function TSHighlighter._on_win(_, _win, buf, topline, botline)
+function TSHighlighter._on_win(_, _, buf, topline, botline)
local self = TSHighlighter.active[buf]
- if not self then
+ if not self or self.parsing then
return false
end
- self.tree:parse({ topline, botline + 1 })
- self:prepare_highlight_states(topline, botline + 1)
+ self.parsing = self.tree:parse({ topline, botline + 1 }, function(_, trees)
+ if trees and self.parsing then
+ self.parsing = false
+ api.nvim__redraw({ buf = buf, valid = false, flush = false })
+ end
+ end) == nil
self.redraw_count = self.redraw_count + 1
- return true
+ self:prepare_highlight_states(topline, botline)
+ return #self._highlight_states > 0
end
api.nvim_set_decoration_provider(ns, {
diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua
index 446051dfd7..16d19bfc5a 100644
--- a/runtime/lua/vim/treesitter/language.lua
+++ b/runtime/lua/vim/treesitter/language.lua
@@ -133,8 +133,9 @@ function M.add(lang, opts)
path = paths[1]
end
- return loadparser(path, lang, symbol_name) or nil,
- string.format('Cannot load parser %s for language "%s"', path, lang)
+ local res = loadparser(path, lang, symbol_name)
+ return res,
+ res == nil and string.format('Cannot load parser %s for language "%s"', path, lang) or nil
end
--- @param x string|string[]
@@ -174,7 +175,7 @@ end
--- (`"`).
---
---@param lang string Language
----@return table
+---@return TSLangInfo
function M.inspect(lang)
M.add(lang)
return vim._ts_inspect_language(lang)
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index 4b42164dc8..ea745c4deb 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -44,6 +44,8 @@ local query = require('vim.treesitter.query')
local language = require('vim.treesitter.language')
local Range = require('vim.treesitter._range')
+local default_parse_timeout_ms = 3
+
---@alias TSCallbackName
---| 'changedtree'
---| 'bytes'
@@ -58,6 +60,8 @@ local Range = require('vim.treesitter._range')
---| 'on_child_added'
---| 'on_child_removed'
+---@alias ParserThreadState { timeout: integer? }
+
--- @type table<TSCallbackNameOn,TSCallbackName>
local TSCallbackNames = {
on_changedtree = 'changedtree',
@@ -76,8 +80,13 @@ local TSCallbackNames = {
---@field private _injections_processed boolean
---@field private _opts table Options
---@field private _parser TSParser Parser for language
----@field private _has_regions boolean
+---Table of regions for which the tree is currently running an async parse
+---@field private _ranges_being_parsed table<string, boolean>
+---Table of callback queues, keyed by each region for which the callbacks should be run
+---@field private _cb_queues table<string, fun(err?: string, trees?: table<integer, TSTree>)[]>
---@field private _regions table<integer, Range6[]>?
+---The total number of regions. Since _regions can have holes, we cannot simply read this value from #_regions.
+---@field private _num_regions integer
---List of regions this tree should manage and parse. If nil then regions are
---taken from _trees. This is mostly a short-lived cache for included_regions()
---@field private _lang string Language name
@@ -85,7 +94,8 @@ local TSCallbackNames = {
---@field private _source (integer|string) Buffer or string to parse
---@field private _trees table<integer, TSTree> Reference to parsed tree (one for each language).
---Each key is the index of region, which is synced with _regions and _valid.
----@field private _valid boolean|table<integer,boolean> If the parsed tree is valid
+---@field private _valid_regions table<integer,true> Set of valid region IDs.
+---@field private _is_entirely_valid boolean Whether the entire tree (excluding children) is valid.
---@field private _logger? fun(logtype: string, msg: string)
---@field private _logfile? file*
local LanguageTree = {}
@@ -117,7 +127,7 @@ function LanguageTree.new(source, lang, opts)
local injections = opts.injections or {}
- --- @type vim.treesitter.LanguageTree
+ --- @class vim.treesitter.LanguageTree
local self = {
_source = source,
_lang = lang,
@@ -126,10 +136,13 @@ function LanguageTree.new(source, lang, opts)
_opts = opts,
_injection_query = injections[lang] and query.parse(lang, injections[lang])
or query.get(lang, 'injections'),
- _has_regions = false,
_injections_processed = false,
- _valid = false,
+ _valid_regions = {},
+ _num_regions = 1,
+ _is_entirely_valid = false,
_parser = vim._create_ts_parser(lang),
+ _ranges_being_parsed = {},
+ _cb_queues = {},
_callbacks = {},
_callbacks_rec = {},
}
@@ -182,7 +195,7 @@ end
---Measure execution time of a function
---@generic R1, R2, R3
----@param f fun(): R1, R2, R2
+---@param f fun(): R1, R2, R3
---@return number, R1, R2, R3
local function tcall(f, ...)
local start = vim.uv.hrtime()
@@ -190,6 +203,7 @@ local function tcall(f, ...)
local r = { f(...) }
--- @type number
local duration = (vim.uv.hrtime() - start) / 1000000
+ --- @diagnostic disable-next-line: redundant-return-value
return duration, unpack(r)
end
@@ -231,7 +245,9 @@ end
--- tree in treesitter. Doesn't clear filesystem cache. Called often, so needs to be fast.
---@param reload boolean|nil
function LanguageTree:invalidate(reload)
- self._valid = false
+ self._valid_regions = {}
+ self._is_entirely_valid = false
+ self._parser:reset()
-- buffer was reloaded, reparse all trees
if reload then
@@ -258,20 +274,51 @@ function LanguageTree:trees()
end
--- Gets the language of this tree node.
+--- @return string
function LanguageTree:lang()
return self._lang
end
+--- @param region Range6[]
+--- @param range? boolean|Range
+--- @return boolean
+local function intercepts_region(region, range)
+ if #region == 0 then
+ return true
+ end
+
+ if range == nil then
+ return false
+ end
+
+ if type(range) == 'boolean' then
+ return range
+ end
+
+ for _, r in ipairs(region) do
+ if Range.intercepts(r, range) then
+ return true
+ end
+ end
+
+ return false
+end
+
--- Returns whether this LanguageTree is valid, i.e., |LanguageTree:trees()| reflects the latest
--- state of the source. If invalid, user should call |LanguageTree:parse()|.
----@param exclude_children boolean|nil whether to ignore the validity of children (default `false`)
+---@param exclude_children boolean? whether to ignore the validity of children (default `false`)
+---@param range Range? range to check for validity
---@return boolean
-function LanguageTree:is_valid(exclude_children)
- local valid = self._valid
+function LanguageTree:is_valid(exclude_children, range)
+ local valid_regions = self._valid_regions
- if type(valid) == 'table' then
- for i, _ in pairs(self:included_regions()) do
- if not valid[i] then
+ if not self._is_entirely_valid then
+ if not range then
+ return false
+ end
+ -- TODO: Efficiently search for possibly intersecting regions using a binary search
+ for i, region in pairs(self:included_regions()) do
+ if not valid_regions[i] and intercepts_region(region, range) then
return false
end
end
@@ -283,97 +330,81 @@ function LanguageTree:is_valid(exclude_children)
end
for _, child in pairs(self._children) do
- if not child:is_valid(exclude_children) then
+ if not child:is_valid(exclude_children, range) then
return false
end
end
end
- if type(valid) == 'boolean' then
- return valid
- end
-
- self._valid = true
return true
end
--- Returns a map of language to child tree.
+--- @return table<string,vim.treesitter.LanguageTree>
function LanguageTree:children()
return self._children
end
--- Returns the source content of the language tree (bufnr or string).
+--- @return integer|string
function LanguageTree:source()
return self._source
end
---- @param region Range6[]
---- @param range? boolean|Range
---- @return boolean
-local function intercepts_region(region, range)
- if #region == 0 then
- return true
- end
-
- if range == nil then
- return false
- end
-
- if type(range) == 'boolean' then
- return range
- end
-
- for _, r in ipairs(region) do
- if Range.intercepts(r, range) then
- return true
- end
- end
-
- return false
-end
-
--- @private
--- @param range boolean|Range?
+--- @param thread_state ParserThreadState
--- @return Range6[] changes
--- @return integer no_regions_parsed
--- @return number total_parse_time
-function LanguageTree:_parse_regions(range)
+--- @return boolean finished whether async parsing still needs time
+function LanguageTree:_parse_regions(range, thread_state)
local changes = {}
local no_regions_parsed = 0
local total_parse_time = 0
- if type(self._valid) ~= 'table' then
- self._valid = {}
- end
-
-- If there are no ranges, set to an empty list
-- so the included ranges in the parser are cleared.
for i, ranges in pairs(self:included_regions()) do
if
- not self._valid[i]
+ not self._valid_regions[i]
and (
intercepts_region(ranges, range)
or (self._trees[i] and intercepts_region(self._trees[i]:included_ranges(false), range))
)
then
self._parser:set_included_ranges(ranges)
+ self._parser:set_timeout(thread_state.timeout and thread_state.timeout * 1000 or 0) -- ms -> micros
+
local parse_time, tree, tree_changes =
tcall(self._parser.parse, self._parser, self._trees[i], self._source, true)
+ while true do
+ if tree then
+ break
+ end
+ coroutine.yield(changes, no_regions_parsed, total_parse_time, false)
- -- Pass ranges if this is an initial parse
- local cb_changes = self._trees[i] and tree_changes or tree:included_ranges(true)
+ parse_time, tree, tree_changes =
+ tcall(self._parser.parse, self._parser, self._trees[i], self._source, true)
+ end
- self:_do_callback('changedtree', cb_changes, tree)
+ self:_do_callback('changedtree', tree_changes, tree)
self._trees[i] = tree
vim.list_extend(changes, tree_changes)
total_parse_time = total_parse_time + parse_time
no_regions_parsed = no_regions_parsed + 1
- self._valid[i] = true
+ self._valid_regions[i] = true
+
+ -- _valid_regions can have holes, but that is okay because this equality is only true when it
+ -- has no holes (meaning all regions are valid)
+ if #self._valid_regions == self._num_regions then
+ self._is_entirely_valid = true
+ end
end
end
- return changes, no_regions_parsed, total_parse_time
+ return changes, no_regions_parsed, total_parse_time, true
end
--- @private
@@ -409,6 +440,98 @@ function LanguageTree:_add_injections()
return query_time
end
+--- @param range boolean|Range?
+--- @return string
+local function range_to_string(range)
+ return type(range) == 'table' and table.concat(range, ',') or tostring(range)
+end
+
+--- @private
+--- @param range boolean|Range?
+--- @param callback fun(err?: string, trees?: table<integer, TSTree>)
+function LanguageTree:_push_async_callback(range, callback)
+ local key = range_to_string(range)
+ self._cb_queues[key] = self._cb_queues[key] or {}
+ local queue = self._cb_queues[key]
+ queue[#queue + 1] = callback
+end
+
+--- @private
+--- @param range boolean|Range?
+--- @param err? string
+--- @param trees? table<integer, TSTree>
+function LanguageTree:_run_async_callbacks(range, err, trees)
+ local key = range_to_string(range)
+ for _, cb in ipairs(self._cb_queues[key]) do
+ cb(err, trees)
+ end
+ self._ranges_being_parsed[key] = nil
+ self._cb_queues[key] = nil
+end
+
+--- Run an asynchronous parse, calling {on_parse} when complete.
+---
+--- @private
+--- @param range boolean|Range?
+--- @param on_parse fun(err?: string, trees?: table<integer, TSTree>)
+--- @return table<integer, TSTree>? trees the list of parsed trees, if parsing completed synchronously
+function LanguageTree:_async_parse(range, on_parse)
+ self:_push_async_callback(range, on_parse)
+
+ -- If we are already running an async parse, just queue the callback.
+ local range_string = range_to_string(range)
+ if not self._ranges_being_parsed[range_string] then
+ self._ranges_being_parsed[range_string] = true
+ else
+ return
+ end
+
+ local source = self._source
+ local is_buffer_parser = type(source) == 'number'
+ local buf = is_buffer_parser and vim.b[source] or nil
+ local ct = is_buffer_parser and buf.changedtick or nil
+ local total_parse_time = 0
+ local redrawtime = vim.o.redrawtime
+
+ local thread_state = {} ---@type ParserThreadState
+
+ ---@type fun(): table<integer, TSTree>, boolean
+ local parse = coroutine.wrap(self._parse)
+
+ local function step()
+ if is_buffer_parser then
+ if
+ not vim.api.nvim_buf_is_valid(source --[[@as number]])
+ then
+ return nil
+ end
+
+ -- If buffer was changed in the middle of parsing, reset parse state
+ if buf.changedtick ~= ct then
+ ct = buf.changedtick
+ total_parse_time = 0
+ parse = coroutine.wrap(self._parse)
+ end
+ end
+
+ thread_state.timeout = not vim.g._ts_force_sync_parsing and default_parse_timeout_ms or nil
+ local parse_time, trees, finished = tcall(parse, self, range, thread_state)
+ total_parse_time = total_parse_time + parse_time
+
+ if finished then
+ self:_run_async_callbacks(range, nil, trees)
+ return trees
+ elseif total_parse_time > redrawtime then
+ self:_run_async_callbacks(range, 'TIMEOUT', nil)
+ return nil
+ else
+ vim.schedule(step)
+ end
+ end
+
+ return step()
+end
+
--- Recursively parse all regions in the language tree using |treesitter-parsers|
--- for the corresponding languages and run injection queries on the parsed trees
--- to determine whether child trees should be created and parsed.
@@ -420,11 +543,33 @@ end
--- Set to `true` to run a complete parse of the source (Note: Can be slow!)
--- Set to `false|nil` to only parse regions with empty ranges (typically
--- only the root tree without injections).
---- @return table<integer, TSTree>
-function LanguageTree:parse(range)
- if self:is_valid() then
+--- @param on_parse fun(err?: string, trees?: table<integer, TSTree>)? Function invoked when parsing completes.
+--- When provided and `vim.g._ts_force_sync_parsing` is not set, parsing will run
+--- asynchronously. The first argument to the function is a string representing the error type,
+--- in case of a failure (currently only possible for timeouts). The second argument is the list
+--- of trees returned by the parse (upon success), or `nil` if the parse timed out (determined
+--- by 'redrawtime').
+---
+--- If parsing was still able to finish synchronously (within 3ms), `parse()` returns the list
+--- of trees. Otherwise, it returns `nil`.
+--- @return table<integer, TSTree>?
+function LanguageTree:parse(range, on_parse)
+ if on_parse then
+ return self:_async_parse(range, on_parse)
+ end
+ local trees, _ = self:_parse(range, {})
+ return trees
+end
+
+--- @private
+--- @param range boolean|Range|nil
+--- @param thread_state ParserThreadState
+--- @return table<integer, TSTree> trees
+--- @return boolean finished
+function LanguageTree:_parse(range, thread_state)
+ if self:is_valid(nil, type(range) == 'table' and range or nil) then
self:_log('valid')
- return self._trees
+ return self._trees, true
end
local changes --- @type Range6[]?
@@ -435,15 +580,27 @@ function LanguageTree:parse(range)
local total_parse_time = 0
-- At least 1 region is invalid
- if not self:is_valid(true) then
- changes, no_regions_parsed, total_parse_time = self:_parse_regions(range)
+ if not self:is_valid(true, type(range) == 'table' and range or nil) then
+ ---@type fun(self: vim.treesitter.LanguageTree, range: boolean|Range?, thread_state: ParserThreadState): Range6[], integer, number, boolean
+ local parse_regions = coroutine.wrap(self._parse_regions)
+ while true do
+ local is_finished
+ changes, no_regions_parsed, total_parse_time, is_finished =
+ parse_regions(self, range, thread_state)
+ thread_state.timeout = thread_state.timeout
+ and math.max(thread_state.timeout - total_parse_time, 0)
+ if is_finished then
+ break
+ end
+ coroutine.yield(self._trees, false)
+ end
-- Need to run injections when we parsed something
if no_regions_parsed > 0 then
self._injections_processed = false
end
end
- if not self._injections_processed and range ~= false and range ~= nil then
+ if not self._injections_processed and range then
query_time = self:_add_injections()
self._injections_processed = true
end
@@ -457,10 +614,24 @@ function LanguageTree:parse(range)
})
for _, child in pairs(self._children) do
- child:parse(range)
+ if thread_state.timeout == 0 then
+ coroutine.yield(self._trees, false)
+ end
+
+ ---@type fun(): table<integer, TSTree>, boolean
+ local parse = coroutine.wrap(child._parse)
+
+ while true do
+ local ctime, _, child_finished = tcall(parse, child, range, thread_state)
+ if child_finished then
+ thread_state.timeout = thread_state.timeout and math.max(thread_state.timeout - ctime, 0)
+ break
+ end
+ coroutine.yield(self._trees, child_finished)
+ end
end
- return self._trees
+ return self._trees, true
end
--- Invokes the callback for each |LanguageTree| recursively.
@@ -504,7 +675,8 @@ function LanguageTree:add_child(lang)
return self._children[lang]
end
---- @package
+---Returns the parent tree. `nil` for the root tree.
+---@return vim.treesitter.LanguageTree?
function LanguageTree:parent()
return self._parent
end
@@ -551,38 +723,34 @@ end
---region is valid or not.
---@param fn fun(index: integer, region: Range6[]): boolean
function LanguageTree:_iter_regions(fn)
- if not self._valid then
+ if vim.deep_equal(self._valid_regions, {}) then
return
end
- local was_valid = type(self._valid) ~= 'table'
-
- if was_valid then
- self:_log('was valid', self._valid)
- self._valid = {}
+ if self._is_entirely_valid then
+ self:_log('was valid')
end
local all_valid = true
for i, region in pairs(self:included_regions()) do
- if was_valid or self._valid[i] then
- self._valid[i] = fn(i, region)
- if not self._valid[i] then
+ if self._valid_regions[i] then
+ -- Setting this to nil rather than false allows us to determine if all regions were parsed
+ -- just by checking the length of _valid_regions.
+ self._valid_regions[i] = fn(i, region) and true or nil
+ if not self._valid_regions[i] then
self:_log(function()
return 'invalidating region', i, region_tostr(region)
end)
end
end
- if not self._valid[i] then
+ if not self._valid_regions[i] then
all_valid = false
end
end
- -- Compress the valid value to 'true' if there are no invalid regions
- if all_valid then
- self._valid = all_valid
- end
+ self._is_entirely_valid = all_valid
end
--- Sets the included regions that should be parsed by this |LanguageTree|.
@@ -602,14 +770,13 @@ end
---@private
---@param new_regions (Range4|Range6|TSNode)[][] List of regions this tree should manage and parse.
function LanguageTree:set_included_regions(new_regions)
- self._has_regions = true
-
-- Transform the tables from 4 element long to 6 element long (with byte offset)
for _, region in ipairs(new_regions) do
for i, range in ipairs(region) do
if type(range) == 'table' and #range == 4 then
region[i] = Range.add_bytes(self._source, range --[[@as Range4]])
elseif type(range) == 'userdata' then
+ --- @diagnostic disable-next-line: missing-fields LuaLS varargs bug
region[i] = { range:range(true) }
end
end
@@ -633,6 +800,7 @@ function LanguageTree:set_included_regions(new_regions)
end
self._regions = new_regions
+ self._num_regions = #new_regions
end
---Gets the set of included regions managed by this LanguageTree. This can be different from the
@@ -646,18 +814,8 @@ function LanguageTree:included_regions()
return self._regions
end
- if not self._has_regions then
- -- treesitter.c will default empty ranges to { -1, -1, -1, -1, -1, -1} (the full range)
- return { {} }
- end
-
- local regions = {} ---@type Range6[][]
- for i, _ in pairs(self._trees) do
- regions[i] = self._trees[i]:included_ranges(true)
- end
-
- self._regions = regions
- return regions
+ -- treesitter.c will default empty ranges to { -1, -1, -1, -1, -1, -1} (the full range)
+ return { {} }
end
---@param node TSNode
@@ -821,7 +979,7 @@ end
--- @private
--- @return table<string, Range6[][]>
function LanguageTree:_get_injections()
- if not self._injection_query then
+ if not self._injection_query or #self._injection_query.captures == 0 then
return {}
end
@@ -907,7 +1065,15 @@ function LanguageTree:_edit(
)
end
- self._regions = nil
+ self._parser:reset()
+
+ if self._regions then
+ local regions = {} ---@type table<integer, Range6[]>
+ for i, tree in pairs(self._trees) do
+ regions[i] = tree:included_ranges(true)
+ end
+ self._regions = regions
+ end
local changed_range = {
start_row,
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 1677e8d364..10fb82e533 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -1,17 +1,77 @@
+--- @brief This Lua |treesitter-query| interface allows you to create queries and use them to parse
+--- text. See |vim.treesitter.query.parse()| for a working example.
+
local api = vim.api
local language = require('vim.treesitter.language')
local memoize = vim.func._memoize
+local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
+local EXTENDS_FORMAT = '^;+%s*extends%s*$'
+
local M = {}
+local function is_directive(name)
+ return string.sub(name, -1) == '!'
+end
+
+---@nodoc
+---@class vim.treesitter.query.ProcessedPredicate
+---@field [1] string predicate name
+---@field [2] boolean should match
+---@field [3] (integer|string)[] the original predicate
+
+---@alias vim.treesitter.query.ProcessedDirective (integer|string)[]
+
+---@nodoc
+---@class vim.treesitter.query.ProcessedPattern {
+---@field predicates vim.treesitter.query.ProcessedPredicate[]
+---@field directives vim.treesitter.query.ProcessedDirective[]
+
+--- Splits the query patterns into predicates and directives.
+---@param patterns table<integer, (integer|string)[][]>
+---@return table<integer, vim.treesitter.query.ProcessedPattern>
+local function process_patterns(patterns)
+ ---@type table<integer, vim.treesitter.query.ProcessedPattern>
+ local processed_patterns = {}
+
+ for k, pattern_list in pairs(patterns) do
+ ---@type vim.treesitter.query.ProcessedPredicate[]
+ local predicates = {}
+ ---@type vim.treesitter.query.ProcessedDirective[]
+ local directives = {}
+
+ for _, pattern in ipairs(pattern_list) do
+ -- Note: tree-sitter strips the leading # from predicates for us.
+ local pred_name = pattern[1]
+ ---@cast pred_name string
+
+ if is_directive(pred_name) then
+ table.insert(directives, pattern)
+ else
+ local should_match = true
+ if pred_name:match('^not%-') then
+ pred_name = pred_name:sub(5)
+ should_match = false
+ end
+ table.insert(predicates, { pred_name, should_match, pattern })
+ end
+ end
+
+ processed_patterns[k] = { predicates = predicates, directives = directives }
+ end
+
+ return processed_patterns
+end
+
---@nodoc
---Parsed query, see |vim.treesitter.query.parse()|
---
---@class vim.treesitter.Query
----@field lang string name of the language for this parser
+---@field lang string parser language name
---@field captures string[] list of (unique) capture names defined in query
----@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives)
+---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives)
---@field query TSQuery userdata query object
+---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern>
local Query = {}
Query.__index = Query
@@ -30,6 +90,7 @@ function Query.new(lang, ts_query)
patterns = query_info.patterns,
}
self.captures = self.info.captures
+ self._processed_patterns = process_patterns(self.info.patterns)
return self
end
@@ -109,9 +170,6 @@ function M.get_files(lang, query_name, is_included)
-- ;+ inherits: ({language},)*{language}
--
-- {language} ::= {lang} | ({lang})
- local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
- local EXTENDS_FORMAT = '^;+%s*extends%s*$'
-
for _, filename in ipairs(lang_files) do
local file, err = io.open(filename, 'r')
if not file then
@@ -184,8 +242,8 @@ local function read_query_files(filenames)
return table.concat(contents, '')
end
--- The explicitly set queries from |vim.treesitter.query.set()|
----@type table<string,table<string,vim.treesitter.Query>>
+-- The explicitly set query strings from |vim.treesitter.query.set()|
+---@type table<string,table<string,string>>
local explicit_queries = setmetatable({}, {
__index = function(t, k)
local lang_queries = {}
@@ -197,14 +255,27 @@ local explicit_queries = setmetatable({}, {
--- Sets the runtime query named {query_name} for {lang}
---
---- This allows users to override any runtime files and/or configuration
+--- This allows users to override or extend any runtime files and/or configuration
--- set by plugins.
---
+--- For example, you could enable spellchecking of `C` identifiers with the
+--- following code:
+--- ```lua
+--- vim.treesitter.query.set(
+--- 'c',
+--- 'highlights',
+--- [[;inherits c
+--- (identifier) @spell]])
+--- ]])
+--- ```
+---
---@param lang string Language to use for the query
---@param query_name string Name of the query (e.g., "highlights")
---@param text string Query text (unparsed).
function M.set(lang, query_name, text)
- explicit_queries[lang][query_name] = M.parse(lang, text)
+ --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics
+ M.get:clear(lang, query_name)
+ explicit_queries[lang][query_name] = text
end
--- Returns the runtime query {query_name} for {lang}.
@@ -214,34 +285,82 @@ end
---
---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found.
M.get = memoize('concat-2', function(lang, query_name)
+ local query_string ---@type string
+
if explicit_queries[lang][query_name] then
- return explicit_queries[lang][query_name]
- end
+ local query_files = {}
+ local base_langs = {} ---@type string[]
- local query_files = M.get_files(lang, query_name)
- local query_string = read_query_files(query_files)
+ for line in explicit_queries[lang][query_name]:gmatch('([^\n]*)\n?') do
+ if not vim.startswith(line, ';') then
+ break
+ end
+
+ local lang_list = line:match(MODELINE_FORMAT)
+ if lang_list then
+ for _, incl_lang in ipairs(vim.split(lang_list, ',')) do
+ local is_optional = incl_lang:match('%(.*%)')
+
+ if is_optional then
+ add_included_lang(base_langs, lang, incl_lang:sub(2, #incl_lang - 1))
+ else
+ add_included_lang(base_langs, lang, incl_lang)
+ end
+ end
+ elseif line:match(EXTENDS_FORMAT) then
+ table.insert(base_langs, lang)
+ end
+ end
+
+ for _, base_lang in ipairs(base_langs) do
+ local base_files = M.get_files(base_lang, query_name, true)
+ vim.list_extend(query_files, base_files)
+ end
+
+ query_string = read_query_files(query_files) .. explicit_queries[lang][query_name]
+ else
+ local query_files = M.get_files(lang, query_name)
+ query_string = read_query_files(query_files)
+ end
if #query_string == 0 then
return nil
end
return M.parse(lang, query_string)
-end)
+end, false)
+
+api.nvim_create_autocmd('OptionSet', {
+ pattern = { 'runtimepath' },
+ group = api.nvim_create_augroup('nvim.treesitter.query_cache_reset', { clear = true }),
+ callback = function()
+ --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics
+ M.get:clear()
+ end,
+})
---- Parse {query} as a string. (If the query is in a file, the caller
---- should read the contents into a string before calling).
----
---- Returns a `Query` (see |lua-treesitter-query|) object which can be used to
---- search nodes in the syntax tree for the patterns defined in {query}
---- using the `iter_captures` and `iter_matches` methods.
+--- Parses a {query} string and returns a `Query` object (|lua-treesitter-query|), which can be used
+--- to search the tree for the query patterns (via |Query:iter_captures()|, |Query:iter_matches()|),
+--- or inspect the query via these fields:
+--- - `captures`: a list of unique capture names defined in the query (alias: `info.captures`).
+--- - `info.patterns`: information about predicates.
---
---- Exposes `info` and `captures` with additional context about {query}.
---- - `captures` contains the list of unique capture names defined in {query}.
---- - `info.captures` also points to `captures`.
---- - `info.patterns` contains information about predicates.
+--- Example:
+--- ```lua
+--- local query = vim.treesitter.query.parse('vimdoc', [[
+--- ; query
+--- ((h1) @str
+--- (#trim! @str 1 1 1 1))
+--- ]])
+--- local tree = vim.treesitter.get_parser():parse()[1]
+--- for id, node, metadata in query:iter_captures(tree:root(), 0) do
+--- -- Print the node name and source text.
+--- vim.print({node:type(), vim.treesitter.get_node_text(node, vim.api.nvim_get_current_buf())})
+--- end
+--- ```
---
---@param lang string Language to use for the query
----@param query string Query in s-expr syntax
+---@param query string Query text, in s-expr syntax
---
---@return vim.treesitter.Query : Parsed query
---
@@ -250,7 +369,7 @@ M.parse = memoize('concat-2', function(lang, query)
assert(language.add(lang))
local ts_query = vim._ts_parse_query(lang, query)
return Query.new(lang, ts_query)
-end)
+end, false)
--- Implementations of predicates that can optionally be prefixed with "any-".
---
@@ -572,13 +691,17 @@ local directive_handlers = {
metadata[id].text = text:gsub(pattern, replacement)
end,
- -- Trim blank lines from end of the node
- -- Example: (#trim! @fold)
- -- TODO(clason): generalize to arbitrary whitespace removal
+ -- Trim whitespace from both sides of the node
+ -- Example: (#trim! @fold 1 1 1 1)
['trim!'] = function(match, _, bufnr, pred, metadata)
local capture_id = pred[2]
assert(type(capture_id) == 'number')
+ local trim_start_lines = pred[3] == '1'
+ local trim_start_cols = pred[4] == '1'
+ local trim_end_lines = pred[5] == '1' or not pred[3] -- default true for backwards compatibility
+ local trim_end_cols = pred[6] == '1'
+
local nodes = match[capture_id]
if not nodes or #nodes == 0 then
return
@@ -588,20 +711,45 @@ local directive_handlers = {
local start_row, start_col, end_row, end_col = node:range()
- -- Don't trim if region ends in middle of a line
- if end_col ~= 0 then
- return
+ local node_text = vim.split(vim.treesitter.get_node_text(node, bufnr), '\n')
+ if end_col == 0 then
+ -- get_node_text() will ignore the last line if the node ends at column 0
+ node_text[#node_text + 1] = ''
end
- while end_row >= start_row do
- -- As we only care when end_col == 0, always inspect one line above end_row.
- local end_line = api.nvim_buf_get_lines(bufnr, end_row - 1, end_row, true)[1]
+ local end_idx = #node_text
+ local start_idx = 1
- if end_line ~= '' then
- break
+ if trim_end_lines then
+ while end_idx > 0 and node_text[end_idx]:find('^%s*$') do
+ end_idx = end_idx - 1
+ end_row = end_row - 1
+ -- set the end position to the last column of the next line, or 0 if we just trimmed the
+ -- last line
+ end_col = end_idx > 0 and #node_text[end_idx] or 0
end
+ end
+ if trim_end_cols then
+ if end_idx == 0 then
+ end_row = start_row
+ end_col = start_col
+ else
+ local whitespace_start = node_text[end_idx]:find('(%s*)$')
+ end_col = (whitespace_start - 1) + (end_idx == 1 and start_col or 0)
+ end
+ end
- end_row = end_row - 1
+ if trim_start_lines then
+ while start_idx <= end_idx and node_text[start_idx]:find('^%s*$') do
+ start_idx = start_idx + 1
+ start_row = start_row + 1
+ start_col = 0
+ end
+ end
+ if trim_start_cols and node_text[start_idx] then
+ local _, whitespace_end = node_text[start_idx]:find('^(%s*)')
+ whitespace_end = whitespace_end or 0
+ start_col = (start_idx == 1 and start_col or 0) + whitespace_end
end
-- If this produces an invalid range, we just skip it.
@@ -711,84 +859,50 @@ function M.list_predicates()
return vim.tbl_keys(predicate_handlers)
end
-local function xor(x, y)
- return (x or y) and not (x and y)
-end
-
-local function is_directive(name)
- return string.sub(name, -1) == '!'
-end
-
---@private
----@param match TSQueryMatch
+---@param pattern_i integer
+---@param predicates vim.treesitter.query.ProcessedPredicate[]
+---@param captures table<integer, TSNode[]>
---@param source integer|string
-function Query:match_preds(match, source)
- local _, pattern = match:info()
- local preds = self.info.patterns[pattern]
-
- if not preds then
- return true
- end
-
- local captures = match:captures()
-
- for _, pred in pairs(preds) do
- -- Here we only want to return if a predicate DOES NOT match, and
- -- continue on the other case. This way unknown predicates will not be considered,
- -- which allows some testing and easier user extensibility (#12173).
- -- Also, tree-sitter strips the leading # from predicates for us.
- local is_not = false
-
- -- Skip over directives... they will get processed after all the predicates.
- if not is_directive(pred[1]) then
- local pred_name = pred[1]
- if pred_name:match('^not%-') then
- pred_name = pred_name:sub(5)
- is_not = true
- end
-
- local handler = predicate_handlers[pred_name]
-
- if not handler then
- error(string.format('No handler for %s', pred[1]))
- return false
- end
-
- local pred_matches = handler(captures, pattern, source, pred)
+---@return boolean whether the predicates match
+function Query:_match_predicates(predicates, pattern_i, captures, source)
+ for _, predicate in ipairs(predicates) do
+ local processed_name = predicate[1]
+ local should_match = predicate[2]
+ local orig_predicate = predicate[3]
+
+ local handler = predicate_handlers[processed_name]
+ if not handler then
+ error(string.format('No handler for %s', orig_predicate[1]))
+ return false
+ end
- if not xor(is_not, pred_matches) then
- return false
- end
+ local does_match = handler(captures, pattern_i, source, orig_predicate)
+ if does_match ~= should_match then
+ return false
end
end
return true
end
---@private
----@param match TSQueryMatch
+---@param pattern_i integer
+---@param directives vim.treesitter.query.ProcessedDirective[]
+---@param source integer|string
+---@param captures table<integer, TSNode[]>
---@return vim.treesitter.query.TSMetadata metadata
-function Query:apply_directives(match, source)
+function Query:_apply_directives(directives, pattern_i, captures, source)
---@type vim.treesitter.query.TSMetadata
local metadata = {}
- local _, pattern = match:info()
- local preds = self.info.patterns[pattern]
-
- if not preds then
- return metadata
- end
- local captures = match:captures()
-
- for _, pred in pairs(preds) do
- if is_directive(pred[1]) then
- local handler = directive_handlers[pred[1]]
-
- if not handler then
- error(string.format('No handler for %s', pred[1]))
- end
+ for _, directive in pairs(directives) do
+ local handler = directive_handlers[directive[1]]
- handler(captures, pattern, source, pred, metadata)
+ if not handler then
+ error(string.format('No handler for %s', directive[1]))
end
+
+ handler(captures, pattern_i, source, directive, metadata)
end
return metadata
@@ -812,26 +926,22 @@ local function value_or_node_range(start, stop, node)
return start, stop
end
---- @param match TSQueryMatch
---- @return integer
-local function match_id_hash(_, match)
- return (match:info())
-end
-
---- Iterate over all captures from all matches inside {node}
+--- Iterates over all captures from all matches in {node}.
---
---- {source} is needed if the query contains predicates; then the caller
+--- {source} is required if the query contains predicates; then the caller
--- must ensure to use a freshly parsed tree consistent with the current
--- text of the buffer (if relevant). {start} and {stop} can be used to limit
--- matches inside a row range (this is typically used with root node
--- as the {node}, i.e., to get syntax highlight matches in the current
--- viewport). When omitted, the {start} and {stop} row values are used from the given node.
---
---- The iterator returns four values: a numeric id identifying the capture,
---- the captured node, metadata from any directives processing the match,
---- and the match itself.
---- The following example shows how to get captures by name:
+--- The iterator returns four values:
+--- 1. the numeric id identifying the capture
+--- 2. the captured node
+--- 3. metadata from any directives processing the match
+--- 4. the match itself
---
+--- Example: how to get captures by name:
--- ```lua
--- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do
--- local name = query.captures[id] -- name of the capture in the query
@@ -847,8 +957,8 @@ end
---@param start? integer Starting line for the search. Defaults to `node:start()`.
---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
---
----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch):
---- capture id, capture node, metadata, match
+---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch, TSTree):
+--- capture id, capture node, metadata, match, tree
---
---@note Captures are only returned if the query pattern of a specific capture contained predicates.
function Query:iter_captures(node, source, start, stop)
@@ -858,10 +968,14 @@ function Query:iter_captures(node, source, start, stop)
start, stop = value_or_node_range(start, stop, node)
+ -- Copy the tree to ensure it is valid during the entire lifetime of the iterator
+ local tree = node:tree():copy()
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
- local apply_directives = memoize(match_id_hash, self.apply_directives, true)
- local match_preds = memoize(match_id_hash, self.match_preds, true)
+ -- For faster checks that a match is not in the cache.
+ local highest_cached_match_id = -1
+ ---@type table<integer, vim.treesitter.query.TSMetadata>
+ local match_cache = {}
local function iter(end_line)
local capture, captured_node, match = cursor:next_capture()
@@ -870,18 +984,39 @@ function Query:iter_captures(node, source, start, stop)
return
end
- if not match_preds(self, match, source) then
- local match_id = match:info()
- cursor:remove_match(match_id)
- if end_line and captured_node:range() > end_line then
- return nil, captured_node, nil, nil
- end
- return iter(end_line) -- tail call: try next match
+ local match_id, pattern_i = match:info()
+
+ --- @type vim.treesitter.query.TSMetadata
+ local metadata
+ if match_id <= highest_cached_match_id then
+ metadata = match_cache[match_id]
end
- local metadata = apply_directives(self, match, source)
+ if not metadata then
+ metadata = {}
+
+ local processed_pattern = self._processed_patterns[pattern_i]
+ if processed_pattern then
+ local captures = match:captures()
- return capture, captured_node, metadata, match
+ local predicates = processed_pattern.predicates
+ if not self:_match_predicates(predicates, pattern_i, captures, source) then
+ cursor:remove_match(match_id)
+ if end_line and captured_node:range() > end_line then
+ return nil, captured_node, nil, nil
+ end
+ return iter(end_line) -- tail call: try next match
+ end
+
+ local directives = processed_pattern.directives
+ metadata = self:_apply_directives(directives, pattern_i, captures, source)
+ end
+
+ highest_cached_match_id = math.max(highest_cached_match_id, match_id)
+ match_cache[match_id] = metadata
+ end
+
+ return capture, captured_node, metadata, match, tree
end
return iter
end
@@ -903,7 +1038,7 @@ end
--- -- `node` was captured by the `name` capture in the match
---
--- local node_data = metadata[id] -- Node level metadata
---- ... use the info here ...
+--- -- ... use the info here ...
--- end
--- end
--- end
@@ -922,7 +1057,7 @@ end
--- (last) node instead of the full list of matching nodes. This option is only for backward
--- compatibility and will be removed in a future release.
---
----@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata): pattern id, match, metadata
+---@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata, TSTree): pattern id, match, metadata, tree
function Query:iter_matches(node, source, start, stop, opts)
opts = opts or {}
opts.match_limit = opts.match_limit or 256
@@ -933,6 +1068,8 @@ function Query:iter_matches(node, source, start, stop, opts)
start, stop = value_or_node_range(start, stop, node)
+ -- Copy the tree to ensure it is valid during the entire lifetime of the iterator
+ local tree = node:tree():copy()
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts)
local function iter()
@@ -942,17 +1079,22 @@ function Query:iter_matches(node, source, start, stop, opts)
return
end
- local match_id, pattern = match:info()
+ local match_id, pattern_i = match:info()
+ local processed_pattern = self._processed_patterns[pattern_i]
+ local captures = match:captures()
- if not self:match_preds(match, source) then
- cursor:remove_match(match_id)
- return iter() -- tail call: try next match
+ --- @type vim.treesitter.query.TSMetadata
+ local metadata = {}
+ if processed_pattern then
+ local predicates = processed_pattern.predicates
+ if not self:_match_predicates(predicates, pattern_i, captures, source) then
+ cursor:remove_match(match_id)
+ return iter() -- tail call: try next match
+ end
+ local directives = processed_pattern.directives
+ metadata = self:_apply_directives(directives, pattern_i, captures, source)
end
- local metadata = self:apply_directives(match, source)
-
- local captures = match:captures()
-
if opts.all == false then
-- Convert the match table into the old buggy version for backward
-- compatibility. This is slow, but we only do it when the caller explicitly opted into it by
@@ -961,11 +1103,11 @@ function Query:iter_matches(node, source, start, stop, opts)
for k, v in pairs(captures or {}) do
old_match[k] = v[#v]
end
- return pattern, old_match, metadata
+ return pattern_i, old_match, metadata
end
-- TODO(lewis6991): create a new function that returns {match, metadata}
- return pattern, captures, metadata
+ return pattern_i, captures, metadata, tree
end
return iter
end