Modified SendUpdate to require it to be called from inside a UseStates context.

graph-rework-2
noah metz 2023-07-04 18:45:23 -06:00
parent ce831af290
commit c42ca80d47
7 changed files with 116 additions and 62 deletions

@ -824,13 +824,13 @@ func GQLMutationSendUpdate() *graphql.Field {
if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
}
SendUpdate(ctx, node, signal, states)
return nil
})
if err != nil {
return nil, err
}
SendUpdate(ctx, node, signal)
return signal, nil
},
}

@ -28,7 +28,10 @@ func TestGQLThread(t * testing.T) {
go func(thread Thread){
time.Sleep(10*time.Millisecond)
SendUpdate(ctx, thread, CancelSignal(nil))
err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error {
SendUpdate(ctx, thread, CancelSignal(nil), states)
return nil
})
}(gql_thread)
err = RunThread(ctx, gql_thread, "start")
@ -36,7 +39,7 @@ func TestGQLThread(t * testing.T) {
}
func TestGQLDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{"thread"})
ctx := logTestContext(t, []string{"thread", "update", "gql"})
l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err)
@ -48,14 +51,19 @@ func TestGQLDBLoad(t * testing.T) {
fatalErr(t, err)
info := NewGQLThreadInfo(true, "start", "restore")
err = LinkThreads(ctx, gql, t1, &info)
err = UpdateStates(ctx, []GraphNode{gql, t1}, func(nodes NodeMap) error {
return LinkThreads(ctx, gql, t1, &info)
})
fatalErr(t, err)
SendUpdate(ctx, gql, CancelSignal(nil))
err = UseStates(ctx, []GraphNode{gql}, func(states NodeStateMap) error {
SendUpdate(ctx, gql, NewSignal(t1, "child_added"), states)
SendUpdate(ctx, gql, CancelSignal(nil), states)
return nil
})
err = RunThread(ctx, gql, "start")
fatalErr(t, err)
(*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1, 100*time.Millisecond, "Dicn't received update_done on t1 from t1")
(*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1, 100*time.Millisecond, "Didn't receive thread_done from t1 on t1")
err = UseStates(ctx, []GraphNode{gql, t1}, func(states NodeStateMap) error {
ser1, err := json.MarshalIndent(states[gql.ID()], "", " ")
@ -80,10 +88,10 @@ func TestGQLDBLoad(t * testing.T) {
fmt.Printf("\n%s\n\n", ser)
return err
})
SendUpdate(ctx, gql_loaded, CancelSignal(nil), states)
return err
})
SendUpdate(ctx, gql_loaded, CancelSignal(nil))
err = RunThread(ctx, gql_loaded.(Thread), "restore")
fatalErr(t, err)
(*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1_loaded, 100*time.Millisecond, "Dicn't received update_done on t1_loaded from t1_loaded")

