diff options
author | Josh Rahm <rahm@google.com> | 2022-07-18 19:37:18 +0000 |
---|---|---|
committer | Josh Rahm <rahm@google.com> | 2022-07-18 19:37:18 +0000 |
commit | 308e1940dcd64aa6c344c403d4f9e0dda58d9c5c (patch) | |
tree | 35fe43e01755e0f312650667004487a44d6b7941 /runtime/lua/vim/treesitter/query.lua | |
parent | 96a00c7c588b2f38a2424aeeb4ea3581d370bf2d (diff) | |
parent | e8c94697bcbe23a5c7b07c292b90a6b70aadfa87 (diff) | |
download | rneovim-308e1940dcd64aa6c344c403d4f9e0dda58d9c5c.tar.gz rneovim-308e1940dcd64aa6c344c403d4f9e0dda58d9c5c.tar.bz2 rneovim-308e1940dcd64aa6c344c403d4f9e0dda58d9c5c.zip |
Merge remote-tracking branch 'upstream/master' into rahm
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 195 |
1 files changed, 118 insertions, 77 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index ebed502c92..103e85abfd 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -1,5 +1,5 @@ local a = vim.api -local language = require'vim.treesitter.language' +local language = require('vim.treesitter.language') -- query: pattern matching on trees -- predicate matching is implemented in lua @@ -43,7 +43,9 @@ function M.get_query_files(lang, query_name, is_included) local query_path = string.format('queries/%s/%s.scm', lang, query_name) local lang_files = dedupe_files(a.nvim_get_runtime_file(query_path, true)) - if #lang_files == 0 then return {} end + if #lang_files == 0 then + return {} + end local base_langs = {} @@ -52,7 +54,7 @@ function M.get_query_files(lang, query_name, is_included) -- ;+ inherits: ({language},)*{language} -- -- {language} ::= {lang} | ({lang}) - local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$" + local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$' for _, file in ipairs(lang_files) do local modeline = safe_read(file, '*l') @@ -62,7 +64,7 @@ function M.get_query_files(lang, query_name, is_included) if langlist then for _, incllang in ipairs(vim.split(langlist, ',', true)) do - local is_optional = incllang:match("%(.*%)") + local is_optional = incllang:match('%(.*%)') if is_optional then if not is_included then @@ -90,7 +92,7 @@ end local function read_query_files(filenames) local contents = {} - for _,filename in ipairs(filenames) do + for _, filename in ipairs(filenames) do table.insert(contents, safe_read(filename, '*a')) end @@ -138,30 +140,43 @@ function M.get_query(lang, query_name) end end +local query_cache = setmetatable({}, { + __index = function(tbl, key) + rawset(tbl, key, {}) + return rawget(tbl, key) + end, +}) + --- Parse {query} as a string. (If the query is in a file, the caller ---- should read the contents into a string before calling). +--- should read the contents into a string before calling). --- --- Returns a `Query` (see |lua-treesitter-query|) object which can be used to --- search nodes in the syntax tree for the patterns defined in {query} --- using `iter_*` methods below. --- ---- Exposes `info` and `captures` with additional information about the {query}. +--- Exposes `info` and `captures` with additional context about {query}. --- - `captures` contains the list of unique capture names defined in --- {query}. --- -` info.captures` also points to `captures`. --- - `info.patterns` contains information about predicates. --- ----@param lang The language ----@param query A string containing the query (s-expr syntax) +---@param lang string The language +---@param query string A string containing the query (s-expr syntax) --- ---@returns The query function M.parse_query(lang, query) language.require_language(lang) - local self = setmetatable({}, Query) - self.query = vim._ts_parse_query(lang, query) - self.info = self.query:inspect() - self.captures = self.info.captures - return self + local cached = query_cache[lang][query] + if cached then + return cached + else + local self = setmetatable({}, Query) + self.query = vim._ts_parse_query(lang, query) + self.info = self.query:inspect() + self.captures = self.info.captures + query_cache[lang][query] = self + return self + end end --- Gets the text corresponding to a given node @@ -172,7 +187,7 @@ function M.get_node_text(node, source) local start_row, start_col, start_byte = node:start() local end_row, end_col, end_byte = node:end_() - if type(source) == "number" then + if type(source) == 'number' then local lines local eof_row = a.nvim_buf_line_count(source) if start_row >= eof_row then @@ -186,56 +201,64 @@ function M.get_node_text(node, source) lines = a.nvim_buf_get_lines(source, start_row, end_row + 1, true) end - if #lines == 1 then - lines[1] = string.sub(lines[1], start_col+1, end_col) - else - lines[1] = string.sub(lines[1], start_col+1) - lines[#lines] = string.sub(lines[#lines], 1, end_col) + if #lines > 0 then + if #lines == 1 then + lines[1] = string.sub(lines[1], start_col + 1, end_col) + else + lines[1] = string.sub(lines[1], start_col + 1) + lines[#lines] = string.sub(lines[#lines], 1, end_col) + end end - return table.concat(lines, "\n") - elseif type(source) == "string" then - return source:sub(start_byte+1, end_byte) + return table.concat(lines, '\n') + elseif type(source) == 'string' then + return source:sub(start_byte + 1, end_byte) end end -- Predicate handler receive the following arguments -- (match, pattern, bufnr, predicate) local predicate_handlers = { - ["eq?"] = function(match, _, source, predicate) - local node = match[predicate[2]] - local node_text = M.get_node_text(node, source) - - local str - if type(predicate[3]) == "string" then - -- (#eq? @aa "foo") - str = predicate[3] - else - -- (#eq? @aa @bb) - str = M.get_node_text(match[predicate[3]], source) - end + ['eq?'] = function(match, _, source, predicate) + local node = match[predicate[2]] + if not node then + return true + end + local node_text = M.get_node_text(node, source) - if node_text ~= str or str == nil then - return false - end + local str + if type(predicate[3]) == 'string' then + -- (#eq? @aa "foo") + str = predicate[3] + else + -- (#eq? @aa @bb) + str = M.get_node_text(match[predicate[3]], source) + end - return true + if node_text ~= str or str == nil then + return false + end + + return true end, - ["lua-match?"] = function(match, _, source, predicate) - local node = match[predicate[2]] - local regex = predicate[3] - return string.find(M.get_node_text(node, source), regex) + ['lua-match?'] = function(match, _, source, predicate) + local node = match[predicate[2]] + if not node then + return true + end + local regex = predicate[3] + return string.find(M.get_node_text(node, source), regex) end, - ["match?"] = (function() - local magic_prefixes = {['\\v']=true, ['\\m']=true, ['\\M']=true, ['\\V']=true} + ['match?'] = (function() + local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true } ---@private local function check_magic(str) - if string.len(str) < 2 or magic_prefixes[string.sub(str,1,2)] then + if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then return str end - return '\\v'..str + return '\\v' .. str end local compiled_vim_regexes = setmetatable({}, { @@ -243,21 +266,27 @@ local predicate_handlers = { local res = vim.regex(check_magic(pattern)) rawset(t, pattern, res) return res - end + end, }) return function(match, _, source, pred) local node = match[pred[2]] + if not node then + return true + end local regex = compiled_vim_regexes[pred[3]] return regex:match_str(M.get_node_text(node, source)) end end)(), - ["contains?"] = function(match, _, source, predicate) + ['contains?'] = function(match, _, source, predicate) local node = match[predicate[2]] + if not node then + return true + end local node_text = M.get_node_text(node, source) - for i=3,#predicate do + for i = 3, #predicate do if string.find(node_text, predicate[i], 1, true) then return true end @@ -266,19 +295,22 @@ local predicate_handlers = { return false end, - ["any-of?"] = function(match, _, source, predicate) + ['any-of?'] = function(match, _, source, predicate) local node = match[predicate[2]] + if not node then + return true + end local node_text = M.get_node_text(node, source) -- Since 'predicate' will not be used by callers of this function, use it -- to store a string set built from the list of words to check against. - local string_set = predicate["string_set"] + local string_set = predicate['string_set'] if not string_set then string_set = {} - for i=3,#predicate do + for i = 3, #predicate do string_set[predicate[i]] = true end - predicate["string_set"] = string_set + predicate['string_set'] = string_set end return string_set[node_text] @@ -286,32 +318,33 @@ local predicate_handlers = { } -- As we provide lua-match? also expose vim-match? -predicate_handlers["vim-match?"] = predicate_handlers["match?"] - +predicate_handlers['vim-match?'] = predicate_handlers['match?'] -- Directives store metadata or perform side effects against a match. -- Directives should always end with a `!`. -- Directive handler receive the following arguments -- (match, pattern, bufnr, predicate, metadata) local directive_handlers = { - ["set!"] = function(_, _, _, pred, metadata) + ['set!'] = function(_, _, _, pred, metadata) if #pred == 4 then -- (#set! @capture "key" "value") - local capture = pred[2] - if not metadata[capture] then - metadata[capture] = {} + local _, capture_id, key, value = unpack(pred) + if not metadata[capture_id] then + metadata[capture_id] = {} end - metadata[capture][pred[3]] = pred[4] + metadata[capture_id][key] = value else + local _, key, value = unpack(pred) -- (#set! "key" "value") - metadata[pred[2]] = pred[3] + metadata[key] = value end end, -- Shifts the range of a node. -- Example: (#offset! @_node 0 1 0 -1) - ["offset!"] = function(match, _, _, pred, metadata) - local offset_node = match[pred[2]] - local range = {offset_node:range()} + ['offset!'] = function(match, _, _, pred, metadata) + local capture_id = pred[2] + local offset_node = match[capture_id] + local range = { offset_node:range() } local start_row_offset = pred[3] or 0 local start_col_offset = pred[4] or 0 local end_row_offset = pred[5] or 0 @@ -324,9 +357,12 @@ local directive_handlers = { -- If this produces an invalid range, we just skip it. if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then - metadata.content = {range} + if not metadata[capture_id] then + metadata[capture_id] = {} + end + metadata[capture_id].range = range end - end + end, } --- Adds a new predicate to be used in queries @@ -336,7 +372,7 @@ local directive_handlers = { --- signature will be (match, pattern, bufnr, predicate) function M.add_predicate(name, handler, force) if predicate_handlers[name] and not force then - error(string.format("Overriding %s", name)) + error(string.format('Overriding %s', name)) end predicate_handlers[name] = handler @@ -344,17 +380,23 @@ end --- Adds a new directive to be used in queries --- +--- Handlers can set match level data by setting directly on the +--- metadata object `metadata.key = value`, additionally, handlers +--- can set node level data by using the capture id on the +--- metadata table `metadata[capture_id].key = value` +--- ---@param name the name of the directive, without leading # ---@param handler the handler function to be used ---- signature will be (match, pattern, bufnr, predicate) +--- signature will be (match, pattern, bufnr, predicate, metadata) function M.add_directive(name, handler, force) if directive_handlers[name] and not force then - error(string.format("Overriding %s", name)) + error(string.format('Overriding %s', name)) end directive_handlers[name] = handler end +--- Lists the currently available directives to use in queries. ---@return The list of supported directives. function M.list_directives() return vim.tbl_keys(directive_handlers) @@ -372,7 +414,7 @@ end ---@private local function is_directive(name) - return string.sub(name, -1) == "!" + return string.sub(name, -1) == '!' end ---@private @@ -389,7 +431,7 @@ function Query:match_preds(match, pattern, source) -- Skip over directives... they will get processed after all the predicates. if not is_directive(pred[1]) then - if string.sub(pred[1], 1, 4) == "not-" then + if string.sub(pred[1], 1, 4) == 'not-' then pred_name = string.sub(pred[1], 5) is_not = true else @@ -400,7 +442,7 @@ function Query:match_preds(match, pattern, source) local handler = predicate_handlers[pred_name] if not handler then - error(string.format("No handler for %s", pred[1])) + error(string.format('No handler for %s', pred[1])) return false end @@ -423,7 +465,7 @@ function Query:apply_directives(match, pattern, source, metadata) local handler = directive_handlers[pred[1]] if not handler then - error(string.format("No handler for %s", pred[1])) + error(string.format('No handler for %s', pred[1])) return end @@ -432,7 +474,6 @@ function Query:apply_directives(match, pattern, source, metadata) end end - --- Returns the start and stop value if set else the node's range. -- When the node's range is used, the stop is incremented by 1 -- to make the search inclusive. @@ -477,7 +518,7 @@ end ---@returns The matching capture id ---@returns The captured node function Query:iter_captures(node, source, start, stop) - if type(source) == "number" and source == 0 then + if type(source) == 'number' and source == 0 then source = vim.api.nvim_get_current_buf() end @@ -534,7 +575,7 @@ end ---@returns The matching pattern id ---@returns The matching match function Query:iter_matches(node, source, start, stop) - if type(source) == "number" and source == 0 then + if type(source) == 'number' and source == 0 then source = vim.api.nvim_get_current_buf() end |