Added ACLExt and tests

gql_cataclysm
noah metz 2023-10-13 00:32:24 -06:00
parent c63ad91252
commit 16e25c009f
8 changed files with 367 additions and 56 deletions

193
acl.go

@ -0,0 +1,193 @@
package graphvent
import (
"github.com/google/uuid"
"slices"
"time"
)
type ACLSignal struct {
SignalHeader
Principal NodeID `gv:"principal"`
Action Tree `gv:"tree"`
}
func NewACLSignal(principal NodeID, action Tree) *ACLSignal {
return &ACLSignal{
SignalHeader: NewSignalHeader(Direct),
Principal: principal,
Action: action,
}
}
var DefaultACLPolicy = NewAllNodesPolicy(Tree{
SerializedType(ACLSignalType): nil,
})
func (signal ACLSignal) Permission() Tree {
return Tree{
SerializedType(ACLSignalType): nil,
}
}
type ACLExt struct {
Policies []Policy `gv:"policies"`
PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"`
Pending map[uuid.UUID]PendingSignal `gv:"pending"`
}
func NewACLExt(policies []Policy) *ACLExt {
return &ACLExt{
Policies: policies,
PendingACLs: map[uuid.UUID]PendingACL{},
Pending: map[uuid.UUID]PendingSignal{},
}
}
func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) {
response, is_response := signal.(ResponseSignal)
if is_response == true {
var messages Messages = nil
var changes Changes = nil
info, waiting := ext.Pending[response.ResponseID()]
if waiting == true {
changes = changes.Add("response_processed")
delete(ext.Pending, response.ResponseID())
if response.ID() != info.Timeout {
err := node.DequeueSignal(info.Timeout)
if err != nil {
ctx.Log.Logf("acl", "timeout dequeue error: %s", err)
}
}
acl_info, found := ext.PendingACLs[info.ID]
if found == true {
acl_info.Counter -= 1
acl_info.Responses = append(acl_info.Responses, response)
policy_index := slices.IndexFunc(ext.Policies, func(policy Policy) bool {
return policy.ID() == info.Policy
})
if policy_index == -1 {
ctx.Log.Logf("acl", "pending signal for nonexistent policy")
delete(ext.PendingACLs, info.ID)
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else {
if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow {
delete(ext.PendingACLs, info.ID)
ctx.Log.Logf("acl", "Request delayed allow")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(info.ID), acl_info.Source)
changes = changes.Add("acl_passed")
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else if acl_info.Counter == 0 {
delete(ext.PendingACLs, info.ID)
ctx.Log.Logf("acl", "Request delayed deny")
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(info.ID, "acl_denied"), acl_info.Source)
changes = changes.Add("acl_blocked")
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else {
node.PendingACLs[info.ID] = acl_info
changes = changes.Add("acl_processed")
}
}
}
}
return messages, changes
}
var messages Messages = nil
var changes Changes = nil
switch sig := signal.(type) {
case *ACLSignal:
var acl_messages map[uuid.UUID]Messages = nil
denied := true
for _, policy := range(ext.Policies) {
policy_messages, result := policy.Allows(ctx, sig.Principal, sig.Action, node)
if result == Allow {
denied = false
break
} else if result == Pending {
if len(policy_messages) == 0 {
ctx.Log.Logf("acl", "Pending result for %s with no messages returned", policy.ID())
continue
} else if acl_messages == nil {
acl_messages = map[uuid.UUID]Messages{}
denied = false
}
acl_messages[policy.ID()] = policy_messages
ctx.Log.Logf("acl", "Pending result for %s:%s - %+v", node.ID, policy.ID(), acl_messages)
}
}
if denied == true {
ctx.Log.Logf("acl", "Request denied")
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.Id, "acl_denied"), source)
} else if acl_messages != nil {
ctx.Log.Logf("acl", "Request pending")
changes = changes.Add("acl_pending")
total_messages := 0
// TODO: reasonable timeout/configurable
timeout_time := time.Now().Add(time.Second)
for policy_id, policy_messages := range(acl_messages) {
total_messages += len(policy_messages)
for _, message := range(policy_messages) {
// Create timeout signal and add the ID to Pending
timeout_signal := NewTimeoutSignal(message.Signal.ID())
ext.Pending[message.Signal.ID()] = PendingSignal{
Policy: policy_id,
Timeout: timeout_signal.Id,
ID: sig.Id,
}
node.QueueSignal(timeout_time, timeout_signal)
messages = append(messages, message)
}
}
acl_timeout := NewACLTimeoutSignal(sig.Id)
node.QueueSignal(timeout_time, acl_timeout)
ext.PendingACLs[sig.Id] = PendingACL{
Counter: total_messages,
Responses: []ResponseSignal{},
TimeoutID: acl_timeout.Id,
Action: sig.Action,
Principal: sig.Principal,
Source: source,
Signal: signal,
}
} else {
ctx.Log.Logf("acl", "Request allowed")
messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(sig.Id), source)
}
// Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request
case *ACLTimeoutSignal:
acl_info, exists := ext.PendingACLs[sig.ReqID]
if exists == true {
delete(ext.PendingACLs, sig.ReqID)
ctx.Log.Logf("acl", "Request timeout deny")
messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ReqID, "acl_timeout"), acl_info.Source)
changes = changes.Add("acl_timeout")
err := node.DequeueSignal(acl_info.TimeoutID)
if err != nil {
ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err)
}
} else {
ctx.Log.Logf("acl", "ACL_TIMEOUT_SIGNAL for passed acl")
}
// Delete from PendingACLs
}
return messages, changes
}

