diff --git a/gql_graph.go b/gql_graph.go index 9067c7d..9a50d25 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -699,22 +699,18 @@ func GQLTypeSignalInput() *graphql.InputObject { } func GQLSubscribeSignal(p graphql.ResolveParams) (interface{}, error) { - return GQLSubscribeFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error) { - if signal == nil { - return nil, nil - } + return GQLSubscribeFn(p, false, func(ctx *GraphContext, server *GQLThread, signal GraphSignal, p graphql.ResolveParams)(interface{}, error) { return signal, nil }) } func GQLSubscribeSelf(p graphql.ResolveParams) (interface{}, error) { - server := p.Context.Value("gql_server").(*GQLThread) - return GQLSubscribeFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error) { + return GQLSubscribeFn(p, true, func(ctx *GraphContext, server *GQLThread, signal GraphSignal, p graphql.ResolveParams)(interface{}, error) { return server, nil }) } -func GQLSubscribeFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { +func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*GraphContext, *GQLThread, GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { server, ok := p.Context.Value("gql_server").(*GQLThread) if ok == false { return nil, fmt.Errorf("Failed to get gql_server from context and cast to GQLServer") @@ -729,13 +725,15 @@ func GQLSubscribeFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.Resolv go func(c chan interface{}, server *GQLThread) { ctx.Log.Logf("gqlws", "GQL_SUBSCRIBE_THREAD_START") sig_c := server.UpdateChannel(1) - sig_c <- nil + if send_nil == true { + sig_c <- nil + } for { val, ok := <- sig_c if ok == false { return } - ret, err := fn(val, p) + ret, err := fn(ctx, server, val, p) if err != nil { ctx.Log.Logf("gqlws", "type convertor error %s", err) return