aboutsummaryrefslogtreecommitdiff
path: root/src/nvim/os/channel.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/nvim/os/channel.c')
-rw-r--r--src/nvim/os/channel.c307
1 files changed, 243 insertions, 64 deletions
diff --git a/src/nvim/os/channel.c b/src/nvim/os/channel.c
index 653f09756a..9bba247a7b 100644
--- a/src/nvim/os/channel.c
+++ b/src/nvim/os/channel.c
@@ -5,6 +5,7 @@
#include "nvim/api/private/helpers.h"
#include "nvim/os/channel.h"
+#include "nvim/os/event.h"
#include "nvim/os/rstream.h"
#include "nvim/os/rstream_defs.h"
#include "nvim/os/wstream.h"
@@ -12,17 +13,24 @@
#include "nvim/os/job.h"
#include "nvim/os/job_defs.h"
#include "nvim/os/msgpack_rpc.h"
+#include "nvim/os/msgpack_rpc_helpers.h"
#include "nvim/vim.h"
#include "nvim/memory.h"
+#include "nvim/message.h"
#include "nvim/map.h"
#include "nvim/lib/kvec.h"
typedef struct {
+ uint64_t request_id;
+ bool errored;
+ Object result;
+} ChannelCallFrame;
+
+typedef struct {
uint64_t id;
PMap(cstr_t) *subscribed_events;
- bool is_job;
+ bool is_job, enabled;
msgpack_unpacker *unpacker;
- msgpack_sbuffer *sbuffer;
union {
Job *job;
struct {
@@ -31,12 +39,15 @@ typedef struct {
uv_stream_t *uv;
} streams;
} data;
+ uint64_t next_request_id;
+ kvec_t(ChannelCallFrame *) call_stack;
+ size_t rpc_call_level;
} Channel;
static uint64_t next_id = 1;
static PMap(uint64_t) *channels = NULL;
static PMap(cstr_t) *event_strings = NULL;
-static msgpack_sbuffer msgpack_event_buffer;
+static msgpack_sbuffer out_buffer;
#ifdef INCLUDE_GENERATED_DECLARATIONS
# include "os/channel.c.generated.h"
@@ -47,7 +58,7 @@ void channel_init()
{
channels = pmap_new(uint64_t)();
event_strings = pmap_new(cstr_t)();
- msgpack_sbuffer_init(&msgpack_event_buffer);
+ msgpack_sbuffer_init(&out_buffer);
}
/// Teardown the module
@@ -80,6 +91,7 @@ bool channel_from_job(char **argv)
job_err,
job_exit,
true,
+ 0,
&status);
if (status <= 0) {
@@ -104,7 +116,7 @@ void channel_from_stream(uv_stream_t *stream)
rstream_set_stream(channel->data.streams.read, stream);
rstream_start(channel->data.streams.read);
// write stream
- channel->data.streams.write = wstream_new(1024 * 1024);
+ channel->data.streams.write = wstream_new(0);
wstream_set_stream(channel->data.streams.write, stream);
channel->data.streams.uv = stream;
}
@@ -113,26 +125,98 @@ void channel_from_stream(uv_stream_t *stream)
///
/// @param id The channel id. If 0, the event will be sent to all
/// channels that have subscribed to the event type
-/// @param type The event type, an arbitrary string
-/// @param obj The event data
+/// @param name The event name, an arbitrary string
+/// @param arg The event arg
/// @return True if the data was sent successfully, false otherwise.
-bool channel_send_event(uint64_t id, char *type, Object data)
+bool channel_send_event(uint64_t id, char *name, Object arg)
{
Channel *channel = NULL;
if (id > 0) {
if (!(channel = pmap_get(uint64_t)(channels, id))) {
- msgpack_rpc_free_object(data);
+ msgpack_rpc_free_object(arg);
return false;
}
- send_event(channel, type, data);
+ send_event(channel, name, arg);
} else {
- broadcast_event(type, data);
+ broadcast_event(name, arg);
}
return true;
}
+bool channel_send_call(uint64_t id,
+ char *name,
+ Object arg,
+ Object *result,
+ bool *errored)
+{
+ Channel *channel = NULL;
+
+ if (!(channel = pmap_get(uint64_t)(channels, id))) {
+ msgpack_rpc_free_object(arg);
+ return false;
+ }
+
+ if (kv_size(channel->call_stack) > 20) {
+ // 20 stack depth is more than anyone should ever need for RPC calls
+ *errored = true;
+ char buf[256];
+ snprintf(buf,
+ sizeof(buf),
+ "Channel %" PRIu64 " was closed due to a high stack depth "
+ "while processing a RPC call",
+ channel->id);
+ *result = STRING_OBJ(cstr_to_string(buf));
+ }
+
+ uint64_t request_id = channel->next_request_id++;
+ // Send the msgpack-rpc request
+ send_request(channel, request_id, name, arg);
+
+ if (!kv_size(channel->call_stack)) {
+ // This is the first frame, we must disable event deferral for this
+ // channel because we won't be returning until the client sends a
+ // response
+ if (channel->is_job) {
+ job_set_defer(channel->data.job, false);
+ } else {
+ rstream_set_defer(channel->data.streams.read, false);
+ }
+ }
+
+ // Push the frame
+ ChannelCallFrame frame = {request_id, false, NIL};
+ kv_push(ChannelCallFrame *, channel->call_stack, &frame);
+ size_t size = kv_size(channel->call_stack);
+
+ do {
+ event_poll(-1);
+ } while (
+ // Continue running if ...
+ channel->enabled && // the channel is still enabled
+ kv_size(channel->call_stack) >= size); // the call didn't return
+
+ if (!kv_size(channel->call_stack)) {
+ // Popped last frame, restore event deferral
+ if (channel->is_job) {
+ job_set_defer(channel->data.job, true);
+ } else {
+ rstream_set_defer(channel->data.streams.read, true);
+ }
+ if (!channel->enabled && !channel->rpc_call_level) {
+ // Close the channel if it has been disabled and we have not been called
+ // by `parse_msgpack`(It would be unsafe to close the channel otherwise)
+ close_channel(channel);
+ }
+ }
+
+ *errored = frame.errored;
+ *result = frame.result;
+
+ return true;
+}
+
/// Subscribes to event broadcasts
///
/// @param id The channel id
@@ -191,10 +275,17 @@ static void parse_msgpack(RStream *rstream, void *data, bool eof)
Channel *channel = data;
if (eof) {
- close_channel(channel);
+ char buf[256];
+ snprintf(buf,
+ sizeof(buf),
+ "Before returning from a RPC call, channel %" PRIu64 " was "
+ "closed by the client",
+ channel->id);
+ disable_channel(channel, buf);
return;
}
+ channel->rpc_call_level++;
uint32_t count = rstream_available(rstream);
// Feed the unpacker with data
@@ -205,23 +296,34 @@ static void parse_msgpack(RStream *rstream, void *data, bool eof)
msgpack_unpacked unpacked;
msgpack_unpacked_init(&unpacked);
UnpackResult result;
- msgpack_packer response;
// Deserialize everything we can.
while ((result = msgpack_rpc_unpack(channel->unpacker, &unpacked))
== kUnpackResultOk) {
- // Each object is a new msgpack-rpc request and requires an empty response
- msgpack_packer_init(&response, channel->sbuffer, msgpack_sbuffer_write);
+ if (kv_size(channel->call_stack) && is_rpc_response(&unpacked.data)) {
+ if (is_valid_rpc_response(&unpacked.data, channel)) {
+ call_stack_pop(&unpacked.data, channel);
+ } else {
+ char buf[256];
+ snprintf(buf,
+ sizeof(buf),
+ "Channel %" PRIu64 " returned a response that doesn't have "
+ " a matching id for the current RPC call. Ensure the client "
+ " is properly synchronized",
+ channel->id);
+ call_stack_unwind(channel, buf, 1);
+ }
+ msgpack_unpacked_destroy(&unpacked);
+ // Bail out from this event loop iteration
+ goto end;
+ }
+
// Perform the call
- msgpack_rpc_call(channel->id, &unpacked.data, &response);
- wstream_write(channel->data.streams.write,
- wstream_new_buffer(xmemdup(channel->sbuffer->data,
- channel->sbuffer->size),
- channel->sbuffer->size,
- free));
-
- // Clear the buffer for future calls
- msgpack_sbuffer_clear(channel->sbuffer);
+ WBuffer *resp = msgpack_rpc_call(channel->id, &unpacked.data, &out_buffer);
+ // write the response
+ if (!channel_write(channel, resp)) {
+ goto end;
+ }
}
if (result == kUnpackResultFail) {
@@ -231,50 +333,87 @@ static void parse_msgpack(RStream *rstream, void *data, bool eof)
// A not so uncommon cause for this might be deserializing objects with
// a high nesting level: msgpack will break when it's internal parse stack
// size exceeds MSGPACK_EMBED_STACK_SIZE(defined as 32 by default)
- msgpack_packer_init(&response, channel->sbuffer, msgpack_sbuffer_write);
- msgpack_pack_array(&response, 4);
- msgpack_pack_int(&response, 1);
- msgpack_pack_int(&response, 0);
- msgpack_rpc_error("Invalid msgpack payload. "
- "This error can also happen when deserializing "
- "an object with high level of nesting",
- &response);
- wstream_write(channel->data.streams.write,
- wstream_new_buffer(xmemdup(channel->sbuffer->data,
- channel->sbuffer->size),
- channel->sbuffer->size,
- free));
- // Clear the buffer for future calls
- msgpack_sbuffer_clear(channel->sbuffer);
+ send_error(channel, 0, "Invalid msgpack payload. "
+ "This error can also happen when deserializing "
+ "an object with high level of nesting");
+ }
+
+end:
+ channel->rpc_call_level--;
+ if (!channel->enabled && !kv_size(channel->call_stack)) {
+ // Now it's safe to destroy the channel
+ close_channel(channel);
}
}
-static void send_event(Channel *channel, char *type, Object data)
+static bool channel_write(Channel *channel, WBuffer *buffer)
{
- wstream_write(channel->data.streams.write, serialize_event(type, data));
+ bool success;
+
+ if (channel->is_job) {
+ success = job_write(channel->data.job, buffer);
+ } else {
+ success = wstream_write(channel->data.streams.write, buffer);
+ }
+
+ if (!success) {
+ // If the write failed for any reason, close the channel
+ char buf[256];
+ snprintf(buf,
+ sizeof(buf),
+ "Before returning from a RPC call, channel %" PRIu64 " was "
+ "closed due to a failed write",
+ channel->id);
+ disable_channel(channel, buf);
+ }
+
+ return success;
}
-static void broadcast_event(char *type, Object data)
+static void send_error(Channel *channel, uint64_t id, char *err)
+{
+ channel_write(channel, serialize_response(id, err, NIL, &out_buffer));
+}
+
+static void send_request(Channel *channel,
+ uint64_t id,
+ char *name,
+ Object arg)
+{
+ String method = {.size = strlen(name), .data = name};
+ channel_write(channel, serialize_request(id, method, arg, &out_buffer));
+}
+
+static void send_event(Channel *channel,
+ char *name,
+ Object arg)
+{
+ String method = {.size = strlen(name), .data = name};
+ channel_write(channel, serialize_request(0, method, arg, &out_buffer));
+}
+
+static void broadcast_event(char *name, Object arg)
{
kvec_t(Channel *) subscribed;
kv_init(subscribed);
Channel *channel;
map_foreach_value(channels, channel, {
- if (pmap_has(cstr_t)(channel->subscribed_events, type)) {
+ if (pmap_has(cstr_t)(channel->subscribed_events, name)) {
kv_push(Channel *, subscribed, channel);
}
});
if (!kv_size(subscribed)) {
- msgpack_rpc_free_object(data);
+ msgpack_rpc_free_object(arg);
goto end;
}
- WBuffer *buffer = serialize_event(type, data);
+ String method = {.size = strlen(name), .data = name};
+ WBuffer *buffer = serialize_request(0, method, arg, &out_buffer);
for (size_t i = 0; i < kv_size(subscribed); i++) {
- wstream_write(kv_A(subscribed, i)->data.streams.write, buffer);
+ channel_write(kv_A(subscribed, i), buffer);
}
end:
@@ -300,7 +439,6 @@ static void unsubscribe(Channel *channel, char *event)
static void close_channel(Channel *channel)
{
pmap_del(uint64_t)(channels, channel->id);
- msgpack_sbuffer_free(channel->sbuffer);
msgpack_unpacker_free(channel->unpacker);
if (channel->is_job) {
@@ -320,6 +458,7 @@ static void close_channel(Channel *channel)
});
pmap_free(cstr_t)(channel->subscribed_events);
+ kv_destroy(channel->call_stack);
free(channel);
}
@@ -329,29 +468,69 @@ static void close_cb(uv_handle_t *handle)
free(handle);
}
-static WBuffer *serialize_event(char *type, Object data)
-{
- String event_type = {.size = strnlen(type, EVENT_MAXLEN), .data = type};
- msgpack_packer packer;
- msgpack_packer_init(&packer, &msgpack_event_buffer, msgpack_sbuffer_write);
- msgpack_rpc_notification(event_type, data, &packer);
- WBuffer *rv = wstream_new_buffer(xmemdup(msgpack_event_buffer.data,
- msgpack_event_buffer.size),
- msgpack_event_buffer.size,
- free);
- msgpack_rpc_free_object(data);
- msgpack_sbuffer_clear(&msgpack_event_buffer);
-
- return rv;
-}
-
static Channel *register_channel()
{
Channel *rv = xmalloc(sizeof(Channel));
+ rv->enabled = true;
+ rv->rpc_call_level = 0;
rv->unpacker = msgpack_unpacker_new(MSGPACK_UNPACKER_INIT_BUFFER_SIZE);
- rv->sbuffer = msgpack_sbuffer_new();
rv->id = next_id++;
rv->subscribed_events = pmap_new(cstr_t)();
+ rv->next_request_id = 1;
+ kv_init(rv->call_stack);
pmap_put(uint64_t)(channels, rv->id, rv);
return rv;
}
+
+static bool is_rpc_response(msgpack_object *obj)
+{
+ return obj->type == MSGPACK_OBJECT_ARRAY
+ && obj->via.array.size == 4
+ && obj->via.array.ptr[0].type == MSGPACK_OBJECT_POSITIVE_INTEGER
+ && obj->via.array.ptr[0].via.u64 == 1
+ && obj->via.array.ptr[1].type == MSGPACK_OBJECT_POSITIVE_INTEGER;
+}
+
+static bool is_valid_rpc_response(msgpack_object *obj, Channel *channel)
+{
+ uint64_t response_id = obj->via.array.ptr[1].via.u64;
+ // Must be equal to the frame at the stack's bottom
+ return response_id == kv_A(channel->call_stack,
+ kv_size(channel->call_stack) - 1)->request_id;
+}
+
+static void call_stack_pop(msgpack_object *obj, Channel *channel)
+{
+ ChannelCallFrame *frame = kv_A(channel->call_stack,
+ kv_size(channel->call_stack) - 1);
+ frame->errored = obj->via.array.ptr[2].type != MSGPACK_OBJECT_NIL;
+ (void)kv_pop(channel->call_stack);
+
+ if (frame->errored) {
+ msgpack_rpc_to_object(&obj->via.array.ptr[2], &frame->result);
+ } else {
+ msgpack_rpc_to_object(&obj->via.array.ptr[3], &frame->result);
+ }
+}
+
+static void call_stack_unwind(Channel *channel, char *msg, int count)
+{
+ while (kv_size(channel->call_stack) && count--) {
+ ChannelCallFrame *frame = kv_pop(channel->call_stack);
+ frame->errored = true;
+ frame->result = STRING_OBJ(cstr_to_string(msg));
+ }
+}
+
+static void disable_channel(Channel *channel, char *msg)
+{
+ if (kv_size(channel->call_stack)) {
+ // Channel is currently in the middle of a call, remove all frames and mark
+ // it as "dead"
+ channel->enabled = false;
+ call_stack_unwind(channel, msg, -1);
+ } else {
+ // Safe to close it now
+ close_channel(channel);
+ }
+}