Updated gql subscriptions to be send less messages

master
noah metz 2024-03-25 18:49:52 -06:00
parent ab76f09923
commit a4115a4f99
5 changed files with 157 additions and 205 deletions

@ -47,8 +47,6 @@ type TypeInfo struct {
Serialize SerializeFn
Deserialize DeserializeFn
Resolve func(interface{},graphql.ResolveParams)(interface{},error)
}
type ExtensionInfo struct {
@ -135,7 +133,6 @@ func (ctx *Context) GQLType(t reflect.Type, node_type string) (graphql.Type, err
return nil, err
}
map_type := ctx.TypeTypes[t].Type
ctx.Log.Logf("gql", "Getting type for %s: %s", t, map_type)
return map_type, nil
}
case reflect.Pointer:
@ -224,24 +221,6 @@ func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) erro
Serialized: serialized_type,
Reflect: reflect_type,
Type: gql_map,
Resolve: func(v interface{},p graphql.ResolveParams) (interface{}, error) {
val := reflect.ValueOf(v)
if val.Type() != (reflect_type) {
return nil, fmt.Errorf("%s is not %s", val.Type(), reflect_type)
} else {
pairs := make([]Pair, val.Len())
iter := val.MapRange()
i := 0
for iter.Next() {
pairs[i] = Pair{
Key: iter.Key().Interface(),
Val: iter.Value().Interface(),
}
i += 1
}
return pairs, nil
}
},
}
ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type]
@ -286,17 +265,12 @@ func BuildSchema(ctx *Context, query, mutation *graphql.Object) (graphql.Schema,
return nil, err
}
c, err := ctx.Ext.AddSubscription(ctx.ID)
if err != nil {
return nil, err
}
first_result, err := query.Resolve(p)
c, err := ctx.Ext.AddSubscription(ctx.ID, ctx)
if err != nil {
return nil, err
}
c <- first_result
c <- nil
return c, nil
},
Resolve: query.Resolve,
@ -477,11 +451,7 @@ func RegisterNodeType(ctx *Context, name string, extensions []ExtType) error {
return nil, fmt.Errorf("Can't resolve unknown NodeType %s", node.NodeType)
}
if gql_resolve != nil {
return gql_resolve(node.Data[node_info.Fields[field_name]][field_name], p)
} else {
return node.Data[node_info.Fields[field_name]][field_name], nil
}
return gql_resolve(node.Data[node_info.Fields[field_name]][field_name], p)
},
})
}
@ -588,7 +558,6 @@ func RegisterObject[T any](ctx *Context) error {
Reflect: reflect_type,
Fields: field_infos,
Type: gql,
Resolve: nil,
}
ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type]
@ -631,7 +600,6 @@ func RegisterObjectNoGQL[T any](ctx *Context) error {
Reflect: reflect_type,
Fields: field_infos,
Type: nil,
Resolve: nil,
}
ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type]
@ -781,14 +749,13 @@ func RegisterEnum[E comparable](ctx *Context, str_map map[E]string) error {
Serialized: serialized_type,
Reflect: reflect_type,
Type: gql,
Resolve: nil,
}
ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type]
return nil
}
func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}, resolve func(interface{},graphql.ResolveParams)(interface{},error)) error {
func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}) error {
reflect_type := reflect.TypeFor[S]()
serialized_type := SerializedTypeFor[S]()
@ -809,7 +776,6 @@ func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, f
Serialized: serialized_type,
Reflect: reflect_type,
Type: gql,
Resolve: resolve,
}
ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type]
@ -920,18 +886,46 @@ func (ctx *Context) Send(node *Node, messages []SendMsg) error {
return nil
}
func resolveNodeID(val interface{}, p graphql.ResolveParams) (interface{}, error) {
id, ok := p.Source.(NodeID)
if ok == false {
return nil, fmt.Errorf("%+v is not NodeID", p.Source)
}
return ResolveNode(id, p)
}
// TODO: Cache these functions so they're not duplicated when called with the same t
func (ctx *Context)GQLResolve(t reflect.Type, node_type string) (func(interface{},graphql.ResolveParams)(interface{},error)) {
info, mapped := ctx.TypeTypes[t]
if mapped {
return info.Resolve
if t == reflect.TypeFor[NodeID]() {
return resolveNodeID
} else {
switch t.Kind() {
//case reflect.Array:
//case reflect.Slice:
case reflect.Map:
return func(v interface{}, p graphql.ResolveParams) (interface{}, error) {
val := reflect.ValueOf(v)
if val.Type() != t {
return nil, fmt.Errorf("%s is not %s", reflect.TypeOf(val), t)
} else {
pairs := make([]Pair, val.Len())
iter := val.MapRange()
i := 0
for iter.Next() {
pairs[i] = Pair{
Key: iter.Key().Interface(),
Val: iter.Value().Interface(),
}
i += 1
}
return pairs, nil
}
}
case reflect.Pointer:
return ctx.GQLResolve(t.Elem(), node_type)
default:
return nil
return func(v interface{}, p graphql.ResolveParams) (interface{}, error) {
return v, nil
}
}
}
}
@ -976,20 +970,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
}
var err error
err = RegisterScalar[NodeID](ctx, stringify, unstringify[NodeID], unstringifyAST[NodeID], func(v interface{}, p graphql.ResolveParams)(interface{}, error) {
id, ok := v.(NodeID)
if ok == false {
return nil, fmt.Errorf("%+v is not NodeID", v)
}
node, err := ResolveNode(id, p)
if err != nil {
return nil, err
}
return node, nil
})
err = RegisterScalar[NodeID](ctx, stringify, unstringify[NodeID], unstringifyAST[NodeID])
if err != nil {
return nil, err
}
@ -1004,12 +986,12 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err
}
err = RegisterScalar[NodeType](ctx, identity, coerce[NodeType], astInt[NodeType], nil)
err = RegisterScalar[NodeType](ctx, identity, coerce[NodeType], astInt[NodeType])
if err != nil {
return nil, err
}
err = RegisterScalar[ExtType](ctx, identity, coerce[ExtType], astInt[ExtType], nil)
err = RegisterScalar[ExtType](ctx, identity, coerce[ExtType], astInt[ExtType])
if err != nil {
return nil, err
}
@ -1019,37 +1001,32 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err
}
err = RegisterScalar[bool](ctx, identity, coerce[bool], astBool[bool], nil)
err = RegisterScalar[bool](ctx, identity, coerce[bool], astBool[bool])
if err != nil {
return nil, err
}
err = RegisterScalar[int](ctx, identity, coerce[int], astInt[int], nil)
err = RegisterScalar[int](ctx, identity, coerce[int], astInt[int])
if err != nil {
return nil, err
}
err = RegisterScalar[uint32](ctx, identity, coerce[uint32], astInt[uint32], nil)
err = RegisterScalar[uint32](ctx, identity, coerce[uint32], astInt[uint32])
if err != nil {
return nil, err
}
err = RegisterScalar[uint8](ctx, identity, coerce[uint8], astInt[uint8], nil)
err = RegisterScalar[uint8](ctx, identity, coerce[uint8], astInt[uint8])
if err != nil {
return nil, err
}
err = RegisterScalar[time.Time](ctx, stringify, unstringify[time.Time], unstringifyAST[time.Time], nil)
err = RegisterScalar[time.Time](ctx, stringify, unstringify[time.Time], unstringifyAST[time.Time])
if err != nil {
return nil, err
}
err = RegisterScalar[string](ctx, identity, coerce[string], astString[string], nil)
if err != nil {
return nil, err
}
err = RegisterScalar[EventState](ctx, identity, coerce[EventState], astString[EventState], nil)
err = RegisterScalar[string](ctx, identity, coerce[string], astString[string])
if err != nil {
return nil, err
}
@ -1059,12 +1036,12 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err
}
err = RegisterScalar[uuid.UUID](ctx, stringify, unstringify[uuid.UUID], unstringifyAST[uuid.UUID], nil)
err = RegisterScalar[uuid.UUID](ctx, stringify, unstringify[uuid.UUID], unstringifyAST[uuid.UUID])
if err != nil {
return nil, err
}
err = RegisterScalar[Change](ctx, identity, coerce[Change], astString[Change], nil)
err = RegisterScalar[Change](ctx, identity, coerce[Change], astString[Change])
if err != nil {
return nil, err
}
@ -1095,11 +1072,6 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err
}
err = RegisterExtension[EventExt](ctx, nil)
if err != nil {
return nil, err
}
err = RegisterExtension[ListenerExt](ctx, nil)
if err != nil {
return nil, err
@ -1120,11 +1092,6 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err
}
err = RegisterObject[EventExt](ctx)
if err != nil {
return nil, err
}
err = RegisterObject[ListenerExt](ctx)
if err != nil {
return nil, err
@ -1145,24 +1112,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
if err != nil {
return nil, err
}
switch source := p.Source.(type) {
case *StatusSignal:
ctx.Context.Log.Logf("test", "StatusSignal: %+v", source)
cached_node, cached := ctx.NodeCache[source.Source]
ctx.Context.Log.Logf("test", "Cached: %t", cached)
if cached {
for ext_type, ext_changes := range(source.Changes) {
cached_ext, cached := cached_node.Data[ext_type]
if cached {
for _, field := range(ext_changes) {
delete(cached_ext, string(field))
}
cached_node.Data[ext_type] = cached_ext
}
}
ctx.NodeCache[source.Source] = cached_node
}
}
return ResolveNode(ctx.Server.ID, p)
},
},

