diff options
author | vanaigr <vanaigranov@gmail.com> | 2024-05-16 09:57:58 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-16 16:57:58 +0200 |
commit | 4b029163345333a2c6975cd0dace6613b036ae47 (patch) | |
tree | ef4b6f914f6415d019e0a6fb64f1e8fd355aa5bc | |
parent | 31dc6279693886a628119cd6c779e580faab32fd (diff) | |
download | rneovim-4b029163345333a2c6975cd0dace6613b036ae47.tar.gz rneovim-4b029163345333a2c6975cd0dace6613b036ae47.tar.bz2 rneovim-4b029163345333a2c6975cd0dace6613b036ae47.zip |
perf(treesitter): use child_containing_descendant() in has-ancestor? (#28512)
Problem: `has-ancestor?` is O(n²) for the depth of the tree since it iterates over each of the node's ancestors (bottom-up), and each ancestor takes O(n) time.
This happens because tree-sitter's nodes don't store their parent nodes, and the tree is searched (top-down) each time a new parent is requested.
Solution: Make use of new `ts_node_child_containing_descendant()` in tree-sitter v0.22.6 (which is now the minimum required version) to rewrite the `has-ancestor?` predicate in C to become O(n).
For a sample file, decreases the time taken by `has-ancestor?` from 360ms to 6ms.
-rw-r--r-- | runtime/doc/treesitter.txt | 5 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/_meta.lua | 1 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 13 | ||||
-rw-r--r-- | src/nvim/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/nvim/lua/treesitter.c | 45 | ||||
-rw-r--r-- | test/functional/treesitter/node_spec.lua | 26 | ||||
-rw-r--r-- | test/functional/treesitter/query_spec.lua | 64 |
7 files changed, 129 insertions, 27 deletions
diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt index e105d06ebb..0b84bb60d4 100644 --- a/runtime/doc/treesitter.txt +++ b/runtime/doc/treesitter.txt @@ -78,6 +78,8 @@ An instance `TSNode` of a treesitter node supports the following methods. TSNode:parent() *TSNode:parent()* Get the node's immediate parent. + Prefer |TSNode:child_containing_descendant()| + for iterating over the node's ancestors. TSNode:next_sibling() *TSNode:next_sibling()* Get the node's next sibling. @@ -114,6 +116,9 @@ TSNode:named_child({index}) *TSNode:named_child()* Get the node's named child at the given {index}, where zero represents the first named child. +TSNode:child_containing_descendant({descendant}) *TSNode:child_containing_descendant()* + Get the node's child that contains {descendant}. + TSNode:start() *TSNode:start()* Get the node's start position. Return three values: the row, column and total byte count (all zero-based). diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua index 34a51e42f6..177699a207 100644 --- a/runtime/lua/vim/treesitter/_meta.lua +++ b/runtime/lua/vim/treesitter/_meta.lua @@ -20,6 +20,7 @@ error('Cannot require a meta file') ---@field descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode? ---@field named_descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode? ---@field parent fun(self: TSNode): TSNode? +---@field child_containing_descendant fun(self: TSNode, descendant: TSNode): TSNode? ---@field next_sibling fun(self: TSNode): TSNode? ---@field prev_sibling fun(self: TSNode): TSNode? ---@field next_named_sibling fun(self: TSNode): TSNode? diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 36c78b7f1d..ef5c2143a7 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -457,17 +457,8 @@ local predicate_handlers = { end for _, node in ipairs(nodes) do - local ancestor_types = {} --- @type table<string, boolean> - for _, type in ipairs({ unpack(predicate, 3) }) do - ancestor_types[type] = true - end - - local cur = node:parent() - while cur do - if ancestor_types[cur:type()] then - return true - end - cur = cur:parent() + if node:__has_ancestor(predicate) then + return true end end return false diff --git a/src/nvim/CMakeLists.txt b/src/nvim/CMakeLists.txt index d9cc695c55..937cfaaa31 100644 --- a/src/nvim/CMakeLists.txt +++ b/src/nvim/CMakeLists.txt @@ -33,7 +33,7 @@ find_package(Libuv 1.28.0 REQUIRED) find_package(Libvterm 0.3.3 REQUIRED) find_package(Lpeg REQUIRED) find_package(Msgpack 1.0.0 REQUIRED) -find_package(Treesitter 0.20.9 REQUIRED) +find_package(Treesitter 0.22.6 REQUIRED) find_package(Unibilium 2.0 REQUIRED) target_link_libraries(main_lib INTERFACE diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c index 8befc6d32d..e87cf756a8 100644 --- a/src/nvim/lua/treesitter.c +++ b/src/nvim/lua/treesitter.c @@ -725,6 +725,8 @@ static struct luaL_Reg node_meta[] = { { "descendant_for_range", node_descendant_for_range }, { "named_descendant_for_range", node_named_descendant_for_range }, { "parent", node_parent }, + { "__has_ancestor", __has_ancestor }, + { "child_containing_descendant", node_child_containing_descendant }, { "iter_children", node_iter_children }, { "next_sibling", node_next_sibling }, { "prev_sibling", node_prev_sibling }, @@ -1052,6 +1054,49 @@ static int node_parent(lua_State *L) return 1; } +static int __has_ancestor(lua_State *L) +{ + TSNode descendant = node_check(L, 1); + if (lua_type(L, 2) != LUA_TTABLE) { + lua_pushboolean(L, false); + return 1; + } + int const pred_len = (int)lua_objlen(L, 2); + + TSNode node = ts_tree_root_node(descendant.tree); + while (!ts_node_is_null(node)) { + char const *node_type = ts_node_type(node); + size_t node_type_len = strlen(node_type); + + for (int i = 3; i <= pred_len; i++) { + lua_rawgeti(L, 2, i); + if (lua_type(L, -1) == LUA_TSTRING) { + size_t check_len; + char const *check_str = lua_tolstring(L, -1, &check_len); + if (node_type_len == check_len && memcmp(node_type, check_str, check_len) == 0) { + lua_pushboolean(L, true); + return 1; + } + } + lua_pop(L, 1); + } + + node = ts_node_child_containing_descendant(node, descendant); + } + + lua_pushboolean(L, false); + return 1; +} + +static int node_child_containing_descendant(lua_State *L) +{ + TSNode node = node_check(L, 1); + TSNode descendant = node_check(L, 2); + TSNode child = ts_node_child_containing_descendant(node, descendant); + push_node(L, child, 1); + return 1; +} + static int node_next_sibling(lua_State *L) { TSNode node = node_check(L, 1); diff --git a/test/functional/treesitter/node_spec.lua b/test/functional/treesitter/node_spec.lua index 8adec82774..96579f296b 100644 --- a/test/functional/treesitter/node_spec.lua +++ b/test/functional/treesitter/node_spec.lua @@ -143,4 +143,30 @@ describe('treesitter node API', function() eq(28, lua_eval('root:byte_length()')) eq(3, lua_eval('child:byte_length()')) end) + + it('child_containing_descendant() works', function() + insert([[ + int main() { + int x = 3; + }]]) + + exec_lua([[ + tree = vim.treesitter.get_parser(0, "c"):parse()[1] + root = tree:root() + main = root:child(0) + body = main:child(2) + statement = body:child(1) + declarator = statement:child(1) + value = declarator:child(1) + ]]) + + eq(lua_eval('main:type()'), lua_eval('root:child_containing_descendant(value):type()')) + eq(lua_eval('body:type()'), lua_eval('main:child_containing_descendant(value):type()')) + eq(lua_eval('statement:type()'), lua_eval('body:child_containing_descendant(value):type()')) + eq( + lua_eval('declarator:type()'), + lua_eval('statement:child_containing_descendant(value):type()') + ) + eq(vim.NIL, lua_eval('declarator:child_containing_descendant(value)')) + end) end) diff --git a/test/functional/treesitter/query_spec.lua b/test/functional/treesitter/query_spec.lua index 96665ee2e7..c3a376cd71 100644 --- a/test/functional/treesitter/query_spec.lua +++ b/test/functional/treesitter/query_spec.lua @@ -10,6 +10,22 @@ local pcall_err = t.pcall_err local api = n.api local fn = n.fn +local get_query_result_code = [[ + function get_query_result(query_text) + cquery = vim.treesitter.query.parse("c", query_text) + parser = vim.treesitter.get_parser(0, "c") + tree = parser:parse()[1] + res = {} + for cid, node in cquery:iter_captures(tree:root(), 0) do + -- can't transmit node over RPC. just check the name, range, and text + local text = vim.treesitter.get_node_text(node, 0) + local range = {node:range()} + table.insert(res, { cquery.captures[cid], node:type(), range, text }) + end + return res + end +]] + describe('treesitter query API', function() before_each(function() clear() @@ -291,21 +307,7 @@ void ui_refresh(void) return 0; } ]]) - exec_lua([[ - function get_query_result(query_text) - cquery = vim.treesitter.query.parse("c", query_text) - parser = vim.treesitter.get_parser(0, "c") - tree = parser:parse()[1] - res = {} - for cid, node in cquery:iter_captures(tree:root(), 0) do - -- can't transmit node over RPC. just check the name, range, and text - local text = vim.treesitter.get_node_text(node, 0) - local range = {node:range()} - table.insert(res, { cquery.captures[cid], node:type(), range, text }) - end - return res - end - ]]) + exec_lua(get_query_result_code) local res0 = exec_lua( [[return get_query_result(...)]], @@ -333,6 +335,38 @@ void ui_refresh(void) }, res1) end) + it('supports builtin predicate has-ancestor?', function() + insert([[ + int x = 123; + enum C { y = 124 }; + int main() { int z = 125; }]]) + exec_lua(get_query_result_code) + + local result = exec_lua( + [[return get_query_result(...)]], + [[((number_literal) @literal (#has-ancestor? @literal "function_definition"))]] + ) + eq({ { 'literal', 'number_literal', { 2, 21, 2, 24 }, '125' } }, result) + + result = exec_lua( + [[return get_query_result(...)]], + [[((number_literal) @literal (#has-ancestor? @literal "function_definition" "enum_specifier"))]] + ) + eq({ + { 'literal', 'number_literal', { 1, 13, 1, 16 }, '124' }, + { 'literal', 'number_literal', { 2, 21, 2, 24 }, '125' }, + }, result) + + result = exec_lua( + [[return get_query_result(...)]], + [[((number_literal) @literal (#not-has-ancestor? @literal "enum_specifier"))]] + ) + eq({ + { 'literal', 'number_literal', { 0, 8, 0, 11 }, '123' }, + { 'literal', 'number_literal', { 2, 21, 2, 24 }, '125' }, + }, result) + end) + it('allows loading query with escaped quotes and capture them `#{lua,vim}-match`?', function() insert('char* astring = "Hello World!";') |