aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/treesitter')
-rw-r--r--runtime/lua/vim/treesitter/language.lua16
-rw-r--r--runtime/lua/vim/treesitter/languagetree.lua43
-rw-r--r--runtime/lua/vim/treesitter/query.lua178
3 files changed, 163 insertions, 74 deletions
diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua
index a7e36a0b89..d60cd2d0c7 100644
--- a/runtime/lua/vim/treesitter/language.lua
+++ b/runtime/lua/vim/treesitter/language.lua
@@ -8,7 +8,8 @@ local M = {}
--
-- @param lang The language the parser should parse
-- @param path Optionnal path the parser is located at
-function M.require_language(lang, path)
+-- @param silent Don't throw an error if language not found
+function M.require_language(lang, path, silent)
if vim._ts_has_language(lang) then
return true
end
@@ -16,12 +17,23 @@ function M.require_language(lang, path)
local fname = 'parser/' .. lang .. '.*'
local paths = a.nvim_get_runtime_file(fname, false)
if #paths == 0 then
+ if silent then
+ return false
+ end
+
-- TODO(bfredl): help tag?
error("no parser for '"..lang.."' language, see :help treesitter-parsers")
end
path = paths[1]
end
- vim._ts_add_language(path, lang)
+
+ if silent then
+ return pcall(function() vim._ts_add_language(path, lang) end)
+ else
+ vim._ts_add_language(path, lang)
+ end
+
+ return true
end
--- Inspects the provided language.
diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua
index a8b62e21b9..9c620c422c 100644
--- a/runtime/lua/vim/treesitter/languagetree.lua
+++ b/runtime/lua/vim/treesitter/languagetree.lua
@@ -121,23 +121,30 @@ function LanguageTree:parse()
local seen_langs = {}
for lang, injection_ranges in pairs(injections_by_lang) do
- local child = self._children[lang]
+ local has_lang = language.require_language(lang, nil, true)
- if not child then
- child = self:add_child(lang)
- end
+ -- Child language trees should just be ignored if not found, since
+ -- they can depend on the text of a node. Intermediate strings
+ -- would cause errors for unknown parsers.
+ if has_lang then
+ local child = self._children[lang]
- child:set_included_regions(injection_ranges)
+ if not child then
+ child = self:add_child(lang)
+ end
- local _, child_changes = child:parse()
+ child:set_included_regions(injection_ranges)
- -- Propagate any child changes so they are included in the
- -- the change list for the callback.
- if child_changes then
- vim.list_extend(changes, child_changes)
- end
+ local _, child_changes = child:parse()
- seen_langs[lang] = true
+ -- Propagate any child changes so they are included in the
+ -- the change list for the callback.
+ if child_changes then
+ vim.list_extend(changes, child_changes)
+ end
+
+ seen_langs[lang] = true
+ end
end
for lang, _ in pairs(self._children) do
@@ -282,7 +289,7 @@ function LanguageTree:_get_injections()
local root_node = tree:root()
local start_line, _, end_line, _ = root_node:range()
- for pattern, match in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do
+ for pattern, match, metadata in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do
local lang = nil
local injection_node = nil
local combined = false
@@ -291,9 +298,9 @@ function LanguageTree:_get_injections()
-- using a tag with the language, for example
-- @javascript
for id, node in pairs(match) do
+ local data = metadata[id]
local name = self._injection_query.captures[id]
- -- TODO add a way to offset the content passed to the parser.
- -- Needed to shave off leading quotes and things of that nature.
+ local offset_range = data and data.offset
-- Lang should override any other language tag
if name == "language" then
@@ -301,7 +308,7 @@ function LanguageTree:_get_injections()
elseif name == "combined" then
combined = true
elseif name == "content" then
- injection_node = node
+ injection_node = offset_range or node
-- Ignore any tags that start with "_"
-- Allows for other tags to be used in matches
elseif string.sub(name, 1, 1) ~= "_" then
@@ -310,7 +317,7 @@ function LanguageTree:_get_injections()
end
if not injection_node then
- injection_node = node
+ injection_node = offset_range or node
end
end
end
@@ -445,7 +452,7 @@ end
function LanguageTree:language_for_range(range)
for _, child in pairs(self._children) do
if child:contains(range) then
- return child:node_for_range(range)
+ return child:language_for_range(range)
end
end
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index 3537ba78f5..5a27d740a2 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -8,36 +8,10 @@ Query.__index = Query
local M = {}
--- Filter the runtime query files, the spec is like regular runtime files but in the new `queries`
--- directory. They resemble ftplugins, that is that you can override queries by adding things in the
--- `queries` directory, and extend using the `after/queries` directory.
-local function filter_files(file_list)
- local main = nil
- local after = {}
-
- for _, fname in ipairs(file_list) do
- -- Only get the name of the directory containing the queries directory
- if vim.fn.fnamemodify(fname, ":p:h:h:h:t") == "after" then
- table.insert(after, fname)
- -- The first one is the one with most priority
- elseif not main then
- main = fname
- end
- end
-
- return main and { main, unpack(after) } or after
-end
-
-local function runtime_query_path(lang, query_name)
- return string.format('queries/%s/%s.scm', lang, query_name)
-end
-
-local function filtered_runtime_queries(lang, query_name)
- return filter_files(a.nvim_get_runtime_file(runtime_query_path(lang, query_name), true) or {})
-end
-local function get_query_files(lang, query_name, is_included)
- local lang_files = filtered_runtime_queries(lang, query_name)
+function M.get_query_files(lang, query_name, is_included)
+ local query_path = string.format('queries/%s/%s.scm', lang, query_name)
+ local lang_files = a.nvim_get_runtime_file(query_path, true)
if #lang_files == 0 then return {} end
@@ -51,10 +25,10 @@ local function get_query_files(lang, query_name, is_included)
local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$"
for _, file in ipairs(lang_files) do
- local modeline = vim.fn.readfile(file, "", 1)
+ local modeline = io.open(file, 'r'):read('*l')
- if #modeline == 1 then
- local langlist = modeline[1]:match(MODELINE_FORMAT)
+ if modeline then
+ local langlist = modeline:match(MODELINE_FORMAT)
if langlist then
for _, incllang in ipairs(vim.split(langlist, ',', true)) do
@@ -74,7 +48,7 @@ local function get_query_files(lang, query_name, is_included)
local query_files = {}
for _, base_lang in ipairs(base_langs) do
- local base_files = get_query_files(base_lang, query_name, true)
+ local base_files = M.get_query_files(base_lang, query_name, true)
vim.list_extend(query_files, base_files)
end
vim.list_extend(query_files, lang_files)
@@ -86,10 +60,21 @@ local function read_query_files(filenames)
local contents = {}
for _,filename in ipairs(filenames) do
- vim.list_extend(contents, vim.fn.readfile(filename))
+ table.insert(contents, io.open(filename, 'r'):read('*a'))
end
- return table.concat(contents, '\n')
+ return table.concat(contents, '')
+end
+
+local match_metatable = {
+ __index = function(tbl, key)
+ rawset(tbl, key, {})
+ return tbl[key]
+ end
+}
+
+local function new_match_metadata()
+ return setmetatable({}, match_metatable)
end
--- Returns the runtime query {query_name} for {lang}.
@@ -99,7 +84,7 @@ end
--
-- @return The corresponding query, parsed.
function M.get_query(lang, query_name)
- local query_files = get_query_files(lang, query_name)
+ local query_files = M.get_query_files(lang, query_name)
local query_string = read_query_files(query_files)
if #query_string > 0 then
@@ -222,6 +207,44 @@ local predicate_handlers = {
-- As we provide lua-match? also expose vim-match?
predicate_handlers["vim-match?"] = predicate_handlers["match?"]
+
+-- Directives store metadata or perform side effects against a match.
+-- Directives should always end with a `!`.
+-- Directive handler receive the following arguments
+-- (match, pattern, bufnr, predicate)
+local directive_handlers = {
+ ["set!"] = function(_, _, _, pred, metadata)
+ if #pred == 4 then
+ -- (set! @capture "key" "value")
+ metadata[pred[2]][pred[3]] = pred[4]
+ else
+ -- (set! "key" "value")
+ metadata[pred[2]] = pred[3]
+ end
+ end,
+ -- Shifts the range of a node.
+ -- Example: (#offset! @_node 0 1 0 -1)
+ ["offset!"] = function(match, _, _, pred, metadata)
+ local offset_node = match[pred[2]]
+ local range = {offset_node:range()}
+ local start_row_offset = pred[3] or 0
+ local start_col_offset = pred[4] or 0
+ local end_row_offset = pred[5] or 0
+ local end_col_offset = pred[6] or 0
+ local key = pred[7] or "offset"
+
+ range[1] = range[1] + start_row_offset
+ range[2] = range[2] + start_col_offset
+ range[3] = range[3] + end_row_offset
+ range[4] = range[4] + end_col_offset
+
+ -- If this produces an invalid range, we just skip it.
+ if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
+ metadata[pred[2]][key] = range
+ end
+ end
+}
+
--- Adds a new predicates to be used in queries
--
-- @param name the name of the predicate, without leading #
@@ -229,12 +252,25 @@ predicate_handlers["vim-match?"] = predicate_handlers["match?"]
-- signature will be (match, pattern, bufnr, predicate)
function M.add_predicate(name, handler, force)
if predicate_handlers[name] and not force then
- a.nvim_err_writeln(string.format("Overriding %s", name))
+ error(string.format("Overriding %s", name))
end
predicate_handlers[name] = handler
end
+--- Adds a new directive to be used in queries
+--
+-- @param name the name of the directive, without leading #
+-- @param handler the handler function to be used
+-- signature will be (match, pattern, bufnr, predicate)
+function M.add_directive(name, handler, force)
+ if directive_handlers[name] and not force then
+ error(string.format("Overriding %s", name))
+ end
+
+ directive_handlers[name] = handler
+end
+
--- Returns the list of currently supported predicates
function M.list_predicates()
return vim.tbl_keys(predicate_handlers)
@@ -244,6 +280,10 @@ local function xor(x, y)
return (x or y) and not (x and y)
end
+local function is_directive(name)
+ return string.sub(name, -1) == "!"
+end
+
function Query:match_preds(match, pattern, source)
local preds = self.info.patterns[pattern]
@@ -254,30 +294,52 @@ function Query:match_preds(match, pattern, source)
-- Also, tree-sitter strips the leading # from predicates for us.
local pred_name
local is_not
- if string.sub(pred[1], 1, 4) == "not-" then
- pred_name = string.sub(pred[1], 5)
- is_not = true
- else
- pred_name = pred[1]
- is_not = false
- end
- local handler = predicate_handlers[pred_name]
+ -- Skip over directives... they will get processed after all the predicates.
+ if not is_directive(pred[1]) then
+ if string.sub(pred[1], 1, 4) == "not-" then
+ pred_name = string.sub(pred[1], 5)
+ is_not = true
+ else
+ pred_name = pred[1]
+ is_not = false
+ end
+
+ local handler = predicate_handlers[pred_name]
- if not handler then
- a.nvim_err_writeln(string.format("No handler for %s", pred[1]))
- return false
- end
+ if not handler then
+ error(string.format("No handler for %s", pred[1]))
+ return false
+ end
- local pred_matches = handler(match, pattern, source, pred)
+ local pred_matches = handler(match, pattern, source, pred)
- if not xor(is_not, pred_matches) then
- return false
+ if not xor(is_not, pred_matches) then
+ return false
+ end
end
end
return true
end
+--- Applies directives against a match and pattern.
+function Query:apply_directives(match, pattern, source, metadata)
+ local preds = self.info.patterns[pattern]
+
+ for _, pred in pairs(preds or {}) do
+ if is_directive(pred[1]) then
+ local handler = directive_handlers[pred[1]]
+
+ if not handler then
+ error(string.format("No handler for %s", pred[1]))
+ return
+ end
+
+ handler(match, pattern, source, pred, metadata)
+ end
+ end
+end
+
--- Iterates of the captures of self on a given range.
--
-- @param node The node under witch the search will occur
@@ -294,14 +356,18 @@ function Query:iter_captures(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, true, start, stop)
local function iter()
local capture, captured_node, match = raw_iter()
+ local metadata = new_match_metadata()
+
if match ~= nil then
local active = self:match_preds(match, match.pattern, source)
match.active = active
if not active then
return iter() -- tail call: try next match
end
+
+ self:apply_directives(match, match.pattern, source, metadata)
end
- return capture, captured_node
+ return capture, captured_node, metadata
end
return iter
end
@@ -322,13 +388,17 @@ function Query:iter_matches(node, source, start, stop)
local raw_iter = node:_rawquery(self.query, false, start, stop)
local function iter()
local pattern, match = raw_iter()
+ local metadata = new_match_metadata()
+
if match ~= nil then
local active = self:match_preds(match, pattern, source)
if not active then
return iter() -- tail call: try next match
end
+
+ self:apply_directives(match, pattern, source, metadata)
end
- return pattern, match
+ return pattern, match, metadata
end
return iter
end