@ -431,7 +431,7 @@ type GraphNode interface {
// Signal propagation function for listener channels
UpdateListeners(ctx * GraphContext, update GraphSignal)
// Signal propagation function for connected nodes(defined in state)
PropagateUpdate(ctx * GraphContext, update GraphSignal)
PropagateUpdate(ctx * GraphContext, update GraphSignal, states NodeStateMap)
// Get an update channel for the node to be notified of signals
UpdateChannel(buffer int) chan GraphSignal
@ -593,6 +593,14 @@ func checkForDuplicate(nodes []GraphNode) error {
return nil
}
func NodeList[K GraphNode](list []K) []GraphNode {
nodes := make([]GraphNode, len(list))
for i, node := range(list) {
nodes[i] = node
}
return nodes
}
type NodeStateMap map[NodeID]NodeState
type NodeMap map[NodeID]GraphNode
type StatesFn func(states NodeStateMap) error
@ -675,7 +683,7 @@ func (node * BaseNode) UpdateListeners(ctx * GraphContext, update GraphSignal) {
}
}
func (node * BaseNode) PropagateUpdate(ctx * GraphContext, update GraphSignal) {
func (node * BaseNode) PropagateUpdate(ctx * GraphContext, update GraphSignal, states NodeStateMap) {
}
func (node * BaseNode) RegisterChannel(listener chan GraphSignal) {
@ -710,7 +718,7 @@ func (node * BaseNode) UpdateChannel(buffer int) chan GraphSignal {
}
// Propogate a signal starting at a node
func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal) {
func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal, states NodeStateMap) {
if node == nil {
panic("Cannot start an update from no node")
}
@ -718,6 +726,6 @@ func SendUpdate(ctx * GraphContext, node GraphNode, signal GraphSignal) {
ctx.Log.Logf("update", "UPDATE %s <- %s: %+v", node.ID(), signal.Source(), signal)
node.UpdateListeners(ctx, signal)
node.PropagateUpdate(ctx, signal)
node.PropagateUpdate(ctx, signal, states)
}

@ -313,35 +313,42 @@ type Lockable interface {
CanUnlock(node GraphNode, state LockableState) error
}
func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal) {
UseStates(ctx, []GraphNode{lockable}, func(states NodeStateMap) error {
lockable_state := states[lockable.ID()].(LockableState)
if signal.Direction() == Up {
// Child->Parent, lockable updates dependency lockables
owner_sent := false
for _, dependency := range lockable_state.Dependencies() {
SendUpdate(ctx, dependency, signal)
// lockable's state must already be locked for read
func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal, states NodeStateMap) {
lockable_state := states[lockable.ID()].(LockableState)
if signal.Direction() == Up {
// Child->Parent, lockable updates dependency lockables
owner_sent := false
UseMoreStates(ctx, NodeList(lockable_state.Dependencies()), states, func(states NodeStateMap) error {
for _, dependency := range(lockable_state.Dependencies()){
SendUpdate(ctx, dependency, signal, states)
if lockable_state.Owner() != nil {
if dependency.ID() != lockable_state.Owner().ID() {
owner_sent = true
}
}
}
if lockable_state.Owner() != nil && owner_sent == false {
SendUpdate(ctx, lockable_state.Owner(), signal)
}
} else if signal.Direction() == Down {
// Parent->Child, lockable updates lock holder
return nil
})
if lockable_state.Owner() != nil && owner_sent == false {
UseMoreStates(ctx, []GraphNode{lockable_state.Owner()}, states, func(states NodeStateMap) error {
SendUpdate(ctx, lockable_state.Owner(), signal, states)
return nil
})
}
} else if signal.Direction() == Down {
// Parent->Child, lockable updates lock holder
UseMoreStates(ctx, NodeList(lockable_state.Requirements()), states, func(states NodeStateMap) error {
for _, requirement := range(lockable_state.Requirements()) {
SendUpdate(ctx, requirement, signal)
SendUpdate(ctx, requirement, signal, states)
}
return nil
})
} else if signal.Direction() == Direct {
} else {
panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction()))
}
return nil
})
} else if signal.Direction() == Direct {
} else {
panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction()))
}
}
func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_id NodeID, nodes NodeMap) bool {

@ -279,7 +279,10 @@ func TestLockableSimpleUpdate(t * testing.T) {
update_channel := l1.UpdateChannel(0)
go func() {
SendUpdate(ctx, l1, NewDirectSignal(l1, "test_update"))
UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) error {
SendUpdate(ctx, l1, NewDirectSignal(l1, "test_update"), states)
return nil
})
}()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Didn't receive test_update sent to l1")
@ -300,14 +303,17 @@ func TestLockableDownUpdate(t * testing.T) {
update_channel := l1.UpdateChannel(0)
go func() {
SendUpdate(ctx, l2, NewDownSignal(l2, "test_update"))
UseStates(ctx, []GraphNode{l2}, func(states NodeStateMap) error {
SendUpdate(ctx, l2, NewDownSignal(l2, "test_update"), states)
return nil
})
}()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2")
}
func TestLockableUpUpdate(t * testing.T) {
ctx := testContext(t)
ctx := logTestContext(t, []string{"test", "update"})
l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err)
@ -321,7 +327,10 @@ func TestLockableUpUpdate(t * testing.T) {
update_channel := l3.UpdateChannel(0)
go func() {
SendUpdate(ctx, l2, NewSignal(l2, "test_update"))
UseStates(ctx, []GraphNode{l2}, func(states NodeStateMap) error {
SendUpdate(ctx, l2, NewSignal(l2, "test_update"), states)
return nil
})
}()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2")
@ -339,7 +348,10 @@ func TestOwnerNotUpdatedTwice(t * testing.T) {
update_channel := l2.UpdateChannel(0)
go func() {
SendUpdate(ctx, l1, NewSignal(l1, "test_update"))
UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) error {
SendUpdate(ctx, l1, NewSignal(l1, "test_update"), states)
return nil
})
}()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1")

