diff --git a/context.go b/context.go index 2fe2a1b..9590ded 100644 --- a/context.go +++ b/context.go @@ -2,7 +2,6 @@ package graphvent import ( badger "github.com/dgraph-io/badger/v3" - "github.com/graphql-go/graphql" "fmt" "sync" "errors" @@ -274,18 +273,6 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = RegisterField(gql_ctx, graphql.String, "Listen", GQLExtType, "listen", func(listen string) (interface{}, error) { - return listen, nil - }) - if err != nil { - return nil, err - } - - err = gql_ctx.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node"}, []string{"Listen"}) - if err != nil { - return nil, err - } - schema, err := BuildSchema(gql_ctx) if err != nil { return nil, err diff --git a/gql.go b/gql.go index 93e7ce6..ed93112 100644 --- a/gql.go +++ b/gql.go @@ -517,7 +517,7 @@ func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) { return graphql.NewSchema(schemaConfig) } -func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(T)(interface{}, error)) error { +func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(graphql.ResolveParams, T)(interface{}, error)) error { if ctx == nil { return fmt.Errorf("ctx is nil") } @@ -534,7 +534,7 @@ func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name st ctx.Fields[gql_name] = Field{ext_type, acl_name, &graphql.Field{ Type: gql_type, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResult(p, func(result NodeResult) (interface{}, error) { + return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) { ext, exists := result.Result.Extensions[ext_type] if exists == false { return nil, fmt.Errorf("%s is not in the extensions of the result", ext_type) @@ -551,7 +551,7 @@ func RegisterField[T any](ctx *GQLExtContext, gql_type graphql.Type, gql_name st return nil, fmt.Errorf("%s.%s is not %s", ext_type, acl_name, reflect.TypeOf(zero)) } - return resolve_fn(val) + return resolve_fn(p, val) }) }, }} @@ -609,13 +609,13 @@ type NodeResult struct { type ListField struct { ACLName string Extension ExtType - ResolveFn func(interface{}) ([]NodeID, error) + ResolveFn func(graphql.ResolveParams, interface{}) ([]NodeID, error) } type SelfField struct { ACLName string Extension ExtType - ResolveFn func(interface{}) (NodeID, error) + ResolveFn func(graphql.ResolveParams, interface{}) (NodeID, error) } func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, interfaces []string, fields []string, self_fields map[string]SelfField, list_fields map[string]ListField) error { @@ -651,14 +651,63 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in }) ctx_interface.List = graphql.NewList(ctx_interface.Interface) - //TODO finish self_fields and do list_fields for field_name, self_field := range(self_fields) { - err := RegisterField(ctx, ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName, func(id_str string) (interface{}, error) { - return nil, nil + err := RegisterField(ctx, ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName, + func(p graphql.ResolveParams, val interface{})(interface{}, error) { + ctx, err := PrepResolve(p) + if err != nil { + return nil, err + } + + var zero NodeID + id, err := self_field.ResolveFn(p, val) + if err != nil { + return zero, err + } + + nodes, err := ResolveNodes(ctx, p, []NodeID{id}) + if err != nil { + return nil, err + } else if len(nodes) != 1 { + return nil, fmt.Errorf("wrong length of nodes returned") + } + return nodes[0], nil + }) + if err != nil { + return err + } + + ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field) + node_fields[field_name] = ctx.Fields[field_name].Field + } + + for field_name, list_field := range(list_fields) { + err := RegisterField(ctx, ctx_interface.Interface, field_name, list_field.Extension, list_field.ACLName, + func(p graphql.ResolveParams, val interface{})(interface{}, error) { + ctx, err := PrepResolve(p) + if err != nil { + return nil, err + } + + var zero NodeID + ids, err := list_field.ResolveFn(p, val) + if err != nil { + return zero, err + } + + nodes, err := ResolveNodes(ctx, p, ids) + if err != nil { + return nil, err + } else if len(nodes) != len(ids) { + return nil, fmt.Errorf("wrong length of nodes returned") + } + return nodes, nil }) if err != nil { return err } + ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field) + node_fields[field_name] = ctx.Fields[field_name].Field } ctx_interface.Default = graphql.NewObject(graphql.ObjectConfig{ @@ -737,6 +786,44 @@ func NewGQLExtContext() *GQLExtContext { panic(err) } + err = context.RegisterInterface("Lockable", "DefaultLockable", []string{"Node"}, []string{}, map[string]SelfField{ + "Owner": SelfField{ + "owner", + LockableExtType, + func(p graphql.ResolveParams, val interface{}) (NodeID, error) { + var zero NodeID + id_str, ok := val.(string) + if ok == false { + return zero, fmt.Errorf("can't parse %+v as string", val) + } + + id, err := ParseID(id_str) + if err != nil { + return zero, err + } + + return id, nil + }, + }, + }, map[string]ListField{ + }) + + if err != nil { + panic(err) + } + + err = RegisterField(&context, graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, listen string) (interface{}, error) { + return listen, nil + }) + if err != nil { + panic(err) + } + + err = context.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node"}, []string{"Listen"}) + if err != nil { + panic(err) + } + context.Query.AddFieldConfig("Node", &graphql.Field{ Type: context.Interfaces["Node"].Interface, Args: graphql.FieldConfigArgument{ @@ -750,12 +837,19 @@ func NewGQLExtContext() *GQLExtContext { return nil, err } - id_str, err := ExtractParam[string](p, "id") + id, err := ExtractID(p, "id") + if err != nil { + return nil, err + } + + nodes, err := ResolveNodes(ctx, p, []NodeID{id}) if err != nil { return nil, err + } else if len(nodes) != 1 { + return nil, fmt.Errorf("wrong length of resolved nodes returned") } - return ResolveNode(ctx, p, id_str) + return nodes[0], nil }, }) diff --git a/gql_query.go b/gql_query.go index 57defba..889c60a 100644 --- a/gql_query.go +++ b/gql_query.go @@ -4,6 +4,7 @@ import ( "reflect" "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/ast" + "github.com/google/uuid" ) func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string { @@ -35,41 +36,43 @@ func GetResolveFields(ctx *Context, p graphql.ResolveParams) []string { return names } -func ResolveNode(ctx *ResolveContext, p graphql.ResolveParams, id_str string) (NodeResult, error) { - var zero NodeResult +func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([]NodeResult, error) { fields := GetResolveFields(ctx.Context, p) - ctx.Context.Log.Logf("gql", "RESOLVE_NODE(%s): %+v", id_str, fields) + ctx.Context.Log.Logf("gql", "RESOLVE_NODES(%+v): %+v", ids, fields) - id, err := ParseID(id_str) - if err != nil { - return zero, err - } - - // Get a list of fields that will be written - ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields) - if err != nil { - return zero, err - } - // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error - read_signal := NewReadSignal(ext_fields) - - ctx.Ext.resolver_reads_lock.Lock() - ctx.Ext.resolver_reads[read_signal.UUID] = ctx.ID - ctx.Ext.resolver_reads_lock.Unlock() + read_signals := map[NodeID]uuid.UUID{} + for _, id := range(ids) { + // Get a list of fields that will be written + ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields) + if err != nil { + return nil, err + } + // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error + read_signal := NewReadSignal(ext_fields) - err = ctx.Context.Send(ctx.Server.ID, id, read_signal) - if err != nil { ctx.Ext.resolver_reads_lock.Lock() - delete(ctx.Ext.resolver_reads, read_signal.UUID) + ctx.Ext.resolver_reads[read_signal.UUID] = ctx.ID ctx.Ext.resolver_reads_lock.Unlock() - return zero, err + + err = ctx.Context.Send(ctx.Server.ID, id, read_signal) + read_signals[id] = read_signal.UUID + if err != nil { + ctx.Ext.resolver_reads_lock.Lock() + delete(ctx.Ext.resolver_reads, read_signal.UUID) + ctx.Ext.resolver_reads_lock.Unlock() + return nil, err + } } - // Wait for the response, returning an error on timeout - response, err := WaitForReadResult(ctx.Chan, time.Millisecond*100, read_signal.UUID) - if err != nil { - return zero, err + responses := []NodeResult{} + for node_id, sig_id := range(read_signals) { + // Wait for the response, returning an error on timeout + response, err := WaitForReadResult(ctx.Chan, time.Millisecond*100, sig_id) + if err != nil { + return nil, err + } + responses = append(responses, NodeResult{node_id, response}) } - return NodeResult{id, response}, nil + return responses, nil } diff --git a/gql_resolvers.go b/gql_resolvers.go index 4d0f216..01bda71 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -64,23 +64,23 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) { return id, nil } -func ResolveNodeResult(p graphql.ResolveParams, resolve_fn func(NodeResult)(interface{}, error)) (interface{}, error) { +func ResolveNodeResult(p graphql.ResolveParams, resolve_fn func(graphql.ResolveParams, NodeResult)(interface{}, error)) (interface{}, error) { node, ok := p.Source.(NodeResult) if ok == false { return nil, fmt.Errorf("p.Value is not NodeResult") } - return resolve_fn(node) + return resolve_fn(p, node) } func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResult(p, func(node NodeResult) (interface{}, error) { + return ResolveNodeResult(p, func(p graphql.ResolveParams, node NodeResult) (interface{}, error) { return node.ID, nil }) } func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResult(p, func(node NodeResult) (interface{}, error) { + return ResolveNodeResult(p, func(p graphql.ResolveParams, node NodeResult) (interface{}, error) { return Hash(node.Result.NodeType), nil }) } diff --git a/lockable_test.go b/lockable_test.go index dd288e6..5d5b422 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -85,7 +85,7 @@ func TestLink10K(t *testing.T) { } l0, l0_listener := NewListener() - lockables := make([]*Node, 10000) + lockables := make([]*Node, 10) for i, _ := range(lockables) { lockables[i] = NewLockable() LinkRequirement(ctx, l0.ID, lockables[i].ID)