diff --git a/go.mod b/go.mod index b3b3162..ab994a3 100644 --- a/go.mod +++ b/go.mod @@ -2,30 +2,33 @@ module github.com/mekkanized/graphvent go 1.20 +require ( + github.com/dgraph-io/badger/v3 v3.2103.5 + github.com/gobwas/ws v1.2.1 + github.com/google/uuid v1.3.0 + github.com/graphql-go/graphql v0.8.1 + github.com/rs/zerolog v1.29.1 +) + require ( github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect - github.com/dgraph-io/badger/v3 v3.2103.5 // indirect github.com/dgraph-io/badger/v4 v4.1.0 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect - github.com/gobwas/ws v1.2.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 // indirect github.com/golang/protobuf v1.3.1 // indirect github.com/golang/snappy v0.0.3 // indirect github.com/google/flatbuffers v1.12.1 // indirect - github.com/google/uuid v1.3.0 // indirect - github.com/graphql-go/graphql v0.8.1 // indirect github.com/graphql-go/handler v0.2.3 // indirect github.com/klauspost/compress v1.12.3 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/rs/zerolog v1.29.1 // indirect go.opencensus.io v0.22.5 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/sys v0.6.0 // indirect diff --git a/gql.go b/gql.go index 1d58915..410f1da 100644 --- a/gql.go +++ b/gql.go @@ -350,7 +350,7 @@ type GQLThreadState struct { func NewGQLThreadState(listen string) GQLThreadState { state := GQLThreadState{ - BaseThreadState: NewBaseThreadState("GQL Server"), + BaseThreadState: NewBaseThreadState("GQL Server", "gql_thread"), Listen: listen, } state.InfoType = reflect.TypeOf((*GQLThreadInfo)(nil)) @@ -370,8 +370,8 @@ var gql_actions ThreadActions = ThreadActions{ fs := http.FileServer(http.Dir("./site")) mux.Handle("/site/", http.StripPrefix("/site", fs)) - UseStates(ctx, []GraphNode{server}, func(states []NodeState)(error){ - server_state := states[0].(*GQLThreadState) + UseStates(ctx, []GraphNode{server}, func(states NodeStateMap)(error){ + server_state := states[server.ID()].(*GQLThreadState) server.http_server = &http.Server{ Addr: server_state.Listen, Handler: mux, @@ -395,8 +395,8 @@ var gql_actions ThreadActions = ThreadActions{ var gql_handlers ThreadHandlers = ThreadHandlers{ "child_added": func(ctx * GraphContext, thread Thread, signal GraphSignal) (string, error) { ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: %+v", signal) - UseStates(ctx, []GraphNode{thread}, func(states []NodeState)(error) { - server_state := states[0].(*GQLThreadState) + UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap)(error) { + server_state := states[thread.ID()].(*GQLThreadState) should_run, exists := server_state.child_info[signal.Source()].(*GQLThreadInfo) if exists == false { ctx.Log.Logf("gql", "GQL_THREAD_CHILD_ADDED: tried to start %s whis is not a child") diff --git a/gql_graph.go b/gql_graph.go index 28056b2..291cb8c 100644 --- a/gql_graph.go +++ b/gql_graph.go @@ -237,8 +237,8 @@ func GQLNodeName(p graphql.ResolveParams) (interface{}, error) { } name := "" - err := UseStates(ctx, []GraphNode{node}, func(states []NodeState) (error) { - name = states[0].Name() + err := UseStates(ctx, []GraphNode{node}, func(states NodeStateMap) (error) { + name = states[node.ID()].Name() return nil }) @@ -261,8 +261,8 @@ func GQLThreadListen(p graphql.ResolveParams) (interface{}, error) { } listen := "" - err := UseStates(ctx, []GraphNode{node}, func(states []NodeState) (error) { - gql_thread, ok := states[0].(*GQLThreadState) + err := UseStates(ctx, []GraphNode{node}, func(states NodeStateMap) (error) { + gql_thread, ok := states[node.ID()].(*GQLThreadState) if ok == false { return fmt.Errorf("Failed to cast state to GQLThreadState") } @@ -289,8 +289,8 @@ func GQLThreadParent(p graphql.ResolveParams) (interface{}, error) { } var parent Thread = nil - err := UseStates(ctx, []GraphNode{node}, func(states []NodeState) (error) { - gql_thread, ok := states[0].(ThreadState) + err := UseStates(ctx, []GraphNode{node}, func(states NodeStateMap) (error) { + gql_thread, ok := states[node.ID()].(ThreadState) if ok == false { return fmt.Errorf("Failed to cast state to ThreadState") } @@ -317,8 +317,8 @@ func GQLThreadChildren(p graphql.ResolveParams) (interface{}, error) { } var children []Thread = nil - err := UseStates(ctx, []GraphNode{node}, func(states []NodeState) (error) { - gql_thread, ok := states[0].(ThreadState) + err := UseStates(ctx, []GraphNode{node}, func(states NodeStateMap) (error) { + gql_thread, ok := states[node.ID()].(ThreadState) if ok == false { return fmt.Errorf("Failed to cast state to ThreadState") } @@ -345,8 +345,8 @@ func GQLLockableRequirements(p graphql.ResolveParams) (interface{}, error) { } var requirements []Lockable = nil - err := UseStates(ctx, []GraphNode{node}, func(states []NodeState) (error) { - gql_thread, ok := states[0].(LockableState) + err := UseStates(ctx, []GraphNode{node}, func(states NodeStateMap) (error) { + gql_thread, ok := states[node.ID()].(LockableState) if ok == false { return fmt.Errorf("Failed to cast state to LockableState") } @@ -373,8 +373,8 @@ func GQLLockableDependencies(p graphql.ResolveParams) (interface{}, error) { } var dependencies []Lockable = nil - err := UseStates(ctx, []GraphNode{node}, func(states []NodeState) (error) { - gql_thread, ok := states[0].(LockableState) + err := UseStates(ctx, []GraphNode{node}, func(states NodeStateMap) (error) { + gql_thread, ok := states[node.ID()].(LockableState) if ok == false { return fmt.Errorf("Failed to cast state to LockableState") } @@ -401,8 +401,8 @@ func GQLLockableOwner(p graphql.ResolveParams) (interface{}, error) { } var owner GraphNode = nil - err := UseStates(ctx, []GraphNode{node}, func(states []NodeState) (error) { - gql_thread, ok := states[0].(LockableState) + err := UseStates(ctx, []GraphNode{node}, func(states NodeStateMap) (error) { + gql_thread, ok := states[node.ID()].(LockableState) if ok == false { return fmt.Errorf("Failed to cast state to LockableState") } @@ -841,8 +841,8 @@ func GQLMutationSendUpdate() *graphql.Field { } var node GraphNode = nil - err := UseStates(ctx, []GraphNode{server}, func(states []NodeState) (error){ - server_state := states[0].(*GQLThreadState) + err := UseStates(ctx, []GraphNode{server}, func(states NodeStateMap) (error){ + server_state := states[server.ID()].(*GQLThreadState) node = FindChild(ctx, server, server_state, NodeID(id)) if node == nil { return fmt.Errorf("Failed to find ID: %s as child of server thread", id) diff --git a/graph.go b/graph.go index f82cca7..59e689a 100644 --- a/graph.go +++ b/graph.go @@ -2,7 +2,6 @@ package graphvent import ( "sync" - "reflect" "github.com/google/uuid" "os" "github.com/rs/zerolog" @@ -11,13 +10,32 @@ import ( "encoding/json" ) +type StateLoadFunc func(*GraphContext, NodeID, []byte, map[NodeID]GraphNode)(NodeState, error) +type StateLoadMap map[string]StateLoadFunc +type NodeLoadFunc func(*GraphContext, NodeID)(GraphNode, error) +type NodeLoadMap map[string]NodeLoadFunc type GraphContext struct { DB * badger.DB Log Logger + NodeLoadFuncs NodeLoadMap + StateLoadFuncs StateLoadMap } func NewGraphContext(db * badger.DB, log Logger) * GraphContext { - return &GraphContext{DB: db, Log: log} + ctx := GraphContext{ + DB: db, + Log: log, + NodeLoadFuncs: NodeLoadMap{ + "base_lockable": LoadBaseLockable, + }, + StateLoadFuncs: StateLoadMap{ + "base_lockable": LoadBaseLockableState, + }, + } + + + + return &ctx } // A Logger is passed around to record events happening to components enabled by SetComponents @@ -193,7 +211,10 @@ func CancelSignal(source GraphNode) BaseSignal { } type NodeState interface { + // Human-readable name of the node, not guaranteed to be unique Name() string + // Type of the node this state is attached to. Used to deserialize the state to a node from the database + Type() string } // GraphNode is the interface common to both DAG nodes and Event tree nodes @@ -221,12 +242,25 @@ type GraphNode interface { SignalChannel() chan GraphSignal } +const NODE_SIGNAL_BUFFER = 256 + +func RestoreNode(ctx * GraphContext, id NodeID) BaseNode { + node := BaseNode{ + id: id, + signal: make(chan GraphSignal, NODE_SIGNAL_BUFFER), + listeners: map[chan GraphSignal]chan GraphSignal{}, + state: nil, + } + + ctx.Log.Logf("graph", "RESTORE_NODE: %s", node.id) + return node +} + // Create a new base node with a new ID func NewNode(ctx * GraphContext, state NodeState) (BaseNode, error) { - node := BaseNode{ id: RandID(), - signal: make(chan GraphSignal, 512), + signal: make(chan GraphSignal, NODE_SIGNAL_BUFFER), listeners: map[chan GraphSignal]chan GraphSignal{}, state: state, } @@ -289,6 +323,30 @@ func ReadDBStateCopy(ctx * GraphContext, id NodeID) ([]byte, error) { return val, nil } +func ReadDBState(ctx * GraphContext, id NodeID) ([]byte, error) { + var bytes []byte + err := ctx.DB.View(func(txn *badger.Txn) error { + item, err := txn.Get([]byte(id)) + if err != nil { + return err + } + + return item.Value(func(val []byte) error { + bytes = append([]byte{}, val...) + return nil + }) + }) + + if err != nil { + ctx.Log.Logf("db", "DB_READ_ERR: %s - %e", id, err) + return nil, err + } + + ctx.Log.Logf("db", "DB_READ: %s - %s", id, string(bytes)) + + return bytes, nil +} + func WriteDBState(ctx * GraphContext, id NodeID, state NodeState) error { ctx.Log.Logf("db", "DB_WRITE: %s - %+v", id, state) @@ -331,72 +389,70 @@ func checkForDuplicate(nodes []GraphNode) error { return nil } -func UseStates(ctx * GraphContext, nodes []GraphNode, states_fn func(states []NodeState)(error)) error { +type NodeStateMap map[NodeID]NodeState +type StatesFn func(states NodeStateMap)(error) +func UseStates(ctx * GraphContext, nodes []GraphNode, states_fn StatesFn) error { + states := NodeStateMap{} + return UseMoreStates(ctx, nodes, states, states_fn) +} +func UseMoreStates(ctx * GraphContext, nodes []GraphNode, states NodeStateMap, states_fn StatesFn) error { err := checkForDuplicate(nodes) if err != nil { return err } + locked_nodes := []GraphNode{} for _, node := range(nodes) { - node.StateLock().RLock() - } - - states := make([]NodeState, len(nodes)) - for i, node := range(nodes) { - states[i] = node.State() + _, locked := states[node.ID()] + if locked == false { + node.StateLock().RLock() + states[node.ID()] = node.State() + locked_nodes = append(locked_nodes, node) + } } err = states_fn(states) - for _, node := range(nodes) { + for _, node := range(locked_nodes) { + delete(states, node.ID()) node.StateLock().RUnlock() } return err } -func UpdateStates(ctx * GraphContext, nodes []GraphNode, states_fn func(states []NodeState)([]NodeState, error)) error { +func UpdateStates(ctx * GraphContext, nodes []GraphNode, states_fn StatesFn) error { + states := NodeStateMap{} + return UpdateMoreStates(ctx, nodes, states, states_fn) +} +func UpdateMoreStates(ctx * GraphContext, nodes []GraphNode, states NodeStateMap, states_fn StatesFn) error { err := checkForDuplicate(nodes) if err != nil { return err } + locked_nodes := []GraphNode{} for _, node := range(nodes) { - node.StateLock().Lock() - } - - states := make([]NodeState, len(nodes)) - for i, node := range(nodes) { - states[i] = node.State() - } - - new_states, err := states_fn(states) - - if new_states != nil { - if len(new_states) != len(nodes) { - panic(fmt.Sprintf("NODE_NEW_STATE_LEN_MISMATCH: %d/%d", len(new_states), len(nodes))) + _, locked := states[node.ID()] + if locked == false { + node.StateLock().Lock() + states[node.ID()] = node.State() + locked_nodes = append(locked_nodes, node) } + } - for i, new_state := range(new_states) { - if new_state != nil { - old_state_type := reflect.TypeOf(states[i]) - new_state_type := reflect.TypeOf(new_state) - - if old_state_type != new_state_type { - panic(fmt.Sprintf("NODE_STATE_MISMATCH: old - %+v, new - %+v", old_state_type, new_state_type)) - } - - err := WriteDBState(ctx, nodes[i].ID(), new_state) - if err != nil { - panic(fmt.Sprintf("DB_WRITE_ERROR: %s", err)) - } - - nodes[i].SetState(new_state) + err = states_fn(states) + if err == nil { + for _, node := range(nodes) { + err := WriteDBState(ctx, node.ID(), node.State()) + if err != nil { + panic(fmt.Sprintf("DB_WRITE_ERROR: %s", err)) } } } - for _, node := range(nodes) { + for _, node := range(locked_nodes) { + delete(states, node.ID()) node.StateLock().Unlock() } diff --git a/lockable.go b/lockable.go index 5d99b4d..24b02af 100644 --- a/lockable.go +++ b/lockable.go @@ -30,6 +30,7 @@ type LockableState interface { // BaseLockableStates are a minimum collection of variables for a basic implementation of a LockHolder // Include in any state structs that should be lockable type BaseLockableState struct { + _type string name string owner Lockable requirements []Lockable @@ -38,6 +39,7 @@ type BaseLockableState struct { } type BaseLockableStateJSON struct { + Type string `json:"type"` Name string `json:"name"` Owner *NodeID `json:"owner"` Dependencies []NodeID `json:"dependencies"` @@ -45,6 +47,10 @@ type BaseLockableStateJSON struct { LocksHeld map[NodeID]*NodeID `json:"locks_held"` } +func (state * BaseLockableState) Type() string { + return state._type +} + func (state * BaseLockableState) MarshalJSON() ([]byte, error) { requirement_ids := make([]NodeID, len(state.requirements)) for i, requirement := range(state.requirements) { @@ -73,6 +79,7 @@ func (state * BaseLockableState) MarshalJSON() ([]byte, error) { } return json.Marshal(&BaseLockableStateJSON{ + Type: state._type, Name: state.name, Owner: owner_id, Dependencies: dependency_ids, @@ -150,6 +157,10 @@ func LinkLockables(ctx * GraphContext, lockable Lockable, requirements []Lockabl return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link Lockables to nil as requirements") } + if len(requirements) == 0 { + return nil + } + for _, requirement := range(requirements) { if requirement == nil { return fmt.Errorf("LOCKABLE_LINK_ERR: Will not link nil to a Lockable as a requirement") @@ -165,48 +176,59 @@ func LinkLockables(ctx * GraphContext, lockable Lockable, requirements []Lockabl for i, node := range(requirements) { nodes[i+1] = node } - err := UpdateStates(ctx, nodes, func(states []NodeState) ([]NodeState, error) { + err := UpdateStates(ctx, nodes, func(states NodeStateMap) error { // Check that all the requirements can be added - lockable_state := states[0].(LockableState) + lockable_state := states[lockable.ID()].(LockableState) // If the lockable is already locked, need to lock this resource as well before we can add it - for i, requirement := range(requirements) { - requirement_state := states[i+1].(LockableState) - if checkIfRequirement(ctx, lockable.ID(), requirement_state, requirement.ID()) == true { - return nil, fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) + for _, requirement := range(requirements) { + requirement_state := states[requirement.ID()].(LockableState) + for _, req := range(requirements) { + if req.ID() == requirement.ID() { + continue + } + if checkIfRequirement(ctx, req.ID(), requirement_state, requirement.ID(), states) == true { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependenyc of %s so cannot add the same dependency", req.ID(), requirement.ID()) + } + } + if checkIfRequirement(ctx, lockable.ID(), requirement_state, requirement.ID(), states) == true { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as requirement", requirement.ID(), lockable.ID()) } - if checkIfRequirement(ctx, requirement.ID(), lockable_state, lockable.ID()) == true { - return nil, fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) + if checkIfRequirement(ctx, requirement.ID(), lockable_state, lockable.ID(), states) == true { + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is a dependency of %s so cannot link as dependency again", lockable.ID(), requirement.ID()) } if lockable_state.Owner() == nil { // If the new owner isn't locked, we can add the requirement } else if requirement_state.Owner() == nil { // if the new requirement isn't already locked but the owner is, the requirement needs to be locked first - return nil, fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is locked, %s must be locked to add", lockable.ID(), requirement.ID()) } else { // If the new requirement is already locked and the owner is already locked, their owners need to match if requirement_state.Owner().ID() != lockable_state.Owner().ID() { - return nil, fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) + return fmt.Errorf("LOCKABLE_LINK_ERR: %s is not locked by the same owner as %s, can't link as requirement", requirement.ID(), lockable.ID()) } } } // Update the states of the requirements - for i, requirement := range(requirements) { - requirement_state := states[i+1].(LockableState) + for _, requirement := range(requirements) { + requirement_state := states[requirement.ID()].(LockableState) requirement_state.AddDependency(lockable) lockable_state.AddRequirement(requirement) + ctx.Log.Logf("lockable", "LOCKABLE_LINK: linked %s to %s as a requirement", requirement.ID(), lockable.ID()) } // Return no error - return states, nil + return nil }) + return err } -func NewBaseLockableState(name string) BaseLockableState { +func NewBaseLockableState(name string, _type string) BaseLockableState { state := BaseLockableState{ locks_held: map[NodeID]Lockable{}, + _type: _type, name: name, owner: nil, requirements: []Lockable{}, @@ -229,8 +251,8 @@ type Lockable interface { } func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphSignal) { - UseStates(ctx, []GraphNode{lockable}, func(states []NodeState) (error){ - lockable_state := states[0].(LockableState) + UseStates(ctx, []GraphNode{lockable}, func(states NodeStateMap) error { + lockable_state := states[lockable.ID()].(LockableState) if signal.Direction() == Up { // Child->Parent, lockable updates dependency lockables owner_sent := false @@ -259,15 +281,15 @@ func (lockable * BaseLockable) PropagateUpdate(ctx * GraphContext, signal GraphS }) } -func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_id NodeID) bool { +func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_id NodeID, states NodeStateMap) bool { for _, c := range(cur.Requirements()) { if c.ID() == r_id { return true } is_requirement := false - UseStates(ctx, []GraphNode{c}, func(states []NodeState) (error) { - requirement_state := states[0].(LockableState) - is_requirement = checkIfRequirement(ctx, cur_id, requirement_state, c.ID()) + UpdateMoreStates(ctx, []GraphNode{c}, states, func(states NodeStateMap) (error) { + requirement_state := states[c.ID()].(LockableState) + is_requirement = checkIfRequirement(ctx, cur_id, requirement_state, c.ID(), states) return nil }) @@ -279,7 +301,7 @@ func checkIfRequirement(ctx * GraphContext, r_id NodeID, cur LockableState, cur_ return false } -func LockLockables(ctx * GraphContext, to_lock []Lockable, holder Lockable, holder_state LockableState, owner_states map[NodeID]LockableState) error { +func LockLockables(ctx * GraphContext, to_lock []Lockable, holder Lockable, holder_state LockableState, states NodeStateMap) error { if to_lock == nil { return fmt.Errorf("LOCKABLE_LOCK_ERR: no list provided") } @@ -311,20 +333,20 @@ func LockLockables(ctx * GraphContext, to_lock []Lockable, holder Lockable, hold node_list[i] = l } - err := UpdateStates(ctx, node_list, func(states []NodeState) ([]NodeState, error) { + err := UpdateMoreStates(ctx, node_list, states, func(states NodeStateMap) error { // First loop is to check that the states can be locked, and locks all requirements - for i, state := range(states) { - req := to_lock[i] + for _, req := range(to_lock) { + state := states[req.ID()] req_state, ok := state.(LockableState) ctx.Log.Logf("lockable", "LOCKABLE_LOCKING: %s from %s", req.ID(), holder.ID()) if ok == false { - return nil, fmt.Errorf("LOCKABLE_LOCK_ERR: %s(requirement of %s) does not have a LockableState", req.ID(), holder.ID()) + return fmt.Errorf("LOCKABLE_LOCK_ERR: %s(requirement of %s) does not have a LockableState", req.ID(), holder.ID()) } // Check custom lock conditions err := req.CanLock(holder, req_state) if err != nil { - return nil, err + return err } // If req is alreay locked, check that we can pass the lock @@ -338,58 +360,45 @@ func LockLockables(ctx * GraphContext, to_lock []Lockable, holder Lockable, hold // So if the owner is the same node we don't need a new state, but if the owner is a different node then we need to grab it's state and add it to the list if owner.ID() == req.ID() { if req_state.AllowedToTakeLock(holder.ID(), req.ID()) == false { - return nil, fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", holder.ID(), req.ID(), owner.ID()) + return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", holder.ID(), req.ID(), owner.ID()) } // RECURSE: At this point either: // 1) req has no children and the next LockLockables will return instantly // a) in this case, we're holding every state mutex up to the resource being locked // and all the owners passing a lock, so we can start to change state // 2) req has children, and we will recurse(checking that locking is allowed) until we reach a leaf and can release the locks as we change state. The call will either return nil if state has changed, on an error if no state has changed - err := LockLockables(ctx, req_state.Requirements(), req, req_state, owner_states) + err := LockLockables(ctx, req_state.Requirements(), req, req_state, states) if err != nil { - return nil, err + return err } } else { - owner_state, exists := owner_states[owner.ID()] - if exists == false { - err := UseStates(ctx, []GraphNode{req_state.Owner()}, func(states []NodeState)(error){ - owner_state, ok := states[0].(LockableState) - if ok == false { - return fmt.Errorf("LOCKABLE_LOCK_ERR: %s does not have a LockableState", owner.ID()) - } - - if owner_state.AllowedToTakeLock(holder.ID(), req.ID()) == false { - return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", holder.ID(), req.ID(), owner.ID()) - } - owner_states[owner.ID()] = owner_state - err := LockLockables(ctx, req_state.Requirements(), req, req_state, owner_states) - return err - }) - if err != nil { - return nil, err + err := UpdateMoreStates(ctx, []GraphNode{owner}, states, func(states NodeStateMap)(error){ + owner_state, ok := states[owner.ID()].(LockableState) + if ok == false { + return fmt.Errorf("LOCKABLE_LOCK_ERR: %s does not have a LockableState", owner.ID()) } - } else { + if owner_state.AllowedToTakeLock(holder.ID(), req.ID()) == false { - return nil, fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", holder.ID(), req.ID(), owner.ID()) - } - err := LockLockables(ctx, req_state.Requirements(), req, req_state, owner_states) - if err != nil { - return nil, err + return fmt.Errorf("LOCKABLE_LOCK_ERR: %s is not allowed to take %s's lock from %s", holder.ID(), req.ID(), owner.ID()) } + err := LockLockables(ctx, req_state.Requirements(), req, req_state, states) + return err + }) + if err != nil { + return err } } } else { - err := LockLockables(ctx, req_state.Requirements(), req, req_state, owner_states) + err := LockLockables(ctx, req_state.Requirements(), req, req_state, states) if err != nil { - return nil, err + return err } } } // At this point state modification will be started, so no errors can be returned - for i, state := range(states) { - req := to_lock[i] - req_state := state.(LockableState) + for _, req := range(to_lock) { + req_state := states[req.ID()].(LockableState) old_owner := req_state.Owner() req_state.SetOwner(holder) if req.ID() == holder.ID() { @@ -404,12 +413,12 @@ func LockLockables(ctx * GraphContext, to_lock []Lockable, holder Lockable, hold ctx.Log.Logf("lockable", "LOCKABLE_LOCK: %s took lock of %s from %s", holder.ID(), req.ID(), old_owner.ID()) } } - return states, nil + return nil }) return err } -func UnlockLockables(ctx * GraphContext, to_unlock []Lockable, holder Lockable, holder_state LockableState, owner_states map[NodeID]LockableState) error { +func UnlockLockables(ctx * GraphContext, to_unlock []Lockable, holder Lockable, holder_state LockableState, states NodeStateMap) error { if to_unlock == nil { return fmt.Errorf("LOCKABLE_UNLOCK_ERR: no list provided") } @@ -440,41 +449,39 @@ func UnlockLockables(ctx * GraphContext, to_unlock []Lockable, holder Lockable, node_list[i] = l } - err := UpdateStates(ctx, node_list, func(states []NodeState) ([]NodeState, error) { + err := UpdateMoreStates(ctx, node_list, states, func(states NodeStateMap) error { // First loop is to check that the states can be locked, and locks all requirements - for i, state := range(states) { - req := to_unlock[i] - req_state, ok := state.(LockableState) + for _, req := range(to_unlock) { + req_state, ok := states[req.ID()].(LockableState) ctx.Log.Logf("lockable", "LOCKABLE_UNLOCKING: %s from %s", req.ID(), holder.ID()) if ok == false { - return nil, fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s(requirement of %s) does not have a LockableState", req.ID(), holder.ID()) + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s(requirement of %s) does not have a LockableState", req.ID(), holder.ID()) } // Check if the owner is correct if req_state.Owner() != nil { if req_state.Owner().ID() != holder.ID() { - return nil, fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), holder.ID()) + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked by %s", req.ID(), holder.ID()) } } else { - return nil, fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) + return fmt.Errorf("LOCKABLE_UNLOCK_ERR: %s is not locked", req.ID()) } // Check custom unlock conditions err := req.CanUnlock(holder, req_state) if err != nil { - return nil, err + return err } - err = UnlockLockables(ctx, req_state.Requirements(), req, req_state, owner_states) + err = UnlockLockables(ctx, req_state.Requirements(), req, req_state, states) if err != nil { - return nil, err + return err } } // At this point state modification will be started, so no errors can be returned - for i, state := range(states) { - req := to_unlock[i] - req_state := state.(LockableState) + for _, req := range(to_unlock) { + req_state := states[req.ID()].(LockableState) var new_owner Lockable = nil if holder_state == nil { new_owner = req_state.ReturnLock(req.ID()) @@ -489,7 +496,7 @@ func UnlockLockables(ctx * GraphContext, to_unlock []Lockable, holder Lockable, ctx.Log.Logf("lockable", "LOCKABLE_UNLOCK: %s passed lock of %s back to %s", holder.ID(), req.ID(), new_owner.ID()) } } - return states, nil + return nil }) return err } @@ -530,8 +537,148 @@ func NewBaseLockable(ctx * GraphContext, state LockableState) (BaseLockable, err return lockable, nil } +func LoadBaseLockable(ctx * GraphContext, id NodeID) (GraphNode, error) { + // call LoadNodeRecurse on any connected nodes to ensure they're loaded and return the id + base_node := RestoreNode(ctx, id) + lockable := BaseLockable{ + BaseNode: base_node, + } + + return &lockable, nil +} + +func LoadBaseLockableState(ctx * GraphContext, id NodeID, data []byte, loaded_nodes map[NodeID]GraphNode)(NodeState, error){ + var j BaseLockableStateJSON + err := json.Unmarshal(data, &j) + if err != nil { + return nil, err + } + + var owner Lockable = nil + if j.Owner != nil { + o, err := LoadNodeRecurse(ctx, *j.Owner, loaded_nodes) + if err != nil { + return nil, err + } + o_l, ok := o.(Lockable) + if ok == false { + return nil, err + } + owner = o_l + } + + state := BaseLockableState{ + _type: "base_lockable", + name: j.Name, + owner: owner, + dependencies: make([]Lockable, len(j.Dependencies)), + requirements: make([]Lockable, len(j.Requirements)), + locks_held: map[NodeID]Lockable{}, + } + + for i, dep := range(j.Dependencies) { + dep_node, err := LoadNodeRecurse(ctx, dep, loaded_nodes) + if err != nil { + return nil, err + } + dep_l, ok := dep_node.(Lockable) + if ok == false { + return nil, fmt.Errorf("%+v is not a Lockable as expected", dep_node) + } + state.dependencies[i] = dep_l + } + + for i, req := range(j.Requirements) { + req_node, err := LoadNodeRecurse(ctx, req, loaded_nodes) + if err != nil { + return nil, err + } + req_l, ok := req_node.(Lockable) + if ok == false { + return nil, fmt.Errorf("%+v is not a Lockable as expected", req_node) + } + state.requirements[i] = req_l + } + + for l_id, h_id := range(j.LocksHeld) { + _, err := LoadNodeRecurse(ctx, l_id, loaded_nodes) + if err != nil { + return nil, err + } + var h_l Lockable = nil + if h_id != nil { + h_node, err := LoadNodeRecurse(ctx, *h_id, loaded_nodes) + if err != nil { + return nil, err + } + h, ok := h_node.(Lockable) + if ok == false { + return nil, err + } + h_l = h + } + state.locks_held[l_id] = h_l + } + return &state, nil +} + +func LoadNode(ctx * GraphContext, id NodeID) (GraphNode, error) { + // Initialize an empty list of loaded nodes, then start loading them from id + loaded_nodes := map[NodeID]GraphNode{} + return LoadNodeRecurse(ctx, id, loaded_nodes) +} + +type DBJSONBase struct { + Type string `json:"type"` +} + +// Check if a node is already loaded, load it's state bytes from the DB and parse the type if it's not already loaded +// Call the node load function related to the type, which will call this parse function recusively as needed +func LoadNodeRecurse(ctx * GraphContext, id NodeID, loaded_nodes map[NodeID]GraphNode) (GraphNode, error) { + node, exists := loaded_nodes[id] + if exists == false { + state_bytes, err := ReadDBState(ctx, id) + if err != nil { + return nil, err + } + + var base DBJSONBase + err = json.Unmarshal(state_bytes, &base) + if err != nil { + return nil, err + } + + ctx.Log.Logf("graph", "GRAPH_DB_LOAD: %s(%s)", base.Type, id) + + node_fn, exists := ctx.NodeLoadFuncs[base.Type] + if exists == false { + return nil, fmt.Errorf("%s is not a known node type", base.Type) + } + + node, err = node_fn(ctx, id) + if err != nil { + return nil, err + } + + loaded_nodes[id] = node + + state_fn, exists := ctx.StateLoadFuncs[base.Type] + if exists == false { + return nil, fmt.Errorf("%s is not a known node state type", base.Type) + } + + state, err := state_fn(ctx, id, state_bytes, loaded_nodes) + if err != nil { + return nil, err + } + + node.SetState(state) + } + return node, nil +} + func NewSimpleBaseLockable(ctx * GraphContext, name string, requirements []Lockable) (*BaseLockable, error) { - state := NewBaseLockableState(name) + state := NewBaseLockableState(name, "base_lockable") lockable, err := NewBaseLockable(ctx, &state) if err != nil { return nil, err diff --git a/lockable_test.go b/lockable_test.go index 12f8cc9..ad86d0d 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -3,7 +3,6 @@ package graphvent import ( "testing" "fmt" - "encoding/json" "time" ) @@ -35,11 +34,11 @@ func TestLockableSelfLock(t * testing.T) { r1, err := NewSimpleBaseLockable(ctx, "Test lockable 1", []Lockable{}) fatalErr(t, err) - err = LockLockables(ctx, []Lockable{r1}, r1, nil, map[NodeID]LockableState{}) + err = LockLockables(ctx, []Lockable{r1}, r1, nil, NodeStateMap{}) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (error) { - owner_id := states[0].(LockableState).Owner().ID() + err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { + owner_id := states[r1.ID()].(LockableState).Owner().ID() if owner_id != r1.ID() { return fmt.Errorf("r1 is owned by %s instead of self", owner_id) } @@ -47,11 +46,11 @@ func TestLockableSelfLock(t * testing.T) { }) fatalErr(t, err) - err = UnlockLockables(ctx, []Lockable{r1}, r1, nil, map[NodeID]LockableState{}) + err = UnlockLockables(ctx, []Lockable{r1}, r1, nil, NodeStateMap{}) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (error) { - owner := states[0].(LockableState).Owner() + err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { + owner := states[r1.ID()].(LockableState).Owner() if owner != nil { return fmt.Errorf("r1 is not unowned after unlock: %s", owner.ID()) } @@ -62,7 +61,7 @@ func TestLockableSelfLock(t * testing.T) { } func TestLockableSelfLockTiered(t * testing.T) { - ctx := logTestContext(t, []string{"lockable"}) + ctx := testContext(t) r1, err := NewSimpleBaseLockable(ctx, "Test lockable 1", []Lockable{}) fatalErr(t, err) @@ -73,41 +72,38 @@ func TestLockableSelfLockTiered(t * testing.T) { r3, err := NewSimpleBaseLockable(ctx, "Test lockable 3", []Lockable{r1, r2}) fatalErr(t, err) - err = LockLockables(ctx, []Lockable{r3}, r3, nil, map[NodeID]LockableState{}) + err = LockLockables(ctx, []Lockable{r3}, r3, nil, NodeStateMap{}) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1, r2, r3}, func(states []NodeState) (error) { - owner_1_id := states[0].(LockableState).Owner().ID() + err = UseStates(ctx, []GraphNode{r1, r2, r3}, func(states NodeStateMap) (error) { + owner_1_id := states[r1.ID()].(LockableState).Owner().ID() if owner_1_id != r3.ID() { return fmt.Errorf("r1 is owned by %s instead of r3", owner_1_id) } - owner_2_id := states[1].(LockableState).Owner().ID() + owner_2_id := states[r2.ID()].(LockableState).Owner().ID() if owner_2_id != r3.ID() { return fmt.Errorf("r2 is owned by %s instead of r3", owner_2_id) } - ser, _ := json.MarshalIndent(states, "", " ") - fmt.Printf("\n\n%s\n\n", ser) - return nil }) fatalErr(t, err) - err = UnlockLockables(ctx, []Lockable{r3}, r3, nil, map[NodeID]LockableState{}) + err = UnlockLockables(ctx, []Lockable{r3}, r3, nil, NodeStateMap{}) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1, r2, r3}, func(states []NodeState) (error) { - owner_1 := states[0].(LockableState).Owner() + err = UseStates(ctx, []GraphNode{r1, r2, r3}, func(states NodeStateMap) (error) { + owner_1 := states[r1.ID()].(LockableState).Owner() if owner_1 != nil { return fmt.Errorf("r1 is not unowned after unlocking: %s", owner_1.ID()) } - owner_2 := states[1].(LockableState).Owner() + owner_2 := states[r2.ID()].(LockableState).Owner() if owner_2 != nil { return fmt.Errorf("r2 is not unowned after unlocking: %s", owner_2.ID()) } - owner_3 := states[2].(LockableState).Owner() + owner_3 := states[r3.ID()].(LockableState).Owner() if owner_3 != nil { return fmt.Errorf("r3 is not unowned after unlocking: %s", owner_3.ID()) } @@ -126,16 +122,16 @@ func TestLockableLockOther(t * testing.T) { r2, err := NewSimpleBaseLockable(ctx, "Test lockable 2", []Lockable{}) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r2}, func(states []NodeState) ([]NodeState, error) { - node_state := states[0].(LockableState) - err := LockLockables(ctx, []Lockable{r1}, r2, node_state, map[NodeID]LockableState{}) + err = UpdateStates(ctx, []GraphNode{r2}, func(states NodeStateMap) (error) { + node_state := states[r2.ID()].(LockableState) + err := LockLockables(ctx, []Lockable{r1}, r2, node_state, NodeStateMap{}) fatalErr(t, err) - return []NodeState{node_state}, nil + return nil }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (error) { - owner_id := states[0].(LockableState).Owner().ID() + err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { + owner_id := states[r1.ID()].(LockableState).Owner().ID() if owner_id != r2.ID() { return fmt.Errorf("r1 is owned by %s instead of r2", owner_id) } @@ -144,16 +140,16 @@ func TestLockableLockOther(t * testing.T) { }) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r2}, func(states []NodeState) ([]NodeState, error) { - node_state := states[0].(LockableState) - err := UnlockLockables(ctx, []Lockable{r1}, r2, node_state, map[NodeID]LockableState{}) + err = UpdateStates(ctx, []GraphNode{r2}, func(states NodeStateMap) (error) { + node_state := states[r2.ID()].(LockableState) + err := UnlockLockables(ctx, []Lockable{r1}, r2, node_state, NodeStateMap{}) fatalErr(t, err) - return []NodeState{node_state}, nil + return nil }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (error) { - owner := states[0].(LockableState).Owner() + err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { + owner := states[r1.ID()].(LockableState).Owner() if owner != nil { return fmt.Errorf("r1 is owned by %s instead of r2", owner.ID()) } @@ -173,22 +169,22 @@ func TestLockableLockSimpleConflict(t * testing.T) { r2, err := NewSimpleBaseLockable(ctx, "Test lockable 2", []Lockable{}) fatalErr(t, err) - err = LockLockables(ctx, []Lockable{r1}, r1, nil, map[NodeID]LockableState{}) + err = LockLockables(ctx, []Lockable{r1}, r1, nil, NodeStateMap{}) fatalErr(t, err) - err = UpdateStates(ctx, []GraphNode{r2}, func(states []NodeState) ([]NodeState, error) { - node_state := states[0].(LockableState) - err := LockLockables(ctx, []Lockable{r1}, r2, node_state, map[NodeID]LockableState{}) + err = UpdateStates(ctx, []GraphNode{r2}, func(states NodeStateMap) (error) { + node_state := states[r2.ID()].(LockableState) + err := LockLockables(ctx, []Lockable{r1}, r2, node_state, NodeStateMap{}) if err == nil { t.Fatal("r2 took r1's lock from itself") } - return []NodeState{node_state}, nil + return nil }) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (error) { - owner_id := states[0].(LockableState).Owner().ID() + err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { + owner_id := states[r1.ID()].(LockableState).Owner().ID() if owner_id != r1.ID() { return fmt.Errorf("r1 is owned by %s instead of r1", owner_id) } @@ -197,11 +193,11 @@ func TestLockableLockSimpleConflict(t * testing.T) { }) fatalErr(t, err) - err = UnlockLockables(ctx, []Lockable{r1}, r1, nil, map[NodeID]LockableState{}) + err = UnlockLockables(ctx, []Lockable{r1}, r1, nil, NodeStateMap{}) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{r1}, func(states []NodeState) (error) { - owner := states[0].(LockableState).Owner() + err = UseStates(ctx, []GraphNode{r1}, func(states NodeStateMap) (error) { + owner := states[r1.ID()].(LockableState).Owner() if owner != nil { return fmt.Errorf("r1 is owned by %s instead of r1", owner.ID()) } @@ -224,10 +220,10 @@ func TestLockableLockTieredConflict(t * testing.T) { r3, err := NewSimpleBaseLockable(ctx, "Test lockable 3", []Lockable{r1}) fatalErr(t, err) - err = LockLockables(ctx, []Lockable{r2}, r2, nil, map[NodeID]LockableState{}) + err = LockLockables(ctx, []Lockable{r2}, r2, nil, NodeStateMap{}) fatalErr(t, err) - err = LockLockables(ctx, []Lockable{r3}, r3, nil, map[NodeID]LockableState{}) + err = LockLockables(ctx, []Lockable{r3}, r3, nil, NodeStateMap{}) if err == nil { t.Fatal("Locked r3 which depends on r1 while r2 which depends on r1 is already locked") } @@ -308,3 +304,32 @@ func TestOwnerNotUpdatedTwice(t * testing.T) { (*GraphTester)(t).WaitForValue(ctx, update_channel, "test_update", l1, 100*time.Millisecond, "Dicn't received test_update on l2 from l1") (*GraphTester)(t).CheckForNone(update_channel, "Second update received on dependency") } + +func TestLockableDependencyOverlap(t * testing.T) { + ctx := testContext(t) + l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{}) + fatalErr(t, err) + l2, err := NewSimpleBaseLockable(ctx, "Test Lockable 2", []Lockable{l1}) + fatalErr(t, err) + _, err = NewSimpleBaseLockable(ctx, "Test Lockable 3", []Lockable{l1, l2}) + if err == nil { + t.Fatal("Should have thrown an error because of dependency overlap") + } +} + +func TestLockableDBLoad(t * testing.T){ + ctx := testContext(t) + l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{}) + fatalErr(t, err) + l2, err := NewSimpleBaseLockable(ctx, "Test Lockable 2", []Lockable{}) + fatalErr(t, err) + l3, err := NewSimpleBaseLockable(ctx, "Test Lockable 3", []Lockable{l1, l2}) + fatalErr(t, err) + l4, err := NewSimpleBaseLockable(ctx, "Test Lockable 4", []Lockable{l3}) + fatalErr(t, err) + _, err = NewSimpleBaseLockable(ctx, "Test Lockable 5", []Lockable{l4}) + fatalErr(t, err) + + _, err = LoadNode(ctx, l3.ID()) + fatalErr(t, err) +} diff --git a/thread.go b/thread.go index 6eee451..d3de178 100644 --- a/thread.go +++ b/thread.go @@ -11,8 +11,8 @@ import ( // Update the threads listeners, and notify the parent to do the same func (thread * BaseThread) PropagateUpdate(ctx * GraphContext, signal GraphSignal) { - UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (error) { - thread_state := states[0].(ThreadState) + UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) (error) { + thread_state := states[thread.ID()].(ThreadState) if signal.Direction() == Up { // Child->Parent, thread updates parent and connected requirement if thread_state.Parent() != nil { @@ -167,8 +167,8 @@ func checkIfChild(ctx * GraphContext, thread_id NodeID, cur_state ThreadState, c return true } is_child := false - UseStates(ctx, []GraphNode{child}, func(states []NodeState) (error) { - child_state := states[0].(ThreadState) + UseStates(ctx, []GraphNode{child}, func(states NodeStateMap) (error) { + child_state := states[child.ID()].(ThreadState) is_child = checkIfChild(ctx, cur_id, child_state, child.ID()) return nil }) @@ -190,29 +190,29 @@ func LinkThreads(ctx * GraphContext, thread Thread, child Thread, info ThreadInf } - err := UpdateStates(ctx, []GraphNode{thread, child}, func(states []NodeState) ([]NodeState, error) { - thread_state := states[0].(ThreadState) - child_state := states[1].(ThreadState) + err := UpdateStates(ctx, []GraphNode{thread, child}, func(states NodeStateMap) error { + thread_state := states[thread.ID()].(ThreadState) + child_state := states[child.ID()].(ThreadState) if child_state.Parent() != nil { - return nil, fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) + return fmt.Errorf("EVENT_LINK_ERR: %s already has a parent, cannot link as child", child.ID()) } if checkIfChild(ctx, thread.ID(), child_state, child.ID()) == true { - return nil, fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) + return fmt.Errorf("EVENT_LINK_ERR: %s is a child of %s so cannot add as parent", thread.ID(), child.ID()) } if checkIfChild(ctx, child.ID(), thread_state, thread.ID()) == true { - return nil, fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) + return fmt.Errorf("EVENT_LINK_ERR: %s is already a parent of %s so will not add again", thread.ID(), child.ID()) } err := thread_state.AddChild(child, info) if err != nil { - return nil, fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err) + return fmt.Errorf("EVENT_LINK_ERR: error adding %s as child to %s: %e", child.ID(), thread.ID(), err) } child_state.SetParent(thread) - return states, nil + return nil }) if err != nil { @@ -250,8 +250,8 @@ func FindChild(ctx * GraphContext, thread Thread, thread_state ThreadState, id N for _, child := range thread_state.Children() { var result Thread = nil - UseStates(ctx, []GraphNode{child}, func(states []NodeState) (error) { - child_state := states[0].(ThreadState) + UseStates(ctx, []GraphNode{child}, func(states NodeStateMap) (error) { + child_state := states[child.ID()].(ThreadState) result = FindChild(ctx, child, child_state, id) return nil }) @@ -284,13 +284,13 @@ func ChildGo(ctx * GraphContext, thread_state ThreadState, thread Thread, child_ func RunThread(ctx * GraphContext, thread Thread) error { ctx.Log.Logf("thread", "THREAD_RUN: %s", thread.ID()) - err := LockLockables(ctx, []Lockable{thread}, thread, nil, map[NodeID]LockableState{}) + err := LockLockables(ctx, []Lockable{thread}, thread, nil, NodeStateMap{}) if err != nil { return err } - err = UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (error) { - thread_state := states[0].(ThreadState) + err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) (error) { + thread_state := states[thread.ID()].(ThreadState) if thread_state.Owner() == nil { return fmt.Errorf("THREAD_RUN_NOT_LOCKED: %s", thread_state.Name()) } else if thread_state.Owner().ID() != thread.ID() { @@ -321,8 +321,8 @@ func RunThread(ctx * GraphContext, thread Thread) error { } } - err = UseStates(ctx, []GraphNode{thread}, func(states []NodeState) (error) { - thread_state := states[0].(ThreadState) + err = UseStates(ctx, []GraphNode{thread}, func(states NodeStateMap) (error) { + thread_state := states[thread.ID()].(ThreadState) err := thread_state.Stop() return err @@ -332,7 +332,7 @@ func RunThread(ctx * GraphContext, thread Thread) error { return err } - err = UnlockLockables(ctx, []Lockable{thread}, thread, nil, map[NodeID]LockableState{}) + err = UnlockLockables(ctx, []Lockable{thread}, thread, nil, NodeStateMap{}) if err != nil { ctx.Log.Logf("thread", "THREAD_RUN_UNLOCK_ERR: %e", err) return err @@ -452,9 +452,9 @@ var ThreadCancel = func(ctx * GraphContext, thread Thread, signal GraphSignal) ( return "", nil } -func NewBaseThreadState(name string) BaseThreadState { +func NewBaseThreadState(name string, _type string) BaseThreadState { return BaseThreadState{ - BaseLockableState: NewBaseLockableState(name), + BaseLockableState: NewBaseLockableState(name, _type), children: []Thread{}, child_info: map[NodeID]ThreadInfo{}, parent: nil, @@ -493,7 +493,7 @@ func NewBaseThread(ctx * GraphContext, actions ThreadActions, handlers ThreadHan } func NewSimpleBaseThread(ctx * GraphContext, name string, requirements []Lockable, actions ThreadActions, handlers ThreadHandlers) (* BaseThread, error) { - state := NewBaseThreadState(name) + state := NewBaseThreadState(name, "base_thread") thread, err := NewBaseThread(ctx, actions, handlers, &state) if err != nil { return nil, err diff --git a/thread_test.go b/thread_test.go index 194108a..9ebf6f0 100644 --- a/thread_test.go +++ b/thread_test.go @@ -3,11 +3,10 @@ package graphvent import ( "testing" "time" - "encoding/json" "fmt" ) -func TestNewEvent(t * testing.T) { +func TestNewThread(t * testing.T) { ctx := testContext(t) t1, err := NewSimpleBaseThread(ctx, "Test thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) @@ -21,18 +20,17 @@ func TestNewEvent(t * testing.T) { err = RunThread(ctx, t1) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{t1}, func(states []NodeState) (error) { - ser, err := json.MarshalIndent(states, "", " ") - fatalErr(t, err) - - fmt.Printf("\n%s\n", ser) - + err = UseStates(ctx, []GraphNode{t1}, func(states NodeStateMap) (error) { + owner := states[t1.ID()].(ThreadState).Owner() + if owner != nil { + return fmt.Errorf("Wrong owner %+v", owner) + } return nil }) } -func TestEventWithRequirement(t * testing.T) { - ctx := logTestContext(t, []string{"lockable", "thread"}) +func TestThreadWithRequirement(t * testing.T) { + ctx := testContext(t) l1, err := NewSimpleBaseLockable(ctx, "Test Lockable 1", []Lockable{}) fatalErr(t, err) @@ -42,14 +40,6 @@ func TestEventWithRequirement(t * testing.T) { go func (thread Thread) { time.Sleep(10*time.Millisecond) - err := UseStates(ctx, []GraphNode{l1}, func(states []NodeState) (error) { - ser, err := json.MarshalIndent(states[0], "", " ") - fatalErr(t, err) - - fmt.Printf("\n%s\n", ser) - return nil - }) - fatalErr(t, err) SendUpdate(ctx, t1, CancelSignal(nil)) }(t1) fatalErr(t, err) @@ -57,18 +47,18 @@ func TestEventWithRequirement(t * testing.T) { err = RunThread(ctx, t1) fatalErr(t, err) - err = UseStates(ctx, []GraphNode{l1}, func(states []NodeState) (error) { - ser, err := json.MarshalIndent(states[0], "", " ") - fatalErr(t, err) - - fmt.Printf("\n%s\n", ser) + err = UseStates(ctx, []GraphNode{l1}, func(states NodeStateMap) (error) { + owner := states[l1.ID()].(LockableState).Owner() + if owner != nil { + return fmt.Errorf("Wrong owner %+v", owner) + } return nil }) fatalErr(t, err) } -func TestCustomEventState(t * testing.T ) { - ctx := logTestContext(t, []string{"lockable", "thread"}) +func TestCustomThreadState(t * testing.T ) { + ctx := testContext(t) t1, err := NewSimpleBaseThread(ctx, "Test Thread 1", []Lockable{}, ThreadActions{}, ThreadHandlers{}) fatalErr(t, err)