aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorvanaigr <vanaigranov@gmail.com>2024-05-16 09:57:58 -0500
committerGitHub <noreply@github.com>2024-05-16 16:57:58 +0200
commit4b029163345333a2c6975cd0dace6613b036ae47 (patch)
treeef4b6f914f6415d019e0a6fb64f1e8fd355aa5bc
parent31dc6279693886a628119cd6c779e580faab32fd (diff)
downloadrneovim-4b029163345333a2c6975cd0dace6613b036ae47.tar.gz
rneovim-4b029163345333a2c6975cd0dace6613b036ae47.tar.bz2
rneovim-4b029163345333a2c6975cd0dace6613b036ae47.zip
perf(treesitter): use child_containing_descendant() in has-ancestor? (#28512)
Problem: `has-ancestor?` is O(n²) for the depth of the tree since it iterates over each of the node's ancestors (bottom-up), and each ancestor takes O(n) time. This happens because tree-sitter's nodes don't store their parent nodes, and the tree is searched (top-down) each time a new parent is requested. Solution: Make use of new `ts_node_child_containing_descendant()` in tree-sitter v0.22.6 (which is now the minimum required version) to rewrite the `has-ancestor?` predicate in C to become O(n). For a sample file, decreases the time taken by `has-ancestor?` from 360ms to 6ms.
-rw-r--r--runtime/doc/treesitter.txt5
-rw-r--r--runtime/lua/vim/treesitter/_meta.lua1
-rw-r--r--runtime/lua/vim/treesitter/query.lua13
-rw-r--r--src/nvim/CMakeLists.txt2
-rw-r--r--src/nvim/lua/treesitter.c45
-rw-r--r--test/functional/treesitter/node_spec.lua26
-rw-r--r--test/functional/treesitter/query_spec.lua64
7 files changed, 129 insertions, 27 deletions
diff --git a/runtime/doc/treesitter.txt b/runtime/doc/treesitter.txt
index e105d06ebb..0b84bb60d4 100644
--- a/runtime/doc/treesitter.txt
+++ b/runtime/doc/treesitter.txt
@@ -78,6 +78,8 @@ An instance `TSNode` of a treesitter node supports the following methods.
TSNode:parent() *TSNode:parent()*
Get the node's immediate parent.
+ Prefer |TSNode:child_containing_descendant()|
+ for iterating over the node's ancestors.
TSNode:next_sibling() *TSNode:next_sibling()*
Get the node's next sibling.
@@ -114,6 +116,9 @@ TSNode:named_child({index}) *TSNode:named_child()*
Get the node's named child at the given {index}, where zero represents the
first named child.
+TSNode:child_containing_descendant({descendant}) *TSNode:child_containing_descendant()*
+ Get the node's child that contains {descendant}.
+
TSNode:start() *TSNode:start()*
Get the node's start position. Return three values: the row, column and
total byte count (all zero-based).
diff --git a/runtime/lua/vim/treesitter/_meta.lua b/runtime/lua/vim/treesitter/_meta.lua
index 34a51e42f6..177699a207 100644
--- a/runtime/lua/vim/treesitter/_meta.lua
+++ b/runtime/lua/vim/treesitter/_meta.lua
@@ -20,6 +20,7 @@ error('Cannot require a meta file')
---@field descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode?
---@field named_descendant_for_range fun(self: TSNode, start_row: integer, start_col: integer, end_row: integer, end_col: integer): TSNode?
---@field parent fun(self: TSNode): TSNode?
+---@field child_containing_descendant fun(self: TSNode, descendant: TSNode): TSNode?
---@field next_sibling fun(self: TSNode): TSNode?
---@field prev_sibling fun(self: TSNode): TSNode?
---@field next_named_sibling fun(self: TSNode): TSNode?
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 36c78b7f1d..ef5c2143a7 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -457,17 +457,8 @@ local predicate_handlers = {
end
for _, node in ipairs(nodes) do
- local ancestor_types = {} --- @type table<string, boolean>
- for _, type in ipairs({ unpack(predicate, 3) }) do
- ancestor_types[type] = true
- end
-
- local cur = node:parent()
- while cur do
- if ancestor_types[cur:type()] then
- return true
- end
- cur = cur:parent()
+ if node:__has_ancestor(predicate) then
+ return true
end
end
return false
diff --git a/src/nvim/CMakeLists.txt b/src/nvim/CMakeLists.txt
index d9cc695c55..937cfaaa31 100644
--- a/src/nvim/CMakeLists.txt
+++ b/src/nvim/CMakeLists.txt
@@ -33,7 +33,7 @@ find_package(Libuv 1.28.0 REQUIRED)
find_package(Libvterm 0.3.3 REQUIRED)
find_package(Lpeg REQUIRED)
find_package(Msgpack 1.0.0 REQUIRED)
-find_package(Treesitter 0.20.9 REQUIRED)
+find_package(Treesitter 0.22.6 REQUIRED)
find_package(Unibilium 2.0 REQUIRED)
target_link_libraries(main_lib INTERFACE
diff --git a/src/nvim/lua/treesitter.c b/src/nvim/lua/treesitter.c
index 8befc6d32d..e87cf756a8 100644
--- a/src/nvim/lua/treesitter.c
+++ b/src/nvim/lua/treesitter.c
@@ -725,6 +725,8 @@ static struct luaL_Reg node_meta[] = {
{ "descendant_for_range", node_descendant_for_range },
{ "named_descendant_for_range", node_named_descendant_for_range },
{ "parent", node_parent },
+ { "__has_ancestor", __has_ancestor },
+ { "child_containing_descendant", node_child_containing_descendant },
{ "iter_children", node_iter_children },
{ "next_sibling", node_next_sibling },
{ "prev_sibling", node_prev_sibling },
@@ -1052,6 +1054,49 @@ static int node_parent(lua_State *L)
return 1;
}
+static int __has_ancestor(lua_State *L)
+{
+ TSNode descendant = node_check(L, 1);
+ if (lua_type(L, 2) != LUA_TTABLE) {
+ lua_pushboolean(L, false);
+ return 1;
+ }
+ int const pred_len = (int)lua_objlen(L, 2);
+
+ TSNode node = ts_tree_root_node(descendant.tree);
+ while (!ts_node_is_null(node)) {
+ char const *node_type = ts_node_type(node);
+ size_t node_type_len = strlen(node_type);
+
+ for (int i = 3; i <= pred_len; i++) {
+ lua_rawgeti(L, 2, i);
+ if (lua_type(L, -1) == LUA_TSTRING) {
+ size_t check_len;
+ char const *check_str = lua_tolstring(L, -1, &check_len);
+ if (node_type_len == check_len && memcmp(node_type, check_str, check_len) == 0) {
+ lua_pushboolean(L, true);
+ return 1;
+ }
+ }
+ lua_pop(L, 1);
+ }
+
+ node = ts_node_child_containing_descendant(node, descendant);
+ }
+
+ lua_pushboolean(L, false);
+ return 1;
+}
+
+static int node_child_containing_descendant(lua_State *L)
+{
+ TSNode node = node_check(L, 1);
+ TSNode descendant = node_check(L, 2);
+ TSNode child = ts_node_child_containing_descendant(node, descendant);
+ push_node(L, child, 1);
+ return 1;
+}
+
static int node_next_sibling(lua_State *L)
{
TSNode node = node_check(L, 1);
diff --git a/test/functional/treesitter/node_spec.lua b/test/functional/treesitter/node_spec.lua
index 8adec82774..96579f296b 100644
--- a/test/functional/treesitter/node_spec.lua
+++ b/test/functional/treesitter/node_spec.lua
@@ -143,4 +143,30 @@ describe('treesitter node API', function()
eq(28, lua_eval('root:byte_length()'))
eq(3, lua_eval('child:byte_length()'))
end)
+
+ it('child_containing_descendant() works', function()
+ insert([[
+ int main() {
+ int x = 3;
+ }]])
+
+ exec_lua([[
+ tree = vim.treesitter.get_parser(0, "c"):parse()[1]
+ root = tree:root()
+ main = root:child(0)
+ body = main:child(2)
+ statement = body:child(1)
+ declarator = statement:child(1)
+ value = declarator:child(1)
+ ]])
+
+ eq(lua_eval('main:type()'), lua_eval('root:child_containing_descendant(value):type()'))
+ eq(lua_eval('body:type()'), lua_eval('main:child_containing_descendant(value):type()'))
+ eq(lua_eval('statement:type()'), lua_eval('body:child_containing_descendant(value):type()'))
+ eq(
+ lua_eval('declarator:type()'),
+ lua_eval('statement:child_containing_descendant(value):type()')
+ )
+ eq(vim.NIL, lua_eval('declarator:child_containing_descendant(value)'))
+ end)
end)
diff --git a/test/functional/treesitter/query_spec.lua b/test/functional/treesitter/query_spec.lua
index 96665ee2e7..c3a376cd71 100644
--- a/test/functional/treesitter/query_spec.lua
+++ b/test/functional/treesitter/query_spec.lua
@@ -10,6 +10,22 @@ local pcall_err = t.pcall_err
local api = n.api
local fn = n.fn
+local get_query_result_code = [[
+ function get_query_result(query_text)
+ cquery = vim.treesitter.query.parse("c", query_text)
+ parser = vim.treesitter.get_parser(0, "c")
+ tree = parser:parse()[1]
+ res = {}
+ for cid, node in cquery:iter_captures(tree:root(), 0) do
+ -- can't transmit node over RPC. just check the name, range, and text
+ local text = vim.treesitter.get_node_text(node, 0)
+ local range = {node:range()}
+ table.insert(res, { cquery.captures[cid], node:type(), range, text })
+ end
+ return res
+ end
+]]
+
describe('treesitter query API', function()
before_each(function()
clear()
@@ -291,21 +307,7 @@ void ui_refresh(void)
return 0;
}
]])
- exec_lua([[
- function get_query_result(query_text)
- cquery = vim.treesitter.query.parse("c", query_text)
- parser = vim.treesitter.get_parser(0, "c")
- tree = parser:parse()[1]
- res = {}
- for cid, node in cquery:iter_captures(tree:root(), 0) do
- -- can't transmit node over RPC. just check the name, range, and text
- local text = vim.treesitter.get_node_text(node, 0)
- local range = {node:range()}
- table.insert(res, { cquery.captures[cid], node:type(), range, text })
- end
- return res
- end
- ]])
+ exec_lua(get_query_result_code)
local res0 = exec_lua(
[[return get_query_result(...)]],
@@ -333,6 +335,38 @@ void ui_refresh(void)
}, res1)
end)
+ it('supports builtin predicate has-ancestor?', function()
+ insert([[
+ int x = 123;
+ enum C { y = 124 };
+ int main() { int z = 125; }]])
+ exec_lua(get_query_result_code)
+
+ local result = exec_lua(
+ [[return get_query_result(...)]],
+ [[((number_literal) @literal (#has-ancestor? @literal "function_definition"))]]
+ )
+ eq({ { 'literal', 'number_literal', { 2, 21, 2, 24 }, '125' } }, result)
+
+ result = exec_lua(
+ [[return get_query_result(...)]],
+ [[((number_literal) @literal (#has-ancestor? @literal "function_definition" "enum_specifier"))]]
+ )
+ eq({
+ { 'literal', 'number_literal', { 1, 13, 1, 16 }, '124' },
+ { 'literal', 'number_literal', { 2, 21, 2, 24 }, '125' },
+ }, result)
+
+ result = exec_lua(
+ [[return get_query_result(...)]],
+ [[((number_literal) @literal (#not-has-ancestor? @literal "enum_specifier"))]]
+ )
+ eq({
+ { 'literal', 'number_literal', { 0, 8, 0, 11 }, '123' },
+ { 'literal', 'number_literal', { 2, 21, 2, 24 }, '125' },
+ }, result)
+ end)
+
it('allows loading query with escaped quotes and capture them `#{lua,vim}-match`?', function()
insert('char* astring = "Hello World!";')