diff --git a/library/lua/dfhack.lua b/library/lua/dfhack.lua index 431c65b3a..ea6cc3685 100644 --- a/library/lua/dfhack.lua +++ b/library/lua/dfhack.lua @@ -162,11 +162,33 @@ NEWLINE = "\n" COMMA = "," PERIOD = "." +local function _wrap_iterator(next_fn, ...) + local wrapped_iter = function(...) + local ret = {pcall(next_fn, ...)} + local ok = table.remove(ret, 1) + if ok then + return table.unpack(ret) + end + end + return wrapped_iter, ... +end + +function safe_pairs(t, iterator_fn) + iterator_fn = iterator_fn or pairs + if (pcall(pairs, t)) then + return _wrap_iterator(iterator_fn(t)) + else + return function() end + end +end + -- calls elem_cb(k, v) for each element of the table -- returns true if we iterated successfully, false if not -local function safe_iterate(table, iterate_fn, elem_cb) +-- this differs from safe_pairs() above in that it only calls pcall() once per +-- full iteration and it returns whether iteration succeeded or failed. +local function safe_iterate(table, iterator_fn, elem_cb) local function iterate() - for k,v in iterate_fn(table) do elem_cb(k, v) end + for k,v in iterator_fn(table) do elem_cb(k, v) end end return pcall(iterate) end diff --git a/test/library/misc.lua b/test/library/misc.lua new file mode 100644 index 000000000..bd9e5fc5b --- /dev/null +++ b/test/library/misc.lua @@ -0,0 +1,21 @@ +-- tests misc functions added by dfhack.lua + +function test.safe_pairs() + for k,v in safe_pairs(nil) do + expect.fail('nil should not be iterable') + end + for k,v in safe_pairs('a') do + expect.fail('a string should not be iterable') + end + for k,v in safe_pairs({}) do + expect.fail('an empty table should not be iterable') + end + + local iterated = 0 + local t = {a='hello', b='world', [1]='extra'} + for k,v in safe_pairs(t) do + expect.eq(t[k], v) + iterated = iterated + 1 + end + expect.eq(3, iterated) +end diff --git a/test/library/print.lua b/test/library/print.lua index b5e7c9004..179b70784 100644 --- a/test/library/print.lua +++ b/test/library/print.lua @@ -19,32 +19,27 @@ end function test.printall_string() printall('a') - expect.eq(1, mock_print.call_count) - expect.eq('a', mock_print.call_args[1][1]) + expect.eq(0, mock_print.call_count) end function test.printall_number() printall(10) - expect.eq(1, mock_print.call_count) - expect.eq('10', mock_print.call_args[1][1]) + expect.eq(0, mock_print.call_count) end function test.printall_nil() printall(nil) - expect.eq(1, mock_print.call_count) - expect.eq('nil', mock_print.call_args[1][1]) + expect.eq(0, mock_print.call_count) end function test.printall_boolean() printall(false) - expect.eq(1, mock_print.call_count) - expect.eq('false', mock_print.call_args[1][1]) + expect.eq(0, mock_print.call_count) end function test.printall_function() printall(function() end) - expect.eq(1, mock_print.call_count) - expect.true_(mock_print.call_args[1][1]:find('^function: 0x')) + expect.eq(0, mock_print.call_count) end function test.printall_userdata() @@ -60,32 +55,27 @@ end function test.printall_ipairs_string() printall_ipairs('a') - expect.eq(1, mock_print.call_count) - expect.eq('a', mock_print.call_args[1][1]) + expect.eq(0, mock_print.call_count) end function test.printall_ipairs_number() printall_ipairs(10) - expect.eq(1, mock_print.call_count) - expect.eq('10', mock_print.call_args[1][1]) + expect.eq(0, mock_print.call_count) end function test.printall_ipairs_nil() printall_ipairs(nil) - expect.eq(1, mock_print.call_count) - expect.eq('nil', mock_print.call_args[1][1]) + expect.eq(0, mock_print.call_count) end function test.printall_ipairs_boolean() printall_ipairs(false) - expect.eq(1, mock_print.call_count) - expect.eq('false', mock_print.call_args[1][1]) + expect.eq(0, mock_print.call_count) end function test.printall_ipairs_function() printall_ipairs(function() end) - expect.eq(1, mock_print.call_count) - expect.true_(mock_print.call_args[1][1]:find('^function: 0x')) + expect.eq(0, mock_print.call_count) end function test.printall_ipairs_userdata()