Added LoadJSONNode to abstract out repeated json unmarshaling

graph-rework-2
noah metz 2023-07-25 00:19:39 -06:00
parent 59df9f04d5
commit a2395189a8
4 changed files with 56 additions and 92 deletions

@ -733,13 +733,7 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON {
} }
} }
func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { var LoadGQLThread = LoadJSONNode(func(id NodeID, j GQLThreadJSON) (Node, error) {
var j GQLThreadJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
ecdh_curve, ok := ecdh_curves[j.ECDH] ecdh_curve, ok := ecdh_curves[j.ECDH]
if ok == false { if ok == false {
return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH) return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH)
@ -751,29 +745,24 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e
} }
thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey) thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey)
nodes[id] = &thread return &thread, nil
}, func(ctx *Context, thread *GQLThread, j GQLThreadJSON, nodes NodeMap) error {
thread.Users = map[NodeID]*User{} thread.Users = map[NodeID]*User{}
for _, id_str := range(j.Users) { for _, id_str := range(j.Users) {
ctx.Log.Logf("db", "THREAD_LOAD_USER: %s", id_str) ctx.Log.Logf("db", "THREAD_LOAD_USER: %s", id_str)
user_id, err := ParseID(id_str) user_id, err := ParseID(id_str)
if err != nil { if err != nil {
return nil, err return err
} }
user, err := LoadNodeRecurse(ctx, user_id, nodes) user, err := LoadNodeRecurse(ctx, user_id, nodes)
if err != nil { if err != nil {
return nil, err return err
} }
thread.Users[user_id] = user.(*User) thread.Users[user_id] = user.(*User)
} }
err = RestoreThread(ctx, &thread.Thread, j.ThreadJSON, nodes) return RestoreThread(ctx, thread, j.ThreadJSON, nodes)
if err != nil { })
return nil, err
}
return &thread, nil
}
func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLThread { func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLThread {
if tls_cert == nil || tls_key == nil { if tls_cert == nil || tls_key == nil {

@ -33,23 +33,10 @@ func NewListener(id NodeID, name string) Listener {
} }
} }
func LoadListener(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { var LoadListener = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) {
var j LockableJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
listener := NewListener(id, j.Name) listener := NewListener(id, j.Name)
nodes[id] = &listener
err = RestoreLockable(ctx, &listener.Lockable, j, nodes)
if err != nil {
return nil, err
}
return &listener, nil return &listener, nil
} }, RestoreLockable)
type LockableNode interface { type LockableNode interface {
Node Node
@ -462,24 +449,10 @@ func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, o
}) })
} }
// Load function for Lockable var LoadLockable = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) {
func LoadLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j LockableJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
lockable := NewLockable(id, j.Name) lockable := NewLockable(id, j.Name)
nodes[id] = &lockable
err = RestoreLockable(ctx, &lockable, j, nodes)
if err != nil {
return nil, err
}
return &lockable, nil return &lockable, nil
} }, RestoreLockable)
func NewLockable(id NodeID, name string) Lockable { func NewLockable(id NodeID, name string) Lockable {
return Lockable{ return Lockable{
@ -493,7 +466,8 @@ func NewLockable(id NodeID, name string) Lockable {
} }
// Helper function to load links when loading a struct that embeds Lockable // Helper function to load links when loading a struct that embeds Lockable
func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes NodeMap) error { func RestoreLockable(ctx * Context, lockable LockableNode, j LockableJSON, nodes NodeMap) error {
lockable_ptr := lockable.LockableHandle()
if j.Owner != "" { if j.Owner != "" {
owner_id, err := ParseID(j.Owner) owner_id, err := ParseID(j.Owner)
if err != nil { if err != nil {
@ -507,7 +481,7 @@ func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes No
if ok == false { if ok == false {
return fmt.Errorf("%s is not a Lockable", j.Owner) return fmt.Errorf("%s is not a Lockable", j.Owner)
} }
lockable.Owner = owner lockable_ptr.Owner = owner
} }
for _, dep_str := range(j.Dependencies) { for _, dep_str := range(j.Dependencies) {
@ -524,7 +498,7 @@ func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes No
return fmt.Errorf("%+v is not a Lockable as expected", dep_node) return fmt.Errorf("%+v is not a Lockable as expected", dep_node)
} }
ctx.Log.Logf("db", "LOCKABLE_LOAD_DEPENDENCY: %s - %s - %+v", lockable.ID(), dep_id, reflect.TypeOf(dep)) ctx.Log.Logf("db", "LOCKABLE_LOAD_DEPENDENCY: %s - %s - %+v", lockable.ID(), dep_id, reflect.TypeOf(dep))
lockable.Dependencies[dep_id] = dep lockable_ptr.Dependencies[dep_id] = dep
} }
for _, req_str := range(j.Requirements) { for _, req_str := range(j.Requirements) {
@ -540,7 +514,7 @@ func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes No
if ok == false { if ok == false {
return fmt.Errorf("%+v is not a Lockable as expected", req_node) return fmt.Errorf("%+v is not a Lockable as expected", req_node)
} }
lockable.Requirements[req_id] = req lockable_ptr.Requirements[req_id] = req
} }
for l_id_str, h_str := range(j.LocksHeld) { for l_id_str, h_str := range(j.LocksHeld) {
@ -570,8 +544,8 @@ func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes No
} }
h_l = h h_l = h
} }
lockable.RecordLock(l_l, h_l) lockable_ptr.RecordLock(l_l, h_l)
} }
return RestoreSimpleNode(ctx, &lockable.SimpleNode, j.SimpleNodeJSON, nodes) return RestoreSimpleNode(ctx, lockable, j.SimpleNodeJSON, nodes)
} }

