aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/query.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r--runtime/lua/vim/treesitter/query.lua223
1 files changed, 166 insertions, 57 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 8b94348994..c0140f9186 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -8,10 +8,40 @@ Query.__index = Query
local M = {}
+---@private
+local function dedupe_files(files)
+ local result = {}
+ local seen = {}
+
+ for _, path in ipairs(files) do
+ if not seen[path] then
+ table.insert(result, path)
+ seen[path] = true
+ end
+ end
+
+ return result
+end
+
+---@private
+local function safe_read(filename, read_quantifier)
+ local file, err = io.open(filename, 'r')
+ if not file then
+ error(err)
+ end
+ local content = file:read(read_quantifier)
+ io.close(file)
+ return content
+end
+--- Gets the list of files used to make up a query
+---
+--- @param lang The language
+--- @param query_name The name of the query to load
+--- @param is_included Internal parameter, most of the time left as `nil`
function M.get_query_files(lang, query_name, is_included)
local query_path = string.format('queries/%s/%s.scm', lang, query_name)
- local lang_files = a.nvim_get_runtime_file(query_path, true)
+ local lang_files = dedupe_files(a.nvim_get_runtime_file(query_path, true))
if #lang_files == 0 then return {} end
@@ -25,7 +55,7 @@ function M.get_query_files(lang, query_name, is_included)
local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$"
for _, file in ipairs(lang_files) do
- local modeline = io.open(file, 'r'):read('*l')
+ local modeline = safe_read(file, '*l')
if modeline then
local langlist = modeline:match(MODELINE_FORMAT)
@@ -56,34 +86,50 @@ function M.get_query_files(lang, query_name, is_included)
return query_files
end
+---@private
local function read_query_files(filenames)
local contents = {}
for _,filename in ipairs(filenames) do
- table.insert(contents, io.open(filename, 'r'):read('*a'))
+ table.insert(contents, safe_read(filename, '*a'))
end
return table.concat(contents, '')
end
-local match_metatable = {
- __index = function(tbl, key)
- rawset(tbl, key, {})
- return tbl[key]
- end
-}
+--- The explicitly set queries from |vim.treesitter.query.set_query()|
+local explicit_queries = setmetatable({}, {
+ __index = function(t, k)
+ local lang_queries = {}
+ rawset(t, k, lang_queries)
-local function new_match_metadata()
- return setmetatable({}, match_metatable)
+ return lang_queries
+ end,
+})
+
+--- Sets the runtime query {query_name} for {lang}
+---
+--- This allows users to override any runtime files and/or configuration
+--- set by plugins.
+---
+--- @param lang string: The language to use for the query
+--- @param query_name string: The name of the query (i.e. "highlights")
+--- @param text string: The query text (unparsed).
+function M.set_query(lang, query_name, text)
+ explicit_queries[lang][query_name] = M.parse_query(lang, text)
end
--- Returns the runtime query {query_name} for {lang}.
---
--- @param lang The language to use for the query
--- @param query_name The name of the query (i.e. "highlights")
---
--- @return The corresponding query, parsed.
+---
+--- @param lang The language to use for the query
+--- @param query_name The name of the query (i.e. "highlights")
+---
+--- @return The corresponding query, parsed.
function M.get_query(lang, query_name)
+ if explicit_queries[lang][query_name] then
+ return explicit_queries[lang][query_name]
+ end
+
local query_files = M.get_query_files(lang, query_name)
local query_string = read_query_files(query_files)
@@ -92,12 +138,23 @@ function M.get_query(lang, query_name)
end
end
---- Parses a query.
---
--- @param language The language
--- @param query A string containing the query (s-expr syntax)
---
--- @returns The query
+--- Parse {query} as a string. (If the query is in a file, the caller
+--- should read the contents into a string before calling).
+---
+--- Returns a `Query` (see |lua-treesitter-query|) object which can be used to
+--- search nodes in the syntax tree for the patterns defined in {query}
+--- using `iter_*` methods below.
+---
+--- Exposes `info` and `captures` with additional information about the {query}.
+--- - `captures` contains the list of unique capture names defined in
+--- {query}.
+--- -` info.captures` also points to `captures`.
+--- - `info.patterns` contains information about predicates.
+---
+--- @param lang The language
+--- @param query A string containing the query (s-expr syntax)
+---
+--- @returns The query
function M.parse_query(lang, query)
language.require_language(lang)
local self = setmetatable({}, Query)
@@ -110,8 +167,9 @@ end
-- TODO(vigoux): support multiline nodes too
--- Gets the text corresponding to a given node
--- @param node the node
--- @param bufnr the buffer from which the node is extracted.
+---
+--- @param node the node
+--- @param bsource The buffer or string from which the node is extracted
function M.get_node_text(node, source)
local start_row, start_col, start_byte = node:start()
local end_row, end_col, end_byte = node:end_()
@@ -163,6 +221,7 @@ local predicate_handlers = {
["match?"] = (function()
local magic_prefixes = {['\\v']=true, ['\\m']=true, ['\\M']=true, ['\\V']=true}
+ ---@private
local function check_magic(str)
if string.len(str) < 2 or magic_prefixes[string.sub(str,1,2)] then
return str
@@ -172,7 +231,7 @@ local predicate_handlers = {
local compiled_vim_regexes = setmetatable({}, {
__index = function(t, pattern)
- local res = vim.regex(check_magic(vim.fn.escape(pattern, '\\')))
+ local res = vim.regex(check_magic(pattern))
rawset(t, pattern, res)
return res
end
@@ -211,12 +270,16 @@ predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-- Directives store metadata or perform side effects against a match.
-- Directives should always end with a `!`.
-- Directive handler receive the following arguments
--- (match, pattern, bufnr, predicate)
+-- (match, pattern, bufnr, predicate, metadata)
local directive_handlers = {
["set!"] = function(_, _, _, pred, metadata)
if #pred == 4 then
-- (#set! @capture "key" "value")
- metadata[pred[2]][pred[3]] = pred[4]
+ local capture = pred[2]
+ if not metadata[capture] then
+ metadata[capture] = {}
+ end
+ metadata[capture][pred[3]] = pred[4]
else
-- (#set! "key" "value")
metadata[pred[2]] = pred[3]
@@ -231,7 +294,6 @@ local directive_handlers = {
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
local end_col_offset = pred[6] or 0
- local key = pred[7] or "offset"
range[1] = range[1] + start_row_offset
range[2] = range[2] + start_col_offset
@@ -240,16 +302,16 @@ local directive_handlers = {
-- If this produces an invalid range, we just skip it.
if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
- metadata[pred[2]][key] = range
+ metadata.content = {range}
end
end
}
--- Adds a new predicate to be used in queries
---
--- @param name the name of the predicate, without leading #
--- @param handler the handler function to be used
--- signature will be (match, pattern, bufnr, predicate)
+---
+--- @param name the name of the predicate, without leading #
+--- @param handler the handler function to be used
+--- signature will be (match, pattern, bufnr, predicate)
function M.add_predicate(name, handler, force)
if predicate_handlers[name] and not force then
error(string.format("Overriding %s", name))
@@ -259,10 +321,10 @@ function M.add_predicate(name, handler, force)
end
--- Adds a new directive to be used in queries
---
--- @param name the name of the directive, without leading #
--- @param handler the handler function to be used
--- signature will be (match, pattern, bufnr, predicate)
+---
+--- @param name the name of the directive, without leading #
+--- @param handler the handler function to be used
+--- signature will be (match, pattern, bufnr, predicate)
function M.add_directive(name, handler, force)
if directive_handlers[name] and not force then
error(string.format("Overriding %s", name))
@@ -276,14 +338,17 @@ function M.list_predicates()
return vim.tbl_keys(predicate_handlers)
end
+---@private
local function xor(x, y)
return (x or y) and not (x and y)
end
+---@private
local function is_directive(name)
return string.sub(name, -1) == "!"
end
+---@private
function Query:match_preds(match, pattern, source)
local preds = self.info.patterns[pattern]
@@ -322,7 +387,7 @@ function Query:match_preds(match, pattern, source)
return true
end
---- Applies directives against a match and pattern.
+---@private
function Query:apply_directives(match, pattern, source, metadata)
local preds = self.info.patterns[pattern]
@@ -344,6 +409,7 @@ end
--- Returns the start and stop value if set else the node's range.
-- When the node's range is used, the stop is incremented by 1
-- to make the search inclusive.
+---@private
local function value_or_node_range(start, stop, node)
if start == nil and stop == nil then
local node_start, _, node_stop, _ = node:range()
@@ -353,15 +419,36 @@ local function value_or_node_range(start, stop, node)
return start, stop
end
---- Iterates of the captures of self on a given range.
---
--- @param node The node under which the search will occur
--- @param buffer The source buffer to search
--- @param start The starting line of the search
--- @param stop The stopping line of the search (end-exclusive)
---
--- @returns The matching capture id
--- @returns The captured node
+--- Iterate over all captures from all matches inside {node}
+---
+--- {source} is needed if the query contains predicates, then the caller
+--- must ensure to use a freshly parsed tree consistent with the current
+--- text of the buffer (if relevent). {start_row} and {end_row} can be used to limit
+--- matches inside a row range (this is typically used with root node
+--- as the node, i e to get syntax highlight matches in the current
+--- viewport). When omitted the start and end row values are used from the given node.
+---
+--- The iterator returns three values, a numeric id identifying the capture,
+--- the captured node, and metadata from any directives processing the match.
+--- The following example shows how to get captures by name:
+---
+--- <pre>
+--- for id, node, metadata in query:iter_captures(tree:root(), bufnr, first, last) do
+--- local name = query.captures[id] -- name of the capture in the query
+--- -- typically useful info about the node:
+--- local type = node:type() -- type of the captured node
+--- local row1, col1, row2, col2 = node:range() -- range of the capture
+--- ... use the info here ...
+--- end
+--- </pre>
+---
+--- @param node The node under which the search will occur
+--- @param source The source buffer or string to exctract text from
+--- @param start The starting line of the search
+--- @param stop The stopping line of the search (end-exclusive)
+---
+--- @returns The matching capture id
+--- @returns The captured node
function Query:iter_captures(node, source, start, stop)
if type(source) == "number" and source == 0 then
source = vim.api.nvim_get_current_buf()
@@ -370,9 +457,10 @@ function Query:iter_captures(node, source, start, stop)
start, stop = value_or_node_range(start, stop, node)
local raw_iter = node:_rawquery(self.query, true, start, stop)
+ ---@private
local function iter()
local capture, captured_node, match = raw_iter()
- local metadata = new_match_metadata()
+ local metadata = {}
if match ~= nil then
local active = self:match_preds(match, match.pattern, source)
@@ -389,14 +477,35 @@ function Query:iter_captures(node, source, start, stop)
end
--- Iterates the matches of self on a given range.
---
--- @param node The node under which the search will occur
--- @param buffer The source buffer to search
--- @param start The starting line of the search
--- @param stop The stopping line of the search (end-exclusive)
---
--- @returns The matching pattern id
--- @returns The matching match
+---
+--- Iterate over all matches within a node. The arguments are the same as
+--- for |query:iter_captures()| but the iterated values are different:
+--- an (1-based) index of the pattern in the query, a table mapping
+--- capture indices to nodes, and metadata from any directives processing the match.
+--- If the query has more than one pattern the capture table might be sparse,
+--- and e.g. `pairs()` method should be used over `ipairs`.
+--- Here an example iterating over all captures in every match:
+---
+--- <pre>
+--- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do
+--- for id, node in pairs(match) do
+--- local name = query.captures[id]
+--- -- `node` was captured by the `name` capture in the match
+---
+--- local node_data = metadata[id] -- Node level metadata
+---
+--- ... use the info here ...
+--- end
+--- end
+--- </pre>
+---
+--- @param node The node under which the search will occur
+--- @param source The source buffer or string to search
+--- @param start The starting line of the search
+--- @param stop The stopping line of the search (end-exclusive)
+---
+--- @returns The matching pattern id
+--- @returns The matching match
function Query:iter_matches(node, source, start, stop)
if type(source) == "number" and source == 0 then
source = vim.api.nvim_get_current_buf()
@@ -407,7 +516,7 @@ function Query:iter_matches(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, false, start, stop)
local function iter()
local pattern, match = raw_iter()
- local metadata = new_match_metadata()
+ local metadata = {}
if match ~= nil then
local active = self:match_preds(match, pattern, source)