Moved GQL context information out of node runtime state and into context

graph-rework-2
noah metz 2023-07-01 13:03:28 -06:00
parent 04771b7816
commit d2b32bac5e
8 changed files with 324 additions and 238 deletions

175
gql.go

@ -112,7 +112,11 @@ func enableCORS(w *http.ResponseWriter) {
(*w).Header().Set("Access-Control-Allow-Methods", "*") (*w).Header().Set("Access-Control-Allow-Methods", "*")
} }
func GQLHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Context) func(http.ResponseWriter, *http.Request) { func GQLHandler(ctx * GraphContext, server * GQLThread) func(http.ResponseWriter, *http.Request) {
gql_ctx := context.Background()
gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx)
gql_ctx = context.WithValue(gql_ctx, "gql_server", server)
return func(w http.ResponseWriter, r * http.Request) { return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr) ctx.Log.Logf("gql", "GQL REQUEST: %s", r.RemoteAddr)
enableCORS(&w) enableCORS(&w)
@ -131,7 +135,7 @@ func GQLHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Conte
json.Unmarshal(str, &query) json.Unmarshal(str, &query)
params := graphql.Params{ params := graphql.Params{
Schema: schema, Schema: ctx.GQL.Schema,
Context: gql_ctx, Context: gql_ctx,
RequestString: query.Query, RequestString: query.Query,
} }
@ -199,7 +203,11 @@ func GQLWSDo(ctx * GraphContext, p graphql.Params) chan *graphql.Result {
return sendOneResultAndClose(res) return sendOneResultAndClose(res)
} }
func GQLWSHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Context) func(http.ResponseWriter, *http.Request) { func GQLWSHandler(ctx * GraphContext, server * GQLThread) func(http.ResponseWriter, *http.Request) {
gql_ctx := context.Background()
gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx)
gql_ctx = context.WithValue(gql_ctx, "gql_server", server)
return func(w http.ResponseWriter, r * http.Request) { return func(w http.ResponseWriter, r * http.Request) {
ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr) ctx.Log.Logf("gqlws_new", "HANDLING %s",r.RemoteAddr)
enableCORS(&w) enableCORS(&w)
@ -251,7 +259,7 @@ func GQLWSHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Con
} else if msg.Type == "subscribe" { } else if msg.Type == "subscribe" {
ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload) ctx.Log.Logf("gqlws", "SUBSCRIBE: %+v", msg.Payload)
params := graphql.Params{ params := graphql.Params{
Schema: schema, Schema: ctx.GQL.Schema,
Context: gql_ctx, Context: gql_ctx,
RequestString: msg.Payload.Query, RequestString: msg.Payload.Query,
} }
@ -316,6 +324,7 @@ func GQLWSHandler(ctx * GraphContext, schema graphql.Schema, gql_ctx context.Con
} }
} }
type TypeList []graphql.Type
type ObjTypeMap map[reflect.Type]*graphql.Object type ObjTypeMap map[reflect.Type]*graphql.Object
type FieldMap map[string]*graphql.Field type FieldMap map[string]*graphql.Field
@ -323,10 +332,6 @@ type GQLThread struct {
BaseThread BaseThread
http_server *http.Server http_server *http.Server
http_done *sync.WaitGroup http_done *sync.WaitGroup
extended_types ObjTypeMap
extended_queries FieldMap
extended_subscriptions FieldMap
extended_mutations FieldMap
} }
type GQLThreadInfo struct { type GQLThreadInfo struct {
@ -343,11 +348,55 @@ func NewGQLThreadInfo(start bool) GQLThreadInfo {
return info return info
} }
type GQLThreadStateJSON struct {
BaseThreadStateJSON
Listen string
}
type GQLThreadState struct { type GQLThreadState struct {
BaseThreadState BaseThreadState
Listen string Listen string
} }
func (state * GQLThreadState) MarshalJSON() ([]byte, error) {
thread_state := SaveBaseThreadState(&state.BaseThreadState)
return json.Marshal(&GQLThreadStateJSON{
BaseThreadStateJSON: thread_state,
Listen: state.Listen,
})
}
func LoadGQLThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap) (NodeState, error){
var j GQLThreadStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
thread_state, err := RestoreBaseThreadState(ctx, j.BaseThreadStateJSON, loaded_nodes)
if err != nil {
return nil, err
}
state := &GQLThreadState{
BaseThreadState: *thread_state,
Listen: j.Listen,
}
return state, nil
}
func LoadGQLThread(ctx * GraphContext, id NodeID) (GraphNode, error) {
thread := RestoreBaseThread(ctx, id)
gql_thread := GQLThread{
BaseThread: thread,
http_server: nil,
http_done: &sync.WaitGroup{},
}
return &gql_thread, nil
}
func NewGQLThreadState(listen string) GQLThreadState { func NewGQLThreadState(listen string) GQLThreadState {
state := GQLThreadState{ state := GQLThreadState{
BaseThreadState: NewBaseThreadState("GQL Server", "gql_thread"), BaseThreadState: NewBaseThreadState("GQL Server", "gql_thread"),
@ -362,11 +411,15 @@ var gql_actions ThreadActions = ThreadActions{
ctx.Log.Logf("gql", "SERVER_STARTED") ctx.Log.Logf("gql", "SERVER_STARTED")
server := thread.(*GQLThread) server := thread.(*GQLThread)
// Serve the GQL http and ws handlers
mux := http.NewServeMux() mux := http.NewServeMux()
http_handler, ws_handler := MakeGQLHandlers(ctx, server) mux.HandleFunc("/gql", GQLHandler(ctx, server))
mux.HandleFunc("/gql", http_handler) mux.HandleFunc("/gqlws", GQLWSHandler(ctx, server))
mux.HandleFunc("/gqlws", ws_handler)
// Server a graphiql interface(TODO make configurable whether to start this)
mux.HandleFunc("/graphiql", GraphiQLHandler()) mux.HandleFunc("/graphiql", GraphiQLHandler())
// Server the ./site directory to /site (TODO make configurable with better defaults)
fs := http.FileServer(http.Dir("./site")) fs := http.FileServer(http.Dir("./site"))
mux.Handle("/site/", http.StripPrefix("/site", fs)) mux.Handle("/site/", http.StripPrefix("/site", fs))
@ -426,7 +479,7 @@ var gql_handlers ThreadHandlers = ThreadHandlers{
}, },
} }
func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable, extended_types ObjTypeMap, extended_queries FieldMap, extended_mutations FieldMap, extended_subscriptions FieldMap) (*GQLThread, error) { func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable) (*GQLThread, error) {
state := NewGQLThreadState(listen) state := NewGQLThreadState(listen)
base_thread, err := NewBaseThread(ctx, gql_actions, gql_handlers, &state) base_thread, err := NewBaseThread(ctx, gql_actions, gql_handlers, &state)
if err != nil { if err != nil {
@ -437,10 +490,6 @@ func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable, ex
BaseThread: base_thread, BaseThread: base_thread,
http_server: nil, http_server: nil,
http_done: &sync.WaitGroup{}, http_done: &sync.WaitGroup{},
extended_types: extended_types,
extended_queries: extended_queries,
extended_mutations: extended_mutations,
extended_subscriptions: extended_subscriptions,
} }
err = LinkLockables(ctx, thread, requirements) err = LinkLockables(ctx, thread, requirements)
@ -449,97 +498,3 @@ func NewGQLThread(ctx * GraphContext, listen string, requirements []Lockable, ex
} }
return thread, nil return thread, nil
} }
func MakeGQLHandlers(ctx * GraphContext, server * GQLThread) (func(http.ResponseWriter, *http.Request), func(http.ResponseWriter, *http.Request)) {
valid_nodes := map[reflect.Type]*graphql.Object{}
valid_lockables := map[reflect.Type]*graphql.Object{}
valid_threads := map[reflect.Type]*graphql.Object{}
valid_lockables[reflect.TypeOf((*BaseLockable)(nil))] = GQLTypeBaseLockable()
for t, v := range(valid_lockables) {
valid_nodes[t] = v
}
valid_threads[reflect.TypeOf((*BaseThread)(nil))] = GQLTypeBaseThread()
valid_threads[reflect.TypeOf((*GQLThread)(nil))] = GQLTypeGQLThread()
for t, v := range(valid_threads) {
valid_lockables[t] = v
valid_nodes[t] = v
}
gql_types := []graphql.Type{GQLTypeSignal(), GQLTypeSignalInput()}
for _, v := range(valid_nodes) {
gql_types = append(gql_types, v)
}
node_type := reflect.TypeOf((*GraphNode)(nil)).Elem()
lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem()
thread_type := reflect.TypeOf((*Thread)(nil)).Elem()
for go_t, gql_t := range(server.extended_types) {
if go_t.Implements(node_type) {
valid_nodes[go_t] = gql_t
}
if go_t.Implements(lockable_type) {
valid_lockables[go_t] = gql_t
}
if go_t.Implements(thread_type) {
valid_threads[go_t] = gql_t
}
gql_types = append(gql_types, gql_t)
}
gql_queries := graphql.Fields{
"Self": GQLQuerySelf(),
}
for key, value := range(server.extended_queries) {
gql_queries[key] = value
}
gql_subscriptions := graphql.Fields{
"Update": GQLSubscriptionUpdate(),
}
for key, value := range(server.extended_subscriptions) {
gql_subscriptions[key] = value
}
gql_mutations := graphql.Fields{
"SendUpdate": GQLMutationSendUpdate(),
}
for key, value := range(server.extended_mutations) {
gql_mutations[key] = value
}
schemaConfig := graphql.SchemaConfig{
Types: gql_types,
Query: graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: gql_queries,
}),
Mutation: graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: gql_mutations,
}),
Subscription: graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: gql_subscriptions,
}),
}
schema, err := graphql.NewSchema(schemaConfig)
if err != nil{
panic(err)
}
gql_ctx := context.Background()
gql_ctx = context.WithValue(gql_ctx, "valid_nodes", valid_nodes)
gql_ctx = context.WithValue(gql_ctx, "node_type", &node_type)
gql_ctx = context.WithValue(gql_ctx, "valid_lockables", valid_lockables)
gql_ctx = context.WithValue(gql_ctx, "lockable_type", &lockable_type)
gql_ctx = context.WithValue(gql_ctx, "valid_threads", valid_threads)
gql_ctx = context.WithValue(gql_ctx, "thread_type", &thread_type)
gql_ctx = context.WithValue(gql_ctx, "gql_server", server)
gql_ctx = context.WithValue(gql_ctx, "graph_context", ctx)
return GQLHandler(ctx, schema, gql_ctx), GQLWSHandler(ctx, schema, gql_ctx)
}

