Compare commits
8 Commits
e5776e0a14
...
c591fa5ace
Author | SHA1 | Date |
---|---|---|
noah metz | c591fa5ace | |
noah metz | f8dad12fdb | |
noah metz | eef8451566 | |
noah metz | 7e143c9d93 | |
noah metz | 7314c74087 | |
noah metz | 1eb6479169 | |
noah metz | e16bec3997 | |
noah metz | 6942dc02db |
@ -1,233 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"slices"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ACLSignal struct {
|
|
||||||
SignalHeader
|
|
||||||
Principal NodeID `gv:"principal"`
|
|
||||||
Action Tree `gv:"tree"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewACLSignal(principal NodeID, action Tree) *ACLSignal {
|
|
||||||
return &ACLSignal{
|
|
||||||
SignalHeader: NewSignalHeader(Direct),
|
|
||||||
Principal: principal,
|
|
||||||
Action: action,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var DefaultACLPolicy = NewAllNodesPolicy(Tree{
|
|
||||||
SerializedType(SignalTypeFor[ACLSignal]()): nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
func (signal ACLSignal) Permission() Tree {
|
|
||||||
return Tree{
|
|
||||||
SerializedType(SignalTypeFor[ACLSignal]()): nil,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ACLExt struct {
|
|
||||||
Policies []Policy `gv:"policies"`
|
|
||||||
PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"`
|
|
||||||
Pending map[uuid.UUID]PendingACLSignal `gv:"pending"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewACLExt(policies []Policy) *ACLExt {
|
|
||||||
return &ACLExt{
|
|
||||||
Policies: policies,
|
|
||||||
PendingACLs: map[uuid.UUID]PendingACL{},
|
|
||||||
Pending: map[uuid.UUID]PendingACLSignal{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ACLExt) Load(ctx *Context, node *Node) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ACLExt) Unload(ctx *Context, node *Node) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
|
|
||||||
response, is_response := signal.(ResponseSignal)
|
|
||||||
if is_response == true {
|
|
||||||
var messages Messages = nil
|
|
||||||
var changes = Changes{}
|
|
||||||
info, waiting := ext.Pending[response.ResponseID()]
|
|
||||||
if waiting == true {
|
|
||||||
changes.Add("pending")
|
|
||||||
delete(ext.Pending, response.ResponseID())
|
|
||||||
if response.ID() != info.Timeout {
|
|
||||||
err := node.DequeueSignal(info.Timeout)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("acl", "timeout dequeue error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
acl_info, found := ext.PendingACLs[info.ID]
|
|
||||||
if found == true {
|
|
||||||
acl_info.Counter -= 1
|
|
||||||
acl_info.Responses = append(acl_info.Responses, response)
|
|
||||||
|
|
||||||
policy_index := slices.IndexFunc(ext.Policies, func(policy Policy) bool {
|
|
||||||
return policy.ID() == info.Policy
|
|
||||||
})
|
|
||||||
|
|
||||||
if policy_index == -1 {
|
|
||||||
ctx.Log.Logf("acl", "pending signal for nonexistent policy")
|
|
||||||
delete(ext.PendingACLs, info.ID)
|
|
||||||
err := node.DequeueSignal(acl_info.TimeoutID)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow {
|
|
||||||
changes.Add("pending_acls")
|
|
||||||
delete(ext.PendingACLs, info.ID)
|
|
||||||
ctx.Log.Logf("acl", "Request delayed allow")
|
|
||||||
messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID))
|
|
||||||
err := node.DequeueSignal(acl_info.TimeoutID)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
|
||||||
}
|
|
||||||
} else if acl_info.Counter == 0 {
|
|
||||||
changes.Add("pending_acls")
|
|
||||||
delete(ext.PendingACLs, info.ID)
|
|
||||||
ctx.Log.Logf("acl", "Request delayed deny")
|
|
||||||
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied"))
|
|
||||||
err := node.DequeueSignal(acl_info.TimeoutID)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
node.PendingACLs[info.ID] = acl_info
|
|
||||||
changes.Add("pending_acls")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return messages, changes
|
|
||||||
}
|
|
||||||
|
|
||||||
var messages Messages = nil
|
|
||||||
var changes = Changes{}
|
|
||||||
|
|
||||||
switch sig := signal.(type) {
|
|
||||||
case *ACLSignal:
|
|
||||||
var acl_messages map[uuid.UUID]Messages = nil
|
|
||||||
denied := true
|
|
||||||
for _, policy := range(ext.Policies) {
|
|
||||||
policy_messages, result := policy.Allows(ctx, sig.Principal, sig.Action, node)
|
|
||||||
if result == Allow {
|
|
||||||
denied = false
|
|
||||||
break
|
|
||||||
} else if result == Pending {
|
|
||||||
if len(policy_messages) == 0 {
|
|
||||||
ctx.Log.Logf("acl", "Pending result for %s with no messages returned", policy.ID())
|
|
||||||
continue
|
|
||||||
} else if acl_messages == nil {
|
|
||||||
acl_messages = map[uuid.UUID]Messages{}
|
|
||||||
denied = false
|
|
||||||
}
|
|
||||||
|
|
||||||
acl_messages[policy.ID()] = policy_messages
|
|
||||||
ctx.Log.Logf("acl", "Pending result for %s:%s - %+v", node.ID, policy.ID(), acl_messages)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if denied == true {
|
|
||||||
ctx.Log.Logf("acl", "Request denied")
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied"))
|
|
||||||
} else if acl_messages != nil {
|
|
||||||
ctx.Log.Logf("acl", "Request pending")
|
|
||||||
changes.Add("pending")
|
|
||||||
total_messages := 0
|
|
||||||
// TODO: reasonable timeout/configurable
|
|
||||||
timeout_time := time.Now().Add(time.Second)
|
|
||||||
for policy_id, policy_messages := range(acl_messages) {
|
|
||||||
total_messages += len(policy_messages)
|
|
||||||
for _, message := range(policy_messages) {
|
|
||||||
timeout_signal := NewTimeoutSignal(message.Signal.ID())
|
|
||||||
ext.Pending[message.Signal.ID()] = PendingACLSignal{
|
|
||||||
Policy: policy_id,
|
|
||||||
Timeout: timeout_signal.Id,
|
|
||||||
ID: sig.Id,
|
|
||||||
}
|
|
||||||
node.QueueSignal(timeout_time, timeout_signal)
|
|
||||||
messages = append(messages, message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
acl_timeout := NewACLTimeoutSignal(sig.Id)
|
|
||||||
node.QueueSignal(timeout_time, acl_timeout)
|
|
||||||
ext.PendingACLs[sig.Id] = PendingACL{
|
|
||||||
Counter: total_messages,
|
|
||||||
Responses: []ResponseSignal{},
|
|
||||||
TimeoutID: acl_timeout.Id,
|
|
||||||
Action: sig.Action,
|
|
||||||
Principal: sig.Principal,
|
|
||||||
|
|
||||||
Source: source,
|
|
||||||
Signal: signal,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ctx.Log.Logf("acl", "Request allowed")
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
|
||||||
}
|
|
||||||
// Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request
|
|
||||||
case *ACLTimeoutSignal:
|
|
||||||
acl_info, exists := ext.PendingACLs[sig.ReqID]
|
|
||||||
if exists == true {
|
|
||||||
delete(ext.PendingACLs, sig.ReqID)
|
|
||||||
changes.Add("pending_acls")
|
|
||||||
ctx.Log.Logf("acl", "Request timeout deny")
|
|
||||||
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout"))
|
|
||||||
err := node.DequeueSignal(acl_info.TimeoutID)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ctx.Log.Logf("acl", "ACL_TIMEOUT_SIGNAL for passed acl")
|
|
||||||
}
|
|
||||||
// Delete from PendingACLs
|
|
||||||
}
|
|
||||||
|
|
||||||
return messages, changes
|
|
||||||
}
|
|
||||||
|
|
||||||
type ACLProxyPolicy struct {
|
|
||||||
PolicyHeader
|
|
||||||
Proxies []NodeID `gv:"proxies"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewACLProxyPolicy(proxies []NodeID) ACLProxyPolicy {
|
|
||||||
return ACLProxyPolicy{
|
|
||||||
NewPolicyHeader(),
|
|
||||||
proxies,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy ACLProxyPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
|
|
||||||
if len(policy.Proxies) == 0 {
|
|
||||||
return nil, Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := Messages{}
|
|
||||||
for _, proxy := range(policy.Proxies) {
|
|
||||||
messages = messages.Add(ctx, proxy, node, nil, NewACLSignal(principal_id, action))
|
|
||||||
}
|
|
||||||
|
|
||||||
return messages, Pending
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy ACLProxyPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
|
||||||
_, is_success := signal.(*SuccessSignal)
|
|
||||||
if is_success == true {
|
|
||||||
return Allow
|
|
||||||
}
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
@ -1,141 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
"reflect"
|
|
||||||
"runtime/debug"
|
|
||||||
)
|
|
||||||
|
|
||||||
func checkSignal[S Signal](t *testing.T, signal Signal, check func(S)){
|
|
||||||
response_casted, cast_ok := signal.(S)
|
|
||||||
if cast_ok == false {
|
|
||||||
error_signal, is_error := signal.(*ErrorSignal)
|
|
||||||
if is_error {
|
|
||||||
t.Log(string(debug.Stack()))
|
|
||||||
t.Fatal(error_signal.Error)
|
|
||||||
}
|
|
||||||
t.Fatalf("Response of wrong type %s", reflect.TypeOf(signal))
|
|
||||||
}
|
|
||||||
|
|
||||||
check(response_casted)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSendACL[S Signal](t *testing.T, ctx *Context, listener *Node, action Tree, policies []Policy, check func(S)){
|
|
||||||
acl_node, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(policies))
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
acl_signal := NewACLSignal(listener.ID, action)
|
|
||||||
response, _ := testSend(t, ctx, acl_signal, listener, acl_node)
|
|
||||||
|
|
||||||
checkSignal(t, response, check)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testErrorSignal(t *testing.T, error_string string) func(*ErrorSignal){
|
|
||||||
return func(response *ErrorSignal) {
|
|
||||||
if response.Error != error_string {
|
|
||||||
t.Fatalf("Wrong error: %s", response.Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSuccess(*SuccessSignal){}
|
|
||||||
|
|
||||||
func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) {
|
|
||||||
source_listener, err := GetExt[ListenerExt](source)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
messages := Messages{}
|
|
||||||
messages = messages.Add(ctx, destination.ID, source, nil, signal)
|
|
||||||
fatalErr(t, ctx.Send(messages))
|
|
||||||
|
|
||||||
response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
return response, signals
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestACLBasic(t *testing.T) {
|
|
||||||
ctx := logTestContext(t, []string{"test", "acl", "group", "read_field"})
|
|
||||||
|
|
||||||
listener, err := NewNode(ctx, nil, "Base", 100, nil, NewListenerExt(100))
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "testing fail")
|
|
||||||
testSendACL(t, ctx, listener, nil, nil, testErrorSignal(t, "acl_denied"))
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "testing allow all")
|
|
||||||
testSendACL(t, ctx, listener, nil, []Policy{NewAllNodesPolicy(nil)}, testSuccess)
|
|
||||||
|
|
||||||
group, err := NewNode(ctx, nil, "Base", 100, []Policy{
|
|
||||||
DefaultGroupPolicy,
|
|
||||||
NewPerNodePolicy(map[NodeID]Tree{
|
|
||||||
listener.ID: {
|
|
||||||
SerializedType(SignalTypeFor[AddSubGroupSignal]()): nil,
|
|
||||||
SerializedType(SignalTypeFor[AddMemberSignal]()): nil,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
}, NewGroupExt(nil))
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "testing empty groups")
|
|
||||||
testSendACL(t, ctx, listener, nil, []Policy{
|
|
||||||
NewMemberOfPolicy(map[NodeID]map[string]Tree{
|
|
||||||
group.ID: {
|
|
||||||
"test_group": nil,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
}, testErrorSignal(t, "acl_denied"))
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "testing adding group")
|
|
||||||
add_subgroup_signal := NewAddSubGroupSignal("test_group")
|
|
||||||
add_subgroup_response, _ := testSend(t, ctx, add_subgroup_signal, listener, group)
|
|
||||||
checkSignal(t, add_subgroup_response, testSuccess)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "testing adding member")
|
|
||||||
add_member_signal := NewAddMemberSignal("test_group", listener.ID)
|
|
||||||
add_member_response, _ := testSend(t, ctx, add_member_signal, listener, group)
|
|
||||||
checkSignal(t, add_member_response, testSuccess)
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "testing group membership")
|
|
||||||
testSendACL(t, ctx, listener, nil, []Policy{
|
|
||||||
NewMemberOfPolicy(map[NodeID]map[string]Tree{
|
|
||||||
group.ID: {
|
|
||||||
"test_group": nil,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
}, testSuccess)
|
|
||||||
|
|
||||||
testSendACL(t, ctx, listener, nil, []Policy{
|
|
||||||
NewACLProxyPolicy(nil),
|
|
||||||
}, testErrorSignal(t, "acl_denied"))
|
|
||||||
|
|
||||||
acl_proxy_1, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(nil))
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
testSendACL(t, ctx, listener, nil, []Policy{
|
|
||||||
NewACLProxyPolicy([]NodeID{acl_proxy_1.ID}),
|
|
||||||
}, testErrorSignal(t, "acl_denied"))
|
|
||||||
|
|
||||||
acl_proxy_2, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewAllNodesPolicy(nil)}))
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
testSendACL(t, ctx, listener, nil, []Policy{
|
|
||||||
NewACLProxyPolicy([]NodeID{acl_proxy_2.ID}),
|
|
||||||
}, testSuccess)
|
|
||||||
|
|
||||||
acl_proxy_3, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy},
|
|
||||||
NewACLExt([]Policy{
|
|
||||||
NewMemberOfPolicy(map[NodeID]map[string]Tree{
|
|
||||||
group.ID: {
|
|
||||||
"test_group": nil,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
testSendACL(t, ctx, listener, nil, []Policy{
|
|
||||||
NewACLProxyPolicy([]NodeID{acl_proxy_3.ID}),
|
|
||||||
}, testSuccess)
|
|
||||||
}
|
|
@ -0,0 +1,43 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
badger "github.com/dgraph-io/badger/v3"
|
||||||
|
gv "github.com/mekkanized/graphvent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func check(err error) {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
ctx, err := gv.NewContext(db, gv.NewConsoleLogger([]string{"test", "signal"}))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil)
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
listener_ext := gv.NewListenerExt(1000)
|
||||||
|
|
||||||
|
n1, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt(nil))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
n2, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID}))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
_, err = gv.NewNode(ctx, nil, "Lockable", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID}))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
for true {
|
||||||
|
select {
|
||||||
|
case message := <- listener_ext.Chan:
|
||||||
|
fmt.Printf("Listener Message: %+v\n", message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,195 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
badger "github.com/dgraph-io/badger/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WriteNodeInit(ctx *Context, node *Node) error {
|
||||||
|
if node == nil {
|
||||||
|
return fmt.Errorf("Cannot serialize nil *Node")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.DB.Update(func(tx *badger.Txn) error {
|
||||||
|
// Get the base key bytes
|
||||||
|
id_ser, err := node.ID.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write Node value
|
||||||
|
node_val, err := Serialize(ctx, node)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = tx.Set(id_ser, node_val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write empty signal queue
|
||||||
|
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
||||||
|
sigqueue_val, err := Serialize(ctx, node.SignalQueue)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = tx.Set(sigqueue_id, sigqueue_val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write node extension list
|
||||||
|
ext_list := []ExtType{}
|
||||||
|
for ext_type := range(node.Extensions) {
|
||||||
|
ext_list = append(ext_list, ext_type)
|
||||||
|
}
|
||||||
|
ext_list_val, err := Serialize(ctx, ext_list)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
||||||
|
err = tx.Set(ext_list_id, ext_list_val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each extension:
|
||||||
|
for ext_type, ext := range(node.Extensions) {
|
||||||
|
// Write each extension's current value
|
||||||
|
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
||||||
|
ext_val, err := Serialize(ctx, ext)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = tx.Set(ext_id, ext_val)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
|
||||||
|
return ctx.DB.Update(func(tx *badger.Txn) error {
|
||||||
|
// Get the base key bytes
|
||||||
|
id_ser, err := node.ID.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the signal queue if it needs to be written
|
||||||
|
if node.writeSignalQueue {
|
||||||
|
node.writeSignalQueue = false
|
||||||
|
|
||||||
|
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
||||||
|
sigqueue_val, err := Serialize(ctx, node.SignalQueue)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = tx.Set(sigqueue_id, sigqueue_val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each ext in changes
|
||||||
|
for ext_type := range(changes) {
|
||||||
|
// Write each ext
|
||||||
|
ext, exists := node.Extensions[ext_type]
|
||||||
|
if exists == false {
|
||||||
|
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
|
||||||
|
}
|
||||||
|
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
||||||
|
ext_ser, err := Serialize(ctx, ext)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Set(ext_id, ext_ser)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadNode(ctx *Context, id NodeID) (*Node, error) {
|
||||||
|
var node *Node = nil
|
||||||
|
err := ctx.DB.View(func(tx *badger.Txn) error {
|
||||||
|
// Get the base key bytes
|
||||||
|
id_ser, err := id.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the node value
|
||||||
|
node_item, err := tx.Get(id_ser)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = node_item.Value(func(val []byte) error {
|
||||||
|
node, err = Deserialize[*Node](ctx, val)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the signal queue
|
||||||
|
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
||||||
|
sigqueue_item, err := tx.Get(sigqueue_id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = sigqueue_item.Value(func(val []byte) error {
|
||||||
|
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the extension list
|
||||||
|
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
||||||
|
ext_list_item, err := tx.Get(ext_list_id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ext_list []ExtType
|
||||||
|
ext_list_item.Value(func(val []byte) error {
|
||||||
|
ext_list, err = Deserialize[[]ExtType](ctx, val)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
// Get the extensions
|
||||||
|
for _, ext_type := range(ext_list) {
|
||||||
|
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
||||||
|
ext_item, err := tx.Get(ext_id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ext Extension
|
||||||
|
err = ext_item.Value(func(val []byte) error {
|
||||||
|
ext, err = Deserialize[Extension](ctx, val)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
node.Extensions[ext_type] = ext
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return node, nil
|
||||||
|
}
|
@ -1,145 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
graphql "github.com/graphql-go/graphql"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"reflect"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type StructFieldInfo struct {
|
|
||||||
Name string
|
|
||||||
Type *TypeInfo
|
|
||||||
Index []int
|
|
||||||
}
|
|
||||||
|
|
||||||
func ArgumentInfo(ctx *Context, field reflect.StructField, gv_tag string) (StructFieldInfo, error) {
|
|
||||||
type_info, mapped := ctx.TypeReflects[field.Type]
|
|
||||||
if mapped == false {
|
|
||||||
return StructFieldInfo{}, fmt.Errorf("field %+v is of unregistered type %+v ", field.Name, field.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
return StructFieldInfo{
|
|
||||||
Name: gv_tag,
|
|
||||||
Type: type_info,
|
|
||||||
Index: field.Index,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SignalFromArgs(ctx *Context, signal_type reflect.Type, fields []StructFieldInfo, args map[string]interface{}, id_index, direction_index []int) (Signal, error) {
|
|
||||||
fmt.Printf("FIELD: %+v\n", fields)
|
|
||||||
signal_value := reflect.New(signal_type)
|
|
||||||
|
|
||||||
id_field := signal_value.Elem().FieldByIndex(id_index)
|
|
||||||
id_field.Set(reflect.ValueOf(uuid.New()))
|
|
||||||
|
|
||||||
direction_field := signal_value.Elem().FieldByIndex(direction_index)
|
|
||||||
direction_field.Set(reflect.ValueOf(Direct))
|
|
||||||
|
|
||||||
for _, field := range(fields) {
|
|
||||||
arg, arg_exists := args[field.Name]
|
|
||||||
if arg_exists == false {
|
|
||||||
return nil, fmt.Errorf("No arg provided named %s", field.Name)
|
|
||||||
}
|
|
||||||
field_value := signal_value.Elem().FieldByIndex(field.Index)
|
|
||||||
if field_value.CanConvert(field.Type.Reflect) == false {
|
|
||||||
return nil, fmt.Errorf("Arg %s wrong type %s/%s", field.Name, field_value.Type(), field.Type.Reflect)
|
|
||||||
}
|
|
||||||
value, err := field.Type.GQLValue(ctx, arg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
fmt.Printf("Setting %s to %+v of type %+v\n", field.Name, value, value.Type())
|
|
||||||
field_value.Set(value)
|
|
||||||
}
|
|
||||||
return signal_value.Interface().(Signal), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSignalMutation(ctx *Context, name string, send_id_key string, signal_type reflect.Type) (*graphql.Field, error) {
|
|
||||||
args := graphql.FieldConfigArgument{}
|
|
||||||
arg_info := []StructFieldInfo{}
|
|
||||||
var id_index []int = nil
|
|
||||||
var direction_index []int = nil
|
|
||||||
|
|
||||||
for _, field := range(reflect.VisibleFields(signal_type)) {
|
|
||||||
gv_tag, tagged_gv := field.Tag.Lookup("gv")
|
|
||||||
if tagged_gv {
|
|
||||||
if gv_tag == "id" {
|
|
||||||
id_index = field.Index
|
|
||||||
} else if gv_tag == "direction" {
|
|
||||||
direction_index = field.Index
|
|
||||||
} else {
|
|
||||||
_, exists := args[gv_tag]
|
|
||||||
if exists == true {
|
|
||||||
return nil, fmt.Errorf("Signal has repeated tag %s", gv_tag)
|
|
||||||
} else {
|
|
||||||
info, err := ArgumentInfo(ctx, field, gv_tag)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
args[gv_tag] = &graphql.ArgumentConfig{
|
|
||||||
}
|
|
||||||
arg_info = append(arg_info, info)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, send_exists := args[send_id_key]
|
|
||||||
if send_exists == false {
|
|
||||||
args[send_id_key] = &graphql.ArgumentConfig{
|
|
||||||
Type: graphql.String,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resolve_signal := func(p graphql.ResolveParams) (interface{}, error) {
|
|
||||||
ctx, err := PrepResolve(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
send_id, err := ExtractID(p, send_id_key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
signal, err := SignalFromArgs(ctx.Context, signal_type, arg_info, p.Args, id_index, direction_index)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
msgs := Messages{}
|
|
||||||
msgs = msgs.Add(ctx.Context, send_id, ctx.Server, ctx.Authorization, signal)
|
|
||||||
|
|
||||||
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
|
|
||||||
err = ctx.Context.Send(msgs)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Ext.FreeResponseChannel(signal.ID())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
|
|
||||||
if err != nil {
|
|
||||||
ctx.Ext.FreeResponseChannel(signal.ID())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, is_success := response.(*SuccessSignal)
|
|
||||||
if is_success == true {
|
|
||||||
return "success", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
error_response, is_error := response.(*ErrorSignal)
|
|
||||||
if is_error == true {
|
|
||||||
return "error", fmt.Errorf(error_response.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("response of unhandled type %s", reflect.TypeOf(response))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &graphql.Field{
|
|
||||||
Type: graphql.String,
|
|
||||||
Args: args,
|
|
||||||
Resolve: resolve_signal,
|
|
||||||
}, nil
|
|
||||||
}
|
|
@ -1,296 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"slices"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AddSubGroupSignal struct {
|
|
||||||
SignalHeader
|
|
||||||
Name string `gv:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAddSubGroupSignal(name string) *AddSubGroupSignal {
|
|
||||||
return &AddSubGroupSignal{
|
|
||||||
NewSignalHeader(Direct),
|
|
||||||
name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (signal AddSubGroupSignal) Permission() Tree {
|
|
||||||
return Tree{
|
|
||||||
SerializedType(SignalTypeFor[AddSubGroupSignal]()): {
|
|
||||||
Hash("name", signal.Name): nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type RemoveSubGroupSignal struct {
|
|
||||||
SignalHeader
|
|
||||||
Name string `gv:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRemoveSubGroupSignal(name string) *RemoveSubGroupSignal {
|
|
||||||
return &RemoveSubGroupSignal{
|
|
||||||
NewSignalHeader(Direct),
|
|
||||||
name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (signal RemoveSubGroupSignal) Permission() Tree {
|
|
||||||
return Tree{
|
|
||||||
SerializedType(SignalTypeFor[RemoveSubGroupSignal]()): {
|
|
||||||
Hash("command", signal.Name): nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type AddMemberSignal struct {
|
|
||||||
SignalHeader
|
|
||||||
SubGroup string `gv:"sub_group"`
|
|
||||||
MemberID NodeID `gv:"member_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SubGroupGQL struct {
|
|
||||||
Name string
|
|
||||||
Members []NodeID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (signal AddMemberSignal) Permission() Tree {
|
|
||||||
return Tree{
|
|
||||||
SerializedType(SignalTypeFor[AddMemberSignal]()): {
|
|
||||||
Hash("sub_group", signal.SubGroup): nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAddMemberSignal(sub_group string, member_id NodeID) *AddMemberSignal {
|
|
||||||
return &AddMemberSignal{
|
|
||||||
NewSignalHeader(Direct),
|
|
||||||
sub_group,
|
|
||||||
member_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type RemoveMemberSignal struct {
|
|
||||||
SignalHeader
|
|
||||||
SubGroup string `gv:"sub_group"`
|
|
||||||
MemberID NodeID `gv:"member_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (signal RemoveMemberSignal) Permission() Tree {
|
|
||||||
return Tree{
|
|
||||||
SerializedType(SignalTypeFor[RemoveMemberSignal]()): {
|
|
||||||
Hash("sub_group", signal.SubGroup): nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRemoveMemberSignal(sub_group string, member_id NodeID) *RemoveMemberSignal {
|
|
||||||
return &RemoveMemberSignal{
|
|
||||||
NewSignalHeader(Direct),
|
|
||||||
sub_group,
|
|
||||||
member_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var DefaultGroupPolicy = NewAllNodesPolicy(Tree{
|
|
||||||
SerializedType(SignalTypeFor[ReadSignal]()): {
|
|
||||||
SerializedType(ExtTypeFor[GroupExt]()): {
|
|
||||||
SerializedType(GetFieldTag("sub_groups")): nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
type SubGroup struct {
|
|
||||||
Members []NodeID
|
|
||||||
Permissions Tree
|
|
||||||
}
|
|
||||||
|
|
||||||
type MemberOfPolicy struct {
|
|
||||||
PolicyHeader
|
|
||||||
Groups map[NodeID]map[string]Tree
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMemberOfPolicy(groups map[NodeID]map[string]Tree) MemberOfPolicy {
|
|
||||||
return MemberOfPolicy{
|
|
||||||
PolicyHeader: NewPolicyHeader(),
|
|
||||||
Groups: groups,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
|
||||||
sig, ok := signal.(*ReadResultSignal)
|
|
||||||
if ok == false {
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
ctx.Log.Logf("group", "member_of_read_result: %+v", sig.Extensions)
|
|
||||||
|
|
||||||
group_ext_data, ok := sig.Extensions[ExtTypeFor[GroupExt]()]
|
|
||||||
if ok == false {
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
sub_groups_ser, ok := group_ext_data["sub_groups"]
|
|
||||||
if ok == false {
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
sub_groups_type, _, err := DeserializeType(ctx, sub_groups_ser.TypeStack)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("group", "Type deserialize error: %s", err)
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
sub_groups_if, _, err := DeserializeValue(ctx, sub_groups_type, sub_groups_ser.Data)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("group", "Value deserialize error: %s", err)
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
ext_sub_groups, ok := sub_groups_if.Interface().(map[string][]NodeID)
|
|
||||||
if ok == false {
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
group, exists := policy.Groups[sig.NodeID]
|
|
||||||
if exists == false {
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
for sub_group_name, permissions := range(group) {
|
|
||||||
ext_sub_group, exists := ext_sub_groups[sub_group_name]
|
|
||||||
if exists == true {
|
|
||||||
for _, member_id := range(ext_sub_group) {
|
|
||||||
if member_id == current.Principal {
|
|
||||||
if permissions.Allows(current.Action) == Allow {
|
|
||||||
return Allow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send a read signal to Group to check if principal_id is a member of it
|
|
||||||
func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
|
|
||||||
var messages Messages = nil
|
|
||||||
for group_id, sub_groups := range(policy.Groups) {
|
|
||||||
if group_id == node.ID {
|
|
||||||
ext, err := GetExt[GroupExt](node)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Log.Logf("group", "MemberOfPolicy with self ID error: %s", err)
|
|
||||||
} else {
|
|
||||||
for sub_group_name, permission := range(sub_groups) {
|
|
||||||
ext_sub_group, exists := ext.SubGroups[sub_group_name]
|
|
||||||
if exists == true {
|
|
||||||
for _, member := range(ext_sub_group) {
|
|
||||||
if member == principal_id {
|
|
||||||
if permission.Allows(action) == Allow {
|
|
||||||
return nil, Allow
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Send the read request to the group so that ContinueAllows can parse the response to check membership
|
|
||||||
messages = messages.Add(ctx, group_id, node, nil, NewReadSignal(map[ExtType][]string{
|
|
||||||
ExtTypeFor[GroupExt](): {"sub_groups"},
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(messages) > 0 {
|
|
||||||
return messages, Pending
|
|
||||||
} else {
|
|
||||||
return nil, Deny
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type GroupExt struct {
|
|
||||||
SubGroups map[string][]NodeID `gv:"sub_groups"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGroupExt(sub_groups map[string][]NodeID) *GroupExt {
|
|
||||||
if sub_groups == nil {
|
|
||||||
sub_groups = map[string][]NodeID{}
|
|
||||||
}
|
|
||||||
return &GroupExt{
|
|
||||||
SubGroups: sub_groups,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *GroupExt) Load(ctx *Context, node *Node) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *GroupExt) Unload(ctx *Context, node *Node) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
|
|
||||||
var messages Messages = nil
|
|
||||||
var changes = Changes{}
|
|
||||||
|
|
||||||
switch sig := signal.(type) {
|
|
||||||
case *AddMemberSignal:
|
|
||||||
sub_group, exists := ext.SubGroups[sig.SubGroup]
|
|
||||||
if exists == false {
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
|
|
||||||
} else {
|
|
||||||
if slices.Contains(sub_group, sig.MemberID) == true {
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_member"))
|
|
||||||
} else {
|
|
||||||
sub_group = append(sub_group, sig.MemberID)
|
|
||||||
ext.SubGroups[sig.SubGroup] = sub_group
|
|
||||||
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
|
||||||
changes.Add("sub_groups")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case *RemoveMemberSignal:
|
|
||||||
sub_group, exists := ext.SubGroups[sig.SubGroup]
|
|
||||||
if exists == false {
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
|
|
||||||
} else {
|
|
||||||
idx := slices.Index(sub_group, sig.MemberID)
|
|
||||||
if idx == -1 {
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_member"))
|
|
||||||
} else {
|
|
||||||
sub_group = slices.Delete(sub_group, idx, idx+1)
|
|
||||||
ext.SubGroups[sig.SubGroup] = sub_group
|
|
||||||
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
|
||||||
changes.Add("sub_groups")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case *AddSubGroupSignal:
|
|
||||||
_, exists := ext.SubGroups[sig.Name]
|
|
||||||
if exists == true {
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_subgroup"))
|
|
||||||
} else {
|
|
||||||
ext.SubGroups[sig.Name] = []NodeID{}
|
|
||||||
|
|
||||||
changes.Add("sub_groups")
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
|
||||||
}
|
|
||||||
case *RemoveSubGroupSignal:
|
|
||||||
_, exists := ext.SubGroups[sig.Name]
|
|
||||||
if exists == false {
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
|
|
||||||
} else {
|
|
||||||
delete(ext.SubGroups, sig.Name)
|
|
||||||
|
|
||||||
changes.Add("sub_groups")
|
|
||||||
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return messages, changes
|
|
||||||
}
|
|
||||||
|
|
@ -1,94 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGroupAdd(t *testing.T) {
|
|
||||||
ctx := logTestContext(t, []string{"listener", "test"})
|
|
||||||
|
|
||||||
group_listener := NewListenerExt(10)
|
|
||||||
group, err := NewNode(ctx, nil, "Base", 10, nil, group_listener, NewGroupExt(nil))
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
add_subgroup_signal := NewAddSubGroupSignal("test_group")
|
|
||||||
messages := Messages{}
|
|
||||||
messages = messages.Add(ctx, group.ID, group, nil, add_subgroup_signal)
|
|
||||||
fatalErr(t, ctx.Send(messages))
|
|
||||||
|
|
||||||
resp_1, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_subgroup_signal.Id)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
error_1, is_error := resp_1.(*ErrorSignal)
|
|
||||||
if is_error {
|
|
||||||
t.Fatalf("Error returned: %s", error_1.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
user_id := RandID()
|
|
||||||
add_member_signal := NewAddMemberSignal("test_group", user_id)
|
|
||||||
|
|
||||||
messages = Messages{}
|
|
||||||
messages = messages.Add(ctx, group.ID, group, nil, add_member_signal)
|
|
||||||
fatalErr(t, ctx.Send(messages))
|
|
||||||
|
|
||||||
resp_2, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_member_signal.Id)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
error_2, is_error := resp_2.(*ErrorSignal)
|
|
||||||
if is_error {
|
|
||||||
t.Fatalf("Error returned: %s", error_2.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
read_signal := NewReadSignal(map[ExtType][]string{
|
|
||||||
ExtTypeFor[GroupExt](): {"sub_groups"},
|
|
||||||
})
|
|
||||||
|
|
||||||
messages = Messages{}
|
|
||||||
messages = messages.Add(ctx, group.ID, group, nil, read_signal)
|
|
||||||
fatalErr(t, ctx.Send(messages))
|
|
||||||
|
|
||||||
response, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, read_signal.Id)
|
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
read_response := response.(*ReadResultSignal)
|
|
||||||
|
|
||||||
sub_groups_serialized := read_response.Extensions[ExtTypeFor[GroupExt]()]["sub_groups"]
|
|
||||||
|
|
||||||
sub_groups_type, remaining_types, err := DeserializeType(ctx, sub_groups_serialized.TypeStack)
|
|
||||||
fatalErr(t, err)
|
|
||||||
if len(remaining_types) > 0 {
|
|
||||||
t.Fatalf("Types remaining after deserializing subgroups: %d", len(remaining_types))
|
|
||||||
}
|
|
||||||
|
|
||||||
sub_groups_value, remaining, err := DeserializeValue(ctx, sub_groups_type, sub_groups_serialized.Data)
|
|
||||||
fatalErr(t, err)
|
|
||||||
if len(remaining) > 0 {
|
|
||||||
t.Fatalf("Data remaining after deserializing subgroups: %d", len(remaining_types))
|
|
||||||
}
|
|
||||||
|
|
||||||
sub_groups, ok := sub_groups_value.Interface().(map[string][]NodeID)
|
|
||||||
|
|
||||||
if ok != true {
|
|
||||||
t.Fatalf("sub_groups wrong type %s", sub_groups_value.Type())
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(sub_groups) != 1 {
|
|
||||||
t.Fatalf("sub_groups wrong length %d", len(sub_groups))
|
|
||||||
}
|
|
||||||
|
|
||||||
test_subgroup, exists := sub_groups["test_group"]
|
|
||||||
if exists == false {
|
|
||||||
t.Fatal("test_group not in subgroups")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(test_subgroup) != 1 {
|
|
||||||
t.Fatalf("test_group wrong size %d/1", len(test_subgroup))
|
|
||||||
}
|
|
||||||
|
|
||||||
if test_subgroup[0] != user_id {
|
|
||||||
t.Fatalf("sub_groups wrong value %s", test_subgroup[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Read Response: %+v", read_response)
|
|
||||||
}
|
|
@ -1,114 +1,11 @@
|
|||||||
package graphvent
|
package graphvent
|
||||||
|
|
||||||
import (
|
type SendMsg struct {
|
||||||
"time"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthInfo struct {
|
|
||||||
// The Node that issued the authorization
|
|
||||||
Identity ed25519.PublicKey
|
|
||||||
|
|
||||||
// Time the authorization was generated
|
|
||||||
Start time.Time
|
|
||||||
|
|
||||||
// Signature of Start + Principal with Identity private key
|
|
||||||
Signature []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthorizationToken struct {
|
|
||||||
AuthInfo
|
|
||||||
|
|
||||||
// The private key generated by the client, encrypted with the servers public key
|
|
||||||
KeyEncrypted []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClientAuthorization struct {
|
|
||||||
AuthInfo
|
|
||||||
|
|
||||||
// The private key generated by the client
|
|
||||||
Key ed25519.PrivateKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authorization structs can be passed in a message that originated from a different node than the sender
|
|
||||||
type Authorization struct {
|
|
||||||
AuthInfo
|
|
||||||
|
|
||||||
// The public key generated for this authorization
|
|
||||||
Key ed25519.PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Dest NodeID
|
Dest NodeID
|
||||||
Source ed25519.PublicKey
|
|
||||||
|
|
||||||
Authorization *Authorization
|
|
||||||
|
|
||||||
Signal Signal
|
Signal Signal
|
||||||
Signature []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type Messages []*Message
|
|
||||||
func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages {
|
|
||||||
msg, err := NewMessage(ctx, dest, source, authorization, signal)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
} else {
|
|
||||||
msgs = append(msgs, msg)
|
|
||||||
}
|
|
||||||
return msgs
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMessages(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signals... Signal) Messages {
|
|
||||||
messages := Messages{}
|
|
||||||
for _, signal := range(signals) {
|
|
||||||
messages = messages.Add(ctx, dest, source, authorization, signal)
|
|
||||||
}
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) {
|
|
||||||
signal_ser, err := SerializeAny(ctx, signal)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
signal_chunks, err := signal_ser.Chunks()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dest_ser, err := dest.MarshalBinary()
|
type RecvMsg struct {
|
||||||
if err != nil {
|
Source NodeID
|
||||||
return nil, err
|
Signal Signal
|
||||||
}
|
|
||||||
source_ser, err := source.ID.MarshalBinary()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sig_data := append(dest_ser, source_ser...)
|
|
||||||
sig_data = append(sig_data, signal_chunks.Slice()...)
|
|
||||||
var message_auth *Authorization = nil
|
|
||||||
if authorization != nil {
|
|
||||||
sig_data = append(sig_data, authorization.Signature...)
|
|
||||||
message_auth = &Authorization{
|
|
||||||
authorization.AuthInfo,
|
|
||||||
authorization.Key.Public().(ed25519.PublicKey),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Message{
|
|
||||||
Dest: dest,
|
|
||||||
Source: source.Key.Public().(ed25519.PublicKey),
|
|
||||||
Authorization: message_auth,
|
|
||||||
Signal: signal,
|
|
||||||
Signature: sig,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
@ -1,139 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Policy interface {
|
|
||||||
Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult)
|
|
||||||
ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult
|
|
||||||
ID() uuid.UUID
|
|
||||||
}
|
|
||||||
|
|
||||||
type PolicyHeader struct {
|
|
||||||
UUID uuid.UUID `gv:"uuid"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (header PolicyHeader) ID() uuid.UUID {
|
|
||||||
return header.UUID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
|
|
||||||
return nil, policy.Rules.Allows(action)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
|
|
||||||
for id, actions := range(policy.NodeRules) {
|
|
||||||
if id != principal_id {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, actions.Allows(action)
|
|
||||||
}
|
|
||||||
return nil, Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
func (policy PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
|
|
||||||
func CopyTree(tree Tree) Tree {
|
|
||||||
if tree == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := Tree{}
|
|
||||||
for name, sub := range(tree) {
|
|
||||||
ret[name] = CopyTree(sub)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func MergeTrees(first Tree, second Tree) Tree {
|
|
||||||
if first == nil || second == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := CopyTree(first)
|
|
||||||
for name, sub := range(second) {
|
|
||||||
current, exists := ret[name]
|
|
||||||
if exists == true {
|
|
||||||
ret[name] = MergeTrees(current, sub)
|
|
||||||
} else {
|
|
||||||
ret[name] = CopyTree(sub)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tree map[SerializedType]Tree
|
|
||||||
|
|
||||||
func (rule Tree) Allows(action Tree) RuleResult {
|
|
||||||
// If the current rule is nil, it's a wildcard and any action being processed is allowed
|
|
||||||
if rule == nil {
|
|
||||||
return Allow
|
|
||||||
// If the rule isn't "allow all" but the action is "request all", deny
|
|
||||||
} else if action == nil {
|
|
||||||
return Deny
|
|
||||||
// If the current action has no children, it's allowed
|
|
||||||
} else if len(action) == 0 {
|
|
||||||
return Allow
|
|
||||||
// If the current rule has no children but the action goes further, it's not allowed
|
|
||||||
} else if len(rule) == 0 {
|
|
||||||
return Deny
|
|
||||||
// If the current rule and action have children, all the children of action must be allowed by rule
|
|
||||||
} else {
|
|
||||||
for sub, subtree := range(action) {
|
|
||||||
subrule, exists := rule[sub]
|
|
||||||
if exists == false {
|
|
||||||
return Deny
|
|
||||||
} else if subrule.Allows(subtree) == Deny {
|
|
||||||
return Deny
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Allow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPolicyHeader() PolicyHeader {
|
|
||||||
return PolicyHeader{
|
|
||||||
UUID: uuid.New(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPerNodePolicy(node_actions map[NodeID]Tree) PerNodePolicy {
|
|
||||||
if node_actions == nil {
|
|
||||||
node_actions = map[NodeID]Tree{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return PerNodePolicy{
|
|
||||||
PolicyHeader: NewPolicyHeader(),
|
|
||||||
NodeRules: node_actions,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type PerNodePolicy struct {
|
|
||||||
PolicyHeader
|
|
||||||
NodeRules map[NodeID]Tree `gv:"node_rules"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAllNodesPolicy(rules Tree) AllNodesPolicy {
|
|
||||||
return AllNodesPolicy{
|
|
||||||
PolicyHeader: NewPolicyHeader(),
|
|
||||||
Rules: rules,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type AllNodesPolicy struct {
|
|
||||||
PolicyHeader
|
|
||||||
Rules Tree `gv:"rules"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var DefaultPolicy = NewAllNodesPolicy(Tree{
|
|
||||||
SerializedType(SignalTypeFor[ResponseSignal]()): nil,
|
|
||||||
SerializedType(SignalTypeFor[StatusSignal]()): nil,
|
|
||||||
})
|
|
File diff suppressed because it is too large
Load Diff
@ -1,247 +1,150 @@
|
|||||||
package graphvent
|
package graphvent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"reflect"
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSerializeTest(t *testing.T) {
|
func testTypeStack[T any](t *testing.T, ctx *Context) {
|
||||||
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"})
|
reflect_type := reflect.TypeFor[T]()
|
||||||
testSerialize(t, ctx, map[string][]NodeID{"test_group": {RandID(), RandID(), RandID()}})
|
stack, err := TypeStack(ctx, reflect_type)
|
||||||
testSerialize(t, ctx, map[NodeID]ReqState{
|
|
||||||
RandID(): Locked,
|
|
||||||
RandID(): Unlocked,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSerializeBasic(t *testing.T) {
|
|
||||||
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"})
|
|
||||||
testSerializeComparable[bool](t, ctx, true)
|
|
||||||
|
|
||||||
type bool_wrapped bool
|
|
||||||
err := RegisterType[bool_wrapped](ctx, nil, nil, nil, DeserializeBool[bool_wrapped])
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
testSerializeComparable[bool_wrapped](t, ctx, true)
|
|
||||||
|
|
||||||
testSerializeSlice[[]bool](t, ctx, []bool{false, false, true, false})
|
|
||||||
testSerializeComparable[string](t, ctx, "test")
|
|
||||||
testSerializeComparable[float32](t, ctx, 0.05)
|
|
||||||
testSerializeComparable[float64](t, ctx, 0.05)
|
|
||||||
testSerializeComparable[uint](t, ctx, uint(1234))
|
|
||||||
testSerializeComparable[uint8] (t, ctx, uint8(123))
|
|
||||||
testSerializeComparable[uint16](t, ctx, uint16(1234))
|
|
||||||
testSerializeComparable[uint32](t, ctx, uint32(12345))
|
|
||||||
testSerializeComparable[uint64](t, ctx, uint64(123456))
|
|
||||||
testSerializeComparable[int](t, ctx, 1234)
|
|
||||||
testSerializeComparable[int8] (t, ctx, int8(-123))
|
|
||||||
testSerializeComparable[int16](t, ctx, int16(-1234))
|
|
||||||
testSerializeComparable[int32](t, ctx, int32(-12345))
|
|
||||||
testSerializeComparable[int64](t, ctx, int64(-123456))
|
|
||||||
testSerializeComparable[time.Duration](t, ctx, time.Duration(100))
|
|
||||||
testSerializeComparable[time.Time](t, ctx, time.Now().Truncate(0))
|
|
||||||
testSerializeSlice[[]int](t, ctx, []int{123, 456, 789, 101112})
|
|
||||||
testSerializeSlice[[]int](t, ctx, ([]int)(nil))
|
|
||||||
testSerializeSliceSlice[[][]int](t, ctx, [][]int{{123, 456, 789, 101112}, {3253, 2341, 735, 212}, {123, 51}, nil})
|
|
||||||
testSerializeSliceSlice[[][]string](t, ctx, [][]string{{"123", "456", "789", "101112"}, {"3253", "2341", "735", "212"}, {"123", "51"}, nil})
|
|
||||||
|
|
||||||
testSerialize(t, ctx, map[int8]map[*int8]string{})
|
|
||||||
testSerialize(t, ctx, map[int8]time.Time{
|
|
||||||
1: time.Now(),
|
|
||||||
3: time.Now().Add(time.Second),
|
|
||||||
0: time.Now().Add(time.Second*2),
|
|
||||||
4: time.Now().Add(time.Second*3),
|
|
||||||
})
|
|
||||||
|
|
||||||
testSerialize(t, ctx, Tree{
|
|
||||||
SerializedTypeFor[NodeType](): nil,
|
|
||||||
SerializedTypeFor[SerializedType](): {
|
|
||||||
SerializedTypeFor[NodeType](): Tree{},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
var i interface{} = nil
|
|
||||||
testSerialize(t, ctx, i)
|
|
||||||
|
|
||||||
testSerializeMap(t, ctx, map[int8]interface{}{
|
|
||||||
0: "abcd",
|
|
||||||
1: uint32(12345678),
|
|
||||||
2: i,
|
|
||||||
3: 123,
|
|
||||||
})
|
|
||||||
|
|
||||||
testSerializeMap(t, ctx, map[int8]int32{
|
|
||||||
0: 1234,
|
|
||||||
2: 5678,
|
|
||||||
4: 9101,
|
|
||||||
6: 1121,
|
|
||||||
})
|
|
||||||
|
|
||||||
type test_struct struct {
|
|
||||||
Int int `gv:"int"`
|
|
||||||
String string `gv:"string"`
|
|
||||||
}
|
|
||||||
|
|
||||||
err = RegisterStruct[test_struct](ctx)
|
ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack)
|
||||||
fatalErr(t, err)
|
|
||||||
|
|
||||||
testSerialize(t, ctx, test_struct{
|
unwrapped_type, rest, err := UnwrapStack(ctx, stack)
|
||||||
12345,
|
|
||||||
"test_string",
|
|
||||||
})
|
|
||||||
|
|
||||||
testSerialize(t, ctx, Tree{
|
|
||||||
SerializedKindFor(reflect.Map): nil,
|
|
||||||
SerializedKindFor(reflect.String): nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
testSerialize(t, ctx, Tree{
|
|
||||||
SerializedTypeFor[Tree](): nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
testSerialize(t, ctx, Tree{
|
|
||||||
SerializedTypeFor[Tree](): {
|
|
||||||
SerializedTypeFor[error](): Tree{},
|
|
||||||
SerializedKindFor(reflect.Map): nil,
|
|
||||||
},
|
|
||||||
SerializedKindFor(reflect.String): nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
type test_slice []string
|
|
||||||
err = RegisterType[test_slice](ctx, SerializeTypeStub, SerializeSlice, DeserializeTypeStub[test_slice], DeserializeSlice)
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
testSerialize[[]string](t, ctx, []string{"test_1", "test_2", "test_3"})
|
if len(rest) != 0 {
|
||||||
testSerialize[test_slice](t, ctx, test_slice{"test_1", "test_2", "test_3"})
|
t.Errorf("Types remaining after unwrapping %s: %+v", unwrapped_type, stack)
|
||||||
}
|
}
|
||||||
|
|
||||||
type test struct {
|
if unwrapped_type != reflect_type {
|
||||||
Int int `gv:"int"`
|
t.Errorf("Unwrapped type[%+v] doesn't match original[%+v]", unwrapped_type, reflect_type)
|
||||||
Str string `gv:"string"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s test) String() string {
|
ctx.Log.Logf("test", "Unwrapped type[%s]: %s", reflect_type, reflect_type)
|
||||||
return fmt.Sprintf("%d:%s", s.Int, s.Str)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSerializeStructTags(t *testing.T) {
|
func TestSerializeTypes(t *testing.T) {
|
||||||
ctx := logTestContext(t, []string{"test"})
|
ctx := logTestContext(t, []string{"test"})
|
||||||
|
|
||||||
err := RegisterStruct[test](ctx)
|
testTypeStack[int](t, ctx)
|
||||||
|
testTypeStack[map[int]string](t, ctx)
|
||||||
|
testTypeStack[string](t, ctx)
|
||||||
|
testTypeStack[*string](t, ctx)
|
||||||
|
testTypeStack[*map[string]*map[*string]int](t, ctx)
|
||||||
|
testTypeStack[[5]int](t, ctx)
|
||||||
|
testTypeStack[uuid.UUID](t, ctx)
|
||||||
|
testTypeStack[NodeID](t, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) {
|
||||||
|
serialized, err := Serialize(ctx, value)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
test_int := 10
|
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
||||||
test_string := "test"
|
|
||||||
|
|
||||||
ret := testSerialize(t, ctx, test{
|
|
||||||
test_int,
|
|
||||||
test_string,
|
|
||||||
})
|
|
||||||
if ret.Int != test_int {
|
|
||||||
t.Fatalf("Deserialized int %d does not equal test %d", ret.Int, test_int)
|
|
||||||
} else if ret.Str != test_string {
|
|
||||||
t.Fatalf("Deserialized string %s does not equal test %s", ret.Str, test_string)
|
|
||||||
}
|
|
||||||
|
|
||||||
testSerialize(t, ctx, []test{
|
deserialized, err := Deserialize[T](ctx, serialized)
|
||||||
{
|
fatalErr(t, err)
|
||||||
test_int,
|
|
||||||
test_string,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
test_int * 2,
|
|
||||||
fmt.Sprintf("%s%s", test_string, test_string),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
test_int * 4,
|
|
||||||
fmt.Sprintf("%s%s%s", test_string, test_string, test_string),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSerializeMap[M map[T]R, T, R comparable](t *testing.T, ctx *Context, val M) {
|
if value != deserialized {
|
||||||
v := testSerialize(t, ctx, val)
|
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
||||||
for key, value := range(val) {
|
|
||||||
recreated, exists := v[key]
|
|
||||||
if exists == false {
|
|
||||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
|
||||||
} else if recreated != value {
|
|
||||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(v) != len(val) {
|
|
||||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerializeSliceSlice[S [][]T, T comparable](t *testing.T, ctx *Context, val S) {
|
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
||||||
v := testSerialize(t, ctx, val)
|
|
||||||
for i, original := range(val) {
|
|
||||||
if (original == nil && v[i] != nil) || (original != nil && v[i] == nil) {
|
|
||||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
|
||||||
}
|
|
||||||
for j, o := range(original) {
|
|
||||||
if v[i][j] != o {
|
|
||||||
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerializeSlice[S []T, T comparable](t *testing.T, ctx *Context, val S) {
|
func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) {
|
||||||
v := testSerialize(t, ctx, val)
|
serialized, err := Serialize(ctx, value)
|
||||||
for i, original := range(val) {
|
fatalErr(t, err)
|
||||||
if v[i] != original {
|
|
||||||
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
|
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized)
|
||||||
}
|
|
||||||
|
deserialized, err := Deserialize[L](ctx, serialized)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
for i, item := range(value) {
|
||||||
|
if item != deserialized[i] {
|
||||||
|
t.Fatalf("Deserialized list %+v does not match original %+v", value, deserialized)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerializeComparable[T comparable](t *testing.T, ctx *Context, val T) {
|
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
||||||
v := testSerialize(t, ctx, val)
|
|
||||||
if v != val {
|
|
||||||
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerialize[T any](t *testing.T, ctx *Context, val T) T {
|
func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) {
|
||||||
value := reflect.ValueOf(&val).Elem()
|
serialized, err := Serialize(ctx, value)
|
||||||
type_stack, err := SerializeType(ctx, value.Type())
|
|
||||||
chunks, err := SerializeValue(ctx, value)
|
|
||||||
value_serialized := SerializedValue{type_stack, chunks.Slice()}
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
ctx.Log.Logf("test", "Serialized %+v to %+v(%d)", val, value_serialized, len(value_serialized.Data))
|
|
||||||
|
|
||||||
value_chunks, err := value_serialized.Chunks()
|
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized)
|
||||||
fatalErr(t, err)
|
|
||||||
ctx.Log.Logf("test", "Binary: %+v", value_chunks.Slice())
|
|
||||||
|
|
||||||
val_parsed, remaining_parse, err := ParseSerializedValue(value_chunks.Slice())
|
deserialized, err := Deserialize[P](ctx, serialized)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
ctx.Log.Logf("test", "Parsed: %+v", val_parsed)
|
|
||||||
|
|
||||||
if len(remaining_parse) != 0 {
|
if value == nil && deserialized == nil {
|
||||||
t.Fatal("Data remaining after deserializing value")
|
ctx.Log.Logf("test", "Deserialized nil")
|
||||||
|
} else if value == nil {
|
||||||
|
t.Fatalf("Non-nil value[%+v] returned for nil value", deserialized)
|
||||||
|
} else if deserialized == nil {
|
||||||
|
t.Fatalf("Nil value returned for non-nil value[%+v]", value)
|
||||||
|
} else if *deserialized != *value {
|
||||||
|
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
||||||
|
} else {
|
||||||
|
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", *value, *deserialized)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val_type, remaining_types, err := DeserializeType(ctx, val_parsed.TypeStack)
|
func testSerialize[T any](t *testing.T, ctx *Context, value T) {
|
||||||
deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_type, val_parsed.Data)
|
serialized, err := Serialize(ctx, value)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
if len(remaining_deserialize) != 0 {
|
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
||||||
t.Fatal("Data remaining after deserializing value")
|
|
||||||
} else if len(remaining_types) != 0 {
|
deserialized, err := Deserialize[T](ctx, serialized)
|
||||||
t.Fatal("TypeStack remaining after deserializing value")
|
fatalErr(t, err)
|
||||||
} else if val_type != value.Type() {
|
|
||||||
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - %+v", val_type, reflect.TypeOf(val)))
|
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
||||||
} else if deserialized_value.CanConvert(val_type) == false {
|
|
||||||
t.Fatal("DeserializeValue returned value that can't convert to original value")
|
|
||||||
}
|
|
||||||
ctx.Log.Logf("test", "Value: %+v", deserialized_value.Interface())
|
|
||||||
if val_type.Kind() == reflect.Interface && deserialized_value.Interface() == nil {
|
|
||||||
var zero T
|
|
||||||
return zero
|
|
||||||
}
|
}
|
||||||
return deserialized_value.Interface().(T)
|
|
||||||
|
func TestSerializeValues(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"test"})
|
||||||
|
|
||||||
|
testSerialize(t, ctx, Extension(NewLockableExt(nil)))
|
||||||
|
|
||||||
|
testSerializeCompare[int8](t, ctx, -64)
|
||||||
|
testSerializeCompare[int16](t, ctx, -64)
|
||||||
|
testSerializeCompare[int32](t, ctx, -64)
|
||||||
|
testSerializeCompare[int64](t, ctx, -64)
|
||||||
|
testSerializeCompare[int](t, ctx, -64)
|
||||||
|
|
||||||
|
testSerializeCompare[uint8](t, ctx, 64)
|
||||||
|
testSerializeCompare[uint16](t, ctx, 64)
|
||||||
|
testSerializeCompare[uint32](t, ctx, 64)
|
||||||
|
testSerializeCompare[uint64](t, ctx, 64)
|
||||||
|
testSerializeCompare[uint](t, ctx, 64)
|
||||||
|
|
||||||
|
testSerializeCompare[string](t, ctx, "test")
|
||||||
|
|
||||||
|
a := 12
|
||||||
|
testSerializePointer[*int](t, ctx, &a)
|
||||||
|
|
||||||
|
b := "test"
|
||||||
|
testSerializePointer[*string](t, ctx, nil)
|
||||||
|
testSerializePointer[*string](t, ctx, &b)
|
||||||
|
|
||||||
|
testSerializeList(t, ctx, []int{1, 2, 3, 4, 5})
|
||||||
|
|
||||||
|
testSerializeCompare[bool](t, ctx, true)
|
||||||
|
testSerializeCompare[bool](t, ctx, false)
|
||||||
|
testSerializeCompare[int](t, ctx, -1)
|
||||||
|
testSerializeCompare[uint](t, ctx, 1)
|
||||||
|
testSerializeCompare[NodeID](t, ctx, RandID())
|
||||||
|
testSerializeCompare[*int](t, ctx, nil)
|
||||||
|
testSerializeCompare(t, ctx, "string")
|
||||||
|
|
||||||
|
node, err := NewNode(ctx, nil, "Base", 100)
|
||||||
|
fatalErr(t, err)
|
||||||
|
testSerialize(t, ctx, node)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue