Compare commits
No commits in common. "c591fa5ace9964e9db626c0d85aee3052ec55e08" and "e5776e0a1450b4b1f4c7734bbd596a4ae10abd5d" have entirely different histories.
c591fa5ace
...
e5776e0a14
@ -0,0 +1,233 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ACLSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
Principal NodeID `gv:"principal"`
|
||||||
|
Action Tree `gv:"tree"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewACLSignal(principal NodeID, action Tree) *ACLSignal {
|
||||||
|
return &ACLSignal{
|
||||||
|
SignalHeader: NewSignalHeader(Direct),
|
||||||
|
Principal: principal,
|
||||||
|
Action: action,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultACLPolicy = NewAllNodesPolicy(Tree{
|
||||||
|
SerializedType(SignalTypeFor[ACLSignal]()): nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
func (signal ACLSignal) Permission() Tree {
|
||||||
|
return Tree{
|
||||||
|
SerializedType(SignalTypeFor[ACLSignal]()): nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ACLExt struct {
|
||||||
|
Policies []Policy `gv:"policies"`
|
||||||
|
PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"`
|
||||||
|
Pending map[uuid.UUID]PendingACLSignal `gv:"pending"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewACLExt(policies []Policy) *ACLExt {
|
||||||
|
return &ACLExt{
|
||||||
|
Policies: policies,
|
||||||
|
PendingACLs: map[uuid.UUID]PendingACL{},
|
||||||
|
Pending: map[uuid.UUID]PendingACLSignal{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ACLExt) Load(ctx *Context, node *Node) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ACLExt) Unload(ctx *Context, node *Node) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
|
||||||
|
response, is_response := signal.(ResponseSignal)
|
||||||
|
if is_response == true {
|
||||||
|
var messages Messages = nil
|
||||||
|
var changes = Changes{}
|
||||||
|
info, waiting := ext.Pending[response.ResponseID()]
|
||||||
|
if waiting == true {
|
||||||
|
changes.Add("pending")
|
||||||
|
delete(ext.Pending, response.ResponseID())
|
||||||
|
if response.ID() != info.Timeout {
|
||||||
|
err := node.DequeueSignal(info.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("acl", "timeout dequeue error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
acl_info, found := ext.PendingACLs[info.ID]
|
||||||
|
if found == true {
|
||||||
|
acl_info.Counter -= 1
|
||||||
|
acl_info.Responses = append(acl_info.Responses, response)
|
||||||
|
|
||||||
|
policy_index := slices.IndexFunc(ext.Policies, func(policy Policy) bool {
|
||||||
|
return policy.ID() == info.Policy
|
||||||
|
})
|
||||||
|
|
||||||
|
if policy_index == -1 {
|
||||||
|
ctx.Log.Logf("acl", "pending signal for nonexistent policy")
|
||||||
|
delete(ext.PendingACLs, info.ID)
|
||||||
|
err := node.DequeueSignal(acl_info.TimeoutID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow {
|
||||||
|
changes.Add("pending_acls")
|
||||||
|
delete(ext.PendingACLs, info.ID)
|
||||||
|
ctx.Log.Logf("acl", "Request delayed allow")
|
||||||
|
messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID))
|
||||||
|
err := node.DequeueSignal(acl_info.TimeoutID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
||||||
|
}
|
||||||
|
} else if acl_info.Counter == 0 {
|
||||||
|
changes.Add("pending_acls")
|
||||||
|
delete(ext.PendingACLs, info.ID)
|
||||||
|
ctx.Log.Logf("acl", "Request delayed deny")
|
||||||
|
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied"))
|
||||||
|
err := node.DequeueSignal(acl_info.TimeoutID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
node.PendingACLs[info.ID] = acl_info
|
||||||
|
changes.Add("pending_acls")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return messages, changes
|
||||||
|
}
|
||||||
|
|
||||||
|
var messages Messages = nil
|
||||||
|
var changes = Changes{}
|
||||||
|
|
||||||
|
switch sig := signal.(type) {
|
||||||
|
case *ACLSignal:
|
||||||
|
var acl_messages map[uuid.UUID]Messages = nil
|
||||||
|
denied := true
|
||||||
|
for _, policy := range(ext.Policies) {
|
||||||
|
policy_messages, result := policy.Allows(ctx, sig.Principal, sig.Action, node)
|
||||||
|
if result == Allow {
|
||||||
|
denied = false
|
||||||
|
break
|
||||||
|
} else if result == Pending {
|
||||||
|
if len(policy_messages) == 0 {
|
||||||
|
ctx.Log.Logf("acl", "Pending result for %s with no messages returned", policy.ID())
|
||||||
|
continue
|
||||||
|
} else if acl_messages == nil {
|
||||||
|
acl_messages = map[uuid.UUID]Messages{}
|
||||||
|
denied = false
|
||||||
|
}
|
||||||
|
|
||||||
|
acl_messages[policy.ID()] = policy_messages
|
||||||
|
ctx.Log.Logf("acl", "Pending result for %s:%s - %+v", node.ID, policy.ID(), acl_messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if denied == true {
|
||||||
|
ctx.Log.Logf("acl", "Request denied")
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied"))
|
||||||
|
} else if acl_messages != nil {
|
||||||
|
ctx.Log.Logf("acl", "Request pending")
|
||||||
|
changes.Add("pending")
|
||||||
|
total_messages := 0
|
||||||
|
// TODO: reasonable timeout/configurable
|
||||||
|
timeout_time := time.Now().Add(time.Second)
|
||||||
|
for policy_id, policy_messages := range(acl_messages) {
|
||||||
|
total_messages += len(policy_messages)
|
||||||
|
for _, message := range(policy_messages) {
|
||||||
|
timeout_signal := NewTimeoutSignal(message.Signal.ID())
|
||||||
|
ext.Pending[message.Signal.ID()] = PendingACLSignal{
|
||||||
|
Policy: policy_id,
|
||||||
|
Timeout: timeout_signal.Id,
|
||||||
|
ID: sig.Id,
|
||||||
|
}
|
||||||
|
node.QueueSignal(timeout_time, timeout_signal)
|
||||||
|
messages = append(messages, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
acl_timeout := NewACLTimeoutSignal(sig.Id)
|
||||||
|
node.QueueSignal(timeout_time, acl_timeout)
|
||||||
|
ext.PendingACLs[sig.Id] = PendingACL{
|
||||||
|
Counter: total_messages,
|
||||||
|
Responses: []ResponseSignal{},
|
||||||
|
TimeoutID: acl_timeout.Id,
|
||||||
|
Action: sig.Action,
|
||||||
|
Principal: sig.Principal,
|
||||||
|
|
||||||
|
Source: source,
|
||||||
|
Signal: signal,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.Log.Logf("acl", "Request allowed")
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||||
|
}
|
||||||
|
// Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request
|
||||||
|
case *ACLTimeoutSignal:
|
||||||
|
acl_info, exists := ext.PendingACLs[sig.ReqID]
|
||||||
|
if exists == true {
|
||||||
|
delete(ext.PendingACLs, sig.ReqID)
|
||||||
|
changes.Add("pending_acls")
|
||||||
|
ctx.Log.Logf("acl", "Request timeout deny")
|
||||||
|
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout"))
|
||||||
|
err := node.DequeueSignal(acl_info.TimeoutID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.Log.Logf("acl", "ACL_TIMEOUT_SIGNAL for passed acl")
|
||||||
|
}
|
||||||
|
// Delete from PendingACLs
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, changes
|
||||||
|
}
|
||||||
|
|
||||||
|
type ACLProxyPolicy struct {
|
||||||
|
PolicyHeader
|
||||||
|
Proxies []NodeID `gv:"proxies"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewACLProxyPolicy(proxies []NodeID) ACLProxyPolicy {
|
||||||
|
return ACLProxyPolicy{
|
||||||
|
NewPolicyHeader(),
|
||||||
|
proxies,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy ACLProxyPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
|
||||||
|
if len(policy.Proxies) == 0 {
|
||||||
|
return nil, Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := Messages{}
|
||||||
|
for _, proxy := range(policy.Proxies) {
|
||||||
|
messages = messages.Add(ctx, proxy, node, nil, NewACLSignal(principal_id, action))
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy ACLProxyPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
||||||
|
_, is_success := signal.(*SuccessSignal)
|
||||||
|
if is_success == true {
|
||||||
|
return Allow
|
||||||
|
}
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,141 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
"reflect"
|
||||||
|
"runtime/debug"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkSignal[S Signal](t *testing.T, signal Signal, check func(S)){
|
||||||
|
response_casted, cast_ok := signal.(S)
|
||||||
|
if cast_ok == false {
|
||||||
|
error_signal, is_error := signal.(*ErrorSignal)
|
||||||
|
if is_error {
|
||||||
|
t.Log(string(debug.Stack()))
|
||||||
|
t.Fatal(error_signal.Error)
|
||||||
|
}
|
||||||
|
t.Fatalf("Response of wrong type %s", reflect.TypeOf(signal))
|
||||||
|
}
|
||||||
|
|
||||||
|
check(response_casted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSendACL[S Signal](t *testing.T, ctx *Context, listener *Node, action Tree, policies []Policy, check func(S)){
|
||||||
|
acl_node, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(policies))
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
acl_signal := NewACLSignal(listener.ID, action)
|
||||||
|
response, _ := testSend(t, ctx, acl_signal, listener, acl_node)
|
||||||
|
|
||||||
|
checkSignal(t, response, check)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testErrorSignal(t *testing.T, error_string string) func(*ErrorSignal){
|
||||||
|
return func(response *ErrorSignal) {
|
||||||
|
if response.Error != error_string {
|
||||||
|
t.Fatalf("Wrong error: %s", response.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSuccess(*SuccessSignal){}
|
||||||
|
|
||||||
|
func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) {
|
||||||
|
source_listener, err := GetExt[ListenerExt](source)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
messages := Messages{}
|
||||||
|
messages = messages.Add(ctx, destination.ID, source, nil, signal)
|
||||||
|
fatalErr(t, ctx.Send(messages))
|
||||||
|
|
||||||
|
response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
return response, signals
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestACLBasic(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"test", "acl", "group", "read_field"})
|
||||||
|
|
||||||
|
listener, err := NewNode(ctx, nil, "Base", 100, nil, NewListenerExt(100))
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "testing fail")
|
||||||
|
testSendACL(t, ctx, listener, nil, nil, testErrorSignal(t, "acl_denied"))
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "testing allow all")
|
||||||
|
testSendACL(t, ctx, listener, nil, []Policy{NewAllNodesPolicy(nil)}, testSuccess)
|
||||||
|
|
||||||
|
group, err := NewNode(ctx, nil, "Base", 100, []Policy{
|
||||||
|
DefaultGroupPolicy,
|
||||||
|
NewPerNodePolicy(map[NodeID]Tree{
|
||||||
|
listener.ID: {
|
||||||
|
SerializedType(SignalTypeFor[AddSubGroupSignal]()): nil,
|
||||||
|
SerializedType(SignalTypeFor[AddMemberSignal]()): nil,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}, NewGroupExt(nil))
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "testing empty groups")
|
||||||
|
testSendACL(t, ctx, listener, nil, []Policy{
|
||||||
|
NewMemberOfPolicy(map[NodeID]map[string]Tree{
|
||||||
|
group.ID: {
|
||||||
|
"test_group": nil,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}, testErrorSignal(t, "acl_denied"))
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "testing adding group")
|
||||||
|
add_subgroup_signal := NewAddSubGroupSignal("test_group")
|
||||||
|
add_subgroup_response, _ := testSend(t, ctx, add_subgroup_signal, listener, group)
|
||||||
|
checkSignal(t, add_subgroup_response, testSuccess)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "testing adding member")
|
||||||
|
add_member_signal := NewAddMemberSignal("test_group", listener.ID)
|
||||||
|
add_member_response, _ := testSend(t, ctx, add_member_signal, listener, group)
|
||||||
|
checkSignal(t, add_member_response, testSuccess)
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "testing group membership")
|
||||||
|
testSendACL(t, ctx, listener, nil, []Policy{
|
||||||
|
NewMemberOfPolicy(map[NodeID]map[string]Tree{
|
||||||
|
group.ID: {
|
||||||
|
"test_group": nil,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}, testSuccess)
|
||||||
|
|
||||||
|
testSendACL(t, ctx, listener, nil, []Policy{
|
||||||
|
NewACLProxyPolicy(nil),
|
||||||
|
}, testErrorSignal(t, "acl_denied"))
|
||||||
|
|
||||||
|
acl_proxy_1, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(nil))
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
testSendACL(t, ctx, listener, nil, []Policy{
|
||||||
|
NewACLProxyPolicy([]NodeID{acl_proxy_1.ID}),
|
||||||
|
}, testErrorSignal(t, "acl_denied"))
|
||||||
|
|
||||||
|
acl_proxy_2, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewAllNodesPolicy(nil)}))
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
testSendACL(t, ctx, listener, nil, []Policy{
|
||||||
|
NewACLProxyPolicy([]NodeID{acl_proxy_2.ID}),
|
||||||
|
}, testSuccess)
|
||||||
|
|
||||||
|
acl_proxy_3, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy},
|
||||||
|
NewACLExt([]Policy{
|
||||||
|
NewMemberOfPolicy(map[NodeID]map[string]Tree{
|
||||||
|
group.ID: {
|
||||||
|
"test_group": nil,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
testSendACL(t, ctx, listener, nil, []Policy{
|
||||||
|
NewACLProxyPolicy([]NodeID{acl_proxy_3.ID}),
|
||||||
|
}, testSuccess)
|
||||||
|
}
|
@ -1,43 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
badger "github.com/dgraph-io/badger/v3"
|
|
||||||
gv "github.com/mekkanized/graphvent"
|
|
||||||
)
|
|
||||||
|
|
||||||
func check(err error) {
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
ctx, err := gv.NewContext(db, gv.NewConsoleLogger([]string{"test", "signal"}))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil)
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
listener_ext := gv.NewListenerExt(1000)
|
|
||||||
|
|
||||||
n1, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt(nil))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
n2, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID}))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
_, err = gv.NewNode(ctx, nil, "Lockable", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID}))
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
for true {
|
|
||||||
select {
|
|
||||||
case message := <- listener_ext.Chan:
|
|
||||||
fmt.Printf("Listener Message: %+v\n", message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
@ -1,195 +0,0 @@
|
|||||||
package graphvent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
badger "github.com/dgraph-io/badger/v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
func WriteNodeInit(ctx *Context, node *Node) error {
|
|
||||||
if node == nil {
|
|
||||||
return fmt.Errorf("Cannot serialize nil *Node")
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx.DB.Update(func(tx *badger.Txn) error {
|
|
||||||
// Get the base key bytes
|
|
||||||
id_ser, err := node.ID.MarshalBinary()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write Node value
|
|
||||||
node_val, err := Serialize(ctx, node)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = tx.Set(id_ser, node_val)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write empty signal queue
|
|
||||||
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
|
||||||
sigqueue_val, err := Serialize(ctx, node.SignalQueue)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = tx.Set(sigqueue_id, sigqueue_val)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write node extension list
|
|
||||||
ext_list := []ExtType{}
|
|
||||||
for ext_type := range(node.Extensions) {
|
|
||||||
ext_list = append(ext_list, ext_type)
|
|
||||||
}
|
|
||||||
ext_list_val, err := Serialize(ctx, ext_list)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
|
||||||
err = tx.Set(ext_list_id, ext_list_val)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each extension:
|
|
||||||
for ext_type, ext := range(node.Extensions) {
|
|
||||||
// Write each extension's current value
|
|
||||||
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
|
||||||
ext_val, err := Serialize(ctx, ext)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = tx.Set(ext_id, ext_val)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
|
|
||||||
return ctx.DB.Update(func(tx *badger.Txn) error {
|
|
||||||
// Get the base key bytes
|
|
||||||
id_ser, err := node.ID.MarshalBinary()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write the signal queue if it needs to be written
|
|
||||||
if node.writeSignalQueue {
|
|
||||||
node.writeSignalQueue = false
|
|
||||||
|
|
||||||
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
|
||||||
sigqueue_val, err := Serialize(ctx, node.SignalQueue)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = tx.Set(sigqueue_id, sigqueue_val)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each ext in changes
|
|
||||||
for ext_type := range(changes) {
|
|
||||||
// Write each ext
|
|
||||||
ext, exists := node.Extensions[ext_type]
|
|
||||||
if exists == false {
|
|
||||||
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
|
|
||||||
}
|
|
||||||
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
|
||||||
ext_ser, err := Serialize(ctx, ext)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Set(ext_id, ext_ser)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadNode(ctx *Context, id NodeID) (*Node, error) {
|
|
||||||
var node *Node = nil
|
|
||||||
err := ctx.DB.View(func(tx *badger.Txn) error {
|
|
||||||
// Get the base key bytes
|
|
||||||
id_ser, err := id.MarshalBinary()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the node value
|
|
||||||
node_item, err := tx.Get(id_ser)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = node_item.Value(func(val []byte) error {
|
|
||||||
node, err = Deserialize[*Node](ctx, val)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the signal queue
|
|
||||||
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
|
|
||||||
sigqueue_item, err := tx.Get(sigqueue_id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = sigqueue_item.Value(func(val []byte) error {
|
|
||||||
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the extension list
|
|
||||||
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
|
|
||||||
ext_list_item, err := tx.Get(ext_list_id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var ext_list []ExtType
|
|
||||||
ext_list_item.Value(func(val []byte) error {
|
|
||||||
ext_list, err = Deserialize[[]ExtType](ctx, val)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
// Get the extensions
|
|
||||||
for _, ext_type := range(ext_list) {
|
|
||||||
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
|
|
||||||
ext_item, err := tx.Get(ext_id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var ext Extension
|
|
||||||
err = ext_item.Value(func(val []byte) error {
|
|
||||||
ext, err = Deserialize[Extension](ctx, val)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
node.Extensions[ext_type] = ext
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return node, nil
|
|
||||||
}
|
|
@ -0,0 +1,145 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
graphql "github.com/graphql-go/graphql"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"reflect"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StructFieldInfo struct {
|
||||||
|
Name string
|
||||||
|
Type *TypeInfo
|
||||||
|
Index []int
|
||||||
|
}
|
||||||
|
|
||||||
|
func ArgumentInfo(ctx *Context, field reflect.StructField, gv_tag string) (StructFieldInfo, error) {
|
||||||
|
type_info, mapped := ctx.TypeReflects[field.Type]
|
||||||
|
if mapped == false {
|
||||||
|
return StructFieldInfo{}, fmt.Errorf("field %+v is of unregistered type %+v ", field.Name, field.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
return StructFieldInfo{
|
||||||
|
Name: gv_tag,
|
||||||
|
Type: type_info,
|
||||||
|
Index: field.Index,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignalFromArgs(ctx *Context, signal_type reflect.Type, fields []StructFieldInfo, args map[string]interface{}, id_index, direction_index []int) (Signal, error) {
|
||||||
|
fmt.Printf("FIELD: %+v\n", fields)
|
||||||
|
signal_value := reflect.New(signal_type)
|
||||||
|
|
||||||
|
id_field := signal_value.Elem().FieldByIndex(id_index)
|
||||||
|
id_field.Set(reflect.ValueOf(uuid.New()))
|
||||||
|
|
||||||
|
direction_field := signal_value.Elem().FieldByIndex(direction_index)
|
||||||
|
direction_field.Set(reflect.ValueOf(Direct))
|
||||||
|
|
||||||
|
for _, field := range(fields) {
|
||||||
|
arg, arg_exists := args[field.Name]
|
||||||
|
if arg_exists == false {
|
||||||
|
return nil, fmt.Errorf("No arg provided named %s", field.Name)
|
||||||
|
}
|
||||||
|
field_value := signal_value.Elem().FieldByIndex(field.Index)
|
||||||
|
if field_value.CanConvert(field.Type.Reflect) == false {
|
||||||
|
return nil, fmt.Errorf("Arg %s wrong type %s/%s", field.Name, field_value.Type(), field.Type.Reflect)
|
||||||
|
}
|
||||||
|
value, err := field.Type.GQLValue(ctx, arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fmt.Printf("Setting %s to %+v of type %+v\n", field.Name, value, value.Type())
|
||||||
|
field_value.Set(value)
|
||||||
|
}
|
||||||
|
return signal_value.Interface().(Signal), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSignalMutation(ctx *Context, name string, send_id_key string, signal_type reflect.Type) (*graphql.Field, error) {
|
||||||
|
args := graphql.FieldConfigArgument{}
|
||||||
|
arg_info := []StructFieldInfo{}
|
||||||
|
var id_index []int = nil
|
||||||
|
var direction_index []int = nil
|
||||||
|
|
||||||
|
for _, field := range(reflect.VisibleFields(signal_type)) {
|
||||||
|
gv_tag, tagged_gv := field.Tag.Lookup("gv")
|
||||||
|
if tagged_gv {
|
||||||
|
if gv_tag == "id" {
|
||||||
|
id_index = field.Index
|
||||||
|
} else if gv_tag == "direction" {
|
||||||
|
direction_index = field.Index
|
||||||
|
} else {
|
||||||
|
_, exists := args[gv_tag]
|
||||||
|
if exists == true {
|
||||||
|
return nil, fmt.Errorf("Signal has repeated tag %s", gv_tag)
|
||||||
|
} else {
|
||||||
|
info, err := ArgumentInfo(ctx, field, gv_tag)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
args[gv_tag] = &graphql.ArgumentConfig{
|
||||||
|
}
|
||||||
|
arg_info = append(arg_info, info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, send_exists := args[send_id_key]
|
||||||
|
if send_exists == false {
|
||||||
|
args[send_id_key] = &graphql.ArgumentConfig{
|
||||||
|
Type: graphql.String,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resolve_signal := func(p graphql.ResolveParams) (interface{}, error) {
|
||||||
|
ctx, err := PrepResolve(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
send_id, err := ExtractID(p, send_id_key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
signal, err := SignalFromArgs(ctx.Context, signal_type, arg_info, p.Args, id_index, direction_index)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msgs := Messages{}
|
||||||
|
msgs = msgs.Add(ctx.Context, send_id, ctx.Server, ctx.Authorization, signal)
|
||||||
|
|
||||||
|
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
|
||||||
|
err = ctx.Context.Send(msgs)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Ext.FreeResponseChannel(signal.ID())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
|
||||||
|
if err != nil {
|
||||||
|
ctx.Ext.FreeResponseChannel(signal.ID())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, is_success := response.(*SuccessSignal)
|
||||||
|
if is_success == true {
|
||||||
|
return "success", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
error_response, is_error := response.(*ErrorSignal)
|
||||||
|
if is_error == true {
|
||||||
|
return "error", fmt.Errorf(error_response.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("response of unhandled type %s", reflect.TypeOf(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &graphql.Field{
|
||||||
|
Type: graphql.String,
|
||||||
|
Args: args,
|
||||||
|
Resolve: resolve_signal,
|
||||||
|
}, nil
|
||||||
|
}
|
@ -0,0 +1,296 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AddSubGroupSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
Name string `gv:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAddSubGroupSignal(name string) *AddSubGroupSignal {
|
||||||
|
return &AddSubGroupSignal{
|
||||||
|
NewSignalHeader(Direct),
|
||||||
|
name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (signal AddSubGroupSignal) Permission() Tree {
|
||||||
|
return Tree{
|
||||||
|
SerializedType(SignalTypeFor[AddSubGroupSignal]()): {
|
||||||
|
Hash("name", signal.Name): nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type RemoveSubGroupSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
Name string `gv:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRemoveSubGroupSignal(name string) *RemoveSubGroupSignal {
|
||||||
|
return &RemoveSubGroupSignal{
|
||||||
|
NewSignalHeader(Direct),
|
||||||
|
name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (signal RemoveSubGroupSignal) Permission() Tree {
|
||||||
|
return Tree{
|
||||||
|
SerializedType(SignalTypeFor[RemoveSubGroupSignal]()): {
|
||||||
|
Hash("command", signal.Name): nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AddMemberSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
SubGroup string `gv:"sub_group"`
|
||||||
|
MemberID NodeID `gv:"member_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SubGroupGQL struct {
|
||||||
|
Name string
|
||||||
|
Members []NodeID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (signal AddMemberSignal) Permission() Tree {
|
||||||
|
return Tree{
|
||||||
|
SerializedType(SignalTypeFor[AddMemberSignal]()): {
|
||||||
|
Hash("sub_group", signal.SubGroup): nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAddMemberSignal(sub_group string, member_id NodeID) *AddMemberSignal {
|
||||||
|
return &AddMemberSignal{
|
||||||
|
NewSignalHeader(Direct),
|
||||||
|
sub_group,
|
||||||
|
member_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type RemoveMemberSignal struct {
|
||||||
|
SignalHeader
|
||||||
|
SubGroup string `gv:"sub_group"`
|
||||||
|
MemberID NodeID `gv:"member_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (signal RemoveMemberSignal) Permission() Tree {
|
||||||
|
return Tree{
|
||||||
|
SerializedType(SignalTypeFor[RemoveMemberSignal]()): {
|
||||||
|
Hash("sub_group", signal.SubGroup): nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRemoveMemberSignal(sub_group string, member_id NodeID) *RemoveMemberSignal {
|
||||||
|
return &RemoveMemberSignal{
|
||||||
|
NewSignalHeader(Direct),
|
||||||
|
sub_group,
|
||||||
|
member_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultGroupPolicy = NewAllNodesPolicy(Tree{
|
||||||
|
SerializedType(SignalTypeFor[ReadSignal]()): {
|
||||||
|
SerializedType(ExtTypeFor[GroupExt]()): {
|
||||||
|
SerializedType(GetFieldTag("sub_groups")): nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
type SubGroup struct {
|
||||||
|
Members []NodeID
|
||||||
|
Permissions Tree
|
||||||
|
}
|
||||||
|
|
||||||
|
type MemberOfPolicy struct {
|
||||||
|
PolicyHeader
|
||||||
|
Groups map[NodeID]map[string]Tree
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMemberOfPolicy(groups map[NodeID]map[string]Tree) MemberOfPolicy {
|
||||||
|
return MemberOfPolicy{
|
||||||
|
PolicyHeader: NewPolicyHeader(),
|
||||||
|
Groups: groups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
||||||
|
sig, ok := signal.(*ReadResultSignal)
|
||||||
|
if ok == false {
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
ctx.Log.Logf("group", "member_of_read_result: %+v", sig.Extensions)
|
||||||
|
|
||||||
|
group_ext_data, ok := sig.Extensions[ExtTypeFor[GroupExt]()]
|
||||||
|
if ok == false {
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_groups_ser, ok := group_ext_data["sub_groups"]
|
||||||
|
if ok == false {
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_groups_type, _, err := DeserializeType(ctx, sub_groups_ser.TypeStack)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("group", "Type deserialize error: %s", err)
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_groups_if, _, err := DeserializeValue(ctx, sub_groups_type, sub_groups_ser.Data)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("group", "Value deserialize error: %s", err)
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
ext_sub_groups, ok := sub_groups_if.Interface().(map[string][]NodeID)
|
||||||
|
if ok == false {
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
group, exists := policy.Groups[sig.NodeID]
|
||||||
|
if exists == false {
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
for sub_group_name, permissions := range(group) {
|
||||||
|
ext_sub_group, exists := ext_sub_groups[sub_group_name]
|
||||||
|
if exists == true {
|
||||||
|
for _, member_id := range(ext_sub_group) {
|
||||||
|
if member_id == current.Principal {
|
||||||
|
if permissions.Allows(current.Action) == Allow {
|
||||||
|
return Allow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a read signal to Group to check if principal_id is a member of it
|
||||||
|
func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
|
||||||
|
var messages Messages = nil
|
||||||
|
for group_id, sub_groups := range(policy.Groups) {
|
||||||
|
if group_id == node.ID {
|
||||||
|
ext, err := GetExt[GroupExt](node)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Log.Logf("group", "MemberOfPolicy with self ID error: %s", err)
|
||||||
|
} else {
|
||||||
|
for sub_group_name, permission := range(sub_groups) {
|
||||||
|
ext_sub_group, exists := ext.SubGroups[sub_group_name]
|
||||||
|
if exists == true {
|
||||||
|
for _, member := range(ext_sub_group) {
|
||||||
|
if member == principal_id {
|
||||||
|
if permission.Allows(action) == Allow {
|
||||||
|
return nil, Allow
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Send the read request to the group so that ContinueAllows can parse the response to check membership
|
||||||
|
messages = messages.Add(ctx, group_id, node, nil, NewReadSignal(map[ExtType][]string{
|
||||||
|
ExtTypeFor[GroupExt](): {"sub_groups"},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(messages) > 0 {
|
||||||
|
return messages, Pending
|
||||||
|
} else {
|
||||||
|
return nil, Deny
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type GroupExt struct {
|
||||||
|
SubGroups map[string][]NodeID `gv:"sub_groups"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGroupExt(sub_groups map[string][]NodeID) *GroupExt {
|
||||||
|
if sub_groups == nil {
|
||||||
|
sub_groups = map[string][]NodeID{}
|
||||||
|
}
|
||||||
|
return &GroupExt{
|
||||||
|
SubGroups: sub_groups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *GroupExt) Load(ctx *Context, node *Node) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *GroupExt) Unload(ctx *Context, node *Node) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
|
||||||
|
var messages Messages = nil
|
||||||
|
var changes = Changes{}
|
||||||
|
|
||||||
|
switch sig := signal.(type) {
|
||||||
|
case *AddMemberSignal:
|
||||||
|
sub_group, exists := ext.SubGroups[sig.SubGroup]
|
||||||
|
if exists == false {
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
|
||||||
|
} else {
|
||||||
|
if slices.Contains(sub_group, sig.MemberID) == true {
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_member"))
|
||||||
|
} else {
|
||||||
|
sub_group = append(sub_group, sig.MemberID)
|
||||||
|
ext.SubGroups[sig.SubGroup] = sub_group
|
||||||
|
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||||
|
changes.Add("sub_groups")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case *RemoveMemberSignal:
|
||||||
|
sub_group, exists := ext.SubGroups[sig.SubGroup]
|
||||||
|
if exists == false {
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
|
||||||
|
} else {
|
||||||
|
idx := slices.Index(sub_group, sig.MemberID)
|
||||||
|
if idx == -1 {
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_member"))
|
||||||
|
} else {
|
||||||
|
sub_group = slices.Delete(sub_group, idx, idx+1)
|
||||||
|
ext.SubGroups[sig.SubGroup] = sub_group
|
||||||
|
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||||
|
changes.Add("sub_groups")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case *AddSubGroupSignal:
|
||||||
|
_, exists := ext.SubGroups[sig.Name]
|
||||||
|
if exists == true {
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_subgroup"))
|
||||||
|
} else {
|
||||||
|
ext.SubGroups[sig.Name] = []NodeID{}
|
||||||
|
|
||||||
|
changes.Add("sub_groups")
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||||
|
}
|
||||||
|
case *RemoveSubGroupSignal:
|
||||||
|
_, exists := ext.SubGroups[sig.Name]
|
||||||
|
if exists == false {
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
|
||||||
|
} else {
|
||||||
|
delete(ext.SubGroups, sig.Name)
|
||||||
|
|
||||||
|
changes.Add("sub_groups")
|
||||||
|
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, changes
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,94 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGroupAdd(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"listener", "test"})
|
||||||
|
|
||||||
|
group_listener := NewListenerExt(10)
|
||||||
|
group, err := NewNode(ctx, nil, "Base", 10, nil, group_listener, NewGroupExt(nil))
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
add_subgroup_signal := NewAddSubGroupSignal("test_group")
|
||||||
|
messages := Messages{}
|
||||||
|
messages = messages.Add(ctx, group.ID, group, nil, add_subgroup_signal)
|
||||||
|
fatalErr(t, ctx.Send(messages))
|
||||||
|
|
||||||
|
resp_1, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_subgroup_signal.Id)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
error_1, is_error := resp_1.(*ErrorSignal)
|
||||||
|
if is_error {
|
||||||
|
t.Fatalf("Error returned: %s", error_1.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
user_id := RandID()
|
||||||
|
add_member_signal := NewAddMemberSignal("test_group", user_id)
|
||||||
|
|
||||||
|
messages = Messages{}
|
||||||
|
messages = messages.Add(ctx, group.ID, group, nil, add_member_signal)
|
||||||
|
fatalErr(t, ctx.Send(messages))
|
||||||
|
|
||||||
|
resp_2, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_member_signal.Id)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
error_2, is_error := resp_2.(*ErrorSignal)
|
||||||
|
if is_error {
|
||||||
|
t.Fatalf("Error returned: %s", error_2.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
read_signal := NewReadSignal(map[ExtType][]string{
|
||||||
|
ExtTypeFor[GroupExt](): {"sub_groups"},
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = Messages{}
|
||||||
|
messages = messages.Add(ctx, group.ID, group, nil, read_signal)
|
||||||
|
fatalErr(t, ctx.Send(messages))
|
||||||
|
|
||||||
|
response, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, read_signal.Id)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
|
read_response := response.(*ReadResultSignal)
|
||||||
|
|
||||||
|
sub_groups_serialized := read_response.Extensions[ExtTypeFor[GroupExt]()]["sub_groups"]
|
||||||
|
|
||||||
|
sub_groups_type, remaining_types, err := DeserializeType(ctx, sub_groups_serialized.TypeStack)
|
||||||
|
fatalErr(t, err)
|
||||||
|
if len(remaining_types) > 0 {
|
||||||
|
t.Fatalf("Types remaining after deserializing subgroups: %d", len(remaining_types))
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_groups_value, remaining, err := DeserializeValue(ctx, sub_groups_type, sub_groups_serialized.Data)
|
||||||
|
fatalErr(t, err)
|
||||||
|
if len(remaining) > 0 {
|
||||||
|
t.Fatalf("Data remaining after deserializing subgroups: %d", len(remaining_types))
|
||||||
|
}
|
||||||
|
|
||||||
|
sub_groups, ok := sub_groups_value.Interface().(map[string][]NodeID)
|
||||||
|
|
||||||
|
if ok != true {
|
||||||
|
t.Fatalf("sub_groups wrong type %s", sub_groups_value.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sub_groups) != 1 {
|
||||||
|
t.Fatalf("sub_groups wrong length %d", len(sub_groups))
|
||||||
|
}
|
||||||
|
|
||||||
|
test_subgroup, exists := sub_groups["test_group"]
|
||||||
|
if exists == false {
|
||||||
|
t.Fatal("test_group not in subgroups")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(test_subgroup) != 1 {
|
||||||
|
t.Fatalf("test_group wrong size %d/1", len(test_subgroup))
|
||||||
|
}
|
||||||
|
|
||||||
|
if test_subgroup[0] != user_id {
|
||||||
|
t.Fatalf("sub_groups wrong value %s", test_subgroup[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Log.Logf("test", "Read Response: %+v", read_response)
|
||||||
|
}
|
@ -1,11 +1,114 @@
|
|||||||
package graphvent
|
package graphvent
|
||||||
|
|
||||||
type SendMsg struct {
|
import (
|
||||||
Dest NodeID
|
"time"
|
||||||
Signal Signal
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthInfo struct {
|
||||||
|
// The Node that issued the authorization
|
||||||
|
Identity ed25519.PublicKey
|
||||||
|
|
||||||
|
// Time the authorization was generated
|
||||||
|
Start time.Time
|
||||||
|
|
||||||
|
// Signature of Start + Principal with Identity private key
|
||||||
|
Signature []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthorizationToken struct {
|
||||||
|
AuthInfo
|
||||||
|
|
||||||
|
// The private key generated by the client, encrypted with the servers public key
|
||||||
|
KeyEncrypted []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientAuthorization struct {
|
||||||
|
AuthInfo
|
||||||
|
|
||||||
|
// The private key generated by the client
|
||||||
|
Key ed25519.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorization structs can be passed in a message that originated from a different node than the sender
|
||||||
|
type Authorization struct {
|
||||||
|
AuthInfo
|
||||||
|
|
||||||
|
// The public key generated for this authorization
|
||||||
|
Key ed25519.PublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
type RecvMsg struct {
|
type Message struct {
|
||||||
Source NodeID
|
Dest NodeID
|
||||||
|
Source ed25519.PublicKey
|
||||||
|
|
||||||
|
Authorization *Authorization
|
||||||
|
|
||||||
Signal Signal
|
Signal Signal
|
||||||
|
Signature []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type Messages []*Message
|
||||||
|
func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages {
|
||||||
|
msg, err := NewMessage(ctx, dest, source, authorization, signal)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
} else {
|
||||||
|
msgs = append(msgs, msg)
|
||||||
|
}
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMessages(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signals... Signal) Messages {
|
||||||
|
messages := Messages{}
|
||||||
|
for _, signal := range(signals) {
|
||||||
|
messages = messages.Add(ctx, dest, source, authorization, signal)
|
||||||
|
}
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) {
|
||||||
|
signal_ser, err := SerializeAny(ctx, signal)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
signal_chunks, err := signal_ser.Chunks()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dest_ser, err := dest.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
source_ser, err := source.ID.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sig_data := append(dest_ser, source_ser...)
|
||||||
|
sig_data = append(sig_data, signal_chunks.Slice()...)
|
||||||
|
var message_auth *Authorization = nil
|
||||||
|
if authorization != nil {
|
||||||
|
sig_data = append(sig_data, authorization.Signature...)
|
||||||
|
message_auth = &Authorization{
|
||||||
|
authorization.AuthInfo,
|
||||||
|
authorization.Key.Public().(ed25519.PublicKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Message{
|
||||||
|
Dest: dest,
|
||||||
|
Source: source.Key.Public().(ed25519.PublicKey),
|
||||||
|
Authorization: message_auth,
|
||||||
|
Signal: signal,
|
||||||
|
Signature: sig,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,139 @@
|
|||||||
|
package graphvent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Policy interface {
|
||||||
|
Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult)
|
||||||
|
ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult
|
||||||
|
ID() uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
type PolicyHeader struct {
|
||||||
|
UUID uuid.UUID `gv:"uuid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (header PolicyHeader) ID() uuid.UUID {
|
||||||
|
return header.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
|
||||||
|
return nil, policy.Rules.Allows(action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
|
||||||
|
for id, actions := range(policy.NodeRules) {
|
||||||
|
if id != principal_id {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, actions.Allows(action)
|
||||||
|
}
|
||||||
|
return nil, Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
func (policy PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
|
||||||
|
func CopyTree(tree Tree) Tree {
|
||||||
|
if tree == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := Tree{}
|
||||||
|
for name, sub := range(tree) {
|
||||||
|
ret[name] = CopyTree(sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func MergeTrees(first Tree, second Tree) Tree {
|
||||||
|
if first == nil || second == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := CopyTree(first)
|
||||||
|
for name, sub := range(second) {
|
||||||
|
current, exists := ret[name]
|
||||||
|
if exists == true {
|
||||||
|
ret[name] = MergeTrees(current, sub)
|
||||||
|
} else {
|
||||||
|
ret[name] = CopyTree(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tree map[SerializedType]Tree
|
||||||
|
|
||||||
|
func (rule Tree) Allows(action Tree) RuleResult {
|
||||||
|
// If the current rule is nil, it's a wildcard and any action being processed is allowed
|
||||||
|
if rule == nil {
|
||||||
|
return Allow
|
||||||
|
// If the rule isn't "allow all" but the action is "request all", deny
|
||||||
|
} else if action == nil {
|
||||||
|
return Deny
|
||||||
|
// If the current action has no children, it's allowed
|
||||||
|
} else if len(action) == 0 {
|
||||||
|
return Allow
|
||||||
|
// If the current rule has no children but the action goes further, it's not allowed
|
||||||
|
} else if len(rule) == 0 {
|
||||||
|
return Deny
|
||||||
|
// If the current rule and action have children, all the children of action must be allowed by rule
|
||||||
|
} else {
|
||||||
|
for sub, subtree := range(action) {
|
||||||
|
subrule, exists := rule[sub]
|
||||||
|
if exists == false {
|
||||||
|
return Deny
|
||||||
|
} else if subrule.Allows(subtree) == Deny {
|
||||||
|
return Deny
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Allow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPolicyHeader() PolicyHeader {
|
||||||
|
return PolicyHeader{
|
||||||
|
UUID: uuid.New(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPerNodePolicy(node_actions map[NodeID]Tree) PerNodePolicy {
|
||||||
|
if node_actions == nil {
|
||||||
|
node_actions = map[NodeID]Tree{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return PerNodePolicy{
|
||||||
|
PolicyHeader: NewPolicyHeader(),
|
||||||
|
NodeRules: node_actions,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PerNodePolicy struct {
|
||||||
|
PolicyHeader
|
||||||
|
NodeRules map[NodeID]Tree `gv:"node_rules"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAllNodesPolicy(rules Tree) AllNodesPolicy {
|
||||||
|
return AllNodesPolicy{
|
||||||
|
PolicyHeader: NewPolicyHeader(),
|
||||||
|
Rules: rules,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AllNodesPolicy struct {
|
||||||
|
PolicyHeader
|
||||||
|
Rules Tree `gv:"rules"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultPolicy = NewAllNodesPolicy(Tree{
|
||||||
|
SerializedType(SignalTypeFor[ResponseSignal]()): nil,
|
||||||
|
SerializedType(SignalTypeFor[StatusSignal]()): nil,
|
||||||
|
})
|
File diff suppressed because it is too large
Load Diff
@ -1,150 +1,247 @@
|
|||||||
package graphvent
|
package graphvent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"github.com/google/uuid"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testTypeStack[T any](t *testing.T, ctx *Context) {
|
func TestSerializeTest(t *testing.T) {
|
||||||
reflect_type := reflect.TypeFor[T]()
|
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"})
|
||||||
stack, err := TypeStack(ctx, reflect_type)
|
testSerialize(t, ctx, map[string][]NodeID{"test_group": {RandID(), RandID(), RandID()}})
|
||||||
fatalErr(t, err)
|
testSerialize(t, ctx, map[NodeID]ReqState{
|
||||||
|
RandID(): Locked,
|
||||||
|
RandID(): Unlocked,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack)
|
func TestSerializeBasic(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"})
|
||||||
|
testSerializeComparable[bool](t, ctx, true)
|
||||||
|
|
||||||
unwrapped_type, rest, err := UnwrapStack(ctx, stack)
|
type bool_wrapped bool
|
||||||
|
err := RegisterType[bool_wrapped](ctx, nil, nil, nil, DeserializeBool[bool_wrapped])
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
testSerializeComparable[bool_wrapped](t, ctx, true)
|
||||||
if len(rest) != 0 {
|
|
||||||
t.Errorf("Types remaining after unwrapping %s: %+v", unwrapped_type, stack)
|
testSerializeSlice[[]bool](t, ctx, []bool{false, false, true, false})
|
||||||
|
testSerializeComparable[string](t, ctx, "test")
|
||||||
|
testSerializeComparable[float32](t, ctx, 0.05)
|
||||||
|
testSerializeComparable[float64](t, ctx, 0.05)
|
||||||
|
testSerializeComparable[uint](t, ctx, uint(1234))
|
||||||
|
testSerializeComparable[uint8] (t, ctx, uint8(123))
|
||||||
|
testSerializeComparable[uint16](t, ctx, uint16(1234))
|
||||||
|
testSerializeComparable[uint32](t, ctx, uint32(12345))
|
||||||
|
testSerializeComparable[uint64](t, ctx, uint64(123456))
|
||||||
|
testSerializeComparable[int](t, ctx, 1234)
|
||||||
|
testSerializeComparable[int8] (t, ctx, int8(-123))
|
||||||
|
testSerializeComparable[int16](t, ctx, int16(-1234))
|
||||||
|
testSerializeComparable[int32](t, ctx, int32(-12345))
|
||||||
|
testSerializeComparable[int64](t, ctx, int64(-123456))
|
||||||
|
testSerializeComparable[time.Duration](t, ctx, time.Duration(100))
|
||||||
|
testSerializeComparable[time.Time](t, ctx, time.Now().Truncate(0))
|
||||||
|
testSerializeSlice[[]int](t, ctx, []int{123, 456, 789, 101112})
|
||||||
|
testSerializeSlice[[]int](t, ctx, ([]int)(nil))
|
||||||
|
testSerializeSliceSlice[[][]int](t, ctx, [][]int{{123, 456, 789, 101112}, {3253, 2341, 735, 212}, {123, 51}, nil})
|
||||||
|
testSerializeSliceSlice[[][]string](t, ctx, [][]string{{"123", "456", "789", "101112"}, {"3253", "2341", "735", "212"}, {"123", "51"}, nil})
|
||||||
|
|
||||||
|
testSerialize(t, ctx, map[int8]map[*int8]string{})
|
||||||
|
testSerialize(t, ctx, map[int8]time.Time{
|
||||||
|
1: time.Now(),
|
||||||
|
3: time.Now().Add(time.Second),
|
||||||
|
0: time.Now().Add(time.Second*2),
|
||||||
|
4: time.Now().Add(time.Second*3),
|
||||||
|
})
|
||||||
|
|
||||||
|
testSerialize(t, ctx, Tree{
|
||||||
|
SerializedTypeFor[NodeType](): nil,
|
||||||
|
SerializedTypeFor[SerializedType](): {
|
||||||
|
SerializedTypeFor[NodeType](): Tree{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
var i interface{} = nil
|
||||||
|
testSerialize(t, ctx, i)
|
||||||
|
|
||||||
|
testSerializeMap(t, ctx, map[int8]interface{}{
|
||||||
|
0: "abcd",
|
||||||
|
1: uint32(12345678),
|
||||||
|
2: i,
|
||||||
|
3: 123,
|
||||||
|
})
|
||||||
|
|
||||||
|
testSerializeMap(t, ctx, map[int8]int32{
|
||||||
|
0: 1234,
|
||||||
|
2: 5678,
|
||||||
|
4: 9101,
|
||||||
|
6: 1121,
|
||||||
|
})
|
||||||
|
|
||||||
|
type test_struct struct {
|
||||||
|
Int int `gv:"int"`
|
||||||
|
String string `gv:"string"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if unwrapped_type != reflect_type {
|
err = RegisterStruct[test_struct](ctx)
|
||||||
t.Errorf("Unwrapped type[%+v] doesn't match original[%+v]", unwrapped_type, reflect_type)
|
fatalErr(t, err)
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Unwrapped type[%s]: %s", reflect_type, reflect_type)
|
testSerialize(t, ctx, test_struct{
|
||||||
}
|
12345,
|
||||||
|
"test_string",
|
||||||
|
})
|
||||||
|
|
||||||
|
testSerialize(t, ctx, Tree{
|
||||||
|
SerializedKindFor(reflect.Map): nil,
|
||||||
|
SerializedKindFor(reflect.String): nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
testSerialize(t, ctx, Tree{
|
||||||
|
SerializedTypeFor[Tree](): nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
testSerialize(t, ctx, Tree{
|
||||||
|
SerializedTypeFor[Tree](): {
|
||||||
|
SerializedTypeFor[error](): Tree{},
|
||||||
|
SerializedKindFor(reflect.Map): nil,
|
||||||
|
},
|
||||||
|
SerializedKindFor(reflect.String): nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
type test_slice []string
|
||||||
|
err = RegisterType[test_slice](ctx, SerializeTypeStub, SerializeSlice, DeserializeTypeStub[test_slice], DeserializeSlice)
|
||||||
|
fatalErr(t, err)
|
||||||
|
|
||||||
func TestSerializeTypes(t *testing.T) {
|
testSerialize[[]string](t, ctx, []string{"test_1", "test_2", "test_3"})
|
||||||
ctx := logTestContext(t, []string{"test"})
|
testSerialize[test_slice](t, ctx, test_slice{"test_1", "test_2", "test_3"})
|
||||||
|
}
|
||||||
|
|
||||||
testTypeStack[int](t, ctx)
|
type test struct {
|
||||||
testTypeStack[map[int]string](t, ctx)
|
Int int `gv:"int"`
|
||||||
testTypeStack[string](t, ctx)
|
Str string `gv:"string"`
|
||||||
testTypeStack[*string](t, ctx)
|
|
||||||
testTypeStack[*map[string]*map[*string]int](t, ctx)
|
|
||||||
testTypeStack[[5]int](t, ctx)
|
|
||||||
testTypeStack[uuid.UUID](t, ctx)
|
|
||||||
testTypeStack[NodeID](t, ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) {
|
func (s test) String() string {
|
||||||
serialized, err := Serialize(ctx, value)
|
return fmt.Sprintf("%d:%s", s.Int, s.Str)
|
||||||
fatalErr(t, err)
|
}
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
func TestSerializeStructTags(t *testing.T) {
|
||||||
|
ctx := logTestContext(t, []string{"test"})
|
||||||
|
|
||||||
deserialized, err := Deserialize[T](ctx, serialized)
|
err := RegisterStruct[test](ctx)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
|
||||||
if value != deserialized {
|
test_int := 10
|
||||||
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
test_string := "test"
|
||||||
|
|
||||||
|
ret := testSerialize(t, ctx, test{
|
||||||
|
test_int,
|
||||||
|
test_string,
|
||||||
|
})
|
||||||
|
if ret.Int != test_int {
|
||||||
|
t.Fatalf("Deserialized int %d does not equal test %d", ret.Int, test_int)
|
||||||
|
} else if ret.Str != test_string {
|
||||||
|
t.Fatalf("Deserialized string %s does not equal test %s", ret.Str, test_string)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
testSerialize(t, ctx, []test{
|
||||||
|
{
|
||||||
|
test_int,
|
||||||
|
test_string,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
test_int * 2,
|
||||||
|
fmt.Sprintf("%s%s", test_string, test_string),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
test_int * 4,
|
||||||
|
fmt.Sprintf("%s%s%s", test_string, test_string, test_string),
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) {
|
func testSerializeMap[M map[T]R, T, R comparable](t *testing.T, ctx *Context, val M) {
|
||||||
serialized, err := Serialize(ctx, value)
|
v := testSerialize(t, ctx, val)
|
||||||
fatalErr(t, err)
|
for key, value := range(val) {
|
||||||
|
recreated, exists := v[key]
|
||||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized)
|
if exists == false {
|
||||||
|
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||||
deserialized, err := Deserialize[L](ctx, serialized)
|
} else if recreated != value {
|
||||||
fatalErr(t, err)
|
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||||
|
|
||||||
for i, item := range(value) {
|
|
||||||
if item != deserialized[i] {
|
|
||||||
t.Fatalf("Deserialized list %+v does not match original %+v", value, deserialized)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(v) != len(val) {
|
||||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) {
|
func testSerializeSliceSlice[S [][]T, T comparable](t *testing.T, ctx *Context, val S) {
|
||||||
serialized, err := Serialize(ctx, value)
|
v := testSerialize(t, ctx, val)
|
||||||
fatalErr(t, err)
|
for i, original := range(val) {
|
||||||
|
if (original == nil && v[i] != nil) || (original != nil && v[i] == nil) {
|
||||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized)
|
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||||
|
}
|
||||||
|
for j, o := range(original) {
|
||||||
|
if v[i][j] != o {
|
||||||
|
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
deserialized, err := Deserialize[P](ctx, serialized)
|
func testSerializeSlice[S []T, T comparable](t *testing.T, ctx *Context, val S) {
|
||||||
fatalErr(t, err)
|
v := testSerialize(t, ctx, val)
|
||||||
|
for i, original := range(val) {
|
||||||
|
if v[i] != original {
|
||||||
|
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if value == nil && deserialized == nil {
|
func testSerializeComparable[T comparable](t *testing.T, ctx *Context, val T) {
|
||||||
ctx.Log.Logf("test", "Deserialized nil")
|
v := testSerialize(t, ctx, val)
|
||||||
} else if value == nil {
|
if v != val {
|
||||||
t.Fatalf("Non-nil value[%+v] returned for nil value", deserialized)
|
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
|
||||||
} else if deserialized == nil {
|
|
||||||
t.Fatalf("Nil value returned for non-nil value[%+v]", value)
|
|
||||||
} else if *deserialized != *value {
|
|
||||||
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
|
|
||||||
} else {
|
|
||||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", *value, *deserialized)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSerialize[T any](t *testing.T, ctx *Context, value T) {
|
func testSerialize[T any](t *testing.T, ctx *Context, val T) T {
|
||||||
serialized, err := Serialize(ctx, value)
|
value := reflect.ValueOf(&val).Elem()
|
||||||
|
type_stack, err := SerializeType(ctx, value.Type())
|
||||||
|
chunks, err := SerializeValue(ctx, value)
|
||||||
|
value_serialized := SerializedValue{type_stack, chunks.Slice()}
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "Serialized %+v to %+v(%d)", val, value_serialized, len(value_serialized.Data))
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
|
value_chunks, err := value_serialized.Chunks()
|
||||||
|
|
||||||
deserialized, err := Deserialize[T](ctx, serialized)
|
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "Binary: %+v", value_chunks.Slice())
|
||||||
|
|
||||||
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
|
val_parsed, remaining_parse, err := ParseSerializedValue(value_chunks.Slice())
|
||||||
}
|
fatalErr(t, err)
|
||||||
|
ctx.Log.Logf("test", "Parsed: %+v", val_parsed)
|
||||||
func TestSerializeValues(t *testing.T) {
|
|
||||||
ctx := logTestContext(t, []string{"test"})
|
|
||||||
|
|
||||||
testSerialize(t, ctx, Extension(NewLockableExt(nil)))
|
|
||||||
|
|
||||||
testSerializeCompare[int8](t, ctx, -64)
|
|
||||||
testSerializeCompare[int16](t, ctx, -64)
|
|
||||||
testSerializeCompare[int32](t, ctx, -64)
|
|
||||||
testSerializeCompare[int64](t, ctx, -64)
|
|
||||||
testSerializeCompare[int](t, ctx, -64)
|
|
||||||
|
|
||||||
testSerializeCompare[uint8](t, ctx, 64)
|
|
||||||
testSerializeCompare[uint16](t, ctx, 64)
|
|
||||||
testSerializeCompare[uint32](t, ctx, 64)
|
|
||||||
testSerializeCompare[uint64](t, ctx, 64)
|
|
||||||
testSerializeCompare[uint](t, ctx, 64)
|
|
||||||
|
|
||||||
testSerializeCompare[string](t, ctx, "test")
|
|
||||||
|
|
||||||
a := 12
|
|
||||||
testSerializePointer[*int](t, ctx, &a)
|
|
||||||
|
|
||||||
b := "test"
|
|
||||||
testSerializePointer[*string](t, ctx, nil)
|
|
||||||
testSerializePointer[*string](t, ctx, &b)
|
|
||||||
|
|
||||||
testSerializeList(t, ctx, []int{1, 2, 3, 4, 5})
|
|
||||||
|
|
||||||
testSerializeCompare[bool](t, ctx, true)
|
if len(remaining_parse) != 0 {
|
||||||
testSerializeCompare[bool](t, ctx, false)
|
t.Fatal("Data remaining after deserializing value")
|
||||||
testSerializeCompare[int](t, ctx, -1)
|
}
|
||||||
testSerializeCompare[uint](t, ctx, 1)
|
|
||||||
testSerializeCompare[NodeID](t, ctx, RandID())
|
|
||||||
testSerializeCompare[*int](t, ctx, nil)
|
|
||||||
testSerializeCompare(t, ctx, "string")
|
|
||||||
|
|
||||||
node, err := NewNode(ctx, nil, "Base", 100)
|
val_type, remaining_types, err := DeserializeType(ctx, val_parsed.TypeStack)
|
||||||
|
deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_type, val_parsed.Data)
|
||||||
fatalErr(t, err)
|
fatalErr(t, err)
|
||||||
testSerialize(t, ctx, node)
|
|
||||||
|
if len(remaining_deserialize) != 0 {
|
||||||
|
t.Fatal("Data remaining after deserializing value")
|
||||||
|
} else if len(remaining_types) != 0 {
|
||||||
|
t.Fatal("TypeStack remaining after deserializing value")
|
||||||
|
} else if val_type != value.Type() {
|
||||||
|
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - %+v", val_type, reflect.TypeOf(val)))
|
||||||
|
} else if deserialized_value.CanConvert(val_type) == false {
|
||||||
|
t.Fatal("DeserializeValue returned value that can't convert to original value")
|
||||||
|
}
|
||||||
|
ctx.Log.Logf("test", "Value: %+v", deserialized_value.Interface())
|
||||||
|
if val_type.Kind() == reflect.Interface && deserialized_value.Interface() == nil {
|
||||||
|
var zero T
|
||||||
|
return zero
|
||||||
|
}
|
||||||
|
return deserialized_value.Interface().(T)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue