diff options
Diffstat (limited to 'src/nvim/os/channel.c')
-rw-r--r-- | src/nvim/os/channel.c | 307 |
1 files changed, 243 insertions, 64 deletions
diff --git a/src/nvim/os/channel.c b/src/nvim/os/channel.c index 653f09756a..9bba247a7b 100644 --- a/src/nvim/os/channel.c +++ b/src/nvim/os/channel.c @@ -5,6 +5,7 @@ #include "nvim/api/private/helpers.h" #include "nvim/os/channel.h" +#include "nvim/os/event.h" #include "nvim/os/rstream.h" #include "nvim/os/rstream_defs.h" #include "nvim/os/wstream.h" @@ -12,17 +13,24 @@ #include "nvim/os/job.h" #include "nvim/os/job_defs.h" #include "nvim/os/msgpack_rpc.h" +#include "nvim/os/msgpack_rpc_helpers.h" #include "nvim/vim.h" #include "nvim/memory.h" +#include "nvim/message.h" #include "nvim/map.h" #include "nvim/lib/kvec.h" typedef struct { + uint64_t request_id; + bool errored; + Object result; +} ChannelCallFrame; + +typedef struct { uint64_t id; PMap(cstr_t) *subscribed_events; - bool is_job; + bool is_job, enabled; msgpack_unpacker *unpacker; - msgpack_sbuffer *sbuffer; union { Job *job; struct { @@ -31,12 +39,15 @@ typedef struct { uv_stream_t *uv; } streams; } data; + uint64_t next_request_id; + kvec_t(ChannelCallFrame *) call_stack; + size_t rpc_call_level; } Channel; static uint64_t next_id = 1; static PMap(uint64_t) *channels = NULL; static PMap(cstr_t) *event_strings = NULL; -static msgpack_sbuffer msgpack_event_buffer; +static msgpack_sbuffer out_buffer; #ifdef INCLUDE_GENERATED_DECLARATIONS # include "os/channel.c.generated.h" @@ -47,7 +58,7 @@ void channel_init() { channels = pmap_new(uint64_t)(); event_strings = pmap_new(cstr_t)(); - msgpack_sbuffer_init(&msgpack_event_buffer); + msgpack_sbuffer_init(&out_buffer); } /// Teardown the module @@ -80,6 +91,7 @@ bool channel_from_job(char **argv) job_err, job_exit, true, + 0, &status); if (status <= 0) { @@ -104,7 +116,7 @@ void channel_from_stream(uv_stream_t *stream) rstream_set_stream(channel->data.streams.read, stream); rstream_start(channel->data.streams.read); // write stream - channel->data.streams.write = wstream_new(1024 * 1024); + channel->data.streams.write = wstream_new(0); wstream_set_stream(channel->data.streams.write, stream); channel->data.streams.uv = stream; } @@ -113,26 +125,98 @@ void channel_from_stream(uv_stream_t *stream) /// /// @param id The channel id. If 0, the event will be sent to all /// channels that have subscribed to the event type -/// @param type The event type, an arbitrary string -/// @param obj The event data +/// @param name The event name, an arbitrary string +/// @param arg The event arg /// @return True if the data was sent successfully, false otherwise. -bool channel_send_event(uint64_t id, char *type, Object data) +bool channel_send_event(uint64_t id, char *name, Object arg) { Channel *channel = NULL; if (id > 0) { if (!(channel = pmap_get(uint64_t)(channels, id))) { - msgpack_rpc_free_object(data); + msgpack_rpc_free_object(arg); return false; } - send_event(channel, type, data); + send_event(channel, name, arg); } else { - broadcast_event(type, data); + broadcast_event(name, arg); } return true; } +bool channel_send_call(uint64_t id, + char *name, + Object arg, + Object *result, + bool *errored) +{ + Channel *channel = NULL; + + if (!(channel = pmap_get(uint64_t)(channels, id))) { + msgpack_rpc_free_object(arg); + return false; + } + + if (kv_size(channel->call_stack) > 20) { + // 20 stack depth is more than anyone should ever need for RPC calls + *errored = true; + char buf[256]; + snprintf(buf, + sizeof(buf), + "Channel %" PRIu64 " was closed due to a high stack depth " + "while processing a RPC call", + channel->id); + *result = STRING_OBJ(cstr_to_string(buf)); + } + + uint64_t request_id = channel->next_request_id++; + // Send the msgpack-rpc request + send_request(channel, request_id, name, arg); + + if (!kv_size(channel->call_stack)) { + // This is the first frame, we must disable event deferral for this + // channel because we won't be returning until the client sends a + // response + if (channel->is_job) { + job_set_defer(channel->data.job, false); + } else { + rstream_set_defer(channel->data.streams.read, false); + } + } + + // Push the frame + ChannelCallFrame frame = {request_id, false, NIL}; + kv_push(ChannelCallFrame *, channel->call_stack, &frame); + size_t size = kv_size(channel->call_stack); + + do { + event_poll(-1); + } while ( + // Continue running if ... + channel->enabled && // the channel is still enabled + kv_size(channel->call_stack) >= size); // the call didn't return + + if (!kv_size(channel->call_stack)) { + // Popped last frame, restore event deferral + if (channel->is_job) { + job_set_defer(channel->data.job, true); + } else { + rstream_set_defer(channel->data.streams.read, true); + } + if (!channel->enabled && !channel->rpc_call_level) { + // Close the channel if it has been disabled and we have not been called + // by `parse_msgpack`(It would be unsafe to close the channel otherwise) + close_channel(channel); + } + } + + *errored = frame.errored; + *result = frame.result; + + return true; +} + /// Subscribes to event broadcasts /// /// @param id The channel id @@ -191,10 +275,17 @@ static void parse_msgpack(RStream *rstream, void *data, bool eof) Channel *channel = data; if (eof) { - close_channel(channel); + char buf[256]; + snprintf(buf, + sizeof(buf), + "Before returning from a RPC call, channel %" PRIu64 " was " + "closed by the client", + channel->id); + disable_channel(channel, buf); return; } + channel->rpc_call_level++; uint32_t count = rstream_available(rstream); // Feed the unpacker with data @@ -205,23 +296,34 @@ static void parse_msgpack(RStream *rstream, void *data, bool eof) msgpack_unpacked unpacked; msgpack_unpacked_init(&unpacked); UnpackResult result; - msgpack_packer response; // Deserialize everything we can. while ((result = msgpack_rpc_unpack(channel->unpacker, &unpacked)) == kUnpackResultOk) { - // Each object is a new msgpack-rpc request and requires an empty response - msgpack_packer_init(&response, channel->sbuffer, msgpack_sbuffer_write); + if (kv_size(channel->call_stack) && is_rpc_response(&unpacked.data)) { + if (is_valid_rpc_response(&unpacked.data, channel)) { + call_stack_pop(&unpacked.data, channel); + } else { + char buf[256]; + snprintf(buf, + sizeof(buf), + "Channel %" PRIu64 " returned a response that doesn't have " + " a matching id for the current RPC call. Ensure the client " + " is properly synchronized", + channel->id); + call_stack_unwind(channel, buf, 1); + } + msgpack_unpacked_destroy(&unpacked); + // Bail out from this event loop iteration + goto end; + } + // Perform the call - msgpack_rpc_call(channel->id, &unpacked.data, &response); - wstream_write(channel->data.streams.write, - wstream_new_buffer(xmemdup(channel->sbuffer->data, - channel->sbuffer->size), - channel->sbuffer->size, - free)); - - // Clear the buffer for future calls - msgpack_sbuffer_clear(channel->sbuffer); + WBuffer *resp = msgpack_rpc_call(channel->id, &unpacked.data, &out_buffer); + // write the response + if (!channel_write(channel, resp)) { + goto end; + } } if (result == kUnpackResultFail) { @@ -231,50 +333,87 @@ static void parse_msgpack(RStream *rstream, void *data, bool eof) // A not so uncommon cause for this might be deserializing objects with // a high nesting level: msgpack will break when it's internal parse stack // size exceeds MSGPACK_EMBED_STACK_SIZE(defined as 32 by default) - msgpack_packer_init(&response, channel->sbuffer, msgpack_sbuffer_write); - msgpack_pack_array(&response, 4); - msgpack_pack_int(&response, 1); - msgpack_pack_int(&response, 0); - msgpack_rpc_error("Invalid msgpack payload. " - "This error can also happen when deserializing " - "an object with high level of nesting", - &response); - wstream_write(channel->data.streams.write, - wstream_new_buffer(xmemdup(channel->sbuffer->data, - channel->sbuffer->size), - channel->sbuffer->size, - free)); - // Clear the buffer for future calls - msgpack_sbuffer_clear(channel->sbuffer); + send_error(channel, 0, "Invalid msgpack payload. " + "This error can also happen when deserializing " + "an object with high level of nesting"); + } + +end: + channel->rpc_call_level--; + if (!channel->enabled && !kv_size(channel->call_stack)) { + // Now it's safe to destroy the channel + close_channel(channel); } } -static void send_event(Channel *channel, char *type, Object data) +static bool channel_write(Channel *channel, WBuffer *buffer) { - wstream_write(channel->data.streams.write, serialize_event(type, data)); + bool success; + + if (channel->is_job) { + success = job_write(channel->data.job, buffer); + } else { + success = wstream_write(channel->data.streams.write, buffer); + } + + if (!success) { + // If the write failed for any reason, close the channel + char buf[256]; + snprintf(buf, + sizeof(buf), + "Before returning from a RPC call, channel %" PRIu64 " was " + "closed due to a failed write", + channel->id); + disable_channel(channel, buf); + } + + return success; } -static void broadcast_event(char *type, Object data) +static void send_error(Channel *channel, uint64_t id, char *err) +{ + channel_write(channel, serialize_response(id, err, NIL, &out_buffer)); +} + +static void send_request(Channel *channel, + uint64_t id, + char *name, + Object arg) +{ + String method = {.size = strlen(name), .data = name}; + channel_write(channel, serialize_request(id, method, arg, &out_buffer)); +} + +static void send_event(Channel *channel, + char *name, + Object arg) +{ + String method = {.size = strlen(name), .data = name}; + channel_write(channel, serialize_request(0, method, arg, &out_buffer)); +} + +static void broadcast_event(char *name, Object arg) { kvec_t(Channel *) subscribed; kv_init(subscribed); Channel *channel; map_foreach_value(channels, channel, { - if (pmap_has(cstr_t)(channel->subscribed_events, type)) { + if (pmap_has(cstr_t)(channel->subscribed_events, name)) { kv_push(Channel *, subscribed, channel); } }); if (!kv_size(subscribed)) { - msgpack_rpc_free_object(data); + msgpack_rpc_free_object(arg); goto end; } - WBuffer *buffer = serialize_event(type, data); + String method = {.size = strlen(name), .data = name}; + WBuffer *buffer = serialize_request(0, method, arg, &out_buffer); for (size_t i = 0; i < kv_size(subscribed); i++) { - wstream_write(kv_A(subscribed, i)->data.streams.write, buffer); + channel_write(kv_A(subscribed, i), buffer); } end: @@ -300,7 +439,6 @@ static void unsubscribe(Channel *channel, char *event) static void close_channel(Channel *channel) { pmap_del(uint64_t)(channels, channel->id); - msgpack_sbuffer_free(channel->sbuffer); msgpack_unpacker_free(channel->unpacker); if (channel->is_job) { @@ -320,6 +458,7 @@ static void close_channel(Channel *channel) }); pmap_free(cstr_t)(channel->subscribed_events); + kv_destroy(channel->call_stack); free(channel); } @@ -329,29 +468,69 @@ static void close_cb(uv_handle_t *handle) free(handle); } -static WBuffer *serialize_event(char *type, Object data) -{ - String event_type = {.size = strnlen(type, EVENT_MAXLEN), .data = type}; - msgpack_packer packer; - msgpack_packer_init(&packer, &msgpack_event_buffer, msgpack_sbuffer_write); - msgpack_rpc_notification(event_type, data, &packer); - WBuffer *rv = wstream_new_buffer(xmemdup(msgpack_event_buffer.data, - msgpack_event_buffer.size), - msgpack_event_buffer.size, - free); - msgpack_rpc_free_object(data); - msgpack_sbuffer_clear(&msgpack_event_buffer); - - return rv; -} - static Channel *register_channel() { Channel *rv = xmalloc(sizeof(Channel)); + rv->enabled = true; + rv->rpc_call_level = 0; rv->unpacker = msgpack_unpacker_new(MSGPACK_UNPACKER_INIT_BUFFER_SIZE); - rv->sbuffer = msgpack_sbuffer_new(); rv->id = next_id++; rv->subscribed_events = pmap_new(cstr_t)(); + rv->next_request_id = 1; + kv_init(rv->call_stack); pmap_put(uint64_t)(channels, rv->id, rv); return rv; } + +static bool is_rpc_response(msgpack_object *obj) +{ + return obj->type == MSGPACK_OBJECT_ARRAY + && obj->via.array.size == 4 + && obj->via.array.ptr[0].type == MSGPACK_OBJECT_POSITIVE_INTEGER + && obj->via.array.ptr[0].via.u64 == 1 + && obj->via.array.ptr[1].type == MSGPACK_OBJECT_POSITIVE_INTEGER; +} + +static bool is_valid_rpc_response(msgpack_object *obj, Channel *channel) +{ + uint64_t response_id = obj->via.array.ptr[1].via.u64; + // Must be equal to the frame at the stack's bottom + return response_id == kv_A(channel->call_stack, + kv_size(channel->call_stack) - 1)->request_id; +} + +static void call_stack_pop(msgpack_object *obj, Channel *channel) +{ + ChannelCallFrame *frame = kv_A(channel->call_stack, + kv_size(channel->call_stack) - 1); + frame->errored = obj->via.array.ptr[2].type != MSGPACK_OBJECT_NIL; + (void)kv_pop(channel->call_stack); + + if (frame->errored) { + msgpack_rpc_to_object(&obj->via.array.ptr[2], &frame->result); + } else { + msgpack_rpc_to_object(&obj->via.array.ptr[3], &frame->result); + } +} + +static void call_stack_unwind(Channel *channel, char *msg, int count) +{ + while (kv_size(channel->call_stack) && count--) { + ChannelCallFrame *frame = kv_pop(channel->call_stack); + frame->errored = true; + frame->result = STRING_OBJ(cstr_to_string(msg)); + } +} + +static void disable_channel(Channel *channel, char *msg) +{ + if (kv_size(channel->call_stack)) { + // Channel is currently in the middle of a call, remove all frames and mark + // it as "dead" + channel->enabled = false; + call_stack_unwind(channel, msg, -1); + } else { + // Safe to close it now + close_channel(channel); + } +} |