aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/shared.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/shared.lua')
-rw-r--r--runtime/lua/vim/shared.lua111
1 files changed, 81 insertions, 30 deletions
diff --git a/runtime/lua/vim/shared.lua b/runtime/lua/vim/shared.lua
index 6e40b6ca52..f0dc34608c 100644
--- a/runtime/lua/vim/shared.lua
+++ b/runtime/lua/vim/shared.lua
@@ -1,8 +1,10 @@
-- Functions shared by Nvim and its test-suite.
--
--- The singular purpose of this module is to share code with the Nvim
--- test-suite. If, in the future, Nvim itself is used to run the test-suite
--- instead of "vanilla Lua", these functions could move to src/nvim/lua/vim.lua
+-- These are "pure" lua functions not depending of the state of the editor.
+-- Thus they should always be available whenever nvim-related lua code is run,
+-- regardless if it is code in the editor itself, or in worker threads/processes,
+-- or the test suite. (Eventually the test suite will be run in a worker process,
+-- so this wouldn't be a separate case to consider)
local vim = vim or {}
@@ -12,7 +14,7 @@ local vim = vim or {}
--- same functions as those in the input table. Userdata and threads are not
--- copied and will throw an error.
---
----@param orig Table to copy
+---@param orig table Table to copy
---@returns New table of copied keys and (nested) values.
function vim.deepcopy(orig) end -- luacheck: no unused
vim.deepcopy = (function()
@@ -21,17 +23,16 @@ vim.deepcopy = (function()
end
local deepcopy_funcs = {
- table = function(orig)
+ table = function(orig, cache)
+ if cache[orig] then return cache[orig] end
local copy = {}
- if vim._empty_dict_mt ~= nil and getmetatable(orig) == vim._empty_dict_mt then
- copy = vim.empty_dict()
- end
-
+ cache[orig] = copy
+ local mt = getmetatable(orig)
for k, v in pairs(orig) do
- copy[vim.deepcopy(k)] = vim.deepcopy(v)
+ copy[vim.deepcopy(k, cache)] = vim.deepcopy(v, cache)
end
- return copy
+ return setmetatable(copy, mt)
end,
number = _id,
string = _id,
@@ -40,10 +41,10 @@ vim.deepcopy = (function()
['function'] = _id,
}
- return function(orig)
+ return function(orig, cache)
local f = deepcopy_funcs[type(orig)]
if f then
- return f(orig)
+ return f(orig, cache or {})
else
error("Cannot deepcopy object of type "..type(orig))
end
@@ -330,7 +331,7 @@ end
--- Add the reverse lookup values to an existing table.
--- For example:
---- `tbl_add_reverse_lookup { A = 1 } == { [1] = 'A', A = 1 }`
+--- ``tbl_add_reverse_lookup { A = 1 } == { [1] = 'A', A = 1 }``
--
--Do note that it *modifies* the input.
---@param o table The table to add the reverse to.
@@ -346,6 +347,33 @@ function vim.tbl_add_reverse_lookup(o)
return o
end
+--- Index into a table (first argument) via string keys passed as subsequent arguments.
+--- Return `nil` if the key does not exist.
+--_
+--- Examples:
+--- <pre>
+--- vim.tbl_get({ key = { nested_key = true }}, 'key', 'nested_key') == true
+--- vim.tbl_get({ key = {}}, 'key', 'nested_key') == nil
+--- </pre>
+---
+---@param o Table to index
+---@param ... Optional strings (0 or more, variadic) via which to index the table
+---
+---@returns nested value indexed by key if it exists, else nil
+function vim.tbl_get(o, ...)
+ local keys = {...}
+ if #keys == 0 then
+ return
+ end
+ for _, k in ipairs(keys) do
+ o = o[k]
+ if o == nil then
+ return
+ end
+ end
+ return o
+end
+
--- Extends a list-like table with the values of another list-like table.
---
--- NOTE: This mutates dst!
@@ -527,13 +555,23 @@ end
--- => error('arg1: expected even number, got 3')
--- </pre>
---
----@param opt Map of parameter names to validations. Each key is a parameter
+--- If multiple types are valid they can be given as a list.
+--- <pre>
+--- vim.validate{arg1={{'foo'}, {'table', 'string'}}, arg2={'foo', {'table', 'string'}}}
+--- => NOP (success)
+---
+--- vim.validate{arg1={1, {'string', table'}}}
+--- => error('arg1: expected string|table, got number')
+---
+--- </pre>
+---
+---@param opt table of parameter names to validations. Each key is a parameter
--- name; each value is a tuple in one of these forms:
--- 1. (arg_value, type_name, optional)
--- - arg_value: argument value
---- - type_name: string type name, one of: ("table", "t", "string",
+--- - type_name: string|table type name, one of: ("table", "t", "string",
--- "s", "number", "n", "boolean", "b", "function", "f", "nil",
---- "thread", "userdata")
+--- "thread", "userdata") or list of them.
--- - optional: (optional) boolean, if true, `nil` is valid
--- 2. (arg_value, fn, msg)
--- - arg_value: argument value
@@ -560,6 +598,7 @@ do
return type(val) == t or (t == 'callable' and vim.is_callable(val))
end
+ ---@private
local function is_valid(opt)
if type(opt) ~= 'table' then
return false, string.format('opt: expected table, got %s', type(opt))
@@ -571,31 +610,43 @@ do
end
local val = spec[1] -- Argument value.
- local t = spec[2] -- Type name, or callable.
+ local types = spec[2] -- Type name, or callable.
local optional = (true == spec[3])
- if type(t) == 'string' then
- local t_name = type_names[t]
- if not t_name then
- return false, string.format('invalid type name: %s', t)
- end
+ if type(types) == 'string' then
+ types = {types}
+ end
- if (not optional or val ~= nil) and not _is_type(val, t_name) then
- return false, string.format("%s: expected %s, got %s", param_name, t_name, type(val))
- end
- elseif vim.is_callable(t) then
+ if vim.is_callable(types) then
-- Check user-provided validation function.
- local valid, optional_message = t(val)
+ local valid, optional_message = types(val)
if not valid then
- local error_message = string.format("%s: expected %s, got %s", param_name, (spec[3] or '?'), val)
+ local error_message = string.format("%s: expected %s, got %s", param_name, (spec[3] or '?'), tostring(val))
if optional_message ~= nil then
error_message = error_message .. string.format(". Info: %s", optional_message)
end
return false, error_message
end
+ elseif type(types) == 'table' then
+ local success = false
+ for i, t in ipairs(types) do
+ local t_name = type_names[t]
+ if not t_name then
+ return false, string.format('invalid type name: %s', t)
+ end
+ types[i] = t_name
+
+ if (optional and val == nil) or _is_type(val, t_name) then
+ success = true
+ break
+ end
+ end
+ if not success then
+ return false, string.format("%s: expected %s, got %s", param_name, table.concat(types, '|'), type(val))
+ end
else
- return false, string.format("invalid type name: %s", tostring(t))
+ return false, string.format("invalid type name: %s", tostring(types))
end
end