diff --git a/thread.go b/thread.go index ff43a6a..a68ed66 100644 --- a/thread.go +++ b/thread.go @@ -10,278 +10,76 @@ import ( ) // Update the threads listeners, and notify the parent to do the same -func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSignal, states NodeStateMap) { - thread_state := states[thread.ID()].(ThreadState) +func (thread * SimpleThread) Signal(ctx * Context, signal GraphSignal, nodes NodeMap) error { + err := thread.SimpleLockable.Signal(ctx, signal, nodes) + if err != nil { + return err + } if signal.Direction() == Up { // Child->Parent, thread updates parent and connected requirement - if thread_state.Parent() != nil { - UseMoreStates(ctx, []GraphNode{thread_state.Parent()}, states, func(states NodeStateMap) (error) { - SendUpdate(ctx, thread_state.Parent(), signal, states) + if thread.parent != nil { + UseMoreStates(ctx, []Node{thread.parent}, nodes, func(nodes NodeMap) error { + thread.parent.Signal(ctx, signal, nodes) return nil }) } - - UseMoreStates(ctx, NodeList(thread_state.Dependencies()), states, func(states NodeStateMap) (error) { - for _, dep := range(thread_state.Dependencies()) { - SendUpdate(ctx, dep, signal, states) - } - return nil - }) } else if signal.Direction() == Down { // Parent->Child, updates children and dependencies - UseMoreStates(ctx, NodeList(thread_state.Children()), states, func(states NodeStateMap) (error) { - for _, child := range(thread_state.Children()) { - SendUpdate(ctx, child, signal, states) - } - return nil - }) - - UseMoreStates(ctx, NodeList(thread_state.Requirements()), states, func(states NodeStateMap) (error) { - for _, requirement := range(thread_state.Requirements()) { - SendUpdate(ctx, requirement, signal, states) + UseMoreStates(ctx, NodeList(thread.children), nodes, func(nodes NodeMap) error { + for _, child := range(thread.children) { + child.Signal(ctx, signal, nodes) } return nil }) } else if signal.Direction() == Direct { - } else { panic(fmt.Sprintf("Invalid signal direction: %d", signal.Direction())) } - thread.signal <- signal + return nil } +// Interface to represent any type of thread information type ThreadInfo interface { } -// An Thread is a lockable that has an additional parent->child relationship with other Threads -// This relationship allows the thread tree to be modified independent of the lockable state -type ThreadState interface { - LockableState - - Parent() Thread - SetParent(parent Thread) - Children() []Thread - Child(id NodeID) Thread - ChildInfo(child NodeID) ThreadInfo - AddChild(child Thread, info ThreadInfo) error - RemoveChild(child Thread) - Start() error - Stop() error - State() string - - TimeoutAction() string - SetTimeout(end_time time.Time, action string) -} - -type BaseThreadState struct { - BaseLockableState - state_name string - parent Thread - children []Thread - child_info map[NodeID] ThreadInfo - InfoType reflect.Type - timeout time.Time - timeout_action string -} - -type BaseThreadStateJSON struct { - Parent *NodeID `json:"parent"` - Children map[NodeID]interface{} `json:"children"` - Timeout time.Time `json:"timeout"` - TimeoutAction string `json:"timeout_action"` - StateName string `json:"state_name"` - BaseLockableStateJSON +func (thread * SimpleThread) SetTimeout(timeout time.Time, action string) { + thread.timeout = timeout + thread.timeout_action = action + thread.timeout_chan = time.After(time.Until(timeout)) } -func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON { - children := map[NodeID]interface{}{} - for _, child := range(state.children) { - children[child.ID()] = state.child_info[child.ID()] - } - - var parent_id *NodeID = nil - if state.parent != nil { - new_str := state.parent.ID() - parent_id = &new_str - } - - lockable_state := SaveBaseLockableState(&state.BaseLockableState) - - ret := BaseThreadStateJSON{ - Parent: parent_id, - Children: children, - Timeout: state.timeout, - TimeoutAction: state.timeout_action, - StateName: state.state_name, - BaseLockableStateJSON: lockable_state, - } - - return ret +func (thread * SimpleThread) TimeoutAction() string { + return thread.timeout_action } -func RestoreBaseThread(ctx * GraphContext, id NodeID, actions ThreadActions, handlers ThreadHandlers) BaseThread { - base_lockable := RestoreBaseLockable(ctx, id) - thread := BaseThread{ - BaseLockable: base_lockable, - Actions: actions, - Handlers: handlers, - child_waits: &sync.WaitGroup{}, - active: false, - active_lock: &sync.Mutex{}, - } - - return thread -} - -func LoadSimpleThread(ctx * GraphContext, id NodeID) (GraphNode, error) { - thread := RestoreBaseThread(ctx, id, BaseThreadActions, BaseThreadHandlers) - return &thread, nil +func (thread * SimpleThread) State() string { + return thread.state_name } -func RestoreBaseThreadState(ctx * GraphContext, j BaseThreadStateJSON, loaded_nodes NodeMap) (*BaseThreadState, error) { - lockable_state, err := RestoreBaseLockableState(ctx, j.BaseLockableStateJSON, loaded_nodes) - if err != nil { - return nil, err - } - - state := BaseThreadState{ - BaseLockableState: *lockable_state, - parent: nil, - children: make([]Thread, len(j.Children)), - child_info: map[NodeID]ThreadInfo{}, - InfoType: nil, - state_name: j.StateName, - timeout: j.Timeout, - timeout_action: j.TimeoutAction, - } - - if j.Parent != nil { - p, err := LoadNodeRecurse(ctx, *j.Parent, loaded_nodes) - if err != nil { - return nil, err - } - p_t, ok := p.(Thread) - if ok == false { - return nil, err - } - state.parent = p_t +func (thread * SimpleThread) SetState(new_state string) error { + if new_state == "" { + return fmt.Errorf("Cannot set state to '' with SetState") } - // TODO: Call different loading functions(to return different ThreadInfo types, based on the j.Type, - // Will probably have to add another set of callbacks to the context for this, and since there's now 3 sets that need to be matching it could be useful to move them to a struct so it's easier to keep in sync - i := 0 - for id, info_raw := range(j.Children) { - child_node, err := LoadNodeRecurse(ctx, id, loaded_nodes) - if err != nil { - return nil, err - } - child_t, ok := child_node.(Thread) - if ok == false { - return nil, fmt.Errorf("%+v is not a Thread as expected", child_node) - } - state.children[i] = child_t - - info_map, ok := info_raw.(map[string]interface{}) - if ok == false && info_raw != nil { - return nil, fmt.Errorf("Parsed map wrong type: %+v", info_raw) - } - info_fn, exists := ctx.InfoLoadFuncs[j.Type] - var parsed_info ThreadInfo - if exists == false { - parsed_info = nil - } else { - parsed_info, err = info_fn(ctx, info_map) - if err != nil { - return nil, err - } - } - - state.child_info[id] = parsed_info - i++ - } - - return &state, nil -} - -func LoadSimpleThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){ - var j BaseThreadStateJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - state, err := RestoreBaseThreadState(ctx, j, loaded_nodes) - if err != nil { - return nil, err - } - - return state, nil -} - -func (state * BaseThreadState) SetTimeout(timeout time.Time, action string) { - state.timeout = timeout - state.timeout_action = action -} - -func (state * BaseThreadState) TimeoutAction() string { - return state.timeout_action -} - -func (state * BaseThreadState) MarshalJSON() ([]byte, error) { - thread_state := SaveBaseThreadState(state) - return json.Marshal(&thread_state) -} - -func (state * BaseThreadState) State() string { - return state.state_name -} - -func (state * BaseThreadState) SetState(new_state string) error { - if new_state == "init" { - return fmt.Errorf("Cannot set a thread to 'init' with SetState") - } else if new_state == "finished" { - return fmt.Errorf("Cannot set a thread to 'finished' with SetState") - } else if new_state == "started" { - return fmt.Errorf("Cannot set a thread to 'started' with SetState") - } - - state.state_name = new_state - return nil -} - -func (state * BaseThreadState) Start() error { - if state.state_name != "init" { - return fmt.Errorf("Cannot start a thread that's already started") - } - state.state_name = "started" - return nil -} - -func (state * BaseThreadState) Stop() error { - if state.state_name == "finished" { - return fmt.Errorf("Cannot stop a finished thread") - } else if state.state_name == "init" { - return fmt.Errorf("Cannot stop a thread that hasn't been started") - } - state.state_name = "finished" + thread.state_name = new_state return nil } -func (state * BaseThreadState) Parent() Thread { - return state.parent +func (thread * SimpleThread) Parent() Thread { + return thread.parent } -func (state * BaseThreadState) SetParent(parent Thread) { - state.parent = parent +func (thread * SimpleThread) SetParent(parent Thread) { + thread.parent = parent } -func (state * BaseThreadState) Children() []Thread { - return state.children +func (thread * SimpleThread) Children() []Thread { + return thread.children } -func (state * BaseThreadState) Child(id NodeID) Thread { - for _, child := range(state.children) { +func (thread * SimpleThread) Child(id NodeID) Thread { + for _, child := range(thread.children) { if child.ID() == id { return child } @@ -291,15 +89,14 @@ func (state * BaseThreadState) Child(id NodeID) Thread { -func (state * BaseThreadState) ChildInfo(child NodeID) ThreadInfo { - return state.child_info[child] +func (thread * SimpleThread) ChildInfo(child NodeID) ThreadInfo { + return thread.child_info[child] } -// Requires thread and childs state to be locked for write -func UnlinkThreads(ctx * GraphContext, thread Thread, child Thread) error { - state := thread.State().(ThreadState) - var found GraphNode = nil - for _, c := range(state.Children()) { +// Requires thread and childs thread to be locked for write +func UnlinkThreads(ctx * Context, thread Thread, child Thread) error { + var found Node = nil + for _, c := range(thread.Children()) { if child.ID() == c.ID() { found = c break @@ -310,16 +107,15 @@ func UnlinkThreads(ctx * GraphContext, thread Thread, child Thread) error { return fmt.Errorf("UNLINK_THREADS_ERR: %s is not a child of %s", child.ID(), thread.ID()) } - child_state := child.State().(ThreadState) - child_state.SetParent(nil) - state.RemoveChild(child) + child.SetParent(nil) + thread.RemoveChild(child) return nil } -func (state * BaseThreadState) RemoveChild(child Thread) { +func (thread * SimpleThread) RemoveChild(child Thread) { idx := -1 - for i, c := range(state.children) { + for i, c := range(thread.children) { if c.ID() == child.ID() { idx = i break @@ -327,47 +123,46 @@ func (state * BaseThreadState) RemoveChild(child Thread) { } if idx == -1 { - panic(fmt.Sprintf("%s is not a child of %s", child.ID(), state.Name())) + panic(fmt.Sprintf("%s is not a child of %s", child.ID(), thread.Name())) } - child_len := len(state.children) - state.children[idx] = state.children[child_len-1] - state.children = state.children[0:child_len-1] + child_len := len(thread.children) + thread.children[idx] = thread.children[child_len-1] + thread.children = thread.children[0:child_len-1] } -func (state * BaseThreadState) AddChild(child Thread, info ThreadInfo) error { +func (thread * SimpleThread) AddChild(child Thread, info ThreadInfo) error { if child == nil { return fmt.Errorf("Will not connect nil to the thread tree") } - _, exists := state.child_info[child.ID()] + _, exists := thread.child_info[child.ID()] if exists == true { return fmt.Errorf("Will not connect the same child twice") } - if info == nil && state.InfoType != nil { + if info == nil && thread.InfoType != nil { return fmt.Errorf("nil info passed when expecting info") } else if info != nil { - if reflect.TypeOf(info) != state.InfoType { - return fmt.Errorf("info type mismatch, expecting %+v", state.InfoType) + if reflect.TypeOf(info) != thread.InfoType { + return fmt.Errorf("info type mismatch, expecting %+v", thread.InfoType) } } - state.children = append(state.children, child) - state.child_info[child.ID()] = info + thread.children = append(thread.children, child) + thread.child_info[child.ID()] = info return nil } -func checkIfChild(ctx * GraphContext, thread_id NodeID, cur_state ThreadState, cur_id NodeID) bool { - for _, child := range(cur_state.Children()) { - if child.ID() == thread_id { +func checkIfChild(ctx * Context, target Thread, cur Thread, nodes NodeMap) bool { + for _, child := range(cur.Children()) { + if child.ID() == target.ID() { return true } is_child := false - UseStates(ctx, []GraphNode{child}, func(states NodeStateMap) (error) { - child_state := states[child.ID()].(ThreadState) - is_child = checkIfChild(ctx, cur_id, child_state, child.ID()) + UpdateMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { + is_child = checkIfChild(ctx, target, child, nodes) return nil }) if is_child { @@ -378,8 +173,8 @@ func checkIfChild(ctx * GraphContext, thread_id NodeID, cur_state ThreadState, c return false } -// Requires thread and childs state to be locked for write -func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInfo) error { +// Requires thread and childs thread to be locked for write +func LinkThreads(ctx * Context, thread Thread, child Thread, info ThreadInfo, nodes NodeMap) error { if ctx == nil || thread == nil || child == nil { return fmt.Errorf("invalid input") } @@ -388,26 +183,23 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf return fmt.Errorf("Will not link %s as a child of itself", thread.ID()) } - thread_state := thread.State().(ThreadState) - child_state := child.State().(ThreadState) - - if child_state.Parent() != nil { + if child.Parent() != nil { return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) } - if checkIfChild(ctx, thread.ID(), child_state, child.ID()) == true { + if checkIfChild(ctx, thread, child, nodes) == true { return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) } - if checkIfChild(ctx, child.ID(), thread_state, thread.ID()) == true { + if checkIfChild(ctx, child, thread, nodes) == true { return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) } - err := thread_state.AddChild(child, info) + err := thread.AddChild(child, info) if err != nil { return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err) } - child_state.SetParent(thread) + child.SetParent(thread) if err != nil { return err @@ -416,24 +208,195 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf return nil } -// Thread is the interface that thread tree nodes must implement +type ThreadAction func(* Context, Thread)(string, error) +type ThreadActions map[string]ThreadAction +type ThreadHandler func(* Context, Thread, GraphSignal)(string, error) +type ThreadHandlers map[string]ThreadHandler + type Thread interface { + // All Threads are Lockables Lockable + /// State Modification Functions + SetParent(parent Thread) + AddChild(child Thread, info ThreadInfo) error + RemoveChild(child Thread) + SetState(new_thread string) error + SetTimeout(end_time time.Time, action string) + /// State Reading Functions + Parent() Thread + Children() []Thread + Child(id NodeID) Thread + ChildInfo(child NodeID) ThreadInfo + State() string + TimeoutAction() string + /// Functions that dont read/write thread + // Deserialize the attribute map from json.Unmarshal + DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) + SetActive(active bool) error Action(action string) (ThreadAction, bool) Handler(signal_type string) (ThreadHandler, bool) - SetTimeout(end time.Time) + // Internal timeout channel for thread Timeout() <-chan time.Time + // Internal signal channel for thread + SignalChannel() <-chan GraphSignal ClearTimeout() ChildWaits() *sync.WaitGroup - Start() error - Stop() error +} + +type SimpleThread struct { + SimpleLockable + + actions ThreadActions + handlers ThreadHandlers + + timeout_chan <-chan time.Time + signal chan GraphSignal + child_waits *sync.WaitGroup + active bool + active_lock *sync.Mutex + + state_name string + parent Thread + children []Thread + child_info map[NodeID] ThreadInfo + InfoType reflect.Type + timeout time.Time + timeout_action string +} + +func (thread * SimpleThread) Serialize() ([]byte, error) { + thread_json := NewSimpleThreadJSON(thread) + return json.MarshalIndent(&thread_json, "", " ") +} + +func (thread * SimpleThread) SignalChannel() <-chan GraphSignal { + return thread.signal +} + +type SimpleThreadJSON struct { + Parent *NodeID `json:"parent"` + Children map[NodeID]interface{} `json:"children"` + Timeout time.Time `json:"timeout"` + TimeoutAction string `json:"timeout_action"` + StateName string `json:"state_name"` + SimpleLockableJSON +} + +func NewSimpleThreadJSON(thread *SimpleThread) SimpleThreadJSON { + children := map[NodeID]interface{}{} + for _, child := range(thread.children) { + children[child.ID()] = thread.child_info[child.ID()] + } + + var parent_id *NodeID = nil + if thread.parent != nil { + new_str := thread.parent.ID() + parent_id = &new_str + } + + lockable_json := NewSimpleLockableJSON(&thread.SimpleLockable) + + return SimpleThreadJSON{ + Parent: parent_id, + Children: children, + Timeout: thread.timeout, + TimeoutAction: thread.timeout_action, + StateName: thread.state_name, + SimpleLockableJSON: lockable_json, + } +} + +func LoadSimpleThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { + var j SimpleThreadJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + thread := NewSimpleThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers) + nodes[id] = &thread + + err = RestoreSimpleThread(ctx, &thread, j, nodes) + if err != nil { + return nil, err + } + + return &thread, nil +} + +// SimpleThread as no associated info with children +func (thread * SimpleThread) DeserializeInfo(ctx *Context, data []byte) (ThreadInfo, error) { + if len(data) > 0 { + return nil, fmt.Errorf("SimpleThread expected to deserialize no info but got %d length data: %s", len(data), string(data)) + } + return nil, nil +} + +func RestoreSimpleThread(ctx *Context, thread Thread, j SimpleThreadJSON, nodes NodeMap) error { + thread.SetTimeout(j.Timeout, j.TimeoutAction) + + if j.Parent != nil { + p, err := LoadNodeRecurse(ctx, *j.Parent, nodes) + if err != nil { + return err + } + p_t, ok := p.(Thread) + if ok == false { + return err + } + thread.SetParent(p_t) + } + + // TODO: Call different loading functions(to return different ThreadInfo types, based on the j.Type, + // Will probably have to add another set of callbacks to the context for this, and since there's now 3 sets that need to be matching it could be useful to move them to a struct so it's easier to keep in sync + i := 0 + for id, info_raw := range(j.Children) { + child_node, err := LoadNodeRecurse(ctx, id, nodes) + if err != nil { + return err + } + child_t, ok := child_node.(Thread) + if ok == false { + return fmt.Errorf("%+v is not a Thread as expected", child_node) + } + + info_ser, err := json.Marshal(info_raw) + if err != nil { + return err + } + + parsed_info, err := thread.DeserializeInfo(ctx, info_ser) + if err != nil { + return err + } + + thread.AddChild(child_t, parsed_info) + i++ + } + + return RestoreSimpleLockable(ctx, thread, j.SimpleLockableJSON, nodes) +} + +const THREAD_SIGNAL_BUFFER_SIZE = 128 +func NewSimpleThread(id NodeID, name string, state_name string, info_type reflect.Type, actions ThreadActions, handlers ThreadHandlers) SimpleThread { + return SimpleThread{ + SimpleLockable: NewSimpleLockable(id, name), + state_name: state_name, + signal: make(chan GraphSignal, THREAD_SIGNAL_BUFFER_SIZE), + children: []Thread{}, + child_info: map[NodeID]ThreadInfo{}, + actions: actions, + handlers: handlers, + child_waits: &sync.WaitGroup{}, + active_lock: &sync.Mutex{}, + } } // Requires that thread is already locked for read in UseStates -func FindChild(ctx * GraphContext, thread Thread, id NodeID, states NodeStateMap) Thread { +func FindChild(ctx * Context, thread Thread, id NodeID, nodes NodeMap) Thread { if thread == nil { panic("cannot recurse through nil") } @@ -441,11 +404,10 @@ func FindChild(ctx * GraphContext, thread Thread, id NodeID, states NodeStateMap return thread } - thread_state := thread.State().(ThreadState) - for _, child := range thread_state.Children() { + for _, child := range thread.Children() { var result Thread = nil - UseMoreStates(ctx, []GraphNode{child}, states, func(states NodeStateMap) (error) { - result = FindChild(ctx, child, id, states) + UseMoreStates(ctx, []Node{child}, nodes, func(nodes NodeMap) error { + result = FindChild(ctx, child, id, nodes) return nil }) if result != nil { @@ -456,7 +418,7 @@ func FindChild(ctx * GraphContext, thread Thread, id NodeID, states NodeStateMap return nil } -func ChildGo(ctx * GraphContext, thread Thread, child Thread, first_action string) { +func ChildGo(ctx * Context, thread Thread, child Thread, first_action string) { thread.ChildWaits().Add(1) go func(child Thread) { ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID()) @@ -471,10 +433,10 @@ func ChildGo(ctx * GraphContext, thread Thread, child Thread, first_action strin } // Main Loop for Threads -func ThreadLoop(ctx * GraphContext, thread Thread, first_action string) error { +func ThreadLoop(ctx * Context, thread Thread, first_action string) error { // Start the thread, error if double-started ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action) - err := thread.Start() + err := thread.SetActive(true) if err != nil { ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err) return err @@ -494,15 +456,14 @@ func ThreadLoop(ctx * GraphContext, thread Thread, first_action string) error { } } - err = thread.Stop() + err = thread.SetActive(false) if err != nil { ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err) return err } - err = UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) (error) { - thread_state := thread.State().(ThreadState) - err := thread_state.Stop() + err = UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { + err := thread.SetState("finished") if err != nil { return err } @@ -519,95 +480,46 @@ func ThreadLoop(ctx * GraphContext, thread Thread, first_action string) error { return nil } -type ThreadAction func(* GraphContext, Thread)(string, error) -type ThreadActions map[string]ThreadAction -type ThreadHandler func(* GraphContext, Thread, GraphSignal)(string, error) -type ThreadHandlers map[string]ThreadHandler - -// Thread is the most basic thread that can exist in the thread tree. -// On start it automatically transitions to completion. -// This node by itself doesn't implement any special behaviours for children, so they will be ignored. -// When started, this thread automatically transitions to completion -type BaseThread struct { - BaseLockable - Actions ThreadActions - Handlers ThreadHandlers - - timeout_chan <-chan time.Time - child_waits *sync.WaitGroup - active bool - active_lock *sync.Mutex -} - -func (thread * BaseThread) ChildWaits() *sync.WaitGroup { +func (thread * SimpleThread) ChildWaits() *sync.WaitGroup { return thread.child_waits } -func (thread * BaseThread) Start() error { - thread.active_lock.Lock() - defer thread.active_lock.Unlock() - if thread.active == true { - return fmt.Errorf("%s is active, cannot start", thread.ID()) - } - thread.active = true - return nil -} - -func (thread * BaseThread) Stop() error { +func (thread * SimpleThread) SetActive(active bool) error { thread.active_lock.Lock() defer thread.active_lock.Unlock() - if thread.active == false { - return fmt.Errorf("%s is not active, cannot stop", thread.ID()) + if thread.active == true && active == true { + return fmt.Errorf("%s is active, cannot set active", thread.ID()) + } else if thread.active == false && active == false { + return fmt.Errorf("%s is already inactive, canot set inactive", thread.ID()) } - thread.active = false + thread.active = active return nil } -func (thread * BaseThread) CanLock(node GraphNode, state LockableState) error { - return nil -} - -func (thread * BaseThread) CanUnlock(node GraphNode, state LockableState) error { - return nil -} - -func (thread * BaseThread) Lock(node GraphNode, state LockableState) { - return -} - -func (thread * BaseThread) Unlock(node GraphNode, state LockableState) { - return -} - -func (thread * BaseThread) Action(action string) (ThreadAction, bool) { - action_fn, exists := thread.Actions[action] +func (thread * SimpleThread) Action(action string) (ThreadAction, bool) { + action_fn, exists := thread.actions[action] return action_fn, exists } -func (thread * BaseThread) Handler(signal_type string) (ThreadHandler, bool) { - handler, exists := thread.Handlers[signal_type] +func (thread * SimpleThread) Handler(signal_type string) (ThreadHandler, bool) { + handler, exists := thread.handlers[signal_type] return handler, exists } -func (thread * BaseThread) Timeout() <-chan time.Time { +func (thread * SimpleThread) Timeout() <-chan time.Time { return thread.timeout_chan } -func (thread * BaseThread) ClearTimeout() { +func (thread * SimpleThread) ClearTimeout() { thread.timeout_chan = nil } -func (thread * BaseThread) SetTimeout(end time.Time) { - thread.timeout_chan = time.After(time.Until(end)) -} - -var ThreadStart = func(ctx * GraphContext, thread Thread) error { - err := UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) (error) { - thread_state := thread.State().(ThreadState) +var ThreadStart = func(ctx * Context, thread Thread) error { + err := UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { owner_id := NodeID("") - if thread_state.Owner() != nil { - owner_id = thread_state.Owner().ID() + if thread.Owner() != nil { + owner_id = thread.Owner().ID() } if owner_id != thread.ID() { err := LockLockables(ctx, []Lockable{thread}, thread, nodes) @@ -615,7 +527,7 @@ var ThreadStart = func(ctx * GraphContext, thread Thread) error { return err } } - return thread_state.Start() + return thread.SetState("started") }) if err != nil { @@ -625,7 +537,7 @@ var ThreadStart = func(ctx * GraphContext, thread Thread) error { return nil } -var ThreadDefaultStart = func(ctx * GraphContext, thread Thread) (string, error) { +var ThreadDefaultStart = func(ctx * Context, thread Thread) (string, error) { ctx.Log.Logf("thread", "THREAD_DEFAULT_START: %s", thread.ID()) err := ThreadStart(ctx, thread) if err != nil { @@ -634,12 +546,12 @@ var ThreadDefaultStart = func(ctx * GraphContext, thread Thread) (string, error) return "wait", nil } -var ThreadDefaultRestore = func(ctx * GraphContext, thread Thread) (string, error) { +var ThreadDefaultRestore = func(ctx * Context, thread Thread) (string, error) { ctx.Log.Logf("thread", "THREAD_DEFAULT_RESTORE: %s", thread.ID()) return "wait", nil } -var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) { +var ThreadWait = func(ctx * Context, thread Thread) (string, error) { ctx.Log.Logf("thread", "THREAD_WAIT: %s TIMEOUT: %+v", thread.ID(), thread.Timeout()) for { select { @@ -658,9 +570,8 @@ var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) { } case <- thread.Timeout(): timeout_action := "" - err := UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) error { - thread_state := thread.State().(ThreadState) - timeout_action = thread_state.TimeoutAction() + err := UpdateStates(ctx, []Node{thread}, func(nodes NodeMap) error { + timeout_action = thread.TimeoutAction() thread.ClearTimeout() return nil }) @@ -688,35 +599,23 @@ func NewThreadAbortedError(aborter NodeID) ThreadAbortedError { } // Default thread abort is to return a ThreadAbortedError -func ThreadAbort(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { - UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error { - SendUpdate(ctx, thread, NewSignal(thread, "thread_aborted"), states) +func ThreadAbort(ctx * Context, thread Thread, signal GraphSignal) (string, error) { + UseStates(ctx, []Node{thread}, func(nodes NodeMap) error { + thread.Signal(ctx, NewSignal(thread, "thread_aborted"), nodes) return nil }) return "", NewThreadAbortedError(signal.Source()) } // Default thread cancel is to finish the thread -func ThreadCancel(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { - UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error { - SendUpdate(ctx, thread, NewSignal(thread, "thread_cancelled"), states) +func ThreadCancel(ctx * Context, thread Thread, signal GraphSignal) (string, error) { + UseStates(ctx, []Node{thread}, func(nodes NodeMap) error { + thread.Signal(ctx, NewSignal(thread, "thread_cancelled"), nodes) return nil }) return "", nil } -func NewBaseThreadState(name string, _type string) BaseThreadState { - return BaseThreadState{ - BaseLockableState: NewBaseLockableState(name, _type), - children: []Thread{}, - child_info: map[NodeID]ThreadInfo{}, - parent: nil, - timeout: time.Time{}, - timeout_action: "wait", - state_name: "init", - } -} - func NewThreadActions() ThreadActions{ actions := ThreadActions{} for k, v := range(BaseThreadActions) { @@ -745,48 +644,3 @@ var BaseThreadHandlers = ThreadHandlers{ "abort": ThreadAbort, "cancel": ThreadCancel, } - -func NewBaseThread(ctx * GraphContext, actions ThreadActions, handlers ThreadHandlers, state ThreadState) (BaseThread, error) { - lockable, err := NewBaseLockable(ctx, state) - if err != nil { - return BaseThread{}, err - } - - thread := BaseThread{ - BaseLockable: lockable, - Actions: actions, - Handlers: handlers, - child_waits: &sync.WaitGroup{}, - active: false, - active_lock: &sync.Mutex{}, - } - - return thread, nil -} - -func NewSimpleThread(ctx * GraphContext, name string, requirements []Lockable, actions ThreadActions, handlers ThreadHandlers) (* BaseThread, error) { - state := NewBaseThreadState(name, "simple_thread") - - thread, err := NewBaseThread(ctx, actions, handlers, &state) - if err != nil { - return nil, err - } - - thread_ptr := &thread - - - if len(requirements) > 0 { - req_nodes := make([]GraphNode, len(requirements)) - for i, req := range(requirements) { - req_nodes[i] = req - } - err = UpdateStates(ctx, req_nodes, func(nodes NodeMap) error { - return LinkLockables(ctx, thread_ptr, requirements, nodes) - }) - if err != nil { - return nil, err - } - } - - return thread_ptr, nil -} diff --git a/thread_test.go b/thread_test.go index 436312d..54613b8 100644 --- a/thread_test.go +++ b/thread_test.go @@ -10,22 +10,21 @@ import ( func TestNewThread(t * testing.T) { ctx := testContext(t) - t1, err := NewSimpleThread(ctx, "Test thread 1", []Lockable{}, BaseThreadActions, BaseThreadHandlers) - fatalErr(t, err) + t1_r := NewSimpleThread(RandID(), "Test thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) + t1 := &t1_r go func(thread Thread) { time.Sleep(10*time.Millisecond) - UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { - SendUpdate(ctx, t1, CancelSignal(nil), states) - return nil + UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { + return t1.Signal(ctx, CancelSignal(nil), nodes) }) }(t1) - err = ThreadLoop(ctx, t1, "start") + err := ThreadLoop(ctx, t1, "start") fatalErr(t, err) - err = UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) (error) { - owner := states[t1.ID()].(ThreadState).Owner() + err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) (error) { + owner := t1.owner if owner != nil { return fmt.Errorf("Wrong owner %+v", owner) } @@ -36,17 +35,21 @@ func TestNewThread(t * testing.T) { func TestThreadWithRequirement(t * testing.T) { ctx := testContext(t) - l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") + l1 := &l1_r - t1, err := NewSimpleThread(ctx, "Test Thread 1", []Lockable{l1}, BaseThreadActions, BaseThreadHandlers) + t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) + t1 := &t1_r + + err := UpdateStates(ctx, []Node{l1, t1}, func(nodes NodeMap) error { + return LinkLockables(ctx, t1, []Lockable{l1}, nodes) + }) fatalErr(t, err) go func (thread Thread) { time.Sleep(10*time.Millisecond) - UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { - SendUpdate(ctx, t1, CancelSignal(nil), states) - return nil + UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { + return t1.Signal(ctx, CancelSignal(nil), nodes) }) }(t1) fatalErr(t, err) @@ -54,8 +57,8 @@ func TestThreadWithRequirement(t * testing.T) { err = ThreadLoop(ctx, t1, "start") fatalErr(t, err) - err = UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) (error) { - owner := states[l1.ID()].(LockableState).Owner() + err = UseStates(ctx, []Node{l1}, func(nodes NodeMap) (error) { + owner := l1.owner if owner != nil { return fmt.Errorf("Wrong owner %+v", owner) } @@ -66,22 +69,25 @@ func TestThreadWithRequirement(t * testing.T) { func TestThreadDBLoad(t * testing.T) { ctx := logTestContext(t, []string{}) - l1, err := NewSimpleLockable(ctx, "Test Lockable 1", []Lockable{}) - fatalErr(t, err) - - t1, err := NewSimpleThread(ctx, "Test Thread 1", []Lockable{l1}, BaseThreadActions, BaseThreadHandlers) - fatalErr(t, err) + l1_r := NewSimpleLockable(RandID(), "Test Lockable 1") + l1 := &l1_r + t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) + t1 := &t1_r + err := UpdateStates(ctx, []Node{t1, l1}, func(nodes NodeMap) error { + return LinkLockables(ctx, t1, []Lockable{l1}, nodes) + }) - UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { - SendUpdate(ctx, t1, CancelSignal(nil), states) - return nil + err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { + return t1.Signal(ctx, CancelSignal(nil), nodes) }) + fatalErr(t, err) + err = ThreadLoop(ctx, t1, "start") fatalErr(t, err) - err = UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { - ser, err := json.MarshalIndent(states[t1.ID()], "", " ") + err = UseStates(ctx, []Node{t1}, func(nodes NodeMap) error { + ser, err := json.MarshalIndent(nodes[t1.ID()], "", " ") fmt.Printf("\n%s\n\n", ser) return err }) @@ -89,8 +95,8 @@ func TestThreadDBLoad(t * testing.T) { t1_loaded, err := LoadNode(ctx, t1.ID()) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{t1_loaded}, func(states NodeStateMap) error { - ser, err := json.MarshalIndent(states[t1_loaded.ID()], "", " ") + err = UseStates(ctx, []Node{t1_loaded}, func(nodes NodeMap) error { + ser, err := json.MarshalIndent(nodes[t1_loaded.ID()], "", " ") fmt.Printf("\n%s\n\n", ser) return err }) @@ -98,15 +104,20 @@ func TestThreadDBLoad(t * testing.T) { func TestThreadUnlink(t * testing.T) { ctx := logTestContext(t, []string{}) - t1, err := NewSimpleThread(ctx, "Test Thread 1", []Lockable{}, BaseThreadActions, BaseThreadHandlers) - fatalErr(t, err) + t1_r := NewSimpleThread(RandID(), "Test Thread 1", "init", nil, BaseThreadActions, BaseThreadHandlers) + t1 := &t1_r + t2_r := NewSimpleThread(RandID(), "Test Thread 2", "init", nil, BaseThreadActions, BaseThreadHandlers) + t2 := &t2_r - t2, err := NewSimpleThread(ctx, "Test Thread 2", []Lockable{}, BaseThreadActions, BaseThreadHandlers) - fatalErr(t, err) - err = LinkThreads(ctx, t1, t2, nil) - fatalErr(t, err) + err := UpdateStates(ctx, []Node{t1, t2}, func(nodes NodeMap) error { + err := LinkThreads(ctx, t1, t2, nil, nodes) + if err != nil { + return err + } - err = UnlinkThreads(ctx, t1, t2) + return UnlinkThreads(ctx, t1, t2) + }) fatalErr(t, err) } +