@ -16,18 +16,9 @@ func GQLInterfaceGraphNode() *graphql.Interface {
if ok == false { if ok == false {
return nil return nil
} }
valid_nodes, ok := p.Context.Value("valid_nodes").(map[reflect.Type]*graphql.Object)
if ok == false {
ctx.Log.Logf("gql", "Failed to get valid_nodes from Context")
return nil
}
node_type, ok := p.Context.Value("node_type").(*reflect.Type)
if ok == false {
ctx.Log.Logf("gql", "Failed to get node_type from Context: %+v", p.Context.Value("node_type"))
return nil
}
valid_nodes := ctx.GQL.ValidNodes
node_type := ctx.GQL.NodeType
p_type := reflect.TypeOf(p.Value) p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_nodes) { for key, value := range(valid_nodes) {
@ -36,7 +27,7 @@ func GQLInterfaceGraphNode() *graphql.Interface {
} }
} }
if p_type.Implements(*node_type) { if p_type.Implements(node_type) {
return GQLTypeBaseNode() return GQLTypeBaseNode()
} }
@ -75,33 +66,22 @@ func GQLInterfaceThread() *graphql.Interface {
if ok == false { if ok == false {
return nil return nil
} }
valid_threads, ok := p.Context.Value("valid_threads").(map[reflect.Type]*graphql.Object)
if ok == false {
ctx.Log.Logf("gql", "Failed to get valid_threads from Context")
return nil
}
thread_type, ok := p.Context.Value("thread_type").(*reflect.Type)
if ok == false {
ctx.Log.Logf("gql", "Failed to get thread_type from Context: %+v", p.Context.Value("thread_type"))
return nil
}
valid_threads := ctx.GQL.ValidThreads
thread_type := ctx.GQL.ThreadType
p_type := reflect.TypeOf(p.Value) p_type := reflect.TypeOf(p.Value)
for key, value := range(valid_threads) { for key, value := range(valid_threads) {
if p_type == key { if p_type == key {
return value return value
} }
} }
if p_type.Implements(*thread_type) { if p_type.Implements(thread_type) {
return GQLTypeBaseThread() return GQLTypeBaseThread()
} }
ctx.Log.Logf("gql", "Found no type that matches %+v: %+v", p_type, p_type.Implements(*thread_type)) ctx.Log.Logf("gql", "Found no type that matches %+v: %+v", p_type, p_type.Implements(thread_type))
return nil return nil
}, },
Fields: graphql.Fields{}, Fields: graphql.Fields{},
@ -157,21 +137,10 @@ func GQLInterfaceLockable() *graphql.Interface {
if ok == false { if ok == false {
return nil return nil
} }
ctx.Log.Logf("gql", "LOCKABLE_RESOLVE: %+v", p.Value)
valid_lockables, ok := p.Context.Value("valid_lockables").(map[reflect.Type]*graphql.Object)
if ok == false {
ctx.Log.Logf("gql", "Failed to get valid_lockables from Context")
return nil
}
lockable_type, ok := p.Context.Value("lockable_type").(*reflect.Type)
if ok == false {
ctx.Log.Logf("gql", "Failed to get lockable_type from Context: %+v", p.Context.Value("lockable_type"))
return nil
}
valid_lockables := ctx.GQL.ValidLockables
lockable_type := ctx.GQL.LockableType
p_type := reflect.TypeOf(p.Value) p_type := reflect.TypeOf(p.Value)
ctx.Log.Logf("gql", "Value Type: %+v, Lockable Type: %+v", p_type, *lockable_type)
for key, value := range(valid_lockables) { for key, value := range(valid_lockables) {
if p_type == key { if p_type == key {
@ -179,8 +148,7 @@ func GQLInterfaceLockable() *graphql.Interface {
} }
} }
if p_type.Implements(*lockable_type) { if p_type.Implements(lockable_type) {
ctx.Log.Logf("gql", "LOCKABLE_RESOLVE_DEFAULT")
return GQLTypeBaseLockable() return GQLTypeBaseLockable()
} }
return nil return nil

@ -7,7 +7,7 @@ import (
func TestGQLThread(t * testing.T) { func TestGQLThread(t * testing.T) {
ctx := testContext(t) ctx := testContext(t)
gql_thread, err := NewGQLThread(ctx, ":8080", []Lockable{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{}) gql_thread, err := NewGQLThread(ctx, ":8080", []Lockable{})
fatalErr(t, err) fatalErr(t, err)
test_thread_1, err := NewSimpleBaseThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) test_thread_1, err := NewSimpleBaseThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{})

@ -2,7 +2,9 @@ package graphvent
import ( import (
"sync" "sync"
"reflect"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/graphql-go/graphql"
"os" "os"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"fmt" "fmt"
@ -10,6 +12,14 @@ import (
"encoding/json" "encoding/json"
) )
// For persistance, each node needs the following functions(* is a placeholder for the node/state type):
// Load*State - StateLoadFunc that returns the NodeState interface to attach to the node
// Load* - NodeLoadFunc that returns the GraphNode restored from it's loaded state
// For convenience, the following functions are a good idea to define for composability:
// Restore*State - takes in the nodes serialized data to allow for easier nesting of inherited Load*State functions
// Save*State - serialize the node into it's json counterpart to be included as part of a larger json
type StateLoadFunc func(*GraphContext, []byte, NodeMap)(NodeState, error) type StateLoadFunc func(*GraphContext, []byte, NodeMap)(NodeState, error)
type StateLoadMap map[string]StateLoadFunc type StateLoadMap map[string]StateLoadFunc
type NodeLoadFunc func(*GraphContext, NodeID)(GraphNode, error) type NodeLoadFunc func(*GraphContext, NodeID)(GraphNode, error)
@ -19,6 +29,115 @@ type GraphContext struct {
Log Logger Log Logger
NodeLoadFuncs NodeLoadMap NodeLoadFuncs NodeLoadMap
StateLoadFuncs StateLoadMap StateLoadFuncs StateLoadMap
GQL * GQLContext
}
type GQLContext struct {
Schema graphql.Schema
ValidNodes ObjTypeMap
NodeType reflect.Type
ValidLockables ObjTypeMap
LockableType reflect.Type
ValidThreads ObjTypeMap
ThreadType reflect.Type
}
func NewGQLContext(additional_types TypeList, extended_types ObjTypeMap, extended_queries FieldMap, extended_mutations FieldMap, extended_subscriptions FieldMap) (*GQLContext, error) {
type_list := TypeList{
GQLTypeSignalInput(),
}
for _, gql_type := range(additional_types) {
type_list = append(type_list, gql_type)
}
type_map := ObjTypeMap{}
type_map[reflect.TypeOf((*BaseLockable)(nil))] = GQLTypeBaseLockable()
type_map[reflect.TypeOf((*BaseThread)(nil))] = GQLTypeBaseThread()
type_map[reflect.TypeOf((*GQLThread)(nil))] = GQLTypeGQLThread()
type_map[reflect.TypeOf((*BaseSignal)(nil))] = GQLTypeSignal()
for go_t, gql_t := range(extended_types) {
type_map[go_t] = gql_t
}
valid_nodes := ObjTypeMap{}
valid_lockables := ObjTypeMap{}
valid_threads := ObjTypeMap{}
node_type := reflect.TypeOf((*GraphNode)(nil)).Elem()
lockable_type := reflect.TypeOf((*Lockable)(nil)).Elem()
thread_type := reflect.TypeOf((*Thread)(nil)).Elem()
for go_t, gql_t := range(type_map) {
if go_t.Implements(node_type) {
valid_nodes[go_t] = gql_t
}
if go_t.Implements(lockable_type) {
valid_lockables[go_t] = gql_t
}
if go_t.Implements(thread_type) {
valid_threads[go_t] = gql_t
}
type_list = append(type_list, gql_t)
}
queries := graphql.Fields{
"Self": GQLQuerySelf(),
}
for key, val := range(extended_queries) {
queries[key] = val
}
mutations := graphql.Fields{
"SendUpdate": GQLMutationSendUpdate(),
}
for key, val := range(extended_mutations) {
mutations[key] = val
}
subscriptions := graphql.Fields{
"Update": GQLSubscriptionUpdate(),
}
for key, val := range(extended_subscriptions) {
subscriptions[key] = val
}
schemaConfig := graphql.SchemaConfig{
Types: type_list,
Query: graphql.NewObject(graphql.ObjectConfig{
Name: "Query",
Fields: queries,
}),
Mutation: graphql.NewObject(graphql.ObjectConfig{
Name: "Mutation",
Fields: mutations,
}),
Subscription: graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: subscriptions,
}),
}
schema, err := graphql.NewSchema(schemaConfig)
if err != nil{
return nil, err
}
ctx := GQLContext{
Schema: schema,
ValidNodes: valid_nodes,
NodeType: node_type,
ValidThreads: valid_threads,
ThreadType: thread_type,
ValidLockables: valid_lockables,
LockableType: lockable_type,
}
return &ctx, nil
} }
func LoadNode(ctx * GraphContext, id NodeID) (GraphNode, error) { func LoadNode(ctx * GraphContext, id NodeID) (GraphNode, error) {
@ -51,7 +170,7 @@ func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]Grap
node_fn, exists := ctx.NodeLoadFuncs[base.Type] node_fn, exists := ctx.NodeLoadFuncs[base.Type]
if exists == false { if exists == false {
return nil, fmt.Errorf("%s is not a known node type", base.Type) return nil, fmt.Errorf("%s is not a known node type: %s", base.Type, state_bytes)
} }
node, err = node_fn(ctx, id) node, err = node_fn(ctx, id)
@ -77,21 +196,27 @@ func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]Grap
} }
func NewGraphContext(db * badger.DB, log Logger) * GraphContext { func NewGraphContext(db * badger.DB, log Logger) * GraphContext {
gql, err := NewGQLContext(TypeList{}, ObjTypeMap{}, FieldMap{}, FieldMap{}, FieldMap{})
if err != nil {
panic(err)
}
ctx := GraphContext{ ctx := GraphContext{
GQL: gql,
DB: db, DB: db,
Log: log, Log: log,
NodeLoadFuncs: NodeLoadMap{ NodeLoadFuncs: NodeLoadMap{
"base_lockable": LoadBaseLockable, "base_lockable": LoadBaseLockable,
"base_thread": LoadBaseThread, "base_thread": LoadBaseThread,
"gql_thread": LoadGQLThread,
}, },
StateLoadFuncs: StateLoadMap{ StateLoadFuncs: StateLoadMap{
"base_lockable": LoadBaseLockableState, "base_lockable": LoadBaseLockableState,
"base_thread": LoadBaseThreadState, "base_thread": LoadBaseThreadState,
"gql_thread": LoadGQLThreadState,
}, },
} }
return &ctx return &ctx
} }

