package graphvent
import (
"bytes"
"context"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"os"
"net"
"net/http"
"reflect"
"strings"
"sync"
"time"
"filippo.io/edwards25519"
"crypto/sha512"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast"
"github.com/graphql-go/graphql/language/parser"
"github.com/graphql-go/graphql/language/source"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"github.com/google/uuid"
)
func NodeInterfaceDefaultIsType(required_extensions []ExtType) func(graphql.IsTypeOfParams) bool {
return func(p graphql.IsTypeOfParams) bool {
ctx, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return false
}
node, ok := p.Value.(NodeResult)
if ok == false {
return false
}
node_type_def, exists := ctx.Context.Nodes[node.NodeType]
if exists == false {
return false
} else {
for _, ext := range(required_extensions) {
found := false
for _, e := range(node_type_def.Extensions) {
if e == ext {
found = true
break
}
}
if found == false {
return false
}
}
}
return true
}
}
func NodeInterfaceResolveType(required_extensions []ExtType, default_type **graphql.Object)func(graphql.ResolveTypeParams) *graphql.Object {
return func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil
}
node, ok := p.Value.(NodeResult)
if ok == false {
return nil
}
gql_type, exists := ctx.GQLContext.NodeTypes[node.NodeType]
ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type)
if exists == false {
node_type_def, exists := ctx.Context.Nodes[node.NodeType]
if exists == false {
return nil
} else {
for _, ext := range(required_extensions) {
found := false
for _, e := range(node_type_def.Extensions) {
if e == ext {
found = true
break
}
}
if found == false {
return nil
}
}
}
return *default_type
}
return gql_type
}
}
func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil, fmt.Errorf("Bad resolve in params context")
}
return resolve_context, nil
}
// TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does
func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) {
var zero K
arg_if, ok := p.Args[name]
if ok == false {
return zero, fmt.Errorf("No Arg of name %s", name)
}
arg, ok := arg_if.(K)
if ok == false {
return zero, fmt.Errorf("Failed to cast arg %s(%+v) to %+v", name, arg_if, reflect.TypeOf(zero))
}
return arg, nil
}
func ExtractList[K interface{}](p graphql.ResolveParams, name string) ([]K, error) {
var zero K
arg_list, err := ExtractParam[[]interface{}](p, name)
if err != nil {
return nil, err
}
ret := make([]K, len(arg_list))
for i, val := range(arg_list) {
val_conv, ok := arg_list[i].(K)
if ok == false {
return nil, fmt.Errorf("Failed to cast arg %s[%d](%+v) to %+v", name, i, val, reflect.TypeOf(zero))
}
ret[i] = val_conv
}
return ret, nil
}
func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
id_str, err := ExtractParam[string](p, name)
if err != nil {
return ZeroID, err
}
id, err := ParseID(id_str)
if err != nil {
return ZeroID, err
}
return id, nil
}
func GraphiQLHandler() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
graphiql_string := fmt.Sprintf(`
GraphiQL
Loading...
`)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
io.WriteString(w, graphiql_string)
}
}
type GQLPayload struct {
OperationName string `json:"operationName,omitempty"`
Query string `json:"query,omitempty"`
Variables map[string]interface{} `json:"variables,omitempty"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Data string `json:"data,omitempty"`
}
type GQLWSMsg struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
Payload GQLPayload `json:"payload,omitempty"`
}
func enableCORS(w *http.ResponseWriter) {
(*w).Header().Set("Access-Control-Allow-Origin", "*")
(*w).Header().Set("Access-Control-Allow-Credentials", "true")
(*w).Header().Set("Access-Control-Allow-Headers", "*")
(*w).Header().Set("Access-Control-Allow-Methods", "*")
}
type GQLUnauthorized string
func (e GQLUnauthorized) Is(target error) bool {
error_type := reflect.TypeOf(GQLUnauthorized(""))
target_type := reflect.TypeOf(target)
return error_type == target_type
}
func (e GQLUnauthorized) Error() string {
return fmt.Sprintf("GQL_UNAUTHORIZED_ERROR: %s", string(e))
}
func (e GQLUnauthorized) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(&struct{
Error string `json:"error"`
}{
Error: string(e),
}, "", " ")
}
// Context passed to each resolve execution
type ResolveContext struct {
// Resolution ID
ID uuid.UUID
// Channels for the gql extension to route data to this context
Chans map[uuid.UUID]chan Signal
// Graph Context this resolver is running under
Context *Context
// GQL Extension context this resolver is running under
GQLContext *GQLExtContext
// Pointer to the node that's currently processing this request
Server *Node
// The state data for the node processing this request
Ext *GQLExt
// Cache of resolved nodes
NodeCache map[NodeID]NodeResult
// Authorization from the user that started this request
Authorization *ClientAuthorization
}
func AuthB64(client_key ed25519.PrivateKey, server_pubkey ed25519.PublicKey) (string, error) {
token_start := time.Now()
token_start_bytes, err := token_start.MarshalBinary()
if err != nil {
return "", err
}
session_key_public, session_key_private, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return "", err
}
session_h := sha512.Sum512(session_key_private.Seed())
ecdh_client, err := ECDH.NewPrivateKey(session_h[:32])
if err != nil {
return "", err
}
server_point, err := (&edwards25519.Point{}).SetBytes(server_pubkey)
if err != nil {
return "", err
}
ecdh_server, err := ECDH.NewPublicKey(server_point.BytesMontgomery())
if err != nil {
return "", err
}
secret, err := ecdh_client.ECDH(ecdh_server)
if err != nil {
return "", err
}
if len(secret) != 32 {
return "", fmt.Errorf("ECDH secret not 32 bytes(for AES-256): %d bytes long", len(secret))
}
block, err := aes.NewCipher(secret)
if err != nil {
return "", err
}
iv := make([]byte, block.BlockSize())
iv_len, err := rand.Reader.Read(iv)
if err != nil {
return "", err
} else if iv_len != block.BlockSize() {
return "", fmt.Errorf("Not enough iv bytes read: %d", iv_len)
}
var key_encrypted bytes.Buffer
stream := cipher.NewOFB(block, iv)
writer := &cipher.StreamWriter{S: stream, W: &key_encrypted}
bytes_written, err := writer.Write(session_key_private.Seed())
if err != nil {
return "", err
} else if bytes_written != len(ecdh_client.Bytes()) {
return "", fmt.Errorf("wrong number of bytes encrypted %d/%d", bytes_written, len(ecdh_client.Bytes()))
}
digest := append(session_key_public, token_start_bytes...)
signature, err := client_key.Sign(rand.Reader, digest, crypto.Hash(0))
if err != nil {
return "", err
}
start_b64 := base64.StdEncoding.EncodeToString(token_start_bytes)
iv_b64 := base64.StdEncoding.EncodeToString(iv)
encrypted_b64 := base64.StdEncoding.EncodeToString(key_encrypted.Bytes())
key_b64 := base64.StdEncoding.EncodeToString(session_key_public)
sig_b64 := base64.StdEncoding.EncodeToString(signature)
id_b64 := base64.StdEncoding.EncodeToString(client_key.Public().(ed25519.PublicKey))
return base64.StdEncoding.EncodeToString([]byte(strings.Join([]string{id_b64, iv_b64, key_b64, encrypted_b64, start_b64, sig_b64}, ":"))), nil
}
func ParseAuthB64(auth_base64 string, server_id ed25519.PrivateKey) (*ClientAuthorization, error) {
joined_b64, err := base64.StdEncoding.DecodeString(auth_base64)
if err != nil {
return nil, err
}
auth_parts := strings.Split(string(joined_b64), ":")
if len(auth_parts) != 6 {
return nil, fmt.Errorf("Wrong number of delimited elements %d", len(auth_parts))
}
id_bytes, err := base64.StdEncoding.DecodeString(auth_parts[0])
if err != nil {
return nil, err
}
iv, err := base64.StdEncoding.DecodeString(auth_parts[1])
if err != nil {
return nil, err
}
public_key, err := base64.StdEncoding.DecodeString(auth_parts[2])
if err != nil {
return nil, err
}
key_encrypted, err := base64.StdEncoding.DecodeString(auth_parts[3])
if err != nil {
return nil, err
}
start_bytes, err := base64.StdEncoding.DecodeString(auth_parts[4])
if err != nil {
return nil, err
}
signature, err := base64.StdEncoding.DecodeString(auth_parts[5])
if err != nil {
return nil, err
}
var start time.Time
err = start.UnmarshalBinary(start_bytes)
if err != nil {
return nil, err
}
client_id := ed25519.PublicKey(id_bytes)
if err != nil {
return nil, err
}
client_point, err := (&edwards25519.Point{}).SetBytes(public_key)
if err != nil {
return nil, err
}
ecdh_client, err := ECDH.NewPublicKey(client_point.BytesMontgomery())
if err != nil {
return nil, err
}
h := sha512.Sum512(server_id.Seed())
ecdh_server, err := ECDH.NewPrivateKey(h[:32])
if err != nil {
return nil, err
}
secret, err := ecdh_server.ECDH(ecdh_client)
if err != nil {
return nil, err
} else if len(secret) != 32 {
return nil, fmt.Errorf("Secret wrong length: %d/32", len(secret))
}
block, err := aes.NewCipher(secret)
if err != nil {
return nil, err
}
encrypted_reader := bytes.NewReader(key_encrypted)
stream := cipher.NewOFB(block, iv)
reader := cipher.StreamReader{S: stream, R: encrypted_reader}
var decrypted_key bytes.Buffer
_, err = io.Copy(&decrypted_key, reader)
if err != nil {
return nil, err
}
session_key := ed25519.NewKeyFromSeed(decrypted_key.Bytes())
digest := append(session_key.Public().(ed25519.PublicKey), start_bytes...)
if ed25519.Verify(client_id, digest, signature) == false {
return nil, fmt.Errorf("Failed to verify digest/signature against client_id")
}
return &ClientAuthorization{
AuthInfo: AuthInfo{
Identity: client_id,
Start: start,
Signature: signature,
},
Key: session_key,
}, nil
}
func ValidateAuthorization(auth Authorization, valid time.Duration) error {
// Check that the time + valid < now
// Check that Signature is public_key + start signed with client_id
if auth.Start.Add(valid).Compare(time.Now()) != 1 {
return fmt.Errorf("authorization expired")
}
time_bytes, err := auth.Start.MarshalBinary()
if err != nil {
return err
}
digest := append(auth.Key, time_bytes...)
if ed25519.Verify(auth.Identity, digest, auth.Signature) != true {
return fmt.Errorf("verification failed")
}
return nil
}
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveContext, error) {
return &ResolveContext{
ID: uuid.New(),
Ext: gql_ext,
Chans: map[uuid.UUID]chan Signal{},
Context: ctx,
GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext),
NodeCache: map[NodeID]NodeResult{},
Server: server,
Authorization: nil,
}, nil
}
func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr)
enableCORS(&w)
header_map := map[string]interface{}{}
for header, value := range(r.Header) {
header_map[header] = value
}
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
auth, err := ParseAuthB64(r.Header.Get("Authorization"), server.Key)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ID_PARSE_ERROR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(""))
return
}
resolve_context, err := NewResolveContext(ctx, server, gql_ext)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(""))
return
}
resolve_context.Authorization = auth
req_ctx := context.Background()
req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
str, err := io.ReadAll(r.Body)
if err != nil {
ctx.Log.Logf("gql", "GQL_READ_ERR: %s", err)
json.NewEncoder(w).Encode(fmt.Sprintf("%e", err))
return
}
query := GQLPayload{}
json.Unmarshal(str, &query)
gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext)
params := graphql.Params{
Schema: gql_context.Schema,
Context: req_ctx,
RequestString: query.Query,
}
if query.OperationName != "" {
params.OperationName = query.OperationName
}
if len(query.Variables) > 0 {
params.VariableValues = query.Variables
}
result := graphql.Do(params)
if len(result.Errors) > 0 {
extra_fields := map[string]interface{}{}
extra_fields["body"] = string(str)
extra_fields["headers"] = r.Header
ctx.Log.Logm("gql", extra_fields, "wrong result, unexpected errors: %v", result.Errors)
}
json.NewEncoder(w).Encode(result)
}
}
func sendOneResultAndClose(res *graphql.Result) chan *graphql.Result {
resultChannel := make(chan *graphql.Result)
go func() {
resultChannel <- res
close(resultChannel)
}()
return resultChannel
}
func getOperationTypeOfReq(p graphql.Params) string{
source := source.NewSource(&source.Source{
Body: []byte(p.RequestString),
Name: "GraphQL request",
})
AST, err := parser.Parse(parser.ParseParams{Source: source})
if err != nil {
return err.Error()
}
for _, node := range AST.Definitions {
if operationDef, ok := node.(*ast.OperationDefinition); ok {
name := ""
if operationDef.Name != nil {
name = operationDef.Name.Value
}
if name == p.OperationName || p.OperationName == "" {
return operationDef.Operation
}
}
}
return "END_OF_FUNCTION"
}
func GQLWSDo(ctx * Context, p graphql.Params) chan *graphql.Result {
operation := getOperationTypeOfReq(p)
ctx.Log.Logf("gqlws", "GQLWSDO_OPERATION: %s - %+v", operation, p.RequestString)
if operation == ast.OperationTypeSubscription {
return graphql.Subscribe(p)
}
res := graphql.Do(p)
return sendOneResultAndClose(res)
}
func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr)
enableCORS(&w)
header_map := map[string]interface{}{}
for header, value := range(r.Header) {
header_map[header] = value
}
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
resolve_context, err := NewResolveContext(ctx, server, gql_ext)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
return
}
req_ctx := context.Background()
req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
u := ws.HTTPUpgrader{
Protocol: func(protocol string) bool {
ctx.Log.Logf("gqlws", "UPGRADE_PROTOCOL: %s", string(protocol))
if string(protocol) == "graphql-transport-ws" || string(protocol) == "graphql-ws" {
return true
}
return false
},
}
conn, _, _, err := u.Upgrade(r, w)
if err == nil {
defer conn.Close()
conn_state := "init"
for {
msg_raw, op, err := wsutil.ReadClientData(conn)
ctx.Log.Logf("gqlws_hb", "MSG: %s\nOP: 0x%02x\nERR: %+v\n", string(msg_raw), op, err)
msg := GQLWSMsg{}
json.Unmarshal(msg_raw, &msg)
if err != nil {
ctx.Log.Logf("gqlws", "WS_CLIENT_ERROR")
break
}
if msg.Type == "connection_init" {
if conn_state != "init" {
ctx.Log.Logf("gqlws", "WS_CLIENT_ERROR: INIT WHILE IN %s", conn_state)
break
}
connection_params := struct {
Payload struct {
Token string `json:"token"`
} `json:"payload"`
}{}
err := json.Unmarshal([]byte(msg_raw), &connection_params)
if err != nil {
ctx.Log.Logf("gqlws", "WS_UNMARSHAL_ERROR: %s - %+v", msg_raw, err)
break
}
authorization, err := ParseAuthB64(connection_params.Payload.Token, server.Key)
if err != nil {
ctx.Log.Logf("gqlws", "WS_AUTH_PARSE_ERR: %s", err)
break
}
resolve_context.Authorization = authorization
conn_state = "ready"
err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}"))
if err != nil {
ctx.Log.Logf("gqlws", "WS_SERVER_ERROR: FAILED TO SEND connection_ack")
break
}
} else if msg.Type == "ping" {
ctx.Log.Logf("gqlws_hb", "PING FROM %s", r.RemoteAddr)
err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"pong\"}"))
if err != nil {
ctx.Log.Logf("gqlws", "WS_SERVER_ERROR: FAILED TO SEND PONG")
break
}
} else if msg.Type == "subscribe" {
ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload)
gql_context := ctx.Extensions[GQLExtType].Data.(*GQLExtContext)
params := graphql.Params{
Schema: gql_context.Schema,
Context: req_ctx,
RequestString: msg.Payload.Query,
}
if msg.Payload.OperationName != "" {
params.OperationName = msg.Payload.OperationName
}
if len(msg.Payload.Variables) > 0 {
params.VariableValues = msg.Payload.Variables
}
res_chan := GQLWSDo(ctx, params)
if res_chan == nil {
ctx.Log.Logf("gqlws", "res_chan is nil")
} else {
ctx.Log.Logf("gqlws", "res_chan: %+v", res_chan)
}
go func(res_chan chan *graphql.Result) {
for {
next, ok := <-res_chan
if ok == false {
ctx.Log.Logf("gqlws", "response channel was closed")
return
}
if next == nil {
ctx.Log.Logf("gqlws", "NIL_ON_CHANNEL")
return
}
if len(next.Errors) > 0 {
extra_fields := map[string]interface{}{}
extra_fields["query"] = string(msg.Payload.Query)
ctx.Log.Logm("gqlws", extra_fields, "ERROR: wrong result, unexpected errors: %+v", next.Errors)
continue
}
ctx.Log.Logf("gqlws", "DATA: %+v", next.Data)
data, err := json.Marshal(next.Data)
if err != nil {
ctx.Log.Logf("gqlws", "ERROR: %+v", err)
continue
}
msg, err := json.Marshal(GQLWSMsg{
ID: msg.ID,
Type: "next",
Payload: GQLPayload{
Data: string(data),
},
})
if err != nil {
ctx.Log.Logf("gqlws", "ERROR: %+v", err)
continue
}
err = wsutil.WriteServerMessage(conn, 1, msg)
if err != nil {
ctx.Log.Logf("gqlws", "ERROR: %+v", err)
continue
}
}
}(res_chan)
} else {
}
}
return
} else {
panic("Failed to upgrade websocket")
}
}
}
type Interface struct {
Interface *graphql.Interface
Default *graphql.Object
List *graphql.List
Extensions []ExtType
}
type Type struct {
Type *graphql.Object
List *graphql.List
}
type Field struct {
Ext ExtType
Name string
Field *graphql.Field
}
// GQL Specific Context information
type GQLExtContext struct {
// Generated GQL schema
Schema graphql.Schema
// Custom graphql types, mapped to NodeTypes
NodeTypes map[NodeType]*graphql.Object
Interfaces map[string]*Interface
Fields map[string]Field
// Schema parameters
Types []graphql.Type
Query *graphql.Object
Mutation *graphql.Object
Subscription *graphql.Object
TypeMap map[reflect.Type]GQLTypeInfo
KindMap map[reflect.Kind]GQLTypeInfo
}
func (ctx *GQLExtContext) GetACLFields(obj_name string, names []string) (map[ExtType][]string, error) {
ext_fields := map[ExtType][]string{}
for _, name := range(names) {
switch name {
case "ID":
case "TypeHash":
default:
field, exists := ctx.Fields[name]
if exists == false {
continue
}
ext, exists := ext_fields[field.Ext]
if exists == false {
ext = []string{}
}
ext = append(ext, field.Name)
ext_fields[field.Ext] = ext
}
}
return ext_fields, nil
}
func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) {
schemaConfig := graphql.SchemaConfig{
Types: ctx.Types,
Query: ctx.Query,
Mutation: ctx.Mutation,
Subscription: ctx.Subscription,
}
return graphql.NewSchema(schemaConfig)
}
func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, ext_type ExtType, gv_tag string, resolve_fn func(graphql.ResolveParams, *ResolveContext, reflect.Value)(interface{}, error)) error {
if ctx == nil {
return fmt.Errorf("ctx is nil")
}
if resolve_fn == nil {
return fmt.Errorf("resolve_fn cannot be nil")
}
_, exists := ctx.Fields[gql_name]
if exists == true {
return fmt.Errorf("%s is already a field in the context, cannot add again", gql_name)
}
// Resolver has p.Source.(NodeResult) = read result of current node
resolver := func(p graphql.ResolveParams)(interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(NodeResult)
if ok == false {
return nil, fmt.Errorf("p.Value is not NodeResult")
}
ext, ext_exists := node.Data[ext_type]
if ext_exists == false {
return nil, fmt.Errorf("%+v is not in the extensions of the result", ext_type)
}
val_ser, field_exists := ext[gv_tag]
if field_exists == false {
return nil, fmt.Errorf("%s is not in the fields of %+v in the result for %s - %+v", gv_tag, ext_type, gql_name, node)
}
if val_ser.TypeStack[0] == ErrorType {
return nil, fmt.Errorf(string(val_ser.Data))
}
field_type, field_value, _, err := DeserializeValue(ctx.Context, val_ser)
if err != nil {
return nil, err
}
if field_value == nil {
return nil, fmt.Errorf("%s returned a nil value of %+v type", gv_tag, field_type)
}
ctx.Context.Log.Logf("gql", "Resolving %+v", field_value)
return resolve_fn(p, ctx, *field_value)
}
ctx.Fields[gql_name] = Field{ext_type, gv_tag, &graphql.Field{
Type: gql_type,
Resolve: resolver,
}}
return nil
}
func GQLInterfaces(ctx *GQLExtContext, interface_names []string) ([]*graphql.Interface, error) {
ret := make([]*graphql.Interface, len(interface_names))
for i, in := range(interface_names) {
ctx_interface, exists := ctx.Interfaces[in]
if exists == false {
return nil, fmt.Errorf("%s is not in GQLExtContext.Interfaces", in)
}
ret[i] = ctx_interface.Interface
}
return ret, nil
}
func GQLFields(ctx *GQLExtContext, field_names []string) (graphql.Fields, []ExtType, error) {
fields := graphql.Fields{
"ID": &graphql.Field{
Type: graphql.String,
Resolve: ResolveNodeID,
},
"TypeHash": &graphql.Field{
Type: graphql.String,
Resolve: ResolveNodeTypeHash,
},
}
exts := map[ExtType]ExtType{}
ext_list := []ExtType{}
for _, name := range(field_names) {
field, exists := ctx.Fields[name]
if exists == false {
return nil, nil, fmt.Errorf("%s is not in GQLExtContext.Fields", name)
}
fields[name] = field.Field
_, exists = exts[field.Ext]
if exists == false {
ext_list = append(ext_list, field.Ext)
exts[field.Ext] = field.Ext
}
}
return fields, ext_list, nil
}
type NodeResult struct {
NodeID NodeID
NodeType NodeType
Data map[ExtType]map[string]SerializedValue
}
type ListField struct {
ACLName string
Extension ExtType
ResolveFn func(graphql.ResolveParams, *ResolveContext, reflect.Value) ([]NodeID, error)
}
type SelfField struct {
ACLName string
Extension ExtType
ResolveFn func(graphql.ResolveParams, *ResolveContext, reflect.Value) (*NodeID, error)
}
func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, interfaces []string, fields []string, self_fields map[string]SelfField, list_fields map[string]ListField) error {
if interfaces == nil {
return fmt.Errorf("interfaces is nil")
}
if fields == nil {
return fmt.Errorf("fields is nil")
}
_, exists := ctx.Interfaces[name]
if exists == true {
return fmt.Errorf("%s is already an interface in ctx", name)
}
node_interfaces, err := GQLInterfaces(ctx, interfaces)
if err != nil {
return err
}
node_fields, node_exts, err := GQLFields(ctx, fields)
if err != nil {
return err
}
ctx_interface := Interface{}
ctx_interface.Interface = graphql.NewInterface(graphql.InterfaceConfig{
Name: name,
ResolveType: NodeInterfaceResolveType(node_exts, &ctx_interface.Default),
Fields: node_fields,
})
ctx_interface.List = graphql.NewList(ctx_interface.Interface)
for field_name, field := range(self_fields) {
self_field := field
err := ctx.RegisterField(ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName,
func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) {
id, err := self_field.ResolveFn(p, ctx, value)
if err != nil {
return nil, err
}
if id != nil {
nodes, err := ResolveNodes(ctx, p, []NodeID{*id})
if err != nil {
return nil, err
} else if len(nodes) != 1 {
return nil, fmt.Errorf("wrong length of nodes returned")
}
return nodes[0], nil
} else {
return nil, nil
}
})
if err != nil {
return err
}
ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field)
node_fields[field_name] = ctx.Fields[field_name].Field
}
for field_name, field := range(list_fields) {
list_field := field
resolve_fn := func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) {
var zero NodeID
ids, err := list_field.ResolveFn(p, ctx, value)
if err != nil {
return zero, err
}
nodes, err := ResolveNodes(ctx, p, ids)
if err != nil {
return nil, err
} else if len(nodes) != len(ids) {
return nil, fmt.Errorf("wrong length of nodes returned")
}
return nodes, nil
}
err := ctx.RegisterField(ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn)
if err != nil {
return err
}
ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field)
node_fields[field_name] = ctx.Fields[field_name].Field
}
ctx_interface.Default = graphql.NewObject(graphql.ObjectConfig{
Name: default_name,
Interfaces: append(node_interfaces, ctx_interface.Interface),
IsTypeOf: NodeInterfaceDefaultIsType(node_exts),
Fields: node_fields,
})
ctx.Interfaces[name] = &ctx_interface
ctx.Types = append(ctx.Types, ctx_interface.Default)
return nil
}
func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, interface_names []string, field_names []string) error {
if field_names == nil {
return fmt.Errorf("fields is nil")
}
_, exists := ctx.NodeTypes[node_type]
if exists == true {
return fmt.Errorf("%+v already in GQLExtContext.NodeTypes", node_type)
}
node_interfaces, err := GQLInterfaces(ctx, interface_names)
if err != nil {
return err
}
gql_fields, _, err := GQLFields(ctx, field_names)
if err != nil {
return err
}
gql_type := graphql.NewObject(graphql.ObjectConfig{
Name: name,
Interfaces: node_interfaces,
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
node, ok := p.Value.(NodeResult)
if ok == false {
return false
}
return node.NodeType == node_type
},
Fields: gql_fields,
})
ctx.NodeTypes[node_type] = gql_type
ctx.Types = append(ctx.Types, gql_type)
return nil
}
func NewGQLExtContext() *GQLExtContext {
query := graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: graphql.Fields{},
})
mutation := graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: graphql.Fields{},
})
subscription := graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: graphql.Fields{},
})
kind_map := map[reflect.Kind]GQLTypeInfo{
reflect.String: {
func(ctx *GQLExtContext, reflect_type reflect.Type)(graphql.Type, error) {
return graphql.String, nil
},
func(ctx *GQLExtContext, value interface{})(reflect.Value, error) {
return reflect.ValueOf(value), nil
},
},
reflect.Bool: {
func(ctx *GQLExtContext, reflect_type reflect.Type)(graphql.Type, error) {
return graphql.Boolean, nil
},
func(ctx *GQLExtContext, value interface{})(reflect.Value, error) {
return reflect.ValueOf(value), nil
},
},
}
type_map := map[reflect.Type]GQLTypeInfo{
reflect.TypeOf([2]NodeID{}): {
func(ctx *GQLExtContext, reflect_type reflect.Type)(graphql.Type, error) {
return graphql.NewList(graphql.String), nil
},
func(ctx *GQLExtContext, value interface{})(reflect.Value, error) {
l, ok := value.([]interface{})
if ok == false {
return reflect.Value{}, fmt.Errorf("not list: %s", reflect.TypeOf(value))
} else if len(l) != 2 {
return reflect.Value{}, fmt.Errorf("wrong length: %d/2", len(l))
}
id1_str, ok := l[0].(string)
if ok == false {
return reflect.Value{}, fmt.Errorf("not strg: %s", reflect.TypeOf(l[0]))
}
id1, err := ParseID(id1_str)
if err != nil {
return reflect.Value{}, err
}
id2_str, ok := l[1].(string)
if ok == false {
return reflect.Value{}, fmt.Errorf("not strg: %s", reflect.TypeOf(l[1]))
}
id2, err := ParseID(id2_str)
if err != nil {
return reflect.Value{}, err
}
return_value := reflect.New(reflect.TypeOf([2]NodeID{})).Elem()
return_value.Index(0).Set(reflect.ValueOf(id1))
return_value.Index(1).Set(reflect.ValueOf(id2))
return return_value, nil
},
},
reflect.TypeOf(time.Time{}): {
func(ctx *GQLExtContext, reflect_type reflect.Type) (graphql.Type, error) {
return graphql.DateTime, nil
},
func(ctx *GQLExtContext, value interface{}) (reflect.Value, error) {
return reflect.ValueOf(value), nil
},
},
reflect.TypeOf(&NodeID{}): {
func(ctx *GQLExtContext, reflect_type reflect.Type) (graphql.Type, error) {
return graphql.String, nil
},
func(ctx *GQLExtContext, value interface{}) (reflect.Value, error) {
str, ok := value.(string)
if ok == false {
return reflect.Value{}, fmt.Errorf("value is not string")
}
if str == "" {
return reflect.New(reflect.TypeOf(&NodeID{})).Elem(), nil
}
id_parsed, err := ParseID(str)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(&id_parsed), nil
},
},
reflect.TypeOf(NodeID{}): {
func(ctx *GQLExtContext, reflect_type reflect.Type)(graphql.Type, error) {
return graphql.String, nil
},
func(ctx *GQLExtContext, value interface{})(reflect.Value, error) {
str, ok := value.(string)
if ok == false {
return reflect.Value{}, fmt.Errorf("value is not string")
}
id_parsed, err := ParseID(str)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(id_parsed), nil
},
},
}
context := GQLExtContext{
Schema: graphql.Schema{},
Types: []graphql.Type{},
Query: query,
Mutation: mutation,
Subscription: subscription,
NodeTypes: map[NodeType]*graphql.Object{},
Interfaces: map[string]*Interface{},
Fields: map[string]Field{},
KindMap: kind_map,
TypeMap: type_map,
}
var err error
err = context.RegisterInterface("Node", "DefaultNode", []string{}, []string{}, map[string]SelfField{}, map[string]ListField{})
if err != nil {
panic(err)
}
sub_group_type := graphql.NewObject(graphql.ObjectConfig{
Name: "SubGroup",
Interfaces: nil,
Fields: graphql.Fields{
"Name": &graphql.Field{
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
val, ok := p.Source.(SubGroupGQL)
if ok == false {
return nil, fmt.Errorf("WRONG_TYPE_RETURNED")
}
return val.Name, nil
},
},
"Members": &graphql.Field{
Type: context.Interfaces["Node"].List,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
val, ok := p.Source.(SubGroupGQL)
if ok == false {
return nil, fmt.Errorf("WRONG_TYPE_RETURNED")
}
nodes, err := ResolveNodes(ctx, p, val.Members)
if err != nil {
return nil, err
}
return nodes, nil
},
},
},
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
return reflect.TypeOf(p.Value) == reflect.TypeOf(SubGroupGQL{})
},
Description: "SubGroup within Group",
})
context.Types = append(context.Types, sub_group_type)
err = context.RegisterField(sub_group_type, "SubGroups", GroupExtType, "sub_groups",
func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) {
node_map, ok := value.Interface().(map[string]SubGroup)
if ok == false {
return nil, fmt.Errorf("value is %+v, not map[string]SubGroup", value.Type())
}
sub_groups := []SubGroupGQL{}
for name, sub_group := range(node_map) {
sub_groups = append(sub_groups, SubGroupGQL{
name,
sub_group.Members,
})
}
return sub_groups, nil
})
if err != nil {
panic(err)
}
err = context.RegisterInterface("Group", "DefaultGroup", []string{"Node"}, []string{"SubGroups"}, map[string]SelfField{}, map[string]ListField{})
if err != nil {
panic(err)
}
err = context.RegisterInterface("Lockable", "DefaultLockable", []string{"Node"}, []string{}, map[string]SelfField{
"Owner": {
"owner",
LockableExtType,
func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) (*NodeID, error) {
id, ok := value.Interface().(*NodeID)
if ok == false {
return nil, fmt.Errorf("can't parse %+v as *NodeID", value.Type())
}
return id, nil
},
},
}, map[string]ListField{
"Requirements": {
"requirements",
LockableExtType,
func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) ([]NodeID, error) {
id_strs, ok := value.Interface().(map[NodeID]ReqInfo)
if ok == false {
return nil, fmt.Errorf("can't parse requirements %+v as map[NodeID]ReqState", value.Type())
}
ids := []NodeID{}
for id := range(id_strs) {
ids = append(ids, id)
}
return ids, nil
},
},
})
if err != nil {
panic(err)
}
err = context.RegisterField(graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) (interface{}, error) {
return value.String(), nil
})
if err != nil {
panic(err)
}
err = context.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node", "Lockable", "Group"}, []string{"Listen", "Owner", "Requirements", "SubGroups"})
if err != nil {
panic(err)
}
err = context.AddSignalMutation("stop", "node_id", reflect.TypeOf(StopSignal{}))
if err != nil {
panic(err)
}
err = context.AddSignalMutation("addMember", "group_id", reflect.TypeOf(AddMemberSignal{}))
if err != nil {
panic(err)
}
err = context.AddSignalMutation("removeMember", "group_id", reflect.TypeOf(RemoveMemberSignal{}))
if err != nil {
panic(err)
}
context.Subscription.AddFieldConfig("Self", &graphql.Field{
Type: context.Interfaces["Node"].Interface,
Subscribe: func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
c, err := ctx.Ext.AddSubscription(ctx.ID)
if err != nil {
return nil, err
}
nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID})
if err != nil {
return nil, err
} else if len(nodes) != 1 {
return nil, fmt.Errorf("wrong length of nodes returned")
}
c <- nodes[0]
return c, nil
},
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
ctx.Context.Log.Logf("gql_subscribe", "SUBSCRIBE_RESOLVE: %+v", p.Source)
switch source := p.Source.(type) {
case NodeResult:
case *StatusSignal:
delete(ctx.NodeCache, source.Source)
ctx.Context.Log.Logf("gql_subscribe", "Deleting %+v from NodeCache", source.Source)
if source.Source == ctx.Server.ID {
nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID})
if err != nil {
return nil, err
} else if len(nodes) != 1 {
return nil, fmt.Errorf("wrong length of nodes returned")
}
ctx.NodeCache[ctx.Server.ID] = nodes[0]
}
default:
return nil, fmt.Errorf("Don't know how to handle %+v", source)
}
return ctx.NodeCache[ctx.Server.ID], nil
},
})
context.Query.AddFieldConfig("Self", &graphql.Field{
Type: context.Interfaces["Node"].Interface,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
nodes, err := ResolveNodes(ctx, p, []NodeID{ctx.Server.ID})
if err != nil {
return nil, err
} else if len(nodes) != 1 {
return nil, fmt.Errorf("wrong length of resolved nodes returned")
}
return nodes[0], nil
},
})
context.Query.AddFieldConfig("Node", &graphql.Field{
Type: context.Interfaces["Node"].Interface,
Args: graphql.FieldConfigArgument{
"id": &graphql.ArgumentConfig{
Type: graphql.String,
},
},
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
id, err := ExtractID(p, "id")
if err != nil {
return nil, err
}
nodes, err := ResolveNodes(ctx, p, []NodeID{id})
if err != nil {
return nil, err
} else if len(nodes) != 1 {
return nil, fmt.Errorf("wrong length of resolved nodes returned")
}
return nodes[0], nil
},
})
schema, err := BuildSchema(&context)
if err != nil {
panic(err)
}
context.Schema = schema
return &context
}
type SubscriptionInfo struct {
ID uuid.UUID
Channel chan interface{}
}
type GQLExt struct {
tcp_listener net.Listener
http_server *http.Server
http_done sync.WaitGroup
subscriptions []SubscriptionInfo
subscriptions_lock sync.RWMutex
// map of read request IDs to response channels
resolver_response map[uuid.UUID]chan Signal
resolver_response_lock sync.RWMutex
TLSKey []byte `gv:"tls_key"`
TLSCert []byte `gv:"tls_cert"`
Listen string `gv:"listen"`
}
func (ext *GQLExt) PostDeserialize(*Context) error {
ext.resolver_response = map[uuid.UUID]chan Signal{}
ext.subscriptions = []SubscriptionInfo{}
return nil
}
func (ext *GQLExt) AddSubscription(id uuid.UUID) (chan interface{}, error) {
ext.subscriptions_lock.Lock()
defer ext.subscriptions_lock.Unlock()
for _, info := range(ext.subscriptions) {
if info.ID == id {
return nil, fmt.Errorf("%+v already in subscription list", info.ID)
}
}
c := make(chan interface{}, 1)
ext.subscriptions = append(ext.subscriptions, SubscriptionInfo{
id,
c,
})
return c, nil
}
func (ext *GQLExt) RemoveSubscription(id uuid.UUID) error {
ext.subscriptions_lock.Lock()
defer ext.subscriptions_lock.Unlock()
for i, info := range(ext.subscriptions) {
if info.ID == id {
ext.subscriptions[i] = ext.subscriptions[len(ext.subscriptions)]
ext.subscriptions = ext.subscriptions[:len(ext.subscriptions)-1]
return nil
}
}
return fmt.Errorf("%+v not in subscription list", id)
}
func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal {
ext.resolver_response_lock.RLock()
response_chan, _ := ext.resolver_response[req_id]
ext.resolver_response_lock.RUnlock()
return response_chan
}
func (ext *GQLExt) GetResponseChannel(req_id uuid.UUID) chan Signal {
response_chan := make(chan Signal, 1)
ext.resolver_response_lock.Lock()
ext.resolver_response[req_id] = response_chan
ext.resolver_response_lock.Unlock()
return response_chan
}
func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
response_chan := ext.FindResponseChannel(req_id)
if response_chan != nil {
ext.resolver_response_lock.Lock()
delete(ext.resolver_response, req_id)
ext.resolver_response_lock.Unlock()
}
return response_chan
}
func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
// Process ReadResultSignalType by forwarding it to the waiting resolver
var changes Changes = nil
switch sig := signal.(type) {
case *SuccessSignal:
response_chan := ext.FreeResponseChannel(sig.ReqID)
if response_chan != nil {
select {
case response_chan <- sig:
ctx.Log.Logf("gql", "Forwarded success to resolver, %+v", sig.ReqID)
default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
}
} else {
ctx.Log.Logf("gql", "received success signal response %+v with no mapped resolver", sig)
}
case *ErrorSignal:
// TODO: Forward to resolver if waiting for it
response_chan := ext.FreeResponseChannel(sig.ReqID)
if response_chan != nil {
select {
case response_chan <- sig:
ctx.Log.Logf("gql", "Forwarded error to resolver, %+v", sig.Error)
default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
}
} else {
ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig)
}
case *ReadResultSignal:
response_chan := ext.FindResponseChannel(sig.ReqID)
if response_chan != nil {
select {
case response_chan <- sig:
ctx.Log.Logf("gql", "Forwarded to resolver, %+v", sig)
default:
ctx.Log.Logf("gql", "Resolver channel overflow %+v", sig)
}
} else {
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
}
case *StopSignal:
ctx.Log.Logf("gql", "stopping gql server %s", node.ID)
err := ext.StopGQLServer()
if err != nil {
ctx.Log.Logf("gql", "GQL_STOP_ERROR: %s", err)
}
case *StartSignal:
ctx.Log.Logf("gql", "starting gql server %s", node.ID)
err := ext.StartGQLServer(ctx, node)
if err == nil {
changes = changes.Add("server_started")
} else {
ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err)
}
case *StatusSignal:
ext.subscriptions_lock.RLock()
ctx.Log.Logf("gql", "forwarding status signal from %+v to resolvers %+v", sig.Source, ext.subscriptions)
for _, resolver := range(ext.subscriptions) {
select {
case resolver.Channel <- sig:
ctx.Log.Logf("gql_subscribe", "forwarded status signal to resolver: %+v", resolver.ID)
default:
ctx.Log.Logf("gql_subscribe", "resolver channel overflow: %+v", resolver.ID)
}
}
ext.subscriptions_lock.RUnlock()
}
return nil, changes
}
var ecdsa_curves = map[uint8]elliptic.Curve{
0: elliptic.P256(),
}
var ecdsa_curve_ids = map[elliptic.Curve]uint8{
elliptic.P256(): 0,
}
var ecdh_curves = map[uint8]ecdh.Curve{
0: ecdh.P256(),
}
var ecdh_curve_ids = map[ecdh.Curve]uint8{
ecdh.P256(): 0,
}
func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte) (*GQLExt, error) {
if tls_cert == nil || tls_key == nil {
ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
ssl_key_bytes, err := x509.MarshalPKCS8PrivateKey(ssl_key)
if err != nil {
return nil, err
}
ssl_key_pem := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: ssl_key_bytes})
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit)
notBefore := time.Now()
notAfter := notBefore.Add(365*24*time.Hour)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"mekkanized"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, ssl_key.Public(), ssl_key)
if err != nil {
return nil, err
}
ssl_cert_pem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ssl_cert})
tls_cert = ssl_cert_pem
tls_key = ssl_key_pem
}
return &GQLExt{
Listen: listen,
resolver_response: map[uuid.UUID]chan Signal{},
subscriptions: []SubscriptionInfo{},
TLSCert: tls_cert,
TLSKey: tls_key,
}, nil
}
// Returns "${base}/${path}" if it's a file or "${base}/${path}/index.html" if it's a directory
// Returns os.ErrInvalid if "${base}/${path}/index.html" is a directory
func getContentPath(base string, path string) (string, error) {
full_path := fmt.Sprintf("%s%s", base, path)
path_info, err := os.Stat(full_path)
if err != nil && err != os.ErrNotExist {
return "", err
} else if path_info.IsDir() == true {
index_path := ""
if path[len(path)-1] != '/' {
index_path = fmt.Sprintf("%s%s/index.html", base, path)
} else {
index_path = fmt.Sprintf("%s%sindex.html", base, path)
}
index_info, err := os.Stat(index_path)
if err != nil {
return "", err
} else if index_info.IsDir() == true {
return index_path, os.ErrInvalid
} else {
return index_path, nil
}
} else {
return full_path, nil
}
}
func (ext *GQLExt) StartGQLServer(ctx *Context, node *Node) error {
if ext.tcp_listener != nil || ext.http_server != nil {
return fmt.Errorf("listener or server is still running, stop them first")
}
mux := http.NewServeMux()
mux.HandleFunc("/gql", GQLHandler(ctx, node, ext))
mux.HandleFunc("/gqlws", GQLWSHandler(ctx, node, ext))
mux.HandleFunc("/graphiql", GraphiQLHandler())
// Server the ./site directory to /site (TODO make configurable with better defaults)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request){
custom_path, err_1 := getContentPath("./custom", r.URL.Path)
if err_1 != nil {
static_path, err_2 := getContentPath("./site", r.URL.Path)
if err_2 != nil {
ctx.Log.Logf("gql", "File Resolve errors: %s - %s", err_1, err_2)
w.WriteHeader(501)
w.Write([]byte("{\"error\": \"server_error\"}"))
} else {
ctx.Log.Logf("gql", "STATIC_FILE: %s", static_path)
file, err := os.Open(static_path)
if err != nil {
w.WriteHeader(501)
w.Write([]byte("{\"error\": \"server_error\"}"))
}
http.ServeContent(w, r, static_path, time.Time{}, file)
w.WriteHeader(200)
}
} else {
ctx.Log.Logf("gql", "CUSTOM_FILE: %s", custom_path)
file, err := os.Open(custom_path)
if err != nil {
w.WriteHeader(501)
w.Write([]byte("{\"error\": \"server_error\"}"))
}
http.ServeContent(w, r, custom_path, time.Time{}, file)
w.WriteHeader(200)
}
})
http_server := &http.Server{
Addr: ext.Listen,
Handler: mux,
}
l, err := net.Listen("tcp", http_server.Addr)
if err != nil {
return fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
}
ext.http_done.Add(1)
go func(qql_ext *GQLExt) {
defer ext.http_done.Done()
err := http_server.Serve(l)
if err != http.ErrServerClosed {
panic(fmt.Sprintf("Failed to start gql server: %s", err))
}
}(ext)
ext.tcp_listener = l
ext.http_server = http_server
return nil
}
func (ext *GQLExt) StopGQLServer() error {
if ext.tcp_listener == nil || ext.http_server == nil {
return fmt.Errorf("already shutdown, cannot shut down again" )
}
ext.http_server.Shutdown(context.TODO())
ext.http_done.Wait()
ext.tcp_listener = nil
ext.http_server = nil
return nil
}