Forward status signals to resolvers

gql_cataclysm
noah metz 2023-09-18 19:55:55 -06:00
parent d34304f6ad
commit ab5b922a7d
4 changed files with 92 additions and 9 deletions

@ -298,6 +298,9 @@ func checkForAuthHeader(header http.Header) (string, bool) {
// Context passed to each resolve execution // Context passed to each resolve execution
type ResolveContext struct { type ResolveContext struct {
// Resolution ID
ID uuid.UUID
// Channels for the gql extension to route data to this context // Channels for the gql extension to route data to this context
Chans map[uuid.UUID]chan Signal Chans map[uuid.UUID]chan Signal
@ -324,7 +327,7 @@ type ResolveContext struct {
Key ed25519.PrivateKey Key ed25519.PrivateKey
} }
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request, id uuid.UUID) (*ResolveContext, error) {
id_b64, key_b64, ok := r.BasicAuth() id_b64, key_b64, ok := r.BasicAuth()
if ok == false { if ok == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header")
@ -365,6 +368,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
} }
return &ResolveContext{ return &ResolveContext{
ID: id,
Ext: gql_ext, Ext: gql_ext,
Chans: map[uuid.UUID]chan Signal{}, Chans: map[uuid.UUID]chan Signal{},
Context: ctx, Context: ctx,
@ -386,7 +390,7 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, gql_ext, r) resolve_context, err := NewResolveContext(ctx, server, gql_ext, r, uuid.New())
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err)))
@ -487,7 +491,7 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, gql_ext, r) resolve_context, err := NewResolveContext(ctx, server, gql_ext, r, uuid.New())
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
return return
@ -1065,7 +1069,12 @@ func NewGQLExtContext() *GQLExtContext {
if err != nil { if err != nil {
return nil, err return nil, err
} }
c := make(chan interface{}, 1)
c, err := ctx.Ext.AddSubscription(ctx.ID)
if err != nil {
return nil, err
}
nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID}) nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID})
if err != nil { if err != nil {
return nil, err return nil, err
@ -1073,7 +1082,6 @@ func NewGQLExtContext() *GQLExtContext {
return nil, fmt.Errorf("wrong length of nodes returned") return nil, fmt.Errorf("wrong length of nodes returned")
} }
ctx.Context.Log.Logf("gql", "NODES: %+v", nodes[0])
c <- nodes[0] c <- nodes[0]
return c, nil return c, nil
@ -1083,12 +1091,22 @@ func NewGQLExtContext() *GQLExtContext {
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx.Context.Log.Logf("gql_subscribe", "SUBSCRIBE_RESOLVE: %+v", p.Source)
switch source := p.Source.(type) { switch source := p.Source.(type) {
case NodeResult: case NodeResult:
case StatusSignal: case *StatusSignal:
delete(ctx.NodeCache, source.Source) delete(ctx.NodeCache, source.Source)
ctx.Context.Log.Logf("gql_subscribe", "Deleting %+v from NodeCache", source.Source) ctx.Context.Log.Logf("gql_subscribe", "Deleting %+v from NodeCache", source.Source)
if source.Source == ctx.Server.ID {
nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID})
if err != nil {
return nil, err
} else if len(nodes) != 1 {
return nil, fmt.Errorf("wrong length of nodes returned")
}
ctx.NodeCache[ctx.Server.ID] = nodes[0]
}
default: default:
return nil, fmt.Errorf("Don't know how to handle %+v", source) return nil, fmt.Errorf("Don't know how to handle %+v", source)
} }
@ -1155,11 +1173,19 @@ func NewGQLExtContext() *GQLExtContext {
return &context return &context
} }
type SubscriptionInfo struct {
ID uuid.UUID
Channel chan interface{}
}
type GQLExt struct { type GQLExt struct {
tcp_listener net.Listener tcp_listener net.Listener
http_server *http.Server http_server *http.Server
http_done sync.WaitGroup http_done sync.WaitGroup
subscriptions []SubscriptionInfo
subscriptions_lock sync.RWMutex
// map of read request IDs to response channels // map of read request IDs to response channels
resolver_response map[uuid.UUID]chan Signal resolver_response map[uuid.UUID]chan Signal
resolver_response_lock sync.RWMutex resolver_response_lock sync.RWMutex
@ -1169,6 +1195,41 @@ type GQLExt struct {
Listen string `gv:"listen"` Listen string `gv:"listen"`
} }
func (ext *GQLExt) AddSubscription(id uuid.UUID) (chan interface{}, error) {
ext.subscriptions_lock.Lock()
defer ext.subscriptions_lock.Unlock()
for _, info := range(ext.subscriptions) {
if info.ID == id {
return nil, fmt.Errorf("%+v already in subscription list", info.ID)
}
}
c := make(chan interface{}, 1)
ext.subscriptions = append(ext.subscriptions, SubscriptionInfo{
id,
c,
})
return c, nil
}
func (ext *GQLExt) RemoveSubscription(id uuid.UUID) error {
ext.subscriptions_lock.Lock()
defer ext.subscriptions_lock.Unlock()
for i, info := range(ext.subscriptions) {
if info.ID == id {
ext.subscriptions[i] = ext.subscriptions[len(ext.subscriptions)]
ext.subscriptions = ext.subscriptions[:len(ext.subscriptions)-1]
return nil
}
}
return fmt.Errorf("%+v not in subscription list", id)
}
func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal { func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal {
ext.resolver_response_lock.RLock() ext.resolver_response_lock.RLock()
response_chan, _ := ext.resolver_response[req_id] response_chan, _ := ext.resolver_response[req_id]
@ -1232,6 +1293,18 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} else { } else {
ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err)
} }
case *StatusSignal:
ext.subscriptions_lock.RLock()
ctx.Log.Logf("gql", "forwarding status signal from %+v to resolvers %+v", sig.Source, ext.subscriptions)
for _, resolver := range(ext.subscriptions) {
select {
case resolver.Channel <- sig:
ctx.Log.Logf("gql_subscribe", "forwarded status signal to resolver: %+v", resolver.ID)
default:
ctx.Log.Logf("gql_subscribe", "resolver channel overflow: %+v", resolver.ID)
}
}
ext.subscriptions_lock.RUnlock()
} }
return nil return nil
} }
@ -1308,6 +1381,7 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte) (*G
return &GQLExt{ return &GQLExt{
Listen: listen, Listen: listen,
resolver_response: map[uuid.UUID]chan Signal{}, resolver_response: map[uuid.UUID]chan Signal{},
subscriptions: []SubscriptionInfo{},
TLSCert: tls_cert, TLSCert: tls_cert,
TLSKey: tls_key, TLSKey: tls_key,
}, nil }, nil

@ -149,10 +149,8 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
} }
} }
} }
ctx.Context.Log.Logf("gql_subscribe", "CACHED_EXISTING_NODE: %+v", response.NodeID)
} else { } else {
ctx.NodeCache[response.NodeID] = responses[idx] ctx.NodeCache[response.NodeID] = responses[idx]
ctx.Context.Log.Logf("gql_subscribe", "CACHED_NEW_NODE: %+v", response.NodeID)
} }
} }

