aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim/iter.lua
diff options
context:
space:
mode:
authorGregory Anders <8965202+gpanders@users.noreply.github.com>2023-04-17 12:54:19 -0600
committerGitHub <noreply@github.com>2023-04-17 12:54:19 -0600
commitab1edecfb7c73c82c2d5886cb8e270b44aca7d01 (patch)
treeebb6a3218cc2e9c3135fe5c791a794e3c580df62 /runtime/lua/vim/iter.lua
parent6cc76011ca28ff61f1c2f8de6d895d4c6d0a1ad8 (diff)
downloadrneovim-ab1edecfb7c73c82c2d5886cb8e270b44aca7d01.tar.gz
rneovim-ab1edecfb7c73c82c2d5886cb8e270b44aca7d01.tar.bz2
rneovim-ab1edecfb7c73c82c2d5886cb8e270b44aca7d01.zip
feat(lua): add vim.iter (#23029)
vim.iter wraps a table or iterator function into an `Iter` object with methods such as `filter`, `map`, and `fold` which can be chained to produce iterator pipelines that do not create new tables at each step.
Diffstat (limited to 'runtime/lua/vim/iter.lua')
-rw-r--r--runtime/lua/vim/iter.lua836
1 files changed, 836 insertions, 0 deletions
diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua
new file mode 100644
index 0000000000..b73c03ba9a
--- /dev/null
+++ b/runtime/lua/vim/iter.lua
@@ -0,0 +1,836 @@
+--- Iterator implementation.
+
+---@class Iter
+local Iter = {}
+Iter.__index = Iter
+Iter.__call = function(self)
+ return self:next()
+end
+
+--- Special case implementations for iterators on list tables.
+---@class ListIter : Iter
+---@field _table table Underlying table data (table iterators only)
+---@field _head number Index to the front of a table iterator (table iterators only)
+---@field _tail number Index to the end of a table iterator (table iterators only)
+local ListIter = {}
+ListIter.__index = setmetatable(ListIter, Iter)
+ListIter.__call = function(self)
+ return self:next()
+end
+
+--- Special case implementations for iterators on non-list tables.
+---@class TableIter : Iter
+local TableIter = {}
+TableIter.__index = setmetatable(TableIter, Iter)
+TableIter.__call = function(self)
+ return self:next()
+end
+
+---@private
+local function unpack(t)
+ if type(t) == 'table' then
+ return _G.unpack(t)
+ end
+ return t
+end
+
+---@private
+local function pack(...)
+ if select('#', ...) > 1 then
+ return { ... }
+ end
+ return ...
+end
+
+--- Add a filter step to the iterator pipeline.
+---
+--- Example:
+--- <pre>lua
+--- local bufs = vim.iter(vim.api.nvim_list_bufs()):filter(vim.api.nvim_buf_is_loaded)
+--- </pre>
+---
+---@param f function(...):bool Takes all values returned from the previous stage in the pipeline and
+--- returns false or nil if the current iterator element should be
+--- removed.
+---@return Iter
+function Iter.filter(self, f)
+ ---@private
+ local function fn(...)
+ local result = nil
+ if select(1, ...) ~= nil then
+ if not f(...) then
+ return true, nil
+ else
+ result = pack(...)
+ end
+ end
+ return false, result
+ end
+
+ local next = self.next
+ self.next = function(this)
+ local cont, result
+ repeat
+ cont, result = fn(next(this))
+ until not cont
+ return unpack(result)
+ end
+ return self
+end
+
+---@private
+function ListIter.filter(self, f)
+ local inc = self._head < self._tail and 1 or -1
+ local n = self._head
+ for i = self._head, self._tail - inc, inc do
+ local v = self._table[i]
+ if f(unpack(v)) then
+ self._table[n] = v
+ n = n + inc
+ end
+ end
+ self._tail = n
+ return self
+end
+
+--- Add a map step to the iterator pipeline.
+---
+--- If the map function returns nil, the value is filtered from the iterator.
+---
+--- Example:
+--- <pre>lua
+--- local it = vim.iter({ 1, 2, 3, 4 }):map(function(v)
+--- if v % 2 == 0 then
+--- return v * 3
+--- end
+--- end)
+--- it:totable()
+--- -- { 6, 12 }
+--- </pre>
+---
+---@param f function(...):any Mapping function. Takes all values returned from the previous stage
+--- in the pipeline as arguments and returns one or more new values,
+--- which are used in the next pipeline stage. Nil return values returned
+--- are filtered from the output.
+---@return Iter
+function Iter.map(self, f)
+ ---@private
+ local function fn(...)
+ local result = nil
+ if select(1, ...) ~= nil then
+ result = pack(f(...))
+ if result == nil then
+ return true, nil
+ end
+ end
+ return false, result
+ end
+
+ local next = self.next
+ self.next = function(this)
+ local cont, result
+ repeat
+ cont, result = fn(next(this))
+ until not cont
+ return unpack(result)
+ end
+ return self
+end
+
+---@private
+function ListIter.map(self, f)
+ local inc = self._head < self._tail and 1 or -1
+ local n = self._head
+ for i = self._head, self._tail - inc, inc do
+ local v = pack(f(unpack(self._table[i])))
+ if v ~= nil then
+ self._table[n] = v
+ n = n + inc
+ end
+ end
+ self._tail = n
+ return self
+end
+
+--- Call a function once for each item in the pipeline.
+---
+--- This is used for functions which have side effects. To modify the values in the iterator, use
+--- |Iter:map()|.
+---
+--- This function drains the iterator.
+---
+---@param f function(...) Function to execute for each item in the pipeline. Takes all of the
+--- values returned by the previous stage in the pipeline as arguments.
+function Iter.each(self, f)
+ ---@private
+ local function fn(...)
+ if select(1, ...) ~= nil then
+ f(...)
+ return true
+ end
+ end
+ while fn(self:next()) do
+ end
+end
+
+---@private
+function ListIter.each(self, f)
+ local inc = self._head < self._tail and 1 or -1
+ for i = self._head, self._tail - inc, inc do
+ f(unpack(self._table[i]))
+ end
+ self._head = self._tail
+end
+
+--- Collect the iterator into a table.
+---
+--- The resulting table depends on the initial source in the iterator pipeline. List-like tables
+--- and function iterators will be collected into a list-like table. If multiple values are returned
+--- from the final stage in the iterator pipeline, each value will be included in a table. If a
+--- map-like table was used as the initial source, then a map-like table is returned.
+---
+--- Examples:
+--- <pre>lua
+--- vim.iter(string.gmatch('100 20 50', '%d+')):map(tonumber):totable()
+--- -- { 100, 20, 50 }
+---
+--- vim.iter({ 1, 2, 3 }):map(function(v) return v, 2 * v end):totable()
+--- -- { { 1, 2 }, { 2, 4 }, { 3, 6 } }
+---
+--- vim.iter({ a = 1, b = 2, c = 3 }):filter(function(k, v) return v % 2 ~= 0 end):totable()
+--- -- { a = 1, c = 3 }
+--- </pre>
+---
+---@return table
+function Iter.totable(self)
+ local t = {}
+
+ while true do
+ local args = pack(self:next())
+ if args == nil then
+ break
+ end
+ t[#t + 1] = args
+ end
+ return t
+end
+
+---@private
+function ListIter.totable(self)
+ if self._head == 1 and self._tail == #self._table + 1 and self.next == ListIter.next then
+ return self._table
+ end
+
+ return Iter.totable(self)
+end
+
+---@private
+function TableIter.totable(self)
+ local t = {}
+ for k, v in self do
+ t[k] = v
+ end
+ return t
+end
+
+--- Fold an iterator or table into a single value.
+---
+---@generic A
+---
+---@param init A Initial value of the accumulator.
+---@param f function(acc:A, ...):A Accumulation function.
+---@return A
+function Iter.fold(self, init, f)
+ local acc = init
+
+ --- Use a closure to handle var args returned from iterator
+ ---@private
+ local function fn(...)
+ if select(1, ...) ~= nil then
+ acc = f(acc, ...)
+ return true
+ end
+ end
+
+ while fn(self:next()) do
+ end
+ return acc
+end
+
+---@private
+function ListIter.fold(self, init, f)
+ local acc = init
+ local inc = self._head < self._tail and 1 or -1
+ for i = self._head, self._tail - inc, inc do
+ acc = f(acc, unpack(self._table[i]))
+ end
+ return acc
+end
+
+--- Return the next value from the iterator.
+---
+--- Example:
+--- <pre>lua
+---
+--- local it = vim.iter(string.gmatch('1 2 3', '%d+')):map(tonumber)
+--- it:next()
+--- -- 1
+--- it:next()
+--- -- 2
+--- it:next()
+--- -- 3
+---
+--- </pre>
+---
+---@return any
+function Iter.next(self) -- luacheck: no unused args
+ -- This function is provided by the source iterator in Iter.new. This definition exists only for
+ -- the docstring
+end
+
+---@private
+function ListIter.next(self)
+ if self._head ~= self._tail then
+ local v = self._table[self._head]
+ local inc = self._head < self._tail and 1 or -1
+ self._head = self._head + inc
+ return unpack(v)
+ end
+end
+
+--- Reverse an iterator.
+---
+--- Only supported for iterators on list-like tables.
+---
+--- Example:
+--- <pre>lua
+---
+--- local it = vim.iter({ 3, 6, 9, 12 }):rev()
+--- it:totable()
+--- -- { 12, 9, 6, 3 }
+---
+--- </pre>
+---
+---@return Iter
+function Iter.rev(self)
+ error('rev() requires a list-like table')
+ return self
+end
+
+---@private
+function ListIter.rev(self)
+ local inc = self._head < self._tail and 1 or -1
+ self._head, self._tail = self._tail - inc, self._head - inc
+ return self
+end
+
+--- Peek at the next value in the iterator without consuming it.
+---
+--- Only supported for iterators on list-like tables.
+---
+--- Example:
+--- <pre>lua
+---
+--- local it = vim.iter({ 3, 6, 9, 12 })
+--- it:peek()
+--- -- 3
+--- it:peek()
+--- -- 3
+--- it:next()
+--- -- 3
+---
+--- </pre>
+---
+---@return any
+function Iter.peek(self) -- luacheck: no unused args
+ error('peek() requires a list-like table')
+end
+
+---@private
+function ListIter.peek(self)
+ if self._head ~= self._tail then
+ return self._table[self._head]
+ end
+end
+
+--- Find the first value in the iterator that satisfies the given predicate.
+---
+--- Advances the iterator. Returns nil and drains the iterator if no value is found.
+---
+--- Examples:
+--- <pre>lua
+---
+--- local it = vim.iter({ 3, 6, 9, 12 })
+--- it:find(12)
+--- -- 12
+---
+--- local it = vim.iter({ 3, 6, 9, 12 })
+--- it:find(20)
+--- -- nil
+---
+--- local it = vim.iter({ 3, 6, 9, 12 })
+--- it:find(function(v) return v % 4 == 0 end)
+--- -- 12
+---
+--- </pre>
+---
+---@return any
+function Iter.find(self, f)
+ if type(f) ~= 'function' then
+ local val = f
+ f = function(v)
+ return v == val
+ end
+ end
+
+ local result = nil
+
+ --- Use a closure to handle var args returned from iterator
+ ---@private
+ local function fn(...)
+ if select(1, ...) ~= nil then
+ if f(...) then
+ result = pack(...)
+ else
+ return true
+ end
+ end
+ end
+
+ while fn(self:next()) do
+ end
+ return unpack(result)
+end
+
+--- Find the first value in the iterator that satisfies the given predicate, starting from the end.
+---
+--- Advances the iterator. Returns nil and drains the iterator if no value is found.
+---
+--- Only supported for iterators on list-like tables.
+---
+--- Examples:
+--- <pre>lua
+---
+--- local it = vim.iter({ 1, 2, 3, 2, 1 }):enumerate()
+--- it:rfind(1)
+--- -- 5 1
+--- it:rfind(1)
+--- -- 1 1
+---
+--- </pre>
+---
+---@see Iter.find
+---
+---@return any
+function Iter.rfind(self, f) -- luacheck: no unused args
+ error('rfind() requires a list-like table')
+end
+
+---@private
+function ListIter.rfind(self, f) -- luacheck: no unused args
+ if type(f) ~= 'function' then
+ local val = f
+ f = function(v)
+ return v == val
+ end
+ end
+
+ local inc = self._head < self._tail and 1 or -1
+ for i = self._tail - inc, self._head, -inc do
+ local v = self._table[i]
+ if f(unpack(v)) then
+ self._tail = i
+ return unpack(v)
+ end
+ end
+ self._head = self._tail
+end
+
+--- Return the next value from the end of the iterator.
+---
+--- Only supported for iterators on list-like tables.
+---
+--- Example:
+--- <pre>lua
+--- local it = vim.iter({1, 2, 3, 4})
+--- it:nextback()
+--- -- 4
+--- it:nextback()
+--- -- 3
+--- </pre>
+---
+---@return any
+function Iter.nextback(self) -- luacheck: no unused args
+ error('nextback() requires a list-like table')
+end
+
+function ListIter.nextback(self)
+ if self._head ~= self._tail then
+ local inc = self._head < self._tail and 1 or -1
+ self._tail = self._tail - inc
+ return self._table[self._tail]
+ end
+end
+
+--- Return the next value from the end of the iterator without consuming it.
+---
+--- Only supported for iterators on list-like tables.
+---
+--- Example:
+--- <pre>lua
+--- local it = vim.iter({1, 2, 3, 4})
+--- it:peekback()
+--- -- 4
+--- it:peekback()
+--- -- 4
+--- it:nextback()
+--- -- 4
+--- </pre>
+---
+---@return any
+function Iter.peekback(self) -- luacheck: no unused args
+ error('peekback() requires a list-like table')
+end
+
+function ListIter.peekback(self)
+ if self._head ~= self._tail then
+ local inc = self._head < self._tail and 1 or -1
+ return self._table[self._tail - inc]
+ end
+end
+
+--- Skip values in the iterator.
+---
+--- Example:
+--- <pre>lua
+---
+--- local it = vim.iter({ 3, 6, 9, 12 }):skip(2)
+--- it:next()
+--- -- 9
+---
+--- </pre>
+---
+---@param n number Number of values to skip.
+---@return Iter
+function Iter.skip(self, n)
+ for _ = 1, n do
+ local _ = self:next()
+ end
+ return self
+end
+
+---@private
+function ListIter.skip(self, n)
+ local inc = self._head < self._tail and n or -n
+ self._head = self._head + inc
+ if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then
+ self._head = self._tail
+ end
+ return self
+end
+
+--- Skip values in the iterator starting from the end.
+---
+--- Only supported for iterators on list-like tables.
+---
+--- Example:
+--- <pre>lua
+--- local it = vim.iter({ 1, 2, 3, 4, 5 }):skipback(2)
+--- it:next()
+--- -- 1
+--- it:nextback()
+--- -- 3
+--- </pre>
+---
+---@param n number Number of values to skip.
+---@return Iter
+function Iter.skipback(self, n) -- luacheck: no unused args
+ error('skipback() requires a list-like table')
+ return self
+end
+
+---@private
+function ListIter.skipback(self, n)
+ local inc = self._head < self._tail and n or -n
+ self._tail = self._tail - inc
+ if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then
+ self._head = self._tail
+ end
+ return self
+end
+
+--- Return the nth value in the iterator.
+---
+--- This function advances the iterator.
+---
+--- Example:
+--- <pre>lua
+---
+--- local it = vim.iter({ 3, 6, 9, 12 })
+--- it:nth(2)
+--- -- 6
+--- it:nth(2)
+--- -- 12
+---
+--- </pre>
+---
+---@param n number The index of the value to return.
+---@return any
+function Iter.nth(self, n)
+ if n > 0 then
+ return self:skip(n - 1):next()
+ end
+end
+
+--- Return the nth value from the end of the iterator.
+---
+--- This function advances the iterator.
+---
+--- Only supported for iterators on list-like tables.
+---
+--- Example:
+--- <pre>lua
+---
+--- local it = vim.iter({ 3, 6, 9, 12 })
+--- it:nthback(2)
+--- -- 9
+--- it:nthback(2)
+--- -- 3
+---
+--- </pre>
+---
+---@param n number The index of the value to return.
+---@return any
+function Iter.nthback(self, n)
+ if n > 0 then
+ return self:skipback(n - 1):nextback()
+ end
+end
+
+--- Slice an iterator, changing its start and end positions.
+---
+--- This is equivalent to :skip(first - 1):skipback(len - last + 1)
+---
+--- Only supported for iterators on list-like tables.
+---
+---@param first number
+---@param last number
+---@return Iter
+function Iter.slice(self, first, last) -- luacheck: no unused args
+ return self:skip(math.max(0, first - 1)):skipback(math.max(0, self._tail - last - 1))
+end
+
+--- Return true if any of the items in the iterator match the given predicate.
+---
+---@param pred function(...):bool Predicate function. Takes all values returned from the previous
+--- stage in the pipeline as arguments and returns true if the
+--- predicate matches.
+function Iter.any(self, pred)
+ local any = false
+
+ --- Use a closure to handle var args returned from iterator
+ ---@private
+ local function fn(...)
+ if select(1, ...) ~= nil then
+ if pred(...) then
+ any = true
+ else
+ return true
+ end
+ end
+ end
+
+ while fn(self:next()) do
+ end
+ return any
+end
+
+--- Return true if all of the items in the iterator match the given predicate.
+---
+---@param pred function(...):bool Predicate function. Takes all values returned from the previous
+--- stage in the pipeline as arguments and returns true if the
+--- predicate matches.
+function Iter.all(self, pred)
+ local all = true
+
+ ---@private
+ local function fn(...)
+ if select(1, ...) ~= nil then
+ if not pred(...) then
+ all = false
+ else
+ return true
+ end
+ end
+ end
+
+ while fn(self:next()) do
+ end
+ return all
+end
+
+--- Return the last item in the iterator.
+---
+--- Drains the iterator.
+---
+--- Example:
+--- <pre>lua
+---
+--- local it = vim.iter(vim.gsplit('abcdefg', ''))
+--- it:last()
+--- -- 'g'
+---
+--- local it = vim.iter({ 3, 6, 9, 12, 15 })
+--- it:last()
+--- -- 15
+---
+--- </pre>
+---
+---@return any
+function Iter.last(self)
+ local last = self:next()
+ local cur = self:next()
+ while cur do
+ last = cur
+ cur = self:next()
+ end
+ return last
+end
+
+---@private
+function ListIter.last(self)
+ local inc = self._head < self._tail and 1 or -1
+ local v = self._table[self._tail - inc]
+ self._head = self._tail
+ return v
+end
+
+--- Add an iterator stage that returns the current iterator count as well as the iterator value.
+---
+--- For list tables, prefer
+--- <pre>lua
+--- vim.iter(ipairs(t))
+--- </pre>
+---
+--- over
+---
+--- <pre>lua
+--- vim.iter(t):enumerate()
+--- </pre>
+---
+--- as the former is faster.
+---
+--- Example:
+--- <pre>lua
+---
+--- local it = vim.iter(vim.gsplit('abc', '')):enumerate()
+--- it:next()
+--- -- 1 'a'
+--- it:next()
+--- -- 2 'b'
+--- it:next()
+--- -- 3 'c'
+---
+--- </pre>
+---
+---@return Iter
+function Iter.enumerate(self)
+ local i = 0
+ return self:map(function(...)
+ i = i + 1
+ return i, ...
+ end)
+end
+
+---@private
+function ListIter.enumerate(self)
+ local inc = self._head < self._tail and 1 or -1
+ for i = self._head, self._tail - inc, inc do
+ local v = self._table[i]
+ self._table[i] = { i, v }
+ end
+ return self
+end
+
+--- Create a new Iter object from a table or iterator.
+---
+---@param src table|function Table or iterator to drain values from
+---@return Iter
+function Iter.new(src, ...)
+ local it = {}
+ if type(src) == 'table' then
+ local t = {}
+
+ -- Check if source table can be treated like a list (indices are consecutive integers
+ -- starting from 1)
+ local count = 0
+ for _ in pairs(src) do
+ count = count + 1
+ local v = src[count]
+ if v == nil then
+ return TableIter.new(src)
+ end
+ t[count] = v
+ end
+ return ListIter.new(t)
+ end
+
+ if type(src) == 'function' then
+ local s, var = ...
+
+ --- Use a closure to handle var args returned from iterator
+ ---@private
+ local function fn(...)
+ if select(1, ...) ~= nil then
+ var = select(1, ...)
+ return ...
+ end
+ end
+
+ function it.next()
+ return fn(src(s, var))
+ end
+
+ setmetatable(it, Iter)
+ else
+ error('src must be a table or function')
+ end
+ return it
+end
+
+--- Create a new ListIter
+---
+---@param t table List-like table. Caller guarantees that this table is a valid list.
+---@return Iter
+---@private
+function ListIter.new(t)
+ local it = {}
+ it._table = t
+ it._head = 1
+ it._tail = #t + 1
+ setmetatable(it, ListIter)
+ return it
+end
+
+--- Create a new TableIter
+---
+---@param t table Table to iterate over. For list-like tables, use ListIter.new instead.
+---@return Iter
+---@private
+function TableIter.new(t)
+ local it = {}
+
+ local index = nil
+ function it.next()
+ local k, v = next(t, index)
+ if k ~= nil then
+ index = k
+ return k, v
+ end
+ end
+
+ setmetatable(it, TableIter)
+ return it
+end
+
+return Iter