aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLewis Russell <lewis6991@gmail.com>2024-11-18 17:15:05 +0000
committerLewis Russell <me@lewisr.dev>2024-11-21 11:35:16 +0000
commit534544cbf7ac92aef44336cc9da1bfc02a441e6e (patch)
treed93f781e0925aa84b6c4104ef1fd8a84f98c6629
parent01026ba47ba8a656bb5cd09afbb25b4b33c0b752 (diff)
downloadrneovim-534544cbf7ac92aef44336cc9da1bfc02a441e6e.tar.gz
rneovim-534544cbf7ac92aef44336cc9da1bfc02a441e6e.tar.bz2
rneovim-534544cbf7ac92aef44336cc9da1bfc02a441e6e.zip
test: move exec_lua logic to separate module
By making it a separate module, the embedded Nvim session can require this module directly instead of setup code sending over the module via RPC. Also make exec_lua wrap _G.print so messages can be seen in the test output immediately as the exec_lua returns.
-rw-r--r--test/functional/testnvim.lua122
-rw-r--r--test/functional/testnvim/exec_lua.lua148
2 files changed, 149 insertions, 121 deletions
diff --git a/test/functional/testnvim.lua b/test/functional/testnvim.lua
index 60b2f872fc..43c38d18c0 100644
--- a/test/functional/testnvim.lua
+++ b/test/functional/testnvim.lua
@@ -800,81 +800,6 @@ function M.exec_capture(code)
return M.api.nvim_exec2(code, { output = true }).output
end
---- @param f function
---- @return table<string,any>
-local function get_upvalues(f)
- local i = 1
- local upvalues = {} --- @type table<string,any>
- while true do
- local n, v = debug.getupvalue(f, i)
- if not n then
- break
- end
- upvalues[n] = v
- i = i + 1
- end
- return upvalues
-end
-
---- @param f function
---- @param upvalues table<string,any>
-local function set_upvalues(f, upvalues)
- local i = 1
- while true do
- local n = debug.getupvalue(f, i)
- if not n then
- break
- end
- if upvalues[n] then
- debug.setupvalue(f, i, upvalues[n])
- end
- i = i + 1
- end
-end
-
---- @type fun(f: function): table<string,any>
-_G.__get_upvalues = nil
-
---- @type fun(f: function, upvalues: table<string,any>)
-_G.__set_upvalues = nil
-
---- @param self table<string,function>
---- @param bytecode string
---- @param upvalues table<string,any>
---- @param ... any[]
---- @return any[] result
---- @return table<string,any> upvalues
-local function exec_lua_handler(self, bytecode, upvalues, ...)
- local f = assert(loadstring(bytecode))
- self.set_upvalues(f, upvalues)
- local ret = { f(...) } --- @type any[]
- --- @type table<string,any>
- local new_upvalues = self.get_upvalues(f)
-
- do -- Check return value types for better error messages
- local invalid_types = {
- ['thread'] = true,
- ['function'] = true,
- ['userdata'] = true,
- }
-
- for k, v in pairs(ret) do
- if invalid_types[type(v)] then
- error(
- string.format(
- "Return index %d with value '%s' of type '%s' cannot be serialized over RPC",
- k,
- tostring(v),
- type(v)
- )
- )
- end
- end
- end
-
- return ret, new_upvalues
-end
-
--- Execute Lua code in the wrapped Nvim session.
---
--- When `code` is passed as a function, it is converted into Lua byte code.
@@ -921,52 +846,7 @@ function M.exec_lua(code, ...)
end
assert(session, 'no Nvim session')
-
- if not session.exec_lua_setup then
- assert(
- session:request(
- 'nvim_exec_lua',
- [[
- _G.__test_exec_lua = {
- get_upvalues = loadstring((select(1,...))),
- set_upvalues = loadstring((select(2,...))),
- handler = loadstring((select(3,...)))
- }
- setmetatable(_G.__test_exec_lua, { __index = _G.__test_exec_lua })
- ]],
- { string.dump(get_upvalues), string.dump(set_upvalues), string.dump(exec_lua_handler) }
- )
- )
- session.exec_lua_setup = true
- end
-
- local stat, rv = session:request(
- 'nvim_exec_lua',
- 'return { _G.__test_exec_lua:handler(...) }',
- { string.dump(code), get_upvalues(code), ... }
- )
-
- if not stat then
- error(rv[2])
- end
-
- --- @type any[], table<string,any>
- local ret, upvalues = unpack(rv)
-
- -- Update upvalues
- if next(upvalues) then
- local caller = debug.getinfo(2)
- local f = caller.func
- -- On PUC-Lua, if the function is a tail call, then func will be nil.
- -- In this case we need to use the current function.
- if not f then
- assert(caller.source == '=(tail call)')
- f = debug.getinfo(1).func
- end
- set_upvalues(f, upvalues)
- end
-
- return unpack(ret, 1, table.maxn(ret))
+ return require('test.functional.testnvim.exec_lua')(session, 2, code, ...)
end
function M.get_pathsep()
diff --git a/test/functional/testnvim/exec_lua.lua b/test/functional/testnvim/exec_lua.lua
new file mode 100644
index 0000000000..ddd9905ce7
--- /dev/null
+++ b/test/functional/testnvim/exec_lua.lua
@@ -0,0 +1,148 @@
+--- @param f function
+--- @return table<string,any>
+local function get_upvalues(f)
+ local i = 1
+ local upvalues = {} --- @type table<string,any>
+ while true do
+ local n, v = debug.getupvalue(f, i)
+ if not n then
+ break
+ end
+ upvalues[n] = v
+ i = i + 1
+ end
+ return upvalues
+end
+
+--- @param f function
+--- @param upvalues table<string,any>
+local function set_upvalues(f, upvalues)
+ local i = 1
+ while true do
+ local n = debug.getupvalue(f, i)
+ if not n then
+ break
+ end
+ if upvalues[n] then
+ debug.setupvalue(f, i, upvalues[n])
+ end
+ i = i + 1
+ end
+end
+
+--- @param messages string[]
+--- @param ... ...
+local function add_print(messages, ...)
+ local msg = {} --- @type string[]
+ for i = 1, select('#', ...) do
+ msg[#msg + 1] = tostring(select(i, ...))
+ end
+ table.insert(messages, table.concat(msg, '\t'))
+end
+
+local invalid_types = {
+ ['thread'] = true,
+ ['function'] = true,
+ ['userdata'] = true,
+}
+
+--- @param r any[]
+local function check_returns(r)
+ for k, v in pairs(r) do
+ if invalid_types[type(v)] then
+ error(
+ string.format(
+ "Return index %d with value '%s' of type '%s' cannot be serialized over RPC",
+ k,
+ tostring(v),
+ type(v)
+ ),
+ 2
+ )
+ end
+ end
+end
+
+local M = {}
+
+--- This is run in the context of the remote Nvim instance.
+--- @param bytecode string
+--- @param upvalues table<string,any>
+--- @param ... any[]
+--- @return any[] result
+--- @return table<string,any> upvalues
+--- @return string[] messages
+function M.handler(bytecode, upvalues, ...)
+ local messages = {} --- @type string[]
+ local orig_print = _G.print
+
+ function _G.print(...)
+ add_print(messages, ...)
+ return orig_print(...)
+ end
+
+ local f = assert(loadstring(bytecode))
+
+ set_upvalues(f, upvalues)
+
+ -- Run in pcall so we can return any print messages
+ local ret = { pcall(f, ...) } --- @type any[]
+
+ _G.print = orig_print
+
+ local new_upvalues = get_upvalues(f)
+
+ -- Check return value types for better error messages
+ check_returns(ret)
+
+ return ret, new_upvalues, messages
+end
+
+--- @param session test.Session
+--- @param lvl integer
+--- @param code function
+--- @param ... ...
+local function run(session, lvl, code, ...)
+ local stat, rv = session:request(
+ 'nvim_exec_lua',
+ [[return { require('test.functional.testnvim.exec_lua').handler(...) }]],
+ { string.dump(code), get_upvalues(code), ... }
+ )
+
+ if not stat then
+ error(rv[2], 2)
+ end
+
+ --- @type any[], table<string,any>, string[]
+ local ret, upvalues, messages = unpack(rv)
+
+ for _, m in ipairs(messages) do
+ print(m)
+ end
+
+ if not ret[1] then
+ error(ret[2], 2)
+ end
+
+ -- Update upvalues
+ if next(upvalues) then
+ local caller = debug.getinfo(lvl)
+ local i = 0
+
+ -- On PUC-Lua, if the function is a tail call, then func will be nil.
+ -- In this case we need to use the caller.
+ while not caller.func do
+ i = i + 1
+ caller = debug.getinfo(lvl + i)
+ end
+ set_upvalues(caller.func, upvalues)
+ end
+
+ return unpack(ret, 2, table.maxn(ret))
+end
+
+return setmetatable(M, {
+ __call = function(_, ...)
+ return run(...)
+ end,
+})