diff options
Diffstat (limited to 'runtime/lua/vim/treesitter.lua')
-rw-r--r-- | runtime/lua/vim/treesitter.lua | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua index 70f2c425ed..6431162799 100644 --- a/runtime/lua/vim/treesitter.lua +++ b/runtime/lua/vim/treesitter.lua @@ -118,4 +118,127 @@ function M.get_string_parser(str, lang, opts) return LanguageTree.new(str, lang, opts) end +--- Determines whether a node is the ancestor of another +--- +---@param dest table the possible ancestor +---@param source table the possible descendant node +--- +---@returns (boolean) True if dest is an ancestor of source +function M.is_ancestor(dest, source) + if not (dest and source) then + return false + end + + local current = source + while current ~= nil do + if current == dest then + return true + end + + current = current:parent() + end + + return false +end + +--- Get the node's range or unpack a range table +--- +---@param node_or_range table +--- +---@returns start_row, start_col, end_row, end_col +function M.get_node_range(node_or_range) + if type(node_or_range) == 'table' then + return unpack(node_or_range) + else + return node_or_range:range() + end +end + +---Determines whether (line, col) position is in node range +--- +---@param node Node defining the range +---@param line A line (0-based) +---@param col A column (0-based) +function M.is_in_node_range(node, line, col) + local start_line, start_col, end_line, end_col = M.get_node_range(node) + if line >= start_line and line <= end_line then + if line == start_line and line == end_line then + return col >= start_col and col < end_col + elseif line == start_line then + return col >= start_col + elseif line == end_line then + return col < end_col + else + return true + end + else + return false + end +end + +---Determines if a node contains a range +---@param node table The node +---@param range table The range +--- +---@returns (boolean) True if the node contains the range +function M.node_contains(node, range) + local start_row, start_col, end_row, end_col = node:range() + local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2]) + local end_fits = end_row > range[3] or (end_row == range[3] and end_col >= range[4]) + + return start_fits and end_fits +end + +---Gets a list of captures for a given cursor position +---@param bufnr number The buffer number +---@param row number The position row +---@param col number The position column +--- +---@returns (table) A table of captures +function M.get_captures_at_position(bufnr, row, col) + if bufnr == 0 then + bufnr = a.nvim_get_current_buf() + end + local buf_highlighter = M.highlighter.active[bufnr] + + if not buf_highlighter then + return {} + end + + local matches = {} + + buf_highlighter.tree:for_each_tree(function(tstree, tree) + if not tstree then + return + end + + local root = tstree:root() + local root_start_row, _, root_end_row, _ = root:range() + + -- Only worry about trees within the line range + if root_start_row > row or root_end_row < row then + return + end + + local q = buf_highlighter:get_query(tree:lang()) + + -- Some injected languages may not have highlight queries. + if not q:query() then + return + end + + local iter = q:query():iter_captures(root, buf_highlighter.bufnr, row, row + 1) + + for capture, node, metadata in iter do + if M.is_in_node_range(node, row, col) then + local c = q._query.captures[capture] -- name of the capture in the query + if c ~= nil then + table.insert(matches, { capture = c, priority = metadata.priority }) + end + end + end + end, true) + return matches +end + return M |