diff --git a/acl_test.go b/acl_test.go index 68767c3..39001c4 100644 --- a/acl_test.go +++ b/acl_test.go @@ -56,7 +56,7 @@ func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *No } func TestACLBasic(t *testing.T) { - ctx := logTestContext(t, []string{"test", "acl", "policy"}) + ctx := logTestContext(t, []string{"serialize_types", "deserialize_types", "test", "listener_debug", "group", "acl", "policy"}) listener, err := NewNode(ctx, nil, BaseNodeType, 100, nil, NewListenerExt(100)) fatalErr(t, err) diff --git a/context.go b/context.go index 69a5db5..27d003f 100644 --- a/context.go +++ b/context.go @@ -2,15 +2,12 @@ package graphvent import ( "crypto/ecdh" - "sort" - "time" - "encoding/binary" "errors" "fmt" - "math" "reflect" "runtime" "sync" + "time" "github.com/google/uuid" badger "github.com/dgraph-io/badger/v3" @@ -33,15 +30,20 @@ type NodeInfo struct { type TypeInfo struct { Reflect reflect.Type Type SerializedType - Serialize TypeSerialize - Deserialize TypeDeserialize + TypeSerialize TypeSerializeFn + Serialize SerializeFn + TypeDeserialize TypeDeserializeFn + Deserialize DeserializeFn } type KindInfo struct { Reflect reflect.Kind + Base reflect.Type Type SerializedType - Serialize TypeSerialize - Deserialize TypeDeserialize + TypeSerialize TypeSerializeFn + Serialize SerializeFn + TypeDeserialize TypeDeserializeFn + Deserialize DeserializeFn } // A Context stores all the data to run a graphvent process @@ -107,12 +109,17 @@ func (ctx *Context) RegisterPolicy(reflect_type reflect.Type, policy_type Policy return fmt.Errorf("Cannot register policy of type %+v, type already exists in context", policy_type) } - err := ctx.RegisterType(reflect_type, SerializedType(policy_type), SerializeStruct(ctx, reflect_type), DeserializeStruct(ctx, reflect_type)) + policy_info, err := GetStructInfo(ctx, reflect_type) + if err != nil { + return err + } + + err = ctx.RegisterType(reflect_type, SerializedType(policy_type), nil, SerializeStruct(policy_info), nil, DeserializeStruct(policy_info)) if err != nil { return err } - ctx.Log.Logf("serialize", "Registered PolicyType: %+v - %+v", reflect_type, policy_type) + ctx.Log.Logf("serialize_types", "Registered PolicyType: %+v - %+v", reflect_type, policy_type) ctx.Policies[policy_type] = reflect_type ctx.PolicyTypes[reflect_type] = policy_type @@ -125,12 +132,17 @@ func (ctx *Context)RegisterSignal(reflect_type reflect.Type, signal_type SignalT return fmt.Errorf("Cannot register signal of type %+v, type already exists in context", signal_type) } - err := ctx.RegisterType(reflect_type, SerializedType(signal_type), SerializeStruct(ctx, reflect_type), DeserializeStruct(ctx, reflect_type)) + signal_info, err := GetStructInfo(ctx, reflect_type) if err != nil { return err } - ctx.Log.Logf("serialize", "Registered SignalType: %+v - %+v", reflect_type, signal_type) + err = ctx.RegisterType(reflect_type, SerializedType(signal_type), nil, SerializeStruct(signal_info), nil, DeserializeStruct(signal_info)) + if err != nil { + return err + } + + ctx.Log.Logf("serialize_types", "Registered SignalType: %+v - %+v", reflect_type, signal_type) ctx.Signals[signal_type] = reflect_type ctx.SignalTypes[reflect_type] = signal_type @@ -144,12 +156,17 @@ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType } elem_type := reflect_type.Elem() - err := ctx.RegisterType(elem_type, SerializedType(ext_type), SerializeStruct(ctx, elem_type), DeserializeStruct(ctx, elem_type)) + elem_info, err := GetStructInfo(ctx, elem_type) + if err != nil { + return err + } + + err = ctx.RegisterType(elem_type, SerializedType(ext_type), nil, SerializeStruct(elem_info), nil, DeserializeStruct(elem_info)) if err != nil { return err } - ctx.Log.Logf("serialize", "Registered ExtType: %+v - %+v", reflect_type, ext_type) + ctx.Log.Logf("serialize_types", "Registered ExtType: %+v - %+v", reflect_type, ext_type) ctx.Extensions[ext_type] = ExtensionInfo{ Type: reflect_type, @@ -160,7 +177,7 @@ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType return nil } -func (ctx *Context)RegisterKind(kind reflect.Kind, ctx_type SerializedType, serialize TypeSerialize, deserialize TypeDeserialize) error { +func (ctx *Context)RegisterKind(kind reflect.Kind, base reflect.Type, ctx_type SerializedType, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error { _, exists := ctx.Kinds[kind] if exists == true { return fmt.Errorf("Cannot register kind %+v, kind already exists in context", kind) @@ -177,18 +194,23 @@ func (ctx *Context)RegisterKind(kind reflect.Kind, ctx_type SerializedType, seri } info := KindInfo{ - kind, - ctx_type, - serialize, - deserialize, + Reflect: kind, + Base: base, + Type: ctx_type, + TypeSerialize: type_serialize, + Serialize: serialize, + TypeDeserialize: type_deserialize, + Deserialize: deserialize, } ctx.KindTypes[ctx_type] = &info ctx.Kinds[kind] = &info + ctx.Log.Logf("serialize_types", "Registered kind %+v, %+v", kind, ctx_type) + return nil } -func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedType, serialize TypeSerialize, deserialize TypeDeserialize) error { +func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedType, type_serialize TypeSerializeFn, serialize SerializeFn, type_deserialize TypeDeserializeFn, deserialize DeserializeFn) error { _, exists := ctx.Types[ctx_type] if exists == true { return fmt.Errorf("Cannot register field of type %+v, type already exists in context", ctx_type) @@ -198,20 +220,45 @@ func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedTy return fmt.Errorf("Cannot register field with type %+v, type already registered in context", reflect_type) } + if type_serialize == nil || type_deserialize == nil { + kind_info, kind_registered := ctx.Kinds[reflect_type.Kind()] + if kind_registered == true { + if type_serialize == nil { + type_serialize = kind_info.TypeSerialize + } + if type_deserialize == nil { + type_deserialize = kind_info.TypeDeserialize + } + } + } + if serialize == nil || deserialize == nil { - return fmt.Errorf("Cannot register field without serialize/deserialize functions") + kind_info, kind_registered := ctx.Kinds[reflect_type.Kind()] + if kind_registered == false { + return fmt.Errorf("No serialize/deserialize passed and none registered for kind %+v", reflect_type.Kind()) + } else { + if serialize == nil { + serialize = kind_info.Serialize + } + if deserialize == nil { + deserialize = kind_info.Deserialize + } + } } type_info := TypeInfo{ Reflect: reflect_type, Type: ctx_type, + TypeSerialize: type_serialize, Serialize: serialize, + TypeDeserialize: type_deserialize, Deserialize: deserialize, } + ctx.Types[ctx_type] = &type_info ctx.TypeReflects[reflect_type] = &type_info - ctx.Log.Logf("serialize", "Registered Type: %+v - %+v", reflect_type, ctx_type) + ctx.Log.Logf("serialize_types", "Registered Type: %+v - %+v", reflect_type, ctx_type) return nil } @@ -312,994 +359,225 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { } var err error - err = ctx.RegisterKind(reflect.Pointer, PointerType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { - var data []byte - var elem_value *reflect.Value = nil - if value == nil { - data = nil - } else if value.IsZero() { - data = []byte{0x01} - } else { - data = []byte{0x00} - ev := value.Elem() - elem_value = &ev - } - elem, err := SerializeValue(ctx, reflect_type.Elem(), elem_value) - if err != nil { - return SerializedValue{}, err - } - if elem.Data != nil { - data = append(data, elem.Data...) - } - return SerializedValue{ - append([]SerializedType{ctx_type}, elem.TypeStack...), - data, - }, nil - }, func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - if value.Data == nil { - var elem_type reflect.Type - var err error - elem_type, _, value, err = DeserializeValue(ctx, value) - if err != nil { - return nil, nil, SerializedValue{}, err - } - return reflect.PointerTo(elem_type), nil, value, nil - } else if len(value.Data) < 1 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize pointer") - } else { - pointer_flags := value.Data[0] - value.Data = value.Data[1:] - ctx.Log.Logf("serialize", "Pointer flags: 0x%x", pointer_flags) - if pointer_flags == 0x00 { - elem_type, elem_value, remaining_data, err := DeserializeValue(ctx, value) - if err != nil { - return nil, nil, SerializedValue{}, err - } - pointer_type := reflect.PointerTo(elem_type) - pointer_value := reflect.New(pointer_type).Elem() - pointer_value.Set(elem_value.Addr()) - return pointer_type, &pointer_value, remaining_data, nil - } else if pointer_flags == 0x01 { - tmp_value := SerializedValue{ - value.TypeStack, - nil, - } - var elem_type reflect.Type - var err error - elem_type, _, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, SerializedValue{}, err - } - value.TypeStack = tmp_value.TypeStack - - pointer_type := reflect.PointerTo(elem_type) - pointer_value := reflect.New(pointer_type).Elem() - return pointer_type, &pointer_value, value, nil - } else { - return nil, nil, SerializedValue{}, fmt.Errorf("unknown pointer flags: %d", pointer_flags) - } - } - }) + err = ctx.RegisterKind(reflect.Pointer, nil, PointerType, SerializeTypeElem, SerializePointer, DeserializeTypePointer, DeserializePointer) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Struct, StructType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - return SerializedValue{}, fmt.Errorf("Cannot serialize unregistered struct %+v", reflect_type) - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - return nil, nil, value, fmt.Errorf("Cannot deserialize unregistered struct") - }) + err = ctx.RegisterKind(reflect.Bool, reflect.TypeOf(true), BoolType, nil, SerializeBool, nil, DeserializeBool[bool]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Bool, BoolType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - b := value.Bool() - if b == true { - data = []byte{0x01} - } else { - data = []byte{0x00} - } - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(true), nil, value, nil - } else if len(value.Data) == 0 { - return nil, nil, SerializedValue{}, fmt.Errorf("not enough data to deserialize bool") - } else { - b := value.Data[0] - value.Data = value.Data[1:] - var val reflect.Value - switch b { - case 0x00: - val = reflect.ValueOf(false) - case 0x01: - val = reflect.ValueOf(true) - default: - return nil, nil, SerializedValue{}, fmt.Errorf("unknown boolean 0x%x", b) - } - return reflect.TypeOf(true), &val, value, nil - } - }) + err = ctx.RegisterKind(reflect.String, reflect.TypeOf(""), StringType, nil, SerializeString, nil, DeserializeString) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Float64, Float64Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = make([]byte, 8) - val := math.Float64bits(float64(value.Float())) - binary.BigEndian.PutUint64(data, val) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(float64(0)), nil, value, nil - } else { - if len(value.Data) < 8 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize float32") - } - val_int := binary.BigEndian.Uint64(value.Data[0:8]) - value.Data = value.Data[8:] - val := math.Float64frombits(val_int) - - float_val := reflect.ValueOf(val) - - return float_val.Type(), &float_val, value, nil - } - }) - if err != nil { + err = ctx.RegisterKind(reflect.Float32, reflect.TypeOf(float32(0)), Float32Type, nil, SerializeFloat32, nil, DeserializeFloat32[float32]) + if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Float32, Float32Type, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte = nil - if value != nil { - data = make([]byte, 4) - val := math.Float32bits(float32(value.Float())) - binary.BigEndian.PutUint32(data, val) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(float32(0)), nil, value, nil - } else { - if len(value.Data) < 4 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize float32") - } - val_int := binary.BigEndian.Uint32(value.Data[0:4]) - value.Data = value.Data[4:] - val := math.Float32frombits(val_int) - - float_value := reflect.ValueOf(val) + err = ctx.RegisterKind(reflect.Float64, reflect.TypeOf(float64(0)), Float64Type, nil, SerializeFloat64, nil, DeserializeFloat64[float64]) + if err != nil { + return nil, err + } - return float_value.Type(), &float_value, value, nil - } - }) + err = ctx.RegisterKind(reflect.Uint, reflect.TypeOf(uint(0)), UIntType, nil, SerializeUint32, nil, DeserializeUint32[uint]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.String, StringType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - if value == nil { - return SerializedValue{ - []SerializedType{ctx_type}, - nil, - }, nil - } - - data := make([]byte, 8) - str := value.String() - binary.BigEndian.PutUint64(data, uint64(len(str))) - return SerializedValue{ - []SerializedType{SerializedType(ctx_type)}, - append(data, []byte(str)...), - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return reflect.TypeOf(""), nil, value, nil - } else if len(value.Data) < 8 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize string") - } else { - str_len := binary.BigEndian.Uint64(value.Data[0:8]) - value.Data = value.Data[8:] - if len(value.Data) < int(str_len) { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize string of length %d(%d)", str_len, len(value.Data)) - } - string_bytes := value.Data[:str_len] - value.Data = value.Data[str_len:] - str_value := reflect.ValueOf(string(string_bytes)) - return reflect.TypeOf(""), &str_value, value, nil - } - }) + err = ctx.RegisterKind(reflect.Uint8, reflect.TypeOf(uint8(0)), UInt8Type, nil, SerializeUint8, nil, DeserializeUint8[uint8]) if err != nil { return nil, err } - - err = ctx.RegisterKind(reflect.Array, ArrayType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte - type_stack := []SerializedType{ctx_type, SerializedType(reflect_type.Len())} - if value == nil { - data = nil - } else if value.IsZero() { - return SerializedValue{}, fmt.Errorf("don't know what zero array means...") - } else { - data := []byte{} - var element SerializedValue - var err error - for i := 0; i < value.Len(); i += 1 { - val := value.Index(i) - element, err = SerializeValue(ctx, reflect_type.Elem(), &val) - if err != nil { - return SerializedValue{}, err - } - data = append(data, element.Data...) - } - return SerializedValue{ - append(type_stack, element.TypeStack...), - data, - }, nil - } - element, err := SerializeValue(ctx, reflect_type.Elem(), nil) - if err != nil { - return SerializedValue{}, err - } - return SerializedValue{ - append(type_stack, element.TypeStack...), - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if len(value.TypeStack) < 1 { - return nil, nil, SerializedValue{}, fmt.Errorf("no array size in type stack") - } - array_size := int(value.TypeStack[0]) - value.TypeStack = value.TypeStack[1:] - if value.Data == nil { - elem_type, _, _, err := DeserializeValue(ctx, value) - if err != nil { - return nil, nil, SerializedValue{}, err - } - return reflect.ArrayOf(array_size, elem_type), nil, value, nil - } else { - if array_size == 0x00 { - elem_type, _, remaining, err := DeserializeValue(ctx, SerializedValue{ - value.TypeStack, - nil, - }) - if err != nil { - return nil, nil, SerializedValue{}, err - } - array_type := reflect.ArrayOf(array_size, elem_type) - array_value := reflect.New(array_type).Elem() - return array_type, &array_value, SerializedValue{ - remaining.TypeStack, - value.Data, - }, nil - } else { - var reflect_value *reflect.Value = nil - var reflect_type reflect.Type = nil - saved_type_stack := value.TypeStack - for i := 0; i < array_size; i += 1 { - var element_type reflect.Type - var element_value *reflect.Value - element_type, element_value, value, err = DeserializeValue(ctx, value) - if err != nil { - return nil, nil, value, err - } - if reflect_value == nil { - reflect_type = reflect.ArrayOf(array_size, element_type) - real_value := reflect.New(reflect_type).Elem() - reflect_value = &real_value - } - if i != (array_size - 1) { - value.TypeStack = saved_type_stack - } - slice_index_ptr := reflect_value.Index(i) - slice_index_ptr.Set(*element_value) - } - return reflect_type, reflect_value, value, nil - } - } - }) + err = ctx.RegisterKind(reflect.Uint16, reflect.TypeOf(uint16(0)), UInt16Type, nil, SerializeUint16, nil, DeserializeUint16[uint16]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Interface, InterfaceType, SerializeInterface, DeserializeInterface[interface{}]()) + err = ctx.RegisterKind(reflect.Uint32, reflect.TypeOf(uint32(0)), UInt32Type, nil, SerializeUint32, nil, DeserializeUint32[uint32]) if err != nil { return nil, err } - - err = ctx.RegisterKind(reflect.Map, MapType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ - var data []byte - type_stack := []SerializedType{ctx_type} - if value == nil { - data = nil - } else if value.IsZero() { - data = []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} - } else if value.Len() == 0 { - data = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - } else { - data = make([]byte, 8) - map_size := 0 - var key_types, elem_types []SerializedType - - map_iter := value.MapRange() - for map_iter.Next() { - map_size += 1 - key_reflect := map_iter.Key() - elem_reflect := map_iter.Value() - - key_value, err := SerializeValue(ctx, key_reflect.Type(), &key_reflect) - if err != nil { - return SerializedValue{}, err - } - elem_value, err := SerializeValue(ctx, elem_reflect.Type(), &elem_reflect) - if err != nil { - return SerializedValue{}, err - } - - data = append(data, key_value.Data...) - data = append(data, elem_value.Data...) - - if key_types == nil { - key_types = key_value.TypeStack - elem_types = elem_value.TypeStack - } - } - - binary.BigEndian.PutUint64(data[0:8], uint64(map_size)) - - ctx.Log.Logf("serialize", "MAP_TYPES: %+v - %+v", key_types, elem_types) - - type_stack = append(type_stack, key_types...) - type_stack = append(type_stack, elem_types...) - return SerializedValue{ - type_stack, - data, - }, nil - } - key_value, err := SerializeValue(ctx, reflect_type.Key(), nil) - if err != nil { - return SerializedValue{}, nil - } - elem_value, err := SerializeValue(ctx, reflect_type.Elem(), nil) - if err != nil { - return SerializedValue{}, nil - } - - type_stack = append(type_stack, key_value.TypeStack...) - type_stack = append(type_stack, elem_value.TypeStack...) - - return SerializedValue{ - type_stack, - data, - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - var key_type, elem_type reflect.Type - var err error - key_type, _, value, err = DeserializeValue(ctx, value) - if err != nil { - return nil, nil, value, err - } - elem_type, _, value, err = DeserializeValue(ctx, value) - if err != nil { - return nil, nil, value, err - } - reflect_type := reflect.MapOf(key_type, elem_type) - return reflect_type, nil, value, nil - } else if len(value.Data) < 8 { - return nil, nil, value, fmt.Errorf("Not enough data to deserialize map") - } else { - var map_size_bytes []byte - var err error - map_size_bytes, value, err = value.PopData(8) - if err != nil { - return nil, nil, value, err - } - - map_size := binary.BigEndian.Uint64(map_size_bytes) - ctx.Log.Logf("serialize", "Deserializing %d elements in map", map_size) - - if map_size == 0xFFFFFFFFFFFFFFFF { - var key_type, elem_type reflect.Type - var err error - tmp_value := SerializedValue{ - value.TypeStack, - nil, - } - key_type, _, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - elem_type, _, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - new_value := SerializedValue{ - tmp_value.TypeStack, - value.Data, - } - reflect_type := reflect.MapOf(key_type, elem_type) - reflect_value := reflect.New(reflect_type).Elem() - return reflect_type, &reflect_value, new_value, nil - } else if map_size == 0x00 { - var key_type, elem_type reflect.Type - var err error - tmp_value := SerializedValue{ - value.TypeStack, - nil, - } - key_type, _, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - elem_type, _, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - new_value := SerializedValue{ - tmp_value.TypeStack, - value.Data, - } - reflect_type := reflect.MapOf(key_type, elem_type) - reflect_value := reflect.MakeMap(reflect_type) - return reflect_type, &reflect_value, new_value, nil - } else { - tmp_value := value - var map_value reflect.Value - var map_type reflect.Type = nil - for i := 0; i < int(map_size); i += 1 { - tmp_value.TypeStack = value.TypeStack - var key_type, elem_type reflect.Type - var key_value, elem_value *reflect.Value - var err error - key_type, key_value, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - elem_type, elem_value, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - if map_type == nil { - map_type = reflect.MapOf(key_type, elem_type) - map_value = reflect.MakeMap(map_type) - } - map_value.SetMapIndex(*key_value, *elem_value) - } - return map_type, &map_value, tmp_value, nil - } - } - }) + err = ctx.RegisterKind(reflect.Uint64, reflect.TypeOf(uint64(0)), UInt64Type, nil, SerializeUint64, nil, DeserializeUint64[uint64]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int8, Int8Type, SerializeIntN(1), DeserializeIntN[int8](1)) + err = ctx.RegisterKind(reflect.Int, reflect.TypeOf(int(0)), IntType, nil, SerializeInt32, nil, DeserializeUint32[int]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int16, Int16Type, SerializeIntN(2), DeserializeIntN[int16](2)) + err = ctx.RegisterKind(reflect.Int8, reflect.TypeOf(int8(0)), Int8Type, nil, SerializeInt8, nil, DeserializeUint8[int8]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int32, Int32Type, SerializeIntN(4), DeserializeIntN[int32](4)) + err = ctx.RegisterKind(reflect.Int16, reflect.TypeOf(int16(0)), Int16Type, nil, SerializeInt16, nil, DeserializeUint16[int16]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int64, Int64Type, SerializeIntN(8), DeserializeIntN[int64](8)) + err = ctx.RegisterKind(reflect.Int32, reflect.TypeOf(int32(0)), Int32Type, nil, SerializeInt32, nil, DeserializeUint32[int32]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Int, IntType, SerializeIntN(8), DeserializeIntN[int](8)) + err = ctx.RegisterKind(reflect.Int64, reflect.TypeOf(int64(0)), Int64Type, nil, SerializeInt64, nil, DeserializeUint64[int64]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint8, UInt8Type, SerializeUintN(1), DeserializeUintN[uint8](1)) + err = ctx.RegisterType(reflect.TypeOf(time.Duration(0)), DurationType, nil, nil, nil, DeserializeUint64[time.Duration]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint16, UInt16Type, SerializeUintN(2), DeserializeUintN[uint16](2)) + err = ctx.RegisterType(reflect.TypeOf(time.Time{}), TimeType, nil, SerializeGob, nil, DeserializeGob[time.Time]) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint32, UInt32Type, SerializeUintN(4), DeserializeUintN[uint32](4)) + err = ctx.RegisterKind(reflect.Map, nil, MapType, SerializeTypeMap, SerializeMap, DeserializeTypeMap, DeserializeMap) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint64, UInt64Type, SerializeUintN(8), DeserializeUintN[uint64](8)) + err = ctx.RegisterKind(reflect.Array, nil, ArrayType, SerializeTypeArray, SerializeArray, DeserializeTypeArray, DeserializeArray) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Uint, UIntType, SerializeUintN(8), DeserializeUintN[uint](8)) + err = ctx.RegisterKind(reflect.Slice, nil, SliceType, SerializeTypeElem, SerializeSlice, DeserializeTypeSlice, DeserializeSlice) if err != nil { return nil, err } - err = ctx.RegisterKind(reflect.Slice, SliceType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { - type_stack := []SerializedType{ctx_type} - var data []byte - if value == nil { - data = nil - } else if value.IsZero() { - data = []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} - } else if value.Len() == 0 { - data = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - } else { - data := make([]byte, 8) - binary.BigEndian.PutUint64(data, uint64(value.Len())) - - var element SerializedValue - var err error - for i := 0; i < value.Len(); i += 1 { - val := value.Index(i) - element, err = SerializeValue(ctx, reflect_type.Elem(), &val) - if err != nil { - return SerializedValue{}, err - } - data = append(data, element.Data...) - } - - return SerializedValue{ - append(type_stack, element.TypeStack...), - data, - }, nil - } - - element, err := SerializeValue(ctx, reflect_type.Elem(), nil) - if err != nil { - return SerializedValue{}, err - } - - return SerializedValue{ - append(type_stack, element.TypeStack...), - data, - }, nil - }, - func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - elem_type, _, _, err := DeserializeValue(ctx, value) - if err != nil { - return nil, nil, SerializedValue{}, err - } - return reflect.SliceOf(elem_type), nil, value, nil - } else if len(value.Data) < 8 { - return nil, nil, SerializedValue{}, fmt.Errorf("Not enough data to deserialize slice") - } else { - slice_length := binary.BigEndian.Uint64(value.Data[0:8]) - value.Data = value.Data[8:] - if slice_length == 0xFFFFFFFFFFFFFFFF { - elem_type, _, remaining, err := DeserializeValue(ctx, SerializedValue{ - value.TypeStack, - nil, - }) - if err != nil { - return nil, nil, SerializedValue{}, err - } - reflect_type := reflect.SliceOf(elem_type) - reflect_value := reflect.New(reflect_type).Elem() - return reflect_type, &reflect_value, SerializedValue{ - remaining.TypeStack, - value.Data, - }, nil - } else if slice_length == 0x00 { - elem_type, _, remaining, err := DeserializeValue(ctx, SerializedValue{ - value.TypeStack, - nil, - }) - if err != nil { - return nil, nil, SerializedValue{}, err - } - reflect_value := reflect.MakeSlice(reflect.SliceOf(elem_type), 0, 0) - return reflect_value.Type(), &reflect_value, SerializedValue{ - remaining.TypeStack, - value.Data, - }, nil - } else { - var reflect_value *reflect.Value = nil - var reflect_type reflect.Type = nil - saved_type_stack := value.TypeStack - for i := 0; i < int(slice_length); i += 1 { - var element_type reflect.Type - var element_value *reflect.Value - element_type, element_value, value, err = DeserializeValue(ctx, value) - if err != nil { - return nil, nil, value, err - } - if reflect_value == nil { - reflect_type = reflect.SliceOf(element_type) - real_value := reflect.MakeSlice(reflect_type, int(slice_length), int(slice_length)) - reflect_value = &real_value - } - if i != (int(slice_length) - 1) { - value.TypeStack = saved_type_stack - } - slice_index_ptr := reflect_value.Index(i) - slice_index_ptr.Set(*element_value) - } - return reflect_type, reflect_value, value, nil - } - } - }) + var ptr interface{} = nil + err = ctx.RegisterKind(reflect.Interface, reflect.TypeOf(&ptr).Elem(), InterfaceType, nil, SerializeInterface, nil, DeserializeInterface) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(Tree{}), TreeType, func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ - var data []byte - type_stack := []SerializedType{ctx_type} - if value == nil { - data = nil - } else if value.IsZero() { - data = []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} - } else if value.Len() == 0 { - data = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - } else { - data = make([]byte, 8) - map_size := 0 - - map_iter := value.MapRange() - type TreeMapValue struct{ - Key SerializedType - Data []byte - } - value_stacks := []TreeMapValue{} - for map_iter.Next() { - map_size += 1 - key_reflect := map_iter.Key() - elem_reflect := map_iter.Value() - - key_value, err := SerializeValue(ctx, key_reflect.Type(), &key_reflect) - if err != nil { - return SerializedValue{}, err - } - elem_value, err := SerializeValue(ctx, elem_reflect.Type(), &elem_reflect) - if err != nil { - return SerializedValue{}, err - } - - value_stacks = append(value_stacks, TreeMapValue{ - SerializedType(key_reflect.Uint()), - append(key_value.Data, elem_value.Data...), - }) - } - - // Sort the value_stacks, then add them to `data` - sort.Slice(value_stacks, func(i, j int) bool { - return value_stacks[i].Key > value_stacks[j].Key - }) - - for _, stack := range(value_stacks) { - data = append(data, stack.Data...) - } - - binary.BigEndian.PutUint64(data[0:8], uint64(map_size)) - } - return SerializedValue{ - type_stack, - data, - }, nil - },func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ - if value.Data == nil { - return reflect.TypeOf(Tree{}), nil, value, nil - } else if len(value.Data) < 8 { - return nil, nil, value, fmt.Errorf("Not enough data to deserialize Tree") - } else { - var map_size_bytes []byte - var err error - map_size_bytes, value, err = value.PopData(8) - if err != nil { - return nil, nil, value, err - } - - map_size := binary.BigEndian.Uint64(map_size_bytes) - ctx.Log.Logf("serialize", "Deserializing %d elements in Tree", map_size) - - if map_size == 0xFFFFFFFFFFFFFFFF { - reflect_type := reflect.TypeOf(Tree{}) - reflect_value := reflect.New(reflect_type).Elem() - return reflect_type, &reflect_value, value, nil - } else if map_size == 0x00 { - reflect_type := reflect.TypeOf(Tree{}) - reflect_value := reflect.MakeMap(reflect_type) - return reflect_type, &reflect_value, value, nil - } else { - reflect_type := reflect.TypeOf(Tree{}) - reflect_value := reflect.MakeMap(reflect_type) - - tmp_value := value - - for i := 0; i < int(map_size); i += 1 { - tmp_value.TypeStack = append([]SerializedType{SerializedTypeSerialized, TreeType}, value.TypeStack...) - - var key_value, elem_value *reflect.Value - var err error - _, key_value, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - _, elem_value, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - reflect_value.SetMapIndex(*key_value, *elem_value) - } - - return reflect_type, &reflect_value, tmp_value, nil - } - } - }) - - err = ctx.RegisterType(reflect.TypeOf(SerializedType(0)), SerializedTypeSerialized, SerializeUintN(8), DeserializeUintN[SerializedType](8)) + err = ctx.RegisterType(reflect.TypeOf(SerializedType(0)), SerializedTypeSerialized, nil, SerializeUint64, nil, DeserializeUint64[SerializedType]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(Changes{}), ChangesSerialized, SerializeSlice, DeserializeSlice[Changes](ctx)) + err = ctx.RegisterType(reflect.TypeOf(Changes{}), ChangesSerialized, SerializeTypeStub, SerializeSlice, DeserializeTypeStub[Changes], DeserializeSlice) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(ExtType(0)), ExtTypeSerialized, SerializeUintN(8), DeserializeUintN[ExtType](8)) + err = ctx.RegisterType(reflect.TypeOf(ExtType(0)), ExtTypeSerialized, nil, SerializeUint64, nil, DeserializeUint64[ExtType]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeSerialized, SerializeUintN(8), DeserializeUintN[NodeType](8)) + err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeSerialized, nil, SerializeUint64, nil, DeserializeUint64[NodeType]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(PolicyType(0)), PolicyTypeSerialized, SerializeUintN(8), DeserializeUintN[PolicyType](8)) + err = ctx.RegisterType(reflect.TypeOf(PolicyType(0)), PolicyTypeSerialized, nil, SerializeUint64, nil, DeserializeUint64[PolicyType]) if err != nil { return nil, err } node_id_type := reflect.TypeOf(RandID()) - err = ctx.RegisterType(node_id_type, NodeIDType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ - type_stack := []SerializedType{ctx_type} - if value == nil { - return SerializedValue{ - type_stack, - nil, - }, nil - } else { - data, err := value.Interface().(NodeID).MarshalBinary() - if err != nil { - return SerializedValue{}, err - } - return SerializedValue{ - type_stack, - data, - }, nil - } - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return node_id_type, nil, value, nil - } else { - id_data, value, err := value.PopData(16) - if err != nil { - return nil, nil, value, err - } - - id, err := IDFromBytes(id_data) - if err != nil { - return nil, nil, value, err - } - - id_value := reflect.New(node_id_type).Elem() - id_value.Set(reflect.ValueOf(id)) - return node_id_type, &id_value, value, nil - } - }) + err = ctx.RegisterType(node_id_type, NodeIDType, SerializeTypeStub, SerializeArray, DeserializeTypeStub[NodeID], DeserializeArray) if err != nil { return nil, err } uuid_type := reflect.TypeOf(uuid.UUID{}) - err = ctx.RegisterType(uuid_type, UUIDType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ - type_stack := []SerializedType{ctx_type} - if value == nil { - return SerializedValue{ - type_stack, - nil, - }, nil - } else { - data, err := value.Interface().(uuid.UUID).MarshalBinary() - if err != nil { - return SerializedValue{}, err - } - return SerializedValue{ - type_stack, - data, - }, nil - } - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - if value.Data == nil { - return uuid_type, nil, value, nil - } else { - id_data, value, err := value.PopData(16) - if err != nil { - return nil, nil, value, err - } - - id, err := uuid.FromBytes(id_data) - if err != nil { - return nil, nil, value, err - } - - id_value := reflect.New(uuid_type).Elem() - id_value.Set(reflect.ValueOf(id)) - return uuid_type, &id_value, value, nil - } - }) + err = ctx.RegisterType(uuid_type, UUIDType, SerializeTypeStub, SerializeArray, DeserializeTypeStub[uuid.UUID], DeserializeArray) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(Up), SignalDirectionType, SerializeUintN(1), DeserializeUintN[SignalDirection](1)) + err = ctx.RegisterType(reflect.TypeOf(Up), SignalDirectionType, nil, SerializeUint8, nil, DeserializeUint8[SignalDirection]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(ReqState(0)), ReqStateType, SerializeUintN(1), DeserializeUintN[ReqState](1)) + err = ctx.RegisterType(reflect.TypeOf(ReqState(0)), ReqStateType, nil, SerializeUint8, nil, DeserializeUint8[ReqState]) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(time.Duration(0)), DurationType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ - var data []byte - type_stack := []SerializedType{ctx_type} - if value == nil { - data = nil - } else { - data = make([]byte, 8) - binary.BigEndian.PutUint64(data, uint64(value.Int())) - } - return SerializedValue{ - type_stack, - data, - }, nil - },func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ - if value.Data == nil { - return reflect.TypeOf(time.Duration(0)), nil, value, nil - } else { - var bytes []byte - var err error - bytes, value, err = value.PopData(8) - if err != nil { - return nil, nil, value, err - } - duration := time.Duration(int64(binary.BigEndian.Uint64(bytes))) - duration_value := reflect.ValueOf(duration) - return duration_value.Type(), &duration_value, value, nil - } - }) - - err = ctx.RegisterType(reflect.TypeOf(time.Time{}), TimeType, - func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error) { - var data []byte - type_stack := []SerializedType{ctx_type} - if value == nil { - data = nil - } else { - data = make([]byte, 8) - time_ser, err := value.Interface().(time.Time).GobEncode() - if err != nil { - return SerializedValue{}, err - } - data = append(data, time_ser...) - binary.BigEndian.PutUint64(data[0:8], uint64(len(time_ser))) - } - return SerializedValue{ - type_stack, - data, - }, nil - },func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ - if value.Data == nil { - return reflect.TypeOf(time.Time{}), nil, value, nil - } else { - var ser_size_bytes []byte - ser_size_bytes, value, err = value.PopData(8) - if err != nil { - return nil, nil, value, err - } - ser_size := int(binary.BigEndian.Uint64(ser_size_bytes)) - if ser_size > len(value.Data) { - return nil, nil, value, fmt.Errorf("ser_size %d is larger than remaining data %d", ser_size, len(value.Data)) - } - data := value.Data[0:ser_size] - value.Data = value.Data[ser_size:] - time_value := reflect.New(reflect.TypeOf(time.Time{})) - time_value.Interface().(*time.Time).GobDecode(data) - time_nonptr := time_value.Elem() - return time_nonptr.Type(), &time_nonptr, value, nil - } - }) + err = ctx.RegisterType(reflect.TypeOf(Tree{}), TreeType, SerializeTypeStub, nil, DeserializeTypeStub[Tree], nil) - // TODO: Make registering interfaces cleaner var extension Extension = nil - err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtSerialized, SerializeInterface, DeserializeInterface[Extension]()) + err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtSerialized, nil, SerializeInterface, nil, DeserializeInterface) if err != nil { return nil, err } var policy Policy = nil - err = ctx.RegisterType(reflect.ValueOf(&policy).Type().Elem(), PolicySerialized, SerializeInterface, DeserializeInterface[Policy]()) + err = ctx.RegisterType(reflect.ValueOf(&policy).Type().Elem(), PolicySerialized, nil, SerializeInterface, nil, DeserializeInterface) if err != nil { return nil, err } var signal Signal = nil - err = ctx.RegisterType(reflect.ValueOf(&signal).Type().Elem(), SignalSerialized, SerializeInterface, DeserializeInterface[Signal]()) + err = ctx.RegisterType(reflect.ValueOf(&signal).Type().Elem(), SignalSerialized, nil, SerializeInterface, nil, DeserializeInterface) if err != nil { return nil, err } pending_acl_type := reflect.TypeOf(PendingACL{}) - err = ctx.RegisterType(pending_acl_type, PendingACLType, SerializeStruct(ctx, pending_acl_type), DeserializeStruct(ctx, pending_acl_type)) + pending_acl_info, err := GetStructInfo(ctx, pending_acl_type) + if err != nil { + return nil, err + } + err = ctx.RegisterType(pending_acl_type, PendingACLType, nil, SerializeStruct(pending_acl_info), nil, DeserializeStruct(pending_acl_info)) if err != nil { return nil, err } pending_signal_type := reflect.TypeOf(PendingSignal{}) - err = ctx.RegisterType(pending_signal_type, PendingSignalType, SerializeStruct(ctx, pending_signal_type), DeserializeStruct(ctx, pending_signal_type)) + pending_signal_info, err := GetStructInfo(ctx, pending_signal_type) + if err != nil { + return nil, err + } + err = ctx.RegisterType(pending_signal_type, PendingSignalType, nil, SerializeStruct(pending_signal_info), nil, DeserializeStruct(pending_signal_info)) if err != nil { return nil, err } queued_signal_type := reflect.TypeOf(QueuedSignal{}) - err = ctx.RegisterType(queued_signal_type, QueuedSignalType, SerializeStruct(ctx, queued_signal_type), DeserializeStruct(ctx, queued_signal_type)) + queued_signal_info, err := GetStructInfo(ctx, queued_signal_type) + if err != nil { + return nil, err + } + err = ctx.RegisterType(queued_signal_type, QueuedSignalType, nil, SerializeStruct(queued_signal_info), nil, DeserializeStruct(queued_signal_info)) if err != nil { return nil, err } node_type := reflect.TypeOf(Node{}) - err = ctx.RegisterType(node_type, NodeStructType, SerializeStruct(ctx, node_type), DeserializeStruct(ctx, node_type)) + node_info, err := GetStructInfo(ctx, node_type) + if err != nil { + return nil, err + } + err = ctx.RegisterType(node_type, NodeStructType, nil, SerializeStruct(node_info), nil, DeserializeStruct(node_info)) if err != nil { return nil, err } req_info_type := reflect.TypeOf(ReqInfo{}) - err = ctx.RegisterType(req_info_type, ReqInfoType, SerializeStruct(ctx, req_info_type), DeserializeStruct(ctx, req_info_type)) + req_info_info, err := GetStructInfo(ctx, req_info_type) + if err != nil { + return nil, err + } + err = ctx.RegisterType(req_info_type, ReqInfoType, nil, SerializeStruct(req_info_info), nil, DeserializeStruct(req_info_info)) if err != nil { return nil, err } @@ -1330,6 +608,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterExtension(reflect.TypeOf((*EventExt)(nil)), EventExtType, nil) + if err != nil { + return nil, err + } + err = ctx.RegisterPolicy(reflect.TypeOf(MemberOfPolicy{}), MemberOfPolicyType) if err != nil { return nil, err @@ -1440,6 +723,16 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterSignal(reflect.TypeOf(EventControlSignal{}), EventControlSignalType) + if err != nil { + return nil, err + } + + err = ctx.RegisterSignal(reflect.TypeOf(EventStateSignal{}), EventStateSignalType) + if err != nil { + return nil, err + } + err = ctx.RegisterNodeType(BaseNodeType, []ExtType{}) if err != nil { return nil, err diff --git a/event.go b/event.go new file mode 100644 index 0000000..a0361fe --- /dev/null +++ b/event.go @@ -0,0 +1,111 @@ +package graphvent + +import ( + "time" + "fmt" +) + +type EventExt struct { + Name string `"name"` + State string `"state"` + Parent *NodeID `"parent"` +} + +func NewEventExt(parent *NodeID, name string) *EventExt { + return &EventExt{ + Name: name, + State: "init", + Parent: parent, + } +} + +type EventStateSignal struct { + SignalHeader + Source NodeID + State string + Time time.Time +} + +func (signal EventStateSignal) Permission() Tree { + return Tree{ + SerializedType(StatusType): nil, + } +} + +func (signal EventStateSignal) String() string { + return fmt.Sprintf("EventStateSignal(%s, %s, %s, %+v)", signal.SignalHeader, signal.Source, signal.State, signal.Time) +} + +func NewEventStateSignal(source NodeID, state string, t time.Time) *EventStateSignal { + return &EventStateSignal{ + SignalHeader: NewSignalHeader(Up), + Source: source, + State: state, + Time: t, + } +} + +type EventControlSignal struct { + SignalHeader + Command string +} + +func NewEventControlSignal(command string) *EventControlSignal { + return &EventControlSignal{ + NewSignalHeader(Direct), + command, + } +} + +func (signal EventControlSignal) Permission() Tree { + return Tree{ + SerializedType(EventControlSignalType): { + Hash("command", signal.Command): nil, + }, + } +} + +var transitions = map[string]struct{ + from_state string + to_state string +}{ + "start": { + "init", + "running", + }, + "stop": { + "running", + "init", + }, + "finish": { + "running", + "done", + }, +} + +func (ext *EventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { + var messages Messages = nil + var changes Changes = nil + + if signal.Direction() == Up && ext.Parent != nil { + messages = messages.Add(ctx, *ext.Parent, node, nil, signal) + } + + switch sig := signal.(type) { + case *EventControlSignal: + info, exists := transitions[sig.Command] + if exists == true { + if ext.State == info.from_state { + ext.State = info.to_state + messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) + node.QueueSignal(time.Now(), NewEventStateSignal(node.ID, ext.State, time.Now())) + } else { + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "bad_state")) + } + } else { + messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "bad_command")) + } + } + + return messages, changes +} diff --git a/event_test.go b/event_test.go new file mode 100644 index 0000000..f9ad1ad --- /dev/null +++ b/event_test.go @@ -0,0 +1,13 @@ +package graphvent + +import ( + "testing" +) + +func TestEvent(t *testing.T) { + ctx := logTestContext(t, []string{"event", "listener"}) + event_listener := NewListenerExt(100) + _, err := NewNode(ctx, nil, BaseNodeType, 100, nil, NewEventExt(nil, "Test Event"), event_listener) + fatalErr(t, err) + +} diff --git a/go.mod b/go.mod index f56e843..65ed97f 100644 --- a/go.mod +++ b/go.mod @@ -31,5 +31,6 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/stretchr/testify v1.8.2 // indirect go.opencensus.io v0.22.5 // indirect - golang.org/x/sys v0.6.0 // indirect + golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect + golang.org/x/sys v0.13.0 // indirect ) diff --git a/go.sum b/go.sum index 97a2b05..623b3f1 100644 --- a/go.sum +++ b/go.sum @@ -109,6 +109,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -142,6 +144,8 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/gql.go b/gql.go index 3fdb7c8..3525bb1 100644 --- a/gql.go +++ b/gql.go @@ -924,18 +924,19 @@ func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, return nil, fmt.Errorf(string(val_ser.Data)) } - field_type, field_value, _, err := DeserializeValue(ctx.Context, val_ser) + field_type, _, err := DeserializeType(ctx.Context, val_ser.TypeStack) if err != nil { return nil, err } - if field_value == nil { - return nil, fmt.Errorf("%s returned a nil value of %+v type", gv_tag, field_type) + field_value, _, err := DeserializeValue(ctx.Context, field_type, val_ser.Data) + if err != nil { + return nil, err } ctx.Context.Log.Logf("gql", "Resolving %+v", field_value) - return resolve_fn(p, ctx, *field_value) + return resolve_fn(p, ctx, field_value) } ctx.Fields[gql_name] = Field{ext_type, gv_tag, &graphql.Field{ @@ -1286,6 +1287,21 @@ func NewGQLExtContext() *GQLExtContext { panic(err) } + err = context.RegisterField(graphql.String, "EventName", EventExtType, "name", func(p graphql.ResolveParams, ctx *ResolveContext, val reflect.Value)(interface{}, error) { + name := val.String() + return name, nil + }) + + err = context.RegisterField(graphql.String, "State", EventExtType, "state", func(p graphql.ResolveParams, ctx *ResolveContext, val reflect.Value)(interface{}, error) { + state := val.String() + return state, nil + }) + + err = context.RegisterInterface("Event", "EventNode", []string{"Node"}, []string{"EventName", "State"}, map[string]SelfField{}, map[string]ListField{}) + if err != nil { + panic(err) + } + sub_group_type := graphql.NewObject(graphql.ObjectConfig{ Name: "SubGroup", Interfaces: nil, diff --git a/gql_test.go b/gql_test.go index 5916c87..fff7759 100644 --- a/gql_test.go +++ b/gql_test.go @@ -43,7 +43,7 @@ func TestGQLAuth(t *testing.T) { } func TestGQLServer(t *testing.T) { - ctx := logTestContext(t, []string{"test", "gqlws"}) + ctx := logTestContext(t, []string{"test", "deserialize_types", "serialize_types", "gqlws"}) TestNodeType := NewNodeType("TEST") err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) diff --git a/group.go b/group.go index 09b9867..b5f12b5 100644 --- a/group.go +++ b/group.go @@ -135,8 +135,15 @@ func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, si return Deny } - _, sub_groups_if, _, err := DeserializeValue(ctx, sub_groups_ser) + sub_groups_type, _, err := DeserializeType(ctx, sub_groups_ser.TypeStack) if err != nil { + ctx.Log.Logf("group", "Type deserialize error: %s", err) + return Deny + } + + sub_groups_if, _, err := DeserializeValue(ctx, sub_groups_type, sub_groups_ser.Data) + if err != nil { + ctx.Log.Logf("group", "Value deserialize error: %s", err) return Deny } diff --git a/group_test.go b/group_test.go index ae9bff0..4f87b85 100644 --- a/group_test.go +++ b/group_test.go @@ -55,10 +55,16 @@ func TestGroupAdd(t *testing.T) { sub_groups_serialized := read_response.Extensions[GroupExtType]["sub_groups"] - _, sub_groups_value, remaining, err := DeserializeValue(ctx, sub_groups_serialized) + sub_groups_type, remaining_types, err := DeserializeType(ctx, sub_groups_serialized.TypeStack) + fatalErr(t, err) + if len(remaining_types) > 0 { + t.Fatalf("Types remaining after deserializing subgroups: %d", len(remaining_types)) + } - if len(remaining.Data) > 0 { - t.Fatalf("Data remaining after deserializing subgroups: %d", len(remaining.Data)) + sub_groups_value, remaining, err := DeserializeValue(ctx, sub_groups_type, sub_groups_serialized.Data) + fatalErr(t, err) + if len(remaining) > 0 { + t.Fatalf("Data remaining after deserializing subgroups: %d", len(remaining_types)) } sub_groups, ok := sub_groups_value.Interface().(map[string][]NodeID) diff --git a/lockable.go b/lockable.go index f06c320..ffde98a 100644 --- a/lockable.go +++ b/lockable.go @@ -353,7 +353,12 @@ func (policy RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingAC return Deny } - _, reqs_if, _, err := DeserializeValue(ctx, reqs_ser) + reqs_type, _, err := DeserializeType(ctx, reqs_ser.TypeStack) + if err != nil { + return Deny + } + + reqs_if, _, err := DeserializeValue(ctx, reqs_type, reqs_ser.Data) if err != nil { return Deny } diff --git a/message.go b/message.go new file mode 100644 index 0000000..f6466f7 --- /dev/null +++ b/message.go @@ -0,0 +1,114 @@ +package graphvent + +import ( + "time" + "crypto/ed25519" + "crypto/rand" + "crypto" +) + +type AuthInfo struct { + // The Node that issued the authorization + Identity ed25519.PublicKey + + // Time the authorization was generated + Start time.Time + + // Signature of Start + Principal with Identity private key + Signature []byte +} + +type AuthorizationToken struct { + AuthInfo + + // The private key generated by the client, encrypted with the servers public key + KeyEncrypted []byte +} + +type ClientAuthorization struct { + AuthInfo + + // The private key generated by the client + Key ed25519.PrivateKey +} + +// Authorization structs can be passed in a message that originated from a different node than the sender +type Authorization struct { + AuthInfo + + // The public key generated for this authorization + Key ed25519.PublicKey +} + +type Message struct { + Dest NodeID + Source ed25519.PublicKey + + Authorization *Authorization + + Signal Signal + Signature []byte +} + +type Messages []*Message +func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages { + msg, err := NewMessage(ctx, dest, source, authorization, signal) + if err != nil { + panic(err) + } else { + msgs = append(msgs, msg) + } + return msgs +} + +func NewMessages(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signals... Signal) Messages { + messages := Messages{} + for _, signal := range(signals) { + messages = messages.Add(ctx, dest, source, authorization, signal) + } + return messages +} + +func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) { + signal_ser, err := SerializeAny(ctx, signal) + if err != nil { + return nil, err + } + + ser, err := signal_ser.MarshalBinary() + if err != nil { + return nil, err + } + + dest_ser, err := dest.MarshalBinary() + if err != nil { + return nil, err + } + source_ser, err := source.ID.MarshalBinary() + if err != nil { + return nil, err + } + sig_data := append(dest_ser, source_ser...) + sig_data = append(sig_data, ser...) + var message_auth *Authorization = nil + if authorization != nil { + sig_data = append(sig_data, authorization.Signature...) + message_auth = &Authorization{ + authorization.AuthInfo, + authorization.Key.Public().(ed25519.PublicKey), + } + } + + sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) + if err != nil { + return nil, err + } + + return &Message{ + Dest: dest, + Source: source.Key.Public().(ed25519.PublicKey), + Authorization: message_auth, + Signal: signal, + Signature: sig, + }, nil +} diff --git a/node.go b/node.go index 3ea825c..654ba70 100644 --- a/node.go +++ b/node.go @@ -8,7 +8,6 @@ import ( badger "github.com/dgraph-io/badger/v3" "fmt" "sync/atomic" - "crypto" "crypto/ed25519" "crypto/sha512" "crypto/rand" @@ -476,105 +475,6 @@ func nodeLoop(ctx *Context, node *Node) error { return nil } -type AuthInfo struct { - // The Node that issued the authorization - Identity ed25519.PublicKey - - // Time the authorization was generated - Start time.Time - - // Signature of Start + Principal with Identity private key - Signature []byte -} - -type AuthorizationToken struct { - AuthInfo - - // The private key generated by the client, encrypted with the servers public key - KeyEncrypted []byte -} - -type ClientAuthorization struct { - AuthInfo - - // The private key generated by the client - Key ed25519.PrivateKey -} - -// Authorization structs can be passed in a message that originated from a different node than the sender -type Authorization struct { - AuthInfo - - // The public key generated for this authorization - Key ed25519.PublicKey -} - -type Message struct { - Dest NodeID - Source ed25519.PublicKey - - Authorization *Authorization - - Signal Signal - Signature []byte -} - -type Messages []*Message -func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages { - msg, err := NewMessage(ctx, dest, source, authorization, signal) - if err != nil { - panic(err) - } else { - msgs = append(msgs, msg) - } - return msgs -} - -func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) { - signal_ser, err := SerializeAny(ctx, signal) - if err != nil { - return nil, err - } - - ser, err := signal_ser.MarshalBinary() - if err != nil { - return nil, err - } - - dest_ser, err := dest.MarshalBinary() - if err != nil { - return nil, err - } - source_ser, err := source.ID.MarshalBinary() - if err != nil { - return nil, err - } - sig_data := append(dest_ser, source_ser...) - sig_data = append(sig_data, ser...) - var message_auth *Authorization = nil - if authorization != nil { - sig_data = append(sig_data, authorization.Signature...) - message_auth = &Authorization{ - authorization.AuthInfo, - authorization.Key.Public().(ed25519.PublicKey), - } - } - - sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) - if err != nil { - return nil, err - } - - return &Message{ - Dest: dest, - Source: source.Key.Public().(ed25519.PublicKey), - Authorization: message_auth, - Signal: signal, - Signature: sig, - }, nil -} - - func (node *Node) Stop(ctx *Context) error { if node.Active.Load() { msg, err := NewMessage(ctx, node.ID, node, nil, NewStopSignal()) @@ -805,16 +705,18 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { } else if len(remaining) != 0 { return nil, fmt.Errorf("%d bytes left after parsing node from DB", len(remaining)) } - _, node_val, remaining_data, err := DeserializeValue(ctx, value) + node_type, remaining_types, err := DeserializeType(ctx, value.TypeStack) if err != nil { return nil, err + } else if len(remaining_types) != 0 { + return nil, fmt.Errorf("%d entries left in typestack after deserializing *Node", len(remaining_types)) } - if len(remaining_data.TypeStack) != 0 { - return nil, fmt.Errorf("%d entries left in typestack after deserializing *Node", len(remaining_data.TypeStack)) - } - if len(remaining_data.Data) != 0 { - return nil, fmt.Errorf("%d bytes left after desrializing *Node", len(remaining_data.Data)) + node_val, remaining_data, err := DeserializeValue(ctx, node_type, value.Data) + if err != nil { + return nil, err + } else if len(remaining_data) != 0 { + return nil, fmt.Errorf("%d bytes left after desrializing *Node", len(remaining_data)) } node, ok := node_val.Interface().(*Node) diff --git a/serialize.go b/serialize.go index 7e228c4..d144a55 100644 --- a/serialize.go +++ b/serialize.go @@ -3,7 +3,9 @@ package graphvent import ( "crypto/sha512" "encoding/binary" + "encoding/gob" "fmt" + "math" "reflect" "sort" ) @@ -55,8 +57,10 @@ func (t PolicyType) String() string { return fmt.Sprintf("0x%x", uint64(t)) } -type TypeSerialize func(*Context, SerializedType, reflect.Type, *reflect.Value) (SerializedValue, error) -type TypeDeserialize func(*Context, SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) +type TypeSerializeFn func(*Context, reflect.Type) ([]SerializedType, error) +type SerializeFn func(*Context, reflect.Value) ([]byte, error) +type TypeDeserializeFn func(*Context, []SerializedType) (reflect.Type, []SerializedType, error) +type DeserializeFn func(*Context, reflect.Type, []byte) (reflect.Value, []byte, error) func NewExtType(name string) ExtType { return ExtType(Hash(ExtTypeBase, name)) @@ -84,6 +88,7 @@ var ( GQLExtType = NewExtType("GQL") GroupExtType = NewExtType("GROUP") ACLExtType = NewExtType("ACL") + EventExtType = NewExtType("EVENT") GQLNodeType = NewNodeType("GQL") BaseNodeType = NewNodeType("BASE") @@ -107,6 +112,8 @@ var ( AddSubGroupSignalType = NewSignalType("ADD_SUBGROUP") RemoveSubGroupSignalType = NewSignalType("REMOVE_SUBGROUP") StoppedSignalType = NewSignalType("STOPPED") + EventControlSignalType = NewSignalType("EVENT_CONTORL") + EventStateSignalType = NewSignalType("VEX_MATCH_STATUS") MemberOfPolicyType = NewPolicyType("USER_OF") RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF") @@ -160,306 +167,10 @@ var ( SerializedTypeSerialized = NewSerializedType("SERIALIZED_TYPE") ) -func SerializeSlice(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { - type_stack := []SerializedType{ctx_type} - var data []byte - if value == nil { - data = nil - } else if value.IsZero() { - data = []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} - } else if value.Len() == 0 { - data = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - } else { - data := make([]byte, 8) - binary.BigEndian.PutUint64(data, uint64(value.Len())) - - for i := 0; i < value.Len(); i += 1 { - val := value.Index(i) - element, err := SerializeValue(ctx, reflect_type.Elem(), &val) - if err != nil { - return SerializedValue{}, err - } - data = append(data, element.Data...) - } - - return SerializedValue{ - type_stack, - data, - }, nil - } - return SerializedValue{ - type_stack, - data, - }, nil -} - -func DeserializeSlice[T any](ctx *Context) func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - var zero T - slice_type := reflect.TypeOf(zero) - zero_value, err := SerializeValue(ctx, slice_type.Elem(), nil) - if err != nil { - panic(err) - } - saved_type_stack := zero_value.TypeStack - - return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - if value.Data == nil { - return slice_type, nil, value, nil - } else { - var err error - var slice_size_bytes []byte - slice_size_bytes, value, err = value.PopData(8) - if err != nil { - return nil, nil, value, err - } - slice_size := binary.BigEndian.Uint64(slice_size_bytes) - slice_value := reflect.New(slice_type).Elem() - if slice_size != 0xFFFFFFFFFFFFFFFF { - slice_unaddr := reflect.MakeSlice(slice_type, int(slice_size), int(slice_size)) - slice_value.Set(slice_unaddr) - for i := uint64(0); i < slice_size; i += 1 { - var element_value *reflect.Value - var err error - tmp_value := SerializedValue{ - saved_type_stack, - value.Data, - } - - _, element_value, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - - value.Data = tmp_value.Data - slice_elem := slice_value.Index(int(i)) - slice_elem.Set(*element_value) - } - } - - return slice_type, &slice_value, value, nil - } - } -} - -func SerializeArray(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { - type_stack := []SerializedType{ctx_type} - if value == nil { - return SerializedValue{ - type_stack, - nil, - }, nil - } else { - var element SerializedValue - var err error - var data []byte - for i := 0; i < value.Len(); i += 1 { - val := value.Index(i) - element, err = SerializeValue(ctx, reflect_type.Elem(), &val) - if err != nil { - return SerializedValue{}, err - } - data = append(data, element.Data...) - } - return SerializedValue{ - type_stack, - data, - }, nil - } -} - -func DeserializeArray[T any](ctx *Context) func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - var zero T - array_type := reflect.TypeOf(zero) - array_size := array_type.Len() - zero_value, err := SerializeValue(ctx, array_type.Elem(), nil) - if err != nil { - panic(err) - } - saved_type_stack := zero_value.TypeStack - return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - if value.Data == nil { - return array_type, nil, value, nil - } else { - array_value := reflect.New(array_type).Elem() - for i := 0; i < array_size; i += 1 { - var element_value *reflect.Value - var err error - tmp_value := SerializedValue{ - saved_type_stack, - value.Data, - } - _, element_value, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - value.Data = tmp_value.Data - array_elem := array_value.Index(i) - array_elem.Set(*element_value) - } - return array_type, &array_value, value, nil - } - } -} - -func SerializeUintN(size int) func(*Context, SerializedType, reflect.Type, *reflect.Value) (SerializedValue, error) { - var fill_data func([]byte, uint64) = nil - switch size { - case 1: - fill_data = func(data []byte, val uint64) { - data[0] = byte(val) - } - case 2: - fill_data = func(data []byte, val uint64) { - binary.BigEndian.PutUint16(data, uint16(val)) - } - case 4: - fill_data = func(data []byte, val uint64) { - binary.BigEndian.PutUint32(data, uint32(val)) - } - case 8: - fill_data = func(data []byte, val uint64) { - binary.BigEndian.PutUint64(data, val) - } - default: - panic(fmt.Sprintf("Cannot serialize uint of size %d", size)) - } - return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { - var data []byte = nil - if value != nil { - data = make([]byte, size) - fill_data(data, value.Uint()) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - } -} - -func DeserializeUintN[T interface { - ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 -}](size int) func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - var get_uint func([]byte) uint64 - switch size { - case 1: - get_uint = func(data []byte) uint64 { - return uint64(data[0]) - } - case 2: - get_uint = func(data []byte) uint64 { - return uint64(binary.BigEndian.Uint16(data)) - } - case 4: - get_uint = func(data []byte) uint64 { - return uint64(binary.BigEndian.Uint32(data)) - } - case 8: - get_uint = func(data []byte) uint64 { - return binary.BigEndian.Uint64(data) - } - default: - panic(fmt.Sprintf("Cannot deserialize int of size %d", size)) - } - var zero T - uint_type := reflect.TypeOf(zero) - return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - if value.Data == nil { - return uint_type, nil, value, nil - } else { - var uint_bytes []byte - var err error - uint_bytes, value, err = value.PopData(size) - if err != nil { - return nil, nil, value, err - } - uint_value := reflect.New(uint_type).Elem() - uint_value.SetUint(get_uint(uint_bytes)) - return uint_type, &uint_value, value, nil - } - } -} - -func SerializeIntN(size int) func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { - var fill_data func([]byte, int64) = nil - switch size { - case 1: - fill_data = func(data []byte, val int64) { - data[0] = byte(val) - } - case 2: - fill_data = func(data []byte, val int64) { - binary.BigEndian.PutUint16(data, uint16(val)) - } - case 4: - fill_data = func(data []byte, val int64) { - binary.BigEndian.PutUint32(data, uint32(val)) - } - case 8: - fill_data = func(data []byte, val int64) { - binary.BigEndian.PutUint64(data, uint64(val)) - } - default: - panic(fmt.Sprintf("Cannot serialize int of size %d", size)) - } - return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { - var data []byte = nil - if value != nil { - data = make([]byte, size) - fill_data(data, value.Int()) - } - return SerializedValue{ - []SerializedType{ctx_type}, - data, - }, nil - } -} - -func DeserializeIntN[T interface { - ~int | ~int8 | ~int16 | ~int32 | ~int64 -}](size int) func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - var get_int func([]byte) int64 - switch size { - case 1: - get_int = func(data []byte) int64 { - return int64(data[0]) - } - case 2: - get_int = func(data []byte) int64 { - return int64(binary.BigEndian.Uint16(data)) - } - case 4: - get_int = func(data []byte) int64 { - return int64(binary.BigEndian.Uint32(data)) - } - case 8: - get_int = func(data []byte) int64 { - return int64(binary.BigEndian.Uint64(data)) - } - default: - panic(fmt.Sprintf("Cannot deserialize int of size %d", size)) - } - var zero T - int_type := reflect.TypeOf(zero) - return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - if value.Data == nil { - return int_type, nil, value, nil - } else { - var int_bytes []byte - var err error - int_bytes, value, err = value.PopData(size) - if err != nil { - return nil, nil, value, err - } - int_value := reflect.New(int_type).Elem() - int_value.SetInt(get_int(int_bytes)) - return int_type, &int_value, value, nil - } - } -} - type FieldInfo struct { Index []int TypeStack []SerializedType + Type reflect.Type } type StructInfo struct { @@ -477,7 +188,7 @@ type Deserializable interface { var deserializable_zero Deserializable = nil var DeserializableType = reflect.TypeOf(&deserializable_zero).Elem() -func structInfo(ctx *Context, struct_type reflect.Type) StructInfo { +func GetStructInfo(ctx *Context, struct_type reflect.Type) (StructInfo, error) { field_order := []SerializedType{} field_map := map[SerializedType]FieldInfo{} for _, field := range reflect.VisibleFields(struct_type) { @@ -488,15 +199,16 @@ func structInfo(ctx *Context, struct_type reflect.Type) StructInfo { field_hash := Hash(FieldNameBase, gv_tag) _, exists := field_map[field_hash] if exists == true { - panic(fmt.Sprintf("gv tag %s is repeated", gv_tag)) + return StructInfo{}, fmt.Errorf("gv tag %s is repeated", gv_tag) } else { - field_serialized, err := SerializeValue(ctx, field.Type, nil) + field_type_stack, err := SerializeType(ctx, field.Type) if err != nil { - panic(err) + return StructInfo{}, err } field_map[field_hash] = FieldInfo{ field.Index, - field_serialized.TypeStack, + field_type_stack, + field.Type, } field_order = append(field_order, field_hash) } @@ -527,219 +239,719 @@ func structInfo(ctx *Context, struct_type reflect.Type) StructInfo { field_map, post_deserialize, post_deserialize_idx, - } + }, nil } -func SerializeStruct(ctx *Context, struct_type reflect.Type) func(*Context, SerializedType, reflect.Type, *reflect.Value) (SerializedValue, error) { - struct_info := structInfo(ctx, struct_type) - return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { - type_stack := []SerializedType{ctx_type} - var data []byte - if value == nil { - data = nil - } else { - data = make([]byte, 8) - for _, field_hash := range struct_info.FieldOrder { - field_hash_bytes := make([]byte, 8) - binary.BigEndian.PutUint64(field_hash_bytes, uint64(field_hash)) - field_info := struct_info.FieldMap[field_hash] - field_value := value.FieldByIndex(field_info.Index) - field_serialized, err := SerializeValue(ctx, field_value.Type(), &field_value) - if err != nil { - return SerializedValue{}, err - } - data = append(data, field_hash_bytes...) - data = append(data, field_serialized.Data...) +func SerializeStruct(info StructInfo)func(*Context, reflect.Value)([]byte, error) { + return func(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 8) + for _, field_hash := range(info.FieldOrder) { + field_hash_bytes := make([]byte, 8) + binary.BigEndian.PutUint64(field_hash_bytes, uint64(field_hash)) + + field_info := info.FieldMap[field_hash] + field_value := value.FieldByIndex(field_info.Index) + + field_serialized, err := SerializeValue(ctx, field_value) + if err != nil { + return nil, err } - binary.BigEndian.PutUint64(data[0:8], uint64(len(struct_info.FieldOrder))) + + data = append(data, field_hash_bytes...) + data = append(data, field_serialized...) } - return SerializedValue{ - type_stack, - data, - }, nil + binary.BigEndian.PutUint64(data[0:8], uint64(len(info.FieldOrder))) + return data, nil } } -func DeserializeStruct(ctx *Context, struct_type reflect.Type) func(*Context, SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - struct_info := structInfo(ctx, struct_type) - return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - if value.Data == nil { - return struct_info.Type, nil, value, nil - } else { - var num_fields_bytes []byte +func DeserializeStruct(info StructInfo)func(*Context, reflect.Type, []byte)(reflect.Value, []byte, error) { + return func(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 8 { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize struct %d/8", len(data)) + } + + num_field_bytes := data[:8] + data = data[8:] + + num_fields := binary.BigEndian.Uint64(num_field_bytes) + + struct_value := reflect.New(reflect_type).Elem() + for i := uint64(0); i < num_fields; i ++ { + field_hash_bytes := data[:8] + data = data[8:] + field_hash := SerializedType(binary.BigEndian.Uint64(field_hash_bytes)) + field_info, exists := info.FieldMap[field_hash] + if exists == false { + return reflect.Value{}, nil, fmt.Errorf("0x%x is not a field in %+v", field_hash, info.Type) + } + + var field_value reflect.Value var err error - num_fields_bytes, value, err = value.PopData(8) + field_value, data, err = DeserializeValue(ctx, field_info.Type, data) if err != nil { - return nil, nil, value, err + return reflect.Value{}, nil, err } - num_fields := int(binary.BigEndian.Uint64(num_fields_bytes)) - ctx.Log.Logf("serialize", "Deserializing %d fields from %+v", num_fields, struct_info) - struct_value := reflect.New(struct_info.Type).Elem() + field_reflect := struct_value.FieldByIndex(field_info.Index) + field_reflect.Set(field_value) + } - for i := 0; i < num_fields; i += 1 { - var field_hash_bytes []byte - field_hash_bytes, value, err = value.PopData(8) - if err != nil { - return nil, nil, value, err - } - field_hash := SerializedType(binary.BigEndian.Uint64(field_hash_bytes)) - field_info, exists := struct_info.FieldMap[field_hash] - if exists == false { - return nil, nil, value, fmt.Errorf("Field 0x%x is not valid for %+v: %+v", field_hash, struct_info.Type, struct_info.FieldMap) - } - field_value := struct_value.FieldByIndex(field_info.Index) + return struct_value, data, nil + } +} - tmp_value := SerializedValue{ - field_info.TypeStack, - value.Data, - } +func SerializeGob(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 8) + gob_ser, err := value.Interface().(gob.GobEncoder).GobEncode() + if err != nil { + return nil, err + } - var field_reflect *reflect.Value - _, field_reflect, tmp_value, err = DeserializeValue(ctx, tmp_value) - if err != nil { - return nil, nil, value, err - } - value.Data = tmp_value.Data - field_value.Set(*field_reflect) - } + binary.BigEndian.PutUint64(data, uint64(len(gob_ser))) + return append(data, gob_ser...), nil +} - if struct_info.PostDeserialize == true { - ctx.Log.Logf("serialize", "running post-deserialize for %+v", struct_info.Type) - post_deserialize_method := struct_value.Addr().Method(struct_info.PostDeserializeIdx) - ret := post_deserialize_method.Call([]reflect.Value{reflect.ValueOf(ctx)}) - if ret[0].IsZero() == false { - return nil, nil, value, ret[0].Interface().(error) - } - } +func DeserializeGob[T any, PT interface{gob.GobDecoder; *T}](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 8 { + return reflect.Value{}, nil, fmt.Errorf("Not enough bytes to deserialize gob %d/8", len(data)) + } - return struct_info.Type, &struct_value, value, err - } + size_bytes := data[:8] + size := binary.BigEndian.Uint64(size_bytes) + gob_data := data[8:8+size] + data = data[8+size:] + + gob_ptr := reflect.New(reflect_type) + err := gob_ptr.Interface().(gob.GobDecoder).GobDecode(gob_data) + if err != nil { + return reflect.Value{}, nil, err } + + return gob_ptr.Elem(), data, nil } -func SerializeInterface(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { - var data []byte - type_stack := []SerializedType{ctx_type} - if value == nil { - data = nil - } else if value.IsZero() { - data = []byte{0x01} +func SerializeInt8(ctx *Context, value reflect.Value) ([]byte, error) { + data := []byte{byte(value.Int())} + + return data, nil +} + +func SerializeInt16(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 2) + binary.BigEndian.PutUint16(data, uint16(value.Int())) + + return data, nil +} + +func SerializeInt32(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 4) + binary.BigEndian.PutUint32(data, uint32(value.Int())) + + return data, nil +} + +func SerializeInt64(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, uint64(value.Int())) + + return data, nil +} + +func SerializeUint8(ctx *Context, value reflect.Value) ([]byte, error) { + data := []byte{byte(value.Uint())} + + return data, nil +} + +func SerializeUint16(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 2) + binary.BigEndian.PutUint16(data, uint16(value.Uint())) + + return data, nil +} + +func SerializeUint32(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 4) + binary.BigEndian.PutUint32(data, uint32(value.Uint())) + + return data, nil +} + +func SerializeUint64(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, value.Uint()) + + return data, nil +} + +func DeserializeUint64[T ~uint64 | ~int64](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + uint_size := 8 + if len(data) < uint_size { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize uint %d/%d", len(data), uint_size) + } + + uint_bytes := data[:uint_size] + data = data[uint_size:] + uint_value := reflect.New(reflect_type).Elem() + + typed_value := T(binary.BigEndian.Uint64(uint_bytes)) + uint_value.Set(reflect.ValueOf(typed_value)) + + return uint_value, data, nil +} + +func DeserializeUint32[T ~uint32 | ~uint | ~int32 | ~int](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + uint_size := 4 + if len(data) < uint_size { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize uint %d/%d", len(data), uint_size) + } + + uint_bytes := data[:uint_size] + data = data[uint_size:] + uint_value := reflect.New(reflect_type).Elem() + + typed_value := T(binary.BigEndian.Uint32(uint_bytes)) + uint_value.Set(reflect.ValueOf(typed_value)) + + return uint_value, data, nil +} + +func DeserializeUint16[T ~uint16 | ~int16](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + uint_size := 2 + if len(data) < uint_size { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize uint %d/%d", len(data), uint_size) + } + + uint_bytes := data[:uint_size] + data = data[uint_size:] + uint_value := reflect.New(reflect_type).Elem() + + typed_value := T(binary.BigEndian.Uint16(uint_bytes)) + uint_value.Set(reflect.ValueOf(typed_value)) + + return uint_value, data, nil +} + +func DeserializeUint8[T ~uint8 | ~int8](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + uint_size := 1 + if len(data) < uint_size { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize uint %d/%d", len(data), uint_size) + } + + uint_bytes := data[:uint_size] + data = data[uint_size:] + uint_value := reflect.New(reflect_type).Elem() + + typed_value := T(uint_bytes[0]) + uint_value.Set(reflect.ValueOf(typed_value)) + + return uint_value, data, nil +} + +func SerializeFloat64(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 8) + float_representation := math.Float64bits(value.Float()) + binary.BigEndian.PutUint64(data, float_representation) + return data, nil +} + +func DeserializeFloat64[T ~float64](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 8 { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize float64 %d/8", len(data)) + } + + float_bytes := data[0:8] + data = data[8:] + + float_representation := binary.BigEndian.Uint64(float_bytes) + float := math.Float64frombits(float_representation) + + float_value := reflect.New(reflect_type).Elem() + float_value.Set(reflect.ValueOf(T(float))) + + return float_value, data, nil +} + +func SerializeFloat32(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 4) + float_representation := math.Float32bits(float32(value.Float())) + binary.BigEndian.PutUint32(data, float_representation) + return data, nil +} + +func DeserializeFloat32[T ~float32](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 4 { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize float32 %d/4", len(data)) + } + + float_bytes := data[0:4] + data = data[4:] + + float_representation := binary.BigEndian.Uint32(float_bytes) + float := math.Float32frombits(float_representation) + + float_value := reflect.New(reflect_type).Elem() + float_value.Set(reflect.ValueOf(T(float))) + + return float_value, data, nil +} + +func SerializeString(ctx *Context, value reflect.Value) ([]byte, error) { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, uint64(value.Len())) + + return append(data, []byte(value.String())...), nil +} + +func DeserializeString(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 8 { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize string %d/8", len(data)) + } + + size_bytes := data[0:8] + data = data[8:] + + size := binary.BigEndian.Uint64(size_bytes) + if len(data) < int(size) { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize string of len %d, %d/%d", size, len(data), size) + } + + string_value := reflect.New(reflect_type).Elem() + string_value.Set(reflect.ValueOf(string(data[:size]))) + data = data[size:] + + return string_value, data, nil +} + +func SerializeBool(ctx *Context, value reflect.Value) ([]byte, error) { + if value.Bool() == true { + return []byte{0xFF}, nil + } else { + return []byte{0x00}, nil + } +} + +func DeserializeBool[T ~bool](ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 1 { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize bool %d/1", len(data)) + } + byte := data[0] + data = data[1:] + + bool_value := reflect.New(reflect_type).Elem() + if byte == 0x00 { + bool_value.Set(reflect.ValueOf(T(false))) + } else { + bool_value.Set(reflect.ValueOf(T(true))) + } + + return bool_value, data, nil +} + +func DeserializeTypePointer(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { + elem_type, remaining, err := DeserializeType(ctx, type_stack) + if err != nil { + return nil, nil, err + } + + return reflect.PointerTo(elem_type), remaining, nil +} + +func SerializePointer(ctx *Context, value reflect.Value) ([]byte, error) { + if value.IsZero() { + return []byte{0x00}, nil } else { - data = []byte{0x00} - elem_value := value.Elem() - elem, err := SerializeValue(ctx, elem_value.Type(), &elem_value) + flags := []byte{0x01} + + elem_data, err := SerializeValue(ctx, value.Elem()) if err != nil { - return SerializedValue{}, err + return nil, err } - elem_data, err := elem.MarshalBinary() + + return append(flags, elem_data...), nil + } +} + +func DeserializePointer(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 1 { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize pointer %d/1", len(data)) + } + + flags := data[0] + data = data[1:] + + pointer_value := reflect.New(reflect_type).Elem() + + if flags != 0x00 { + var element_value reflect.Value + var err error + element_value, data, err = DeserializeValue(ctx, reflect_type.Elem(), data) if err != nil { - return SerializedValue{}, err + return reflect.Value{}, nil, err } - data = append(data, elem_data...) + + pointer_value.Set(element_value.Addr()) } - return SerializedValue{ - type_stack, - data, - }, nil + + return pointer_value, data, nil } -func DeserializeInterface[T any]() func(*Context, SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - var interface_zero T - var interface_type = reflect.ValueOf(&interface_zero).Type().Elem() - if value.Data == nil { - return interface_type, nil, value, nil - } else { - var flag_bytes []byte - var err error - flag_bytes, value, err = value.PopData(1) +func SerializeTypeStub(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { + return nil, nil +} + +func DeserializeTypeStub[T any](ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { + var zero T + return reflect.TypeOf(zero), type_stack, nil +} + +func SerializeTypeElem(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { + return SerializeType(ctx, reflect_type.Elem()) +} + +func SerializeSlice(ctx *Context, value reflect.Value) ([]byte, error) { + if value.IsZero() { + return []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, nil + } else if value.Len() == 0 { + return []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, nil + } else { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, uint64(value.Len())) + + for i := 0; i < value.Len(); i += 1 { + val := value.Index(i) + element, err := SerializeValue(ctx, val) if err != nil { - return nil, nil, value, err + return nil, err } + data = append(data, element...) + } - interface_value := reflect.New(interface_type).Elem() - nil_flag := flag_bytes[0] - if nil_flag == 0x01 { - } else if nil_flag == 0x00 { - var elem_value *reflect.Value - var elem_ser SerializedValue - elem_ser, value.Data, err = ParseSerializedValue(value.Data) - _, elem_value, _, err = DeserializeValue(ctx, elem_ser) - if err != nil { - return nil, nil, value, err - } - interface_value.Set(*elem_value) - } else { - return nil, nil, value, fmt.Errorf("Unknown interface nil_flag value 0x%x", nil_flag) + return data, nil + } +} + +func DeserializeTypeSlice(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { + elem_type, remaining, err := DeserializeType(ctx, type_stack) + if err != nil { + return nil, nil, err + } + + reflect_type := reflect.SliceOf(elem_type) + return reflect_type, remaining, nil +} + +func DeserializeSlice(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 8 { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize slice %d/8", len(data)) + } + + slice_size := binary.BigEndian.Uint64(data[0:8]) + slice_value := reflect.New(reflect_type).Elem() + data = data[8:] + + if slice_size != 0xFFFFFFFFFFFFFFFF { + slice_unaddr := reflect.MakeSlice(reflect_type, int(slice_size), int(slice_size)) + slice_value.Set(slice_unaddr) + for i := uint64(0); i < slice_size; i += 1 { + var element_value reflect.Value + var err error + element_value, data, err = DeserializeValue(ctx, reflect_type.Elem(), data) + if err != nil { + return reflect.Value{}, nil, err } - return interface_type, &interface_value, value, nil + + slice_elem := slice_value.Index(int(i)) + slice_elem.Set(element_value) } } + + return slice_value, data, nil } -type SerializedValue struct { - TypeStack []SerializedType - Data []byte +func SerializeTypeMap(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { + key_stack, err := SerializeType(ctx, reflect_type.Key()) + if err != nil { + return nil, err + } + + elem_stack, err := SerializeType(ctx, reflect_type.Elem()) + if err != nil { + return nil, err + } + + return append(key_stack, elem_stack...), nil } -func (value SerializedValue) PopType() (SerializedType, SerializedValue, error) { - if len(value.TypeStack) == 0 { - return SerializedType(0), value, fmt.Errorf("No elements in TypeStack") +func DeserializeTypeMap(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { + key_type, after_key, err := DeserializeType(ctx, type_stack) + if err != nil { + return nil, nil, err + } + + elem_type, after_elem, err := DeserializeType(ctx, after_key) + if err != nil { + return nil, nil, err } - ctx_type := value.TypeStack[0] - value.TypeStack = value.TypeStack[1:] - return ctx_type, value, nil + + map_type := reflect.MapOf(key_type, elem_type) + return map_type, after_elem, nil } -func (value SerializedValue) PopData(n int) ([]byte, SerializedValue, error) { - if len(value.Data) < n { - return nil, value, fmt.Errorf("Not enough data %d/%d", len(value.Data), n) +func SerializeMap(ctx *Context, value reflect.Value) ([]byte, error) { + if value.IsZero() == true { + return []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, nil } - data := value.Data[0:n] - value.Data = value.Data[n:] - return data, value, nil + map_data := []byte{} + map_size := uint64(0) + map_iter := value.MapRange() + for map_iter.Next() { + map_size = map_size + 1 + key := map_iter.Key() + val := map_iter.Value() + + key_data, err := SerializeValue(ctx, key) + if err != nil { + return nil, err + } + map_data = append(map_data, key_data...) + + val_data, err := SerializeValue(ctx, val) + if err != nil { + return nil, err + } + map_data = append(map_data, val_data...) + } + + size_data := make([]byte, 8) + binary.BigEndian.PutUint64(size_data, map_size) + + return append(size_data, map_data...), nil +} + +func DeserializeMap(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 8 { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize map %d/8", len(data)) + } + + size_bytes := data[:8] + data = data[8:] + + size := binary.BigEndian.Uint64(size_bytes) + + map_value := reflect.New(reflect_type).Elem() + if size == 0xFFFFFFFFFFFFFFFF { + return map_value, data, nil + } + + map_unaddr := reflect.MakeMapWithSize(reflect_type, int(size)) + map_value.Set(map_unaddr) + + for i := uint64(0); i < size; i++ { + var err error + var key_value reflect.Value + key_value, data, err = DeserializeValue(ctx, reflect_type.Key(), data) + if err != nil { + return reflect.Value{}, nil, err + } + + var val_value reflect.Value + val_value, data, err = DeserializeValue(ctx, reflect_type.Elem(), data) + if err != nil { + return reflect.Value{}, nil, err + } + + map_value.SetMapIndex(key_value, val_value) + } + + return map_value, data, nil +} + +func SerializeTypeArray(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { + size := SerializedType(reflect_type.Len()) + elem_stack, err := SerializeType(ctx, reflect_type.Elem()) + if err != nil { + return nil, err + } + + return append([]SerializedType{size}, elem_stack...), nil +} + +func SerializeArray(ctx *Context, value reflect.Value) ([]byte, error) { + data := []byte{} + for i := 0; i < value.Len(); i += 1 { + element := value.Index(i) + element_data, err := SerializeValue(ctx, element) + if err != nil { + return nil, err + } + data = append(data, element_data...) + } + + return data, nil +} + +func DeserializeTypeArray(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { + if len(type_stack) < 1 { + return nil, nil, fmt.Errorf("Not enough valued in type stack to deserialize array") + } + + size := int(type_stack[0]) + element_type, remaining, err := DeserializeType(ctx, type_stack[1:]) + if err != nil { + return nil, nil, err + } + + array_type := reflect.ArrayOf(size, element_type) + return array_type, remaining, nil +} + +func DeserializeArray(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + array_value := reflect.New(reflect_type).Elem() + for i := 0; i < array_value.Len(); i += 1 { + var element_value reflect.Value + var err error + element_value, data, err = DeserializeValue(ctx, reflect_type.Elem(), data) + if err != nil { + return reflect.Value{}, nil, err + } + + element := array_value.Index(i) + element.Set(element_value) + } + + return array_value, data, nil +} + +func SerializeInterface(ctx *Context, value reflect.Value) ([]byte, error) { + if value.IsZero() == true { + return []byte{0xFF}, nil + } + + type_stack, err := SerializeType(ctx, value.Elem().Type()) + if err != nil { + return nil, err + } + + data, err := SerializeValue(ctx, value.Elem()) + if err != nil { + return nil, err + } + + serialized_value, err := SerializedValue{type_stack, data}.MarshalBinary() + if err != nil { + return nil, err + } + + return append([]byte{0x00}, serialized_value...), nil +} + +func DeserializeInterface(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + if len(data) < 1 { + return reflect.Value{}, nil, fmt.Errorf("Not enough data to deserialize interface %d/1", len(data)) + } + + flags := data[0] + data = data[1:] + if flags == 0xFF { + return reflect.New(reflect_type).Elem(), data, nil + } + + serialized_value, remaining, err := ParseSerializedValue(data) + elem_type, types_remaining, err := DeserializeType(ctx, serialized_value.TypeStack) + if err != nil { + return reflect.Value{}, nil, err + } else if len(types_remaining) > 0 { + return reflect.Value{}, nil, fmt.Errorf("Types remaining in interface stack after deserializing") + } + + elem_value, data_remaining, err := DeserializeValue(ctx, elem_type, serialized_value.Data) + if err != nil { + return reflect.Value{}, nil, err + } else if len(data_remaining) > 0 { + return reflect.Value{}, nil, fmt.Errorf("Data remaining in interface data after deserializing") + } + + interface_value := reflect.New(reflect_type).Elem() + interface_value.Set(elem_value) + + return interface_value, remaining, nil +} + +type SerializedValue struct { + TypeStack []SerializedType + Data []byte } func SerializeAny[T any](ctx *Context, value T) (SerializedValue, error) { reflect_value := reflect.ValueOf(value) - return SerializeValue(ctx, reflect_value.Type(), &reflect_value) + type_stack, err := SerializeType(ctx, reflect_value.Type()) + if err != nil { + return SerializedValue{}, err + } + data, err := SerializeValue(ctx, reflect_value) + if err != nil { + return SerializedValue{}, err + } + return SerializedValue{type_stack, data}, nil } -func SerializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (SerializedValue, error) { - ctx.Log.Logf("serialize", "Serializing: %+v - %+v", t, value) - type_info, type_exists := ctx.TypeReflects[t] +func SerializeType(ctx *Context, reflect_type reflect.Type) ([]SerializedType, error) { + ctx.Log.Logf("serialize", "Serializing type %+v", reflect_type) + + type_info, type_exists := ctx.TypeReflects[reflect_type] + var serialize_type TypeSerializeFn = nil var ctx_type SerializedType - var ctx_name string - var serialize TypeSerialize = nil if type_exists == true { + serialize_type = type_info.TypeSerialize ctx_type = type_info.Type - ctx_name = type_info.Reflect.Name() - if type_info.Serialize != nil { - serialize = type_info.Serialize + } + + if serialize_type == nil { + kind_info, handled := ctx.Kinds[reflect_type.Kind()] + if handled == true { + if type_exists == false { + ctx_type = kind_info.Type + } + serialize_type = kind_info.TypeSerialize } } - kind := t.Kind() - kind_info, handled := ctx.Kinds[kind] - if handled == false { - return SerializedValue{}, fmt.Errorf("Don't know how to serialize kind %+v", kind) - } else if type_exists == false { - ctx_type = kind_info.Type - ctx_name = kind_info.Reflect.String() + type_stack := []SerializedType{ctx_type} + if serialize_type != nil { + extra_types, err := serialize_type(ctx, reflect_type) + if err != nil { + return nil, err + } + return append(type_stack, extra_types...), nil + } else { + return type_stack, nil } +} - if serialize == nil { - serialize = kind_info.Serialize +func SerializeValue(ctx *Context, value reflect.Value) ([]byte, error) { + ctx.Log.Logf("serialize", "Serializing value %+v", value) + + type_info, type_exists := ctx.TypeReflects[value.Type()] + var serialize SerializeFn = nil + if type_exists == true { + if type_info.Serialize != nil { + serialize = type_info.Serialize + } } - serialized_value, err := serialize(ctx, ctx_type, t, value) - if err != nil { - return serialized_value, err + if serialize == nil { + kind_info, handled := ctx.Kinds[value.Kind()] + if handled { + serialize = kind_info.Serialize + } else { + return nil, fmt.Errorf("Don't know how to serialize %+v", value.Type()) + } } - ctx.Log.Logf("serialize", "Serialized %+v: %+v", ctx_name, serialized_value) - return serialized_value, err + + return serialize(ctx, value) } func ExtField(ctx *Context, ext Extension, field_name string) (reflect.Value, error) { @@ -763,8 +975,15 @@ func SerializeField(ctx *Context, ext Extension, field_name string) (SerializedV if err != nil { return SerializedValue{}, err } - - return SerializeValue(ctx, field_value.Type(), &field_value) + type_stack, err := SerializeType(ctx, field_value.Type()) + if err != nil { + return SerializedValue{}, err + } + data, err := SerializeValue(ctx, field_value) + if err != nil { + return SerializedValue{}, err + } + return SerializedValue{type_stack, data}, nil } func (value SerializedValue) MarshalBinary() ([]byte, error) { @@ -806,47 +1025,52 @@ func ParseSerializedValue(data []byte) (SerializedValue, []byte, error) { }, data[data_end:], nil } -func DeserializeValue(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { +func DeserializeValue(ctx *Context, reflect_type reflect.Type, data []byte) (reflect.Value, []byte, error) { + ctx.Log.Logf("serialize", "Deserializing %+v with %d bytes", reflect_type, len(data)) + var deserialize DeserializeFn = nil - var deserialize TypeDeserialize = nil + type_info, type_exists := ctx.TypeReflects[reflect_type] + if type_exists == true { + deserialize = type_info.Deserialize + } else { + kind_info, exists := ctx.Kinds[reflect_type.Kind()] + if exists == false { + return reflect.Value{}, nil, fmt.Errorf("Cannot deserialize %+v/%+v: unknown type/kind", reflect_type, reflect_type.Kind()) + } + deserialize = kind_info.Deserialize + } + + return deserialize(ctx, reflect_type, data) +} + +func DeserializeType(ctx *Context, type_stack []SerializedType) (reflect.Type, []SerializedType, error) { + ctx.Log.Logf("deserialize_types", "Deserializing type stack %+v", type_stack) + var deserialize_type TypeDeserializeFn = nil var reflect_type reflect.Type = nil - var reflect_value *reflect.Value = nil - ctx_type, value, err := value.PopType() - if err != nil { - return nil, nil, value, err + if len(type_stack) < 1 { + return nil, nil, fmt.Errorf("No elements in type stack to deserialize(DeserializeType)") } - var ctx_name string + ctx_type := type_stack[0] + type_stack = type_stack[1:] type_info, type_exists := ctx.Types[SerializedType(ctx_type)] if type_exists == true { - deserialize = type_info.Deserialize - ctx_name = type_info.Reflect.Name() + deserialize_type = type_info.TypeDeserialize + reflect_type = type_info.Reflect } else { kind_info, exists := ctx.KindTypes[SerializedType(ctx_type)] if exists == false { - return nil, nil, value, fmt.Errorf("Cannot deserialize 0x%x: unknown type/kind", ctx_type) + return nil, nil, fmt.Errorf("Cannot deserialize 0x%x: unknown type/kind", ctx_type) } - deserialize = kind_info.Deserialize - ctx_name = kind_info.Reflect.String() - } - - ctx.Log.Logf("serialize", "Deserializing: %+v(%+v) - %+v", ctx_name, ctx_type, value.TypeStack) - - if value.Data == nil { - reflect_type, _, value, err = deserialize(ctx, value) - } else { - reflect_type, reflect_value, value, err = deserialize(ctx, value) - } - if err != nil { - return nil, nil, value, err + deserialize_type = kind_info.TypeDeserialize + reflect_type = kind_info.Base } - if reflect_value != nil { - ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v", reflect_type, reflect_value.Interface(), err) + if deserialize_type == nil { + return reflect_type, type_stack, nil } else { - ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v", reflect_type, reflect_value, err) + return deserialize_type(ctx, type_stack) } - return reflect_type, reflect_value, value, nil } diff --git a/serialize_test.go b/serialize_test.go index 4cd3e30..a78d744 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -8,14 +8,26 @@ import ( ) func TestSerializeTest(t *testing.T) { - ctx := logTestContext(t, []string{"test", "serialize"}) + ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"}) testSerialize(t, ctx, map[string][]NodeID{"test_group": {RandID(), RandID(), RandID()}}) + testSerialize(t, ctx, map[NodeID]ReqInfo{ + RandID(): {}, + RandID(): {}, + RandID(): {}, + }) } func TestSerializeBasic(t *testing.T) { - ctx := logTestContext(t, []string{"test"}) - testSerializeComparable[string](t, ctx, "test") + ctx := logTestContext(t, []string{"test", "serialize"}) testSerializeComparable[bool](t, ctx, true) + + type bool_wrapped bool + err := ctx.RegisterType(reflect.TypeOf(bool_wrapped(true)), NewSerializedType("BOOL_WRAPPED"), nil, nil, nil, DeserializeBool[bool_wrapped]) + fatalErr(t, err) + testSerializeComparable[bool_wrapped](t, ctx, true) + + testSerializeSlice[[]bool](t, ctx, []bool{false, false, true, false}) + testSerializeComparable[string](t, ctx, "test") testSerializeComparable[float32](t, ctx, 0.05) testSerializeComparable[float64](t, ctx, 0.05) testSerializeComparable[uint](t, ctx, uint(1234)) @@ -36,7 +48,19 @@ func TestSerializeBasic(t *testing.T) { testSerializeSliceSlice[[][]string](t, ctx, [][]string{{"123", "456", "789", "101112"}, {"3253", "2341", "735", "212"}, {"123", "51"}, nil}) testSerialize(t, ctx, map[int8]map[*int8]string{}) + testSerialize(t, ctx, map[int8]time.Time{ + 1: time.Now(), + 3: time.Now().Add(time.Second), + 0: time.Now().Add(time.Second*2), + 4: time.Now().Add(time.Second*3), + }) + testSerialize(t, ctx, Tree{ + NodeTypeSerialized: nil, + SerializedTypeSerialized: Tree{ + NodeTypeSerialized: Tree{}, + }, + }) var i interface{} = nil testSerialize(t, ctx, i) @@ -61,7 +85,10 @@ func TestSerializeBasic(t *testing.T) { } test_struct_type := reflect.TypeOf(test_struct{}) - err := ctx.RegisterType(test_struct_type, NewSerializedType("TEST_STRUCT"), SerializeStruct(ctx, test_struct_type), DeserializeStruct(ctx, test_struct_type)) + test_struct_info, err := GetStructInfo(ctx, test_struct_type) + fatalErr(t, err) + + err = ctx.RegisterType(test_struct_type, NewSerializedType("TEST_STRUCT"), nil, SerializeStruct(test_struct_info), nil, DeserializeStruct(test_struct_info)) fatalErr(t, err) testSerialize(t, ctx, test_struct{ @@ -70,7 +97,6 @@ func TestSerializeBasic(t *testing.T) { }) testSerialize(t, ctx, Tree{ - ErrorType: nil, MapType: nil, StringType: nil, }) @@ -89,9 +115,10 @@ func TestSerializeBasic(t *testing.T) { type test_slice []string test_slice_type := reflect.TypeOf(test_slice{}) - err = ctx.RegisterType(test_slice_type, NewSerializedType("TEST_SLICE"), SerializeSlice, DeserializeSlice[test_slice](ctx)) + err = ctx.RegisterType(test_slice_type, NewSerializedType("TEST_SLICE"), SerializeTypeStub, SerializeSlice, DeserializeTypeStub[test_slice], DeserializeSlice) fatalErr(t, err) + testSerialize[[]string](t, ctx, []string{"test_1", "test_2", "test_3"}) testSerialize[test_slice](t, ctx, test_slice{"test_1", "test_2", "test_3"}) testSerialize[Changes](t, ctx, Changes{"change_1", "change_2", "change_3"}) @@ -112,7 +139,9 @@ func TestSerializeStructTags(t *testing.T) { test_type := NewSerializedType("TEST_STRUCT") test_struct_type := reflect.TypeOf(test{}) ctx.Log.Logf("test", "TEST_TYPE: %+v", test_type) - ctx.RegisterType(test_struct_type, test_type, SerializeStruct(ctx, test_struct_type), DeserializeStruct(ctx, test_struct_type)) + test_struct_info, err := GetStructInfo(ctx, test_struct_type) + fatalErr(t, err) + ctx.RegisterType(test_struct_type, test_type, nil, SerializeStruct(test_struct_info), nil, DeserializeStruct(test_struct_info)) test_int := 10 test_string := "test" @@ -190,7 +219,9 @@ func testSerializeComparable[T comparable](t *testing.T, ctx *Context, val T) { func testSerialize[T any](t *testing.T, ctx *Context, val T) T { value := reflect.ValueOf(&val).Elem() - value_serialized, err := SerializeValue(ctx, value.Type(), &value) + type_stack, err := SerializeType(ctx, value.Type()) + data, err := SerializeValue(ctx, value) + value_serialized := SerializedValue{type_stack, data} fatalErr(t, err) ctx.Log.Logf("test", "Serialized %+v to %+v", val, value_serialized) @@ -206,19 +237,16 @@ func testSerialize[T any](t *testing.T, ctx *Context, val T) T { t.Fatal("Data remaining after deserializing value") } - val_type, deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_parsed) + val_type, remaining_types, err := DeserializeType(ctx, val_parsed.TypeStack) + deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_type, val_parsed.Data) fatalErr(t, err) - if len(remaining_deserialize.Data) != 0 { + if len(remaining_deserialize) != 0 { t.Fatal("Data remaining after deserializing value") - } else if len(remaining_deserialize.TypeStack) != 0 { + } else if len(remaining_types) != 0 { t.Fatal("TypeStack remaining after deserializing value") } else if val_type != value.Type() { t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - %+v", val_type, reflect.TypeOf(val))) - } else if deserialized_value == nil { - t.Fatal("DeserializeValue returned no []reflect.Value") - } else if deserialized_value == nil { - t.Fatal("DeserializeValue returned nil *reflect.Value") } else if deserialized_value.CanConvert(val_type) == false { t.Fatal("DeserializeValue returned value that can't convert to original value") }