diff --git a/gql.go b/gql.go index 6fd4c74..319dc97 100644 --- a/gql.go +++ b/gql.go @@ -351,7 +351,13 @@ func checkForAuthHeader(header http.Header) (string, bool) { return "", false } -func CheckAuth(server *GQLThread, r *http.Request) (*User, error) { +type ResolveContext struct { + Context *Context + Server *GQLThread + User *User +} + +func NewResolveContext(ctx *Context, server *GQLThread, r *http.Request) (*ResolveContext, error) { username, password, ok := r.BasicAuth() if ok == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") @@ -371,13 +377,15 @@ func CheckAuth(server *GQLThread, r *http.Request) (*User, error) { return nil, fmt.Errorf("GQL_AUTH_FAIL") } - return user, nil + return &ResolveContext{ + Context: ctx, + Server: server, + User: user, + }, nil } func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) { gql_ctx := context.Background() - gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx) - gql_ctx = context.WithValue(gql_ctx, "gql_server", server) return func(w http.ResponseWriter, r * http.Request) { ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr) @@ -388,13 +396,15 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - user, err := CheckAuth(server, r) + resolve_context, err := NewResolveContext(ctx, server, r) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) return } - req_ctx := context.WithValue(gql_ctx, "user", user) + + req_ctx := context.Background() + req_ctx = context.WithValue(gql_ctx, "resolve", resolve_context) str, err := io.ReadAll(r.Body) if err != nil { @@ -488,12 +498,13 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, * } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") - user, err := CheckAuth(server, r) + resolve_context, err := NewResolveContext(ctx, server, r) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) return } - req_ctx := context.WithValue(gql_ctx, "user", user) + req_ctx := context.Background() + req_ctx = context.WithValue(req_ctx, "resolve", resolve_context) u := ws.HTTPUpgrader{ Protocol: func(protocol string) bool { diff --git a/gql_mutation.go b/gql_mutation.go index ed105b7..907d339 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -16,12 +16,12 @@ var GQLMutationSendUpdate = NewField(func()*graphql.Field { }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - ctx, server, user, err := PrepResolve(p) + ctx, err := GetResolveContext(p) if err != nil { return nil, err } - err = server.Allowed("signal", "self", user) + err = ctx.Server.Allowed("signal", "self", ctx.User) if err != nil { return nil, err } @@ -33,11 +33,11 @@ var GQLMutationSendUpdate = NewField(func()*graphql.Field { var signal GraphSignal = nil if signal_map["Direction"] == "up" { - signal = NewSignal(server, signal_map["Type"].(string)) + signal = NewSignal(ctx.Server, signal_map["Type"].(string)) } else if signal_map["Direction"] == "down" { - signal = NewDownSignal(server, signal_map["Type"].(string)) + signal = NewDownSignal(ctx.Server, signal_map["Type"].(string)) } else if signal_map["Direction"] == "direct" { - signal = NewDirectSignal(server, signal_map["Type"].(string)) + signal = NewDirectSignal(ctx.Server, signal_map["Type"].(string)) } else { return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"]) } @@ -48,12 +48,12 @@ var GQLMutationSendUpdate = NewField(func()*graphql.Field { } var node Node = nil - err = UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){ - node = FindChild(ctx, server, id, nodes) + err = UseStates(ctx.Context, []Node{ctx.Server}, func(nodes NodeMap) (error){ + node = FindChild(ctx.Context, ctx.Server, id, nodes) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } - node.Signal(ctx, signal, nodes) + node.Signal(ctx.Context, signal, nodes) return nil }) if err != nil { @@ -83,12 +83,12 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - ctx, server, user, err := PrepResolve(p) + ctx, err := GetResolveContext(p) if err != nil { return nil, err } - err = server.Allowed("start_child", "self", user) + err = ctx.Server.Allowed("start_child", "self", ctx.User) if err != nil { return nil, err } @@ -109,14 +109,14 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ } var signal GraphSignal - err = UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){ - node := FindChild(ctx, server, parent_id, nodes) + err = UseStates(ctx.Context, []Node{ctx.Server}, func(nodes NodeMap) (error){ + node := FindChild(ctx.Context, ctx.Server, parent_id, nodes) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) } - return UseMoreStates(ctx, []Node{node}, nodes, func(NodeMap) error { - signal = NewStartChildSignal(server, child_id, action) - return node.Signal(ctx, signal, nodes) + return UseMoreStates(ctx.Context, []Node{node}, nodes, func(NodeMap) error { + signal = NewStartChildSignal(ctx.Server, child_id, action) + return node.Signal(ctx.Context, signal, nodes) }) }) if err != nil { diff --git a/gql_query.go b/gql_query.go index 2225357..8406efb 100644 --- a/gql_query.go +++ b/gql_query.go @@ -6,33 +6,33 @@ import ( var GQLQuerySelf = &graphql.Field{ Type: GQLTypeGQLThread.Type, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - _, server, user, err := PrepResolve(p) + ctx, err := GetResolveContext(p) if err != nil { return nil, err } - err = server.Allowed("enumerate", "self", user) + err = ctx.Server.Allowed("enumerate", "self", ctx.User) if err != nil { return nil, err } - return server, nil + return ctx.Server, nil }, } var GQLQueryUser = &graphql.Field{ Type: GQLTypeUser.Type, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - _, _, user, err := PrepResolve(p) + ctx, err := GetResolveContext(p) if err != nil { return nil, err } - err = user.Allowed("enumerate", "self", user) + err = ctx.User.Allowed("enumerate", "self", ctx.User) if err != nil { return nil, err } - return user, nil + return ctx.User, nil }, } diff --git a/gql_resolvers.go b/gql_resolvers.go index 674866e..c6ce781 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -5,23 +5,12 @@ import ( "github.com/graphql-go/graphql" ) -func PrepResolve(p graphql.ResolveParams) (*Context, *GQLThread, *User, error) { - context, ok := p.Context.Value("graph_context").(*Context) +func GetResolveContext(p graphql.ResolveParams) (*ResolveContext, error) { + resolve_context, ok := p.Context.Value("resolve").(*ResolveContext) if ok == false { - return nil, nil, nil, fmt.Errorf("failed to cast graph_context to *Context") + return nil, fmt.Errorf("Bad resolve in params context") } - - server, ok := p.Context.Value("gql_server").(*GQLThread) - if ok == false { - return nil, nil, nil, fmt.Errorf("failed to cast gql_server to *GQLThread") - } - - user, ok := p.Context.Value("user").(*User) - if ok == false { - return nil, nil, nil, fmt.Errorf("failed to cast user to *User") - } - - return context, server, user, nil + return resolve_context, nil } func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) {