diff options
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 56 |
1 files changed, 31 insertions, 25 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index ca27a50c6a..494fb59fa7 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -28,21 +28,27 @@ end --- Gets the text corresponding to a given node -- @param node the node -- @param bufnr the buffer from which the node in extracted. -function M.get_node_text(node, bufnr) - local start_row, start_col, end_row, end_col = node:range() - if start_row ~= end_row then - return nil +function M.get_node_text(node, source) + local start_row, start_col, start_byte = node:start() + local end_row, end_col, end_byte = node:end_() + + if type(source) == "number" then + if start_row ~= end_row then + return nil + end + local line = a.nvim_buf_get_lines(source, start_row, start_row+1, true)[1] + return string.sub(line, start_col+1, end_col) + elseif type(source) == "string" then + return source:sub(start_byte+1, end_byte) end - local line = a.nvim_buf_get_lines(bufnr, start_row, start_row+1, true)[1] - return string.sub(line, start_col+1, end_col) end -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) local predicate_handlers = { - ["eq?"] = function(match, _, bufnr, predicate) + ["eq?"] = function(match, _, source, predicate) local node = match[predicate[2]] - local node_text = M.get_node_text(node, bufnr) + local node_text = M.get_node_text(node, source) local str if type(predicate[3]) == "string" then @@ -50,7 +56,7 @@ local predicate_handlers = { str = predicate[3] else -- (#eq? @aa @bb) - str = M.get_node_text(match[predicate[3]], bufnr) + str = M.get_node_text(match[predicate[3]], source) end if node_text ~= str or str == nil then @@ -60,7 +66,7 @@ local predicate_handlers = { return true end, - ["lua-match?"] = function(match, _, bufnr, predicate) + ["lua-match?"] = function(match, _, source, predicate) local node = match[predicate[2]] local regex = predicate[3] local start_row, _, end_row, _ = node:range() @@ -68,7 +74,7 @@ local predicate_handlers = { return false end - return string.find(M.get_node_text(node, bufnr), regex) + return string.find(M.get_node_text(node, source), regex) end, ["match?"] = (function() @@ -88,7 +94,7 @@ local predicate_handlers = { end }) - return function(match, _, bufnr, pred) + return function(match, _, source, pred) local node = match[pred[2]] local start_row, start_col, end_row, end_col = node:range() if start_row ~= end_row then @@ -96,13 +102,13 @@ local predicate_handlers = { end local regex = compiled_vim_regexes[pred[3]] - return regex:match_line(bufnr, start_row, start_col, end_col) + return regex:match_line(source, start_row, start_col, end_col) end end)(), - ["contains?"] = function(match, _, bufnr, predicate) + ["contains?"] = function(match, _, source, predicate) local node = match[predicate[2]] - local node_text = M.get_node_text(node, bufnr) + local node_text = M.get_node_text(node, source) for i=3,#predicate do if string.find(node_text, predicate[i], 1, true) then @@ -139,7 +145,7 @@ local function xor(x, y) return (x or y) and not (x and y) end -function Query:match_preds(match, pattern, bufnr) +function Query:match_preds(match, pattern, source) local preds = self.info.patterns[pattern] for _, pred in pairs(preds or {}) do @@ -164,7 +170,7 @@ function Query:match_preds(match, pattern, bufnr) return false end - local pred_matches = handler(match, pattern, bufnr, pred) + local pred_matches = handler(match, pattern, source, pred) if not xor(is_not, pred_matches) then return false @@ -182,15 +188,15 @@ end -- -- @returns The matching capture id -- @returns The captured node -function Query:iter_captures(node, bufnr, start, stop) - if bufnr == 0 then - bufnr = vim.api.nvim_get_current_buf() +function Query:iter_captures(node, source, start, stop) + if type(source) == "number" and source == 0 then + source = vim.api.nvim_get_current_buf() end local raw_iter = node:_rawquery(self.query, true, start, stop) local function iter() local capture, captured_node, match = raw_iter() if match ~= nil then - local active = self:match_preds(match, match.pattern, bufnr) + local active = self:match_preds(match, match.pattern, source) match.active = active if not active then return iter() -- tail call: try next match @@ -210,15 +216,15 @@ end -- -- @returns The matching pattern id -- @returns The matching match -function Query:iter_matches(node, bufnr, start, stop) - if bufnr == 0 then - bufnr = vim.api.nvim_get_current_buf() +function Query:iter_matches(node, source, start, stop) + if type(source) == "number" and source == 0 then + source = vim.api.nvim_get_current_buf() end local raw_iter = node:_rawquery(self.query, false, start, stop) local function iter() local pattern, match = raw_iter() if match ~= nil then - local active = self:match_preds(match, pattern, bufnr) + local active = self:match_preds(match, pattern, source) if not active then return iter() -- tail call: try next match end |