@ -538,85 +538,16 @@ func NewBaseLockable(ctx * GraphContext, state LockableState) (BaseLockable, err
return lockable, nil return lockable, nil
} }
func LoadBaseThread(ctx * GraphContext, id NodeID) (GraphNode, error) { func RestoreBaseLockable(ctx * GraphContext, id NodeID) BaseLockable {
base_node := RestoreNode(ctx, id) base_node := RestoreNode(ctx, id)
thread := BaseThread{ return BaseLockable{
BaseLockable: BaseLockable{ BaseNode: base_node,
BaseNode: base_node,
},
}
return &thread, nil
}
func RestoreBaseThreadState(ctx * GraphContext, j BaseThreadStateJSON, loaded_nodes NodeMap) (*BaseThreadState, error) {
lockable_state, err := RestoreBaseLockableState(ctx, j.LockableState, loaded_nodes)
if err != nil {
return nil, err
}
lockable_state._type = "thread_state"
state := BaseThreadState{
BaseLockableState: *lockable_state,
parent: nil,
children: make([]Thread, len(j.Children)),
child_info: map[NodeID]ThreadInfo{},
InfoType: nil,
running: false,
}
if j.Parent != nil {
p, err := LoadNodeRecurse(ctx, *j.Parent, loaded_nodes)
if err != nil {
return nil, err
}
p_t, ok := p.(Thread)
if ok == false {
return nil, err
}
state.owner = p_t
}
i := 0
for id, info := range(j.Children) {
child_node, err := LoadNodeRecurse(ctx, id, loaded_nodes)
if err != nil {
return nil, err
}
child_t, ok := child_node.(Thread)
if ok == false {
return nil, fmt.Errorf("%+v is not a Thread as expected", child_node)
}
state.children[i] = child_t
state.child_info[id] = info
i++
}
return &state, nil
}
func LoadBaseThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){
var j BaseThreadStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
state, err := RestoreBaseThreadState(ctx, j, loaded_nodes)
if err != nil {
return nil, err
} }
return state, nil
} }
func LoadBaseLockable(ctx * GraphContext, id NodeID) (GraphNode, error) { func LoadBaseLockable(ctx * GraphContext, id NodeID) (GraphNode, error) {
// call LoadNodeRecurse on any connected nodes to ensure they're loaded and return the id // call LoadNodeRecurse on any connected nodes to ensure they're loaded and return the id
base_node := RestoreNode(ctx, id) lockable := RestoreBaseLockable(ctx, id)
lockable := BaseLockable{
BaseNode: base_node,
}
return &lockable, nil return &lockable, nil
} }

