Modified SendUpdate to require it to be called from inside a UseStates context.

graph-rework-2
noah metz 2023-07-04 18:45:23 -06:00
parent ce831af290
commit c42ca80d47
7 changed files with 116 additions and 62 deletions

@ -824,13 +824,13 @@ func GQLMutationSendUpdate() *graphql.Field {
if node == nil { if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id) return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
} }
SendUpdate(ctx, node, signal, states)
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
SendUpdate(ctx, node, signal)
return signal, nil return signal, nil
}, },
} }

@ -28,7 +28,10 @@ func TestGQLThread(t * testing.T) {
go func(thread Thread){ go func(thread Thread){
time.Sleep(10*time.Millisecond) time.Sleep(10*time.Millisecond)
SendUpdate(ctx, thread, CancelSignal(nil)) err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error {
SendUpdate(ctx, thread, CancelSignal(nil), states)
return nil
})
}(gql_thread) }(gql_thread)
err = RunThread(ctx, gql_thread, "start") err = RunThread(ctx, gql_thread, "start")
@ -36,7 +39,7 @@ func TestGQLThread(t * testing.T) {
} }
func TestGQLDBLoad(t * testing.T) { func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"thread"}) ctx := logTestContext(t, []string{"thread", "update", "gql"})
l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err) fatalErr(t, err)
@ -48,14 +51,19 @@ func TestGQLDBLoad(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
info := NewGQLThreadInfo(true, "start", "restore") info := NewGQLThreadInfo(true, "start", "restore")
err = LinkThreads(ctx, gql, t1, &info) err = UpdateStates(ctx, []GraphNode{gql, t1}, func(nodes NodeMap) error {
return LinkThreads(ctx, gql, t1, &info)
})
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []GraphNode{gql}, func(states NodeStateMap) error {
SendUpdate(ctx, gql, CancelSignal(nil)) SendUpdate(ctx, gql, NewSignal(t1, "child_added"), states)
SendUpdate(ctx, gql, CancelSignal(nil), states)
return nil
})
err = RunThread(ctx, gql, "start") err = RunThread(ctx, gql, "start")
fatalErr(t, err) fatalErr(t, err)
(*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1, 100*time.Millisecond, "Dicn't received update_done on t1 from t1") (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1, 100*time.Millisecond, "Didn't receive thread_done from t1 on t1")
err = UseStates(ctx, []GraphNode{gql, t1}, func(states NodeStateMap) error { err = UseStates(ctx, []GraphNode{gql, t1}, func(states NodeStateMap) error {
ser1, err := json.MarshalIndent(states[gql.ID()], "", " ") ser1, err := json.MarshalIndent(states[gql.ID()], "", " ")
@ -80,10 +88,10 @@ func TestGQLDBLoad(t * testing.T) {
fmt.Printf("\n%s\n\n", ser) fmt.Printf("\n%s\n\n", ser)
return err return err
}) })
SendUpdate(ctx, gql_loaded, CancelSignal(nil), states)
return err return err
}) })
SendUpdate(ctx, gql_loaded, CancelSignal(nil))
err = RunThread(ctx, gql_loaded.(Thread), "restore") err = RunThread(ctx, gql_loaded.(Thread), "restore")
fatalErr(t, err) fatalErr(t, err)
(*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1_loaded, 100*time.Millisecond, "Dicn't received update_done on t1_loaded from t1_loaded") (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1_loaded, 100*time.Millisecond, "Dicn't received update_done on t1_loaded from t1_loaded")

@ -431,7 +431,7 @@ type GraphNode interface {
// Signal propagation function for listener channels // Signal propagation function for listener channels
UpdateListeners(ctx * GraphContext, update GraphSignal) UpdateListeners(ctx * GraphContext, update GraphSignal)
// Signal propagation function for connected nodes(defined in state) // Signal propagation function for connected nodes(defined in state)
PropagateUpdate(ctx * GraphContext, update GraphSignal) PropagateUpdate(ctx * GraphContext, update GraphSignal, states NodeStateMap)
// Get an update channel for the node to be notified of signals // Get an update channel for the node to be notified of signals
UpdateChannel(buffer int) chan GraphSignal UpdateChannel(buffer int) chan GraphSignal
@ -593,6 +593,14 @@ func checkForDuplicate(nodes []GraphNode) error {
return nil return nil
} }
func NodeList[K GraphNode](list []K) []GraphNode {
nodes := make([]GraphNode, len(list))
for i, node := range(list) {
nodes[i] = node
}
return nodes
}
type NodeStateMap map[NodeID]NodeState type NodeStateMap map[NodeID]NodeState
type NodeMap map[NodeID]GraphNode type NodeMap map[NodeID]GraphNode
type StatesFn func(states NodeStateMap) error type StatesFn func(states NodeStateMap) error
@ -675,7 +683,7 @@ func (node * BaseNode) UpdateListeners(ctx * GraphContext, update GraphSignal) {
} }
} }
func (node * BaseNode) PropagateUpdate(ctx * GraphContext, update GraphSignal) { func (node * BaseNode) PropagateUpdate(ctx * GraphContext, update GraphSignal, states NodeStateMap) {
} }
func (node * BaseNode) RegisterChannel(listener chan GraphSignal) { func (node * BaseNode) RegisterChannel(listener chan GraphSignal) {
@ -710,7 +718,7 @@ func (node * BaseNode) UpdateChannel(buffer int) chan GraphSignal {
} }
// Propogate a signal starting at a node // Propogate a signal starting at a node
func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal) { func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal, states NodeStateMap) {
if node == nil { if node == nil {
panic("Cannot start an update from no node") panic("Cannot start an update from no node")
} }
@ -718,6 +726,6 @@ func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal) {
ctx.Log.Logf("update", "UPDATE %s <- %s: %+v", node.ID(), signal.Source(), signal) ctx.Log.Logf("update", "UPDATE %s <- %s: %+v", node.ID(), signal.Source(), signal)
node.UpdateListeners(ctx, signal) node.UpdateListeners(ctx, signal)
node.PropagateUpdate(ctx, signal) node.PropagateUpdate(ctx, signal, states)
} }

