diff options
-rw-r--r-- | runtime/lua/vim/treesitter.lua | 12 | ||||
-rw-r--r-- | test/functional/treesitter/utils_spec.lua | 8 |
2 files changed, 8 insertions, 12 deletions
diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua index 809ea59b94..baf47482a8 100644 --- a/runtime/lua/vim/treesitter.lua +++ b/runtime/lua/vim/treesitter.lua @@ -159,16 +159,8 @@ function M.is_ancestor(dest, source) return false end - local current = source ---@type TSNode? - while current ~= nil do - if current == dest then - return true - end - - current = current:parent() - end - - return false + -- child_containing_descendant returns nil if dest is a direct parent + return source:parent() == dest or dest:child_containing_descendant(source) ~= nil end --- Returns the node's range or an unpacked range table diff --git a/test/functional/treesitter/utils_spec.lua b/test/functional/treesitter/utils_spec.lua index e079a7c8e7..34bea349f6 100644 --- a/test/functional/treesitter/utils_spec.lua +++ b/test/functional/treesitter/utils_spec.lua @@ -21,12 +21,16 @@ describe('treesitter utils', function() local parser = vim.treesitter.get_parser(0, 'c') local tree = parser:parse()[1] local root = tree:root() - _G.ancestor = root:child(0) - _G.child = _G.ancestor:child(0) + _G.ancestor = assert(root:child(0)) + _G.child = assert(_G.ancestor:named_child(1)) + _G.child_sibling = assert(_G.ancestor:named_child(2)) + _G.grandchild = assert(_G.child:named_child(0)) end) eq(true, exec_lua('return vim.treesitter.is_ancestor(_G.ancestor, _G.child)')) + eq(true, exec_lua('return vim.treesitter.is_ancestor(_G.ancestor, _G.grandchild)')) eq(false, exec_lua('return vim.treesitter.is_ancestor(_G.child, _G.ancestor)')) + eq(false, exec_lua('return vim.treesitter.is_ancestor(_G.child, _G.child_sibling)')) end) it('can detect if a position is contained in a node', function() |