From d2f3daf5a6797856d4f7fd358a63efe538a80991 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Wed, 26 Jul 2023 15:08:14 -0600 Subject: [PATCH] Changed NewNode to return a pointer and add the node to the context --- gql.go | 16 ++-- gql_resolvers.go | 13 +-- gql_test.go | 68 +++++++++------- gql_types.go | 2 +- graph_test.go | 4 +- lockable.go | 14 ++-- node.go | 23 ++++-- node_test.go | 7 +- policy.go | 205 ++++++++++++++++------------------------------- signal.go | 50 +++++------- thread.go | 28 +++---- user.go | 4 +- 12 files changed, 185 insertions(+), 249 deletions(-) diff --git a/gql.go b/gql.go index 2810018..eb3858e 100644 --- a/gql.go +++ b/gql.go @@ -714,30 +714,30 @@ type GQLExt struct { Key *ecdsa.PrivateKey ECDH ecdh.Curve SubscribeLock sync.Mutex - SubscribeListeners []chan GraphSignal + SubscribeListeners []chan Signal } -func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan GraphSignal { +func (ext *GQLExt) NewSubscriptionChannel(buffer int) chan Signal { ext.SubscribeLock.Lock() defer ext.SubscribeLock.Unlock() - new_listener := make(chan GraphSignal, buffer) + new_listener := make(chan Signal, buffer) ext.SubscribeListeners = append(ext.SubscribeListeners, new_listener) return new_listener } -func (ext *GQLExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext *GQLExt) Process(context *StateContext, node *Node, signal Signal) error { ext.SubscribeLock.Lock() defer ext.SubscribeLock.Unlock() - active_listeners := []chan GraphSignal{} + active_listeners := []chan Signal{} for _, listener := range(ext.SubscribeListeners) { select { case listener <- signal: active_listeners = append(active_listeners, listener) default: - go func(listener chan GraphSignal) { + go func(listener chan Signal) { listener <- NewDirectSignal("Channel Closed") close(listener) }(listener) @@ -853,7 +853,7 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_ } return &GQLExt{ Listen: listen, - SubscribeListeners: []chan GraphSignal{}, + SubscribeListeners: []chan Signal{}, Key: key, ECDH: ecdh_curve, tls_cert: tls_cert, @@ -936,7 +936,7 @@ var gql_actions ThreadActions = ThreadActions{ } context = NewReadContext(ctx) - err = Signal(context, thread, thread, NewStatusSignal("server_started", thread.ID)) + err = SendSignal(context, thread, thread, NewStatusSignal("server_started", thread.ID)) if err != nil { return "", err } diff --git a/gql_resolvers.go b/gql_resolvers.go index e5023cc..827a926 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -303,21 +303,21 @@ func GQLGroupMembers(p graphql.ResolveParams) (interface{}, error) { return members, nil } -func GQLSignalFn(p graphql.ResolveParams, fn func(GraphSignal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { - if signal, ok := p.Source.(GraphSignal); ok { +func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { + if signal, ok := p.Source.(Signal); ok { return fn(signal, p) } return nil, fmt.Errorf("Failed to cast source to event") } func GQLSignalType(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ + return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){ return signal.Type(), nil }) } func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ + return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){ direction := signal.Direction() if direction == Up { return "up", nil @@ -331,7 +331,8 @@ func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) { } func GQLSignalString(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal GraphSignal, p graphql.ResolveParams)(interface{}, error){ - return signal.String(), nil + return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){ + ser, err := signal.Serialize() + return string(ser), err }) } diff --git a/gql_test.go b/gql_test.go index 54b867a..49fcc16 100644 --- a/gql_test.go +++ b/gql_test.go @@ -17,41 +17,47 @@ func TestGQLDBLoad(t * testing.T) { err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType, ACLPolicyExtType}) fatalErr(t, err) - u1 := NewNode(RandID(), TestUserNodeType) - ctx.Nodes[u1.ID] = &u1 + u1 := NewNode(ctx, RandID(), TestUserNodeType) + u1_policy := NewPerNodePolicy(map[NodeID][]string{ + u1.ID: []string{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"}, + }) u1.Extensions[ACLExtType] = NewACLExt(nil) u1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ - PerNodePolicyType: NewPerNodePolicy(map[NodeID][]string{ - u1.ID: []string{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"}, - }, nil), + PerNodePolicyType: &u1_policy, }) ctx.Log.Logf("test", "U1_ID: %s", u1.ID) ListenerNodeType := NodeType("LISTENER") - err = ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType}) + err = ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ListenerExtType, LockableExtType}) fatalErr(t, err) - l1 := NewNode(RandID(), ListenerNodeType) - ctx.Nodes[l1.ID] = &l1 - l1.Extensions[ACLExtType] = NewACLExt(NodeList(&u1)) + l1 := NewNode(ctx, RandID(), ListenerNodeType) + l1_policy := NewRequirementOfPolicy(map[NodeID][]string{ + l1.ID: []string{"signal.status"}, + }) + + l1.Extensions[ACLExtType] = NewACLExt(NodeList(u1)) listener_ext := NewListenerExt(10) l1.Extensions[ListenerExtType] = listener_ext + l1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ + RequirementOfPolicyType: &l1_policy, + }) l1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) ctx.Log.Logf("test", "L1_ID: %s", l1.ID) TestThreadNodeType := NodeType("TEST_THREAD") - err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ThreadExtType, LockableExtType}) + err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ThreadExtType, LockableExtType}) fatalErr(t, err) - t1 := NewNode(RandID(), TestThreadNodeType) - ctx.Nodes[t1.ID] = &t1 - t1.Extensions[ACLExtType] = NewACLExt(NodeList(&u1)) + t1 := NewNode(ctx, RandID(), TestThreadNodeType) + t1_policy := NewParentOfPolicy(map[NodeID][]string{ + t1.ID: []string{"signal.abort", "state.write"}, + }) + t1.Extensions[ACLExtType] = NewACLExt(NodeList(u1)) t1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ - ParentOfPolicyType: NewParentOfPolicy(map[NodeID][]string{ - t1.ID: []string{"signal.abort", "state.write"}, - }), + ParentOfPolicyType: &t1_policy, }) t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil) fatalErr(t, err) @@ -60,19 +66,19 @@ func TestGQLDBLoad(t * testing.T) { ctx.Log.Logf("test", "T1_ID: %s", t1.ID) TestGQLNodeType := NodeType("TEST_GQL") - err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType}) + err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, ACLPolicyExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType}) fatalErr(t, err) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) fatalErr(t, err) - gql := NewNode(RandID(), TestGQLNodeType) - ctx.Nodes[gql.ID] = &gql - gql.Extensions[ACLExtType] = NewACLExt(NodeList(&u1)) + gql := NewNode(ctx, RandID(), TestGQLNodeType) + gql_policy := NewChildOfPolicy(map[NodeID][]string{ + gql.ID: []string{"signal.status"}, + }) + gql.Extensions[ACLExtType] = NewACLExt(NodeList(u1)) gql.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ - ChildOfPolicyType: NewChildOfPolicy(map[NodeID][]string{ - gql.ID: []string{"signal.status"}, - }), + ChildOfPolicyType: &gql_policy, }) gql.Extensions[GroupExtType] = NewGroupExt(nil) gql.Extensions[GQLExtType] = NewGQLExt(":0", ecdh.P256(), key, nil, nil) @@ -83,25 +89,25 @@ func TestGQLDBLoad(t * testing.T) { ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) info := ParentInfo{true, "start", "restore"} context := NewWriteContext(ctx) - err = UpdateStates(context, &u1, NewACLInfo(&gql, []string{"users"}), func(context *StateContext) error { - err := LinkThreads(context, &u1, &gql, ChildInfo{&t1, map[InfoType]Info{ + err = UpdateStates(context, u1, NewACLInfo(gql, []string{"users"}), func(context *StateContext) error { + err := LinkThreads(context, u1, gql, ChildInfo{t1, map[InfoType]Info{ ParentInfoType: &info, }}) if err != nil { return err } - return LinkLockables(context, &u1, &l1, []*Node{&gql}) + return LinkLockables(context, u1, l1, []*Node{gql}) }) fatalErr(t, err) context = NewReadContext(ctx) - err = Signal(context, &gql, &gql, NewStatusSignal("child_linked", t1.ID)) + err = SendSignal(context, gql, gql, NewStatusSignal("child_linked", t1.ID)) fatalErr(t, err) context = NewReadContext(ctx) - err = Signal(context, &gql, &gql, AbortSignal) + err = SendSignal(context, gql, gql, AbortSignal) fatalErr(t, err) - err = ThreadLoop(ctx, &gql, "start") + err = ThreadLoop(ctx, gql, "start") if errors.Is(err, ThreadAbortedError) == false { fatalErr(t, err) } @@ -109,7 +115,7 @@ func TestGQLDBLoad(t * testing.T) { (*GraphTester)(t).WaitForStatus(ctx, listener_ext.Chan, "aborted", 100*time.Millisecond, "Didn't receive aborted on listener") context = NewReadContext(ctx) - err = UseStates(context, &gql, ACLList([]*Node{&gql, &u1}, nil), func(context *StateContext) error { + err = UseStates(context, gql, ACLList([]*Node{gql, u1}, nil), func(context *StateContext) error { ser1, err := gql.Serialize() ser2, err := u1.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser1) @@ -134,7 +140,7 @@ func TestGQLDBLoad(t * testing.T) { if err != nil { return err } - Signal(context, gql_loaded, gql_loaded, StopSignal) + SendSignal(context, gql_loaded, gql_loaded, StopSignal) return err }) diff --git a/gql_types.go b/gql_types.go index b130118..579c54d 100644 --- a/gql_types.go +++ b/gql_types.go @@ -125,7 +125,7 @@ var GQLTypeSignal = NewSingleton(func() *graphql.Object { gql_type_signal := graphql.NewObject(graphql.ObjectConfig{ Name: "Signal", IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(GraphSignal) + _, ok := p.Value.(Signal) return ok }, Fields: graphql.Fields{}, diff --git a/graph_test.go b/graph_test.go index 258e64f..c0c5e46 100644 --- a/graph_test.go +++ b/graph_test.go @@ -13,7 +13,7 @@ import ( type GraphTester testing.T const listner_timeout = 50 * time.Millisecond -func (t * GraphTester) WaitForStatus(ctx * Context, listener chan GraphSignal, status string, timeout time.Duration, str string) GraphSignal { +func (t * GraphTester) WaitForStatus(ctx * Context, listener chan Signal, status string, timeout time.Duration, str string) Signal { timeout_channel := time.After(timeout) for true { select { @@ -42,7 +42,7 @@ func (t * GraphTester) WaitForStatus(ctx * Context, listener chan GraphSignal, s return nil } -func (t * GraphTester) CheckForNone(listener chan GraphSignal, str string) { +func (t * GraphTester) CheckForNone(listener chan Signal, str string) { timeout := time.After(listner_timeout) select { case sig := <- listener: diff --git a/lockable.go b/lockable.go index ba2ea8f..dd3a24a 100644 --- a/lockable.go +++ b/lockable.go @@ -7,13 +7,13 @@ import ( type ListenerExt struct { Buffer int - Chan chan GraphSignal + Chan chan Signal } func NewListenerExt(buffer int) *ListenerExt { return &ListenerExt{ Buffer: buffer, - Chan: make(chan GraphSignal, buffer), + Chan: make(chan Signal, buffer), } } @@ -32,7 +32,7 @@ func (listener ListenerExt) Type() ExtType { return ListenerExtType } -func (ext ListenerExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext ListenerExt) Process(context *StateContext, node *Node, signal Signal) error { select { case ext.Chan <- signal: default: @@ -125,7 +125,7 @@ func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { return NewLockableExt(owner, requirements, dependencies, locks_held), nil } -func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal) error { context.Graph.Log.Logf("signal", "LOCKABLE_PROCESS: %s", node.ID) var err error @@ -136,7 +136,7 @@ func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphS owner_sent := false for _, dependency := range(ext.Dependencies) { context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID) - Signal(context, dependency, node, signal) + SendSignal(context, dependency, node, signal) if ext.Owner != nil { if dependency.ID == ext.Owner.ID { owner_sent = true @@ -146,7 +146,7 @@ func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphS if ext.Owner != nil && owner_sent == false { if ext.Owner.ID != node.ID { context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID) - return Signal(context, ext.Owner, node, signal) + return SendSignal(context, ext.Owner, node, signal) } } return nil @@ -154,7 +154,7 @@ func (ext *LockableExt) Process(context *StateContext, node *Node, signal GraphS case Down: err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error { for _, requirement := range(ext.Requirements) { - err := Signal(context, requirement, node, signal) + err := SendSignal(context, requirement, node, signal) if err != nil { return err } diff --git a/node.go b/node.go index 4c77f1f..a459f88 100644 --- a/node.go +++ b/node.go @@ -67,7 +67,7 @@ type Extension interface { Serializable[ExtType] // Send a signal to this extension to process, // this typically triggers signals to be sent to nodes linked in the extension - Process(context *StateContext, node *Node, signal GraphSignal) error + Process(context *StateContext, node *Node, signal Signal) error } // Nodes represent an addressible group of extensions @@ -141,12 +141,20 @@ func (node *Node) Serialize() ([]byte, error) { return node_db.Serialize(), nil } -func NewNode(id NodeID, node_type NodeType) Node { - return Node{ +func NewNode(ctx *Context, id NodeID, node_type NodeType) *Node { + _, exists := ctx.Nodes[id] + if exists == true { + panic("Attempted to create an existing node") + } + + node := &Node{ ID: id, Type: node_type, Extensions: map[ExtType]Extension{}, } + + ctx.Nodes[id] = node + return node } func Allowed(context *StateContext, principal *Node, action string, node *Node) error { @@ -191,8 +199,9 @@ func Allowed(context *StateContext, principal *Node, action string, node *Node) // Check that princ is allowed to signal this action, // then send the signal to all the extensions of the node -func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal) error { - context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, signal.String()) +func SendSignal(context *StateContext, node *Node, princ *Node, signal Signal) error { + ser, _ := signal.Serialize() + context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, string(ser)) err := UseStates(context, princ, NewACLInfo(node, []string{}), func(context *StateContext) error { return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node) @@ -398,9 +407,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { } // Create the blank node with the ID, and add it to the context - new_node := NewNode(id, node_type.Type) - node = &new_node - ctx.Nodes[id] = node + node = NewNode(ctx, id, node_type.Type) found_extensions := []ExtType{} // Parse each of the extensions from the db diff --git a/node_test.go b/node_test.go index 563ca4c..82f53a6 100644 --- a/node_test.go +++ b/node_test.go @@ -9,21 +9,20 @@ func TestNodeDB(t *testing.T) { node_type := NodeType("test") err := ctx.RegisterNodeType(node_type, []ExtType{"ACL"}) fatalErr(t, err) - node := NewNode(RandID(), node_type) + node := NewNode(ctx, RandID(), node_type) node.Extensions[ACLExtType] = &ACLExt{ Delegations: NodeMap{}, } - ctx.Nodes[node.ID] = &node context := NewWriteContext(ctx) - err = UpdateStates(context, &node, NewACLInfo(&node, []string{"test"}), func(context *StateContext) error { + err = UpdateStates(context, node, NewACLInfo(node, []string{"test"}), func(context *StateContext) error { ser, err := node.Serialize() ctx.Log.Logf("test", "NODE_SER: %+v", ser) return err }) fatalErr(t, err) - delete(ctx.Nodes, node.ID) + ctx.Nodes = NodeMap{} _, err = LoadNode(ctx, node.ID) fatalErr(t, err) } diff --git a/policy.go b/policy.go index 9bb9547..3547947 100644 --- a/policy.go +++ b/policy.go @@ -6,30 +6,32 @@ import ( ) type Policy interface { - Type() PolicyType - Serialize() ([]byte, error) + Serializable[PolicyType] Allows(context *StateContext, principal *Node, action string, node *Node) bool } +const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF") +type RequirementOfPolicy struct { + PerNodePolicy +} +func (policy *RequirementOfPolicy) Type() PolicyType { + return RequirementOfPolicyType +} + +func NewRequirementOfPolicy(nodes NodeActions) RequirementOfPolicy { + return RequirementOfPolicy{ + PerNodePolicy: NewPerNodePolicy(nodes), + } +} + const ChildOfPolicyType = PolicyType("CHILD_OF") type ChildOfPolicy struct { - NodeActions map[NodeID][]string + PerNodePolicy } - func (policy *ChildOfPolicy) Type() PolicyType { return ChildOfPolicyType } -func (policy *ChildOfPolicy) Serialize() ([]byte, error) { - node_actions := map[string][]string{} - for id, actions := range(policy.NodeActions) { - node_actions[id.String()] = actions - } - return json.MarshalIndent(&ChildOfPolicyJSON{ - NodeActions: node_actions, - }, "", " ") -} - func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { context.Graph.Log.Logf("policy", "CHILD_OF_POLICY: %+v", policy) thread_ext, err := GetExt[*ThreadExt](principal) @@ -53,64 +55,49 @@ func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, acti return false } -type ChildOfPolicyJSON struct { - NodeActions map[string][]string `json:"node_actions"` -} - -func LoadChildOfPolicy(ctx *Context, data []byte) (Policy, error) { - var j ChildOfPolicyJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } +type NodeActions map[NodeID][]string - node_actions := map[NodeID][]string{} - for id_str, actions := range(j.NodeActions) { - id, err := ParseID(id_str) +func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) { + return func(ctx *Context, data []byte)(Policy, error){ + var j PerNodePolicyJSON + err := json.Unmarshal(data, &j) if err != nil { return nil, err } - _, err = LoadNode(ctx, id) - if err != nil { - return nil, err - } + node_actions := NodeActions{} + for id_str, actions := range(j.NodeActions) { + id, err := ParseID(id_str) + if err != nil { + return nil, err + } - node_actions[id] = actions - } + _, err = LoadNode(ctx, id) + if err != nil { + return nil, err + } - return NewChildOfPolicy(node_actions), nil -} + node_actions[id] = actions + } -func NewChildOfPolicy(node_actions map[NodeID][]string) *ChildOfPolicy { - if node_actions == nil { - node_actions = map[NodeID][]string{} + return init_fn(node_actions) } +} - return &ChildOfPolicy{ - NodeActions: node_actions, +func NewChildOfPolicy(node_actions map[NodeID][]string) ChildOfPolicy { + return ChildOfPolicy{ + PerNodePolicy: NewPerNodePolicy(node_actions), } } const ParentOfPolicyType = PolicyType("PARENT_OF") type ParentOfPolicy struct { - NodeActions map[NodeID][]string + PerNodePolicy } - func (policy *ParentOfPolicy) Type() PolicyType { return ParentOfPolicyType } -func (policy *ParentOfPolicy) Serialize() ([]byte, error) { - node_actions := map[string][]string{} - for id, actions := range(policy.NodeActions) { - node_actions[id.String()] = actions - } - return json.MarshalIndent(&ParentOfPolicyJSON{ - NodeActions: node_actions, - }, "", " ") -} - func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { context.Graph.Log.Logf("policy", "PARENT_OF_POLICY: %+v", policy) for id, actions := range(policy.NodeActions) { @@ -134,102 +121,36 @@ func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, act return false } -type ParentOfPolicyJSON struct { - NodeActions map[string][]string `json:"node_actions"` -} - -func LoadParentOfPolicy(ctx *Context, data []byte) (Policy, error) { - var j ParentOfPolicyJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - node_actions := map[NodeID][]string{} - for id_str, actions := range(j.NodeActions) { - id, err := ParseID(id_str) - if err != nil { - return nil, err - } - - _, err = LoadNode(ctx, id) - if err != nil { - return nil, err - } - - node_actions[id] = actions +func NewParentOfPolicy(node_actions map[NodeID][]string) ParentOfPolicy { + return ParentOfPolicy{ + PerNodePolicy: NewPerNodePolicy(node_actions), } - - return NewParentOfPolicy(node_actions), nil } -func NewParentOfPolicy(node_actions map[NodeID][]string) *ParentOfPolicy { +func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy { if node_actions == nil { node_actions = map[NodeID][]string{} } - return &ParentOfPolicy{ + return PerNodePolicy{ NodeActions: node_actions, } } -func LoadPerNodePolicy(ctx *Context, data []byte) (Policy, error) { - var j PerNodePolicyJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - node_actions := map[NodeID][]string{} - for id_str, actions := range(j.NodeActions) { - id, err := ParseID(id_str) - if err != nil { - return nil, err - } - - _, err = LoadNode(ctx, id) - if err != nil { - return nil, err - } - - node_actions[id] = actions - } - - - return NewPerNodePolicy(node_actions, j.WildcardActions), nil -} - -func NewPerNodePolicy(node_actions map[NodeID][]string, wildcard_actions []string) *PerNodePolicy { - if node_actions == nil { - node_actions = map[NodeID][]string{} - } - - if wildcard_actions == nil { - wildcard_actions = []string{} - } - - return &PerNodePolicy{ - NodeActions: node_actions, - WildcardActions: wildcard_actions, - } -} - type PerNodePolicy struct { NodeActions map[NodeID][]string - WildcardActions []string } type PerNodePolicyJSON struct { NodeActions map[string][]string `json:"node_actions"` - WildcardActions []string `json:"wildcard_actions"` } const PerNodePolicyType = PolicyType("PER_NODE") -func (policy PerNodePolicy) Type() PolicyType { +func (policy *PerNodePolicy) Type() PolicyType { return PerNodePolicyType } -func (policy PerNodePolicy) Serialize() ([]byte, error) { +func (policy *PerNodePolicy) Serialize() ([]byte, error) { node_actions := map[string][]string{} for id, actions := range(policy.NodeActions) { node_actions[id.String()] = actions @@ -237,17 +158,10 @@ func (policy PerNodePolicy) Serialize() ([]byte, error) { return json.MarshalIndent(&PerNodePolicyJSON{ NodeActions: node_actions, - WildcardActions: policy.WildcardActions, }, "", " ") } -func (policy PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { - for _, a := range(policy.WildcardActions) { - if a == action { - return true - } - } - +func (policy *PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { for id, actions := range(policy.NodeActions) { if id != principal.ID { continue @@ -272,7 +186,7 @@ type ACLExt struct { Delegations NodeMap } -func (ext *ACLExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error { return nil } @@ -347,13 +261,28 @@ func NewACLPolicyExtContext() *ACLPolicyExtContext { return &ACLPolicyExtContext{ Types: map[PolicyType]PolicyInfo{ PerNodePolicyType: PolicyInfo{ - Load: LoadPerNodePolicy, + Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){ + policy := NewPerNodePolicy(nodes) + return &policy, nil + }), }, ParentOfPolicyType: PolicyInfo{ - Load: LoadParentOfPolicy, + Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){ + policy := NewParentOfPolicy(nodes) + return &policy, nil + }), }, ChildOfPolicyType: PolicyInfo{ - Load: LoadChildOfPolicy, + Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){ + policy := NewChildOfPolicy(nodes) + return &policy, nil + }), + }, + RequirementOfPolicyType: PolicyInfo{ + Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){ + policy := NewRequirementOfPolicy(nodes) + return &policy, nil + }), }, }, } @@ -376,7 +305,7 @@ func (ext *ACLPolicyExt) Serialize() ([]byte, error) { }, "", " ") } -func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal Signal) error { return nil } diff --git a/signal.go b/signal.go index 2232280..033a81a 100644 --- a/signal.go +++ b/signal.go @@ -11,54 +11,48 @@ const ( Direct ) -// GraphSignals are passed around the event tree/resource DAG and cast by Type() -type GraphSignal interface { - // How to propogate the signal +type SignalType string + +type Signal interface { + Serializable[SignalType] Direction() SignalDirection - Type() string - String() string } -// BaseSignal is the most basic type of signal, it has no additional data type BaseSignal struct { - FDirection SignalDirection `json:"direction"` - FType string `json:"type"` + SignalDirection SignalDirection `json:"direction"` + SignalType SignalType `json:"type"` } -func (signal BaseSignal) String() string { - ser, err := json.Marshal(signal) - if err != nil { - return "STATE_SER_ERR" - } - return string(ser) +func (signal BaseSignal) Type() SignalType { + return signal.SignalType } func (signal BaseSignal) Direction() SignalDirection { - return signal.FDirection + return signal.SignalDirection } -func (signal BaseSignal) Type() string { - return signal.FType +func (signal BaseSignal) Serialize() ([]byte, error) { + return json.MarshalIndent(signal, "", " ") } -func NewBaseSignal(_type string, direction SignalDirection) BaseSignal { +func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal { signal := BaseSignal{ - FDirection: direction, - FType: _type, + SignalDirection: direction, + SignalType: signal_type, } return signal } -func NewDownSignal(_type string) BaseSignal { - return NewBaseSignal(_type, Down) +func NewDownSignal(signal_type SignalType) BaseSignal { + return NewBaseSignal(signal_type, Down) } -func NewSignal(_type string) BaseSignal { - return NewBaseSignal(_type, Up) +func NewUpSignal(signal_type SignalType) BaseSignal { + return NewBaseSignal(signal_type, Up) } -func NewDirectSignal(_type string) BaseSignal { - return NewBaseSignal(_type, Direct) +func NewDirectSignal(signal_type SignalType) BaseSignal { + return NewBaseSignal(signal_type, Direct) } var AbortSignal = NewBaseSignal("abort", Down) @@ -77,9 +71,9 @@ func (signal IDSignal) String() string { return string(ser) } -func NewIDSignal(_type string, direction SignalDirection, id NodeID) IDSignal { +func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal { return IDSignal{ - BaseSignal: NewBaseSignal(_type, direction), + BaseSignal: NewBaseSignal(signal_type, direction), ID: id, } } diff --git a/thread.go b/thread.go index 98c59ed..e24a398 100644 --- a/thread.go +++ b/thread.go @@ -12,8 +12,8 @@ import ( type ThreadAction func(*Context, *Node, *ThreadExt)(string, error) type ThreadActions map[string]ThreadAction -type ThreadHandler func(*Context, *Node, *ThreadExt, GraphSignal)(string, error) -type ThreadHandlers map[string]ThreadHandler +type ThreadHandler func(*Context, *Node, *ThreadExt, Signal)(string, error) +type ThreadHandlers map[SignalType]ThreadHandler type InfoType string func (t InfoType) String() string { @@ -122,7 +122,7 @@ type ThreadExt struct { ThreadType ThreadType - SignalChan chan GraphSignal + SignalChan chan Signal TimeoutChan <-chan time.Time ChildWaits sync.WaitGroup @@ -191,7 +191,7 @@ func NewThreadExt(ctx*Context, thread_type ThreadType, parent *Node, children ma return &ThreadExt{ Actions: type_info.Actions, Handlers: type_info.Handlers, - SignalChan: make(chan GraphSignal, THREAD_BUFFER_SIZE), + SignalChan: make(chan Signal, THREAD_BUFFER_SIZE), TimeoutChan: timeout_chan, Active: false, State: state, @@ -276,7 +276,7 @@ func (ext *ThreadExt) ChildList() []*Node { } // Assumed that thread is already locked for signal -func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext *ThreadExt) Process(context *StateContext, node *Node, signal Signal) error { context.Graph.Log.Logf("signal", "THREAD_PROCESS: %s", node.ID) var err error @@ -285,7 +285,7 @@ func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSig err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error { if ext.Parent != nil { if ext.Parent.ID != node.ID { - return Signal(context, ext.Parent, node, signal) + return SendSignal(context, ext.Parent, node, signal) } } return nil @@ -293,7 +293,7 @@ func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSig case Down: err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error { for _, info := range(ext.Children) { - err := Signal(context, info.Child, node, signal) + err := SendSignal(context, info.Child, node, signal) if err != nil { return err } @@ -535,7 +535,7 @@ func ThreadLoop(ctx * Context, thread *Node, first_action string) error { return nil } -func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { +func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) { ctx.Log.Logf("thread", "THREAD_CHILD_LINKED: %s - %+v", thread.ID, signal) context := NewWriteContext(ctx) err := UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error { @@ -570,7 +570,7 @@ func ThreadChildLinked(ctx *Context, thread *Node, thread_ext *ThreadExt, signal // Helper function to start a child from a thread during a signal handler // Starts a write context, so cannot be called from either a write or read context -func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { +func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) { sig, ok := signal.(StartChildSignal) if ok == false { return "wait", nil @@ -638,7 +638,7 @@ func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, er } context = NewReadContext(ctx) - return "wait", Signal(context, thread, thread, NewStatusSignal("started", thread.ID)) + return "wait", SendSignal(context, thread, thread, NewStatusSignal("started", thread.ID)) } func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) { @@ -685,9 +685,9 @@ func ThreadFinish(ctx *Context, thread *Node, thread_ext *ThreadExt) (string, er var ThreadAbortedError = errors.New("Thread aborted by signal") // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError -func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { +func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) { context := NewReadContext(ctx) - err := Signal(context, thread, thread, NewStatusSignal("aborted", thread.ID)) + err := SendSignal(context, thread, thread, NewStatusSignal("aborted", thread.ID)) if err != nil { return "", err } @@ -695,9 +695,9 @@ func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Grap } // Default thread action for "stop", sends a signal and returns no error -func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal GraphSignal) (string, error) { +func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) { context := NewReadContext(ctx) - err := Signal(context, thread, thread, NewStatusSignal("stopped", thread.ID)) + err := SendSignal(context, thread, thread, NewStatusSignal("stopped", thread.ID)) return "finish", err } diff --git a/user.go b/user.go index 930fbd9..0be303c 100644 --- a/user.go +++ b/user.go @@ -20,7 +20,7 @@ type ECDHExtJSON struct { Shared []byte `json:"shared"` } -func (ext *ECDHExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext *ECDHExt) Process(context *StateContext, node *Node, signal Signal) error { return nil } @@ -115,6 +115,6 @@ func LoadGroupExt(ctx *Context, data []byte) (Extension, error) { return NewGroupExt(members), nil } -func (ext *GroupExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext *GroupExt) Process(context *StateContext, node *Node, signal Signal) error { return nil }