diff options
-rw-r--r-- | runtime/doc/treesitter.txt | 8 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_meta.lua | 44 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_query_linter.lua | 2 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 125 | ||||
-rw-r--r-- | src/nvim/lua/executor.c | 3 | ||||
-rw-r--r-- | src/nvim/lua/treesitter.c | 265 |
6 files changed, 245 insertions, 202 deletions
diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt index e036df5130..a76fa3c123 100644 --- a/runtime/doc/treesitter.txt +++ b/runtime/doc/treesitter.txt @@ -1152,6 +1152,10 @@ Query:iter_captures({node}, {source}, {start}, {stop}) end < + Note: ~ + • Captures are only returned if the query pattern of a specific capture + contained predicates. + Parameters: ~ • {node} (`TSNode`) under which the search will occur • {source} (`integer|string`) Source buffer or string to extract text @@ -1162,7 +1166,7 @@ Query:iter_captures({node}, {source}, {start}, {stop}) Defaults to `node:end_()`. Return: ~ - (`fun(end_line: integer?): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer, TSNode>`) + (`fun(end_line: integer?): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer,TSNode[]>?`) capture id, capture node, metadata, match *Query:iter_matches()* @@ -1206,6 +1210,8 @@ Query:iter_matches({node}, {source}, {start}, {stop}, {opts}) • max_start_depth (integer) if non-zero, sets the maximum start depth for each match. This is used to prevent traversing too deep into a tree. + • match_limit (integer) Set the maximum number of + in-progress matches (Default: 256). • all (boolean) When set, the returned match table maps capture IDs to a list of nodes. Older versions of iter_matches incorrectly mapped capture IDs to a single diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua index 19d97d2820..e2768d4b06 100644 --- a/runtime/lua/vim/treesitter/_meta.lua +++ b/runtime/lua/vim/treesitter/_meta.lua @@ -34,22 +34,6 @@ error('Cannot require a meta file') ---@field byte_length fun(self: TSNode): integer local TSNode = {} ----@param query TSQuery ----@param captures true ----@param start? integer ----@param end_? integer ----@param opts? table ----@return fun(): integer, TSNode, vim.treesitter.query.TSMatch -function TSNode:_rawquery(query, captures, start, end_, opts) end - ----@param query TSQuery ----@param captures false ----@param start? integer ----@param end_? integer ----@param opts? table ----@return fun(): integer, vim.treesitter.query.TSMatch -function TSNode:_rawquery(query, captures, start, end_, opts) end - ---@alias TSLoggerCallback fun(logtype: 'parse'|'lex', msg: string) ---@class TSParser: userdata @@ -90,3 +74,31 @@ vim._ts_parse_query = function(lang, query) end ---@param lang string ---@return TSParser vim._create_ts_parser = function(lang) end + +--- @class TSQueryMatch: userdata +--- @field captures fun(self: TSQueryMatch): table<integer,TSNode[]> +local TSQueryMatch = {} + +--- @return integer match_id +--- @return integer pattern_index +function TSQueryMatch:info() end + +--- @class TSQueryCursor: userdata +--- @field remove_match fun(self: TSQueryCursor, id: integer) +local TSQueryCursor = {} + +--- @return integer capture +--- @return TSNode captured_node +--- @return TSQueryMatch match +function TSQueryCursor:next_capture() end + +--- @return TSQueryMatch match +function TSQueryCursor:next_match() end + +--- @param node TSNode +--- @param query TSQuery +--- @param start integer? +--- @param stop integer? +--- @param opts? { max_start_depth?: integer, match_limit?: integer} +--- @return TSQueryCursor +function vim._create_ts_querycursor(node, query, start, stop, opts) end diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua index 6216d4e891..12b4cbc7b9 100644 --- a/runtime/lua/vim/treesitter/_query_linter.lua +++ b/runtime/lua/vim/treesitter/_query_linter.lua @@ -122,7 +122,7 @@ local parse = vim.func._memoize(hash_parse, function(node, buf, lang) end) --- @param buf integer ---- @param match vim.treesitter.query.TSMatch +--- @param match table<integer,TSNode[]> --- @param query vim.treesitter.Query --- @param lang_context QueryLinterLanguageContext --- @param diagnostics vim.Diagnostic[] diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 30cd00c617..075fd0e99b 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -258,7 +258,7 @@ end) --- handling the "any" vs "all" semantics. They are called from the --- predicate_handlers table with the appropriate arguments for each predicate. local impl = { - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -293,7 +293,7 @@ local impl = { return not any end, - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -333,7 +333,7 @@ local impl = { end, }) - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -356,7 +356,7 @@ local impl = { end end)(), - --- @param match vim.treesitter.query.TSMatch + --- @param match table<integer,TSNode[]> --- @param source integer|string --- @param predicate any[] --- @param any boolean @@ -383,13 +383,7 @@ local impl = { end, } ----@nodoc ----@class vim.treesitter.query.TSMatch ----@field pattern? integer ----@field active? boolean ----@field [integer] TSNode[] - ----@alias TSPredicate fun(match: vim.treesitter.query.TSMatch, pattern: integer, source: integer|string, predicate: any[]): boolean +---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -504,7 +498,7 @@ predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?'] ---@field [integer] vim.treesitter.query.TSMetadata ---@field [string] integer|string ----@alias TSDirective fun(match: vim.treesitter.query.TSMatch, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) +---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata) -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) @@ -726,13 +720,19 @@ local function is_directive(name) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param pattern integer +---@param match TSQueryMatch ---@param source integer|string -function Query:match_preds(match, pattern, source) +function Query:match_preds(match, source) + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return true + end + + local captures = match:captures() + + for _, pred in pairs(preds) do -- Here we only want to return if a predicate DOES NOT match, and -- continue on the other case. This way unknown predicates will not be considered, -- which allows some testing and easier user extensibility (#12173). @@ -754,7 +754,7 @@ function Query:match_preds(match, pattern, source) return false end - local pred_matches = handler(match, pattern, source, pred) + local pred_matches = handler(captures, pattern, source, pred) if not xor(is_not, pred_matches) then return false @@ -765,23 +765,33 @@ function Query:match_preds(match, pattern, source) end ---@private ----@param match vim.treesitter.query.TSMatch ----@param metadata vim.treesitter.query.TSMetadata -function Query:apply_directives(match, pattern, source, metadata) +---@param match TSQueryMatch +---@return vim.treesitter.query.TSMetadata metadata +function Query:apply_directives(match, source) + ---@type vim.treesitter.query.TSMetadata + local metadata = {} + local _, pattern = match:info() local preds = self.info.patterns[pattern] - for _, pred in pairs(preds or {}) do + if not preds then + return metadata + end + + local captures = match:captures() + + for _, pred in pairs(preds) do if is_directive(pred[1]) then local handler = directive_handlers[pred[1]] if not handler then error(string.format('No handler for %s', pred[1])) - return end - handler(match, pattern, source, pred, metadata) + handler(captures, pattern, source, pred, metadata) end end + + return metadata end --- Returns the start and stop value if set else the node's range. @@ -831,8 +841,10 @@ end ---@param start? integer Starting line for the search. Defaults to `node:start()`. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`. --- ----@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer, TSNode>): +---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, table<integer,TSNode[]>?): --- capture id, capture node, metadata, match +--- +---@note Captures are only returned if the query pattern of a specific capture contained predicates. function Query:iter_captures(node, source, start, stop) if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() @@ -840,24 +852,38 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, true, start, stop) ---@type fun(): integer, TSNode, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 }) + + local max_match_id = -1 + local function iter(end_line) - local capture, captured_node, match = raw_iter() + local capture, captured_node, match = cursor:next_capture() + + if not capture then + return + end + + local captures --- @type table<integer,TSNode[]>? + local match_id, pattern_index = match:info() + local metadata = {} - if match ~= nil then - local active = self:match_preds(match, match.pattern, source) - match.active = active - if not active then + local preds = self.info.patterns[pattern_index] or {} + + if #preds > 0 and match_id > max_match_id then + captures = match:captures() + max_match_id = match_id + if not self:match_preds(match, source) then + cursor:remove_match(match_id) if end_line and captured_node:range() > end_line then return nil, captured_node, nil end return iter(end_line) -- tail call: try next match end - self:apply_directives(match, match.pattern, source, metadata) + metadata = self:apply_directives(match, source) end - return capture, captured_node, metadata, match + return capture, captured_node, metadata, captures end return iter end @@ -899,45 +925,54 @@ end ---@param opts? table Optional keyword arguments: --- - max_start_depth (integer) if non-zero, sets the maximum start depth --- for each match. This is used to prevent traversing too deep into a tree. +--- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256). --- - all (boolean) When set, the returned match table maps capture IDs to a list of nodes. --- Older versions of iter_matches incorrectly mapped capture IDs to a single node, which is --- incorrect behavior. This option will eventually become the default and removed. --- ---@return (fun(): integer, table<integer, TSNode[]>, table): pattern id, match, metadata function Query:iter_matches(node, source, start, stop, opts) - local all = opts and opts.all + opts = opts or {} + opts.match_limit = opts.match_limit or 256 + if type(source) == 'number' and source == 0 then source = api.nvim_get_current_buf() end start, stop = value_or_node_range(start, stop, node) - local raw_iter = node:_rawquery(self.query, false, start, stop, opts) ---@type fun(): integer, vim.treesitter.query.TSMatch + local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts) + local function iter() - local pattern, match = raw_iter() - local metadata = {} + local match = cursor:next_match() - if match ~= nil then - local active = self:match_preds(match, pattern, source) - if not active then - return iter() -- tail call: try next match - end + if not match then + return + end - self:apply_directives(match, pattern, source, metadata) + local match_id, pattern = match:info() + + if not self:match_preds(match, source) then + cursor:remove_match(match_id) + return iter() -- tail call: try next match end - if not all then + local metadata = self:apply_directives(match, source) + + local captures = match:captures() + + if not opts.all then -- Convert the match table into the old buggy version for backward -- compatibility. This is slow. Plugin authors, if you're reading this, set the "all" -- option! local old_match = {} ---@type table<integer, TSNode> - for k, v in pairs(match or {}) do + for k, v in pairs(captures or {}) do old_match[k] = v[#v] end return pattern, old_match, metadata end - return pattern, match, metadata + return pattern, captures, metadata end return iter end diff --git a/src/nvim/lua/executor.c b/src/nvim/lua/executor.c index 78c746d169..d5d35c5295 100644 --- a/src/nvim/lua/executor.c +++ b/src/nvim/lua/executor.c @@ -1909,6 +1909,9 @@ static void nlua_add_treesitter(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL lua_pushcfunction(lstate, tslua_push_parser); lua_setfield(lstate, -2, "_create_ts_parser"); + lua_pushcfunction(lstate, tslua_push_querycursor); + lua_setfield(lstate, -2, "_create_ts_querycursor"); + lua_pushcfunction(lstate, tslua_add_language); lua_setfield(lstate, -2, "_ts_add_language"); diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index 6d6ef6c7b9..2d44e485cb 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -33,15 +33,10 @@ #define TS_META_NODE "treesitter_node" #define TS_META_QUERY "treesitter_query" #define TS_META_QUERYCURSOR "treesitter_querycursor" +#define TS_META_QUERYMATCH "treesitter_querymatch" #define TS_META_TREECURSOR "treesitter_treecursor" typedef struct { - TSQueryCursor *cursor; - int predicated_match; - int max_match_id; -} TSLua_cursor; - -typedef struct { LuaRef cb; lua_State *lstate; bool lex; @@ -108,7 +103,6 @@ static struct luaL_Reg node_meta[] = { { "named_descendant_for_range", node_named_descendant_for_range }, { "parent", node_parent }, { "iter_children", node_iter_children }, - { "_rawquery", node_rawquery }, { "next_sibling", node_next_sibling }, { "prev_sibling", node_prev_sibling }, { "next_named_sibling", node_next_named_sibling }, @@ -130,18 +124,27 @@ static struct luaL_Reg query_meta[] = { { NULL, NULL } }; -// cursors are not exposed, but still needs garbage collection +// TSQueryCursor static struct luaL_Reg querycursor_meta[] = { + { "remove_match", querycursor_remove_match }, + { "next_capture", querycursor_next_capture }, + { "next_match", querycursor_next_match }, { "__gc", querycursor_gc }, { NULL, NULL } }; +// TSQueryMatch +static struct luaL_Reg querymatch_meta[] = { + { "info", querymatch_info }, + { "captures", querymatch_captures }, + { NULL, NULL } +}; + static struct luaL_Reg treecursor_meta[] = { { "__gc", treecursor_gc }, { NULL, NULL } }; -static kvec_t(TSQueryCursor *) cursors = KV_INITIAL_VALUE; static PMap(cstr_t) langs = MAP_INIT; static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta) @@ -166,6 +169,7 @@ void tslua_init(lua_State *L) build_meta(L, TS_META_NODE, node_meta); build_meta(L, TS_META_QUERY, query_meta); build_meta(L, TS_META_QUERYCURSOR, querycursor_meta); + build_meta(L, TS_META_QUERYMATCH, querymatch_meta); build_meta(L, TS_META_TREECURSOR, treecursor_meta); ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree); @@ -1361,173 +1365,156 @@ static int node_equal(lua_State *L) return 1; } -/// assumes the match table being on top of the stack -static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx) -{ - // [match] - for (size_t i = 0; i < match->capture_count; i++) { - lua_rawgeti(L, -1, (int)match->captures[i].index + 1); // [match, captures] - if (lua_isnil(L, -1)) { // [match, nil] - lua_pop(L, 1); // [match] - lua_createtable(L, 1, 0); // [match, captures] - } - push_node(L, match->captures[i].node, nodeidx); // [match, captures, node] - lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, captures] - lua_rawseti(L, -2, (int)match->captures[i].index + 1); // [match] - } -} - -static int query_next_match(lua_State *L) -{ - TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1)); - TSQueryCursor *cursor = ud->cursor; - - TSQuery *query = query_check(L, lua_upvalueindex(3)); - TSQueryMatch match; - if (ts_query_cursor_next_match(cursor, &match)) { - lua_pushinteger(L, match.pattern_index + 1); // [index] - lua_createtable(L, (int)ts_query_capture_count(query), 0); // [index, match] - set_match(L, &match, lua_upvalueindex(2)); - return 2; - } - return 0; -} - -static int query_next_capture(lua_State *L) -{ - // Upvalues are: - // [ cursor, node, query, current_match ] - TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1)); - TSQueryCursor *cursor = ud->cursor; - - TSQuery *query = query_check(L, lua_upvalueindex(3)); - - if (ud->predicated_match > -1) { - lua_getfield(L, lua_upvalueindex(4), "active"); - bool active = lua_toboolean(L, -1); - lua_pop(L, 1); - if (!active) { - ts_query_cursor_remove_match(cursor, (uint32_t)ud->predicated_match); - } - ud->predicated_match = -1; - } - - TSQueryMatch match; - uint32_t capture_index; - if (ts_query_cursor_next_capture(cursor, &match, &capture_index)) { - TSQueryCapture capture = match.captures[capture_index]; - - // TODO(vigoux): handle capture quantifiers here - lua_pushinteger(L, capture.index + 1); // [index] - push_node(L, capture.node, lua_upvalueindex(2)); // [index, node] - - // Now check if we need to run the predicates - uint32_t n_pred; - ts_query_predicates_for_pattern(query, match.pattern_index, &n_pred); - - if (n_pred > 0 && (ud->max_match_id < (int)match.id)) { - ud->max_match_id = (int)match.id; - - // Create a new cleared match table - lua_createtable(L, (int)ts_query_capture_count(query), 2); // [index, node, match] - set_match(L, &match, lua_upvalueindex(2)); - lua_pushinteger(L, match.pattern_index + 1); - lua_setfield(L, -2, "pattern"); - - if (match.capture_count > 1) { - ud->predicated_match = (int)match.id; - lua_pushboolean(L, false); - lua_setfield(L, -2, "active"); - } - - // Set current_match to the new match - lua_replace(L, lua_upvalueindex(4)); // [index, node] - lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match] - return 3; - } - return 2; - } - return 0; -} - -static int node_rawquery(lua_State *L) +int tslua_push_querycursor(lua_State *L) { TSNode node; if (!node_check(L, 1, &node)) { - return 0; + return luaL_error(L, "TSNode expected"); } - TSQuery *query = query_check(L, 2); - TSQueryCursor *cursor; - if (kv_size(cursors) > 0) { - cursor = kv_pop(cursors); - } else { - cursor = ts_query_cursor_new(); + TSQuery *query = query_check(L, 2); + if (!query) { + return luaL_error(L, "TSQuery expected"); } - ts_query_cursor_set_max_start_depth(cursor, UINT32_MAX); - ts_query_cursor_set_match_limit(cursor, 256); + TSQueryCursor *cursor = ts_query_cursor_new(); ts_query_cursor_exec(cursor, query, node); - bool captures = lua_toboolean(L, 3); - - if (lua_gettop(L) >= 4) { - uint32_t start = (uint32_t)luaL_checkinteger(L, 4); - uint32_t end = lua_gettop(L) >= 5 ? (uint32_t)luaL_checkinteger(L, 5) : MAXLNUM; + if (lua_gettop(L) >= 3) { + uint32_t start = (uint32_t)luaL_checkinteger(L, 3); + uint32_t end = lua_gettop(L) >= 4 ? (uint32_t)luaL_checkinteger(L, 4) : MAXLNUM; ts_query_cursor_set_point_range(cursor, (TSPoint){ start, 0 }, (TSPoint){ end, 0 }); } - if (lua_gettop(L) >= 6 && !lua_isnil(L, 6)) { - if (!lua_istable(L, 6)) { + if (lua_gettop(L) >= 5 && !lua_isnil(L, 5)) { + if (!lua_istable(L, 5)) { return luaL_error(L, "table expected"); } - lua_pushnil(L); - // stack: [dict, ..., nil] - while (lua_next(L, 6)) { - // stack: [dict, ..., key, value] + lua_pushnil(L); // [dict, ..., nil] + while (lua_next(L, 5)) { + // [dict, ..., key, value] if (lua_type(L, -2) == LUA_TSTRING) { char *k = (char *)lua_tostring(L, -2); if (strequal("max_start_depth", k)) { uint32_t max_start_depth = (uint32_t)lua_tointeger(L, -1); ts_query_cursor_set_max_start_depth(cursor, max_start_depth); + } else if (strequal("match_limit", k)) { + uint32_t match_limit = (uint32_t)lua_tointeger(L, -1); + ts_query_cursor_set_match_limit(cursor, match_limit); } } - lua_pop(L, 1); // pop the value; lua_next will pop the key. - // stack: [dict, ..., key] + // pop the value; lua_next will pop the key. + lua_pop(L, 1); // [dict, ..., key] } } - TSLua_cursor *ud = lua_newuserdata(L, sizeof(*ud)); // [udata] - ud->cursor = cursor; - ud->predicated_match = -1; - ud->max_match_id = -1; + TSQueryCursor **ud = lua_newuserdata(L, sizeof(*ud)); // [node, query, ..., udata] + *ud = cursor; + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR); // [node, query, ..., udata, meta] + lua_setmetatable(L, -2); // [node, query, ..., udata] + + // Copy the fenv which contains the nodes tree. + lua_getfenv(L, 1); // [udata, reftable] + lua_setfenv(L, -2); // [udata] - lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR); + return 1; +} + +static int querycursor_remove_match(lua_State *L) +{ + TSQueryCursor *cursor = querycursor_check(L, 1); + uint32_t match_id = (uint32_t)luaL_checkinteger(L, 2); + ts_query_cursor_remove_match(cursor, match_id); + return 0; +} + +static void push_querymatch(lua_State *L, TSQueryMatch *match, int uindex) +{ + TSQueryMatch *ud = lua_newuserdata(L, sizeof(TSQueryMatch)); // [udata] + *ud = *match; + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYMATCH); // [udata, meta] lua_setmetatable(L, -2); // [udata] - lua_pushvalue(L, 1); // [udata, node] - // include query separately, as to keep a ref to it for gc - lua_pushvalue(L, 2); // [udata, node, query] + // Copy the fenv which contains the nodes tree. + lua_getfenv(L, uindex); // [udata, reftable] + lua_setfenv(L, -2); // [udata] +} + +static int querycursor_next_capture(lua_State *L) +{ + TSQueryCursor *cursor = querycursor_check(L, 1); - if (captures) { - // placeholder for match state - lua_createtable(L, (int)ts_query_capture_count(query), 2); // [u, n, q, match] - lua_pushcclosure(L, query_next_capture, 4); // [closure] - } else { - lua_pushcclosure(L, query_next_match, 3); // [closure] + TSQueryMatch match; + uint32_t capture_index; + if (!ts_query_cursor_next_capture(cursor, &match, &capture_index)) { + return 0; } + TSQueryCapture capture = match.captures[capture_index]; + + // Handle capture quantifiers here + lua_pushinteger(L, capture.index + 1); // [index] + push_node(L, capture.node, 1); // [index, node] + push_querymatch(L, &match, 1); + + return 3; +} + +static int querycursor_next_match(lua_State *L) +{ + TSQueryCursor *cursor = querycursor_check(L, 1); + + TSQueryMatch match; + if (!ts_query_cursor_next_match(cursor, &match)) { + return 0; + } + + push_querymatch(L, &match, 1); + return 1; } +static TSQueryCursor *querycursor_check(lua_State *L, int index) +{ + TSQueryCursor **ud = luaL_checkudata(L, index, TS_META_QUERYCURSOR); + return *ud; +} + static int querycursor_gc(lua_State *L) { - TSLua_cursor *ud = luaL_checkudata(L, 1, TS_META_QUERYCURSOR); - kv_push(cursors, ud->cursor); - ud->cursor = NULL; + TSQueryCursor *cursor = querycursor_check(L, 1); + ts_query_cursor_delete(cursor); return 0; } +static int querymatch_info(lua_State *L) +{ + TSQueryMatch *ud = luaL_checkudata(L, 1, TS_META_QUERYMATCH); + lua_pushinteger(L, ud->id); + lua_pushinteger(L, ud->pattern_index + 1); + return 2; +} + +static int querymatch_captures(lua_State *L) +{ + TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH); + lua_newtable(L); // [match, nodes, captures] + for (size_t i = 0; i < match->capture_count; i++) { + TSQueryCapture capture = match->captures[i]; + int index = (int)capture.index + 1; + + lua_rawgeti(L, -1, index); // [match, node, captures] + if (lua_isnil(L, -1)) { // [match, node, captures, nil] + lua_pop(L, 1); // [match, node, captures] + lua_newtable(L); // [match, node, captures, nodes] + } + push_node(L, capture.node, 1); // [match, node, captures, nodes, node] + lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, node, captures, nodes] + lua_rawseti(L, -2, index); // [match, node, captures] + } + return 1; +} + // Query methods int tslua_parse_query(lua_State *L) @@ -1638,7 +1625,7 @@ static void query_err_string(const char *src, int error_offset, TSQueryError err static TSQuery *query_check(lua_State *L, int index) { TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY); - return *ud; + return ud ? *ud : NULL; } static int query_gc(lua_State *L) |