From c3058fbd3d63fcec37a7c93fe28d16dcdf4a418a Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Fri, 21 Jul 2023 00:02:53 -0600 Subject: [PATCH] Added more to policy, and updated lockable to use better IDs --- context.go | 4 -- gql_test.go | 11 +++-- lockable.go | 76 ++++++++++++++++------------ node.go | 18 ++++--- policy.go | 139 +++++++++++++++++++++++----------------------------- 5 files changed, 122 insertions(+), 126 deletions(-) diff --git a/context.go b/context.go index 37b820e..6cee835 100644 --- a/context.go +++ b/context.go @@ -192,10 +192,6 @@ func NewContext(db * badger.DB, log Logger) * Context { if err != nil { panic(err) } - err = ctx.RegisterNodeType(NewNodeDef((*AllNodePolicy)(nil), LoadAllNodePolicy, GQLTypeGraphNode())) - if err != nil { - panic(err) - } ctx.AddGQLType(GQLTypeSignal()) diff --git a/gql_test.go b/gql_test.go index cfbdb30..29dfb03 100644 --- a/gql_test.go +++ b/gql_test.go @@ -70,7 +70,7 @@ func TestGQLDBLoad(t * testing.T) { u1_r := NewUser("Test User", time.Now(), &u1_key.PublicKey, u1_shared) u1 := &u1_r - p1_r := NewAllNodePolicy(RandID(), []string{"*"}) + p1_r := NewPerNodePolicy(RandID(), nil, NewNodeActions([]string{"*"})) p1 := &p1_r key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -111,9 +111,9 @@ func TestGQLDBLoad(t * testing.T) { (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", t1, 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") - err = UseStates(ctx, []Node{gql, t1}, func(nodes NodeMap) error { + err = UseStates(ctx, []Node{gql, u1}, func(nodes NodeMap) error { ser1, err := gql.Serialize() - ser2, err := t1.Serialize() + ser2, err := u1.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser1) ctx.Log.Logf("test", "\n%s\n\n", ser2) return err @@ -127,11 +127,12 @@ func TestGQLDBLoad(t * testing.T) { err = UseStates(ctx, []Node{gql_loaded}, func(nodes NodeMap) error { ser, err := gql_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) + u_loaded := gql_loaded.(*GQLThread).Users[u1.ID()] child := gql_loaded.(Thread).Children()[0].(*SimpleThread) t1_loaded = child update_channel_2 = UpdateChannel(t1_loaded, 10, NodeID{}) - err = UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { - ser, err := child.Serialize() + err = UseMoreStates(ctx, []Node{u_loaded}, nodes, func(nodes NodeMap) error { + ser, err := u_loaded.Serialize() ctx.Log.Logf("test", "\n%s\n\n", ser) return err }) diff --git a/lockable.go b/lockable.go index b66e09d..08b70b4 100644 --- a/lockable.go +++ b/lockable.go @@ -62,10 +62,10 @@ func (state * SimpleLockable) Type() NodeType { type SimpleLockableJSON struct { GraphNodeJSON Name string `json:"name"` - Owner *NodeID `json:"owner"` - Dependencies []NodeID `json:"dependencies"` - Requirements []NodeID `json:"requirements"` - LocksHeld map[string]*NodeID `json:"locks_held"` + Owner string `json:"owner"` + Dependencies []string `json:"dependencies"` + Requirements []string `json:"requirements"` + LocksHeld map[string]string `json:"locks_held"` } func (lockable * SimpleLockable) Serialize() ([]byte, error) { @@ -74,29 +74,27 @@ func (lockable * SimpleLockable) Serialize() ([]byte, error) { } func NewSimpleLockableJSON(lockable *SimpleLockable) SimpleLockableJSON { - requirement_ids := make([]NodeID, len(lockable.requirements)) + requirement_ids := make([]string, len(lockable.requirements)) for i, requirement := range(lockable.requirements) { - requirement_ids[i] = requirement.ID() + requirement_ids[i] = requirement.ID().String() } - dependency_ids := make([]NodeID, len(lockable.dependencies)) + dependency_ids := make([]string, len(lockable.dependencies)) for i, dependency := range(lockable.dependencies) { - dependency_ids[i] = dependency.ID() + dependency_ids[i] = dependency.ID().String() } - var owner_id *NodeID = nil + owner_id := "" if lockable.owner != nil { - new_str := lockable.owner.ID() - owner_id = &new_str + owner_id = lockable.owner.ID().String() } - locks_held := map[string]*NodeID{} + locks_held := map[string]string{} for lockable_id, node := range(lockable.locks_held) { if node == nil { - locks_held[lockable_id.String()] = nil + locks_held[lockable_id.String()] = "" } else { - str := node.ID() - locks_held[lockable_id.String()] = &str + locks_held[lockable_id.String()] = node.ID().String() } } @@ -558,43 +556,55 @@ func NewSimpleLockable(id NodeID, name string) SimpleLockable { // Helper function to load links when loading a struct that embeds SimpleLockable func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSON, nodes NodeMap) error { - if j.Owner != nil { - o, err := LoadNodeRecurse(ctx, *j.Owner, nodes) + if j.Owner != "" { + owner_id, err := ParseID(j.Owner) if err != nil { return err } - o_l, ok := o.(Lockable) + owner_node, err := LoadNodeRecurse(ctx, owner_id, nodes) + if err != nil { + return err + } + owner, ok := owner_node.(Lockable) if ok == false { - return fmt.Errorf("%s is not a Lockable", *j.Owner) + return fmt.Errorf("%s is not a Lockable", j.Owner) } - lockable.SetOwner(o_l) + lockable.SetOwner(owner) } - for _, dep := range(j.Dependencies) { - dep_node, err := LoadNodeRecurse(ctx, dep, nodes) + for _, dep_str := range(j.Dependencies) { + dep_id, err := ParseID(dep_str) + if err != nil { + return err + } + dep_node, err := LoadNodeRecurse(ctx, dep_id, nodes) if err != nil { return err } - dep_l, ok := dep_node.(Lockable) + dep, ok := dep_node.(Lockable) if ok == false { return fmt.Errorf("%+v is not a Lockable as expected", dep_node) } - lockable.AddDependency(dep_l) + lockable.AddDependency(dep) } - for _, req := range(j.Requirements) { - req_node, err := LoadNodeRecurse(ctx, req, nodes) + for _, req_str := range(j.Requirements) { + req_id, err := ParseID(req_str) if err != nil { return err } - req_l, ok := req_node.(Lockable) + req_node, err := LoadNodeRecurse(ctx, req_id, nodes) + if err != nil { + return err + } + req, ok := req_node.(Lockable) if ok == false { return fmt.Errorf("%+v is not a Lockable as expected", req_node) } - lockable.AddRequirement(req_l) + lockable.AddRequirement(req) } - for l_id_str, h_id := range(j.LocksHeld) { + for l_id_str, h_str := range(j.LocksHeld) { l_id, err := ParseID(l_id_str) l, err := LoadNodeRecurse(ctx, l_id, nodes) if err != nil { @@ -606,8 +616,12 @@ func RestoreSimpleLockable(ctx * Context, lockable Lockable, j SimpleLockableJSO } var h_l Lockable = nil - if h_id != nil { - h_node, err := LoadNodeRecurse(ctx, *h_id, nodes) + if h_str != "" { + h_id, err := ParseID(h_str) + if err != nil { + return err + } + h_node, err := LoadNodeRecurse(ctx, h_id, nodes) if err != nil { return err } diff --git a/node.go b/node.go index 36cb837..cbdbb77 100644 --- a/node.go +++ b/node.go @@ -69,7 +69,7 @@ type Node interface { ID() NodeID Type() NodeType - Allowed(action string, principal NodeID) bool + Allowed(action string, resource string, principal NodeID) bool AddPolicy(Policy) error RemovePolicy(Policy) error @@ -92,7 +92,7 @@ type GraphNode struct { } type GraphNodeJSON struct { - Policies []NodeID `json:"policies"` + Policies []string `json:"policies"` } func (node * GraphNode) Serialize() ([]byte, error) { @@ -100,9 +100,9 @@ func (node * GraphNode) Serialize() ([]byte, error) { return json.MarshalIndent(&node_json, "", " ") } -func (node *GraphNode) Allowed(action string, principal NodeID) bool { +func (node *GraphNode) Allowed(action string, resource string, principal NodeID) bool { for _, policy := range(node.policies) { - if policy.Allows(action, principal) == true { + if policy.Allows(action, resource, principal) == true { return true } } @@ -138,10 +138,10 @@ func (node *GraphNode) RemovePolicy(policy Policy) error { } func NewGraphNodeJSON(node *GraphNode) GraphNodeJSON { - policies := make([]NodeID, len(node.policies)) + policies := make([]string, len(node.policies)) i := 0 for _, policy := range(node.policies) { - policies[i] = policy.ID() + policies[i] = policy.ID().String() i += 1 } return GraphNodeJSON{ @@ -150,7 +150,11 @@ func NewGraphNodeJSON(node *GraphNode) GraphNodeJSON { } func RestoreGraphNode(ctx *Context, node Node, j GraphNodeJSON, nodes NodeMap) error { - for _, policy_id := range(j.Policies) { + for _, policy_str := range(j.Policies) { + policy_id, err := ParseID(policy_str) + if err != nil { + return err + } policy_ptr, err := LoadNodeRecurse(ctx, policy_id, nodes) if err != nil { return err diff --git a/policy.go b/policy.go index a7d0a04..15b972c 100644 --- a/policy.go +++ b/policy.go @@ -8,17 +8,49 @@ import ( type Policy interface { Node // Returns true if the policy allows the action on the given principal - Allows(action string, principal NodeID) bool + Allows(action string, resource string, principal NodeID) bool +} + +type NodeActions map[string][]string +func (actions NodeActions) Allows(action string, resource string) bool { + for _, a := range(actions[""]) { + if a == action { + return true + } + } + + resource_actions, exists := actions[resource] + if exists == true { + for _, a := range(resource_actions) { + if a == action { + return true + } + } + } + + return false +} + +func NewNodeActions(wildcard_actions []string) NodeActions { + actions := NodeActions{} + // Wildcard actions, all actions in "" will be allowed on all resources + if wildcard_actions == nil { + wildcard_actions = []string{} + } + actions[""] = wildcard_actions + return actions } type PerNodePolicy struct { GraphNode - AllowedActions map[NodeID][]string + NodeActions map[NodeID]NodeActions + WildcardActions NodeActions } type PerNodePolicyJSON struct { GraphNodeJSON - AllowedActions map[string][]string `json:"allowed_actions"` + NodeActions map[string]map[string][]string `json:"allowed_actions"` + WildcardActions map[string][]string `json:"wildcard_actions"` } func (policy *PerNodePolicy) Type() NodeType { @@ -26,25 +58,31 @@ func (policy *PerNodePolicy) Type() NodeType { } func (policy *PerNodePolicy) Serialize() ([]byte, error) { - allowed_actions := map[string][]string{} - for principal, actions := range(policy.AllowedActions) { + allowed_actions := map[string]map[string][]string{} + for principal, actions := range(policy.NodeActions) { allowed_actions[principal.String()] = actions } return json.MarshalIndent(&PerNodePolicyJSON{ GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode), - AllowedActions: allowed_actions, + NodeActions: allowed_actions, + WildcardActions: policy.WildcardActions, }, "", " ") } -func NewPerNodePolicy(id NodeID, allowed_actions map[NodeID][]string) PerNodePolicy { - if allowed_actions == nil { - allowed_actions = map[NodeID][]string{} +func NewPerNodePolicy(id NodeID, node_actions map[NodeID]NodeActions, wildcard_actions NodeActions) PerNodePolicy { + if node_actions == nil { + node_actions = map[NodeID]NodeActions{} + } + + if wildcard_actions == nil { + wildcard_actions = NewNodeActions(nil) } return PerNodePolicy{ GraphNode: NewGraphNode(id), - AllowedActions: allowed_actions, + NodeActions: node_actions, + WildcardActions: wildcard_actions, } } @@ -55,8 +93,8 @@ func LoadPerNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Nod return nil, err } - allowed_actions := map[NodeID][]string{} - for principal_str, actions := range(j.AllowedActions) { + allowed_actions := map[NodeID]NodeActions{} + for principal_str, actions := range(j.NodeActions) { principal_id, err := ParseID(principal_str) if err != nil { return nil, err @@ -65,7 +103,7 @@ func LoadPerNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Nod allowed_actions[principal_id] = actions } - policy := NewPerNodePolicy(id, allowed_actions) + policy := NewPerNodePolicy(id, allowed_actions, j.WildcardActions) nodes[id] = &policy err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes) @@ -76,77 +114,20 @@ func LoadPerNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Nod return &policy, nil } -func (policy *PerNodePolicy) Allows(action string, principal NodeID) bool { - actions, exists := policy.AllowedActions[principal] - if exists == false { - return false - } - - for _, a := range(actions) { - if a == action { - return true - } - } - - return false -} - -type AllNodePolicy struct { - GraphNode - AllowedActions []string -} - -type AllNodePolicyJSON struct { - GraphNodeJSON - AllowedActions []string `json:"allowed_actions"` -} - -func (policy *AllNodePolicy) Type() NodeType { - return NodeType("all_node_policy") -} - -func (policy *AllNodePolicy) Serialize() ([]byte, error) { - return json.MarshalIndent(&AllNodePolicyJSON{ - GraphNodeJSON: NewGraphNodeJSON(&policy.GraphNode), - AllowedActions: policy.AllowedActions, - }, "", " ") -} - -func NewAllNodePolicy(id NodeID, allowed_actions []string) AllNodePolicy { - if allowed_actions == nil { - allowed_actions = []string{} - } - - return AllNodePolicy{ - GraphNode: NewGraphNode(id), - AllowedActions: allowed_actions, +func (policy *PerNodePolicy) Allows(action string, resource string, principal NodeID) bool { + // Check wildcard actions + if policy.WildcardActions.Allows(action, resource) == true { + return true } -} -func LoadAllNodePolicy(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { - var j AllNodePolicyJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err + node_actions, exists := policy.NodeActions[principal] + if exists == false { + return false } - policy := NewAllNodePolicy(id, j.AllowedActions) - nodes[id] = &policy - - err = RestoreGraphNode(ctx, &policy.GraphNode, j.GraphNodeJSON, nodes) - if err != nil { - return nil, err + if node_actions.Allows(action, resource) == true { + return true } - return &policy, nil -} - -func (policy *AllNodePolicy) Allows(action string, principal NodeID) bool { - for _, a := range(policy.AllowedActions) { - if a == action { - return true - } - } return false } -