diff options
-rw-r--r-- | runtime/lua/vim/shared.lua | 75 | ||||
-rw-r--r-- | test/functional/lua/vim_spec.lua | 47 | ||||
-rw-r--r-- | test/functional/plugin/lsp_spec.lua | 5 | ||||
-rw-r--r-- | test/helpers.lua | 3 |
4 files changed, 78 insertions, 52 deletions
diff --git a/runtime/lua/vim/shared.lua b/runtime/lua/vim/shared.lua index 5c89c63f7b..750af02f19 100644 --- a/runtime/lua/vim/shared.lua +++ b/runtime/lua/vim/shared.lua @@ -477,48 +477,77 @@ end --- 2. (arg_value, fn, msg) --- - arg_value: argument value --- - fn: any function accepting one argument, returns true if and ---- only if the argument is valid +--- only if the argument is valid. Can optionally return an additional +--- informative error message as the second returned value. --- - msg: (optional) error string if validation fails function vim.validate(opt) end -- luacheck: no unused -vim.validate = (function() + +do local type_names = { - t='table', s='string', n='number', b='boolean', f='function', c='callable', - ['table']='table', ['string']='string', ['number']='number', - ['boolean']='boolean', ['function']='function', ['callable']='callable', - ['nil']='nil', ['thread']='thread', ['userdata']='userdata', + ['table'] = 'table', t = 'table', + ['string'] = 'string', s = 'string', + ['number'] = 'number', n = 'number', + ['boolean'] = 'boolean', b = 'boolean', + ['function'] = 'function', f = 'function', + ['callable'] = 'callable', c = 'callable', + ['nil'] = 'nil', + ['thread'] = 'thread', + ['userdata'] = 'userdata', } - local function _type_name(t) - local tname = type_names[t] - if tname == nil then - error(string.format('invalid type name: %s', tostring(t))) - end - return tname - end + local function _is_type(val, t) return t == 'callable' and vim.is_callable(val) or type(val) == t end - return function(opt) - assert(type(opt) == 'table', string.format('opt: expected table, got %s', type(opt))) + local function is_valid(opt) + if type(opt) ~= 'table' then + return false, string.format('opt: expected table, got %s', type(opt)) + end + for param_name, spec in pairs(opt) do - assert(type(spec) == 'table', string.format('%s: expected table, got %s', param_name, type(spec))) + if type(spec) ~= 'table' then + return false, string.format('opt[%s]: expected table, got %s', param_name, type(spec)) + end local val = spec[1] -- Argument value. local t = spec[2] -- Type name, or callable. local optional = (true == spec[3]) - if not vim.is_callable(t) then -- Check type name. - if (not optional or val ~= nil) and not _is_type(val, _type_name(t)) then - error(string.format("%s: expected %s, got %s", param_name, _type_name(t), type(val))) + if type(t) == 'string' then + local translated_type_name = type_names[t] + if not translated_type_name then + return false, string.format('invalid type name: %s', t) + end + + if (not optional or val ~= nil) and not _is_type(val, translated_type_name) then + return false, string.format("%s: expected %s, got %s", param_name, translated_type_name, type(val)) + end + elseif vim.is_callable(t) then + -- Check user-provided validation function. + local valid, optional_message = t(val) + if not valid then + local error_message = string.format("%s: expected %s, got %s", param_name, (spec[3] or '?'), val) + if not (optional_message == nil) then + error_message = error_message .. string.format(". Info: %s", optional_message) + end + + return false, error_message end - elseif not t(val) then -- Check user-provided validation function. - error(string.format("%s: expected %s, got %s", param_name, (spec[3] or '?'), val)) + else + return false, string.format("invalid type name: %s", tostring(t)) end end - return true + + return true, nil end -end)() + function vim.validate(opt) + local ok, err_msg = is_valid(opt) + if not ok then + error(debug.traceback(err_msg, 2), 2) + end + end +end --- Returns true if object `f` can be called as a function. --- --@param f Any object diff --git a/test/functional/lua/vim_spec.lua b/test/functional/lua/vim_spec.lua index 61447f1152..bbd999ead2 100644 --- a/test/functional/lua/vim_spec.lua +++ b/test/functional/lua/vim_spec.lua @@ -13,6 +13,7 @@ local feed = helpers.feed local pcall_err = helpers.pcall_err local exec_lua = helpers.exec_lua local matches = helpers.matches +local contains = helpers.contains local source = helpers.source local NIL = helpers.NIL local retry = helpers.retry @@ -262,12 +263,9 @@ describe('lua stdlib', function() -- Validates args. eq(true, pcall(split, 'string', 'string')) - eq('Error executing lua: .../shared.lua: s: expected string, got number', - pcall_err(split, 1, 'string')) - eq('Error executing lua: .../shared.lua: sep: expected string, got number', - pcall_err(split, 'string', 1)) - eq('Error executing lua: .../shared.lua: plain: expected boolean, got number', - pcall_err(split, 'string', 'string', 1)) + contains('s: expected string, got number', pcall_err(split, 1, 'string')) + contains('sep: expected string, got number', pcall_err(split, 'string', 1)) + contains('plain: expected boolean, got number', pcall_err(split, 'string', 'string', 1)) end) it('vim.trim', function() @@ -287,8 +285,7 @@ describe('lua stdlib', function() end -- Validates args. - eq('Error executing lua: .../shared.lua: s: expected string, got number', - pcall_err(trim, 2)) + contains('s: expected string, got number', pcall_err(trim, 2)) end) it('vim.inspect', function() @@ -366,7 +363,7 @@ describe('lua stdlib', function() eq('foo%%%-bar', exec_lua([[return vim.pesc(vim.pesc('foo-bar'))]])) -- Validates args. - eq('Error executing lua: .../shared.lua: s: expected string, got number', + contains('s: expected string, got number', pcall_err(exec_lua, [[return vim.pesc(2)]])) end) @@ -624,14 +621,14 @@ describe('lua stdlib', function() it('vim.list_extend', function() eq({1,2,3}, exec_lua [[ return vim.list_extend({1}, {2,3}) ]]) - eq('Error executing lua: .../shared.lua: src: expected table, got nil', - pcall_err(exec_lua, [[ return vim.list_extend({1}, nil) ]])) eq({1,2}, exec_lua [[ return vim.list_extend({1}, {2;a=1}) ]]) eq(true, exec_lua [[ local a = {1} return vim.list_extend(a, {2;a=1}) == a ]]) eq({2}, exec_lua [[ return vim.list_extend({}, {2;a=1}, 1) ]]) eq({}, exec_lua [[ return vim.list_extend({}, {2;a=1}, 2) ]]) eq({}, exec_lua [[ return vim.list_extend({}, {2;a=1}, 1, -1) ]]) eq({2}, exec_lua [[ return vim.list_extend({}, {2;a=1}, -1, 2) ]]) + + contains('src: expected table, got nil', pcall_err(exec_lua, [[ return vim.list_extend({1}, nil) ]])) end) it('vim.tbl_add_reverse_lookup', function() @@ -820,33 +817,29 @@ describe('lua stdlib', function() exec_lua("vim.validate{arg1={{}, 't' }, arg2={ 'foo', 's' }}") exec_lua("vim.validate{arg1={2, function(a) return (a % 2) == 0 end, 'even number' }}") - eq("Error executing lua: .../shared.lua: 1: expected table, got number", - pcall_err(exec_lua, "vim.validate{ 1, 'x' }")) - eq("Error executing lua: .../shared.lua: invalid type name: x", - pcall_err(exec_lua, "vim.validate{ arg1={ 1, 'x' }}")) - eq("Error executing lua: .../shared.lua: invalid type name: 1", - pcall_err(exec_lua, "vim.validate{ arg1={ 1, 1 }}")) - eq("Error executing lua: .../shared.lua: invalid type name: nil", - pcall_err(exec_lua, "vim.validate{ arg1={ 1 }}")) + contains("expected table, got number", pcall_err(exec_lua, "vim.validate{ 1, 'x' }")) + contains("invalid type name: x", pcall_err(exec_lua, "vim.validate{ arg1={ 1, 'x' }}")) + contains("invalid type name: 1", pcall_err(exec_lua, "vim.validate{ arg1={ 1, 1 }}")) + contains("invalid type name: nil", pcall_err(exec_lua, "vim.validate{ arg1={ 1 }}")) -- Validated parameters are required by default. - eq("Error executing lua: .../shared.lua: arg1: expected string, got nil", + contains("arg1: expected string, got nil", pcall_err(exec_lua, "vim.validate{ arg1={ nil, 's' }}")) -- Explicitly required. - eq("Error executing lua: .../shared.lua: arg1: expected string, got nil", + contains("arg1: expected string, got nil", pcall_err(exec_lua, "vim.validate{ arg1={ nil, 's', false }}")) - eq("Error executing lua: .../shared.lua: arg1: expected table, got number", + contains("arg1: expected table, got number", pcall_err(exec_lua, "vim.validate{arg1={1, 't'}}")) - eq("Error executing lua: .../shared.lua: arg2: expected string, got number", + contains("arg2: expected string, got number", pcall_err(exec_lua, "vim.validate{arg1={{}, 't'}, arg2={1, 's'}}")) - eq("Error executing lua: .../shared.lua: arg2: expected string, got nil", + contains("arg2: expected string, got nil", pcall_err(exec_lua, "vim.validate{arg1={{}, 't'}, arg2={nil, 's'}}")) - eq("Error executing lua: .../shared.lua: arg2: expected string, got nil", + contains("arg2: expected string, got nil", pcall_err(exec_lua, "vim.validate{arg1={{}, 't'}, arg2={nil, 's'}}")) - eq("Error executing lua: .../shared.lua: arg1: expected even number, got 3", + contains("arg1: expected even number, got 3", pcall_err(exec_lua, "vim.validate{arg1={3, function(a) return a == 1 end, 'even number'}}")) - eq("Error executing lua: .../shared.lua: arg1: expected ?, got 3", + contains("arg1: expected %?, got 3", pcall_err(exec_lua, "vim.validate{arg1={3, function(a) return a == 1 end}}")) end) diff --git a/test/functional/plugin/lsp_spec.lua b/test/functional/plugin/lsp_spec.lua index f514f4ea6f..3c4c448e7a 100644 --- a/test/functional/plugin/lsp_spec.lua +++ b/test/functional/plugin/lsp_spec.lua @@ -6,6 +6,7 @@ local buf_lines = helpers.buf_lines local dedent = helpers.dedent local exec_lua = helpers.exec_lua local eq = helpers.eq +local contains = helpers.contains local pcall_err = helpers.pcall_err local pesc = helpers.pesc local insert = helpers.insert @@ -747,8 +748,8 @@ describe('LSP', function() end) it('should invalid cmd argument', function() - eq('Error executing lua: .../shared.lua: cmd: expected list, got nvim', pcall_err(_cmd_parts, "nvim")) - eq('Error executing lua: .../shared.lua: cmd argument: expected string, got number', pcall_err(_cmd_parts, {"nvim", 1})) + contains('cmd: expected list, got nvim', pcall_err(_cmd_parts, "nvim")) + contains('cmd argument: expected string, got number', pcall_err(_cmd_parts, {"nvim", 1})) end) end) end) diff --git a/test/helpers.lua b/test/helpers.lua index 5acd2ea0bd..d59ce2a7c7 100644 --- a/test/helpers.lua +++ b/test/helpers.lua @@ -99,6 +99,9 @@ function module.matches(pat, actual) end error(string.format('Pattern does not match.\nPattern:\n%s\nActual:\n%s', pat, actual)) end +function module.contains(pat, actual) + return module.matches(".*" .. pat .. ".*", actual) +end --- Asserts that `pat` matches one or more lines in the tail of $NVIM_LOG_FILE. --- |