From 045304f9f6edc4786f4ed039e1d99675ce42a96b Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Mon, 11 Sep 2023 21:47:53 -0600 Subject: [PATCH] Moved int, struct, and interface serialization to functions to make registering types easy --- context.go | 402 +++++++++++----------------------------------- node.go | 14 +- serialize.go | 352 +++++++++++++++++++++++++++++++++++++++- serialize_test.go | 89 ++++++++-- 4 files changed, 523 insertions(+), 334 deletions(-) diff --git a/context.go b/context.go index 454ea29..40f8693 100644 --- a/context.go +++ b/context.go @@ -174,6 +174,10 @@ func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedTy return fmt.Errorf("Cannot register field with type %+v, type already registered in context", reflect_type) } + if serialize == nil || deserialize == nil { + return fmt.Errorf("Cannot register field without serialize/deserialize functions") + } + type_info := TypeInfo{ Reflect: reflect_type, Type: ctx_type, @@ -183,6 +187,8 @@ func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedTy ctx.Types[ctx_type] = &type_info ctx.TypeReflects[reflect_type] = &type_info + ctx.Log.Logf("serialize", "Registered Type: %+v - %+v", reflect_type, ctx_type) + return nil } @@ -325,15 +331,12 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - // TODO: figure out why this doesn't break in the simple test, but breaks in TestGQLDB err = ctx.RegisterKind(reflect.Struct, StructType, func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - serialized_value := SerializedValue{ - []SerializedType{ctx_type}, - nil, - } + type_stack := []SerializedType{ctx_type} + var data []byte if value != nil { - serialized_value.Data = make([]byte, 8) + data = make([]byte, 8) } num_fields := uint64(0) @@ -348,29 +351,31 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { field_hash := Hash(FieldNameBase, gv_tag) field_hash_bytes := make([]byte, 8) binary.BigEndian.PutUint64(field_hash_bytes, uint64(field_hash)) - if value == nil { - field_ser, err := SerializeValue(ctx, field.Type, nil) + if value != nil { + field_value := value.FieldByIndex(field.Index) + field_ser, err := SerializeValue(ctx, field.Type, &field_value) if err != nil { return SerializedValue{}, err } - serialized_value.TypeStack = append(serialized_value.TypeStack, field_ser.TypeStack...) - } else { - field_value := value.FieldByIndex(field.Index) - field_ser, err := SerializeValue(ctx, field.Type, &field_value) + + field_bytes, err := field_ser.MarshalBinary() if err != nil { return SerializedValue{}, err } - serialized_value.TypeStack = append(serialized_value.TypeStack, field_ser.TypeStack...) - serialized_value.Data = append(serialized_value.Data, field_hash_bytes...) - serialized_value.Data = append(serialized_value.Data, field_ser.Data...) + data = append(data, field_hash_bytes...) + data = append(data, field_bytes...) } } } + if value != nil { - binary.BigEndian.PutUint64(serialized_value.Data[0:8], num_fields) + binary.BigEndian.PutUint64(data[0:8], num_fields) } - return serialized_value, nil + return SerializedValue{ + type_stack, + data, + }, nil }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ if value.Data == nil { return reflect.TypeOf(map[uint64]reflect.Value{}), nil, value, nil @@ -390,24 +395,29 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { if num_fields == 0 { return map_type, &map_value, value, nil } else { - tmp_value := value + tmp_data := value.Data for i := 0; i < num_fields; i += 1 { - var field_hash_bytes []byte - field_hash_bytes, tmp_value, err = tmp_value.PopData(8) + if len(tmp_data) < 8 { + return nil, nil, value, fmt.Errorf("Not enough data to deserialize struct field") + } + field_hash := binary.BigEndian.Uint64(tmp_data[0:8]) + tmp_data = tmp_data[8:] + var field_value SerializedValue + field_value, tmp_data, err = ParseSerializedValue(tmp_data) if err != nil { return nil, nil, value, err } - field_hash := binary.BigEndian.Uint64(field_hash_bytes) field_hash_value := reflect.ValueOf(field_hash) var elem_value *reflect.Value - _, elem_value, tmp_value, err = DeserializeValue(ctx, tmp_value) + _, elem_value, _, err = DeserializeValue(ctx, field_value) if err != nil { return nil, nil, value, err } map_value.SetMapIndex(field_hash_value, reflect.ValueOf(*elem_value)) } - return map_type, &map_value, tmp_value, nil + value.Data = tmp_data + return map_type, &map_value, value, nil } } }) @@ -415,32 +425,6 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterKind(reflect.Int, IntType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = make([]byte, 8) - binary.BigEndian.PutUint64(data, uint64(value.Int())) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(0), nil, value, nil - } - if len(value.Data) < 8 { - return nil, nil, SerializedValue{}, fmt.Errorf("invalid length: %d/8", len(value.Data)) - } - int_val := reflect.ValueOf(int(binary.BigEndian.Uint64(value.Data[0:8]))) - value.Data = value.Data[8:] - return int_val.Type(), &int_val, value, nil - }) - if err != nil { - return nil, err - } - err = ctx.RegisterKind(reflect.Bool, BoolType, func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ var data []byte = nil @@ -544,34 +528,6 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterKind(reflect.Uint32, UInt32Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - data := make([]byte, 4) - if value != nil { - binary.BigEndian.PutUint32(data, uint32(value.Uint())) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(uint32(0)), nil, value, nil - } else { - if len(value.Data) < 4 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize uint32") - } - val := binary.BigEndian.Uint32(value.Data[0:4]) - value.Data = value.Data[4:] - int_value := reflect.ValueOf(val) - - return int_value.Type(), &int_value, value, nil - } - }) - if err != nil { - return nil, err - } - err = ctx.RegisterKind(reflect.String, StringType, func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ if value == nil { @@ -651,47 +607,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterKind(reflect.Interface, InterfaceType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte - type_stack := []SerializedType{ctx_type} - if value == nil { - data = nil - } else if value.IsZero() { - } else { - elem_value := value.Elem() - elem, err := SerializeValue(ctx, elem_value.Type(), &elem_value) - if err != nil { - return SerializedValue{}, err - } - data, err = elem.MarshalBinary() - if err != nil { - return SerializedValue{}, err - } - } - return SerializedValue{ - type_stack, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf((interface{})(nil)), nil, value, nil - } else { - var elem_value *reflect.Value - var elem_ser SerializedValue - var elem_type reflect.Type - var err error - elem_ser, value.Data, err = ParseSerializedValue(value.Data) - elem_type, elem_value, _, err = DeserializeValue(ctx, elem_ser) - if err != nil { - return nil, nil, value, err - } - ptr_type := reflect.PointerTo(elem_type) - ptr_value := reflect.New(ptr_type).Elem() - ptr_value.Set(elem_value.Addr()) - return ptr_type, &ptr_value, value, nil - } - }) + err = ctx.RegisterKind(reflect.Interface, InterfaceType, SerializeInterface, DeserializeInterface[interface{}]()) if err != nil { return nil, err } @@ -862,230 +778,52 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterKind(reflect.Int8, Int8Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = []byte{byte(value.Int())} - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(int8(0)), nil, value, nil - } else { - if len(value.Data) < 1 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize int8") - } - i := int8(value.Data[0]) - value.Data = value.Data[1:] - val := reflect.ValueOf(i) - return val.Type(), &val, value, nil - } - }) + err = ctx.RegisterKind(reflect.Int8, Int8Type, SerializeIntN(1), DeserializeIntN[int8](1)) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint8, UInt8Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = []byte{uint8(value.Uint())} - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(uint8(0)), nil, value, nil - } else { - if len(value.Data) < 1 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize uint8") - } - i := uint8(value.Data[0]) - value.Data = value.Data[1:] - val := reflect.ValueOf(i) - return val.Type(), &val, value, nil - } - }) + err = ctx.RegisterKind(reflect.Int16, Int16Type, SerializeIntN(2), DeserializeIntN[int16](2)) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint16, UInt16Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = make([]byte, 2) - binary.BigEndian.PutUint16(data, uint16(value.Uint())) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(uint16(0)), nil, value, nil - } else { - if len(value.Data) < 2 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize uint16") - } - val := binary.BigEndian.Uint16(value.Data[0:2]) - value.Data = value.Data[2:] - i := reflect.ValueOf(val) - - return i.Type(), &i, value, nil - } - }) + err = ctx.RegisterKind(reflect.Int32, Int32Type, SerializeIntN(4), DeserializeIntN[int32](4)) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int16, Int16Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = make([]byte, 2) - binary.BigEndian.PutUint16(data, uint16(value.Int())) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(int16(0)), nil, value, nil - } else { - if len(value.Data) < 2 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize uint16") - } - val := int16(binary.BigEndian.Uint16(value.Data[0:2])) - value.Data = value.Data[2:] - i := reflect.ValueOf(val) - - return i.Type(), &i, value, nil - } - }) + err = ctx.RegisterKind(reflect.Int64, Int64Type, SerializeIntN(8), DeserializeIntN[int64](8)) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int32, Int32Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = make([]byte, 4) - binary.BigEndian.PutUint32(data, uint32(value.Int())) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(int32(0)), nil, value, nil - } else { - if len(value.Data) < 4 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize uint16") - } - val := int32(binary.BigEndian.Uint32(value.Data[0:4])) - value.Data = value.Data[4:] - i := reflect.ValueOf(val) - - return i.Type(), &i, value, nil - } - }) + err = ctx.RegisterKind(reflect.Int, IntType, SerializeIntN(8), DeserializeIntN[int](8)) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint, UIntType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = make([]byte, 8) - binary.BigEndian.PutUint64(data, value.Uint()) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(uint(0)), nil, value, nil - } else { - if len(value.Data) < 8 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize SerializedType") - } - val := uint(binary.BigEndian.Uint64(value.Data[0:8])) - value.Data = value.Data[8:] - i := reflect.ValueOf(val) - - return i.Type(), &i, value, nil - } - }) + err = ctx.RegisterKind(reflect.Uint8, UInt8Type, SerializeUintN(1), DeserializeUintN[uint8](1)) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint64, UInt64Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = make([]byte, 8) - binary.BigEndian.PutUint64(data, value.Uint()) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(SerializedType(0)), nil, value, nil - } else { - if len(value.Data) < 8 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize SerializedType") - } - val := binary.BigEndian.Uint64(value.Data[0:8]) - value.Data = value.Data[8:] - i := reflect.ValueOf(val) + err = ctx.RegisterKind(reflect.Uint16, UInt16Type, SerializeUintN(2), DeserializeUintN[uint16](2)) + if err != nil { + return nil, err + } - return i.Type(), &i, value, nil - } - }) + err = ctx.RegisterKind(reflect.Uint32, UInt32Type, SerializeUintN(4), DeserializeUintN[uint32](4)) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int64, Int64Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = make([]byte, 8) - binary.BigEndian.PutUint64(data, uint64(value.Int())) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(int64(0)), nil, value, nil - } else { - if len(value.Data) < 8 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize SerializedType") - } - val := int64(binary.BigEndian.Uint64(value.Data[0:8])) - value.Data = value.Data[8:] - i := reflect.ValueOf(val) + err = ctx.RegisterKind(reflect.Uint64, UInt64Type, SerializeUintN(8), DeserializeUintN[uint64](8)) + if err != nil { + return nil, err + } - return i.Type(), &i, value, nil - } - }) + err = ctx.RegisterKind(reflect.Uint, UIntType, SerializeUintN(8), DeserializeUintN[uint](8)) if err != nil { return nil, err } @@ -1195,10 +933,50 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + // TODO: move functions for string serialize/deserialize out of RegisterKind + /* err = ctx.RegisterType(reflect.TypeOf(StringError("")), ErrorType, nil, nil) if err != nil { return nil, err } + */ + + err = ctx.RegisterType(reflect.TypeOf(ExtType(0)), ExtTypeType, SerializeUintN(4), DeserializeUintN[ExtType](4)) + if err != nil { + return nil, err + } + + // TODO: Make registering interfaces cleaner + var extension Extension = nil + err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtensionType, SerializeInterface, DeserializeInterface[Extension]()) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(ListenerExt{}), SerializedType(ListenerExtType), SerializeStruct[ListenerExt](ctx), DeserializeStruct[ListenerExt](ctx)) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(GroupExt{}), SerializedType(GroupExtType), SerializeStruct[GroupExt](ctx), DeserializeStruct[GroupExt](ctx)) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(GQLExt{}), SerializedType(GQLExtType), SerializeStruct[GQLExt](ctx), DeserializeStruct[GQLExt](ctx)) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeType, SerializeUintN(4), DeserializeUintN[NodeType](4)) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(Node{}), NodeStructType, SerializeStruct[Node](ctx), DeserializeStruct[Node](ctx)) + if err != nil { + return nil, err + } err = ctx.RegisterType(reflect.TypeOf(RandID()), NodeIDType, func(ctx *Context, ctx_type SerializedType, t reflect.Type, value *reflect.Value) (SerializedValue, error) { diff --git a/node.go b/node.go index 4439cdd..a48db82 100644 --- a/node.go +++ b/node.go @@ -88,25 +88,25 @@ type PendingSignal struct { // Default message channel size for nodes // Nodes represent a group of extensions that can be collectively addressed type Node struct { - Key ed25519.PrivateKey `gv:""` + Key ed25519.PrivateKey `gv:"key"` ID NodeID - Type NodeType `gv:""` + Type NodeType `gv:"type"` Extensions map[ExtType]Extension `gv:"extensions"` - Policies map[PolicyType]Policy `gv:""` + Policies map[PolicyType]Policy `gv:"policies"` - PendingACLs map[uuid.UUID]PendingACL `gv:""` - PendingSignals map[uuid.UUID]PendingSignal `gv:""` + PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"` + PendingSignals map[uuid.UUID]PendingSignal `gv:"pending_signal"` // Channel for this node to receive messages from the Context MsgChan chan *Message // Size of MsgChan - BufferSize uint32 `gv:""` + BufferSize uint32 `gv:"buffer_size"` // Channel for this node to process delayed signals TimeoutChan <-chan time.Time Active atomic.Bool - SignalQueue []QueuedSignal `gv:""` + SignalQueue []QueuedSignal `gv:"signal_queue"` NextSignal *QueuedSignal } diff --git a/serialize.go b/serialize.go index c76cfa3..68b2d86 100644 --- a/serialize.go +++ b/serialize.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "reflect" + "sort" ) const ( @@ -25,6 +26,11 @@ func Hash(base string, name string) SerializedType { } type SerializedType uint64 + +func (t SerializedType) String() string { + return fmt.Sprintf("0x%x", uint64(t)) +} + type ExtType SerializedType type NodeType SerializedType type SignalType SerializedType @@ -102,9 +108,353 @@ var ( ReqStateType = NewSerializedType("REQ_STATE") SignalDirectionType = NewSerializedType("SIGNAL_DIRECTION") + NodeStructType = NewSerializedType("NODE_STRUCT") + NodeTypeType = NewSerializedType("NODE_TYPE") + ExtTypeType = NewSerializedType("EXT_TYPE") + ExtensionType = NewSerializedType("EXTENSION") NodeIDType = NewSerializedType("NODE_ID") ) +func SerializeUintN(size int)(func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error)){ + var fill_data func([]byte, uint64) = nil + switch size { + case 1: + fill_data = func(data []byte, val uint64) { + data[0] = byte(val) + } + case 2: + fill_data = func(data []byte, val uint64) { + binary.BigEndian.PutUint16(data, uint16(val)) + } + case 4: + fill_data = func(data []byte, val uint64) { + binary.BigEndian.PutUint32(data, uint32(val)) + } + case 8: + fill_data = func(data []byte, val uint64) { + binary.BigEndian.PutUint64(data, val) + } + default: + panic(fmt.Sprintf("Cannot serialize uint of size %d", size)) + } + return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + var data []byte = nil + if value != nil { + data = make([]byte, size) + fill_data(data, value.Uint()) + } + return SerializedValue{ + []SerializedType{ctx_type}, + data, + }, nil + } +} + +func DeserializeUintN[T interface{~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64}](size int)(func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ + var get_uint func([]byte) uint64 + switch size { + case 1: + get_uint = func(data []byte) uint64 { + return uint64(data[0]) + } + case 2: + get_uint = func(data []byte) uint64 { + return uint64(binary.BigEndian.Uint16(data)) + } + case 4: + get_uint = func(data []byte) uint64 { + return uint64(binary.BigEndian.Uint32(data)) + } + case 8: + get_uint = func(data []byte) uint64 { + return binary.BigEndian.Uint64(data) + } + default: + panic(fmt.Sprintf("Cannot deserialize int of size %d", size)) + } + var zero T + uint_type := reflect.TypeOf(zero) + return func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ + if value.Data == nil { + return uint_type, nil, value, nil + } else { + var uint_bytes []byte + var err error + uint_bytes, value, err = value.PopData(size) + if err != nil { + return nil, nil, value, err + } + uint_value := reflect.New(uint_type).Elem() + uint_value.SetUint(get_uint(uint_bytes)) + return uint_type, &uint_value, value, nil + } + } +} + +func SerializeIntN(size int)(func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error)){ + var fill_data func([]byte, int64) = nil + switch size { + case 1: + fill_data = func(data []byte, val int64) { + data[0] = byte(val) + } + case 2: + fill_data = func(data []byte, val int64) { + binary.BigEndian.PutUint16(data, uint16(val)) + } + case 4: + fill_data = func(data []byte, val int64) { + binary.BigEndian.PutUint32(data, uint32(val)) + } + case 8: + fill_data = func(data []byte, val int64) { + binary.BigEndian.PutUint64(data, uint64(val)) + } + default: + panic(fmt.Sprintf("Cannot serialize int of size %d", size)) + } + return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + var data []byte = nil + if value != nil { + data = make([]byte, size) + fill_data(data, value.Int()) + } + return SerializedValue{ + []SerializedType{ctx_type}, + data, + }, nil + } +} + +func DeserializeIntN[T interface{~int | ~int8 | ~int16 | ~int32 | ~int64}](size int)(func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ + var get_int func([]byte) int64 + switch size { + case 1: + get_int = func(data []byte) int64 { + return int64(data[0]) + } + case 2: + get_int = func(data []byte) int64 { + return int64(binary.BigEndian.Uint16(data)) + } + case 4: + get_int = func(data []byte) int64 { + return int64(binary.BigEndian.Uint32(data)) + } + case 8: + get_int = func(data []byte) int64 { + return int64(binary.BigEndian.Uint64(data)) + } + default: + panic(fmt.Sprintf("Cannot deserialize int of size %d", size)) + } + var zero T + int_type := reflect.TypeOf(zero) + return func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ + if value.Data == nil { + return int_type, nil, value, nil + } else { + var int_bytes []byte + var err error + int_bytes, value, err = value.PopData(size) + if err != nil { + return nil, nil, value, err + } + int_value := reflect.New(int_type).Elem() + int_value.SetInt(get_int(int_bytes)) + return int_type, &int_value, value, nil + } + } +} + +type FieldInfo struct { + Index []int + TypeStack []SerializedType +} + +type StructInfo struct { + Type reflect.Type + FieldOrder []SerializedType + FieldMap map[SerializedType]FieldInfo +} + +func structInfo[T any](ctx *Context)StructInfo{ + var struct_zero T + struct_type := reflect.TypeOf(struct_zero) + field_order := []SerializedType{} + field_map := map[SerializedType]FieldInfo{} + for _, field := range(reflect.VisibleFields(struct_type)) { + gv_tag, tagged_gv := field.Tag.Lookup("gv") + if tagged_gv == false { + continue + } else { + field_hash := Hash(FieldNameBase, gv_tag) + _, exists := field_map[field_hash] + if exists == true { + panic(fmt.Sprintf("gv tag %s is repeated", gv_tag)) + } else { + field_serialized, err := SerializeValue(ctx, field.Type, nil) + if err != nil { + panic(err) + } + field_map[field_hash] = FieldInfo{ + field.Index, + field_serialized.TypeStack, + } + field_order = append(field_order, field_hash) + } + } + } + + sort.Slice(field_order, func(i, j int)bool { + return uint64(field_order[i]) < uint64(field_order[j]) + }) + + return StructInfo{ + struct_type, + field_order, + field_map, + } +} + +func SerializeStruct[T any](ctx *Context)(func(*Context,SerializedType,reflect.Type,*reflect.Value)(SerializedValue,error)){ + struct_info := structInfo[T](ctx) + return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + type_stack := []SerializedType{ctx_type} + var data []byte + if value == nil { + data = nil + } else { + data = make([]byte, 8) + for _, field_hash := range(struct_info.FieldOrder) { + field_hash_bytes := make([]byte, 8) + binary.BigEndian.PutUint64(field_hash_bytes, uint64(field_hash)) + field_info := struct_info.FieldMap[field_hash] + field_value := value.FieldByIndex(field_info.Index) + field_serialized, err := SerializeValue(ctx, field_value.Type(), &field_value) + if err != nil { + return SerializedValue{}, err + } + data = append(data, field_hash_bytes...) + data = append(data, field_serialized.Data...) + } + binary.BigEndian.PutUint64(data[0:8], uint64(len(struct_info.FieldOrder))) + } + return SerializedValue{ + type_stack, + data, + }, nil + } +} + +func DeserializeStruct[T any](ctx *Context)(func(*Context,SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ + struct_info := structInfo[T](ctx) + return func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error) { + if value.Data == nil { + return struct_info.Type, nil, value, nil + } else { + var num_fields_bytes []byte + var err error + num_fields_bytes, value, err = value.PopData(8) + if err != nil { + return nil, nil, value, err + } + num_fields := int(binary.BigEndian.Uint64(num_fields_bytes)) + + struct_value := reflect.New(struct_info.Type).Elem() + + for i := 0; i < num_fields; i += 1 { + var field_hash_bytes []byte + field_hash_bytes, value, err = value.PopData(8) + if err != nil { + return nil, nil, value, err + } + field_hash := SerializedType(binary.BigEndian.Uint64(field_hash_bytes)) + field_info, exists := struct_info.FieldMap[field_hash] + if exists == false { + return nil, nil, value, fmt.Errorf("Field 0x%x is not valid for %+v: %d", field_hash, struct_info.Type, i) + } + field_value := struct_value.FieldByIndex(field_info.Index) + + tmp_value := SerializedValue{ + field_info.TypeStack, + value.Data, + } + + var field_reflect *reflect.Value + _, field_reflect, tmp_value, err = DeserializeValue(ctx, tmp_value) + if err != nil { + return nil, nil, value, err + } + value.Data = tmp_value.Data + field_value.Set(*field_reflect) + } + + return struct_info.Type, &struct_value, value, err + } + } +} + +func SerializeInterface(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + var data []byte + type_stack := []SerializedType{ctx_type} + if value == nil { + data = nil + } else if value.IsZero() { + data = []byte{0x01} + } else { + data = []byte{0x00} + elem_value := value.Elem() + elem, err := SerializeValue(ctx, elem_value.Type(), &elem_value) + if err != nil { + return SerializedValue{}, err + } + elem_data, err := elem.MarshalBinary() + if err != nil { + return SerializedValue{}, err + } + data = append(data, elem_data...) + } + return SerializedValue{ + type_stack, + data, + }, nil +} + +func DeserializeInterface[T any]()(func(*Context,SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ + return func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ + var interface_zero T + var interface_type = reflect.ValueOf(&interface_zero).Type().Elem() + if value.Data == nil { + return interface_type, nil, value, nil + } else { + var flag_bytes []byte + var err error + flag_bytes, value, err = value.PopData(1) + if err != nil { + return nil, nil, value, err + } + + interface_value := reflect.New(interface_type).Elem() + nil_flag := flag_bytes[0] + if nil_flag == 0x01 { + } else if nil_flag == 0x00 { + var elem_value *reflect.Value + var elem_ser SerializedValue + elem_ser, value.Data, err = ParseSerializedValue(value.Data) + _, elem_value, _, err = DeserializeValue(ctx, elem_ser) + if err != nil { + return nil, nil, value, err + } + interface_value.Set(*elem_value) + } else { + return nil, nil, value, fmt.Errorf("Unknown interface nil_flag value 0x%x", nil_flag) + } + return interface_type, &interface_value, value, nil + } + } +} + type SerializedValue struct { TypeStack []SerializedType Data []byte @@ -247,7 +597,7 @@ func DeserializeValue(ctx *Context, value SerializedValue) (reflect.Type, *refle ctx_name = kind_info.Reflect.String() } - ctx.Log.Logf("serialize", "Deserializing: %+v(0x%d) - %+v", ctx_name, ctx_type, value.TypeStack) + ctx.Log.Logf("serialize", "Deserializing: %+v(0x%d) - %+v", ctx_name, ctx_type, deserialize) if value.Data == nil { reflect_type, _, value, err = deserialize(ctx, value) diff --git a/serialize_test.go b/serialize_test.go index f49ee57..373a623 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -29,6 +29,16 @@ func TestSerializeBasic(t *testing.T) { testSerialize(t, ctx, map[int8]map[*int8]string{}) + var i interface{} = nil + testSerialize(t, ctx, i) + + testSerializeMap(t, ctx, map[int8]interface{}{ + 0: "abcd", + 1: uint32(12345678), + 2: i, + 3: 123, + }) + testSerializeMap(t, ctx, map[int8]int32{ 0: 1234, 2: 5678, @@ -45,18 +55,63 @@ func TestSerializeBasic(t *testing.T) { }) } +type test struct { + Int int `gv:"int"` + Str string `gv:"string"` +} + +func (s test) String() string { + return fmt.Sprintf("%d:%s", s.Int, s.Str) +} + +func TestSerializeStructTags(t *testing.T) { + ctx := logTestContext(t, []string{"test", "serialize"}) + + test_type := NewSerializedType("TEST_STRUCT") + ctx.Log.Logf("test", "TEST_TYPE: %+v", test_type) + ctx.RegisterType(reflect.TypeOf(test{}), test_type, SerializeStruct[test](ctx), DeserializeStruct[test](ctx)) + + test_int := 10 + test_string := "test" + + ret := testSerialize(t, ctx, test{ + test_int, + test_string, + }) + if ret.Int != test_int { + t.Fatalf("Deserialized int %d does not equal test %d", ret.Int, test_int) + } else if ret.Str != test_string { + t.Fatalf("Deserialized string %s does not equal test %s", ret.Str, test_string) + } + + testSerialize(t, ctx, []test{ + { + test_int, + test_string, + }, + { + test_int * 2, + fmt.Sprintf("%s%s", test_string, test_string), + }, + { + test_int * 4, + fmt.Sprintf("%s%s%s", test_string, test_string, test_string), + }, + }) +} + func testSerializeMap[M map[T]R, T, R comparable](t *testing.T, ctx *Context, val M) { v := testSerialize(t, ctx, val) for key, value := range(val) { recreated, exists := v[key] if exists == false { - t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val)) + t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) } else if recreated != value { - t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val)) + t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) } } if len(v) != len(val) { - t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val)) + t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) } } @@ -64,11 +119,11 @@ func testSerializeSliceSlice[S [][]T, T comparable](t *testing.T, ctx *Context, v := testSerialize(t, ctx, val) for i, original := range(val) { if (original == nil && v[i] != nil) || (original != nil && v[i] == nil) { - t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val)) + t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) } for j, o := range(original) { if v[i][j] != o { - t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val)) + t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val) } } } @@ -149,11 +204,12 @@ func testSerializeStruct[T any](t *testing.T, ctx *Context, val T) { } func testSerialize[T any](t *testing.T, ctx *Context, val T) T { - value, err := SerializeAny(ctx, val) + value := reflect.ValueOf(&val).Elem() + value_serialized, err := SerializeValue(ctx, value.Type(), &value) fatalErr(t, err) - ctx.Log.Logf("test", "Serialized %+v to %+v", val, value) + ctx.Log.Logf("test", "Serialized %+v to %+v", val, value_serialized) - ser, err := value.MarshalBinary() + ser, err := value_serialized.MarshalBinary() fatalErr(t, err) ctx.Log.Logf("test", "Binary: %+v", ser) @@ -165,21 +221,26 @@ func testSerialize[T any](t *testing.T, ctx *Context, val T) T { t.Fatal("Data remaining after deserializing value") } - val_type, deserialized_values, remaining_deserialize, err := DeserializeValue(ctx, val_parsed) + val_type, deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_parsed) fatalErr(t, err) if len(remaining_deserialize.Data) != 0 { t.Fatal("Data remaining after deserializing value") } else if len(remaining_deserialize.TypeStack) != 0 { t.Fatal("TypeStack remaining after deserializing value") - } else if val_type != reflect.TypeOf(val) { + } else if val_type != value.Type() { t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - %+v", val_type, reflect.TypeOf(val))) - } else if deserialized_values == nil { + } else if deserialized_value == nil { t.Fatal("DeserializeValue returned no []reflect.Value") - } else if deserialized_values == nil { + } else if deserialized_value == nil { t.Fatal("DeserializeValue returned nil *reflect.Value") - } else if deserialized_values.CanConvert(val_type) == false { + } else if deserialized_value.CanConvert(val_type) == false { t.Fatal("DeserializeValue returned value that can't convert to original value") } - return deserialized_values.Interface().(T) + ctx.Log.Logf("test", "Value: %+v", deserialized_value.Interface()) + if val_type.Kind() == reflect.Interface && deserialized_value.Interface() == nil { + var zero T + return zero + } + return deserialized_value.Interface().(T) }