diff options
author | Lewis Russell <lewis6991@gmail.com> | 2025-03-01 14:44:29 +0000 |
---|---|---|
committer | Lewis Russell <lewis6991@gmail.com> | 2025-03-31 16:51:18 +0100 |
commit | e76a7e8afb1d683de473f881289677f17ef79410 (patch) | |
tree | 118ed9609babcf88ffaacce077d363ed283754b0 /runtime/lua/vim/lsp/rpc.lua | |
parent | f517fcd14847e30c55b88b2ccffdb6ba4b80018c (diff) | |
download | rneovim-e76a7e8afb1d683de473f881289677f17ef79410.tar.gz rneovim-e76a7e8afb1d683de473f881289677f17ef79410.tar.bz2 rneovim-e76a7e8afb1d683de473f881289677f17ef79410.zip |
refactor: add basic stringbuffer shim
Diffstat (limited to 'runtime/lua/vim/lsp/rpc.lua')
-rw-r--r-- | runtime/lua/vim/lsp/rpc.lua | 132 |
1 files changed, 27 insertions, 105 deletions
diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua index 510c2e5199..637cda84d3 100644 --- a/runtime/lua/vim/lsp/rpc.lua +++ b/runtime/lua/vim/lsp/rpc.lua @@ -5,8 +5,8 @@ local validate, schedule_wrap = vim.validate, vim.schedule_wrap --- Embeds the given string into a table and correctly computes `Content-Length`. --- ----@param message string ----@return string message with `Content-Length` attribute +--- @param message string +--- @return string message with `Content-Length` attribute local function format_message_with_content_length(message) return table.concat({ 'Content-Length: ', @@ -18,8 +18,8 @@ end --- Extract content-length from the header --- ----@param header string The header to parse ----@return integer? +--- @param header string The header to parse +--- @return integer local function get_content_length(header) for line in header:gmatch('(.-)\r\n') do if line == '' then @@ -27,112 +27,12 @@ local function get_content_length(header) end local key, value = line:match('^%s*(%S+)%s*:%s*(%d+)%s*$') if key and key:lower() == 'content-length' then - return tonumber(value) + return assert(tonumber(value)) end end error('Content-Length not found in header: ' .. header) end --- This is the start of any possible header patterns. The gsub converts it to a --- case insensitive pattern. -local header_start_pattern = ('content'):gsub('%w', function(c) - return '[' .. c .. c:upper() .. ']' -end) - -local has_strbuffer, strbuffer = pcall(require, 'string.buffer') - ---- The actual workhorse. ----@type function -local request_parser_loop - -if has_strbuffer then - request_parser_loop = function() - local buf = strbuffer.new() - while true do - local msg = buf:tostring() - local header_end = msg:find('\r\n\r\n', 1, true) - if header_end then - local header = buf:get(header_end + 1) - buf:skip(2) -- skip past header boundary - local content_length = get_content_length(header) - while #buf < content_length do - local chunk = coroutine.yield() - buf:put(chunk) - end - local body = buf:get(content_length) - local chunk = coroutine.yield(body) - buf:put(chunk) - else - local chunk = coroutine.yield() - buf:put(chunk) - end - end - end -else - request_parser_loop = function() - local buffer = '' -- only for header part - while true do - -- A message can only be complete if it has a double CRLF and also the full - -- payload, so first let's check for the CRLFs - local header_end, body_start = buffer:find('\r\n\r\n', 1, true) - -- Start parsing the headers - if header_end then - -- This is a workaround for servers sending initial garbage before - -- sending headers, such as if a bash script sends stdout. It assumes - -- that we know all of the headers ahead of time. At this moment, the - -- only valid headers start with "Content-*", so that's the thing we will - -- be searching for. - -- TODO(ashkan) I'd like to remove this, but it seems permanent :( - local buffer_start = buffer:find(header_start_pattern) - if not buffer_start then - error( - string.format( - "Headers were expected, a different response was received. The server response was '%s'.", - buffer - ) - ) - end - local header = buffer:sub(buffer_start, header_end + 1) - local content_length = get_content_length(header) - -- Use table instead of just string to buffer the message. It prevents - -- a ton of strings allocating. - -- ref. http://www.lua.org/pil/11.6.html - ---@type string[] - local body_chunks = { buffer:sub(body_start + 1) } - local body_length = #body_chunks[1] - -- Keep waiting for data until we have enough. - while body_length < content_length do - ---@type string - local chunk = coroutine.yield() - or error('Expected more data for the body. The server may have died.') -- TODO hmm. - table.insert(body_chunks, chunk) - body_length = body_length + #chunk - end - local last_chunk = body_chunks[#body_chunks] - - body_chunks[#body_chunks] = last_chunk:sub(1, content_length - body_length - 1) - local rest = '' - if body_length > content_length then - rest = last_chunk:sub(content_length - body_length) - end - local body = table.concat(body_chunks) - -- Yield our data. - - --- @type string - local data = coroutine.yield(body) - or error('Expected more data for the body. The server may have died.') - buffer = rest .. data - else - -- Get more data since we don't have enough. - --- @type string - local data = coroutine.yield() - or error('Expected more data for the header. The server may have died.') - buffer = buffer .. data - end - end - end -end - local M = {} --- Mapping of error codes used by the client @@ -249,6 +149,28 @@ local default_dispatchers = { end, } +local strbuffer = require('vim._stringbuffer') + +local function request_parser_loop() + local buf = strbuffer.new() + while true do + local msg = buf:tostring() + local header_end = msg:find('\r\n\r\n', 1, true) + if header_end then + local header = buf:get(header_end + 1) + buf:skip(2) -- skip past header boundary + local content_length = get_content_length(header) + while strbuffer.len(buf) < content_length do + buf:put(coroutine.yield()) + end + local body = buf:get(content_length) + buf:put(coroutine.yield(body)) + else + buf:put(coroutine.yield()) + end + end +end + --- @private --- @param handle_body fun(body: string) --- @param on_exit? fun() |