@ -19,7 +19,7 @@ import (
) )
func TestGQLServer(t *testing.T) { func TestGQLServer(t *testing.T) {
ctx := logTestContext(t, []string{"test", "gql_subscribe"}) ctx := logTestContext(t, []string{"test"})
TestNodeType := NewNodeType("TEST") TestNodeType := NewNodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType})
@ -57,6 +57,7 @@ func TestGQLServer(t *testing.T) {
gql_id: { gql_id: {
SerializedType(LinkSignalType): nil, SerializedType(LinkSignalType): nil,
SerializedType(ReadSignalType): nil, SerializedType(ReadSignalType): nil,
SerializedType(LockSignalType): nil,
}, },
}) })
@ -192,6 +193,15 @@ func TestGQLServer(t *testing.T) {
n, err = ws.Read(resp) n, err = ws.Read(resp)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "SUB: %s", resp[:n]) ctx.Log.Logf("test", "SUB: %s", resp[:n])
msgs := Messages{}
msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStatusSignal(gql.ID, "test_status"), gql.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
n, err = ws.Read(resp)
fatalErr(t, err)
ctx.Log.Logf("test", "SUB: %s", resp[:n])
} }
SubGQL(sub_1) SubGQL(sub_1)

@ -30,6 +30,7 @@ func (listener *ListenerExt) Type() ExtType {
// Send the signal to the channel, logging an overflow if it occurs // Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, reflect.TypeOf(signal)) ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, reflect.TypeOf(signal))
ctx.Log.Logf("listener_debug", "LISTENER_DETAIL %+v", signal)
select { select {
case ext.Chan <- signal: case ext.Chan <- signal:
default: default: