Added tests for lockable gql fields

gql_cataclysm
noah metz 2023-07-29 22:16:54 -06:00
parent dca4de183e
commit e92b2e508d
2 changed files with 76 additions and 27 deletions

@ -531,29 +531,31 @@ func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name st
return fmt.Errorf("%s is already a field in the context, cannot add again", gql_name)
}
ctx.Fields[gql_name] = Field{ext_type, acl_name, &graphql.Field{
Type: gql_type,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) {
ext, exists := result.Result.Extensions[ext_type]
if exists == false {
return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type)
}
resolver := func(p graphql.ResolveParams)(interface{}, error) {
return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) {
ext, exists := result.Result.Extensions[ext_type]
if exists == false {
return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type)
}
val_if, exists := ext[acl_name]
if exists == false {
return nil, fmt.Errorf("%s is not in the fields of %s in the result", acl_name, ext_type)
}
val_if, exists := ext[acl_name]
if exists == false {
return nil, fmt.Errorf("%s is not in the fields of %s in the result", acl_name, ext_type)
}
var zero T
val, ok := val_if.(T)
if ok == false {
return nil, fmt.Errorf("%s.%s is not %s", ext_type, acl_name, reflect.TypeOf(zero))
}
var zero T
val, ok := val_if.(T)
if ok == false {
return nil, fmt.Errorf("%s.%s is not %s", ext_type, acl_name, reflect.TypeOf(zero))
}
return resolve_fn(p, val)
})
},
return resolve_fn(p, val)
})
}
ctx.Fields[gql_name] = Field{ext_type, acl_name, &graphql.Field{
Type: gql_type,
Resolve: resolver,
}}
return nil
}
@ -651,7 +653,8 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in
})
ctx_interface.List = graphql.NewList(ctx_interface.Interface)
for field_name, self_field := range(self_fields) {
for field_name, field := range(self_fields) {
self_field := field
err := RegisterField(ctx, ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName,
func(p graphql.ResolveParams, val interface{})(interface{}, error) {
ctx, err := PrepResolve(p)
@ -681,9 +684,9 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in
node_fields[field_name] = ctx.Fields[field_name].Field
}
for field_name, list_field := range(list_fields) {
err := RegisterField(ctx, ctx_interface.Interface, field_name, list_field.Extension, list_field.ACLName,
func(p graphql.ResolveParams, val interface{})(interface{}, error) {
for field_name, field := range(list_fields) {
list_field := field
resolve_fn := func(p graphql.ResolveParams, val interface{})(interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
@ -702,7 +705,9 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in
return nil, fmt.Errorf("wrong length of nodes returned")
}
return nodes, nil
})
}
err := RegisterField(ctx, ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn)
if err != nil {
return err
}
@ -806,6 +811,42 @@ func NewGQLExtContext() *GQLExtContext {
},
},
}, map[string]ListField{
"Requirements": ListField{
"requirements",
LockableExtType,
func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) {
id_strs, ok := val.(map[NodeID]ReqState)
if ok == false {
return nil, fmt.Errorf("can't parse requirements %+v as string, %s", val, reflect.TypeOf(val))
}
ids := make([]NodeID, len(id_strs))
i := 0
for id, _ := range(id_strs) {
ids[i] = id
i++
}
return ids, nil
},
},
"Dependencies": ListField{
"dependencies",
LockableExtType,
func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) {
id_strs, ok := val.(map[NodeID]string)
if ok == false {
return nil, fmt.Errorf("can't parse dependencies %+v as string, %s", val, reflect.TypeOf(val))
}
ids := make([]NodeID, len(id_strs))
i := 0
for id, _ := range(id_strs) {
ids[i] = id
i++
}
return ids, nil
},
},
})
if err != nil {
@ -819,7 +860,7 @@ func NewGQLExtContext() *GQLExtContext {
panic(err)
}
err = context.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node"}, []string{"Listen"})
err = context.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node", "Lockable"}, []string{"Listen", "Owner", "Requirements", "Dependencies"})
if err != nil {
panic(err)
}

@ -26,6 +26,14 @@ func TestGQL(t *testing.T) {
gql := NewNode(ctx, nil, GQLNodeType, 10, nil, NewLockableExt(), NewACLExt(policy), gql_ext, NewGroupExt(nil))
n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(policy), listener_ext)
err = LinkRequirement(ctx, gql.ID, n1.ID)
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, time.Millisecond*10, LinkSignalType, func(sig StateSignal) bool {
return sig.State == "linked_as_dep"
})
fatalErr(t, err)
ctx.Send(n1.ID, gql.ID, StateSignal{NewDirectSignal(GQLStateSignalType), "start_server"})
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig StateSignal) bool {
return sig.State == "server_started"
@ -47,7 +55,7 @@ func TestGQL(t *testing.T) {
}
req_2 := GQLPayload{
Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQLServer { Listen } } }",
Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQLServer { Listen, Requirements { ID, TypeHash, Dependencies { ID } } } } }",
Variables: map[string]interface{}{
"id": gql.ID.String(),
},