Fixed lockable to properly fail

gql_cataclysm
noah metz 2023-08-15 18:23:06 -06:00
parent b446c9078a
commit 98c0b7e807
7 changed files with 277 additions and 202 deletions

@ -238,6 +238,9 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) {
// Route a Signal to dest. Currently only local context routing is supported // Route a Signal to dest. Currently only local context routing is supported
func (ctx *Context) Send(messages Messages) error { func (ctx *Context) Send(messages Messages) error {
for _, msg := range(messages) { for _, msg := range(messages) {
if msg.Dest == ZeroID {
panic("Can't send to null ID")
}
target, err := ctx.GetNode(msg.Dest) target, err := ctx.GetNode(msg.Dest)
if err == nil { if err == nil {
select { select {

@ -808,12 +808,17 @@ func NewGQLExtContext() *GQLExtContext {
Fields: graphql.Fields{}, Fields: graphql.Fields{},
}) })
subscription := graphql.NewObject(graphql.ObjectConfig{
Name: "Subscription",
Fields: graphql.Fields{},
})
context := GQLExtContext{ context := GQLExtContext{
Schema: graphql.Schema{}, Schema: graphql.Schema{},
Types: []graphql.Type{}, Types: []graphql.Type{},
Query: query, Query: query,
Mutation: mutation, Mutation: mutation,
Subscription: nil, Subscription: subscription,
NodeTypes: map[NodeType]*graphql.Object{}, NodeTypes: map[NodeType]*graphql.Object{},
Interfaces: map[string]*Interface{}, Interfaces: map[string]*Interface{},
Fields: map[string]Field{}, Fields: map[string]Field{},
@ -872,9 +877,9 @@ func NewGQLExtContext() *GQLExtContext {
"requirements", "requirements",
LockableExtType, LockableExtType,
func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) { func(p graphql.ResolveParams, val interface{}) ([]NodeID, error) {
id_strs, ok := val.(map[NodeID]string) id_strs, ok := val.(map[NodeID]ReqState)
if ok == false { if ok == false {
return nil, fmt.Errorf("can't parse requirements %+v as string, %s", val, reflect.TypeOf(val)) return nil, fmt.Errorf("can't parse requirements %+v as map[NodeID]ReqState, %s", val, reflect.TypeOf(val))
} }
ids := []NodeID{} ids := []NodeID{}
@ -909,6 +914,20 @@ func NewGQLExtContext() *GQLExtContext {
}, },
}) })
context.Subscription.AddFieldConfig("Self", &graphql.Field{
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return p.Source, nil
},
Subscribe: func(p graphql.ResolveParams) (interface{}, error) {
c := make(chan interface{}, 10)
for i := 0; i < 10; i++ {
c <- fmt.Sprintf("test %d", i)
}
return c, nil
},
})
context.Query.AddFieldConfig("Self", &graphql.Field{ context.Query.AddFieldConfig("Self", &graphql.Field{
Type: context.Interfaces["Node"].Interface, Type: context.Interfaces["Node"].Interface,
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
@ -976,7 +995,6 @@ type GQLExt struct {
resolver_response map[uuid.UUID]chan Signal `json:"-"` resolver_response map[uuid.UUID]chan Signal `json:"-"`
resolver_response_lock sync.RWMutex `json:"-"` resolver_response_lock sync.RWMutex `json:"-"`
State string `json:"state"`
TLSKey []byte `json:"tls_key"` TLSKey []byte `json:"tls_key"`
TLSCert []byte `json:"tls_cert"` TLSCert []byte `json:"tls_cert"`
Listen string `json:"listen"` Listen string `json:"listen"`
@ -990,6 +1008,13 @@ func (ext *GQLExt) Field(name string) interface{} {
}) })
} }
func (ext *GQLExt) FindResponseChannel(req_id uuid.UUID) chan Signal {
ext.resolver_response_lock.RLock()
response_chan, _ := ext.resolver_response[req_id]
ext.resolver_response_lock.RUnlock()
return response_chan
}
func (ext *GQLExt) GetResponseChannel(req_id uuid.UUID) chan Signal { func (ext *GQLExt) GetResponseChannel(req_id uuid.UUID) chan Signal {
response_chan := make(chan Signal, 1) response_chan := make(chan Signal, 1)
ext.resolver_response_lock.Lock() ext.resolver_response_lock.Lock()
@ -999,18 +1024,14 @@ func (ext *GQLExt) GetResponseChannel(req_id uuid.UUID) chan Signal {
} }
func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal { func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
ext.resolver_response_lock.RLock() response_chan := ext.FindResponseChannel(req_id)
response_chan, exists := ext.resolver_response[req_id]
ext.resolver_response_lock.RUnlock()
if exists == true { if response_chan != nil {
ext.resolver_response_lock.Lock() ext.resolver_response_lock.Lock()
delete(ext.resolver_response, req_id) delete(ext.resolver_response, req_id)
ext.resolver_response_lock.Unlock() ext.resolver_response_lock.Unlock()
return response_chan
} else {
return nil
} }
return response_chan
} }
func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
@ -1033,7 +1054,7 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} }
} else if signal.Type() == ReadResultSignalType { } else if signal.Type() == ReadResultSignalType {
sig := signal.(*ReadResultSignal) sig := signal.(*ReadResultSignal)
response_chan := ext.FreeResponseChannel(sig.ReqID()) response_chan := ext.FindResponseChannel(sig.ReqID())
if response_chan != nil { if response_chan != nil {
select { select {
case response_chan <- sig: case response_chan <- sig:
@ -1044,47 +1065,14 @@ func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signa
} else { } else {
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig) ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
} }
} else if signal.Type() == GQLStateSignalType {
sig := signal.(*StringSignal)
ctx.Log.Logf("gql", "GQL_STATE_SIGNAL: %s - %+v", node.ID, sig.Str)
switch sig.Str {
case "start_server":
if ext.State == "stopped" {
err := ext.StartGQLServer(ctx, node)
if err == nil {
ext.State = "running"
node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
} else {
ctx.Log.Logf("gql", "GQL_START_ERROR: %s", err)
}
}
case "stop_server":
if ext.State == "running" {
err := ext.StopGQLServer()
if err == nil {
ext.State = "stopped"
node.QueueSignal(time.Now(), NewStatusSignal("server_stopped", node.ID))
} else {
ctx.Log.Logf("gql", "GQL_STOP_ERROR: %s", err)
}
}
default:
ctx.Log.Logf("gql", "unknown gql state %s", sig.Str)
}
} else if signal.Type() == StartSignalType { } else if signal.Type() == StartSignalType {
ctx.Log.Logf("gql", "starting with state: %s", ext.State) ctx.Log.Logf("gql", "starting gql server %s", node.ID)
switch ext.State {
case "running":
err := ext.StartGQLServer(ctx, node) err := ext.StartGQLServer(ctx, node)
if err == nil { if err == nil {
node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID)) node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
} else { } else {
ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err)
} }
case "stopped":
default:
ctx.Log.Logf("gql", "unknown state to restore from: %s", ext.State)
}
} }
return messages return messages
} }
@ -1118,7 +1106,7 @@ func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, &ext) return json.Unmarshal(data, &ext)
} }
func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) { func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte) (*GQLExt, error) {
if tls_cert == nil || tls_key == nil { if tls_cert == nil || tls_key == nil {
ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
@ -1159,7 +1147,6 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta
tls_key = ssl_key_pem tls_key = ssl_key_pem
} }
return &GQLExt{ return &GQLExt{
State: state,
Listen: listen, Listen: listen,
resolver_response: map[uuid.UUID]chan Signal{}, resolver_response: map[uuid.UUID]chan Signal{},
TLSCert: tls_cert, TLSCert: tls_cert,

@ -14,10 +14,12 @@ import (
"crypto/rand" "crypto/rand"
"crypto/ed25519" "crypto/ed25519"
"bytes" "bytes"
"golang.org/x/net/websocket"
"github.com/google/uuid"
) )
func TestGQLServer(t *testing.T) { func TestGQLServer(t *testing.T) {
ctx := logTestContext(t, []string{"test", "policy", "pending"}) ctx := logTestContext(t, []string{"test", "gql", "policy", "pending"})
TestNodeType := NodeType("TEST") TestNodeType := NodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType})
@ -44,7 +46,6 @@ func TestGQLServer(t *testing.T) {
LockSignalType.String(): nil, LockSignalType.String(): nil,
StatusSignalType.String(): nil, StatusSignalType.String(): nil,
ReadSignalType.String(): nil, ReadSignalType.String(): nil,
GQLStateSignalType.String(): nil,
}, },
}) })
@ -60,7 +61,7 @@ func TestGQLServer(t *testing.T) {
}, },
}) })
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped") gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err) fatalErr(t, err)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
@ -80,11 +81,6 @@ func TestGQLServer(t *testing.T) {
ctx.Log.Logf("test", "GQL: %s", gql.ID) ctx.Log.Logf("test", "GQL: %s", gql.ID)
ctx.Log.Logf("test", "NODE: %s", n1.ID) ctx.Log.Logf("test", "NODE: %s", n1.ID)
msgs := Messages{}
msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}, gql.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "server_started" return sig.Str == "server_started"
}) })
@ -96,6 +92,7 @@ func TestGQLServer(t *testing.T) {
client := &http.Client{Transport: skipVerifyTransport} client := &http.Client{Transport: skipVerifyTransport}
port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port port := gql_ext.tcp_listener.Addr().(*net.TCPAddr).Port
url := fmt.Sprintf("https://localhost:%d/gql", port) url := fmt.Sprintf("https://localhost:%d/gql", port)
ws_url := fmt.Sprintf("wss://127.0.0.1:%d/gqlws", port)
req_1 := GQLPayload{ req_1 := GQLPayload{
Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash } }", Query: "query Node($id:String) { Node(id:$id) { ID, TypeHash } }",
@ -111,6 +108,12 @@ func TestGQLServer(t *testing.T) {
}, },
} }
auth_username := base64.StdEncoding.EncodeToString(n1.ID.Serialize())
key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key)
fatalErr(t, err)
auth_password := base64.StdEncoding.EncodeToString(key_bytes)
auth_b64 := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", auth_username, auth_password)))
SendGQL := func(payload GQLPayload) []byte { SendGQL := func(payload GQLPayload) []byte {
ser, err := json.MarshalIndent(&payload, "", " ") ser, err := json.MarshalIndent(&payload, "", " ")
fatalErr(t, err) fatalErr(t, err)
@ -119,9 +122,7 @@ func TestGQLServer(t *testing.T) {
req, err := http.NewRequest("GET", url, req_data) req, err := http.NewRequest("GET", url, req_data)
fatalErr(t, err) fatalErr(t, err)
key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key) req.SetBasicAuth(auth_username, auth_password)
fatalErr(t, err)
req.SetBasicAuth(base64.StdEncoding.EncodeToString(n1.ID.Serialize()), base64.StdEncoding.EncodeToString(key_bytes))
resp, err := client.Do(req) resp, err := client.Do(req)
fatalErr(t, err) fatalErr(t, err)
@ -137,12 +138,70 @@ func TestGQLServer(t *testing.T) {
resp_2 := SendGQL(req_2) resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2) ctx.Log.Logf("test", "RESP_2: %s", resp_2)
msgs = Messages{} sub_1 := GQLPayload{
msgs = msgs.Add(gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}, gql.ID) Query: "subscription { Self }",
}
SubGQL := func(payload GQLPayload) {
config, err := websocket.NewConfig(ws_url, url)
fatalErr(t, err)
config.Protocol = append(config.Protocol, "graphql-ws")
config.TlsConfig = &tls.Config{InsecureSkipVerify: true}
config.Header.Add("Authorization", fmt.Sprintf("Basic %s", auth_b64))
ws, err := websocket.DialConfig(config)
fatalErr(t, err)
init := GQLWSMsg{
ID: uuid.New().String(),
Type: "connection_init",
}
ser, err := json.Marshal(&init)
fatalErr(t, err)
_, err = ws.Write(ser)
fatalErr(t, err)
resp := make([]byte, 1024)
n, err := ws.Read(resp)
var init_resp GQLWSMsg
err = json.Unmarshal(resp[:n], &init_resp)
fatalErr(t, err)
if init_resp.Type != "connection_ack" {
t.Fatal("Didn't receive connection_ack")
}
sub := GQLWSMsg{
ID: uuid.New().String(),
Type: "subscribe",
Payload: sub_1,
}
ser, err = json.Marshal(&sub)
fatalErr(t, err)
_, err = ws.Write(ser)
fatalErr(t, err)
for i := 0; i < 10; i++ {
n, err = ws.Read(resp)
fatalErr(t, err)
ctx.Log.Logf("test", "SUB_%d: %s", i, resp[:n])
}
}
SubGQL(sub_1)
msgs := Messages{}
msgs = msgs.Add(gql.ID, gql.Key, &StopSignal, gql.ID)
err = ctx.Send(msgs) err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "server_stopped" return sig.Str == "stopped"
}) })
fatalErr(t, err) fatalErr(t, err)
} }
@ -157,7 +216,7 @@ func TestGQLDB(t *testing.T) {
ctx.Log.Logf("test", "U1_ID: %s", u1.ID) ctx.Log.Logf("test", "U1_ID: %s", u1.ID)
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "start") gql_ext, err := NewGQLExt(ctx, ":0", nil, nil)
fatalErr(t, err) fatalErr(t, err)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
gql := NewNode(ctx, nil, GQLNodeType, 10, nil, gql := NewNode(ctx, nil, GQLNodeType, 10, nil,

@ -2,17 +2,21 @@ package graphvent
import ( import (
"encoding/binary" "encoding/binary"
"github.com/google/uuid"
) )
type ReqState int type ReqState byte
const ( const (
Unlocked = ReqState(0) Unlocked = ReqState(0)
Unlocking = ReqState(1) Unlocking = ReqState(1)
Locked = ReqState(2) Locked = ReqState(2)
Locking = ReqState(3) Locking = ReqState(3)
AbortingLock = ReqState(4)
) )
type LockableExt struct{ type LockableExt struct{
State ReqState
ReqID uuid.UUID
Owner *NodeID Owner *NodeID
PendingOwner *NodeID PendingOwner *NodeID
Requirements map[NodeID]ReqState Requirements map[NodeID]ReqState
@ -37,7 +41,7 @@ func (ext *LockableExt) Type() ExtType {
} }
func (ext *LockableExt) Serialize() ([]byte, error) { func (ext *LockableExt) Serialize() ([]byte, error) {
ret := make([]byte, 8 + (16 * 2) + (17 * len(ext.Requirements))) ret := make([]byte, 9 + (16 * 2) + (17 * len(ext.Requirements)))
if ext.Owner != nil { if ext.Owner != nil {
bytes, err := ext.Owner.MarshalBinary() bytes, err := ext.Owner.MarshalBinary()
if err != nil { if err != nil {
@ -55,8 +59,8 @@ func (ext *LockableExt) Serialize() ([]byte, error) {
} }
binary.BigEndian.PutUint64(ret[32:40], uint64(len(ext.Requirements))) binary.BigEndian.PutUint64(ret[32:40], uint64(len(ext.Requirements)))
ret[40] = byte(ext.State)
cur := 40 cur := 41
for req, state := range(ext.Requirements) { for req, state := range(ext.Requirements) {
bytes, err := req.MarshalBinary() bytes, err := req.MarshalBinary()
if err != nil { if err != nil {
@ -105,6 +109,9 @@ func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error {
num_requirements := int(binary.BigEndian.Uint64(data[cur:cur+8])) num_requirements := int(binary.BigEndian.Uint64(data[cur:cur+8]))
cur += 8 cur += 8
ext.State = ReqState(data[cur])
cur += 1
if num_requirements != 0 { if num_requirements != 0 {
ext.Requirements = map[NodeID]ReqState{} ext.Requirements = map[NodeID]ReqState{}
} }
@ -130,6 +137,7 @@ func NewLockableExt(requirements []NodeID) *LockableExt {
} }
} }
return &LockableExt{ return &LockableExt{
State: Unlocked,
Owner: nil, Owner: nil,
PendingOwner: nil, PendingOwner: nil,
Requirements: reqs, Requirements: reqs,
@ -137,162 +145,163 @@ func NewLockableExt(requirements []NodeID) *LockableExt {
} }
// Send the signal to unlock a node from itself // Send the signal to unlock a node from itself
func UnlockLockable(ctx *Context, node *Node) error { func UnlockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) {
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), node.ID) signal := NewLockSignal("unlock")
return ctx.Send(msgs) msgs = msgs.Add(owner.ID, owner.Key, signal, target)
return signal.ID(), ctx.Send(msgs)
} }
// Send the signal to lock a node from itself // Send the signal to lock a node from itself
func LockLockable(ctx *Context, node *Node) error { func LockLockable(ctx *Context, owner *Node, target NodeID) (uuid.UUID, error) {
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("lock"), node.ID) signal := NewLockSignal("lock")
return ctx.Send(msgs) msgs = msgs.Add(owner.ID, owner.Key, signal, target)
return signal.ID(), ctx.Send(msgs)
} }
// Handle a LockSignal and update the extensions owner/requirement states func (ext *LockableExt) HandleErrorSignal(log Logger, node *Node, source NodeID, signal *ErrorSignal) Messages {
func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages { str := signal.Error
state := signal.Str log.Logf("lockable", "ERROR_SIGNAL: %s->%s %+v", source, node.ID, str)
log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal)
messages := Messages{} msgs := Messages {}
switch state { switch str {
case "unlock": case "not_unlocked":
if ext.Owner == nil { if ext.State == Locking {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocked"), source) ext.State = AbortingLock
} else if source != *ext.Owner { ext.Requirements[source] = Unlocked
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_owner"), source)
} else if ext.PendingOwner == nil {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocking"), source)
} else {
if len(ext.Requirements) == 0 {
ext.Owner = nil
ext.PendingOwner = nil
messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), source)
} else {
ext.PendingOwner = nil
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
if state != Locked { if state == Locked {
panic("NOT_LOCKED")
}
ext.Requirements[id] = Unlocking ext.Requirements[id] = Unlocking
messages = messages.Add(node.ID, node.Key, NewLockSignal("unlock"), id) msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), id)
}
if source != node.ID {
messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocking"), source)
}
} }
} }
case "unlocking":
if ext.Requirements != nil {
state, exists := ext.Requirements[source]
if exists == false {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state != Unlocking {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source)
} }
case "not_locked":
panic("RECEIVED not_locked, meaning a node thought it held a lock it didn't")
case "not_requirement":
} }
case "unlocked": return msgs
if source == node.ID { }
return nil
}
if ext.Requirements != nil {
state, exists := ext.Requirements[source]
if exists == false {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state != Unlocking {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source)
} else {
ext.Requirements[source] = Unlocked
if ext.PendingOwner == nil { // Handle a LockSignal and update the extensions owner/requirement states
unlocked := 0 func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages {
for _, s := range(ext.Requirements) { state := signal.Str
if s == Unlocked { log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, state)
unlocked += 1
}
}
if len(ext.Requirements) == unlocked { msgs := Messages{}
previous_owner := *ext.Owner switch state {
ext.Owner = nil
messages = messages.Add(node.ID, node.Key, NewLockSignal("unlocked"), previous_owner)
}
}
}
}
case "locked": case "locked":
if source == node.ID { state, found := ext.Requirements[source]
return nil if found == false {
} msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state == Locking {
if ext.Requirements != nil { if ext.State == Locking {
state, exists := ext.Requirements[source]
if exists == false {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state != Locking {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source)
} else {
ext.Requirements[source] = Locked ext.Requirements[source] = Locked
reqs := 0
if ext.PendingOwner != nil {
locked := 0 locked := 0
for _, s := range(ext.Requirements) { for _, s := range(ext.Requirements) {
reqs += 1
if s == Locked { if s == Locked {
locked += 1 locked += 1
} }
} }
if len(ext.Requirements) == locked { if locked == reqs {
ext.State = Locked
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), *ext.Owner)
} } else {
log.Logf("lockable", "PARTIAL LOCK: %s - %d/%d", node.ID, locked, reqs)
} }
} else if ext.State == AbortingLock {
ext.Requirements[source] = Unlocking
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlock"), source)
} }
} }
case "locking": case "unlocked":
if ext.Requirements != nil { state, found := ext.Requirements[source]
state, exists := ext.Requirements[source] if found == false {
if exists == false { msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state == Unlocking {
} else if state != Locking { ext.Requirements[source] = Unlocked
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) reqs := 0
unlocked := 0
for _, s := range(ext.Requirements) {
reqs += 1
if s == Unlocked {
unlocked += 1
} }
} }
case "lock": if unlocked == reqs {
if ext.Owner != nil { old_state := ext.State
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source) ext.State = Unlocked
} else if ext.PendingOwner != nil { if old_state == Unlocking {
messages = messages.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locking"), source) ext.Owner = ext.PendingOwner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), *ext.Owner)
} else if old_state == AbortingLock {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(ext.ReqID, "not_unlocked"), *ext.PendingOwner)
ext.PendingOwner = ext.Owner
}
} else { } else {
owner := source log.Logf("lockable", "PARTIAL UNLOCK: %s - %d/%d", node.ID, unlocked, reqs)
}
}
case "lock":
if ext.State == Unlocked {
if len(ext.Requirements) == 0 { if len(ext.Requirements) == 0 {
ext.Owner = &owner ext.State = Locked
ext.PendingOwner = ext.Owner new_owner := source
messages = messages.Add(node.ID, node.Key, NewLockSignal("locked"), source) ext.PendingOwner = &new_owner
ext.Owner = &new_owner
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("locked"), new_owner)
} else { } else {
ext.PendingOwner = &owner ext.State = Locking
ext.ReqID = signal.ID()
new_owner := source
ext.PendingOwner = &new_owner
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
log.Logf("lockable_detail", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id)
if state != Unlocked { if state != Unlocked {
panic("NOT_UNLOCKED") log.Logf("lockable", "REQ_NOT_UNLOCKED_WHEN_LOCKING")
} }
ext.Requirements[id] = Locking ext.Requirements[id] = Locking
messages = messages.Add(node.ID, node.Key, NewLockSignal("lock"), id) lock_signal := NewLockSignal("lock")
msgs = msgs.Add(node.ID, node.Key, lock_signal, id)
}
}
} else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocked"), source)
}
case "unlock":
if ext.State == Locked {
if len(ext.Requirements) == 0 {
ext.State = Unlocked
new_owner := source
ext.PendingOwner = nil
ext.Owner = nil
msgs = msgs.Add(node.ID, node.Key, NewLockSignal("unlocked"), new_owner)
} else if source == *ext.Owner {
ext.State = Unlocking
ext.ReqID = signal.ID()
ext.PendingOwner = nil
for id, state := range(ext.Requirements) {
if state != Locked {
log.Logf("lockable", "REQ_NOT_LOCKED_WHEN_UNLOCKING")
} }
log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %d requirements", node.ID, len(ext.Requirements)) ext.Requirements[id] = Unlocking
if source != node.ID { lock_signal := NewLockSignal("unlock")
messages = messages.Add(node.ID, node.Key, NewLockSignal("locking"), source) msgs = msgs.Add(node.ID, node.Key, lock_signal, id)
} }
} }
} else {
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locked"), source)
} }
default: default:
log.Logf("lockable", "LOCK_ERR: unkown state %s", state) log.Logf("lockable", "LOCK_ERR: unkown state %s", state)
} }
return messages return msgs
} }
// LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes // LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes
@ -314,6 +323,8 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal
switch signal.Type() { switch signal.Type() {
case LockSignalType: case LockSignalType:
messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal)) messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal))
case ErrorSignalType:
messages = ext.HandleErrorSignal(ctx.Log, node, source, signal.(*ErrorSignal))
default: default:
} }
default: default:

