diff options
-rw-r--r-- | runtime/doc/eval.txt | 15 | ||||
-rw-r--r-- | src/nvim/eval.c | 48 | ||||
-rw-r--r-- | src/nvim/eval.lua | 1 | ||||
-rw-r--r-- | src/nvim/event/socket.c | 74 | ||||
-rw-r--r-- | src/nvim/msgpack_rpc/channel.c | 95 | ||||
-rw-r--r-- | src/nvim/msgpack_rpc/server.c | 12 | ||||
-rw-r--r-- | src/nvim/path.c | 2 | ||||
-rw-r--r-- | test/functional/api/server_requests_spec.lua | 73 | ||||
-rw-r--r-- | test/functional/helpers.lua | 5 |
9 files changed, 308 insertions, 17 deletions
diff --git a/runtime/doc/eval.txt b/runtime/doc/eval.txt index 4e71b89067..b692b28418 100644 --- a/runtime/doc/eval.txt +++ b/runtime/doc/eval.txt @@ -6863,6 +6863,21 @@ sinh({expr}) *sinh()* :echo sinh(-0.9) < -1.026517 +sockconnect({mode}, {address}, {opts}) *sockconnect()* + Connect a socket to an address. If {mode} is "pipe" then + {address} should be the path of a named pipe. If {mode} is + "tcp" then {address} should be of the form "host:port" where + the host should be an ip adderess or host name, and port the + port number. Currently only rpc sockets are supported, so + {opts} must be passed with "rpc" set to |TRUE|. + + {opts} is a dictionary with these keys: + rpc : If set, |msgpack-rpc| will be used to communicate + over the socket. + Returns: + - The channel ID on success, which is used by + |rpcnotify()| and |rpcrequest()| and |rpcstop()|. + - 0 on invalid arguments or connection failure. sort({list} [, {func} [, {dict}]]) *sort()* *E702* Sort the items in {list} in-place. Returns {list}. diff --git a/src/nvim/eval.c b/src/nvim/eval.c index 7bfd638e86..4e303414a3 100644 --- a/src/nvim/eval.c +++ b/src/nvim/eval.c @@ -15058,6 +15058,54 @@ static void f_simplify(typval_T *argvars, typval_T *rettv, FunPtr fptr) rettv->v_type = VAR_STRING; } +/// "sockconnect()" function +static void f_sockconnect(typval_T *argvars, typval_T *rettv, FunPtr fptr) +{ + if (argvars[0].v_type != VAR_STRING || argvars[1].v_type != VAR_STRING) { + EMSG(_(e_invarg)); + return; + } + if (argvars[2].v_type != VAR_DICT && argvars[2].v_type != VAR_UNKNOWN) { + // Wrong argument types + EMSG2(_(e_invarg2), "expected dictionary"); + return; + } + + const char *mode = tv_get_string(&argvars[0]); + const char *address = tv_get_string(&argvars[1]); + + bool tcp; + if (strcmp(mode, "tcp") == 0) { + tcp = true; + } else if (strcmp(mode, "pipe") == 0) { + tcp = false; + } else { + EMSG2(_(e_invarg2), "invalid mode"); + return; + } + + bool rpc = false; + if (argvars[2].v_type == VAR_DICT) { + dict_T *opts = argvars[2].vval.v_dict; + rpc = tv_dict_get_number(opts, "rpc") != 0; + } + + if (!rpc) { + EMSG2(_(e_invarg2), "rpc option must be true"); + return; + } + + const char *error = NULL; + uint64_t id = channel_connect(tcp, address, 50, &error); + + if (error) { + EMSG2(_("connection failed: %s"), error); + } + + rettv->vval.v_number = (varnumber_T)id; + rettv->v_type = VAR_NUMBER; +} + /// struct used in the array that's given to qsort() typedef struct { listitem_T *item; diff --git a/src/nvim/eval.lua b/src/nvim/eval.lua index 533403b2b0..334e10eb6c 100644 --- a/src/nvim/eval.lua +++ b/src/nvim/eval.lua @@ -268,6 +268,7 @@ return { simplify={args=1}, sin={args=1, func="float_op_wrapper", data="&sin"}, sinh={args=1, func="float_op_wrapper", data="&sinh"}, + sockconnect={args={2,3}}, sort={args={1, 3}}, soundfold={args=1}, spellbadword={args={0, 1}}, diff --git a/src/nvim/event/socket.c b/src/nvim/event/socket.c index bc5a4ec75e..30a71a5586 100644 --- a/src/nvim/event/socket.c +++ b/src/nvim/event/socket.c @@ -15,6 +15,7 @@ #include "nvim/vim.h" #include "nvim/strings.h" #include "nvim/path.h" +#include "nvim/main.h" #include "nvim/memory.h" #include "nvim/macros.h" #include "nvim/charset.h" @@ -189,3 +190,76 @@ static void close_cb(uv_handle_t *handle) watcher->close_cb(watcher, watcher->data); } } + +static void connect_cb(uv_connect_t *req, int status) +{ + int *ret_status = req->data; + *ret_status = status; + if (status != 0) { + uv_close((uv_handle_t *)req->handle, NULL); + } +} + +bool socket_connect(Loop *loop, Stream *stream, + bool is_tcp, const char *address, + int timeout, const char **error) +{ + bool success = false; + int status; + uv_connect_t req; + req.data = &status; + uv_stream_t *uv_stream; + + uv_tcp_t *tcp = &stream->uv.tcp; + uv_getaddrinfo_t addr_req; + addr_req.addrinfo = NULL; + const struct addrinfo *addrinfo = NULL; + char *addr = NULL; + if (is_tcp) { + addr = xstrdup(address); + char *host_end = strrchr(addr, ':'); + if (!host_end) { + *error = _("tcp address must be host:port"); + goto cleanup; + } + *host_end = NUL; + + const struct addrinfo hints = { .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + .ai_flags = AI_NUMERICSERV }; + int retval = uv_getaddrinfo(&loop->uv, &addr_req, NULL, + addr, host_end+1, &hints); + if (retval != 0) { + *error = _("failed to lookup host or port"); + goto cleanup; + } + addrinfo = addr_req.addrinfo; + +tcp_retry: + uv_tcp_init(&loop->uv, tcp); + uv_tcp_connect(&req, tcp, addrinfo->ai_addr, connect_cb); + uv_stream = (uv_stream_t *)tcp; + + } else { + uv_pipe_t *pipe = &stream->uv.pipe; + uv_pipe_init(&loop->uv, pipe, 0); + uv_pipe_connect(&req, pipe, address, connect_cb); + uv_stream = (uv_stream_t *)pipe; + } + status = 1; + LOOP_PROCESS_EVENTS_UNTIL(&main_loop, NULL, timeout, status != 1); + if (status == 0) { + stream_init(NULL, stream, -1, uv_stream); + success = true; + } else if (is_tcp && addrinfo->ai_next) { + addrinfo = addrinfo->ai_next; + goto tcp_retry; + } else { + *error = _("connection refused"); + } + +cleanup: + xfree(addr); + uv_freeaddrinfo(addr_req.addrinfo); + return success; +} diff --git a/src/nvim/msgpack_rpc/channel.c b/src/nvim/msgpack_rpc/channel.c index cd64e14976..e8ee0ede75 100644 --- a/src/nvim/msgpack_rpc/channel.c +++ b/src/nvim/msgpack_rpc/channel.c @@ -12,6 +12,7 @@ #include "nvim/api/vim.h" #include "nvim/api/ui.h" #include "nvim/msgpack_rpc/channel.h" +#include "nvim/msgpack_rpc/server.h" #include "nvim/event/loop.h" #include "nvim/event/libuv_process.h" #include "nvim/event/rstream.h" @@ -28,6 +29,7 @@ #include "nvim/map.h" #include "nvim/log.h" #include "nvim/misc1.h" +#include "nvim/path.h" #include "nvim/lib/kvec.h" #include "nvim/os/input.h" @@ -41,7 +43,8 @@ typedef enum { kChannelTypeSocket, kChannelTypeProc, - kChannelTypeStdio + kChannelTypeStdio, + kChannelTypeInternal } ChannelType; typedef struct { @@ -125,7 +128,7 @@ uint64_t channel_from_process(Process *proc, uint64_t id) wstream_init(proc->in, 0); rstream_init(proc->out, 0); - rstream_start(proc->out, parse_msgpack, channel); + rstream_start(proc->out, receive_msgpack, channel); return channel->id; } @@ -142,7 +145,36 @@ void channel_from_connection(SocketWatcher *watcher) channel->data.stream.internal_data = channel; wstream_init(&channel->data.stream, 0); rstream_init(&channel->data.stream, CHANNEL_BUFFER_SIZE); - rstream_start(&channel->data.stream, parse_msgpack, channel); + rstream_start(&channel->data.stream, receive_msgpack, channel); +} + +uint64_t channel_connect(bool tcp, const char *address, + int timeout, const char **error) +{ + if (!tcp) { + char *path = fix_fname(address); + if (server_owns_pipe_address(path)) { + // avoid deadlock + xfree(path); + return channel_create_internal(); + } + xfree(path); + } + + Channel *channel = register_channel(kChannelTypeSocket, 0, NULL); + if (!socket_connect(&main_loop, &channel->data.stream, + tcp, address, timeout, error)) { + decref(channel); + return 0; + } + + incref(channel); // close channel only after the stream is closed + channel->data.stream.internal_close_cb = close_cb; + channel->data.stream.internal_data = channel; + wstream_init(&channel->data.stream, 0); + rstream_init(&channel->data.stream, CHANNEL_BUFFER_SIZE); + rstream_start(&channel->data.stream, receive_msgpack, channel); + return channel->id; } /// Sends event/arguments to channel @@ -305,11 +337,20 @@ void channel_from_stdio(void) incref(channel); // stdio channels are only closed on exit // read stream rstream_init_fd(&main_loop, &channel->data.std.in, 0, CHANNEL_BUFFER_SIZE); - rstream_start(&channel->data.std.in, parse_msgpack, channel); + rstream_start(&channel->data.std.in, receive_msgpack, channel); // write stream wstream_init_fd(&main_loop, &channel->data.std.out, 1, 0); } +/// Creates a loopback channel. This is used to avoid deadlock +/// when an instance connects to its own named pipe. +uint64_t channel_create_internal(void) +{ + Channel *channel = register_channel(kChannelTypeInternal, 0, NULL); + incref(channel); // internal channel lives until process exit + return channel->id; +} + void channel_process_exit(uint64_t id, int status) { Channel *channel = pmap_get(uint64_t)(channels, id); @@ -318,8 +359,8 @@ void channel_process_exit(uint64_t id, int status) decref(channel); } -static void parse_msgpack(Stream *stream, RBuffer *rbuf, size_t c, void *data, - bool eof) +static void receive_msgpack(Stream *stream, RBuffer *rbuf, size_t c, + void *data, bool eof) { Channel *channel = data; incref(channel); @@ -341,6 +382,14 @@ static void parse_msgpack(Stream *stream, RBuffer *rbuf, size_t c, void *data, rbuffer_read(rbuf, msgpack_unpacker_buffer(channel->unpacker), count); msgpack_unpacker_buffer_consumed(channel->unpacker, count); + parse_msgpack(channel); + +end: + decref(channel); +} + +static void parse_msgpack(Channel *channel) +{ msgpack_unpacked unpacked; msgpack_unpacked_init(&unpacked); msgpack_unpack_return result; @@ -364,7 +413,7 @@ static void parse_msgpack(Stream *stream, RBuffer *rbuf, size_t c, void *data, } msgpack_unpacked_destroy(&unpacked); // Bail out from this event loop iteration - goto end; + return; } handle_request(channel, &unpacked.data); @@ -388,11 +437,9 @@ static void parse_msgpack(Stream *stream, RBuffer *rbuf, size_t c, void *data, "This error can also happen when deserializing " "an object with high level of nesting"); } - -end: - decref(channel); } + static void handle_request(Channel *channel, msgpack_object *request) FUNC_ATTR_NONNULL_ALL { @@ -502,8 +549,11 @@ static bool channel_write(Channel *channel, WBuffer *buffer) case kChannelTypeStdio: success = wstream_write(&channel->data.std.out, buffer); break; - default: - abort(); + case kChannelTypeInternal: + incref(channel); + CREATE_EVENT(channel->events, internal_read_event, 2, channel, buffer); + success = true; + break; } if (!success) { @@ -520,6 +570,22 @@ static bool channel_write(Channel *channel, WBuffer *buffer) return success; } +static void internal_read_event(void **argv) +{ + Channel *channel = argv[0]; + WBuffer *buffer = argv[1]; + + msgpack_unpacker_reserve_buffer(channel->unpacker, buffer->size); + memcpy(msgpack_unpacker_buffer(channel->unpacker), + buffer->data, buffer->size); + msgpack_unpacker_buffer_consumed(channel->unpacker, buffer->size); + + parse_msgpack(channel); + + decref(channel); + wstream_release_wbuffer(buffer); +} + static void send_error(Channel *channel, uint64_t id, char *err) { Error e = ERROR_INIT; @@ -636,8 +702,9 @@ static void close_channel(Channel *channel) stream_close(&channel->data.std.out, NULL, NULL); multiqueue_put(main_loop.fast_events, exit_event, 1, channel); return; - default: - abort(); + case kChannelTypeInternal: + // nothing to free. + break; } decref(channel); diff --git a/src/nvim/msgpack_rpc/server.c b/src/nvim/msgpack_rpc/server.c index bae5a32850..c9edd05dc2 100644 --- a/src/nvim/msgpack_rpc/server.c +++ b/src/nvim/msgpack_rpc/server.c @@ -97,6 +97,18 @@ char *server_address_new(void) #endif } +/// Check if this instance owns a pipe address. +/// The argument must already be resolved to an absolute path! +bool server_owns_pipe_address(const char *path) +{ + for (int i = 0; i < watchers.ga_len; i++) { + if (!strcmp(path, ((SocketWatcher **)watchers.ga_data)[i]->addr)) { + return true; + } + } + return false; +} + /// Starts listening for API calls. /// /// The socket type is determined by parsing `endpoint`: If it's a valid IPv4 diff --git a/src/nvim/path.c b/src/nvim/path.c index 9162b6da4d..f2339c8046 100644 --- a/src/nvim/path.c +++ b/src/nvim/path.c @@ -1715,7 +1715,7 @@ int vim_FullName(const char *fname, char *buf, size_t len, bool force) /// /// @param fname is the filename to expand /// @return [allocated] Full path (NULL for failure). -char *fix_fname(char *fname) +char *fix_fname(const char *fname) { #ifdef UNIX return FullName_save(fname, true); diff --git a/test/functional/api/server_requests_spec.lua b/test/functional/api/server_requests_spec.lua index 658077b112..cf15062325 100644 --- a/test/functional/api/server_requests_spec.lua +++ b/test/functional/api/server_requests_spec.lua @@ -9,6 +9,8 @@ local nvim_prog, command, funcs = helpers.nvim_prog, helpers.command, helpers.fu local source, next_message = helpers.source, helpers.next_message local ok = helpers.ok local meths = helpers.meths +local spawn, nvim_argv = helpers.spawn, helpers.nvim_argv +local set_session = helpers.set_session describe('server -> client', function() local cid @@ -225,4 +227,75 @@ describe('server -> client', function() end) end) + describe('when connecting to another nvim instance', function() + local function connect_test(server, mode, address) + local serverpid = funcs.getpid() + local client = spawn(nvim_argv) + set_session(client, true) + local clientpid = funcs.getpid() + neq(serverpid, clientpid) + local id = funcs.sockconnect(mode, address, {rpc=true}) + ok(id > 0) + + funcs.rpcrequest(id, 'nvim_set_current_line', 'hello') + local client_id = funcs.rpcrequest(id, 'nvim_get_api_info')[1] + + set_session(server, true) + eq(serverpid, funcs.getpid()) + eq('hello', meths.get_current_line()) + + -- method calls work both ways + funcs.rpcrequest(client_id, 'nvim_set_current_line', 'howdy!') + eq(id, funcs.rpcrequest(client_id, 'nvim_get_api_info')[1]) + + set_session(client, true) + eq(clientpid, funcs.getpid()) + eq('howdy!', meths.get_current_line()) + + server:close() + client:close() + end + + it('over a named pipe', function() + local server = spawn(nvim_argv) + set_session(server) + local address = funcs.serverlist()[1] + local first = string.sub(address,1,1) + ok(first == '/' or first == '\\') + connect_test(server, 'pipe', address) + end) + + it('to an ip adress', function() + local server = spawn(nvim_argv) + set_session(server) + local address = funcs.serverstart("127.0.0.1:") + eq('127.0.0.1:', string.sub(address,1,10)) + connect_test(server, 'tcp', address) + end) + + it('to a hostname', function() + local server = spawn(nvim_argv) + set_session(server) + local address = funcs.serverstart("localhost:") + eq('localhost:', string.sub(address,1,10)) + connect_test(server, 'tcp', address) + end) + end) + + describe('when connecting to its own pipe adress', function() + it('it does not deadlock', function() + local address = funcs.serverlist()[1] + local first = string.sub(address,1,1) + ok(first == '/' or first == '\\') + local serverpid = funcs.getpid() + + local id = funcs.sockconnect('pipe', address, {rpc=true}) + + funcs.rpcrequest(id, 'nvim_set_current_line', 'hello') + eq('hello', meths.get_current_line()) + eq(serverpid, funcs.rpcrequest(id, "nvim_eval", "getpid()")) + + eq(id, funcs.rpcrequest(id, 'nvim_get_api_info')[1]) + end) + end) end) diff --git a/test/functional/helpers.lua b/test/functional/helpers.lua index b03840b3fe..62b0ce1200 100644 --- a/test/functional/helpers.lua +++ b/test/functional/helpers.lua @@ -76,8 +76,8 @@ end local session, loop_running, last_error -local function set_session(s) - if session then +local function set_session(s, keep) + if session and not keep then session:close() end session = s @@ -609,6 +609,7 @@ local module = { nvim = nvim, nvim_async = nvim_async, nvim_prog = nvim_prog, + nvim_argv = nvim_argv, nvim_set = nvim_set, nvim_dir = nvim_dir, buffer = buffer, |