From e12d02eb3fd9bcfc2289d2489d92ff4cfa7bd939 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Mon, 10 Jul 2023 21:15:01 -0600 Subject: [PATCH] Cleaned up GQL context --- context.go | 221 +++++++++++++++++++++++--------------------------- gql_graph.go | 64 +++++++-------- graph_test.go | 4 +- 3 files changed, 135 insertions(+), 154 deletions(-) diff --git a/context.go b/context.go index b8ad42b..5ea84d9 100644 --- a/context.go +++ b/context.go @@ -19,12 +19,16 @@ type NodeLoadFunc func(*Context, NodeID, []byte, NodeMap)(Node, error) type NodeDef struct { Load NodeLoadFunc Type NodeType + GQLType *graphql.Object + Reflect reflect.Type } -func NewNodeDef(type_name string, load_func NodeLoadFunc) NodeDef { +func NewNodeDef(type_name string, reflect reflect.Type, load_func NodeLoadFunc, gql_type *graphql.Object) NodeDef { return NodeDef{ Type: NodeType(type_name), Load: load_func, + GQLType: gql_type, + Reflect: reflect, } } @@ -32,26 +36,62 @@ type Context struct { DB * badger.DB Log Logger Types map[uint64]NodeDef - GQL * GQLContext + GQL GQLContext } -func (ctx * Context) RegisterNodeType(type_name string, load_func NodeLoadFunc) error { - if load_func == nil { - return fmt.Errorf("Cannot register a node without a load function") +func (ctx * Context) RebuildSchema() error { + schemaConfig := graphql.SchemaConfig{ + Types: ctx.GQL.TypeList, + Query: ctx.GQL.Query, + Mutation: ctx.GQL.Mutation, + Subscription: ctx.GQL.Subscription, } - def := NodeDef{ - Type: NodeType(type_name), - Load: load_func, + schema, err := graphql.NewSchema(schemaConfig) + if err != nil { + return err + } + + ctx.GQL.Schema = schema + return nil +} + +func (ctx * Context) AddGQLType(gql_type graphql.Type) { + ctx.GQL.TypeList = append(ctx.GQL.TypeList, gql_type) +} + +func (ctx * Context) RegisterNodeType(def NodeDef) error { + if def.Load == nil { + return fmt.Errorf("Cannot register a node without a load function: %s", def.Type) + } + + if def.Reflect == nil { + return fmt.Errorf("Cannot register a node without a reflect type: %s", def.Type) + } + + if def.GQLType == nil { + return fmt.Errorf("Cannot register a node without a gql type: %s", def.Type) } type_hash := def.Type.Hash() _, exists := ctx.Types[type_hash] if exists == true { - return fmt.Errorf("Cannot register node of type %s, type already exists in context", type_name) + return fmt.Errorf("Cannot register node of type %s, type already exists in context", def.Type) } ctx.Types[type_hash] = def + + if def.Reflect.Implements(ctx.GQL.NodeType) { + ctx.GQL.ValidNodes[def.Reflect] = def.GQLType + } + if def.Reflect.Implements(ctx.GQL.LockableType) { + ctx.GQL.ValidLockables[def.Reflect] = def.GQLType + } + if def.Reflect.Implements(ctx.GQL.ThreadType) { + ctx.GQL.ValidThreads[def.Reflect] = def.GQLType + } + ctx.GQL.TypeList = append(ctx.GQL.TypeList, def.GQLType) + return nil } @@ -61,151 +101,92 @@ type FieldMap map[string]*graphql.Field type GQLContext struct { Schema graphql.Schema - ValidNodes ObjTypeMap + NodeType reflect.Type - ValidLockables ObjTypeMap LockableType reflect.Type - ValidThreads ObjTypeMap ThreadType reflect.Type -} - -func NewGQLContext(additional_types TypeList, extended_types ObjTypeMap, extended_queries FieldMap, extended_subscriptions FieldMap, extended_mutations FieldMap) (*GQLContext, error) { - type_list := TypeList{ - GQLTypeSignalInput(), - } - - for _, gql_type := range(additional_types) { - type_list = append(type_list, gql_type) - } - - type_map := ObjTypeMap{} - type_map[reflect.TypeOf((*GraphNode)(nil))] = GQLTypeBaseNode() - type_map[reflect.TypeOf((*SimpleLockable)(nil))] = GQLTypeBaseLockable() - type_map[reflect.TypeOf((*SimpleThread)(nil))] = GQLTypeBaseThread() - type_map[reflect.TypeOf((*GQLThread)(nil))] = GQLTypeGQLThread() - type_map[reflect.TypeOf((*BaseSignal)(nil))] = GQLTypeSignal() - - for go_t, gql_t := range(extended_types) { - type_map[go_t] = gql_t - } - valid_nodes := ObjTypeMap{} - valid_lockables := ObjTypeMap{} - valid_threads := ObjTypeMap{} + TypeList TypeList - node_type := reflect.TypeOf((*Node)(nil)).Elem() - lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem() - thread_type := reflect.TypeOf((*Thread)(nil)).Elem() - - for go_t, gql_t := range(type_map) { - if go_t.Implements(node_type) { - valid_nodes[go_t] = gql_t - } - if go_t.Implements(lockable_type) { - valid_lockables[go_t] = gql_t - } - if go_t.Implements(thread_type) { - valid_threads[go_t] = gql_t - } - type_list = append(type_list, gql_t) - } - - queries := graphql.Fields{ - "Self": GQLQuerySelf(), - } - - for key, val := range(extended_queries) { - queries[key] = val - } - - subscriptions := graphql.Fields{ - "Update": GQLSubscriptionUpdate(), - "Self": GQLSubscriptionSelf(), - } - - for key, val := range(extended_subscriptions) { - subscriptions[key] = val - } + ValidNodes ObjTypeMap + ValidLockables ObjTypeMap + ValidThreads ObjTypeMap - mutations := graphql.Fields{ - "SendUpdate": GQLMutationSendUpdate(), - } + Query *graphql.Object + Mutation *graphql.Object + Subscription *graphql.Object +} - for key, val := range(extended_mutations) { - mutations[key] = val - } +func NewGQLContext() GQLContext { + query := graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{}, + }) - schemaConfig := graphql.SchemaConfig{ - Types: type_list, - Query: graphql.NewObject(graphql.ObjectConfig{ - Name: "Query", - Fields: queries, - }), - Mutation: graphql.NewObject(graphql.ObjectConfig{ - Name: "Mutation", - Fields: mutations, - }), - Subscription: graphql.NewObject(graphql.ObjectConfig{ - Name: "Subscription", - Fields: subscriptions, - }), - } + mutation := graphql.NewObject(graphql.ObjectConfig{ + Name: "Mutation", + Fields: graphql.Fields{}, + }) - schema, err := graphql.NewSchema(schemaConfig) - if err != nil{ - return nil, err - } + subscription := graphql.NewObject(graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{}, + }) ctx := GQLContext{ - Schema: schema, - ValidNodes: valid_nodes, - NodeType: node_type, - ValidThreads: valid_threads, - ThreadType: thread_type, - ValidLockables: valid_lockables, - LockableType: lockable_type, + Schema: graphql.Schema{}, + TypeList: TypeList{}, + ValidNodes: ObjTypeMap{}, + NodeType: reflect.TypeOf((*Node)(nil)).Elem(), + ValidThreads: ObjTypeMap{}, + ThreadType: reflect.TypeOf((*Thread)(nil)).Elem(), + ValidLockables: ObjTypeMap{}, + LockableType: reflect.TypeOf((*Lockable)(nil)).Elem(), + Query: query, + Mutation: mutation, + Subscription: subscription, } - return &ctx, nil + return ctx } -func NewContext(db * badger.DB, log Logger, extra_nodes map[string]NodeLoadFunc, types TypeList, type_map ObjTypeMap, queries FieldMap, subscriptions FieldMap, mutations FieldMap) * Context { - gql, err := NewGQLContext(types, type_map, queries, subscriptions, mutations) - if err != nil { - panic(err) - } - +func NewContext(db * badger.DB, log Logger) * Context { ctx := &Context{ - GQL: gql, + GQL: NewGQLContext(), DB: db, Log: log, Types: map[uint64]NodeDef{}, } - - - err = ctx.RegisterNodeType("graph_node", LoadGraphNode) + err := ctx.RegisterNodeType(NewNodeDef("graph_node", reflect.TypeOf((*GraphNode)(nil)), LoadGraphNode, GQLTypeGraphNode())) if err != nil { panic(err) } - err = ctx.RegisterNodeType("simple_lockable", LoadSimpleLockable) + err = ctx.RegisterNodeType(NewNodeDef("simple_lockable", reflect.TypeOf((*SimpleLockable)(nil)), LoadSimpleLockable, GQLTypeSimpleLockable())) if err != nil { panic(err) } - err = ctx.RegisterNodeType("simple_thread", LoadSimpleThread) + err = ctx.RegisterNodeType(NewNodeDef("simple_thread", reflect.TypeOf((*SimpleThread)(nil)), LoadSimpleThread, GQLTypeSimpleThread())) if err != nil { panic(err) } - err = ctx.RegisterNodeType("gql_thread", LoadGQLThread) + err = ctx.RegisterNodeType(NewNodeDef("gql_thread", reflect.TypeOf((*GQLThread)(nil)), LoadGQLThread, GQLTypeGQLThread())) if err != nil { panic(err) } - for name, load_fn := range(extra_nodes) { - err := ctx.RegisterNodeType(name, load_fn) - if err != nil { - panic(err) - } + ctx.AddGQLType(GQLTypeSignal()) + + ctx.GQL.Query.AddFieldConfig("Self", GQLQuerySelf()) + + ctx.GQL.Subscription.AddFieldConfig("Update", GQLSubscriptionUpdate()) + ctx.GQL.Subscription.AddFieldConfig("Self", GQLSubscriptionSelf()) + + ctx.GQL.Mutation.AddFieldConfig("SendUpdate", GQLMutationSendUpdate()) + + err = ctx.RebuildSchema() + if err != nil { + panic(err) } return ctx diff --git a/gql_graph.go b/gql_graph.go index 7df4d2f..c1d941a 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -28,7 +28,7 @@ func GQLInterfaceNode() *graphql.Interface { } if p_type.Implements(node_type) { - return GQLTypeBaseNode() + return GQLTypeGraphNode() } return nil @@ -74,7 +74,7 @@ func GQLInterfaceThread() *graphql.Interface { } if p_type.Implements(thread_type) { - return GQLTypeBaseThread() + return GQLTypeSimpleThread() } ctx.Log.Logf("gql", "Found no type that matches %+v: %+v", p_type, p_type.Implements(thread_type)) @@ -145,7 +145,7 @@ func GQLInterfaceLockable() *graphql.Interface { } if p_type.Implements(lockable_type) { - return GQLTypeBaseLockable() + return GQLTypeSimpleLockable() } return nil }, @@ -418,10 +418,10 @@ func GQLTypeGQLThread() * graphql.Object { return gql_type_gql_thread } -var gql_type_base_thread *graphql.Object = nil -func GQLTypeBaseThread() * graphql.Object { - if gql_type_base_thread == nil { - gql_type_base_thread = graphql.NewObject(graphql.ObjectConfig{ +var gql_type_simple_thread *graphql.Object = nil +func GQLTypeSimpleThread() * graphql.Object { + if gql_type_simple_thread == nil { + gql_type_simple_thread = graphql.NewObject(graphql.ObjectConfig{ Name: "BaseThread", Interfaces: []*graphql.Interface{ GQLInterfaceNode(), @@ -446,48 +446,48 @@ func GQLTypeBaseThread() * graphql.Object { }, Fields: graphql.Fields{}, }) - gql_type_base_thread.AddFieldConfig("ID", &graphql.Field{ + gql_type_simple_thread.AddFieldConfig("ID", &graphql.Field{ Type: graphql.String, Resolve: GQLNodeID, }) - gql_type_base_thread.AddFieldConfig("Name", &graphql.Field{ + gql_type_simple_thread.AddFieldConfig("Name", &graphql.Field{ Type: graphql.String, Resolve: GQLLockableName, }) - gql_type_base_thread.AddFieldConfig("Children", &graphql.Field{ + gql_type_simple_thread.AddFieldConfig("Children", &graphql.Field{ Type: GQLListThread(), Resolve: GQLThreadChildren, }) - gql_type_base_thread.AddFieldConfig("Parent", &graphql.Field{ + gql_type_simple_thread.AddFieldConfig("Parent", &graphql.Field{ Type: GQLInterfaceThread(), Resolve: GQLThreadParent, }) - gql_type_base_thread.AddFieldConfig("Requirements", &graphql.Field{ + gql_type_simple_thread.AddFieldConfig("Requirements", &graphql.Field{ Type: GQLListLockable(), Resolve: GQLLockableRequirements, }) - gql_type_base_thread.AddFieldConfig("Owner", &graphql.Field{ + gql_type_simple_thread.AddFieldConfig("Owner", &graphql.Field{ Type: GQLInterfaceLockable(), Resolve: GQLLockableOwner, }) - gql_type_base_thread.AddFieldConfig("Dependencies", &graphql.Field{ + gql_type_simple_thread.AddFieldConfig("Dependencies", &graphql.Field{ Type: GQLListLockable(), Resolve: GQLLockableDependencies, }) } - return gql_type_base_thread + return gql_type_simple_thread } -var gql_type_base_lockable *graphql.Object = nil -func GQLTypeBaseLockable() * graphql.Object { - if gql_type_base_lockable == nil { - gql_type_base_lockable = graphql.NewObject(graphql.ObjectConfig{ +var gql_type_simple_lockable *graphql.Object = nil +func GQLTypeSimpleLockable() * graphql.Object { + if gql_type_simple_lockable == nil { + gql_type_simple_lockable = graphql.NewObject(graphql.ObjectConfig{ Name: "BaseLockable", Interfaces: []*graphql.Interface{ GQLInterfaceNode(), @@ -511,38 +511,38 @@ func GQLTypeBaseLockable() * graphql.Object { Fields: graphql.Fields{}, }) - gql_type_base_lockable.AddFieldConfig("ID", &graphql.Field{ + gql_type_simple_lockable.AddFieldConfig("ID", &graphql.Field{ Type: graphql.String, Resolve: GQLNodeID, }) - gql_type_base_lockable.AddFieldConfig("Name", &graphql.Field{ + gql_type_simple_lockable.AddFieldConfig("Name", &graphql.Field{ Type: graphql.String, Resolve: GQLLockableName, }) - gql_type_base_lockable.AddFieldConfig("Requirements", &graphql.Field{ + gql_type_simple_lockable.AddFieldConfig("Requirements", &graphql.Field{ Type: GQLListLockable(), Resolve: GQLLockableRequirements, }) - gql_type_base_lockable.AddFieldConfig("Owner", &graphql.Field{ + gql_type_simple_lockable.AddFieldConfig("Owner", &graphql.Field{ Type: GQLInterfaceLockable(), Resolve: GQLLockableOwner, }) - gql_type_base_lockable.AddFieldConfig("Dependencies", &graphql.Field{ + gql_type_simple_lockable.AddFieldConfig("Dependencies", &graphql.Field{ Type: GQLListLockable(), Resolve: GQLLockableDependencies, }) } - return gql_type_base_lockable + return gql_type_simple_lockable } -var gql_type_base_node *graphql.Object = nil -func GQLTypeBaseNode() * graphql.Object { - if gql_type_base_node == nil { - gql_type_base_node = graphql.NewObject(graphql.ObjectConfig{ +var gql_type_simple_node *graphql.Object = nil +func GQLTypeGraphNode() * graphql.Object { + if gql_type_simple_node == nil { + gql_type_simple_node = graphql.NewObject(graphql.ObjectConfig{ Name: "BaseNode", Interfaces: []*graphql.Interface{ GQLInterfaceNode(), @@ -565,18 +565,18 @@ func GQLTypeBaseNode() * graphql.Object { Fields: graphql.Fields{}, }) - gql_type_base_node.AddFieldConfig("ID", &graphql.Field{ + gql_type_simple_node.AddFieldConfig("ID", &graphql.Field{ Type: graphql.String, Resolve: GQLNodeID, }) - gql_type_base_node.AddFieldConfig("Name", &graphql.Field{ + gql_type_simple_node.AddFieldConfig("Name", &graphql.Field{ Type: graphql.String, Resolve: GQLLockableName, }) } - return gql_type_base_node + return gql_type_simple_node } func GQLSignalFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { diff --git a/graph_test.go b/graph_test.go index 558aef0..54b76dd 100644 --- a/graph_test.go +++ b/graph_test.go @@ -58,7 +58,7 @@ func logTestContext(t * testing.T, components []string) * Context { t.Fatal(err) } - return NewContext(db, NewConsoleLogger(components), map[string]NodeLoadFunc{}, TypeList{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) + return NewContext(db, NewConsoleLogger(components)) } func testContext(t * testing.T) * Context { @@ -67,7 +67,7 @@ func testContext(t * testing.T) * Context { t.Fatal(err) } - return NewContext(db, NewConsoleLogger([]string{}), map[string]NodeLoadFunc{}, TypeList{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) + return NewContext(db, NewConsoleLogger([]string{})) } func fatalErr(t * testing.T, err error) {