From e92b2e508da0e424e61d7aad539f2f0bcdf67d75 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 29 Jul 2023 22:16:54 -0600 Subject: [PATCH] Added tests for lockable gql fields --- gql.go | 93 ++++++++++++++++++++++++++++++++++++++--------------- gql_test.go | 10 +++++- 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/gql.go b/gql.go index ed93112..7d56e0f 100644 --- a/gql.go +++ b/gql.go @@ -531,29 +531,31 @@ func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name st return fmt.Errorf("%s is already a field in the context, cannot add again", gql_name) } - ctx.Fields[gql_name] = Field{ext_type, acl_name, &graphql.Field{ - Type: gql_type, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) { - ext, exists := result.Result.Extensions[ext_type] - if exists == false { - return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type) - } + resolver := func(p graphql.ResolveParams)(interface{}, error) { + return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) { + ext, exists := result.Result.Extensions[ext_type] + if exists == false { + return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type) + } - val_if, exists := ext[acl_name] - if exists == false { - return nil, fmt.Errorf("%s is not in the fields of %s in the result", acl_name, ext_type) - } + val_if, exists := ext[acl_name] + if exists == false { + return nil, fmt.Errorf("%s is not in the fields of %s in the result", acl_name, ext_type) + } - var zero T - val, ok := val_if.(T) - if ok == false { - return nil, fmt.Errorf("%s.%s is not %s", ext_type, acl_name, reflect.TypeOf(zero)) - } + var zero T + val, ok := val_if.(T) + if ok == false { + return nil, fmt.Errorf("%s.%s is not %s", ext_type, acl_name, reflect.TypeOf(zero)) + } - return resolve_fn(p, val) - }) - }, + return resolve_fn(p, val) + }) + } + + ctx.Fields[gql_name] = Field{ext_type, acl_name, &graphql.Field{ + Type: gql_type, + Resolve: resolver, }} return nil } @@ -651,7 +653,8 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in }) ctx_interface.List = graphql.NewList(ctx_interface.Interface) - for field_name, self_field := range(self_fields) { + for field_name, field := range(self_fields) { + self_field := field err := RegisterField(ctx, ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName, func(p graphql.ResolveParams, val interface{})(interface{}, error) { ctx, err := PrepResolve(p) @@ -681,9 +684,9 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in node_fields[field_name] = ctx.Fields[field_name].Field } - for field_name, list_field := range(list_fields) { - err := RegisterField(ctx, ctx_interface.Interface, field_name, list_field.Extension, list_field.ACLName, - func(p graphql.ResolveParams, val interface{})(interface{}, error) { + for field_name, field := range(list_fields) { + list_field := field + resolve_fn := func(p graphql.ResolveParams, val interface{})(interface{}, error) { ctx, err := PrepResolve(p) if err != nil { return nil, err @@ -702,7 +705,9 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in return nil, fmt.Errorf("wrong length of nodes returned") } return nodes, nil - }) + } + + err := RegisterField(ctx, ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn) if err != nil { return err } @@ -806,6 +811,42 @@ func NewGQLExtContext() *GQLExtContext { }, }, }, map[string]ListField{ + "Requirements": ListField{ + "requirements", + LockableExtType, + func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) { + id_strs, ok := val.(map[NodeID]ReqState) + if ok == false { + return nil, fmt.Errorf("can't parse requirements %+v as string, %s", val, reflect.TypeOf(val)) + } + + ids := make([]NodeID, len(id_strs)) + i := 0 + for id, _ := range(id_strs) { + ids[i] = id + i++ + } + return ids, nil + }, + }, + "Dependencies": ListField{ + "dependencies", + LockableExtType, + func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) { + id_strs, ok := val.(map[NodeID]string) + if ok == false { + return nil, fmt.Errorf("can't parse dependencies %+v as string, %s", val, reflect.TypeOf(val)) + } + + ids := make([]NodeID, len(id_strs)) + i := 0 + for id, _ := range(id_strs) { + ids[i] = id + i++ + } + return ids, nil + }, + }, }) if err != nil { @@ -819,7 +860,7 @@ func NewGQLExtContext() *GQLExtContext { panic(err) } - err = context.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node"}, []string{"Listen"}) + err = context.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node", "Lockable"}, []string{"Listen", "Owner", "Requirements", "Dependencies"}) if err != nil { panic(err) } diff --git a/gql_test.go b/gql_test.go index 18763a0..31127ec 100644 --- a/gql_test.go +++ b/gql_test.go @@ -26,6 +26,14 @@ func TestGQL(t *testing.T) { gql := NewNode(ctx, nil, GQLNodeType, 10, nil, NewLockableExt(), NewACLExt(policy), gql_ext, NewGroupExt(nil)) n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(policy), listener_ext) + err = LinkRequirement(ctx, gql.ID, n1.ID) + fatalErr(t, err) + + _, err = WaitForSignal(ctx, listener_ext, time.Millisecond*10, LinkSignalType, func(sig StateSignal) bool { + return sig.State == "linked_as_dep" + }) + fatalErr(t, err) + ctx.Send(n1.ID, gql.ID, StateSignal{NewDirectSignal(GQLStateSignalType), "start_server"}) _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig StateSignal) bool { return sig.State == "server_started" @@ -47,7 +55,7 @@ func TestGQL(t *testing.T) { } req_2 := GQLPayload{ - Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQLServer { Listen } } }", + Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQLServer { Listen, Requirements { ID, TypeHash, Dependencies { ID } } } } }", Variables: map[string]interface{}{ "id": gql.ID.String(), },