Changed NewNode to return a pointer and add the node to the context

graph-rework-2
noah metz 2023-07-26 15:08:14 -06:00
parent 755edf8558
commit d2f3daf5a6
12 changed files with 185 additions and 249 deletions

@ -714,30 +714,30 @@ type GQLExt struct {
Key *ecdsa.PrivateKey Key *ecdsa.PrivateKey
ECDH ecdh.Curve ECDH ecdh.Curve
SubscribeLock sync.Mutex SubscribeLock sync.Mutex
SubscribeListeners []chan GraphSignal SubscribeListeners []chan Signal
} }
func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan GraphSignal { func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan Signal {
ext.SubscribeLock.Lock() ext.SubscribeLock.Lock()
defer ext.SubscribeLock.Unlock() defer ext.SubscribeLock.Unlock()
new_listener := make(chan GraphSignal, buffer) new_listener := make(chan Signal, buffer)
ext.SubscribeListeners = append(ext.SubscribeListeners, new_listener) ext.SubscribeListeners = append(ext.SubscribeListeners, new_listener)
return new_listener return new_listener
} }
func (ext *GQLExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext *GQLExt) Process(context *StateContext, node *Node, signal Signal) error {
ext.SubscribeLock.Lock() ext.SubscribeLock.Lock()
defer ext.SubscribeLock.Unlock() defer ext.SubscribeLock.Unlock()
active_listeners := []chan GraphSignal{} active_listeners := []chan Signal{}
for _, listener := range(ext.SubscribeListeners) { for _, listener := range(ext.SubscribeListeners) {
select { select {
case listener <- signal: case listener <- signal:
active_listeners = append(active_listeners, listener) active_listeners = append(active_listeners, listener)
default: default:
go func(listener chan GraphSignal) { go func(listener chan Signal) {
listener <- NewDirectSignal("Channel Closed") listener <- NewDirectSignal("Channel Closed")
close(listener) close(listener)
}(listener) }(listener)
@ -853,7 +853,7 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_
} }
return &GQLExt{ return &GQLExt{
Listen: listen, Listen: listen,
SubscribeListeners: []chan GraphSignal{}, SubscribeListeners: []chan Signal{},
Key: key, Key: key,
ECDH: ecdh_curve, ECDH: ecdh_curve,
tls_cert: tls_cert, tls_cert: tls_cert,
@ -936,7 +936,7 @@ var gql_actions ThreadActions = ThreadActions{
} }
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = Signal(context, thread, thread, NewStatusSignal("server_started", thread.ID)) err = SendSignal(context, thread, thread, NewStatusSignal("server_started", thread.ID))
if err != nil { if err != nil {
return "", err return "", err
} }

@ -303,21 +303,21 @@ func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) {
return members, nil return members, nil
} }
func GQLSignalFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
if signal, ok := p.Source.(GraphSignal); ok { if signal, ok := p.Source.(Signal); ok {
return fn(signal, p) return fn(signal, p)
} }
return nil, fmt.Errorf("Failed to cast source to event") return nil, fmt.Errorf("Failed to cast source to event")
} }
func GQLSignalType(p graphql.ResolveParams) (interface{}, error) { func GQLSignalType(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
return signal.Type(), nil return signal.Type(), nil
}) })
} }
func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) { func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
direction := signal.Direction() direction := signal.Direction()
if direction == Up { if direction == Up {
return "up", nil return "up", nil
@ -331,7 +331,8 @@ func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) {
} }
func GQLSignalString(p graphql.ResolveParams) (interface{}, error) { func GQLSignalString(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
return signal.String(), nil ser, err := signal.Serialize()
return string(ser), err
}) })
} }

