From bfcf06b1907168aeb607aceeaf9ffed3e4f36fb9 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Fri, 16 Jun 2023 16:53:29 -0600 Subject: [PATCH] Added basic ws subscription --- gql.go | 242 ++++++++++++++++++++++------- graph.go | 6 +- main.go | 1 + test-site/src/routes/+page.svelte | 42 ++++- test-site/src/routes/Button.svelte | 28 ++++ 5 files changed, 252 insertions(+), 67 deletions(-) create mode 100644 test-site/src/routes/Button.svelte diff --git a/gql.go b/gql.go index 4667871..1572b45 100644 --- a/gql.go +++ b/gql.go @@ -3,6 +3,9 @@ package main import ( "net/http" "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/language/parser" + "github.com/graphql-go/graphql/language/source" + "github.com/graphql-go/graphql/language/ast" "context" "encoding/json" "io" @@ -192,7 +195,90 @@ type GQLWSMsg struct { Payload GQLWSPayload `json:"payload,omitempty"` } +func enableCORS(w *http.ResponseWriter) { + (*w).Header().Set("Access-Control-Allow-Origin", "*") +} + func GQLHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r * http.Request) { + log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr) + enableCORS(&w) + header_map := map[string]interface{}{} + for header, value := range(r.Header) { + header_map[header] = value + } + log.Logm("gql", header_map, "REQUEST_HEADERS") + + str, err := io.ReadAll(r.Body) + if err != nil { + log.Logf("gql", "failed to read request body: %s", err) + return + } + query := GQLWSPayload{} + json.Unmarshal(str, &query) + + params := graphql.Params{ + Schema: schema, + Context: ctx, + RequestString: query.Query, + } + if query.OperationName != "" { + params.OperationName = query.OperationName + } + if len(query.Variables) > 0 { + params.VariableValues = query.Variables + } + result := graphql.Do(params) + if len(result.Errors) > 0 { + extra_fields := map[string]interface{}{} + extra_fields["body"] = string(str) + extra_fields["headers"] = r.Header + log.Logm("gql", extra_fields, "wrong result, unexpected errors: %v", result.Errors) + } + json.NewEncoder(w).Encode(result) + } +} + +func sendOneResultAndClose(res *graphql.Result) chan *graphql.Result { + resultChannel := make(chan *graphql.Result) + resultChannel <- res + close(resultChannel) + return resultChannel +} + + +func getOperationTypeOfReq(p graphql.Params) string{ + source := source.NewSource(&source.Source{ + Body: []byte(p.RequestString), + Name: "GraphQL request", + }) + + AST, err := parser.Parse(parser.ParseParams{Source: source}) + if err != nil { + return "" + } + + for _, node := range AST.Definitions { + if operationDef, ok := node.(*ast.OperationDefinition); ok { + if operationDef.Name.Value == p.OperationName || p.OperationName == "" { + return operationDef.Operation + } + } + } + return "" +} + +func GQLWSDo(p graphql.Params) chan *graphql.Result { + operation := getOperationTypeOfReq(p) + + if operation == ast.OperationTypeSubscription { + return graphql.Subscribe(p) + } + + return sendOneResultAndClose(graphql.Do(p)) +} + +func GQLWSHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r * http.Request) { log.Logf("gqlws", "HANDLING %s",r.RemoteAddr) header_map := map[string]interface{}{} @@ -211,6 +297,7 @@ func GQLHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWr defer conn.Close() conn_state := "init" for { + // TODO: Make this a select between reading client data and getting updates from the event to push to clients" msg_raw, op, err := wsutil.ReadClientData(conn) log.Logf("gqlws", "MSG: %s\nOP: 0x%02x\nERR: %+v\n", string(msg_raw), op, err) msg := GQLWSMsg{} @@ -249,67 +336,57 @@ func GQLHandler(schema graphql.Schema, ctx context.Context) func(http.ResponseWr if len(msg.Payload.Variables) > 0 { params.VariableValues = msg.Payload.Variables } - result := graphql.Do(params) - if len(result.Errors) > 0 { - extra_fields := map[string]interface{}{} - extra_fields["query"] = string(msg.Payload.Query) - log.Logm("gql", extra_fields, "ERROR: wrong result, unexpected errors: %v", result.Errors) - break - } - - - log.Logf("gqlws", "DATA: %+v", result.Data) - data, err := json.Marshal(result.Data) - msg, err := json.Marshal(GQLWSMsg{ - ID: msg.ID, - Type: "next", - Payload: GQLWSPayload{ - Data: string(data), - }, - }) - if err != nil { - log.Logf("gqlws", "ERROR: %+v", err) - break - } - log.Logf("gqlws", "WRITING_GQLWS: %s", msg) - err = wsutil.WriteServerMessage(conn, 1, msg) - if err != nil { - log.Logf("gqlws", "ERROR: %+v", err) - break - } + res_chan := GQLWSDo(params) + + go func(res_chan chan *graphql.Result) { + for { + next, ok := <-res_chan + if ok == false { + log.Logf("gqlws", "response channel was closed") + return + } + if next == nil { + log.Logf("gqlws", "NIL_ON_CHANNEL") + return + } + if len(next.Errors) > 0 { + extra_fields := map[string]interface{}{} + extra_fields["query"] = string(msg.Payload.Query) + log.Logm("gqlws", extra_fields, "ERROR: wrong result, unexpected errors: %+v", next.Errors) + continue + } + log.Logf("gqlws", "DATA: %+v", next.Data) + data, err := json.Marshal(next.Data) + if err != nil { + log.Logf("gqlws", "ERROR: %+v", err) + continue + } + msg, err := json.Marshal(GQLWSMsg{ + ID: msg.ID, + Type: "next", + Payload: GQLWSPayload{ + Data: string(data), + }, + }) + if err != nil { + log.Logf("gqlws", "ERROR: %+v", err) + continue + } + + err = wsutil.WriteServerMessage(conn, 1, msg) + if err != nil { + log.Logf("gqlws", "ERROR: %+v", err) + continue + } + } + }(res_chan) } else { } } return } else { - str, err := io.ReadAll(r.Body) - if err != nil { - log.Logf("gql", "failed to read request body: %s", err) - return - } - query := GQLWSPayload{} - json.Unmarshal(str, &query) - - params := graphql.Params{ - Schema: schema, - Context: ctx, - RequestString: query.Query, - } - if query.OperationName != "" { - params.OperationName = query.OperationName - } - if len(query.Variables) > 0 { - params.VariableValues = query.Variables - } - result := graphql.Do(params) - if len(result.Errors) > 0 { - extra_fields := map[string]interface{}{} - extra_fields["body"] = string(str) - extra_fields["headers"] = r.Header - log.Logm("gql", extra_fields, "wrong result, unexpected errors: %v", result.Errors) - } - json.NewEncoder(w).Encode(result) + panic("Failed to upgrade websocket") } } } @@ -763,6 +840,7 @@ type GQLServer struct { gql_channel chan error extended_types map[reflect.Type]*graphql.Object extended_queries map[string]*graphql.Field + extended_subscriptions map[string]*graphql.Field extended_mutations map[string]*graphql.Field } @@ -785,7 +863,7 @@ func (server * GQLServer) update(signal GraphSignal) { server.BaseResource.update(signal) } -func (server * GQLServer) Handler() func(http.ResponseWriter, *http.Request) { +func MakeGQLHandlers(server * GQLServer) (func(http.ResponseWriter, *http.Request), func(http.ResponseWriter, *http.Request)) { valid_events := map[reflect.Type]*graphql.Object{} valid_events[reflect.TypeOf((*BaseEvent)(nil))] = GQLTypeBaseEvent() valid_events[reflect.TypeOf((*EventQueue)(nil))] = GQLTypeEventQueue() @@ -807,6 +885,48 @@ func (server * GQLServer) Handler() func(http.ResponseWriter, *http.Request) { gql_queries[key] = value } + gql_subscriptions := graphql.Fields{ + "Test": &graphql.Field{ + Type: GQLTypeSignal(), + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return p.Source, nil + }, + Subscribe: func(p graphql.ResolveParams) (interface{}, error) { + /*c := make(chan interface{}) + go func() { + elements := []string{"a", "b", "c"} + for _, r := range elements { + select { + case <-p.Context.Done(): + close(c) + return + case c <- r: + } + } + close(c) + }() + return c, nil*/ + server, ok := p.Context.Value("gql_server").(*GQLServer) + if ok == false { + return nil, fmt.Errorf("Failed to get gql_server from context and cast") + } + c := make(chan interface{}) + go func(c chan interface{}) { + sig_c := server.UpdateChannel() + for { + val, _ := <- sig_c + c <- val + } + }(c) + return c, nil + }, + }, + } + + for key, value := range(server.extended_subscriptions) { + gql_subscriptions[key] = value + } + gql_mutations := graphql.Fields{ "updateEvent": GQLMutationUpdateEvent(), } @@ -825,6 +945,10 @@ func (server * GQLServer) Handler() func(http.ResponseWriter, *http.Request) { Name: "Mutation", Fields: gql_mutations, }), + Subscription: graphql.NewObject(graphql.ObjectConfig{ + Name: "Subscription", + Fields: gql_subscriptions, + }), } schema, err := graphql.NewSchema(schemaConfig) @@ -835,7 +959,7 @@ func (server * GQLServer) Handler() func(http.ResponseWriter, *http.Request) { ctx = context.WithValue(ctx, "valid_events", valid_events) ctx = context.WithValue(ctx, "valid_resources", valid_resources) ctx = context.WithValue(ctx, "gql_server", server) - return GQLHandler(schema, ctx) + return GQLHandler(schema, ctx), GQLWSHandler(schema, ctx) } var gql_query_owner *graphql.Field = nil @@ -863,7 +987,9 @@ func (server * GQLServer) Init(abort chan error) bool { log.Logf("gql", "GOROUTINE_START for %s", server.ID()) mux := http.NewServeMux() - mux.HandleFunc("/gql", server.Handler()) + http_handler, ws_handler := MakeGQLHandlers(server) + mux.HandleFunc("/gql", http_handler) + mux.HandleFunc("/gqlws", ws_handler) mux.HandleFunc("/", GraphiQLHandler()) srv := &http.Server{ diff --git a/graph.go b/graph.go index e023610..67cdc98 100644 --- a/graph.go +++ b/graph.go @@ -23,7 +23,7 @@ type DefaultLogger struct { } var log DefaultLogger = DefaultLogger{loggers: map[string]zerolog.Logger{}} -var all_components = []string{"update", "graph", "event", "resource", "manager", "test", "gql", "vex", "gqlws"} +var all_components = []string{"update", "graph", "event", "resource", "manager", "test", "gql", "vex", "gqlws", "listeners"} func (logger * DefaultLogger) Init(components []string) error { logger.init_lock.Lock() @@ -227,11 +227,11 @@ func (node * BaseNode) UpdateListeners(update GraphSignal) { closed := []chan GraphSignal{} for _, listener := range node.listeners { - log.Logf("update", "UPDATE_LISTENER %s: %p", node.Name(), listener) + log.Logf("listeners", "UPDATE_LISTENER %s: %p", node.Name(), listener) select { case listener <- update: default: - log.Logf("update", "CLOSED_LISTENER: %s: %p", node.Name(), listener) + log.Logf("listeners", "CLOSED_LISTENER: %s: %p", node.Name(), listener) go func(node GraphNode, listener chan GraphSignal) { listener <- NewSignal(node, "listener_closed") close(listener) diff --git a/main.go b/main.go index 9d795cd..43ab945 100644 --- a/main.go +++ b/main.go @@ -200,6 +200,7 @@ func main() { for true { select { case <-sigs: + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) signal := NewSignal(nil, "abort") signal.description = event_manager.root_event.ID() SendUpdate(event_manager.root_event, signal) diff --git a/test-site/src/routes/+page.svelte b/test-site/src/routes/+page.svelte index 9f06ceb..6e1f420 100644 --- a/test-site/src/routes/+page.svelte +++ b/test-site/src/routes/+page.svelte @@ -1,33 +1,63 @@ -

+ diff --git a/test-site/src/routes/Button.svelte b/test-site/src/routes/Button.svelte new file mode 100644 index 0000000..9cf9dd9 --- /dev/null +++ b/test-site/src/routes/Button.svelte @@ -0,0 +1,28 @@ + + + +