@ -137,7 +137,8 @@ func NewSimpleNodeJSON(node *SimpleNode) SimpleNodeJSON {
} }
} }
func RestoreSimpleNode(ctx *Context, node *SimpleNode, j SimpleNodeJSON, nodes NodeMap) error { func RestoreSimpleNode(ctx *Context, node Node, j SimpleNodeJSON, nodes NodeMap) error {
node_ptr := node.NodeHandle()
for _, policy_str := range(j.Policies) { for _, policy_str := range(j.Policies) {
policy_id, err := ParseID(policy_str) policy_id, err := ParseID(policy_str)
if err != nil { if err != nil {
@ -153,30 +154,39 @@ func RestoreSimpleNode(ctx *Context, node *SimpleNode, j SimpleNodeJSON, nodes N
if ok == false { if ok == false {
return fmt.Errorf("%s is not a Policy", policy_id) return fmt.Errorf("%s is not a Policy", policy_id)
} }
node.PolicyMap[policy_id] = policy node_ptr.PolicyMap[policy_id] = policy
} }
return nil return nil
} }
func LoadSimpleNode(ctx *Context, id NodeID, data []byte, nodes NodeMap)(Node, error) { func LoadJSONNode[J any, N Node](init_func func(NodeID, J)(Node, error), restore_func func(*Context, N, J, NodeMap)error)func(*Context, NodeID, []byte, NodeMap)(Node, error) {
var j SimpleNodeJSON return func(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) {
var j J
err := json.Unmarshal(data, &j) err := json.Unmarshal(data, &j)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node := NewSimpleNode(id) node, err := init_func(id, j)
nodes[id] = &node if err != nil {
return nil, err
err = RestoreSimpleNode(ctx, &node, j, nodes) }
nodes[id] = node
err = restore_func(ctx, node.(N), j, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &node, nil return node, nil
}
} }
var LoadSimpleNode = LoadJSONNode(func(id NodeID, j SimpleNodeJSON) (Node, error) {
node := NewSimpleNode(id)
return &node, nil
}, RestoreSimpleNode)
func (node *SimpleNode) Policies() []Policy { func (node *SimpleNode) Policies() []Policy {
ret := make([]Policy, len(node.PolicyMap)) ret := make([]Policy, len(node.PolicyMap))
i := 0 i := 0

@ -225,11 +225,12 @@ func (thread *Thread) ChildList() []ThreadNode {
} }
type ThreadJSON struct { type ThreadJSON struct {
LockableJSON
Parent string `json:"parent"` Parent string `json:"parent"`
Children map[string]map[string]interface{} `json:"children"` Children map[string]map[string]interface{} `json:"children"`
ActionQueue []QueuedAction `json:"action_queue"` ActionQueue []QueuedAction `json:"action_queue"`
StateName string `json:"state_name"` StateName string `json:"state_name"`
LockableJSON InfoTypes []InfoType `json:"info_types"`
} }
func NewThreadJSON(thread *Thread) ThreadJSON { func NewThreadJSON(thread *Thread) ThreadJSON {
@ -255,26 +256,14 @@ func NewThreadJSON(thread *Thread) ThreadJSON {
ActionQueue: thread.ActionQueue, ActionQueue: thread.ActionQueue,
StateName: thread.StateName, StateName: thread.StateName,
LockableJSON: lockable_json, LockableJSON: lockable_json,
InfoTypes: thread.InfoTypes,
} }
} }
func LoadThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { var LoadThread = LoadJSONNode(func(id NodeID, j ThreadJSON) (Node, error) {
var j ThreadJSON thread := NewThread(id, j.Name, j.StateName, j.InfoTypes, BaseThreadActions, BaseThreadHandlers)
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
thread := NewThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers)
nodes[id] = &thread
err = RestoreThread(ctx, &thread, j, nodes)
if err != nil {
return nil, err
}
return &thread, nil return &thread, nil
} }, RestoreThread)
func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) { func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) {
var soonest_action *QueuedAction var soonest_action *QueuedAction
@ -292,9 +281,11 @@ func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) {
} }
} }
func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) error { func RestoreThread(ctx *Context, thread ThreadNode, j ThreadJSON, nodes NodeMap) error {
thread.ActionQueue = j.ActionQueue thread_ptr := thread.ThreadHandle()
thread.NextAction, thread.TimeoutChan = thread.SoonestAction()
thread_ptr.ActionQueue = j.ActionQueue
thread_ptr.NextAction, thread_ptr.TimeoutChan = thread_ptr.SoonestAction()
if j.Parent != "" { if j.Parent != "" {
parent_id, err := ParseID(j.Parent) parent_id, err := ParseID(j.Parent)
@ -309,7 +300,7 @@ func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) er
if ok == false { if ok == false {
return err return err
} }
thread.Parent = p_t thread_ptr.Parent = p_t
} }
for id_str, info_raw := range(j.Children) { for id_str, info_raw := range(j.Children) {
@ -333,10 +324,10 @@ func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) er
return err return err
} }
thread.Children[id] = ChildInfo{child_t, parsed_info} thread_ptr.Children[id] = ChildInfo{child_t, parsed_info}
} }
return RestoreLockable(ctx, &thread.Lockable, j.LockableJSON, nodes) return RestoreLockable(ctx, thread, j.LockableJSON, nodes)
} }
var deserializers = map[InfoType]func(interface{})(interface{}, error) { var deserializers = map[InfoType]func(interface{})(interface{}, error) {