local mock = mkmodule('test_util.mock')

function _patch_impl(patches_raw, callback, restore_only)
    local patches = {}
    for _, v in ipairs(patches_raw) do
        local p = {
            table = v[1],
            key = v[2],
            new_value = v[3],
        }
        p.old_value = p.table[p.key]
        -- no-op to ensure that the value can be restored by the finalizer below
        p.table[p.key] = p.old_value
        table.insert(patches, p)
    end

    return dfhack.with_finalize(
        function()
            for _, p in ipairs(patches) do
                p.table[p.key] = p.old_value
            end
        end,
        function()
            if not restore_only then
                for _, p in ipairs(patches) do
                    p.table[p.key] = p.new_value
                end
            end
            return callback()
        end
    )
end

--[[

Replaces `table[key]` with `value`, calls `callback()`, then restores the
original value of `table[key]`.

Usage:
    patch(table, key, value, callback)
    patch({
        {table, key, value},
        {table2, key2, value2},
    }, callback)

]]
function mock.patch(...)
    local args = {...}
    local patches
    local callback
    if #args == 4 then
        patches = {{args[1], args[2], args[3]}}
        callback = args[4]
    elseif #args == 2 then
        patches = args[1]
        callback = args[2]
    else
        error('expected 2 or 4 arguments')
    end

    return _patch_impl(patches, callback)
end

--[[

Restores the original value of `table[key]` after calling `callback()`.

Equivalent to: patch(table, key, table[key], callback)

Usage:
    restore(table, key, callback)
    restore({
        {table, key},
        {table2, key2},
    }, callback)

]]
function mock.restore(...)
    local args = {...}
    local patches
    local callback
    if #args == 3 then
        patches = {{args[1], args[2]}}
        callback = args[3]
    elseif #args == 2 then
        patches = args[1]
        callback = args[2]
    else
        error('expected 2 or 3 arguments')
    end

    return _patch_impl(patches, callback, true)
end

--[[

Returns a callable object that tracks the arguments it is called with, then
passes those arguments to `callback()`.

The returned object has the following properties:
- `call_count`: the number of times the object has been called
- `call_args`: a table of function arguments (shallow-copied) corresponding
    to each time the object was called

]]
function mock.observe_func(callback)
    local f = {
        call_count = 0,
        call_args = {},
    }

    setmetatable(f, {
        __call = function(self, ...)
            self.call_count = self.call_count + 1
            local args = {...}
            for i,v in ipairs(args) do
                if type(v) == 'table' then
                    -- just a shallow copy, but it offers some ability to
                    -- inspect original values in tables that were altered after
                    -- the call
                    args[i] = copyall(v)
                end
            end
            table.insert(self.call_args, args)
            return callback(...)
        end,
    })

    return f
end

--[[

Returns a callable object similar to `mock.observe_func()`, but which when
called, only returns the given `return_value`(s) with no additional side effects.

Intended for use by `patch()`.

Usage:
    func(return_value [, return_value2 ...])

See `observe_func()` for a description of the return value.

The return value also has an additional `return_values` field, which is a table
of values returned when the object is called. This can be modified.

]]
function mock.func(...)
    local f
    f = mock.observe_func(function()
        return table.unpack(f.return_values)
    end)
    f.return_values = {...}
    return f
end

return mock