package graphvent import ( "time" "net" "net/http" "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/parser" "github.com/graphql-go/graphql/language/source" "github.com/graphql-go/graphql/language/ast" "context" "encoding/json" "io" "reflect" "fmt" "sync" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "strings" "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "crypto/tls" "crypto/x509/pkix" "math/big" "encoding/pem" ) func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r * http.Request) { graphiql_string := fmt.Sprintf(` GraphiQL
Loading...
`) w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusOK) io.WriteString(w, graphiql_string) } } type GQLPayload struct { OperationName string `json:"operationName,omitempty"` Query string `json:"query,omitempty"` Variables map[string]interface{} `json:"variables,omitempty"` Extensions map[string]interface{} `json:"extensions,omitempty"` Data string `json:"data,omitempty"` } type GQLWSMsg struct { ID string `json:"id,omitempty"` Type string `json:"type"` Payload GQLPayload `json:"payload,omitempty"` } func enableCORS(w *http.ResponseWriter) { (*w).Header().Set("Access-Control-Allow-Origin", "*") (*w).Header().Set("Access-Control-Allow-Credentials", "true") (*w).Header().Set("Access-Control-Allow-Headers", "*") (*w).Header().Set("Access-Control-Allow-Methods", "*") } type GQLUnauthorized string func (e GQLUnauthorized) Is(target error) bool { error_type := reflect.TypeOf(GQLUnauthorized("")) target_type := reflect.TypeOf(target) return error_type == target_type } func (e GQLUnauthorized) Error() string { return fmt.Sprintf("GQL_UNAUTHORIZED_ERROR: %s", string(e)) } func (e GQLUnauthorized) MarshalJSON() ([]byte, error) { return json.MarshalIndent(&struct{ Error string `json:"error"` }{ Error: string(e), }, "", " ") } func checkForAuthHeader(header http.Header) (string, bool) { auths, ok := header["Authorization"] if ok == false { return "", false } for _, auth := range(auths) { parts := strings.SplitN(auth, " ", 2) if len(parts) != 2 { continue } if parts[0] == "TM" { return parts[1], true } } return "", false } type ResolveContext struct { Context *Context GQLContext *GQLExtContext Server *Node Ext *GQLExt User *Node } func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { username, _, ok := r.BasicAuth() if ok == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") } auth_id, err := ParseID(username) if err != nil { return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) } user, exists := gql_ext.Users[auth_id] if exists == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id) } return &ResolveContext{ Context: ctx, GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), Server: server, User: user, }, nil } func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r * http.Request) { ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr) enableCORS(&w) header_map := map[string]interface{}{} for header, value := range(r.Header) { header_map[header] = value } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") resolve_context, err := NewResolveContext(ctx, server, gql_ext, r) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err))) return } req_ctx := context.Background() req_ctx = context.WithValue(req_ctx, "resolve", resolve_context) str, err := io.ReadAll(r.Body) if err != nil { ctx.Log.Logf("gql", "GQL_READ_ERR: %s", err) json.NewEncoder(w).Encode(fmt.Sprintf("%e", err)) return } query := GQLPayload{} json.Unmarshal(str, &query) gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext) params := graphql.Params{ Schema: gql_context.Schema, Context: req_ctx, RequestString: query.Query, } if query.OperationName != "" { params.OperationName = query.OperationName } if len(query.Variables) > 0 { params.VariableValues = query.Variables } result := graphql.Do(params) if len(result.Errors) > 0 { extra_fields := map[string]interface{}{} extra_fields["body"] = string(str) extra_fields["headers"] = r.Header ctx.Log.Logm("gql", extra_fields, "wrong result, unexpected errors: %v", result.Errors) } json.NewEncoder(w).Encode(result) } } func sendOneResultAndClose(res *graphql.Result) chan *graphql.Result { resultChannel := make(chan *graphql.Result) go func() { resultChannel <- res close(resultChannel) }() return resultChannel } func getOperationTypeOfReq(p graphql.Params) string{ source := source.NewSource(&source.Source{ Body: []byte(p.RequestString), Name: "GraphQL request", }) AST, err := parser.Parse(parser.ParseParams{Source: source}) if err != nil { return "" } for _, node := range AST.Definitions { if operationDef, ok := node.(*ast.OperationDefinition); ok { name := "" if operationDef.Name != nil { name = operationDef.Name.Value } if name == p.OperationName || p.OperationName == "" { return operationDef.Operation } } } return "" } func GQLWSDo(ctx * Context, p graphql.Params) chan *graphql.Result { operation := getOperationTypeOfReq(p) ctx.Log.Logf("gqlws", "GQLWSDO_OPERATION: %s %+v", operation, p.RequestString) if operation == ast.OperationTypeSubscription { return graphql.Subscribe(p) } res := graphql.Do(p) return sendOneResultAndClose(res) } func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r * http.Request) { ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr) enableCORS(&w) header_map := map[string]interface{}{} for header, value := range(r.Header) { header_map[header] = value } ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") resolve_context, err := NewResolveContext(ctx, server, gql_ext, r) if err != nil { ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) return } req_ctx := context.Background() req_ctx = context.WithValue(req_ctx, "resolve", resolve_context) u := ws.HTTPUpgrader{ Protocol: func(protocol string) bool { ctx.Log.Logf("gqlws", "UPGRADE_PROTOCOL: %s", string(protocol)) if string(protocol) == "graphql-transport-ws" || string(protocol) == "graphql-ws" { return true } return false }, } conn, _, _, err := u.Upgrade(r, w) if err == nil { defer conn.Close() conn_state := "init" for { msg_raw, op, err := wsutil.ReadClientData(conn) ctx.Log.Logf("gqlws_hb", "MSG: %s\nOP: 0x%02x\nERR: %+v\n", string(msg_raw), op, err) msg := GQLWSMsg{} json.Unmarshal(msg_raw, &msg) if err != nil { ctx.Log.Logf("gqlws", "WS_CLIENT_ERROR") break } if msg.Type == "connection_init" { if conn_state != "init" { ctx.Log.Logf("gqlws", "WS_CLIENT_ERROR: INIT WHILE IN %s", conn_state) break } conn_state = "ready" err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}")) if err != nil { ctx.Log.Logf("gqlws", "WS_SERVER_ERROR: FAILED TO SEND connection_ack") break } } else if msg.Type == "ping" { ctx.Log.Logf("gqlws_hb", "PING FROM %s", r.RemoteAddr) err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"pong\"}")) if err != nil { ctx.Log.Logf("gqlws", "WS_SERVER_ERROR: FAILED TO SEND PONG") } } else if msg.Type == "subscribe" { ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext) params := graphql.Params{ Schema: gql_context.Schema, Context: req_ctx, RequestString: msg.Payload.Query, } if msg.Payload.OperationName != "" { params.OperationName = msg.Payload.OperationName } if len(msg.Payload.Variables) > 0 { params.VariableValues = msg.Payload.Variables } res_chan := GQLWSDo(ctx, params) if res_chan == nil { ctx.Log.Logf("gqlws", "res_chan is nil") } else { ctx.Log.Logf("gqlws", "res_chan: %+v", res_chan) } go func(res_chan chan *graphql.Result) { for { next, ok := <-res_chan if ok == false { ctx.Log.Logf("gqlws", "response channel was closed") return } if next == nil { ctx.Log.Logf("gqlws", "NIL_ON_CHANNEL") return } if len(next.Errors) > 0 { extra_fields := map[string]interface{}{} extra_fields["query"] = string(msg.Payload.Query) ctx.Log.Logm("gqlws", extra_fields, "ERROR: wrong result, unexpected errors: %+v", next.Errors) continue } ctx.Log.Logf("gqlws", "DATA: %+v", next.Data) data, err := json.Marshal(next.Data) if err != nil { ctx.Log.Logf("gqlws", "ERROR: %+v", err) continue } msg, err := json.Marshal(GQLWSMsg{ ID: msg.ID, Type: "next", Payload: GQLPayload{ Data: string(data), }, }) if err != nil { ctx.Log.Logf("gqlws", "ERROR: %+v", err) continue } err = wsutil.WriteServerMessage(conn, 1, msg) if err != nil { ctx.Log.Logf("gqlws", "ERROR: %+v", err) continue } } }(res_chan) } else { } } return } else { panic("Failed to upgrade websocket") } } } type GQLInterface struct { Interface *graphql.Interface Default *graphql.Object List *graphql.List Extensions []ExtType } type GQLType struct { Type *graphql.Object List *graphql.List } func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init func(*GQLType)) *GQLType { var gql GQLType gql.Type = graphql.NewObject(graphql.ObjectConfig{ Name: string(node_type), Interfaces: interfaces, IsTypeOf: func(p graphql.IsTypeOfParams) bool { node, ok := p.Value.(*Node) if ok == false { return false } return node.Type == node_type }, Fields: graphql.Fields{}, }) gql.List = graphql.NewList(gql.Type) init(&gql) return &gql } func NewGQLInterface(if_name string, default_name string, interfaces []*graphql.Interface, extensions []ExtType, init_1 func(*GQLInterface), init_2 func(*GQLInterface)) *GQLInterface { var gql GQLInterface gql.Extensions = extensions gql.Interface = graphql.NewInterface(graphql.InterfaceConfig{ Name: if_name, ResolveType: NodeResolver([]ExtType{}, &gql.Default), Fields: graphql.Fields{}, }) gql.List = graphql.NewList(gql.Interface) init_1(&gql) gql.Default = graphql.NewObject(graphql.ObjectConfig{ Name: default_name, Interfaces: append(interfaces, gql.Interface), IsTypeOf: GQLNodeHasExtensions([]ExtType{}), Fields: graphql.Fields{}, }) init_2(&gql) return &gql } // GQL Specific Context information type GQLExtContext struct { // Generated GQL schema Schema graphql.Schema // Custom graphql types, mapped to NodeTypes NodeTypes map[NodeType]*graphql.Object Interfaces []*GQLInterface // Schema parameters Types []graphql.Type Query *graphql.Object Mutation *graphql.Object Subscription *graphql.Object } func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) { schemaConfig := graphql.SchemaConfig{ Types: ctx.Types, Query: ctx.Query, Mutation: ctx.Mutation, Subscription: ctx.Subscription, } return graphql.NewSchema(schemaConfig) } func (ctx *GQLExtContext) AddInterface(i *GQLInterface) error { if i == nil { return fmt.Errorf("interface is nil") } if i.Interface == nil || i.Extensions == nil || i.Default == nil || i.List == nil { return fmt.Errorf("invalid interface, contains nil") } ctx.Interfaces = append(ctx.Interfaces, i) ctx.Types = append(ctx.Types, i.Default) return nil } func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, gql_type *graphql.Object) error { if gql_type == nil { return fmt.Errorf("gql_type is nil") } _, exists := ctx.NodeTypes[node_type] if exists == true { return fmt.Errorf("%s already in GQLExtContext.NodeTypes", node_type) } ctx.NodeTypes[node_type] = gql_type ctx.Types = append(ctx.Types, gql_type) return nil } func NewGQLExtContext() *GQLExtContext { query := graphql.NewObject(graphql.ObjectConfig{ Name: "Query", Fields: graphql.Fields{}, }) query.AddFieldConfig("Self", GQLQuerySelf) query.AddFieldConfig("User", GQLQueryUser) mutation := graphql.NewObject(graphql.ObjectConfig{ Name: "Mutation", Fields: graphql.Fields{}, }) mutation.AddFieldConfig("stop", GQLMutationStop) mutation.AddFieldConfig("startChild", GQLMutationStartChild) subscription := graphql.NewObject(graphql.ObjectConfig{ Name: "Subscription", Fields: graphql.Fields{}, }) subscription.AddFieldConfig("Self", GQLSubscriptionSelf) subscription.AddFieldConfig("Update", GQLSubscriptionUpdate) context := GQLExtContext{ Schema: graphql.Schema{}, Types: []graphql.Type{}, Query: query, Mutation: mutation, Subscription: subscription, NodeTypes: map[NodeType]*graphql.Object{}, Interfaces: []*GQLInterface{}, } var err error err = context.AddInterface(GQLInterfaceNode) if err != nil { panic(err) } err = context.AddInterface(GQLInterfaceLockable) if err != nil { panic(err) } schema, err := BuildSchema(&context) if err != nil { panic(err) } context.Schema = schema return &context } type GQLExt struct { tcp_listener net.Listener http_server *http.Server http_done sync.WaitGroup tls_key []byte tls_cert []byte Listen string Users NodeMap Key *ecdsa.PrivateKey ECDH ecdh.Curve SubscribeLock sync.Mutex SubscribeListeners []chan Signal } func (ext *GQLExt) Field(name string) interface{} { return ResolveFields(ext, name, map[string]func(*GQLExt)interface{}{ "listen": func(ext *GQLExt) interface{} { return ext.Listen }, }) } func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan Signal { ext.SubscribeLock.Lock() defer ext.SubscribeLock.Unlock() new_listener := make(chan Signal, buffer) ext.SubscribeListeners = append(ext.SubscribeListeners, new_listener) return new_listener } func (ext *GQLExt) Process(context *Context, princ_id NodeID, node *Node, signal Signal) { if signal.Type() == ReadResultSignalType { } ext.SubscribeLock.Lock() defer ext.SubscribeLock.Unlock() active_listeners := []chan Signal{} for _, listener := range(ext.SubscribeListeners) { select { case listener <- signal: active_listeners = append(active_listeners, listener) default: go func(listener chan Signal) { listener <- NewDirectSignal("Channel Closed") close(listener) }(listener) } } ext.SubscribeListeners = active_listeners return } func (ext *GQLExt) Type() ExtType { return GQLExtType } type GQLExtJSON struct { Listen string `json:"listen"` Key []byte `json:"key"` ECDH uint8 `json:"ecdh_curve"` TLSKey []byte `json:"ssl_key"` TLSCert []byte `json:"ssl_cert"` } func (ext *GQLExt) Serialize() ([]byte, error) { ser_key, err := x509.MarshalECPrivateKey(ext.Key) if err != nil { return nil, err } return json.MarshalIndent(&GQLExtJSON{ Listen: ext.Listen, Key: ser_key, ECDH: ecdh_curve_ids[ext.ECDH], TLSKey: ext.tls_key, TLSCert: ext.tls_cert, }, "", " ") } var ecdsa_curves = map[uint8]elliptic.Curve{ 0: elliptic.P256(), } var ecdsa_curve_ids = map[elliptic.Curve]uint8{ elliptic.P256(): 0, } var ecdh_curves = map[uint8]ecdh.Curve{ 0: ecdh.P256(), } var ecdh_curve_ids = map[ecdh.Curve]uint8{ ecdh.P256(): 0, } func LoadGQLExt(ctx *Context, data []byte) (Extension, error) { var j GQLExtJSON err := json.Unmarshal(data, &j) if err != nil { return nil, err } ecdh_curve, ok := ecdh_curves[j.ECDH] if ok == false { return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH) } key, err := x509.ParseECPrivateKey(j.Key) if err != nil { return nil, err } return NewGQLExt(j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey), nil } func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) *GQLExt { if tls_cert == nil || tls_key == nil { ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader) if err != nil { panic(err) } ssl_key_bytes, err := x509.MarshalECPrivateKey(ssl_key) if err != nil { panic(err) } ssl_key_pem := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: ssl_key_bytes}) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit) notBefore := time.Now() notAfter := notBefore.Add(365*24*time.Hour) template := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ Organization: []string{"mekkanized"}, }, NotBefore: notBefore, NotAfter: notAfter, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &ssl_key.PublicKey, ssl_key) if err != nil { panic(err) } ssl_cert_pem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ssl_cert}) tls_cert = ssl_cert_pem tls_key = ssl_key_pem } return &GQLExt{ Listen: listen, SubscribeListeners: []chan Signal{}, Key: key, ECDH: ecdh_curve, tls_cert: tls_cert, tls_key: tls_key, } } func StartGQLServer(ctx *Context, node *Node, gql_ext *GQLExt) error { mux := http.NewServeMux() mux.HandleFunc("/gql", GQLHandler(ctx, node, gql_ext)) mux.HandleFunc("/gqlws", GQLWSHandler(ctx, node, gql_ext)) // Server a graphiql interface(TODO make configurable whether to start this) mux.HandleFunc("/graphiql", GraphiQLHandler()) // Server the ./site directory to /site (TODO make configurable with better defaults) fs := http.FileServer(http.Dir("./site")) mux.Handle("/site/", http.StripPrefix("/site", fs)) http_server := &http.Server{ Addr: gql_ext.Listen, Handler: mux, } l, err := net.Listen("tcp", http_server.Addr) if err != nil { return fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) } cert, err := tls.X509KeyPair(gql_ext.tls_cert, gql_ext.tls_key) if err != nil { return err } config := tls.Config{ Certificates: []tls.Certificate{cert}, NextProtos: []string{"http/1.1"}, } listener := tls.NewListener(l, &config) gql_ext.http_done.Add(1) go func(qql_ext *GQLExt) { defer gql_ext.http_done.Done() err := http_server.Serve(listener) if err != http.ErrServerClosed { panic(fmt.Sprintf("Failed to start gql server: %s", err)) } }(gql_ext) gql_ext.tcp_listener = listener gql_ext.http_server = http_server return nil } func StopGQLServer(gql_ext *GQLExt) { gql_ext.http_server.Shutdown(context.TODO()) gql_ext.http_done.Wait() }