diff --git a/go.mod b/go.mod
index 2b2dcd2..3dcc673 100644
--- a/go.mod
+++ b/go.mod
@@ -9,6 +9,7 @@ require (
github.com/dgraph-io/badger/v3 v3.2103.5
github.com/gobwas/ws v1.2.1
github.com/google/uuid v1.3.0
+ github.com/graphql-go/graphql v0.8.1
github.com/mekkanized/graphvent/signal v0.0.0
github.com/rs/zerolog v1.29.1
golang.org/x/net v0.7.0
@@ -27,7 +28,6 @@ require (
github.com/golang/protobuf v1.3.1 // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/google/flatbuffers v1.12.1 // indirect
- github.com/graphql-go/graphql v0.8.1 // indirect
github.com/klauspost/compress v1.12.3 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
diff --git a/gql.go b/gql.go
index a74560a..4073868 100644
--- a/gql.go
+++ b/gql.go
@@ -42,7 +42,7 @@ func NodeInterfaceDefaultIsType(required_extensions []ExtType) func(graphql.IsTy
return false
}
- node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType]
+ node_type_def, exists := ctx.Context.Nodes[node.NodeType]
if exists == false {
return false
} else {
@@ -76,10 +76,10 @@ func NodeInterfaceResolveType(required_extensions []ExtType, default_type **grap
return nil
}
- gql_type, exists := ctx.GQLContext.NodeTypes[node.Result.NodeType]
+ gql_type, exists := ctx.GQLContext.NodeTypes[node.NodeType]
ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type)
if exists == false {
- node_type_def, exists := ctx.Context.Nodes[node.Result.NodeType]
+ node_type_def, exists := ctx.Context.Nodes[node.NodeType]
if exists == false {
return nil
} else {
@@ -175,57 +175,57 @@ func GraphiQLHandler() func(http.ResponseWriter, *http.Request) {
- GraphiQL
-
-
-
-
-
-
+ GraphiQL
+
+
+
+
+
+
- Loading...
-
-
+ Loading...
+
+
`)
@@ -252,10 +252,10 @@ type GQLWSMsg struct {
}
func enableCORS(w *http.ResponseWriter) {
- (*w).Header().Set("Access-Control-Allow-Origin", "*")
- (*w).Header().Set("Access-Control-Allow-Credentials", "true")
- (*w).Header().Set("Access-Control-Allow-Headers", "*")
- (*w).Header().Set("Access-Control-Allow-Methods", "*")
+ (*w).Header().Set("Access-Control-Allow-Origin", "*")
+ (*w).Header().Set("Access-Control-Allow-Credentials", "true")
+ (*w).Header().Set("Access-Control-Allow-Headers", "*")
+ (*w).Header().Set("Access-Control-Allow-Methods", "*")
}
type GQLUnauthorized string
@@ -317,7 +317,7 @@ type ResolveContext struct {
User NodeID
// Cache of resolved nodes
- NodeCache map[NodeID]interface{}
+ NodeCache map[NodeID]NodeResult
// Key for the user that made this request, to sign resolver requests
// TODO: figure out some way to use a generated key so that the server can't impersonate the user afterwards
@@ -369,6 +369,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
Chans: map[uuid.UUID]chan Signal{},
Context: ctx,
GQLContext: ctx.Extensions[GQLExtType].Data.(*GQLExtContext),
+ NodeCache: map[NodeID]NodeResult{},
Server: server,
User: key_id,
Key: key,
@@ -400,7 +401,7 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
ctx.Log.Logf("gql", "GQL_READ_ERR: %s", err)
json.NewEncoder(w).Encode(fmt.Sprintf("%e", err))
return
- }
+ }
query := GQLPayload{}
json.Unmarshal(str, &query)
@@ -649,17 +650,17 @@ func (ctx *GQLExtContext) GetACLFields(obj_name string, names []string) (map[Ext
case "ID":
case "TypeHash":
default:
- field, exists := ctx.Fields[name]
- if exists == false {
- return nil, fmt.Errorf("%s is not a know field in GQLContext, cannot resolve", name)
- }
-
- ext, exists := ext_fields[field.Ext]
- if exists == false {
- ext = []string{}
- }
- ext = append(ext, field.Name)
- ext_fields[field.Ext] = ext
+ field, exists := ctx.Fields[name]
+ if exists == false {
+ return nil, fmt.Errorf("%s is not a know field in GQLContext, cannot resolve", name)
+ }
+
+ ext, exists := ext_fields[field.Ext]
+ if exists == false {
+ ext = []string{}
+ }
+ ext = append(ext, field.Name)
+ ext_fields[field.Ext] = ext
}
}
@@ -703,7 +704,7 @@ func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string,
return nil, fmt.Errorf("p.Value is not NodeResult")
}
- ext, ext_exists := node.Result.Extensions[ext_type]
+ ext, ext_exists := node.Data[ext_type]
if ext_exists == false {
return nil, fmt.Errorf("%+v is not in the extensions of the result", ext_type)
}
@@ -726,6 +727,8 @@ func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string,
return nil, fmt.Errorf("%s returned a nil value of %+v type", gv_tag, field_type)
}
+ ctx.Context.Log.Logf("gql", "Resolving %+v", field_value)
+
return resolve_fn(p, ctx, *field_value)
}
@@ -780,8 +783,9 @@ func GQLFields(ctx *GQLExtContext, field_names []string) (graphql.Fields, []ExtT
}
type NodeResult struct {
- ID NodeID
- Result *ReadResultSignal
+ NodeID NodeID
+ NodeType NodeType
+ Data map[ExtType]map[string]SerializedValue
}
type ListField struct {
@@ -926,7 +930,7 @@ func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, inte
return false
}
- return node.Result.NodeType == node_type
+ return node.NodeType == node_type
},
Fields: gql_fields,
})
@@ -971,25 +975,25 @@ func NewGQLExtContext() *GQLExtContext {
}
err = context.RegisterField(context.Interfaces["Node"].List, "Members", GroupExtType, "members",
- func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) {
- node_map, ok := value.Interface().(map[NodeID]string)
- if ok == false {
- return nil, fmt.Errorf("value is %+v, not map[NodeID]string", value.Type())
- }
- node_list := []NodeID{}
- i := 0
- for id := range(node_map) {
- node_list = append(node_list, id)
- i += 1
- }
-
- nodes, err := ResolveNodes(ctx, p, node_list)
- if err != nil {
- return nil, err
- }
-
- return nodes, nil
- })
+ func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) {
+ node_map, ok := value.Interface().(map[NodeID]string)
+ if ok == false {
+ return nil, fmt.Errorf("value is %+v, not map[NodeID]string", value.Type())
+ }
+ node_list := []NodeID{}
+ i := 0
+ for id := range(node_map) {
+ node_list = append(node_list, id)
+ i += 1
+ }
+
+ nodes, err := ResolveNodes(ctx, p, node_list)
+ if err != nil {
+ return nil, err
+ }
+
+ return nodes, nil
+ })
if err != nil {
panic(err)
}
@@ -1069,7 +1073,7 @@ func NewGQLExtContext() *GQLExtContext {
return nil, fmt.Errorf("wrong length of nodes returned")
}
- ctx.Context.Log.Logf("gql", "NODES: %+v", nodes[0].Result)
+ ctx.Context.Log.Logf("gql", "NODES: %+v", nodes[0])
c <- nodes[0]
return c, nil
@@ -1079,17 +1083,17 @@ func NewGQLExtContext() *GQLExtContext {
if err != nil {
return nil, err
}
- var self_result NodeResult
+
switch source := p.Source.(type) {
case NodeResult:
- self_result = source
- ctx.Context.Log.Logf("gql_subscribe", "SUBSCRIBE_FIRST_RESULT: %+v", self_result)
case StatusSignal:
+ delete(ctx.NodeCache, source.Source)
+ ctx.Context.Log.Logf("gql_subscribe", "Deleting %+v from NodeCache", source.Source)
default:
return nil, fmt.Errorf("Don't know how to handle %+v", source)
}
- return self_result, nil
+ return ctx.NodeCache[ctx.Server.ID], nil
},
})
@@ -1352,7 +1356,7 @@ func (ext *GQLExt) StartGQLServer(ctx *Context, node *Node) error {
err := http_server.Serve(listener)
if err != http.ErrServerClosed {
- panic(fmt.Sprintf("Failed to start gql server: %s", err))
+ panic(fmt.Sprintf("Failed to start gql server: %s", err))
}
}(ext)
diff --git a/gql_node.go b/gql_node.go
index 9c0d8ee..4849dcb 100644
--- a/gql_node.go
+++ b/gql_node.go
@@ -14,7 +14,7 @@ func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
return nil, fmt.Errorf("Can't get NodeID from %+v", reflect.TypeOf(p.Source))
}
- return node.ID, nil
+ return node.NodeID, nil
}
func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
@@ -23,7 +23,7 @@ func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
return nil, fmt.Errorf("Can't get TypeHash from %+v", reflect.TypeOf(p.Source))
}
- return uint64(node.Result.NodeType), nil
+ return uint64(node.NodeType), nil
}
func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string {
@@ -60,21 +60,60 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
ctx.Context.Log.Logf("gql", "RESOLVE_NODES(%+v): %+v", ids, fields)
resp_channels := map[uuid.UUID]chan Signal{}
- node_ids := map[uuid.UUID]NodeID{}
- for _, id := range(ids) {
- // Get a list of fields that will be written
- ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields)
- if err != nil {
- return nil, err
+ indices := map[uuid.UUID]int{}
+
+ // Get a list of fields that will be written
+ ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields)
+ if err != nil {
+ return nil, err
+ }
+
+ responses := make([]NodeResult, len(ids))
+
+ for i, id := range(ids) {
+ var read_signal *ReadSignal = nil
+
+ node, cached := ctx.NodeCache[id]
+ if cached == true {
+ resolve := false
+ missing_exts := map[ExtType][]string{}
+ for ext_type, fields := range(ext_fields) {
+ cached_ext, exists := node.Data[ext_type]
+ if exists == true {
+ missing_fields := []string{}
+ for _, field_name := range(fields) {
+ _, found := cached_ext[field_name]
+ if found == false {
+ missing_fields = append(missing_fields, field_name)
+ }
+ }
+ if len(missing_fields) > 0 {
+ missing_exts[ext_type] = missing_fields
+ resolve = true
+ }
+ } else {
+ missing_exts[ext_type] = fields
+ resolve = true
+ }
+ }
+
+ if resolve == true {
+ read_signal = NewReadSignal(missing_exts)
+ } else {
+ ctx.Context.Log.Logf("gql_subscribe", "Using cached response for %+v(%d)", id, i)
+ responses[i] = node
+ continue
+ }
+ } else {
+ read_signal = NewReadSignal(ext_fields)
}
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
- read_signal := NewReadSignal(ext_fields)
msgs := Messages{}
msgs = msgs.Add(ctx.Context, ctx.Server.ID, ctx.Key, read_signal, id)
response_chan := ctx.Ext.GetResponseChannel(read_signal.ID)
resp_channels[read_signal.ID] = response_chan
- node_ids[read_signal.ID] = id
+ indices[read_signal.ID] = i
// TODO: Send all at once instead of creating Messages for each
err = ctx.Context.Send(msgs)
@@ -84,7 +123,6 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
}
}
- responses := []NodeResult{}
for sig_id, response_chan := range(resp_channels) {
// Wait for the response, returning an error on timeout
response, err := WaitForSignal(response_chan, time.Millisecond*100, func(sig *ReadResultSignal)bool{
@@ -93,9 +131,32 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
if err != nil {
return nil, err
}
- responses = append(responses, NodeResult{node_ids[sig_id], response})
+
+ idx := indices[sig_id]
+ responses[idx] = NodeResult{
+ response.NodeID,
+ response.NodeType,
+ response.Extensions,
+ }
+
+ cache, exists := ctx.NodeCache[response.NodeID]
+ if exists == true {
+ for ext_type, fields := range(response.Extensions) {
+ cached_fields, exists := cache.Data[ext_type]
+ if exists == true {
+ for field_name, field_value := range(fields) {
+ cached_fields[field_name] = field_value
+ }
+ }
+ }
+ ctx.Context.Log.Logf("gql_subscribe", "CACHED_EXISTING_NODE: %+v", response.NodeID)
+ } else {
+ ctx.NodeCache[response.NodeID] = responses[idx]
+ ctx.Context.Log.Logf("gql_subscribe", "CACHED_NEW_NODE: %+v", response.NodeID)
+ }
+
}
- ctx.Context.Log.Logf("gql", "RESOLVED_NODES")
+ ctx.Context.Log.Logf("gql", "RESOLVED_NODES %+v - %+v", ids, responses)
return responses, nil
}