diff --git a/go.mod b/go.mod index 2b2dcd2..3dcc673 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/dgraph-io/badger/v3 v3.2103.5 github.com/gobwas/ws v1.2.1 github.com/google/uuid v1.3.0 + github.com/graphql-go/graphql v0.8.1 github.com/mekkanized/graphvent/signal v0.0.0 github.com/rs/zerolog v1.29.1 golang.org/x/net v0.7.0 @@ -27,7 +28,6 @@ require ( github.com/golang/protobuf v1.3.1 // indirect github.com/golang/snappy v0.0.3 // indirect github.com/google/flatbuffers v1.12.1 // indirect - github.com/graphql-go/graphql v0.8.1 // indirect github.com/klauspost/compress v1.12.3 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect diff --git a/gql.go b/gql.go index a74560a..4073868 100644 --- a/gql.go +++ b/gql.go @@ -42,7 +42,7 @@ func NodeInterfaceDefaultIsType(required_extensions []ExtType) func(graphql.IsTy return false } - node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType] + node_type_def, exists := ctx.Context.Nodes[node.NodeType] if exists == false { return false } else { @@ -76,10 +76,10 @@ func NodeInterfaceResolveType(required_extensions []ExtType, default_type **grap return nil } - gql_type, exists := ctx.GQLContext.NodeTypes[node.Result.NodeType] + gql_type, exists := ctx.GQLContext.NodeTypes[node.NodeType] ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type) if exists == false { - node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType] + node_type_def, exists := ctx.Context.Nodes[node.NodeType] if exists == false { return nil } else { @@ -175,57 +175,57 @@ func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { - GraphiQL - - - - - - + GraphiQL + + + + + + -
Loading...
- - +
Loading...
+ + `) @@ -252,10 +252,10 @@ type GQLWSMsg struct { } func enableCORS(w *http.ResponseWriter) { - (*w).Header().Set("Access-Control-Allow-Origin", "*") - (*w).Header().Set("Access-Control-Allow-Credentials", "true") - (*w).Header().Set("Access-Control-Allow-Headers", "*") - (*w).Header().Set("Access-Control-Allow-Methods", "*") + (*w).Header().Set("Access-Control-Allow-Origin", "*") + (*w).Header().Set("Access-Control-Allow-Credentials", "true") + (*w).Header().Set("Access-Control-Allow-Headers", "*") + (*w).Header().Set("Access-Control-Allow-Methods", "*") } type GQLUnauthorized string @@ -317,7 +317,7 @@ type ResolveContext struct { User NodeID // Cache of resolved nodes - NodeCache map[NodeID]interface{} + NodeCache map[NodeID]NodeResult // Key for the user that made this request, to sign resolver requests // TODO: figure out some way to use a generated key so that the server can't impersonate the user afterwards @@ -369,6 +369,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ Chans: map[uuid.UUID]chan Signal{}, Context: ctx, GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext), + NodeCache: map[NodeID]NodeResult{}, Server: server, User: key_id, Key: key, @@ -400,7 +401,7 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW ctx.Log.Logf("gql", "GQL_READ_ERR: %s", err) json.NewEncoder(w).Encode(fmt.Sprintf("%e", err)) return - } + } query := GQLPayload{} json.Unmarshal(str, &query) @@ -649,17 +650,17 @@ func (ctx *GQLExtContext) GetACLFields(obj_name string, names []string) (map[Ext case "ID": case "TypeHash": default: - field, exists := ctx.Fields[name] - if exists == false { - return nil, fmt.Errorf("%s is not a know field in GQLContext, cannot resolve", name) - } - - ext, exists := ext_fields[field.Ext] - if exists == false { - ext = []string{} - } - ext = append(ext, field.Name) - ext_fields[field.Ext] = ext + field, exists := ctx.Fields[name] + if exists == false { + return nil, fmt.Errorf("%s is not a know field in GQLContext, cannot resolve", name) + } + + ext, exists := ext_fields[field.Ext] + if exists == false { + ext = []string{} + } + ext = append(ext, field.Name) + ext_fields[field.Ext] = ext } } @@ -703,7 +704,7 @@ func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, return nil, fmt.Errorf("p.Value is not NodeResult") } - ext, ext_exists := node.Result.Extensions[ext_type] + ext, ext_exists := node.Data[ext_type] if ext_exists == false { return nil, fmt.Errorf("%+v is not in the extensions of the result", ext_type) } @@ -726,6 +727,8 @@ func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, return nil, fmt.Errorf("%s returned a nil value of %+v type", gv_tag, field_type) } + ctx.Context.Log.Logf("gql", "Resolving %+v", field_value) + return resolve_fn(p, ctx, *field_value) } @@ -780,8 +783,9 @@ func GQLFields(ctx *GQLExtContext, field_names []string) (graphql.Fields, []ExtT } type NodeResult struct { - ID NodeID - Result *ReadResultSignal + NodeID NodeID + NodeType NodeType + Data map[ExtType]map[string]SerializedValue } type ListField struct { @@ -926,7 +930,7 @@ func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, inte return false } - return node.Result.NodeType == node_type + return node.NodeType == node_type }, Fields: gql_fields, }) @@ -971,25 +975,25 @@ func NewGQLExtContext() *GQLExtContext { } err = context.RegisterField(context.Interfaces["Node"].List, "Members", GroupExtType, "members", - func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) { - node_map, ok := value.Interface().(map[NodeID]string) - if ok == false { - return nil, fmt.Errorf("value is %+v, not map[NodeID]string", value.Type()) - } - node_list := []NodeID{} - i := 0 - for id := range(node_map) { - node_list = append(node_list, id) - i += 1 - } - - nodes, err := ResolveNodes(ctx, p, node_list) - if err != nil { - return nil, err - } - - return nodes, nil - }) + func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) { + node_map, ok := value.Interface().(map[NodeID]string) + if ok == false { + return nil, fmt.Errorf("value is %+v, not map[NodeID]string", value.Type()) + } + node_list := []NodeID{} + i := 0 + for id := range(node_map) { + node_list = append(node_list, id) + i += 1 + } + + nodes, err := ResolveNodes(ctx, p, node_list) + if err != nil { + return nil, err + } + + return nodes, nil + }) if err != nil { panic(err) } @@ -1069,7 +1073,7 @@ func NewGQLExtContext() *GQLExtContext { return nil, fmt.Errorf("wrong length of nodes returned") } - ctx.Context.Log.Logf("gql", "NODES: %+v", nodes[0].Result) + ctx.Context.Log.Logf("gql", "NODES: %+v", nodes[0]) c <- nodes[0] return c, nil @@ -1079,17 +1083,17 @@ func NewGQLExtContext() *GQLExtContext { if err != nil { return nil, err } - var self_result NodeResult + switch source := p.Source.(type) { case NodeResult: - self_result = source - ctx.Context.Log.Logf("gql_subscribe", "SUBSCRIBE_FIRST_RESULT: %+v", self_result) case StatusSignal: + delete(ctx.NodeCache, source.Source) + ctx.Context.Log.Logf("gql_subscribe", "Deleting %+v from NodeCache", source.Source) default: return nil, fmt.Errorf("Don't know how to handle %+v", source) } - return self_result, nil + return ctx.NodeCache[ctx.Server.ID], nil }, }) @@ -1352,7 +1356,7 @@ func (ext *GQLExt) StartGQLServer(ctx *Context, node *Node) error { err := http_server.Serve(listener) if err != http.ErrServerClosed { - panic(fmt.Sprintf("Failed to start gql server: %s", err)) + panic(fmt.Sprintf("Failed to start gql server: %s", err)) } }(ext) diff --git a/gql_node.go b/gql_node.go index 9c0d8ee..4849dcb 100644 --- a/gql_node.go +++ b/gql_node.go @@ -14,7 +14,7 @@ func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) { return nil, fmt.Errorf("Can't get NodeID from %+v", reflect.TypeOf(p.Source)) } - return node.ID, nil + return node.NodeID, nil } func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { @@ -23,7 +23,7 @@ func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { return nil, fmt.Errorf("Can't get TypeHash from %+v", reflect.TypeOf(p.Source)) } - return uint64(node.Result.NodeType), nil + return uint64(node.NodeType), nil } func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string { @@ -60,21 +60,60 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ ctx.Context.Log.Logf("gql", "RESOLVE_NODES(%+v): %+v", ids, fields) resp_channels := map[uuid.UUID]chan Signal{} - node_ids := map[uuid.UUID]NodeID{} - for _, id := range(ids) { - // Get a list of fields that will be written - ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields) - if err != nil { - return nil, err + indices := map[uuid.UUID]int{} + + // Get a list of fields that will be written + ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields) + if err != nil { + return nil, err + } + + responses := make([]NodeResult, len(ids)) + + for i, id := range(ids) { + var read_signal *ReadSignal = nil + + node, cached := ctx.NodeCache[id] + if cached == true { + resolve := false + missing_exts := map[ExtType][]string{} + for ext_type, fields := range(ext_fields) { + cached_ext, exists := node.Data[ext_type] + if exists == true { + missing_fields := []string{} + for _, field_name := range(fields) { + _, found := cached_ext[field_name] + if found == false { + missing_fields = append(missing_fields, field_name) + } + } + if len(missing_fields) > 0 { + missing_exts[ext_type] = missing_fields + resolve = true + } + } else { + missing_exts[ext_type] = fields + resolve = true + } + } + + if resolve == true { + read_signal = NewReadSignal(missing_exts) + } else { + ctx.Context.Log.Logf("gql_subscribe", "Using cached response for %+v(%d)", id, i) + responses[i] = node + continue + } + } else { + read_signal = NewReadSignal(ext_fields) } // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error - read_signal := NewReadSignal(ext_fields) msgs := Messages{} msgs = msgs.Add(ctx.Context, ctx.Server.ID, ctx.Key, read_signal, id) response_chan := ctx.Ext.GetResponseChannel(read_signal.ID) resp_channels[read_signal.ID] = response_chan - node_ids[read_signal.ID] = id + indices[read_signal.ID] = i // TODO: Send all at once instead of creating Messages for each err = ctx.Context.Send(msgs) @@ -84,7 +123,6 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ } } - responses := []NodeResult{} for sig_id, response_chan := range(resp_channels) { // Wait for the response, returning an error on timeout response, err := WaitForSignal(response_chan, time.Millisecond*100, func(sig *ReadResultSignal)bool{ @@ -93,9 +131,32 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ if err != nil { return nil, err } - responses = append(responses, NodeResult{node_ids[sig_id], response}) + + idx := indices[sig_id] + responses[idx] = NodeResult{ + response.NodeID, + response.NodeType, + response.Extensions, + } + + cache, exists := ctx.NodeCache[response.NodeID] + if exists == true { + for ext_type, fields := range(response.Extensions) { + cached_fields, exists := cache.Data[ext_type] + if exists == true { + for field_name, field_value := range(fields) { + cached_fields[field_name] = field_value + } + } + } + ctx.Context.Log.Logf("gql_subscribe", "CACHED_EXISTING_NODE: %+v", response.NodeID) + } else { + ctx.NodeCache[response.NodeID] = responses[idx] + ctx.Context.Log.Logf("gql_subscribe", "CACHED_NEW_NODE: %+v", response.NodeID) + } + } - ctx.Context.Log.Logf("gql", "RESOLVED_NODES") + ctx.Context.Log.Logf("gql", "RESOLVED_NODES %+v - %+v", ids, responses) return responses, nil }