Compare commits

..

8 Commits

26 changed files with 1774 additions and 4549 deletions

233
acl.go

@ -1,233 +0,0 @@
package graphvent
import (
"github.com/google/uuid"
"slices"
"time"
)
type ACLSignal struct {
SignalHeader
Principal NodeID `gv:"principal"`
Action Tree `gv:"tree"`
}
func NewACLSignal(principal NodeID, action Tree) *ACLSignal {
return &ACLSignal{
SignalHeader: NewSignalHeader(Direct),
Principal: principal,
Action: action,
}
}
var DefaultACLPolicy = NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[ACLSignal]()): nil,
})
func (signal ACLSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[ACLSignal]()): nil,
}
}
type ACLExt struct {
Policies []Policy `gv:"policies"`
PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"`
Pending map[uuid.UUID]PendingACLSignal `gv:"pending"`
}
func NewACLExt(policies []Policy) *ACLExt {
return &ACLExt{
Policies: policies,
PendingACLs: map[uuid.UUID]PendingACL{},
Pending: map[uuid.UUID]PendingACLSignal{},
}
}
func (ext *ACLExt) Load(ctx *Context, node *Node) error {
return nil
}
func (ext *ACLExt) Unload(ctx *Context, node *Node) {
}
func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
response, is_response := signal.(ResponseSignal)
if is_response == true {
var messages Messages = nil
var changes = Changes{}
info, waiting := ext.Pending[response.ResponseID()]
if waiting == true {
changes.Add("pending")
delete(ext.Pending, response.ResponseID())
if response.ID() != info.Timeout {
err := node.DequeueSignal(info.Timeout)
if err != nil {
ctx.Log.Logf("acl", "timeout dequeue error: %s", err)
}
}
acl_info, found := ext.PendingACLs[info.ID]
if found == true {
acl_info.Counter -= 1
acl_info.Responses = append(acl_info.Responses, response)
policy_index := slices.IndexFunc(ext.Policies, func(policy Policy) bool {
return policy.ID() == info.Policy
})
if policy_index == -1 {
ctx.Log.Logf("acl", "pending signal for nonexistent policy")
delete(ext.PendingACLs, info.ID)
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else {
if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow {
changes.Add("pending_acls")
delete(ext.PendingACLs, info.ID)
ctx.Log.Logf("acl", "Request delayed allow")
messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID))
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else if acl_info.Counter == 0 {
changes.Add("pending_acls")
delete(ext.PendingACLs, info.ID)
ctx.Log.Logf("acl", "Request delayed deny")
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied"))
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else {
node.PendingACLs[info.ID] = acl_info
changes.Add("pending_acls")
}
}
}
}
return messages, changes
}
var messages Messages = nil
var changes = Changes{}
switch sig := signal.(type) {
case *ACLSignal:
var acl_messages map[uuid.UUID]Messages = nil
denied := true
for _, policy := range(ext.Policies) {
policy_messages, result := policy.Allows(ctx, sig.Principal, sig.Action, node)
if result == Allow {
denied = false
break
} else if result == Pending {
if len(policy_messages) == 0 {
ctx.Log.Logf("acl", "Pending result for %s with no messages returned", policy.ID())
continue
} else if acl_messages == nil {
acl_messages = map[uuid.UUID]Messages{}
denied = false
}
acl_messages[policy.ID()] = policy_messages
ctx.Log.Logf("acl", "Pending result for %s:%s - %+v", node.ID, policy.ID(), acl_messages)
}
}
if denied == true {
ctx.Log.Logf("acl", "Request denied")
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied"))
} else if acl_messages != nil {
ctx.Log.Logf("acl", "Request pending")
changes.Add("pending")
total_messages := 0
// TODO: reasonable timeout/configurable
timeout_time := time.Now().Add(time.Second)
for policy_id, policy_messages := range(acl_messages) {
total_messages += len(policy_messages)
for _, message := range(policy_messages) {
timeout_signal := NewTimeoutSignal(message.Signal.ID())
ext.Pending[message.Signal.ID()] = PendingACLSignal{
Policy: policy_id,
Timeout: timeout_signal.Id,
ID: sig.Id,
}
node.QueueSignal(timeout_time, timeout_signal)
messages = append(messages, message)
}
}
acl_timeout := NewACLTimeoutSignal(sig.Id)
node.QueueSignal(timeout_time, acl_timeout)
ext.PendingACLs[sig.Id] = PendingACL{
Counter: total_messages,
Responses: []ResponseSignal{},
TimeoutID: acl_timeout.Id,
Action: sig.Action,
Principal: sig.Principal,
Source: source,
Signal: signal,
}
} else {
ctx.Log.Logf("acl", "Request allowed")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
}
// Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request
case *ACLTimeoutSignal:
acl_info, exists := ext.PendingACLs[sig.ReqID]
if exists == true {
delete(ext.PendingACLs, sig.ReqID)
changes.Add("pending_acls")
ctx.Log.Logf("acl", "Request timeout deny")
messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout"))
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else {
ctx.Log.Logf("acl", "ACL_TIMEOUT_SIGNAL for passed acl")
}
// Delete from PendingACLs
}
return messages, changes
}
type ACLProxyPolicy struct {
PolicyHeader
Proxies []NodeID `gv:"proxies"`
}
func NewACLProxyPolicy(proxies []NodeID) ACLProxyPolicy {
return ACLProxyPolicy{
NewPolicyHeader(),
proxies,
}
}
func (policy ACLProxyPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
if len(policy.Proxies) == 0 {
return nil, Deny
}
messages := Messages{}
for _, proxy := range(policy.Proxies) {
messages = messages.Add(ctx, proxy, node, nil, NewACLSignal(principal_id, action))
}
return messages, Pending
}
func (policy ACLProxyPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
_, is_success := signal.(*SuccessSignal)
if is_success == true {
return Allow
}
return Deny
}

@ -1,141 +0,0 @@
package graphvent
import (
"testing"
"time"
"reflect"
"runtime/debug"
)
func checkSignal[S Signal](t *testing.T, signal Signal, check func(S)){
response_casted, cast_ok := signal.(S)
if cast_ok == false {
error_signal, is_error := signal.(*ErrorSignal)
if is_error {
t.Log(string(debug.Stack()))
t.Fatal(error_signal.Error)
}
t.Fatalf("Response of wrong type %s", reflect.TypeOf(signal))
}
check(response_casted)
}
func testSendACL[S Signal](t *testing.T, ctx *Context, listener *Node, action Tree, policies []Policy, check func(S)){
acl_node, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(policies))
fatalErr(t, err)
acl_signal := NewACLSignal(listener.ID, action)
response, _ := testSend(t, ctx, acl_signal, listener, acl_node)
checkSignal(t, response, check)
}
func testErrorSignal(t *testing.T, error_string string) func(*ErrorSignal){
return func(response *ErrorSignal) {
if response.Error != error_string {
t.Fatalf("Wrong error: %s", response.Error)
}
}
}
func testSuccess(*SuccessSignal){}
func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) {
source_listener, err := GetExt[ListenerExt](source)
fatalErr(t, err)
messages := Messages{}
messages = messages.Add(ctx, destination.ID, source, nil, signal)
fatalErr(t, ctx.Send(messages))
response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())
fatalErr(t, err)
return response, signals
}
func TestACLBasic(t *testing.T) {
ctx := logTestContext(t, []string{"test", "acl", "group", "read_field"})
listener, err := NewNode(ctx, nil, "Base", 100, nil, NewListenerExt(100))
fatalErr(t, err)
ctx.Log.Logf("test", "testing fail")
testSendACL(t, ctx, listener, nil, nil, testErrorSignal(t, "acl_denied"))
ctx.Log.Logf("test", "testing allow all")
testSendACL(t, ctx, listener, nil, []Policy{NewAllNodesPolicy(nil)}, testSuccess)
group, err := NewNode(ctx, nil, "Base", 100, []Policy{
DefaultGroupPolicy,
NewPerNodePolicy(map[NodeID]Tree{
listener.ID: {
SerializedType(SignalTypeFor[AddSubGroupSignal]()): nil,
SerializedType(SignalTypeFor[AddMemberSignal]()): nil,
},
}),
}, NewGroupExt(nil))
fatalErr(t, err)
ctx.Log.Logf("test", "testing empty groups")
testSendACL(t, ctx, listener, nil, []Policy{
NewMemberOfPolicy(map[NodeID]map[string]Tree{
group.ID: {
"test_group": nil,
},
}),
}, testErrorSignal(t, "acl_denied"))
ctx.Log.Logf("test", "testing adding group")
add_subgroup_signal := NewAddSubGroupSignal("test_group")
add_subgroup_response, _ := testSend(t, ctx, add_subgroup_signal, listener, group)
checkSignal(t, add_subgroup_response, testSuccess)
ctx.Log.Logf("test", "testing adding member")
add_member_signal := NewAddMemberSignal("test_group", listener.ID)
add_member_response, _ := testSend(t, ctx, add_member_signal, listener, group)
checkSignal(t, add_member_response, testSuccess)
ctx.Log.Logf("test", "testing group membership")
testSendACL(t, ctx, listener, nil, []Policy{
NewMemberOfPolicy(map[NodeID]map[string]Tree{
group.ID: {
"test_group": nil,
},
}),
}, testSuccess)
testSendACL(t, ctx, listener, nil, []Policy{
NewACLProxyPolicy(nil),
}, testErrorSignal(t, "acl_denied"))
acl_proxy_1, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt(nil))
fatalErr(t, err)
testSendACL(t, ctx, listener, nil, []Policy{
NewACLProxyPolicy([]NodeID{acl_proxy_1.ID}),
}, testErrorSignal(t, "acl_denied"))
acl_proxy_2, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewAllNodesPolicy(nil)}))
fatalErr(t, err)
testSendACL(t, ctx, listener, nil, []Policy{
NewACLProxyPolicy([]NodeID{acl_proxy_2.ID}),
}, testSuccess)
acl_proxy_3, err := NewNode(ctx, nil, "Base", 100, []Policy{DefaultACLPolicy},
NewACLExt([]Policy{
NewMemberOfPolicy(map[NodeID]map[string]Tree{
group.ID: {
"test_group": nil,
},
}),
}),
)
fatalErr(t, err)
testSendACL(t, ctx, listener, nil, []Policy{
NewACLProxyPolicy([]NodeID{acl_proxy_3.ID}),
}, testSuccess)
}

@ -0,0 +1,43 @@
package main
import (
"fmt"
badger "github.com/dgraph-io/badger/v3"
gv "github.com/mekkanized/graphvent"
)
func check(err error) {
if err != nil {
panic(err)
}
}
func main() {
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true))
check(err)
ctx, err := gv.NewContext(db, gv.NewConsoleLogger([]string{"test", "signal"}))
check(err)
gql_ext, err := gv.NewGQLExt(ctx, ":8080", nil, nil)
check(err)
listener_ext := gv.NewListenerExt(1000)
n1, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt(nil))
check(err)
n2, err := gv.NewNode(ctx, nil, "Lockable", 1000, gv.NewLockableExt([]gv.NodeID{n1.ID}))
check(err)
_, err = gv.NewNode(ctx, nil, "Lockable", 1000, gql_ext, listener_ext, gv.NewLockableExt([]gv.NodeID{n2.ID}))
check(err)
for true {
select {
case message := <- listener_ext.Chan:
fmt.Printf("Listener Message: %+v\n", message)
}
}
}

File diff suppressed because it is too large Load Diff

195
db.go

@ -0,0 +1,195 @@
package graphvent
import (
"encoding/binary"
"fmt"
badger "github.com/dgraph-io/badger/v3"
)
func WriteNodeInit(ctx *Context, node *Node) error {
if node == nil {
return fmt.Errorf("Cannot serialize nil *Node")
}
return ctx.DB.Update(func(tx *badger.Txn) error {
// Get the base key bytes
id_ser, err := node.ID.MarshalBinary()
if err != nil {
return err
}
// Write Node value
node_val, err := Serialize(ctx, node)
if err != nil {
return err
}
err = tx.Set(id_ser, node_val)
if err != nil {
return err
}
// Write empty signal queue
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
sigqueue_val, err := Serialize(ctx, node.SignalQueue)
if err != nil {
return err
}
err = tx.Set(sigqueue_id, sigqueue_val)
if err != nil {
return err
}
// Write node extension list
ext_list := []ExtType{}
for ext_type := range(node.Extensions) {
ext_list = append(ext_list, ext_type)
}
ext_list_val, err := Serialize(ctx, ext_list)
if err != nil {
return err
}
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
err = tx.Set(ext_list_id, ext_list_val)
if err != nil {
return err
}
// For each extension:
for ext_type, ext := range(node.Extensions) {
// Write each extension's current value
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
ext_val, err := Serialize(ctx, ext)
if err != nil {
return err
}
err = tx.Set(ext_id, ext_val)
}
return nil
})
}
func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
return ctx.DB.Update(func(tx *badger.Txn) error {
// Get the base key bytes
id_ser, err := node.ID.MarshalBinary()
if err != nil {
return err
}
// Write the signal queue if it needs to be written
if node.writeSignalQueue {
node.writeSignalQueue = false
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
sigqueue_val, err := Serialize(ctx, node.SignalQueue)
if err != nil {
return err
}
err = tx.Set(sigqueue_id, sigqueue_val)
if err != nil {
return err
}
}
// For each ext in changes
for ext_type := range(changes) {
// Write each ext
ext, exists := node.Extensions[ext_type]
if exists == false {
return fmt.Errorf("%s is not an extension in %s", ext_type, node.ID)
}
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
ext_ser, err := Serialize(ctx, ext)
if err != nil {
return err
}
err = tx.Set(ext_id, ext_ser)
if err != nil {
return err
}
}
return nil
})
}
func LoadNode(ctx *Context, id NodeID) (*Node, error) {
var node *Node = nil
err := ctx.DB.View(func(tx *badger.Txn) error {
// Get the base key bytes
id_ser, err := id.MarshalBinary()
if err != nil {
return err
}
// Get the node value
node_item, err := tx.Get(id_ser)
if err != nil {
return err
}
err = node_item.Value(func(val []byte) error {
node, err = Deserialize[*Node](ctx, val)
return err
})
if err != nil {
return nil
}
// Get the signal queue
sigqueue_id := append(id_ser, []byte(" - SIGQUEUE")...)
sigqueue_item, err := tx.Get(sigqueue_id)
if err != nil {
return err
}
err = sigqueue_item.Value(func(val []byte) error {
node.SignalQueue, err = Deserialize[[]QueuedSignal](ctx, val)
return err
})
if err != nil {
return err
}
// Get the extension list
ext_list_id := append(id_ser, []byte(" - EXTLIST")...)
ext_list_item, err := tx.Get(ext_list_id)
if err != nil {
return err
}
var ext_list []ExtType
ext_list_item.Value(func(val []byte) error {
ext_list, err = Deserialize[[]ExtType](ctx, val)
return err
})
// Get the extensions
for _, ext_type := range(ext_list) {
ext_id := binary.BigEndian.AppendUint64(id_ser, uint64(ext_type))
ext_item, err := tx.Get(ext_id)
if err != nil {
return err
}
var ext Extension
err = ext_item.Value(func(val []byte) error {
ext, err = Deserialize[Extension](ctx, val)
return err
})
if err != nil {
return err
}
node.Extensions[ext_type] = ext
}
return nil
})
if err != nil {
return nil, err
}
return node, nil
}

