diff --git a/context.go b/context.go index 1279338..42811f1 100644 --- a/context.go +++ b/context.go @@ -26,27 +26,43 @@ var ( ECDH = ecdh.X25519() ) +type SerializeFn func(ctx *Context, value reflect.Value) ([]byte, error) +type DeserializeFn func(ctx *Context, data []byte) (reflect.Value, []byte, error) + +type FieldInfo struct { + Index []int + Tag FieldTag + Type reflect.Type +} + type TypeInfo struct { + Serialized SerializedType + Reflect reflect.Type Type graphql.Type + + Fields map[FieldTag]FieldInfo + PostDeserializeIndex int + + Serialize SerializeFn + Deserialize DeserializeFn } type ExtensionInfo struct { + ExtType Interface *graphql.Interface Fields map[string][]int Data interface{} } -type SignalInfo struct { - Type graphql.Type -} - type FieldIndex struct { + FieldTag Extension ExtType Field string } type NodeInfo struct { - GQL *graphql.Object + NodeType + Type *graphql.Object Extensions []ExtType Fields map[string]FieldIndex } @@ -60,16 +76,16 @@ type Context struct { Log Logger // Mapped types - TypeMap map[SerializedType]TypeInfo - TypeTypes map[reflect.Type]SerializedType + TypeMap map[SerializedType]*TypeInfo + TypeTypes map[reflect.Type]*TypeInfo // Map between database extension hashes and the registered info - Extensions map[ExtType]ExtensionInfo - ExtensionTypes map[reflect.Type]ExtType + Extensions map[ExtType]*ExtensionInfo + ExtensionTypes map[reflect.Type]*ExtensionInfo // Map between database type hashes and the registered info - Nodes map[NodeType]NodeInfo - NodeTypes map[string]NodeType + Nodes map[NodeType]*NodeInfo + NodeTypes map[string]*NodeInfo // Routing map to all the nodes local to this context nodeMapLock sync.RWMutex @@ -77,36 +93,36 @@ type Context struct { } func (ctx *Context) GQLType(t reflect.Type) graphql.Type { - ser, mapped := ctx.TypeTypes[t] + info, mapped := ctx.TypeTypes[t] if mapped { - return ctx.TypeMap[ser].Type + return info.Type } else { switch t.Kind() { case reflect.Array: - ser, mapped := ctx.TypeTypes[t.Elem()] + info, mapped := ctx.TypeTypes[t.Elem()] if mapped { - return graphql.NewList(ctx.TypeMap[ser].Type) + return graphql.NewList(info.Type) } case reflect.Slice: - ser, mapped := ctx.TypeTypes[t.Elem()] + info, mapped := ctx.TypeTypes[t.Elem()] if mapped { - return graphql.NewList(ctx.TypeMap[ser].Type) + return graphql.NewList(info.Type) } case reflect.Map: - ser, exists := ctx.TypeTypes[t] + info, exists := ctx.TypeTypes[t] if exists { - return ctx.TypeMap[ser].Type + return info.Type } else { err := RegisterMap(ctx, t) if err != nil { return nil } - return ctx.TypeMap[ctx.TypeTypes[t]].Type + return ctx.TypeTypes[t].Type } case reflect.Pointer: - ser, mapped := ctx.TypeTypes[t.Elem()] + info, mapped := ctx.TypeTypes[t.Elem()] if mapped { - return ctx.TypeMap[ser].Type + return info.Type } } return nil @@ -144,10 +160,13 @@ func RegisterMap(ctx *Context, reflect_type reflect.Type) error { gql_map := graphql.NewList(gql_pair) - ctx.TypeTypes[reflect_type] = SerializeType(reflect_type) - ctx.TypeMap[SerializeType(reflect_type)] = TypeInfo{ + serialized_type := SerializeType(reflect_type) + ctx.TypeMap[serialized_type] = &TypeInfo{ + Serialized: serialized_type, + Reflect: reflect_type, Type: gql_map, } + ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] return nil } @@ -164,7 +183,7 @@ func BuildSchema(ctx *Context, query, mutation *graphql.Object) (graphql.Schema, } for _, info := range(ctx.Nodes) { - types = append(types, info.GQL) + types = append(types, info.Type) } subscription := graphql.NewObject(graphql.ObjectConfig{ @@ -223,10 +242,10 @@ func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data i type_info, type_exists := ctx.Context.Nodes[node.NodeType] if type_exists == false { - return ctx.Context.Nodes[ctx.Context.NodeTypes["Base"]].GQL + return ctx.Context.NodeTypes["Base"].Type } - return type_info.GQL + return type_info.Type }, Fields: graphql.Fields{ "ID": &graphql.Field{ @@ -259,12 +278,13 @@ func RegisterExtension[E any, T interface { *E; Extension}](ctx *Context, data i ctx.Log.Logf("serialize_types", "Registered ExtType: %+v - %+v", reflect_type, ext_type) - ctx.Extensions[ext_type] = ExtensionInfo{ + ctx.Extensions[ext_type] = &ExtensionInfo{ + ExtType: ext_type, Interface: gql_interface, Data: data, Fields: fields, } - ctx.ExtensionTypes[reflect_type] = ext_type + ctx.ExtensionTypes[reflect_type] = ctx.Extensions[ext_type] return nil } @@ -290,19 +310,36 @@ func RegisterNodeType(ctx *Context, name string, extensions []ExtType, mappings ext_found[extension] = true } + + gql := graphql.NewObject(graphql.ObjectConfig{ + Name: name, + Interfaces: []*graphql.Interface{}, + Fields: graphql.Fields{ + "ID": &graphql.Field{ + Type: graphql.String, + }, + "Type": &graphql.Field{ + Type: graphql.String, + }, + }, + IsTypeOf: func(p graphql.IsTypeOfParams) bool { + return false + }, + }) - ctx.Nodes[node_type] = NodeInfo{ + ctx.Nodes[node_type] = &NodeInfo{ + NodeType: node_type, + Type: gql, Extensions: extensions, Fields: mappings, } - ctx.NodeTypes[name] = node_type + ctx.NodeTypes[name] = ctx.Nodes[node_type] return nil } func RegisterObject[T any](ctx *Context) error { reflect_type := reflect.TypeFor[T]() - ctx.Log.Logf("test", "registering %+v", reflect_type) serialized_type := SerializedTypeFor[T]() _, exists := ctx.TypeTypes[reflect_type] @@ -317,10 +354,24 @@ func RegisterObject[T any](ctx *Context) error { }, Fields: graphql.Fields{}, }) + + field_infos := map[FieldTag]FieldInfo{} + + post_deserialize, post_deserialize_exists := reflect.PointerTo(reflect_type).MethodByName("PostDeserialize") + post_deserialize_index := -1 + if post_deserialize_exists { + post_deserialize_index = post_deserialize.Index + } for _, field := range(reflect.VisibleFields(reflect_type)) { gv_tag, tagged_gv := field.Tag.Lookup("gv") if tagged_gv { + field_infos[GetFieldTag(gv_tag)] = FieldInfo{ + Type: field.Type, + Tag: GetFieldTag(gv_tag), + Index: field.Index, + } + gql_type := ctx.GQLType(field.Type) if gql_type == nil { return fmt.Errorf("Object %+v has field %s of unknown type %+v", reflect_type, gv_tag, field.Type) @@ -344,10 +395,14 @@ func RegisterObject[T any](ctx *Context) error { } } - ctx.TypeTypes[reflect_type] = serialized_type - ctx.TypeMap[serialized_type] = TypeInfo{ + ctx.TypeMap[serialized_type] = &TypeInfo{ + PostDeserializeIndex: post_deserialize_index, + Serialized: serialized_type, + Reflect: reflect_type, + Fields: field_infos, Type: gql, } + ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] return nil } @@ -416,6 +471,31 @@ func astString[T ~string](value ast.Value) interface{} { return T(str.Value) } +func astBool[T ~bool](value ast.Value) interface{} { + switch value := value.(type) { + case *ast.BooleanValue: + if value.Value { + return T(true) + } else { + return T(false) + } + case *ast.IntValue: + i, err := strconv.Atoi(value.Value) + if err != nil { + return nil + } + return i != 0 + case *ast.StringValue: + b, err := strconv.ParseBool(value.Value) + if err != nil { + return nil + } + return b + default: + return nil + } +} + func astInt[T constraints.Integer](value ast.Value) interface{} { switch value := value.(type) { case *ast.BooleanValue: @@ -459,10 +539,12 @@ func RegisterScalar[S any](ctx *Context, to_json func(interface{})interface{}, f ParseLiteral: from_ast, }) - ctx.TypeTypes[reflect_type] = serialized_type - ctx.TypeMap[serialized_type] = TypeInfo{ + ctx.TypeMap[serialized_type] = &TypeInfo{ + Serialized: serialized_type, + Reflect: reflect_type, Type: gql, } + ctx.TypeTypes[reflect_type] = ctx.TypeMap[serialized_type] return nil } @@ -560,25 +642,35 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { DB: db, Log: log, - TypeMap: map[SerializedType]TypeInfo{}, - TypeTypes: map[reflect.Type]SerializedType{}, + TypeMap: map[SerializedType]*TypeInfo{}, + TypeTypes: map[reflect.Type]*TypeInfo{}, - Extensions: map[ExtType]ExtensionInfo{}, - ExtensionTypes: map[reflect.Type]ExtType{}, + Extensions: map[ExtType]*ExtensionInfo{}, + ExtensionTypes: map[reflect.Type]*ExtensionInfo{}, - Nodes: map[NodeType]NodeInfo{}, - NodeTypes: map[string]NodeType{}, + Nodes: map[NodeType]*NodeInfo{}, + NodeTypes: map[string]*NodeInfo{}, nodeMap: map[NodeID]*Node{}, } var err error + err = RegisterScalar[bool](ctx, identity, coerce[bool], astBool[bool]) + if err != nil { + return nil, err + } + err = RegisterScalar[int](ctx, identity, coerce[int], astInt[int]) if err != nil { return nil, err } + err = RegisterScalar[uint32](ctx, identity, coerce[uint32], astInt[uint32]) + if err != nil { + return nil, err + } + err = RegisterScalar[uint8](ctx, identity, coerce[uint8], astInt[uint8]) if err != nil { return nil, err @@ -619,6 +711,16 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = RegisterScalar[NodeType](ctx, identity, coerce[NodeType], astInt[NodeType]) + if err != nil { + return nil, err + } + + err = RegisterObject[Node](ctx) + if err != nil { + return nil, err + } + err = RegisterObject[WaitInfo](ctx) if err != nil { return nil, err @@ -649,6 +751,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = RegisterNodeType(ctx, "Base", []ExtType{}, map[string]FieldIndex{}) + if err != nil { + return nil, err + } + _, err = BuildSchema(ctx, graphql.NewObject(graphql.ObjectConfig{ Name: "Query", Fields: graphql.Fields{ diff --git a/db.go b/db.go new file mode 100644 index 0000000..de78492 --- /dev/null +++ b/db.go @@ -0,0 +1,38 @@ +package graphvent + +import ( + badger "github.com/dgraph-io/badger/v3" +) + +func WriteNodeInit(ctx *Context, node *Node) error { + return ctx.DB.Update(func(db *badger.Txn) error { + // Write node private key + // Write node type + // Write Node buffer size + // Write node extension list + // For each extension: + // Write each extension's current value + return nil + }) +} + +func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error { + return ctx.DB.Update(func(db *badger.Txn) error { + // Write the signal queue if it needs to be written + // For each ext in changes + // Write each change + return nil + }) +} + +func LoadNode(ctx *Context, id NodeID) (*Node, error) { + err := ctx.DB.Update(func(db *badger.Txn) error { + return nil + }) + + if err != nil { + return nil, err + } + + return nil, nil +} diff --git a/event_test.go b/event_test.go index 1f06104..a7d5f63 100644 --- a/event_test.go +++ b/event_test.go @@ -15,7 +15,7 @@ func TestEvent(t *testing.T) { event_public, event_private, err := ed25519.GenerateKey(rand.Reader) event_listener := NewListenerExt(100) - event, err := NewNode(ctx, event_private, "Base", 100, nil, NewEventExt(KeyID(event_public), "Test Event"), &TestEventExt{time.Second}, event_listener) + event, err := NewNode(ctx, event_private, "Base", 100, NewEventExt(KeyID(event_public), "Test Event"), &TestEventExt{time.Second}, event_listener) fatalErr(t, err) response, signals := testSend(t, ctx, NewEventControlSignal("ready?"), event, event) diff --git a/node.go b/node.go index c9b60f5..eae4100 100644 --- a/node.go +++ b/node.go @@ -226,7 +226,7 @@ func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType if exists == false { fields[req] = fmt.Errorf("%+v does not have %+v extension", node.ID, ext_type) } else { - fields[req] = reflect.ValueOf(ext).FieldByIndex(ext_info.Fields[req]).Interface() + fields[req] = reflect.ValueOf(ext).Elem().FieldByIndex(ext_info.Fields[req]).Interface() } } exts[ext_type] = fields @@ -420,25 +420,20 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size return nil, fmt.Errorf("Attempted to create an existing node") } - def, exists := ctx.Nodes[node_type] - if exists == false { - return nil, fmt.Errorf("Node type %+v not registered in Context", node_type) - } - ext_map := map[ExtType]Extension{} for _, ext := range(extensions) { ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext).Elem()] if exists == false { return nil, fmt.Errorf(fmt.Sprintf("%+v is not a known Extension", reflect.TypeOf(ext))) } - _, exists = ext_map[ext_type] + _, exists = ext_map[ext_type.ExtType] if exists == true { return nil, fmt.Errorf("Cannot add the same extension to a node twice") } - ext_map[ext_type] = ext + ext_map[ext_type.ExtType] = ext } - for _, required_ext := range(def.Extensions) { + for _, required_ext := range(node_type.Extensions) { _, exists := ext_map[required_ext] if exists == false { return nil, fmt.Errorf(fmt.Sprintf("%+v requires %+v", node_type, required_ext)) @@ -448,19 +443,14 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size node := &Node{ Key: key, ID: id, - Type: node_type, + Type: node_type.NodeType, Extensions: ext_map, MsgChan: make(chan RecvMsg, buffer_size), BufferSize: buffer_size, SignalQueue: []QueuedSignal{}, + writeSignalQueue: false, } - err = WriteNodeExtList(ctx, node) - if err != nil { - return nil, err - } - - node.writeSignalQueue = true err = WriteNodeInit(ctx, node) if err != nil { return nil, err @@ -488,22 +478,3 @@ func ExtTypeSuffix(ext_type ExtType) []byte { binary.BigEndian.PutUint64(ret[4:], uint64(ext_type)) return ret } - -func WriteNodeExtList(ctx *Context, node *Node) error { - ctx.Log.Logf("todo", "write node list") - return nil -} - -func WriteNodeInit(ctx *Context, node *Node) error { - ctx.Log.Logf("todo", "write initial node entry") - return nil -} - -func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error { - ctx.Log.Logf("todo", "write node changes") - return nil -} - -func LoadNode(ctx *Context, id NodeID) (*Node, error) { - return nil, fmt.Errorf("TODO: load node + extensions from DB") -} diff --git a/node_test.go b/node_test.go index da4cd89..29b9f5e 100644 --- a/node_test.go +++ b/node_test.go @@ -12,7 +12,7 @@ func TestNodeDB(t *testing.T) { ctx := logTestContext(t, []string{"node", "db"}) node_listener := NewListenerExt(10) - node, err := NewNode(ctx, nil, "Base", 10, nil, NewLockableExt(nil), node_listener) + node, err := NewNode(ctx, nil, "Base", 10, NewLockableExt(nil), node_listener) fatalErr(t, err) _, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { diff --git a/serialize.go b/serialize.go index e5f16af..861fa2b 100644 --- a/serialize.go +++ b/serialize.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "reflect" + "math" "slices" ) @@ -66,7 +67,7 @@ func NodeTypeFor(name string, extensions []ExtType, mappings map[string]FieldInd return NodeType(binary.BigEndian.Uint64(hash[0:8])) } -func SerializeType(t reflect.Type) SerializedType { +func SerializeType(t fmt.Stringer) SerializedType { digest := []byte(t.String()) hash := sha512.Sum512(digest) return SerializedType(binary.BigEndian.Uint64(hash[0:8])) @@ -93,3 +94,439 @@ func Hash(base, data string) SerializedType { func GetFieldTag(tag string) FieldTag { return FieldTag(Hash("GRAPHVENT_FIELD_TAG", tag)) } + +func TypeStack(ctx *Context, t reflect.Type) ([]SerializedType, error) { + info, registered := ctx.TypeTypes[t] + if registered { + return []SerializedType{info.Serialized}, nil + } else { + switch t.Kind() { + case reflect.Map: + key_stack, err := TypeStack(ctx, t.Key()) + if err != nil { + return nil, err + } + + elem_stack, err := TypeStack(ctx, t.Elem()) + if err != nil { + return nil, err + } + + return append([]SerializedType{SerializeType(reflect.Map)}, append(key_stack, elem_stack...)...), nil + case reflect.Pointer: + elem_stack, err := TypeStack(ctx, t.Elem()) + if err != nil { + return nil, err + } + + return append([]SerializedType{SerializeType(reflect.Pointer)}, elem_stack...), nil + case reflect.Slice: + elem_stack, err := TypeStack(ctx, t.Elem()) + if err != nil { + return nil, err + } + + return append([]SerializedType{SerializeType(reflect.Slice)}, elem_stack...), nil + case reflect.Array: + elem_stack, err := TypeStack(ctx, t.Elem()) + if err != nil { + return nil, err + } + + return append([]SerializedType{SerializeType(reflect.Array), SerializedType(t.Len())}, elem_stack...), nil + default: + return nil, fmt.Errorf("Hit %s, which is not a registered type", t.String()) + } + } +} + +func UnwrapStack(ctx *Context, stack []SerializedType) (reflect.Type, []SerializedType, error) { + first := stack[0] + stack = stack[1:] + + info, registered := ctx.TypeMap[first] + if registered { + return info.Reflect, stack, nil + } else { + switch first { + case SerializeType(reflect.Map): + key_type, after_key, err := UnwrapStack(ctx, stack) + if err != nil { + return nil, nil, err + } + + elem_type, after_elem, err := UnwrapStack(ctx, after_key) + if err != nil { + return nil, nil, err + } + + return reflect.MapOf(key_type, elem_type), after_elem, nil + case SerializeType(reflect.Pointer): + elem_type, rest, err := UnwrapStack(ctx, stack) + if err != nil { + return nil, nil, err + } + return reflect.PointerTo(elem_type), rest, nil + case SerializeType(reflect.Slice): + elem_type, rest, err := UnwrapStack(ctx, stack) + if err != nil { + return nil, nil, err + } + return reflect.SliceOf(elem_type), rest, nil + case SerializeType(reflect.Array): + length := int(stack[0]) + + stack = stack[1:] + + elem_type, rest, err := UnwrapStack(ctx, stack) + if err != nil { + return nil, nil, err + } + + return reflect.ArrayOf(length, elem_type), rest, nil + default: + return nil, nil, fmt.Errorf("Type stack %+v not recognized", stack) + } + } +} + +func SerializeValue(ctx *Context, value reflect.Value) ([]byte, error) { + var serialize SerializeFn = nil + + info, registered := ctx.TypeTypes[value.Type()] + if registered { + serialize = info.Serialize + } + + if serialize == nil { + switch value.Type().Kind() { + case reflect.Bool: + if value.Bool() { + return []byte{0xFF}, nil + } else { + return []byte{0x00}, nil + } + + case reflect.Int8: + return []byte{byte(value.Int())}, nil + case reflect.Int16: + return binary.BigEndian.AppendUint16(nil, uint16(value.Int())), nil + case reflect.Int32: + return binary.BigEndian.AppendUint32(nil, uint32(value.Int())), nil + case reflect.Int64: + fallthrough + case reflect.Int: + return binary.BigEndian.AppendUint64(nil, uint64(value.Int())), nil + + case reflect.Uint8: + return []byte{byte(value.Uint())}, nil + case reflect.Uint16: + return binary.BigEndian.AppendUint16(nil, uint16(value.Uint())), nil + case reflect.Uint32: + return binary.BigEndian.AppendUint32(nil, uint32(value.Uint())), nil + case reflect.Uint64: + fallthrough + case reflect.Uint: + return binary.BigEndian.AppendUint64(nil, value.Uint()), nil + + case reflect.Float32: + return binary.BigEndian.AppendUint32(nil, math.Float32bits(float32(value.Float()))), nil + case reflect.Float64: + return binary.BigEndian.AppendUint64(nil, math.Float64bits(value.Float())), nil + + case reflect.String: + len_bytes := make([]byte, 8) + binary.BigEndian.PutUint64(len_bytes, uint64(value.Len())) + return append(len_bytes, []byte(value.String())...), nil + + case reflect.Pointer: + if value.IsNil() { + return []byte{0x00}, nil + } else { + elem, err := SerializeValue(ctx, value.Elem()) + if err != nil { + return nil, err + } + + return append([]byte{0x01}, elem...), nil + } + + case reflect.Slice: + if value.IsNil() { + return []byte{0x00}, nil + } else { + len_bytes := make([]byte, 8) + binary.BigEndian.PutUint64(len_bytes, uint64(value.Len())) + + data := []byte{} + for i := 0; i < value.Len(); i++ { + elem, err := SerializeValue(ctx, value.Index(i)) + if err != nil { + return nil, err + } + + data = append(data, elem...) + } + + return append(len_bytes, data...), nil + } + + case reflect.Array: + data := []byte{} + for i := 0; i < value.Len(); i++ { + elem, err := SerializeValue(ctx, value.Index(i)) + if err != nil { + return nil, err + } + + data = append(data, elem...) + } + return data, nil + + case reflect.Map: + len_bytes := make([]byte, 8) + binary.BigEndian.PutUint64(len_bytes, uint64(value.Len())) + + data := []byte{} + iter := value.MapRange() + for iter.Next() { + k, err := SerializeValue(ctx, iter.Key()) + if err != nil { + return nil, err + } + + data = append(data, k...) + + v, err := SerializeValue(ctx, iter.Value()) + if err != nil { + return nil, err + } + + data = append(data, v...) + } + return append(len_bytes, data...), nil + + case reflect.Struct: + if registered == false { + return nil, fmt.Errorf("Cannot serialize unregistered struct %s", value.Type()) + } else { + data := binary.BigEndian.AppendUint64(nil, uint64(len(info.Fields))) + + for field_tag, field_info := range(info.Fields) { + data = append(data, binary.BigEndian.AppendUint64(nil, uint64(field_tag))...) + field_bytes, err := SerializeValue(ctx, value.FieldByIndex(field_info.Index)) + if err != nil { + return nil, err + } + + data = append(data, field_bytes...) + } + return data, nil + } + + default: + return nil, fmt.Errorf("Don't know how to serialize %s", value.Type()) + } + } else { + return serialize(ctx, value) + } +} + +func split(data []byte, n int) ([]byte, []byte) { + return data[:n], data[n:] +} + +func DeserializeValue(ctx *Context, data []byte, t reflect.Type) (reflect.Value, []byte, error) { + var deserialize DeserializeFn = nil + + info, registered := ctx.TypeTypes[t] + if registered { + deserialize = info.Deserialize + } + + if deserialize == nil { + switch t.Kind() { + case reflect.Bool: + used, left := split(data, 1) + value := reflect.New(t).Elem() + value.SetBool(used[0] != 0x00) + return value, left, nil + + case reflect.Int8: + used, left := split(data, 1) + value := reflect.New(t).Elem() + value.SetInt(int64(used[0])) + return value, left, nil + case reflect.Int16: + used, left := split(data, 2) + value := reflect.New(t).Elem() + value.SetInt(int64(binary.BigEndian.Uint16(used))) + return value, left, nil + case reflect.Int32: + used, left := split(data, 4) + value := reflect.New(t).Elem() + value.SetInt(int64(binary.BigEndian.Uint32(used))) + return value, left, nil + case reflect.Int64: + fallthrough + case reflect.Int: + used, left := split(data, 8) + value := reflect.New(t).Elem() + value.SetInt(int64(binary.BigEndian.Uint64(used))) + return value, left, nil + + case reflect.Uint8: + used, left := split(data, 1) + value := reflect.New(t).Elem() + value.SetUint(uint64(used[0])) + return value, left, nil + case reflect.Uint16: + used, left := split(data, 2) + value := reflect.New(t).Elem() + value.SetUint(uint64(binary.BigEndian.Uint16(used))) + return value, left, nil + case reflect.Uint32: + used, left := split(data, 4) + value := reflect.New(t).Elem() + value.SetUint(uint64(binary.BigEndian.Uint32(used))) + return value, left, nil + case reflect.Uint64: + fallthrough + case reflect.Uint: + used, left := split(data, 8) + value := reflect.New(t).Elem() + value.SetUint(binary.BigEndian.Uint64(used)) + return value, left, nil + + case reflect.Float32: + used, left := split(data, 4) + value := reflect.New(t).Elem() + value.SetFloat(float64(math.Float32frombits(binary.BigEndian.Uint32(used)))) + return value, left, nil + case reflect.Float64: + used, left := split(data, 8) + value := reflect.New(t).Elem() + value.SetFloat(math.Float64frombits(binary.BigEndian.Uint64(used))) + return value, left, nil + + case reflect.String: + length, after_len := split(data, 8) + used, left := split(after_len, int(binary.BigEndian.Uint64(length))) + value := reflect.New(t).Elem() + value.SetString(string(used)) + return value, left, nil + + case reflect.Pointer: + flags, after_flags := split(data, 1) + value := reflect.New(t).Elem() + if flags[0] == 0x00 { + value.SetZero() + return value, after_flags, nil + } else { + elem_value, after_elem, err := DeserializeValue(ctx, after_flags, t.Elem()) + if err != nil { + return reflect.Value{}, nil, err + } + value.Set(elem_value.Addr()) + return value, after_elem, nil + } + + case reflect.Slice: + len_bytes, left := split(data, 8) + length := int(binary.BigEndian.Uint64(len_bytes)) + value := reflect.MakeSlice(t, length, length) + for i := 0; i < length; i++ { + var elem_value reflect.Value + var err error + elem_value, left, err = DeserializeValue(ctx, left, t.Elem()) + if err != nil { + return reflect.Value{}, nil, err + } + value.Index(i).Set(elem_value) + } + return value, left, nil + + case reflect.Array: + value := reflect.New(t).Elem() + left := data + for i := 0; i < t.Len(); i++ { + var elem_value reflect.Value + var err error + elem_value, left, err = DeserializeValue(ctx, left, t.Elem()) + if err != nil { + return reflect.Value{}, nil, err + } + value.Index(i).Set(elem_value) + } + return value, left, nil + + case reflect.Map: + len_bytes, left := split(data, 8) + length := int(binary.BigEndian.Uint64(len_bytes)) + + value := reflect.MakeMapWithSize(t, length) + + for i := 0; i < length; i++ { + var key_value reflect.Value + var val_value reflect.Value + var err error + + key_value, left, err = DeserializeValue(ctx, left, t.Key()) + if err != nil { + return reflect.Value{}, nil, err + } + + val_value, left, err = DeserializeValue(ctx, left, t.Elem()) + if err != nil { + return reflect.Value{}, nil, err + } + + value.SetMapIndex(key_value, val_value) + } + + return value, left, nil + + case reflect.Struct: + info, mapped := ctx.TypeTypes[t] + if mapped { + value := reflect.New(t).Elem() + + num_field_bytes, left := split(data, 8) + num_fields := int(binary.BigEndian.Uint64(num_field_bytes)) + + for i := 0; i < num_fields; i++ { + var tag_bytes []byte + + tag_bytes, left = split(left, 8) + field_tag := FieldTag(binary.BigEndian.Uint64(tag_bytes)) + + field_info, mapped := info.Fields[field_tag] + if mapped { + var field_val reflect.Value + var err error + field_val, left, err = DeserializeValue(ctx, left, field_info.Type) + if err != nil { + return reflect.Value{}, nil, err + } + value.FieldByIndex(field_info.Index).Set(field_val) + } else { + return reflect.Value{}, nil, fmt.Errorf("Unknown field %s on struct %s", field_tag, t) + } + } + if info.PostDeserializeIndex != -1 { + post_deserialize_method := value.Addr().Method(info.PostDeserializeIndex) + post_deserialize_method.Call([]reflect.Value{reflect.ValueOf(ctx)}) + } + return value, left, nil + } else { + return reflect.Value{}, nil, fmt.Errorf("Cannot deserialize unregistered struct %s", t) + } + + default: + return reflect.Value{}, nil, fmt.Errorf("Don't know how to deserialize %s", t) + } + } else { + return deserialize(ctx, data) + } +} + diff --git a/serialize_test.go b/serialize_test.go new file mode 100644 index 0000000..7c2f126 --- /dev/null +++ b/serialize_test.go @@ -0,0 +1,200 @@ +package graphvent + +import ( + "testing" + "reflect" + "github.com/google/uuid" +) + +func testTypeStack[T any](t *testing.T, ctx *Context) { + reflect_type := reflect.TypeFor[T]() + stack, err := TypeStack(ctx, reflect_type) + fatalErr(t, err) + + ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack) + + unwrapped_type, rest, err := UnwrapStack(ctx, stack) + fatalErr(t, err) + + if len(rest) != 0 { + t.Errorf("Types remaining after unwrapping %s: %+v", unwrapped_type, stack) + } + + if unwrapped_type != reflect_type { + t.Errorf("Unwrapped type[%+v] doesn't match original[%+v]", unwrapped_type, reflect_type) + } + + ctx.Log.Logf("test", "Unwrapped type[%s]: %s", reflect_type, reflect_type) +} + +func TestSerializeTypes(t *testing.T) { + ctx := logTestContext(t, []string{"test"}) + + testTypeStack[int](t, ctx) + testTypeStack[map[int]string](t, ctx) + testTypeStack[string](t, ctx) + testTypeStack[*string](t, ctx) + testTypeStack[*map[string]*map[*string]int](t, ctx) + testTypeStack[[5]int](t, ctx) + testTypeStack[uuid.UUID](t, ctx) + testTypeStack[NodeID](t, ctx) +} + +func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) { + serialized, err := SerializeValue(ctx, reflect.ValueOf(value)) + fatalErr(t, err) + + ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized) + + deserialized, left, err := DeserializeValue(ctx, serialized, reflect.TypeFor[T]()) + fatalErr(t, err) + + if len(left) != 0 { + t.Fatalf("Data left after deserialize[%+v]: %+v", deserialized, left) + } + + if reflect.TypeFor[T]() != deserialized.Type() { + t.Fatalf("Type mismatch after deserialize %s != %s", reflect.TypeFor[T](), deserialized.Type()) + } + + val, ok := deserialized.Interface().(T) + if ok == false { + t.Fatalf("Deserialized type[%s] can't cast to type %s", deserialized.Type(), reflect.TypeFor[T]()) + } + + if value != val { + t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized) + } + + ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, val) +} + +func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) { + serialized, err := SerializeValue(ctx, reflect.ValueOf(value)) + fatalErr(t, err) + + ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized) + + deserialized, left, err := DeserializeValue(ctx, serialized, reflect.TypeFor[L]()) + fatalErr(t, err) + + if len(left) != 0 { + t.Fatalf("Data left after deserialize[%+v]: %+v", deserialized, left) + } + + if reflect.TypeFor[L]() != deserialized.Type() { + t.Fatalf("Type mismatch after deserialize %s != %s", reflect.TypeFor[L](), deserialized.Type()) + } + + val, ok := deserialized.Interface().(L) + if ok == false { + t.Fatalf("Deserialized type[%s] can't cast to type %s", deserialized.Type(), reflect.TypeFor[L]()) + } + + for i, item := range(value) { + if item != val[i] { + t.Fatalf("Deserialized list %+v does not match original %+v", value, val) + } + } + + ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, val) +} + +func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) { + serialized, err := SerializeValue(ctx, reflect.ValueOf(value)) + fatalErr(t, err) + + ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized) + + deserialized, left, err := DeserializeValue(ctx, serialized, reflect.TypeFor[P]()) + fatalErr(t, err) + + if len(left) != 0 { + t.Fatalf("Data left after deserialize[%+v]: %+v", deserialized, left) + } + + if reflect.TypeFor[P]() != deserialized.Type() { + t.Fatalf("Type mismatch after deserialize %s != %s", reflect.TypeFor[P](), deserialized.Type()) + } + + val, ok := deserialized.Interface().(P) + if ok == false { + t.Fatalf("Deserialized type[%s] can't cast to type %s", deserialized.Type(), reflect.TypeFor[P]()) + } + + if value == nil && val == nil { + ctx.Log.Logf("test", "Deserialized nil") + } else if value == nil { + t.Fatalf("Non-nil value[%+v] returned for nil value", val) + } else if val == nil { + t.Fatalf("Nil value returned for non-nil value[%+v]", value) + } else if *val != *value { + t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized) + } else { + ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", *value, *val) + } +} + +func testSerialize[T any](t *testing.T, ctx *Context, value T) { + serialized, err := SerializeValue(ctx, reflect.ValueOf(value)) + fatalErr(t, err) + + ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized) + + deserialized, left, err := DeserializeValue(ctx, serialized, reflect.TypeFor[T]()) + fatalErr(t, err) + + if len(left) != 0 { + t.Fatalf("Data left after deserialize[%+v]: %+v", deserialized, left) + } + + if reflect.TypeFor[T]() != deserialized.Type() { + t.Fatalf("Type mismatch after deserialize %s != %s", reflect.TypeFor[T](), deserialized.Type()) + } + + val, ok := deserialized.Interface().(T) + if ok == false { + t.Fatalf("Deserialized type[%s] can't cast to type %s", deserialized.Type(), reflect.TypeFor[T]()) + } + + ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, val) +} + +func TestSerializeValues(t *testing.T) { + ctx := logTestContext(t, []string{"test"}) + + testSerializeCompare[int8](t, ctx, -64) + testSerializeCompare[int16](t, ctx, -64) + testSerializeCompare[int32](t, ctx, -64) + testSerializeCompare[int64](t, ctx, -64) + testSerializeCompare[int](t, ctx, -64) + + testSerializeCompare[uint8](t, ctx, 64) + testSerializeCompare[uint16](t, ctx, 64) + testSerializeCompare[uint32](t, ctx, 64) + testSerializeCompare[uint64](t, ctx, 64) + testSerializeCompare[uint](t, ctx, 64) + + testSerializeCompare[string](t, ctx, "test") + + a := 12 + testSerializePointer[*int](t, ctx, &a) + + b := "test" + testSerializePointer[*string](t, ctx, nil) + testSerializePointer[*string](t, ctx, &b) + + testSerializeList(t, ctx, []int{1, 2, 3, 4, 5}) + + testSerializeCompare[bool](t, ctx, true) + testSerializeCompare[bool](t, ctx, false) + testSerializeCompare[int](t, ctx, -1) + testSerializeCompare[uint](t, ctx, 1) + testSerializeCompare[NodeID](t, ctx, RandID()) + testSerializeCompare[*int](t, ctx, nil) + testSerializeCompare(t, ctx, "string") + + node, err := NewNode(ctx, nil, "Base", 100) + fatalErr(t, err) + testSerialize(t, ctx, node) +}