From e26ddcae374693fd40a705ad7ab78e8fc845c98e Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Fri, 11 Aug 2023 13:01:32 -0600 Subject: [PATCH] Moved listener to listener.go and user.go to group.go. Fixed some GQL resolving --- gql.go | 76 ++++++------- gql_test.go | 12 +- graph_test.go | 2 +- user.go => group.go | 17 +-- listener.go | 56 +++++++++ lockable.go | 271 +++++--------------------------------------- lockable_test.go | 122 ++++++-------------- node.go | 37 ++---- 8 files changed, 171 insertions(+), 422 deletions(-) rename user.go => group.go (72%) create mode 100644 listener.go diff --git a/gql.go b/gql.go index 34585be..f072e37 100644 --- a/gql.go +++ b/gql.go @@ -10,6 +10,7 @@ import ( "github.com/graphql-go/graphql/language/ast" "context" "encoding/json" + "encoding/base64" "io" "reflect" "fmt" @@ -190,17 +191,27 @@ type ResolveContext struct { } func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { - id_bytes, key_bytes, ok := r.BasicAuth() + id_b64, key_b64, ok := r.BasicAuth() if ok == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") } + id_bytes, err := base64.StdEncoding.DecodeString(id_b64) + if err != nil { + return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID bytes from auth username: %+v", id_b64) + } + auth_uuid, err := uuid.FromBytes([]byte(id_bytes)) if err != nil { - return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username %+v", id_bytes) + return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from id_bytes %+v", id_bytes) } auth_id := NodeID(auth_uuid) + key_bytes, err := base64.StdEncoding.DecodeString(key_b64) + if err != nil { + return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse key bytes from auth password: %+v", key_b64) + } + key_raw, err := x509.ParsePKCS8PrivateKey([]byte(key_bytes)) if err != nil { return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ecdsa key from auth password: %s", key_bytes) @@ -632,7 +643,7 @@ type ListField struct { type SelfField struct { ACLName string Extension ExtType - ResolveFn func(graphql.ResolveParams, interface{}) (NodeID, error) + ResolveFn func(graphql.ResolveParams, interface{}) (*NodeID, error) } func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, interfaces []string, fields []string, self_fields map[string]SelfField, list_fields map[string]ListField) error { @@ -677,19 +688,22 @@ func (ctx *GQLExtContext) RegisterInterface(name string, default_name string, in return nil, err } - var zero NodeID id, err := self_field.ResolveFn(p, val) if err != nil { - return zero, err + return nil, err } - nodes, err := ResolveNodes(ctx, p, []NodeID{id}) - if err != nil { - return nil, err - } else if len(nodes) != 1 { - return nil, fmt.Errorf("wrong length of nodes returned") + if id != nil { + nodes, err := ResolveNodes(ctx, p, []NodeID{*id}) + if err != nil { + return nil, err + } else if len(nodes) != 1 { + return nil, fmt.Errorf("wrong length of nodes returned") + } + return nodes[0], nil + } else { + return nil, nil } - return nodes[0], nil }) if err != nil { return err @@ -844,16 +858,10 @@ func NewGQLExtContext() *GQLExtContext { "Owner": SelfField{ "owner", LockableExtType, - func(p graphql.ResolveParams, val interface{}) (NodeID, error) { - var zero NodeID - id_str, ok := val.(string) + func(p graphql.ResolveParams, val interface{}) (*NodeID, error) { + id, ok := val.(*NodeID) if ok == false { - return zero, fmt.Errorf("can't parse %+v as string", val) - } - - id, err := ParseID(id_str) - if err != nil { - return zero, err + return nil, fmt.Errorf("can't parse %+v as *NodeID", val) } return id, nil @@ -864,34 +872,14 @@ func NewGQLExtContext() *GQLExtContext { "requirements", LockableExtType, func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) { - id_strs, ok := val.(LinkMap) + id_strs, ok := val.(map[NodeID]string) if ok == false { return nil, fmt.Errorf("can't parse requirements %+v as string, %s", val, reflect.TypeOf(val)) } ids := []NodeID{} - for id, state := range(id_strs) { - if state.Link == "linked" { - ids = append(ids, id) - } - } - return ids, nil - }, - }, - "Dependencies": ListField{ - "dependencies", - LockableExtType, - func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) { - id_strs, ok := val.(LinkMap) - if ok == false { - return nil, fmt.Errorf("can't parse dependencies %+v as string, %s", val, reflect.TypeOf(val)) - } - - ids := []NodeID{} - for id, state := range(id_strs) { - if state.Link == "linked" { - ids = append(ids, id) - } + for id, _ := range(id_strs) { + ids = append(ids, id) } return ids, nil }, @@ -909,7 +897,7 @@ func NewGQLExtContext() *GQLExtContext { panic(err) } - err = context.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node", "Lockable", "Group"}, []string{"Listen", "Owner", "Requirements", "Dependencies", "Members"}) + err = context.RegisterNodeType(GQLNodeType, "GQLServer", []string{"Node", "Lockable", "Group"}, []string{"Listen", "Owner", "Requirements", "Members"}) if err != nil { panic(err) } diff --git a/gql_test.go b/gql_test.go index ac3e1f7..3aa304d 100644 --- a/gql_test.go +++ b/gql_test.go @@ -5,6 +5,7 @@ import ( "time" "fmt" "encoding/json" + "encoding/base64" "io" "net/http" "net" @@ -66,12 +67,12 @@ func TestGQLServer(t *testing.T) { n1 := NewNode(ctx, nil, TestNodeType, 10, map[PolicyType]Policy{ MemberOfPolicyType: &user_policy_2, AllNodesPolicyType: &user_policy_1, - }, NewLockableExt()) + }, NewLockableExt(nil)) gql := NewNode(ctx, gql_key, GQLNodeType, 10, map[PolicyType]Policy{ MemberOfPolicyType: &group_policy_2, AllNodesPolicyType: &group_policy_1, - }, NewLockableExt(), gql_ext, NewGroupExt(map[NodeID]string{ + }, NewLockableExt([]NodeID{n1.ID}), gql_ext, NewGroupExt(map[NodeID]string{ n1.ID: "user", gql_id: "self", }), listener_ext) @@ -79,9 +80,6 @@ func TestGQLServer(t *testing.T) { ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "NODE: %s", n1.ID) - err = LinkRequirement(ctx, gql, n1.ID) - fatalErr(t, err) - msgs := Messages{} msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}, gql.ID) err = ctx.Send(msgs) @@ -107,7 +105,7 @@ func TestGQLServer(t *testing.T) { } req_2 := GQLPayload{ - Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQLServer { Members { ID } , Listen, Requirements { ID, TypeHash, Dependencies { ID } } } } }", + Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash, ... on GQLServer { Members { ID } , Listen, Requirements { ID, TypeHash Owner { ID } } } } }", Variables: map[string]interface{}{ "id": gql.ID.String(), }, @@ -123,7 +121,7 @@ func TestGQLServer(t *testing.T) { key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key) fatalErr(t, err) - req.SetBasicAuth(string(n1.ID.Serialize()), string(key_bytes)) + req.SetBasicAuth(base64.StdEncoding.EncodeToString(n1.ID.Serialize()), base64.StdEncoding.EncodeToString(key_bytes)) resp, err := client.Do(req) fatalErr(t, err) diff --git a/graph_test.go b/graph_test.go index f13e528..f90a7bc 100644 --- a/graph_test.go +++ b/graph_test.go @@ -16,7 +16,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { 10, nil, listener_extension, - NewLockableExt()) + NewLockableExt(nil)) return listener, listener_extension } diff --git a/user.go b/group.go similarity index 72% rename from user.go rename to group.go index d14eb8c..e784fbe 100644 --- a/user.go +++ b/group.go @@ -5,11 +5,7 @@ import ( ) type GroupExt struct { - Members map[NodeID]string -} - -type GroupExtJSON struct { - Members map[string]string `json:"members"` + Members map[NodeID]string `json:"members"` } func (ext *GroupExt) Type() ExtType { @@ -17,9 +13,7 @@ func (ext *GroupExt) Type() ExtType { } func (ext *GroupExt) Serialize() ([]byte, error) { - return json.Marshal(&GroupExtJSON{ - Members: IDMap(ext.Members), - }) + return json.Marshal(ext) } func (ext *GroupExt) Field(name string) interface{} { @@ -41,11 +35,8 @@ func NewGroupExt(members map[NodeID]string) *GroupExt { } func (ext *GroupExt) Deserialize(ctx *Context, data []byte) error { - var j GroupExtJSON - err := json.Unmarshal(data, &j) - - ext.Members, err = LoadIDMap(j.Members) - return err + ext.Members = map[NodeID]string{} + return json.Unmarshal(data, ext) } func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..10d026c --- /dev/null +++ b/listener.go @@ -0,0 +1,56 @@ +package graphvent + +import ( + "encoding/json" +) + +// A Listener extension provides a channel that can receive signals on a different thread +type ListenerExt struct { + Buffer int + Chan chan Signal +} + +// Create a new listener extension with a given buffer size +func NewListenerExt(buffer int) *ListenerExt { + return &ListenerExt{ + Buffer: buffer, + Chan: make(chan Signal, buffer), + } +} + +func (ext *ListenerExt) Field(name string) interface{} { + return ResolveFields(ext, name, map[string]func(*ListenerExt)interface{}{ + "buffer": func(ext *ListenerExt) interface{} { + return ext.Buffer + }, + "chan": func(ext *ListenerExt) interface{} { + return ext.Chan + }, + }) +} + +// Simple load function, unmarshal the buffer int from json +func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error { + err := json.Unmarshal(data, &ext.Buffer) + ext.Chan = make(chan Signal, ext.Buffer) + return err +} + +func (listener *ListenerExt) Type() ExtType { + return ListenerExtType +} + +// Send the signal to the channel, logging an overflow if it occurs +func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { + ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal) + select { + case ext.Chan <- signal: + default: + ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID) + } + return nil +} + +func (ext *ListenerExt) Serialize() ([]byte, error) { + return json.Marshal(ext.Buffer) +} diff --git a/lockable.go b/lockable.go index b5c4b94..46ca4eb 100644 --- a/lockable.go +++ b/lockable.go @@ -4,94 +4,10 @@ import ( "encoding/json" ) -// A Listener extension provides a channel that can receive signals on a different thread -type ListenerExt struct { - Buffer int - Chan chan Signal -} - -// Create a new listener extension with a given buffer size -func NewListenerExt(buffer int) *ListenerExt { - return &ListenerExt{ - Buffer: buffer, - Chan: make(chan Signal, buffer), - } -} - -func (ext *ListenerExt) Field(name string) interface{} { - return ResolveFields(ext, name, map[string]func(*ListenerExt)interface{}{ - "buffer": func(ext *ListenerExt) interface{} { - return ext.Buffer - }, - "chan": func(ext *ListenerExt) interface{} { - return ext.Chan - }, - }) -} - -// Simple load function, unmarshal the buffer int from json -func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error { - err := json.Unmarshal(data, &ext.Buffer) - ext.Chan = make(chan Signal, ext.Buffer) - return err -} - -func (listener *ListenerExt) Type() ExtType { - return ListenerExtType -} - -// Send the signal to the channel, logging an overflow if it occurs -func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { - ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal) - select { - case ext.Chan <- signal: - default: - ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID) - } - return nil -} - -// ReqState holds the multiple states of a requirement -type LinkState struct { - Link string `json:"link"` - Lock string `json:"lock"` - Initiator NodeID `json:"initiator"` -} - -// A LockableExt allows a node to be linked to other nodes(via LinkSignal) and locked/unlocked(via LockSignal) -type LinkMap map[NodeID]LinkState -func (m LinkMap) MarshalJSON() ([]byte, error) { - tmp := map[string]LinkState{} - for id, state := range(m) { - tmp[id.String()] = state - } - - return json.Marshal(tmp) -} - -func (m LinkMap) UnmarshalJSON(data []byte) error { - tmp := map[string]LinkState{} - err := json.Unmarshal(data, &tmp) - if err != nil { - return err - } - - for id_str, state := range(tmp) { - id, err := ParseID(id_str) - if err != nil { - return err - } - - m[id] = state - } - return nil -} - type LockableExt struct { Owner *NodeID `json:"owner"` PendingOwner *NodeID `json:"pending_owner"` - Requirements LinkMap `json:"requirements"` - Dependencies LinkMap `json:"dependencies"` + Requirements map[NodeID]string `json:"requirements"` } func (ext *LockableExt) Field(name string) interface{} { @@ -105,16 +21,9 @@ func (ext *LockableExt) Field(name string) interface{} { "requirements": func(ext *LockableExt) interface{} { return ext.Requirements }, - "dependencies": func(ext *LockableExt) interface{} { - return ext.Dependencies - }, }) } -func (ext *ListenerExt) Serialize() ([]byte, error) { - return json.Marshal(ext.Buffer) -} - func (ext *LockableExt) Type() ExtType { return LockableExtType } @@ -127,12 +36,15 @@ func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error { return json.Unmarshal(data, ext) } -func NewLockableExt() *LockableExt { +func NewLockableExt(requirements []NodeID) *LockableExt { + reqs := map[NodeID]string{} + for _, id := range(requirements) { + reqs[id] = "unlocked" + } return &LockableExt{ Owner: nil, PendingOwner: nil, - Requirements: map[NodeID]LinkState{}, - Dependencies: map[NodeID]LinkState{}, + Requirements: reqs, } } @@ -150,13 +62,6 @@ func LockLockable(ctx *Context, node *Node) error { return ctx.Send(msgs) } -// Setup a node to send the initial requirement link signal, then send the signal -func LinkRequirement(ctx *Context, dependency *Node, requirement NodeID) error { - msgs := Messages{} - msgs = msgs.Add(dependency.ID, dependency.Key, NewLinkStartSignal("req", requirement), dependency.ID) - return ctx.Send(msgs) -} - // Handle a LockSignal and update the extensions owner/requirement states func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages { state := signal.Str @@ -179,14 +84,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, } else { ext.PendingOwner = nil for id, state := range(ext.Requirements) { - if state.Link == "linked" { - if state.Lock != "locked" { - panic("NOT_LOCKED") - } - state.Lock = "unlocking" - ext.Requirements[id] = state - messages = messages.Add(node.ID, node.Key, NewLockSignal("unlock"), id) + if state != "locked" { + panic("NOT_LOCKED") } + ext.Requirements[id] = "unlocking" + messages = messages.Add(node.ID, node.Key, NewLockSignal("unlock"), id) } if source != node.ID { messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocking"), source) @@ -197,9 +99,7 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, state, exists := ext.Requirements[source] if exists == false { messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state.Link != "linked" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) - } else if state.Lock != "unlocking" { + } else if state != "unlocking" { messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) } @@ -211,27 +111,20 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, state, exists := ext.Requirements[source] if exists == false { messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state.Link != "linked" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) - } else if state.Lock != "unlocking" { + } else if state != "unlocking" { messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) } else { - state.Lock = "unlocked" - ext.Requirements[source] = state + ext.Requirements[source] = "unlocked" if ext.PendingOwner == nil { - linked := 0 unlocked := 0 for _, s := range(ext.Requirements) { - if s.Link == "linked" { - linked += 1 - } - if s.Lock == "unlocked" { + if s == "unlocked" { unlocked += 1 } } - if linked == unlocked { + if len(ext.Requirements) == unlocked { previous_owner := *ext.Owner ext.Owner = nil messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), previous_owner) @@ -246,27 +139,20 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, state, exists := ext.Requirements[source] if exists == false { messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state.Link != "linked" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) - } else if state.Lock != "locking" { + } else if state != "locking" { messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) } else { - state.Lock = "locked" - ext.Requirements[source] = state + ext.Requirements[source] = "locked" if ext.PendingOwner != nil { - linked := 0 locked := 0 for _, s := range(ext.Requirements) { - if s.Link == "linked" { - linked += 1 - } - if s.Lock == "locked" { + if s == "locked" { locked += 1 } } - if linked == locked { + if len(ext.Requirements) == locked { ext.Owner = ext.PendingOwner messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) } @@ -276,9 +162,7 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, state, exists := ext.Requirements[source] if exists == false { messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state.Link != "linked" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) - } else if state.Lock != "locking" { + } else if state != "locking" { messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) } @@ -296,15 +180,12 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, } else { ext.PendingOwner = &owner for id, state := range(ext.Requirements) { - if state.Link == "linked" { - log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id) - if state.Lock != "unlocked" { - panic("NOT_UNLOCKED") - } - state.Lock = "locking" - ext.Requirements[id] = state - messages = messages.Add(node.ID, node.Key, NewLockSignal("lock"), id) + log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id) + if state != "unlocked" { + panic("NOT_UNLOCKED") } + ext.Requirements[id] = "locking" + messages = messages.Add(node.ID, node.Key, NewLockSignal("lock"), id) } if source != node.ID { messages = messages.Add(node.ID, node.Key, NewLockSignal("locking"), source) @@ -318,119 +199,25 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, return messages } -func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) Messages { - link_type := signal.Str - target := signal.NodeID - log.Logf("lockable", "LINK_START_SIGNAL: %s->%s %s %s", source, node.ID, link_type, target) - - messages := Messages{} - switch link_type { - case "req": - state, exists := ext.Requirements[target] - _, dep_exists := ext.Dependencies[target] - if ext.Owner != nil { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source) - } else if ext.Owner != ext.PendingOwner { - if ext.PendingOwner == nil { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source) - } else { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source) - } - } else if exists == true { - if state.Link == "linking" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking_req"), source) - } else if state.Link == "linked" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_req"), source) - } - } else if dep_exists == true { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_dep"), source) - } else { - ext.Requirements[target] = LinkState{"linking", "unlocked", source} - messages = messages.Add(node.ID, node.Key, NewLinkSignal("linked_as_req"), target) - messages = messages.Add(node.ID, node.Key, NewLinkStartSignal("linking_req", target), source) - } - } - return messages -} - -// Handle LinkSignal, updating the extensions requirements and dependencies as necessary -// TODO: Add unlink -func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages { - log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) - state := signal.Str - - messages := Messages{} - switch state { - case "dep_done": - state, exists := ext.Requirements[source] - if exists == false { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linking"), source) - } else if state.Link == "linking" { - state.Link = "linked" - ext.Requirements[source] = state - log.Logf("lockable", "FINISHED_LINKING_REQ: %s->%s", node.ID, source) - } - case "linked_as_req": - state, exists := ext.Dependencies[source] - if exists == false { - ext.Dependencies[source] = LinkState{"linked", "unlocked", source} - messages = messages.Add(node.ID, node.Key, NewLinkSignal("dep_done"), source) - } else if state.Link == "linking" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking"), source) - } else if state.Link == "linked" { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linked"), source) - } else if ext.PendingOwner != ext.Owner { - if ext.Owner == nil { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source) - } else { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source) - } - } - - default: - log.Logf("lockable", "LINK_ERROR: unknown state %s", state) - } - return messages -} - // LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { messages := Messages{} switch signal.Direction() { case Up: - ctx.Log.Logf("lockable", "LOCKABLE_DEPENDENCIES: %+v", ext.Dependencies) - owner_sent := false - for dependency, state := range(ext.Dependencies) { - if state.Link == "linked" { - messages = messages.Add(node.ID, node.Key, signal, dependency) - if ext.Owner != nil { - if dependency == *ext.Owner { - owner_sent = true - } - } - } - } - - if ext.Owner != nil && owner_sent == false { + if ext.Owner != nil { if *ext.Owner != node.ID { messages = messages.Add(node.ID, node.Key, signal, *ext.Owner) } } case Down: - for requirement, state := range(ext.Requirements) { - if state.Link == "linked" { - messages = messages.Add(node.ID, node.Key, signal, requirement) - } + for requirement, _ := range(ext.Requirements) { + messages = messages.Add(node.ID, node.Key, signal, requirement) } case Direct: switch signal.Type() { - case LinkSignalType: - messages = ext.HandleLinkSignal(ctx.Log, node, source, signal.(*StringSignal)) case LockSignalType: messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal)) - case LinkStartSignalType: - messages = ext.HandleLinkStartSignal(ctx.Log, node, source, signal.(*IDStringSignal)) default: } default: diff --git a/lockable_test.go b/lockable_test.go index bd75958..bcbcb46 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -18,29 +18,20 @@ func lockableTestContext(t *testing.T, logs []string) *Context { func TestLink(t *testing.T) { ctx := lockableTestContext(t, []string{"lockable"}) - l1_listener := NewListenerExt(10) - l1 := NewNode(ctx, nil, TestLockableType, 10, nil, - l1_listener, - NewLockableExt(), - ) l2_listener := NewListenerExt(10) l2 := NewNode(ctx, nil, TestLockableType, 10, nil, l2_listener, - NewLockableExt(), + NewLockableExt(nil), + ) + l1_listener := NewListenerExt(10) + NewNode(ctx, nil, TestLockableType, 10, nil, + l1_listener, + NewLockableExt([]NodeID{l2.ID}), ) - - // Link l2 as a requirement of l1 - err := LinkRequirement(ctx, l1, l2.ID) - fatalErr(t, err) - - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { - return sig.Str == "dep_done" - }) - fatalErr(t, err) msgs := Messages{} msgs = msgs.Add(l2.ID, l2.Key, NewStatusSignal("TEST", l2.ID), l2.ID) - err = ctx.Send(msgs) + err := ctx.Send(msgs) fatalErr(t, err) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { @@ -59,106 +50,57 @@ func TestLink10K(t *testing.T) { NewLockable := func()(*Node) { l := NewNode(ctx, nil, TestLockableType, 10, nil, - NewLockableExt(), + NewLockableExt(nil), ) return l } - NewListener := func()(*Node, *ListenerExt) { - listener := NewListenerExt(100000) - l := NewNode(ctx, nil, TestLockableType, 256, nil, - listener, - NewLockableExt(), - ) - return l, listener + reqs := make([]NodeID, 10000) + for i, _ := range(reqs) { + new_lockable := NewLockable() + reqs[i] = new_lockable.ID } - - l0, l0_listener := NewListener() - lockables := make([]*Node, 10) - for i, _ := range(lockables) { - lockables[i] = NewLockable() - LinkRequirement(ctx, l0, lockables[i].ID) - } - ctx.Log.Logf("test", "CREATED_10K") - - for range(lockables) { - _, err := WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { - return sig.Str == "dep_done" - }) - fatalErr(t, err) + NewListener := func()(*ListenerExt) { + listener := NewListenerExt(100000) + NewNode(ctx, nil, TestLockableType, 256, nil, + listener, + NewLockableExt(reqs), + ) + return listener } + NewListener() + ctx.Log.Logf("test", "CREATED_LISTENER") - ctx.Log.Logf("test", "LINKED_10K") + // TODO: Lock listener and wait for all the lock signals + //ctx.Log.Logf("test", "LOCKED_10K") } func TestLock(t *testing.T) { ctx := lockableTestContext(t, []string{}) - NewLockable := func()(*Node, *ListenerExt) { + NewLockable := func(reqs []NodeID)(*Node, *ListenerExt) { listener := NewListenerExt(100) l := NewNode(ctx, nil, TestLockableType, 10, nil, listener, - NewLockableExt(), + NewLockableExt(reqs), ) return l, listener } - l0, l0_listener := NewLockable() - l1, l1_listener := NewLockable() - l2, _ := NewLockable() - l3, _ := NewLockable() - l4, _ := NewLockable() - l5, _ := NewLockable() - - - var err error - err = LinkRequirement(ctx, l1, l2.ID) - fatalErr(t, err) - err = LinkRequirement(ctx, l1, l3.ID) - fatalErr(t, err) - err = LinkRequirement(ctx, l1, l4.ID) - fatalErr(t, err) - err = LinkRequirement(ctx, l1, l5.ID) - fatalErr(t, err) - - err = LinkRequirement(ctx, l0, l2.ID) - fatalErr(t, err) - err = LinkRequirement(ctx, l0, l3.ID) - fatalErr(t, err) - err = LinkRequirement(ctx, l0, l4.ID) - fatalErr(t, err) - err = LinkRequirement(ctx, l0, l5.ID) - fatalErr(t, err) - - linked_as_req := func(sig *StringSignal) bool { - return sig.Str == "dep_done" - } + l2, _ := NewLockable(nil) + l3, _ := NewLockable(nil) + l4, _ := NewLockable(nil) + l5, _ := NewLockable(nil) + NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) + l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) locked := func(sig *StringSignal) bool { return sig.Str == "locked" } - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) - fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) - fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) - fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) - fatalErr(t, err) - - _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) - fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) - fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) - fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) - fatalErr(t, err) - - err = LockLockable(ctx, l1) + err := LockLockable(ctx, l1) fatalErr(t, err) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) diff --git a/node.go b/node.go index 81f45f5..b150e7e 100644 --- a/node.go +++ b/node.go @@ -35,12 +35,17 @@ var ( // A NodeID uniquely identifies a Node type NodeID uuid.UUID + +func (id NodeID) MarshalText() ([]byte, error) { + return json.Marshal(id.String()) +} + +func (id *NodeID) UnmarshalText(data []byte) error { + return json.Unmarshal(data, id) +} + func (id *NodeID) MarshalJSON() ([]byte, error) { - str := "" - if id != nil { - str = id.String() - } - return json.Marshal(&str) + return json.Marshal(id.String()) } func (id *NodeID) UnmarshalJSON(bytes []byte) error { @@ -59,6 +64,8 @@ func (id NodeID) Serialize() []byte { return ser } + + func (id NodeID) String() string { return (uuid.UUID)(id).String() } @@ -1042,23 +1049,3 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return node, nil } - -func IDMap[S any, T map[NodeID]S](m T)map[string]S { - ret := map[string]S{} - for id, val := range(m) { - ret[id.String()] = val - } - return ret -} - -func LoadIDMap[S any, T map[string]S](m T)(map[NodeID]S, error) { - ret := map[NodeID]S{} - for str, val := range(m) { - id, err := ParseID(str) - if err != nil { - return nil, err - } - ret[id] = val - } - return ret, nil -}