diff options
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 46 | ||||
-rw-r--r-- | test/functional/lua/treesitter_spec.lua | 36 |
2 files changed, 76 insertions, 6 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index b30bf5fb6b..b43c28b0ab 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -24,7 +24,11 @@ function M.parse_query(lang, query) end -- TODO(vigoux): support multiline nodes too -local function get_node_text(node, bufnr) + +--- Gets the text corresponding to a given node +-- @param node the node +-- @param bufnr the buffer from which the node in extracted. +function M.get_node_text(node, bufnr) local start_row, start_col, end_row, end_col = node:range() if start_row ~= end_row then return nil @@ -34,11 +38,11 @@ local function get_node_text(node, bufnr) end -- Predicate handler receive the following arguments --- (match, pattern, bufnr, regexes, index, predicate) +-- (match, pattern, bufnr, predicate) local predicate_handlers = { ["eq?"] = function(match, _, bufnr, predicate) local node = match[predicate[2]] - local node_text = get_node_text(node, bufnr) + local node_text = M.get_node_text(node, bufnr) local str if type(predicate[3]) == "string" then @@ -46,7 +50,7 @@ local predicate_handlers = { str = predicate[3] else -- (#eq? @aa @bb) - str = get_node_text(match[predicate[3]], bufnr) + str = M.get_node_text(match[predicate[3]], bufnr) end if node_text ~= str or str == nil then @@ -63,12 +67,42 @@ local predicate_handlers = { return false end - return string.find(get_node_text(node, bufnr), regex) + return string.find(M.get_node_text(node, bufnr), regex) end, + ["vim-match?"] = (function() + + local magic_prefixes = {['\\v']=true, ['\\m']=true, ['\\M']=true, ['\\V']=true} + local function check_magic(str) + if string.len(str) < 2 or magic_prefixes[string.sub(str,1,2)] then + return str + end + return '\\v'..str + end + + local compiled_vim_regexes = setmetatable({}, { + __index = function(t, pattern) + local res = vim.regex(check_magic(pattern)) + rawset(t, pattern, res) + return res + end + }) + + return function(match, _, bufnr, pred) + local node = match[pred[2]] + local start_row, start_col, end_row, end_col = node:range() + if start_row ~= end_row then + return false + end + + local regex = compiled_vim_regexes[pred[3]] + return regex:match_line(bufnr, start_row, start_col, end_col) + end + end)(), + ["contains?"] = function(match, _, bufnr, predicate) local node = match[predicate[2]] - local node_text = get_node_text(node, bufnr) + local node_text = M.get_node_text(node, bufnr) for i=3,#predicate do if string.find(node_text, predicate[i], 1, true) then diff --git a/test/functional/lua/treesitter_spec.lua b/test/functional/lua/treesitter_spec.lua index 808f6aa8db..b0ac9e079a 100644 --- a/test/functional/lua/treesitter_spec.lua +++ b/test/functional/lua/treesitter_spec.lua @@ -198,6 +198,41 @@ void ui_refresh(void) }, res) end) + it('allows to add predicates', function() + insert([[ + int main(void) { + return 0; + } + ]]) + + local custom_query = "((identifier) @main (#is-main? @main))" + + local res = exec_lua([[ + local query = require"vim.treesitter.query" + + local function is_main(match, pattern, bufnr, predicate) + local node = match[ predicate[2] ] + + return query.get_node_text(node, bufnr) + end + + local parser = vim.treesitter.get_parser(0, "c") + + query.add_predicate("is-main?", is_main) + + local query = query.parse_query("c", ...) + + local nodes = {} + for _, node in query:iter_captures(parser:parse():root(), 0, 0, 19) do + table.insert(nodes, {node:range()}) + end + + return nodes + ]], custom_query) + + eq({{0, 4, 0, 8}}, res) + end) + it('supports highlighting', function() if not check_parser() then return end @@ -246,6 +281,7 @@ static int nlua_schedule(lua_State *const lstate) ; Use lua regexes ((identifier) @Identifier (#contains? @Identifier "lua_")) ((identifier) @Constant (#match? @Constant "^[A-Z_]+$")) +((identifier) @Normal (#vim-match? @Constant "^lstate$")) ((binary_expression left: (identifier) @WarningMsg.left right: (identifier) @WarningMsg.right) (#eq? @WarningMsg.left @WarningMsg.right)) |