diff --git a/context.go b/context.go index 2a32958..3d263de 100644 --- a/context.go +++ b/context.go @@ -942,6 +942,100 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { } */ + err = ctx.RegisterType(reflect.TypeOf(Tree{}), TreeType, func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + var data []byte + type_stack := []SerializedType{ctx_type} + if value == nil { + data = nil + } else if value.IsZero() { + data = []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} + } else if value.Len() == 0 { + data = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + } else { + data = make([]byte, 8) + map_size := 0 + + map_iter := value.MapRange() + for map_iter.Next() { + map_size += 1 + key_reflect := map_iter.Key() + elem_reflect := map_iter.Value() + + key_value, err := SerializeValue(ctx, key_reflect.Type(), &key_reflect) + if err != nil { + return SerializedValue{}, err + } + elem_value, err := SerializeValue(ctx, elem_reflect.Type(), &elem_reflect) + if err != nil { + return SerializedValue{}, err + } + + data = append(data, key_value.Data...) + data = append(data, elem_value.Data...) + } + + binary.BigEndian.PutUint64(data[0:8], uint64(map_size)) + } + return SerializedValue{ + type_stack, + data, + }, nil + },func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ + if value.Data == nil { + return reflect.TypeOf(Tree{}), nil, value, nil + } else if len(value.Data) < 8 { + return nil, nil, value, fmt.Errorf("Not enough data to deserialize Tree") + } else { + var map_size_bytes []byte + var err error + map_size_bytes, value, err = value.PopData(8) + if err != nil { + return nil, nil, value, err + } + + map_size := binary.BigEndian.Uint64(map_size_bytes) + ctx.Log.Logf("serialize", "Deserializing %d elements in Tree", map_size) + + if map_size == 0xFFFFFFFFFFFFFFFF { + reflect_type := reflect.TypeOf(Tree{}) + reflect_value := reflect.New(reflect_type).Elem() + return reflect_type, &reflect_value, value, nil + } else if map_size == 0x00 { + reflect_type := reflect.TypeOf(Tree{}) + reflect_value := reflect.MakeMap(reflect_type) + return reflect_type, &reflect_value, value, nil + } else { + reflect_type := reflect.TypeOf(Tree{}) + reflect_value := reflect.MakeMap(reflect_type) + + tmp_value := value + + for i := 0; i < int(map_size); i += 1 { + tmp_value.TypeStack = append([]SerializedType{SerializedTypeSerialized, TreeType}, value.TypeStack...) + + var key_value, elem_value *reflect.Value + var err error + _, key_value, tmp_value, err = DeserializeValue(ctx, tmp_value) + if err != nil { + return nil, nil, value, err + } + _, elem_value, tmp_value, err = DeserializeValue(ctx, tmp_value) + if err != nil { + return nil, nil, value, err + } + reflect_value.SetMapIndex(*key_value, *elem_value) + } + + return reflect_type, &reflect_value, tmp_value, nil + } + } + }) + + err = ctx.RegisterType(reflect.TypeOf(SerializedType(0)), SerializedTypeSerialized, SerializeUintN(8), DeserializeUintN[SerializedType](8)) + if err != nil { + return nil, err + } + err = ctx.RegisterType(reflect.TypeOf(ExtType(0)), ExtTypeSerialized, SerializeUintN(8), DeserializeUintN[ExtType](8)) if err != nil { return nil, err @@ -1136,6 +1230,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterSignal(reflect.TypeOf(SuccessSignal{}), SuccessSignalType) + if err != nil { + return nil, err + } + err = ctx.RegisterSignal(reflect.TypeOf(ReadResultSignal{}), ReadResultSignalType) if err != nil { return nil, err diff --git a/gql_test.go b/gql_test.go index 04fb06e..1d7b10d 100644 --- a/gql_test.go +++ b/gql_test.go @@ -65,16 +65,11 @@ func TestGQLServer(t *testing.T) { fatalErr(t, err) listener_ext := NewListenerExt(10) - n1, err := NewNode(ctx, nil, TestNodeType, 10, map[PolicyType]Policy{ - MemberOfPolicyType: &user_policy_2, - AllNodesPolicyType: &user_policy_1, - }, NewLockableExt(nil)) + n1, err := NewNode(ctx, nil, TestNodeType, 10, []Policy{user_policy_2, user_policy_1}, NewLockableExt(nil)) fatalErr(t, err) - gql, err := NewNode(ctx, gql_key, GQLNodeType, 10, map[PolicyType]Policy{ - MemberOfPolicyType: &group_policy_2, - AllNodesPolicyType: &group_policy_1, - }, NewLockableExt([]NodeID{n1.ID}), gql_ext, NewGroupExt(map[NodeID]string{ + gql, err := NewNode(ctx, gql_key, GQLNodeType, 10, []Policy{group_policy_2, group_policy_1}, + NewLockableExt([]NodeID{n1.ID}), gql_ext, NewGroupExt(map[NodeID]string{ n1.ID: "user", gql_id: "self", }), listener_ext) diff --git a/lockable_test.go b/lockable_test.go index 172e0f1..1a3d520 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -28,10 +28,7 @@ func TestLink(t *testing.T) { }) l2_listener := NewListenerExt(10) - l2, err := NewNode(ctx, nil, TestLockableType, 10, - map[PolicyType]Policy{ - PerNodePolicyType: &policy, - }, + l2, err := NewNode(ctx, nil, TestLockableType, 10, []Policy{policy}, l2_listener, NewLockableExt(nil), ) @@ -77,10 +74,7 @@ func Test10KLink(t *testing.T) { }, }) NewLockable := func()(*Node) { - l, err := NewNode(ctx, nil, TestLockableType, 10, - map[PolicyType]Policy{ - PerNodePolicyType: &child_policy, - }, + l, err := NewNode(ctx, nil, TestLockableType, 10, []Policy{child_policy}, NewLockableExt(nil), ) fatalErr(t, err) @@ -98,10 +92,7 @@ func Test10KLink(t *testing.T) { SerializedType(LockSignalType): nil, }) listener := NewListenerExt(100000) - node, err := NewNode(ctx, listener_key, TestLockableType, 10000, - map[PolicyType]Policy{ - AllNodesPolicyType: &l_policy, - }, + node, err := NewNode(ctx, listener_key, TestLockableType, 10000, []Policy{l_policy}, listener, NewLockableExt(reqs), ) @@ -132,10 +123,7 @@ func TestLock(t *testing.T) { NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { listener := NewListenerExt(100) - l, err := NewNode(ctx, nil, TestLockableType, 10, - map[PolicyType]Policy{ - AllNodesPolicyType: &policy, - }, + l, err := NewNode(ctx, nil, TestLockableType, 10, []Policy{policy}, listener, NewLockableExt(reqs), ) diff --git a/node.go b/node.go index ebab3a6..befb070 100644 --- a/node.go +++ b/node.go @@ -14,16 +14,6 @@ import ( "crypto/rand" ) -const ( - // Magic first four bytes of serialized DB content, stored big endian - NODE_DB_MAGIC = 0x2491df14 - // Total length of the node database header, has magic to verify and type_hash to map to load function - NODE_DB_HEADER_LEN = 32 - EXTENSION_DB_HEADER_LEN = 16 - QSIGNAL_DB_HEADER_LEN = 24 - POLICY_DB_HEADER_LEN = 16 -) - var ( // Base NodeID, used as a special value ZeroUUID = uuid.UUID{} @@ -84,7 +74,7 @@ type PendingACL struct { } type PendingSignal struct { - Policy PolicyType + Policy uuid.UUID Found bool ID uuid.UUID } @@ -96,7 +86,7 @@ type Node struct { ID NodeID Type NodeType `gv:"type"` Extensions map[ExtType]Extension `gv:"extensions"` - Policies map[PolicyType]Policy `gv:"policies"` + Policies []Policy `gv:"policies"` PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"` PendingSignals map[uuid.UUID]PendingSignal `gv:"pending_signal"` @@ -134,14 +124,14 @@ const ( Pending ) -func (node *Node) Allows(ctx *Context, principal_id NodeID, action Tree)(map[PolicyType]Messages, RuleResult) { - pends := map[PolicyType]Messages{} - for policy_type, policy := range(node.Policies) { +func (node *Node) Allows(ctx *Context, principal_id NodeID, action Tree)(map[uuid.UUID]Messages, RuleResult) { + pends := map[uuid.UUID]Messages{} + for _, policy := range(node.Policies) { msgs, resp := policy.Allows(ctx, principal_id, action, node) if resp == Allow { return nil, Allow } else if resp == Pending { - pends[policy_type] = msgs + pends[policy.ID()] = msgs } } if len(pends) != 0 { @@ -370,33 +360,45 @@ func nodeLoop(ctx *Context, node *Node) error { req_info.Counter -= 1 req_info.Responses = append(req_info.Responses, signal) - allowed := node.Policies[info.Policy].ContinueAllows(ctx, req_info, signal) - if allowed == Allow { - ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) - signal = req_info.Signal - source = req_info.Source - err := node.DequeueSignal(req_info.TimeoutID) - if err != nil { - panic("dequeued a passed signal") - } - delete(node.PendingACLs, info.ID) - } else if req_info.Counter == 0 { - ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) - // Send the denied response - msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.Header().ID, "ACL_DENIED"), req_info.Source) - err := ctx.Send(msgs) - if err != nil { - ctx.Log.Logf("signal", "SEND_ERR: %s", err) - } - err = node.DequeueSignal(req_info.TimeoutID) - if err != nil { - panic("dequeued a passed signal") + idx := -1 + for i, p := range(node.Policies) { + if p.ID() == info.Policy { + idx = i + break } + } + if idx == -1 { + ctx.Log.Logf("policy", "PENDING_FOR_NONEXISTENT_POLICY: %s - %s", node.ID, info.Policy) delete(node.PendingACLs, info.ID) } else { - node.PendingACLs[info.ID] = req_info - continue + allowed := node.Policies[idx].ContinueAllows(ctx, req_info, signal) + if allowed == Allow { + ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) + signal = req_info.Signal + source = req_info.Source + err := node.DequeueSignal(req_info.TimeoutID) + if err != nil { + panic("dequeued a passed signal") + } + delete(node.PendingACLs, info.ID) + } else if req_info.Counter == 0 { + ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) + // Send the denied response + msgs := Messages{} + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.Header().ID, "ACL_DENIED"), req_info.Source) + err := ctx.Send(msgs) + if err != nil { + ctx.Log.Logf("signal", "SEND_ERR: %s", err) + } + err = node.DequeueSignal(req_info.TimeoutID) + if err != nil { + panic("dequeued a passed signal") + } + delete(node.PendingACLs, info.ID) + } else { + node.PendingACLs[info.ID] = req_info + continue + } } } } @@ -544,7 +546,7 @@ func KeyID(pub ed25519.PublicKey) NodeID { } // Create a new node in memory and start it's event loop -func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, policies map[PolicyType]Policy, extensions ...Extension) (*Node, error) { +func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, policies []Policy, extensions ...Extension) (*Node, error) { var err error var public ed25519.PublicKey if key == nil { @@ -586,22 +588,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si } } - if policies == nil { - policies = map[PolicyType]Policy{} - } - - default_policy := NewAllNodesPolicy(Tree{ - SerializedType(ErrorSignalType): nil, - SerializedType(ReadResultSignalType): nil, - SerializedType(StatusSignalType): nil, - }) - - all_nodes_policy, exists := policies[AllNodesPolicyType] - if exists == true { - policies[AllNodesPolicyType] = all_nodes_policy.Merge(&default_policy) - } else { - policies[AllNodesPolicyType] = &default_policy - } + policies = append(policies, DefaultPolicy) node := &Node{ Key: key, diff --git a/node_test.go b/node_test.go index 44c67df..5c0e0e5 100644 --- a/node_test.go +++ b/node_test.go @@ -63,9 +63,7 @@ func TestNodeRead(t *testing.T) { n2, err := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), n2_listener) fatalErr(t, err) - n1, err := NewNode(ctx, n1_key, node_type, 10, map[PolicyType]Policy{ - PerNodePolicyType: &n1_policy, - }, NewGroupExt(nil)) + n1, err := NewNode(ctx, n1_key, node_type, 10, []Policy{n1_policy}, NewGroupExt(nil)) fatalErr(t, err) read_sig := NewReadSignal(map[ExtType][]string{ diff --git a/policy.go b/policy.go index d212a66..0c7e792 100644 --- a/policy.go +++ b/policy.go @@ -1,15 +1,21 @@ package graphvent import ( + "github.com/google/uuid" ) type Policy interface { Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult - // Merge with another policy of the same underlying type - Merge(Policy) Policy - // Make a copy of this policy - Copy() Policy + ID() uuid.UUID +} + +type PolicyHeader struct { + UUID uuid.UUID `gv:"uuid"` +} + +func (header PolicyHeader) ID() uuid.UUID { + return header.UUID } func (policy AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { @@ -38,10 +44,6 @@ type RequirementOfPolicy struct { PerNodePolicy } -func (policy RequirementOfPolicy) Type() PolicyType { - return RequirementOfPolicyType -} - func NewRequirementOfPolicy(dep_rules map[NodeID]Tree) RequirementOfPolicy { return RequirementOfPolicy { PerNodePolicy: NewPerNodePolicy(dep_rules), @@ -87,10 +89,6 @@ type MemberOfPolicy struct { PerNodePolicy } -func (policy MemberOfPolicy) Type() PolicyType { - return MemberOfPolicyType -} - func NewMemberOfPolicy(group_rules map[NodeID]Tree) MemberOfPolicy { return MemberOfPolicy{ PerNodePolicy: NewPerNodePolicy(group_rules), @@ -156,19 +154,6 @@ func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tr return msgs, Pending } -func (policy MemberOfPolicy) Merge(p Policy) Policy { - other := p.(*MemberOfPolicy) - policy.NodeRules = MergeNodeRules(policy.NodeRules, other.NodeRules) - return policy -} - -func (policy MemberOfPolicy) Copy() Policy { - new_rules := CopyNodeRules(policy.NodeRules) - return &MemberOfPolicy{ - PerNodePolicy: NewPerNodePolicy(new_rules), - } -} - func CopyTree(tree Tree) Tree { if tree == nil { return nil @@ -199,56 +184,6 @@ func MergeTrees(first Tree, second Tree) Tree { return ret } -func CopyNodeRules(rules map[NodeID]Tree) map[NodeID]Tree { - ret := map[NodeID]Tree{} - for id, r := range(rules) { - ret[id] = r - } - return ret -} - -func MergeNodeRules(first map[NodeID]Tree, second map[NodeID]Tree) map[NodeID]Tree { - merged := map[NodeID]Tree{} - for id, actions := range(first) { - merged[id] = actions - } - for id, actions := range(second) { - existing, exists := merged[id] - if exists { - merged[id] = MergeTrees(existing, actions) - } else { - merged[id] = actions - } - } - return merged -} - -func (policy PerNodePolicy) Merge(p Policy) Policy { - other := p.(*PerNodePolicy) - policy.NodeRules = MergeNodeRules(policy.NodeRules, other.NodeRules) - return policy -} - -func (policy PerNodePolicy) Copy() Policy { - new_rules := CopyNodeRules(policy.NodeRules) - return &PerNodePolicy{ - NodeRules: new_rules, - } -} - -func (policy AllNodesPolicy) Merge(p Policy) Policy { - other := p.(*AllNodesPolicy) - policy.Rules = MergeTrees(policy.Rules, other.Rules) - return policy -} - -func (policy AllNodesPolicy) Copy() Policy { - new_rules := policy.Rules - return &AllNodesPolicy { - Rules: new_rules, - } -} - type Tree map[SerializedType]Tree func (rule Tree) Allows(action Tree) RuleResult { @@ -278,39 +213,40 @@ func (rule Tree) Allows(action Tree) RuleResult { } } +func NewPolicyHeader() PolicyHeader { + return PolicyHeader{ + UUID: uuid.New(), + } +} + func NewPerNodePolicy(node_actions map[NodeID]Tree) PerNodePolicy { if node_actions == nil { node_actions = map[NodeID]Tree{} } return PerNodePolicy{ + PolicyHeader: NewPolicyHeader(), NodeRules: node_actions, } } type PerNodePolicy struct { - NodeRules map[NodeID]Tree `json:"node_actions"` -} - -func (policy PerNodePolicy) Type() PolicyType { - return PerNodePolicyType + PolicyHeader + NodeRules map[NodeID]Tree `gv:"node_rules"` } func NewAllNodesPolicy(rules Tree) AllNodesPolicy { return AllNodesPolicy{ + PolicyHeader: NewPolicyHeader(), Rules: rules, } } type AllNodesPolicy struct { - Rules Tree -} - -func (policy AllNodesPolicy) Type() PolicyType { - return AllNodesPolicyType + PolicyHeader + Rules Tree `gv:"rules"` } var DefaultPolicy = NewAllNodesPolicy(Tree{ - SerializedType(ErrorSignalType): nil, - SerializedType(ReadResultSignalType): nil, + ResultType: nil, }) diff --git a/serialize.go b/serialize.go index a32fdd1..e6b4ec3 100644 --- a/serialize.go +++ b/serialize.go @@ -9,13 +9,13 @@ import ( ) const ( - TagBase = "GraphventTag" - ExtTypeBase = "ExtType" - NodeTypeBase = "NodeType" - SignalTypeBase = "SignalType" - PolicyTypeBase = "PolicyType" + TagBase = "GraphventTag" + ExtTypeBase = "ExtType" + NodeTypeBase = "NodeType" + SignalTypeBase = "SignalType" + PolicyTypeBase = "PolicyType" SerializedTypeBase = "SerializedType" - FieldNameBase = "FieldName" + FieldNameBase = "FieldName" ) func Hash(base string, name string) SerializedType { @@ -26,32 +26,37 @@ func Hash(base string, name string) SerializedType { } type SerializedType uint64 + func (t SerializedType) String() string { return fmt.Sprintf("0x%x", uint64(t)) } type ExtType SerializedType + func (t ExtType) String() string { return fmt.Sprintf("0x%x", uint64(t)) } type NodeType SerializedType + func (t NodeType) String() string { return fmt.Sprintf("0x%x", uint64(t)) } type SignalType SerializedType + func (t SignalType) String() string { return fmt.Sprintf("0x%x", uint64(t)) } type PolicyType SerializedType + func (t PolicyType) String() string { return fmt.Sprintf("0x%x", uint64(t)) } -type TypeSerialize func(*Context,SerializedType,reflect.Type,*reflect.Value) (SerializedValue, error) -type TypeDeserialize func(*Context,SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) +type TypeSerialize func(*Context, SerializedType, reflect.Type, *reflect.Value) (SerializedValue, error) +type TypeDeserialize func(*Context, SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) func NewExtType(name string) ExtType { return ExtType(Hash(ExtTypeBase, name)) @@ -84,59 +89,63 @@ var ( StopSignalType = NewSignalType("STOP") CreateSignalType = NewSignalType("CREATE") StartSignalType = NewSignalType("START") - ErrorSignalType = NewSignalType("ERROR") StatusSignalType = NewSignalType("STATUS") LinkSignalType = NewSignalType("LINK") LockSignalType = NewSignalType("LOCK") ReadSignalType = NewSignalType("READ") - ReadResultSignalType = NewSignalType("READ_RESULT") ACLTimeoutSignalType = NewSignalType("ACL_TIMEOUT") + ErrorSignalType = NewSignalType("ERROR") + SuccessSignalType = NewSignalType("SUCCESS") + ReadResultSignalType = NewSignalType("READ_RESULT") MemberOfPolicyType = NewPolicyType("USER_OF") RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF") PerNodePolicyType = NewPolicyType("PER_NODE") AllNodesPolicyType = NewPolicyType("ALL_NODES") - ErrorType = NewSerializedType("ERROR") - PointerType = NewSerializedType("POINTER") - SliceType = NewSerializedType("SLICE") - StructType = NewSerializedType("STRUCT") - IntType = NewSerializedType("INT") - UIntType = NewSerializedType("UINT") - BoolType = NewSerializedType("BOOL") - Float64Type = NewSerializedType("FLOAT64") - Float32Type = NewSerializedType("FLOAT32") - UInt8Type = NewSerializedType("UINT8") - UInt16Type = NewSerializedType("UINT16") - UInt32Type = NewSerializedType("UINT32") - UInt64Type = NewSerializedType("UINT64") - Int8Type = NewSerializedType("INT8") - Int16Type = NewSerializedType("INT16") - Int32Type = NewSerializedType("INT32") - Int64Type = NewSerializedType("INT64") - StringType = NewSerializedType("STRING") - ArrayType = NewSerializedType("ARRAY") + ErrorType = NewSerializedType("ERROR") + PointerType = NewSerializedType("POINTER") + SliceType = NewSerializedType("SLICE") + StructType = NewSerializedType("STRUCT") + IntType = NewSerializedType("INT") + UIntType = NewSerializedType("UINT") + BoolType = NewSerializedType("BOOL") + Float64Type = NewSerializedType("FLOAT64") + Float32Type = NewSerializedType("FLOAT32") + UInt8Type = NewSerializedType("UINT8") + UInt16Type = NewSerializedType("UINT16") + UInt32Type = NewSerializedType("UINT32") + UInt64Type = NewSerializedType("UINT64") + Int8Type = NewSerializedType("INT8") + Int16Type = NewSerializedType("INT16") + Int32Type = NewSerializedType("INT32") + Int64Type = NewSerializedType("INT64") + StringType = NewSerializedType("STRING") + ArrayType = NewSerializedType("ARRAY") InterfaceType = NewSerializedType("INTERFACE") - MapType = NewSerializedType("MAP") - - ReqStateType = NewSerializedType("REQ_STATE") - SignalDirectionType = NewSerializedType("SIGNAL_DIRECTION") - NodeStructType = NewSerializedType("NODE_STRUCT") - QueuedSignalType = NewSerializedType("QUEUED_SIGNAL") - NodeTypeSerialized = NewSerializedType("NODE_TYPE") - ExtTypeSerialized = NewSerializedType("EXT_TYPE") + MapType = NewSerializedType("MAP") + + ReqStateType = NewSerializedType("REQ_STATE") + SignalDirectionType = NewSerializedType("SIGNAL_DIRECTION") + NodeStructType = NewSerializedType("NODE_STRUCT") + QueuedSignalType = NewSerializedType("QUEUED_SIGNAL") + NodeTypeSerialized = NewSerializedType("NODE_TYPE") + ExtTypeSerialized = NewSerializedType("EXT_TYPE") PolicyTypeSerialized = NewSerializedType("POLICY_TYPE") - ExtSerialized = NewSerializedType("EXTENSION") - PolicySerialized = NewSerializedType("POLICY") - SignalSerialized = NewSerializedType("SIGNAL") - NodeIDType = NewSerializedType("NODE_ID") - UUIDType = NewSerializedType("UUID") - PendingACLType = NewSerializedType("PENDING_ACL") - PendingSignalType = NewSerializedType("PENDING_SIGNAL") - TimeType = NewSerializedType("TIME") + ExtSerialized = NewSerializedType("EXTENSION") + PolicySerialized = NewSerializedType("POLICY") + SignalSerialized = NewSerializedType("SIGNAL") + NodeIDType = NewSerializedType("NODE_ID") + UUIDType = NewSerializedType("UUID") + PendingACLType = NewSerializedType("PENDING_ACL") + PendingSignalType = NewSerializedType("PENDING_SIGNAL") + TimeType = NewSerializedType("TIME") + ResultType = NewSerializedType("RESULT") + TreeType = NewSerializedType("TREE") + SerializedTypeSerialized = NewSerializedType("SERIALIZED_TYPE") ) -func SerializeArray(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ +func SerializeArray(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { type_stack := []SerializedType{ctx_type} if value == nil { return SerializedValue{ @@ -164,7 +173,7 @@ func SerializeArray(ctx *Context, ctx_type SerializedType, reflect_type reflect. } } -func DeserializeArray[T any](ctx *Context)(func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ +func DeserializeArray[T any](ctx *Context) func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { var zero T array_type := reflect.TypeOf(zero) array_size := array_type.Len() @@ -173,7 +182,7 @@ func DeserializeArray[T any](ctx *Context)(func(ctx *Context, value SerializedVa panic(err) } saved_type_stack := zero_value.TypeStack - return func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ + return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { if value.Data == nil { return array_type, nil, value, nil } else { @@ -198,7 +207,7 @@ func DeserializeArray[T any](ctx *Context)(func(ctx *Context, value SerializedVa } } -func SerializeUintN(size int)(func(*Context,SerializedType,reflect.Type,*reflect.Value)(SerializedValue,error)){ +func SerializeUintN(size int) func(*Context, SerializedType, reflect.Type, *reflect.Value) (SerializedValue, error) { var fill_data func([]byte, uint64) = nil switch size { case 1: @@ -220,7 +229,7 @@ func SerializeUintN(size int)(func(*Context,SerializedType,reflect.Type,*reflect default: panic(fmt.Sprintf("Cannot serialize uint of size %d", size)) } - return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { var data []byte = nil if value != nil { data = make([]byte, size) @@ -233,7 +242,9 @@ func SerializeUintN(size int)(func(*Context,SerializedType,reflect.Type,*reflect } } -func DeserializeUintN[T interface{~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64}](size int)(func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ +func DeserializeUintN[T interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 +}](size int) func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { var get_uint func([]byte) uint64 switch size { case 1: @@ -257,7 +268,7 @@ func DeserializeUintN[T interface{~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64}] } var zero T uint_type := reflect.TypeOf(zero) - return func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ + return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { if value.Data == nil { return uint_type, nil, value, nil } else { @@ -274,7 +285,7 @@ func DeserializeUintN[T interface{~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64}] } } -func SerializeIntN(size int)(func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error)){ +func SerializeIntN(size int) func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { var fill_data func([]byte, int64) = nil switch size { case 1: @@ -296,7 +307,7 @@ func SerializeIntN(size int)(func(ctx *Context, ctx_type SerializedType, reflect default: panic(fmt.Sprintf("Cannot serialize int of size %d", size)) } - return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { var data []byte = nil if value != nil { data = make([]byte, size) @@ -309,7 +320,9 @@ func SerializeIntN(size int)(func(ctx *Context, ctx_type SerializedType, reflect } } -func DeserializeIntN[T interface{~int | ~int8 | ~int16 | ~int32 | ~int64}](size int)(func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ +func DeserializeIntN[T interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +}](size int) func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { var get_int func([]byte) int64 switch size { case 1: @@ -333,7 +346,7 @@ func DeserializeIntN[T interface{~int | ~int8 | ~int16 | ~int32 | ~int64}](size } var zero T int_type := reflect.TypeOf(zero) - return func(ctx *Context, value SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error){ + return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { if value.Data == nil { return int_type, nil, value, nil } else { @@ -351,15 +364,15 @@ func DeserializeIntN[T interface{~int | ~int8 | ~int16 | ~int32 | ~int64}](size } type FieldInfo struct { - Index []int + Index []int TypeStack []SerializedType } type StructInfo struct { - Type reflect.Type - FieldOrder []SerializedType - FieldMap map[SerializedType]FieldInfo - PostDeserialize bool + Type reflect.Type + FieldOrder []SerializedType + FieldMap map[SerializedType]FieldInfo + PostDeserialize bool PostDeserializeIdx int } @@ -370,10 +383,10 @@ type Deserializable interface { var deserializable_zero Deserializable = nil var DeserializableType = reflect.TypeOf(&deserializable_zero).Elem() -func structInfo(ctx *Context, struct_type reflect.Type)StructInfo{ +func structInfo(ctx *Context, struct_type reflect.Type) StructInfo { field_order := []SerializedType{} field_map := map[SerializedType]FieldInfo{} - for _, field := range(reflect.VisibleFields(struct_type)) { + for _, field := range reflect.VisibleFields(struct_type) { gv_tag, tagged_gv := field.Tag.Lookup("gv") if tagged_gv == false { continue @@ -396,7 +409,7 @@ func structInfo(ctx *Context, struct_type reflect.Type)StructInfo{ } } - sort.Slice(field_order, func(i, j int)bool { + sort.Slice(field_order, func(i, j int) bool { return uint64(field_order[i]) < uint64(field_order[j]) }) @@ -423,16 +436,16 @@ func structInfo(ctx *Context, struct_type reflect.Type)StructInfo{ } } -func SerializeStruct(ctx *Context, struct_type reflect.Type)(func(*Context,SerializedType,reflect.Type,*reflect.Value)(SerializedValue,error)){ +func SerializeStruct(ctx *Context, struct_type reflect.Type) func(*Context, SerializedType, reflect.Type, *reflect.Value) (SerializedValue, error) { struct_info := structInfo(ctx, struct_type) - return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + return func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { type_stack := []SerializedType{ctx_type} var data []byte if value == nil { data = nil } else { data = make([]byte, 8) - for _, field_hash := range(struct_info.FieldOrder) { + for _, field_hash := range struct_info.FieldOrder { field_hash_bytes := make([]byte, 8) binary.BigEndian.PutUint64(field_hash_bytes, uint64(field_hash)) field_info := struct_info.FieldMap[field_hash] @@ -453,9 +466,9 @@ func SerializeStruct(ctx *Context, struct_type reflect.Type)(func(*Context,Seria } } -func DeserializeStruct(ctx *Context, struct_type reflect.Type)(func(*Context,SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ +func DeserializeStruct(ctx *Context, struct_type reflect.Type) func(*Context, SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { struct_info := structInfo(ctx, struct_type) - return func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error) { + return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { if value.Data == nil { return struct_info.Type, nil, value, nil } else { @@ -511,7 +524,7 @@ func DeserializeStruct(ctx *Context, struct_type reflect.Type)(func(*Context,Ser } } -func SerializeInterface(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ +func SerializeInterface(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value) (SerializedValue, error) { var data []byte type_stack := []SerializedType{ctx_type} if value == nil { @@ -537,8 +550,8 @@ func SerializeInterface(ctx *Context, ctx_type SerializedType, reflect_type refl }, nil } -func DeserializeInterface[T any]()(func(*Context,SerializedValue)(reflect.Type,*reflect.Value,SerializedValue,error)){ - return func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ +func DeserializeInterface[T any]() func(*Context, SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { + return func(ctx *Context, value SerializedValue) (reflect.Type, *reflect.Value, SerializedValue, error) { var interface_zero T var interface_type = reflect.ValueOf(&interface_zero).Type().Elem() if value.Data == nil { @@ -573,7 +586,7 @@ func DeserializeInterface[T any]()(func(*Context,SerializedValue)(reflect.Type,* type SerializedValue struct { TypeStack []SerializedType - Data []byte + Data []byte } func (value SerializedValue) PopType() (SerializedType, SerializedValue, error) { @@ -627,7 +640,7 @@ func SerializeValue(ctx *Context, t reflect.Type, value *reflect.Value) (Seriali serialize = kind_info.Serialize } - serialized_value, err := serialize(ctx, ctx_type, t, value) + serialized_value, err := serialize(ctx, ctx_type, t, value) if err != nil { return serialized_value, err } @@ -641,7 +654,7 @@ func ExtField(ctx *Context, ext Extension, field_name string) (reflect.Value, er } ext_value := reflect.ValueOf(ext).Elem() - for _, field := range(reflect.VisibleFields(ext_value.Type())) { + for _, field := range reflect.VisibleFields(ext_value.Type()) { gv_tag, tagged := field.Tag.Lookup("gv") if tagged == true && gv_tag == field_name { return ext_value.FieldByIndex(field.Index), nil @@ -665,9 +678,9 @@ func (value SerializedValue) MarshalBinary() ([]byte, error) { binary.BigEndian.PutUint64(data[0:8], uint64(len(value.TypeStack))) binary.BigEndian.PutUint64(data[8:16], uint64(len(value.Data))) - for i, t := range(value.TypeStack) { - type_start := (i+2)*8 - type_end := (i+3)*8 + for i, t := range value.TypeStack { + type_start := (i + 2) * 8 + type_end := (i + 3) * 8 binary.BigEndian.PutUint64(data[type_start:type_end], uint64(t)) } @@ -686,12 +699,12 @@ func ParseSerializedValue(data []byte) (SerializedValue, []byte, error) { data_size := int(binary.BigEndian.Uint64(data[8:16])) type_stack := make([]SerializedType, num_types) for i := 0; i < num_types; i += 1 { - type_start := (i+2) * 8 - type_end := (i+3) * 8 + type_start := (i + 2) * 8 + type_end := (i + 3) * 8 type_stack[i] = SerializedType(binary.BigEndian.Uint64(data[type_start:type_end])) } - types_end := 8*(num_types + 2) + types_end := 8 * (num_types + 2) data_end := types_end + data_size return SerializedValue{ type_stack, diff --git a/serialize_test.go b/serialize_test.go index b75b8bf..04e943d 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -7,7 +7,7 @@ import ( ) func TestSerializeBasic(t *testing.T) { - ctx := logTestContext(t, []string{"test"}) + ctx := logTestContext(t, []string{"test", "serialize"}) testSerializeComparable[string](t, ctx, "test") testSerializeComparable[bool](t, ctx, true) testSerializeComparable[float32](t, ctx, 0.05) @@ -59,6 +59,18 @@ func TestSerializeBasic(t *testing.T) { 12345, "test_string", }) + + testSerialize(t, ctx, Tree{ + TreeType: nil, + }) + + testSerialize(t, ctx, Tree{ + TreeType: { + ErrorType: Tree{}, + MapType: nil, + }, + StringType: nil, + }) } type test struct { diff --git a/signal.go b/signal.go index 04b35ad..65a67aa 100644 --- a/signal.go +++ b/signal.go @@ -144,13 +144,31 @@ func NewStopSignal() *StopSignal { } } +type SuccessSignal struct { + SignalHeader +} +func (signal SuccessSignal) Permission() Tree { + return Tree{ + ResultType: { + SerializedType(SuccessSignalType): nil, + }, + } +} +func NewSuccessSignal(req_id uuid.UUID) Signal { + return &SuccessSignal{ + NewRespHeader(req_id, Direct), + } +} + type ErrorSignal struct { SignalHeader Error string } func (signal ErrorSignal) Permission() Tree { return Tree{ - SerializedType(ErrorSignalType): nil, + ResultType: { + SerializedType(ErrorSignalType): nil, + }, } } func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Signal { @@ -275,7 +293,9 @@ type ReadResultSignal struct { } func (signal ReadResultSignal) Permission() Tree { return Tree{ - SerializedType(ReadResultSignalType): nil, + ResultType: { + SerializedType(ReadResultSignalType): nil, + }, } } func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]SerializedValue) *ReadResultSignal {