aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--runtime/doc/lua.txt2
-rw-r--r--runtime/lua/vim/iter.lua164
-rw-r--r--test/benchmark/iter_spec.lua215
3 files changed, 318 insertions, 63 deletions
diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt
index 465c83e6ab..820bd0eb35 100644
--- a/runtime/doc/lua.txt
+++ b/runtime/doc/lua.txt
@@ -3080,7 +3080,7 @@ Iter:map({self}, {f}) *Iter:map()*
• {f} function(...):any Mapping function. Takes all values returned
from the previous stage in the pipeline as arguments and returns
one or more new values, which are used in the next pipeline
- stage. Nil return values returned are filtered from the output.
+ stage. Nil return values are filtered from the output.
Return: ~
Iter
diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua
index c2e2c5bd9f..bda3508262 100644
--- a/runtime/lua/vim/iter.lua
+++ b/runtime/lua/vim/iter.lua
@@ -1,13 +1,14 @@
---@defgroup lua-iter
---
---- The \*vim.iter\* module provides a generic "iterator" interface over tables and iterator
---- functions.
+--- The \*vim.iter\* module provides a generic "iterator" interface over tables
+--- and iterator functions.
---
---- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object with methods (such
---- as |Iter:filter()| and |Iter:map()|) that transform the underlying source data. These methods
---- can be chained together to create iterator "pipelines". Each pipeline stage receives as input
---- the output values from the prior stage. The values used in the first stage of the pipeline
---- depend on the type passed to this function:
+--- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object
+--- with methods (such as |Iter:filter()| and |Iter:map()|) that transform the
+--- underlying source data. These methods can be chained together to create
+--- iterator "pipelines". Each pipeline stage receives as input the output
+--- values from the prior stage. The values used in the first stage of the
+--- pipeline depend on the type passed to this function:
---
--- - List tables pass only the value of each element
--- - Non-list tables pass both the key and value of each element
@@ -47,8 +48,8 @@
--- -- true
--- </pre>
---
---- In addition to the |vim.iter()| function, the |vim.iter| module provides convenience functions
---- like |vim.iter.filter()| and |vim.iter.totable()|.
+--- In addition to the |vim.iter()| function, the |vim.iter| module provides
+--- convenience functions like |vim.iter.filter()| and |vim.iter.totable()|.
local M = {}
@@ -61,9 +62,9 @@ end
--- Special case implementations for iterators on list tables.
---@class ListIter : Iter
----@field _table table Underlying table data (table iterators only)
----@field _head number Index to the front of a table iterator (table iterators only)
----@field _tail number Index to the end of a table iterator (table iterators only)
+---@field _table table Underlying table data
+---@field _head number Index to the front of a table iterator
+---@field _tail number Index to the end of a table iterator
local ListIter = {}
ListIter.__index = setmetatable(ListIter, Iter)
ListIter.__call = function(self)
@@ -75,7 +76,7 @@ local packedmt = {}
---@private
local function unpack(t)
- if getmetatable(t) == packedmt then
+ if type(t) == 'table' and getmetatable(t) == packedmt then
return _G.unpack(t, 1, t.n)
end
return t
@@ -92,13 +93,47 @@ end
---@private
local function sanitize(t)
- if getmetatable(t) == packedmt then
+ if type(t) == 'table' and getmetatable(t) == packedmt then
-- Remove length tag
t.n = nil
end
return t
end
+--- Determine if the current iterator stage should continue.
+---
+--- If any arguments are passed to this function, then return those arguments
+--- and stop the current iterator stage. Otherwise, return true to signal that
+--- the current stage should continue.
+---
+---@param ... any Function arguments.
+---@return boolean True if the iterator stage should continue, false otherwise
+---@return any Function arguments.
+---@private
+local function continue(...)
+ if select('#', ...) > 0 then
+ return false, ...
+ end
+ return true
+end
+
+--- If no input arguments are given return false, indicating the current
+--- iterator stage should stop. Otherwise, apply the arguments to the function
+--- f. If that function returns no values, the current iterator stage continues.
+--- Otherwise, those values are returned.
+---
+---@param f function Function to call with the given arguments
+---@param ... any Arguments to apply to f
+---@return boolean True if the iterator pipeline should continue, false otherwise
+---@return any Return values of f
+---@private
+local function apply(f, ...)
+ if select('#', ...) > 0 then
+ return continue(f(...))
+ end
+ return false
+end
+
--- Add a filter step to the iterator pipeline.
---
--- Example:
@@ -106,33 +141,16 @@ end
--- local bufs = vim.iter(vim.api.nvim_list_bufs()):filter(vim.api.nvim_buf_is_loaded)
--- </pre>
---
----@param f function(...):bool Takes all values returned from the previous stage in the pipeline and
---- returns false or nil if the current iterator element should be
---- removed.
+---@param f function(...):bool Takes all values returned from the previous stage
+--- in the pipeline and returns false or nil if the
+--- current iterator element should be removed.
---@return Iter
function Iter.filter(self, f)
- ---@private
- local function fn(...)
- local result = nil
- if select(1, ...) ~= nil then
- if not f(...) then
- return true, nil
- else
- result = pack(...)
- end
+ return self:map(function(...)
+ if f(...) then
+ return ...
end
- return false, result
- end
-
- local next = self.next
- self.next = function(this)
- local cont, result
- repeat
- cont, result = fn(next(this))
- until not cont
- return unpack(result)
- end
- return self
+ end)
end
---@private
@@ -165,31 +183,52 @@ end
--- -- { 6, 12 }
--- </pre>
---
----@param f function(...):any Mapping function. Takes all values returned from the previous stage
---- in the pipeline as arguments and returns one or more new values,
---- which are used in the next pipeline stage. Nil return values returned
---- are filtered from the output.
+---@param f function(...):any Mapping function. Takes all values returned from
+--- the previous stage in the pipeline as arguments
+--- and returns one or more new values, which are used
+--- in the next pipeline stage. Nil return values
+--- are filtered from the output.
---@return Iter
function Iter.map(self, f)
+ -- Implementation note: the reader may be forgiven for observing that this
+ -- function appears excessively convoluted. The problem to solve is that each
+ -- stage of the iterator pipeline can return any number of values, and the
+ -- number of values could even change per iteration. And the return values
+ -- must be checked to determine if the pipeline has ended, so we cannot
+ -- naively forward them along to the next stage.
+ --
+ -- A simple approach is to pack all of the return values into a table, check
+ -- for nil, then unpack the table for the next stage. However, packing and
+ -- unpacking tables is quite slow. There is no other way in Lua to handle an
+ -- unknown number of function return values than to simply forward those
+ -- values along to another function. Hence the intricate function passing you
+ -- see here.
+
+ local next = self.next
+
+ --- Drain values from the upstream iterator source until a value can be
+ --- returned.
+ ---
+ --- This is a recursive function. The base case is when the first argument is
+ --- false, which indicates that the rest of the arguments should be returned
+ --- as the values for the current iteration stage.
+ ---
+ ---@param cont boolean If true, the current iterator stage should continue to
+ --- pull values from its upstream pipeline stage.
+ --- Otherwise, this stage is complete and returns the
+ --- values passed.
+ ---@param ... any Values to return if cont is false.
+ ---@return any
---@private
- local function fn(...)
- local result = nil
- if select(1, ...) ~= nil then
- result = pack(f(...))
- if result == nil then
- return true, nil
- end
+ local function fn(cont, ...)
+ if cont then
+ return fn(apply(f, next(self)))
end
- return false, result
+ return ...
end
- local next = self.next
- self.next = function(this)
- local cont, result
- repeat
- cont, result = fn(next(this))
- until not cont
- return unpack(result)
+ self.next = function()
+ return fn(apply(f, next(self)))
end
return self
end
@@ -211,17 +250,18 @@ end
--- Call a function once for each item in the pipeline.
---
---- This is used for functions which have side effects. To modify the values in the iterator, use
---- |Iter:map()|.
+--- This is used for functions which have side effects. To modify the values in
+--- the iterator, use |Iter:map()|.
---
--- This function drains the iterator.
---
----@param f function(...) Function to execute for each item in the pipeline. Takes all of the
---- values returned by the previous stage in the pipeline as arguments.
+---@param f function(...) Function to execute for each item in the pipeline.
+--- Takes all of the values returned by the previous stage
+--- in the pipeline as arguments.
function Iter.each(self, f)
---@private
local function fn(...)
- if select(1, ...) ~= nil then
+ if select('#', ...) > 0 then
f(...)
return true
end
diff --git a/test/benchmark/iter_spec.lua b/test/benchmark/iter_spec.lua
new file mode 100644
index 0000000000..8d77054e83
--- /dev/null
+++ b/test/benchmark/iter_spec.lua
@@ -0,0 +1,215 @@
+local N = 500
+local test_table_size = 100000
+
+describe('vim.iter perf', function()
+ local function mean(t)
+ assert(#t > 0)
+ local sum = 0
+ for _, v in ipairs(t) do
+ sum = sum + v
+ end
+ return sum / #t
+ end
+
+ local function median(t)
+ local len = #t
+ if len % 2 == 0 then
+ return t[len / 2]
+ end
+ return t[(len + 1) / 2]
+ end
+
+ -- Assert that results are equal between each benchmark
+ local last = nil
+
+ local function reset()
+ last = nil
+ end
+
+ local input = {}
+ for i = 1, test_table_size do
+ input[#input + 1] = i
+ end
+
+ local function measure(f)
+ local stats = {}
+ local result
+ for _ = 1, N do
+ local tic = vim.loop.hrtime()
+ result = f(input)
+ local toc = vim.loop.hrtime()
+ stats[#stats + 1] = (toc - tic) / 1000000
+ end
+ table.sort(stats)
+ print(
+ string.format(
+ '\nMin: %0.6f ms, Max: %0.6f ms, Median: %0.6f ms, Mean: %0.6f ms',
+ math.min(unpack(stats)),
+ math.max(unpack(stats)),
+ median(stats),
+ mean(stats)
+ )
+ )
+
+ if last ~= nil then
+ assert(#result == #last)
+ for i, v in ipairs(result) do
+ if type(v) == 'string' or type(v) == 'number' then
+ assert(last[i] == v)
+ elseif type(v) == 'table' then
+ for k, vv in pairs(v) do
+ assert(last[i][k] == vv)
+ end
+ end
+ end
+ end
+
+ last = result
+ end
+
+ describe('list like table', function()
+ describe('simple map', function()
+ reset()
+
+ it('vim.iter', function()
+ local function f(t)
+ return vim
+ .iter(t)
+ :map(function(v)
+ return v * 2
+ end)
+ :totable()
+ end
+ measure(f)
+ end)
+
+ it('for loop', function()
+ local function f(t)
+ local res = {}
+ for i = 1, #t do
+ res[#res + 1] = t[i] * 2
+ end
+ return res
+ end
+ measure(f)
+ end)
+ end)
+
+ describe('filter, map, skip, reverse', function()
+ reset()
+
+ it('vim.iter', function()
+ local function f(t)
+ local i = 0
+ return vim
+ .iter(t)
+ :map(function(v)
+ i = i + 1
+ if i % 2 == 0 then
+ return v * 2
+ end
+ end)
+ :skip(1000)
+ :rev()
+ :totable()
+ end
+ measure(f)
+ end)
+
+ it('tables', function()
+ local function f(t)
+ local a = {}
+ for i = 1, #t do
+ if i % 2 == 0 then
+ a[#a + 1] = t[i] * 2
+ end
+ end
+
+ local b = {}
+ for i = 1001, #a do
+ b[#b + 1] = a[i]
+ end
+
+ local c = {}
+ for i = 1, #b do
+ c[#c + 1] = b[#b - i + 1]
+ end
+ return c
+ end
+ measure(f)
+ end)
+ end)
+ end)
+
+ describe('iterator', function()
+ describe('simple map', function()
+ reset()
+ it('vim.iter', function()
+ local function f(t)
+ return vim
+ .iter(ipairs(t))
+ :map(function(i, v)
+ return i + v
+ end)
+ :totable()
+ end
+ measure(f)
+ end)
+
+ it('ipairs', function()
+ local function f(t)
+ local res = {}
+ for i, v in ipairs(t) do
+ res[#res + 1] = i + v
+ end
+ return res
+ end
+ measure(f)
+ end)
+ end)
+
+ describe('multiple stages', function()
+ reset()
+ it('vim.iter', function()
+ local function f(t)
+ return vim
+ .iter(ipairs(t))
+ :map(function(i, v)
+ if i % 2 ~= 0 then
+ return v
+ end
+ end)
+ :map(function(v)
+ return v * 3
+ end)
+ :skip(50)
+ :totable()
+ end
+ measure(f)
+ end)
+
+ it('ipairs', function()
+ local function f(t)
+ local a = {}
+ for i, v in ipairs(t) do
+ if i % 2 ~= 0 then
+ a[#a + 1] = v
+ end
+ end
+ local b = {}
+ for _, v in ipairs(a) do
+ b[#b + 1] = v * 3
+ end
+ local c = {}
+ for i, v in ipairs(b) do
+ if i > 50 then
+ c[#c + 1] = v
+ end
+ end
+ return c
+ end
+ measure(f)
+ end)
+ end)
+ end)
+end)