@ -8,45 +8,11 @@ import (
type EventCommand string type EventCommand string
type EventState string type EventState string
type ParentOfPolicy struct {
PolicyHeader
Policy Tree
}
func NewParentOfPolicy(policy Tree) *ParentOfPolicy {
return &ParentOfPolicy{
PolicyHeader: NewPolicyHeader(),
Policy: policy,
}
}
func (policy ParentOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
event_ext, err := GetExt[EventExt](node)
if err != nil {
ctx.Log.Logf("event", "ParentOfPolicy, node not event %s", node.ID)
return nil, Deny
}
if event_ext.Parent == principal_id {
return nil, policy.Policy.Allows(action)
}
return nil, Deny
}
func (policy ParentOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
var DefaultEventPolicy = NewParentOfPolicy(Tree{
SerializedType(SignalTypeFor[EventControlSignal]()): nil,
})
type EventExt struct { type EventExt struct {
Name string `gv:"name"` Name string `gv:"name"`
State EventState `gv:"state"` State EventState `gv:"state"`
StateStart time.Time `gv:"state_start"` StateStart time.Time `gv:"state_start"`
Parent NodeID `gv:"parent"` Parent NodeID `gv:"parent" node:"Base"`
} }
func (ext *EventExt) Load(ctx *Context, node *Node) error { func (ext *EventExt) Load(ctx *Context, node *Node) error {
@ -71,19 +37,13 @@ type EventStateSignal struct {
Time time.Time `gv:"time"` Time time.Time `gv:"time"`
} }
func (signal EventStateSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[StatusSignal]()): nil,
}
}
func (signal EventStateSignal) String() string { func (signal EventStateSignal) String() string {
return fmt.Sprintf("EventStateSignal(%s, %s, %s, %+v)", signal.SignalHeader, signal.Source, signal.State, signal.Time) return fmt.Sprintf("EventStateSignal(%s, %s, %s, %+v)", signal.SignalHeader, signal.Source, signal.State, signal.Time)
} }
func NewEventStateSignal(source NodeID, state EventState, t time.Time) *EventStateSignal { func NewEventStateSignal(source NodeID, state EventState, t time.Time) *EventStateSignal {
return &EventStateSignal{ return &EventStateSignal{
SignalHeader: NewSignalHeader(Up), SignalHeader: NewSignalHeader(),
Source: source, Source: source,
State: state, State: state,
Time: t, Time: t,
@ -101,19 +61,11 @@ func (signal EventControlSignal) String() string {
func NewEventControlSignal(command EventCommand) *EventControlSignal { func NewEventControlSignal(command EventCommand) *EventControlSignal {
return &EventControlSignal{ return &EventControlSignal{
NewSignalHeader(Direct), NewSignalHeader(),
command, command,
} }
} }
func (signal EventControlSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[EventControlSignal]()): {
Hash("command", string(signal.Command)): nil,
},
}
}
func (ext *EventExt) UpdateState(node *Node, changes Changes, state EventState, state_start time.Time) { func (ext *EventExt) UpdateState(node *Node, changes Changes, state EventState, state_start time.Time) {
if ext.State != state { if ext.State != state {
ext.StateStart = state_start ext.StateStart = state_start
@ -123,14 +75,10 @@ func (ext *EventExt) UpdateState(node *Node, changes Changes, state EventState,
} }
} }
func (ext *EventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { func (ext *EventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
var messages Messages = nil var messages []SendMsg = nil
var changes = Changes{} var changes = Changes{}
if signal.Direction() == Up && ext.Parent != node.ID {
messages = messages.Add(ctx, ext.Parent, node, nil, signal)
}
return messages, changes return messages, changes
} }
@ -165,27 +113,27 @@ var test_event_commands = EventCommandMap{
} }
func (ext *TestEventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { func (ext *TestEventExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
var messages Messages = nil var messages []SendMsg = nil
var changes = Changes{} var changes = Changes{}
switch sig := signal.(type) { switch sig := signal.(type) {
case *EventControlSignal: case *EventControlSignal:
event_ext, err := GetExt[EventExt](node) event_ext, err := GetExt[EventExt](node)
if err != nil { if err != nil {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_event")) messages = append(messages, SendMsg{source, NewErrorSignal(sig.Id, "not_event")})
} else { } else {
ctx.Log.Logf("event", "%s got %s EventControlSignal while in %s", node.ID, sig.Command, event_ext.State) ctx.Log.Logf("event", "%s got %s EventControlSignal while in %s", node.ID, sig.Command, event_ext.State)
new_state, error_signal := event_ext.ValidateEventCommand(sig, test_event_commands) new_state, error_signal := event_ext.ValidateEventCommand(sig, test_event_commands)
if error_signal != nil { if error_signal != nil {
messages = messages.Add(ctx, source, node, nil, error_signal) messages = append(messages, SendMsg{source, error_signal})
} else { } else {
switch sig.Command { switch sig.Command {
case "start": case "start":
node.QueueSignal(time.Now().Add(ext.Length), NewEventControlSignal("finish")) node.QueueSignal(time.Now().Add(ext.Length), NewEventControlSignal("finish"))
} }
event_ext.UpdateState(node, changes, new_state, time.Now()) event_ext.UpdateState(node, changes, new_state, time.Now())
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) messages = append(messages, SendMsg{source, NewSuccessSignal(sig.Id)})
} }
} }
} }

@ -11,11 +11,13 @@ func TestEvent(t *testing.T) {
ctx := logTestContext(t, []string{"event", "listener", "listener_debug"}) ctx := logTestContext(t, []string{"event", "listener", "listener_debug"})
err := RegisterExtension[TestEventExt](ctx, nil) err := RegisterExtension[TestEventExt](ctx, nil)
fatalErr(t, err) fatalErr(t, err)
err = RegisterObject[TestEventExt](ctx)
fatalErr(t, err)
event_public, event_private, err := ed25519.GenerateKey(rand.Reader) event_public, event_private, err := ed25519.GenerateKey(rand.Reader)
event_listener := NewListenerExt(100) event_listener := NewListenerExt(100)
event, err := NewNode(ctx, event_private, "Base", 100, nil, NewEventExt(KeyID(event_public), "Test Event"), &TestEventExt{time.Second}, event_listener) event, err := NewNode(ctx, event_private, "Base", 100, NewEventExt(KeyID(event_public), "Test Event"), &TestEventExt{time.Second}, event_listener)
fatalErr(t, err) fatalErr(t, err)
response, signals := testSend(t, ctx, NewEventControlSignal("ready?"), event, event) response, signals := testSend(t, ctx, NewEventControlSignal("ready?"), event, event)

@ -7,7 +7,7 @@ import (
// Extensions are data attached to nodes that process signals // Extensions are data attached to nodes that process signals
type Extension interface { type Extension interface {
// Called to process incoming signals, returning changes and messages to send // Called to process incoming signals, returning changes and messages to send
Process(*Context, *Node, NodeID, Signal) (Messages, Changes) Process(*Context, *Node, NodeID, Signal) ([]SendMsg, Changes)
// Called when the node is loaded into a context(creation or move), so extension data can be initialized // Called when the node is loaded into a context(creation or move), so extension data can be initialized
Load(*Context, *Node) error Load(*Context, *Node) error

@ -31,6 +31,6 @@ require (
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/stretchr/testify v1.8.2 // indirect github.com/stretchr/testify v1.8.2 // indirect
go.opencensus.io v0.22.5 // indirect go.opencensus.io v0.22.5 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect
golang.org/x/sys v0.13.0 // indirect golang.org/x/sys v0.13.0 // indirect
) )

@ -111,6 +111,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=

684
gql.go

@ -1,18 +1,12 @@
package graphvent package graphvent
import ( import (
"bytes"
"context" "context"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdh" "crypto/ecdh"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -20,12 +14,9 @@ import (
"net" "net"
"net/http" "net/http"
"reflect" "reflect"
"strings"
"sync" "sync"
"time" "time"
"filippo.io/edwards25519"
"crypto/sha512"
"github.com/gobwas/ws" "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
@ -40,78 +31,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
func NodeInterfaceDefaultIsType(required_extensions []ExtType) func(graphql.IsTypeOfParams) bool {
return func(p graphql.IsTypeOfParams) bool {
ctx, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return false
}
node, ok := p.Value.(NodeResult)
if ok == false {
return false
}
node_type_def, exists := ctx.Context.Nodes[node.NodeType]
if exists == false {
return false
} else {
for _, ext := range(required_extensions) {
found := false
for _, e := range(node_type_def.Extensions) {
if e == ext {
found = true
break
}
}
if found == false {
return false
}
}
}
return true
}
}
func NodeInterfaceResolveType(required_extensions []ExtType, default_type **graphql.Object)func(graphql.ResolveTypeParams) *graphql.Object {
return func(p graphql.ResolveTypeParams) *graphql.Object {
ctx, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false {
return nil
}
node, ok := p.Value.(NodeResult)
if ok == false {
return nil
}
gql_type, exists := ctx.GQLContext.NodeTypes[node.NodeType]
ctx.Context.Log.Logf("gql", "GQL_INTERFACE_RESOLVE_TYPE(%+v): %+v - %t - %+v - %+v", node, gql_type, exists, required_extensions, *default_type)
if exists == false {
node_type_def, exists := ctx.Context.Nodes[node.NodeType]
if exists == false {
return nil
} else {
for _, ext := range(required_extensions) {
found := false
for _, e := range(node_type_def.Extensions) {
if e == ext {
found = true
break
}
}
if found == false {
return nil
}
}
}
return *default_type
}
return gql_type
}
}
func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) { func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) {
resolve_context, ok := p.Context.Value("resolve").(*ResolveContext) resolve_context, ok := p.Context.Value("resolve").(*ResolveContext)
if ok == false { if ok == false {
@ -121,7 +40,7 @@ func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) {
return resolve_context, nil return resolve_context, nil
} }
// TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does // TODO: Make composabe by checking if K is a slice, then recursing in the same way that ExtractList does
func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) { func ExtractParam[K interface{}](p graphql.ResolveParams, name string) (K, error) {
var zero K var zero K
arg_if, ok := p.Args[name] arg_if, ok := p.Args[name]
@ -157,20 +76,6 @@ func ExtractList[K interface{}](p graphql.ResolveParams, name string) ([]K, erro
return ret, nil return ret, nil
} }
func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) {
id_str, err := ExtractParam[string](p, name)
if err != nil {
return ZeroID, err
}
id, err := ParseID(id_str)
if err != nil {
return ZeroID, err
}
return id, nil
}
func GraphiQLHandler() func(http.ResponseWriter, *http.Request) { func GraphiQLHandler() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r * http.Request) { return func(w http.ResponseWriter, r * http.Request) {
graphiql_string := fmt.Sprintf(` graphiql_string := fmt.Sprintf(`
@ -315,9 +220,6 @@ type ResolveContext struct {
// Graph Context this resolver is running under // Graph Context this resolver is running under
Context *Context Context *Context
// GQL Extension context this resolver is running under
GQLContext *GQLExtContext
// Pointer to the node that's currently processing this request // Pointer to the node that's currently processing this request
Server *Node Server *Node
@ -326,212 +228,6 @@ type ResolveContext struct {
// Cache of resolved nodes // Cache of resolved nodes
NodeCache map[NodeID]NodeResult NodeCache map[NodeID]NodeResult
// Authorization from the user that started this request
Authorization *ClientAuthorization
}
func AuthB64(client_key ed25519.PrivateKey, server_pubkey ed25519.PublicKey) (string, error) {
token_start := time.Now()
token_start_bytes, err := token_start.MarshalBinary()
if err != nil {
return "", err
}
session_key_public, session_key_private, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return "", err
}
session_h := sha512.Sum512(session_key_private.Seed())
ecdh_client, err := ECDH.NewPrivateKey(session_h[:32])
if err != nil {
return "", err
}
server_point, err := (&edwards25519.Point{}).SetBytes(server_pubkey)
if err != nil {
return "", err
}
ecdh_server, err := ECDH.NewPublicKey(server_point.BytesMontgomery())
if err != nil {
return "", err
}
secret, err := ecdh_client.ECDH(ecdh_server)
if err != nil {
return "", err
}
if len(secret) != 32 {
return "", fmt.Errorf("ECDH secret not 32 bytes(for AES-256): %d bytes long", len(secret))
}
block, err := aes.NewCipher(secret)
if err != nil {
return "", err
}
iv := make([]byte, block.BlockSize())
iv_len, err := rand.Reader.Read(iv)
if err != nil {
return "", err
} else if iv_len != block.BlockSize() {
return "", fmt.Errorf("Not enough iv bytes read: %d", iv_len)
}
var key_encrypted bytes.Buffer
stream := cipher.NewOFB(block, iv)
writer := &cipher.StreamWriter{S: stream, W: &key_encrypted}
bytes_written, err := writer.Write(session_key_private.Seed())
if err != nil {
return "", err
} else if bytes_written != len(ecdh_client.Bytes()) {
return "", fmt.Errorf("wrong number of bytes encrypted %d/%d", bytes_written, len(ecdh_client.Bytes()))
}
digest := append(session_key_public, token_start_bytes...)
signature, err := client_key.Sign(rand.Reader, digest, crypto.Hash(0))
if err != nil {
return "", err
}
start_b64 := base64.StdEncoding.EncodeToString(token_start_bytes)
iv_b64 := base64.StdEncoding.EncodeToString(iv)
encrypted_b64 := base64.StdEncoding.EncodeToString(key_encrypted.Bytes())
key_b64 := base64.StdEncoding.EncodeToString(session_key_public)
sig_b64 := base64.StdEncoding.EncodeToString(signature)
id_b64 := base64.StdEncoding.EncodeToString(client_key.Public().(ed25519.PublicKey))
return base64.StdEncoding.EncodeToString([]byte(strings.Join([]string{id_b64, iv_b64, key_b64, encrypted_b64, start_b64, sig_b64}, ":"))), nil
}
func ParseAuthB64(auth_base64 string, server_id ed25519.PrivateKey) (*ClientAuthorization, error) {
joined_b64, err := base64.StdEncoding.DecodeString(auth_base64)
if err != nil {
return nil, err
}
auth_parts := strings.Split(string(joined_b64), ":")
if len(auth_parts) != 6 {
return nil, fmt.Errorf("Wrong number of delimited elements %d", len(auth_parts))
}
id_bytes, err := base64.StdEncoding.DecodeString(auth_parts[0])
if err != nil {
return nil, err
}
iv, err := base64.StdEncoding.DecodeString(auth_parts[1])
if err != nil {
return nil, err
}
public_key, err := base64.StdEncoding.DecodeString(auth_parts[2])
if err != nil {
return nil, err
}
key_encrypted, err := base64.StdEncoding.DecodeString(auth_parts[3])
if err != nil {
return nil, err
}
start_bytes, err := base64.StdEncoding.DecodeString(auth_parts[4])
if err != nil {
return nil, err
}
signature, err := base64.StdEncoding.DecodeString(auth_parts[5])
if err != nil {
return nil, err
}
var start time.Time
err = start.UnmarshalBinary(start_bytes)
if err != nil {
return nil, err
}
client_id := ed25519.PublicKey(id_bytes)
if err != nil {
return nil, err
}
client_point, err := (&edwards25519.Point{}).SetBytes(public_key)
if err != nil {
return nil, err
}
ecdh_client, err := ECDH.NewPublicKey(client_point.BytesMontgomery())
if err != nil {
return nil, err
}
h := sha512.Sum512(server_id.Seed())
ecdh_server, err := ECDH.NewPrivateKey(h[:32])
if err != nil {
return nil, err
}
secret, err := ecdh_server.ECDH(ecdh_client)
if err != nil {
return nil, err
} else if len(secret) != 32 {
return nil, fmt.Errorf("Secret wrong length: %d/32", len(secret))
}
block, err := aes.NewCipher(secret)
if err != nil {
return nil, err
}
encrypted_reader := bytes.NewReader(key_encrypted)
stream := cipher.NewOFB(block, iv)
reader := cipher.StreamReader{S: stream, R: encrypted_reader}
var decrypted_key bytes.Buffer
_, err = io.Copy(&decrypted_key, reader)
if err != nil {
return nil, err
}
session_key := ed25519.NewKeyFromSeed(decrypted_key.Bytes())
digest := append(session_key.Public().(ed25519.PublicKey), start_bytes...)
if ed25519.Verify(client_id, digest, signature) == false {
return nil, fmt.Errorf("Failed to verify digest/signature against client_id")
}
return &ClientAuthorization{
AuthInfo: AuthInfo{
Identity: client_id,
Start: start,
Signature: signature,
},
Key: session_key,
}, nil
}
func ValidateAuthorization(auth Authorization, valid time.Duration) error {
// Check that the time + valid < now
// Check that Signature is public_key + start signed with client_id
if auth.Start.Add(valid).Compare(time.Now()) != 1 {
return fmt.Errorf("authorization expired")
}
time_bytes, err := auth.Start.MarshalBinary()
if err != nil {
return err
}
digest := append(auth.Key, time_bytes...)
if ed25519.Verify(auth.Identity, digest, auth.Signature) != true {
return fmt.Errorf("verification failed")
}
return nil
} }
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveContext, error) { func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveContext, error) {
@ -540,10 +236,8 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt) (*ResolveCon
Ext: gql_ext, Ext: gql_ext,
Chans: map[uuid.UUID]chan Signal{}, Chans: map[uuid.UUID]chan Signal{},
Context: ctx, Context: ctx,
GQLContext: ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext),
NodeCache: map[NodeID]NodeResult{}, NodeCache: map[NodeID]NodeResult{},
Server: server, Server: server,
Authorization: nil,
}, nil }, nil
} }
@ -557,13 +251,6 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
} }
ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS") ctx.Log.Logm("gql", header_map, "REQUEST_HEADERS")
auth, err := ParseAuthB64(r.Header.Get("Authorization"), server.Key)
if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ID_PARSE_ERROR: %s", err)
json.NewEncoder(w).Encode(GQLUnauthorized(""))
return
}
resolve_context, err := NewResolveContext(ctx, server, gql_ext) resolve_context, err := NewResolveContext(ctx, server, gql_ext)
if err != nil { if err != nil {
ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err) ctx.Log.Logf("gql", "GQL_AUTH_ERR: %s", err)
@ -571,8 +258,6 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
return return
} }
resolve_context.Authorization = auth
req_ctx := context.Background() req_ctx := context.Background()
req_ctx = context.WithValue(req_ctx, "resolve", resolve_context) req_ctx = context.WithValue(req_ctx, "resolve", resolve_context)
@ -585,10 +270,10 @@ func GQLHandler(ctx *Context, server *Node, gql_ext *GQLExt) func(http.ResponseW
query := GQLPayload{} query := GQLPayload{}
json.Unmarshal(str, &query) json.Unmarshal(str, &query)
gql_context := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext) schema := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(graphql.Schema)
params := graphql.Params{ params := graphql.Params{
Schema: gql_context.Schema, Schema: schema,
Context: req_ctx, Context: req_ctx,
RequestString: query.Query, RequestString: query.Query,
} }
@ -716,14 +401,6 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
break break
} }
authorization, err := ParseAuthB64(connection_params.Payload.Token, server.Key)
if err != nil {
ctx.Log.Logf("gqlws", "WS_AUTH_PARSE_ERR: %s", err)
break
}
resolve_context.Authorization = authorization
conn_state = "ready" conn_state = "ready"
err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}")) err = wsutil.WriteServerMessage(conn, 1, []byte("{\"type\": \"connection_ack\"}"))
if err != nil { if err != nil {
@ -739,9 +416,9 @@ func GQLWSHandler(ctx * Context, server *Node, gql_ext *GQLExt) func(http.Respon
} }
} else if msg.Type == "subscribe" { } else if msg.Type == "subscribe" {
ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload)
gql_context := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(*GQLExtContext) schema := ctx.Extensions[ExtTypeFor[GQLExt]()].Data.(graphql.Schema)
params := graphql.Params{ params := graphql.Params{
Schema: gql_context.Schema, Schema: schema,
Context: req_ctx, Context: req_ctx,
RequestString: msg.Payload.Query, RequestString: msg.Payload.Query,
} }
@ -829,165 +506,10 @@ type Field struct {
Field *graphql.Field Field *graphql.Field
} }
// GQL Specific Context information
type GQLExtContext struct {
// Generated GQL schema
Schema graphql.Schema
// Custom graphql types, mapped to NodeTypes
NodeTypes map[NodeType]*graphql.Object
Interfaces map[string]*Interface
Fields map[string]Field
// Schema parameters
Types []graphql.Type
Query *graphql.Object
Mutation *graphql.Object
}
func (ctx *GQLExtContext) GetACLFields(obj_name string, names []string) (map[ExtType][]string, error) {
ext_fields := map[ExtType][]string{}
for _, name := range(names) {
switch name {
case "ID":
case "TypeHash":
default:
field, exists := ctx.Fields[name]
if exists == false {
continue
}
ext, exists := ext_fields[field.Ext]
if exists == false {
ext = []string{}
}
ext = append(ext, field.Name)
ext_fields[field.Ext] = ext
}
}
return ext_fields, nil
}
func BuildSchema(ctx *GQLExtContext) (graphql.Schema, error) {
schemaConfig := graphql.SchemaConfig{
Types: ctx.Types,
Query: ctx.Query,
Mutation: ctx.Mutation,
}
return graphql.NewSchema(schemaConfig)
}
func (ctx *GQLExtContext) RegisterField(gql_type graphql.Type, gql_name string, ext_type ExtType, gv_tag string, resolve_fn func(graphql.ResolveParams, *ResolveContext, reflect.Value)(interface{}, error)) error {
if ctx == nil {
return fmt.Errorf("ctx is nil")
}
if resolve_fn == nil {
return fmt.Errorf("resolve_fn cannot be nil")
}
_, exists := ctx.Fields[gql_name]
if exists == true {
return fmt.Errorf("%s is already a field in the context, cannot add again", gql_name)
}
// Resolver has p.Source.(NodeResult) = read result of current node
resolver := func(p graphql.ResolveParams)(interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
node, ok := p.Source.(NodeResult)
if ok == false {
return nil, fmt.Errorf("p.Value is not NodeResult")
}
ext, ext_exists := node.Data[ext_type]
if ext_exists == false {
return nil, fmt.Errorf("%+v is not in the extensions of the result: %+v", ext_type, node.Data)
}
val_ser, field_exists := ext[gv_tag]
if field_exists == false {
return nil, fmt.Errorf("%s is not in the fields of %+v in the result for %s - %+v", gv_tag, ext_type, gql_name, node)
}
if val_ser.TypeStack[0] == SerializedTypeFor[error]() {
return nil, fmt.Errorf(string(val_ser.Data))
}
field_type, _, err := DeserializeType(ctx.Context, val_ser.TypeStack)
if err != nil {
return nil, err
}
field_value, _, err := DeserializeValue(ctx.Context, field_type, val_ser.Data)
if err != nil {
return nil, err
}
ctx.Context.Log.Logf("gql", "Resolving %+v", field_value)
return resolve_fn(p, ctx, field_value)
}
ctx.Fields[gql_name] = Field{ext_type, gv_tag, &graphql.Field{
Type: gql_type,
Resolve: resolver,
}}
return nil
}
func GQLInterfaces(ctx *GQLExtContext, interface_names []string) ([]*graphql.Interface, error) {
ret := make([]*graphql.Interface, len(interface_names))
for i, in := range(interface_names) {
ctx_interface, exists := ctx.Interfaces[in]
if exists == false {
return nil, fmt.Errorf("%s is not in GQLExtContext.Interfaces", in)
}
ret[i] = ctx_interface.Interface
}
return ret, nil
}
func GQLFields(ctx *GQLExtContext, field_names []string) (graphql.Fields, []ExtType, error) {
fields := graphql.Fields{
"ID": &graphql.Field{
Type: graphql.String,
Resolve: ResolveNodeID,
},
"TypeHash": &graphql.Field{
Type: graphql.String,
Resolve: ResolveNodeTypeHash,
},
}
exts := map[ExtType]ExtType{}
ext_list := []ExtType{}
for _, name := range(field_names) {
field, exists := ctx.Fields[name]
if exists == false {
return nil, nil, fmt.Errorf("%s is not in GQLExtContext.Fields", name)
}
fields[name] = field.Field
_, exists = exts[field.Ext]
if exists == false {
ext_list = append(ext_list, field.Ext)
exts[field.Ext] = field.Ext
}
}
return fields, ext_list, nil
}
type NodeResult struct { type NodeResult struct {
NodeID NodeID NodeID NodeID
NodeType NodeType NodeType NodeType
Data map[ExtType]map[string]SerializedValue Data map[ExtType]map[string]interface{}
} }
type ListField struct { type ListField struct {
@ -1002,193 +524,6 @@ type SelfField struct {
ResolveFn func(graphql.ResolveParams, *ResolveContext, reflect.Value) (*NodeID, error) ResolveFn func(graphql.ResolveParams, *ResolveContext, reflect.Value) (*NodeID, error)
} }
func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, interfaces []string, fields []string, self_fields map[string]SelfField, list_fields map[string]ListField) error {
if interfaces == nil {
return fmt.Errorf("interfaces is nil")
}
if fields == nil {
return fmt.Errorf("fields is nil")
}
_, exists := ctx.Interfaces[name]
if exists == true {
return fmt.Errorf("%s is already an interface in ctx", name)
}
node_interfaces, err := GQLInterfaces(ctx, interfaces)
if err != nil {
return err
}
node_fields, node_exts, err := GQLFields(ctx, fields)
if err != nil {
return err
}
ctx_interface := Interface{}
ctx_interface.Interface = graphql.NewInterface(graphql.InterfaceConfig{
Name: name,
ResolveType: NodeInterfaceResolveType(node_exts, &ctx_interface.Default),
Fields: node_fields,
})
ctx_interface.List = graphql.NewList(ctx_interface.Interface)
for field_name, field := range(self_fields) {
self_field := field
err := ctx.RegisterField(ctx_interface.Interface, field_name, self_field.Extension, self_field.ACLName,
func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) {
id, err := self_field.ResolveFn(p, ctx, value)
if err != nil {
return nil, err
}
if id != nil {
nodes, err := ResolveNodes(ctx, p, []NodeID{*id})
if err != nil {
return nil, err
} else if len(nodes) != 1 {
return nil, fmt.Errorf("wrong length of nodes returned")
}
return nodes[0], nil
} else {
return nil, nil
}
})
if err != nil {
return err
}
ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field)
node_fields[field_name] = ctx.Fields[field_name].Field
}
for field_name, field := range(list_fields) {
list_field := field
resolve_fn := func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value)(interface{}, error) {
var zero NodeID
ids, err := list_field.ResolveFn(p, ctx, value)
if err != nil {
return zero, err
}
nodes, err := ResolveNodes(ctx, p, ids)
if err != nil {
return nil, err
} else if len(nodes) != len(ids) {
return nil, fmt.Errorf("wrong length of nodes returned")
}
return nodes, nil
}
err := ctx.RegisterField(ctx_interface.List, field_name, list_field.Extension, list_field.ACLName, resolve_fn)
if err != nil {
return err
}
ctx_interface.Interface.AddFieldConfig(field_name, ctx.Fields[field_name].Field)
node_fields[field_name] = ctx.Fields[field_name].Field
}
ctx_interface.Default = graphql.NewObject(graphql.ObjectConfig{
Name: default_name,
Interfaces: append(node_interfaces, ctx_interface.Interface),
IsTypeOf: NodeInterfaceDefaultIsType(node_exts),
Fields: node_fields,
})
ctx.Interfaces[name] = &ctx_interface
ctx.Types = append(ctx.Types, ctx_interface.Default)
return nil
}
func (ctx *GQLExtContext) RegisterNodeType(node_type NodeType, name string, interface_names []string, field_names []string) error {
if field_names == nil {
return fmt.Errorf("fields is nil")
}
_, exists := ctx.NodeTypes[node_type]
if exists == true {
return fmt.Errorf("%+v already in GQLExtContext.NodeTypes", node_type)
}
node_interfaces, err := GQLInterfaces(ctx, interface_names)
if err != nil {
return err
}
gql_fields, _, err := GQLFields(ctx, field_names)
if err != nil {
return err
}
gql_type := graphql.NewObject(graphql.ObjectConfig{
Name: name,
Interfaces: node_interfaces,
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
node, ok := p.Value.(NodeResult)
if ok == false {
return false
}
return node.NodeType == node_type
},
Fields: gql_fields,
})
ctx.NodeTypes[node_type] = gql_type
ctx.Types = append(ctx.Types, gql_type)
return nil
}
func NewGQLExtContext() *GQLExtContext {
query := graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: graphql.Fields{
"Test": &graphql.Field{
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return "Test Data", nil
},
},
},
})
mutation := graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: graphql.Fields{
"Test": &graphql.Field{
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return "Test Mutation Data", nil
},
},
},
})
context := GQLExtContext{
Schema: graphql.Schema{},
Types: []graphql.Type{},
Query: query,
Mutation: mutation,
NodeTypes: map[NodeType]*graphql.Object{},
Interfaces: map[string]*Interface{},
Fields: map[string]Field{},
}
schema, err := BuildSchema(&context)
if err != nil {
panic(err)
}
context.Schema = schema
return &context
}
type SubscriptionInfo struct { type SubscriptionInfo struct {
ID uuid.UUID ID uuid.UUID
Channel chan interface{} Channel chan interface{}
@ -1295,9 +630,10 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
return response_chan return response_chan
} }
func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
// Process ReadResultSignalType by forwarding it to the waiting resolver // Process ReadResultSignalType by forwarding it to the waiting resolver
var changes = Changes{} var changes Changes = nil
var messages []SendMsg = nil
switch sig := signal.(type) { switch sig := signal.(type) {
case *SuccessSignal: case *SuccessSignal:
@ -1355,7 +691,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
ext.subscriptions_lock.RUnlock() ext.subscriptions_lock.RUnlock()
} }
return nil, changes return messages, changes
} }
var ecdsa_curves = map[uint8]elliptic.Curve{ var ecdsa_curves = map[uint8]elliptic.Curve{

@ -1,11 +1,10 @@
package graphvent package graphvent
import ( import (
"time"
"reflect" "reflect"
"fmt" "fmt"
"time"
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/ast"
"github.com/google/uuid"
) )
func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) { func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
@ -17,7 +16,7 @@ func ResolveNodeID(p graphql.ResolveParams) (interface{}, error) {
return node.NodeID, nil return node.NodeID, nil
} }
func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { func ResolveNodeType(p graphql.ResolveParams) (interface{}, error) {
node, ok := p.Source.(NodeResult) node, ok := p.Source.(NodeResult)
if ok == false { if ok == false {
return nil, fmt.Errorf("Can't get TypeHash from %+v", reflect.TypeOf(p.Source)) return nil, fmt.Errorf("Can't get TypeHash from %+v", reflect.TypeOf(p.Source))
@ -37,7 +36,6 @@ func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string {
case *ast.Field: case *ast.Field:
names = append(names, field.Name.Value) names = append(names, field.Name.Value)
case *ast.InlineFragment: case *ast.InlineFragment:
names = append(names, GetFieldNames(ctx, field.SelectionSet)...)
default: default:
ctx.Log.Logf("gql", "Unknown selection type: %s", reflect.TypeOf(field)) ctx.Log.Logf("gql", "Unknown selection type: %s", reflect.TypeOf(field))
} }
@ -46,144 +44,109 @@ func GetFieldNames(ctx *Context, selection_set *ast.SelectionSet) []string {
return names return names
} }
func GetResolveFields(ctx *Context, p graphql.ResolveParams) []string { // Returns the fields that need to be resolved
func GetResolveFields(id NodeID, ctx *ResolveContext, p graphql.ResolveParams) (map[ExtType][]string, error) {
node_info, mapped := ctx.Context.NodeTypes[p.Info.ReturnType.Name()]
if mapped == false {
return nil, fmt.Errorf("No NodeType %s", p.Info.ReturnType.Name())
}
fields := map[ExtType][]string{}
names := []string{} names := []string{}
for _, field := range(p.Info.FieldASTs) { for _, field := range(p.Info.FieldASTs) {
names = append(names, GetFieldNames(ctx, field.SelectionSet)...) names = append(names, GetFieldNames(ctx.Context, field.SelectionSet)...)
} }
return names cache, node_cached := ctx.NodeCache[id]
} for _, name := range(names) {
if name == "ID" || name == "Type" {
continue
}
func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([]NodeResult, error) { ext_type, field_mapped := node_info.Fields[name]
fields := GetResolveFields(ctx.Context, p) if field_mapped == false {
ctx.Context.Log.Logf("gql_resolve_node", "RESOLVE_NODES(%+v): %+v", ids, fields) return nil, fmt.Errorf("NodeType %s does not have field %s", p.Info.ReturnType.Name(), name)
}
resp_channels := map[uuid.UUID]chan Signal{} ext_fields, exists := fields[ext_type]
indices := map[uuid.UUID]int{} if exists == false {
ext_fields = []string{}
}
// Get a list of fields that will be written if node_cached {
ext_fields, err := ctx.GQLContext.GetACLFields(p.Info.FieldName, fields) ext_cache, ext_cached := cache.Data[ext_type]
if err != nil { if ext_cached {
return nil, err _, field_cached := ext_cache[name]
} if field_cached {
ctx.Context.Log.Logf("gql_resolve_node", "ACL Fields from request: %+v", ext_fields) continue
responses := make([]NodeResult, len(ids))
for i, id := range(ids) {
var read_signal *ReadSignal = nil
node, cached := ctx.NodeCache[id]
if cached == true {
resolve := false
missing_exts := map[ExtType][]string{}
for ext_type, fields := range(ext_fields) {
cached_ext, exists := node.Data[ext_type]
if exists == true {
missing_fields := []string{}
for _, field_name := range(fields) {
_, found := cached_ext[field_name]
if found == false {
missing_fields = append(missing_fields, field_name)
}
}
if len(missing_fields) > 0 {
missing_exts[ext_type] = missing_fields
resolve = true
}
} else {
missing_exts[ext_type] = fields
resolve = true
} }
} }
if resolve == true {
read_signal = NewReadSignal(missing_exts)
ctx.Context.Log.Logf("gql_resolve_node", "sending read for %+v because of missing fields %+v", id, missing_exts)
} else {
ctx.Context.Log.Logf("gql_resolve_node", "Using cached response for %+v(%d)", id, i)
responses[i] = node
continue
}
} else {
ctx.Context.Log.Logf("gql_resolve_node", "sending read for %+v", id)
read_signal = NewReadSignal(ext_fields)
}
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
msgs := Messages{}
msgs = msgs.Add(ctx.Context, id, ctx.Server, ctx.Authorization, read_signal)
response_chan := ctx.Ext.GetResponseChannel(read_signal.ID())
resp_channels[read_signal.ID()] = response_chan
indices[read_signal.ID()] = i
// TODO: Send all at once instead of creating Messages for each
err = ctx.Context.Send(msgs)
if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID())
return nil, err
} }
fields[ext_type] = append(ext_fields, name)
} }
errors := "" return fields, nil
for sig_id, response_chan := range(resp_channels) { }
// Wait for the response, returning an error on timeout
response, other, err := WaitForResponse(response_chan, time.Millisecond*100, sig_id)
if err != nil {
return nil, err
}
ctx.Context.Log.Logf("gql_resolve_node", "GQL node response: %+v", response)
ctx.Context.Log.Logf("gql_resolve_node", "GQL node other messages: %+v", other)
// for now, just put signals we didn't want back into the 'queue' func ResolveNode(id NodeID, p graphql.ResolveParams) (NodeResult, error) {
for _, other_signal := range(other) { ctx, err := PrepResolve(p)
response_chan <- other_signal if err != nil {
} return NodeResult{}, err
}
error_signal, is_error := response.(*ErrorSignal) fields, err := GetResolveFields(id, ctx, p)
if is_error { if err != nil {
errors = fmt.Sprintf("%s, %s", errors, error_signal.Error) return NodeResult{}, err
continue }
}
read_response, is_read_response := response.(*ReadResultSignal) ctx.Context.Log.Logf("gql", "Resolving fields %+v on node %s", fields, id)
if is_read_response == false {
errors = fmt.Sprintf("%s, wrong response type %+v", errors, reflect.TypeOf(response))
continue
}
idx := indices[sig_id] signal := NewReadSignal(fields)
responses[idx] = NodeResult{ response_chan := ctx.Ext.GetResponseChannel(signal.ID())
read_response.NodeID, // TODO: TIMEOUT DURATION
read_response.NodeType, err = ctx.Context.Send(ctx.Server, []SendMsg{{
read_response.Extensions, Dest: id,
} Signal: signal,
}})
if err != nil {
ctx.Ext.FreeResponseChannel(signal.ID())
return NodeResult{}, err
}
cache, exists := ctx.NodeCache[read_response.NodeID] response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
if exists == true { ctx.Ext.FreeResponseChannel(signal.ID())
ctx.Context.Log.Logf("gql_resolve_node", "Merging new response with cached: %s, %+v - %+v", read_response.NodeID, cache, read_response.Extensions) if err != nil {
for ext_type, fields := range(read_response.Extensions) { return NodeResult{}, err
cached_fields, exists := cache.Data[ext_type] }
if exists == false {
cached_fields = map[string]SerializedValue{} switch response := response.(type) {
cache.Data[ext_type] = cached_fields case *ReadResultSignal:
} cache, node_cached := ctx.NodeCache[id]
for field_name, field_value := range(fields) { if node_cached == false {
cached_fields[field_name] = field_value cache = NodeResult{
} NodeID: id,
NodeType: response.NodeType,
Data: response.Extensions,
} }
responses[idx] = cache
} else { } else {
ctx.Context.Log.Logf("gql_resolve_node", "Adding new response to node cache: %s, %+v", read_response.NodeID, read_response.Extensions) for ext_type, ext_data := range(response.Extensions) {
ctx.NodeCache[read_response.NodeID] = responses[idx] cached_ext, ext_cached := cache.Data[ext_type]
if ext_cached {
for field_name, field := range(ext_data) {
cache.Data[ext_type][field_name] = field
}
} else {
cache.Data[ext_type] = ext_data
}
cache.Data[ext_type] = cached_ext
}
} }
}
if errors != "" { ctx.NodeCache[id] = cache
return nil, fmt.Errorf(errors) return ctx.NodeCache[id], nil
default:
return NodeResult{}, fmt.Errorf("Bad read response: %+v", response)
} }
ctx.Context.Log.Logf("gql_resolve_node", "RESOLVED_NODES %+v - %+v", ids, responses)
return responses, nil
} }

