diff --git a/graph_test.go b/graph_test.go index 0562b1a..1ea2b98 100644 --- a/graph_test.go +++ b/graph_test.go @@ -13,7 +13,7 @@ import ( type GraphTester testing.T const listner_timeout = 50 * time.Millisecond -func (t * GraphTester) WaitForLinkState(ctx * Context, listener *ListenerExt, state string, timeout time.Duration, str string) Signal { +func (t * GraphTester) WaitForState(ctx * Context, listener *ListenerExt, stype SignalType, state string, timeout time.Duration, str string) Signal { timeout_channel := time.After(timeout) for true { select { @@ -22,15 +22,13 @@ func (t * GraphTester) WaitForLinkState(ctx * Context, listener *ListenerExt, st ctx.Log.Logf("test", "SIGNAL_CHANNEL_CLOSED: %s", listener) t.Fatal(str) } - if signal.Type() == LinkSignalType { - sig, ok := signal.(LinkSignal) + if signal.Type() == stype { + sig, ok := signal.(StateSignal) if ok == true { - ctx.Log.Logf("test", "Link state received: %s", sig.State) + ctx.Log.Logf("test", "%s state received: %s", stype, sig.State) if sig.State == state { return signal } - } else { - ctx.Log.Logf("test", "Failed to cast signal to LinkSignal: %+v", signal) } } case <-timeout_channel: @@ -94,7 +92,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { nil, listener_extension, NewACLExt(&policy), - NewLockableExt(nil, nil, nil, nil)) + NewLockableExt()) return listener, listener_extension } diff --git a/lockable.go b/lockable.go index 2288bee..ab7b267 100644 --- a/lockable.go +++ b/lockable.go @@ -67,24 +67,13 @@ func (ext *LockableExt) Serialize() ([]byte, error) { return json.MarshalIndent(ext, "", " ") } -func NewLockableExt(owner *NodeID, requirements map[NodeID]string, dependencies map[NodeID]string, locks_held map[NodeID]*NodeID) *LockableExt { - if requirements == nil { - requirements = map[NodeID]string{} - } - - if dependencies == nil { - dependencies = map[NodeID]string{} - } - - if locks_held == nil { - locks_held = map[NodeID]*NodeID{} - } - +func NewLockableExt() *LockableExt { return &LockableExt{ - Owner: owner, - Requirements: requirements, - Dependencies: dependencies, - LocksHeld: locks_held, + Owner: nil, + Requirements: map[NodeID]string{}, + Dependencies: map[NodeID]string{}, + LocksHeld: map[NodeID]*NodeID{}, + LockStates: map[NodeID]string{}, } } @@ -92,9 +81,25 @@ type LockableExt struct { Owner *NodeID `json:"owner"` Requirements map[NodeID]string `json:"requirements"` Dependencies map[NodeID]string `json:"dependencies"` + LockStates map[NodeID]string `json:"lock_states"` LocksHeld map[NodeID]*NodeID `json:"locks_held"` } +func LockLockable(ctx *Context, node *Node) error { + ext, err := GetExt[*LockableExt](node) + if err != nil { + return err + } + + _, exists := ext.LockStates[node.ID] + if exists == true { + return fmt.Errorf("%s is already being locked, cannot lock again", node.ID) + } + + ext.LockStates[node.ID] = "start" + return ctx.Send(node.ID, node.ID, NewLockSignal("lock")) +} + func LinkRequirement(ctx *Context, dependency *Node, requirement NodeID) error { dep_ext, err := GetExt[*LockableExt](dependency) if err != nil { @@ -115,7 +120,16 @@ func LinkRequirement(ctx *Context, dependency *Node, requirement NodeID) error { return ctx.Send(dependency.ID, requirement, NewLinkSignal("req_link")) } -func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal LinkSignal) { +func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal StateSignal) { + ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal) + state := signal.State + switch state { + default: + ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", state) + } +} + +func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal StateSignal) { ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) state := signal.State switch state { @@ -212,7 +226,9 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal case Direct: switch signal.Type() { case LinkSignalType: - ext.HandleLinkSignal(ctx, source, node, signal.(LinkSignal)) + ext.HandleLinkSignal(ctx, source, node, signal.(StateSignal)) + case LockSignalType: + ext.HandleLockSignal(ctx, source, node, signal.(StateSignal)) default: } default: diff --git a/lockable_test.go b/lockable_test.go index 8f79047..616f2ec 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -6,8 +6,8 @@ import ( ) const TestLockableType = NodeType("TEST_LOCKABLE") -func lockableTestContext(t *testing.T) *Context { - ctx := logTestContext(t, []string{"lockable", "test"}) +func lockableTestContext(t *testing.T, logs []string) *Context { + ctx := logTestContext(t, logs) err := ctx.RegisterNodeType(TestLockableType, []ExtType{ACLExtType, LockableExtType, ListenerExtType}) fatalErr(t, err) @@ -17,33 +17,60 @@ func lockableTestContext(t *testing.T) *Context { var link_policy = NewAllNodesPolicy([]SignalType{LinkSignalType, StatusSignalType}) +var lock_policy = NewAllNodesPolicy([]SignalType{LinkSignalType, LockSignalType, StatusSignalType}) -func TestLinkStatus(t *testing.T) { - ctx := lockableTestContext(t) +func TestLink(t *testing.T) { + ctx := lockableTestContext(t, []string{}) l1_listener := NewListenerExt(10) l1 := NewNode(ctx, RandID(), TestLockableType, nil, l1_listener, NewACLExt(&link_policy), - NewLockableExt(nil, nil, nil, nil), + NewLockableExt(), ) l2_listener := NewListenerExt(10) l2 := NewNode(ctx, RandID(), TestLockableType, nil, l2_listener, NewACLExt(&link_policy), - NewLockableExt(nil, nil, nil, nil), + NewLockableExt(), ) // Link l2 as a requirement of l1 err := LinkRequirement(ctx, l1, l2.ID) fatalErr(t, err) - (*GraphTester)(t).WaitForLinkState(ctx, l1_listener, "dep_linked", time.Millisecond*100, "No dep_link") - (*GraphTester)(t).WaitForLinkState(ctx, l2_listener, "req_linked", time.Millisecond*100, "No req_linked") + (*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "dep_linked", time.Millisecond*10, "No dep_link") + (*GraphTester)(t).WaitForState(ctx, l2_listener, LinkSignalType, "req_linked", time.Millisecond*10, "No req_linked") err = ctx.Send(l2.ID, l2.ID, NewStatusSignal("TEST", l2.ID)) fatalErr(t, err) - (*GraphTester)(t).WaitForStatus(ctx, l1_listener, "TEST", time.Millisecond*100, "No TEST on l1") - (*GraphTester)(t).WaitForStatus(ctx, l2_listener, "TEST", time.Millisecond*100, "No TEST on l2") + (*GraphTester)(t).WaitForStatus(ctx, l1_listener, "TEST", time.Millisecond*10, "No TEST on l1") + (*GraphTester)(t).WaitForStatus(ctx, l2_listener, "TEST", time.Millisecond*10, "No TEST on l2") +} + +func TestLock(t *testing.T) { + ctx := lockableTestContext(t, []string{"test", "lockable"}) + + l1_listener := NewListenerExt(10) + l1 := NewNode(ctx, RandID(), TestLockableType, nil, + l1_listener, + NewACLExt(&link_policy), + NewLockableExt(), + ) + l2_listener := NewListenerExt(10) + l2 := NewNode(ctx, RandID(), TestLockableType, nil, + l2_listener, + NewACLExt(&link_policy), + NewLockableExt(), + ) + + err := LinkRequirement(ctx, l1, l2.ID) + fatalErr(t, err) + (*GraphTester)(t).WaitForState(ctx, l1_listener, LinkSignalType, "dep_linked", time.Millisecond*10, "No dep_link") + (*GraphTester)(t).WaitForState(ctx, l2_listener, LinkSignalType, "req_linked", time.Millisecond*10, "No req_linked") + + err = LockLockable(ctx, l1) + fatalErr(t, err) + (*GraphTester)(t).WaitForState(ctx, l1_listener, LockSignalType, "locked", time.Millisecond*10, "No locked") } diff --git a/signal.go b/signal.go index 6d7c9a6..e717f33 100644 --- a/signal.go +++ b/signal.go @@ -100,27 +100,35 @@ func NewStatusSignal(status string, source NodeID) StatusSignal { } const LinkSignalType = SignalType("LINK") -type LinkSignal struct { +const LockSignalType = SignalType("LOCK") +type StateSignal struct { BaseSignal State string `json:"state"` } -func (signal LinkSignal) Serialize() ([]byte, error) { +func (signal StateSignal) Serialize() ([]byte, error) { return json.MarshalIndent(signal, "", " ") } -func (signal LinkSignal) String() string { +func (signal StateSignal) String() string { ser, _ := signal.Serialize() return string(ser) } -func NewLinkSignal(state string) LinkSignal { - return LinkSignal{ +func NewLinkSignal(state string) StateSignal { + return StateSignal{ BaseSignal: NewDirectSignal(LinkSignalType), State: state, } } +func NewLockSignal(state string) StateSignal { + return StateSignal{ + BaseSignal: NewDirectSignal(LockSignalType), + State: state, + } +} + type StartChildSignal struct { IDSignal Action string `json:"action"`