aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/query.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r--runtime/lua/vim/treesitter/query.lua178
1 files changed, 124 insertions, 54 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 3537ba78f5..5a27d740a2 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -8,36 +8,10 @@ Query.__index = Query
local M = {}
--- Filter the runtime query files, the spec is like regular runtime files but in the new `queries`
--- directory. They resemble ftplugins, that is that you can override queries by adding things in the
--- `queries` directory, and extend using the `after/queries` directory.
-local function filter_files(file_list)
- local main = nil
- local after = {}
-
- for _, fname in ipairs(file_list) do
- -- Only get the name of the directory containing the queries directory
- if vim.fn.fnamemodify(fname, ":p:h:h:h:t") == "after" then
- table.insert(after, fname)
- -- The first one is the one with most priority
- elseif not main then
- main = fname
- end
- end
-
- return main and { main, unpack(after) } or after
-end
-
-local function runtime_query_path(lang, query_name)
- return string.format('queries/%s/%s.scm', lang, query_name)
-end
-
-local function filtered_runtime_queries(lang, query_name)
- return filter_files(a.nvim_get_runtime_file(runtime_query_path(lang, query_name), true) or {})
-end
-local function get_query_files(lang, query_name, is_included)
- local lang_files = filtered_runtime_queries(lang, query_name)
+function M.get_query_files(lang, query_name, is_included)
+ local query_path = string.format('queries/%s/%s.scm', lang, query_name)
+ local lang_files = a.nvim_get_runtime_file(query_path, true)
if #lang_files == 0 then return {} end
@@ -51,10 +25,10 @@ local function get_query_files(lang, query_name, is_included)
local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$"
for _, file in ipairs(lang_files) do
- local modeline = vim.fn.readfile(file, "", 1)
+ local modeline = io.open(file, 'r'):read('*l')
- if #modeline == 1 then
- local langlist = modeline[1]:match(MODELINE_FORMAT)
+ if modeline then
+ local langlist = modeline:match(MODELINE_FORMAT)
if langlist then
for _, incllang in ipairs(vim.split(langlist, ',', true)) do
@@ -74,7 +48,7 @@ local function get_query_files(lang, query_name, is_included)
local query_files = {}
for _, base_lang in ipairs(base_langs) do
- local base_files = get_query_files(base_lang, query_name, true)
+ local base_files = M.get_query_files(base_lang, query_name, true)
vim.list_extend(query_files, base_files)
end
vim.list_extend(query_files, lang_files)
@@ -86,10 +60,21 @@ local function read_query_files(filenames)
local contents = {}
for _,filename in ipairs(filenames) do
- vim.list_extend(contents, vim.fn.readfile(filename))
+ table.insert(contents, io.open(filename, 'r'):read('*a'))
end
- return table.concat(contents, '\n')
+ return table.concat(contents, '')
+end
+
+local match_metatable = {
+ __index = function(tbl, key)
+ rawset(tbl, key, {})
+ return tbl[key]
+ end
+}
+
+local function new_match_metadata()
+ return setmetatable({}, match_metatable)
end
--- Returns the runtime query {query_name} for {lang}.
@@ -99,7 +84,7 @@ end
--
-- @return The corresponding query, parsed.
function M.get_query(lang, query_name)
- local query_files = get_query_files(lang, query_name)
+ local query_files = M.get_query_files(lang, query_name)
local query_string = read_query_files(query_files)
if #query_string > 0 then
@@ -222,6 +207,44 @@ local predicate_handlers = {
-- As we provide lua-match? also expose vim-match?
predicate_handlers["vim-match?"] = predicate_handlers["match?"]
+
+-- Directives store metadata or perform side effects against a match.
+-- Directives should always end with a `!`.
+-- Directive handler receive the following arguments
+-- (match, pattern, bufnr, predicate)
+local directive_handlers = {
+ ["set!"] = function(_, _, _, pred, metadata)
+ if #pred == 4 then
+ -- (set! @capture "key" "value")
+ metadata[pred[2]][pred[3]] = pred[4]
+ else
+ -- (set! "key" "value")
+ metadata[pred[2]] = pred[3]
+ end
+ end,
+ -- Shifts the range of a node.
+ -- Example: (#offset! @_node 0 1 0 -1)
+ ["offset!"] = function(match, _, _, pred, metadata)
+ local offset_node = match[pred[2]]
+ local range = {offset_node:range()}
+ local start_row_offset = pred[3] or 0
+ local start_col_offset = pred[4] or 0
+ local end_row_offset = pred[5] or 0
+ local end_col_offset = pred[6] or 0
+ local key = pred[7] or "offset"
+
+ range[1] = range[1] + start_row_offset
+ range[2] = range[2] + start_col_offset
+ range[3] = range[3] + end_row_offset
+ range[4] = range[4] + end_col_offset
+
+ -- If this produces an invalid range, we just skip it.
+ if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
+ metadata[pred[2]][key] = range
+ end
+ end
+}
+
--- Adds a new predicates to be used in queries
--
-- @param name the name of the predicate, without leading #
@@ -229,12 +252,25 @@ predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-- signature will be (match, pattern, bufnr, predicate)
function M.add_predicate(name, handler, force)
if predicate_handlers[name] and not force then
- a.nvim_err_writeln(string.format("Overriding %s", name))
+ error(string.format("Overriding %s", name))
end
predicate_handlers[name] = handler
end
+--- Adds a new directive to be used in queries
+--
+-- @param name the name of the directive, without leading #
+-- @param handler the handler function to be used
+-- signature will be (match, pattern, bufnr, predicate)
+function M.add_directive(name, handler, force)
+ if directive_handlers[name] and not force then
+ error(string.format("Overriding %s", name))
+ end
+
+ directive_handlers[name] = handler
+end
+
--- Returns the list of currently supported predicates
function M.list_predicates()
return vim.tbl_keys(predicate_handlers)
@@ -244,6 +280,10 @@ local function xor(x, y)
return (x or y) and not (x and y)
end
+local function is_directive(name)
+ return string.sub(name, -1) == "!"
+end
+
function Query:match_preds(match, pattern, source)
local preds = self.info.patterns[pattern]
@@ -254,30 +294,52 @@ function Query:match_preds(match, pattern, source)
-- Also, tree-sitter strips the leading # from predicates for us.
local pred_name
local is_not
- if string.sub(pred[1], 1, 4) == "not-" then
- pred_name = string.sub(pred[1], 5)
- is_not = true
- else
- pred_name = pred[1]
- is_not = false
- end
- local handler = predicate_handlers[pred_name]
+ -- Skip over directives... they will get processed after all the predicates.
+ if not is_directive(pred[1]) then
+ if string.sub(pred[1], 1, 4) == "not-" then
+ pred_name = string.sub(pred[1], 5)
+ is_not = true
+ else
+ pred_name = pred[1]
+ is_not = false
+ end
+
+ local handler = predicate_handlers[pred_name]
- if not handler then
- a.nvim_err_writeln(string.format("No handler for %s", pred[1]))
- return false
- end
+ if not handler then
+ error(string.format("No handler for %s", pred[1]))
+ return false
+ end
- local pred_matches = handler(match, pattern, source, pred)
+ local pred_matches = handler(match, pattern, source, pred)
- if not xor(is_not, pred_matches) then
- return false
+ if not xor(is_not, pred_matches) then
+ return false
+ end
end
end
return true
end
+--- Applies directives against a match and pattern.
+function Query:apply_directives(match, pattern, source, metadata)
+ local preds = self.info.patterns[pattern]
+
+ for _, pred in pairs(preds or {}) do
+ if is_directive(pred[1]) then
+ local handler = directive_handlers[pred[1]]
+
+ if not handler then
+ error(string.format("No handler for %s", pred[1]))
+ return
+ end
+
+ handler(match, pattern, source, pred, metadata)
+ end
+ end
+end
+
--- Iterates of the captures of self on a given range.
--
-- @param node The node under witch the search will occur
@@ -294,14 +356,18 @@ function Query:iter_captures(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, true, start, stop)
local function iter()
local capture, captured_node, match = raw_iter()
+ local metadata = new_match_metadata()
+
if match ~= nil then
local active = self:match_preds(match, match.pattern, source)
match.active = active
if not active then
return iter() -- tail call: try next match
end
+
+ self:apply_directives(match, match.pattern, source, metadata)
end
- return capture, captured_node
+ return capture, captured_node, metadata
end
return iter
end
@@ -322,13 +388,17 @@ function Query:iter_matches(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, false, start, stop)
local function iter()
local pattern, match = raw_iter()
+ local metadata = new_match_metadata()
+
if match ~= nil then
local active = self:match_preds(match, pattern, source)
if not active then
return iter() -- tail call: try next match
end
+
+ self:apply_directives(match, pattern, source, metadata)
end
- return pattern, match
+ return pattern, match, metadata
end
return iter
end