diff options
Diffstat (limited to 'runtime/lua')
-rw-r--r-- | runtime/lua/vim/lsp.lua | 27 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/buf.lua | 10 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/diagnostic.lua | 18 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/handlers.lua | 101 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/protocol.lua | 54 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/util.lua | 77 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter.lua | 2 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/language.lua | 16 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 43 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 178 |
10 files changed, 427 insertions, 99 deletions
diff --git a/runtime/lua/vim/lsp.lua b/runtime/lua/vim/lsp.lua index f082fe29f2..0326550245 100644 --- a/runtime/lua/vim/lsp.lua +++ b/runtime/lua/vim/lsp.lua @@ -1,3 +1,5 @@ +local if_nil = vim.F.if_nil + local default_handlers = require 'vim.lsp.handlers' local log = require 'vim.lsp.log' local lsp_rpc = require 'vim.lsp.rpc' @@ -226,6 +228,7 @@ local function validate_client_config(config) on_init = { config.on_init, "f", true }; before_init = { config.before_init, "f", true }; offset_encoding = { config.offset_encoding, "s", true }; + flags = { config.flags, "t", true }; } -- TODO(remove-callbacks) @@ -434,6 +437,17 @@ end --- --@param trace: "off" | "messages" | "verbose" | nil passed directly to the language --- server in the initialize request. Invalid/empty values will default to "off" +--@param flags: A table with flags for the client. The current (experimental) flags are: +--- - allow_incremental_sync (bool, default false): Allow using on_line callbacks for lsp +--- +--- <pre> +--- -- In attach function for the client, you can do: +--- local custom_attach = function(client) +--- if client.config.flags then +--- client.config.flags.allow_incremental_sync = true +--- end +--- end +--- </pre> --- --@returns Client id. |vim.lsp.get_client_by_id()| Note: client may not be --- fully initialized. Use `on_init` to do any actions once @@ -442,6 +456,8 @@ function lsp.start_client(config) local cleaned_config = validate_client_config(config) local cmd, cmd_args, offset_encoding = cleaned_config.cmd, cleaned_config.cmd_args, cleaned_config.offset_encoding + config.flags = config.flags or {} + local client_id = next_client_id() -- TODO(remove-callbacks) @@ -553,6 +569,8 @@ function lsp.start_client(config) -- TODO(remove-callbacks) callbacks = handlers; handlers = handlers; + -- for $/progress report + messages = { name = name, messages = {}, progress = {}, status = {} } } -- Store the uninitialized_clients for cleanup in case we exit before initialize finishes. @@ -799,6 +817,7 @@ do local size_index = encoding_index[client.offset_encoding] local length = select(size_index, old_byte_size, old_utf16_size, old_utf32_size) local lines = nvim_buf_get_lines(bufnr, firstline, new_lastline, true) + -- This is necessary because we are specifying the full line including the -- newline in range. Therefore, we must replace the newline as well. if #lines > 0 then @@ -820,6 +839,8 @@ do end) local uri = vim.uri_from_bufnr(bufnr) for_each_buffer_client(bufnr, function(client, _client_id) + local allow_incremental_sync = if_nil(client.config.flags.allow_incremental_sync, false) + local text_document_did_change = client.resolved_capabilities.text_document_did_change local changes if text_document_did_change == protocol.TextDocumentSyncKind.None then @@ -830,7 +851,7 @@ do -- is no way to specify the sync capability by the client. -- See https://github.com/palantir/python-language-server/commit/cfd6675bc10d5e8dbc50fc50f90e4a37b7178821#diff-f68667852a14e9f761f6ebf07ba02fc8 for an example of pyls handling both. --]=] - elseif true or text_document_did_change == protocol.TextDocumentSyncKind.Full then + elseif not allow_incremental_sync or text_document_did_change == protocol.TextDocumentSyncKind.Full then changes = full_changes(client) elseif text_document_did_change == protocol.TextDocumentSyncKind.Incremental then changes = incremental_changes(client) @@ -862,8 +883,8 @@ function lsp._text_document_did_save_handler(bufnr) client.notify('textDocument/didSave', { textDocument = { uri = uri; - text = included_text; - } + }; + text = included_text; }) end end) diff --git a/runtime/lua/vim/lsp/buf.lua b/runtime/lua/vim/lsp/buf.lua index a70581478b..00219b6d98 100644 --- a/runtime/lua/vim/lsp/buf.lua +++ b/runtime/lua/vim/lsp/buf.lua @@ -149,7 +149,7 @@ end --@param options Table with valid `FormattingOptions` entries. --@param start_pos ({number, number}, optional) mark-indexed position. ---Defaults to the start of the last visual selection. ---@param start_pos ({number, number}, optional) mark-indexed position. +--@param end_pos ({number, number}, optional) mark-indexed position. ---Defaults to the end of the last visual selection. function M.range_formatting(options, start_pos, end_pos) validate { options = {options, 't', true} } @@ -239,6 +239,7 @@ function M.outgoing_calls() end --- List workspace folders. +--- function M.list_workspace_folders() local workspace_folders = {} for _, client in ipairs(vim.lsp.buf_get_clients()) do @@ -249,7 +250,8 @@ function M.list_workspace_folders() return workspace_folders end ---- Add a workspace folder. +--- Add the folder at path to the workspace folders. If {path} is +--- not provided, the user will be prompted for a path using |input()|. function M.add_workspace_folder(workspace_folder) workspace_folder = workspace_folder or npcall(vfn.input, "Workspace Folder: ", vfn.expand('%:p:h')) vim.api.nvim_command("redraw") @@ -275,7 +277,9 @@ function M.add_workspace_folder(workspace_folder) end end ---- Remove a workspace folder. +--- Remove the folder at path from the workspace folders. If +--- {path} is not provided, the user will be prompted for +--- a path using |input()|. function M.remove_workspace_folder(workspace_folder) workspace_folder = workspace_folder or npcall(vfn.input, "Workspace Folder: ", vfn.expand('%:p:h')) vim.api.nvim_command("redraw") diff --git a/runtime/lua/vim/lsp/diagnostic.lua b/runtime/lua/vim/lsp/diagnostic.lua index 27a1f53f89..072349b226 100644 --- a/runtime/lua/vim/lsp/diagnostic.lua +++ b/runtime/lua/vim/lsp/diagnostic.lua @@ -400,9 +400,9 @@ end --- let sl = '' --- if luaeval('not vim.tbl_isempty(vim.lsp.buf_get_clients(0))') --- let sl.='%#MyStatuslineLSP#E:' ---- let sl.='%#MyStatuslineLSPErrors#%{luaeval("vim.lsp.diagnostic.get_count([[Error]])")}' +--- let sl.='%#MyStatuslineLSPErrors#%{luaeval("vim.lsp.diagnostic.get_count(0, [[Error]])")}' --- let sl.='%#MyStatuslineLSP# W:' ---- let sl.='%#MyStatuslineLSPWarnings#%{luaeval("vim.lsp.diagnostic.get_count([[Warning]])")}' +--- let sl.='%#MyStatuslineLSPWarnings#%{luaeval("vim.lsp.diagnostic.get_count(0, [[Warning]])")}' --- else --- let sl.='%#MyStatuslineLSPErrors#off' --- endif @@ -510,7 +510,7 @@ end --- Get the previous diagnostic closest to the cursor_position --- ----@param opts table See |vim.lsp.diagnostics.goto_next()| +---@param opts table See |vim.lsp.diagnostic.goto_next()| ---@return table Previous diagnostic function M.get_prev(opts) opts = opts or {} @@ -523,7 +523,7 @@ function M.get_prev(opts) end --- Return the pos, {row, col}, for the prev diagnostic in the current buffer. ----@param opts table See |vim.lsp.diagnostics.goto_next()| +---@param opts table See |vim.lsp.diagnostic.goto_next()| ---@return table Previous diagnostic position function M.get_prev_pos(opts) return _iter_diagnostic_lines_pos( @@ -533,7 +533,7 @@ function M.get_prev_pos(opts) end --- Move to the previous diagnostic ----@param opts table See |vim.lsp.diagnostics.goto_next()| +---@param opts table See |vim.lsp.diagnostic.goto_next()| function M.goto_prev(opts) return _iter_diagnostic_move_pos( "DiagnosticPrevious", @@ -543,7 +543,7 @@ function M.goto_prev(opts) end --- Get the previous diagnostic closest to the cursor_position ----@param opts table See |vim.lsp.diagnostics.goto_next()| +---@param opts table See |vim.lsp.diagnostic.goto_next()| ---@return table Next diagnostic function M.get_next(opts) opts = opts or {} @@ -556,7 +556,7 @@ function M.get_next(opts) end --- Return the pos, {row, col}, for the next diagnostic in the current buffer. ----@param opts table See |vim.lsp.diagnostics.goto_next()| +---@param opts table See |vim.lsp.diagnostic.goto_next()| ---@return table Next diagnostic position function M.get_next_pos(opts) return _iter_diagnostic_lines_pos( @@ -1044,6 +1044,8 @@ function M.display(diagnostics, bufnr, client_id, config) diagnostics = diagnostics or M.get(bufnr, client_id) + vim.api.nvim_command("doautocmd <nomodeline> User LspDiagnosticsChanged") + if not diagnostics or vim.tbl_isempty(diagnostics) then return end @@ -1062,8 +1064,6 @@ function M.display(diagnostics, bufnr, client_id, config) if signs_opts then M.set_signs(diagnostics, bufnr, client_id, nil, signs_opts) end - - vim.api.nvim_command("doautocmd User LspDiagnosticsChanged") end -- }}} -- Diagnostic User Functions {{{ diff --git a/runtime/lua/vim/lsp/handlers.lua b/runtime/lua/vim/lsp/handlers.lua index e034923afb..87f35363b1 100644 --- a/runtime/lua/vim/lsp/handlers.lua +++ b/runtime/lua/vim/lsp/handlers.lua @@ -24,6 +24,79 @@ M['workspace/executeCommand'] = function(err, _) end end +-- @msg of type ProgressParams +-- Basically a token of type number/string +local function progress_callback(_, _, params, client_id) + local client = vim.lsp.get_client_by_id(client_id) + local client_name = client and client.name or string.format("id=%d", client_id) + if not client then + err_message("LSP[", client_name, "] client has shut down after sending the message") + end + local val = params.value -- unspecified yet + local token = params.token -- string or number + + + if val.kind then + if val.kind == 'begin' then + client.messages.progress[token] = { + title = val.title, + message = val.message, + percentage = val.percentage, + } + elseif val.kind == 'report' then + client.messages.progress[token].message = val.message; + client.messages.progress[token].percentage = val.percentage; + elseif val.kind == 'end' then + if client.messages.progress[token] == nil then + err_message("LSP[", client_name, "] received `end` message with no corresponding `begin`") + else + client.messages.progress[token].message = val.message + client.messages.progress[token].done = true + end + end + else + table.insert(client.messages, {content = val, show_once = true, shown = 0}) + end + + vim.api.nvim_command("doautocmd <nomodeline> User LspProgressUpdate") +end + +--@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#progress +M['$/progress'] = progress_callback + +--@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#window_workDoneProgress_create +M['window/workDoneProgress/create'] = function(_, _, params, client_id) + local client = vim.lsp.get_client_by_id(client_id) + local token = params.token -- string or number + local client_name = client and client.name or string.format("id=%d", client_id) + if not client then + err_message("LSP[", client_name, "] client has shut down after sending the message") + end + client.messages.progress[token] = {} + return vim.NIL +end + +--@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#window_showMessageRequest +M['window/showMessageRequest'] = function(_, _, params) + + local actions = params.actions + print(params.message) + local option_strings = {params.message, "\nRequest Actions:"} + for i, action in ipairs(actions) do + local title = action.title:gsub('\r\n', '\\r\\n') + title = title:gsub('\n', '\\n') + table.insert(option_strings, string.format("%d. %s", i, title)) + end + + -- window/showMessageRequest can return either MessageActionItem[] or null. + local choice = vim.fn.inputlist(option_strings) + if choice < 1 or choice > #actions then + return vim.NIL + else + return actions[choice] + end +end + --@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_codeAction M['textDocument/codeAction'] = function(_, _, actions) if actions == nil or vim.tbl_isempty(actions) then @@ -72,6 +145,31 @@ M['workspace/applyEdit'] = function(_, _, workspace_edit) } end +--@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#workspace_configuration +M['workspace/configuration'] = function(err, _, params, client_id) + local client = vim.lsp.get_client_by_id(client_id) + if not client then + err_message("LSP[id=", client_id, "] client has shut down after sending the message") + end + if err then error(vim.inspect(err)) end + if not params.items then + return {} + end + + local result = {} + for _, item in ipairs(params.items) do + if item.section then + local value = util.lookup_section(client.config.settings, item.section) or vim.NIL + -- For empty sections with no explicit '' key, return settings as is + if value == vim.NIL and item.section == '' then + value = client.config.settings or vim.NIL + end + table.insert(result, value) + end + end + return result +end + M['textDocument/publishDiagnostics'] = function(...) return require('vim.lsp.diagnostic').on_publish_diagnostics(...) end @@ -212,9 +310,8 @@ M['textDocument/signatureHelp'] = function(_, method, result) end --@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_documentHighlight -M['textDocument/documentHighlight'] = function(_, _, result, _) +M['textDocument/documentHighlight'] = function(_, _, result, _, bufnr, _) if not result then return end - local bufnr = api.nvim_get_current_buf() util.buf_highlight_references(bufnr, result) end diff --git a/runtime/lua/vim/lsp/protocol.lua b/runtime/lua/vim/lsp/protocol.lua index 218424fa14..3e111c154a 100644 --- a/runtime/lua/vim/lsp/protocol.lua +++ b/runtime/lua/vim/lsp/protocol.lua @@ -34,6 +34,13 @@ local constants = { Hint = 4; }; + DiagnosticTag = { + -- Unused or unnecessary code + Unnecessary = 1; + -- Deprecated or obsolete code + Deprecated = 2; + }; + MessageType = { -- An error message. Error = 1; @@ -292,8 +299,9 @@ local constants = { } for k, v in pairs(constants) do - vim.tbl_add_reverse_lookup(v) - protocol[k] = v + local tbl = vim.deepcopy(v) + vim.tbl_add_reverse_lookup(tbl) + protocol[k] = tbl end --[=[ @@ -520,6 +528,13 @@ export interface TextDocumentClientCapabilities { publishDiagnostics?: { --Whether the clients accepts diagnostics with related information. relatedInformation?: boolean; + --Client supports the tag property to provide meta data about a diagnostic. + --Clients supporting tags have to handle unknown tags gracefully. + --Since 3.15.0 + tagSupport?: { + --The tags supported by this client + valueSet: DiagnosticTag[]; + }; }; --Capabilities specific to `textDocument/foldingRange` requests. -- @@ -623,7 +638,11 @@ function protocol.make_client_capabilities() codeActionLiteralSupport = { codeActionKind = { - valueSet = vim.tbl_values(protocol.CodeActionKind); + valueSet = (function() + local res = vim.tbl_values(protocol.CodeActionKind) + table.sort(res) + return res + end)(); }; }; }; @@ -643,7 +662,7 @@ function protocol.make_client_capabilities() completionItemKind = { valueSet = (function() local res = {} - for k in pairs(protocol.CompletionItemKind) do + for k in ipairs(protocol.CompletionItemKind) do if type(k) == 'number' then table.insert(res, k) end end return res @@ -689,7 +708,7 @@ function protocol.make_client_capabilities() symbolKind = { valueSet = (function() local res = {} - for k in pairs(protocol.SymbolKind) do + for k in ipairs(protocol.SymbolKind) do if type(k) == 'number' then table.insert(res, k) end end return res @@ -701,6 +720,18 @@ function protocol.make_client_capabilities() dynamicRegistration = false; prepareSupport = true; }; + publishDiagnostics = { + relatedInformation = true; + tagSupport = { + valueSet = (function() + local res = {} + for k in ipairs(protocol.DiagnosticTag) do + if type(k) == 'number' then table.insert(res, k) end + end + return res + end)(); + }; + }; }; workspace = { symbol = { @@ -708,7 +739,7 @@ function protocol.make_client_capabilities() symbolKind = { valueSet = (function() local res = {} - for k in pairs(protocol.SymbolKind) do + for k in ipairs(protocol.SymbolKind) do if type(k) == 'number' then table.insert(res, k) end end return res @@ -723,6 +754,17 @@ function protocol.make_client_capabilities() dynamicRegistration = false; }; experimental = nil; + window = { + workDoneProgress = true; + showMessage = { + messageActionItem = { + additionalPropertiesSupport = false; + }; + }; + showDocument = { + support = false; + }; + }; } end diff --git a/runtime/lua/vim/lsp/util.lua b/runtime/lua/vim/lsp/util.lua index 5804ac6656..ecff95f61e 100644 --- a/runtime/lua/vim/lsp/util.lua +++ b/runtime/lua/vim/lsp/util.lua @@ -120,6 +120,63 @@ local function get_line_byte_from_position(bufnr, position) return col end +--- Process and return progress reports from lsp server +function M.get_progress_messages() + + local new_messages = {} + local msg_remove = {} + local progress_remove = {} + + for _, client in ipairs(vim.lsp.get_active_clients()) do + local messages = client.messages + local data = messages + for token, ctx in pairs(data.progress) do + + local new_report = { + name = data.name, + title = ctx.title or "empty title", + message = ctx.message, + percentage = ctx.percentage, + progress = true, + } + table.insert(new_messages, new_report) + + if ctx.done then + table.insert(progress_remove, {client = client, token = token}) + end + end + + for i, msg in ipairs(data.messages) do + if msg.show_once then + msg.shown = msg.shown + 1 + if msg.shown > 1 then + table.insert(msg_remove, {client = client, idx = i}) + end + end + + table.insert(new_messages, {name = data.name, content = msg.content}) + end + + if next(data.status) ~= nil then + table.insert(new_messages, { + name = data.name, + content = data.status.content, + uri = data.status.uri, + status = true + }) + end + for _, item in ipairs(msg_remove) do + table.remove(client.messages, item.idx) + end + + for _, item in ipairs(progress_remove) do + client.messages.progress[item.token] = nil + end + end + + return new_messages +end + --- Applies a list of text edits to a buffer. --@param text_edits (table) list of `TextEdit` objects --@param buf_nr (number) Buffer id @@ -1022,7 +1079,7 @@ do --@deprecated function M.buf_diagnostics_signs(bufnr, diagnostics, client_id) - warn_once("buf_diagnostics_signs is deprecated. Use 'vim.lsp.diagnostics.set_signs'") + warn_once("buf_diagnostics_signs is deprecated. Use 'vim.lsp.diagnostic.set_signs'") return vim.lsp.diagnostic.set_signs(diagnostics, bufnr, client_id) end @@ -1315,6 +1372,9 @@ function M.make_text_document_params() return { uri = vim.uri_from_bufnr(0) } end +--- Create the workspace params +--@param added +--@param removed function M.make_workspace_params(added, removed) return { event = { added = added; removed = removed; } } end @@ -1362,6 +1422,21 @@ function M.character_offset(buf, row, col) return str_utfindex(line, col) end +--- Helper function to return nested values in language server settings +--- +--@param settings a table of language server settings +--@param section a string indicating the field of the settings table +--@returns (table or string) The value of settings accessed via section +function M.lookup_section(settings, section) + for part in vim.gsplit(section, '.', true) do + settings = settings[part] + if not settings then + return + end + end + return settings +end + M._get_line_byte_from_position = get_line_byte_from_position M._warn_once = warn_once diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua index 6886f0c178..79dcf77f9e 100644 --- a/runtime/lua/vim/treesitter.lua +++ b/runtime/lua/vim/treesitter.lua @@ -50,7 +50,7 @@ function M._create_parser(bufnr, lang, opts) end end - a.nvim_buf_attach(self.bufnr, false, {on_bytes=bytes_cb, on_detach=detach_cb}) + a.nvim_buf_attach(self.bufnr, false, {on_bytes=bytes_cb, on_detach=detach_cb, preview=true}) self:parse() diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua index a7e36a0b89..d60cd2d0c7 100644 --- a/runtime/lua/vim/treesitter/language.lua +++ b/runtime/lua/vim/treesitter/language.lua @@ -8,7 +8,8 @@ local M = {} -- -- @param lang The language the parser should parse -- @param path Optionnal path the parser is located at -function M.require_language(lang, path) +-- @param silent Don't throw an error if language not found +function M.require_language(lang, path, silent) if vim._ts_has_language(lang) then return true end @@ -16,12 +17,23 @@ function M.require_language(lang, path) local fname = 'parser/' .. lang .. '.*' local paths = a.nvim_get_runtime_file(fname, false) if #paths == 0 then + if silent then + return false + end + -- TODO(bfredl): help tag? error("no parser for '"..lang.."' language, see :help treesitter-parsers") end path = paths[1] end - vim._ts_add_language(path, lang) + + if silent then + return pcall(function() vim._ts_add_language(path, lang) end) + else + vim._ts_add_language(path, lang) + end + + return true end --- Inspects the provided language. diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index a8b62e21b9..9c620c422c 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -121,23 +121,30 @@ function LanguageTree:parse() local seen_langs = {} for lang, injection_ranges in pairs(injections_by_lang) do - local child = self._children[lang] + local has_lang = language.require_language(lang, nil, true) - if not child then - child = self:add_child(lang) - end + -- Child language trees should just be ignored if not found, since + -- they can depend on the text of a node. Intermediate strings + -- would cause errors for unknown parsers. + if has_lang then + local child = self._children[lang] - child:set_included_regions(injection_ranges) + if not child then + child = self:add_child(lang) + end - local _, child_changes = child:parse() + child:set_included_regions(injection_ranges) - -- Propagate any child changes so they are included in the - -- the change list for the callback. - if child_changes then - vim.list_extend(changes, child_changes) - end + local _, child_changes = child:parse() - seen_langs[lang] = true + -- Propagate any child changes so they are included in the + -- the change list for the callback. + if child_changes then + vim.list_extend(changes, child_changes) + end + + seen_langs[lang] = true + end end for lang, _ in pairs(self._children) do @@ -282,7 +289,7 @@ function LanguageTree:_get_injections() local root_node = tree:root() local start_line, _, end_line, _ = root_node:range() - for pattern, match in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do + for pattern, match, metadata in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do local lang = nil local injection_node = nil local combined = false @@ -291,9 +298,9 @@ function LanguageTree:_get_injections() -- using a tag with the language, for example -- @javascript for id, node in pairs(match) do + local data = metadata[id] local name = self._injection_query.captures[id] - -- TODO add a way to offset the content passed to the parser. - -- Needed to shave off leading quotes and things of that nature. + local offset_range = data and data.offset -- Lang should override any other language tag if name == "language" then @@ -301,7 +308,7 @@ function LanguageTree:_get_injections() elseif name == "combined" then combined = true elseif name == "content" then - injection_node = node + injection_node = offset_range or node -- Ignore any tags that start with "_" -- Allows for other tags to be used in matches elseif string.sub(name, 1, 1) ~= "_" then @@ -310,7 +317,7 @@ function LanguageTree:_get_injections() end if not injection_node then - injection_node = node + injection_node = offset_range or node end end end @@ -445,7 +452,7 @@ end function LanguageTree:language_for_range(range) for _, child in pairs(self._children) do if child:contains(range) then - return child:node_for_range(range) + return child:language_for_range(range) end end diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 3537ba78f5..5a27d740a2 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -8,36 +8,10 @@ Query.__index = Query local M = {} --- Filter the runtime query files, the spec is like regular runtime files but in the new `queries` --- directory. They resemble ftplugins, that is that you can override queries by adding things in the --- `queries` directory, and extend using the `after/queries` directory. -local function filter_files(file_list) - local main = nil - local after = {} - - for _, fname in ipairs(file_list) do - -- Only get the name of the directory containing the queries directory - if vim.fn.fnamemodify(fname, ":p:h:h:h:t") == "after" then - table.insert(after, fname) - -- The first one is the one with most priority - elseif not main then - main = fname - end - end - - return main and { main, unpack(after) } or after -end - -local function runtime_query_path(lang, query_name) - return string.format('queries/%s/%s.scm', lang, query_name) -end - -local function filtered_runtime_queries(lang, query_name) - return filter_files(a.nvim_get_runtime_file(runtime_query_path(lang, query_name), true) or {}) -end -local function get_query_files(lang, query_name, is_included) - local lang_files = filtered_runtime_queries(lang, query_name) +function M.get_query_files(lang, query_name, is_included) + local query_path = string.format('queries/%s/%s.scm', lang, query_name) + local lang_files = a.nvim_get_runtime_file(query_path, true) if #lang_files == 0 then return {} end @@ -51,10 +25,10 @@ local function get_query_files(lang, query_name, is_included) local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$" for _, file in ipairs(lang_files) do - local modeline = vim.fn.readfile(file, "", 1) + local modeline = io.open(file, 'r'):read('*l') - if #modeline == 1 then - local langlist = modeline[1]:match(MODELINE_FORMAT) + if modeline then + local langlist = modeline:match(MODELINE_FORMAT) if langlist then for _, incllang in ipairs(vim.split(langlist, ',', true)) do @@ -74,7 +48,7 @@ local function get_query_files(lang, query_name, is_included) local query_files = {} for _, base_lang in ipairs(base_langs) do - local base_files = get_query_files(base_lang, query_name, true) + local base_files = M.get_query_files(base_lang, query_name, true) vim.list_extend(query_files, base_files) end vim.list_extend(query_files, lang_files) @@ -86,10 +60,21 @@ local function read_query_files(filenames) local contents = {} for _,filename in ipairs(filenames) do - vim.list_extend(contents, vim.fn.readfile(filename)) + table.insert(contents, io.open(filename, 'r'):read('*a')) end - return table.concat(contents, '\n') + return table.concat(contents, '') +end + +local match_metatable = { + __index = function(tbl, key) + rawset(tbl, key, {}) + return tbl[key] + end +} + +local function new_match_metadata() + return setmetatable({}, match_metatable) end --- Returns the runtime query {query_name} for {lang}. @@ -99,7 +84,7 @@ end -- -- @return The corresponding query, parsed. function M.get_query(lang, query_name) - local query_files = get_query_files(lang, query_name) + local query_files = M.get_query_files(lang, query_name) local query_string = read_query_files(query_files) if #query_string > 0 then @@ -222,6 +207,44 @@ local predicate_handlers = { -- As we provide lua-match? also expose vim-match? predicate_handlers["vim-match?"] = predicate_handlers["match?"] + +-- Directives store metadata or perform side effects against a match. +-- Directives should always end with a `!`. +-- Directive handler receive the following arguments +-- (match, pattern, bufnr, predicate) +local directive_handlers = { + ["set!"] = function(_, _, _, pred, metadata) + if #pred == 4 then + -- (set! @capture "key" "value") + metadata[pred[2]][pred[3]] = pred[4] + else + -- (set! "key" "value") + metadata[pred[2]] = pred[3] + end + end, + -- Shifts the range of a node. + -- Example: (#offset! @_node 0 1 0 -1) + ["offset!"] = function(match, _, _, pred, metadata) + local offset_node = match[pred[2]] + local range = {offset_node:range()} + local start_row_offset = pred[3] or 0 + local start_col_offset = pred[4] or 0 + local end_row_offset = pred[5] or 0 + local end_col_offset = pred[6] or 0 + local key = pred[7] or "offset" + + range[1] = range[1] + start_row_offset + range[2] = range[2] + start_col_offset + range[3] = range[3] + end_row_offset + range[4] = range[4] + end_col_offset + + -- If this produces an invalid range, we just skip it. + if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then + metadata[pred[2]][key] = range + end + end +} + --- Adds a new predicates to be used in queries -- -- @param name the name of the predicate, without leading # @@ -229,12 +252,25 @@ predicate_handlers["vim-match?"] = predicate_handlers["match?"] -- signature will be (match, pattern, bufnr, predicate) function M.add_predicate(name, handler, force) if predicate_handlers[name] and not force then - a.nvim_err_writeln(string.format("Overriding %s", name)) + error(string.format("Overriding %s", name)) end predicate_handlers[name] = handler end +--- Adds a new directive to be used in queries +-- +-- @param name the name of the directive, without leading # +-- @param handler the handler function to be used +-- signature will be (match, pattern, bufnr, predicate) +function M.add_directive(name, handler, force) + if directive_handlers[name] and not force then + error(string.format("Overriding %s", name)) + end + + directive_handlers[name] = handler +end + --- Returns the list of currently supported predicates function M.list_predicates() return vim.tbl_keys(predicate_handlers) @@ -244,6 +280,10 @@ local function xor(x, y) return (x or y) and not (x and y) end +local function is_directive(name) + return string.sub(name, -1) == "!" +end + function Query:match_preds(match, pattern, source) local preds = self.info.patterns[pattern] @@ -254,30 +294,52 @@ function Query:match_preds(match, pattern, source) -- Also, tree-sitter strips the leading # from predicates for us. local pred_name local is_not - if string.sub(pred[1], 1, 4) == "not-" then - pred_name = string.sub(pred[1], 5) - is_not = true - else - pred_name = pred[1] - is_not = false - end - local handler = predicate_handlers[pred_name] + -- Skip over directives... they will get processed after all the predicates. + if not is_directive(pred[1]) then + if string.sub(pred[1], 1, 4) == "not-" then + pred_name = string.sub(pred[1], 5) + is_not = true + else + pred_name = pred[1] + is_not = false + end + + local handler = predicate_handlers[pred_name] - if not handler then - a.nvim_err_writeln(string.format("No handler for %s", pred[1])) - return false - end + if not handler then + error(string.format("No handler for %s", pred[1])) + return false + end - local pred_matches = handler(match, pattern, source, pred) + local pred_matches = handler(match, pattern, source, pred) - if not xor(is_not, pred_matches) then - return false + if not xor(is_not, pred_matches) then + return false + end end end return true end +--- Applies directives against a match and pattern. +function Query:apply_directives(match, pattern, source, metadata) + local preds = self.info.patterns[pattern] + + for _, pred in pairs(preds or {}) do + if is_directive(pred[1]) then + local handler = directive_handlers[pred[1]] + + if not handler then + error(string.format("No handler for %s", pred[1])) + return + end + + handler(match, pattern, source, pred, metadata) + end + end +end + --- Iterates of the captures of self on a given range. -- -- @param node The node under witch the search will occur @@ -294,14 +356,18 @@ function Query:iter_captures(node, source, start, stop) local raw_iter = node:_rawquery(self.query, true, start, stop) local function iter() local capture, captured_node, match = raw_iter() + local metadata = new_match_metadata() + if match ~= nil then local active = self:match_preds(match, match.pattern, source) match.active = active if not active then return iter() -- tail call: try next match end + + self:apply_directives(match, match.pattern, source, metadata) end - return capture, captured_node + return capture, captured_node, metadata end return iter end @@ -322,13 +388,17 @@ function Query:iter_matches(node, source, start, stop) local raw_iter = node:_rawquery(self.query, false, start, stop) local function iter() local pattern, match = raw_iter() + local metadata = new_match_metadata() + if match ~= nil then local active = self:match_preds(match, pattern, source) if not active then return iter() -- tail call: try next match end + + self:apply_directives(match, pattern, source, metadata) end - return pattern, match + return pattern, match, metadata end return iter end |