From c4e5054e072613634947cd8bd65ff3e8f1f1907b Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Tue, 12 Sep 2023 16:56:01 -0600 Subject: [PATCH] Fixed array serialization and added NodeID serialization --- context.go | 80 +++++++++++++++------------------------------------- serialize.go | 66 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 59 deletions(-) diff --git a/context.go b/context.go index 07d8213..7899d5b 100644 --- a/context.go +++ b/context.go @@ -569,16 +569,13 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { err = ctx.RegisterKind(reflect.Array, ArrayType, func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ var data []byte - type_stack := []SerializedType{ctx_type} + type_stack := []SerializedType{ctx_type, SerializedType(reflect_type.Len())} if value == nil { data = nil } else if value.IsZero() { return SerializedValue{}, fmt.Errorf("don't know what zero array means...") - } else if value.Len() == 0 { - data = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} } else { - data := make([]byte, 8) - binary.BigEndian.PutUint64(data, uint64(value.Len())) + data := []byte{} var element SerializedValue var err error for i := 0; i < value.Len(); i += 1 { @@ -603,18 +600,19 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { data, }, nil }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ + if len(value.TypeStack) < 1 { + return nil, nil, SerializedValue{}, fmt.Errorf("no array size in type stack") + } + array_size := int(value.TypeStack[0]) + value.TypeStack = value.TypeStack[1:] if value.Data == nil { elem_type, _, _, err := DeserializeValue(ctx, value) if err != nil { return nil, nil, SerializedValue{}, err } - return reflect.SliceOf(elem_type), nil, value, nil - } else if len(value.Data) < 8 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize slice") + return reflect.ArrayOf(array_size, elem_type), nil, value, nil } else { - slice_length := binary.BigEndian.Uint64(value.Data[0:8]) - value.Data = value.Data[8:] - if slice_length == 0x00 { + if array_size == 0x00 { elem_type, _, remaining, err := DeserializeValue(ctx, SerializedValue{ value.TypeStack, nil, @@ -622,7 +620,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { if err != nil { return nil, nil, SerializedValue{}, err } - array_type := reflect.ArrayOf(0, elem_type) + array_type := reflect.ArrayOf(array_size, elem_type) array_value := reflect.New(array_type).Elem() return array_type, &array_value, SerializedValue{ remaining.TypeStack, @@ -632,7 +630,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { var reflect_value *reflect.Value = nil var reflect_type reflect.Type = nil saved_type_stack := value.TypeStack - for i := 0; i < int(slice_length); i += 1 { + for i := 0; i < array_size; i += 1 { var element_type reflect.Type var element_value *reflect.Value element_type, element_value, value, err = DeserializeValue(ctx, value) @@ -640,11 +638,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, nil, value, err } if reflect_value == nil { - reflect_type = reflect.ArrayOf(int(slice_length), element_type) + reflect_type = reflect.ArrayOf(array_size, element_type) real_value := reflect.New(reflect_type).Elem() reflect_value = &real_value } - if i != (int(slice_length) - 1) { + if i != (array_size - 1) { value.TypeStack = saved_type_stack } slice_index_ptr := reflect_value.Index(i) @@ -997,76 +995,44 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - // TODO: Make registering interfaces cleaner - var extension Extension = nil - err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtensionType, SerializeInterface, DeserializeInterface[Extension]()) + err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeType, SerializeUintN(4), DeserializeUintN[NodeType](4)) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(ListenerExt{}), SerializedType(ListenerExtType), SerializeStruct[ListenerExt](ctx), DeserializeStruct[ListenerExt](ctx)) + err = ctx.RegisterType(reflect.TypeOf(RandID()), NodeIDType, SerializeArray, DeserializeArray[NodeID](ctx)) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(GroupExt{}), SerializedType(GroupExtType), SerializeStruct[GroupExt](ctx), DeserializeStruct[GroupExt](ctx)) + // TODO: Make registering interfaces cleaner + var extension Extension = nil + err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtensionType, SerializeInterface, DeserializeInterface[Extension]()) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(GQLExt{}), SerializedType(GQLExtType), SerializeStruct[GQLExt](ctx), DeserializeStruct[GQLExt](ctx)) + err = ctx.RegisterType(reflect.TypeOf(ListenerExt{}), SerializedType(ListenerExtType), SerializeStruct[ListenerExt](ctx), DeserializeStruct[ListenerExt](ctx)) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeType, SerializeUintN(4), DeserializeUintN[NodeType](4)) + err = ctx.RegisterType(reflect.TypeOf(GroupExt{}), SerializedType(GroupExtType), SerializeStruct[GroupExt](ctx), DeserializeStruct[GroupExt](ctx)) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(Node{}), NodeStructType, SerializeStruct[Node](ctx), DeserializeStruct[Node](ctx)) + err = ctx.RegisterType(reflect.TypeOf(GQLExt{}), SerializedType(GQLExtType), SerializeStruct[GQLExt](ctx), DeserializeStruct[GQLExt](ctx)) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(RandID()), NodeIDType, - func(ctx *Context, ctx_type SerializedType, t reflect.Type, value *reflect.Value) (SerializedValue, error) { - var id_ser []byte = nil - if value != nil { - var err error = nil - id_ser, err = value.Interface().(NodeID).MarshalBinary() - if err != nil { - return SerializedValue{}, err - } - } - return SerializedValue{ - []SerializedType{ctx_type}, - id_ser, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(ZeroID), nil, value, nil - } else { - var err error - var id_bytes []byte - id_bytes, value, err = value.PopData(16) - if err != nil { - return nil, nil, value, err - } - - id, err := IDFromBytes(id_bytes) - if err != nil { - return nil, nil, value, err - } - - id_value := reflect.ValueOf(id) - return id_value.Type(), &id_value, value, nil - } - }) + err = ctx.RegisterType(reflect.TypeOf(Node{}), NodeStructType, SerializeStruct[Node](ctx), DeserializeStruct[Node](ctx)) if err != nil { return nil, err } + err = ctx.RegisterType(reflect.TypeOf(Up), SignalDirectionType, func(ctx *Context, ctx_type SerializedType, t reflect.Type, value *reflect.Value) (SerializedValue, error) { var data []byte = nil diff --git a/serialize.go b/serialize.go index 68b2d86..adf35bd 100644 --- a/serialize.go +++ b/serialize.go @@ -115,7 +115,69 @@ var ( NodeIDType = NewSerializedType("NODE_ID") ) -func SerializeUintN(size int)(func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error)){ +func SerializeArray(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + type_stack := []SerializedType{ctx_type} + if value == nil { + return SerializedValue{ + type_stack, + nil, + }, nil + } else if value.IsZero() { + return SerializedValue{}, fmt.Errorf("don't know what zero array means...") + } else { + var element SerializedValue + var err error + var data []byte + for i := 0; i < value.Len(); i += 1 { + val := value.Index(i) + element, err = SerializeValue(ctx, reflect_type.Elem(), &val) + if err != nil { + return SerializedValue{}, err + } + data = append(data, element.Data...) + } + return SerializedValue{ + type_stack, + data, + }, nil + } +} + +func DeserializeArray[T any](ctx *Context)(func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ + var zero T + array_type := reflect.TypeOf(zero) + array_size := array_type.Len() + zero_value, err := SerializeValue(ctx, array_type.Elem(), nil) + if err != nil { + panic(err) + } + saved_type_stack := zero_value.TypeStack + return func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ + if value.Data == nil { + return array_type, nil, value, nil + } else { + array_value := reflect.New(array_type).Elem() + for i := 0; i < array_size; i += 1 { + var element_value *reflect.Value + var err error + tmp_value := SerializedValue{ + saved_type_stack, + value.Data, + } + _, element_value, tmp_value, err = DeserializeValue(ctx, tmp_value) + if err != nil { + return nil, nil, value, err + } + value.Data = tmp_value.Data + array_elem := array_value.Index(i) + array_elem.Set(*element_value) + } + return array_type, &array_value, value, nil + } + } +} + +func SerializeUintN(size int)(func(*Context,SerializedType,reflect.Type,*reflect.Value)(SerializedValue,error)){ var fill_data func([]byte, uint64) = nil switch size { case 1: @@ -597,7 +659,7 @@ func DeserializeValue(ctx *Context, value SerializedValue) (reflect.Type, *refle ctx_name = kind_info.Reflect.String() } - ctx.Log.Logf("serialize", "Deserializing: %+v(0x%d) - %+v", ctx_name, ctx_type, deserialize) + ctx.Log.Logf("serialize", "Deserializing: %+v(%+v) - %+v", ctx_name, ctx_type, value.TypeStack) if value.Data == nil { reflect_type, _, value, err = deserialize(ctx, value)