diff --git a/context.go b/context.go index 8a1cd96..0dc5423 100644 --- a/context.go +++ b/context.go @@ -103,11 +103,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { Nodes: map[NodeID]*Node{}, } - err := ctx.RegisterExtension(ACLExtType, LoadACLExt, nil) - if err != nil { - return nil, err - } - + var err error err = ctx.RegisterExtension(ACLPolicyExtType, LoadACLPolicyExt, NewACLPolicyExtContext()) if err != nil { return nil, err @@ -150,7 +146,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, ACLPolicyExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType}) + err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLPolicyExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType}) if err != nil { return nil, err } diff --git a/gql.go b/gql.go index 004ef4f..7d00d26 100644 --- a/gql.go +++ b/gql.go @@ -1053,7 +1053,7 @@ var gql_actions ThreadActions = ThreadActions{ } context = NewReadContext(ctx) - err = thread.Process(context, thread, NewStatusSignal("server_started", thread.ID)) + err = thread.Process(context, thread.ID, NewStatusSignal("server_started", thread.ID)) if err != nil { return "", err } diff --git a/gql_mutation.go b/gql_mutation.go index 00e5165..d961604 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -35,7 +35,7 @@ var GQLMutationAbort = NewField(func()*graphql.Field { if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) } - return node.Process(context, ctx.User, AbortSignal) + return node.Process(context, ctx.User.ID, AbortSignal) }) if err != nil { return nil, err @@ -98,7 +98,7 @@ var GQLMutationStartChild = NewField(func()*graphql.Field{ } signal = NewStartChildSignal(child_id, action) - return parent.Process(context, ctx.User, signal) + return parent.Process(context, ctx.User.ID, signal) }) if err != nil { return nil, err diff --git a/gql_test.go b/gql_test.go index a5efc45..7820fb7 100644 --- a/gql_test.go +++ b/gql_test.go @@ -19,14 +19,13 @@ func TestGQLDB(t * testing.T) { ctx := logTestContext(t, []string{"test", "signal"}) TestUserNodeType := NodeType("TEST_USER") - err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType, ACLPolicyExtType}) + err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLPolicyExtType}) fatalErr(t, err) u1 := NewNode(ctx, RandID(), TestUserNodeType) u1_policy := NewPerNodePolicy(NodeActions{ u1.ID: Actions{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"}, }) - u1.Extensions[ACLExtType] = NewACLExt(nil) u1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ PerNodePolicyType: &u1_policy, }) @@ -37,14 +36,13 @@ func TestGQLDB(t * testing.T) { ctx.Log.Logf("test", "L1_ID: %s", l1.ID) TestThreadNodeType := NodeType("TEST_THREAD") - err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLExtType, ACLPolicyExtType, ThreadExtType, LockableExtType}) + err = ctx.RegisterNodeType(TestThreadNodeType, []ExtType{ACLPolicyExtType, ThreadExtType, LockableExtType}) fatalErr(t, err) t1 := NewNode(ctx, RandID(), TestThreadNodeType) t1_policy := NewParentOfPolicy(NodeActions{ t1.ID: Actions{"signal.abort", "state.write"}, }) - t1.Extensions[ACLExtType] = NewACLExt(NodeList(u1)) t1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ ParentOfPolicyType: &t1_policy, }) @@ -82,10 +80,10 @@ func TestGQLDB(t * testing.T) { fatalErr(t, err) context = NewReadContext(ctx) - err = gql.Process(context, gql, NewStatusSignal("child_linked", t1.ID)) + err = gql.Process(context, gql.ID, NewStatusSignal("child_linked", t1.ID)) fatalErr(t, err) context = NewReadContext(ctx) - err = gql.Process(context, gql, AbortSignal) + err = gql.Process(context, gql.ID, AbortSignal) fatalErr(t, err) err = ThreadLoop(ctx, gql, "start") @@ -121,7 +119,7 @@ func TestGQLDB(t * testing.T) { if err != nil { return err } - gql_loaded.Process(context, gql_loaded, StopSignal) + gql_loaded.Process(context, gql_loaded.ID, StopSignal) return err }) diff --git a/lockable.go b/lockable.go index 9141eb8..654b048 100644 --- a/lockable.go +++ b/lockable.go @@ -139,7 +139,7 @@ func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal owner_sent := false for _, dependency := range(ext.Dependencies) { context.Graph.Log.Logf("signal", "SENDING_TO_DEPENDENCY: %s -> %s", node.ID, dependency.ID) - dependency.Process(context, node, signal) + dependency.Process(context, node.ID, signal) if ext.Owner != nil { if dependency.ID == ext.Owner.ID { owner_sent = true @@ -149,7 +149,7 @@ func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal if ext.Owner != nil && owner_sent == false { if ext.Owner.ID != node.ID { context.Graph.Log.Logf("signal", "SENDING_TO_OWNER: %s -> %s", node.ID, ext.Owner.ID) - return ext.Owner.Process(context, node, signal) + return ext.Owner.Process(context, node.ID, signal) } } return nil @@ -157,7 +157,7 @@ func (ext *LockableExt) Process(context *StateContext, node *Node, signal Signal case Down: err = UseStates(context, node, NewACLInfo(node, []string{"requirements"}), func(context *StateContext) error { for _, requirement := range(ext.Requirements) { - err := requirement.Process(context, node, signal) + err := requirement.Process(context, node.ID, signal) if err != nil { return err } diff --git a/node.go b/node.go index b8b7dfe..2636f18 100644 --- a/node.go +++ b/node.go @@ -157,62 +157,34 @@ func NewNode(ctx *Context, id NodeID, node_type NodeType) *Node { return node } -func Allowed(context *StateContext, principal *Node, action string, node *Node) error { - context.Graph.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal.ID, node.ID, action) - if principal == nil { - context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s", principal.ID, node.ID, action) - return fmt.Errorf("nil is not allowed to perform any actions") - } - +func Allowed(context *StateContext, principal_id NodeID, action string, node *Node) error { + context.Graph.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal_id, node.ID, action) // Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension - if principal.ID == node.ID { + if principal_id == node.ID { return nil } // Check if the node has a policy extension itself, and check against the policies in it policy_ext, err := GetExt[*ACLPolicyExt](node) - self_tried := false - if err == nil { - if policy_ext.Allows(context, principal, action, node) == true { - return nil - } - self_tried = true - } - - acl_ext, err := GetExt[*ACLExt](node) if err != nil { - if self_tried == true { - return fmt.Errorf("POLICY_SELF: policies on %s do not allow %s to perform %s", node.ID, principal.ID, action) - } else { - return err - } + return err } - for _, policy_node := range(acl_ext.Delegations) { - context.Graph.Log.Logf("policy", "POLICY_DELEGATION_CHECK: %s->%s", node.ID, policy_node.ID) - policy_ext, err := GetExt[*ACLPolicyExt](policy_node) - if err != nil { - return err - } - if policy_ext.Allows(context, principal, action, node) == true { - context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s", principal.ID, node.ID, action) - return nil - } + if policy_ext.Allows(context, principal_id, action, node) == true { + return nil } - context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s", principal.ID, node.ID, action) - return fmt.Errorf("%s is not allowed to perform %s on %s", principal.ID, action, node.ID) + + context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s", principal_id, node.ID, action) + return fmt.Errorf("%s is not allowed to perform %s on %s", principal_id, action, node.ID) } // Check that princ is allowed to signal this action, // then send the signal to all the extensions of the node -func (node *Node) Process(context *StateContext, princ *Node, signal Signal) error { +func (node *Node) Process(context *StateContext, princ_id NodeID, signal Signal) error { ser, _ := signal.Serialize() context.Graph.Log.Logf("signal", "SIGNAL: %s - %s", node.ID, string(ser)) - err := UseStates(context, princ, NewACLInfo(node, []string{}), func(context *StateContext) error { - return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node) - }) - + err := Allowed(context, princ_id, fmt.Sprintf("signal.%s", signal.Type()), node) if err != nil { return err } @@ -676,7 +648,7 @@ func UseStates(context *StateContext, principal *Node, new_nodes ACLMap, state_f } if already_granted == false { - err := Allowed(context, principal, fmt.Sprintf("%s.read", resource), node) + err := Allowed(context, principal.ID, fmt.Sprintf("%s.read", resource), node) if err != nil { for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "RUNLOCKING_ON_ERROR %s", id.String()) @@ -772,7 +744,7 @@ func UpdateStates(context *StateContext, principal *Node, new_nodes ACLMap, stat } if already_granted == false { - err := Allowed(context, principal, fmt.Sprintf("%s.write", resource), node) + err := Allowed(context, principal.ID, fmt.Sprintf("%s.write", resource), node) if err != nil { for _, n := range(new_locks) { context.Graph.Log.Logf("mutex", "UNLOCKING_ON_ERROR %s", id.String()) diff --git a/node_test.go b/node_test.go index 82f53a6..588e047 100644 --- a/node_test.go +++ b/node_test.go @@ -7,12 +7,11 @@ import ( func TestNodeDB(t *testing.T) { ctx := logTestContext(t, []string{"test", "db", "node", "policy"}) node_type := NodeType("test") - err := ctx.RegisterNodeType(node_type, []ExtType{"ACL"}) + err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) fatalErr(t, err) + node := NewNode(ctx, RandID(), node_type) - node.Extensions[ACLExtType] = &ACLExt{ - Delegations: NodeMap{}, - } + node.Extensions[GroupExtType] = NewGroupExt(nil) context := NewWriteContext(ctx) err = UpdateStates(context, node, NewACLInfo(node, []string{"test"}), func(context *StateContext) error { diff --git a/policy.go b/policy.go index e021f2b..4b3e179 100644 --- a/policy.go +++ b/policy.go @@ -7,7 +7,38 @@ import ( type Policy interface { Serializable[PolicyType] - Allows(context *StateContext, principal *Node, action string, node *Node) bool + Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool +} + +//TODO +func (policy *AllNodesPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { + return policy.Actions.Allows(action) +} + +func (policy *RequirementOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { + return false +} + +func (policy *PerNodePolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { + for id, actions := range(policy.NodeActions) { + if id != principal_id { + continue + } + for _, a := range(actions) { + if a == action { + return true + } + } + } + return false +} + +func (policy *ParentOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { + return false +} + +func (policy *ChildOfPolicy) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { + return false } const RequirementOfPolicyType = PolicyType("REQUIREMENT_OF") @@ -24,27 +55,6 @@ func NewRequirementOfPolicy(nodes NodeActions) RequirementOfPolicy { } } -// Check if any of principals dependencies are in the policy -func (policy *RequirementOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { - lockable_ext, err := GetExt[*LockableExt](principal) - - if err != nil { - return false - } - - for dep_id, _ := range(lockable_ext.Dependencies) { - for node_id, actions := range(policy.NodeActions) { - if node_id == dep_id { - if actions.Allows(action) == true { - return true - } - break - } - } - } - return false -} - const ChildOfPolicyType = PolicyType("CHILD_OF") type ChildOfPolicy struct { PerNodePolicy @@ -53,29 +63,6 @@ func (policy *ChildOfPolicy) Type() PolicyType { return ChildOfPolicyType } -func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { - context.Graph.Log.Logf("policy", "CHILD_OF_POLICY: %+v", policy) - thread_ext, err := GetExt[*ThreadExt](principal) - if err != nil { - return false - } - - parent := thread_ext.Parent - if parent != nil { - actions, exists := policy.NodeActions[parent.ID] - if exists == false { - return false - } - for _, a := range(actions) { - if a == action { - return true - } - } - } - - return false -} - type Actions []string func (actions Actions) Allows(action string) bool { @@ -130,29 +117,6 @@ func (policy *ParentOfPolicy) Type() PolicyType { return ParentOfPolicyType } -func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { - context.Graph.Log.Logf("policy", "PARENT_OF_POLICY: %+v", policy) - for id, actions := range(policy.NodeActions) { - thread_ext, err := GetExt[*ThreadExt](context.Graph.Nodes[id]) - if err != nil { - continue - } - - context.Graph.Log.Logf("policy", "PARENT_OF_PARENT: %s %+v", id, thread_ext.Parent) - if thread_ext.Parent != nil { - if thread_ext.Parent.ID == principal.ID { - for _, a := range(actions) { - if a == action { - return true - } - } - } - } - } - - return false -} - func NewParentOfPolicy(node_actions NodeActions) ParentOfPolicy { return ParentOfPolicy{ PerNodePolicy: NewPerNodePolicy(node_actions), @@ -193,20 +157,6 @@ func (policy *PerNodePolicy) Serialize() ([]byte, error) { }, "", " ") } -func (policy *PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { - for id, actions := range(policy.NodeActions) { - if id != principal.ID { - continue - } - for _, a := range(actions) { - if a == action { - return true - } - } - } - return false -} - func NewAllNodesPolicy(actions Actions) AllNodesPolicy { if actions == nil { actions = Actions{} @@ -230,44 +180,11 @@ func (policy *AllNodesPolicy) Serialize() ([]byte, error) { return json.MarshalIndent(policy, "", " ") } -func (policy *AllNodesPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { - return policy.Actions.Allows(action) -} - - // Extension to allow a node to hold ACL policies type ACLPolicyExt struct { Policies map[PolicyType]Policy } -// The ACL extension stores a map of nodes to delegate ACL to, and a list of policies -type ACLExt struct { - Delegations NodeMap -} - -func (ext *ACLExt) Process(context *StateContext, node *Node, signal Signal) error { - return nil -} - -func LoadACLExt(ctx *Context, data []byte) (Extension, error) { - var j struct { - Delegations []string `json:"delegation"` - } - - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - delegations, err := RestoreNodeList(ctx, j.Delegations) - if err != nil { - return nil, err - } - - return &ACLExt{ - Delegations: delegations, - }, nil -} func NodeList(nodes ...*Node) NodeMap { m := NodeMap{} @@ -277,36 +194,6 @@ func NodeList(nodes ...*Node) NodeMap { return m } -func NewACLExt(delegations NodeMap) *ACLExt { - if delegations == nil { - delegations = NodeMap{} - } - - return &ACLExt{ - Delegations: delegations, - } -} - -func (ext *ACLExt) Serialize() ([]byte, error) { - delegations := make([]string, len(ext.Delegations)) - i := 0 - for id, _ := range(ext.Delegations) { - delegations[i] = id.String() - i += 1 - } - - return json.MarshalIndent(&struct{ - Delegations []string `json:"delegations"` - }{ - Delegations: delegations, - }, "", " ") -} - -const ACLExtType = ExtType("ACL") -func (ext *ACLExt) Type() ExtType { - return ACLExtType -} - type PolicyLoadFunc func(*Context, []byte) (Policy, error) type PolicyInfo struct { Load PolicyLoadFunc @@ -427,11 +314,11 @@ func (ext *ACLPolicyExt) Type() ExtType { } // Check if the extension allows the principal to perform action on node -func (ext *ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool { +func (ext *ACLPolicyExt) Allows(context *StateContext, principal_id NodeID, action string, node *Node) bool { context.Graph.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext) for _, policy := range(ext.Policies) { context.Graph.Log.Logf("policy", "POLICY_CHECK_POLICY: %+v", policy) - if policy.Allows(context, principal, action, node) == true { + if policy.Allows(context, principal_id, action, node) == true { return true } } diff --git a/thread.go b/thread.go index 1643054..0a00c82 100644 --- a/thread.go +++ b/thread.go @@ -306,7 +306,7 @@ func (ext *ThreadExt) Process(context *StateContext, node *Node, signal Signal) err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error { if ext.Parent != nil { if ext.Parent.ID != node.ID { - return ext.Parent.Process(context, node, signal) + return ext.Parent.Process(context, node.ID, signal) } } return nil @@ -314,7 +314,7 @@ func (ext *ThreadExt) Process(context *StateContext, node *Node, signal Signal) case Down: err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error { for _, info := range(ext.Children) { - err := info.Child.Process(context, node, signal) + err := info.Child.Process(context, node.ID, signal) if err != nil { return err } @@ -659,7 +659,7 @@ func ThreadStart(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, er } context = NewReadContext(ctx) - return "wait", thread.Process(context, thread, NewStatusSignal("started", thread.ID)) + return "wait", thread.Process(context, thread.ID, NewStatusSignal("started", thread.ID)) } func ThreadWait(ctx * Context, thread *Node, thread_ext *ThreadExt) (string, error) { @@ -708,7 +708,7 @@ var ThreadAbortedError = errors.New("Thread aborted by signal") // Default thread action function for "abort", sends a signal and returns a ThreadAbortedError func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) { context := NewReadContext(ctx) - err := thread.Process(context, thread, NewStatusSignal("aborted", thread.ID)) + err := thread.Process(context, thread.ID, NewStatusSignal("aborted", thread.ID)) if err != nil { return "", err } @@ -718,7 +718,7 @@ func ThreadAbort(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Sign // Default thread action for "stop", sends a signal and returns no error func ThreadStop(ctx * Context, thread *Node, thread_ext *ThreadExt, signal Signal) (string, error) { context := NewReadContext(ctx) - err := thread.Process(context, thread, NewStatusSignal("stopped", thread.ID)) + err := thread.Process(context, thread.ID, NewStatusSignal("stopped", thread.ID)) return "finish", err }