require auth for gql

graph-rework-2
noah metz 2023-07-13 18:21:33 -06:00
parent 5d23646cd5
commit 893fb8c4c4
1 changed files with 41 additions and 3 deletions

@ -14,6 +14,7 @@ import (
"sync" "sync"
"github.com/gobwas/ws" "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"strings"
) )
func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { func GraphiQLHandler() func(http.ResponseWriter, *http.Request) {
@ -112,6 +113,36 @@ func enableCORS(w *http.ResponseWriter) {
(*w).Header().Set("Access-Control-Allow-Methods", "*") (*w).Header().Set("Access-Control-Allow-Methods", "*")
} }
type GQLUnauthorized string
func (e GQLUnauthorized) Is(target error) bool {
error_type := reflect.TypeOf(GQLUnauthorized(""))
target_type := reflect.TypeOf(target)
return error_type == target_type
}
func (e GQLUnauthorized) Error() string {
return fmt.Sprintf("GQL_UNAUTHORIZED_ERROR: %s", string(e))
}
func checkForAuthHeader(header http.Header) (string, bool) {
auths, ok := header["Authorization"]
if ok == false {
return "", false
}
for _, auth := range(auths) {
parts := strings.SplitN(auth, " ", 2)
if len(parts) != 2 {
continue
}
if parts[0] == "TM" {
return parts[1], true
}
}
return "", false
}
func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) { func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *http.Request) {
gql_ctx := context.Background() gql_ctx := context.Background()
gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx) gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx)
@ -125,18 +156,25 @@ func GQLHandler(ctx * Context, server * GQLThread) func(http.ResponseWriter, *ht
header_map[header] = value header_map[header] = value
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
auth, ok := checkForAuthHeader(r.Header)
if ok == false {
ctx.Log.Logf("gql", "GQL_REQUEST_ERR: no auth header included in request header")
return
}
str, err := io.ReadAll(r.Body) str, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "failed to read request body: %s", err) ctx.Log.Logf("gql", "GQL_REQUEST_ERR: failed to read request body: %s", err)
return return
} }
query := GQLWSPayload{} query := GQLWSPayload{}
json.Unmarshal(str, &query) json.Unmarshal(str, &query)
req_ctx := context.WithValue(gql_ctx, "auth", auth)
params := graphql.Params{ params := graphql.Params{
Schema: ctx.GQL.Schema, Schema: ctx.GQL.Schema,
Context: gql_ctx, Context: req_ctx,
RequestString: query.Query, RequestString: query.Query,
} }
if query.OperationName != "" { if query.OperationName != "" {