diff --git a/acl.go b/acl.go new file mode 100644 index 0000000..a3b85b1 --- /dev/null +++ b/acl.go @@ -0,0 +1,193 @@ +package graphvent + +import ( + "github.com/google/uuid" + "slices" + "time" +) + +type ACLSignal struct { + SignalHeader + Principal NodeID `gv:"principal"` + Action Tree `gv:"tree"` +} + +func NewACLSignal(principal NodeID, action Tree) *ACLSignal { + return &ACLSignal{ + SignalHeader: NewSignalHeader(Direct), + Principal: principal, + Action: action, + } +} + +var DefaultACLPolicy = NewAllNodesPolicy(Tree{ + SerializedType(ACLSignalType): nil, +}) + +func (signal ACLSignal) Permission() Tree { + return Tree{ + SerializedType(ACLSignalType): nil, + } +} + +type ACLExt struct { + Policies []Policy `gv:"policies"` + PendingACLs map[uuid.UUID]PendingACL `gv:"pending_acls"` + Pending map[uuid.UUID]PendingSignal `gv:"pending"` +} + +func NewACLExt(policies []Policy) *ACLExt { + return &ACLExt{ + Policies: policies, + PendingACLs: map[uuid.UUID]PendingACL{}, + Pending: map[uuid.UUID]PendingSignal{}, + } +} + +func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { + response, is_response := signal.(ResponseSignal) + if is_response == true { + var messages Messages = nil + var changes Changes = nil + info, waiting := ext.Pending[response.ResponseID()] + if waiting == true { + changes = changes.Add("response_processed") + delete(ext.Pending, response.ResponseID()) + if response.ID() != info.Timeout { + err := node.DequeueSignal(info.Timeout) + if err != nil { + ctx.Log.Logf("acl", "timeout dequeue error: %s", err) + } + } + + acl_info, found := ext.PendingACLs[info.ID] + if found == true { + acl_info.Counter -= 1 + acl_info.Responses = append(acl_info.Responses, response) + + policy_index := slices.IndexFunc(ext.Policies, func(policy Policy) bool { + return policy.ID() == info.Policy + }) + + if policy_index == -1 { + ctx.Log.Logf("acl", "pending signal for nonexistent policy") + delete(ext.PendingACLs, info.ID) + err := node.DequeueSignal(acl_info.TimeoutID) + if err != nil { + ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) + } + } else { + if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow { + delete(ext.PendingACLs, info.ID) + ctx.Log.Logf("acl", "Request delayed allow") + messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(info.ID), acl_info.Source) + changes = changes.Add("acl_passed") + err := node.DequeueSignal(acl_info.TimeoutID) + if err != nil { + ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) + } + } else if acl_info.Counter == 0 { + delete(ext.PendingACLs, info.ID) + ctx.Log.Logf("acl", "Request delayed deny") + messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(info.ID, "acl_denied"), acl_info.Source) + changes = changes.Add("acl_blocked") + err := node.DequeueSignal(acl_info.TimeoutID) + if err != nil { + ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) + } + } else { + node.PendingACLs[info.ID] = acl_info + changes = changes.Add("acl_processed") + } + } + } + } + return messages, changes + } + + var messages Messages = nil + var changes Changes = nil + + switch sig := signal.(type) { + case *ACLSignal: + var acl_messages map[uuid.UUID]Messages = nil + denied := true + for _, policy := range(ext.Policies) { + policy_messages, result := policy.Allows(ctx, sig.Principal, sig.Action, node) + if result == Allow { + denied = false + break + } else if result == Pending { + if len(policy_messages) == 0 { + ctx.Log.Logf("acl", "Pending result for %s with no messages returned", policy.ID()) + continue + } else if acl_messages == nil { + acl_messages = map[uuid.UUID]Messages{} + denied = false + } + + acl_messages[policy.ID()] = policy_messages + ctx.Log.Logf("acl", "Pending result for %s:%s - %+v", node.ID, policy.ID(), acl_messages) + } + } + + if denied == true { + ctx.Log.Logf("acl", "Request denied") + messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.Id, "acl_denied"), source) + } else if acl_messages != nil { + ctx.Log.Logf("acl", "Request pending") + changes = changes.Add("acl_pending") + total_messages := 0 + // TODO: reasonable timeout/configurable + timeout_time := time.Now().Add(time.Second) + for policy_id, policy_messages := range(acl_messages) { + total_messages += len(policy_messages) + for _, message := range(policy_messages) { + // Create timeout signal and add the ID to Pending + timeout_signal := NewTimeoutSignal(message.Signal.ID()) + ext.Pending[message.Signal.ID()] = PendingSignal{ + Policy: policy_id, + Timeout: timeout_signal.Id, + ID: sig.Id, + } + node.QueueSignal(timeout_time, timeout_signal) + messages = append(messages, message) + } + } + + acl_timeout := NewACLTimeoutSignal(sig.Id) + node.QueueSignal(timeout_time, acl_timeout) + ext.PendingACLs[sig.Id] = PendingACL{ + Counter: total_messages, + Responses: []ResponseSignal{}, + TimeoutID: acl_timeout.Id, + Action: sig.Action, + Principal: sig.Principal, + + Source: source, + Signal: signal, + } + } else { + ctx.Log.Logf("acl", "Request allowed") + messages = messages.Add(ctx, node.ID, node.Key, NewSuccessSignal(sig.Id), source) + } + // Test an action against the policy list, sending any intermediate signals necessary and seeting Pending and PendingACLs accordingly. Add a TimeoutSignal for every message awaiting a response, and an ACLTimeoutSignal for the overall request + case *ACLTimeoutSignal: + acl_info, exists := ext.PendingACLs[sig.ReqID] + if exists == true { + delete(ext.PendingACLs, sig.ReqID) + ctx.Log.Logf("acl", "Request timeout deny") + messages = messages.Add(ctx, node.ID, node.Key, NewErrorSignal(sig.ReqID, "acl_timeout"), acl_info.Source) + changes = changes.Add("acl_timeout") + err := node.DequeueSignal(acl_info.TimeoutID) + if err != nil { + ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) + } + } else { + ctx.Log.Logf("acl", "ACL_TIMEOUT_SIGNAL for passed acl") + } + // Delete from PendingACLs + } + + return messages, changes +} diff --git a/acl_test.go b/acl_test.go new file mode 100644 index 0000000..b28ea5d --- /dev/null +++ b/acl_test.go @@ -0,0 +1,91 @@ +package graphvent + +import ( + "testing" + "time" + "reflect" +) + +func checkSignal[S Signal](t *testing.T, signal Signal, check func(S)){ + response_casted, cast_ok := signal.(S) + if cast_ok == false { + error_signal, is_error := signal.(*ErrorSignal) + if is_error { + t.Fatal(error_signal.Error) + } + t.Fatalf("Response of wrong type %s", reflect.TypeOf(signal)) + } + + check(response_casted) +} + +func testSendACL[S Signal](t *testing.T, ctx *Context, listener *Node, action Tree, policies []Policy, check func(S)){ + acl_node, err := NewNode(ctx, nil, BaseNodeType, 100, []Policy{DefaultACLPolicy}, NewACLExt(policies)) + fatalErr(t, err) + + acl_signal := NewACLSignal(listener.ID, action) + response := testSend(t, ctx, acl_signal, listener, acl_node) + + checkSignal(t, response, check) +} + +func testErrorSignal(t *testing.T, error_string string) func(*ErrorSignal){ + return func(response *ErrorSignal) { + if response.Error != error_string { + t.Fatalf("Wrong error: %s", response.Error) + } + } +} + +func testSuccess(*SuccessSignal){} + +func testSend(t *testing.T, ctx *Context, signal Signal, source, destination *Node) ResponseSignal { + source_listener, err := GetExt[*ListenerExt](source, ListenerExtType) + fatalErr(t, err) + + messages := Messages{} + messages = messages.Add(ctx, source.ID, source.Key, signal, destination.ID) + fatalErr(t, ctx.Send(messages)) + + response, err := WaitForResponse(source_listener.Chan, time.Millisecond*10, signal.ID()) + fatalErr(t, err) + + return response +} + +func TestACLBasic(t *testing.T) { + ctx := logTestContext(t, []string{"test", "acl"}) + + listener, err := NewNode(ctx, nil, BaseNodeType, 100, nil, NewListenerExt(100)) + fatalErr(t, err) + + testSendACL(t, ctx, listener, nil, nil, testErrorSignal(t, "acl_denied")) + + testSendACL(t, ctx, listener, nil, []Policy{NewAllNodesPolicy(nil)}, testSuccess) + + group, err := NewNode(ctx, nil, GroupNodeType, 100, []Policy{ + DefaultGroupPolicy, + NewPerNodePolicy(map[NodeID]Tree{ + listener.ID: { + SerializedType(AddMemberSignalType): nil, + }, + }), + }, NewGroupExt(nil)) + fatalErr(t, err) + + testSendACL(t, ctx, listener, nil, []Policy{ + NewMemberOfPolicy(map[NodeID]Tree{ + group.ID: nil, + }), + }, testErrorSignal(t, "acl_denied")) + + add_member_signal := NewAddMemberSignal(listener.ID) + add_member_response := testSend(t, ctx, add_member_signal, listener, group) + checkSignal(t, add_member_response, testSuccess) + + testSendACL(t, ctx, listener, nil, []Policy{ + NewMemberOfPolicy(map[NodeID]Tree{ + group.ID: nil, + }), + }, testSuccess) +} diff --git a/context.go b/context.go index b4320d0..68fbbd9 100644 --- a/context.go +++ b/context.go @@ -1279,6 +1279,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterExtension(reflect.TypeOf((*ACLExt)(nil)), ACLExtType, nil) + if err != nil { + return nil, err + } + err = ctx.RegisterPolicy(reflect.TypeOf(MemberOfPolicy{}), MemberOfPolicyType) if err != nil { return nil, err @@ -1294,6 +1299,16 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterSignal(reflect.TypeOf(ACLTimeoutSignal{}), ACLTimeoutSignalType) + if err != nil { + return nil, err + } + + err = ctx.RegisterSignal(reflect.TypeOf(ACLSignal{}), ACLSignalType) + if err != nil { + return nil, err + } + err = ctx.RegisterSignal(reflect.TypeOf(RemoveMemberSignal{}), RemoveMemberSignalType) if err != nil { return nil, err diff --git a/group.go b/group.go index 3dfa572..77d81c6 100644 --- a/group.go +++ b/group.go @@ -40,7 +40,7 @@ func NewRemoveMemberSignal(member_id NodeID) *RemoveMemberSignal { } } -var GroupReadPolicy = NewAllNodesPolicy(Tree{ +var DefaultGroupPolicy = NewAllNodesPolicy(Tree{ SerializedType(ReadSignalType): { SerializedType(GroupExtType): { Hash(FieldNameBase, "members"): nil, diff --git a/node.go b/node.go index 7d20bbe..6e28d42 100644 --- a/node.go +++ b/node.go @@ -70,18 +70,19 @@ func (q QueuedSignal) String() string { type PendingACL struct { Counter int + Responses []ResponseSignal + TimeoutID uuid.UUID Action Tree Principal NodeID - Messages Messages - Responses []Signal + Signal Signal Source NodeID } type PendingSignal struct { Policy uuid.UUID - Found bool + Timeout uuid.UUID ID uuid.UUID } @@ -314,10 +315,20 @@ func nodeLoop(ctx *Context, node *Node) error { for policy_type, sigs := range(pends) { for _, m := range(sigs) { msgs = append(msgs, m) - node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, false, msg.Signal.ID()} + timeout_signal := NewTimeoutSignal(m.Signal.ID()) + node.QueueSignal(time.Now().Add(time.Second), timeout_signal) + node.PendingSignals[m.Signal.ID()] = PendingSignal{policy_type, timeout_signal.Id, msg.Signal.ID()} } } - node.PendingACLs[msg.Signal.ID()] = PendingACL{len(msgs), timeout_signal.ID(), msg.Signal.Permission(), princ_id, msgs, []Signal{}, msg.Signal, msg.Source} + node.PendingACLs[msg.Signal.ID()] = PendingACL{ + Counter: len(msgs), + TimeoutID: timeout_signal.ID(), + Action: msg.Signal.Permission(), + Principal: princ_id, + Responses: []ResponseSignal{}, + Signal: msg.Signal, + Source: msg.Source, + } ctx.Log.Logf("policy", "Sending signals for pending ACL: %+v", msgs) ctx.Send(msgs) continue @@ -364,54 +375,52 @@ func nodeLoop(ctx *Context, node *Node) error { if ok == true { info, waiting := node.PendingSignals[response.ResponseID()] if waiting == true { - if info.Found == false { - info.Found = true - node.PendingSignals[response.ResponseID()] = info - ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal) - req_info, exists := node.PendingACLs[info.ID] - if exists == true { - req_info.Counter -= 1 - req_info.Responses = append(req_info.Responses, signal) - - idx := -1 - for i, p := range(node.Policies) { - if p.ID() == info.Policy { - idx = i - break - } + delete(node.PendingSignals, response.ResponseID()) + ctx.Log.Logf("pending", "FOUND_PENDING_SIGNAL: %s - %s", node.ID, signal) + + req_info, exists := node.PendingACLs[info.ID] + if exists == true { + req_info.Counter -= 1 + req_info.Responses = append(req_info.Responses, response) + + idx := -1 + for i, p := range(node.Policies) { + if p.ID() == info.Policy { + idx = i + break } - if idx == -1 { - ctx.Log.Logf("policy", "PENDING_FOR_NONEXISTENT_POLICY: %s - %s", node.ID, info.Policy) + } + if idx == -1 { + ctx.Log.Logf("policy", "PENDING_FOR_NONEXISTENT_POLICY: %s - %s", node.ID, info.Policy) + delete(node.PendingACLs, info.ID) + } else { + allowed := node.Policies[idx].ContinueAllows(ctx, req_info, signal) + if allowed == Allow { + ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) + signal = req_info.Signal + source = req_info.Source + err := node.DequeueSignal(req_info.TimeoutID) + if err != nil { + ctx.Log.Logf("node", "dequeue error: %s", err) + } delete(node.PendingACLs, info.ID) - } else { - allowed := node.Policies[idx].ContinueAllows(ctx, req_info, signal) - if allowed == Allow { - ctx.Log.Logf("policy", "DELAYED_POLICY_ALLOW: %s - %s", node.ID, req_info.Signal) - signal = req_info.Signal - source = req_info.Source - err := node.DequeueSignal(req_info.TimeoutID) - if err != nil { - panic("dequeued a passed signal") - } - delete(node.PendingACLs, info.ID) - } else if req_info.Counter == 0 { - ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) - // Send the denied response - msgs := Messages{} - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source) - err := ctx.Send(msgs) - if err != nil { - ctx.Log.Logf("signal", "SEND_ERR: %s", err) - } - err = node.DequeueSignal(req_info.TimeoutID) - if err != nil { - panic("dequeued a passed signal") - } - delete(node.PendingACLs, info.ID) - } else { - node.PendingACLs[info.ID] = req_info - continue + } else if req_info.Counter == 0 { + ctx.Log.Logf("policy", "DELAYED_POLICY_DENY: %s - %s", node.ID, req_info.Signal) + // Send the denied response + msgs := Messages{} + msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(req_info.Signal.ID(), "ACL_DENIED"), req_info.Source) + err := ctx.Send(msgs) + if err != nil { + ctx.Log.Logf("signal", "SEND_ERR: %s", err) + } + err = node.DequeueSignal(req_info.TimeoutID) + if err != nil { + panic("dequeued a passed signal") } + delete(node.PendingACLs, info.ID) + } else { + node.PendingACLs[info.ID] = req_info + continue } } } diff --git a/policy.go b/policy.go index c262247..69bd5f9 100644 --- a/policy.go +++ b/policy.go @@ -100,6 +100,7 @@ func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, si if ok == false { return Deny } + ctx.Log.Logf("group", "member_of_read_result: %+v", sig.Extensions) group_ext_data, ok := sig.Extensions[GroupExtType] if ok == false { @@ -116,12 +117,12 @@ func (policy MemberOfPolicy) ContinueAllows(ctx *Context, current PendingACL, si return Deny } - members, ok := members_if.Interface().(map[NodeID]string) + members, ok := members_if.Interface().([]NodeID) if ok == false { return Deny } - for member := range(members) { + for _, member := range(members) { if member == current.Principal { return policy.NodeRules[sig.NodeID].Allows(current.Action) } diff --git a/serialize.go b/serialize.go index 9c17076..f760e0f 100644 --- a/serialize.go +++ b/serialize.go @@ -83,6 +83,7 @@ var ( LockableExtType = NewExtType("LOCKABLE") GQLExtType = NewExtType("GQL") GroupExtType = NewExtType("GROUP") + ACLExtType = NewExtType("ACL") GQLNodeType = NewNodeType("GQL") BaseNodeType = NewNodeType("BASE") @@ -102,6 +103,7 @@ var ( ReadResultSignalType = NewSignalType("READ_RESULT") RemoveMemberSignalType = NewSignalType("REMOVE_MEMBER") AddMemberSignalType = NewSignalType("ADD_MEMBER") + ACLSignalType = NewSignalType("ACL") MemberOfPolicyType = NewPolicyType("USER_OF") RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF") diff --git a/signal.go b/signal.go index a027963..e6d1c51 100644 --- a/signal.go +++ b/signal.go @@ -15,12 +15,12 @@ const ( ) type TimeoutSignal struct { - SignalHeader + ResponseHeader } -func NewTimeoutSignal() *TimeoutSignal { +func NewTimeoutSignal(req_id uuid.UUID) *TimeoutSignal { return &TimeoutSignal{ - NewSignalHeader(Direct), + NewResponseHeader(req_id, Direct), } }