diff --git a/gql_test.go b/gql_test.go index 9e514ee..54b867a 100644 --- a/gql_test.go +++ b/gql_test.go @@ -11,14 +11,30 @@ import ( ) func TestGQLDBLoad(t * testing.T) { - ctx := logTestContext(t, []string{"test", "db"}) + ctx := logTestContext(t, []string{"test", "db", "policy", "signal"}) + + TestUserNodeType := NodeType("TEST_USER") + err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType, ACLPolicyExtType}) + fatalErr(t, err) + + u1 := NewNode(RandID(), TestUserNodeType) + ctx.Nodes[u1.ID] = &u1 + u1.Extensions[ACLExtType] = NewACLExt(nil) + u1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ + PerNodePolicyType: NewPerNodePolicy(map[NodeID][]string{ + u1.ID: []string{"users.write", "children.write", "parent.write", "dependencies.write", "requirements.write"}, + }, nil), + }) + + ctx.Log.Logf("test", "U1_ID: %s", u1.ID) ListenerNodeType := NodeType("LISTENER") - err := ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType}) + err = ctx.RegisterNodeType(ListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType}) fatalErr(t, err) l1 := NewNode(RandID(), ListenerNodeType) - l1.Extensions[ACLExtType] = NewACLExt(nil) + ctx.Nodes[l1.ID] = &l1 + l1.Extensions[ACLExtType] = NewACLExt(NodeList(&u1)) listener_ext := NewListenerExt(10) l1.Extensions[ListenerExtType] = listener_ext l1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) @@ -30,22 +46,19 @@ func TestGQLDBLoad(t * testing.T) { fatalErr(t, err) t1 := NewNode(RandID(), TestThreadNodeType) - t1.Extensions[ACLExtType] = NewACLExt(nil) + ctx.Nodes[t1.ID] = &t1 + t1.Extensions[ACLExtType] = NewACLExt(NodeList(&u1)) + t1.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ + ParentOfPolicyType: NewParentOfPolicy(map[NodeID][]string{ + t1.ID: []string{"signal.abort", "state.write"}, + }), + }) t1.Extensions[ThreadExtType], err = NewThreadExt(ctx, BaseThreadType, nil, nil, "init", nil) fatalErr(t, err) t1.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) ctx.Log.Logf("test", "T1_ID: %s", t1.ID) - TestUserNodeType := NodeType("TEST_USER") - err = ctx.RegisterNodeType(TestUserNodeType, []ExtType{ACLExtType}) - fatalErr(t, err) - - u1 := NewNode(RandID(), TestUserNodeType) - u1.Extensions[ACLExtType] = NewACLExt(nil) - - ctx.Log.Logf("test", "U1_ID: %s", u1.ID) - TestGQLNodeType := NodeType("TEST_GQL") err = ctx.RegisterNodeType(TestGQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType, ThreadExtType, LockableExtType}) fatalErr(t, err) @@ -54,7 +67,13 @@ func TestGQLDBLoad(t * testing.T) { fatalErr(t, err) gql := NewNode(RandID(), TestGQLNodeType) - gql.Extensions[ACLExtType] = NewACLExt(nil) + ctx.Nodes[gql.ID] = &gql + gql.Extensions[ACLExtType] = NewACLExt(NodeList(&u1)) + gql.Extensions[ACLPolicyExtType] = NewACLPolicyExt(map[PolicyType]Policy{ + ChildOfPolicyType: NewChildOfPolicy(map[NodeID][]string{ + gql.ID: []string{"signal.status"}, + }), + }) gql.Extensions[GroupExtType] = NewGroupExt(nil) gql.Extensions[GQLExtType] = NewGQLExt(":0", ecdh.P256(), key, nil, nil) gql.Extensions[ThreadExtType], err = NewThreadExt(ctx, GQLThreadType, nil, nil, "ini", nil) @@ -62,17 +81,16 @@ func TestGQLDBLoad(t * testing.T) { gql.Extensions[LockableExtType] = NewLockableExt(nil, nil, nil, nil) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) - info := ParentInfo{true, "start", "restore"} context := NewWriteContext(ctx) - err = UpdateStates(context, &gql, NewACLInfo(&gql, []string{"users"}), func(context *StateContext) error { - err := LinkThreads(context, &gql, &gql, ChildInfo{&t1, map[InfoType]Info{ + err = UpdateStates(context, &u1, NewACLInfo(&gql, []string{"users"}), func(context *StateContext) error { + err := LinkThreads(context, &u1, &gql, ChildInfo{&t1, map[InfoType]Info{ ParentInfoType: &info, }}) if err != nil { return err } - return LinkLockables(context, &gql, &l1, []*Node{&gql}) + return LinkLockables(context, &u1, &l1, []*Node{&gql}) }) fatalErr(t, err) diff --git a/node.go b/node.go index d47b72f..4c77f1f 100644 --- a/node.go +++ b/node.go @@ -150,6 +150,7 @@ func NewNode(id NodeID, node_type NodeType) Node { } func Allowed(context *StateContext, principal *Node, action string, node *Node) error { + context.Graph.Log.Logf("policy", "POLICY_CHECK: %s %s.%s", principal.ID, node.ID, action) if principal == nil { context.Graph.Log.Logf("policy", "POLICY_CHECK_ERR: %s %s.%s", principal.ID, node.ID, action) return fmt.Errorf("nil is not allowed to perform any actions") @@ -160,24 +161,31 @@ func Allowed(context *StateContext, principal *Node, action string, node *Node) return nil } + // Check if the node has a policy extension itself, and check against the policies in it + policy_ext, err := GetExt[*ACLPolicyExt](node) + if err == nil { + if policy_ext.Allows(context, principal, action, node) == true { + return nil + } + } + acl_ext, err := GetExt[*ACLExt](node) if err != nil { return err } for _, policy_node := range(acl_ext.Delegations) { - ext, exists := policy_node.Extensions[ACLPolicyExtType] - if exists == false { - context.Graph.Log.Logf("policy", "WARNING: %s has dependency %s which doesn't have ACLPolicyExt") - continue + context.Graph.Log.Logf("policy", "POLICY_DELEGATION_CHECK: %s->%s", node.ID, policy_node.ID) + policy_ext, err := GetExt[*ACLPolicyExt](policy_node) + if err != nil { + return err } - policy_ext := ext.(ACLPolicyExt) if policy_ext.Allows(context, principal, action, node) == true { context.Graph.Log.Logf("policy", "POLICY_CHECK_PASS: %s %s.%s", principal.ID, node.ID, action) return nil } } - context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s.%s", principal.ID, node.ID, action) + context.Graph.Log.Logf("policy", "POLICY_CHECK_FAIL: %s %s.%s", principal.ID, node.ID, action) return fmt.Errorf("%s is not allowed to perform %s on %s", principal.ID, action, node.ID) } @@ -190,10 +198,14 @@ func Signal(context *StateContext, node *Node, princ *Node, signal GraphSignal) return Allowed(context, princ, fmt.Sprintf("signal.%s", signal.Type()), node) }) + if err != nil { + return err + } + for _, ext := range(node.Extensions) { err = ext.Process(context, node, signal) if err != nil { - return nil + return err } } @@ -324,7 +336,7 @@ func WriteNodes(context *StateContext) error { for id, _ := range(context.Locked) { node, _ := context.Graph.Nodes[id] if node == nil { - return fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node, maybe node isn't in the context") + return fmt.Errorf("DB_SERIALIZE_ERROR: cannot serialize nil node(%s), maybe node isn't in the context", id) } ser, err := node.Serialize() diff --git a/policy.go b/policy.go index 9bc6560..9bb9547 100644 --- a/policy.go +++ b/policy.go @@ -6,31 +6,262 @@ import ( ) type Policy interface { + Type() PolicyType Serialize() ([]byte, error) Allows(context *StateContext, principal *Node, action string, node *Node) bool } -func LoadAllNodesPolicy(ctx *Context, data []byte) (Policy, error) { - var policy AllNodesPolicy - err := json.Unmarshal(data, &policy) +const ChildOfPolicyType = PolicyType("CHILD_OF") +type ChildOfPolicy struct { + NodeActions map[NodeID][]string +} + +func (policy *ChildOfPolicy) Type() PolicyType { + return ChildOfPolicyType +} + +func (policy *ChildOfPolicy) Serialize() ([]byte, error) { + node_actions := map[string][]string{} + for id, actions := range(policy.NodeActions) { + node_actions[id.String()] = actions + } + return json.MarshalIndent(&ChildOfPolicyJSON{ + NodeActions: node_actions, + }, "", " ") +} + +func (policy *ChildOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { + context.Graph.Log.Logf("policy", "CHILD_OF_POLICY: %+v", policy) + thread_ext, err := GetExt[*ThreadExt](principal) + if err != nil { + return false + } + + parent := thread_ext.Parent + if parent != nil { + actions, exists := policy.NodeActions[parent.ID] + if exists == false { + return false + } + for _, a := range(actions) { + if a == action { + return true + } + } + } + + return false +} + +type ChildOfPolicyJSON struct { + NodeActions map[string][]string `json:"node_actions"` +} + +func LoadChildOfPolicy(ctx *Context, data []byte) (Policy, error) { + var j ChildOfPolicyJSON + err := json.Unmarshal(data, &j) if err != nil { - return policy, err + return nil, err + } + + node_actions := map[NodeID][]string{} + for id_str, actions := range(j.NodeActions) { + id, err := ParseID(id_str) + if err != nil { + return nil, err + } + + _, err = LoadNode(ctx, id) + if err != nil { + return nil, err + } + + node_actions[id] = actions + } + + return NewChildOfPolicy(node_actions), nil +} + +func NewChildOfPolicy(node_actions map[NodeID][]string) *ChildOfPolicy { + if node_actions == nil { + node_actions = map[NodeID][]string{} + } + + return &ChildOfPolicy{ + NodeActions: node_actions, + } +} + +const ParentOfPolicyType = PolicyType("PARENT_OF") +type ParentOfPolicy struct { + NodeActions map[NodeID][]string +} + +func (policy *ParentOfPolicy) Type() PolicyType { + return ParentOfPolicyType +} + +func (policy *ParentOfPolicy) Serialize() ([]byte, error) { + node_actions := map[string][]string{} + for id, actions := range(policy.NodeActions) { + node_actions[id.String()] = actions } - return policy, nil + return json.MarshalIndent(&ParentOfPolicyJSON{ + NodeActions: node_actions, + }, "", " ") } -type AllNodesPolicy struct { - Actions []string `json:"actions"` +func (policy *ParentOfPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { + context.Graph.Log.Logf("policy", "PARENT_OF_POLICY: %+v", policy) + for id, actions := range(policy.NodeActions) { + thread_ext, err := GetExt[*ThreadExt](context.Graph.Nodes[id]) + if err != nil { + continue + } + + context.Graph.Log.Logf("policy", "PARENT_OF_PARENT: %s %+v", id, thread_ext.Parent) + if thread_ext.Parent != nil { + if thread_ext.Parent.ID == principal.ID { + for _, a := range(actions) { + if a == action { + return true + } + } + } + } + } + + return false } -func (policy AllNodesPolicy) Type() PolicyType { - return PolicyType("simple_policy") +type ParentOfPolicyJSON struct { + NodeActions map[string][]string `json:"node_actions"` } -func (policy AllNodesPolicy) Serialize() ([]byte, error) { - return json.MarshalIndent(&policy, "", " ") +func LoadParentOfPolicy(ctx *Context, data []byte) (Policy, error) { + var j ParentOfPolicyJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + node_actions := map[NodeID][]string{} + for id_str, actions := range(j.NodeActions) { + id, err := ParseID(id_str) + if err != nil { + return nil, err + } + + _, err = LoadNode(ctx, id) + if err != nil { + return nil, err + } + + node_actions[id] = actions + } + + return NewParentOfPolicy(node_actions), nil } +func NewParentOfPolicy(node_actions map[NodeID][]string) *ParentOfPolicy { + if node_actions == nil { + node_actions = map[NodeID][]string{} + } + + return &ParentOfPolicy{ + NodeActions: node_actions, + } +} + +func LoadPerNodePolicy(ctx *Context, data []byte) (Policy, error) { + var j PerNodePolicyJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + node_actions := map[NodeID][]string{} + for id_str, actions := range(j.NodeActions) { + id, err := ParseID(id_str) + if err != nil { + return nil, err + } + + _, err = LoadNode(ctx, id) + if err != nil { + return nil, err + } + + node_actions[id] = actions + } + + + return NewPerNodePolicy(node_actions, j.WildcardActions), nil +} + +func NewPerNodePolicy(node_actions map[NodeID][]string, wildcard_actions []string) *PerNodePolicy { + if node_actions == nil { + node_actions = map[NodeID][]string{} + } + + if wildcard_actions == nil { + wildcard_actions = []string{} + } + + return &PerNodePolicy{ + NodeActions: node_actions, + WildcardActions: wildcard_actions, + } +} + +type PerNodePolicy struct { + NodeActions map[NodeID][]string + WildcardActions []string +} + +type PerNodePolicyJSON struct { + NodeActions map[string][]string `json:"node_actions"` + WildcardActions []string `json:"wildcard_actions"` +} + +const PerNodePolicyType = PolicyType("PER_NODE") +func (policy PerNodePolicy) Type() PolicyType { + return PerNodePolicyType +} + +func (policy PerNodePolicy) Serialize() ([]byte, error) { + node_actions := map[string][]string{} + for id, actions := range(policy.NodeActions) { + node_actions[id.String()] = actions + } + + return json.MarshalIndent(&PerNodePolicyJSON{ + NodeActions: node_actions, + WildcardActions: policy.WildcardActions, + }, "", " ") +} + +func (policy PerNodePolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { + for _, a := range(policy.WildcardActions) { + if a == action { + return true + } + } + + for id, actions := range(policy.NodeActions) { + if id != principal.ID { + continue + } + for _, a := range(actions) { + if a == action { + return true + } + } + } + return false +} + + // Extension to allow a node to hold ACL policies type ACLPolicyExt struct { Policies map[PolicyType]Policy @@ -65,6 +296,14 @@ func LoadACLExt(ctx *Context, data []byte) (Extension, error) { }, nil } +func NodeList(nodes ...*Node) NodeMap { + m := NodeMap{} + for _, node := range(nodes) { + m[node.ID] = node + } + return m +} + func NewACLExt(delegations NodeMap) *ACLExt { if delegations == nil { delegations = NodeMap{} @@ -98,7 +337,6 @@ func (ext *ACLExt) Type() ExtType { type PolicyLoadFunc func(*Context, []byte) (Policy, error) type PolicyInfo struct { Load PolicyLoadFunc - Type PolicyType } type ACLPolicyExtContext struct { @@ -106,10 +344,22 @@ type ACLPolicyExtContext struct { } func NewACLPolicyExtContext() *ACLPolicyExtContext { - return nil + return &ACLPolicyExtContext{ + Types: map[PolicyType]PolicyInfo{ + PerNodePolicyType: PolicyInfo{ + Load: LoadPerNodePolicy, + }, + ParentOfPolicyType: PolicyInfo{ + Load: LoadParentOfPolicy, + }, + ChildOfPolicyType: PolicyInfo{ + Load: LoadChildOfPolicy, + }, + }, + } } -func (ext ACLPolicyExt) Serialize() ([]byte, error) { +func (ext *ACLPolicyExt) Serialize() ([]byte, error) { policies := map[string][]byte{} for name, policy := range(ext.Policies) { ser, err := policy.Serialize() @@ -126,10 +376,20 @@ func (ext ACLPolicyExt) Serialize() ([]byte, error) { }, "", " ") } -func (ext ACLPolicyExt) Process(context *StateContext, node *Node, signal GraphSignal) error { +func (ext *ACLPolicyExt) Process(context *StateContext, node *Node, signal GraphSignal) error { return nil } +func NewACLPolicyExt(policies map[PolicyType]Policy) *ACLPolicyExt { + if policies == nil { + policies = map[PolicyType]Policy{} + } + + return &ACLPolicyExt{ + Policies: policies, + } +} + func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) { var j struct { Policies map[string][]byte `json:"policies"` @@ -154,32 +414,22 @@ func LoadACLPolicyExt(ctx *Context, data []byte) (Extension, error) { policies[PolicyType(name)] = policy } - return ACLPolicyExt{ - Policies: policies, - }, nil + return NewACLPolicyExt(policies), nil } const ACLPolicyExtType = ExtType("ACL_POLICIES") -func (ext ACLPolicyExt) Type() ExtType { +func (ext *ACLPolicyExt) Type() ExtType { return ACLPolicyExtType } // Check if the extension allows the principal to perform action on node -func (ext ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool { +func (ext *ACLPolicyExt) Allows(context *StateContext, principal *Node, action string, node *Node) bool { + context.Graph.Log.Logf("policy", "POLICY_EXT_ALLOWED: %+v", ext) for _, policy := range(ext.Policies) { + context.Graph.Log.Logf("policy", "POLICY_CHECK_POLICY: %+v", policy) if policy.Allows(context, principal, action, node) == true { return true } } return false } - -func (policy AllNodesPolicy) Allows(context *StateContext, principal *Node, action string, node *Node) bool { - for _, a := range(policy.Actions) { - if a == action { - return true - } - } - return false -} - diff --git a/thread.go b/thread.go index 1464502..98c59ed 100644 --- a/thread.go +++ b/thread.go @@ -142,12 +142,31 @@ type ThreadExtJSON struct { State string `json:"state"` Type string `json:"type"` Parent string `json:"parent"` - Children map[string][]byte `json:"children"` + Children map[string]map[string][]byte `json:"children"` ActionQueue []QueuedAction } func (ext *ThreadExt) Serialize() ([]byte, error) { - return nil, fmt.Errorf("NOT_IMPLEMENTED") + children := map[string]map[string][]byte{} + for id, child := range(ext.Children) { + id_str := id.String() + children[id_str] = map[string][]byte{} + for info_type, info := range(child.Infos) { + var err error + children[id_str][string(info_type)], err = info.Serialize() + if err != nil { + return nil, err + } + } + } + + return json.MarshalIndent(&ThreadExtJSON{ + State: ext.State, + Type: string(ext.ThreadType), + Parent: SaveNode(ext.Parent), + Children: children, + ActionQueue: ext.ActionQueue, + }, "", " ") } func NewThreadExt(ctx*Context, thread_type ThreadType, parent *Node, children map[NodeID]ChildInfo, state string, action_queue []QueuedAction) (*ThreadExt, error) { @@ -265,10 +284,11 @@ func (ext *ThreadExt) Process(context *StateContext, node *Node, signal GraphSig case Up: err = UseStates(context, node, NewACLInfo(node, []string{"parent"}), func(context *StateContext) error { if ext.Parent != nil { - return Signal(context, ext.Parent, node, signal) - } else { - return nil + if ext.Parent.ID != node.ID { + return Signal(context, ext.Parent, node, signal) + } } + return nil }) case Down: err = UseStates(context, node, NewACLInfo(node, []string{"children"}), func(context *StateContext) error { @@ -360,7 +380,7 @@ func LinkThreads(context *StateContext, principal *Node, thread *Node, info Chil return err } - child_ext, err := GetExt[*ThreadExt](thread) + child_ext, err := GetExt[*ThreadExt](child) if err != nil { return err } @@ -562,17 +582,15 @@ func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal if exists == false { return fmt.Errorf("%s is not a child of %s", sig.ID, thread.ID) } - return UpdateStates(context, thread, NewACLInfo(info.Child, []string{"start"}), func(context *StateContext) error { - parent_info, exists := info.Infos["parent"].(*ParentInfo) - if exists == false { - return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info") - } - parent_info.Start = true - ChildGo(ctx, thread_ext, info.Child, sig.Action) + parent_info, exists := info.Infos["parent"].(*ParentInfo) + if exists == false { + return fmt.Errorf("Called ThreadStartChild from a thread that doesn't require parent child info") + } + parent_info.Start = true + ChildGo(ctx, thread_ext, info.Child, sig.Action) - return nil - }) + return nil }) } @@ -581,14 +599,14 @@ func ThreadStartChild(ctx *Context, thread *Node, thread_ext *ThreadExt, signal func ThreadRestore(ctx * Context, thread *Node, thread_ext *ThreadExt, start bool) error { context := NewWriteContext(ctx) return UpdateStates(context, thread, NewACLInfo(thread, []string{"children"}), func(context *StateContext) error { - return UpdateStates(context, thread, ACLList(thread_ext.ChildList(), []string{"start", "state"}), func(context *StateContext) error { + return UpdateStates(context, thread, ACLList(thread_ext.ChildList(), []string{"state"}), func(context *StateContext) error { for _, info := range(thread_ext.Children) { child_ext, err := GetExt[*ThreadExt](info.Child) if err != nil { return err } - parent_info := info.Infos["parent"].(*ParentInfo) + parent_info := info.Infos[ParentInfoType].(*ParentInfo) if parent_info.Start == true && child_ext.State != "finished" { ctx.Log.Logf("thread", "THREAD_RESTORED: %s -> %s", thread.ID, info.Child.ID) if start == true {