aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/query.lua
diff options
context:
space:
mode:
authorJosh Rahm <rahm@google.com>2022-07-18 19:37:18 +0000
committerJosh Rahm <rahm@google.com>2022-07-18 19:37:18 +0000
commit308e1940dcd64aa6c344c403d4f9e0dda58d9c5c (patch)
tree35fe43e01755e0f312650667004487a44d6b7941 /runtime/lua/vim/treesitter/query.lua
parent96a00c7c588b2f38a2424aeeb4ea3581d370bf2d (diff)
parente8c94697bcbe23a5c7b07c292b90a6b70aadfa87 (diff)
downloadrneovim-308e1940dcd64aa6c344c403d4f9e0dda58d9c5c.tar.gz
rneovim-308e1940dcd64aa6c344c403d4f9e0dda58d9c5c.tar.bz2
rneovim-308e1940dcd64aa6c344c403d4f9e0dda58d9c5c.zip
Merge remote-tracking branch 'upstream/master' into rahm
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r--runtime/lua/vim/treesitter/query.lua195
1 files changed, 118 insertions, 77 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index ebed502c92..103e85abfd 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -1,5 +1,5 @@
local a = vim.api
-local language = require'vim.treesitter.language'
+local language = require('vim.treesitter.language')
-- query: pattern matching on trees
-- predicate matching is implemented in lua
@@ -43,7 +43,9 @@ function M.get_query_files(lang, query_name, is_included)
local query_path = string.format('queries/%s/%s.scm', lang, query_name)
local lang_files = dedupe_files(a.nvim_get_runtime_file(query_path, true))
- if #lang_files == 0 then return {} end
+ if #lang_files == 0 then
+ return {}
+ end
local base_langs = {}
@@ -52,7 +54,7 @@ function M.get_query_files(lang, query_name, is_included)
-- ;+ inherits: ({language},)*{language}
--
-- {language} ::= {lang} | ({lang})
- local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$"
+ local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
for _, file in ipairs(lang_files) do
local modeline = safe_read(file, '*l')
@@ -62,7 +64,7 @@ function M.get_query_files(lang, query_name, is_included)
if langlist then
for _, incllang in ipairs(vim.split(langlist, ',', true)) do
- local is_optional = incllang:match("%(.*%)")
+ local is_optional = incllang:match('%(.*%)')
if is_optional then
if not is_included then
@@ -90,7 +92,7 @@ end
local function read_query_files(filenames)
local contents = {}
- for _,filename in ipairs(filenames) do
+ for _, filename in ipairs(filenames) do
table.insert(contents, safe_read(filename, '*a'))
end
@@ -138,30 +140,43 @@ function M.get_query(lang, query_name)
end
end
+local query_cache = setmetatable({}, {
+ __index = function(tbl, key)
+ rawset(tbl, key, {})
+ return rawget(tbl, key)
+ end,
+})
+
--- Parse {query} as a string. (If the query is in a file, the caller
---- should read the contents into a string before calling).
+--- should read the contents into a string before calling).
---
--- Returns a `Query` (see |lua-treesitter-query|) object which can be used to
--- search nodes in the syntax tree for the patterns defined in {query}
--- using `iter_*` methods below.
---
---- Exposes `info` and `captures` with additional information about the {query}.
+--- Exposes `info` and `captures` with additional context about {query}.
--- - `captures` contains the list of unique capture names defined in
--- {query}.
--- -` info.captures` also points to `captures`.
--- - `info.patterns` contains information about predicates.
---
----@param lang The language
----@param query A string containing the query (s-expr syntax)
+---@param lang string The language
+---@param query string A string containing the query (s-expr syntax)
---
---@returns The query
function M.parse_query(lang, query)
language.require_language(lang)
- local self = setmetatable({}, Query)
- self.query = vim._ts_parse_query(lang, query)
- self.info = self.query:inspect()
- self.captures = self.info.captures
- return self
+ local cached = query_cache[lang][query]
+ if cached then
+ return cached
+ else
+ local self = setmetatable({}, Query)
+ self.query = vim._ts_parse_query(lang, query)
+ self.info = self.query:inspect()
+ self.captures = self.info.captures
+ query_cache[lang][query] = self
+ return self
+ end
end
--- Gets the text corresponding to a given node
@@ -172,7 +187,7 @@ function M.get_node_text(node, source)
local start_row, start_col, start_byte = node:start()
local end_row, end_col, end_byte = node:end_()
- if type(source) == "number" then
+ if type(source) == 'number' then
local lines
local eof_row = a.nvim_buf_line_count(source)
if start_row >= eof_row then
@@ -186,56 +201,64 @@ function M.get_node_text(node, source)
lines = a.nvim_buf_get_lines(source, start_row, end_row + 1, true)
end
- if #lines == 1 then
- lines[1] = string.sub(lines[1], start_col+1, end_col)
- else
- lines[1] = string.sub(lines[1], start_col+1)
- lines[#lines] = string.sub(lines[#lines], 1, end_col)
+ if #lines > 0 then
+ if #lines == 1 then
+ lines[1] = string.sub(lines[1], start_col + 1, end_col)
+ else
+ lines[1] = string.sub(lines[1], start_col + 1)
+ lines[#lines] = string.sub(lines[#lines], 1, end_col)
+ end
end
- return table.concat(lines, "\n")
- elseif type(source) == "string" then
- return source:sub(start_byte+1, end_byte)
+ return table.concat(lines, '\n')
+ elseif type(source) == 'string' then
+ return source:sub(start_byte + 1, end_byte)
end
end
-- Predicate handler receive the following arguments
-- (match, pattern, bufnr, predicate)
local predicate_handlers = {
- ["eq?"] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- local node_text = M.get_node_text(node, source)
-
- local str
- if type(predicate[3]) == "string" then
- -- (#eq? @aa "foo")
- str = predicate[3]
- else
- -- (#eq? @aa @bb)
- str = M.get_node_text(match[predicate[3]], source)
- end
+ ['eq?'] = function(match, _, source, predicate)
+ local node = match[predicate[2]]
+ if not node then
+ return true
+ end
+ local node_text = M.get_node_text(node, source)
- if node_text ~= str or str == nil then
- return false
- end
+ local str
+ if type(predicate[3]) == 'string' then
+ -- (#eq? @aa "foo")
+ str = predicate[3]
+ else
+ -- (#eq? @aa @bb)
+ str = M.get_node_text(match[predicate[3]], source)
+ end
- return true
+ if node_text ~= str or str == nil then
+ return false
+ end
+
+ return true
end,
- ["lua-match?"] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- local regex = predicate[3]
- return string.find(M.get_node_text(node, source), regex)
+ ['lua-match?'] = function(match, _, source, predicate)
+ local node = match[predicate[2]]
+ if not node then
+ return true
+ end
+ local regex = predicate[3]
+ return string.find(M.get_node_text(node, source), regex)
end,
- ["match?"] = (function()
- local magic_prefixes = {['\\v']=true, ['\\m']=true, ['\\M']=true, ['\\V']=true}
+ ['match?'] = (function()
+ local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true }
---@private
local function check_magic(str)
- if string.len(str) < 2 or magic_prefixes[string.sub(str,1,2)] then
+ if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then
return str
end
- return '\\v'..str
+ return '\\v' .. str
end
local compiled_vim_regexes = setmetatable({}, {
@@ -243,21 +266,27 @@ local predicate_handlers = {
local res = vim.regex(check_magic(pattern))
rawset(t, pattern, res)
return res
- end
+ end,
})
return function(match, _, source, pred)
local node = match[pred[2]]
+ if not node then
+ return true
+ end
local regex = compiled_vim_regexes[pred[3]]
return regex:match_str(M.get_node_text(node, source))
end
end)(),
- ["contains?"] = function(match, _, source, predicate)
+ ['contains?'] = function(match, _, source, predicate)
local node = match[predicate[2]]
+ if not node then
+ return true
+ end
local node_text = M.get_node_text(node, source)
- for i=3,#predicate do
+ for i = 3, #predicate do
if string.find(node_text, predicate[i], 1, true) then
return true
end
@@ -266,19 +295,22 @@ local predicate_handlers = {
return false
end,
- ["any-of?"] = function(match, _, source, predicate)
+ ['any-of?'] = function(match, _, source, predicate)
local node = match[predicate[2]]
+ if not node then
+ return true
+ end
local node_text = M.get_node_text(node, source)
-- Since 'predicate' will not be used by callers of this function, use it
-- to store a string set built from the list of words to check against.
- local string_set = predicate["string_set"]
+ local string_set = predicate['string_set']
if not string_set then
string_set = {}
- for i=3,#predicate do
+ for i = 3, #predicate do
string_set[predicate[i]] = true
end
- predicate["string_set"] = string_set
+ predicate['string_set'] = string_set
end
return string_set[node_text]
@@ -286,32 +318,33 @@ local predicate_handlers = {
}
-- As we provide lua-match? also expose vim-match?
-predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-
+predicate_handlers['vim-match?'] = predicate_handlers['match?']
-- Directives store metadata or perform side effects against a match.
-- Directives should always end with a `!`.
-- Directive handler receive the following arguments
-- (match, pattern, bufnr, predicate, metadata)
local directive_handlers = {
- ["set!"] = function(_, _, _, pred, metadata)
+ ['set!'] = function(_, _, _, pred, metadata)
if #pred == 4 then
-- (#set! @capture "key" "value")
- local capture = pred[2]
- if not metadata[capture] then
- metadata[capture] = {}
+ local _, capture_id, key, value = unpack(pred)
+ if not metadata[capture_id] then
+ metadata[capture_id] = {}
end
- metadata[capture][pred[3]] = pred[4]
+ metadata[capture_id][key] = value
else
+ local _, key, value = unpack(pred)
-- (#set! "key" "value")
- metadata[pred[2]] = pred[3]
+ metadata[key] = value
end
end,
-- Shifts the range of a node.
-- Example: (#offset! @_node 0 1 0 -1)
- ["offset!"] = function(match, _, _, pred, metadata)
- local offset_node = match[pred[2]]
- local range = {offset_node:range()}
+ ['offset!'] = function(match, _, _, pred, metadata)
+ local capture_id = pred[2]
+ local offset_node = match[capture_id]
+ local range = { offset_node:range() }
local start_row_offset = pred[3] or 0
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
@@ -324,9 +357,12 @@ local directive_handlers = {
-- If this produces an invalid range, we just skip it.
if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
- metadata.content = {range}
+ if not metadata[capture_id] then
+ metadata[capture_id] = {}
+ end
+ metadata[capture_id].range = range
end
- end
+ end,
}
--- Adds a new predicate to be used in queries
@@ -336,7 +372,7 @@ local directive_handlers = {
--- signature will be (match, pattern, bufnr, predicate)
function M.add_predicate(name, handler, force)
if predicate_handlers[name] and not force then
- error(string.format("Overriding %s", name))
+ error(string.format('Overriding %s', name))
end
predicate_handlers[name] = handler
@@ -344,17 +380,23 @@ end
--- Adds a new directive to be used in queries
---
+--- Handlers can set match level data by setting directly on the
+--- metadata object `metadata.key = value`, additionally, handlers
+--- can set node level data by using the capture id on the
+--- metadata table `metadata[capture_id].key = value`
+---
---@param name the name of the directive, without leading #
---@param handler the handler function to be used
---- signature will be (match, pattern, bufnr, predicate)
+--- signature will be (match, pattern, bufnr, predicate, metadata)
function M.add_directive(name, handler, force)
if directive_handlers[name] and not force then
- error(string.format("Overriding %s", name))
+ error(string.format('Overriding %s', name))
end
directive_handlers[name] = handler
end
+--- Lists the currently available directives to use in queries.
---@return The list of supported directives.
function M.list_directives()
return vim.tbl_keys(directive_handlers)
@@ -372,7 +414,7 @@ end
---@private
local function is_directive(name)
- return string.sub(name, -1) == "!"
+ return string.sub(name, -1) == '!'
end
---@private
@@ -389,7 +431,7 @@ function Query:match_preds(match, pattern, source)
-- Skip over directives... they will get processed after all the predicates.
if not is_directive(pred[1]) then
- if string.sub(pred[1], 1, 4) == "not-" then
+ if string.sub(pred[1], 1, 4) == 'not-' then
pred_name = string.sub(pred[1], 5)
is_not = true
else
@@ -400,7 +442,7 @@ function Query:match_preds(match, pattern, source)
local handler = predicate_handlers[pred_name]
if not handler then
- error(string.format("No handler for %s", pred[1]))
+ error(string.format('No handler for %s', pred[1]))
return false
end
@@ -423,7 +465,7 @@ function Query:apply_directives(match, pattern, source, metadata)
local handler = directive_handlers[pred[1]]
if not handler then
- error(string.format("No handler for %s", pred[1]))
+ error(string.format('No handler for %s', pred[1]))
return
end
@@ -432,7 +474,6 @@ function Query:apply_directives(match, pattern, source, metadata)
end
end
-
--- Returns the start and stop value if set else the node's range.
-- When the node's range is used, the stop is incremented by 1
-- to make the search inclusive.
@@ -477,7 +518,7 @@ end
---@returns The matching capture id
---@returns The captured node
function Query:iter_captures(node, source, start, stop)
- if type(source) == "number" and source == 0 then
+ if type(source) == 'number' and source == 0 then
source = vim.api.nvim_get_current_buf()
end
@@ -534,7 +575,7 @@ end
---@returns The matching pattern id
---@returns The matching match
function Query:iter_matches(node, source, start, stop)
- if type(source) == "number" and source == 0 then
+ if type(source) == 'number' and source == 0 then
source = vim.api.nvim_get_current_buf()
end