diff --git a/gql.go b/gql.go index 4b6d999..b66b1a9 100644 --- a/gql.go +++ b/gql.go @@ -342,17 +342,15 @@ type GQLThread struct { type GQLThreadInfo struct { ThreadInfo `json:"-"` Start bool `json:"start"` - Started bool `json:"started"` - FirstAction string `json:"first_action"` - RestoreAction string `json:"restore_action"` + StartState string `json:"start_state"` + RestoreState string `json:"restore_state"` } -func NewGQLThreadInfo(start bool, first_action string, restore_action string) GQLThreadInfo { +func NewGQLThreadInfo(start bool, start_state string, restore_state string) GQLThreadInfo { info := GQLThreadInfo{ Start: start, - Started: false, - FirstAction: first_action, - RestoreAction: restore_action, + StartState: start_state, + RestoreState: restore_state, } return info } @@ -398,9 +396,8 @@ func LoadGQLThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap) ( func LoadGQLThreadInfo(ctx * GraphContext, raw map[string]interface{}) (ThreadInfo, error) { info := GQLThreadInfo{ Start: raw["start"].(bool), - Started: raw["started"].(bool), - FirstAction: raw["first_action"].(string), - RestoreAction: raw["restore_action"].(string), + StartState: raw["start_state"].(string), + RestoreState: raw["restore_state"].(string), } return &info, nil } @@ -429,60 +426,39 @@ var gql_actions ThreadActions = ThreadActions{ "restore": func(ctx * GraphContext, thread Thread) (string, error) { // Start all the threads that should be "started" ctx.Log.Logf("gql", "GQL_THREAD_RESTORE: %s", thread.ID()) - server, ok := thread.(*GQLThread) - if ok == false { - panic("thread is not *GQLThread") - } - - // Serve the GQL http and ws handlers - mux := http.NewServeMux() - mux.HandleFunc("/gql", GQLHandler(ctx, server)) - mux.HandleFunc("/gqlws", GQLWSHandler(ctx, server)) - - // Server a graphiql interface(TODO make configurable whether to start this) - mux.HandleFunc("/graphiql", GraphiQLHandler()) - - // Server the ./site directory to /site (TODO make configurable with better defaults) - fs := http.FileServer(http.Dir("./site")) - mux.Handle("/site/", http.StripPrefix("/site", fs)) - - UseStates(ctx, []GraphNode{server}, func(states NodeStateMap)(error){ - server_state := states[server.ID()].(*GQLThreadState) - server.http_server = &http.Server{ - Addr: server_state.Listen, - Handler: mux, - } - return nil - }) - - server.http_done.Add(1) - go func(server *GQLThread) { - defer server.http_done.Done() - err := server.http_server.ListenAndServe() - if err != http.ErrServerClosed { - panic(fmt.Sprintf("Failed to start gql server: %s", err)) - } - }(server) UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap)(error) { server_state := thread.State().(*GQLThreadState) - for _, child := range(server_state.Children()) { - should_run := (server_state.child_info[child.ID()]).(*GQLThreadInfo) - if should_run.Started == true { - ChildGo(ctx, server_state, thread, child.ID(), should_run.RestoreAction) + return UpdateMoreStates(ctx, NodeList(server_state.Children()), nodes, func(nodes NodeMap) error { + for _, child := range(server_state.Children()) { + child_state := child.State().(ThreadState) + should_run := (server_state.child_info[child.ID()]).(*GQLThreadInfo) + if should_run.Start == true && child_state.State() != "finished" { + ChildGo(ctx, thread, child, should_run.RestoreState) + } } - } - return nil + return nil + }) }) - return "wait", nil + + return "start_server", nil }, "start": func(ctx * GraphContext, thread Thread) (string, error) { - ctx.Log.Logf("gql", "SERVER_STARTED") + ctx.Log.Logf("gql", "GQL_START") + err := ThreadStart(ctx, thread) + if err != nil { + return "", err + } + + return "start_server", nil + }, + "start_server": func(ctx * GraphContext, thread Thread) (string, error) { server, ok := thread.(*GQLThread) if ok == false { - panic(fmt.Sprintf("GQL_THREAD_START: %s is not GQLThread, %+v", thread.ID(), thread.State())) + return "", fmt.Errorf("GQL_THREAD_START: %s is not GQLThread, %+v", thread.ID(), thread.State()) } + ctx.Log.Logf("gql", "GQL_START_SERVER") // Serve the GQL http and ws handlers mux := http.NewServeMux() mux.HandleFunc("/gql", GQLHandler(ctx, server)) @@ -527,9 +503,8 @@ var gql_handlers ThreadHandlers = ThreadHandlers{ ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: tried to start %s whis is not a child") return nil } - if should_run.Start == true && should_run.Started == false { - ChildGo(ctx, server_state, thread, signal.Source(), should_run.FirstAction) - should_run.Started = true + if should_run.Start == true { + ChildGo(ctx, thread, server_state.Child(signal.Source()), should_run.StartState) } return nil }) @@ -540,7 +515,7 @@ var gql_handlers ThreadHandlers = ThreadHandlers{ server := thread.(*GQLThread) server.http_server.Shutdown(context.TODO()) server.http_done.Wait() - return "", fmt.Errorf("GQLThread aborted by signal") + return "", NewThreadAbortedError(signal.Source()) }, "cancel": func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { ctx.Log.Logf("gql", "GQL_CANCEL") diff --git a/gql_test.go b/gql_test.go index 8670575..e6b6892 100644 --- a/gql_test.go +++ b/gql_test.go @@ -5,6 +5,7 @@ import ( "time" "fmt" "encoding/json" + "errors" ) func TestGQLThread(t * testing.T) { @@ -35,7 +36,7 @@ func TestGQLThread(t * testing.T) { fatalErr(t, err) }(gql_thread) - err = RunThread(ctx, gql_thread, "start") + err = ThreadLoop(ctx, gql_thread, "start") fatalErr(t, err) } @@ -58,13 +59,17 @@ func TestGQLDBLoad(t * testing.T) { fatalErr(t, err) err = UseStates(ctx, []GraphNode{gql}, func(states NodeStateMap) error { SendUpdate(ctx, gql, NewSignal(t1, "child_added"), states) - SendUpdate(ctx, gql, CancelSignal(nil), states) + SendUpdate(ctx, gql, AbortSignal(nil), states) return nil }) - err = RunThread(ctx, gql, "start") - fatalErr(t, err) + err = ThreadLoop(ctx, gql, "start") + if errors.Is(err, NewThreadAbortedError("")) { + ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) + } else { + fatalErr(t, err) + } - (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1, 100*time.Millisecond, "Didn't receive thread_done from t1 on t1") + (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", t1, 100*time.Millisecond, "Didn't receive thread_abort from t1 on t1") err = UseStates(ctx, []GraphNode{gql, t1}, func(states NodeStateMap) error { ser1, err := json.MarshalIndent(states[gql.ID()], "", " ") @@ -89,12 +94,16 @@ func TestGQLDBLoad(t * testing.T) { fmt.Printf("\n%s\n\n", ser) return err }) - SendUpdate(ctx, gql_loaded, CancelSignal(nil), states) + SendUpdate(ctx, gql_loaded, AbortSignal(nil), states) return err }) - err = RunThread(ctx, gql_loaded.(Thread), "restore") - fatalErr(t, err) - (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_done", t1_loaded, 100*time.Millisecond, "Dicn't received update_done on t1_loaded from t1_loaded") + err = ThreadLoop(ctx, gql_loaded.(Thread), "restore") + if errors.Is(err, NewThreadAbortedError("")) { + ctx.Log.Logf("test", "Main thread aborted by signal: %s", err) + } else { + fatalErr(t, err) + } + (*GraphTester)(t).WaitForValue(ctx, update_channel, "thread_aborted", t1_loaded, 100*time.Millisecond, "Dicn't received thread_aborted on t1_loaded from t1_loaded") } diff --git a/graph.go b/graph.go index c7bb6d3..3611392 100644 --- a/graph.go +++ b/graph.go @@ -378,7 +378,7 @@ func (signal BaseSignal) Type() string { } func NewBaseSignal(source GraphNode, _type string, direction SignalDirection) BaseSignal { - var source_id NodeID = "" + var source_id NodeID = "nil" if source != nil { source_id = source.ID() } diff --git a/thread.go b/thread.go index a616842..ff43a6a 100644 --- a/thread.go +++ b/thread.go @@ -68,6 +68,7 @@ type ThreadState interface { RemoveChild(child Thread) Start() error Stop() error + State() string TimeoutAction() string SetTimeout(end_time time.Time, action string) @@ -75,11 +76,11 @@ type ThreadState interface { type BaseThreadState struct { BaseLockableState + state_name string parent Thread children []Thread child_info map[NodeID] ThreadInfo InfoType reflect.Type - running bool timeout time.Time timeout_action string } @@ -89,6 +90,7 @@ type BaseThreadStateJSON struct { Children map[NodeID]interface{} `json:"children"` Timeout time.Time `json:"timeout"` TimeoutAction string `json:"timeout_action"` + StateName string `json:"state_name"` BaseLockableStateJSON } @@ -111,6 +113,7 @@ func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON { Children: children, Timeout: state.timeout, TimeoutAction: state.timeout_action, + StateName: state.state_name, BaseLockableStateJSON: lockable_state, } @@ -124,6 +127,8 @@ func RestoreBaseThread(ctx * GraphContext, id NodeID, actions ThreadActions, han Actions: actions, Handlers: handlers, child_waits: &sync.WaitGroup{}, + active: false, + active_lock: &sync.Mutex{}, } return thread @@ -146,7 +151,7 @@ func RestoreBaseThreadState(ctx * GraphContext, j BaseThreadStateJSON, loaded_no children: make([]Thread, len(j.Children)), child_info: map[NodeID]ThreadInfo{}, InfoType: nil, - running: false, + state_name: j.StateName, timeout: j.Timeout, timeout_action: j.TimeoutAction, } @@ -228,19 +233,38 @@ func (state * BaseThreadState) MarshalJSON() ([]byte, error) { return json.Marshal(&thread_state) } +func (state * BaseThreadState) State() string { + return state.state_name +} + +func (state * BaseThreadState) SetState(new_state string) error { + if new_state == "init" { + return fmt.Errorf("Cannot set a thread to 'init' with SetState") + } else if new_state == "finished" { + return fmt.Errorf("Cannot set a thread to 'finished' with SetState") + } else if new_state == "started" { + return fmt.Errorf("Cannot set a thread to 'started' with SetState") + } + + state.state_name = new_state + return nil +} + func (state * BaseThreadState) Start() error { - if state.running == true { - return fmt.Errorf("Cannot start a running thread") + if state.state_name != "init" { + return fmt.Errorf("Cannot start a thread that's already started") } - state.running = true + state.state_name = "started" return nil } func (state * BaseThreadState) Stop() error { - if state.running == false { - return fmt.Errorf("Cannot stop a thread that's not running") + if state.state_name == "finished" { + return fmt.Errorf("Cannot stop a finished thread") + } else if state.state_name == "init" { + return fmt.Errorf("Cannot stop a thread that hasn't been started") } - state.running = false + state.state_name = "finished" return nil } @@ -404,6 +428,8 @@ type Thread interface { ClearTimeout() ChildWaits() *sync.WaitGroup + Start() error + Stop() error } // Requires that thread is already locked for read in UseStates @@ -430,16 +456,12 @@ func FindChild(ctx * GraphContext, thread Thread, id NodeID, states NodeStateMap return nil } -func ChildGo(ctx * GraphContext, thread_state ThreadState, thread Thread, child_id NodeID, first_action string) { - child := thread_state.Child(child_id) - if child == nil { - panic(fmt.Errorf("Child not in thread, can't start %s", child_id)) - } +func ChildGo(ctx * GraphContext, thread Thread, child Thread, first_action string) { thread.ChildWaits().Add(1) go func(child Thread) { ctx.Log.Logf("thread", "THREAD_START_CHILD: %s from %s", thread.ID(), child.ID()) defer thread.ChildWaits().Done() - err := RunThread(ctx, child, first_action) + err := ThreadLoop(ctx, child, first_action) if err != nil { ctx.Log.Logf("thread", "THREAD_CHILD_RUN_ERR: %s %e", child.ID(), err) } else { @@ -448,39 +470,15 @@ func ChildGo(ctx * GraphContext, thread_state ThreadState, thread Thread, child_ }(child) } -func RunThread(ctx * GraphContext, thread Thread, first_action string) error { - ctx.Log.Logf("thread", "THREAD_RUN: %s", thread.ID()) - - err := UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) (error) { - thread_state := thread.State().(ThreadState) - owner_id := NodeID("") - if thread_state.Owner() != nil { - owner_id = thread_state.Owner().ID() - } - if owner_id != thread.ID() { - return LockLockables(ctx, []Lockable{thread}, thread, nodes) - } - return nil - }) - if err != nil { - return err - } - - err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) (error) { - thread_state := states[thread.ID()].(ThreadState) - if thread_state.Owner() == nil { - return fmt.Errorf("THREAD_RUN_NOT_LOCKED: %s", thread_state.Name()) - } else if thread_state.Owner().ID() != thread.ID() { - return fmt.Errorf("THREAD_RUN_RESOURCE_ALREADY_LOCKED: %s, %s", thread_state.Name(), thread_state.Owner().ID()) - } else if err := thread_state.Start(); err != nil { - return fmt.Errorf("THREAD_START_ERR: %e", err) - } - return nil - }) +// Main Loop for Threads +func ThreadLoop(ctx * GraphContext, thread Thread, first_action string) error { + // Start the thread, error if double-started + ctx.Log.Logf("thread", "THREAD_LOOP_START: %s - %s", thread.ID(), first_action) + err := thread.Start() if err != nil { + ctx.Log.Logf("thread", "THREAD_LOOP_START_ERR: %e", err) return err } - next_action := first_action for next_action != "" { action, exists := thread.Action(next_action) @@ -496,31 +494,27 @@ func RunThread(ctx * GraphContext, thread Thread, first_action string) error { } } - err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) (error) { - thread_state := states[thread.ID()].(ThreadState) - err := thread_state.Stop() - return err - - }) + err = thread.Stop() if err != nil { - ctx.Log.Logf("thread", "THREAD_RUN_STOP_ERR: %e", err) + ctx.Log.Logf("thread", "THREAD_LOOP_STOP_ERR: %e", err) return err } err = UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) (error) { + thread_state := thread.State().(ThreadState) + err := thread_state.Stop() + if err != nil { + return err + } return UnlockLockables(ctx, []Lockable{thread}, thread, nodes) }) + if err != nil { - ctx.Log.Logf("thread", "THREAD_RUN_UNLOCK_ERR: %e", err) + ctx.Log.Logf("thread", "THREAD_LOOP_UNLOCK_ERR: %e", err) return err } - err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error { - SendUpdate(ctx, thread, NewSignal(thread, "thread_done"), states) - return nil - }) - - ctx.Log.Logf("thread", "THREAD_RUN_DONE: %s", thread.ID()) + ctx.Log.Logf("thread", "THREAD_LOOP_DONE: %s", thread.ID()) return nil } @@ -542,12 +536,34 @@ type BaseThread struct { timeout_chan <-chan time.Time child_waits *sync.WaitGroup + active bool + active_lock *sync.Mutex } func (thread * BaseThread) ChildWaits() *sync.WaitGroup { return thread.child_waits } +func (thread * BaseThread) Start() error { + thread.active_lock.Lock() + defer thread.active_lock.Unlock() + if thread.active == true { + return fmt.Errorf("%s is active, cannot start", thread.ID()) + } + thread.active = true + return nil +} + +func (thread * BaseThread) Stop() error { + thread.active_lock.Lock() + defer thread.active_lock.Unlock() + if thread.active == false { + return fmt.Errorf("%s is not active, cannot stop", thread.ID()) + } + thread.active = false + return nil +} + func (thread * BaseThread) CanLock(node GraphNode, state LockableState) error { return nil } @@ -586,8 +602,35 @@ func (thread * BaseThread) SetTimeout(end time.Time) { thread.timeout_chan = time.After(time.Until(end)) } +var ThreadStart = func(ctx * GraphContext, thread Thread) error { + err := UpdateStates(ctx, []GraphNode{thread}, func(nodes NodeMap) (error) { + thread_state := thread.State().(ThreadState) + owner_id := NodeID("") + if thread_state.Owner() != nil { + owner_id = thread_state.Owner().ID() + } + if owner_id != thread.ID() { + err := LockLockables(ctx, []Lockable{thread}, thread, nodes) + if err != nil { + return err + } + } + return thread_state.Start() + }) + + if err != nil { + return err + } + + return nil +} + var ThreadDefaultStart = func(ctx * GraphContext, thread Thread) (string, error) { - ctx.Log.Logf("thread", "THREAD_DEFAUL_START: %s", thread.ID()) + ctx.Log.Logf("thread", "THREAD_DEFAULT_START: %s", thread.ID()) + err := ThreadStart(ctx, thread) + if err != nil { + return "", err + } return "wait", nil } @@ -630,11 +673,35 @@ var ThreadWait = func(ctx * GraphContext, thread Thread) (string, error) { } } -var ThreadAbort = func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { - return "", fmt.Errorf("%s aborted by signal from %s", thread.ID(), signal.Source()) +type ThreadAbortedError NodeID + +func (e ThreadAbortedError) Is(target error) bool { + error_type := reflect.TypeOf(ThreadAbortedError("")) + target_type := reflect.TypeOf(target) + return error_type == target_type +} +func (e ThreadAbortedError) Error() string { + return fmt.Sprintf("Aborted by %s", string(e)) +} +func NewThreadAbortedError(aborter NodeID) ThreadAbortedError { + return ThreadAbortedError(aborter) +} + +// Default thread abort is to return a ThreadAbortedError +func ThreadAbort(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { + UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error { + SendUpdate(ctx, thread, NewSignal(thread, "thread_aborted"), states) + return nil + }) + return "", NewThreadAbortedError(signal.Source()) } -var ThreadCancel = func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { +// Default thread cancel is to finish the thread +func ThreadCancel(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { + UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) error { + SendUpdate(ctx, thread, NewSignal(thread, "thread_cancelled"), states) + return nil + }) return "", nil } @@ -646,6 +713,7 @@ func NewBaseThreadState(name string, _type string) BaseThreadState { parent: nil, timeout: time.Time{}, timeout_action: "wait", + state_name: "init", } } @@ -689,6 +757,8 @@ func NewBaseThread(ctx * GraphContext, actions ThreadActions, handlers ThreadHan Actions: actions, Handlers: handlers, child_waits: &sync.WaitGroup{}, + active: false, + active_lock: &sync.Mutex{}, } return thread, nil diff --git a/thread_test.go b/thread_test.go index 6d1d129..436312d 100644 --- a/thread_test.go +++ b/thread_test.go @@ -21,7 +21,7 @@ func TestNewThread(t * testing.T) { }) }(t1) - err = RunThread(ctx, t1, "start") + err = ThreadLoop(ctx, t1, "start") fatalErr(t, err) err = UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) (error) { @@ -51,7 +51,7 @@ func TestThreadWithRequirement(t * testing.T) { }(t1) fatalErr(t, err) - err = RunThread(ctx, t1, "start") + err = ThreadLoop(ctx, t1, "start") fatalErr(t, err) err = UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) (error) { @@ -77,7 +77,7 @@ func TestThreadDBLoad(t * testing.T) { SendUpdate(ctx, t1, CancelSignal(nil), states) return nil }) - err = RunThread(ctx, t1, "start") + err = ThreadLoop(ctx, t1, "start") fatalErr(t, err) err = UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error {