Adds cxxrandom unit test and fixes interface problems (#2099)

* Adds cxxrandom unit test and fixes interface problems

* Tightens braces

* Adds detection code for Shuffle's seqID/engID

* Adds usage examples for cxxrandom

* Gives cxxrandom objects id ranges, sort of

* Updates changelog

* Updates changelog.txt

* Increases id space for cxxrandom

* Fixes bool distribution error message and improves check

* Adds comment explaining the seeded RNG tests for cxxrandom

* Fixes type problem for 32bit builds

* Reduces loop count a few magnitudes

* Fixes a mistake in test.cxxrandom_seed
develop
Josh Cooper 2022-04-30 21:46:47 -07:00 committed by GitHub
parent 9643246b18
commit fe29bff845
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 202 additions and 114 deletions

@ -4401,7 +4401,7 @@ Native functions (exported to Lua)
adds a number to the sequence
- ``ShuffleSequence(rngID, seqID)``
- ``ShuffleSequence(seqID, rngID)``
shuffles the number sequence
@ -4464,7 +4464,7 @@ Lua plugin classes
``bool_distribution``
~~~~~~~~~~~~~~~~~~~~~
- ``init(min, max)``: constructor
- ``init(chance)``: constructor
- ``next(id)``: returns next boolean in the distribution
- ``id``: engine ID to pass to native function
@ -4477,6 +4477,41 @@ Lua plugin classes
- ``shuffle()``: shuffles the sequence of numbers
- ``next()``: returns next number in the sequence
Usage
-----
The basic idea is you create a number distribution which you generate random numbers along. The C++ relies
on engines keeping state information to determine the next number along the distribution.
You're welcome to try and (ab)use this knowledge for your RNG purposes.
Example::
local rng = require('plugins.cxxrandom')
local norm_dist = rng.normal_distribution(6820,116) // avg, stddev
local engID = rng.MakeNewEngine(0)
-- somewhat reminiscent of the C++ syntax
print(norm_dist:next(engID))
-- a bit more streamlined
local cleanup = true --delete engine on cleanup
local number_generator = rng.crng:new(engID, cleanup, norm_dist)
print(number_generator:next())
-- simplified
print(rng.rollNormal(engID,6820,116))
The number sequences are much simpler. They're intended for where you need to randomly generate an index, perhaps in a loop for an array. You technically don't need an engine to use it, if you don't mind never shuffling.
Example::
local rng = require('plugins.cxxrandom')
local g = rng.crng:new(rng.MakeNewEngine(0), true, rng.num_sequence:new(0,table_size))
g:shuffle()
for _ = 1, table_size do
func(array[g:next()])
end
dig-now
=======

@ -42,6 +42,8 @@ changelog.txt uses a syntax similar to RST, with a few special sequences:
- `tweak` partial-items: displays percentages on partially-consumed items such as hospital cloth
## Fixes
- `cxxrandom`: fixed exception when calling ``bool_distribution``
- `cxxrandom`: fixed id order for ShuffleSequence, but adds code to detect which parameter is which so each id is used correctly. 16000 limit before things get weird (previous was 16 bits)
- `autofarm` removed restriction on only planting 'discovered' plants
- `luasocket` (and others): return correct status code when closing socket connections
@ -64,6 +66,7 @@ changelog.txt uses a syntax similar to RST, with a few special sequences:
- Include recently-added tweaks in example dfhack.init file, clean up dreamfort onMapLoad.init file
## Documentation
- `cxxrandom`: added usage examples
- Add more examples to the plugin skeleton files so they are more informative for a newbie
- Lua API.rst added: ``isHidden(unit)``, ``isFortControlled(unit)``, ``getOuterContainerRef(unit)``, ``getOuterContainerRef(item)``
- Update download link and installation instructions for Visual C++ 2015 build tools on Windows

@ -37,94 +37,82 @@ DFHACK_PLUGIN("cxxrandom");
#define PLUGIN_VERSION 2.0
color_ostream *cout = nullptr;
DFhackCExport command_result plugin_init (color_ostream &out, std::vector <PluginCommand> &commands)
{
DFhackCExport command_result plugin_init (color_ostream &out, std::vector <PluginCommand> &commands) {
cout = &out;
return CR_OK;
}
DFhackCExport command_result plugin_shutdown (color_ostream &out)
{
DFhackCExport command_result plugin_shutdown (color_ostream &out) {
return CR_OK;
}
DFhackCExport command_result plugin_onstatechange(color_ostream &out, state_change_event event)
{
DFhackCExport command_result plugin_onstatechange(color_ostream &out, state_change_event event) {
return CR_OK;
}
#define EK_ID_BASE (1ll << 40)
class EnginesKeeper
{
private:
EnginesKeeper() {}
std::unordered_map<uint16_t, std::mt19937_64> m_engines;
uint16_t counter = 0;
EnginesKeeper() = default;
std::unordered_map<uint64_t, std::mt19937_64> m_engines;
uint64_t id_counter = EK_ID_BASE;
public:
static EnginesKeeper& Instance()
{
static EnginesKeeper& Instance() {
static EnginesKeeper instance;
return instance;
}
uint16_t NewEngine( uint64_t seed )
{
uint64_t NewEngine( uint64_t seed ) {
auto id = ++id_counter;
CHECK_INVALID_ARGUMENT(m_engines.count(id) == 0);
std::mt19937_64 engine( seed != 0 ? seed : std::chrono::system_clock::now().time_since_epoch().count() );
m_engines[++counter] = engine;
return counter;
m_engines[id] = engine;
return id;
}
void DestroyEngine( uint16_t id )
{
void DestroyEngine( uint64_t id ) {
m_engines.erase( id );
}
void NewSeed( uint16_t id, uint64_t seed )
{
void NewSeed( uint64_t id, uint64_t seed ) {
CHECK_INVALID_ARGUMENT( m_engines.find( id ) != m_engines.end() );
m_engines[id].seed( seed != 0 ? seed : std::chrono::system_clock::now().time_since_epoch().count() );
}
std::mt19937_64& RNG( uint16_t id )
{
std::mt19937_64& RNG( uint64_t id ) {
CHECK_INVALID_ARGUMENT( m_engines.find( id ) != m_engines.end() );
return m_engines[id];
}
};
uint16_t GenerateEngine( uint64_t seed )
{
uint64_t GenerateEngine( uint64_t seed ) {
return EnginesKeeper::Instance().NewEngine( seed );
}
void DestroyEngine( uint16_t id )
{
void DestroyEngine( uint64_t id ) {
EnginesKeeper::Instance().DestroyEngine( id );
}
void NewSeed( uint16_t id, uint64_t seed )
{
void NewSeed( uint64_t id, uint64_t seed ) {
EnginesKeeper::Instance().NewSeed( id, seed );
}
int rollInt(uint16_t id, int min, int max)
{
int rollInt(uint64_t id, int min, int max) {
std::uniform_int_distribution<int> ND(min, max);
return ND(EnginesKeeper::Instance().RNG(id));
}
double rollDouble(uint16_t id, double min, double max)
{
double rollDouble(uint64_t id, double min, double max) {
std::uniform_real_distribution<double> ND(min, max);
return ND(EnginesKeeper::Instance().RNG(id));
}
double rollNormal(uint16_t id, double mean, double stddev)
{
double rollNormal(uint64_t id, double mean, double stddev) {
std::normal_distribution<double> ND(mean, stddev);
return ND(EnginesKeeper::Instance().RNG(id));
}
bool rollBool(uint16_t id, float p)
{
bool rollBool(uint64_t id, float p) {
std::bernoulli_distribution ND(p);
return ND(EnginesKeeper::Instance().RNG(id));
}
@ -137,118 +125,104 @@ private:
std::vector<int64_t> m_numbers;
public:
NumberSequence(){}
NumberSequence( int64_t start, int64_t end )
{
for( int64_t i = start; i <= end; ++i )
{
NumberSequence( int64_t start, int64_t end ) {
for( int64_t i = start; i <= end; ++i ) {
m_numbers.push_back( i );
}
}
void Add( int64_t num ) { m_numbers.push_back( num ); }
void Reset() { m_numbers.clear(); }
int64_t Next()
{
if(m_position >= m_numbers.size())
{
void Reset() { m_numbers.clear(); }
int64_t Next() {
if(m_position >= m_numbers.size()) {
m_position = 0;
}
return m_numbers[m_position++];
}
void Shuffle( uint16_t id )
{
std::shuffle( std::begin( m_numbers ), std::end( m_numbers ), EnginesKeeper::Instance().RNG( id ) );
void Shuffle( uint64_t engID ) {
std::shuffle( std::begin( m_numbers ), std::end( m_numbers ), EnginesKeeper::Instance().RNG(engID));
}
void Print()
{
for( auto v : m_numbers )
{
void Print() {
for( auto v : m_numbers ) {
cout->print( "%" PRId64 " ", v );
}
}
};
#define SK_ID_BASE 0
class SequenceKeeper
{
private:
SequenceKeeper() {}
std::unordered_map<uint16_t, NumberSequence> m_sequences;
uint16_t counter = 0;
SequenceKeeper() = default;
std::unordered_map<uint64_t, NumberSequence> m_sequences;
uint64_t id_counter = SK_ID_BASE;
public:
static SequenceKeeper& Instance()
{
static SequenceKeeper& Instance() {
static SequenceKeeper instance;
return instance;
}
uint16_t MakeNumSequence( int64_t start, int64_t end )
{
m_sequences[++counter] = NumberSequence( start, end );
return counter;
}
uint16_t MakeNumSequence()
{
m_sequences[++counter] = NumberSequence();
return counter;
}
void DestroySequence( uint16_t id )
{
m_sequences.erase( id );
}
void AddToSequence( uint16_t id, int64_t num )
{
CHECK_INVALID_ARGUMENT( m_sequences.find( id ) != m_sequences.end() );
m_sequences[id].Add( num );
}
void Shuffle( uint16_t id, uint16_t rng_id )
{
CHECK_INVALID_ARGUMENT( m_sequences.find( id ) != m_sequences.end() );
m_sequences[id].Shuffle( rng_id );
}
int64_t NextInSequence( uint16_t id )
{
CHECK_INVALID_ARGUMENT( m_sequences.find( id ) != m_sequences.end() );
return m_sequences[id].Next();
}
void PrintSequence( uint16_t id )
{
CHECK_INVALID_ARGUMENT( m_sequences.find( id ) != m_sequences.end() );
auto seq = m_sequences[id];
uint64_t MakeNumSequence( int64_t start, int64_t end ) {
auto id = ++id_counter;
CHECK_INVALID_ARGUMENT(m_sequences.count(id) == 0);
m_sequences[id] = NumberSequence(start, end);
return id;
}
uint64_t MakeNumSequence() {
auto id = ++id_counter;
CHECK_INVALID_ARGUMENT(m_sequences.count(id) == 0);
m_sequences[id] = NumberSequence();
return id;
}
void DestroySequence( uint64_t seqID ) {
m_sequences.erase(seqID);
}
void AddToSequence(uint64_t seqID, int64_t num ) {
CHECK_INVALID_ARGUMENT(m_sequences.find(seqID) != m_sequences.end());
m_sequences[seqID].Add(num);
}
void Shuffle(uint64_t seqID, uint64_t engID ) {
uint64_t sid = seqID >= SK_ID_BASE ? seqID : engID;
uint64_t eid = engID >= EK_ID_BASE ? engID : seqID;
CHECK_INVALID_ARGUMENT(m_sequences.find(sid) != m_sequences.end());
m_sequences[sid].Shuffle(eid);
}
int64_t NextInSequence( uint64_t seqID ) {
CHECK_INVALID_ARGUMENT(m_sequences.find(seqID) != m_sequences.end());
return m_sequences[seqID].Next();
}
void PrintSequence( uint64_t seqID ) {
CHECK_INVALID_ARGUMENT(m_sequences.find(seqID) != m_sequences.end());
auto seq = m_sequences[seqID];
seq.Print();
}
};
uint16_t MakeNumSequence( int64_t start, int64_t end )
{
if( start == end )
{
uint64_t MakeNumSequence( int64_t start, int64_t end ) {
if (start == end) {
return SequenceKeeper::Instance().MakeNumSequence();
}
return SequenceKeeper::Instance().MakeNumSequence( start, end );
return SequenceKeeper::Instance().MakeNumSequence(start, end);
}
void DestroyNumSequence( uint16_t id )
{
SequenceKeeper::Instance().DestroySequence( id );
void DestroyNumSequence( uint64_t seqID ) {
SequenceKeeper::Instance().DestroySequence(seqID);
}
void AddToSequence( uint16_t id, int64_t num )
{
SequenceKeeper::Instance().AddToSequence( id, num );
void AddToSequence(uint64_t seqID, int64_t num ) {
SequenceKeeper::Instance().AddToSequence(seqID, num);
}
void ShuffleSequence( uint16_t rngID, uint16_t id )
{
SequenceKeeper::Instance().Shuffle( id, rngID );
void ShuffleSequence(uint64_t seqID, uint64_t engID ) {
SequenceKeeper::Instance().Shuffle(seqID, engID);
}
int64_t NextInSequence( uint16_t id )
{
return SequenceKeeper::Instance().NextInSequence( id );
int64_t NextInSequence( uint64_t seqID ) {
return SequenceKeeper::Instance().NextInSequence(seqID);
}
void DebugSequence( uint16_t id )
{
SequenceKeeper::Instance().PrintSequence( id );
void DebugSequence( uint64_t seqID ) {
SequenceKeeper::Instance().PrintSequence(seqID);
}

@ -151,8 +151,8 @@ bool_distribution = {}
function bool_distribution:new(chance)
local o = {}
self.__index = self
if type(min) ~= 'number' or type(max) ~= 'number' then
error("Invalid arguments in bool_distribution construction. min and max must be numbers.")
if type(chance) ~= 'number' or chance < 0 or chance > 1 then
error("Invalid arguments in bool_distribution construction. chance must be a number between 0.0 and 1.0 (both included).")
end
o.p = chance
setmetatable(o,self)
@ -208,7 +208,7 @@ function num_sequence:shuffle()
if self.rngID == 'nil' then
error("Add num_sequence object to crng as distribution, before attempting to shuffle.")
end
ShuffleSequence(self.rngID, self.seqID)
ShuffleSequence(self.seqID, self.rngID)
end
return _ENV

@ -0,0 +1,76 @@
local rng = require('plugins.cxxrandom')
function test.cxxrandom_distributions()
rng.normal_distribution:new(0,5)
rng.real_distribution:new(-1,1)
rng.int_distribution:new(-20,20)
rng.bool_distribution:new(0.00000000001)
rng.num_sequence:new(-1000,1000)
-- no errors, no problem
end
--[[
The below tests pass with their given seeds, if they begin failing
for a given platform, or all around, new seeds should be found.
Note: these tests which assert RNG, are mere sanity checks
to ensure things haven't been severely broken by any changes
]]
function test.cxxrandom_seed()
local nd = rng.normal_distribution:new(0,500000)
local e1 = rng.MakeNewEngine(1)
local e2 = rng.MakeNewEngine(1)
local e3 = rng.MakeNewEngine(2)
local g1 = rng.crng:new(e1, true, nd)
local g2 = rng.crng:new(e2, true, nd)
local g3 = rng.crng:new(e3, true, nd)
local v1 = g1:next()
expect.eq(v1, g2:next())
expect.ne(v1, g3:next())
end
function test.cxxrandom_ranges()
local e1 = rng.MakeNewEngine(1)
local g1 = rng.crng:new(e1, true, rng.normal_distribution:new(0,1))
local g2 = rng.crng:new(e1, true, rng.real_distribution:new(-5,5))
local g3 = rng.crng:new(e1, true, rng.int_distribution:new(-5,5))
local g4 = rng.crng:new(e1, true, rng.num_sequence:new(-5,5))
for i = 1, 10 do
local a = g1:next()
local b = g2:next()
local c = g3:next()
local d = g4:next()
expect.ge(a, -5)
expect.ge(b, -5)
expect.ge(c, -5)
expect.ge(d, -5)
expect.le(a, 5)
expect.le(b, 5)
expect.le(c, 5)
expect.le(d, 5)
end
local gb = rng.crng:new(e1, true, rng.bool_distribution:new(0.00000000001))
for i = 1, 10 do
expect.false_(gb:next())
end
end
function test.cxxrandom_exports()
local id = rng.GenerateEngine(0)
rng.NewSeed(id, 2022)
expect.ge(rng.rollInt(id, 0, 1000), 0)
expect.ge(rng.rollDouble(id, 0, 1), 0)
expect.ge(rng.rollNormal(id, 5, 1), 0)
expect.true_(rng.rollBool(id, 0.9999999999))
local sid = rng.MakeNumSequence(0,8)
rng.AddToSequence(sid, 9)
rng.ShuffleSequence(sid, id)
for i = 1, 10 do
local v = rng.NextInSequence(sid)
expect.ge(v, 0)
expect.le(v, 9)
end
rng.DestroyNumSequence(sid)
rng.DestroyEngine(id)
end