@ -1,145 +0,0 @@
package graphvent
import (
graphql "github.com/graphql-go/graphql"
"github.com/google/uuid"
"reflect"
"fmt"
"time"
)
type StructFieldInfo struct {
Name string
Type *TypeInfo
Index []int
}
func ArgumentInfo(ctx *Context, field reflect.StructField, gv_tag string) (StructFieldInfo, error) {
type_info, mapped := ctx.TypeReflects[field.Type]
if mapped == false {
return StructFieldInfo{}, fmt.Errorf("field %+v is of unregistered type %+v ", field.Name, field.Type)
}
return StructFieldInfo{
Name: gv_tag,
Type: type_info,
Index: field.Index,
}, nil
}
func SignalFromArgs(ctx *Context, signal_type reflect.Type, fields []StructFieldInfo, args map[string]interface{}, id_index, direction_index []int) (Signal, error) {
fmt.Printf("FIELD: %+v\n", fields)
signal_value := reflect.New(signal_type)
id_field := signal_value.Elem().FieldByIndex(id_index)
id_field.Set(reflect.ValueOf(uuid.New()))
direction_field := signal_value.Elem().FieldByIndex(direction_index)
direction_field.Set(reflect.ValueOf(Direct))
for _, field := range(fields) {
arg, arg_exists := args[field.Name]
if arg_exists == false {
return nil, fmt.Errorf("No arg provided named %s", field.Name)
}
field_value := signal_value.Elem().FieldByIndex(field.Index)
if field_value.CanConvert(field.Type.Reflect) == false {
return nil, fmt.Errorf("Arg %s wrong type %s/%s", field.Name, field_value.Type(), field.Type.Reflect)
}
value, err := field.Type.GQLValue(ctx, arg)
if err != nil {
return nil, err
}
fmt.Printf("Setting %s to %+v of type %+v\n", field.Name, value, value.Type())
field_value.Set(value)
}
return signal_value.Interface().(Signal), nil
}
func NewSignalMutation(ctx *Context, name string, send_id_key string, signal_type reflect.Type) (*graphql.Field, error) {
args := graphql.FieldConfigArgument{}
arg_info := []StructFieldInfo{}
var id_index []int = nil
var direction_index []int = nil
for _, field := range(reflect.VisibleFields(signal_type)) {
gv_tag, tagged_gv := field.Tag.Lookup("gv")
if tagged_gv {
if gv_tag == "id" {
id_index = field.Index
} else if gv_tag == "direction" {
direction_index = field.Index
} else {
_, exists := args[gv_tag]
if exists == true {
return nil, fmt.Errorf("Signal has repeated tag %s", gv_tag)
} else {
info, err := ArgumentInfo(ctx, field, gv_tag)
if err != nil {
return nil, err
}
args[gv_tag] = &graphql.ArgumentConfig{
}
arg_info = append(arg_info, info)
}
}
}
}
_, send_exists := args[send_id_key]
if send_exists == false {
args[send_id_key] = &graphql.ArgumentConfig{
Type: graphql.String,
}
}
resolve_signal := func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
if err != nil {
return nil, err
}
send_id, err := ExtractID(p, send_id_key)
if err != nil {
return nil, err
}
signal, err := SignalFromArgs(ctx.Context, signal_type, arg_info, p.Args, id_index, direction_index)
if err != nil {
return nil, err
}
msgs := Messages{}
msgs = msgs.Add(ctx.Context, send_id, ctx.Server, ctx.Authorization, signal)
response_chan := ctx.Ext.GetResponseChannel(signal.ID())
err = ctx.Context.Send(msgs)
if err != nil {
ctx.Ext.FreeResponseChannel(signal.ID())
return nil, err
}
response, _, err := WaitForResponse(response_chan, 100*time.Millisecond, signal.ID())
if err != nil {
ctx.Ext.FreeResponseChannel(signal.ID())
return nil, err
}
_, is_success := response.(*SuccessSignal)
if is_success == true {
return "success", nil
}
error_response, is_error := response.(*ErrorSignal)
if is_error == true {
return "error", fmt.Errorf(error_response.Error)
}
return nil, fmt.Errorf("response of unhandled type %s", reflect.TypeOf(response))
}
return &graphql.Field{
Type: graphql.String,
Args: args,
Resolve: resolve_signal,
}, nil
}

