add safe_pairs, update unit tests

develop
myk002 2021-08-14 17:36:38 -07:00
parent 6d0f7e40a9
commit d9c6c2dde3
No known key found for this signature in database
GPG Key ID: 8A39CA0FA0C16E78
3 changed files with 55 additions and 22 deletions

@ -162,11 +162,33 @@ NEWLINE = "\n"
COMMA = "," COMMA = ","
PERIOD = "." PERIOD = "."
local function _wrap_iterator(next_fn, ...)
local wrapped_iter = function(...)
local ret = {pcall(next_fn, ...)}
local ok = table.remove(ret, 1)
if ok then
return table.unpack(ret)
end
end
return wrapped_iter, ...
end
function safe_pairs(t, iterator_fn)
iterator_fn = iterator_fn or pairs
if (pcall(pairs, t)) then
return _wrap_iterator(iterator_fn(t))
else
return function() end
end
end
-- calls elem_cb(k, v) for each element of the table -- calls elem_cb(k, v) for each element of the table
-- returns true if we iterated successfully, false if not -- returns true if we iterated successfully, false if not
local function safe_iterate(table, iterate_fn, elem_cb) -- this differs from safe_pairs() above in that it only calls pcall() once per
-- full iteration and it returns whether iteration succeeded or failed.
local function safe_iterate(table, iterator_fn, elem_cb)
local function iterate() local function iterate()
for k,v in iterate_fn(table) do elem_cb(k, v) end for k,v in iterator_fn(table) do elem_cb(k, v) end
end end
return pcall(iterate) return pcall(iterate)
end end

@ -0,0 +1,21 @@
-- tests misc functions added by dfhack.lua
function test.safe_pairs()
for k,v in safe_pairs(nil) do
expect.fail('nil should not be iterable')
end
for k,v in safe_pairs('a') do
expect.fail('a string should not be iterable')
end
for k,v in safe_pairs({}) do
expect.fail('an empty table should not be iterable')
end
local iterated = 0
local t = {a='hello', b='world', [1]='extra'}
for k,v in safe_pairs(t) do
expect.eq(t[k], v)
iterated = iterated + 1
end
expect.eq(3, iterated)
end

@ -19,32 +19,27 @@ end
function test.printall_string() function test.printall_string()
printall('a') printall('a')
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.eq('a', mock_print.call_args[1][1])
end end
function test.printall_number() function test.printall_number()
printall(10) printall(10)
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.eq('10', mock_print.call_args[1][1])
end end
function test.printall_nil() function test.printall_nil()
printall(nil) printall(nil)
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.eq('nil', mock_print.call_args[1][1])
end end
function test.printall_boolean() function test.printall_boolean()
printall(false) printall(false)
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.eq('false', mock_print.call_args[1][1])
end end
function test.printall_function() function test.printall_function()
printall(function() end) printall(function() end)
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.true_(mock_print.call_args[1][1]:find('^function: 0x'))
end end
function test.printall_userdata() function test.printall_userdata()
@ -60,32 +55,27 @@ end
function test.printall_ipairs_string() function test.printall_ipairs_string()
printall_ipairs('a') printall_ipairs('a')
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.eq('a', mock_print.call_args[1][1])
end end
function test.printall_ipairs_number() function test.printall_ipairs_number()
printall_ipairs(10) printall_ipairs(10)
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.eq('10', mock_print.call_args[1][1])
end end
function test.printall_ipairs_nil() function test.printall_ipairs_nil()
printall_ipairs(nil) printall_ipairs(nil)
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.eq('nil', mock_print.call_args[1][1])
end end
function test.printall_ipairs_boolean() function test.printall_ipairs_boolean()
printall_ipairs(false) printall_ipairs(false)
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.eq('false', mock_print.call_args[1][1])
end end
function test.printall_ipairs_function() function test.printall_ipairs_function()
printall_ipairs(function() end) printall_ipairs(function() end)
expect.eq(1, mock_print.call_count) expect.eq(0, mock_print.call_count)
expect.true_(mock_print.call_args[1][1]:find('^function: 0x'))
end end
function test.printall_ipairs_userdata() function test.printall_ipairs_userdata()