diff --git a/gql.go b/gql.go index f072e37..e75cd2c 100644 --- a/gql.go +++ b/gql.go @@ -205,7 +205,7 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ if err != nil { return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from id_bytes %+v", id_bytes) } - auth_id := NodeID(auth_uuid) + auth_id := NodeID{auth_uuid} key_bytes, err := base64.StdEncoding.DecodeString(key_b64) if err != nil { diff --git a/gql_test.go b/gql_test.go index 3aa304d..9f8c675 100644 --- a/gql_test.go +++ b/gql_test.go @@ -37,7 +37,7 @@ func TestGQLServer(t *testing.T) { ErrorSignalType.String(): nil, }) - group_policy_2 := NewMemberOfPolicy(NodeRules{ + group_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{ gql_id: Tree{ LinkSignalType.String(): nil, LinkStartSignalType.String(): nil, @@ -53,7 +53,7 @@ func TestGQLServer(t *testing.T) { ErrorSignalType.String(): nil, }) - user_policy_2 := NewMemberOfPolicy(NodeRules{ + user_policy_2 := NewMemberOfPolicy(map[NodeID]Tree{ gql_id: Tree{ LinkSignalType.String(): nil, ReadSignalType.String(): nil, diff --git a/lockable.go b/lockable.go index 46ca4eb..afe958f 100644 --- a/lockable.go +++ b/lockable.go @@ -1,13 +1,21 @@ package graphvent import ( - "encoding/json" + "encoding/binary" ) -type LockableExt struct { - Owner *NodeID `json:"owner"` - PendingOwner *NodeID `json:"pending_owner"` - Requirements map[NodeID]string `json:"requirements"` +type ReqState int +const ( + Unlocked = ReqState(0) + Unlocking = ReqState(1) + Locked = ReqState(2) + Locking = ReqState(3) +) + +type LockableExt struct{ + Owner *NodeID + PendingOwner *NodeID + Requirements map[NodeID]ReqState } func (ext *LockableExt) Field(name string) interface{} { @@ -29,17 +37,97 @@ func (ext *LockableExt) Type() ExtType { } func (ext *LockableExt) Serialize() ([]byte, error) { - return json.Marshal(ext) + ret := make([]byte, 8 + (16 * 2) + (17 * len(ext.Requirements))) + if ext.Owner != nil { + bytes, err := ext.Owner.MarshalBinary() + if err != nil { + return nil, err + } + copy(ret[0:16], bytes) + } + + if ext.PendingOwner != nil { + bytes, err := ext.PendingOwner.MarshalBinary() + if err != nil { + return nil, err + } + copy(ret[16:32], bytes) + } + + binary.BigEndian.PutUint64(ret[32:40], uint64(len(ext.Requirements))) + + cur := 40 + for req, state := range(ext.Requirements) { + bytes, err := req.MarshalBinary() + if err != nil { + return nil, err + } + copy(ret[cur:cur+16], bytes) + ret[cur+16] = byte(state) + cur += 17 + } + + return ret, nil } func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error { - return json.Unmarshal(data, ext) + cur := 0 + all_zero := true + for _, b := range(data[cur:cur+16]) { + if all_zero == true && b != 0x00 { + all_zero = false + } + } + if all_zero == false { + tmp, err := IDFromBytes(data[cur:cur+16]) + if err != nil { + return err + } + ext.Owner = &tmp + } + cur += 16 + + all_zero = true + for _, b := range(data[cur:cur+16]) { + if all_zero == true && b != 0x00 { + all_zero = false + } + } + if all_zero == false { + tmp, err := IDFromBytes(data[cur:cur+16]) + if err != nil { + return err + } + ext.PendingOwner = &tmp + } + cur += 16 + + num_requirements := int(binary.BigEndian.Uint64(data[cur:cur+8])) + cur += 8 + + if num_requirements != 0 { + ext.Requirements = map[NodeID]ReqState{} + } + for i := 0; i < num_requirements; i++ { + id, err := IDFromBytes(data[cur:cur+16]) + if err != nil { + return err + } + cur += 16 + state := ReqState(data[cur]) + cur += 1 + ext.Requirements[id] = state + } + return nil } func NewLockableExt(requirements []NodeID) *LockableExt { - reqs := map[NodeID]string{} - for _, id := range(requirements) { - reqs[id] = "unlocked" + var reqs map[NodeID]ReqState = nil + if requirements != nil { + reqs = map[NodeID]ReqState{} + for _, id := range(requirements) { + reqs[id] = Unlocked + } } return &LockableExt{ Owner: nil, @@ -84,10 +172,10 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, } else { ext.PendingOwner = nil for id, state := range(ext.Requirements) { - if state != "locked" { + if state != Locked { panic("NOT_LOCKED") } - ext.Requirements[id] = "unlocking" + ext.Requirements[id] = Unlocking messages = messages.Add(node.ID, node.Key, NewLockSignal("unlock"), id) } if source != node.ID { @@ -96,11 +184,13 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, } } case "unlocking": - state, exists := ext.Requirements[source] - if exists == false { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state != "unlocking" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) + if ext.Requirements != nil { + state, exists := ext.Requirements[source] + if exists == false { + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + } else if state != Unlocking { + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) + } } case "unlocked": @@ -108,26 +198,28 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, return nil } - state, exists := ext.Requirements[source] - if exists == false { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state != "unlocking" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) - } else { - ext.Requirements[source] = "unlocked" + if ext.Requirements != nil { + state, exists := ext.Requirements[source] + if exists == false { + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + } else if state != Unlocking { + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) + } else { + ext.Requirements[source] = Unlocked - if ext.PendingOwner == nil { - unlocked := 0 - for _, s := range(ext.Requirements) { - if s == "unlocked" { - unlocked += 1 + if ext.PendingOwner == nil { + unlocked := 0 + for _, s := range(ext.Requirements) { + if s == Unlocked { + unlocked += 1 + } } - } - if len(ext.Requirements) == unlocked { - previous_owner := *ext.Owner - ext.Owner = nil - messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), previous_owner) + if len(ext.Requirements) == unlocked { + previous_owner := *ext.Owner + ext.Owner = nil + messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), previous_owner) + } } } } @@ -136,34 +228,38 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, return nil } - state, exists := ext.Requirements[source] - if exists == false { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state != "locking" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) - } else { - ext.Requirements[source] = "locked" + if ext.Requirements != nil { + state, exists := ext.Requirements[source] + if exists == false { + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + } else if state != Locking { + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) + } else { + ext.Requirements[source] = Locked - if ext.PendingOwner != nil { - locked := 0 - for _, s := range(ext.Requirements) { - if s == "locked" { - locked += 1 + if ext.PendingOwner != nil { + locked := 0 + for _, s := range(ext.Requirements) { + if s == Locked { + locked += 1 + } } - } - if len(ext.Requirements) == locked { - ext.Owner = ext.PendingOwner - messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) + if len(ext.Requirements) == locked { + ext.Owner = ext.PendingOwner + messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) + } } } } case "locking": - state, exists := ext.Requirements[source] - if exists == false { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state != "locking" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) + if ext.Requirements != nil { + state, exists := ext.Requirements[source] + if exists == false { + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + } else if state != Locking { + messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) + } } case "lock": @@ -180,13 +276,14 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, } else { ext.PendingOwner = &owner for id, state := range(ext.Requirements) { - log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id) - if state != "unlocked" { + log.Logf("lockable_detail", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id) + if state != Unlocked { panic("NOT_UNLOCKED") } - ext.Requirements[id] = "locking" + ext.Requirements[id] = Locking messages = messages.Add(node.ID, node.Key, NewLockSignal("lock"), id) } + log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %d requirements", node.ID, len(ext.Requirements)) if source != node.ID { messages = messages.Add(node.ID, node.Key, NewLockSignal("locking"), source) } @@ -195,7 +292,6 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, default: log.Logf("lockable", "LOCK_ERR: unkown state %s", state) } - log.Logf("lockable", "LOCK_MESSAGES: %+v", messages) return messages } diff --git a/lockable_test.go b/lockable_test.go index bcbcb46..a7f1f75 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -3,6 +3,8 @@ package graphvent import ( "testing" "time" + "crypto/ed25519" + "crypto/rand" ) const TestLockableType = NodeType("TEST_LOCKABLE") @@ -16,73 +18,116 @@ func lockableTestContext(t *testing.T, logs []string) *Context { } func TestLink(t *testing.T) { - ctx := lockableTestContext(t, []string{"lockable"}) + ctx := lockableTestContext(t, []string{"listener"}) + + l1_pub, l1_key, err := ed25519.GenerateKey(rand.Reader) + fatalErr(t, err) + l1_id := KeyID(l1_pub) + policy := NewPerNodePolicy(map[NodeID]Tree{ + l1_id: nil, + }) l2_listener := NewListenerExt(10) - l2 := NewNode(ctx, nil, TestLockableType, 10, nil, - l2_listener, - NewLockableExt(nil), - ) + l2 := NewNode(ctx, nil, TestLockableType, 10, + map[PolicyType]Policy{ + PerNodePolicyType: &policy, + }, + l2_listener, + NewLockableExt(nil), + ) + l1_listener := NewListenerExt(10) - NewNode(ctx, nil, TestLockableType, 10, nil, + l1 := NewNode(ctx, l1_key, TestLockableType, 10, nil, l1_listener, NewLockableExt([]NodeID{l2.ID}), ) msgs := Messages{} - msgs = msgs.Add(l2.ID, l2.Key, NewStatusSignal("TEST", l2.ID), l2.ID) - err := ctx.Send(msgs) + s := NewBaseSignal("TEST", Down) + msgs = msgs.Add(l1.ID, l1.Key, &s, l1.ID) + err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { - return sig.Str == "TEST" + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { + return sig.ID() == s.ID() }) fatalErr(t, err) - _, err = WaitForSignal(ctx, l2_listener.Chan, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { - return sig.Str == "TEST" + _, err = WaitForSignal(ctx, l2_listener.Chan, time.Millisecond*10, "TEST", func(sig *BaseSignal) bool { + return sig.ID() == s.ID() }) fatalErr(t, err) } func TestLink10K(t *testing.T) { - ctx := lockableTestContext(t, []string{}) + ctx := lockableTestContext(t, []string{"test"}) + l_pub, listener_key, err := ed25519.GenerateKey(rand.Reader) + fatalErr(t, err) + listener_id := KeyID(l_pub) + child_policy := NewPerNodePolicy(map[NodeID]Tree{ + listener_id: Tree{ + LockSignalType.String(): nil, + }, + }) NewLockable := func()(*Node) { - l := NewNode(ctx, nil, TestLockableType, 10, nil, + l := NewNode(ctx, nil, TestLockableType, 10, + map[PolicyType]Policy{ + PerNodePolicyType: &child_policy, + }, NewLockableExt(nil), ) return l } - reqs := make([]NodeID, 10000) + reqs := make([]NodeID, 1000) for i, _ := range(reqs) { new_lockable := NewLockable() reqs[i] = new_lockable.ID } ctx.Log.Logf("test", "CREATED_10K") - NewListener := func()(*ListenerExt) { - listener := NewListenerExt(100000) - NewNode(ctx, nil, TestLockableType, 256, nil, - listener, - NewLockableExt(reqs), - ) - return listener - } - NewListener() + l_policy := NewAllNodesPolicy(Tree{ + LockSignalType.String(): nil, + }) + listener := NewListenerExt(100000) + node := NewNode(ctx, listener_key, TestLockableType, 10000, + map[PolicyType]Policy{ + AllNodesPolicyType: &l_policy, + }, + listener, + NewLockableExt(reqs), + ) ctx.Log.Logf("test", "CREATED_LISTENER") - // TODO: Lock listener and wait for all the lock signals - //ctx.Log.Logf("test", "LOCKED_10K") + err = LockLockable(ctx, node) + fatalErr(t, err) + + _, err = WaitForSignal(ctx, listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool { + return sig.Str == "locked" + }) + fatalErr(t, err) + + for _, _ = range(reqs) { + _, err := WaitForSignal(ctx, listener.Chan, time.Millisecond*100, LockSignalType, func(sig *StringSignal) bool { + return sig.Str == "locked" + }) + fatalErr(t, err) + } + ctx.Log.Logf("test", "LOCKED_10K") } func TestLock(t *testing.T) { - ctx := lockableTestContext(t, []string{}) + ctx := lockableTestContext(t, []string{"lockable", "policy"}) + + policy := NewAllNodesPolicy(nil) NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { listener := NewListenerExt(100) - l := NewNode(ctx, nil, TestLockableType, 10, nil, + l := NewNode(ctx, nil, TestLockableType, 10, + map[PolicyType]Policy{ + AllNodesPolicyType: &policy, + }, listener, NewLockableExt(reqs), ) diff --git a/node.go b/node.go index b150e7e..6eb48a7 100644 --- a/node.go +++ b/node.go @@ -8,7 +8,6 @@ import ( badger "github.com/dgraph-io/badger/v3" "fmt" "encoding/binary" - "encoding/json" "sync/atomic" "crypto" "crypto/ed25519" @@ -30,51 +29,22 @@ const ( var ( // Base NodeID, used as a special value ZeroUUID = uuid.UUID{} - ZeroID = NodeID(ZeroUUID) + ZeroID = NodeID{ZeroUUID} ) // A NodeID uniquely identifies a Node -type NodeID uuid.UUID - -func (id NodeID) MarshalText() ([]byte, error) { - return json.Marshal(id.String()) -} - -func (id *NodeID) UnmarshalText(data []byte) error { - return json.Unmarshal(data, id) -} - -func (id *NodeID) MarshalJSON() ([]byte, error) { - return json.Marshal(id.String()) -} - -func (id *NodeID) UnmarshalJSON(bytes []byte) error { - var id_str string - err := json.Unmarshal(bytes, &id_str) - if err != nil { - return err - } - - *id, err = ParseID(id_str) - return err +type NodeID struct { + uuid.UUID } func (id NodeID) Serialize() []byte { - ser, _ := (uuid.UUID)(id).MarshalBinary() + ser, _ := id.MarshalBinary() return ser } - - -func (id NodeID) String() string { - return (uuid.UUID)(id).String() -} - -// Create an ID from a fixed length byte array -// Ignore the error since we're enforcing 16 byte length at compile time -func IDFromBytes(bytes [16]byte) NodeID { - id, _ := uuid.FromBytes(bytes[:]) - return NodeID(id) +func IDFromBytes(bytes []byte) (NodeID, error) { + id, err := uuid.FromBytes(bytes[:]) + return NodeID{id}, err } // Parse an ID from a string @@ -83,12 +53,12 @@ func ParseID(str string) (NodeID, error) { if err != nil { return NodeID{}, err } - return NodeID(id_uuid), nil + return NodeID{id_uuid}, nil } // Generate a random NodeID func RandID() NodeID { - return NodeID(uuid.New()) + return NodeID{uuid.New()} } // A Serializable has a type that can be used to map to it, and a function to serialize` the current state @@ -249,7 +219,7 @@ func (node *Node) ReadFields(reqs map[ExtType][]string)map[ExtType]map[string]in return exts } -// Main Loop for Threads, starts a write context, so cannot be called from a write or read context +// Main Loop for nodes func nodeLoop(ctx *Context, node *Node) error { started := node.Active.CompareAndSwap(false, true) if started == false { @@ -352,7 +322,6 @@ func nodeLoop(ctx *Context, node *Node) error { req_info.Counter -= 1 req_info.Responses = append(req_info.Responses, signal) - // TODO: call the right policy ParseResponse to check if the updated state passes the ACL check allowed := node.Policies[info.Policy].ContinueAllows(req_info, signal) if allowed == Allow { ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) @@ -385,7 +354,7 @@ func nodeLoop(ctx *Context, node *Node) error { } } - // Handle special signal types + // Handle node signals if signal.Type() == StopSignalType { msgs := Messages{} msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "stopped"), source) @@ -591,7 +560,7 @@ func (node *Node) Serialize() ([]byte, error) { func KeyID(pub ed25519.PublicKey) NodeID { str := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3) - return NodeID(str) + return NodeID{str} } // Create a new node in memory and start it's event loop @@ -652,6 +621,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si SignalQueue: []QueuedSignal{}, } ctx.AddNode(id, node) + err = WriteNode(ctx, node) if err != nil { panic(err) diff --git a/node_test.go b/node_test.go index 0f7bfca..a9179e7 100644 --- a/node_test.go +++ b/node_test.go @@ -37,10 +37,16 @@ func TestNodeRead(t *testing.T) { ctx.Log.Logf("test", "N1: %s", n1_id) ctx.Log.Logf("test", "N2: %s", n2_id) + policy := NewAllNodesPolicy(nil) + n2_listener := NewListenerExt(10) - n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt(), n2_listener) + n2 := NewNode(ctx, n2_key, node_type, 10, map[PolicyType]Policy{ + AllNodesPolicyType: &policy, + }, NewGroupExt(nil), NewECDHExt(), n2_listener) - n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt()) + n1 := NewNode(ctx, n1_key, node_type, 10, map[PolicyType]Policy{ + AllNodesPolicyType: &policy, + }, NewGroupExt(nil), NewECDHExt()) read_sig := NewReadSignal(map[ExtType][]string{ GroupExtType: []string{"members"}, @@ -56,49 +62,3 @@ func TestNodeRead(t *testing.T) { fatalErr(t, err) ctx.Log.Logf("test", "READ_RESULT: %+v", res) } - -func TestECDH(t *testing.T) { - ctx := logTestContext(t, []string{"test", "ecdh", "policy"}) - - node_type := NodeType("TEST") - err := ctx.RegisterNodeType(node_type, []ExtType{ECDHExtType}) - fatalErr(t, err) - - n1_listener := NewListenerExt(10) - n1 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt(), n1_listener) - n2 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt()) - n3_listener := NewListenerExt(10) - n3 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt(), n3_listener) - - ctx.Log.Logf("test", "N1: %s", n1.ID) - ctx.Log.Logf("test", "N2: %s", n2.ID) - - - ecdh_req, n1_ec, err := NewECDHReqSignal(n1) - ecdh_ext, err := GetExt[*ECDHExt](n1) - fatalErr(t, err) - ecdh_ext.ECDHStates[n2.ID] = ECDHState{ - ECKey: n1_ec, - SharedSecret: nil, - } - fatalErr(t, err) - ctx.Log.Logf("test", "N1_EC: %+v", n1_ec) - msgs := Messages{} - msgs = msgs.Add(n1.ID, n1.Key, ecdh_req, n2.ID) - err = ctx.Send(msgs) - fatalErr(t, err) - - _, err = WaitForSignal(ctx, n1_listener.Chan, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool { - return sig.Str == "resp" - }) - fatalErr(t, err) - time.Sleep(10*time.Millisecond) - - ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret) - fatalErr(t, err) - - msgs = Messages{} - msgs = msgs.Add(n1.ID, n1.Key, ecdh_sig, n2.ID) - err = ctx.Send(msgs) - fatalErr(t, err) -} diff --git a/policy.go b/policy.go index a5e92c1..59a74a2 100644 --- a/policy.go +++ b/policy.go @@ -6,6 +6,7 @@ import ( const ( MemberOfPolicyType = PolicyType("USER_OF") + RequirementOfPolicyType = PolicyType("REQUIEMENT_OF") PerNodePolicyType = PolicyType("PER_NODE") AllNodesPolicyType = PolicyType("ALL_NODES") ) @@ -42,6 +43,45 @@ func (policy *PerNodePolicy) ContinueAllows(current PendingACL, signal Signal) R return Deny } +type RequirementOfPolicy struct { + PerNodePolicy +} + +func (policy *RequirementOfPolicy) Type() PolicyType { + return RequirementOfPolicyType +} + +func NewRequirementOfPolicy(dep_rules map[NodeID]Tree) RequirementOfPolicy { + return RequirementOfPolicy { + PerNodePolicy: NewPerNodePolicy(dep_rules), + } +} + +func (policy *RequirementOfPolicy) ContinueAllows(current PendingACL, signal Signal) RuleResult { + sig, ok := signal.(*ReadResultSignal) + if ok == false { + return Deny + } + + ext, ok := sig.Extensions[LockableExtType] + if ok == false { + return Deny + } + + requirements, ok := ext["requirements"].(map[NodeID]string) + if ok == false { + return Deny + } + + for req, _ := range(requirements) { + if req == current.Principal { + return policy.NodeRules[sig.NodeID].Allows(current.Action) + } + } + + return Deny +} + type MemberOfPolicy struct { PerNodePolicy } @@ -50,7 +90,7 @@ func (policy *MemberOfPolicy) Type() PolicyType { return MemberOfPolicyType } -func NewMemberOfPolicy(group_rules NodeRules) MemberOfPolicy { +func NewMemberOfPolicy(group_rules map[NodeID]Tree) MemberOfPolicy { return MemberOfPolicy{ PerNodePolicy: NewPerNodePolicy(group_rules), } @@ -148,16 +188,16 @@ func MergeTrees(first Tree, second Tree) Tree { return ret } -func CopyNodeRules(rules NodeRules) NodeRules { - ret := NodeRules{} +func CopyNodeRules(rules map[NodeID]Tree) map[NodeID]Tree { + ret := map[NodeID]Tree{} for id, r := range(rules) { ret[id] = r } return ret } -func MergeNodeRules(first NodeRules, second NodeRules) NodeRules { - merged := NodeRules{} +func MergeNodeRules(first map[NodeID]Tree, second map[NodeID]Tree) map[NodeID]Tree { + merged := map[NodeID]Tree{} for id, actions := range(first) { merged[id] = actions } @@ -227,38 +267,9 @@ func (rule Tree) Allows(action Tree) RuleResult { } } -type NodeRules map[NodeID]Tree - -func (rules NodeRules) MarshalJSON() ([]byte, error) { - tmp := map[string]Tree{} - for id, r := range(rules) { - tmp[id.String()] = r - } - return json.Marshal(tmp) -} - -func (rules *NodeRules) UnmarshalJSON(data []byte) error { - tmp := map[string]Tree{} - err := json.Unmarshal(data, &tmp) - if err != nil { - return err - } - - for id_str, r := range(tmp) { - id, err := ParseID(id_str) - if err != nil { - return err - } - ru := *rules - ru[id] = r - } - - return nil -} - -func NewPerNodePolicy(node_actions NodeRules) PerNodePolicy { +func NewPerNodePolicy(node_actions map[NodeID]Tree) PerNodePolicy { if node_actions == nil { - node_actions = NodeRules{} + node_actions = map[NodeID]Tree{} } return PerNodePolicy{ @@ -267,7 +278,7 @@ func NewPerNodePolicy(node_actions NodeRules) PerNodePolicy { } type PerNodePolicy struct { - NodeRules NodeRules `json:"node_actions"` + NodeRules map[NodeID]Tree `json:"node_actions"` } func (policy *PerNodePolicy) Type() PolicyType {