Moved type registration to signal/extension/policy registration

gql_cataclysm
noah metz 2023-09-12 20:30:18 -06:00
parent de1a229db6
commit 5c70d1b18d
7 changed files with 131 additions and 147 deletions

@ -107,6 +107,13 @@ func (ctx *Context) RegisterPolicy(reflect_type reflect.Type, policy_type Policy
return fmt.Errorf("Cannot register policy of type %+v, type already exists in context", policy_type) return fmt.Errorf("Cannot register policy of type %+v, type already exists in context", policy_type)
} }
err := ctx.RegisterType(reflect_type, SerializedType(policy_type), SerializeStruct(ctx, reflect_type), DeserializeStruct(ctx, reflect_type))
if err != nil {
return err
}
ctx.Log.Logf("serialize", "Registered PolicyType: %+v - %+v", reflect_type, policy_type)
ctx.Policies[policy_type] = reflect_type ctx.Policies[policy_type] = reflect_type
ctx.PolicyTypes[reflect_type] = policy_type ctx.PolicyTypes[reflect_type] = policy_type
return nil return nil
@ -118,6 +125,13 @@ func (ctx *Context)RegisterSignal(reflect_type reflect.Type, signal_type SignalT
return fmt.Errorf("Cannot register signal of type %+v, type already exists in context", signal_type) return fmt.Errorf("Cannot register signal of type %+v, type already exists in context", signal_type)
} }
err := ctx.RegisterType(reflect_type, SerializedType(signal_type), SerializeStruct(ctx, reflect_type), DeserializeStruct(ctx, reflect_type))
if err != nil {
return err
}
ctx.Log.Logf("serialize", "Registered SignalType: %+v - %+v", reflect_type, signal_type)
ctx.Signals[signal_type] = reflect_type ctx.Signals[signal_type] = reflect_type
ctx.SignalTypes[reflect_type] = signal_type ctx.SignalTypes[reflect_type] = signal_type
return nil return nil
@ -130,6 +144,12 @@ func (ctx *Context)RegisterExtension(reflect_type reflect.Type, ext_type ExtType
return fmt.Errorf("Cannot register extension of type %+v, type already exists in context", ext_type) return fmt.Errorf("Cannot register extension of type %+v, type already exists in context", ext_type)
} }
elem_type := reflect_type.Elem()
err := ctx.RegisterType(elem_type, SerializedType(ext_type), SerializeStruct(ctx, elem_type), DeserializeStruct(ctx, elem_type))
if err != nil {
return err
}
ctx.Log.Logf("serialize", "Registered ExtType: %+v - %+v", reflect_type, ext_type) ctx.Log.Logf("serialize", "Registered ExtType: %+v - %+v", reflect_type, ext_type)
ctx.Extensions[ext_type] = ExtensionInfo{ ctx.Extensions[ext_type] = ExtensionInfo{
@ -309,6 +329,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
} else { } else {
pointer_flags := value.Data[0] pointer_flags := value.Data[0]
value.Data = value.Data[1:] value.Data = value.Data[1:]
ctx.Log.Logf("serialize", "Pointer flags: 0x%x", pointer_flags)
if pointer_flags == 0x00 { if pointer_flags == 0x00 {
elem_type, elem_value, remaining_data, err := DeserializeValue(ctx, value) elem_type, elem_value, remaining_data, err := DeserializeValue(ctx, value)
if err != nil { if err != nil {
@ -319,14 +340,21 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
pointer_value.Set(elem_value.Addr()) pointer_value.Set(elem_value.Addr())
return pointer_type, &pointer_value, remaining_data, nil return pointer_type, &pointer_value, remaining_data, nil
} else if pointer_flags == 0x01 { } else if pointer_flags == 0x01 {
elem_type, _, remaining_data, err := DeserializeValue(ctx, value) tmp_value := SerializedValue{
value.TypeStack,
nil,
}
var elem_type reflect.Type
var err error
elem_type, _, tmp_value, err = DeserializeValue(ctx, tmp_value)
if err != nil { if err != nil {
return nil, nil, SerializedValue{}, err return nil, nil, SerializedValue{}, err
} }
value.TypeStack = tmp_value.TypeStack
pointer_type := reflect.PointerTo(elem_type) pointer_type := reflect.PointerTo(elem_type)
pointer_value := reflect.New(pointer_type).Elem() pointer_value := reflect.New(pointer_type).Elem()
return pointer_type, &pointer_value, remaining_data, nil return pointer_type, &pointer_value, value, nil
} else { } else {
return nil, nil, SerializedValue{}, fmt.Errorf("unknown pointer flags: %d", pointer_flags) return nil, nil, SerializedValue{}, fmt.Errorf("unknown pointer flags: %d", pointer_flags)
} }
@ -1008,103 +1036,102 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(PendingACL{}), PendingACLType, SerializeStruct[PendingACL](ctx), DeserializeStruct[PendingACL](ctx)) pending_acl_type := reflect.TypeOf(PendingACL{})
if err != nil { err = ctx.RegisterType(pending_acl_type, PendingACLType, SerializeStruct(ctx, pending_acl_type), DeserializeStruct(ctx, pending_acl_type))
return nil, err
}
err = ctx.RegisterType(reflect.TypeOf(PendingSignal{}), PendingSignalType, SerializeStruct[PendingSignal](ctx), DeserializeStruct[PendingSignal](ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(ListenerExt{}), SerializedType(ListenerExtType), SerializeStruct[ListenerExt](ctx), DeserializeStruct[ListenerExt](ctx)) pending_signal_type := reflect.TypeOf(PendingSignal{})
err = ctx.RegisterType(pending_signal_type, PendingSignalType, SerializeStruct(ctx, pending_signal_type), DeserializeStruct(ctx, pending_signal_type))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(GroupExt{}), SerializedType(GroupExtType), SerializeStruct[GroupExt](ctx), DeserializeStruct[GroupExt](ctx)) queued_signal_type := reflect.TypeOf(QueuedSignal{})
err = ctx.RegisterType(queued_signal_type, QueuedSignalType, SerializeStruct(ctx, queued_signal_type), DeserializeStruct(ctx, queued_signal_type))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(GQLExt{}), SerializedType(GQLExtType), SerializeStruct[GQLExt](ctx), DeserializeStruct[GQLExt](ctx)) node_type := reflect.TypeOf(Node{})
err = ctx.RegisterType(node_type, NodeStructType, SerializeStruct(ctx, node_type), DeserializeStruct(ctx, node_type))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(QueuedSignal{}), QueuedSignalType, SerializeStruct[QueuedSignal](ctx), DeserializeStruct[QueuedSignal](ctx)) err = ctx.RegisterExtension(reflect.TypeOf((*LockableExt)(nil)), LockableExtType, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(AllNodesPolicy{}), SerializedType(AllNodesPolicyType), SerializeStruct[AllNodesPolicy](ctx), DeserializeStruct[AllNodesPolicy](ctx)) err = ctx.RegisterExtension(reflect.TypeOf((*ListenerExt)(nil)), ListenerExtType, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(StatusSignal{}), SerializedType(StatusSignalType), SerializeStruct[StatusSignal](ctx), DeserializeStruct[StatusSignal](ctx)) err = ctx.RegisterExtension(reflect.TypeOf((*GroupExt)(nil)), GroupExtType, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(StopSignal{}), SerializedType(StopSignalType), SerializeStruct[StopSignal](ctx), DeserializeStruct[StopSignal](ctx)) gql_ctx := NewGQLExtContext()
err = ctx.RegisterExtension(reflect.TypeOf((*GQLExt)(nil)), GQLExtType, gql_ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(StartSignal{}), SerializedType(StartSignalType), SerializeStruct[StartSignal](ctx), DeserializeStruct[StartSignal](ctx)) err = ctx.RegisterPolicy(reflect.TypeOf(AllNodesPolicy{}), AllNodesPolicyType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterType(reflect.TypeOf(Node{}), NodeStructType, SerializeStruct[Node](ctx), DeserializeStruct[Node](ctx)) err = ctx.RegisterPolicy(reflect.TypeOf(PerNodePolicy{}), PerNodePolicyType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterExtension(reflect.TypeOf((*LockableExt)(nil)), LockableExtType, nil) err = ctx.RegisterSignal(reflect.TypeOf(StopSignal{}), StopSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterExtension(reflect.TypeOf((*ListenerExt)(nil)), ListenerExtType, nil) err = ctx.RegisterSignal(reflect.TypeOf(CreateSignal{}), CreateSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterExtension(reflect.TypeOf((*GroupExt)(nil)), GroupExtType, nil) err = ctx.RegisterSignal(reflect.TypeOf(StartSignal{}), StartSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
gql_ctx := NewGQLExtContext() err = ctx.RegisterSignal(reflect.TypeOf(StatusSignal{}), StatusSignalType)
err = ctx.RegisterExtension(reflect.TypeOf((*GQLExt)(nil)), GQLExtType, gql_ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterSignal(reflect.TypeOf((*StopSignal)(nil)), StopSignalType) err = ctx.RegisterSignal(reflect.TypeOf(ReadSignal{}), ReadSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterSignal(reflect.TypeOf((*CreateSignal)(nil)), CreateSignalType) err = ctx.RegisterSignal(reflect.TypeOf(LockSignal{}), LockSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterSignal(reflect.TypeOf((*StartSignal)(nil)), StartSignalType) err = ctx.RegisterSignal(reflect.TypeOf(LinkSignal{}), LinkSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterSignal(reflect.TypeOf((*ReadSignal)(nil)), ReadSignalType) err = ctx.RegisterSignal(reflect.TypeOf(ErrorSignal{}), ErrorSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterSignal(reflect.TypeOf((*ReadResultSignal)(nil)), ReadResultSignalType) err = ctx.RegisterSignal(reflect.TypeOf(ReadResultSignal{}), ReadResultSignalType)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -14,11 +14,11 @@ const (
) )
type LockableExt struct{ type LockableExt struct{
State ReqState `gv:"0"` State ReqState `gv:"state"`
ReqID *uuid.UUID `gv:"1"` ReqID *uuid.UUID `gv:"req_id"`
Owner *NodeID `gv:"2"` Owner *NodeID
PendingOwner *NodeID `gv:"3"` PendingOwner *NodeID
Requirements map[NodeID]ReqState `gv:"4"` Requirements map[NodeID]ReqState
} }
func (ext *LockableExt) Type() ExtType { func (ext *LockableExt) Type() ExtType {

@ -8,7 +8,7 @@ import (
) )
func TestNodeDB(t *testing.T) { func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"signal", "node", "db", "db_data", "serialize", "listener"}) ctx := logTestContext(t, []string{"signal", "node", "db", "serialize", "listener"})
node_type := NewNodeType("test") node_type := NewNodeType("test")
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType})
fatalErr(t, err) fatalErr(t, err)
@ -39,7 +39,7 @@ func TestNodeDB(t *testing.T) {
func TestNodeRead(t *testing.T) { func TestNodeRead(t *testing.T) {
ctx := logTestContext(t, []string{"test"}) ctx := logTestContext(t, []string{"test"})
node_type := NewNodeType("TEST") node_type := NewNodeType("TEST")
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType, ECDHExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType})
fatalErr(t, err) fatalErr(t, err)
n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader) n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader)

@ -12,15 +12,15 @@ type Policy interface {
Copy() Policy Copy() Policy
} }
func (policy *AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { func (policy AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
return nil, policy.Rules.Allows(action) return nil, policy.Rules.Allows(action)
} }
func (policy *AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { func (policy AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny return Deny
} }
func (policy *PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { func (policy PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
for id, actions := range(policy.NodeRules) { for id, actions := range(policy.NodeRules) {
if id != principal_id { if id != principal_id {
continue continue
@ -30,7 +30,7 @@ func (policy *PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tr
return nil, Deny return nil, Deny
} }
func (policy *PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { func (policy PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny return Deny
} }
@ -38,7 +38,7 @@ type RequirementOfPolicy struct {
PerNodePolicy PerNodePolicy
} }
func (policy *RequirementOfPolicy) Type() PolicyType { func (policy RequirementOfPolicy) Type() PolicyType {
return RequirementOfPolicyType return RequirementOfPolicyType
} }
@ -48,7 +48,7 @@ func NewRequirementOfPolicy(dep_rules map[NodeID]Tree) RequirementOfPolicy {
} }
} }
func (policy *RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { func (policy RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal) sig, ok := signal.(*ReadResultSignal)
if ok == false { if ok == false {
return Deny return Deny
@ -87,7 +87,7 @@ type MemberOfPolicy struct {
PerNodePolicy PerNodePolicy
} }
func (policy *MemberOfPolicy) Type() PolicyType { func (policy MemberOfPolicy) Type() PolicyType {
return MemberOfPolicyType return MemberOfPolicyType
} }
@ -97,7 +97,7 @@ func NewMemberOfPolicy(group_rules map[NodeID]Tree) MemberOfPolicy {
} }
} }
func (policy *MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal) sig, ok := signal.(*ReadResultSignal)
if ok == false { if ok == false {
return Deny return Deny
@ -133,7 +133,7 @@ func (policy *MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, s
} }
// Send a read signal to Group to check if principal_id is a member of it // Send a read signal to Group to check if principal_id is a member of it
func (policy *MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) { func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
msgs := Messages{} msgs := Messages{}
for id, rule := range(policy.NodeRules) { for id, rule := range(policy.NodeRules) {
if id == node.ID { if id == node.ID {
@ -156,13 +156,13 @@ func (policy *MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action T
return msgs, Pending return msgs, Pending
} }
func (policy *MemberOfPolicy) Merge(p Policy) Policy { func (policy MemberOfPolicy) Merge(p Policy) Policy {
other := p.(*MemberOfPolicy) other := p.(*MemberOfPolicy)
policy.NodeRules = MergeNodeRules(policy.NodeRules, other.NodeRules) policy.NodeRules = MergeNodeRules(policy.NodeRules, other.NodeRules)
return policy return policy
} }
func (policy *MemberOfPolicy) Copy() Policy { func (policy MemberOfPolicy) Copy() Policy {
new_rules := CopyNodeRules(policy.NodeRules) new_rules := CopyNodeRules(policy.NodeRules)
return &MemberOfPolicy{ return &MemberOfPolicy{
PerNodePolicy: NewPerNodePolicy(new_rules), PerNodePolicy: NewPerNodePolicy(new_rules),
@ -223,26 +223,26 @@ func MergeNodeRules(first map[NodeID]Tree, second map[NodeID]Tree) map[NodeID]Tr
return merged return merged
} }
func (policy *PerNodePolicy) Merge(p Policy) Policy { func (policy PerNodePolicy) Merge(p Policy) Policy {
other := p.(*PerNodePolicy) other := p.(*PerNodePolicy)
policy.NodeRules = MergeNodeRules(policy.NodeRules, other.NodeRules) policy.NodeRules = MergeNodeRules(policy.NodeRules, other.NodeRules)
return policy return policy
} }
func (policy *PerNodePolicy) Copy() Policy { func (policy PerNodePolicy) Copy() Policy {
new_rules := CopyNodeRules(policy.NodeRules) new_rules := CopyNodeRules(policy.NodeRules)
return &PerNodePolicy{ return &PerNodePolicy{
NodeRules: new_rules, NodeRules: new_rules,
} }
} }
func (policy *AllNodesPolicy) Merge(p Policy) Policy { func (policy AllNodesPolicy) Merge(p Policy) Policy {
other := p.(*AllNodesPolicy) other := p.(*AllNodesPolicy)
policy.Rules = MergeTrees(policy.Rules, other.Rules) policy.Rules = MergeTrees(policy.Rules, other.Rules)
return policy return policy
} }
func (policy *AllNodesPolicy) Copy() Policy { func (policy AllNodesPolicy) Copy() Policy {
new_rules := policy.Rules new_rules := policy.Rules
return &AllNodesPolicy { return &AllNodesPolicy {
Rules: new_rules, Rules: new_rules,
@ -292,7 +292,7 @@ type PerNodePolicy struct {
NodeRules map[NodeID]Tree `json:"node_actions"` NodeRules map[NodeID]Tree `json:"node_actions"`
} }
func (policy *PerNodePolicy) Type() PolicyType { func (policy PerNodePolicy) Type() PolicyType {
return PerNodePolicyType return PerNodePolicyType
} }
@ -306,7 +306,7 @@ type AllNodesPolicy struct {
Rules Tree Rules Tree
} }
func (policy *AllNodesPolicy) Type() PolicyType { func (policy AllNodesPolicy) Type() PolicyType {
return AllNodesPolicyType return AllNodesPolicyType
} }

@ -36,8 +36,19 @@ func (t ExtType) String() string {
} }
type NodeType SerializedType type NodeType SerializedType
func (t NodeType) String() string {
return fmt.Sprintf("0x%x", uint64(t))
}
type SignalType SerializedType type SignalType SerializedType
func (t SignalType) String() string {
return fmt.Sprintf("0x%x", uint64(t))
}
type PolicyType SerializedType type PolicyType SerializedType
func (t PolicyType) String() string {
return fmt.Sprintf("0x%x", uint64(t))
}
type TypeSerialize func(*Context,SerializedType,reflect.Type,*reflect.Value) (SerializedValue, error) type TypeSerialize func(*Context,SerializedType,reflect.Type,*reflect.Value) (SerializedValue, error)
type TypeDeserialize func(*Context,SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) type TypeDeserialize func(*Context,SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error)
@ -67,7 +78,6 @@ var (
LockableExtType = NewExtType("LOCKABLE") LockableExtType = NewExtType("LOCKABLE")
GQLExtType = NewExtType("GQL") GQLExtType = NewExtType("GQL")
GroupExtType = NewExtType("GROUP") GroupExtType = NewExtType("GROUP")
ECDHExtType = NewExtType("ECDH")
GQLNodeType = NewNodeType("GQL") GQLNodeType = NewNodeType("GQL")
@ -360,9 +370,7 @@ type Deserializable interface {
var deserializable_zero Deserializable = nil var deserializable_zero Deserializable = nil
var DeserializableType = reflect.TypeOf(&deserializable_zero).Elem() var DeserializableType = reflect.TypeOf(&deserializable_zero).Elem()
func structInfo[T any](ctx *Context)StructInfo{ func structInfo(ctx *Context, struct_type reflect.Type)StructInfo{
var struct_zero T
struct_type := reflect.TypeOf(struct_zero)
field_order := []SerializedType{} field_order := []SerializedType{}
field_map := map[SerializedType]FieldInfo{} field_map := map[SerializedType]FieldInfo{}
for _, field := range(reflect.VisibleFields(struct_type)) { for _, field := range(reflect.VisibleFields(struct_type)) {
@ -415,8 +423,8 @@ func structInfo[T any](ctx *Context)StructInfo{
} }
} }
func SerializeStruct[T any](ctx *Context)(func(*Context,SerializedType,reflect.Type,*reflect.Value)(SerializedValue,error)){ func SerializeStruct(ctx *Context, struct_type reflect.Type)(func(*Context,SerializedType,reflect.Type,*reflect.Value)(SerializedValue,error)){
struct_info := structInfo[T](ctx) struct_info := structInfo(ctx, struct_type)
return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){
type_stack := []SerializedType{ctx_type} type_stack := []SerializedType{ctx_type}
var data []byte var data []byte
@ -445,8 +453,8 @@ func SerializeStruct[T any](ctx *Context)(func(*Context,SerializedType,reflect.T
} }
} }
func DeserializeStruct[T any](ctx *Context)(func(*Context,SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ func DeserializeStruct(ctx *Context, struct_type reflect.Type)(func(*Context,SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){
struct_info := structInfo[T](ctx) struct_info := structInfo(ctx, struct_type)
return func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error) { return func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error) {
if value.Data == nil { if value.Data == nil {
return struct_info.Type, nil, value, nil return struct_info.Type, nil, value, nil

@ -7,7 +7,7 @@ import (
) )
func TestSerializeBasic(t *testing.T) { func TestSerializeBasic(t *testing.T) {
ctx := logTestContext(t, []string{"test", "serialize"}) ctx := logTestContext(t, []string{"test"})
testSerializeComparable[string](t, ctx, "test") testSerializeComparable[string](t, ctx, "test")
testSerializeComparable[bool](t, ctx, true) testSerializeComparable[bool](t, ctx, true)
testSerializeComparable[float32](t, ctx, 0.05) testSerializeComparable[float32](t, ctx, 0.05)
@ -46,10 +46,16 @@ func TestSerializeBasic(t *testing.T) {
6: 1121, 6: 1121,
}) })
testSerializeStruct(t, ctx, struct{ type test_struct struct {
int `gv:"0"` Int int `gv:"int"`
string `gv:"1"` String string `gv:"string"`
}{ }
test_struct_type := reflect.TypeOf(test_struct{})
err := ctx.RegisterType(test_struct_type, NewSerializedType("TEST_STRUCT"), SerializeStruct(ctx, test_struct_type), DeserializeStruct(ctx, test_struct_type))
fatalErr(t, err)
testSerialize(t, ctx, test_struct{
12345, 12345,
"test_string", "test_string",
}) })
@ -65,11 +71,12 @@ func (s test) String() string {
} }
func TestSerializeStructTags(t *testing.T) { func TestSerializeStructTags(t *testing.T) {
ctx := logTestContext(t, []string{"test", "serialize"}) ctx := logTestContext(t, []string{"test"})
test_type := NewSerializedType("TEST_STRUCT") test_type := NewSerializedType("TEST_STRUCT")
test_struct_type := reflect.TypeOf(test{})
ctx.Log.Logf("test", "TEST_TYPE: %+v", test_type) ctx.Log.Logf("test", "TEST_TYPE: %+v", test_type)
ctx.RegisterType(reflect.TypeOf(test{}), test_type, SerializeStruct[test](ctx), DeserializeStruct[test](ctx)) ctx.RegisterType(test_struct_type, test_type, SerializeStruct(ctx, test_struct_type), DeserializeStruct(ctx, test_struct_type))
test_int := 10 test_int := 10
test_string := "test" test_string := "test"
@ -145,64 +152,6 @@ func testSerializeComparable[T comparable](t *testing.T, ctx *Context, val T) {
} }
} }
func testSerializeStruct[T any](t *testing.T, ctx *Context, val T) {
value, err := SerializeAny(ctx, val)
fatalErr(t, err)
ctx.Log.Logf("test", "Serialized %+v to %+v", val, value)
ser, err := value.MarshalBinary()
fatalErr(t, err)
ctx.Log.Logf("test", "Binary: %+v", ser)
val_parsed, remaining_parse, err := ParseSerializedValue(ser)
fatalErr(t, err)
ctx.Log.Logf("test", "Parsed: %+v", val_parsed)
if len(remaining_parse) != 0 {
t.Fatal("Data remaining after deserializing value")
}
val_type, deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_parsed)
fatalErr(t, err)
if len(remaining_deserialize.Data) != 0 {
t.Fatal("Data remaining after deserializing value")
} else if len(remaining_deserialize.TypeStack) != 0 {
t.Fatal("TypeStack remaining after deserializing value")
} else if val_type != reflect.TypeOf(map[uint64]reflect.Value{}) {
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - map[uint64]reflect.Value", val_type))
} else if deserialized_value == nil {
t.Fatal("DeserializeValue returned no []reflect.Value")
} else if deserialized_value == nil {
t.Fatal("DeserializeValue returned nil *reflect.Value")
} else if deserialized_value.CanConvert(reflect.TypeOf(map[uint64]reflect.Value{})) == false {
t.Fatal("DeserializeValue returned value that can't convert to map[uint64]reflect.Value")
}
reflect_value := reflect.ValueOf(val)
deserialized_map := deserialized_value.Interface().(map[uint64]reflect.Value)
for _, field := range(reflect.VisibleFields(reflect_value.Type())) {
gv_tag, tagged_gv := field.Tag.Lookup("gv")
if tagged_gv == false {
continue
} else if gv_tag == "" {
continue
} else {
field_hash := uint64(Hash(FieldNameBase, gv_tag))
deserialized_field, exists := deserialized_map[field_hash]
if exists == false {
t.Fatal(fmt.Sprintf("field %s is not in deserialized struct", field.Name))
}
field_value := reflect_value.FieldByIndex(field.Index)
if field_value.Type() != deserialized_field.Type() {
t.Fatal(fmt.Sprintf("Type of %s does not match", field.Name))
}
ctx.Log.Logf("test", "Field %s matched", field.Name)
}
}
}
func testSerialize[T any](t *testing.T, ctx *Context, val T) T { func testSerialize[T any](t *testing.T, ctx *Context, val T) T {
value := reflect.ValueOf(&val).Elem() value := reflect.ValueOf(&val).Elem()
value_serialized, err := SerializeValue(ctx, value.Type(), &value) value_serialized, err := SerializeValue(ctx, value.Type(), &value)

@ -139,10 +139,10 @@ type CreateSignal struct {
SignalHeader SignalHeader
} }
func (signal *CreateSignal) Header() *SignalHeader { func (signal CreateSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
func (signal *CreateSignal) Permission() Tree { func (signal CreateSignal) Permission() Tree {
return Tree{ return Tree{
SerializedType(CreateSignalType): nil, SerializedType(CreateSignalType): nil,
} }
@ -157,10 +157,10 @@ func NewCreateSignal() *CreateSignal {
type StartSignal struct { type StartSignal struct {
SignalHeader SignalHeader
} }
func (signal *StartSignal) Header() *SignalHeader { func (signal StartSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
func (signal *StartSignal) Permission() Tree { func (signal StartSignal) Permission() Tree {
return Tree{ return Tree{
SerializedType(StartSignalType): nil, SerializedType(StartSignalType): nil,
} }
@ -174,10 +174,10 @@ func NewStartSignal() *StartSignal {
type StopSignal struct { type StopSignal struct {
SignalHeader SignalHeader
} }
func (signal *StopSignal) Header() *SignalHeader { func (signal StopSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
func (signal *StopSignal) Permission() Tree { func (signal StopSignal) Permission() Tree {
return Tree{ return Tree{
SerializedType(StopSignalType): nil, SerializedType(StopSignalType): nil,
} }
@ -192,10 +192,10 @@ type ErrorSignal struct {
SignalHeader SignalHeader
Error string Error string
} }
func (signal *ErrorSignal) Header() *SignalHeader { func (signal ErrorSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
func (signal *ErrorSignal) MarshalBinary() ([]byte, error) { func (signal ErrorSignal) MarshalBinary() ([]byte, error) {
arena := capnp.SingleSegment(nil) arena := capnp.SingleSegment(nil)
msg, seg, err := capnp.NewMessage(arena) msg, seg, err := capnp.NewMessage(arena)
if err != nil { if err != nil {
@ -221,7 +221,7 @@ func (signal *ErrorSignal) MarshalBinary() ([]byte, error) {
return msg.Marshal() return msg.Marshal()
} }
func (signal *ErrorSignal) Deserialize(ctx *Context, data []byte) error { func (signal ErrorSignal) Deserialize(ctx *Context, data []byte) error {
msg, err := capnp.Unmarshal(data) msg, err := capnp.Unmarshal(data)
if err != nil { if err != nil {
return err return err
@ -248,7 +248,7 @@ func (signal *ErrorSignal) Deserialize(ctx *Context, data []byte) error {
return nil return nil
} }
func (signal *ErrorSignal) Permission() Tree { func (signal ErrorSignal) Permission() Tree {
return Tree{ return Tree{
SerializedType(ErrorSignalType): nil, SerializedType(ErrorSignalType): nil,
} }
@ -263,10 +263,10 @@ func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Si
type ACLTimeoutSignal struct { type ACLTimeoutSignal struct {
SignalHeader SignalHeader
} }
func (signal *ACLTimeoutSignal) Header() *SignalHeader { func (signal ACLTimeoutSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
func (signal *ACLTimeoutSignal) Permission() Tree { func (signal ACLTimeoutSignal) Permission() Tree {
return Tree{ return Tree{
SerializedType(ACLTimeoutSignalType): nil, SerializedType(ACLTimeoutSignalType): nil,
} }
@ -283,10 +283,10 @@ type StatusSignal struct {
Source NodeID `gv:"source"` Source NodeID `gv:"source"`
Status string `gv:"status"` Status string `gv:"status"`
} }
func (signal *StatusSignal) Header() *SignalHeader { func (signal StatusSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
func (signal *StatusSignal) Permission() Tree { func (signal StatusSignal) Permission() Tree {
return Tree{ return Tree{
SerializedType(StatusSignalType): nil, SerializedType(StatusSignalType): nil,
} }
@ -304,7 +304,7 @@ type LinkSignal struct {
NodeID NodeID
Action string Action string
} }
func (signal *LinkSignal) Header() *SignalHeader { func (signal LinkSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
@ -313,7 +313,7 @@ const (
LinkActionAdd = "ADD" LinkActionAdd = "ADD"
) )
func (signal *LinkSignal) Permission() Tree { func (signal LinkSignal) Permission() Tree {
return Tree{ return Tree{
SerializedType(LinkSignalType): Tree{ SerializedType(LinkSignalType): Tree{
Hash(LinkActionBase, signal.Action): nil, Hash(LinkActionBase, signal.Action): nil,
@ -332,7 +332,7 @@ type LockSignal struct {
SignalHeader SignalHeader
State string State string
} }
func (signal *LockSignal) Header() *SignalHeader { func (signal LockSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
@ -340,7 +340,7 @@ const (
LockStateBase = "LOCK_STATE" LockStateBase = "LOCK_STATE"
) )
func (signal *LockSignal) Permission() Tree { func (signal LockSignal) Permission() Tree {
return Tree{ return Tree{
SerializedType(LockSignalType): Tree{ SerializedType(LockSignalType): Tree{
Hash(LockStateBase, signal.State): nil, Hash(LockStateBase, signal.State): nil,
@ -359,7 +359,7 @@ type ReadSignal struct {
SignalHeader SignalHeader
Extensions map[ExtType][]string `json:"extensions"` Extensions map[ExtType][]string `json:"extensions"`
} }
func (signal *ReadSignal) MarshalBinary() ([]byte, error) { func (signal ReadSignal) MarshalBinary() ([]byte, error) {
arena := capnp.SingleSegment(nil) arena := capnp.SingleSegment(nil)
msg, seg, err := capnp.NewMessage(arena) msg, seg, err := capnp.NewMessage(arena)
if err != nil { if err != nil {
@ -407,11 +407,11 @@ func (signal *ReadSignal) MarshalBinary() ([]byte, error) {
return msg.Marshal() return msg.Marshal()
} }
func (signal *ReadSignal) Header() *SignalHeader { func (signal ReadSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
func (signal *ReadSignal) Permission() Tree { func (signal ReadSignal) Permission() Tree {
ret := Tree{} ret := Tree{}
for ext, fields := range(signal.Extensions) { for ext, fields := range(signal.Extensions) {
field_tree := Tree{} field_tree := Tree{}
@ -435,10 +435,10 @@ type ReadResultSignal struct {
NodeType NodeType NodeType NodeType
Extensions map[ExtType]map[string]SerializedValue Extensions map[ExtType]map[string]SerializedValue
} }
func (signal *ReadResultSignal) Header() *SignalHeader { func (signal ReadResultSignal) Header() *SignalHeader {
return &signal.SignalHeader return &signal.SignalHeader
} }
func (signal *ReadResultSignal) Permission() Tree { func (signal ReadResultSignal) Permission() Tree {
return Tree{ return Tree{
SerializedType(ReadResultSignalType): nil, SerializedType(ReadResultSignalType): nil,
} }