aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim
diff options
context:
space:
mode:
authorLewis Russell <lewis6991@gmail.com>2025-04-01 13:49:45 +0100
committerGitHub <noreply@github.com>2025-04-01 13:49:45 +0100
commit9b239a6a86ed0caaaf7522cfc600da4b35d94d04 (patch)
tree1c5124c8a313956341ef09969e00bc3a8c2e639c /runtime/lua/vim
parentec18ebcb417bb9f2afc81d247db6993eaa48701f (diff)
parente76a7e8afb1d683de473f881289677f17ef79410 (diff)
downloadrneovim-9b239a6a86ed0caaaf7522cfc600da4b35d94d04.tar.gz
rneovim-9b239a6a86ed0caaaf7522cfc600da4b35d94d04.tar.bz2
rneovim-9b239a6a86ed0caaaf7522cfc600da4b35d94d04.zip
Merge pull request #32686 from lewis6991/lsp-rpc-perf
perf(lsp): improve rpc loop performance (with shim)
Diffstat (limited to 'runtime/lua/vim')
-rw-r--r--runtime/lua/vim/_stringbuffer.lua110
-rw-r--r--runtime/lua/vim/lsp/rpc.lua135
2 files changed, 147 insertions, 98 deletions
diff --git a/runtime/lua/vim/_stringbuffer.lua b/runtime/lua/vim/_stringbuffer.lua
new file mode 100644
index 0000000000..92965ee54d
--- /dev/null
+++ b/runtime/lua/vim/_stringbuffer.lua
@@ -0,0 +1,110 @@
+-- Basic shim for LuaJIT's stringbuffer.
+-- Note this does not implement the full API.
+-- This is intentionally internal-only. If we want to expose it, we should
+-- reimplement this a userdata and ship it as `string.buffer`
+-- (minus the FFI stuff) for Lua 5.1.
+local M = {}
+
+local has_strbuffer, strbuffer = pcall(require, 'string.buffer')
+
+if has_strbuffer then
+ M.new = strbuffer.new
+
+ -- Lua 5.1 does not have __len metamethod so we need to provide a len()
+ -- function to use instead.
+
+ --- @param buf vim._stringbuffer
+ --- @return integer
+ function M.len(buf)
+ return #buf
+ end
+
+ return M
+end
+
+--- @class vim._stringbuffer
+--- @field private buf string[]
+--- @field package len integer absolute length of the `buf`
+--- @field package skip_ptr integer
+local StrBuffer = {}
+StrBuffer.__index = StrBuffer
+
+--- @return string
+function StrBuffer:tostring()
+ if #self.buf > 1 then
+ self.buf = { table.concat(self.buf) }
+ end
+
+ -- assert(self.len == #(self.buf[1] or ''), 'len mismatch')
+
+ if self.skip_ptr > 0 then
+ if self.buf[1] then
+ self.buf[1] = self.buf[1]:sub(self.skip_ptr + 1)
+ self.len = self.len - self.skip_ptr
+ end
+ self.skip_ptr = 0
+ end
+
+ -- assert(self.len == #(self.buf[1] or ''), 'len mismatch')
+
+ return self.buf[1] or ''
+end
+
+StrBuffer.__tostring = StrBuffer.tostring
+
+--- @private
+--- Efficiently peak at the first `n` characters of the buffer.
+--- @param n integer
+--- @return string
+function StrBuffer:_peak(n)
+ local skip, buf1 = self.skip_ptr, self.buf[1]
+ if buf1 and (n + skip) < #buf1 then
+ return buf1:sub(skip + 1, skip + n)
+ end
+ return self:tostring():sub(1, n)
+end
+
+--- @param chunk string
+function StrBuffer:put(chunk)
+ local s = tostring(chunk)
+ self.buf[#self.buf + 1] = s
+ self.len = self.len + #s
+ return self
+end
+
+--- @param str string
+function StrBuffer:set(str)
+ return self:reset():put(str)
+end
+
+--- @param n integer
+--- @return string
+function StrBuffer:get(n)
+ local r = self:_peak(n)
+ self:skip(n)
+ return r
+end
+
+--- @param n integer
+function StrBuffer:skip(n)
+ self.skip_ptr = math.min(self.len, self.skip_ptr + n)
+ return self
+end
+
+function StrBuffer:reset()
+ self.buf = {}
+ self.skip_ptr = 0
+ self.len = 0
+ return self
+end
+
+function M.new()
+ return setmetatable({}, StrBuffer):reset()
+end
+
+--- @param buf vim._stringbuffer
+function M.len(buf)
+ return buf.len - buf.skip_ptr
+end
+
+return M
diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua
index d31d94cab7..637cda84d3 100644
--- a/runtime/lua/vim/lsp/rpc.lua
+++ b/runtime/lua/vim/lsp/rpc.lua
@@ -5,8 +5,8 @@ local validate, schedule_wrap = vim.validate, vim.schedule_wrap
--- Embeds the given string into a table and correctly computes `Content-Length`.
---
----@param message string
----@return string message with `Content-Length` attribute
+--- @param message string
+--- @return string message with `Content-Length` attribute
local function format_message_with_content_length(message)
return table.concat({
'Content-Length: ',
@@ -16,104 +16,21 @@ local function format_message_with_content_length(message)
})
end
----@class (private) vim.lsp.rpc.Headers: {string: any}
----@field content_length integer
-
---- Parses an LSP Message's header
+--- Extract content-length from the header
---
----@param header string The header to parse.
----@return vim.lsp.rpc.Headers#parsed headers
-local function parse_headers(header)
- assert(type(header) == 'string', 'header must be a string')
- --- @type vim.lsp.rpc.Headers
- local headers = {}
- for line in vim.gsplit(header, '\r\n', { plain = true }) do
+--- @param header string The header to parse
+--- @return integer
+local function get_content_length(header)
+ for line in header:gmatch('(.-)\r\n') do
if line == '' then
break
end
- --- @type string?, string?
- local key, value = line:match('^%s*(%S+)%s*:%s*(.+)%s*$')
- if key then
- key = key:lower():gsub('%-', '_') --- @type string
- headers[key] = value
- else
- log.error('invalid header line %q', line)
- error(string.format('invalid header line %q', line))
- end
- end
- headers.content_length = tonumber(headers.content_length)
- or error(string.format('Content-Length not found in headers. %q', header))
- return headers
-end
-
--- This is the start of any possible header patterns. The gsub converts it to a
--- case insensitive pattern.
-local header_start_pattern = ('content'):gsub('%w', function(c)
- return '[' .. c .. c:upper() .. ']'
-end)
-
---- The actual workhorse.
-local function request_parser_loop()
- local buffer = '' -- only for header part
- while true do
- -- A message can only be complete if it has a double CRLF and also the full
- -- payload, so first let's check for the CRLFs
- local start, finish = buffer:find('\r\n\r\n', 1, true)
- -- Start parsing the headers
- if start then
- -- This is a workaround for servers sending initial garbage before
- -- sending headers, such as if a bash script sends stdout. It assumes
- -- that we know all of the headers ahead of time. At this moment, the
- -- only valid headers start with "Content-*", so that's the thing we will
- -- be searching for.
- -- TODO(ashkan) I'd like to remove this, but it seems permanent :(
- local buffer_start = buffer:find(header_start_pattern)
- if not buffer_start then
- error(
- string.format(
- "Headers were expected, a different response was received. The server response was '%s'.",
- buffer
- )
- )
- end
- local headers = parse_headers(buffer:sub(buffer_start, start - 1))
- local content_length = headers.content_length
- -- Use table instead of just string to buffer the message. It prevents
- -- a ton of strings allocating.
- -- ref. http://www.lua.org/pil/11.6.html
- ---@type string[]
- local body_chunks = { buffer:sub(finish + 1) }
- local body_length = #body_chunks[1]
- -- Keep waiting for data until we have enough.
- while body_length < content_length do
- ---@type string
- local chunk = coroutine.yield()
- or error('Expected more data for the body. The server may have died.') -- TODO hmm.
- table.insert(body_chunks, chunk)
- body_length = body_length + #chunk
- end
- local last_chunk = body_chunks[#body_chunks]
-
- body_chunks[#body_chunks] = last_chunk:sub(1, content_length - body_length - 1)
- local rest = ''
- if body_length > content_length then
- rest = last_chunk:sub(content_length - body_length)
- end
- local body = table.concat(body_chunks)
- -- Yield our data.
-
- --- @type string
- local data = coroutine.yield(headers, body)
- or error('Expected more data for the body. The server may have died.')
- buffer = rest .. data
- else
- -- Get more data since we don't have enough.
- --- @type string
- local data = coroutine.yield()
- or error('Expected more data for the header. The server may have died.')
- buffer = buffer .. data
+ local key, value = line:match('^%s*(%S+)%s*:%s*(%d+)%s*$')
+ if key and key:lower() == 'content-length' then
+ return assert(tonumber(value))
end
end
+ error('Content-Length not found in header: ' .. header)
end
local M = {}
@@ -232,12 +149,34 @@ local default_dispatchers = {
end,
}
+local strbuffer = require('vim._stringbuffer')
+
+local function request_parser_loop()
+ local buf = strbuffer.new()
+ while true do
+ local msg = buf:tostring()
+ local header_end = msg:find('\r\n\r\n', 1, true)
+ if header_end then
+ local header = buf:get(header_end + 1)
+ buf:skip(2) -- skip past header boundary
+ local content_length = get_content_length(header)
+ while strbuffer.len(buf) < content_length do
+ buf:put(coroutine.yield())
+ end
+ local body = buf:get(content_length)
+ buf:put(coroutine.yield(body))
+ else
+ buf:put(coroutine.yield())
+ end
+ end
+end
+
--- @private
--- @param handle_body fun(body: string)
--- @param on_exit? fun()
--- @param on_error fun(err: any)
function M.create_read_loop(handle_body, on_exit, on_error)
- local parse_chunk = coroutine.wrap(request_parser_loop) --[[@as fun(chunk: string?): vim.lsp.rpc.Headers?, string?]]
+ local parse_chunk = coroutine.wrap(request_parser_loop) --[[@as fun(chunk: string?): string]]
parse_chunk()
return function(err, chunk)
if err then
@@ -253,9 +192,9 @@ function M.create_read_loop(handle_body, on_exit, on_error)
end
while true do
- local headers, body = parse_chunk(chunk)
- if headers then
- handle_body(assert(body))
+ local body = parse_chunk(chunk)
+ if body then
+ handle_body(body)
chunk = ''
else
break