diff --git a/context.go b/context.go index dc7ed29..2882636 100644 --- a/context.go +++ b/context.go @@ -7,7 +7,6 @@ import ( "fmt" "reflect" "runtime" - "slices" "strconv" "strings" "sync" @@ -30,10 +29,14 @@ var ( type SerializeFn func(ctx *Context, value reflect.Value) ([]byte, error) type DeserializeFn func(ctx *Context, data []byte) (reflect.Value, []byte, error) -type FieldInfo struct { +type NodeFieldInfo struct { + Extension ExtType + Index []int + Type graphql.Type +} + +type StructFieldInfo struct { Index []int - Tag string - NodeTag string Type reflect.Type } @@ -42,30 +45,37 @@ type TypeInfo struct { Reflect reflect.Type Type graphql.Type - Fields map[FieldTag]FieldInfo + Fields map[FieldTag]StructFieldInfo PostDeserializeIndex int Serialize SerializeFn Deserialize DeserializeFn } +type ExtensionFieldInfo struct { + Index []int + Type reflect.Type + NodeTag string +} + type ExtensionInfo struct { ExtType - Fields map[string]FieldInfo + Type reflect.Type + Fields map[string]ExtensionFieldInfo Data interface{} } type NodeInfo struct { NodeType Type *graphql.Object - Interface *graphql.Interface - Extensions []ExtType - Fields map[string]ExtType + RequiredExtensions []ExtType + Fields map[string]NodeFieldInfo + ReverseFields map[ExtType]map[Tag]string } type InterfaceInfo struct { - Serialized SerializedType - Reflect reflect.Type + Type *graphql.Interface + Fields map[string]graphql.Type } // A Context stores all the data to run a graphvent process @@ -77,23 +87,30 @@ type Context struct { Log Logger // Mapped types - TypeMap map[SerializedType]*TypeInfo - TypeTypes map[reflect.Type]*TypeInfo + Types map[reflect.Type]*TypeInfo + TypesReverse map[SerializedType]*TypeInfo // Map between database extension hashes and the registered info - Extensions map[ExtType]*ExtensionInfo - ExtensionTypes map[reflect.Type]*ExtensionInfo + Extensions map[ExtType]ExtensionInfo - Interfaces map[SerializedType]*InterfaceInfo - InterfaceTypes map[reflect.Type]*InterfaceInfo + // Map between GQL interface name and the registered info + Interfaces map[string]InterfaceInfo - // Map between database type hashes and the registered info - Nodes map[NodeType]*NodeInfo - NodeTypes map[string]*NodeInfo + // Map between database node type hashes and the registered info + NodeTypes map[NodeType]NodeInfo // Routing map to all the nodes local to this context - nodeMapLock sync.RWMutex - nodeMap map[NodeID]*Node + nodesLock sync.RWMutex + nodes map[NodeID]*Node +} + +func gqltype(ctx *Context, t reflect.Type, node_type string) graphql.Type { + gql, err := ctx.GQLType(t, node_type) + if err != nil { + panic(err) + } else { + return gql + } } func (ctx *Context) GQLType(t reflect.Type, node_type string) (graphql.Type, error) { @@ -101,30 +118,37 @@ func (ctx *Context) GQLType(t reflect.Type, node_type string) (graphql.Type, err if node_type == "" { node_type = "Base" } - node_info, mapped := ctx.NodeTypes[node_type] + + interface_info, mapped := ctx.Interfaces[node_type] if mapped == false { - return nil, fmt.Errorf("Cannot get GQL type for unregistered Node Type \"%s\"", node_type) + type_info, mapped := ctx.NodeTypes[NodeTypeFor(node_type)] + if mapped { + return type_info.Type, nil + } else { + return nil, fmt.Errorf("Cannot get GQL type for unregistered Node Type \"%s\"", node_type) + } } else { - return node_info.Interface, nil + return interface_info.Type, nil } + } else { - info, mapped := ctx.TypeTypes[t] + info, mapped := ctx.Types[t] if mapped { return info.Type, nil } else { switch t.Kind() { case reflect.Array: - info, mapped := ctx.TypeTypes[t.Elem()] + info, mapped := ctx.Types[t.Elem()] if mapped { return graphql.NewList(info.Type), nil } case reflect.Slice: - info, mapped := ctx.TypeTypes[t.Elem()] + info, mapped := ctx.Types[t.Elem()] if mapped { return graphql.NewList(info.Type), nil } case reflect.Map: - info, exists := ctx.TypeTypes[t] + info, exists := ctx.Types[t] if exists { return info.Type, nil } else { @@ -132,7 +156,7 @@ func (ctx *Context) GQLType(t reflect.Type, node_type string) (graphql.Type, err if err != nil { return nil, err } - map_type := ctx.TypeTypes[t].Type + map_type := ctx.Types[t].Type return map_type, nil } case reflect.Pointer: @@ -217,12 +241,12 @@ func RegisterMap(ctx *Context, reflect_type reflect.Type, node_type string) erro gql_map := graphql.NewList(gql_pair) serialized_type := SerializeType(reflect_type) - ctx.TypeMap[serialized_type] = &TypeInfo{ + ctx.Types[reflect_type] = &TypeInfo{ Serialized: serialized_type, Reflect: reflect_type, Type: gql_map, } - ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] + ctx.TypesReverse[serialized_type] = ctx.Types[reflect_type] return nil } @@ -231,15 +255,18 @@ func BuildSchema(ctx *Context, query, mutation *graphql.Object) (graphql.Schema, types := []graphql.Type{} ctx.Log.Logf("gql", "Building Schema") - for _, info := range(ctx.TypeMap) { + for _, info := range(ctx.Types) { if info.Type != nil { types = append(types, info.Type) } } - for _, info := range(ctx.Nodes) { + for _, info := range(ctx.NodeTypes) { types = append(types, info.Type) - types = append(types, info.Interface) + } + + for _, info := range(ctx.Interfaces) { + types = append(types, info.Type) } subscription := graphql.NewObject(graphql.ObjectConfig{ @@ -287,78 +314,69 @@ func BuildSchema(ctx *Context, query, mutation *graphql.Object) (graphql.Schema, func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data interface{}) error { reflect_type := reflect.TypeFor[E]() - ext_type := ExtType(SerializedTypeFor[E]()) + ext_type := ExtTypeFor[E, T]() _, exists := ctx.Extensions[ext_type] if exists == true { return fmt.Errorf("Cannot register extension %+v of type %+v, type already exists in context", reflect_type, ext_type) } - fields := map[string]FieldInfo{} + fields := map[string]ExtensionFieldInfo{} for _, field := range(reflect.VisibleFields(reflect_type)) { gv_tag, tagged_gv := field.Tag.Lookup("gv") + node_tag := field.Tag.Get("node") if tagged_gv { - fields[gv_tag] = FieldInfo{ + fields[gv_tag] = ExtensionFieldInfo{ Index: field.Index, - Tag: gv_tag, - NodeTag: field.Tag.Get("node"), Type: field.Type, + NodeTag: node_tag, } } } - ctx.Extensions[ext_type] = &ExtensionInfo{ + ctx.Extensions[ext_type] = ExtensionInfo{ ExtType: ext_type, + Type: reflect_type, Data: data, Fields: fields, } - ctx.ExtensionTypes[reflect_type] = ctx.Extensions[ext_type] return nil } -func RegisterNodeType(ctx *Context, name string, extensions []ExtType) error { - node_type := NodeTypeFor(extensions) - _, exists := ctx.Nodes[node_type] - if exists == true { - return fmt.Errorf("Cannot register node type %+v, type already exists in context", node_type) +type FieldMapping struct { + Extension ExtType + Tag string +} + +func RegisterNodeInterface(ctx *Context, name string, fields map[string]graphql.Type) error { + _, exists := ctx.Interfaces[name] + if exists { + return fmt.Errorf("Cannot register Node Interface %s, already registered", name) } - fields := map[string]ExtType{} + gql_fields := graphql.Fields{ + "ID": &graphql.Field{ + Type: ctx.Types[reflect.TypeFor[NodeID]()].Type, + }, + "Type": &graphql.Field{ + Type: ctx.Types[reflect.TypeFor[NodeType]()].Type, + }, + } - ext_found := map[ExtType]bool{} - for _, extension := range(extensions) { - ext_info, in_ctx := ctx.Extensions[extension] - if in_ctx == false { - return fmt.Errorf("Cannot register node type %+v, required extension %+v not in context", name, extension) + for field_name, field_type := range(fields) { + _, exists := gql_fields[field_name] + if exists { + return fmt.Errorf("Cannot register interface %s with duplicate field %s", name, field_name) } - - _, duplicate := ext_found[extension] - if duplicate == true { - return fmt.Errorf("Duplicate extension %+v found in extension list", extension) - } - - ext_found[extension] = true - - for field_name := range(ext_info.Fields) { - _, exists := fields[field_name] - if exists { - return fmt.Errorf("Cannot register NodeType %s with duplicate field name %s", name, field_name) - } - fields[field_name] = extension + gql_fields[field_name] = &graphql.Field{ + Type: field_type, } } - - gql_interface := graphql.NewInterface(graphql.InterfaceConfig{ + + gql := graphql.NewInterface(graphql.InterfaceConfig{ Name: name, - Fields: graphql.Fields{ - "ID": &graphql.Field{ - Type: ctx.TypeTypes[reflect.TypeFor[NodeID]()].Type, - }, - "Type": &graphql.Field{ - Type: ctx.TypeTypes[reflect.TypeFor[NodeType]()].Type, - }, - }, + Fields: gql_fields, ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { ctx_val := p.Context.Value("resolve") ctx, ok := ctx_val.(*ResolveContext) @@ -372,36 +390,94 @@ func RegisterNodeType(ctx *Context, name string, extensions []ExtType) error { return nil } - node_info, exists := ctx.Context.Nodes[val.NodeType] + node_info, exists := ctx.Context.NodeTypes[val.NodeType] if exists == false { ctx.Context.Log.Logf("gql", "Interface ResolveType got bad NodeType", val.NodeType) return nil } - for _, ext_type := range(extensions) { - if slices.Contains(node_info.Extensions, ext_type) == false { - ctx.Context.Log.Logf("gql", "Interface ResolveType for %s missing extensions %s: %+v", name, ext_type, val) - return nil - } - } - return node_info.Type }, }) - gql := graphql.NewObject(graphql.ObjectConfig{ - Name: name + "Node", - Interfaces: ctx.GQLInterfaces(node_type, extensions), - Fields: graphql.Fields{ - "ID": &graphql.Field{ - Type: ctx.TypeTypes[reflect.TypeFor[NodeID]()].Type, - Resolve: ResolveNodeID, - }, - "Type": &graphql.Field{ - Type: ctx.TypeTypes[reflect.TypeFor[NodeType]()].Type, - Resolve: ResolveNodeType, - }, + ctx.Interfaces[name] = InterfaceInfo{ + Type: gql, + Fields: fields, + } + + return nil +} + +func RegisterNodeType(ctx *Context, name string, mappings map[string]FieldMapping) error { + node_type := NodeTypeFor(name) + _, exists := ctx.NodeTypes[node_type] + if exists { + return fmt.Errorf("Cannot register node type %s, already registered", name) + } + + fields := map[string]NodeFieldInfo{} + reverse_fields := map[ExtType]map[Tag]string{} + + gql_fields := graphql.Fields{ + "ID": &graphql.Field{ + Type: ctx.Types[reflect.TypeFor[NodeID]()].Type, + Resolve: ResolveNodeID, + }, + "Type": &graphql.Field{ + Type: ctx.Types[reflect.TypeFor[NodeType]()].Type, + Resolve: ResolveNodeType, }, + } + + ext_map := map[ExtType]bool{} + for field_name, mapping := range(mappings) { + _, duplicate := fields[field_name] + if duplicate { + return fmt.Errorf("Cannot register node type %s, contains duplicate field %s", name, field_name) + } + + ext_info, exists := ctx.Extensions[mapping.Extension] + if exists == false { + return fmt.Errorf("Cannot register node type %s, unknown extension %s", name, mapping.Extension) + } + + ext_map[mapping.Extension] = true + + ext_field, exists := ext_info.Fields[mapping.Tag] + if exists == false { + return fmt.Errorf("Cannot register node type %s, extension %s has no field %s", name, mapping.Extension, mapping.Tag) + } + + gql_type, err := ctx.GQLType(ext_field.Type, ext_field.NodeTag) + if err != nil { + return fmt.Errorf("Cannot register node type %s, GQLType error: %w", name, err) + } + + gql_resolve := ctx.GQLResolve(ext_field.Type, ext_field.NodeTag) + + fields[field_name] = NodeFieldInfo{ + Extension: mapping.Extension, + Index: ext_field.Index, + Type: gql_type, + } + + gql_fields[field_name] = &graphql.Field{ + Type: gql_type, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + node, ok := p.Source.(NodeResult) + if ok == false { + return nil, fmt.Errorf("Can't resolve Node field on non-Node %s", reflect.TypeOf(p.Source)) + } + + return gql_resolve(node.Data[field_name], p) + }, + } + } + + gql := graphql.NewObject(graphql.ObjectConfig{ + Name: name, + Interfaces: ctx.GQLInterfaces(fields), + Fields: gql_fields, IsTypeOf: func(p graphql.IsTypeOfParams) bool { source, ok := p.Value.(NodeResult) if ok == false { @@ -411,73 +487,41 @@ func RegisterNodeType(ctx *Context, name string, extensions []ExtType) error { }, }) - ctx.Nodes[node_type] = &NodeInfo{ + extensions := []ExtType{} + for ext_type := range(ext_map) { + extensions = append(extensions, ext_type) + } + + ctx.NodeTypes[node_type] = NodeInfo{ NodeType: node_type, - Interface: gql_interface, Type: gql, - Extensions: extensions, Fields: fields, - } - ctx.NodeTypes[name] = ctx.Nodes[node_type] - - for _, ext_type := range(extensions) { - ext_info, ext_found := ctx.Extensions[ext_type] - if ext_found == false { - return fmt.Errorf("Extension %s not found", ext_type) - } - - for field_name, field_info := range(ext_info.Fields) { - gql_type, err := ctx.GQLType(field_info.Type, field_info.NodeTag) - if err != nil { - return err - } - - gql_resolve := ctx.GQLResolve(field_info.Type, field_info.NodeTag) - - gql_interface.AddFieldConfig(field_name, &graphql.Field{ - Type: gql_type, - }) - - gql.AddFieldConfig(field_name, &graphql.Field{ - Type: gql_type, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - node, ok := p.Source.(NodeResult) - if ok == false { - return nil, fmt.Errorf("Can't resolve Node field on non-Node %s", reflect.TypeOf(p.Source)) - } - - node_info, mapped := ctx.Nodes[node.NodeType] - if mapped == false { - return nil, fmt.Errorf("Can't resolve unknown NodeType %s", node.NodeType) - } - - return gql_resolve(node.Data[node_info.Fields[field_name]][field_name], p) - }, - }) - } - + ReverseFields: reverse_fields, + RequiredExtensions: extensions, } return nil } -func (ctx *Context) GQLInterfaces(known_type NodeType, extensions []ExtType) graphql.InterfacesThunk { +// Returns a function which returns a list of interfaces from the context whose fields are a subset of fields +func (ctx *Context) GQLInterfaces(fields map[string]NodeFieldInfo) graphql.InterfacesThunk { return func() []*graphql.Interface { interfaces := []*graphql.Interface{} - for node_type, node_info := range(ctx.Nodes) { - if node_type != known_type { - has_ext := true - for _, ext := range(node_info.Extensions) { - if slices.Contains(extensions, ext) == false { - has_ext = false - break - } - } - if has_ext == false { - continue + for _, interface_info := range(ctx.Interfaces) { + match := true + for field_name, field_type := range(interface_info.Fields) { + field, exists := fields[field_name] + if exists == false { + match = false + break + } else if field.Type != field_type { + match = false + break } } - interfaces = append(interfaces, node_info.Interface) + if match { + interfaces = append(interfaces, interface_info.Type) + } } return interfaces } @@ -491,7 +535,7 @@ func RegisterObject[T any](ctx *Context) error { reflect_type := reflect.TypeFor[T]() serialized_type := SerializedTypeFor[T]() - _, exists := ctx.TypeTypes[reflect_type] + _, exists := ctx.Types[reflect_type] if exists { return fmt.Errorf("%+v already registered in TypeMap", reflect_type) } @@ -505,7 +549,7 @@ func RegisterObject[T any](ctx *Context) error { Fields: graphql.Fields{}, }) - field_infos := map[FieldTag]FieldInfo{} + field_infos := map[FieldTag]StructFieldInfo{} post_deserialize, post_deserialize_exists := reflect.PointerTo(reflect_type).MethodByName("PostDeserialize") post_deserialize_index := -1 @@ -517,12 +561,11 @@ func RegisterObject[T any](ctx *Context) error { gv_tag, tagged_gv := field.Tag.Lookup("gv") if tagged_gv { node_tag := field.Tag.Get("node") - field_infos[GetFieldTag(gv_tag)] = FieldInfo{ + field_infos[GetFieldTag(gv_tag)] = StructFieldInfo{ Type: field.Type, Index: field.Index, - NodeTag: node_tag, - Tag: gv_tag, } + gql_type, err := ctx.GQLType(field.Type, node_tag) if err != nil { return err @@ -552,14 +595,14 @@ func RegisterObject[T any](ctx *Context) error { } } - ctx.TypeMap[serialized_type] = &TypeInfo{ + ctx.Types[reflect_type] = &TypeInfo{ PostDeserializeIndex: post_deserialize_index, Serialized: serialized_type, Reflect: reflect_type, Fields: field_infos, Type: gql, } - ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] + ctx.TypesReverse[serialized_type] = ctx.Types[reflect_type] return nil } @@ -568,12 +611,12 @@ func RegisterObjectNoGQL[T any](ctx *Context) error { reflect_type := reflect.TypeFor[T]() serialized_type := SerializedTypeFor[T]() - _, exists := ctx.TypeTypes[reflect_type] + _, exists := ctx.Types[reflect_type] if exists { return fmt.Errorf("%+v already registered in TypeMap", reflect_type) } - field_infos := map[FieldTag]FieldInfo{} + field_infos := map[FieldTag]StructFieldInfo{} post_deserialize, post_deserialize_exists := reflect.PointerTo(reflect_type).MethodByName("PostDeserialize") post_deserialize_index := -1 @@ -584,24 +627,21 @@ func RegisterObjectNoGQL[T any](ctx *Context) error { for _, field := range(reflect.VisibleFields(reflect_type)) { gv_tag, tagged_gv := field.Tag.Lookup("gv") if tagged_gv { - node_tag := field.Tag.Get("node") - field_infos[GetFieldTag(gv_tag)] = FieldInfo{ + field_infos[GetFieldTag(gv_tag)] = StructFieldInfo{ Type: field.Type, Index: field.Index, - NodeTag: node_tag, - Tag: gv_tag, } } } - ctx.TypeMap[serialized_type] = &TypeInfo{ + ctx.Types[reflect_type] = &TypeInfo{ PostDeserializeIndex: post_deserialize_index, Serialized: serialized_type, Reflect: reflect_type, Fields: field_infos, Type: nil, } - ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] + ctx.TypesReverse[serialized_type] = ctx.Types[reflect_type] return nil } @@ -726,7 +766,7 @@ func RegisterEnum[E comparable](ctx *Context, str_map map[E]string) error { reflect_type := reflect.TypeFor[E]() serialized_type := SerializedTypeFor[E]() - _, exists := ctx.TypeTypes[reflect_type] + _, exists := ctx.Types[reflect_type] if exists { return fmt.Errorf("%+v already registered in TypeMap", reflect_type) } @@ -745,21 +785,21 @@ func RegisterEnum[E comparable](ctx *Context, str_map map[E]string) error { Values: value_config, }) - ctx.TypeMap[serialized_type] = &TypeInfo{ + ctx.Types[reflect_type] = &TypeInfo{ Serialized: serialized_type, Reflect: reflect_type, Type: gql, } - ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] + ctx.TypesReverse[serialized_type] = ctx.Types[reflect_type] return nil } -func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}) error { +func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, from_json func(interface{})interface{}, from_ast func(ast.Value)interface{}, serialize SerializeFn, deserialize DeserializeFn) error { reflect_type := reflect.TypeFor[S]() serialized_type := SerializedTypeFor[S]() - _, exists := ctx.TypeTypes[reflect_type] + _, exists := ctx.Types[reflect_type] if exists { return fmt.Errorf("%+v already registered in TypeMap", reflect_type) } @@ -772,26 +812,29 @@ func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, f ParseLiteral: from_ast, }) - ctx.TypeMap[serialized_type] = &TypeInfo{ + ctx.Types[reflect_type] = &TypeInfo{ Serialized: serialized_type, Reflect: reflect_type, Type: gql, + + Serialize: serialize, + Deserialize: deserialize, } - ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] + ctx.TypesReverse[serialized_type] = ctx.Types[reflect_type] return nil } func (ctx *Context) AddNode(id NodeID, node *Node) { - ctx.nodeMapLock.Lock() - ctx.nodeMap[id] = node - ctx.nodeMapLock.Unlock() + ctx.nodesLock.Lock() + ctx.nodes[id] = node + ctx.nodesLock.Unlock() } func (ctx *Context) Node(id NodeID) (*Node, bool) { - ctx.nodeMapLock.RLock() - node, exists := ctx.nodeMap[id] - ctx.nodeMapLock.RUnlock() + ctx.nodesLock.RLock() + node, exists := ctx.nodes[id] + ctx.nodesLock.RUnlock() return node, exists } @@ -805,25 +848,25 @@ func (ctx *Context) Delete(id NodeID) error { } func (ctx *Context) Unload(id NodeID) error { - ctx.nodeMapLock.Lock() - defer ctx.nodeMapLock.Unlock() - node, exists := ctx.nodeMap[id] + ctx.nodesLock.Lock() + defer ctx.nodesLock.Unlock() + node, exists := ctx.nodes[id] if exists == false { return fmt.Errorf("%s is not a node in ctx", id) } err := node.Unload(ctx) - delete(ctx.nodeMap, id) + delete(ctx.nodes, id) return err } func (ctx *Context) Stop() { - ctx.nodeMapLock.Lock() - for id, node := range(ctx.nodeMap) { + ctx.nodesLock.Lock() + for id, node := range(ctx.nodes) { node.Unload(ctx) - delete(ctx.nodeMap, id) + delete(ctx.nodes, id) } - ctx.nodeMapLock.Unlock() + ctx.nodesLock.Unlock() } func (ctx *Context) Load(id NodeID) (*Node, error) { @@ -835,7 +878,6 @@ func (ctx *Context) Load(id NodeID) (*Node, error) { ctx.AddNode(id, node) started := make(chan error, 1) go runNode(ctx, node, started) - err = <- started if err != nil { return nil, err @@ -852,7 +894,7 @@ func (ctx *Context) getNode(id NodeID) (*Node, error) { var err error target, err = ctx.Load(id) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to load node %s: %w", id, err) } } return target, nil @@ -930,183 +972,203 @@ func (ctx *Context)GQLResolve(t reflect.Type, node_type string) (func(interface{ } } -func RegisterInterface[T any](ctx *Context) error { - serialized_type := SerializeType(reflect.TypeFor[T]()) - reflect_type := reflect.TypeFor[T]() - - _, exists := ctx.Interfaces[serialized_type] - if exists == true { - return fmt.Errorf("Interface %+v already exists in context", reflect_type) - } - - ctx.Interfaces[serialized_type] = &InterfaceInfo{ - Serialized: serialized_type, - Reflect: reflect_type, - } - ctx.InterfaceTypes[reflect_type] = ctx.Interfaces[serialized_type] - - return nil -} - // Create a new Context with the base library content added func NewContext(db * badger.DB, log Logger) (*Context, error) { + uuid.EnableRandPool() + ctx := &Context{ DB: db, Log: log, - TypeMap: map[SerializedType]*TypeInfo{}, - TypeTypes: map[reflect.Type]*TypeInfo{}, - - Extensions: map[ExtType]*ExtensionInfo{}, - ExtensionTypes: map[reflect.Type]*ExtensionInfo{}, - - Interfaces: map[SerializedType]*InterfaceInfo{}, - InterfaceTypes: map[reflect.Type]*InterfaceInfo{}, - - Nodes: map[NodeType]*NodeInfo{}, - NodeTypes: map[string]*NodeInfo{}, + Types: map[reflect.Type]*TypeInfo{}, + TypesReverse: map[SerializedType]*TypeInfo{}, + Extensions: map[ExtType]ExtensionInfo{}, + Interfaces: map[string]InterfaceInfo{}, + NodeTypes: map[NodeType]NodeInfo{}, - nodeMap: map[NodeID]*Node{}, + nodes: map[NodeID]*Node{}, } var err error - err = RegisterScalar[NodeID](ctx, stringify, unstringify[NodeID], unstringifyAST[NodeID]) + err = RegisterScalar[NodeID](ctx, stringify, unstringify[NodeID], unstringifyAST[NodeID], + func(ctx *Context, value reflect.Value) ([]byte, error) { + return value.Bytes(), nil + }, func(ctx *Context, data []byte) (reflect.Value, []byte, error) { + if len(data) < 16 { + return reflect.Value{}, nil, fmt.Errorf("Not enough bytes to decode NodeID(got %d, want 16)", len(data)) + } + + id := new(NodeID) + err := id.UnmarshalBinary(data[0:16]) + if err != nil { + return reflect.Value{}, nil, err + } + + return reflect.ValueOf(id).Elem(), data[16:], nil + }) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register NodeID: %w", err) } - - err = RegisterInterface[Extension](ctx) + + err = RegisterScalar[uuid.UUID](ctx, stringify, unstringify[uuid.UUID], unstringifyAST[uuid.UUID], + func(ctx *Context, value reflect.Value) ([]byte, error) { + return value.Bytes(), nil + }, func(ctx *Context, data []byte) (reflect.Value, []byte, error) { + if len(data) < 16 { + return reflect.Value{}, nil, fmt.Errorf("Not enough bytes to decode uuid.UUID(got %d, want 16)", len(data)) + } + + id := new(uuid.UUID) + err := id.UnmarshalBinary(data[0:16]) + if err != nil { + return reflect.Value{}, nil, err + } + + return reflect.ValueOf(id).Elem(), data[16:], nil + }) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register uuid.UUID: %w", err) } - - err = RegisterInterface[Signal](ctx) + + err = RegisterScalar[NodeType](ctx, identity, coerce[NodeType], astInt[NodeType], nil, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register NodeType: %w", err) } - err = RegisterScalar[NodeType](ctx, identity, coerce[NodeType], astInt[NodeType]) + err = RegisterScalar[ExtType](ctx, identity, coerce[ExtType], astInt[ExtType], nil, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register ExtType: %w", err) } - err = RegisterScalar[ExtType](ctx, identity, coerce[ExtType], astInt[ExtType]) + err = RegisterNodeInterface(ctx, "Base", map[string]graphql.Type{}) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register NodeInterface Base: %w", err) } - err = RegisterNodeType(ctx, "Base", []ExtType{}) + err = RegisterNodeType(ctx, "Node", map[string]FieldMapping{}) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register NodeType Node: %w", err) } - err = RegisterScalar[bool](ctx, identity, coerce[bool], astBool[bool]) + err = RegisterScalar[bool](ctx, identity, coerce[bool], astBool[bool], nil, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register bool: %w", err) } - err = RegisterScalar[int](ctx, identity, coerce[int], astInt[int]) + err = RegisterScalar[int](ctx, identity, coerce[int], astInt[int], nil, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register int: %w", err) } - err = RegisterScalar[uint32](ctx, identity, coerce[uint32], astInt[uint32]) + err = RegisterScalar[uint32](ctx, identity, coerce[uint32], astInt[uint32], nil, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register uint32: %w", err) } - err = RegisterScalar[uint8](ctx, identity, coerce[uint8], astInt[uint8]) + err = RegisterScalar[uint8](ctx, identity, coerce[uint8], astInt[uint8], nil, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register uint8: %w", err) } - err = RegisterScalar[time.Time](ctx, stringify, unstringify[time.Time], unstringifyAST[time.Time]) + err = RegisterScalar[time.Time](ctx, stringify, unstringify[time.Time], unstringifyAST[time.Time], nil, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register time.Time: %w", err) } - err = RegisterScalar[string](ctx, identity, coerce[string], astString[string]) + err = RegisterScalar[string](ctx, identity, coerce[string], astString[string], nil, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register string: %w", err) } err = RegisterEnum[ReqState](ctx, ReqStateStrings) if err != nil { - return nil, err - } - - err = RegisterScalar[uuid.UUID](ctx, stringify, unstringify[uuid.UUID], unstringifyAST[uuid.UUID]) - if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register ReqState: %w", err) } - err = RegisterScalar[Change](ctx, identity, coerce[Change], astString[Change]) + err = RegisterScalar[Tag](ctx, identity, coerce[Tag], astString[Tag], nil, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register Tag: %w", err) } // TODO: Register as a GQL type with Signal as an interface err = RegisterObjectNoGQL[QueuedSignal](ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register QueuedSignal: %w", err) } err = RegisterSignal[TimeoutSignal](ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register TimeoutSignal: %w", err) } err = RegisterSignal[StatusSignal](ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register StatusSignal: %w", err) } err = RegisterObject[Node](ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register Node: %w", err) } err = RegisterExtension[LockableExt](ctx, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register LockableExt extension: %w", err) } err = RegisterExtension[ListenerExt](ctx, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register ListenerExt extension: %w", err) } err = RegisterExtension[GQLExt](ctx, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register GQLExt extension: %w", err) } - err = RegisterNodeType(ctx, "Lockable", []ExtType{ExtTypeFor[LockableExt]()}) + err = RegisterNodeInterface(ctx, "Lockable", map[string]graphql.Type{ + "LockableState": gqltype(ctx, reflect.TypeFor[ReqState](), ""), + "Requirements": gqltype(ctx, reflect.TypeFor[map[NodeID]ReqState](), ":Lockable"), + }) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register NodeInterface Lockable: %w", err) + } + + err = RegisterNodeType(ctx, "LockableNode", map[string]FieldMapping{ + "LockableState": { + Extension: ExtTypeFor[LockableExt](), + Tag: "state", + }, + "Requirements": { + Extension: ExtTypeFor[LockableExt](), + Tag: "requirements", + }, + }) + if err != nil { + return nil, fmt.Errorf("Failed to register NodeType LockableNode: %w", err) } err = RegisterObject[LockableExt](ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register LockableExt object: %w", err) } err = RegisterObject[ListenerExt](ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register ListenerExt object: %w", err) } err = RegisterObject[GQLExt](ctx) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to register GQLExt object: %w", err) } schema, err := BuildSchema(ctx, graphql.NewObject(graphql.ObjectConfig{ Name: "Query", Fields: graphql.Fields{ "Self": &graphql.Field{ - Type: ctx.NodeTypes["Base"].Interface, + Type: ctx.Interfaces["Base"].Type, Resolve: func(p graphql.ResolveParams) (interface{}, error) { ctx, err := PrepResolve(p) if err != nil { @@ -1117,10 +1179,10 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { }, }, "Node": &graphql.Field{ - Type: ctx.NodeTypes["Base"].Interface, + Type: ctx.Interfaces["Base"].Type, Args: graphql.FieldConfigArgument{ "id": &graphql.ArgumentConfig{ - Type: ctx.TypeTypes[reflect.TypeFor[NodeID]()].Type, + Type: ctx.Types[reflect.TypeFor[NodeID]()].Type, }, }, Resolve: func(p graphql.ResolveParams) (interface{}, error) { @@ -1145,10 +1207,12 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { }, })) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to build schema: %w", err) } - ctx.ExtensionTypes[reflect.TypeFor[GQLExt]()].Data = schema + ext_info := ctx.Extensions[ExtTypeFor[GQLExt]()] + ext_info.Data = schema + ctx.Extensions[ExtTypeFor[GQLExt]()] = ext_info return ctx, nil } diff --git a/db.go b/db.go index 89ba21e..d53f209 100644 --- a/db.go +++ b/db.go @@ -3,6 +3,7 @@ package graphvent import ( "encoding/binary" "fmt" + "reflect" badger "github.com/dgraph-io/badger/v3" ) @@ -59,11 +60,11 @@ func WriteNodeInit(ctx *Context, node *Node) error { for ext_type, ext := range(node.Extensions) { // Write each extension's current value ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) - ext_val, err := Serialize(ctx, ext) + ext_ser, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem()) if err != nil { return err } - err = tx.Set(ext_id, ext_val) + err = tx.Set(ext_id, ext_ser) } return nil }) @@ -74,7 +75,7 @@ func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) err // Get the base key bytes id_ser, err := node.ID.MarshalBinary() if err != nil { - return err + return fmt.Errorf("Marshal ID error: %+w", err) } // Write the signal queue if it needs to be written @@ -84,11 +85,11 @@ func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) err sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...) sigqueue_val, err := Serialize(ctx, node.SignalQueue) if err != nil { - return err + return fmt.Errorf("SignalQueue Serialize Error: %+v, %w", node.SignalQueue, err) } err = tx.Set(sigqueue_id, sigqueue_val) if err != nil { - return err + return fmt.Errorf("SignalQueue set error: %+v, %w", node.SignalQueue, err) } } @@ -100,14 +101,14 @@ func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) err return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID) } ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type)) - ext_ser, err := Serialize(ctx, ext) + ext_ser, err := SerializeValue(ctx, reflect.ValueOf(ext).Elem()) if err != nil { - return err + return fmt.Errorf("Extension serialize err: %s, %w", reflect.TypeOf(ext), err) } err = tx.Set(ext_id, ext_ser) if err != nil { - return err + return fmt.Errorf("Extension set err: %s, %w", reflect.TypeOf(ext), err) } } return nil @@ -173,10 +174,25 @@ func LoadNode(ctx *Context, id NodeID) (*Node, error) { return err } + ext_info, exists := ctx.Extensions[ext_type] + if exists == false { + return fmt.Errorf("Extension %s not in context", ext_type) + } + var ext Extension + var ok bool err = ext_item.Value(func(val []byte) error { - ext, err = Deserialize[Extension](ctx, val) - return err + value, _, err := DeserializeValue(ctx, val, ext_info.Type) + if err != nil { + return err + } + + ext, ok = value.Addr().Interface().(Extension) + if ok == false { + return fmt.Errorf("Parsed value %+v is not extension", value.Type()) + } + + return nil }) if err != nil { return err diff --git a/extension.go b/extension.go index 0370a0f..cc700a2 100644 --- a/extension.go +++ b/extension.go @@ -1,7 +1,7 @@ package graphvent -type Change string -type Changes []Change +type Tag string +type Changes []Tag // Extensions are data attached to nodes that process signals type Extension interface { diff --git a/gql.go b/gql.go index d76c850..00f8aeb 100644 --- a/gql.go +++ b/gql.go @@ -501,7 +501,7 @@ type Field struct { type NodeResult struct { NodeID NodeID NodeType NodeType - Data map[ExtType]map[string]interface{} + Data map[string]interface{} } type ListField struct { diff --git a/gql_node.go b/gql_node.go index 292f99b..fe4a884 100644 --- a/gql_node.go +++ b/gql_node.go @@ -30,33 +30,21 @@ type FieldIndex struct { Tag string } -func GetFields(ctx *Context, node_type string, selection_set *ast.SelectionSet) []FieldIndex { - names := []FieldIndex{} +func GetFields(selection_set *ast.SelectionSet) []string { + names := []string{} if selection_set == nil { return names } - node_info, mapped := ctx.NodeTypes[node_type] - if mapped == false { - return nil - } - for _, sel := range(selection_set.Selections) { switch field := sel.(type) { case *ast.Field: if field.Name.Value == "ID" || field.Name.Value == "Type" { continue } - - extension, mapped := node_info.Fields[field.Name.Value] - if mapped == false { - continue - } - names = append(names, FieldIndex{extension, field.Name.Value}) + names = append(names, field.Name.Value) case *ast.InlineFragment: - names = append(names, GetFields(ctx, field.TypeCondition.Name.Value, field.SelectionSet)...) - default: - ctx.Log.Logf("gql", "Unknown selection type: %s", reflect.TypeOf(field)) + names = append(names, GetFields(field.SelectionSet)...) } } @@ -64,10 +52,10 @@ func GetFields(ctx *Context, node_type string, selection_set *ast.SelectionSet) } // Returns the fields that need to be resolved -func GetResolveFields(id NodeID, ctx *ResolveContext, p graphql.ResolveParams) []FieldIndex { - fields := []FieldIndex{} +func GetResolveFields(p graphql.ResolveParams) []string { + fields := []string{} for _, field := range(p.Info.FieldASTs) { - fields = append(fields, GetFields(ctx.Context, p.Info.ReturnType.Name(), field.SelectionSet)...) + fields = append(fields, GetFields(field.SelectionSet)...) } return fields @@ -83,13 +71,10 @@ func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) { case *StatusSignal: cached_node, cached := ctx.NodeCache[source.Source] if cached { - for ext_type, ext_changes := range(source.Changes) { - cached_ext, cached := cached_node.Data[ext_type] + for _, field_name := range(source.Fields) { + _, cached := cached_node.Data[field_name] if cached { - for _, field := range(ext_changes) { - delete(cached_ext, string(field)) - } - cached_node.Data[ext_type] = cached_ext + delete(cached_node.Data, field_name) } } ctx.NodeCache[source.Source] = cached_node @@ -97,25 +82,22 @@ func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) { } cache, node_cached := ctx.NodeCache[id] - fields := GetResolveFields(id, ctx, p) - not_cached := map[ExtType][]string{} - for _, field := range(fields) { - ext_fields, exists := not_cached[field.Extension] - if exists == false { - ext_fields = []string{} - } - - if node_cached { - ext_cache, ext_cached := cache.Data[field.Extension] - if ext_cached { - _, field_cached := ext_cache[field.Tag] + fields := GetResolveFields(p) + var not_cached []string + if node_cached { + not_cached = []string{} + for _, field := range(fields) { + if node_cached { + _, field_cached := cache.Data[field] if field_cached { continue } } - } - not_cached[field.Extension] = append(ext_fields, field.Tag) + not_cached = append(not_cached, field) + } + } else { + not_cached = fields } if (len(not_cached) == 0) && (node_cached == true) { @@ -148,20 +130,11 @@ func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) { cache = NodeResult{ NodeID: id, NodeType: response.NodeType, - Data: response.Extensions, + Data: response.Fields, } } else { - for ext_type, ext_data := range(response.Extensions) { - cached_ext, ext_cached := cache.Data[ext_type] - if ext_cached { - for field_name, field := range(ext_data) { - cache.Data[ext_type][field_name] = field - } - } else { - cache.Data[ext_type] = ext_data - } - - cache.Data[ext_type] = cached_ext + for field_name, field_value := range(response.Fields) { + cache.Data[field_name] = field_value } } diff --git a/gql_test.go b/gql_test.go index 09eb9df..b9f1de7 100644 --- a/gql_test.go +++ b/gql_test.go @@ -19,7 +19,7 @@ import ( func TestGQLSubscribe(t *testing.T) { ctx := logTestContext(t, []string{"test", "gql"}) - n1, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil)) + n1, err := NewNode(ctx, nil, "LockableNode", 10, NewLockableExt(nil)) fatalErr(t, err) listener_ext := NewListenerExt(10) @@ -27,10 +27,10 @@ func TestGQLSubscribe(t *testing.T) { gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) fatalErr(t, err) - gql, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext) + gql, err := NewNode(ctx, nil, "LockableNode", 10, NewLockableExt([]NodeID{n1.ID}), gql_ext, listener_ext) fatalErr(t, err) - query := "subscription { Self { ID, Type ... on Lockable { lockable_state } } }" + query := "subscription { Self { ID, Type ... on Lockable { LockableState } } }" ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "Node: %s", n1.ID) @@ -129,14 +129,14 @@ func TestGQLQuery(t *testing.T) { ctx := logTestContext(t, []string{"test", "lockable"}) n1_listener := NewListenerExt(10) - n1, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil), n1_listener) + n1, err := NewNode(ctx, nil, "LockableNode", 10, NewLockableExt(nil), n1_listener) fatalErr(t, err) gql_listener := NewListenerExt(10) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) fatalErr(t, err) - gql, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt([]NodeID{n1.ID}), gql_ext, gql_listener) + gql, err := NewNode(ctx, nil, "LockableNode", 10, NewLockableExt([]NodeID{n1.ID}), gql_ext, gql_listener) fatalErr(t, err) ctx.Log.Logf("test", "GQL: %s", gql.ID) @@ -150,14 +150,14 @@ func TestGQLQuery(t *testing.T) { url := fmt.Sprintf("http://localhost:%d/gql", port) req_1 := GQLPayload{ - Query: "query Node($id:graphvent_NodeID) { Node(id:$id) { ID, Type, ... on Lockable { lockable_state } } }", + Query: "query Node($id:graphvent_NodeID) { Node(id:$id) { ID, Type, ... on Lockable { LockableState } } }", Variables: map[string]interface{}{ "id": n1.ID.String(), }, } req_2 := GQLPayload{ - Query: "query Self { Self { ID, Type, ... on Lockable { lockable_state, requirements { Key { ID ... on Lockable { lockable_state } } } } } }", + Query: "query Self { Self { ID, Type, ... on Lockable { LockableState, Requirements { Key { ID ... on Lockable { LockableState } } } } } }", } SendGQL := func(payload GQLPayload) []byte { @@ -208,7 +208,7 @@ func TestGQLDB(t *testing.T) { fatalErr(t, err) listener_ext := NewListenerExt(10) - gql, err := NewNode(ctx, nil, "Base", 10, gql_ext, listener_ext) + gql, err := NewNode(ctx, nil, "Node", 10, gql_ext, listener_ext) fatalErr(t, err) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) diff --git a/graph_test.go b/graph_test.go index 5794351..5e8b6a1 100644 --- a/graph_test.go +++ b/graph_test.go @@ -11,7 +11,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt, error) { listener_extension := NewListenerExt(buffer) listener, err := NewNode(ctx, nil, - "LockableListener", + "LockableNode", 10, nil, listener_extension, @@ -29,9 +29,6 @@ func logTestContext(t * testing.T, components []string) *Context { ctx, err := NewContext(db, NewConsoleLogger(components)) fatalErr(t, err) - err = RegisterNodeType(ctx, "LockableListener", []ExtType{ExtTypeFor[ListenerExt](), ExtTypeFor[LockableExt]()}) - fatalErr(t, err) - return ctx } diff --git a/listener.go b/listener.go index fbdc4a5..719b6db 100644 --- a/listener.go +++ b/listener.go @@ -60,7 +60,7 @@ func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal } switch sig := signal.(type) { case *StatusSignal: - ctx.Log.Logf("listener_status", "%s - %+v", sig.Source, sig.Changes) + ctx.Log.Logf("listener_status", "%s - %+v", sig.Source, sig.Fields) } return nil, nil } diff --git a/lockable.go b/lockable.go index d470655..7f59f77 100644 --- a/lockable.go +++ b/lockable.go @@ -31,7 +31,7 @@ func (state ReqState) String() string { } type LockableExt struct{ - State ReqState `gv:"lockable_state"` + State ReqState `gv:"state"` ReqID *uuid.UUID `gv:"req_id"` Owner *NodeID `gv:"owner"` PendingOwner *NodeID `gv:"pending_owner"` @@ -129,7 +129,7 @@ func (ext *LockableExt) HandleUnlockSignal(ctx *Context, node *Node, source Node messages = append(messages, SendMsg{source, NewErrorSignal(signal.Id, "not_owner")}) } else { if len(ext.Requirements) == 0 { - changes = append(changes, "lockable_state", "owner", "pending_owner") + changes = append(changes, "state", "owner", "pending_owner") ext.Owner = nil @@ -139,7 +139,7 @@ func (ext *LockableExt) HandleUnlockSignal(ctx *Context, node *Node, source Node messages = append(messages, SendMsg{source, NewSuccessSignal(signal.Id)}) } else { - changes = append(changes, "lockable_state", "waiting", "requirements", "pending_owner") + changes = append(changes, "state", "waiting", "requirements", "pending_owner") ext.PendingOwner = nil @@ -173,7 +173,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID switch ext.State { case Unlocked: if len(ext.Requirements) == 0 { - changes = append(changes, "lockable_state", "owner", "pending_owner") + changes = append(changes, "state", "owner", "pending_owner") ext.Owner = new(NodeID) *ext.Owner = source @@ -184,7 +184,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.State = Locked messages = append(messages, SendMsg{source, NewSuccessSignal(signal.Id)}) } else { - changes = append(changes, "lockable_state", "requirements", "waiting", "pending_owner") + changes = append(changes, "state", "requirements", "waiting", "pending_owner") ext.PendingOwner = new(NodeID) *ext.PendingOwner = source @@ -221,7 +221,7 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI switch ext.State { case Locking: - changes = append(changes, "lockable_state", "requirements") + changes = append(changes, "state", "requirements") ext.Requirements[id] = Unlocked @@ -242,11 +242,11 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI } if unlocked == len(ext.Requirements) { - changes = append(changes, "owner", "lockable_state") + changes = append(changes, "owner", "state") ext.State = Unlocked ext.Owner = nil } else { - changes = append(changes, "lockable_state") + changes = append(changes, "state") ext.State = AbortingLock } @@ -271,7 +271,7 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI } if unlocked == len(ext.Requirements) { - changes = append(changes, "owner", "lockable_state") + changes = append(changes, "owner", "state") ext.State = Unlocked ext.Owner = nil } @@ -309,7 +309,8 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod } if locked == len(ext.Requirements) { - changes = append(changes, "lockable_state", "owner", "req_id") + ctx.Log.Logf("lockable", "%s FULL_LOCK: %d", node.ID, locked) + changes = append(changes, "state", "owner", "req_id") ext.State = Locked ext.Owner = new(NodeID) @@ -317,6 +318,8 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(*ext.ReqID)}) ext.ReqID = nil + } else { + ctx.Log.Logf("lockable", "%s PARTIAL_LOCK: %d/%d", node.ID, locked, len(ext.Requirements)) } case AbortingLock: req_state := ext.Requirements[id] @@ -338,7 +341,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod } if unlocked == len(ext.Requirements) { - changes = append(changes, "lockable_state", "pending_owner", "req_id") + changes = append(changes, "state", "pending_owner", "req_id") messages = append(messages, SendMsg{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked: %s", ext.State)}) ext.State = Unlocked @@ -359,7 +362,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod } if unlocked == len(ext.Requirements) { - changes = append(changes, "lockable_state", "owner", "req_id") + changes = append(changes, "state", "owner", "req_id") messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(*ext.ReqID)}) ext.State = Unlocked diff --git a/lockable_test.go b/lockable_test.go index 8f7728c..32d0c8e 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -10,12 +10,12 @@ func TestLink(t *testing.T) { l2_listener := NewListenerExt(10) - l2, err := NewNode(ctx, nil, "Lockable", 10, l2_listener, NewLockableExt(nil)) + l2, err := NewNode(ctx, nil, "LockableNode", 10, l2_listener, NewLockableExt(nil)) fatalErr(t, err) l1_lockable := NewLockableExt(nil) l1_listener := NewListenerExt(10) - l1, err := NewNode(ctx, nil, "Lockable", 10, l1_listener, l1_lockable) + l1, err := NewNode(ctx, nil, "LockableNode", 10, l1_listener, l1_lockable) fatalErr(t, err) link_signal := NewLinkSignal("add", l2.ID) @@ -42,24 +42,24 @@ func TestLink(t *testing.T) { fatalErr(t, err) } -func Test1000Lock(t *testing.T) { +func Test10000Lock(t *testing.T) { ctx := logTestContext(t, []string{"test"}) NewLockable := func()(*Node) { - l, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil)) + l, err := NewNode(ctx, nil, "LockableNode", 10, NewLockableExt(nil)) fatalErr(t, err) return l } - reqs := make([]NodeID, 1000) + reqs := make([]NodeID, 10000) for i := range(reqs) { new_lockable := NewLockable() reqs[i] = new_lockable.ID } - ctx.Log.Logf("test", "CREATED_1000") + ctx.Log.Logf("test", "CREATED_10000") - listener := NewListenerExt(5000) - node, err := NewNode(ctx, nil, "Lockable", 5000, listener, NewLockableExt(reqs)) + listener := NewListenerExt(50000) + node, err := NewNode(ctx, nil, "LockableNode", 500000, listener, NewLockableExt(reqs)) fatalErr(t, err) ctx.Log.Logf("test", "CREATED_LISTENER") @@ -75,15 +75,15 @@ func Test1000Lock(t *testing.T) { t.Fatalf("Unexpected response to lock - %s", resp) } - ctx.Log.Logf("test", "LOCKED_1000") + ctx.Log.Logf("test", "LOCKED_10000") } func TestLock(t *testing.T) { ctx := logTestContext(t, []string{"test", "lockable"}) NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { - listener := NewListenerExt(1000) - l, err := NewNode(ctx, nil, "Lockable", 10, listener, NewLockableExt(reqs)) + listener := NewListenerExt(10000) + l, err := NewNode(ctx, nil, "LockableNode", 10, listener, NewLockableExt(reqs)) fatalErr(t, err) return l, listener } @@ -112,7 +112,7 @@ func TestLock(t *testing.T) { ctx.Log.Logf("test", "locking l1") id_2, err := LockLockable(ctx, l1) fatalErr(t, err) - response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*1000, id_2) + response, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10000, id_2) fatalErr(t, err) ctx.Log.Logf("test", "l1 lock: %+v", response) diff --git a/node.go b/node.go index 2b5d881..adfa72f 100644 --- a/node.go +++ b/node.go @@ -24,6 +24,10 @@ type NodeID uuid.UUID func (id NodeID) MarshalBinary() ([]byte, error) { return (uuid.UUID)(id).MarshalBinary() } +func (id *NodeID) UnmarshalBinary(data []byte) error { + return (*uuid.UUID)(id).UnmarshalBinary(data) +} + func (id NodeID) String() string { return (uuid.UUID)(id).String() } @@ -140,7 +144,13 @@ func SoonestSignal(signals []QueuedSignal) (*QueuedSignal, <-chan time.Time) { } if soonest_signal != nil { - return soonest_signal, time.After(time.Until(soonest_signal.Time)) + if time.Now().Compare(soonest_time) == -1 { + return soonest_signal, time.After(time.Until(soonest_signal.Time)) + } else { + c := make(chan time.Time, 1) + c <- soonest_time + return soonest_signal, c + } } else { return nil, nil } @@ -166,25 +176,23 @@ func (err StringError) MarshalBinary() ([]byte, error) { return []byte(string(err)), nil } -func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]any { - ctx.Log.Logf("read_field", "Reading %+v on %+v", reqs, node.ID) - exts := map[ExtType]map[string]any{} - for ext_type, field_reqs := range(reqs) { - ext_info, ext_known := ctx.Extensions[ext_type] - if ext_known { - fields := map[string]any{} - for _, req := range(field_reqs) { - ext, exists := node.Extensions[ext_type] - if exists == false { - fields[req] = fmt.Errorf("%+v does not have %+v extension", node.ID, ext_type) - } else { - fields[req] = reflect.ValueOf(ext).Elem().FieldByIndex(ext_info.Fields[req].Index).Interface() - } - } - exts[ext_type] = fields +func (node *Node) ReadFields(ctx *Context, fields []string)map[string]any { + ctx.Log.Logf("read_field", "Reading %+v on %+v", fields, node.ID) + values := map[string]any{} + + node_info := ctx.NodeTypes[node.Type] + + for _, field_name := range(fields) { + field_info, mapped := node_info.Fields[field_name] + if mapped { + ext := node.Extensions[field_info.Extension] + values[field_name] = reflect.ValueOf(ext).Elem().FieldByIndex(field_info.Index).Interface() + } else { + values[field_name] = fmt.Errorf("NodeType %s has no field %s", node.Type, field_name) } } - return exts + + return values } // Main Loop for nodes @@ -196,15 +204,22 @@ func nodeLoop(ctx *Context, node *Node, started chan error) error { ctx.Log.Logf("node", "Set %s active", node.ID) } + ctx.Log.Logf("node_ext", "Loading extensions for %s", node.ID) + for _, extension := range(node.Extensions) { + ctx.Log.Logf("node_ext", "Loading extension %s for %s", reflect.TypeOf(extension), node.ID) err := extension.Load(ctx, node) if err != nil { + ctx.Log.Logf("node_ext", "Failed to load extension %s on node %s", reflect.TypeOf(extension), node.ID) node.Active.Store(false) - ctx.Log.Logf("node", "Failed to load extension %s on node %s", reflect.TypeOf(extension), node.ID) return err + } else { + ctx.Log.Logf("node_ext", "Loaded extension %s on node %s", reflect.TypeOf(extension), node.ID) } } + ctx.Log.Logf("node_ext", "Loaded extensions for %s", node.ID) + started <- nil run := true @@ -212,10 +227,6 @@ func nodeLoop(ctx *Context, node *Node, started chan error) error { var signal Signal var source NodeID select { - case msg := <- node.MsgChan: - signal = msg.Signal - source = msg.Source - case <-node.TimeoutChan: signal = node.NextSignal.Signal source = node.ID @@ -244,13 +255,17 @@ func nodeLoop(ctx *Context, node *Node, started chan error) error { } else { ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, signal, t, node.NextSignal, node.NextSignal.Time) } + case msg := <- node.MsgChan: + signal = msg.Signal + source = msg.Source + } ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) switch sig := signal.(type) { case *ReadSignal: - result := node.ReadFields(ctx, sig.Extensions) + result := node.ReadFields(ctx, sig.Fields) msgs := []SendMsg{} msgs = append(msgs, SendMsg{source, NewReadResultSignal(sig.ID(), node.ID, node.Type, result)}) ctx.Send(node, msgs) @@ -283,8 +298,25 @@ func (node *Node) Unload(ctx *Context) error { } func (node *Node) QueueChanges(ctx *Context, changes map[ExtType]Changes) error { - node.QueueSignal(time.Now(), NewStatusSignal(node.ID, changes)) - return nil + node_info, exists := ctx.NodeTypes[node.Type] + if exists == false { + return fmt.Errorf("Node type not in context, can't map changes to field names") + } else { + fields := []string{} + for ext_type, ext_changes := range(changes) { + ext_map, ext_mapped := node_info.ReverseFields[ext_type] + if ext_mapped { + for _, ext_tag := range(ext_changes) { + field_name, tag_mapped := ext_map[ext_tag] + if tag_mapped { + fields = append(fields, field_name) + } + } + } + } + node.QueueSignal(time.Time{}, NewStatusSignal(node.ID, fields)) + return nil + } } func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { @@ -311,12 +343,14 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { } } - if len(changes) != 0 { + if (len(changes) != 0) || node.writeSignalQueue { write_err := WriteNodeChanges(ctx, node, changes) if write_err != nil { return write_err } + } + if len(changes) != 0 { status_err := node.QueueChanges(ctx, changes) if status_err != nil { return status_err @@ -365,7 +399,8 @@ func KeyID(pub ed25519.PublicKey) NodeID { // Create a new node in memory and start it's event loop func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size uint32, extensions ...Extension) (*Node, error) { - node_type, known_type := ctx.NodeTypes[type_name] + node_type := NodeTypeFor(type_name) + node_info, known_type := ctx.NodeTypes[node_type] if known_type == false { return nil, fmt.Errorf("%s is not a known node type", type_name) } @@ -392,9 +427,9 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size return nil, fmt.Errorf("Cannot create node with nil extension") } - ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext).Elem()] + ext_type, exists := ctx.Extensions[ExtTypeOf(reflect.TypeOf(ext))] if exists == false { - return nil, fmt.Errorf("%+v is not a known Extension", reflect.TypeOf(ext)) + return nil, fmt.Errorf("%+v(%+v) is not a known Extension", reflect.TypeOf(ext), ExtTypeOf(reflect.TypeOf(ext))) } _, exists = ext_map[ext_type.ExtType] if exists == true { @@ -403,7 +438,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size ext_map[ext_type.ExtType] = ext } - for _, required_ext := range(node_type.Extensions) { + for _, required_ext := range(node_info.RequiredExtensions) { _, exists := ext_map[required_ext] if exists == false { return nil, fmt.Errorf(fmt.Sprintf("%+v requires %+v", node_type, required_ext)) @@ -413,7 +448,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size node := &Node{ Key: key, ID: id, - Type: node_type.NodeType, + Type: node_type, Extensions: ext_map, MsgChan: make(chan RecvMsg, buffer_size), BufferSize: buffer_size, @@ -429,7 +464,6 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size ctx.AddNode(id, node) started := make(chan error, 1) go runNode(ctx, node, started) - err = <- started if err != nil { return nil, err diff --git a/node_test.go b/node_test.go index 29b9f5e..dd0e75d 100644 --- a/node_test.go +++ b/node_test.go @@ -9,24 +9,20 @@ import ( ) func TestNodeDB(t *testing.T) { - ctx := logTestContext(t, []string{"node", "db"}) + ctx := logTestContext(t, []string{"test", "node", "db"}) node_listener := NewListenerExt(10) - node, err := NewNode(ctx, nil, "Base", 10, NewLockableExt(nil), node_listener) + node, err := NewNode(ctx, nil, "Node", 10, NewLockableExt(nil), node_listener) fatalErr(t, err) _, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { - gql_changes, has_gql := sig.Changes[ExtTypeFor[GQLExt]()] - if has_gql == true { - return slices.Contains(gql_changes, "state") && sig.Source == node.ID - } - return false + return slices.Contains(sig.Fields, "state") && sig.Source == node.ID }) err = ctx.Unload(node.ID) fatalErr(t, err) - ctx.nodeMap = map[NodeID]*Node{} + ctx.nodes = map[NodeID]*Node{} _, err = ctx.getNode(node.ID) fatalErr(t, err) } @@ -46,15 +42,13 @@ func TestNodeRead(t *testing.T) { ctx.Log.Logf("test", "N2: %s", n2_id) n2_listener := NewListenerExt(10) - n2, err := NewNode(ctx, n2_key, "Base", 10, n2_listener) + n2, err := NewNode(ctx, n2_key, "Node", 10, n2_listener) fatalErr(t, err) - n1, err := NewNode(ctx, n1_key, "Base", 10, NewListenerExt(10)) + n1, err := NewNode(ctx, n1_key, "Node", 10, NewListenerExt(10)) fatalErr(t, err) - read_sig := NewReadSignal(map[ExtType][]string{ - ExtTypeFor[ListenerExt](): {"buffer"}, - }) + read_sig := NewReadSignal([]string{"buffer"}) msgs := []SendMsg{{n1.ID, read_sig}} err = ctx.Send(n2, msgs) fatalErr(t, err) diff --git a/serialize.go b/serialize.go index 9276db6..db65de9 100644 --- a/serialize.go +++ b/serialize.go @@ -6,7 +6,6 @@ import ( "fmt" "reflect" "math" - "slices" ) type SerializedType uint64 @@ -39,14 +38,8 @@ func (t FieldTag) String() string { return fmt.Sprintf("0x%x", uint64(t)) } -func NodeTypeFor(extensions []ExtType) NodeType { - digest := []byte("GRAPHVENT_NODE - ") - - slices.Sort(extensions) - - for _, ext := range(extensions) { - digest = binary.BigEndian.AppendUint64(digest, uint64(ext)) - } +func NodeTypeFor(name string) NodeType { + digest := []byte("GRAPHVENT_NODE - " + name) hash := sha512.Sum512(digest) return NodeType(binary.BigEndian.Uint64(hash[0:8])) @@ -66,6 +59,10 @@ func ExtTypeFor[E any, T interface { *E; Extension}]() ExtType { return ExtType(SerializedTypeFor[E]()) } +func ExtTypeOf(t reflect.Type) ExtType { + return ExtType(SerializeType(t.Elem())) +} + func SignalTypeFor[S Signal]() SignalType { return SignalType(SerializedTypeFor[S]()) } @@ -81,7 +78,7 @@ func GetFieldTag(tag string) FieldTag { } func TypeStack(ctx *Context, t reflect.Type) ([]byte, error) { - info, registered := ctx.TypeTypes[t] + info, registered := ctx.Types[t] if registered { return binary.BigEndian.AppendUint64(nil, uint64(info.Serialized)), nil } else { @@ -131,7 +128,7 @@ func UnwrapStack(ctx *Context, stack []byte) (reflect.Type, []byte, error) { first_bytes, left := split(stack, 8) first := SerializedType(binary.BigEndian.Uint64(first_bytes)) - info, registered := ctx.TypeMap[first] + info, registered := ctx.TypesReverse[first] if registered { return info.Reflect, left, nil } else { @@ -177,13 +174,13 @@ func UnwrapStack(ctx *Context, stack []byte) (reflect.Type, []byte, error) { } func Serialize[T any](ctx *Context, value T) ([]byte, error) { - return serializeValue(ctx, reflect.ValueOf(&value).Elem()) + return SerializeValue(ctx, reflect.ValueOf(&value).Elem()) } func Deserialize[T any](ctx *Context, data []byte) (T, error) { reflect_type := reflect.TypeFor[T]() var zero T - value, left, err := deserializeValue(ctx, data, reflect_type) + value, left, err := DeserializeValue(ctx, data, reflect_type) if err != nil { return zero, err } else if len(left) != 0 { @@ -195,10 +192,10 @@ func Deserialize[T any](ctx *Context, data []byte) (T, error) { return value.Interface().(T), nil } -func serializeValue(ctx *Context, value reflect.Value) ([]byte, error) { +func SerializeValue(ctx *Context, value reflect.Value) ([]byte, error) { var serialize SerializeFn = nil - info, registered := ctx.TypeTypes[value.Type()] + info, registered := ctx.Types[value.Type()] if registered { serialize = info.Serialize } @@ -248,7 +245,7 @@ func serializeValue(ctx *Context, value reflect.Value) ([]byte, error) { if value.IsNil() { return []byte{0x00}, nil } else { - elem, err := serializeValue(ctx, value.Elem()) + elem, err := SerializeValue(ctx, value.Elem()) if err != nil { return nil, err } @@ -265,7 +262,7 @@ func serializeValue(ctx *Context, value reflect.Value) ([]byte, error) { data := []byte{} for i := 0; i < value.Len(); i++ { - elem, err := serializeValue(ctx, value.Index(i)) + elem, err := SerializeValue(ctx, value.Index(i)) if err != nil { return nil, err } @@ -279,7 +276,7 @@ func serializeValue(ctx *Context, value reflect.Value) ([]byte, error) { case reflect.Array: data := []byte{} for i := 0; i < value.Len(); i++ { - elem, err := serializeValue(ctx, value.Index(i)) + elem, err := SerializeValue(ctx, value.Index(i)) if err != nil { return nil, err } @@ -293,16 +290,20 @@ func serializeValue(ctx *Context, value reflect.Value) ([]byte, error) { binary.BigEndian.PutUint64(len_bytes, uint64(value.Len())) data := []byte{} + key := reflect.New(value.Type().Key()).Elem() + val := reflect.New(value.Type().Elem()).Elem() iter := value.MapRange() for iter.Next() { - k, err := serializeValue(ctx, iter.Key()) + key.SetIterKey(iter) + val.SetIterValue(iter) + + k, err := SerializeValue(ctx, key) if err != nil { return nil, err } - data = append(data, k...) - v, err := serializeValue(ctx, iter.Value()) + v, err := SerializeValue(ctx, val) if err != nil { return nil, err } @@ -319,7 +320,7 @@ func serializeValue(ctx *Context, value reflect.Value) ([]byte, error) { for field_tag, field_info := range(info.Fields) { data = append(data, binary.BigEndian.AppendUint64(nil, uint64(field_tag))...) - field_bytes, err := serializeValue(ctx, value.FieldByIndex(field_info.Index)) + field_bytes, err := SerializeValue(ctx, value.FieldByIndex(field_info.Index)) if err != nil { return nil, err } @@ -332,7 +333,7 @@ func serializeValue(ctx *Context, value reflect.Value) ([]byte, error) { case reflect.Interface: data, err := TypeStack(ctx, value.Elem().Type()) - val_data, err := serializeValue(ctx, value.Elem()) + val_data, err := SerializeValue(ctx, value.Elem()) if err != nil { return nil, err } @@ -352,10 +353,10 @@ func split(data []byte, n int) ([]byte, []byte) { return data[:n], data[n:] } -func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, []byte, error) { +func DeserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, []byte, error) { var deserialize DeserializeFn = nil - info, registered := ctx.TypeTypes[t] + info, registered := ctx.Types[t] if registered { deserialize = info.Deserialize } @@ -439,7 +440,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, value.SetZero() return value, after_flags, nil } else { - elem_value, after_elem, err := deserializeValue(ctx, after_flags, t.Elem()) + elem_value, after_elem, err := DeserializeValue(ctx, after_flags, t.Elem()) if err != nil { return reflect.Value{}, nil, err } @@ -454,7 +455,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, for i := 0; i < length; i++ { var elem_value reflect.Value var err error - elem_value, left, err = deserializeValue(ctx, left, t.Elem()) + elem_value, left, err = DeserializeValue(ctx, left, t.Elem()) if err != nil { return reflect.Value{}, nil, err } @@ -468,7 +469,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, for i := 0; i < t.Len(); i++ { var elem_value reflect.Value var err error - elem_value, left, err = deserializeValue(ctx, left, t.Elem()) + elem_value, left, err = DeserializeValue(ctx, left, t.Elem()) if err != nil { return reflect.Value{}, nil, err } @@ -487,12 +488,12 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, var val_value reflect.Value var err error - key_value, left, err = deserializeValue(ctx, left, t.Key()) + key_value, left, err = DeserializeValue(ctx, left, t.Key()) if err != nil { return reflect.Value{}, nil, err } - val_value, left, err = deserializeValue(ctx, left, t.Elem()) + val_value, left, err = DeserializeValue(ctx, left, t.Elem()) if err != nil { return reflect.Value{}, nil, err } @@ -503,7 +504,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, return value, left, nil case reflect.Struct: - info, mapped := ctx.TypeTypes[t] + info, mapped := ctx.Types[t] if mapped { value := reflect.New(t).Elem() @@ -520,7 +521,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, if mapped { var field_val reflect.Value var err error - field_val, left, err = deserializeValue(ctx, left, field_info.Type) + field_val, left, err = DeserializeValue(ctx, left, field_info.Type) if err != nil { return reflect.Value{}, nil, err } @@ -544,7 +545,7 @@ func deserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, return reflect.Value{}, nil, err } - elem_val, left, err := deserializeValue(ctx, rest, elem_type) + elem_val, left, err := DeserializeValue(ctx, rest, elem_type) if err != nil { return reflect.Value{}, nil, err } diff --git a/serialize_test.go b/serialize_test.go index 59e2392..7d2eccd 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -144,7 +144,9 @@ func TestSerializeValues(t *testing.T) { testSerializeCompare[*int](t, ctx, nil) testSerializeCompare(t, ctx, "string") - node, err := NewNode(ctx, nil, "Base", 100) + testSerialize(t, ctx, NewListenerExt(10)) + + node, err := NewNode(ctx, nil, "Node", 100) fatalErr(t, err) testSerialize(t, ctx, node) } diff --git a/signal.go b/signal.go index e543bcf..aa55e98 100644 --- a/signal.go +++ b/signal.go @@ -30,7 +30,7 @@ func (signal SignalHeader) ID() uuid.UUID { } func (header SignalHeader) String() string { - return fmt.Sprintf("SignalHeader(%s)", header.Id) + return fmt.Sprintf("%s", header.Id) } type ResponseSignal interface { @@ -48,7 +48,7 @@ func (header ResponseHeader) ResponseID() uuid.UUID { } func (header ResponseHeader) String() string { - return fmt.Sprintf("ResponseHeader(%s, %s)", header.Id, header.ReqID) + return fmt.Sprintf("%s for %s", header.Id, header.ReqID) } type Signal interface { @@ -164,16 +164,16 @@ func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal { type StatusSignal struct { SignalHeader Source NodeID `gv:"source"` - Changes map[ExtType]Changes `gv:"changes"` + Fields []string `gv:"fields"` } func (signal StatusSignal) String() string { - return fmt.Sprintf("StatusSignal(%s: %+v)", signal.Source, signal.Changes) + return fmt.Sprintf("StatusSignal(%s: %+v)", signal.Source, signal.Fields) } -func NewStatusSignal(source NodeID, changes map[ExtType]Changes) *StatusSignal { +func NewStatusSignal(source NodeID, fields []string) *StatusSignal { return &StatusSignal{ NewSignalHeader(), source, - changes, + fields, } } @@ -225,17 +225,17 @@ func NewUnlockSignal() *UnlockSignal { type ReadSignal struct { SignalHeader - Extensions map[ExtType][]string `json:"extensions"` + Fields []string `json:"extensions"` } func (signal ReadSignal) String() string { - return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Extensions) + return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Fields) } -func NewReadSignal(exts map[ExtType][]string) *ReadSignal { +func NewReadSignal(fields []string) *ReadSignal { return &ReadSignal{ NewSignalHeader(), - exts, + fields, } } @@ -243,19 +243,19 @@ type ReadResultSignal struct { ResponseHeader NodeID NodeID NodeType NodeType - Extensions map[ExtType]map[string]any + Fields map[string]any } func (signal ReadResultSignal) String() string { - return fmt.Sprintf("ReadResultSignal(%s, %s, %+v)", signal.ResponseHeader, signal.NodeID, signal.Extensions) + return fmt.Sprintf("ReadResultSignal(%s, %s, %+v)", signal.ResponseHeader, signal.NodeID, signal.Fields) } -func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]any) *ReadResultSignal { +func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, fields map[string]any) *ReadResultSignal { return &ReadResultSignal{ NewResponseHeader(req_id), node_id, node_type, - exts, + fields, } }