aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLewis Russell <lewis6991@gmail.com>2023-03-10 16:10:05 +0000
committerLewis Russell <lewis6991@gmail.com>2023-03-10 16:35:06 +0000
commit9d70fe062ca01ac0673faa6ccbb88345916aeea7 (patch)
tree8b4db135bc63d7055185c808624f7555f96bc059
parent845efb8e12cb014b385deac62fb83622a99024ec (diff)
downloadrneovim-9d70fe062ca01ac0673faa6ccbb88345916aeea7.tar.gz
rneovim-9d70fe062ca01ac0673faa6ccbb88345916aeea7.tar.bz2
rneovim-9d70fe062ca01ac0673faa6ccbb88345916aeea7.zip
feat(treesitter)!: consolidate query util functions
- And address more type errors. - Removed the `concat` option from `get_node_text` since it was applied inconsistently and made typing awkward.
-rw-r--r--runtime/doc/treesitter.txt17
-rw-r--r--runtime/lua/vim/treesitter/_fold.lua15
-rw-r--r--runtime/lua/vim/treesitter/_range.lua15
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua55
-rw-r--r--runtime/lua/vim/treesitter/query.lua97
5 files changed, 93 insertions, 106 deletions
diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt
index c7a0e1927b..dc577a015e 100644
--- a/runtime/doc/treesitter.txt
+++ b/runtime/doc/treesitter.txt
@@ -789,14 +789,12 @@ get_node_text({node}, {source}, {opts})
• {source} (integer|string) Buffer or string from which the {node} is
extracted
• {opts} (table|nil) Optional parameters.
- • concat: (boolean) Concatenate result in a string (default
- true)
• metadata (table) Metadata of a specific capture. This
would be set to `metadata[capture_id]` when using
|vim.treesitter.add_directive()|.
Return: ~
- (string[]|string|nil)
+ (string)
get_query({lang}, {query_name}) *vim.treesitter.get_query()*
Returns the runtime query {query_name} for {lang}.
@@ -822,6 +820,19 @@ get_query_files({lang}, {query_name}, {is_included})
string[] query_files List of files to load for given query and
language
+get_range({node}, {source}, {metadata}) *vim.treesitter.get_range()*
+ Get the range of a |TSNode|. Can also supply {source} and {metadata} to
+ get the range with directives applied.
+
+ Parameters: ~
+ • {node} |TSNode|
+ • {source} integer|string|nil Buffer or string from which the {node}
+ is extracted
+ • {metadata} TSMetadata|nil
+
+ Return: ~
+ (table)
+
list_directives() *vim.treesitter.list_directives()*
Lists the currently available directives to use in queries.
diff --git a/runtime/lua/vim/treesitter/_fold.lua b/runtime/lua/vim/treesitter/_fold.lua
index 435cb9fdb6..fd2c707d17 100644
--- a/runtime/lua/vim/treesitter/_fold.lua
+++ b/runtime/lua/vim/treesitter/_fold.lua
@@ -1,4 +1,5 @@
local Range = require('vim.treesitter._range')
+local Query = require('vim.treesitter.query')
local api = vim.api
@@ -74,18 +75,6 @@ function FoldInfo:get_stop(lnum)
return self.stop_counts[lnum] or 0
end
----@private
---- TODO(lewis6991): copied from languagetree.lua. Consolidate
----@param node TSNode
----@param metadata TSMetadata
----@return Range4
-local function get_range_from_metadata(node, metadata)
- if metadata and metadata.range then
- return metadata.range --[[@as Range4]]
- end
- return { node:range() }
-end
-
local function trim_level(level)
local max_fold_level = vim.wo.foldnestmax
if level > max_fold_level then
@@ -118,7 +107,7 @@ local function get_folds_levels(bufnr, info, srow, erow)
for id, node, metadata in query:iter_captures(tree:root(), bufnr, srow or 0, q_erow) do
if query.captures[id] == 'fold' then
- local range = get_range_from_metadata(node, metadata[id])
+ local range = Query.get_range(node, bufnr, metadata[id])
local start, _, stop, stop_col = Range.unpack4(range)
if stop_col == 0 then
diff --git a/runtime/lua/vim/treesitter/_range.lua b/runtime/lua/vim/treesitter/_range.lua
index 02918da23f..0017a567ec 100644
--- a/runtime/lua/vim/treesitter/_range.lua
+++ b/runtime/lua/vim/treesitter/_range.lua
@@ -2,8 +2,19 @@ local api = vim.api
local M = {}
----@alias Range4 {[1]: integer, [2]: integer, [3]: integer, [4]: integer}
----@alias Range6 {[1]: integer, [2]: integer, [3]: integer, [4]: integer, [5]: integer, [6]: integer}
+---@class Range4
+---@field [1] integer start row
+---@field [2] integer start column
+---@field [3] integer end row
+---@field [4] integer end column
+
+---@class Range6
+---@field [1] integer start row
+---@field [2] integer start column
+---@field [3] integer start bytes
+---@field [4] integer end row
+---@field [5] integer end column
+---@field [6] integer end bytes
---@private
---@param a_row integer
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index c89419085f..0bb0601241 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -455,7 +455,7 @@ end
--- nodes, which is useful for templating languages like ERB and EJS.
---
---@private
----@param new_regions Range4[][] List of regions this tree should manage and parse.
+---@param new_regions Range6[][] List of regions this tree should manage and parse.
function LanguageTree:set_included_regions(new_regions)
-- Transform the tables from 4 element long to 6 element long (with byte offset)
for _, region in ipairs(new_regions) do
@@ -484,25 +484,13 @@ function LanguageTree:included_regions()
end
---@private
----@param node TSNode
----@param source integer|string
----@param metadata TSMetadata
----@return Range6
-local function get_range_from_metadata(node, source, metadata)
- if metadata and metadata.range then
- return Range.add_bytes(source, metadata.range --[[@as Range4|Range6]])
- end
- return { node:range(true) }
-end
-
----@private
--- TODO(lewis6991): cleanup of the node_range interface
---@param node TSNode
---@param source string|integer
---@param metadata TSMetadata
---@return Range6[]
local function get_node_ranges(node, source, metadata, include_children)
- local range = get_range_from_metadata(node, source, metadata)
+ local range = query.get_range(node, source, metadata)
if include_children then
return { range }
@@ -566,30 +554,17 @@ local function add_injection(t, tree_index, pattern, lang, combined, ranges)
end
---@private
----Get node text
----
----Note: `query.get_node_text` returns string|string[]|nil so use this simple alias function
----to annotate it returns string.
----
----TODO(lewis6991): use [at]overload annotations on `query.get_node_text`
----@param node TSNode
----@param source integer|string
----@param metadata table
----@return string
-local function get_node_text(node, source, metadata)
- return query.get_node_text(node, source, { metadata = metadata }) --[[@as string]]
-end
-
----@private
--- Extract injections according to:
--- https://tree-sitter.github.io/tree-sitter/syntax-highlighting#language-injection
---@param match table<integer,TSNode>
----@param metadata table
+---@param metadata TSMetadata
---@return string, boolean, Range4[]
function LanguageTree:_get_injection(match, metadata)
local ranges = {} ---@type Range4[]
local combined = metadata['injection.combined'] ~= nil
- local lang = metadata['injection.language'] ---@type string
+ local lang = metadata['injection.language']
+ assert(type(lang) == 'string')
+
local include_children = metadata['injection.include-children'] ~= nil
for id, node in pairs(match) do
@@ -597,7 +572,7 @@ function LanguageTree:_get_injection(match, metadata)
-- Lang should override any other language tag
if name == 'injection.language' then
- lang = get_node_text(node, self._source, metadata[id])
+ lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'injection.content' then
ranges = get_node_ranges(node, self._source, metadata[id], include_children)
end
@@ -608,11 +583,11 @@ end
---@private
---@param match table<integer,TSNode>
----@param metadata table
+---@param metadata TSMetadata
---@return string, boolean, Range4[]
function LanguageTree:_get_injection_deprecated(match, metadata)
local lang = nil ---@type string
- local ranges = {} ---@type Range4[]
+ local ranges = {} ---@type Range6[]
local combined = metadata.combined ~= nil
-- Directives can configure how injections are captured as well as actual node captures.
@@ -630,8 +605,10 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
end
end
- if metadata.language then
- lang = metadata.language ---@type string
+ local mlang = metadata.language
+ if mlang ~= nil then
+ assert(type(mlang) == 'string')
+ lang = mlang
end
-- You can specify the content and language together
@@ -642,11 +619,11 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
-- Lang should override any other language tag
if name == 'language' and not lang then
- lang = get_node_text(node, self._source, metadata[id])
+ lang = query.get_node_text(node, self._source, { metadata = metadata[id] })
elseif name == 'combined' then
combined = true
elseif name == 'content' and #ranges == 0 then
- ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id])
+ ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id])
-- Ignore any tags that start with "_"
-- Allows for other tags to be used in matches
elseif string.sub(name, 1, 1) ~= '_' then
@@ -655,7 +632,7 @@ function LanguageTree:_get_injection_deprecated(match, metadata)
end
if #ranges == 0 then
- ranges[#ranges + 1] = get_range_from_metadata(node, self._source, metadata[id])
+ ranges[#ranges + 1] = query.get_range(node, self._source, metadata[id])
end
end
end
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index e7cf42283d..70af4f7bce 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -1,6 +1,8 @@
local a = vim.api
local language = require('vim.treesitter.language')
+local Range = require('vim.treesitter._range')
+
---@class Query
---@field captures string[] List of captures used in query
---@field info TSQueryInfo Contains used queries, predicates, directives
@@ -56,35 +58,13 @@ local function add_included_lang(base_langs, lang, ilang)
end
---@private
----@param buf (integer)
----@param range (table)
----@param concat (boolean)
----@returns (string[]|string|nil)
-local function buf_range_get_text(buf, range, concat)
- local lines
- local start_row, start_col, end_row, end_col = unpack(range)
- local eof_row = a.nvim_buf_line_count(buf)
- if start_row >= eof_row then
- return nil
- end
-
- if end_col == 0 then
- lines = a.nvim_buf_get_lines(buf, start_row, end_row, true)
- end_col = -1
- else
- lines = a.nvim_buf_get_lines(buf, start_row, end_row + 1, true)
- end
-
- if #lines > 0 then
- if #lines == 1 then
- lines[1] = string.sub(lines[1], start_col + 1, end_col)
- else
- lines[1] = string.sub(lines[1], start_col + 1)
- lines[#lines] = string.sub(lines[#lines], 1, end_col)
- end
- end
-
- return concat and table.concat(lines, '\n') or lines
+---@param buf integer
+---@param range Range6
+---@returns string
+local function buf_range_get_text(buf, range)
+ local start_row, start_col, end_row, end_col = Range.unpack4(range)
+ local lines = a.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {})
+ return table.concat(lines, '\n')
end
--- Gets the list of files used to make up a query
@@ -256,14 +236,28 @@ function M.parse_query(lang, query)
local cached = query_cache[lang][query]
if cached then
return cached
- else
- local self = setmetatable({}, Query)
- self.query = vim._ts_parse_query(lang, query)
- self.info = self.query:inspect()
- self.captures = self.info.captures
- query_cache[lang][query] = self
- return self
end
+
+ local self = setmetatable({}, Query)
+ self.query = vim._ts_parse_query(lang, query)
+ self.info = self.query:inspect()
+ self.captures = self.info.captures
+ query_cache[lang][query] = self
+ return self
+end
+
+---Get the range of a |TSNode|. Can also supply {source} and {metadata}
+---to get the range with directives applied.
+---@param node TSNode
+---@param source integer|string|nil Buffer or string from which the {node} is extracted
+---@param metadata TSMetadata|nil
+---@return Range6
+function M.get_range(node, source, metadata)
+ if metadata and metadata.range then
+ assert(source)
+ return Range.add_bytes(source, metadata.range)
+ end
+ return { node:range(true) }
end
--- Gets the text corresponding to a given node
@@ -271,24 +265,22 @@ end
---@param node TSNode
---@param source (integer|string) Buffer or string from which the {node} is extracted
---@param opts (table|nil) Optional parameters.
---- - concat: (boolean) Concatenate result in a string (default true)
--- - metadata (table) Metadata of a specific capture. This would be
--- set to `metadata[capture_id]` when using |vim.treesitter.add_directive()|.
----@return (string[]|string|nil)
+---@return string
function M.get_node_text(node, source, opts)
opts = opts or {}
- -- TODO(lewis6991): concat only works when source is number.
- local concat = vim.F.if_nil(opts.concat, true)
local metadata = opts.metadata or {}
if metadata.text then
return metadata.text
elseif type(source) == 'number' then
- return metadata.range and buf_range_get_text(source, metadata.range, concat)
- or buf_range_get_text(source, { node:range() }, concat)
- elseif type(source) == 'string' then
- return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
+ local range = M.get_range(node, source, metadata)
+ return buf_range_get_text(source, range)
end
+
+ ---@cast source string
+ return source:sub(select(3, node:start()) + 1, select(3, node:end_()))
end
---@alias TSMatch table<integer,TSNode>
@@ -312,7 +304,7 @@ local predicate_handlers = {
str = predicate[3]
else
-- (#eq? @aa @bb)
- str = M.get_node_text(match[predicate[3]], source) --[[@as string]]
+ str = M.get_node_text(match[predicate[3]], source)
end
if node_text ~= str or str == nil then
@@ -328,7 +320,7 @@ local predicate_handlers = {
return true
end
local regex = predicate[3]
- return string.find(M.get_node_text(node, source) --[[@as string]], regex) ~= nil
+ return string.find(M.get_node_text(node, source), regex) ~= nil
end,
['match?'] = (function()
@@ -366,7 +358,7 @@ local predicate_handlers = {
if not node then
return true
end
- local node_text = M.get_node_text(node, source) --[[@as string]]
+ local node_text = M.get_node_text(node, source)
for i = 3, #predicate do
if string.find(node_text, predicate[i], 1, true) then
@@ -404,9 +396,9 @@ local predicate_handlers = {
predicate_handlers['vim-match?'] = predicate_handlers['match?']
---@class TSMetadata
+---@field range Range4|Range6
---@field [integer] TSMetadata
---@field [string] integer|string
----@field range Range4
---@alias TSDirective fun(match: TSMatch, _, _, predicate: (string|integer)[], metadata: TSMetadata)
@@ -465,13 +457,20 @@ local directive_handlers = {
assert(#pred == 4)
local id = pred[2]
+ assert(type(id) == 'number')
+
local node = match[id]
local text = M.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
if not metadata[id] then
metadata[id] = {}
end
- metadata[id].text = text:gsub(pred[3], pred[4])
+
+ local pattern, replacement = pred[3], pred[3]
+ assert(type(pattern) == 'string')
+ assert(type(replacement) == 'string')
+
+ metadata[id].text = text:gsub(pattern, replacement)
end,
}