aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/query.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r--runtime/lua/vim/treesitter/query.lua79
1 files changed, 61 insertions, 18 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 103e85abfd..90ed2a357c 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -34,6 +34,18 @@ local function safe_read(filename, read_quantifier)
return content
end
+---@private
+--- Adds @p ilang to @p base_langs, only if @p ilang is different than @lang
+---
+---@return boolean true it lang == ilang
+local function add_included_lang(base_langs, lang, ilang)
+ if lang == ilang then
+ return true
+ end
+ table.insert(base_langs, ilang)
+ return false
+end
+
--- Gets the list of files used to make up a query
---
---@param lang The language
@@ -47,6 +59,9 @@ function M.get_query_files(lang, query_name, is_included)
return {}
end
+ local base_query = nil
+ local extensions = {}
+
local base_langs = {}
-- Now get the base languages by looking at the first line of every file
@@ -55,35 +70,61 @@ function M.get_query_files(lang, query_name, is_included)
--
-- {language} ::= {lang} | ({lang})
local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
+ local EXTENDS_FORMAT = '^;+%s*extends%s*$'
- for _, file in ipairs(lang_files) do
- local modeline = safe_read(file, '*l')
+ for _, filename in ipairs(lang_files) do
+ local file, err = io.open(filename, 'r')
+ if not file then
+ error(err)
+ end
- if modeline then
- local langlist = modeline:match(MODELINE_FORMAT)
+ local extension = false
+ for modeline in
+ function()
+ return file:read('*l')
+ end
+ do
+ if not vim.startswith(modeline, ';') then
+ break
+ end
+
+ local langlist = modeline:match(MODELINE_FORMAT)
if langlist then
for _, incllang in ipairs(vim.split(langlist, ',', true)) do
local is_optional = incllang:match('%(.*%)')
if is_optional then
if not is_included then
- table.insert(base_langs, incllang:sub(2, #incllang - 1))
+ if add_included_lang(base_langs, lang, incllang:sub(2, #incllang - 1)) then
+ extension = true
+ end
end
else
- table.insert(base_langs, incllang)
+ if add_included_lang(base_langs, lang, incllang) then
+ extension = true
+ end
end
end
+ elseif modeline:match(EXTENDS_FORMAT) then
+ extension = true
end
end
+
+ if extension then
+ table.insert(extensions, filename)
+ elseif base_query == nil then
+ base_query = filename
+ end
+ io.close(file)
end
- local query_files = {}
+ local query_files = { base_query }
for _, base_lang in ipairs(base_langs) do
local base_files = M.get_query_files(base_lang, query_name, true)
vim.list_extend(query_files, base_files)
end
- vim.list_extend(query_files, lang_files)
+ vim.list_extend(query_files, extensions)
return query_files
end
@@ -140,12 +181,9 @@ function M.get_query(lang, query_name)
end
end
-local query_cache = setmetatable({}, {
- __index = function(tbl, key)
- rawset(tbl, key, {})
- return rawget(tbl, key)
- end,
-})
+local query_cache = vim.defaulttable(function()
+ return setmetatable({}, { __mode = 'v' })
+end)
--- Parse {query} as a string. (If the query is in a file, the caller
--- should read the contents into a string before calling).
@@ -181,9 +219,14 @@ end
--- Gets the text corresponding to a given node
---
----@param node the node
----@param source The buffer or string from which the node is extracted
-function M.get_node_text(node, source)
+---@param node table The node
+---@param source table The buffer or string from which the node is extracted
+---@param opts table Optional parameters.
+--- - concat: (boolean default true) Concatenate result in a string
+function M.get_node_text(node, source, opts)
+ opts = opts or {}
+ local concat = vim.F.if_nil(opts.concat, true)
+
local start_row, start_col, start_byte = node:start()
local end_row, end_col, end_byte = node:end_()
@@ -210,7 +253,7 @@ function M.get_node_text(node, source)
end
end
- return table.concat(lines, '\n')
+ return concat and table.concat(lines, '\n') or lines
elseif type(source) == 'string' then
return source:sub(start_byte + 1, end_byte)
end