diff options
Diffstat (limited to 'runtime/lua/vim/treesitter/query.lua')
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 65 |
1 files changed, 29 insertions, 36 deletions
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 682f981fbc..e49f54681d 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -8,36 +8,10 @@ Query.__index = Query local M = {} --- Filter the runtime query files, the spec is like regular runtime files but in the new `queries` --- directory. They resemble ftplugins, that is that you can override queries by adding things in the --- `queries` directory, and extend using the `after/queries` directory. -local function filter_files(file_list) - local main = nil - local after = {} - - for _, fname in ipairs(file_list) do - -- Only get the name of the directory containing the queries directory - if vim.fn.fnamemodify(fname, ":p:h:h:h:t") == "after" then - table.insert(after, fname) - -- The first one is the one with most priority - elseif not main then - main = fname - end - end - return main and { main, unpack(after) } or after -end - -local function runtime_query_path(lang, query_name) - return string.format('queries/%s/%s.scm', lang, query_name) -end - -local function filtered_runtime_queries(lang, query_name) - return filter_files(a.nvim_get_runtime_file(runtime_query_path(lang, query_name), true) or {}) -end - -local function get_query_files(lang, query_name, is_included) - local lang_files = filtered_runtime_queries(lang, query_name) +function M.get_query_files(lang, query_name, is_included) + local query_path = string.format('queries/%s/%s.scm', lang, query_name) + local lang_files = a.nvim_get_runtime_file(query_path, true) if #lang_files == 0 then return {} end @@ -51,10 +25,10 @@ local function get_query_files(lang, query_name, is_included) local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$" for _, file in ipairs(lang_files) do - local modeline = vim.fn.readfile(file, "", 1) + local modeline = io.open(file, 'r'):read('*l') - if #modeline == 1 then - local langlist = modeline[1]:match(MODELINE_FORMAT) + if modeline then + local langlist = modeline:match(MODELINE_FORMAT) if langlist then for _, incllang in ipairs(vim.split(langlist, ',', true)) do @@ -74,7 +48,7 @@ local function get_query_files(lang, query_name, is_included) local query_files = {} for _, base_lang in ipairs(base_langs) do - local base_files = get_query_files(base_lang, query_name, true) + local base_files = M.get_query_files(base_lang, query_name, true) vim.list_extend(query_files, base_files) end vim.list_extend(query_files, lang_files) @@ -86,10 +60,10 @@ local function read_query_files(filenames) local contents = {} for _,filename in ipairs(filenames) do - vim.list_extend(contents, vim.fn.readfile(filename)) + table.insert(contents, io.open(filename, 'r'):read('*a')) end - return table.concat(contents, '\n') + return table.concat(contents, '') end local match_metatable = { @@ -110,7 +84,7 @@ end -- -- @return The corresponding query, parsed. function M.get_query(lang, query_name) - local query_files = get_query_files(lang, query_name) + local query_files = M.get_query_files(lang, query_name) local query_string = read_query_files(query_files) if #query_string > 0 then @@ -366,6 +340,19 @@ function Query:apply_directives(match, pattern, source, metadata) end end + +--- Returns the start and stop value if set else the node's range. +-- When the node's range is used, the stop is incremented by 1 +-- to make the search inclusive. +local function value_or_node_range(start, stop, node) + if start == nil and stop == nil then + local node_start, _, node_stop, _ = node:range() + return node_start, node_stop + 1 -- Make stop inclusive + end + + return start, stop +end + --- Iterates of the captures of self on a given range. -- -- @param node The node under witch the search will occur @@ -379,6 +366,9 @@ function Query:iter_captures(node, source, start, stop) if type(source) == "number" and source == 0 then source = vim.api.nvim_get_current_buf() end + + start, stop = value_or_node_range(start, stop, node) + local raw_iter = node:_rawquery(self.query, true, start, stop) local function iter() local capture, captured_node, match = raw_iter() @@ -411,6 +401,9 @@ function Query:iter_matches(node, source, start, stop) if type(source) == "number" and source == 0 then source = vim.api.nvim_get_current_buf() end + + start, stop = value_or_node_range(start, stop, node) + local raw_iter = node:_rawquery(self.query, false, start, stop) local function iter() local pattern, match = raw_iter() |