@ -3,6 +3,7 @@ package graphvent
import ( import (
"testing" "testing"
"runtime/debug" "runtime/debug"
"time"
badger "github.com/dgraph-io/badger/v3" badger "github.com/dgraph-io/badger/v3"
) )
@ -28,7 +29,7 @@ func logTestContext(t * testing.T, components []string) *Context {
ctx, err := NewContext(db, NewConsoleLogger(components)) ctx, err := NewContext(db, NewConsoleLogger(components))
fatalErr(t, err) fatalErr(t, err)
err = RegisterNodeType(ctx, "LockableListener", []ExtType{ExtTypeFor[ListenerExt](), ExtTypeFor[LockableExt]()}, map[string]FieldIndex{}) err = RegisterNodeType(ctx, "LockableListener", []ExtType{ExtTypeFor[ListenerExt](), ExtTypeFor[LockableExt]()})
fatalErr(t, err) fatalErr(t, err)
return ctx return ctx
@ -44,3 +45,16 @@ func fatalErr(t * testing.T, err error) {
t.Fatal(err) t.Fatal(err)
} }
} }
func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) (ResponseSignal, []Signal) {
source_listener, err := GetExt[ListenerExt](source)
fatalErr(t, err)
messages := []SendMsg{{destination.ID, signal}}
fatalErr(t, ctx.Send(source, messages))
response, signals, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())
fatalErr(t, err)
return response, signals
}

@ -1,296 +0,0 @@
package graphvent
import (
"slices"
)
type AddSubGroupSignal struct {
SignalHeader
Name string `gv:"name"`
}
func NewAddSubGroupSignal(name string) *AddSubGroupSignal {
return &AddSubGroupSignal{
NewSignalHeader(Direct),
name,
}
}
func (signal AddSubGroupSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[AddSubGroupSignal]()): {
Hash("name", signal.Name): nil,
},
}
}
type RemoveSubGroupSignal struct {
SignalHeader
Name string `gv:"name"`
}
func NewRemoveSubGroupSignal(name string) *RemoveSubGroupSignal {
return &RemoveSubGroupSignal{
NewSignalHeader(Direct),
name,
}
}
func (signal RemoveSubGroupSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[RemoveSubGroupSignal]()): {
Hash("command", signal.Name): nil,
},
}
}
type AddMemberSignal struct {
SignalHeader
SubGroup string `gv:"sub_group"`
MemberID NodeID `gv:"member_id"`
}
type SubGroupGQL struct {
Name string
Members []NodeID
}
func (signal AddMemberSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[AddMemberSignal]()): {
Hash("sub_group", signal.SubGroup): nil,
},
}
}
func NewAddMemberSignal(sub_group string, member_id NodeID) *AddMemberSignal {
return &AddMemberSignal{
NewSignalHeader(Direct),
sub_group,
member_id,
}
}
type RemoveMemberSignal struct {
SignalHeader
SubGroup string `gv:"sub_group"`
MemberID NodeID `gv:"member_id"`
}
func (signal RemoveMemberSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[RemoveMemberSignal]()): {
Hash("sub_group", signal.SubGroup): nil,
},
}
}
func NewRemoveMemberSignal(sub_group string, member_id NodeID) *RemoveMemberSignal {
return &RemoveMemberSignal{
NewSignalHeader(Direct),
sub_group,
member_id,
}
}
var DefaultGroupPolicy = NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[ReadSignal]()): {
SerializedType(ExtTypeFor[GroupExt]()): {
SerializedType(GetFieldTag("sub_groups")): nil,
},
},
})
type SubGroup struct {
Members []NodeID
Permissions Tree
}
type MemberOfPolicy struct {
PolicyHeader
Groups map[NodeID]map[string]Tree
}
func NewMemberOfPolicy(groups map[NodeID]map[string]Tree) MemberOfPolicy {
return MemberOfPolicy{
PolicyHeader: NewPolicyHeader(),
Groups: groups,
}
}
func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal)
if ok == false {
return Deny
}
ctx.Log.Logf("group", "member_of_read_result: %+v", sig.Extensions)
group_ext_data, ok := sig.Extensions[ExtTypeFor[GroupExt]()]
if ok == false {
return Deny
}
sub_groups_ser, ok := group_ext_data["sub_groups"]
if ok == false {
return Deny
}
sub_groups_type, _, err := DeserializeType(ctx, sub_groups_ser.TypeStack)
if err != nil {
ctx.Log.Logf("group", "Type deserialize error: %s", err)
return Deny
}
sub_groups_if, _, err := DeserializeValue(ctx, sub_groups_type, sub_groups_ser.Data)
if err != nil {
ctx.Log.Logf("group", "Value deserialize error: %s", err)
return Deny
}
ext_sub_groups, ok := sub_groups_if.Interface().(map[string][]NodeID)
if ok == false {
return Deny
}
group, exists := policy.Groups[sig.NodeID]
if exists == false {
return Deny
}
for sub_group_name, permissions := range(group) {
ext_sub_group, exists := ext_sub_groups[sub_group_name]
if exists == true {
for _, member_id := range(ext_sub_group) {
if member_id == current.Principal {
if permissions.Allows(current.Action) == Allow {
return Allow
}
}
}
}
}
return Deny
}
// Send a read signal to Group to check if principal_id is a member of it
func (policy MemberOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) {
var messages Messages = nil
for group_id, sub_groups := range(policy.Groups) {
if group_id == node.ID {
ext, err := GetExt[GroupExt](node)
if err != nil {
ctx.Log.Logf("group", "MemberOfPolicy with self ID error: %s", err)
} else {
for sub_group_name, permission := range(sub_groups) {
ext_sub_group, exists := ext.SubGroups[sub_group_name]
if exists == true {
for _, member := range(ext_sub_group) {
if member == principal_id {
if permission.Allows(action) == Allow {
return nil, Allow
}
break
}
}
}
}
}
} else {
// Send the read request to the group so that ContinueAllows can parse the response to check membership
messages = messages.Add(ctx, group_id, node, nil, NewReadSignal(map[ExtType][]string{
ExtTypeFor[GroupExt](): {"sub_groups"},
}))
}
}
if len(messages) > 0 {
return messages, Pending
} else {
return nil, Deny
}
}
type GroupExt struct {
SubGroups map[string][]NodeID `gv:"sub_groups"`
}
func NewGroupExt(sub_groups map[string][]NodeID) *GroupExt {
if sub_groups == nil {
sub_groups = map[string][]NodeID{}
}
return &GroupExt{
SubGroups: sub_groups,
}
}
func (ext *GroupExt) Load(ctx *Context, node *Node) error {
return nil
}
func (ext *GroupExt) Unload(ctx *Context, node *Node) {
}
func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
var messages Messages = nil
var changes = Changes{}
switch sig := signal.(type) {
case *AddMemberSignal:
sub_group, exists := ext.SubGroups[sig.SubGroup]
if exists == false {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
} else {
if slices.Contains(sub_group, sig.MemberID) == true {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_member"))
} else {
sub_group = append(sub_group, sig.MemberID)
ext.SubGroups[sig.SubGroup] = sub_group
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
changes.Add("sub_groups")
}
}
case *RemoveMemberSignal:
sub_group, exists := ext.SubGroups[sig.SubGroup]
if exists == false {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
} else {
idx := slices.Index(sub_group, sig.MemberID)
if idx == -1 {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_member"))
} else {
sub_group = slices.Delete(sub_group, idx, idx+1)
ext.SubGroups[sig.SubGroup] = sub_group
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
changes.Add("sub_groups")
}
}
case *AddSubGroupSignal:
_, exists := ext.SubGroups[sig.Name]
if exists == true {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "already_subgroup"))
} else {
ext.SubGroups[sig.Name] = []NodeID{}
changes.Add("sub_groups")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
}
case *RemoveSubGroupSignal:
_, exists := ext.SubGroups[sig.Name]
if exists == false {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "not_subgroup"))
} else {
delete(ext.SubGroups, sig.Name)
changes.Add("sub_groups")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id))
}
}
return messages, changes
}

@ -1,94 +0,0 @@
package graphvent
import (
"testing"
"time"
)
func TestGroupAdd(t *testing.T) {
ctx := logTestContext(t, []string{"listener", "test"})
group_listener := NewListenerExt(10)
group, err := NewNode(ctx, nil, "Base", 10, nil, group_listener, NewGroupExt(nil))
fatalErr(t, err)
add_subgroup_signal := NewAddSubGroupSignal("test_group")
messages := Messages{}
messages = messages.Add(ctx, group.ID, group, nil, add_subgroup_signal)
fatalErr(t, ctx.Send(messages))
resp_1, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_subgroup_signal.Id)
fatalErr(t, err)
error_1, is_error := resp_1.(*ErrorSignal)
if is_error {
t.Fatalf("Error returned: %s", error_1.Error)
}
user_id := RandID()
add_member_signal := NewAddMemberSignal("test_group", user_id)
messages = Messages{}
messages = messages.Add(ctx, group.ID, group, nil, add_member_signal)
fatalErr(t, ctx.Send(messages))
resp_2, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, add_member_signal.Id)
fatalErr(t, err)
error_2, is_error := resp_2.(*ErrorSignal)
if is_error {
t.Fatalf("Error returned: %s", error_2.Error)
}
read_signal := NewReadSignal(map[ExtType][]string{
ExtTypeFor[GroupExt](): {"sub_groups"},
})
messages = Messages{}
messages = messages.Add(ctx, group.ID, group, nil, read_signal)
fatalErr(t, ctx.Send(messages))
response, _, err := WaitForResponse(group_listener.Chan, 10*time.Millisecond, read_signal.Id)
fatalErr(t, err)
read_response := response.(*ReadResultSignal)
sub_groups_serialized := read_response.Extensions[ExtTypeFor[GroupExt]()]["sub_groups"]
sub_groups_type, remaining_types, err := DeserializeType(ctx, sub_groups_serialized.TypeStack)
fatalErr(t, err)
if len(remaining_types) > 0 {
t.Fatalf("Types remaining after deserializing subgroups: %d", len(remaining_types))
}
sub_groups_value, remaining, err := DeserializeValue(ctx, sub_groups_type, sub_groups_serialized.Data)
fatalErr(t, err)
if len(remaining) > 0 {
t.Fatalf("Data remaining after deserializing subgroups: %d", len(remaining_types))
}
sub_groups, ok := sub_groups_value.Interface().(map[string][]NodeID)
if ok != true {
t.Fatalf("sub_groups wrong type %s", sub_groups_value.Type())
}
if len(sub_groups) != 1 {
t.Fatalf("sub_groups wrong length %d", len(sub_groups))
}
test_subgroup, exists := sub_groups["test_group"]
if exists == false {
t.Fatal("test_group not in subgroups")
}
if len(test_subgroup) != 1 {
t.Fatalf("test_group wrong size %d/1", len(test_subgroup))
}
if test_subgroup[0] != user_id {
t.Fatalf("sub_groups wrong value %s", test_subgroup[0])
}
ctx.Log.Logf("test", "Read Response: %+v", read_response)
}

