diff options
author | Lewis Russell <lewis6991@gmail.com> | 2024-03-19 14:25:54 +0000 |
---|---|---|
committer | Lewis Russell <me@lewisr.dev> | 2024-03-19 16:16:54 +0000 |
commit | aca6c930025e191f22cfb541b01cb89093b9b809 (patch) | |
tree | ec243e31057dfb08d9ecda43898ba9ac3a4488c0 /src | |
parent | aca2048bcd57937ea1c7b7f0325f25d5b82588db (diff) | |
download | rneovim-aca6c930025e191f22cfb541b01cb89093b9b809.tar.gz rneovim-aca6c930025e191f22cfb541b01cb89093b9b809.tar.bz2 rneovim-aca6c930025e191f22cfb541b01cb89093b9b809.zip |
refactor(treesitter): simplify argument checks for userdata
Diffstat (limited to 'src')
-rw-r--r-- | src/nvim/lua/treesitter.c | 418 |
1 files changed, 136 insertions, 282 deletions
diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index 2d44e485cb..0cf1ad2833 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -258,14 +258,19 @@ int tslua_remove_lang(lua_State *L) return 1; } -int tslua_inspect_lang(lua_State *L) +static TSLanguage *lang_check(lua_State *L, int index) { - const char *lang_name = luaL_checkstring(L, 1); - + const char *lang_name = luaL_checkstring(L, index); TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name); if (!lang) { - return luaL_error(L, "no such language: %s", lang_name); + luaL_error(L, "no such language: %s", lang_name); } + return lang; +} + +int tslua_inspect_lang(lua_State *L) +{ + TSLanguage *lang = lang_check(L, 1); lua_createtable(L, 0, 2); // [retval] @@ -308,19 +313,14 @@ int tslua_inspect_lang(lua_State *L) int tslua_push_parser(lua_State *L) { - // Gather language name - const char *lang_name = luaL_checkstring(L, 1); - - TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name); - if (!lang) { - return luaL_error(L, "no such language: %s", lang_name); - } + TSLanguage *lang = lang_check(L, 1); TSParser **parser = lua_newuserdata(L, sizeof(TSParser *)); *parser = ts_parser_new(); if (!ts_parser_set_language(*parser, lang)) { ts_parser_delete(*parser); + const char *lang_name = luaL_checkstring(L, 1); return luaL_error(L, "Failed to load language : %s", lang_name); } @@ -329,9 +329,13 @@ int tslua_push_parser(lua_State *L) return 1; } -static TSParser **parser_check(lua_State *L, uint16_t index) +static TSParser *parser_check(lua_State *L, uint16_t index) { - return luaL_checkudata(L, index, TS_META_PARSER); + TSParser **ud = luaL_checkudata(L, index, TS_META_PARSER); + if (!ud || !(*ud)) { + luaL_argerror(L, index, "TSParser expected"); + } + return *ud; } static void logger_gc(TSLogger logger) @@ -347,13 +351,9 @@ static void logger_gc(TSLogger logger) static int parser_gc(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } - - logger_gc(ts_parser_logger(*p)); - ts_parser_delete(*p); + TSParser *p = parser_check(L, 1); + logger_gc(ts_parser_logger(p)); + ts_parser_delete(p); return 0; } @@ -426,11 +426,7 @@ static void push_ranges(lua_State *L, const TSRange *ranges, const size_t length static int parser_parse(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p || !(*p)) { - return 0; - } - + TSParser *p = parser_check(L, 1); TSTree *old_tree = NULL; if (!lua_isnil(L, 2)) { TSLuaTree *ud = tree_check(L, 2); @@ -449,7 +445,7 @@ static int parser_parse(lua_State *L) switch (lua_type(L, 3)) { case LUA_TSTRING: str = lua_tolstring(L, 3, &len); - new_tree = ts_parser_parse_string(*p, old_tree, str, (uint32_t)len); + new_tree = ts_parser_parse_string(p, old_tree, str, (uint32_t)len); break; case LUA_TNUMBER: @@ -465,7 +461,7 @@ static int parser_parse(lua_State *L) } input = (TSInput){ (void *)buf, input_cb, TSInputEncodingUTF8 }; - new_tree = ts_parser_parse(*p, old_tree, input); + new_tree = ts_parser_parse(p, old_tree, input); break; @@ -496,21 +492,14 @@ static int parser_parse(lua_State *L) static int parser_reset(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (p && *p) { - ts_parser_reset(*p); - } - + TSParser *p = parser_check(L, 1); + ts_parser_reset(p); return 0; } static int tree_copy(lua_State *L) { TSLuaTree *ud = tree_check(L, 1); - if (!ud) { - return 0; - } - TSTree *copy = ts_tree_copy(ud->tree); push_tree(L, copy); // [tree] @@ -525,9 +514,6 @@ static int tree_edit(lua_State *L) } TSLuaTree *ud = tree_check(L, 1); - if (!ud) { - return 0; - } uint32_t start_byte = (uint32_t)luaL_checkint(L, 2); uint32_t old_end_byte = (uint32_t)luaL_checkint(L, 3); @@ -547,9 +533,6 @@ static int tree_edit(lua_State *L) static int tree_get_ranges(lua_State *L) { TSLuaTree *ud = tree_check(L, 1); - if (!ud) { - return 0; - } bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); @@ -562,6 +545,11 @@ static int tree_get_ranges(lua_State *L) return 1; } +static void range_err(lua_State *L) +{ + luaL_error(L, "Ranges can only be made from 6 element long tables or nodes."); +} + // Use the top of the stack (without popping it) to create a TSRange, it can be // either a lua table or a TSNode static void range_from_lua(lua_State *L, TSRange *range) @@ -571,7 +559,7 @@ static void range_from_lua(lua_State *L, TSRange *range) if (lua_istable(L, -1)) { // should be a table of 6 elements if (lua_objlen(L, -1) != 6) { - goto error; + range_err(L); } lua_rawgeti(L, -1, 1); // [ range, start_row] @@ -610,7 +598,7 @@ static void range_from_lua(lua_State *L, TSRange *range) .start_byte = start_byte, .end_byte = end_byte, }; - } else if (node_check(L, -1, &node)) { + } else if (node_check2(L, -1, &node)) { *range = (TSRange) { .start_point = ts_node_start_point(node), .end_point = ts_node_end_point(node), @@ -618,29 +606,20 @@ static void range_from_lua(lua_State *L, TSRange *range) .end_byte = ts_node_end_byte(node) }; } else { - goto error; + range_err(L); } - return; -error: - luaL_error(L, - "Ranges can only be made from 6 element long tables or nodes."); } static int parser_set_ranges(lua_State *L) { if (lua_gettop(L) < 2) { - return luaL_error(L, - "not enough args to parser:set_included_ranges()"); + return luaL_error(L, "not enough args to parser:set_included_ranges()"); } - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } + TSParser *p = parser_check(L, 1); if (!lua_istable(L, 2)) { - return luaL_error(L, - "argument for parser:set_included_ranges() should be a table."); + return luaL_argerror(L, 2, "table expected."); } size_t tbl_len = lua_objlen(L, 2); @@ -654,7 +633,7 @@ static int parser_set_ranges(lua_State *L) } // This memcpies ranges, thus we can free it afterwards - ts_parser_set_included_ranges(*p, ranges, (uint32_t)tbl_len); + ts_parser_set_included_ranges(p, ranges, (uint32_t)tbl_len); xfree(ranges); return 0; @@ -662,15 +641,12 @@ static int parser_set_ranges(lua_State *L) static int parser_get_ranges(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } + TSParser *p = parser_check(L, 1); bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); uint32_t len; - const TSRange *ranges = ts_parser_included_ranges(*p, &len); + const TSRange *ranges = ts_parser_included_ranges(p, &len); push_ranges(L, ranges, len, include_bytes); return 1; @@ -678,28 +654,21 @@ static int parser_get_ranges(lua_State *L) static int parser_set_timeout(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } + TSParser *p = parser_check(L, 1); if (lua_gettop(L) < 2) { luaL_error(L, "integer expected"); } uint32_t timeout = (uint32_t)luaL_checkinteger(L, 2); - ts_parser_set_timeout_micros(*p, timeout); + ts_parser_set_timeout_micros(p, timeout); return 0; } static int parser_get_timeout(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } - - lua_pushinteger(L, (lua_Integer)ts_parser_timeout_micros(*p)); + TSParser *p = parser_check(L, 1); + lua_pushinteger(L, (lua_Integer)ts_parser_timeout_micros(p)); return 1; } @@ -723,10 +692,7 @@ static void logger_cb(void *payload, TSLogType logtype, const char *s) static int parser_set_logger(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } + TSParser *p = parser_check(L, 1); if (!lua_isboolean(L, 2)) { return luaL_argerror(L, 2, "boolean expected"); @@ -756,18 +722,14 @@ static int parser_set_logger(lua_State *L) .log = logger_cb }; - ts_parser_set_logger(*p, logger); + ts_parser_set_logger(p, logger); return 0; } static int parser_get_logger(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } - - TSLogger logger = ts_parser_logger(*p); + TSParser *p = parser_check(L, 1); + TSLogger logger = ts_parser_logger(p); if (logger.log) { TSLuaLoggerOpts *opts = (TSLuaLoggerOpts *)logger.payload; lua_rawgeti(L, LUA_REGISTRYINDEX, opts->cb); @@ -811,15 +773,16 @@ static void push_tree(lua_State *L, TSTree *tree) static TSLuaTree *tree_check(lua_State *L, int index) { TSLuaTree *ud = luaL_checkudata(L, index, TS_META_TREE); + if (!ud) { + luaL_argerror(L, index, "TSTree expected"); + } return ud; } static int tree_gc(lua_State *L) { TSLuaTree *ud = tree_check(L, 1); - if (ud) { - ts_tree_delete(ud->tree); - } + ts_tree_delete(ud->tree); return 0; } @@ -832,9 +795,6 @@ static int tree_tostring(lua_State *L) static int tree_root(lua_State *L) { TSLuaTree *ud = tree_check(L, 1); - if (!ud) { - return 0; - } TSNode root = ts_tree_root_node(ud->tree); push_node(L, root, 1); return 1; @@ -864,7 +824,7 @@ static void push_node(lua_State *L, TSNode node, int uindex) lua_setfenv(L, -2); // [udata] } -static bool node_check(lua_State *L, int index, TSNode *res) +static bool node_check2(lua_State *L, int index, TSNode *res) { TSNode *ud = luaL_checkudata(L, index, TS_META_NODE); if (ud) { @@ -874,12 +834,18 @@ static bool node_check(lua_State *L, int index, TSNode *res) return false; } -static int node_tostring(lua_State *L) +static TSNode node_check(lua_State *L, int index) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; + TSNode *ud = luaL_checkudata(L, index, TS_META_NODE); + if (!ud) { + luaL_argerror(L, index, "TSNode expected"); } + return *ud; +} + +static int node_tostring(lua_State *L) +{ + TSNode node = node_check(L, 1); lua_pushstring(L, "<node "); lua_pushstring(L, ts_node_type(node)); lua_pushstring(L, ">"); @@ -889,37 +855,22 @@ static int node_tostring(lua_State *L) static int node_eq(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - - TSNode node2; - if (!node_check(L, 2, &node2)) { - return 0; - } - + TSNode node = node_check(L, 1); + TSNode node2 = node_check(L, 2); lua_pushboolean(L, ts_node_eq(node, node2)); return 1; } static int node_id(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - + TSNode node = node_check(L, 1); lua_pushlstring(L, (const char *)&node.id, sizeof node.id); return 1; } static int node_range(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); @@ -945,10 +896,7 @@ static int node_range(lua_State *L) static int node_start(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSPoint start = ts_node_start_point(node); uint32_t start_byte = ts_node_start_byte(node); lua_pushinteger(L, start.row); @@ -959,10 +907,7 @@ static int node_start(lua_State *L) static int node_end(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSPoint end = ts_node_end_point(node); uint32_t end_byte = ts_node_end_byte(node); lua_pushinteger(L, end.row); @@ -973,10 +918,7 @@ static int node_end(lua_State *L) static int node_child_count(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); uint32_t count = ts_node_child_count(node); lua_pushinteger(L, count); return 1; @@ -984,10 +926,7 @@ static int node_child_count(lua_State *L) static int node_named_child_count(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); uint32_t count = ts_node_named_child_count(node); lua_pushinteger(L, count); return 1; @@ -995,20 +934,14 @@ static int node_named_child_count(lua_State *L) static int node_type(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushstring(L, ts_node_type(node)); return 1; } static int node_symbol(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSSymbol symbol = ts_node_symbol(node); lua_pushinteger(L, symbol); return 1; @@ -1016,10 +949,7 @@ static int node_symbol(lua_State *L) static int node_field(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); size_t name_len; const char *field_name = luaL_checklstring(L, 2, &name_len); @@ -1046,20 +976,14 @@ static int node_field(lua_State *L) static int node_named(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_is_named(node)); return 1; } static int node_sexpr(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); char *allocated = ts_node_string(node); lua_pushstring(L, allocated); xfree(allocated); @@ -1068,50 +992,35 @@ static int node_sexpr(lua_State *L) static int node_missing(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_is_missing(node)); return 1; } static int node_extra(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_is_extra(node)); return 1; } static int node_has_changes(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_has_changes(node)); return 1; } static int node_has_error(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_has_error(node)); return 1; } static int node_child(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); uint32_t num = (uint32_t)lua_tointeger(L, 2); TSNode child = ts_node_child(node, num); @@ -1121,10 +1030,7 @@ static int node_child(lua_State *L) static int node_named_child(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); uint32_t num = (uint32_t)lua_tointeger(L, 2); TSNode child = ts_node_named_child(node, num); @@ -1134,10 +1040,7 @@ static int node_named_child(lua_State *L) static int node_descendant_for_range(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSPoint start = { (uint32_t)lua_tointeger(L, 2), (uint32_t)lua_tointeger(L, 3) }; TSPoint end = { (uint32_t)lua_tointeger(L, 4), @@ -1150,10 +1053,7 @@ static int node_descendant_for_range(lua_State *L) static int node_named_descendant_for_range(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSPoint start = { (uint32_t)lua_tointeger(L, 2), (uint32_t)lua_tointeger(L, 3) }; TSPoint end = { (uint32_t)lua_tointeger(L, 4), @@ -1164,56 +1064,54 @@ static int node_named_descendant_for_range(lua_State *L) return 1; } -static int node_next_child(lua_State *L) +static TSTreeCursor *treecursor_check(lua_State *L, int index) { TSTreeCursor *ud = luaL_checkudata(L, lua_upvalueindex(1), TS_META_TREECURSOR); if (!ud) { - return 0; + luaL_error(L, "TSTreeCursor expected"); } + return ud; +} - TSNode source; - if (!node_check(L, lua_upvalueindex(2), &source)) { - return 0; - } +static int node_next_child(lua_State *L) +{ + TSTreeCursor *cursor = treecursor_check(L, lua_upvalueindex(1)); + TSNode source = node_check(L, lua_upvalueindex(2)); // First call should return first child - if (ts_node_eq(source, ts_tree_cursor_current_node(ud))) { - if (ts_tree_cursor_goto_first_child(ud)) { + if (ts_node_eq(source, ts_tree_cursor_current_node(cursor))) { + if (ts_tree_cursor_goto_first_child(cursor)) { goto push; } else { - goto end; + return 0; } } - if (ts_tree_cursor_goto_next_sibling(ud)) { + if (ts_tree_cursor_goto_next_sibling(cursor)) { push: push_node(L, - ts_tree_cursor_current_node(ud), + ts_tree_cursor_current_node(cursor), lua_upvalueindex(2)); // [node] - const char *field = ts_tree_cursor_current_field_name(ud); + const char *field = ts_tree_cursor_current_field_name(cursor); if (field != NULL) { - lua_pushstring(L, ts_tree_cursor_current_field_name(ud)); + lua_pushstring(L, ts_tree_cursor_current_field_name(cursor)); } else { lua_pushnil(L); } // [node, field_name_or_nil] return 2; } -end: return 0; } static int node_iter_children(lua_State *L) { - TSNode source; - if (!node_check(L, 1, &source)) { - return 0; - } + TSNode node = node_check(L, 1); TSTreeCursor *ud = lua_newuserdata(L, sizeof(TSTreeCursor)); // [udata] - *ud = ts_tree_cursor_new(source); + *ud = ts_tree_cursor_new(node); lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREECURSOR); // [udata, mt] lua_setmetatable(L, -2); // [udata] @@ -1225,17 +1123,14 @@ static int node_iter_children(lua_State *L) static int treecursor_gc(lua_State *L) { - TSTreeCursor *ud = luaL_checkudata(L, 1, TS_META_TREECURSOR); - ts_tree_cursor_delete(ud); + TSTreeCursor *cursor = treecursor_check(L, 1); + ts_tree_cursor_delete(cursor); return 0; } static int node_parent(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSNode parent = ts_node_parent(node); push_node(L, parent, 1); return 1; @@ -1243,10 +1138,7 @@ static int node_parent(lua_State *L) static int node_next_sibling(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSNode sibling = ts_node_next_sibling(node); push_node(L, sibling, 1); return 1; @@ -1254,10 +1146,7 @@ static int node_next_sibling(lua_State *L) static int node_prev_sibling(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSNode sibling = ts_node_prev_sibling(node); push_node(L, sibling, 1); return 1; @@ -1265,10 +1154,7 @@ static int node_prev_sibling(lua_State *L) static int node_next_named_sibling(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSNode sibling = ts_node_next_named_sibling(node); push_node(L, sibling, 1); return 1; @@ -1276,10 +1162,7 @@ static int node_next_named_sibling(lua_State *L) static int node_prev_named_sibling(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSNode sibling = ts_node_prev_named_sibling(node); push_node(L, sibling, 1); return 1; @@ -1287,10 +1170,7 @@ static int node_prev_named_sibling(lua_State *L) static int node_named_children(lua_State *L) { - TSNode source; - if (!node_check(L, 1, &source)) { - return 0; - } + TSNode source = node_check(L, 1); TSTreeCursor cursor = ts_tree_cursor_new(source); lua_newtable(L); @@ -1312,11 +1192,7 @@ static int node_named_children(lua_State *L) static int node_root(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - + TSNode node = node_check(L, 1); TSNode root = ts_tree_root_node(node.tree); push_node(L, root, 1); return 1; @@ -1324,59 +1200,34 @@ static int node_root(lua_State *L) static int node_tree(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - + node_check(L, 1); lua_getfenv(L, 1); // [udata, reftable] lua_rawgeti(L, -1, 1); // [udata, reftable, tree_udata] - return 1; } static int node_byte_length(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - + TSNode node = node_check(L, 1); uint32_t start_byte = ts_node_start_byte(node); uint32_t end_byte = ts_node_end_byte(node); - lua_pushinteger(L, end_byte - start_byte); return 1; } static int node_equal(lua_State *L) { - TSNode node1; - if (!node_check(L, 1, &node1)) { - return 0; - } - - TSNode node2; - if (!node_check(L, 2, &node2)) { - return luaL_error(L, "TSNode expected"); - } - + TSNode node1 = node_check(L, 1); + TSNode node2 = node_check(L, 2); lua_pushboolean(L, ts_node_eq(node1, node2)); return 1; } int tslua_push_querycursor(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return luaL_error(L, "TSNode expected"); - } + TSNode node = node_check(L, 1); TSQuery *query = query_check(L, 2); - if (!query) { - return luaL_error(L, "TSQuery expected"); - } - TSQueryCursor *cursor = ts_query_cursor_new(); ts_query_cursor_exec(cursor, query, node); @@ -1388,7 +1239,7 @@ int tslua_push_querycursor(lua_State *L) if (lua_gettop(L) >= 5 && !lua_isnil(L, 5)) { if (!lua_istable(L, 5)) { - return luaL_error(L, "table expected"); + return luaL_argerror(L, 5, "table expected"); } lua_pushnil(L); // [dict, ..., nil] while (lua_next(L, 5)) { @@ -1443,7 +1294,6 @@ static void push_querymatch(lua_State *L, TSQueryMatch *match, int uindex) static int querycursor_next_capture(lua_State *L) { TSQueryCursor *cursor = querycursor_check(L, 1); - TSQueryMatch match; uint32_t capture_index; if (!ts_query_cursor_next_capture(cursor, &match, &capture_index)) { @@ -1477,6 +1327,9 @@ static int querycursor_next_match(lua_State *L) static TSQueryCursor *querycursor_check(lua_State *L, int index) { TSQueryCursor **ud = luaL_checkudata(L, index, TS_META_QUERYCURSOR); + if (!ud || !(*ud)) { + luaL_argerror(L, index, "TSQueryCursor expected"); + } return *ud; } @@ -1487,17 +1340,26 @@ static int querycursor_gc(lua_State *L) return 0; } -static int querymatch_info(lua_State *L) +static TSQueryMatch *querymatch_check(lua_State *L, int index) { TSQueryMatch *ud = luaL_checkudata(L, 1, TS_META_QUERYMATCH); - lua_pushinteger(L, ud->id); - lua_pushinteger(L, ud->pattern_index + 1); + if (!ud) { + luaL_argerror(L, index, "TSQueryMatch expected"); + } + return ud; +} + +static int querymatch_info(lua_State *L) +{ + TSQueryMatch *match = querymatch_check(L, 1); + lua_pushinteger(L, match->id); + lua_pushinteger(L, match->pattern_index + 1); return 2; } static int querymatch_captures(lua_State *L) { - TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH); + TSQueryMatch *match = querymatch_check(L, 1); lua_newtable(L); // [match, nodes, captures] for (size_t i = 0; i < match->capture_count; i++) { TSQueryCapture capture = match->captures[i]; @@ -1523,11 +1385,7 @@ int tslua_parse_query(lua_State *L) return luaL_error(L, "string expected"); } - const char *lang_name = lua_tostring(L, 1); - TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name); - if (!lang) { - return luaL_error(L, "no such language: %s", lang_name); - } + TSLanguage *lang = lang_check(L, 1); size_t len; const char *src = lua_tolstring(L, 2, &len); @@ -1625,16 +1483,15 @@ static void query_err_string(const char *src, int error_offset, TSQueryError err static TSQuery *query_check(lua_State *L, int index) { TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY); - return ud ? *ud : NULL; + if (!ud || !(*ud)) { + luaL_argerror(L, index, "TSQuery expected"); + } + return *ud; } static int query_gc(lua_State *L) { TSQuery *query = query_check(L, 1); - if (!query) { - return 0; - } - ts_query_delete(query); return 0; } @@ -1648,9 +1505,6 @@ static int query_tostring(lua_State *L) static int query_inspect(lua_State *L) { TSQuery *query = query_check(L, 1); - if (!query) { - return 0; - } // TSQueryInfo lua_createtable(L, 0, 2); // [retval] |