aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWill Hopkins <willothyh@gmail.com>2023-12-12 12:27:24 -0800
committerGitHub <noreply@github.com>2023-12-12 14:27:24 -0600
commit69ffbb76c237fcbba24de80f1b5346d92642e800 (patch)
treef6a59d9e31dedaff08b08a183e5484d04ee50d3f
parent1907abb4c27857fe7f4e7394f32e130f9955a2e7 (diff)
downloadrneovim-69ffbb76c237fcbba24de80f1b5346d92642e800.tar.gz
rneovim-69ffbb76c237fcbba24de80f1b5346d92642e800.tar.bz2
rneovim-69ffbb76c237fcbba24de80f1b5346d92642e800.zip
feat(iter): add `Iter.take` (#26525)
-rw-r--r--runtime/doc/lua.txt19
-rw-r--r--runtime/lua/vim/iter.lua35
-rw-r--r--test/functional/lua/iter_spec.lua27
3 files changed, 81 insertions, 0 deletions
diff --git a/runtime/doc/lua.txt b/runtime/doc/lua.txt
index f7f722bc0e..7e0ad5f4c3 100644
--- a/runtime/doc/lua.txt
+++ b/runtime/doc/lua.txt
@@ -3639,6 +3639,25 @@ Iter:slice({first}, {last}) *Iter:slice()*
Return: ~
Iter
+Iter:take({n}) *Iter:take()*
+ Transforms an iterator to yield only the first n values.
+
+ Example: >lua
+ local it = vim.iter({ 1, 2, 3, 4 }):take(2)
+ it:next()
+ -- 1
+ it:next()
+ -- 2
+ it:next()
+ -- nil
+<
+
+ Parameters: ~
+ • {n} (integer)
+
+ Return: ~
+ Iter
+
Iter:totable() *Iter:totable()*
Collect the iterator into a table.
diff --git a/runtime/lua/vim/iter.lua b/runtime/lua/vim/iter.lua
index e9c2b66bf2..8e602c406a 100644
--- a/runtime/lua/vim/iter.lua
+++ b/runtime/lua/vim/iter.lua
@@ -592,6 +592,41 @@ function ListIter.rfind(self, f) -- luacheck: no unused args
self._head = self._tail
end
+--- Transforms an iterator to yield only the first n values.
+---
+--- Example:
+---
+--- ```lua
+--- local it = vim.iter({ 1, 2, 3, 4 }):take(2)
+--- it:next()
+--- -- 1
+--- it:next()
+--- -- 2
+--- it:next()
+--- -- nil
+--- ```
+---
+---@param n integer
+---@return Iter
+function Iter.take(self, n)
+ local next = self.next
+ local i = 0
+ self.next = function()
+ if i < n then
+ i = i + 1
+ return next(self)
+ end
+ end
+ return self
+end
+
+---@private
+function ListIter.take(self, n)
+ local inc = self._head < self._tail and 1 or -1
+ self._tail = math.min(self._tail, self._head + n * inc)
+ return self
+end
+
--- "Pops" a value from a |list-iterator| (gets the last value and decrements the tail).
---
--- Example:
diff --git a/test/functional/lua/iter_spec.lua b/test/functional/lua/iter_spec.lua
index 2d28395c59..a589474262 100644
--- a/test/functional/lua/iter_spec.lua
+++ b/test/functional/lua/iter_spec.lua
@@ -203,6 +203,33 @@ describe('vim.iter', function()
matches('skipback%(%) requires a list%-like table', pcall_err(it.nthback, it, 1))
end)
+ it('take()', function()
+ do
+ local t = { 4, 3, 2, 1 }
+ eq({}, vim.iter(t):take(0):totable())
+ eq({ 4 }, vim.iter(t):take(1):totable())
+ eq({ 4, 3 }, vim.iter(t):take(2):totable())
+ eq({ 4, 3, 2 }, vim.iter(t):take(3):totable())
+ eq({ 4, 3, 2, 1 }, vim.iter(t):take(4):totable())
+ eq({ 4, 3, 2, 1 }, vim.iter(t):take(5):totable())
+ end
+
+ do
+ local t = { 4, 3, 2, 1 }
+ local it = vim.iter(t)
+ eq({ 4, 3 }, it:take(2):totable())
+ -- tail is already set from the previous take()
+ eq({ 4, 3 }, it:take(3):totable())
+ end
+
+ do
+ local it = vim.iter(vim.gsplit('a|b|c|d', '|'))
+ eq({ 'a', 'b' }, it:take(2):totable())
+ -- non-array iterators are consumed by take()
+ eq({}, it:take(2):totable())
+ end
+ end)
+
it('any()', function()
local function odd(v)
return v % 2 ~= 0