diff --git a/gql.go b/gql.go index 4c064a1..cdccfcc 100644 --- a/gql.go +++ b/gql.go @@ -1287,7 +1287,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa switch sig := signal.(type) { case *ErrorSignal: // TODO: Forward to resolver if waiting for it - response_chan := ext.FreeResponseChannel(sig.Header().ReqID) + response_chan := ext.FreeResponseChannel(sig.ReqID) if response_chan != nil { select { case response_chan <- sig: diff --git a/gql_node.go b/gql_node.go index f75de80..2dcf2b1 100644 --- a/gql_node.go +++ b/gql_node.go @@ -113,14 +113,14 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ msgs := Messages{} msgs = msgs.Add(ctx.Context, ctx.Server.ID, ctx.Key, read_signal, id) - response_chan := ctx.Ext.GetResponseChannel(read_signal.ID) - resp_channels[read_signal.ID] = response_chan - indices[read_signal.ID] = i + response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) + resp_channels[read_signal.ID()] = response_chan + indices[read_signal.ID()] = i // TODO: Send all at once instead of creating Messages for each err = ctx.Context.Send(msgs) if err != nil { - ctx.Ext.FreeResponseChannel(read_signal.ID) + ctx.Ext.FreeResponseChannel(read_signal.ID()) return nil, err } diff --git a/lockable.go b/lockable.go index caa17ff..0076c8f 100644 --- a/lockable.go +++ b/lockable.go @@ -46,7 +46,7 @@ func UnlockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) msgs := Messages{} signal := NewLockSignal("unlock") msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target) - return signal.Header().ID, ctx.Send(msgs) + return signal.ID(), ctx.Send(msgs) } // Send the signal to lock a node from itself @@ -54,7 +54,7 @@ func LockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) { msgs := Messages{} signal := NewLockSignal("lock") msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target) - return signal.Header().ID, ctx.Send(msgs) + return signal.ID(), ctx.Send(msgs) } func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) Messages { @@ -89,27 +89,27 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID case "add": _, exists := ext.Requirements[signal.NodeID] if exists == true { - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "already_requirement"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_requirement"), source) } else { if ext.Requirements == nil { ext.Requirements = map[NodeID]ReqState{} } ext.Requirements[signal.NodeID] = Unlocked - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "req_added"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "req_added"), source) } case "remove": _, exists := ext.Requirements[signal.NodeID] if exists == false { - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "can't link: not_requirement"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "can't link: not_requirement"), source) } else { delete(ext.Requirements, signal.NodeID) - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "req_removed"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "req_removed"), source) } default: - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "unknown_action"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "unknown_action"), source) } } else { - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_unlocked"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) } return msgs } @@ -123,7 +123,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID case "locked": state, found := ext.Requirements[source] if found == false && source != node.ID { - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "got 'locked' from non-requirement"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "got 'locked' from non-requirement"), source) } else if state == Locking { if ext.State == Locking { ext.Requirements[source] = Locked @@ -151,7 +151,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID case "unlocked": state, found := ext.Requirements[source] if found == false { - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_requirement"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state == Unlocking { ext.Requirements[source] = Unlocked reqs := 0 @@ -188,7 +188,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("locked"), new_owner) } else { ext.State = Locking - id := signal.ID + id := signal.ID() ext.ReqID = &id new_owner := source ext.PendingOwner = &new_owner @@ -202,7 +202,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID } } } else { - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_unlocked"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) } case "unlock": if ext.State == Locked { @@ -214,7 +214,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlocked"), new_owner) } else if source == *ext.Owner { ext.State = Unlocking - id := signal.ID + id := signal.ID() ext.ReqID = &id ext.PendingOwner = nil for id, state := range(ext.Requirements) { @@ -227,7 +227,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID } } } else { - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID, "not_locked"), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locked"), source) } default: ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State) @@ -239,7 +239,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { messages := Messages{} - switch signal.Header().Direction { + switch signal.Direction() { case Up: if ext.Owner != nil { if *ext.Owner != node.ID { diff --git a/lockable_test.go b/lockable_test.go index 1a3d520..0f64d12 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -153,9 +153,7 @@ func TestLock(t *testing.T) { id, err := LockLockable(ctx, l1, l1.ID) fatalErr(t, err) - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *ErrorSignal) bool { - return sig.Error == "not_unlocked" && sig.Header().ReqID == id - }) + _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id) fatalErr(t, err) _, err = UnlockLockable(ctx, l0, l5.ID) diff --git a/node.go b/node.go index ec5b7d8..fae514d 100644 --- a/node.go +++ b/node.go @@ -148,7 +148,7 @@ func (node *Node) QueueSignal(time time.Time, signal Signal) { func (node *Node) DequeueSignal(id uuid.UUID) error { idx := -1 for i, q := range(node.SignalQueue) { - if q.Signal.Header().ID == id { + if q.Signal.ID() == id { idx = i break } @@ -293,21 +293,21 @@ func nodeLoop(ctx *Context, node *Node) error { ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %+v(%+s)", princ_id, node.ID, reflect.TypeOf(msg.Signal), msg.Signal) ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source) msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(msg.Signal.Header().ID, "acl denied"), msg.Source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), msg.Source) ctx.Send(msgs) continue } else if resp == Pending { ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends) - timeout_signal := NewACLTimeoutSignal(msg.Signal.Header().ID) + timeout_signal := NewACLTimeoutSignal(msg.Signal.ID()) node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal) msgs := Messages{} for policy_type, sigs := range(pends) { for _, m := range(sigs) { msgs = append(msgs, m) - node.PendingSignals[m.Signal.Header().ID] = PendingSignal{policy_type, false, msg.Signal.Header().ID} + node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, false, msg.Signal.ID()} } } - node.PendingACLs[msg.Signal.Header().ID] = PendingACL{len(msgs), timeout_signal.ID, msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source} + node.PendingACLs[msg.Signal.ID()] = PendingACL{len(msgs), timeout_signal.ID(), msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source} ctx.Send(msgs) continue } else if resp == Allow { @@ -327,7 +327,7 @@ func nodeLoop(ctx *Context, node *Node) error { t := node.NextSignal.Time i := -1 for j, queued := range(node.SignalQueue) { - if queued.Signal.Header().ID == node.NextSignal.Signal.Header().ID { + if queued.Signal.ID() == node.NextSignal.Signal.ID() { i = j break } @@ -349,55 +349,58 @@ func nodeLoop(ctx *Context, node *Node) error { ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) - info, waiting := node.PendingSignals[signal.Header().ReqID] - if waiting == true { - if info.Found == false { - info.Found = true - node.PendingSignals[signal.Header().ReqID] = info - ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal) - req_info, exists := node.PendingACLs[info.ID] - if exists == true { - req_info.Counter -= 1 - req_info.Responses = append(req_info.Responses, signal) - - idx := -1 - for i, p := range(node.Policies) { - if p.ID() == info.Policy { - idx = i - break - } - } - if idx == -1 { - ctx.Log.Logf("policy", "PENDING_FOR_NONEXISTENT_POLICY: %s - %s", node.ID, info.Policy) - delete(node.PendingACLs, info.ID) - } else { - allowed := node.Policies[idx].ContinueAllows(ctx, req_info, signal) - if allowed == Allow { - ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) - signal = req_info.Signal - source = req_info.Source - err := node.DequeueSignal(req_info.TimeoutID) - if err != nil { - panic("dequeued a passed signal") - } - delete(node.PendingACLs, info.ID) - } else if req_info.Counter == 0 { - ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) - // Send the denied response - msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.Header().ID, "ACL_DENIED"), req_info.Source) - err := ctx.Send(msgs) - if err != nil { - ctx.Log.Logf("signal", "SEND_ERR: %s", err) - } - err = node.DequeueSignal(req_info.TimeoutID) - if err != nil { - panic("dequeued a passed signal") + response, ok := signal.(ResponseSignal) + if ok == true { + info, waiting := node.PendingSignals[response.ResponseID()] + if waiting == true { + if info.Found == false { + info.Found = true + node.PendingSignals[response.ResponseID()] = info + ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal) + req_info, exists := node.PendingACLs[info.ID] + if exists == true { + req_info.Counter -= 1 + req_info.Responses = append(req_info.Responses, signal) + + idx := -1 + for i, p := range(node.Policies) { + if p.ID() == info.Policy { + idx = i + break } + } + if idx == -1 { + ctx.Log.Logf("policy", "PENDING_FOR_NONEXISTENT_POLICY: %s - %s", node.ID, info.Policy) delete(node.PendingACLs, info.ID) } else { - node.PendingACLs[info.ID] = req_info - continue + allowed := node.Policies[idx].ContinueAllows(ctx, req_info, signal) + if allowed == Allow { + ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) + signal = req_info.Signal + source = req_info.Source + err := node.DequeueSignal(req_info.TimeoutID) + if err != nil { + panic("dequeued a passed signal") + } + delete(node.PendingACLs, info.ID) + } else if req_info.Counter == 0 { + ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) + // Send the denied response + msgs := Messages{} + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source) + err := ctx.Send(msgs) + if err != nil { + ctx.Log.Logf("signal", "SEND_ERR: %s", err) + } + err = node.DequeueSignal(req_info.TimeoutID) + if err != nil { + panic("dequeued a passed signal") + } + delete(node.PendingACLs, info.ID) + } else { + node.PendingACLs[info.ID] = req_info + continue + } } } } @@ -421,7 +424,7 @@ func nodeLoop(ctx *Context, node *Node) error { case *ReadSignal: result := node.ReadFields(ctx, sig.Extensions) msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewReadResultSignal(sig.ID, node.ID, node.Type, result), source) + msgs = msgs.Add(ctx, node.ID, node.Key, NewReadResultSignal(sig.ID(), node.ID, node.Type, result), source) ctx.Send(msgs) default: diff --git a/signal.go b/signal.go index 98f7b2a..17bf402 100644 --- a/signal.go +++ b/signal.go @@ -14,27 +14,64 @@ const ( Direct ) +type TimeoutSignal struct { + SignalHeader +} + +func NewTimeoutSignal() *TimeoutSignal { + return &TimeoutSignal{ + NewSignalHeader(Direct), + } +} + +// Timeouts are internal only, no permission allows sending them +func (signal TimeoutSignal) Permission() Tree { + return nil +} + type SignalHeader struct { - Direction SignalDirection `gv:"direction"` - ID uuid.UUID `gv:"id"` - ReqID uuid.UUID `gv:"req_id"` + Id uuid.UUID `gv:"id"` + Dir SignalDirection `gv:"direction"` } -func (header SignalHeader) Header() SignalHeader { - return header +func (signal SignalHeader) ID() uuid.UUID { + return signal.Id +} + +func (signal SignalHeader) Direction() SignalDirection { + return signal.Dir } func (header SignalHeader) String() string { - return fmt.Sprintf("SignalHeader(%d, %s->%s)", header.Direction, header.ID, header.ReqID) + return fmt.Sprintf("SignalHeader(%d, %s)", header.Dir, header.Id) +} + +type ResponseSignal interface { + Signal + ResponseID() uuid.UUID +} + +type ResponseHeader struct { + SignalHeader + ReqID uuid.UUID `gv:"req_id"` +} + +func (header ResponseHeader) ResponseID() uuid.UUID { + return header.ReqID +} + +func (header ResponseHeader) String() string { + return fmt.Sprintf("ResponseHeader(%d, %s->%s)", header.Dir, header.Id, header.ReqID) } type Signal interface { fmt.Stringer - Header() SignalHeader + ID() uuid.UUID + Direction() SignalDirection Permission() Tree } -func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UUID) (Signal, error) { +func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UUID) (ResponseSignal, error) { var timeout_channel <- chan time.Time if timeout > 0 { timeout_channel = time.After(timeout) @@ -46,8 +83,13 @@ func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UU if signal == nil { return nil, fmt.Errorf("LISTENER_CLOSED") } - if signal.Header().ReqID == req_id { - return signal, nil + resp_signal, ok := signal.(ResponseSignal) + if ok == false { + continue + } + + if resp_signal.ResponseID() == req_id { + return resp_signal, nil } case <-timeout_channel: return nil, fmt.Errorf("LISTENER_TIMEOUT") @@ -82,22 +124,17 @@ func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, check } func NewSignalHeader(direction SignalDirection) SignalHeader { - id := uuid.New() - header := SignalHeader{ - ID: id, - ReqID: id, - Direction: direction, + return SignalHeader{ + uuid.New(), + direction, } - return header } -func NewRespHeader(req_id uuid.UUID, direction SignalDirection) SignalHeader { - header := SignalHeader{ - ID: uuid.New(), - ReqID: req_id, - Direction: direction, +func NewResponseHeader(req_id uuid.UUID, direction SignalDirection) ResponseHeader { + return ResponseHeader{ + NewSignalHeader(direction), + req_id, } - return header } type CreateSignal struct { @@ -145,7 +182,7 @@ func NewStopSignal() *StopSignal { } type SuccessSignal struct { - SignalHeader + ResponseHeader } func (signal SuccessSignal) Permission() Tree { return Tree{ @@ -156,12 +193,12 @@ func (signal SuccessSignal) Permission() Tree { } func NewSuccessSignal(req_id uuid.UUID) Signal { return &SuccessSignal{ - NewRespHeader(req_id, Direct), + NewResponseHeader(req_id, Direct), } } type ErrorSignal struct { - SignalHeader + ResponseHeader Error string } func (signal ErrorSignal) String() string { @@ -176,13 +213,13 @@ func (signal ErrorSignal) Permission() Tree { } func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Signal { return &ErrorSignal{ - NewRespHeader(req_id, Direct), + NewResponseHeader(req_id, Direct), fmt.Sprintf(fmt_string, args...), } } type ACLTimeoutSignal struct { - SignalHeader + ResponseHeader } func (signal ACLTimeoutSignal) Permission() Tree { return Tree{ @@ -191,7 +228,7 @@ func (signal ACLTimeoutSignal) Permission() Tree { } func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal { sig := &ACLTimeoutSignal{ - NewRespHeader(req_id, Direct), + NewResponseHeader(req_id, Direct), } return sig } @@ -292,7 +329,7 @@ func NewReadSignal(exts map[ExtType][]string) *ReadSignal { } type ReadResultSignal struct { - SignalHeader + ResponseHeader NodeID NodeID NodeType NodeType Extensions map[ExtType]map[string]SerializedValue @@ -306,7 +343,7 @@ func (signal ReadResultSignal) Permission() Tree { } func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]SerializedValue) *ReadResultSignal { return &ReadResultSignal{ - NewRespHeader(req_id, Direct), + NewResponseHeader(req_id, Direct), node_id, node_type, exts,