diff --git a/context.go b/context.go index e48e6d2..2ab003f 100644 --- a/context.go +++ b/context.go @@ -1051,12 +1051,84 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(RandID()), NodeIDType, SerializeArray, DeserializeArray[NodeID](ctx)) + node_id_type := reflect.TypeOf(RandID()) + err = ctx.RegisterType(node_id_type, NodeIDType, + func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + type_stack := []SerializedType{ctx_type} + if value == nil { + return SerializedValue{ + type_stack, + nil, + }, nil + } else { + data, err := value.Interface().(NodeID).MarshalBinary() + if err != nil { + return SerializedValue{}, err + } + return SerializedValue{ + type_stack, + data, + }, nil + } + }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ + if value.Data == nil { + return node_id_type, nil, value, nil + } else { + id_data, value, err := value.PopData(16) + if err != nil { + return nil, nil, value, err + } + + id, err := IDFromBytes(id_data) + if err != nil { + return nil, nil, value, err + } + + id_value := reflect.ValueOf(id) + return node_id_type, &id_value, value, nil + } + }) if err != nil { return nil, err } - err = ctx.RegisterType(reflect.TypeOf(uuid.New()), UUIDType, SerializeArray, DeserializeArray[uuid.UUID](ctx)) + uuid_type := reflect.TypeOf(uuid.UUID{}) + err = ctx.RegisterType(uuid_type, UUIDType, + func(ctx *Context, ctx_type SerializedType, reflect_type reflect.Type, value *reflect.Value)(SerializedValue,error){ + type_stack := []SerializedType{ctx_type} + if value == nil { + return SerializedValue{ + type_stack, + nil, + }, nil + } else { + data, err := value.Interface().(uuid.UUID).MarshalBinary() + if err != nil { + return SerializedValue{}, err + } + return SerializedValue{ + type_stack, + data, + }, nil + } + }, func(ctx *Context, value SerializedValue)(reflect.Type, *reflect.Value, SerializedValue, error){ + if value.Data == nil { + return uuid_type, nil, value, nil + } else { + id_data, value, err := value.PopData(16) + if err != nil { + return nil, nil, value, err + } + + id, err := uuid.FromBytes(id_data) + if err != nil { + return nil, nil, value, err + } + + id_value := reflect.ValueOf(id) + return uuid_type, &id_value, value, nil + } + }) if err != nil { return nil, err } @@ -1154,6 +1226,12 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + req_info_type := reflect.TypeOf(ReqInfo{}) + err = ctx.RegisterType(req_info_type, ReqInfoType, SerializeStruct(ctx, req_info_type), DeserializeStruct(ctx, req_info_type)) + if err != nil { + return nil, err + } + err = ctx.RegisterExtension(reflect.TypeOf((*LockableExt)(nil)), LockableExtType, nil) if err != nil { return nil, err @@ -1220,6 +1298,11 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterSignal(reflect.TypeOf(TimeoutSignal{}), TimeoutSignalType) + if err != nil { + return nil, err + } + err = ctx.RegisterSignal(reflect.TypeOf(LinkSignal{}), LinkSignalType) if err != nil { return nil, err diff --git a/gql.go b/gql.go index cdccfcc..38b4df9 100644 --- a/gql.go +++ b/gql.go @@ -1044,7 +1044,7 @@ func NewGQLExtContext() *GQLExtContext { "requirements", LockableExtType, func(p graphql.ResolveParams, ctx *ResolveContext, value reflect.Value) ([]NodeID, error) { - id_strs, ok := value.Interface().(map[NodeID]ReqState) + id_strs, ok := value.Interface().(map[NodeID]ReqInfo) if ok == false { return nil, fmt.Errorf("can't parse requirements %+v as map[NodeID]ReqState", value.Type()) } diff --git a/lockable.go b/lockable.go index 0076c8f..82a6435 100644 --- a/lockable.go +++ b/lockable.go @@ -13,12 +13,18 @@ const ( AbortingLock = ReqState(4) ) +type ReqInfo struct { + State ReqState `gv:"state"` + MsgID uuid.UUID `gv:"msg_id"` +} + type LockableExt struct{ State ReqState `gv:"state"` ReqID *uuid.UUID `gv:"req_id"` Owner *NodeID `gv:"owner"` PendingOwner *NodeID `gv:"pending_owner"` - Requirements map[NodeID]ReqState `gv:"requirements"` + PendingID uuid.UUID `gv:"pending_id"` + Requirements map[NodeID]ReqInfo `gv:"requirements"` } func (ext *LockableExt) Type() ExtType { @@ -26,11 +32,14 @@ func (ext *LockableExt) Type() ExtType { } func NewLockableExt(requirements []NodeID) *LockableExt { - var reqs map[NodeID]ReqState = nil + var reqs map[NodeID]ReqInfo = nil if requirements != nil { - reqs = map[NodeID]ReqState{} + reqs = map[NodeID]ReqInfo{} for _, id := range(requirements) { - reqs[id] = Unlocked + reqs[id] = ReqInfo{ + Unlocked, + uuid.UUID{}, + } } } return &LockableExt{ @@ -41,19 +50,17 @@ func NewLockableExt(requirements []NodeID) *LockableExt { } } -// Send the signal to unlock a node from itself -func UnlockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) { +func UnlockLockable(ctx *Context, node *Node) (uuid.UUID, error) { msgs := Messages{} signal := NewLockSignal("unlock") - msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target) + msgs = msgs.Add(ctx, node.ID, node.Key, signal, node.ID) return signal.ID(), ctx.Send(msgs) } -// Send the signal to lock a node from itself -func LockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) { +func LockLockable(ctx *Context, node *Node) (uuid.UUID, error) { msgs := Messages{} signal := NewLockSignal("lock") - msgs = msgs.Add(ctx, owner.ID, owner.Key, signal, target) + msgs = msgs.Add(ctx, node.ID, node.Key, signal, node.ID) return signal.ID(), ctx.Send(msgs) } @@ -66,11 +73,20 @@ func (ext *LockableExt) HandleErrorSignal(ctx *Context, node *Node, source NodeI case "not_unlocked": if ext.State == Locking { ext.State = AbortingLock - ext.Requirements[source] = Unlocked - for id, state := range(ext.Requirements) { - if state == Locked { - ext.Requirements[id] = Unlocking - msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlock"), id) + req_info := ext.Requirements[source] + req_info.State = Unlocked + ext.Requirements[source] = req_info + for id, info := range(ext.Requirements) { + if info.State == Locked { + lock_signal := NewLockSignal("unlock") + + req_info := ext.Requirements[id] + req_info.State = Unlocking + req_info.MsgID = lock_signal.ID() + ext.Requirements[id] = req_info + ctx.Log.Logf("lockable", "SENT_ABORT_UNLOCK: %s to %s", lock_signal.ID(), id) + + msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, id) } } } @@ -92,9 +108,12 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_requirement"), source) } else { if ext.Requirements == nil { - ext.Requirements = map[NodeID]ReqState{} + ext.Requirements = map[NodeID]ReqInfo{} + } + ext.Requirements[signal.NodeID] = ReqInfo{ + Unlocked, + uuid.UUID{}, } - ext.Requirements[signal.NodeID] = Unlocked msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "req_added"), source) } case "remove": @@ -114,51 +133,58 @@ func (ext *LockableExt) HandleLinkSignal(ctx *Context, node *Node, source NodeID return msgs } -// Handle a LockSignal and update the extensions owner/requirement states -func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) Messages { - ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal.State) +func (ext *LockableExt) HandleSuccessSignal(ctx *Context, node *Node, source NodeID, signal *SuccessSignal) Messages { + ctx.Log.Logf("lockable", "SUCCESS_SIGNAL: %+v", signal) msgs := Messages{} - switch signal.State { - case "locked": - state, found := ext.Requirements[source] - if found == false && source != node.ID { - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "got 'locked' from non-requirement"), source) - } else if state == Locking { + + info, found := ext.Requirements[source] + ctx.Log.Logf("lockable", "State: %+v", ext.State) + if found == false { + ctx.Log.Logf("lockable", "Got success from non-requirement %s", source) + } else if source == node.ID { + // Do nothing with it + } else if info.MsgID != signal.ReqID { + ctx.Log.Logf("lockable", "Got success for wrong signal for %s: %s, expecting %s", source, signal.ReqID, info.MsgID) + } else { + if info.State == Locking { if ext.State == Locking { - ext.Requirements[source] = Locked + info.State = Locked + info.MsgID = uuid.UUID{} + ext.Requirements[source] = info reqs := 0 locked := 0 for _, s := range(ext.Requirements) { reqs += 1 - if s == Locked { + if s.State == Locked { locked += 1 } } if locked == reqs { + ctx.Log.Logf("lockable", "WHOLE LOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner) ext.State = Locked ext.Owner = ext.PendingOwner - msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) + msgs = msgs.Add(ctx, node.ID, node.Key, NewSuccessSignal(ext.PendingID), *ext.Owner) } else { ctx.Log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs) } } else if ext.State == AbortingLock { - ext.Requirements[source] = Unlocking - msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlock"), source) + lock_signal := NewLockSignal("unlock") + info.State = Unlocking + info.MsgID = lock_signal.ID() + ext.Requirements[source] = info + msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, source) } - } - case "unlocked": - state, found := ext.Requirements[source] - if found == false { - msgs = msgs.Add(ctx, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state == Unlocking { - ext.Requirements[source] = Unlocked + } else if info.State == Unlocking { + info.State = Unlocked + info.MsgID = uuid.UUID{} + ext.Requirements[source] = info reqs := 0 unlocked := 0 for _, s := range(ext.Requirements) { reqs += 1 - if s == Unlocked { + if s.State == Unlocked { unlocked += 1 } } @@ -166,10 +192,12 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID if unlocked == reqs { old_state := ext.State ext.State = Unlocked + ctx.Log.Logf("lockable", "WHOLE UNLOCK: %s - %s - %+v", node.ID, ext.PendingID, ext.PendingOwner) if old_state == Unlocking { + previous_owner := *ext.Owner ext.Owner = ext.PendingOwner ext.ReqID = nil - msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner) + msgs = msgs.Add(ctx, node.ID, node.Key, NewSuccessSignal(ext.PendingID), previous_owner) } else if old_state == AbortingLock { msgs = msgs.Add(ctx ,node.ID, node.Key, NewErrorSignal(*ext.ReqID, "not_unlocked"), *ext.PendingOwner) ext.PendingOwner = ext.Owner @@ -178,6 +206,17 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID ctx.Log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs) } } + } + + return msgs +} + +// Handle a LockSignal and update the extensions owner/requirement states +func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID, signal *LockSignal) Messages { + ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal.State) + + msgs := Messages{} + switch signal.State { case "lock": if ext.State == Unlocked { if len(ext.Requirements) == 0 { @@ -185,19 +224,22 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID new_owner := source ext.PendingOwner = &new_owner ext.Owner = &new_owner - msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("locked"), new_owner) + msgs = msgs.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), new_owner) } else { ext.State = Locking id := signal.ID() ext.ReqID = &id new_owner := source ext.PendingOwner = &new_owner - for id, state := range(ext.Requirements) { - if state != Unlocked { + ext.PendingID = signal.ID() + for id, info := range(ext.Requirements) { + if info.State != Unlocked { ctx.Log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING") } - ext.Requirements[id] = Locking lock_signal := NewLockSignal("lock") + info.State = Locking + info.MsgID = lock_signal.ID() + ext.Requirements[id] = info msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, id) } } @@ -211,18 +253,21 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, node *Node, source NodeID new_owner := source ext.PendingOwner = nil ext.Owner = nil - msgs = msgs.Add(ctx, node.ID, node.Key, NewLockSignal("unlocked"), new_owner) + msgs = msgs.Add(ctx, node.ID, node.Key, NewSuccessSignal(signal.ID()), new_owner) } else if source == *ext.Owner { ext.State = Unlocking id := signal.ID() ext.ReqID = &id ext.PendingOwner = nil - for id, state := range(ext.Requirements) { - if state != Locked { + ext.PendingID = signal.ID() + for id, info := range(ext.Requirements) { + if info.State != Locked { ctx.Log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING") } - ext.Requirements[id] = Unlocking lock_signal := NewLockSignal("unlock") + info.State = Unlocking + info.MsgID = lock_signal.ID() + ext.Requirements[id] = info msgs = msgs.Add(ctx, node.ID, node.Key, lock_signal, id) } } @@ -258,6 +303,8 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal messages = ext.HandleLockSignal(ctx, node, source, sig) case *ErrorSignal: messages = ext.HandleErrorSignal(ctx, node, source, sig) + case *SuccessSignal: + messages = ext.HandleSuccessSignal(ctx, node, source, sig) default: } default: diff --git a/lockable_test.go b/lockable_test.go index 0f64d12..9974d69 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -62,7 +62,7 @@ func TestLink(t *testing.T) { fatalErr(t, err) } -func Test10KLink(t *testing.T) { +func Test1KLink(t *testing.T) { ctx := lockableTestContext(t, []string{"test"}) l_pub, listener_key, err := ed25519.GenerateKey(rand.Reader) @@ -91,6 +91,7 @@ func Test10KLink(t *testing.T) { l_policy := NewAllNodesPolicy(Tree{ SerializedType(LockSignalType): nil, }) + listener := NewListenerExt(100000) node, err := NewNode(ctx, listener_key, TestLockableType, 10000, []Policy{l_policy}, listener, @@ -99,25 +100,17 @@ func Test10KLink(t *testing.T) { fatalErr(t, err) ctx.Log.Logf("test", "CREATED_LISTENER") - _, err = LockLockable(ctx, node, node.ID) + lock_id, err := LockLockable(ctx, node) fatalErr(t, err) - _, err = WaitForSignal(listener.Chan, time.Millisecond*1000, func(sig *LockSignal) bool { - return sig.State == "locked" - }) + _, err = WaitForResponse(listener.Chan, time.Second*20, lock_id) fatalErr(t, err) - for _, _ = range(reqs) { - _, err := WaitForSignal(listener.Chan, time.Millisecond*100, func(sig *LockSignal) bool { - return sig.State == "locked" - }) - fatalErr(t, err) - } ctx.Log.Logf("test", "LOCKED_10K") } func TestLock(t *testing.T) { - ctx := lockableTestContext(t, []string{"lockable"}) + ctx := lockableTestContext(t, []string{"test", "lockable"}) policy := NewAllNodesPolicy(nil) @@ -135,38 +128,35 @@ func TestLock(t *testing.T) { l3, _ := NewLockable(nil) l4, _ := NewLockable(nil) l5, _ := NewLockable(nil) - l0, l0_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) + l0, l0_listener := NewLockable([]NodeID{l5.ID}) l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) - locked := func(sig *LockSignal) bool { - return sig.State == "locked" - } - - unlocked := func(sig *LockSignal) bool { - return sig.State == "unlocked" - } + ctx.Log.Logf("test", "l0: %s", l0.ID) + ctx.Log.Logf("test", "l1: %s", l1.ID) + ctx.Log.Logf("test", "l2: %s", l2.ID) + ctx.Log.Logf("test", "l3: %s", l3.ID) + ctx.Log.Logf("test", "l4: %s", l4.ID) + ctx.Log.Logf("test", "l5: %s", l5.ID) - _, err := LockLockable(ctx, l0, l5.ID) + id_1, err := LockLockable(ctx, l0) + ctx.Log.Logf("test", "ID_1: %s", id_1) fatalErr(t, err) - _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, locked) + _, err = WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_1) fatalErr(t, err) - id, err := LockLockable(ctx, l1, l1.ID) + id_2, err := LockLockable(ctx, l1) fatalErr(t, err) - _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id) + _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id_2) fatalErr(t, err) - _, err = UnlockLockable(ctx, l0, l5.ID) + id_3, err := UnlockLockable(ctx, l0) fatalErr(t, err) - _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, unlocked) + _, err = WaitForResponse(l0_listener.Chan, time.Millisecond*10, id_3) fatalErr(t, err) - _, err = LockLockable(ctx, l1, l1.ID) + id_4, err := LockLockable(ctx, l1) + fatalErr(t, err) + + _, err = WaitForResponse(l1_listener.Chan, time.Millisecond*10, id_4) fatalErr(t, err) - for i := 0; i < 4; i++ { - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, func(sig *LockSignal) bool { - return sig.State == "locked" - }) - fatalErr(t, err) - } } diff --git a/policy.go b/policy.go index 1b31ed8..4d1e616 100644 --- a/policy.go +++ b/policy.go @@ -248,6 +248,6 @@ type AllNodesPolicy struct { } var DefaultPolicy = NewAllNodesPolicy(Tree{ - ResultType: nil, + ResponseType: nil, StatusType: nil, }) diff --git a/serialize.go b/serialize.go index 4112469..ab765ee 100644 --- a/serialize.go +++ b/serialize.go @@ -93,6 +93,7 @@ var ( StatusSignalType = NewSignalType("STATUS") LinkSignalType = NewSignalType("LINK") LockSignalType = NewSignalType("LOCK") + TimeoutSignalType = NewSignalType("TIMEOUT") ReadSignalType = NewSignalType("READ") ACLTimeoutSignalType = NewSignalType("ACL_TIMEOUT") ErrorSignalType = NewSignalType("ERROR") @@ -127,6 +128,7 @@ var ( MapType = NewSerializedType("MAP") ReqStateType = NewSerializedType("REQ_STATE") + ReqInfoType = NewSerializedType("REQ_INFO") SignalDirectionType = NewSerializedType("SIGNAL_DIRECTION") NodeStructType = NewSerializedType("NODE_STRUCT") QueuedSignalType = NewSerializedType("QUEUED_SIGNAL") @@ -141,7 +143,7 @@ var ( PendingACLType = NewSerializedType("PENDING_ACL") PendingSignalType = NewSerializedType("PENDING_SIGNAL") TimeType = NewSerializedType("TIME") - ResultType = NewSerializedType("RESULT") + ResponseType = NewSerializedType("RESPONSE") StatusType = NewSerializedType("STATUS") TreeType = NewSerializedType("TREE") SerializedTypeSerialized = NewSerializedType("SERIALIZED_TYPE") @@ -154,8 +156,6 @@ func SerializeArray(ctx *Context, ctx_type SerializedType, reflect_type reflect. type_stack, nil, }, nil - } else if value.IsZero() { - return SerializedValue{}, fmt.Errorf("don't know what zero array means...") } else { var element SerializedValue var err error diff --git a/signal.go b/signal.go index 17bf402..f848457 100644 --- a/signal.go +++ b/signal.go @@ -186,7 +186,7 @@ type SuccessSignal struct { } func (signal SuccessSignal) Permission() Tree { return Tree{ - ResultType: { + ResponseType: { SerializedType(SuccessSignalType): nil, }, } @@ -206,7 +206,7 @@ func (signal ErrorSignal) String() string { } func (signal ErrorSignal) Permission() Tree { return Tree{ - ResultType: { + ResponseType: { SerializedType(ErrorSignalType): nil, }, } @@ -336,7 +336,7 @@ type ReadResultSignal struct { } func (signal ReadResultSignal) Permission() Tree { return Tree{ - ResultType: { + ResponseType: { SerializedType(ReadResultSignalType): nil, }, }