aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/treesitter/_query_linter.lua
diff options
context:
space:
mode:
authorJosh Rahm <joshuarahm@gmail.com>2023-11-30 20:35:25 +0000
committerJosh Rahm <joshuarahm@gmail.com>2023-11-30 20:35:25 +0000
commit1b7b916b7631ddf73c38e3a0070d64e4636cb2f3 (patch)
treecd08258054db80bb9a11b1061bb091c70b76926a /runtime/lua/vim/treesitter/_query_linter.lua
parenteaa89c11d0f8aefbb512de769c6c82f61a8baca3 (diff)
parent4a8bf24ac690004aedf5540fa440e788459e5e34 (diff)
downloadrneovim-aucmd_textputpost.tar.gz
rneovim-aucmd_textputpost.tar.bz2
rneovim-aucmd_textputpost.zip
Merge remote-tracking branch 'upstream/master' into aucmd_textputpostaucmd_textputpost
Diffstat (limited to 'runtime/lua/vim/treesitter/_query_linter.lua')
-rw-r--r--runtime/lua/vim/treesitter/_query_linter.lua249
1 files changed, 249 insertions, 0 deletions
diff --git a/runtime/lua/vim/treesitter/_query_linter.lua b/runtime/lua/vim/treesitter/_query_linter.lua
new file mode 100644
index 0000000000..87d74789a3
--- /dev/null
+++ b/runtime/lua/vim/treesitter/_query_linter.lua
@@ -0,0 +1,249 @@
+local api = vim.api
+
+local namespace = api.nvim_create_namespace('vim.treesitter.query_linter')
+
+local M = {}
+
+--- @class QueryLinterNormalizedOpts
+--- @field langs string[]
+--- @field clear boolean
+
+--- @alias vim.treesitter.ParseError {msg: string, range: Range4}
+
+--- Contains language dependent context for the query linter
+--- @class QueryLinterLanguageContext
+--- @field lang string? Current `lang` of the targeted parser
+--- @field parser_info table? Parser info returned by vim.treesitter.language.inspect
+--- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs`
+
+--- Adds a diagnostic for node in the query buffer
+--- @param diagnostics Diagnostic[]
+--- @param range Range4
+--- @param lint string
+--- @param lang string?
+local function add_lint_for_node(diagnostics, range, lint, lang)
+ local message = lint:gsub('\n', ' ')
+ diagnostics[#diagnostics + 1] = {
+ lnum = range[1],
+ end_lnum = range[3],
+ col = range[2],
+ end_col = range[4],
+ severity = vim.diagnostic.ERROR,
+ message = message,
+ source = lang,
+ }
+end
+
+--- Determines the target language of a query file by its path: <lang>/<query_type>.scm
+--- @param buf integer
+--- @return string?
+local function guess_query_lang(buf)
+ local filename = api.nvim_buf_get_name(buf)
+ if filename ~= '' then
+ return vim.F.npcall(vim.fn.fnamemodify, filename, ':p:h:t')
+ end
+end
+
+--- @param buf integer
+--- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil
+--- @return QueryLinterNormalizedOpts
+local function normalize_opts(buf, opts)
+ opts = opts or {}
+ if not opts.langs then
+ opts.langs = guess_query_lang(buf)
+ end
+
+ if type(opts.langs) ~= 'table' then
+ --- @diagnostic disable-next-line:assign-type-mismatch
+ opts.langs = { opts.langs }
+ end
+
+ --- @cast opts QueryLinterNormalizedOpts
+ opts.langs = opts.langs or {}
+ return opts
+end
+
+local lint_query = [[;; query
+ (program [(named_node) (list) (grouping)] @toplevel)
+ (named_node
+ name: _ @node.named)
+ (anonymous_node
+ name: _ @node.anonymous)
+ (field_definition
+ name: (identifier) @field)
+ (predicate
+ name: (identifier) @predicate.name
+ type: (predicate_type) @predicate.type)
+ (ERROR) @error
+]]
+
+--- @param err string
+--- @param node TSNode
+--- @return vim.treesitter.ParseError
+local function get_error_entry(err, node)
+ local start_line, start_col = node:range()
+ local line_offset, col_offset, msg = err:gmatch('.-:%d+: Query error at (%d+):(%d+)%. ([^:]+)')() ---@type string, string, string
+ start_line, start_col =
+ start_line + tonumber(line_offset) - 1, start_col + tonumber(col_offset) - 1
+ local end_line, end_col = start_line, start_col
+ if msg:match('^Invalid syntax') or msg:match('^Impossible') then
+ -- Use the length of the underlined node
+ local underlined = vim.split(err, '\n')[2]
+ end_col = end_col + #underlined
+ elseif msg:match('^Invalid') then
+ -- Use the length of the problematic type/capture/field
+ end_col = end_col + #msg:match('"([^"]+)"')
+ end
+
+ return {
+ msg = msg,
+ range = { start_line, start_col, end_line, end_col },
+ }
+end
+
+--- @param node TSNode
+--- @param buf integer
+--- @param lang string
+local function hash_parse(node, buf, lang)
+ return tostring(node:id()) .. tostring(buf) .. tostring(vim.b[buf].changedtick) .. lang
+end
+
+--- @param node TSNode
+--- @param buf integer
+--- @param lang string
+--- @return vim.treesitter.ParseError?
+local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
+ local query_text = vim.treesitter.get_node_text(node, buf)
+ local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|Query
+
+ if not ok and type(err) == 'string' then
+ return get_error_entry(err, node)
+ end
+end)
+
+--- @param buf integer
+--- @param match table<integer,TSNode>
+--- @param query Query
+--- @param lang_context QueryLinterLanguageContext
+--- @param diagnostics Diagnostic[]
+local function lint_match(buf, match, query, lang_context, diagnostics)
+ local lang = lang_context.lang
+ local parser_info = lang_context.parser_info
+
+ for id, node in pairs(match) do
+ local cap_id = query.captures[id]
+
+ -- perform language-independent checks only for first lang
+ if lang_context.is_first_lang and cap_id == 'error' then
+ local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
+ add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text)
+ end
+
+ -- other checks rely on Neovim parser introspection
+ if lang and parser_info and cap_id == 'toplevel' then
+ local err = parse(node, buf, lang)
+ if err then
+ add_lint_for_node(diagnostics, err.range, err.msg, lang)
+ end
+ end
+ end
+end
+
+--- @private
+--- @param buf integer Buffer to lint
+--- @param opts QueryLinterOpts|QueryLinterNormalizedOpts|nil Options for linting
+function M.lint(buf, opts)
+ if buf == 0 then
+ buf = api.nvim_get_current_buf()
+ end
+
+ local diagnostics = {}
+ local query = vim.treesitter.query.parse('query', lint_query)
+
+ opts = normalize_opts(buf, opts)
+
+ -- perform at least one iteration even with no langs to perform language independent checks
+ for i = 1, math.max(1, #opts.langs) do
+ local lang = opts.langs[i]
+
+ --- @type (table|nil)
+ local parser_info = vim.F.npcall(vim.treesitter.language.inspect, lang)
+
+ local parser = vim.treesitter.get_parser(buf)
+ parser:parse()
+ parser:for_each_tree(function(tree, ltree)
+ if ltree:lang() == 'query' then
+ for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1) do
+ local lang_context = {
+ lang = lang,
+ parser_info = parser_info,
+ is_first_lang = i == 1,
+ }
+ lint_match(buf, match, query, lang_context, diagnostics)
+ end
+ end
+ end)
+ end
+
+ vim.diagnostic.set(namespace, buf, diagnostics)
+end
+
+--- @private
+--- @param buf integer
+function M.clear(buf)
+ vim.diagnostic.reset(namespace, buf)
+end
+
+--- @private
+--- @param findstart integer
+--- @param base string
+function M.omnifunc(findstart, base)
+ if findstart == 1 then
+ local result =
+ api.nvim_get_current_line():sub(1, api.nvim_win_get_cursor(0)[2]):find('["#%-%w]*$')
+ return result - 1
+ end
+
+ local buf = api.nvim_get_current_buf()
+ local query_lang = guess_query_lang(buf)
+
+ local ok, parser_info = pcall(vim.treesitter.language.inspect, query_lang)
+ if not ok then
+ return -2
+ end
+
+ local items = {}
+ for _, f in pairs(parser_info.fields) do
+ if f:find(base, 1, true) then
+ table.insert(items, f .. ':')
+ end
+ end
+ for _, p in pairs(vim.treesitter.query.list_predicates()) do
+ local text = '#' .. p
+ local found = text:find(base, 1, true)
+ if found and found <= 2 then -- with or without '#'
+ table.insert(items, text)
+ end
+ text = '#not-' .. p
+ found = text:find(base, 1, true)
+ if found and found <= 2 then -- with or without '#'
+ table.insert(items, text)
+ end
+ end
+ for _, p in pairs(vim.treesitter.query.list_directives()) do
+ local text = '#' .. p
+ local found = text:find(base, 1, true)
+ if found and found <= 2 then -- with or without '#'
+ table.insert(items, text)
+ end
+ end
+ for _, s in pairs(parser_info.symbols) do
+ local text = s[2] and s[1] or '"' .. s[1]:gsub([[\]], [[\\]]) .. '"' ---@type string
+ if text:find(base, 1, true) then
+ table.insert(items, text)
+ end
+ end
+ return { words = items, refresh = 'always' }
+end
+
+return M