From 6e6c36ca5bc31de39504a2949da85043d1469db8 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Wed, 17 Nov 2021 15:17:15 +0000 Subject: feat(treesitter): multiline match predicates --- runtime/lua/vim/treesitter/query.lua | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) (limited to 'runtime/lua/vim') diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 5fa45289d8..ebed502c92 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -167,30 +167,31 @@ end --- Gets the text corresponding to a given node --- ---@param node the node ----@param bsource The buffer or string from which the node is extracted +---@param source The buffer or string from which the node is extracted function M.get_node_text(node, source) local start_row, start_col, start_byte = node:start() local end_row, end_col, end_byte = node:end_() if type(source) == "number" then local lines - local eof_row = vim.api.nvim_buf_line_count(source) + local eof_row = a.nvim_buf_line_count(source) if start_row >= eof_row then return nil end + if end_col == 0 then lines = a.nvim_buf_get_lines(source, start_row, end_row, true) - end_col = #lines[#lines] + end_col = -1 else lines = a.nvim_buf_get_lines(source, start_row, end_row + 1, true) end - lines[1] = string.sub(lines[1], start_col + 1) - local end_index = end_col if #lines == 1 then - end_index = end_col - start_col + lines[1] = string.sub(lines[1], start_col+1, end_col) + else + lines[1] = string.sub(lines[1], start_col+1) + lines[#lines] = string.sub(lines[#lines], 1, end_col) end - lines[#lines] = string.sub(lines[#lines], 1, end_index) return table.concat(lines, "\n") elseif type(source) == "string" then @@ -247,13 +248,8 @@ local predicate_handlers = { return function(match, _, source, pred) local node = match[pred[2]] - local start_row, start_col, end_row, end_col = node:range() - if start_row ~= end_row then - return false - end - local regex = compiled_vim_regexes[pred[3]] - return regex:match_line(source, start_row, start_col, end_col) + return regex:match_str(M.get_node_text(node, source)) end end)(), -- cgit