@ -17,41 +17,47 @@ func TestGQLDBLoad(t * testing.T) {
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType, ACLPolicyExtType}) err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType, ACLPolicyExtType})
fatalErr(t, err) fatalErr(t, err)
u1 := NewNode(RandID(), TestUserNodeType) u1 := NewNode(ctx, RandID(), TestUserNodeType)
ctx.Nodes[u1.ID] = &u1 u1_policy := NewPerNodePolicy(map[NodeID][]string{
u1.ID: []string{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"},
})
u1.Extensions[ACLExtType] = NewACLExt(nil) u1.Extensions[ACLExtType] = NewACLExt(nil)
u1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ u1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
PerNodePolicyType: NewPerNodePolicy(map[NodeID][]string{ PerNodePolicyType: &u1_policy,
u1.ID: []string{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"},
}, nil),
}) })
ctx.Log.Logf("test", "U1_ID: %s", u1.ID) ctx.Log.Logf("test", "U1_ID: %s", u1.ID)
ListenerNodeType := NodeType("LISTENER") ListenerNodeType := NodeType("LISTENER")
err = ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType}) err = ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ListenerExtType, LockableExtType})
fatalErr(t, err) fatalErr(t, err)
l1 := NewNode(RandID(), ListenerNodeType) l1 := NewNode(ctx, RandID(), ListenerNodeType)
ctx.Nodes[l1.ID] = &l1 l1_policy := NewRequirementOfPolicy(map[NodeID][]string{
l1.Extensions[ACLExtType] = NewACLExt(NodeList(&u1)) l1.ID: []string{"signal.status"},
})
l1.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
l1.Extensions[ListenerExtType] = listener_ext l1.Extensions[ListenerExtType] = listener_ext
l1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
RequirementOfPolicyType: &l1_policy,
})
l1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) l1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil)
ctx.Log.Logf("test", "L1_ID: %s", l1.ID) ctx.Log.Logf("test", "L1_ID: %s", l1.ID)
TestThreadNodeType := NodeType("TEST_THREAD") TestThreadNodeType := NodeType("TEST_THREAD")
err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ThreadExtType, LockableExtType}) err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ThreadExtType, LockableExtType})
fatalErr(t, err) fatalErr(t, err)
t1 := NewNode(RandID(), TestThreadNodeType) t1 := NewNode(ctx, RandID(), TestThreadNodeType)
ctx.Nodes[t1.ID] = &t1 t1_policy := NewParentOfPolicy(map[NodeID][]string{
t1.Extensions[ACLExtType] = NewACLExt(NodeList(&u1))
t1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
ParentOfPolicyType: NewParentOfPolicy(map[NodeID][]string{
t1.ID: []string{"signal.abort", "state.write"}, t1.ID: []string{"signal.abort", "state.write"},
}), })
t1.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
t1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
ParentOfPolicyType: &t1_policy,
}) })
t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil) t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil)
fatalErr(t, err) fatalErr(t, err)
@ -60,19 +66,19 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "T1_ID: %s", t1.ID) ctx.Log.Logf("test", "T1_ID: %s", t1.ID)
TestGQLNodeType := NodeType("TEST_GQL") TestGQLNodeType := NodeType("TEST_GQL")
err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType}) err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, ACLPolicyExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType})
fatalErr(t, err) fatalErr(t, err)
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
fatalErr(t, err) fatalErr(t, err)
gql := NewNode(RandID(), TestGQLNodeType) gql := NewNode(ctx, RandID(), TestGQLNodeType)
ctx.Nodes[gql.ID] = &gql gql_policy := NewChildOfPolicy(map[NodeID][]string{
gql.Extensions[ACLExtType] = NewACLExt(NodeList(&u1))
gql.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
ChildOfPolicyType: NewChildOfPolicy(map[NodeID][]string{
gql.ID: []string{"signal.status"}, gql.ID: []string{"signal.status"},
}), })
gql.Extensions[ACLExtType] = NewACLExt(NodeList(u1))
gql.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{
ChildOfPolicyType: &gql_policy,
}) })
gql.Extensions[GroupExtType] = NewGroupExt(nil) gql.Extensions[GroupExtType] = NewGroupExt(nil)
gql.Extensions[GQLExtType] = NewGQLExt(":0", ecdh.P256(), key, nil, nil) gql.Extensions[GQLExtType] = NewGQLExt(":0", ecdh.P256(), key, nil, nil)
@ -83,25 +89,25 @@ func TestGQLDBLoad(t * testing.T) {
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
info := ParentInfo{true, "start", "restore"} info := ParentInfo{true, "start", "restore"}
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err = UpdateStates(context, &u1, NewACLInfo(&gql, []string{"users"}), func(context *StateContext) error { err = UpdateStates(context, u1, NewACLInfo(gql, []string{"users"}), func(context *StateContext) error {
err := LinkThreads(context, &u1, &gql, ChildInfo{&t1, map[InfoType]Info{ err := LinkThreads(context, u1, gql, ChildInfo{t1, map[InfoType]Info{
ParentInfoType: &info, ParentInfoType: &info,
}}) }})
if err != nil { if err != nil {
return err return err
} }
return LinkLockables(context, &u1, &l1, []*Node{&gql}) return LinkLockables(context, u1, l1, []*Node{gql})
}) })
fatalErr(t, err) fatalErr(t, err)
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = Signal(context, &gql, &gql, NewStatusSignal("child_linked", t1.ID)) err = SendSignal(context, gql, gql, NewStatusSignal("child_linked", t1.ID))
fatalErr(t, err) fatalErr(t, err)
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = Signal(context, &gql, &gql, AbortSignal) err = SendSignal(context, gql, gql, AbortSignal)
fatalErr(t, err) fatalErr(t, err)
err = ThreadLoop(ctx, &gql, "start") err = ThreadLoop(ctx, gql, "start")
if errors.Is(err, ThreadAbortedError) == false { if errors.Is(err, ThreadAbortedError) == false {
fatalErr(t, err) fatalErr(t, err)
} }
@ -109,7 +115,7 @@ func TestGQLDBLoad(t * testing.T) {
(*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener") (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener")
context = NewReadContext(ctx) context = NewReadContext(ctx)
err = UseStates(context, &gql, ACLList([]*Node{&gql, &u1}, nil), func(context *StateContext) error { err = UseStates(context, gql, ACLList([]*Node{gql, u1}, nil), func(context *StateContext) error {
ser1, err := gql.Serialize() ser1, err := gql.Serialize()
ser2, err := u1.Serialize() ser2, err := u1.Serialize()
ctx.Log.Logf("test", "\n%s\n\n", ser1) ctx.Log.Logf("test", "\n%s\n\n", ser1)
@ -134,7 +140,7 @@ func TestGQLDBLoad(t * testing.T) {
if err != nil { if err != nil {
return err return err
} }
Signal(context, gql_loaded, gql_loaded, StopSignal) SendSignal(context, gql_loaded, gql_loaded, StopSignal)
return err return err
}) })

@ -125,7 +125,7 @@ var GQLTypeSignal = NewSingleton(func() *graphql.Object {
gql_type_signal := graphql.NewObject(graphql.ObjectConfig{ gql_type_signal := graphql.NewObject(graphql.ObjectConfig{
Name: "Signal", Name: "Signal",
IsTypeOf: func(p graphql.IsTypeOfParams) bool { IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(GraphSignal) _, ok := p.Value.(Signal)
return ok return ok
}, },
Fields: graphql.Fields{}, Fields: graphql.Fields{},

@ -13,7 +13,7 @@ import (
type GraphTester testing.T type GraphTester testing.T
const listner_timeout = 50 * time.Millisecond const listner_timeout = 50 * time.Millisecond
func (t * GraphTester) WaitForStatus(ctx * Context, listener chan GraphSignal, status string, timeout time.Duration, str string) GraphSignal { func (t * GraphTester) WaitForStatus(ctx * Context, listener chan Signal, status string, timeout time.Duration, str string) Signal {
timeout_channel := time.After(timeout) timeout_channel := time.After(timeout)
for true { for true {
select { select {
@ -42,7 +42,7 @@ func (t * GraphTester) WaitForStatus(ctx * Context, listener chan GraphSignal, s
return nil return nil
} }
func (t * GraphTester) CheckForNone(listener chan GraphSignal, str string) { func (t * GraphTester) CheckForNone(listener chan Signal, str string) {
timeout := time.After(listner_timeout) timeout := time.After(listner_timeout)
select { select {
case sig := <- listener: case sig := <- listener:

@ -7,13 +7,13 @@ import (
type ListenerExt struct { type ListenerExt struct {
Buffer int Buffer int
Chan chan GraphSignal Chan chan Signal
} }
func NewListenerExt(buffer int) *ListenerExt { func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{ return &ListenerExt{
Buffer: buffer, Buffer: buffer,
Chan: make(chan GraphSignal, buffer), Chan: make(chan Signal, buffer),
} }
} }
@ -32,7 +32,7 @@ func (listener ListenerExt) Type() ExtType {
return ListenerExtType return ListenerExtType
} }
func (ext ListenerExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext ListenerExt) Process(context *StateContext, node *Node, signal Signal) error {
select { select {
case ext.Chan <- signal: case ext.Chan <- signal:
default: default:
@ -125,7 +125,7 @@ func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
return NewLockableExt(owner, requirements, dependencies, locks_held), nil return NewLockableExt(owner, requirements, dependencies, locks_held), nil
} }
func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal) error {
context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID) context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID)
var err error var err error
@ -136,7 +136,7 @@ func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphS
owner_sent := false owner_sent := false
for _, dependency := range(ext.Dependencies) { for _, dependency := range(ext.Dependencies) {
context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID) context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID)
Signal(context, dependency, node, signal) SendSignal(context, dependency, node, signal)
if ext.Owner != nil { if ext.Owner != nil {
if dependency.ID == ext.Owner.ID { if dependency.ID == ext.Owner.ID {
owner_sent = true owner_sent = true
@ -146,7 +146,7 @@ func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphS
if ext.Owner != nil && owner_sent == false { if ext.Owner != nil && owner_sent == false {
if ext.Owner.ID != node.ID { if ext.Owner.ID != node.ID {
context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID) context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID)
return Signal(context, ext.Owner, node, signal) return SendSignal(context, ext.Owner, node, signal)
} }
} }
return nil return nil
@ -154,7 +154,7 @@ func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphS
case Down: case Down:
err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error { err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error {
for _, requirement := range(ext.Requirements) { for _, requirement := range(ext.Requirements) {
err := Signal(context, requirement, node, signal) err := SendSignal(context, requirement, node, signal)
if err != nil { if err != nil {
return err return err
} }

@ -67,7 +67,7 @@ type Extension interface {
Serializable[ExtType] Serializable[ExtType]
// Send a signal to this extension to process, // Send a signal to this extension to process,
// this typically triggers signals to be sent to nodes linked in the extension // this typically triggers signals to be sent to nodes linked in the extension
Process(context *StateContext, node *Node, signal GraphSignal) error Process(context *StateContext, node *Node, signal Signal) error
} }
// Nodes represent an addressible group of extensions // Nodes represent an addressible group of extensions
@ -141,12 +141,20 @@ func (node *Node) Serialize() ([]byte, error) {
return node_db.Serialize(), nil return node_db.Serialize(), nil
} }
func NewNode(id NodeID, node_type NodeType) Node { func NewNode(ctx *Context, id NodeID, node_type NodeType) *Node {
return Node{ _, exists := ctx.Nodes[id]
if exists == true {
panic("Attempted to create an existing node")
}
node := &Node{
ID: id, ID: id,
Type: node_type, Type: node_type,
Extensions: map[ExtType]Extension{}, Extensions: map[ExtType]Extension{},
} }
ctx.Nodes[id] = node
return node
} }
func Allowed(context *StateContext, principal *Node, action string, node *Node) error { func Allowed(context *StateContext, principal *Node, action string, node *Node) error {
@ -191,8 +199,9 @@ func Allowed(context *StateContext, principal *Node, action string, node *Node)
// Check that princ is allowed to signal this action, // Check that princ is allowed to signal this action,
// then send the signal to all the extensions of the node // then send the signal to all the extensions of the node
func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal) error { func SendSignal(context *StateContext, node *Node, princ *Node, signal Signal) error {
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, signal.String()) ser, _ := signal.Serialize()
context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, string(ser))
err := UseStates(context, princ, NewACLInfo(node, []string{}), func(context *StateContext) error { err := UseStates(context, princ, NewACLInfo(node, []string{}), func(context *StateContext) error {
return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node) return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node)
@ -398,9 +407,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
} }
// Create the blank node with the ID, and add it to the context // Create the blank node with the ID, and add it to the context
new_node := NewNode(id, node_type.Type) node = NewNode(ctx, id, node_type.Type)
node = &new_node
ctx.Nodes[id] = node
found_extensions := []ExtType{} found_extensions := []ExtType{}
// Parse each of the extensions from the db // Parse each of the extensions from the db

@ -9,21 +9,20 @@ func TestNodeDB(t *testing.T) {
node_type := NodeType("test") node_type := NodeType("test")
err := ctx.RegisterNodeType(node_type, []ExtType{"ACL"}) err := ctx.RegisterNodeType(node_type, []ExtType{"ACL"})
fatalErr(t, err) fatalErr(t, err)
node := NewNode(RandID(), node_type) node := NewNode(ctx, RandID(), node_type)
node.Extensions[ACLExtType] = &ACLExt{ node.Extensions[ACLExtType] = &ACLExt{
Delegations: NodeMap{}, Delegations: NodeMap{},
} }
ctx.Nodes[node.ID] = &node
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err = UpdateStates(context, &node, NewACLInfo(&node, []string{"test"}), func(context *StateContext) error { err = UpdateStates(context, node, NewACLInfo(node, []string{"test"}), func(context *StateContext) error {
ser, err := node.Serialize() ser, err := node.Serialize()
ctx.Log.Logf("test", "NODE_SER: %+v", ser) ctx.Log.Logf("test", "NODE_SER: %+v", ser)
return err return err
}) })
fatalErr(t, err) fatalErr(t, err)
delete(ctx.Nodes, node.ID) ctx.Nodes = NodeMap{}
_, err = LoadNode(ctx, node.ID) _, err = LoadNode(ctx, node.ID)
fatalErr(t, err) fatalErr(t, err)
} }

@ -6,30 +6,32 @@ import (
) )
type Policy interface { type Policy interface {
Type() PolicyType Serializable[PolicyType]
Serialize() ([]byte, error)
Allows(context *StateContext, principal *Node, action string, node *Node) bool Allows(context *StateContext, principal *Node, action string, node *Node) bool
} }
const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF")
type RequirementOfPolicy struct {
PerNodePolicy
}
func (policy *RequirementOfPolicy) Type() PolicyType {
return RequirementOfPolicyType
}
func NewRequirementOfPolicy(nodes NodeActions) RequirementOfPolicy {
return RequirementOfPolicy{
PerNodePolicy: NewPerNodePolicy(nodes),
}
}
const ChildOfPolicyType = PolicyType("CHILD_OF") const ChildOfPolicyType = PolicyType("CHILD_OF")
type ChildOfPolicy struct { type ChildOfPolicy struct {
NodeActions map[NodeID][]string PerNodePolicy
} }
func (policy *ChildOfPolicy) Type() PolicyType { func (policy *ChildOfPolicy) Type() PolicyType {
return ChildOfPolicyType return ChildOfPolicyType
} }
func (policy *ChildOfPolicy) Serialize() ([]byte, error) {
node_actions := map[string][]string{}
for id, actions := range(policy.NodeActions) {
node_actions[id.String()] = actions
}
return json.MarshalIndent(&ChildOfPolicyJSON{
NodeActions: node_actions,
}, "", " ")
}
func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
context.Graph.Log.Logf("policy", "CHILD_OF_POLICY: %+v", policy) context.Graph.Log.Logf("policy", "CHILD_OF_POLICY: %+v", policy)
thread_ext, err := GetExt[*ThreadExt](principal) thread_ext, err := GetExt[*ThreadExt](principal)
@ -53,18 +55,17 @@ func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, acti
return false return false
} }
type ChildOfPolicyJSON struct { type NodeActions map[NodeID][]string
NodeActions map[string][]string `json:"node_actions"`
}
func LoadChildOfPolicy(ctx *Context, data []byte) (Policy, error) { func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) {
var j ChildOfPolicyJSON return func(ctx *Context, data []byte)(Policy, error){
var j PerNodePolicyJSON
err := json.Unmarshal(data, &j) err := json.Unmarshal(data, &j)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node_actions := map[NodeID][]string{} node_actions := NodeActions{}
for id_str, actions := range(j.NodeActions) { for id_str, actions := range(j.NodeActions) {
id, err := ParseID(id_str) id, err := ParseID(id_str)
if err != nil { if err != nil {
@ -79,38 +80,24 @@ func LoadChildOfPolicy(ctx *Context, data []byte) (Policy, error) {
node_actions[id] = actions node_actions[id] = actions
} }
return NewChildOfPolicy(node_actions), nil return init_fn(node_actions)
}
func NewChildOfPolicy(node_actions map[NodeID][]string) *ChildOfPolicy {
if node_actions == nil {
node_actions = map[NodeID][]string{}
} }
}
return &ChildOfPolicy{ func NewChildOfPolicy(node_actions map[NodeID][]string) ChildOfPolicy {
NodeActions: node_actions, return ChildOfPolicy{
PerNodePolicy: NewPerNodePolicy(node_actions),
} }
} }
const ParentOfPolicyType = PolicyType("PARENT_OF") const ParentOfPolicyType = PolicyType("PARENT_OF")
type ParentOfPolicy struct { type ParentOfPolicy struct {
NodeActions map[NodeID][]string PerNodePolicy
} }
func (policy *ParentOfPolicy) Type() PolicyType { func (policy *ParentOfPolicy) Type() PolicyType {
return ParentOfPolicyType return ParentOfPolicyType
} }
func (policy *ParentOfPolicy) Serialize() ([]byte, error) {
node_actions := map[string][]string{}
for id, actions := range(policy.NodeActions) {
node_actions[id.String()] = actions
}
return json.MarshalIndent(&ParentOfPolicyJSON{
NodeActions: node_actions,
}, "", " ")
}
func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
context.Graph.Log.Logf("policy", "PARENT_OF_POLICY: %+v", policy) context.Graph.Log.Logf("policy", "PARENT_OF_POLICY: %+v", policy)
for id, actions := range(policy.NodeActions) { for id, actions := range(policy.NodeActions) {
@ -134,102 +121,36 @@ func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, act
return false return false
} }
type ParentOfPolicyJSON struct { func NewParentOfPolicy(node_actions map[NodeID][]string) ParentOfPolicy {
NodeActions map[string][]string `json:"node_actions"` return ParentOfPolicy{
} PerNodePolicy: NewPerNodePolicy(node_actions),
func LoadParentOfPolicy(ctx *Context, data []byte) (Policy, error) {
var j ParentOfPolicyJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
node_actions := map[NodeID][]string{}
for id_str, actions := range(j.NodeActions) {
id, err := ParseID(id_str)
if err != nil {
return nil, err
} }
_, err = LoadNode(ctx, id)
if err != nil {
return nil, err
}
node_actions[id] = actions
}
return NewParentOfPolicy(node_actions), nil
} }
func NewParentOfPolicy(node_actions map[NodeID][]string) *ParentOfPolicy { func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy {
if node_actions == nil { if node_actions == nil {
node_actions = map[NodeID][]string{} node_actions = map[NodeID][]string{}
} }
return &ParentOfPolicy{ return PerNodePolicy{
NodeActions: node_actions, NodeActions: node_actions,
} }
} }
func LoadPerNodePolicy(ctx *Context, data []byte) (Policy, error) {
var j PerNodePolicyJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
node_actions := map[NodeID][]string{}
for id_str, actions := range(j.NodeActions) {
id, err := ParseID(id_str)
if err != nil {
return nil, err
}
_, err = LoadNode(ctx, id)
if err != nil {
return nil, err
}
node_actions[id] = actions
}
return NewPerNodePolicy(node_actions, j.WildcardActions), nil
}
func NewPerNodePolicy(node_actions map[NodeID][]string, wildcard_actions []string) *PerNodePolicy {
if node_actions == nil {
node_actions = map[NodeID][]string{}
}
if wildcard_actions == nil {
wildcard_actions = []string{}
}
return &PerNodePolicy{
NodeActions: node_actions,
WildcardActions: wildcard_actions,
}
}
type PerNodePolicy struct { type PerNodePolicy struct {
NodeActions map[NodeID][]string NodeActions map[NodeID][]string
WildcardActions []string
} }
type PerNodePolicyJSON struct { type PerNodePolicyJSON struct {
NodeActions map[string][]string `json:"node_actions"` NodeActions map[string][]string `json:"node_actions"`
WildcardActions []string `json:"wildcard_actions"`
} }
const PerNodePolicyType = PolicyType("PER_NODE") const PerNodePolicyType = PolicyType("PER_NODE")
func (policy PerNodePolicy) Type() PolicyType { func (policy *PerNodePolicy) Type() PolicyType {
return PerNodePolicyType return PerNodePolicyType
} }
func (policy PerNodePolicy) Serialize() ([]byte, error) { func (policy *PerNodePolicy) Serialize() ([]byte, error) {
node_actions := map[string][]string{} node_actions := map[string][]string{}
for id, actions := range(policy.NodeActions) { for id, actions := range(policy.NodeActions) {
node_actions[id.String()] = actions node_actions[id.String()] = actions
@ -237,17 +158,10 @@ func (policy PerNodePolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(&PerNodePolicyJSON{ return json.MarshalIndent(&PerNodePolicyJSON{
NodeActions: node_actions, NodeActions: node_actions,
WildcardActions: policy.WildcardActions,
}, "", " ") }, "", " ")
} }
func (policy PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { func (policy *PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool {
for _, a := range(policy.WildcardActions) {
if a == action {
return true
}
}
for id, actions := range(policy.NodeActions) { for id, actions := range(policy.NodeActions) {
if id != principal.ID { if id != principal.ID {
continue continue
@ -272,7 +186,7 @@ type ACLExt struct {
Delegations NodeMap Delegations NodeMap
} }
func (ext *ACLExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error {
return nil return nil
} }
@ -347,13 +261,28 @@ func NewACLPolicyExtContext() *ACLPolicyExtContext {
return &ACLPolicyExtContext{ return &ACLPolicyExtContext{
Types: map[PolicyType]PolicyInfo{ Types: map[PolicyType]PolicyInfo{
PerNodePolicyType: PolicyInfo{ PerNodePolicyType: PolicyInfo{
Load: LoadPerNodePolicy, Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewPerNodePolicy(nodes)
return &policy, nil
}),
}, },
ParentOfPolicyType: PolicyInfo{ ParentOfPolicyType: PolicyInfo{
Load: LoadParentOfPolicy, Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewParentOfPolicy(nodes)
return &policy, nil
}),
}, },
ChildOfPolicyType: PolicyInfo{ ChildOfPolicyType: PolicyInfo{
Load: LoadChildOfPolicy, Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewChildOfPolicy(nodes)
return &policy, nil
}),
},
RequirementOfPolicyType: PolicyInfo{
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewRequirementOfPolicy(nodes)
return &policy, nil
}),
}, },
}, },
} }
@ -376,7 +305,7 @@ func (ext *ACLPolicyExt) Serialize() ([]byte, error) {
}, "", " ") }, "", " ")
} }
func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal Signal) error {
return nil return nil
} }

@ -11,54 +11,48 @@ const (
Direct Direct
) )
// GraphSignals are passed around the event tree/resource DAG and cast by Type() type SignalType string
type GraphSignal interface {
// How to propogate the signal type Signal interface {
Serializable[SignalType]
Direction() SignalDirection Direction() SignalDirection
Type() string
String() string
} }
// BaseSignal is the most basic type of signal, it has no additional data
type BaseSignal struct { type BaseSignal struct {
FDirection SignalDirection `json:"direction"` SignalDirection SignalDirection `json:"direction"`
FType string `json:"type"` SignalType SignalType `json:"type"`
} }
func (signal BaseSignal) String() string { func (signal BaseSignal) Type() SignalType {
ser, err := json.Marshal(signal) return signal.SignalType
if err != nil {
return "STATE_SER_ERR"
}
return string(ser)
} }
func (signal BaseSignal) Direction() SignalDirection { func (signal BaseSignal) Direction() SignalDirection {
return signal.FDirection return signal.SignalDirection
} }
func (signal BaseSignal) Type() string { func (signal BaseSignal) Serialize() ([]byte, error) {
return signal.FType return json.MarshalIndent(signal, "", " ")
} }
func NewBaseSignal(_type string, direction SignalDirection) BaseSignal { func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal {
signal := BaseSignal{ signal := BaseSignal{
FDirection: direction, SignalDirection: direction,
FType: _type, SignalType: signal_type,
} }
return signal return signal
} }
func NewDownSignal(_type string) BaseSignal { func NewDownSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(_type, Down) return NewBaseSignal(signal_type, Down)
} }
func NewSignal(_type string) BaseSignal { func NewUpSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(_type, Up) return NewBaseSignal(signal_type, Up)
} }
func NewDirectSignal(_type string) BaseSignal { func NewDirectSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(_type, Direct) return NewBaseSignal(signal_type, Direct)
} }
var AbortSignal = NewBaseSignal("abort", Down) var AbortSignal = NewBaseSignal("abort", Down)
@ -77,9 +71,9 @@ func (signal IDSignal) String() string {
return string(ser) return string(ser)
} }
func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal { func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal {
return IDSignal{ return IDSignal{
BaseSignal: NewBaseSignal(_type, direction), BaseSignal: NewBaseSignal(signal_type, direction),
ID: id, ID: id,
} }
} }

@ -12,8 +12,8 @@ import (
type ThreadAction func(*Context, *Node, *ThreadExt)(string, error) type ThreadAction func(*Context, *Node, *ThreadExt)(string, error)
type ThreadActions map[string]ThreadAction type ThreadActions map[string]ThreadAction
type ThreadHandler func(*Context, *Node, *ThreadExt, GraphSignal)(string, error) type ThreadHandler func(*Context, *Node, *ThreadExt, Signal)(string, error)
type ThreadHandlers map[string]ThreadHandler type ThreadHandlers map[SignalType]ThreadHandler
type InfoType string type InfoType string
func (t InfoType) String() string { func (t InfoType) String() string {
@ -122,7 +122,7 @@ type ThreadExt struct {
ThreadType ThreadType ThreadType ThreadType
SignalChan chan GraphSignal SignalChan chan Signal
TimeoutChan <-chan time.Time TimeoutChan <-chan time.Time
ChildWaits sync.WaitGroup ChildWaits sync.WaitGroup
@ -191,7 +191,7 @@ func NewThreadExt(ctx*Context, thread_type ThreadType, parent *Node, children ma
return &ThreadExt{ return &ThreadExt{
Actions: type_info.Actions, Actions: type_info.Actions,
Handlers: type_info.Handlers, Handlers: type_info.Handlers,
SignalChan: make(chan GraphSignal, THREAD_BUFFER_SIZE), SignalChan: make(chan Signal, THREAD_BUFFER_SIZE),
TimeoutChan: timeout_chan, TimeoutChan: timeout_chan,
Active: false, Active: false,
State: state, State: state,
@ -276,7 +276,7 @@ func (ext *ThreadExt) ChildList() []*Node {
} }
// Assumed that thread is already locked for signal // Assumed that thread is already locked for signal
func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext *ThreadExt) Process(context *StateContext, node *Node, signal Signal) error {
context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID) context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID)
var err error var err error
@ -285,7 +285,7 @@ func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSig
err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error { err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error {
if ext.Parent != nil { if ext.Parent != nil {
if ext.Parent.ID != node.ID { if ext.Parent.ID != node.ID {
return Signal(context, ext.Parent, node, signal) return SendSignal(context, ext.Parent, node, signal)
} }
} }
return nil return nil
@ -293,7 +293,7 @@ func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSig
case Down: case Down:
err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error { err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error {
for _, info := range(ext.Children) { for _, info := range(ext.Children) {
err := Signal(context, info.Child, node, signal) err := SendSignal(context, info.Child, node, signal)
if err != nil { if err != nil {
return err return err
} }
@ -535,7 +535,7 @@ func ThreadLoop(ctx * Context, thread *Node, first_action string) error {
return nil return nil
} }
func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal) ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal)
context := NewWriteContext(ctx) context := NewWriteContext(ctx)
err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error { err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error {
@ -570,7 +570,7 @@ func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal
// Helper function to start a child from a thread during a signal handler // Helper function to start a child from a thread during a signal handler
// Starts a write context, so cannot be called from either a write or read context // Starts a write context, so cannot be called from either a write or read context
func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
sig, ok := signal.(StartChildSignal) sig, ok := signal.(StartChildSignal)
if ok == false { if ok == false {
return "wait", nil return "wait", nil
@ -638,7 +638,7 @@ func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, er
} }
context = NewReadContext(ctx) context = NewReadContext(ctx)
return "wait", Signal(context, thread, thread, NewStatusSignal("started", thread.ID)) return "wait", SendSignal(context, thread, thread, NewStatusSignal("started", thread.ID))
} }
func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) { func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) {
@ -685,9 +685,9 @@ func ThreadFinish(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, er
var ThreadAbortedError = errors.New("Thread aborted by signal") var ThreadAbortedError = errors.New("Thread aborted by signal")
// Default thread action function for "abort", sends a signal and returns a ThreadAbortedError // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError
func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID)) err := SendSignal(context, thread, thread, NewStatusSignal("aborted", thread.ID))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -695,9 +695,9 @@ func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Grap
} }
// Default thread action for "stop", sends a signal and returns no error // Default thread action for "stop", sends a signal and returns no error
func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) {
context := NewReadContext(ctx) context := NewReadContext(ctx)
err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID)) err := SendSignal(context, thread, thread, NewStatusSignal("stopped", thread.ID))
return "finish", err return "finish", err
} }

@ -20,7 +20,7 @@ type ECDHExtJSON struct {
Shared []byte `json:"shared"` Shared []byte `json:"shared"`
} }
func (ext *ECDHExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext *ECDHExt) Process(context *StateContext, node *Node, signal Signal) error {
return nil return nil
} }
@ -115,6 +115,6 @@ func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
return NewGroupExt(members), nil return NewGroupExt(members), nil
} }
func (ext *GroupExt) Process(context *StateContext, node *Node, signal GraphSignal) error { func (ext *GroupExt) Process(context *StateContext, node *Node, signal Signal) error {
return nil return nil
} }