aboutsummaryrefslogtreecommitdiff
path: root/src/nvim/lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/nvim/lua')
-rw-r--r--src/nvim/lua/api_wrappers.c18
-rw-r--r--src/nvim/lua/base64.c5
-rw-r--r--src/nvim/lua/converter.c52
-rw-r--r--src/nvim/lua/converter.h6
-rw-r--r--src/nvim/lua/executor.c43
-rw-r--r--src/nvim/lua/executor.h4
-rw-r--r--src/nvim/lua/stdlib.c33
-rw-r--r--src/nvim/lua/treesitter.c1109
-rw-r--r--src/nvim/lua/treesitter.h5
9 files changed, 579 insertions, 696 deletions
diff --git a/src/nvim/lua/api_wrappers.c b/src/nvim/lua/api_wrappers.c
new file mode 100644
index 0000000000..2b7b0c6471
--- /dev/null
+++ b/src/nvim/lua/api_wrappers.c
@@ -0,0 +1,18 @@
+#include <lauxlib.h>
+#include <lua.h>
+#include <lualib.h>
+
+#include "nvim/api/private/defs.h"
+#include "nvim/api/private/dispatch.h"
+#include "nvim/api/private/helpers.h"
+#include "nvim/ex_docmd.h"
+#include "nvim/ex_getln.h"
+#include "nvim/func_attr.h"
+#include "nvim/globals.h"
+#include "nvim/lua/converter.h"
+#include "nvim/lua/executor.h"
+#include "nvim/memory.h"
+
+#ifdef INCLUDE_GENERATED_DECLARATIONS
+# include "lua_api_c_bindings.generated.h"
+#endif
diff --git a/src/nvim/lua/base64.c b/src/nvim/lua/base64.c
index c1f43a37d7..8fe918493a 100644
--- a/src/nvim/lua/base64.c
+++ b/src/nvim/lua/base64.c
@@ -45,12 +45,13 @@ static int nlua_base64_decode(lua_State *L)
size_t src_len = 0;
const char *src = lua_tolstring(L, 1, &src_len);
- const char *ret = base64_decode(src, src_len);
+ size_t out_len = 0;
+ const char *ret = base64_decode(src, src_len, &out_len);
if (ret == NULL) {
return luaL_error(L, "Invalid input");
}
- lua_pushstring(L, ret);
+ lua_pushlstring(L, ret, out_len);
xfree((void *)ret);
return 1;
diff --git a/src/nvim/lua/converter.c b/src/nvim/lua/converter.c
index bba771f8a5..38ccb03cfc 100644
--- a/src/nvim/lua/converter.c
+++ b/src/nvim/lua/converter.c
@@ -597,9 +597,9 @@ static bool typval_conv_special = false;
/// @param[in] tv typval_T to convert.
///
/// @return true in case of success, false otherwise.
-bool nlua_push_typval(lua_State *lstate, typval_T *const tv, bool special)
+bool nlua_push_typval(lua_State *lstate, typval_T *const tv, int flags)
{
- typval_conv_special = special;
+ typval_conv_special = (flags & kNluaPushSpecial);
const int initial_size = lua_gettop(lstate);
if (!lua_checkstack(lstate, initial_size + 2)) {
@@ -662,7 +662,7 @@ static inline void nlua_create_typed_table(lua_State *lstate, const size_t narr,
/// Convert given String to lua string
///
/// Leaves converted string on top of the stack.
-void nlua_push_String(lua_State *lstate, const String s, bool special)
+void nlua_push_String(lua_State *lstate, const String s, int flags)
FUNC_ATTR_NONNULL_ALL
{
lua_pushlstring(lstate, s.data, s.size);
@@ -671,7 +671,7 @@ void nlua_push_String(lua_State *lstate, const String s, bool special)
/// Convert given Integer to lua number
///
/// Leaves converted number on top of the stack.
-void nlua_push_Integer(lua_State *lstate, const Integer n, bool special)
+void nlua_push_Integer(lua_State *lstate, const Integer n, int flags)
FUNC_ATTR_NONNULL_ALL
{
lua_pushnumber(lstate, (lua_Number)n);
@@ -680,10 +680,10 @@ void nlua_push_Integer(lua_State *lstate, const Integer n, bool special)
/// Convert given Float to lua table
///
/// Leaves converted table on top of the stack.
-void nlua_push_Float(lua_State *lstate, const Float f, bool special)
+void nlua_push_Float(lua_State *lstate, const Float f, int flags)
FUNC_ATTR_NONNULL_ALL
{
- if (special) {
+ if (flags & kNluaPushSpecial) {
nlua_create_typed_table(lstate, 0, 1, kObjectTypeFloat);
nlua_push_val_idx(lstate);
lua_pushnumber(lstate, (lua_Number)f);
@@ -696,7 +696,7 @@ void nlua_push_Float(lua_State *lstate, const Float f, bool special)
/// Convert given Float to lua boolean
///
/// Leaves converted value on top of the stack.
-void nlua_push_Boolean(lua_State *lstate, const Boolean b, bool special)
+void nlua_push_Boolean(lua_State *lstate, const Boolean b, int flags)
FUNC_ATTR_NONNULL_ALL
{
lua_pushboolean(lstate, b);
@@ -705,21 +705,21 @@ void nlua_push_Boolean(lua_State *lstate, const Boolean b, bool special)
/// Convert given Dictionary to lua table
///
/// Leaves converted table on top of the stack.
-void nlua_push_Dictionary(lua_State *lstate, const Dictionary dict, bool special)
+void nlua_push_Dictionary(lua_State *lstate, const Dictionary dict, int flags)
FUNC_ATTR_NONNULL_ALL
{
- if (dict.size == 0 && special) {
+ if (dict.size == 0 && (flags & kNluaPushSpecial)) {
nlua_create_typed_table(lstate, 0, 0, kObjectTypeDictionary);
} else {
lua_createtable(lstate, 0, (int)dict.size);
- if (dict.size == 0 && !special) {
+ if (dict.size == 0 && !(flags & kNluaPushSpecial)) {
nlua_pushref(lstate, nlua_global_refs->empty_dict_ref);
lua_setmetatable(lstate, -2);
}
}
for (size_t i = 0; i < dict.size; i++) {
- nlua_push_String(lstate, dict.items[i].key, special);
- nlua_push_Object(lstate, &dict.items[i].value, special);
+ nlua_push_String(lstate, dict.items[i].key, flags);
+ nlua_push_Object(lstate, &dict.items[i].value, flags);
lua_rawset(lstate, -3);
}
}
@@ -727,18 +727,18 @@ void nlua_push_Dictionary(lua_State *lstate, const Dictionary dict, bool special
/// Convert given Array to lua table
///
/// Leaves converted table on top of the stack.
-void nlua_push_Array(lua_State *lstate, const Array array, bool special)
+void nlua_push_Array(lua_State *lstate, const Array array, int flags)
FUNC_ATTR_NONNULL_ALL
{
lua_createtable(lstate, (int)array.size, 0);
for (size_t i = 0; i < array.size; i++) {
- nlua_push_Object(lstate, &array.items[i], special);
+ nlua_push_Object(lstate, &array.items[i], flags);
lua_rawseti(lstate, -2, (int)i + 1);
}
}
#define GENERATE_INDEX_FUNCTION(type) \
- void nlua_push_##type(lua_State *lstate, const type item, bool special) \
+ void nlua_push_##type(lua_State *lstate, const type item, int flags) \
FUNC_ATTR_NONNULL_ALL \
{ \
lua_pushnumber(lstate, (lua_Number)(item)); \
@@ -753,12 +753,12 @@ GENERATE_INDEX_FUNCTION(Tabpage)
/// Convert given Object to lua value
///
/// Leaves converted value on top of the stack.
-void nlua_push_Object(lua_State *lstate, Object *obj, bool special)
+void nlua_push_Object(lua_State *lstate, Object *obj, int flags)
FUNC_ATTR_NONNULL_ALL
{
switch (obj->type) {
case kObjectTypeNil:
- if (special) {
+ if (flags & kNluaPushSpecial) {
lua_pushnil(lstate);
} else {
nlua_pushref(lstate, nlua_global_refs->nil_ref);
@@ -766,13 +766,15 @@ void nlua_push_Object(lua_State *lstate, Object *obj, bool special)
break;
case kObjectTypeLuaRef: {
nlua_pushref(lstate, obj->data.luaref);
- api_free_luaref(obj->data.luaref);
- obj->data.luaref = LUA_NOREF;
+ if (flags & kNluaPushFreeRefs) {
+ api_free_luaref(obj->data.luaref);
+ obj->data.luaref = LUA_NOREF;
+ }
break;
}
#define ADD_TYPE(type, data_key) \
case kObjectType##type: { \
- nlua_push_##type(lstate, obj->data.data_key, special); \
+ nlua_push_##type(lstate, obj->data.data_key, flags); \
break; \
}
ADD_TYPE(Boolean, boolean)
@@ -784,7 +786,7 @@ void nlua_push_Object(lua_State *lstate, Object *obj, bool special)
#undef ADD_TYPE
#define ADD_REMOTE_TYPE(type) \
case kObjectType##type: { \
- nlua_push_##type(lstate, (type)obj->data.integer, special); \
+ nlua_push_##type(lstate, (type)obj->data.integer, flags); \
break; \
}
ADD_REMOTE_TYPE(Buffer)
@@ -1380,7 +1382,7 @@ void nlua_push_keydict(lua_State *L, void *value, KeySetLink *table)
lua_pushstring(L, field->str);
if (field->type == kObjectTypeNil) {
- nlua_push_Object(L, (Object *)mem, false);
+ nlua_push_Object(L, (Object *)mem, 0);
} else if (field->type == kObjectTypeInteger) {
lua_pushinteger(L, *(Integer *)mem);
} else if (field->type == kObjectTypeBuffer || field->type == kObjectTypeWindow
@@ -1391,11 +1393,11 @@ void nlua_push_keydict(lua_State *L, void *value, KeySetLink *table)
} else if (field->type == kObjectTypeBoolean) {
lua_pushboolean(L, *(Boolean *)mem);
} else if (field->type == kObjectTypeString) {
- nlua_push_String(L, *(String *)mem, false);
+ nlua_push_String(L, *(String *)mem, 0);
} else if (field->type == kObjectTypeArray) {
- nlua_push_Array(L, *(Array *)mem, false);
+ nlua_push_Array(L, *(Array *)mem, 0);
} else if (field->type == kObjectTypeDictionary) {
- nlua_push_Dictionary(L, *(Dictionary *)mem, false);
+ nlua_push_Dictionary(L, *(Dictionary *)mem, 0);
} else if (field->type == kObjectTypeLuaRef) {
nlua_pushref(L, *(LuaRef *)mem);
} else {
diff --git a/src/nvim/lua/converter.h b/src/nvim/lua/converter.h
index a502df80d9..d1ba61bcee 100644
--- a/src/nvim/lua/converter.h
+++ b/src/nvim/lua/converter.h
@@ -9,6 +9,12 @@
#define nlua_pop_Window nlua_pop_handle
#define nlua_pop_Tabpage nlua_pop_handle
+/// Flags for nlua_push_*() functions.
+enum {
+ kNluaPushSpecial = 0x01, ///< Use lua-special-tbl when necessary
+ kNluaPushFreeRefs = 0x02, ///< Free luarefs to elide an api_luarefs_free_*() later
+};
+
#ifdef INCLUDE_GENERATED_DECLARATIONS
# include "lua/converter.h.generated.h"
#endif
diff --git a/src/nvim/lua/executor.c b/src/nvim/lua/executor.c
index 1a9bd026b5..a76b8213e5 100644
--- a/src/nvim/lua/executor.c
+++ b/src/nvim/lua/executor.c
@@ -103,7 +103,7 @@ typedef struct {
if (args[i].v_type == VAR_UNKNOWN) { \
lua_pushnil(lstate); \
} else { \
- nlua_push_typval(lstate, &args[i], special); \
+ nlua_push_typval(lstate, &args[i], (special) ? kNluaPushSpecial : 0); \
} \
}
@@ -325,7 +325,7 @@ static int nlua_thr_api_nvim__get_runtime(lua_State *lstate)
}
ArrayOf(String) ret = runtime_get_named_thread(is_lua, pat, all);
- nlua_push_Array(lstate, ret, true);
+ nlua_push_Array(lstate, ret, kNluaPushSpecial);
api_free_array(ret);
api_free_array(pat);
@@ -1210,7 +1210,7 @@ int nlua_call(lua_State *lstate)
});
if (!ERROR_SET(&err)) {
- nlua_push_typval(lstate, &rettv, false);
+ nlua_push_typval(lstate, &rettv, 0);
}
tv_clear(&rettv);
@@ -1261,7 +1261,7 @@ static int nlua_rpc(lua_State *lstate, bool request)
ArenaMem res_mem = NULL;
Object result = rpc_send_call(chan_id, name, args, &res_mem, &err);
if (!ERROR_SET(&err)) {
- nlua_push_Object(lstate, &result, false);
+ nlua_push_Object(lstate, &result, 0);
arena_mem_free(res_mem);
}
} else {
@@ -1487,7 +1487,7 @@ static void nlua_typval_exec(const char *lcmd, size_t lcmd_len, const char *name
}
}
-int nlua_source_using_linegetter(LineGetter fgetline, void *cookie, char *name)
+void nlua_source_str(const char *code, char *name)
{
const sctx_T save_current_sctx = current_sctx;
current_sctx.sc_sid = SID_STR;
@@ -1495,22 +1495,11 @@ int nlua_source_using_linegetter(LineGetter fgetline, void *cookie, char *name)
current_sctx.sc_lnum = 0;
estack_push(ETYPE_SCRIPT, name, 0);
- garray_T ga;
- char *line = NULL;
-
- ga_init(&ga, (int)sizeof(char *), 10);
- while ((line = fgetline(0, cookie, 0, false)) != NULL) {
- GA_APPEND(char *, &ga, line);
- }
- char *code = ga_concat_strings_sep(&ga, "\n");
size_t len = strlen(code);
nlua_typval_exec(code, len, name, NULL, 0, false, NULL);
estack_pop();
current_sctx = save_current_sctx;
- ga_clear_strings(&ga);
- xfree(code);
- return OK;
}
/// Call a LuaCallable given some typvals
@@ -1564,7 +1553,7 @@ Object nlua_exec(const String str, const Array args, LuaRetMode mode, Arena *are
}
for (size_t i = 0; i < args.size; i++) {
- nlua_push_Object(lstate, &args.items[i], false);
+ nlua_push_Object(lstate, &args.items[i], 0);
}
if (nlua_pcall(lstate, (int)args.size, 1)) {
@@ -1611,7 +1600,7 @@ Object nlua_call_ref(LuaRef ref, const char *name, Array args, LuaRetMode mode,
nargs++;
}
for (size_t i = 0; i < args.size; i++) {
- nlua_push_Object(lstate, &args.items[i], false);
+ nlua_push_Object(lstate, &args.items[i], 0);
}
if (nlua_pcall(lstate, nargs, 1)) {
@@ -1767,7 +1756,7 @@ void ex_luado(exarg_T *const eap)
lua_pushvalue(lstate, -1);
const char *const old_line = ml_get_buf(curbuf, l);
// Get length of old_line here as calling Lua code may free it.
- const size_t old_line_len = strlen(old_line);
+ const colnr_T old_line_len = ml_get_buf_len(curbuf, l);
lua_pushstring(lstate, old_line);
lua_pushnumber(lstate, (lua_Number)l);
if (nlua_pcall(lstate, 2, 1)) {
@@ -1791,13 +1780,13 @@ void ex_luado(exarg_T *const eap)
}
}
ml_replace(l, new_line_transformed, false);
- inserted_bytes(l, 0, (int)old_line_len, (int)new_line_len);
+ inserted_bytes(l, 0, old_line_len, (int)new_line_len);
}
lua_pop(lstate, 1);
}
lua_pop(lstate, 1);
- check_cursor();
+ check_cursor(curwin);
redraw_curbuf_later(UPD_NOT_VALID);
}
@@ -1909,6 +1898,9 @@ static void nlua_add_treesitter(lua_State *const lstate) FUNC_ATTR_NONNULL_ALL
lua_pushcfunction(lstate, tslua_push_parser);
lua_setfield(lstate, -2, "_create_ts_parser");
+ lua_pushcfunction(lstate, tslua_push_querycursor);
+ lua_setfield(lstate, -2, "_create_ts_querycursor");
+
lua_pushcfunction(lstate, tslua_add_language);
lua_setfield(lstate, -2, "_ts_add_language");
@@ -2061,9 +2053,9 @@ char *nlua_register_table_as_callable(const typval_T *const arg)
return name;
}
-void nlua_execute_on_key(int c)
+void nlua_execute_on_key(int c, char *typed_buf, size_t typed_len)
{
- char buf[NUMBUFLEN];
+ char buf[MB_MAXBYTES * 3 + 4];
size_t buf_len = special_to_buf(c, mod_mask, false, buf);
lua_State *const lstate = global_lstate;
@@ -2082,9 +2074,12 @@ void nlua_execute_on_key(int c)
// [ vim, vim._on_key, buf ]
lua_pushlstring(lstate, buf, buf_len);
+ // [ vim, vim._on_key, buf, typed_buf ]
+ lua_pushlstring(lstate, typed_buf, typed_len);
+
int save_got_int = got_int;
got_int = false; // avoid interrupts when the key typed is Ctrl-C
- if (nlua_pcall(lstate, 1, 0)) {
+ if (nlua_pcall(lstate, 2, 0)) {
nlua_error(lstate,
_("Error executing vim.on_key Lua callback: %.*s"));
}
diff --git a/src/nvim/lua/executor.h b/src/nvim/lua/executor.h
index ebcd62122f..32fde3853b 100644
--- a/src/nvim/lua/executor.h
+++ b/src/nvim/lua/executor.h
@@ -12,7 +12,7 @@
#include "nvim/types_defs.h"
#include "nvim/usercmd.h" // IWYU pragma: keep
-// Generated by msgpack-gen.lua
+// Generated by generators/gen_api_dispatch.lua
void nlua_add_api_functions(lua_State *lstate) REAL_FATTR_NONNULL_ALL;
typedef struct {
@@ -43,7 +43,7 @@ typedef enum {
kRetLuaref, ///< return value becomes a single Luaref, regardless of type (except NIL)
} LuaRetMode;
-/// To use with kRetNilBool for quick thuthyness check
+/// To use with kRetNilBool for quick truthiness check
#define LUARET_TRUTHY(res) ((res).type == kObjectTypeBoolean && (res).data.boolean == true)
#ifdef INCLUDE_GENERATED_DECLARATIONS
diff --git a/src/nvim/lua/stdlib.c b/src/nvim/lua/stdlib.c
index 8f58fd1a1a..22ee0a1c98 100644
--- a/src/nvim/lua/stdlib.c
+++ b/src/nvim/lua/stdlib.c
@@ -107,15 +107,15 @@ static int regex_match_line(lua_State *lstate)
}
char *line = ml_get_buf(buf, rownr + 1);
- size_t len = strlen(line);
+ colnr_T len = ml_get_buf_len(buf, rownr + 1);
- if (start < 0 || (size_t)start > len) {
+ if (start < 0 || start > len) {
return luaL_error(lstate, "invalid start");
}
char save = NUL;
if (end >= 0) {
- if ((size_t)end > len || end < start) {
+ if (end > len || end < start) {
return luaL_error(lstate, "invalid end");
}
save = line[end];
@@ -449,7 +449,7 @@ int nlua_getvar(lua_State *lstate)
if (di == NULL) {
return 0; // nil
}
- nlua_push_typval(lstate, &di->di_tv, false);
+ nlua_push_typval(lstate, &di->di_tv, 0);
return 1;
}
@@ -543,14 +543,27 @@ static int nlua_iconv(lua_State *lstate)
return 1;
}
-// Update foldlevels (e.g., by evaluating 'foldexpr') for all lines in the current window without
-// invoking other side effects. Unlike `zx`, it does not close manually opened folds and does not
-// open folds under the cursor.
+// Update foldlevels (e.g., by evaluating 'foldexpr') for the given line range in the given window,
+// without invoking other side effects. Unlike `zx`, it does not close manually opened folds and
+// does not open folds under the cursor.
static int nlua_foldupdate(lua_State *lstate)
{
- curwin->w_foldinvalid = true; // recompute folds
- foldUpdate(curwin, 1, (linenr_T)MAXLNUM);
- curwin->w_foldinvalid = false;
+ handle_T window = (handle_T)luaL_checkinteger(lstate, 1);
+ win_T *win = handle_get_window(window);
+ if (!win) {
+ return luaL_error(lstate, "invalid window");
+ }
+ // input is zero-based end-exclusive range
+ linenr_T top = (linenr_T)luaL_checkinteger(lstate, 2) + 1;
+ if (top < 1 || top > win->w_buffer->b_ml.ml_line_count) {
+ return luaL_error(lstate, "invalid top");
+ }
+ linenr_T bot = (linenr_T)luaL_checkinteger(lstate, 3);
+ if (top > bot) {
+ return luaL_error(lstate, "invalid bot");
+ }
+
+ foldUpdate(win, top, bot);
return 0;
}
diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c
index 25a753b179..e87cf756a8 100644
--- a/src/nvim/lua/treesitter.c
+++ b/src/nvim/lua/treesitter.c
@@ -33,15 +33,10 @@
#define TS_META_NODE "treesitter_node"
#define TS_META_QUERY "treesitter_query"
#define TS_META_QUERYCURSOR "treesitter_querycursor"
+#define TS_META_QUERYMATCH "treesitter_querymatch"
#define TS_META_TREECURSOR "treesitter_treecursor"
typedef struct {
- TSQueryCursor *cursor;
- int predicated_match;
- int max_match_id;
-} TSLua_cursor;
-
-typedef struct {
LuaRef cb;
lua_State *lstate;
bool lex;
@@ -56,126 +51,44 @@ typedef struct {
# include "lua/treesitter.c.generated.h"
#endif
-// TSParser
-static struct luaL_Reg parser_meta[] = {
- { "__gc", parser_gc },
- { "__tostring", parser_tostring },
- { "parse", parser_parse },
- { "reset", parser_reset },
- { "set_included_ranges", parser_set_ranges },
- { "included_ranges", parser_get_ranges },
- { "set_timeout", parser_set_timeout },
- { "timeout", parser_get_timeout },
- { "_set_logger", parser_set_logger },
- { "_logger", parser_get_logger },
- { NULL, NULL }
-};
-
-// TSTree
-static struct luaL_Reg tree_meta[] = {
- { "__gc", tree_gc },
- { "__tostring", tree_tostring },
- { "root", tree_root },
- { "edit", tree_edit },
- { "included_ranges", tree_get_ranges },
- { "copy", tree_copy },
- { NULL, NULL }
-};
-
-// TSNode
-static struct luaL_Reg node_meta[] = {
- { "__tostring", node_tostring },
- { "__eq", node_eq },
- { "__len", node_child_count },
- { "id", node_id },
- { "range", node_range },
- { "start", node_start },
- { "end_", node_end },
- { "type", node_type },
- { "symbol", node_symbol },
- { "field", node_field },
- { "named", node_named },
- { "missing", node_missing },
- { "extra", node_extra },
- { "has_changes", node_has_changes },
- { "has_error", node_has_error },
- { "sexpr", node_sexpr },
- { "child_count", node_child_count },
- { "named_child_count", node_named_child_count },
- { "child", node_child },
- { "named_child", node_named_child },
- { "descendant_for_range", node_descendant_for_range },
- { "named_descendant_for_range", node_named_descendant_for_range },
- { "parent", node_parent },
- { "iter_children", node_iter_children },
- { "_rawquery", node_rawquery },
- { "next_sibling", node_next_sibling },
- { "prev_sibling", node_prev_sibling },
- { "next_named_sibling", node_next_named_sibling },
- { "prev_named_sibling", node_prev_named_sibling },
- { "named_children", node_named_children },
- { "root", node_root },
- { "tree", node_tree },
- { "byte_length", node_byte_length },
- { "equal", node_equal },
-
- { NULL, NULL }
-};
-
-// TSQuery
-static struct luaL_Reg query_meta[] = {
- { "__gc", query_gc },
- { "__tostring", query_tostring },
- { "inspect", query_inspect },
- { NULL, NULL }
-};
+static PMap(cstr_t) langs = MAP_INIT;
-// cursors are not exposed, but still needs garbage collection
-static struct luaL_Reg querycursor_meta[] = {
- { "__gc", querycursor_gc },
- { NULL, NULL }
-};
+// TSLanguage
-static struct luaL_Reg treecursor_meta[] = {
- { "__gc", treecursor_gc },
- { NULL, NULL }
-};
-
-static kvec_t(TSQueryCursor *) cursors = KV_INITIAL_VALUE;
-static PMap(cstr_t) langs = MAP_INIT;
+int tslua_has_language(lua_State *L)
+{
+ const char *lang_name = luaL_checkstring(L, 1);
+ lua_pushboolean(L, map_has(cstr_t, &langs, lang_name));
+ return 1;
+}
-static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta)
+static TSLanguage *load_language(lua_State *L, const char *path, const char *lang_name,
+ const char *symbol)
{
- if (luaL_newmetatable(L, tname)) { // [meta]
- luaL_register(L, NULL, meta);
+ uv_lib_t lib;
+ if (uv_dlopen(path, &lib)) {
+ uv_dlclose(&lib);
+ luaL_error(L, "Failed to load parser for language '%s': uv_dlopen: %s",
+ lang_name, uv_dlerror(&lib));
+ }
- lua_pushvalue(L, -1); // [meta, meta]
- lua_setfield(L, -2, "__index"); // [meta]
+ char symbol_buf[128];
+ snprintf(symbol_buf, sizeof(symbol_buf), "tree_sitter_%s", symbol);
+
+ TSLanguage *(*lang_parser)(void);
+ if (uv_dlsym(&lib, symbol_buf, (void **)&lang_parser)) {
+ uv_dlclose(&lib);
+ luaL_error(L, "Failed to load parser: uv_dlsym: %s", uv_dlerror(&lib));
}
- lua_pop(L, 1); // [] (don't use it now)
-}
-/// Init the tslua library.
-///
-/// All global state is stored in the registry of the lua_State.
-void tslua_init(lua_State *L)
-{
- // type metatables
- build_meta(L, TS_META_PARSER, parser_meta);
- build_meta(L, TS_META_TREE, tree_meta);
- build_meta(L, TS_META_NODE, node_meta);
- build_meta(L, TS_META_QUERY, query_meta);
- build_meta(L, TS_META_QUERYCURSOR, querycursor_meta);
- build_meta(L, TS_META_TREECURSOR, treecursor_meta);
+ TSLanguage *lang = lang_parser();
- ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree);
-}
+ if (lang == NULL) {
+ uv_dlclose(&lib);
+ luaL_error(L, "Failed to load parser %s: internal error", path);
+ }
-int tslua_has_language(lua_State *L)
-{
- const char *lang_name = luaL_checkstring(L, 1);
- lua_pushboolean(L, map_has(cstr_t, &langs, lang_name));
- return 1;
+ return lang;
}
// Creates the language into the internal language map.
@@ -196,34 +109,7 @@ int tslua_add_language(lua_State *L)
return 1;
}
-#define BUFSIZE 128
- char symbol_buf[BUFSIZE];
- snprintf(symbol_buf, BUFSIZE, "tree_sitter_%s", symbol_name);
-#undef BUFSIZE
-
- uv_lib_t lib;
- if (uv_dlopen(path, &lib)) {
- snprintf(IObuff, IOSIZE, "Failed to load parser for language '%s': uv_dlopen: %s",
- lang_name, uv_dlerror(&lib));
- uv_dlclose(&lib);
- lua_pushstring(L, IObuff);
- return lua_error(L);
- }
-
- TSLanguage *(*lang_parser)(void);
- if (uv_dlsym(&lib, symbol_buf, (void **)&lang_parser)) {
- snprintf(IObuff, IOSIZE, "Failed to load parser: uv_dlsym: %s",
- uv_dlerror(&lib));
- uv_dlclose(&lib);
- lua_pushstring(L, IObuff);
- return lua_error(L);
- }
-
- TSLanguage *lang = lang_parser();
- if (lang == NULL) {
- uv_dlclose(&lib);
- return luaL_error(L, "Failed to load parser %s: internal error", path);
- }
+ TSLanguage *lang = load_language(L, path, lang_name, symbol_name);
uint32_t lang_version = ts_language_version(lang);
if (lang_version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION
@@ -254,14 +140,19 @@ int tslua_remove_lang(lua_State *L)
return 1;
}
-int tslua_inspect_lang(lua_State *L)
+static TSLanguage *lang_check(lua_State *L, int index)
{
- const char *lang_name = luaL_checkstring(L, 1);
-
+ const char *lang_name = luaL_checkstring(L, index);
TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
if (!lang) {
- return luaL_error(L, "no such language: %s", lang_name);
+ luaL_error(L, "no such language: %s", lang_name);
}
+ return lang;
+}
+
+int tslua_inspect_lang(lua_State *L)
+{
+ TSLanguage *lang = lang_check(L, 1);
lua_createtable(L, 0, 2); // [retval]
@@ -295,28 +186,38 @@ int tslua_inspect_lang(lua_State *L)
lua_setfield(L, -2, "fields"); // [retval]
- uint32_t lang_version = ts_language_version(lang);
- lua_pushinteger(L, lang_version); // [retval, version]
+ lua_pushinteger(L, ts_language_version(lang)); // [retval, version]
lua_setfield(L, -2, "_abi_version");
return 1;
}
+// TSParser
+
+static struct luaL_Reg parser_meta[] = {
+ { "__gc", parser_gc },
+ { "__tostring", parser_tostring },
+ { "parse", parser_parse },
+ { "reset", parser_reset },
+ { "set_included_ranges", parser_set_ranges },
+ { "included_ranges", parser_get_ranges },
+ { "set_timeout", parser_set_timeout },
+ { "timeout", parser_get_timeout },
+ { "_set_logger", parser_set_logger },
+ { "_logger", parser_get_logger },
+ { NULL, NULL }
+};
+
int tslua_push_parser(lua_State *L)
{
- // Gather language name
- const char *lang_name = luaL_checkstring(L, 1);
-
- TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
- if (!lang) {
- return luaL_error(L, "no such language: %s", lang_name);
- }
+ TSLanguage *lang = lang_check(L, 1);
TSParser **parser = lua_newuserdata(L, sizeof(TSParser *));
*parser = ts_parser_new();
if (!ts_parser_set_language(*parser, lang)) {
ts_parser_delete(*parser);
+ const char *lang_name = luaL_checkstring(L, 1);
return luaL_error(L, "Failed to load language : %s", lang_name);
}
@@ -325,9 +226,11 @@ int tslua_push_parser(lua_State *L)
return 1;
}
-static TSParser **parser_check(lua_State *L, uint16_t index)
+static TSParser *parser_check(lua_State *L, uint16_t index)
{
- return luaL_checkudata(L, index, TS_META_PARSER);
+ TSParser **ud = luaL_checkudata(L, index, TS_META_PARSER);
+ luaL_argcheck(L, *ud, index, "TSParser expected");
+ return *ud;
}
static void logger_gc(TSLogger logger)
@@ -343,13 +246,9 @@ static void logger_gc(TSLogger logger)
static int parser_gc(lua_State *L)
{
- TSParser **p = parser_check(L, 1);
- if (!p) {
- return 0;
- }
-
- logger_gc(ts_parser_logger(*p));
- ts_parser_delete(*p);
+ TSParser *p = parser_check(L, 1);
+ logger_gc(ts_parser_logger(p));
+ ts_parser_delete(p);
return 0;
}
@@ -371,7 +270,7 @@ static const char *input_cb(void *payload, uint32_t byte_index, TSPoint position
return "";
}
char *line = ml_get_buf(bp, (linenr_T)position.row + 1);
- size_t len = strlen(line);
+ size_t len = (size_t)ml_get_buf_len(bp, (linenr_T)position.row + 1);
if (position.column > len) {
*bytes_read = 0;
return "";
@@ -422,14 +321,10 @@ static void push_ranges(lua_State *L, const TSRange *ranges, const size_t length
static int parser_parse(lua_State *L)
{
- TSParser **p = parser_check(L, 1);
- if (!p || !(*p)) {
- return 0;
- }
-
+ TSParser *p = parser_check(L, 1);
TSTree *old_tree = NULL;
if (!lua_isnil(L, 2)) {
- TSLuaTree *ud = tree_check(L, 2);
+ TSLuaTree *ud = luaL_checkudata(L, 2, TS_META_TREE);
old_tree = ud ? ud->tree : NULL;
}
@@ -445,7 +340,7 @@ static int parser_parse(lua_State *L)
switch (lua_type(L, 3)) {
case LUA_TSTRING:
str = lua_tolstring(L, 3, &len);
- new_tree = ts_parser_parse_string(*p, old_tree, str, (uint32_t)len);
+ new_tree = ts_parser_parse_string(p, old_tree, str, (uint32_t)len);
break;
case LUA_TNUMBER:
@@ -461,7 +356,7 @@ static int parser_parse(lua_State *L)
}
input = (TSInput){ (void *)buf, input_cb, TSInputEncodingUTF8 };
- new_tree = ts_parser_parse(*p, old_tree, input);
+ new_tree = ts_parser_parse(p, old_tree, input);
break;
@@ -492,70 +387,14 @@ static int parser_parse(lua_State *L)
static int parser_reset(lua_State *L)
{
- TSParser **p = parser_check(L, 1);
- if (p && *p) {
- ts_parser_reset(*p);
- }
-
+ TSParser *p = parser_check(L, 1);
+ ts_parser_reset(p);
return 0;
}
-static int tree_copy(lua_State *L)
+static void range_err(lua_State *L)
{
- TSLuaTree *ud = tree_check(L, 1);
- if (!ud) {
- return 0;
- }
-
- TSTree *copy = ts_tree_copy(ud->tree);
- push_tree(L, copy); // [tree]
-
- return 1;
-}
-
-static int tree_edit(lua_State *L)
-{
- if (lua_gettop(L) < 10) {
- lua_pushstring(L, "not enough args to tree:edit()");
- return lua_error(L);
- }
-
- TSLuaTree *ud = tree_check(L, 1);
- if (!ud) {
- return 0;
- }
-
- uint32_t start_byte = (uint32_t)luaL_checkint(L, 2);
- uint32_t old_end_byte = (uint32_t)luaL_checkint(L, 3);
- uint32_t new_end_byte = (uint32_t)luaL_checkint(L, 4);
- TSPoint start_point = { (uint32_t)luaL_checkint(L, 5), (uint32_t)luaL_checkint(L, 6) };
- TSPoint old_end_point = { (uint32_t)luaL_checkint(L, 7), (uint32_t)luaL_checkint(L, 8) };
- TSPoint new_end_point = { (uint32_t)luaL_checkint(L, 9), (uint32_t)luaL_checkint(L, 10) };
-
- TSInputEdit edit = { start_byte, old_end_byte, new_end_byte,
- start_point, old_end_point, new_end_point };
-
- ts_tree_edit(ud->tree, &edit);
-
- return 0;
-}
-
-static int tree_get_ranges(lua_State *L)
-{
- TSLuaTree *ud = tree_check(L, 1);
- if (!ud) {
- return 0;
- }
-
- bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
-
- uint32_t len;
- TSRange *ranges = ts_tree_included_ranges(ud->tree, &len);
-
- push_ranges(L, ranges, len, include_bytes);
-
- xfree(ranges);
- return 1;
+ luaL_error(L, "Ranges can only be made from 6 element long tables or nodes.");
}
// Use the top of the stack (without popping it) to create a TSRange, it can be
@@ -567,7 +406,7 @@ static void range_from_lua(lua_State *L, TSRange *range)
if (lua_istable(L, -1)) {
// should be a table of 6 elements
if (lua_objlen(L, -1) != 6) {
- goto error;
+ range_err(L);
}
lua_rawgeti(L, -1, 1); // [ range, start_row]
@@ -606,7 +445,7 @@ static void range_from_lua(lua_State *L, TSRange *range)
.start_byte = start_byte,
.end_byte = end_byte,
};
- } else if (node_check(L, -1, &node)) {
+ } else if (node_check_opt(L, -1, &node)) {
*range = (TSRange) {
.start_point = ts_node_start_point(node),
.end_point = ts_node_end_point(node),
@@ -614,30 +453,19 @@ static void range_from_lua(lua_State *L, TSRange *range)
.end_byte = ts_node_end_byte(node)
};
} else {
- goto error;
+ range_err(L);
}
- return;
-error:
- luaL_error(L,
- "Ranges can only be made from 6 element long tables or nodes.");
}
static int parser_set_ranges(lua_State *L)
{
if (lua_gettop(L) < 2) {
- return luaL_error(L,
- "not enough args to parser:set_included_ranges()");
+ return luaL_error(L, "not enough args to parser:set_included_ranges()");
}
- TSParser **p = parser_check(L, 1);
- if (!p) {
- return 0;
- }
+ TSParser *p = parser_check(L, 1);
- if (!lua_istable(L, 2)) {
- return luaL_error(L,
- "argument for parser:set_included_ranges() should be a table.");
- }
+ luaL_argcheck(L, lua_istable(L, 2), 2, "table expected.");
size_t tbl_len = lua_objlen(L, 2);
TSRange *ranges = xmalloc(sizeof(TSRange) * tbl_len);
@@ -650,7 +478,7 @@ static int parser_set_ranges(lua_State *L)
}
// This memcpies ranges, thus we can free it afterwards
- ts_parser_set_included_ranges(*p, ranges, (uint32_t)tbl_len);
+ ts_parser_set_included_ranges(p, ranges, (uint32_t)tbl_len);
xfree(ranges);
return 0;
@@ -658,15 +486,12 @@ static int parser_set_ranges(lua_State *L)
static int parser_get_ranges(lua_State *L)
{
- TSParser **p = parser_check(L, 1);
- if (!p) {
- return 0;
- }
+ TSParser *p = parser_check(L, 1);
bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
uint32_t len;
- const TSRange *ranges = ts_parser_included_ranges(*p, &len);
+ const TSRange *ranges = ts_parser_included_ranges(p, &len);
push_ranges(L, ranges, len, include_bytes);
return 1;
@@ -674,28 +499,21 @@ static int parser_get_ranges(lua_State *L)
static int parser_set_timeout(lua_State *L)
{
- TSParser **p = parser_check(L, 1);
- if (!p) {
- return 0;
- }
+ TSParser *p = parser_check(L, 1);
if (lua_gettop(L) < 2) {
luaL_error(L, "integer expected");
}
uint32_t timeout = (uint32_t)luaL_checkinteger(L, 2);
- ts_parser_set_timeout_micros(*p, timeout);
+ ts_parser_set_timeout_micros(p, timeout);
return 0;
}
static int parser_get_timeout(lua_State *L)
{
- TSParser **p = parser_check(L, 1);
- if (!p) {
- return 0;
- }
-
- lua_pushinteger(L, (lua_Integer)ts_parser_timeout_micros(*p));
+ TSParser *p = parser_check(L, 1);
+ lua_pushinteger(L, (lua_Integer)ts_parser_timeout_micros(p));
return 1;
}
@@ -719,22 +537,11 @@ static void logger_cb(void *payload, TSLogType logtype, const char *s)
static int parser_set_logger(lua_State *L)
{
- TSParser **p = parser_check(L, 1);
- if (!p) {
- return 0;
- }
+ TSParser *p = parser_check(L, 1);
- if (!lua_isboolean(L, 2)) {
- return luaL_argerror(L, 2, "boolean expected");
- }
-
- if (!lua_isboolean(L, 3)) {
- return luaL_argerror(L, 3, "boolean expected");
- }
-
- if (!lua_isfunction(L, 4)) {
- return luaL_argerror(L, 4, "function expected");
- }
+ luaL_argcheck(L, lua_isboolean(L, 2), 2, "boolean expected");
+ luaL_argcheck(L, lua_isboolean(L, 3), 3, "boolean expected");
+ luaL_argcheck(L, lua_isfunction(L, 4), 4, "function expected");
TSLuaLoggerOpts *opts = xmalloc(sizeof(TSLuaLoggerOpts));
lua_pushvalue(L, 4);
@@ -752,18 +559,14 @@ static int parser_set_logger(lua_State *L)
.log = logger_cb
};
- ts_parser_set_logger(*p, logger);
+ ts_parser_set_logger(p, logger);
return 0;
}
static int parser_get_logger(lua_State *L)
{
- TSParser **p = parser_check(L, 1);
- if (!p) {
- return 0;
- }
-
- TSLogger logger = ts_parser_logger(*p);
+ TSParser *p = parser_check(L, 1);
+ TSLogger logger = ts_parser_logger(p);
if (logger.log) {
TSLuaLoggerOpts *opts = (TSLuaLoggerOpts *)logger.payload;
lua_rawgeti(L, LUA_REGISTRYINDEX, opts->cb);
@@ -774,7 +577,17 @@ static int parser_get_logger(lua_State *L)
return 1;
}
-// Tree methods
+// TSTree
+
+static struct luaL_Reg tree_meta[] = {
+ { "__gc", tree_gc },
+ { "__tostring", tree_tostring },
+ { "root", tree_root },
+ { "edit", tree_edit },
+ { "included_ranges", tree_get_ranges },
+ { "copy", tree_copy },
+ { NULL, NULL }
+};
/// Push tree interface on to the lua stack.
///
@@ -804,18 +617,58 @@ static void push_tree(lua_State *L, TSTree *tree)
lua_setfenv(L, -2); // [udata]
}
-static TSLuaTree *tree_check(lua_State *L, int index)
+static int tree_copy(lua_State *L)
{
- TSLuaTree *ud = luaL_checkudata(L, index, TS_META_TREE);
- return ud;
+ TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
+ TSTree *copy = ts_tree_copy(ud->tree);
+ push_tree(L, copy); // [tree]
+
+ return 1;
}
-static int tree_gc(lua_State *L)
+static int tree_edit(lua_State *L)
{
- TSLuaTree *ud = tree_check(L, 1);
- if (ud) {
- ts_tree_delete(ud->tree);
+ if (lua_gettop(L) < 10) {
+ lua_pushstring(L, "not enough args to tree:edit()");
+ return lua_error(L);
}
+
+ TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
+
+ uint32_t start_byte = (uint32_t)luaL_checkint(L, 2);
+ uint32_t old_end_byte = (uint32_t)luaL_checkint(L, 3);
+ uint32_t new_end_byte = (uint32_t)luaL_checkint(L, 4);
+ TSPoint start_point = { (uint32_t)luaL_checkint(L, 5), (uint32_t)luaL_checkint(L, 6) };
+ TSPoint old_end_point = { (uint32_t)luaL_checkint(L, 7), (uint32_t)luaL_checkint(L, 8) };
+ TSPoint new_end_point = { (uint32_t)luaL_checkint(L, 9), (uint32_t)luaL_checkint(L, 10) };
+
+ TSInputEdit edit = { start_byte, old_end_byte, new_end_byte,
+ start_point, old_end_point, new_end_point };
+
+ ts_tree_edit(ud->tree, &edit);
+
+ return 0;
+}
+
+static int tree_get_ranges(lua_State *L)
+{
+ TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
+
+ bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
+
+ uint32_t len;
+ TSRange *ranges = ts_tree_included_ranges(ud->tree, &len);
+
+ push_ranges(L, ranges, len, include_bytes);
+
+ xfree(ranges);
+ return 1;
+}
+
+static int tree_gc(lua_State *L)
+{
+ TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
+ ts_tree_delete(ud->tree);
return 0;
}
@@ -827,16 +680,66 @@ static int tree_tostring(lua_State *L)
static int tree_root(lua_State *L)
{
- TSLuaTree *ud = tree_check(L, 1);
- if (!ud) {
- return 0;
- }
+ TSLuaTree *ud = luaL_checkudata(L, 1, TS_META_TREE);
TSNode root = ts_tree_root_node(ud->tree);
push_node(L, root, 1);
return 1;
}
-// Node methods
+// TSTreeCursor
+
+static struct luaL_Reg treecursor_meta[] = {
+ { "__gc", treecursor_gc },
+ { NULL, NULL }
+};
+
+static int treecursor_gc(lua_State *L)
+{
+ TSTreeCursor *cursor = luaL_checkudata(L, 1, TS_META_TREECURSOR);
+ ts_tree_cursor_delete(cursor);
+ return 0;
+}
+
+// TSNode
+static struct luaL_Reg node_meta[] = {
+ { "__tostring", node_tostring },
+ { "__eq", node_eq },
+ { "__len", node_child_count },
+ { "id", node_id },
+ { "range", node_range },
+ { "start", node_start },
+ { "end_", node_end },
+ { "type", node_type },
+ { "symbol", node_symbol },
+ { "field", node_field },
+ { "named", node_named },
+ { "missing", node_missing },
+ { "extra", node_extra },
+ { "has_changes", node_has_changes },
+ { "has_error", node_has_error },
+ { "sexpr", node_sexpr },
+ { "child_count", node_child_count },
+ { "named_child_count", node_named_child_count },
+ { "child", node_child },
+ { "named_child", node_named_child },
+ { "descendant_for_range", node_descendant_for_range },
+ { "named_descendant_for_range", node_named_descendant_for_range },
+ { "parent", node_parent },
+ { "__has_ancestor", __has_ancestor },
+ { "child_containing_descendant", node_child_containing_descendant },
+ { "iter_children", node_iter_children },
+ { "next_sibling", node_next_sibling },
+ { "prev_sibling", node_prev_sibling },
+ { "next_named_sibling", node_next_named_sibling },
+ { "prev_named_sibling", node_prev_named_sibling },
+ { "named_children", node_named_children },
+ { "root", node_root },
+ { "tree", node_tree },
+ { "byte_length", node_byte_length },
+ { "equal", node_equal },
+
+ { NULL, NULL }
+};
/// Push node interface on to the Lua stack
///
@@ -860,7 +763,7 @@ static void push_node(lua_State *L, TSNode node, int uindex)
lua_setfenv(L, -2); // [udata]
}
-static bool node_check(lua_State *L, int index, TSNode *res)
+static bool node_check_opt(lua_State *L, int index, TSNode *res)
{
TSNode *ud = luaL_checkudata(L, index, TS_META_NODE);
if (ud) {
@@ -870,12 +773,15 @@ static bool node_check(lua_State *L, int index, TSNode *res)
return false;
}
+static TSNode node_check(lua_State *L, int index)
+{
+ TSNode *ud = luaL_checkudata(L, index, TS_META_NODE);
+ return *ud;
+}
+
static int node_tostring(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
lua_pushstring(L, "<node ");
lua_pushstring(L, ts_node_type(node));
lua_pushstring(L, ">");
@@ -885,37 +791,22 @@ static int node_tostring(lua_State *L)
static int node_eq(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
-
- TSNode node2;
- if (!node_check(L, 2, &node2)) {
- return 0;
- }
-
+ TSNode node = node_check(L, 1);
+ TSNode node2 = node_check(L, 2);
lua_pushboolean(L, ts_node_eq(node, node2));
return 1;
}
static int node_id(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
-
+ TSNode node = node_check(L, 1);
lua_pushlstring(L, (const char *)&node.id, sizeof node.id);
return 1;
}
static int node_range(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
bool include_bytes = (lua_gettop(L) >= 2) && lua_toboolean(L, 2);
@@ -941,10 +832,7 @@ static int node_range(lua_State *L)
static int node_start(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSPoint start = ts_node_start_point(node);
uint32_t start_byte = ts_node_start_byte(node);
lua_pushinteger(L, start.row);
@@ -955,10 +843,7 @@ static int node_start(lua_State *L)
static int node_end(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSPoint end = ts_node_end_point(node);
uint32_t end_byte = ts_node_end_byte(node);
lua_pushinteger(L, end.row);
@@ -969,10 +854,7 @@ static int node_end(lua_State *L)
static int node_child_count(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
uint32_t count = ts_node_child_count(node);
lua_pushinteger(L, count);
return 1;
@@ -980,10 +862,7 @@ static int node_child_count(lua_State *L)
static int node_named_child_count(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
uint32_t count = ts_node_named_child_count(node);
lua_pushinteger(L, count);
return 1;
@@ -991,20 +870,14 @@ static int node_named_child_count(lua_State *L)
static int node_type(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
lua_pushstring(L, ts_node_type(node));
return 1;
}
static int node_symbol(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSSymbol symbol = ts_node_symbol(node);
lua_pushinteger(L, symbol);
return 1;
@@ -1012,10 +885,7 @@ static int node_symbol(lua_State *L)
static int node_field(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
size_t name_len;
const char *field_name = luaL_checklstring(L, 2, &name_len);
@@ -1042,20 +912,14 @@ static int node_field(lua_State *L)
static int node_named(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
lua_pushboolean(L, ts_node_is_named(node));
return 1;
}
static int node_sexpr(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
char *allocated = ts_node_string(node);
lua_pushstring(L, allocated);
xfree(allocated);
@@ -1064,50 +928,35 @@ static int node_sexpr(lua_State *L)
static int node_missing(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
lua_pushboolean(L, ts_node_is_missing(node));
return 1;
}
static int node_extra(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
lua_pushboolean(L, ts_node_is_extra(node));
return 1;
}
static int node_has_changes(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
lua_pushboolean(L, ts_node_has_changes(node));
return 1;
}
static int node_has_error(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
lua_pushboolean(L, ts_node_has_error(node));
return 1;
}
static int node_child(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
uint32_t num = (uint32_t)lua_tointeger(L, 2);
TSNode child = ts_node_child(node, num);
@@ -1117,10 +966,7 @@ static int node_child(lua_State *L)
static int node_named_child(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
uint32_t num = (uint32_t)lua_tointeger(L, 2);
TSNode child = ts_node_named_child(node, num);
@@ -1130,10 +976,7 @@ static int node_named_child(lua_State *L)
static int node_descendant_for_range(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSPoint start = { (uint32_t)lua_tointeger(L, 2),
(uint32_t)lua_tointeger(L, 3) };
TSPoint end = { (uint32_t)lua_tointeger(L, 4),
@@ -1146,10 +989,7 @@ static int node_descendant_for_range(lua_State *L)
static int node_named_descendant_for_range(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSPoint start = { (uint32_t)lua_tointeger(L, 2),
(uint32_t)lua_tointeger(L, 3) };
TSPoint end = { (uint32_t)lua_tointeger(L, 4),
@@ -1162,54 +1002,41 @@ static int node_named_descendant_for_range(lua_State *L)
static int node_next_child(lua_State *L)
{
- TSTreeCursor *ud = luaL_checkudata(L, lua_upvalueindex(1), TS_META_TREECURSOR);
- if (!ud) {
- return 0;
- }
-
- TSNode source;
- if (!node_check(L, lua_upvalueindex(2), &source)) {
- return 0;
- }
+ TSTreeCursor *cursor = luaL_checkudata(L, lua_upvalueindex(1), TS_META_TREECURSOR);
+ TSNode source = node_check(L, lua_upvalueindex(2));
// First call should return first child
- if (ts_node_eq(source, ts_tree_cursor_current_node(ud))) {
- if (ts_tree_cursor_goto_first_child(ud)) {
+ if (ts_node_eq(source, ts_tree_cursor_current_node(cursor))) {
+ if (ts_tree_cursor_goto_first_child(cursor)) {
goto push;
} else {
- goto end;
+ return 0;
}
}
- if (ts_tree_cursor_goto_next_sibling(ud)) {
-push:
- push_node(L,
- ts_tree_cursor_current_node(ud),
- lua_upvalueindex(2)); // [node]
+ if (!ts_tree_cursor_goto_next_sibling(cursor)) {
+ return 0;
+ }
- const char *field = ts_tree_cursor_current_field_name(ud);
+push:
+ push_node(L, ts_tree_cursor_current_node(cursor), lua_upvalueindex(2)); // [node]
- if (field != NULL) {
- lua_pushstring(L, ts_tree_cursor_current_field_name(ud));
- } else {
- lua_pushnil(L);
- } // [node, field_name_or_nil]
- return 2;
- }
+ const char *field = ts_tree_cursor_current_field_name(cursor);
-end:
- return 0;
+ if (field != NULL) {
+ lua_pushstring(L, ts_tree_cursor_current_field_name(cursor));
+ } else {
+ lua_pushnil(L);
+ } // [node, field_name_or_nil]
+ return 2;
}
static int node_iter_children(lua_State *L)
{
- TSNode source;
- if (!node_check(L, 1, &source)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSTreeCursor *ud = lua_newuserdata(L, sizeof(TSTreeCursor)); // [udata]
- *ud = ts_tree_cursor_new(source);
+ *ud = ts_tree_cursor_new(node);
lua_getfield(L, LUA_REGISTRYINDEX, TS_META_TREECURSOR); // [udata, mt]
lua_setmetatable(L, -2); // [udata]
@@ -1219,30 +1046,60 @@ static int node_iter_children(lua_State *L)
return 1;
}
-static int treecursor_gc(lua_State *L)
+static int node_parent(lua_State *L)
{
- TSTreeCursor *ud = luaL_checkudata(L, 1, TS_META_TREECURSOR);
- ts_tree_cursor_delete(ud);
- return 0;
+ TSNode node = node_check(L, 1);
+ TSNode parent = ts_node_parent(node);
+ push_node(L, parent, 1);
+ return 1;
}
-static int node_parent(lua_State *L)
+static int __has_ancestor(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
+ TSNode descendant = node_check(L, 1);
+ if (lua_type(L, 2) != LUA_TTABLE) {
+ lua_pushboolean(L, false);
+ return 1;
}
- TSNode parent = ts_node_parent(node);
- push_node(L, parent, 1);
+ int const pred_len = (int)lua_objlen(L, 2);
+
+ TSNode node = ts_tree_root_node(descendant.tree);
+ while (!ts_node_is_null(node)) {
+ char const *node_type = ts_node_type(node);
+ size_t node_type_len = strlen(node_type);
+
+ for (int i = 3; i <= pred_len; i++) {
+ lua_rawgeti(L, 2, i);
+ if (lua_type(L, -1) == LUA_TSTRING) {
+ size_t check_len;
+ char const *check_str = lua_tolstring(L, -1, &check_len);
+ if (node_type_len == check_len && memcmp(node_type, check_str, check_len) == 0) {
+ lua_pushboolean(L, true);
+ return 1;
+ }
+ }
+ lua_pop(L, 1);
+ }
+
+ node = ts_node_child_containing_descendant(node, descendant);
+ }
+
+ lua_pushboolean(L, false);
+ return 1;
+}
+
+static int node_child_containing_descendant(lua_State *L)
+{
+ TSNode node = node_check(L, 1);
+ TSNode descendant = node_check(L, 2);
+ TSNode child = ts_node_child_containing_descendant(node, descendant);
+ push_node(L, child, 1);
return 1;
}
static int node_next_sibling(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSNode sibling = ts_node_next_sibling(node);
push_node(L, sibling, 1);
return 1;
@@ -1250,10 +1107,7 @@ static int node_next_sibling(lua_State *L)
static int node_prev_sibling(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSNode sibling = ts_node_prev_sibling(node);
push_node(L, sibling, 1);
return 1;
@@ -1261,10 +1115,7 @@ static int node_prev_sibling(lua_State *L)
static int node_next_named_sibling(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSNode sibling = ts_node_next_named_sibling(node);
push_node(L, sibling, 1);
return 1;
@@ -1272,10 +1123,7 @@ static int node_next_named_sibling(lua_State *L)
static int node_prev_named_sibling(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
+ TSNode node = node_check(L, 1);
TSNode sibling = ts_node_prev_named_sibling(node);
push_node(L, sibling, 1);
return 1;
@@ -1283,10 +1131,7 @@ static int node_prev_named_sibling(lua_State *L)
static int node_named_children(lua_State *L)
{
- TSNode source;
- if (!node_check(L, 1, &source)) {
- return 0;
- }
+ TSNode source = node_check(L, 1);
TSTreeCursor cursor = ts_tree_cursor_new(source);
lua_newtable(L);
@@ -1308,11 +1153,7 @@ static int node_named_children(lua_State *L)
static int node_root(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
-
+ TSNode node = node_check(L, 1);
TSNode root = ts_tree_root_node(node.tree);
push_node(L, root, 1);
return 1;
@@ -1320,215 +1161,196 @@ static int node_root(lua_State *L)
static int node_tree(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
-
+ node_check(L, 1);
lua_getfenv(L, 1); // [udata, reftable]
lua_rawgeti(L, -1, 1); // [udata, reftable, tree_udata]
-
return 1;
}
static int node_byte_length(lua_State *L)
{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
-
+ TSNode node = node_check(L, 1);
uint32_t start_byte = ts_node_start_byte(node);
uint32_t end_byte = ts_node_end_byte(node);
-
lua_pushinteger(L, end_byte - start_byte);
return 1;
}
static int node_equal(lua_State *L)
{
- TSNode node1;
- if (!node_check(L, 1, &node1)) {
- return 0;
- }
-
- TSNode node2;
- if (!node_check(L, 2, &node2)) {
- return luaL_error(L, "TSNode expected");
- }
-
+ TSNode node1 = node_check(L, 1);
+ TSNode node2 = node_check(L, 2);
lua_pushboolean(L, ts_node_eq(node1, node2));
return 1;
}
-/// assumes the match table being on top of the stack
-static void set_match(lua_State *L, TSQueryMatch *match, int nodeidx)
-{
- // [match]
- for (size_t i = 0; i < match->capture_count; i++) {
- lua_rawgeti(L, -1, (int)match->captures[i].index + 1); // [match, captures]
- if (lua_isnil(L, -1)) { // [match, nil]
- lua_pop(L, 1); // [match]
- lua_createtable(L, 1, 0); // [match, captures]
- }
- push_node(L, match->captures[i].node, nodeidx); // [match, captures, node]
- lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, captures]
- lua_rawseti(L, -2, (int)match->captures[i].index + 1); // [match]
- }
-}
+// TSQueryCursor
-static int query_next_match(lua_State *L)
-{
- TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
- TSQueryCursor *cursor = ud->cursor;
-
- TSQuery *query = query_check(L, lua_upvalueindex(3));
- TSQueryMatch match;
- if (ts_query_cursor_next_match(cursor, &match)) {
- lua_pushinteger(L, match.pattern_index + 1); // [index]
- lua_createtable(L, (int)ts_query_capture_count(query), 0); // [index, match]
- set_match(L, &match, lua_upvalueindex(2));
- return 2;
- }
- return 0;
-}
+static struct luaL_Reg querycursor_meta[] = {
+ { "remove_match", querycursor_remove_match },
+ { "next_capture", querycursor_next_capture },
+ { "next_match", querycursor_next_match },
+ { "__gc", querycursor_gc },
+ { NULL, NULL }
+};
-static int query_next_capture(lua_State *L)
+int tslua_push_querycursor(lua_State *L)
{
- // Upvalues are:
- // [ cursor, node, query, current_match ]
- TSLua_cursor *ud = lua_touserdata(L, lua_upvalueindex(1));
- TSQueryCursor *cursor = ud->cursor;
-
- TSQuery *query = query_check(L, lua_upvalueindex(3));
-
- if (ud->predicated_match > -1) {
- lua_getfield(L, lua_upvalueindex(4), "active");
- bool active = lua_toboolean(L, -1);
- lua_pop(L, 1);
- if (!active) {
- ts_query_cursor_remove_match(cursor, (uint32_t)ud->predicated_match);
- }
- ud->predicated_match = -1;
- }
+ TSNode node = node_check(L, 1);
- TSQueryMatch match;
- uint32_t capture_index;
- if (ts_query_cursor_next_capture(cursor, &match, &capture_index)) {
- TSQueryCapture capture = match.captures[capture_index];
-
- // TODO(vigoux): handle capture quantifiers here
- lua_pushinteger(L, capture.index + 1); // [index]
- push_node(L, capture.node, lua_upvalueindex(2)); // [index, node]
-
- // Now check if we need to run the predicates
- uint32_t n_pred;
- ts_query_predicates_for_pattern(query, match.pattern_index, &n_pred);
-
- if (n_pred > 0 && (ud->max_match_id < (int)match.id)) {
- ud->max_match_id = (int)match.id;
-
- // Create a new cleared match table
- lua_createtable(L, (int)ts_query_capture_count(query), 2); // [index, node, match]
- set_match(L, &match, lua_upvalueindex(2));
- lua_pushinteger(L, match.pattern_index + 1);
- lua_setfield(L, -2, "pattern");
-
- if (match.capture_count > 1) {
- ud->predicated_match = (int)match.id;
- lua_pushboolean(L, false);
- lua_setfield(L, -2, "active");
- }
-
- // Set current_match to the new match
- lua_replace(L, lua_upvalueindex(4)); // [index, node]
- lua_pushvalue(L, lua_upvalueindex(4)); // [index, node, match]
- return 3;
- }
- return 2;
- }
- return 0;
-}
-
-static int node_rawquery(lua_State *L)
-{
- TSNode node;
- if (!node_check(L, 1, &node)) {
- return 0;
- }
TSQuery *query = query_check(L, 2);
-
- TSQueryCursor *cursor;
- if (kv_size(cursors) > 0) {
- cursor = kv_pop(cursors);
- } else {
- cursor = ts_query_cursor_new();
- }
-
- ts_query_cursor_set_max_start_depth(cursor, UINT32_MAX);
- ts_query_cursor_set_match_limit(cursor, 256);
+ TSQueryCursor *cursor = ts_query_cursor_new();
ts_query_cursor_exec(cursor, query, node);
- bool captures = lua_toboolean(L, 3);
-
- if (lua_gettop(L) >= 4) {
- uint32_t start = (uint32_t)luaL_checkinteger(L, 4);
- uint32_t end = lua_gettop(L) >= 5 ? (uint32_t)luaL_checkinteger(L, 5) : MAXLNUM;
+ if (lua_gettop(L) >= 3) {
+ uint32_t start = (uint32_t)luaL_checkinteger(L, 3);
+ uint32_t end = lua_gettop(L) >= 4 ? (uint32_t)luaL_checkinteger(L, 4) : MAXLNUM;
ts_query_cursor_set_point_range(cursor, (TSPoint){ start, 0 }, (TSPoint){ end, 0 });
}
- if (lua_gettop(L) >= 6 && !lua_isnil(L, 6)) {
- if (!lua_istable(L, 6)) {
- return luaL_error(L, "table expected");
- }
- lua_pushnil(L);
- // stack: [dict, ..., nil]
- while (lua_next(L, 6)) {
- // stack: [dict, ..., key, value]
+ if (lua_gettop(L) >= 5 && !lua_isnil(L, 5)) {
+ luaL_argcheck(L, lua_istable(L, 5), 5, "table expected");
+ lua_pushnil(L); // [dict, ..., nil]
+ while (lua_next(L, 5)) {
+ // [dict, ..., key, value]
if (lua_type(L, -2) == LUA_TSTRING) {
char *k = (char *)lua_tostring(L, -2);
if (strequal("max_start_depth", k)) {
uint32_t max_start_depth = (uint32_t)lua_tointeger(L, -1);
ts_query_cursor_set_max_start_depth(cursor, max_start_depth);
+ } else if (strequal("match_limit", k)) {
+ uint32_t match_limit = (uint32_t)lua_tointeger(L, -1);
+ ts_query_cursor_set_match_limit(cursor, match_limit);
}
}
- lua_pop(L, 1); // pop the value; lua_next will pop the key.
- // stack: [dict, ..., key]
+ // pop the value; lua_next will pop the key.
+ lua_pop(L, 1); // [dict, ..., key]
}
}
- TSLua_cursor *ud = lua_newuserdata(L, sizeof(*ud)); // [udata]
- ud->cursor = cursor;
- ud->predicated_match = -1;
- ud->max_match_id = -1;
+ TSQueryCursor **ud = lua_newuserdata(L, sizeof(*ud)); // [node, query, ..., udata]
+ *ud = cursor;
+ lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR); // [node, query, ..., udata, meta]
+ lua_setmetatable(L, -2); // [node, query, ..., udata]
- lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYCURSOR);
- lua_setmetatable(L, -2); // [udata]
- lua_pushvalue(L, 1); // [udata, node]
+ // Copy the fenv which contains the nodes tree.
+ lua_getfenv(L, 1); // [udata, reftable]
+ lua_setfenv(L, -2); // [udata]
- // include query separately, as to keep a ref to it for gc
- lua_pushvalue(L, 2); // [udata, node, query]
+ return 1;
+}
- if (captures) {
- // placeholder for match state
- lua_createtable(L, (int)ts_query_capture_count(query), 2); // [u, n, q, match]
- lua_pushcclosure(L, query_next_capture, 4); // [closure]
- } else {
- lua_pushcclosure(L, query_next_match, 3); // [closure]
+static int querycursor_remove_match(lua_State *L)
+{
+ TSQueryCursor *cursor = querycursor_check(L, 1);
+ uint32_t match_id = (uint32_t)luaL_checkinteger(L, 2);
+ ts_query_cursor_remove_match(cursor, match_id);
+ return 0;
+}
+
+static int querycursor_next_capture(lua_State *L)
+{
+ TSQueryCursor *cursor = querycursor_check(L, 1);
+ TSQueryMatch match;
+ uint32_t capture_index;
+ if (!ts_query_cursor_next_capture(cursor, &match, &capture_index)) {
+ return 0;
+ }
+
+ TSQueryCapture capture = match.captures[capture_index];
+
+ // Handle capture quantifiers here
+ lua_pushinteger(L, capture.index + 1); // [index]
+ push_node(L, capture.node, 1); // [index, node]
+ push_querymatch(L, &match, 1);
+
+ return 3;
+}
+
+static int querycursor_next_match(lua_State *L)
+{
+ TSQueryCursor *cursor = querycursor_check(L, 1);
+
+ TSQueryMatch match;
+ if (!ts_query_cursor_next_match(cursor, &match)) {
+ return 0;
}
+ push_querymatch(L, &match, 1);
+
return 1;
}
+static TSQueryCursor *querycursor_check(lua_State *L, int index)
+{
+ TSQueryCursor **ud = luaL_checkudata(L, index, TS_META_QUERYCURSOR);
+ luaL_argcheck(L, *ud, index, "TSQueryCursor expected");
+ return *ud;
+}
+
static int querycursor_gc(lua_State *L)
{
- TSLua_cursor *ud = luaL_checkudata(L, 1, TS_META_QUERYCURSOR);
- kv_push(cursors, ud->cursor);
- ud->cursor = NULL;
+ TSQueryCursor *cursor = querycursor_check(L, 1);
+ ts_query_cursor_delete(cursor);
return 0;
}
-// Query methods
+// TSQueryMatch
+
+static struct luaL_Reg querymatch_meta[] = {
+ { "info", querymatch_info },
+ { "captures", querymatch_captures },
+ { NULL, NULL }
+};
+
+static void push_querymatch(lua_State *L, TSQueryMatch *match, int uindex)
+{
+ TSQueryMatch *ud = lua_newuserdata(L, sizeof(TSQueryMatch)); // [udata]
+ *ud = *match;
+ lua_getfield(L, LUA_REGISTRYINDEX, TS_META_QUERYMATCH); // [udata, meta]
+ lua_setmetatable(L, -2); // [udata]
+
+ // Copy the fenv which contains the nodes tree.
+ lua_getfenv(L, uindex); // [udata, reftable]
+ lua_setfenv(L, -2); // [udata]
+}
+
+static int querymatch_info(lua_State *L)
+{
+ TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH);
+ lua_pushinteger(L, match->id);
+ lua_pushinteger(L, match->pattern_index + 1);
+ return 2;
+}
+
+static int querymatch_captures(lua_State *L)
+{
+ TSQueryMatch *match = luaL_checkudata(L, 1, TS_META_QUERYMATCH);
+ lua_newtable(L); // [match, nodes, captures]
+ for (size_t i = 0; i < match->capture_count; i++) {
+ TSQueryCapture capture = match->captures[i];
+ int index = (int)capture.index + 1;
+
+ lua_rawgeti(L, -1, index); // [match, node, captures]
+ if (lua_isnil(L, -1)) { // [match, node, captures, nil]
+ lua_pop(L, 1); // [match, node, captures]
+ lua_newtable(L); // [match, node, captures, nodes]
+ }
+ push_node(L, capture.node, 1); // [match, node, captures, nodes, node]
+ lua_rawseti(L, -2, (int)lua_objlen(L, -2) + 1); // [match, node, captures, nodes]
+ lua_rawseti(L, -2, index); // [match, node, captures]
+ }
+ return 1;
+}
+
+// TSQuery
+
+static struct luaL_Reg query_meta[] = {
+ { "__gc", query_gc },
+ { "__tostring", query_tostring },
+ { "inspect", query_inspect },
+ { NULL, NULL }
+};
int tslua_parse_query(lua_State *L)
{
@@ -1536,15 +1358,12 @@ int tslua_parse_query(lua_State *L)
return luaL_error(L, "string expected");
}
- const char *lang_name = lua_tostring(L, 1);
- TSLanguage *lang = pmap_get(cstr_t)(&langs, lang_name);
- if (!lang) {
- return luaL_error(L, "no such language: %s", lang_name);
- }
+ TSLanguage *lang = lang_check(L, 1);
size_t len;
const char *src = lua_tolstring(L, 2, &len);
+ tslua_query_parse_count++;
uint32_t error_offset;
TSQueryError error_type;
TSQuery *query = ts_query_new(lang, src, (uint32_t)len, &error_offset, &error_type);
@@ -1638,16 +1457,13 @@ static void query_err_string(const char *src, int error_offset, TSQueryError err
static TSQuery *query_check(lua_State *L, int index)
{
TSQuery **ud = luaL_checkudata(L, index, TS_META_QUERY);
+ luaL_argcheck(L, *ud, index, "TSQuery expected");
return *ud;
}
static int query_gc(lua_State *L)
{
TSQuery *query = query_check(L, 1);
- if (!query) {
- return 0;
- }
-
ts_query_delete(query);
return 0;
}
@@ -1661,9 +1477,6 @@ static int query_tostring(lua_State *L)
static int query_inspect(lua_State *L)
{
TSQuery *query = query_check(L, 1);
- if (!query) {
- return 0;
- }
// TSQueryInfo
lua_createtable(L, 0, 2); // [retval]
@@ -1718,3 +1531,33 @@ static int query_inspect(lua_State *L)
return 1;
}
+
+// Library init
+
+static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta)
+{
+ if (luaL_newmetatable(L, tname)) { // [meta]
+ luaL_register(L, NULL, meta);
+
+ lua_pushvalue(L, -1); // [meta, meta]
+ lua_setfield(L, -2, "__index"); // [meta]
+ }
+ lua_pop(L, 1); // [] (don't use it now)
+}
+
+/// Init the tslua library.
+///
+/// All global state is stored in the registry of the lua_State.
+void tslua_init(lua_State *L)
+{
+ // type metatables
+ build_meta(L, TS_META_PARSER, parser_meta);
+ build_meta(L, TS_META_TREE, tree_meta);
+ build_meta(L, TS_META_NODE, node_meta);
+ build_meta(L, TS_META_QUERY, query_meta);
+ build_meta(L, TS_META_QUERYCURSOR, querycursor_meta);
+ build_meta(L, TS_META_QUERYMATCH, querymatch_meta);
+ build_meta(L, TS_META_TREECURSOR, treecursor_meta);
+
+ ts_set_allocator(xmalloc, xcalloc, xrealloc, xfree);
+}
diff --git a/src/nvim/lua/treesitter.h b/src/nvim/lua/treesitter.h
index 4ef9a10602..14df06e184 100644
--- a/src/nvim/lua/treesitter.h
+++ b/src/nvim/lua/treesitter.h
@@ -1,7 +1,12 @@
#pragma once
#include <lua.h> // IWYU pragma: keep
+#include <stdint.h>
+
+#include "nvim/macros_defs.h"
#ifdef INCLUDE_GENERATED_DECLARATIONS
# include "lua/treesitter.h.generated.h"
#endif
+
+EXTERN uint64_t tslua_query_parse_count INIT( = 0);