@ -362,18 +362,13 @@ func TestLockableDependencyOverlap(t * testing.T) {
} }
func TestLockableDBLoad(t * testing.T){ func TestLockableDBLoad(t * testing.T){
ctx := logTestContext(t, []string{"db"}) ctx := logTestContext(t, []string{})
l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{}) l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err) fatalErr(t, err)
l2, err := NewSimpleBaseLockable(ctx, "Test Lockable 2", []Lockable{}) l2, err := NewSimpleBaseLockable(ctx, "Test Lockable 2", []Lockable{})
fatalErr(t, err) fatalErr(t, err)
l3, err := NewSimpleBaseLockable(ctx, "Test Lockable 3", []Lockable{l1, l2}) l3, err := NewSimpleBaseLockable(ctx, "Test Lockable 3", []Lockable{l1, l2})
fatalErr(t, err) fatalErr(t, err)
err = UseStates(ctx, []GraphNode{l3}, func(states NodeStateMap) error {
ser, err := json.MarshalIndent(states[l3.ID()], "", " ")
fmt.Printf("\n%s\n\n", ser)
return err
})
l4, err := NewSimpleBaseLockable(ctx, "Test Lockable 4", []Lockable{l3}) l4, err := NewSimpleBaseLockable(ctx, "Test Lockable 4", []Lockable{l3})
fatalErr(t, err) fatalErr(t, err)
_, err = NewSimpleBaseLockable(ctx, "Test Lockable 5", []Lockable{l4}) _, err = NewSimpleBaseLockable(ctx, "Test Lockable 5", []Lockable{l4})
@ -391,6 +386,13 @@ func TestLockableDBLoad(t * testing.T){
return err return err
}) })
_, err = LoadNode(ctx, l3.ID()) l3_loaded, err := LoadNode(ctx, l3.ID())
fatalErr(t, err) fatalErr(t, err)
// TODO: add more equivalence checks
err = UseStates(ctx, []GraphNode{l3_loaded}, func(states NodeStateMap) error {
ser, err := json.MarshalIndent(states[l3_loaded.ID()], "", " ")
fmt.Printf("\n%s\n\n", ser)
return err
})
} }

