aboutsummaryrefslogtreecommitdiff
path: root/src/nvim/lua/treesitter.c
diff options
context:
space:
mode:
authorJosh Rahm <joshuarahm@gmail.com>2024-11-19 22:57:13 +0000
committerJosh Rahm <joshuarahm@gmail.com>2024-11-19 22:57:13 +0000
commit9be89f131f87608f224f0ee06d199fcd09d32176 (patch)
tree11022dcfa9e08cb4ac5581b16734196128688d48 /src/nvim/lua/treesitter.c
parentff7ed8f586589d620a806c3758fac4a47a8e7e15 (diff)
parent88085c2e80a7e3ac29aabb6b5420377eed99b8b6 (diff)
downloadrneovim-9be89f131f87608f224f0ee06d199fcd09d32176.tar.gz
rneovim-9be89f131f87608f224f0ee06d199fcd09d32176.tar.bz2
rneovim-9be89f131f87608f224f0ee06d199fcd09d32176.zip
Merge remote-tracking branch 'upstream/master' into mix_20240309
Diffstat (limited to 'src/nvim/lua/treesitter.c')
-rw-r--r--src/nvim/lua/treesitter.c234
1 files changed, 162 insertions, 72 deletions
diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c
index e87cf756a8..ab97704dfe 100644
--- a/src/nvim/lua/treesitter.c
+++ b/src/nvim/lua/treesitter.c
@@ -15,6 +15,10 @@
#include <tree_sitter/api.h>
#include <uv.h>
+#ifdef HAVE_WASMTIME
+# include <wasm.h>
+#endif
+
#include "klib/kvec.h"
#include "nvim/api/private/helpers.h"
#include "nvim/buffer_defs.h"
@@ -24,6 +28,7 @@
#include "nvim/map_defs.h"
#include "nvim/memline.h"
#include "nvim/memory.h"
+#include "nvim/os/fs.h"
#include "nvim/pos_defs.h"
#include "nvim/strings.h"
#include "nvim/types_defs.h"
@@ -34,7 +39,6 @@
#define TS_META_QUERY "treesitter_query"
#define TS_META_QUERYCURSOR "treesitter_querycursor"
#define TS_META_QUERYMATCH "treesitter_querymatch"
-#define TS_META_TREECURSOR "treesitter_treecursor"
typedef struct {
LuaRef cb;
@@ -53,6 +57,11 @@ typedef struct {
static PMap(cstr_t) langs = MAP_INIT;
+#ifdef HAVE_WASMTIME
+static wasm_engine_t *wasmengine;
+static TSWasmStore *ts_wasmstore;
+#endif
+
// TSLanguage
int tslua_has_language(lua_State *L)
@@ -62,8 +71,59 @@ int tslua_has_language(lua_State *L)
return 1;
}
-static TSLanguage *load_language(lua_State *L, const char *path, const char *lang_name,
- const char *symbol)
+#ifdef HAVE_WASMTIME
+static char *read_file(const char *path, size_t *len)
+ FUNC_ATTR_MALLOC
+{
+ FILE *file = os_fopen(path, "r");
+ if (file == NULL) {
+ return NULL;
+ }
+ fseek(file, 0L, SEEK_END);
+ *len = (size_t)ftell(file);
+ fseek(file, 0L, SEEK_SET);
+ char *data = xmalloc(*len);
+ if (fread(data, *len, 1, file) != 1) {
+ xfree(data);
+ fclose(file);
+ return NULL;
+ }
+ fclose(file);
+ return data;
+}
+
+static const char *wasmerr_to_str(TSWasmErrorKind werr)
+{
+ switch (werr) {
+ case TSWasmErrorKindParse:
+ return "PARSE";
+ case TSWasmErrorKindCompile:
+ return "COMPILE";
+ case TSWasmErrorKindInstantiate:
+ return "INSTANTIATE";
+ case TSWasmErrorKindAllocate:
+ return "ALLOCATE";
+ default:
+ return "UNKNOWN";
+ }
+}
+#endif
+
+int tslua_add_language_from_wasm(lua_State *L)
+{
+ return add_language(L, true);
+}
+
+// Creates the language into the internal language map.
+//
+// Returns true if the language is correctly loaded in the language map
+int tslua_add_language_from_object(lua_State *L)
+{
+ return add_language(L, false);
+}
+
+static const TSLanguage *load_language_from_object(lua_State *L, const char *path,
+ const char *lang_name, const char *symbol)
{
uv_lib_t lib;
if (uv_dlopen(path, &lib)) {
@@ -91,16 +151,59 @@ static TSLanguage *load_language(lua_State *L, const char *path, const char *lan
return lang;
}
-// Creates the language into the internal language map.
-//
-// Returns true if the language is correctly loaded in the language map
-int tslua_add_language(lua_State *L)
+static const TSLanguage *load_language_from_wasm(lua_State *L, const char *path,
+ const char *lang_name)
+{
+#ifndef HAVE_WASMTIME
+ luaL_error(L, "Not supported");
+ return NULL;
+#else
+ if (wasmengine == NULL) {
+ wasmengine = wasm_engine_new();
+ }
+ assert(wasmengine != NULL);
+
+ TSWasmError werr = { 0 };
+ if (ts_wasmstore == NULL) {
+ ts_wasmstore = ts_wasm_store_new(wasmengine, &werr);
+ }
+
+ if (werr.kind > 0) {
+ luaL_error(L, "Error creating wasm store: (%s) %s", wasmerr_to_str(werr.kind), werr.message);
+ }
+
+ size_t file_size = 0;
+ char *data = read_file(path, &file_size);
+
+ if (data == NULL) {
+ luaL_error(L, "Unable to read file", path);
+ }
+
+ const TSLanguage *lang = ts_wasm_store_load_language(ts_wasmstore, lang_name, data,
+ (uint32_t)file_size, &werr);
+
+ xfree(data);
+
+ if (werr.kind > 0) {
+ luaL_error(L, "Failed to load WASM parser %s: (%s) %s", path, wasmerr_to_str(werr.kind),
+ werr.message);
+ }
+
+ if (lang == NULL) {
+ luaL_error(L, "Failed to load parser %s: internal error", path);
+ }
+
+ return lang;
+#endif
+}
+
+static int add_language(lua_State *L, bool is_wasm)
{
const char *path = luaL_checkstring(L, 1);
const char *lang_name = luaL_checkstring(L, 2);
const char *symbol_name = lang_name;
- if (lua_gettop(L) >= 3 && !lua_isnil(L, 3)) {
+ if (!is_wasm && lua_gettop(L) >= 3 && !lua_isnil(L, 3)) {
symbol_name = luaL_checkstring(L, 3);
}
@@ -109,7 +212,9 @@ int tslua_add_language(lua_State *L)
return 1;
}
- TSLanguage *lang = load_language(L, path, lang_name, symbol_name);
+ const TSLanguage *lang = is_wasm
+ ? load_language_from_wasm(L, path, lang_name)
+ : load_language_from_object(L, path, lang_name, symbol_name);
uint32_t lang_version = ts_language_version(lang);
if (lang_version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION
@@ -121,7 +226,7 @@ int tslua_add_language(lua_State *L)
TREE_SITTER_LANGUAGE_VERSION, lang_version);
}
- pmap_put(cstr_t)(&langs, xstrdup(lang_name), lang);
+ pmap_put(cstr_t)(&langs, xstrdup(lang_name), (TSLanguage *)lang);
lua_pushboolean(L, true);
return 1;
@@ -186,6 +291,9 @@ int tslua_inspect_lang(lua_State *L)
lua_setfield(L, -2, "fields"); // [retval]
+ lua_pushboolean(L, ts_language_is_wasm(lang));
+ lua_setfield(L, -2, "_wasm");
+
lua_pushinteger(L, ts_language_version(lang)); // [retval, version]
lua_setfield(L, -2, "_abi_version");
@@ -215,6 +323,13 @@ int tslua_push_parser(lua_State *L)
TSParser **parser = lua_newuserdata(L, sizeof(TSParser *));
*parser = ts_parser_new();
+#ifdef HAVE_WASMTIME
+ if (ts_language_is_wasm(lang)) {
+ assert(wasmengine != NULL);
+ ts_parser_set_wasm_store(*parser, ts_wasmstore);
+ }
+#endif
+
if (!ts_parser_set_language(*parser, lang)) {
ts_parser_delete(*parser);
const char *lang_name = luaL_checkstring(L, 1);
@@ -279,7 +394,7 @@ static const char *input_cb(void *payload, uint32_t byte_index, TSPoint position
memcpy(buf, line + position.column, tocopy);
// Translate embedded \n to NUL
- memchrsub(buf, '\n', '\0', tocopy);
+ memchrsub(buf, '\n', NUL, tocopy);
*bytes_read = (uint32_t)tocopy;
if (tocopy < BUFSIZE) {
// now add the final \n. If it didn't fit, input_cb will be called again
@@ -686,20 +801,6 @@ static int tree_root(lua_State *L)
return 1;
}
-// TSTreeCursor
-
-static struct luaL_Reg treecursor_meta[] = {
- { "__gc", treecursor_gc },
- { NULL, NULL }
-};
-
-static int treecursor_gc(lua_State *L)
-{
- TSTreeCursor *cursor = luaL_checkudata(L, 1, TS_META_TREECURSOR);
- ts_tree_cursor_delete(cursor);
- return 0;
-}
-
// TSNode
static struct luaL_Reg node_meta[] = {
{ "__tostring", node_tostring },
@@ -890,23 +991,14 @@ static int node_field(lua_State *L)
size_t name_len;
const char *field_name = luaL_checklstring(L, 2, &name_len);
- TSTreeCursor cursor = ts_tree_cursor_new(node);
-
lua_newtable(L); // [table]
- size_t curr_index = 0;
-
- if (ts_tree_cursor_goto_first_child(&cursor)) {
- do {
- const char *current_field = ts_tree_cursor_current_field_name(&cursor);
- if (current_field != NULL && !strcmp(field_name, current_field)) {
- push_node(L, ts_tree_cursor_current_node(&cursor), 1); // [table, node]
- lua_rawseti(L, -2, (int)++curr_index);
- }
- } while (ts_tree_cursor_goto_next_sibling(&cursor));
+ TSNode field = ts_node_child_by_field_name(node, field_name, (uint32_t)name_len);
+ if (!ts_node_is_null(field)) {
+ push_node(L, field, 1); // [table, node]
+ lua_rawseti(L, -2, 1);
}
- ts_tree_cursor_delete(&cursor);
return 1;
}
@@ -1002,45 +1094,35 @@ static int node_named_descendant_for_range(lua_State *L)
static int node_next_child(lua_State *L)
{
- TSTreeCursor *cursor = luaL_checkudata(L, lua_upvalueindex(1), TS_META_TREECURSOR);
+ uint32_t *child_index = lua_touserdata(L, lua_upvalueindex(1));
TSNode source = node_check(L, lua_upvalueindex(2));
- // First call should return first child
- if (ts_node_eq(source, ts_tree_cursor_current_node(cursor))) {
- if (ts_tree_cursor_goto_first_child(cursor)) {
- goto push;
- } else {
- return 0;
- }
- }
-
- if (!ts_tree_cursor_goto_next_sibling(cursor)) {
+ if (*child_index >= ts_node_child_count(source)) {
return 0;
}
-push:
- push_node(L, ts_tree_cursor_current_node(cursor), lua_upvalueindex(2)); // [node]
-
- const char *field = ts_tree_cursor_current_field_name(cursor);
+ TSNode child = ts_node_child(source, *child_index);
+ push_node(L, child, lua_upvalueindex(2));
+ const char *field = ts_node_field_name_for_child(source, *child_index);
if (field != NULL) {
- lua_pushstring(L, ts_tree_cursor_current_field_name(cursor));
+ lua_pushstring(L, field);
} else {
lua_pushnil(L);
} // [node, field_name_or_nil]
+
+ (*child_index)++;
+
return 2;
}
static int node_iter_children(lua_State *L)
{
- TSNode node = node_check(L, 1);
-
- TSTreeCursor *ud = lua_newuserdata(L, sizeof(TSTreeCursor)); // [udata]
- *ud = ts_tree_cursor_new(node);
+ node_check(L, 1);
+ uint32_t *child_index = lua_newuserdata(L, sizeof(uint32_t)); // [source_node,..., udata]
+ *child_index = 0;
- lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREECURSOR); // [udata, mt]
- lua_setmetatable(L, -2); // [udata]
- lua_pushvalue(L, 1); // [udata, source_node]
+ lua_pushvalue(L, 1); // [source_node, ..., udata, source_node]
lua_pushcclosure(L, node_next_child, 2);
return 1;
@@ -1132,22 +1214,19 @@ static int node_prev_named_sibling(lua_State *L)
static int node_named_children(lua_State *L)
{
TSNode source = node_check(L, 1);
- TSTreeCursor cursor = ts_tree_cursor_new(source);
lua_newtable(L);
int curr_index = 0;
- if (ts_tree_cursor_goto_first_child(&cursor)) {
- do {
- TSNode node = ts_tree_cursor_current_node(&cursor);
- if (ts_node_is_named(node)) {
- push_node(L, node, 1);
- lua_rawseti(L, -2, ++curr_index);
- }
- } while (ts_tree_cursor_goto_next_sibling(&cursor));
+ uint32_t n = ts_node_child_count(source);
+ for (uint32_t i = 0; i < n; i++) {
+ TSNode child = ts_node_child(source, i);
+ if (ts_node_is_named(child)) {
+ push_node(L, child, 1);
+ lua_rawseti(L, -2, ++curr_index);
+ }
}
- ts_tree_cursor_delete(&cursor);
return 1;
}
@@ -1557,7 +1636,18 @@ void tslua_init(lua_State *L)
build_meta(L, TS_META_QUERY, query_meta);
build_meta(L, TS_META_QUERYCURSOR, querycursor_meta);
build_meta(L, TS_META_QUERYMATCH, querymatch_meta);
- build_meta(L, TS_META_TREECURSOR, treecursor_meta);
ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree);
}
+
+void tslua_free(void)
+{
+#ifdef HAVE_WASMTIME
+ if (wasmengine != NULL) {
+ wasm_engine_delete(wasmengine);
+ }
+ if (ts_wasmstore != NULL) {
+ ts_wasm_store_delete(ts_wasmstore);
+ }
+#endif
+}