diff options
author | Josh Rahm <joshuarahm@gmail.com> | 2023-11-30 20:35:25 +0000 |
---|---|---|
committer | Josh Rahm <joshuarahm@gmail.com> | 2023-11-30 20:35:25 +0000 |
commit | 1b7b916b7631ddf73c38e3a0070d64e4636cb2f3 (patch) | |
tree | cd08258054db80bb9a11b1061bb091c70b76926a /runtime/lua/vim/treesitter | |
parent | eaa89c11d0f8aefbb512de769c6c82f61a8baca3 (diff) | |
parent | 4a8bf24ac690004aedf5540fa440e788459e5e34 (diff) | |
download | rneovim-1b7b916b7631ddf73c38e3a0070d64e4636cb2f3.tar.gz rneovim-1b7b916b7631ddf73c38e3a0070d64e4636cb2f3.tar.bz2 rneovim-1b7b916b7631ddf73c38e3a0070d64e4636cb2f3.zip |
Merge remote-tracking branch 'upstream/master' into aucmd_textputpostaucmd_textputpost
Diffstat (limited to 'runtime/lua/vim/treesitter')
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 456 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_meta.lua | 80 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_query_linter.lua | 249 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_range.lua | 193 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/dev.lua | 645 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/health.lua | 30 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/highlighter.lua | 183 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/language.lua | 144 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 1041 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/playground.lua | 186 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 441 |
11 files changed, 2957 insertions, 691 deletions
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua new file mode 100644 index 0000000000..5c1cc06908 --- /dev/null +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -0,0 +1,456 @@ +local ts = vim.treesitter + +local Range = require('vim.treesitter._range') + +local api = vim.api + +---@class TS.FoldInfo +---@field levels table<integer,string> +---@field levels0 table<integer,integer> +---@field private start_counts table<integer,integer> +---@field private stop_counts table<integer,integer> +local FoldInfo = {} +FoldInfo.__index = FoldInfo + +---@private +function FoldInfo.new() + return setmetatable({ + start_counts = {}, + stop_counts = {}, + levels0 = {}, + levels = {}, + }, FoldInfo) +end + +---@package +---@param srow integer +---@param erow integer +function FoldInfo:invalidate_range(srow, erow) + for i = srow, erow do + self.start_counts[i + 1] = nil + self.stop_counts[i + 1] = nil + self.levels0[i + 1] = nil + self.levels[i + 1] = nil + end +end + +--- Efficiently remove items from middle of a list a list. +--- +--- Calling table.remove() in a loop will re-index the tail of the table on +--- every iteration, instead this function will re-index the table exactly +--- once. +--- +--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 +--- +---@param t any[] +---@param first integer +---@param last integer +local function list_remove(t, first, last) + local n = #t + for i = 0, n - first do + t[first + i] = t[last + 1 + i] + t[last + 1 + i] = nil + end +end + +---@package +---@param srow integer +---@param erow integer +function FoldInfo:remove_range(srow, erow) + list_remove(self.levels, srow + 1, erow) + list_remove(self.levels0, srow + 1, erow) + list_remove(self.start_counts, srow + 1, erow) + list_remove(self.stop_counts, srow + 1, erow) +end + +--- Efficiently insert items into the middle of a list. +--- +--- Calling table.insert() in a loop will re-index the tail of the table on +--- every iteration, instead this function will re-index the table exactly +--- once. +--- +--- Based on https://stackoverflow.com/questions/12394841/safely-remove-items-from-an-array-table-while-iterating/53038524#53038524 +--- +---@param t any[] +---@param first integer +---@param last integer +---@param v any +local function list_insert(t, first, last, v) + local n = #t + + -- Shift table forward + for i = n - first, 0, -1 do + t[last + 1 + i] = t[first + i] + end + + -- Fill in new values + for i = first, last do + t[i] = v + end +end + +---@package +---@param srow integer +---@param erow integer +function FoldInfo:add_range(srow, erow) + list_insert(self.levels, srow + 1, erow, '-1') + list_insert(self.levels0, srow + 1, erow, -1) + list_insert(self.start_counts, srow + 1, erow, nil) + list_insert(self.stop_counts, srow + 1, erow, nil) +end + +---@package +---@param lnum integer +function FoldInfo:add_start(lnum) + self.start_counts[lnum] = (self.start_counts[lnum] or 0) + 1 +end + +---@package +---@param lnum integer +function FoldInfo:add_stop(lnum) + self.stop_counts[lnum] = (self.stop_counts[lnum] or 0) + 1 +end + +---@package +---@param lnum integer +---@return integer +function FoldInfo:get_start(lnum) + return self.start_counts[lnum] or 0 +end + +---@package +---@param lnum integer +---@return integer +function FoldInfo:get_stop(lnum) + return self.stop_counts[lnum] or 0 +end + +local function trim_level(level) + local max_fold_level = vim.wo.foldnestmax + if level > max_fold_level then + return max_fold_level + end + return level +end + +--- If a parser doesn't have any ranges explicitly set, treesitter will +--- return a range with end_row and end_bytes with a value of UINT32_MAX, +--- so clip end_row to the max buffer line. +--- +--- TODO(lewis6991): Handle this generally +--- +--- @param bufnr integer +--- @param erow integer? +--- @return integer +local function normalise_erow(bufnr, erow) + local max_erow = api.nvim_buf_line_count(bufnr) - 1 + return math.min(erow or max_erow, max_erow) +end + +-- TODO(lewis6991): Setup a decor provider so injections folds can be parsed +-- as the window is redrawn +---@param bufnr integer +---@param info TS.FoldInfo +---@param srow integer? +---@param erow integer? +---@param parse_injections? boolean +local function get_folds_levels(bufnr, info, srow, erow, parse_injections) + srow = srow or 0 + erow = normalise_erow(bufnr, erow) + + info:invalidate_range(srow, erow) + + local prev_start = -1 + local prev_stop = -1 + + local parser = ts.get_parser(bufnr) + + parser:parse(parse_injections and { srow, erow } or nil) + + parser:for_each_tree(function(tree, ltree) + local query = ts.query.get(ltree:lang(), 'folds') + if not query then + return + end + + -- erow in query is end-exclusive + local q_erow = erow and erow + 1 or -1 + + for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow, q_erow) do + if query.captures[id] == 'fold' then + local range = ts.get_range(node, bufnr, metadata[id]) + local start, _, stop, stop_col = Range.unpack4(range) + + if stop_col == 0 then + stop = stop - 1 + end + + local fold_length = stop - start + 1 + + -- Fold only multiline nodes that are not exactly the same as previously met folds + -- Checking against just the previously found fold is sufficient if nodes + -- are returned in preorder or postorder when traversing tree + if + fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop) + then + info:add_start(start + 1) + info:add_stop(stop + 1) + prev_start = start + prev_stop = stop + end + end + end + end) + + local current_level = info.levels0[srow] or 0 + + -- We now have the list of fold opening and closing, fill the gaps and mark where fold start + for lnum = srow + 1, erow + 1 do + local last_trimmed_level = trim_level(current_level) + current_level = current_level + info:get_start(lnum) + info.levels0[lnum] = current_level + + local trimmed_level = trim_level(current_level) + current_level = current_level - info:get_stop(lnum) + + -- Determine if it's the start/end of a fold + -- NB: vim's fold-expr interface does not have a mechanism to indicate that + -- two (or more) folds start at this line, so it cannot distinguish between + -- ( \n ( \n )) \n (( \n ) \n ) + -- versus + -- ( \n ( \n ) \n ( \n ) \n ) + -- If it did have such a mechanism, (trimmed_level - last_trimmed_level) + -- would be the correct number of starts to pass on. + local prefix = '' + if trimmed_level - last_trimmed_level > 0 then + prefix = '>' + end + + info.levels[lnum] = prefix .. tostring(trimmed_level) + end +end + +local M = {} + +---@type table<integer,TS.FoldInfo> +local foldinfos = {} + +local group = api.nvim_create_augroup('treesitter/fold', {}) + +--- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that +--- the user doesn't use different foldexpr for the same buffer). +--- +--- Nvim usually automatically updates folds when text changes, but it doesn't work here because +--- FoldInfo update is scheduled. So we do it manually. +local function foldupdate(bufnr) + local function do_update() + for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do + api.nvim_win_call(win, function() + if vim.wo.foldmethod == 'expr' then + vim._foldupdate() + end + end) + end + end + + if api.nvim_get_mode().mode == 'i' then + -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave + if #(api.nvim_get_autocmds({ + group = group, + buffer = bufnr, + })) > 0 then + return + end + api.nvim_create_autocmd('InsertLeave', { + group = group, + buffer = bufnr, + once = true, + callback = do_update, + }) + return + end + + do_update() +end + +--- Schedule a function only if bufnr is loaded. +--- We schedule fold level computation for the following reasons: +--- * queries seem to use the old buffer state in on_bytes for some unknown reason; +--- * to avoid textlock; +--- * to avoid infinite recursion: +--- get_folds_levels → parse → _do_callback → on_changedtree → get_folds_levels. +---@param bufnr integer +---@param fn function +local function schedule_if_loaded(bufnr, fn) + vim.schedule(function() + if not api.nvim_buf_is_loaded(bufnr) then + return + end + fn() + end) +end + +---@param bufnr integer +---@param foldinfo TS.FoldInfo +---@param tree_changes Range4[] +local function on_changedtree(bufnr, foldinfo, tree_changes) + schedule_if_loaded(bufnr, function() + for _, change in ipairs(tree_changes) do + local srow, _, erow = Range.unpack4(change) + get_folds_levels(bufnr, foldinfo, srow, erow) + end + if #tree_changes > 0 then + foldupdate(bufnr) + end + end) +end + +---@param bufnr integer +---@param foldinfo TS.FoldInfo +---@param start_row integer +---@param old_row integer +---@param new_row integer +local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) + local end_row_old = start_row + old_row + local end_row_new = start_row + new_row + + if new_row ~= old_row then + if new_row < old_row then + foldinfo:remove_range(end_row_new, end_row_old) + else + foldinfo:add_range(start_row, end_row_new) + end + schedule_if_loaded(bufnr, function() + get_folds_levels(bufnr, foldinfo, start_row, end_row_new) + foldupdate(bufnr) + end) + end +end + +---@package +---@param lnum integer|nil +---@return string +function M.foldexpr(lnum) + lnum = lnum or vim.v.lnum + local bufnr = api.nvim_get_current_buf() + + local parser = vim.F.npcall(ts.get_parser, bufnr) + if not parser then + return '0' + end + + if not foldinfos[bufnr] then + foldinfos[bufnr] = FoldInfo.new() + get_folds_levels(bufnr, foldinfos[bufnr]) + + parser:register_cbs({ + on_changedtree = function(tree_changes) + on_changedtree(bufnr, foldinfos[bufnr], tree_changes) + end, + + on_bytes = function(_, _, start_row, _, _, old_row, _, _, new_row, _, _) + on_bytes(bufnr, foldinfos[bufnr], start_row, old_row, new_row) + end, + + on_detach = function() + foldinfos[bufnr] = nil + end, + }) + end + + return foldinfos[bufnr].levels[lnum] or '0' +end + +---@package +---@return { [1]: string, [2]: string[] }[]|string +function M.foldtext() + local foldstart = vim.v.foldstart + local bufnr = api.nvim_get_current_buf() + + ---@type boolean, LanguageTree + local ok, parser = pcall(ts.get_parser, bufnr) + if not ok then + return vim.fn.foldtext() + end + + local query = ts.query.get(parser:lang(), 'highlights') + if not query then + return vim.fn.foldtext() + end + + local tree = parser:parse({ foldstart - 1, foldstart })[1] + + local line = api.nvim_buf_get_lines(bufnr, foldstart - 1, foldstart, false)[1] + if not line then + return vim.fn.foldtext() + end + + ---@type { [1]: string, [2]: string[], range: { [1]: integer, [2]: integer } }[] | { [1]: string, [2]: string[] }[] + local result = {} + + local line_pos = 0 + + for id, node, metadata in query:iter_captures(tree:root(), 0, foldstart - 1, foldstart) do + local name = query.captures[id] + local start_row, start_col, end_row, end_col = node:range() + + local priority = tonumber(metadata.priority or vim.highlight.priorities.treesitter) + + if start_row == foldstart - 1 and end_row == foldstart - 1 then + -- check for characters ignored by treesitter + if start_col > line_pos then + table.insert(result, { + line:sub(line_pos + 1, start_col), + {}, + range = { line_pos, start_col }, + }) + end + line_pos = end_col + + local text = line:sub(start_col + 1, end_col) + table.insert(result, { text, { { '@' .. name, priority } }, range = { start_col, end_col } }) + end + end + + local i = 1 + while i <= #result do + -- find first capture that is not in current range and apply highlights on the way + local j = i + 1 + while + j <= #result + and result[j].range[1] >= result[i].range[1] + and result[j].range[2] <= result[i].range[2] + do + for k, v in ipairs(result[i][2]) do + if not vim.tbl_contains(result[j][2], v) then + table.insert(result[j][2], k, v) + end + end + j = j + 1 + end + + -- remove the parent capture if it is split into children + if j > i + 1 then + table.remove(result, i) + else + -- highlights need to be sorted by priority, on equal prio, the deeper nested capture (earlier + -- in list) should be considered higher prio + if #result[i][2] > 1 then + table.sort(result[i][2], function(a, b) + return a[2] < b[2] + end) + end + + result[i][2] = vim.tbl_map(function(tbl) + return tbl[1] + end, result[i][2]) + result[i] = { result[i][1], result[i][2] } + + i = i + 1 + end + end + + return result +end + +return M diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua new file mode 100644 index 0000000000..80c998b555 --- /dev/null +++ b/runtime/lua/vim/treesitter/_meta.lua @@ -0,0 +1,80 @@ +---@meta + +---@class TSNode: userdata +---@field id fun(self: TSNode): string +---@field tree fun(self: TSNode): TSTree +---@field range fun(self: TSNode, include_bytes: false?): integer, integer, integer, integer +---@field range fun(self: TSNode, include_bytes: true): integer, integer, integer, integer, integer, integer +---@field start fun(self: TSNode): integer, integer, integer +---@field end_ fun(self: TSNode): integer, integer, integer +---@field type fun(self: TSNode): string +---@field symbol fun(self: TSNode): integer +---@field named fun(self: TSNode): boolean +---@field missing fun(self: TSNode): boolean +---@field extra fun(self: TSNode): boolean +---@field child_count fun(self: TSNode): integer +---@field named_child_count fun(self: TSNode): integer +---@field child fun(self: TSNode, index: integer): TSNode? +---@field named_child fun(self: TSNode, index: integer): TSNode? +---@field descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode? +---@field named_descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode? +---@field parent fun(self: TSNode): TSNode? +---@field next_sibling fun(self: TSNode): TSNode? +---@field prev_sibling fun(self: TSNode): TSNode? +---@field next_named_sibling fun(self: TSNode): TSNode? +---@field prev_named_sibling fun(self: TSNode): TSNode? +---@field named_children fun(self: TSNode): TSNode[] +---@field has_changes fun(self: TSNode): boolean +---@field has_error fun(self: TSNode): boolean +---@field sexpr fun(self: TSNode): string +---@field equal fun(self: TSNode, other: TSNode): boolean +---@field iter_children fun(self: TSNode): fun(): TSNode, string +---@field field fun(self: TSNode, name: string): TSNode[] +---@field byte_length fun(self: TSNode): integer +local TSNode = {} + +---@param query userdata +---@param captures true +---@param start? integer +---@param end_? integer +---@param opts? table +---@return fun(): integer, TSNode, any +function TSNode:_rawquery(query, captures, start, end_, opts) end + +---@param query userdata +---@param captures false +---@param start? integer +---@param end_? integer +---@param opts? table +---@return fun(): string, any +function TSNode:_rawquery(query, captures, start, end_, opts) end + +---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string) + +---@class TSParser +---@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: true): TSTree, Range6[] +---@field parse fun(self: TSParser, tree: TSTree?, source: integer|string, include_bytes: false|nil): TSTree, Range4[] +---@field reset fun(self: TSParser) +---@field included_ranges fun(self: TSParser, include_bytes: boolean?): integer[] +---@field set_included_ranges fun(self: TSParser, ranges: (Range6|TSNode)[]) +---@field set_timeout fun(self: TSParser, timeout: integer) +---@field timeout fun(self: TSParser): integer +---@field _set_logger fun(self: TSParser, lex: boolean, parse: boolean, cb: TSLoggerCallback) +---@field _logger fun(self: TSParser): TSLoggerCallback + +---@class TSTree +---@field root fun(self: TSTree): TSNode +---@field edit fun(self: TSTree, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _: integer, _:integer) +---@field copy fun(self: TSTree): TSTree +---@field included_ranges fun(self: TSTree, include_bytes: true): Range6[] +---@field included_ranges fun(self: TSTree, include_bytes: false): Range4[] + +---@return integer +vim._ts_get_language_version = function() end + +---@return integer +vim._ts_get_minimum_language_version = function() end + +---@param lang string +---@return TSParser +vim._create_ts_parser = function(lang) end diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua new file mode 100644 index 0000000000..87d74789a3 --- /dev/null +++ b/runtime/lua/vim/treesitter/_query_linter.lua @@ -0,0 +1,249 @@ +local api = vim.api + +local namespace = api.nvim_create_namespace('vim.treesitter.query_linter') + +local M = {} + +--- @class QueryLinterNormalizedOpts +--- @field langs string[] +--- @field clear boolean + +--- @alias vim.treesitter.ParseError {msg: string, range: Range4} + +--- Contains language dependent context for the query linter +--- @class QueryLinterLanguageContext +--- @field lang string? Current `lang` of the targeted parser +--- @field parser_info table? Parser info returned by vim.treesitter.language.inspect +--- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs` + +--- Adds a diagnostic for node in the query buffer +--- @param diagnostics Diagnostic[] +--- @param range Range4 +--- @param lint string +--- @param lang string? +local function add_lint_for_node(diagnostics, range, lint, lang) + local message = lint:gsub('\n', ' ') + diagnostics[#diagnostics + 1] = { + lnum = range[1], + end_lnum = range[3], + col = range[2], + end_col = range[4], + severity = vim.diagnostic.ERROR, + message = message, + source = lang, + } +end + +--- Determines the target language of a query file by its path: <lang>/<query_type>.scm +--- @param buf integer +--- @return string? +local function guess_query_lang(buf) + local filename = api.nvim_buf_get_name(buf) + if filename ~= '' then + return vim.F.npcall(vim.fn.fnamemodify, filename, ':p:h:t') + end +end + +--- @param buf integer +--- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil +--- @return QueryLinterNormalizedOpts +local function normalize_opts(buf, opts) + opts = opts or {} + if not opts.langs then + opts.langs = guess_query_lang(buf) + end + + if type(opts.langs) ~= 'table' then + --- @diagnostic disable-next-line:assign-type-mismatch + opts.langs = { opts.langs } + end + + --- @cast opts QueryLinterNormalizedOpts + opts.langs = opts.langs or {} + return opts +end + +local lint_query = [[;; query + (program [(named_node) (list) (grouping)] @toplevel) + (named_node + name: _ @node.named) + (anonymous_node + name: _ @node.anonymous) + (field_definition + name: (identifier) @field) + (predicate + name: (identifier) @predicate.name + type: (predicate_type) @predicate.type) + (ERROR) @error +]] + +--- @param err string +--- @param node TSNode +--- @return vim.treesitter.ParseError +local function get_error_entry(err, node) + local start_line, start_col = node:range() + local line_offset, col_offset, msg = err:gmatch('.-:%d+: Query error at (%d+):(%d+)%. ([^:]+)')() ---@type string, string, string + start_line, start_col = + start_line + tonumber(line_offset) - 1, start_col + tonumber(col_offset) - 1 + local end_line, end_col = start_line, start_col + if msg:match('^Invalid syntax') or msg:match('^Impossible') then + -- Use the length of the underlined node + local underlined = vim.split(err, '\n')[2] + end_col = end_col + #underlined + elseif msg:match('^Invalid') then + -- Use the length of the problematic type/capture/field + end_col = end_col + #msg:match('"([^"]+)"') + end + + return { + msg = msg, + range = { start_line, start_col, end_line, end_col }, + } +end + +--- @param node TSNode +--- @param buf integer +--- @param lang string +local function hash_parse(node, buf, lang) + return tostring(node:id()) .. tostring(buf) .. tostring(vim.b[buf].changedtick) .. lang +end + +--- @param node TSNode +--- @param buf integer +--- @param lang string +--- @return vim.treesitter.ParseError? +local parse = vim.func._memoize(hash_parse, function(node, buf, lang) + local query_text = vim.treesitter.get_node_text(node, buf) + local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query + + if not ok and type(err) == 'string' then + return get_error_entry(err, node) + end +end) + +--- @param buf integer +--- @param match table<integer,TSNode> +--- @param query Query +--- @param lang_context QueryLinterLanguageContext +--- @param diagnostics Diagnostic[] +local function lint_match(buf, match, query, lang_context, diagnostics) + local lang = lang_context.lang + local parser_info = lang_context.parser_info + + for id, node in pairs(match) do + local cap_id = query.captures[id] + + -- perform language-independent checks only for first lang + if lang_context.is_first_lang and cap_id == 'error' then + local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ') + add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text) + end + + -- other checks rely on Neovim parser introspection + if lang and parser_info and cap_id == 'toplevel' then + local err = parse(node, buf, lang) + if err then + add_lint_for_node(diagnostics, err.range, err.msg, lang) + end + end + end +end + +--- @private +--- @param buf integer Buffer to lint +--- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil Options for linting +function M.lint(buf, opts) + if buf == 0 then + buf = api.nvim_get_current_buf() + end + + local diagnostics = {} + local query = vim.treesitter.query.parse('query', lint_query) + + opts = normalize_opts(buf, opts) + + -- perform at least one iteration even with no langs to perform language independent checks + for i = 1, math.max(1, #opts.langs) do + local lang = opts.langs[i] + + --- @type (table|nil) + local parser_info = vim.F.npcall(vim.treesitter.language.inspect, lang) + + local parser = vim.treesitter.get_parser(buf) + parser:parse() + parser:for_each_tree(function(tree, ltree) + if ltree:lang() == 'query' then + for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1) do + local lang_context = { + lang = lang, + parser_info = parser_info, + is_first_lang = i == 1, + } + lint_match(buf, match, query, lang_context, diagnostics) + end + end + end) + end + + vim.diagnostic.set(namespace, buf, diagnostics) +end + +--- @private +--- @param buf integer +function M.clear(buf) + vim.diagnostic.reset(namespace, buf) +end + +--- @private +--- @param findstart integer +--- @param base string +function M.omnifunc(findstart, base) + if findstart == 1 then + local result = + api.nvim_get_current_line():sub(1, api.nvim_win_get_cursor(0)[2]):find('["#%-%w]*$') + return result - 1 + end + + local buf = api.nvim_get_current_buf() + local query_lang = guess_query_lang(buf) + + local ok, parser_info = pcall(vim.treesitter.language.inspect, query_lang) + if not ok then + return -2 + end + + local items = {} + for _, f in pairs(parser_info.fields) do + if f:find(base, 1, true) then + table.insert(items, f .. ':') + end + end + for _, p in pairs(vim.treesitter.query.list_predicates()) do + local text = '#' .. p + local found = text:find(base, 1, true) + if found and found <= 2 then -- with or without '#' + table.insert(items, text) + end + text = '#not-' .. p + found = text:find(base, 1, true) + if found and found <= 2 then -- with or without '#' + table.insert(items, text) + end + end + for _, p in pairs(vim.treesitter.query.list_directives()) do + local text = '#' .. p + local found = text:find(base, 1, true) + if found and found <= 2 then -- with or without '#' + table.insert(items, text) + end + end + for _, s in pairs(parser_info.symbols) do + local text = s[2] and s[1] or '"' .. s[1]:gsub([[\]], [[\\]]) .. '"' ---@type string + if text:find(base, 1, true) then + table.insert(items, text) + end + end + return { words = items, refresh = 'always' } +end + +return M diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua new file mode 100644 index 0000000000..8d727c3c52 --- /dev/null +++ b/runtime/lua/vim/treesitter/_range.lua @@ -0,0 +1,193 @@ +local api = vim.api + +local M = {} + +---@class Range2 +---@field [1] integer start row +---@field [2] integer end row + +---@class Range4 +---@field [1] integer start row +---@field [2] integer start column +---@field [3] integer end row +---@field [4] integer end column + +---@class Range6 +---@field [1] integer start row +---@field [2] integer start column +---@field [3] integer start bytes +---@field [4] integer end row +---@field [5] integer end column +---@field [6] integer end bytes + +---@alias Range Range2|Range4|Range6 + +---@private +---@param a_row integer +---@param a_col integer +---@param b_row integer +---@param b_col integer +---@return integer +--- 1: a > b +--- 0: a == b +--- -1: a < b +local function cmp_pos(a_row, a_col, b_row, b_col) + if a_row == b_row then + if a_col > b_col then + return 1 + elseif a_col < b_col then + return -1 + else + return 0 + end + elseif a_row > b_row then + return 1 + end + + return -1 +end + +M.cmp_pos = { + lt = function(...) + return cmp_pos(...) == -1 + end, + le = function(...) + return cmp_pos(...) ~= 1 + end, + gt = function(...) + return cmp_pos(...) == 1 + end, + ge = function(...) + return cmp_pos(...) ~= -1 + end, + eq = function(...) + return cmp_pos(...) == 0 + end, + ne = function(...) + return cmp_pos(...) ~= 0 + end, +} + +setmetatable(M.cmp_pos, { __call = cmp_pos }) + +---@private +---Check if a variable is a valid range object +---@param r any +---@return boolean +function M.validate(r) + if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then + return false + end + + for _, e in + ipairs(r --[[@as any[] ]]) + do + if type(e) ~= 'number' then + return false + end + end + + return true +end + +---@private +---@param r1 Range +---@param r2 Range +---@return boolean +function M.intercepts(r1, r2) + local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) + local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2) + + -- r1 is above r2 + if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then + return false + end + + -- r1 is below r2 + if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then + return false + end + + return true +end + +---@private +---@param r Range +---@return integer, integer, integer, integer +function M.unpack4(r) + if #r == 2 then + return r[1], 0, r[2], 0 + end + local off_1 = #r == 6 and 1 or 0 + return r[1], r[2], r[3 + off_1], r[4 + off_1] +end + +---@private +---@param r Range6 +---@return integer, integer, integer, integer, integer, integer +function M.unpack6(r) + return r[1], r[2], r[3], r[4], r[5], r[6] +end + +---@private +---@param r1 Range +---@param r2 Range +---@return boolean whether r1 contains r2 +function M.contains(r1, r2) + local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) + local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2) + + -- start doesn't fit + if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then + return false + end + + -- end doesn't fit + if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then + return false + end + + return true +end + +--- @param source integer|string +--- @param index integer +--- @return integer +local function get_offset(source, index) + if index == 0 then + return 0 + end + + if type(source) == 'number' then + return api.nvim_buf_get_offset(source, index) + end + + local byte = 0 + local next_offset = source:gmatch('()\n') + local line = 1 + while line <= index do + byte = next_offset() --[[@as integer]] + line = line + 1 + end + + return byte +end + +---@private +---@param source integer|string +---@param range Range +---@return Range6 +function M.add_bytes(source, range) + if type(range) == 'table' and #range == 6 then + return range --[[@as Range6]] + end + + local start_row, start_col, end_row, end_col = M.unpack4(range) + -- TODO(vigoux): proper byte computation here, and account for EOL ? + local start_byte = get_offset(source, start_row) + start_col + local end_byte = get_offset(source, end_row) + end_col + + return { start_row, start_col, start_byte, end_row, end_col, end_byte } +end + +return M diff --git a/runtime/lua/vim/treesitter/dev.lua b/runtime/lua/vim/treesitter/dev.lua new file mode 100644 index 0000000000..69ddc9b558 --- /dev/null +++ b/runtime/lua/vim/treesitter/dev.lua @@ -0,0 +1,645 @@ +local api = vim.api + +---@class TSDevModule +local M = {} + +---@class TSTreeView +---@field ns integer API namespace +---@field opts table Options table with the following keys: +--- - anon (boolean): If true, display anonymous nodes +--- - lang (boolean): If true, display the language alongside each node +--- - indent (number): Number of spaces to indent nested lines. Default is 2. +---@field nodes TSP.Node[] +---@field named TSP.Node[] +local TSTreeView = {} + +---@class TSP.Node +---@field id integer Node id +---@field text string Node text +---@field named boolean True if this is a named (non-anonymous) node +---@field depth integer Depth of the node within the tree +---@field lnum integer Beginning line number of this node in the source buffer +---@field col integer Beginning column number of this node in the source buffer +---@field end_lnum integer Final line number of this node in the source buffer +---@field end_col integer Final column number of this node in the source buffer +---@field lang string Source language of this node +---@field root TSNode + +---@class TSP.Injection +---@field lang string Source language of this injection +---@field root TSNode Root node of the injection + +--- Traverse all child nodes starting at {node}. +--- +--- This is a recursive function. The {depth} parameter indicates the current recursion level. +--- {lang} is a string indicating the language of the tree currently being traversed. Each traversed +--- node is added to {tree}. When recursion completes, {tree} is an array of all nodes in the order +--- they were visited. +--- +--- {injections} is a table mapping node ids from the primary tree to language tree injections. Each +--- injected language has a series of trees nested within the primary language's tree, and the root +--- node of each of these trees is contained within a node in the primary tree. The {injections} +--- table maps nodes in the primary tree to root nodes of injected trees. +--- +---@param node TSNode Starting node to begin traversal |tsnode| +---@param depth integer Current recursion depth +---@param lang string Language of the tree currently being traversed +---@param injections table<string, TSP.Injection> Mapping of node ids to root nodes +--- of injected language trees (see explanation above) +---@param tree TSP.Node[] Output table containing a list of tables each representing a node in the tree +local function traverse(node, depth, lang, injections, tree) + local injection = injections[node:id()] + if injection then + traverse(injection.root, depth, injection.lang, injections, tree) + end + + for child, field in node:iter_children() do + local type = child:type() + local lnum, col, end_lnum, end_col = child:range() + local named = child:named() + local text ---@type string + if named then + if field then + text = string.format('%s: (%s', field, type) + else + text = string.format('(%s', type) + end + else + text = string.format('"%s"', type:gsub('\n', '\\n'):gsub('"', '\\"')) + end + + table.insert(tree, { + id = child:id(), + text = text, + named = named, + depth = depth, + lnum = lnum, + col = col, + end_lnum = end_lnum, + end_col = end_col, + lang = lang, + }) + + traverse(child, depth + 1, lang, injections, tree) + + if named then + tree[#tree].text = string.format('%s)', tree[#tree].text) + end + end + + return tree +end + +--- Create a new treesitter view. +--- +---@param bufnr integer Source buffer number +---@param lang string|nil Language of source buffer +--- +---@return TSTreeView|nil +---@return string|nil Error message, if any +--- +---@package +function TSTreeView:new(bufnr, lang) + local ok, parser = pcall(vim.treesitter.get_parser, bufnr or 0, lang) + if not ok then + return nil, 'No parser available for the given buffer' + end + + -- For each child tree (injected language), find the root of the tree and locate the node within + -- the primary tree that contains that root. Add a mapping from the node in the primary tree to + -- the root in the child tree to the {injections} table. + local root = parser:parse(true)[1]:root() + local injections = {} ---@type table<string, TSP.Injection> + + parser:for_each_tree(function(parent_tree, parent_ltree) + local parent = parent_tree:root() + for _, child in pairs(parent_ltree:children()) do + child:for_each_tree(function(tree, ltree) + local r = tree:root() + local node = assert(parent:named_descendant_for_range(r:range())) + local id = node:id() + if not injections[id] or r:byte_length() > injections[id].root:byte_length() then + injections[id] = { + lang = ltree:lang(), + root = r, + } + end + end) + end + end) + + local nodes = traverse(root, 0, parser:lang(), injections, {}) + + local named = {} ---@type TSP.Node[] + for _, v in ipairs(nodes) do + if v.named then + named[#named + 1] = v + end + end + + local t = { + ns = api.nvim_create_namespace('treesitter/dev-inspect'), + nodes = nodes, + named = named, + opts = { + anon = false, + lang = false, + indent = 2, + }, + } + + setmetatable(t, self) + self.__index = self + return t +end + +local decor_ns = api.nvim_create_namespace('ts.dev') + +---@param lnum integer +---@param col integer +---@param end_lnum integer +---@param end_col integer +---@return string +local function get_range_str(lnum, col, end_lnum, end_col) + if lnum == end_lnum then + return string.format('[%d:%d - %d]', lnum + 1, col + 1, end_col) + end + return string.format('[%d:%d - %d:%d]', lnum + 1, col + 1, end_lnum + 1, end_col) +end + +---@param w integer +---@return boolean closed Whether the window was closed. +local function close_win(w) + if api.nvim_win_is_valid(w) then + api.nvim_win_close(w, true) + return true + end + + return false +end + +---@param w integer +---@param b integer +local function set_dev_properties(w, b) + vim.wo[w].scrolloff = 5 + vim.wo[w].wrap = false + vim.wo[w].foldmethod = 'manual' -- disable folding + vim.bo[b].buflisted = false + vim.bo[b].buftype = 'nofile' + vim.bo[b].bufhidden = 'wipe' + vim.bo[b].filetype = 'query' +end + +--- Updates the cursor position in the inspector to match the node under the cursor. +--- +--- @param treeview TSTreeView +--- @param lang string +--- @param source_buf integer +--- @param inspect_buf integer +--- @param inspect_win integer +--- @param pos? { [1]: integer, [2]: integer } +local function set_inspector_cursor(treeview, lang, source_buf, inspect_buf, inspect_win, pos) + api.nvim_buf_clear_namespace(inspect_buf, treeview.ns, 0, -1) + + local cursor_node = vim.treesitter.get_node({ + bufnr = source_buf, + lang = lang, + pos = pos, + ignore_injections = false, + }) + if not cursor_node then + return + end + + local cursor_node_id = cursor_node:id() + for i, v in treeview:iter() do + if v.id == cursor_node_id then + local start = v.depth * treeview.opts.indent ---@type integer + local end_col = start + #v.text + api.nvim_buf_set_extmark(inspect_buf, treeview.ns, i - 1, start, { + end_col = end_col, + hl_group = 'Visual', + }) + api.nvim_win_set_cursor(inspect_win, { i, 0 }) + break + end + end +end + +--- Write the contents of this View into {bufnr}. +--- +---@param bufnr integer Buffer number to write into. +---@package +function TSTreeView:draw(bufnr) + vim.bo[bufnr].modifiable = true + local lines = {} ---@type string[] + local lang_hl_marks = {} ---@type table[] + + for _, item in self:iter() do + local range_str = get_range_str(item.lnum, item.col, item.end_lnum, item.end_col) + local lang_str = self.opts.lang and string.format(' %s', item.lang) or '' + local line = string.format( + '%s%s ; %s%s', + string.rep(' ', item.depth * self.opts.indent), + item.text, + range_str, + lang_str + ) + + if self.opts.lang then + lang_hl_marks[#lang_hl_marks + 1] = { + col = #line - #lang_str, + end_col = #line, + } + end + + lines[#lines + 1] = line + end + + api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) + + api.nvim_buf_clear_namespace(bufnr, decor_ns, 0, -1) + + for i, m in ipairs(lang_hl_marks) do + api.nvim_buf_set_extmark(bufnr, decor_ns, i - 1, m.col, { + hl_group = 'Title', + end_col = m.end_col, + }) + end + + vim.bo[bufnr].modifiable = false +end + +--- Get node {i} from this View. +--- +--- The node number is dependent on whether or not anonymous nodes are displayed. +--- +---@param i integer Node number to get +---@return TSP.Node +---@package +function TSTreeView:get(i) + local t = self.opts.anon and self.nodes or self.named + return t[i] +end + +--- Iterate over all of the nodes in this View. +--- +---@return (fun(): integer, TSP.Node) Iterator over all nodes in this View +---@return table +---@return integer +---@package +function TSTreeView:iter() + return ipairs(self.opts.anon and self.nodes or self.named) +end + +--- @class InspectTreeOpts +--- @field lang string? The language of the source buffer. If omitted, the +--- filetype of the source buffer is used. +--- @field bufnr integer? Buffer to draw the tree into. If omitted, a new +--- buffer is created. +--- @field winid integer? Window id to display the tree buffer in. If omitted, +--- a new window is created with {command}. +--- @field command string? Vimscript command to create the window. Default +--- value is "60vnew". Only used when {winid} is nil. +--- @field title (string|fun(bufnr:integer):string|nil) Title of the window. If a +--- function, it accepts the buffer number of the source +--- buffer as its only argument and should return a string. + +--- @private +--- +--- @param opts InspectTreeOpts? +function M.inspect_tree(opts) + vim.validate({ + opts = { opts, 't', true }, + }) + + opts = opts or {} + + local buf = api.nvim_get_current_buf() + local win = api.nvim_get_current_win() + local treeview = assert(TSTreeView:new(buf, opts.lang)) + + -- Close any existing inspector window + if vim.b[buf].dev_inspect then + close_win(vim.b[buf].dev_inspect) + end + + local w = opts.winid + if not w then + vim.cmd(opts.command or '60vnew') + w = api.nvim_get_current_win() + end + + local b = opts.bufnr + if b then + api.nvim_win_set_buf(w, b) + else + b = api.nvim_win_get_buf(w) + end + + vim.b[buf].dev_inspect = w + vim.b[b].dev_base = win -- base window handle + vim.b[b].disable_query_linter = true + set_dev_properties(w, b) + + local title --- @type string? + local opts_title = opts.title + if not opts_title then + local bufname = api.nvim_buf_get_name(buf) + title = string.format('Syntax tree for %s', vim.fn.fnamemodify(bufname, ':.')) + elseif type(opts_title) == 'function' then + title = opts_title(buf) + end + + assert(type(title) == 'string', 'Window title must be a string') + api.nvim_buf_set_name(b, title) + + treeview:draw(b) + + local cursor = api.nvim_win_get_cursor(win) + set_inspector_cursor(treeview, opts.lang, buf, b, w, { cursor[1] - 1, cursor[2] }) + + api.nvim_buf_clear_namespace(buf, treeview.ns, 0, -1) + api.nvim_buf_set_keymap(b, 'n', '<CR>', '', { + desc = 'Jump to the node under the cursor in the source buffer', + callback = function() + local row = api.nvim_win_get_cursor(w)[1] + local pos = treeview:get(row) + api.nvim_set_current_win(win) + api.nvim_win_set_cursor(win, { pos.lnum + 1, pos.col }) + end, + }) + api.nvim_buf_set_keymap(b, 'n', 'a', '', { + desc = 'Toggle anonymous nodes', + callback = function() + local row, col = unpack(api.nvim_win_get_cursor(w)) ---@type integer, integer + local curnode = treeview:get(row) + while curnode and not curnode.named do + row = row - 1 + curnode = treeview:get(row) + end + + treeview.opts.anon = not treeview.opts.anon + treeview:draw(b) + + if not curnode then + return + end + + local id = curnode.id + for i, node in treeview:iter() do + if node.id == id then + api.nvim_win_set_cursor(w, { i, col }) + break + end + end + end, + }) + api.nvim_buf_set_keymap(b, 'n', 'I', '', { + desc = 'Toggle language display', + callback = function() + treeview.opts.lang = not treeview.opts.lang + treeview:draw(b) + end, + }) + api.nvim_buf_set_keymap(b, 'n', 'o', '', { + desc = 'Toggle query editor', + callback = function() + local edit_w = vim.b[buf].dev_edit + if not edit_w or not close_win(edit_w) then + M.edit_query() + end + end, + }) + + local group = api.nvim_create_augroup('treesitter/dev', {}) + + api.nvim_create_autocmd('CursorMoved', { + group = group, + buffer = b, + callback = function() + if not api.nvim_buf_is_loaded(buf) then + return true + end + + api.nvim_buf_clear_namespace(buf, treeview.ns, 0, -1) + local row = api.nvim_win_get_cursor(w)[1] + local pos = treeview:get(row) + api.nvim_buf_set_extmark(buf, treeview.ns, pos.lnum, pos.col, { + end_row = pos.end_lnum, + end_col = math.max(0, pos.end_col), + hl_group = 'Visual', + }) + + local topline, botline = vim.fn.line('w0', win), vim.fn.line('w$', win) + + -- Move the cursor if highlighted range is completely out of view + if pos.lnum < topline and pos.end_lnum < topline then + api.nvim_win_set_cursor(win, { pos.end_lnum + 1, 0 }) + elseif pos.lnum > botline and pos.end_lnum > botline then + api.nvim_win_set_cursor(win, { pos.lnum + 1, 0 }) + end + end, + }) + + api.nvim_create_autocmd('CursorMoved', { + group = group, + buffer = buf, + callback = function() + if not api.nvim_buf_is_loaded(b) then + return true + end + + set_inspector_cursor(treeview, opts.lang, buf, b, w) + end, + }) + + api.nvim_create_autocmd({ 'TextChanged', 'InsertLeave' }, { + group = group, + buffer = buf, + callback = function() + if not api.nvim_buf_is_loaded(b) then + return true + end + + treeview = assert(TSTreeView:new(buf, opts.lang)) + treeview:draw(b) + end, + }) + + api.nvim_create_autocmd('BufLeave', { + group = group, + buffer = b, + callback = function() + if not api.nvim_buf_is_loaded(buf) then + return true + end + api.nvim_buf_clear_namespace(buf, treeview.ns, 0, -1) + end, + }) + + api.nvim_create_autocmd('BufLeave', { + group = group, + buffer = buf, + callback = function() + if not api.nvim_buf_is_loaded(b) then + return true + end + api.nvim_buf_clear_namespace(b, treeview.ns, 0, -1) + end, + }) + + api.nvim_create_autocmd('BufHidden', { + group = group, + buffer = buf, + once = true, + callback = function() + close_win(w) + end, + }) +end + +local edit_ns = api.nvim_create_namespace('treesitter/dev-edit') + +---@param query_win integer +---@param base_win integer +---@param lang string +local function update_editor_highlights(query_win, base_win, lang) + local base_buf = api.nvim_win_get_buf(base_win) + local query_buf = api.nvim_win_get_buf(query_win) + local parser = vim.treesitter.get_parser(base_buf, lang) + api.nvim_buf_clear_namespace(base_buf, edit_ns, 0, -1) + local query_content = table.concat(api.nvim_buf_get_lines(query_buf, 0, -1, false), '\n') + + local ok_query, query = pcall(vim.treesitter.query.parse, lang, query_content) + if not ok_query then + return + end + + local cursor_word = vim.fn.expand('<cword>') --[[@as string]] + -- Only highlight captures if the cursor is on a capture name + if cursor_word:find('^@') == nil then + return + end + -- Remove the '@' from the cursor word + cursor_word = cursor_word:sub(2) + local topline, botline = vim.fn.line('w0', base_win), vim.fn.line('w$', base_win) + for id, node in query:iter_captures(parser:trees()[1]:root(), base_buf, topline - 1, botline) do + local capture_name = query.captures[id] + if capture_name == cursor_word then + local lnum, col, end_lnum, end_col = node:range() + api.nvim_buf_set_extmark(base_buf, edit_ns, lnum, col, { + end_row = end_lnum, + end_col = end_col, + hl_group = 'Visual', + virt_text = { + { capture_name, 'Title' }, + }, + }) + end + end +end + +--- @private +--- @param lang? string language to open the query editor for. +function M.edit_query(lang) + local buf = api.nvim_get_current_buf() + local win = api.nvim_get_current_win() + + -- Close any existing editor window + if vim.b[buf].dev_edit then + close_win(vim.b[buf].dev_edit) + end + + local cmd = '60vnew' + -- If the inspector is open, place the editor above it. + local base_win = vim.b[buf].dev_base ---@type integer? + local base_buf = base_win and api.nvim_win_get_buf(base_win) + local inspect_win = base_buf and vim.b[base_buf].dev_inspect + if base_win and base_buf and api.nvim_win_is_valid(inspect_win) then + vim.api.nvim_set_current_win(inspect_win) + buf = base_buf + win = base_win + cmd = 'new' + end + vim.cmd(cmd) + + local ok, parser = pcall(vim.treesitter.get_parser, buf, lang) + if not ok then + return nil, 'No parser available for the given buffer' + end + lang = parser:lang() + + local query_win = api.nvim_get_current_win() + local query_buf = api.nvim_win_get_buf(query_win) + + vim.b[buf].dev_edit = query_win + vim.bo[query_buf].omnifunc = 'v:lua.vim.treesitter.query.omnifunc' + set_dev_properties(query_win, query_buf) + + -- Note that omnifunc guesses the language based on the containing folder, + -- so we add the parser's language to the buffer's name so that omnifunc + -- can infer the language later. + api.nvim_buf_set_name(query_buf, string.format('%s/query_editor.scm', lang)) + + local group = api.nvim_create_augroup('treesitter/dev-edit', {}) + api.nvim_create_autocmd({ 'TextChanged', 'InsertLeave' }, { + group = group, + buffer = query_buf, + desc = 'Update query editor diagnostics when the query changes', + callback = function() + vim.treesitter.query.lint(query_buf, { langs = lang, clear = false }) + end, + }) + api.nvim_create_autocmd({ 'TextChanged', 'InsertLeave', 'CursorMoved', 'BufEnter' }, { + group = group, + buffer = query_buf, + desc = 'Update query editor highlights when the cursor moves', + callback = function() + if api.nvim_win_is_valid(win) then + update_editor_highlights(query_win, win, lang) + end + end, + }) + api.nvim_create_autocmd('BufLeave', { + group = group, + buffer = query_buf, + desc = 'Clear highlights when leaving the query editor', + callback = function() + api.nvim_buf_clear_namespace(buf, edit_ns, 0, -1) + end, + }) + api.nvim_create_autocmd('BufLeave', { + group = group, + buffer = buf, + desc = 'Clear the query editor highlights when leaving the source buffer', + callback = function() + if not api.nvim_buf_is_loaded(query_buf) then + return true + end + + api.nvim_buf_clear_namespace(query_buf, edit_ns, 0, -1) + end, + }) + api.nvim_create_autocmd('BufHidden', { + group = group, + buffer = buf, + desc = 'Close the editor window when the source buffer is hidden', + once = true, + callback = function() + close_win(query_win) + end, + }) + + api.nvim_buf_set_lines(query_buf, 0, -1, false, { + ';; Write queries here (see $VIMRUNTIME/queries/ for examples).', + ';; Move cursor to a capture ("@foo") to highlight matches in the source buffer.', + ';; Completion for grammar nodes is available (:help compl-omni)', + '', + '', + }) + vim.cmd('normal! G') + vim.cmd.startinsert() +end + +return M diff --git a/runtime/lua/vim/treesitter/health.lua b/runtime/lua/vim/treesitter/health.lua index c0a1eca0ce..ed1161e97f 100644 --- a/runtime/lua/vim/treesitter/health.lua +++ b/runtime/lua/vim/treesitter/health.lua @@ -2,30 +2,28 @@ local M = {} local ts = vim.treesitter local health = require('vim.health') ---- Lists the parsers currently installed ---- ----@return string[] list of parser files -function M.list_parsers() - return vim.api.nvim_get_runtime_file('parser/*', true) -end - --- Performs a healthcheck for treesitter integration function M.check() - local parsers = M.list_parsers() + local parsers = vim.api.nvim_get_runtime_file('parser/*', true) - health.report_info(string.format('Nvim runtime ABI version: %d', ts.language_version)) + health.info(string.format('Nvim runtime ABI version: %d', ts.language_version)) for _, parser in pairs(parsers) do local parsername = vim.fn.fnamemodify(parser, ':t:r') - local is_loadable, ret = pcall(ts.language.require_language, parsername) + local is_loadable, err_or_nil = pcall(ts.language.add, parsername) - if not is_loadable or not ret then - health.report_error( - string.format('Parser "%s" failed to load (path: %s): %s', parsername, parser, ret or '?') + if not is_loadable then + health.error( + string.format( + 'Parser "%s" failed to load (path: %s): %s', + parsername, + parser, + err_or_nil or '?' + ) ) - elseif ret then - local lang = ts.language.inspect_language(parsername) - health.report_ok( + else + local lang = ts.language.inspect(parsername) + health.ok( string.format('Parser: %-10s ABI: %d, path: %s', parsername, lang._abi_version, parser) ) end diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index d77a0d0d03..496193c6ed 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -1,17 +1,34 @@ -local a = vim.api -local query = require('vim.treesitter.query') +local api = vim.api +local query = vim.treesitter.query +local Range = require('vim.treesitter._range') + +---@alias TSHlIter fun(end_line: integer|nil): integer, TSNode, TSMetadata + +---@class TSHighlightState +---@field next_row integer +---@field iter TSHlIter|nil --- support reload for quick experimentation ---@class TSHighlighter +---@field active table<integer,TSHighlighter> +---@field bufnr integer +---@field orig_spelloptions string +---@field _highlight_states table<TSTree,TSHighlightState> +---@field _queries table<string,TSHighlighterQuery> +---@field tree LanguageTree +---@field redraw_count integer local TSHighlighter = rawget(vim.treesitter, 'TSHighlighter') or {} TSHighlighter.__index = TSHighlighter +--- @nodoc TSHighlighter.active = TSHighlighter.active or {} +---@class TSHighlighterQuery +---@field _query Query|nil +---@field hl_cache table<integer,integer> local TSHighlighterQuery = {} TSHighlighterQuery.__index = TSHighlighterQuery -local ns = a.nvim_create_namespace('treesitter/highlighter') +local ns = api.nvim_create_namespace('treesitter/highlighter') ---@private function TSHighlighterQuery.new(lang, query_string) @@ -22,7 +39,7 @@ function TSHighlighterQuery.new(lang, query_string) local name = self._query.captures[capture] local id = 0 if not vim.startswith(name, '_') then - id = a.nvim_get_hl_id_by_name('@' .. name .. '.' .. lang) + id = api.nvim_get_hl_id_by_name('@' .. name .. '.' .. lang) end rawset(table, capture, id) @@ -31,22 +48,24 @@ function TSHighlighterQuery.new(lang, query_string) }) if query_string then - self._query = query.parse_query(lang, query_string) + self._query = query.parse(lang, query_string) else - self._query = query.get_query(lang, 'highlights') + self._query = query.get(lang, 'highlights') end return self end ----@private +---@package function TSHighlighterQuery:query() return self._query end ---- Creates a new highlighter using @param tree +---@package +--- +--- Creates a highlighter for `tree`. --- ----@param tree LanguageTree |LanguageTree| parser object to use for highlighting +---@param tree LanguageTree parser object to use for highlighting ---@param opts (table|nil) Configuration of the highlighter: --- - queries table overwrite queries used by the highlighter ---@return TSHighlighter Created highlighter object @@ -57,27 +76,37 @@ function TSHighlighter.new(tree, opts) error('TSHighlighter can not be used with a string parser source.') end - opts = opts or {} + opts = opts or {} ---@type { queries: table<string,string> } self.tree = tree tree:register_cbs({ - on_changedtree = function(...) - self:on_changedtree(...) - end, on_bytes = function(...) self:on_bytes(...) end, - on_detach = function(...) - self:on_detach(...) + on_detach = function() + self:on_detach() end, }) - self.bufnr = tree:source() + tree:register_cbs({ + on_changedtree = function(...) + self:on_changedtree(...) + end, + on_child_removed = function(child) + child:for_each_tree(function(t) + self:on_changedtree(t:included_ranges(true)) + end) + end, + }, true) + + self.bufnr = tree:source() --[[@as integer]] self.edit_count = 0 self.redraw_count = 0 self.line_count = {} -- A map of highlight states. -- This state is kept during rendering across each line update. self._highlight_states = {} + + ---@type table<string,TSHighlighterQuery> self._queries = {} -- Queries for a specific language can be overridden by a custom @@ -103,7 +132,7 @@ function TSHighlighter.new(tree, opts) vim.cmd.runtime({ 'syntax/synload.vim', bang = true }) end - a.nvim_buf_call(self.bufnr, function() + api.nvim_buf_call(self.bufnr, function() vim.opt_local.spelloptions:append('noplainbuffer') end) @@ -112,6 +141,7 @@ function TSHighlighter.new(tree, opts) return self end +--- @nodoc --- Removes all internal references to the highlighter function TSHighlighter:destroy() if TSHighlighter.active[self.bufnr] then @@ -122,12 +152,14 @@ function TSHighlighter:destroy() vim.bo[self.bufnr].spelloptions = self.orig_spelloptions vim.b[self.bufnr].ts_highlight = nil if vim.g.syntax_on == 1 then - a.nvim_exec_autocmds('FileType', { group = 'syntaxset', buffer = self.bufnr }) + api.nvim_exec_autocmds('FileType', { group = 'syntaxset', buffer = self.bufnr }) end end end ----@private +---@package +---@param tstree TSTree +---@return TSHighlightState function TSHighlighter:get_highlight_state(tstree) if not self._highlight_states[tstree] then self._highlight_states[tstree] = { @@ -144,28 +176,31 @@ function TSHighlighter:reset_highlight_state() self._highlight_states = {} end ----@private +---@package +---@param start_row integer +---@param new_end integer function TSHighlighter:on_bytes(_, _, start_row, _, _, _, _, _, new_end) - a.nvim__buf_redraw_range(self.bufnr, start_row, start_row + new_end + 1) + api.nvim__buf_redraw_range(self.bufnr, start_row, start_row + new_end + 1) end ----@private +---@package function TSHighlighter:on_detach() self:destroy() end ----@private +---@package +---@param changes Range6[] function TSHighlighter:on_changedtree(changes) - for _, ch in ipairs(changes or {}) do - a.nvim__buf_redraw_range(self.bufnr, ch[1], ch[3] + 1) + for _, ch in ipairs(changes) do + api.nvim__buf_redraw_range(self.bufnr, ch[1], ch[4] + 1) end end --- Gets the query used for @param lang -- ----@private +---@package ---@param lang string Language used by the highlighter. ----@return Query +---@return TSHighlighterQuery function TSHighlighter:get_query(lang) if not self._queries[lang] then self._queries[lang] = TSHighlighterQuery.new(lang) @@ -174,7 +209,10 @@ function TSHighlighter:get_query(lang) return self._queries[lang] end ----@private +---@param self TSHighlighter +---@param buf integer +---@param line integer +---@param is_spell_nav boolean local function on_line_impl(self, buf, line, is_spell_nav) self.tree:for_each_tree(function(tstree, tree) if not tstree then @@ -203,45 +241,54 @@ local function on_line_impl(self, buf, line, is_spell_nav) end while line >= state.next_row do - local capture, node, metadata = state.iter() + local capture, node, metadata = state.iter(line) - if capture == nil then - break + local range = { root_end_row + 1, 0, root_end_row + 1, 0 } + if node then + range = vim.treesitter.get_range(node, buf, metadata and metadata[capture]) end - - local start_row, start_col, end_row, end_col = node:range() - local hl = highlighter_query.hl_cache[capture] - - local capture_name = highlighter_query:query().captures[capture] - local spell = nil - if capture_name == 'spell' then - spell = true - elseif capture_name == 'nospell' then - spell = false + local start_row, start_col, end_row, end_col = Range.unpack4(range) + + if capture then + local hl = highlighter_query.hl_cache[capture] + + local capture_name = highlighter_query:query().captures[capture] + local spell = nil ---@type boolean? + if capture_name == 'spell' then + spell = true + elseif capture_name == 'nospell' then + spell = false + end + + -- Give nospell a higher priority so it always overrides spell captures. + local spell_pri_offset = capture_name == 'nospell' and 1 or 0 + + if hl and end_row >= line and (not is_spell_nav or spell ~= nil) then + local priority = (tonumber(metadata.priority) or vim.highlight.priorities.treesitter) + + spell_pri_offset + api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { + end_line = end_row, + end_col = end_col, + hl_group = hl, + ephemeral = true, + priority = priority, + conceal = metadata.conceal, + spell = spell, + }) + end end - -- Give nospell a higher priority so it always overrides spell captures. - local spell_pri_offset = capture_name == 'nospell' and 1 or 0 - - if hl and end_row >= line and (not is_spell_nav or spell ~= nil) then - a.nvim_buf_set_extmark(buf, ns, start_row, start_col, { - end_line = end_row, - end_col = end_col, - hl_group = hl, - ephemeral = true, - priority = (tonumber(metadata.priority) or 100) + spell_pri_offset, -- Low but leaves room below - conceal = metadata.conceal, - spell = spell, - }) - end if start_row > line then state.next_row = start_row end end - end, true) + end) end ---@private +---@param _win integer +---@param buf integer +---@param line integer function TSHighlighter._on_line(_, _win, buf, line, _) local self = TSHighlighter.active[buf] if not self then @@ -252,6 +299,9 @@ function TSHighlighter._on_line(_, _win, buf, line, _) end ---@private +---@param buf integer +---@param srow integer +---@param erow integer function TSHighlighter._on_spell_nav(_, _, buf, srow, _, erow, _) local self = TSHighlighter.active[buf] if not self then @@ -266,27 +316,22 @@ function TSHighlighter._on_spell_nav(_, _, buf, srow, _, erow, _) end ---@private -function TSHighlighter._on_buf(_, buf) - local self = TSHighlighter.active[buf] - if self then - self.tree:parse() - end -end - ----@private -function TSHighlighter._on_win(_, _win, buf, _topline) +---@param _win integer +---@param buf integer +---@param topline integer +---@param botline integer +function TSHighlighter._on_win(_, _win, buf, topline, botline) local self = TSHighlighter.active[buf] if not self then return false end - + self.tree:parse({ topline, botline + 1 }) self:reset_highlight_state() self.redraw_count = self.redraw_count + 1 return true end -a.nvim_set_decoration_provider(ns, { - on_buf = TSHighlighter._on_buf, +api.nvim_set_decoration_provider(ns, { on_win = TSHighlighter._on_win, on_line = TSHighlighter._on_line, _on_spell_nav = TSHighlighter._on_spell_nav, diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua index 8634e53b7b..15bf666a1e 100644 --- a/runtime/lua/vim/treesitter/language.lua +++ b/runtime/lua/vim/treesitter/language.lua @@ -1,48 +1,132 @@ -local a = vim.api +local api = vim.api +---@class TSLanguageModule local M = {} ---- Asserts that a parser for the language {lang} is installed. +---@type table<string,string> +local ft_to_lang = { + help = 'vimdoc', +} + +--- Get the filetypes associated with the parser named {lang}. +--- @param lang string Name of parser +--- @return string[] filetypes +function M.get_filetypes(lang) + local r = {} ---@type string[] + for ft, p in pairs(ft_to_lang) do + if p == lang then + r[#r + 1] = ft + end + end + return r +end + +--- @param filetype string +--- @return string|nil +function M.get_lang(filetype) + if filetype == '' then + return + end + if ft_to_lang[filetype] then + return ft_to_lang[filetype] + end + -- support subfiletypes like html.glimmer + filetype = vim.split(filetype, '.', { plain = true })[1] + return ft_to_lang[filetype] +end + +---@deprecated +function M.require_language(lang, path, silent, symbol_name) + local opts = { + silent = silent, + path = path, + symbol_name = symbol_name, + } + + if silent then + local installed = pcall(M.add, lang, opts) + return installed + end + + M.add(lang, opts) + return true +end + +---@class treesitter.RequireLangOpts +---@field path? string +---@field silent? boolean +---@field filetype? string|string[] +---@field symbol_name? string + +--- Load parser with name {lang} --- --- Parsers are searched in the `parser` runtime directory, or the provided {path} --- ----@param lang string Language the parser should parse (alphanumerical and `_` only) ----@param path (string|nil) Optional path the parser is located at ----@param silent (boolean|nil) Don't throw an error if language not found ----@param symbol_name (string|nil) Internal symbol name for the language to load ----@return boolean If the specified language is installed -function M.require_language(lang, path, silent, symbol_name) +---@param lang string Name of the parser (alphanumerical and `_` only) +---@param opts (table|nil) Options: +--- - filetype (string|string[]) Default filetype the parser should be associated with. +--- Defaults to {lang}. +--- - path (string|nil) Optional path the parser is located at +--- - symbol_name (string|nil) Internal symbol name for the language to load +function M.add(lang, opts) + ---@cast opts treesitter.RequireLangOpts + opts = opts or {} + local path = opts.path + local filetype = opts.filetype or lang + local symbol_name = opts.symbol_name + + vim.validate({ + lang = { lang, 'string' }, + path = { path, 'string', true }, + symbol_name = { symbol_name, 'string', true }, + filetype = { filetype, { 'string', 'table' }, true }, + }) + if vim._ts_has_language(lang) then - return true + M.register(lang, filetype) + return end + if path == nil then if not (lang and lang:match('[%w_]+') == lang) then - if silent then - return false - end error("'" .. lang .. "' is not a valid language name") end local fname = 'parser/' .. lang .. '.*' - local paths = a.nvim_get_runtime_file(fname, false) + local paths = api.nvim_get_runtime_file(fname, false) if #paths == 0 then - if silent then - return false - end error("no parser for '" .. lang .. "' language, see :help treesitter-parsers") end path = paths[1] end - if silent then - return pcall(function() - vim._ts_add_language(path, lang, symbol_name) - end) - else - vim._ts_add_language(path, lang, symbol_name) + vim._ts_add_language(path, lang, symbol_name) + M.register(lang, filetype) +end + +--- @param x string|string[] +--- @return string[] +local function ensure_list(x) + if type(x) == 'table' then + return x end + return { x } +end - return true +--- Register a parser named {lang} to be used for {filetype}(s). +--- @param lang string Name of parser +--- @param filetype string|string[] Filetype(s) to associate with lang +function M.register(lang, filetype) + vim.validate({ + lang = { lang, 'string' }, + filetype = { filetype, { 'string', 'table' } }, + }) + + for _, f in ipairs(ensure_list(filetype)) do + if f ~= '' then + ft_to_lang[f] = lang + end + end end --- Inspects the provided language. @@ -51,9 +135,19 @@ end --- ---@param lang string Language ---@return table -function M.inspect_language(lang) - M.require_language(lang) +function M.inspect(lang) + M.add(lang) return vim._ts_inspect_language(lang) end +---@deprecated +function M.inspect_language(...) + vim.deprecate( + 'vim.treesitter.language.inspect_language()', + 'vim.treesitter.language.inspect()', + '0.10' + ) + return M.inspect(...) +end + return M diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index a1e96f8ef2..0171b416cd 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -1,85 +1,257 @@ -local a = vim.api +--- @defgroup lua-treesitter-languagetree +--- +--- @brief A \*LanguageTree\* contains a tree of parsers: the root treesitter parser for {lang} and +--- any "injected" language parsers, which themselves may inject other languages, recursively. +--- For example a Lua buffer containing some Vimscript commands needs multiple parsers to fully +--- understand its contents. +--- +--- To create a LanguageTree (parser object) for a given buffer and language, use: +--- +--- ```lua +--- local parser = vim.treesitter.get_parser(bufnr, lang) +--- ``` +--- +--- (where `bufnr=0` means current buffer). `lang` defaults to 'filetype'. +--- Note: currently the parser is retained for the lifetime of a buffer but this may change; +--- a plugin should keep a reference to the parser object if it wants incremental updates. +--- +--- Whenever you need to access the current syntax tree, parse the buffer: +--- +--- ```lua +--- local tree = parser:parse({ start_row, end_row }) +--- ``` +--- +--- This returns a table of immutable |treesitter-tree| objects representing the current state of +--- the buffer. When the plugin wants to access the state after a (possible) edit it must call +--- `parse()` again. If the buffer wasn't edited, the same tree will be returned again without extra +--- work. If the buffer was parsed before, incremental parsing will be done of the changed parts. +--- +--- Note: To use the parser directly inside a |nvim_buf_attach()| Lua callback, you must call +--- |vim.treesitter.get_parser()| before you register your callback. But preferably parsing +--- shouldn't be done directly in the change callback anyway as they will be very frequent. Rather +--- a plugin that does any kind of analysis on a tree should use a timer to throttle too frequent +--- updates. +--- + +-- Debugging: +-- +-- vim.g.__ts_debug levels: +-- - 1. Messages from languagetree.lua +-- - 2. Parse messages from treesitter +-- - 2. Lex messages from treesitter +-- +-- Log file can be found in stdpath('log')/treesitter.log + local query = require('vim.treesitter.query') local language = require('vim.treesitter.language') +local Range = require('vim.treesitter._range') + +---@alias TSCallbackName +---| 'changedtree' +---| 'bytes' +---| 'detach' +---| 'child_added' +---| 'child_removed' + +---@alias TSCallbackNameOn +---| 'on_changedtree' +---| 'on_bytes' +---| 'on_detach' +---| 'on_child_added' +---| 'on_child_removed' + +--- @type table<TSCallbackNameOn,TSCallbackName> +local TSCallbackNames = { + on_changedtree = 'changedtree', + on_bytes = 'bytes', + on_detach = 'detach', + on_child_added = 'child_added', + on_child_removed = 'child_removed', +} ---@class LanguageTree ----@field _callbacks function[] Callback handlers ----@field _children LanguageTree[] Injected languages ----@field _injection_query table Queries defining injected languages ----@field _opts table Options ----@field _parser userdata Parser for language ----@field _regions table List of regions this tree should manage and parse ----@field _lang string Language name ----@field _regions table ----@field _source (number|string) Buffer or string to parse ----@field _trees userdata[] Reference to parsed |tstree| (one for each language) ----@field _valid boolean If the parsed tree is valid - +---@field private _callbacks table<TSCallbackName,function[]> Callback handlers +---@field package _callbacks_rec table<TSCallbackName,function[]> Callback handlers (recursive) +---@field private _children table<string,LanguageTree> Injected languages +---@field private _injection_query Query Queries defining injected languages +---@field private _injections_processed boolean +---@field private _opts table Options +---@field private _parser TSParser Parser for language +---@field private _has_regions boolean +---@field private _regions table<integer, Range6[]>? +---List of regions this tree should manage and parse. If nil then regions are +---taken from _trees. This is mostly a short-lived cache for included_regions() +---@field private _lang string Language name +---@field private _parent_lang? string Parent language name +---@field private _source (integer|string) Buffer or string to parse +---@field private _trees table<integer, TSTree> Reference to parsed tree (one for each language). +---Each key is the index of region, which is synced with _regions and _valid. +---@field private _valid boolean|table<integer,boolean> If the parsed tree is valid +---@field private _logger? fun(logtype: string, msg: string) +---@field private _logfile? file* local LanguageTree = {} + +---@class LanguageTreeOpts +---@field queries table<string,string> -- Deprecated +---@field injections table<string,string> + LanguageTree.__index = LanguageTree ---- A |LanguageTree| holds the treesitter parser for a given language {lang} used ---- to parse a buffer. As the buffer may contain injected languages, the LanguageTree ---- needs to store parsers for these child languages as well (which in turn may contain ---- child languages themselves, hence the name). ---- ----@param source (number|string) Buffer or a string of text to parse ----@param lang string Root language this tree represents ----@param opts (table|nil) Optional keyword arguments: ---- - injections table Mapping language to injection query strings. ---- This is useful for overriding the built-in ---- runtime file searching for the injection language ---- query per language. ----@return LanguageTree |LanguageTree| parser object -function LanguageTree.new(source, lang, opts) - language.require_language(lang) +--- @package +--- +--- |LanguageTree| contains a tree of parsers: the root treesitter parser for {lang} and any +--- "injected" language parsers, which themselves may inject other languages, recursively. +--- +---@param source (integer|string) Buffer or text string to parse +---@param lang string Root language of this tree +---@param opts (table|nil) Optional arguments: +--- - injections table Map of language to injection query strings. Overrides the +--- built-in runtime file searching for language injections. +---@param parent_lang? string Parent language name of this tree +---@return LanguageTree parser object +function LanguageTree.new(source, lang, opts, parent_lang) + language.add(lang) + ---@type LanguageTreeOpts opts = opts or {} - if opts.queries then - a.nvim_err_writeln("'queries' is no longer supported. Use 'injections' now") - opts.injections = opts.queries + if source == 0 then + source = vim.api.nvim_get_current_buf() end local injections = opts.injections or {} - local self = setmetatable({ + + --- @type LanguageTree + local self = { _source = source, _lang = lang, + _parent_lang = parent_lang, _children = {}, - _regions = {}, _trees = {}, _opts = opts, - _injection_query = injections[lang] and query.parse_query(lang, injections[lang]) - or query.get_query(lang, 'injections'), + _injection_query = injections[lang] and query.parse(lang, injections[lang]) + or query.get(lang, 'injections'), + _has_regions = false, + _injections_processed = false, _valid = false, _parser = vim._create_ts_parser(lang), - _callbacks = { - changedtree = {}, - bytes = {}, - detach = {}, - child_added = {}, - child_removed = {}, - }, - }, LanguageTree) + _callbacks = {}, + _callbacks_rec = {}, + } + + setmetatable(self, LanguageTree) + + if vim.g.__ts_debug and type(vim.g.__ts_debug) == 'number' then + self:_set_logger() + self:_log('START') + end + + for _, name in pairs(TSCallbackNames) do + self._callbacks[name] = {} + self._callbacks_rec[name] = {} + end return self end +--- @private +function LanguageTree:_set_logger() + local source = self:source() + source = type(source) == 'string' and 'text' or tostring(source) + + local lang = self:lang() + + vim.fn.mkdir(vim.fn.stdpath('log'), 'p') + local logfilename = vim.fs.joinpath(vim.fn.stdpath('log'), 'treesitter.log') + + local logfile, openerr = io.open(logfilename, 'a+') + + if not logfile or openerr then + error(string.format('Could not open file (%s) for logging: %s', logfilename, openerr)) + return + end + + self._logfile = logfile + + self._logger = function(logtype, msg) + self._logfile:write(string.format('%s:%s:(%s) %s\n', source, lang, logtype, msg)) + self._logfile:flush() + end + + local log_lex = vim.g.__ts_debug >= 3 + local log_parse = vim.g.__ts_debug >= 2 + self._parser:_set_logger(log_lex, log_parse, self._logger) +end + +---Measure execution time of a function +---@generic R1, R2, R3 +---@param f fun(): R1, R2, R2 +---@return number, R1, R2, R3 +local function tcall(f, ...) + local start = vim.uv.hrtime() + ---@diagnostic disable-next-line + local r = { f(...) } + --- @type number + local duration = (vim.uv.hrtime() - start) / 1000000 + return duration, unpack(r) +end + +---@private +---@vararg any +function LanguageTree:_log(...) + if not self._logger then + return + end + + if not vim.g.__ts_debug or vim.g.__ts_debug < 1 then + return + end + + local args = { ... } + if type(args[1]) == 'function' then + args = { args[1]() } + end + + local info = debug.getinfo(2, 'nl') + local nregions = vim.tbl_count(self:included_regions()) + local prefix = + string.format('%s:%d: (#regions=%d) ', info.name or '???', info.currentline or 0, nregions) + + local msg = { prefix } + for _, x in ipairs(args) do + if type(x) == 'string' then + msg[#msg + 1] = x + else + msg[#msg + 1] = vim.inspect(x, { newline = ' ', indent = '' }) + end + end + self._logger('nvim', table.concat(msg, ' ')) +end + --- Invalidates this parser and all its children +---@param reload boolean|nil function LanguageTree:invalidate(reload) self._valid = false -- buffer was reloaded, reparse all trees if reload then + for _, t in pairs(self._trees) do + self:_do_callback('changedtree', t:included_ranges(true), t) + end self._trees = {} end - for _, child in ipairs(self._children) do + for _, child in pairs(self._children) do child:invalidate(reload) end end ---- Returns all trees this language tree contains. +--- Returns all trees of the regions parsed by this parser. --- Does not include child languages. +--- The result is list-like if +--- * this LanguageTree is the root, in which case the result is empty or a singleton list; or +--- * the root LanguageTree is fully parsed. +--- +---@return table<integer, TSTree> function LanguageTree:trees() return self._trees end @@ -89,11 +261,39 @@ function LanguageTree:lang() return self._lang end ---- Determines whether this tree is valid. ---- If the tree is invalid, call `parse()`. ---- This will return the updated tree. -function LanguageTree:is_valid() - return self._valid +--- Returns whether this LanguageTree is valid, i.e., |LanguageTree:trees()| reflects the latest +--- state of the source. If invalid, user should call |LanguageTree:parse()|. +---@param exclude_children boolean|nil whether to ignore the validity of children (default `false`) +---@return boolean +function LanguageTree:is_valid(exclude_children) + local valid = self._valid + + if type(valid) == 'table' then + for i, _ in pairs(self:included_regions()) do + if not valid[i] then + return false + end + end + end + + if not exclude_children then + if not self._injections_processed then + return false + end + + for _, child in pairs(self._children) do + if not child:is_valid(exclude_children) then + return false + end + end + end + + if type(valid) == 'boolean' then + return valid + end + + self._valid = true + return true end --- Returns a map of language to child tree. @@ -106,50 +306,77 @@ function LanguageTree:source() return self._source end ---- Parses all defined regions using a treesitter parser ---- for the language this tree represents. ---- This will run the injection query for this language to ---- determine if any child languages should be created. ---- ----@return userdata[] Table of parsed |tstree| ----@return table Change list -function LanguageTree:parse() - if self._valid then - return self._trees +--- @param region Range6[] +--- @param range? boolean|Range +--- @return boolean +local function intercepts_region(region, range) + if #region == 0 then + return true + end + + if range == nil then + return false + end + + if type(range) == 'boolean' then + return range end - local parser = self._parser + for _, r in ipairs(region) do + if Range.intercepts(r, range) then + return true + end + end + + return false +end + +--- @private +--- @param range boolean|Range? +--- @return Range6[] changes +--- @return integer no_regions_parsed +--- @return number total_parse_time +function LanguageTree:_parse_regions(range) local changes = {} + local no_regions_parsed = 0 + local total_parse_time = 0 - local old_trees = self._trees - self._trees = {} + if type(self._valid) ~= 'table' then + self._valid = {} + end -- If there are no ranges, set to an empty list -- so the included ranges in the parser are cleared. - if self._regions and #self._regions > 0 then - for i, ranges in ipairs(self._regions) do - local old_tree = old_trees[i] - parser:set_included_ranges(ranges) + for i, ranges in pairs(self:included_regions()) do + if not self._valid[i] and intercepts_region(ranges, range) then + self._parser:set_included_ranges(ranges) + local parse_time, tree, tree_changes = + tcall(self._parser.parse, self._parser, self._trees[i], self._source, true) - local tree, tree_changes = parser:parse(old_tree, self._source) - self:_do_callback('changedtree', tree_changes, tree) + -- Pass ranges if this is an initial parse + local cb_changes = self._trees[i] and tree_changes or tree:included_ranges(true) - table.insert(self._trees, tree) + self:_do_callback('changedtree', cb_changes, tree) + self._trees[i] = tree vim.list_extend(changes, tree_changes) - end - else - local tree, tree_changes = parser:parse(old_trees[1], self._source) - self:_do_callback('changedtree', tree_changes, tree) - table.insert(self._trees, tree) - vim.list_extend(changes, tree_changes) + total_parse_time = total_parse_time + parse_time + no_regions_parsed = no_regions_parsed + 1 + self._valid[i] = true + end end - local injections_by_lang = self:_get_injections() - local seen_langs = {} + return changes, no_regions_parsed, total_parse_time +end + +--- @private +--- @return number +function LanguageTree:_add_injections() + local seen_langs = {} ---@type table<string,boolean> - for lang, injection_ranges in pairs(injections_by_lang) do - local has_lang = language.require_language(lang, nil, true) + local query_time, injections_by_lang = tcall(self._get_injections, self) + for lang, injection_regions in pairs(injections_by_lang) do + local has_lang = pcall(language.add, lang) -- Child language trees should just be ignored if not found, since -- they can depend on the text of a node. Intermediate strings @@ -161,16 +388,7 @@ function LanguageTree:parse() child = self:add_child(lang) end - child:set_included_regions(injection_ranges) - - local _, child_changes = child:parse() - - -- Propagate any child changes so they are included in the - -- the change list for the callback. - if child_changes then - vim.list_extend(changes, child_changes) - end - + child:set_included_regions(injection_regions) seen_langs[lang] = true end end @@ -181,16 +399,71 @@ function LanguageTree:parse() end end - self._valid = true + return query_time +end + +--- Recursively parse all regions in the language tree using |treesitter-parsers| +--- for the corresponding languages and run injection queries on the parsed trees +--- to determine whether child trees should be created and parsed. +--- +--- Any region with empty range (`{}`, typically only the root tree) is always parsed; +--- otherwise (typically injections) only if it intersects {range} (or if {range} is `true`). +--- +--- @param range boolean|Range|nil: Parse this range in the parser's source. +--- Set to `true` to run a complete parse of the source (Note: Can be slow!) +--- Set to `false|nil` to only parse regions with empty ranges (typically +--- only the root tree without injections). +--- @return table<integer, TSTree> +function LanguageTree:parse(range) + if self:is_valid() then + self:_log('valid') + return self._trees + end - return self._trees, changes + local changes --- @type Range6[]? + + -- Collect some stats + local no_regions_parsed = 0 + local query_time = 0 + local total_parse_time = 0 + + --- At least 1 region is invalid + if not self:is_valid(true) then + changes, no_regions_parsed, total_parse_time = self:_parse_regions(range) + -- Need to run injections when we parsed something + if no_regions_parsed > 0 then + self._injections_processed = false + end + end + + if not self._injections_processed and range ~= false and range ~= nil then + query_time = self:_add_injections() + self._injections_processed = true + end + + self:_log({ + changes = changes and #changes > 0 and changes or nil, + regions_parsed = no_regions_parsed, + parse_time = total_parse_time, + query_time = query_time, + range = range, + }) + + for _, child in pairs(self._children) do + child:parse(range) + end + + return self._trees end +---@deprecated Misleading name. Use `LanguageTree:children()` (non-recursive) instead, +--- add recursion yourself if needed. --- Invokes the callback for each |LanguageTree| and its children recursively --- ----@param fn function(tree: LanguageTree, lang: string) ----@param include_self boolean Whether to include the invoking tree in the results +---@param fn fun(tree: LanguageTree, lang: string) +---@param include_self boolean|nil Whether to include the invoking tree in the results function LanguageTree:for_each_child(fn, include_self) + vim.deprecate('LanguageTree:for_each_child()', 'LanguageTree:children()', '0.11') if include_self then fn(self, self._lang) end @@ -204,9 +477,9 @@ end --- --- Note: This includes the invoking tree's child trees as well. --- ----@param fn function(tree: TSTree, languageTree: LanguageTree) +---@param fn fun(tree: TSTree, ltree: LanguageTree) function LanguageTree:for_each_tree(fn) - for _, tree in ipairs(self._trees) do + for _, tree in pairs(self._trees) do fn(tree, self) end @@ -221,15 +494,20 @@ end --- ---@private ---@param lang string Language to add. ----@return LanguageTree Injected |LanguageTree| +---@return LanguageTree injected function LanguageTree:add_child(lang) if self._children[lang] then self:remove_child(lang) end - self._children[lang] = LanguageTree.new(self._source, lang, self._opts) + local child = LanguageTree.new(self._source, lang, self._opts, self:lang()) - self:invalidate() + -- Inherit recursive callbacks + for nm, cb in pairs(self._callbacks_rec) do + vim.list_extend(child._callbacks_rec[nm], cb) + end + + self._children[lang] = child self:_do_callback('child_added', self._children[lang]) return self._children[lang] @@ -245,7 +523,6 @@ function LanguageTree:remove_child(lang) if child then self._children[lang] = nil child:destroy() - self:invalidate() self:_do_callback('child_removed', child) end end @@ -258,11 +535,60 @@ end --- `remove_child` must be called on the parent to remove it. function LanguageTree:destroy() -- Cleanup here - for _, child in ipairs(self._children) do + for _, child in pairs(self._children) do child:destroy() end end +---@param region Range6[] +local function region_tostr(region) + if #region == 0 then + return '[]' + end + local srow, scol = region[1][1], region[1][2] + local erow, ecol = region[#region][4], region[#region][5] + return string.format('[%d:%d-%d:%d]', srow, scol, erow, ecol) +end + +---@private +---Iterate through all the regions. fn returns a boolean to indicate if the +---region is valid or not. +---@param fn fun(index: integer, region: Range6[]): boolean +function LanguageTree:_iter_regions(fn) + if not self._valid then + return + end + + local was_valid = type(self._valid) ~= 'table' + + if was_valid then + self:_log('was valid', self._valid) + self._valid = {} + end + + local all_valid = true + + for i, region in pairs(self:included_regions()) do + if was_valid or self._valid[i] then + self._valid[i] = fn(i, region) + if not self._valid[i] then + self:_log(function() + return 'invalidating region', i, region_tostr(region) + end) + end + end + + if not self._valid[i] then + all_valid = false + end + end + + -- Compress the valid value to 'true' if there are no invalid regions + if all_valid then + self._valid = all_valid + end +end + --- Sets the included regions that should be parsed by this |LanguageTree|. --- A region is a set of nodes and/or ranges that will be parsed in the same context. --- @@ -277,151 +603,253 @@ end --- This allows for embedded languages to be parsed together across different --- nodes, which is useful for templating languages like ERB and EJS. --- ---- Note: This call invalidates the tree and requires it to be parsed again. ---- ---@private ----@param regions table List of regions this tree should manage and parse. -function LanguageTree:set_included_regions(regions) +---@param new_regions (Range4|Range6|TSNode)[][] List of regions this tree should manage and parse. +function LanguageTree:set_included_regions(new_regions) + self._has_regions = true + -- Transform the tables from 4 element long to 6 element long (with byte offset) - for _, region in ipairs(regions) do + for _, region in ipairs(new_regions) do for i, range in ipairs(region) do if type(range) == 'table' and #range == 4 then - local start_row, start_col, end_row, end_col = unpack(range) - local start_byte = 0 - local end_byte = 0 - -- TODO(vigoux): proper byte computation here, and account for EOL ? - if type(self._source) == 'number' then - -- Easy case, this is a buffer parser - start_byte = a.nvim_buf_get_offset(self._source, start_row) + start_col - end_byte = a.nvim_buf_get_offset(self._source, end_row) + end_col - elseif type(self._source) == 'string' then - -- string parser, single `\n` delimited string - start_byte = vim.fn.byteidx(self._source, start_col) - end_byte = vim.fn.byteidx(self._source, end_col) - end - - region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte } + region[i] = Range.add_bytes(self._source, range --[[@as Range4]]) + elseif type(range) == 'userdata' then + region[i] = { range:range(true) } end end end - self._regions = regions - -- Trees are no longer valid now that we have changed regions. - -- TODO(vigoux,steelsojka): Look into doing this smarter so we can use some of the - -- old trees for incremental parsing. Currently, this only - -- affects injected languages. - self._trees = {} - self:invalidate() + -- included_regions is not guaranteed to be list-like, but this is still sound, i.e. if + -- new_regions is different from included_regions, then outdated regions in included_regions are + -- invalidated. For example, if included_regions = new_regions ++ hole ++ outdated_regions, then + -- outdated_regions is invalidated by _iter_regions in else branch. + if #self:included_regions() ~= #new_regions then + -- TODO(lewis6991): inefficient; invalidate trees incrementally + for _, t in pairs(self._trees) do + self:_do_callback('changedtree', t:included_ranges(true), t) + end + self._trees = {} + self:invalidate() + else + self:_iter_regions(function(i, region) + return vim.deep_equal(new_regions[i], region) + end) + end + + self._regions = new_regions end ---- Gets the set of included regions +---Gets the set of included regions managed by this LanguageTree. This can be different from the +---regions set by injection query, because a partial |LanguageTree:parse()| drops the regions +---outside the requested range. +---@return table<integer, Range6[]> function LanguageTree:included_regions() - return self._regions + if self._regions then + return self._regions + end + + if not self._has_regions then + -- treesitter.c will default empty ranges to { -1, -1, -1, -1, -1, -1} (the full range) + return { {} } + end + + local regions = {} ---@type Range6[][] + for i, _ in pairs(self._trees) do + regions[i] = self._trees[i]:included_ranges(true) + end + + self._regions = regions + return regions +end + +---@param node TSNode +---@param source string|integer +---@param metadata TSMetadata +---@param include_children boolean +---@return Range6[] +local function get_node_ranges(node, source, metadata, include_children) + local range = vim.treesitter.get_range(node, source, metadata) + local child_count = node:named_child_count() + + if include_children or child_count == 0 then + return { range } + end + + local ranges = {} ---@type Range6[] + + local srow, scol, sbyte, erow, ecol, ebyte = Range.unpack6(range) + + -- We are excluding children so we need to mask out their ranges + for i = 0, child_count - 1 do + local child = assert(node:named_child(i)) + local c_srow, c_scol, c_sbyte, c_erow, c_ecol, c_ebyte = child:range(true) + if c_srow > srow or c_scol > scol then + ranges[#ranges + 1] = { srow, scol, sbyte, c_srow, c_scol, c_sbyte } + end + srow = c_erow + scol = c_ecol + sbyte = c_ebyte + end + + if erow > srow or ecol > scol then + ranges[#ranges + 1] = Range.add_bytes(source, { srow, scol, sbyte, erow, ecol, ebyte }) + end + + return ranges +end + +---@class TSInjectionElem +---@field combined boolean +---@field regions Range6[][] + +---@alias TSInjection table<string,table<integer,TSInjectionElem>> + +---@param t table<integer,TSInjection> +---@param tree_index integer +---@param pattern integer +---@param lang string +---@param combined boolean +---@param ranges Range6[] +local function add_injection(t, tree_index, pattern, lang, combined, ranges) + if #ranges == 0 then + -- Make sure not to add an empty range set as this is interpreted to mean the whole buffer. + return + end + + -- Each tree index should be isolated from the other nodes. + if not t[tree_index] then + t[tree_index] = {} + end + + if not t[tree_index][lang] then + t[tree_index][lang] = {} + end + + -- Key this by pattern. If combined is set to true all captures of this pattern + -- will be parsed by treesitter as the same "source". + -- If combined is false, each "region" will be parsed as a single source. + if not t[tree_index][lang][pattern] then + t[tree_index][lang][pattern] = { combined = combined, regions = {} } + end + + table.insert(t[tree_index][lang][pattern].regions, ranges) +end + +-- TODO(clason): replace by refactored `ts.has_parser` API (without registering) +--- The result of this function is cached to prevent nvim_get_runtime_file from being +--- called too often +--- @param lang string parser name +--- @return boolean # true if parser for {lang} exists on rtp +local has_parser = vim.func._memoize(1, function(lang) + return vim._ts_has_language(lang) + or #vim.api.nvim_get_runtime_file('parser/' .. lang .. '.*', false) > 0 +end) + +--- Return parser name for language (if exists) or filetype (if registered and exists). +--- Also attempts with the input lower-cased. +--- +---@param alias string language or filetype name +---@return string? # resolved parser name +local function resolve_lang(alias) + if has_parser(alias) then + return alias + end + + if has_parser(alias:lower()) then + return alias:lower() + end + + local lang = vim.treesitter.language.get_lang(alias) + if lang and has_parser(lang) then + return lang + end + + lang = vim.treesitter.language.get_lang(alias:lower()) + if lang and has_parser(lang) then + return lang + end end ---@private -local function get_range_from_metadata(node, id, metadata) - if metadata[id] and metadata[id].range then - return metadata[id].range +--- Extract injections according to: +--- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection +---@param match table<integer,TSNode> +---@param metadata TSMetadata +---@return string?, boolean, Range6[] +function LanguageTree:_get_injection(match, metadata) + local ranges = {} ---@type Range6[] + local combined = metadata['injection.combined'] ~= nil + local injection_lang = metadata['injection.language'] --[[@as string?]] + local lang = metadata['injection.self'] ~= nil and self:lang() + or metadata['injection.parent'] ~= nil and self._parent_lang + or (injection_lang and resolve_lang(injection_lang)) + local include_children = metadata['injection.include-children'] ~= nil + + for id, node in pairs(match) do + local name = self._injection_query.captures[id] + -- Lang should override any other language tag + if name == 'injection.language' then + local text = vim.treesitter.get_node_text(node, self._source, { metadata = metadata[id] }) + lang = resolve_lang(text) + elseif name == 'injection.content' then + ranges = get_node_ranges(node, self._source, metadata[id], include_children) + end end - return { node:range() } + + return lang, combined, ranges end ---- Gets language injection points by language. +--- Can't use vim.tbl_flatten since a range is just a table. +---@param regions Range6[][] +---@return Range6[] +local function combine_regions(regions) + local result = {} ---@type Range6[] + for _, region in ipairs(regions) do + for _, range in ipairs(region) do + result[#result + 1] = range + end + end + return result +end + +--- Gets language injection regions by language. --- --- This is where most of the injection processing occurs. --- --- TODO: Allow for an offset predicate to tailor the injection range --- instead of using the entire nodes range. ----@private +--- @private +--- @return table<string, Range6[][]> function LanguageTree:_get_injections() if not self._injection_query then return {} end + ---@type table<integer,TSInjection> local injections = {} - for tree_index, tree in ipairs(self._trees) do + for index, tree in pairs(self._trees) do local root_node = tree:root() local start_line, _, end_line, _ = root_node:range() for pattern, match, metadata in self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1) do - local lang = nil - local ranges = {} - local combined = metadata.combined - - -- Directives can configure how injections are captured as well as actual node captures. - -- This allows more advanced processing for determining ranges and language resolution. - if metadata.content then - local content = metadata.content - - -- Allow for captured nodes to be used - if type(content) == 'number' then - content = { match[content]:range() } - end - - if type(content) == 'table' and #content >= 4 then - vim.list_extend(ranges, content) - end + local lang, combined, ranges = self:_get_injection(match, metadata) + if lang then + add_injection(injections, index, pattern, lang, combined, ranges) + else + self:_log('match from injection query failed for pattern', pattern) end - - if metadata.language then - lang = metadata.language - end - - -- You can specify the content and language together - -- using a tag with the language, for example - -- @javascript - for id, node in pairs(match) do - local name = self._injection_query.captures[id] - - -- Lang should override any other language tag - if name == 'language' and not lang then - lang = query.get_node_text(node, self._source) - elseif name == 'combined' then - combined = true - elseif name == 'content' and #ranges == 0 then - table.insert(ranges, get_range_from_metadata(node, id, metadata)) - -- Ignore any tags that start with "_" - -- Allows for other tags to be used in matches - elseif string.sub(name, 1, 1) ~= '_' then - if not lang then - lang = name - end - - if #ranges == 0 then - table.insert(ranges, get_range_from_metadata(node, id, metadata)) - end - end - end - - -- Each tree index should be isolated from the other nodes. - if not injections[tree_index] then - injections[tree_index] = {} - end - - if not injections[tree_index][lang] then - injections[tree_index][lang] = {} - end - - -- Key this by pattern. If combined is set to true all captures of this pattern - -- will be parsed by treesitter as the same "source". - -- If combined is false, each "region" will be parsed as a single source. - if not injections[tree_index][lang][pattern] then - injections[tree_index][lang][pattern] = { combined = combined, regions = {} } - end - - table.insert(injections[tree_index][lang][pattern].regions, ranges) end end + ---@type table<string,Range6[][]> local result = {} -- Generate a map by lang of node lists. -- Each list is a set of ranges that should be parsed together. - for _, lang_map in ipairs(injections) do + for _, lang_map in pairs(injections) do for lang, patterns in pairs(lang_map) do if not result[lang] then result[lang] = {} @@ -429,12 +857,9 @@ function LanguageTree:_get_injections() for _, entry in pairs(patterns) do if entry.combined then - local regions = vim.tbl_map(function(e) - return vim.tbl_flatten(e) - end, entry.regions) - table.insert(result[lang], regions) + table.insert(result[lang], combine_regions(entry.regions)) else - for _, ranges in ipairs(entry.regions) do + for _, ranges in pairs(entry.regions) do table.insert(result[lang], ranges) end end @@ -446,13 +871,94 @@ function LanguageTree:_get_injections() end ---@private +---@param cb_name TSCallbackName function LanguageTree:_do_callback(cb_name, ...) for _, cb in ipairs(self._callbacks[cb_name]) do cb(...) end + for _, cb in ipairs(self._callbacks_rec[cb_name]) do + cb(...) + end end ----@private +---@package +function LanguageTree:_edit( + start_byte, + end_byte_old, + end_byte_new, + start_row, + start_col, + end_row_old, + end_col_old, + end_row_new, + end_col_new +) + for _, tree in pairs(self._trees) do + tree:edit( + start_byte, + end_byte_old, + end_byte_new, + start_row, + start_col, + end_row_old, + end_col_old, + end_row_new, + end_col_new + ) + end + + self._regions = nil + + local changed_range = { + start_row, + start_col, + start_byte, + end_row_old, + end_col_old, + end_byte_old, + } + + -- Validate regions after editing the tree + self:_iter_regions(function(_, region) + if #region == 0 then + -- empty region, use the full source + return false + end + for _, r in ipairs(region) do + if Range.intercepts(r, changed_range) then + return false + end + end + return true + end) + + for _, child in pairs(self._children) do + child:_edit( + start_byte, + end_byte_old, + end_byte_new, + start_row, + start_col, + end_row_old, + end_col_old, + end_row_new, + end_col_new + ) + end +end + +---@package +---@param bufnr integer +---@param changed_tick integer +---@param start_row integer +---@param start_col integer +---@param start_byte integer +---@param old_row integer +---@param old_col integer +---@param old_byte integer +---@param new_row integer +---@param new_col integer +---@param new_byte integer function LanguageTree:_on_bytes( bufnr, changed_tick, @@ -466,26 +972,36 @@ function LanguageTree:_on_bytes( new_col, new_byte ) - self:invalidate() - local old_end_col = old_col + ((old_row == 0) and start_col or 0) local new_end_col = new_col + ((new_row == 0) and start_col or 0) - -- Edit all trees recursively, together BEFORE emitting a bytes callback. - -- In most cases this callback should only be called from the root tree. - self:for_each_tree(function(tree) - tree:edit( - start_byte, - start_byte + old_byte, - start_byte + new_byte, - start_row, - start_col, - start_row + old_row, - old_end_col, - start_row + new_row, - new_end_col - ) - end) + self:_log( + 'on_bytes', + bufnr, + changed_tick, + start_row, + start_col, + start_byte, + old_row, + old_col, + old_byte, + new_row, + new_col, + new_byte + ) + + -- Edit trees together BEFORE emitting a bytes callback. + self:_edit( + start_byte, + start_byte + old_byte, + start_byte + new_byte, + start_row, + start_col, + start_row + old_row, + old_end_col, + start_row + new_row, + new_end_col + ) self:_do_callback( 'bytes', @@ -503,63 +1019,65 @@ function LanguageTree:_on_bytes( ) end ----@private +---@package function LanguageTree:_on_reload() self:invalidate(true) end ----@private +---@package function LanguageTree:_on_detach(...) self:invalidate(true) self:_do_callback('detach', ...) + if self._logfile then + self._logger('nvim', 'detaching') + self._logger = nil + self._logfile:close() + end end --- Registers callbacks for the |LanguageTree|. ---@param cbs table An |nvim_buf_attach()|-like table argument with the following handlers: --- - `on_bytes` : see |nvim_buf_attach()|, but this will be called _after_ the parsers callback. --- - `on_changedtree` : a callback that will be called every time the tree has syntactical changes. ---- It will only be passed one argument, which is a table of the ranges (as node ranges) that ---- changed. +--- It will be passed two arguments: a table of the ranges (as node ranges) that +--- changed and the changed tree. --- - `on_child_added` : emitted when a child is added to the tree. --- - `on_child_removed` : emitted when a child is removed from the tree. -function LanguageTree:register_cbs(cbs) +--- - `on_detach` : emitted when the buffer is detached, see |nvim_buf_detach_event|. +--- Takes one argument, the number of the buffer. +--- @param recursive? boolean Apply callbacks recursively for all children. Any new children will +--- also inherit the callbacks. +function LanguageTree:register_cbs(cbs, recursive) + ---@cast cbs table<TSCallbackNameOn,function> if not cbs then return end - if cbs.on_changedtree then - table.insert(self._callbacks.changedtree, cbs.on_changedtree) - end - - if cbs.on_bytes then - table.insert(self._callbacks.bytes, cbs.on_bytes) - end + local callbacks = recursive and self._callbacks_rec or self._callbacks - if cbs.on_detach then - table.insert(self._callbacks.detach, cbs.on_detach) - end - - if cbs.on_child_added then - table.insert(self._callbacks.child_added, cbs.on_child_added) + for name, cbname in pairs(TSCallbackNames) do + if cbs[name] then + table.insert(callbacks[cbname], cbs[name]) + end end - if cbs.on_child_removed then - table.insert(self._callbacks.child_removed, cbs.on_child_removed) + if recursive then + for _, child in pairs(self._children) do + child:register_cbs(cbs, true) + end end end ----@private +---@param tree TSTree +---@param range Range +---@return boolean local function tree_contains(tree, range) - local start_row, start_col, end_row, end_col = tree:root():range() - local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2]) - local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4]) - - return start_fits and end_fits + return Range.contains({ tree:root():range() }, range) end --- Determines whether {range} is contained in the |LanguageTree|. --- ----@param range table `{ start_line, start_col, end_line, end_col }` +---@param range Range4 `{ start_line, start_col, end_line, end_col }` ---@return boolean function LanguageTree:contains(range) for _, tree in pairs(self._trees) do @@ -573,20 +1091,19 @@ end --- Gets the tree that contains {range}. --- ----@param range table `{ start_line, start_col, end_line, end_col }` +---@param range Range4 `{ start_line, start_col, end_line, end_col }` ---@param opts table|nil Optional keyword arguments: --- - ignore_injections boolean Ignore injected languages (default true) ----@return userdata|nil Contained |tstree| +---@return TSTree|nil function LanguageTree:tree_for_range(range, opts) opts = opts or {} local ignore = vim.F.if_nil(opts.ignore_injections, true) if not ignore then for _, child in pairs(self._children) do - for _, tree in pairs(child:trees()) do - if tree_contains(tree, range) then - return tree - end + local tree = child:tree_for_range(range, opts) + if tree then + return tree end end end @@ -602,10 +1119,10 @@ end --- Gets the smallest named node that contains {range}. --- ----@param range table `{ start_line, start_col, end_line, end_col }` +---@param range Range4 `{ start_line, start_col, end_line, end_col }` ---@param opts table|nil Optional keyword arguments: --- - ignore_injections boolean Ignore injected languages (default true) ----@return userdata|nil Found |tsnode| +---@return TSNode | nil Found node function LanguageTree:named_node_for_range(range, opts) local tree = self:tree_for_range(range, opts) if tree then @@ -615,7 +1132,7 @@ end --- Gets the appropriate language that contains {range}. --- ----@param range table `{ start_line, start_col, end_line, end_col }` +---@param range Range4 `{ start_line, start_col, end_line, end_col }` ---@return LanguageTree Managing {range} function LanguageTree:language_for_range(range) for _, child in pairs(self._children) do diff --git a/runtime/lua/vim/treesitter/playground.lua b/runtime/lua/vim/treesitter/playground.lua deleted file mode 100644 index bb073290c6..0000000000 --- a/runtime/lua/vim/treesitter/playground.lua +++ /dev/null @@ -1,186 +0,0 @@ -local api = vim.api - -local M = {} - ----@class Playground ----@field ns number API namespace ----@field opts table Options table with the following keys: ---- - anon (boolean): If true, display anonymous nodes ---- - lang (boolean): If true, display the language alongside each node ---- ----@class Node ----@field id number Node id ----@field text string Node text ----@field named boolean True if this is a named (non-anonymous) node ----@field depth number Depth of the node within the tree ----@field lnum number Beginning line number of this node in the source buffer ----@field col number Beginning column number of this node in the source buffer ----@field end_lnum number Final line number of this node in the source buffer ----@field end_col number Final column number of this node in the source buffer ----@field lang string Source language of this node - ---- Traverse all child nodes starting at {node}. ---- ---- This is a recursive function. The {depth} parameter indicates the current recursion level. ---- {lang} is a string indicating the language of the tree currently being traversed. Each traversed ---- node is added to {tree}. When recursion completes, {tree} is an array of all nodes in the order ---- they were visited. ---- ---- {injections} is a table mapping node ids from the primary tree to language tree injections. Each ---- injected language has a series of trees nested within the primary language's tree, and the root ---- node of each of these trees is contained within a node in the primary tree. The {injections} ---- table maps nodes in the primary tree to root nodes of injected trees. ---- ----@param node userdata Starting node to begin traversal |tsnode| ----@param depth number Current recursion depth ----@param lang string Language of the tree currently being traversed ----@param injections table Mapping of node ids to root nodes of injected language trees (see ---- explanation above) ----@param tree Node[] Output table containing a list of tables each representing a node in the tree ----@private -local function traverse(node, depth, lang, injections, tree) - local injection = injections[node:id()] - if injection then - traverse(injection.root, depth, injection.lang, injections, tree) - end - - for child, field in node:iter_children() do - local type = child:type() - local lnum, col, end_lnum, end_col = child:range() - local named = child:named() - local text - if named then - if field then - text = string.format('%s: (%s)', field, type) - else - text = string.format('(%s)', type) - end - else - text = string.format('"%s"', type:gsub('\n', '\\n')) - end - - table.insert(tree, { - id = child:id(), - text = text, - named = named, - depth = depth, - lnum = lnum, - col = col, - end_lnum = end_lnum, - end_col = end_col, - lang = lang, - }) - - traverse(child, depth + 1, lang, injections, tree) - end - - return tree -end - ---- Create a new Playground object. ---- ----@param bufnr number Source buffer number ----@param lang string|nil Language of source buffer ---- ----@return Playground|nil ----@return string|nil Error message, if any ---- ----@private -function M.new(self, bufnr, lang) - local ok, parser = pcall(vim.treesitter.get_parser, bufnr or 0, lang) - if not ok then - return nil, 'No parser available for the given buffer' - end - - -- For each child tree (injected language), find the root of the tree and locate the node within - -- the primary tree that contains that root. Add a mapping from the node in the primary tree to - -- the root in the child tree to the {injections} table. - local root = parser:parse()[1]:root() - local injections = {} - parser:for_each_child(function(child, lang_) - child:for_each_tree(function(tree) - local r = tree:root() - local node = root:named_descendant_for_range(r:range()) - if node then - injections[node:id()] = { - lang = lang_, - root = r, - } - end - end) - end) - - local nodes = traverse(root, 0, parser:lang(), injections, {}) - - local named = {} - for _, v in ipairs(nodes) do - if v.named then - named[#named + 1] = v - end - end - - local t = { - ns = api.nvim_create_namespace(''), - nodes = nodes, - named = named, - opts = { - anon = false, - lang = false, - }, - } - - setmetatable(t, self) - self.__index = self - return t -end - ---- Write the contents of this Playground into {bufnr}. ---- ----@param bufnr number Buffer number to write into. ----@private -function M.draw(self, bufnr) - vim.bo[bufnr].modifiable = true - local lines = {} - for _, item in self:iter() do - lines[#lines + 1] = table.concat({ - string.rep(' ', item.depth), - item.text, - item.lnum == item.end_lnum - and string.format(' [%d:%d-%d]', item.lnum + 1, item.col + 1, item.end_col) - or string.format( - ' [%d:%d-%d:%d]', - item.lnum + 1, - item.col + 1, - item.end_lnum + 1, - item.end_col - ), - self.opts.lang and string.format(' %s', item.lang) or '', - }) - end - api.nvim_buf_set_lines(bufnr, 0, -1, false, lines) - vim.bo[bufnr].modifiable = false -end - ---- Get node {i} from this Playground object. ---- ---- The node number is dependent on whether or not anonymous nodes are displayed. ---- ----@param i number Node number to get ----@return Node ----@private -function M.get(self, i) - local t = self.opts.anon and self.nodes or self.named - return t[i] -end - ---- Iterate over all of the nodes in this Playground object. ---- ----@return function Iterator over all nodes in this Playground ----@return table ----@return number ----@private -function M.iter(self) - return ipairs(self.opts.anon and self.nodes or self.named) -end - -return M diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index dbf134573d..8cbbffcd60 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,21 +1,25 @@ -local a = vim.api +local api = vim.api local language = require('vim.treesitter.language') --- query: pattern matching on trees --- predicate matching is implemented in lua --- ---@class Query ---@field captures string[] List of captures used in query ----@field info table Contains used queries, predicates, directives +---@field info TSQueryInfo Contains used queries, predicates, directives ---@field query userdata Parsed query local Query = {} Query.__index = Query +---@class TSQueryInfo +---@field captures table +---@field patterns table<string,any[][]> + +---@class TSQueryModule local M = {} ----@private +---@param files string[] +---@return string[] local function dedupe_files(files) local result = {} + ---@type table<string,boolean> local seen = {} for _, path in ipairs(files) do @@ -28,7 +32,6 @@ local function dedupe_files(files) return result end ----@private local function safe_read(filename, read_quantifier) local file, err = io.open(filename, 'r') if not file then @@ -39,7 +42,6 @@ local function safe_read(filename, read_quantifier) return content end ----@private --- Adds {ilang} to {base_langs}, only if {ilang} is different than {lang} --- ---@return boolean true If lang == ilang @@ -51,24 +53,34 @@ local function add_included_lang(base_langs, lang, ilang) return false end +---@deprecated +function M.get_query_files(...) + vim.deprecate( + 'vim.treesitter.query.get_query_files()', + 'vim.treesitter.query.get_files()', + '0.10' + ) + return M.get_files(...) +end + --- Gets the list of files used to make up a query --- ---@param lang string Language to get query for ---@param query_name string Name of the query to load (e.g., "highlights") ---@param is_included (boolean|nil) Internal parameter, most of the time left as `nil` ---@return string[] query_files List of files to load for given query and language -function M.get_query_files(lang, query_name, is_included) +function M.get_files(lang, query_name, is_included) local query_path = string.format('queries/%s/%s.scm', lang, query_name) - local lang_files = dedupe_files(a.nvim_get_runtime_file(query_path, true)) + local lang_files = dedupe_files(api.nvim_get_runtime_file(query_path, true)) if #lang_files == 0 then return {} end - local base_query = nil + local base_query = nil ---@type string? local extensions = {} - local base_langs = {} + local base_langs = {} ---@type string[] -- Now get the base languages by looking at the first line of every file -- The syntax is the following : @@ -87,6 +99,7 @@ function M.get_query_files(lang, query_name, is_included) local extension = false for modeline in + ---@return string function() return file:read('*l') end @@ -97,6 +110,7 @@ function M.get_query_files(lang, query_name, is_included) local langlist = modeline:match(MODELINE_FORMAT) if langlist then + ---@diagnostic disable-next-line:param-type-mismatch for _, incllang in ipairs(vim.split(langlist, ',', true)) do local is_optional = incllang:match('%(.*%)') @@ -127,7 +141,7 @@ function M.get_query_files(lang, query_name, is_included) local query_files = {} for _, base_lang in ipairs(base_langs) do - local base_files = M.get_query_files(base_lang, query_name, true) + local base_files = M.get_files(base_lang, query_name, true) vim.list_extend(query_files, base_files) end vim.list_extend(query_files, { base_query }) @@ -136,7 +150,8 @@ function M.get_query_files(lang, query_name, is_included) return query_files end ----@private +---@param filenames string[] +---@return string local function read_query_files(filenames) local contents = {} @@ -147,7 +162,8 @@ local function read_query_files(filenames) return table.concat(contents, '') end ---- The explicitly set queries from |vim.treesitter.query.set_query()| +-- The explicitly set queries from |vim.treesitter.query.set()| +---@type table<string,table<string,Query>> local explicit_queries = setmetatable({}, { __index = function(t, k) local lang_queries = {} @@ -157,6 +173,12 @@ local explicit_queries = setmetatable({}, { end, }) +---@deprecated +function M.set_query(...) + vim.deprecate('vim.treesitter.query.set_query()', 'vim.treesitter.query.set()', '0.10') + M.set(...) +end + --- Sets the runtime query named {query_name} for {lang} --- --- This allows users to override any runtime files and/or configuration @@ -165,8 +187,14 @@ local explicit_queries = setmetatable({}, { ---@param lang string Language to use for the query ---@param query_name string Name of the query (e.g., "highlights") ---@param text string Query text (unparsed). -function M.set_query(lang, query_name, text) - explicit_queries[lang][query_name] = M.parse_query(lang, text) +function M.set(lang, query_name, text) + explicit_queries[lang][query_name] = M.parse(lang, text) +end + +---@deprecated +function M.get_query(...) + vim.deprecate('vim.treesitter.query.get_query()', 'vim.treesitter.query.get()', '0.10') + return M.get(...) end --- Returns the runtime query {query_name} for {lang}. @@ -174,24 +202,28 @@ end ---@param lang string Language to use for the query ---@param query_name string Name of the query (e.g. "highlights") --- ----@return Query Parsed query -function M.get_query(lang, query_name) +---@return Query|nil Parsed query +M.get = vim.func._memoize('concat-2', function(lang, query_name) if explicit_queries[lang][query_name] then return explicit_queries[lang][query_name] end - local query_files = M.get_query_files(lang, query_name) + local query_files = M.get_files(lang, query_name) local query_string = read_query_files(query_files) - if #query_string > 0 then - return M.parse_query(lang, query_string) + if #query_string == 0 then + return nil end -end -local query_cache = vim.defaulttable(function() - return setmetatable({}, { __mode = 'v' }) + return M.parse(lang, query_string) end) +---@deprecated +function M.parse_query(...) + vim.deprecate('vim.treesitter.query.parse_query()', 'vim.treesitter.query.parse()', '0.10') + return M.parse(...) +end + --- Parse {query} as a string. (If the query is in a file, the caller --- should read the contents into a string before calling). --- @@ -209,81 +241,50 @@ end) ---@param query string Query in s-expr syntax --- ---@return Query Parsed query -function M.parse_query(lang, query) - language.require_language(lang) - local cached = query_cache[lang][query] - if cached then - return cached - else - local self = setmetatable({}, Query) - self.query = vim._ts_parse_query(lang, query) - self.info = self.query:inspect() - self.captures = self.info.captures - query_cache[lang][query] = self - return self - end -end +M.parse = vim.func._memoize('concat-2', function(lang, query) + language.add(lang) + + local self = setmetatable({}, Query) + self.query = vim._ts_parse_query(lang, query) + self.info = self.query:inspect() + self.captures = self.info.captures + return self +end) ---- Gets the text corresponding to a given node ---- ----@param node userdata |tsnode| ----@param source (number|string) Buffer or string from which the {node} is extracted ----@param opts (table|nil) Optional parameters. ---- - concat: (boolean) Concatenate result in a string (default true) ----@return (string[]|string) -function M.get_node_text(node, source, opts) - opts = opts or {} - local concat = vim.F.if_nil(opts.concat, true) - - local start_row, start_col, start_byte = node:start() - local end_row, end_col, end_byte = node:end_() - - if type(source) == 'number' then - local lines - local eof_row = a.nvim_buf_line_count(source) - if start_row >= eof_row then - return nil - end +---@deprecated +function M.get_range(...) + vim.deprecate('vim.treesitter.query.get_range()', 'vim.treesitter.get_range()', '0.10') + return vim.treesitter.get_range(...) +end - if end_col == 0 then - lines = a.nvim_buf_get_lines(source, start_row, end_row, true) - end_col = -1 - else - lines = a.nvim_buf_get_lines(source, start_row, end_row + 1, true) - end +---@deprecated +function M.get_node_text(...) + vim.deprecate('vim.treesitter.query.get_node_text()', 'vim.treesitter.get_node_text()', '0.10') + return vim.treesitter.get_node_text(...) +end - if #lines > 0 then - if #lines == 1 then - lines[1] = string.sub(lines[1], start_col + 1, end_col) - else - lines[1] = string.sub(lines[1], start_col + 1) - lines[#lines] = string.sub(lines[#lines], 1, end_col) - end - end +---@alias TSMatch table<integer,TSNode> - return concat and table.concat(lines, '\n') or lines - elseif type(source) == 'string' then - return source:sub(start_byte + 1, end_byte) - end -end +---@alias TSPredicate fun(match: TSMatch, _, _, predicate: any[]): boolean -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) +---@type table<string,TSPredicate> local predicate_handlers = { ['eq?'] = function(match, _, source, predicate) local node = match[predicate[2]] if not node then return true end - local node_text = M.get_node_text(node, source) + local node_text = vim.treesitter.get_node_text(node, source) - local str + local str ---@type string if type(predicate[3]) == 'string' then -- (#eq? @aa "foo") str = predicate[3] else -- (#eq? @aa @bb) - str = M.get_node_text(match[predicate[3]], source) + str = vim.treesitter.get_node_text(match[predicate[3]], source) end if node_text ~= str or str == nil then @@ -299,12 +300,11 @@ local predicate_handlers = { return true end local regex = predicate[3] - return string.find(M.get_node_text(node, source), regex) + return string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil end, ['match?'] = (function() local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true } - ---@private local function check_magic(str) if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then return str @@ -321,12 +321,14 @@ local predicate_handlers = { }) return function(match, _, source, pred) + ---@cast match TSMatch local node = match[pred[2]] if not node then return true end + ---@diagnostic disable-next-line no-unknown local regex = compiled_vim_regexes[pred[3]] - return regex:match_str(M.get_node_text(node, source)) + return regex:match_str(vim.treesitter.get_node_text(node, source)) end end)(), @@ -335,7 +337,7 @@ local predicate_handlers = { if not node then return true end - local node_text = M.get_node_text(node, source) + local node_text = vim.treesitter.get_node_text(node, source) for i = 3, #predicate do if string.find(node_text, predicate[i], 1, true) then @@ -351,7 +353,7 @@ local predicate_handlers = { if not node then return true end - local node_text = M.get_node_text(node, source) + local node_text = vim.treesitter.get_node_text(node, source) -- Since 'predicate' will not be used by callers of this function, use it -- to store a string set built from the list of words to check against. @@ -359,6 +361,7 @@ local predicate_handlers = { if not string_set then string_set = {} for i = 3, #predicate do + ---@diagnostic disable-next-line:no-unknown string_set[predicate[i]] = true end predicate['string_set'] = string_set @@ -366,36 +369,85 @@ local predicate_handlers = { return string_set[node_text] end, + + ['has-ancestor?'] = function(match, _, _, predicate) + local node = match[predicate[2]] + if not node then + return true + end + + local ancestor_types = {} + for _, type in ipairs({ unpack(predicate, 3) }) do + ancestor_types[type] = true + end + + node = node:parent() + while node do + if ancestor_types[node:type()] then + return true + end + node = node:parent() + end + return false + end, + + ['has-parent?'] = function(match, _, _, predicate) + local node = match[predicate[2]] + if not node then + return true + end + + if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then + return true + end + return false + end, } -- As we provide lua-match? also expose vim-match? predicate_handlers['vim-match?'] = predicate_handlers['match?'] +---@class TSMetadata +---@field range? Range +---@field conceal? string +---@field [integer] TSMetadata +---@field [string] integer|string + +---@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata) + +-- Predicate handler receive the following arguments +-- (match, pattern, bufnr, predicate) + -- Directives store metadata or perform side effects against a match. -- Directives should always end with a `!`. -- Directive handler receive the following arguments -- (match, pattern, bufnr, predicate, metadata) +---@type table<string,TSDirective> local directive_handlers = { ['set!'] = function(_, _, _, pred, metadata) - if #pred == 4 then - -- (#set! @capture "key" "value") - local _, capture_id, key, value = unpack(pred) + if #pred >= 3 and type(pred[2]) == 'number' then + -- (#set! @capture key value) + local capture_id, key, value = pred[2], pred[3], pred[4] if not metadata[capture_id] then metadata[capture_id] = {} end metadata[capture_id][key] = value else - local _, key, value = unpack(pred) - -- (#set! "key" "value") - metadata[key] = value + -- (#set! key value) + local key, value = pred[2], pred[3] + metadata[key] = value or true end end, -- Shifts the range of a node. -- Example: (#offset! @_node 0 1 0 -1) ['offset!'] = function(match, _, _, pred, metadata) + ---@cast pred integer[] local capture_id = pred[2] - local offset_node = match[capture_id] - local range = { offset_node:range() } + if not metadata[capture_id] then + metadata[capture_id] = {} + end + + local range = metadata[capture_id].range or { match[capture_id]:range() } local start_row_offset = pred[3] or 0 local start_col_offset = pred[4] or 0 local end_row_offset = pred[5] or 0 @@ -408,19 +460,74 @@ local directive_handlers = { -- If this produces an invalid range, we just skip it. if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then - if not metadata[capture_id] then - metadata[capture_id] = {} - end metadata[capture_id].range = range end end, + -- Transform the content of the node + -- Example: (#gsub! @_node ".*%.(.*)" "%1") + ['gsub!'] = function(match, _, bufnr, pred, metadata) + assert(#pred == 4) + + local id = pred[2] + assert(type(id) == 'number') + + local node = match[id] + local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' + + if not metadata[id] then + metadata[id] = {} + end + + local pattern, replacement = pred[3], pred[4] + assert(type(pattern) == 'string') + assert(type(replacement) == 'string') + + metadata[id].text = text:gsub(pattern, replacement) + end, + -- Trim blank lines from end of the node + -- Example: (#trim! @fold) + -- TODO(clason): generalize to arbitrary whitespace removal + ['trim!'] = function(match, _, bufnr, pred, metadata) + local capture_id = pred[2] + assert(type(capture_id) == 'number') + + local node = match[capture_id] + if not node then + return + end + + local start_row, start_col, end_row, end_col = node:range() + + -- Don't trim if region ends in middle of a line + if end_col ~= 0 then + return + end + + while end_row >= start_row do + -- As we only care when end_col == 0, always inspect one line above end_row. + local end_line = api.nvim_buf_get_lines(bufnr, end_row - 1, end_row, true)[1] + + if end_line ~= '' then + break + end + + end_row = end_row - 1 + end + + -- If this produces an invalid range, we just skip it. + if start_row < end_row or (start_row == end_row and start_col <= end_col) then + metadata[capture_id] = metadata[capture_id] or {} + metadata[capture_id].range = { start_row, start_col, end_row, end_col } + end + end, } --- Adds a new predicate to be used in queries --- ---@param name string Name of the predicate, without leading # ----@param handler function(match:table, pattern:string, bufnr:number, predicate:string[]) +---@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[]) --- - see |vim.treesitter.query.add_directive()| for argument meanings +---@param force boolean|nil function M.add_predicate(name, handler, force) if predicate_handlers[name] and not force then error(string.format('Overriding %s', name)) @@ -437,12 +544,13 @@ end --- metadata table `metadata[capture_id].key = value` --- ---@param name string Name of the directive, without leading # ----@param handler function(match:table, pattern:string, bufnr:number, predicate:string[], metadata:table) +---@param handler function(match:table<string,TSNode>, pattern:string, bufnr:integer, predicate:string[], metadata:table) --- - match: see |treesitter-query| --- - node-level data are accessible via `match[capture_id]` --- - pattern: see |treesitter-query| --- - predicate: list of strings containing the full directive being called, e.g. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }` +---@param force boolean|nil function M.add_directive(name, handler, force) if directive_handlers[name] and not force then error(string.format('Overriding %s', name)) @@ -463,17 +571,18 @@ function M.list_predicates() return vim.tbl_keys(predicate_handlers) end ----@private local function xor(x, y) return (x or y) and not (x and y) end ----@private local function is_directive(name) return string.sub(name, -1) == '!' end ---@private +---@param match TSMatch +---@param pattern string +---@param source integer|string function Query:match_preds(match, pattern, source) local preds = self.info.patterns[pattern] @@ -482,8 +591,9 @@ function Query:match_preds(match, pattern, source) -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). -- Also, tree-sitter strips the leading # from predicates for us. - local pred_name - local is_not + local pred_name ---@type string + + local is_not ---@type boolean -- Skip over directives... they will get processed after all the predicates. if not is_directive(pred[1]) then @@ -513,6 +623,8 @@ function Query:match_preds(match, pattern, source) end ---@private +---@param match TSMatch +---@param metadata TSMetadata function Query:apply_directives(match, pattern, source, metadata) local preds = self.info.patterns[pattern] @@ -533,7 +645,10 @@ end --- Returns the start and stop value if set else the node's range. -- When the node's range is used, the stop is incremented by 1 -- to make the search inclusive. ----@private +---@param start integer +---@param stop integer +---@param node TSNode +---@return integer, integer local function value_or_node_range(start, stop, node) if start == nil and stop == nil then local node_start, _, node_stop, _ = node:range() @@ -547,42 +662,41 @@ end --- --- {source} is needed if the query contains predicates; then the caller --- must ensure to use a freshly parsed tree consistent with the current ---- text of the buffer (if relevant). {start_row} and {end_row} can be used to limit +--- text of the buffer (if relevant). {start} and {stop} can be used to limit --- matches inside a row range (this is typically used with root node --- as the {node}, i.e., to get syntax highlight matches in the current ---- viewport). When omitted, the {start} and {end} row values are used from the given node. +--- viewport). When omitted, the {start} and {stop} row values are used from the given node. --- --- The iterator returns three values: a numeric id identifying the capture, --- the captured node, and metadata from any directives processing the match. --- The following example shows how to get captures by name: ---- <pre>lua +--- +--- ```lua --- for id, node, metadata in query:iter_captures(tree:root(), bufnr, first, last) do --- local name = query.captures[id] -- name of the capture in the query --- -- typically useful info about the node: --- local type = node:type() -- type of the captured node --- local row1, col1, row2, col2 = node:range() -- range of the capture ---- ... use the info here ... +--- -- ... use the info here ... --- end ---- </pre> +--- ``` --- ----@param node userdata |tsnode| under which the search will occur ----@param source (number|string) Source buffer or string to extract text from ----@param start number Starting line for the search ----@param stop number Stopping line for the search (end-exclusive) +---@param node TSNode under which the search will occur +---@param source (integer|string) Source buffer or string to extract text from +---@param start integer Starting line for the search +---@param stop integer Stopping line for the search (end-exclusive) --- ----@return number capture Matching capture id ----@return table capture_node Capture for {node} ----@return table metadata for the {capture} +---@return (fun(end_line: integer|nil): integer, TSNode, TSMetadata): +--- capture id, capture node, metadata function Query:iter_captures(node, source, start, stop) if type(source) == 'number' and source == 0 then - source = vim.api.nvim_get_current_buf() + source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) local raw_iter = node:_rawquery(self.query, true, start, stop) - ---@private - local function iter() + local function iter(end_line) local capture, captured_node, match = raw_iter() local metadata = {} @@ -590,7 +704,10 @@ function Query:iter_captures(node, source, start, stop) local active = self:match_preds(match, match.pattern, source) match.active = active if not active then - return iter() -- tail call: try next match + if end_line and captured_node:range() > end_line then + return nil, captured_node, nil + end + return iter(end_line) -- tail call: try next match end self:apply_directives(match, match.pattern, source, metadata) @@ -609,7 +726,8 @@ end --- If the query has more than one pattern, the capture table might be sparse --- and e.g. `pairs()` method should be used over `ipairs`. --- Here is an example iterating over all captures in every match: ---- <pre>lua +--- +--- ```lua --- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do --- for id, node in pairs(match) do --- local name = query.captures[id] @@ -617,27 +735,30 @@ end --- --- local node_data = metadata[id] -- Node level metadata --- ---- ... use the info here ... +--- -- ... use the info here ... --- end --- end ---- </pre> +--- ``` --- ----@param node userdata |tsnode| under which the search will occur ----@param source (number|string) Source buffer or string to search ----@param start number Starting line for the search ----@param stop number Stopping line for the search (end-exclusive) +---@param node TSNode under which the search will occur +---@param source (integer|string) Source buffer or string to search +---@param start integer Starting line for the search +---@param stop integer Stopping line for the search (end-exclusive) +---@param opts table|nil Options: +--- - max_start_depth (integer) if non-zero, sets the maximum start depth +--- for each match. This is used to prevent traversing too deep into a tree. +--- Requires treesitter >= 0.20.9. --- ----@return number pattern id ----@return table match ----@return table metadata -function Query:iter_matches(node, source, start, stop) +---@return (fun(): integer, table<integer,TSNode>, table): pattern id, match, metadata +function Query:iter_matches(node, source, start, stop, opts) if type(source) == 'number' and source == 0 then - source = vim.api.nvim_get_current_buf() + source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, false, start, stop) + local raw_iter = node:_rawquery(self.query, false, start, stop, opts) + ---@cast raw_iter fun(): string, any local function iter() local pattern, match = raw_iter() local metadata = {} @@ -655,4 +776,58 @@ function Query:iter_matches(node, source, start, stop) return iter end +---@class QueryLinterOpts +---@field langs (string|string[]|nil) +---@field clear (boolean) + +--- Lint treesitter queries using installed parser, or clear lint errors. +--- +--- Use |treesitter-parsers| in runtimepath to check the query file in {buf} for errors: +--- +--- - verify that used nodes are valid identifiers in the grammar. +--- - verify that predicates and directives are valid. +--- - verify that top-level s-expressions are valid. +--- +--- The found diagnostics are reported using |diagnostic-api|. +--- By default, the parser used for verification is determined by the containing folder +--- of the query file, e.g., if the path ends in `/lua/highlights.scm`, the parser for the +--- `lua` language will be used. +---@param buf (integer) Buffer handle +---@param opts (QueryLinterOpts|nil) Optional keyword arguments: +--- - langs (string|string[]|nil) Language(s) to use for checking the query. +--- If multiple languages are specified, queries are validated for all of them +--- - clear (boolean) if `true`, just clear current lint errors +function M.lint(buf, opts) + if opts and opts.clear then + require('vim.treesitter._query_linter').clear(buf) + else + require('vim.treesitter._query_linter').lint(buf, opts) + end +end + +--- Omnifunc for completing node names and predicates in treesitter queries. +--- +--- Use via +--- +--- ```lua +--- vim.bo.omnifunc = 'v:lua.vim.treesitter.query.omnifunc' +--- ``` +--- +function M.omnifunc(findstart, base) + return require('vim.treesitter._query_linter').omnifunc(findstart, base) +end + +--- Opens a live editor to query the buffer you started from. +--- +--- Can also be shown with *:EditQuery*. +--- +--- If you move the cursor to a capture name ("@foo"), text matching the capture is highlighted in +--- the source buffer. The query editor is a scratch buffer, use `:write` to save it. You can find +--- example queries at `$VIMRUNTIME/queries/`. +--- +--- @param lang? string language to open the query editor for. If omitted, inferred from the current buffer's filetype. +function M.edit(lang) + require('vim.treesitter.dev').edit_query(lang) +end + return M |