@ -11,17 +11,13 @@ type ListenerExt struct {
} }
func (ext *ListenerExt) Load(ctx *Context, node *Node) error { func (ext *ListenerExt) Load(ctx *Context, node *Node) error {
ext.Chan = make(chan Signal, ext.Buffer)
return nil return nil
} }
func (ext *ListenerExt) Unload(ctx *Context, node *Node) { func (ext *ListenerExt) Unload(ctx *Context, node *Node) {
} }
func (ext *ListenerExt) PostDeserialize(ctx *Context) error {
ext.Chan = make(chan Signal, ext.Buffer)
return nil
}
// Create a new listener extension with a given buffer size // Create a new listener extension with a given buffer size
func NewListenerExt(buffer int) *ListenerExt { func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{ return &ListenerExt{
@ -31,7 +27,7 @@ func NewListenerExt(buffer int) *ListenerExt {
} }
// Send the signal to the channel, logging an overflow if it occurs // Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal)) ctx.Log.Logf("listener", "%s - %+v", node.ID, reflect.TypeOf(signal))
ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal) ctx.Log.Logf("listener_debug", "%s->%s - %+v", source, node.ID, signal)
select { select {

@ -5,18 +5,6 @@ import (
"time" "time"
) )
var AllowParentUnlockPolicy = NewOwnerOfPolicy(Tree{
SerializedType(SignalTypeFor[LockSignal]()): {
Hash(LockStateBase, "unlock"): nil,
},
})
var AllowAnyLockPolicy = NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[LockSignal]()): {
Hash(LockStateBase, "lock"): nil,
},
})
type ReqState byte type ReqState byte
const ( const (
Unlocked = ReqState(0) Unlocked = ReqState(0)
@ -37,11 +25,11 @@ var ReqStateStrings = map[ReqState]string {
type LockableExt struct{ type LockableExt struct{
State ReqState `gv:"state"` State ReqState `gv:"state"`
ReqID *uuid.UUID `gv:"req_id"` ReqID *uuid.UUID `gv:"req_id"`
Owner *NodeID `gv:"owner"` Owner *NodeID `gv:"owner" node:"Base"`
PendingOwner *NodeID `gv:"pending_owner"` PendingOwner *NodeID `gv:"pending_owner" node:"Base"`
PendingID uuid.UUID `gv:"pending_id"` PendingID uuid.UUID `gv:"pending_id"`
Requirements map[NodeID]ReqState `gv:"requirements"` Requirements map[NodeID]ReqState `gv:"requirements" node:"Lockable:"`
WaitInfos WaitMap `gv:"wait_infos"` WaitInfos WaitMap `gv:"wait_infos" node:":Base"`
} }
func NewLockableExt(requirements []NodeID) *LockableExt { func NewLockableExt(requirements []NodeID) *LockableExt {
@ -62,17 +50,15 @@ func NewLockableExt(requirements []NodeID) *LockableExt {
} }
func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) { func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
messages := Messages{}
signal := NewLockSignal("unlock") signal := NewLockSignal("unlock")
messages = messages.Add(ctx, node.ID, node, nil, signal) messages := []SendMsg{{node.ID, signal}}
return signal.ID(), ctx.Send(messages) return signal.ID(), ctx.Send(node, messages)
} }
func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) { func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) {
messages := Messages{}
signal := NewLockSignal("lock") signal := NewLockSignal("lock")
messages = messages.Add(ctx, node.ID, node, nil, signal) messages := []SendMsg{{node.ID, signal}}
return signal.ID(), ctx.Send(messages) return signal.ID(), ctx.Send(node, messages)
} }
func (ext *LockableExt) Load(ctx *Context, node *Node) error { func (ext *LockableExt) Load(ctx *Context, node *Node) error {
@ -82,8 +68,8 @@ func (ext *LockableExt) Load(ctx *Context, node *Node) error {
func (ext *LockableExt) Unload(ctx *Context, node *Node) { func (ext *LockableExt) Unload(ctx *Context, node *Node) {
} }
func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) (Messages, Changes) { func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) ([]SendMsg, Changes) {
var messages Messages = nil var messages []SendMsg = nil
var changes Changes = nil var changes Changes = nil
info, info_found := node.ProcessResponse(ext.WaitInfos, signal) info, info_found := node.ProcessResponse(ext.WaitInfos, signal)
@ -126,7 +112,7 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI
ext.Requirements[id] = Unlocking ext.Requirements[id] = Unlocking
lock_signal := NewLockSignal("unlock") lock_signal := NewLockSignal("unlock")
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond) ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond)
messages = messages.Add(ctx, id, node, nil, lock_signal) messages = append(messages, SendMsg{id, lock_signal})
ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID) ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID)
} }
} }
@ -153,43 +139,43 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI
return messages, changes return messages, changes
} }
func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) (Messages, Changes) { func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID, signal *LinkSignal) ([]SendMsg, Changes) {
var messages Messages = nil var messages []SendMsg = nil
var changes = Changes{} var changes = Changes{}
if ext.State == Unlocked { if ext.State == Unlocked {
switch signal.Action { switch signal.Action {
case "add": case "add":
_, exists := ext.Requirements[signal.NodeID] _, exists := ext.Requirements[signal.NodeID]
if exists == true { if exists == true {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "already_requirement")) messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "already_requirement")})
} else { } else {
if ext.Requirements == nil { if ext.Requirements == nil {
ext.Requirements = map[NodeID]ReqState{} ext.Requirements = map[NodeID]ReqState{}
} }
ext.Requirements[signal.NodeID] = Unlocked ext.Requirements[signal.NodeID] = Unlocked
changes.Add("requirements") changes.Add("requirements")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) messages = append(messages, SendMsg{source, NewSuccessSignal(signal.ID())})
} }
case "remove": case "remove":
_, exists := ext.Requirements[signal.NodeID] _, exists := ext.Requirements[signal.NodeID]
if exists == false { if exists == false {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "can't link: not_requirement")) messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "can't link: not_requirement")})
} else { } else {
delete(ext.Requirements, signal.NodeID) delete(ext.Requirements, signal.NodeID)
changes.Add("requirements") changes.Add("requirements")
messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) messages = append(messages, SendMsg{source, NewSuccessSignal(signal.ID())})
} }
default: default:
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "unknown_action")) messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "unknown_action")})
} }
} else { } else {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked")) messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_unlocked")})
} }
return messages, changes return messages, changes
} }
func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) (Messages, Changes) { func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) ([]SendMsg, Changes) {
var messages Messages = nil var messages []SendMsg = nil
var changes = Changes{} var changes = Changes{}
if source == node.ID { if source == node.ID {
return messages, changes return messages, changes
@ -218,7 +204,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
ext.State = Locked ext.State = Locked
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
changes.Add("state", "owner", "requirements") changes.Add("state", "owner", "requirements")
messages = messages.Add(ctx, *ext.Owner, node, nil, NewSuccessSignal(ext.PendingID)) messages = append(messages, SendMsg{*ext.Owner, NewSuccessSignal(ext.PendingID)})
} else { } else {
changes.Add("requirements") changes.Add("requirements")
ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, len(ext.Requirements)) ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, len(ext.Requirements))
@ -228,7 +214,7 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
lock_signal := NewLockSignal("unlock") lock_signal := NewLockSignal("unlock")
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", info.Destination, lock_signal, 100*time.Millisecond) ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", info.Destination, lock_signal, 100*time.Millisecond)
messages = messages.Add(ctx, info.Destination, node, nil, lock_signal) messages = append(messages, SendMsg{info.Destination, lock_signal})
ctx.Log.Logf("lockable", "sending abort_lock to %s for %s", info.Destination, node.ID) ctx.Log.Logf("lockable", "sending abort_lock to %s for %s", info.Destination, node.ID)
} }
@ -254,10 +240,10 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
ext.ReqID = nil ext.ReqID = nil
changes.Add("state", "owner", "req_id") changes.Add("state", "owner", "req_id")
messages = messages.Add(ctx, previous_owner, node, nil, NewSuccessSignal(ext.PendingID)) messages = append(messages, SendMsg{previous_owner, NewSuccessSignal(ext.PendingID)})
} else if old_state == AbortingLock { } else if old_state == AbortingLock {
changes.Add("state", "pending_owner") changes.Add("state", "pending_owner")
messages = messages.Add(ctx, *ext.PendingOwner, node, nil, NewErrorSignal(*ext.ReqID, "not_unlocked")) messages = append(messages, SendMsg{*ext.PendingOwner, NewErrorSignal(*ext.ReqID, "not_unlocked")})
ext.PendingOwner = ext.Owner ext.PendingOwner = ext.Owner
} }
} else { } else {
@ -272,8 +258,8 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod
} }
// Handle a LockSignal and update the extensions owner/requirement states // Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) (Messages, Changes) { func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) ([]SendMsg, Changes) {
var messages Messages = nil var messages []SendMsg = nil
var changes = Changes{} var changes = Changes{}
switch signal.State { switch signal.State {
@ -286,7 +272,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.PendingOwner = &new_owner ext.PendingOwner = &new_owner
ext.Owner = &new_owner ext.Owner = &new_owner
changes.Add("state", "pending_owner", "owner") changes.Add("state", "pending_owner", "owner")
messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) messages = append(messages, SendMsg{new_owner, NewSuccessSignal(signal.ID())})
} else { } else {
ext.State = Locking ext.State = Locking
id := signal.ID() id := signal.ID()
@ -304,11 +290,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("lock", id, lock_signal, 500*time.Millisecond) ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("lock", id, lock_signal, 500*time.Millisecond)
ext.Requirements[id] = Locking ext.Requirements[id] = Locking
messages = messages.Add(ctx, id, node, nil, lock_signal) messages = append(messages, SendMsg{id, lock_signal})
} }
} }
default: default:
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked")) messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_unlocked")})
ctx.Log.Logf("lockable", "Tried to lock %s while %s", node.ID, ext.State) ctx.Log.Logf("lockable", "Tried to lock %s while %s", node.ID, ext.State)
} }
case "unlock": case "unlock":
@ -319,7 +305,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.PendingOwner = nil ext.PendingOwner = nil
ext.Owner = nil ext.Owner = nil
changes.Add("state", "pending_owner", "owner") changes.Add("state", "pending_owner", "owner")
messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) messages = append(messages, SendMsg{new_owner, NewSuccessSignal(signal.ID())})
} else if source == *ext.Owner { } else if source == *ext.Owner {
ext.State = Unlocking ext.State = Unlocking
id := signal.ID() id := signal.ID()
@ -336,11 +322,11 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond) ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond)
ext.Requirements[id] = Unlocking ext.Requirements[id] = Unlocking
messages = messages.Add(ctx, id, node, nil, lock_signal) messages = append(messages, SendMsg{id, lock_signal})
} }
} }
} else { } else {
messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_locked")) messages = append(messages, SendMsg{source, NewErrorSignal(signal.ID(), "not_locked")})
} }
default: default:
ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State) ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", signal.State)
@ -348,8 +334,8 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID
return messages, changes return messages, changes
} }
func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source NodeID, signal *TimeoutSignal) (Messages, Changes) { func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source NodeID, signal *TimeoutSignal) ([]SendMsg, Changes) {
var messages Messages = nil var messages []SendMsg = nil
var changes = Changes{} var changes = Changes{}
wait_info, found := node.ProcessResponse(ext.WaitInfos, signal) wait_info, found := node.ProcessResponse(ext.WaitInfos, signal)
@ -380,7 +366,7 @@ func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source Nod
ext.Requirements[id] = Unlocking ext.Requirements[id] = Unlocking
lock_signal := NewLockSignal("unlock") lock_signal := NewLockSignal("unlock")
ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond) ext.WaitInfos[lock_signal.Id] = node.QueueTimeout("unlock", id, lock_signal, 100*time.Millisecond)
messages = messages.Add(ctx, id, node, nil, lock_signal) messages = append(messages, SendMsg{id, lock_signal})
ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID) ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID)
} }
} }
@ -405,124 +391,32 @@ func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source Nod
return messages, changes return messages, changes
} }
// LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes // LockableExts process status signals by forwarding them to it's owner
// LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state
func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) ([]SendMsg, Changes) {
var messages Messages = nil var messages []SendMsg = nil
var changes = Changes{} var changes = Changes{}
switch signal.Direction() { switch sig := signal.(type) {
case Up: case *StatusSignal:
if ext.Owner != nil { if ext.Owner != nil {
if *ext.Owner != node.ID { if *ext.Owner != node.ID {
messages = messages.Add(ctx, *ext.Owner, node, nil, signal) messages = append(messages, SendMsg{*ext.Owner, signal})
} }
} }
case *LinkSignal:
case Down: messages, changes = ext.HandleLinkSignal(ctx, node, source, sig)
for requirement := range(ext.Requirements) { case *LockSignal:
messages = messages.Add(ctx, requirement, node, nil, signal) messages, changes = ext.HandleLockSignal(ctx, node, source, sig)
} case *ErrorSignal:
messages, changes = ext.HandleErrorSignal(ctx, node, source, sig)
case Direct: case *SuccessSignal:
switch sig := signal.(type) { messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig)
case *LinkSignal: case *TimeoutSignal:
messages, changes = ext.HandleLinkSignal(ctx, node, source, sig) messages, changes = ext.HandleTimeoutSignal(ctx, node, source, sig)
case *LockSignal:
messages, changes = ext.HandleLockSignal(ctx, node, source, sig)
case *ErrorSignal:
messages, changes = ext.HandleErrorSignal(ctx, node, source, sig)
case *SuccessSignal:
messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig)
case *TimeoutSignal:
messages, changes = ext.HandleTimeoutSignal(ctx, node, source, sig)
default:
}
default: default:
} }
return messages, changes
}
type OwnerOfPolicy struct {
PolicyHeader
Rules Tree `gv:"rules"`
}
func NewOwnerOfPolicy(rules Tree) OwnerOfPolicy {
return OwnerOfPolicy{
PolicyHeader: NewPolicyHeader(),
Rules: rules,
}
}
func (policy OwnerOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
func (policy OwnerOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
l_ext, err := GetExt[LockableExt](node)
if err != nil {
ctx.Log.Logf("lockable", "OwnerOfPolicy.Allows called on node without LockableExt")
return nil, Deny
}
if l_ext.Owner == nil { return messages, changes
return nil, Deny
}
if principal_id == *l_ext.Owner {
return nil, Allow
}
return nil, Deny
}
type RequirementOfPolicy struct {
PerNodePolicy
}
func NewRequirementOfPolicy(dep_rules map[NodeID]Tree) RequirementOfPolicy {
return RequirementOfPolicy {
PerNodePolicy: NewPerNodePolicy(dep_rules),
}
} }
func (policy RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
sig, ok := signal.(*ReadResultSignal)
if ok == false {
return Deny
}
ext, ok := sig.Extensions[ExtTypeFor[LockableExt]()]
if ok == false {
return Deny
}
reqs_ser, ok := ext["requirements"]
if ok == false {
return Deny
}
reqs_type, _, err := DeserializeType(ctx, reqs_ser.TypeStack)
if err != nil {
return Deny
}
reqs_if, _, err := DeserializeValue(ctx, reqs_type, reqs_ser.Data)
if err != nil {
return Deny
}
requirements, ok := reqs_if.Interface().(map[NodeID]ReqState)
if ok == false {
return Deny
}
for req := range(requirements) {
if req == current.Principal {
return policy.NodeRules[sig.NodeID].Allows(current.Action)
}
}
return Deny
}

@ -3,48 +3,24 @@ package graphvent
import ( import (
"testing" "testing"
"time" "time"
"crypto/ed25519"
"crypto/rand"
) )
func lockableTestContext(t *testing.T, logs []string) *Context {
ctx := logTestContext(t, logs)
err := RegisterNodeType(ctx, "Lockable", []ExtType{ExtTypeFor[LockableExt]()}, map[string]FieldIndex{})
fatalErr(t, err)
return ctx
}
func TestLink(t *testing.T) { func TestLink(t *testing.T) {
ctx := lockableTestContext(t, []string{"lockable", "listener"}) ctx := logTestContext(t, []string{"lockable", "listener"})
l1_pub, l1_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
l1_id := KeyID(l1_pub)
policy := NewPerNodePolicy(map[NodeID]Tree{
l1_id: nil,
})
l2_listener := NewListenerExt(10) l2_listener := NewListenerExt(10)
l2, err := NewNode(ctx, nil, "Lockable", 10, []Policy{policy}, l2, err := NewNode(ctx, nil, "Lockable", 10, l2_listener, NewLockableExt(nil))
l2_listener,
NewLockableExt(nil),
)
fatalErr(t, err) fatalErr(t, err)
l1_lockable := NewLockableExt(nil) l1_lockable := NewLockableExt(nil)
l1_listener := NewListenerExt(10) l1_listener := NewListenerExt(10)
l1, err := NewNode(ctx, l1_key, "Lockable", 10, nil, l1, err := NewNode(ctx, nil, "Lockable", 10, l1_listener, l1_lockable)
l1_listener,
l1_lockable,
)
fatalErr(t, err) fatalErr(t, err)
msgs := Messages{}
link_signal := NewLinkSignal("add", l2.ID) link_signal := NewLinkSignal("add", l2.ID)
msgs = msgs.Add(ctx, l1.ID, l1, nil, link_signal) msgs := []SendMsg{{l1.ID, link_signal}}
err = ctx.Send(msgs) err = ctx.Send(l1, msgs)
fatalErr(t, err) fatalErr(t, err)
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, link_signal.ID()) _, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, link_signal.ID())
@ -57,10 +33,9 @@ func TestLink(t *testing.T) {
t.Fatalf("l2 in bad requirement state in l1: %+v", state) t.Fatalf("l2 in bad requirement state in l1: %+v", state)
} }
msgs = Messages{}
unlink_signal := NewLinkSignal("remove", l2.ID) unlink_signal := NewLinkSignal("remove", l2.ID)
msgs = msgs.Add(ctx, l1.ID, l1, nil, unlink_signal) msgs = []SendMsg{{l1.ID, unlink_signal}}
err = ctx.Send(msgs) err = ctx.Send(l1, msgs)
fatalErr(t, err) fatalErr(t, err)
_, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, unlink_signal.ID()) _, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, unlink_signal.ID())
@ -68,20 +43,10 @@ func TestLink(t *testing.T) {
} }
func Test1000Lock(t *testing.T) { func Test1000Lock(t *testing.T) {
ctx := lockableTestContext(t, []string{"test", "lockable"}) ctx := logTestContext(t, []string{"test"})
l_pub, listener_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
listener_id := KeyID(l_pub)
child_policy := NewPerNodePolicy(map[NodeID]Tree{
listener_id: {
SerializedType(SignalTypeFor[LockSignal]()): nil,
},
})
NewLockable := func()(*Node) { NewLockable := func()(*Node) {
l, err := NewNode(ctx, nil, "Lockable", 10, []Policy{child_policy}, l, err := NewNode(ctx, nil, "Lockable", 10, NewLockableExt(nil))
NewLockableExt(nil),
)
fatalErr(t, err) fatalErr(t, err)
return l return l
} }
@ -93,15 +58,8 @@ func Test1000Lock(t *testing.T) {
} }
ctx.Log.Logf("test", "CREATED_1000") ctx.Log.Logf("test", "CREATED_1000")
l_policy := NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[LockSignal]()): nil,
})
listener := NewListenerExt(5000) listener := NewListenerExt(5000)
node, err := NewNode(ctx, listener_key, "Lockable", 5000, []Policy{l_policy}, node, err := NewNode(ctx, nil, "Lockable", 5000, listener, NewLockableExt(reqs))
listener,
NewLockableExt(reqs),
)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "CREATED_LISTENER") ctx.Log.Logf("test", "CREATED_LISTENER")
@ -121,16 +79,11 @@ func Test1000Lock(t *testing.T) {
} }
func TestLock(t *testing.T) { func TestLock(t *testing.T) {
ctx := lockableTestContext(t, []string{"test", "lockable"}) ctx := logTestContext(t, []string{"test", "lockable"})
policy := NewAllNodesPolicy(nil)
NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) {
listener := NewListenerExt(1000) listener := NewListenerExt(1000)
l, err := NewNode(ctx, nil, "Lockable", 10, []Policy{policy}, l, err := NewNode(ctx, nil, "Lockable", 10, listener, NewLockableExt(reqs))
listener,
NewLockableExt(reqs),
)
fatalErr(t, err) fatalErr(t, err)
return l, listener return l, listener
} }

