Added GetResolveContext

graph-rework-2
noah metz 2023-07-21 18:51:42 -06:00
parent 6d6effadec
commit 97815c86ff
4 changed files with 44 additions and 44 deletions

@ -351,7 +351,13 @@ func checkForAuthHeader(header http.Header) (string, bool) {
return "", false return "", false
} }
func CheckAuth(server *GQLThread, r *http.Request) (*User, error) { type ResolveContext struct {
Context *Context
Server *GQLThread
User *User
}
func NewResolveContext(ctx *Context, server *GQLThread, r *http.Request) (*ResolveContext, error) {
username, password, ok := r.BasicAuth() username, password, ok := r.BasicAuth()
if ok == false { if ok == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header")
@ -371,13 +377,15 @@ func CheckAuth(server *GQLThread, r *http.Request) (*User, error) {
return nil, fmt.Errorf("GQL_AUTH_FAIL") return nil, fmt.Errorf("GQL_AUTH_FAIL")
} }
return user, nil return &ResolveContext{
Context: ctx,
Server: server,
User: user,
}, nil
} }
func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) { func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) {
gql_ctx := context.Background() gql_ctx := context.Background()
gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx)
gql_ctx = context.WithValue(gql_ctx, "gql_server", server)
return func(w http.ResponseWriter, r * http.Request) { return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr) ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr)
@ -388,13 +396,15 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
user, err := CheckAuth(server, r) resolve_context, err := NewResolveContext(ctx, server, r)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err)))
return return
} }
req_ctx := context.WithValue(gql_ctx, "user", user)
req_ctx := context.Background()
req_ctx = context.WithValue(gql_ctx, "resolve", resolve_context)
str, err := io.ReadAll(r.Body) str, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
@ -488,12 +498,13 @@ func GQLWSHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
user, err := CheckAuth(server, r) resolve_context, err := NewResolveContext(ctx, server, r)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
return return
} }
req_ctx := context.WithValue(gql_ctx, "user", user) req_ctx := context.Background()
req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
u := ws.HTTPUpgrader{ u := ws.HTTPUpgrader{
Protocol: func(protocol string) bool { Protocol: func(protocol string) bool {

@ -16,12 +16,12 @@ var GQLMutationSendUpdate = NewField(func()*graphql.Field {
}, },
}, },
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ctx, server, user, err := PrepResolve(p) ctx, err := GetResolveContext(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = server.Allowed("signal", "self", user) err = ctx.Server.Allowed("signal", "self", ctx.User)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -33,11 +33,11 @@ var GQLMutationSendUpdate = NewField(func()*graphql.Field {
var signal GraphSignal = nil var signal GraphSignal = nil
if signal_map["Direction"] == "up" { if signal_map["Direction"] == "up" {
signal = NewSignal(server, signal_map["Type"].(string)) signal = NewSignal(ctx.Server, signal_map["Type"].(string))
} else if signal_map["Direction"] == "down" { } else if signal_map["Direction"] == "down" {
signal = NewDownSignal(server, signal_map["Type"].(string)) signal = NewDownSignal(ctx.Server, signal_map["Type"].(string))
} else if signal_map["Direction"] == "direct" { } else if signal_map["Direction"] == "direct" {
signal = NewDirectSignal(server, signal_map["Type"].(string)) signal = NewDirectSignal(ctx.Server, signal_map["Type"].(string))
} else { } else {
return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"]) return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"])
} }
@ -48,12 +48,12 @@ var GQLMutationSendUpdate = NewField(func()*graphql.Field {
} }
var node Node = nil var node Node = nil
err = UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){ err = UseStates(ctx.Context, []Node{ctx.Server}, func(nodes NodeMap) (error){
node = FindChild(ctx, server, id, nodes) node = FindChild(ctx.Context, ctx.Server, id, nodes)
if node == nil { if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", id) return fmt.Errorf("Failed to find ID: %s as child of server thread", id)
} }
node.Signal(ctx, signal, nodes) node.Signal(ctx.Context, signal, nodes)
return nil return nil
}) })
if err != nil { if err != nil {
@ -83,12 +83,12 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
}, },
}, },
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ctx, server, user, err := PrepResolve(p) ctx, err := GetResolveContext(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = server.Allowed("start_child", "self", user) err = ctx.Server.Allowed("start_child", "self", ctx.User)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -109,14 +109,14 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{
} }
var signal GraphSignal var signal GraphSignal
err = UseStates(ctx, []Node{server}, func(nodes NodeMap) (error){ err = UseStates(ctx.Context, []Node{ctx.Server}, func(nodes NodeMap) (error){
node := FindChild(ctx, server, parent_id, nodes) node := FindChild(ctx.Context, ctx.Server, parent_id, nodes)
if node == nil { if node == nil {
return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id) return fmt.Errorf("Failed to find ID: %s as child of server thread", parent_id)
} }
return UseMoreStates(ctx, []Node{node}, nodes, func(NodeMap) error { return UseMoreStates(ctx.Context, []Node{node}, nodes, func(NodeMap) error {
signal = NewStartChildSignal(server, child_id, action) signal = NewStartChildSignal(ctx.Server, child_id, action)
return node.Signal(ctx, signal, nodes) return node.Signal(ctx.Context, signal, nodes)
}) })
}) })
if err != nil { if err != nil {

@ -6,33 +6,33 @@ import (
var GQLQuerySelf = &graphql.Field{ var GQLQuerySelf = &graphql.Field{
Type: GQLTypeGQLThread.Type, Type: GQLTypeGQLThread.Type,
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, server, user, err := PrepResolve(p) ctx, err := GetResolveContext(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = server.Allowed("enumerate", "self", user) err = ctx.Server.Allowed("enumerate", "self", ctx.User)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return server, nil return ctx.Server, nil
}, },
} }
var GQLQueryUser = &graphql.Field{ var GQLQueryUser = &graphql.Field{
Type: GQLTypeUser.Type, Type: GQLTypeUser.Type,
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
_, _, user, err := PrepResolve(p) ctx, err := GetResolveContext(p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = user.Allowed("enumerate", "self", user) err = ctx.User.Allowed("enumerate", "self", ctx.User)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return user, nil return ctx.User, nil
}, },
} }

@ -5,23 +5,12 @@ import (
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
) )
func PrepResolve(p graphql.ResolveParams) (*Context, *GQLThread, *User, error) { func GetResolveContext(p graphql.ResolveParams) (*ResolveContext, error) {
context, ok := p.Context.Value("graph_context").(*Context) resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false { if ok == false {
return nil, nil, nil, fmt.Errorf("failed to cast graph_context to *Context") return nil, fmt.Errorf("Bad resolve in params context")
} }
return resolve_context, nil
server, ok := p.Context.Value("gql_server").(*GQLThread)
if ok == false {
return nil, nil, nil, fmt.Errorf("failed to cast gql_server to *GQLThread")
}
user, ok := p.Context.Value("user").(*User)
if ok == false {
return nil, nil, nil, fmt.Errorf("failed to cast user to *User")
}
return context, server, user, nil
} }
func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) { func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) {