From b9a2cceaf191b2ced318a910a4e2ee3efeb76441 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Mon, 31 Jul 2023 21:03:48 -0600 Subject: [PATCH] Moved gql response channel interaction into methods --- gql.go | 58 ++++++++++++++++++++++++++++------------------------ gql_query.go | 9 ++------ 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/gql.go b/gql.go index dd946d4..667e9cd 100644 --- a/gql.go +++ b/gql.go @@ -913,16 +913,10 @@ func NewGQLExtContext() *GQLExtContext { return err, nil } - response_chan := make(chan Signal, 1) - ctx.Ext.resolver_response_lock.Lock() - ctx.Ext.resolver_response[sig.ID()] = response_chan - ctx.Ext.resolver_response_lock.Unlock() - + response_chan := ctx.Ext.GetResponseChannel(sig.ID()) err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, sig) if err != nil { - ctx.Ext.resolver_response_lock.Lock() - delete(ctx.Ext.resolver_response, sig.ID()) - ctx.Ext.resolver_response_lock.Unlock() + ctx.Ext.FreeResponseChannel(sig.ID()) return nil, err } @@ -1018,21 +1012,38 @@ func (ext *GQLExt) Field(name string) interface{} { }) } +func (ext *GQLExt) GetResponseChannel(req_id uuid.UUID) chan Signal { + response_chan := make(chan Signal, 1) + ext.resolver_response_lock.Lock() + ext.resolver_response[req_id] = response_chan + ext.resolver_response_lock.Unlock() + return response_chan +} + +func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal { + ext.resolver_response_lock.RLock() + response_chan, exists := ext.resolver_response[req_id] + ext.resolver_response_lock.RUnlock() + + if exists == true { + ext.resolver_response_lock.Lock() + delete(ext.resolver_response, req_id) + ext.resolver_response_lock.Unlock() + return response_chan + } else { + return nil + } +} + func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { // Process ReadResultSignalType by forwarding it to the waiting resolver if signal.Type() == ErrorSignalType { // TODO: Forward to resolver if waiting for it sig := signal.(ErrorSignal) - ext.resolver_response_lock.RLock() - resolver_chan, exists := ext.resolver_response[sig.UUID] - ext.resolver_response_lock.RUnlock() - if exists == true { - ext.resolver_response_lock.Lock() - delete(ext.resolver_response, sig.UUID) - ext.resolver_response_lock.Unlock() - + response_chan := ext.FreeResponseChannel(sig.ID()) + if response_chan != nil { select { - case resolver_chan <- sig: + case response_chan <- sig: ctx.Log.Logf("gql", "Forwarded error to resolver, %+v", sig) default: ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig) @@ -1043,17 +1054,10 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa } } else if signal.Type() == ReadResultSignalType { sig := signal.(ReadResultSignal) - ext.resolver_response_lock.RLock() - resolver_chan, exists := ext.resolver_response[sig.ID()] - ext.resolver_response_lock.RUnlock() - - if exists == true { - ext.resolver_response_lock.Lock() - delete(ext.resolver_response, sig.UUID) - ext.resolver_response_lock.Unlock() - + response_chan := ext.FreeResponseChannel(sig.ID()) + if response_chan != nil { select { - case resolver_chan <- sig: + case response_chan <- sig: ctx.Log.Logf("gql", "Forwarded to resolver, %+v", sig) default: ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig) diff --git a/gql_query.go b/gql_query.go index 9e23d0a..9f2c26c 100644 --- a/gql_query.go +++ b/gql_query.go @@ -51,19 +51,14 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error read_signal := NewReadSignal(ext_fields) - response_chan := make(chan Signal, 1) - ctx.Ext.resolver_response_lock.Lock() - ctx.Ext.resolver_response[read_signal.ID()] = response_chan - ctx.Ext.resolver_response_lock.Unlock() + response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) resp_channels[read_signal.ID()] = response_chan node_ids[read_signal.ID()] = id err = ctx.Context.Send(ctx.Server.ID, id, read_signal) if err != nil { - ctx.Ext.resolver_response_lock.Lock() - delete(ctx.Ext.resolver_response, read_signal.ID()) - ctx.Ext.resolver_response_lock.Unlock() + ctx.Ext.FreeResponseChannel(read_signal.ID()) return nil, err } }