@ -1,114 +1,11 @@
package graphvent package graphvent
import ( type SendMsg struct {
"time" Dest NodeID
"crypto/ed25519"
"crypto/rand"
"crypto"
)
type AuthInfo struct {
// The Node that issued the authorization
Identity ed25519.PublicKey
// Time the authorization was generated
Start time.Time
// Signature of Start + Principal with Identity private key
Signature []byte
}
type AuthorizationToken struct {
AuthInfo
// The private key generated by the client, encrypted with the servers public key
KeyEncrypted []byte
}
type ClientAuthorization struct {
AuthInfo
// The private key generated by the client
Key ed25519.PrivateKey
}
// Authorization structs can be passed in a message that originated from a different node than the sender
type Authorization struct {
AuthInfo
// The public key generated for this authorization
Key ed25519.PublicKey
}
type Message struct {
Dest NodeID
Source ed25519.PublicKey
Authorization *Authorization
Signal Signal Signal Signal
Signature []byte
}
type Messages []*Message
func (msgs Messages) Add(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) Messages {
msg, err := NewMessage(ctx, dest, source, authorization, signal)
if err != nil {
panic(err)
} else {
msgs = append(msgs, msg)
}
return msgs
}
func NewMessages(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signals... Signal) Messages {
messages := Messages{}
for _, signal := range(signals) {
messages = messages.Add(ctx, dest, source, authorization, signal)
}
return messages
} }
func NewMessage(ctx *Context, dest NodeID, source *Node, authorization *ClientAuthorization, signal Signal) (*Message, error) { type RecvMsg struct {
signal_ser, err := SerializeAny(ctx, signal) Source NodeID
if err != nil { Signal Signal
return nil, err
}
signal_chunks, err := signal_ser.Chunks()
if err != nil {
return nil, err
}
dest_ser, err := dest.MarshalBinary()
if err != nil {
return nil, err
}
source_ser, err := source.ID.MarshalBinary()
if err != nil {
return nil, err
}
sig_data := append(dest_ser, source_ser...)
sig_data = append(sig_data, signal_chunks.Slice()...)
var message_auth *Authorization = nil
if authorization != nil {
sig_data = append(sig_data, authorization.Signature...)
message_auth = &Authorization{
authorization.AuthInfo,
authorization.Key.Public().(ed25519.PublicKey),
}
}
sig, err := source.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil {
return nil, err
}
return &Message{
Dest: dest,
Source: source.Key.Public().(ed25519.PublicKey),
Authorization: message_auth,
Signal: signal,
Signature: sig,
}, nil
} }

