From 8d1e2733314db491b8cc89743a3ff3ab7a48837b Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sun, 5 Nov 2023 21:18:14 -0700 Subject: [PATCH] Reworked changes to include map[ExtType][]string --- acl.go | 12 +++++------ context.go | 8 +++++++- event.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++ gql.go | 6 +++++- gql_test.go | 6 ++++-- group.go | 8 ++++---- lockable.go | 24 +++++++++++----------- node.go | 20 ++++++++++++++---- node_test.go | 6 +++++- serialize.go | 13 +++++++++++- serialize_test.go | 2 -- signal.go | 2 +- 12 files changed, 124 insertions(+), 35 deletions(-) diff --git a/acl.go b/acl.go index c9c9665..3e8a5bd 100644 --- a/acl.go +++ b/acl.go @@ -51,7 +51,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa var changes Changes = nil info, waiting := ext.Pending[response.ResponseID()] if waiting == true { - changes = changes.Add("response_processed") + changes.Add(ACLExtType, "pending") delete(ext.Pending, response.ResponseID()) if response.ID() != info.Timeout { err := node.DequeueSignal(info.Timeout) @@ -78,26 +78,26 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } } else { if ext.Policies[policy_index].ContinueAllows(ctx, acl_info, response) == Allow { + changes.Add(ACLExtType, "pending_acls") delete(ext.PendingACLs, info.ID) ctx.Log.Logf("acl", "Request delayed allow") messages = messages.Add(ctx, acl_info.Source, node, nil, NewSuccessSignal(info.ID)) - changes = changes.Add("acl_passed") err := node.DequeueSignal(acl_info.TimeoutID) if err != nil { ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) } } else if acl_info.Counter == 0 { + changes.Add(ACLExtType, "pending_acls") delete(ext.PendingACLs, info.ID) ctx.Log.Logf("acl", "Request delayed deny") messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(info.ID, "acl_denied")) - changes = changes.Add("acl_blocked") err := node.DequeueSignal(acl_info.TimeoutID) if err != nil { ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) } } else { node.PendingACLs[info.ID] = acl_info - changes = changes.Add("acl_processed") + changes.Add(ACLExtType, "pending_acls") } } } @@ -136,7 +136,7 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa messages = messages.Add(ctx, source, node, nil, NewErrorSignal(sig.Id, "acl_denied")) } else if acl_messages != nil { ctx.Log.Logf("acl", "Request pending") - changes = changes.Add("acl_pending") + changes.Add(ACLExtType, "pending") total_messages := 0 // TODO: reasonable timeout/configurable timeout_time := time.Now().Add(time.Second) @@ -175,9 +175,9 @@ func (ext *ACLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa acl_info, exists := ext.PendingACLs[sig.ReqID] if exists == true { delete(ext.PendingACLs, sig.ReqID) + changes.Add(ACLExtType, "pending_acls") ctx.Log.Logf("acl", "Request timeout deny") messages = messages.Add(ctx, acl_info.Source, node, nil, NewErrorSignal(sig.ReqID, "acl_timeout")) - changes = changes.Add("acl_timeout") err := node.DequeueSignal(acl_info.TimeoutID) if err != nil { ctx.Log.Logf("acl", "acl proxy timeout dequeue error: %s", err) diff --git a/context.go b/context.go index 63949f1..b6b7f54 100644 --- a/context.go +++ b/context.go @@ -315,6 +315,7 @@ func (ctx *Context) getNode(id NodeID) (*Node, error) { // Route Messages to dest. Currently only local context routing is supported func (ctx *Context) Send(messages Messages) error { for _, msg := range(messages) { + ctx.Log.Logf("signal", "Sending %s -> %+v", msg.Dest, msg) if msg.Dest == ZeroID { panic("Can't send to null ID") } @@ -480,7 +481,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(Changes{}), ChangesSerialized, SerializeTypeStub, SerializeSlice, DeserializeTypeStub[Changes], DeserializeSlice) + err = ctx.RegisterType(reflect.TypeOf(Changes{}), ChangesSerialized, SerializeTypeStub, SerializeMap, DeserializeTypeStub[Changes], DeserializeMap) if err != nil { return nil, err } @@ -613,6 +614,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterPolicy(reflect.TypeOf(ParentOfPolicy{}), ParentOfPolicyType) + if err != nil { + return nil, err + } + err = ctx.RegisterPolicy(reflect.TypeOf(MemberOfPolicy{}), MemberOfPolicyType) if err != nil { return nil, err diff --git a/event.go b/event.go index b9aa399..bca6fd5 100644 --- a/event.go +++ b/event.go @@ -5,6 +5,40 @@ import ( "fmt" ) +type ParentOfPolicy struct { + PolicyHeader + Policy Tree +} + +func NewParentOfPolicy(policy Tree) *ParentOfPolicy { + return &ParentOfPolicy{ + PolicyHeader: NewPolicyHeader(), + Policy: policy, + } +} + +func (policy ParentOfPolicy) Allows(ctx *Context, principal_id NodeID, action Tree, node *Node)(Messages, RuleResult) { + event_ext, err := GetExt[*EventExt](node, EventExtType) + if err != nil { + ctx.Log.Logf("event", "ParentOfPolicy, node not event %s", node.ID) + return nil, Deny + } + + if event_ext.Parent == principal_id { + return nil, policy.Policy.Allows(action) + } + + return nil, Deny +} + +func (policy ParentOfPolicy) ContinueAllows(ctx *Context, current PendingACL, signal Signal) RuleResult { + return Deny +} + +var DefaultEventPolicy = NewParentOfPolicy(Tree{ + SerializedType(EventControlSignalType): nil, +}) + type EventExt struct { Name string `gv:"name"` State string `gv:"state"` @@ -148,3 +182,21 @@ func (ext *TestEventExt) Process(ctx *Context, node *Node, source NodeID, signal return messages, changes } + +type TransitionValidation struct { + ToState string +} + +func(ext *EventExt) ValidateEventCommand(signal *EventControlSignal, commands map[string]map[string]string) (string, *ErrorSignal) { + transitions, command_mapped := commands[signal.Command] + if command_mapped == false { + return "", NewErrorSignal(signal.Id, "unknown command %s", signal.Command) + } else { + new_state, valid_transition := transitions[ext.State] + if valid_transition == false { + return "", NewErrorSignal(signal.Id, "invalid command state %s(%s)", signal.Command, ext.State) + } else { + return new_state, nil + } + } +} diff --git a/gql.go b/gql.go index c7a6c9a..239c8f2 100644 --- a/gql.go +++ b/gql.go @@ -1562,6 +1562,7 @@ type GQLExt struct { resolver_response map[uuid.UUID]chan Signal resolver_response_lock sync.RWMutex + State string `gv:"state"` TLSKey []byte `gv:"tls_key"` TLSCert []byte `gv:"tls_cert"` Listen string `gv:"listen"` @@ -1692,7 +1693,8 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa ctx.Log.Logf("gql", "starting gql server %s", node.ID) err := ext.StartGQLServer(ctx, node) if err == nil { - changes = changes.Add("server_started") + ctx.Log.Logf("gql", "started gql server on %s", ext.Listen) + changes.Add(GQLExtType, "state") } else { ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) } @@ -1873,6 +1875,7 @@ func (ext *GQLExt) StartGQLServer(ctx *Context, node *Node) error { ext.tcp_listener = l ext.http_server = http_server + ext.State = "running" return nil } @@ -1884,5 +1887,6 @@ func (ext *GQLExt) StopGQLServer() error { ext.http_done.Wait() ext.tcp_listener = nil ext.http_server = nil + ext.State = "stopped" return nil } diff --git a/gql_test.go b/gql_test.go index fff7759..d63d13e 100644 --- a/gql_test.go +++ b/gql_test.go @@ -43,7 +43,7 @@ func TestGQLAuth(t *testing.T) { } func TestGQLServer(t *testing.T) { - ctx := logTestContext(t, []string{"test", "deserialize_types", "serialize_types", "gqlws"}) + ctx := logTestContext(t, []string{"test", "deserialize_types", "serialize_types", "gqlws", "gql"}) TestNodeType := NewNodeType("TEST") err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) @@ -218,7 +218,9 @@ func TestGQLServer(t *testing.T) { ctx.Log.Logf("test", "SUB: %s", resp[:n]) msgs := Messages{} - msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStatusSignal(gql.ID, Changes{"test_status"})) + test_changes := Changes{} + test_changes.Add(GQLExtType, "state") + msgs = msgs.Add(ctx, gql.ID, gql, nil, NewStatusSignal(gql.ID, test_changes)) err = ctx.Send(msgs) fatalErr(t, err) diff --git a/group.go b/group.go index b5f12b5..b865e20 100644 --- a/group.go +++ b/group.go @@ -240,7 +240,7 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig ext.SubGroups[sig.SubGroup] = sub_group messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) - changes = changes.Add("member_added") + changes.Add(GroupExtType, "sub_groups") } } @@ -257,7 +257,7 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig ext.SubGroups[sig.SubGroup] = sub_group messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) - changes = changes.Add("member_removed") + changes.Add(GroupExtType, "sub_groups") } } @@ -268,7 +268,7 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig } else { ext.SubGroups[sig.Name] = []NodeID{} - changes = changes.Add("subgroup_added") + changes.Add(GroupExtType, "sub_groups") messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) } case *RemoveSubGroupSignal: @@ -278,7 +278,7 @@ func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Sig } else { delete(ext.SubGroups, sig.Name) - changes = changes.Add("subgroup_removed") + changes.Add(GroupExtType, "sub_groups") messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(sig.Id)) } } diff --git a/lockable.go b/lockable.go index cefa010..b8f0228 100644 --- a/lockable.go +++ b/lockable.go @@ -113,7 +113,7 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID ext.Requirements = map[NodeID]ReqState{} } ext.Requirements[signal.NodeID] = Unlocked - changes = changes.Add("requirement_added") + changes.Add(LockableExtType, "requirements") messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) } case "remove": @@ -122,7 +122,7 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "can't link: not_requirement")) } else { delete(ext.Requirements, signal.NodeID) - changes = changes.Add("requirement_removed") + changes.Add(LockableExtType, "requirements") messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) } default: @@ -163,10 +163,10 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod ctx.Log.Logf("lockable", "WHOLE LOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner) ext.State = Locked ext.Owner = ext.PendingOwner - changes = changes.Add("locked") + changes.Add(LockableExtType, "state", "owner", "requirements") messages = messages.Add(ctx, *ext.Owner, node, nil, NewSuccessSignal(ext.PendingID)) } else { - changes = changes.Add("partial_lock") + changes.Add(LockableExtType, "requirements") ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, len(ext.Requirements)) } case AbortingLock: @@ -199,15 +199,15 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod previous_owner := *ext.Owner ext.Owner = ext.PendingOwner ext.ReqID = nil - changes = changes.Add("unlocked") + changes.Add(LockableExtType, "state", "owner", "req_id") messages = messages.Add(ctx, previous_owner, node, nil, NewSuccessSignal(ext.PendingID)) } else if old_state == AbortingLock { - changes = changes.Add("lock_aborted") + changes.Add(LockableExtType, "state", "pending_owner") messages = messages.Add(ctx, *ext.PendingOwner, node, nil, NewErrorSignal(*ext.ReqID, "not_unlocked")) ext.PendingOwner = ext.Owner } } else { - changes = changes.Add("partial_unlock") + changes.Add(LockableExtType, "state") ctx.Log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, len(ext.Requirements)) } } @@ -231,7 +231,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID new_owner := source ext.PendingOwner = &new_owner ext.Owner = &new_owner - changes = changes.Add("locked") + changes.Add(LockableExtType, "state", "pending_owner", "owner") messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) } else { ext.State = Locking @@ -240,7 +240,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID new_owner := source ext.PendingOwner = &new_owner ext.PendingID = signal.ID() - changes = changes.Add("locking") + changes.Add(LockableExtType, "state", "req_id", "pending_owner", "pending_id") for id, state := range(ext.Requirements) { if state != Unlocked { ctx.Log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING") @@ -264,7 +264,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID new_owner := source ext.PendingOwner = nil ext.Owner = nil - changes = changes.Add("unlocked") + changes.Add(LockableExtType, "state", "pending_owner", "owner") messages = messages.Add(ctx, new_owner, node, nil, NewSuccessSignal(signal.ID())) } else if source == *ext.Owner { ext.State = Unlocking @@ -272,7 +272,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.ReqID = &id ext.PendingOwner = nil ext.PendingID = signal.ID() - changes = changes.Add("unlocking") + changes.Add(LockableExtType, "state", "pending_owner", "pending_id", "req_id") for id, state := range(ext.Requirements) { if state != Locked { ctx.Log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING") @@ -405,7 +405,7 @@ func (policy RequirementOfPolicy) ContinueAllows(ctx *Context, current PendingAC return Deny } - for req, _ := range(requirements) { + for req := range(requirements) { if req == current.Principal { return policy.NodeRules[sig.NodeID].Allows(current.Action) } diff --git a/node.go b/node.go index f6e01f3..0291cd4 100644 --- a/node.go +++ b/node.go @@ -46,10 +46,18 @@ func RandID() NodeID { return NodeID(uuid.New()) } -type Changes []string +type Changes map[ExtType][]string -func (changes Changes) Add(detail string) Changes { - return append(changes, detail) +func (changes *Changes) Add(ext ExtType, fields ...string) { + if *changes == nil { + *changes = Changes{} + } + current, exists := (*changes)[ext] + if exists == false { + current = []string{} + } + current = append(current, fields...) + (*changes)[ext] = current } // Extensions are data attached to nodes that process signals @@ -533,14 +541,18 @@ func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { for ext_type, ext := range(node.Extensions) { ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) ext_messages, ext_changes := ext.Process(ctx, node, source, signal) + ctx.Log.Logf("gql", "%s changes %+v", reflect.TypeOf(ext), ext_changes) if len(ext_messages) != 0 { messages = append(messages, ext_messages...) } if len(ext_changes) != 0 { - changes = append(changes, ext_changes...) + for ext, change_list := range(ext_changes) { + changes[ext] = append(changes[ext], change_list...) + } } } + ctx.Log.Logf("gql", "changes after process %+v", changes) if len(messages) != 0 { send_err := ctx.Send(messages) diff --git a/node_test.go b/node_test.go index e2e51c4..69087d5 100644 --- a/node_test.go +++ b/node_test.go @@ -19,7 +19,11 @@ func TestNodeDB(t *testing.T) { fatalErr(t, err) _, err = WaitForSignal(node_listener.Chan, 10*time.Millisecond, func(sig *StatusSignal) bool { - return slices.Contains(sig.Changes, "started") && sig.Source == node.ID + gql_changes, has_gql := sig.Changes[GQLExtType] + if has_gql == true { + return slices.Contains(gql_changes, "state") && sig.Source == node.ID + } + return false }) msgs := Messages{} diff --git a/serialize.go b/serialize.go index 4927b0a..e0020f7 100644 --- a/serialize.go +++ b/serialize.go @@ -59,6 +59,12 @@ func (t PolicyType) String() string { return fmt.Sprintf("0x%x", uint64(t)) } +type FieldTag SerializedType + +func (t FieldTag) String() string { + return fmt.Sprintf("0x%x", uint64(t)) +} + type Chunk struct { Data []byte Next *Chunk @@ -190,6 +196,10 @@ func NewPolicyType(name string) PolicyType { return PolicyType(Hash(PolicyTypeBase, name)) } +func NewFieldTag(tag_string string) FieldTag { + return FieldTag(Hash(FieldNameBase, tag_string)) +} + func NewSerializedType(name string) SerializedType { return Hash(SerializedTypeBase, name) } @@ -227,7 +237,8 @@ var ( EventControlSignalType = NewSignalType("EVENT_CONTORL") EventStateSignalType = NewSignalType("VEX_MATCH_STATUS") - MemberOfPolicyType = NewPolicyType("USER_OF") + MemberOfPolicyType = NewPolicyType("MEMBER_OF") + ParentOfPolicyType = NewPolicyType("PARENT_OF") RequirementOfPolicyType = NewPolicyType("REQUIEMENT_OF") PerNodePolicyType = NewPolicyType("PER_NODE") AllNodesPolicyType = NewPolicyType("ALL_NODES") diff --git a/serialize_test.go b/serialize_test.go index bf5b052..15796b4 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -119,8 +119,6 @@ func TestSerializeBasic(t *testing.T) { testSerialize[[]string](t, ctx, []string{"test_1", "test_2", "test_3"}) testSerialize[test_slice](t, ctx, test_slice{"test_1", "test_2", "test_3"}) - - testSerialize[Changes](t, ctx, Changes{"change_1", "change_2", "change_3"}) } type test struct { diff --git a/signal.go b/signal.go index c1b157f..c56190f 100644 --- a/signal.go +++ b/signal.go @@ -238,7 +238,7 @@ func (signal ErrorSignal) Permission() Tree { }, } } -func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) Signal { +func NewErrorSignal(req_id uuid.UUID, fmt_string string, args ...interface{}) *ErrorSignal { return &ErrorSignal{ NewResponseHeader(req_id, Direct), fmt.Sprintf(fmt_string, args...),