From 0428645be3307840a9793f701603e48eb10b0e58 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Fri, 21 Jul 2023 14:28:53 -0600 Subject: [PATCH] Added ExtractParam and ExtractID --- gql.go | 17 ++++++++ gql_graph.go | 115 +++++++++++++++++++++++++++++---------------------- 2 files changed, 83 insertions(+), 49 deletions(-) diff --git a/gql.go b/gql.go index 9673bd6..6fd4c74 100644 --- a/gql.go +++ b/gql.go @@ -879,6 +879,23 @@ var gql_handlers ThreadHandlers = ThreadHandlers{ }) return "wait", nil }, + "start_child": func(ctx *Context, thread Thread, signal GraphSignal) (string, error) { + ctx.Log.Logf("gql", "GQL_START_CHILD") + sig, ok := signal.(StartChildSignal) + if ok == false { + ctx.Log.Logf("gql", "GQL_START_CHILD_BAD_SIGNAL: %+v", signal) + return "wait", nil + } + + err := ThreadStartChild(ctx, thread, sig) + if err != nil { + ctx.Log.Logf("gql", "GQL_START_CHILD_ERR: %s", err) + } else { + ctx.Log.Logf("gql", "GQL_START_CHILD: %s", sig.ChildID.String()) + } + + return "wait", nil + }, "abort": func(ctx * Context, thread Thread, signal GraphSignal) (string, error) { ctx.Log.Logf("gql", "GQL_ABORT") server := thread.(*GQLThread) diff --git a/gql_graph.go b/gql_graph.go index 1703be8..d97530d 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -880,6 +880,54 @@ func GQLSubscriptionUpdate() * graphql.Field { return gql_subscription_update } +func PrepResolve(p graphql.ResolveParams) (*Context, *GQLThread, *User, error) { + context, ok := p.Context.Value("graph_context").(*Context) + if ok == false { + return nil, nil, nil, fmt.Errorf("failed to cast graph_context to *Context") + } + + server, ok := p.Context.Value("gql_server").(*GQLThread) + if ok == false { + return nil, nil, nil, fmt.Errorf("failed to cast gql_server to *GQLThread") + } + + user, ok := p.Context.Value("user").(*User) + if ok == false { + return nil, nil, nil, fmt.Errorf("failed to cast user to *User") + } + + return context, server, user, nil +} + +func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) { + var zero K + arg_if, ok := p.Args[name] + if ok == false { + return zero, fmt.Errorf("No Arg of name %s", name) + } + + arg, ok := arg_if.(K) + if ok == false { + return zero, fmt.Errorf("Failed to cast arg %s(%+v) to %+v", name, arg_if, reflect.TypeOf(zero)) + } + + return arg, nil +} + +func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) { + id_str, err := ExtractParam[string](p, name) + if err != nil { + return ZeroID, err + } + + id, err := ParseID(id_str) + if err != nil { + return ZeroID, err + } + + return id, nil +} + var gql_mutation_send_update *graphql.Field = nil func GQLMutationSendUpdate() *graphql.Field { if gql_mutation_send_update == nil { @@ -894,7 +942,7 @@ func GQLMutationSendUpdate() *graphql.Field { }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - ctx, server, user, err := GQLPrepResolve(p) + ctx, server, user, err := PrepResolve(p) if err != nil { return nil, err } @@ -904,10 +952,11 @@ func GQLMutationSendUpdate() *graphql.Field { return nil, err } - signal_map, ok := p.Args["signal"].(map[string]interface{}) - if ok == false { - return nil, fmt.Errorf("Failed to cast arg signal to GraphSignal: %+v", p.Args["signal"]) + signal_map, err := ExtractParam[map[string]interface{}](p, "signal") + if err != nil { + return nil, err } + var signal GraphSignal = nil if signal_map["Direction"] == "up" { signal = NewSignal(server, signal_map["Type"].(string)) @@ -919,12 +968,7 @@ func GQLMutationSendUpdate() *graphql.Field { return nil, fmt.Errorf("Bad direction: %d", signal_map["Direction"]) } - id_str, ok := p.Args["id"].(string) - if ok == false { - return nil, fmt.Errorf("Failed to cast arg id to string") - } - - id, err := ParseID(id_str) + id, err := ExtractID(p, "id") if err != nil { return nil, err } @@ -968,37 +1012,29 @@ func GQLMutationStartChild() *graphql.Field { }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - server, ok := p.Context.Value("gql_server").(*GQLThread) - if ok == false { - return nil, fmt.Errorf("Failed to cast context gql_server to GQLServer: %+v", p.Context.Value("gql_server")) + ctx, server, user, err := PrepResolve(p) + if err != nil { + return nil, err } - ctx, ok := p.Context.Value("graph_context").(*Context) - if ok == false { - return nil, fmt.Errorf("Failed to cast context graph_context to Context: %+v", p.Context.Value("graph_context")) + err = server.Allowed("start_child", "self", user) + if err != nil { + return nil, err } - parent_str, ok := p.Args["parent_id"].(string) - if ok == false { - return nil, fmt.Errorf("Failed to cast arg parent_id to string: %+v", p.Args["parent_id"]) - } - parent_id, err := ParseID(parent_str) + parent_id, err := ExtractID(p, "parent_id") if err != nil { return nil, err } - child_str, ok := p.Args["child_id"].(string) - if ok == false { - return nil, fmt.Errorf("Failed to cast arg child_id to string: %+v", p.Args["child_id"]) - } - child_id, err := ParseID(child_str) + child_id, err := ExtractID(p, "child_id") if err != nil { return nil, err } - action, ok := p.Args["action"].(string) - if ok == false { - return nil, fmt.Errorf("Failed to cast arg action to string: %+v", p.Args["action"]) + action, err := ExtractParam[string](p, "action") + if err != nil { + return nil, err } var signal GraphSignal @@ -1025,32 +1061,13 @@ func GQLMutationStartChild() *graphql.Field { return gql_mutation_start_child } -func GQLPrepResolve(p graphql.ResolveParams) (*Context, *GQLThread, *User, error) { - context, ok := p.Context.Value("graph_context").(*Context) - if ok == false { - return nil, nil, nil, fmt.Errorf("failed to cast graph_context to *Context") - } - - server, ok := p.Context.Value("gql_server").(*GQLThread) - if ok == false { - return nil, nil, nil, fmt.Errorf("failed to cast gql_server to *GQLThread") - } - - user, ok := p.Context.Value("user").(*User) - if ok == false { - return nil, nil, nil, fmt.Errorf("failed to cast user to *User") - } - - return context, server, user, nil -} - var gql_query_self *graphql.Field = nil func GQLQuerySelf() *graphql.Field { if gql_query_self == nil { gql_query_self = &graphql.Field{ Type: GQLTypeGQLThread(), Resolve: func(p graphql.ResolveParams) (interface{}, error) { - _, server, user, err := GQLPrepResolve(p) + _, server, user, err := PrepResolve(p) if err != nil { return nil, err