aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRiley Bruins <ribru17@hotmail.com>2024-09-01 16:54:30 -0700
committerChristian Clason <ch.clason+github@icloud.com>2025-01-05 12:28:01 +0100
commitb61051ccb4c23958d43d285b8b801af11620264f (patch)
treeade40815a7e92af29910f3750d37e0192b6f4d7c
parent54ac406649b9e93d756ea62c1a6a587db462039c (diff)
downloadrneovim-b61051ccb4c23958d43d285b8b801af11620264f.tar.gz
rneovim-b61051ccb4c23958d43d285b8b801af11620264f.tar.bz2
rneovim-b61051ccb4c23958d43d285b8b801af11620264f.zip
feat(func): allow manual cache invalidation for _memoize
This commit also adds some tests for the existing memoization functionality.
-rw-r--r--runtime/lua/vim/func.lua14
-rw-r--r--runtime/lua/vim/func/_memoize.lua56
-rw-r--r--runtime/lua/vim/treesitter/query.lua4
-rw-r--r--test/functional/func/memoize_spec.lua142
4 files changed, 191 insertions, 25 deletions
diff --git a/runtime/lua/vim/func.lua b/runtime/lua/vim/func.lua
index f71659ffb4..fc8fa62c71 100644
--- a/runtime/lua/vim/func.lua
+++ b/runtime/lua/vim/func.lua
@@ -3,9 +3,6 @@ local M = {}
-- TODO(lewis6991): Private for now until:
-- - There are other places in the codebase that could benefit from this
-- (e.g. LSP), but might require other changes to accommodate.
--- - Invalidation of the cache needs to be controllable. Using weak tables
--- is an acceptable invalidation policy, but it shouldn't be the only
--- one.
-- - I don't think the story around `hash` is completely thought out. We
-- may be able to have a good default hash by hashing each argument,
-- so basically a better 'concat'.
@@ -17,6 +14,10 @@ local M = {}
--- Internally uses a |lua-weaktable| to cache the results of {fn} meaning the
--- cache will be invalidated whenever Lua does garbage collection.
---
+--- The cache can also be manually invalidated by calling `:clear()` on the returned object.
+--- Calling this function with no arguments clears the entire cache; otherwise, the arguments will
+--- be interpreted as function inputs, and only the cache entry at their hash will be cleared.
+---
--- The memoized function returns shared references so be wary about
--- mutating return values.
---
@@ -32,11 +33,12 @@ local M = {}
--- first n arguments passed to {fn}.
---
--- @param fn F Function to memoize.
---- @param strong? boolean Do not use a weak table
+--- @param weak? boolean Use a weak table (default `true`)
--- @return F # Memoized version of {fn}
--- @nodoc
-function M._memoize(hash, fn, strong)
- return require('vim.func._memoize')(hash, fn, strong)
+function M._memoize(hash, fn, weak)
+ -- this is wrapped in a function to lazily require the module
+ return require('vim.func._memoize')(hash, fn, weak)
end
return M
diff --git a/runtime/lua/vim/func/_memoize.lua b/runtime/lua/vim/func/_memoize.lua
index 6e557905a7..c46f878067 100644
--- a/runtime/lua/vim/func/_memoize.lua
+++ b/runtime/lua/vim/func/_memoize.lua
@@ -1,5 +1,7 @@
--- Module for private utility functions
+--- @alias vim.func.MemoObj { _hash: (fun(...): any), _weak: boolean?, _cache: table<any> }
+
--- @param argc integer?
--- @return fun(...): any
local function concat_hash(argc)
@@ -33,29 +35,49 @@ local function resolve_hash(hash)
return hash
end
+--- @param weak boolean?
+--- @return table
+local create_cache = function(weak)
+ return setmetatable({}, {
+ __mode = weak ~= false and 'kv',
+ })
+end
+
--- @generic F: function
--- @param hash integer|string|fun(...): any
--- @param fn F
---- @param strong? boolean
+--- @param weak? boolean
--- @return F
-return function(hash, fn, strong)
+return function(hash, fn, weak)
vim.validate('hash', hash, { 'number', 'string', 'function' })
vim.validate('fn', fn, 'function')
+ vim.validate('weak', weak, 'boolean', true)
- ---@type table<any,table<any,any>>
- local cache = {}
- if not strong then
- setmetatable(cache, { __mode = 'kv' })
- end
-
- hash = resolve_hash(hash)
+ --- @type vim.func.MemoObj
+ local obj = {
+ _cache = create_cache(weak),
+ _hash = resolve_hash(hash),
+ _weak = weak,
+ --- @param self vim.func.MemoObj
+ clear = function(self, ...)
+ if select('#', ...) == 0 then
+ self._cache = create_cache(self._weak)
+ return
+ end
+ local key = self._hash(...)
+ self._cache[key] = nil
+ end,
+ }
- return function(...)
- local key = hash(...)
- if cache[key] == nil then
- cache[key] = vim.F.pack_len(fn(...))
- end
-
- return vim.F.unpack_len(cache[key])
- end
+ return setmetatable(obj, {
+ --- @param self vim.func.MemoObj
+ __call = function(self, ...)
+ local key = self._hash(...)
+ local cache = self._cache
+ if cache[key] == nil then
+ cache[key] = vim.F.pack_len(fn(...))
+ end
+ return vim.F.unpack_len(cache[key])
+ end,
+ })
end
diff --git a/runtime/lua/vim/treesitter/query.lua b/runtime/lua/vim/treesitter/query.lua
index f9c497337f..2b3b9096a6 100644
--- a/runtime/lua/vim/treesitter/query.lua
+++ b/runtime/lua/vim/treesitter/query.lua
@@ -902,8 +902,8 @@ function Query:iter_captures(node, source, start, stop)
local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
- local apply_directives = memoize(match_id_hash, self.apply_directives, true)
- local match_preds = memoize(match_id_hash, self.match_preds, true)
+ local apply_directives = memoize(match_id_hash, self.apply_directives, false)
+ local match_preds = memoize(match_id_hash, self.match_preds, false)
local function iter(end_line)
local capture, captured_node, match = cursor:next_capture()
diff --git a/test/functional/func/memoize_spec.lua b/test/functional/func/memoize_spec.lua
new file mode 100644
index 0000000000..ca518ab88d
--- /dev/null
+++ b/test/functional/func/memoize_spec.lua
@@ -0,0 +1,142 @@
+local t = require('test.testutil')
+local n = require('test.functional.testnvim')()
+local clear = n.clear
+local exec_lua = n.exec_lua
+local eq = t.eq
+
+describe('vim.func._memoize', function()
+ before_each(clear)
+
+ it('caches function results based on their parameters', function()
+ exec_lua([[
+ _G.count = 0
+
+ local adder = vim.func._memoize('concat', function(arg1, arg2)
+ _G.count = _G.count + 1
+ return arg1 + arg2
+ end)
+
+ collectgarbage('stop')
+ adder(3, -4)
+ adder(3, -4)
+ adder(3, -4)
+ adder(3, -4)
+ adder(3, -4)
+ collectgarbage('restart')
+ ]])
+
+ eq(1, exec_lua([[return _G.count]]))
+ end)
+
+ it('caches function results using a weak table by default', function()
+ exec_lua([[
+ _G.count = 0
+
+ local adder = vim.func._memoize('concat-2', function(arg1, arg2)
+ _G.count = _G.count + 1
+ return arg1 + arg2
+ end)
+
+ adder(3, -4)
+ collectgarbage()
+ adder(3, -4)
+ collectgarbage()
+ adder(3, -4)
+ ]])
+
+ eq(3, exec_lua([[return _G.count]]))
+ end)
+
+ it('can cache using a strong table', function()
+ exec_lua([[
+ _G.count = 0
+
+ local adder = vim.func._memoize('concat-2', function(arg1, arg2)
+ _G.count = _G.count + 1
+ return arg1 + arg2
+ end, false)
+
+ adder(3, -4)
+ collectgarbage()
+ adder(3, -4)
+ collectgarbage()
+ adder(3, -4)
+ ]])
+
+ eq(1, exec_lua([[return _G.count]]))
+ end)
+
+ it('can clear a single cache entry', function()
+ exec_lua([[
+ _G.count = 0
+
+ local adder = vim.func._memoize(function(arg1, arg2)
+ return tostring(arg1) .. '%%' .. tostring(arg2)
+ end, function(arg1, arg2)
+ _G.count = _G.count + 1
+ return arg1 + arg2
+ end)
+
+ collectgarbage('stop')
+ adder(3, -4)
+ adder(3, -4)
+ adder(3, -4)
+ adder(3, -4)
+ adder(3, -4)
+ adder:clear(3, -4)
+ adder(3, -4)
+ collectgarbage('restart')
+ ]])
+
+ eq(2, exec_lua([[return _G.count]]))
+ end)
+
+ it('can clear the entire cache', function()
+ exec_lua([[
+ _G.count = 0
+
+ local adder = vim.func._memoize(function(arg1, arg2)
+ return tostring(arg1) .. '%%' .. tostring(arg2)
+ end, function(arg1, arg2)
+ _G.count = _G.count + 1
+ return arg1 + arg2
+ end)
+
+ collectgarbage('stop')
+ adder(1, 2)
+ adder(3, -4)
+ adder(1, 2)
+ adder(3, -4)
+ adder(1, 2)
+ adder(3, -4)
+ adder:clear()
+ adder(1, 2)
+ adder(3, -4)
+ collectgarbage('restart')
+ ]])
+
+ eq(4, exec_lua([[return _G.count]]))
+ end)
+
+ it('can cache functions that return nil', function()
+ exec_lua([[
+ _G.count = 0
+
+ local adder = vim.func._memoize('concat', function(arg1, arg2)
+ _G.count = _G.count + 1
+ return nil
+ end)
+
+ collectgarbage('stop')
+ adder(1, 2)
+ adder(1, 2)
+ adder(1, 2)
+ adder(1, 2)
+ adder:clear()
+ adder(1, 2)
+ collectgarbage('restart')
+ ]])
+
+ eq(2, exec_lua([[return _G.count]]))
+ end)
+end)