Fixed update path and added more tests

graph-rework
noah metz 2023-06-24 19:48:59 -06:00
parent 6f83587d7f
commit b66fad2c8e
6 changed files with 190 additions and 49 deletions

@ -192,6 +192,9 @@ type GraphNode interface {
// Signal propagation function for connected nodes(defined in state) // Signal propagation function for connected nodes(defined in state)
PropagateUpdate(ctx * GraphContext, update GraphSignal) PropagateUpdate(ctx * GraphContext, update GraphSignal)
// Get an update channel for the node to be notified of signals
UpdateChannel(buffer int) chan GraphSignal
// Register and unregister a channel to propogate updates to // Register and unregister a channel to propogate updates to
RegisterChannel(listener chan GraphSignal) RegisterChannel(listener chan GraphSignal)
UnregisterChannel(listener chan GraphSignal) UnregisterChannel(listener chan GraphSignal)
@ -384,7 +387,7 @@ func (node * BaseNode) SignalChannel() chan GraphSignal {
} }
// Create a new GraphSinal channel with a buffer of size buffer and register it to a node // Create a new GraphSinal channel with a buffer of size buffer and register it to a node
func GetUpdateChannel(node * BaseNode, buffer int) chan GraphSignal { func (node * BaseNode) UpdateChannel(buffer int) chan GraphSignal {
new_listener := make(chan GraphSignal, buffer) new_listener := make(chan GraphSignal, buffer)
node.RegisterChannel(new_listener) node.RegisterChannel(new_listener)
return new_listener return new_listener

@ -17,10 +17,20 @@ func (t * GraphTester) WaitForValue(ctx * GraphContext, listener chan GraphSigna
for true { for true {
select { select {
case signal := <- listener: case signal := <- listener:
if signal == nil {
ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener)
t.Fatal(str)
}
if signal.Type() == signal_type { if signal.Type() == signal_type {
ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener) ctx.Log.Logf("test", "SIGNAL_TYPE_FOUND: %s - %s %+v\n", signal.Type(), signal.Source(), listener)
if signal.Source() == source.ID() { if source == nil {
return signal if signal.Source() == "" {
return signal
}
} else {
if signal.Source() == source.ID() {
return signal
}
} }
} }
case <-timeout_channel: case <-timeout_channel:

@ -7,16 +7,16 @@ import (
// LockHolderState is the interface that any node that wants to posses locks must implement // LockHolderState is the interface that any node that wants to posses locks must implement
// //
// ReturnLock returns the node that held the resource pointed to by ID before this node and // ReturnLock returns the node that held the lockable pointed to by ID before this node and
// removes the mapping from it's state, or nil if the resource was unlocked previously // removes the mapping from it's state, or nil if the lockable was unlocked previously
// //
// AllowedToTakeLock returns true if the node pointed to by ID is allowed to take a lock from this node // AllowedToTakeLock returns true if the node pointed to by ID is allowed to take a lock from this node
// //
// RecordLockHolder records that resource_id needs to be passed back to lock_holder // RecordLockHolder records that lockable_id needs to be passed back to lock_holder
type LockHolderState interface { type LockHolderState interface {
ReturnLock(resource_id NodeID) GraphNode ReturnLock(lockable_id NodeID) GraphNode
AllowedToTakeLock(node_id NodeID, resource_id NodeID) bool AllowedToTakeLock(node_id NodeID, lockable_id NodeID) bool
RecordLockHolder(resource_id NodeID, lock_holder GraphNode) RecordLockHolder(lockable_id NodeID, lock_holder GraphNode)
} }
// LockableState is the interface that a lockables state must have to allow it to connect to the DAG // LockableState is the interface that a lockables state must have to allow it to connect to the DAG
@ -104,31 +104,31 @@ func (state * BaseLockableState) Name() string {
// Locks cannot be passed between base lockables, so the answer to // Locks cannot be passed between base lockables, so the answer to
// "who used to own this lock held by a base lockable" is always "nobody" // "who used to own this lock held by a base lockable" is always "nobody"
func (state * BaseLockHolderState) ReturnLock(resource_id NodeID) GraphNode { func (state * BaseLockHolderState) ReturnLock(lockable_id NodeID) GraphNode {
node, exists := state.delegation_map[resource_id] node, exists := state.delegation_map[lockable_id]
if exists == false { if exists == false {
panic("Attempted to take a get the original lock holder of a resource we don't own") panic("Attempted to take a get the original lock holder of a lockable we don't own")
} }
delete(state.delegation_map, resource_id) delete(state.delegation_map, lockable_id)
return node return node
} }
// Nothing can take a lock from a base lockable either // Nothing can take a lock from a base lockable either
func (state * BaseLockHolderState) AllowedToTakeLock(node_id NodeID, resource_id NodeID) bool { func (state * BaseLockHolderState) AllowedToTakeLock(node_id NodeID, lockable_id NodeID) bool {
_, exists := state.delegation_map[resource_id] _, exists := state.delegation_map[lockable_id]
if exists == false { if exists == false {
panic ("Trying to give away lock we don't own") panic ("Trying to give away lock we don't own")
} }
return false return false
} }
func (state * BaseLockHolderState) RecordLockHolder(resource_id NodeID, lock_holder GraphNode) { func (state * BaseLockHolderState) RecordLockHolder(lockable_id NodeID, lock_holder GraphNode) {
_, exists := state.delegation_map[resource_id] _, exists := state.delegation_map[lockable_id]
if exists == true { if exists == true {
panic("Attempted to lock a resource we're already holding(lock cycle)") panic("Attempted to lock a lockable we're already holding(lock cycle)")
} }
state.delegation_map[resource_id] = lock_holder state.delegation_map[lockable_id] = lock_holder
} }
func (state * BaseLockableState) Owner() GraphNode { func (state * BaseLockableState) Owner() GraphNode {
@ -215,25 +215,29 @@ type Lockable interface {
Unlock(node GraphNode, state LockableState) error Unlock(node GraphNode, state LockableState) error
} }
// Lockables propagate update up to multiple dependencies, and not downwards
// (subscriber to team won't get update to alliance, but subscriber to alliance will get update to team)
func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal) { func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal) {
UseStates(ctx, []GraphNode{lockable}, func(states []NodeState) (interface{}, error){ UseStates(ctx, []GraphNode{lockable}, func(states []NodeState) (interface{}, error){
lockable_state := states[0].(LockableState) lockable_state := states[0].(LockableState)
if signal.Direction() == Up { if signal.Direction() == Up {
// Child->Parent, lockable updates dependency lockables // Child->Parent, lockable updates dependency lockables
owner_sent := false
for _, dependency := range lockable_state.Dependencies() { for _, dependency := range lockable_state.Dependencies() {
SendUpdate(ctx, dependency, signal) SendUpdate(ctx, dependency, signal)
if lockable_state.Owner() != nil {
if dependency.ID() != lockable_state.Owner().ID() {
owner_sent = true
}
}
} }
} else if signal.Direction() == Down { if lockable_state.Owner() != nil && owner_sent == false {
// Parent->Child, lockable updates lock holder
if lockable_state.Owner() != nil {
SendUpdate(ctx, lockable_state.Owner(), signal) SendUpdate(ctx, lockable_state.Owner(), signal)
} }
} else if signal.Direction() == Down {
// Parent->Child, lockable updates lock holder
for _, requirement := range(lockable_state.Requirements()) { for _, requirement := range(lockable_state.Requirements()) {
SendUpdate(ctx, requirement, signal) SendUpdate(ctx, requirement, signal)
} }
} else if signal.Direction() == Direct { } else if signal.Direction() == Direct {
} else { } else {
panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction()))
@ -317,14 +321,14 @@ func UnlockLockable(ctx * GraphContext, lockable Lockable, node GraphNode, node_
func LockLockable(ctx * GraphContext, lockable Lockable, node GraphNode, node_state LockHolderState) error { func LockLockable(ctx * GraphContext, lockable Lockable, node GraphNode, node_state LockHolderState) error {
if node == nil || lockable == nil { if node == nil || lockable == nil {
panic("Cannot lock without a specified node and lockable") return fmt.Errorf("Cannot lock without a specified node and lockable")
} }
ctx.Log.Logf("resource", "LOCKING: %s from %s", lockable.ID(), node.ID()) ctx.Log.Logf("lockable", "LOCKING: %s from %s", lockable.ID(), node.ID())
_, err := UpdateStates(ctx, []GraphNode{lockable}, func(states []NodeState) ([]NodeState, interface{}, error) { _, err := UpdateStates(ctx, []GraphNode{lockable}, func(states []NodeState) ([]NodeState, interface{}, error) {
if lockable.ID() == node.ID() { if lockable.ID() == node.ID() {
if node_state != nil { if node_state != nil {
panic("node_state must be nil if locking lockable from itself") return nil, nil, fmt.Errorf("node_state must be nil if locking lockable from itself")
} }
node_state = states[0].(LockHolderState) node_state = states[0].(LockHolderState)
} }

@ -4,6 +4,7 @@ import (
"testing" "testing"
"fmt" "fmt"
"encoding/json" "encoding/json"
"time"
) )
func TestNewLockable(t * testing.T) { func TestNewLockable(t * testing.T) {
@ -231,3 +232,79 @@ func TestLockableLockTieredConflict(t * testing.T) {
t.Fatal("Locked r3 which depends on r1 while r2 which depends on r1 is already locked") t.Fatal("Locked r3 which depends on r1 while r2 which depends on r1 is already locked")
} }
} }
func TestLockableSimpleUpdate(t * testing.T) {
ctx := testContext(t)
l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err)
update_channel := l1.UpdateChannel(0)
go func() {
SendUpdate(ctx, l1, NewDirectSignal(l1, "test_update"))
}()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Didn't receive test_update sent to l1")
}
func TestLockableDownUpdate(t * testing.T) {
ctx := testContext(t)
l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err)
l2, err := NewLockable(ctx, "Test Lockable 2", []Lockable{l1})
fatalErr(t, err)
_, err = NewLockable(ctx, "Test Lockable 3", []Lockable{l2})
fatalErr(t, err)
update_channel := l1.UpdateChannel(0)
go func() {
SendUpdate(ctx, l2, NewDownSignal(l2, "test_update"))
}()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2")
}
func TestLockableUpUpdate(t * testing.T) {
ctx := testContext(t)
l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err)
l2, err := NewLockable(ctx, "Test Lockable 2", []Lockable{l1})
fatalErr(t, err)
l3, err := NewLockable(ctx, "Test Lockable 3", []Lockable{l2})
fatalErr(t, err)
update_channel := l3.UpdateChannel(0)
go func() {
SendUpdate(ctx, l2, NewSignal(l2, "test_update"))
}()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l2, 100*time.Millisecond, "Didn't receive test_update on l3 sent on l2")
}
func TestOwnerNotUpdatedTwice(t * testing.T) {
ctx := testContext(t)
l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err)
l2, err := NewLockable(ctx, "Test Lockable 2", []Lockable{l1})
fatalErr(t, err)
update_channel := l2.UpdateChannel(0)
go func() {
SendUpdate(ctx, l1, NewSignal(l1, "test_update"))
}()
(*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1")
(*GraphTester)(t).CheckForNone(update_channel, "Second update received on dependency")
}

@ -13,22 +13,22 @@ func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSigna
UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (interface{}, error) { UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (interface{}, error) {
thread_state := states[0].(ThreadState) thread_state := states[0].(ThreadState)
if signal.Direction() == Up { if signal.Direction() == Up {
// Child->Parent, thread updates parent and connected resources // Child->Parent, thread updates parent and connected requirement
if thread_state.Parent() != nil { if thread_state.Parent() != nil {
SendUpdate(ctx, thread_state.Parent(), signal) SendUpdate(ctx, thread_state.Parent(), signal)
} }
for _, resource := range(thread_state.Requirements()) { for _, dep := range(thread_state.Dependencies()) {
SendUpdate(ctx, resource, signal) SendUpdate(ctx, dep, signal)
} }
} else if signal.Direction() == Down { } else if signal.Direction() == Down {
// Parent->Child, thread updated children // Parent->Child, updates children and dependencies
for _, child := range(thread_state.Children()) { for _, child := range(thread_state.Children()) {
SendUpdate(ctx, child, signal) SendUpdate(ctx, child, signal)
} }
for _, dep := range(thread_state.Dependencies()) { for _, requirement := range(thread_state.Requirements()) {
SendUpdate(ctx, dep, signal) SendUpdate(ctx, requirement, signal)
} }
} else if signal.Direction() == Direct { } else if signal.Direction() == Direct {
@ -193,7 +193,7 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf
// Thread is the interface that thread tree nodes must implement // Thread is the interface that thread tree nodes must implement
type Thread interface { type Thread interface {
GraphNode Lockable
Action(action string) (ThreadAction, bool) Action(action string) (ThreadAction, bool)
Handler(signal_type string) (ThreadHandler, bool) Handler(signal_type string) (ThreadHandler, bool)
@ -231,7 +231,12 @@ func FindChild(ctx * GraphContext, thread Thread, thread_state ThreadState, id N
func RunThread(ctx * GraphContext, thread Thread) error { func RunThread(ctx * GraphContext, thread Thread) error {
ctx.Log.Logf("thread", "EVENT_RUN: %s", thread.ID()) ctx.Log.Logf("thread", "EVENT_RUN: %s", thread.ID())
_, err := UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (interface{}, error) { err := LockLockable(ctx, thread, thread, nil)
if err != nil {
return err
}
_, err = UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (interface{}, error) {
thread_state := states[0].(ThreadState) thread_state := states[0].(ThreadState)
if thread_state.Owner() == nil { if thread_state.Owner() == nil {
return nil, fmt.Errorf("EVENT_RUN_NOT_LOCKED: %s", thread_state.Name()) return nil, fmt.Errorf("EVENT_RUN_NOT_LOCKED: %s", thread_state.Name())
@ -275,7 +280,7 @@ type ThreadHandlers map[string]ThreadHandler
// This node by itself doesn't implement any special behaviours for children, so they will be ignored. // This node by itself doesn't implement any special behaviours for children, so they will be ignored.
// When started, this thread automatically transitions to completion // When started, this thread automatically transitions to completion
type BaseThread struct { type BaseThread struct {
BaseNode BaseLockable
Actions ThreadActions Actions ThreadActions
Handlers ThreadHandlers Handlers ThreadHandlers
@ -327,17 +332,23 @@ var ThreadDefaultStart = func(ctx * GraphContext, thread Thread) (string, error)
var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) { var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout()) ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout())
select { for {
case signal := <- thread.SignalChannel(): select {
ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) case signal := <- thread.SignalChannel():
signal_fn, exists := thread.Handler(signal.Type()) ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal)
if exists == true { if signal.Source() == thread.ID() {
ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) ctx.Log.Logf("thread", "THREAD_SIGNAL_INTERNAL")
return signal_fn(ctx, thread, signal) continue
} }
case <- thread.Timeout(): signal_fn, exists := thread.Handler(signal.Type())
ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), thread.TimeoutAction()) if exists == true {
return thread.TimeoutAction(), nil ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type())
return signal_fn(ctx, thread, signal)
}
case <- thread.Timeout():
ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), thread.TimeoutAction())
return thread.TimeoutAction(), nil
}
} }
return "wait", nil return "wait", nil
} }
@ -362,7 +373,7 @@ func NewBaseThreadState(name string) BaseThreadState {
func NewBaseThread(ctx * GraphContext, name string) (BaseThread, error) { func NewBaseThread(ctx * GraphContext, name string) (BaseThread, error) {
state := NewBaseThreadState(name) state := NewBaseThreadState(name)
thread := BaseThread{ thread := BaseThread{
BaseNode: NewNode(ctx, RandID(), &state), BaseLockable: BaseLockable{BaseNode: NewNode(ctx, RandID(), &state)},
Actions: ThreadActions{ Actions: ThreadActions{
"wait": ThreadWait, "wait": ThreadWait,
"start": ThreadDefaultStart, "start": ThreadDefaultStart,

@ -30,3 +30,39 @@ func TestNewEvent(t * testing.T) {
return nil, nil return nil, nil
}) })
} }
func TestEventWithRequirement(t * testing.T) {
ctx := logTestContext(t, []string{"lockable", "thread"})
l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err)
t1, err := NewThread(ctx, "Test Thread 1", []Lockable{l1}, ThreadActions{}, ThreadHandlers{})
fatalErr(t, err)
go func (thread Thread) {
time.Sleep(10*time.Millisecond)
_, err := UseStates(ctx, []GraphNode{l1}, func(states []NodeState) (interface{}, error) {
ser, err := json.MarshalIndent(states[0], "", " ")
fatalErr(t, err)
fmt.Printf("\n%s\n", ser)
return nil, nil
})
fatalErr(t, err)
SendUpdate(ctx, t1, CancelSignal(nil))
}(t1)
fatalErr(t, err)
err = RunThread(ctx, t1)
fatalErr(t, err)
_, err = UseStates(ctx, []GraphNode{l1}, func(states []NodeState) (interface{}, error) {
ser, err := json.MarshalIndent(states[0], "", " ")
fatalErr(t, err)
fmt.Printf("\n%s\n", ser)
return nil, nil
})
fatalErr(t, err)
}