aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorvanaigr <vanaigranov@gmail.com>2024-12-18 01:06:41 -0600
committervanaigr <vanaigranov@gmail.com>2025-01-06 00:35:19 -0600
commitdd234135ad20119917831fd8ffcb19d8562022ca (patch)
tree9c26bb172ead60b9bdab2d847525edb3c4b34881
parent8d2ee542a82a0d162198f27de316ddfc81e8761c (diff)
downloadrneovim-dd234135ad20119917831fd8ffcb19d8562022ca.tar.gz
rneovim-dd234135ad20119917831fd8ffcb19d8562022ca.tar.bz2
rneovim-dd234135ad20119917831fd8ffcb19d8562022ca.zip
refactor: split predicates and directives
-rw-r--r--runtime/lua/vim/treesitter/highlighter.lua4
-rw-r--r--runtime/lua/vim/treesitter/query.lua172
-rw-r--r--test/benchmark/decor_spec.lua1
-rw-r--r--test/functional/treesitter/query_spec.lua4
4 files changed, 108 insertions, 73 deletions
diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua
index 8ce8652f7d..96503c38ea 100644
--- a/runtime/lua/vim/treesitter/highlighter.lua
+++ b/runtime/lua/vim/treesitter/highlighter.lua
@@ -299,6 +299,8 @@ local function on_line_impl(self, buf, line, is_spell_nav)
state.highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1)
end
+ local captures = state.highlighter_query:query().captures
+
while line >= state.next_row do
local capture, node, metadata, match = state.iter(line)
@@ -311,7 +313,7 @@ local function on_line_impl(self, buf, line, is_spell_nav)
if capture then
local hl = state.highlighter_query:get_hl_from_capture(capture)
- local capture_name = state.highlighter_query:query().captures[capture]
+ local capture_name = captures[capture]
local spell, spell_pri_offset = get_spell(capture_name)
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 01fdb708eb..1fc001b39f 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -7,6 +7,59 @@ local memoize = vim.func._memoize
local M = {}
+local function is_directive(name)
+ return string.sub(name, -1) == '!'
+end
+
+---@nodoc
+---@class vim.treesitter.query.ProcessedPredicate
+---@field [1] string predicate name
+---@field [2] boolean should match
+---@field [3] (integer|string)[] the original predicate
+
+---@alias vim.treesitter.query.ProcessedDirective (integer|string)[]
+
+---@nodoc
+---@class vim.treesitter.query.ProcessedPattern {
+---@field predicates vim.treesitter.query.ProcessedPredicate[]
+---@field directives vim.treesitter.query.ProcessedDirective[]
+
+--- Splits the query patterns into predicates and directives.
+---@param patterns table<integer, (integer|string)[][]>
+---@return table<integer, vim.treesitter.query.ProcessedPattern>
+local function process_patterns(patterns)
+ ---@type table<integer, vim.treesitter.query.ProcessedPattern>
+ local processed_patterns = {}
+
+ for k, pattern_list in pairs(patterns) do
+ ---@type vim.treesitter.query.ProcessedPredicate[]
+ local predicates = {}
+ ---@type vim.treesitter.query.ProcessedDirective[]
+ local directives = {}
+
+ for _, pattern in ipairs(pattern_list) do
+ -- Note: tree-sitter strips the leading # from predicates for us.
+ local pred_name = pattern[1]
+ ---@cast pred_name string
+
+ if is_directive(pred_name) then
+ table.insert(directives, pattern)
+ else
+ local should_match = true
+ if pred_name:match('^not%-') then
+ pred_name = pred_name:sub(5)
+ should_match = false
+ end
+ table.insert(predicates, { pred_name, should_match, pattern })
+ end
+ end
+
+ processed_patterns[k] = { predicates = predicates, directives = directives }
+ end
+
+ return processed_patterns
+end
+
---@nodoc
---Parsed query, see |vim.treesitter.query.parse()|
---
@@ -15,6 +68,7 @@ local M = {}
---@field captures string[] list of (unique) capture names defined in query
---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives)
---@field query TSQuery userdata query object
+---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern>
local Query = {}
Query.__index = Query
@@ -33,6 +87,7 @@ function Query.new(lang, ts_query)
patterns = query_info.patterns,
}
self.captures = self.info.captures
+ self._processed_patterns = process_patterns(self.info.patterns)
return self
end
@@ -751,67 +806,50 @@ function M.list_predicates()
return vim.tbl_keys(predicate_handlers)
end
-local function xor(x, y)
- return (x or y) and not (x and y)
-end
-
-local function is_directive(name)
- return string.sub(name, -1) == '!'
-end
-
---@private
----@param match TSQueryMatch
+---@param pattern_i integer
+---@param predicates vim.treesitter.query.ProcessedPredicate[]
+---@param captures table<integer, TSNode[]>
---@param source integer|string
-function Query:match_preds(preds, pattern, captures, source)
- for _, pred in pairs(preds) do
- -- Here we only want to return if a predicate DOES NOT match, and
- -- continue on the other case. This way unknown predicates will not be considered,
- -- which allows some testing and easier user extensibility (#12173).
- -- Also, tree-sitter strips the leading # from predicates for us.
- local is_not = false
-
- -- Skip over directives... they will get processed after all the predicates.
- if not is_directive(pred[1]) then
- local pred_name = pred[1]
- if pred_name:match('^not%-') then
- pred_name = pred_name:sub(5)
- is_not = true
- end
-
- local handler = predicate_handlers[pred_name]
-
- if not handler then
- error(string.format('No handler for %s', pred[1]))
- return false
- end
-
- local pred_matches = handler(captures, pattern, source, pred)
+---@return boolean whether the predicates match
+function Query:_match_predicates(predicates, pattern_i, captures, source)
+ for _, predicate in ipairs(predicates) do
+ local processed_name = predicate[1]
+ local should_match = predicate[2]
+ local orig_predicate = predicate[3]
+
+ local handler = predicate_handlers[processed_name]
+ if not handler then
+ error(string.format('No handler for %s', orig_predicate[1]))
+ return false
+ end
- if not xor(is_not, pred_matches) then
- return false
- end
+ local does_match = handler(captures, pattern_i, source, orig_predicate)
+ if does_match ~= should_match then
+ return false
end
end
return true
end
---@private
----@param match TSQueryMatch
+---@param pattern_i integer
+---@param directives vim.treesitter.query.ProcessedDirective[]
+---@param source integer|string
+---@param captures table<integer, TSNode[]>
---@return vim.treesitter.query.TSMetadata metadata
-function Query:apply_directives(preds, pattern, captures, source)
+function Query:_apply_directives(directives, pattern_i, captures, source)
---@type vim.treesitter.query.TSMetadata
local metadata = {}
- for _, pred in pairs(preds) do
- if is_directive(pred[1]) then
- local handler = directive_handlers[pred[1]]
+ for _, directive in pairs(directives) do
+ local handler = directive_handlers[directive[1]]
- if not handler then
- error(string.format('No handler for %s', pred[1]))
- end
-
- handler(captures, pattern, source, pred, metadata)
+ if not handler then
+ error(string.format('No handler for %s', directive[1]))
end
+
+ handler(captures, pattern_i, source, directive, metadata)
end
return metadata
@@ -835,12 +873,6 @@ local function value_or_node_range(start, stop, node)
return start, stop
end
---- @param match TSQueryMatch
---- @return integer
-local function match_id_hash(_, match)
- return (match:info())
-end
-
--- Iterates over all captures from all matches in {node}.
---
--- {source} is required if the query contains predicates; then the caller
@@ -897,7 +929,7 @@ function Query:iter_captures(node, source, start, stop)
return
end
- local match_id, pattern = match:info()
+ local match_id, pattern_i = match:info()
--- @type vim.treesitter.query.TSMetadata
local metadata
@@ -906,11 +938,14 @@ function Query:iter_captures(node, source, start, stop)
end
if not metadata then
- local preds = self.info.patterns[pattern]
- if preds then
+ metadata = {}
+
+ local processed_pattern = self._processed_patterns[pattern_i]
+ if processed_pattern then
local captures = match:captures()
- if not self:match_preds(preds, pattern, captures, source) then
+ local predicates = processed_pattern.predicates
+ if not self:_match_predicates(predicates, pattern_i, captures, source) then
cursor:remove_match(match_id)
if end_line and captured_node:range() > end_line then
return nil, captured_node, nil, nil
@@ -918,9 +953,8 @@ function Query:iter_captures(node, source, start, stop)
return iter(end_line) -- tail call: try next match
end
- metadata = self:apply_directives(preds, pattern, captures, source)
- else
- metadata = {}
+ local directives = processed_pattern.directives
+ metadata = self:_apply_directives(directives, pattern_i, captures, source)
end
highest_cached_match_id = math.max(highest_cached_match_id, match_id)
@@ -988,20 +1022,20 @@ function Query:iter_matches(node, source, start, stop, opts)
return
end
- local match_id, pattern = match:info()
- local preds = self.info.patterns[pattern]
+ local match_id, pattern_i = match:info()
+ local processed_pattern = self._processed_patterns[pattern_i]
local captures = match:captures()
--- @type vim.treesitter.query.TSMetadata
- local metadata
- if preds then
- if not self:match_preds(preds, pattern, captures, source) then
+ local metadata = {}
+ if processed_pattern then
+ local predicates = processed_pattern.predicates
+ if not self:_match_predicates(predicates, pattern_i, captures, source) then
cursor:remove_match(match_id)
return iter() -- tail call: try next match
end
- metadata = self:apply_directives(preds, pattern, captures, source)
- else
- metadata = {}
+ local directives = processed_pattern.directives
+ metadata = self:_apply_directives(directives, pattern_i, captures, source)
end
if opts.all == false then
@@ -1012,11 +1046,11 @@ function Query:iter_matches(node, source, start, stop, opts)
for k, v in pairs(captures or {}) do
old_match[k] = v[#v]
end
- return pattern, old_match, metadata
+ return pattern_i, old_match, metadata
end
-- TODO(lewis6991): create a new function that returns {match, metadata}
- return pattern, captures, metadata
+ return pattern_i, captures, metadata
end
return iter
end
diff --git a/test/benchmark/decor_spec.lua b/test/benchmark/decor_spec.lua
index 42b2d1e744..1b7e763a09 100644
--- a/test/benchmark/decor_spec.lua
+++ b/test/benchmark/decor_spec.lua
@@ -99,7 +99,6 @@ describe('decor perf', function()
print('\nTotal ' .. fmt(total) .. '\nDecoration provider: ' .. fmt(provider))
end)
-
it('can handle full screen of highlighting', function()
Screen.new(100, 51)
diff --git a/test/functional/treesitter/query_spec.lua b/test/functional/treesitter/query_spec.lua
index 634f8af83d..6e21ed1d99 100644
--- a/test/functional/treesitter/query_spec.lua
+++ b/test/functional/treesitter/query_spec.lua
@@ -835,9 +835,9 @@ void ui_refresh(void)
local result = exec_lua(function()
local query0 = vim.treesitter.query.parse('c', query)
- local match_preds = query0.match_preds
+ local match_preds = query0._match_predicates
local called = 0
- function query0:match_preds(...)
+ function query0:_match_predicates(...)
called = called + 1
return match_preds(self, ...)
end