diff --git a/gql.go b/gql.go index 4073868..8b31955 100644 --- a/gql.go +++ b/gql.go @@ -298,6 +298,9 @@ func checkForAuthHeader(header http.Header) (string, bool) { // Context passed to each resolve execution type ResolveContext struct { + // Resolution ID + ID uuid.UUID + // Channels for the gql extension to route data to this context Chans map[uuid.UUID]chan Signal @@ -324,7 +327,7 @@ type ResolveContext struct { Key ed25519.PrivateKey } -func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { +func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request, id uuid.UUID) (*ResolveContext, error) { id_b64, key_b64, ok := r.BasicAuth() if ok == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") @@ -365,6 +368,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ } return &ResolveContext{ + ID: id, Ext: gql_ext, Chans: map[uuid.UUID]chan Signal{}, Context: ctx, @@ -386,7 +390,7 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - resolve_context, err := NewResolveContext(ctx, server, gql_ext, r) + resolve_context, err := NewResolveContext(ctx, server, gql_ext, r, uuid.New()) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) @@ -487,7 +491,7 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - resolve_context, err := NewResolveContext(ctx, server, gql_ext, r) + resolve_context, err := NewResolveContext(ctx, server, gql_ext, r, uuid.New()) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) return @@ -1065,7 +1069,12 @@ func NewGQLExtContext() *GQLExtContext { if err != nil { return nil, err } - c := make(chan interface{}, 1) + + c, err := ctx.Ext.AddSubscription(ctx.ID) + if err != nil { + return nil, err + } + nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID}) if err != nil { return nil, err @@ -1073,7 +1082,6 @@ func NewGQLExtContext() *GQLExtContext { return nil, fmt.Errorf("wrong length of nodes returned") } - ctx.Context.Log.Logf("gql", "NODES: %+v", nodes[0]) c <- nodes[0] return c, nil @@ -1083,12 +1091,22 @@ func NewGQLExtContext() *GQLExtContext { if err != nil { return nil, err } + ctx.Context.Log.Logf("gql_subscribe", "SUBSCRIBE_RESOLVE: %+v", p.Source) switch source := p.Source.(type) { case NodeResult: - case StatusSignal: + case *StatusSignal: delete(ctx.NodeCache, source.Source) ctx.Context.Log.Logf("gql_subscribe", "Deleting %+v from NodeCache", source.Source) + if source.Source == ctx.Server.ID { + nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID}) + if err != nil { + return nil, err + } else if len(nodes) != 1 { + return nil, fmt.Errorf("wrong length of nodes returned") + } + ctx.NodeCache[ctx.Server.ID] = nodes[0] + } default: return nil, fmt.Errorf("Don't know how to handle %+v", source) } @@ -1155,11 +1173,19 @@ func NewGQLExtContext() *GQLExtContext { return &context } +type SubscriptionInfo struct { + ID uuid.UUID + Channel chan interface{} +} + type GQLExt struct { tcp_listener net.Listener http_server *http.Server http_done sync.WaitGroup + subscriptions []SubscriptionInfo + subscriptions_lock sync.RWMutex + // map of read request IDs to response channels resolver_response map[uuid.UUID]chan Signal resolver_response_lock sync.RWMutex @@ -1169,6 +1195,41 @@ type GQLExt struct { Listen string `gv:"listen"` } +func (ext *GQLExt) AddSubscription(id uuid.UUID) (chan interface{}, error) { + ext.subscriptions_lock.Lock() + defer ext.subscriptions_lock.Unlock() + + for _, info := range(ext.subscriptions) { + if info.ID == id { + return nil, fmt.Errorf("%+v already in subscription list", info.ID) + } + } + + c := make(chan interface{}, 1) + + ext.subscriptions = append(ext.subscriptions, SubscriptionInfo{ + id, + c, + }) + + return c, nil +} + +func (ext *GQLExt) RemoveSubscription(id uuid.UUID) error { + ext.subscriptions_lock.Lock() + defer ext.subscriptions_lock.Unlock() + + for i, info := range(ext.subscriptions) { + if info.ID == id { + ext.subscriptions[i] = ext.subscriptions[len(ext.subscriptions)] + ext.subscriptions = ext.subscriptions[:len(ext.subscriptions)-1] + return nil + } + } + + return fmt.Errorf("%+v not in subscription list", id) +} + func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal { ext.resolver_response_lock.RLock() response_chan, _ := ext.resolver_response[req_id] @@ -1232,6 +1293,18 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } else { ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) } + case *StatusSignal: + ext.subscriptions_lock.RLock() + ctx.Log.Logf("gql", "forwarding status signal from %+v to resolvers %+v", sig.Source, ext.subscriptions) + for _, resolver := range(ext.subscriptions) { + select { + case resolver.Channel <- sig: + ctx.Log.Logf("gql_subscribe", "forwarded status signal to resolver: %+v", resolver.ID) + default: + ctx.Log.Logf("gql_subscribe", "resolver channel overflow: %+v", resolver.ID) + } + } + ext.subscriptions_lock.RUnlock() } return nil } @@ -1308,6 +1381,7 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte) (*G return &GQLExt{ Listen: listen, resolver_response: map[uuid.UUID]chan Signal{}, + subscriptions: []SubscriptionInfo{}, TLSCert: tls_cert, TLSKey: tls_key, }, nil diff --git a/gql_node.go b/gql_node.go index 4849dcb..8591206 100644 --- a/gql_node.go +++ b/gql_node.go @@ -149,10 +149,8 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ } } } - ctx.Context.Log.Logf("gql_subscribe", "CACHED_EXISTING_NODE: %+v", response.NodeID) } else { ctx.NodeCache[response.NodeID] = responses[idx] - ctx.Context.Log.Logf("gql_subscribe", "CACHED_NEW_NODE: %+v", response.NodeID) } } diff --git a/gql_test.go b/gql_test.go index 4b10e03..04fb06e 100644 --- a/gql_test.go +++ b/gql_test.go @@ -19,7 +19,7 @@ import ( ) func TestGQLServer(t *testing.T) { - ctx := logTestContext(t, []string{"test", "gql_subscribe"}) + ctx := logTestContext(t, []string{"test"}) TestNodeType := NewNodeType("TEST") err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) @@ -57,6 +57,7 @@ func TestGQLServer(t *testing.T) { gql_id: { SerializedType(LinkSignalType): nil, SerializedType(ReadSignalType): nil, + SerializedType(LockSignalType): nil, }, }) @@ -192,6 +193,15 @@ func TestGQLServer(t *testing.T) { n, err = ws.Read(resp) fatalErr(t, err) ctx.Log.Logf("test", "SUB: %s", resp[:n]) + + msgs := Messages{} + msgs = msgs.Add(ctx, gql.ID, gql.Key, NewStatusSignal(gql.ID, "test_status"), gql.ID) + err = ctx.Send(msgs) + fatalErr(t, err) + + n, err = ws.Read(resp) + fatalErr(t, err) + ctx.Log.Logf("test", "SUB: %s", resp[:n]) } SubGQL(sub_1) diff --git a/listener.go b/listener.go index 9bbfb74..65cacdf 100644 --- a/listener.go +++ b/listener.go @@ -30,6 +30,7 @@ func (listener *ListenerExt) Type() ExtType { // Send the signal to the channel, logging an overflow if it occurs func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, reflect.TypeOf(signal)) + ctx.Log.Logf("listener_debug", "LISTENER_DETAIL %+v", signal) select { case ext.Chan <- signal: default: