Updated thread.go and thread_test.go

graph-rework-2
noah metz 2023-07-09 15:59:41 -06:00
parent 5c416a4a3f
commit b846bbb812
2 changed files with 326 additions and 461 deletions

@ -10,278 +10,76 @@ import (
) )
// Update the threads listeners, and notify the parent to do the same // Update the threads listeners, and notify the parent to do the same
func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSignal, states NodeStateMap) { func (thread * SimpleThread) Signal(ctx * Context, signal GraphSignal, nodes NodeMap) error {
thread_state := states[thread.ID()].(ThreadState) err := thread.SimpleLockable.Signal(ctx, signal, nodes)
if err != nil {
return err
}
if signal.Direction() == Up { if signal.Direction() == Up {
// Child->Parent, thread updates parent and connected requirement // Child->Parent, thread updates parent and connected requirement
if thread_state.Parent() != nil { if thread.parent != nil {
UseMoreStates(ctx, []GraphNode{thread_state.Parent()}, states, func(states NodeStateMap) (error) { UseMoreStates(ctx, []Node{thread.parent}, nodes, func(nodes NodeMap) error {
SendUpdate(ctx, thread_state.Parent(), signal, states) thread.parent.Signal(ctx, signal, nodes)
return nil return nil
}) })
} }
UseMoreStates(ctx, NodeList(thread_state.Dependencies()), states, func(states NodeStateMap) (error) {
for _, dep := range(thread_state.Dependencies()) {
SendUpdate(ctx, dep, signal, states)
}
return nil
})
} else if signal.Direction() == Down { } else if signal.Direction() == Down {
// Parent->Child, updates children and dependencies // Parent->Child, updates children and dependencies
UseMoreStates(ctx, NodeList(thread_state.Children()), states, func(states NodeStateMap) (error) { UseMoreStates(ctx, NodeList(thread.children), nodes, func(nodes NodeMap) error {
for _, child := range(thread_state.Children()) { for _, child := range(thread.children) {
SendUpdate(ctx, child, signal, states) child.Signal(ctx, signal, nodes)
}
return nil
})
UseMoreStates(ctx, NodeList(thread_state.Requirements()), states, func(states NodeStateMap) (error) {
for _, requirement := range(thread_state.Requirements()) {
SendUpdate(ctx, requirement, signal, states)
} }
return nil return nil
}) })
} else if signal.Direction() == Direct { } else if signal.Direction() == Direct {
} else { } else {
panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction()))
} }
thread.signal <- signal thread.signal <- signal
return nil
} }
// Interface to represent any type of thread information
type ThreadInfo interface { type ThreadInfo interface {
} }
// An Thread is a lockable that has an additional parent->child relationship with other Threads func (thread * SimpleThread) SetTimeout(timeout time.Time, action string) {
// This relationship allows the thread tree to be modified independent of the lockable state thread.timeout = timeout
type ThreadState interface { thread.timeout_action = action
LockableState thread.timeout_chan = time.After(time.Until(timeout))
Parent() Thread
SetParent(parent Thread)
Children() []Thread
Child(id NodeID) Thread
ChildInfo(child NodeID) ThreadInfo
AddChild(child Thread, info ThreadInfo) error
RemoveChild(child Thread)
Start() error
Stop() error
State() string
TimeoutAction() string
SetTimeout(end_time time.Time, action string)
}
type BaseThreadState struct {
BaseLockableState
state_name string
parent Thread
children []Thread
child_info map[NodeID] ThreadInfo
InfoType reflect.Type
timeout time.Time
timeout_action string
}
type BaseThreadStateJSON struct {
Parent *NodeID `json:"parent"`
Children map[NodeID]interface{} `json:"children"`
Timeout time.Time `json:"timeout"`
TimeoutAction string `json:"timeout_action"`
StateName string `json:"state_name"`
BaseLockableStateJSON
}
func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON {
children := map[NodeID]interface{}{}
for _, child := range(state.children) {
children[child.ID()] = state.child_info[child.ID()]
}
var parent_id *NodeID = nil
if state.parent != nil {
new_str := state.parent.ID()
parent_id = &new_str
}
lockable_state := SaveBaseLockableState(&state.BaseLockableState)
ret := BaseThreadStateJSON{
Parent: parent_id,
Children: children,
Timeout: state.timeout,
TimeoutAction: state.timeout_action,
StateName: state.state_name,
BaseLockableStateJSON: lockable_state,
}
return ret
}
func RestoreBaseThread(ctx * GraphContext, id NodeID, actions ThreadActions, handlers ThreadHandlers) BaseThread {
base_lockable := RestoreBaseLockable(ctx, id)
thread := BaseThread{
BaseLockable: base_lockable,
Actions: actions,
Handlers: handlers,
child_waits: &sync.WaitGroup{},
active: false,
active_lock: &sync.Mutex{},
}
return thread
}
func LoadSimpleThread(ctx * GraphContext, id NodeID) (GraphNode, error) {
thread := RestoreBaseThread(ctx, id, BaseThreadActions, BaseThreadHandlers)
return &thread, nil
}
func RestoreBaseThreadState(ctx * GraphContext, j BaseThreadStateJSON, loaded_nodes NodeMap) (*BaseThreadState, error) {
lockable_state, err := RestoreBaseLockableState(ctx, j.BaseLockableStateJSON, loaded_nodes)
if err != nil {
return nil, err
}
state := BaseThreadState{
BaseLockableState: *lockable_state,
parent: nil,
children: make([]Thread, len(j.Children)),
child_info: map[NodeID]ThreadInfo{},
InfoType: nil,
state_name: j.StateName,
timeout: j.Timeout,
timeout_action: j.TimeoutAction,
}
if j.Parent != nil {
p, err := LoadNodeRecurse(ctx, *j.Parent, loaded_nodes)
if err != nil {
return nil, err
}
p_t, ok := p.(Thread)
if ok == false {
return nil, err
}
state.parent = p_t
}
// TODO: Call different loading functions(to return different ThreadInfo types, based on the j.Type,
// Will probably have to add another set of callbacks to the context for this, and since there's now 3 sets that need to be matching it could be useful to move them to a struct so it's easier to keep in sync
i := 0
for id, info_raw := range(j.Children) {
child_node, err := LoadNodeRecurse(ctx, id, loaded_nodes)
if err != nil {
return nil, err
}
child_t, ok := child_node.(Thread)
if ok == false {
return nil, fmt.Errorf("%+v is not a Thread as expected", child_node)
}
state.children[i] = child_t
info_map, ok := info_raw.(map[string]interface{})
if ok == false && info_raw != nil {
return nil, fmt.Errorf("Parsed map wrong type: %+v", info_raw)
}
info_fn, exists := ctx.InfoLoadFuncs[j.Type]
var parsed_info ThreadInfo
if exists == false {
parsed_info = nil
} else {
parsed_info, err = info_fn(ctx, info_map)
if err != nil {
return nil, err
}
}
state.child_info[id] = parsed_info
i++
}
return &state, nil
}
func LoadSimpleThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){
var j BaseThreadStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
state, err := RestoreBaseThreadState(ctx, j, loaded_nodes)
if err != nil {
return nil, err
}
return state, nil
} }
func (state * BaseThreadState) SetTimeout(timeout time.Time, action string) { func (thread * SimpleThread) TimeoutAction() string {
state.timeout = timeout return thread.timeout_action
state.timeout_action = action
} }
func (state * BaseThreadState) TimeoutAction() string { func (thread * SimpleThread) State() string {
return state.timeout_action return thread.state_name
} }
func (state * BaseThreadState) MarshalJSON() ([]byte, error) { func (thread * SimpleThread) SetState(new_state string) error {
thread_state := SaveBaseThreadState(state) if new_state == "" {
return json.Marshal(&thread_state) return fmt.Errorf("Cannot set state to '' with SetState")
}
func (state * BaseThreadState) State() string {
return state.state_name
}
func (state * BaseThreadState) SetState(new_state string) error {
if new_state == "init" {
return fmt.Errorf("Cannot set a thread to 'init' with SetState")
} else if new_state == "finished" {
return fmt.Errorf("Cannot set a thread to 'finished' with SetState")
} else if new_state == "started" {
return fmt.Errorf("Cannot set a thread to 'started' with SetState")
} }
state.state_name = new_state thread.state_name = new_state
return nil return nil
} }
func (state * BaseThreadState) Start() error { func (thread * SimpleThread) Parent() Thread {
if state.state_name != "init" { return thread.parent
return fmt.Errorf("Cannot start a thread that's already started")
}
state.state_name = "started"
return nil
}
func (state * BaseThreadState) Stop() error {
if state.state_name == "finished" {
return fmt.Errorf("Cannot stop a finished thread")
} else if state.state_name == "init" {
return fmt.Errorf("Cannot stop a thread that hasn't been started")
}
state.state_name = "finished"
return nil
} }
func (state * BaseThreadState) Parent() Thread { func (thread * SimpleThread) SetParent(parent Thread) {
return state.parent thread.parent = parent
} }
func (state * BaseThreadState) SetParent(parent Thread) { func (thread * SimpleThread) Children() []Thread {
state.parent = parent return thread.children
} }
func (state * BaseThreadState) Children() []Thread { func (thread * SimpleThread) Child(id NodeID) Thread {
return state.children for _, child := range(thread.children) {
}
func (state * BaseThreadState) Child(id NodeID) Thread {
for _, child := range(state.children) {
if child.ID() == id { if child.ID() == id {
return child return child
} }
@ -291,15 +89,14 @@ func (state * BaseThreadState) Child(id NodeID) Thread {
func (state * BaseThreadState) ChildInfo(child NodeID) ThreadInfo { func (thread * SimpleThread) ChildInfo(child NodeID) ThreadInfo {
return state.child_info[child] return thread.child_info[child]
} }
// Requires thread and childs state to be locked for write // Requires thread and childs thread to be locked for write
func UnlinkThreads(ctx * GraphContext, thread Thread, child Thread) error { func UnlinkThreads(ctx * Context, thread Thread, child Thread) error {
state := thread.State().(ThreadState) var found Node = nil
var found GraphNode = nil for _, c := range(thread.Children()) {
for _, c := range(state.Children()) {
if child.ID() == c.ID() { if child.ID() == c.ID() {
found = c found = c
break break
@ -310,16 +107,15 @@ func UnlinkThreads(ctx * GraphContext, thread Thread, child Thread) error {
return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID()) return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID())
} }
child_state := child.State().(ThreadState) child.SetParent(nil)
child_state.SetParent(nil) thread.RemoveChild(child)
state.RemoveChild(child)
return nil return nil
} }
func (state * BaseThreadState) RemoveChild(child Thread) { func (thread * SimpleThread) RemoveChild(child Thread) {
idx := -1 idx := -1
for i, c := range(state.children) { for i, c := range(thread.children) {
if c.ID() == child.ID() { if c.ID() == child.ID() {
idx = i idx = i
break break
@ -327,47 +123,46 @@ func (state * BaseThreadState) RemoveChild(child Thread) {
} }
if idx == -1 { if idx == -1 {
panic(fmt.Sprintf("%s is not a child of %s", child.ID(), state.Name())) panic(fmt.Sprintf("%s is not a child of %s", child.ID(), thread.Name()))
} }
child_len := len(state.children) child_len := len(thread.children)
state.children[idx] = state.children[child_len-1] thread.children[idx] = thread.children[child_len-1]
state.children = state.children[0:child_len-1] thread.children = thread.children[0:child_len-1]
} }
func (state * BaseThreadState) AddChild(child Thread, info ThreadInfo) error { func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error {
if child == nil { if child == nil {
return fmt.Errorf("Will not connect nil to the thread tree") return fmt.Errorf("Will not connect nil to the thread tree")
} }
_, exists := state.child_info[child.ID()] _, exists := thread.child_info[child.ID()]
if exists == true { if exists == true {
return fmt.Errorf("Will not connect the same child twice") return fmt.Errorf("Will not connect the same child twice")
} }
if info == nil && state.InfoType != nil { if info == nil && thread.InfoType != nil {
return fmt.Errorf("nil info passed when expecting info") return fmt.Errorf("nil info passed when expecting info")
} else if info != nil { } else if info != nil {
if reflect.TypeOf(info) != state.InfoType { if reflect.TypeOf(info) != thread.InfoType {
return fmt.Errorf("info type mismatch, expecting %+v", state.InfoType) return fmt.Errorf("info type mismatch, expecting %+v", thread.InfoType)
} }
} }
state.children = append(state.children, child) thread.children = append(thread.children, child)
state.child_info[child.ID()] = info thread.child_info[child.ID()] = info
return nil return nil
} }
func checkIfChild(ctx * GraphContext, thread_id NodeID, cur_state ThreadState, cur_id NodeID) bool { func checkIfChild(ctx * Context, target Thread, cur Thread, nodes NodeMap) bool {
for _, child := range(cur_state.Children()) { for _, child := range(cur.Children()) {
if child.ID() == thread_id { if child.ID() == target.ID() {
return true return true
} }
is_child := false is_child := false
UseStates(ctx, []GraphNode{child}, func(states NodeStateMap) (error) { UpdateMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error {
child_state := states[child.ID()].(ThreadState) is_child = checkIfChild(ctx, target, child, nodes)
is_child = checkIfChild(ctx, cur_id, child_state, child.ID())
return nil return nil
}) })
if is_child { if is_child {
@ -378,8 +173,8 @@ func checkIfChild(ctx * GraphContext, thread_id NodeID, cur_state ThreadState, c
return false return false
} }
// Requires thread and childs state to be locked for write // Requires thread and childs thread to be locked for write
func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInfo) error { func LinkThreads(ctx * Context, thread Thread, child Thread, info ThreadInfo, nodes NodeMap) error {
if ctx == nil || thread == nil || child == nil { if ctx == nil || thread == nil || child == nil {
return fmt.Errorf("invalid input") return fmt.Errorf("invalid input")
} }
@ -388,26 +183,23 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf
return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) return fmt.Errorf("Will not link %s as a child of itself", thread.ID())
} }
thread_state := thread.State().(ThreadState) if child.Parent() != nil {
child_state := child.State().(ThreadState)
if child_state.Parent() != nil {
return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID())
} }
if checkIfChild(ctx, thread.ID(), child_state, child.ID()) == true { if checkIfChild(ctx, thread, child, nodes) == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID())
} }
if checkIfChild(ctx, child.ID(), thread_state, thread.ID()) == true { if checkIfChild(ctx, child, thread, nodes) == true {
return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID())
} }
err := thread_state.AddChild(child, info) err := thread.AddChild(child, info)
if err != nil { if err != nil {
return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err) return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err)
} }
child_state.SetParent(thread) child.SetParent(thread)
if err != nil { if err != nil {
return err return err
@ -416,24 +208,195 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf
return nil return nil
} }
// Thread is the interface that thread tree nodes must implement type ThreadAction func(* Context, Thread)(string, error)
type ThreadActions map[string]ThreadAction
type ThreadHandler func(* Context, Thread, GraphSignal)(string, error)
type ThreadHandlers map[string]ThreadHandler
type Thread interface { type Thread interface {
// All Threads are Lockables
Lockable Lockable
/// State Modification Functions
SetParent(parent Thread)
AddChild(child Thread, info ThreadInfo) error
RemoveChild(child Thread)
SetState(new_thread string) error
SetTimeout(end_time time.Time, action string)
/// State Reading Functions
Parent() Thread
Children() []Thread
Child(id NodeID) Thread
ChildInfo(child NodeID) ThreadInfo
State() string
TimeoutAction() string
/// Functions that dont read/write thread
// Deserialize the attribute map from json.Unmarshal
DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error)
SetActive(active bool) error
Action(action string) (ThreadAction, bool) Action(action string) (ThreadAction, bool)
Handler(signal_type string) (ThreadHandler, bool) Handler(signal_type string) (ThreadHandler, bool)
SetTimeout(end time.Time) // Internal timeout channel for thread
Timeout() <-chan time.Time Timeout() <-chan time.Time
// Internal signal channel for thread
SignalChannel() <-chan GraphSignal
ClearTimeout() ClearTimeout()
ChildWaits() *sync.WaitGroup ChildWaits() *sync.WaitGroup
Start() error }
Stop() error
type SimpleThread struct {
SimpleLockable
actions ThreadActions
handlers ThreadHandlers
timeout_chan <-chan time.Time
signal chan GraphSignal
child_waits *sync.WaitGroup
active bool
active_lock *sync.Mutex
state_name string
parent Thread
children []Thread
child_info map[NodeID] ThreadInfo
InfoType reflect.Type
timeout time.Time
timeout_action string
}
func (thread * SimpleThread) Serialize() ([]byte, error) {
thread_json := NewSimpleThreadJSON(thread)
return json.MarshalIndent(&thread_json, "", " ")
}
func (thread * SimpleThread) SignalChannel() <-chan GraphSignal {
return thread.signal
}
type SimpleThreadJSON struct {
Parent *NodeID `json:"parent"`
Children map[NodeID]interface{} `json:"children"`
Timeout time.Time `json:"timeout"`
TimeoutAction string `json:"timeout_action"`
StateName string `json:"state_name"`
SimpleLockableJSON
}
func NewSimpleThreadJSON(thread *SimpleThread) SimpleThreadJSON {
children := map[NodeID]interface{}{}
for _, child := range(thread.children) {
children[child.ID()] = thread.child_info[child.ID()]
}
var parent_id *NodeID = nil
if thread.parent != nil {
new_str := thread.parent.ID()
parent_id = &new_str
}
lockable_json := NewSimpleLockableJSON(&thread.SimpleLockable)
return SimpleThreadJSON{
Parent: parent_id,
Children: children,
Timeout: thread.timeout,
TimeoutAction: thread.timeout_action,
StateName: thread.state_name,
SimpleLockableJSON: lockable_json,
}
}
func LoadSimpleThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j SimpleThreadJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
thread := NewSimpleThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers)
nodes[id] = &thread
err = RestoreSimpleThread(ctx, &thread, j, nodes)
if err != nil {
return nil, err
}
return &thread, nil
}
// SimpleThread as no associated info with children
func (thread * SimpleThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) {
if len(data) > 0 {
return nil, fmt.Errorf("SimpleThread expected to deserialize no info but got %d length data: %s", len(data), string(data))
}
return nil, nil
}
func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes NodeMap) error {
thread.SetTimeout(j.Timeout, j.TimeoutAction)
if j.Parent != nil {
p, err := LoadNodeRecurse(ctx, *j.Parent, nodes)
if err != nil {
return err
}
p_t, ok := p.(Thread)
if ok == false {
return err
}
thread.SetParent(p_t)
}
// TODO: Call different loading functions(to return different ThreadInfo types, based on the j.Type,
// Will probably have to add another set of callbacks to the context for this, and since there's now 3 sets that need to be matching it could be useful to move them to a struct so it's easier to keep in sync
i := 0
for id, info_raw := range(j.Children) {
child_node, err := LoadNodeRecurse(ctx, id, nodes)
if err != nil {
return err
}
child_t, ok := child_node.(Thread)
if ok == false {
return fmt.Errorf("%+v is not a Thread as expected", child_node)
}
info_ser, err := json.Marshal(info_raw)
if err != nil {
return err
}
parsed_info, err := thread.DeserializeInfo(ctx, info_ser)
if err != nil {
return err
}
thread.AddChild(child_t, parsed_info)
i++
}
return RestoreSimpleLockable(ctx, thread, j.SimpleLockableJSON, nodes)
}
const THREAD_SIGNAL_BUFFER_SIZE = 128
func NewSimpleThread(id NodeID, name string, state_name string, info_type reflect.Type, actions ThreadActions, handlers ThreadHandlers) SimpleThread {
return SimpleThread{
SimpleLockable: NewSimpleLockable(id, name),
state_name: state_name,
signal: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE),
children: []Thread{},
child_info: map[NodeID]ThreadInfo{},
actions: actions,
handlers: handlers,
child_waits: &sync.WaitGroup{},
active_lock: &sync.Mutex{},
}
} }
// Requires that thread is already locked for read in UseStates // Requires that thread is already locked for read in UseStates
func FindChild(ctx * GraphContext, thread Thread, id NodeID, states NodeStateMap) Thread { func FindChild(ctx * Context, thread Thread, id NodeID, nodes NodeMap) Thread {
if thread == nil { if thread == nil {
panic("cannot recurse through nil") panic("cannot recurse through nil")
} }
@ -441,11 +404,10 @@ func FindChild(ctx * GraphContext, thread Thread, id NodeID, states NodeStateMap
return thread return thread
} }
thread_state := thread.State().(ThreadState) for _, child := range thread.Children() {
for _, child := range thread_state.Children() {
var result Thread = nil var result Thread = nil
UseMoreStates(ctx, []GraphNode{child}, states, func(states NodeStateMap) (error) { UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error {
result = FindChild(ctx, child, id, states) result = FindChild(ctx, child, id, nodes)
return nil return nil
}) })
if result != nil { if result != nil {
@ -456,7 +418,7 @@ func FindChild(ctx * GraphContext, thread Thread, id NodeID, states NodeStateMap
return nil return nil
} }
func ChildGo(ctx * GraphContext, thread Thread, child Thread, first_action string) { func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) {
thread.ChildWaits().Add(1) thread.ChildWaits().Add(1)
go func(child Thread) { go func(child Thread) {
ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID()) ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID())
@ -471,10 +433,10 @@ func ChildGo(ctx * GraphContext, thread Thread, child Thread, first_action strin
} }
// Main Loop for Threads // Main Loop for Threads
func ThreadLoop(ctx * GraphContext, thread Thread, first_action string) error { func ThreadLoop(ctx * Context, thread Thread, first_action string) error {
// Start the thread, error if double-started // Start the thread, error if double-started
ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action) ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action)
err := thread.Start() err := thread.SetActive(true)
if err != nil { if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err) ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err)
return err return err
@ -494,15 +456,14 @@ func ThreadLoop(ctx * GraphContext, thread Thread, first_action string) error {
} }
} }
err = thread.Stop() err = thread.SetActive(false)
if err != nil { if err != nil {
ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err) ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err)
return err return err
} }
err = UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) (error) { err = UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error {
thread_state := thread.State().(ThreadState) err := thread.SetState("finished")
err := thread_state.Stop()
if err != nil { if err != nil {
return err return err
} }
@ -519,95 +480,46 @@ func ThreadLoop(ctx * GraphContext, thread Thread, first_action string) error {
return nil return nil
} }
type ThreadAction func(* GraphContext, Thread)(string, error)
type ThreadActions map[string]ThreadAction
type ThreadHandler func(* GraphContext, Thread, GraphSignal)(string, error)
type ThreadHandlers map[string]ThreadHandler
// Thread is the most basic thread that can exist in the thread tree.
// On start it automatically transitions to completion.
// This node by itself doesn't implement any special behaviours for children, so they will be ignored.
// When started, this thread automatically transitions to completion
type BaseThread struct {
BaseLockable
Actions ThreadActions
Handlers ThreadHandlers
timeout_chan <-chan time.Time
child_waits *sync.WaitGroup
active bool
active_lock *sync.Mutex
}
func (thread * BaseThread) ChildWaits() *sync.WaitGroup { func (thread * SimpleThread) ChildWaits() *sync.WaitGroup {
return thread.child_waits return thread.child_waits
} }
func (thread * BaseThread) Start() error { func (thread * SimpleThread) SetActive(active bool) error {
thread.active_lock.Lock() thread.active_lock.Lock()
defer thread.active_lock.Unlock() defer thread.active_lock.Unlock()
if thread.active == true { if thread.active == true && active == true {
return fmt.Errorf("%s is active, cannot start", thread.ID()) return fmt.Errorf("%s is active, cannot set active", thread.ID())
} else if thread.active == false && active == false {
return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID())
} }
thread.active = true thread.active = active
return nil return nil
} }
func (thread * BaseThread) Stop() error { func (thread * SimpleThread) Action(action string) (ThreadAction, bool) {
thread.active_lock.Lock() action_fn, exists := thread.actions[action]
defer thread.active_lock.Unlock()
if thread.active == false {
return fmt.Errorf("%s is not active, cannot stop", thread.ID())
}
thread.active = false
return nil
}
func (thread * BaseThread) CanLock(node GraphNode, state LockableState) error {
return nil
}
func (thread * BaseThread) CanUnlock(node GraphNode, state LockableState) error {
return nil
}
func (thread * BaseThread) Lock(node GraphNode, state LockableState) {
return
}
func (thread * BaseThread) Unlock(node GraphNode, state LockableState) {
return
}
func (thread * BaseThread) Action(action string) (ThreadAction, bool) {
action_fn, exists := thread.Actions[action]
return action_fn, exists return action_fn, exists
} }
func (thread * BaseThread) Handler(signal_type string) (ThreadHandler, bool) { func (thread * SimpleThread) Handler(signal_type string) (ThreadHandler, bool) {
handler, exists := thread.Handlers[signal_type] handler, exists := thread.handlers[signal_type]
return handler, exists return handler, exists
} }
func (thread * BaseThread) Timeout() <-chan time.Time { func (thread * SimpleThread) Timeout() <-chan time.Time {
return thread.timeout_chan return thread.timeout_chan
} }
func (thread * BaseThread) ClearTimeout() { func (thread * SimpleThread) ClearTimeout() {
thread.timeout_chan = nil thread.timeout_chan = nil
} }
func (thread * BaseThread) SetTimeout(end time.Time) { var ThreadStart = func(ctx * Context, thread Thread) error {
thread.timeout_chan = time.After(time.Until(end)) err := UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error {
}
var ThreadStart = func(ctx * GraphContext, thread Thread) error {
err := UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) (error) {
thread_state := thread.State().(ThreadState)
owner_id := NodeID("") owner_id := NodeID("")
if thread_state.Owner() != nil { if thread.Owner() != nil {
owner_id = thread_state.Owner().ID() owner_id = thread.Owner().ID()
} }
if owner_id != thread.ID() { if owner_id != thread.ID() {
err := LockLockables(ctx, []Lockable{thread}, thread, nodes) err := LockLockables(ctx, []Lockable{thread}, thread, nodes)
@ -615,7 +527,7 @@ var ThreadStart = func(ctx * GraphContext, thread Thread) error {
return err return err
} }
} }
return thread_state.Start() return thread.SetState("started")
}) })
if err != nil { if err != nil {
@ -625,7 +537,7 @@ var ThreadStart = func(ctx * GraphContext, thread Thread) error {
return nil return nil
} }
var ThreadDefaultStart = func(ctx * GraphContext, thread Thread) (string, error) { var ThreadDefaultStart = func(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_START: %s", thread.ID()) ctx.Log.Logf("thread", "THREAD_DEFAULT_START: %s", thread.ID())
err := ThreadStart(ctx, thread) err := ThreadStart(ctx, thread)
if err != nil { if err != nil {
@ -634,12 +546,12 @@ var ThreadDefaultStart = func(ctx * GraphContext, thread Thread) (string, error)
return "wait", nil return "wait", nil
} }
var ThreadDefaultRestore = func(ctx * GraphContext, thread Thread) (string, error) { var ThreadDefaultRestore = func(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_DEFAULT_RESTORE: %s", thread.ID()) ctx.Log.Logf("thread", "THREAD_DEFAULT_RESTORE: %s", thread.ID())
return "wait", nil return "wait", nil
} }
var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) { var ThreadWait = func(ctx * Context, thread Thread) (string, error) {
ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout()) ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout())
for { for {
select { select {
@ -658,9 +570,8 @@ var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) {
} }
case <- thread.Timeout(): case <- thread.Timeout():
timeout_action := "" timeout_action := ""
err := UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) error { err := UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error {
thread_state := thread.State().(ThreadState) timeout_action = thread.TimeoutAction()
timeout_action = thread_state.TimeoutAction()
thread.ClearTimeout() thread.ClearTimeout()
return nil return nil
}) })
@ -688,35 +599,23 @@ func NewThreadAbortedError(aborter NodeID) ThreadAbortedError {
} }
// Default thread abort is to return a ThreadAbortedError // Default thread abort is to return a ThreadAbortedError
func ThreadAbort(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error { UseStates(ctx, []Node{thread}, func(nodes NodeMap) error {
SendUpdate(ctx, thread, NewSignal(thread, "thread_aborted"), states) thread.Signal(ctx, NewSignal(thread, "thread_aborted"), nodes)
return nil return nil
}) })
return "", NewThreadAbortedError(signal.Source()) return "", NewThreadAbortedError(signal.Source())
} }
// Default thread cancel is to finish the thread // Default thread cancel is to finish the thread
func ThreadCancel(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) {
UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error { UseStates(ctx, []Node{thread}, func(nodes NodeMap) error {
SendUpdate(ctx, thread, NewSignal(thread, "thread_cancelled"), states) thread.Signal(ctx, NewSignal(thread, "thread_cancelled"), nodes)
return nil return nil
}) })
return "", nil return "", nil
} }
func NewBaseThreadState(name string, _type string) BaseThreadState {
return BaseThreadState{
BaseLockableState: NewBaseLockableState(name, _type),
children: []Thread{},
child_info: map[NodeID]ThreadInfo{},
parent: nil,
timeout: time.Time{},
timeout_action: "wait",
state_name: "init",
}
}
func NewThreadActions() ThreadActions{ func NewThreadActions() ThreadActions{
actions := ThreadActions{} actions := ThreadActions{}
for k, v := range(BaseThreadActions) { for k, v := range(BaseThreadActions) {
@ -745,48 +644,3 @@ var BaseThreadHandlers = ThreadHandlers{
"abort": ThreadAbort, "abort": ThreadAbort,
"cancel": ThreadCancel, "cancel": ThreadCancel,
} }
func NewBaseThread(ctx * GraphContext, actions ThreadActions, handlers ThreadHandlers, state ThreadState) (BaseThread, error) {
lockable, err := NewBaseLockable(ctx, state)
if err != nil {
return BaseThread{}, err
}
thread := BaseThread{
BaseLockable: lockable,
Actions: actions,
Handlers: handlers,
child_waits: &sync.WaitGroup{},
active: false,
active_lock: &sync.Mutex{},
}
return thread, nil
}
func NewSimpleThread(ctx * GraphContext, name string, requirements []Lockable, actions ThreadActions, handlers ThreadHandlers) (* BaseThread, error) {
state := NewBaseThreadState(name, "simple_thread")
thread, err := NewBaseThread(ctx, actions, handlers, &state)
if err != nil {
return nil, err
}
thread_ptr := &thread
if len(requirements) > 0 {
req_nodes := make([]GraphNode, len(requirements))
for i, req := range(requirements) {
req_nodes[i] = req
}
err = UpdateStates(ctx, req_nodes, func(nodes NodeMap) error {
return LinkLockables(ctx, thread_ptr, requirements, nodes)
})
if err != nil {
return nil, err
}
}
return thread_ptr, nil
}

@ -10,22 +10,21 @@ import (
func TestNewThread(t * testing.T) { func TestNewThread(t * testing.T) {
ctx := testContext(t) ctx := testContext(t)
t1, err := NewSimpleThread(ctx, "Test thread 1", []Lockable{}, BaseThreadActions, BaseThreadHandlers) t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
fatalErr(t, err) t1 := &t1_r
go func(thread Thread) { go func(thread Thread) {
time.Sleep(10*time.Millisecond) time.Sleep(10*time.Millisecond)
UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { UseStates(ctx, []Node{t1}, func(nodes NodeMap) error {
SendUpdate(ctx, t1, CancelSignal(nil), states) return t1.Signal(ctx, CancelSignal(nil), nodes)
return nil
}) })
}(t1) }(t1)
err = ThreadLoop(ctx, t1, "start") err := ThreadLoop(ctx, t1, "start")
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) (error) { err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) (error) {
owner := states[t1.ID()].(ThreadState).Owner() owner := t1.owner
if owner != nil { if owner != nil {
return fmt.Errorf("Wrong owner %+v", owner) return fmt.Errorf("Wrong owner %+v", owner)
} }
@ -36,17 +35,21 @@ func TestNewThread(t * testing.T) {
func TestThreadWithRequirement(t * testing.T) { func TestThreadWithRequirement(t * testing.T) {
ctx := testContext(t) ctx := testContext(t)
l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
fatalErr(t, err) l1 := &l1_r
t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
t1 := &t1_r
t1, err := NewSimpleThread(ctx, "Test Thread 1", []Lockable{l1}, BaseThreadActions, BaseThreadHandlers) err := UpdateStates(ctx, []Node{l1, t1}, func(nodes NodeMap) error {
return LinkLockables(ctx, t1, []Lockable{l1}, nodes)
})
fatalErr(t, err) fatalErr(t, err)
go func (thread Thread) { go func (thread Thread) {
time.Sleep(10*time.Millisecond) time.Sleep(10*time.Millisecond)
UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { UseStates(ctx, []Node{t1}, func(nodes NodeMap) error {
SendUpdate(ctx, t1, CancelSignal(nil), states) return t1.Signal(ctx, CancelSignal(nil), nodes)
return nil
}) })
}(t1) }(t1)
fatalErr(t, err) fatalErr(t, err)
@ -54,8 +57,8 @@ func TestThreadWithRequirement(t * testing.T) {
err = ThreadLoop(ctx, t1, "start") err = ThreadLoop(ctx, t1, "start")
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) (error) { err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) {
owner := states[l1.ID()].(LockableState).Owner() owner := l1.owner
if owner != nil { if owner != nil {
return fmt.Errorf("Wrong owner %+v", owner) return fmt.Errorf("Wrong owner %+v", owner)
} }
@ -66,22 +69,25 @@ func TestThreadWithRequirement(t * testing.T) {
func TestThreadDBLoad(t * testing.T) { func TestThreadDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{}) ctx := logTestContext(t, []string{})
l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) l1_r := NewSimpleLockable(RandID(), "Test Lockable 1")
fatalErr(t, err) l1 := &l1_r
t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
t1, err := NewSimpleThread(ctx, "Test Thread 1", []Lockable{l1}, BaseThreadActions, BaseThreadHandlers) t1 := &t1_r
fatalErr(t, err)
err := UpdateStates(ctx, []Node{t1, l1}, func(nodes NodeMap) error {
return LinkLockables(ctx, t1, []Lockable{l1}, nodes)
})
UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error {
SendUpdate(ctx, t1, CancelSignal(nil), states) return t1.Signal(ctx, CancelSignal(nil), nodes)
return nil
}) })
fatalErr(t, err)
err = ThreadLoop(ctx, t1, "start") err = ThreadLoop(ctx, t1, "start")
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error {
ser, err := json.MarshalIndent(states[t1.ID()], "", " ") ser, err := json.MarshalIndent(nodes[t1.ID()], "", " ")
fmt.Printf("\n%s\n\n", ser) fmt.Printf("\n%s\n\n", ser)
return err return err
}) })
@ -89,8 +95,8 @@ func TestThreadDBLoad(t * testing.T) {
t1_loaded, err := LoadNode(ctx, t1.ID()) t1_loaded, err := LoadNode(ctx, t1.ID())
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []GraphNode{t1_loaded}, func(states NodeStateMap) error { err = UseStates(ctx, []Node{t1_loaded}, func(nodes NodeMap) error {
ser, err := json.MarshalIndent(states[t1_loaded.ID()], "", " ") ser, err := json.MarshalIndent(nodes[t1_loaded.ID()], "", " ")
fmt.Printf("\n%s\n\n", ser) fmt.Printf("\n%s\n\n", ser)
return err return err
}) })
@ -98,15 +104,20 @@ func TestThreadDBLoad(t * testing.T) {
func TestThreadUnlink(t * testing.T) { func TestThreadUnlink(t * testing.T) {
ctx := logTestContext(t, []string{}) ctx := logTestContext(t, []string{})
t1, err := NewSimpleThread(ctx, "Test Thread 1", []Lockable{}, BaseThreadActions, BaseThreadHandlers) t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers)
fatalErr(t, err) t1 := &t1_r
t2_r := NewSimpleThread(RandID(), "Test Thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers)
t2 := &t2_r
t2, err := NewSimpleThread(ctx, "Test Thread 2", []Lockable{}, BaseThreadActions, BaseThreadHandlers)
fatalErr(t, err)
err = LinkThreads(ctx, t1, t2, nil) err := UpdateStates(ctx, []Node{t1, t2}, func(nodes NodeMap) error {
fatalErr(t, err) err := LinkThreads(ctx, t1, t2, nil, nodes)
if err != nil {
return err
}
err = UnlinkThreads(ctx, t1, t2) return UnlinkThreads(ctx, t1, t2)
})
fatalErr(t, err) fatalErr(t, err)
} }