Got serialization to the point that TestGQLDB is passing

gql_cataclysm
noah metz 2023-09-12 19:00:48 -06:00
parent c4e5054e07
commit 07ce005365
5 changed files with 122 additions and 13 deletions

@ -9,6 +9,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"sync" "sync"
"github.com/google/uuid"
badger "github.com/dgraph-io/badger/v3" badger "github.com/dgraph-io/badger/v3"
) )
@ -127,6 +128,8 @@ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType
return fmt.Errorf("Cannot register extension of type %+v, type already exists in context", ext_type) return fmt.Errorf("Cannot register extension of type %+v, type already exists in context", ext_type)
} }
ctx.Log.Logf("serialize", "Registered ExtType: %+v - %+v", reflect_type, ext_type)
ctx.Extensions[ext_type] = ExtensionInfo{ ctx.Extensions[ext_type] = ExtensionInfo{
Type: reflect_type, Type: reflect_type,
Data: data, Data: data,
@ -703,6 +706,8 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
binary.BigEndian.PutUint64(data[0:8], uint64(map_size)) binary.BigEndian.PutUint64(data[0:8], uint64(map_size))
ctx.Log.Logf("serialize", "MAP_TYPES: %+v - %+v", key_types, elem_types)
type_stack = append(type_stack, key_types...) type_stack = append(type_stack, key_types...)
type_stack = append(type_stack, elem_types...) type_stack = append(type_stack, elem_types...)
return SerializedValue{ return SerializedValue{
@ -751,6 +756,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
} }
map_size := binary.BigEndian.Uint64(map_size_bytes) map_size := binary.BigEndian.Uint64(map_size_bytes)
ctx.Log.Logf("serialize", "Deserializing %d elements in map", map_size)
if map_size == 0xFFFFFFFFFFFFFFFF { if map_size == 0xFFFFFFFFFFFFFFFF {
var key_type, elem_type reflect.Type var key_type, elem_type reflect.Type
@ -990,12 +996,17 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
} }
*/ */
err = ctx.RegisterType(reflect.TypeOf(ExtType(0)), ExtTypeType, SerializeUintN(4), DeserializeUintN[ExtType](4)) err = ctx.RegisterType(reflect.TypeOf(ExtType(0)), ExtTypeSerialized, SerializeUintN(8), DeserializeUintN[ExtType](8))
if err != nil {
return nil, err
}
err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeSerialized, SerializeUintN(8), DeserializeUintN[NodeType](8))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(NodeType(0)), NodeTypeType, SerializeUintN(4), DeserializeUintN[NodeType](4)) err = ctx.RegisterType(reflect.TypeOf(PolicyType(0)), PolicyTypeSerialized, SerializeUintN(8), DeserializeUintN[PolicyType](8))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1005,9 +1016,30 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(uuid.New()), UUIDType, SerializeArray, DeserializeArray[uuid.UUID](ctx))
if err != nil {
return nil, err
}
err = ctx.RegisterType(reflect.TypeOf(PendingACL{}), PendingACLType, SerializeStruct[PendingACL](ctx), DeserializeStruct[PendingACL](ctx))
if err != nil {
return nil, err
}
err = ctx.RegisterType(reflect.TypeOf(PendingSignal{}), PendingSignalType, SerializeStruct[PendingSignal](ctx), DeserializeStruct[PendingSignal](ctx))
if err != nil {
return nil, err
}
// TODO: Make registering interfaces cleaner // TODO: Make registering interfaces cleaner
var extension Extension = nil var extension Extension = nil
err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtensionType, SerializeInterface, DeserializeInterface[Extension]()) err = ctx.RegisterType(reflect.ValueOf(&extension).Type().Elem(), ExtSerialized, SerializeInterface, DeserializeInterface[Extension]())
if err != nil {
return nil, err
}
var policy Policy = nil
err = ctx.RegisterType(reflect.ValueOf(&policy).Type().Elem(), PolicySerialized, SerializeInterface, DeserializeInterface[Policy]())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1027,11 +1059,20 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(Node{}), NodeStructType, SerializeStruct[Node](ctx), DeserializeStruct[Node](ctx)) err = ctx.RegisterType(reflect.TypeOf(QueuedSignal{}), QueuedSignalType, SerializeStruct[QueuedSignal](ctx), DeserializeStruct[QueuedSignal](ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(AllNodesPolicy{}), SerializedType(AllNodesPolicyType), SerializeStruct[AllNodesPolicy](ctx), DeserializeStruct[AllNodesPolicy](ctx))
if err != nil {
return nil, err
}
err = ctx.RegisterType(reflect.TypeOf(Node{}), NodeStructType, SerializeStruct[Node](ctx), DeserializeStruct[Node](ctx))
if err != nil {
return nil, err
}
err = ctx.RegisterType(reflect.TypeOf(Up), SignalDirectionType, err = ctx.RegisterType(reflect.TypeOf(Up), SignalDirectionType,
func(ctx *Context, ctx_type SerializedType, t reflect.Type, value *reflect.Value) (SerializedValue, error) { func(ctx *Context, ctx_type SerializedType, t reflect.Type, value *reflect.Value) (SerializedValue, error) {

@ -210,7 +210,7 @@ func TestGQLServer(t *testing.T) {
} }
func TestGQLDB(t *testing.T) { func TestGQLDB(t *testing.T) {
ctx := logTestContext(t, []string{"test", "serialize", "node"}) ctx := logTestContext(t, []string{"test", "node"})
TestUserNodeType := NewNodeType("TEST_USER") TestUserNodeType := NewNodeType("TEST_USER")
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{})
@ -243,7 +243,7 @@ func TestGQLDB(t *testing.T) {
ctx.nodeMap = map[NodeID]*Node{} ctx.nodeMap = map[NodeID]*Node{}
gql_loaded, err := LoadNode(ctx, gql.ID) gql_loaded, err := LoadNode(ctx, gql.ID)
fatalErr(t, err) fatalErr(t, err)
listener_ext, err = GetExt[*ListenerExt](gql_loaded, GQLExtType) listener_ext, err = GetExt[*ListenerExt](gql_loaded, ListenerExtType)
fatalErr(t, err) fatalErr(t, err)
msgs = Messages{} msgs = Messages{}
msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded.Key, NewStopSignal(), gql_loaded.ID) msgs = msgs.Add(ctx, gql_loaded.ID, gql_loaded.Key, NewStopSignal(), gql_loaded.ID)

@ -7,10 +7,15 @@ import (
// A Listener extension provides a channel that can receive signals on a different thread // A Listener extension provides a channel that can receive signals on a different thread
type ListenerExt struct { type ListenerExt struct {
Buffer int Buffer int `gv:"buffer"`
Chan chan Signal Chan chan Signal
} }
func (ext *ListenerExt) PostDeserialize(ctx *Context) error {
ext.Chan = make(chan Signal, ext.Buffer)
return nil
}
// Create a new listener extension with a given buffer size // Create a new listener extension with a given buffer size
func NewListenerExt(buffer int) *ListenerExt { func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{ return &ListenerExt{

@ -110,6 +110,17 @@ type Node struct {
NextSignal *QueuedSignal NextSignal *QueuedSignal
} }
func (node *Node) PostDeserialize(ctx *Context) error {
public := node.Key.Public().(ed25519.PublicKey)
node.ID = KeyID(public)
node.MsgChan = make(chan *Message, node.BufferSize)
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
return nil
}
type RuleResult int type RuleResult int
const ( const (
Allow RuleResult = iota Allow RuleResult = iota
@ -687,9 +698,13 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, fmt.Errorf("Deserialized %+v when expecting *Node", node_val.Type()) return nil, fmt.Errorf("Deserialized %+v when expecting *Node", node_val.Type())
} }
for ext_type, ext := range(node.Extensions){
ctx.Log.Logf("serialize", "Deserialized extension: %+v - %+v", ext_type, ext)
}
ctx.AddNode(id, node) ctx.AddNode(id, node)
ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id) ctx.Log.Logf("db", "DB_NODE_LOADED: %s", id)
go runNode(ctx, node) go runNode(ctx, node)
return nil, nil return node, nil
} }

@ -26,12 +26,15 @@ func Hash(base string, name string) SerializedType {
} }
type SerializedType uint64 type SerializedType uint64
func (t SerializedType) String() string { func (t SerializedType) String() string {
return fmt.Sprintf("0x%x", uint64(t)) return fmt.Sprintf("0x%x", uint64(t))
} }
type ExtType SerializedType type ExtType SerializedType
func (t ExtType) String() string {
return fmt.Sprintf("0x%x", uint64(t))
}
type NodeType SerializedType type NodeType SerializedType
type SignalType SerializedType type SignalType SerializedType
type PolicyType SerializedType type PolicyType SerializedType
@ -109,10 +112,16 @@ var (
ReqStateType = NewSerializedType("REQ_STATE") ReqStateType = NewSerializedType("REQ_STATE")
SignalDirectionType = NewSerializedType("SIGNAL_DIRECTION") SignalDirectionType = NewSerializedType("SIGNAL_DIRECTION")
NodeStructType = NewSerializedType("NODE_STRUCT") NodeStructType = NewSerializedType("NODE_STRUCT")
NodeTypeType = NewSerializedType("NODE_TYPE") QueuedSignalType = NewSerializedType("QUEUED_SIGNAL")
ExtTypeType = NewSerializedType("EXT_TYPE") NodeTypeSerialized = NewSerializedType("NODE_TYPE")
ExtensionType = NewSerializedType("EXTENSION") ExtTypeSerialized = NewSerializedType("EXT_TYPE")
PolicyTypeSerialized = NewSerializedType("POLICY_TYPE")
ExtSerialized = NewSerializedType("EXTENSION")
PolicySerialized = NewSerializedType("POLICY")
NodeIDType = NewSerializedType("NODE_ID") NodeIDType = NewSerializedType("NODE_ID")
UUIDType = NewSerializedType("UUID")
PendingACLType = NewSerializedType("PENDING_ACL")
PendingSignalType = NewSerializedType("PENDING_SIGNAL")
) )
func SerializeArray(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ func SerializeArray(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){
@ -338,8 +347,17 @@ type StructInfo struct {
Type reflect.Type Type reflect.Type
FieldOrder []SerializedType FieldOrder []SerializedType
FieldMap map[SerializedType]FieldInfo FieldMap map[SerializedType]FieldInfo
PostDeserialize bool
PostDeserializeIdx int
} }
type Deserializable interface {
PostDeserialize(*Context) error
}
var deserializable_zero Deserializable = nil
var DeserializableType = reflect.TypeOf(&deserializable_zero).Elem()
func structInfo[T any](ctx *Context)StructInfo{ func structInfo[T any](ctx *Context)StructInfo{
var struct_zero T var struct_zero T
struct_type := reflect.TypeOf(struct_zero) struct_type := reflect.TypeOf(struct_zero)
@ -372,10 +390,26 @@ func structInfo[T any](ctx *Context)StructInfo{
return uint64(field_order[i]) < uint64(field_order[j]) return uint64(field_order[i]) < uint64(field_order[j])
}) })
post_deserialize := false
post_deserialize_idx := 0
ptr_type := reflect.PointerTo(struct_type)
if ptr_type.Implements(DeserializableType) {
post_deserialize = true
for i := 0; i < ptr_type.NumMethod(); i += 1 {
method := ptr_type.Method(i)
if method.Name == "PostDeserialize" {
post_deserialize_idx = i
break
}
}
}
return StructInfo{ return StructInfo{
struct_type, struct_type,
field_order, field_order,
field_map, field_map,
post_deserialize,
post_deserialize_idx,
} }
} }
@ -422,6 +456,7 @@ func DeserializeStruct[T any](ctx *Context)(func(*Context,SerializedValue)(refle
return nil, nil, value, err return nil, nil, value, err
} }
num_fields := int(binary.BigEndian.Uint64(num_fields_bytes)) num_fields := int(binary.BigEndian.Uint64(num_fields_bytes))
ctx.Log.Logf("serialize", "Deserializing %d fields from %+v", num_fields, struct_info)
struct_value := reflect.New(struct_info.Type).Elem() struct_value := reflect.New(struct_info.Type).Elem()
@ -452,6 +487,15 @@ func DeserializeStruct[T any](ctx *Context)(func(*Context,SerializedValue)(refle
field_value.Set(*field_reflect) field_value.Set(*field_reflect)
} }
if struct_info.PostDeserialize == true {
ctx.Log.Logf("serialize", "running post-deserialize for %+v", struct_info.Type)
post_deserialize_method := struct_value.Addr().Method(struct_info.PostDeserializeIdx)
ret := post_deserialize_method.Call([]reflect.Value{reflect.ValueOf(ctx)})
if ret[0].IsZero() == false {
return nil, nil, value, ret[0].Interface().(error)
}
}
return struct_info.Type, &struct_value, value, err return struct_info.Type, &struct_value, value, err
} }
} }
@ -670,6 +714,10 @@ func DeserializeValue(ctx *Context, value SerializedValue) (reflect.Type, *refle
return nil, nil, value, err return nil, nil, value, err
} }
if reflect_value != nil {
ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v", reflect_type, reflect_value.Interface(), err)
} else {
ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v", reflect_type, reflect_value, err) ctx.Log.Logf("serialize", "Deserialized %+v - %+v - %+v", reflect_type, reflect_value, err)
}
return reflect_type, reflect_value, value, nil return reflect_type, reflect_value, value, nil
} }