aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/query.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r--runtime/lua/vim/treesitter/query.lua81
1 files changed, 45 insertions, 36 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 2b3b9096a6..01fdb708eb 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -762,16 +762,7 @@ end
---@private
---@param match TSQueryMatch
---@param source integer|string
-function Query:match_preds(match, source)
- local _, pattern = match:info()
- local preds = self.info.patterns[pattern]
-
- if not preds then
- return true
- end
-
- local captures = match:captures()
-
+function Query:match_preds(preds, pattern, captures, source)
for _, pred in pairs(preds) do
-- Here we only want to return if a predicate DOES NOT match, and
-- continue on the other case. This way unknown predicates will not be considered,
@@ -807,17 +798,9 @@ end
---@private
---@param match TSQueryMatch
---@return vim.treesitter.query.TSMetadata metadata
-function Query:apply_directives(match, source)
+function Query:apply_directives(preds, pattern, captures, source)
---@type vim.treesitter.query.TSMetadata
local metadata = {}
- local _, pattern = match:info()
- local preds = self.info.patterns[pattern]
-
- if not preds then
- return metadata
- end
-
- local captures = match:captures()
for _, pred in pairs(preds) do
if is_directive(pred[1]) then
@@ -902,8 +885,10 @@ function Query:iter_captures(node, source, start, stop)
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
- local apply_directives = memoize(match_id_hash, self.apply_directives, false)
- local match_preds = memoize(match_id_hash, self.match_preds, false)
+ -- For faster checks that a match is not in the cache.
+ local highest_cached_match_id = -1
+ ---@type table<integer, vim.treesitter.query.TSMetadata>
+ local match_cache = {}
local function iter(end_line)
local capture, captured_node, match = cursor:next_capture()
@@ -912,16 +897,35 @@ function Query:iter_captures(node, source, start, stop)
return
end
- if not match_preds(self, match, source) then
- local match_id = match:info()
- cursor:remove_match(match_id)
- if end_line and captured_node:range() > end_line then
- return nil, captured_node, nil, nil
- end
- return iter(end_line) -- tail call: try next match
+ local match_id, pattern = match:info()
+
+ --- @type vim.treesitter.query.TSMetadata
+ local metadata
+ if match_id <= highest_cached_match_id then
+ metadata = match_cache[match_id]
end
- local metadata = apply_directives(self, match, source)
+ if not metadata then
+ local preds = self.info.patterns[pattern]
+ if preds then
+ local captures = match:captures()
+
+ if not self:match_preds(preds, pattern, captures, source) then
+ cursor:remove_match(match_id)
+ if end_line and captured_node:range() > end_line then
+ return nil, captured_node, nil, nil
+ end
+ return iter(end_line) -- tail call: try next match
+ end
+
+ metadata = self:apply_directives(preds, pattern, captures, source)
+ else
+ metadata = {}
+ end
+
+ highest_cached_match_id = math.max(highest_cached_match_id, match_id)
+ match_cache[match_id] = metadata
+ end
return capture, captured_node, metadata, match
end
@@ -985,16 +989,21 @@ function Query:iter_matches(node, source, start, stop, opts)
end
local match_id, pattern = match:info()
+ local preds = self.info.patterns[pattern]
+ local captures = match:captures()
- if not self:match_preds(match, source) then
- cursor:remove_match(match_id)
- return iter() -- tail call: try next match
+ --- @type vim.treesitter.query.TSMetadata
+ local metadata
+ if preds then
+ if not self:match_preds(preds, pattern, captures, source) then
+ cursor:remove_match(match_id)
+ return iter() -- tail call: try next match
+ end
+ metadata = self:apply_directives(preds, pattern, captures, source)
+ else
+ metadata = {}
end
- local metadata = self:apply_directives(match, source)
-
- local captures = match:captures()
-
if opts.all == false then
-- Convert the match table into the old buggy version for backward
-- compatibility. This is slow, but we only do it when the caller explicitly opted into it by