aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--runtime/doc/treesitter.txt8
-rw-r--r--runtime/lua/vim/treesitter/_meta.lua44
-rw-r--r--runtime/lua/vim/treesitter/_query_linter.lua2
-rw-r--r--runtime/lua/vim/treesitter/query.lua125
-rw-r--r--src/nvim/lua/executor.c3
-rw-r--r--src/nvim/lua/treesitter.c265
6 files changed, 245 insertions, 202 deletions
diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt
index e036df5130..a76fa3c123 100644
--- a/runtime/doc/treesitter.txt
+++ b/runtime/doc/treesitter.txt
@@ -1152,6 +1152,10 @@ Query:iter_captures({node}, {source}, {start}, {stop})
end
<
+ Note: ~
+ • Captures are only returned if the query pattern of a specific capture
+ contained predicates.
+
Parameters: ~
• {node} (`TSNode`) under which the search will occur
• {source} (`integer|string`) Source buffer or string to extract text
@@ -1162,7 +1166,7 @@ Query:iter_captures({node}, {source}, {start}, {stop})
Defaults to `node:end_()`.
Return: ~
- (`fun(end_line: integer?): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer, TSNode>`)
+ (`fun(end_line: integer?): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer,TSNode[]>?`)
capture id, capture node, metadata, match
*Query:iter_matches()*
@@ -1206,6 +1210,8 @@ Query:iter_matches({node}, {source}, {start}, {stop}, {opts})
• max_start_depth (integer) if non-zero, sets the maximum
start depth for each match. This is used to prevent
traversing too deep into a tree.
+ • match_limit (integer) Set the maximum number of
+ in-progress matches (Default: 256).
• all (boolean) When set, the returned match table maps
capture IDs to a list of nodes. Older versions of
iter_matches incorrectly mapped capture IDs to a single
diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua
index 19d97d2820..e2768d4b06 100644
--- a/runtime/lua/vim/treesitter/_meta.lua
+++ b/runtime/lua/vim/treesitter/_meta.lua
@@ -34,22 +34,6 @@ error('Cannot require a meta file')
---@field byte_length fun(self: TSNode): integer
local TSNode = {}
----@param query TSQuery
----@param captures true
----@param start? integer
----@param end_? integer
----@param opts? table
----@return fun(): integer, TSNode, vim.treesitter.query.TSMatch
-function TSNode:_rawquery(query, captures, start, end_, opts) end
-
----@param query TSQuery
----@param captures false
----@param start? integer
----@param end_? integer
----@param opts? table
----@return fun(): integer, vim.treesitter.query.TSMatch
-function TSNode:_rawquery(query, captures, start, end_, opts) end
-
---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string)
---@class TSParser: userdata
@@ -90,3 +74,31 @@ vim._ts_parse_query = function(lang, query) end
---@param lang string
---@return TSParser
vim._create_ts_parser = function(lang) end
+
+--- @class TSQueryMatch: userdata
+--- @field captures fun(self: TSQueryMatch): table<integer,TSNode[]>
+local TSQueryMatch = {}
+
+--- @return integer match_id
+--- @return integer pattern_index
+function TSQueryMatch:info() end
+
+--- @class TSQueryCursor: userdata
+--- @field remove_match fun(self: TSQueryCursor, id: integer)
+local TSQueryCursor = {}
+
+--- @return integer capture
+--- @return TSNode captured_node
+--- @return TSQueryMatch match
+function TSQueryCursor:next_capture() end
+
+--- @return TSQueryMatch match
+function TSQueryCursor:next_match() end
+
+--- @param node TSNode
+--- @param query TSQuery
+--- @param start integer?
+--- @param stop integer?
+--- @param opts? { max_start_depth?: integer, match_limit?: integer}
+--- @return TSQueryCursor
+function vim._create_ts_querycursor(node, query, start, stop, opts) end
diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua
index 6216d4e891..12b4cbc7b9 100644
--- a/runtime/lua/vim/treesitter/_query_linter.lua
+++ b/runtime/lua/vim/treesitter/_query_linter.lua
@@ -122,7 +122,7 @@ local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
end)
--- @param buf integer
---- @param match vim.treesitter.query.TSMatch
+--- @param match table<integer,TSNode[]>
--- @param query vim.treesitter.Query
--- @param lang_context QueryLinterLanguageContext
--- @param diagnostics vim.Diagnostic[]
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 30cd00c617..075fd0e99b 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -258,7 +258,7 @@ end)
--- handling the "any" vs "all" semantics. They are called from the
--- predicate_handlers table with the appropriate arguments for each predicate.
local impl = {
- --- @param match vim.treesitter.query.TSMatch
+ --- @param match table<integer,TSNode[]>
--- @param source integer|string
--- @param predicate any[]
--- @param any boolean
@@ -293,7 +293,7 @@ local impl = {
return not any
end,
- --- @param match vim.treesitter.query.TSMatch
+ --- @param match table<integer,TSNode[]>
--- @param source integer|string
--- @param predicate any[]
--- @param any boolean
@@ -333,7 +333,7 @@ local impl = {
end,
})
- --- @param match vim.treesitter.query.TSMatch
+ --- @param match table<integer,TSNode[]>
--- @param source integer|string
--- @param predicate any[]
--- @param any boolean
@@ -356,7 +356,7 @@ local impl = {
end
end)(),
- --- @param match vim.treesitter.query.TSMatch
+ --- @param match table<integer,TSNode[]>
--- @param source integer|string
--- @param predicate any[]
--- @param any boolean
@@ -383,13 +383,7 @@ local impl = {
end,
}
----@nodoc
----@class vim.treesitter.query.TSMatch
----@field pattern? integer
----@field active? boolean
----@field [integer] TSNode[]
-
----@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean
+---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean
-- Predicate handler receive the following arguments
-- (match, pattern, bufnr, predicate)
@@ -504,7 +498,7 @@ predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?']
---@field [integer] vim.treesitter.query.TSMetadata
---@field [string] integer|string
----@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
+---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
-- Predicate handler receive the following arguments
-- (match, pattern, bufnr, predicate)
@@ -726,13 +720,19 @@ local function is_directive(name)
end
---@private
----@param match vim.treesitter.query.TSMatch
----@param pattern integer
+---@param match TSQueryMatch
---@param source integer|string
-function Query:match_preds(match, pattern, source)
+function Query:match_preds(match, source)
+ local _, pattern = match:info()
local preds = self.info.patterns[pattern]
- for _, pred in pairs(preds or {}) do
+ if not preds then
+ return true
+ end
+
+ local captures = match:captures()
+
+ for _, pred in pairs(preds) do
-- Here we only want to return if a predicate DOES NOT match, and
-- continue on the other case. This way unknown predicates will not be considered,
-- which allows some testing and easier user extensibility (#12173).
@@ -754,7 +754,7 @@ function Query:match_preds(match, pattern, source)
return false
end
- local pred_matches = handler(match, pattern, source, pred)
+ local pred_matches = handler(captures, pattern, source, pred)
if not xor(is_not, pred_matches) then
return false
@@ -765,23 +765,33 @@ function Query:match_preds(match, pattern, source)
end
---@private
----@param match vim.treesitter.query.TSMatch
----@param metadata vim.treesitter.query.TSMetadata
-function Query:apply_directives(match, pattern, source, metadata)
+---@param match TSQueryMatch
+---@return vim.treesitter.query.TSMetadata metadata
+function Query:apply_directives(match, source)
+ ---@type vim.treesitter.query.TSMetadata
+ local metadata = {}
+ local _, pattern = match:info()
local preds = self.info.patterns[pattern]
- for _, pred in pairs(preds or {}) do
+ if not preds then
+ return metadata
+ end
+
+ local captures = match:captures()
+
+ for _, pred in pairs(preds) do
if is_directive(pred[1]) then
local handler = directive_handlers[pred[1]]
if not handler then
error(string.format('No handler for %s', pred[1]))
- return
end
- handler(match, pattern, source, pred, metadata)
+ handler(captures, pattern, source, pred, metadata)
end
end
+
+ return metadata
end
--- Returns the start and stop value if set else the node's range.
@@ -831,8 +841,10 @@ end
---@param start? integer Starting line for the search. Defaults to `node:start()`.
---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
---
----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer, TSNode>):
+---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer,TSNode[]>?):
--- capture id, capture node, metadata, match
+---
+---@note Captures are only returned if the query pattern of a specific capture contained predicates.
function Query:iter_captures(node, source, start, stop)
if type(source) == 'number' and source == 0 then
source = api.nvim_get_current_buf()
@@ -840,24 +852,38 @@ function Query:iter_captures(node, source, start, stop)
start, stop = value_or_node_range(start, stop, node)
- local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch
+ local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
+
+ local max_match_id = -1
+
local function iter(end_line)
- local capture, captured_node, match = raw_iter()
+ local capture, captured_node, match = cursor:next_capture()
+
+ if not capture then
+ return
+ end
+
+ local captures --- @type table<integer,TSNode[]>?
+ local match_id, pattern_index = match:info()
+
local metadata = {}
- if match ~= nil then
- local active = self:match_preds(match, match.pattern, source)
- match.active = active
- if not active then
+ local preds = self.info.patterns[pattern_index] or {}
+
+ if #preds > 0 and match_id > max_match_id then
+ captures = match:captures()
+ max_match_id = match_id
+ if not self:match_preds(match, source) then
+ cursor:remove_match(match_id)
if end_line and captured_node:range() > end_line then
return nil, captured_node, nil
end
return iter(end_line) -- tail call: try next match
end
- self:apply_directives(match, match.pattern, source, metadata)
+ metadata = self:apply_directives(match, source)
end
- return capture, captured_node, metadata, match
+ return capture, captured_node, metadata, captures
end
return iter
end
@@ -899,45 +925,54 @@ end
---@param opts? table Optional keyword arguments:
--- - max_start_depth (integer) if non-zero, sets the maximum start depth
--- for each match. This is used to prevent traversing too deep into a tree.
+--- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256).
--- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes.
--- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is
--- incorrect behavior. This option will eventually become the default and removed.
---
---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata
function Query:iter_matches(node, source, start, stop, opts)
- local all = opts and opts.all
+ opts = opts or {}
+ opts.match_limit = opts.match_limit or 256
+
if type(source) == 'number' and source == 0 then
source = api.nvim_get_current_buf()
end
start, stop = value_or_node_range(start, stop, node)
- local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch
+ local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts)
+
local function iter()
- local pattern, match = raw_iter()
- local metadata = {}
+ local match = cursor:next_match()
- if match ~= nil then
- local active = self:match_preds(match, pattern, source)
- if not active then
- return iter() -- tail call: try next match
- end
+ if not match then
+ return
+ end
- self:apply_directives(match, pattern, source, metadata)
+ local match_id, pattern = match:info()
+
+ if not self:match_preds(match, source) then
+ cursor:remove_match(match_id)
+ return iter() -- tail call: try next match
end
- if not all then
+ local metadata = self:apply_directives(match, source)
+
+ local captures = match:captures()
+
+ if not opts.all then
-- Convert the match table into the old buggy version for backward
-- compatibility. This is slow. Plugin authors, if you're reading this, set the "all"
-- option!
local old_match = {} ---@type table<integer, TSNode>
- for k, v in pairs(match or {}) do
+ for k, v in pairs(captures or {}) do
old_match[k] = v[#v]
end
return pattern, old_match, metadata
end
- return pattern, match, metadata
+ return pattern, captures, metadata
end
return iter
end
diff --git a/src/nvim/lua/executor.c b/src/nvim/lua/executor.c
index 78c746d169..d5d35c5295 100644
--- a/src/nvim/lua/executor.c
+++ b/src/nvim/lua/executor.c
@@ -1909,6 +1909,9 @@ static void nlua_add_treesitter(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL
lua_pushcfunction(lstate, tslua_push_parser);
lua_setfield(lstate, -2, "_create_ts_parser");
+ lua_pushcfunction(lstate, tslua_push_querycursor);
+ lua_setfield(lstate, -2, "_create_ts_querycursor");
+
lua_pushcfunction(lstate, tslua_add_language);
lua_setfield(lstate, -2, "_ts_add_language");
diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c
index 6d6ef6c7b9..2d44e485cb 100644
--- a/src/nvim/lua/treesitter.c
+++ b/src/nvim/lua/treesitter.c
@@ -33,15 +33,10 @@
#define TS_META_NODE "treesitter_node"
#define TS_META_QUERY "treesitter_query"
#define TS_META_QUERYCURSOR "treesitter_querycursor"
+#define TS_META_QUERYMATCH "treesitter_querymatch"
#define TS_META_TREECURSOR "treesitter_treecursor"
typedef struct {
- TSQueryCursor *cursor;
- int predicated_match;
- int max_match_id;
-} TSLua_cursor;
-
-typedef struct {
LuaRef cb;
lua_State *lstate;
bool lex;
@@ -108,7 +103,6 @@ static struct luaL_Reg node_meta[] = {
{ "named_descendant_for_range", node_named_descendant_for_range },
{ "parent", node_parent },
{ "iter_children", node_iter_children },
- { "_rawquery", node_rawquery },
{ "next_sibling", node_next_sibling },
{ "prev_sibling", node_prev_sibling },
{ "next_named_sibling", node_next_named_sibling },
@@ -130,18 +124,27 @@ static struct luaL_Reg query_meta[] = {
{ NULL, NULL }
};
-// cursors are not exposed, but still needs garbage collection
+// TSQueryCursor
static struct luaL_Reg querycursor_meta[] = {
+ { "remove_match", querycursor_remove_match },
+ { "next_capture", querycursor_next_capture },
+ { "next_match", querycursor_next_match },
{ "__gc", querycursor_gc },
{ NULL, NULL }
};
+// TSQueryMatch
+static struct luaL_Reg querymatch_meta[] = {
+ { "info", querymatch_info },
+ { "captures", querymatch_captures },
+ { NULL, NULL }
+};
+
static struct luaL_Reg treecursor_meta[] = {
{ "__gc", treecursor_gc },
{ NULL, NULL }
};
-static kvec_t(TSQueryCursor *) cursors = KV_INITIAL_VALUE;
static PMap(cstr_t) langs = MAP_INIT;
static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta)
@@ -166,6 +169,7 @@ void tslua_init(lua_State *L)
build_meta(L, TS_META_NODE, node_meta);
build_meta(L, TS_META_QUERY, query_meta);
build_meta(L, TS_META_QUERYCURSOR, querycursor_meta);
+ build_meta(L, TS_META_QUERYMATCH, querymatch_meta);
build_meta(L, TS_META_TREECURSOR, treecursor_meta);
ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree);
@@ -1361,173 +1365,156 @@ static int node_equal(lua_State *L)
return 1;
}
-/// assumes the match table being on top of the stack
-static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx)
-{
- // [match]
- for (size_t i = 0; i < match->capture_count; i++) {
- lua_rawgeti(L, -1, (int)match->captures[i].index + 1); // [match, captures]
- if (lua_isnil(L, -1)) { // [match, nil]
- lua_pop(L, 1); // [match]
- lua_createtable(L, 1, 0); // [match, captures]
- }
- push_node(L, match->captures[i].node, nodeidx); // [match, captures, node]
- lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, captures]
- lua_rawseti(L, -2, (int)match->captures[i].index + 1); // [match]
- }
-}
-
-static int query_next_match(lua_State *L)
-{
- TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
- TSQueryCursor *cursor = ud->cursor;
-
- TSQuery *query = query_check(L, lua_upvalueindex(3));
- TSQueryMatch match;
- if (ts_query_cursor_next_match(cursor, &match)) {
- lua_pushinteger(L, match.pattern_index + 1); // [index]
- lua_createtable(L, (int)ts_query_capture_count(query), 0); // [index, match]
- set_match(L, &match, lua_upvalueindex(2));
- return 2;
- }
- return 0;
-}
-
-static int query_next_capture(lua_State *L)
-{
- // Upvalues are:
- // [ cursor, node, query, current_match ]
- TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
- TSQueryCursor *cursor = ud->cursor;
-
- TSQuery *query = query_check(L, lua_upvalueindex(3));
-
- if (ud->predicated_match > -1) {
- lua_getfield(L, lua_upvalueindex(4), "active");
- bool active = lua_toboolean(L, -1);
- lua_pop(L, 1);
- if (!active) {
- ts_query_cursor_remove_match(cursor, (uint32_t)ud->predicated_match);
- }
- ud->predicated_match = -1;
- }
-
- TSQueryMatch match;
- uint32_t capture_index;
- if (ts_query_cursor_next_capture(cursor, &match, &capture_index)) {
- TSQueryCapture capture = match.captures[capture_index];
-
- // TODO(vigoux): handle capture quantifiers here
- lua_pushinteger(L, capture.index + 1); // [index]
- push_node(L, capture.node, lua_upvalueindex(2)); // [index, node]
-
- // Now check if we need to run the predicates
- uint32_t n_pred;
- ts_query_predicates_for_pattern(query, match.pattern_index, &n_pred);
-
- if (n_pred > 0 && (ud->max_match_id < (int)match.id)) {
- ud->max_match_id = (int)match.id;
-
- // Create a new cleared match table
- lua_createtable(L, (int)ts_query_capture_count(query), 2); // [index, node, match]
- set_match(L, &match, lua_upvalueindex(2));
- lua_pushinteger(L, match.pattern_index + 1);
- lua_setfield(L, -2, "pattern");
-
- if (match.capture_count > 1) {
- ud->predicated_match = (int)match.id;
- lua_pushboolean(L, false);
- lua_setfield(L, -2, "active");
- }
-
- // Set current_match to the new match
- lua_replace(L, lua_upvalueindex(4)); // [index, node]
- lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match]
- return 3;
- }
- return 2;
- }
- return 0;
-}
-
-static int node_rawquery(lua_State *L)
+int tslua_push_querycursor(lua_State *L)
{
TSNode node;
if (!node_check(L, 1, &node)) {
- return 0;
+ return luaL_error(L, "TSNode expected");
}
- TSQuery *query = query_check(L, 2);
- TSQueryCursor *cursor;
- if (kv_size(cursors) > 0) {
- cursor = kv_pop(cursors);
- } else {
- cursor = ts_query_cursor_new();
+ TSQuery *query = query_check(L, 2);
+ if (!query) {
+ return luaL_error(L, "TSQuery expected");
}
- ts_query_cursor_set_max_start_depth(cursor, UINT32_MAX);
- ts_query_cursor_set_match_limit(cursor, 256);
+ TSQueryCursor *cursor = ts_query_cursor_new();
ts_query_cursor_exec(cursor, query, node);
- bool captures = lua_toboolean(L, 3);
-
- if (lua_gettop(L) >= 4) {
- uint32_t start = (uint32_t)luaL_checkinteger(L, 4);
- uint32_t end = lua_gettop(L) >= 5 ? (uint32_t)luaL_checkinteger(L, 5) : MAXLNUM;
+ if (lua_gettop(L) >= 3) {
+ uint32_t start = (uint32_t)luaL_checkinteger(L, 3);
+ uint32_t end = lua_gettop(L) >= 4 ? (uint32_t)luaL_checkinteger(L, 4) : MAXLNUM;
ts_query_cursor_set_point_range(cursor, (TSPoint){ start, 0 }, (TSPoint){ end, 0 });
}
- if (lua_gettop(L) >= 6 && !lua_isnil(L, 6)) {
- if (!lua_istable(L, 6)) {
+ if (lua_gettop(L) >= 5 && !lua_isnil(L, 5)) {
+ if (!lua_istable(L, 5)) {
return luaL_error(L, "table expected");
}
- lua_pushnil(L);
- // stack: [dict, ..., nil]
- while (lua_next(L, 6)) {
- // stack: [dict, ..., key, value]
+ lua_pushnil(L); // [dict, ..., nil]
+ while (lua_next(L, 5)) {
+ // [dict, ..., key, value]
if (lua_type(L, -2) == LUA_TSTRING) {
char *k = (char *)lua_tostring(L, -2);
if (strequal("max_start_depth", k)) {
uint32_t max_start_depth = (uint32_t)lua_tointeger(L, -1);
ts_query_cursor_set_max_start_depth(cursor, max_start_depth);
+ } else if (strequal("match_limit", k)) {
+ uint32_t match_limit = (uint32_t)lua_tointeger(L, -1);
+ ts_query_cursor_set_match_limit(cursor, match_limit);
}
}
- lua_pop(L, 1); // pop the value; lua_next will pop the key.
- // stack: [dict, ..., key]
+ // pop the value; lua_next will pop the key.
+ lua_pop(L, 1); // [dict, ..., key]
}
}
- TSLua_cursor *ud = lua_newuserdata(L, sizeof(*ud)); // [udata]
- ud->cursor = cursor;
- ud->predicated_match = -1;
- ud->max_match_id = -1;
+ TSQueryCursor **ud = lua_newuserdata(L, sizeof(*ud)); // [node, query, ..., udata]
+ *ud = cursor;
+ lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR); // [node, query, ..., udata, meta]
+ lua_setmetatable(L, -2); // [node, query, ..., udata]
+
+ // Copy the fenv which contains the nodes tree.
+ lua_getfenv(L, 1); // [udata, reftable]
+ lua_setfenv(L, -2); // [udata]
- lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR);
+ return 1;
+}
+
+static int querycursor_remove_match(lua_State *L)
+{
+ TSQueryCursor *cursor = querycursor_check(L, 1);
+ uint32_t match_id = (uint32_t)luaL_checkinteger(L, 2);
+ ts_query_cursor_remove_match(cursor, match_id);
+ return 0;
+}
+
+static void push_querymatch(lua_State *L, TSQueryMatch *match, int uindex)
+{
+ TSQueryMatch *ud = lua_newuserdata(L, sizeof(TSQueryMatch)); // [udata]
+ *ud = *match;
+ lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYMATCH); // [udata, meta]
lua_setmetatable(L, -2); // [udata]
- lua_pushvalue(L, 1); // [udata, node]
- // include query separately, as to keep a ref to it for gc
- lua_pushvalue(L, 2); // [udata, node, query]
+ // Copy the fenv which contains the nodes tree.
+ lua_getfenv(L, uindex); // [udata, reftable]
+ lua_setfenv(L, -2); // [udata]
+}
+
+static int querycursor_next_capture(lua_State *L)
+{
+ TSQueryCursor *cursor = querycursor_check(L, 1);
- if (captures) {
- // placeholder for match state
- lua_createtable(L, (int)ts_query_capture_count(query), 2); // [u, n, q, match]
- lua_pushcclosure(L, query_next_capture, 4); // [closure]
- } else {
- lua_pushcclosure(L, query_next_match, 3); // [closure]
+ TSQueryMatch match;
+ uint32_t capture_index;
+ if (!ts_query_cursor_next_capture(cursor, &match, &capture_index)) {
+ return 0;
}
+ TSQueryCapture capture = match.captures[capture_index];
+
+ // Handle capture quantifiers here
+ lua_pushinteger(L, capture.index + 1); // [index]
+ push_node(L, capture.node, 1); // [index, node]
+ push_querymatch(L, &match, 1);
+
+ return 3;
+}
+
+static int querycursor_next_match(lua_State *L)
+{
+ TSQueryCursor *cursor = querycursor_check(L, 1);
+
+ TSQueryMatch match;
+ if (!ts_query_cursor_next_match(cursor, &match)) {
+ return 0;
+ }
+
+ push_querymatch(L, &match, 1);
+
return 1;
}
+static TSQueryCursor *querycursor_check(lua_State *L, int index)
+{
+ TSQueryCursor **ud = luaL_checkudata(L, index, TS_META_QUERYCURSOR);
+ return *ud;
+}
+
static int querycursor_gc(lua_State *L)
{
- TSLua_cursor *ud = luaL_checkudata(L, 1, TS_META_QUERYCURSOR);
- kv_push(cursors, ud->cursor);
- ud->cursor = NULL;
+ TSQueryCursor *cursor = querycursor_check(L, 1);
+ ts_query_cursor_delete(cursor);
return 0;
}
+static int querymatch_info(lua_State *L)
+{
+ TSQueryMatch *ud = luaL_checkudata(L, 1, TS_META_QUERYMATCH);
+ lua_pushinteger(L, ud->id);
+ lua_pushinteger(L, ud->pattern_index + 1);
+ return 2;
+}
+
+static int querymatch_captures(lua_State *L)
+{
+ TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH);
+ lua_newtable(L); // [match, nodes, captures]
+ for (size_t i = 0; i < match->capture_count; i++) {
+ TSQueryCapture capture = match->captures[i];
+ int index = (int)capture.index + 1;
+
+ lua_rawgeti(L, -1, index); // [match, node, captures]
+ if (lua_isnil(L, -1)) { // [match, node, captures, nil]
+ lua_pop(L, 1); // [match, node, captures]
+ lua_newtable(L); // [match, node, captures, nodes]
+ }
+ push_node(L, capture.node, 1); // [match, node, captures, nodes, node]
+ lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, node, captures, nodes]
+ lua_rawseti(L, -2, index); // [match, node, captures]
+ }
+ return 1;
+}
+
// Query methods
int tslua_parse_query(lua_State *L)
@@ -1638,7 +1625,7 @@ static void query_err_string(const char *src, int error_offset, TSQueryError err
static TSQuery *query_check(lua_State *L, int index)
{
TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY);
- return *ud;
+ return ud ? *ud : NULL;
}
static int query_gc(lua_State *L)