diff --git a/gql.go b/gql.go index b405570..48117c4 100644 --- a/gql.go +++ b/gql.go @@ -321,79 +321,109 @@ type FieldMap map[string]*graphql.Field type GQLThread struct { BaseThread + http_server *http.Server + http_done *sync.WaitGroup extended_types ObjTypeMap extended_queries FieldMap extended_subscriptions FieldMap extended_mutations FieldMap } +type GQLThreadInfo struct { + ThreadInfo + Start bool + Started bool +} + +func NewGQLThreadInfo(start bool) GQLThreadInfo { + info := GQLThreadInfo{ + Start: start, + Started: false, + } + return info +} + type GQLThreadState struct { BaseThreadState Listen string } func NewGQLThreadState(listen string) GQLThreadState { - return GQLThreadState{ + state := GQLThreadState{ BaseThreadState: NewBaseThreadState("GQL Server"), Listen: listen, } + state.InfoType = reflect.TypeOf((*GQLThreadInfo)(nil)) + return state } var gql_actions ThreadActions = ThreadActions{ "start": func(ctx * GraphContext, thread Thread) (string, error) { ctx.Log.Logf("gql", "SERVER_STARTED") server := thread.(*GQLThread) - go func() { - ctx.Log.Logf("gql", "GOROUTINE_START for %s", server.ID()) - - mux := http.NewServeMux() - http_handler, ws_handler := MakeGQLHandlers(ctx, server) - mux.HandleFunc("/gql", http_handler) - mux.HandleFunc("/gqlws", ws_handler) - mux.HandleFunc("/graphiql", GraphiQLHandler()) - fs := http.FileServer(http.Dir("./site")) - mux.Handle("/site/", http.StripPrefix("/site", fs)) - - srv_if, _ := UseStates(ctx, []GraphNode{server}, func(states []NodeState)(interface{}, error){ - server_state := states[0].(*GQLThreadState) - return &http.Server{ - Addr: server_state.Listen, - Handler: mux, - }, nil - }) - srv := srv_if.(*http.Server) - - http_done := &sync.WaitGroup{} - http_done.Add(1) - go func(srv *http.Server, http_done *sync.WaitGroup) { - defer http_done.Done() - err := srv.ListenAndServe() - if err != http.ErrServerClosed { - panic(fmt.Sprintf("Failed to start gql server: %s", err)) - } - }(srv, http_done) - - for true { - select { - case signal:=<-server.signal: - if signal.Type() == "abort" || signal.Type() == "cancel" { - err := srv.Shutdown(context.Background()) - if err != nil{ - panic(fmt.Sprintf("Failed to shutdown gql server: %s", err)) - } - http_done.Wait() - break - } - ctx.Log.Logf("gql", "GOROUTINE_SIGNAL for %s: %+v", server.ID(), signal) - // Take signals to resource and send to GQL subscriptions - } + + mux := http.NewServeMux() + http_handler, ws_handler := MakeGQLHandlers(ctx, server) + mux.HandleFunc("/gql", http_handler) + mux.HandleFunc("/gqlws", ws_handler) + mux.HandleFunc("/graphiql", GraphiQLHandler()) + fs := http.FileServer(http.Dir("./site")) + mux.Handle("/site/", http.StripPrefix("/site", fs)) + + UseStates(ctx, []GraphNode{server}, func(states []NodeState)(interface{}, error){ + server_state := states[0].(*GQLThreadState) + server.http_server = &http.Server{ + Addr: server_state.Listen, + Handler: mux, } - }() + return nil, nil + }) + + server.http_done.Add(1) + go func(server *GQLThread) { + defer server.http_done.Done() + err := server.http_server.ListenAndServe() + if err != http.ErrServerClosed { + panic(fmt.Sprintf("Failed to start gql server: %s", err)) + } + }(server) + return "wait", nil }, } var gql_handlers ThreadHandlers = ThreadHandlers{ + "child_added": func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { + ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: %+v", signal) + UseStates(ctx, []GraphNode{thread}, func(states []NodeState)(interface{}, error) { + server_state := states[0].(*GQLThreadState) + should_run, exists := server_state.child_info[signal.Source()].(*GQLThreadInfo) + if exists == false { + ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: tried to start %s whis is not a child") + return nil, nil + } + if should_run.Start == true && should_run.Started == false { + ChildGo(ctx, server_state, thread, signal.Source()) + should_run.Started = false + } + return nil, nil + }) + return "wait", nil + }, + "abort": func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { + ctx.Log.Logf("gql", "GQL_ABORT") + server := thread.(*GQLThread) + server.http_server.Shutdown(context.TODO()) + server.http_done.Wait() + return "", fmt.Errorf("GQLThread aborted by signal") + }, + "cancel": func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { + ctx.Log.Logf("gql", "GQL_CANCEL") + server := thread.(*GQLThread) + server.http_server.Shutdown(context.TODO()) + server.http_done.Wait() + return "", nil + }, } func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable, extended_types ObjTypeMap, extended_queries FieldMap, extended_mutations FieldMap, extended_subscriptions FieldMap) (*GQLThread, error) { @@ -405,6 +435,8 @@ func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable, ex thread := &GQLThread { BaseThread: base_thread, + http_server: nil, + http_done: &sync.WaitGroup{}, extended_types: extended_types, extended_queries: extended_queries, extended_mutations: extended_mutations, diff --git a/gql_test.go b/gql_test.go index 93c627b..e8d67e4 100644 --- a/gql_test.go +++ b/gql_test.go @@ -6,18 +6,28 @@ import ( ) func TestGQLThread(t * testing.T) { + println("TEST_GQL") ctx := logTestContext(t, []string{"gqlws", "gql", "thread", "update"}) gql_thread, err := NewGQLThread(ctx, ":8080", []Lockable{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) fatalErr(t, err) - test_thread, err := NewSimpleBaseThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) + test_thread_1, err := NewSimpleBaseThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) fatalErr(t, err) - err = LinkThreads(ctx, gql_thread, test_thread, nil) + test_thread_2, err := NewSimpleBaseThread(ctx, "Test thread 2", []Lockable{}, ThreadActions{}, ThreadHandlers{}) + fatalErr(t, err) + + i1 := NewGQLThreadInfo(true) + err = LinkThreads(ctx, gql_thread, test_thread_1, &i1) + fatalErr(t, err) + + i2 := NewGQLThreadInfo(false) + err = LinkThreads(ctx, gql_thread, test_thread_2, &i2) fatalErr(t, err) go func(thread Thread){ time.Sleep(10*time.Millisecond) + // Check that test_thread_1 is running and test_thread_2 is not SendUpdate(ctx, thread, CancelSignal(nil)) }(gql_thread) diff --git a/thread.go b/thread.go index f671cfd..33317a1 100644 --- a/thread.go +++ b/thread.go @@ -3,6 +3,7 @@ package graphvent import ( "fmt" "time" + "sync" "errors" "reflect" "encoding/json" @@ -53,8 +54,11 @@ type ThreadState interface { Parent() Thread SetParent(parent Thread) Children() []Thread + Child(id NodeID) Thread ChildInfo(child NodeID) ThreadInfo AddChild(child Thread, info ThreadInfo) error + Start() error + Stop() error } type BaseThreadState struct { @@ -63,6 +67,7 @@ type BaseThreadState struct { children []Thread child_info map[NodeID] ThreadInfo InfoType reflect.Type + running bool } type BaseThreadStateJSON struct { @@ -90,6 +95,22 @@ func (state * BaseThreadState) MarshalJSON() ([]byte, error) { }) } +func (state * BaseThreadState) Start() error { + if state.running == true { + return fmt.Errorf("Cannot start a running thread") + } + state.running = true + return nil +} + +func (state * BaseThreadState) Stop() error { + if state.running == false { + return fmt.Errorf("Cannot stop a thread that's not running") + } + state.running = false + return nil +} + func (state * BaseThreadState) Parent() Thread { return state.parent } @@ -102,6 +123,17 @@ func (state * BaseThreadState) Children() []Thread { return state.children } +func (state * BaseThreadState) Child(id NodeID) Thread { + for _, child := range(state.children) { + if child.ID() == id { + return child + } + } + return nil +} + + + func (state * BaseThreadState) ChildInfo(child NodeID) ThreadInfo { return state.child_info[child] } @@ -188,6 +220,8 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf return err } + SendUpdate(ctx, thread, NewSignal(child, "child_added")) + return nil } @@ -202,6 +236,8 @@ type Thread interface { ClearTimeout() Timeout() <-chan time.Time TimeoutAction() string + + ChildWaits() *sync.WaitGroup } func FindChild(ctx * GraphContext, thread Thread, thread_state ThreadState, id NodeID) Thread { @@ -228,8 +264,26 @@ func FindChild(ctx * GraphContext, thread Thread, thread_state ThreadState, id N return nil } +func ChildGo(ctx * GraphContext, thread_state ThreadState, thread Thread, child_id NodeID) { + child := thread_state.Child(child_id) + if child == nil { + panic(fmt.Errorf("Child not in thread, can't start %s", child_id)) + } + thread.ChildWaits().Add(1) + go func(child Thread) { + ctx.Log.Logf("gql", "THREAD_START_CHILD: %s", child.ID()) + defer thread.ChildWaits().Done() + err := RunThread(ctx, child) + if err != nil { + ctx.Log.Logf("gql", "THREAD_CHILD_RUN_ERR: %s %e", child.ID(), err) + } else { + ctx.Log.Logf("gql", "THREAD_CHILD_RUN_DONE: %s", child.ID()) + } + }(child) +} + func RunThread(ctx * GraphContext, thread Thread) error { - ctx.Log.Logf("thread", "EVENT_RUN: %s", thread.ID()) + ctx.Log.Logf("thread", "THREAD_RUN: %s", thread.ID()) err := LockLockable(ctx, thread, thread, nil) if err != nil { @@ -239,9 +293,11 @@ func RunThread(ctx * GraphContext, thread Thread) error { _, err = UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (interface{}, error) { thread_state := states[0].(ThreadState) if thread_state.Owner() == nil { - return nil, fmt.Errorf("EVENT_RUN_NOT_LOCKED: %s", thread_state.Name()) + return nil, fmt.Errorf("THREAD_RUN_NOT_LOCKED: %s", thread_state.Name()) } else if thread_state.Owner().ID() != thread.ID() { - return nil, fmt.Errorf("EVENT_RUN_RESOURCE_ALREADY_LOCKED: %s, %s", thread_state.Name(), thread_state.Owner().ID()) + return nil, fmt.Errorf("THREAD_RUN_RESOURCE_ALREADY_LOCKED: %s, %s", thread_state.Name(), thread_state.Owner().ID()) + } else if err := thread_state.Start(); err != nil { + return nil, fmt.Errorf("THREAD_START_ERR: %e", err) } return nil, nil }) @@ -256,16 +312,26 @@ func RunThread(ctx * GraphContext, thread Thread) error { return errors.New(error_str) } - ctx.Log.Logf("thread", "EVENT_ACTION: %s - %s", thread.ID(), next_action) + ctx.Log.Logf("thread", "THREAD_ACTION: %s - %s", thread.ID(), next_action) next_action, err = action(ctx, thread) if err != nil { return err } } + _, err = UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (interface{}, error) { + thread_state := states[0].(ThreadState) + err := thread_state.Stop() + return nil, err + }) + if err != nil { + ctx.Log.Logf("thread", "THREAD_RUN_STOP_ERR: %e", err) + return err + } + SendUpdate(ctx, thread, NewSignal(thread, "thread_done")) - ctx.Log.Logf("thread", "EVENT_RUN_DONE: %s", thread.ID()) + ctx.Log.Logf("thread", "THREAD_RUN_DONE: %s", thread.ID()) return nil } @@ -287,6 +353,12 @@ type BaseThread struct { timeout <-chan time.Time timeout_action string + + child_waits sync.WaitGroup +} + +func (thread * BaseThread) ChildWaits() *sync.WaitGroup { + return &thread.child_waits } func (thread * BaseThread) Lock(node GraphNode, state LockableState) error {