@ -10,35 +10,44 @@ import (
)
// Update the threads listeners, and notify the parent to do the same
func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSignal) {
UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) (error) {
thread_state := states[thread.ID()].(ThreadState)
if signal.Direction() == Up {
// Child->Parent, thread updates parent and connected requirement
if thread_state.Parent() != nil {
SendUpdate(ctx, thread_state.Parent(), signal)
}
func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSignal, states NodeStateMap) {
thread_state := states[thread.ID()].(ThreadState)
if signal.Direction() == Up {
// Child->Parent, thread updates parent and connected requirement
if thread_state.Parent() != nil {
UseMoreStates(ctx, []GraphNode{thread_state.Parent()}, states, func(states NodeStateMap) (error) {
SendUpdate(ctx, thread_state.Parent(), signal, states)
return nil
})
}
UseMoreStates(ctx, NodeList(thread_state.Dependencies()), states, func(states NodeStateMap) (error) {
for _, dep := range(thread_state.Dependencies()) {
SendUpdate(ctx, dep, signal)
SendUpdate(ctx, dep, signal, states)
}
} else if signal.Direction() == Down {
// Parent->Child, updates children and dependencies
return nil
})
} else if signal.Direction() == Down {
// Parent->Child, updates children and dependencies
UseMoreStates(ctx, NodeList(thread_state.Children()), states, func(states NodeStateMap) (error) {
for _, child := range(thread_state.Children()) {
SendUpdate(ctx, child, signal)
SendUpdate(ctx, child, signal, states)
}
return nil
})
UseMoreStates(ctx, NodeList(thread_state.Requirements()), states, func(states NodeStateMap) (error) {
for _, requirement := range(thread_state.Requirements()) {
SendUpdate(ctx, requirement, signal)
SendUpdate(ctx, requirement, signal, states)
}
} else if signal.Direction() == Direct {
return nil
})
} else if signal.Direction() == Direct {
} else {
panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction()))
}
} else {
panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction()))
}
return nil
})
thread.signal <- signal
}
@ -379,8 +388,6 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf
return err
}
SendUpdate(ctx, thread, NewSignal(child, "child_added"))
return nil
}
@ -507,7 +514,10 @@ func RunThread(ctx * GraphContext, thread Thread, first_action string) error {
return err
}
SendUpdate(ctx, thread, NewSignal(thread, "thread_done"))
err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error {
SendUpdate(ctx, thread, NewSignal(thread, "thread_done"), states)
return nil
})
ctx.Log.Logf("thread", "THREAD_RUN_DONE: %s", thread.ID())

@ -15,7 +15,10 @@ func TestNewThread(t * testing.T) {
go func(thread Thread) {
time.Sleep(10*time.Millisecond)
SendUpdate(ctx, t1, CancelSignal(nil))
UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error {
SendUpdate(ctx, t1, CancelSignal(nil), states)
return nil
})
}(t1)
err = RunThread(ctx, t1, "start")
@ -41,7 +44,10 @@ func TestThreadWithRequirement(t * testing.T) {
go func (thread Thread) {
time.Sleep(10*time.Millisecond)
SendUpdate(ctx, t1, CancelSignal(nil))
UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error {
SendUpdate(ctx, t1, CancelSignal(nil), states)
return nil
})
}(t1)
fatalErr(t, err)
@ -67,7 +73,10 @@ func TestThreadDBLoad(t * testing.T) {
fatalErr(t, err)
SendUpdate(ctx, t1, CancelSignal(nil))
UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error {
SendUpdate(ctx, t1, CancelSignal(nil), states)
return nil
})
err = RunThread(ctx, t1, "start")
fatalErr(t, err)