From 42657e70b8a8ddf8edbe261f410aeb6169e5f6dc Mon Sep 17 00:00:00 2001 From: Mathias Fussenegger Date: Sun, 8 Dec 2024 18:10:28 +0100 Subject: perf(lsp): optimize content length extraction from rpc headers - No redundant `:gsub` to turn `-` in `Content-Length` into `_` - No table allocations only to add and later get the content-length header --- runtime/lua/vim/lsp/rpc.lua | 49 +++++++++++++++++---------------------------- 1 file changed, 18 insertions(+), 31 deletions(-) (limited to 'runtime/lua/vim') diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua index d31d94cab7..a358582033 100644 --- a/runtime/lua/vim/lsp/rpc.lua +++ b/runtime/lua/vim/lsp/rpc.lua @@ -16,34 +16,21 @@ local function format_message_with_content_length(message) }) end ----@class (private) vim.lsp.rpc.Headers: {string: any} ----@field content_length integer - ---- Parses an LSP Message's header +--- Extract content-length from the header --- ----@param header string The header to parse. ----@return vim.lsp.rpc.Headers#parsed headers -local function parse_headers(header) - assert(type(header) == 'string', 'header must be a string') - --- @type vim.lsp.rpc.Headers - local headers = {} - for line in vim.gsplit(header, '\r\n', { plain = true }) do +---@param header string The header to parse +---@return integer? +local function get_content_length(header) + for line in header:gmatch('(.-)\r\n') do if line == '' then break end - --- @type string?, string? local key, value = line:match('^%s*(%S+)%s*:%s*(.+)%s*$') - if key then - key = key:lower():gsub('%-', '_') --- @type string - headers[key] = value - else - log.error('invalid header line %q', line) - error(string.format('invalid header line %q', line)) + if key:lower() == 'content-length' then + return tonumber(value) end end - headers.content_length = tonumber(headers.content_length) - or error(string.format('Content-Length not found in headers. %q', header)) - return headers + error('Content-Length not found in header: ' .. header) end -- This is the start of any possible header patterns. The gsub converts it to a @@ -58,9 +45,9 @@ local function request_parser_loop() while true do -- A message can only be complete if it has a double CRLF and also the full -- payload, so first let's check for the CRLFs - local start, finish = buffer:find('\r\n\r\n', 1, true) + local header_end, body_start = buffer:find('\r\n\r\n', 1, true) -- Start parsing the headers - if start then + if header_end then -- This is a workaround for servers sending initial garbage before -- sending headers, such as if a bash script sends stdout. It assumes -- that we know all of the headers ahead of time. At this moment, the @@ -76,13 +63,13 @@ local function request_parser_loop() ) ) end - local headers = parse_headers(buffer:sub(buffer_start, start - 1)) - local content_length = headers.content_length + local header = buffer:sub(buffer_start, header_end + 1) + local content_length = get_content_length(header) -- Use table instead of just string to buffer the message. It prevents -- a ton of strings allocating. -- ref. http://www.lua.org/pil/11.6.html ---@type string[] - local body_chunks = { buffer:sub(finish + 1) } + local body_chunks = { buffer:sub(body_start + 1) } local body_length = #body_chunks[1] -- Keep waiting for data until we have enough. while body_length < content_length do @@ -103,7 +90,7 @@ local function request_parser_loop() -- Yield our data. --- @type string - local data = coroutine.yield(headers, body) + local data = coroutine.yield(body) or error('Expected more data for the body. The server may have died.') buffer = rest .. data else @@ -237,7 +224,7 @@ local default_dispatchers = { --- @param on_exit? fun() --- @param on_error fun(err: any) function M.create_read_loop(handle_body, on_exit, on_error) - local parse_chunk = coroutine.wrap(request_parser_loop) --[[@as fun(chunk: string?): vim.lsp.rpc.Headers?, string?]] + local parse_chunk = coroutine.wrap(request_parser_loop) --[[@as fun(chunk: string?): string]] parse_chunk() return function(err, chunk) if err then @@ -253,9 +240,9 @@ function M.create_read_loop(handle_body, on_exit, on_error) end while true do - local headers, body = parse_chunk(chunk) - if headers then - handle_body(assert(body)) + local body = parse_chunk(chunk) + if body then + handle_body(body) chunk = '' else break -- cgit From f517fcd14847e30c55b88b2ccffdb6ba4b80018c Mon Sep 17 00:00:00 2001 From: Mathias Fussenegger Date: Sun, 8 Dec 2024 18:14:30 +0100 Subject: perf(lsp): use string.buffer for rpc loop Avoids some table allocations. In a quick test over 50000 iterations it reduces the time from 130ms to 74 ms For the test setup details see: https://github.com/mfussenegger/nvim-dap/pull/1394#issue-2725352391 --- runtime/lua/vim/lsp/rpc.lua | 144 ++++++++++++++++++++++++++------------------ 1 file changed, 87 insertions(+), 57 deletions(-) (limited to 'runtime/lua/vim') diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua index a358582033..510c2e5199 100644 --- a/runtime/lua/vim/lsp/rpc.lua +++ b/runtime/lua/vim/lsp/rpc.lua @@ -25,8 +25,8 @@ local function get_content_length(header) if line == '' then break end - local key, value = line:match('^%s*(%S+)%s*:%s*(.+)%s*$') - if key:lower() == 'content-length' then + local key, value = line:match('^%s*(%S+)%s*:%s*(%d+)%s*$') + if key and key:lower() == 'content-length' then return tonumber(value) end end @@ -39,66 +39,96 @@ local header_start_pattern = ('content'):gsub('%w', function(c) return '[' .. c .. c:upper() .. ']' end) +local has_strbuffer, strbuffer = pcall(require, 'string.buffer') + --- The actual workhorse. -local function request_parser_loop() - local buffer = '' -- only for header part - while true do - -- A message can only be complete if it has a double CRLF and also the full - -- payload, so first let's check for the CRLFs - local header_end, body_start = buffer:find('\r\n\r\n', 1, true) - -- Start parsing the headers - if header_end then - -- This is a workaround for servers sending initial garbage before - -- sending headers, such as if a bash script sends stdout. It assumes - -- that we know all of the headers ahead of time. At this moment, the - -- only valid headers start with "Content-*", so that's the thing we will - -- be searching for. - -- TODO(ashkan) I'd like to remove this, but it seems permanent :( - local buffer_start = buffer:find(header_start_pattern) - if not buffer_start then - error( - string.format( - "Headers were expected, a different response was received. The server response was '%s'.", - buffer - ) - ) - end - local header = buffer:sub(buffer_start, header_end + 1) - local content_length = get_content_length(header) - -- Use table instead of just string to buffer the message. It prevents - -- a ton of strings allocating. - -- ref. http://www.lua.org/pil/11.6.html - ---@type string[] - local body_chunks = { buffer:sub(body_start + 1) } - local body_length = #body_chunks[1] - -- Keep waiting for data until we have enough. - while body_length < content_length do - ---@type string +---@type function +local request_parser_loop + +if has_strbuffer then + request_parser_loop = function() + local buf = strbuffer.new() + while true do + local msg = buf:tostring() + local header_end = msg:find('\r\n\r\n', 1, true) + if header_end then + local header = buf:get(header_end + 1) + buf:skip(2) -- skip past header boundary + local content_length = get_content_length(header) + while #buf < content_length do + local chunk = coroutine.yield() + buf:put(chunk) + end + local body = buf:get(content_length) + local chunk = coroutine.yield(body) + buf:put(chunk) + else local chunk = coroutine.yield() - or error('Expected more data for the body. The server may have died.') -- TODO hmm. - table.insert(body_chunks, chunk) - body_length = body_length + #chunk + buf:put(chunk) end - local last_chunk = body_chunks[#body_chunks] + end + end +else + request_parser_loop = function() + local buffer = '' -- only for header part + while true do + -- A message can only be complete if it has a double CRLF and also the full + -- payload, so first let's check for the CRLFs + local header_end, body_start = buffer:find('\r\n\r\n', 1, true) + -- Start parsing the headers + if header_end then + -- This is a workaround for servers sending initial garbage before + -- sending headers, such as if a bash script sends stdout. It assumes + -- that we know all of the headers ahead of time. At this moment, the + -- only valid headers start with "Content-*", so that's the thing we will + -- be searching for. + -- TODO(ashkan) I'd like to remove this, but it seems permanent :( + local buffer_start = buffer:find(header_start_pattern) + if not buffer_start then + error( + string.format( + "Headers were expected, a different response was received. The server response was '%s'.", + buffer + ) + ) + end + local header = buffer:sub(buffer_start, header_end + 1) + local content_length = get_content_length(header) + -- Use table instead of just string to buffer the message. It prevents + -- a ton of strings allocating. + -- ref. http://www.lua.org/pil/11.6.html + ---@type string[] + local body_chunks = { buffer:sub(body_start + 1) } + local body_length = #body_chunks[1] + -- Keep waiting for data until we have enough. + while body_length < content_length do + ---@type string + local chunk = coroutine.yield() + or error('Expected more data for the body. The server may have died.') -- TODO hmm. + table.insert(body_chunks, chunk) + body_length = body_length + #chunk + end + local last_chunk = body_chunks[#body_chunks] - body_chunks[#body_chunks] = last_chunk:sub(1, content_length - body_length - 1) - local rest = '' - if body_length > content_length then - rest = last_chunk:sub(content_length - body_length) - end - local body = table.concat(body_chunks) - -- Yield our data. + body_chunks[#body_chunks] = last_chunk:sub(1, content_length - body_length - 1) + local rest = '' + if body_length > content_length then + rest = last_chunk:sub(content_length - body_length) + end + local body = table.concat(body_chunks) + -- Yield our data. - --- @type string - local data = coroutine.yield(body) - or error('Expected more data for the body. The server may have died.') - buffer = rest .. data - else - -- Get more data since we don't have enough. - --- @type string - local data = coroutine.yield() - or error('Expected more data for the header. The server may have died.') - buffer = buffer .. data + --- @type string + local data = coroutine.yield(body) + or error('Expected more data for the body. The server may have died.') + buffer = rest .. data + else + -- Get more data since we don't have enough. + --- @type string + local data = coroutine.yield() + or error('Expected more data for the header. The server may have died.') + buffer = buffer .. data + end end end end -- cgit From e76a7e8afb1d683de473f881289677f17ef79410 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Sat, 1 Mar 2025 14:44:29 +0000 Subject: refactor: add basic stringbuffer shim --- runtime/lua/vim/_stringbuffer.lua | 110 +++++++++++++++++++++++++++++++ runtime/lua/vim/lsp/rpc.lua | 132 ++++++++------------------------------ 2 files changed, 137 insertions(+), 105 deletions(-) create mode 100644 runtime/lua/vim/_stringbuffer.lua (limited to 'runtime/lua/vim') diff --git a/runtime/lua/vim/_stringbuffer.lua b/runtime/lua/vim/_stringbuffer.lua new file mode 100644 index 0000000000..92965ee54d --- /dev/null +++ b/runtime/lua/vim/_stringbuffer.lua @@ -0,0 +1,110 @@ +-- Basic shim for LuaJIT's stringbuffer. +-- Note this does not implement the full API. +-- This is intentionally internal-only. If we want to expose it, we should +-- reimplement this a userdata and ship it as `string.buffer` +-- (minus the FFI stuff) for Lua 5.1. +local M = {} + +local has_strbuffer, strbuffer = pcall(require, 'string.buffer') + +if has_strbuffer then + M.new = strbuffer.new + + -- Lua 5.1 does not have __len metamethod so we need to provide a len() + -- function to use instead. + + --- @param buf vim._stringbuffer + --- @return integer + function M.len(buf) + return #buf + end + + return M +end + +--- @class vim._stringbuffer +--- @field private buf string[] +--- @field package len integer absolute length of the `buf` +--- @field package skip_ptr integer +local StrBuffer = {} +StrBuffer.__index = StrBuffer + +--- @return string +function StrBuffer:tostring() + if #self.buf > 1 then + self.buf = { table.concat(self.buf) } + end + + -- assert(self.len == #(self.buf[1] or ''), 'len mismatch') + + if self.skip_ptr > 0 then + if self.buf[1] then + self.buf[1] = self.buf[1]:sub(self.skip_ptr + 1) + self.len = self.len - self.skip_ptr + end + self.skip_ptr = 0 + end + + -- assert(self.len == #(self.buf[1] or ''), 'len mismatch') + + return self.buf[1] or '' +end + +StrBuffer.__tostring = StrBuffer.tostring + +--- @private +--- Efficiently peak at the first `n` characters of the buffer. +--- @param n integer +--- @return string +function StrBuffer:_peak(n) + local skip, buf1 = self.skip_ptr, self.buf[1] + if buf1 and (n + skip) < #buf1 then + return buf1:sub(skip + 1, skip + n) + end + return self:tostring():sub(1, n) +end + +--- @param chunk string +function StrBuffer:put(chunk) + local s = tostring(chunk) + self.buf[#self.buf + 1] = s + self.len = self.len + #s + return self +end + +--- @param str string +function StrBuffer:set(str) + return self:reset():put(str) +end + +--- @param n integer +--- @return string +function StrBuffer:get(n) + local r = self:_peak(n) + self:skip(n) + return r +end + +--- @param n integer +function StrBuffer:skip(n) + self.skip_ptr = math.min(self.len, self.skip_ptr + n) + return self +end + +function StrBuffer:reset() + self.buf = {} + self.skip_ptr = 0 + self.len = 0 + return self +end + +function M.new() + return setmetatable({}, StrBuffer):reset() +end + +--- @param buf vim._stringbuffer +function M.len(buf) + return buf.len - buf.skip_ptr +end + +return M diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua index 510c2e5199..637cda84d3 100644 --- a/runtime/lua/vim/lsp/rpc.lua +++ b/runtime/lua/vim/lsp/rpc.lua @@ -5,8 +5,8 @@ local validate, schedule_wrap = vim.validate, vim.schedule_wrap --- Embeds the given string into a table and correctly computes `Content-Length`. --- ----@param message string ----@return string message with `Content-Length` attribute +--- @param message string +--- @return string message with `Content-Length` attribute local function format_message_with_content_length(message) return table.concat({ 'Content-Length: ', @@ -18,8 +18,8 @@ end --- Extract content-length from the header --- ----@param header string The header to parse ----@return integer? +--- @param header string The header to parse +--- @return integer local function get_content_length(header) for line in header:gmatch('(.-)\r\n') do if line == '' then @@ -27,112 +27,12 @@ local function get_content_length(header) end local key, value = line:match('^%s*(%S+)%s*:%s*(%d+)%s*$') if key and key:lower() == 'content-length' then - return tonumber(value) + return assert(tonumber(value)) end end error('Content-Length not found in header: ' .. header) end --- This is the start of any possible header patterns. The gsub converts it to a --- case insensitive pattern. -local header_start_pattern = ('content'):gsub('%w', function(c) - return '[' .. c .. c:upper() .. ']' -end) - -local has_strbuffer, strbuffer = pcall(require, 'string.buffer') - ---- The actual workhorse. ----@type function -local request_parser_loop - -if has_strbuffer then - request_parser_loop = function() - local buf = strbuffer.new() - while true do - local msg = buf:tostring() - local header_end = msg:find('\r\n\r\n', 1, true) - if header_end then - local header = buf:get(header_end + 1) - buf:skip(2) -- skip past header boundary - local content_length = get_content_length(header) - while #buf < content_length do - local chunk = coroutine.yield() - buf:put(chunk) - end - local body = buf:get(content_length) - local chunk = coroutine.yield(body) - buf:put(chunk) - else - local chunk = coroutine.yield() - buf:put(chunk) - end - end - end -else - request_parser_loop = function() - local buffer = '' -- only for header part - while true do - -- A message can only be complete if it has a double CRLF and also the full - -- payload, so first let's check for the CRLFs - local header_end, body_start = buffer:find('\r\n\r\n', 1, true) - -- Start parsing the headers - if header_end then - -- This is a workaround for servers sending initial garbage before - -- sending headers, such as if a bash script sends stdout. It assumes - -- that we know all of the headers ahead of time. At this moment, the - -- only valid headers start with "Content-*", so that's the thing we will - -- be searching for. - -- TODO(ashkan) I'd like to remove this, but it seems permanent :( - local buffer_start = buffer:find(header_start_pattern) - if not buffer_start then - error( - string.format( - "Headers were expected, a different response was received. The server response was '%s'.", - buffer - ) - ) - end - local header = buffer:sub(buffer_start, header_end + 1) - local content_length = get_content_length(header) - -- Use table instead of just string to buffer the message. It prevents - -- a ton of strings allocating. - -- ref. http://www.lua.org/pil/11.6.html - ---@type string[] - local body_chunks = { buffer:sub(body_start + 1) } - local body_length = #body_chunks[1] - -- Keep waiting for data until we have enough. - while body_length < content_length do - ---@type string - local chunk = coroutine.yield() - or error('Expected more data for the body. The server may have died.') -- TODO hmm. - table.insert(body_chunks, chunk) - body_length = body_length + #chunk - end - local last_chunk = body_chunks[#body_chunks] - - body_chunks[#body_chunks] = last_chunk:sub(1, content_length - body_length - 1) - local rest = '' - if body_length > content_length then - rest = last_chunk:sub(content_length - body_length) - end - local body = table.concat(body_chunks) - -- Yield our data. - - --- @type string - local data = coroutine.yield(body) - or error('Expected more data for the body. The server may have died.') - buffer = rest .. data - else - -- Get more data since we don't have enough. - --- @type string - local data = coroutine.yield() - or error('Expected more data for the header. The server may have died.') - buffer = buffer .. data - end - end - end -end - local M = {} --- Mapping of error codes used by the client @@ -249,6 +149,28 @@ local default_dispatchers = { end, } +local strbuffer = require('vim._stringbuffer') + +local function request_parser_loop() + local buf = strbuffer.new() + while true do + local msg = buf:tostring() + local header_end = msg:find('\r\n\r\n', 1, true) + if header_end then + local header = buf:get(header_end + 1) + buf:skip(2) -- skip past header boundary + local content_length = get_content_length(header) + while strbuffer.len(buf) < content_length do + buf:put(coroutine.yield()) + end + local body = buf:get(content_length) + buf:put(coroutine.yield(body)) + else + buf:put(coroutine.yield()) + end + end +end + --- @private --- @param handle_body fun(body: string) --- @param on_exit? fun() -- cgit