diff options
-rw-r--r-- | runtime/doc/lua.txt | 6 | ||||
-rw-r--r-- | src/nvim/lua/treesitter.c | 77 | ||||
-rw-r--r-- | test/functional/lua/treesitter_spec.lua | 24 |
3 files changed, 106 insertions, 1 deletions
diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt index 8c306135d0..e948a7c9aa 100644 --- a/runtime/doc/lua.txt +++ b/runtime/doc/lua.txt @@ -622,6 +622,12 @@ Node methods *lua-treesitter-node* tsnode:parent() *tsnode:parent()* Get the node's immediate parent. +tsnode:iter_children() *tsnode:iter_children()* + Iterates over all the direct children of {tsnode}, regardless of + wether they are named or not. + Returns the child node plus the eventual field name corresponding to + this child node. + tsnode:child_count() *tsnode:child_count()* Get the node's number of children. diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index 138031237e..308bfe8cfb 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -73,6 +73,7 @@ static struct luaL_Reg node_meta[] = { { "descendant_for_range", node_descendant_for_range }, { "named_descendant_for_range", node_named_descendant_for_range }, { "parent", node_parent }, + { "iter_children", node_iter_children }, { "_rawquery", node_rawquery }, { NULL, NULL } }; @@ -84,12 +85,17 @@ static struct luaL_Reg query_meta[] = { { NULL, NULL } }; -// cursor is not exposed, but still needs garbage collection +// cursors are not exposed, but still needs garbage collection static struct luaL_Reg querycursor_meta[] = { { "__gc", querycursor_gc }, { NULL, NULL } }; +static struct luaL_Reg treecursor_meta[] = { + { "__gc", treecursor_gc }, + { NULL, NULL } +}; + static PMap(cstr_t) *langs; static void build_meta(lua_State *L, const char *tname, const luaL_Reg *meta) @@ -116,6 +122,7 @@ void tslua_init(lua_State *L) build_meta(L, "treesitter_node", node_meta); build_meta(L, "treesitter_query", query_meta); build_meta(L, "treesitter_querycursor", querycursor_meta); + build_meta(L, "treesitter_treecursor", treecursor_meta); } int tslua_has_language(lua_State *L) @@ -746,6 +753,74 @@ static int node_named_descendant_for_range(lua_State *L) return 1; } +static int node_next_child(lua_State *L) +{ + TSTreeCursor *ud = luaL_checkudata( + L, lua_upvalueindex(1), "treesitter_treecursor"); + if (!ud) { + return 0; + } + + TSNode source; + if (!node_check(L, lua_upvalueindex(2), &source)) { + return 0; + } + + // First call should return first child + if (ts_node_eq(source, ts_tree_cursor_current_node(ud))) { + if (ts_tree_cursor_goto_first_child(ud)) { + goto push; + } else { + goto end; + } + } + + if (ts_tree_cursor_goto_next_sibling(ud)) { +push: + push_node( + L, + ts_tree_cursor_current_node(ud), + lua_upvalueindex(2)); // [node] + + const char * field = ts_tree_cursor_current_field_name(ud); + + if (field != NULL) { + lua_pushstring(L, ts_tree_cursor_current_field_name(ud)); + } else { + lua_pushnil(L); + } // [node, field_name_or_nil] + return 2; + } + +end: + return 0; +} + +static int node_iter_children(lua_State *L) +{ + TSNode source; + if (!node_check(L, 1, &source)) { + return 0; + } + + TSTreeCursor *ud = lua_newuserdata(L, sizeof(TSTreeCursor)); // [udata] + *ud = ts_tree_cursor_new(source); + + lua_getfield(L, LUA_REGISTRYINDEX, "treesitter_treecursor"); // [udata, mt] + lua_setmetatable(L, -2); // [udata] + lua_pushvalue(L, 1); // [udata, source_node] + lua_pushcclosure(L, node_next_child, 2); + + return 1; +} + +static int treecursor_gc(lua_State *L) +{ + TSTreeCursor *ud = luaL_checkudata(L, 1, "treesitter_treecursor"); + ts_tree_cursor_delete(ud); + return 0; +} + static int node_parent(lua_State *L) { TSNode node; diff --git a/test/functional/lua/treesitter_spec.lua b/test/functional/lua/treesitter_spec.lua index b0ac9e079a..f8d7f30261 100644 --- a/test/functional/lua/treesitter_spec.lua +++ b/test/functional/lua/treesitter_spec.lua @@ -127,6 +127,30 @@ void ui_refresh(void) } }]] + it('allows to iterate over nodes children', function() + if not check_parser() then return end + + insert(test_text); + + local res = exec_lua([[ + parser = vim.treesitter.get_parser(0, "c") + + func_node = parser:parse():root():child(0) + + res = {} + for node, field in func_node:iter_children() do + table.insert(res, {node:type(), field}) + end + return res + ]]) + + eq({ + {"primitive_type", "type"}, + {"function_declarator", "declarator"}, + {"compound_statement", "body"} + }, res) + end) + local query = [[ ((call_expression function: (identifier) @minfunc (argument_list (identifier) @min_id)) (eq? @minfunc "MIN")) "for" @keyword |