@ -72,7 +72,7 @@ type BaseThreadState struct {
type BaseThreadStateJSON struct { type BaseThreadStateJSON struct {
Parent *NodeID `json:"parent"` Parent *NodeID `json:"parent"`
Children map[NodeID]interface{} `json:"children"` Children map[NodeID]interface{} `json:"children"`
LockableState BaseLockableStateJSON `json:"lockable"` BaseLockableStateJSON
} }
func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON { func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON {
@ -92,10 +92,85 @@ func SaveBaseThreadState(state * BaseThreadState) BaseThreadStateJSON {
return BaseThreadStateJSON{ return BaseThreadStateJSON{
Parent: parent_id, Parent: parent_id,
Children: children, Children: children,
LockableState: lockable_state, BaseLockableStateJSON: lockable_state,
} }
} }
func RestoreBaseThread(ctx * GraphContext, id NodeID) BaseThread {
base_lockable := RestoreBaseLockable(ctx, id)
thread := BaseThread{
BaseLockable: base_lockable,
}
return thread
}
func LoadBaseThread(ctx * GraphContext, id NodeID) (GraphNode, error) {
thread := RestoreBaseThread(ctx, id)
return &thread, nil
}
func RestoreBaseThreadState(ctx * GraphContext, j BaseThreadStateJSON, loaded_nodes NodeMap) (*BaseThreadState, error) {
lockable_state, err := RestoreBaseLockableState(ctx, j.BaseLockableStateJSON, loaded_nodes)
if err != nil {
return nil, err
}
lockable_state._type = "thread_state"
state := BaseThreadState{
BaseLockableState: *lockable_state,
parent: nil,
children: make([]Thread, len(j.Children)),
child_info: map[NodeID]ThreadInfo{},
InfoType: nil,
running: false,
}
if j.Parent != nil {
p, err := LoadNodeRecurse(ctx, *j.Parent, loaded_nodes)
if err != nil {
return nil, err
}
p_t, ok := p.(Thread)
if ok == false {
return nil, err
}
state.owner = p_t
}
i := 0
for id, info := range(j.Children) {
child_node, err := LoadNodeRecurse(ctx, id, loaded_nodes)
if err != nil {
return nil, err
}
child_t, ok := child_node.(Thread)
if ok == false {
return nil, fmt.Errorf("%+v is not a Thread as expected", child_node)
}
state.children[i] = child_t
state.child_info[id] = info
i++
}
return &state, nil
}
func LoadBaseThreadState(ctx * GraphContext, data []byte, loaded_nodes NodeMap)(NodeState, error){
var j BaseThreadStateJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
state, err := RestoreBaseThreadState(ctx, j, loaded_nodes)
if err != nil {
return nil, err
}
return state, nil
}
func (state * BaseThreadState) MarshalJSON() ([]byte, error) { func (state * BaseThreadState) MarshalJSON() ([]byte, error) {
thread_state := SaveBaseThreadState(state) thread_state := SaveBaseThreadState(state)
return json.Marshal(&thread_state) return json.Marshal(&thread_state)

@ -4,6 +4,7 @@ import (
"testing" "testing"
"time" "time"
"fmt" "fmt"
"encoding/json"
) )
func TestNewThread(t * testing.T) { func TestNewThread(t * testing.T) {
@ -56,3 +57,32 @@ func TestThreadWithRequirement(t * testing.T) {
}) })
fatalErr(t, err) fatalErr(t, err)
} }
func TestThreadDBLoad(t * testing.T) {
ctx := logTestContext(t, []string{})
l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{})
fatalErr(t, err)
t1, err := NewSimpleBaseThread(ctx, "Test Thread 1", []Lockable{l1}, ThreadActions{}, ThreadHandlers{})
fatalErr(t, err)
SendUpdate(ctx, t1, CancelSignal(nil))
err = RunThread(ctx, t1)
fatalErr(t, err)
err = UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) error {
ser, err := json.MarshalIndent(states[t1.ID()], "", " ")
fmt.Printf("\n%s\n\n", ser)
return err
})
t1_loaded, err := LoadNode(ctx, t1.ID())
fatalErr(t, err)
err = UseStates(ctx, []GraphNode{t1_loaded}, func(states NodeStateMap) error {
ser, err := json.MarshalIndent(states[t1_loaded.ID()], "", " ")
fmt.Printf("\n%s\n\n", ser)
return err
})
}