aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/query.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r--runtime/lua/vim/treesitter/query.lua86
1 files changed, 61 insertions, 25 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index e49f54681d..ed5146be44 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -8,10 +8,33 @@ Query.__index = Query
local M = {}
+local function dedupe_files(files)
+ local result = {}
+ local seen = {}
+
+ for _, path in ipairs(files) do
+ if not seen[path] then
+ table.insert(result, path)
+ seen[path] = true
+ end
+ end
+
+ return result
+end
+
+local function safe_read(filename, read_quantifier)
+ local file, err = io.open(filename, 'r')
+ if not file then
+ error(err)
+ end
+ local content = file:read(read_quantifier)
+ io.close(file)
+ return content
+end
function M.get_query_files(lang, query_name, is_included)
local query_path = string.format('queries/%s/%s.scm', lang, query_name)
- local lang_files = a.nvim_get_runtime_file(query_path, true)
+ local lang_files = dedupe_files(a.nvim_get_runtime_file(query_path, true))
if #lang_files == 0 then return {} end
@@ -25,7 +48,7 @@ function M.get_query_files(lang, query_name, is_included)
local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$"
for _, file in ipairs(lang_files) do
- local modeline = io.open(file, 'r'):read('*l')
+ local modeline = safe_read(file, '*l')
if modeline then
local langlist = modeline:match(MODELINE_FORMAT)
@@ -60,21 +83,31 @@ local function read_query_files(filenames)
local contents = {}
for _,filename in ipairs(filenames) do
- table.insert(contents, io.open(filename, 'r'):read('*a'))
+ table.insert(contents, safe_read(filename, '*a'))
end
return table.concat(contents, '')
end
-local match_metatable = {
- __index = function(tbl, key)
- rawset(tbl, key, {})
- return tbl[key]
- end
-}
+--- The explicitly set queries from |vim.treesitter.query.set_query()|
+local explicit_queries = setmetatable({}, {
+ __index = function(t, k)
+ local lang_queries = {}
+ rawset(t, k, lang_queries)
-local function new_match_metadata()
- return setmetatable({}, match_metatable)
+ return lang_queries
+ end,
+})
+
+--- Sets the runtime query {query_name} for {lang}
+---
+--- This allows users to override any runtime files and/or configuration
+--- set by plugins.
+---@param lang string: The language to use for the query
+---@param query_name string: The name of the query (i.e. "highlights")
+---@param text string: The query text (unparsed).
+function M.set_query(lang, query_name, text)
+ explicit_queries[lang][query_name] = M.parse_query(lang, text)
end
--- Returns the runtime query {query_name} for {lang}.
@@ -84,6 +117,10 @@ end
--
-- @return The corresponding query, parsed.
function M.get_query(lang, query_name)
+ if explicit_queries[lang][query_name] then
+ return explicit_queries[lang][query_name]
+ end
+
local query_files = M.get_query_files(lang, query_name)
local query_string = read_query_files(query_files)
@@ -111,7 +148,7 @@ end
--- Gets the text corresponding to a given node
-- @param node the node
--- @param bufnr the buffer from which the node in extracted.
+-- @param bufnr the buffer from which the node is extracted.
function M.get_node_text(node, source)
local start_row, start_col, start_byte = node:start()
local end_row, end_col, end_byte = node:end_()
@@ -211,14 +248,14 @@ predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-- Directives store metadata or perform side effects against a match.
-- Directives should always end with a `!`.
-- Directive handler receive the following arguments
--- (match, pattern, bufnr, predicate)
+-- (match, pattern, bufnr, predicate, metadata)
local directive_handlers = {
["set!"] = function(_, _, _, pred, metadata)
if #pred == 4 then
- -- (set! @capture "key" "value")
+ -- (#set! @capture "key" "value")
metadata[pred[2]][pred[3]] = pred[4]
else
- -- (set! "key" "value")
+ -- (#set! "key" "value")
metadata[pred[2]] = pred[3]
end
end,
@@ -231,7 +268,6 @@ local directive_handlers = {
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
local end_col_offset = pred[6] or 0
- local key = pred[7] or "offset"
range[1] = range[1] + start_row_offset
range[2] = range[2] + start_col_offset
@@ -240,12 +276,12 @@ local directive_handlers = {
-- If this produces an invalid range, we just skip it.
if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
- metadata[pred[2]][key] = range
+ metadata.content = {range}
end
end
}
---- Adds a new predicates to be used in queries
+--- Adds a new predicate to be used in queries
--
-- @param name the name of the predicate, without leading #
-- @param handler the handler function to be used
@@ -355,10 +391,10 @@ end
--- Iterates of the captures of self on a given range.
--
--- @param node The node under witch the search will occur
+-- @param node The node under which the search will occur
-- @param buffer The source buffer to search
-- @param start The starting line of the search
--- @param stop The stoping line of the search (end-exclusive)
+-- @param stop The stopping line of the search (end-exclusive)
--
-- @returns The matching capture id
-- @returns The captured node
@@ -372,7 +408,7 @@ function Query:iter_captures(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, true, start, stop)
local function iter()
local capture, captured_node, match = raw_iter()
- local metadata = new_match_metadata()
+ local metadata = {}
if match ~= nil then
local active = self:match_preds(match, match.pattern, source)
@@ -388,12 +424,12 @@ function Query:iter_captures(node, source, start, stop)
return iter
end
---- Iterates of the matches of self on a given range.
+--- Iterates the matches of self on a given range.
--
--- @param node The node under witch the search will occur
+-- @param node The node under which the search will occur
-- @param buffer The source buffer to search
-- @param start The starting line of the search
--- @param stop The stoping line of the search (end-exclusive)
+-- @param stop The stopping line of the search (end-exclusive)
--
-- @returns The matching pattern id
-- @returns The matching match
@@ -407,7 +443,7 @@ function Query:iter_matches(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, false, start, stop)
local function iter()
local pattern, match = raw_iter()
- local metadata = new_match_metadata()
+ local metadata = {}
if match ~= nil then
local active = self:match_preds(match, pattern, source)