@ -101,25 +101,130 @@ function expect.table_eq(a, b, comment)
( ' key %s: "%s" ~= "%s" ' ) : format ( keystr , diff [ 1 ] , diff [ 2 ] )
end
function expect . error ( func , ... )
local ok , ret = pcall ( func , ... )
function expect . error ( func , comment )
local ok = pcall ( func )
if ok then
return false , ' no error raised by function call '
return false , comment , ' no error raised by function call '
else
return true
end
end
function expect . error_match ( func , matcher , ... )
local ok , err = pcall ( func , ... )
if ok then
return false, ' no error raised by function call '
local function matches ( obj , matcher )
if not matcher then return false end
if type( matcher ) == ' boolean ' then
return true
elseif type ( matcher ) == ' string ' then
if not tostring ( err ) : match ( matcher ) then
return false , ( ' error "%s" did not match "%s" ' ) : format ( err , matcher )
return tostring ( obj ) : match ( matcher )
elseif type ( matcher ) == ' function ' then
return matcher ( obj )
end
return false
end
-- matches errors thrown from the specified function. the check passes if an
-- error is thrown and the thrown error matches the specified matcher.
--
-- matcher can be:
-- a string, interpreted as a lua pattern that matches the error message
-- a function that takes an err object and returns a boolean (true means match)
-- the literal value true, which matches any thrown error
function expect . error_match ( matcher , func , comment )
local ok , err = pcall ( func )
if ok then
return false , comment , ' no error raised by function call '
end
if matches ( err , matcher ) then return true end
local matcher_str = ' '
if type ( matcher ) == ' string ' then
matcher_str = ( ' : "%s" ' ) : format ( matcher )
end
return false ,
( ' error "%s" did not satisfy matcher%s ' ) : format ( err , matcher_str )
end
-- matches error messages output from dfhack.printerr() when the specified
-- callback is run. the check passes if all printerr messages are matched by
-- specified matchers and no matchers remain unmatched.
--
-- matcher can be:
-- a string, interpreted as a lua pattern that matches a message
-- a function that takes the string message and returns a boolean (true means
-- match)
-- the literal value true, which matches any message
-- the literal value false, nil, or an empty table, which match the absence of
-- printerr messages
-- a populated table that can be used to match multiple messages (explained
-- in more detail below)
--
-- if matcher is a table, it can contain:
-- a list of strings, literal true values, and/or functions which will be
-- matched in order
-- a map of strings, literal true values, and/or functions to positive
-- integers, which will be matched (in any order) the number of times
-- specified
--
-- when this function attempts to match a message, it will first try the next
-- matcher in the list (that is, the next numeric key). if that matcher doesn't
-- exist or doesn't match, it will try all string and function keys whose values
-- are numeric and > 0. if none of those match, it will check for a key equal to
-- true with a value > 0.
--
-- if a mapped matcher is matched, it will have its value decremented by 1.
function expect . printerr_match ( matcher , func , comment )
local saved_printerr = dfhack.printerr
local messages = { }
dfhack.printerr = function ( msg ) table.insert ( messages , msg ) end
dfhack.with_finalize (
function ( ) dfhack.printerr = saved_printerr end ,
func )
if not matcher then
local num_messages = # messages
if num_messages == 0 then return true end
return false , comment , ( ' expected 0 calls to dfhack.printerr but got ' ..
' %d ' ) : format ( num_messages )
end
if type ( matcher ) ~= ' table ' then matcher = { matcher } end
local true_count = matcher [ true ] or 0
matcher [ true ] = nil
for _ , msg in ipairs ( messages ) do
local m = matcher [ 1 ]
if matches ( msg , m ) then
table.remove ( matcher , 1 )
goto continue
end
elseif not matcher ( err ) then
return false , ( ' error "%s" did not satisfy matcher ' ) : format ( err )
for k , v in pairs ( matcher ) do
if type ( v ) == ' number ' and v > 0 and matches ( msg , k ) then
local remaining = v - 1
if v == 0 then
matcher [ k ] = nil
else
matcher [ k ] = remaining
end
goto continue
end
end
if true_count > 0 then
true_count = true_count - 1
goto continue
end
return false , comment , ( ' unmatched printerr message: "%s" ' ) : format ( msg )
:: continue ::
end
local extra_matchers = { }
for k , v in ipairs ( matcher ) do
table.insert ( extra_matchers , ( ' "%s" ' ) : format ( v ) )
matcher [ k ] = nil
end
for k , v in pairs ( matcher ) do
table.insert ( extra_matchers , ( ' "%s"=%d ' ) : format ( k , v ) )
end
if true_count > 0 then
table.insert ( extra_matchers , ( ' true=%d ' ) : format ( true_count ) )
end
if # extra_matchers > 0 then
return false , comment , ( ' unmatched or invalid matchers: %s ' ) : format (
table.concat ( extra_matchers , ' , ' ) )
end
return true
end