Added basic ws subscription

graph-rework
noah metz 2023-06-16 16:53:29 -06:00
parent 5f5916010d
commit bfcf06b190
5 changed files with 252 additions and 67 deletions

208
gql.go

@ -3,6 +3,9 @@ package main
import ( import (
"net/http" "net/http"
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/parser"
"github.com/graphql-go/graphql/language/source"
"github.com/graphql-go/graphql/language/ast"
"context" "context"
"encoding/json" "encoding/json"
"io" "io"
@ -192,7 +195,90 @@ type GQLWSMsg struct {
Payload GQLWSPayload `json:"payload,omitempty"` Payload GQLWSPayload `json:"payload,omitempty"`
} }
func enableCORS(w *http.ResponseWriter) {
(*w).Header().Set("Access-Control-Allow-Origin", "*")
}
func GQLHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWriter, *http.Request) { func GQLHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) {
log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr)
enableCORS(&w)
header_map := map[string]interface{}{}
for header, value := range(r.Header) {
header_map[header] = value
}
log.Logm("gql", header_map, "REQUEST_HEADERS")
str, err := io.ReadAll(r.Body)
if err != nil {
log.Logf("gql", "failed to read request body: %s", err)
return
}
query := GQLWSPayload{}
json.Unmarshal(str, &query)
params := graphql.Params{
Schema: schema,
Context: ctx,
RequestString: query.Query,
}
if query.OperationName != "" {
params.OperationName = query.OperationName
}
if len(query.Variables) > 0 {
params.VariableValues = query.Variables
}
result := graphql.Do(params)
if len(result.Errors) > 0 {
extra_fields := map[string]interface{}{}
extra_fields["body"] = string(str)
extra_fields["headers"] = r.Header
log.Logm("gql", extra_fields, "wrong result, unexpected errors: %v", result.Errors)
}
json.NewEncoder(w).Encode(result)
}
}
func sendOneResultAndClose(res *graphql.Result) chan *graphql.Result {
resultChannel := make(chan *graphql.Result)
resultChannel <- res
close(resultChannel)
return resultChannel
}
func getOperationTypeOfReq(p graphql.Params) string{
source := source.NewSource(&source.Source{
Body: []byte(p.RequestString),
Name: "GraphQL request",
})
AST, err := parser.Parse(parser.ParseParams{Source: source})
if err != nil {
return ""
}
for _, node := range AST.Definitions {
if operationDef, ok := node.(*ast.OperationDefinition); ok {
if operationDef.Name.Value == p.OperationName || p.OperationName == "" {
return operationDef.Operation
}
}
}
return ""
}
func GQLWSDo(p graphql.Params) chan *graphql.Result {
operation := getOperationTypeOfReq(p)
if operation == ast.OperationTypeSubscription {
return graphql.Subscribe(p)
}
return sendOneResultAndClose(graphql.Do(p))
}
func GQLWSHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) { return func(w http.ResponseWriter, r * http.Request) {
log.Logf("gqlws", "HANDLING %s",r.RemoteAddr) log.Logf("gqlws", "HANDLING %s",r.RemoteAddr)
header_map := map[string]interface{}{} header_map := map[string]interface{}{}
@ -211,6 +297,7 @@ func GQLHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWr
defer conn.Close() defer conn.Close()
conn_state := "init" conn_state := "init"
for { for {
// TODO: Make this a select between reading client data and getting updates from the event to push to clients"
msg_raw, op, err := wsutil.ReadClientData(conn) msg_raw, op, err := wsutil.ReadClientData(conn)
log.Logf("gqlws", "MSG: %s\nOP: 0x%02x\nERR: %+v\n", string(msg_raw), op, err) log.Logf("gqlws", "MSG: %s\nOP: 0x%02x\nERR: %+v\n", string(msg_raw), op, err)
msg := GQLWSMsg{} msg := GQLWSMsg{}
@ -249,18 +336,32 @@ func GQLHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWr
if len(msg.Payload.Variables) > 0 { if len(msg.Payload.Variables) > 0 {
params.VariableValues = msg.Payload.Variables params.VariableValues = msg.Payload.Variables
} }
result := graphql.Do(params)
if len(result.Errors) > 0 { res_chan := GQLWSDo(params)
go func(res_chan chan *graphql.Result) {
for {
next, ok := <-res_chan
if ok == false {
log.Logf("gqlws", "response channel was closed")
return
}
if next == nil {
log.Logf("gqlws", "NIL_ON_CHANNEL")
return
}
if len(next.Errors) > 0 {
extra_fields := map[string]interface{}{} extra_fields := map[string]interface{}{}
extra_fields["query"] = string(msg.Payload.Query) extra_fields["query"] = string(msg.Payload.Query)
log.Logm("gql", extra_fields, "ERROR: wrong result, unexpected errors: %v", result.Errors) log.Logm("gqlws", extra_fields, "ERROR: wrong result, unexpected errors: %+v", next.Errors)
break continue
}
log.Logf("gqlws", "DATA: %+v", next.Data)
data, err := json.Marshal(next.Data)
if err != nil {
log.Logf("gqlws", "ERROR: %+v", err)
continue
} }
log.Logf("gqlws", "DATA: %+v", result.Data)
data, err := json.Marshal(result.Data)
msg, err := json.Marshal(GQLWSMsg{ msg, err := json.Marshal(GQLWSMsg{
ID: msg.ID, ID: msg.ID,
Type: "next", Type: "next",
@ -270,46 +371,22 @@ func GQLHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWr
}) })
if err != nil { if err != nil {
log.Logf("gqlws", "ERROR: %+v", err) log.Logf("gqlws", "ERROR: %+v", err)
break continue
} }
log.Logf("gqlws", "WRITING_GQLWS: %s", msg)
err = wsutil.WriteServerMessage(conn, 1, msg) err = wsutil.WriteServerMessage(conn, 1, msg)
if err != nil { if err != nil {
log.Logf("gqlws", "ERROR: %+v", err) log.Logf("gqlws", "ERROR: %+v", err)
break continue
} }
}
}(res_chan)
} else { } else {
} }
} }
return return
} else { } else {
str, err := io.ReadAll(r.Body) panic("Failed to upgrade websocket")
if err != nil {
log.Logf("gql", "failed to read request body: %s", err)
return
}
query := GQLWSPayload{}
json.Unmarshal(str, &query)
params := graphql.Params{
Schema: schema,
Context: ctx,
RequestString: query.Query,
}
if query.OperationName != "" {
params.OperationName = query.OperationName
}
if len(query.Variables) > 0 {
params.VariableValues = query.Variables
}
result := graphql.Do(params)
if len(result.Errors) > 0 {
extra_fields := map[string]interface{}{}
extra_fields["body"] = string(str)
extra_fields["headers"] = r.Header
log.Logm("gql", extra_fields, "wrong result, unexpected errors: %v", result.Errors)
}
json.NewEncoder(w).Encode(result)
} }
} }
} }
@ -763,6 +840,7 @@ type GQLServer struct {
gql_channel chan error gql_channel chan error
extended_types map[reflect.Type]*graphql.Object extended_types map[reflect.Type]*graphql.Object
extended_queries map[string]*graphql.Field extended_queries map[string]*graphql.Field
extended_subscriptions map[string]*graphql.Field
extended_mutations map[string]*graphql.Field extended_mutations map[string]*graphql.Field
} }
@ -785,7 +863,7 @@ func (server * GQLServer) update(signal GraphSignal) {
server.BaseResource.update(signal) server.BaseResource.update(signal)
} }
func (server * GQLServer) Handler() func(http.ResponseWriter, *http.Request) { func MakeGQLHandlers(server * GQLServer) (func(http.ResponseWriter, *http.Request), func(http.ResponseWriter, *http.Request)) {
valid_events := map[reflect.Type]*graphql.Object{} valid_events := map[reflect.Type]*graphql.Object{}
valid_events[reflect.TypeOf((*BaseEvent)(nil))] = GQLTypeBaseEvent() valid_events[reflect.TypeOf((*BaseEvent)(nil))] = GQLTypeBaseEvent()
valid_events[reflect.TypeOf((*EventQueue)(nil))] = GQLTypeEventQueue() valid_events[reflect.TypeOf((*EventQueue)(nil))] = GQLTypeEventQueue()
@ -807,6 +885,48 @@ func (server * GQLServer) Handler() func(http.ResponseWriter, *http.Request) {
gql_queries[key] = value gql_queries[key] = value
} }
gql_subscriptions := graphql.Fields{
"Test": &graphql.Field{
Type: GQLTypeSignal(),
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return p.Source, nil
},
Subscribe: func(p graphql.ResolveParams) (interface{}, error) {
/*c := make(chan interface{})
go func() {
elements := []string{"a", "b", "c"}
for _, r := range elements {
select {
case <-p.Context.Done():
close(c)
return
case c <- r:
}
}
close(c)
}()
return c, nil*/
server, ok := p.Context.Value("gql_server").(*GQLServer)
if ok == false {
return nil, fmt.Errorf("Failed to get gql_server from context and cast")
}
c := make(chan interface{})
go func(c chan interface{}) {
sig_c := server.UpdateChannel()
for {
val, _ := <- sig_c
c <- val
}
}(c)
return c, nil
},
},
}
for key, value := range(server.extended_subscriptions) {
gql_subscriptions[key] = value
}
gql_mutations := graphql.Fields{ gql_mutations := graphql.Fields{
"updateEvent": GQLMutationUpdateEvent(), "updateEvent": GQLMutationUpdateEvent(),
} }
@ -825,6 +945,10 @@ func (server * GQLServer) Handler() func(http.ResponseWriter, *http.Request) {
Name: "Mutation", Name: "Mutation",
Fields: gql_mutations, Fields: gql_mutations,
}), }),
Subscription: graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: gql_subscriptions,
}),
} }
schema, err := graphql.NewSchema(schemaConfig) schema, err := graphql.NewSchema(schemaConfig)
@ -835,7 +959,7 @@ func (server * GQLServer) Handler() func(http.ResponseWriter, *http.Request) {
ctx = context.WithValue(ctx, "valid_events", valid_events) ctx = context.WithValue(ctx, "valid_events", valid_events)
ctx = context.WithValue(ctx, "valid_resources", valid_resources) ctx = context.WithValue(ctx, "valid_resources", valid_resources)
ctx = context.WithValue(ctx, "gql_server", server) ctx = context.WithValue(ctx, "gql_server", server)
return GQLHandler(schema, ctx) return GQLHandler(schema, ctx), GQLWSHandler(schema, ctx)
} }
var gql_query_owner *graphql.Field = nil var gql_query_owner *graphql.Field = nil
@ -863,7 +987,9 @@ func (server * GQLServer) Init(abort chan error) bool {
log.Logf("gql", "GOROUTINE_START for %s", server.ID()) log.Logf("gql", "GOROUTINE_START for %s", server.ID())
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/gql", server.Handler()) http_handler, ws_handler := MakeGQLHandlers(server)
mux.HandleFunc("/gql", http_handler)
mux.HandleFunc("/gqlws", ws_handler)
mux.HandleFunc("/", GraphiQLHandler()) mux.HandleFunc("/", GraphiQLHandler())
srv := &http.Server{ srv := &http.Server{

@ -23,7 +23,7 @@ type DefaultLogger struct {
} }
var log DefaultLogger = DefaultLogger{loggers: map[string]zerolog.Logger{}} var log DefaultLogger = DefaultLogger{loggers: map[string]zerolog.Logger{}}
var all_components = []string{"update", "graph", "event", "resource", "manager", "test", "gql", "vex", "gqlws"} var all_components = []string{"update", "graph", "event", "resource", "manager", "test", "gql", "vex", "gqlws", "listeners"}
func (logger * DefaultLogger) Init(components []string) error { func (logger * DefaultLogger) Init(components []string) error {
logger.init_lock.Lock() logger.init_lock.Lock()
@ -227,11 +227,11 @@ func (node * BaseNode) UpdateListeners(update GraphSignal) {
closed := []chan GraphSignal{} closed := []chan GraphSignal{}
for _, listener := range node.listeners { for _, listener := range node.listeners {
log.Logf("update", "UPDATE_LISTENER %s: %p", node.Name(), listener) log.Logf("listeners", "UPDATE_LISTENER %s: %p", node.Name(), listener)
select { select {
case listener <- update: case listener <- update:
default: default:
log.Logf("update", "CLOSED_LISTENER: %s: %p", node.Name(), listener) log.Logf("listeners", "CLOSED_LISTENER: %s: %p", node.Name(), listener)
go func(node GraphNode, listener chan GraphSignal) { go func(node GraphNode, listener chan GraphSignal) {
listener <- NewSignal(node, "listener_closed") listener <- NewSignal(node, "listener_closed")
close(listener) close(listener)

@ -200,6 +200,7 @@ func main() {
for true { for true {
select { select {
case <-sigs: case <-sigs:
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
signal := NewSignal(nil, "abort") signal := NewSignal(nil, "abort")
signal.description = event_manager.root_event.ID() signal.description = event_manager.root_event.ID()
SendUpdate(event_manager.root_event, signal) SendUpdate(event_manager.root_event, signal)

@ -1,33 +1,63 @@
<script lang="ts"> <script lang="ts">
import Button from "./Button.svelte"
import { createClient } from 'graphql-ws'; import { createClient } from 'graphql-ws';
import { WebSocket } from 'ws'; import { WebSocket } from 'ws';
const client = createClient({ const client = createClient({
url: "ws://localhost:8080/gql", url: "ws://localhost:8080/gqlws",
webSocketImpl: WebSocket, webSocketImpl: WebSocket,
keepAlive: 10_000, keepAlive: 10_000,
}); });
var game_id = null
console.log("STARTING_CLIENT") console.log("STARTING_CLIENT")
client.subscribe({ client.subscribe({
query: "{ Arenas { Name Owner { ... on Match { Name, ID } } } }", operationName: "Sub",
query: "query GetArenas { Arenas { Name Owner { ... on Match { Name, ID } } } } subscription Sub { Test { String } }",
}, },
{ {
next: (data) => { next: (data) => {
console.log("NEXT") console.log("NEXT")
let r = JSON.parse(data.data) console.log(data)
let game_id = r.Arenas[0].Owner.ID
console.log(game_id)
}, },
error: (err) => { error: (err) => {
console.log("ERROR") console.log("ERROR")
console.log(err)
}, },
complete: () => { complete: () => {
console.log("COMPLETED") console.log("COMPLETED")
}, },
});
async function match_state(match_id, state) {
let url = "http://localhost:8080/gql"
let data = {
operationName: "MatchState",
query: "mutation MatchState($match_id:String, $match_state:String) { setMatchState(id:$match_id, state:$match_state) { String } }",
variables: {
match_id: match_id,
match_state: state,
}
}
const response = await fetch(url, {
method: "POST",
mode: "same-origin",
cache: "no-cache",
credentials: "include",
headers: {
"Content-Type": "applicaton/json",
},
redirect: "follow",
referrerPolicy: "no-referrer",
body: JSON.stringify(data),
}); });
console.log(response.json())
}
</script> </script>
<h1></h1> <Button on:click={()=>match_state("eafd0201-caa4-4b35-a99b-869aac9455fa", "queue_autonomous")}>Queue Autonomous</Button>

@ -0,0 +1,28 @@
<script>
let buttonProps = {
class:[$$restProps.class]
}
</script>
<button on:click
on:mouseover
on:mouseenter
on:mouseleave
{...buttonProps}>
<slot/>
</button>
<style>
.primary{
color:green;
}
.danger {
color:red;
}
.sm {
font-size:1em;
padding:0.1em;
}
.lg {
font-size:2em
}
</style>