aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/iter.lua
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim/iter.lua')
-rw-r--r--runtime/lua/vim/iter.lua274
1 files changed, 204 insertions, 70 deletions
diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua
index 874bdfb437..a37b7f7858 100644
--- a/runtime/lua/vim/iter.lua
+++ b/runtime/lua/vim/iter.lua
@@ -1,7 +1,7 @@
----@defgroup vim.iter
+--- @brief
---
---- \*vim.iter()\* is an interface for |iterable|s: it wraps a table or function argument into an
---- \*Iter\* object with methods (such as |Iter:filter()| and |Iter:map()|) that transform the
+--- [vim.iter()]() is an interface for [iterable]s: it wraps a table or function argument into an
+--- [Iter]() object with methods (such as [Iter:filter()] and [Iter:map()]) that transform the
--- underlying source data. These methods can be chained to create iterator "pipelines": the output
--- of each pipeline stage is input to the next stage. The first stage depends on the type passed to
--- `vim.iter()`:
@@ -64,10 +64,16 @@
--- In addition to the |vim.iter()| function, the |vim.iter| module provides
--- convenience functions like |vim.iter.filter()| and |vim.iter.totable()|.
+--- LuaLS is bad at generics which this module mostly deals with
+--- @diagnostic disable:no-unknown
+
+---@nodoc
---@class IterMod
---@operator call:Iter
+
local M = {}
+---@nodoc
---@class Iter
local Iter = {}
Iter.__index = Iter
@@ -76,6 +82,7 @@ Iter.__call = function(self)
end
--- Special case implementations for iterators on list tables.
+---@nodoc
---@class ListIter : Iter
---@field _table table Underlying table data
---@field _head number Index to the front of a table iterator
@@ -112,6 +119,35 @@ local function sanitize(t)
return t
end
+--- Flattens a single list-like table. Errors if it attempts to flatten a
+--- dict-like table
+---@param v table table which should be flattened
+---@param max_depth number depth to which the table should be flattened
+---@param depth number current iteration depth
+---@param result table output table that contains flattened result
+---@return table|nil flattened table if it can be flattened, otherwise nil
+local function flatten(v, max_depth, depth, result)
+ if depth < max_depth and type(v) == 'table' then
+ local i = 0
+ for _ in pairs(v) do
+ i = i + 1
+
+ if v[i] == nil then
+ -- short-circuit: this is not a list like table
+ return nil
+ end
+
+ if flatten(v[i], max_depth, depth + 1, result) == nil then
+ return nil
+ end
+ end
+ else
+ result[#result + 1] = v
+ end
+
+ return result
+end
+
--- Determine if the current iterator stage should continue.
---
--- If any arguments are passed to this function, then return those arguments
@@ -152,11 +188,11 @@ end
--- local bufs = vim.iter(vim.api.nvim_list_bufs()):filter(vim.api.nvim_buf_is_loaded)
--- ```
---
----@param f function(...):bool Takes all values returned from the previous stage
---- in the pipeline and returns false or nil if the
---- current iterator element should be removed.
+---@param f fun(...):boolean Takes all values returned from the previous stage
+--- in the pipeline and returns false or nil if the
+--- current iterator element should be removed.
---@return Iter
-function Iter.filter(self, f)
+function Iter:filter(f)
return self:map(function(...)
if f(...) then
return ...
@@ -165,7 +201,7 @@ function Iter.filter(self, f)
end
---@private
-function ListIter.filter(self, f)
+function ListIter:filter(f)
local inc = self._head < self._tail and 1 or -1
local n = self._head
for i = self._head, self._tail - inc, inc do
@@ -179,6 +215,55 @@ function ListIter.filter(self, f)
return self
end
+--- Flattens a |list-iterator|, un-nesting nested values up to the given {depth}.
+--- Errors if it attempts to flatten a dict-like value.
+---
+--- Examples:
+---
+--- ```lua
+--- vim.iter({ 1, { 2 }, { { 3 } } }):flatten():totable()
+--- -- { 1, 2, { 3 } }
+---
+--- vim.iter({1, { { a = 2 } }, { 3 } }):flatten():totable()
+--- -- { 1, { a = 2 }, 3 }
+---
+--- vim.iter({ 1, { { a = 2 } }, { 3 } }):flatten(math.huge):totable()
+--- -- error: attempt to flatten a dict-like table
+--- ```
+---
+---@param depth? number Depth to which |list-iterator| should be flattened
+--- (defaults to 1)
+---@return Iter
+---@diagnostic disable-next-line:unused-local
+function Iter:flatten(depth) -- luacheck: no unused args
+ error('flatten() requires a list-like table')
+end
+
+---@private
+function ListIter:flatten(depth)
+ depth = depth or 1
+ local inc = self._head < self._tail and 1 or -1
+ local target = {}
+
+ for i = self._head, self._tail - inc, inc do
+ local flattened = flatten(self._table[i], depth, 0, {})
+
+ -- exit early if we try to flatten a dict-like table
+ if flattened == nil then
+ error('flatten() requires a list-like table')
+ end
+
+ for _, v in pairs(flattened) do
+ target[#target + 1] = v
+ end
+ end
+
+ self._head = 1
+ self._tail = #target + 1
+ self._table = target
+ return self
+end
+
--- Maps the items of an iterator pipeline to the values returned by `f`.
---
--- If the map function returns nil, the value is filtered from the iterator.
@@ -195,13 +280,13 @@ end
--- -- { 6, 12 }
--- ```
---
----@param f function(...):any Mapping function. Takes all values returned from
---- the previous stage in the pipeline as arguments
---- and returns one or more new values, which are used
---- in the next pipeline stage. Nil return values
---- are filtered from the output.
+---@param f fun(...):any Mapping function. Takes all values returned from
+--- the previous stage in the pipeline as arguments
+--- and returns one or more new values, which are used
+--- in the next pipeline stage. Nil return values
+--- are filtered from the output.
---@return Iter
-function Iter.map(self, f)
+function Iter:map(f)
-- Implementation note: the reader may be forgiven for observing that this
-- function appears excessively convoluted. The problem to solve is that each
-- stage of the iterator pipeline can return any number of values, and the
@@ -245,7 +330,7 @@ function Iter.map(self, f)
end
---@private
-function ListIter.map(self, f)
+function ListIter:map(f)
local inc = self._head < self._tail and 1 or -1
local n = self._head
for i = self._head, self._tail - inc, inc do
@@ -263,10 +348,10 @@ end
---
--- For functions with side effects. To modify the values in the iterator, use |Iter:map()|.
---
----@param f function(...) Function to execute for each item in the pipeline.
---- Takes all of the values returned by the previous stage
---- in the pipeline as arguments.
-function Iter.each(self, f)
+---@param f fun(...) Function to execute for each item in the pipeline.
+--- Takes all of the values returned by the previous stage
+--- in the pipeline as arguments.
+function Iter:each(f)
local function fn(...)
if select(1, ...) ~= nil then
f(...)
@@ -278,7 +363,7 @@ function Iter.each(self, f)
end
---@private
-function ListIter.each(self, f)
+function ListIter:each(f)
local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do
f(unpack(self._table[i]))
@@ -311,7 +396,7 @@ end
---
---
---@return table
-function Iter.totable(self)
+function Iter:totable()
local t = {}
while true do
@@ -326,7 +411,7 @@ function Iter.totable(self)
end
---@private
-function ListIter.totable(self)
+function ListIter:totable()
if self.next ~= ListIter.next or self._head >= self._tail then
return Iter.totable(self)
end
@@ -356,6 +441,18 @@ function ListIter.totable(self)
return self._table
end
+--- Collect the iterator into a delimited string.
+---
+--- Each element in the iterator is joined into a string separated by {delim}.
+---
+--- Consumes the iterator.
+---
+--- @param delim string Delimiter
+--- @return string
+function Iter:join(delim)
+ return table.concat(self:totable(), delim)
+end
+
--- Folds ("reduces") an iterator into a single value.
---
--- Examples:
@@ -375,9 +472,9 @@ end
---@generic A
---
---@param init A Initial value of the accumulator.
----@param f function(acc:A, ...):A Accumulation function.
+---@param f fun(acc:A, ...):A Accumulation function.
---@return A
-function Iter.fold(self, init, f)
+function Iter:fold(init, f)
local acc = init
--- Use a closure to handle var args returned from iterator
@@ -394,7 +491,7 @@ function Iter.fold(self, init, f)
end
---@private
-function ListIter.fold(self, init, f)
+function ListIter:fold(init, f)
local acc = init
local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do
@@ -420,13 +517,13 @@ end
--- ```
---
---@return any
-function Iter.next(self) -- luacheck: no unused args
+function Iter:next()
-- This function is provided by the source iterator in Iter.new. This definition exists only for
-- the docstring
end
---@private
-function ListIter.next(self)
+function ListIter:next()
if self._head ~= self._tail then
local v = self._table[self._head]
local inc = self._head < self._tail and 1 or -1
@@ -448,13 +545,12 @@ end
--- ```
---
---@return Iter
-function Iter.rev(self)
+function Iter:rev()
error('rev() requires a list-like table')
- return self
end
---@private
-function ListIter.rev(self)
+function ListIter:rev()
local inc = self._head < self._tail and 1 or -1
self._head, self._tail = self._tail - inc, self._head - inc
return self
@@ -477,12 +573,12 @@ end
--- ```
---
---@return any
-function Iter.peek(self) -- luacheck: no unused args
+function Iter:peek()
error('peek() requires a list-like table')
end
---@private
-function ListIter.peek(self)
+function ListIter:peek()
if self._head ~= self._tail then
return self._table[self._head]
end
@@ -509,9 +605,9 @@ end
--- -- 12
---
--- ```
----
+---@param f any
---@return any
-function Iter.find(self, f)
+function Iter:find(f)
if type(f) ~= 'function' then
local val = f
f = function(v)
@@ -555,13 +651,15 @@ end
---
---@see Iter.find
---
+---@param f any
---@return any
-function Iter.rfind(self, f) -- luacheck: no unused args
+---@diagnostic disable-next-line: unused-local
+function Iter:rfind(f) -- luacheck: no unused args
error('rfind() requires a list-like table')
end
---@private
-function ListIter.rfind(self, f) -- luacheck: no unused args
+function ListIter:rfind(f)
if type(f) ~= 'function' then
local val = f
f = function(v)
@@ -580,6 +678,41 @@ function ListIter.rfind(self, f) -- luacheck: no unused args
self._head = self._tail
end
+--- Transforms an iterator to yield only the first n values.
+---
+--- Example:
+---
+--- ```lua
+--- local it = vim.iter({ 1, 2, 3, 4 }):take(2)
+--- it:next()
+--- -- 1
+--- it:next()
+--- -- 2
+--- it:next()
+--- -- nil
+--- ```
+---
+---@param n integer
+---@return Iter
+function Iter:take(n)
+ local next = self.next
+ local i = 0
+ self.next = function()
+ if i < n then
+ i = i + 1
+ return next(self)
+ end
+ end
+ return self
+end
+
+---@private
+function ListIter:take(n)
+ local inc = self._head < self._tail and 1 or -1
+ self._tail = math.min(self._tail, self._head + n * inc)
+ return self
+end
+
--- "Pops" a value from a |list-iterator| (gets the last value and decrements the tail).
---
--- Example:
@@ -593,11 +726,12 @@ end
--- ```
---
---@return any
-function Iter.nextback(self) -- luacheck: no unused args
+function Iter:nextback()
error('nextback() requires a list-like table')
end
-function ListIter.nextback(self)
+--- @nodoc
+function ListIter:nextback()
if self._head ~= self._tail then
local inc = self._head < self._tail and 1 or -1
self._tail = self._tail - inc
@@ -622,11 +756,12 @@ end
--- ```
---
---@return any
-function Iter.peekback(self) -- luacheck: no unused args
+function Iter:peekback()
error('peekback() requires a list-like table')
end
-function ListIter.peekback(self)
+---@nodoc
+function ListIter:peekback()
if self._head ~= self._tail then
local inc = self._head < self._tail and 1 or -1
return self._table[self._tail - inc]
@@ -647,7 +782,7 @@ end
---
---@param n number Number of values to skip.
---@return Iter
-function Iter.skip(self, n)
+function Iter:skip(n)
for _ = 1, n do
local _ = self:next()
end
@@ -655,7 +790,7 @@ function Iter.skip(self, n)
end
---@private
-function ListIter.skip(self, n)
+function ListIter:skip(n)
local inc = self._head < self._tail and n or -n
self._head = self._head + inc
if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then
@@ -678,13 +813,13 @@ end
---
---@param n number Number of values to skip.
---@return Iter
-function Iter.skipback(self, n) -- luacheck: no unused args
+---@diagnostic disable-next-line: unused-local
+function Iter:skipback(n) -- luacheck: no unused args
error('skipback() requires a list-like table')
- return self
end
---@private
-function ListIter.skipback(self, n)
+function ListIter:skipback(n)
local inc = self._head < self._tail and n or -n
self._tail = self._tail - inc
if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then
@@ -709,7 +844,7 @@ end
---
---@param n number The index of the value to return.
---@return any
-function Iter.nth(self, n)
+function Iter:nth(n)
if n > 0 then
return self:skip(n - 1):next()
end
@@ -731,7 +866,7 @@ end
---
---@param n number The index of the value to return.
---@return any
-function Iter.nthback(self, n)
+function Iter:nthback(n)
if n > 0 then
return self:skipback(n - 1):nextback()
end
@@ -744,22 +879,22 @@ end
---@param first number
---@param last number
---@return Iter
-function Iter.slice(self, first, last) -- luacheck: no unused args
+---@diagnostic disable-next-line: unused-local
+function Iter:slice(first, last) -- luacheck: no unused args
error('slice() requires a list-like table')
- return self
end
---@private
-function ListIter.slice(self, first, last)
+function ListIter:slice(first, last)
return self:skip(math.max(0, first - 1)):skipback(math.max(0, self._tail - last - 1))
end
--- Returns true if any of the items in the iterator match the given predicate.
---
----@param pred function(...):bool Predicate function. Takes all values returned from the previous
---- stage in the pipeline as arguments and returns true if the
---- predicate matches.
-function Iter.any(self, pred)
+---@param pred fun(...):boolean Predicate function. Takes all values returned from the previous
+--- stage in the pipeline as arguments and returns true if the
+--- predicate matches.
+function Iter:any(pred)
local any = false
--- Use a closure to handle var args returned from iterator
@@ -780,10 +915,10 @@ end
--- Returns true if all items in the iterator match the given predicate.
---
----@param pred function(...):bool Predicate function. Takes all values returned from the previous
---- stage in the pipeline as arguments and returns true if the
---- predicate matches.
-function Iter.all(self, pred)
+---@param pred fun(...):boolean Predicate function. Takes all values returned from the previous
+--- stage in the pipeline as arguments and returns true if the
+--- predicate matches.
+function Iter:all(pred)
local all = true
local function fn(...)
@@ -818,7 +953,7 @@ end
--- ```
---
---@return any
-function Iter.last(self)
+function Iter:last()
local last = self:next()
local cur = self:next()
while cur do
@@ -829,7 +964,7 @@ function Iter.last(self)
end
---@private
-function ListIter.last(self)
+function ListIter:last()
local inc = self._head < self._tail and 1 or -1
local v = self._table[self._tail - inc]
self._head = self._tail
@@ -865,7 +1000,7 @@ end
--- ```
---
---@return Iter
-function Iter.enumerate(self)
+function Iter:enumerate()
local i = 0
return self:map(function(...)
i = i + 1
@@ -874,7 +1009,7 @@ function Iter.enumerate(self)
end
---@private
-function ListIter.enumerate(self)
+function ListIter:enumerate()
local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do
local v = self._table[i]
@@ -978,9 +1113,9 @@ end
---
---@see |Iter:filter()|
---
----@param f function(...):bool Filter function. Accepts the current iterator or table values as
---- arguments and returns true if those values should be kept in the
---- final table
+---@param f fun(...):boolean Filter function. Accepts the current iterator or table values as
+--- arguments and returns true if those values should be kept in the
+--- final table
---@param src table|function Table or iterator function to filter
---@return table
function M.filter(f, src, ...)
@@ -996,18 +1131,17 @@ end
---
---@see |Iter:map()|
---
----@param f function(...):?any Map function. Accepts the current iterator or table values as
---- arguments and returns one or more new values. Nil values are removed
---- from the final table.
+---@param f fun(...): any? Map function. Accepts the current iterator or table values as
+--- arguments and returns one or more new values. Nil values are removed
+--- from the final table.
---@param src table|function Table or iterator function to filter
---@return table
function M.map(f, src, ...)
return Iter.new(src, ...):map(f):totable()
end
----@type IterMod
return setmetatable(M, {
__call = function(_, ...)
return Iter.new(...)
end,
-})
+}) --[[@as IterMod]]