diff --git a/graph.go b/graph.go index 477f4f2..0a615cf 100644 --- a/graph.go +++ b/graph.go @@ -10,7 +10,7 @@ import ( "encoding/json" ) -type StateLoadFunc func(*GraphContext, NodeID, []byte, map[NodeID]GraphNode)(NodeState, error) +type StateLoadFunc func(*GraphContext, []byte, NodeMap)(NodeState, error) type StateLoadMap map[string]StateLoadFunc type NodeLoadFunc func(*GraphContext, NodeID)(GraphNode, error) type NodeLoadMap map[string]NodeLoadFunc @@ -21,15 +21,72 @@ type GraphContext struct { StateLoadFuncs StateLoadMap } +func LoadNode(ctx * GraphContext, id NodeID) (GraphNode, error) { + // Initialize an empty list of loaded nodes, then start loading them from id + loaded_nodes := map[NodeID]GraphNode{} + return LoadNodeRecurse(ctx, id, loaded_nodes) +} + +type DBJSONBase struct { + Type string `json:"type"` +} + +// Check if a node is already loaded, load it's state bytes from the DB and parse the type if it's not already loaded +// Call the node load function related to the type, which will call this parse function recusively as needed +func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]GraphNode) (GraphNode, error) { + node, exists := loaded_nodes[id] + if exists == false { + state_bytes, err := ReadDBState(ctx, id) + if err != nil { + return nil, err + } + + var base DBJSONBase + err = json.Unmarshal(state_bytes, &base) + if err != nil { + return nil, err + } + + ctx.Log.Logf("graph", "GRAPH_DB_LOAD: %s(%s)", base.Type, id) + + node_fn, exists := ctx.NodeLoadFuncs[base.Type] + if exists == false { + return nil, fmt.Errorf("%s is not a known node type", base.Type) + } + + node, err = node_fn(ctx, id) + if err != nil { + return nil, err + } + + loaded_nodes[id] = node + + state_fn, exists := ctx.StateLoadFuncs[base.Type] + if exists == false { + return nil, fmt.Errorf("%s is not a known node state type", base.Type) + } + + state, err := state_fn(ctx, state_bytes, loaded_nodes) + if err != nil { + return nil, err + } + + node.SetState(state) + } + return node, nil +} + func NewGraphContext(db * badger.DB, log Logger) * GraphContext { ctx := GraphContext{ DB: db, Log: log, NodeLoadFuncs: NodeLoadMap{ "base_lockable": LoadBaseLockable, + "base_thread": LoadBaseThread, }, StateLoadFuncs: StateLoadMap{ "base_lockable": LoadBaseLockableState, + "base_thread": LoadBaseThreadState, }, } @@ -300,29 +357,6 @@ func (node * BaseNode) StateLock() * sync.RWMutex { return &node.state_lock } -func ReadDBStateCopy(ctx * GraphContext, id NodeID) ([]byte, error) { - ctx.Log.Logf("db", "DB_READ: %s", id) - - var val []byte = nil - err := ctx.DB.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(id)) - if err != nil { - return err - } - - val, err = item.ValueCopy(nil) - if err != nil { - return err - } - return nil - }) - if err != nil { - return nil, err - } - - return val, nil -} - func ReadDBState(ctx * GraphContext, id NodeID) ([]byte, error) { var bytes []byte err := ctx.DB.View(func(txn *badger.Txn) error { @@ -349,21 +383,20 @@ func ReadDBState(ctx * GraphContext, id NodeID) ([]byte, error) { func WriteDBStates(ctx * GraphContext, nodes NodeMap) error{ ctx.Log.Logf("db", "DB_WRITES: %d", len(nodes)) - var serialized_states [][]byte = make([][]byte, len(nodes)) - i := 0 + serialized_states := map[NodeID][]byte{} for _, node := range(nodes) { ser, err := json.Marshal(node.State()) if err != nil { return fmt.Errorf("DB_MARSHAL_ERROR: %e", err) } - serialized_states[i] = ser - i++ + serialized_states[node.ID()] = ser } err := ctx.DB.Update(func(txn *badger.Txn) error { i := 0 for id, _ := range(nodes) { - err := txn.Set([]byte(id), serialized_states[i]) + ctx.Log.Logf("db", "DB_WRITE: %s - %s", id, string(serialized_states[id])) + err := txn.Set([]byte(id), serialized_states[id]) if err != nil { return fmt.Errorf("DB_MARSHAL_ERROR: %e", err) } diff --git a/lockable.go b/lockable.go index 81c4128..90d310b 100644 --- a/lockable.go +++ b/lockable.go @@ -51,7 +51,7 @@ func (state * BaseLockableState) Type() string { return state._type } -func (state * BaseLockableState) MarshalJSON() ([]byte, error) { +func SaveBaseLockableState(state * BaseLockableState) BaseLockableStateJSON { requirement_ids := make([]NodeID, len(state.requirements)) for i, requirement := range(state.requirements) { requirement_ids[i] = requirement.ID() @@ -77,15 +77,19 @@ func (state * BaseLockableState) MarshalJSON() ([]byte, error) { locks_held[lockable_id] = &str } } - - return json.Marshal(&BaseLockableStateJSON{ + return BaseLockableStateJSON{ Type: state._type, Name: state.name, Owner: owner_id, Dependencies: dependency_ids, Requirements: requirement_ids, LocksHeld: locks_held, - }) + } +} + +func (state * BaseLockableState) MarshalJSON() ([]byte, error) { + lockable_state := SaveBaseLockableState(state) + return json.Marshal(&lockable_state) } func (state * BaseLockableState) Name() string { @@ -534,45 +538,110 @@ func NewBaseLockable(ctx * GraphContext, state LockableState) (BaseLockable, err return lockable, nil } -func LoadBaseLockable(ctx * GraphContext, id NodeID) (GraphNode, error) { - // call LoadNodeRecurse on any connected nodes to ensure they're loaded and return the id +func LoadBaseThread(ctx * GraphContext, id NodeID) (GraphNode, error) { base_node := RestoreNode(ctx, id) - lockable := BaseLockable{ - BaseNode: base_node, + thread := BaseThread{ + BaseLockable: BaseLockable{ + BaseNode: base_node, + }, } - return &lockable, nil + return &thread, nil } -func LoadBaseLockableState(ctx * GraphContext, id NodeID, data []byte, loaded_nodes map[NodeID]GraphNode)(NodeState, error){ - var j BaseLockableStateJSON - err := json.Unmarshal(data, &j) +func RestoreBaseThreadState(ctx * GraphContext, j BaseThreadStateJSON, loaded_nodes NodeMap) (*BaseThreadState, error) { + lockable_state, err := RestoreBaseLockableState(ctx, j.LockableState, loaded_nodes) if err != nil { return nil, err } + lockable_state._type = "thread_state" - var owner Lockable = nil - if j.Owner != nil { - o, err := LoadNodeRecurse(ctx, *j.Owner, loaded_nodes) + state := BaseThreadState{ + BaseLockableState: *lockable_state, + parent: nil, + children: make([]Thread, len(j.Children)), + child_info: map[NodeID]ThreadInfo{}, + InfoType: nil, + running: false, + } + + if j.Parent != nil { + p, err := LoadNodeRecurse(ctx, *j.Parent, loaded_nodes) if err != nil { return nil, err } - o_l, ok := o.(Lockable) + p_t, ok := p.(Thread) if ok == false { return nil, err } - owner = o_l + state.owner = p_t } + i := 0 + for id, info := range(j.Children) { + child_node, err := LoadNodeRecurse(ctx, id, loaded_nodes) + if err != nil { + return nil, err + } + child_t, ok := child_node.(Thread) + if ok == false { + return nil, fmt.Errorf("%+v is not a Thread as expected", child_node) + } + state.children[i] = child_t + state.child_info[id] = info + i++ + } + + return &state, nil +} + +func LoadBaseThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){ + var j BaseThreadStateJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + state, err := RestoreBaseThreadState(ctx, j, loaded_nodes) + if err != nil { + return nil, err + } + + return state, nil +} + +func LoadBaseLockable(ctx * GraphContext, id NodeID) (GraphNode, error) { + // call LoadNodeRecurse on any connected nodes to ensure they're loaded and return the id + base_node := RestoreNode(ctx, id) + lockable := BaseLockable{ + BaseNode: base_node, + } + + return &lockable, nil +} + +func RestoreBaseLockableState(ctx * GraphContext, j BaseLockableStateJSON, loaded_nodes NodeMap) (*BaseLockableState, error) { state := BaseLockableState{ _type: "base_lockable", name: j.Name, - owner: owner, + owner: nil, dependencies: make([]Lockable, len(j.Dependencies)), requirements: make([]Lockable, len(j.Requirements)), locks_held: map[NodeID]Lockable{}, } + if j.Owner != nil { + o, err := LoadNodeRecurse(ctx, *j.Owner, loaded_nodes) + if err != nil { + return nil, err + } + o_l, ok := o.(Lockable) + if ok == false { + return nil, err + } + state.owner = o_l + } + for i, dep := range(j.Dependencies) { dep_node, err := LoadNodeRecurse(ctx, dep, loaded_nodes) if err != nil { @@ -616,62 +685,23 @@ func LoadBaseLockableState(ctx * GraphContext, id NodeID, data []byte, loaded_no } state.locks_held[l_id] = h_l } - return &state, nil -} - -func LoadNode(ctx * GraphContext, id NodeID) (GraphNode, error) { - // Initialize an empty list of loaded nodes, then start loading them from id - loaded_nodes := map[NodeID]GraphNode{} - return LoadNodeRecurse(ctx, id, loaded_nodes) -} -type DBJSONBase struct { - Type string `json:"type"` + return &state, nil } -// Check if a node is already loaded, load it's state bytes from the DB and parse the type if it's not already loaded -// Call the node load function related to the type, which will call this parse function recusively as needed -func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]GraphNode) (GraphNode, error) { - node, exists := loaded_nodes[id] - if exists == false { - state_bytes, err := ReadDBState(ctx, id) - if err != nil { - return nil, err - } - - var base DBJSONBase - err = json.Unmarshal(state_bytes, &base) - if err != nil { - return nil, err - } - - ctx.Log.Logf("graph", "GRAPH_DB_LOAD: %s(%s)", base.Type, id) - - node_fn, exists := ctx.NodeLoadFuncs[base.Type] - if exists == false { - return nil, fmt.Errorf("%s is not a known node type", base.Type) - } - - node, err = node_fn(ctx, id) - if err != nil { - return nil, err - } - - loaded_nodes[id] = node - - state_fn, exists := ctx.StateLoadFuncs[base.Type] - if exists == false { - return nil, fmt.Errorf("%s is not a known node state type", base.Type) - } - - state, err := state_fn(ctx, id, state_bytes, loaded_nodes) - if err != nil { - return nil, err - } +func LoadBaseLockableState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){ + var j BaseLockableStateJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } - node.SetState(state) + state, err := RestoreBaseLockableState(ctx, j, loaded_nodes) + if err != nil { + return nil, err } - return node, nil + + return state, nil } func NewSimpleBaseLockable(ctx * GraphContext, name string, requirements []Lockable) (*BaseLockable, error) { diff --git a/lockable_test.go b/lockable_test.go index 14b1a8d..00d2768 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -4,15 +4,43 @@ import ( "testing" "fmt" "time" + "encoding/json" ) func TestNewSimpleBaseLockable(t * testing.T) { ctx := testContext(t) - r1, err := NewSimpleBaseLockable(ctx, "Test lockable 1", []Lockable{}) + l1, err := NewSimpleBaseLockable(ctx, "Test lockable 1", []Lockable{}) + fatalErr(t, err) + + l2, err := NewSimpleBaseLockable(ctx, "Test lockable 2", []Lockable{l1}) fatalErr(t, err) - _, err = NewSimpleBaseLockable(ctx, "Test lockable 2", []Lockable{r1}) + err = UseStates(ctx, []GraphNode{l1, l2}, func(states NodeStateMap) error { + l1_state := states[l1.ID()].(LockableState) + l2_state := states[l2.ID()].(LockableState) + + l1_deps := len(l1_state.Dependencies()) + if l1_deps != 1 { + return fmt.Errorf("l1 has wront amount of dependencies %d/1", l1_deps) + } + + l1_dep1 := l1_state.Dependencies()[0] + if l1_dep1.ID() != l2.ID() { + return fmt.Errorf("Wrong dependency for l1, %s instead of %s", l1_dep1.ID(), l2.ID()) + } + + l2_reqs := len(l2_state.Requirements()) + if l2_reqs != 1 { + return fmt.Errorf("l2 has wrong amount of requirements %d/1", l2_reqs) + } + + l2_req1 := l2_state.Requirements()[0] + if l2_req1.ID() != l1.ID() { + return fmt.Errorf("Wrong requirement for l2, %s instead of %s", l2_req1.ID(), l1.ID()) + } + return nil + }) fatalErr(t, err) } @@ -334,13 +362,18 @@ func TestLockableDependencyOverlap(t * testing.T) { } func TestLockableDBLoad(t * testing.T){ - ctx := testContext(t) + ctx := logTestContext(t, []string{"db"}) l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{}) fatalErr(t, err) l2, err := NewSimpleBaseLockable(ctx, "Test Lockable 2", []Lockable{}) fatalErr(t, err) l3, err := NewSimpleBaseLockable(ctx, "Test Lockable 3", []Lockable{l1, l2}) fatalErr(t, err) + err = UseStates(ctx, []GraphNode{l3}, func(states NodeStateMap) error { + ser, err := json.MarshalIndent(states[l3.ID()], "", " ") + fmt.Printf("\n%s\n\n", ser) + return err + }) l4, err := NewSimpleBaseLockable(ctx, "Test Lockable 4", []Lockable{l3}) fatalErr(t, err) _, err = NewSimpleBaseLockable(ctx, "Test Lockable 5", []Lockable{l4}) @@ -352,6 +385,11 @@ func TestLockableDBLoad(t * testing.T){ return err }) fatalErr(t, err) + err = UseStates(ctx, []GraphNode{l3}, func(states NodeStateMap) error { + ser, err := json.MarshalIndent(states[l3.ID()], "", " ") + fmt.Printf("\n%s\n\n", ser) + return err + }) _, err = LoadNode(ctx, l3.ID()) fatalErr(t, err) diff --git a/thread.go b/thread.go index 36d25e9..06181ff 100644 --- a/thread.go +++ b/thread.go @@ -72,10 +72,10 @@ type BaseThreadState struct { type BaseThreadStateJSON struct { Parent *NodeID `json:"parent"` Children map[NodeID]interface{} `json:"children"` - LockableState *BaseLockableState `json:"lockable"` + LockableState BaseLockableStateJSON `json:"lockable"` } -func (state * BaseThreadState) MarshalJSON() ([]byte, error) { +func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON { children := map[NodeID]interface{}{} for _, child := range(state.children) { children[child.ID()] = state.child_info[child.ID()] @@ -87,11 +87,18 @@ func (state * BaseThreadState) MarshalJSON() ([]byte, error) { parent_id = &new_str } - return json.Marshal(&BaseThreadStateJSON{ + lockable_state := SaveBaseLockableState(&state.BaseLockableState) + + return BaseThreadStateJSON{ Parent: parent_id, Children: children, - LockableState: &state.BaseLockableState, - }) + LockableState: lockable_state, + } +} + +func (state * BaseThreadState) MarshalJSON() ([]byte, error) { + thread_state := SaveBaseThreadState(state) + return json.Marshal(&thread_state) } func (state * BaseThreadState) Start() error { diff --git a/thread_test.go b/thread_test.go index 401363b..5b101e3 100644 --- a/thread_test.go +++ b/thread_test.go @@ -56,10 +56,3 @@ func TestThreadWithRequirement(t * testing.T) { }) fatalErr(t, err) } - -func TestCustomThreadState(t * testing.T ) { - ctx := testContext(t) - - _, err := NewSimpleBaseThread(ctx, "Test Thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) - fatalErr(t, err) -}