diff options
Diffstat (limited to 'src/nvim/lua/treesitter.c')
-rw-r--r-- | src/nvim/lua/treesitter.c | 1109 |
1 files changed, 476 insertions, 633 deletions
diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index 25a753b179..e87cf756a8 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -33,15 +33,10 @@ #define TS_META_NODE "treesitter_node" #define TS_META_QUERY "treesitter_query" #define TS_META_QUERYCURSOR "treesitter_querycursor" +#define TS_META_QUERYMATCH "treesitter_querymatch" #define TS_META_TREECURSOR "treesitter_treecursor" typedef struct { - TSQueryCursor *cursor; - int predicated_match; - int max_match_id; -} TSLua_cursor; - -typedef struct { LuaRef cb; lua_State *lstate; bool lex; @@ -56,126 +51,44 @@ typedef struct { # include "lua/treesitter.c.generated.h" #endif -// TSParser -static struct luaL_Reg parser_meta[] = { - { "__gc", parser_gc }, - { "__tostring", parser_tostring }, - { "parse", parser_parse }, - { "reset", parser_reset }, - { "set_included_ranges", parser_set_ranges }, - { "included_ranges", parser_get_ranges }, - { "set_timeout", parser_set_timeout }, - { "timeout", parser_get_timeout }, - { "_set_logger", parser_set_logger }, - { "_logger", parser_get_logger }, - { NULL, NULL } -}; - -// TSTree -static struct luaL_Reg tree_meta[] = { - { "__gc", tree_gc }, - { "__tostring", tree_tostring }, - { "root", tree_root }, - { "edit", tree_edit }, - { "included_ranges", tree_get_ranges }, - { "copy", tree_copy }, - { NULL, NULL } -}; - -// TSNode -static struct luaL_Reg node_meta[] = { - { "__tostring", node_tostring }, - { "__eq", node_eq }, - { "__len", node_child_count }, - { "id", node_id }, - { "range", node_range }, - { "start", node_start }, - { "end_", node_end }, - { "type", node_type }, - { "symbol", node_symbol }, - { "field", node_field }, - { "named", node_named }, - { "missing", node_missing }, - { "extra", node_extra }, - { "has_changes", node_has_changes }, - { "has_error", node_has_error }, - { "sexpr", node_sexpr }, - { "child_count", node_child_count }, - { "named_child_count", node_named_child_count }, - { "child", node_child }, - { "named_child", node_named_child }, - { "descendant_for_range", node_descendant_for_range }, - { "named_descendant_for_range", node_named_descendant_for_range }, - { "parent", node_parent }, - { "iter_children", node_iter_children }, - { "_rawquery", node_rawquery }, - { "next_sibling", node_next_sibling }, - { "prev_sibling", node_prev_sibling }, - { "next_named_sibling", node_next_named_sibling }, - { "prev_named_sibling", node_prev_named_sibling }, - { "named_children", node_named_children }, - { "root", node_root }, - { "tree", node_tree }, - { "byte_length", node_byte_length }, - { "equal", node_equal }, - - { NULL, NULL } -}; - -// TSQuery -static struct luaL_Reg query_meta[] = { - { "__gc", query_gc }, - { "__tostring", query_tostring }, - { "inspect", query_inspect }, - { NULL, NULL } -}; +static PMap(cstr_t) langs = MAP_INIT; -// cursors are not exposed, but still needs garbage collection -static struct luaL_Reg querycursor_meta[] = { - { "__gc", querycursor_gc }, - { NULL, NULL } -}; +// TSLanguage -static struct luaL_Reg treecursor_meta[] = { - { "__gc", treecursor_gc }, - { NULL, NULL } -}; - -static kvec_t(TSQueryCursor *) cursors = KV_INITIAL_VALUE; -static PMap(cstr_t) langs = MAP_INIT; +int tslua_has_language(lua_State *L) +{ + const char *lang_name = luaL_checkstring(L, 1); + lua_pushboolean(L, map_has(cstr_t, &langs, lang_name)); + return 1; +} -static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta) +static TSLanguage *load_language(lua_State *L, const char *path, const char *lang_name, + const char *symbol) { - if (luaL_newmetatable(L, tname)) { // [meta] - luaL_register(L, NULL, meta); + uv_lib_t lib; + if (uv_dlopen(path, &lib)) { + uv_dlclose(&lib); + luaL_error(L, "Failed to load parser for language '%s': uv_dlopen: %s", + lang_name, uv_dlerror(&lib)); + } - lua_pushvalue(L, -1); // [meta, meta] - lua_setfield(L, -2, "__index"); // [meta] + char symbol_buf[128]; + snprintf(symbol_buf, sizeof(symbol_buf), "tree_sitter_%s", symbol); + + TSLanguage *(*lang_parser)(void); + if (uv_dlsym(&lib, symbol_buf, (void **)&lang_parser)) { + uv_dlclose(&lib); + luaL_error(L, "Failed to load parser: uv_dlsym: %s", uv_dlerror(&lib)); } - lua_pop(L, 1); // [] (don't use it now) -} -/// Init the tslua library. -/// -/// All global state is stored in the registry of the lua_State. -void tslua_init(lua_State *L) -{ - // type metatables - build_meta(L, TS_META_PARSER, parser_meta); - build_meta(L, TS_META_TREE, tree_meta); - build_meta(L, TS_META_NODE, node_meta); - build_meta(L, TS_META_QUERY, query_meta); - build_meta(L, TS_META_QUERYCURSOR, querycursor_meta); - build_meta(L, TS_META_TREECURSOR, treecursor_meta); + TSLanguage *lang = lang_parser(); - ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree); -} + if (lang == NULL) { + uv_dlclose(&lib); + luaL_error(L, "Failed to load parser %s: internal error", path); + } -int tslua_has_language(lua_State *L) -{ - const char *lang_name = luaL_checkstring(L, 1); - lua_pushboolean(L, map_has(cstr_t, &langs, lang_name)); - return 1; + return lang; } // Creates the language into the internal language map. @@ -196,34 +109,7 @@ int tslua_add_language(lua_State *L) return 1; } -#define BUFSIZE 128 - char symbol_buf[BUFSIZE]; - snprintf(symbol_buf, BUFSIZE, "tree_sitter_%s", symbol_name); -#undef BUFSIZE - - uv_lib_t lib; - if (uv_dlopen(path, &lib)) { - snprintf(IObuff, IOSIZE, "Failed to load parser for language '%s': uv_dlopen: %s", - lang_name, uv_dlerror(&lib)); - uv_dlclose(&lib); - lua_pushstring(L, IObuff); - return lua_error(L); - } - - TSLanguage *(*lang_parser)(void); - if (uv_dlsym(&lib, symbol_buf, (void **)&lang_parser)) { - snprintf(IObuff, IOSIZE, "Failed to load parser: uv_dlsym: %s", - uv_dlerror(&lib)); - uv_dlclose(&lib); - lua_pushstring(L, IObuff); - return lua_error(L); - } - - TSLanguage *lang = lang_parser(); - if (lang == NULL) { - uv_dlclose(&lib); - return luaL_error(L, "Failed to load parser %s: internal error", path); - } + TSLanguage *lang = load_language(L, path, lang_name, symbol_name); uint32_t lang_version = ts_language_version(lang); if (lang_version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION @@ -254,14 +140,19 @@ int tslua_remove_lang(lua_State *L) return 1; } -int tslua_inspect_lang(lua_State *L) +static TSLanguage *lang_check(lua_State *L, int index) { - const char *lang_name = luaL_checkstring(L, 1); - + const char *lang_name = luaL_checkstring(L, index); TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name); if (!lang) { - return luaL_error(L, "no such language: %s", lang_name); + luaL_error(L, "no such language: %s", lang_name); } + return lang; +} + +int tslua_inspect_lang(lua_State *L) +{ + TSLanguage *lang = lang_check(L, 1); lua_createtable(L, 0, 2); // [retval] @@ -295,28 +186,38 @@ int tslua_inspect_lang(lua_State *L) lua_setfield(L, -2, "fields"); // [retval] - uint32_t lang_version = ts_language_version(lang); - lua_pushinteger(L, lang_version); // [retval, version] + lua_pushinteger(L, ts_language_version(lang)); // [retval, version] lua_setfield(L, -2, "_abi_version"); return 1; } +// TSParser + +static struct luaL_Reg parser_meta[] = { + { "__gc", parser_gc }, + { "__tostring", parser_tostring }, + { "parse", parser_parse }, + { "reset", parser_reset }, + { "set_included_ranges", parser_set_ranges }, + { "included_ranges", parser_get_ranges }, + { "set_timeout", parser_set_timeout }, + { "timeout", parser_get_timeout }, + { "_set_logger", parser_set_logger }, + { "_logger", parser_get_logger }, + { NULL, NULL } +}; + int tslua_push_parser(lua_State *L) { - // Gather language name - const char *lang_name = luaL_checkstring(L, 1); - - TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name); - if (!lang) { - return luaL_error(L, "no such language: %s", lang_name); - } + TSLanguage *lang = lang_check(L, 1); TSParser **parser = lua_newuserdata(L, sizeof(TSParser *)); *parser = ts_parser_new(); if (!ts_parser_set_language(*parser, lang)) { ts_parser_delete(*parser); + const char *lang_name = luaL_checkstring(L, 1); return luaL_error(L, "Failed to load language : %s", lang_name); } @@ -325,9 +226,11 @@ int tslua_push_parser(lua_State *L) return 1; } -static TSParser **parser_check(lua_State *L, uint16_t index) +static TSParser *parser_check(lua_State *L, uint16_t index) { - return luaL_checkudata(L, index, TS_META_PARSER); + TSParser **ud = luaL_checkudata(L, index, TS_META_PARSER); + luaL_argcheck(L, *ud, index, "TSParser expected"); + return *ud; } static void logger_gc(TSLogger logger) @@ -343,13 +246,9 @@ static void logger_gc(TSLogger logger) static int parser_gc(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } - - logger_gc(ts_parser_logger(*p)); - ts_parser_delete(*p); + TSParser *p = parser_check(L, 1); + logger_gc(ts_parser_logger(p)); + ts_parser_delete(p); return 0; } @@ -371,7 +270,7 @@ static const char *input_cb(void *payload, uint32_t byte_index, TSPoint position return ""; } char *line = ml_get_buf(bp, (linenr_T)position.row + 1); - size_t len = strlen(line); + size_t len = (size_t)ml_get_buf_len(bp, (linenr_T)position.row + 1); if (position.column > len) { *bytes_read = 0; return ""; @@ -422,14 +321,10 @@ static void push_ranges(lua_State *L, const TSRange *ranges, const size_t length static int parser_parse(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p || !(*p)) { - return 0; - } - + TSParser *p = parser_check(L, 1); TSTree *old_tree = NULL; if (!lua_isnil(L, 2)) { - TSLuaTree *ud = tree_check(L, 2); + TSLuaTree *ud = luaL_checkudata(L, 2, TS_META_TREE); old_tree = ud ? ud->tree : NULL; } @@ -445,7 +340,7 @@ static int parser_parse(lua_State *L) switch (lua_type(L, 3)) { case LUA_TSTRING: str = lua_tolstring(L, 3, &len); - new_tree = ts_parser_parse_string(*p, old_tree, str, (uint32_t)len); + new_tree = ts_parser_parse_string(p, old_tree, str, (uint32_t)len); break; case LUA_TNUMBER: @@ -461,7 +356,7 @@ static int parser_parse(lua_State *L) } input = (TSInput){ (void *)buf, input_cb, TSInputEncodingUTF8 }; - new_tree = ts_parser_parse(*p, old_tree, input); + new_tree = ts_parser_parse(p, old_tree, input); break; @@ -492,70 +387,14 @@ static int parser_parse(lua_State *L) static int parser_reset(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (p && *p) { - ts_parser_reset(*p); - } - + TSParser *p = parser_check(L, 1); + ts_parser_reset(p); return 0; } -static int tree_copy(lua_State *L) +static void range_err(lua_State *L) { - TSLuaTree *ud = tree_check(L, 1); - if (!ud) { - return 0; - } - - TSTree *copy = ts_tree_copy(ud->tree); - push_tree(L, copy); // [tree] - - return 1; -} - -static int tree_edit(lua_State *L) -{ - if (lua_gettop(L) < 10) { - lua_pushstring(L, "not enough args to tree:edit()"); - return lua_error(L); - } - - TSLuaTree *ud = tree_check(L, 1); - if (!ud) { - return 0; - } - - uint32_t start_byte = (uint32_t)luaL_checkint(L, 2); - uint32_t old_end_byte = (uint32_t)luaL_checkint(L, 3); - uint32_t new_end_byte = (uint32_t)luaL_checkint(L, 4); - TSPoint start_point = { (uint32_t)luaL_checkint(L, 5), (uint32_t)luaL_checkint(L, 6) }; - TSPoint old_end_point = { (uint32_t)luaL_checkint(L, 7), (uint32_t)luaL_checkint(L, 8) }; - TSPoint new_end_point = { (uint32_t)luaL_checkint(L, 9), (uint32_t)luaL_checkint(L, 10) }; - - TSInputEdit edit = { start_byte, old_end_byte, new_end_byte, - start_point, old_end_point, new_end_point }; - - ts_tree_edit(ud->tree, &edit); - - return 0; -} - -static int tree_get_ranges(lua_State *L) -{ - TSLuaTree *ud = tree_check(L, 1); - if (!ud) { - return 0; - } - - bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); - - uint32_t len; - TSRange *ranges = ts_tree_included_ranges(ud->tree, &len); - - push_ranges(L, ranges, len, include_bytes); - - xfree(ranges); - return 1; + luaL_error(L, "Ranges can only be made from 6 element long tables or nodes."); } // Use the top of the stack (without popping it) to create a TSRange, it can be @@ -567,7 +406,7 @@ static void range_from_lua(lua_State *L, TSRange *range) if (lua_istable(L, -1)) { // should be a table of 6 elements if (lua_objlen(L, -1) != 6) { - goto error; + range_err(L); } lua_rawgeti(L, -1, 1); // [ range, start_row] @@ -606,7 +445,7 @@ static void range_from_lua(lua_State *L, TSRange *range) .start_byte = start_byte, .end_byte = end_byte, }; - } else if (node_check(L, -1, &node)) { + } else if (node_check_opt(L, -1, &node)) { *range = (TSRange) { .start_point = ts_node_start_point(node), .end_point = ts_node_end_point(node), @@ -614,30 +453,19 @@ static void range_from_lua(lua_State *L, TSRange *range) .end_byte = ts_node_end_byte(node) }; } else { - goto error; + range_err(L); } - return; -error: - luaL_error(L, - "Ranges can only be made from 6 element long tables or nodes."); } static int parser_set_ranges(lua_State *L) { if (lua_gettop(L) < 2) { - return luaL_error(L, - "not enough args to parser:set_included_ranges()"); + return luaL_error(L, "not enough args to parser:set_included_ranges()"); } - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } + TSParser *p = parser_check(L, 1); - if (!lua_istable(L, 2)) { - return luaL_error(L, - "argument for parser:set_included_ranges() should be a table."); - } + luaL_argcheck(L, lua_istable(L, 2), 2, "table expected."); size_t tbl_len = lua_objlen(L, 2); TSRange *ranges = xmalloc(sizeof(TSRange) * tbl_len); @@ -650,7 +478,7 @@ static int parser_set_ranges(lua_State *L) } // This memcpies ranges, thus we can free it afterwards - ts_parser_set_included_ranges(*p, ranges, (uint32_t)tbl_len); + ts_parser_set_included_ranges(p, ranges, (uint32_t)tbl_len); xfree(ranges); return 0; @@ -658,15 +486,12 @@ static int parser_set_ranges(lua_State *L) static int parser_get_ranges(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } + TSParser *p = parser_check(L, 1); bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); uint32_t len; - const TSRange *ranges = ts_parser_included_ranges(*p, &len); + const TSRange *ranges = ts_parser_included_ranges(p, &len); push_ranges(L, ranges, len, include_bytes); return 1; @@ -674,28 +499,21 @@ static int parser_get_ranges(lua_State *L) static int parser_set_timeout(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } + TSParser *p = parser_check(L, 1); if (lua_gettop(L) < 2) { luaL_error(L, "integer expected"); } uint32_t timeout = (uint32_t)luaL_checkinteger(L, 2); - ts_parser_set_timeout_micros(*p, timeout); + ts_parser_set_timeout_micros(p, timeout); return 0; } static int parser_get_timeout(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } - - lua_pushinteger(L, (lua_Integer)ts_parser_timeout_micros(*p)); + TSParser *p = parser_check(L, 1); + lua_pushinteger(L, (lua_Integer)ts_parser_timeout_micros(p)); return 1; } @@ -719,22 +537,11 @@ static void logger_cb(void *payload, TSLogType logtype, const char *s) static int parser_set_logger(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } + TSParser *p = parser_check(L, 1); - if (!lua_isboolean(L, 2)) { - return luaL_argerror(L, 2, "boolean expected"); - } - - if (!lua_isboolean(L, 3)) { - return luaL_argerror(L, 3, "boolean expected"); - } - - if (!lua_isfunction(L, 4)) { - return luaL_argerror(L, 4, "function expected"); - } + luaL_argcheck(L, lua_isboolean(L, 2), 2, "boolean expected"); + luaL_argcheck(L, lua_isboolean(L, 3), 3, "boolean expected"); + luaL_argcheck(L, lua_isfunction(L, 4), 4, "function expected"); TSLuaLoggerOpts *opts = xmalloc(sizeof(TSLuaLoggerOpts)); lua_pushvalue(L, 4); @@ -752,18 +559,14 @@ static int parser_set_logger(lua_State *L) .log = logger_cb }; - ts_parser_set_logger(*p, logger); + ts_parser_set_logger(p, logger); return 0; } static int parser_get_logger(lua_State *L) { - TSParser **p = parser_check(L, 1); - if (!p) { - return 0; - } - - TSLogger logger = ts_parser_logger(*p); + TSParser *p = parser_check(L, 1); + TSLogger logger = ts_parser_logger(p); if (logger.log) { TSLuaLoggerOpts *opts = (TSLuaLoggerOpts *)logger.payload; lua_rawgeti(L, LUA_REGISTRYINDEX, opts->cb); @@ -774,7 +577,17 @@ static int parser_get_logger(lua_State *L) return 1; } -// Tree methods +// TSTree + +static struct luaL_Reg tree_meta[] = { + { "__gc", tree_gc }, + { "__tostring", tree_tostring }, + { "root", tree_root }, + { "edit", tree_edit }, + { "included_ranges", tree_get_ranges }, + { "copy", tree_copy }, + { NULL, NULL } +}; /// Push tree interface on to the lua stack. /// @@ -804,18 +617,58 @@ static void push_tree(lua_State *L, TSTree *tree) lua_setfenv(L, -2); // [udata] } -static TSLuaTree *tree_check(lua_State *L, int index) +static int tree_copy(lua_State *L) { - TSLuaTree *ud = luaL_checkudata(L, index, TS_META_TREE); - return ud; + TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); + TSTree *copy = ts_tree_copy(ud->tree); + push_tree(L, copy); // [tree] + + return 1; } -static int tree_gc(lua_State *L) +static int tree_edit(lua_State *L) { - TSLuaTree *ud = tree_check(L, 1); - if (ud) { - ts_tree_delete(ud->tree); + if (lua_gettop(L) < 10) { + lua_pushstring(L, "not enough args to tree:edit()"); + return lua_error(L); } + + TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); + + uint32_t start_byte = (uint32_t)luaL_checkint(L, 2); + uint32_t old_end_byte = (uint32_t)luaL_checkint(L, 3); + uint32_t new_end_byte = (uint32_t)luaL_checkint(L, 4); + TSPoint start_point = { (uint32_t)luaL_checkint(L, 5), (uint32_t)luaL_checkint(L, 6) }; + TSPoint old_end_point = { (uint32_t)luaL_checkint(L, 7), (uint32_t)luaL_checkint(L, 8) }; + TSPoint new_end_point = { (uint32_t)luaL_checkint(L, 9), (uint32_t)luaL_checkint(L, 10) }; + + TSInputEdit edit = { start_byte, old_end_byte, new_end_byte, + start_point, old_end_point, new_end_point }; + + ts_tree_edit(ud->tree, &edit); + + return 0; +} + +static int tree_get_ranges(lua_State *L) +{ + TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); + + bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); + + uint32_t len; + TSRange *ranges = ts_tree_included_ranges(ud->tree, &len); + + push_ranges(L, ranges, len, include_bytes); + + xfree(ranges); + return 1; +} + +static int tree_gc(lua_State *L) +{ + TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); + ts_tree_delete(ud->tree); return 0; } @@ -827,16 +680,66 @@ static int tree_tostring(lua_State *L) static int tree_root(lua_State *L) { - TSLuaTree *ud = tree_check(L, 1); - if (!ud) { - return 0; - } + TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE); TSNode root = ts_tree_root_node(ud->tree); push_node(L, root, 1); return 1; } -// Node methods +// TSTreeCursor + +static struct luaL_Reg treecursor_meta[] = { + { "__gc", treecursor_gc }, + { NULL, NULL } +}; + +static int treecursor_gc(lua_State *L) +{ + TSTreeCursor *cursor = luaL_checkudata(L, 1, TS_META_TREECURSOR); + ts_tree_cursor_delete(cursor); + return 0; +} + +// TSNode +static struct luaL_Reg node_meta[] = { + { "__tostring", node_tostring }, + { "__eq", node_eq }, + { "__len", node_child_count }, + { "id", node_id }, + { "range", node_range }, + { "start", node_start }, + { "end_", node_end }, + { "type", node_type }, + { "symbol", node_symbol }, + { "field", node_field }, + { "named", node_named }, + { "missing", node_missing }, + { "extra", node_extra }, + { "has_changes", node_has_changes }, + { "has_error", node_has_error }, + { "sexpr", node_sexpr }, + { "child_count", node_child_count }, + { "named_child_count", node_named_child_count }, + { "child", node_child }, + { "named_child", node_named_child }, + { "descendant_for_range", node_descendant_for_range }, + { "named_descendant_for_range", node_named_descendant_for_range }, + { "parent", node_parent }, + { "__has_ancestor", __has_ancestor }, + { "child_containing_descendant", node_child_containing_descendant }, + { "iter_children", node_iter_children }, + { "next_sibling", node_next_sibling }, + { "prev_sibling", node_prev_sibling }, + { "next_named_sibling", node_next_named_sibling }, + { "prev_named_sibling", node_prev_named_sibling }, + { "named_children", node_named_children }, + { "root", node_root }, + { "tree", node_tree }, + { "byte_length", node_byte_length }, + { "equal", node_equal }, + + { NULL, NULL } +}; /// Push node interface on to the Lua stack /// @@ -860,7 +763,7 @@ static void push_node(lua_State *L, TSNode node, int uindex) lua_setfenv(L, -2); // [udata] } -static bool node_check(lua_State *L, int index, TSNode *res) +static bool node_check_opt(lua_State *L, int index, TSNode *res) { TSNode *ud = luaL_checkudata(L, index, TS_META_NODE); if (ud) { @@ -870,12 +773,15 @@ static bool node_check(lua_State *L, int index, TSNode *res) return false; } +static TSNode node_check(lua_State *L, int index) +{ + TSNode *ud = luaL_checkudata(L, index, TS_META_NODE); + return *ud; +} + static int node_tostring(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushstring(L, "<node "); lua_pushstring(L, ts_node_type(node)); lua_pushstring(L, ">"); @@ -885,37 +791,22 @@ static int node_tostring(lua_State *L) static int node_eq(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - - TSNode node2; - if (!node_check(L, 2, &node2)) { - return 0; - } - + TSNode node = node_check(L, 1); + TSNode node2 = node_check(L, 2); lua_pushboolean(L, ts_node_eq(node, node2)); return 1; } static int node_id(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - + TSNode node = node_check(L, 1); lua_pushlstring(L, (const char *)&node.id, sizeof node.id); return 1; } static int node_range(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2); @@ -941,10 +832,7 @@ static int node_range(lua_State *L) static int node_start(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSPoint start = ts_node_start_point(node); uint32_t start_byte = ts_node_start_byte(node); lua_pushinteger(L, start.row); @@ -955,10 +843,7 @@ static int node_start(lua_State *L) static int node_end(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSPoint end = ts_node_end_point(node); uint32_t end_byte = ts_node_end_byte(node); lua_pushinteger(L, end.row); @@ -969,10 +854,7 @@ static int node_end(lua_State *L) static int node_child_count(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); uint32_t count = ts_node_child_count(node); lua_pushinteger(L, count); return 1; @@ -980,10 +862,7 @@ static int node_child_count(lua_State *L) static int node_named_child_count(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); uint32_t count = ts_node_named_child_count(node); lua_pushinteger(L, count); return 1; @@ -991,20 +870,14 @@ static int node_named_child_count(lua_State *L) static int node_type(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushstring(L, ts_node_type(node)); return 1; } static int node_symbol(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSSymbol symbol = ts_node_symbol(node); lua_pushinteger(L, symbol); return 1; @@ -1012,10 +885,7 @@ static int node_symbol(lua_State *L) static int node_field(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); size_t name_len; const char *field_name = luaL_checklstring(L, 2, &name_len); @@ -1042,20 +912,14 @@ static int node_field(lua_State *L) static int node_named(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_is_named(node)); return 1; } static int node_sexpr(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); char *allocated = ts_node_string(node); lua_pushstring(L, allocated); xfree(allocated); @@ -1064,50 +928,35 @@ static int node_sexpr(lua_State *L) static int node_missing(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_is_missing(node)); return 1; } static int node_extra(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_is_extra(node)); return 1; } static int node_has_changes(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_has_changes(node)); return 1; } static int node_has_error(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); lua_pushboolean(L, ts_node_has_error(node)); return 1; } static int node_child(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); uint32_t num = (uint32_t)lua_tointeger(L, 2); TSNode child = ts_node_child(node, num); @@ -1117,10 +966,7 @@ static int node_child(lua_State *L) static int node_named_child(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); uint32_t num = (uint32_t)lua_tointeger(L, 2); TSNode child = ts_node_named_child(node, num); @@ -1130,10 +976,7 @@ static int node_named_child(lua_State *L) static int node_descendant_for_range(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSPoint start = { (uint32_t)lua_tointeger(L, 2), (uint32_t)lua_tointeger(L, 3) }; TSPoint end = { (uint32_t)lua_tointeger(L, 4), @@ -1146,10 +989,7 @@ static int node_descendant_for_range(lua_State *L) static int node_named_descendant_for_range(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSPoint start = { (uint32_t)lua_tointeger(L, 2), (uint32_t)lua_tointeger(L, 3) }; TSPoint end = { (uint32_t)lua_tointeger(L, 4), @@ -1162,54 +1002,41 @@ static int node_named_descendant_for_range(lua_State *L) static int node_next_child(lua_State *L) { - TSTreeCursor *ud = luaL_checkudata(L, lua_upvalueindex(1), TS_META_TREECURSOR); - if (!ud) { - return 0; - } - - TSNode source; - if (!node_check(L, lua_upvalueindex(2), &source)) { - return 0; - } + TSTreeCursor *cursor = luaL_checkudata(L, lua_upvalueindex(1), TS_META_TREECURSOR); + TSNode source = node_check(L, lua_upvalueindex(2)); // First call should return first child - if (ts_node_eq(source, ts_tree_cursor_current_node(ud))) { - if (ts_tree_cursor_goto_first_child(ud)) { + if (ts_node_eq(source, ts_tree_cursor_current_node(cursor))) { + if (ts_tree_cursor_goto_first_child(cursor)) { goto push; } else { - goto end; + return 0; } } - if (ts_tree_cursor_goto_next_sibling(ud)) { -push: - push_node(L, - ts_tree_cursor_current_node(ud), - lua_upvalueindex(2)); // [node] + if (!ts_tree_cursor_goto_next_sibling(cursor)) { + return 0; + } - const char *field = ts_tree_cursor_current_field_name(ud); +push: + push_node(L, ts_tree_cursor_current_node(cursor), lua_upvalueindex(2)); // [node] - if (field != NULL) { - lua_pushstring(L, ts_tree_cursor_current_field_name(ud)); - } else { - lua_pushnil(L); - } // [node, field_name_or_nil] - return 2; - } + const char *field = ts_tree_cursor_current_field_name(cursor); -end: - return 0; + if (field != NULL) { + lua_pushstring(L, ts_tree_cursor_current_field_name(cursor)); + } else { + lua_pushnil(L); + } // [node, field_name_or_nil] + return 2; } static int node_iter_children(lua_State *L) { - TSNode source; - if (!node_check(L, 1, &source)) { - return 0; - } + TSNode node = node_check(L, 1); TSTreeCursor *ud = lua_newuserdata(L, sizeof(TSTreeCursor)); // [udata] - *ud = ts_tree_cursor_new(source); + *ud = ts_tree_cursor_new(node); lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREECURSOR); // [udata, mt] lua_setmetatable(L, -2); // [udata] @@ -1219,30 +1046,60 @@ static int node_iter_children(lua_State *L) return 1; } -static int treecursor_gc(lua_State *L) +static int node_parent(lua_State *L) { - TSTreeCursor *ud = luaL_checkudata(L, 1, TS_META_TREECURSOR); - ts_tree_cursor_delete(ud); - return 0; + TSNode node = node_check(L, 1); + TSNode parent = ts_node_parent(node); + push_node(L, parent, 1); + return 1; } -static int node_parent(lua_State *L) +static int __has_ancestor(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; + TSNode descendant = node_check(L, 1); + if (lua_type(L, 2) != LUA_TTABLE) { + lua_pushboolean(L, false); + return 1; } - TSNode parent = ts_node_parent(node); - push_node(L, parent, 1); + int const pred_len = (int)lua_objlen(L, 2); + + TSNode node = ts_tree_root_node(descendant.tree); + while (!ts_node_is_null(node)) { + char const *node_type = ts_node_type(node); + size_t node_type_len = strlen(node_type); + + for (int i = 3; i <= pred_len; i++) { + lua_rawgeti(L, 2, i); + if (lua_type(L, -1) == LUA_TSTRING) { + size_t check_len; + char const *check_str = lua_tolstring(L, -1, &check_len); + if (node_type_len == check_len && memcmp(node_type, check_str, check_len) == 0) { + lua_pushboolean(L, true); + return 1; + } + } + lua_pop(L, 1); + } + + node = ts_node_child_containing_descendant(node, descendant); + } + + lua_pushboolean(L, false); + return 1; +} + +static int node_child_containing_descendant(lua_State *L) +{ + TSNode node = node_check(L, 1); + TSNode descendant = node_check(L, 2); + TSNode child = ts_node_child_containing_descendant(node, descendant); + push_node(L, child, 1); return 1; } static int node_next_sibling(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSNode sibling = ts_node_next_sibling(node); push_node(L, sibling, 1); return 1; @@ -1250,10 +1107,7 @@ static int node_next_sibling(lua_State *L) static int node_prev_sibling(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSNode sibling = ts_node_prev_sibling(node); push_node(L, sibling, 1); return 1; @@ -1261,10 +1115,7 @@ static int node_prev_sibling(lua_State *L) static int node_next_named_sibling(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSNode sibling = ts_node_next_named_sibling(node); push_node(L, sibling, 1); return 1; @@ -1272,10 +1123,7 @@ static int node_next_named_sibling(lua_State *L) static int node_prev_named_sibling(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } + TSNode node = node_check(L, 1); TSNode sibling = ts_node_prev_named_sibling(node); push_node(L, sibling, 1); return 1; @@ -1283,10 +1131,7 @@ static int node_prev_named_sibling(lua_State *L) static int node_named_children(lua_State *L) { - TSNode source; - if (!node_check(L, 1, &source)) { - return 0; - } + TSNode source = node_check(L, 1); TSTreeCursor cursor = ts_tree_cursor_new(source); lua_newtable(L); @@ -1308,11 +1153,7 @@ static int node_named_children(lua_State *L) static int node_root(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - + TSNode node = node_check(L, 1); TSNode root = ts_tree_root_node(node.tree); push_node(L, root, 1); return 1; @@ -1320,215 +1161,196 @@ static int node_root(lua_State *L) static int node_tree(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - + node_check(L, 1); lua_getfenv(L, 1); // [udata, reftable] lua_rawgeti(L, -1, 1); // [udata, reftable, tree_udata] - return 1; } static int node_byte_length(lua_State *L) { - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } - + TSNode node = node_check(L, 1); uint32_t start_byte = ts_node_start_byte(node); uint32_t end_byte = ts_node_end_byte(node); - lua_pushinteger(L, end_byte - start_byte); return 1; } static int node_equal(lua_State *L) { - TSNode node1; - if (!node_check(L, 1, &node1)) { - return 0; - } - - TSNode node2; - if (!node_check(L, 2, &node2)) { - return luaL_error(L, "TSNode expected"); - } - + TSNode node1 = node_check(L, 1); + TSNode node2 = node_check(L, 2); lua_pushboolean(L, ts_node_eq(node1, node2)); return 1; } -/// assumes the match table being on top of the stack -static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx) -{ - // [match] - for (size_t i = 0; i < match->capture_count; i++) { - lua_rawgeti(L, -1, (int)match->captures[i].index + 1); // [match, captures] - if (lua_isnil(L, -1)) { // [match, nil] - lua_pop(L, 1); // [match] - lua_createtable(L, 1, 0); // [match, captures] - } - push_node(L, match->captures[i].node, nodeidx); // [match, captures, node] - lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, captures] - lua_rawseti(L, -2, (int)match->captures[i].index + 1); // [match] - } -} +// TSQueryCursor -static int query_next_match(lua_State *L) -{ - TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1)); - TSQueryCursor *cursor = ud->cursor; - - TSQuery *query = query_check(L, lua_upvalueindex(3)); - TSQueryMatch match; - if (ts_query_cursor_next_match(cursor, &match)) { - lua_pushinteger(L, match.pattern_index + 1); // [index] - lua_createtable(L, (int)ts_query_capture_count(query), 0); // [index, match] - set_match(L, &match, lua_upvalueindex(2)); - return 2; - } - return 0; -} +static struct luaL_Reg querycursor_meta[] = { + { "remove_match", querycursor_remove_match }, + { "next_capture", querycursor_next_capture }, + { "next_match", querycursor_next_match }, + { "__gc", querycursor_gc }, + { NULL, NULL } +}; -static int query_next_capture(lua_State *L) +int tslua_push_querycursor(lua_State *L) { - // Upvalues are: - // [ cursor, node, query, current_match ] - TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1)); - TSQueryCursor *cursor = ud->cursor; - - TSQuery *query = query_check(L, lua_upvalueindex(3)); - - if (ud->predicated_match > -1) { - lua_getfield(L, lua_upvalueindex(4), "active"); - bool active = lua_toboolean(L, -1); - lua_pop(L, 1); - if (!active) { - ts_query_cursor_remove_match(cursor, (uint32_t)ud->predicated_match); - } - ud->predicated_match = -1; - } + TSNode node = node_check(L, 1); - TSQueryMatch match; - uint32_t capture_index; - if (ts_query_cursor_next_capture(cursor, &match, &capture_index)) { - TSQueryCapture capture = match.captures[capture_index]; - - // TODO(vigoux): handle capture quantifiers here - lua_pushinteger(L, capture.index + 1); // [index] - push_node(L, capture.node, lua_upvalueindex(2)); // [index, node] - - // Now check if we need to run the predicates - uint32_t n_pred; - ts_query_predicates_for_pattern(query, match.pattern_index, &n_pred); - - if (n_pred > 0 && (ud->max_match_id < (int)match.id)) { - ud->max_match_id = (int)match.id; - - // Create a new cleared match table - lua_createtable(L, (int)ts_query_capture_count(query), 2); // [index, node, match] - set_match(L, &match, lua_upvalueindex(2)); - lua_pushinteger(L, match.pattern_index + 1); - lua_setfield(L, -2, "pattern"); - - if (match.capture_count > 1) { - ud->predicated_match = (int)match.id; - lua_pushboolean(L, false); - lua_setfield(L, -2, "active"); - } - - // Set current_match to the new match - lua_replace(L, lua_upvalueindex(4)); // [index, node] - lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match] - return 3; - } - return 2; - } - return 0; -} - -static int node_rawquery(lua_State *L) -{ - TSNode node; - if (!node_check(L, 1, &node)) { - return 0; - } TSQuery *query = query_check(L, 2); - - TSQueryCursor *cursor; - if (kv_size(cursors) > 0) { - cursor = kv_pop(cursors); - } else { - cursor = ts_query_cursor_new(); - } - - ts_query_cursor_set_max_start_depth(cursor, UINT32_MAX); - ts_query_cursor_set_match_limit(cursor, 256); + TSQueryCursor *cursor = ts_query_cursor_new(); ts_query_cursor_exec(cursor, query, node); - bool captures = lua_toboolean(L, 3); - - if (lua_gettop(L) >= 4) { - uint32_t start = (uint32_t)luaL_checkinteger(L, 4); - uint32_t end = lua_gettop(L) >= 5 ? (uint32_t)luaL_checkinteger(L, 5) : MAXLNUM; + if (lua_gettop(L) >= 3) { + uint32_t start = (uint32_t)luaL_checkinteger(L, 3); + uint32_t end = lua_gettop(L) >= 4 ? (uint32_t)luaL_checkinteger(L, 4) : MAXLNUM; ts_query_cursor_set_point_range(cursor, (TSPoint){ start, 0 }, (TSPoint){ end, 0 }); } - if (lua_gettop(L) >= 6 && !lua_isnil(L, 6)) { - if (!lua_istable(L, 6)) { - return luaL_error(L, "table expected"); - } - lua_pushnil(L); - // stack: [dict, ..., nil] - while (lua_next(L, 6)) { - // stack: [dict, ..., key, value] + if (lua_gettop(L) >= 5 && !lua_isnil(L, 5)) { + luaL_argcheck(L, lua_istable(L, 5), 5, "table expected"); + lua_pushnil(L); // [dict, ..., nil] + while (lua_next(L, 5)) { + // [dict, ..., key, value] if (lua_type(L, -2) == LUA_TSTRING) { char *k = (char *)lua_tostring(L, -2); if (strequal("max_start_depth", k)) { uint32_t max_start_depth = (uint32_t)lua_tointeger(L, -1); ts_query_cursor_set_max_start_depth(cursor, max_start_depth); + } else if (strequal("match_limit", k)) { + uint32_t match_limit = (uint32_t)lua_tointeger(L, -1); + ts_query_cursor_set_match_limit(cursor, match_limit); } } - lua_pop(L, 1); // pop the value; lua_next will pop the key. - // stack: [dict, ..., key] + // pop the value; lua_next will pop the key. + lua_pop(L, 1); // [dict, ..., key] } } - TSLua_cursor *ud = lua_newuserdata(L, sizeof(*ud)); // [udata] - ud->cursor = cursor; - ud->predicated_match = -1; - ud->max_match_id = -1; + TSQueryCursor **ud = lua_newuserdata(L, sizeof(*ud)); // [node, query, ..., udata] + *ud = cursor; + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR); // [node, query, ..., udata, meta] + lua_setmetatable(L, -2); // [node, query, ..., udata] - lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR); - lua_setmetatable(L, -2); // [udata] - lua_pushvalue(L, 1); // [udata, node] + // Copy the fenv which contains the nodes tree. + lua_getfenv(L, 1); // [udata, reftable] + lua_setfenv(L, -2); // [udata] - // include query separately, as to keep a ref to it for gc - lua_pushvalue(L, 2); // [udata, node, query] + return 1; +} - if (captures) { - // placeholder for match state - lua_createtable(L, (int)ts_query_capture_count(query), 2); // [u, n, q, match] - lua_pushcclosure(L, query_next_capture, 4); // [closure] - } else { - lua_pushcclosure(L, query_next_match, 3); // [closure] +static int querycursor_remove_match(lua_State *L) +{ + TSQueryCursor *cursor = querycursor_check(L, 1); + uint32_t match_id = (uint32_t)luaL_checkinteger(L, 2); + ts_query_cursor_remove_match(cursor, match_id); + return 0; +} + +static int querycursor_next_capture(lua_State *L) +{ + TSQueryCursor *cursor = querycursor_check(L, 1); + TSQueryMatch match; + uint32_t capture_index; + if (!ts_query_cursor_next_capture(cursor, &match, &capture_index)) { + return 0; + } + + TSQueryCapture capture = match.captures[capture_index]; + + // Handle capture quantifiers here + lua_pushinteger(L, capture.index + 1); // [index] + push_node(L, capture.node, 1); // [index, node] + push_querymatch(L, &match, 1); + + return 3; +} + +static int querycursor_next_match(lua_State *L) +{ + TSQueryCursor *cursor = querycursor_check(L, 1); + + TSQueryMatch match; + if (!ts_query_cursor_next_match(cursor, &match)) { + return 0; } + push_querymatch(L, &match, 1); + return 1; } +static TSQueryCursor *querycursor_check(lua_State *L, int index) +{ + TSQueryCursor **ud = luaL_checkudata(L, index, TS_META_QUERYCURSOR); + luaL_argcheck(L, *ud, index, "TSQueryCursor expected"); + return *ud; +} + static int querycursor_gc(lua_State *L) { - TSLua_cursor *ud = luaL_checkudata(L, 1, TS_META_QUERYCURSOR); - kv_push(cursors, ud->cursor); - ud->cursor = NULL; + TSQueryCursor *cursor = querycursor_check(L, 1); + ts_query_cursor_delete(cursor); return 0; } -// Query methods +// TSQueryMatch + +static struct luaL_Reg querymatch_meta[] = { + { "info", querymatch_info }, + { "captures", querymatch_captures }, + { NULL, NULL } +}; + +static void push_querymatch(lua_State *L, TSQueryMatch *match, int uindex) +{ + TSQueryMatch *ud = lua_newuserdata(L, sizeof(TSQueryMatch)); // [udata] + *ud = *match; + lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYMATCH); // [udata, meta] + lua_setmetatable(L, -2); // [udata] + + // Copy the fenv which contains the nodes tree. + lua_getfenv(L, uindex); // [udata, reftable] + lua_setfenv(L, -2); // [udata] +} + +static int querymatch_info(lua_State *L) +{ + TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH); + lua_pushinteger(L, match->id); + lua_pushinteger(L, match->pattern_index + 1); + return 2; +} + +static int querymatch_captures(lua_State *L) +{ + TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH); + lua_newtable(L); // [match, nodes, captures] + for (size_t i = 0; i < match->capture_count; i++) { + TSQueryCapture capture = match->captures[i]; + int index = (int)capture.index + 1; + + lua_rawgeti(L, -1, index); // [match, node, captures] + if (lua_isnil(L, -1)) { // [match, node, captures, nil] + lua_pop(L, 1); // [match, node, captures] + lua_newtable(L); // [match, node, captures, nodes] + } + push_node(L, capture.node, 1); // [match, node, captures, nodes, node] + lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, node, captures, nodes] + lua_rawseti(L, -2, index); // [match, node, captures] + } + return 1; +} + +// TSQuery + +static struct luaL_Reg query_meta[] = { + { "__gc", query_gc }, + { "__tostring", query_tostring }, + { "inspect", query_inspect }, + { NULL, NULL } +}; int tslua_parse_query(lua_State *L) { @@ -1536,15 +1358,12 @@ int tslua_parse_query(lua_State *L) return luaL_error(L, "string expected"); } - const char *lang_name = lua_tostring(L, 1); - TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name); - if (!lang) { - return luaL_error(L, "no such language: %s", lang_name); - } + TSLanguage *lang = lang_check(L, 1); size_t len; const char *src = lua_tolstring(L, 2, &len); + tslua_query_parse_count++; uint32_t error_offset; TSQueryError error_type; TSQuery *query = ts_query_new(lang, src, (uint32_t)len, &error_offset, &error_type); @@ -1638,16 +1457,13 @@ static void query_err_string(const char *src, int error_offset, TSQueryError err static TSQuery *query_check(lua_State *L, int index) { TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY); + luaL_argcheck(L, *ud, index, "TSQuery expected"); return *ud; } static int query_gc(lua_State *L) { TSQuery *query = query_check(L, 1); - if (!query) { - return 0; - } - ts_query_delete(query); return 0; } @@ -1661,9 +1477,6 @@ static int query_tostring(lua_State *L) static int query_inspect(lua_State *L) { TSQuery *query = query_check(L, 1); - if (!query) { - return 0; - } // TSQueryInfo lua_createtable(L, 0, 2); // [retval] @@ -1718,3 +1531,33 @@ static int query_inspect(lua_State *L) return 1; } + +// Library init + +static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta) +{ + if (luaL_newmetatable(L, tname)) { // [meta] + luaL_register(L, NULL, meta); + + lua_pushvalue(L, -1); // [meta, meta] + lua_setfield(L, -2, "__index"); // [meta] + } + lua_pop(L, 1); // [] (don't use it now) +} + +/// Init the tslua library. +/// +/// All global state is stored in the registry of the lua_State. +void tslua_init(lua_State *L) +{ + // type metatables + build_meta(L, TS_META_PARSER, parser_meta); + build_meta(L, TS_META_TREE, tree_meta); + build_meta(L, TS_META_NODE, node_meta); + build_meta(L, TS_META_QUERY, query_meta); + build_meta(L, TS_META_QUERYCURSOR, querycursor_meta); + build_meta(L, TS_META_QUERYMATCH, querymatch_meta); + build_meta(L, TS_META_TREECURSOR, treecursor_meta); + + ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree); +} |