From 09c25b1e48532f6d274e018089ddb55a24c1d305 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Mon, 31 Jul 2023 18:29:26 -0600 Subject: [PATCH] Moved UUID field to all signals to nodes can wait for responses to specific signals --- gql.go | 61 +++++++++++++++++++++++++++++++++++++++------------- gql_query.go | 16 +++++++------- gql_test.go | 4 ++-- lockable.go | 40 +++++++++++++++++----------------- node.go | 2 ++ signal.go | 51 ++++++++++++++++++++++++++++++------------- 6 files changed, 114 insertions(+), 60 deletions(-) diff --git a/gql.go b/gql.go index d870c8d..5364ded 100644 --- a/gql.go +++ b/gql.go @@ -169,7 +169,7 @@ type ResolveContext struct { ID uuid.UUID // Channel for the gql extension to route data to this context - Chan chan *ReadResultSignal + Chan chan Signal // Graph Context this resolver is running under Context *Context @@ -203,7 +203,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ return &ResolveContext{ Ext: gql_ext, ID: uuid.New(), - Chan: make(chan *ReadResultSignal, GQL_RESOLVER_CHAN_SIZE), + Chan: make(chan Signal, GQL_RESOLVER_CHAN_SIZE), Context: ctx, GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), Server: server, @@ -984,10 +984,10 @@ type GQLExt struct { http_done sync.WaitGroup `json:"-"` // map of read request IDs to gql request ID - resolver_reads map[uuid.UUID]uuid.UUID `json:"-"` - resolver_reads_lock sync.RWMutex `json:"-"` + resolver_response map[uuid.UUID]uuid.UUID `json:"-"` + resolver_response_lock sync.RWMutex `json:"-"` // map of gql request ID to active channel - resolver_chans map[uuid.UUID]chan *ReadResultSignal `json:"-"` + resolver_chans map[uuid.UUID]chan Signal `json:"-"` resolver_chans_lock sync.RWMutex `json:"-"` State string `json:"state"` @@ -1006,16 +1006,47 @@ func (ext *GQLExt) Field(name string) interface{} { func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { // Process ReadResultSignalType by forwarding it to the waiting resolver - if signal.Type() == ReadResultSignalType { + if signal.Type() == ErrorSignalType { + // TODO: Forward to resolver if waiting for it + sig := signal.(ErrorSignal) + ext.resolver_response_lock.RLock() + resolver_id, exists := ext.resolver_response[sig.UUID] + ext.resolver_response_lock.RUnlock() + if exists == true { + ext.resolver_response_lock.Lock() + delete(ext.resolver_response, sig.UUID) + ext.resolver_response_lock.Unlock() + + ext.resolver_chans_lock.RLock() + resolver_chan, exists := ext.resolver_chans[resolver_id] + ext.resolver_chans_lock.RUnlock() + if exists == true { + select { + case resolver_chan <- sig: + ctx.Log.Logf("gql", "Forwarded error to resolver %s, %+v", resolver_id, sig) + default: + ctx.Log.Logf("gql", "Resolver %s channel overflow %+v", resolver_id, sig) + ext.resolver_chans_lock.Lock() + delete(ext.resolver_chans, resolver_id) + ext.resolver_chans_lock.Unlock() + } + } else { + ctx.Log.Logf("gql", "received error signal response for resolver %s which doesn't exist", resolver_id) + } + + } else { + ctx.Log.Logf("gql", "received error signal response %s with no mapped resolver", sig.UUID) + } + } else if signal.Type() == ReadResultSignalType { sig := signal.(ReadResultSignal) - ext.resolver_reads_lock.RLock() - resolver_id, exists := ext.resolver_reads[sig.UUID] - ext.resolver_reads_lock.RUnlock() + ext.resolver_response_lock.RLock() + resolver_id, exists := ext.resolver_response[sig.UUID] + ext.resolver_response_lock.RUnlock() if exists == true { - ext.resolver_reads_lock.Lock() - delete(ext.resolver_reads, sig.UUID) - ext.resolver_reads_lock.Unlock() + ext.resolver_response_lock.Lock() + delete(ext.resolver_response, sig.UUID) + ext.resolver_response_lock.Unlock() ext.resolver_chans_lock.RLock() resolver_chan, exists := ext.resolver_chans[resolver_id] @@ -1023,7 +1054,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa if exists == true { select { - case resolver_chan <- &sig: + case resolver_chan <- sig: ctx.Log.Logf("gql", "Forwarded to resolver %s, %+v", resolver_id, sig) default: ctx.Log.Logf("gql", "Resolver %s channel overflow %+v", resolver_id, sig) @@ -1151,8 +1182,8 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta return &GQLExt{ State: state, Listen: listen, - resolver_reads: map[uuid.UUID]uuid.UUID{}, - resolver_chans: map[uuid.UUID]chan *ReadResultSignal{}, + resolver_response: map[uuid.UUID]uuid.UUID{}, + resolver_chans: map[uuid.UUID]chan Signal{}, tls_cert: tls_cert, tls_key: tls_key, }, nil diff --git a/gql_query.go b/gql_query.go index 889c60a..a2cfa6a 100644 --- a/gql_query.go +++ b/gql_query.go @@ -50,16 +50,16 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error read_signal := NewReadSignal(ext_fields) - ctx.Ext.resolver_reads_lock.Lock() - ctx.Ext.resolver_reads[read_signal.UUID] = ctx.ID - ctx.Ext.resolver_reads_lock.Unlock() + ctx.Ext.resolver_response_lock.Lock() + ctx.Ext.resolver_response[read_signal.UUID] = ctx.ID + ctx.Ext.resolver_response_lock.Unlock() err = ctx.Context.Send(ctx.Server.ID, id, read_signal) read_signals[id] = read_signal.UUID if err != nil { - ctx.Ext.resolver_reads_lock.Lock() - delete(ctx.Ext.resolver_reads, read_signal.UUID) - ctx.Ext.resolver_reads_lock.Unlock() + ctx.Ext.resolver_response_lock.Lock() + delete(ctx.Ext.resolver_response, read_signal.UUID) + ctx.Ext.resolver_response_lock.Unlock() return nil, err } } @@ -67,11 +67,11 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ responses := []NodeResult{} for node_id, sig_id := range(read_signals) { // Wait for the response, returning an error on timeout - response, err := WaitForReadResult(ctx.Chan, time.Millisecond*100, sig_id) + response, err := WaitForResult(ctx.Chan, time.Millisecond*100, sig_id) if err != nil { return nil, err } - responses = append(responses, NodeResult{node_id, response}) + responses = append(responses, NodeResult{node_id, response.(*ReadResultSignal)}) } return responses, nil diff --git a/gql_test.go b/gql_test.go index 4711bbe..3e41b72 100644 --- a/gql_test.go +++ b/gql_test.go @@ -112,7 +112,7 @@ func TestGQLDB(t *testing.T) { fatalErr(t, err) _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStringSignal) bool { - return sig.Str == "stopped" && sig.ID == gql.ID + return sig.Str == "stopped" && sig.NodeID == gql.ID }) fatalErr(t, err) @@ -130,7 +130,7 @@ func TestGQLDB(t *testing.T) { err = ctx.Send(gql_loaded.ID, gql_loaded.ID, StopSignal) fatalErr(t, err) _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStringSignal) bool { - return sig.Str == "stopped" && sig.ID == gql_loaded.ID + return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID }) fatalErr(t, err) diff --git a/lockable.go b/lockable.go index 100aeea..bbd420b 100644 --- a/lockable.go +++ b/lockable.go @@ -169,11 +169,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node switch state { case "unlock": if ext.Owner == nil { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("already_unlocked"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_unlocked"))) } else if source != *ext.Owner { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_owner"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_owner"))) } else if ext.PendingOwner == nil { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("already_unlocking"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_unlocking"))) } else { if len(ext.Requirements) == 0 { ext.Owner = nil @@ -199,11 +199,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node case "unlocking": state, exists := ext.Requirements[source] if exists == false { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_requirement"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) } else if state.Link != "linked" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("node_not_linked"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("node_not_linked"))) } else if state.Lock != "unlocking" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_unlocking"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_unlocking"))) } case "unlocked": @@ -213,11 +213,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node state, exists := ext.Requirements[source] if exists == false { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_requirement"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) } else if state.Link != "linked" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_linked"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_linked"))) } else if state.Lock != "unlocking" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_unlocking"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_unlocking"))) } else { state.Lock = "unlocked" ext.Requirements[source] = state @@ -248,11 +248,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node state, exists := ext.Requirements[source] if exists == false { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_requirement"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) } else if state.Link != "linked" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_linked"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_linked"))) } else if state.Lock != "locking" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_locking"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_locking"))) } else { state.Lock = "locked" ext.Requirements[source] = state @@ -278,18 +278,18 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node case "locking": state, exists := ext.Requirements[source] if exists == false { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_requirement"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) } else if state.Link != "linked" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("node_not_linked"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("node_not_linked"))) } else if state.Lock != "locking" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("not_locking"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_locking"))) } case "lock": if ext.Owner != nil { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("already_locked"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_locked"))) } else if ext.PendingOwner != nil { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("already_locking"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_locking"))) } else { owner := source if len(ext.Requirements) == 0 { @@ -321,7 +321,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal IDStringSignal) { ctx.Log.Logf("lockable", "LINK__START_SIGNAL: %s->%s %+v", source, node.ID, signal) link_type := signal.Str - target := signal.ID + target := signal.NodeID switch link_type { case "req": state, exists := ext.Requirements[target] @@ -336,9 +336,9 @@ func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node } } else if exists == true { if state.Link == "linking" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("already_linking_req"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_linking_req"))) } else if state.Link == "linked" { - ctx.Send(node.ID, source, NewErrorSignal(fmt.Errorf("already_req"))) + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_req"))) } } else if dep_exists == true { ctx.Send(node.ID, source, NewLinkStartSignal("already_dep", target)) diff --git a/node.go b/node.go index 0583eb5..79eb357 100644 --- a/node.go +++ b/node.go @@ -209,6 +209,7 @@ func nodeLoop(ctx *Context, node *Node) error { err := Allowed(ctx, msg.Source, signal.Permission(), node) if err != nil { ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err) + ctx.Send(node.ID, msg.Source, NewErrorSignal(msg.Signal.ID(), err)) continue } case <-node.TimeoutChan: @@ -239,6 +240,7 @@ func nodeLoop(ctx *Context, node *Node) error { // Handle special signal types if signal.Type() == StopSignalType { + ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), nil)) node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID)) break } else if signal.Type() == ReadSignalType { diff --git a/signal.go b/signal.go index 443936c..cacfde9 100644 --- a/signal.go +++ b/signal.go @@ -42,17 +42,26 @@ func (signal_type SignalType) Prefix() string { return "SIGNAL: " } type Signal interface { Serializable[SignalType] Direction() SignalDirection + ID() uuid.UUID Permission() Action } -func WaitForReadResult(listener chan *ReadResultSignal, timeout time.Duration, id uuid.UUID) (*ReadResultSignal, error) { +func WaitForResult(listener chan Signal, timeout time.Duration, id uuid.UUID) (Signal, error) { timeout_channel := time.After(timeout) var err error = nil - var result *ReadResultSignal = nil - select { - case result =<-listener: - case <-timeout_channel: - err = fmt.Errorf("timeout waiting for read response to %s", id) + var result Signal = nil + run := true + for run == true { + select { + case result=<-listener: + if result.ID() == id { + run = false + } + case <-timeout_channel: + result = nil + err = fmt.Errorf("timeout waiting for read response to %s", id) + run = false + } } return result, err } @@ -88,6 +97,11 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time. type BaseSignal struct { SignalDirection SignalDirection `json:"direction"` SignalType SignalType `json:"type"` + uuid.UUID `json:"id"` +} + +func (signal BaseSignal) ID() uuid.UUID { + return signal.UUID } func (signal BaseSignal) Type() SignalType { @@ -108,6 +122,7 @@ func (signal BaseSignal) Serialize() ([]byte, error) { func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal { signal := BaseSignal{ + UUID: uuid.New(), SignalDirection: direction, SignalType: signal_type, } @@ -133,9 +148,13 @@ type ErrorSignal struct { Error error `json:"error"` } -func NewErrorSignal(err error) ErrorSignal { +func NewErrorSignal(req_id uuid.UUID, err error) ErrorSignal { return ErrorSignal{ - BaseSignal: NewDirectSignal(ErrorSignalType), + BaseSignal: BaseSignal{ + Direct, + ErrorSignalType, + req_id, + }, Error: err, } } @@ -175,8 +194,8 @@ func (signal StringSignal) Serialize() ([]byte, error) { type IDStringSignal struct { BaseSignal - ID NodeID `json:"id"` - Str string `json:"state"` + NodeID `json:"node_id"` + Str string `json:"string"` } func (signal IDStringSignal) Serialize() ([]byte, error) { @@ -194,7 +213,7 @@ func (signal IDStringSignal) String() string { func NewStatusSignal(status string, source NodeID) IDStringSignal { return IDStringSignal{ BaseSignal: NewUpSignal(StatusSignalType), - ID: source, + NodeID: source, Str: status, } } @@ -209,7 +228,7 @@ func NewLinkSignal(state string) StringSignal { func NewIDStringSignal(signal_type SignalType, direction SignalDirection, state string, id NodeID) IDStringSignal { return IDStringSignal{ BaseSignal: NewBaseSignal(signal_type, direction), - ID: id, + NodeID: id, Str: state, } } @@ -249,15 +268,17 @@ func NewReadSignal(exts map[ExtType][]string) ReadSignal { type ReadResultSignal struct { BaseSignal - uuid.UUID NodeType Extensions map[ExtType]map[string]interface{} `json:"extensions"` } func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]map[string]interface{}) ReadResultSignal { return ReadResultSignal{ - BaseSignal: NewDirectSignal(ReadResultSignalType), - UUID: req_id, + BaseSignal: BaseSignal{ + Direct, + ReadResultSignalType, + req_id, + }, NodeType: node_type, Extensions: exts, }