@ -249,13 +249,14 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
for header, value := range(r.Header) {
header_map[header] = value
}
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, gql_ext)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(""))
return
} else {
ctx.Log.Logf("gql", "New Query: %s", resolve_context.ID)
}
req_ctx := context.Background()
@ -304,7 +305,6 @@ func sendOneResultAndClose(res *graphql.Result) chan *graphql.Result {
return resultChannel
}
func getOperationTypeOfReq(p graphql.Params) string{
source := source.NewSource(&source.Source{
Body: []byte(p.RequestString),
@ -330,18 +330,6 @@ func getOperationTypeOfReq(p graphql.Params) string{
return "END_OF_FUNCTION"
}
func GQLWSDo(ctx * Context, p graphql.Params) chan *graphql.Result {
operation := getOperationTypeOfReq(p)
ctx.Log.Logf("gqlws", "GQLWSDO_OPERATION: %s - %+v", operation, p.RequestString)
if operation == ast.OperationTypeSubscription {
return graphql.Subscribe(p)
} else {
res := graphql.Do(p)
return sendOneResultAndClose(res)
}
}
func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr)
@ -351,11 +339,12 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
header_map[header] = value
}
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, gql_ext)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
return
} else {
ctx.Log.Logf("gql", "New Subscription: %s", resolve_context.ID)
}
req_ctx := context.Background()
@ -429,11 +418,14 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
params.VariableValues = msg.Payload.Variables
}
res_chan := GQLWSDo(ctx, params)
if res_chan == nil {
ctx.Log.Logf("gqlws", "res_chan is nil")
var res_chan chan *graphql.Result
operation := getOperationTypeOfReq(params)
if operation == ast.OperationTypeSubscription {
res_chan = graphql.Subscribe(params)
} else {
ctx.Log.Logf("gqlws", "res_chan: %+v", res_chan)
res := graphql.Do(params)
res_chan = sendOneResultAndClose(res)
}
go func(res_chan chan *graphql.Result) {
@ -526,6 +518,7 @@ type SelfField struct {
type SubscriptionInfo struct {
ID uuid.UUID
NodeCache *map[NodeID]NodeResult
Channel chan interface{}
}
@ -569,7 +562,7 @@ func (ext *GQLExt) PostDeserialize(*Context) error {
return nil
}
func (ext *GQLExt) AddSubscription(id uuid.UUID) (chan interface{}, error) {
func (ext *GQLExt) AddSubscription(id uuid.UUID, ctx *ResolveContext) (chan interface{}, error) {
ext.subscriptions_lock.Lock()
defer ext.subscriptions_lock.Unlock()
@ -583,6 +576,7 @@ func (ext *GQLExt) AddSubscription(id uuid.UUID) (chan interface{}, error) {
ext.subscriptions = append(ext.subscriptions, SubscriptionInfo{
id,
&ctx.NodeCache,
c,
})
@ -645,8 +639,6 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
}
} else {
ctx.Log.Logf("gql", "received success signal response %+v with no mapped resolver", sig)
}
case *ErrorSignal:
@ -659,9 +651,6 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
}
} else {
ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig)
}
case *ReadResultSignal:
@ -669,23 +658,22 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
if response_chan != nil {
select {
case response_chan <- sig:
ctx.Log.Logf("gql", "Forwarded to resolver, %+v", sig)
default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
}
} else {
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
}
case *StatusSignal:
ext.subscriptions_lock.RLock()
ctx.Log.Logf("gql", "forwarding status signal from %+v to resolvers %+v", sig.Source, ext.subscriptions)
for _, resolver := range(ext.subscriptions) {
select {
case resolver.Channel <- sig:
ctx.Log.Logf("gql_subscribe", "forwarded status signal to resolver: %+v", resolver.ID)
default:
ctx.Log.Logf("gql_subscribe", "resolver channel overflow: %+v", resolver.ID)
for _, sub := range(ext.subscriptions) {
_, cached := (*sub.NodeCache)[sig.Source]
if cached {
select {
case sub.Channel <- sig:
ctx.Log.Logf("gql", "forwarded status signal %+v to subscription: %s", sig, sub.ID)
default:
ctx.Log.Logf("gql", "subscription channel overflow: %s", sub.ID)
}
}
}
ext.subscriptions_lock.RUnlock()

@ -64,16 +64,43 @@ func GetFields(ctx *Context, node_type string, selection_set *ast.SelectionSet)
}
// Returns the fields that need to be resolved
func GetResolveFields(id NodeID, ctx *ResolveContext, p graphql.ResolveParams) (map[ExtType][]string, error) {
m := map[ExtType][]string{}
func GetResolveFields(id NodeID, ctx *ResolveContext, p graphql.ResolveParams) []FieldIndex {
fields := []FieldIndex{}
for _, field := range(p.Info.FieldASTs) {
fields = append(fields, GetFields(ctx.Context, p.Info.ReturnType.Name(), field.SelectionSet)...)
}
return fields
}
func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) {
ctx, err := PrepResolve(p)
if err != nil {
return NodeResult{}, err
}
switch source := p.Source.(type) {
case *StatusSignal:
cached_node, cached := ctx.NodeCache[source.Source]
if cached {
for ext_type, ext_changes := range(source.Changes) {
cached_ext, cached := cached_node.Data[ext_type]
if cached {
for _, field := range(ext_changes) {
delete(cached_ext, string(field))
}
cached_node.Data[ext_type] = cached_ext
}
}
ctx.NodeCache[source.Source] = cached_node
}
}
cache, node_cached := ctx.NodeCache[id]
fields := GetResolveFields(id, ctx, p)
not_cached := map[ExtType][]string{}
for _, field := range(fields) {
ext_fields, exists := m[field.Extension]
ext_fields, exists := not_cached[field.Extension]
if exists == false {
ext_fields = []string{}
}
@ -88,71 +115,60 @@ func GetResolveFields(id NodeID, ctx *ResolveContext, p graphql.ResolveParams) (
}
}
m[field.Extension] = append(ext_fields, field.Tag)
not_cached[field.Extension] = append(ext_fields, field.Tag)
}
return m, nil
}
// TODO: instead of doing the read right away, check if any new fields need to be read
func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) {
ctx, err := PrepResolve(p)
if err != nil {
return NodeResult{}, err
}
if (len(not_cached) == 0) && (node_cached == true) {
ctx.Context.Log.Logf("gql", "No new fields to resolve for %s", id)
return cache, nil
} else {
ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", not_cached, id)
signal := NewReadSignal(not_cached)
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
// TODO: TIMEOUT DURATION
err = ctx.Context.Send(ctx.Server, []SendMsg{{
Dest: id,
Signal: signal,
}})
if err != nil {
ctx.Ext.FreeResponseChannel(signal.ID())
return NodeResult{}, err
}
fields, err := GetResolveFields(id, ctx, p)
if err != nil {
return NodeResult{}, err
}
ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", fields, id)
signal := NewReadSignal(fields)
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
// TODO: TIMEOUT DURATION
err = ctx.Context.Send(ctx.Server, []SendMsg{{
Dest: id,
Signal: signal,
}})
if err != nil {
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
ctx.Ext.FreeResponseChannel(signal.ID())
return NodeResult{}, err
}
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
ctx.Ext.FreeResponseChannel(signal.ID())
if err != nil {
return NodeResult{}, err
}
if err != nil {
return NodeResult{}, err
}
switch response := response.(type) {
case *ReadResultSignal:
cache, node_cached := ctx.NodeCache[id]
if node_cached == false {
cache = NodeResult{
NodeID: id,
NodeType: response.NodeType,
Data: response.Extensions,
}
} else {
for ext_type, ext_data := range(response.Extensions) {
cached_ext, ext_cached := cache.Data[ext_type]
if ext_cached {
for field_name, field := range(ext_data) {
cache.Data[ext_type][field_name] = field
}
} else {
cache.Data[ext_type] = ext_data
switch response := response.(type) {
case *ReadResultSignal:
if node_cached == false {
cache = NodeResult{
NodeID: id,
NodeType: response.NodeType,
Data: response.Extensions,
}
} else {
for ext_type, ext_data := range(response.Extensions) {
cached_ext, ext_cached := cache.Data[ext_type]
if ext_cached {
for field_name, field := range(ext_data) {
cache.Data[ext_type][field_name] = field
}
} else {
cache.Data[ext_type] = ext_data
}
cache.Data[ext_type] = cached_ext
cache.Data[ext_type] = cached_ext
}
}
}
ctx.NodeCache[id] = cache
return ctx.NodeCache[id], nil
default:
return NodeResult{}, fmt.Errorf("Bad read response: %+v", response)
ctx.NodeCache[id] = cache
return ctx.NodeCache[id], nil
default:
return NodeResult{}, fmt.Errorf("Bad read response: %+v", response)
}
}
}

@ -16,9 +16,8 @@ import (
"golang.org/x/net/websocket"
)
func TestGQLSubscribe(t *testing.T) {
ctx := logTestContext(t, []string{"test"})
ctx := logTestContext(t, []string{"test", "gql"})
n1, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil))
fatalErr(t, err)
@ -31,11 +30,14 @@ func TestGQLSubscribe(t *testing.T) {
gql, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext)
fatalErr(t, err)
query := "subscription { Self { ID, Type ... on Lockable { lockable_state } } }"
ctx.Log.Logf("test", "GQL: %s", gql.ID)
ctx.Log.Logf("test", "NODE: %s", n1.ID)
ctx.Log.Logf("test", "Node: %s", n1.ID)
ctx.Log.Logf("test", "Query: %s", query)
sub_1 := GQLPayload{
Query: "subscription { Self { ID, Type ... on Lockable { lockable_state } } }",
Query: query,
}
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
@ -117,10 +119,6 @@ func TestGQLSubscribe(t *testing.T) {
fatalErr(t, err)
ctx.Log.Logf("test", "SUB3: %s", resp[:n])
n, err = ws.Read(resp)
fatalErr(t, err)
ctx.Log.Logf("test", "SUB4: %s", resp[:n])
// TODO: check that there are no more messages sent to ws within a timeout
}

@ -167,7 +167,7 @@ type StatusSignal struct {
Changes map[ExtType]Changes `gv:"changes"`
}
func (signal StatusSignal) String() string {
return fmt.Sprintf("StatusSignal(%s, %+v)", signal.SignalHeader, signal.Changes)
return fmt.Sprintf("StatusSignal(%s: %+v)", signal.Source, signal.Changes)
}
func NewStatusSignal(source NodeID, changes map[ExtType]Changes) *StatusSignal {
return &StatusSignal{