diff --git a/gql.go b/gql.go index d047a1b..95d3e40 100644 --- a/gql.go +++ b/gql.go @@ -733,13 +733,7 @@ func NewGQLThreadJSON(thread *GQLThread) GQLThreadJSON { } } -func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { - var j GQLThreadJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - +var LoadGQLThread = LoadJSONNode(func(id NodeID, j GQLThreadJSON) (Node, error) { ecdh_curve, ok := ecdh_curves[j.ECDH] if ok == false { return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH) @@ -751,29 +745,24 @@ func LoadGQLThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, e } thread := NewGQLThread(id, j.Name, j.StateName, j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey) - nodes[id] = &thread - + return &thread, nil +}, func(ctx *Context, thread *GQLThread, j GQLThreadJSON, nodes NodeMap) error { thread.Users = map[NodeID]*User{} for _, id_str := range(j.Users) { ctx.Log.Logf("db", "THREAD_LOAD_USER: %s", id_str) user_id, err := ParseID(id_str) if err != nil { - return nil, err + return err } user, err := LoadNodeRecurse(ctx, user_id, nodes) if err != nil { - return nil, err + return err } thread.Users[user_id] = user.(*User) } - err = RestoreThread(ctx, &thread.Thread, j.ThreadJSON, nodes) - if err != nil { - return nil, err - } - - return &thread, nil -} + return RestoreThread(ctx, thread, j.ThreadJSON, nodes) +}) func NewGQLThread(id NodeID, name string, state_name string, listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) GQLThread { if tls_cert == nil || tls_key == nil { diff --git a/lockable.go b/lockable.go index 8212df7..9f54216 100644 --- a/lockable.go +++ b/lockable.go @@ -33,23 +33,10 @@ func NewListener(id NodeID, name string) Listener { } } -func LoadListener(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { - var j LockableJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - +var LoadListener = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) { listener := NewListener(id, j.Name) - nodes[id] = &listener - - err = RestoreLockable(ctx, &listener.Lockable, j, nodes) - if err != nil { - return nil, err - } - return &listener, nil -} +}, RestoreLockable) type LockableNode interface { Node @@ -462,24 +449,10 @@ func UnlockLockables(context *StateContext, to_unlock map[NodeID]LockableNode, o }) } -// Load function for Lockable -func LoadLockable(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { - var j LockableJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - +var LoadLockable = LoadJSONNode(func(id NodeID, j LockableJSON) (Node, error) { lockable := NewLockable(id, j.Name) - nodes[id] = &lockable - - err = RestoreLockable(ctx, &lockable, j, nodes) - if err != nil { - return nil, err - } - return &lockable, nil -} +}, RestoreLockable) func NewLockable(id NodeID, name string) Lockable { return Lockable{ @@ -493,7 +466,8 @@ func NewLockable(id NodeID, name string) Lockable { } // Helper function to load links when loading a struct that embeds Lockable -func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes NodeMap) error { +func RestoreLockable(ctx * Context, lockable LockableNode, j LockableJSON, nodes NodeMap) error { + lockable_ptr := lockable.LockableHandle() if j.Owner != "" { owner_id, err := ParseID(j.Owner) if err != nil { @@ -507,7 +481,7 @@ func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes No if ok == false { return fmt.Errorf("%s is not a Lockable", j.Owner) } - lockable.Owner = owner + lockable_ptr.Owner = owner } for _, dep_str := range(j.Dependencies) { @@ -524,7 +498,7 @@ func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes No return fmt.Errorf("%+v is not a Lockable as expected", dep_node) } ctx.Log.Logf("db", "LOCKABLE_LOAD_DEPENDENCY: %s - %s - %+v", lockable.ID(), dep_id, reflect.TypeOf(dep)) - lockable.Dependencies[dep_id] = dep + lockable_ptr.Dependencies[dep_id] = dep } for _, req_str := range(j.Requirements) { @@ -540,7 +514,7 @@ func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes No if ok == false { return fmt.Errorf("%+v is not a Lockable as expected", req_node) } - lockable.Requirements[req_id] = req + lockable_ptr.Requirements[req_id] = req } for l_id_str, h_str := range(j.LocksHeld) { @@ -570,8 +544,8 @@ func RestoreLockable(ctx * Context, lockable *Lockable, j LockableJSON, nodes No } h_l = h } - lockable.RecordLock(l_l, h_l) + lockable_ptr.RecordLock(l_l, h_l) } - return RestoreSimpleNode(ctx, &lockable.SimpleNode, j.SimpleNodeJSON, nodes) + return RestoreSimpleNode(ctx, lockable, j.SimpleNodeJSON, nodes) } diff --git a/node.go b/node.go index e6d94b9..0114880 100644 --- a/node.go +++ b/node.go @@ -137,7 +137,8 @@ func NewSimpleNodeJSON(node *SimpleNode) SimpleNodeJSON { } } -func RestoreSimpleNode(ctx *Context, node *SimpleNode, j SimpleNodeJSON, nodes NodeMap) error { +func RestoreSimpleNode(ctx *Context, node Node, j SimpleNodeJSON, nodes NodeMap) error { + node_ptr := node.NodeHandle() for _, policy_str := range(j.Policies) { policy_id, err := ParseID(policy_str) if err != nil { @@ -153,29 +154,38 @@ func RestoreSimpleNode(ctx *Context, node *SimpleNode, j SimpleNodeJSON, nodes N if ok == false { return fmt.Errorf("%s is not a Policy", policy_id) } - node.PolicyMap[policy_id] = policy + node_ptr.PolicyMap[policy_id] = policy } return nil } -func LoadSimpleNode(ctx *Context, id NodeID, data []byte, nodes NodeMap)(Node, error) { - var j SimpleNodeJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } +func LoadJSONNode[J any, N Node](init_func func(NodeID, J)(Node, error), restore_func func(*Context, N, J, NodeMap)error)func(*Context, NodeID, []byte, NodeMap)(Node, error) { + return func(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { + var j J + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } - node := NewSimpleNode(id) - nodes[id] = &node + node, err := init_func(id, j) + if err != nil { + return nil, err + } + nodes[id] = node + err = restore_func(ctx, node.(N), j, nodes) + if err != nil { + return nil, err + } - err = RestoreSimpleNode(ctx, &node, j, nodes) - if err != nil { - return nil, err + return node, nil } +} +var LoadSimpleNode = LoadJSONNode(func(id NodeID, j SimpleNodeJSON) (Node, error) { + node := NewSimpleNode(id) return &node, nil -} +}, RestoreSimpleNode) func (node *SimpleNode) Policies() []Policy { ret := make([]Policy, len(node.PolicyMap)) diff --git a/thread.go b/thread.go index 5503029..90e5ce0 100644 --- a/thread.go +++ b/thread.go @@ -225,11 +225,12 @@ func (thread *Thread) ChildList() []ThreadNode { } type ThreadJSON struct { + LockableJSON Parent string `json:"parent"` Children map[string]map[string]interface{} `json:"children"` ActionQueue []QueuedAction `json:"action_queue"` StateName string `json:"state_name"` - LockableJSON + InfoTypes []InfoType `json:"info_types"` } func NewThreadJSON(thread *Thread) ThreadJSON { @@ -255,26 +256,14 @@ func NewThreadJSON(thread *Thread) ThreadJSON { ActionQueue: thread.ActionQueue, StateName: thread.StateName, LockableJSON: lockable_json, + InfoTypes: thread.InfoTypes, } } -func LoadThread(ctx *Context, id NodeID, data []byte, nodes NodeMap) (Node, error) { - var j ThreadJSON - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - thread := NewThread(id, j.Name, j.StateName, nil, BaseThreadActions, BaseThreadHandlers) - nodes[id] = &thread - - err = RestoreThread(ctx, &thread, j, nodes) - if err != nil { - return nil, err - } - +var LoadThread = LoadJSONNode(func(id NodeID, j ThreadJSON) (Node, error) { + thread := NewThread(id, j.Name, j.StateName, j.InfoTypes, BaseThreadActions, BaseThreadHandlers) return &thread, nil -} +}, RestoreThread) func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) { var soonest_action *QueuedAction @@ -292,9 +281,11 @@ func (thread *Thread) SoonestAction() (*QueuedAction, <-chan time.Time) { } } -func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) error { - thread.ActionQueue = j.ActionQueue - thread.NextAction, thread.TimeoutChan = thread.SoonestAction() +func RestoreThread(ctx *Context, thread ThreadNode, j ThreadJSON, nodes NodeMap) error { + thread_ptr := thread.ThreadHandle() + + thread_ptr.ActionQueue = j.ActionQueue + thread_ptr.NextAction, thread_ptr.TimeoutChan = thread_ptr.SoonestAction() if j.Parent != "" { parent_id, err := ParseID(j.Parent) @@ -309,7 +300,7 @@ func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) er if ok == false { return err } - thread.Parent = p_t + thread_ptr.Parent = p_t } for id_str, info_raw := range(j.Children) { @@ -333,10 +324,10 @@ func RestoreThread(ctx *Context, thread *Thread, j ThreadJSON, nodes NodeMap) er return err } - thread.Children[id] = ChildInfo{child_t, parsed_info} + thread_ptr.Children[id] = ChildInfo{child_t, parsed_info} } - return RestoreLockable(ctx, &thread.Lockable, j.LockableJSON, nodes) + return RestoreLockable(ctx, thread, j.LockableJSON, nodes) } var deserializers = map[InfoType]func(interface{})(interface{}, error) {