Fixed lockable to properly fail

gql_cataclysm
noah metz 2023-08-15 18:23:06 -06:00
parent b446c9078a
commit 98c0b7e807
7 changed files with 277 additions and 202 deletions

@ -238,6 +238,9 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) {
// Route a Signal to dest. Currently only local context routing is supported
func (ctx *Context) Send(messages Messages) error {
for _, msg := range(messages) {
if msg.Dest == ZeroID {
panic("Can't send to null ID")
}
target, err := ctx.GetNode(msg.Dest)
if err == nil {
select {

@ -808,12 +808,17 @@ func NewGQLExtContext() *GQLExtContext {
Fields: graphql.Fields{},
})
subscription := graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: graphql.Fields{},
})
context := GQLExtContext{
Schema: graphql.Schema{},
Types: []graphql.Type{},
Query: query,
Mutation: mutation,
Subscription: nil,
Subscription: subscription,
NodeTypes: map[NodeType]*graphql.Object{},
Interfaces: map[string]*Interface{},
Fields: map[string]Field{},
@ -872,9 +877,9 @@ func NewGQLExtContext() *GQLExtContext {
"requirements",
LockableExtType,
func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) {
id_strs, ok := val.(map[NodeID]string)
id_strs, ok := val.(map[NodeID]ReqState)
if ok == false {
return nil, fmt.Errorf("can't parse requirements %+v as string, %s", val, reflect.TypeOf(val))
return nil, fmt.Errorf("can't parse requirements %+v as map[NodeID]ReqState, %s", val, reflect.TypeOf(val))
}
ids := []NodeID{}
@ -909,6 +914,20 @@ func NewGQLExtContext() *GQLExtContext {
},
})
context.Subscription.AddFieldConfig("Self", &graphql.Field{
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return p.Source, nil
},
Subscribe: func(p graphql.ResolveParams) (interface{}, error) {
c := make(chan interface{}, 10)
for i := 0; i < 10; i++ {
c <- fmt.Sprintf("test %d", i)
}
return c, nil
},
})
context.Query.AddFieldConfig("Self", &graphql.Field{
Type: context.Interfaces["Node"].Interface,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
@ -976,7 +995,6 @@ type GQLExt struct {
resolver_response map[uuid.UUID]chan Signal `json:"-"`
resolver_response_lock sync.RWMutex `json:"-"`
State string `json:"state"`
TLSKey []byte `json:"tls_key"`
TLSCert []byte `json:"tls_cert"`
Listen string `json:"listen"`
@ -990,6 +1008,13 @@ func (ext *GQLExt) Field(name string) interface{} {
})
}
func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal {
ext.resolver_response_lock.RLock()
response_chan, _ := ext.resolver_response[req_id]
ext.resolver_response_lock.RUnlock()
return response_chan
}
func (ext *GQLExt) GetResponseChannel(req_id uuid.UUID) chan Signal {
response_chan := make(chan Signal, 1)
ext.resolver_response_lock.Lock()
@ -999,18 +1024,14 @@ func (ext *GQLExt) GetResponseChannel(req_id uuid.UUID) chan Signal {
}
func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
ext.resolver_response_lock.RLock()
response_chan, exists := ext.resolver_response[req_id]
ext.resolver_response_lock.RUnlock()
response_chan := ext.FindResponseChannel(req_id)
if exists == true {
if response_chan != nil {
ext.resolver_response_lock.Lock()
delete(ext.resolver_response, req_id)
ext.resolver_response_lock.Unlock()
return response_chan
} else {
return nil
}
return response_chan
}
func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
@ -1033,7 +1054,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
}
} else if signal.Type() == ReadResultSignalType {
sig := signal.(*ReadResultSignal)
response_chan := ext.FreeResponseChannel(sig.ReqID())
response_chan := ext.FindResponseChannel(sig.ReqID())
if response_chan != nil {
select {
case response_chan <- sig:
@ -1044,46 +1065,13 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} else {
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
}
} else if signal.Type() == GQLStateSignalType {
sig := signal.(*StringSignal)
ctx.Log.Logf("gql", "GQL_STATE_SIGNAL: %s - %+v", node.ID, sig.Str)
switch sig.Str {
case "start_server":
if ext.State == "stopped" {
err := ext.StartGQLServer(ctx, node)
if err == nil {
ext.State = "running"
node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
} else {
ctx.Log.Logf("gql", "GQL_START_ERROR: %s", err)
}
}
case "stop_server":
if ext.State == "running" {
err := ext.StopGQLServer()
if err == nil {
ext.State = "stopped"
node.QueueSignal(time.Now(), NewStatusSignal("server_stopped", node.ID))
} else {
ctx.Log.Logf("gql", "GQL_STOP_ERROR: %s", err)
}
}
default:
ctx.Log.Logf("gql", "unknown gql state %s", sig.Str)
}
} else if signal.Type() == StartSignalType {
ctx.Log.Logf("gql", "starting with state: %s", ext.State)
switch ext.State {
case "running":
err := ext.StartGQLServer(ctx, node)
if err == nil {
node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
} else {
ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err)
}
case "stopped":
default:
ctx.Log.Logf("gql", "unknown state to restore from: %s", ext.State)
ctx.Log.Logf("gql", "starting gql server %s", node.ID)
err := ext.StartGQLServer(ctx, node)
if err == nil {
node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
} else {
ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err)
}
}
return messages
@ -1118,7 +1106,7 @@ func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, &ext)
}
func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) {
func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte) (*GQLExt, error) {
if tls_cert == nil || tls_key == nil {
ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
@ -1159,7 +1147,6 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta
tls_key = ssl_key_pem
}
return &GQLExt{
State: state,
Listen: listen,
resolver_response: map[uuid.UUID]chan Signal{},
TLSCert: tls_cert,

@ -14,10 +14,12 @@ import (
"crypto/rand"
"crypto/ed25519"
"bytes"
"golang.org/x/net/websocket"
"github.com/google/uuid"
)
func TestGQLServer(t *testing.T) {
ctx := logTestContext(t, []string{"test", "policy", "pending"})
ctx := logTestContext(t, []string{"test", "gql", "policy", "pending"})
TestNodeType := NodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType})
@ -44,7 +46,6 @@ func TestGQLServer(t *testing.T) {
LockSignalType.String(): nil,
StatusSignalType.String(): nil,
ReadSignalType.String(): nil,
GQLStateSignalType.String(): nil,
},
})
@ -60,7 +61,7 @@ func TestGQLServer(t *testing.T) {
},
})
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped")
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err)
listener_ext := NewListenerExt(10)
@ -80,11 +81,6 @@ func TestGQLServer(t *testing.T) {
ctx.Log.Logf("test", "GQL: %s", gql.ID)
ctx.Log.Logf("test", "NODE: %s", n1.ID)
msgs := Messages{}
msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}, gql.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "server_started"
})
@ -96,6 +92,7 @@ func TestGQLServer(t *testing.T) {
client := &http.Client{Transport: skipVerifyTransport}
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
url := fmt.Sprintf("https://localhost:%d/gql", port)
ws_url := fmt.Sprintf("wss://127.0.0.1:%d/gqlws", port)
req_1 := GQLPayload{
Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash } }",
@ -111,6 +108,12 @@ func TestGQLServer(t *testing.T) {
},
}
auth_username := base64.StdEncoding.EncodeToString(n1.ID.Serialize())
key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key)
fatalErr(t, err)
auth_password := base64.StdEncoding.EncodeToString(key_bytes)
auth_b64 := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", auth_username, auth_password)))
SendGQL := func(payload GQLPayload) []byte {
ser, err := json.MarshalIndent(&payload, "", " ")
fatalErr(t, err)
@ -119,9 +122,7 @@ func TestGQLServer(t *testing.T) {
req, err := http.NewRequest("GET", url, req_data)
fatalErr(t, err)
key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key)
fatalErr(t, err)
req.SetBasicAuth(base64.StdEncoding.EncodeToString(n1.ID.Serialize()), base64.StdEncoding.EncodeToString(key_bytes))
req.SetBasicAuth(auth_username, auth_password)
resp, err := client.Do(req)
fatalErr(t, err)
@ -137,12 +138,70 @@ func TestGQLServer(t *testing.T) {
resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2)
msgs = Messages{}
msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}, gql.ID)
sub_1 := GQLPayload{
Query: "subscription { Self }",
}
SubGQL := func(payload GQLPayload) {
config, err := websocket.NewConfig(ws_url, url)
fatalErr(t, err)
config.Protocol = append(config.Protocol, "graphql-ws")
config.TlsConfig = &tls.Config{InsecureSkipVerify: true}
config.Header.Add("Authorization", fmt.Sprintf("Basic %s", auth_b64))
ws, err := websocket.DialConfig(config)
fatalErr(t, err)
init := GQLWSMsg{
ID: uuid.New().String(),
Type: "connection_init",
}
ser, err := json.Marshal(&init)
fatalErr(t, err)
_, err = ws.Write(ser)
fatalErr(t, err)
resp := make([]byte, 1024)
n, err := ws.Read(resp)
var init_resp GQLWSMsg
err = json.Unmarshal(resp[:n], &init_resp)
fatalErr(t, err)
if init_resp.Type != "connection_ack" {
t.Fatal("Didn't receive connection_ack")
}
sub := GQLWSMsg{
ID: uuid.New().String(),
Type: "subscribe",
Payload: sub_1,
}
ser, err = json.Marshal(&sub)
fatalErr(t, err)
_, err = ws.Write(ser)
fatalErr(t, err)
for i := 0; i < 10; i++ {
n, err = ws.Read(resp)
fatalErr(t, err)
ctx.Log.Logf("test", "SUB_%d: %s", i, resp[:n])
}
}
SubGQL(sub_1)
msgs := Messages{}
msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "server_stopped"
return sig.Str == "stopped"
})
fatalErr(t, err)
}
@ -157,7 +216,7 @@ func TestGQLDB(t *testing.T) {
ctx.Log.Logf("test", "U1_ID: %s", u1.ID)
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "start")
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err)
listener_ext := NewListenerExt(10)
gql := NewNode(ctx, nil, GQLNodeType, 10, nil,

@ -2,17 +2,21 @@ package graphvent
import (
"encoding/binary"
"github.com/google/uuid"
)
type ReqState int
type ReqState byte
const (
Unlocked = ReqState(0)
Unlocking = ReqState(1)
Locked = ReqState(2)
Locking = ReqState(3)
AbortingLock = ReqState(4)
)
type LockableExt struct{
State ReqState
ReqID uuid.UUID
Owner *NodeID
PendingOwner *NodeID
Requirements map[NodeID]ReqState
@ -37,7 +41,7 @@ func (ext *LockableExt) Type() ExtType {
}
func (ext *LockableExt) Serialize() ([]byte, error) {
ret := make([]byte, 8 + (16 * 2) + (17 * len(ext.Requirements)))
ret := make([]byte, 9 + (16 * 2) + (17 * len(ext.Requirements)))
if ext.Owner != nil {
bytes, err := ext.Owner.MarshalBinary()
if err != nil {
@ -55,8 +59,8 @@ func (ext *LockableExt) Serialize() ([]byte, error) {
}
binary.BigEndian.PutUint64(ret[32:40], uint64(len(ext.Requirements)))
cur := 40
ret[40] = byte(ext.State)
cur := 41
for req, state := range(ext.Requirements) {
bytes, err := req.MarshalBinary()
if err != nil {
@ -105,6 +109,9 @@ func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error {
num_requirements := int(binary.BigEndian.Uint64(data[cur:cur+8]))
cur += 8
ext.State = ReqState(data[cur])
cur += 1
if num_requirements != 0 {
ext.Requirements = map[NodeID]ReqState{}
}
@ -130,6 +137,7 @@ func NewLockableExt(requirements []NodeID) *LockableExt {
}
}
return &LockableExt{
State: Unlocked,
Owner: nil,
PendingOwner: nil,
Requirements: reqs,
@ -137,162 +145,163 @@ func NewLockableExt(requirements []NodeID) *LockableExt {
}
// Send the signal to unlock a node from itself
func UnlockLockable(ctx *Context, node *Node) error {
func UnlockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) {
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), node.ID)
return ctx.Send(msgs)
signal := NewLockSignal("unlock")
msgs = msgs.Add(owner.ID, owner.Key, signal, target)
return signal.ID(), ctx.Send(msgs)
}
// Send the signal to lock a node from itself
func LockLockable(ctx *Context, node *Node) error {
func LockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) {
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("lock"), node.ID)
return ctx.Send(msgs)
signal := NewLockSignal("lock")
msgs = msgs.Add(owner.ID, owner.Key, signal, target)
return signal.ID(), ctx.Send(msgs)
}
// Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages {
state := signal.Str
log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal)
func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID, signal *ErrorSignal) Messages {
str := signal.Error
log.Logf("lockable", "ERROR_SIGNAL: %s->%s %+v", source, node.ID, str)
messages := Messages{}
switch state {
case "unlock":
if ext.Owner == nil {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocked"), source)
} else if source != *ext.Owner {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_owner"), source)
} else if ext.PendingOwner == nil {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocking"), source)
} else {
if len(ext.Requirements) == 0 {
ext.Owner = nil
ext.PendingOwner = nil
messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), source)
} else {
ext.PendingOwner = nil
for id, state := range(ext.Requirements) {
if state != Locked {
panic("NOT_LOCKED")
}
msgs := Messages {}
switch str {
case "not_unlocked":
if ext.State == Locking {
ext.State = AbortingLock
ext.Requirements[source] = Unlocked
for id, state := range(ext.Requirements) {
if state == Locked {
ext.Requirements[id] = Unlocking
messages = messages.Add(node.ID, node.Key, NewLockSignal("unlock"), id)
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), id)
}
if source != node.ID {
messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocking"), source)
}
}
}
case "unlocking":
if ext.Requirements != nil {
state, exists := ext.Requirements[source]
if exists == false {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state != Unlocking {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source)
}
}
case "not_locked":
panic("RECEIVED not_locked, meaning a node thought it held a lock it didn't")
case "not_requirement":
}
case "unlocked":
if source == node.ID {
return nil
}
if ext.Requirements != nil {
state, exists := ext.Requirements[source]
if exists == false {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state != Unlocking {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source)
} else {
ext.Requirements[source] = Unlocked
return msgs
}
if ext.PendingOwner == nil {
unlocked := 0
for _, s := range(ext.Requirements) {
if s == Unlocked {
unlocked += 1
}
}
// Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages {
state := signal.Str
log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, state)
if len(ext.Requirements) == unlocked {
previous_owner := *ext.Owner
ext.Owner = nil
messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), previous_owner)
}
}
}
}
msgs := Messages{}
switch state {
case "locked":
if source == node.ID {
return nil
}
if ext.Requirements != nil {
state, exists := ext.Requirements[source]
if exists == false {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state != Locking {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source)
} else {
state, found := ext.Requirements[source]
if found == false {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state == Locking {
if ext.State == Locking {
ext.Requirements[source] = Locked
if ext.PendingOwner != nil {
locked := 0
for _, s := range(ext.Requirements) {
if s == Locked {
locked += 1
}
reqs := 0
locked := 0
for _, s := range(ext.Requirements) {
reqs += 1
if s == Locked {
locked += 1
}
}
if len(ext.Requirements) == locked {
ext.Owner = ext.PendingOwner
messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner)
}
if locked == reqs {
ext.State = Locked
ext.Owner = ext.PendingOwner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner)
} else {
log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs)
}
} else if ext.State == AbortingLock {
ext.Requirements[source] = Unlocking
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), source)
}
}
case "locking":
if ext.Requirements != nil {
state, exists := ext.Requirements[source]
if exists == false {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state != Locking {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source)
case "unlocked":
state, found := ext.Requirements[source]
if found == false {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state == Unlocking {
ext.Requirements[source] = Unlocked
reqs := 0
unlocked := 0
for _, s := range(ext.Requirements) {
reqs += 1
if s == Unlocked {
unlocked += 1
}
}
}
if unlocked == reqs {
old_state := ext.State
ext.State = Unlocked
if old_state == Unlocking {
ext.Owner = ext.PendingOwner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner)
} else if old_state == AbortingLock {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(ext.ReqID, "not_unlocked"), *ext.PendingOwner)
ext.PendingOwner = ext.Owner
}
} else {
log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs)
}
}
case "lock":
if ext.Owner != nil {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source)
} else if ext.PendingOwner != nil {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locking"), source)
} else {
owner := source
if ext.State == Unlocked {
if len(ext.Requirements) == 0 {
ext.Owner = &owner
ext.PendingOwner = ext.Owner
messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), source)
ext.State = Locked
new_owner := source
ext.PendingOwner = &new_owner
ext.Owner = &new_owner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), new_owner)
} else {
ext.PendingOwner = &owner
ext.State = Locking
ext.ReqID = signal.ID()
new_owner := source
ext.PendingOwner = &new_owner
for id, state := range(ext.Requirements) {
log.Logf("lockable_detail", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id)
if state != Unlocked {
panic("NOT_UNLOCKED")
log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING")
}
ext.Requirements[id] = Locking
messages = messages.Add(node.ID, node.Key, NewLockSignal("lock"), id)
lock_signal := NewLockSignal("lock")
msgs = msgs.Add(node.ID, node.Key, lock_signal, id)
}
log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %d requirements", node.ID, len(ext.Requirements))
if source != node.ID {
messages = messages.Add(node.ID, node.Key, NewLockSignal("locking"), source)
}
} else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source)
}
case "unlock":
if ext.State == Locked {
if len(ext.Requirements) == 0 {
ext.State = Unlocked
new_owner := source
ext.PendingOwner = nil
ext.Owner = nil
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), new_owner)
} else if source == *ext.Owner {
ext.State = Unlocking
ext.ReqID = signal.ID()
ext.PendingOwner = nil
for id, state := range(ext.Requirements) {
if state != Locked {
log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING")
}
ext.Requirements[id] = Unlocking
lock_signal := NewLockSignal("unlock")
msgs = msgs.Add(node.ID, node.Key, lock_signal, id)
}
}
} else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locked"), source)
}
default:
log.Logf("lockable", "LOCK_ERR: unkown state %s", state)
}
return messages
return msgs
}
// LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes
@ -314,6 +323,8 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal
switch signal.Type() {
case LockSignalType:
messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal))
case ErrorSignalType:
messages = ext.HandleErrorSignal(ctx.Log, node, source, signal.(*ErrorSignal))
default:
}
default:

