diff options
Diffstat (limited to 'runtime/lua')
-rw-r--r-- | runtime/lua/vim/_editor.lua | 39 | ||||
-rw-r--r-- | runtime/lua/vim/lsp.lua | 20 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/buf.lua | 13 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/protocol.lua | 174 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_fold.lua | 18 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_range.lua | 29 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 64 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 99 |
8 files changed, 147 insertions, 309 deletions
diff --git a/runtime/lua/vim/_editor.lua b/runtime/lua/vim/_editor.lua index 9516233b45..5445c4e492 100644 --- a/runtime/lua/vim/_editor.lua +++ b/runtime/lua/vim/_editor.lua @@ -778,22 +778,37 @@ do end end ----Prints given arguments in human-readable format. ----Example: ----<pre>lua ---- -- Print highlight group Normal and store it's contents in a variable. ---- local hl_normal = vim.pretty_print(vim.api.nvim_get_hl_by_name("Normal", true)) ----</pre> ----@see |vim.inspect()| ----@return any # given arguments. +---@private function vim.pretty_print(...) - local objects = {} + vim.deprecate('vim.pretty_print', 'vim.print', '0.10') + return vim.print(...) +end + +--- "Pretty prints" the given arguments and returns them unmodified. +--- +--- Example: +--- <pre>lua +--- local hl_normal = vim.print(vim.api.nvim_get_hl_by_name('Normal', true)) +--- </pre> +--- +--- @see |vim.inspect()| +--- @return any # given arguments. +function vim.print(...) + if vim.in_fast_event() then + print(...) + return ... + end + for i = 1, select('#', ...) do - local v = select(i, ...) - table.insert(objects, vim.inspect(v)) + local o = select(i, ...) + if type(o) == 'string' then + vim.api.nvim_out_write(o) + else + vim.api.nvim_out_write(vim.inspect(o, { newline = '\n', indent = ' ' })) + end + vim.api.nvim_out_write('\n') end - print(table.concat(objects, ' ')) return ... end diff --git a/runtime/lua/vim/lsp.lua b/runtime/lua/vim/lsp.lua index 117b32dc57..39665a3d4f 100644 --- a/runtime/lua/vim/lsp.lua +++ b/runtime/lua/vim/lsp.lua @@ -1,3 +1,4 @@ +---@diagnostic disable: invisible local default_handlers = require('vim.lsp.handlers') local log = require('vim.lsp.log') local lsp_rpc = require('vim.lsp.rpc') @@ -1037,7 +1038,7 @@ function lsp.start_client(config) --- Returns the default handler if the user hasn't set a custom one. --- ---@param method (string) LSP method name - ---@return function|nil The handler for the given method, if defined, or the default from |vim.lsp.handlers| + ---@return lsp-handler|nil The handler for the given method, if defined, or the default from |vim.lsp.handlers| local function resolve_handler(method) return handlers[method] or default_handlers[method] end @@ -1592,6 +1593,11 @@ local function text_document_did_save_handler(bufnr) local name = api.nvim_buf_get_name(bufnr) local old_name = changetracking._get_and_set_name(client, bufnr, name) if old_name and name ~= old_name then + client.notify('textDocument/didClose', { + textDocument = { + uri = vim.uri_from_fname(old_name), + }, + }) client.notify('textDocument/didOpen', { textDocument = { version = 0, @@ -1932,7 +1938,7 @@ api.nvim_create_autocmd('VimLeavePre', { ---@param bufnr (integer) Buffer handle, or 0 for current. ---@param method (string) LSP method name ---@param params table|nil Parameters to send to the server ----@param handler function|nil See |lsp-handler| +---@param handler lsp-handler|nil See |lsp-handler| --- If nil, follows resolution strategy defined in |lsp-handler-configuration| --- ---@return table<integer, integer>, fun() 2-tuple: @@ -1992,9 +1998,10 @@ end ---@param bufnr (integer) Buffer handle, or 0 for current. ---@param method (string) LSP method name ---@param params (table|nil) Parameters to send to the server ----@param callback (function) The callback to call when all requests are finished. +---@param callback fun(request_results: table<integer, {error: lsp.ResponseError, result: any}>) (function) +--- The callback to call when all requests are finished. --- Unlike `buf_request`, this will collect all the responses from each server instead of handling them. ---- A map of client_id:request_result will be provided to the callback +--- A map of client_id:request_result will be provided to the callback. --- ---@return fun() cancel A function that will cancel all requests function lsp.buf_request_all(bufnr, method, params, callback) @@ -2037,9 +2044,8 @@ end ---@param timeout_ms (integer|nil) Maximum time in milliseconds to wait for a --- result. Defaults to 1000 --- ----@return table<integer, any>|nil result, string|nil err Map of client_id:request_result. ---- On timeout, cancel or error, returns `(nil, err)` where `err` is a string describing ---- the failure reason. +---@return table<integer, {err: lsp.ResponseError, result: any}>|nil (table) result Map of client_id:request_result. +---@return string|nil err On timeout, cancel, or error, `err` is a string describing the failure reason, and `result` is nil. function lsp.buf_request_sync(bufnr, method, params, timeout_ms) local request_results diff --git a/runtime/lua/vim/lsp/buf.lua b/runtime/lua/vim/lsp/buf.lua index 6ac885c78f..0e16e8f820 100644 --- a/runtime/lua/vim/lsp/buf.lua +++ b/runtime/lua/vim/lsp/buf.lua @@ -118,8 +118,10 @@ function M.completion(context) end ---@private +---@param bufnr integer +---@param mode "v"|"V" ---@return table {start={row, col}, end={row, col}} using (1, 0) indexing -local function range_from_selection() +local function range_from_selection(bufnr, mode) -- TODO: Use `vim.region()` instead https://github.com/neovim/neovim/pull/13896 -- [bufnum, lnum, col, off]; both row and column 1-indexed @@ -138,6 +140,11 @@ local function range_from_selection() start_row, end_row = end_row, start_row start_col, end_col = end_col, start_col end + if mode == 'V' then + start_col = 1 + local lines = api.nvim_buf_get_lines(bufnr, end_row - 1, end_row, true) + end_col = #lines[1] + end return { ['start'] = { start_row, start_col - 1 }, ['end'] = { end_row, end_col - 1 }, @@ -200,7 +207,7 @@ function M.format(options) local mode = api.nvim_get_mode().mode local range = options.range if not range and mode == 'v' or mode == 'V' then - range = range_from_selection() + range = range_from_selection(bufnr, mode) end local method = range and 'textDocument/rangeFormatting' or 'textDocument/formatting' @@ -772,7 +779,7 @@ function M.code_action(options) local end_ = assert(options.range['end'], 'range must have a `end` property') params = util.make_given_range_params(start, end_) elseif mode == 'v' or mode == 'V' then - local range = range_from_selection() + local range = range_from_selection(0, mode) params = util.make_given_range_params(range.start, range['end']) else params = util.make_range_params() diff --git a/runtime/lua/vim/lsp/protocol.lua b/runtime/lua/vim/lsp/protocol.lua index 27dd68645a..1686e22c48 100644 --- a/runtime/lua/vim/lsp/protocol.lua +++ b/runtime/lua/vim/lsp/protocol.lua @@ -854,7 +854,6 @@ function protocol.make_client_capabilities() } end -local if_nil = vim.F.if_nil --- Creates a normalized object describing LSP server capabilities. ---@param server_capabilities table Table of capabilities supported by the server ---@return table Normalized table of capabilities @@ -892,178 +891,5 @@ function protocol.resolve_capabilities(server_capabilities) return server_capabilities end ----@private ---- Creates a normalized object describing LSP server capabilities. --- @deprecated access resolved_capabilities instead ----@param server_capabilities table Table of capabilities supported by the server ----@return table Normalized table of capabilities -function protocol._resolve_capabilities_compat(server_capabilities) - local general_properties = {} - local text_document_sync_properties - do - local TextDocumentSyncKind = protocol.TextDocumentSyncKind - local textDocumentSync = server_capabilities.textDocumentSync - if textDocumentSync == nil then - -- Defaults if omitted. - text_document_sync_properties = { - text_document_open_close = false, - text_document_did_change = TextDocumentSyncKind.None, - -- text_document_did_change = false; - text_document_will_save = false, - text_document_will_save_wait_until = false, - text_document_save = false, - text_document_save_include_text = false, - } - elseif type(textDocumentSync) == 'number' then - -- Backwards compatibility - if not TextDocumentSyncKind[textDocumentSync] then - return nil, 'Invalid server TextDocumentSyncKind for textDocumentSync' - end - text_document_sync_properties = { - text_document_open_close = true, - text_document_did_change = textDocumentSync, - text_document_will_save = false, - text_document_will_save_wait_until = false, - text_document_save = true, - text_document_save_include_text = false, - } - elseif type(textDocumentSync) == 'table' then - text_document_sync_properties = { - text_document_open_close = if_nil(textDocumentSync.openClose, false), - text_document_did_change = if_nil(textDocumentSync.change, TextDocumentSyncKind.None), - text_document_will_save = if_nil(textDocumentSync.willSave, true), - text_document_will_save_wait_until = if_nil(textDocumentSync.willSaveWaitUntil, true), - text_document_save = if_nil(textDocumentSync.save, false), - text_document_save_include_text = if_nil( - type(textDocumentSync.save) == 'table' and textDocumentSync.save.includeText, - false - ), - } - else - return nil, string.format('Invalid type for textDocumentSync: %q', type(textDocumentSync)) - end - end - general_properties.completion = server_capabilities.completionProvider ~= nil - general_properties.hover = server_capabilities.hoverProvider or false - general_properties.goto_definition = server_capabilities.definitionProvider or false - general_properties.find_references = server_capabilities.referencesProvider or false - general_properties.document_highlight = server_capabilities.documentHighlightProvider or false - general_properties.document_symbol = server_capabilities.documentSymbolProvider or false - general_properties.workspace_symbol = server_capabilities.workspaceSymbolProvider or false - general_properties.document_formatting = server_capabilities.documentFormattingProvider or false - general_properties.document_range_formatting = server_capabilities.documentRangeFormattingProvider - or false - general_properties.call_hierarchy = server_capabilities.callHierarchyProvider or false - general_properties.execute_command = server_capabilities.executeCommandProvider ~= nil - - if server_capabilities.renameProvider == nil then - general_properties.rename = false - elseif type(server_capabilities.renameProvider) == 'boolean' then - general_properties.rename = server_capabilities.renameProvider - else - general_properties.rename = true - end - - if server_capabilities.codeLensProvider == nil then - general_properties.code_lens = false - general_properties.code_lens_resolve = false - elseif type(server_capabilities.codeLensProvider) == 'table' then - general_properties.code_lens = true - general_properties.code_lens_resolve = server_capabilities.codeLensProvider.resolveProvider - or false - else - error('The server sent invalid codeLensProvider') - end - - if server_capabilities.codeActionProvider == nil then - general_properties.code_action = false - elseif - type(server_capabilities.codeActionProvider) == 'boolean' - or type(server_capabilities.codeActionProvider) == 'table' - then - general_properties.code_action = server_capabilities.codeActionProvider - else - error('The server sent invalid codeActionProvider') - end - - if server_capabilities.declarationProvider == nil then - general_properties.declaration = false - elseif type(server_capabilities.declarationProvider) == 'boolean' then - general_properties.declaration = server_capabilities.declarationProvider - elseif type(server_capabilities.declarationProvider) == 'table' then - general_properties.declaration = server_capabilities.declarationProvider - else - error('The server sent invalid declarationProvider') - end - - if server_capabilities.typeDefinitionProvider == nil then - general_properties.type_definition = false - elseif type(server_capabilities.typeDefinitionProvider) == 'boolean' then - general_properties.type_definition = server_capabilities.typeDefinitionProvider - elseif type(server_capabilities.typeDefinitionProvider) == 'table' then - general_properties.type_definition = server_capabilities.typeDefinitionProvider - else - error('The server sent invalid typeDefinitionProvider') - end - - if server_capabilities.implementationProvider == nil then - general_properties.implementation = false - elseif type(server_capabilities.implementationProvider) == 'boolean' then - general_properties.implementation = server_capabilities.implementationProvider - elseif type(server_capabilities.implementationProvider) == 'table' then - general_properties.implementation = server_capabilities.implementationProvider - else - error('The server sent invalid implementationProvider') - end - - local workspace = server_capabilities.workspace - local workspace_properties = {} - if workspace == nil or workspace.workspaceFolders == nil then - -- Defaults if omitted. - workspace_properties = { - workspace_folder_properties = { - supported = false, - changeNotifications = false, - }, - } - elseif type(workspace.workspaceFolders) == 'table' then - workspace_properties = { - workspace_folder_properties = { - supported = if_nil(workspace.workspaceFolders.supported, false), - changeNotifications = if_nil(workspace.workspaceFolders.changeNotifications, false), - }, - } - else - error('The server sent invalid workspace') - end - - local signature_help_properties - if server_capabilities.signatureHelpProvider == nil then - signature_help_properties = { - signature_help = false, - signature_help_trigger_characters = {}, - } - elseif type(server_capabilities.signatureHelpProvider) == 'table' then - signature_help_properties = { - signature_help = true, - -- The characters that trigger signature help automatically. - signature_help_trigger_characters = server_capabilities.signatureHelpProvider.triggerCharacters - or {}, - } - else - error('The server sent invalid signatureHelpProvider') - end - - local capabilities = vim.tbl_extend( - 'error', - text_document_sync_properties, - signature_help_properties, - workspace_properties, - general_properties - ) - - return capabilities -end - return protocol -- vim:sw=2 ts=2 et diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua index 435cb9fdb6..90f4394fcc 100644 --- a/runtime/lua/vim/treesitter/_fold.lua +++ b/runtime/lua/vim/treesitter/_fold.lua @@ -1,4 +1,5 @@ local Range = require('vim.treesitter._range') +local Query = require('vim.treesitter.query') local api = vim.api @@ -74,18 +75,6 @@ function FoldInfo:get_stop(lnum) return self.stop_counts[lnum] or 0 end ----@private ---- TODO(lewis6991): copied from languagetree.lua. Consolidate ----@param node TSNode ----@param metadata TSMetadata ----@return Range4 -local function get_range_from_metadata(node, metadata) - if metadata and metadata.range then - return metadata.range --[[@as Range4]] - end - return { node:range() } -end - local function trim_level(level) local max_fold_level = vim.wo.foldnestmax if level > max_fold_level then @@ -118,7 +107,7 @@ local function get_folds_levels(bufnr, info, srow, erow) for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do if query.captures[id] == 'fold' then - local range = get_range_from_metadata(node, metadata[id]) + local range = Query.get_range(node, bufnr, metadata[id]) local start, _, stop, stop_col = Range.unpack4(range) if stop_col == 0 then @@ -211,8 +200,9 @@ end local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row) local end_row_old = start_row + old_row local end_row_new = start_row + new_row + if new_row < old_row then - foldinfo:remove_range(end_row_old, end_row_new) + foldinfo:remove_range(end_row_new, end_row_old) elseif new_row > old_row then foldinfo:add_range(start_row, end_row_new) vim.schedule(function() diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua index 02918da23f..f4db5016ac 100644 --- a/runtime/lua/vim/treesitter/_range.lua +++ b/runtime/lua/vim/treesitter/_range.lua @@ -2,8 +2,21 @@ local api = vim.api local M = {} ----@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer} ----@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer} +---@class Range4 +---@field [1] integer start row +---@field [2] integer start column +---@field [3] integer end row +---@field [4] integer end column + +---@class Range6 +---@field [1] integer start row +---@field [2] integer start column +---@field [3] integer start bytes +---@field [4] integer end row +---@field [5] integer end column +---@field [6] integer end bytes + +---@alias Range Range4|Range6 ---@private ---@param a_row integer @@ -74,8 +87,8 @@ function M.validate(r) end ---@private ----@param r1 Range4|Range6 ----@param r2 Range4|Range6 +---@param r1 Range +---@param r2 Range ---@return boolean function M.intercepts(r1, r2) local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) @@ -95,7 +108,7 @@ function M.intercepts(r1, r2) end ---@private ----@param r Range4|Range6 +---@param r Range ---@return integer, integer, integer, integer function M.unpack4(r) local off_1 = #r == 6 and 1 or 0 @@ -110,8 +123,8 @@ function M.unpack6(r) end ---@private ----@param r1 Range4|Range6 ----@param r2 Range4|Range6 +---@param r1 Range +---@param r2 Range ---@return boolean whether r1 contains r2 function M.contains(r1, r2) local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) @@ -132,7 +145,7 @@ end ---@private ---@param source integer|string ----@param range Range4|Range6 +---@param range Range ---@return Range6 function M.add_bytes(source, range) if type(range) == 'table' and #range == 6 then diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 26321cd1f4..bdfe281a5b 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -448,7 +448,7 @@ end --- nodes, which is useful for templating languages like ERB and EJS. --- ---@private ----@param new_regions Range4[][] List of regions this tree should manage and parse. +---@param new_regions Range6[][] List of regions this tree should manage and parse. function LanguageTree:set_included_regions(new_regions) -- Transform the tables from 4 element long to 6 element long (with byte offset) for _, region in ipairs(new_regions) do @@ -478,24 +478,11 @@ end ---@private ---@param node TSNode ----@param source integer|string ----@param metadata TSMetadata ----@return Range6 -local function get_range_from_metadata(node, source, metadata) - if metadata and metadata.range then - return Range.add_bytes(source, metadata.range --[[@as Range4|Range6]]) - end - return { node:range(true) } -end - ----@private ---- TODO(lewis6991): cleanup of the node_range interface ----@param node TSNode ---@param source string|integer ---@param metadata TSMetadata ---@return Range6[] local function get_node_ranges(node, source, metadata, include_children) - local range = get_range_from_metadata(node, source, metadata) + local range = query.get_range(node, source, metadata) if include_children then return { range } @@ -535,7 +522,7 @@ end ---@param pattern integer ---@param lang string ---@param combined boolean ----@param ranges Range4[] +---@param ranges Range6[] local function add_injection(t, tree_index, pattern, lang, combined, ranges) assert(type(lang) == 'string') @@ -559,30 +546,15 @@ local function add_injection(t, tree_index, pattern, lang, combined, ranges) end ---@private ----Get node text ---- ----Note: `query.get_node_text` returns string|string[]|nil so use this simple alias function ----to annotate it returns string. ---- ----TODO(lewis6991): use [at]overload annotations on `query.get_node_text` ----@param node TSNode ----@param source integer|string ----@param metadata table ----@return string -local function get_node_text(node, source, metadata) - return query.get_node_text(node, source, { metadata = metadata }) --[[@as string]] -end - ----@private --- Extract injections according to: --- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection ---@param match table<integer,TSNode> ----@param metadata table ----@return string, boolean, Range4[] +---@param metadata TSMetadata +---@return string?, boolean, Range6[] function LanguageTree:_get_injection(match, metadata) - local ranges = {} ---@type Range4[] + local ranges = {} ---@type Range6[] local combined = metadata['injection.combined'] ~= nil - local lang = metadata['injection.language'] ---@type string + local lang = metadata['injection.language'] --[[@as string?]] local include_children = metadata['injection.include-children'] ~= nil for id, node in pairs(match) do @@ -590,7 +562,7 @@ function LanguageTree:_get_injection(match, metadata) -- Lang should override any other language tag if name == 'injection.language' then - lang = get_node_text(node, self._source, metadata[id]) + lang = query.get_node_text(node, self._source, { metadata = metadata[id] }) elseif name == 'injection.content' then ranges = get_node_ranges(node, self._source, metadata[id], include_children) end @@ -601,11 +573,11 @@ end ---@private ---@param match table<integer,TSNode> ----@param metadata table ----@return string, boolean, Range4[] +---@param metadata TSMetadata +---@return string, boolean, Range6[] function LanguageTree:_get_injection_deprecated(match, metadata) local lang = nil ---@type string - local ranges = {} ---@type Range4[] + local ranges = {} ---@type Range6[] local combined = metadata.combined ~= nil -- Directives can configure how injections are captured as well as actual node captures. @@ -623,8 +595,10 @@ function LanguageTree:_get_injection_deprecated(match, metadata) end end - if metadata.language then - lang = metadata.language ---@type string + local mlang = metadata.language + if mlang ~= nil then + assert(type(mlang) == 'string') + lang = mlang end -- You can specify the content and language together @@ -635,11 +609,11 @@ function LanguageTree:_get_injection_deprecated(match, metadata) -- Lang should override any other language tag if name == 'language' and not lang then - lang = get_node_text(node, self._source, metadata[id]) + lang = query.get_node_text(node, self._source, { metadata = metadata[id] }) elseif name == 'combined' then combined = true elseif name == 'content' and #ranges == 0 then - ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id]) + ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id]) -- Ignore any tags that start with "_" -- Allows for other tags to be used in matches elseif string.sub(name, 1, 1) ~= '_' then @@ -648,7 +622,7 @@ function LanguageTree:_get_injection_deprecated(match, metadata) end if #ranges == 0 then - ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id]) + ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id]) end end end @@ -926,7 +900,7 @@ end ---@private ---@param tree TSTree ----@param range Range4 +---@param range Range ---@return boolean local function tree_contains(tree, range) return Range.contains({ tree:root():range() }, range) diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index e7cf42283d..f4e038b2d8 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,6 +1,8 @@ local a = vim.api local language = require('vim.treesitter.language') +local Range = require('vim.treesitter._range') + ---@class Query ---@field captures string[] List of captures used in query ---@field info TSQueryInfo Contains used queries, predicates, directives @@ -56,35 +58,21 @@ local function add_included_lang(base_langs, lang, ilang) end ---@private ----@param buf (integer) ----@param range (table) ----@param concat (boolean) ----@returns (string[]|string|nil) -local function buf_range_get_text(buf, range, concat) - local lines - local start_row, start_col, end_row, end_col = unpack(range) - local eof_row = a.nvim_buf_line_count(buf) - if start_row >= eof_row then - return nil - end - +---@param buf integer +---@param range Range +---@returns string +local function buf_range_get_text(buf, range) + local start_row, start_col, end_row, end_col = Range.unpack4(range) if end_col == 0 then - lines = a.nvim_buf_get_lines(buf, start_row, end_row, true) - end_col = -1 - else - lines = a.nvim_buf_get_lines(buf, start_row, end_row + 1, true) - end - - if #lines > 0 then - if #lines == 1 then - lines[1] = string.sub(lines[1], start_col + 1, end_col) - else - lines[1] = string.sub(lines[1], start_col + 1) - lines[#lines] = string.sub(lines[#lines], 1, end_col) + if start_row == end_row then + start_col = -1 + start_row = start_row - 1 end + end_col = -1 + end_row = end_row - 1 end - - return concat and table.concat(lines, '\n') or lines + local lines = a.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {}) + return table.concat(lines, '\n') end --- Gets the list of files used to make up a query @@ -256,14 +244,28 @@ function M.parse_query(lang, query) local cached = query_cache[lang][query] if cached then return cached - else - local self = setmetatable({}, Query) - self.query = vim._ts_parse_query(lang, query) - self.info = self.query:inspect() - self.captures = self.info.captures - query_cache[lang][query] = self - return self end + + local self = setmetatable({}, Query) + self.query = vim._ts_parse_query(lang, query) + self.info = self.query:inspect() + self.captures = self.info.captures + query_cache[lang][query] = self + return self +end + +---Get the range of a |TSNode|. Can also supply {source} and {metadata} +---to get the range with directives applied. +---@param node TSNode +---@param source integer|string|nil Buffer or string from which the {node} is extracted +---@param metadata TSMetadata|nil +---@return Range6 +function M.get_range(node, source, metadata) + if metadata and metadata.range then + assert(source) + return Range.add_bytes(source, metadata.range) + end + return { node:range(true) } end --- Gets the text corresponding to a given node @@ -271,24 +273,22 @@ end ---@param node TSNode ---@param source (integer|string) Buffer or string from which the {node} is extracted ---@param opts (table|nil) Optional parameters. ---- - concat: (boolean) Concatenate result in a string (default true) --- - metadata (table) Metadata of a specific capture. This would be --- set to `metadata[capture_id]` when using |vim.treesitter.add_directive()|. ----@return (string[]|string|nil) +---@return string function M.get_node_text(node, source, opts) opts = opts or {} - -- TODO(lewis6991): concat only works when source is number. - local concat = vim.F.if_nil(opts.concat, true) local metadata = opts.metadata or {} if metadata.text then return metadata.text elseif type(source) == 'number' then - return metadata.range and buf_range_get_text(source, metadata.range, concat) - or buf_range_get_text(source, { node:range() }, concat) - elseif type(source) == 'string' then - return source:sub(select(3, node:start()) + 1, select(3, node:end_())) + local range = M.get_range(node, source, metadata) + return buf_range_get_text(source, range) end + + ---@cast source string + return source:sub(select(3, node:start()) + 1, select(3, node:end_())) end ---@alias TSMatch table<integer,TSNode> @@ -312,7 +312,7 @@ local predicate_handlers = { str = predicate[3] else -- (#eq? @aa @bb) - str = M.get_node_text(match[predicate[3]], source) --[[@as string]] + str = M.get_node_text(match[predicate[3]], source) end if node_text ~= str or str == nil then @@ -328,7 +328,7 @@ local predicate_handlers = { return true end local regex = predicate[3] - return string.find(M.get_node_text(node, source) --[[@as string]], regex) ~= nil + return string.find(M.get_node_text(node, source), regex) ~= nil end, ['match?'] = (function() @@ -366,7 +366,7 @@ local predicate_handlers = { if not node then return true end - local node_text = M.get_node_text(node, source) --[[@as string]] + local node_text = M.get_node_text(node, source) for i = 3, #predicate do if string.find(node_text, predicate[i], 1, true) then @@ -404,9 +404,9 @@ local predicate_handlers = { predicate_handlers['vim-match?'] = predicate_handlers['match?'] ---@class TSMetadata +---@field range Range ---@field [integer] TSMetadata ---@field [string] integer|string ----@field range Range4 ---@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata) @@ -465,13 +465,20 @@ local directive_handlers = { assert(#pred == 4) local id = pred[2] + assert(type(id) == 'number') + local node = match[id] local text = M.get_node_text(node, bufnr, { metadata = metadata[id] }) or '' if not metadata[id] then metadata[id] = {} end - metadata[id].text = text:gsub(pred[3], pred[4]) + + local pattern, replacement = pred[3], pred[3] + assert(type(pattern) == 'string') + assert(type(replacement) == 'string') + + metadata[id].text = text:gsub(pattern, replacement) end, } |