@ -313,35 +313,42 @@ type Lockable interface {
CanUnlock(node GraphNode, state LockableState) error CanUnlock(node GraphNode, state LockableState) error
} }
func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal) { // lockable's state must already be locked for read
UseStates(ctx, []GraphNode{lockable}, func(states NodeStateMap) error { func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal, states NodeStateMap) {
lockable_state := states[lockable.ID()].(LockableState) lockable_state := states[lockable.ID()].(LockableState)
if signal.Direction() == Up { if signal.Direction() == Up {
// Child->Parent, lockable updates dependency lockables // Child->Parent, lockable updates dependency lockables
owner_sent := false owner_sent := false
for _, dependency := range lockable_state.Dependencies() { UseMoreStates(ctx, NodeList(lockable_state.Dependencies()), states, func(states NodeStateMap) error {
SendUpdate(ctx, dependency, signal) for _, dependency := range(lockable_state.Dependencies()){
SendUpdate(ctx, dependency, signal, states)
if lockable_state.Owner() != nil { if lockable_state.Owner() != nil {
if dependency.ID() != lockable_state.Owner().ID() { if dependency.ID() != lockable_state.Owner().ID() {
owner_sent = true owner_sent = true
} }
} }
} }
if lockable_state.Owner() != nil && owner_sent == false { return nil
SendUpdate(ctx, lockable_state.Owner(), signal) })
} if lockable_state.Owner() != nil && owner_sent == false {
} else if signal.Direction() == Down { UseMoreStates(ctx, []GraphNode{lockable_state.Owner()}, states, func(states NodeStateMap) error {
// Parent->Child, lockable updates lock holder SendUpdate(ctx, lockable_state.Owner(), signal, states)
return nil
})
}
} else if signal.Direction() == Down {
// Parent->Child, lockable updates lock holder
UseMoreStates(ctx, NodeList(lockable_state.Requirements()), states, func(states NodeStateMap) error {
for _, requirement := range(lockable_state.Requirements()) { for _, requirement := range(lockable_state.Requirements()) {
SendUpdate(ctx, requirement, signal) SendUpdate(ctx, requirement, signal, states)
} }
return nil
})
} else if signal.Direction() == Direct { } else if signal.Direction() == Direct {
} else { } else {
panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction()))
} }
return nil
})
} }
func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_id NodeID, nodes NodeMap) bool { func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_id NodeID, nodes NodeMap) bool {

@ -279,7 +279,10 @@ func TestLockableSimpleUpdate(t * testing.T) {
update_channel := l1.UpdateChannel(0) update_channel := l1.UpdateChannel(0)
go func() { go func() {
SendUpdate(ctx, l1, NewDirectSignal(l1, "test_update")) UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) error {
SendUpdate(ctx, l1, NewDirectSignal(l1, "test_update"), states)
return nil
})
}() }()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Didn't receive test_update sent to l1") (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Didn't receive test_update sent to l1")
@ -300,14 +303,17 @@ func TestLockableDownUpdate(t * testing.T) {
update_channel := l1.UpdateChannel(0) update_channel := l1.UpdateChannel(0)
go func() { go func() {
SendUpdate(ctx, l2, NewDownSignal(l2, "test_update")) UseStates(ctx, []GraphNode{l2}, func(states NodeStateMap) error {
SendUpdate(ctx, l2, NewDownSignal(l2, "test_update"), states)
return nil
})
}() }()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2")
} }
func TestLockableUpUpdate(t * testing.T) { func TestLockableUpUpdate(t * testing.T) {
ctx := testContext(t) ctx := logTestContext(t, []string{"test", "update"})
l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err) fatalErr(t, err)
@ -321,7 +327,10 @@ func TestLockableUpUpdate(t * testing.T) {
update_channel := l3.UpdateChannel(0) update_channel := l3.UpdateChannel(0)
go func() { go func() {
SendUpdate(ctx, l2, NewSignal(l2, "test_update")) UseStates(ctx, []GraphNode{l2}, func(states NodeStateMap) error {
SendUpdate(ctx, l2, NewSignal(l2, "test_update"), states)
return nil
})
}() }()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2")
@ -339,7 +348,10 @@ func TestOwnerNotUpdatedTwice(t * testing.T) {
update_channel := l2.UpdateChannel(0) update_channel := l2.UpdateChannel(0)
go func() { go func() {
SendUpdate(ctx, l1, NewSignal(l1, "test_update")) UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) error {
SendUpdate(ctx, l1, NewSignal(l1, "test_update"), states)
return nil
})
}() }()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1") (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1")

