diff --git a/gql_graph.go b/gql_graph.go index 9a50d25..e417ea2 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -824,13 +824,13 @@ func GQLMutationSendUpdate() *graphql.Field { if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } + SendUpdate(ctx, node, signal, states) return nil }) if err != nil { return nil, err } - SendUpdate(ctx, node, signal) return signal, nil }, } diff --git a/gql_test.go b/gql_test.go index a01606a..2478b08 100644 --- a/gql_test.go +++ b/gql_test.go @@ -28,7 +28,10 @@ func TestGQLThread(t * testing.T) { go func(thread Thread){ time.Sleep(10*time.Millisecond) - SendUpdate(ctx, thread, CancelSignal(nil)) + err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error { + SendUpdate(ctx, thread, CancelSignal(nil), states) + return nil + }) }(gql_thread) err = RunThread(ctx, gql_thread, "start") @@ -36,7 +39,7 @@ func TestGQLThread(t * testing.T) { } func TestGQLDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{"thread"}) + ctx := logTestContext(t, []string{"thread", "update", "gql"}) l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) fatalErr(t, err) @@ -48,14 +51,19 @@ func TestGQLDBLoad(t * testing.T) { fatalErr(t, err) info := NewGQLThreadInfo(true, "start", "restore") - err = LinkThreads(ctx, gql, t1, &info) + err = UpdateStates(ctx, []GraphNode{gql, t1}, func(nodes NodeMap) error { + return LinkThreads(ctx, gql, t1, &info) + }) fatalErr(t, err) - - SendUpdate(ctx, gql, CancelSignal(nil)) + err = UseStates(ctx, []GraphNode{gql}, func(states NodeStateMap) error { + SendUpdate(ctx, gql, NewSignal(t1, "child_added"), states) + SendUpdate(ctx, gql, CancelSignal(nil), states) + return nil + }) err = RunThread(ctx, gql, "start") fatalErr(t, err) - (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1, 100*time.Millisecond, "Dicn't received update_done on t1 from t1") + (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1, 100*time.Millisecond, "Didn't receive thread_done from t1 on t1") err = UseStates(ctx, []GraphNode{gql, t1}, func(states NodeStateMap) error { ser1, err := json.MarshalIndent(states[gql.ID()], "", " ") @@ -80,10 +88,10 @@ func TestGQLDBLoad(t * testing.T) { fmt.Printf("\n%s\n\n", ser) return err }) + SendUpdate(ctx, gql_loaded, CancelSignal(nil), states) return err }) - SendUpdate(ctx, gql_loaded, CancelSignal(nil)) err = RunThread(ctx, gql_loaded.(Thread), "restore") fatalErr(t, err) (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1_loaded, 100*time.Millisecond, "Dicn't received update_done on t1_loaded from t1_loaded") diff --git a/graph.go b/graph.go index 260aa24..35b39e2 100644 --- a/graph.go +++ b/graph.go @@ -431,7 +431,7 @@ type GraphNode interface { // Signal propagation function for listener channels UpdateListeners(ctx * GraphContext, update GraphSignal) // Signal propagation function for connected nodes(defined in state) - PropagateUpdate(ctx * GraphContext, update GraphSignal) + PropagateUpdate(ctx * GraphContext, update GraphSignal, states NodeStateMap) // Get an update channel for the node to be notified of signals UpdateChannel(buffer int) chan GraphSignal @@ -593,6 +593,14 @@ func checkForDuplicate(nodes []GraphNode) error { return nil } +func NodeList[K GraphNode](list []K) []GraphNode { + nodes := make([]GraphNode, len(list)) + for i, node := range(list) { + nodes[i] = node + } + return nodes +} + type NodeStateMap map[NodeID]NodeState type NodeMap map[NodeID]GraphNode type StatesFn func(states NodeStateMap) error @@ -675,7 +683,7 @@ func (node * BaseNode) UpdateListeners(ctx * GraphContext, update GraphSignal) { } } -func (node * BaseNode) PropagateUpdate(ctx * GraphContext, update GraphSignal) { +func (node * BaseNode) PropagateUpdate(ctx * GraphContext, update GraphSignal, states NodeStateMap) { } func (node * BaseNode) RegisterChannel(listener chan GraphSignal) { @@ -710,7 +718,7 @@ func (node * BaseNode) UpdateChannel(buffer int) chan GraphSignal { } // Propogate a signal starting at a node -func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal) { +func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal, states NodeStateMap) { if node == nil { panic("Cannot start an update from no node") } @@ -718,6 +726,6 @@ func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal) { ctx.Log.Logf("update", "UPDATE %s <- %s: %+v", node.ID(), signal.Source(), signal) node.UpdateListeners(ctx, signal) - node.PropagateUpdate(ctx, signal) + node.PropagateUpdate(ctx, signal, states) } diff --git a/lockable.go b/lockable.go index 968d656..94f87a3 100644 --- a/lockable.go +++ b/lockable.go @@ -313,35 +313,42 @@ type Lockable interface { CanUnlock(node GraphNode, state LockableState) error } -func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal) { - UseStates(ctx, []GraphNode{lockable}, func(states NodeStateMap) error { - lockable_state := states[lockable.ID()].(LockableState) - if signal.Direction() == Up { - // Child->Parent, lockable updates dependency lockables - owner_sent := false - for _, dependency := range lockable_state.Dependencies() { - SendUpdate(ctx, dependency, signal) +// lockable's state must already be locked for read +func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal, states NodeStateMap) { + lockable_state := states[lockable.ID()].(LockableState) + if signal.Direction() == Up { + // Child->Parent, lockable updates dependency lockables + owner_sent := false + UseMoreStates(ctx, NodeList(lockable_state.Dependencies()), states, func(states NodeStateMap) error { + for _, dependency := range(lockable_state.Dependencies()){ + SendUpdate(ctx, dependency, signal, states) if lockable_state.Owner() != nil { if dependency.ID() != lockable_state.Owner().ID() { owner_sent = true } } } - if lockable_state.Owner() != nil && owner_sent == false { - SendUpdate(ctx, lockable_state.Owner(), signal) - } - } else if signal.Direction() == Down { - // Parent->Child, lockable updates lock holder + return nil + }) + if lockable_state.Owner() != nil && owner_sent == false { + UseMoreStates(ctx, []GraphNode{lockable_state.Owner()}, states, func(states NodeStateMap) error { + SendUpdate(ctx, lockable_state.Owner(), signal, states) + return nil + }) + } + } else if signal.Direction() == Down { + // Parent->Child, lockable updates lock holder + UseMoreStates(ctx, NodeList(lockable_state.Requirements()), states, func(states NodeStateMap) error { for _, requirement := range(lockable_state.Requirements()) { - SendUpdate(ctx, requirement, signal) + SendUpdate(ctx, requirement, signal, states) } + return nil + }) - } else if signal.Direction() == Direct { - } else { - panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) - } - return nil - }) + } else if signal.Direction() == Direct { + } else { + panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) + } } func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_id NodeID, nodes NodeMap) bool { diff --git a/lockable_test.go b/lockable_test.go index a9faae3..3de88d5 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -279,7 +279,10 @@ func TestLockableSimpleUpdate(t * testing.T) { update_channel := l1.UpdateChannel(0) go func() { - SendUpdate(ctx, l1, NewDirectSignal(l1, "test_update")) + UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) error { + SendUpdate(ctx, l1, NewDirectSignal(l1, "test_update"), states) + return nil + }) }() (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Didn't receive test_update sent to l1") @@ -300,14 +303,17 @@ func TestLockableDownUpdate(t * testing.T) { update_channel := l1.UpdateChannel(0) go func() { - SendUpdate(ctx, l2, NewDownSignal(l2, "test_update")) + UseStates(ctx, []GraphNode{l2}, func(states NodeStateMap) error { + SendUpdate(ctx, l2, NewDownSignal(l2, "test_update"), states) + return nil + }) }() (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") } func TestLockableUpUpdate(t * testing.T) { - ctx := testContext(t) + ctx := logTestContext(t, []string{"test", "update"}) l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) fatalErr(t, err) @@ -321,7 +327,10 @@ func TestLockableUpUpdate(t * testing.T) { update_channel := l3.UpdateChannel(0) go func() { - SendUpdate(ctx, l2, NewSignal(l2, "test_update")) + UseStates(ctx, []GraphNode{l2}, func(states NodeStateMap) error { + SendUpdate(ctx, l2, NewSignal(l2, "test_update"), states) + return nil + }) }() (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") @@ -339,7 +348,10 @@ func TestOwnerNotUpdatedTwice(t * testing.T) { update_channel := l2.UpdateChannel(0) go func() { - SendUpdate(ctx, l1, NewSignal(l1, "test_update")) + UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) error { + SendUpdate(ctx, l1, NewSignal(l1, "test_update"), states) + return nil + }) }() (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1") diff --git a/thread.go b/thread.go index 7d36797..4f2840c 100644 --- a/thread.go +++ b/thread.go @@ -10,35 +10,44 @@ import ( ) // Update the threads listeners, and notify the parent to do the same -func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSignal) { - UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) (error) { - thread_state := states[thread.ID()].(ThreadState) - if signal.Direction() == Up { - // Child->Parent, thread updates parent and connected requirement - if thread_state.Parent() != nil { - SendUpdate(ctx, thread_state.Parent(), signal) - } +func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSignal, states NodeStateMap) { + thread_state := states[thread.ID()].(ThreadState) + if signal.Direction() == Up { + // Child->Parent, thread updates parent and connected requirement + if thread_state.Parent() != nil { + UseMoreStates(ctx, []GraphNode{thread_state.Parent()}, states, func(states NodeStateMap) (error) { + SendUpdate(ctx, thread_state.Parent(), signal, states) + return nil + }) + } + UseMoreStates(ctx, NodeList(thread_state.Dependencies()), states, func(states NodeStateMap) (error) { for _, dep := range(thread_state.Dependencies()) { - SendUpdate(ctx, dep, signal) + SendUpdate(ctx, dep, signal, states) } - } else if signal.Direction() == Down { - // Parent->Child, updates children and dependencies + return nil + }) + } else if signal.Direction() == Down { + // Parent->Child, updates children and dependencies + UseMoreStates(ctx, NodeList(thread_state.Children()), states, func(states NodeStateMap) (error) { for _, child := range(thread_state.Children()) { - SendUpdate(ctx, child, signal) + SendUpdate(ctx, child, signal, states) } + return nil + }) + UseMoreStates(ctx, NodeList(thread_state.Requirements()), states, func(states NodeStateMap) (error) { for _, requirement := range(thread_state.Requirements()) { - SendUpdate(ctx, requirement, signal) + SendUpdate(ctx, requirement, signal, states) } - } else if signal.Direction() == Direct { + return nil + }) + } else if signal.Direction() == Direct { - } else { - panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) - } + } else { + panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) + } - return nil - }) thread.signal <- signal } @@ -379,8 +388,6 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf return err } - SendUpdate(ctx, thread, NewSignal(child, "child_added")) - return nil } @@ -507,7 +514,10 @@ func RunThread(ctx * GraphContext, thread Thread, first_action string) error { return err } - SendUpdate(ctx, thread, NewSignal(thread, "thread_done")) + err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error { + SendUpdate(ctx, thread, NewSignal(thread, "thread_done"), states) + return nil + }) ctx.Log.Logf("thread", "THREAD_RUN_DONE: %s", thread.ID()) diff --git a/thread_test.go b/thread_test.go index 3a9d589..6d1d129 100644 --- a/thread_test.go +++ b/thread_test.go @@ -15,7 +15,10 @@ func TestNewThread(t * testing.T) { go func(thread Thread) { time.Sleep(10*time.Millisecond) - SendUpdate(ctx, t1, CancelSignal(nil)) + UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { + SendUpdate(ctx, t1, CancelSignal(nil), states) + return nil + }) }(t1) err = RunThread(ctx, t1, "start") @@ -41,7 +44,10 @@ func TestThreadWithRequirement(t * testing.T) { go func (thread Thread) { time.Sleep(10*time.Millisecond) - SendUpdate(ctx, t1, CancelSignal(nil)) + UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { + SendUpdate(ctx, t1, CancelSignal(nil), states) + return nil + }) }(t1) fatalErr(t, err) @@ -67,7 +73,10 @@ func TestThreadDBLoad(t * testing.T) { fatalErr(t, err) - SendUpdate(ctx, t1, CancelSignal(nil)) + UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { + SendUpdate(ctx, t1, CancelSignal(nil), states) + return nil + }) err = RunThread(ctx, t1, "start") fatalErr(t, err)