diff options
Diffstat (limited to 'runtime/lua/vim/iter.lua')
-rw-r--r-- | runtime/lua/vim/iter.lua | 274 |
1 files changed, 204 insertions, 70 deletions
diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua index 874bdfb437..a37b7f7858 100644 --- a/runtime/lua/vim/iter.lua +++ b/runtime/lua/vim/iter.lua @@ -1,7 +1,7 @@ ----@defgroup vim.iter +--- @brief --- ---- \*vim.iter()\* is an interface for |iterable|s: it wraps a table or function argument into an ---- \*Iter\* object with methods (such as |Iter:filter()| and |Iter:map()|) that transform the +--- [vim.iter()]() is an interface for [iterable]s: it wraps a table or function argument into an +--- [Iter]() object with methods (such as [Iter:filter()] and [Iter:map()]) that transform the --- underlying source data. These methods can be chained to create iterator "pipelines": the output --- of each pipeline stage is input to the next stage. The first stage depends on the type passed to --- `vim.iter()`: @@ -64,10 +64,16 @@ --- In addition to the |vim.iter()| function, the |vim.iter| module provides --- convenience functions like |vim.iter.filter()| and |vim.iter.totable()|. +--- LuaLS is bad at generics which this module mostly deals with +--- @diagnostic disable:no-unknown + +---@nodoc ---@class IterMod ---@operator call:Iter + local M = {} +---@nodoc ---@class Iter local Iter = {} Iter.__index = Iter @@ -76,6 +82,7 @@ Iter.__call = function(self) end --- Special case implementations for iterators on list tables. +---@nodoc ---@class ListIter : Iter ---@field _table table Underlying table data ---@field _head number Index to the front of a table iterator @@ -112,6 +119,35 @@ local function sanitize(t) return t end +--- Flattens a single list-like table. Errors if it attempts to flatten a +--- dict-like table +---@param v table table which should be flattened +---@param max_depth number depth to which the table should be flattened +---@param depth number current iteration depth +---@param result table output table that contains flattened result +---@return table|nil flattened table if it can be flattened, otherwise nil +local function flatten(v, max_depth, depth, result) + if depth < max_depth and type(v) == 'table' then + local i = 0 + for _ in pairs(v) do + i = i + 1 + + if v[i] == nil then + -- short-circuit: this is not a list like table + return nil + end + + if flatten(v[i], max_depth, depth + 1, result) == nil then + return nil + end + end + else + result[#result + 1] = v + end + + return result +end + --- Determine if the current iterator stage should continue. --- --- If any arguments are passed to this function, then return those arguments @@ -152,11 +188,11 @@ end --- local bufs = vim.iter(vim.api.nvim_list_bufs()):filter(vim.api.nvim_buf_is_loaded) --- ``` --- ----@param f function(...):bool Takes all values returned from the previous stage ---- in the pipeline and returns false or nil if the ---- current iterator element should be removed. +---@param f fun(...):boolean Takes all values returned from the previous stage +--- in the pipeline and returns false or nil if the +--- current iterator element should be removed. ---@return Iter -function Iter.filter(self, f) +function Iter:filter(f) return self:map(function(...) if f(...) then return ... @@ -165,7 +201,7 @@ function Iter.filter(self, f) end ---@private -function ListIter.filter(self, f) +function ListIter:filter(f) local inc = self._head < self._tail and 1 or -1 local n = self._head for i = self._head, self._tail - inc, inc do @@ -179,6 +215,55 @@ function ListIter.filter(self, f) return self end +--- Flattens a |list-iterator|, un-nesting nested values up to the given {depth}. +--- Errors if it attempts to flatten a dict-like value. +--- +--- Examples: +--- +--- ```lua +--- vim.iter({ 1, { 2 }, { { 3 } } }):flatten():totable() +--- -- { 1, 2, { 3 } } +--- +--- vim.iter({1, { { a = 2 } }, { 3 } }):flatten():totable() +--- -- { 1, { a = 2 }, 3 } +--- +--- vim.iter({ 1, { { a = 2 } }, { 3 } }):flatten(math.huge):totable() +--- -- error: attempt to flatten a dict-like table +--- ``` +--- +---@param depth? number Depth to which |list-iterator| should be flattened +--- (defaults to 1) +---@return Iter +---@diagnostic disable-next-line:unused-local +function Iter:flatten(depth) -- luacheck: no unused args + error('flatten() requires a list-like table') +end + +---@private +function ListIter:flatten(depth) + depth = depth or 1 + local inc = self._head < self._tail and 1 or -1 + local target = {} + + for i = self._head, self._tail - inc, inc do + local flattened = flatten(self._table[i], depth, 0, {}) + + -- exit early if we try to flatten a dict-like table + if flattened == nil then + error('flatten() requires a list-like table') + end + + for _, v in pairs(flattened) do + target[#target + 1] = v + end + end + + self._head = 1 + self._tail = #target + 1 + self._table = target + return self +end + --- Maps the items of an iterator pipeline to the values returned by `f`. --- --- If the map function returns nil, the value is filtered from the iterator. @@ -195,13 +280,13 @@ end --- -- { 6, 12 } --- ``` --- ----@param f function(...):any Mapping function. Takes all values returned from ---- the previous stage in the pipeline as arguments ---- and returns one or more new values, which are used ---- in the next pipeline stage. Nil return values ---- are filtered from the output. +---@param f fun(...):any Mapping function. Takes all values returned from +--- the previous stage in the pipeline as arguments +--- and returns one or more new values, which are used +--- in the next pipeline stage. Nil return values +--- are filtered from the output. ---@return Iter -function Iter.map(self, f) +function Iter:map(f) -- Implementation note: the reader may be forgiven for observing that this -- function appears excessively convoluted. The problem to solve is that each -- stage of the iterator pipeline can return any number of values, and the @@ -245,7 +330,7 @@ function Iter.map(self, f) end ---@private -function ListIter.map(self, f) +function ListIter:map(f) local inc = self._head < self._tail and 1 or -1 local n = self._head for i = self._head, self._tail - inc, inc do @@ -263,10 +348,10 @@ end --- --- For functions with side effects. To modify the values in the iterator, use |Iter:map()|. --- ----@param f function(...) Function to execute for each item in the pipeline. ---- Takes all of the values returned by the previous stage ---- in the pipeline as arguments. -function Iter.each(self, f) +---@param f fun(...) Function to execute for each item in the pipeline. +--- Takes all of the values returned by the previous stage +--- in the pipeline as arguments. +function Iter:each(f) local function fn(...) if select(1, ...) ~= nil then f(...) @@ -278,7 +363,7 @@ function Iter.each(self, f) end ---@private -function ListIter.each(self, f) +function ListIter:each(f) local inc = self._head < self._tail and 1 or -1 for i = self._head, self._tail - inc, inc do f(unpack(self._table[i])) @@ -311,7 +396,7 @@ end --- --- ---@return table -function Iter.totable(self) +function Iter:totable() local t = {} while true do @@ -326,7 +411,7 @@ function Iter.totable(self) end ---@private -function ListIter.totable(self) +function ListIter:totable() if self.next ~= ListIter.next or self._head >= self._tail then return Iter.totable(self) end @@ -356,6 +441,18 @@ function ListIter.totable(self) return self._table end +--- Collect the iterator into a delimited string. +--- +--- Each element in the iterator is joined into a string separated by {delim}. +--- +--- Consumes the iterator. +--- +--- @param delim string Delimiter +--- @return string +function Iter:join(delim) + return table.concat(self:totable(), delim) +end + --- Folds ("reduces") an iterator into a single value. --- --- Examples: @@ -375,9 +472,9 @@ end ---@generic A --- ---@param init A Initial value of the accumulator. ----@param f function(acc:A, ...):A Accumulation function. +---@param f fun(acc:A, ...):A Accumulation function. ---@return A -function Iter.fold(self, init, f) +function Iter:fold(init, f) local acc = init --- Use a closure to handle var args returned from iterator @@ -394,7 +491,7 @@ function Iter.fold(self, init, f) end ---@private -function ListIter.fold(self, init, f) +function ListIter:fold(init, f) local acc = init local inc = self._head < self._tail and 1 or -1 for i = self._head, self._tail - inc, inc do @@ -420,13 +517,13 @@ end --- ``` --- ---@return any -function Iter.next(self) -- luacheck: no unused args +function Iter:next() -- This function is provided by the source iterator in Iter.new. This definition exists only for -- the docstring end ---@private -function ListIter.next(self) +function ListIter:next() if self._head ~= self._tail then local v = self._table[self._head] local inc = self._head < self._tail and 1 or -1 @@ -448,13 +545,12 @@ end --- ``` --- ---@return Iter -function Iter.rev(self) +function Iter:rev() error('rev() requires a list-like table') - return self end ---@private -function ListIter.rev(self) +function ListIter:rev() local inc = self._head < self._tail and 1 or -1 self._head, self._tail = self._tail - inc, self._head - inc return self @@ -477,12 +573,12 @@ end --- ``` --- ---@return any -function Iter.peek(self) -- luacheck: no unused args +function Iter:peek() error('peek() requires a list-like table') end ---@private -function ListIter.peek(self) +function ListIter:peek() if self._head ~= self._tail then return self._table[self._head] end @@ -509,9 +605,9 @@ end --- -- 12 --- --- ``` ---- +---@param f any ---@return any -function Iter.find(self, f) +function Iter:find(f) if type(f) ~= 'function' then local val = f f = function(v) @@ -555,13 +651,15 @@ end --- ---@see Iter.find --- +---@param f any ---@return any -function Iter.rfind(self, f) -- luacheck: no unused args +---@diagnostic disable-next-line: unused-local +function Iter:rfind(f) -- luacheck: no unused args error('rfind() requires a list-like table') end ---@private -function ListIter.rfind(self, f) -- luacheck: no unused args +function ListIter:rfind(f) if type(f) ~= 'function' then local val = f f = function(v) @@ -580,6 +678,41 @@ function ListIter.rfind(self, f) -- luacheck: no unused args self._head = self._tail end +--- Transforms an iterator to yield only the first n values. +--- +--- Example: +--- +--- ```lua +--- local it = vim.iter({ 1, 2, 3, 4 }):take(2) +--- it:next() +--- -- 1 +--- it:next() +--- -- 2 +--- it:next() +--- -- nil +--- ``` +--- +---@param n integer +---@return Iter +function Iter:take(n) + local next = self.next + local i = 0 + self.next = function() + if i < n then + i = i + 1 + return next(self) + end + end + return self +end + +---@private +function ListIter:take(n) + local inc = self._head < self._tail and 1 or -1 + self._tail = math.min(self._tail, self._head + n * inc) + return self +end + --- "Pops" a value from a |list-iterator| (gets the last value and decrements the tail). --- --- Example: @@ -593,11 +726,12 @@ end --- ``` --- ---@return any -function Iter.nextback(self) -- luacheck: no unused args +function Iter:nextback() error('nextback() requires a list-like table') end -function ListIter.nextback(self) +--- @nodoc +function ListIter:nextback() if self._head ~= self._tail then local inc = self._head < self._tail and 1 or -1 self._tail = self._tail - inc @@ -622,11 +756,12 @@ end --- ``` --- ---@return any -function Iter.peekback(self) -- luacheck: no unused args +function Iter:peekback() error('peekback() requires a list-like table') end -function ListIter.peekback(self) +---@nodoc +function ListIter:peekback() if self._head ~= self._tail then local inc = self._head < self._tail and 1 or -1 return self._table[self._tail - inc] @@ -647,7 +782,7 @@ end --- ---@param n number Number of values to skip. ---@return Iter -function Iter.skip(self, n) +function Iter:skip(n) for _ = 1, n do local _ = self:next() end @@ -655,7 +790,7 @@ function Iter.skip(self, n) end ---@private -function ListIter.skip(self, n) +function ListIter:skip(n) local inc = self._head < self._tail and n or -n self._head = self._head + inc if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then @@ -678,13 +813,13 @@ end --- ---@param n number Number of values to skip. ---@return Iter -function Iter.skipback(self, n) -- luacheck: no unused args +---@diagnostic disable-next-line: unused-local +function Iter:skipback(n) -- luacheck: no unused args error('skipback() requires a list-like table') - return self end ---@private -function ListIter.skipback(self, n) +function ListIter:skipback(n) local inc = self._head < self._tail and n or -n self._tail = self._tail - inc if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then @@ -709,7 +844,7 @@ end --- ---@param n number The index of the value to return. ---@return any -function Iter.nth(self, n) +function Iter:nth(n) if n > 0 then return self:skip(n - 1):next() end @@ -731,7 +866,7 @@ end --- ---@param n number The index of the value to return. ---@return any -function Iter.nthback(self, n) +function Iter:nthback(n) if n > 0 then return self:skipback(n - 1):nextback() end @@ -744,22 +879,22 @@ end ---@param first number ---@param last number ---@return Iter -function Iter.slice(self, first, last) -- luacheck: no unused args +---@diagnostic disable-next-line: unused-local +function Iter:slice(first, last) -- luacheck: no unused args error('slice() requires a list-like table') - return self end ---@private -function ListIter.slice(self, first, last) +function ListIter:slice(first, last) return self:skip(math.max(0, first - 1)):skipback(math.max(0, self._tail - last - 1)) end --- Returns true if any of the items in the iterator match the given predicate. --- ----@param pred function(...):bool Predicate function. Takes all values returned from the previous ---- stage in the pipeline as arguments and returns true if the ---- predicate matches. -function Iter.any(self, pred) +---@param pred fun(...):boolean Predicate function. Takes all values returned from the previous +--- stage in the pipeline as arguments and returns true if the +--- predicate matches. +function Iter:any(pred) local any = false --- Use a closure to handle var args returned from iterator @@ -780,10 +915,10 @@ end --- Returns true if all items in the iterator match the given predicate. --- ----@param pred function(...):bool Predicate function. Takes all values returned from the previous ---- stage in the pipeline as arguments and returns true if the ---- predicate matches. -function Iter.all(self, pred) +---@param pred fun(...):boolean Predicate function. Takes all values returned from the previous +--- stage in the pipeline as arguments and returns true if the +--- predicate matches. +function Iter:all(pred) local all = true local function fn(...) @@ -818,7 +953,7 @@ end --- ``` --- ---@return any -function Iter.last(self) +function Iter:last() local last = self:next() local cur = self:next() while cur do @@ -829,7 +964,7 @@ function Iter.last(self) end ---@private -function ListIter.last(self) +function ListIter:last() local inc = self._head < self._tail and 1 or -1 local v = self._table[self._tail - inc] self._head = self._tail @@ -865,7 +1000,7 @@ end --- ``` --- ---@return Iter -function Iter.enumerate(self) +function Iter:enumerate() local i = 0 return self:map(function(...) i = i + 1 @@ -874,7 +1009,7 @@ function Iter.enumerate(self) end ---@private -function ListIter.enumerate(self) +function ListIter:enumerate() local inc = self._head < self._tail and 1 or -1 for i = self._head, self._tail - inc, inc do local v = self._table[i] @@ -978,9 +1113,9 @@ end --- ---@see |Iter:filter()| --- ----@param f function(...):bool Filter function. Accepts the current iterator or table values as ---- arguments and returns true if those values should be kept in the ---- final table +---@param f fun(...):boolean Filter function. Accepts the current iterator or table values as +--- arguments and returns true if those values should be kept in the +--- final table ---@param src table|function Table or iterator function to filter ---@return table function M.filter(f, src, ...) @@ -996,18 +1131,17 @@ end --- ---@see |Iter:map()| --- ----@param f function(...):?any Map function. Accepts the current iterator or table values as ---- arguments and returns one or more new values. Nil values are removed ---- from the final table. +---@param f fun(...): any? Map function. Accepts the current iterator or table values as +--- arguments and returns one or more new values. Nil values are removed +--- from the final table. ---@param src table|function Table or iterator function to filter ---@return table function M.map(f, src, ...) return Iter.new(src, ...):map(f):totable() end ----@type IterMod return setmetatable(M, { __call = function(_, ...) return Iter.new(...) end, -}) +}) --[[@as IterMod]] |