@ -10,35 +10,44 @@ import (
) )
// Update the threads listeners, and notify the parent to do the same // Update the threads listeners, and notify the parent to do the same
func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSignal) { func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSignal, states NodeStateMap) {
UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) (error) { thread_state := states[thread.ID()].(ThreadState)
thread_state := states[thread.ID()].(ThreadState) if signal.Direction() == Up {
if signal.Direction() == Up { // Child->Parent, thread updates parent and connected requirement
// Child->Parent, thread updates parent and connected requirement if thread_state.Parent() != nil {
if thread_state.Parent() != nil { UseMoreStates(ctx, []GraphNode{thread_state.Parent()}, states, func(states NodeStateMap) (error) {
SendUpdate(ctx, thread_state.Parent(), signal) SendUpdate(ctx, thread_state.Parent(), signal, states)
} return nil
})
}
UseMoreStates(ctx, NodeList(thread_state.Dependencies()), states, func(states NodeStateMap) (error) {
for _, dep := range(thread_state.Dependencies()) { for _, dep := range(thread_state.Dependencies()) {
SendUpdate(ctx, dep, signal) SendUpdate(ctx, dep, signal, states)
} }
} else if signal.Direction() == Down { return nil
// Parent->Child, updates children and dependencies })
} else if signal.Direction() == Down {
// Parent->Child, updates children and dependencies
UseMoreStates(ctx, NodeList(thread_state.Children()), states, func(states NodeStateMap) (error) {
for _, child := range(thread_state.Children()) { for _, child := range(thread_state.Children()) {
SendUpdate(ctx, child, signal) SendUpdate(ctx, child, signal, states)
} }
return nil
})
UseMoreStates(ctx, NodeList(thread_state.Requirements()), states, func(states NodeStateMap) (error) {
for _, requirement := range(thread_state.Requirements()) { for _, requirement := range(thread_state.Requirements()) {
SendUpdate(ctx, requirement, signal) SendUpdate(ctx, requirement, signal, states)
} }
} else if signal.Direction() == Direct { return nil
})
} else if signal.Direction() == Direct {
} else { } else {
panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction()))
} }
return nil
})
thread.signal <- signal thread.signal <- signal
} }
@ -379,8 +388,6 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf
return err return err
} }
SendUpdate(ctx, thread, NewSignal(child, "child_added"))
return nil return nil
} }
@ -507,7 +514,10 @@ func RunThread(ctx * GraphContext, thread Thread, first_action string) error {
return err return err
} }
SendUpdate(ctx, thread, NewSignal(thread, "thread_done")) err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error {
SendUpdate(ctx, thread, NewSignal(thread, "thread_done"), states)
return nil
})
ctx.Log.Logf("thread", "THREAD_RUN_DONE: %s", thread.ID()) ctx.Log.Logf("thread", "THREAD_RUN_DONE: %s", thread.ID())

@ -15,7 +15,10 @@ func TestNewThread(t * testing.T) {
go func(thread Thread) { go func(thread Thread) {
time.Sleep(10*time.Millisecond) time.Sleep(10*time.Millisecond)
SendUpdate(ctx, t1, CancelSignal(nil)) UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error {
SendUpdate(ctx, t1, CancelSignal(nil), states)
return nil
})
}(t1) }(t1)
err = RunThread(ctx, t1, "start") err = RunThread(ctx, t1, "start")
@ -41,7 +44,10 @@ func TestThreadWithRequirement(t * testing.T) {
go func (thread Thread) { go func (thread Thread) {
time.Sleep(10*time.Millisecond) time.Sleep(10*time.Millisecond)
SendUpdate(ctx, t1, CancelSignal(nil)) UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error {
SendUpdate(ctx, t1, CancelSignal(nil), states)
return nil
})
}(t1) }(t1)
fatalErr(t, err) fatalErr(t, err)
@ -67,7 +73,10 @@ func TestThreadDBLoad(t * testing.T) {
fatalErr(t, err) fatalErr(t, err)
SendUpdate(ctx, t1, CancelSignal(nil)) UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error {
SendUpdate(ctx, t1, CancelSignal(nil), states)
return nil
})
err = RunThread(ctx, t1, "start") err = RunThread(ctx, t1, "start")
fatalErr(t, err) fatalErr(t, err)