diff --git a/context.go b/context.go index 65d5b1b..a839c19 100644 --- a/context.go +++ b/context.go @@ -238,6 +238,9 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) { // Route a Signal to dest. Currently only local context routing is supported func (ctx *Context) Send(messages Messages) error { for _, msg := range(messages) { + if msg.Dest == ZeroID { + panic("Can't send to null ID") + } target, err := ctx.GetNode(msg.Dest) if err == nil { select { diff --git a/gql.go b/gql.go index e75cd2c..c6ffc3f 100644 --- a/gql.go +++ b/gql.go @@ -808,12 +808,17 @@ func NewGQLExtContext() *GQLExtContext { Fields: graphql.Fields{}, }) + subscription := graphql.NewObject(graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{}, + }) + context := GQLExtContext{ Schema: graphql.Schema{}, Types: []graphql.Type{}, Query: query, Mutation: mutation, - Subscription: nil, + Subscription: subscription, NodeTypes: map[NodeType]*graphql.Object{}, Interfaces: map[string]*Interface{}, Fields: map[string]Field{}, @@ -872,9 +877,9 @@ func NewGQLExtContext() *GQLExtContext { "requirements", LockableExtType, func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) { - id_strs, ok := val.(map[NodeID]string) + id_strs, ok := val.(map[NodeID]ReqState) if ok == false { - return nil, fmt.Errorf("can't parse requirements %+v as string, %s", val, reflect.TypeOf(val)) + return nil, fmt.Errorf("can't parse requirements %+v as map[NodeID]ReqState, %s", val, reflect.TypeOf(val)) } ids := []NodeID{} @@ -909,6 +914,20 @@ func NewGQLExtContext() *GQLExtContext { }, }) + context.Subscription.AddFieldConfig("Self", &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return p.Source, nil + }, + Subscribe: func(p graphql.ResolveParams) (interface{}, error) { + c := make(chan interface{}, 10) + for i := 0; i < 10; i++ { + c <- fmt.Sprintf("test %d", i) + } + return c, nil + }, + }) + context.Query.AddFieldConfig("Self", &graphql.Field{ Type: context.Interfaces["Node"].Interface, Resolve: func(p graphql.ResolveParams) (interface{}, error) { @@ -976,7 +995,6 @@ type GQLExt struct { resolver_response map[uuid.UUID]chan Signal `json:"-"` resolver_response_lock sync.RWMutex `json:"-"` - State string `json:"state"` TLSKey []byte `json:"tls_key"` TLSCert []byte `json:"tls_cert"` Listen string `json:"listen"` @@ -990,6 +1008,13 @@ func (ext *GQLExt) Field(name string) interface{} { }) } +func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal { + ext.resolver_response_lock.RLock() + response_chan, _ := ext.resolver_response[req_id] + ext.resolver_response_lock.RUnlock() + return response_chan +} + func (ext *GQLExt) GetResponseChannel(req_id uuid.UUID) chan Signal { response_chan := make(chan Signal, 1) ext.resolver_response_lock.Lock() @@ -999,18 +1024,14 @@ func (ext *GQLExt) GetResponseChannel(req_id uuid.UUID) chan Signal { } func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal { - ext.resolver_response_lock.RLock() - response_chan, exists := ext.resolver_response[req_id] - ext.resolver_response_lock.RUnlock() + response_chan := ext.FindResponseChannel(req_id) - if exists == true { + if response_chan != nil { ext.resolver_response_lock.Lock() delete(ext.resolver_response, req_id) ext.resolver_response_lock.Unlock() - return response_chan - } else { - return nil } + return response_chan } func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { @@ -1033,7 +1054,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } } else if signal.Type() == ReadResultSignalType { sig := signal.(*ReadResultSignal) - response_chan := ext.FreeResponseChannel(sig.ReqID()) + response_chan := ext.FindResponseChannel(sig.ReqID()) if response_chan != nil { select { case response_chan <- sig: @@ -1044,46 +1065,13 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa } else { ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig) } - } else if signal.Type() == GQLStateSignalType { - sig := signal.(*StringSignal) - ctx.Log.Logf("gql", "GQL_STATE_SIGNAL: %s - %+v", node.ID, sig.Str) - switch sig.Str { - case "start_server": - if ext.State == "stopped" { - err := ext.StartGQLServer(ctx, node) - if err == nil { - ext.State = "running" - node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID)) - } else { - ctx.Log.Logf("gql", "GQL_START_ERROR: %s", err) - } - } - case "stop_server": - if ext.State == "running" { - err := ext.StopGQLServer() - if err == nil { - ext.State = "stopped" - node.QueueSignal(time.Now(), NewStatusSignal("server_stopped", node.ID)) - } else { - ctx.Log.Logf("gql", "GQL_STOP_ERROR: %s", err) - } - } - default: - ctx.Log.Logf("gql", "unknown gql state %s", sig.Str) - } } else if signal.Type() == StartSignalType { - ctx.Log.Logf("gql", "starting with state: %s", ext.State) - switch ext.State { - case "running": - err := ext.StartGQLServer(ctx, node) - if err == nil { - node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID)) - } else { - ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) - } - case "stopped": - default: - ctx.Log.Logf("gql", "unknown state to restore from: %s", ext.State) + ctx.Log.Logf("gql", "starting gql server %s", node.ID) + err := ext.StartGQLServer(ctx, node) + if err == nil { + node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID)) + } else { + ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) } } return messages @@ -1118,7 +1106,7 @@ func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error { return json.Unmarshal(data, &ext) } -func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) { +func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte) (*GQLExt, error) { if tls_cert == nil || tls_key == nil { ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { @@ -1159,7 +1147,6 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta tls_key = ssl_key_pem } return &GQLExt{ - State: state, Listen: listen, resolver_response: map[uuid.UUID]chan Signal{}, TLSCert: tls_cert, diff --git a/gql_test.go b/gql_test.go index accae7e..a995762 100644 --- a/gql_test.go +++ b/gql_test.go @@ -14,10 +14,12 @@ import ( "crypto/rand" "crypto/ed25519" "bytes" + "golang.org/x/net/websocket" + "github.com/google/uuid" ) func TestGQLServer(t *testing.T) { - ctx := logTestContext(t, []string{"test", "policy", "pending"}) + ctx := logTestContext(t, []string{"test", "gql", "policy", "pending"}) TestNodeType := NodeType("TEST") err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) @@ -44,7 +46,6 @@ func TestGQLServer(t *testing.T) { LockSignalType.String(): nil, StatusSignalType.String(): nil, ReadSignalType.String(): nil, - GQLStateSignalType.String(): nil, }, }) @@ -60,7 +61,7 @@ func TestGQLServer(t *testing.T) { }, }) - gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped") + gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) fatalErr(t, err) listener_ext := NewListenerExt(10) @@ -80,11 +81,6 @@ func TestGQLServer(t *testing.T) { ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "NODE: %s", n1.ID) - msgs := Messages{} - msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}, gql.ID) - err = ctx.Send(msgs) - fatalErr(t, err) - _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "server_started" }) @@ -96,6 +92,7 @@ func TestGQLServer(t *testing.T) { client := &http.Client{Transport: skipVerifyTransport} port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port url := fmt.Sprintf("https://localhost:%d/gql", port) + ws_url := fmt.Sprintf("wss://127.0.0.1:%d/gqlws", port) req_1 := GQLPayload{ Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash } }", @@ -111,6 +108,12 @@ func TestGQLServer(t *testing.T) { }, } + auth_username := base64.StdEncoding.EncodeToString(n1.ID.Serialize()) + key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key) + fatalErr(t, err) + auth_password := base64.StdEncoding.EncodeToString(key_bytes) + auth_b64 := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", auth_username, auth_password))) + SendGQL := func(payload GQLPayload) []byte { ser, err := json.MarshalIndent(&payload, "", " ") fatalErr(t, err) @@ -119,9 +122,7 @@ func TestGQLServer(t *testing.T) { req, err := http.NewRequest("GET", url, req_data) fatalErr(t, err) - key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key) - fatalErr(t, err) - req.SetBasicAuth(base64.StdEncoding.EncodeToString(n1.ID.Serialize()), base64.StdEncoding.EncodeToString(key_bytes)) + req.SetBasicAuth(auth_username, auth_password) resp, err := client.Do(req) fatalErr(t, err) @@ -137,12 +138,70 @@ func TestGQLServer(t *testing.T) { resp_2 := SendGQL(req_2) ctx.Log.Logf("test", "RESP_2: %s", resp_2) - msgs = Messages{} - msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}, gql.ID) + sub_1 := GQLPayload{ + Query: "subscription { Self }", + } + + SubGQL := func(payload GQLPayload) { + config, err := websocket.NewConfig(ws_url, url) + fatalErr(t, err) + config.Protocol = append(config.Protocol, "graphql-ws") + config.TlsConfig = &tls.Config{InsecureSkipVerify: true} + config.Header.Add("Authorization", fmt.Sprintf("Basic %s", auth_b64)) + + ws, err := websocket.DialConfig(config) + + fatalErr(t, err) + + init := GQLWSMsg{ + ID: uuid.New().String(), + Type: "connection_init", + } + + ser, err := json.Marshal(&init) + fatalErr(t, err) + + _, err = ws.Write(ser) + fatalErr(t, err) + + resp := make([]byte, 1024) + n, err := ws.Read(resp) + + var init_resp GQLWSMsg + err = json.Unmarshal(resp[:n], &init_resp) + fatalErr(t, err) + + if init_resp.Type != "connection_ack" { + t.Fatal("Didn't receive connection_ack") + } + + sub := GQLWSMsg{ + ID: uuid.New().String(), + Type: "subscribe", + Payload: sub_1, + } + + ser, err = json.Marshal(&sub) + fatalErr(t, err) + _, err = ws.Write(ser) + fatalErr(t, err) + + for i := 0; i < 10; i++ { + n, err = ws.Read(resp) + fatalErr(t, err) + + ctx.Log.Logf("test", "SUB_%d: %s", i, resp[:n]) + } + } + + SubGQL(sub_1) + + msgs := Messages{} + msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID) err = ctx.Send(msgs) fatalErr(t, err) _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { - return sig.Str == "server_stopped" + return sig.Str == "stopped" }) fatalErr(t, err) } @@ -157,7 +216,7 @@ func TestGQLDB(t *testing.T) { ctx.Log.Logf("test", "U1_ID: %s", u1.ID) - gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "start") + gql_ext, err := NewGQLExt(ctx, ":0", nil, nil) fatalErr(t, err) listener_ext := NewListenerExt(10) gql := NewNode(ctx, nil, GQLNodeType, 10, nil, diff --git a/lockable.go b/lockable.go index afe958f..b51f31d 100644 --- a/lockable.go +++ b/lockable.go @@ -2,17 +2,21 @@ package graphvent import ( "encoding/binary" + "github.com/google/uuid" ) -type ReqState int +type ReqState byte const ( Unlocked = ReqState(0) Unlocking = ReqState(1) Locked = ReqState(2) Locking = ReqState(3) + AbortingLock = ReqState(4) ) type LockableExt struct{ + State ReqState + ReqID uuid.UUID Owner *NodeID PendingOwner *NodeID Requirements map[NodeID]ReqState @@ -37,7 +41,7 @@ func (ext *LockableExt) Type() ExtType { } func (ext *LockableExt) Serialize() ([]byte, error) { - ret := make([]byte, 8 + (16 * 2) + (17 * len(ext.Requirements))) + ret := make([]byte, 9 + (16 * 2) + (17 * len(ext.Requirements))) if ext.Owner != nil { bytes, err := ext.Owner.MarshalBinary() if err != nil { @@ -55,8 +59,8 @@ func (ext *LockableExt) Serialize() ([]byte, error) { } binary.BigEndian.PutUint64(ret[32:40], uint64(len(ext.Requirements))) - - cur := 40 + ret[40] = byte(ext.State) + cur := 41 for req, state := range(ext.Requirements) { bytes, err := req.MarshalBinary() if err != nil { @@ -105,6 +109,9 @@ func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error { num_requirements := int(binary.BigEndian.Uint64(data[cur:cur+8])) cur += 8 + ext.State = ReqState(data[cur]) + cur += 1 + if num_requirements != 0 { ext.Requirements = map[NodeID]ReqState{} } @@ -130,6 +137,7 @@ func NewLockableExt(requirements []NodeID) *LockableExt { } } return &LockableExt{ + State: Unlocked, Owner: nil, PendingOwner: nil, Requirements: reqs, @@ -137,162 +145,163 @@ func NewLockableExt(requirements []NodeID) *LockableExt { } // Send the signal to unlock a node from itself -func UnlockLockable(ctx *Context, node *Node) error { +func UnlockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) { msgs := Messages{} - msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), node.ID) - return ctx.Send(msgs) + signal := NewLockSignal("unlock") + msgs = msgs.Add(owner.ID, owner.Key, signal, target) + return signal.ID(), ctx.Send(msgs) } // Send the signal to lock a node from itself -func LockLockable(ctx *Context, node *Node) error { +func LockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) { msgs := Messages{} - msgs = msgs.Add(node.ID, node.Key, NewLockSignal("lock"), node.ID) - return ctx.Send(msgs) + signal := NewLockSignal("lock") + msgs = msgs.Add(owner.ID, owner.Key, signal, target) + return signal.ID(), ctx.Send(msgs) } -// Handle a LockSignal and update the extensions owner/requirement states -func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages { - state := signal.Str - log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal) +func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID, signal *ErrorSignal) Messages { + str := signal.Error + log.Logf("lockable", "ERROR_SIGNAL: %s->%s %+v", source, node.ID, str) - messages := Messages{} - switch state { - case "unlock": - if ext.Owner == nil { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocked"), source) - } else if source != *ext.Owner { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_owner"), source) - } else if ext.PendingOwner == nil { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocking"), source) - } else { - if len(ext.Requirements) == 0 { - ext.Owner = nil - ext.PendingOwner = nil - messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), source) - } else { - ext.PendingOwner = nil - for id, state := range(ext.Requirements) { - if state != Locked { - panic("NOT_LOCKED") - } + msgs := Messages {} + switch str { + case "not_unlocked": + if ext.State == Locking { + ext.State = AbortingLock + ext.Requirements[source] = Unlocked + for id, state := range(ext.Requirements) { + if state == Locked { ext.Requirements[id] = Unlocking - messages = messages.Add(node.ID, node.Key, NewLockSignal("unlock"), id) + msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), id) } - if source != node.ID { - messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocking"), source) - } - } - } - case "unlocking": - if ext.Requirements != nil { - state, exists := ext.Requirements[source] - if exists == false { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state != Unlocking { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) } } + case "not_locked": + panic("RECEIVED not_locked, meaning a node thought it held a lock it didn't") + case "not_requirement": + } - case "unlocked": - if source == node.ID { - return nil - } - - if ext.Requirements != nil { - state, exists := ext.Requirements[source] - if exists == false { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state != Unlocking { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) - } else { - ext.Requirements[source] = Unlocked + return msgs +} - if ext.PendingOwner == nil { - unlocked := 0 - for _, s := range(ext.Requirements) { - if s == Unlocked { - unlocked += 1 - } - } +// Handle a LockSignal and update the extensions owner/requirement states +func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages { + state := signal.Str + log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, state) - if len(ext.Requirements) == unlocked { - previous_owner := *ext.Owner - ext.Owner = nil - messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), previous_owner) - } - } - } - } + msgs := Messages{} + switch state { case "locked": - if source == node.ID { - return nil - } - - if ext.Requirements != nil { - state, exists := ext.Requirements[source] - if exists == false { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state != Locking { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) - } else { + state, found := ext.Requirements[source] + if found == false { + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + } else if state == Locking { + if ext.State == Locking { ext.Requirements[source] = Locked - - if ext.PendingOwner != nil { - locked := 0 - for _, s := range(ext.Requirements) { - if s == Locked { - locked += 1 - } + reqs := 0 + locked := 0 + for _, s := range(ext.Requirements) { + reqs += 1 + if s == Locked { + locked += 1 } + } - if len(ext.Requirements) == locked { - ext.Owner = ext.PendingOwner - messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) - } + if locked == reqs { + ext.State = Locked + ext.Owner = ext.PendingOwner + msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) + } else { + log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs) } + } else if ext.State == AbortingLock { + ext.Requirements[source] = Unlocking + msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), source) } } - case "locking": - if ext.Requirements != nil { - state, exists := ext.Requirements[source] - if exists == false { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) - } else if state != Locking { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) + case "unlocked": + state, found := ext.Requirements[source] + if found == false { + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) + } else if state == Unlocking { + ext.Requirements[source] = Unlocked + reqs := 0 + unlocked := 0 + for _, s := range(ext.Requirements) { + reqs += 1 + if s == Unlocked { + unlocked += 1 + } } - } + if unlocked == reqs { + old_state := ext.State + ext.State = Unlocked + if old_state == Unlocking { + ext.Owner = ext.PendingOwner + msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner) + } else if old_state == AbortingLock { + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(ext.ReqID, "not_unlocked"), *ext.PendingOwner) + ext.PendingOwner = ext.Owner + } + } else { + log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs) + } + } case "lock": - if ext.Owner != nil { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source) - } else if ext.PendingOwner != nil { - messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locking"), source) - } else { - owner := source + if ext.State == Unlocked { if len(ext.Requirements) == 0 { - ext.Owner = &owner - ext.PendingOwner = ext.Owner - messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), source) + ext.State = Locked + new_owner := source + ext.PendingOwner = &new_owner + ext.Owner = &new_owner + msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), new_owner) } else { - ext.PendingOwner = &owner + ext.State = Locking + ext.ReqID = signal.ID() + new_owner := source + ext.PendingOwner = &new_owner for id, state := range(ext.Requirements) { - log.Logf("lockable_detail", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id) if state != Unlocked { - panic("NOT_UNLOCKED") + log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING") } ext.Requirements[id] = Locking - messages = messages.Add(node.ID, node.Key, NewLockSignal("lock"), id) + lock_signal := NewLockSignal("lock") + msgs = msgs.Add(node.ID, node.Key, lock_signal, id) } - log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %d requirements", node.ID, len(ext.Requirements)) - if source != node.ID { - messages = messages.Add(node.ID, node.Key, NewLockSignal("locking"), source) + } + } else { + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source) + } + case "unlock": + if ext.State == Locked { + if len(ext.Requirements) == 0 { + ext.State = Unlocked + new_owner := source + ext.PendingOwner = nil + ext.Owner = nil + msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), new_owner) + } else if source == *ext.Owner { + ext.State = Unlocking + ext.ReqID = signal.ID() + ext.PendingOwner = nil + for id, state := range(ext.Requirements) { + if state != Locked { + log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING") + } + ext.Requirements[id] = Unlocking + lock_signal := NewLockSignal("unlock") + msgs = msgs.Add(node.ID, node.Key, lock_signal, id) } } + } else { + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locked"), source) } default: log.Logf("lockable", "LOCK_ERR: unkown state %s", state) } - return messages + return msgs } // LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes @@ -314,6 +323,8 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal switch signal.Type() { case LockSignalType: messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal)) + case ErrorSignalType: + messages = ext.HandleErrorSignal(ctx.Log, node, source, signal.(*ErrorSignal)) default: } default: diff --git a/lockable_test.go b/lockable_test.go index d68c4f6..9076099 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -100,7 +100,7 @@ func TestLink10K(t *testing.T) { ) ctx.Log.Logf("test", "CREATED_LISTENER") - err = LockLockable(ctx, node) + _, err = LockLockable(ctx, node, node.ID) fatalErr(t, err) _, err = WaitForSignal(listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool { @@ -118,7 +118,7 @@ func TestLink10K(t *testing.T) { } func TestLock(t *testing.T) { - ctx := lockableTestContext(t, []string{"lockable", "policy"}) + ctx := lockableTestContext(t, []string{"lockable"}) policy := NewAllNodesPolicy(nil) @@ -138,26 +138,40 @@ func TestLock(t *testing.T) { l3, _ := NewLockable(nil) l4, _ := NewLockable(nil) l5, _ := NewLockable(nil) - NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) + l0, l0_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) locked := func(sig *StringSignal) bool { return sig.Str == "locked" } - err := LockLockable(ctx, l1) + unlocked := func(sig *StringSignal) bool { + return sig.Str == "unlocked" + } + + _, err := LockLockable(ctx, l0, l5.ID) fatalErr(t, err) - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + + id, err := LockLockable(ctx, l1, l1.ID) fatalErr(t, err) - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool { + return sig.Error == "not_unlocked" && sig.ReqID() == id + }) fatalErr(t, err) - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + + _, err = UnlockLockable(ctx, l0, l5.ID) fatalErr(t, err) - _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, unlocked) fatalErr(t, err) - err = UnlockLockable(ctx, l1) + _, err = LockLockable(ctx, l1, l1.ID) fatalErr(t, err) + for i := 0; i < 4; i++ { + _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, func(sig *StringSignal) bool { + return sig.Str == "locked" + }) + fatalErr(t, err) + } } diff --git a/node.go b/node.go index 6eb48a7..28965ee 100644 --- a/node.go +++ b/node.go @@ -254,8 +254,9 @@ func nodeLoop(ctx *Context, node *Node) error { pends, resp := node.Allows(princ_id, msg.Signal.Permission()) if resp == Deny { ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission()) + ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source) msgs := Messages{} - msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), source) + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), msg.Source) ctx.Send(msgs) continue } else if resp == Pending { @@ -369,6 +370,7 @@ func nodeLoop(ctx *Context, node *Node) error { result := node.ReadFields(read_signal.Extensions) msgs := Messages{} msgs = msgs.Add(node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.ID, node.Type, result), source) + msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(read_signal.ID(), "read_done"), source) ctx.Send(msgs) } } diff --git a/signal.go b/signal.go index 7e023ca..022965c 100644 --- a/signal.go +++ b/signal.go @@ -28,7 +28,6 @@ const ( LinkStartSignalType = SignalType("LINK_START") ECDHSignalType = SignalType("ECDH") ECDHProxySignalType = SignalType("ECDH_PROXY") - GQLStateSignalType = SignalType("GQL_STATE") ACLTimeoutSignalType = SignalType("ACL_TIMEOUT") Up SignalDirection = iota