@ -100,7 +100,7 @@ func TestLink10K(t *testing.T) {
) )
ctx.Log.Logf("test", "CREATED_LISTENER") ctx.Log.Logf("test", "CREATED_LISTENER")
err = LockLockable(ctx, node) _, err = LockLockable(ctx, node, node.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool { _, err = WaitForSignal(listener.Chan, time.Millisecond*1000, LockSignalType, func(sig *StringSignal) bool {
@ -118,7 +118,7 @@ func TestLink10K(t *testing.T) {
} }
func TestLock(t *testing.T) { func TestLock(t *testing.T) {
ctx := lockableTestContext(t, []string{"lockable", "policy"}) ctx := lockableTestContext(t, []string{"lockable"})
policy := NewAllNodesPolicy(nil) policy := NewAllNodesPolicy(nil)
@ -138,26 +138,40 @@ func TestLock(t *testing.T) {
l3, _ := NewLockable(nil) l3, _ := NewLockable(nil)
l4, _ := NewLockable(nil) l4, _ := NewLockable(nil)
l5, _ := NewLockable(nil) l5, _ := NewLockable(nil)
NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) l0, l0_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID}) l1, l1_listener := NewLockable([]NodeID{l2.ID, l3.ID, l4.ID, l5.ID})
locked := func(sig *StringSignal) bool { locked := func(sig *StringSignal) bool {
return sig.Str == "locked" return sig.Str == "locked"
} }
err := LockLockable(ctx, l1) unlocked := func(sig *StringSignal) bool {
return sig.Str == "unlocked"
}
_, err := LockLockable(ctx, l0, l5.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
id, err := LockLockable(ctx, l1, l1.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) _, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, ErrorSignalType, func(sig *ErrorSignal) bool {
return sig.Error == "not_unlocked" && sig.ReqID() == id
})
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
_, err = UnlockLockable(ctx, l0, l5.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) _, err = WaitForSignal(l0_listener.Chan, time.Millisecond*10, LockSignalType, unlocked)
fatalErr(t, err) fatalErr(t, err)
err = UnlockLockable(ctx, l1) _, err = LockLockable(ctx, l1, l1.ID)
fatalErr(t, err) fatalErr(t, err)
for i := 0; i < 4; i++ {
_, err = WaitForSignal(l1_listener.Chan, time.Millisecond*10, LockSignalType, func(sig *StringSignal) bool {
return sig.Str == "locked"
})
fatalErr(t, err)
}
} }

@ -254,8 +254,9 @@ func nodeLoop(ctx *Context, node *Node) error {
pends, resp := node.Allows(princ_id, msg.Signal.Permission()) pends, resp := node.Allows(princ_id, msg.Signal.Permission())
if resp == Deny { if resp == Deny {
ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission()) ctx.Log.Logf("policy", "SIGNAL_POLICY_DENY: %s->%s - %s", princ_id, node.ID, msg.Signal.Permission())
ctx.Log.Logf("policy", "SIGNAL_POLICY_SOURCE: %s", msg.Source)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), source) msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(msg.Signal.ID(), "acl denied"), msg.Source)
ctx.Send(msgs) ctx.Send(msgs)
continue continue
} else if resp == Pending { } else if resp == Pending {
@ -369,6 +370,7 @@ func nodeLoop(ctx *Context, node *Node) error {
result := node.ReadFields(read_signal.Extensions) result := node.ReadFields(read_signal.Extensions)
msgs := Messages{} msgs := Messages{}
msgs = msgs.Add(node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.ID, node.Type, result), source) msgs = msgs.Add(node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.ID, node.Type, result), source)
msgs = msgs.Add(node.ID, node.Key, NewErrorSignal(read_signal.ID(), "read_done"), source)
ctx.Send(msgs) ctx.Send(msgs)
} }
} }

@ -28,7 +28,6 @@ const (
LinkStartSignalType = SignalType("LINK_START") LinkStartSignalType = SignalType("LINK_START")
ECDHSignalType = SignalType("ECDH") ECDHSignalType = SignalType("ECDH")
ECDHProxySignalType = SignalType("ECDH_PROXY") ECDHProxySignalType = SignalType("ECDH_PROXY")
GQLStateSignalType = SignalType("GQL_STATE")
ACLTimeoutSignalType = SignalType("ACL_TIMEOUT") ACLTimeoutSignalType = SignalType("ACL_TIMEOUT")
Up SignalDirection = iota Up SignalDirection = iota