From b31de3418b490c4b913949e03aa39f63fb3d7cf5 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Fri, 23 Jun 2023 21:21:14 -0600 Subject: [PATCH] Added thread_test.go --- graph.go | 4 +- thread.go | 151 +++++++++++++++++++++++++++---------------------- thread_test.go | 21 +++++++ 3 files changed, 105 insertions(+), 71 deletions(-) create mode 100644 thread_test.go diff --git a/graph.go b/graph.go index 7b1eed1..f9cef8f 100644 --- a/graph.go +++ b/graph.go @@ -165,11 +165,11 @@ func NewDirectSignal(source GraphNode, _type string) BaseSignal { return NewBaseSignal(source, _type, Direct) } -func NewAbortSignal(source GraphNode) BaseSignal { +func AbortSignal(source GraphNode) BaseSignal { return NewBaseSignal(source, "abort", Down) } -func NewCancelSignal(source GraphNode) BaseSignal { +func CancelSignal(source GraphNode) BaseSignal { return NewBaseSignal(source, "cancel", Down) } diff --git a/thread.go b/thread.go index 602d530..8e9d24a 100644 --- a/thread.go +++ b/thread.go @@ -19,7 +19,7 @@ func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSigna SendUpdate(ctx, thread_state.Parent(), signal) } - for _, resource := range(thread_state.Lockables()) { + for _, resource := range(thread_state.Requirements()) { SendUpdate(ctx, resource, signal) } } else if signal.Direction() == Down { @@ -27,6 +27,10 @@ func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSigna for _, child := range(thread_state.Children()) { SendUpdate(ctx, child, signal) } + + for _, dep := range(thread_state.Dependencies()) { + SendUpdate(ctx, dep, signal) + } } else if signal.Direction() == Direct { } else { @@ -71,8 +75,6 @@ type ThreadState interface { Children() []Thread ChildInfo(child NodeID) ThreadInfo AddChild(child Thread, info ThreadInfo) error - Lockables() []Lockable - AddLockable(resource Lockable) error } type BaseThreadState struct { @@ -249,8 +251,8 @@ func (state * BaseThreadState) RecordLockHolder(id NodeID, lock_holder GraphNode type Thread interface { GraphNode - Action(action string) (func(* GraphContext)(string, error), bool) - Handler(signal_type string) (func(* GraphContext, GraphSignal) (string, error), bool) + Action(action string) (ThreadAction, bool) + Handler(signal_type string) (ThreadHandler, bool) SetTimeout(end_time time.Time, action string) ClearTimeout() @@ -258,29 +260,6 @@ type Thread interface { TimeoutAction() string } -func (thread * BaseThread) TimeoutAction() string { - return thread.timeout_action -} - -func (thread * BaseThread) Timeout() <-chan time.Time { - return thread.timeout -} - -func (thread * BaseThread) ClearTimeout() { - thread.timeout_action = "" - thread.timeout = nil -} - -func (thread * BaseThread) SetTimeout(end_time time.Time, action string) { - thread.timeout_action = action - thread.timeout = time.After(time.Until(end_time)) -} - -func (thread * BaseThread) Handler(signal_type string) (func(* GraphContext, GraphSignal)(string, error), bool) { - handler, exists := thread.Handlers[signal_type] - return handler, exists -} - func FindChild(ctx * GraphContext, thread Thread, thread_state ThreadState, id NodeID) Thread { if thread == nil { panic("cannot recurse through nil") @@ -329,7 +308,7 @@ func RunThread(ctx * GraphContext, thread Thread) error { } ctx.Log.Logf("thread", "EVENT_ACTION: %s - %s", thread.ID(), next_action) - next_action, err = action(ctx) + next_action, err = action(ctx, thread) if err != nil { return err } @@ -342,17 +321,10 @@ func RunThread(ctx * GraphContext, thread Thread) error { return nil } -func ThreadAbort(thread Thread) func(*GraphContext, GraphSignal) (string, error) { - return func(ctx * GraphContext, signal GraphSignal) (string, error) { - return "", errors.New(fmt.Sprintf("%s aborted by signal", thread.ID())) - } -} - -func ThreadCancel(thread Thread) func(*GraphContext, GraphSignal) (string, error) { - return func(ctx * GraphContext, signal GraphSignal) (string, error) { - return "", nil - } -} +type ThreadAction func(* GraphContext, Thread)(string, error) +type ThreadActions map[string]ThreadAction +type ThreadHandler func(* GraphContext, Thread, GraphSignal)(string, error) +type ThreadHandlers map[string]ThreadHandler // Thread is the most basic thread that can exist in the thread tree. // On start it automatically transitions to completion. @@ -366,8 +338,8 @@ type BaseThread struct { info_lock sync.Mutex parent_lock sync.Mutex - Actions map[string]func(* GraphContext) (string, error) - Handlers map[string]func(* GraphContext, GraphSignal) (string, error) + Actions ThreadActions + Handlers ThreadHandlers timeout <-chan time.Time timeout_action string @@ -381,40 +353,62 @@ func (thread * BaseThread) Unlock(node GraphNode, state LockableState) error { return nil } -func (thread * BaseThread) Action(action string) (func(ctx * GraphContext) (string, error), bool) { +func (thread * BaseThread) Action(action string) (ThreadAction, bool) { action_fn, exists := thread.Actions[action] return action_fn, exists } -func ThreadWait(thread Thread) (func(*GraphContext) (string, error)) { - return func(ctx * GraphContext) (string, error) { - ctx.Log.Logf("thread", "EVENT_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout()) - select { +func (thread * BaseThread) Handler(signal_type string) (ThreadHandler, bool) { + handler, exists := thread.Handlers[signal_type] + return handler, exists +} + +func (thread * BaseThread) TimeoutAction() string { + return thread.timeout_action +} + +func (thread * BaseThread) Timeout() <-chan time.Time { + return thread.timeout +} + +func (thread * BaseThread) ClearTimeout() { + thread.timeout_action = "" + thread.timeout = nil +} + +func (thread * BaseThread) SetTimeout(end_time time.Time, action string) { + thread.timeout_action = action + thread.timeout = time.After(time.Until(end_time)) +} + +var ThreadDefaultStart = func(ctx * GraphContext, thread Thread) (string, error) { + ctx.Log.Logf("thread", "THREAD_DEFAUL_START: %s", thread.ID()) + return "wait", nil +} + +var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) { + ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout()) + select { case signal := <- thread.SignalChannel(): - ctx.Log.Logf("thread", "EVENT_SIGNAL: %s %+v", thread.ID(), signal) + ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal) signal_fn, exists := thread.Handler(signal.Type()) if exists == true { - ctx.Log.Logf("thread", "EVENT_HANDLER: %s - %s", thread.ID(), signal.Type()) - return signal_fn(ctx, signal) + ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type()) + return signal_fn(ctx, thread, signal) } - return "wait", nil case <- thread.Timeout(): - ctx.Log.Logf("thread", "EVENT_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), thread.TimeoutAction()) + ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), thread.TimeoutAction()) return thread.TimeoutAction(), nil - } } + return "wait", nil } -func NewBaseThread(ctx * GraphContext, name string) (BaseThread, error) { - state := NewBaseThreadState(name) - thread := BaseThread{ - BaseNode: NewNode(ctx, RandID(), &state), - Actions: map[string]func(*GraphContext)(string, error){}, - Handlers: map[string]func(*GraphContext,GraphSignal)(string, error){}, - timeout: nil, - timeout_action: "", - } - return thread, nil +var ThreadAbort = func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { + return "", fmt.Errorf("%s aborted by signal from %s", thread.ID(), signal.Source()) +} + +var ThreadCancel = func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { + return "", nil } func NewBaseThreadState(name string) BaseThreadState { @@ -428,7 +422,26 @@ func NewBaseThreadState(name string) BaseThreadState { } } -func NewThread(ctx * GraphContext, name string, requirements []Lockable) (* BaseThread, error) { +func NewBaseThread(ctx * GraphContext, name string) (BaseThread, error) { + state := NewBaseThreadState(name) + thread := BaseThread{ + BaseNode: NewNode(ctx, RandID(), &state), + Actions: ThreadActions{ + "wait": ThreadWait, + "start": ThreadDefaultStart, + }, + Handlers: ThreadHandlers{ + "abort": ThreadAbort, + "cancel": ThreadCancel, + }, + timeout: nil, + timeout_action: "", + } + + return thread, nil +} + +func NewThread(ctx * GraphContext, name string, requirements []Lockable, actions ThreadActions, handlers ThreadHandlers) (* BaseThread, error) { thread, err := NewBaseThread(ctx, name) if err != nil { return nil, err @@ -443,12 +456,12 @@ func NewThread(ctx * GraphContext, name string, requirements []Lockable) (* Base } } - thread_ptr.Actions["wait"] = ThreadWait(thread_ptr) - thread_ptr.Handlers["abort"] = ThreadAbort(thread_ptr) - thread_ptr.Handlers["cancel"] = ThreadCancel(thread_ptr) + for key, fn := range(actions) { + thread.Actions[key] = fn + } - thread_ptr.Actions["start"] = func(ctx * GraphContext) (string, error) { - return "", nil + for key, fn := range(handlers) { + thread.Handlers[key] = fn } return thread_ptr, nil diff --git a/thread_test.go b/thread_test.go new file mode 100644 index 0000000..d06921d --- /dev/null +++ b/thread_test.go @@ -0,0 +1,21 @@ +package graphvent + +import ( + "testing" + "time" +) + +func TestNewEvent(t * testing.T) { + ctx := testContext(t) + + t1, err := NewThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) + fatalErr(t, err) + + go func(thread Thread) { + time.Sleep(1*time.Second) + SendUpdate(ctx, t1, CancelSignal(nil)) + }(t1) + + err = RunThread(ctx, t1) + fatalErr(t, err) +}