aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter')
-rw-r--r--runtime/lua/vim/treesitter/health.lua14
-rw-r--r--runtime/lua/vim/treesitter/highlighter.lua195
-rw-r--r--runtime/lua/vim/treesitter/language.lua8
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua177
-rw-r--r--runtime/lua/vim/treesitter/query.lua195
5 files changed, 358 insertions, 231 deletions
diff --git a/runtime/lua/vim/treesitter/health.lua b/runtime/lua/vim/treesitter/health.lua
index 53ccc6e88d..3bd59ca282 100644
--- a/runtime/lua/vim/treesitter/health.lua
+++ b/runtime/lua/vim/treesitter/health.lua
@@ -15,24 +15,24 @@ function M.check()
local report_error = vim.fn['health#report_error']
local parsers = M.list_parsers()
- report_info(string.format("Runtime ABI version : %d", ts.language_version))
+ report_info(string.format('Runtime ABI version : %d', ts.language_version))
for _, parser in pairs(parsers) do
- local parsername = vim.fn.fnamemodify(parser, ":t:r")
+ local parsername = vim.fn.fnamemodify(parser, ':t:r')
local is_loadable, ret = pcall(ts.language.require_language, parsername)
if not is_loadable then
- report_error(string.format("Impossible to load parser for %s: %s", parsername, ret))
+ report_error(string.format('Impossible to load parser for %s: %s', parsername, ret))
elseif ret then
local lang = ts.language.inspect_language(parsername)
- report_ok(string.format("Loaded parser for %s: ABI version %d",
- parsername, lang._abi_version))
+ report_ok(
+ string.format('Loaded parser for %s: ABI version %d', parsername, lang._abi_version)
+ )
else
- report_error(string.format("Unable to load parser for %s", parsername))
+ report_error(string.format('Unable to load parser for %s', parsername))
end
end
end
return M
-
diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua
index 22b528838c..e27a5fa9c3 100644
--- a/runtime/lua/vim/treesitter/highlighter.lua
+++ b/runtime/lua/vim/treesitter/highlighter.lua
@@ -1,5 +1,5 @@
local a = vim.api
-local query = require"vim.treesitter.query"
+local query = require('vim.treesitter.query')
-- support reload for quick experimentation
local TSHighlighter = rawget(vim.treesitter, 'TSHighlighter') or {}
@@ -10,64 +10,88 @@ TSHighlighter.active = TSHighlighter.active or {}
local TSHighlighterQuery = {}
TSHighlighterQuery.__index = TSHighlighterQuery
-local ns = a.nvim_create_namespace("treesitter/highlighter")
+local ns = a.nvim_create_namespace('treesitter/highlighter')
local _default_highlights = {}
local _link_default_highlight_once = function(from, to)
if not _default_highlights[from] then
_default_highlights[from] = true
- vim.cmd(string.format("highlight default link %s %s", from, to))
+ a.nvim_set_hl(0, from, { link = to, default = true })
end
return from
end
-TSHighlighter.hl_map = {
- ["error"] = "Error",
-
--- Miscs
- ["comment"] = "Comment",
- ["punctuation.delimiter"] = "Delimiter",
- ["punctuation.bracket"] = "Delimiter",
- ["punctuation.special"] = "Delimiter",
-
--- Constants
- ["constant"] = "Constant",
- ["constant.builtin"] = "Special",
- ["constant.macro"] = "Define",
- ["string"] = "String",
- ["string.regex"] = "String",
- ["string.escape"] = "SpecialChar",
- ["character"] = "Character",
- ["number"] = "Number",
- ["boolean"] = "Boolean",
- ["float"] = "Float",
-
--- Functions
- ["function"] = "Function",
- ["function.special"] = "Function",
- ["function.builtin"] = "Special",
- ["function.macro"] = "Macro",
- ["parameter"] = "Identifier",
- ["method"] = "Function",
- ["field"] = "Identifier",
- ["property"] = "Identifier",
- ["constructor"] = "Special",
-
--- Keywords
- ["conditional"] = "Conditional",
- ["repeat"] = "Repeat",
- ["label"] = "Label",
- ["operator"] = "Operator",
- ["keyword"] = "Keyword",
- ["exception"] = "Exception",
-
- ["type"] = "Type",
- ["type.builtin"] = "Type",
- ["structure"] = "Structure",
- ["include"] = "Include",
+-- If @definition.special does not exist use @definition instead
+local subcapture_fallback = {
+ __index = function(self, capture)
+ local rtn
+ local shortened = capture
+ while not rtn and shortened do
+ shortened = shortened:match('(.*)%.')
+ rtn = shortened and rawget(self, shortened)
+ end
+ rawset(self, capture, rtn or '__notfound')
+ return rtn
+ end,
}
+TSHighlighter.hl_map = setmetatable({
+ ['error'] = 'Error',
+ ['text.underline'] = 'Underlined',
+ ['todo'] = 'Todo',
+ ['debug'] = 'Debug',
+
+ -- Miscs
+ ['comment'] = 'Comment',
+ ['punctuation.delimiter'] = 'Delimiter',
+ ['punctuation.bracket'] = 'Delimiter',
+ ['punctuation.special'] = 'Delimiter',
+
+ -- Constants
+ ['constant'] = 'Constant',
+ ['constant.builtin'] = 'Special',
+ ['constant.macro'] = 'Define',
+ ['define'] = 'Define',
+ ['macro'] = 'Macro',
+ ['string'] = 'String',
+ ['string.regex'] = 'String',
+ ['string.escape'] = 'SpecialChar',
+ ['character'] = 'Character',
+ ['character.special'] = 'SpecialChar',
+ ['number'] = 'Number',
+ ['boolean'] = 'Boolean',
+ ['float'] = 'Float',
+
+ -- Functions
+ ['function'] = 'Function',
+ ['function.special'] = 'Function',
+ ['function.builtin'] = 'Special',
+ ['function.macro'] = 'Macro',
+ ['parameter'] = 'Identifier',
+ ['method'] = 'Function',
+ ['field'] = 'Identifier',
+ ['property'] = 'Identifier',
+ ['constructor'] = 'Special',
+
+ -- Keywords
+ ['conditional'] = 'Conditional',
+ ['repeat'] = 'Repeat',
+ ['label'] = 'Label',
+ ['operator'] = 'Operator',
+ ['keyword'] = 'Keyword',
+ ['exception'] = 'Exception',
+
+ ['type'] = 'Type',
+ ['type.builtin'] = 'Type',
+ ['type.qualifier'] = 'Type',
+ ['type.definition'] = 'Typedef',
+ ['storageclass'] = 'StorageClass',
+ ['structure'] = 'Structure',
+ ['include'] = 'Include',
+ ['preproc'] = 'PreProc',
+}, subcapture_fallback)
+
---@private
local function is_highlight_name(capture_name)
local firstc = string.sub(capture_name, 1, 1)
@@ -89,13 +113,13 @@ function TSHighlighterQuery.new(lang, query_string)
rawset(table, capture, id)
return id
- end
+ end,
})
if query_string then
self._query = query.parse_query(lang, query_string)
else
- self._query = query.get_query(lang, "highlights")
+ self._query = query.get_query(lang, 'highlights')
end
return self
@@ -128,17 +152,23 @@ end
function TSHighlighter.new(tree, opts)
local self = setmetatable({}, TSHighlighter)
- if type(tree:source()) ~= "number" then
- error("TSHighlighter can not be used with a string parser source.")
+ if type(tree:source()) ~= 'number' then
+ error('TSHighlighter can not be used with a string parser source.')
end
opts = opts or {}
self.tree = tree
- tree:register_cbs {
- on_changedtree = function(...) self:on_changedtree(...) end;
- on_bytes = function(...) self:on_bytes(...) end;
- on_detach = function(...) self:on_detach(...) end;
- }
+ tree:register_cbs({
+ on_changedtree = function(...)
+ self:on_changedtree(...)
+ end,
+ on_bytes = function(...)
+ self:on_bytes(...)
+ end,
+ on_detach = function(...)
+ self:on_detach(...)
+ end,
+ })
self.bufnr = tree:source()
self.edit_count = 0
@@ -157,7 +187,7 @@ function TSHighlighter.new(tree, opts)
end
end
- a.nvim_buf_set_option(self.bufnr, "syntax", "")
+ a.nvim_buf_set_option(self.bufnr, 'syntax', '')
TSHighlighter.active[self.bufnr] = self
@@ -166,7 +196,7 @@ function TSHighlighter.new(tree, opts)
-- syntax FileType autocmds. Later on we should integrate with the
-- `:syntax` and `set syntax=...` machinery properly.
if vim.g.syntax_on ~= 1 then
- vim.api.nvim_command("runtime! syntax/synload.vim")
+ vim.api.nvim_command('runtime! syntax/synload.vim')
end
self.tree:parse()
@@ -186,7 +216,7 @@ function TSHighlighter:get_highlight_state(tstree)
if not self._highlight_states[tstree] then
self._highlight_states[tstree] = {
next_row = 0,
- iter = nil
+ iter = nil,
}
end
@@ -211,7 +241,7 @@ end
---@private
function TSHighlighter:on_changedtree(changes)
for _, ch in ipairs(changes or {}) do
- a.nvim__buf_redraw_range(self.bufnr, ch[1], ch[3]+1)
+ a.nvim__buf_redraw_range(self.bufnr, ch[1], ch[3] + 1)
end
end
@@ -229,39 +259,50 @@ end
---@private
local function on_line_impl(self, buf, line)
self.tree:for_each_tree(function(tstree, tree)
- if not tstree then return end
+ if not tstree then
+ return
+ end
local root_node = tstree:root()
local root_start_row, _, root_end_row, _ = root_node:range()
-- Only worry about trees within the line range
- if root_start_row > line or root_end_row < line then return end
+ if root_start_row > line or root_end_row < line then
+ return
+ end
local state = self:get_highlight_state(tstree)
local highlighter_query = self:get_query(tree:lang())
-- Some injected languages may not have highlight queries.
- if not highlighter_query:query() then return end
+ if not highlighter_query:query() then
+ return
+ end
- if state.iter == nil then
- state.iter = highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1)
+ if state.iter == nil or state.next_row < line then
+ state.iter =
+ highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1)
end
while line >= state.next_row do
local capture, node, metadata = state.iter()
- if capture == nil then break end
+ if capture == nil then
+ break
+ end
local start_row, start_col, end_row, end_col = node:range()
local hl = highlighter_query.hl_cache[capture]
if hl and end_row >= line then
- a.nvim_buf_set_extmark(buf, ns, start_row, start_col,
- { end_line = end_row, end_col = end_col,
- hl_group = hl,
- ephemeral = true,
- priority = tonumber(metadata.priority) or 100 -- Low but leaves room below
- })
+ a.nvim_buf_set_extmark(buf, ns, start_row, start_col, {
+ end_line = end_row,
+ end_col = end_col,
+ hl_group = hl,
+ ephemeral = true,
+ priority = tonumber(metadata.priority) or 100, -- Low but leaves room below
+ conceal = metadata.conceal,
+ })
end
if start_row > line then
state.next_row = start_row
@@ -273,7 +314,9 @@ end
---@private
function TSHighlighter._on_line(_, _win, buf, line, _)
local self = TSHighlighter.active[buf]
- if not self then return end
+ if not self then
+ return
+ end
on_line_impl(self, buf, line)
end
@@ -299,9 +342,9 @@ function TSHighlighter._on_win(_, _win, buf, _topline)
end
a.nvim_set_decoration_provider(ns, {
- on_buf = TSHighlighter._on_buf;
- on_win = TSHighlighter._on_win;
- on_line = TSHighlighter._on_line;
+ on_buf = TSHighlighter._on_buf,
+ on_win = TSHighlighter._on_win,
+ on_line = TSHighlighter._on_line,
})
return TSHighlighter
diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua
index 6f347ff25f..dfb6f5be84 100644
--- a/runtime/lua/vim/treesitter/language.lua
+++ b/runtime/lua/vim/treesitter/language.lua
@@ -14,7 +14,7 @@ function M.require_language(lang, path, silent)
return true
end
if path == nil then
- local fname = 'parser/' .. lang .. '.*'
+ local fname = 'parser/' .. vim.fn.fnameescape(lang) .. '.*'
local paths = a.nvim_get_runtime_file(fname, false)
if #paths == 0 then
if silent then
@@ -22,13 +22,15 @@ function M.require_language(lang, path, silent)
end
-- TODO(bfredl): help tag?
- error("no parser for '"..lang.."' language, see :help treesitter-parsers")
+ error("no parser for '" .. lang .. "' language, see :help treesitter-parsers")
end
path = paths[1]
end
if silent then
- return pcall(function() vim._ts_add_language(path, lang) end)
+ return pcall(function()
+ vim._ts_add_language(path, lang)
+ end)
else
vim._ts_add_language(path, lang)
end
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index 85fd5cd8e0..4d3b0631a2 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -1,6 +1,6 @@
local a = vim.api
-local query = require'vim.treesitter.query'
-local language = require'vim.treesitter.language'
+local query = require('vim.treesitter.query')
+local language = require('vim.treesitter.language')
local LanguageTree = {}
LanguageTree.__index = LanguageTree
@@ -32,9 +32,8 @@ function LanguageTree.new(source, lang, opts)
_regions = {},
_trees = {},
_opts = opts,
- _injection_query = injections[lang]
- and query.parse_query(lang, injections[lang])
- or query.get_query(lang, "injections"),
+ _injection_query = injections[lang] and query.parse_query(lang, injections[lang])
+ or query.get_query(lang, 'injections'),
_valid = false,
_parser = vim._create_ts_parser(lang),
_callbacks = {
@@ -42,11 +41,10 @@ function LanguageTree.new(source, lang, opts)
bytes = {},
detach = {},
child_added = {},
- child_removed = {}
+ child_removed = {},
},
}, LanguageTree)
-
return self
end
@@ -76,8 +74,8 @@ function LanguageTree:lang()
end
--- Determines whether this tree is valid.
---- If the tree is invalid, `parse()` must be called
---- to get the updated tree.
+--- If the tree is invalid, call `parse()`.
+--- This will return the updated tree.
function LanguageTree:is_valid()
return self._valid
end
@@ -234,7 +232,9 @@ end
--- Destroys this language tree and all its children.
---
--- Any cleanup logic should be performed here.
---- Note, this DOES NOT remove this tree from a parent.
+---
+--- Note:
+--- This DOES NOT remove this tree from a parent. Instead,
--- `remove_child` must be called on the parent to remove it.
function LanguageTree:destroy()
-- Cleanup here
@@ -259,22 +259,27 @@ end
---
--- Note, this call invalidates the tree and requires it to be parsed again.
---
----@param regions A list of regions this tree should manage and parse.
+---@param regions (table) list of regions this tree should manage and parse.
function LanguageTree:set_included_regions(regions)
- -- TODO(vigoux): I don't think string parsers are useful for now
- if type(self._source) == "number" then
- -- Transform the tables from 4 element long to 6 element long (with byte offset)
- for _, region in ipairs(regions) do
- for i, range in ipairs(region) do
- if type(range) == "table" and #range == 4 then
- local start_row, start_col, end_row, end_col = unpack(range)
+ -- Transform the tables from 4 element long to 6 element long (with byte offset)
+ for _, region in ipairs(regions) do
+ for i, range in ipairs(region) do
+ if type(range) == 'table' and #range == 4 then
+ local start_row, start_col, end_row, end_col = unpack(range)
+ local start_byte = 0
+ local end_byte = 0
+ -- TODO(vigoux): proper byte computation here, and account for EOL ?
+ if type(self._source) == 'number' then
-- Easy case, this is a buffer parser
- -- TODO(vigoux): proper byte computation here, and account for EOL ?
- local start_byte = a.nvim_buf_get_offset(self._source, start_row) + start_col
- local end_byte = a.nvim_buf_get_offset(self._source, end_row) + end_col
-
- region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte }
+ start_byte = a.nvim_buf_get_offset(self._source, start_row) + start_col
+ end_byte = a.nvim_buf_get_offset(self._source, end_row) + end_col
+ elseif type(self._source) == 'string' then
+ -- string parser, single `\n` delimited string
+ start_byte = vim.fn.byteidx(self._source, start_col)
+ end_byte = vim.fn.byteidx(self._source, end_col)
end
+
+ region[i] = { start_row, start_col, start_byte, end_row, end_col, end_byte }
end
end
end
@@ -293,6 +298,14 @@ function LanguageTree:included_regions()
return self._regions
end
+---@private
+local function get_node_range(node, id, metadata)
+ if metadata[id] and metadata[id].range then
+ return metadata[id].range
+ end
+ return { node:range() }
+end
+
--- Gets language injection points by language.
---
--- This is where most of the injection processing occurs.
@@ -301,7 +314,9 @@ end
--- instead of using the entire nodes range.
---@private
function LanguageTree:_get_injections()
- if not self._injection_query then return {} end
+ if not self._injection_query then
+ return {}
+ end
local injections = {}
@@ -309,7 +324,9 @@ function LanguageTree:_get_injections()
local root_node = tree:root()
local start_line, _, end_line, _ = root_node:range()
- for pattern, match, metadata in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do
+ for pattern, match, metadata in
+ self._injection_query:iter_matches(root_node, self._source, start_line, end_line + 1)
+ do
local lang = nil
local ranges = {}
local combined = metadata.combined
@@ -320,11 +337,11 @@ function LanguageTree:_get_injections()
local content = metadata.content
-- Allow for captured nodes to be used
- if type(content) == "number" then
- content = {match[content]}
+ if type(content) == 'number' then
+ content = { match[content]:range() }
end
- if content then
+ if type(content) == 'table' and #content >= 4 then
vim.list_extend(ranges, content)
end
end
@@ -340,21 +357,21 @@ function LanguageTree:_get_injections()
local name = self._injection_query.captures[id]
-- Lang should override any other language tag
- if name == "language" and not lang then
+ if name == 'language' and not lang then
lang = query.get_node_text(node, self._source)
- elseif name == "combined" then
+ elseif name == 'combined' then
combined = true
- elseif name == "content" and #ranges == 0 then
- table.insert(ranges, node)
- -- Ignore any tags that start with "_"
- -- Allows for other tags to be used in matches
- elseif string.sub(name, 1, 1) ~= "_" then
+ elseif name == 'content' and #ranges == 0 then
+ table.insert(ranges, get_node_range(node, id, metadata))
+ -- Ignore any tags that start with "_"
+ -- Allows for other tags to be used in matches
+ elseif string.sub(name, 1, 1) ~= '_' then
if not lang then
lang = name
end
if #ranges == 0 then
- table.insert(ranges, node)
+ table.insert(ranges, get_node_range(node, id, metadata))
end
end
end
@@ -391,7 +408,10 @@ function LanguageTree:_get_injections()
for _, entry in pairs(patterns) do
if entry.combined then
- table.insert(result[lang], vim.tbl_flatten(entry.regions))
+ local regions = vim.tbl_map(function(e)
+ return vim.tbl_flatten(e)
+ end, entry.regions)
+ table.insert(result[lang], regions)
else
for _, ranges in ipairs(entry.regions) do
table.insert(result[lang], ranges)
@@ -412,10 +432,19 @@ function LanguageTree:_do_callback(cb_name, ...)
end
---@private
-function LanguageTree:_on_bytes(bufnr, changed_tick,
- start_row, start_col, start_byte,
- old_row, old_col, old_byte,
- new_row, new_col, new_byte)
+function LanguageTree:_on_bytes(
+ bufnr,
+ changed_tick,
+ start_row,
+ start_col,
+ start_byte,
+ old_row,
+ old_col,
+ old_byte,
+ new_row,
+ new_col,
+ new_byte
+)
self:invalidate()
local old_end_col = old_col + ((old_row == 0) and start_col or 0)
@@ -424,16 +453,33 @@ function LanguageTree:_on_bytes(bufnr, changed_tick,
-- Edit all trees recursively, together BEFORE emitting a bytes callback.
-- In most cases this callback should only be called from the root tree.
self:for_each_tree(function(tree)
- tree:edit(start_byte,start_byte+old_byte,start_byte+new_byte,
- start_row, start_col,
- start_row+old_row, old_end_col,
- start_row+new_row, new_end_col)
+ tree:edit(
+ start_byte,
+ start_byte + old_byte,
+ start_byte + new_byte,
+ start_row,
+ start_col,
+ start_row + old_row,
+ old_end_col,
+ start_row + new_row,
+ new_end_col
+ )
end)
- self:_do_callback('bytes', bufnr, changed_tick,
- start_row, start_col, start_byte,
- old_row, old_col, old_byte,
- new_row, new_col, new_byte)
+ self:_do_callback(
+ 'bytes',
+ bufnr,
+ changed_tick,
+ start_row,
+ start_col,
+ start_byte,
+ old_row,
+ old_col,
+ old_byte,
+ new_row,
+ new_col,
+ new_byte
+ )
end
---@private
@@ -441,23 +487,24 @@ function LanguageTree:_on_reload()
self:invalidate(true)
end
-
---@private
function LanguageTree:_on_detach(...)
self:invalidate(true)
self:_do_callback('detach', ...)
end
---- Registers callbacks for the parser
----@param cbs An `nvim_buf_attach`-like table argument with the following keys :
---- `on_bytes` : see `nvim_buf_attach`, but this will be called _after_ the parsers callback.
---- `on_changedtree` : a callback that will be called every time the tree has syntactical changes.
---- it will only be passed one argument, that is a table of the ranges (as node ranges) that
---- changed.
---- `on_child_added` : emitted when a child is added to the tree.
---- `on_child_removed` : emitted when a child is removed from the tree.
+--- Registers callbacks for the parser.
+---@param cbs table An |nvim_buf_attach()|-like table argument with the following keys :
+--- - `on_bytes` : see |nvim_buf_attach()|, but this will be called _after_ the parsers callback.
+--- - `on_changedtree` : a callback that will be called every time the tree has syntactical changes.
+--- It will only be passed one argument, which is a table of the ranges (as node ranges) that
+--- changed.
+--- - `on_child_added` : emitted when a child is added to the tree.
+--- - `on_child_removed` : emitted when a child is removed from the tree.
function LanguageTree:register_cbs(cbs)
- if not cbs then return end
+ if not cbs then
+ return
+ end
if cbs.on_changedtree then
table.insert(self._callbacks.changedtree, cbs.on_changedtree)
@@ -486,16 +533,10 @@ local function tree_contains(tree, range)
local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2])
local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4])
- if start_fits and end_fits then
- return true
- end
-
- return false
+ return start_fits and end_fits
end
---- Determines whether @param range is contained in this language tree
----
---- This goes down the tree to recursively check children.
+--- Determines whether {range} is contained in this language tree
---
---@param range A range, that is a `{ start_line, start_col, end_line, end_col }` table.
function LanguageTree:contains(range)
@@ -508,7 +549,7 @@ function LanguageTree:contains(range)
return false
end
---- Gets the appropriate language that contains @param range
+--- Gets the appropriate language that contains {range}
---
---@param range A text range, see |LanguageTree:contains|
function LanguageTree:language_for_range(range)
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index ebed502c92..103e85abfd 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -1,5 +1,5 @@
local a = vim.api
-local language = require'vim.treesitter.language'
+local language = require('vim.treesitter.language')
-- query: pattern matching on trees
-- predicate matching is implemented in lua
@@ -43,7 +43,9 @@ function M.get_query_files(lang, query_name, is_included)
local query_path = string.format('queries/%s/%s.scm', lang, query_name)
local lang_files = dedupe_files(a.nvim_get_runtime_file(query_path, true))
- if #lang_files == 0 then return {} end
+ if #lang_files == 0 then
+ return {}
+ end
local base_langs = {}
@@ -52,7 +54,7 @@ function M.get_query_files(lang, query_name, is_included)
-- ;+ inherits: ({language},)*{language}
--
-- {language} ::= {lang} | ({lang})
- local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$"
+ local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
for _, file in ipairs(lang_files) do
local modeline = safe_read(file, '*l')
@@ -62,7 +64,7 @@ function M.get_query_files(lang, query_name, is_included)
if langlist then
for _, incllang in ipairs(vim.split(langlist, ',', true)) do
- local is_optional = incllang:match("%(.*%)")
+ local is_optional = incllang:match('%(.*%)')
if is_optional then
if not is_included then
@@ -90,7 +92,7 @@ end
local function read_query_files(filenames)
local contents = {}
- for _,filename in ipairs(filenames) do
+ for _, filename in ipairs(filenames) do
table.insert(contents, safe_read(filename, '*a'))
end
@@ -138,30 +140,43 @@ function M.get_query(lang, query_name)
end
end
+local query_cache = setmetatable({}, {
+ __index = function(tbl, key)
+ rawset(tbl, key, {})
+ return rawget(tbl, key)
+ end,
+})
+
--- Parse {query} as a string. (If the query is in a file, the caller
---- should read the contents into a string before calling).
+--- should read the contents into a string before calling).
---
--- Returns a `Query` (see |lua-treesitter-query|) object which can be used to
--- search nodes in the syntax tree for the patterns defined in {query}
--- using `iter_*` methods below.
---
---- Exposes `info` and `captures` with additional information about the {query}.
+--- Exposes `info` and `captures` with additional context about {query}.
--- - `captures` contains the list of unique capture names defined in
--- {query}.
--- -` info.captures` also points to `captures`.
--- - `info.patterns` contains information about predicates.
---
----@param lang The language
----@param query A string containing the query (s-expr syntax)
+---@param lang string The language
+---@param query string A string containing the query (s-expr syntax)
---
---@returns The query
function M.parse_query(lang, query)
language.require_language(lang)
- local self = setmetatable({}, Query)
- self.query = vim._ts_parse_query(lang, query)
- self.info = self.query:inspect()
- self.captures = self.info.captures
- return self
+ local cached = query_cache[lang][query]
+ if cached then
+ return cached
+ else
+ local self = setmetatable({}, Query)
+ self.query = vim._ts_parse_query(lang, query)
+ self.info = self.query:inspect()
+ self.captures = self.info.captures
+ query_cache[lang][query] = self
+ return self
+ end
end
--- Gets the text corresponding to a given node
@@ -172,7 +187,7 @@ function M.get_node_text(node, source)
local start_row, start_col, start_byte = node:start()
local end_row, end_col, end_byte = node:end_()
- if type(source) == "number" then
+ if type(source) == 'number' then
local lines
local eof_row = a.nvim_buf_line_count(source)
if start_row >= eof_row then
@@ -186,56 +201,64 @@ function M.get_node_text(node, source)
lines = a.nvim_buf_get_lines(source, start_row, end_row + 1, true)
end
- if #lines == 1 then
- lines[1] = string.sub(lines[1], start_col+1, end_col)
- else
- lines[1] = string.sub(lines[1], start_col+1)
- lines[#lines] = string.sub(lines[#lines], 1, end_col)
+ if #lines > 0 then
+ if #lines == 1 then
+ lines[1] = string.sub(lines[1], start_col + 1, end_col)
+ else
+ lines[1] = string.sub(lines[1], start_col + 1)
+ lines[#lines] = string.sub(lines[#lines], 1, end_col)
+ end
end
- return table.concat(lines, "\n")
- elseif type(source) == "string" then
- return source:sub(start_byte+1, end_byte)
+ return table.concat(lines, '\n')
+ elseif type(source) == 'string' then
+ return source:sub(start_byte + 1, end_byte)
end
end
-- Predicate handler receive the following arguments
-- (match, pattern, bufnr, predicate)
local predicate_handlers = {
- ["eq?"] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- local node_text = M.get_node_text(node, source)
-
- local str
- if type(predicate[3]) == "string" then
- -- (#eq? @aa "foo")
- str = predicate[3]
- else
- -- (#eq? @aa @bb)
- str = M.get_node_text(match[predicate[3]], source)
- end
+ ['eq?'] = function(match, _, source, predicate)
+ local node = match[predicate[2]]
+ if not node then
+ return true
+ end
+ local node_text = M.get_node_text(node, source)
- if node_text ~= str or str == nil then
- return false
- end
+ local str
+ if type(predicate[3]) == 'string' then
+ -- (#eq? @aa "foo")
+ str = predicate[3]
+ else
+ -- (#eq? @aa @bb)
+ str = M.get_node_text(match[predicate[3]], source)
+ end
- return true
+ if node_text ~= str or str == nil then
+ return false
+ end
+
+ return true
end,
- ["lua-match?"] = function(match, _, source, predicate)
- local node = match[predicate[2]]
- local regex = predicate[3]
- return string.find(M.get_node_text(node, source), regex)
+ ['lua-match?'] = function(match, _, source, predicate)
+ local node = match[predicate[2]]
+ if not node then
+ return true
+ end
+ local regex = predicate[3]
+ return string.find(M.get_node_text(node, source), regex)
end,
- ["match?"] = (function()
- local magic_prefixes = {['\\v']=true, ['\\m']=true, ['\\M']=true, ['\\V']=true}
+ ['match?'] = (function()
+ local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true }
---@private
local function check_magic(str)
- if string.len(str) < 2 or magic_prefixes[string.sub(str,1,2)] then
+ if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then
return str
end
- return '\\v'..str
+ return '\\v' .. str
end
local compiled_vim_regexes = setmetatable({}, {
@@ -243,21 +266,27 @@ local predicate_handlers = {
local res = vim.regex(check_magic(pattern))
rawset(t, pattern, res)
return res
- end
+ end,
})
return function(match, _, source, pred)
local node = match[pred[2]]
+ if not node then
+ return true
+ end
local regex = compiled_vim_regexes[pred[3]]
return regex:match_str(M.get_node_text(node, source))
end
end)(),
- ["contains?"] = function(match, _, source, predicate)
+ ['contains?'] = function(match, _, source, predicate)
local node = match[predicate[2]]
+ if not node then
+ return true
+ end
local node_text = M.get_node_text(node, source)
- for i=3,#predicate do
+ for i = 3, #predicate do
if string.find(node_text, predicate[i], 1, true) then
return true
end
@@ -266,19 +295,22 @@ local predicate_handlers = {
return false
end,
- ["any-of?"] = function(match, _, source, predicate)
+ ['any-of?'] = function(match, _, source, predicate)
local node = match[predicate[2]]
+ if not node then
+ return true
+ end
local node_text = M.get_node_text(node, source)
-- Since 'predicate' will not be used by callers of this function, use it
-- to store a string set built from the list of words to check against.
- local string_set = predicate["string_set"]
+ local string_set = predicate['string_set']
if not string_set then
string_set = {}
- for i=3,#predicate do
+ for i = 3, #predicate do
string_set[predicate[i]] = true
end
- predicate["string_set"] = string_set
+ predicate['string_set'] = string_set
end
return string_set[node_text]
@@ -286,32 +318,33 @@ local predicate_handlers = {
}
-- As we provide lua-match? also expose vim-match?
-predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-
+predicate_handlers['vim-match?'] = predicate_handlers['match?']
-- Directives store metadata or perform side effects against a match.
-- Directives should always end with a `!`.
-- Directive handler receive the following arguments
-- (match, pattern, bufnr, predicate, metadata)
local directive_handlers = {
- ["set!"] = function(_, _, _, pred, metadata)
+ ['set!'] = function(_, _, _, pred, metadata)
if #pred == 4 then
-- (#set! @capture "key" "value")
- local capture = pred[2]
- if not metadata[capture] then
- metadata[capture] = {}
+ local _, capture_id, key, value = unpack(pred)
+ if not metadata[capture_id] then
+ metadata[capture_id] = {}
end
- metadata[capture][pred[3]] = pred[4]
+ metadata[capture_id][key] = value
else
+ local _, key, value = unpack(pred)
-- (#set! "key" "value")
- metadata[pred[2]] = pred[3]
+ metadata[key] = value
end
end,
-- Shifts the range of a node.
-- Example: (#offset! @_node 0 1 0 -1)
- ["offset!"] = function(match, _, _, pred, metadata)
- local offset_node = match[pred[2]]
- local range = {offset_node:range()}
+ ['offset!'] = function(match, _, _, pred, metadata)
+ local capture_id = pred[2]
+ local offset_node = match[capture_id]
+ local range = { offset_node:range() }
local start_row_offset = pred[3] or 0
local start_col_offset = pred[4] or 0
local end_row_offset = pred[5] or 0
@@ -324,9 +357,12 @@ local directive_handlers = {
-- If this produces an invalid range, we just skip it.
if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
- metadata.content = {range}
+ if not metadata[capture_id] then
+ metadata[capture_id] = {}
+ end
+ metadata[capture_id].range = range
end
- end
+ end,
}
--- Adds a new predicate to be used in queries
@@ -336,7 +372,7 @@ local directive_handlers = {
--- signature will be (match, pattern, bufnr, predicate)
function M.add_predicate(name, handler, force)
if predicate_handlers[name] and not force then
- error(string.format("Overriding %s", name))
+ error(string.format('Overriding %s', name))
end
predicate_handlers[name] = handler
@@ -344,17 +380,23 @@ end
--- Adds a new directive to be used in queries
---
+--- Handlers can set match level data by setting directly on the
+--- metadata object `metadata.key = value`, additionally, handlers
+--- can set node level data by using the capture id on the
+--- metadata table `metadata[capture_id].key = value`
+---
---@param name the name of the directive, without leading #
---@param handler the handler function to be used
---- signature will be (match, pattern, bufnr, predicate)
+--- signature will be (match, pattern, bufnr, predicate, metadata)
function M.add_directive(name, handler, force)
if directive_handlers[name] and not force then
- error(string.format("Overriding %s", name))
+ error(string.format('Overriding %s', name))
end
directive_handlers[name] = handler
end
+--- Lists the currently available directives to use in queries.
---@return The list of supported directives.
function M.list_directives()
return vim.tbl_keys(directive_handlers)
@@ -372,7 +414,7 @@ end
---@private
local function is_directive(name)
- return string.sub(name, -1) == "!"
+ return string.sub(name, -1) == '!'
end
---@private
@@ -389,7 +431,7 @@ function Query:match_preds(match, pattern, source)
-- Skip over directives... they will get processed after all the predicates.
if not is_directive(pred[1]) then
- if string.sub(pred[1], 1, 4) == "not-" then
+ if string.sub(pred[1], 1, 4) == 'not-' then
pred_name = string.sub(pred[1], 5)
is_not = true
else
@@ -400,7 +442,7 @@ function Query:match_preds(match, pattern, source)
local handler = predicate_handlers[pred_name]
if not handler then
- error(string.format("No handler for %s", pred[1]))
+ error(string.format('No handler for %s', pred[1]))
return false
end
@@ -423,7 +465,7 @@ function Query:apply_directives(match, pattern, source, metadata)
local handler = directive_handlers[pred[1]]
if not handler then
- error(string.format("No handler for %s", pred[1]))
+ error(string.format('No handler for %s', pred[1]))
return
end
@@ -432,7 +474,6 @@ function Query:apply_directives(match, pattern, source, metadata)
end
end
-
--- Returns the start and stop value if set else the node's range.
-- When the node's range is used, the stop is incremented by 1
-- to make the search inclusive.
@@ -477,7 +518,7 @@ end
---@returns The matching capture id
---@returns The captured node
function Query:iter_captures(node, source, start, stop)
- if type(source) == "number" and source == 0 then
+ if type(source) == 'number' and source == 0 then
source = vim.api.nvim_get_current_buf()
end
@@ -534,7 +575,7 @@ end
---@returns The matching pattern id
---@returns The matching match
function Query:iter_matches(node, source, start, stop)
- if type(source) == "number" and source == 0 then
+ if type(source) == 'number' and source == 0 then
source = vim.api.nvim_get_current_buf()
end