From 07ce0053650dc7505dc90bb6df8eab8f37f67030 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Tue, 12 Sep 2023 19:00:48 -0600 Subject: [PATCH] Got serialization to the point that TestGQLDB is passing --- context.go | 49 ++++++++++++++++++++++++++++++++++++++++---- gql_test.go | 4 ++-- listener.go | 7 ++++++- node.go | 17 ++++++++++++++- serialize.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++----- 5 files changed, 122 insertions(+), 13 deletions(-) diff --git a/context.go b/context.go index 7899d5b..5b56732 100644 --- a/context.go +++ b/context.go @@ -9,6 +9,7 @@ import ( "reflect" "runtime" "sync" + "github.com/google/uuid" badger "github.com/dgraph-io/badger/v3" ) @@ -127,6 +128,8 @@ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType return fmt.Errorf("Cannot register extension of type %+v, type already exists in context", ext_type) } + ctx.Log.Logf("serialize", "Registered ExtType: %+v - %+v", reflect_type, ext_type) + ctx.Extensions[ext_type] = ExtensionInfo{ Type: reflect_type, Data: data, @@ -703,6 +706,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { binary.BigEndian.PutUint64(data[0:8], uint64(map_size)) + ctx.Log.Logf("serialize", "MAP_TYPES: %+v - %+v", key_types, elem_types) + type_stack = append(type_stack, key_types...) type_stack = append(type_stack, elem_types...) return SerializedValue{ @@ -751,6 +756,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { } map_size := binary.BigEndian.Uint64(map_size_bytes) + ctx.Log.Logf("serialize", "Deserializing %d elements in map", map_size) if map_size == 0xFFFFFFFFFFFFFFFF { var key_type, elem_type reflect.Type @@ -990,12 +996,17 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { } */ - err = ctx.RegisterType(reflect.TypeOf(ExtType(0)), ExtTypeType, SerializeUintN(4), DeserializeUintN[ExtType](4)) + err = ctx.RegisterType(reflect.TypeOf(ExtType(0)), ExtTypeSerialized, SerializeUintN(8), DeserializeUintN[ExtType](8)) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeSerialized, SerializeUintN(8), DeserializeUintN[NodeType](8)) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeType, SerializeUintN(4), DeserializeUintN[NodeType](4)) + err = ctx.RegisterType(reflect.TypeOf(PolicyType(0)), PolicyTypeSerialized, SerializeUintN(8), DeserializeUintN[PolicyType](8)) if err != nil { return nil, err } @@ -1005,9 +1016,30 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterType(reflect.TypeOf(uuid.New()), UUIDType, SerializeArray, DeserializeArray[uuid.UUID](ctx)) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(PendingACL{}), PendingACLType, SerializeStruct[PendingACL](ctx), DeserializeStruct[PendingACL](ctx)) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(PendingSignal{}), PendingSignalType, SerializeStruct[PendingSignal](ctx), DeserializeStruct[PendingSignal](ctx)) + if err != nil { + return nil, err + } + // TODO: Make registering interfaces cleaner var extension Extension = nil - err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtensionType, SerializeInterface, DeserializeInterface[Extension]()) + err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtSerialized, SerializeInterface, DeserializeInterface[Extension]()) + if err != nil { + return nil, err + } + + var policy Policy = nil + err = ctx.RegisterType(reflect.ValueOf(&policy).Type().Elem(), PolicySerialized, SerializeInterface, DeserializeInterface[Policy]()) if err != nil { return nil, err } @@ -1027,11 +1059,20 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(Node{}), NodeStructType, SerializeStruct[Node](ctx), DeserializeStruct[Node](ctx)) + err = ctx.RegisterType(reflect.TypeOf(QueuedSignal{}), QueuedSignalType, SerializeStruct[QueuedSignal](ctx), DeserializeStruct[QueuedSignal](ctx)) if err != nil { return nil, err } + err = ctx.RegisterType(reflect.TypeOf(AllNodesPolicy{}), SerializedType(AllNodesPolicyType), SerializeStruct[AllNodesPolicy](ctx), DeserializeStruct[AllNodesPolicy](ctx)) + if err != nil { + return nil, err + } + + err = ctx.RegisterType(reflect.TypeOf(Node{}), NodeStructType, SerializeStruct[Node](ctx), DeserializeStruct[Node](ctx)) + if err != nil { + return nil, err + } err = ctx.RegisterType(reflect.TypeOf(Up), SignalDirectionType, func(ctx *Context, ctx_type SerializedType, t reflect.Type, value *reflect.Value) (SerializedValue, error) { diff --git a/gql_test.go b/gql_test.go index 97ff334..6543e2a 100644 --- a/gql_test.go +++ b/gql_test.go @@ -210,7 +210,7 @@ func TestGQLServer(t *testing.T) { } func TestGQLDB(t *testing.T) { - ctx := logTestContext(t, []string{"test", "serialize", "node"}) + ctx := logTestContext(t, []string{"test", "node"}) TestUserNodeType := NewNodeType("TEST_USER") err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) @@ -243,7 +243,7 @@ func TestGQLDB(t *testing.T) { ctx.nodeMap = map[NodeID]*Node{} gql_loaded, err := LoadNode(ctx, gql.ID) fatalErr(t, err) - listener_ext, err = GetExt[*ListenerExt](gql_loaded, GQLExtType) + listener_ext, err = GetExt[*ListenerExt](gql_loaded, ListenerExtType) fatalErr(t, err) msgs = Messages{} msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded.Key, NewStopSignal(), gql_loaded.ID) diff --git a/listener.go b/listener.go index 357c866..1d25075 100644 --- a/listener.go +++ b/listener.go @@ -7,10 +7,15 @@ import ( // A Listener extension provides a channel that can receive signals on a different thread type ListenerExt struct { - Buffer int + Buffer int `gv:"buffer"` Chan chan Signal } +func (ext *ListenerExt) PostDeserialize(ctx *Context) error { + ext.Chan = make(chan Signal, ext.Buffer) + return nil +} + // Create a new listener extension with a given buffer size func NewListenerExt(buffer int) *ListenerExt { return &ListenerExt{ diff --git a/node.go b/node.go index a48db82..44381a0 100644 --- a/node.go +++ b/node.go @@ -110,6 +110,17 @@ type Node struct { NextSignal *QueuedSignal } +func (node *Node) PostDeserialize(ctx *Context) error { + public := node.Key.Public().(ed25519.PublicKey) + node.ID = KeyID(public) + + node.MsgChan = make(chan *Message, node.BufferSize) + + node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) + + return nil +} + type RuleResult int const ( Allow RuleResult = iota @@ -687,9 +698,13 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, fmt.Errorf("Deserialized %+v when expecting *Node", node_val.Type()) } + for ext_type, ext := range(node.Extensions){ + ctx.Log.Logf("serialize", "Deserialized extension: %+v - %+v", ext_type, ext) + } + ctx.AddNode(id, node) ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) go runNode(ctx, node) - return nil, nil + return node, nil } diff --git a/serialize.go b/serialize.go index adf35bd..123b670 100644 --- a/serialize.go +++ b/serialize.go @@ -26,12 +26,15 @@ func Hash(base string, name string) SerializedType { } type SerializedType uint64 - func (t SerializedType) String() string { return fmt.Sprintf("0x%x", uint64(t)) } type ExtType SerializedType +func (t ExtType) String() string { + return fmt.Sprintf("0x%x", uint64(t)) +} + type NodeType SerializedType type SignalType SerializedType type PolicyType SerializedType @@ -109,10 +112,16 @@ var ( ReqStateType = NewSerializedType("REQ_STATE") SignalDirectionType = NewSerializedType("SIGNAL_DIRECTION") NodeStructType = NewSerializedType("NODE_STRUCT") - NodeTypeType = NewSerializedType("NODE_TYPE") - ExtTypeType = NewSerializedType("EXT_TYPE") - ExtensionType = NewSerializedType("EXTENSION") + QueuedSignalType = NewSerializedType("QUEUED_SIGNAL") + NodeTypeSerialized = NewSerializedType("NODE_TYPE") + ExtTypeSerialized = NewSerializedType("EXT_TYPE") + PolicyTypeSerialized = NewSerializedType("POLICY_TYPE") + ExtSerialized = NewSerializedType("EXTENSION") + PolicySerialized = NewSerializedType("POLICY") NodeIDType = NewSerializedType("NODE_ID") + UUIDType = NewSerializedType("UUID") + PendingACLType = NewSerializedType("PENDING_ACL") + PendingSignalType = NewSerializedType("PENDING_SIGNAL") ) func SerializeArray(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ @@ -338,8 +347,17 @@ type StructInfo struct { Type reflect.Type FieldOrder []SerializedType FieldMap map[SerializedType]FieldInfo + PostDeserialize bool + PostDeserializeIdx int } +type Deserializable interface { + PostDeserialize(*Context) error +} + +var deserializable_zero Deserializable = nil +var DeserializableType = reflect.TypeOf(&deserializable_zero).Elem() + func structInfo[T any](ctx *Context)StructInfo{ var struct_zero T struct_type := reflect.TypeOf(struct_zero) @@ -372,10 +390,26 @@ func structInfo[T any](ctx *Context)StructInfo{ return uint64(field_order[i]) < uint64(field_order[j]) }) + post_deserialize := false + post_deserialize_idx := 0 + ptr_type := reflect.PointerTo(struct_type) + if ptr_type.Implements(DeserializableType) { + post_deserialize = true + for i := 0; i < ptr_type.NumMethod(); i += 1 { + method := ptr_type.Method(i) + if method.Name == "PostDeserialize" { + post_deserialize_idx = i + break + } + } + } + return StructInfo{ struct_type, field_order, field_map, + post_deserialize, + post_deserialize_idx, } } @@ -422,6 +456,7 @@ func DeserializeStruct[T any](ctx *Context)(func(*Context,SerializedValue)(refle return nil, nil, value, err } num_fields := int(binary.BigEndian.Uint64(num_fields_bytes)) + ctx.Log.Logf("serialize", "Deserializing %d fields from %+v", num_fields, struct_info) struct_value := reflect.New(struct_info.Type).Elem() @@ -452,6 +487,15 @@ func DeserializeStruct[T any](ctx *Context)(func(*Context,SerializedValue)(refle field_value.Set(*field_reflect) } + if struct_info.PostDeserialize == true { + ctx.Log.Logf("serialize", "running post-deserialize for %+v", struct_info.Type) + post_deserialize_method := struct_value.Addr().Method(struct_info.PostDeserializeIdx) + ret := post_deserialize_method.Call([]reflect.Value{reflect.ValueOf(ctx)}) + if ret[0].IsZero() == false { + return nil, nil, value, ret[0].Interface().(error) + } + } + return struct_info.Type, &struct_value, value, err } } @@ -670,6 +714,10 @@ func DeserializeValue(ctx *Context, value SerializedValue) (reflect.Type, *refle return nil, nil, value, err } - ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v", reflect_type, reflect_value, err) + if reflect_value != nil { + ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v", reflect_type, reflect_value.Interface(), err) + } else { + ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v", reflect_type, reflect_value, err) + } return reflect_type, reflect_value, value, nil }