Major cleanup
parent
e5776e0a14
commit
6942dc02db
@ -1,233 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ACLSignal struct {
|
||||
SignalHeader
|
||||
Principal NodeID `gv:"principal"`
|
||||
Action Tree `gv:"tree"`
|
||||
}
|
||||
|
||||
func NewACLSignal(principal NodeID, action Tree) *ACLSignal {
|
||||
return &ACLSignal{
|
||||
SignalHeader: NewSignalHeader(Direct),
|
||||
Principal: principal,
|
||||
Action: action,
|
||||
}
|
||||
}
|
||||
|
||||
var DefaultACLPolicy = NewAllNodesPolicy(Tree{
|
||||
SerializedType(SignalTypeFor[ACLSignal]()): nil,
|
||||
})
|
||||
|
||||
func (signal ACLSignal) Permission() Tree {
|
||||
return Tree{
|
||||
SerializedType(SignalTypeFor[ACLSignal]()): nil,
|
||||
}
|
||||
}
|
||||
|
||||
type ACLExt struct {
|
||||
Policies []Policy `gv:"policies"`
|
||||
PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"`
|
||||
Pending map[uuid.UUID]PendingACLSignal `gv:"pending"`
|
||||
}
|
||||
|
||||
func NewACLExt(policies []Policy) *ACLExt {
|
||||
return &ACLExt{
|
||||
Policies: policies,
|
||||
PendingACLs: map[uuid.UUID]PendingACL{},
|
||||
Pending: map[uuid.UUID]PendingACLSignal{},
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *ACLExt) Load(ctx *Context, node *Node) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ext *ACLExt) Unload(ctx *Context, node *Node) {
|
||||
}
|
||||
|
||||
func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
|
||||
response, is_response := signal.(ResponseSignal)
|
||||
if is_response == true {
|
||||
var messages Messages = nil
|
||||
var changes = Changes{}
|
||||
info, waiting := ext.Pending[response.ResponseID()]
|
||||
if waiting == true {
|
||||
changes.Add("pending")
|
||||
delete(ext.Pending, response.ResponseID())
|
||||
if response.ID() != info.Timeout {
|
||||
err := node.DequeueSignal(info.Timeout)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("acl", "timeout dequeue error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
acl_info, found := ext.PendingACLs[info.ID]
|
||||
if found == true {
|
||||
acl_info.Counter -= 1
|
||||
acl_info.Responses = append(acl_info.Responses, response)
|
||||
|
||||
policy_index := slices.IndexFunc(ext.Policies, func(policy Policy) bool {
|
||||
return policy.ID() == info.Policy
|
||||
})
|
||||
|
||||
if policy_index == -1 {
|
||||
ctx.Log.Logf("acl", "pending signal for nonexistent policy")
|
||||
delete(ext.PendingACLs, info.ID)
|
||||
err := node.DequeueSignal(acl_info.TimeoutID)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
||||
}
|
||||
} else {
|
||||
if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow {
|
||||
changes.Add("pending_acls")
|
||||
delete(ext.PendingACLs, info.ID)
|
||||
ctx.Log.Logf("acl", "Request delayed allow")
|
||||
messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID))
|
||||
err := node.DequeueSignal(acl_info.TimeoutID)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
||||
}
|
||||
} else if acl_info.Counter == 0 {
|
||||
changes.Add("pending_acls")
|
||||
delete(ext.PendingACLs, info.ID)
|
||||
ctx.Log.Logf("acl", "Request delayed deny")
|
||||
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied"))
|
||||
err := node.DequeueSignal(acl_info.TimeoutID)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
||||
}
|
||||
} else {
|
||||
node.PendingACLs[info.ID] = acl_info
|
||||
changes.Add("pending_acls")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return messages, changes
|
||||
}
|
||||
|
||||
var messages Messages = nil
|
||||
var changes = Changes{}
|
||||
|
||||
switch sig := signal.(type) {
|
||||
case *ACLSignal:
|
||||
var acl_messages map[uuid.UUID]Messages = nil
|
||||
denied := true
|
||||
for _, policy := range(ext.Policies) {
|
||||
policy_messages, result := policy.Allows(ctx, sig.Principal, sig.Action, node)
|
||||
if result == Allow {
|
||||
denied = false
|
||||
break
|
||||
} else if result == Pending {
|
||||
if len(policy_messages) == 0 {
|
||||
ctx.Log.Logf("acl", "Pending result for %s with no messages returned", policy.ID())
|
||||
continue
|
||||
} else if acl_messages == nil {
|
||||
acl_messages = map[uuid.UUID]Messages{}
|
||||
denied = false
|
||||
}
|
||||
|
||||
acl_messages[policy.ID()] = policy_messages
|
||||
ctx.Log.Logf("acl", "Pending result for %s:%s - %+v", node.ID, policy.ID(), acl_messages)
|
||||
}
|
||||
}
|
||||
|
||||
if denied == true {
|
||||
ctx.Log.Logf("acl", "Request denied")
|
||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied"))
|
||||
} else if acl_messages != nil {
|
||||
ctx.Log.Logf("acl", "Request pending")
|
||||
changes.Add("pending")
|
||||
total_messages := 0
|
||||
// TODO: reasonable timeout/configurable
|
||||
timeout_time := time.Now().Add(time.Second)
|
||||
for policy_id, policy_messages := range(acl_messages) {
|
||||
total_messages += len(policy_messages)
|
||||
for _, message := range(policy_messages) {
|
||||
timeout_signal := NewTimeoutSignal(message.Signal.ID())
|
||||
ext.Pending[message.Signal.ID()] = PendingACLSignal{
|
||||
Policy: policy_id,
|
||||
Timeout: timeout_signal.Id,
|
||||
ID: sig.Id,
|
||||
}
|
||||
node.QueueSignal(timeout_time, timeout_signal)
|
||||
messages = append(messages, message)
|
||||
}
|
||||
}
|
||||
|
||||
acl_timeout := NewACLTimeoutSignal(sig.Id)
|
||||
node.QueueSignal(timeout_time, acl_timeout)
|
||||
ext.PendingACLs[sig.Id] = PendingACL{
|
||||
Counter: total_messages,
|
||||
Responses: []ResponseSignal{},
|
||||
TimeoutID: acl_timeout.Id,
|
||||
Action: sig.Action,
|
||||
Principal: sig.Principal,
|
||||
|
||||
Source: source,
|
||||
Signal: signal,
|
||||
}
|
||||
} else {
|
||||
ctx.Log.Logf("acl", "Request allowed")
|
||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||
}
|
||||
// Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request
|
||||
case *ACLTimeoutSignal:
|
||||
acl_info, exists := ext.PendingACLs[sig.ReqID]
|
||||
if exists == true {
|
||||
delete(ext.PendingACLs, sig.ReqID)
|
||||
changes.Add("pending_acls")
|
||||
ctx.Log.Logf("acl", "Request timeout deny")
|
||||
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout"))
|
||||
err := node.DequeueSignal(acl_info.TimeoutID)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
||||
}
|
||||
} else {
|
||||
ctx.Log.Logf("acl", "ACL_TIMEOUT_SIGNAL for passed acl")
|
||||
}
|
||||
// Delete from PendingACLs
|
||||
}
|
||||
|
||||
return messages, changes
|
||||
}
|
||||
|
||||
type ACLProxyPolicy struct {
|
||||
PolicyHeader
|
||||
Proxies []NodeID `gv:"proxies"`
|
||||
}
|
||||
|
||||
func NewACLProxyPolicy(proxies []NodeID) ACLProxyPolicy {
|
||||
return ACLProxyPolicy{
|
||||
NewPolicyHeader(),
|
||||
proxies,
|
||||
}
|
||||
}
|
||||
|
||||
func (policy ACLProxyPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
|
||||
if len(policy.Proxies) == 0 {
|
||||
return nil, Deny
|
||||
}
|
||||
|
||||
messages := Messages{}
|
||||
for _, proxy := range(policy.Proxies) {
|
||||
messages = messages.Add(ctx, proxy, node, nil, NewACLSignal(principal_id, action))
|
||||
}
|
||||
|
||||
return messages, Pending
|
||||
}
|
||||
|
||||
func (policy ACLProxyPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
||||
_, is_success := signal.(*SuccessSignal)
|
||||
if is_success == true {
|
||||
return Allow
|
||||
}
|
||||
return Deny
|
||||
}
|
||||
|
@ -1,141 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
func checkSignal[S Signal](t *testing.T, signal Signal, check func(S)){
|
||||
response_casted, cast_ok := signal.(S)
|
||||
if cast_ok == false {
|
||||
error_signal, is_error := signal.(*ErrorSignal)
|
||||
if is_error {
|
||||
t.Log(string(debug.Stack()))
|
||||
t.Fatal(error_signal.Error)
|
||||
}
|
||||
t.Fatalf("Response of wrong type %s", reflect.TypeOf(signal))
|
||||
}
|
||||
|
||||
check(response_casted)
|
||||
}
|
||||
|
||||
func testSendACL[S Signal](t *testing.T, ctx *Context, listener *Node, action Tree, policies []Policy, check func(S)){
|
||||
acl_node, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(policies))
|
||||
fatalErr(t, err)
|
||||
|
||||
acl_signal := NewACLSignal(listener.ID, action)
|
||||
response, _ := testSend(t, ctx, acl_signal, listener, acl_node)
|
||||
|
||||
checkSignal(t, response, check)
|
||||
}
|
||||
|
||||
func testErrorSignal(t *testing.T, error_string string) func(*ErrorSignal){
|
||||
return func(response *ErrorSignal) {
|
||||
if response.Error != error_string {
|
||||
t.Fatalf("Wrong error: %s", response.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testSuccess(*SuccessSignal){}
|
||||
|
||||
func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) {
|
||||
source_listener, err := GetExt[ListenerExt](source)
|
||||
fatalErr(t, err)
|
||||
|
||||
messages := Messages{}
|
||||
messages = messages.Add(ctx, destination.ID, source, nil, signal)
|
||||
fatalErr(t, ctx.Send(messages))
|
||||
|
||||
response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())
|
||||
fatalErr(t, err)
|
||||
|
||||
return response, signals
|
||||
}
|
||||
|
||||
func TestACLBasic(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test", "acl", "group", "read_field"})
|
||||
|
||||
listener, err := NewNode(ctx, nil, "Base", 100, nil, NewListenerExt(100))
|
||||
fatalErr(t, err)
|
||||
|
||||
ctx.Log.Logf("test", "testing fail")
|
||||
testSendACL(t, ctx, listener, nil, nil, testErrorSignal(t, "acl_denied"))
|
||||
|
||||
ctx.Log.Logf("test", "testing allow all")
|
||||
testSendACL(t, ctx, listener, nil, []Policy{NewAllNodesPolicy(nil)}, testSuccess)
|
||||
|
||||
group, err := NewNode(ctx, nil, "Base", 100, []Policy{
|
||||
DefaultGroupPolicy,
|
||||
NewPerNodePolicy(map[NodeID]Tree{
|
||||
listener.ID: {
|
||||
SerializedType(SignalTypeFor[AddSubGroupSignal]()): nil,
|
||||
SerializedType(SignalTypeFor[AddMemberSignal]()): nil,
|
||||
},
|
||||
}),
|
||||
}, NewGroupExt(nil))
|
||||
fatalErr(t, err)
|
||||
|
||||
ctx.Log.Logf("test", "testing empty groups")
|
||||
testSendACL(t, ctx, listener, nil, []Policy{
|
||||
NewMemberOfPolicy(map[NodeID]map[string]Tree{
|
||||
group.ID: {
|
||||
"test_group": nil,
|
||||
},
|
||||
}),
|
||||
}, testErrorSignal(t, "acl_denied"))
|
||||
|
||||
ctx.Log.Logf("test", "testing adding group")
|
||||
add_subgroup_signal := NewAddSubGroupSignal("test_group")
|
||||
add_subgroup_response, _ := testSend(t, ctx, add_subgroup_signal, listener, group)
|
||||
checkSignal(t, add_subgroup_response, testSuccess)
|
||||
|
||||
ctx.Log.Logf("test", "testing adding member")
|
||||
add_member_signal := NewAddMemberSignal("test_group", listener.ID)
|
||||
add_member_response, _ := testSend(t, ctx, add_member_signal, listener, group)
|
||||
checkSignal(t, add_member_response, testSuccess)
|
||||
|
||||
ctx.Log.Logf("test", "testing group membership")
|
||||
testSendACL(t, ctx, listener, nil, []Policy{
|
||||
NewMemberOfPolicy(map[NodeID]map[string]Tree{
|
||||
group.ID: {
|
||||
"test_group": nil,
|
||||
},
|
||||
}),
|
||||
}, testSuccess)
|
||||
|
||||
testSendACL(t, ctx, listener, nil, []Policy{
|
||||
NewACLProxyPolicy(nil),
|
||||
}, testErrorSignal(t, "acl_denied"))
|
||||
|
||||
acl_proxy_1, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(nil))
|
||||
fatalErr(t, err)
|
||||
|
||||
testSendACL(t, ctx, listener, nil, []Policy{
|
||||
NewACLProxyPolicy([]NodeID{acl_proxy_1.ID}),
|
||||
}, testErrorSignal(t, "acl_denied"))
|
||||
|
||||
acl_proxy_2, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewAllNodesPolicy(nil)}))
|
||||
fatalErr(t, err)
|
||||
|
||||
testSendACL(t, ctx, listener, nil, []Policy{
|
||||
NewACLProxyPolicy([]NodeID{acl_proxy_2.ID}),
|
||||
}, testSuccess)
|
||||
|
||||
acl_proxy_3, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy},
|
||||
NewACLExt([]Policy{
|
||||
NewMemberOfPolicy(map[NodeID]map[string]Tree{
|
||||
group.ID: {
|
||||
"test_group": nil,
|
||||
},
|
||||
}),
|
||||
}),
|
||||
)
|
||||
fatalErr(t, err)
|
||||
|
||||
testSendACL(t, ctx, listener, nil, []Policy{
|
||||
NewACLProxyPolicy([]NodeID{acl_proxy_3.ID}),
|
||||
}, testSuccess)
|
||||
}
|
@ -1,145 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
graphql "github.com/graphql-go/graphql"
|
||||
"github.com/google/uuid"
|
||||
"reflect"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type StructFieldInfo struct {
|
||||
Name string
|
||||
Type *TypeInfo
|
||||
Index []int
|
||||
}
|
||||
|
||||
func ArgumentInfo(ctx *Context, field reflect.StructField, gv_tag string) (StructFieldInfo, error) {
|
||||
type_info, mapped := ctx.TypeReflects[field.Type]
|
||||
if mapped == false {
|
||||
return StructFieldInfo{}, fmt.Errorf("field %+v is of unregistered type %+v ", field.Name, field.Type)
|
||||
}
|
||||
|
||||
return StructFieldInfo{
|
||||
Name: gv_tag,
|
||||
Type: type_info,
|
||||
Index: field.Index,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SignalFromArgs(ctx *Context, signal_type reflect.Type, fields []StructFieldInfo, args map[string]interface{}, id_index, direction_index []int) (Signal, error) {
|
||||
fmt.Printf("FIELD: %+v\n", fields)
|
||||
signal_value := reflect.New(signal_type)
|
||||
|
||||
id_field := signal_value.Elem().FieldByIndex(id_index)
|
||||
id_field.Set(reflect.ValueOf(uuid.New()))
|
||||
|
||||
direction_field := signal_value.Elem().FieldByIndex(direction_index)
|
||||
direction_field.Set(reflect.ValueOf(Direct))
|
||||
|
||||
for _, field := range(fields) {
|
||||
arg, arg_exists := args[field.Name]
|
||||
if arg_exists == false {
|
||||
return nil, fmt.Errorf("No arg provided named %s", field.Name)
|
||||
}
|
||||
field_value := signal_value.Elem().FieldByIndex(field.Index)
|
||||
if field_value.CanConvert(field.Type.Reflect) == false {
|
||||
return nil, fmt.Errorf("Arg %s wrong type %s/%s", field.Name, field_value.Type(), field.Type.Reflect)
|
||||
}
|
||||
value, err := field.Type.GQLValue(ctx, arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("Setting %s to %+v of type %+v\n", field.Name, value, value.Type())
|
||||
field_value.Set(value)
|
||||
}
|
||||
return signal_value.Interface().(Signal), nil
|
||||
}
|
||||
|
||||
func NewSignalMutation(ctx *Context, name string, send_id_key string, signal_type reflect.Type) (*graphql.Field, error) {
|
||||
args := graphql.FieldConfigArgument{}
|
||||
arg_info := []StructFieldInfo{}
|
||||
var id_index []int = nil
|
||||
var direction_index []int = nil
|
||||
|
||||
for _, field := range(reflect.VisibleFields(signal_type)) {
|
||||
gv_tag, tagged_gv := field.Tag.Lookup("gv")
|
||||
if tagged_gv {
|
||||
if gv_tag == "id" {
|
||||
id_index = field.Index
|
||||
} else if gv_tag == "direction" {
|
||||
direction_index = field.Index
|
||||
} else {
|
||||
_, exists := args[gv_tag]
|
||||
if exists == true {
|
||||
return nil, fmt.Errorf("Signal has repeated tag %s", gv_tag)
|
||||
} else {
|
||||
info, err := ArgumentInfo(ctx, field, gv_tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args[gv_tag] = &graphql.ArgumentConfig{
|
||||
}
|
||||
arg_info = append(arg_info, info)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, send_exists := args[send_id_key]
|
||||
if send_exists == false {
|
||||
args[send_id_key] = &graphql.ArgumentConfig{
|
||||
Type: graphql.String,
|
||||
}
|
||||
}
|
||||
|
||||
resolve_signal := func(p graphql.ResolveParams) (interface{}, error) {
|
||||
ctx, err := PrepResolve(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
send_id, err := ExtractID(p, send_id_key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signal, err := SignalFromArgs(ctx.Context, signal_type, arg_info, p.Args, id_index, direction_index)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgs := Messages{}
|
||||
msgs = msgs.Add(ctx.Context, send_id, ctx.Server, ctx.Authorization, signal)
|
||||
|
||||
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
|
||||
err = ctx.Context.Send(msgs)
|
||||
if err != nil {
|
||||
ctx.Ext.FreeResponseChannel(signal.ID())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
|
||||
if err != nil {
|
||||
ctx.Ext.FreeResponseChannel(signal.ID())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, is_success := response.(*SuccessSignal)
|
||||
if is_success == true {
|
||||
return "success", nil
|
||||
}
|
||||
|
||||
error_response, is_error := response.(*ErrorSignal)
|
||||
if is_error == true {
|
||||
return "error", fmt.Errorf(error_response.Error)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("response of unhandled type %s", reflect.TypeOf(response))
|
||||
}
|
||||
|
||||
return &graphql.Field{
|
||||
Type: graphql.String,
|
||||
Args: args,
|
||||
Resolve: resolve_signal,
|
||||
}, nil
|
||||
}
|
@ -1,296 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"slices"
|
||||
)
|
||||
|
||||
type AddSubGroupSignal struct {
|
||||
SignalHeader
|
||||
Name string `gv:"name"`
|
||||
}
|
||||
|
||||
func NewAddSubGroupSignal(name string) *AddSubGroupSignal {
|
||||
return &AddSubGroupSignal{
|
||||
NewSignalHeader(Direct),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
func (signal AddSubGroupSignal) Permission() Tree {
|
||||
return Tree{
|
||||
SerializedType(SignalTypeFor[AddSubGroupSignal]()): {
|
||||
Hash("name", signal.Name): nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type RemoveSubGroupSignal struct {
|
||||
SignalHeader
|
||||
Name string `gv:"name"`
|
||||
}
|
||||
|
||||
func NewRemoveSubGroupSignal(name string) *RemoveSubGroupSignal {
|
||||
return &RemoveSubGroupSignal{
|
||||
NewSignalHeader(Direct),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
func (signal RemoveSubGroupSignal) Permission() Tree {
|
||||
return Tree{
|
||||
SerializedType(SignalTypeFor[RemoveSubGroupSignal]()): {
|
||||
Hash("command", signal.Name): nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type AddMemberSignal struct {
|
||||
SignalHeader
|
||||
SubGroup string `gv:"sub_group"`
|
||||
MemberID NodeID `gv:"member_id"`
|
||||
}
|
||||
|
||||
type SubGroupGQL struct {
|
||||
Name string
|
||||
Members []NodeID
|
||||
}
|
||||
|
||||
func (signal AddMemberSignal) Permission() Tree {
|
||||
return Tree{
|
||||
SerializedType(SignalTypeFor[AddMemberSignal]()): {
|
||||
Hash("sub_group", signal.SubGroup): nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewAddMemberSignal(sub_group string, member_id NodeID) *AddMemberSignal {
|
||||
return &AddMemberSignal{
|
||||
NewSignalHeader(Direct),
|
||||
sub_group,
|
||||
member_id,
|
||||
}
|
||||
}
|
||||
|
||||
type RemoveMemberSignal struct {
|
||||
SignalHeader
|
||||
SubGroup string `gv:"sub_group"`
|
||||
MemberID NodeID `gv:"member_id"`
|
||||
}
|
||||
|
||||
func (signal RemoveMemberSignal) Permission() Tree {
|
||||
return Tree{
|
||||
SerializedType(SignalTypeFor[RemoveMemberSignal]()): {
|
||||
Hash("sub_group", signal.SubGroup): nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewRemoveMemberSignal(sub_group string, member_id NodeID) *RemoveMemberSignal {
|
||||
return &RemoveMemberSignal{
|
||||
NewSignalHeader(Direct),
|
||||
sub_group,
|
||||
member_id,
|
||||
}
|
||||
}
|
||||
|
||||
var DefaultGroupPolicy = NewAllNodesPolicy(Tree{
|
||||
SerializedType(SignalTypeFor[ReadSignal]()): {
|
||||
SerializedType(ExtTypeFor[GroupExt]()): {
|
||||
SerializedType(GetFieldTag("sub_groups")): nil,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
type SubGroup struct {
|
||||
Members []NodeID
|
||||
Permissions Tree
|
||||
}
|
||||
|
||||
type MemberOfPolicy struct {
|
||||
PolicyHeader
|
||||
Groups map[NodeID]map[string]Tree
|
||||
}
|
||||
|
||||
func NewMemberOfPolicy(groups map[NodeID]map[string]Tree) MemberOfPolicy {
|
||||
return MemberOfPolicy{
|
||||
PolicyHeader: NewPolicyHeader(),
|
||||
Groups: groups,
|
||||
}
|
||||
}
|
||||
|
||||
func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
||||
sig, ok := signal.(*ReadResultSignal)
|
||||
if ok == false {
|
||||
return Deny
|
||||
}
|
||||
ctx.Log.Logf("group", "member_of_read_result: %+v", sig.Extensions)
|
||||
|
||||
group_ext_data, ok := sig.Extensions[ExtTypeFor[GroupExt]()]
|
||||
if ok == false {
|
||||
return Deny
|
||||
}
|
||||
|
||||
sub_groups_ser, ok := group_ext_data["sub_groups"]
|
||||
if ok == false {
|
||||
return Deny
|
||||
}
|
||||
|
||||
sub_groups_type, _, err := DeserializeType(ctx, sub_groups_ser.TypeStack)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("group", "Type deserialize error: %s", err)
|
||||
return Deny
|
||||
}
|
||||
|
||||
sub_groups_if, _, err := DeserializeValue(ctx, sub_groups_type, sub_groups_ser.Data)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("group", "Value deserialize error: %s", err)
|
||||
return Deny
|
||||
}
|
||||
|
||||
ext_sub_groups, ok := sub_groups_if.Interface().(map[string][]NodeID)
|
||||
if ok == false {
|
||||
return Deny
|
||||
}
|
||||
|
||||
group, exists := policy.Groups[sig.NodeID]
|
||||
if exists == false {
|
||||
return Deny
|
||||
}
|
||||
|
||||
for sub_group_name, permissions := range(group) {
|
||||
ext_sub_group, exists := ext_sub_groups[sub_group_name]
|
||||
if exists == true {
|
||||
for _, member_id := range(ext_sub_group) {
|
||||
if member_id == current.Principal {
|
||||
if permissions.Allows(current.Action) == Allow {
|
||||
return Allow
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Deny
|
||||
}
|
||||
|
||||
// Send a read signal to Group to check if principal_id is a member of it
|
||||
func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
|
||||
var messages Messages = nil
|
||||
for group_id, sub_groups := range(policy.Groups) {
|
||||
if group_id == node.ID {
|
||||
ext, err := GetExt[GroupExt](node)
|
||||
if err != nil {
|
||||
ctx.Log.Logf("group", "MemberOfPolicy with self ID error: %s", err)
|
||||
} else {
|
||||
for sub_group_name, permission := range(sub_groups) {
|
||||
ext_sub_group, exists := ext.SubGroups[sub_group_name]
|
||||
if exists == true {
|
||||
for _, member := range(ext_sub_group) {
|
||||
if member == principal_id {
|
||||
if permission.Allows(action) == Allow {
|
||||
return nil, Allow
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Send the read request to the group so that ContinueAllows can parse the response to check membership
|
||||
messages = messages.Add(ctx, group_id, node, nil, NewReadSignal(map[ExtType][]string{
|
||||
ExtTypeFor[GroupExt](): {"sub_groups"},
|
||||
}))
|
||||
}
|
||||
}
|
||||
if len(messages) > 0 {
|
||||
return messages, Pending
|
||||
} else {
|
||||
return nil, Deny
|
||||
}
|
||||
}
|
||||
|
||||
type GroupExt struct {
|
||||
SubGroups map[string][]NodeID `gv:"sub_groups"`
|
||||
}
|
||||
|
||||
func NewGroupExt(sub_groups map[string][]NodeID) *GroupExt {
|
||||
if sub_groups == nil {
|
||||
sub_groups = map[string][]NodeID{}
|
||||
}
|
||||
return &GroupExt{
|
||||
SubGroups: sub_groups,
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *GroupExt) Load(ctx *Context, node *Node) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ext *GroupExt) Unload(ctx *Context, node *Node) {
|
||||
|
||||
}
|
||||
|
||||
func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
|
||||
var messages Messages = nil
|
||||
var changes = Changes{}
|
||||
|
||||
switch sig := signal.(type) {
|
||||
case *AddMemberSignal:
|
||||
sub_group, exists := ext.SubGroups[sig.SubGroup]
|
||||
if exists == false {
|
||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
|
||||
} else {
|
||||
if slices.Contains(sub_group, sig.MemberID) == true {
|
||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_member"))
|
||||
} else {
|
||||
sub_group = append(sub_group, sig.MemberID)
|
||||
ext.SubGroups[sig.SubGroup] = sub_group
|
||||
|
||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||
changes.Add("sub_groups")
|
||||
}
|
||||
}
|
||||
|
||||
case *RemoveMemberSignal:
|
||||
sub_group, exists := ext.SubGroups[sig.SubGroup]
|
||||
if exists == false {
|
||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
|
||||
} else {
|
||||
idx := slices.Index(sub_group, sig.MemberID)
|
||||
if idx == -1 {
|
||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_member"))
|
||||
} else {
|
||||
sub_group = slices.Delete(sub_group, idx, idx+1)
|
||||
ext.SubGroups[sig.SubGroup] = sub_group
|
||||
|
||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||
changes.Add("sub_groups")
|
||||
}
|
||||
}
|
||||
|
||||
case *AddSubGroupSignal:
|
||||
_, exists := ext.SubGroups[sig.Name]
|
||||
if exists == true {
|
||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_subgroup"))
|
||||
} else {
|
||||
ext.SubGroups[sig.Name] = []NodeID{}
|
||||
|
||||
changes.Add("sub_groups")
|
||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||
}
|
||||
case *RemoveSubGroupSignal:
|
||||
_, exists := ext.SubGroups[sig.Name]
|
||||
if exists == false {
|
||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
|
||||
} else {
|
||||
delete(ext.SubGroups, sig.Name)
|
||||
|
||||
changes.Add("sub_groups")
|
||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||
}
|
||||
}
|
||||
|
||||
return messages, changes
|
||||
}
|
||||
|
@ -1,94 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGroupAdd(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"listener", "test"})
|
||||
|
||||
group_listener := NewListenerExt(10)
|
||||
group, err := NewNode(ctx, nil, "Base", 10, nil, group_listener, NewGroupExt(nil))
|
||||
fatalErr(t, err)
|
||||
|
||||
add_subgroup_signal := NewAddSubGroupSignal("test_group")
|
||||
messages := Messages{}
|
||||
messages = messages.Add(ctx, group.ID, group, nil, add_subgroup_signal)
|
||||
fatalErr(t, ctx.Send(messages))
|
||||
|
||||
resp_1, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_subgroup_signal.Id)
|
||||
fatalErr(t, err)
|
||||
|
||||
error_1, is_error := resp_1.(*ErrorSignal)
|
||||
if is_error {
|
||||
t.Fatalf("Error returned: %s", error_1.Error)
|
||||
}
|
||||
|
||||
user_id := RandID()
|
||||
add_member_signal := NewAddMemberSignal("test_group", user_id)
|
||||
|
||||
messages = Messages{}
|
||||
messages = messages.Add(ctx, group.ID, group, nil, add_member_signal)
|
||||
fatalErr(t, ctx.Send(messages))
|
||||
|
||||
resp_2, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_member_signal.Id)
|
||||
fatalErr(t, err)
|
||||
|
||||
error_2, is_error := resp_2.(*ErrorSignal)
|
||||
if is_error {
|
||||
t.Fatalf("Error returned: %s", error_2.Error)
|
||||
}
|
||||
|
||||
read_signal := NewReadSignal(map[ExtType][]string{
|
||||
ExtTypeFor[GroupExt](): {"sub_groups"},
|
||||
})
|
||||
|
||||
messages = Messages{}
|
||||
messages = messages.Add(ctx, group.ID, group, nil, read_signal)
|
||||
fatalErr(t, ctx.Send(messages))
|
||||
|
||||
response, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, read_signal.Id)
|
||||
fatalErr(t, err)
|
||||
|
||||
read_response := response.(*ReadResultSignal)
|
||||
|
||||
sub_groups_serialized := read_response.Extensions[ExtTypeFor[GroupExt]()]["sub_groups"]
|
||||
|
||||
sub_groups_type, remaining_types, err := DeserializeType(ctx, sub_groups_serialized.TypeStack)
|
||||
fatalErr(t, err)
|
||||
if len(remaining_types) > 0 {
|
||||
t.Fatalf("Types remaining after deserializing subgroups: %d", len(remaining_types))
|
||||
}
|
||||
|
||||
sub_groups_value, remaining, err := DeserializeValue(ctx, sub_groups_type, sub_groups_serialized.Data)
|
||||
fatalErr(t, err)
|
||||
if len(remaining) > 0 {
|
||||
t.Fatalf("Data remaining after deserializing subgroups: %d", len(remaining_types))
|
||||
}
|
||||
|
||||
sub_groups, ok := sub_groups_value.Interface().(map[string][]NodeID)
|
||||
|
||||
if ok != true {
|
||||
t.Fatalf("sub_groups wrong type %s", sub_groups_value.Type())
|
||||
}
|
||||
|
||||
if len(sub_groups) != 1 {
|
||||
t.Fatalf("sub_groups wrong length %d", len(sub_groups))
|
||||
}
|
||||
|
||||
test_subgroup, exists := sub_groups["test_group"]
|
||||
if exists == false {
|
||||
t.Fatal("test_group not in subgroups")
|
||||
}
|
||||
|
||||
if len(test_subgroup) != 1 {
|
||||
t.Fatalf("test_group wrong size %d/1", len(test_subgroup))
|
||||
}
|
||||
|
||||
if test_subgroup[0] != user_id {
|
||||
t.Fatalf("sub_groups wrong value %s", test_subgroup[0])
|
||||
}
|
||||
|
||||
ctx.Log.Logf("test", "Read Response: %+v", read_response)
|
||||
}
|
@ -1,114 +1,11 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"time"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto"
|
||||
)
|
||||
|
||||
type AuthInfo struct {
|
||||
// The Node that issued the authorization
|
||||
Identity ed25519.PublicKey
|
||||
|
||||
// Time the authorization was generated
|
||||
Start time.Time
|
||||
|
||||
// Signature of Start + Principal with Identity private key
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
type AuthorizationToken struct {
|
||||
AuthInfo
|
||||
|
||||
// The private key generated by the client, encrypted with the servers public key
|
||||
KeyEncrypted []byte
|
||||
}
|
||||
|
||||
type ClientAuthorization struct {
|
||||
AuthInfo
|
||||
|
||||
// The private key generated by the client
|
||||
Key ed25519.PrivateKey
|
||||
}
|
||||
|
||||
// Authorization structs can be passed in a message that originated from a different node than the sender
|
||||
type Authorization struct {
|
||||
AuthInfo
|
||||
|
||||
// The public key generated for this authorization
|
||||
Key ed25519.PublicKey
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Dest NodeID
|
||||
Source ed25519.PublicKey
|
||||
|
||||
Authorization *Authorization
|
||||
|
||||
type SendMsg struct {
|
||||
Dest NodeID
|
||||
Signal Signal
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
type Messages []*Message
|
||||
func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages {
|
||||
msg, err := NewMessage(ctx, dest, source, authorization, signal)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
func NewMessages(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signals... Signal) Messages {
|
||||
messages := Messages{}
|
||||
for _, signal := range(signals) {
|
||||
messages = messages.Add(ctx, dest, source, authorization, signal)
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) {
|
||||
signal_ser, err := SerializeAny(ctx, signal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signal_chunks, err := signal_ser.Chunks()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dest_ser, err := dest.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
source_ser, err := source.ID.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sig_data := append(dest_ser, source_ser...)
|
||||
sig_data = append(sig_data, signal_chunks.Slice()...)
|
||||
var message_auth *Authorization = nil
|
||||
if authorization != nil {
|
||||
sig_data = append(sig_data, authorization.Signature...)
|
||||
message_auth = &Authorization{
|
||||
authorization.AuthInfo,
|
||||
authorization.Key.Public().(ed25519.PublicKey),
|
||||
}
|
||||
}
|
||||
|
||||
sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Message{
|
||||
Dest: dest,
|
||||
Source: source.Key.Public().(ed25519.PublicKey),
|
||||
Authorization: message_auth,
|
||||
Signal: signal,
|
||||
Signature: sig,
|
||||
}, nil
|
||||
type RecvMsg struct {
|
||||
Source NodeID
|
||||
Signal Signal
|
||||
}
|
||||
|
@ -1,139 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Policy interface {
|
||||
Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult)
|
||||
ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult
|
||||
ID() uuid.UUID
|
||||
}
|
||||
|
||||
type PolicyHeader struct {
|
||||
UUID uuid.UUID `gv:"uuid"`
|
||||
}
|
||||
|
||||
func (header PolicyHeader) ID() uuid.UUID {
|
||||
return header.UUID
|
||||
}
|
||||
|
||||
func (policy AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
|
||||
return nil, policy.Rules.Allows(action)
|
||||
}
|
||||
|
||||
func (policy AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
||||
return Deny
|
||||
}
|
||||
|
||||
func (policy PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
|
||||
for id, actions := range(policy.NodeRules) {
|
||||
if id != principal_id {
|
||||
continue
|
||||
}
|
||||
return nil, actions.Allows(action)
|
||||
}
|
||||
return nil, Deny
|
||||
}
|
||||
|
||||
func (policy PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
||||
return Deny
|
||||
}
|
||||
|
||||
func CopyTree(tree Tree) Tree {
|
||||
if tree == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := Tree{}
|
||||
for name, sub := range(tree) {
|
||||
ret[name] = CopyTree(sub)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func MergeTrees(first Tree, second Tree) Tree {
|
||||
if first == nil || second == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := CopyTree(first)
|
||||
for name, sub := range(second) {
|
||||
current, exists := ret[name]
|
||||
if exists == true {
|
||||
ret[name] = MergeTrees(current, sub)
|
||||
} else {
|
||||
ret[name] = CopyTree(sub)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
type Tree map[SerializedType]Tree
|
||||
|
||||
func (rule Tree) Allows(action Tree) RuleResult {
|
||||
// If the current rule is nil, it's a wildcard and any action being processed is allowed
|
||||
if rule == nil {
|
||||
return Allow
|
||||
// If the rule isn't "allow all" but the action is "request all", deny
|
||||
} else if action == nil {
|
||||
return Deny
|
||||
// If the current action has no children, it's allowed
|
||||
} else if len(action) == 0 {
|
||||
return Allow
|
||||
// If the current rule has no children but the action goes further, it's not allowed
|
||||
} else if len(rule) == 0 {
|
||||
return Deny
|
||||
// If the current rule and action have children, all the children of action must be allowed by rule
|
||||
} else {
|
||||
for sub, subtree := range(action) {
|
||||
subrule, exists := rule[sub]
|
||||
if exists == false {
|
||||
return Deny
|
||||
} else if subrule.Allows(subtree) == Deny {
|
||||
return Deny
|
||||
}
|
||||
}
|
||||
return Allow
|
||||
}
|
||||
}
|
||||
|
||||
func NewPolicyHeader() PolicyHeader {
|
||||
return PolicyHeader{
|
||||
UUID: uuid.New(),
|
||||
}
|
||||
}
|
||||
|
||||
func NewPerNodePolicy(node_actions map[NodeID]Tree) PerNodePolicy {
|
||||
if node_actions == nil {
|
||||
node_actions = map[NodeID]Tree{}
|
||||
}
|
||||
|
||||
return PerNodePolicy{
|
||||
PolicyHeader: NewPolicyHeader(),
|
||||
NodeRules: node_actions,
|
||||
}
|
||||
}
|
||||
|
||||
type PerNodePolicy struct {
|
||||
PolicyHeader
|
||||
NodeRules map[NodeID]Tree `gv:"node_rules"`
|
||||
}
|
||||
|
||||
func NewAllNodesPolicy(rules Tree) AllNodesPolicy {
|
||||
return AllNodesPolicy{
|
||||
PolicyHeader: NewPolicyHeader(),
|
||||
Rules: rules,
|
||||
}
|
||||
}
|
||||
|
||||
type AllNodesPolicy struct {
|
||||
PolicyHeader
|
||||
Rules Tree `gv:"rules"`
|
||||
}
|
||||
|
||||
var DefaultPolicy = NewAllNodesPolicy(Tree{
|
||||
SerializedType(SignalTypeFor[ResponseSignal]()): nil,
|
||||
SerializedType(SignalTypeFor[StatusSignal]()): nil,
|
||||
})
|
File diff suppressed because it is too large
Load Diff
@ -1,247 +0,0 @@
|
||||
package graphvent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSerializeTest(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"})
|
||||
testSerialize(t, ctx, map[string][]NodeID{"test_group": {RandID(), RandID(), RandID()}})
|
||||
testSerialize(t, ctx, map[NodeID]ReqState{
|
||||
RandID(): Locked,
|
||||
RandID(): Unlocked,
|
||||
})
|
||||
}
|
||||
|
||||
func TestSerializeBasic(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"})
|
||||
testSerializeComparable[bool](t, ctx, true)
|
||||
|
||||
type bool_wrapped bool
|
||||
err := RegisterType[bool_wrapped](ctx, nil, nil, nil, DeserializeBool[bool_wrapped])
|
||||
fatalErr(t, err)
|
||||
testSerializeComparable[bool_wrapped](t, ctx, true)
|
||||
|
||||
testSerializeSlice[[]bool](t, ctx, []bool{false, false, true, false})
|
||||
testSerializeComparable[string](t, ctx, "test")
|
||||
testSerializeComparable[float32](t, ctx, 0.05)
|
||||
testSerializeComparable[float64](t, ctx, 0.05)
|
||||
testSerializeComparable[uint](t, ctx, uint(1234))
|
||||
testSerializeComparable[uint8] (t, ctx, uint8(123))
|
||||
testSerializeComparable[uint16](t, ctx, uint16(1234))
|
||||
testSerializeComparable[uint32](t, ctx, uint32(12345))
|
||||
testSerializeComparable[uint64](t, ctx, uint64(123456))
|
||||
testSerializeComparable[int](t, ctx, 1234)
|
||||
testSerializeComparable[int8] (t, ctx, int8(-123))
|
||||
testSerializeComparable[int16](t, ctx, int16(-1234))
|
||||
testSerializeComparable[int32](t, ctx, int32(-12345))
|
||||
testSerializeComparable[int64](t, ctx, int64(-123456))
|
||||
testSerializeComparable[time.Duration](t, ctx, time.Duration(100))
|
||||
testSerializeComparable[time.Time](t, ctx, time.Now().Truncate(0))
|
||||
testSerializeSlice[[]int](t, ctx, []int{123, 456, 789, 101112})
|
||||
testSerializeSlice[[]int](t, ctx, ([]int)(nil))
|
||||
testSerializeSliceSlice[[][]int](t, ctx, [][]int{{123, 456, 789, 101112}, {3253, 2341, 735, 212}, {123, 51}, nil})
|
||||
testSerializeSliceSlice[[][]string](t, ctx, [][]string{{"123", "456", "789", "101112"}, {"3253", "2341", "735", "212"}, {"123", "51"}, nil})
|
||||
|
||||
testSerialize(t, ctx, map[int8]map[*int8]string{})
|
||||
testSerialize(t, ctx, map[int8]time.Time{
|
||||
1: time.Now(),
|
||||
3: time.Now().Add(time.Second),
|
||||
0: time.Now().Add(time.Second*2),
|
||||
4: time.Now().Add(time.Second*3),
|
||||
})
|
||||
|
||||
testSerialize(t, ctx, Tree{
|
||||
SerializedTypeFor[NodeType](): nil,
|
||||
SerializedTypeFor[SerializedType](): {
|
||||
SerializedTypeFor[NodeType](): Tree{},
|
||||
},
|
||||
})
|
||||
|
||||
var i interface{} = nil
|
||||
testSerialize(t, ctx, i)
|
||||
|
||||
testSerializeMap(t, ctx, map[int8]interface{}{
|
||||
0: "abcd",
|
||||
1: uint32(12345678),
|
||||
2: i,
|
||||
3: 123,
|
||||
})
|
||||
|
||||
testSerializeMap(t, ctx, map[int8]int32{
|
||||
0: 1234,
|
||||
2: 5678,
|
||||
4: 9101,
|
||||
6: 1121,
|
||||
})
|
||||
|
||||
type test_struct struct {
|
||||
Int int `gv:"int"`
|
||||
String string `gv:"string"`
|
||||
}
|
||||
|
||||
err = RegisterStruct[test_struct](ctx)
|
||||
fatalErr(t, err)
|
||||
|
||||
testSerialize(t, ctx, test_struct{
|
||||
12345,
|
||||
"test_string",
|
||||
})
|
||||
|
||||
testSerialize(t, ctx, Tree{
|
||||
SerializedKindFor(reflect.Map): nil,
|
||||
SerializedKindFor(reflect.String): nil,
|
||||
})
|
||||
|
||||
testSerialize(t, ctx, Tree{
|
||||
SerializedTypeFor[Tree](): nil,
|
||||
})
|
||||
|
||||
testSerialize(t, ctx, Tree{
|
||||
SerializedTypeFor[Tree](): {
|
||||
SerializedTypeFor[error](): Tree{},
|
||||
SerializedKindFor(reflect.Map): nil,
|
||||
},
|
||||
SerializedKindFor(reflect.String): nil,
|
||||
})
|
||||
|
||||
type test_slice []string
|
||||
err = RegisterType[test_slice](ctx, SerializeTypeStub, SerializeSlice, DeserializeTypeStub[test_slice], DeserializeSlice)
|
||||
fatalErr(t, err)
|
||||
|
||||
testSerialize[[]string](t, ctx, []string{"test_1", "test_2", "test_3"})
|
||||
testSerialize[test_slice](t, ctx, test_slice{"test_1", "test_2", "test_3"})
|
||||
}
|
||||
|
||||
type test struct {
|
||||
Int int `gv:"int"`
|
||||
Str string `gv:"string"`
|
||||
}
|
||||
|
||||
func (s test) String() string {
|
||||
return fmt.Sprintf("%d:%s", s.Int, s.Str)
|
||||
}
|
||||
|
||||
func TestSerializeStructTags(t *testing.T) {
|
||||
ctx := logTestContext(t, []string{"test"})
|
||||
|
||||
err := RegisterStruct[test](ctx)
|
||||
fatalErr(t, err)
|
||||
|
||||
test_int := 10
|
||||
test_string := "test"
|
||||
|
||||
ret := testSerialize(t, ctx, test{
|
||||
test_int,
|
||||
test_string,
|
||||
})
|
||||
if ret.Int != test_int {
|
||||
t.Fatalf("Deserialized int %d does not equal test %d", ret.Int, test_int)
|
||||
} else if ret.Str != test_string {
|
||||
t.Fatalf("Deserialized string %s does not equal test %s", ret.Str, test_string)
|
||||
}
|
||||
|
||||
testSerialize(t, ctx, []test{
|
||||
{
|
||||
test_int,
|
||||
test_string,
|
||||
},
|
||||
{
|
||||
test_int * 2,
|
||||
fmt.Sprintf("%s%s", test_string, test_string),
|
||||
},
|
||||
{
|
||||
test_int * 4,
|
||||
fmt.Sprintf("%s%s%s", test_string, test_string, test_string),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testSerializeMap[M map[T]R, T, R comparable](t *testing.T, ctx *Context, val M) {
|
||||
v := testSerialize(t, ctx, val)
|
||||
for key, value := range(val) {
|
||||
recreated, exists := v[key]
|
||||
if exists == false {
|
||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||
} else if recreated != value {
|
||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||
}
|
||||
}
|
||||
if len(v) != len(val) {
|
||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||
}
|
||||
}
|
||||
|
||||
func testSerializeSliceSlice[S [][]T, T comparable](t *testing.T, ctx *Context, val S) {
|
||||
v := testSerialize(t, ctx, val)
|
||||
for i, original := range(val) {
|
||||
if (original == nil && v[i] != nil) || (original != nil && v[i] == nil) {
|
||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||
}
|
||||
for j, o := range(original) {
|
||||
if v[i][j] != o {
|
||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testSerializeSlice[S []T, T comparable](t *testing.T, ctx *Context, val S) {
|
||||
v := testSerialize(t, ctx, val)
|
||||
for i, original := range(val) {
|
||||
if v[i] != original {
|
||||
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testSerializeComparable[T comparable](t *testing.T, ctx *Context, val T) {
|
||||
v := testSerialize(t, ctx, val)
|
||||
if v != val {
|
||||
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
|
||||
}
|
||||
}
|
||||
|
||||
func testSerialize[T any](t *testing.T, ctx *Context, val T) T {
|
||||
value := reflect.ValueOf(&val).Elem()
|
||||
type_stack, err := SerializeType(ctx, value.Type())
|
||||
chunks, err := SerializeValue(ctx, value)
|
||||
value_serialized := SerializedValue{type_stack, chunks.Slice()}
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "Serialized %+v to %+v(%d)", val, value_serialized, len(value_serialized.Data))
|
||||
|
||||
value_chunks, err := value_serialized.Chunks()
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "Binary: %+v", value_chunks.Slice())
|
||||
|
||||
val_parsed, remaining_parse, err := ParseSerializedValue(value_chunks.Slice())
|
||||
fatalErr(t, err)
|
||||
ctx.Log.Logf("test", "Parsed: %+v", val_parsed)
|
||||
|
||||
if len(remaining_parse) != 0 {
|
||||
t.Fatal("Data remaining after deserializing value")
|
||||
}
|
||||
|
||||
val_type, remaining_types, err := DeserializeType(ctx, val_parsed.TypeStack)
|
||||
deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_type, val_parsed.Data)
|
||||
fatalErr(t, err)
|
||||
|
||||
if len(remaining_deserialize) != 0 {
|
||||
t.Fatal("Data remaining after deserializing value")
|
||||
} else if len(remaining_types) != 0 {
|
||||
t.Fatal("TypeStack remaining after deserializing value")
|
||||
} else if val_type != value.Type() {
|
||||
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - %+v", val_type, reflect.TypeOf(val)))
|
||||
} else if deserialized_value.CanConvert(val_type) == false {
|
||||
t.Fatal("DeserializeValue returned value that can't convert to original value")
|
||||
}
|
||||
ctx.Log.Logf("test", "Value: %+v", deserialized_value.Interface())
|
||||
if val_type.Kind() == reflect.Interface && deserialized_value.Interface() == nil {
|
||||
var zero T
|
||||
return zero
|
||||
}
|
||||
return deserialized_value.Interface().(T)
|
||||
}
|
Loading…
Reference in New Issue