diff --git a/lua/lazyvim/util/init.lua b/lua/lazyvim/util/init.lua index fb1c9606..4a1478e6 100644 --- a/lua/lazyvim/util/init.lua +++ b/lua/lazyvim/util/init.lua @@ -269,19 +269,18 @@ for _, level in ipairs({ "info", "warn", "error" }) do end end -local cache = {} ---@type table +local cache = {} ---@type table<(fun()), table> ---@generic T: fun() ---@param fn T ---@return T function M.memoize(fn) - local info = debug.getinfo(fn, "S") - local keyprefix = info.source .. ":" .. info.linedefined .. ":" return function(...) - local key = keyprefix .. vim.inspect({ ... }) - if cache[key] == nil then - cache[key] = fn(...) + local key = vim.inspect({ ... }) + cache[fn] = cache[fn] or {} + if cache[fn][key] == nil then + cache[fn][key] = fn(...) end - return cache[key] + return cache[fn][key] end end diff --git a/tests/util/util_spec.lua b/tests/util/util_spec.lua new file mode 100644 index 00000000..0839eefe --- /dev/null +++ b/tests/util/util_spec.lua @@ -0,0 +1,36 @@ +---@module "luassert" + +local LazyVim = require("lazyvim.util") + +describe("util", function() + local t = 0 + local fn = function(a) + t = t + 1 + return a + end + + local m = LazyVim.memoize(fn) + + it("should memoize a function", function() + local a = m(1) + local b = m(1) + local c = m(2) + assert.are.equal(a, b) + assert.are.equal(a, 1) + assert.are.equal(c, 2) + assert.are.equal(t, 2) + assert.are_not.equal(a, c) + end) + + local f1 = LazyVim.memoize(function() + return 1 + end) + local f2 = LazyVim.memoize(function() + return 2 + end) + + it("should memoize based on the correct key", function() + assert.are.equal(f1(), 1) + assert.are.equal(f2(), 2) + end) +end)