From fe29bff845e3cd2494c6a63898c5797b886e2a81 Mon Sep 17 00:00:00 2001 From: Josh Cooper Date: Sat, 30 Apr 2022 21:46:47 -0700 Subject: [PATCH] Adds cxxrandom unit test and fixes interface problems (#2099) * Adds cxxrandom unit test and fixes interface problems * Tightens braces * Adds detection code for Shuffle's seqID/engID * Adds usage examples for cxxrandom * Gives cxxrandom objects id ranges, sort of * Updates changelog * Updates changelog.txt * Increases id space for cxxrandom * Fixes bool distribution error message and improves check * Adds comment explaining the seeded RNG tests for cxxrandom * Fixes type problem for 32bit builds * Reduces loop count a few magnitudes * Fixes a mistake in test.cxxrandom_seed --- docs/Lua API.rst | 39 +++++++- docs/changelog.txt | 3 + plugins/cxxrandom.cpp | 192 ++++++++++++++++--------------------- plugins/lua/cxxrandom.lua | 6 +- test/plugins/cxxrandom.lua | 76 +++++++++++++++ 5 files changed, 202 insertions(+), 114 deletions(-) create mode 100644 test/plugins/cxxrandom.lua diff --git a/docs/Lua API.rst b/docs/Lua API.rst index 4615aaf4f..c0dec4b6e 100644 --- a/docs/Lua API.rst +++ b/docs/Lua API.rst @@ -4401,7 +4401,7 @@ Native functions (exported to Lua) adds a number to the sequence -- ``ShuffleSequence(rngID, seqID)`` +- ``ShuffleSequence(seqID, rngID)`` shuffles the number sequence @@ -4464,7 +4464,7 @@ Lua plugin classes ``bool_distribution`` ~~~~~~~~~~~~~~~~~~~~~ -- ``init(min, max)``: constructor +- ``init(chance)``: constructor - ``next(id)``: returns next boolean in the distribution - ``id``: engine ID to pass to native function @@ -4477,6 +4477,41 @@ Lua plugin classes - ``shuffle()``: shuffles the sequence of numbers - ``next()``: returns next number in the sequence +Usage +----- + +The basic idea is you create a number distribution which you generate random numbers along. The C++ relies +on engines keeping state information to determine the next number along the distribution. +You're welcome to try and (ab)use this knowledge for your RNG purposes. + +Example:: + + local rng = require('plugins.cxxrandom') + local norm_dist = rng.normal_distribution(6820,116) // avg, stddev + local engID = rng.MakeNewEngine(0) + -- somewhat reminiscent of the C++ syntax + print(norm_dist:next(engID)) + + -- a bit more streamlined + local cleanup = true --delete engine on cleanup + local number_generator = rng.crng:new(engID, cleanup, norm_dist) + print(number_generator:next()) + + -- simplified + print(rng.rollNormal(engID,6820,116)) + +The number sequences are much simpler. They're intended for where you need to randomly generate an index, perhaps in a loop for an array. You technically don't need an engine to use it, if you don't mind never shuffling. + +Example:: + + local rng = require('plugins.cxxrandom') + local g = rng.crng:new(rng.MakeNewEngine(0), true, rng.num_sequence:new(0,table_size)) + g:shuffle() + for _ = 1, table_size do + func(array[g:next()]) + end + + dig-now ======= diff --git a/docs/changelog.txt b/docs/changelog.txt index 56d743171..06406087a 100644 --- a/docs/changelog.txt +++ b/docs/changelog.txt @@ -42,6 +42,8 @@ changelog.txt uses a syntax similar to RST, with a few special sequences: - `tweak` partial-items: displays percentages on partially-consumed items such as hospital cloth ## Fixes +- `cxxrandom`: fixed exception when calling ``bool_distribution`` +- `cxxrandom`: fixed id order for ShuffleSequence, but adds code to detect which parameter is which so each id is used correctly. 16000 limit before things get weird (previous was 16 bits) - `autofarm` removed restriction on only planting 'discovered' plants - `luasocket` (and others): return correct status code when closing socket connections @@ -64,6 +66,7 @@ changelog.txt uses a syntax similar to RST, with a few special sequences: - Include recently-added tweaks in example dfhack.init file, clean up dreamfort onMapLoad.init file ## Documentation +- `cxxrandom`: added usage examples - Add more examples to the plugin skeleton files so they are more informative for a newbie - Lua API.rst added: ``isHidden(unit)``, ``isFortControlled(unit)``, ``getOuterContainerRef(unit)``, ``getOuterContainerRef(item)`` - Update download link and installation instructions for Visual C++ 2015 build tools on Windows diff --git a/plugins/cxxrandom.cpp b/plugins/cxxrandom.cpp index 159edaefc..12f043214 100644 --- a/plugins/cxxrandom.cpp +++ b/plugins/cxxrandom.cpp @@ -37,94 +37,82 @@ DFHACK_PLUGIN("cxxrandom"); #define PLUGIN_VERSION 2.0 color_ostream *cout = nullptr; -DFhackCExport command_result plugin_init (color_ostream &out, std::vector &commands) -{ +DFhackCExport command_result plugin_init (color_ostream &out, std::vector &commands) { cout = &out; return CR_OK; } -DFhackCExport command_result plugin_shutdown (color_ostream &out) -{ +DFhackCExport command_result plugin_shutdown (color_ostream &out) { return CR_OK; } -DFhackCExport command_result plugin_onstatechange(color_ostream &out, state_change_event event) -{ +DFhackCExport command_result plugin_onstatechange(color_ostream &out, state_change_event event) { return CR_OK; } +#define EK_ID_BASE (1ll << 40) class EnginesKeeper { private: - EnginesKeeper() {} - std::unordered_map m_engines; - uint16_t counter = 0; + EnginesKeeper() = default; + std::unordered_map m_engines; + uint64_t id_counter = EK_ID_BASE; public: - static EnginesKeeper& Instance() - { + static EnginesKeeper& Instance() { static EnginesKeeper instance; return instance; } - uint16_t NewEngine( uint64_t seed ) - { + uint64_t NewEngine( uint64_t seed ) { + auto id = ++id_counter; + CHECK_INVALID_ARGUMENT(m_engines.count(id) == 0); std::mt19937_64 engine( seed != 0 ? seed : std::chrono::system_clock::now().time_since_epoch().count() ); - m_engines[++counter] = engine; - return counter; + m_engines[id] = engine; + return id; } - void DestroyEngine( uint16_t id ) - { + void DestroyEngine( uint64_t id ) { m_engines.erase( id ); } - void NewSeed( uint16_t id, uint64_t seed ) - { + void NewSeed( uint64_t id, uint64_t seed ) { CHECK_INVALID_ARGUMENT( m_engines.find( id ) != m_engines.end() ); m_engines[id].seed( seed != 0 ? seed : std::chrono::system_clock::now().time_since_epoch().count() ); } - std::mt19937_64& RNG( uint16_t id ) - { + std::mt19937_64& RNG( uint64_t id ) { CHECK_INVALID_ARGUMENT( m_engines.find( id ) != m_engines.end() ); return m_engines[id]; } }; -uint16_t GenerateEngine( uint64_t seed ) -{ +uint64_t GenerateEngine( uint64_t seed ) { return EnginesKeeper::Instance().NewEngine( seed ); } -void DestroyEngine( uint16_t id ) -{ +void DestroyEngine( uint64_t id ) { EnginesKeeper::Instance().DestroyEngine( id ); } -void NewSeed( uint16_t id, uint64_t seed ) -{ +void NewSeed( uint64_t id, uint64_t seed ) { EnginesKeeper::Instance().NewSeed( id, seed ); } -int rollInt(uint16_t id, int min, int max) -{ +int rollInt(uint64_t id, int min, int max) { std::uniform_int_distribution ND(min, max); return ND(EnginesKeeper::Instance().RNG(id)); } -double rollDouble(uint16_t id, double min, double max) -{ +double rollDouble(uint64_t id, double min, double max) { std::uniform_real_distribution ND(min, max); return ND(EnginesKeeper::Instance().RNG(id)); } -double rollNormal(uint16_t id, double mean, double stddev) -{ +double rollNormal(uint64_t id, double mean, double stddev) { std::normal_distribution ND(mean, stddev); return ND(EnginesKeeper::Instance().RNG(id)); } -bool rollBool(uint16_t id, float p) -{ +bool rollBool(uint64_t id, float p) { std::bernoulli_distribution ND(p); return ND(EnginesKeeper::Instance().RNG(id)); } @@ -137,118 +125,104 @@ private: std::vector m_numbers; public: NumberSequence(){} - NumberSequence( int64_t start, int64_t end ) - { - for( int64_t i = start; i <= end; ++i ) - { + NumberSequence( int64_t start, int64_t end ) { + for( int64_t i = start; i <= end; ++i ) { m_numbers.push_back( i ); } } void Add( int64_t num ) { m_numbers.push_back( num ); } - void Reset() { m_numbers.clear(); } - int64_t Next() - { - if(m_position >= m_numbers.size()) - { + void Reset() { m_numbers.clear(); } + int64_t Next() { + if(m_position >= m_numbers.size()) { m_position = 0; } return m_numbers[m_position++]; } - void Shuffle( uint16_t id ) - { - std::shuffle( std::begin( m_numbers ), std::end( m_numbers ), EnginesKeeper::Instance().RNG( id ) ); + void Shuffle( uint64_t engID ) { + std::shuffle( std::begin( m_numbers ), std::end( m_numbers ), EnginesKeeper::Instance().RNG(engID)); } - void Print() - { - for( auto v : m_numbers ) - { + void Print() { + for( auto v : m_numbers ) { cout->print( "%" PRId64 " ", v ); } } }; +#define SK_ID_BASE 0 + class SequenceKeeper { private: - SequenceKeeper() {} - std::unordered_map m_sequences; - uint16_t counter = 0; + SequenceKeeper() = default; + std::unordered_map m_sequences; + uint64_t id_counter = SK_ID_BASE; public: - static SequenceKeeper& Instance() - { + static SequenceKeeper& Instance() { static SequenceKeeper instance; return instance; } - uint16_t MakeNumSequence( int64_t start, int64_t end ) - { - m_sequences[++counter] = NumberSequence( start, end ); - return counter; - } - uint16_t MakeNumSequence() - { - m_sequences[++counter] = NumberSequence(); - return counter; - } - void DestroySequence( uint16_t id ) - { - m_sequences.erase( id ); - } - void AddToSequence( uint16_t id, int64_t num ) - { - CHECK_INVALID_ARGUMENT( m_sequences.find( id ) != m_sequences.end() ); - m_sequences[id].Add( num ); - } - void Shuffle( uint16_t id, uint16_t rng_id ) - { - CHECK_INVALID_ARGUMENT( m_sequences.find( id ) != m_sequences.end() ); - m_sequences[id].Shuffle( rng_id ); - } - int64_t NextInSequence( uint16_t id ) - { - CHECK_INVALID_ARGUMENT( m_sequences.find( id ) != m_sequences.end() ); - return m_sequences[id].Next(); - } - void PrintSequence( uint16_t id ) - { - CHECK_INVALID_ARGUMENT( m_sequences.find( id ) != m_sequences.end() ); - auto seq = m_sequences[id]; + uint64_t MakeNumSequence( int64_t start, int64_t end ) { + auto id = ++id_counter; + CHECK_INVALID_ARGUMENT(m_sequences.count(id) == 0); + m_sequences[id] = NumberSequence(start, end); + return id; + } + uint64_t MakeNumSequence() { + auto id = ++id_counter; + CHECK_INVALID_ARGUMENT(m_sequences.count(id) == 0); + m_sequences[id] = NumberSequence(); + return id; + } + void DestroySequence( uint64_t seqID ) { + m_sequences.erase(seqID); + } + void AddToSequence(uint64_t seqID, int64_t num ) { + CHECK_INVALID_ARGUMENT(m_sequences.find(seqID) != m_sequences.end()); + m_sequences[seqID].Add(num); + } + void Shuffle(uint64_t seqID, uint64_t engID ) { + uint64_t sid = seqID >= SK_ID_BASE ? seqID : engID; + uint64_t eid = engID >= EK_ID_BASE ? engID : seqID; + CHECK_INVALID_ARGUMENT(m_sequences.find(sid) != m_sequences.end()); + m_sequences[sid].Shuffle(eid); + } + int64_t NextInSequence( uint64_t seqID ) { + CHECK_INVALID_ARGUMENT(m_sequences.find(seqID) != m_sequences.end()); + return m_sequences[seqID].Next(); + } + void PrintSequence( uint64_t seqID ) { + CHECK_INVALID_ARGUMENT(m_sequences.find(seqID) != m_sequences.end()); + auto seq = m_sequences[seqID]; seq.Print(); } }; -uint16_t MakeNumSequence( int64_t start, int64_t end ) -{ - if( start == end ) - { +uint64_t MakeNumSequence( int64_t start, int64_t end ) { + if (start == end) { return SequenceKeeper::Instance().MakeNumSequence(); } - return SequenceKeeper::Instance().MakeNumSequence( start, end ); + return SequenceKeeper::Instance().MakeNumSequence(start, end); } -void DestroyNumSequence( uint16_t id ) -{ - SequenceKeeper::Instance().DestroySequence( id ); +void DestroyNumSequence( uint64_t seqID ) { + SequenceKeeper::Instance().DestroySequence(seqID); } -void AddToSequence( uint16_t id, int64_t num ) -{ - SequenceKeeper::Instance().AddToSequence( id, num ); +void AddToSequence(uint64_t seqID, int64_t num ) { + SequenceKeeper::Instance().AddToSequence(seqID, num); } -void ShuffleSequence( uint16_t rngID, uint16_t id ) -{ - SequenceKeeper::Instance().Shuffle( id, rngID ); +void ShuffleSequence(uint64_t seqID, uint64_t engID ) { + SequenceKeeper::Instance().Shuffle(seqID, engID); } -int64_t NextInSequence( uint16_t id ) -{ - return SequenceKeeper::Instance().NextInSequence( id ); +int64_t NextInSequence( uint64_t seqID ) { + return SequenceKeeper::Instance().NextInSequence(seqID); } -void DebugSequence( uint16_t id ) -{ - SequenceKeeper::Instance().PrintSequence( id ); +void DebugSequence( uint64_t seqID ) { + SequenceKeeper::Instance().PrintSequence(seqID); } diff --git a/plugins/lua/cxxrandom.lua b/plugins/lua/cxxrandom.lua index 78e363bef..542575b9c 100644 --- a/plugins/lua/cxxrandom.lua +++ b/plugins/lua/cxxrandom.lua @@ -151,8 +151,8 @@ bool_distribution = {} function bool_distribution:new(chance) local o = {} self.__index = self - if type(min) ~= 'number' or type(max) ~= 'number' then - error("Invalid arguments in bool_distribution construction. min and max must be numbers.") + if type(chance) ~= 'number' or chance < 0 or chance > 1 then + error("Invalid arguments in bool_distribution construction. chance must be a number between 0.0 and 1.0 (both included).") end o.p = chance setmetatable(o,self) @@ -208,7 +208,7 @@ function num_sequence:shuffle() if self.rngID == 'nil' then error("Add num_sequence object to crng as distribution, before attempting to shuffle.") end - ShuffleSequence(self.rngID, self.seqID) + ShuffleSequence(self.seqID, self.rngID) end return _ENV diff --git a/test/plugins/cxxrandom.lua b/test/plugins/cxxrandom.lua new file mode 100644 index 000000000..6b11e1937 --- /dev/null +++ b/test/plugins/cxxrandom.lua @@ -0,0 +1,76 @@ +local rng = require('plugins.cxxrandom') + +function test.cxxrandom_distributions() + rng.normal_distribution:new(0,5) + rng.real_distribution:new(-1,1) + rng.int_distribution:new(-20,20) + rng.bool_distribution:new(0.00000000001) + rng.num_sequence:new(-1000,1000) + -- no errors, no problem +end + +--[[ +The below tests pass with their given seeds, if they begin failing +for a given platform, or all around, new seeds should be found. + +Note: these tests which assert RNG, are mere sanity checks +to ensure things haven't been severely broken by any changes +]] + +function test.cxxrandom_seed() + local nd = rng.normal_distribution:new(0,500000) + local e1 = rng.MakeNewEngine(1) + local e2 = rng.MakeNewEngine(1) + local e3 = rng.MakeNewEngine(2) + local g1 = rng.crng:new(e1, true, nd) + local g2 = rng.crng:new(e2, true, nd) + local g3 = rng.crng:new(e3, true, nd) + local v1 = g1:next() + expect.eq(v1, g2:next()) + expect.ne(v1, g3:next()) +end + +function test.cxxrandom_ranges() + local e1 = rng.MakeNewEngine(1) + local g1 = rng.crng:new(e1, true, rng.normal_distribution:new(0,1)) + local g2 = rng.crng:new(e1, true, rng.real_distribution:new(-5,5)) + local g3 = rng.crng:new(e1, true, rng.int_distribution:new(-5,5)) + local g4 = rng.crng:new(e1, true, rng.num_sequence:new(-5,5)) + for i = 1, 10 do + local a = g1:next() + local b = g2:next() + local c = g3:next() + local d = g4:next() + expect.ge(a, -5) + expect.ge(b, -5) + expect.ge(c, -5) + expect.ge(d, -5) + expect.le(a, 5) + expect.le(b, 5) + expect.le(c, 5) + expect.le(d, 5) + end + local gb = rng.crng:new(e1, true, rng.bool_distribution:new(0.00000000001)) + for i = 1, 10 do + expect.false_(gb:next()) + end +end + +function test.cxxrandom_exports() + local id = rng.GenerateEngine(0) + rng.NewSeed(id, 2022) + expect.ge(rng.rollInt(id, 0, 1000), 0) + expect.ge(rng.rollDouble(id, 0, 1), 0) + expect.ge(rng.rollNormal(id, 5, 1), 0) + expect.true_(rng.rollBool(id, 0.9999999999)) + local sid = rng.MakeNumSequence(0,8) + rng.AddToSequence(sid, 9) + rng.ShuffleSequence(sid, id) + for i = 1, 10 do + local v = rng.NextInSequence(sid) + expect.ge(v, 0) + expect.le(v, 9) + end + rng.DestroyNumSequence(sid) + rng.DestroyEngine(id) +end