Add mock.observe_func(), improve mock.lua documentation

observe_func() is similar to func() but passes through all calls to a specified
function.
develop
lethosor 2022-05-27 00:33:44 -04:00
parent c4febc789a
commit 888c531774
No known key found for this signature in database
GPG Key ID: 76A269552F4F58C1
2 changed files with 85 additions and 4 deletions

@ -32,12 +32,17 @@ function _patch_impl(patches_raw, callback, restore_only)
end end
--[[ --[[
Replaces `table[key]` with `value`, calls `callback()`, then restores the
original value of `table[key]`.
Usage: Usage:
patch(table, key, value, callback) patch(table, key, value, callback)
patch({ patch({
{table, key, value}, {table, key, value},
{table2, key2, value2}, {table2, key2, value2},
}, callback) }, callback)
]] ]]
function mock.patch(...) function mock.patch(...)
local args = {...} local args = {...}
@ -57,12 +62,18 @@ function mock.patch(...)
end end
--[[ --[[
Restores the original value of `table[key]` after calling `callback()`.
Equivalent to: patch(table, key, table[key], callback)
Usage: Usage:
restore(table, key, callback) restore(table, key, callback)
restore({ restore({
{table, key}, {table, key},
{table2, key2}, {table2, key2},
}, callback) }, callback)
]] ]]
function mock.restore(...) function mock.restore(...)
local args = {...} local args = {...}
@ -81,9 +92,19 @@ function mock.restore(...)
return _patch_impl(patches, callback, true) return _patch_impl(patches, callback, true)
end end
function mock.func(...) --[[
Returns a callable object that tracks the arguments it is called with, then
passes those arguments to `callback()`.
The returned object has the following properties:
- `call_count`: the number of times the object has been called
- `call_args`: a table of function arguments (shallow-copied) corresponding
to each time the object was called
]]
function mock.observe_func(callback)
local f = { local f = {
return_values = {...},
call_count = 0, call_count = 0,
call_args = {}, call_args = {},
} }
@ -101,11 +122,36 @@ function mock.func(...)
end end
end end
table.insert(self.call_args, args) table.insert(self.call_args, args)
return table.unpack(self.return_values) return callback(...)
end, end,
}) })
return f return f
end end
--[[
Returns a callable object similar to `mock.observe_func()`, but which when
called, only returns the given `return_value`(s) with no additional side effects.
Intended for use by `patch()`.
Usage:
func(return_value [, return_value2 ...])
See `observe_func()` for a description of the return value.
The return value also has an additional `return_values` field, which is a table
of values returned when the object is called. This can be modified.
]]
function mock.func(...)
local f
f = mock.observe_func(function()
return table.unpack(f.return_values)
end)
f.return_values = {...}
return f
end
return mock return mock

@ -208,9 +208,44 @@ function test.func_call_return_value()
end end
function test.func_call_return_multiple_values() function test.func_call_return_multiple_values()
local f = mock.func(7,5,{imatable='snarfsnarf'}) local f = mock.func(7, 5, {imatable='snarfsnarf'})
local a, b, c = f() local a, b, c = f()
expect.eq(7, a) expect.eq(7, a)
expect.eq(5, b) expect.eq(5, b)
expect.table_eq({imatable='snarfsnarf'}, c) expect.table_eq({imatable='snarfsnarf'}, c)
end end
function test.observe_func()
-- basic end-to-end test for common cases;
-- most edge cases are covered by mock.func() tests
local counter = 0
local function target()
counter = counter + 1
return counter
end
local observer = mock.observe_func(target)
expect.eq(observer(), 1)
expect.eq(counter, 1)
expect.eq(observer.call_count, 1)
expect.table_eq(observer.call_args, {{}})
expect.eq(observer('x', 'y'), 2)
expect.eq(counter, 2)
expect.eq(observer.call_count, 2)
expect.table_eq(observer.call_args, {{}, {'x', 'y'}})
end
function test.observe_func_error()
local function target()
error('asdf')
end
local observer = mock.observe_func(target)
expect.error_match('asdf', function()
observer('x')
end)
-- make sure the call was still tracked
expect.eq(observer.call_count, 1)
expect.table_eq(observer.call_args, {{'x'}})
end