From df09433b88dc54467c7853e1cdb84b2658ace50d Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Mon, 31 Jul 2023 20:53:56 -0600 Subject: [PATCH] Changed resolves to map back to request channel specifically instead of to the context --- gql.go | 75 +++++++++++++++++++--------------------------------- gql_query.go | 18 ++++++++----- node.go | 2 +- signal.go | 4 +-- 4 files changed, 40 insertions(+), 59 deletions(-) diff --git a/gql.go b/gql.go index e8d9319..dd946d4 100644 --- a/gql.go +++ b/gql.go @@ -168,8 +168,8 @@ type ResolveContext struct { // ID generated for the context so the gql extension can route data to it ID uuid.UUID - // Channel for the gql extension to route data to this context - Chan chan Signal + // Channels for the gql extension to route data to this context + Chans map[uuid.UUID]chan Signal // Graph Context this resolver is running under Context *Context @@ -203,7 +203,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ return &ResolveContext{ Ext: gql_ext, ID: uuid.New(), - Chan: make(chan Signal, GQL_RESOLVER_CHAN_SIZE), + Chans: map[uuid.UUID]chan Signal{}, Context: ctx, GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), Server: server, @@ -254,10 +254,6 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW params.VariableValues = query.Variables } - gql_ext.resolver_chans_lock.Lock() - gql_ext.resolver_chans[resolve_context.ID] = resolve_context.Chan - gql_ext.resolver_chans_lock.Unlock() - result := graphql.Do(params) if len(result.Errors) > 0 { extra_fields := map[string]interface{}{} @@ -917,12 +913,20 @@ func NewGQLExtContext() *GQLExtContext { return err, nil } + response_chan := make(chan Signal, 1) + ctx.Ext.resolver_response_lock.Lock() + ctx.Ext.resolver_response[sig.ID()] = response_chan + ctx.Ext.resolver_response_lock.Unlock() + err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, sig) if err != nil { + ctx.Ext.resolver_response_lock.Lock() + delete(ctx.Ext.resolver_response, sig.ID()) + ctx.Ext.resolver_response_lock.Unlock() return nil, err } - resp, err := WaitForResult(ctx.Chan, 100*time.Millisecond, sig.ID()) + resp, err := WaitForResult(response_chan, 100*time.Millisecond, sig.ID()) if err != nil { return nil, err } @@ -996,12 +1000,9 @@ type GQLExt struct { http_server *http.Server `json:"-"` http_done sync.WaitGroup `json:"-"` - // map of read request IDs to gql request ID - resolver_response map[uuid.UUID]uuid.UUID `json:"-"` + // map of read request IDs to response channels + resolver_response map[uuid.UUID]chan Signal `json:"-"` resolver_response_lock sync.RWMutex `json:"-"` - // map of gql request ID to active channel - resolver_chans map[uuid.UUID]chan Signal `json:"-"` - resolver_chans_lock sync.RWMutex `json:"-"` State string `json:"state"` tls_key []byte `json:"tls_key"` @@ -1023,28 +1024,18 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa // TODO: Forward to resolver if waiting for it sig := signal.(ErrorSignal) ext.resolver_response_lock.RLock() - resolver_id, exists := ext.resolver_response[sig.UUID] + resolver_chan, exists := ext.resolver_response[sig.UUID] ext.resolver_response_lock.RUnlock() if exists == true { ext.resolver_response_lock.Lock() delete(ext.resolver_response, sig.UUID) ext.resolver_response_lock.Unlock() - ext.resolver_chans_lock.RLock() - resolver_chan, exists := ext.resolver_chans[resolver_id] - ext.resolver_chans_lock.RUnlock() - if exists == true { - select { - case resolver_chan <- sig: - ctx.Log.Logf("gql", "Forwarded error to resolver %s, %+v", resolver_id, sig) - default: - ctx.Log.Logf("gql", "Resolver %s channel overflow %+v", resolver_id, sig) - ext.resolver_chans_lock.Lock() - delete(ext.resolver_chans, resolver_id) - ext.resolver_chans_lock.Unlock() - } - } else { - ctx.Log.Logf("gql", "received error signal response for resolver %s which doesn't exist", resolver_id) + select { + case resolver_chan <- sig: + ctx.Log.Logf("gql", "Forwarded error to resolver, %+v", sig) + default: + ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig) } } else { @@ -1053,7 +1044,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa } else if signal.Type() == ReadResultSignalType { sig := signal.(ReadResultSignal) ext.resolver_response_lock.RLock() - resolver_id, exists := ext.resolver_response[sig.UUID] + resolver_chan, exists := ext.resolver_response[sig.ID()] ext.resolver_response_lock.RUnlock() if exists == true { @@ -1061,22 +1052,11 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa delete(ext.resolver_response, sig.UUID) ext.resolver_response_lock.Unlock() - ext.resolver_chans_lock.RLock() - resolver_chan, exists := ext.resolver_chans[resolver_id] - ext.resolver_chans_lock.RUnlock() - - if exists == true { - select { - case resolver_chan <- sig: - ctx.Log.Logf("gql", "Forwarded to resolver %s, %+v", resolver_id, sig) - default: - ctx.Log.Logf("gql", "Resolver %s channel overflow %+v", resolver_id, sig) - ext.resolver_chans_lock.Lock() - delete(ext.resolver_chans, resolver_id) - ext.resolver_chans_lock.Unlock() - } - } else { - ctx.Log.Logf("gql", "Received message for waiting resolver %s which doesn't exist - %+v", resolver_id, sig) + select { + case resolver_chan <- sig: + ctx.Log.Logf("gql", "Forwarded to resolver, %+v", sig) + default: + ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig) } } else { ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig) @@ -1195,8 +1175,7 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta return &GQLExt{ State: state, Listen: listen, - resolver_response: map[uuid.UUID]uuid.UUID{}, - resolver_chans: map[uuid.UUID]chan Signal{}, + resolver_response: map[uuid.UUID]chan Signal{}, tls_cert: tls_cert, tls_key: tls_key, }, nil diff --git a/gql_query.go b/gql_query.go index 45217c4..9e23d0a 100644 --- a/gql_query.go +++ b/gql_query.go @@ -40,7 +40,8 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ fields := GetResolveFields(ctx.Context, p) ctx.Context.Log.Logf("gql", "RESOLVE_NODES(%+v): %+v", ids, fields) - read_signals := map[NodeID]uuid.UUID{} + resp_channels := map[uuid.UUID]chan Signal{} + node_ids := map[uuid.UUID]NodeID{} for _, id := range(ids) { // Get a list of fields that will be written ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields) @@ -50,28 +51,31 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error read_signal := NewReadSignal(ext_fields) + response_chan := make(chan Signal, 1) ctx.Ext.resolver_response_lock.Lock() - ctx.Ext.resolver_response[read_signal.UUID] = ctx.ID + ctx.Ext.resolver_response[read_signal.ID()] = response_chan ctx.Ext.resolver_response_lock.Unlock() + resp_channels[read_signal.ID()] = response_chan + node_ids[read_signal.ID()] = id + err = ctx.Context.Send(ctx.Server.ID, id, read_signal) - read_signals[id] = read_signal.UUID if err != nil { ctx.Ext.resolver_response_lock.Lock() - delete(ctx.Ext.resolver_response, read_signal.UUID) + delete(ctx.Ext.resolver_response, read_signal.ID()) ctx.Ext.resolver_response_lock.Unlock() return nil, err } } responses := []NodeResult{} - for node_id, sig_id := range(read_signals) { + for sig_id, response_chan := range(resp_channels) { // Wait for the response, returning an error on timeout - response, err := WaitForResult(ctx.Chan, time.Millisecond*100, sig_id) + response, err := WaitForResult(response_chan, time.Millisecond*100, sig_id) if err != nil { return nil, err } - responses = append(responses, NodeResult{node_id, response.(ReadResultSignal)}) + responses = append(responses, NodeResult{node_ids[sig_id], response.(ReadResultSignal)}) } return responses, nil diff --git a/node.go b/node.go index 79eb357..f3319a6 100644 --- a/node.go +++ b/node.go @@ -249,7 +249,7 @@ func nodeLoop(ctx *Context, node *Node) error { ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) } else { result := ReadNodeFields(ctx, node, source, read_signal.Extensions) - ctx.Send(node.ID, source, NewReadResultSignal(read_signal.UUID, node.Type, result)) + ctx.Send(node.ID, source, NewReadResultSignal(read_signal.ID(), node.Type, result)) } } diff --git a/signal.go b/signal.go index e910194..c800639 100644 --- a/signal.go +++ b/signal.go @@ -97,7 +97,7 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time. type BaseSignal struct { SignalDirection SignalDirection `json:"direction"` SignalType SignalType `json:"type"` - uuid.UUID `json:"id"` + UUID uuid.UUID `json:"id"` } func (signal BaseSignal) ID() uuid.UUID { @@ -254,7 +254,6 @@ func (signal StringSignal) Permission() Action { type ReadSignal struct { BaseSignal - UUID uuid.UUID Extensions map[ExtType][]string `json:"extensions"` } @@ -264,7 +263,6 @@ func (signal ReadSignal) Serialize() ([]byte, error) { func NewReadSignal(exts map[ExtType][]string) ReadSignal { return ReadSignal{ - UUID: uuid.New(), BaseSignal: NewDirectSignal(ReadSignalType), Extensions: exts, }