aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim
diff options
context:
space:
mode:
authorLewis Russell <lewis6991@gmail.com>2023-02-05 21:31:30 +0000
committerGitHub <noreply@github.com>2023-02-05 21:31:30 +0000
commit7963a160e90f9ded63db1dfa24e607ee70af18ba (patch)
treec392a7df8678f8b026882b641038609d6a2adb74 /runtime/lua/vim
parent23e34fe534d201a1323ab040cb2201d21fe865cc (diff)
parent4c66f5ff97a52fbc933fdbe1907c4b960d5a7403 (diff)
downloadrneovim-7963a160e90f9ded63db1dfa24e607ee70af18ba.tar.gz
rneovim-7963a160e90f9ded63db1dfa24e607ee70af18ba.tar.bz2
rneovim-7963a160e90f9ded63db1dfa24e607ee70af18ba.zip
Merge pull request #21548 from figsoda/transform-capture
feat(treesitter): allow capture text to be transformed
Diffstat (limited to 'runtime/lua/vim')
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua2
-rw-r--r--runtime/lua/vim/treesitter/query.lua96
2 files changed, 63 insertions, 35 deletions
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index 89aac3ae26..8255c6f4fe 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -408,7 +408,7 @@ function LanguageTree:_get_injections()
-- Lang should override any other language tag
if name == 'language' and not lang then
- lang = query.get_node_text(node, self._source) --[[@as string]]
+ lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'combined' then
combined = true
elseif name == 'content' and #ranges == 0 then
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 84ed2667b9..a0522d7cda 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -55,6 +55,38 @@ local function add_included_lang(base_langs, lang, ilang)
return false
end
+---@private
+---@param buf (number)
+---@param range (table)
+---@param concat (boolean)
+---@returns (string[]|string|nil)
+local function buf_range_get_text(buf, range, concat)
+ local lines
+ local start_row, start_col, end_row, end_col = unpack(range)
+ local eof_row = a.nvim_buf_line_count(buf)
+ if start_row >= eof_row then
+ return nil
+ end
+
+ if end_col == 0 then
+ lines = a.nvim_buf_get_lines(buf, start_row, end_row, true)
+ end_col = -1
+ else
+ lines = a.nvim_buf_get_lines(buf, start_row, end_row + 1, true)
+ end
+
+ if #lines > 0 then
+ if #lines == 1 then
+ lines[1] = string.sub(lines[1], start_col + 1, end_col)
+ else
+ lines[1] = string.sub(lines[1], start_col + 1)
+ lines[#lines] = string.sub(lines[#lines], 1, end_col)
+ end
+ end
+
+ return concat and table.concat(lines, '\n') or lines
+end
+
--- Gets the list of files used to make up a query
---
---@param lang string Language to get query for
@@ -240,40 +272,22 @@ end
---@param source (number|string) Buffer or string from which the {node} is extracted
---@param opts (table|nil) Optional parameters.
--- - concat: (boolean) Concatenate result in a string (default true)
+--- - metadata (table) Metadata of a specific capture. This would be
+--- set to `metadata[capture_id]` when using
+--- |vim.treesitter.query.add_directive()|.
---@return (string[]|string|nil)
function M.get_node_text(node, source, opts)
opts = opts or {}
local concat = vim.F.if_nil(opts.concat, true)
+ local metadata = opts.metadata or {}
- local start_row, start_col, start_byte = node:start()
- local end_row, end_col, end_byte = node:end_()
-
- if type(source) == 'number' then
- local eof_row = a.nvim_buf_line_count(source)
- if start_row >= eof_row then
- return nil
- end
-
- local lines ---@type string[]
- if end_col == 0 then
- lines = a.nvim_buf_get_lines(source, start_row, end_row, true)
- end_col = -1
- else
- lines = a.nvim_buf_get_lines(source, start_row, end_row + 1, true)
- end
-
- if #lines > 0 then
- if #lines == 1 then
- lines[1] = string.sub(lines[1], start_col + 1, end_col)
- else
- lines[1] = string.sub(lines[1], start_col + 1)
- lines[#lines] = string.sub(lines[#lines], 1, end_col)
- end
- end
-
- return concat and table.concat(lines, '\n') or lines
+ if metadata.text then
+ return metadata.text
+ elseif type(source) == 'number' then
+ return metadata.range and buf_range_get_text(source, metadata.range, concat)
+ or buf_range_get_text(source, { node:range() }, concat)
elseif type(source) == 'string' then
- return source:sub(start_byte + 1, end_byte)
+ return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
end
end
@@ -431,9 +445,11 @@ local directive_handlers = {
['offset!'] = function(match, _, _, pred, metadata)
---@cast pred integer[]
local capture_id = pred[2]
- local offset_node = match[capture_id]
- local range = { offset_node:range() }
- ---@cast range integer[] bug in sumneko
+ if not metadata[capture_id] then
+ metadata[capture_id] = {}
+ end
+
+ local range = metadata[capture_id].range or { match[capture_id]:range() }
local start_row_offset = pred[3] or 0
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
@@ -446,12 +462,24 @@ local directive_handlers = {
-- If this produces an invalid range, we just skip it.
if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
- if not metadata[capture_id] then
- metadata[capture_id] = {}
- end
metadata[capture_id].range = range
end
end,
+
+ -- Transform the content of the node
+ -- Example: (#gsub! @_node ".*%.(.*)" "%1")
+ ['gsub!'] = function(match, _, bufnr, pred, metadata)
+ assert(#pred == 4)
+
+ local id = pred[2]
+ local node = match[id]
+ local text = M.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
+
+ if not metadata[id] then
+ metadata[id] = {}
+ end
+ metadata[id].text = text:gsub(pred[3], pred[4])
+ end,
}
--- Adds a new predicate to be used in queries