From 1598e2939a049d62bf6b01d1be18df2c5f9f605e Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sun, 25 Jun 2023 13:39:00 -0600 Subject: [PATCH] Cleaned up initializers --- graph.go | 27 +++++++++++++++- lockable.go | 86 +++++++++++++++++++++++++++++++------------------- thread.go | 46 ++++++++++++--------------- thread_test.go | 12 +++++-- 4 files changed, 110 insertions(+), 61 deletions(-) diff --git a/graph.go b/graph.go index 43da196..491453b 100644 --- a/graph.go +++ b/graph.go @@ -272,8 +272,28 @@ func (node * BaseNode) SetState(new_state NodeState) { node.state = new_state } -// How to prevent the states from being modified if they're pointer receivers? +func checkForDuplicate(nodes []GraphNode) error { + found := map[NodeID]bool{} + for _, node := range(nodes) { + if node == nil { + return fmt.Errorf("Cannot get state of nil node") + } + + _, exists := found[node.ID()] + if exists == true { + return fmt.Errorf("Attempted to get state of %s twice", node.ID()) + } + found[node.ID()] = true + } + return nil +} + func UseStates(ctx * GraphContext, nodes []GraphNode, states_fn func(states []NodeState)(interface{}, error)) (interface{}, error) { + err := checkForDuplicate(nodes) + if err != nil { + return nil, err + } + for _, node := range(nodes) { node.StateLock().RLock() } @@ -293,6 +313,11 @@ func UseStates(ctx * GraphContext, nodes []GraphNode, states_fn func(states []No } func UpdateStates(ctx * GraphContext, nodes []GraphNode, states_fn func(states []NodeState)([]NodeState, interface{}, error)) (interface{}, error) { + err := checkForDuplicate(nodes) + if err != nil { + return nil, err + } + for _, node := range(nodes) { node.StateLock().Lock() } diff --git a/lockable.go b/lockable.go index 75c6d3e..1eee207 100644 --- a/lockable.go +++ b/lockable.go @@ -168,45 +168,67 @@ func NewLockHolderState() BaseLockHolderState { } } -func NewLockableState(name string) BaseLockableState { - return BaseLockableState{ - BaseLockHolderState: NewLockHolderState(), - name: name, - owner: nil, - requirements: []Lockable{}, - dependencies: []Lockable{}, +func LinkLockables(ctx * GraphContext, lockable Lockable, requirements []Lockable) error { + if lockable == nil { + return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") } -} -// Link a lockable with a requirement -func LinkLockables(ctx * GraphContext, lockable Lockable, requirement Lockable) error { - if lockable == nil || requirement == nil { - return fmt.Errorf("Will not connect nil to DAG") - } + for _, requirement := range(requirements) { + if requirement == nil { + return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement") + } - if lockable.ID() == requirement.ID() { - return fmt.Errorf("Will not link %s as requirement of itself", lockable.ID()) + if lockable.ID() == requirement.ID() { + return fmt.Errorf("LOCKABLE_LINK_ERR: cannot link %s to itself", lockable.ID()) + } } - _, err := UpdateStates(ctx, []GraphNode{lockable, requirement}, func(states []NodeState) ([]NodeState, interface{}, error) { + nodes := make([]GraphNode, len(requirements) + 1) + nodes[0] = lockable + for i, node := range(requirements) { + nodes[i+1] = node + } + _, err := UpdateStates(ctx, nodes, func(states []NodeState) ([]NodeState, interface{}, error) { + // Check that all the requirements can be added lockable_state := states[0].(LockableState) - requirement_state := states[1].(LockableState) + for i, requirement := range(requirements) { + requirement_state := states[i+1].(LockableState) - if checkIfRequirement(ctx, lockable_state, lockable.ID(), requirement_state, requirement.ID()) == true { - return nil, nil, fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) - } - if checkIfRequirement(ctx, requirement_state, requirement.ID(), lockable_state, lockable.ID()) == true { - return nil, nil, fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) + if checkIfRequirement(ctx, lockable.ID(), requirement_state, requirement.ID()) == true { + return nil, nil, fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) + } + + if checkIfRequirement(ctx, requirement.ID(), lockable_state, lockable.ID()) == true { + return nil, nil, fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) + } + } + // Update the states of the requirements + for i, requirement := range(requirements) { + requirement_state := states[i+1].(LockableState) + requirement_state.AddDependency(lockable) + lockable_state.AddRequirement(requirement) } - lockable_state.AddRequirement(requirement) - requirement_state.AddDependency(lockable) - return []NodeState{lockable_state, requirement_state}, nil, nil + // Return no error + return states, nil, nil }) + return err } +func NewBaseLockableState(name string) BaseLockableState { + state := BaseLockableState{ + BaseLockHolderState: NewLockHolderState(), + name: name, + owner: nil, + requirements: []Lockable{}, + dependencies: []Lockable{}, + } + + return state +} + type Lockable interface { GraphNode // Called when locking the node to allow for custom lock behaviour @@ -246,14 +268,14 @@ func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphS }) } -func checkIfRequirement(ctx * GraphContext, r LockableState, r_id NodeID, cur LockableState, cur_id NodeID) bool { +func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_id NodeID) bool { for _, c := range(cur.Requirements()) { if c.ID() == r_id { return true } val, _ := UseStates(ctx, []GraphNode{c}, func(states []NodeState) (interface{}, error) { requirement_state := states[0].(LockableState) - return checkIfRequirement(ctx, cur, cur_id, requirement_state, c.ID()), nil + return checkIfRequirement(ctx, cur_id, requirement_state, c.ID()), nil }) is_requirement := val.(bool) @@ -408,16 +430,14 @@ func (lockable * BaseLockable) Unlock(node GraphNode, state LockableState) error } func NewLockable(ctx * GraphContext, name string, requirements []Lockable) (* BaseLockable, error) { - state := NewLockableState(name) + state := NewBaseLockableState(name) lockable := &BaseLockable{ BaseNode: NewNode(ctx, RandID(), &state), } - for _, requirement := range(requirements) { - err := LinkLockables(ctx, lockable, requirement) - if err != nil { - return nil, err - } + err := LinkLockables(ctx, lockable, requirements) + if err != nil { + return nil, err } return lockable, nil diff --git a/thread.go b/thread.go index 9f06c31..3aaa37b 100644 --- a/thread.go +++ b/thread.go @@ -130,14 +130,14 @@ func (state * BaseThreadState) AddChild(child Thread, info ThreadInfo) error { return nil } -func checkIfChild(ctx * GraphContext, thread_state ThreadState, thread_id NodeID, cur_state ThreadState, cur_id NodeID) bool { +func checkIfChild(ctx * GraphContext, thread_id NodeID, cur_state ThreadState, cur_id NodeID) bool { for _, child := range(cur_state.Children()) { if child.ID() == thread_id { return true } val, _ := UseStates(ctx, []GraphNode{child}, func(states []NodeState) (interface{}, error) { child_state := states[0].(ThreadState) - return checkIfRequirement(ctx, cur_state, cur_id, child_state, child.ID()), nil + return checkIfChild(ctx, cur_id, child_state, child.ID()), nil }) is_child := val.(bool) @@ -167,11 +167,11 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf return nil, nil, fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) } - if checkIfChild(ctx, thread_state, thread.ID(), child_state, child.ID()) == true { + if checkIfChild(ctx, thread.ID(), child_state, child.ID()) == true { return nil, nil, fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) } - if checkIfChild(ctx, child_state, child.ID(), thread_state, thread.ID()) == true { + if checkIfChild(ctx, child.ID(), thread_state, thread.ID()) == true { return nil, nil, fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) } @@ -363,17 +363,16 @@ var ThreadCancel = func(ctx * GraphContext, thread Thread, signal GraphSignal) ( func NewBaseThreadState(name string) BaseThreadState { return BaseThreadState{ - BaseLockableState: NewLockableState(name), + BaseLockableState: NewBaseLockableState(name), children: []Thread{}, child_info: map[NodeID]ThreadInfo{}, parent: nil, } } -func NewBaseThread(ctx * GraphContext, name string) (BaseThread, error) { - state := NewBaseThreadState(name) +func NewBaseThread(ctx * GraphContext, name string, actions ThreadActions, handlers ThreadHandlers, state ThreadState) (BaseThread, error) { thread := BaseThread{ - BaseLockable: BaseLockable{BaseNode: NewNode(ctx, RandID(), &state)}, + BaseLockable: BaseLockable{BaseNode: NewNode(ctx, RandID(), state)}, Actions: ThreadActions{ "wait": ThreadWait, "start": ThreadDefaultStart, @@ -386,32 +385,29 @@ func NewBaseThread(ctx * GraphContext, name string) (BaseThread, error) { timeout_action: "", } + for key, fn := range(actions) { + thread.Actions[key] = fn + } + + for key, fn := range(handlers) { + thread.Handlers[key] = fn + } + return thread, nil } -func NewThread(ctx * GraphContext, name string, requirements []Lockable, actions ThreadActions, handlers ThreadHandlers) (* BaseThread, error) { - thread, err := NewBaseThread(ctx, name) +func NewSimpleBaseThread(ctx * GraphContext, name string, requirements []Lockable, actions ThreadActions, handlers ThreadHandlers) (* BaseThread, error) { + state := NewBaseThreadState(name) + thread, err := NewBaseThread(ctx, name, actions, handlers, &state) if err != nil { return nil, err } thread_ptr := &thread - for _, requirement := range(requirements) { - err := LinkLockables(ctx, thread_ptr, requirement) - if err != nil { - return nil, err - } - } - - for key, fn := range(actions) { - thread.Actions[key] = fn - } - - for key, fn := range(handlers) { - thread.Handlers[key] = fn + err = LinkLockables(ctx, thread_ptr, requirements) + if err != nil { + return nil, err } - return thread_ptr, nil } - diff --git a/thread_test.go b/thread_test.go index 392554a..1f97bb9 100644 --- a/thread_test.go +++ b/thread_test.go @@ -10,7 +10,7 @@ import ( func TestNewEvent(t * testing.T) { ctx := testContext(t) - t1, err := NewThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) + t1, err := NewSimpleBaseThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) fatalErr(t, err) go func(thread Thread) { @@ -37,7 +37,7 @@ func TestEventWithRequirement(t * testing.T) { l1, err := NewLockable(ctx, "Test Lockable 1", []Lockable{}) fatalErr(t, err) - t1, err := NewThread(ctx, "Test Thread 1", []Lockable{l1}, ThreadActions{}, ThreadHandlers{}) + t1, err := NewSimpleBaseThread(ctx, "Test Thread 1", []Lockable{l1}, ThreadActions{}, ThreadHandlers{}) fatalErr(t, err) go func (thread Thread) { @@ -66,3 +66,11 @@ func TestEventWithRequirement(t * testing.T) { }) fatalErr(t, err) } + +func TestCustomEventState(t * testing.T ) { + ctx := logTestContext(t, []string{"lockable", "thread"}) + + t1, err := NewSimpleBaseThread(ctx, "Test Thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) + fatalErr(t, err) + println(t1) +}