diff options
-rw-r--r-- | runtime/doc/lua.txt | 2 | ||||
-rw-r--r-- | runtime/lua/vim/iter.lua | 164 | ||||
-rw-r--r-- | test/benchmark/iter_spec.lua | 215 |
3 files changed, 318 insertions, 63 deletions
diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt index 465c83e6ab..820bd0eb35 100644 --- a/runtime/doc/lua.txt +++ b/runtime/doc/lua.txt @@ -3080,7 +3080,7 @@ Iter:map({self}, {f}) *Iter:map()* • {f} function(...):any Mapping function. Takes all values returned from the previous stage in the pipeline as arguments and returns one or more new values, which are used in the next pipeline - stage. Nil return values returned are filtered from the output. + stage. Nil return values are filtered from the output. Return: ~ Iter diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua index c2e2c5bd9f..bda3508262 100644 --- a/runtime/lua/vim/iter.lua +++ b/runtime/lua/vim/iter.lua @@ -1,13 +1,14 @@ ---@defgroup lua-iter --- ---- The \*vim.iter\* module provides a generic "iterator" interface over tables and iterator ---- functions. +--- The \*vim.iter\* module provides a generic "iterator" interface over tables +--- and iterator functions. --- ---- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object with methods (such ---- as |Iter:filter()| and |Iter:map()|) that transform the underlying source data. These methods ---- can be chained together to create iterator "pipelines". Each pipeline stage receives as input ---- the output values from the prior stage. The values used in the first stage of the pipeline ---- depend on the type passed to this function: +--- \*vim.iter()\* wraps its table or function argument into an \*Iter\* object +--- with methods (such as |Iter:filter()| and |Iter:map()|) that transform the +--- underlying source data. These methods can be chained together to create +--- iterator "pipelines". Each pipeline stage receives as input the output +--- values from the prior stage. The values used in the first stage of the +--- pipeline depend on the type passed to this function: --- --- - List tables pass only the value of each element --- - Non-list tables pass both the key and value of each element @@ -47,8 +48,8 @@ --- -- true --- </pre> --- ---- In addition to the |vim.iter()| function, the |vim.iter| module provides convenience functions ---- like |vim.iter.filter()| and |vim.iter.totable()|. +--- In addition to the |vim.iter()| function, the |vim.iter| module provides +--- convenience functions like |vim.iter.filter()| and |vim.iter.totable()|. local M = {} @@ -61,9 +62,9 @@ end --- Special case implementations for iterators on list tables. ---@class ListIter : Iter ----@field _table table Underlying table data (table iterators only) ----@field _head number Index to the front of a table iterator (table iterators only) ----@field _tail number Index to the end of a table iterator (table iterators only) +---@field _table table Underlying table data +---@field _head number Index to the front of a table iterator +---@field _tail number Index to the end of a table iterator local ListIter = {} ListIter.__index = setmetatable(ListIter, Iter) ListIter.__call = function(self) @@ -75,7 +76,7 @@ local packedmt = {} ---@private local function unpack(t) - if getmetatable(t) == packedmt then + if type(t) == 'table' and getmetatable(t) == packedmt then return _G.unpack(t, 1, t.n) end return t @@ -92,13 +93,47 @@ end ---@private local function sanitize(t) - if getmetatable(t) == packedmt then + if type(t) == 'table' and getmetatable(t) == packedmt then -- Remove length tag t.n = nil end return t end +--- Determine if the current iterator stage should continue. +--- +--- If any arguments are passed to this function, then return those arguments +--- and stop the current iterator stage. Otherwise, return true to signal that +--- the current stage should continue. +--- +---@param ... any Function arguments. +---@return boolean True if the iterator stage should continue, false otherwise +---@return any Function arguments. +---@private +local function continue(...) + if select('#', ...) > 0 then + return false, ... + end + return true +end + +--- If no input arguments are given return false, indicating the current +--- iterator stage should stop. Otherwise, apply the arguments to the function +--- f. If that function returns no values, the current iterator stage continues. +--- Otherwise, those values are returned. +--- +---@param f function Function to call with the given arguments +---@param ... any Arguments to apply to f +---@return boolean True if the iterator pipeline should continue, false otherwise +---@return any Return values of f +---@private +local function apply(f, ...) + if select('#', ...) > 0 then + return continue(f(...)) + end + return false +end + --- Add a filter step to the iterator pipeline. --- --- Example: @@ -106,33 +141,16 @@ end --- local bufs = vim.iter(vim.api.nvim_list_bufs()):filter(vim.api.nvim_buf_is_loaded) --- </pre> --- ----@param f function(...):bool Takes all values returned from the previous stage in the pipeline and ---- returns false or nil if the current iterator element should be ---- removed. +---@param f function(...):bool Takes all values returned from the previous stage +--- in the pipeline and returns false or nil if the +--- current iterator element should be removed. ---@return Iter function Iter.filter(self, f) - ---@private - local function fn(...) - local result = nil - if select(1, ...) ~= nil then - if not f(...) then - return true, nil - else - result = pack(...) - end + return self:map(function(...) + if f(...) then + return ... end - return false, result - end - - local next = self.next - self.next = function(this) - local cont, result - repeat - cont, result = fn(next(this)) - until not cont - return unpack(result) - end - return self + end) end ---@private @@ -165,31 +183,52 @@ end --- -- { 6, 12 } --- </pre> --- ----@param f function(...):any Mapping function. Takes all values returned from the previous stage ---- in the pipeline as arguments and returns one or more new values, ---- which are used in the next pipeline stage. Nil return values returned ---- are filtered from the output. +---@param f function(...):any Mapping function. Takes all values returned from +--- the previous stage in the pipeline as arguments +--- and returns one or more new values, which are used +--- in the next pipeline stage. Nil return values +--- are filtered from the output. ---@return Iter function Iter.map(self, f) + -- Implementation note: the reader may be forgiven for observing that this + -- function appears excessively convoluted. The problem to solve is that each + -- stage of the iterator pipeline can return any number of values, and the + -- number of values could even change per iteration. And the return values + -- must be checked to determine if the pipeline has ended, so we cannot + -- naively forward them along to the next stage. + -- + -- A simple approach is to pack all of the return values into a table, check + -- for nil, then unpack the table for the next stage. However, packing and + -- unpacking tables is quite slow. There is no other way in Lua to handle an + -- unknown number of function return values than to simply forward those + -- values along to another function. Hence the intricate function passing you + -- see here. + + local next = self.next + + --- Drain values from the upstream iterator source until a value can be + --- returned. + --- + --- This is a recursive function. The base case is when the first argument is + --- false, which indicates that the rest of the arguments should be returned + --- as the values for the current iteration stage. + --- + ---@param cont boolean If true, the current iterator stage should continue to + --- pull values from its upstream pipeline stage. + --- Otherwise, this stage is complete and returns the + --- values passed. + ---@param ... any Values to return if cont is false. + ---@return any ---@private - local function fn(...) - local result = nil - if select(1, ...) ~= nil then - result = pack(f(...)) - if result == nil then - return true, nil - end + local function fn(cont, ...) + if cont then + return fn(apply(f, next(self))) end - return false, result + return ... end - local next = self.next - self.next = function(this) - local cont, result - repeat - cont, result = fn(next(this)) - until not cont - return unpack(result) + self.next = function() + return fn(apply(f, next(self))) end return self end @@ -211,17 +250,18 @@ end --- Call a function once for each item in the pipeline. --- ---- This is used for functions which have side effects. To modify the values in the iterator, use ---- |Iter:map()|. +--- This is used for functions which have side effects. To modify the values in +--- the iterator, use |Iter:map()|. --- --- This function drains the iterator. --- ----@param f function(...) Function to execute for each item in the pipeline. Takes all of the ---- values returned by the previous stage in the pipeline as arguments. +---@param f function(...) Function to execute for each item in the pipeline. +--- Takes all of the values returned by the previous stage +--- in the pipeline as arguments. function Iter.each(self, f) ---@private local function fn(...) - if select(1, ...) ~= nil then + if select('#', ...) > 0 then f(...) return true end diff --git a/test/benchmark/iter_spec.lua b/test/benchmark/iter_spec.lua new file mode 100644 index 0000000000..8d77054e83 --- /dev/null +++ b/test/benchmark/iter_spec.lua @@ -0,0 +1,215 @@ +local N = 500 +local test_table_size = 100000 + +describe('vim.iter perf', function() + local function mean(t) + assert(#t > 0) + local sum = 0 + for _, v in ipairs(t) do + sum = sum + v + end + return sum / #t + end + + local function median(t) + local len = #t + if len % 2 == 0 then + return t[len / 2] + end + return t[(len + 1) / 2] + end + + -- Assert that results are equal between each benchmark + local last = nil + + local function reset() + last = nil + end + + local input = {} + for i = 1, test_table_size do + input[#input + 1] = i + end + + local function measure(f) + local stats = {} + local result + for _ = 1, N do + local tic = vim.loop.hrtime() + result = f(input) + local toc = vim.loop.hrtime() + stats[#stats + 1] = (toc - tic) / 1000000 + end + table.sort(stats) + print( + string.format( + '\nMin: %0.6f ms, Max: %0.6f ms, Median: %0.6f ms, Mean: %0.6f ms', + math.min(unpack(stats)), + math.max(unpack(stats)), + median(stats), + mean(stats) + ) + ) + + if last ~= nil then + assert(#result == #last) + for i, v in ipairs(result) do + if type(v) == 'string' or type(v) == 'number' then + assert(last[i] == v) + elseif type(v) == 'table' then + for k, vv in pairs(v) do + assert(last[i][k] == vv) + end + end + end + end + + last = result + end + + describe('list like table', function() + describe('simple map', function() + reset() + + it('vim.iter', function() + local function f(t) + return vim + .iter(t) + :map(function(v) + return v * 2 + end) + :totable() + end + measure(f) + end) + + it('for loop', function() + local function f(t) + local res = {} + for i = 1, #t do + res[#res + 1] = t[i] * 2 + end + return res + end + measure(f) + end) + end) + + describe('filter, map, skip, reverse', function() + reset() + + it('vim.iter', function() + local function f(t) + local i = 0 + return vim + .iter(t) + :map(function(v) + i = i + 1 + if i % 2 == 0 then + return v * 2 + end + end) + :skip(1000) + :rev() + :totable() + end + measure(f) + end) + + it('tables', function() + local function f(t) + local a = {} + for i = 1, #t do + if i % 2 == 0 then + a[#a + 1] = t[i] * 2 + end + end + + local b = {} + for i = 1001, #a do + b[#b + 1] = a[i] + end + + local c = {} + for i = 1, #b do + c[#c + 1] = b[#b - i + 1] + end + return c + end + measure(f) + end) + end) + end) + + describe('iterator', function() + describe('simple map', function() + reset() + it('vim.iter', function() + local function f(t) + return vim + .iter(ipairs(t)) + :map(function(i, v) + return i + v + end) + :totable() + end + measure(f) + end) + + it('ipairs', function() + local function f(t) + local res = {} + for i, v in ipairs(t) do + res[#res + 1] = i + v + end + return res + end + measure(f) + end) + end) + + describe('multiple stages', function() + reset() + it('vim.iter', function() + local function f(t) + return vim + .iter(ipairs(t)) + :map(function(i, v) + if i % 2 ~= 0 then + return v + end + end) + :map(function(v) + return v * 3 + end) + :skip(50) + :totable() + end + measure(f) + end) + + it('ipairs', function() + local function f(t) + local a = {} + for i, v in ipairs(t) do + if i % 2 ~= 0 then + a[#a + 1] = v + end + end + local b = {} + for _, v in ipairs(a) do + b[#b + 1] = v * 3 + end + local c = {} + for i, v in ipairs(b) do + if i > 50 then + c[#c + 1] = v + end + end + return c + end + measure(f) + end) + end) + end) +end) |