diff --git a/acl.go b/acl.go index a3b85b1..0d93dac 100644 --- a/acl.go +++ b/acl.go @@ -143,7 +143,6 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa for policy_id, policy_messages := range(acl_messages) { total_messages += len(policy_messages) for _, message := range(policy_messages) { - // Create timeout signal and add the ID to Pending timeout_signal := NewTimeoutSignal(message.Signal.ID()) ext.Pending[message.Signal.ID()] = PendingSignal{ Policy: policy_id, @@ -191,3 +190,37 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa return messages, changes } + +type ACLProxyPolicy struct { + PolicyHeader + Proxies []NodeID +} + +func NewACLProxyPolicy(proxies []NodeID) ACLProxyPolicy { + return ACLProxyPolicy{ + NewPolicyHeader(), + proxies, + } +} + +func (policy ACLProxyPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node) (Messages, RuleResult) { + if len(policy.Proxies) == 0 { + return nil, Deny + } + + messages := Messages{} + for _, proxy := range(policy.Proxies) { + messages = messages.Add(ctx, node.ID, node.Key, NewACLSignal(principal_id, action), proxy) + } + + return messages, Pending +} + +func (policy ACLProxyPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { + _, is_success := signal.(*SuccessSignal) + if is_success == true { + return Allow + } + return Deny +} + diff --git a/acl_test.go b/acl_test.go index b28ea5d..dda44b7 100644 --- a/acl_test.go +++ b/acl_test.go @@ -88,4 +88,29 @@ func TestACLBasic(t *testing.T) { group.ID: nil, }), }, testSuccess) + + testSendACL(t, ctx, listener, nil, []Policy{ + NewACLProxyPolicy(nil), + }, testErrorSignal(t, "acl_denied")) + + acl_proxy_1, err := NewNode(ctx, nil, BaseNodeType, 100, []Policy{DefaultACLPolicy}, NewACLExt(nil)) + fatalErr(t, err) + + testSendACL(t, ctx, listener, nil, []Policy{ + NewACLProxyPolicy([]NodeID{acl_proxy_1.ID}), + }, testErrorSignal(t, "acl_denied")) + + acl_proxy_2, err := NewNode(ctx, nil, BaseNodeType, 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewAllNodesPolicy(nil)})) + fatalErr(t, err) + + testSendACL(t, ctx, listener, nil, []Policy{ + NewACLProxyPolicy([]NodeID{acl_proxy_2.ID}), + }, testSuccess) + + acl_proxy_3, err := NewNode(ctx, nil, BaseNodeType, 100, []Policy{DefaultACLPolicy}, NewACLExt([]Policy{NewMemberOfPolicy(map[NodeID]Tree{group.ID: nil})})) + fatalErr(t, err) + + testSendACL(t, ctx, listener, nil, []Policy{ + NewACLProxyPolicy([]NodeID{acl_proxy_3.ID}), + }, testSuccess) } diff --git a/context.go b/context.go index 68fbbd9..6c63f88 100644 --- a/context.go +++ b/context.go @@ -1299,6 +1299,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterPolicy(reflect.TypeOf(ACLProxyPolicy{}), ACLProxyPolicyType) + if err != nil { + return nil, err + } + err = ctx.RegisterSignal(reflect.TypeOf(ACLTimeoutSignal{}), ACLTimeoutSignalType) if err != nil { return nil, err diff --git a/serialize.go b/serialize.go index f760e0f..a70df42 100644 --- a/serialize.go +++ b/serialize.go @@ -109,6 +109,7 @@ var ( RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF") PerNodePolicyType = NewPolicyType("PER_NODE") AllNodesPolicyType = NewPolicyType("ALL_NODES") + ACLProxyPolicyType = NewPolicyType("ACL_PROXY") ErrorType = NewSerializedType("ERROR") PointerType = NewSerializedType("POINTER")