diff --git a/gql.go b/gql.go index cac3147..0741097 100644 --- a/gql.go +++ b/gql.go @@ -543,7 +543,7 @@ func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) { return graphql.NewSchema(schemaConfig) } -func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, ext_type ExtType, acl_name string, resolve_fn func(graphql.ResolveParams, SerializedValue)(interface{}, error)) error { +func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, ext_type ExtType, gv_tag string, resolve_fn func(graphql.ResolveParams, *ResolveContext, reflect.Value)(interface{}, error)) error { if ctx == nil { return fmt.Errorf("ctx is nil") } @@ -557,27 +557,45 @@ func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, return fmt.Errorf("%s is already a field in the context, cannot add again", gql_name) } + // Resolver has p.Source.(NodeResult) = read result of current node resolver := func(p graphql.ResolveParams)(interface{}, error) { - return ResolveNodeResult(p, func(p graphql.ResolveParams, result NodeResult) (interface{}, error) { - ext, exists := result.Result.Extensions[ext_type] - if exists == false { - return nil, fmt.Errorf("%+v is not in the extensions of the result", ext_type) - } + ctx, err := PrepResolve(p) + if err != nil { + return nil, err + } - val_ser, exists := ext[acl_name] - if exists == false { - return nil, fmt.Errorf("%s is not in the fields of %+v in the result", acl_name, ext_type) - } + node, ok := p.Source.(NodeResult) + if ok == false { + return nil, fmt.Errorf("p.Value is not NodeResult") + } - if val_ser.TypeStack[0] == ErrorType { - return nil, fmt.Errorf(string(val_ser.Data)) - } + ext, ext_exists := node.Result.Extensions[ext_type] + if ext_exists == false { + return nil, fmt.Errorf("%+v is not in the extensions of the result", ext_type) + } - return resolve_fn(p, val_ser) - }) + val_ser, field_exists := ext[gv_tag] + if field_exists == false { + return nil, fmt.Errorf("%s is not in the fields of %+v in the result", gv_tag, ext_type) + } + + if val_ser.TypeStack[0] == ErrorType { + return nil, fmt.Errorf(string(val_ser.Data)) + } + + field_type, field_value, _, err := DeserializeValue(ctx.Context, val_ser) + if err != nil { + return nil, err + } + + if field_value == nil { + return nil, fmt.Errorf("%s returned a nil value of %+v type", gv_tag, field_type) + } + + return resolve_fn(p, ctx, *field_value) } - ctx.Fields[gql_name] = Field{ext_type, acl_name, &graphql.Field{ + ctx.Fields[gql_name] = Field{ext_type, gv_tag, &graphql.Field{ Type: gql_type, Resolve: resolver, }} @@ -635,13 +653,13 @@ type NodeResult struct { type ListField struct { ACLName string Extension ExtType - ResolveFn func(graphql.ResolveParams, interface{}) ([]NodeID, error) + ResolveFn func(graphql.ResolveParams, *ResolveContext, reflect.Value) ([]NodeID, error) } type SelfField struct { ACLName string Extension ExtType - ResolveFn func(graphql.ResolveParams, interface{}) (*NodeID, error) + ResolveFn func(graphql.ResolveParams, *ResolveContext, reflect.Value) (*NodeID, error) } func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, interfaces []string, fields []string, self_fields map[string]SelfField, list_fields map[string]ListField) error { @@ -680,13 +698,8 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in for field_name, field := range(self_fields) { self_field := field err := ctx.RegisterField(ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName, - func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - id, err := self_field.ResolveFn(p, val) + func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) { + id, err := self_field.ResolveFn(p, ctx, value) if err != nil { return nil, err } @@ -713,14 +726,9 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in for field_name, field := range(list_fields) { list_field := field - resolve_fn := func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - + resolve_fn := func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) { var zero NodeID - ids, err := list_field.ResolveFn(p, val) + ids, err := list_field.ResolveFn(p, ctx, value) if err != nil { return zero, err } @@ -829,19 +837,18 @@ func NewGQLExtContext() *GQLExtContext { } err = context.RegisterField(context.Interfaces["Node"].List, "Members", GroupExtType, "members", - func(p graphql.ResolveParams, val SerializedValue)(interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err + func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) { + node_map, ok := value.Interface().(map[NodeID]string) + if ok == false { + return nil, fmt.Errorf("value is %+v, not map[NodeID]string", value.Type()) } - /*node_list := make([]NodeID, len(val)) + node_list := []NodeID{} i := 0 - for id, _ := range(val) { - node_list[i] = id + for id, _ := range(node_map) { + node_list = append(node_list, id) i += 1 } - */ - node_list := []NodeID{} + nodes, err := ResolveNodes(ctx, p, node_list) if err != nil { return nil, err @@ -862,10 +869,10 @@ func NewGQLExtContext() *GQLExtContext { "Owner": SelfField{ "owner", LockableExtType, - func(p graphql.ResolveParams, val interface{}) (*NodeID, error) { - id, ok := val.(*NodeID) + func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) (*NodeID, error) { + id, ok := value.Interface().(*NodeID) if ok == false { - return nil, fmt.Errorf("can't parse %+v as *NodeID", val) + return nil, fmt.Errorf("can't parse %+v as *NodeID", value.Type()) } return id, nil @@ -875,14 +882,14 @@ func NewGQLExtContext() *GQLExtContext { "Requirements": ListField{ "requirements", LockableExtType, - func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) { - id_strs, ok := val.(map[NodeID]ReqState) + func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) ([]NodeID, error) { + id_strs, ok := value.Interface().(map[NodeID]ReqState) if ok == false { - return nil, fmt.Errorf("can't parse requirements %+v as map[NodeID]ReqState, %s", val, reflect.TypeOf(val)) + return nil, fmt.Errorf("can't parse requirements %+v as map[NodeID]ReqState", value.Type()) } ids := []NodeID{} - for id, _ := range(id_strs) { + for id := range(id_strs) { ids = append(ids, id) } return ids, nil @@ -894,8 +901,8 @@ func NewGQLExtContext() *GQLExtContext { panic(err) } - err = context.RegisterField(graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, listen SerializedValue) (interface{}, error) { - return listen, nil + err = context.RegisterField(graphql.String, "Listen", GQLExtType, "listen", func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) (interface{}, error) { + return value.String(), nil }) if err != nil { panic(err) @@ -986,17 +993,17 @@ func NewGQLExtContext() *GQLExtContext { } type GQLExt struct { - tcp_listener net.Listener `json:"-"` - http_server *http.Server `json:"-"` - http_done sync.WaitGroup `json:"-"` + tcp_listener net.Listener + http_server *http.Server + http_done sync.WaitGroup // map of read request IDs to response channels - resolver_response map[uuid.UUID]chan Signal `json:"-"` - resolver_response_lock sync.RWMutex `json:"-"` + resolver_response map[uuid.UUID]chan Signal + resolver_response_lock sync.RWMutex - TLSKey []byte `json:"tls_key"` - TLSCert []byte `json:"tls_cert"` - Listen string `json:"listen"` + TLSKey []byte `gv:"tls_key"` + TLSCert []byte `gv:"tls_cert"` + Listen string `gv:"listen"` } func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal { diff --git a/gql_resolvers.go b/gql_resolvers.go index ccb2d8f..cacc0a7 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -64,23 +64,20 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) { return id, nil } -func ResolveNodeResult(p graphql.ResolveParams, resolve_fn func(graphql.ResolveParams, NodeResult)(interface{}, error)) (interface{}, error) { +func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) { node, ok := p.Source.(NodeResult) if ok == false { - return nil, fmt.Errorf("p.Value is not NodeResult") + return nil, fmt.Errorf("Can't get NodeID from %+v", reflect.TypeOf(p.Source)) } - return resolve_fn(p, node) -} - -func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResult(p, func(p graphql.ResolveParams, node NodeResult) (interface{}, error) { - return node.ID, nil - }) + return node.ID, nil } func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { - return ResolveNodeResult(p, func(p graphql.ResolveParams, node NodeResult) (interface{}, error) { - return uint64(node.Result.NodeType), nil - }) + node, ok := p.Source.(NodeResult) + if ok == false { + return nil, fmt.Errorf("Can't get TypeHash from %+v", reflect.TypeOf(p.Source)) + } + + return uint64(node.Result.NodeType), nil } diff --git a/group.go b/group.go index 14321b2..fca4383 100644 --- a/group.go +++ b/group.go @@ -5,7 +5,7 @@ import ( ) type GroupExt struct { - Members map[NodeID]string `gv:"0"` + Members map[NodeID]string `gv:"members"` } func (ext *GroupExt) Type() ExtType { diff --git a/lockable.go b/lockable.go index b009074..d85b8cd 100644 --- a/lockable.go +++ b/lockable.go @@ -16,9 +16,9 @@ const ( type LockableExt struct{ State ReqState `gv:"state"` ReqID *uuid.UUID `gv:"req_id"` - Owner *NodeID - PendingOwner *NodeID - Requirements map[NodeID]ReqState + Owner *NodeID `gv:"owner"` + PendingOwner *NodeID `gv:"pending_owner"` + Requirements map[NodeID]ReqState `gv:"requirements"` } func (ext *LockableExt) Type() ExtType { diff --git a/serialize.go b/serialize.go index 21b1be3..a32fdd1 100644 --- a/serialize.go +++ b/serialize.go @@ -635,17 +635,29 @@ func SerializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (Seriali return serialized_value, err } -func SerializeField(ctx *Context, ext Extension, field_name string) (SerializedValue, error) { +func ExtField(ctx *Context, ext Extension, field_name string) (reflect.Value, error) { if ext == nil { - return SerializedValue{}, fmt.Errorf("Cannot get fields on nil Extension") + return reflect.Value{}, fmt.Errorf("Cannot get fields on nil Extension") } + ext_value := reflect.ValueOf(ext).Elem() - field := ext_value.FieldByName(field_name) - if field.IsValid() == false { - return SerializedValue{}, fmt.Errorf("%s is not a field in %+v", field_name, ext) - } else { - return SerializeValue(ctx, field.Type(), &field) + for _, field := range(reflect.VisibleFields(ext_value.Type())) { + gv_tag, tagged := field.Tag.Lookup("gv") + if tagged == true && gv_tag == field_name { + return ext_value.FieldByIndex(field.Index), nil + } } + + return reflect.Value{}, fmt.Errorf("%s is not a field in %+v", field_name, reflect.TypeOf(ext)) +} + +func SerializeField(ctx *Context, ext Extension, field_name string) (SerializedValue, error) { + field_value, err := ExtField(ctx, ext, field_name) + if err != nil { + return SerializedValue{}, err + } + + return SerializeValue(ctx, field_value.Type(), &field_value) } func (value SerializedValue) MarshalBinary() ([]byte, error) {