@ -0,0 +1,91 @@
package graphvent
import (
"testing"
"time"
"reflect"
)
func checkSignal[S Signal](t *testing.T, signal Signal, check func(S)){
response_casted, cast_ok := signal.(S)
if cast_ok == false {
error_signal, is_error := signal.(*ErrorSignal)
if is_error {
t.Fatal(error_signal.Error)
}
t.Fatalf("Response of wrong type %s", reflect.TypeOf(signal))
}
check(response_casted)
}
func testSendACL[S Signal](t *testing.T, ctx *Context, listener *Node, action Tree, policies []Policy, check func(S)){
acl_node, err := NewNode(ctx, nil, BaseNodeType, 100, []Policy{DefaultACLPolicy}, NewACLExt(policies))
fatalErr(t, err)
acl_signal := NewACLSignal(listener.ID, action)
response := testSend(t, ctx, acl_signal, listener, acl_node)
checkSignal(t, response, check)
}
func testErrorSignal(t *testing.T, error_string string) func(*ErrorSignal){
return func(response *ErrorSignal) {
if response.Error != error_string {
t.Fatalf("Wrong error: %s", response.Error)
}
}
}
func testSuccess(*SuccessSignal){}
func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) ResponseSignal {
source_listener, err := GetExt[*ListenerExt](source, ListenerExtType)
fatalErr(t, err)
messages := Messages{}
messages = messages.Add(ctx, source.ID, source.Key, signal, destination.ID)
fatalErr(t, ctx.Send(messages))
response, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID())
fatalErr(t, err)
return response
}
func TestACLBasic(t *testing.T) {
ctx := logTestContext(t, []string{"test", "acl"})
listener, err := NewNode(ctx, nil, BaseNodeType, 100, nil, NewListenerExt(100))
fatalErr(t, err)
testSendACL(t, ctx, listener, nil, nil, testErrorSignal(t, "acl_denied"))
testSendACL(t, ctx, listener, nil, []Policy{NewAllNodesPolicy(nil)}, testSuccess)
group, err := NewNode(ctx, nil, GroupNodeType, 100, []Policy{
DefaultGroupPolicy,
NewPerNodePolicy(map[NodeID]Tree{
listener.ID: {
SerializedType(AddMemberSignalType): nil,
},
}),
}, NewGroupExt(nil))
fatalErr(t, err)
testSendACL(t, ctx, listener, nil, []Policy{
NewMemberOfPolicy(map[NodeID]Tree{
group.ID: nil,
}),
}, testErrorSignal(t, "acl_denied"))
add_member_signal := NewAddMemberSignal(listener.ID)
add_member_response := testSend(t, ctx, add_member_signal, listener, group)
checkSignal(t, add_member_response, testSuccess)
testSendACL(t, ctx, listener, nil, []Policy{
NewMemberOfPolicy(map[NodeID]Tree{
group.ID: nil,
}),
}, testSuccess)
}

