diff --git a/context.go b/context.go index c050445..454ea29 100644 --- a/context.go +++ b/context.go @@ -28,12 +28,14 @@ type NodeInfo struct { } type TypeInfo struct { - Type reflect.Type + Reflect reflect.Type + Type SerializedType Serialize TypeSerialize Deserialize TypeDeserialize } type KindInfo struct { + Reflect reflect.Kind Type SerializedType Serialize TypeSerialize Deserialize TypeDeserialize @@ -57,11 +59,11 @@ type Context struct { // Map between database type hashes and the registered info Nodes map[NodeType]NodeInfo // Map between go types and registered info - Types map[SerializedType]TypeInfo - TypeReflects map[reflect.Type]SerializedType + Types map[SerializedType]*TypeInfo + TypeReflects map[reflect.Type]*TypeInfo - Kinds map[reflect.Kind]KindInfo - KindTypes map[SerializedType]reflect.Kind + Kinds map[reflect.Kind]*KindInfo + KindTypes map[SerializedType]*KindInfo // Routing map to all the nodes local to this context nodeMapLock sync.RWMutex @@ -150,12 +152,14 @@ func (ctx *Context)RegisterKind(kind reflect.Kind, ctx_type SerializedType, seri return fmt.Errorf("Cannot register field without serialize function") } - ctx.Kinds[kind] = KindInfo{ + info := KindInfo{ + kind, ctx_type, serialize, deserialize, } - ctx.KindTypes[ctx_type] = kind + ctx.KindTypes[ctx_type] = &info + ctx.Kinds[kind] = &info return nil } @@ -169,19 +173,15 @@ func (ctx *Context)RegisterType(reflect_type reflect.Type, ctx_type SerializedTy if exists == true { return fmt.Errorf("Cannot register field with type %+v, type already registered in context", reflect_type) } - if deserialize == nil { - return fmt.Errorf("Cannot register field without deserialize function") - } - if serialize == nil { - return fmt.Errorf("Cannot register field without serialize function") - } - ctx.Types[ctx_type] = TypeInfo{ - Type: reflect_type, + type_info := TypeInfo{ + Reflect: reflect_type, + Type: ctx_type, Serialize: serialize, Deserialize: deserialize, } - ctx.TypeReflects[reflect_type] = ctx_type + ctx.Types[ctx_type] = &type_info + ctx.TypeReflects[reflect_type] = &type_info return nil } @@ -253,10 +253,10 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { SignalTypes: map[reflect.Type]SignalType{}, Nodes: map[NodeType]NodeInfo{}, nodeMap: map[NodeID]*Node{}, - Types: map[SerializedType]TypeInfo{}, - TypeReflects: map[reflect.Type]SerializedType{}, - Kinds: map[reflect.Kind]KindInfo{}, - KindTypes: map[SerializedType]reflect.Kind{}, + Types: map[SerializedType]*TypeInfo{}, + TypeReflects: map[reflect.Type]*TypeInfo{}, + Kinds: map[reflect.Kind]*KindInfo{}, + KindTypes: map[SerializedType]*KindInfo{}, } var err error @@ -299,12 +299,14 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { pointer_flags := value.Data[0] value.Data = value.Data[1:] if pointer_flags == 0x00 { - _, elem_value, remaining_data, err := DeserializeValue(ctx, value) + elem_type, elem_value, remaining_data, err := DeserializeValue(ctx, value) if err != nil { return nil, nil, SerializedValue{}, err } - pointer_value := elem_value.Addr() - return pointer_value.Type(), &pointer_value, remaining_data, nil + pointer_type := reflect.PointerTo(elem_type) + pointer_value := reflect.New(pointer_type).Elem() + pointer_value.Set(elem_value.Addr()) + return pointer_type, &pointer_value, remaining_data, nil } else if pointer_flags == 0x01 { elem_type, _, remaining_data, err := DeserializeValue(ctx, value) if err != nil { @@ -323,13 +325,18 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + // TODO: figure out why this doesn't break in the simple test, but breaks in TestGQLDB err = ctx.RegisterKind(reflect.Struct, StructType, func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ serialized_value := SerializedValue{ []SerializedType{ctx_type}, nil, } - field_values := map[SerializedType]SerializedValue{} + if value != nil { + serialized_value.Data = make([]byte, 8) + } + + num_fields := uint64(0) for _, field := range(reflect.VisibleFields(reflect_type)) { gv_tag, tagged_gv := field.Tag.Lookup("gv") if tagged_gv == false { @@ -337,28 +344,72 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { } else if gv_tag == "" { continue } else { - // Add to the type stack and data stack + num_fields += 1 field_hash := Hash(FieldNameBase, gv_tag) + field_hash_bytes := make([]byte, 8) + binary.BigEndian.PutUint64(field_hash_bytes, uint64(field_hash)) if value == nil { field_ser, err := SerializeValue(ctx, field.Type, nil) if err != nil { return SerializedValue{}, err } - field_values[field_hash] = field_ser + serialized_value.TypeStack = append(serialized_value.TypeStack, field_ser.TypeStack...) } else { field_value := value.FieldByIndex(field.Index) field_ser, err := SerializeValue(ctx, field.Type, &field_value) if err != nil { return SerializedValue{}, err } - field_values[field_hash] = field_ser + serialized_value.TypeStack = append(serialized_value.TypeStack, field_ser.TypeStack...) + serialized_value.Data = append(serialized_value.Data, field_hash_bytes...) + serialized_value.Data = append(serialized_value.Data, field_ser.Data...) } } } + if value != nil { + binary.BigEndian.PutUint64(serialized_value.Data[0:8], num_fields) + } return serialized_value, nil }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - return nil, nil, SerializedValue{}, fmt.Errorf("deserialize struct not implemented") + if value.Data == nil { + return reflect.TypeOf(map[uint64]reflect.Value{}), nil, value, nil + } else { + var num_fields_data []byte + var err error + num_fields_data, value, err = value.PopData(8) + if err != nil { + return nil, nil, value, err + } + num_fields := int(binary.BigEndian.Uint64(num_fields_data)) + + map_type := reflect.TypeOf(map[uint64]reflect.Value{}) + map_ptr := reflect.New(map_type) + map_ptr.Elem().Set(reflect.MakeMap(map_type)) + map_value := map_ptr.Elem() + if num_fields == 0 { + return map_type, &map_value, value, nil + } else { + tmp_value := value + for i := 0; i < num_fields; i += 1 { + var field_hash_bytes []byte + field_hash_bytes, tmp_value, err = tmp_value.PopData(8) + if err != nil { + return nil, nil, value, err + } + field_hash := binary.BigEndian.Uint64(field_hash_bytes) + field_hash_value := reflect.ValueOf(field_hash) + + var elem_value *reflect.Value + _, elem_value, tmp_value, err = DeserializeValue(ctx, tmp_value) + if err != nil { + return nil, nil, value, err + } + map_value.SetMapIndex(field_hash_value, reflect.ValueOf(*elem_value)) + } + return map_type, &map_value, tmp_value, nil + } + } }) if err != nil { return nil, err @@ -603,26 +654,43 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { err = ctx.RegisterKind(reflect.Interface, InterfaceType, func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue, error){ var data []byte - type_stack := []SerializedType{} + type_stack := []SerializedType{ctx_type} if value == nil { data = nil } else if value.IsZero() { - return SerializedValue{}, fmt.Errorf("Cannot serialize nil interfaces") } else { elem_value := value.Elem() - elem, err := SerializeValue(ctx, value.Elem().Type(), &elem_value) + elem, err := SerializeValue(ctx, elem_value.Type(), &elem_value) + if err != nil { + return SerializedValue{}, err + } + data, err = elem.MarshalBinary() if err != nil { return SerializedValue{}, err } - data = elem.Data - type_stack = elem.TypeStack } return SerializedValue{ - append([]SerializedType{ctx_type}, type_stack...), + type_stack, data, }, nil }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - return nil, nil, SerializedValue{}, fmt.Errorf("deserialize interface unimplemented") + if value.Data == nil { + return reflect.TypeOf((interface{})(nil)), nil, value, nil + } else { + var elem_value *reflect.Value + var elem_ser SerializedValue + var elem_type reflect.Type + var err error + elem_ser, value.Data, err = ParseSerializedValue(value.Data) + elem_type, elem_value, _, err = DeserializeValue(ctx, elem_ser) + if err != nil { + return nil, nil, value, err + } + ptr_type := reflect.PointerTo(elem_type) + ptr_value := reflect.New(ptr_type).Elem() + ptr_value.Set(elem_value.Addr()) + return ptr_type, &ptr_value, value, nil + } }) if err != nil { return nil, err @@ -764,7 +832,6 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { reflect_value := reflect.MakeMap(reflect_type) return reflect_type, &reflect_value, new_value, nil } else { - // TODO: basically copy above except instead of getting the key/elem type once, get key/elem values for map_size tmp_value := value var map_value reflect.Value var map_type reflect.Type = nil @@ -1128,26 +1195,10 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(StringError("")), ErrorType, - func(ctx *Context, ctx_type SerializedType, t reflect.Type, value *reflect.Value) (SerializedValue, error) { - if value == nil { - return SerializedValue{ - []SerializedType{ctx_type}, - nil, - }, nil - } - - data := make([]byte, 8) - err := value.Interface().(StringError) - str := string(err) - binary.BigEndian.PutUint64(data, uint64(len(str))) - return SerializedValue{ - []SerializedType{SerializedType(ctx_type)}, - append(data, []byte(str)...), - }, nil - }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - return nil, nil, SerializedValue{}, fmt.Errorf("unimplemented") - }) + err = ctx.RegisterType(reflect.TypeOf(StringError("")), ErrorType, nil, nil) + if err != nil { + return nil, err + } err = ctx.RegisterType(reflect.TypeOf(RandID()), NodeIDType, func(ctx *Context, ctx_type SerializedType, t reflect.Type, value *reflect.Value) (SerializedValue, error) { @@ -1164,7 +1215,24 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { id_ser, }, nil }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ - return nil, nil, SerializedValue{}, fmt.Errorf("unimplemented") + if value.Data == nil { + return reflect.TypeOf(ZeroID), nil, value, nil + } else { + var err error + var id_bytes []byte + id_bytes, value, err = value.PopData(16) + if err != nil { + return nil, nil, value, err + } + + id, err := IDFromBytes(id_bytes) + if err != nil { + return nil, nil, value, err + } + + id_value := reflect.ValueOf(id) + return id_value.Type(), &id_value, value, nil + } }) if err != nil { return nil, err diff --git a/gql_test.go b/gql_test.go index 29c0b0c..97ff334 100644 --- a/gql_test.go +++ b/gql_test.go @@ -210,7 +210,7 @@ func TestGQLServer(t *testing.T) { } func TestGQLDB(t *testing.T) { - ctx := logTestContext(t, []string{"test", "signal", "node"}) + ctx := logTestContext(t, []string{"test", "serialize", "node"}) TestUserNodeType := NewNodeType("TEST_USER") err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) @@ -239,13 +239,6 @@ func TestGQLDB(t *testing.T) { }) fatalErr(t, err) - ser1, err := SerializeAny(ctx, gql) - fatalErr(t, err) - ctx.Log.Logf("test", "SER_1: \n%+v\n\n", ser1) - ser2, err := SerializeAny(ctx, u1) - fatalErr(t, err) - ctx.Log.Logf("test", "SER_2: \n%+v\n\n", ser2) - // Clear all loaded nodes from the context so it loads them from the database ctx.nodeMap = map[NodeID]*Node{} gql_loaded, err := LoadNode(ctx, gql.ID) diff --git a/node.go b/node.go index b1c2840..4439cdd 100644 --- a/node.go +++ b/node.go @@ -88,25 +88,25 @@ type PendingSignal struct { // Default message channel size for nodes // Nodes represent a group of extensions that can be collectively addressed type Node struct { - Key ed25519.PrivateKey `gv:"0"` + Key ed25519.PrivateKey `gv:""` ID NodeID - Type NodeType `gv:"1"` - Extensions map[ExtType]Extension `gv:"3"` - Policies map[PolicyType]Policy `gv:"4"` + Type NodeType `gv:""` + Extensions map[ExtType]Extension `gv:"extensions"` + Policies map[PolicyType]Policy `gv:""` - PendingACLs map[uuid.UUID]PendingACL `gv:"6"` - PendingSignals map[uuid.UUID]PendingSignal `gv:"7"` + PendingACLs map[uuid.UUID]PendingACL `gv:""` + PendingSignals map[uuid.UUID]PendingSignal `gv:""` // Channel for this node to receive messages from the Context MsgChan chan *Message // Size of MsgChan - BufferSize uint32 `gv:"2"` + BufferSize uint32 `gv:""` // Channel for this node to process delayed signals TimeoutChan <-chan time.Time Active atomic.Bool - SignalQueue []QueuedSignal `gv:"5"` + SignalQueue []QueuedSignal `gv:""` NextSignal *QueuedSignal } @@ -387,26 +387,31 @@ func nodeLoop(ctx *Context, node *Node) error { switch sig := signal.(type) { case *StopSignal: + node.Process(ctx, source, signal) + err := WriteNode(ctx, node) + if err != nil { + panic(err) + } + msgs := Messages{} msgs = msgs.Add(ctx, node.ID, node.Key, NewStatusSignal(node.ID, "stopped"), source) ctx.Send(msgs) node.Process(ctx, node.ID, NewStatusSignal(node.ID, "stopped")) run = false + case *ReadSignal: result := node.ReadFields(ctx, sig.Extensions) msgs := Messages{} msgs = msgs.Add(ctx, node.ID, node.Key, NewReadResultSignal(sig.ID, node.ID, node.Type, result), source) msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ID, "read_done"), source) ctx.Send(msgs) - } - node.Process(ctx, source, signal) - // assume that processing a signal means that this nodes state changed - // TODO: remove a lot of database writes by only writing when things change, - // so need to have Process return whether or not state changed - err := WriteNode(ctx, node) - if err != nil { - panic(err) + default: + node.Process(ctx, source, signal) + err := WriteNode(ctx, node) + if err != nil { + panic(err) + } } } @@ -679,7 +684,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { node, ok := node_val.Interface().(*Node) if ok == false { - return nil, fmt.Errorf("Deserialized %+v when expecting *Node", reflect.TypeOf(node_val).Elem()) + return nil, fmt.Errorf("Deserialized %+v when expecting *Node", node_val.Type()) } ctx.AddNode(id, node) diff --git a/serialize.go b/serialize.go index 87142b3..c76cfa3 100644 --- a/serialize.go +++ b/serialize.go @@ -136,10 +136,13 @@ func SerializeAny[T any](ctx *Context, value T) (SerializedValue, error) { func SerializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (SerializedValue, error) { ctx.Log.Logf("serialize", "Serializing: %+v - %+v", t, value) - ctx_type, type_exists := ctx.TypeReflects[t] + type_info, type_exists := ctx.TypeReflects[t] + var ctx_type SerializedType + var ctx_name string var serialize TypeSerialize = nil if type_exists == true { - type_info := ctx.Types[ctx_type] + ctx_type = type_info.Type + ctx_name = type_info.Reflect.Name() if type_info.Serialize != nil { serialize = type_info.Serialize } @@ -151,6 +154,7 @@ func SerializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (Seriali return SerializedValue{}, fmt.Errorf("Don't know how to serialize kind %+v", kind) } else if type_exists == false { ctx_type = kind_info.Type + ctx_name = kind_info.Reflect.String() } if serialize == nil { @@ -158,7 +162,10 @@ func SerializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (Seriali } serialized_value, err := serialize(ctx, ctx_type, t, value) - ctx.Log.Logf("serialize", "Serialized %+v: %+v - %+v", value, serialized_value, err) + if err != nil { + return serialized_value, err + } + ctx.Log.Logf("serialize", "Serialized %+v: %+v", ctx_name, serialized_value) return serialized_value, err } @@ -215,7 +222,6 @@ func ParseSerializedValue(data []byte) (SerializedValue, []byte, error) { } func DeserializeValue(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { - ctx.Log.Logf("serialize", "Deserializing: %+v", value) var deserialize TypeDeserialize = nil var reflect_type reflect.Type = nil @@ -226,19 +232,23 @@ func DeserializeValue(ctx *Context, value SerializedValue) (reflect.Type, *refle return nil, nil, value, err } - type_info, exists := ctx.Types[SerializedType(ctx_type)] - if exists == true { + var ctx_name string + + type_info, type_exists := ctx.Types[SerializedType(ctx_type)] + if type_exists == true { deserialize = type_info.Deserialize - reflect_type = type_info.Type + ctx_name = type_info.Reflect.Name() } else { - kind, exists := ctx.KindTypes[SerializedType(ctx_type)] + kind_info, exists := ctx.KindTypes[SerializedType(ctx_type)] if exists == false { return nil, nil, value, fmt.Errorf("Cannot deserialize 0x%x: unknown type/kind", ctx_type) } - kind_info := ctx.Kinds[kind] deserialize = kind_info.Deserialize + ctx_name = kind_info.Reflect.String() } + ctx.Log.Logf("serialize", "Deserializing: %+v(0x%d) - %+v", ctx_name, ctx_type, value.TypeStack) + if value.Data == nil { reflect_type, _, value, err = deserialize(ctx, value) } else { @@ -248,6 +258,6 @@ func DeserializeValue(ctx *Context, value SerializedValue) (reflect.Type, *refle return nil, nil, value, err } - ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v - remaining %+v", reflect_type, reflect_value, err, value) + ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v", reflect_type, reflect_value, err) return reflect_type, reflect_value, value, nil } diff --git a/serialize_test.go b/serialize_test.go index 8291763..f49ee57 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -36,7 +36,7 @@ func TestSerializeBasic(t *testing.T) { 6: 1121, }) - testSerialize(t, ctx, struct{ + testSerializeStruct(t, ctx, struct{ int `gv:"0"` string `gv:"1"` }{ @@ -90,6 +90,64 @@ func testSerializeComparable[T comparable](t *testing.T, ctx *Context, val T) { } } +func testSerializeStruct[T any](t *testing.T, ctx *Context, val T) { + value, err := SerializeAny(ctx, val) + fatalErr(t, err) + ctx.Log.Logf("test", "Serialized %+v to %+v", val, value) + + ser, err := value.MarshalBinary() + fatalErr(t, err) + ctx.Log.Logf("test", "Binary: %+v", ser) + + val_parsed, remaining_parse, err := ParseSerializedValue(ser) + fatalErr(t, err) + ctx.Log.Logf("test", "Parsed: %+v", val_parsed) + + if len(remaining_parse) != 0 { + t.Fatal("Data remaining after deserializing value") + } + + val_type, deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_parsed) + fatalErr(t, err) + + if len(remaining_deserialize.Data) != 0 { + t.Fatal("Data remaining after deserializing value") + } else if len(remaining_deserialize.TypeStack) != 0 { + t.Fatal("TypeStack remaining after deserializing value") + } else if val_type != reflect.TypeOf(map[uint64]reflect.Value{}) { + t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - map[uint64]reflect.Value", val_type)) + } else if deserialized_value == nil { + t.Fatal("DeserializeValue returned no []reflect.Value") + } else if deserialized_value == nil { + t.Fatal("DeserializeValue returned nil *reflect.Value") + } else if deserialized_value.CanConvert(reflect.TypeOf(map[uint64]reflect.Value{})) == false { + t.Fatal("DeserializeValue returned value that can't convert to map[uint64]reflect.Value") + } + + reflect_value := reflect.ValueOf(val) + deserialized_map := deserialized_value.Interface().(map[uint64]reflect.Value) + + for _, field := range(reflect.VisibleFields(reflect_value.Type())) { + gv_tag, tagged_gv := field.Tag.Lookup("gv") + if tagged_gv == false { + continue + } else if gv_tag == "" { + continue + } else { + field_hash := uint64(Hash(FieldNameBase, gv_tag)) + deserialized_field, exists := deserialized_map[field_hash] + if exists == false { + t.Fatal(fmt.Sprintf("field %s is not in deserialized struct", field.Name)) + } + field_value := reflect_value.FieldByIndex(field.Index) + if field_value.Type() != deserialized_field.Type() { + t.Fatal(fmt.Sprintf("Type of %s does not match", field.Name)) + } + ctx.Log.Logf("test", "Field %s matched", field.Name) + } + } +} + func testSerialize[T any](t *testing.T, ctx *Context, val T) T { value, err := SerializeAny(ctx, val) fatalErr(t, err)