aboutsummaryrefslogtreecommitdiff
path: root/runtime/lua/vim
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/lua/vim')
-rw-r--r--runtime/lua/vim/iter.lua164
1 files changed, 102 insertions, 62 deletions
diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua
index c2e2c5bd9f..bda3508262 100644
--- a/runtime/lua/vim/iter.lua
+++ b/runtime/lua/vim/iter.lua
@@ -1,13 +1,14 @@
---@defgroup lua-iter
---
---- The \*vim.iter\* module provides a generic "iterator" interface over tables and iterator
---- functions.
+--- The \*vim.iter\* module provides a generic "iterator" interface over tables
+--- and iterator functions.
---
---- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object with methods (such
---- as |Iter:filter()| and |Iter:map()|) that transform the underlying source data. These methods
---- can be chained together to create iterator "pipelines". Each pipeline stage receives as input
---- the output values from the prior stage. The values used in the first stage of the pipeline
---- depend on the type passed to this function:
+--- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object
+--- with methods (such as |Iter:filter()| and |Iter:map()|) that transform the
+--- underlying source data. These methods can be chained together to create
+--- iterator "pipelines". Each pipeline stage receives as input the output
+--- values from the prior stage. The values used in the first stage of the
+--- pipeline depend on the type passed to this function:
---
--- - List tables pass only the value of each element
--- - Non-list tables pass both the key and value of each element
@@ -47,8 +48,8 @@
--- -- true
--- </pre>
---
---- In addition to the |vim.iter()| function, the |vim.iter| module provides convenience functions
---- like |vim.iter.filter()| and |vim.iter.totable()|.
+--- In addition to the |vim.iter()| function, the |vim.iter| module provides
+--- convenience functions like |vim.iter.filter()| and |vim.iter.totable()|.
local M = {}
@@ -61,9 +62,9 @@ end
--- Special case implementations for iterators on list tables.
---@class ListIter : Iter
----@field _table table Underlying table data (table iterators only)
----@field _head number Index to the front of a table iterator (table iterators only)
----@field _tail number Index to the end of a table iterator (table iterators only)
+---@field _table table Underlying table data
+---@field _head number Index to the front of a table iterator
+---@field _tail number Index to the end of a table iterator
local ListIter = {}
ListIter.__index = setmetatable(ListIter, Iter)
ListIter.__call = function(self)
@@ -75,7 +76,7 @@ local packedmt = {}
---@private
local function unpack(t)
- if getmetatable(t) == packedmt then
+ if type(t) == 'table' and getmetatable(t) == packedmt then
return _G.unpack(t, 1, t.n)
end
return t
@@ -92,13 +93,47 @@ end
---@private
local function sanitize(t)
- if getmetatable(t) == packedmt then
+ if type(t) == 'table' and getmetatable(t) == packedmt then
-- Remove length tag
t.n = nil
end
return t
end
+--- Determine if the current iterator stage should continue.
+---
+--- If any arguments are passed to this function, then return those arguments
+--- and stop the current iterator stage. Otherwise, return true to signal that
+--- the current stage should continue.
+---
+---@param ... any Function arguments.
+---@return boolean True if the iterator stage should continue, false otherwise
+---@return any Function arguments.
+---@private
+local function continue(...)
+ if select('#', ...) > 0 then
+ return false, ...
+ end
+ return true
+end
+
+--- If no input arguments are given return false, indicating the current
+--- iterator stage should stop. Otherwise, apply the arguments to the function
+--- f. If that function returns no values, the current iterator stage continues.
+--- Otherwise, those values are returned.
+---
+---@param f function Function to call with the given arguments
+---@param ... any Arguments to apply to f
+---@return boolean True if the iterator pipeline should continue, false otherwise
+---@return any Return values of f
+---@private
+local function apply(f, ...)
+ if select('#', ...) > 0 then
+ return continue(f(...))
+ end
+ return false
+end
+
--- Add a filter step to the iterator pipeline.
---
--- Example:
@@ -106,33 +141,16 @@ end
--- local bufs = vim.iter(vim.api.nvim_list_bufs()):filter(vim.api.nvim_buf_is_loaded)
--- </pre>
---
----@param f function(...):bool Takes all values returned from the previous stage in the pipeline and
---- returns false or nil if the current iterator element should be
---- removed.
+---@param f function(...):bool Takes all values returned from the previous stage
+--- in the pipeline and returns false or nil if the
+--- current iterator element should be removed.
---@return Iter
function Iter.filter(self, f)
- ---@private
- local function fn(...)
- local result = nil
- if select(1, ...) ~= nil then
- if not f(...) then
- return true, nil
- else
- result = pack(...)
- end
+ return self:map(function(...)
+ if f(...) then
+ return ...
end
- return false, result
- end
-
- local next = self.next
- self.next = function(this)
- local cont, result
- repeat
- cont, result = fn(next(this))
- until not cont
- return unpack(result)
- end
- return self
+ end)
end
---@private
@@ -165,31 +183,52 @@ end
--- -- { 6, 12 }
--- </pre>
---
----@param f function(...):any Mapping function. Takes all values returned from the previous stage
---- in the pipeline as arguments and returns one or more new values,
---- which are used in the next pipeline stage. Nil return values returned
---- are filtered from the output.
+---@param f function(...):any Mapping function. Takes all values returned from
+--- the previous stage in the pipeline as arguments
+--- and returns one or more new values, which are used
+--- in the next pipeline stage. Nil return values
+--- are filtered from the output.
---@return Iter
function Iter.map(self, f)
+ -- Implementation note: the reader may be forgiven for observing that this
+ -- function appears excessively convoluted. The problem to solve is that each
+ -- stage of the iterator pipeline can return any number of values, and the
+ -- number of values could even change per iteration. And the return values
+ -- must be checked to determine if the pipeline has ended, so we cannot
+ -- naively forward them along to the next stage.
+ --
+ -- A simple approach is to pack all of the return values into a table, check
+ -- for nil, then unpack the table for the next stage. However, packing and
+ -- unpacking tables is quite slow. There is no other way in Lua to handle an
+ -- unknown number of function return values than to simply forward those
+ -- values along to another function. Hence the intricate function passing you
+ -- see here.
+
+ local next = self.next
+
+ --- Drain values from the upstream iterator source until a value can be
+ --- returned.
+ ---
+ --- This is a recursive function. The base case is when the first argument is
+ --- false, which indicates that the rest of the arguments should be returned
+ --- as the values for the current iteration stage.
+ ---
+ ---@param cont boolean If true, the current iterator stage should continue to
+ --- pull values from its upstream pipeline stage.
+ --- Otherwise, this stage is complete and returns the
+ --- values passed.
+ ---@param ... any Values to return if cont is false.
+ ---@return any
---@private
- local function fn(...)
- local result = nil
- if select(1, ...) ~= nil then
- result = pack(f(...))
- if result == nil then
- return true, nil
- end
+ local function fn(cont, ...)
+ if cont then
+ return fn(apply(f, next(self)))
end
- return false, result
+ return ...
end
- local next = self.next
- self.next = function(this)
- local cont, result
- repeat
- cont, result = fn(next(this))
- until not cont
- return unpack(result)
+ self.next = function()
+ return fn(apply(f, next(self)))
end
return self
end
@@ -211,17 +250,18 @@ end
--- Call a function once for each item in the pipeline.
---
---- This is used for functions which have side effects. To modify the values in the iterator, use
---- |Iter:map()|.
+--- This is used for functions which have side effects. To modify the values in
+--- the iterator, use |Iter:map()|.
---
--- This function drains the iterator.
---
----@param f function(...) Function to execute for each item in the pipeline. Takes all of the
---- values returned by the previous stage in the pipeline as arguments.
+---@param f function(...) Function to execute for each item in the pipeline.
+--- Takes all of the values returned by the previous stage
+--- in the pipeline as arguments.
function Iter.each(self, f)
---@private
local function fn(...)
- if select(1, ...) ~= nil then
+ if select('#', ...) > 0 then
f(...)
return true
end