From cf11176cffb00e679be7cf453a488deff07d44ed Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Tue, 27 Jun 2023 13:07:23 -0600 Subject: [PATCH] Fix gql base types --- gql_graph.go | 51 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/gql_graph.go b/gql_graph.go index 56abade..c21fc7c 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -496,16 +496,23 @@ func GQLTypeBaseThread() * graphql.Object { GQLInterfaceLockable(), }, IsTypeOf: func(p graphql.IsTypeOfParams) bool { - valid_threads, ok := p.Context.Value("valid_threads").(map[reflect.Type]*graphql.Object) + ctx, ok := p.Context.Value("graph_context").(*GraphContext) if ok == false { return false } + + thread_type, ok := p.Context.Value("thread_type").(*reflect.Type) + if ok == false { + ctx.Log.Logf("gql", "Failed to get thread_type from Context: %+v", p.Context.Value("thread_type")) + return false + } + value_type := reflect.TypeOf(p.Value) - for go_type, _ := range(valid_threads) { - if value_type == go_type { - return true - } + + if value_type.Implements(*thread_type) { + return true } + return false }, Fields: graphql.Fields{}, @@ -558,16 +565,23 @@ func GQLTypeBaseLockable() * graphql.Object { GQLInterfaceLockable(), }, IsTypeOf: func(p graphql.IsTypeOfParams) bool { - valid_lockables, ok := p.Context.Value("valid_lockables").(map[reflect.Type]*graphql.Object) + ctx, ok := p.Context.Value("graph_context").(*GraphContext) + if ok == false { + return false + } + + lockable_type, ok := p.Context.Value("lockable_type").(*reflect.Type) if ok == false { + ctx.Log.Logf("gql", "Failed to get lockable_type from Context: %+v", p.Context.Value("lockable_type")) return false } + value_type := reflect.TypeOf(p.Value) - for go_type, _ := range(valid_lockables) { - if value_type == go_type { - return true - } + + if value_type.Implements(*lockable_type) { + return true } + return false }, Fields: graphql.Fields{}, @@ -611,16 +625,23 @@ func GQLTypeBaseNode() * graphql.Object { GQLInterfaceGraphNode(), }, IsTypeOf: func(p graphql.IsTypeOfParams) bool { - valid_nodes, ok := p.Context.Value("valid_nodes").(map[reflect.Type]*graphql.Object) + ctx, ok := p.Context.Value("graph_context").(*GraphContext) + if ok == false { + return false + } + + node_type, ok := p.Context.Value("node_type").(*reflect.Type) if ok == false { + ctx.Log.Logf("gql", "Failed to get node_type from Context: %+v", p.Context.Value("node_type")) return false } + value_type := reflect.TypeOf(p.Value) - for go_type, _ := range(valid_nodes) { - if value_type == go_type { - return true - } + + if value_type.Implements(*node_type) { + return true } + return false }, Fields: graphql.Fields{},