Added thread_test.go

graph-rework
noah metz 2023-06-23 21:21:14 -06:00
parent 30971f00bd
commit b31de3418b
3 changed files with 105 additions and 71 deletions

@ -165,11 +165,11 @@ func NewDirectSignal(source GraphNode, _type string) BaseSignal {
return NewBaseSignal(source, _type, Direct) return NewBaseSignal(source, _type, Direct)
} }
func NewAbortSignal(source GraphNode) BaseSignal { func AbortSignal(source GraphNode) BaseSignal {
return NewBaseSignal(source, "abort", Down) return NewBaseSignal(source, "abort", Down)
} }
func NewCancelSignal(source GraphNode) BaseSignal { func CancelSignal(source GraphNode) BaseSignal {
return NewBaseSignal(source, "cancel", Down) return NewBaseSignal(source, "cancel", Down)
} }

@ -19,7 +19,7 @@ func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSigna
SendUpdate(ctx, thread_state.Parent(), signal) SendUpdate(ctx, thread_state.Parent(), signal)
} }
for _, resource := range(thread_state.Lockables()) { for _, resource := range(thread_state.Requirements()) {
SendUpdate(ctx, resource, signal) SendUpdate(ctx, resource, signal)
} }
} else if signal.Direction() == Down { } else if signal.Direction() == Down {
@ -27,6 +27,10 @@ func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSigna
for _, child := range(thread_state.Children()) { for _, child := range(thread_state.Children()) {
SendUpdate(ctx, child, signal) SendUpdate(ctx, child, signal)
} }
for _, dep := range(thread_state.Dependencies()) {
SendUpdate(ctx, dep, signal)
}
} else if signal.Direction() == Direct { } else if signal.Direction() == Direct {
} else { } else {
@ -71,8 +75,6 @@ type ThreadState interface {
Children() []Thread Children() []Thread
ChildInfo(child NodeID) ThreadInfo ChildInfo(child NodeID) ThreadInfo
AddChild(child Thread, info ThreadInfo) error AddChild(child Thread, info ThreadInfo) error
Lockables() []Lockable
AddLockable(resource Lockable) error
} }
type BaseThreadState struct { type BaseThreadState struct {
@ -249,8 +251,8 @@ func (state * BaseThreadState) RecordLockHolder(id NodeID, lock_holder GraphNode
type Thread interface { type Thread interface {
GraphNode GraphNode
Action(action string) (func(* GraphContext)(string, error), bool) Action(action string) (ThreadAction, bool)
Handler(signal_type string) (func(* GraphContext, GraphSignal) (string, error), bool) Handler(signal_type string) (ThreadHandler, bool)
SetTimeout(end_time time.Time, action string) SetTimeout(end_time time.Time, action string)
ClearTimeout() ClearTimeout()
@ -258,29 +260,6 @@ type Thread interface {
TimeoutAction() string TimeoutAction() string
} }
func (thread * BaseThread) TimeoutAction() string {
return thread.timeout_action
}
func (thread * BaseThread) Timeout() <-chan time.Time {
return thread.timeout
}
func (thread * BaseThread) ClearTimeout() {
thread.timeout_action = ""
thread.timeout = nil
}
func (thread * BaseThread) SetTimeout(end_time time.Time, action string) {
thread.timeout_action = action
thread.timeout = time.After(time.Until(end_time))
}
func (thread * BaseThread) Handler(signal_type string) (func(* GraphContext, GraphSignal)(string, error), bool) {
handler, exists := thread.Handlers[signal_type]
return handler, exists
}
func FindChild(ctx * GraphContext, thread Thread, thread_state ThreadState, id NodeID) Thread { func FindChild(ctx * GraphContext, thread Thread, thread_state ThreadState, id NodeID) Thread {
if thread == nil { if thread == nil {
panic("cannot recurse through nil") panic("cannot recurse through nil")
@ -329,7 +308,7 @@ func RunThread(ctx * GraphContext, thread Thread) error {
} }
ctx.Log.Logf("thread", "EVENT_ACTION: %s - %s", thread.ID(), next_action) ctx.Log.Logf("thread", "EVENT_ACTION: %s - %s", thread.ID(), next_action)
next_action, err = action(ctx) next_action, err = action(ctx, thread)
if err != nil { if err != nil {
return err return err
} }
@ -342,17 +321,10 @@ func RunThread(ctx * GraphContext, thread Thread) error {
return nil return nil
} }
func ThreadAbort(thread Thread) func(*GraphContext, GraphSignal) (string, error) { type ThreadAction func(* GraphContext, Thread)(string, error)
return func(ctx * GraphContext, signal GraphSignal) (string, error) { type ThreadActions map[string]ThreadAction
return "", errors.New(fmt.Sprintf("%s aborted by signal", thread.ID())) type ThreadHandler func(* GraphContext, Thread, GraphSignal)(string, error)
} type ThreadHandlers map[string]ThreadHandler
}
func ThreadCancel(thread Thread) func(*GraphContext, GraphSignal) (string, error) {
return func(ctx * GraphContext, signal GraphSignal) (string, error) {
return "", nil
}
}
// Thread is the most basic thread that can exist in the thread tree. // Thread is the most basic thread that can exist in the thread tree.
// On start it automatically transitions to completion. // On start it automatically transitions to completion.
@ -366,8 +338,8 @@ type BaseThread struct {
info_lock sync.Mutex info_lock sync.Mutex
parent_lock sync.Mutex parent_lock sync.Mutex
Actions map[string]func(* GraphContext) (string, error) Actions ThreadActions
Handlers map[string]func(* GraphContext, GraphSignal) (string, error) Handlers ThreadHandlers
timeout <-chan time.Time timeout <-chan time.Time
timeout_action string timeout_action string
@ -381,40 +353,62 @@ func (thread * BaseThread) Unlock(node GraphNode, state LockableState) error {
return nil return nil
} }
func (thread * BaseThread) Action(action string) (func(ctx * GraphContext) (string, error), bool) { func (thread * BaseThread) Action(action string) (ThreadAction, bool) {
action_fn, exists := thread.Actions[action] action_fn, exists := thread.Actions[action]
return action_fn, exists return action_fn, exists
} }
func ThreadWait(thread Thread) (func(*GraphContext) (string, error)) { func (thread * BaseThread) Handler(signal_type string) (ThreadHandler, bool) {
return func(ctx * GraphContext) (string, error) { handler, exists := thread.Handlers[signal_type]
ctx.Log.Logf("thread", "EVENT_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout()) return handler, exists
}
func (thread * BaseThread) TimeoutAction() string {
return thread.timeout_action
}
func (thread * BaseThread) Timeout() <-chan time.Time {
return thread.timeout
}
func (thread * BaseThread) ClearTimeout() {
thread.timeout_action = ""
thread.timeout = nil
}
func (thread * BaseThread) SetTimeout(end_time time.Time, action string) {
thread.timeout_action = action
thread.timeout = time.After(time.Until(end_time))
}
var ThreadDefaultStart = func(ctx * GraphContext, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAUL_START: %s", thread.ID())
return "wait", nil
}
var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout())
select { select {
case signal := <- thread.SignalChannel(): case signal := <- thread.SignalChannel():
ctx.Log.Logf("thread", "EVENT_SIGNAL: %s %+v", thread.ID(), signal) ctx.Log.Logf("thread", "THREAD_SIGNAL: %s %+v", thread.ID(), signal)
signal_fn, exists := thread.Handler(signal.Type()) signal_fn, exists := thread.Handler(signal.Type())
if exists == true { if exists == true {
ctx.Log.Logf("thread", "EVENT_HANDLER: %s - %s", thread.ID(), signal.Type()) ctx.Log.Logf("thread", "THREAD_HANDLER: %s - %s", thread.ID(), signal.Type())
return signal_fn(ctx, signal) return signal_fn(ctx, thread, signal)
} }
return "wait", nil
case <- thread.Timeout(): case <- thread.Timeout():
ctx.Log.Logf("thread", "EVENT_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), thread.TimeoutAction()) ctx.Log.Logf("thread", "THREAD_TIMEOUT %s - NEXT_STATE: %s", thread.ID(), thread.TimeoutAction())
return thread.TimeoutAction(), nil return thread.TimeoutAction(), nil
} }
} return "wait", nil
} }
func NewBaseThread(ctx * GraphContext, name string) (BaseThread, error) { var ThreadAbort = func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) {
state := NewBaseThreadState(name) return "", fmt.Errorf("%s aborted by signal from %s", thread.ID(), signal.Source())
thread := BaseThread{
BaseNode: NewNode(ctx, RandID(), &state),
Actions: map[string]func(*GraphContext)(string, error){},
Handlers: map[string]func(*GraphContext,GraphSignal)(string, error){},
timeout: nil,
timeout_action: "",
} }
return thread, nil
var ThreadCancel = func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) {
return "", nil
} }
func NewBaseThreadState(name string) BaseThreadState { func NewBaseThreadState(name string) BaseThreadState {
@ -428,7 +422,26 @@ func NewBaseThreadState(name string) BaseThreadState {
} }
} }
func NewThread(ctx * GraphContext, name string, requirements []Lockable) (* BaseThread, error) { func NewBaseThread(ctx * GraphContext, name string) (BaseThread, error) {
state := NewBaseThreadState(name)
thread := BaseThread{
BaseNode: NewNode(ctx, RandID(), &state),
Actions: ThreadActions{
"wait": ThreadWait,
"start": ThreadDefaultStart,
},
Handlers: ThreadHandlers{
"abort": ThreadAbort,
"cancel": ThreadCancel,
},
timeout: nil,
timeout_action: "",
}
return thread, nil
}
func NewThread(ctx * GraphContext, name string, requirements []Lockable, actions ThreadActions, handlers ThreadHandlers) (* BaseThread, error) {
thread, err := NewBaseThread(ctx, name) thread, err := NewBaseThread(ctx, name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -443,12 +456,12 @@ func NewThread(ctx * GraphContext, name string, requirements []Lockable) (* Base
} }
} }
thread_ptr.Actions["wait"] = ThreadWait(thread_ptr) for key, fn := range(actions) {
thread_ptr.Handlers["abort"] = ThreadAbort(thread_ptr) thread.Actions[key] = fn
thread_ptr.Handlers["cancel"] = ThreadCancel(thread_ptr) }
thread_ptr.Actions["start"] = func(ctx * GraphContext) (string, error) { for key, fn := range(handlers) {
return "", nil thread.Handlers[key] = fn
} }
return thread_ptr, nil return thread_ptr, nil

@ -0,0 +1,21 @@
package graphvent
import (
"testing"
"time"
)
func TestNewEvent(t * testing.T) {
ctx := testContext(t)
t1, err := NewThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{})
fatalErr(t, err)
go func(thread Thread) {
time.Sleep(1*time.Second)
SendUpdate(ctx, t1, CancelSignal(nil))
}(t1)
err = RunThread(ctx, t1)
fatalErr(t, err)
}