diff --git a/library/lua/test_util/mock.lua b/library/lua/test_util/mock.lua index 0d6eae546..2e86cbe38 100644 --- a/library/lua/test_util/mock.lua +++ b/library/lua/test_util/mock.lua @@ -1,27 +1,8 @@ local mock = mkmodule('test_util.mock') ---[[ -Usage: - patch(table, key, value, callback) - patch({ - {table, key, value}, - {table2, key2, value2} - }, callback) -]] -function mock.patch(...) - local args = {...} - if #args == 4 then - args = {{ - {args[1], args[2], args[3]} - }, args[4]} - end - if #args ~= 2 then - error('expected 2 or 4 arguments') - end - - local callback = args[2] +function _patch_impl(patches_raw, callback, restore_only) local patches = {} - for _, v in ipairs(args[1]) do + for _, v in ipairs(patches_raw) do local p = { table = v[1], key = v[2], @@ -40,14 +21,66 @@ function mock.patch(...) end end, function() - for _, p in ipairs(patches) do - p.table[p.key] = p.new_value + if not restore_only then + for _, p in ipairs(patches) do + p.table[p.key] = p.new_value + end end return callback() end ) end +--[[ +Usage: + patch(table, key, value, callback) + patch({ + {table, key, value}, + {table2, key2, value2}, + }, callback) +]] +function mock.patch(...) + local args = {...} + local patches + local callback + if #args == 4 then + patches = {{args[1], args[2], args[3]}} + callback = args[4] + elseif #args == 2 then + patches = args[1] + callback = args[2] + else + error('expected 2 or 4 arguments') + end + + return _patch_impl(patches, callback) +end + +--[[ +Usage: + restore(table, key, callback) + restore({ + {table, key}, + {table2, key2}, + }, callback) +]] +function mock.restore(...) + local args = {...} + local patches + local callback + if #args == 3 then + patches = {{args[1], args[2]}} + callback = args[3] + elseif #args == 2 then + patches = args[1] + callback = args[2] + else + error('expected 2 or 3 arguments') + end + + return _patch_impl(patches, callback, true) +end + function mock.func(return_value) local f = { return_value = return_value, diff --git a/test/library/test_util/mock.lua b/test/library/test_util/mock.lua index 584f6c188..a605abc8a 100644 --- a/test/library/test_util/mock.lua +++ b/test/library/test_util/mock.lua @@ -99,6 +99,78 @@ function test.patch_callback_return_value() expect.eq(b, 4) end +function test.patch_invalid_value() + dfhack.with_temp_object(df.new('int8_t'), function(i) + i.value = 1 + local called = false + expect.error_match('integer expected', function() + mock.patch(i, 'value', 2, function() + expect.eq(i.value, 2) + called = true + i.value = 'a' + end) + end) + expect.true_(called) + expect.eq(i.value, 1) + end) +end + +function test.patch_invalid_value_initial() + dfhack.with_temp_object(df.new('int8_t'), function(i) + i.value = 1 + expect.error_match('integer expected', function() + mock.patch(i, 'value', 'a', function() + expect.fail('patch() callback called unexpectedly') + end) + end) + expect.eq(i.value, 1) + end) +end + +function test.patch_invalid_value_initial_multiple() + dfhack.with_temp_object(df.new('int8_t', 2), function(i) + i[0] = 1 + i[1] = 2 + expect.error_match('integer expected', function() + mock.patch({ + {i, 0, 3}, + {i, 1, 'a'}, + }, function() + expect.fail('patch() callback called unexpectedly') + end) + end) + expect.eq(i[0], 1) + expect.eq(i[1], 2) + end) +end + +function test.restore_single() + local t = {k = 1} + mock.restore(t, 'k', function() + expect.eq(t.k, 1) + t.k = 2 + expect.eq(t.k, 2) + end) + expect.eq(t.k, 1) +end + +function test.restore_multiple() + local t = {a = 1, b = 2} + mock.restore({ + {t, 'a'}, + {t, 'b'}, + }, function() + expect.eq(t.a, 1) + expect.eq(t.b, 2) + t.a = 3 + t.b = 4 + expect.eq(t.a, 3) + expect.eq(t.b, 4) + end) + expect.eq(t.a, 1) + expect.eq(t.b, 2) +end + function test.func_call_count() local f = mock.func() expect.eq(f.call_count, 0)