diff --git a/context.go b/context.go index 0dc5423..6ae2000 100644 --- a/context.go +++ b/context.go @@ -104,7 +104,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { } var err error - err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExt, NewACLPolicyExtContext()) + err = ctx.RegisterExtension(ACLExtType, LoadACLExt, NewACLExtContext()) if err != nil { return nil, err } @@ -146,7 +146,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLPolicyExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType}) + err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType}) if err != nil { return nil, err } diff --git a/gql_test.go b/gql_test.go index 7820fb7..ad60782 100644 --- a/gql_test.go +++ b/gql_test.go @@ -19,14 +19,14 @@ func TestGQLDB(t * testing.T) { ctx := logTestContext(t, []string{"test", "signal"}) TestUserNodeType := NodeType("TEST_USER") - err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLPolicyExtType}) + err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType}) fatalErr(t, err) u1 := NewNode(ctx, RandID(), TestUserNodeType) u1_policy := NewPerNodePolicy(NodeActions{ u1.ID: Actions{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"}, }) - u1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ + u1.Extensions[ACLExtType] = NewACLExt(map[PolicyType]Policy{ PerNodePolicyType: &u1_policy, }) @@ -36,14 +36,12 @@ func TestGQLDB(t * testing.T) { ctx.Log.Logf("test", "L1_ID: %s", l1.ID) TestThreadNodeType := NodeType("TEST_THREAD") - err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLPolicyExtType, ThreadExtType, LockableExtType}) + err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ThreadExtType, LockableExtType}) fatalErr(t, err) t1 := NewNode(ctx, RandID(), TestThreadNodeType) - t1_policy := NewParentOfPolicy(NodeActions{ - t1.ID: Actions{"signal.abort", "state.write"}, - }) - t1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ + t1_policy := NewParentOfPolicy(Actions{"signal.abort", "state.write"}) + t1.Extensions[ACLExtType] = NewACLExt(map[PolicyType]Policy{ ParentOfPolicyType: &t1_policy, }) t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil) @@ -57,10 +55,8 @@ func TestGQLDB(t * testing.T) { gql, err := NewGQLNode(ctx, NewGQLExt(":0", ecdh.P256(), key, nil, nil)) fatalErr(t, err) - gql_policy := NewChildOfPolicy(NodeActions{ - gql.ID: Actions{"signal.status"}, - }) - gql.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ + gql_policy := NewChildOfPolicy(Actions{"signal.status"}) + gql.Extensions[ACLExtType] = NewACLExt(map[PolicyType]Policy{ ChildOfPolicyType: &gql_policy, }) diff --git a/graph_test.go b/graph_test.go index 83c01f6..fec1f17 100644 --- a/graph_test.go +++ b/graph_test.go @@ -59,7 +59,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { policy := NewAllNodesPolicy([]string{"signal.status", "requirements.write", "requirements.read", "dependencies.write", "dependencies.read", "owner.read", "owner.write"}) listener_extension := NewListenerExt(buffer) listener.Extensions[ListenerExtType] = listener_extension - listener.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ + listener.Extensions[ACLExtType] = NewACLExt(map[PolicyType]Policy{ AllNodesPolicyType: &policy, }) listener.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) @@ -76,7 +76,7 @@ func logTestContext(t * testing.T, components []string) *Context { ctx, err := NewContext(db, NewConsoleLogger(components)) fatalErr(t, err) - err = ctx.RegisterNodeType(SimpleListenerNodeType, []ExtType{ACLPolicyExtType, ListenerExtType, LockableExtType}) + err = ctx.RegisterNodeType(SimpleListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType}) fatalErr(t, err) return ctx diff --git a/node.go b/node.go index 2636f18..b8f9927 100644 --- a/node.go +++ b/node.go @@ -165,17 +165,12 @@ func Allowed(context *StateContext, principal_id NodeID, action string, node *No } // Check if the node has a policy extension itself, and check against the policies in it - policy_ext, err := GetExt[*ACLPolicyExt](node) + policy_ext, err := GetExt[*ACLExt](node) if err != nil { return err } - if policy_ext.Allows(context, principal_id, action, node) == true { - return nil - } - - context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s", principal_id, node.ID, action) - return fmt.Errorf("%s is not allowed to perform %s on %s", principal_id, action, node.ID) + return policy_ext.Allows(context, principal_id, action, node) } // Check that princ is allowed to signal this action, diff --git a/policy.go b/policy.go index 4b3e179..6db7de1 100644 --- a/policy.go +++ b/policy.go @@ -7,57 +7,90 @@ import ( type Policy interface { Serializable[PolicyType] - Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool + Allows(context *StateContext, principal_id NodeID, action string, node *Node) error } -//TODO -func (policy *AllNodesPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { +//TODO: Update with change from principal *Node to principal_id so sane policies can still be made +func (policy *AllNodesPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { return policy.Actions.Allows(action) } -func (policy *RequirementOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { - return false -} - -func (policy *PerNodePolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { +func (policy *PerNodePolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { for id, actions := range(policy.NodeActions) { if id != principal_id { continue } for _, a := range(actions) { if a == action { - return true + return nil } } } - return false + return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) } -func (policy *ParentOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { - return false +func (policy *RequirementOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { + lockable_ext, err := GetExt[*LockableExt](node) + if err != nil { + return err + } + + for id, _ := range(lockable_ext.Requirements) { + if id == principal_id { + return policy.Actions.Allows(action) + } + } + + return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID) } -func (policy *ChildOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { - return false +func (policy *ParentOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { + thread_ext, err := GetExt[*ThreadExt](node) + if err != nil { + return err + } + + if thread_ext.Parent != nil { + if thread_ext.Parent.ID == principal_id { + return policy.Actions.Allows(action) + } + } + + return fmt.Errorf("%s is not a parent of %s", principal_id, node.ID) +} + +func (policy *ChildOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { + thread_ext, err := GetExt[*ThreadExt](node) + if err != nil { + return err + } + + for id, _ := range(thread_ext.Children) { + if id == principal_id { + return policy.Actions.Allows(action) + } + } + + return fmt.Errorf("%s is not a child of %s", principal_id, node.ID) } const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF") type RequirementOfPolicy struct { - PerNodePolicy + AllNodesPolicy } func (policy *RequirementOfPolicy) Type() PolicyType { return RequirementOfPolicyType } -func NewRequirementOfPolicy(nodes NodeActions) RequirementOfPolicy { +func NewRequirementOfPolicy(actions Actions) RequirementOfPolicy { return RequirementOfPolicy{ - PerNodePolicy: NewPerNodePolicy(nodes), + AllNodesPolicy: NewAllNodesPolicy(actions), } } const ChildOfPolicyType = PolicyType("CHILD_OF") type ChildOfPolicy struct { - PerNodePolicy + AllNodesPolicy } func (policy *ChildOfPolicy) Type() PolicyType { return ChildOfPolicyType @@ -65,17 +98,34 @@ func (policy *ChildOfPolicy) Type() PolicyType { type Actions []string -func (actions Actions) Allows(action string) bool { +func (actions Actions) Allows(action string) error { for _, a := range(actions) { if a == action { - return true + return nil } } - return false + return fmt.Errorf("%s not in allows list", action) } type NodeActions map[NodeID]Actions +type AllNodesPolicyJSON struct { + Actions Actions `json:"actions"` +} + +func AllNodesPolicyLoad(init_fn func(Actions)(Policy, error)) func(*Context, []byte)(Policy, error) { + return func(ctx *Context, data []byte)(Policy, error){ + var j AllNodesPolicyJSON + err := json.Unmarshal(data, &j) + + if err != nil { + return nil, err + } + + return init_fn(j.Actions) + } +} + func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) { return func(ctx *Context, data []byte)(Policy, error){ var j PerNodePolicyJSON @@ -103,23 +153,23 @@ func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, } } -func NewChildOfPolicy(node_actions NodeActions) ChildOfPolicy { +func NewChildOfPolicy(actions Actions) ChildOfPolicy { return ChildOfPolicy{ - PerNodePolicy: NewPerNodePolicy(node_actions), + AllNodesPolicy: NewAllNodesPolicy(actions), } } const ParentOfPolicyType = PolicyType("PARENT_OF") type ParentOfPolicy struct { - PerNodePolicy + AllNodesPolicy } func (policy *ParentOfPolicy) Type() PolicyType { return ParentOfPolicyType } -func NewParentOfPolicy(node_actions NodeActions) ParentOfPolicy { +func NewParentOfPolicy(actions Actions) ParentOfPolicy { return ParentOfPolicy{ - PerNodePolicy: NewPerNodePolicy(node_actions), + AllNodesPolicy: NewAllNodesPolicy(actions), } } @@ -168,7 +218,7 @@ func NewAllNodesPolicy(actions Actions) AllNodesPolicy { } type AllNodesPolicy struct { - Actions Actions `json:"actions"` + Actions Actions } const AllNodesPolicyType = PolicyType("ALL_NODES") @@ -181,7 +231,7 @@ func (policy *AllNodesPolicy) Serialize() ([]byte, error) { } // Extension to allow a node to hold ACL policies -type ACLPolicyExt struct { +type ACLExt struct { Policies map[PolicyType]Policy } @@ -199,13 +249,19 @@ type PolicyInfo struct { Load PolicyLoadFunc } -type ACLPolicyExtContext struct { +type ACLExtContext struct { Types map[PolicyType]PolicyInfo } -func NewACLPolicyExtContext() *ACLPolicyExtContext { - return &ACLPolicyExtContext{ +func NewACLExtContext() *ACLExtContext { + return &ACLExtContext{ Types: map[PolicyType]PolicyInfo{ + AllNodesPolicyType: PolicyInfo{ + Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ + policy := NewAllNodesPolicy(actions) + return &policy, nil + }), + }, PerNodePolicyType: PolicyInfo{ Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){ policy := NewPerNodePolicy(nodes) @@ -213,38 +269,28 @@ func NewACLPolicyExtContext() *ACLPolicyExtContext { }), }, ParentOfPolicyType: PolicyInfo{ - Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){ - policy := NewParentOfPolicy(nodes) + Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ + policy := NewParentOfPolicy(actions) return &policy, nil }), }, ChildOfPolicyType: PolicyInfo{ - Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){ - policy := NewChildOfPolicy(nodes) + Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ + policy := NewChildOfPolicy(actions) return &policy, nil }), }, RequirementOfPolicyType: PolicyInfo{ - Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){ - policy := NewRequirementOfPolicy(nodes) + Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ + policy := NewRequirementOfPolicy(actions) return &policy, nil }), }, - AllNodesPolicyType: PolicyInfo{ - Load: func(ctx *Context, data []byte) (Policy, error) { - var policy AllNodesPolicy - err := json.Unmarshal(data, &policy) - if err != nil { - return nil, err - } - return &policy, nil - }, - }, }, } } -func (ext *ACLPolicyExt) Serialize() ([]byte, error) { +func (ext *ACLExt) Serialize() ([]byte, error) { policies := map[string][]byte{} for name, policy := range(ext.Policies) { ser, err := policy.Serialize() @@ -261,11 +307,11 @@ func (ext *ACLPolicyExt) Serialize() ([]byte, error) { }, "", " ") } -func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal Signal) error { +func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error { return nil } -func NewACLPolicyExt(policies map[PolicyType]Policy) *ACLPolicyExt { +func NewACLExt(policies map[PolicyType]Policy) *ACLExt { if policies == nil { policies = map[PolicyType]Policy{} } @@ -276,12 +322,12 @@ func NewACLPolicyExt(policies map[PolicyType]Policy) *ACLPolicyExt { } } - return &ACLPolicyExt{ + return &ACLExt{ Policies: policies, } } -func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) { +func LoadACLExt(ctx *Context, data []byte) (Extension, error) { var j struct { Policies map[string][]byte `json:"policies"` } @@ -291,7 +337,7 @@ func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) { } policies := map[PolicyType]Policy{} - acl_ctx := ctx.ExtByType(ACLPolicyExtType).Data.(*ACLPolicyExtContext) + acl_ctx := ctx.ExtByType(ACLExtType).Data.(*ACLExtContext) for name, ser := range(j.Policies) { policy_def, exists := acl_ctx.Types[PolicyType(name)] if exists == false { @@ -305,22 +351,24 @@ func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) { policies[PolicyType(name)] = policy } - return NewACLPolicyExt(policies), nil + return NewACLExt(policies), nil } -const ACLPolicyExtType = ExtType("ACL_POLICIES") -func (ext *ACLPolicyExt) Type() ExtType { - return ACLPolicyExtType +const ACLExtType = ExtType("ACL_POLICIES") +func (ext *ACLExt) Type() ExtType { + return ACLExtType } // Check if the extension allows the principal to perform action on node -func (ext *ACLPolicyExt) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { +func (ext *ACLExt) Allows(context *StateContext, principal_id NodeID, action string, node *Node) error { context.Graph.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext) + errs := []error{} for _, policy := range(ext.Policies) { - context.Graph.Log.Logf("policy", "POLICY_CHECK_POLICY: %+v", policy) - if policy.Allows(context, principal_id, action, node) == true { - return true + err := policy.Allows(context, principal_id, action, node) + if err == nil { + return nil } + errs = append(errs, err) } - return false + return fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs) }