Added BaseLockable and BaseThread nested saving/loading

graph-rework-2
noah metz 2023-06-30 13:25:35 -06:00
parent 41d88b9d86
commit 04771b7816
5 changed files with 214 additions and 113 deletions

@ -10,7 +10,7 @@ import (
"encoding/json" "encoding/json"
) )
type StateLoadFunc func(*GraphContext, NodeID, []byte, map[NodeID]GraphNode)(NodeState, error) type StateLoadFunc func(*GraphContext, []byte, NodeMap)(NodeState, error)
type StateLoadMap map[string]StateLoadFunc type StateLoadMap map[string]StateLoadFunc
type NodeLoadFunc func(*GraphContext, NodeID)(GraphNode, error) type NodeLoadFunc func(*GraphContext, NodeID)(GraphNode, error)
type NodeLoadMap map[string]NodeLoadFunc type NodeLoadMap map[string]NodeLoadFunc
@ -21,15 +21,72 @@ type GraphContext struct {
StateLoadFuncs StateLoadMap StateLoadFuncs StateLoadMap
} }
func LoadNode(ctx * GraphContext, id NodeID) (GraphNode, error) {
// Initialize an empty list of loaded nodes, then start loading them from id
loaded_nodes := map[NodeID]GraphNode{}
return LoadNodeRecurse(ctx, id, loaded_nodes)
}
type DBJSONBase struct {
Type string `json:"type"`
}
// Check if a node is already loaded, load it's state bytes from the DB and parse the type if it's not already loaded
// Call the node load function related to the type, which will call this parse function recusively as needed
func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]GraphNode) (GraphNode, error) {
node, exists := loaded_nodes[id]
if exists == false {
state_bytes, err := ReadDBState(ctx, id)
if err != nil {
return nil, err
}
var base DBJSONBase
err = json.Unmarshal(state_bytes, &base)
if err != nil {
return nil, err
}
ctx.Log.Logf("graph", "GRAPH_DB_LOAD: %s(%s)", base.Type, id)
node_fn, exists := ctx.NodeLoadFuncs[base.Type]
if exists == false {
return nil, fmt.Errorf("%s is not a known node type", base.Type)
}
node, err = node_fn(ctx, id)
if err != nil {
return nil, err
}
loaded_nodes[id] = node
state_fn, exists := ctx.StateLoadFuncs[base.Type]
if exists == false {
return nil, fmt.Errorf("%s is not a known node state type", base.Type)
}
state, err := state_fn(ctx, state_bytes, loaded_nodes)
if err != nil {
return nil, err
}
node.SetState(state)
}
return node, nil
}
func NewGraphContext(db * badger.DB, log Logger) * GraphContext { func NewGraphContext(db * badger.DB, log Logger) * GraphContext {
ctx := GraphContext{ ctx := GraphContext{
DB: db, DB: db,
Log: log, Log: log,
NodeLoadFuncs: NodeLoadMap{ NodeLoadFuncs: NodeLoadMap{
"base_lockable": LoadBaseLockable, "base_lockable": LoadBaseLockable,
"base_thread": LoadBaseThread,
}, },
StateLoadFuncs: StateLoadMap{ StateLoadFuncs: StateLoadMap{
"base_lockable": LoadBaseLockableState, "base_lockable": LoadBaseLockableState,
"base_thread": LoadBaseThreadState,
}, },
} }
@ -300,29 +357,6 @@ func (node * BaseNode) StateLock() * sync.RWMutex {
return &node.state_lock return &node.state_lock
} }
func ReadDBStateCopy(ctx * GraphContext, id NodeID) ([]byte, error) {
ctx.Log.Logf("db", "DB_READ: %s", id)
var val []byte = nil
err := ctx.DB.View(func(txn *badger.Txn) error {
item, err := txn.Get([]byte(id))
if err != nil {
return err
}
val, err = item.ValueCopy(nil)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return val, nil
}
func ReadDBState(ctx * GraphContext, id NodeID) ([]byte, error) { func ReadDBState(ctx * GraphContext, id NodeID) ([]byte, error) {
var bytes []byte var bytes []byte
err := ctx.DB.View(func(txn *badger.Txn) error { err := ctx.DB.View(func(txn *badger.Txn) error {
@ -349,21 +383,20 @@ func ReadDBState(ctx * GraphContext, id NodeID) ([]byte, error) {
func WriteDBStates(ctx * GraphContext, nodes NodeMap) error{ func WriteDBStates(ctx * GraphContext, nodes NodeMap) error{
ctx.Log.Logf("db", "DB_WRITES: %d", len(nodes)) ctx.Log.Logf("db", "DB_WRITES: %d", len(nodes))
var serialized_states [][]byte = make([][]byte, len(nodes)) serialized_states := map[NodeID][]byte{}
i := 0
for _, node := range(nodes) { for _, node := range(nodes) {
ser, err := json.Marshal(node.State()) ser, err := json.Marshal(node.State())
if err != nil { if err != nil {
return fmt.Errorf("DB_MARSHAL_ERROR: %e", err) return fmt.Errorf("DB_MARSHAL_ERROR: %e", err)
} }
serialized_states[i] = ser serialized_states[node.ID()] = ser
i++
} }
err := ctx.DB.Update(func(txn *badger.Txn) error { err := ctx.DB.Update(func(txn *badger.Txn) error {
i := 0 i := 0
for id, _ := range(nodes) { for id, _ := range(nodes) {
err := txn.Set([]byte(id), serialized_states[i]) ctx.Log.Logf("db", "DB_WRITE: %s - %s", id, string(serialized_states[id]))
err := txn.Set([]byte(id), serialized_states[id])
if err != nil { if err != nil {
return fmt.Errorf("DB_MARSHAL_ERROR: %e", err) return fmt.Errorf("DB_MARSHAL_ERROR: %e", err)
} }

@ -51,7 +51,7 @@ func (state * BaseLockableState) Type() string {
return state._type return state._type
} }
func (state * BaseLockableState) MarshalJSON() ([]byte, error) { func SaveBaseLockableState(state * BaseLockableState) BaseLockableStateJSON {
requirement_ids := make([]NodeID, len(state.requirements)) requirement_ids := make([]NodeID, len(state.requirements))
for i, requirement := range(state.requirements) { for i, requirement := range(state.requirements) {
requirement_ids[i] = requirement.ID() requirement_ids[i] = requirement.ID()
@ -77,15 +77,19 @@ func (state * BaseLockableState) MarshalJSON() ([]byte, error) {
locks_held[lockable_id] = &str locks_held[lockable_id] = &str
} }
} }
return BaseLockableStateJSON{
return json.Marshal(&BaseLockableStateJSON{
Type: state._type, Type: state._type,
Name: state.name, Name: state.name,
Owner: owner_id, Owner: owner_id,
Dependencies: dependency_ids, Dependencies: dependency_ids,
Requirements: requirement_ids, Requirements: requirement_ids,
LocksHeld: locks_held, LocksHeld: locks_held,
}) }
}
func (state * BaseLockableState) MarshalJSON() ([]byte, error) {
lockable_state := SaveBaseLockableState(state)
return json.Marshal(&lockable_state)
} }
func (state * BaseLockableState) Name() string { func (state * BaseLockableState) Name() string {
@ -534,45 +538,110 @@ func NewBaseLockable(ctx * GraphContext, state LockableState) (BaseLockable, err
return lockable, nil return lockable, nil
} }
func LoadBaseLockable(ctx * GraphContext, id NodeID) (GraphNode, error) { func LoadBaseThread(ctx * GraphContext, id NodeID) (GraphNode, error) {
// call LoadNodeRecurse on any connected nodes to ensure they're loaded and return the id
base_node := RestoreNode(ctx, id) base_node := RestoreNode(ctx, id)
lockable := BaseLockable{ thread := BaseThread{
BaseLockable: BaseLockable{
BaseNode: base_node, BaseNode: base_node,
},
} }
return &lockable, nil return &thread, nil
} }
func LoadBaseLockableState(ctx * GraphContext, id NodeID, data []byte, loaded_nodes map[NodeID]GraphNode)(NodeState, error){ func RestoreBaseThreadState(ctx * GraphContext, j BaseThreadStateJSON, loaded_nodes NodeMap) (*BaseThreadState, error) {
var j BaseLockableStateJSON lockable_state, err := RestoreBaseLockableState(ctx, j.LockableState, loaded_nodes)
err := json.Unmarshal(data, &j)
if err != nil { if err != nil {
return nil, err return nil, err
} }
lockable_state._type = "thread_state"
var owner Lockable = nil state := BaseThreadState{
if j.Owner != nil { BaseLockableState: *lockable_state,
o, err := LoadNodeRecurse(ctx, *j.Owner, loaded_nodes) parent: nil,
children: make([]Thread, len(j.Children)),
child_info: map[NodeID]ThreadInfo{},
InfoType: nil,
running: false,
}
if j.Parent != nil {
p, err := LoadNodeRecurse(ctx, *j.Parent, loaded_nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
o_l, ok := o.(Lockable) p_t, ok := p.(Thread)
if ok == false {
return nil, err
}
state.owner = p_t
}
i := 0
for id, info := range(j.Children) {
child_node, err := LoadNodeRecurse(ctx, id, loaded_nodes)
if err != nil {
return nil, err
}
child_t, ok := child_node.(Thread)
if ok == false { if ok == false {
return nil, fmt.Errorf("%+v is not a Thread as expected", child_node)
}
state.children[i] = child_t
state.child_info[id] = info
i++
}
return &state, nil
}
func LoadBaseThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){
var j BaseThreadStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
state, err := RestoreBaseThreadState(ctx, j, loaded_nodes)
if err != nil {
return nil, err return nil, err
} }
owner = o_l
return state, nil
} }
func LoadBaseLockable(ctx * GraphContext, id NodeID) (GraphNode, error) {
// call LoadNodeRecurse on any connected nodes to ensure they're loaded and return the id
base_node := RestoreNode(ctx, id)
lockable := BaseLockable{
BaseNode: base_node,
}
return &lockable, nil
}
func RestoreBaseLockableState(ctx * GraphContext, j BaseLockableStateJSON, loaded_nodes NodeMap) (*BaseLockableState, error) {
state := BaseLockableState{ state := BaseLockableState{
_type: "base_lockable", _type: "base_lockable",
name: j.Name, name: j.Name,
owner: owner, owner: nil,
dependencies: make([]Lockable, len(j.Dependencies)), dependencies: make([]Lockable, len(j.Dependencies)),
requirements: make([]Lockable, len(j.Requirements)), requirements: make([]Lockable, len(j.Requirements)),
locks_held: map[NodeID]Lockable{}, locks_held: map[NodeID]Lockable{},
} }
if j.Owner != nil {
o, err := LoadNodeRecurse(ctx, *j.Owner, loaded_nodes)
if err != nil {
return nil, err
}
o_l, ok := o.(Lockable)
if ok == false {
return nil, err
}
state.owner = o_l
}
for i, dep := range(j.Dependencies) { for i, dep := range(j.Dependencies) {
dep_node, err := LoadNodeRecurse(ctx, dep, loaded_nodes) dep_node, err := LoadNodeRecurse(ctx, dep, loaded_nodes)
if err != nil { if err != nil {
@ -616,62 +685,23 @@ func LoadBaseLockableState(ctx * GraphContext, id NodeID, data []byte, loaded_no
} }
state.locks_held[l_id] = h_l state.locks_held[l_id] = h_l
} }
return &state, nil
}
func LoadNode(ctx * GraphContext, id NodeID) (GraphNode, error) { return &state, nil
// Initialize an empty list of loaded nodes, then start loading them from id
loaded_nodes := map[NodeID]GraphNode{}
return LoadNodeRecurse(ctx, id, loaded_nodes)
}
type DBJSONBase struct {
Type string `json:"type"`
}
// Check if a node is already loaded, load it's state bytes from the DB and parse the type if it's not already loaded
// Call the node load function related to the type, which will call this parse function recusively as needed
func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]GraphNode) (GraphNode, error) {
node, exists := loaded_nodes[id]
if exists == false {
state_bytes, err := ReadDBState(ctx, id)
if err != nil {
return nil, err
}
var base DBJSONBase
err = json.Unmarshal(state_bytes, &base)
if err != nil {
return nil, err
}
ctx.Log.Logf("graph", "GRAPH_DB_LOAD: %s(%s)", base.Type, id)
node_fn, exists := ctx.NodeLoadFuncs[base.Type]
if exists == false {
return nil, fmt.Errorf("%s is not a known node type", base.Type)
} }
node, err = node_fn(ctx, id) func LoadBaseLockableState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){
var j BaseLockableStateJSON
err := json.Unmarshal(data, &j)
if err != nil { if err != nil {
return nil, err return nil, err
} }
loaded_nodes[id] = node state, err := RestoreBaseLockableState(ctx, j, loaded_nodes)
state_fn, exists := ctx.StateLoadFuncs[base.Type]
if exists == false {
return nil, fmt.Errorf("%s is not a known node state type", base.Type)
}
state, err := state_fn(ctx, id, state_bytes, loaded_nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node.SetState(state) return state, nil
}
return node, nil
} }
func NewSimpleBaseLockable(ctx * GraphContext, name string, requirements []Lockable) (*BaseLockable, error) { func NewSimpleBaseLockable(ctx * GraphContext, name string, requirements []Lockable) (*BaseLockable, error) {

@ -4,15 +4,43 @@ import (
"testing" "testing"
"fmt" "fmt"
"time" "time"
"encoding/json"
) )
func TestNewSimpleBaseLockable(t * testing.T) { func TestNewSimpleBaseLockable(t * testing.T) {
ctx := testContext(t) ctx := testContext(t)
r1, err := NewSimpleBaseLockable(ctx, "Test lockable 1", []Lockable{}) l1, err := NewSimpleBaseLockable(ctx, "Test lockable 1", []Lockable{})
fatalErr(t, err)
l2, err := NewSimpleBaseLockable(ctx, "Test lockable 2", []Lockable{l1})
fatalErr(t, err) fatalErr(t, err)
_, err = NewSimpleBaseLockable(ctx, "Test lockable 2", []Lockable{r1}) err = UseStates(ctx, []GraphNode{l1, l2}, func(states NodeStateMap) error {
l1_state := states[l1.ID()].(LockableState)
l2_state := states[l2.ID()].(LockableState)
l1_deps := len(l1_state.Dependencies())
if l1_deps != 1 {
return fmt.Errorf("l1 has wront amount of dependencies %d/1", l1_deps)
}
l1_dep1 := l1_state.Dependencies()[0]
if l1_dep1.ID() != l2.ID() {
return fmt.Errorf("Wrong dependency for l1, %s instead of %s", l1_dep1.ID(), l2.ID())
}
l2_reqs := len(l2_state.Requirements())
if l2_reqs != 1 {
return fmt.Errorf("l2 has wrong amount of requirements %d/1", l2_reqs)
}
l2_req1 := l2_state.Requirements()[0]
if l2_req1.ID() != l1.ID() {
return fmt.Errorf("Wrong requirement for l2, %s instead of %s", l2_req1.ID(), l1.ID())
}
return nil
})
fatalErr(t, err) fatalErr(t, err)
} }
@ -334,13 +362,18 @@ func TestLockableDependencyOverlap(t * testing.T) {
} }
func TestLockableDBLoad(t * testing.T){ func TestLockableDBLoad(t * testing.T){
ctx := testContext(t) ctx := logTestContext(t, []string{"db"})
l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{}) l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err) fatalErr(t, err)
l2, err := NewSimpleBaseLockable(ctx, "Test Lockable 2", []Lockable{}) l2, err := NewSimpleBaseLockable(ctx, "Test Lockable 2", []Lockable{})
fatalErr(t, err) fatalErr(t, err)
l3, err := NewSimpleBaseLockable(ctx, "Test Lockable 3", []Lockable{l1, l2}) l3, err := NewSimpleBaseLockable(ctx, "Test Lockable 3", []Lockable{l1, l2})
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []GraphNode{l3}, func(states NodeStateMap) error {
ser, err := json.MarshalIndent(states[l3.ID()], "", " ")
fmt.Printf("\n%s\n\n", ser)
return err
})
l4, err := NewSimpleBaseLockable(ctx, "Test Lockable 4", []Lockable{l3}) l4, err := NewSimpleBaseLockable(ctx, "Test Lockable 4", []Lockable{l3})
fatalErr(t, err) fatalErr(t, err)
_, err = NewSimpleBaseLockable(ctx, "Test Lockable 5", []Lockable{l4}) _, err = NewSimpleBaseLockable(ctx, "Test Lockable 5", []Lockable{l4})
@ -352,6 +385,11 @@ func TestLockableDBLoad(t * testing.T){
return err return err
}) })
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []GraphNode{l3}, func(states NodeStateMap) error {
ser, err := json.MarshalIndent(states[l3.ID()], "", " ")
fmt.Printf("\n%s\n\n", ser)
return err
})
_, err = LoadNode(ctx, l3.ID()) _, err = LoadNode(ctx, l3.ID())
fatalErr(t, err) fatalErr(t, err)

@ -72,10 +72,10 @@ type BaseThreadState struct {
type BaseThreadStateJSON struct { type BaseThreadStateJSON struct {
Parent *NodeID `json:"parent"` Parent *NodeID `json:"parent"`
Children map[NodeID]interface{} `json:"children"` Children map[NodeID]interface{} `json:"children"`
LockableState *BaseLockableState `json:"lockable"` LockableState BaseLockableStateJSON `json:"lockable"`
} }
func (state * BaseThreadState) MarshalJSON() ([]byte, error) { func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON {
children := map[NodeID]interface{}{} children := map[NodeID]interface{}{}
for _, child := range(state.children) { for _, child := range(state.children) {
children[child.ID()] = state.child_info[child.ID()] children[child.ID()] = state.child_info[child.ID()]
@ -87,11 +87,18 @@ func (state * BaseThreadState) MarshalJSON() ([]byte, error) {
parent_id = &new_str parent_id = &new_str
} }
return json.Marshal(&BaseThreadStateJSON{ lockable_state := SaveBaseLockableState(&state.BaseLockableState)
return BaseThreadStateJSON{
Parent: parent_id, Parent: parent_id,
Children: children, Children: children,
LockableState: &state.BaseLockableState, LockableState: lockable_state,
}) }
}
func (state * BaseThreadState) MarshalJSON() ([]byte, error) {
thread_state := SaveBaseThreadState(state)
return json.Marshal(&thread_state)
} }
func (state * BaseThreadState) Start() error { func (state * BaseThreadState) Start() error {

@ -56,10 +56,3 @@ func TestThreadWithRequirement(t * testing.T) {
}) })
fatalErr(t, err) fatalErr(t, err)
} }
func TestCustomThreadState(t * testing.T ) {
ctx := testContext(t)
_, err := NewSimpleBaseThread(ctx, "Test Thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{})
fatalErr(t, err)
}