aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/query.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r--runtime/lua/vim/treesitter/query.lua54
1 files changed, 51 insertions, 3 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 8b94348994..f40e1d5294 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -8,10 +8,33 @@ Query.__index = Query
local M = {}
+local function dedupe_files(files)
+ local result = {}
+ local seen = {}
+
+ for _, path in ipairs(files) do
+ if not seen[path] then
+ table.insert(result, path)
+ seen[path] = true
+ end
+ end
+
+ return result
+end
+
+local function safe_read(filename, read_quantifier)
+ local file, err = io.open(filename, 'r')
+ if not file then
+ error(err)
+ end
+ local content = file:read(read_quantifier)
+ io.close(file)
+ return content
+end
function M.get_query_files(lang, query_name, is_included)
local query_path = string.format('queries/%s/%s.scm', lang, query_name)
- local lang_files = a.nvim_get_runtime_file(query_path, true)
+ local lang_files = dedupe_files(a.nvim_get_runtime_file(query_path, true))
if #lang_files == 0 then return {} end
@@ -25,7 +48,7 @@ function M.get_query_files(lang, query_name, is_included)
local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$"
for _, file in ipairs(lang_files) do
- local modeline = io.open(file, 'r'):read('*l')
+ local modeline = safe_read(file, '*l')
if modeline then
local langlist = modeline:match(MODELINE_FORMAT)
@@ -60,7 +83,7 @@ local function read_query_files(filenames)
local contents = {}
for _,filename in ipairs(filenames) do
- table.insert(contents, io.open(filename, 'r'):read('*a'))
+ table.insert(contents, safe_read(filename, '*a'))
end
return table.concat(contents, '')
@@ -77,6 +100,27 @@ local function new_match_metadata()
return setmetatable({}, match_metatable)
end
+--- The explicitly set queries from |vim.treesitter.query.set_query()|
+local explicit_queries = setmetatable({}, {
+ __index = function(t, k)
+ local lang_queries = {}
+ rawset(t, k, lang_queries)
+
+ return lang_queries
+ end,
+})
+
+--- Sets the runtime query {query_name} for {lang}
+---
+--- This allows users to override any runtime files and/or configuration
+--- set by plugins.
+---@param lang string: The language to use for the query
+---@param query_name string: The name of the query (i.e. "highlights")
+---@param text string: The query text (unparsed).
+function M.set_query(lang, query_name, text)
+ explicit_queries[lang][query_name] = M.parse_query(lang, text)
+end
+
--- Returns the runtime query {query_name} for {lang}.
--
-- @param lang The language to use for the query
@@ -84,6 +128,10 @@ end
--
-- @return The corresponding query, parsed.
function M.get_query(lang, query_name)
+ if explicit_queries[lang][query_name] then
+ return explicit_queries[lang][query_name]
+ end
+
local query_files = M.get_query_files(lang, query_name)
local query_string = read_query_files(query_files)