aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua')
-rw-r--r--runtime/lua/vim/_editor.lua39
-rw-r--r--runtime/lua/vim/lsp.lua20
-rw-r--r--runtime/lua/vim/lsp/buf.lua13
-rw-r--r--runtime/lua/vim/lsp/protocol.lua174
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua18
-rw-r--r--runtime/lua/vim/treesitter/_range.lua29
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua64
-rw-r--r--runtime/lua/vim/treesitter/query.lua99
8 files changed, 147 insertions, 309 deletions
diff --git a/runtime/lua/vim/_editor.lua b/runtime/lua/vim/_editor.lua
index 9516233b45..5445c4e492 100644
--- a/runtime/lua/vim/_editor.lua
+++ b/runtime/lua/vim/_editor.lua
@@ -778,22 +778,37 @@ do
end
end
----Prints given arguments in human-readable format.
----Example:
----<pre>lua
---- -- Print highlight group Normal and store it's contents in a variable.
---- local hl_normal = vim.pretty_print(vim.api.nvim_get_hl_by_name("Normal", true))
----</pre>
----@see |vim.inspect()|
----@return any # given arguments.
+---@private
function vim.pretty_print(...)
- local objects = {}
+ vim.deprecate('vim.pretty_print', 'vim.print', '0.10')
+ return vim.print(...)
+end
+
+--- "Pretty prints" the given arguments and returns them unmodified.
+---
+--- Example:
+--- <pre>lua
+--- local hl_normal = vim.print(vim.api.nvim_get_hl_by_name('Normal', true))
+--- </pre>
+---
+--- @see |vim.inspect()|
+--- @return any # given arguments.
+function vim.print(...)
+ if vim.in_fast_event() then
+ print(...)
+ return ...
+ end
+
for i = 1, select('#', ...) do
- local v = select(i, ...)
- table.insert(objects, vim.inspect(v))
+ local o = select(i, ...)
+ if type(o) == 'string' then
+ vim.api.nvim_out_write(o)
+ else
+ vim.api.nvim_out_write(vim.inspect(o, { newline = '\n', indent = ' ' }))
+ end
+ vim.api.nvim_out_write('\n')
end
- print(table.concat(objects, ' '))
return ...
end
diff --git a/runtime/lua/vim/lsp.lua b/runtime/lua/vim/lsp.lua
index 117b32dc57..39665a3d4f 100644
--- a/runtime/lua/vim/lsp.lua
+++ b/runtime/lua/vim/lsp.lua
@@ -1,3 +1,4 @@
+---@diagnostic disable: invisible
local default_handlers = require('vim.lsp.handlers')
local log = require('vim.lsp.log')
local lsp_rpc = require('vim.lsp.rpc')
@@ -1037,7 +1038,7 @@ function lsp.start_client(config)
--- Returns the default handler if the user hasn't set a custom one.
---
---@param method (string) LSP method name
- ---@return function|nil The handler for the given method, if defined, or the default from |vim.lsp.handlers|
+ ---@return lsp-handler|nil The handler for the given method, if defined, or the default from |vim.lsp.handlers|
local function resolve_handler(method)
return handlers[method] or default_handlers[method]
end
@@ -1592,6 +1593,11 @@ local function text_document_did_save_handler(bufnr)
local name = api.nvim_buf_get_name(bufnr)
local old_name = changetracking._get_and_set_name(client, bufnr, name)
if old_name and name ~= old_name then
+ client.notify('textDocument/didClose', {
+ textDocument = {
+ uri = vim.uri_from_fname(old_name),
+ },
+ })
client.notify('textDocument/didOpen', {
textDocument = {
version = 0,
@@ -1932,7 +1938,7 @@ api.nvim_create_autocmd('VimLeavePre', {
---@param bufnr (integer) Buffer handle, or 0 for current.
---@param method (string) LSP method name
---@param params table|nil Parameters to send to the server
----@param handler function|nil See |lsp-handler|
+---@param handler lsp-handler|nil See |lsp-handler|
--- If nil, follows resolution strategy defined in |lsp-handler-configuration|
---
---@return table<integer, integer>, fun() 2-tuple:
@@ -1992,9 +1998,10 @@ end
---@param bufnr (integer) Buffer handle, or 0 for current.
---@param method (string) LSP method name
---@param params (table|nil) Parameters to send to the server
----@param callback (function) The callback to call when all requests are finished.
+---@param callback fun(request_results: table<integer, {error: lsp.ResponseError, result: any}>) (function)
+--- The callback to call when all requests are finished.
--- Unlike `buf_request`, this will collect all the responses from each server instead of handling them.
---- A map of client_id:request_result will be provided to the callback
+--- A map of client_id:request_result will be provided to the callback.
---
---@return fun() cancel A function that will cancel all requests
function lsp.buf_request_all(bufnr, method, params, callback)
@@ -2037,9 +2044,8 @@ end
---@param timeout_ms (integer|nil) Maximum time in milliseconds to wait for a
--- result. Defaults to 1000
---
----@return table<integer, any>|nil result, string|nil err Map of client_id:request_result.
---- On timeout, cancel or error, returns `(nil, err)` where `err` is a string describing
---- the failure reason.
+---@return table<integer, {err: lsp.ResponseError, result: any}>|nil (table) result Map of client_id:request_result.
+---@return string|nil err On timeout, cancel, or error, `err` is a string describing the failure reason, and `result` is nil.
function lsp.buf_request_sync(bufnr, method, params, timeout_ms)
local request_results
diff --git a/runtime/lua/vim/lsp/buf.lua b/runtime/lua/vim/lsp/buf.lua
index 6ac885c78f..0e16e8f820 100644
--- a/runtime/lua/vim/lsp/buf.lua
+++ b/runtime/lua/vim/lsp/buf.lua
@@ -118,8 +118,10 @@ function M.completion(context)
end
---@private
+---@param bufnr integer
+---@param mode "v"|"V"
---@return table {start={row, col}, end={row, col}} using (1, 0) indexing
-local function range_from_selection()
+local function range_from_selection(bufnr, mode)
-- TODO: Use `vim.region()` instead https://github.com/neovim/neovim/pull/13896
-- [bufnum, lnum, col, off]; both row and column 1-indexed
@@ -138,6 +140,11 @@ local function range_from_selection()
start_row, end_row = end_row, start_row
start_col, end_col = end_col, start_col
end
+ if mode == 'V' then
+ start_col = 1
+ local lines = api.nvim_buf_get_lines(bufnr, end_row - 1, end_row, true)
+ end_col = #lines[1]
+ end
return {
['start'] = { start_row, start_col - 1 },
['end'] = { end_row, end_col - 1 },
@@ -200,7 +207,7 @@ function M.format(options)
local mode = api.nvim_get_mode().mode
local range = options.range
if not range and mode == 'v' or mode == 'V' then
- range = range_from_selection()
+ range = range_from_selection(bufnr, mode)
end
local method = range and 'textDocument/rangeFormatting' or 'textDocument/formatting'
@@ -772,7 +779,7 @@ function M.code_action(options)
local end_ = assert(options.range['end'], 'range must have a `end` property')
params = util.make_given_range_params(start, end_)
elseif mode == 'v' or mode == 'V' then
- local range = range_from_selection()
+ local range = range_from_selection(0, mode)
params = util.make_given_range_params(range.start, range['end'])
else
params = util.make_range_params()
diff --git a/runtime/lua/vim/lsp/protocol.lua b/runtime/lua/vim/lsp/protocol.lua
index 27dd68645a..1686e22c48 100644
--- a/runtime/lua/vim/lsp/protocol.lua
+++ b/runtime/lua/vim/lsp/protocol.lua
@@ -854,7 +854,6 @@ function protocol.make_client_capabilities()
}
end
-local if_nil = vim.F.if_nil
--- Creates a normalized object describing LSP server capabilities.
---@param server_capabilities table Table of capabilities supported by the server
---@return table Normalized table of capabilities
@@ -892,178 +891,5 @@ function protocol.resolve_capabilities(server_capabilities)
return server_capabilities
end
----@private
---- Creates a normalized object describing LSP server capabilities.
--- @deprecated access resolved_capabilities instead
----@param server_capabilities table Table of capabilities supported by the server
----@return table Normalized table of capabilities
-function protocol._resolve_capabilities_compat(server_capabilities)
- local general_properties = {}
- local text_document_sync_properties
- do
- local TextDocumentSyncKind = protocol.TextDocumentSyncKind
- local textDocumentSync = server_capabilities.textDocumentSync
- if textDocumentSync == nil then
- -- Defaults if omitted.
- text_document_sync_properties = {
- text_document_open_close = false,
- text_document_did_change = TextDocumentSyncKind.None,
- -- text_document_did_change = false;
- text_document_will_save = false,
- text_document_will_save_wait_until = false,
- text_document_save = false,
- text_document_save_include_text = false,
- }
- elseif type(textDocumentSync) == 'number' then
- -- Backwards compatibility
- if not TextDocumentSyncKind[textDocumentSync] then
- return nil, 'Invalid server TextDocumentSyncKind for textDocumentSync'
- end
- text_document_sync_properties = {
- text_document_open_close = true,
- text_document_did_change = textDocumentSync,
- text_document_will_save = false,
- text_document_will_save_wait_until = false,
- text_document_save = true,
- text_document_save_include_text = false,
- }
- elseif type(textDocumentSync) == 'table' then
- text_document_sync_properties = {
- text_document_open_close = if_nil(textDocumentSync.openClose, false),
- text_document_did_change = if_nil(textDocumentSync.change, TextDocumentSyncKind.None),
- text_document_will_save = if_nil(textDocumentSync.willSave, true),
- text_document_will_save_wait_until = if_nil(textDocumentSync.willSaveWaitUntil, true),
- text_document_save = if_nil(textDocumentSync.save, false),
- text_document_save_include_text = if_nil(
- type(textDocumentSync.save) == 'table' and textDocumentSync.save.includeText,
- false
- ),
- }
- else
- return nil, string.format('Invalid type for textDocumentSync: %q', type(textDocumentSync))
- end
- end
- general_properties.completion = server_capabilities.completionProvider ~= nil
- general_properties.hover = server_capabilities.hoverProvider or false
- general_properties.goto_definition = server_capabilities.definitionProvider or false
- general_properties.find_references = server_capabilities.referencesProvider or false
- general_properties.document_highlight = server_capabilities.documentHighlightProvider or false
- general_properties.document_symbol = server_capabilities.documentSymbolProvider or false
- general_properties.workspace_symbol = server_capabilities.workspaceSymbolProvider or false
- general_properties.document_formatting = server_capabilities.documentFormattingProvider or false
- general_properties.document_range_formatting = server_capabilities.documentRangeFormattingProvider
- or false
- general_properties.call_hierarchy = server_capabilities.callHierarchyProvider or false
- general_properties.execute_command = server_capabilities.executeCommandProvider ~= nil
-
- if server_capabilities.renameProvider == nil then
- general_properties.rename = false
- elseif type(server_capabilities.renameProvider) == 'boolean' then
- general_properties.rename = server_capabilities.renameProvider
- else
- general_properties.rename = true
- end
-
- if server_capabilities.codeLensProvider == nil then
- general_properties.code_lens = false
- general_properties.code_lens_resolve = false
- elseif type(server_capabilities.codeLensProvider) == 'table' then
- general_properties.code_lens = true
- general_properties.code_lens_resolve = server_capabilities.codeLensProvider.resolveProvider
- or false
- else
- error('The server sent invalid codeLensProvider')
- end
-
- if server_capabilities.codeActionProvider == nil then
- general_properties.code_action = false
- elseif
- type(server_capabilities.codeActionProvider) == 'boolean'
- or type(server_capabilities.codeActionProvider) == 'table'
- then
- general_properties.code_action = server_capabilities.codeActionProvider
- else
- error('The server sent invalid codeActionProvider')
- end
-
- if server_capabilities.declarationProvider == nil then
- general_properties.declaration = false
- elseif type(server_capabilities.declarationProvider) == 'boolean' then
- general_properties.declaration = server_capabilities.declarationProvider
- elseif type(server_capabilities.declarationProvider) == 'table' then
- general_properties.declaration = server_capabilities.declarationProvider
- else
- error('The server sent invalid declarationProvider')
- end
-
- if server_capabilities.typeDefinitionProvider == nil then
- general_properties.type_definition = false
- elseif type(server_capabilities.typeDefinitionProvider) == 'boolean' then
- general_properties.type_definition = server_capabilities.typeDefinitionProvider
- elseif type(server_capabilities.typeDefinitionProvider) == 'table' then
- general_properties.type_definition = server_capabilities.typeDefinitionProvider
- else
- error('The server sent invalid typeDefinitionProvider')
- end
-
- if server_capabilities.implementationProvider == nil then
- general_properties.implementation = false
- elseif type(server_capabilities.implementationProvider) == 'boolean' then
- general_properties.implementation = server_capabilities.implementationProvider
- elseif type(server_capabilities.implementationProvider) == 'table' then
- general_properties.implementation = server_capabilities.implementationProvider
- else
- error('The server sent invalid implementationProvider')
- end
-
- local workspace = server_capabilities.workspace
- local workspace_properties = {}
- if workspace == nil or workspace.workspaceFolders == nil then
- -- Defaults if omitted.
- workspace_properties = {
- workspace_folder_properties = {
- supported = false,
- changeNotifications = false,
- },
- }
- elseif type(workspace.workspaceFolders) == 'table' then
- workspace_properties = {
- workspace_folder_properties = {
- supported = if_nil(workspace.workspaceFolders.supported, false),
- changeNotifications = if_nil(workspace.workspaceFolders.changeNotifications, false),
- },
- }
- else
- error('The server sent invalid workspace')
- end
-
- local signature_help_properties
- if server_capabilities.signatureHelpProvider == nil then
- signature_help_properties = {
- signature_help = false,
- signature_help_trigger_characters = {},
- }
- elseif type(server_capabilities.signatureHelpProvider) == 'table' then
- signature_help_properties = {
- signature_help = true,
- -- The characters that trigger signature help automatically.
- signature_help_trigger_characters = server_capabilities.signatureHelpProvider.triggerCharacters
- or {},
- }
- else
- error('The server sent invalid signatureHelpProvider')
- end
-
- local capabilities = vim.tbl_extend(
- 'error',
- text_document_sync_properties,
- signature_help_properties,
- workspace_properties,
- general_properties
- )
-
- return capabilities
-end
-
return protocol
-- vim:sw=2 ts=2 et
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
index 435cb9fdb6..90f4394fcc 100644
--- a/runtime/lua/vim/treesitter/_fold.lua
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -1,4 +1,5 @@
local Range = require('vim.treesitter._range')
+local Query = require('vim.treesitter.query')
local api = vim.api
@@ -74,18 +75,6 @@ function FoldInfo:get_stop(lnum)
return self.stop_counts[lnum] or 0
end
----@private
---- TODO(lewis6991): copied from languagetree.lua. Consolidate
----@param node TSNode
----@param metadata TSMetadata
----@return Range4
-local function get_range_from_metadata(node, metadata)
- if metadata and metadata.range then
- return metadata.range --[[@as Range4]]
- end
- return { node:range() }
-end
-
local function trim_level(level)
local max_fold_level = vim.wo.foldnestmax
if level > max_fold_level then
@@ -118,7 +107,7 @@ local function get_folds_levels(bufnr, info, srow, erow)
for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do
if query.captures[id] == 'fold' then
- local range = get_range_from_metadata(node, metadata[id])
+ local range = Query.get_range(node, bufnr, metadata[id])
local start, _, stop, stop_col = Range.unpack4(range)
if stop_col == 0 then
@@ -211,8 +200,9 @@ end
local function on_bytes(bufnr, foldinfo, start_row, old_row, new_row)
local end_row_old = start_row + old_row
local end_row_new = start_row + new_row
+
if new_row < old_row then
- foldinfo:remove_range(end_row_old, end_row_new)
+ foldinfo:remove_range(end_row_new, end_row_old)
elseif new_row > old_row then
foldinfo:add_range(start_row, end_row_new)
vim.schedule(function()
diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua
index 02918da23f..f4db5016ac 100644
--- a/runtime/lua/vim/treesitter/_range.lua
+++ b/runtime/lua/vim/treesitter/_range.lua
@@ -2,8 +2,21 @@ local api = vim.api
local M = {}
----@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
----@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer}
+---@class Range4
+---@field [1] integer start row
+---@field [2] integer start column
+---@field [3] integer end row
+---@field [4] integer end column
+
+---@class Range6
+---@field [1] integer start row
+---@field [2] integer start column
+---@field [3] integer start bytes
+---@field [4] integer end row
+---@field [5] integer end column
+---@field [6] integer end bytes
+
+---@alias Range Range4|Range6
---@private
---@param a_row integer
@@ -74,8 +87,8 @@ function M.validate(r)
end
---@private
----@param r1 Range4|Range6
----@param r2 Range4|Range6
+---@param r1 Range
+---@param r2 Range
---@return boolean
function M.intercepts(r1, r2)
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
@@ -95,7 +108,7 @@ function M.intercepts(r1, r2)
end
---@private
----@param r Range4|Range6
+---@param r Range
---@return integer, integer, integer, integer
function M.unpack4(r)
local off_1 = #r == 6 and 1 or 0
@@ -110,8 +123,8 @@ function M.unpack6(r)
end
---@private
----@param r1 Range4|Range6
----@param r2 Range4|Range6
+---@param r1 Range
+---@param r2 Range
---@return boolean whether r1 contains r2
function M.contains(r1, r2)
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
@@ -132,7 +145,7 @@ end
---@private
---@param source integer|string
----@param range Range4|Range6
+---@param range Range
---@return Range6
function M.add_bytes(source, range)
if type(range) == 'table' and #range == 6 then
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index 26321cd1f4..bdfe281a5b 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -448,7 +448,7 @@ end
--- nodes, which is useful for templating languages like ERB and EJS.
---
---@private
----@param new_regions Range4[][] List of regions this tree should manage and parse.
+---@param new_regions Range6[][] List of regions this tree should manage and parse.
function LanguageTree:set_included_regions(new_regions)
-- Transform the tables from 4 element long to 6 element long (with byte offset)
for _, region in ipairs(new_regions) do
@@ -478,24 +478,11 @@ end
---@private
---@param node TSNode
----@param source integer|string
----@param metadata TSMetadata
----@return Range6
-local function get_range_from_metadata(node, source, metadata)
- if metadata and metadata.range then
- return Range.add_bytes(source, metadata.range --[[@as Range4|Range6]])
- end
- return { node:range(true) }
-end
-
----@private
---- TODO(lewis6991): cleanup of the node_range interface
----@param node TSNode
---@param source string|integer
---@param metadata TSMetadata
---@return Range6[]
local function get_node_ranges(node, source, metadata, include_children)
- local range = get_range_from_metadata(node, source, metadata)
+ local range = query.get_range(node, source, metadata)
if include_children then
return { range }
@@ -535,7 +522,7 @@ end
---@param pattern integer
---@param lang string
---@param combined boolean
----@param ranges Range4[]
+---@param ranges Range6[]
local function add_injection(t, tree_index, pattern, lang, combined, ranges)
assert(type(lang) == 'string')
@@ -559,30 +546,15 @@ local function add_injection(t, tree_index, pattern, lang, combined, ranges)
end
---@private
----Get node text
----
----Note: `query.get_node_text` returns string|string[]|nil so use this simple alias function
----to annotate it returns string.
----
----TODO(lewis6991): use [at]overload annotations on `query.get_node_text`
----@param node TSNode
----@param source integer|string
----@param metadata table
----@return string
-local function get_node_text(node, source, metadata)
- return query.get_node_text(node, source, { metadata = metadata }) --[[@as string]]
-end
-
----@private
--- Extract injections according to:
--- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection
---@param match table<integer,TSNode>
----@param metadata table
----@return string, boolean, Range4[]
+---@param metadata TSMetadata
+---@return string?, boolean, Range6[]
function LanguageTree:_get_injection(match, metadata)
- local ranges = {} ---@type Range4[]
+ local ranges = {} ---@type Range6[]
local combined = metadata['injection.combined'] ~= nil
- local lang = metadata['injection.language'] ---@type string
+ local lang = metadata['injection.language'] --[[@as string?]]
local include_children = metadata['injection.include-children'] ~= nil
for id, node in pairs(match) do
@@ -590,7 +562,7 @@ function LanguageTree:_get_injection(match, metadata)
-- Lang should override any other language tag
if name == 'injection.language' then
- lang = get_node_text(node, self._source, metadata[id])
+ lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'injection.content' then
ranges = get_node_ranges(node, self._source, metadata[id], include_children)
end
@@ -601,11 +573,11 @@ end
---@private
---@param match table<integer,TSNode>
----@param metadata table
----@return string, boolean, Range4[]
+---@param metadata TSMetadata
+---@return string, boolean, Range6[]
function LanguageTree:_get_injection_deprecated(match, metadata)
local lang = nil ---@type string
- local ranges = {} ---@type Range4[]
+ local ranges = {} ---@type Range6[]
local combined = metadata.combined ~= nil
-- Directives can configure how injections are captured as well as actual node captures.
@@ -623,8 +595,10 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
end
end
- if metadata.language then
- lang = metadata.language ---@type string
+ local mlang = metadata.language
+ if mlang ~= nil then
+ assert(type(mlang) == 'string')
+ lang = mlang
end
-- You can specify the content and language together
@@ -635,11 +609,11 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
-- Lang should override any other language tag
if name == 'language' and not lang then
- lang = get_node_text(node, self._source, metadata[id])
+ lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'combined' then
combined = true
elseif name == 'content' and #ranges == 0 then
- ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id])
+ ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id])
-- Ignore any tags that start with "_"
-- Allows for other tags to be used in matches
elseif string.sub(name, 1, 1) ~= '_' then
@@ -648,7 +622,7 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
end
if #ranges == 0 then
- ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id])
+ ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id])
end
end
end
@@ -926,7 +900,7 @@ end
---@private
---@param tree TSTree
----@param range Range4
+---@param range Range
---@return boolean
local function tree_contains(tree, range)
return Range.contains({ tree:root():range() }, range)
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index e7cf42283d..f4e038b2d8 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -1,6 +1,8 @@
local a = vim.api
local language = require('vim.treesitter.language')
+local Range = require('vim.treesitter._range')
+
---@class Query
---@field captures string[] List of captures used in query
---@field info TSQueryInfo Contains used queries, predicates, directives
@@ -56,35 +58,21 @@ local function add_included_lang(base_langs, lang, ilang)
end
---@private
----@param buf (integer)
----@param range (table)
----@param concat (boolean)
----@returns (string[]|string|nil)
-local function buf_range_get_text(buf, range, concat)
- local lines
- local start_row, start_col, end_row, end_col = unpack(range)
- local eof_row = a.nvim_buf_line_count(buf)
- if start_row >= eof_row then
- return nil
- end
-
+---@param buf integer
+---@param range Range
+---@returns string
+local function buf_range_get_text(buf, range)
+ local start_row, start_col, end_row, end_col = Range.unpack4(range)
if end_col == 0 then
- lines = a.nvim_buf_get_lines(buf, start_row, end_row, true)
- end_col = -1
- else
- lines = a.nvim_buf_get_lines(buf, start_row, end_row + 1, true)
- end
-
- if #lines > 0 then
- if #lines == 1 then
- lines[1] = string.sub(lines[1], start_col + 1, end_col)
- else
- lines[1] = string.sub(lines[1], start_col + 1)
- lines[#lines] = string.sub(lines[#lines], 1, end_col)
+ if start_row == end_row then
+ start_col = -1
+ start_row = start_row - 1
end
+ end_col = -1
+ end_row = end_row - 1
end
-
- return concat and table.concat(lines, '\n') or lines
+ local lines = a.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {})
+ return table.concat(lines, '\n')
end
--- Gets the list of files used to make up a query
@@ -256,14 +244,28 @@ function M.parse_query(lang, query)
local cached = query_cache[lang][query]
if cached then
return cached
- else
- local self = setmetatable({}, Query)
- self.query = vim._ts_parse_query(lang, query)
- self.info = self.query:inspect()
- self.captures = self.info.captures
- query_cache[lang][query] = self
- return self
end
+
+ local self = setmetatable({}, Query)
+ self.query = vim._ts_parse_query(lang, query)
+ self.info = self.query:inspect()
+ self.captures = self.info.captures
+ query_cache[lang][query] = self
+ return self
+end
+
+---Get the range of a |TSNode|. Can also supply {source} and {metadata}
+---to get the range with directives applied.
+---@param node TSNode
+---@param source integer|string|nil Buffer or string from which the {node} is extracted
+---@param metadata TSMetadata|nil
+---@return Range6
+function M.get_range(node, source, metadata)
+ if metadata and metadata.range then
+ assert(source)
+ return Range.add_bytes(source, metadata.range)
+ end
+ return { node:range(true) }
end
--- Gets the text corresponding to a given node
@@ -271,24 +273,22 @@ end
---@param node TSNode
---@param source (integer|string) Buffer or string from which the {node} is extracted
---@param opts (table|nil) Optional parameters.
---- - concat: (boolean) Concatenate result in a string (default true)
--- - metadata (table) Metadata of a specific capture. This would be
--- set to `metadata[capture_id]` when using |vim.treesitter.add_directive()|.
----@return (string[]|string|nil)
+---@return string
function M.get_node_text(node, source, opts)
opts = opts or {}
- -- TODO(lewis6991): concat only works when source is number.
- local concat = vim.F.if_nil(opts.concat, true)
local metadata = opts.metadata or {}
if metadata.text then
return metadata.text
elseif type(source) == 'number' then
- return metadata.range and buf_range_get_text(source, metadata.range, concat)
- or buf_range_get_text(source, { node:range() }, concat)
- elseif type(source) == 'string' then
- return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
+ local range = M.get_range(node, source, metadata)
+ return buf_range_get_text(source, range)
end
+
+ ---@cast source string
+ return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
end
---@alias TSMatch table<integer,TSNode>
@@ -312,7 +312,7 @@ local predicate_handlers = {
str = predicate[3]
else
-- (#eq? @aa @bb)
- str = M.get_node_text(match[predicate[3]], source) --[[@as string]]
+ str = M.get_node_text(match[predicate[3]], source)
end
if node_text ~= str or str == nil then
@@ -328,7 +328,7 @@ local predicate_handlers = {
return true
end
local regex = predicate[3]
- return string.find(M.get_node_text(node, source) --[[@as string]], regex) ~= nil
+ return string.find(M.get_node_text(node, source), regex) ~= nil
end,
['match?'] = (function()
@@ -366,7 +366,7 @@ local predicate_handlers = {
if not node then
return true
end
- local node_text = M.get_node_text(node, source) --[[@as string]]
+ local node_text = M.get_node_text(node, source)
for i = 3, #predicate do
if string.find(node_text, predicate[i], 1, true) then
@@ -404,9 +404,9 @@ local predicate_handlers = {
predicate_handlers['vim-match?'] = predicate_handlers['match?']
---@class TSMetadata
+---@field range Range
---@field [integer] TSMetadata
---@field [string] integer|string
----@field range Range4
---@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata)
@@ -465,13 +465,20 @@ local directive_handlers = {
assert(#pred == 4)
local id = pred[2]
+ assert(type(id) == 'number')
+
local node = match[id]
local text = M.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
if not metadata[id] then
metadata[id] = {}
end
- metadata[id].text = text:gsub(pred[3], pred[4])
+
+ local pattern, replacement = pred[3], pred[3]
+ assert(type(pattern) == 'string')
+ assert(type(replacement) == 'string')
+
+ metadata[id].text = text:gsub(pattern, replacement)
end,
}