diff options
Diffstat (limited to 'runtime/lua/vim')
-rw-r--r-- | runtime/lua/vim/treesitter.lua | 22 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 56 |
2 files changed, 52 insertions, 26 deletions
diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua index 3a475b8f98..77bbfaa3ad 100644 --- a/runtime/lua/vim/treesitter.lua +++ b/runtime/lua/vim/treesitter.lua @@ -21,7 +21,9 @@ function Parser:parse() return self.tree end local changes - self.tree, changes = self._parser:parse_buf(self.bufnr) + + self.tree, changes = self._parser:parse(self:input_source()) + self.valid = true if not vim.tbl_isempty(changes) then @@ -33,6 +35,10 @@ function Parser:parse() return self.tree, changes end +function Parser:input_source() + return self.bufnr or self.str +end + function Parser:_on_bytes(bufnr, changed_tick, start_row, start_col, start_byte, old_row, old_col, old_byte, @@ -152,4 +158,18 @@ function M.get_parser(bufnr, lang, buf_attach_cbs) return parsers[id] end +function M.get_string_parser(str, lang) + vim.validate { + str = { str, 'string' }, + lang = { lang, 'string' } + } + language.require_language(lang) + + local self = setmetatable({str=str, lang=lang, valid=false}, Parser) + self._parser = vim._create_ts_parser(lang) + self:parse() + + return self +end + return M diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index ca27a50c6a..494fb59fa7 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -28,21 +28,27 @@ end --- Gets the text corresponding to a given node -- @param node the node -- @param bufnr the buffer from which the node in extracted. -function M.get_node_text(node, bufnr) - local start_row, start_col, end_row, end_col = node:range() - if start_row ~= end_row then - return nil +function M.get_node_text(node, source) + local start_row, start_col, start_byte = node:start() + local end_row, end_col, end_byte = node:end_() + + if type(source) == "number" then + if start_row ~= end_row then + return nil + end + local line = a.nvim_buf_get_lines(source, start_row, start_row+1, true)[1] + return string.sub(line, start_col+1, end_col) + elseif type(source) == "string" then + return source:sub(start_byte+1, end_byte) end - local line = a.nvim_buf_get_lines(bufnr, start_row, start_row+1, true)[1] - return string.sub(line, start_col+1, end_col) end -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) local predicate_handlers = { - ["eq?"] = function(match, _, bufnr, predicate) + ["eq?"] = function(match, _, source, predicate) local node = match[predicate[2]] - local node_text = M.get_node_text(node, bufnr) + local node_text = M.get_node_text(node, source) local str if type(predicate[3]) == "string" then @@ -50,7 +56,7 @@ local predicate_handlers = { str = predicate[3] else -- (#eq? @aa @bb) - str = M.get_node_text(match[predicate[3]], bufnr) + str = M.get_node_text(match[predicate[3]], source) end if node_text ~= str or str == nil then @@ -60,7 +66,7 @@ local predicate_handlers = { return true end, - ["lua-match?"] = function(match, _, bufnr, predicate) + ["lua-match?"] = function(match, _, source, predicate) local node = match[predicate[2]] local regex = predicate[3] local start_row, _, end_row, _ = node:range() @@ -68,7 +74,7 @@ local predicate_handlers = { return false end - return string.find(M.get_node_text(node, bufnr), regex) + return string.find(M.get_node_text(node, source), regex) end, ["match?"] = (function() @@ -88,7 +94,7 @@ local predicate_handlers = { end }) - return function(match, _, bufnr, pred) + return function(match, _, source, pred) local node = match[pred[2]] local start_row, start_col, end_row, end_col = node:range() if start_row ~= end_row then @@ -96,13 +102,13 @@ local predicate_handlers = { end local regex = compiled_vim_regexes[pred[3]] - return regex:match_line(bufnr, start_row, start_col, end_col) + return regex:match_line(source, start_row, start_col, end_col) end end)(), - ["contains?"] = function(match, _, bufnr, predicate) + ["contains?"] = function(match, _, source, predicate) local node = match[predicate[2]] - local node_text = M.get_node_text(node, bufnr) + local node_text = M.get_node_text(node, source) for i=3,#predicate do if string.find(node_text, predicate[i], 1, true) then @@ -139,7 +145,7 @@ local function xor(x, y) return (x or y) and not (x and y) end -function Query:match_preds(match, pattern, bufnr) +function Query:match_preds(match, pattern, source) local preds = self.info.patterns[pattern] for _, pred in pairs(preds or {}) do @@ -164,7 +170,7 @@ function Query:match_preds(match, pattern, bufnr) return false end - local pred_matches = handler(match, pattern, bufnr, pred) + local pred_matches = handler(match, pattern, source, pred) if not xor(is_not, pred_matches) then return false @@ -182,15 +188,15 @@ end -- -- @returns The matching capture id -- @returns The captured node -function Query:iter_captures(node, bufnr, start, stop) - if bufnr == 0 then - bufnr = vim.api.nvim_get_current_buf() +function Query:iter_captures(node, source, start, stop) + if type(source) == "number" and source == 0 then + source = vim.api.nvim_get_current_buf() end local raw_iter = node:_rawquery(self.query, true, start, stop) local function iter() local capture, captured_node, match = raw_iter() if match ~= nil then - local active = self:match_preds(match, match.pattern, bufnr) + local active = self:match_preds(match, match.pattern, source) match.active = active if not active then return iter() -- tail call: try next match @@ -210,15 +216,15 @@ end -- -- @returns The matching pattern id -- @returns The matching match -function Query:iter_matches(node, bufnr, start, stop) - if bufnr == 0 then - bufnr = vim.api.nvim_get_current_buf() +function Query:iter_matches(node, source, start, stop) + if type(source) == "number" and source == 0 then + source = vim.api.nvim_get_current_buf() end local raw_iter = node:_rawquery(self.query, false, start, stop) local function iter() local pattern, match = raw_iter() if match ~= nil then - local active = self:match_preds(match, pattern, bufnr) + local active = self:match_preds(match, pattern, source) if not active then return iter() -- tail call: try next match end |