From d2b32bac5e77d65af1b2c123f368e26cb1740236 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 1 Jul 2023 13:03:28 -0600 Subject: [PATCH] Moved GQL context information out of node runtime state and into context --- gql.go | 175 ++++++++++++++++++----------------------------- gql_graph.go | 52 +++----------- gql_test.go | 2 +- graph.go | 131 ++++++++++++++++++++++++++++++++++- lockable.go | 77 ++------------------- lockable_test.go | 16 +++-- thread.go | 79 ++++++++++++++++++++- thread_test.go | 30 ++++++++ 8 files changed, 324 insertions(+), 238 deletions(-) diff --git a/gql.go b/gql.go index 410f1da..7d80e2b 100644 --- a/gql.go +++ b/gql.go @@ -112,7 +112,11 @@ func enableCORS(w *http.ResponseWriter) { (*w).Header().Set("Access-Control-Allow-Methods", "*") } -func GQLHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Context) func(http.ResponseWriter, *http.Request) { +func GQLHandler(ctx * GraphContext, server * GQLThread) func(http.ResponseWriter, *http.Request) { + gql_ctx := context.Background() + gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx) + gql_ctx = context.WithValue(gql_ctx, "gql_server", server) + return func(w http.ResponseWriter, r * http.Request) { ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr) enableCORS(&w) @@ -131,7 +135,7 @@ func GQLHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Conte json.Unmarshal(str, &query) params := graphql.Params{ - Schema: schema, + Schema: ctx.GQL.Schema, Context: gql_ctx, RequestString: query.Query, } @@ -199,7 +203,11 @@ func GQLWSDo(ctx * GraphContext, p graphql.Params) chan *graphql.Result { return sendOneResultAndClose(res) } -func GQLWSHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Context) func(http.ResponseWriter, *http.Request) { +func GQLWSHandler(ctx * GraphContext, server * GQLThread) func(http.ResponseWriter, *http.Request) { + gql_ctx := context.Background() + gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx) + gql_ctx = context.WithValue(gql_ctx, "gql_server", server) + return func(w http.ResponseWriter, r * http.Request) { ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr) enableCORS(&w) @@ -251,7 +259,7 @@ func GQLWSHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Con } else if msg.Type == "subscribe" { ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) params := graphql.Params{ - Schema: schema, + Schema: ctx.GQL.Schema, Context: gql_ctx, RequestString: msg.Payload.Query, } @@ -316,6 +324,7 @@ func GQLWSHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Con } } +type TypeList []graphql.Type type ObjTypeMap map[reflect.Type]*graphql.Object type FieldMap map[string]*graphql.Field @@ -323,10 +332,6 @@ type GQLThread struct { BaseThread http_server *http.Server http_done *sync.WaitGroup - extended_types ObjTypeMap - extended_queries FieldMap - extended_subscriptions FieldMap - extended_mutations FieldMap } type GQLThreadInfo struct { @@ -343,11 +348,55 @@ func NewGQLThreadInfo(start bool) GQLThreadInfo { return info } +type GQLThreadStateJSON struct { + BaseThreadStateJSON + Listen string +} + type GQLThreadState struct { BaseThreadState Listen string } +func (state * GQLThreadState) MarshalJSON() ([]byte, error) { + thread_state := SaveBaseThreadState(&state.BaseThreadState) + return json.Marshal(&GQLThreadStateJSON{ + BaseThreadStateJSON: thread_state, + Listen: state.Listen, + }) +} + +func LoadGQLThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap) (NodeState, error){ + var j GQLThreadStateJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + thread_state, err := RestoreBaseThreadState(ctx, j.BaseThreadStateJSON, loaded_nodes) + if err != nil { + return nil, err + } + + state := &GQLThreadState{ + BaseThreadState: *thread_state, + Listen: j.Listen, + } + + return state, nil +} + +func LoadGQLThread(ctx * GraphContext, id NodeID) (GraphNode, error) { + thread := RestoreBaseThread(ctx, id) + gql_thread := GQLThread{ + BaseThread: thread, + http_server: nil, + http_done: &sync.WaitGroup{}, + } + + return &gql_thread, nil +} + func NewGQLThreadState(listen string) GQLThreadState { state := GQLThreadState{ BaseThreadState: NewBaseThreadState("GQL Server", "gql_thread"), @@ -362,11 +411,15 @@ var gql_actions ThreadActions = ThreadActions{ ctx.Log.Logf("gql", "SERVER_STARTED") server := thread.(*GQLThread) + // Serve the GQL http and ws handlers mux := http.NewServeMux() - http_handler, ws_handler := MakeGQLHandlers(ctx, server) - mux.HandleFunc("/gql", http_handler) - mux.HandleFunc("/gqlws", ws_handler) + mux.HandleFunc("/gql", GQLHandler(ctx, server)) + mux.HandleFunc("/gqlws", GQLWSHandler(ctx, server)) + + // Server a graphiql interface(TODO make configurable whether to start this) mux.HandleFunc("/graphiql", GraphiQLHandler()) + + // Server the ./site directory to /site (TODO make configurable with better defaults) fs := http.FileServer(http.Dir("./site")) mux.Handle("/site/", http.StripPrefix("/site", fs)) @@ -426,7 +479,7 @@ var gql_handlers ThreadHandlers = ThreadHandlers{ }, } -func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable, extended_types ObjTypeMap, extended_queries FieldMap, extended_mutations FieldMap, extended_subscriptions FieldMap) (*GQLThread, error) { +func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable) (*GQLThread, error) { state := NewGQLThreadState(listen) base_thread, err := NewBaseThread(ctx, gql_actions, gql_handlers, &state) if err != nil { @@ -437,10 +490,6 @@ func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable, ex BaseThread: base_thread, http_server: nil, http_done: &sync.WaitGroup{}, - extended_types: extended_types, - extended_queries: extended_queries, - extended_mutations: extended_mutations, - extended_subscriptions: extended_subscriptions, } err = LinkLockables(ctx, thread, requirements) @@ -449,97 +498,3 @@ func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable, ex } return thread, nil } - -func MakeGQLHandlers(ctx * GraphContext, server * GQLThread) (func(http.ResponseWriter, *http.Request), func(http.ResponseWriter, *http.Request)) { - valid_nodes := map[reflect.Type]*graphql.Object{} - valid_lockables := map[reflect.Type]*graphql.Object{} - valid_threads := map[reflect.Type]*graphql.Object{} - valid_lockables[reflect.TypeOf((*BaseLockable)(nil))] = GQLTypeBaseLockable() - for t, v := range(valid_lockables) { - valid_nodes[t] = v - } - valid_threads[reflect.TypeOf((*BaseThread)(nil))] = GQLTypeBaseThread() - valid_threads[reflect.TypeOf((*GQLThread)(nil))] = GQLTypeGQLThread() - for t, v := range(valid_threads) { - valid_lockables[t] = v - valid_nodes[t] = v - } - - - gql_types := []graphql.Type{GQLTypeSignal(), GQLTypeSignalInput()} - for _, v := range(valid_nodes) { - gql_types = append(gql_types, v) - } - - node_type := reflect.TypeOf((*GraphNode)(nil)).Elem() - lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem() - thread_type := reflect.TypeOf((*Thread)(nil)).Elem() - - for go_t, gql_t := range(server.extended_types) { - if go_t.Implements(node_type) { - valid_nodes[go_t] = gql_t - } - if go_t.Implements(lockable_type) { - valid_lockables[go_t] = gql_t - } - if go_t.Implements(thread_type) { - valid_threads[go_t] = gql_t - } - gql_types = append(gql_types, gql_t) - } - - gql_queries := graphql.Fields{ - "Self": GQLQuerySelf(), - } - - for key, value := range(server.extended_queries) { - gql_queries[key] = value - } - - gql_subscriptions := graphql.Fields{ - "Update": GQLSubscriptionUpdate(), - } - - for key, value := range(server.extended_subscriptions) { - gql_subscriptions[key] = value - } - - gql_mutations := graphql.Fields{ - "SendUpdate": GQLMutationSendUpdate(), - } - - for key, value := range(server.extended_mutations) { - gql_mutations[key] = value - } - - schemaConfig := graphql.SchemaConfig{ - Types: gql_types, - Query: graphql.NewObject(graphql.ObjectConfig{ - Name: "Query", - Fields: gql_queries, - }), - Mutation: graphql.NewObject(graphql.ObjectConfig{ - Name: "Mutation", - Fields: gql_mutations, - }), - Subscription: graphql.NewObject(graphql.ObjectConfig{ - Name: "Subscription", - Fields: gql_subscriptions, - }), - } - - schema, err := graphql.NewSchema(schemaConfig) - if err != nil{ - panic(err) - } - gql_ctx := context.Background() - gql_ctx = context.WithValue(gql_ctx, "valid_nodes", valid_nodes) - gql_ctx = context.WithValue(gql_ctx, "node_type", &node_type) - gql_ctx = context.WithValue(gql_ctx, "valid_lockables", valid_lockables) - gql_ctx = context.WithValue(gql_ctx, "lockable_type", &lockable_type) - gql_ctx = context.WithValue(gql_ctx, "valid_threads", valid_threads) - gql_ctx = context.WithValue(gql_ctx, "thread_type", &thread_type) - gql_ctx = context.WithValue(gql_ctx, "gql_server", server) - gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx) - return GQLHandler(ctx, schema, gql_ctx), GQLWSHandler(ctx, schema, gql_ctx) -} diff --git a/gql_graph.go b/gql_graph.go index 291cb8c..a2e9f88 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -16,18 +16,9 @@ func GQLInterfaceGraphNode() *graphql.Interface { if ok == false { return nil } - valid_nodes, ok := p.Context.Value("valid_nodes").(map[reflect.Type]*graphql.Object) - if ok == false { - ctx.Log.Logf("gql", "Failed to get valid_nodes from Context") - return nil - } - - node_type, ok := p.Context.Value("node_type").(*reflect.Type) - if ok == false { - ctx.Log.Logf("gql", "Failed to get node_type from Context: %+v", p.Context.Value("node_type")) - return nil - } + valid_nodes := ctx.GQL.ValidNodes + node_type := ctx.GQL.NodeType p_type := reflect.TypeOf(p.Value) for key, value := range(valid_nodes) { @@ -36,7 +27,7 @@ func GQLInterfaceGraphNode() *graphql.Interface { } } - if p_type.Implements(*node_type) { + if p_type.Implements(node_type) { return GQLTypeBaseNode() } @@ -75,33 +66,22 @@ func GQLInterfaceThread() *graphql.Interface { if ok == false { return nil } - valid_threads, ok := p.Context.Value("valid_threads").(map[reflect.Type]*graphql.Object) - if ok == false { - ctx.Log.Logf("gql", "Failed to get valid_threads from Context") - return nil - } - - thread_type, ok := p.Context.Value("thread_type").(*reflect.Type) - if ok == false { - ctx.Log.Logf("gql", "Failed to get thread_type from Context: %+v", p.Context.Value("thread_type")) - return nil - } + valid_threads := ctx.GQL.ValidThreads + thread_type := ctx.GQL.ThreadType p_type := reflect.TypeOf(p.Value) - for key, value := range(valid_threads) { if p_type == key { return value } } - if p_type.Implements(*thread_type) { + if p_type.Implements(thread_type) { return GQLTypeBaseThread() } - ctx.Log.Logf("gql", "Found no type that matches %+v: %+v", p_type, p_type.Implements(*thread_type)) - + ctx.Log.Logf("gql", "Found no type that matches %+v: %+v", p_type, p_type.Implements(thread_type)) return nil }, Fields: graphql.Fields{}, @@ -157,21 +137,10 @@ func GQLInterfaceLockable() *graphql.Interface { if ok == false { return nil } - ctx.Log.Logf("gql", "LOCKABLE_RESOLVE: %+v", p.Value) - valid_lockables, ok := p.Context.Value("valid_lockables").(map[reflect.Type]*graphql.Object) - if ok == false { - ctx.Log.Logf("gql", "Failed to get valid_lockables from Context") - return nil - } - - lockable_type, ok := p.Context.Value("lockable_type").(*reflect.Type) - if ok == false { - ctx.Log.Logf("gql", "Failed to get lockable_type from Context: %+v", p.Context.Value("lockable_type")) - return nil - } + valid_lockables := ctx.GQL.ValidLockables + lockable_type := ctx.GQL.LockableType p_type := reflect.TypeOf(p.Value) - ctx.Log.Logf("gql", "Value Type: %+v, Lockable Type: %+v", p_type, *lockable_type) for key, value := range(valid_lockables) { if p_type == key { @@ -179,8 +148,7 @@ func GQLInterfaceLockable() *graphql.Interface { } } - if p_type.Implements(*lockable_type) { - ctx.Log.Logf("gql", "LOCKABLE_RESOLVE_DEFAULT") + if p_type.Implements(lockable_type) { return GQLTypeBaseLockable() } return nil diff --git a/gql_test.go b/gql_test.go index 9e9ce78..a06d6a3 100644 --- a/gql_test.go +++ b/gql_test.go @@ -7,7 +7,7 @@ import ( func TestGQLThread(t * testing.T) { ctx := testContext(t) - gql_thread, err := NewGQLThread(ctx, ":8080", []Lockable{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) + gql_thread, err := NewGQLThread(ctx, ":8080", []Lockable{}) fatalErr(t, err) test_thread_1, err := NewSimpleBaseThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) diff --git a/graph.go b/graph.go index 0a615cf..f4319ba 100644 --- a/graph.go +++ b/graph.go @@ -2,7 +2,9 @@ package graphvent import ( "sync" + "reflect" "github.com/google/uuid" + "github.com/graphql-go/graphql" "os" "github.com/rs/zerolog" "fmt" @@ -10,6 +12,14 @@ import ( "encoding/json" ) +// For persistance, each node needs the following functions(* is a placeholder for the node/state type): +// Load*State - StateLoadFunc that returns the NodeState interface to attach to the node +// Load* - NodeLoadFunc that returns the GraphNode restored from it's loaded state + +// For convenience, the following functions are a good idea to define for composability: +// Restore*State - takes in the nodes serialized data to allow for easier nesting of inherited Load*State functions +// Save*State - serialize the node into it's json counterpart to be included as part of a larger json + type StateLoadFunc func(*GraphContext, []byte, NodeMap)(NodeState, error) type StateLoadMap map[string]StateLoadFunc type NodeLoadFunc func(*GraphContext, NodeID)(GraphNode, error) @@ -19,6 +29,115 @@ type GraphContext struct { Log Logger NodeLoadFuncs NodeLoadMap StateLoadFuncs StateLoadMap + GQL * GQLContext +} + +type GQLContext struct { + Schema graphql.Schema + ValidNodes ObjTypeMap + NodeType reflect.Type + ValidLockables ObjTypeMap + LockableType reflect.Type + ValidThreads ObjTypeMap + ThreadType reflect.Type +} + +func NewGQLContext(additional_types TypeList, extended_types ObjTypeMap, extended_queries FieldMap, extended_mutations FieldMap, extended_subscriptions FieldMap) (*GQLContext, error) { + type_list := TypeList{ + GQLTypeSignalInput(), + } + + for _, gql_type := range(additional_types) { + type_list = append(type_list, gql_type) + } + + type_map := ObjTypeMap{} + type_map[reflect.TypeOf((*BaseLockable)(nil))] = GQLTypeBaseLockable() + type_map[reflect.TypeOf((*BaseThread)(nil))] = GQLTypeBaseThread() + type_map[reflect.TypeOf((*GQLThread)(nil))] = GQLTypeGQLThread() + type_map[reflect.TypeOf((*BaseSignal)(nil))] = GQLTypeSignal() + + for go_t, gql_t := range(extended_types) { + type_map[go_t] = gql_t + } + + valid_nodes := ObjTypeMap{} + valid_lockables := ObjTypeMap{} + valid_threads := ObjTypeMap{} + + node_type := reflect.TypeOf((*GraphNode)(nil)).Elem() + lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem() + thread_type := reflect.TypeOf((*Thread)(nil)).Elem() + + for go_t, gql_t := range(type_map) { + if go_t.Implements(node_type) { + valid_nodes[go_t] = gql_t + } + if go_t.Implements(lockable_type) { + valid_lockables[go_t] = gql_t + } + if go_t.Implements(thread_type) { + valid_threads[go_t] = gql_t + } + type_list = append(type_list, gql_t) + } + + queries := graphql.Fields{ + "Self": GQLQuerySelf(), + } + + for key, val := range(extended_queries) { + queries[key] = val + } + + mutations := graphql.Fields{ + "SendUpdate": GQLMutationSendUpdate(), + } + + for key, val := range(extended_mutations) { + mutations[key] = val + } + + subscriptions := graphql.Fields{ + "Update": GQLSubscriptionUpdate(), + } + + for key, val := range(extended_subscriptions) { + subscriptions[key] = val + } + + schemaConfig := graphql.SchemaConfig{ + Types: type_list, + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: queries, + }), + Mutation: graphql.NewObject(graphql.ObjectConfig{ + Name: "Mutation", + Fields: mutations, + }), + Subscription: graphql.NewObject(graphql.ObjectConfig{ + Name: "Subscription", + Fields: subscriptions, + }), + } + + schema, err := graphql.NewSchema(schemaConfig) + if err != nil{ + return nil, err + } + + ctx := GQLContext{ + Schema: schema, + ValidNodes: valid_nodes, + NodeType: node_type, + ValidThreads: valid_threads, + ThreadType: thread_type, + ValidLockables: valid_lockables, + LockableType: lockable_type, + } + + return &ctx, nil } func LoadNode(ctx * GraphContext, id NodeID) (GraphNode, error) { @@ -51,7 +170,7 @@ func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]Grap node_fn, exists := ctx.NodeLoadFuncs[base.Type] if exists == false { - return nil, fmt.Errorf("%s is not a known node type", base.Type) + return nil, fmt.Errorf("%s is not a known node type: %s", base.Type, state_bytes) } node, err = node_fn(ctx, id) @@ -77,21 +196,27 @@ func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]Grap } func NewGraphContext(db * badger.DB, log Logger) * GraphContext { + gql, err := NewGQLContext(TypeList{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) + if err != nil { + panic(err) + } + ctx := GraphContext{ + GQL: gql, DB: db, Log: log, NodeLoadFuncs: NodeLoadMap{ "base_lockable": LoadBaseLockable, "base_thread": LoadBaseThread, + "gql_thread": LoadGQLThread, }, StateLoadFuncs: StateLoadMap{ "base_lockable": LoadBaseLockableState, "base_thread": LoadBaseThreadState, + "gql_thread": LoadGQLThreadState, }, } - - return &ctx } diff --git a/lockable.go b/lockable.go index 90d310b..0af1797 100644 --- a/lockable.go +++ b/lockable.go @@ -538,85 +538,16 @@ func NewBaseLockable(ctx * GraphContext, state LockableState) (BaseLockable, err return lockable, nil } -func LoadBaseThread(ctx * GraphContext, id NodeID) (GraphNode, error) { +func RestoreBaseLockable(ctx * GraphContext, id NodeID) BaseLockable { base_node := RestoreNode(ctx, id) - thread := BaseThread{ - BaseLockable: BaseLockable{ - BaseNode: base_node, - }, - } - - return &thread, nil -} - -func RestoreBaseThreadState(ctx * GraphContext, j BaseThreadStateJSON, loaded_nodes NodeMap) (*BaseThreadState, error) { - lockable_state, err := RestoreBaseLockableState(ctx, j.LockableState, loaded_nodes) - if err != nil { - return nil, err - } - lockable_state._type = "thread_state" - - state := BaseThreadState{ - BaseLockableState: *lockable_state, - parent: nil, - children: make([]Thread, len(j.Children)), - child_info: map[NodeID]ThreadInfo{}, - InfoType: nil, - running: false, - } - - if j.Parent != nil { - p, err := LoadNodeRecurse(ctx, *j.Parent, loaded_nodes) - if err != nil { - return nil, err - } - p_t, ok := p.(Thread) - if ok == false { - return nil, err - } - state.owner = p_t - } - - i := 0 - for id, info := range(j.Children) { - child_node, err := LoadNodeRecurse(ctx, id, loaded_nodes) - if err != nil { - return nil, err - } - child_t, ok := child_node.(Thread) - if ok == false { - return nil, fmt.Errorf("%+v is not a Thread as expected", child_node) - } - state.children[i] = child_t - state.child_info[id] = info - i++ - } - - return &state, nil -} - -func LoadBaseThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){ - var j BaseThreadStateJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - state, err := RestoreBaseThreadState(ctx, j, loaded_nodes) - if err != nil { - return nil, err + return BaseLockable{ + BaseNode: base_node, } - - return state, nil } func LoadBaseLockable(ctx * GraphContext, id NodeID) (GraphNode, error) { // call LoadNodeRecurse on any connected nodes to ensure they're loaded and return the id - base_node := RestoreNode(ctx, id) - lockable := BaseLockable{ - BaseNode: base_node, - } - + lockable := RestoreBaseLockable(ctx, id) return &lockable, nil } diff --git a/lockable_test.go b/lockable_test.go index 00d2768..cb7e636 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -362,18 +362,13 @@ func TestLockableDependencyOverlap(t * testing.T) { } func TestLockableDBLoad(t * testing.T){ - ctx := logTestContext(t, []string{"db"}) + ctx := logTestContext(t, []string{}) l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{}) fatalErr(t, err) l2, err := NewSimpleBaseLockable(ctx, "Test Lockable 2", []Lockable{}) fatalErr(t, err) l3, err := NewSimpleBaseLockable(ctx, "Test Lockable 3", []Lockable{l1, l2}) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{l3}, func(states NodeStateMap) error { - ser, err := json.MarshalIndent(states[l3.ID()], "", " ") - fmt.Printf("\n%s\n\n", ser) - return err - }) l4, err := NewSimpleBaseLockable(ctx, "Test Lockable 4", []Lockable{l3}) fatalErr(t, err) _, err = NewSimpleBaseLockable(ctx, "Test Lockable 5", []Lockable{l4}) @@ -391,6 +386,13 @@ func TestLockableDBLoad(t * testing.T){ return err }) - _, err = LoadNode(ctx, l3.ID()) + l3_loaded, err := LoadNode(ctx, l3.ID()) fatalErr(t, err) + + // TODO: add more equivalence checks + err = UseStates(ctx, []GraphNode{l3_loaded}, func(states NodeStateMap) error { + ser, err := json.MarshalIndent(states[l3_loaded.ID()], "", " ") + fmt.Printf("\n%s\n\n", ser) + return err + }) } diff --git a/thread.go b/thread.go index 06181ff..c5c2108 100644 --- a/thread.go +++ b/thread.go @@ -72,7 +72,7 @@ type BaseThreadState struct { type BaseThreadStateJSON struct { Parent *NodeID `json:"parent"` Children map[NodeID]interface{} `json:"children"` - LockableState BaseLockableStateJSON `json:"lockable"` + BaseLockableStateJSON } func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON { @@ -92,10 +92,85 @@ func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON { return BaseThreadStateJSON{ Parent: parent_id, Children: children, - LockableState: lockable_state, + BaseLockableStateJSON: lockable_state, } } +func RestoreBaseThread(ctx * GraphContext, id NodeID) BaseThread { + base_lockable := RestoreBaseLockable(ctx, id) + thread := BaseThread{ + BaseLockable: base_lockable, + } + + return thread +} + +func LoadBaseThread(ctx * GraphContext, id NodeID) (GraphNode, error) { + thread := RestoreBaseThread(ctx, id) + return &thread, nil +} + +func RestoreBaseThreadState(ctx * GraphContext, j BaseThreadStateJSON, loaded_nodes NodeMap) (*BaseThreadState, error) { + lockable_state, err := RestoreBaseLockableState(ctx, j.BaseLockableStateJSON, loaded_nodes) + if err != nil { + return nil, err + } + lockable_state._type = "thread_state" + + state := BaseThreadState{ + BaseLockableState: *lockable_state, + parent: nil, + children: make([]Thread, len(j.Children)), + child_info: map[NodeID]ThreadInfo{}, + InfoType: nil, + running: false, + } + + if j.Parent != nil { + p, err := LoadNodeRecurse(ctx, *j.Parent, loaded_nodes) + if err != nil { + return nil, err + } + p_t, ok := p.(Thread) + if ok == false { + return nil, err + } + state.owner = p_t + } + + i := 0 + for id, info := range(j.Children) { + child_node, err := LoadNodeRecurse(ctx, id, loaded_nodes) + if err != nil { + return nil, err + } + child_t, ok := child_node.(Thread) + if ok == false { + return nil, fmt.Errorf("%+v is not a Thread as expected", child_node) + } + state.children[i] = child_t + state.child_info[id] = info + i++ + } + + return &state, nil +} + +func LoadBaseThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){ + var j BaseThreadStateJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + state, err := RestoreBaseThreadState(ctx, j, loaded_nodes) + if err != nil { + return nil, err + } + + return state, nil +} + func (state * BaseThreadState) MarshalJSON() ([]byte, error) { thread_state := SaveBaseThreadState(state) return json.Marshal(&thread_state) diff --git a/thread_test.go b/thread_test.go index 5b101e3..3d8a79b 100644 --- a/thread_test.go +++ b/thread_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" "fmt" + "encoding/json" ) func TestNewThread(t * testing.T) { @@ -56,3 +57,32 @@ func TestThreadWithRequirement(t * testing.T) { }) fatalErr(t, err) } + +func TestThreadDBLoad(t * testing.T) { + ctx := logTestContext(t, []string{}) + l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{}) + fatalErr(t, err) + + t1, err := NewSimpleBaseThread(ctx, "Test Thread 1", []Lockable{l1}, ThreadActions{}, ThreadHandlers{}) + fatalErr(t, err) + + + SendUpdate(ctx, t1, CancelSignal(nil)) + err = RunThread(ctx, t1) + fatalErr(t, err) + + err = UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error { + ser, err := json.MarshalIndent(states[t1.ID()], "", " ") + fmt.Printf("\n%s\n\n", ser) + return err + }) + + t1_loaded, err := LoadNode(ctx, t1.ID()) + fatalErr(t, err) + + err = UseStates(ctx, []GraphNode{t1_loaded}, func(states NodeStateMap) error { + ser, err := json.MarshalIndent(states[t1_loaded.ID()], "", " ") + fmt.Printf("\n%s\n\n", ser) + return err + }) +}