@ -10,7 +10,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
badger "github.com/dgraph-io/badger/v3" _ "github.com/dgraph-io/badger/v3"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -33,6 +33,16 @@ func IDFromBytes(bytes []byte) (NodeID, error) {
return NodeID(id), err return NodeID(id), err
} }
func (id NodeID) MarshalText() ([]byte, error) {
return []byte(id.String()), nil
}
func (id *NodeID) UnmarshalText(text []byte) error {
parsed, err := ParseID(string(text))
*id = parsed
return err
}
// Parse an ID from a string // Parse an ID from a string
func ParseID(str string) (NodeID, error) { func ParseID(str string) (NodeID, error) {
id_uuid, err := uuid.Parse(str) id_uuid, err := uuid.Parse(str)
@ -57,24 +67,6 @@ func (q QueuedSignal) String() string {
return fmt.Sprintf("%+v@%s", reflect.TypeOf(q.Signal), q.Time) return fmt.Sprintf("%+v@%s", reflect.TypeOf(q.Signal), q.Time)
} }
type PendingACL struct {
Counter int
Responses []ResponseSignal
TimeoutID uuid.UUID
Action Tree
Principal NodeID
Signal Signal
Source NodeID
}
type PendingACLSignal struct {
Policy uuid.UUID
Timeout uuid.UUID
ID uuid.UUID
}
// Default message channel size for nodes // Default message channel size for nodes
// Nodes represent a group of extensions that can be collectively addressed // Nodes represent a group of extensions that can be collectively addressed
type Node struct { type Node struct {
@ -84,13 +76,8 @@ type Node struct {
// TODO: move each extension to it's own db key, and extend changes to notify which extension was changed // TODO: move each extension to it's own db key, and extend changes to notify which extension was changed
Extensions map[ExtType]Extension Extensions map[ExtType]Extension
Policies []Policy `gv:"policies"`
PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"`
PendingACLSignals map[uuid.UUID]PendingACLSignal `gv:"pending_signal"`
// Channel for this node to receive messages from the Context // Channel for this node to receive messages from the Context
MsgChan chan *Message MsgChan chan RecvMsg
// Size of MsgChan // Size of MsgChan
BufferSize uint32 `gv:"buffer_size"` BufferSize uint32 `gv:"buffer_size"`
// Channel for this node to process delayed signals // Channel for this node to process delayed signals
@ -110,37 +97,14 @@ func (node *Node) PostDeserialize(ctx *Context) error {
public := node.Key.Public().(ed25519.PublicKey) public := node.Key.Public().(ed25519.PublicKey)
node.ID = KeyID(public) node.ID = KeyID(public)
node.MsgChan = make(chan *Message, node.BufferSize) node.MsgChan = make(chan RecvMsg, node.BufferSize)
return nil return nil
} }
type RuleResult int
const (
Allow RuleResult = iota
Deny
Pending
)
func (node *Node) Allows(ctx *Context, principal_id NodeID, action Tree)(map[uuid.UUID]Messages, RuleResult) {
pends := map[uuid.UUID]Messages{}
for _, policy := range(node.Policies) {
msgs, resp := policy.Allows(ctx, principal_id, action, node)
if resp == Allow {
return nil, Allow
} else if resp == Pending {
pends[policy.ID()] = msgs
}
}
if len(pends) != 0 {
return pends, Pending
}
return nil, Deny
}
type WaitReason string type WaitReason string
type WaitInfo struct { type WaitInfo struct {
Destination NodeID `gv:"destination"` Destination NodeID `gv:"destination" node:"Base"`
Timeout uuid.UUID `gv:"timeout"` Timeout uuid.UUID `gv:"timeout"`
Reason WaitReason `gv:"reason"` Reason WaitReason `gv:"reason"`
} }
@ -250,37 +214,23 @@ func (err StringError) MarshalBinary() ([]byte, error) {
return []byte(string(err)), nil return []byte(string(err)), nil
} }
func NewErrorField(fstring string, args ...interface{}) SerializedValue { func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]any {
str := StringError(fmt.Sprintf(fstring, args...))
str_ser, err := str.MarshalBinary()
if err != nil {
panic(err)
}
return SerializedValue{
TypeStack: []SerializedType{SerializedTypeFor[error]()},
Data: str_ser,
}
}
func (node *Node) ReadFields(ctx *Context, reqs map[ExtType][]string)map[ExtType]map[string]SerializedValue {
ctx.Log.Logf("read_field", "Reading %+v on %+v", reqs, node.ID) ctx.Log.Logf("read_field", "Reading %+v on %+v", reqs, node.ID)
exts := map[ExtType]map[string]SerializedValue{} exts := map[ExtType]map[string]any{}
for ext_type, field_reqs := range(reqs) { for ext_type, field_reqs := range(reqs) {
fields := map[string]SerializedValue{} ext_info, ext_known := ctx.Extensions[ext_type]
for _, req := range(field_reqs) { if ext_known {
ext, exists := node.Extensions[ext_type] fields := map[string]any{}
if exists == false { for _, req := range(field_reqs) {
fields[req] = NewErrorField("%+v does not have %+v extension", node.ID, ext_type) ext, exists := node.Extensions[ext_type]
} else { if exists == false {
f, err := SerializeField(ctx, ext, req) fields[req] = fmt.Errorf("%+v does not have %+v extension", node.ID, ext_type)
if err != nil {
fields[req] = NewErrorField(err.Error())
} else { } else {
fields[req] = f fields[req] = reflect.ValueOf(ext).Elem().FieldByIndex(ext_info.Fields[req].Index).Interface()
} }
} }
exts[ext_type] = fields
} }
exts[ext_type] = fields
} }
return exts return exts
} }
@ -292,106 +242,14 @@ func nodeLoop(ctx *Context, node *Node) error {
return fmt.Errorf("%s is already started, will not start again", node.ID) return fmt.Errorf("%s is already started, will not start again", node.ID)
} }
// Load each extension before starting the main loop
for _, extension := range(node.Extensions) {
err := extension.Load(ctx, node)
if err != nil {
return err
}
}
run := true run := true
for run == true { for run == true {
var signal Signal var signal Signal
var source NodeID var source NodeID
select { select {
case msg := <- node.MsgChan: case msg := <- node.MsgChan:
ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal)
signal_ser, err := SerializeAny(ctx, msg.Signal)
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, msg.Signal)
}
chunks, err := signal_ser.Chunks()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", err, signal_ser)
continue
}
dst_id_ser, err := msg.Dest.MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_DEST_ID_SER_ERR: %e", err)
continue
}
src_id_ser, err := KeyID(msg.Source).MarshalBinary()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SRC_ID_SER_ERR: %e", err)
continue
}
sig_data := append(dst_id_ser, src_id_ser...)
sig_data = append(sig_data, chunks.Slice()...)
if msg.Authorization != nil {
sig_data = append(sig_data, msg.Authorization.Signature...)
}
validated := ed25519.Verify(msg.Source, sig_data, msg.Signature)
if validated == false {
ctx.Log.Logf("signal_verify", "SIGNAL_VERIFY_ERR: %s - %s", node.ID, reflect.TypeOf(msg.Signal))
continue
}
var princ_id NodeID
if msg.Authorization == nil {
princ_id = KeyID(msg.Source)
} else {
err := ValidateAuthorization(*msg.Authorization, time.Hour)
if err != nil {
ctx.Log.Logf("node", "Authorization validation failed: %s", err)
continue
}
princ_id = KeyID(msg.Authorization.Identity)
}
if princ_id != node.ID {
pends, resp := node.Allows(ctx, princ_id, msg.Signal.Permission())
if resp == Deny {
ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %+v(%+s)", princ_id, node.ID, reflect.TypeOf(msg.Signal), msg.Signal)
ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source)
msgs := Messages{}
msgs = msgs.Add(ctx, KeyID(msg.Source), node, nil, NewErrorSignal(msg.Signal.ID(), "acl denied"))
ctx.Send(msgs)
continue
} else if resp == Pending {
ctx.Log.Logf("policy", "SIGNAL_POLICY_PENDING: %s->%s - %s - %+v", princ_id, node.ID, msg.Signal.Permission(), pends)
timeout_signal := NewACLTimeoutSignal(msg.Signal.ID())
node.QueueSignal(time.Now().Add(100*time.Millisecond), timeout_signal)
msgs := Messages{}
for policy_type, sigs := range(pends) {
for _, m := range(sigs) {
msgs = append(msgs, m)
timeout_signal := NewTimeoutSignal(m.Signal.ID())
node.QueueSignal(time.Now().Add(time.Second), timeout_signal)
node.PendingACLSignals[m.Signal.ID()] = PendingACLSignal{policy_type, timeout_signal.Id, msg.Signal.ID()}
}
}
node.PendingACLs[msg.Signal.ID()] = PendingACL{
Counter: len(msgs),
TimeoutID: timeout_signal.ID(),
Action: msg.Signal.Permission(),
Principal: princ_id,
Responses: []ResponseSignal{},
Signal: msg.Signal,
Source: KeyID(msg.Source),
}
ctx.Log.Logf("policy", "Sending signals for pending ACL: %+v", msgs)
ctx.Send(msgs)
continue
} else if resp == Allow {
ctx.Log.Logf("policy", "SIGNAL_POLICY_ALLOW: %s->%s - %s", princ_id, node.ID, reflect.TypeOf(msg.Signal))
}
} else {
ctx.Log.Logf("policy", "SIGNAL_POLICY_SELF: %s - %s", node.ID, reflect.TypeOf(msg.Signal))
}
signal = msg.Signal signal = msg.Signal
source = KeyID(msg.Source) source = msg.Source
case <-node.TimeoutChan: case <-node.TimeoutChan:
signal = node.NextSignal.Signal signal = node.NextSignal.Signal
@ -425,68 +283,12 @@ func nodeLoop(ctx *Context, node *Node) error {
ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
response, ok := signal.(ResponseSignal)
if ok == true {
info, waiting := node.PendingACLSignals[response.ResponseID()]
if waiting == true {
delete(node.PendingACLSignals, response.ResponseID())
ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal)
req_info, exists := node.PendingACLs[info.ID]
if exists == true {
req_info.Counter -= 1
req_info.Responses = append(req_info.Responses, response)
idx := -1
for i, p := range(node.Policies) {
if p.ID() == info.Policy {
idx = i
break
}
}
if idx == -1 {
ctx.Log.Logf("policy", "PENDING_FOR_NONEXISTENT_POLICY: %s - %s", node.ID, info.Policy)
delete(node.PendingACLs, info.ID)
} else {
allowed := node.Policies[idx].ContinueAllows(ctx, req_info, signal)
if allowed == Allow {
ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal)
signal = req_info.Signal
source = req_info.Source
err := node.DequeueSignal(req_info.TimeoutID)
if err != nil {
ctx.Log.Logf("node", "dequeue error: %s", err)
}
delete(node.PendingACLs, info.ID)
} else if req_info.Counter == 0 {
ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal)
// Send the denied response
msgs := Messages{}
msgs = msgs.Add(ctx, req_info.Source, node, nil, NewErrorSignal(req_info.Signal.ID(), "acl_denied"))
err := ctx.Send(msgs)
if err != nil {
ctx.Log.Logf("signal", "SEND_ERR: %s", err)
}
err = node.DequeueSignal(req_info.TimeoutID)
if err != nil {
ctx.Log.Logf("node", "ACL_DEQUEUE_ERROR: timeout signal not in queue when trying to clear after counter hit 0 %s, %s - %s", err, signal.ID(), req_info.TimeoutID)
}
delete(node.PendingACLs, info.ID)
} else {
node.PendingACLs[info.ID] = req_info
continue
}
}
}
}
}
switch sig := signal.(type) { switch sig := signal.(type) {
case *ReadSignal: case *ReadSignal:
result := node.ReadFields(ctx, sig.Extensions) result := node.ReadFields(ctx, sig.Extensions)
msgs := Messages{} msgs := []SendMsg{}
msgs = msgs.Add(ctx, source, node, nil, NewReadResultSignal(sig.ID(), node.ID, node.Type, result)) msgs = append(msgs, SendMsg{source, NewReadResultSignal(sig.ID(), node.ID, node.Type, result)})
ctx.Send(msgs) ctx.Send(node, msgs)
default: default:
err := node.Process(ctx, source, signal) err := node.Process(ctx, source, signal)
@ -522,7 +324,7 @@ func (node *Node) QueueChanges(ctx *Context, changes map[ExtType]Changes) error
func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal) ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal)
messages := Messages{} messages := []SendMsg{}
changes := map[ExtType]Changes{} changes := map[ExtType]Changes{}
for ext_type, ext := range(node.Extensions) { for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
@ -537,7 +339,7 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
ctx.Log.Logf("changes", "Changes for %s after %+v - %+v", node.ID, reflect.TypeOf(signal), changes) ctx.Log.Logf("changes", "Changes for %s after %+v - %+v", node.ID, reflect.TypeOf(signal), changes)
if len(messages) != 0 { if len(messages) != 0 {
send_err := ctx.Send(messages) send_err := ctx.Send(node, messages)
if send_err != nil { if send_err != nil {
return send_err return send_err
} }
@ -596,7 +398,7 @@ func KeyID(pub ed25519.PublicKey) NodeID {
} }
// Create a new node in memory and start it's event loop // Create a new node in memory and start it's event loop
func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size uint32, policies []Policy, extensions ...Extension) (*Node, error) { func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size uint32, extensions ...Extension) (*Node, error) {
node_type, known_type := ctx.NodeTypes[type_name] node_type, known_type := ctx.NodeTypes[type_name]
if known_type == false { if known_type == false {
return nil, fmt.Errorf("%s is not a known node type", type_name) return nil, fmt.Errorf("%s is not a known node type", type_name)
@ -618,55 +420,48 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, type_name string, buffer_size
return nil, fmt.Errorf("Attempted to create an existing node") return nil, fmt.Errorf("Attempted to create an existing node")
} }
def, exists := ctx.Nodes[node_type]
if exists == false {
return nil, fmt.Errorf("Node type %+v not registered in Context", node_type)
}
ext_map := map[ExtType]Extension{} ext_map := map[ExtType]Extension{}
for _, ext := range(extensions) { for _, ext := range(extensions) {
ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext)] ext_type, exists := ctx.ExtensionTypes[reflect.TypeOf(ext).Elem()]
if exists == false { if exists == false {
return nil, fmt.Errorf(fmt.Sprintf("%+v is not a known Extension", reflect.TypeOf(ext))) return nil, fmt.Errorf(fmt.Sprintf("%+v is not a known Extension", reflect.TypeOf(ext)))
} }
_, exists = ext_map[ext_type] _, exists = ext_map[ext_type.ExtType]
if exists == true { if exists == true {
return nil, fmt.Errorf("Cannot add the same extension to a node twice") return nil, fmt.Errorf("Cannot add the same extension to a node twice")
} }
ext_map[ext_type] = ext ext_map[ext_type.ExtType] = ext
} }
for _, required_ext := range(def.Extensions) { for _, required_ext := range(node_type.Extensions) {
_, exists := ext_map[required_ext] _, exists := ext_map[required_ext]
if exists == false { if exists == false {
return nil, fmt.Errorf(fmt.Sprintf("%+v requires %+v", node_type, required_ext)) return nil, fmt.Errorf(fmt.Sprintf("%+v requires %+v", node_type, required_ext))
} }
} }
policies = append(policies, DefaultPolicy)
node := &Node{ node := &Node{
Key: key, Key: key,
ID: id, ID: id,
Type: node_type, Type: node_type.NodeType,
Extensions: ext_map, Extensions: ext_map,
Policies: policies, MsgChan: make(chan RecvMsg, buffer_size),
PendingACLs: map[uuid.UUID]PendingACL{},
PendingACLSignals: map[uuid.UUID]PendingACLSignal{},
MsgChan: make(chan *Message, buffer_size),
BufferSize: buffer_size, BufferSize: buffer_size,
SignalQueue: []QueuedSignal{}, SignalQueue: []QueuedSignal{},
writeSignalQueue: false,
} }
err = WriteNodeExtList(ctx, node) err = WriteNodeInit(ctx, node)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node.writeSignalQueue = true // Load each extension before starting the main loop
err = WriteNodeInit(ctx, node) for _, extension := range(node.Extensions) {
if err != nil { err := extension.Load(ctx, node)
return nil, err if err != nil {
return nil, err
}
} }
ctx.AddNode(id, node) ctx.AddNode(id, node)
@ -683,258 +478,3 @@ func ExtTypeSuffix(ext_type ExtType) []byte {
binary.BigEndian.PutUint64(ret[4:], uint64(ext_type)) binary.BigEndian.PutUint64(ret[4:], uint64(ext_type))
return ret return ret
} }
func WriteNodeExtList(ctx *Context, node *Node) error {
ext_list := make([]ExtType, len(node.Extensions))
i := 0
for ext_type := range(node.Extensions) {
ext_list[i] = ext_type
i += 1
}
ctx.Log.Logf("db", "Writing ext_list for %s - %+v", node.ID, ext_list)
id_bytes, err := node.ID.MarshalBinary()
if err != nil {
return err
}
ext_list_serialized, err := SerializeAny(ctx, ext_list)
if err != nil {
return err
}
return ctx.DB.Update(func(txn *badger.Txn) error {
return txn.Set(append(id_bytes, extension_suffix...), ext_list_serialized.Data)
})
}
func WriteNodeInit(ctx *Context, node *Node) error {
ctx.Log.Logf("db", "Writing initial entry for %s - %+v", node.ID, node)
ext_serialized := map[ExtType]SerializedValue{}
for ext_type, ext := range(node.Extensions) {
serialized_ext, err := SerializeAny(ctx, ext)
if err != nil {
return err
}
ext_serialized[ext_type] = serialized_ext
}
sq_serialized, err := SerializeAny(ctx, node.SignalQueue)
if err != nil {
return err
}
node_serialized, err := SerializeAny(ctx, node)
if err != nil {
return err
}
id_bytes, err := node.ID.MarshalBinary()
return ctx.DB.Update(func(txn *badger.Txn) error {
err := txn.Set(id_bytes, node_serialized.Data)
if err != nil {
return nil
}
err = txn.Set(append(id_bytes, signal_queue_suffix...), sq_serialized.Data)
if err != nil {
return err
}
for ext_type, data := range(ext_serialized) {
err := txn.Set(append(id_bytes, ExtTypeSuffix(ext_type)...), data.Data)
if err != nil {
return err
}
}
return nil
})
}
func WriteNodeChanges(ctx *Context, node *Node, changes map[ExtType]Changes) error {
ctx.Log.Logf("db", "Writing changes for %s - %+v", node.ID, changes)
ext_serialized := map[ExtType]SerializedValue{}
for ext_type := range(changes) {
ext, ext_exists := node.Extensions[ext_type]
if ext_exists == false {
ctx.Log.Logf("db", "extension 0x%x does not exist for %s", ext_type, node.ID)
} else {
serialized_ext, err := SerializeAny(ctx, ext)
if err != nil {
return err
}
ext_serialized[ext_type] = serialized_ext
}
}
var sq_serialized *SerializedValue = nil
if node.writeSignalQueue == true {
node.writeSignalQueue = false
ser, err := SerializeAny(ctx, node.SignalQueue)
if err != nil {
return err
}
sq_serialized = &ser
}
node_serialized, err := SerializeAny(ctx, node)
if err != nil {
return err
}
id_bytes, err := node.ID.MarshalBinary()
return ctx.DB.Update(func(txn *badger.Txn) error {
err := txn.Set(id_bytes, node_serialized.Data)
if err != nil {
return err
}
if sq_serialized != nil {
err := txn.Set(append(id_bytes, signal_queue_suffix...), sq_serialized.Data)
if err != nil {
return err
}
}
for ext_type, data := range(ext_serialized) {
err := txn.Set(append(id_bytes, ExtTypeSuffix(ext_type)...), data.Data)
if err != nil {
return err
}
}
return nil
})
}
func LoadNode(ctx *Context, id NodeID) (*Node, error) {
ctx.Log.Logf("db", "LOADING_NODE: %s", id)
var node_bytes []byte = nil
var sq_bytes []byte = nil
var ext_bytes = map[ExtType][]byte{}
err := ctx.DB.View(func(txn *badger.Txn) error {
id_bytes, err := id.MarshalBinary()
if err != nil {
return err
}
node_item, err := txn.Get(id_bytes)
if err != nil {
ctx.Log.Logf("db", "node key not found")
return err
}
node_bytes, err = node_item.ValueCopy(nil)
if err != nil {
return err
}
sq_item, err := txn.Get(append(id_bytes, signal_queue_suffix...))
if err != nil {
ctx.Log.Logf("db", "sq key not found")
return err
}
sq_bytes, err = sq_item.ValueCopy(nil)
if err != nil {
return err
}
ext_list_item, err := txn.Get(append(id_bytes, extension_suffix...))
if err != nil {
ctx.Log.Logf("db", "ext_list key not found")
return err
}
ext_list_bytes, err := ext_list_item.ValueCopy(nil)
if err != nil {
return err
}
ext_list_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf([]ExtType{}), ext_list_bytes)
if err != nil {
return err
} else if len(remaining) > 0 {
return fmt.Errorf("Data remaining after ext_list deserialize %d", len(remaining))
}
ext_list, ok := ext_list_value.Interface().([]ExtType)
if ok == false {
return fmt.Errorf("deserialize returned wrong type %s", ext_list_value.Type())
}
for _, ext_type := range(ext_list) {
ext_item, err := txn.Get(append(id_bytes, ExtTypeSuffix(ext_type)...))
if err != nil {
ctx.Log.Logf("db", "ext %s key not found", ext_type)
return err
}
ext_bytes[ext_type], err = ext_item.ValueCopy(nil)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
node_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf((*Node)(nil)), node_bytes)
if err != nil {
return nil, err
} else if len(remaining) != 0 {
return nil, fmt.Errorf("data left after deserializing node %d", len(remaining))
}
node, node_ok := node_value.Interface().(*Node)
if node_ok == false {
return nil, fmt.Errorf("node wrong type %s", node_value.Type())
}
ctx.Log.Logf("db", "Deserialized node bytes %+v", node)
signal_queue_value, remaining, err := DeserializeValue(ctx, reflect.TypeOf([]QueuedSignal{}), sq_bytes)
if err != nil {
return nil, err
} else if len(remaining) != 0 {
return nil, fmt.Errorf("data left after deserializing signal_queue %d", len(remaining))
}
signal_queue, sq_ok := signal_queue_value.Interface().([]QueuedSignal)
if sq_ok == false {
return nil, fmt.Errorf("signal queue wrong type %s", signal_queue_value.Type())
}
for ext_type, data := range(ext_bytes) {
ext_info, exists := ctx.Extensions[ext_type]
if exists == false {
return nil, fmt.Errorf("0x%0x is not a known extension type", ext_type)
}
ext_value, remaining, err := DeserializeValue(ctx, ext_info.Reflect, data)
if err != nil {
return nil, err
} else if len(remaining) > 0 {
return nil, fmt.Errorf("data left after deserializing ext(0x%x) %d", ext_type, len(remaining))
}
ext, ext_ok := ext_value.Interface().(Extension)
if ext_ok == false {
return nil, fmt.Errorf("extension wrong type %s", ext_value.Type())
}
node.Extensions[ext_type] = ext
}
node.SignalQueue = signal_queue
node.NextSignal, node.TimeoutChan = SoonestSignal(signal_queue)
ctx.AddNode(id, node)
ctx.Log.Logf("db", "loaded %+v", node)
go runNode(ctx, node)
return node, nil
}

@ -12,7 +12,7 @@ func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{"node", "db"}) ctx := logTestContext(t, []string{"node", "db"})
node_listener := NewListenerExt(10) node_listener := NewListenerExt(10)
node, err := NewNode(ctx, nil, "Base", 10, nil, NewGroupExt(nil), NewLockableExt(nil), node_listener) node, err := NewNode(ctx, nil, "Base", 10, NewLockableExt(nil), node_listener)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { _, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool {
@ -45,25 +45,18 @@ func TestNodeRead(t *testing.T) {
ctx.Log.Logf("test", "N1: %s", n1_id) ctx.Log.Logf("test", "N1: %s", n1_id)
ctx.Log.Logf("test", "N2: %s", n2_id) ctx.Log.Logf("test", "N2: %s", n2_id)
n1_policy := NewPerNodePolicy(map[NodeID]Tree{
n2_id: {
SerializedType(SignalTypeFor[ReadSignal]()): nil,
},
})
n2_listener := NewListenerExt(10) n2_listener := NewListenerExt(10)
n2, err := NewNode(ctx, n2_key, "Base", 10, nil, NewGroupExt(nil), n2_listener) n2, err := NewNode(ctx, n2_key, "Base", 10, n2_listener)
fatalErr(t, err) fatalErr(t, err)
n1, err := NewNode(ctx, n1_key, "Base", 10, []Policy{n1_policy}, NewGroupExt(nil)) n1, err := NewNode(ctx, n1_key, "Base", 10, NewListenerExt(10))
fatalErr(t, err) fatalErr(t, err)
read_sig := NewReadSignal(map[ExtType][]string{ read_sig := NewReadSignal(map[ExtType][]string{
ExtTypeFor[GroupExt](): {"members"}, ExtTypeFor[ListenerExt](): {"buffer"},
}) })
msgs := Messages{} msgs := []SendMsg{{n1.ID, read_sig}}
msgs = msgs.Add(ctx, n1.ID, n2, nil, read_sig) err = ctx.Send(n2, msgs)
err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, func(sig *ReadResultSignal) bool { res, err := WaitForSignal(n2_listener.Chan, 10*time.Millisecond, func(sig *ReadResultSignal) bool {

@ -1,139 +0,0 @@
package graphvent
import (
"github.com/google/uuid"
)
type Policy interface {
Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult)
ContinueAllows(ctx *Context, current PendingACL, signal Signal)RuleResult
ID() uuid.UUID
}
type PolicyHeader struct {
UUID uuid.UUID `gv:"uuid"`
}
func (header PolicyHeader) ID() uuid.UUID {
return header.UUID
}
func (policy AllNodesPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
return nil, policy.Rules.Allows(action)
}
func (policy AllNodesPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
func (policy PerNodePolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) {
for id, actions := range(policy.NodeRules) {
if id != principal_id {
continue
}
return nil, actions.Allows(action)
}
return nil, Deny
}
func (policy PerNodePolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult {
return Deny
}
func CopyTree(tree Tree) Tree {
if tree == nil {
return nil
}
ret := Tree{}
for name, sub := range(tree) {
ret[name] = CopyTree(sub)
}
return ret
}
func MergeTrees(first Tree, second Tree) Tree {
if first == nil || second == nil {
return nil
}
ret := CopyTree(first)
for name, sub := range(second) {
current, exists := ret[name]
if exists == true {
ret[name] = MergeTrees(current, sub)
} else {
ret[name] = CopyTree(sub)
}
}
return ret
}
type Tree map[SerializedType]Tree
func (rule Tree) Allows(action Tree) RuleResult {
// If the current rule is nil, it's a wildcard and any action being processed is allowed
if rule == nil {
return Allow
// If the rule isn't "allow all" but the action is "request all", deny
} else if action == nil {
return Deny
// If the current action has no children, it's allowed
} else if len(action) == 0 {
return Allow
// If the current rule has no children but the action goes further, it's not allowed
} else if len(rule) == 0 {
return Deny
// If the current rule and action have children, all the children of action must be allowed by rule
} else {
for sub, subtree := range(action) {
subrule, exists := rule[sub]
if exists == false {
return Deny
} else if subrule.Allows(subtree) == Deny {
return Deny
}
}
return Allow
}
}
func NewPolicyHeader() PolicyHeader {
return PolicyHeader{
UUID: uuid.New(),
}
}
func NewPerNodePolicy(node_actions map[NodeID]Tree) PerNodePolicy {
if node_actions == nil {
node_actions = map[NodeID]Tree{}
}
return PerNodePolicy{
PolicyHeader: NewPolicyHeader(),
NodeRules: node_actions,
}
}
type PerNodePolicy struct {
PolicyHeader
NodeRules map[NodeID]Tree `gv:"node_rules"`
}
func NewAllNodesPolicy(rules Tree) AllNodesPolicy {
return AllNodesPolicy{
PolicyHeader: NewPolicyHeader(),
Rules: rules,
}
}
type AllNodesPolicy struct {
PolicyHeader
Rules Tree `gv:"rules"`
}
var DefaultPolicy = NewAllNodesPolicy(Tree{
SerializedType(SignalTypeFor[ResponseSignal]()): nil,
SerializedType(SignalTypeFor[StatusSignal]()): nil,
})

File diff suppressed because it is too large Load Diff

@ -1,247 +1,150 @@
package graphvent package graphvent
import ( import (
"fmt" "testing"
"reflect" "reflect"
"testing" "github.com/google/uuid"
"time"
) )
func TestSerializeTest(t *testing.T) { func testTypeStack[T any](t *testing.T, ctx *Context) {
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"}) reflect_type := reflect.TypeFor[T]()
testSerialize(t, ctx, map[string][]NodeID{"test_group": {RandID(), RandID(), RandID()}}) stack, err := TypeStack(ctx, reflect_type)
testSerialize(t, ctx, map[NodeID]ReqState{ fatalErr(t, err)
RandID(): Locked,
RandID(): Unlocked,
})
}
func TestSerializeBasic(t *testing.T) { ctx.Log.Logf("test", "TypeStack[%s]: %+v", reflect_type, stack)
ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"})
testSerializeComparable[bool](t, ctx, true)
type bool_wrapped bool unwrapped_type, rest, err := UnwrapStack(ctx, stack)
err := RegisterType[bool_wrapped](ctx, nil, nil, nil, DeserializeBool[bool_wrapped])
fatalErr(t, err) fatalErr(t, err)
testSerializeComparable[bool_wrapped](t, ctx, true)
testSerializeSlice[[]bool](t, ctx, []bool{false, false, true, false})
testSerializeComparable[string](t, ctx, "test")
testSerializeComparable[float32](t, ctx, 0.05)
testSerializeComparable[float64](t, ctx, 0.05)
testSerializeComparable[uint](t, ctx, uint(1234))
testSerializeComparable[uint8] (t, ctx, uint8(123))
testSerializeComparable[uint16](t, ctx, uint16(1234))
testSerializeComparable[uint32](t, ctx, uint32(12345))
testSerializeComparable[uint64](t, ctx, uint64(123456))
testSerializeComparable[int](t, ctx, 1234)
testSerializeComparable[int8] (t, ctx, int8(-123))
testSerializeComparable[int16](t, ctx, int16(-1234))
testSerializeComparable[int32](t, ctx, int32(-12345))
testSerializeComparable[int64](t, ctx, int64(-123456))
testSerializeComparable[time.Duration](t, ctx, time.Duration(100))
testSerializeComparable[time.Time](t, ctx, time.Now().Truncate(0))
testSerializeSlice[[]int](t, ctx, []int{123, 456, 789, 101112})
testSerializeSlice[[]int](t, ctx, ([]int)(nil))
testSerializeSliceSlice[[][]int](t, ctx, [][]int{{123, 456, 789, 101112}, {3253, 2341, 735, 212}, {123, 51}, nil})
testSerializeSliceSlice[[][]string](t, ctx, [][]string{{"123", "456", "789", "101112"}, {"3253", "2341", "735", "212"}, {"123", "51"}, nil})
testSerialize(t, ctx, map[int8]map[*int8]string{})
testSerialize(t, ctx, map[int8]time.Time{
1: time.Now(),
3: time.Now().Add(time.Second),
0: time.Now().Add(time.Second*2),
4: time.Now().Add(time.Second*3),
})
testSerialize(t, ctx, Tree{
SerializedTypeFor[NodeType](): nil,
SerializedTypeFor[SerializedType](): {
SerializedTypeFor[NodeType](): Tree{},
},
})
var i interface{} = nil
testSerialize(t, ctx, i)
testSerializeMap(t, ctx, map[int8]interface{}{
0: "abcd",
1: uint32(12345678),
2: i,
3: 123,
})
testSerializeMap(t, ctx, map[int8]int32{
0: 1234,
2: 5678,
4: 9101,
6: 1121,
})
type test_struct struct {
Int int `gv:"int"`
String string `gv:"string"`
}
err = RegisterStruct[test_struct](ctx) if len(rest) != 0 {
fatalErr(t, err) t.Errorf("Types remaining after unwrapping %s: %+v", unwrapped_type, stack)
}
testSerialize(t, ctx, test_struct{ if unwrapped_type != reflect_type {
12345, t.Errorf("Unwrapped type[%+v] doesn't match original[%+v]", unwrapped_type, reflect_type)
"test_string", }
})
testSerialize(t, ctx, Tree{
SerializedKindFor(reflect.Map): nil,
SerializedKindFor(reflect.String): nil,
})
testSerialize(t, ctx, Tree{
SerializedTypeFor[Tree](): nil,
})
testSerialize(t, ctx, Tree{
SerializedTypeFor[Tree](): {
SerializedTypeFor[error](): Tree{},
SerializedKindFor(reflect.Map): nil,
},
SerializedKindFor(reflect.String): nil,
})
type test_slice []string
err = RegisterType[test_slice](ctx, SerializeTypeStub, SerializeSlice, DeserializeTypeStub[test_slice], DeserializeSlice)
fatalErr(t, err)
testSerialize[[]string](t, ctx, []string{"test_1", "test_2", "test_3"}) ctx.Log.Logf("test", "Unwrapped type[%s]: %s", reflect_type, reflect_type)
testSerialize[test_slice](t, ctx, test_slice{"test_1", "test_2", "test_3"})
} }
type test struct { func TestSerializeTypes(t *testing.T) {
Int int `gv:"int"` ctx := logTestContext(t, []string{"test"})
Str string `gv:"string"`
}
func (s test) String() string { testTypeStack[int](t, ctx)
return fmt.Sprintf("%d:%s", s.Int, s.Str) testTypeStack[map[int]string](t, ctx)
testTypeStack[string](t, ctx)
testTypeStack[*string](t, ctx)
testTypeStack[*map[string]*map[*string]int](t, ctx)
testTypeStack[[5]int](t, ctx)
testTypeStack[uuid.UUID](t, ctx)
testTypeStack[NodeID](t, ctx)
} }
func TestSerializeStructTags(t *testing.T) { func testSerializeCompare[T comparable](t *testing.T, ctx *Context, value T) {
ctx := logTestContext(t, []string{"test"}) serialized, err := Serialize(ctx, value)
fatalErr(t, err)
ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
err := RegisterStruct[test](ctx) deserialized, err := Deserialize[T](ctx, serialized)
fatalErr(t, err) fatalErr(t, err)
test_int := 10 if value != deserialized {
test_string := "test" t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
ret := testSerialize(t, ctx, test{
test_int,
test_string,
})
if ret.Int != test_int {
t.Fatalf("Deserialized int %d does not equal test %d", ret.Int, test_int)
} else if ret.Str != test_string {
t.Fatalf("Deserialized string %s does not equal test %s", ret.Str, test_string)
} }
testSerialize(t, ctx, []test{ ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
{
test_int,
test_string,
},
{
test_int * 2,
fmt.Sprintf("%s%s", test_string, test_string),
},
{
test_int * 4,
fmt.Sprintf("%s%s%s", test_string, test_string, test_string),
},
})
} }
func testSerializeMap[M map[T]R, T, R comparable](t *testing.T, ctx *Context, val M) { func testSerializeList[L []T, T comparable](t *testing.T, ctx *Context, value L) {
v := testSerialize(t, ctx, val) serialized, err := Serialize(ctx, value)
for key, value := range(val) { fatalErr(t, err)
recreated, exists := v[key]
if exists == false {
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
} else if recreated != value {
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
}
}
if len(v) != len(val) {
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
}
}
func testSerializeSliceSlice[S [][]T, T comparable](t *testing.T, ctx *Context, val S) { ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[L](), value, serialized)
v := testSerialize(t, ctx, val)
for i, original := range(val) { deserialized, err := Deserialize[L](ctx, serialized)
if (original == nil && v[i] != nil) || (original != nil && v[i] == nil) { fatalErr(t, err)
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
}
for j, o := range(original) {
if v[i][j] != o {
t.Fatalf("DeserializeValue returned wrong value %+v != %+v", v, val)
}
}
}
}
func testSerializeSlice[S []T, T comparable](t *testing.T, ctx *Context, val S) { for i, item := range(value) {
v := testSerialize(t, ctx, val) if item != deserialized[i] {
for i, original := range(val) { t.Fatalf("Deserialized list %+v does not match original %+v", value, deserialized)
if v[i] != original {
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
} }
} }
}
func testSerializeComparable[T comparable](t *testing.T, ctx *Context, val T) { ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
v := testSerialize(t, ctx, val)
if v != val {
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong value %+v != %+v", v, val))
}
} }
func testSerialize[T any](t *testing.T, ctx *Context, val T) T { func testSerializePointer[P interface {*T}, T comparable](t *testing.T, ctx *Context, value P) {
value := reflect.ValueOf(&val).Elem() serialized, err := Serialize(ctx, value)
type_stack, err := SerializeType(ctx, value.Type())
chunks, err := SerializeValue(ctx, value)
value_serialized := SerializedValue{type_stack, chunks.Slice()}
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "Serialized %+v to %+v(%d)", val, value_serialized, len(value_serialized.Data))
value_chunks, err := value_serialized.Chunks() ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[P](), value, serialized)
fatalErr(t, err)
ctx.Log.Logf("test", "Binary: %+v", value_chunks.Slice())
val_parsed, remaining_parse, err := ParseSerializedValue(value_chunks.Slice()) deserialized, err := Deserialize[P](ctx, serialized)
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "Parsed: %+v", val_parsed)
if len(remaining_parse) != 0 { if value == nil && deserialized == nil {
t.Fatal("Data remaining after deserializing value") ctx.Log.Logf("test", "Deserialized nil")
} else if value == nil {
t.Fatalf("Non-nil value[%+v] returned for nil value", deserialized)
} else if deserialized == nil {
t.Fatalf("Nil value returned for non-nil value[%+v]", value)
} else if *deserialized != *value {
t.Fatalf("Deserialized value[%+v] doesn't match original[%+v]", value, deserialized)
} else {
ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", *value, *deserialized)
} }
}
val_type, remaining_types, err := DeserializeType(ctx, val_parsed.TypeStack) func testSerialize[T any](t *testing.T, ctx *Context, value T) {
deserialized_value, remaining_deserialize, err := DeserializeValue(ctx, val_type, val_parsed.Data) serialized, err := Serialize(ctx, value)
fatalErr(t, err) fatalErr(t, err)
if len(remaining_deserialize) != 0 { ctx.Log.Logf("test", "Serialized Value[%s : %+v]: %+v", reflect.TypeFor[T](), value, serialized)
t.Fatal("Data remaining after deserializing value")
} else if len(remaining_types) != 0 { deserialized, err := Deserialize[T](ctx, serialized)
t.Fatal("TypeStack remaining after deserializing value") fatalErr(t, err)
} else if val_type != value.Type() {
t.Fatal(fmt.Sprintf("DeserializeValue returned wrong reflect.Type %+v - %+v", val_type, reflect.TypeOf(val))) ctx.Log.Logf("test", "Deserialized Value[%+v]: %+v", value, deserialized)
} else if deserialized_value.CanConvert(val_type) == false { }
t.Fatal("DeserializeValue returned value that can't convert to original value")
} func TestSerializeValues(t *testing.T) {
ctx.Log.Logf("test", "Value: %+v", deserialized_value.Interface()) ctx := logTestContext(t, []string{"test"})
if val_type.Kind() == reflect.Interface && deserialized_value.Interface() == nil {
var zero T testSerialize(t, ctx, Extension(NewLockableExt(nil)))
return zero
} testSerializeCompare[int8](t, ctx, -64)
return deserialized_value.Interface().(T) testSerializeCompare[int16](t, ctx, -64)
testSerializeCompare[int32](t, ctx, -64)
testSerializeCompare[int64](t, ctx, -64)
testSerializeCompare[int](t, ctx, -64)
testSerializeCompare[uint8](t, ctx, 64)
testSerializeCompare[uint16](t, ctx, 64)
testSerializeCompare[uint32](t, ctx, 64)
testSerializeCompare[uint64](t, ctx, 64)
testSerializeCompare[uint](t, ctx, 64)
testSerializeCompare[string](t, ctx, "test")
a := 12
testSerializePointer[*int](t, ctx, &a)
b := "test"
testSerializePointer[*string](t, ctx, nil)
testSerializePointer[*string](t, ctx, &b)
testSerializeList(t, ctx, []int{1, 2, 3, 4, 5})
testSerializeCompare[bool](t, ctx, true)
testSerializeCompare[bool](t, ctx, false)
testSerializeCompare[int](t, ctx, -1)
testSerializeCompare[uint](t, ctx, 1)
testSerializeCompare[NodeID](t, ctx, RandID())
testSerializeCompare[*int](t, ctx, nil)
testSerializeCompare(t, ctx, "string")
node, err := NewNode(ctx, nil, "Base", 100)
fatalErr(t, err)
testSerialize(t, ctx, node)
} }

@ -7,20 +7,13 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
type SignalDirection uint8
const (
Up SignalDirection = iota
Down
Direct
)
type TimeoutSignal struct { type TimeoutSignal struct {
ResponseHeader ResponseHeader
} }
func NewTimeoutSignal(req_id uuid.UUID) *TimeoutSignal { func NewTimeoutSignal(req_id uuid.UUID) *TimeoutSignal {
return &TimeoutSignal{ return &TimeoutSignal{
NewResponseHeader(req_id, Direct), NewResponseHeader(req_id),
} }
} }
@ -28,26 +21,23 @@ func (signal TimeoutSignal) String() string {
return fmt.Sprintf("TimeoutSignal(%s)", &signal.ResponseHeader) return fmt.Sprintf("TimeoutSignal(%s)", &signal.ResponseHeader)
} }
// Timeouts are internal only, no permission allows sending them type SignalDirection int
func (signal TimeoutSignal) Permission() Tree { const (
return nil Up SignalDirection = iota
} Down
Direct
)
type SignalHeader struct { type SignalHeader struct {
Id uuid.UUID `gv:"id"` Id uuid.UUID `gv:"id"`
Dir SignalDirection `gv:"direction"`
} }
func (signal SignalHeader) ID() uuid.UUID { func (signal SignalHeader) ID() uuid.UUID {
return signal.Id return signal.Id
} }
func (signal SignalHeader) Direction() SignalDirection {
return signal.Dir
}
func (header SignalHeader) String() string { func (header SignalHeader) String() string {
return fmt.Sprintf("SignalHeader(%d, %s)", header.Dir, header.Id) return fmt.Sprintf("SignalHeader(%s)", header.Id)
} }
type ResponseSignal interface { type ResponseSignal interface {
@ -65,14 +55,12 @@ func (header ResponseHeader) ResponseID() uuid.UUID {
} }
func (header ResponseHeader) String() string { func (header ResponseHeader) String() string {
return fmt.Sprintf("ResponseHeader(%d, %s->%s)", header.Dir, header.Id, header.ReqID) return fmt.Sprintf("ResponseHeader(%s, %s)", header.Id, header.ReqID)
} }
type Signal interface { type Signal interface {
fmt.Stringer fmt.Stringer
ID() uuid.UUID ID() uuid.UUID
Direction() SignalDirection
Permission() Tree
} }
func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UUID) (ResponseSignal, []Signal, error) { func WaitForResponse(listener chan Signal, timeout time.Duration, req_id uuid.UUID) (ResponseSignal, []Signal, error) {
@ -129,16 +117,15 @@ func WaitForSignal[S Signal](listener chan Signal, timeout time.Duration, check
return zero, fmt.Errorf("LOOP_ENDED") return zero, fmt.Errorf("LOOP_ENDED")
} }
func NewSignalHeader(direction SignalDirection) SignalHeader { func NewSignalHeader() SignalHeader {
return SignalHeader{ return SignalHeader{
uuid.New(), uuid.New(),
direction,
} }
} }
func NewResponseHeader(req_id uuid.UUID, direction SignalDirection) ResponseHeader { func NewResponseHeader(req_id uuid.UUID) ResponseHeader {
return ResponseHeader{ return ResponseHeader{
NewSignalHeader(direction), NewSignalHeader(),
req_id, req_id,
} }
} }
@ -151,16 +138,9 @@ func (signal SuccessSignal) String() string {
return fmt.Sprintf("SuccessSignal(%s)", signal.ResponseHeader) return fmt.Sprintf("SuccessSignal(%s)", signal.ResponseHeader)
} }
func (signal SuccessSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[ResponseSignal]()): {
SerializedType(SignalTypeFor[SuccessSignal]()): nil,
},
}
}
func NewSuccessSignal(req_id uuid.UUID) *SuccessSignal { func NewSuccessSignal(req_id uuid.UUID) *SuccessSignal {
return &SuccessSignal{ return &SuccessSignal{
NewResponseHeader(req_id, Direct), NewResponseHeader(req_id),
} }
} }
@ -171,16 +151,9 @@ type ErrorSignal struct {
func (signal ErrorSignal) String() string { func (signal ErrorSignal) String() string {
return fmt.Sprintf("ErrorSignal(%s, %s)", signal.ResponseHeader, signal.Error) return fmt.Sprintf("ErrorSignal(%s, %s)", signal.ResponseHeader, signal.Error)
} }
func (signal ErrorSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[ResponseSignal]()): {
SerializedType(SignalTypeFor[ErrorSignal]()): nil,
},
}
}
func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *ErrorSignal { func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *ErrorSignal {
return &ErrorSignal{ return &ErrorSignal{
NewResponseHeader(req_id, Direct), NewResponseHeader(req_id),
fmt.Sprintf(fmt_string, args...), fmt.Sprintf(fmt_string, args...),
} }
} }
@ -188,14 +161,9 @@ func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *E
type ACLTimeoutSignal struct { type ACLTimeoutSignal struct {
ResponseHeader ResponseHeader
} }
func (signal ACLTimeoutSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[ACLTimeoutSignal]()): nil,
}
}
func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal { func NewACLTimeoutSignal(req_id uuid.UUID) *ACLTimeoutSignal {
sig := &ACLTimeoutSignal{ sig := &ACLTimeoutSignal{
NewResponseHeader(req_id, Direct), NewResponseHeader(req_id),
} }
return sig return sig
} }
@ -205,17 +173,12 @@ type StatusSignal struct {
Source NodeID `gv:"source"` Source NodeID `gv:"source"`
Changes map[ExtType]Changes `gv:"changes"` Changes map[ExtType]Changes `gv:"changes"`
} }
func (signal StatusSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[StatusSignal]()): nil,
}
}
func (signal StatusSignal) String() string { func (signal StatusSignal) String() string {
return fmt.Sprintf("StatusSignal(%s, %+v)", signal.SignalHeader, signal.Changes) return fmt.Sprintf("StatusSignal(%s, %+v)", signal.SignalHeader, signal.Changes)
} }
func NewStatusSignal(source NodeID, changes map[ExtType]Changes) *StatusSignal { func NewStatusSignal(source NodeID, changes map[ExtType]Changes) *StatusSignal {
return &StatusSignal{ return &StatusSignal{
NewSignalHeader(Up), NewSignalHeader(),
source, source,
changes, changes,
} }
@ -232,17 +195,9 @@ const (
LinkActionAdd = "ADD" LinkActionAdd = "ADD"
) )
func (signal LinkSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[LinkSignal]()): Tree{
Hash(LinkActionBase, signal.Action): nil,
},
}
}
func NewLinkSignal(action string, id NodeID) Signal { func NewLinkSignal(action string, id NodeID) Signal {
return &LinkSignal{ return &LinkSignal{
NewSignalHeader(Direct), NewSignalHeader(),
id, id,
action, action,
} }
@ -256,21 +211,9 @@ func (signal LockSignal) String() string {
return fmt.Sprintf("LockSignal(%s, %s)", signal.SignalHeader, signal.State) return fmt.Sprintf("LockSignal(%s, %s)", signal.SignalHeader, signal.State)
} }
const (
LockStateBase = "LOCK_STATE"
)
func (signal LockSignal) Permission() Tree {
return Tree{
SerializedType(SignalTypeFor[LockSignal]()): Tree{
Hash(LockStateBase, signal.State): nil,
},
}
}
func NewLockSignal(state string) *LockSignal { func NewLockSignal(state string) *LockSignal {
return &LockSignal{ return &LockSignal{
NewSignalHeader(Direct), NewSignalHeader(),
state, state,
} }
} }
@ -284,21 +227,9 @@ func (signal ReadSignal) String() string {
return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Extensions) return fmt.Sprintf("ReadSignal(%s, %+v)", signal.SignalHeader, signal.Extensions)
} }
func (signal ReadSignal) Permission() Tree {
ret := Tree{}
for ext, fields := range(signal.Extensions) {
field_tree := Tree{}
for _, field := range(fields) {
field_tree[SerializedType(GetFieldTag(field))] = nil
}
ret[SerializedType(ext)] = field_tree
}
return Tree{SerializedType(SignalTypeFor[ReadSignal]()): ret}
}
func NewReadSignal(exts map[ExtType][]string) *ReadSignal { func NewReadSignal(exts map[ExtType][]string) *ReadSignal {
return &ReadSignal{ return &ReadSignal{
NewSignalHeader(Direct), NewSignalHeader(),
exts, exts,
} }
} }
@ -307,23 +238,16 @@ type ReadResultSignal struct {
ResponseHeader ResponseHeader
NodeID NodeID NodeID NodeID
NodeType NodeType NodeType NodeType
Extensions map[ExtType]map[string]SerializedValue Extensions map[ExtType]map[string]any
} }
func (signal ReadResultSignal) String() string { func (signal ReadResultSignal) String() string {
return fmt.Sprintf("ReadResultSignal(%s, %s, %+v, %+v)", signal.ResponseHeader, signal.NodeID, signal.NodeType, signal.Extensions) return fmt.Sprintf("ReadResultSignal(%s, %s, %+v)", signal.ResponseHeader, signal.NodeID, signal.Extensions)
} }
func (signal ReadResultSignal) Permission() Tree { func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]any) *ReadResultSignal {
return Tree{
SerializedType(SignalTypeFor[ResponseSignal]()): {
SerializedType(SignalTypeFor[ReadResultSignal]()): nil,
},
}
}
func NewReadResultSignal(req_id uuid.UUID, node_id NodeID, node_type NodeType, exts map[ExtType]map[string]SerializedValue) *ReadResultSignal {
return &ReadResultSignal{ return &ReadResultSignal{
NewResponseHeader(req_id, Direct), NewResponseHeader(req_id),
node_id, node_id,
node_type, node_type,
exts, exts,