diff --git a/gql.go b/gql.go index d8fe0ae..c3dba7d 100644 --- a/gql.go +++ b/gql.go @@ -14,6 +14,7 @@ import ( "sync" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" + "strings" ) func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { @@ -112,6 +113,36 @@ func enableCORS(w *http.ResponseWriter) { (*w).Header().Set("Access-Control-Allow-Methods", "*") } +type GQLUnauthorized string + +func (e GQLUnauthorized) Is(target error) bool { + error_type := reflect.TypeOf(GQLUnauthorized("")) + target_type := reflect.TypeOf(target) + return error_type == target_type +} + +func (e GQLUnauthorized) Error() string { + return fmt.Sprintf("GQL_UNAUTHORIZED_ERROR: %s", string(e)) +} + +func checkForAuthHeader(header http.Header) (string, bool) { + auths, ok := header["Authorization"] + if ok == false { + return "", false + } + for _, auth := range(auths) { + parts := strings.SplitN(auth, " ", 2) + if len(parts) != 2 { + continue + } + if parts[0] == "TM" { + return parts[1], true + } + } + + return "", false +} + func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) { gql_ctx := context.Background() gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx) @@ -125,18 +156,25 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht header_map[header] = value } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") + auth, ok := checkForAuthHeader(r.Header) + if ok == false { + ctx.Log.Logf("gql", "GQL_REQUEST_ERR: no auth header included in request header") + return + } str, err := io.ReadAll(r.Body) if err != nil { - ctx.Log.Logf("gql", "failed to read request body: %s", err) + ctx.Log.Logf("gql", "GQL_REQUEST_ERR: failed to read request body: %s", err) return - } + } query := GQLWSPayload{} json.Unmarshal(str, &query) + req_ctx := context.WithValue(gql_ctx, "auth", auth) + params := graphql.Params{ Schema: ctx.GQL.Schema, - Context: gql_ctx, + Context: req_ctx, RequestString: query.Query, } if query.OperationName != "" {