@ -1279,6 +1279,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err return nil, err
} }
err = ctx.RegisterExtension(reflect.TypeOf((*ACLExt)(nil)), ACLExtType, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterPolicy(reflect.TypeOf(MemberOfPolicy{}), MemberOfPolicyType) err = ctx.RegisterPolicy(reflect.TypeOf(MemberOfPolicy{}), MemberOfPolicyType)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1294,6 +1299,16 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err return nil, err
} }
err = ctx.RegisterSignal(reflect.TypeOf(ACLTimeoutSignal{}), ACLTimeoutSignalType)
if err != nil {
return nil, err
}
err = ctx.RegisterSignal(reflect.TypeOf(ACLSignal{}), ACLSignalType)
if err != nil {
return nil, err
}
err = ctx.RegisterSignal(reflect.TypeOf(RemoveMemberSignal{}), RemoveMemberSignalType) err = ctx.RegisterSignal(reflect.TypeOf(RemoveMemberSignal{}), RemoveMemberSignalType)
if err != nil { if err != nil {
return nil, err return nil, err

@ -40,7 +40,7 @@ func NewRemoveMemberSignal(member_id NodeID) *RemoveMemberSignal {
} }
} }
var GroupReadPolicy = NewAllNodesPolicy(Tree{ var DefaultGroupPolicy = NewAllNodesPolicy(Tree{
SerializedType(ReadSignalType): { SerializedType(ReadSignalType): {
SerializedType(GroupExtType): { SerializedType(GroupExtType): {
Hash(FieldNameBase, "members"): nil, Hash(FieldNameBase, "members"): nil,

@ -70,18 +70,19 @@ func (q QueuedSignal) String() string {
type PendingACL struct { type PendingACL struct {
Counter int Counter int
Responses []ResponseSignal
TimeoutID uuid.UUID TimeoutID uuid.UUID
Action Tree Action Tree
Principal NodeID Principal NodeID
Messages Messages
Responses []Signal
Signal Signal Signal Signal
Source NodeID Source NodeID
} }
type PendingSignal struct { type PendingSignal struct {
Policy uuid.UUID Policy uuid.UUID
Found bool Timeout uuid.UUID
ID uuid.UUID ID uuid.UUID
} }
@ -314,10 +315,20 @@ func nodeLoop(ctx *Context, node *Node) error {
for policy_type, sigs := range(pends) { for policy_type, sigs := range(pends) {
for _, m := range(sigs) { for _, m := range(sigs) {
msgs = append(msgs, m) msgs = append(msgs, m)
node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, false, msg.Signal.ID()} timeout_signal := NewTimeoutSignal(m.Signal.ID())
node.QueueSignal(time.Now().Add(time.Second), timeout_signal)
node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, timeout_signal.Id, msg.Signal.ID()}
}
} }
node.PendingACLs[msg.Signal.ID()] = PendingACL{
Counter: len(msgs),
TimeoutID: timeout_signal.ID(),
Action: msg.Signal.Permission(),
Principal: princ_id,
Responses: []ResponseSignal{},
Signal: msg.Signal,
Source: msg.Source,
} }
node.PendingACLs[msg.Signal.ID()] = PendingACL{len(msgs), timeout_signal.ID(), msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source}
ctx.Log.Logf("policy", "Sending signals for pending ACL: %+v", msgs) ctx.Log.Logf("policy", "Sending signals for pending ACL: %+v", msgs)
ctx.Send(msgs) ctx.Send(msgs)
continue continue
@ -364,14 +375,13 @@ func nodeLoop(ctx *Context, node *Node) error {
if ok == true { if ok == true {
info, waiting := node.PendingSignals[response.ResponseID()] info, waiting := node.PendingSignals[response.ResponseID()]
if waiting == true { if waiting == true {
if info.Found == false { delete(node.PendingSignals, response.ResponseID())
info.Found = true
node.PendingSignals[response.ResponseID()] = info
ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal) ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal)
req_info, exists := node.PendingACLs[info.ID] req_info, exists := node.PendingACLs[info.ID]
if exists == true { if exists == true {
req_info.Counter -= 1 req_info.Counter -= 1
req_info.Responses = append(req_info.Responses, signal) req_info.Responses = append(req_info.Responses, response)
idx := -1 idx := -1
for i, p := range(node.Policies) { for i, p := range(node.Policies) {
@ -391,7 +401,7 @@ func nodeLoop(ctx *Context, node *Node) error {
source = req_info.Source source = req_info.Source
err := node.DequeueSignal(req_info.TimeoutID) err := node.DequeueSignal(req_info.TimeoutID)
if err != nil { if err != nil {
panic("dequeued a passed signal") ctx.Log.Logf("node", "dequeue error: %s", err)
} }
delete(node.PendingACLs, info.ID) delete(node.PendingACLs, info.ID)
} else if req_info.Counter == 0 { } else if req_info.Counter == 0 {
@ -416,7 +426,6 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
} }
} }
}
switch sig := signal.(type) { switch sig := signal.(type) {
case *StopSignal: case *StopSignal:

@ -100,6 +100,7 @@ func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, si
if ok == false { if ok == false {
return Deny return Deny
} }
ctx.Log.Logf("group", "member_of_read_result: %+v", sig.Extensions)
group_ext_data, ok := sig.Extensions[GroupExtType] group_ext_data, ok := sig.Extensions[GroupExtType]
if ok == false { if ok == false {
@ -116,12 +117,12 @@ func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, si
return Deny return Deny
} }
members, ok := members_if.Interface().(map[NodeID]string) members, ok := members_if.Interface().([]NodeID)
if ok == false { if ok == false {
return Deny return Deny
} }
for member := range(members) { for _, member := range(members) {
if member == current.Principal { if member == current.Principal {
return policy.NodeRules[sig.NodeID].Allows(current.Action) return policy.NodeRules[sig.NodeID].Allows(current.Action)
} }

@ -83,6 +83,7 @@ var (
LockableExtType = NewExtType("LOCKABLE") LockableExtType = NewExtType("LOCKABLE")
GQLExtType = NewExtType("GQL") GQLExtType = NewExtType("GQL")
GroupExtType = NewExtType("GROUP") GroupExtType = NewExtType("GROUP")
ACLExtType = NewExtType("ACL")
GQLNodeType = NewNodeType("GQL") GQLNodeType = NewNodeType("GQL")
BaseNodeType = NewNodeType("BASE") BaseNodeType = NewNodeType("BASE")
@ -102,6 +103,7 @@ var (
ReadResultSignalType = NewSignalType("READ_RESULT") ReadResultSignalType = NewSignalType("READ_RESULT")
RemoveMemberSignalType = NewSignalType("REMOVE_MEMBER") RemoveMemberSignalType = NewSignalType("REMOVE_MEMBER")
AddMemberSignalType = NewSignalType("ADD_MEMBER") AddMemberSignalType = NewSignalType("ADD_MEMBER")
ACLSignalType = NewSignalType("ACL")
MemberOfPolicyType = NewPolicyType("USER_OF") MemberOfPolicyType = NewPolicyType("USER_OF")
RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF") RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF")

@ -15,12 +15,12 @@ const (
) )
type TimeoutSignal struct { type TimeoutSignal struct {
SignalHeader ResponseHeader
} }
func NewTimeoutSignal() *TimeoutSignal { func NewTimeoutSignal(req_id uuid.UUID) *TimeoutSignal {
return &TimeoutSignal{ return &TimeoutSignal{
NewSignalHeader(Direct), NewResponseHeader(req_id, Direct),
} }
} }