diff options
Diffstat (limited to 'runtime/lua/vim')
-rw-r--r-- | runtime/lua/vim/lsp.lua | 360 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/buf.lua | 135 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/diagnostic.lua | 16 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/handlers.lua | 70 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/log.lua | 2 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/rpc.lua | 11 | ||||
-rw-r--r-- | runtime/lua/vim/lsp/util.lua | 143 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter.lua | 41 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/health.lua | 38 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/highlighter.lua | 25 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/language.lua | 20 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/languagetree.lua | 226 | ||||
-rw-r--r-- | runtime/lua/vim/treesitter/query.lua | 215 |
13 files changed, 948 insertions, 354 deletions
diff --git a/runtime/lua/vim/lsp.lua b/runtime/lua/vim/lsp.lua index 563ffc479e..93ec9ed624 100644 --- a/runtime/lua/vim/lsp.lua +++ b/runtime/lua/vim/lsp.lua @@ -232,6 +232,12 @@ local function validate_client_config(config) flags = { config.flags, "t", true }; get_language_id = { config.get_language_id, "f", true }; } + assert( + (not config.flags + or not config.flags.debounce_text_changes + or type(config.flags.debounce_text_changes) == 'number'), + "flags.debounce_text_changes must be nil or a number with the debounce time in milliseconds" + ) local cmd, cmd_args = lsp._cmd_parts(config.cmd) local offset_encoding = valid_encodings.UTF16 @@ -260,21 +266,171 @@ local function buf_get_full_text(bufnr) end --@private +--- Memoizes a function. On first run, the function return value is saved and +--- immediately returned on subsequent runs. If the function returns a multival, +--- only the first returned value will be memoized and returned. The function will only be run once, +--- even if it has side-effects. +--- +--@param fn (function) Function to run +--@returns (function) Memoized function +local function once(fn) + local value + local ran = false + return function(...) + if not ran then + value = fn(...) + ran = true + end + return value + end +end + + +local changetracking = {} +do + --- client_id → state + --- + --- state + --- pending_change?: function that the timer starts to trigger didChange + --- pending_changes: list of tables with the pending changesets; for incremental_sync only + --- use_incremental_sync: bool + --- buffers?: table (bufnr → lines); for incremental sync only + --- timer?: uv_timer + local state_by_client = {} + + function changetracking.init(client, bufnr) + local state = state_by_client[client.id] + if not state then + state = { + pending_changes = {}; + use_incremental_sync = ( + if_nil(client.config.flags.allow_incremental_sync, true) + and client.resolved_capabilities.text_document_did_change == protocol.TextDocumentSyncKind.Incremental + ); + } + state_by_client[client.id] = state + end + if not state.use_incremental_sync then + return + end + if not state.buffers then + state.buffers = {} + end + state.buffers[bufnr] = nvim_buf_get_lines(bufnr, 0, -1, true) + end + + function changetracking.reset_buf(client, bufnr) + local state = state_by_client[client.id] + if state then + changetracking._reset_timer(state) + if state.buffers then + state.buffers[bufnr] = nil + end + end + end + + function changetracking.reset(client_id) + local state = state_by_client[client_id] + if state then + state_by_client[client_id] = nil + changetracking._reset_timer(state) + end + end + + function changetracking.prepare(bufnr, firstline, new_lastline, changedtick) + local incremental_changes = function(client) + local cached_buffers = state_by_client[client.id].buffers + local lines = nvim_buf_get_lines(bufnr, 0, -1, true) + local startline = math.min(firstline + 1, math.min(#cached_buffers[bufnr], #lines)) + local endline = math.min(-(#lines - new_lastline), -1) + local incremental_change = vim.lsp.util.compute_diff( + cached_buffers[bufnr], lines, startline, endline, client.offset_encoding or 'utf-16') + cached_buffers[bufnr] = lines + return incremental_change + end + local full_changes = once(function() + return { + text = buf_get_full_text(bufnr); + }; + end) + local uri = vim.uri_from_bufnr(bufnr) + return function(client) + if client.resolved_capabilities.text_document_did_change == protocol.TextDocumentSyncKind.None then + return + end + local state = state_by_client[client.id] + local debounce = client.config.flags.debounce_text_changes + if not debounce then + local changes = state.use_incremental_sync and incremental_changes(client) or full_changes() + client.notify("textDocument/didChange", { + textDocument = { + uri = uri; + version = changedtick; + }; + contentChanges = { changes, } + }) + return + end + changetracking._reset_timer(state) + if state.use_incremental_sync then + -- This must be done immediately and cannot be delayed + -- The contents would further change and startline/endline may no longer fit + table.insert(state.pending_changes, incremental_changes(client)) + end + state.pending_change = function() + state.pending_change = nil + if client.is_stopped() then + return + end + local contentChanges + if state.use_incremental_sync then + contentChanges = state.pending_changes + state.pending_changes = {} + else + contentChanges = { full_changes(), } + end + client.notify("textDocument/didChange", { + textDocument = { + uri = uri; + version = changedtick; + }; + contentChanges = contentChanges + }) + end + state.timer = vim.loop.new_timer() + -- Must use schedule_wrap because `full_changes()` calls nvim_buf_get_lines + state.timer:start(debounce, 0, vim.schedule_wrap(state.pending_change)) + end + end + + function changetracking._reset_timer(state) + if state.timer then + state.timer:stop() + state.timer:close() + state.timer = nil + end + end + + --- Flushes any outstanding change notification. + function changetracking.flush(client) + local state = state_by_client[client.id] + if state then + changetracking._reset_timer(state) + if state.pending_change then + state.pending_change() + end + end + end +end + + +--@private --- Default handler for the 'textDocument/didOpen' LSP notification. --- --@param bufnr (Number) Number of the buffer, or 0 for current --@param client Client object local function text_document_did_open_handler(bufnr, client) - local use_incremental_sync = ( - if_nil(client.config.flags.allow_incremental_sync, true) - and client.resolved_capabilities.text_document_did_change == protocol.TextDocumentSyncKind.Incremental - ) - if use_incremental_sync then - if not client._cached_buffers then - client._cached_buffers = {} - end - client._cached_buffers[bufnr] = nvim_buf_get_lines(bufnr, 0, -1, true) - end + changetracking.init(client, bufnr) if not client.resolved_capabilities.text_document_open_close then return end @@ -327,6 +483,13 @@ end --- result. You can use this with `client.cancel_request(request_id)` --- to cancel the request. --- +--- - request_sync(method, params, timeout_ms, bufnr) +--- Sends a request to the server and synchronously waits for the response. +--- This is a wrapper around {client.request} +--- Returns: { err=err, result=result }, a dictionary, where `err` and `result` come from +--- the |lsp-handler|. On timeout, cancel or error, returns `(nil, err)` where `err` is a +--- string describing the failure reason. If the request was unsuccessful returns `nil`. +--- --- - notify(method, params) --- Sends a notification to an LSP server. --- Returns: a boolean to indicate if the notification was successful. If @@ -469,6 +632,9 @@ end --- server in the initialize request. Invalid/empty values will default to "off" --@param flags: A table with flags for the client. The current (experimental) flags are: --- - allow_incremental_sync (bool, default true): Allow using incremental sync for buffer edits +--- - debounce_text_changes (number, default nil): Debounce didChange +--- notifications to the server by the given number in milliseconds. No debounce +--- occurs if nil --- --@returns Client id. |vim.lsp.get_client_by_id()| Note: client may not be --- fully initialized. Use `on_init` to do any actions once @@ -563,6 +729,7 @@ function lsp.start_client(config) uninitialized_clients[client_id] = nil lsp.diagnostic.reset(client_id, all_buffer_active_clients) + changetracking.reset(client_id) all_client_active_buffers[client_id] = nil for _, client_ids in pairs(all_buffer_active_clients) do client_ids[client_id] = nil @@ -721,6 +888,9 @@ function lsp.start_client(config) handler = resolve_handler(method) or error(string.format("not found: %q request handler for client %q.", method, client.name)) end + -- Ensure pending didChange notifications are sent so that the server doesn't operate on a stale state + changetracking.flush(client) + local _ = log.debug() and log.debug(log_prefix, "client.request", client_id, method, params, handler, bufnr) return rpc.request(method, params, function(err, result) handler(err, method, result, client_id, bufnr) @@ -728,6 +898,42 @@ function lsp.start_client(config) end --@private + --- Sends a request to the server and synchronously waits for the response. + --- + --- This is a wrapper around {client.request} + --- + --@param method (string) LSP method name. + --@param params (table) LSP request params. + --@param timeout_ms (number, optional, default=1000) Maximum time in + ---milliseconds to wait for a result. + --@param bufnr (number) Buffer handle (0 for current). + --@returns { err=err, result=result }, a dictionary, where `err` and `result` come from the |lsp-handler|. + ---On timeout, cancel or error, returns `(nil, err)` where `err` is a + ---string describing the failure reason. If the request was unsuccessful + ---returns `nil`. + --@see |vim.lsp.buf_request_sync()| + function client.request_sync(method, params, timeout_ms, bufnr) + local request_result = nil + local function _sync_handler(err, _, result) + request_result = { err = err, result = result } + end + + local success, request_id = client.request(method, params, _sync_handler, + bufnr) + if not success then return nil end + + local wait_result, reason = vim.wait(timeout_ms or 1000, function() + return request_result ~= nil + end, 10) + + if not wait_result then + client.cancel_request(request_id) + return nil, wait_result_reason[reason] + end + return request_result + end + + --@private --- Sends a notification to an LSP server. --- --@param method (string) LSP method name. @@ -753,7 +959,7 @@ function lsp.start_client(config) -- Track this so that we can escalate automatically if we've alredy tried a -- graceful shutdown - local tried_graceful_shutdown = false + local graceful_shutdown_failed = false --@private --- Stops a client, optionally with force. --- @@ -765,6 +971,7 @@ function lsp.start_client(config) function client.stop(force) lsp.diagnostic.reset(client_id, all_buffer_active_clients) + changetracking.reset(client_id) all_client_active_buffers[client_id] = nil for _, client_ids in pairs(all_buffer_active_clients) do client_ids[client_id] = nil @@ -774,11 +981,10 @@ function lsp.start_client(config) if handle:is_closing() then return end - if force or (not client.initialized) or tried_graceful_shutdown then + if force or (not client.initialized) or graceful_shutdown_failed then handle:kill(15) return end - tried_graceful_shutdown = true -- Sending a signal after a process has exited is acceptable. rpc.request('shutdown', nil, function(err, _) if err == nil then @@ -786,6 +992,7 @@ function lsp.start_client(config) else -- If there was an error in the shutdown request, then term to be safe. handle:kill(15) + graceful_shutdown_failed = true end end) end @@ -816,20 +1023,6 @@ function lsp.start_client(config) end --@private ---- Memoizes a function. On first run, the function return value is saved and ---- immediately returned on subsequent runs. ---- ---@param fn (function) Function to run ---@returns (function) Memoized function -local function once(fn) - local value - return function(...) - if not value then value = fn(...) end - return value - end -end - ---@private --@fn text_document_did_change_handler(_, bufnr, changedtick, firstline, lastline, new_lastline, old_byte_size, old_utf32_size, old_utf16_size) --- Notify all attached clients that a buffer has changed. local text_document_did_change_handler @@ -848,45 +1041,9 @@ do if tbl_isempty(all_buffer_active_clients[bufnr] or {}) then return end - util.buf_versions[bufnr] = changedtick - - local incremental_changes = function(client) - local lines = nvim_buf_get_lines(bufnr, 0, -1, true) - local startline = math.min(firstline + 1, math.min(#client._cached_buffers[bufnr], #lines)) - local endline = math.min(-(#lines - new_lastline), -1) - local incremental_change = vim.lsp.util.compute_diff( - client._cached_buffers[bufnr], lines, startline, endline, client.offset_encoding or "utf-16") - client._cached_buffers[bufnr] = lines - return incremental_change - end - - local full_changes = once(function() - return { - text = buf_get_full_text(bufnr); - }; - end) - - local uri = vim.uri_from_bufnr(bufnr) - for_each_buffer_client(bufnr, function(client) - local allow_incremental_sync = if_nil(client.config.flags.allow_incremental_sync, true) - local text_document_did_change = client.resolved_capabilities.text_document_did_change - local changes - if text_document_did_change == protocol.TextDocumentSyncKind.None then - return - elseif not allow_incremental_sync or text_document_did_change == protocol.TextDocumentSyncKind.Full then - changes = full_changes(client) - elseif text_document_did_change == protocol.TextDocumentSyncKind.Incremental then - changes = incremental_changes(client) - end - client.notify("textDocument/didChange", { - textDocument = { - uri = uri; - version = changedtick; - }; - contentChanges = { changes; } - }) - end) + local compute_change_and_notify = changetracking.prepare(bufnr, firstline, new_lastline, changedtick) + for_each_buffer_client(bufnr, compute_change_and_notify) end end @@ -956,9 +1113,7 @@ function lsp.buf_attach_client(bufnr, client_id) if client.resolved_capabilities.text_document_open_close then client.notify('textDocument/didClose', params) end - if client._cached_buffers then - client._cached_buffers[bufnr] = nil - end + changetracking.reset_buf(client, bufnr) end) util.buf_versions[bufnr] = nil all_buffer_active_clients[bufnr] = nil @@ -1133,42 +1288,77 @@ function lsp.buf_request(bufnr, method, params, handler) return client_request_ids, _cancel_all_requests end ---- Sends a request to a server and waits for the response. +---Sends an async request for all active clients attached to the buffer. +---Executes the callback on the combined result. +---Parameters are the same as |vim.lsp.buf_request()| but the return result and callback are +---different. --- ---- Calls |vim.lsp.buf_request()| but blocks Nvim while awaiting the result. +--@param bufnr (number) Buffer handle, or 0 for current. +--@param method (string) LSP method name +--@param params (optional, table) Parameters to send to the server +--@param callback (function) The callback to call when all requests are finished. +-- Unlike `buf_request`, this will collect all the responses from each server instead of handling them. +-- A map of client_id:request_result will be provided to the callback +-- +--@returns (function) A function that will cancel all requests which is the same as the one returned from `buf_request`. +function lsp.buf_request_all(bufnr, method, params, callback) + local request_results = {} + local result_count = 0 + local expected_result_count = 0 + local cancel, client_request_ids + + local set_expected_result_count = once(function() + for _ in pairs(client_request_ids) do + expected_result_count = expected_result_count + 1 + end + end) + + local function _sync_handler(err, _, result, client_id) + request_results[client_id] = { error = err, result = result } + result_count = result_count + 1 + set_expected_result_count() + + if result_count >= expected_result_count then + callback(request_results) + end + end + + client_request_ids, cancel = lsp.buf_request(bufnr, method, params, _sync_handler) + + return cancel +end + +--- Sends a request to all server and waits for the response of all of them. +--- +--- Calls |vim.lsp.buf_request_all()| but blocks Nvim while awaiting the result. --- Parameters are the same as |vim.lsp.buf_request()| but the return result is ---- different. Wait maximum of {timeout_ms} (default 100) ms. +--- different. Wait maximum of {timeout_ms} (default 1000) ms. --- --@param bufnr (number) Buffer handle, or 0 for current. --@param method (string) LSP method name --@param params (optional, table) Parameters to send to the server ---@param timeout_ms (optional, number, default=100) Maximum time in +--@param timeout_ms (optional, number, default=1000) Maximum time in --- milliseconds to wait for a result. --- --@returns Map of client_id:request_result. On timeout, cancel or error, --- returns `(nil, err)` where `err` is a string describing the failure --- reason. function lsp.buf_request_sync(bufnr, method, params, timeout_ms) - local request_results = {} - local result_count = 0 - local function _sync_handler(err, _, result, client_id) - request_results[client_id] = { error = err, result = result } - result_count = result_count + 1 - end - local client_request_ids, cancel = lsp.buf_request(bufnr, method, params, _sync_handler) - local expected_result_count = 0 - for _ in pairs(client_request_ids) do - expected_result_count = expected_result_count + 1 - end + local request_results + + local cancel = lsp.buf_request_all(bufnr, method, params, function(it) + request_results = it + end) - local wait_result, reason = vim.wait(timeout_ms or 100, function() - return result_count >= expected_result_count + local wait_result, reason = vim.wait(timeout_ms or 1000, function() + return request_results ~= nil end, 10) if not wait_result then cancel() return nil, wait_result_reason[reason] end + return request_results end diff --git a/runtime/lua/vim/lsp/buf.lua b/runtime/lua/vim/lsp/buf.lua index 31116985e2..5dd7109bb0 100644 --- a/runtime/lua/vim/lsp/buf.lua +++ b/runtime/lua/vim/lsp/buf.lua @@ -111,6 +111,39 @@ function M.completion(context) return request('textDocument/completion', params) end +--@private +--- If there is more than one client that supports the given method, +--- asks the user to select one. +-- +--@returns The client that the user selected or nil +local function select_client(method) + local clients = vim.tbl_values(vim.lsp.buf_get_clients()); + clients = vim.tbl_filter(function (client) + return client.supports_method(method) + end, clients) + -- better UX when choices are always in the same order (between restarts) + table.sort(clients, function (a, b) return a.name < b.name end) + + if #clients > 1 then + local choices = {} + for k,v in ipairs(clients) do + table.insert(choices, string.format("%d %s", k, v.name)) + end + local user_choice = vim.fn.confirm( + "Select a language server:", + table.concat(choices, "\n"), + 0, + "Question" + ) + if user_choice == 0 then return nil end + return clients[user_choice] + elseif #clients < 1 then + return nil + else + return clients[1] + end +end + --- Formats the current buffer. --- --@param options (optional, table) Can be used to specify FormattingOptions. @@ -119,8 +152,11 @@ end -- --@see https://microsoft.github.io/language-server-protocol/specification#textDocument_formatting function M.formatting(options) + local client = select_client("textDocument/formatting") + if client == nil then return end + local params = util.make_formatting_params(options) - return request('textDocument/formatting', params) + return client.request("textDocument/formatting", params) end --- Performs |vim.lsp.buf.formatting()| synchronously. @@ -134,14 +170,62 @@ end --- --@param options Table with valid `FormattingOptions` entries --@param timeout_ms (number) Request timeout +--@see |vim.lsp.buf.formatting_seq_sync| function M.formatting_sync(options, timeout_ms) + local client = select_client("textDocument/formatting") + if client == nil then return end + local params = util.make_formatting_params(options) - local result = vim.lsp.buf_request_sync(0, "textDocument/formatting", params, timeout_ms) - if not result or vim.tbl_isempty(result) then return end - local _, formatting_result = next(result) - result = formatting_result.result - if not result then return end - vim.lsp.util.apply_text_edits(result) + local result, err = client.request_sync("textDocument/formatting", params, timeout_ms) + if result and result.result then + util.apply_text_edits(result.result) + elseif err then + vim.notify("vim.lsp.buf.formatting_sync: " .. err, vim.log.levels.WARN) + end +end + +--- Formats the current buffer by sequentially requesting formatting from attached clients. +--- +--- Useful when multiple clients with formatting capability are attached. +--- +--- Since it's synchronous, can be used for running on save, to make sure buffer is formatted +--- prior to being saved. {timeout_ms} is passed on to the |vim.lsp.client| `request_sync` method. +--- Example: +--- <pre> +--- vim.api.nvim_command[[autocmd BufWritePre <buffer> lua vim.lsp.buf.formatting_seq_sync()]] +--- </pre> +--- +--@param options (optional, table) `FormattingOptions` entries +--@param timeout_ms (optional, number) Request timeout +--@param order (optional, table) List of client names. Formatting is requested from clients +---in the following order: first all clients that are not in the `order` list, then +---the remaining clients in the order as they occur in the `order` list. +function M.formatting_seq_sync(options, timeout_ms, order) + local clients = vim.tbl_values(vim.lsp.buf_get_clients()); + + -- sort the clients according to `order` + for _, client_name in ipairs(order or {}) do + -- if the client exists, move to the end of the list + for i, client in ipairs(clients) do + if client.name == client_name then + table.insert(clients, table.remove(clients, i)) + break + end + end + end + + -- loop through the clients and make synchronous formatting requests + for _, client in ipairs(clients) do + if client.resolved_capabilities.document_formatting then + local params = util.make_formatting_params(options) + local result, err = client.request_sync("textDocument/formatting", params, timeout_ms) + if result and result.result then + util.apply_text_edits(result.result) + elseif err then + vim.notify(string.format("vim.lsp.buf.formatting_seq_sync: (%s) %s", client.name, err), vim.log.levels.WARN) + end + end + end end --- Formats a given range. @@ -152,15 +236,12 @@ end --@param end_pos ({number, number}, optional) mark-indexed position. ---Defaults to the end of the last visual selection. function M.range_formatting(options, start_pos, end_pos) - validate { options = {options, 't', true} } - local sts = vim.bo.softtabstop; - options = vim.tbl_extend('keep', options or {}, { - tabSize = (sts > 0 and sts) or (sts < 0 and vim.bo.shiftwidth) or vim.bo.tabstop; - insertSpaces = vim.bo.expandtab; - }) + local client = select_client("textDocument/rangeFormatting") + if client == nil then return end + local params = util.make_given_range_params(start_pos, end_pos) - params.options = options - return request('textDocument/rangeFormatting', params) + params.options = util.make_formatting_params(options).options + return client.request("textDocument/rangeFormatting", params) end --- Renames all references to the symbol under the cursor. @@ -216,26 +297,30 @@ local function pick_call_hierarchy_item(call_hierarchy_items) return choice end +local function call_hierarchy(method) + local params = util.make_position_params() + request('textDocument/prepareCallHierarchy', params, function(err, _, result) + if err then + vim.notify(err.message, vim.log.levels.WARN) + return + end + local call_hierarchy_item = pick_call_hierarchy_item(result) + vim.lsp.buf_request(0, method, { item = call_hierarchy_item }) + end) +end + --- Lists all the call sites of the symbol under the cursor in the --- |quickfix| window. If the symbol can resolve to multiple --- items, the user can pick one in the |inputlist|. function M.incoming_calls() - local params = util.make_position_params() - request('textDocument/prepareCallHierarchy', params, function(_, _, result) - local call_hierarchy_item = pick_call_hierarchy_item(result) - vim.lsp.buf_request(0, 'callHierarchy/incomingCalls', { item = call_hierarchy_item }) - end) + call_hierarchy('callHierarchy/incomingCalls') end --- Lists all the items that are called by the symbol under the --- cursor in the |quickfix| window. If the symbol can resolve to --- multiple items, the user can pick one in the |inputlist|. function M.outgoing_calls() - local params = util.make_position_params() - request('textDocument/prepareCallHierarchy', params, function(_, _, result) - local call_hierarchy_item = pick_call_hierarchy_item(result) - vim.lsp.buf_request(0, 'callHierarchy/outgoingCalls', { item = call_hierarchy_item }) - end) + call_hierarchy('callHierarchy/outgoingCalls') end --- List workspace folders. diff --git a/runtime/lua/vim/lsp/diagnostic.lua b/runtime/lua/vim/lsp/diagnostic.lua index 4e82c46fef..6f2f846a3b 100644 --- a/runtime/lua/vim/lsp/diagnostic.lua +++ b/runtime/lua/vim/lsp/diagnostic.lua @@ -406,9 +406,7 @@ function M.get_line_diagnostics(bufnr, line_nr, opts, client_id) line_diagnostics = filter_by_severity_limit(opts.severity_limit, line_diagnostics) end - if opts.severity_sort then - table.sort(line_diagnostics, function(a, b) return a.severity < b.severity end) - end + table.sort(line_diagnostics, function(a, b) return a.severity < b.severity end) return line_diagnostics end @@ -997,6 +995,8 @@ end --- - See |vim.lsp.diagnostic.set_signs()| --- - update_in_insert: (default=false) --- - Update diagnostics in InsertMode or wait until InsertLeave +--- - severity_sort: (default=false) +--- - Sort diagnostics (and thus signs and virtual text) function M.on_publish_diagnostics(_, _, params, client_id, _, config) local uri = params.uri local bufnr = vim.uri_to_bufnr(uri) @@ -1007,6 +1007,10 @@ function M.on_publish_diagnostics(_, _, params, client_id, _, config) local diagnostics = params.diagnostics + if config and if_nil(config.severity_sort, false) then + table.sort(diagnostics, function(a, b) return a.severity > b.severity end) + end + -- Always save the diagnostics, even if the buf is not loaded. -- Language servers may report compile or build errors via diagnostics -- Users should be able to find these, even if they're in files which @@ -1034,6 +1038,7 @@ function M.display(diagnostics, bufnr, client_id, config) underline = true, virtual_text = true, update_in_insert = false, + severity_sort = false, }, config) -- TODO(tjdevries): Consider how we can make this a "standardized" kind of thing for |lsp-handlers|. @@ -1116,7 +1121,6 @@ end ---@return table {popup_bufnr, win_id} function M.show_line_diagnostics(opts, bufnr, line_nr, client_id) opts = opts or {} - opts.severity_sort = if_nil(opts.severity_sort, true) local show_header = if_nil(opts.show_header, true) @@ -1140,14 +1144,14 @@ function M.show_line_diagnostics(opts, bufnr, line_nr, client_id) local message_lines = vim.split(diagnostic.message, '\n', true) table.insert(lines, prefix..message_lines[1]) - table.insert(highlights, {#prefix + 1, hiname}) + table.insert(highlights, {#prefix, hiname}) for j = 2, #message_lines do table.insert(lines, message_lines[j]) table.insert(highlights, {0, hiname}) end end - local popup_bufnr, winnr = util.open_floating_preview(lines, 'plaintext') + local popup_bufnr, winnr = util.open_floating_preview(lines, 'plaintext', opts) for i, hi in ipairs(highlights) do local prefixlen, hiname = unpack(hi) -- Start highlight after the prefix diff --git a/runtime/lua/vim/lsp/handlers.lua b/runtime/lua/vim/lsp/handlers.lua index eacbd90077..18155ceb7e 100644 --- a/runtime/lua/vim/lsp/handlers.lua +++ b/runtime/lua/vim/lsp/handlers.lua @@ -245,9 +245,22 @@ M['textDocument/completion'] = function(_, _, result) vim.fn.complete(textMatch+1, matches) end ---@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_hover -M['textDocument/hover'] = function(_, method, result) - util.focusable_float(method, function() +--- |lsp-handler| for the method "textDocument/hover" +--- <pre> +--- vim.lsp.handlers["textDocument/hover"] = vim.lsp.with( +--- vim.lsp.handlers.hover, { +--- -- Use a sharp border with `FloatBorder` highlights +--- border = "single" +--- } +--- ) +--- </pre> +---@param config table Configuration table. +--- - border: (default=nil) +--- - Add borders to the floating window +--- - See |vim.api.nvim_open_win()| +function M.hover(_, method, result, _, _, config) + config = config or {} + local bufnr, winnr = util.focusable_float(method, function() if not (result and result.contents) then -- return { 'No information available' } return @@ -258,14 +271,16 @@ M['textDocument/hover'] = function(_, method, result) -- return { 'No information available' } return end - local bufnr, winnr = util.fancy_floating_markdown(markdown_lines, { - pad_left = 1; pad_right = 1; - }) + local bufnr, winnr = util.fancy_floating_markdown(markdown_lines, config) util.close_preview_autocmd({"CursorMoved", "BufHidden", "InsertCharPre"}, winnr) return bufnr, winnr end) + return bufnr, winnr end +--@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_hover +M['textDocument/hover'] = M.hover + --@private --- Jumps to a location. Used as a handler for multiple LSP methods. --@param _ (not used) @@ -303,27 +318,46 @@ M['textDocument/typeDefinition'] = location_handler --@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_implementation M['textDocument/implementation'] = location_handler ---@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_signatureHelp -M['textDocument/signatureHelp'] = function(_, method, result, _, bufnr) +--- |lsp-handler| for the method "textDocument/signatureHelp" +--- <pre> +--- vim.lsp.handlers["textDocument/signatureHelp"] = vim.lsp.with( +--- vim.lsp.handlers.signature_help, { +--- -- Use a sharp border with `FloatBorder` highlights +--- border = "single" +--- } +--- ) +--- </pre> +---@param config table Configuration table. +--- - border: (default=nil) +--- - Add borders to the floating window +--- - See |vim.api.nvim_open_win()| +function M.signature_help(_, method, result, _, bufnr, config) + config = config or {} -- When use `autocmd CompleteDone <silent><buffer> lua vim.lsp.buf.signature_help()` to call signatureHelp handler -- If the completion item doesn't have signatures It will make noise. Change to use `print` that can use `<silent>` to ignore if not (result and result.signatures and result.signatures[1]) then print('No signature help available') return end - local lines = util.convert_signature_help_to_markdown_lines(result) - lines = util.trim_empty_lines(lines) - if vim.tbl_isempty(lines) then - print('No signature help available') - return - end - local syntax = api.nvim_buf_get_option(bufnr, 'syntax') - local p_bufnr, _ = util.focusable_preview(method, function() - return lines, util.try_trim_markdown_code_blocks(lines) + local p_bufnr, winnr = util.focusable_float(method, function() + local ft = api.nvim_buf_get_option(bufnr, 'filetype') + local lines = util.convert_signature_help_to_markdown_lines(result, ft) + lines = util.trim_empty_lines(lines) + if vim.tbl_isempty(lines) then + print('No signature help available') + return + end + local p_bufnr, p_winnr = util.fancy_floating_markdown(lines, config) + util.close_preview_autocmd({"CursorMoved", "CursorMovedI", "BufHidden", "InsertCharPre"}, p_winnr) + + return p_bufnr, p_winnr end) - api.nvim_buf_set_option(p_bufnr, 'syntax', syntax) + return p_bufnr, winnr end +--@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_signatureHelp +M['textDocument/signatureHelp'] = M.signature_help + --@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_documentHighlight M['textDocument/documentHighlight'] = function(_, _, result, _, bufnr, _) if not result then return end diff --git a/runtime/lua/vim/lsp/log.lua b/runtime/lua/vim/lsp/log.lua index 331e980e67..471a311c16 100644 --- a/runtime/lua/vim/lsp/log.lua +++ b/runtime/lua/vim/lsp/log.lua @@ -10,7 +10,7 @@ local log = {} -- Can be used to lookup the number from the name or the name from the number. -- Levels by name: 'trace', 'debug', 'info', 'warn', 'error' -- Level numbers begin with 'trace' at 0 -log.levels = vim.log.levels +log.levels = vim.deepcopy(vim.log.levels) -- Default log level is warn. local current_log_level = log.levels.WARN diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua index 4e2dd7c8e8..0cabd1a0d4 100644 --- a/runtime/lua/vim/lsp/rpc.lua +++ b/runtime/lua/vim/lsp/rpc.lua @@ -315,8 +315,10 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params) dispatchers = { dispatchers, 't', true }; } - if not (vim.fn.executable(cmd) == 1) then - error(string.format("The given command %q is not executable.", cmd)) + if extra_spawn_params and extra_spawn_params.cwd then + assert(is_dir(extra_spawn_params.cwd), "cwd must be a directory") + elseif not (vim.fn.executable(cmd) == 1) then + error(string.format("The given command %q is not executable.", cmd)) end if dispatchers then local user_dispatchers = dispatchers @@ -370,9 +372,6 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params) } if extra_spawn_params then spawn_params.cwd = extra_spawn_params.cwd - if spawn_params.cwd then - assert(is_dir(spawn_params.cwd), "cwd must be a directory") - end spawn_params.env = env_merge(extra_spawn_params.env) end handle, pid = uv.spawn(cmd, spawn_params, onexit) @@ -519,7 +518,7 @@ local function start(cmd, cmd_args, dispatchers, extra_spawn_params) send_response(decoded.id, err, result) end) -- This works because we are expecting vim.NIL here - elseif decoded.id and (decoded.result or decoded.error) then + elseif decoded.id and (decoded.result ~= vim.NIL or decoded.error ~= vim.NIL) then -- Server Result decoded.error = convert_NIL(decoded.error) decoded.result = convert_NIL(decoded.result) diff --git a/runtime/lua/vim/lsp/util.lua b/runtime/lua/vim/lsp/util.lua index 8627e0f3fb..01fa22999a 100644 --- a/runtime/lua/vim/lsp/util.lua +++ b/runtime/lua/vim/lsp/util.lua @@ -18,6 +18,40 @@ end local M = {} +local default_border = { + {"", "NormalFloat"}, + {"", "NormalFloat"}, + {"", "NormalFloat"}, + {" ", "NormalFloat"}, + {"", "NormalFloat"}, + {"", "NormalFloat"}, + {"", "NormalFloat"}, + {" ", "NormalFloat"}, +} + +--@private +-- Check the border given by opts or the default border for the additional +-- size it adds to a float. +--@returns size of border in height and width +local function get_border_size(opts) + local border = opts and opts.border or default_border + local height = 0 + local width = 0 + + if type(border) == 'string' then + -- 'single', 'double', etc. + height = 2 + width = 2 + else + height = height + vim.fn.strdisplaywidth(border[2][1]) -- top + height = height + vim.fn.strdisplaywidth(border[6][1]) -- bottom + width = width + vim.fn.strdisplaywidth(border[4][1]) -- right + width = width + vim.fn.strdisplaywidth(border[8][1]) -- left + end + + return { height = height, width = width } +end + --@private local function split_lines(value) return split(value, '\n', true) @@ -39,7 +73,7 @@ function M.set_lines(lines, A, B, new_lines) -- way the LSP describes the range including the last newline is by -- specifying a line number after what we would call the last line. local i_n = math.min(B[1] + 1, #lines) - if not (i_0 >= 1 and i_0 <= #lines and i_n >= 1 and i_n <= #lines) then + if not (i_0 >= 1 and i_0 <= #lines + 1 and i_n >= 1 and i_n <= #lines) then error("Invalid range: "..vim.inspect{A = A; B = B; #lines, new_lines}) end local prefix = "" @@ -436,6 +470,7 @@ function M.apply_text_document_edit(text_document_edit, index) -- `VersionedTextDocumentIdentifier`s version may be null -- https://microsoft.github.io/language-server-protocol/specification#versionedTextDocumentIdentifier if should_check_version and (text_document.version + and text_document.version > 0 and M.buf_versions[bufnr] and M.buf_versions[bufnr] > text_document.version) then print("Buffer ", text_document.uri, " newer than edits.") @@ -531,13 +566,15 @@ end --@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_completion local function get_completion_word(item) if item.textEdit ~= nil and item.textEdit.newText ~= nil and item.textEdit.newText ~= "" then - if protocol.InsertTextFormat[item.insertTextFormat] == "PlainText" then + local insert_text_format = protocol.InsertTextFormat[item.insertTextFormat] + if insert_text_format == "PlainText" or insert_text_format == nil then return item.textEdit.newText else return M.parse_snippet(item.textEdit.newText) end elseif item.insertText ~= nil and item.insertText ~= "" then - if protocol.InsertTextFormat[item.insertTextFormat] == "PlainText" then + local insert_text_format = protocol.InsertTextFormat[item.insertTextFormat] + if insert_text_format == "PlainText" or insert_text_format == nil then return item.insertText else return M.parse_snippet(item.insertText) @@ -767,9 +804,10 @@ end --- Converts `textDocument/SignatureHelp` response to markdown lines. --- --@param signature_help Response of `textDocument/SignatureHelp` +--@param ft optional filetype that will be use as the `lang` for the label markdown code block --@returns list of lines of converted markdown. --@see https://microsoft.github.io/language-server-protocol/specifications/specification-current/#textDocument_signatureHelp -function M.convert_signature_help_to_markdown_lines(signature_help) +function M.convert_signature_help_to_markdown_lines(signature_help, ft) if not signature_help.signatures then return end @@ -787,7 +825,12 @@ function M.convert_signature_help_to_markdown_lines(signature_help) if not signature then return end - vim.list_extend(contents, vim.split(signature.label, '\n', true)) + local label = signature.label + if ft then + -- wrap inside a code block so fancy_markdown can render it properly + label = ("```%s\n%s\n```"):format(ft, label) + end + vim.list_extend(contents, vim.split(label, '\n', true)) if signature.documentation then M.convert_input_to_markdown_lines(signature.documentation, contents) end @@ -856,7 +899,7 @@ function M.make_floating_popup_options(width, height, opts) else anchor = anchor..'S' height = math.min(lines_above, height) - row = 0 + row = -get_border_size(opts).height end if vim.fn.wincol() + width <= api.nvim_get_option('columns') then @@ -875,6 +918,7 @@ function M.make_floating_popup_options(width, height, opts) row = row + (opts.offset_y or 0), style = 'minimal', width = width, + border = opts.border or default_border, } end @@ -913,7 +957,7 @@ end --- --@param location a single `Location` or `LocationLink` --@returns (bufnr,winnr) buffer and window number of floating window or nil -function M.preview_location(location) +function M.preview_location(location, opts) -- location may be LocationLink or Location (more useful for the former) local uri = location.targetUri or location.uri if uri == nil then return end @@ -924,7 +968,13 @@ function M.preview_location(location) local range = location.targetRange or location.range local contents = api.nvim_buf_get_lines(bufnr, range.start.line, range["end"].line+1, false) local syntax = api.nvim_buf_get_option(bufnr, 'syntax') - return M.open_floating_preview(contents, syntax) + if syntax == "" then + -- When no syntax is set, we use filetype as fallback. This might not result + -- in a valid syntax definition. See also ft detection in fancy_floating_win. + -- An empty syntax is more common now with TreeSitter, since TS disables syntax. + syntax = api.nvim_buf_get_option(bufnr, 'filetype') + end + return M.open_floating_preview(contents, syntax, opts) end --@private @@ -981,27 +1031,20 @@ function M.focusable_preview(unique_name, fn) end) end ---- Trims empty lines from input and pad left and right with spaces +--- Trims empty lines from input and pad top and bottom with empty lines --- ---@param contents table of lines to trim and pad ---@param opts dictionary with optional fields ---- - pad_left number of columns to pad contents at left (default 1) ---- - pad_right number of columns to pad contents at right (default 1) --- - pad_top number of lines to pad contents at top (default 0) --- - pad_bottom number of lines to pad contents at bottom (default 0) ---@return contents table of trimmed and padded lines -function M._trim_and_pad(contents, opts) +function M._trim(contents, opts) validate { contents = { contents, 't' }; opts = { opts, 't', true }; } opts = opts or {} - local left_padding = (" "):rep(opts.pad_left or 1) - local right_padding = (" "):rep(opts.pad_right or 1) contents = M.trim_empty_lines(contents) - for i, line in ipairs(contents) do - contents[i] = string.format('%s%s%s', left_padding, line:gsub("\r", ""), right_padding) - end if opts.pad_top then for _ = 1, opts.pad_top do table.insert(contents, 1, "") @@ -1055,12 +1098,16 @@ function M.fancy_floating_markdown(contents, opts) local ft = line:match("^```([a-zA-Z0-9_]*)$") -- local ft = line:match("^```(.*)$") -- TODO(ashkan): validate the filetype here. + local is_pre = line:match("^%s*<pre>%s*$") + if is_pre then + ft = "" + end if ft then local start = #stripped i = i + 1 while i <= #contents do line = contents[i] - if line == "```" then + if line == "```" or (is_pre and line:match("^%s*</pre>%s*$")) then i = i + 1 break end @@ -1078,8 +1125,8 @@ function M.fancy_floating_markdown(contents, opts) end end end - -- Clean up and add padding - stripped = M._trim_and_pad(stripped, opts) + -- Clean up + stripped = M._trim(stripped, opts) -- Compute size of float needed to show (wrapped) lines opts.wrap_at = opts.wrap_at or (vim.wo["wrap"] and api.nvim_win_get_width(0)) @@ -1089,12 +1136,19 @@ function M.fancy_floating_markdown(contents, opts) local insert_separator = opts.separator if insert_separator == nil then insert_separator = true end if insert_separator then - for i, h in ipairs(highlights) do - h.start = h.start + i - 1 - h.finish = h.finish + i - 1 + local offset = 0 + for _, h in ipairs(highlights) do + h.start = h.start + offset + h.finish = h.finish + offset + -- check if a seperator already exists and use that one instead of creating a new one if h.finish + 1 <= #stripped then - table.insert(stripped, h.finish + 1, string.rep("─", math.min(width, opts.wrap_at or width))) - height = height + 1 + if stripped[h.finish + 1]:match("^---+$") then + stripped[h.finish + 1] = string.rep("─", math.min(width, opts.wrap_at or width)) + else + table.insert(stripped, h.finish + 1, string.rep("─", math.min(width, opts.wrap_at or width))) + offset = offset + 1 + height = height + 1 + end end end end @@ -1113,27 +1167,28 @@ function M.fancy_floating_markdown(contents, opts) api.nvim_win_set_option(winnr, 'concealcursor', 'n') vim.cmd("ownsyntax lsp_markdown") + local idx = 1 --@private local function apply_syntax_to_region(ft, start, finish) - if ft == '' then return end + if ft == "" then + vim.cmd(string.format("syntax region markdownCode start=+\\%%%dl+ end=+\\%%%dl+ keepend extend", start, finish + 1)) + return + end local name = ft..idx idx = idx + 1 local lang = "@"..ft:upper() + -- HACK: reset current_syntax, since some syntax files like markdown won't load if it is already set + pcall(vim.api.nvim_buf_del_var, bufnr, "current_syntax") -- TODO(ashkan): better validation before this. if not pcall(vim.cmd, string.format("syntax include %s syntax/%s.vim", lang, ft)) then return end - vim.cmd(string.format("syntax region %s start=+\\%%%dl+ end=+\\%%%dl+ contains=%s", name, start, finish + 1, lang)) + vim.cmd(string.format("syntax region %s start=+\\%%%dl+ end=+\\%%%dl+ contains=%s keepend", name, start, finish + 1, lang)) end - -- Previous highlight region. - -- TODO(ashkan): this wasn't working for some reason, but I would like to - -- make sure that regions between code blocks are definitely markdown. - -- local ph = {start = 0; finish = 1;} + for _, h in ipairs(highlights) do - -- apply_syntax_to_region('markdown', ph.finish, h.start) apply_syntax_to_region(h.ft, h.start, h.finish) - -- ph = h end vim.api.nvim_set_current_win(cwin) @@ -1182,6 +1237,20 @@ function M._make_floating_popup_size(contents, opts) width = math.max(line_widths[i], width) end end + + local border_width = get_border_size(opts).width + local screen_width = api.nvim_win_get_width(0) + width = math.min(width, screen_width) + + -- make sure borders are always inside the screen + if width + border_width > screen_width then + width = width - (width + border_width - screen_width) + end + + if wrap_at and wrap_at > width then + wrap_at = width + end + if max_width then width = math.min(width, max_width) wrap_at = math.min(wrap_at or max_width, max_width) @@ -1235,7 +1304,7 @@ function M.open_floating_preview(contents, syntax, opts) opts = opts or {} -- Clean up input: trim empty lines from the end, pad - contents = M._trim_and_pad(contents, opts) + contents = M._trim(contents, opts) -- Compute size of float needed to show (wrapped) lines opts.wrap_at = opts.wrap_at or (vim.wo["wrap"] and api.nvim_win_get_width(0)) @@ -1521,6 +1590,12 @@ function M.make_given_range_params(start_pos, end_pos) if B[2] > 0 then B = {B[1], M.character_offset(0, B[1], B[2])} end + -- we need to offset the end character position otherwise we loose the last + -- character of the selection, as LSP end position is exclusive + -- see https://microsoft.github.io/language-server-protocol/specification#range + if vim.o.selection ~= 'exclusive' then + B[2] = B[2] + 1 + end return { textDocument = M.make_text_document_params(), range = { diff --git a/runtime/lua/vim/treesitter.lua b/runtime/lua/vim/treesitter.lua index cac0ab864b..de997b2d86 100644 --- a/runtime/lua/vim/treesitter.lua +++ b/runtime/lua/vim/treesitter.lua @@ -17,17 +17,20 @@ setmetatable(M, { if k == "highlighter" then t[k] = require'vim.treesitter.highlighter' return t[k] + elseif k == "language" then + t[k] = require"vim.treesitter.language" + return t[k] end end }) --- Creates a new parser. --- --- It is not recommended to use this, use vim.treesitter.get_parser() instead. --- --- @param bufnr The buffer the parser will be tied to --- @param lang The language of the parser --- @param opts Options to pass to the language tree +--- +--- It is not recommended to use this, use vim.treesitter.get_parser() instead. +--- +--- @param bufnr The buffer the parser will be tied to +--- @param lang The language of the parser +--- @param opts Options to pass to the created language tree function M._create_parser(bufnr, lang, opts) language.require_language(lang) if bufnr == 0 then @@ -38,10 +41,12 @@ function M._create_parser(bufnr, lang, opts) local self = LanguageTree.new(bufnr, lang, opts) + ---@private local function bytes_cb(_, ...) self:_on_bytes(...) end + ---@private local function detach_cb(_, ...) if parsers[bufnr] == self then parsers[bufnr] = nil @@ -49,6 +54,7 @@ function M._create_parser(bufnr, lang, opts) self:_on_detach(...) end + ---@private local function reload_cb(_, ...) self:_on_reload(...) end @@ -61,15 +67,15 @@ function M._create_parser(bufnr, lang, opts) end --- Gets the parser for this bufnr / ft combination. --- --- If needed this will create the parser. --- Unconditionnally attach the provided callback --- --- @param bufnr The buffer the parser should be tied to --- @param ft The filetype of this parser --- @param opts Options object to pass to the parser --- --- @returns The parser +--- +--- If needed this will create the parser. +--- Unconditionnally attach the provided callback +--- +--- @param bufnr The buffer the parser should be tied to +--- @param lang The filetype of this parser +--- @param opts Options object to pass to the created language tree +--- +--- @returns The parser function M.get_parser(bufnr, lang, opts) opts = opts or {} @@ -89,6 +95,11 @@ function M.get_parser(bufnr, lang, opts) return parsers[bufnr] end +--- Gets a string parser +--- +--- @param str The string to parse +--- @param lang The language of this string +--- @param opts Options to pass to the created language tree function M.get_string_parser(str, lang, opts) vim.validate { str = { str, 'string' }, diff --git a/runtime/lua/vim/treesitter/health.lua b/runtime/lua/vim/treesitter/health.lua new file mode 100644 index 0000000000..e031ba1bd6 --- /dev/null +++ b/runtime/lua/vim/treesitter/health.lua @@ -0,0 +1,38 @@ +local M = {} +local ts = vim.treesitter + +--- Lists the parsers currently installed +--- +---@return A list of parsers +function M.list_parsers() + return vim.api.nvim_get_runtime_file('parser/*', true) +end + +--- Performs a healthcheck for treesitter integration +function M.check_health() + local report_info = vim.fn['health#report_info'] + local report_ok = vim.fn['health#report_ok'] + local report_error = vim.fn['health#report_error'] + local parsers = M.list_parsers() + + report_info(string.format("Runtime ABI version : %d", ts.language_version)) + + for _, parser in pairs(parsers) do + local parsername = vim.fn.fnamemodify(parser, ":t:r") + + local is_loadable, ret = pcall(ts.language.require_language, parsername) + + if not is_loadable then + report_error(string.format("Impossible to load parser for %s: %s", parsername, ret)) + elseif ret then + local lang = ts.language.inspect_language(parsername) + report_ok(string.format("Loaded parser for %s: ABI version %d", + parsername, lang._abi_version)) + else + report_error(string.format("Unable to load parser for %s", parsername)) + end + end +end + +return M + diff --git a/runtime/lua/vim/treesitter/highlighter.lua b/runtime/lua/vim/treesitter/highlighter.lua index afb72dc069..84b6a5f135 100644 --- a/runtime/lua/vim/treesitter/highlighter.lua +++ b/runtime/lua/vim/treesitter/highlighter.lua @@ -70,11 +70,13 @@ TSHighlighter.hl_map = { ["include"] = "Include", } +---@private local function is_highlight_name(capture_name) local firstc = string.sub(capture_name, 1, 1) return firstc ~= string.lower(firstc) end +---@private function TSHighlighterQuery.new(lang, query_string) local self = setmetatable({}, { __index = TSHighlighterQuery }) @@ -99,10 +101,12 @@ function TSHighlighterQuery.new(lang, query_string) return self end +---@private function TSHighlighterQuery:query() return self._query end +---@private --- Get the hl from capture. --- Returns a tuple { highlight_name: string, is_builtin: bool } function TSHighlighterQuery:_get_hl_from_capture(capture) @@ -116,6 +120,11 @@ function TSHighlighterQuery:_get_hl_from_capture(capture) end end +--- Creates a new highlighter using @param tree +--- +--- @param tree The language tree to use for highlighting +--- @param opts Table used to configure the highlighter +--- - queries: Table to overwrite queries used by the highlighter function TSHighlighter.new(tree, opts) local self = setmetatable({}, TSHighlighter) @@ -165,12 +174,14 @@ function TSHighlighter.new(tree, opts) return self end +--- Removes all internal references to the highlighter function TSHighlighter:destroy() if TSHighlighter.active[self.bufnr] then TSHighlighter.active[self.bufnr] = nil end end +---@private function TSHighlighter:get_highlight_state(tstree) if not self._highlight_states[tstree] then self._highlight_states[tstree] = { @@ -182,24 +193,31 @@ function TSHighlighter:get_highlight_state(tstree) return self._highlight_states[tstree] end +---@private function TSHighlighter:reset_highlight_state() self._highlight_states = {} end +---@private function TSHighlighter:on_bytes(_, _, start_row, _, _, _, _, _, new_end) a.nvim__buf_redraw_range(self.bufnr, start_row, start_row + new_end + 1) end +---@private function TSHighlighter:on_detach() self:destroy() end +---@private function TSHighlighter:on_changedtree(changes) for _, ch in ipairs(changes or {}) do a.nvim__buf_redraw_range(self.bufnr, ch[1], ch[3]+1) end end +--- Gets the query used for @param lang +--- +--- @param lang A language used by the highlighter. function TSHighlighter:get_query(lang) if not self._queries[lang] then self._queries[lang] = TSHighlighterQuery.new(lang) @@ -208,6 +226,7 @@ function TSHighlighter:get_query(lang) return self._queries[lang] end +---@private local function on_line_impl(self, buf, line) self.tree:for_each_tree(function(tstree, tree) if not tstree then return end @@ -221,6 +240,9 @@ local function on_line_impl(self, buf, line) local state = self:get_highlight_state(tstree) local highlighter_query = self:get_query(tree:lang()) + -- Some injected languages may not have highlight queries. + if not highlighter_query:query() then return end + if state.iter == nil then state.iter = highlighter_query:query():iter_captures(root_node, self.bufnr, line, root_end_row + 1) end @@ -248,6 +270,7 @@ local function on_line_impl(self, buf, line) end, true) end +---@private function TSHighlighter._on_line(_, _win, buf, line, _) local self = TSHighlighter.active[buf] if not self then return end @@ -255,6 +278,7 @@ function TSHighlighter._on_line(_, _win, buf, line, _) on_line_impl(self, buf, line) end +---@private function TSHighlighter._on_buf(_, buf) local self = TSHighlighter.active[buf] if self then @@ -262,6 +286,7 @@ function TSHighlighter._on_buf(_, buf) end end +---@private function TSHighlighter._on_win(_, _win, buf, _topline) local self = TSHighlighter.active[buf] if not self then diff --git a/runtime/lua/vim/treesitter/language.lua b/runtime/lua/vim/treesitter/language.lua index eed28e0e41..6dc37c7848 100644 --- a/runtime/lua/vim/treesitter/language.lua +++ b/runtime/lua/vim/treesitter/language.lua @@ -3,12 +3,12 @@ local a = vim.api local M = {} --- Asserts that the provided language is installed, and optionally provide a path for the parser --- --- Parsers are searched in the `parser` runtime directory. --- --- @param lang The language the parser should parse --- @param path Optional path the parser is located at --- @param silent Don't throw an error if language not found +--- +--- Parsers are searched in the `parser` runtime directory. +--- +--- @param lang The language the parser should parse +--- @param path Optional path the parser is located at +--- @param silent Don't throw an error if language not found function M.require_language(lang, path, silent) if vim._ts_has_language(lang) then return true @@ -37,10 +37,10 @@ function M.require_language(lang, path, silent) end --- Inspects the provided language. --- --- Inspecting provides some useful informations on the language like node names, ... --- --- @param lang The language. +--- +--- Inspecting provides some useful informations on the language like node names, ... +--- +--- @param lang The language. function M.inspect_language(lang) M.require_language(lang) return vim._ts_inspect_language(lang) diff --git a/runtime/lua/vim/treesitter/languagetree.lua b/runtime/lua/vim/treesitter/languagetree.lua index 4168c1e365..899d90e464 100644 --- a/runtime/lua/vim/treesitter/languagetree.lua +++ b/runtime/lua/vim/treesitter/languagetree.lua @@ -5,21 +5,26 @@ local language = require'vim.treesitter.language' local LanguageTree = {} LanguageTree.__index = LanguageTree --- Represents a single treesitter parser for a language. --- The language can contain child languages with in its range, --- hence the tree. --- --- @param source Can be a bufnr or a string of text to parse --- @param lang The language this tree represents --- @param opts Options table --- @param opts.queries A table of language to injection query strings. --- This is useful for overriding the built-in runtime file --- searching for the injection language query per language. +--- Represents a single treesitter parser for a language. +--- The language can contain child languages with in its range, +--- hence the tree. +--- +--- @param source Can be a bufnr or a string of text to parse +--- @param lang The language this tree represents +--- @param opts Options table +--- @param opts.injections A table of language to injection query strings. +--- This is useful for overriding the built-in runtime file +--- searching for the injection language query per language. function LanguageTree.new(source, lang, opts) language.require_language(lang) opts = opts or {} - local custom_queries = opts.queries or {} + if opts.queries then + a.nvim_err_writeln("'queries' is no longer supported. Use 'injections' now") + opts.injections = opts.queries + end + + local injections = opts.injections or {} local self = setmetatable({ _source = source, _lang = lang, @@ -27,8 +32,8 @@ function LanguageTree.new(source, lang, opts) _regions = {}, _trees = {}, _opts = opts, - _injection_query = custom_queries[lang] - and query.parse_query(lang, custom_queries[lang]) + _injection_query = injections[lang] + and query.parse_query(lang, injections[lang]) or query.get_query(lang, "injections"), _valid = false, _parser = vim._create_ts_parser(lang), @@ -45,7 +50,7 @@ function LanguageTree.new(source, lang, opts) return self end --- Invalidates this parser and all its children +--- Invalidates this parser and all its children function LanguageTree:invalidate(reload) self._valid = false @@ -59,38 +64,38 @@ function LanguageTree:invalidate(reload) end end --- Returns all trees this language tree contains. --- Does not include child languages. +--- Returns all trees this language tree contains. +--- Does not include child languages. function LanguageTree:trees() return self._trees end --- Gets the language of this tree layer. +--- Gets the language of this tree node. function LanguageTree:lang() return self._lang end --- Determines whether this tree is valid. --- If the tree is invalid, `parse()` must be called --- to get the an updated tree. +--- Determines whether this tree is valid. +--- If the tree is invalid, `parse()` must be called +--- to get the an updated tree. function LanguageTree:is_valid() return self._valid end --- Returns a map of language to child tree. +--- Returns a map of language to child tree. function LanguageTree:children() return self._children end --- Returns the source content of the language tree (bufnr or string). +--- Returns the source content of the language tree (bufnr or string). function LanguageTree:source() return self._source end --- Parses all defined regions using a treesitter parser --- for the language this tree represents. --- This will run the injection query for this language to --- determine if any child languages should be created. +--- Parses all defined regions using a treesitter parser +--- for the language this tree represents. +--- This will run the injection query for this language to +--- determine if any child languages should be created. function LanguageTree:parse() if self._valid then return self._trees @@ -164,9 +169,10 @@ function LanguageTree:parse() return self._trees, changes end --- Invokes the callback for each LanguageTree and it's children recursively --- @param fn The function to invoke. This is invoked with arguments (tree: LanguageTree, lang: string) --- @param include_self Whether to include the invoking tree in the results. +--- Invokes the callback for each LanguageTree and it's children recursively +--- +--- @param fn The function to invoke. This is invoked with arguments (tree: LanguageTree, lang: string) +--- @param include_self Whether to include the invoking tree in the results. function LanguageTree:for_each_child(fn, include_self) if include_self then fn(self, self._lang) @@ -177,10 +183,12 @@ function LanguageTree:for_each_child(fn, include_self) end end --- Invokes the callback for each treesitter trees recursively. --- Note, this includes the invoking language tree's trees as well. --- @param fn The callback to invoke. The callback is invoked with arguments --- (tree: TSTree, languageTree: LanguageTree) +--- Invokes the callback for each treesitter trees recursively. +--- +--- Note, this includes the invoking language tree's trees as well. +--- +--- @param fn The callback to invoke. The callback is invoked with arguments +--- (tree: TSTree, languageTree: LanguageTree) function LanguageTree:for_each_tree(fn) for _, tree in ipairs(self._trees) do fn(tree, self) @@ -191,9 +199,11 @@ function LanguageTree:for_each_tree(fn) end end --- Adds a child language to this tree. --- If the language already exists as a child, it will first be removed. --- @param lang The language to add. +--- Adds a child language to this tree. +--- +--- If the language already exists as a child, it will first be removed. +--- +--- @param lang The language to add. function LanguageTree:add_child(lang) if self._children[lang] then self:remove_child(lang) @@ -207,8 +217,9 @@ function LanguageTree:add_child(lang) return self._children[lang] end --- Removes a child language from this tree. --- @param lang The language to remove. +--- Removes a child language from this tree. +--- +--- @param lang The language to remove. function LanguageTree:remove_child(lang) local child = self._children[lang] @@ -220,10 +231,11 @@ function LanguageTree:remove_child(lang) end end --- Destroys this language tree and all its children. --- Any cleanup logic should be performed here. --- Note, this DOES NOT remove this tree from a parent. --- `remove_child` must be called on the parent to remove it. +--- Destroys this language tree and all its children. +--- +--- Any cleanup logic should be performed here. +--- Note, this DOES NOT remove this tree from a parent. +--- `remove_child` must be called on the parent to remove it. function LanguageTree:destroy() -- Cleanup here for _, child in ipairs(self._children) do @@ -231,23 +243,23 @@ function LanguageTree:destroy() end end --- Sets the included regions that should be parsed by this parser. --- A region is a set of nodes and/or ranges that will be parsed in the same context. --- --- For example, `{ { node1 }, { node2} }` is two separate regions. --- This will be parsed by the parser in two different contexts... thus resulting --- in two separate trees. --- --- `{ { node1, node2 } }` is a single region consisting of two nodes. --- This will be parsed by the parser in a single context... thus resulting --- in a single tree. --- --- This allows for embedded languages to be parsed together across different --- nodes, which is useful for templating languages like ERB and EJS. --- --- Note, this call invalidates the tree and requires it to be parsed again. --- --- @param regions A list of regions this tree should manage and parse. +--- Sets the included regions that should be parsed by this parser. +--- A region is a set of nodes and/or ranges that will be parsed in the same context. +--- +--- For example, `{ { node1 }, { node2} }` is two separate regions. +--- This will be parsed by the parser in two different contexts... thus resulting +--- in two separate trees. +--- +--- `{ { node1, node2 } }` is a single region consisting of two nodes. +--- This will be parsed by the parser in a single context... thus resulting +--- in a single tree. +--- +--- This allows for embedded languages to be parsed together across different +--- nodes, which is useful for templating languages like ERB and EJS. +--- +--- Note, this call invalidates the tree and requires it to be parsed again. +--- +--- @param regions A list of regions this tree should manage and parse. function LanguageTree:set_included_regions(regions) -- TODO(vigoux): I don't think string parsers are useful for now if type(self._source) == "number" then @@ -276,16 +288,18 @@ function LanguageTree:set_included_regions(regions) self:invalidate() end --- Gets the set of included regions +--- Gets the set of included regions function LanguageTree:included_regions() return self._regions end --- Gets language injection points by language. --- This is where most of the injection processing occurs. --- TODO: Allow for an offset predicate to tailor the injection range --- instead of using the entire nodes range. --- @private +--- Gets language injection points by language. +--- +--- This is where most of the injection processing occurs. +--- +--- TODO: Allow for an offset predicate to tailor the injection range +--- instead of using the entire nodes range. +--- @private function LanguageTree:_get_injections() if not self._injection_query then return {} end @@ -297,33 +311,50 @@ function LanguageTree:_get_injections() for pattern, match, metadata in self._injection_query:iter_matches(root_node, self._source, start_line, end_line+1) do local lang = nil - local injection_node = nil - local combined = false + local ranges = {} + local combined = metadata.combined + + -- Directives can configure how injections are captured as well as actual node captures. + -- This allows more advanced processing for determining ranges and language resolution. + if metadata.content then + local content = metadata.content + + -- Allow for captured nodes to be used + if type(content) == "number" then + content = {match[content]} + end + + if content then + vim.list_extend(ranges, content) + end + end + + if metadata.language then + lang = metadata.language + end -- You can specify the content and language together -- using a tag with the language, for example -- @javascript for id, node in pairs(match) do - local data = metadata[id] local name = self._injection_query.captures[id] - local offset_range = data and data.offset -- Lang should override any other language tag - if name == "language" then + if name == "language" and not lang then lang = query.get_node_text(node, self._source) elseif name == "combined" then combined = true - elseif name == "content" then - injection_node = offset_range or node + elseif name == "content" and #ranges == 0 then + table.insert(ranges, node) -- Ignore any tags that start with "_" -- Allows for other tags to be used in matches elseif string.sub(name, 1, 1) ~= "_" then - if lang == nil then + if not lang then lang = name end - if not injection_node then - injection_node = offset_range or node + if #ranges == 0 then + table.insert(ranges, node) end end end @@ -337,21 +368,21 @@ function LanguageTree:_get_injections() injections[tree_index][lang] = {} end - -- Key by pattern so we can either combine each node to parse in the same - -- context or treat each node independently. + -- Key this by pattern. If combined is set to true all captures of this pattern + -- will be parsed by treesitter as the same "source". + -- If combined is false, each "region" will be parsed as a single source. if not injections[tree_index][lang][pattern] then - injections[tree_index][lang][pattern] = { combined = combined, nodes = {} } + injections[tree_index][lang][pattern] = { combined = combined, regions = {} } end - table.insert(injections[tree_index][lang][pattern].nodes, injection_node) + table.insert(injections[tree_index][lang][pattern].regions, ranges) end end local result = {} -- Generate a map by lang of node lists. - -- Each list is a set of ranges that should be parsed - -- together. + -- Each list is a set of ranges that should be parsed together. for _, lang_map in ipairs(injections) do for lang, patterns in pairs(lang_map) do if not result[lang] then @@ -360,10 +391,10 @@ function LanguageTree:_get_injections() for _, entry in pairs(patterns) do if entry.combined then - table.insert(result[lang], entry.nodes) + table.insert(result[lang], vim.tbl_flatten(entry.regions)) else - for _, node in ipairs(entry.nodes) do - table.insert(result[lang], {node}) + for _, ranges in ipairs(entry.regions) do + table.insert(result[lang], ranges) end end end @@ -373,12 +404,14 @@ function LanguageTree:_get_injections() return result end +---@private function LanguageTree:_do_callback(cb_name, ...) for _, cb in ipairs(self._callbacks[cb_name]) do cb(...) end end +---@private function LanguageTree:_on_bytes(bufnr, changed_tick, start_row, start_col, start_byte, old_row, old_col, old_byte, @@ -403,24 +436,26 @@ function LanguageTree:_on_bytes(bufnr, changed_tick, new_row, new_col, new_byte) end +---@private function LanguageTree:_on_reload() self:invalidate(true) end +---@private function LanguageTree:_on_detach(...) self:invalidate(true) self:_do_callback('detach', ...) end --- Registers callbacks for the parser --- @param cbs An `nvim_buf_attach`-like table argument with the following keys : --- `on_bytes` : see `nvim_buf_attach`, but this will be called _after_ the parsers callback. --- `on_changedtree` : a callback that will be called every time the tree has syntactical changes. --- it will only be passed one argument, that is a table of the ranges (as node ranges) that --- changed. --- `on_child_added` : emitted when a child is added to the tree. --- `on_child_removed` : emitted when a child is removed from the tree. +--- @param cbs An `nvim_buf_attach`-like table argument with the following keys : +--- `on_bytes` : see `nvim_buf_attach`, but this will be called _after_ the parsers callback. +--- `on_changedtree` : a callback that will be called every time the tree has syntactical changes. +--- it will only be passed one argument, that is a table of the ranges (as node ranges) that +--- changed. +--- `on_child_added` : emitted when a child is added to the tree. +--- `on_child_removed` : emitted when a child is removed from the tree. function LanguageTree:register_cbs(cbs) if not cbs then return end @@ -445,6 +480,7 @@ function LanguageTree:register_cbs(cbs) end end +---@private local function tree_contains(tree, range) local start_row, start_col, end_row, end_col = tree:root():range() local start_fits = start_row < range[1] or (start_row == range[1] and start_col <= range[2]) @@ -457,6 +493,11 @@ local function tree_contains(tree, range) return false end +--- Determines wether @param range is contained in this language tree +--- +--- This goes down the tree to recursively check childs. +--- +--- @param range A range, that is a `{ start_line, start_col, end_line, end_col }` table. function LanguageTree:contains(range) for _, tree in pairs(self._trees) do if tree_contains(tree, range) then @@ -467,6 +508,9 @@ function LanguageTree:contains(range) return false end +--- Gets the appropriate language that contains @param range +--- +--- @param range A text range, see |LanguageTree:contains| function LanguageTree:language_for_range(range) for _, child in pairs(self._children) do if child:contains(range) then diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua index 188ec94a6a..b81eb18945 100644 --- a/runtime/lua/vim/treesitter/query.lua +++ b/runtime/lua/vim/treesitter/query.lua @@ -8,6 +8,7 @@ Query.__index = Query local M = {} +---@private local function dedupe_files(files) local result = {} local seen = {} @@ -22,6 +23,22 @@ local function dedupe_files(files) return result end +---@private +local function safe_read(filename, read_quantifier) + local file, err = io.open(filename, 'r') + if not file then + error(err) + end + local content = file:read(read_quantifier) + io.close(file) + return content +end + +--- Gets the list of files used to make up a query +--- +--- @param lang The language +--- @param query_name The name of the query to load +--- @param is_included Internal parameter, most of the time left as `nil` function M.get_query_files(lang, query_name, is_included) local query_path = string.format('queries/%s/%s.scm', lang, query_name) local lang_files = dedupe_files(a.nvim_get_runtime_file(query_path, true)) @@ -38,7 +55,7 @@ function M.get_query_files(lang, query_name, is_included) local MODELINE_FORMAT = "^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$" for _, file in ipairs(lang_files) do - local modeline = io.open(file, 'r'):read('*l') + local modeline = safe_read(file, '*l') if modeline then local langlist = modeline:match(MODELINE_FORMAT) @@ -69,27 +86,17 @@ function M.get_query_files(lang, query_name, is_included) return query_files end +---@private local function read_query_files(filenames) local contents = {} for _,filename in ipairs(filenames) do - table.insert(contents, io.open(filename, 'r'):read('*a')) + table.insert(contents, safe_read(filename, '*a')) end return table.concat(contents, '') end -local match_metatable = { - __index = function(tbl, key) - rawset(tbl, key, {}) - return tbl[key] - end -} - -local function new_match_metadata() - return setmetatable({}, match_metatable) -end - --- The explicitly set queries from |vim.treesitter.query.set_query()| local explicit_queries = setmetatable({}, { __index = function(t, k) @@ -104,19 +111,20 @@ local explicit_queries = setmetatable({}, { --- --- This allows users to override any runtime files and/or configuration --- set by plugins. ----@param lang string: The language to use for the query ----@param query_name string: The name of the query (i.e. "highlights") ----@param text string: The query text (unparsed). +--- +--- @param lang string: The language to use for the query +--- @param query_name string: The name of the query (i.e. "highlights") +--- @param text string: The query text (unparsed). function M.set_query(lang, query_name, text) explicit_queries[lang][query_name] = M.parse_query(lang, text) end --- Returns the runtime query {query_name} for {lang}. --- --- @param lang The language to use for the query --- @param query_name The name of the query (i.e. "highlights") --- --- @return The corresponding query, parsed. +--- +--- @param lang The language to use for the query +--- @param query_name The name of the query (i.e. "highlights") +--- +--- @return The corresponding query, parsed. function M.get_query(lang, query_name) if explicit_queries[lang][query_name] then return explicit_queries[lang][query_name] @@ -130,12 +138,23 @@ function M.get_query(lang, query_name) end end ---- Parses a query. --- --- @param language The language --- @param query A string containing the query (s-expr syntax) --- --- @returns The query +--- Parse {query} as a string. (If the query is in a file, the caller +--- should read the contents into a string before calling). +--- +--- Returns a `Query` (see |lua-treesitter-query|) object which can be used to +--- search nodes in the syntax tree for the patterns defined in {query} +--- using `iter_*` methods below. +--- +--- Exposes `info` and `captures` with additional information about the {query}. +--- - `captures` contains the list of unique capture names defined in +--- {query}. +--- -` info.captures` also points to `captures`. +--- - `info.patterns` contains information about predicates. +--- +--- @param lang The language +--- @param query A string containing the query (s-expr syntax) +--- +--- @returns The query function M.parse_query(lang, query) language.require_language(lang) local self = setmetatable({}, Query) @@ -148,8 +167,9 @@ end -- TODO(vigoux): support multiline nodes too --- Gets the text corresponding to a given node --- @param node the node --- @param bufnr the buffer from which the node is extracted. +--- +--- @param node the node +--- @param bsource The buffer or string from which the node is extracted function M.get_node_text(node, source) local start_row, start_col, start_byte = node:start() local end_row, end_col, end_byte = node:end_() @@ -201,6 +221,7 @@ local predicate_handlers = { ["match?"] = (function() local magic_prefixes = {['\\v']=true, ['\\m']=true, ['\\M']=true, ['\\V']=true} + ---@private local function check_magic(str) if string.len(str) < 2 or magic_prefixes[string.sub(str,1,2)] then return str @@ -210,7 +231,7 @@ local predicate_handlers = { local compiled_vim_regexes = setmetatable({}, { __index = function(t, pattern) - local res = vim.regex(check_magic(vim.fn.escape(pattern, '\\'))) + local res = vim.regex(check_magic(pattern)) rawset(t, pattern, res) return res end @@ -239,7 +260,25 @@ local predicate_handlers = { end return false - end + end, + + ["any-of?"] = function(match, _, source, predicate) + local node = match[predicate[2]] + local node_text = M.get_node_text(node, source) + + -- Since 'predicate' will not be used by callers of this function, use it + -- to store a string set built from the list of words to check against. + local string_set = predicate["string_set"] + if not string_set then + string_set = {} + for i=3,#predicate do + string_set[predicate[i]] = true + end + predicate["string_set"] = string_set + end + + return string_set[node_text] + end, } -- As we provide lua-match? also expose vim-match? @@ -249,12 +288,16 @@ predicate_handlers["vim-match?"] = predicate_handlers["match?"] -- Directives store metadata or perform side effects against a match. -- Directives should always end with a `!`. -- Directive handler receive the following arguments --- (match, pattern, bufnr, predicate) +-- (match, pattern, bufnr, predicate, metadata) local directive_handlers = { ["set!"] = function(_, _, _, pred, metadata) if #pred == 4 then -- (#set! @capture "key" "value") - metadata[pred[2]][pred[3]] = pred[4] + local capture = pred[2] + if not metadata[capture] then + metadata[capture] = {} + end + metadata[capture][pred[3]] = pred[4] else -- (#set! "key" "value") metadata[pred[2]] = pred[3] @@ -269,7 +312,6 @@ local directive_handlers = { local start_col_offset = pred[4] or 0 local end_row_offset = pred[5] or 0 local end_col_offset = pred[6] or 0 - local key = pred[7] or "offset" range[1] = range[1] + start_row_offset range[2] = range[2] + start_col_offset @@ -278,16 +320,16 @@ local directive_handlers = { -- If this produces an invalid range, we just skip it. if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then - metadata[pred[2]][key] = range + metadata.content = {range} end end } --- Adds a new predicate to be used in queries --- --- @param name the name of the predicate, without leading # --- @param handler the handler function to be used --- signature will be (match, pattern, bufnr, predicate) +--- +--- @param name the name of the predicate, without leading # +--- @param handler the handler function to be used +--- signature will be (match, pattern, bufnr, predicate) function M.add_predicate(name, handler, force) if predicate_handlers[name] and not force then error(string.format("Overriding %s", name)) @@ -297,10 +339,10 @@ function M.add_predicate(name, handler, force) end --- Adds a new directive to be used in queries --- --- @param name the name of the directive, without leading # --- @param handler the handler function to be used --- signature will be (match, pattern, bufnr, predicate) +--- +--- @param name the name of the directive, without leading # +--- @param handler the handler function to be used +--- signature will be (match, pattern, bufnr, predicate) function M.add_directive(name, handler, force) if directive_handlers[name] and not force then error(string.format("Overriding %s", name)) @@ -314,14 +356,17 @@ function M.list_predicates() return vim.tbl_keys(predicate_handlers) end +---@private local function xor(x, y) return (x or y) and not (x and y) end +---@private local function is_directive(name) return string.sub(name, -1) == "!" end +---@private function Query:match_preds(match, pattern, source) local preds = self.info.patterns[pattern] @@ -360,7 +405,7 @@ function Query:match_preds(match, pattern, source) return true end ---- Applies directives against a match and pattern. +---@private function Query:apply_directives(match, pattern, source, metadata) local preds = self.info.patterns[pattern] @@ -382,6 +427,7 @@ end --- Returns the start and stop value if set else the node's range. -- When the node's range is used, the stop is incremented by 1 -- to make the search inclusive. +---@private local function value_or_node_range(start, stop, node) if start == nil and stop == nil then local node_start, _, node_stop, _ = node:range() @@ -391,15 +437,36 @@ local function value_or_node_range(start, stop, node) return start, stop end ---- Iterates of the captures of self on a given range. --- --- @param node The node under which the search will occur --- @param buffer The source buffer to search --- @param start The starting line of the search --- @param stop The stopping line of the search (end-exclusive) --- --- @returns The matching capture id --- @returns The captured node +--- Iterate over all captures from all matches inside {node} +--- +--- {source} is needed if the query contains predicates, then the caller +--- must ensure to use a freshly parsed tree consistent with the current +--- text of the buffer (if relevent). {start_row} and {end_row} can be used to limit +--- matches inside a row range (this is typically used with root node +--- as the node, i e to get syntax highlight matches in the current +--- viewport). When omitted the start and end row values are used from the given node. +--- +--- The iterator returns three values, a numeric id identifying the capture, +--- the captured node, and metadata from any directives processing the match. +--- The following example shows how to get captures by name: +--- +--- <pre> +--- for id, node, metadata in query:iter_captures(tree:root(), bufnr, first, last) do +--- local name = query.captures[id] -- name of the capture in the query +--- -- typically useful info about the node: +--- local type = node:type() -- type of the captured node +--- local row1, col1, row2, col2 = node:range() -- range of the capture +--- ... use the info here ... +--- end +--- </pre> +--- +--- @param node The node under which the search will occur +--- @param source The source buffer or string to exctract text from +--- @param start The starting line of the search +--- @param stop The stopping line of the search (end-exclusive) +--- +--- @returns The matching capture id +--- @returns The captured node function Query:iter_captures(node, source, start, stop) if type(source) == "number" and source == 0 then source = vim.api.nvim_get_current_buf() @@ -408,9 +475,10 @@ function Query:iter_captures(node, source, start, stop) start, stop = value_or_node_range(start, stop, node) local raw_iter = node:_rawquery(self.query, true, start, stop) + ---@private local function iter() local capture, captured_node, match = raw_iter() - local metadata = new_match_metadata() + local metadata = {} if match ~= nil then local active = self:match_preds(match, match.pattern, source) @@ -427,14 +495,35 @@ function Query:iter_captures(node, source, start, stop) end --- Iterates the matches of self on a given range. --- --- @param node The node under which the search will occur --- @param buffer The source buffer to search --- @param start The starting line of the search --- @param stop The stopping line of the search (end-exclusive) --- --- @returns The matching pattern id --- @returns The matching match +--- +--- Iterate over all matches within a node. The arguments are the same as +--- for |query:iter_captures()| but the iterated values are different: +--- an (1-based) index of the pattern in the query, a table mapping +--- capture indices to nodes, and metadata from any directives processing the match. +--- If the query has more than one pattern the capture table might be sparse, +--- and e.g. `pairs()` method should be used over `ipairs`. +--- Here an example iterating over all captures in every match: +--- +--- <pre> +--- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, first, last) do +--- for id, node in pairs(match) do +--- local name = query.captures[id] +--- -- `node` was captured by the `name` capture in the match +--- +--- local node_data = metadata[id] -- Node level metadata +--- +--- ... use the info here ... +--- end +--- end +--- </pre> +--- +--- @param node The node under which the search will occur +--- @param source The source buffer or string to search +--- @param start The starting line of the search +--- @param stop The stopping line of the search (end-exclusive) +--- +--- @returns The matching pattern id +--- @returns The matching match function Query:iter_matches(node, source, start, stop) if type(source) == "number" and source == 0 then source = vim.api.nvim_get_current_buf() @@ -445,7 +534,7 @@ function Query:iter_matches(node, source, start, stop) local raw_iter = node:_rawquery(self.query, false, start, stop) local function iter() local pattern, match = raw_iter() - local metadata = new_match_metadata() + local metadata = {} if match ~= nil then local active = self:match_preds(match, pattern, source) |