diff --git a/context.go b/context.go index a0d37b0..c9c8e0b 100644 --- a/context.go +++ b/context.go @@ -2,6 +2,7 @@ package graphvent import ( badger "github.com/dgraph-io/badger/v3" + "github.com/graphql-go/graphql" "fmt" "sync" "errors" @@ -273,12 +274,14 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = gql_ctx.RegisterNodeType(GQLNodeType, TypeGQLNode.Type) + err = RegisterField(gql_ctx, graphql.String, "Listen", GQLExtType, "listen", func(listen string) (interface{}, error) { + return listen, nil + }) if err != nil { return nil, err } - err = gql_ctx.RegisterField("Listen", GQLExtType, "listen") + err = gql_ctx.RegisterNodeType(GQLNodeType, "GQLServer", NodeInterfaces, []string{"Listen"}) if err != nil { return nil, err } diff --git a/gql.go b/gql.go index c8b2b9d..b018be7 100644 --- a/gql.go +++ b/gql.go @@ -459,29 +459,10 @@ type Type struct { List *graphql.List } -func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init func(*Type)) *Type { - var gql Type - gql.Type = graphql.NewObject(graphql.ObjectConfig{ - Name: string(node_type), - Interfaces: interfaces, - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - node, ok := p.Value.(NodeResult) - if ok == false { - return false - } - return node.Result.NodeType == node_type - }, - Fields: graphql.Fields{}, - }) - gql.List = graphql.NewList(gql.Type) - - init(&gql) - return &gql -} - type Field struct { Ext ExtType Name string + Field *graphql.Field } // GQL Specific Context information @@ -491,7 +472,7 @@ type GQLExtContext struct { // Custom graphql types, mapped to NodeTypes NodeTypes map[NodeType]*graphql.Object - Interfaces []*Interface + Interfaces map[string]*Interface Fields map[string]Field // Schema parameters @@ -536,17 +517,48 @@ func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) { return graphql.NewSchema(schemaConfig) } -func (ctx *GQLExtContext) RegisterField(gql_name string, ext ExtType, acl_name string) error { +func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(T)(interface{}, error)) error { + if ctx == nil { + return fmt.Errorf("ctx is nil") + } + + if resolve_fn == nil { + return fmt.Errorf("resolve_fn cannot be nil") + } + _, exists := ctx.Fields[gql_name] if exists == true { return fmt.Errorf("%s is already a field in the context, cannot add again", gql_name) } - ctx.Fields[gql_name] = Field{ext, acl_name} + ctx.Fields[gql_name] = Field{ext_type, acl_name, &graphql.Field{ + Type: gql_type, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return ResolveNodeResult(p, func(result NodeResult) (interface{}, error) { + ext, exists := result.Result.Extensions[ext_type] + if exists == false { + return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type) + } + + val_if, exists := ext[acl_name] + if exists == false { + return nil, fmt.Errorf("%s is not in the fields of %s in the result", acl_name, ext_type) + } + + var zero T + val, ok := val_if.(T) + if ok == false { + return nil, fmt.Errorf("%s.%s is not %s", ext_type, acl_name, reflect.TypeOf(zero)) + } + + return resolve_fn(val) + }) + }, + }} return nil } -func (ctx *GQLExtContext) AddInterface(i *Interface) error { +func (ctx *GQLExtContext) RegisterInterface(i *Interface) error { if i == nil { return fmt.Errorf("interface is nil") } @@ -555,15 +567,21 @@ func (ctx *GQLExtContext) AddInterface(i *Interface) error { return fmt.Errorf("invalid interface, contains nil") } - ctx.Interfaces = append(ctx.Interfaces, i) + name := i.Interface.PrivateName + _, exists := ctx.Interfaces[name] + if exists == true { + return fmt.Errorf("%s is already an interface in ctx", name) + } + + ctx.Interfaces[name] = i ctx.Types = append(ctx.Types, i.Default) return nil } -func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, gql_type *graphql.Object) error { - if gql_type == nil { - return fmt.Errorf("gql_type is nil") +func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, interfaces []*Interface, field_names []string) error { + if field_names == nil { + return fmt.Errorf("fields is nil") } _, exists := ctx.NodeTypes[node_type] @@ -571,6 +589,50 @@ func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, gql_type *graphql return fmt.Errorf("%s already in GQLExtContext.NodeTypes", node_type) } + node_interfaces := make([]*graphql.Interface, len(interfaces)) + for i, in := range(interfaces) { + if_name := in.Interface.PrivateName + _, found := ctx.Interfaces[if_name] + + if found == false { + return fmt.Errorf("%+v is not in GQLExtContext.Interfaces", in) + } + node_interfaces[i] = in.Interface + } + + fields := graphql.Fields{ + "ID": &graphql.Field{ + Type: graphql.String, + Resolve: ResolveNodeID, + }, + "TypeHash": &graphql.Field{ + Type: graphql.String, + Resolve: ResolveNodeTypeHash, + }, + } + + for _, name := range(field_names) { + field, exists := ctx.Fields[name] + if exists == false { + return fmt.Errorf("%s is not in GQLExtContext.Fields", name) + } + fields[name] = field.Field + } + + gql_type := graphql.NewObject(graphql.ObjectConfig{ + Name: name, + Interfaces: node_interfaces, + IsTypeOf: func(p graphql.IsTypeOfParams) bool { + node, ok := p.Value.(NodeResult) + if ok == false { + return false + } + + return node.Result.NodeType == node_type + }, + Fields: fields, + }) + ctx.NodeTypes[node_type] = gql_type ctx.Types = append(ctx.Types, gql_type) @@ -608,16 +670,16 @@ func NewGQLExtContext() *GQLExtContext { Mutation: mutation, Subscription: subscription, NodeTypes: map[NodeType]*graphql.Object{}, - Interfaces: []*Interface{}, + Interfaces: map[string]*Interface{}, Fields: map[string]Field{}, } var err error - err = context.AddInterface(InterfaceNode) + err = context.RegisterInterface(InterfaceNode) if err != nil { panic(err) } - err = context.AddInterface(InterfaceLockable) + err = context.RegisterInterface(InterfaceLockable) if err != nil { panic(err) } diff --git a/gql_interfaces.go b/gql_interfaces.go index cd8fa2d..f0539e2 100644 --- a/gql_interfaces.go +++ b/gql_interfaces.go @@ -139,6 +139,17 @@ var InterfaceNode = NewInterface("Node", "DefaultNode", []*graphql.Interface{}, var InterfaceLockable = NewInterface("Lockable", "DefaultLockable", []*graphql.Interface{InterfaceNode.Interface}, []ExtType{LockableExtType}, func(gql *Interface) { addLockableInterfaceFields(gql, gql) }, func(gql *Interface) { - addLockableFields(gql.Default, gql.Interface, gql.List) + AddNodeFields(gql.Default) + gql.Default.AddFieldConfig("Requirements", &graphql.Field{ + Type: gql.List, + }) + + gql.Default.AddFieldConfig("Owner", &graphql.Field{ + Type: gql.Interface, + }) + + gql.Default.AddFieldConfig("Dependencies", &graphql.Field{ + Type: gql.List, + }) }) diff --git a/gql_resolvers.go b/gql_resolvers.go index e545510..4d0f216 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -85,70 +85,6 @@ func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { }) } -func ResolveNodeResultExt[T any](p graphql.ResolveParams, ext_type ExtType, field string, resolve_fn func(T)(interface{}, error)) (interface{}, error) { - return ResolveNodeResult(p, func(result NodeResult) (interface{}, error) { - ext, exists := result.Result.Extensions[ext_type] - if exists == false { - return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type) - } - - val_if, exists := ext[field] - if exists == false { - return nil, fmt.Errorf("%s is not in the fields of %s in the result", field, ext_type) - } - - var zero T - val, ok := val_if.(T) - if ok == false { - return nil, fmt.Errorf("%s.%s is not %s", ext_type, field, reflect.TypeOf(zero)) - } - - return resolve_fn(val) - }) -} - -func ResolveListen(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResultExt(p, GQLExtType, "listen", func(listen string) (interface{}, error) { - return listen, nil - }) -} - -func ResolveRequirements(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResultExt(p, LockableExtType, "requirements", func(requirements []NodeID) (interface{}, error) { - res := make([]string, len(requirements)) - for i, id := range(requirements) { - res[i] = id.String() - } - return res, nil - }) -} - -func ResolveDependencies(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResultExt(p, LockableExtType, "dependencies", func(dependencies []NodeID) (interface{}, error) { - res := make([]string, len(dependencies)) - for i, id := range(dependencies) { - res[i] = id.String() - } - return res, nil - }) -} - -func ResolveOwner(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResultExt(p, LockableExtType, "owner", func(owner NodeID) (interface{}, error) { - return owner.String(), nil - }) -} - -func ResolveMembers(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResultExt(p, GroupExtType, "members", func(members []NodeID) (interface{}, error) { - res := make([]string, len(members)) - for i, id := range(members) { - res[i] = id.String() - } - return res, nil - }) -} - func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { if signal, ok := p.Source.(Signal); ok { return fn(signal, p) diff --git a/gql_test.go b/gql_test.go index 5c60870..18763a0 100644 --- a/gql_test.go +++ b/gql_test.go @@ -47,7 +47,7 @@ func TestGQL(t *testing.T) { } req_2 := GQLPayload{ - Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQL { Listen } } }", + Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQLServer { Listen } } }", Variables: map[string]interface{}{ "id": gql.ID.String(), }, diff --git a/gql_types.go b/gql_types.go index 676d646..a540700 100644 --- a/gql_types.go +++ b/gql_types.go @@ -16,40 +16,29 @@ func AddNodeFields(object *graphql.Object) { }) } -func AddLockableFields(object *graphql.Object) { - addLockableFields(object, InterfaceLockable.Interface, InterfaceLockable.List) -} - -func addLockableFields(object *graphql.Object, lockable_interface *graphql.Interface, lockable_list *graphql.List) { - AddNodeFields(object) - object.AddFieldConfig("Requirements", &graphql.Field{ - Type: lockable_list, - Resolve: ResolveRequirements, - }) +var NodeInterfaces = []*Interface{InterfaceNode} +var LockableInterfaces = append(NodeInterfaces, InterfaceLockable) - object.AddFieldConfig("Owner", &graphql.Field{ - Type: lockable_interface, - Resolve: ResolveOwner, +func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init func(*Type)) *Type { + var gql Type + gql.Type = graphql.NewObject(graphql.ObjectConfig{ + Name: string(node_type), + Interfaces: interfaces, + IsTypeOf: func(p graphql.IsTypeOfParams) bool { + node, ok := p.Value.(NodeResult) + if ok == false { + return false + } + return node.Result.NodeType == node_type + }, + Fields: graphql.Fields{}, }) + gql.List = graphql.NewList(gql.Type) - object.AddFieldConfig("Dependencies", &graphql.Field{ - Type: lockable_list, - Resolve: ResolveDependencies, - }) + init(&gql) + return &gql } -var GQLNodeInterfaces = []*graphql.Interface{InterfaceNode.Interface} -var GQLLockableInterfaces = append(GQLNodeInterfaces, InterfaceLockable.Interface) - -var TypeGQLNode = NewGQLNodeType(GQLNodeType, GQLNodeInterfaces, func(gql *Type) { - AddNodeFields(gql.Type) - - gql.Type.AddFieldConfig("Listen", &graphql.Field{ - Type: graphql.String, - Resolve: ResolveListen, - }) -}) - var TypeSignal = NewSingleton(func() *graphql.Object { gql_type_signal := graphql.NewObject(graphql.ObjectConfig{ Name: "Signal",