From 1a3a07336a6c16f9b50406d41a99c67e165674ec Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 29 Jul 2023 16:00:01 -0600 Subject: [PATCH] Added back GQL functionality, TODO pool channels for resolve executions instead of creating new ones and GCing them every time --- gql.go | 136 ++++++++++++++++++++++++++++++++++++---------- gql_interfaces.go | 66 ++++++++++++++-------- gql_query.go | 38 ++++++++++++- gql_resolvers.go | 89 ++++++++++++++++++++++-------- gql_test.go | 5 +- gql_types.go | 12 ++-- node.go | 2 +- signal.go | 21 ++++++- 8 files changed, 283 insertions(+), 86 deletions(-) diff --git a/gql.go b/gql.go index 4dda9d4..0dc4c89 100644 --- a/gql.go +++ b/gql.go @@ -26,6 +26,7 @@ import ( "crypto/x509/pkix" "math/big" "encoding/pem" + "github.com/google/uuid" ) func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { @@ -162,14 +163,32 @@ func checkForAuthHeader(header http.Header) (string, bool) { return "", false } +// Context passed to each resolve execution type ResolveContext struct { + // ID generated for the context so the gql extension can route data to it + ID uuid.UUID + + // Channel for the gql extension to route data to this context + Chan chan *ReadResultSignal + + // Graph Context this resolver is running under Context *Context + + // GQL Extension context this resolver is running under GQLContext *GQLExtContext + + // Pointer to the node that's currently processing this request Server *Node + + // The state data for the node processing this request Ext *GQLExt + + // ID of the user that made this request + // TODO: figure out auth User NodeID } +const GQL_RESOLVER_CHAN_SIZE = 10 func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { username, _, ok := r.BasicAuth() if ok == false { @@ -182,6 +201,9 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ } return &ResolveContext{ + Ext: gql_ext, + ID: uuid.New(), + Chan: make(chan *ReadResultSignal, GQL_RESOLVER_CHAN_SIZE), Context: ctx, GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), Server: server, @@ -231,6 +253,11 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW if len(query.Variables) > 0 { params.VariableValues = query.Variables } + + gql_ext.resolver_chans_lock.Lock() + gql_ext.resolver_chans[resolve_context.ID] = resolve_context.Chan + gql_ext.resolver_chans_lock.Unlock() + result := graphql.Do(params) if len(result.Errors) > 0 { extra_fields := map[string]interface{}{} @@ -452,33 +479,9 @@ func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init fu return &gql } -func NewInterface(if_name string, default_name string, interfaces []*graphql.Interface, extensions []ExtType, init_1 func(*Interface), init_2 func(*Interface)) *Interface { - var gql Interface - gql.Extensions = extensions - gql.Interface = graphql.NewInterface(graphql.InterfaceConfig{ - Name: if_name, - ResolveType: NodeResolver([]ExtType{}, &gql.Default), - Fields: graphql.Fields{}, - }) - gql.List = graphql.NewList(gql.Interface) - - init_1(&gql) - - gql.Default = graphql.NewObject(graphql.ObjectConfig{ - Name: default_name, - Interfaces: append(interfaces, gql.Interface), - IsTypeOf: GQLNodeHasExtensions([]ExtType{}), - Fields: graphql.Fields{}, - }) - - init_2(&gql) - - return &gql -} - -type GQLNode struct { - ID NodeID - Type NodeType +type Field struct { + Ext ExtType + Name string } // GQL Specific Context information @@ -489,6 +492,7 @@ type GQLExtContext struct { // Custom graphql types, mapped to NodeTypes NodeTypes map[NodeType]*graphql.Object Interfaces []*Interface + Fields map[string]Field // Schema parameters Types []graphql.Type @@ -497,6 +501,30 @@ type GQLExtContext struct { Subscription *graphql.Object } +func (ctx *GQLExtContext) GetACLFields(obj_name string, names []string) (map[ExtType][]string, error) { + ext_fields := map[ExtType][]string{} + for _, name := range(names) { + switch name { + case "ID": + case "TypeHash": + default: + field, exists := ctx.Fields[name] + if exists == false { + return nil, fmt.Errorf("%s is not a know field in GQLContext, cannot resolve", name) + } + + ext, exists := ext_fields[field.Ext] + if exists == false { + ext = []string{} + } + ext = append(ext, field.Name) + ext_fields[field.Ext] = ext + } + } + + return ext_fields, nil +} + func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) { schemaConfig := graphql.SchemaConfig{ Types: ctx.Types, @@ -508,6 +536,16 @@ func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) { return graphql.NewSchema(schemaConfig) } +func (ctx *GQLExtContext) RegisterField(gql_name string, ext ExtType, acl_name string) error { + _, exists := ctx.Fields[gql_name] + if exists == true { + return fmt.Errorf("%s is already a field in the context, cannot add again", gql_name) + } + + ctx.Fields[gql_name] = Field{ext, acl_name} + return nil +} + func (ctx *GQLExtContext) AddInterface(i *Interface) error { if i == nil { return fmt.Errorf("interface is nil") @@ -527,6 +565,7 @@ func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, gql_type *graphql if gql_type == nil { return fmt.Errorf("gql_type is nil") } + _, exists := ctx.NodeTypes[node_type] if exists == true { return fmt.Errorf("%s already in GQLExtContext.NodeTypes", node_type) @@ -597,6 +636,13 @@ type GQLExt struct { http_server *http.Server http_done sync.WaitGroup + // map of read request IDs to gql request ID + resolver_reads map[uuid.UUID]uuid.UUID + resolver_reads_lock sync.RWMutex + // map of gql request ID to active channel + resolver_chans map[uuid.UUID]chan *ReadResultSignal + resolver_chans_lock sync.RWMutex + tls_key []byte tls_cert []byte Listen string @@ -611,7 +657,39 @@ func (ext *GQLExt) Field(name string) interface{} { } func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { - if signal.Type() == GQLStateSignalType { + // Process ReadResultSignalType by forwarding it to the waiting resolver + if signal.Type() == ReadResultSignalType { + sig := signal.(ReadResultSignal) + ext.resolver_reads_lock.RLock() + resolver_id, exists := ext.resolver_reads[sig.UUID] + ext.resolver_reads_lock.RUnlock() + + if exists == true { + ext.resolver_reads_lock.Lock() + delete(ext.resolver_reads, sig.UUID) + ext.resolver_reads_lock.Unlock() + + ext.resolver_chans_lock.RLock() + resolver_chan, exists := ext.resolver_chans[resolver_id] + ext.resolver_chans_lock.RUnlock() + + if exists == true { + select { + case resolver_chan <- &sig: + ctx.Log.Logf("gql", "Forwarded to resolver %s, %+v", resolver_id, sig) + default: + ctx.Log.Logf("gql", "Resolver %s channel overflow %+v", resolver_id, sig) + ext.resolver_chans_lock.Lock() + delete(ext.resolver_chans, resolver_id) + ext.resolver_chans_lock.Unlock() + } + } else { + ctx.Log.Logf("gql", "Received message for waiting resolver %s which doesn't exist - %+v", resolver_id, sig) + } + } else { + ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig) + } + } else if signal.Type() == GQLStateSignalType { sig := signal.(StateSignal) switch sig.State { case "start_server": @@ -716,6 +794,8 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte) (*G } return &GQLExt{ Listen: listen, + resolver_reads: map[uuid.UUID]uuid.UUID{}, + resolver_chans: map[uuid.UUID]chan *ReadResultSignal{}, tls_cert: tls_cert, tls_key: tls_key, }, nil diff --git a/gql_interfaces.go b/gql_interfaces.go index 1083e76..cd8fa2d 100644 --- a/gql_interfaces.go +++ b/gql_interfaces.go @@ -55,48 +55,40 @@ func addLockableInterfaceFields(gql *Interface, gql_lockable *Interface) { }) } -func NodeHasExtensions(node *Node, extensions []ExtType) bool { - if node == nil { - return false - } - - for _, ext := range(extensions) { - _, has := node.Extensions[ext] - if has == false { - return false - } - } - - return true -} - -func GQLNodeHasExtensions(extensions []ExtType) func(graphql.IsTypeOfParams) bool { +func NodeIsTypeResolver(extensions []ExtType) func(graphql.IsTypeOfParams) bool { return func(p graphql.IsTypeOfParams) bool { - node, ok := p.Value.(*Node) + node, ok := p.Value.(NodeResult) if ok == false { return false } - return NodeHasExtensions(node, extensions) + for _, ext := range(extensions) { + _, has := node.Result.Extensions[ext] + if has == false { + return false + } + } + + return true } } -func NodeResolver(required_extensions []ExtType, default_type **graphql.Object)func(graphql.ResolveTypeParams) *graphql.Object { +func NodeTypeResolver(required_extensions []ExtType, default_type **graphql.Object)func(graphql.ResolveTypeParams) *graphql.Object { return func(p graphql.ResolveTypeParams) *graphql.Object { ctx, ok := p.Context.Value("resolve").(*ResolveContext) if ok == false { return nil } - node, ok := p.Value.(*Node) + node, ok := p.Value.(NodeResult) if ok == false { return nil } - gql_type, exists := ctx.GQLContext.NodeTypes[node.Type] + gql_type, exists := ctx.GQLContext.NodeTypes[node.Result.NodeType] if exists == false { for _, ext := range(required_extensions) { - _, exists := node.Extensions[ext] + _, exists := node.Result.Extensions[ext] if exists == false { return nil } @@ -108,6 +100,36 @@ func NodeResolver(required_extensions []ExtType, default_type **graphql.Object)f } } +type NodeResult struct { + ID NodeID + Result *ReadResultSignal +} + +func NewInterface(if_name string, default_name string, interfaces []*graphql.Interface, extensions []ExtType, init_1 func(*Interface), init_2 func(*Interface)) *Interface { + var gql Interface + gql.Extensions = extensions + gql.Interface = graphql.NewInterface(graphql.InterfaceConfig{ + Name: if_name, + ResolveType: NodeTypeResolver([]ExtType{}, &gql.Default), + Fields: graphql.Fields{}, + }) + gql.List = graphql.NewList(gql.Interface) + + init_1(&gql) + + gql.Default = graphql.NewObject(graphql.ObjectConfig{ + Name: default_name, + Interfaces: append(interfaces, gql.Interface), + IsTypeOf: NodeIsTypeResolver([]ExtType{}), + Fields: graphql.Fields{}, + }) + + init_2(&gql) + + return &gql +} + + var InterfaceNode = NewInterface("Node", "DefaultNode", []*graphql.Interface{}, []ExtType{}, func(gql *Interface) { AddNodeInterfaceFields(gql) }, func(gql *Interface) { diff --git a/gql_query.go b/gql_query.go index e3b8502..184ca03 100644 --- a/gql_query.go +++ b/gql_query.go @@ -1,5 +1,6 @@ package graphvent import ( + "time" "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/ast" ) @@ -28,12 +29,43 @@ var QueryNode = &graphql.Field{ if err != nil { return nil, err } - ctx.Context.Log.Logf("gql", "FIELDS: %+v", GetFieldNames(p)) + + + id, err := ExtractID(p, "id") + if err != nil { + return nil, err + } + + fields := GetFieldNames(p) + ctx.Context.Log.Logf("gql", "RESOLVE_NODE(%s): %+v", id, fields) + // Get a list of fields that will be written - // Send the read signal + ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields) + if err != nil { + return nil, err + } + // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error + read_signal := NewReadSignal(ext_fields) + + ctx.Ext.resolver_reads_lock.Lock() + ctx.Ext.resolver_reads[read_signal.UUID] = ctx.ID + ctx.Ext.resolver_reads_lock.Unlock() + + err = ctx.Context.Send(ctx.Server.ID, id, read_signal) + if err != nil { + ctx.Ext.resolver_reads_lock.Lock() + delete(ctx.Ext.resolver_reads, read_signal.UUID) + ctx.Ext.resolver_reads_lock.Unlock() + return nil, err + } + // Wait for the response, returning an error on timeout + response, err := WaitForReadResult(ctx.Chan, time.Millisecond*100, read_signal.UUID) + if err != nil { + return nil, err + } - return nil, nil + return NodeResult{id, response}, nil }, } diff --git a/gql_resolvers.go b/gql_resolvers.go index 4ba12e8..e545510 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -64,46 +64,89 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) { return id, nil } -// TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted -func GQLNodeID(p graphql.ResolveParams) (interface{}, error) { - return nil, nil -} +func ResolveNodeResult(p graphql.ResolveParams, resolve_fn func(NodeResult)(interface{}, error)) (interface{}, error) { + node, ok := p.Source.(NodeResult) + if ok == false { + return nil, fmt.Errorf("p.Value is not NodeResult") + } -func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { - return nil, nil + return resolve_fn(node) } -func GQLNodeListen(p graphql.ResolveParams) (interface{}, error) { - // TODO figure out how nodes can read eachother - return "", nil +func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) { + return ResolveNodeResult(p, func(node NodeResult) (interface{}, error) { + return node.ID, nil + }) } -func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { - return nil, nil +func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { + return ResolveNodeResult(p, func(node NodeResult) (interface{}, error) { + return Hash(node.Result.NodeType), nil + }) } -func GQLThreadState(p graphql.ResolveParams) (interface{}, error) { - return "", nil +func ResolveNodeResultExt[T any](p graphql.ResolveParams, ext_type ExtType, field string, resolve_fn func(T)(interface{}, error)) (interface{}, error) { + return ResolveNodeResult(p, func(result NodeResult) (interface{}, error) { + ext, exists := result.Result.Extensions[ext_type] + if exists == false { + return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type) + } + + val_if, exists := ext[field] + if exists == false { + return nil, fmt.Errorf("%s is not in the fields of %s in the result", field, ext_type) + } + + var zero T + val, ok := val_if.(T) + if ok == false { + return nil, fmt.Errorf("%s.%s is not %s", ext_type, field, reflect.TypeOf(zero)) + } + + return resolve_fn(val) + }) } -func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { - return nil, nil +func ResolveListen(p graphql.ResolveParams) (interface{}, error) { + return ResolveNodeResultExt(p, GQLExtType, "listen", func(listen string) (interface{}, error) { + return listen, nil + }) } -func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { - return nil, nil +func ResolveRequirements(p graphql.ResolveParams) (interface{}, error) { + return ResolveNodeResultExt(p, LockableExtType, "requirements", func(requirements []NodeID) (interface{}, error) { + res := make([]string, len(requirements)) + for i, id := range(requirements) { + res[i] = id.String() + } + return res, nil + }) } -func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { - return nil, nil +func ResolveDependencies(p graphql.ResolveParams) (interface{}, error) { + return ResolveNodeResultExt(p, LockableExtType, "dependencies", func(dependencies []NodeID) (interface{}, error) { + res := make([]string, len(dependencies)) + for i, id := range(dependencies) { + res[i] = id.String() + } + return res, nil + }) } -func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { - return nil, nil +func ResolveOwner(p graphql.ResolveParams) (interface{}, error) { + return ResolveNodeResultExt(p, LockableExtType, "owner", func(owner NodeID) (interface{}, error) { + return owner.String(), nil + }) } -func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) { - return nil, nil +func ResolveMembers(p graphql.ResolveParams) (interface{}, error) { + return ResolveNodeResultExt(p, GroupExtType, "members", func(members []NodeID) (interface{}, error) { + res := make([]string, len(members)) + for i, id := range(members) { + res[i] = id.String() + } + return res, nil + }) } func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { diff --git a/gql_test.go b/gql_test.go index f407fef..5887856 100644 --- a/gql_test.go +++ b/gql_test.go @@ -40,7 +40,10 @@ func TestGQL(t *testing.T) { url := fmt.Sprintf("https://localhost:%d/gql", port) ser, err := json.MarshalIndent(&GQLPayload{ - Query: "query { Self { ID } }", + Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash } }", + Variables: map[string]interface{}{ + "id": n1.ID.String(), + }, }, "", " ") fatalErr(t, err) diff --git a/gql_types.go b/gql_types.go index 5d81d61..676d646 100644 --- a/gql_types.go +++ b/gql_types.go @@ -7,12 +7,12 @@ import ( func AddNodeFields(object *graphql.Object) { object.AddFieldConfig("ID", &graphql.Field{ Type: graphql.String, - Resolve: GQLNodeID, + Resolve: ResolveNodeID, }) object.AddFieldConfig("TypeHash", &graphql.Field{ Type: graphql.String, - Resolve: GQLNodeTypeHash, + Resolve: ResolveNodeTypeHash, }) } @@ -24,17 +24,17 @@ func addLockableFields(object *graphql.Object, lockable_interface *graphql.Inter AddNodeFields(object) object.AddFieldConfig("Requirements", &graphql.Field{ Type: lockable_list, - Resolve: GQLLockableRequirements, + Resolve: ResolveRequirements, }) object.AddFieldConfig("Owner", &graphql.Field{ Type: lockable_interface, - Resolve: GQLLockableOwner, + Resolve: ResolveOwner, }) object.AddFieldConfig("Dependencies", &graphql.Field{ Type: lockable_list, - Resolve: GQLLockableDependencies, + Resolve: ResolveDependencies, }) } @@ -46,7 +46,7 @@ var TypeGQLNode = NewGQLNodeType(GQLNodeType, GQLNodeInterfaces, func(gql *Type) gql.Type.AddFieldConfig("Listen", &graphql.Field{ Type: graphql.String, - Resolve: GQLNodeListen, + Resolve: ResolveListen, }) }) diff --git a/node.go b/node.go index 4b84751..1b1b02b 100644 --- a/node.go +++ b/node.go @@ -218,7 +218,7 @@ func nodeLoop(ctx *Context, node *Node) error { ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) } else { result := ReadNodeFields(ctx, node, source, read_signal.Extensions) - ctx.Send(node.ID, source, NewReadResultSignal(node.Type, result)) + ctx.Send(node.ID, source, NewReadResultSignal(read_signal.UUID, node.Type, result)) } } diff --git a/signal.go b/signal.go index cbac046..72e91c3 100644 --- a/signal.go +++ b/signal.go @@ -11,6 +11,7 @@ import ( "crypto/rand" "crypto/aes" "crypto/cipher" + "github.com/google/uuid" ) type SignalDirection int @@ -42,6 +43,18 @@ type Signal interface { Permission() Action } +func WaitForReadResult(listener chan *ReadResultSignal, timeout time.Duration, id uuid.UUID) (*ReadResultSignal, error) { + timeout_channel := time.After(timeout) + var err error = nil + var result *ReadResultSignal = nil + select { + case result =<-listener: + case <-timeout_channel: + err = fmt.Errorf("timeout waiting for read response to %s", id) + } + return result, err +} + func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { var zero S timeout_channel := time.After(timeout) @@ -201,6 +214,7 @@ func (signal StateSignal) Permission() Action { type ReadSignal struct { BaseSignal + UUID uuid.UUID Extensions map[ExtType][]string `json:"extensions"` } @@ -210,6 +224,7 @@ func (signal ReadSignal) Serialize() ([]byte, error) { func NewReadSignal(exts map[ExtType][]string) ReadSignal { return ReadSignal{ + UUID: uuid.New(), BaseSignal: NewDirectSignal(ReadSignalType), Extensions: exts, } @@ -217,13 +232,15 @@ func NewReadSignal(exts map[ExtType][]string) ReadSignal { type ReadResultSignal struct { BaseSignal - NodeType NodeType + uuid.UUID + NodeType Extensions map[ExtType]map[string]interface{} `json:"extensions"` } -func NewReadResultSignal(node_type NodeType, exts map[ExtType]map[string]interface{}) ReadResultSignal { +func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]map[string]interface{}) ReadResultSignal { return ReadResultSignal{ BaseSignal: NewDirectSignal(ReadResultSignalType), + UUID: req_id, NodeType: node_type, Extensions: exts, }