diff --git a/library/lua/test_util/mock.lua b/library/lua/test_util/mock.lua index c60646b77..8d253cc10 100644 --- a/library/lua/test_util/mock.lua +++ b/library/lua/test_util/mock.lua @@ -32,12 +32,17 @@ function _patch_impl(patches_raw, callback, restore_only) end --[[ + +Replaces `table[key]` with `value`, calls `callback()`, then restores the +original value of `table[key]`. + Usage: patch(table, key, value, callback) patch({ {table, key, value}, {table2, key2, value2}, }, callback) + ]] function mock.patch(...) local args = {...} @@ -57,12 +62,18 @@ function mock.patch(...) end --[[ + +Restores the original value of `table[key]` after calling `callback()`. + +Equivalent to: patch(table, key, table[key], callback) + Usage: restore(table, key, callback) restore({ {table, key}, {table2, key2}, }, callback) + ]] function mock.restore(...) local args = {...} @@ -81,9 +92,19 @@ function mock.restore(...) return _patch_impl(patches, callback, true) end -function mock.func(...) +--[[ + +Returns a callable object that tracks the arguments it is called with, then +passes those arguments to `callback()`. + +The returned object has the following properties: +- `call_count`: the number of times the object has been called +- `call_args`: a table of function arguments (shallow-copied) corresponding + to each time the object was called + +]] +function mock.observe_func(callback) local f = { - return_values = {...}, call_count = 0, call_args = {}, } @@ -101,11 +122,36 @@ function mock.func(...) end end table.insert(self.call_args, args) - return table.unpack(self.return_values) + return callback(...) end, }) return f end +--[[ + +Returns a callable object similar to `mock.observe_func()`, but which when +called, only returns the given `return_value`(s) with no additional side effects. + +Intended for use by `patch()`. + +Usage: + func(return_value [, return_value2 ...]) + +See `observe_func()` for a description of the return value. + +The return value also has an additional `return_values` field, which is a table +of values returned when the object is called. This can be modified. + +]] +function mock.func(...) + local f + f = mock.observe_func(function() + return table.unpack(f.return_values) + end) + f.return_values = {...} + return f +end + return mock diff --git a/test/library/test_util/mock.lua b/test/library/test_util/mock.lua index 32aed28e1..1031a496a 100644 --- a/test/library/test_util/mock.lua +++ b/test/library/test_util/mock.lua @@ -208,9 +208,44 @@ function test.func_call_return_value() end function test.func_call_return_multiple_values() - local f = mock.func(7,5,{imatable='snarfsnarf'}) + local f = mock.func(7, 5, {imatable='snarfsnarf'}) local a, b, c = f() expect.eq(7, a) expect.eq(5, b) expect.table_eq({imatable='snarfsnarf'}, c) end + +function test.observe_func() + -- basic end-to-end test for common cases; + -- most edge cases are covered by mock.func() tests + local counter = 0 + local function target() + counter = counter + 1 + return counter + end + local observer = mock.observe_func(target) + + expect.eq(observer(), 1) + expect.eq(counter, 1) + expect.eq(observer.call_count, 1) + expect.table_eq(observer.call_args, {{}}) + + expect.eq(observer('x', 'y'), 2) + expect.eq(counter, 2) + expect.eq(observer.call_count, 2) + expect.table_eq(observer.call_args, {{}, {'x', 'y'}}) +end + +function test.observe_func_error() + local function target() + error('asdf') + end + local observer = mock.observe_func(target) + + expect.error_match('asdf', function() + observer('x') + end) + -- make sure the call was still tracked + expect.eq(observer.call_count, 1) + expect.table_eq(observer.call_args, {{'x'}}) +end