@ -100,7 +100,7 @@ func TestLink10K(t *testing.T) {
)
ctx.Log.Logf("test", "CREATED_LISTENER")
err = LockLockable(ctx, node)
_, err = LockLockable(ctx, node, node.ID)
fatalErr(t, err)
_, err = WaitForSignal(listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool {
@ -118,7 +118,7 @@ func TestLink10K(t *testing.T) {
}
func TestLock(t *testing.T) {
ctx := lockableTestContext(t, []string{"lockable", "policy"})
ctx := lockableTestContext(t, []string{"lockable"})
policy := NewAllNodesPolicy(nil)
@ -138,26 +138,40 @@ func TestLock(t *testing.T) {
l3, _ := NewLockable(nil)
l4, _ := NewLockable(nil)
l5, _ := NewLockable(nil)
NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
l0, l0_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
locked := func(sig *StringSignal) bool {
return sig.Str == "locked"
}
err := LockLockable(ctx, l1)
unlocked := func(sig *StringSignal) bool {
return sig.Str == "unlocked"
}
_, err := LockLockable(ctx, l0, l5.ID)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
_, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
id, err := LockLockable(ctx, l1, l1.ID)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool {
return sig.Error == "not_unlocked" && sig.ReqID() == id
})
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
_, err = UnlockLockable(ctx, l0, l5.ID)
fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
_, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, unlocked)
fatalErr(t, err)
err = UnlockLockable(ctx, l1)
_, err = LockLockable(ctx, l1, l1.ID)
fatalErr(t, err)
for i := 0; i < 4; i++ {
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, func(sig *StringSignal) bool {
return sig.Str == "locked"
})
fatalErr(t, err)
}
}

@ -254,8 +254,9 @@ func nodeLoop(ctx *Context, node *Node) error {
pends, resp := node.Allows(princ_id, msg.Signal.Permission())
if resp == Deny {
ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission())
ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source)
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), source)
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), msg.Source)
ctx.Send(msgs)
continue
} else if resp == Pending {
@ -369,6 +370,7 @@ func nodeLoop(ctx *Context, node *Node) error {
result := node.ReadFields(read_signal.Extensions)
msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.ID, node.Type, result), source)
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(read_signal.ID(), "read_done"), source)
ctx.Send(msgs)
}
}

@ -28,7 +28,6 @@ const (
LinkStartSignalType = SignalType("LINK_START")
ECDHSignalType = SignalType("ECDH")
ECDHProxySignalType = SignalType("ECDH_PROXY")
GQLStateSignalType = SignalType("GQL_STATE")
ACLTimeoutSignalType = SignalType("ACL_TIMEOUT")
Up SignalDirection = iota