diff options
Diffstat (limited to 'runtime/lua/vim/shared.lua')
-rw-r--r-- | runtime/lua/vim/shared.lua | 111 |
1 files changed, 81 insertions, 30 deletions
diff --git a/runtime/lua/vim/shared.lua b/runtime/lua/vim/shared.lua index 6e40b6ca52..f0dc34608c 100644 --- a/runtime/lua/vim/shared.lua +++ b/runtime/lua/vim/shared.lua @@ -1,8 +1,10 @@ -- Functions shared by Nvim and its test-suite. -- --- The singular purpose of this module is to share code with the Nvim --- test-suite. If, in the future, Nvim itself is used to run the test-suite --- instead of "vanilla Lua", these functions could move to src/nvim/lua/vim.lua +-- These are "pure" lua functions not depending of the state of the editor. +-- Thus they should always be available whenever nvim-related lua code is run, +-- regardless if it is code in the editor itself, or in worker threads/processes, +-- or the test suite. (Eventually the test suite will be run in a worker process, +-- so this wouldn't be a separate case to consider) local vim = vim or {} @@ -12,7 +14,7 @@ local vim = vim or {} --- same functions as those in the input table. Userdata and threads are not --- copied and will throw an error. --- ----@param orig Table to copy +---@param orig table Table to copy ---@returns New table of copied keys and (nested) values. function vim.deepcopy(orig) end -- luacheck: no unused vim.deepcopy = (function() @@ -21,17 +23,16 @@ vim.deepcopy = (function() end local deepcopy_funcs = { - table = function(orig) + table = function(orig, cache) + if cache[orig] then return cache[orig] end local copy = {} - if vim._empty_dict_mt ~= nil and getmetatable(orig) == vim._empty_dict_mt then - copy = vim.empty_dict() - end - + cache[orig] = copy + local mt = getmetatable(orig) for k, v in pairs(orig) do - copy[vim.deepcopy(k)] = vim.deepcopy(v) + copy[vim.deepcopy(k, cache)] = vim.deepcopy(v, cache) end - return copy + return setmetatable(copy, mt) end, number = _id, string = _id, @@ -40,10 +41,10 @@ vim.deepcopy = (function() ['function'] = _id, } - return function(orig) + return function(orig, cache) local f = deepcopy_funcs[type(orig)] if f then - return f(orig) + return f(orig, cache or {}) else error("Cannot deepcopy object of type "..type(orig)) end @@ -330,7 +331,7 @@ end --- Add the reverse lookup values to an existing table. --- For example: ---- `tbl_add_reverse_lookup { A = 1 } == { [1] = 'A', A = 1 }` +--- ``tbl_add_reverse_lookup { A = 1 } == { [1] = 'A', A = 1 }`` -- --Do note that it *modifies* the input. ---@param o table The table to add the reverse to. @@ -346,6 +347,33 @@ function vim.tbl_add_reverse_lookup(o) return o end +--- Index into a table (first argument) via string keys passed as subsequent arguments. +--- Return `nil` if the key does not exist. +--_ +--- Examples: +--- <pre> +--- vim.tbl_get({ key = { nested_key = true }}, 'key', 'nested_key') == true +--- vim.tbl_get({ key = {}}, 'key', 'nested_key') == nil +--- </pre> +--- +---@param o Table to index +---@param ... Optional strings (0 or more, variadic) via which to index the table +--- +---@returns nested value indexed by key if it exists, else nil +function vim.tbl_get(o, ...) + local keys = {...} + if #keys == 0 then + return + end + for _, k in ipairs(keys) do + o = o[k] + if o == nil then + return + end + end + return o +end + --- Extends a list-like table with the values of another list-like table. --- --- NOTE: This mutates dst! @@ -527,13 +555,23 @@ end --- => error('arg1: expected even number, got 3') --- </pre> --- ----@param opt Map of parameter names to validations. Each key is a parameter +--- If multiple types are valid they can be given as a list. +--- <pre> +--- vim.validate{arg1={{'foo'}, {'table', 'string'}}, arg2={'foo', {'table', 'string'}}} +--- => NOP (success) +--- +--- vim.validate{arg1={1, {'string', table'}}} +--- => error('arg1: expected string|table, got number') +--- +--- </pre> +--- +---@param opt table of parameter names to validations. Each key is a parameter --- name; each value is a tuple in one of these forms: --- 1. (arg_value, type_name, optional) --- - arg_value: argument value ---- - type_name: string type name, one of: ("table", "t", "string", +--- - type_name: string|table type name, one of: ("table", "t", "string", --- "s", "number", "n", "boolean", "b", "function", "f", "nil", ---- "thread", "userdata") +--- "thread", "userdata") or list of them. --- - optional: (optional) boolean, if true, `nil` is valid --- 2. (arg_value, fn, msg) --- - arg_value: argument value @@ -560,6 +598,7 @@ do return type(val) == t or (t == 'callable' and vim.is_callable(val)) end + ---@private local function is_valid(opt) if type(opt) ~= 'table' then return false, string.format('opt: expected table, got %s', type(opt)) @@ -571,31 +610,43 @@ do end local val = spec[1] -- Argument value. - local t = spec[2] -- Type name, or callable. + local types = spec[2] -- Type name, or callable. local optional = (true == spec[3]) - if type(t) == 'string' then - local t_name = type_names[t] - if not t_name then - return false, string.format('invalid type name: %s', t) - end + if type(types) == 'string' then + types = {types} + end - if (not optional or val ~= nil) and not _is_type(val, t_name) then - return false, string.format("%s: expected %s, got %s", param_name, t_name, type(val)) - end - elseif vim.is_callable(t) then + if vim.is_callable(types) then -- Check user-provided validation function. - local valid, optional_message = t(val) + local valid, optional_message = types(val) if not valid then - local error_message = string.format("%s: expected %s, got %s", param_name, (spec[3] or '?'), val) + local error_message = string.format("%s: expected %s, got %s", param_name, (spec[3] or '?'), tostring(val)) if optional_message ~= nil then error_message = error_message .. string.format(". Info: %s", optional_message) end return false, error_message end + elseif type(types) == 'table' then + local success = false + for i, t in ipairs(types) do + local t_name = type_names[t] + if not t_name then + return false, string.format('invalid type name: %s', t) + end + types[i] = t_name + + if (optional and val == nil) or _is_type(val, t_name) then + success = true + break + end + end + if not success then + return false, string.format("%s: expected %s, got %s", param_name, table.concat(types, '|'), type(val)) + end else - return false, string.format("invalid type name: %s", tostring(t)) + return false, string.format("invalid type name: %s", tostring(types)) end end |