package graphvent
import (
"time"
"net"
"net/http"
"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/parser"
"github.com/graphql-go/graphql/language/source"
"github.com/graphql-go/graphql/language/ast"
"context"
"encoding/json"
"io"
"reflect"
"fmt"
"sync"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"strings"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/tls"
"crypto/x509/pkix"
"math/big"
"encoding/pem"
)
func GraphiQLHandler() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
graphiql_string := fmt.Sprintf(`
GraphiQL
Loading...
`)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
io.WriteString(w, graphiql_string)
}
}
type GQLPayload struct {
OperationName string `json:"operationName,omitempty"`
Query string `json:"query,omitempty"`
Variables map[string]interface{} `json:"variables,omitempty"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Data string `json:"data,omitempty"`
}
type GQLWSMsg struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Payload GQLPayload `json:"payload,omitempty"`
}
func enableCORS(w *http.ResponseWriter) {
(*w).Header().Set("Access-Control-Allow-Origin", "*")
(*w).Header().Set("Access-Control-Allow-Credentials", "true")
(*w).Header().Set("Access-Control-Allow-Headers", "*")
(*w).Header().Set("Access-Control-Allow-Methods", "*")
}
type GQLUnauthorized string
func (e GQLUnauthorized) Is(target error) bool {
error_type := reflect.TypeOf(GQLUnauthorized(""))
target_type := reflect.TypeOf(target)
return error_type == target_type
}
func (e GQLUnauthorized) Error() string {
return fmt.Sprintf("GQL_UNAUTHORIZED_ERROR: %s", string(e))
}
func (e GQLUnauthorized) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(&struct{
Error string `json:"error"`
}{
Error: string(e),
}, "", " ")
}
func checkForAuthHeader(header http.Header) (string, bool) {
auths, ok := header["Authorization"]
if ok == false {
return "", false
}
for _, auth := range(auths) {
parts := strings.SplitN(auth, " ", 2)
if len(parts) != 2 {
continue
}
if parts[0] == "TM" {
return parts[1], true
}
}
return "", false
}
type ResolveContext struct {
Context *Context
GQLContext *GQLExtContext
Server *Node
Ext *GQLExt
User *Node
}
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) {
username, _, ok := r.BasicAuth()
if ok == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header")
}
auth_id, err := ParseID(username)
if err != nil {
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username)
}
user, exists := gql_ext.Users[auth_id]
if exists == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id)
}
return &ResolveContext{
Context: ctx,
GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext),
Server: server,
User: user,
}, nil
}
func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr)
enableCORS(&w)
header_map := map[string]interface{}{}
for header, value := range(r.Header) {
header_map[header] = value
}
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, gql_ext, r)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(fmt.Sprintf("%s", err)))
return
}
req_ctx := context.Background()
req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
str, err := io.ReadAll(r.Body)
if err != nil {
ctx.Log.Logf("gql", "GQL_READ_ERR: %s", err)
json.NewEncoder(w).Encode(fmt.Sprintf("%e", err))
return
}
query := GQLPayload{}
json.Unmarshal(str, &query)
gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext)
params := graphql.Params{
Schema: gql_context.Schema,
Context: req_ctx,
RequestString: query.Query,
}
if query.OperationName != "" {
params.OperationName = query.OperationName
}
if len(query.Variables) > 0 {
params.VariableValues = query.Variables
}
result := graphql.Do(params)
if len(result.Errors) > 0 {
extra_fields := map[string]interface{}{}
extra_fields["body"] = string(str)
extra_fields["headers"] = r.Header
ctx.Log.Logm("gql", extra_fields, "wrong result, unexpected errors: %v", result.Errors)
}
json.NewEncoder(w).Encode(result)
}
}
func sendOneResultAndClose(res *graphql.Result) chan *graphql.Result {
resultChannel := make(chan *graphql.Result)
go func() {
resultChannel <- res
close(resultChannel)
}()
return resultChannel
}
func getOperationTypeOfReq(p graphql.Params) string{
source := source.NewSource(&source.Source{
Body: []byte(p.RequestString),
Name: "GraphQL request",
})
AST, err := parser.Parse(parser.ParseParams{Source: source})
if err != nil {
return ""
}
for _, node := range AST.Definitions {
if operationDef, ok := node.(*ast.OperationDefinition); ok {
name := ""
if operationDef.Name != nil {
name = operationDef.Name.Value
}
if name == p.OperationName || p.OperationName == "" {
return operationDef.Operation
}
}
}
return ""
}
func GQLWSDo(ctx * Context, p graphql.Params) chan *graphql.Result {
operation := getOperationTypeOfReq(p)
ctx.Log.Logf("gqlws", "GQLWSDO_OPERATION: %s %+v", operation, p.RequestString)
if operation == ast.OperationTypeSubscription {
return graphql.Subscribe(p)
}
res := graphql.Do(p)
return sendOneResultAndClose(res)
}
func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr)
enableCORS(&w)
header_map := map[string]interface{}{}
for header, value := range(r.Header) {
header_map[header] = value
}
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, gql_ext, r)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
return
}
req_ctx := context.Background()
req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
u := ws.HTTPUpgrader{
Protocol: func(protocol string) bool {
ctx.Log.Logf("gqlws", "UPGRADE_PROTOCOL: %s", string(protocol))
if string(protocol) == "graphql-transport-ws" || string(protocol) == "graphql-ws" {
return true
}
return false
},
}
conn, _, _, err := u.Upgrade(r, w)
if err == nil {
defer conn.Close()
conn_state := "init"
for {
msg_raw, op, err := wsutil.ReadClientData(conn)
ctx.Log.Logf("gqlws_hb", "MSG: %s\nOP: 0x%02x\nERR: %+v\n", string(msg_raw), op, err)
msg := GQLWSMsg{}
json.Unmarshal(msg_raw, &msg)
if err != nil {
ctx.Log.Logf("gqlws", "WS_CLIENT_ERROR")
break
}
if msg.Type == "connection_init" {
if conn_state != "init" {
ctx.Log.Logf("gqlws", "WS_CLIENT_ERROR: INIT WHILE IN %s", conn_state)
break
}
conn_state = "ready"
err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}"))
if err != nil {
ctx.Log.Logf("gqlws", "WS_SERVER_ERROR: FAILED TO SEND connection_ack")
break
}
} else if msg.Type == "ping" {
ctx.Log.Logf("gqlws_hb", "PING FROM %s", r.RemoteAddr)
err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"pong\"}"))
if err != nil {
ctx.Log.Logf("gqlws", "WS_SERVER_ERROR: FAILED TO SEND PONG")
}
} else if msg.Type == "subscribe" {
ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload)
gql_context := ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext)
params := graphql.Params{
Schema: gql_context.Schema,
Context: req_ctx,
RequestString: msg.Payload.Query,
}
if msg.Payload.OperationName != "" {
params.OperationName = msg.Payload.OperationName
}
if len(msg.Payload.Variables) > 0 {
params.VariableValues = msg.Payload.Variables
}
res_chan := GQLWSDo(ctx, params)
if res_chan == nil {
ctx.Log.Logf("gqlws", "res_chan is nil")
} else {
ctx.Log.Logf("gqlws", "res_chan: %+v", res_chan)
}
go func(res_chan chan *graphql.Result) {
for {
next, ok := <-res_chan
if ok == false {
ctx.Log.Logf("gqlws", "response channel was closed")
return
}
if next == nil {
ctx.Log.Logf("gqlws", "NIL_ON_CHANNEL")
return
}
if len(next.Errors) > 0 {
extra_fields := map[string]interface{}{}
extra_fields["query"] = string(msg.Payload.Query)
ctx.Log.Logm("gqlws", extra_fields, "ERROR: wrong result, unexpected errors: %+v", next.Errors)
continue
}
ctx.Log.Logf("gqlws", "DATA: %+v", next.Data)
data, err := json.Marshal(next.Data)
if err != nil {
ctx.Log.Logf("gqlws", "ERROR: %+v", err)
continue
}
msg, err := json.Marshal(GQLWSMsg{
ID: msg.ID,
Type: "next",
Payload: GQLPayload{
Data: string(data),
},
})
if err != nil {
ctx.Log.Logf("gqlws", "ERROR: %+v", err)
continue
}
err = wsutil.WriteServerMessage(conn, 1, msg)
if err != nil {
ctx.Log.Logf("gqlws", "ERROR: %+v", err)
continue
}
}
}(res_chan)
} else {
}
}
return
} else {
panic("Failed to upgrade websocket")
}
}
}
type GQLInterface struct {
Interface *graphql.Interface
Default *graphql.Object
List *graphql.List
Extensions []ExtType
}
type GQLType struct {
Type *graphql.Object
List *graphql.List
}
func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init func(*GQLType)) *GQLType {
var gql GQLType
gql.Type = graphql.NewObject(graphql.ObjectConfig{
Name: string(node_type),
Interfaces: interfaces,
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
node, ok := p.Value.(*Node)
if ok == false {
return false
}
return node.Type == node_type
},
Fields: graphql.Fields{},
})
gql.List = graphql.NewList(gql.Type)
init(&gql)
return &gql
}
func NewGQLInterface(if_name string, default_name string, interfaces []*graphql.Interface, extensions []ExtType, init_1 func(*GQLInterface), init_2 func(*GQLInterface)) *GQLInterface {
var gql GQLInterface
gql.Extensions = extensions
gql.Interface = graphql.NewInterface(graphql.InterfaceConfig{
Name: if_name,
ResolveType: NodeResolver([]ExtType{}, &gql.Default),
Fields: graphql.Fields{},
})
gql.List = graphql.NewList(gql.Interface)
init_1(&gql)
gql.Default = graphql.NewObject(graphql.ObjectConfig{
Name: default_name,
Interfaces: append(interfaces, gql.Interface),
IsTypeOf: GQLNodeHasExtensions([]ExtType{}),
Fields: graphql.Fields{},
})
init_2(&gql)
return &gql
}
// GQL Specific Context information
type GQLExtContext struct {
// Generated GQL schema
Schema graphql.Schema
// Custom graphql types, mapped to NodeTypes
NodeTypes map[NodeType]*graphql.Object
Interfaces []*GQLInterface
// Schema parameters
Types []graphql.Type
Query *graphql.Object
Mutation *graphql.Object
Subscription *graphql.Object
}
func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) {
schemaConfig := graphql.SchemaConfig{
Types: ctx.Types,
Query: ctx.Query,
Mutation: ctx.Mutation,
Subscription: ctx.Subscription,
}
return graphql.NewSchema(schemaConfig)
}
func (ctx *GQLExtContext) AddInterface(i *GQLInterface) error {
if i == nil {
return fmt.Errorf("interface is nil")
}
if i.Interface == nil || i.Extensions == nil || i.Default == nil || i.List == nil {
return fmt.Errorf("invalid interface, contains nil")
}
ctx.Interfaces = append(ctx.Interfaces, i)
ctx.Types = append(ctx.Types, i.Default)
return nil
}
func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, gql_type *graphql.Object) error {
if gql_type == nil {
return fmt.Errorf("gql_type is nil")
}
_, exists := ctx.NodeTypes[node_type]
if exists == true {
return fmt.Errorf("%s already in GQLExtContext.NodeTypes", node_type)
}
ctx.NodeTypes[node_type] = gql_type
ctx.Types = append(ctx.Types, gql_type)
return nil
}
func NewGQLExtContext() *GQLExtContext {
query := graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: graphql.Fields{},
})
query.AddFieldConfig("Self", GQLQuerySelf)
query.AddFieldConfig("User", GQLQueryUser)
mutation := graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: graphql.Fields{},
})
mutation.AddFieldConfig("stop", GQLMutationStop)
mutation.AddFieldConfig("startChild", GQLMutationStartChild)
subscription := graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: graphql.Fields{},
})
subscription.AddFieldConfig("Self", GQLSubscriptionSelf)
subscription.AddFieldConfig("Update", GQLSubscriptionUpdate)
context := GQLExtContext{
Schema: graphql.Schema{},
Types: []graphql.Type{},
Query: query,
Mutation: mutation,
Subscription: subscription,
NodeTypes: map[NodeType]*graphql.Object{},
Interfaces: []*GQLInterface{},
}
var err error
err = context.AddInterface(GQLInterfaceNode)
if err != nil {
panic(err)
}
err = context.AddInterface(GQLInterfaceLockable)
if err != nil {
panic(err)
}
schema, err := BuildSchema(&context)
if err != nil {
panic(err)
}
context.Schema = schema
return &context
}
type GQLExt struct {
tcp_listener net.Listener
http_server *http.Server
http_done sync.WaitGroup
tls_key []byte
tls_cert []byte
Listen string
Users NodeMap
Key *ecdsa.PrivateKey
ECDH ecdh.Curve
SubscribeLock sync.Mutex
SubscribeListeners []chan Signal
}
func (ext *GQLExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*GQLExt)interface{}{
"listen": func(ext *GQLExt) interface{} {
return ext.Listen
},
})
}
func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan Signal {
ext.SubscribeLock.Lock()
defer ext.SubscribeLock.Unlock()
new_listener := make(chan Signal, buffer)
ext.SubscribeListeners = append(ext.SubscribeListeners, new_listener)
return new_listener
}
func (ext *GQLExt) Process(context *Context, princ_id NodeID, node *Node, signal Signal) {
if signal.Type() == ReadResultSignalType {
}
ext.SubscribeLock.Lock()
defer ext.SubscribeLock.Unlock()
active_listeners := []chan Signal{}
for _, listener := range(ext.SubscribeListeners) {
select {
case listener <- signal:
active_listeners = append(active_listeners, listener)
default:
go func(listener chan Signal) {
listener <- NewDirectSignal("Channel Closed")
close(listener)
}(listener)
}
}
ext.SubscribeListeners = active_listeners
return
}
func (ext *GQLExt) Type() ExtType {
return GQLExtType
}
type GQLExtJSON struct {
Listen string `json:"listen"`
Key []byte `json:"key"`
ECDH uint8 `json:"ecdh_curve"`
TLSKey []byte `json:"ssl_key"`
TLSCert []byte `json:"ssl_cert"`
}
func (ext *GQLExt) Serialize() ([]byte, error) {
ser_key, err := x509.MarshalECPrivateKey(ext.Key)
if err != nil {
return nil, err
}
return json.MarshalIndent(&GQLExtJSON{
Listen: ext.Listen,
Key: ser_key,
ECDH: ecdh_curve_ids[ext.ECDH],
TLSKey: ext.tls_key,
TLSCert: ext.tls_cert,
}, "", " ")
}
var ecdsa_curves = map[uint8]elliptic.Curve{
0: elliptic.P256(),
}
var ecdsa_curve_ids = map[elliptic.Curve]uint8{
elliptic.P256(): 0,
}
var ecdh_curves = map[uint8]ecdh.Curve{
0: ecdh.P256(),
}
var ecdh_curve_ids = map[ecdh.Curve]uint8{
ecdh.P256(): 0,
}
func LoadGQLExt(ctx *Context, data []byte) (Extension, error) {
var j GQLExtJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
ecdh_curve, ok := ecdh_curves[j.ECDH]
if ok == false {
return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH)
}
key, err := x509.ParseECPrivateKey(j.Key)
if err != nil {
return nil, err
}
return NewGQLExt(j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey), nil
}
func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) *GQLExt {
if tls_cert == nil || tls_key == nil {
ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader)
if err != nil {
panic(err)
}
ssl_key_bytes, err := x509.MarshalECPrivateKey(ssl_key)
if err != nil {
panic(err)
}
ssl_key_pem := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: ssl_key_bytes})
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit)
notBefore := time.Now()
notAfter := notBefore.Add(365*24*time.Hour)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"mekkanized"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &ssl_key.PublicKey, ssl_key)
if err != nil {
panic(err)
}
ssl_cert_pem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ssl_cert})
tls_cert = ssl_cert_pem
tls_key = ssl_key_pem
}
return &GQLExt{
Listen: listen,
SubscribeListeners: []chan Signal{},
Key: key,
ECDH: ecdh_curve,
tls_cert: tls_cert,
tls_key: tls_key,
}
}
func StartGQLServer(ctx *Context, node *Node, gql_ext *GQLExt) error {
mux := http.NewServeMux()
mux.HandleFunc("/gql", GQLHandler(ctx, node, gql_ext))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, node, gql_ext))
// Server a graphiql interface(TODO make configurable whether to start this)
mux.HandleFunc("/graphiql", GraphiQLHandler())
// Server the ./site directory to /site (TODO make configurable with better defaults)
fs := http.FileServer(http.Dir("./site"))
mux.Handle("/site/", http.StripPrefix("/site", fs))
http_server := &http.Server{
Addr: gql_ext.Listen,
Handler: mux,
}
l, err := net.Listen("tcp", http_server.Addr)
if err != nil {
return fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
}
cert, err := tls.X509KeyPair(gql_ext.tls_cert, gql_ext.tls_key)
if err != nil {
return err
}
config := tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"http/1.1"},
}
listener := tls.NewListener(l, &config)
gql_ext.http_done.Add(1)
go func(qql_ext *GQLExt) {
defer gql_ext.http_done.Done()
err := http_server.Serve(listener)
if err != http.ErrServerClosed {
panic(fmt.Sprintf("Failed to start gql server: %s", err))
}
}(gql_ext)
gql_ext.tcp_listener = listener
gql_ext.http_server = http_server
return nil
}
func StopGQLServer(gql_ext *GQLExt) {
gql_ext.http_server.Shutdown(context.TODO())
gql_ext.http_done.Wait()
}