diff options
author | Josh Rahm <joshuarahm@gmail.com> | 2024-11-19 22:57:13 +0000 |
---|---|---|
committer | Josh Rahm <joshuarahm@gmail.com> | 2024-11-19 22:57:13 +0000 |
commit | 9be89f131f87608f224f0ee06d199fcd09d32176 (patch) | |
tree | 11022dcfa9e08cb4ac5581b16734196128688d48 /src/nvim/lua/treesitter.c | |
parent | ff7ed8f586589d620a806c3758fac4a47a8e7e15 (diff) | |
parent | 88085c2e80a7e3ac29aabb6b5420377eed99b8b6 (diff) | |
download | rneovim-9be89f131f87608f224f0ee06d199fcd09d32176.tar.gz rneovim-9be89f131f87608f224f0ee06d199fcd09d32176.tar.bz2 rneovim-9be89f131f87608f224f0ee06d199fcd09d32176.zip |
Merge remote-tracking branch 'upstream/master' into mix_20240309
Diffstat (limited to 'src/nvim/lua/treesitter.c')
-rw-r--r-- | src/nvim/lua/treesitter.c | 234 |
1 files changed, 162 insertions, 72 deletions
diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index e87cf756a8..ab97704dfe 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -15,6 +15,10 @@ #include <tree_sitter/api.h> #include <uv.h> +#ifdef HAVE_WASMTIME +# include <wasm.h> +#endif + #include "klib/kvec.h" #include "nvim/api/private/helpers.h" #include "nvim/buffer_defs.h" @@ -24,6 +28,7 @@ #include "nvim/map_defs.h" #include "nvim/memline.h" #include "nvim/memory.h" +#include "nvim/os/fs.h" #include "nvim/pos_defs.h" #include "nvim/strings.h" #include "nvim/types_defs.h" @@ -34,7 +39,6 @@ #define TS_META_QUERY "treesitter_query" #define TS_META_QUERYCURSOR "treesitter_querycursor" #define TS_META_QUERYMATCH "treesitter_querymatch" -#define TS_META_TREECURSOR "treesitter_treecursor" typedef struct { LuaRef cb; @@ -53,6 +57,11 @@ typedef struct { static PMap(cstr_t) langs = MAP_INIT; +#ifdef HAVE_WASMTIME +static wasm_engine_t *wasmengine; +static TSWasmStore *ts_wasmstore; +#endif + // TSLanguage int tslua_has_language(lua_State *L) @@ -62,8 +71,59 @@ int tslua_has_language(lua_State *L) return 1; } -static TSLanguage *load_language(lua_State *L, const char *path, const char *lang_name, - const char *symbol) +#ifdef HAVE_WASMTIME +static char *read_file(const char *path, size_t *len) + FUNC_ATTR_MALLOC +{ + FILE *file = os_fopen(path, "r"); + if (file == NULL) { + return NULL; + } + fseek(file, 0L, SEEK_END); + *len = (size_t)ftell(file); + fseek(file, 0L, SEEK_SET); + char *data = xmalloc(*len); + if (fread(data, *len, 1, file) != 1) { + xfree(data); + fclose(file); + return NULL; + } + fclose(file); + return data; +} + +static const char *wasmerr_to_str(TSWasmErrorKind werr) +{ + switch (werr) { + case TSWasmErrorKindParse: + return "PARSE"; + case TSWasmErrorKindCompile: + return "COMPILE"; + case TSWasmErrorKindInstantiate: + return "INSTANTIATE"; + case TSWasmErrorKindAllocate: + return "ALLOCATE"; + default: + return "UNKNOWN"; + } +} +#endif + +int tslua_add_language_from_wasm(lua_State *L) +{ + return add_language(L, true); +} + +// Creates the language into the internal language map. +// +// Returns true if the language is correctly loaded in the language map +int tslua_add_language_from_object(lua_State *L) +{ + return add_language(L, false); +} + +static const TSLanguage *load_language_from_object(lua_State *L, const char *path, + const char *lang_name, const char *symbol) { uv_lib_t lib; if (uv_dlopen(path, &lib)) { @@ -91,16 +151,59 @@ static TSLanguage *load_language(lua_State *L, const char *path, const char *lan return lang; } -// Creates the language into the internal language map. -// -// Returns true if the language is correctly loaded in the language map -int tslua_add_language(lua_State *L) +static const TSLanguage *load_language_from_wasm(lua_State *L, const char *path, + const char *lang_name) +{ +#ifndef HAVE_WASMTIME + luaL_error(L, "Not supported"); + return NULL; +#else + if (wasmengine == NULL) { + wasmengine = wasm_engine_new(); + } + assert(wasmengine != NULL); + + TSWasmError werr = { 0 }; + if (ts_wasmstore == NULL) { + ts_wasmstore = ts_wasm_store_new(wasmengine, &werr); + } + + if (werr.kind > 0) { + luaL_error(L, "Error creating wasm store: (%s) %s", wasmerr_to_str(werr.kind), werr.message); + } + + size_t file_size = 0; + char *data = read_file(path, &file_size); + + if (data == NULL) { + luaL_error(L, "Unable to read file", path); + } + + const TSLanguage *lang = ts_wasm_store_load_language(ts_wasmstore, lang_name, data, + (uint32_t)file_size, &werr); + + xfree(data); + + if (werr.kind > 0) { + luaL_error(L, "Failed to load WASM parser %s: (%s) %s", path, wasmerr_to_str(werr.kind), + werr.message); + } + + if (lang == NULL) { + luaL_error(L, "Failed to load parser %s: internal error", path); + } + + return lang; +#endif +} + +static int add_language(lua_State *L, bool is_wasm) { const char *path = luaL_checkstring(L, 1); const char *lang_name = luaL_checkstring(L, 2); const char *symbol_name = lang_name; - if (lua_gettop(L) >= 3 && !lua_isnil(L, 3)) { + if (!is_wasm && lua_gettop(L) >= 3 && !lua_isnil(L, 3)) { symbol_name = luaL_checkstring(L, 3); } @@ -109,7 +212,9 @@ int tslua_add_language(lua_State *L) return 1; } - TSLanguage *lang = load_language(L, path, lang_name, symbol_name); + const TSLanguage *lang = is_wasm + ? load_language_from_wasm(L, path, lang_name) + : load_language_from_object(L, path, lang_name, symbol_name); uint32_t lang_version = ts_language_version(lang); if (lang_version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION @@ -121,7 +226,7 @@ int tslua_add_language(lua_State *L) TREE_SITTER_LANGUAGE_VERSION, lang_version); } - pmap_put(cstr_t)(&langs, xstrdup(lang_name), lang); + pmap_put(cstr_t)(&langs, xstrdup(lang_name), (TSLanguage *)lang); lua_pushboolean(L, true); return 1; @@ -186,6 +291,9 @@ int tslua_inspect_lang(lua_State *L) lua_setfield(L, -2, "fields"); // [retval] + lua_pushboolean(L, ts_language_is_wasm(lang)); + lua_setfield(L, -2, "_wasm"); + lua_pushinteger(L, ts_language_version(lang)); // [retval, version] lua_setfield(L, -2, "_abi_version"); @@ -215,6 +323,13 @@ int tslua_push_parser(lua_State *L) TSParser **parser = lua_newuserdata(L, sizeof(TSParser *)); *parser = ts_parser_new(); +#ifdef HAVE_WASMTIME + if (ts_language_is_wasm(lang)) { + assert(wasmengine != NULL); + ts_parser_set_wasm_store(*parser, ts_wasmstore); + } +#endif + if (!ts_parser_set_language(*parser, lang)) { ts_parser_delete(*parser); const char *lang_name = luaL_checkstring(L, 1); @@ -279,7 +394,7 @@ static const char *input_cb(void *payload, uint32_t byte_index, TSPoint position memcpy(buf, line + position.column, tocopy); // Translate embedded \n to NUL - memchrsub(buf, '\n', '\0', tocopy); + memchrsub(buf, '\n', NUL, tocopy); *bytes_read = (uint32_t)tocopy; if (tocopy < BUFSIZE) { // now add the final \n. If it didn't fit, input_cb will be called again @@ -686,20 +801,6 @@ static int tree_root(lua_State *L) return 1; } -// TSTreeCursor - -static struct luaL_Reg treecursor_meta[] = { - { "__gc", treecursor_gc }, - { NULL, NULL } -}; - -static int treecursor_gc(lua_State *L) -{ - TSTreeCursor *cursor = luaL_checkudata(L, 1, TS_META_TREECURSOR); - ts_tree_cursor_delete(cursor); - return 0; -} - // TSNode static struct luaL_Reg node_meta[] = { { "__tostring", node_tostring }, @@ -890,23 +991,14 @@ static int node_field(lua_State *L) size_t name_len; const char *field_name = luaL_checklstring(L, 2, &name_len); - TSTreeCursor cursor = ts_tree_cursor_new(node); - lua_newtable(L); // [table] - size_t curr_index = 0; - - if (ts_tree_cursor_goto_first_child(&cursor)) { - do { - const char *current_field = ts_tree_cursor_current_field_name(&cursor); - if (current_field != NULL && !strcmp(field_name, current_field)) { - push_node(L, ts_tree_cursor_current_node(&cursor), 1); // [table, node] - lua_rawseti(L, -2, (int)++curr_index); - } - } while (ts_tree_cursor_goto_next_sibling(&cursor)); + TSNode field = ts_node_child_by_field_name(node, field_name, (uint32_t)name_len); + if (!ts_node_is_null(field)) { + push_node(L, field, 1); // [table, node] + lua_rawseti(L, -2, 1); } - ts_tree_cursor_delete(&cursor); return 1; } @@ -1002,45 +1094,35 @@ static int node_named_descendant_for_range(lua_State *L) static int node_next_child(lua_State *L) { - TSTreeCursor *cursor = luaL_checkudata(L, lua_upvalueindex(1), TS_META_TREECURSOR); + uint32_t *child_index = lua_touserdata(L, lua_upvalueindex(1)); TSNode source = node_check(L, lua_upvalueindex(2)); - // First call should return first child - if (ts_node_eq(source, ts_tree_cursor_current_node(cursor))) { - if (ts_tree_cursor_goto_first_child(cursor)) { - goto push; - } else { - return 0; - } - } - - if (!ts_tree_cursor_goto_next_sibling(cursor)) { + if (*child_index >= ts_node_child_count(source)) { return 0; } -push: - push_node(L, ts_tree_cursor_current_node(cursor), lua_upvalueindex(2)); // [node] - - const char *field = ts_tree_cursor_current_field_name(cursor); + TSNode child = ts_node_child(source, *child_index); + push_node(L, child, lua_upvalueindex(2)); + const char *field = ts_node_field_name_for_child(source, *child_index); if (field != NULL) { - lua_pushstring(L, ts_tree_cursor_current_field_name(cursor)); + lua_pushstring(L, field); } else { lua_pushnil(L); } // [node, field_name_or_nil] + + (*child_index)++; + return 2; } static int node_iter_children(lua_State *L) { - TSNode node = node_check(L, 1); - - TSTreeCursor *ud = lua_newuserdata(L, sizeof(TSTreeCursor)); // [udata] - *ud = ts_tree_cursor_new(node); + node_check(L, 1); + uint32_t *child_index = lua_newuserdata(L, sizeof(uint32_t)); // [source_node,..., udata] + *child_index = 0; - lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREECURSOR); // [udata, mt] - lua_setmetatable(L, -2); // [udata] - lua_pushvalue(L, 1); // [udata, source_node] + lua_pushvalue(L, 1); // [source_node, ..., udata, source_node] lua_pushcclosure(L, node_next_child, 2); return 1; @@ -1132,22 +1214,19 @@ static int node_prev_named_sibling(lua_State *L) static int node_named_children(lua_State *L) { TSNode source = node_check(L, 1); - TSTreeCursor cursor = ts_tree_cursor_new(source); lua_newtable(L); int curr_index = 0; - if (ts_tree_cursor_goto_first_child(&cursor)) { - do { - TSNode node = ts_tree_cursor_current_node(&cursor); - if (ts_node_is_named(node)) { - push_node(L, node, 1); - lua_rawseti(L, -2, ++curr_index); - } - } while (ts_tree_cursor_goto_next_sibling(&cursor)); + uint32_t n = ts_node_child_count(source); + for (uint32_t i = 0; i < n; i++) { + TSNode child = ts_node_child(source, i); + if (ts_node_is_named(child)) { + push_node(L, child, 1); + lua_rawseti(L, -2, ++curr_index); + } } - ts_tree_cursor_delete(&cursor); return 1; } @@ -1557,7 +1636,18 @@ void tslua_init(lua_State *L) build_meta(L, TS_META_QUERY, query_meta); build_meta(L, TS_META_QUERYCURSOR, querycursor_meta); build_meta(L, TS_META_QUERYMATCH, querymatch_meta); - build_meta(L, TS_META_TREECURSOR, treecursor_meta); ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree); } + +void tslua_free(void) +{ +#ifdef HAVE_WASMTIME + if (wasmengine != NULL) { + wasm_engine_delete(wasmengine); + } + if (ts_wasmstore != NULL) { + ts_wasm_store_delete(ts_wasmstore); + } +#endif +} |