diff --git a/context.go b/context.go index 83b27f4..0428872 100644 --- a/context.go +++ b/context.go @@ -582,16 +582,6 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - req_info_type := reflect.TypeOf(ReqInfo{}) - req_info_info, err := GetStructInfo(ctx, req_info_type) - if err != nil { - return nil, err - } - err = ctx.RegisterType(req_info_type, ReqInfoType, nil, SerializeStruct(req_info_info), nil, DeserializeStruct(req_info_info)) - if err != nil { - return nil, err - } - err = ctx.RegisterExtension(reflect.TypeOf((*LockableExt)(nil)), LockableExtType, nil) if err != nil { return nil, err diff --git a/gql.go b/gql.go index bb7e2fe..c7a6c9a 100644 --- a/gql.go +++ b/gql.go @@ -1389,7 +1389,7 @@ func NewGQLExtContext() *GQLExtContext { "requirements", LockableExtType, func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) ([]NodeID, error) { - id_strs, ok := value.Interface().(map[NodeID]ReqInfo) + id_strs, ok := value.Interface().(map[NodeID]ReqState) if ok == false { return nil, fmt.Errorf("can't parse requirements %+v as map[NodeID]ReqState", value.Type()) } diff --git a/lockable.go b/lockable.go index ffde98a..aa4bafc 100644 --- a/lockable.go +++ b/lockable.go @@ -2,6 +2,7 @@ package graphvent import ( "github.com/google/uuid" + "time" ) type ReqState byte @@ -13,9 +14,12 @@ const ( AbortingLock = ReqState(4) ) -type ReqInfo struct { - State ReqState `gv:"state"` - MsgID uuid.UUID `gv:"msg_id"` +var ReqStateStrings = map[ReqState]string { + Unlocked: "Unlocked", + Unlocking: "Unlocking", + Locked: "Locked", + Locking: "Locking", + AbortingLock: "AbortingLock", } type LockableExt struct{ @@ -24,7 +28,8 @@ type LockableExt struct{ Owner *NodeID `gv:"owner"` PendingOwner *NodeID `gv:"pending_owner"` PendingID uuid.UUID `gv:"pending_id"` - Requirements map[NodeID]ReqInfo `gv:"requirements"` + Requirements map[NodeID]ReqState `gv:"requirements"` + WaitInfos WaitMap `gv:"wait_infos"` } func (ext *LockableExt) Type() ExtType { @@ -32,14 +37,11 @@ func (ext *LockableExt) Type() ExtType { } func NewLockableExt(requirements []NodeID) *LockableExt { - var reqs map[NodeID]ReqInfo = nil + var reqs map[NodeID]ReqState = nil if requirements != nil { - reqs = map[NodeID]ReqInfo{} + reqs = map[NodeID]ReqState{} for _, id := range(requirements) { - reqs[id] = ReqInfo{ - Unlocked, - uuid.UUID{}, - } + reqs[id] = Unlocked } } return &LockableExt{ @@ -47,6 +49,7 @@ func NewLockableExt(requirements []NodeID) *LockableExt { Owner: nil, PendingOwner: nil, Requirements: reqs, + WaitInfos: WaitMap{}, } } @@ -65,34 +68,32 @@ func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) { } func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeID, signal *ErrorSignal) (Messages, Changes) { - str := signal.Error var messages Messages = nil var changes Changes = nil - switch str { - case "not_unlocked": - changes = changes.Add("requirements") - if ext.State == Locking { - ext.State = AbortingLock - req_info := ext.Requirements[source] - req_info.State = Unlocked - ext.Requirements[source] = req_info - for id, info := range(ext.Requirements) { - if info.State == Locked { - lock_signal := NewLockSignal("unlock") - req_info := ext.Requirements[id] - req_info.State = Unlocking - req_info.MsgID = lock_signal.ID() - ext.Requirements[id] = req_info - ctx.Log.Logf("lockable", "SENT_ABORT_UNLOCK: %s to %s", lock_signal.ID(), id) - - messages = messages.Add(ctx, id, node, nil, lock_signal) + info, info_found := node.ProcessResponse(ext.WaitInfos, signal) + if info_found { + state, found := ext.Requirements[info.NodeID] + if found == true { + ctx.Log.Logf("lockable", "got mapped response %+v for %+v in state %s", signal, info, ReqStateStrings[state]) + switch state { + case Locking: + ext.State = AbortingLock + ext.Requirements[info.NodeID] = Unlocked + for id, state := range(ext.Requirements) { + if state == Locked { + ext.Requirements[id] = Unlocking + lock_signal := NewLockSignal("unlock") + ext.WaitInfos[lock_signal.Id] = node.QueueTimeout(id, lock_signal, 100*time.Millisecond) + messages = messages.Add(ctx, id, node, nil, lock_signal) + ctx.Log.Logf("lockable", "sent abort unlock to %s from %s", id, node.ID) + } } + case Unlocking: } + } else { + ctx.Log.Logf("lockable", "Got mapped error %s, but %s isn't a requirement", signal, info.NodeID) } - case "not_locked": - panic("RECEIVED not_locked, meaning a node thought it held a lock it didn't") - case "not_requirement": } return messages, changes @@ -109,12 +110,9 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "already_requirement")) } else { if ext.Requirements == nil { - ext.Requirements = map[NodeID]ReqInfo{} - } - ext.Requirements[signal.NodeID] = ReqInfo{ - Unlocked, - uuid.UUID{}, + ext.Requirements = map[NodeID]ReqState{} } + ext.Requirements[signal.NodeID] = Unlocked changes = changes.Add("requirement_added") messages = messages.Add(ctx, source, node, nil, NewSuccessSignal(signal.ID())) } @@ -143,74 +141,75 @@ func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source Nod return messages, changes } - info, found := ext.Requirements[source] - if found == false { - ctx.Log.Logf("lockable", "Got success from non-requirement %s", source) - } else if info.MsgID != signal.ReqID { - ctx.Log.Logf("lockable", "Got success for wrong signal for %s: %s, expecting %s", source, signal.ReqID, info.MsgID) - } else { - if info.State == Locking { - if ext.State == Locking { - info.State = Locked - info.MsgID = uuid.UUID{} - ext.Requirements[source] = info - reqs := 0 - locked := 0 - for _, s := range(ext.Requirements) { - reqs += 1 - if s.State == Locked { - locked += 1 + info, info_found := node.ProcessResponse(ext.WaitInfos, signal) + if info_found == true { + state, found := ext.Requirements[info.NodeID] + if found == false { + ctx.Log.Logf("lockable", "Got success signal for requirement that is no longer in the map(%s), ignoring...", info.NodeID) + } else { + ctx.Log.Logf("lockable", "got mapped response %+v for %+v in state %s", signal, info, ReqStateStrings[state]) + switch state { + case Locking: + switch ext.State { + case Locking: + ext.Requirements[info.NodeID] = Locked + locked := 0 + for _, s := range(ext.Requirements) { + if s == Locked { + locked += 1 + } } - } + if locked == len(ext.Requirements) { + ctx.Log.Logf("lockable", "WHOLE LOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner) + ext.State = Locked + ext.Owner = ext.PendingOwner + changes = changes.Add("locked") + messages = messages.Add(ctx, *ext.Owner, node, nil, NewSuccessSignal(ext.PendingID)) + } else { + changes = changes.Add("partial_lock") + ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, len(ext.Requirements)) + } + case AbortingLock: + ext.Requirements[info.NodeID] = Unlocking - if locked == reqs { - ctx.Log.Logf("lockable", "WHOLE LOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner) - ext.State = Locked - ext.Owner = ext.PendingOwner - changes = changes.Add("locked") - messages = messages.Add(ctx, *ext.Owner, node, nil, NewSuccessSignal(ext.PendingID)) - } else { - changes = changes.Add("partial_lock") - ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs) + lock_signal := NewLockSignal("unlock") + ext.WaitInfos[lock_signal.Id] = node.QueueTimeout(info.NodeID, lock_signal, 100*time.Millisecond) + messages = messages.Add(ctx, info.NodeID, node, nil, lock_signal) + + ctx.Log.Logf("lockable", "sending abort_lock to %s for %s", info.NodeID, node.ID) } - } else if ext.State == AbortingLock { - lock_signal := NewLockSignal("unlock") - info.State = Unlocking - info.MsgID = lock_signal.ID() - ext.Requirements[source] = info - messages = messages.Add(ctx, source, node, nil, lock_signal) - } - } else if info.State == Unlocking { - info.State = Unlocked - info.MsgID = uuid.UUID{} - ext.Requirements[source] = info - reqs := 0 - unlocked := 0 - for _, s := range(ext.Requirements) { - reqs += 1 - if s.State == Unlocked { - unlocked += 1 + case AbortingLock: + ctx.Log.Logf("lockable", "Got success signal in AbortingLock %s", node.ID) + fallthrough + case Unlocking: + ext.Requirements[source] = Unlocked + + unlocked := 0 + for _, s := range(ext.Requirements) { + if s == Unlocked { + unlocked += 1 + } } - } - if unlocked == reqs { - old_state := ext.State - ext.State = Unlocked - ctx.Log.Logf("lockable", "WHOLE UNLOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner) - if old_state == Unlocking { - previous_owner := *ext.Owner - ext.Owner = ext.PendingOwner - ext.ReqID = nil - changes = changes.Add("unlocked") - messages = messages.Add(ctx, previous_owner, node, nil, NewSuccessSignal(ext.PendingID)) - } else if old_state == AbortingLock { - changes = changes.Add("lock_aborted") - messages = messages.Add(ctx, *ext.PendingOwner, node, nil, NewErrorSignal(*ext.ReqID, "not_unlocked")) - ext.PendingOwner = ext.Owner + if unlocked == len(ext.Requirements) { + old_state := ext.State + ext.State = Unlocked + ctx.Log.Logf("lockable", "WHOLE UNLOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner) + if old_state == Unlocking { + previous_owner := *ext.Owner + ext.Owner = ext.PendingOwner + ext.ReqID = nil + changes = changes.Add("unlocked") + messages = messages.Add(ctx, previous_owner, node, nil, NewSuccessSignal(ext.PendingID)) + } else if old_state == AbortingLock { + changes = changes.Add("lock_aborted") + messages = messages.Add(ctx, *ext.PendingOwner, node, nil, NewErrorSignal(*ext.ReqID, "not_unlocked")) + ext.PendingOwner = ext.Owner + } + } else { + changes = changes.Add("partial_unlock") + ctx.Log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, len(ext.Requirements)) } - } else { - changes = changes.Add("partial_unlock") - ctx.Log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs) } } } @@ -225,7 +224,8 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID switch signal.State { case "lock": - if ext.State == Unlocked { + switch ext.State { + case Unlocked: if len(ext.Requirements) == 0 { ext.State = Locked new_owner := source @@ -241,19 +241,21 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.PendingOwner = &new_owner ext.PendingID = signal.ID() changes = changes.Add("locking") - for id, info := range(ext.Requirements) { - if info.State != Unlocked { + for id, state := range(ext.Requirements) { + if state != Unlocked { ctx.Log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING") } + lock_signal := NewLockSignal("lock") - info.State = Locking - info.MsgID = lock_signal.ID() - ext.Requirements[id] = info + ext.WaitInfos[lock_signal.Id] = node.QueueTimeout(id, lock_signal, 100*time.Millisecond) + ext.Requirements[id] = Locking + messages = messages.Add(ctx, id, node, nil, lock_signal) } } - } else { + default: messages = messages.Add(ctx, source, node, nil, NewErrorSignal(signal.ID(), "not_unlocked")) + ctx.Log.Logf("lockable", "Tried to lock %s while locked", node.ID) } case "unlock": if ext.State == Locked { @@ -271,14 +273,15 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ext.PendingOwner = nil ext.PendingID = signal.ID() changes = changes.Add("unlocking") - for id, info := range(ext.Requirements) { - if info.State != Locked { + for id, state := range(ext.Requirements) { + if state != Locked { ctx.Log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING") } + lock_signal := NewLockSignal("unlock") - info.State = Unlocking - info.MsgID = lock_signal.ID() - ext.Requirements[id] = info + ext.WaitInfos[lock_signal.Id] = node.QueueTimeout(id, lock_signal, 100*time.Millisecond) + ext.Requirements[id] = Unlocking + messages = messages.Add(ctx, id, node, nil, lock_signal) } } @@ -291,6 +294,24 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID return messages, changes } +func (ext *LockableExt) HandleTimeoutSignal(ctx *Context, node *Node, source NodeID, signal *TimeoutSignal) (Messages, Changes) { + var messages Messages = nil + var changes Changes = nil + + //TODO: Handle timeout errors better + wait_info, found := node.ProcessResponse(ext.WaitInfos, signal) + if found == true { + state, found := ext.Requirements[wait_info.NodeID] + if found == true { + ctx.Log.Logf("lockable", "%s timed out %s", wait_info.NodeID, ReqStateStrings[state]) + } else { + ctx.Log.Logf("lockable", "%s timed out", wait_info.NodeID) + } + } + + return messages, changes +} + // LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) (Messages, Changes) { @@ -320,6 +341,8 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal messages, changes = ext.HandleErrorSignal(ctx, node, source, sig) case *SuccessSignal: messages, changes = ext.HandleSuccessSignal(ctx, node, source, sig) + case *TimeoutSignal: + messages, changes = ext.HandleTimeoutSignal(ctx, node, source, sig) default: } default: diff --git a/lockable_test.go b/lockable_test.go index 8db916b..7b9a938 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -51,11 +51,11 @@ func TestLink(t *testing.T) { _, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, link_signal.ID()) fatalErr(t, err) - info, exists := l1_lockable.Requirements[l2.ID] + state, exists := l1_lockable.Requirements[l2.ID] if exists == false { t.Fatal("l2 not in l1 requirements") - } else if info.State != Unlocked { - t.Fatalf("l2 in bad requirement state in l1: %+v", info.State) + } else if state != Unlocked { + t.Fatalf("l2 in bad requirement state in l1: %+v", state) } msgs = Messages{} @@ -68,8 +68,8 @@ func TestLink(t *testing.T) { fatalErr(t, err) } -func Test10KLink(t *testing.T) { - ctx := lockableTestContext(t, []string{"test"}) +func Test100Lock(t *testing.T) { + ctx := lockableTestContext(t, []string{"test", "lockable"}) l_pub, listener_key, err := ed25519.GenerateKey(rand.Reader) fatalErr(t, err) @@ -87,7 +87,7 @@ func Test10KLink(t *testing.T) { return l } - reqs := make([]NodeID, 1000) + reqs := make([]NodeID, 100) for i := range(reqs) { new_lockable := NewLockable() reqs[i] = new_lockable.ID @@ -152,7 +152,7 @@ func TestLock(t *testing.T) { id_2, err := LockLockable(ctx, l1) fatalErr(t, err) - _, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id_2) + _, _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*100, id_2) fatalErr(t, err) id_3, err := UnlockLockable(ctx, l0) diff --git a/serialize.go b/serialize.go index 14751ee..a93e737 100644 --- a/serialize.go +++ b/serialize.go @@ -256,7 +256,6 @@ var ( MapType = NewSerializedType("MAP") ReqStateType = NewSerializedType("REQ_STATE") - ReqInfoType = NewSerializedType("REQ_INFO") WaitInfoType = NewSerializedType("WAIT_INFO") SignalDirectionType = NewSerializedType("SIGNAL_DIRECTION") NodeStructType = NewSerializedType("NODE_STRUCT") diff --git a/serialize_test.go b/serialize_test.go index ed02eaa..bf5b052 100644 --- a/serialize_test.go +++ b/serialize_test.go @@ -10,9 +10,9 @@ import ( func TestSerializeTest(t *testing.T) { ctx := logTestContext(t, []string{"test", "serialize", "deserialize_types"}) testSerialize(t, ctx, map[string][]NodeID{"test_group": {RandID(), RandID(), RandID()}}) - testSerialize(t, ctx, map[NodeID]ReqInfo{ - RandID(): {}, - RandID(): {}, + testSerialize(t, ctx, map[NodeID]ReqState{ + RandID(): Locked, + RandID(): Unlocked, }) }