dfhack/library/lua/test_util/mock.lua

158 lines
3.8 KiB
Lua

local mock = mkmodule('test_util.mock')
function _patch_impl(patches_raw, callback, restore_only)
local patches = {}
for _, v in ipairs(patches_raw) do
local p = {
2021-04-08 22:35:54 -06:00
table = v[1],
key = v[2],
new_value = v[3],
}
p.old_value = p.table[p.key]
-- no-op to ensure that the value can be restored by the finalizer below
2021-04-08 22:35:54 -06:00
p.table[p.key] = p.old_value
table.insert(patches, p)
end
2021-04-08 22:41:39 -06:00
return dfhack.with_finalize(
function()
for _, p in ipairs(patches) do
p.table[p.key] = p.old_value
end
end,
function()
if not restore_only then
for _, p in ipairs(patches) do
p.table[p.key] = p.new_value
end
end
2021-04-08 22:41:39 -06:00
return callback()
end
)
end
--[[
Replaces `table[key]` with `value`, calls `callback()`, then restores the
original value of `table[key]`.
Usage:
patch(table, key, value, callback)
patch({
{table, key, value},
{table2, key2, value2},
}, callback)
]]
function mock.patch(...)
local args = {...}
local patches
local callback
if #args == 4 then
patches = {{args[1], args[2], args[3]}}
callback = args[4]
elseif #args == 2 then
patches = args[1]
callback = args[2]
else
error('expected 2 or 4 arguments')
end
return _patch_impl(patches, callback)
end
--[[
Restores the original value of `table[key]` after calling `callback()`.
Equivalent to: patch(table, key, table[key], callback)
Usage:
restore(table, key, callback)
restore({
{table, key},
{table2, key2},
}, callback)
]]
function mock.restore(...)
local args = {...}
local patches
local callback
if #args == 3 then
patches = {{args[1], args[2]}}
callback = args[3]
elseif #args == 2 then
patches = args[1]
callback = args[2]
else
error('expected 2 or 3 arguments')
end
return _patch_impl(patches, callback, true)
end
--[[
Returns a callable object that tracks the arguments it is called with, then
passes those arguments to `callback()`.
The returned object has the following properties:
- `call_count`: the number of times the object has been called
- `call_args`: a table of function arguments (shallow-copied) corresponding
to each time the object was called
]]
function mock.observe_func(callback)
local f = {
call_count = 0,
call_args = {},
}
setmetatable(f, {
__call = function(self, ...)
self.call_count = self.call_count + 1
local args = {...}
for i,v in ipairs(args) do
if type(v) == 'table' then
-- just a shallow copy, but it offers some ability to
-- inspect original values in tables that were altered after
-- the call
args[i] = copyall(v)
end
end
table.insert(self.call_args, args)
return callback(...)
end,
})
return f
end
--[[
Returns a callable object similar to `mock.observe_func()`, but which when
called, only returns the given `return_value`(s) with no additional side effects.
Intended for use by `patch()`.
Usage:
func(return_value [, return_value2 ...])
See `observe_func()` for a description of the return value.
The return value also has an additional `return_values` field, which is a table
of values returned when the object is called. This can be modified.
]]
function mock.func(...)
local f
f = mock.observe_func(function()
return table.unpack(f.return_values)
end)
f.return_values = {...}
return f
end
return mock