From b66fad2c8edffa2e06c71e74cbcf4d630bec3818 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 24 Jun 2023 19:48:59 -0600 Subject: [PATCH] Fixed update path and added more tests --- graph.go | 5 +++- graph_test.go | 14 +++++++-- lockable.go | 54 +++++++++++++++++---------------- lockable_test.go | 77 ++++++++++++++++++++++++++++++++++++++++++++++++ thread.go | 53 ++++++++++++++++++++------------- thread_test.go | 36 ++++++++++++++++++++++ 6 files changed, 190 insertions(+), 49 deletions(-) diff --git a/graph.go b/graph.go index f9cef8f..43da196 100644 --- a/graph.go +++ b/graph.go @@ -192,6 +192,9 @@ type GraphNode interface { // Signal propagation function for connected nodes(defined in state) PropagateUpdate(ctx * GraphContext, update GraphSignal) + // Get an update channel for the node to be notified of signals + UpdateChannel(buffer int) chan GraphSignal + // Register and unregister a channel to propogate updates to RegisterChannel(listener chan GraphSignal) UnregisterChannel(listener chan GraphSignal) @@ -384,7 +387,7 @@ func (node * BaseNode) SignalChannel() chan GraphSignal { } // Create a new GraphSinal channel with a buffer of size buffer and register it to a node -func GetUpdateChannel(node * BaseNode, buffer int) chan GraphSignal { +func (node * BaseNode) UpdateChannel(buffer int) chan GraphSignal { new_listener := make(chan GraphSignal, buffer) node.RegisterChannel(new_listener) return new_listener diff --git a/graph_test.go b/graph_test.go index 6faed47..844c85a 100644 --- a/graph_test.go +++ b/graph_test.go @@ -17,10 +17,20 @@ func (t * GraphTester) WaitForValue(ctx * GraphContext, listener chan GraphSigna for true { select { case signal := <- listener: + if signal == nil { + ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener) + t.Fatal(str) + } if signal.Type() == signal_type { ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener) - if signal.Source() == source.ID() { - return signal + if source == nil { + if signal.Source() == "" { + return signal + } + } else { + if signal.Source() == source.ID() { + return signal + } } } case <-timeout_channel: diff --git a/lockable.go b/lockable.go index b212b22..75c6d3e 100644 --- a/lockable.go +++ b/lockable.go @@ -7,16 +7,16 @@ import ( // LockHolderState is the interface that any node that wants to posses locks must implement // -// ReturnLock returns the node that held the resource pointed to by ID before this node and -// removes the mapping from it's state, or nil if the resource was unlocked previously +// ReturnLock returns the node that held the lockable pointed to by ID before this node and +// removes the mapping from it's state, or nil if the lockable was unlocked previously // // AllowedToTakeLock returns true if the node pointed to by ID is allowed to take a lock from this node // -// RecordLockHolder records that resource_id needs to be passed back to lock_holder +// RecordLockHolder records that lockable_id needs to be passed back to lock_holder type LockHolderState interface { - ReturnLock(resource_id NodeID) GraphNode - AllowedToTakeLock(node_id NodeID, resource_id NodeID) bool - RecordLockHolder(resource_id NodeID, lock_holder GraphNode) + ReturnLock(lockable_id NodeID) GraphNode + AllowedToTakeLock(node_id NodeID, lockable_id NodeID) bool + RecordLockHolder(lockable_id NodeID, lock_holder GraphNode) } // LockableState is the interface that a lockables state must have to allow it to connect to the DAG @@ -104,31 +104,31 @@ func (state * BaseLockableState) Name() string { // Locks cannot be passed between base lockables, so the answer to // "who used to own this lock held by a base lockable" is always "nobody" -func (state * BaseLockHolderState) ReturnLock(resource_id NodeID) GraphNode { - node, exists := state.delegation_map[resource_id] +func (state * BaseLockHolderState) ReturnLock(lockable_id NodeID) GraphNode { + node, exists := state.delegation_map[lockable_id] if exists == false { - panic("Attempted to take a get the original lock holder of a resource we don't own") + panic("Attempted to take a get the original lock holder of a lockable we don't own") } - delete(state.delegation_map, resource_id) + delete(state.delegation_map, lockable_id) return node } // Nothing can take a lock from a base lockable either -func (state * BaseLockHolderState) AllowedToTakeLock(node_id NodeID, resource_id NodeID) bool { - _, exists := state.delegation_map[resource_id] +func (state * BaseLockHolderState) AllowedToTakeLock(node_id NodeID, lockable_id NodeID) bool { + _, exists := state.delegation_map[lockable_id] if exists == false { panic ("Trying to give away lock we don't own") } return false } -func (state * BaseLockHolderState) RecordLockHolder(resource_id NodeID, lock_holder GraphNode) { - _, exists := state.delegation_map[resource_id] +func (state * BaseLockHolderState) RecordLockHolder(lockable_id NodeID, lock_holder GraphNode) { + _, exists := state.delegation_map[lockable_id] if exists == true { - panic("Attempted to lock a resource we're already holding(lock cycle)") + panic("Attempted to lock a lockable we're already holding(lock cycle)") } - state.delegation_map[resource_id] = lock_holder + state.delegation_map[lockable_id] = lock_holder } func (state * BaseLockableState) Owner() GraphNode { @@ -215,25 +215,29 @@ type Lockable interface { Unlock(node GraphNode, state LockableState) error } -// Lockables propagate update up to multiple dependencies, and not downwards -// (subscriber to team won't get update to alliance, but subscriber to alliance will get update to team) func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal) { UseStates(ctx, []GraphNode{lockable}, func(states []NodeState) (interface{}, error){ lockable_state := states[0].(LockableState) if signal.Direction() == Up { // Child->Parent, lockable updates dependency lockables + owner_sent := false for _, dependency := range lockable_state.Dependencies() { SendUpdate(ctx, dependency, signal) + if lockable_state.Owner() != nil { + if dependency.ID() != lockable_state.Owner().ID() { + owner_sent = true + } + } } - } else if signal.Direction() == Down { - // Parent->Child, lockable updates lock holder - if lockable_state.Owner() != nil { + if lockable_state.Owner() != nil && owner_sent == false { SendUpdate(ctx, lockable_state.Owner(), signal) } - + } else if signal.Direction() == Down { + // Parent->Child, lockable updates lock holder for _, requirement := range(lockable_state.Requirements()) { SendUpdate(ctx, requirement, signal) } + } else if signal.Direction() == Direct { } else { panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) @@ -317,14 +321,14 @@ func UnlockLockable(ctx * GraphContext, lockable Lockable, node GraphNode, node_ func LockLockable(ctx * GraphContext, lockable Lockable, node GraphNode, node_state LockHolderState) error { if node == nil || lockable == nil { - panic("Cannot lock without a specified node and lockable") + return fmt.Errorf("Cannot lock without a specified node and lockable") } - ctx.Log.Logf("resource", "LOCKING: %s from %s", lockable.ID(), node.ID()) + ctx.Log.Logf("lockable", "LOCKING: %s from %s", lockable.ID(), node.ID()) _, err := UpdateStates(ctx, []GraphNode{lockable}, func(states []NodeState) ([]NodeState, interface{}, error) { if lockable.ID() == node.ID() { if node_state != nil { - panic("node_state must be nil if locking lockable from itself") + return nil, nil, fmt.Errorf("node_state must be nil if locking lockable from itself") } node_state = states[0].(LockHolderState) } diff --git a/lockable_test.go b/lockable_test.go index 9b8022b..829b483 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -4,6 +4,7 @@ import ( "testing" "fmt" "encoding/json" + "time" ) func TestNewLockable(t * testing.T) { @@ -231,3 +232,79 @@ func TestLockableLockTieredConflict(t * testing.T) { t.Fatal("Locked r3 which depends on r1 while r2 which depends on r1 is already locked") } } + +func TestLockableSimpleUpdate(t * testing.T) { + ctx := testContext(t) + + l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{}) + fatalErr(t, err) + + update_channel := l1.UpdateChannel(0) + + go func() { + SendUpdate(ctx, l1, NewDirectSignal(l1, "test_update")) + }() + + (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Didn't receive test_update sent to l1") +} + +func TestLockableDownUpdate(t * testing.T) { + ctx := testContext(t) + + l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{}) + fatalErr(t, err) + + l2, err := NewLockable(ctx, "Test Lockable 2", []Lockable{l1}) + fatalErr(t, err) + + _, err = NewLockable(ctx, "Test Lockable 3", []Lockable{l2}) + fatalErr(t, err) + + update_channel := l1.UpdateChannel(0) + + go func() { + SendUpdate(ctx, l2, NewDownSignal(l2, "test_update")) + }() + + (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") +} + +func TestLockableUpUpdate(t * testing.T) { + ctx := testContext(t) + + l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{}) + fatalErr(t, err) + + l2, err := NewLockable(ctx, "Test Lockable 2", []Lockable{l1}) + fatalErr(t, err) + + l3, err := NewLockable(ctx, "Test Lockable 3", []Lockable{l2}) + fatalErr(t, err) + + update_channel := l3.UpdateChannel(0) + + go func() { + SendUpdate(ctx, l2, NewSignal(l2, "test_update")) + }() + + (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2") +} + +func TestOwnerNotUpdatedTwice(t * testing.T) { + ctx := testContext(t) + + l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{}) + fatalErr(t, err) + + l2, err := NewLockable(ctx, "Test Lockable 2", []Lockable{l1}) + fatalErr(t, err) + + update_channel := l2.UpdateChannel(0) + + go func() { + SendUpdate(ctx, l1, NewSignal(l1, "test_update")) + }() + + (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1") + (*GraphTester)(t).CheckForNone(update_channel, "Second update received on dependency") +} diff --git a/thread.go b/thread.go index b6475a1..9f06c31 100644 --- a/thread.go +++ b/thread.go @@ -13,22 +13,22 @@ func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSigna UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (interface{}, error) { thread_state := states[0].(ThreadState) if signal.Direction() == Up { - // Child->Parent, thread updates parent and connected resources + // Child->Parent, thread updates parent and connected requirement if thread_state.Parent() != nil { SendUpdate(ctx, thread_state.Parent(), signal) } - for _, resource := range(thread_state.Requirements()) { - SendUpdate(ctx, resource, signal) + for _, dep := range(thread_state.Dependencies()) { + SendUpdate(ctx, dep, signal) } } else if signal.Direction() == Down { - // Parent->Child, thread updated children + // Parent->Child, updates children and dependencies for _, child := range(thread_state.Children()) { SendUpdate(ctx, child, signal) } - for _, dep := range(thread_state.Dependencies()) { - SendUpdate(ctx, dep, signal) + for _, requirement := range(thread_state.Requirements()) { + SendUpdate(ctx, requirement, signal) } } else if signal.Direction() == Direct { @@ -193,7 +193,7 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf // Thread is the interface that thread tree nodes must implement type Thread interface { - GraphNode + Lockable Action(action string) (ThreadAction, bool) Handler(signal_type string) (ThreadHandler, bool) @@ -231,7 +231,12 @@ func FindChild(ctx * GraphContext, thread Thread, thread_state ThreadState, id N func RunThread(ctx * GraphContext, thread Thread) error { ctx.Log.Logf("thread", "EVENT_RUN: %s", thread.ID()) - _, err := UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (interface{}, error) { + err := LockLockable(ctx, thread, thread, nil) + if err != nil { + return err + } + + _, err = UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (interface{}, error) { thread_state := states[0].(ThreadState) if thread_state.Owner() == nil { return nil, fmt.Errorf("EVENT_RUN_NOT_LOCKED: %s", thread_state.Name()) @@ -275,7 +280,7 @@ type ThreadHandlers map[string]ThreadHandler // This node by itself doesn't implement any special behaviours for children, so they will be ignored. // When started, this thread automatically transitions to completion type BaseThread struct { - BaseNode + BaseLockable Actions ThreadActions Handlers ThreadHandlers @@ -327,17 +332,23 @@ var ThreadDefaultStart = func(ctx * GraphContext, thread Thread) (string, error) var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) { ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout()) - select { - case signal := <- thread.SignalChannel(): - ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) - signal_fn, exists := thread.Handler(signal.Type()) - if exists == true { - ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) - return signal_fn(ctx, thread, signal) - } - case <- thread.Timeout(): - ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), thread.TimeoutAction()) - return thread.TimeoutAction(), nil + for { + select { + case signal := <- thread.SignalChannel(): + ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) + if signal.Source() == thread.ID() { + ctx.Log.Logf("thread", "THREAD_SIGNAL_INTERNAL") + continue + } + signal_fn, exists := thread.Handler(signal.Type()) + if exists == true { + ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) + return signal_fn(ctx, thread, signal) + } + case <- thread.Timeout(): + ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), thread.TimeoutAction()) + return thread.TimeoutAction(), nil + } } return "wait", nil } @@ -362,7 +373,7 @@ func NewBaseThreadState(name string) BaseThreadState { func NewBaseThread(ctx * GraphContext, name string) (BaseThread, error) { state := NewBaseThreadState(name) thread := BaseThread{ - BaseNode: NewNode(ctx, RandID(), &state), + BaseLockable: BaseLockable{BaseNode: NewNode(ctx, RandID(), &state)}, Actions: ThreadActions{ "wait": ThreadWait, "start": ThreadDefaultStart, diff --git a/thread_test.go b/thread_test.go index 0eb2d16..392554a 100644 --- a/thread_test.go +++ b/thread_test.go @@ -30,3 +30,39 @@ func TestNewEvent(t * testing.T) { return nil, nil }) } + +func TestEventWithRequirement(t * testing.T) { + ctx := logTestContext(t, []string{"lockable", "thread"}) + + l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{}) + fatalErr(t, err) + + t1, err := NewThread(ctx, "Test Thread 1", []Lockable{l1}, ThreadActions{}, ThreadHandlers{}) + fatalErr(t, err) + + go func (thread Thread) { + time.Sleep(10*time.Millisecond) + _, err := UseStates(ctx, []GraphNode{l1}, func(states []NodeState) (interface{}, error) { + ser, err := json.MarshalIndent(states[0], "", " ") + fatalErr(t, err) + + fmt.Printf("\n%s\n", ser) + return nil, nil + }) + fatalErr(t, err) + SendUpdate(ctx, t1, CancelSignal(nil)) + }(t1) + fatalErr(t, err) + + err = RunThread(ctx, t1) + fatalErr(t, err) + + _, err = UseStates(ctx, []GraphNode{l1}, func(states []NodeState) (interface{}, error) { + ser, err := json.MarshalIndent(states[0], "", " ") + fatalErr(t, err) + + fmt.Printf("\n%s\n", ser) + return nil, nil + }) + fatalErr(t, err) +}