diff --git a/gql.go b/gql.go index 48117c4..78e2055 100644 --- a/gql.go +++ b/gql.go @@ -534,8 +534,11 @@ func MakeGQLHandlers(ctx * GraphContext, server * GQLThread) (func(http.Response } gql_ctx := context.Background() gql_ctx = context.WithValue(gql_ctx, "valid_nodes", valid_nodes) + gql_ctx = context.WithValue(gql_ctx, "node_type", node_type) gql_ctx = context.WithValue(gql_ctx, "valid_lockables", valid_lockables) + gql_ctx = context.WithValue(gql_ctx, "lockable_type", lockable_type) gql_ctx = context.WithValue(gql_ctx, "valid_threads", valid_threads) + gql_ctx = context.WithValue(gql_ctx, "thread_type", thread_type) gql_ctx = context.WithValue(gql_ctx, "gql_server", server) gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx) return GQLHandler(ctx, schema, gql_ctx), GQLWSHandler(ctx, schema, gql_ctx) diff --git a/gql_graph.go b/gql_graph.go index 9e3ea79..4cfadf7 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -17,11 +17,23 @@ func GQLInterfaceGraphNode() *graphql.Interface { return nil } + node_type, ok := p.Context.Value("node_type").(reflect.Type) + if ok == false { + return nil + } + + p_type := reflect.TypeOf(p.Value) + for key, value := range(valid_nodes) { - if reflect.TypeOf(p.Value) == key { + if p_type == key { return value } } + + if p_type.Implements(node_type) { + return GQLTypeBaseNode() + } + return nil }, Fields: graphql.Fields{}, @@ -53,16 +65,29 @@ func GQLInterfaceThread() *graphql.Interface { gql_interface_thread = graphql.NewInterface(graphql.InterfaceConfig{ Name: "Thread", ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { - valid_nodes, ok := p.Context.Value("valid_threads").(map[reflect.Type]*graphql.Object) + valid_threads, ok := p.Context.Value("valid_threads").(map[reflect.Type]*graphql.Object) if ok == false { return nil } - for key, value := range(valid_nodes) { - if reflect.TypeOf(p.Value) == key { + thread_type, ok := p.Context.Value("thread_type").(reflect.Type) + if ok == false { + return nil + } + + p_type := reflect.TypeOf(p.Value) + + + for key, value := range(valid_threads) { + if p_type == key { return value } } + + if p_type.Implements(thread_type) { + return GQLTypeBaseThread() + } + return nil }, Fields: graphql.Fields{}, @@ -102,16 +127,27 @@ func GQLInterfaceLockable() *graphql.Interface { gql_interface_lockable = graphql.NewInterface(graphql.InterfaceConfig{ Name: "Lockable", ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { - valid_nodes, ok := p.Context.Value("valid_lockables").(map[reflect.Type]*graphql.Object) + valid_lockables, ok := p.Context.Value("valid_lockables").(map[reflect.Type]*graphql.Object) if ok == false { return nil } - for key, value := range(valid_nodes) { - if reflect.TypeOf(p.Value) == key { + lockable_type, ok := p.Context.Value("lockable_type").(reflect.Type) + if ok == false { + return nil + } + + p_type := reflect.TypeOf(p.Value) + + for key, value := range(valid_lockables) { + if p_type == key { return value } } + + if p_type.Implements(lockable_type) { + return GQLTypeBaseLockable() + } return nil }, Fields: graphql.Fields{},