Changed resolves to map back to request channel specifically instead of to the context

gql_cataclysm
noah metz 2023-07-31 20:53:56 -06:00
parent 47151905a0
commit df09433b88
4 changed files with 40 additions and 59 deletions

@ -168,8 +168,8 @@ type ResolveContext struct {
// ID generated for the context so the gql extension can route data to it
ID uuid.UUID
// Channel for the gql extension to route data to this context
Chan chan Signal
// Channels for the gql extension to route data to this context
Chans map[uuid.UUID]chan Signal
// Graph Context this resolver is running under
Context *Context
@ -203,7 +203,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
return &ResolveContext{
Ext: gql_ext,
ID: uuid.New(),
Chan: make(chan Signal, GQL_RESOLVER_CHAN_SIZE),
Chans: map[uuid.UUID]chan Signal{},
Context: ctx,
GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext),
Server: server,
@ -254,10 +254,6 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
params.VariableValues = query.Variables
}
gql_ext.resolver_chans_lock.Lock()
gql_ext.resolver_chans[resolve_context.ID] = resolve_context.Chan
gql_ext.resolver_chans_lock.Unlock()
result := graphql.Do(params)
if len(result.Errors) > 0 {
extra_fields := map[string]interface{}{}
@ -917,12 +913,20 @@ func NewGQLExtContext() *GQLExtContext {
return err, nil
}
response_chan := make(chan Signal, 1)
ctx.Ext.resolver_response_lock.Lock()
ctx.Ext.resolver_response[sig.ID()] = response_chan
ctx.Ext.resolver_response_lock.Unlock()
err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, sig)
if err != nil {
ctx.Ext.resolver_response_lock.Lock()
delete(ctx.Ext.resolver_response, sig.ID())
ctx.Ext.resolver_response_lock.Unlock()
return nil, err
}
resp, err := WaitForResult(ctx.Chan, 100*time.Millisecond, sig.ID())
resp, err := WaitForResult(response_chan, 100*time.Millisecond, sig.ID())
if err != nil {
return nil, err
}
@ -996,12 +1000,9 @@ type GQLExt struct {
http_server *http.Server `json:"-"`
http_done sync.WaitGroup `json:"-"`
// map of read request IDs to gql request ID
resolver_response map[uuid.UUID]uuid.UUID `json:"-"`
// map of read request IDs to response channels
resolver_response map[uuid.UUID]chan Signal `json:"-"`
resolver_response_lock sync.RWMutex `json:"-"`
// map of gql request ID to active channel
resolver_chans map[uuid.UUID]chan Signal `json:"-"`
resolver_chans_lock sync.RWMutex `json:"-"`
State string `json:"state"`
tls_key []byte `json:"tls_key"`
@ -1023,28 +1024,18 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
// TODO: Forward to resolver if waiting for it
sig := signal.(ErrorSignal)
ext.resolver_response_lock.RLock()
resolver_id, exists := ext.resolver_response[sig.UUID]
resolver_chan, exists := ext.resolver_response[sig.UUID]
ext.resolver_response_lock.RUnlock()
if exists == true {
ext.resolver_response_lock.Lock()
delete(ext.resolver_response, sig.UUID)
ext.resolver_response_lock.Unlock()
ext.resolver_chans_lock.RLock()
resolver_chan, exists := ext.resolver_chans[resolver_id]
ext.resolver_chans_lock.RUnlock()
if exists == true {
select {
case resolver_chan <- sig:
ctx.Log.Logf("gql", "Forwarded error to resolver %s, %+v", resolver_id, sig)
default:
ctx.Log.Logf("gql", "Resolver %s channel overflow %+v", resolver_id, sig)
ext.resolver_chans_lock.Lock()
delete(ext.resolver_chans, resolver_id)
ext.resolver_chans_lock.Unlock()
}
} else {
ctx.Log.Logf("gql", "received error signal response for resolver %s which doesn't exist", resolver_id)
select {
case resolver_chan <- sig:
ctx.Log.Logf("gql", "Forwarded error to resolver, %+v", sig)
default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
}
} else {
@ -1053,7 +1044,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
} else if signal.Type() == ReadResultSignalType {
sig := signal.(ReadResultSignal)
ext.resolver_response_lock.RLock()
resolver_id, exists := ext.resolver_response[sig.UUID]
resolver_chan, exists := ext.resolver_response[sig.ID()]
ext.resolver_response_lock.RUnlock()
if exists == true {
@ -1061,22 +1052,11 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
delete(ext.resolver_response, sig.UUID)
ext.resolver_response_lock.Unlock()
ext.resolver_chans_lock.RLock()
resolver_chan, exists := ext.resolver_chans[resolver_id]
ext.resolver_chans_lock.RUnlock()
if exists == true {
select {
case resolver_chan <- sig:
ctx.Log.Logf("gql", "Forwarded to resolver %s, %+v", resolver_id, sig)
default:
ctx.Log.Logf("gql", "Resolver %s channel overflow %+v", resolver_id, sig)
ext.resolver_chans_lock.Lock()
delete(ext.resolver_chans, resolver_id)
ext.resolver_chans_lock.Unlock()
}
} else {
ctx.Log.Logf("gql", "Received message for waiting resolver %s which doesn't exist - %+v", resolver_id, sig)
select {
case resolver_chan <- sig:
ctx.Log.Logf("gql", "Forwarded to resolver, %+v", sig)
default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
}
} else {
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
@ -1195,8 +1175,7 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta
return &GQLExt{
State: state,
Listen: listen,
resolver_response: map[uuid.UUID]uuid.UUID{},
resolver_chans: map[uuid.UUID]chan Signal{},
resolver_response: map[uuid.UUID]chan Signal{},
tls_cert: tls_cert,
tls_key: tls_key,
}, nil

@ -40,7 +40,8 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
fields := GetResolveFields(ctx.Context, p)
ctx.Context.Log.Logf("gql", "RESOLVE_NODES(%+v): %+v", ids, fields)
read_signals := map[NodeID]uuid.UUID{}
resp_channels := map[uuid.UUID]chan Signal{}
node_ids := map[uuid.UUID]NodeID{}
for _, id := range(ids) {
// Get a list of fields that will be written
ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields)
@ -50,28 +51,31 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
read_signal := NewReadSignal(ext_fields)
response_chan := make(chan Signal, 1)
ctx.Ext.resolver_response_lock.Lock()
ctx.Ext.resolver_response[read_signal.UUID] = ctx.ID
ctx.Ext.resolver_response[read_signal.ID()] = response_chan
ctx.Ext.resolver_response_lock.Unlock()
resp_channels[read_signal.ID()] = response_chan
node_ids[read_signal.ID()] = id
err = ctx.Context.Send(ctx.Server.ID, id, read_signal)
read_signals[id] = read_signal.UUID
if err != nil {
ctx.Ext.resolver_response_lock.Lock()
delete(ctx.Ext.resolver_response, read_signal.UUID)
delete(ctx.Ext.resolver_response, read_signal.ID())
ctx.Ext.resolver_response_lock.Unlock()
return nil, err
}
}
responses := []NodeResult{}
for node_id, sig_id := range(read_signals) {
for sig_id, response_chan := range(resp_channels) {
// Wait for the response, returning an error on timeout
response, err := WaitForResult(ctx.Chan, time.Millisecond*100, sig_id)
response, err := WaitForResult(response_chan, time.Millisecond*100, sig_id)
if err != nil {
return nil, err
}
responses = append(responses, NodeResult{node_id, response.(ReadResultSignal)})
responses = append(responses, NodeResult{node_ids[sig_id], response.(ReadResultSignal)})
}
return responses, nil

@ -249,7 +249,7 @@ func nodeLoop(ctx *Context, node *Node) error {
ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal)
} else {
result := ReadNodeFields(ctx, node, source, read_signal.Extensions)
ctx.Send(node.ID, source, NewReadResultSignal(read_signal.UUID, node.Type, result))
ctx.Send(node.ID, source, NewReadResultSignal(read_signal.ID(), node.Type, result))
}
}

@ -97,7 +97,7 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.
type BaseSignal struct {
SignalDirection SignalDirection `json:"direction"`
SignalType SignalType `json:"type"`
uuid.UUID `json:"id"`
UUID uuid.UUID `json:"id"`
}
func (signal BaseSignal) ID() uuid.UUID {
@ -254,7 +254,6 @@ func (signal StringSignal) Permission() Action {
type ReadSignal struct {
BaseSignal
UUID uuid.UUID
Extensions map[ExtType][]string `json:"extensions"`
}
@ -264,7 +263,6 @@ func (signal ReadSignal) Serialize() ([]byte, error) {
func NewReadSignal(exts map[ExtType][]string) ReadSignal {
return ReadSignal{
UUID: uuid.New(),
BaseSignal: NewDirectSignal(ReadSignalType),
Extensions: exts,
}