1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
|
--- @param f function
--- @return table<string,any>
local function get_upvalues(f)
local i = 1
local upvalues = {} --- @type table<string,any>
while true do
local n, v = debug.getupvalue(f, i)
if not n then
break
end
upvalues[n] = v
i = i + 1
end
return upvalues
end
--- @param f function
--- @param upvalues table<string,any>
local function set_upvalues(f, upvalues)
local i = 1
while true do
local n = debug.getupvalue(f, i)
if not n then
break
end
if upvalues[n] then
debug.setupvalue(f, i, upvalues[n])
end
i = i + 1
end
end
--- @param messages string[]
--- @param ... ...
local function add_print(messages, ...)
local msg = {} --- @type string[]
for i = 1, select('#', ...) do
msg[#msg + 1] = tostring(select(i, ...))
end
table.insert(messages, table.concat(msg, '\t'))
end
local invalid_types = {
['thread'] = true,
['function'] = true,
['userdata'] = true,
}
--- @param r any[]
local function check_returns(r)
for k, v in pairs(r) do
if invalid_types[type(v)] then
error(
string.format(
"Return index %d with value '%s' of type '%s' cannot be serialized over RPC",
k,
tostring(v),
type(v)
),
2
)
end
end
end
local M = {}
--- This is run in the context of the remote Nvim instance.
--- @param bytecode string
--- @param upvalues table<string,any>
--- @param ... any[]
--- @return any[] result
--- @return table<string,any> upvalues
--- @return string[] messages
function M.handler(bytecode, upvalues, ...)
local messages = {} --- @type string[]
local orig_print = _G.print
function _G.print(...)
add_print(messages, ...)
return orig_print(...)
end
local f = assert(loadstring(bytecode))
set_upvalues(f, upvalues)
-- Run in pcall so we can return any print messages
local ret = { pcall(f, ...) } --- @type any[]
_G.print = orig_print
local new_upvalues = get_upvalues(f)
-- Check return value types for better error messages
check_returns(ret)
return ret, new_upvalues, messages
end
--- @param session test.Session
--- @param lvl integer
--- @param code function
--- @param ... ...
local function run(session, lvl, code, ...)
local stat, rv = session:request(
'nvim_exec_lua',
[[return { require('test.functional.testnvim.exec_lua').handler(...) }]],
{ string.dump(code), get_upvalues(code), ... }
)
if not stat then
error(rv[2], 2)
end
--- @type any[], table<string,any>, string[]
local ret, upvalues, messages = unpack(rv)
for _, m in ipairs(messages) do
print(m)
end
if not ret[1] then
error(ret[2], 2)
end
-- Update upvalues
if next(upvalues) then
local caller = debug.getinfo(lvl)
local i = 0
-- On PUC-Lua, if the function is a tail call, then func will be nil.
-- In this case we need to use the caller.
while not caller.func do
i = i + 1
caller = debug.getinfo(lvl + i)
end
set_upvalues(caller.func, upvalues)
end
return unpack(ret, 2, table.maxn(ret))
end
return setmetatable(M, {
__call = function(_, ...)
return run(...)
end,
})
|