aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter
diff options
context:
space:
mode:
authorJosh Rahm <joshuarahm@gmail.com>2023-11-29 21:52:58 +0000
committerJosh Rahm <joshuarahm@gmail.com>2023-11-29 21:52:58 +0000
commit931bffbda3668ddc609fc1da8f9eb576b170aa52 (patch)
treed8c1843a95da5ea0bb4acc09f7e37843d9995c86 /runtime/lua/vim/treesitter
parent142d9041391780ac15b89886a54015fdc5c73995 (diff)
parent4a8bf24ac690004aedf5540fa440e788459e5e34 (diff)
downloadrneovim-userreg.tar.gz
rneovim-userreg.tar.bz2
rneovim-userreg.zip
Merge remote-tracking branch 'upstream/master' into userreguserreg
Diffstat (limited to 'runtime/lua/vim/treesitter')
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua456
-rw-r--r--runtime/lua/vim/treesitter/_meta.lua80
-rw-r--r--runtime/lua/vim/treesitter/_query_linter.lua249
-rw-r--r--runtime/lua/vim/treesitter/_range.lua193
-rw-r--r--runtime/lua/vim/treesitter/dev.lua645
-rw-r--r--runtime/lua/vim/treesitter/health.lua30
-rw-r--r--runtime/lua/vim/treesitter/highlighter.lua183
-rw-r--r--runtime/lua/vim/treesitter/language.lua148
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua1041
-rw-r--r--runtime/lua/vim/treesitter/playground.lua186
-rw-r--r--runtime/lua/vim/treesitter/query.lua441
11 files changed, 2962 insertions, 690 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
new file mode 100644
index 0000000000..5c1cc06908
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -0,0 +1,456 @@
+local ts = vim.treesitter
+
+local Range = require('vim.treesitter._range')
+
+local api = vim.api
+
+---@class TS.FoldInfo
+---@field levels table<integer,string>
+---@field levels0 table<integer,integer>
+---@field private start_counts table<integer,integer>
+---@field private stop_counts table<integer,integer>
+local FoldInfo = {}
+FoldInfo.__index = FoldInfo
+
+---@private
+function FoldInfo.new()
+ return setmetatable({
+ start_counts = {},
+ stop_counts = {},
+ levels0 = {},
+ levels = {},
+ }, FoldInfo)
+end
+
+---@package
+---@param srow integer
+---@param erow integer
+function FoldInfo:invalidate_range(srow, erow)
+ for i = srow, erow do
+ self.start_counts[i + 1] = nil
+ self.stop_counts[i + 1] = nil
+ self.levels0[i + 1] = nil
+ self.levels[i + 1] = nil
+ end
+end
+
+--- Efficiently remove items from middle of a list a list.
+---
+--- Calling table.remove() in a loop will re-index the tail of the table on
+--- every iteration, instead this function will re-index the table exactly
+--- once.
+---
+--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
+---
+---@param t any[]
+---@param first integer
+---@param last integer
+local function list_remove(t, first, last)
+ local n = #t
+ for i = 0, n - first do
+ t[first + i] = t[last + 1 + i]
+ t[last + 1 + i] = nil
+ end
+end
+
+---@package
+---@param srow integer
+---@param erow integer
+function FoldInfo:remove_range(srow, erow)
+ list_remove(self.levels, srow + 1, erow)
+ list_remove(self.levels0, srow + 1, erow)
+ list_remove(self.start_counts, srow + 1, erow)
+ list_remove(self.stop_counts, srow + 1, erow)
+end
+
+--- Efficiently insert items into the middle of a list.
+---
+--- Calling table.insert() in a loop will re-index the tail of the table on
+--- every iteration, instead this function will re-index the table exactly
+--- once.
+---
+--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524
+---
+---@param t any[]
+---@param first integer
+---@param last integer
+---@param v any
+local function list_insert(t, first, last, v)
+ local n = #t
+
+ -- Shift table forward
+ for i = n - first, 0, -1 do
+ t[last + 1 + i] = t[first + i]
+ end
+
+ -- Fill in new values
+ for i = first, last do
+ t[i] = v
+ end
+end
+
+---@package
+---@param srow integer
+---@param erow integer
+function FoldInfo:add_range(srow, erow)
+ list_insert(self.levels, srow + 1, erow, '-1')
+ list_insert(self.levels0, srow + 1, erow, -1)
+ list_insert(self.start_counts, srow + 1, erow, nil)
+ list_insert(self.stop_counts, srow + 1, erow, nil)
+end
+
+---@package
+---@param lnum integer
+function FoldInfo:add_start(lnum)
+ self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1
+end
+
+---@package
+---@param lnum integer
+function FoldInfo:add_stop(lnum)
+ self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1
+end
+
+---@package
+---@param lnum integer
+---@return integer
+function FoldInfo:get_start(lnum)
+ return self.start_counts[lnum] or 0
+end
+
+---@package
+---@param lnum integer
+---@return integer
+function FoldInfo:get_stop(lnum)
+ return self.stop_counts[lnum] or 0
+end
+
+local function trim_level(level)
+ local max_fold_level = vim.wo.foldnestmax
+ if level > max_fold_level then
+ return max_fold_level
+ end
+ return level
+end
+
+--- If a parser doesn't have any ranges explicitly set, treesitter will
+--- return a range with end_row and end_bytes with a value of UINT32_MAX,
+--- so clip end_row to the max buffer line.
+---
+--- TODO(lewis6991): Handle this generally
+---
+--- @param bufnr integer
+--- @param erow integer?
+--- @return integer
+local function normalise_erow(bufnr, erow)
+ local max_erow = api.nvim_buf_line_count(bufnr) - 1
+ return math.min(erow or max_erow, max_erow)
+end
+
+-- TODO(lewis6991): Setup a decor provider so injections folds can be parsed
+-- as the window is redrawn
+---@param bufnr integer
+---@param info TS.FoldInfo
+---@param srow integer?
+---@param erow integer?
+---@param parse_injections? boolean
+local function get_folds_levels(bufnr, info, srow, erow, parse_injections)
+ srow = srow or 0
+ erow = normalise_erow(bufnr, erow)
+
+ info:invalidate_range(srow, erow)
+
+ local prev_start = -1
+ local prev_stop = -1
+
+ local parser = ts.get_parser(bufnr)
+
+ parser:parse(parse_injections and { srow, erow } or nil)
+
+ parser:for_each_tree(function(tree, ltree)
+ local query = ts.query.get(ltree:lang(), 'folds')
+ if not query then
+ return
+ end
+
+ -- erow in query is end-exclusive
+ local q_erow = erow and erow + 1 or -1
+
+ for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do
+ if query.captures[id] == 'fold' then
+ local range = ts.get_range(node, bufnr, metadata[id])
+ local start, _, stop, stop_col = Range.unpack4(range)
+
+ if stop_col == 0 then
+ stop = stop - 1
+ end
+
+ local fold_length = stop - start + 1
+
+ -- Fold only multiline nodes that are not exactly the same as previously met folds
+ -- Checking against just the previously found fold is sufficient if nodes
+ -- are returned in preorder or postorder when traversing tree
+ if
+ fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
+ then
+ info:add_start(start + 1)
+ info:add_stop(stop + 1)
+ prev_start = start
+ prev_stop = stop
+ end
+ end
+ end
+ end)
+
+ local current_level = info.levels0[srow] or 0
+
+ -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
+ for lnum = srow + 1, erow + 1 do
+ local last_trimmed_level = trim_level(current_level)
+ current_level = current_level + info:get_start(lnum)
+ info.levels0[lnum] = current_level
+
+ local trimmed_level = trim_level(current_level)
+ current_level = current_level - info:get_stop(lnum)
+
+ -- Determine if it's the start/end of a fold
+ -- NB: vim's fold-expr interface does not have a mechanism to indicate that
+ -- two (or more) folds start at this line, so it cannot distinguish between
+ -- ( \n ( \n )) \n (( \n ) \n )
+ -- versus
+ -- ( \n ( \n ) \n ( \n ) \n )
+ -- If it did have such a mechanism, (trimmed_level - last_trimmed_level)
+ -- would be the correct number of starts to pass on.
+ local prefix = ''
+ if trimmed_level - last_trimmed_level > 0 then
+ prefix = '>'
+ end
+
+ info.levels[lnum] = prefix .. tostring(trimmed_level)
+ end
+end
+
+local M = {}
+
+---@type table<integer,TS.FoldInfo>
+local foldinfos = {}
+
+local group = api.nvim_create_augroup('treesitter/fold', {})
+
+--- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that
+--- the user doesn't use different foldexpr for the same buffer).
+---
+--- Nvim usually automatically updates folds when text changes, but it doesn't work here because
+--- FoldInfo update is scheduled. So we do it manually.
+local function foldupdate(bufnr)
+ local function do_update()
+ for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do
+ api.nvim_win_call(win, function()
+ if vim.wo.foldmethod == 'expr' then
+ vim._foldupdate()
+ end
+ end)
+ end
+ end
+
+ if api.nvim_get_mode().mode == 'i' then
+ -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave
+ if #(api.nvim_get_autocmds({
+ group = group,
+ buffer = bufnr,
+ })) > 0 then
+ return
+ end
+ api.nvim_create_autocmd('InsertLeave', {
+ group = group,
+ buffer = bufnr,
+ once = true,
+ callback = do_update,
+ })
+ return
+ end
+
+ do_update()
+end
+
+--- Schedule a function only if bufnr is loaded.
+--- We schedule fold level computation for the following reasons:
+--- * queries seem to use the old buffer state in on_bytes for some unknown reason;
+--- * to avoid textlock;
+--- * to avoid infinite recursion:
+--- get_folds_levels → parse → _do_callback → on_changedtree → get_folds_levels.
+---@param bufnr integer
+---@param fn function
+local function schedule_if_loaded(bufnr, fn)
+ vim.schedule(function()
+ if not api.nvim_buf_is_loaded(bufnr) then
+ return
+ end
+ fn()
+ end)
+end
+
+---@param bufnr integer
+---@param foldinfo TS.FoldInfo
+---@param tree_changes Range4[]
+local function on_changedtree(bufnr, foldinfo, tree_changes)
+ schedule_if_loaded(bufnr, function()
+ for _, change in ipairs(tree_changes) do
+ local srow, _, erow = Range.unpack4(change)
+ get_folds_levels(bufnr, foldinfo, srow, erow)
+ end
+ if #tree_changes > 0 then
+ foldupdate(bufnr)
+ end
+ end)
+end
+
+---@param bufnr integer
+---@param foldinfo TS.FoldInfo
+---@param start_row integer
+---@param old_row integer
+---@param new_row integer
+local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row)
+ local end_row_old = start_row + old_row
+ local end_row_new = start_row + new_row
+
+ if new_row ~= old_row then
+ if new_row < old_row then
+ foldinfo:remove_range(end_row_new, end_row_old)
+ else
+ foldinfo:add_range(start_row, end_row_new)
+ end
+ schedule_if_loaded(bufnr, function()
+ get_folds_levels(bufnr, foldinfo, start_row, end_row_new)
+ foldupdate(bufnr)
+ end)
+ end
+end
+
+---@package
+---@param lnum integer|nil
+---@return string
+function M.foldexpr(lnum)
+ lnum = lnum or vim.v.lnum
+ local bufnr = api.nvim_get_current_buf()
+
+ local parser = vim.F.npcall(ts.get_parser, bufnr)
+ if not parser then
+ return '0'
+ end
+
+ if not foldinfos[bufnr] then
+ foldinfos[bufnr] = FoldInfo.new()
+ get_folds_levels(bufnr, foldinfos[bufnr])
+
+ parser:register_cbs({
+ on_changedtree = function(tree_changes)
+ on_changedtree(bufnr, foldinfos[bufnr], tree_changes)
+ end,
+
+ on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _)
+ on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row)
+ end,
+
+ on_detach = function()
+ foldinfos[bufnr] = nil
+ end,
+ })
+ end
+
+ return foldinfos[bufnr].levels[lnum] or '0'
+end
+
+---@package
+---@return { [1]: string, [2]: string[] }[]|string
+function M.foldtext()
+ local foldstart = vim.v.foldstart
+ local bufnr = api.nvim_get_current_buf()
+
+ ---@type boolean, LanguageTree
+ local ok, parser = pcall(ts.get_parser, bufnr)
+ if not ok then
+ return vim.fn.foldtext()
+ end
+
+ local query = ts.query.get(parser:lang(), 'highlights')
+ if not query then
+ return vim.fn.foldtext()
+ end
+
+ local tree = parser:parse({ foldstart - 1, foldstart })[1]
+
+ local line = api.nvim_buf_get_lines(bufnr, foldstart - 1, foldstart, false)[1]
+ if not line then
+ return vim.fn.foldtext()
+ end
+
+ ---@type { [1]: string, [2]: string[], range: { [1]: integer, [2]: integer } }[] | { [1]: string, [2]: string[] }[]
+ local result = {}
+
+ local line_pos = 0
+
+ for id, node, metadata in query:iter_captures(tree:root(), 0, foldstart - 1, foldstart) do
+ local name = query.captures[id]
+ local start_row, start_col, end_row, end_col = node:range()
+
+ local priority = tonumber(metadata.priority or vim.highlight.priorities.treesitter)
+
+ if start_row == foldstart - 1 and end_row == foldstart - 1 then
+ -- check for characters ignored by treesitter
+ if start_col > line_pos then
+ table.insert(result, {
+ line:sub(line_pos + 1, start_col),
+ {},
+ range = { line_pos, start_col },
+ })
+ end
+ line_pos = end_col
+
+ local text = line:sub(start_col + 1, end_col)
+ table.insert(result, { text, { { '@' .. name, priority } }, range = { start_col, end_col } })
+ end
+ end
+
+ local i = 1
+ while i <= #result do
+ -- find first capture that is not in current range and apply highlights on the way
+ local j = i + 1
+ while
+ j <= #result
+ and result[j].range[1] >= result[i].range[1]
+ and result[j].range[2] <= result[i].range[2]
+ do
+ for k, v in ipairs(result[i][2]) do
+ if not vim.tbl_contains(result[j][2], v) then
+ table.insert(result[j][2], k, v)
+ end
+ end
+ j = j + 1
+ end
+
+ -- remove the parent capture if it is split into children
+ if j > i + 1 then
+ table.remove(result, i)
+ else
+ -- highlights need to be sorted by priority, on equal prio, the deeper nested capture (earlier
+ -- in list) should be considered higher prio
+ if #result[i][2] > 1 then
+ table.sort(result[i][2], function(a, b)
+ return a[2] < b[2]
+ end)
+ end
+
+ result[i][2] = vim.tbl_map(function(tbl)
+ return tbl[1]
+ end, result[i][2])
+ result[i] = { result[i][1], result[i][2] }
+
+ i = i + 1
+ end
+ end
+
+ return result
+end
+
+return M
diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua
new file mode 100644
index 0000000000..80c998b555
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_meta.lua
@@ -0,0 +1,80 @@
+---@meta
+
+---@class TSNode: userdata
+---@field id fun(self: TSNode): string
+---@field tree fun(self: TSNode): TSTree
+---@field range fun(self: TSNode, include_bytes: false?): integer, integer, integer, integer
+---@field range fun(self: TSNode, include_bytes: true): integer, integer, integer, integer, integer, integer
+---@field start fun(self: TSNode): integer, integer, integer
+---@field end_ fun(self: TSNode): integer, integer, integer
+---@field type fun(self: TSNode): string
+---@field symbol fun(self: TSNode): integer
+---@field named fun(self: TSNode): boolean
+---@field missing fun(self: TSNode): boolean
+---@field extra fun(self: TSNode): boolean
+---@field child_count fun(self: TSNode): integer
+---@field named_child_count fun(self: TSNode): integer
+---@field child fun(self: TSNode, index: integer): TSNode?
+---@field named_child fun(self: TSNode, index: integer): TSNode?
+---@field descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode?
+---@field named_descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode?
+---@field parent fun(self: TSNode): TSNode?
+---@field next_sibling fun(self: TSNode): TSNode?
+---@field prev_sibling fun(self: TSNode): TSNode?
+---@field next_named_sibling fun(self: TSNode): TSNode?
+---@field prev_named_sibling fun(self: TSNode): TSNode?
+---@field named_children fun(self: TSNode): TSNode[]
+---@field has_changes fun(self: TSNode): boolean
+---@field has_error fun(self: TSNode): boolean
+---@field sexpr fun(self: TSNode): string
+---@field equal fun(self: TSNode, other: TSNode): boolean
+---@field iter_children fun(self: TSNode): fun(): TSNode, string
+---@field field fun(self: TSNode, name: string): TSNode[]
+---@field byte_length fun(self: TSNode): integer
+local TSNode = {}
+
+---@param query userdata
+---@param captures true
+---@param start? integer
+---@param end_? integer
+---@param opts? table
+---@return fun(): integer, TSNode, any
+function TSNode:_rawquery(query, captures, start, end_, opts) end
+
+---@param query userdata
+---@param captures false
+---@param start? integer
+---@param end_? integer
+---@param opts? table
+---@return fun(): string, any
+function TSNode:_rawquery(query, captures, start, end_, opts) end
+
+---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string)
+
+---@class TSParser
+---@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: true): TSTree, Range6[]
+---@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: false|nil): TSTree, Range4[]
+---@field reset fun(self: TSParser)
+---@field included_ranges fun(self: TSParser, include_bytes: boolean?): integer[]
+---@field set_included_ranges fun(self: TSParser, ranges: (Range6|TSNode)[])
+---@field set_timeout fun(self: TSParser, timeout: integer)
+---@field timeout fun(self: TSParser): integer
+---@field _set_logger fun(self: TSParser, lex: boolean, parse: boolean, cb: TSLoggerCallback)
+---@field _logger fun(self: TSParser): TSLoggerCallback
+
+---@class TSTree
+---@field root fun(self: TSTree): TSNode
+---@field edit fun(self: TSTree, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _:integer)
+---@field copy fun(self: TSTree): TSTree
+---@field included_ranges fun(self: TSTree, include_bytes: true): Range6[]
+---@field included_ranges fun(self: TSTree, include_bytes: false): Range4[]
+
+---@return integer
+vim._ts_get_language_version = function() end
+
+---@return integer
+vim._ts_get_minimum_language_version = function() end
+
+---@param lang string
+---@return TSParser
+vim._create_ts_parser = function(lang) end
diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua
new file mode 100644
index 0000000000..87d74789a3
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_query_linter.lua
@@ -0,0 +1,249 @@
+local api = vim.api
+
+local namespace = api.nvim_create_namespace('vim.treesitter.query_linter')
+
+local M = {}
+
+--- @class QueryLinterNormalizedOpts
+--- @field langs string[]
+--- @field clear boolean
+
+--- @alias vim.treesitter.ParseError {msg: string, range: Range4}
+
+--- Contains language dependent context for the query linter
+--- @class QueryLinterLanguageContext
+--- @field lang string? Current `lang` of the targeted parser
+--- @field parser_info table? Parser info returned by vim.treesitter.language.inspect
+--- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs`
+
+--- Adds a diagnostic for node in the query buffer
+--- @param diagnostics Diagnostic[]
+--- @param range Range4
+--- @param lint string
+--- @param lang string?
+local function add_lint_for_node(diagnostics, range, lint, lang)
+ local message = lint:gsub('\n', ' ')
+ diagnostics[#diagnostics + 1] = {
+ lnum = range[1],
+ end_lnum = range[3],
+ col = range[2],
+ end_col = range[4],
+ severity = vim.diagnostic.ERROR,
+ message = message,
+ source = lang,
+ }
+end
+
+--- Determines the target language of a query file by its path: <lang>/<query_type>.scm
+--- @param buf integer
+--- @return string?
+local function guess_query_lang(buf)
+ local filename = api.nvim_buf_get_name(buf)
+ if filename ~= '' then
+ return vim.F.npcall(vim.fn.fnamemodify, filename, ':p:h:t')
+ end
+end
+
+--- @param buf integer
+--- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil
+--- @return QueryLinterNormalizedOpts
+local function normalize_opts(buf, opts)
+ opts = opts or {}
+ if not opts.langs then
+ opts.langs = guess_query_lang(buf)
+ end
+
+ if type(opts.langs) ~= 'table' then
+ --- @diagnostic disable-next-line:assign-type-mismatch
+ opts.langs = { opts.langs }
+ end
+
+ --- @cast opts QueryLinterNormalizedOpts
+ opts.langs = opts.langs or {}
+ return opts
+end
+
+local lint_query = [[;; query
+ (program [(named_node) (list) (grouping)] @toplevel)
+ (named_node
+ name: _ @node.named)
+ (anonymous_node
+ name: _ @node.anonymous)
+ (field_definition
+ name: (identifier) @field)
+ (predicate
+ name: (identifier) @predicate.name
+ type: (predicate_type) @predicate.type)
+ (ERROR) @error
+]]
+
+--- @param err string
+--- @param node TSNode
+--- @return vim.treesitter.ParseError
+local function get_error_entry(err, node)
+ local start_line, start_col = node:range()
+ local line_offset, col_offset, msg = err:gmatch('.-:%d+: Query error at (%d+):(%d+)%. ([^:]+)')() ---@type string, string, string
+ start_line, start_col =
+ start_line + tonumber(line_offset) - 1, start_col + tonumber(col_offset) - 1
+ local end_line, end_col = start_line, start_col
+ if msg:match('^Invalid syntax') or msg:match('^Impossible') then
+ -- Use the length of the underlined node
+ local underlined = vim.split(err, '\n')[2]
+ end_col = end_col + #underlined
+ elseif msg:match('^Invalid') then
+ -- Use the length of the problematic type/capture/field
+ end_col = end_col + #msg:match('"([^"]+)"')
+ end
+
+ return {
+ msg = msg,
+ range = { start_line, start_col, end_line, end_col },
+ }
+end
+
+--- @param node TSNode
+--- @param buf integer
+--- @param lang string
+local function hash_parse(node, buf, lang)
+ return tostring(node:id()) .. tostring(buf) .. tostring(vim.b[buf].changedtick) .. lang
+end
+
+--- @param node TSNode
+--- @param buf integer
+--- @param lang string
+--- @return vim.treesitter.ParseError?
+local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
+ local query_text = vim.treesitter.get_node_text(node, buf)
+ local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query
+
+ if not ok and type(err) == 'string' then
+ return get_error_entry(err, node)
+ end
+end)
+
+--- @param buf integer
+--- @param match table<integer,TSNode>
+--- @param query Query
+--- @param lang_context QueryLinterLanguageContext
+--- @param diagnostics Diagnostic[]
+local function lint_match(buf, match, query, lang_context, diagnostics)
+ local lang = lang_context.lang
+ local parser_info = lang_context.parser_info
+
+ for id, node in pairs(match) do
+ local cap_id = query.captures[id]
+
+ -- perform language-independent checks only for first lang
+ if lang_context.is_first_lang and cap_id == 'error' then
+ local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
+ add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text)
+ end
+
+ -- other checks rely on Neovim parser introspection
+ if lang and parser_info and cap_id == 'toplevel' then
+ local err = parse(node, buf, lang)
+ if err then
+ add_lint_for_node(diagnostics, err.range, err.msg, lang)
+ end
+ end
+ end
+end
+
+--- @private
+--- @param buf integer Buffer to lint
+--- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil Options for linting
+function M.lint(buf, opts)
+ if buf == 0 then
+ buf = api.nvim_get_current_buf()
+ end
+
+ local diagnostics = {}
+ local query = vim.treesitter.query.parse('query', lint_query)
+
+ opts = normalize_opts(buf, opts)
+
+ -- perform at least one iteration even with no langs to perform language independent checks
+ for i = 1, math.max(1, #opts.langs) do
+ local lang = opts.langs[i]
+
+ --- @type (table|nil)
+ local parser_info = vim.F.npcall(vim.treesitter.language.inspect, lang)
+
+ local parser = vim.treesitter.get_parser(buf)
+ parser:parse()
+ parser:for_each_tree(function(tree, ltree)
+ if ltree:lang() == 'query' then
+ for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1) do
+ local lang_context = {
+ lang = lang,
+ parser_info = parser_info,
+ is_first_lang = i == 1,
+ }
+ lint_match(buf, match, query, lang_context, diagnostics)
+ end
+ end
+ end)
+ end
+
+ vim.diagnostic.set(namespace, buf, diagnostics)
+end
+
+--- @private
+--- @param buf integer
+function M.clear(buf)
+ vim.diagnostic.reset(namespace, buf)
+end
+
+--- @private
+--- @param findstart integer
+--- @param base string
+function M.omnifunc(findstart, base)
+ if findstart == 1 then
+ local result =
+ api.nvim_get_current_line():sub(1, api.nvim_win_get_cursor(0)[2]):find('["#%-%w]*$')
+ return result - 1
+ end
+
+ local buf = api.nvim_get_current_buf()
+ local query_lang = guess_query_lang(buf)
+
+ local ok, parser_info = pcall(vim.treesitter.language.inspect, query_lang)
+ if not ok then
+ return -2
+ end
+
+ local items = {}
+ for _, f in pairs(parser_info.fields) do
+ if f:find(base, 1, true) then
+ table.insert(items, f .. ':')
+ end
+ end
+ for _, p in pairs(vim.treesitter.query.list_predicates()) do
+ local text = '#' .. p
+ local found = text:find(base, 1, true)
+ if found and found <= 2 then -- with or without '#'
+ table.insert(items, text)
+ end
+ text = '#not-' .. p
+ found = text:find(base, 1, true)
+ if found and found <= 2 then -- with or without '#'
+ table.insert(items, text)
+ end
+ end
+ for _, p in pairs(vim.treesitter.query.list_directives()) do
+ local text = '#' .. p
+ local found = text:find(base, 1, true)
+ if found and found <= 2 then -- with or without '#'
+ table.insert(items, text)
+ end
+ end
+ for _, s in pairs(parser_info.symbols) do
+ local text = s[2] and s[1] or '"' .. s[1]:gsub([[\]], [[\\]]) .. '"' ---@type string
+ if text:find(base, 1, true) then
+ table.insert(items, text)
+ end
+ end
+ return { words = items, refresh = 'always' }
+end
+
+return M
diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua
new file mode 100644
index 0000000000..8d727c3c52
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_range.lua
@@ -0,0 +1,193 @@
+local api = vim.api
+
+local M = {}
+
+---@class Range2
+---@field [1] integer start row
+---@field [2] integer end row
+
+---@class Range4
+---@field [1] integer start row
+---@field [2] integer start column
+---@field [3] integer end row
+---@field [4] integer end column
+
+---@class Range6
+---@field [1] integer start row
+---@field [2] integer start column
+---@field [3] integer start bytes
+---@field [4] integer end row
+---@field [5] integer end column
+---@field [6] integer end bytes
+
+---@alias Range Range2|Range4|Range6
+
+---@private
+---@param a_row integer
+---@param a_col integer
+---@param b_row integer
+---@param b_col integer
+---@return integer
+--- 1: a > b
+--- 0: a == b
+--- -1: a < b
+local function cmp_pos(a_row, a_col, b_row, b_col)
+ if a_row == b_row then
+ if a_col > b_col then
+ return 1
+ elseif a_col < b_col then
+ return -1
+ else
+ return 0
+ end
+ elseif a_row > b_row then
+ return 1
+ end
+
+ return -1
+end
+
+M.cmp_pos = {
+ lt = function(...)
+ return cmp_pos(...) == -1
+ end,
+ le = function(...)
+ return cmp_pos(...) ~= 1
+ end,
+ gt = function(...)
+ return cmp_pos(...) == 1
+ end,
+ ge = function(...)
+ return cmp_pos(...) ~= -1
+ end,
+ eq = function(...)
+ return cmp_pos(...) == 0
+ end,
+ ne = function(...)
+ return cmp_pos(...) ~= 0
+ end,
+}
+
+setmetatable(M.cmp_pos, { __call = cmp_pos })
+
+---@private
+---Check if a variable is a valid range object
+---@param r any
+---@return boolean
+function M.validate(r)
+ if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
+ return false
+ end
+
+ for _, e in
+ ipairs(r --[[@as any[] ]])
+ do
+ if type(e) ~= 'number' then
+ return false
+ end
+ end
+
+ return true
+end
+
+---@private
+---@param r1 Range
+---@param r2 Range
+---@return boolean
+function M.intercepts(r1, r2)
+ local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
+ local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
+
+ -- r1 is above r2
+ if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
+ return false
+ end
+
+ -- r1 is below r2
+ if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
+ return false
+ end
+
+ return true
+end
+
+---@private
+---@param r Range
+---@return integer, integer, integer, integer
+function M.unpack4(r)
+ if #r == 2 then
+ return r[1], 0, r[2], 0
+ end
+ local off_1 = #r == 6 and 1 or 0
+ return r[1], r[2], r[3 + off_1], r[4 + off_1]
+end
+
+---@private
+---@param r Range6
+---@return integer, integer, integer, integer, integer, integer
+function M.unpack6(r)
+ return r[1], r[2], r[3], r[4], r[5], r[6]
+end
+
+---@private
+---@param r1 Range
+---@param r2 Range
+---@return boolean whether r1 contains r2
+function M.contains(r1, r2)
+ local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
+ local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
+
+ -- start doesn't fit
+ if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
+ return false
+ end
+
+ -- end doesn't fit
+ if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
+ return false
+ end
+
+ return true
+end
+
+--- @param source integer|string
+--- @param index integer
+--- @return integer
+local function get_offset(source, index)
+ if index == 0 then
+ return 0
+ end
+
+ if type(source) == 'number' then
+ return api.nvim_buf_get_offset(source, index)
+ end
+
+ local byte = 0
+ local next_offset = source:gmatch('()\n')
+ local line = 1
+ while line <= index do
+ byte = next_offset() --[[@as integer]]
+ line = line + 1
+ end
+
+ return byte
+end
+
+---@private
+---@param source integer|string
+---@param range Range
+---@return Range6
+function M.add_bytes(source, range)
+ if type(range) == 'table' and #range == 6 then
+ return range --[[@as Range6]]
+ end
+
+ local start_row, start_col, end_row, end_col = M.unpack4(range)
+ -- TODO(vigoux): proper byte computation here, and account for EOL ?
+ local start_byte = get_offset(source, start_row) + start_col
+ local end_byte = get_offset(source, end_row) + end_col
+
+ return { start_row, start_col, start_byte, end_row, end_col, end_byte }
+end
+
+return M
diff --git a/runtime/lua/vim/treesitter/dev.lua b/runtime/lua/vim/treesitter/dev.lua
new file mode 100644
index 0000000000..69ddc9b558
--- /dev/null
+++ b/runtime/lua/vim/treesitter/dev.lua
@@ -0,0 +1,645 @@
+local api = vim.api
+
+---@class TSDevModule
+local M = {}
+
+---@class TSTreeView
+---@field ns integer API namespace
+---@field opts table Options table with the following keys:
+--- - anon (boolean): If true, display anonymous nodes
+--- - lang (boolean): If true, display the language alongside each node
+--- - indent (number): Number of spaces to indent nested lines. Default is 2.
+---@field nodes TSP.Node[]
+---@field named TSP.Node[]
+local TSTreeView = {}
+
+---@class TSP.Node
+---@field id integer Node id
+---@field text string Node text
+---@field named boolean True if this is a named (non-anonymous) node
+---@field depth integer Depth of the node within the tree
+---@field lnum integer Beginning line number of this node in the source buffer
+---@field col integer Beginning column number of this node in the source buffer
+---@field end_lnum integer Final line number of this node in the source buffer
+---@field end_col integer Final column number of this node in the source buffer
+---@field lang string Source language of this node
+---@field root TSNode
+
+---@class TSP.Injection
+---@field lang string Source language of this injection
+---@field root TSNode Root node of the injection
+
+--- Traverse all child nodes starting at {node}.
+---
+--- This is a recursive function. The {depth} parameter indicates the current recursion level.
+--- {lang} is a string indicating the language of the tree currently being traversed. Each traversed
+--- node is added to {tree}. When recursion completes, {tree} is an array of all nodes in the order
+--- they were visited.
+---
+--- {injections} is a table mapping node ids from the primary tree to language tree injections. Each
+--- injected language has a series of trees nested within the primary language's tree, and the root
+--- node of each of these trees is contained within a node in the primary tree. The {injections}
+--- table maps nodes in the primary tree to root nodes of injected trees.
+---
+---@param node TSNode Starting node to begin traversal |tsnode|
+---@param depth integer Current recursion depth
+---@param lang string Language of the tree currently being traversed
+---@param injections table<string, TSP.Injection> Mapping of node ids to root nodes
+--- of injected language trees (see explanation above)
+---@param tree TSP.Node[] Output table containing a list of tables each representing a node in the tree
+local function traverse(node, depth, lang, injections, tree)
+ local injection = injections[node:id()]
+ if injection then
+ traverse(injection.root, depth, injection.lang, injections, tree)
+ end
+
+ for child, field in node:iter_children() do
+ local type = child:type()
+ local lnum, col, end_lnum, end_col = child:range()
+ local named = child:named()
+ local text ---@type string
+ if named then
+ if field then
+ text = string.format('%s: (%s', field, type)
+ else
+ text = string.format('(%s', type)
+ end
+ else
+ text = string.format('"%s"', type:gsub('\n', '\\n'):gsub('"', '\\"'))
+ end
+
+ table.insert(tree, {
+ id = child:id(),
+ text = text,
+ named = named,
+ depth = depth,
+ lnum = lnum,
+ col = col,
+ end_lnum = end_lnum,
+ end_col = end_col,
+ lang = lang,
+ })
+
+ traverse(child, depth + 1, lang, injections, tree)
+
+ if named then
+ tree[#tree].text = string.format('%s)', tree[#tree].text)
+ end
+ end
+
+ return tree
+end
+
+--- Create a new treesitter view.
+---
+---@param bufnr integer Source buffer number
+---@param lang string|nil Language of source buffer
+---
+---@return TSTreeView|nil
+---@return string|nil Error message, if any
+---
+---@package
+function TSTreeView:new(bufnr, lang)
+ local ok, parser = pcall(vim.treesitter.get_parser, bufnr or 0, lang)
+ if not ok then
+ return nil, 'No parser available for the given buffer'
+ end
+
+ -- For each child tree (injected language), find the root of the tree and locate the node within
+ -- the primary tree that contains that root. Add a mapping from the node in the primary tree to
+ -- the root in the child tree to the {injections} table.
+ local root = parser:parse(true)[1]:root()
+ local injections = {} ---@type table<string, TSP.Injection>
+
+ parser:for_each_tree(function(parent_tree, parent_ltree)
+ local parent = parent_tree:root()
+ for _, child in pairs(parent_ltree:children()) do
+ child:for_each_tree(function(tree, ltree)
+ local r = tree:root()
+ local node = assert(parent:named_descendant_for_range(r:range()))
+ local id = node:id()
+ if not injections[id] or r:byte_length() > injections[id].root:byte_length() then
+ injections[id] = {
+ lang = ltree:lang(),
+ root = r,
+ }
+ end
+ end)
+ end
+ end)
+
+ local nodes = traverse(root, 0, parser:lang(), injections, {})
+
+ local named = {} ---@type TSP.Node[]
+ for _, v in ipairs(nodes) do
+ if v.named then
+ named[#named + 1] = v
+ end
+ end
+
+ local t = {
+ ns = api.nvim_create_namespace('treesitter/dev-inspect'),
+ nodes = nodes,
+ named = named,
+ opts = {
+ anon = false,
+ lang = false,
+ indent = 2,
+ },
+ }
+
+ setmetatable(t, self)
+ self.__index = self
+ return t
+end
+
+local decor_ns = api.nvim_create_namespace('ts.dev')
+
+---@param lnum integer
+---@param col integer
+---@param end_lnum integer
+---@param end_col integer
+---@return string
+local function get_range_str(lnum, col, end_lnum, end_col)
+ if lnum == end_lnum then
+ return string.format('[%d:%d - %d]', lnum + 1, col + 1, end_col)
+ end
+ return string.format('[%d:%d - %d:%d]', lnum + 1, col + 1, end_lnum + 1, end_col)
+end
+
+---@param w integer
+---@return boolean closed Whether the window was closed.
+local function close_win(w)
+ if api.nvim_win_is_valid(w) then
+ api.nvim_win_close(w, true)
+ return true
+ end
+
+ return false
+end
+
+---@param w integer
+---@param b integer
+local function set_dev_properties(w, b)
+ vim.wo[w].scrolloff = 5
+ vim.wo[w].wrap = false
+ vim.wo[w].foldmethod = 'manual' -- disable folding
+ vim.bo[b].buflisted = false
+ vim.bo[b].buftype = 'nofile'
+ vim.bo[b].bufhidden = 'wipe'
+ vim.bo[b].filetype = 'query'
+end
+
+--- Updates the cursor position in the inspector to match the node under the cursor.
+---
+--- @param treeview TSTreeView
+--- @param lang string
+--- @param source_buf integer
+--- @param inspect_buf integer
+--- @param inspect_win integer
+--- @param pos? { [1]: integer, [2]: integer }
+local function set_inspector_cursor(treeview, lang, source_buf, inspect_buf, inspect_win, pos)
+ api.nvim_buf_clear_namespace(inspect_buf, treeview.ns, 0, -1)
+
+ local cursor_node = vim.treesitter.get_node({
+ bufnr = source_buf,
+ lang = lang,
+ pos = pos,
+ ignore_injections = false,
+ })
+ if not cursor_node then
+ return
+ end
+
+ local cursor_node_id = cursor_node:id()
+ for i, v in treeview:iter() do
+ if v.id == cursor_node_id then
+ local start = v.depth * treeview.opts.indent ---@type integer
+ local end_col = start + #v.text
+ api.nvim_buf_set_extmark(inspect_buf, treeview.ns, i - 1, start, {
+ end_col = end_col,
+ hl_group = 'Visual',
+ })
+ api.nvim_win_set_cursor(inspect_win, { i, 0 })
+ break
+ end
+ end
+end
+
+--- Write the contents of this View into {bufnr}.
+---
+---@param bufnr integer Buffer number to write into.
+---@package
+function TSTreeView:draw(bufnr)
+ vim.bo[bufnr].modifiable = true
+ local lines = {} ---@type string[]
+ local lang_hl_marks = {} ---@type table[]
+
+ for _, item in self:iter() do
+ local range_str = get_range_str(item.lnum, item.col, item.end_lnum, item.end_col)
+ local lang_str = self.opts.lang and string.format(' %s', item.lang) or ''
+ local line = string.format(
+ '%s%s ; %s%s',
+ string.rep(' ', item.depth * self.opts.indent),
+ item.text,
+ range_str,
+ lang_str
+ )
+
+ if self.opts.lang then
+ lang_hl_marks[#lang_hl_marks + 1] = {
+ col = #line - #lang_str,
+ end_col = #line,
+ }
+ end
+
+ lines[#lines + 1] = line
+ end
+
+ api.nvim_buf_set_lines(bufnr, 0, -1, false, lines)
+
+ api.nvim_buf_clear_namespace(bufnr, decor_ns, 0, -1)
+
+ for i, m in ipairs(lang_hl_marks) do
+ api.nvim_buf_set_extmark(bufnr, decor_ns, i - 1, m.col, {
+ hl_group = 'Title',
+ end_col = m.end_col,
+ })
+ end
+
+ vim.bo[bufnr].modifiable = false
+end
+
+--- Get node {i} from this View.
+---
+--- The node number is dependent on whether or not anonymous nodes are displayed.
+---
+---@param i integer Node number to get
+---@return TSP.Node
+---@package
+function TSTreeView:get(i)
+ local t = self.opts.anon and self.nodes or self.named
+ return t[i]
+end
+
+--- Iterate over all of the nodes in this View.
+---
+---@return (fun(): integer, TSP.Node) Iterator over all nodes in this View
+---@return table
+---@return integer
+---@package
+function TSTreeView:iter()
+ return ipairs(self.opts.anon and self.nodes or self.named)
+end
+
+--- @class InspectTreeOpts
+--- @field lang string? The language of the source buffer. If omitted, the
+--- filetype of the source buffer is used.
+--- @field bufnr integer? Buffer to draw the tree into. If omitted, a new
+--- buffer is created.
+--- @field winid integer? Window id to display the tree buffer in. If omitted,
+--- a new window is created with {command}.
+--- @field command string? Vimscript command to create the window. Default
+--- value is "60vnew". Only used when {winid} is nil.
+--- @field title (string|fun(bufnr:integer):string|nil) Title of the window. If a
+--- function, it accepts the buffer number of the source
+--- buffer as its only argument and should return a string.
+
+--- @private
+---
+--- @param opts InspectTreeOpts?
+function M.inspect_tree(opts)
+ vim.validate({
+ opts = { opts, 't', true },
+ })
+
+ opts = opts or {}
+
+ local buf = api.nvim_get_current_buf()
+ local win = api.nvim_get_current_win()
+ local treeview = assert(TSTreeView:new(buf, opts.lang))
+
+ -- Close any existing inspector window
+ if vim.b[buf].dev_inspect then
+ close_win(vim.b[buf].dev_inspect)
+ end
+
+ local w = opts.winid
+ if not w then
+ vim.cmd(opts.command or '60vnew')
+ w = api.nvim_get_current_win()
+ end
+
+ local b = opts.bufnr
+ if b then
+ api.nvim_win_set_buf(w, b)
+ else
+ b = api.nvim_win_get_buf(w)
+ end
+
+ vim.b[buf].dev_inspect = w
+ vim.b[b].dev_base = win -- base window handle
+ vim.b[b].disable_query_linter = true
+ set_dev_properties(w, b)
+
+ local title --- @type string?
+ local opts_title = opts.title
+ if not opts_title then
+ local bufname = api.nvim_buf_get_name(buf)
+ title = string.format('Syntax tree for %s', vim.fn.fnamemodify(bufname, ':.'))
+ elseif type(opts_title) == 'function' then
+ title = opts_title(buf)
+ end
+
+ assert(type(title) == 'string', 'Window title must be a string')
+ api.nvim_buf_set_name(b, title)
+
+ treeview:draw(b)
+
+ local cursor = api.nvim_win_get_cursor(win)
+ set_inspector_cursor(treeview, opts.lang, buf, b, w, { cursor[1] - 1, cursor[2] })
+
+ api.nvim_buf_clear_namespace(buf, treeview.ns, 0, -1)
+ api.nvim_buf_set_keymap(b, 'n', '<CR>', '', {
+ desc = 'Jump to the node under the cursor in the source buffer',
+ callback = function()
+ local row = api.nvim_win_get_cursor(w)[1]
+ local pos = treeview:get(row)
+ api.nvim_set_current_win(win)
+ api.nvim_win_set_cursor(win, { pos.lnum + 1, pos.col })
+ end,
+ })
+ api.nvim_buf_set_keymap(b, 'n', 'a', '', {
+ desc = 'Toggle anonymous nodes',
+ callback = function()
+ local row, col = unpack(api.nvim_win_get_cursor(w)) ---@type integer, integer
+ local curnode = treeview:get(row)
+ while curnode and not curnode.named do
+ row = row - 1
+ curnode = treeview:get(row)
+ end
+
+ treeview.opts.anon = not treeview.opts.anon
+ treeview:draw(b)
+
+ if not curnode then
+ return
+ end
+
+ local id = curnode.id
+ for i, node in treeview:iter() do
+ if node.id == id then
+ api.nvim_win_set_cursor(w, { i, col })
+ break
+ end
+ end
+ end,
+ })
+ api.nvim_buf_set_keymap(b, 'n', 'I', '', {
+ desc = 'Toggle language display',
+ callback = function()
+ treeview.opts.lang = not treeview.opts.lang
+ treeview:draw(b)
+ end,
+ })
+ api.nvim_buf_set_keymap(b, 'n', 'o', '', {
+ desc = 'Toggle query editor',
+ callback = function()
+ local edit_w = vim.b[buf].dev_edit
+ if not edit_w or not close_win(edit_w) then
+ M.edit_query()
+ end
+ end,
+ })
+
+ local group = api.nvim_create_augroup('treesitter/dev', {})
+
+ api.nvim_create_autocmd('CursorMoved', {
+ group = group,
+ buffer = b,
+ callback = function()
+ if not api.nvim_buf_is_loaded(buf) then
+ return true
+ end
+
+ api.nvim_buf_clear_namespace(buf, treeview.ns, 0, -1)
+ local row = api.nvim_win_get_cursor(w)[1]
+ local pos = treeview:get(row)
+ api.nvim_buf_set_extmark(buf, treeview.ns, pos.lnum, pos.col, {
+ end_row = pos.end_lnum,
+ end_col = math.max(0, pos.end_col),
+ hl_group = 'Visual',
+ })
+
+ local topline, botline = vim.fn.line('w0', win), vim.fn.line('w$', win)
+
+ -- Move the cursor if highlighted range is completely out of view
+ if pos.lnum < topline and pos.end_lnum < topline then
+ api.nvim_win_set_cursor(win, { pos.end_lnum + 1, 0 })
+ elseif pos.lnum > botline and pos.end_lnum > botline then
+ api.nvim_win_set_cursor(win, { pos.lnum + 1, 0 })
+ end
+ end,
+ })
+
+ api.nvim_create_autocmd('CursorMoved', {
+ group = group,
+ buffer = buf,
+ callback = function()
+ if not api.nvim_buf_is_loaded(b) then
+ return true
+ end
+
+ set_inspector_cursor(treeview, opts.lang, buf, b, w)
+ end,
+ })
+
+ api.nvim_create_autocmd({ 'TextChanged', 'InsertLeave' }, {
+ group = group,
+ buffer = buf,
+ callback = function()
+ if not api.nvim_buf_is_loaded(b) then
+ return true
+ end
+
+ treeview = assert(TSTreeView:new(buf, opts.lang))
+ treeview:draw(b)
+ end,
+ })
+
+ api.nvim_create_autocmd('BufLeave', {
+ group = group,
+ buffer = b,
+ callback = function()
+ if not api.nvim_buf_is_loaded(buf) then
+ return true
+ end
+ api.nvim_buf_clear_namespace(buf, treeview.ns, 0, -1)
+ end,
+ })
+
+ api.nvim_create_autocmd('BufLeave', {
+ group = group,
+ buffer = buf,
+ callback = function()
+ if not api.nvim_buf_is_loaded(b) then
+ return true
+ end
+ api.nvim_buf_clear_namespace(b, treeview.ns, 0, -1)
+ end,
+ })
+
+ api.nvim_create_autocmd('BufHidden', {
+ group = group,
+ buffer = buf,
+ once = true,
+ callback = function()
+ close_win(w)
+ end,
+ })
+end
+
+local edit_ns = api.nvim_create_namespace('treesitter/dev-edit')
+
+---@param query_win integer
+---@param base_win integer
+---@param lang string
+local function update_editor_highlights(query_win, base_win, lang)
+ local base_buf = api.nvim_win_get_buf(base_win)
+ local query_buf = api.nvim_win_get_buf(query_win)
+ local parser = vim.treesitter.get_parser(base_buf, lang)
+ api.nvim_buf_clear_namespace(base_buf, edit_ns, 0, -1)
+ local query_content = table.concat(api.nvim_buf_get_lines(query_buf, 0, -1, false), '\n')
+
+ local ok_query, query = pcall(vim.treesitter.query.parse, lang, query_content)
+ if not ok_query then
+ return
+ end
+
+ local cursor_word = vim.fn.expand('<cword>') --[[@as string]]
+ -- Only highlight captures if the cursor is on a capture name
+ if cursor_word:find('^@') == nil then
+ return
+ end
+ -- Remove the '@' from the cursor word
+ cursor_word = cursor_word:sub(2)
+ local topline, botline = vim.fn.line('w0', base_win), vim.fn.line('w$', base_win)
+ for id, node in query:iter_captures(parser:trees()[1]:root(), base_buf, topline - 1, botline) do
+ local capture_name = query.captures[id]
+ if capture_name == cursor_word then
+ local lnum, col, end_lnum, end_col = node:range()
+ api.nvim_buf_set_extmark(base_buf, edit_ns, lnum, col, {
+ end_row = end_lnum,
+ end_col = end_col,
+ hl_group = 'Visual',
+ virt_text = {
+ { capture_name, 'Title' },
+ },
+ })
+ end
+ end
+end
+
+--- @private
+--- @param lang? string language to open the query editor for.
+function M.edit_query(lang)
+ local buf = api.nvim_get_current_buf()
+ local win = api.nvim_get_current_win()
+
+ -- Close any existing editor window
+ if vim.b[buf].dev_edit then
+ close_win(vim.b[buf].dev_edit)
+ end
+
+ local cmd = '60vnew'
+ -- If the inspector is open, place the editor above it.
+ local base_win = vim.b[buf].dev_base ---@type integer?
+ local base_buf = base_win and api.nvim_win_get_buf(base_win)
+ local inspect_win = base_buf and vim.b[base_buf].dev_inspect
+ if base_win and base_buf and api.nvim_win_is_valid(inspect_win) then
+ vim.api.nvim_set_current_win(inspect_win)
+ buf = base_buf
+ win = base_win
+ cmd = 'new'
+ end
+ vim.cmd(cmd)
+
+ local ok, parser = pcall(vim.treesitter.get_parser, buf, lang)
+ if not ok then
+ return nil, 'No parser available for the given buffer'
+ end
+ lang = parser:lang()
+
+ local query_win = api.nvim_get_current_win()
+ local query_buf = api.nvim_win_get_buf(query_win)
+
+ vim.b[buf].dev_edit = query_win
+ vim.bo[query_buf].omnifunc = 'v:lua.vim.treesitter.query.omnifunc'
+ set_dev_properties(query_win, query_buf)
+
+ -- Note that omnifunc guesses the language based on the containing folder,
+ -- so we add the parser's language to the buffer's name so that omnifunc
+ -- can infer the language later.
+ api.nvim_buf_set_name(query_buf, string.format('%s/query_editor.scm', lang))
+
+ local group = api.nvim_create_augroup('treesitter/dev-edit', {})
+ api.nvim_create_autocmd({ 'TextChanged', 'InsertLeave' }, {
+ group = group,
+ buffer = query_buf,
+ desc = 'Update query editor diagnostics when the query changes',
+ callback = function()
+ vim.treesitter.query.lint(query_buf, { langs = lang, clear = false })
+ end,
+ })
+ api.nvim_create_autocmd({ 'TextChanged', 'InsertLeave', 'CursorMoved', 'BufEnter' }, {
+ group = group,
+ buffer = query_buf,
+ desc = 'Update query editor highlights when the cursor moves',
+ callback = function()
+ if api.nvim_win_is_valid(win) then
+ update_editor_highlights(query_win, win, lang)
+ end
+ end,
+ })
+ api.nvim_create_autocmd('BufLeave', {
+ group = group,
+ buffer = query_buf,
+ desc = 'Clear highlights when leaving the query editor',
+ callback = function()
+ api.nvim_buf_clear_namespace(buf, edit_ns, 0, -1)
+ end,
+ })
+ api.nvim_create_autocmd('BufLeave', {
+ group = group,
+ buffer = buf,
+ desc = 'Clear the query editor highlights when leaving the source buffer',
+ callback = function()
+ if not api.nvim_buf_is_loaded(query_buf) then
+ return true
+ end
+
+ api.nvim_buf_clear_namespace(query_buf, edit_ns, 0, -1)
+ end,
+ })
+ api.nvim_create_autocmd('BufHidden', {
+ group = group,
+ buffer = buf,
+ desc = 'Close the editor window when the source buffer is hidden',
+ once = true,
+ callback = function()
+ close_win(query_win)
+ end,
+ })
+
+ api.nvim_buf_set_lines(query_buf, 0, -1, false, {
+ ';; Write queries here (see $VIMRUNTIME/queries/ for examples).',
+ ';; Move cursor to a capture ("@foo") to highlight matches in the source buffer.',
+ ';; Completion for grammar nodes is available (:help compl-omni)',
+ '',
+ '',
+ })
+ vim.cmd('normal! G')
+ vim.cmd.startinsert()
+end
+
+return M
diff --git a/runtime/lua/vim/treesitter/health.lua b/runtime/lua/vim/treesitter/health.lua
index c0a1eca0ce..ed1161e97f 100644
--- a/runtime/lua/vim/treesitter/health.lua
+++ b/runtime/lua/vim/treesitter/health.lua
@@ -2,30 +2,28 @@ local M = {}
local ts = vim.treesitter
local health = require('vim.health')
---- Lists the parsers currently installed
----
----@return string[] list of parser files
-function M.list_parsers()
- return vim.api.nvim_get_runtime_file('parser/*', true)
-end
-
--- Performs a healthcheck for treesitter integration
function M.check()
- local parsers = M.list_parsers()
+ local parsers = vim.api.nvim_get_runtime_file('parser/*', true)
- health.report_info(string.format('Nvim runtime ABI version: %d', ts.language_version))
+ health.info(string.format('Nvim runtime ABI version: %d', ts.language_version))
for _, parser in pairs(parsers) do
local parsername = vim.fn.fnamemodify(parser, ':t:r')
- local is_loadable, ret = pcall(ts.language.require_language, parsername)
+ local is_loadable, err_or_nil = pcall(ts.language.add, parsername)
- if not is_loadable or not ret then
- health.report_error(
- string.format('Parser "%s" failed to load (path: %s): %s', parsername, parser, ret or '?')
+ if not is_loadable then
+ health.error(
+ string.format(
+ 'Parser "%s" failed to load (path: %s): %s',
+ parsername,
+ parser,
+ err_or_nil or '?'
+ )
)
- elseif ret then
- local lang = ts.language.inspect_language(parsername)
- health.report_ok(
+ else
+ local lang = ts.language.inspect(parsername)
+ health.ok(
string.format('Parser: %-10s ABI: %d, path: %s', parsername, lang._abi_version, parser)
)
end
diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua
index d77a0d0d03..496193c6ed 100644
--- a/runtime/lua/vim/treesitter/highlighter.lua
+++ b/runtime/lua/vim/treesitter/highlighter.lua
@@ -1,17 +1,34 @@
-local a = vim.api
-local query = require('vim.treesitter.query')
+local api = vim.api
+local query = vim.treesitter.query
+local Range = require('vim.treesitter._range')
+
+---@alias TSHlIter fun(end_line: integer|nil): integer, TSNode, TSMetadata
+
+---@class TSHighlightState
+---@field next_row integer
+---@field iter TSHlIter|nil
--- support reload for quick experimentation
---@class TSHighlighter
+---@field active table<integer,TSHighlighter>
+---@field bufnr integer
+---@field orig_spelloptions string
+---@field _highlight_states table<TSTree,TSHighlightState>
+---@field _queries table<string,TSHighlighterQuery>
+---@field tree LanguageTree
+---@field redraw_count integer
local TSHighlighter = rawget(vim.treesitter, 'TSHighlighter') or {}
TSHighlighter.__index = TSHighlighter
+--- @nodoc
TSHighlighter.active = TSHighlighter.active or {}
+---@class TSHighlighterQuery
+---@field _query Query|nil
+---@field hl_cache table<integer,integer>
local TSHighlighterQuery = {}
TSHighlighterQuery.__index = TSHighlighterQuery
-local ns = a.nvim_create_namespace('treesitter/highlighter')
+local ns = api.nvim_create_namespace('treesitter/highlighter')
---@private
function TSHighlighterQuery.new(lang, query_string)
@@ -22,7 +39,7 @@ function TSHighlighterQuery.new(lang, query_string)
local name = self._query.captures[capture]
local id = 0
if not vim.startswith(name, '_') then
- id = a.nvim_get_hl_id_by_name('@' .. name .. '.' .. lang)
+ id = api.nvim_get_hl_id_by_name('@' .. name .. '.' .. lang)
end
rawset(table, capture, id)
@@ -31,22 +48,24 @@ function TSHighlighterQuery.new(lang, query_string)
})
if query_string then
- self._query = query.parse_query(lang, query_string)
+ self._query = query.parse(lang, query_string)
else
- self._query = query.get_query(lang, 'highlights')
+ self._query = query.get(lang, 'highlights')
end
return self
end
----@private
+---@package
function TSHighlighterQuery:query()
return self._query
end
---- Creates a new highlighter using @param tree
+---@package
+---
+--- Creates a highlighter for `tree`.
---
----@param tree LanguageTree |LanguageTree| parser object to use for highlighting
+---@param tree LanguageTree parser object to use for highlighting
---@param opts (table|nil) Configuration of the highlighter:
--- - queries table overwrite queries used by the highlighter
---@return TSHighlighter Created highlighter object
@@ -57,27 +76,37 @@ function TSHighlighter.new(tree, opts)
error('TSHighlighter can not be used with a string parser source.')
end
- opts = opts or {}
+ opts = opts or {} ---@type { queries: table<string,string> }
self.tree = tree
tree:register_cbs({
- on_changedtree = function(...)
- self:on_changedtree(...)
- end,
on_bytes = function(...)
self:on_bytes(...)
end,
- on_detach = function(...)
- self:on_detach(...)
+ on_detach = function()
+ self:on_detach()
end,
})
- self.bufnr = tree:source()
+ tree:register_cbs({
+ on_changedtree = function(...)
+ self:on_changedtree(...)
+ end,
+ on_child_removed = function(child)
+ child:for_each_tree(function(t)
+ self:on_changedtree(t:included_ranges(true))
+ end)
+ end,
+ }, true)
+
+ self.bufnr = tree:source() --[[@as integer]]
self.edit_count = 0
self.redraw_count = 0
self.line_count = {}
-- A map of highlight states.
-- This state is kept during rendering across each line update.
self._highlight_states = {}
+
+ ---@type table<string,TSHighlighterQuery>
self._queries = {}
-- Queries for a specific language can be overridden by a custom
@@ -103,7 +132,7 @@ function TSHighlighter.new(tree, opts)
vim.cmd.runtime({ 'syntax/synload.vim', bang = true })
end
- a.nvim_buf_call(self.bufnr, function()
+ api.nvim_buf_call(self.bufnr, function()
vim.opt_local.spelloptions:append('noplainbuffer')
end)
@@ -112,6 +141,7 @@ function TSHighlighter.new(tree, opts)
return self
end
+--- @nodoc
--- Removes all internal references to the highlighter
function TSHighlighter:destroy()
if TSHighlighter.active[self.bufnr] then
@@ -122,12 +152,14 @@ function TSHighlighter:destroy()
vim.bo[self.bufnr].spelloptions = self.orig_spelloptions
vim.b[self.bufnr].ts_highlight = nil
if vim.g.syntax_on == 1 then
- a.nvim_exec_autocmds('FileType', { group = 'syntaxset', buffer = self.bufnr })
+ api.nvim_exec_autocmds('FileType', { group = 'syntaxset', buffer = self.bufnr })
end
end
end
----@private
+---@package
+---@param tstree TSTree
+---@return TSHighlightState
function TSHighlighter:get_highlight_state(tstree)
if not self._highlight_states[tstree] then
self._highlight_states[tstree] = {
@@ -144,28 +176,31 @@ function TSHighlighter:reset_highlight_state()
self._highlight_states = {}
end
----@private
+---@package
+---@param start_row integer
+---@param new_end integer
function TSHighlighter:on_bytes(_, _, start_row, _, _, _, _, _, new_end)
- a.nvim__buf_redraw_range(self.bufnr, start_row, start_row + new_end + 1)
+ api.nvim__buf_redraw_range(self.bufnr, start_row, start_row + new_end + 1)
end
----@private
+---@package
function TSHighlighter:on_detach()
self:destroy()
end
----@private
+---@package
+---@param changes Range6[]
function TSHighlighter:on_changedtree(changes)
- for _, ch in ipairs(changes or {}) do
- a.nvim__buf_redraw_range(self.bufnr, ch[1], ch[3] + 1)
+ for _, ch in ipairs(changes) do
+ api.nvim__buf_redraw_range(self.bufnr, ch[1], ch[4] + 1)
end
end
--- Gets the query used for @param lang
--
----@private
+---@package
---@param lang string Language used by the highlighter.
----@return Query
+---@return TSHighlighterQuery
function TSHighlighter:get_query(lang)
if not self._queries[lang] then
self._queries[lang] = TSHighlighterQuery.new(lang)
@@ -174,7 +209,10 @@ function TSHighlighter:get_query(lang)
return self._queries[lang]
end
----@private
+---@param self TSHighlighter
+---@param buf integer
+---@param line integer
+---@param is_spell_nav boolean
local function on_line_impl(self, buf, line, is_spell_nav)
self.tree:for_each_tree(function(tstree, tree)
if not tstree then
@@ -203,45 +241,54 @@ local function on_line_impl(self, buf, line, is_spell_nav)
end
while line >= state.next_row do
- local capture, node, metadata = state.iter()
+ local capture, node, metadata = state.iter(line)
- if capture == nil then
- break
+ local range = { root_end_row + 1, 0, root_end_row + 1, 0 }
+ if node then
+ range = vim.treesitter.get_range(node, buf, metadata and metadata[capture])
end
-
- local start_row, start_col, end_row, end_col = node:range()
- local hl = highlighter_query.hl_cache[capture]
-
- local capture_name = highlighter_query:query().captures[capture]
- local spell = nil
- if capture_name == 'spell' then
- spell = true
- elseif capture_name == 'nospell' then
- spell = false
+ local start_row, start_col, end_row, end_col = Range.unpack4(range)
+
+ if capture then
+ local hl = highlighter_query.hl_cache[capture]
+
+ local capture_name = highlighter_query:query().captures[capture]
+ local spell = nil ---@type boolean?
+ if capture_name == 'spell' then
+ spell = true
+ elseif capture_name == 'nospell' then
+ spell = false
+ end
+
+ -- Give nospell a higher priority so it always overrides spell captures.
+ local spell_pri_offset = capture_name == 'nospell' and 1 or 0
+
+ if hl and end_row >= line and (not is_spell_nav or spell ~= nil) then
+ local priority = (tonumber(metadata.priority) or vim.highlight.priorities.treesitter)
+ + spell_pri_offset
+ api.nvim_buf_set_extmark(buf, ns, start_row, start_col, {
+ end_line = end_row,
+ end_col = end_col,
+ hl_group = hl,
+ ephemeral = true,
+ priority = priority,
+ conceal = metadata.conceal,
+ spell = spell,
+ })
+ end
end
- -- Give nospell a higher priority so it always overrides spell captures.
- local spell_pri_offset = capture_name == 'nospell' and 1 or 0
-
- if hl and end_row >= line and (not is_spell_nav or spell ~= nil) then
- a.nvim_buf_set_extmark(buf, ns, start_row, start_col, {
- end_line = end_row,
- end_col = end_col,
- hl_group = hl,
- ephemeral = true,
- priority = (tonumber(metadata.priority) or 100) + spell_pri_offset, -- Low but leaves room below
- conceal = metadata.conceal,
- spell = spell,
- })
- end
if start_row > line then
state.next_row = start_row
end
end
- end, true)
+ end)
end
---@private
+---@param _win integer
+---@param buf integer
+---@param line integer
function TSHighlighter._on_line(_, _win, buf, line, _)
local self = TSHighlighter.active[buf]
if not self then
@@ -252,6 +299,9 @@ function TSHighlighter._on_line(_, _win, buf, line, _)
end
---@private
+---@param buf integer
+---@param srow integer
+---@param erow integer
function TSHighlighter._on_spell_nav(_, _, buf, srow, _, erow, _)
local self = TSHighlighter.active[buf]
if not self then
@@ -266,27 +316,22 @@ function TSHighlighter._on_spell_nav(_, _, buf, srow, _, erow, _)
end
---@private
-function TSHighlighter._on_buf(_, buf)
- local self = TSHighlighter.active[buf]
- if self then
- self.tree:parse()
- end
-end
-
----@private
-function TSHighlighter._on_win(_, _win, buf, _topline)
+---@param _win integer
+---@param buf integer
+---@param topline integer
+---@param botline integer
+function TSHighlighter._on_win(_, _win, buf, topline, botline)
local self = TSHighlighter.active[buf]
if not self then
return false
end
-
+ self.tree:parse({ topline, botline + 1 })
self:reset_highlight_state()
self.redraw_count = self.redraw_count + 1
return true
end
-a.nvim_set_decoration_provider(ns, {
- on_buf = TSHighlighter._on_buf,
+api.nvim_set_decoration_provider(ns, {
on_win = TSHighlighter._on_win,
on_line = TSHighlighter._on_line,
_on_spell_nav = TSHighlighter._on_spell_nav,
diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua
index c92d63b8c4..15bf666a1e 100644
--- a/runtime/lua/vim/treesitter/language.lua
+++ b/runtime/lua/vim/treesitter/language.lua
@@ -1,42 +1,132 @@
-local a = vim.api
+local api = vim.api
+---@class TSLanguageModule
local M = {}
---- Asserts that a parser for the language {lang} is installed.
+---@type table<string,string>
+local ft_to_lang = {
+ help = 'vimdoc',
+}
+
+--- Get the filetypes associated with the parser named {lang}.
+--- @param lang string Name of parser
+--- @return string[] filetypes
+function M.get_filetypes(lang)
+ local r = {} ---@type string[]
+ for ft, p in pairs(ft_to_lang) do
+ if p == lang then
+ r[#r + 1] = ft
+ end
+ end
+ return r
+end
+
+--- @param filetype string
+--- @return string|nil
+function M.get_lang(filetype)
+ if filetype == '' then
+ return
+ end
+ if ft_to_lang[filetype] then
+ return ft_to_lang[filetype]
+ end
+ -- support subfiletypes like html.glimmer
+ filetype = vim.split(filetype, '.', { plain = true })[1]
+ return ft_to_lang[filetype]
+end
+
+---@deprecated
+function M.require_language(lang, path, silent, symbol_name)
+ local opts = {
+ silent = silent,
+ path = path,
+ symbol_name = symbol_name,
+ }
+
+ if silent then
+ local installed = pcall(M.add, lang, opts)
+ return installed
+ end
+
+ M.add(lang, opts)
+ return true
+end
+
+---@class treesitter.RequireLangOpts
+---@field path? string
+---@field silent? boolean
+---@field filetype? string|string[]
+---@field symbol_name? string
+
+--- Load parser with name {lang}
---
--- Parsers are searched in the `parser` runtime directory, or the provided {path}
---
----@param lang string Language the parser should parse
----@param path (string|nil) Optional path the parser is located at
----@param silent (boolean|nil) Don't throw an error if language not found
----@param symbol_name (string|nil) Internal symbol name for the language to load
----@return boolean If the specified language is installed
-function M.require_language(lang, path, silent, symbol_name)
+---@param lang string Name of the parser (alphanumerical and `_` only)
+---@param opts (table|nil) Options:
+--- - filetype (string|string[]) Default filetype the parser should be associated with.
+--- Defaults to {lang}.
+--- - path (string|nil) Optional path the parser is located at
+--- - symbol_name (string|nil) Internal symbol name for the language to load
+function M.add(lang, opts)
+ ---@cast opts treesitter.RequireLangOpts
+ opts = opts or {}
+ local path = opts.path
+ local filetype = opts.filetype or lang
+ local symbol_name = opts.symbol_name
+
+ vim.validate({
+ lang = { lang, 'string' },
+ path = { path, 'string', true },
+ symbol_name = { symbol_name, 'string', true },
+ filetype = { filetype, { 'string', 'table' }, true },
+ })
+
if vim._ts_has_language(lang) then
- return true
+ M.register(lang, filetype)
+ return
end
+
if path == nil then
- local fname = 'parser/' .. vim.fn.fnameescape(lang) .. '.*'
- local paths = a.nvim_get_runtime_file(fname, false)
- if #paths == 0 then
- if silent then
- return false
- end
+ if not (lang and lang:match('[%w_]+') == lang) then
+ error("'" .. lang .. "' is not a valid language name")
+ end
+ local fname = 'parser/' .. lang .. '.*'
+ local paths = api.nvim_get_runtime_file(fname, false)
+ if #paths == 0 then
error("no parser for '" .. lang .. "' language, see :help treesitter-parsers")
end
path = paths[1]
end
- if silent then
- return pcall(function()
- vim._ts_add_language(path, lang, symbol_name)
- end)
- else
- vim._ts_add_language(path, lang, symbol_name)
+ vim._ts_add_language(path, lang, symbol_name)
+ M.register(lang, filetype)
+end
+
+--- @param x string|string[]
+--- @return string[]
+local function ensure_list(x)
+ if type(x) == 'table' then
+ return x
end
+ return { x }
+end
- return true
+--- Register a parser named {lang} to be used for {filetype}(s).
+--- @param lang string Name of parser
+--- @param filetype string|string[] Filetype(s) to associate with lang
+function M.register(lang, filetype)
+ vim.validate({
+ lang = { lang, 'string' },
+ filetype = { filetype, { 'string', 'table' } },
+ })
+
+ for _, f in ipairs(ensure_list(filetype)) do
+ if f ~= '' then
+ ft_to_lang[f] = lang
+ end
+ end
end
--- Inspects the provided language.
@@ -45,9 +135,19 @@ end
---
---@param lang string Language
---@return table
-function M.inspect_language(lang)
- M.require_language(lang)
+function M.inspect(lang)
+ M.add(lang)
return vim._ts_inspect_language(lang)
end
+---@deprecated
+function M.inspect_language(...)
+ vim.deprecate(
+ 'vim.treesitter.language.inspect_language()',
+ 'vim.treesitter.language.inspect()',
+ '0.10'
+ )
+ return M.inspect(...)
+end
+
return M
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index a1e96f8ef2..0171b416cd 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -1,85 +1,257 @@
-local a = vim.api
+--- @defgroup lua-treesitter-languagetree
+---
+--- @brief A \*LanguageTree\* contains a tree of parsers: the root treesitter parser for {lang} and
+--- any "injected" language parsers, which themselves may inject other languages, recursively.
+--- For example a Lua buffer containing some Vimscript commands needs multiple parsers to fully
+--- understand its contents.
+---
+--- To create a LanguageTree (parser object) for a given buffer and language, use:
+---
+--- ```lua
+--- local parser = vim.treesitter.get_parser(bufnr, lang)
+--- ```
+---
+--- (where `bufnr=0` means current buffer). `lang` defaults to 'filetype'.
+--- Note: currently the parser is retained for the lifetime of a buffer but this may change;
+--- a plugin should keep a reference to the parser object if it wants incremental updates.
+---
+--- Whenever you need to access the current syntax tree, parse the buffer:
+---
+--- ```lua
+--- local tree = parser:parse({ start_row, end_row })
+--- ```
+---
+--- This returns a table of immutable |treesitter-tree| objects representing the current state of
+--- the buffer. When the plugin wants to access the state after a (possible) edit it must call
+--- `parse()` again. If the buffer wasn't edited, the same tree will be returned again without extra
+--- work. If the buffer was parsed before, incremental parsing will be done of the changed parts.
+---
+--- Note: To use the parser directly inside a |nvim_buf_attach()| Lua callback, you must call
+--- |vim.treesitter.get_parser()| before you register your callback. But preferably parsing
+--- shouldn't be done directly in the change callback anyway as they will be very frequent. Rather
+--- a plugin that does any kind of analysis on a tree should use a timer to throttle too frequent
+--- updates.
+---
+
+-- Debugging:
+--
+-- vim.g.__ts_debug levels:
+-- - 1. Messages from languagetree.lua
+-- - 2. Parse messages from treesitter
+-- - 2. Lex messages from treesitter
+--
+-- Log file can be found in stdpath('log')/treesitter.log
+
local query = require('vim.treesitter.query')
local language = require('vim.treesitter.language')
+local Range = require('vim.treesitter._range')
+
+---@alias TSCallbackName
+---| 'changedtree'
+---| 'bytes'
+---| 'detach'
+---| 'child_added'
+---| 'child_removed'
+
+---@alias TSCallbackNameOn
+---| 'on_changedtree'
+---| 'on_bytes'
+---| 'on_detach'
+---| 'on_child_added'
+---| 'on_child_removed'
+
+--- @type table<TSCallbackNameOn,TSCallbackName>
+local TSCallbackNames = {
+ on_changedtree = 'changedtree',
+ on_bytes = 'bytes',
+ on_detach = 'detach',
+ on_child_added = 'child_added',
+ on_child_removed = 'child_removed',
+}
---@class LanguageTree
----@field _callbacks function[] Callback handlers
----@field _children LanguageTree[] Injected languages
----@field _injection_query table Queries defining injected languages
----@field _opts table Options
----@field _parser userdata Parser for language
----@field _regions table List of regions this tree should manage and parse
----@field _lang string Language name
----@field _regions table
----@field _source (number|string) Buffer or string to parse
----@field _trees userdata[] Reference to parsed |tstree| (one for each language)
----@field _valid boolean If the parsed tree is valid
-
+---@field private _callbacks table<TSCallbackName,function[]> Callback handlers
+---@field package _callbacks_rec table<TSCallbackName,function[]> Callback handlers (recursive)
+---@field private _children table<string,LanguageTree> Injected languages
+---@field private _injection_query Query Queries defining injected languages
+---@field private _injections_processed boolean
+---@field private _opts table Options
+---@field private _parser TSParser Parser for language
+---@field private _has_regions boolean
+---@field private _regions table<integer, Range6[]>?
+---List of regions this tree should manage and parse. If nil then regions are
+---taken from _trees. This is mostly a short-lived cache for included_regions()
+---@field private _lang string Language name
+---@field private _parent_lang? string Parent language name
+---@field private _source (integer|string) Buffer or string to parse
+---@field private _trees table<integer, TSTree> Reference to parsed tree (one for each language).
+---Each key is the index of region, which is synced with _regions and _valid.
+---@field private _valid boolean|table<integer,boolean> If the parsed tree is valid
+---@field private _logger? fun(logtype: string, msg: string)
+---@field private _logfile? file*
local LanguageTree = {}
+
+---@class LanguageTreeOpts
+---@field queries table<string,string> -- Deprecated
+---@field injections table<string,string>
+
LanguageTree.__index = LanguageTree
---- A |LanguageTree| holds the treesitter parser for a given language {lang} used
---- to parse a buffer. As the buffer may contain injected languages, the LanguageTree
---- needs to store parsers for these child languages as well (which in turn may contain
---- child languages themselves, hence the name).
----
----@param source (number|string) Buffer or a string of text to parse
----@param lang string Root language this tree represents
----@param opts (table|nil) Optional keyword arguments:
---- - injections table Mapping language to injection query strings.
---- This is useful for overriding the built-in
---- runtime file searching for the injection language
---- query per language.
----@return LanguageTree |LanguageTree| parser object
-function LanguageTree.new(source, lang, opts)
- language.require_language(lang)
+--- @package
+---
+--- |LanguageTree| contains a tree of parsers: the root treesitter parser for {lang} and any
+--- "injected" language parsers, which themselves may inject other languages, recursively.
+---
+---@param source (integer|string) Buffer or text string to parse
+---@param lang string Root language of this tree
+---@param opts (table|nil) Optional arguments:
+--- - injections table Map of language to injection query strings. Overrides the
+--- built-in runtime file searching for language injections.
+---@param parent_lang? string Parent language name of this tree
+---@return LanguageTree parser object
+function LanguageTree.new(source, lang, opts, parent_lang)
+ language.add(lang)
+ ---@type LanguageTreeOpts
opts = opts or {}
- if opts.queries then
- a.nvim_err_writeln("'queries' is no longer supported. Use 'injections' now")
- opts.injections = opts.queries
+ if source == 0 then
+ source = vim.api.nvim_get_current_buf()
end
local injections = opts.injections or {}
- local self = setmetatable({
+
+ --- @type LanguageTree
+ local self = {
_source = source,
_lang = lang,
+ _parent_lang = parent_lang,
_children = {},
- _regions = {},
_trees = {},
_opts = opts,
- _injection_query = injections[lang] and query.parse_query(lang, injections[lang])
- or query.get_query(lang, 'injections'),
+ _injection_query = injections[lang] and query.parse(lang, injections[lang])
+ or query.get(lang, 'injections'),
+ _has_regions = false,
+ _injections_processed = false,
_valid = false,
_parser = vim._create_ts_parser(lang),
- _callbacks = {
- changedtree = {},
- bytes = {},
- detach = {},
- child_added = {},
- child_removed = {},
- },
- }, LanguageTree)
+ _callbacks = {},
+ _callbacks_rec = {},
+ }
+
+ setmetatable(self, LanguageTree)
+
+ if vim.g.__ts_debug and type(vim.g.__ts_debug) == 'number' then
+ self:_set_logger()
+ self:_log('START')
+ end
+
+ for _, name in pairs(TSCallbackNames) do
+ self._callbacks[name] = {}
+ self._callbacks_rec[name] = {}
+ end
return self
end
+--- @private
+function LanguageTree:_set_logger()
+ local source = self:source()
+ source = type(source) == 'string' and 'text' or tostring(source)
+
+ local lang = self:lang()
+
+ vim.fn.mkdir(vim.fn.stdpath('log'), 'p')
+ local logfilename = vim.fs.joinpath(vim.fn.stdpath('log'), 'treesitter.log')
+
+ local logfile, openerr = io.open(logfilename, 'a+')
+
+ if not logfile or openerr then
+ error(string.format('Could not open file (%s) for logging: %s', logfilename, openerr))
+ return
+ end
+
+ self._logfile = logfile
+
+ self._logger = function(logtype, msg)
+ self._logfile:write(string.format('%s:%s:(%s) %s\n', source, lang, logtype, msg))
+ self._logfile:flush()
+ end
+
+ local log_lex = vim.g.__ts_debug >= 3
+ local log_parse = vim.g.__ts_debug >= 2
+ self._parser:_set_logger(log_lex, log_parse, self._logger)
+end
+
+---Measure execution time of a function
+---@generic R1, R2, R3
+---@param f fun(): R1, R2, R2
+---@return number, R1, R2, R3
+local function tcall(f, ...)
+ local start = vim.uv.hrtime()
+ ---@diagnostic disable-next-line
+ local r = { f(...) }
+ --- @type number
+ local duration = (vim.uv.hrtime() - start) / 1000000
+ return duration, unpack(r)
+end
+
+---@private
+---@vararg any
+function LanguageTree:_log(...)
+ if not self._logger then
+ return
+ end
+
+ if not vim.g.__ts_debug or vim.g.__ts_debug < 1 then
+ return
+ end
+
+ local args = { ... }
+ if type(args[1]) == 'function' then
+ args = { args[1]() }
+ end
+
+ local info = debug.getinfo(2, 'nl')
+ local nregions = vim.tbl_count(self:included_regions())
+ local prefix =
+ string.format('%s:%d: (#regions=%d) ', info.name or '???', info.currentline or 0, nregions)
+
+ local msg = { prefix }
+ for _, x in ipairs(args) do
+ if type(x) == 'string' then
+ msg[#msg + 1] = x
+ else
+ msg[#msg + 1] = vim.inspect(x, { newline = ' ', indent = '' })
+ end
+ end
+ self._logger('nvim', table.concat(msg, ' '))
+end
+
--- Invalidates this parser and all its children
+---@param reload boolean|nil
function LanguageTree:invalidate(reload)
self._valid = false
-- buffer was reloaded, reparse all trees
if reload then
+ for _, t in pairs(self._trees) do
+ self:_do_callback('changedtree', t:included_ranges(true), t)
+ end
self._trees = {}
end
- for _, child in ipairs(self._children) do
+ for _, child in pairs(self._children) do
child:invalidate(reload)
end
end
---- Returns all trees this language tree contains.
+--- Returns all trees of the regions parsed by this parser.
--- Does not include child languages.
+--- The result is list-like if
+--- * this LanguageTree is the root, in which case the result is empty or a singleton list; or
+--- * the root LanguageTree is fully parsed.
+---
+---@return table<integer, TSTree>
function LanguageTree:trees()
return self._trees
end
@@ -89,11 +261,39 @@ function LanguageTree:lang()
return self._lang
end
---- Determines whether this tree is valid.
---- If the tree is invalid, call `parse()`.
---- This will return the updated tree.
-function LanguageTree:is_valid()
- return self._valid
+--- Returns whether this LanguageTree is valid, i.e., |LanguageTree:trees()| reflects the latest
+--- state of the source. If invalid, user should call |LanguageTree:parse()|.
+---@param exclude_children boolean|nil whether to ignore the validity of children (default `false`)
+---@return boolean
+function LanguageTree:is_valid(exclude_children)
+ local valid = self._valid
+
+ if type(valid) == 'table' then
+ for i, _ in pairs(self:included_regions()) do
+ if not valid[i] then
+ return false
+ end
+ end
+ end
+
+ if not exclude_children then
+ if not self._injections_processed then
+ return false
+ end
+
+ for _, child in pairs(self._children) do
+ if not child:is_valid(exclude_children) then
+ return false
+ end
+ end
+ end
+
+ if type(valid) == 'boolean' then
+ return valid
+ end
+
+ self._valid = true
+ return true
end
--- Returns a map of language to child tree.
@@ -106,50 +306,77 @@ function LanguageTree:source()
return self._source
end
---- Parses all defined regions using a treesitter parser
---- for the language this tree represents.
---- This will run the injection query for this language to
---- determine if any child languages should be created.
----
----@return userdata[] Table of parsed |tstree|
----@return table Change list
-function LanguageTree:parse()
- if self._valid then
- return self._trees
+--- @param region Range6[]
+--- @param range? boolean|Range
+--- @return boolean
+local function intercepts_region(region, range)
+ if #region == 0 then
+ return true
+ end
+
+ if range == nil then
+ return false
+ end
+
+ if type(range) == 'boolean' then
+ return range
end
- local parser = self._parser
+ for _, r in ipairs(region) do
+ if Range.intercepts(r, range) then
+ return true
+ end
+ end
+
+ return false
+end
+
+--- @private
+--- @param range boolean|Range?
+--- @return Range6[] changes
+--- @return integer no_regions_parsed
+--- @return number total_parse_time
+function LanguageTree:_parse_regions(range)
local changes = {}
+ local no_regions_parsed = 0
+ local total_parse_time = 0
- local old_trees = self._trees
- self._trees = {}
+ if type(self._valid) ~= 'table' then
+ self._valid = {}
+ end
-- If there are no ranges, set to an empty list
-- so the included ranges in the parser are cleared.
- if self._regions and #self._regions > 0 then
- for i, ranges in ipairs(self._regions) do
- local old_tree = old_trees[i]
- parser:set_included_ranges(ranges)
+ for i, ranges in pairs(self:included_regions()) do
+ if not self._valid[i] and intercepts_region(ranges, range) then
+ self._parser:set_included_ranges(ranges)
+ local parse_time, tree, tree_changes =
+ tcall(self._parser.parse, self._parser, self._trees[i], self._source, true)
- local tree, tree_changes = parser:parse(old_tree, self._source)
- self:_do_callback('changedtree', tree_changes, tree)
+ -- Pass ranges if this is an initial parse
+ local cb_changes = self._trees[i] and tree_changes or tree:included_ranges(true)
- table.insert(self._trees, tree)
+ self:_do_callback('changedtree', cb_changes, tree)
+ self._trees[i] = tree
vim.list_extend(changes, tree_changes)
- end
- else
- local tree, tree_changes = parser:parse(old_trees[1], self._source)
- self:_do_callback('changedtree', tree_changes, tree)
- table.insert(self._trees, tree)
- vim.list_extend(changes, tree_changes)
+ total_parse_time = total_parse_time + parse_time
+ no_regions_parsed = no_regions_parsed + 1
+ self._valid[i] = true
+ end
end
- local injections_by_lang = self:_get_injections()
- local seen_langs = {}
+ return changes, no_regions_parsed, total_parse_time
+end
+
+--- @private
+--- @return number
+function LanguageTree:_add_injections()
+ local seen_langs = {} ---@type table<string,boolean>
- for lang, injection_ranges in pairs(injections_by_lang) do
- local has_lang = language.require_language(lang, nil, true)
+ local query_time, injections_by_lang = tcall(self._get_injections, self)
+ for lang, injection_regions in pairs(injections_by_lang) do
+ local has_lang = pcall(language.add, lang)
-- Child language trees should just be ignored if not found, since
-- they can depend on the text of a node. Intermediate strings
@@ -161,16 +388,7 @@ function LanguageTree:parse()
child = self:add_child(lang)
end
- child:set_included_regions(injection_ranges)
-
- local _, child_changes = child:parse()
-
- -- Propagate any child changes so they are included in the
- -- the change list for the callback.
- if child_changes then
- vim.list_extend(changes, child_changes)
- end
-
+ child:set_included_regions(injection_regions)
seen_langs[lang] = true
end
end
@@ -181,16 +399,71 @@ function LanguageTree:parse()
end
end
- self._valid = true
+ return query_time
+end
+
+--- Recursively parse all regions in the language tree using |treesitter-parsers|
+--- for the corresponding languages and run injection queries on the parsed trees
+--- to determine whether child trees should be created and parsed.
+---
+--- Any region with empty range (`{}`, typically only the root tree) is always parsed;
+--- otherwise (typically injections) only if it intersects {range} (or if {range} is `true`).
+---
+--- @param range boolean|Range|nil: Parse this range in the parser's source.
+--- Set to `true` to run a complete parse of the source (Note: Can be slow!)
+--- Set to `false|nil` to only parse regions with empty ranges (typically
+--- only the root tree without injections).
+--- @return table<integer, TSTree>
+function LanguageTree:parse(range)
+ if self:is_valid() then
+ self:_log('valid')
+ return self._trees
+ end
- return self._trees, changes
+ local changes --- @type Range6[]?
+
+ -- Collect some stats
+ local no_regions_parsed = 0
+ local query_time = 0
+ local total_parse_time = 0
+
+ --- At least 1 region is invalid
+ if not self:is_valid(true) then
+ changes, no_regions_parsed, total_parse_time = self:_parse_regions(range)
+ -- Need to run injections when we parsed something
+ if no_regions_parsed > 0 then
+ self._injections_processed = false
+ end
+ end
+
+ if not self._injections_processed and range ~= false and range ~= nil then
+ query_time = self:_add_injections()
+ self._injections_processed = true
+ end
+
+ self:_log({
+ changes = changes and #changes > 0 and changes or nil,
+ regions_parsed = no_regions_parsed,
+ parse_time = total_parse_time,
+ query_time = query_time,
+ range = range,
+ })
+
+ for _, child in pairs(self._children) do
+ child:parse(range)
+ end
+
+ return self._trees
end
+---@deprecated Misleading name. Use `LanguageTree:children()` (non-recursive) instead,
+--- add recursion yourself if needed.
--- Invokes the callback for each |LanguageTree| and its children recursively
---
----@param fn function(tree: LanguageTree, lang: string)
----@param include_self boolean Whether to include the invoking tree in the results
+---@param fn fun(tree: LanguageTree, lang: string)
+---@param include_self boolean|nil Whether to include the invoking tree in the results
function LanguageTree:for_each_child(fn, include_self)
+ vim.deprecate('LanguageTree:for_each_child()', 'LanguageTree:children()', '0.11')
if include_self then
fn(self, self._lang)
end
@@ -204,9 +477,9 @@ end
---
--- Note: This includes the invoking tree's child trees as well.
---
----@param fn function(tree: TSTree, languageTree: LanguageTree)
+---@param fn fun(tree: TSTree, ltree: LanguageTree)
function LanguageTree:for_each_tree(fn)
- for _, tree in ipairs(self._trees) do
+ for _, tree in pairs(self._trees) do
fn(tree, self)
end
@@ -221,15 +494,20 @@ end
---
---@private
---@param lang string Language to add.
----@return LanguageTree Injected |LanguageTree|
+---@return LanguageTree injected
function LanguageTree:add_child(lang)
if self._children[lang] then
self:remove_child(lang)
end
- self._children[lang] = LanguageTree.new(self._source, lang, self._opts)
+ local child = LanguageTree.new(self._source, lang, self._opts, self:lang())
- self:invalidate()
+ -- Inherit recursive callbacks
+ for nm, cb in pairs(self._callbacks_rec) do
+ vim.list_extend(child._callbacks_rec[nm], cb)
+ end
+
+ self._children[lang] = child
self:_do_callback('child_added', self._children[lang])
return self._children[lang]
@@ -245,7 +523,6 @@ function LanguageTree:remove_child(lang)
if child then
self._children[lang] = nil
child:destroy()
- self:invalidate()
self:_do_callback('child_removed', child)
end
end
@@ -258,11 +535,60 @@ end
--- `remove_child` must be called on the parent to remove it.
function LanguageTree:destroy()
-- Cleanup here
- for _, child in ipairs(self._children) do
+ for _, child in pairs(self._children) do
child:destroy()
end
end
+---@param region Range6[]
+local function region_tostr(region)
+ if #region == 0 then
+ return '[]'
+ end
+ local srow, scol = region[1][1], region[1][2]
+ local erow, ecol = region[#region][4], region[#region][5]
+ return string.format('[%d:%d-%d:%d]', srow, scol, erow, ecol)
+end
+
+---@private
+---Iterate through all the regions. fn returns a boolean to indicate if the
+---region is valid or not.
+---@param fn fun(index: integer, region: Range6[]): boolean
+function LanguageTree:_iter_regions(fn)
+ if not self._valid then
+ return
+ end
+
+ local was_valid = type(self._valid) ~= 'table'
+
+ if was_valid then
+ self:_log('was valid', self._valid)
+ self._valid = {}
+ end
+
+ local all_valid = true
+
+ for i, region in pairs(self:included_regions()) do
+ if was_valid or self._valid[i] then
+ self._valid[i] = fn(i, region)
+ if not self._valid[i] then
+ self:_log(function()
+ return 'invalidating region', i, region_tostr(region)
+ end)
+ end
+ end
+
+ if not self._valid[i] then
+ all_valid = false
+ end
+ end
+
+ -- Compress the valid value to 'true' if there are no invalid regions
+ if all_valid then
+ self._valid = all_valid
+ end
+end
+
--- Sets the included regions that should be parsed by this |LanguageTree|.
--- A region is a set of nodes and/or ranges that will be parsed in the same context.
---
@@ -277,151 +603,253 @@ end
--- This allows for embedded languages to be parsed together across different
--- nodes, which is useful for templating languages like ERB and EJS.
---
---- Note: This call invalidates the tree and requires it to be parsed again.
----
---@private
----@param regions table List of regions this tree should manage and parse.
-function LanguageTree:set_included_regions(regions)
+---@param new_regions (Range4|Range6|TSNode)[][] List of regions this tree should manage and parse.
+function LanguageTree:set_included_regions(new_regions)
+ self._has_regions = true
+
-- Transform the tables from 4 element long to 6 element long (with byte offset)
- for _, region in ipairs(regions) do
+ for _, region in ipairs(new_regions) do
for i, range in ipairs(region) do
if type(range) == 'table' and #range == 4 then
- local start_row, start_col, end_row, end_col = unpack(range)
- local start_byte = 0
- local end_byte = 0
- -- TODO(vigoux): proper byte computation here, and account for EOL ?
- if type(self._source) == 'number' then
- -- Easy case, this is a buffer parser
- start_byte = a.nvim_buf_get_offset(self._source, start_row) + start_col
- end_byte = a.nvim_buf_get_offset(self._source, end_row) + end_col
- elseif type(self._source) == 'string' then
- -- string parser, single `\n` delimited string
- start_byte = vim.fn.byteidx(self._source, start_col)
- end_byte = vim.fn.byteidx(self._source, end_col)
- end
-
- region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte }
+ region[i] = Range.add_bytes(self._source, range --[[@as Range4]])
+ elseif type(range) == 'userdata' then
+ region[i] = { range:range(true) }
end
end
end
- self._regions = regions
- -- Trees are no longer valid now that we have changed regions.
- -- TODO(vigoux,steelsojka): Look into doing this smarter so we can use some of the
- -- old trees for incremental parsing. Currently, this only
- -- affects injected languages.
- self._trees = {}
- self:invalidate()
+ -- included_regions is not guaranteed to be list-like, but this is still sound, i.e. if
+ -- new_regions is different from included_regions, then outdated regions in included_regions are
+ -- invalidated. For example, if included_regions = new_regions ++ hole ++ outdated_regions, then
+ -- outdated_regions is invalidated by _iter_regions in else branch.
+ if #self:included_regions() ~= #new_regions then
+ -- TODO(lewis6991): inefficient; invalidate trees incrementally
+ for _, t in pairs(self._trees) do
+ self:_do_callback('changedtree', t:included_ranges(true), t)
+ end
+ self._trees = {}
+ self:invalidate()
+ else
+ self:_iter_regions(function(i, region)
+ return vim.deep_equal(new_regions[i], region)
+ end)
+ end
+
+ self._regions = new_regions
end
---- Gets the set of included regions
+---Gets the set of included regions managed by this LanguageTree. This can be different from the
+---regions set by injection query, because a partial |LanguageTree:parse()| drops the regions
+---outside the requested range.
+---@return table<integer, Range6[]>
function LanguageTree:included_regions()
- return self._regions
+ if self._regions then
+ return self._regions
+ end
+
+ if not self._has_regions then
+ -- treesitter.c will default empty ranges to { -1, -1, -1, -1, -1, -1} (the full range)
+ return { {} }
+ end
+
+ local regions = {} ---@type Range6[][]
+ for i, _ in pairs(self._trees) do
+ regions[i] = self._trees[i]:included_ranges(true)
+ end
+
+ self._regions = regions
+ return regions
+end
+
+---@param node TSNode
+---@param source string|integer
+---@param metadata TSMetadata
+---@param include_children boolean
+---@return Range6[]
+local function get_node_ranges(node, source, metadata, include_children)
+ local range = vim.treesitter.get_range(node, source, metadata)
+ local child_count = node:named_child_count()
+
+ if include_children or child_count == 0 then
+ return { range }
+ end
+
+ local ranges = {} ---@type Range6[]
+
+ local srow, scol, sbyte, erow, ecol, ebyte = Range.unpack6(range)
+
+ -- We are excluding children so we need to mask out their ranges
+ for i = 0, child_count - 1 do
+ local child = assert(node:named_child(i))
+ local c_srow, c_scol, c_sbyte, c_erow, c_ecol, c_ebyte = child:range(true)
+ if c_srow > srow or c_scol > scol then
+ ranges[#ranges + 1] = { srow, scol, sbyte, c_srow, c_scol, c_sbyte }
+ end
+ srow = c_erow
+ scol = c_ecol
+ sbyte = c_ebyte
+ end
+
+ if erow > srow or ecol > scol then
+ ranges[#ranges + 1] = Range.add_bytes(source, { srow, scol, sbyte, erow, ecol, ebyte })
+ end
+
+ return ranges
+end
+
+---@class TSInjectionElem
+---@field combined boolean
+---@field regions Range6[][]
+
+---@alias TSInjection table<string,table<integer,TSInjectionElem>>
+
+---@param t table<integer,TSInjection>
+---@param tree_index integer
+---@param pattern integer
+---@param lang string
+---@param combined boolean
+---@param ranges Range6[]
+local function add_injection(t, tree_index, pattern, lang, combined, ranges)
+ if #ranges == 0 then
+ -- Make sure not to add an empty range set as this is interpreted to mean the whole buffer.
+ return
+ end
+
+ -- Each tree index should be isolated from the other nodes.
+ if not t[tree_index] then
+ t[tree_index] = {}
+ end
+
+ if not t[tree_index][lang] then
+ t[tree_index][lang] = {}
+ end
+
+ -- Key this by pattern. If combined is set to true all captures of this pattern
+ -- will be parsed by treesitter as the same "source".
+ -- If combined is false, each "region" will be parsed as a single source.
+ if not t[tree_index][lang][pattern] then
+ t[tree_index][lang][pattern] = { combined = combined, regions = {} }
+ end
+
+ table.insert(t[tree_index][lang][pattern].regions, ranges)
+end
+
+-- TODO(clason): replace by refactored `ts.has_parser` API (without registering)
+--- The result of this function is cached to prevent nvim_get_runtime_file from being
+--- called too often
+--- @param lang string parser name
+--- @return boolean # true if parser for {lang} exists on rtp
+local has_parser = vim.func._memoize(1, function(lang)
+ return vim._ts_has_language(lang)
+ or #vim.api.nvim_get_runtime_file('parser/' .. lang .. '.*', false) > 0
+end)
+
+--- Return parser name for language (if exists) or filetype (if registered and exists).
+--- Also attempts with the input lower-cased.
+---
+---@param alias string language or filetype name
+---@return string? # resolved parser name
+local function resolve_lang(alias)
+ if has_parser(alias) then
+ return alias
+ end
+
+ if has_parser(alias:lower()) then
+ return alias:lower()
+ end
+
+ local lang = vim.treesitter.language.get_lang(alias)
+ if lang and has_parser(lang) then
+ return lang
+ end
+
+ lang = vim.treesitter.language.get_lang(alias:lower())
+ if lang and has_parser(lang) then
+ return lang
+ end
end
---@private
-local function get_range_from_metadata(node, id, metadata)
- if metadata[id] and metadata[id].range then
- return metadata[id].range
+--- Extract injections according to:
+--- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection
+---@param match table<integer,TSNode>
+---@param metadata TSMetadata
+---@return string?, boolean, Range6[]
+function LanguageTree:_get_injection(match, metadata)
+ local ranges = {} ---@type Range6[]
+ local combined = metadata['injection.combined'] ~= nil
+ local injection_lang = metadata['injection.language'] --[[@as string?]]
+ local lang = metadata['injection.self'] ~= nil and self:lang()
+ or metadata['injection.parent'] ~= nil and self._parent_lang
+ or (injection_lang and resolve_lang(injection_lang))
+ local include_children = metadata['injection.include-children'] ~= nil
+
+ for id, node in pairs(match) do
+ local name = self._injection_query.captures[id]
+ -- Lang should override any other language tag
+ if name == 'injection.language' then
+ local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] })
+ lang = resolve_lang(text)
+ elseif name == 'injection.content' then
+ ranges = get_node_ranges(node, self._source, metadata[id], include_children)
+ end
end
- return { node:range() }
+
+ return lang, combined, ranges
end
---- Gets language injection points by language.
+--- Can't use vim.tbl_flatten since a range is just a table.
+---@param regions Range6[][]
+---@return Range6[]
+local function combine_regions(regions)
+ local result = {} ---@type Range6[]
+ for _, region in ipairs(regions) do
+ for _, range in ipairs(region) do
+ result[#result + 1] = range
+ end
+ end
+ return result
+end
+
+--- Gets language injection regions by language.
---
--- This is where most of the injection processing occurs.
---
--- TODO: Allow for an offset predicate to tailor the injection range
--- instead of using the entire nodes range.
----@private
+--- @private
+--- @return table<string, Range6[][]>
function LanguageTree:_get_injections()
if not self._injection_query then
return {}
end
+ ---@type table<integer,TSInjection>
local injections = {}
- for tree_index, tree in ipairs(self._trees) do
+ for index, tree in pairs(self._trees) do
local root_node = tree:root()
local start_line, _, end_line, _ = root_node:range()
for pattern, match, metadata in
self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1)
do
- local lang = nil
- local ranges = {}
- local combined = metadata.combined
-
- -- Directives can configure how injections are captured as well as actual node captures.
- -- This allows more advanced processing for determining ranges and language resolution.
- if metadata.content then
- local content = metadata.content
-
- -- Allow for captured nodes to be used
- if type(content) == 'number' then
- content = { match[content]:range() }
- end
-
- if type(content) == 'table' and #content >= 4 then
- vim.list_extend(ranges, content)
- end
+ local lang, combined, ranges = self:_get_injection(match, metadata)
+ if lang then
+ add_injection(injections, index, pattern, lang, combined, ranges)
+ else
+ self:_log('match from injection query failed for pattern', pattern)
end
-
- if metadata.language then
- lang = metadata.language
- end
-
- -- You can specify the content and language together
- -- using a tag with the language, for example
- -- @javascript
- for id, node in pairs(match) do
- local name = self._injection_query.captures[id]
-
- -- Lang should override any other language tag
- if name == 'language' and not lang then
- lang = query.get_node_text(node, self._source)
- elseif name == 'combined' then
- combined = true
- elseif name == 'content' and #ranges == 0 then
- table.insert(ranges, get_range_from_metadata(node, id, metadata))
- -- Ignore any tags that start with "_"
- -- Allows for other tags to be used in matches
- elseif string.sub(name, 1, 1) ~= '_' then
- if not lang then
- lang = name
- end
-
- if #ranges == 0 then
- table.insert(ranges, get_range_from_metadata(node, id, metadata))
- end
- end
- end
-
- -- Each tree index should be isolated from the other nodes.
- if not injections[tree_index] then
- injections[tree_index] = {}
- end
-
- if not injections[tree_index][lang] then
- injections[tree_index][lang] = {}
- end
-
- -- Key this by pattern. If combined is set to true all captures of this pattern
- -- will be parsed by treesitter as the same "source".
- -- If combined is false, each "region" will be parsed as a single source.
- if not injections[tree_index][lang][pattern] then
- injections[tree_index][lang][pattern] = { combined = combined, regions = {} }
- end
-
- table.insert(injections[tree_index][lang][pattern].regions, ranges)
end
end
+ ---@type table<string,Range6[][]>
local result = {}
-- Generate a map by lang of node lists.
-- Each list is a set of ranges that should be parsed together.
- for _, lang_map in ipairs(injections) do
+ for _, lang_map in pairs(injections) do
for lang, patterns in pairs(lang_map) do
if not result[lang] then
result[lang] = {}
@@ -429,12 +857,9 @@ function LanguageTree:_get_injections()
for _, entry in pairs(patterns) do
if entry.combined then
- local regions = vim.tbl_map(function(e)
- return vim.tbl_flatten(e)
- end, entry.regions)
- table.insert(result[lang], regions)
+ table.insert(result[lang], combine_regions(entry.regions))
else
- for _, ranges in ipairs(entry.regions) do
+ for _, ranges in pairs(entry.regions) do
table.insert(result[lang], ranges)
end
end
@@ -446,13 +871,94 @@ function LanguageTree:_get_injections()
end
---@private
+---@param cb_name TSCallbackName
function LanguageTree:_do_callback(cb_name, ...)
for _, cb in ipairs(self._callbacks[cb_name]) do
cb(...)
end
+ for _, cb in ipairs(self._callbacks_rec[cb_name]) do
+ cb(...)
+ end
end
----@private
+---@package
+function LanguageTree:_edit(
+ start_byte,
+ end_byte_old,
+ end_byte_new,
+ start_row,
+ start_col,
+ end_row_old,
+ end_col_old,
+ end_row_new,
+ end_col_new
+)
+ for _, tree in pairs(self._trees) do
+ tree:edit(
+ start_byte,
+ end_byte_old,
+ end_byte_new,
+ start_row,
+ start_col,
+ end_row_old,
+ end_col_old,
+ end_row_new,
+ end_col_new
+ )
+ end
+
+ self._regions = nil
+
+ local changed_range = {
+ start_row,
+ start_col,
+ start_byte,
+ end_row_old,
+ end_col_old,
+ end_byte_old,
+ }
+
+ -- Validate regions after editing the tree
+ self:_iter_regions(function(_, region)
+ if #region == 0 then
+ -- empty region, use the full source
+ return false
+ end
+ for _, r in ipairs(region) do
+ if Range.intercepts(r, changed_range) then
+ return false
+ end
+ end
+ return true
+ end)
+
+ for _, child in pairs(self._children) do
+ child:_edit(
+ start_byte,
+ end_byte_old,
+ end_byte_new,
+ start_row,
+ start_col,
+ end_row_old,
+ end_col_old,
+ end_row_new,
+ end_col_new
+ )
+ end
+end
+
+---@package
+---@param bufnr integer
+---@param changed_tick integer
+---@param start_row integer
+---@param start_col integer
+---@param start_byte integer
+---@param old_row integer
+---@param old_col integer
+---@param old_byte integer
+---@param new_row integer
+---@param new_col integer
+---@param new_byte integer
function LanguageTree:_on_bytes(
bufnr,
changed_tick,
@@ -466,26 +972,36 @@ function LanguageTree:_on_bytes(
new_col,
new_byte
)
- self:invalidate()
-
local old_end_col = old_col + ((old_row == 0) and start_col or 0)
local new_end_col = new_col + ((new_row == 0) and start_col or 0)
- -- Edit all trees recursively, together BEFORE emitting a bytes callback.
- -- In most cases this callback should only be called from the root tree.
- self:for_each_tree(function(tree)
- tree:edit(
- start_byte,
- start_byte + old_byte,
- start_byte + new_byte,
- start_row,
- start_col,
- start_row + old_row,
- old_end_col,
- start_row + new_row,
- new_end_col
- )
- end)
+ self:_log(
+ 'on_bytes',
+ bufnr,
+ changed_tick,
+ start_row,
+ start_col,
+ start_byte,
+ old_row,
+ old_col,
+ old_byte,
+ new_row,
+ new_col,
+ new_byte
+ )
+
+ -- Edit trees together BEFORE emitting a bytes callback.
+ self:_edit(
+ start_byte,
+ start_byte + old_byte,
+ start_byte + new_byte,
+ start_row,
+ start_col,
+ start_row + old_row,
+ old_end_col,
+ start_row + new_row,
+ new_end_col
+ )
self:_do_callback(
'bytes',
@@ -503,63 +1019,65 @@ function LanguageTree:_on_bytes(
)
end
----@private
+---@package
function LanguageTree:_on_reload()
self:invalidate(true)
end
----@private
+---@package
function LanguageTree:_on_detach(...)
self:invalidate(true)
self:_do_callback('detach', ...)
+ if self._logfile then
+ self._logger('nvim', 'detaching')
+ self._logger = nil
+ self._logfile:close()
+ end
end
--- Registers callbacks for the |LanguageTree|.
---@param cbs table An |nvim_buf_attach()|-like table argument with the following handlers:
--- - `on_bytes` : see |nvim_buf_attach()|, but this will be called _after_ the parsers callback.
--- - `on_changedtree` : a callback that will be called every time the tree has syntactical changes.
---- It will only be passed one argument, which is a table of the ranges (as node ranges) that
---- changed.
+--- It will be passed two arguments: a table of the ranges (as node ranges) that
+--- changed and the changed tree.
--- - `on_child_added` : emitted when a child is added to the tree.
--- - `on_child_removed` : emitted when a child is removed from the tree.
-function LanguageTree:register_cbs(cbs)
+--- - `on_detach` : emitted when the buffer is detached, see |nvim_buf_detach_event|.
+--- Takes one argument, the number of the buffer.
+--- @param recursive? boolean Apply callbacks recursively for all children. Any new children will
+--- also inherit the callbacks.
+function LanguageTree:register_cbs(cbs, recursive)
+ ---@cast cbs table<TSCallbackNameOn,function>
if not cbs then
return
end
- if cbs.on_changedtree then
- table.insert(self._callbacks.changedtree, cbs.on_changedtree)
- end
-
- if cbs.on_bytes then
- table.insert(self._callbacks.bytes, cbs.on_bytes)
- end
+ local callbacks = recursive and self._callbacks_rec or self._callbacks
- if cbs.on_detach then
- table.insert(self._callbacks.detach, cbs.on_detach)
- end
-
- if cbs.on_child_added then
- table.insert(self._callbacks.child_added, cbs.on_child_added)
+ for name, cbname in pairs(TSCallbackNames) do
+ if cbs[name] then
+ table.insert(callbacks[cbname], cbs[name])
+ end
end
- if cbs.on_child_removed then
- table.insert(self._callbacks.child_removed, cbs.on_child_removed)
+ if recursive then
+ for _, child in pairs(self._children) do
+ child:register_cbs(cbs, true)
+ end
end
end
----@private
+---@param tree TSTree
+---@param range Range
+---@return boolean
local function tree_contains(tree, range)
- local start_row, start_col, end_row, end_col = tree:root():range()
- local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2])
- local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4])
-
- return start_fits and end_fits
+ return Range.contains({ tree:root():range() }, range)
end
--- Determines whether {range} is contained in the |LanguageTree|.
---
----@param range table `{ start_line, start_col, end_line, end_col }`
+---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@return boolean
function LanguageTree:contains(range)
for _, tree in pairs(self._trees) do
@@ -573,20 +1091,19 @@ end
--- Gets the tree that contains {range}.
---
----@param range table `{ start_line, start_col, end_line, end_col }`
+---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@param opts table|nil Optional keyword arguments:
--- - ignore_injections boolean Ignore injected languages (default true)
----@return userdata|nil Contained |tstree|
+---@return TSTree|nil
function LanguageTree:tree_for_range(range, opts)
opts = opts or {}
local ignore = vim.F.if_nil(opts.ignore_injections, true)
if not ignore then
for _, child in pairs(self._children) do
- for _, tree in pairs(child:trees()) do
- if tree_contains(tree, range) then
- return tree
- end
+ local tree = child:tree_for_range(range, opts)
+ if tree then
+ return tree
end
end
end
@@ -602,10 +1119,10 @@ end
--- Gets the smallest named node that contains {range}.
---
----@param range table `{ start_line, start_col, end_line, end_col }`
+---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@param opts table|nil Optional keyword arguments:
--- - ignore_injections boolean Ignore injected languages (default true)
----@return userdata|nil Found |tsnode|
+---@return TSNode | nil Found node
function LanguageTree:named_node_for_range(range, opts)
local tree = self:tree_for_range(range, opts)
if tree then
@@ -615,7 +1132,7 @@ end
--- Gets the appropriate language that contains {range}.
---
----@param range table `{ start_line, start_col, end_line, end_col }`
+---@param range Range4 `{ start_line, start_col, end_line, end_col }`
---@return LanguageTree Managing {range}
function LanguageTree:language_for_range(range)
for _, child in pairs(self._children) do
diff --git a/runtime/lua/vim/treesitter/playground.lua b/runtime/lua/vim/treesitter/playground.lua
deleted file mode 100644
index bb073290c6..0000000000
--- a/runtime/lua/vim/treesitter/playground.lua
+++ /dev/null
@@ -1,186 +0,0 @@
-local api = vim.api
-
-local M = {}
-
----@class Playground
----@field ns number API namespace
----@field opts table Options table with the following keys:
---- - anon (boolean): If true, display anonymous nodes
---- - lang (boolean): If true, display the language alongside each node
----
----@class Node
----@field id number Node id
----@field text string Node text
----@field named boolean True if this is a named (non-anonymous) node
----@field depth number Depth of the node within the tree
----@field lnum number Beginning line number of this node in the source buffer
----@field col number Beginning column number of this node in the source buffer
----@field end_lnum number Final line number of this node in the source buffer
----@field end_col number Final column number of this node in the source buffer
----@field lang string Source language of this node
-
---- Traverse all child nodes starting at {node}.
----
---- This is a recursive function. The {depth} parameter indicates the current recursion level.
---- {lang} is a string indicating the language of the tree currently being traversed. Each traversed
---- node is added to {tree}. When recursion completes, {tree} is an array of all nodes in the order
---- they were visited.
----
---- {injections} is a table mapping node ids from the primary tree to language tree injections. Each
---- injected language has a series of trees nested within the primary language's tree, and the root
---- node of each of these trees is contained within a node in the primary tree. The {injections}
---- table maps nodes in the primary tree to root nodes of injected trees.
----
----@param node userdata Starting node to begin traversal |tsnode|
----@param depth number Current recursion depth
----@param lang string Language of the tree currently being traversed
----@param injections table Mapping of node ids to root nodes of injected language trees (see
---- explanation above)
----@param tree Node[] Output table containing a list of tables each representing a node in the tree
----@private
-local function traverse(node, depth, lang, injections, tree)
- local injection = injections[node:id()]
- if injection then
- traverse(injection.root, depth, injection.lang, injections, tree)
- end
-
- for child, field in node:iter_children() do
- local type = child:type()
- local lnum, col, end_lnum, end_col = child:range()
- local named = child:named()
- local text
- if named then
- if field then
- text = string.format('%s: (%s)', field, type)
- else
- text = string.format('(%s)', type)
- end
- else
- text = string.format('"%s"', type:gsub('\n', '\\n'))
- end
-
- table.insert(tree, {
- id = child:id(),
- text = text,
- named = named,
- depth = depth,
- lnum = lnum,
- col = col,
- end_lnum = end_lnum,
- end_col = end_col,
- lang = lang,
- })
-
- traverse(child, depth + 1, lang, injections, tree)
- end
-
- return tree
-end
-
---- Create a new Playground object.
----
----@param bufnr number Source buffer number
----@param lang string|nil Language of source buffer
----
----@return Playground|nil
----@return string|nil Error message, if any
----
----@private
-function M.new(self, bufnr, lang)
- local ok, parser = pcall(vim.treesitter.get_parser, bufnr or 0, lang)
- if not ok then
- return nil, 'No parser available for the given buffer'
- end
-
- -- For each child tree (injected language), find the root of the tree and locate the node within
- -- the primary tree that contains that root. Add a mapping from the node in the primary tree to
- -- the root in the child tree to the {injections} table.
- local root = parser:parse()[1]:root()
- local injections = {}
- parser:for_each_child(function(child, lang_)
- child:for_each_tree(function(tree)
- local r = tree:root()
- local node = root:named_descendant_for_range(r:range())
- if node then
- injections[node:id()] = {
- lang = lang_,
- root = r,
- }
- end
- end)
- end)
-
- local nodes = traverse(root, 0, parser:lang(), injections, {})
-
- local named = {}
- for _, v in ipairs(nodes) do
- if v.named then
- named[#named + 1] = v
- end
- end
-
- local t = {
- ns = api.nvim_create_namespace(''),
- nodes = nodes,
- named = named,
- opts = {
- anon = false,
- lang = false,
- },
- }
-
- setmetatable(t, self)
- self.__index = self
- return t
-end
-
---- Write the contents of this Playground into {bufnr}.
----
----@param bufnr number Buffer number to write into.
----@private
-function M.draw(self, bufnr)
- vim.bo[bufnr].modifiable = true
- local lines = {}
- for _, item in self:iter() do
- lines[#lines + 1] = table.concat({
- string.rep(' ', item.depth),
- item.text,
- item.lnum == item.end_lnum
- and string.format(' [%d:%d-%d]', item.lnum + 1, item.col + 1, item.end_col)
- or string.format(
- ' [%d:%d-%d:%d]',
- item.lnum + 1,
- item.col + 1,
- item.end_lnum + 1,
- item.end_col
- ),
- self.opts.lang and string.format(' %s', item.lang) or '',
- })
- end
- api.nvim_buf_set_lines(bufnr, 0, -1, false, lines)
- vim.bo[bufnr].modifiable = false
-end
-
---- Get node {i} from this Playground object.
----
---- The node number is dependent on whether or not anonymous nodes are displayed.
----
----@param i number Node number to get
----@return Node
----@private
-function M.get(self, i)
- local t = self.opts.anon and self.nodes or self.named
- return t[i]
-end
-
---- Iterate over all of the nodes in this Playground object.
----
----@return function Iterator over all nodes in this Playground
----@return table
----@return number
----@private
-function M.iter(self)
- return ipairs(self.opts.anon and self.nodes or self.named)
-end
-
-return M
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index dbf134573d..8cbbffcd60 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -1,21 +1,25 @@
-local a = vim.api
+local api = vim.api
local language = require('vim.treesitter.language')
--- query: pattern matching on trees
--- predicate matching is implemented in lua
---
---@class Query
---@field captures string[] List of captures used in query
----@field info table Contains used queries, predicates, directives
+---@field info TSQueryInfo Contains used queries, predicates, directives
---@field query userdata Parsed query
local Query = {}
Query.__index = Query
+---@class TSQueryInfo
+---@field captures table
+---@field patterns table<string,any[][]>
+
+---@class TSQueryModule
local M = {}
----@private
+---@param files string[]
+---@return string[]
local function dedupe_files(files)
local result = {}
+ ---@type table<string,boolean>
local seen = {}
for _, path in ipairs(files) do
@@ -28,7 +32,6 @@ local function dedupe_files(files)
return result
end
----@private
local function safe_read(filename, read_quantifier)
local file, err = io.open(filename, 'r')
if not file then
@@ -39,7 +42,6 @@ local function safe_read(filename, read_quantifier)
return content
end
----@private
--- Adds {ilang} to {base_langs}, only if {ilang} is different than {lang}
---
---@return boolean true If lang == ilang
@@ -51,24 +53,34 @@ local function add_included_lang(base_langs, lang, ilang)
return false
end
+---@deprecated
+function M.get_query_files(...)
+ vim.deprecate(
+ 'vim.treesitter.query.get_query_files()',
+ 'vim.treesitter.query.get_files()',
+ '0.10'
+ )
+ return M.get_files(...)
+end
+
--- Gets the list of files used to make up a query
---
---@param lang string Language to get query for
---@param query_name string Name of the query to load (e.g., "highlights")
---@param is_included (boolean|nil) Internal parameter, most of the time left as `nil`
---@return string[] query_files List of files to load for given query and language
-function M.get_query_files(lang, query_name, is_included)
+function M.get_files(lang, query_name, is_included)
local query_path = string.format('queries/%s/%s.scm', lang, query_name)
- local lang_files = dedupe_files(a.nvim_get_runtime_file(query_path, true))
+ local lang_files = dedupe_files(api.nvim_get_runtime_file(query_path, true))
if #lang_files == 0 then
return {}
end
- local base_query = nil
+ local base_query = nil ---@type string?
local extensions = {}
- local base_langs = {}
+ local base_langs = {} ---@type string[]
-- Now get the base languages by looking at the first line of every file
-- The syntax is the following :
@@ -87,6 +99,7 @@ function M.get_query_files(lang, query_name, is_included)
local extension = false
for modeline in
+ ---@return string
function()
return file:read('*l')
end
@@ -97,6 +110,7 @@ function M.get_query_files(lang, query_name, is_included)
local langlist = modeline:match(MODELINE_FORMAT)
if langlist then
+ ---@diagnostic disable-next-line:param-type-mismatch
for _, incllang in ipairs(vim.split(langlist, ',', true)) do
local is_optional = incllang:match('%(.*%)')
@@ -127,7 +141,7 @@ function M.get_query_files(lang, query_name, is_included)
local query_files = {}
for _, base_lang in ipairs(base_langs) do
- local base_files = M.get_query_files(base_lang, query_name, true)
+ local base_files = M.get_files(base_lang, query_name, true)
vim.list_extend(query_files, base_files)
end
vim.list_extend(query_files, { base_query })
@@ -136,7 +150,8 @@ function M.get_query_files(lang, query_name, is_included)
return query_files
end
----@private
+---@param filenames string[]
+---@return string
local function read_query_files(filenames)
local contents = {}
@@ -147,7 +162,8 @@ local function read_query_files(filenames)
return table.concat(contents, '')
end
---- The explicitly set queries from |vim.treesitter.query.set_query()|
+-- The explicitly set queries from |vim.treesitter.query.set()|
+---@type table<string,table<string,Query>>
local explicit_queries = setmetatable({}, {
__index = function(t, k)
local lang_queries = {}
@@ -157,6 +173,12 @@ local explicit_queries = setmetatable({}, {
end,
})
+---@deprecated
+function M.set_query(...)
+ vim.deprecate('vim.treesitter.query.set_query()', 'vim.treesitter.query.set()', '0.10')
+ M.set(...)
+end
+
--- Sets the runtime query named {query_name} for {lang}
---
--- This allows users to override any runtime files and/or configuration
@@ -165,8 +187,14 @@ local explicit_queries = setmetatable({}, {
---@param lang string Language to use for the query
---@param query_name string Name of the query (e.g., "highlights")
---@param text string Query text (unparsed).
-function M.set_query(lang, query_name, text)
- explicit_queries[lang][query_name] = M.parse_query(lang, text)
+function M.set(lang, query_name, text)
+ explicit_queries[lang][query_name] = M.parse(lang, text)
+end
+
+---@deprecated
+function M.get_query(...)
+ vim.deprecate('vim.treesitter.query.get_query()', 'vim.treesitter.query.get()', '0.10')
+ return M.get(...)
end
--- Returns the runtime query {query_name} for {lang}.
@@ -174,24 +202,28 @@ end
---@param lang string Language to use for the query
---@param query_name string Name of the query (e.g. "highlights")
---
----@return Query Parsed query
-function M.get_query(lang, query_name)
+---@return Query|nil Parsed query
+M.get = vim.func._memoize('concat-2', function(lang, query_name)
if explicit_queries[lang][query_name] then
return explicit_queries[lang][query_name]
end
- local query_files = M.get_query_files(lang, query_name)
+ local query_files = M.get_files(lang, query_name)
local query_string = read_query_files(query_files)
- if #query_string > 0 then
- return M.parse_query(lang, query_string)
+ if #query_string == 0 then
+ return nil
end
-end
-local query_cache = vim.defaulttable(function()
- return setmetatable({}, { __mode = 'v' })
+ return M.parse(lang, query_string)
end)
+---@deprecated
+function M.parse_query(...)
+ vim.deprecate('vim.treesitter.query.parse_query()', 'vim.treesitter.query.parse()', '0.10')
+ return M.parse(...)
+end
+
--- Parse {query} as a string. (If the query is in a file, the caller
--- should read the contents into a string before calling).
---
@@ -209,81 +241,50 @@ end)
---@param query string Query in s-expr syntax
---
---@return Query Parsed query
-function M.parse_query(lang, query)
- language.require_language(lang)
- local cached = query_cache[lang][query]
- if cached then
- return cached
- else
- local self = setmetatable({}, Query)
- self.query = vim._ts_parse_query(lang, query)
- self.info = self.query:inspect()
- self.captures = self.info.captures
- query_cache[lang][query] = self
- return self
- end
-end
+M.parse = vim.func._memoize('concat-2', function(lang, query)
+ language.add(lang)
+
+ local self = setmetatable({}, Query)
+ self.query = vim._ts_parse_query(lang, query)
+ self.info = self.query:inspect()
+ self.captures = self.info.captures
+ return self
+end)
---- Gets the text corresponding to a given node
----
----@param node userdata |tsnode|
----@param source (number|string) Buffer or string from which the {node} is extracted
----@param opts (table|nil) Optional parameters.
---- - concat: (boolean) Concatenate result in a string (default true)
----@return (string[]|string)
-function M.get_node_text(node, source, opts)
- opts = opts or {}
- local concat = vim.F.if_nil(opts.concat, true)
-
- local start_row, start_col, start_byte = node:start()
- local end_row, end_col, end_byte = node:end_()
-
- if type(source) == 'number' then
- local lines
- local eof_row = a.nvim_buf_line_count(source)
- if start_row >= eof_row then
- return nil
- end
+---@deprecated
+function M.get_range(...)
+ vim.deprecate('vim.treesitter.query.get_range()', 'vim.treesitter.get_range()', '0.10')
+ return vim.treesitter.get_range(...)
+end
- if end_col == 0 then
- lines = a.nvim_buf_get_lines(source, start_row, end_row, true)
- end_col = -1
- else
- lines = a.nvim_buf_get_lines(source, start_row, end_row + 1, true)
- end
+---@deprecated
+function M.get_node_text(...)
+ vim.deprecate('vim.treesitter.query.get_node_text()', 'vim.treesitter.get_node_text()', '0.10')
+ return vim.treesitter.get_node_text(...)
+end
- if #lines > 0 then
- if #lines == 1 then
- lines[1] = string.sub(lines[1], start_col + 1, end_col)
- else
- lines[1] = string.sub(lines[1], start_col + 1)
- lines[#lines] = string.sub(lines[#lines], 1, end_col)
- end
- end
+---@alias TSMatch table<integer,TSNode>
- return concat and table.concat(lines, '\n') or lines
- elseif type(source) == 'string' then
- return source:sub(start_byte + 1, end_byte)
- end
-end
+---@alias TSPredicate fun(match: TSMatch, _, _, predicate: any[]): boolean
-- Predicate handler receive the following arguments
-- (match, pattern, bufnr, predicate)
+---@type table<string,TSPredicate>
local predicate_handlers = {
['eq?'] = function(match, _, source, predicate)
local node = match[predicate[2]]
if not node then
return true
end
- local node_text = M.get_node_text(node, source)
+ local node_text = vim.treesitter.get_node_text(node, source)
- local str
+ local str ---@type string
if type(predicate[3]) == 'string' then
-- (#eq? @aa "foo")
str = predicate[3]
else
-- (#eq? @aa @bb)
- str = M.get_node_text(match[predicate[3]], source)
+ str = vim.treesitter.get_node_text(match[predicate[3]], source)
end
if node_text ~= str or str == nil then
@@ -299,12 +300,11 @@ local predicate_handlers = {
return true
end
local regex = predicate[3]
- return string.find(M.get_node_text(node, source), regex)
+ return string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil
end,
['match?'] = (function()
local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true }
- ---@private
local function check_magic(str)
if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then
return str
@@ -321,12 +321,14 @@ local predicate_handlers = {
})
return function(match, _, source, pred)
+ ---@cast match TSMatch
local node = match[pred[2]]
if not node then
return true
end
+ ---@diagnostic disable-next-line no-unknown
local regex = compiled_vim_regexes[pred[3]]
- return regex:match_str(M.get_node_text(node, source))
+ return regex:match_str(vim.treesitter.get_node_text(node, source))
end
end)(),
@@ -335,7 +337,7 @@ local predicate_handlers = {
if not node then
return true
end
- local node_text = M.get_node_text(node, source)
+ local node_text = vim.treesitter.get_node_text(node, source)
for i = 3, #predicate do
if string.find(node_text, predicate[i], 1, true) then
@@ -351,7 +353,7 @@ local predicate_handlers = {
if not node then
return true
end
- local node_text = M.get_node_text(node, source)
+ local node_text = vim.treesitter.get_node_text(node, source)
-- Since 'predicate' will not be used by callers of this function, use it
-- to store a string set built from the list of words to check against.
@@ -359,6 +361,7 @@ local predicate_handlers = {
if not string_set then
string_set = {}
for i = 3, #predicate do
+ ---@diagnostic disable-next-line:no-unknown
string_set[predicate[i]] = true
end
predicate['string_set'] = string_set
@@ -366,36 +369,85 @@ local predicate_handlers = {
return string_set[node_text]
end,
+
+ ['has-ancestor?'] = function(match, _, _, predicate)
+ local node = match[predicate[2]]
+ if not node then
+ return true
+ end
+
+ local ancestor_types = {}
+ for _, type in ipairs({ unpack(predicate, 3) }) do
+ ancestor_types[type] = true
+ end
+
+ node = node:parent()
+ while node do
+ if ancestor_types[node:type()] then
+ return true
+ end
+ node = node:parent()
+ end
+ return false
+ end,
+
+ ['has-parent?'] = function(match, _, _, predicate)
+ local node = match[predicate[2]]
+ if not node then
+ return true
+ end
+
+ if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then
+ return true
+ end
+ return false
+ end,
}
-- As we provide lua-match? also expose vim-match?
predicate_handlers['vim-match?'] = predicate_handlers['match?']
+---@class TSMetadata
+---@field range? Range
+---@field conceal? string
+---@field [integer] TSMetadata
+---@field [string] integer|string
+
+---@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata)
+
+-- Predicate handler receive the following arguments
+-- (match, pattern, bufnr, predicate)
+
-- Directives store metadata or perform side effects against a match.
-- Directives should always end with a `!`.
-- Directive handler receive the following arguments
-- (match, pattern, bufnr, predicate, metadata)
+---@type table<string,TSDirective>
local directive_handlers = {
['set!'] = function(_, _, _, pred, metadata)
- if #pred == 4 then
- -- (#set! @capture "key" "value")
- local _, capture_id, key, value = unpack(pred)
+ if #pred >= 3 and type(pred[2]) == 'number' then
+ -- (#set! @capture key value)
+ local capture_id, key, value = pred[2], pred[3], pred[4]
if not metadata[capture_id] then
metadata[capture_id] = {}
end
metadata[capture_id][key] = value
else
- local _, key, value = unpack(pred)
- -- (#set! "key" "value")
- metadata[key] = value
+ -- (#set! key value)
+ local key, value = pred[2], pred[3]
+ metadata[key] = value or true
end
end,
-- Shifts the range of a node.
-- Example: (#offset! @_node 0 1 0 -1)
['offset!'] = function(match, _, _, pred, metadata)
+ ---@cast pred integer[]
local capture_id = pred[2]
- local offset_node = match[capture_id]
- local range = { offset_node:range() }
+ if not metadata[capture_id] then
+ metadata[capture_id] = {}
+ end
+
+ local range = metadata[capture_id].range or { match[capture_id]:range() }
local start_row_offset = pred[3] or 0
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
@@ -408,19 +460,74 @@ local directive_handlers = {
-- If this produces an invalid range, we just skip it.
if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
- if not metadata[capture_id] then
- metadata[capture_id] = {}
- end
metadata[capture_id].range = range
end
end,
+ -- Transform the content of the node
+ -- Example: (#gsub! @_node ".*%.(.*)" "%1")
+ ['gsub!'] = function(match, _, bufnr, pred, metadata)
+ assert(#pred == 4)
+
+ local id = pred[2]
+ assert(type(id) == 'number')
+
+ local node = match[id]
+ local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
+
+ if not metadata[id] then
+ metadata[id] = {}
+ end
+
+ local pattern, replacement = pred[3], pred[4]
+ assert(type(pattern) == 'string')
+ assert(type(replacement) == 'string')
+
+ metadata[id].text = text:gsub(pattern, replacement)
+ end,
+ -- Trim blank lines from end of the node
+ -- Example: (#trim! @fold)
+ -- TODO(clason): generalize to arbitrary whitespace removal
+ ['trim!'] = function(match, _, bufnr, pred, metadata)
+ local capture_id = pred[2]
+ assert(type(capture_id) == 'number')
+
+ local node = match[capture_id]
+ if not node then
+ return
+ end
+
+ local start_row, start_col, end_row, end_col = node:range()
+
+ -- Don't trim if region ends in middle of a line
+ if end_col ~= 0 then
+ return
+ end
+
+ while end_row >= start_row do
+ -- As we only care when end_col == 0, always inspect one line above end_row.
+ local end_line = api.nvim_buf_get_lines(bufnr, end_row - 1, end_row, true)[1]
+
+ if end_line ~= '' then
+ break
+ end
+
+ end_row = end_row - 1
+ end
+
+ -- If this produces an invalid range, we just skip it.
+ if start_row < end_row or (start_row == end_row and start_col <= end_col) then
+ metadata[capture_id] = metadata[capture_id] or {}
+ metadata[capture_id].range = { start_row, start_col, end_row, end_col }
+ end
+ end,
}
--- Adds a new predicate to be used in queries
---
---@param name string Name of the predicate, without leading #
----@param handler function(match:table, pattern:string, bufnr:number, predicate:string[])
+---@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[])
--- - see |vim.treesitter.query.add_directive()| for argument meanings
+---@param force boolean|nil
function M.add_predicate(name, handler, force)
if predicate_handlers[name] and not force then
error(string.format('Overriding %s', name))
@@ -437,12 +544,13 @@ end
--- metadata table `metadata[capture_id].key = value`
---
---@param name string Name of the directive, without leading #
----@param handler function(match:table, pattern:string, bufnr:number, predicate:string[], metadata:table)
+---@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[], metadata:table)
--- - match: see |treesitter-query|
--- - node-level data are accessible via `match[capture_id]`
--- - pattern: see |treesitter-query|
--- - predicate: list of strings containing the full directive being called, e.g.
--- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }`
+---@param force boolean|nil
function M.add_directive(name, handler, force)
if directive_handlers[name] and not force then
error(string.format('Overriding %s', name))
@@ -463,17 +571,18 @@ function M.list_predicates()
return vim.tbl_keys(predicate_handlers)
end
----@private
local function xor(x, y)
return (x or y) and not (x and y)
end
----@private
local function is_directive(name)
return string.sub(name, -1) == '!'
end
---@private
+---@param match TSMatch
+---@param pattern string
+---@param source integer|string
function Query:match_preds(match, pattern, source)
local preds = self.info.patterns[pattern]
@@ -482,8 +591,9 @@ function Query:match_preds(match, pattern, source)
-- continue on the other case. This way unknown predicates will not be considered,
-- which allows some testing and easier user extensibility (#12173).
-- Also, tree-sitter strips the leading # from predicates for us.
- local pred_name
- local is_not
+ local pred_name ---@type string
+
+ local is_not ---@type boolean
-- Skip over directives... they will get processed after all the predicates.
if not is_directive(pred[1]) then
@@ -513,6 +623,8 @@ function Query:match_preds(match, pattern, source)
end
---@private
+---@param match TSMatch
+---@param metadata TSMetadata
function Query:apply_directives(match, pattern, source, metadata)
local preds = self.info.patterns[pattern]
@@ -533,7 +645,10 @@ end
--- Returns the start and stop value if set else the node's range.
-- When the node's range is used, the stop is incremented by 1
-- to make the search inclusive.
----@private
+---@param start integer
+---@param stop integer
+---@param node TSNode
+---@return integer, integer
local function value_or_node_range(start, stop, node)
if start == nil and stop == nil then
local node_start, _, node_stop, _ = node:range()
@@ -547,42 +662,41 @@ end
---
--- {source} is needed if the query contains predicates; then the caller
--- must ensure to use a freshly parsed tree consistent with the current
---- text of the buffer (if relevant). {start_row} and {end_row} can be used to limit
+--- text of the buffer (if relevant). {start} and {stop} can be used to limit
--- matches inside a row range (this is typically used with root node
--- as the {node}, i.e., to get syntax highlight matches in the current
---- viewport). When omitted, the {start} and {end} row values are used from the given node.
+--- viewport). When omitted, the {start} and {stop} row values are used from the given node.
---
--- The iterator returns three values: a numeric id identifying the capture,
--- the captured node, and metadata from any directives processing the match.
--- The following example shows how to get captures by name:
---- <pre>lua
+---
+--- ```lua
--- for id, node, metadata in query:iter_captures(tree:root(), bufnr, first, last) do
--- local name = query.captures[id] -- name of the capture in the query
--- -- typically useful info about the node:
--- local type = node:type() -- type of the captured node
--- local row1, col1, row2, col2 = node:range() -- range of the capture
---- ... use the info here ...
+--- -- ... use the info here ...
--- end
---- </pre>
+--- ```
---
----@param node userdata |tsnode| under which the search will occur
----@param source (number|string) Source buffer or string to extract text from
----@param start number Starting line for the search
----@param stop number Stopping line for the search (end-exclusive)
+---@param node TSNode under which the search will occur
+---@param source (integer|string) Source buffer or string to extract text from
+---@param start integer Starting line for the search
+---@param stop integer Stopping line for the search (end-exclusive)
---
----@return number capture Matching capture id
----@return table capture_node Capture for {node}
----@return table metadata for the {capture}
+---@return (fun(end_line: integer|nil): integer, TSNode, TSMetadata):
+--- capture id, capture node, metadata
function Query:iter_captures(node, source, start, stop)
if type(source) == 'number' and source == 0 then
- source = vim.api.nvim_get_current_buf()
+ source = api.nvim_get_current_buf()
end
start, stop = value_or_node_range(start, stop, node)
local raw_iter = node:_rawquery(self.query, true, start, stop)
- ---@private
- local function iter()
+ local function iter(end_line)
local capture, captured_node, match = raw_iter()
local metadata = {}
@@ -590,7 +704,10 @@ function Query:iter_captures(node, source, start, stop)
local active = self:match_preds(match, match.pattern, source)
match.active = active
if not active then
- return iter() -- tail call: try next match
+ if end_line and captured_node:range() > end_line then
+ return nil, captured_node, nil
+ end
+ return iter(end_line) -- tail call: try next match
end
self:apply_directives(match, match.pattern, source, metadata)
@@ -609,7 +726,8 @@ end
--- If the query has more than one pattern, the capture table might be sparse
--- and e.g. `pairs()` method should be used over `ipairs`.
--- Here is an example iterating over all captures in every match:
---- <pre>lua
+---
+--- ```lua
--- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do
--- for id, node in pairs(match) do
--- local name = query.captures[id]
@@ -617,27 +735,30 @@ end
---
--- local node_data = metadata[id] -- Node level metadata
---
---- ... use the info here ...
+--- -- ... use the info here ...
--- end
--- end
---- </pre>
+--- ```
---
----@param node userdata |tsnode| under which the search will occur
----@param source (number|string) Source buffer or string to search
----@param start number Starting line for the search
----@param stop number Stopping line for the search (end-exclusive)
+---@param node TSNode under which the search will occur
+---@param source (integer|string) Source buffer or string to search
+---@param start integer Starting line for the search
+---@param stop integer Stopping line for the search (end-exclusive)
+---@param opts table|nil Options:
+--- - max_start_depth (integer) if non-zero, sets the maximum start depth
+--- for each match. This is used to prevent traversing too deep into a tree.
+--- Requires treesitter >= 0.20.9.
---
----@return number pattern id
----@return table match
----@return table metadata
-function Query:iter_matches(node, source, start, stop)
+---@return (fun(): integer, table<integer,TSNode>, table): pattern id, match, metadata
+function Query:iter_matches(node, source, start, stop, opts)
if type(source) == 'number' and source == 0 then
- source = vim.api.nvim_get_current_buf()
+ source = api.nvim_get_current_buf()
end
start, stop = value_or_node_range(start, stop, node)
- local raw_iter = node:_rawquery(self.query, false, start, stop)
+ local raw_iter = node:_rawquery(self.query, false, start, stop, opts)
+ ---@cast raw_iter fun(): string, any
local function iter()
local pattern, match = raw_iter()
local metadata = {}
@@ -655,4 +776,58 @@ function Query:iter_matches(node, source, start, stop)
return iter
end
+---@class QueryLinterOpts
+---@field langs (string|string[]|nil)
+---@field clear (boolean)
+
+--- Lint treesitter queries using installed parser, or clear lint errors.
+---
+--- Use |treesitter-parsers| in runtimepath to check the query file in {buf} for errors:
+---
+--- - verify that used nodes are valid identifiers in the grammar.
+--- - verify that predicates and directives are valid.
+--- - verify that top-level s-expressions are valid.
+---
+--- The found diagnostics are reported using |diagnostic-api|.
+--- By default, the parser used for verification is determined by the containing folder
+--- of the query file, e.g., if the path ends in `/lua/highlights.scm`, the parser for the
+--- `lua` language will be used.
+---@param buf (integer) Buffer handle
+---@param opts (QueryLinterOpts|nil) Optional keyword arguments:
+--- - langs (string|string[]|nil) Language(s) to use for checking the query.
+--- If multiple languages are specified, queries are validated for all of them
+--- - clear (boolean) if `true`, just clear current lint errors
+function M.lint(buf, opts)
+ if opts and opts.clear then
+ require('vim.treesitter._query_linter').clear(buf)
+ else
+ require('vim.treesitter._query_linter').lint(buf, opts)
+ end
+end
+
+--- Omnifunc for completing node names and predicates in treesitter queries.
+---
+--- Use via
+---
+--- ```lua
+--- vim.bo.omnifunc = 'v:lua.vim.treesitter.query.omnifunc'
+--- ```
+---
+function M.omnifunc(findstart, base)
+ return require('vim.treesitter._query_linter').omnifunc(findstart, base)
+end
+
+--- Opens a live editor to query the buffer you started from.
+---
+--- Can also be shown with *:EditQuery*.
+---
+--- If you move the cursor to a capture name ("@foo"), text matching the capture is highlighted in
+--- the source buffer. The query editor is a scratch buffer, use `:write` to save it. You can find
+--- example queries at `$VIMRUNTIME/queries/`.
+---
+--- @param lang? string language to open the query editor for. If omitted, inferred from the current buffer's filetype.
+function M.edit(lang)
+ require('vim.treesitter.dev').edit_query(lang)
+end
+
return M