aboutsummaryrefslogtreecommitdiff
path: root/src/nvim/msgpack_rpc/channel.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/nvim/msgpack_rpc/channel.c')
-rw-r--r--src/nvim/msgpack_rpc/channel.c770
1 files changed, 770 insertions, 0 deletions
diff --git a/src/nvim/msgpack_rpc/channel.c b/src/nvim/msgpack_rpc/channel.c
new file mode 100644
index 0000000000..43bed54b2c
--- /dev/null
+++ b/src/nvim/msgpack_rpc/channel.c
@@ -0,0 +1,770 @@
+#include <stdbool.h>
+#include <string.h>
+#include <inttypes.h>
+
+#include <uv.h>
+#include <msgpack.h>
+
+#include "nvim/lib/klist.h"
+
+#include "nvim/api/private/helpers.h"
+#include "nvim/api/vim.h"
+#include "nvim/msgpack_rpc/channel.h"
+#include "nvim/os/event.h"
+#include "nvim/os/rstream.h"
+#include "nvim/os/rstream_defs.h"
+#include "nvim/os/wstream.h"
+#include "nvim/os/wstream_defs.h"
+#include "nvim/os/job.h"
+#include "nvim/os/job_defs.h"
+#include "nvim/msgpack_rpc/helpers.h"
+#include "nvim/vim.h"
+#include "nvim/ascii.h"
+#include "nvim/memory.h"
+#include "nvim/os_unix.h"
+#include "nvim/message.h"
+#include "nvim/term.h"
+#include "nvim/map.h"
+#include "nvim/log.h"
+#include "nvim/misc1.h"
+#include "nvim/lib/kvec.h"
+
+#define CHANNEL_BUFFER_SIZE 0xffff
+
+#if MIN_LOG_LEVEL > DEBUG_LOG_LEVEL
+#define log_client_msg(...)
+#define log_server_msg(...)
+#endif
+
+typedef struct {
+ uint64_t request_id;
+ bool returned, errored;
+ Object result;
+} ChannelCallFrame;
+
+typedef struct {
+ uint64_t id;
+ PMap(cstr_t) *subscribed_events;
+ bool is_job, closed;
+ msgpack_unpacker *unpacker;
+ union {
+ Job *job;
+ struct {
+ RStream *read;
+ WStream *write;
+ uv_stream_t *uv;
+ } streams;
+ } data;
+ uint64_t next_request_id;
+ kvec_t(ChannelCallFrame *) call_stack;
+} Channel;
+
+typedef struct {
+ Channel *channel;
+ MsgpackRpcRequestHandler handler;
+ Array args;
+ uint64_t request_id;
+} RequestEvent;
+
+#define RequestEventFreer(x)
+KMEMPOOL_INIT(RequestEventPool, RequestEvent, RequestEventFreer)
+kmempool_t(RequestEventPool) *request_event_pool = NULL;
+
+static uint64_t next_id = 1;
+static PMap(uint64_t) *channels = NULL;
+static PMap(cstr_t) *event_strings = NULL;
+static msgpack_sbuffer out_buffer;
+
+#ifdef INCLUDE_GENERATED_DECLARATIONS
+# include "msgpack_rpc/channel.c.generated.h"
+#endif
+
+/// Initializes the module
+void channel_init(void)
+{
+ request_event_pool = kmp_init(RequestEventPool);
+ channels = pmap_new(uint64_t)();
+ event_strings = pmap_new(cstr_t)();
+ msgpack_sbuffer_init(&out_buffer);
+
+ if (embedded_mode) {
+ channel_from_stdio();
+ }
+}
+
+/// Teardown the module
+void channel_teardown(void)
+{
+ if (!channels) {
+ return;
+ }
+
+ Channel *channel;
+
+ map_foreach_value(channels, channel, {
+ close_channel(channel);
+ });
+}
+
+/// Creates an API channel by starting a job and connecting to its
+/// stdin/stdout. stderr is forwarded to the editor error stream.
+///
+/// @param argv The argument vector for the process
+/// @return The channel id
+uint64_t channel_from_job(char **argv)
+{
+ Channel *channel = register_channel();
+ channel->is_job = true;
+
+ int status;
+ channel->data.job = job_start(argv,
+ channel,
+ job_out,
+ job_err,
+ job_exit,
+ 0,
+ &status);
+
+ if (status <= 0) {
+ free_channel(channel);
+ return 0;
+ }
+
+ return channel->id;
+}
+
+/// Creates an API channel from a libuv stream representing a tcp or
+/// pipe/socket client connection
+///
+/// @param stream The established connection
+void channel_from_stream(uv_stream_t *stream)
+{
+ Channel *channel = register_channel();
+ stream->data = NULL;
+ channel->is_job = false;
+ // read stream
+ channel->data.streams.read = rstream_new(parse_msgpack,
+ rbuffer_new(CHANNEL_BUFFER_SIZE),
+ channel);
+ rstream_set_stream(channel->data.streams.read, stream);
+ rstream_start(channel->data.streams.read);
+ // write stream
+ channel->data.streams.write = wstream_new(0);
+ wstream_set_stream(channel->data.streams.write, stream);
+ channel->data.streams.uv = stream;
+}
+
+bool channel_exists(uint64_t id)
+{
+ Channel *channel;
+ return (channel = pmap_get(uint64_t)(channels, id)) != NULL
+ && !channel->closed;
+}
+
+/// Sends event/arguments to channel
+///
+/// @param id The channel id. If 0, the event will be sent to all
+/// channels that have subscribed to the event type
+/// @param name The event name, an arbitrary string
+/// @param args Array with event arguments
+/// @return True if the event was sent successfully, false otherwise.
+bool channel_send_event(uint64_t id, char *name, Array args)
+{
+ Channel *channel = NULL;
+
+ if (id > 0) {
+ if (!(channel = pmap_get(uint64_t)(channels, id)) || channel->closed) {
+ api_free_array(args);
+ return false;
+ }
+ send_event(channel, name, args);
+ } else {
+ broadcast_event(name, args);
+ }
+
+ return true;
+}
+
+/// Sends a method call to a channel
+///
+/// @param id The channel id
+/// @param method_name The method name, an arbitrary string
+/// @param args Array with method arguments
+/// @param[out] error True if the return value is an error
+/// @return Whatever the remote method returned
+Object channel_send_call(uint64_t id,
+ char *method_name,
+ Array args,
+ Error *err)
+{
+ Channel *channel = NULL;
+
+ if (!(channel = pmap_get(uint64_t)(channels, id)) || channel->closed) {
+ api_set_error(err, Exception, _("Invalid channel \"%" PRIu64 "\""), id);
+ api_free_array(args);
+ return NIL;
+ }
+
+ if (kv_size(channel->call_stack) > 20) {
+ // 20 stack depth is more than anyone should ever need for RPC calls
+ api_set_error(err,
+ Exception,
+ _("Channel %" PRIu64 " crossed maximum stack depth"),
+ channel->id);
+ api_free_array(args);
+ return NIL;
+ }
+
+ uint64_t request_id = channel->next_request_id++;
+ // Send the msgpack-rpc request
+ send_request(channel, request_id, method_name, args);
+
+ // Push the frame
+ ChannelCallFrame frame = {request_id, false, false, NIL};
+ kv_push(ChannelCallFrame *, channel->call_stack, &frame);
+ event_poll_until(-1, frame.returned);
+ (void)kv_pop(channel->call_stack);
+
+ if (frame.errored) {
+ api_set_error(err, Exception, "%s", frame.result.data.string.data);
+ return NIL;
+ }
+
+ if (channel->closed && !kv_size(channel->call_stack)) {
+ free_channel(channel);
+ }
+
+ return frame.result;
+}
+
+/// Subscribes to event broadcasts
+///
+/// @param id The channel id
+/// @param event The event type string
+void channel_subscribe(uint64_t id, char *event)
+{
+ Channel *channel;
+
+ if (!(channel = pmap_get(uint64_t)(channels, id)) || channel->closed) {
+ abort();
+ }
+
+ char *event_string = pmap_get(cstr_t)(event_strings, event);
+
+ if (!event_string) {
+ event_string = xstrdup(event);
+ pmap_put(cstr_t)(event_strings, event_string, event_string);
+ }
+
+ pmap_put(cstr_t)(channel->subscribed_events, event_string, event_string);
+}
+
+/// Unsubscribes to event broadcasts
+///
+/// @param id The channel id
+/// @param event The event type string
+void channel_unsubscribe(uint64_t id, char *event)
+{
+ Channel *channel;
+
+ if (!(channel = pmap_get(uint64_t)(channels, id)) || channel->closed) {
+ abort();
+ }
+
+ unsubscribe(channel, event);
+}
+
+/// Closes a channel
+///
+/// @param id The channel id
+/// @return true if successful, false otherwise
+bool channel_close(uint64_t id)
+{
+ Channel *channel;
+
+ if (!(channel = pmap_get(uint64_t)(channels, id)) || channel->closed) {
+ return false;
+ }
+
+ close_channel(channel);
+ return true;
+}
+
+/// Creates an API channel from stdin/stdout. This is used when embedding
+/// Neovim
+static void channel_from_stdio(void)
+{
+ Channel *channel = register_channel();
+ channel->is_job = false;
+ // read stream
+ channel->data.streams.read = rstream_new(parse_msgpack,
+ rbuffer_new(CHANNEL_BUFFER_SIZE),
+ channel);
+ rstream_set_file(channel->data.streams.read, 0);
+ rstream_start(channel->data.streams.read);
+ // write stream
+ channel->data.streams.write = wstream_new(0);
+ wstream_set_file(channel->data.streams.write, 1);
+ channel->data.streams.uv = NULL;
+}
+
+static void job_out(RStream *rstream, void *data, bool eof)
+{
+ Job *job = data;
+ parse_msgpack(rstream, job_data(job), eof);
+}
+
+static void job_err(RStream *rstream, void *data, bool eof)
+{
+ size_t count;
+ char buf[256];
+ Channel *channel = job_data(data);
+
+ while ((count = rstream_pending(rstream))) {
+ size_t read = rstream_read(rstream, buf, sizeof(buf) - 1);
+ buf[read] = NUL;
+ ELOG("Channel %" PRIu64 " stderr: %s", channel->id, buf);
+ }
+}
+
+static void job_exit(Job *job, void *data)
+{
+ free_channel((Channel *)data);
+}
+
+static void parse_msgpack(RStream *rstream, void *data, bool eof)
+{
+ Channel *channel = data;
+
+ if (eof) {
+ close_channel(channel);
+ call_set_error(channel, "Channel was closed by the client");
+ return;
+ }
+
+ size_t count = rstream_pending(rstream);
+ DLOG("Feeding the msgpack parser with %u bytes of data from RStream(%p)",
+ count,
+ rstream);
+
+ // Feed the unpacker with data
+ msgpack_unpacker_reserve_buffer(channel->unpacker, count);
+ rstream_read(rstream, msgpack_unpacker_buffer(channel->unpacker), count);
+ msgpack_unpacker_buffer_consumed(channel->unpacker, count);
+
+ msgpack_unpacked unpacked;
+ msgpack_unpacked_init(&unpacked);
+ msgpack_unpack_return result;
+
+ // Deserialize everything we can.
+ while ((result = msgpack_unpacker_next(channel->unpacker, &unpacked)) ==
+ MSGPACK_UNPACK_SUCCESS) {
+ bool is_response = is_rpc_response(&unpacked.data);
+ log_client_msg(channel->id, !is_response, unpacked.data);
+
+ if (kv_size(channel->call_stack) && is_response) {
+ if (is_valid_rpc_response(&unpacked.data, channel)) {
+ complete_call(&unpacked.data, channel);
+ } else {
+ char buf[256];
+ snprintf(buf,
+ sizeof(buf),
+ "Channel %" PRIu64 " returned a response that doesn't have "
+ " a matching id for the current RPC call. Ensure the client "
+ " is properly synchronized",
+ channel->id);
+ call_set_error(channel, buf);
+ }
+ msgpack_unpacked_destroy(&unpacked);
+ // Bail out from this event loop iteration
+ return;
+ }
+
+ handle_request(channel, &unpacked.data);
+ }
+
+ if (result == MSGPACK_UNPACK_NOMEM_ERROR) {
+ OUT_STR(e_outofmem);
+ out_char('\n');
+ preserve_exit();
+ }
+
+ if (result == MSGPACK_UNPACK_PARSE_ERROR) {
+ // See src/msgpack/unpack_template.h in msgpack source tree for
+ // causes for this error(search for 'goto _failed')
+ //
+ // A not so uncommon cause for this might be deserializing objects with
+ // a high nesting level: msgpack will break when it's internal parse stack
+ // size exceeds MSGPACK_EMBED_STACK_SIZE(defined as 32 by default)
+ send_error(channel, 0, "Invalid msgpack payload. "
+ "This error can also happen when deserializing "
+ "an object with high level of nesting");
+ }
+}
+
+static void handle_request(Channel *channel, msgpack_object *request)
+ FUNC_ATTR_NONNULL_ALL
+{
+ uint64_t request_id;
+ Error error = ERROR_INIT;
+ msgpack_rpc_validate(&request_id, request, &error);
+
+ if (error.set) {
+ // Validation failed, send response with error
+ channel_write(channel,
+ serialize_response(channel->id,
+ request_id,
+ &error,
+ NIL,
+ &out_buffer));
+ return;
+ }
+
+ // Retrieve the request handler
+ MsgpackRpcRequestHandler handler;
+ msgpack_object method = request->via.array.ptr[2];
+
+ if (method.type == MSGPACK_OBJECT_BIN || method.type == MSGPACK_OBJECT_STR) {
+ handler = msgpack_rpc_get_handler_for(method.via.bin.ptr,
+ method.via.bin.size);
+ } else {
+ handler.fn = msgpack_rpc_handle_missing_method;
+ handler.defer = false;
+ }
+
+ Array args;
+ msgpack_rpc_to_array(request->via.array.ptr + 3, &args);
+
+ if (kv_size(channel->call_stack) || !handler.defer) {
+ call_request_handler(channel, handler, args, request_id);
+ return;
+ }
+
+ // Defer calling the request handler.
+ RequestEvent *event_data = kmp_alloc(RequestEventPool, request_event_pool);
+ event_data->channel = channel;
+ event_data->handler = handler;
+ event_data->args = args;
+ event_data->request_id = request_id;
+ event_push((Event) {
+ .handler = on_request_event,
+ .data = event_data
+ });
+}
+
+static void on_request_event(Event event)
+{
+ RequestEvent *e = event.data;
+ call_request_handler(e->channel, e->handler, e->args, e->request_id);
+ kmp_free(RequestEventPool, request_event_pool, e);
+}
+
+static void call_request_handler(Channel *channel,
+ MsgpackRpcRequestHandler handler,
+ Array args,
+ uint64_t request_id)
+{
+ Error error = ERROR_INIT;
+ Object result = handler.fn(channel->id, request_id, args, &error);
+ // send the response
+ msgpack_packer response;
+ msgpack_packer_init(&response, &out_buffer, msgpack_sbuffer_write);
+ channel_write(channel, serialize_response(channel->id,
+ request_id,
+ &error,
+ result,
+ &out_buffer));
+ // All arguments were freed already, but we still need to free the array
+ free(args.items);
+}
+
+static bool channel_write(Channel *channel, WBuffer *buffer)
+{
+ bool success;
+
+ if (channel->is_job) {
+ success = job_write(channel->data.job, buffer);
+ } else {
+ success = wstream_write(channel->data.streams.write, buffer);
+ }
+
+ if (!success) {
+ // If the write failed for any reason, close the channel
+ char buf[256];
+ snprintf(buf,
+ sizeof(buf),
+ "Before returning from a RPC call, channel %" PRIu64 " was "
+ "closed due to a failed write",
+ channel->id);
+ call_set_error(channel, buf);
+ }
+
+ return success;
+}
+
+static void send_error(Channel *channel, uint64_t id, char *err)
+{
+ Error e = ERROR_INIT;
+ api_set_error(&e, Exception, "%s", err);
+ channel_write(channel, serialize_response(channel->id,
+ id,
+ &e,
+ NIL,
+ &out_buffer));
+}
+
+static void send_request(Channel *channel,
+ uint64_t id,
+ char *name,
+ Array args)
+{
+ String method = {.size = strlen(name), .data = name};
+ channel_write(channel, serialize_request(channel->id,
+ id,
+ method,
+ args,
+ &out_buffer,
+ 1));
+}
+
+static void send_event(Channel *channel,
+ char *name,
+ Array args)
+{
+ String method = {.size = strlen(name), .data = name};
+ channel_write(channel, serialize_request(channel->id,
+ 0,
+ method,
+ args,
+ &out_buffer,
+ 1));
+}
+
+static void broadcast_event(char *name, Array args)
+{
+ kvec_t(Channel *) subscribed;
+ kv_init(subscribed);
+ Channel *channel;
+
+ map_foreach_value(channels, channel, {
+ if (pmap_has(cstr_t)(channel->subscribed_events, name)) {
+ kv_push(Channel *, subscribed, channel);
+ }
+ });
+
+ if (!kv_size(subscribed)) {
+ api_free_array(args);
+ goto end;
+ }
+
+ String method = {.size = strlen(name), .data = name};
+ WBuffer *buffer = serialize_request(0,
+ 0,
+ method,
+ args,
+ &out_buffer,
+ kv_size(subscribed));
+
+ for (size_t i = 0; i < kv_size(subscribed); i++) {
+ channel_write(kv_A(subscribed, i), buffer);
+ }
+
+end:
+ kv_destroy(subscribed);
+}
+
+static void unsubscribe(Channel *channel, char *event)
+{
+ char *event_string = pmap_get(cstr_t)(event_strings, event);
+ pmap_del(cstr_t)(channel->subscribed_events, event_string);
+
+ map_foreach_value(channels, channel, {
+ if (pmap_has(cstr_t)(channel->subscribed_events, event_string)) {
+ return;
+ }
+ });
+
+ // Since the string is no longer used by other channels, release it's memory
+ pmap_del(cstr_t)(event_strings, event_string);
+ free(event_string);
+}
+
+/// Close the channel streams/job. The channel resources will be freed by
+/// free_channel later.
+static void close_channel(Channel *channel)
+{
+ if (channel->closed) {
+ return;
+ }
+
+ channel->closed = true;
+ if (channel->is_job) {
+ if (channel->data.job) {
+ job_stop(channel->data.job);
+ }
+ } else {
+ rstream_free(channel->data.streams.read);
+ wstream_free(channel->data.streams.write);
+ uv_handle_t *handle = (uv_handle_t *)channel->data.streams.uv;
+ if (handle) {
+ uv_close(handle, close_cb);
+ } else {
+ mch_exit(0);
+ }
+ }
+}
+
+static void free_channel(Channel *channel)
+{
+ pmap_del(uint64_t)(channels, channel->id);
+ msgpack_unpacker_free(channel->unpacker);
+
+ // Unsubscribe from all events
+ char *event_string;
+ map_foreach_value(channel->subscribed_events, event_string, {
+ unsubscribe(channel, event_string);
+ });
+
+ pmap_free(cstr_t)(channel->subscribed_events);
+ kv_destroy(channel->call_stack);
+ free(channel);
+}
+
+static void close_cb(uv_handle_t *handle)
+{
+ free(handle->data);
+ free(handle);
+}
+
+static Channel *register_channel(void)
+{
+ Channel *rv = xmalloc(sizeof(Channel));
+ rv->closed = false;
+ rv->unpacker = msgpack_unpacker_new(MSGPACK_UNPACKER_INIT_BUFFER_SIZE);
+ rv->id = next_id++;
+ rv->subscribed_events = pmap_new(cstr_t)();
+ rv->next_request_id = 1;
+ kv_init(rv->call_stack);
+ pmap_put(uint64_t)(channels, rv->id, rv);
+ return rv;
+}
+
+static bool is_rpc_response(msgpack_object *obj)
+{
+ return obj->type == MSGPACK_OBJECT_ARRAY
+ && obj->via.array.size == 4
+ && obj->via.array.ptr[0].type == MSGPACK_OBJECT_POSITIVE_INTEGER
+ && obj->via.array.ptr[0].via.u64 == 1
+ && obj->via.array.ptr[1].type == MSGPACK_OBJECT_POSITIVE_INTEGER;
+}
+
+static bool is_valid_rpc_response(msgpack_object *obj, Channel *channel)
+{
+ uint64_t response_id = obj->via.array.ptr[1].via.u64;
+ // Must be equal to the frame at the stack's bottom
+ return response_id == kv_A(channel->call_stack,
+ kv_size(channel->call_stack) - 1)->request_id;
+}
+
+static void complete_call(msgpack_object *obj, Channel *channel)
+{
+ ChannelCallFrame *frame = kv_A(channel->call_stack,
+ kv_size(channel->call_stack) - 1);
+ frame->returned = true;
+ frame->errored = obj->via.array.ptr[2].type != MSGPACK_OBJECT_NIL;
+
+ if (frame->errored) {
+ msgpack_rpc_to_object(&obj->via.array.ptr[2], &frame->result);
+ } else {
+ msgpack_rpc_to_object(&obj->via.array.ptr[3], &frame->result);
+ }
+}
+
+static void call_set_error(Channel *channel, char *msg)
+{
+ for (size_t i = 0; i < kv_size(channel->call_stack); i++) {
+ ChannelCallFrame *frame = kv_A(channel->call_stack, i);
+ frame->returned = true;
+ frame->errored = true;
+ frame->result = STRING_OBJ(cstr_to_string(msg));
+ }
+
+ close_channel(channel);
+}
+
+static WBuffer *serialize_request(uint64_t channel_id,
+ uint64_t request_id,
+ String method,
+ Array args,
+ msgpack_sbuffer *sbuffer,
+ size_t refcount)
+{
+ msgpack_packer pac;
+ msgpack_packer_init(&pac, sbuffer, msgpack_sbuffer_write);
+ msgpack_rpc_serialize_request(request_id, method, args, &pac);
+ log_server_msg(channel_id, sbuffer);
+ WBuffer *rv = wstream_new_buffer(xmemdup(sbuffer->data, sbuffer->size),
+ sbuffer->size,
+ refcount,
+ free);
+ msgpack_sbuffer_clear(sbuffer);
+ api_free_array(args);
+ return rv;
+}
+
+static WBuffer *serialize_response(uint64_t channel_id,
+ uint64_t response_id,
+ Error *err,
+ Object arg,
+ msgpack_sbuffer *sbuffer)
+{
+ msgpack_packer pac;
+ msgpack_packer_init(&pac, sbuffer, msgpack_sbuffer_write);
+ msgpack_rpc_serialize_response(response_id, err, arg, &pac);
+ log_server_msg(channel_id, sbuffer);
+ WBuffer *rv = wstream_new_buffer(xmemdup(sbuffer->data, sbuffer->size),
+ sbuffer->size,
+ 1, // responses only go though 1 channel
+ free);
+ msgpack_sbuffer_clear(sbuffer);
+ api_free_object(arg);
+ return rv;
+}
+
+#if MIN_LOG_LEVEL <= DEBUG_LOG_LEVEL
+#define REQ "[request] "
+#define RES "[response] "
+#define NOT "[notification] "
+
+static void log_server_msg(uint64_t channel_id,
+ msgpack_sbuffer *packed)
+{
+ msgpack_unpacked unpacked;
+ msgpack_unpacked_init(&unpacked);
+ msgpack_unpack_next(&unpacked, packed->data, packed->size, NULL);
+ uint64_t type = unpacked.data.via.array.ptr[0].via.u64;
+ DLOGN("[msgpack-rpc] nvim -> client(%" PRIu64 ") ", channel_id);
+ FILE *f = open_log_file();
+ fprintf(f, type ? (type == 1 ? RES : NOT) : REQ);
+ log_msg_close(f, unpacked.data);
+ msgpack_unpacked_destroy(&unpacked);
+}
+
+static void log_client_msg(uint64_t channel_id,
+ bool is_request,
+ msgpack_object msg)
+{
+ DLOGN("[msgpack-rpc] client(%" PRIu64 ") -> nvim ", channel_id);
+ FILE *f = open_log_file();
+ fprintf(f, is_request ? REQ : RES);
+ log_msg_close(f, msg);
+}
+
+static void log_msg_close(FILE *f, msgpack_object msg)
+{
+ msgpack_object_print(f, msg);
+ fputc('\n', f);
+ fflush(f);
+ fclose(f);
+}
+#endif