Moved policies to node instead of an extension, need to fix gql tests

gql_cataclysm
noah metz 2023-08-07 20:26:02 -06:00
parent 1d91854f6f
commit 8770d6f433
13 changed files with 535 additions and 667 deletions

@ -48,9 +48,30 @@ func LoadExtension[T any, E interface {
return e, nil return e, nil
} }
type PolicyType string
func (policy PolicyType) Prefix() string { return "POLICY: " }
func (policy PolicyType) String() string { return string(policy) }
type PolicyLoadFunc func(*Context,[]byte) (Policy, error)
func LoadPolicy[T any, P interface {
*T
Policy
}](ctx *Context, data []byte) (Policy, error) {
p := P(new(T))
err := p.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return p, nil
}
type PolicyInfo struct {
Load PolicyLoadFunc
Type PolicyType
}
// ExtType and NodeType constants // ExtType and NodeType constants
const ( const (
ACLExtType = ExtType("ACL")
ListenerExtType = ExtType("LISTENER") ListenerExtType = ExtType("LISTENER")
LockableExtType = ExtType("LOCKABLE") LockableExtType = ExtType("LOCKABLE")
GQLExtType = ExtType("GQL") GQLExtType = ExtType("GQL")
@ -62,6 +83,7 @@ const (
var ( var (
NodeNotFoundError = errors.New("Node not found in DB") NodeNotFoundError = errors.New("Node not found in DB")
ECDH = ecdh.X25519()
) )
type SignalLoadFunc func(*Context,[]byte) (Signal, error) type SignalLoadFunc func(*Context,[]byte) (Signal, error)
@ -107,12 +129,12 @@ type Context struct {
Log Logger Log Logger
// Map between database extension hashes and the registered info // Map between database extension hashes and the registered info
Extensions map[uint64]ExtensionInfo Extensions map[uint64]ExtensionInfo
// Map between databse policy hashes and the registered info
Policies map[uint64]PolicyInfo
// Map between serialized signal hashes and the registered info // Map between serialized signal hashes and the registered info
Signals map[uint64]SignalInfo Signals map[uint64]SignalInfo
// Map between database type hashes and the registered info // Map between database type hashes and the registered info
Types map[uint64]*NodeInfo Types map[uint64]*NodeInfo
// Curve used for ecdh operations
ECDH ecdh.Curve
// Routing map to all the nodes local to this context // Routing map to all the nodes local to this context
NodesLock sync.RWMutex NodesLock sync.RWMutex
Nodes map[NodeID]*Node Nodes map[NodeID]*Node
@ -216,28 +238,31 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) {
// Stop every running loop // Stop every running loop
func (ctx *Context) Stop() { func (ctx *Context) Stop() {
for _, node := range(ctx.Nodes) { for _, node := range(ctx.Nodes) {
node.MsgChan <- Msg{ZeroID, &StopSignal} node.MsgChan <- Message{ZeroID, &StopSignal}
} }
} }
// Route a Signal to dest. Currently only local context routing is supported // Route a Signal to dest. Currently only local context routing is supported
func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error { func (ctx *Context) Send(source NodeID, messages []Message) error {
target, err := ctx.GetNode(dest) for _, msg := range(messages) {
target, err := ctx.GetNode(msg.NodeID)
if err == nil { if err == nil {
select { select {
case target.MsgChan <- Msg{source, signal}: case target.MsgChan <- Message{source, msg.Signal}:
default: default:
buf := make([]byte, 4096) buf := make([]byte, 4096)
n := runtime.Stack(buf, false) n := runtime.Stack(buf, false)
stack_str := string(buf[:n]) stack_str := string(buf[:n])
return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", dest, stack_str) return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", msg.NodeID, stack_str)
} }
return nil
} else if errors.Is(err, NodeNotFoundError) { } else if errors.Is(err, NodeNotFoundError) {
// TODO: Handle finding nodes in other contexts // TODO: Handle finding nodes in other contexts
return err return err
} } else {
return err return err
}
}
return nil
} }
// Create a new Context with the base library content added // Create a new Context with the base library content added
@ -249,15 +274,9 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
Types: map[uint64]*NodeInfo{}, Types: map[uint64]*NodeInfo{},
Signals: map[uint64]SignalInfo{}, Signals: map[uint64]SignalInfo{},
Nodes: map[NodeID]*Node{}, Nodes: map[NodeID]*Node{},
ECDH: ecdh.X25519(),
} }
var err error var err error
err = RegisterExtension[ACLExt,*ACLExt](ctx, NewACLExtContext())
if err != nil {
return nil, err
}
err = RegisterExtension[LockableExt,*LockableExt](ctx, nil) err = RegisterExtension[LockableExt,*LockableExt](ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -299,7 +318,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err return nil, err
} }
err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType}) err = ctx.RegisterNodeType(GQLNodeType, []ExtType{GroupExtType, GQLExtType})
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -103,28 +103,30 @@ func (ext *ECDHExt) Field(name string) interface{} {
}) })
} }
func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal *ECDHSignal) { func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) []Message {
source := KeyID(signal.EDDSA)
messages := []Message{}
switch signal.Str { switch signal.Str {
case "req": case "req":
state, exists := ext.ECDHStates[source] state, exists := ext.ECDHStates[source]
if exists == false { if exists == false {
state = ECDHState{nil, nil} state = ECDHState{nil, nil}
} }
resp, shared_secret, err := NewECDHRespSignal(ctx, node, signal) resp, shared_secret, err := NewECDHRespSignal(node, signal)
if err == nil { if err == nil {
state.SharedSecret = shared_secret state.SharedSecret = shared_secret
ext.ECDHStates[source] = state ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
ctx.Send(node.ID, source, &resp) messages = append(messages, Message{source, &resp})
} else { } else {
ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
// TODO: send error response messages = append(messages, Message{source, NewErrorSignal(signal.ID(), err.Error())})
} }
case "resp": case "resp":
state, exists := ext.ECDHStates[source] state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil { if exists == false || state.ECKey == nil {
resp := NewErrorSignal(signal.ID(), "no_req") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "no_req")})
ctx.Send(node.ID, source, &resp)
} else { } else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil { if err == nil {
@ -133,55 +135,23 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si
state.SharedSecret = shared_secret state.SharedSecret = shared_secret
state.ECKey = nil state.ECKey = nil
ext.ECDHStates[source] = state ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
} }
} }
} }
default: default:
ctx.Log.Logf("ecdh", "unknown echd state %s", signal.Str) log.Logf("ecdh", "unknown echd state %s", signal.Str)
}
}
func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
}
func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal *ECDHProxySignal) {
state, exists := ext.ECDHStates[source]
if exists == false {
resp := NewErrorSignal(signal.ID(), "no_req")
ctx.Send(node.ID, source, &resp)
} else if state.SharedSecret == nil {
resp := NewErrorSignal(signal.ID(), "no_shared")
ctx.Send(node.ID, source, &resp)
} else {
unwrapped_signal, err := ParseECDHProxySignal(ctx, signal, state.SharedSecret)
if err != nil {
resp := NewErrorSignal(signal.ID(), err.Error())
ctx.Send(node.ID, source, &resp)
} else {
//TODO: Figure out what I was trying to do here and fix it
ctx.Send(signal.Source, signal.Dest, unwrapped_signal)
}
} }
return messages
} }
func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { func (ext *ECDHExt) Process(ctx *Context, node *Node, msg Message) []Message {
switch signal.Direction() { switch msg.Signal.Type() {
case Direct:
switch signal.Type() {
case ECDHProxySignalType:
ecdh_signal := signal.(*ECDHProxySignal)
ext.HandleECDHProxySignal(ctx, source, node, ecdh_signal)
case ECDHStateSignalType:
ecdh_signal := signal.(*StringSignal)
ext.HandleStateSignal(ctx, source, node, ecdh_signal)
case ECDHSignalType: case ECDHSignalType:
ecdh_signal := signal.(*ECDHSignal) sig := msg.Signal.(*ECDHSignal)
ext.HandleECDHSignal(ctx, source, node, ecdh_signal) return ext.HandleECDHSignal(ctx.Log, node, sig)
default:
}
default:
} }
return nil
} }
func (ext *ECDHExt) Type() ExtType { func (ext *ECDHExt) Type() ExtType {

@ -17,9 +17,10 @@ import (
"github.com/gobwas/ws" "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"strings" "strings"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/ecdh" "crypto/ecdh"
"crypto/ed25519" "crypto/ed25519"
"crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"crypto/tls" "crypto/tls"
@ -189,15 +190,16 @@ type ResolveContext struct {
} }
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) {
username, key_bytes, ok := r.BasicAuth() id_bytes, key_bytes, ok := r.BasicAuth()
if ok == false { if ok == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header")
} }
auth_id, err := ParseID(username) auth_uuid, err := uuid.FromBytes([]byte(id_bytes))
if err != nil { if err != nil {
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username")
} }
auth_id := NodeID(auth_uuid)
key_raw, err := x509.ParsePKCS8PrivateKey([]byte(key_bytes)) key_raw, err := x509.ParsePKCS8PrivateKey([]byte(key_bytes))
if err != nil { if err != nil {
@ -916,7 +918,7 @@ func NewGQLExtContext() *GQLExtContext {
panic(err) panic(err)
} }
context.Mutation.AddFieldConfig("stopServer", &graphql.Field{ context.Mutation.AddFieldConfig("stop", &graphql.Field{
Type: graphql.String, Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) ctx, err := PrepResolve(p)
@ -924,14 +926,13 @@ func NewGQLExtContext() *GQLExtContext {
return nil, err return nil, err
} }
sig := StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"} sig, err := NewAuthorizedSignal(ctx.Key, &StopSignal)
err = Allowed(ctx.Context, ctx.User, sig.Permission(), ctx.Server)
if err != nil { if err != nil {
return err, nil return nil, err
} }
response_chan := ctx.Ext.GetResponseChannel(sig.ID()) response_chan := ctx.Ext.GetResponseChannel(sig.ID())
err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, &sig) err = ctx.Context.Send(ctx.Server.ID, []Message{Message{ctx.Server.ID, sig}})
if err != nil { if err != nil {
ctx.Ext.FreeResponseChannel(sig.ID()) ctx.Ext.FreeResponseChannel(sig.ID())
return nil, err return nil, err
@ -1016,8 +1017,8 @@ type GQLExt struct {
resolver_response_lock sync.RWMutex `json:"-"` resolver_response_lock sync.RWMutex `json:"-"`
State string `json:"state"` State string `json:"state"`
tls_key []byte `json:"tls_key"` TLSKey []byte `json:"tls_key"`
tls_cert []byte `json:"tls_cert"` TLSCert []byte `json:"tls_cert"`
Listen string `json:"listen"` Listen string `json:"listen"`
} }
@ -1052,12 +1053,14 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
} }
} }
func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { func (ext *GQLExt) Process(ctx *Context, node *Node, msg Message) []Message {
// Process ReadResultSignalType by forwarding it to the waiting resolver // Process ReadResultSignalType by forwarding it to the waiting resolver
signal := msg.Signal
messages := []Message{}
if signal.Type() == ErrorSignalType { if signal.Type() == ErrorSignalType {
// TODO: Forward to resolver if waiting for it // TODO: Forward to resolver if waiting for it
sig := signal.(*ErrorSignal) sig := signal.(*ErrorSignal)
response_chan := ext.FreeResponseChannel(sig.ID()) response_chan := ext.FreeResponseChannel(sig.UUID)
if response_chan != nil { if response_chan != nil {
select { select {
case response_chan <- sig: case response_chan <- sig:
@ -1084,14 +1087,16 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
} }
} else if signal.Type() == GQLStateSignalType { } else if signal.Type() == GQLStateSignalType {
sig := signal.(*StringSignal) sig := signal.(*StringSignal)
ctx.Log.Logf("gql", "GQL_STATE_SIGNAL: %s - %+v", node.ID, sig.Str)
switch sig.Str { switch sig.Str {
case "start_server": case "start_server":
if ext.State == "stopped" { if ext.State == "stopped" {
err := ext.StartGQLServer(ctx, node) err := ext.StartGQLServer(ctx, node)
if err == nil { if err == nil {
ext.State = "running" ext.State = "running"
resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"} node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
ctx.Send(node.ID, source, &resp) } else {
ctx.Log.Logf("gql", "GQL_START_ERROR: %s", err)
} }
} }
case "stop_server": case "stop_server":
@ -1099,8 +1104,9 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
err := ext.StopGQLServer() err := ext.StopGQLServer()
if err == nil { if err == nil {
ext.State = "stopped" ext.State = "stopped"
resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_stopped"} node.QueueSignal(time.Now(), NewStatusSignal("server_stopped", node.ID))
ctx.Send(node.ID, source, &resp) } else {
ctx.Log.Logf("gql", "GQL_STOP_ERROR: %s", err)
} }
} }
default: default:
@ -1112,14 +1118,16 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
case "running": case "running":
err := ext.StartGQLServer(ctx, node) err := ext.StartGQLServer(ctx, node)
if err == nil { if err == nil {
resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"} node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
ctx.Send(node.ID, source, &resp) } else {
ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err)
} }
case "stopped": case "stopped":
default: default:
ctx.Log.Logf("gql", "unknown state to restore from: %s", ext.State) ctx.Log.Logf("gql", "unknown state to restore from: %s", ext.State)
} }
} }
return messages
} }
func (ext *GQLExt) Type() ExtType { func (ext *GQLExt) Type() ExtType {
@ -1147,12 +1155,13 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{
} }
func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error { func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error {
ext.resolver_response = map[uuid.UUID]chan Signal{}
return json.Unmarshal(data, &ext) return json.Unmarshal(data, &ext)
} }
func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) { func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) {
if tls_cert == nil || tls_key == nil { if tls_cert == nil || tls_key == nil {
_, ssl_key, err := ed25519.GenerateKey(rand.Reader) ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1194,8 +1203,8 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta
State: state, State: state,
Listen: listen, Listen: listen,
resolver_response: map[uuid.UUID]chan Signal{}, resolver_response: map[uuid.UUID]chan Signal{},
tls_cert: tls_cert, TLSCert: tls_cert,
tls_key: tls_key, TLSKey: tls_key,
}, nil }, nil
} }
@ -1224,7 +1233,7 @@ func (ext *GQLExt) StartGQLServer(ctx *Context, node *Node) error {
return fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) return fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
} }
cert, err := tls.X509KeyPair(ext.tls_cert, ext.tls_key) cert, err := tls.X509KeyPair(ext.TLSCert, ext.TLSKey)
if err != nil { if err != nil {
return err return err
} }

@ -61,7 +61,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
resp_channels[read_signal.ID()] = response_chan resp_channels[read_signal.ID()] = response_chan
node_ids[read_signal.ID()] = id node_ids[read_signal.ID()] = id
err = ctx.Context.Send(ctx.Server.ID, id, &auth_signal) err = ctx.Context.Send(ctx.Server.ID, []Message{Message{id, auth_signal}})
if err != nil { if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID()) ctx.Ext.FreeResponseChannel(read_signal.ID())
return nil, err return nil, err
@ -79,11 +79,12 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
case *ReadResultSignal: case *ReadResultSignal:
responses = append(responses, NodeResult{node_ids[sig_id], resp}) responses = append(responses, NodeResult{node_ids[sig_id], resp})
case *ErrorSignal: case *ErrorSignal:
return nil, fmt.Errorf(resp.Str) return nil, fmt.Errorf(resp.Error)
default: default:
return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp)) return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp))
} }
} }
ctx.Context.Log.Logf("gql", "RESOLVED_NODES")
return responses, nil return responses, nil
} }

@ -11,30 +11,28 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"bytes" "bytes"
"github.com/google/uuid"
) )
func TestGQL(t *testing.T) { func TestGQL(t *testing.T) {
ctx := logTestContext(t, []string{}) ctx := logTestContext(t, []string{"gql", "lockable", "node_timeout", "listener"})
TestNodeType := NodeType("TEST") TestNodeType := NodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType, ACLExtType}) err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType})
fatalErr(t, err) fatalErr(t, err)
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped") gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped")
fatalErr(t, err) fatalErr(t, err)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
policy := NewAllNodesPolicy(Actions{MakeAction("+")}) gql := NewNode(ctx, nil, GQLNodeType, 10, nil, NewLockableExt(), gql_ext, NewGroupExt(nil), listener_ext)
start_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "start_server"} n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt())
gql := NewNode(ctx, nil, GQLNodeType, 10, []QueuedSignal{
QueuedSignal{uuid.New(), &start_signal, time.Now()},
}, NewLockableExt(), NewACLExt(&policy), gql_ext, NewGroupExt(nil), listener_ext)
n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(&policy))
err = LinkRequirement(ctx, gql.ID, n1.ID) err = LinkRequirement(ctx, gql.ID, n1.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool { err = ctx.Send(gql.ID, []Message{{gql.ID, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}}})
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "server_started" return sig.Str == "server_started"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -86,8 +84,8 @@ func TestGQL(t *testing.T) {
resp_2 := SendGQL(req_2) resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2) ctx.Log.Logf("test", "RESP_2: %s", resp_2)
stop_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"} stop_signal := StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}
ctx.Send(n1.ID, gql.ID, &stop_signal) ctx.Send(n1.ID, []Message{{gql.ID, &stop_signal}})
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool { _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool {
return sig.Str == "server_stopped" return sig.Str == "server_stopped"
}) })
@ -109,13 +107,10 @@ func TestGQLDB(t *testing.T) {
gql := NewNode(ctx, nil, GQLNodeType, 10, nil, gql := NewNode(ctx, nil, GQLNodeType, 10, nil,
gql_ext, gql_ext,
listener_ext, listener_ext,
NewACLExt(),
NewGroupExt(nil)) NewGroupExt(nil))
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
err = ctx.Send(gql.ID, gql.ID, &StopSignal) ctx.Stop()
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql.ID return sig.Str == "stopped" && sig.NodeID == gql.ID
}) })
@ -134,7 +129,7 @@ func TestGQLDB(t *testing.T) {
fatalErr(t, err) fatalErr(t, err)
listener_ext, err = GetExt[*ListenerExt](gql_loaded) listener_ext, err = GetExt[*ListenerExt](gql_loaded)
fatalErr(t, err) fatalErr(t, err)
err = ctx.Send(gql_loaded.ID, gql_loaded.ID, &StopSignal) err = ctx.Send(gql_loaded.ID, []Message{{gql_loaded.ID, &StopSignal}})
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID

@ -9,7 +9,6 @@ import (
const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER") const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER")
func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
policy := NewAllNodesPolicy(Actions{MakeAction("status")})
listener_extension := NewListenerExt(buffer) listener_extension := NewListenerExt(buffer)
listener := NewNode(ctx, listener := NewNode(ctx,
nil, nil,
@ -17,7 +16,6 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
10, 10,
nil, nil,
listener_extension, listener_extension,
NewACLExt(&policy),
NewLockableExt()) NewLockableExt())
return listener, listener_extension return listener, listener_extension
@ -32,7 +30,7 @@ func logTestContext(t * testing.T, components []string) *Context {
ctx, err := NewContext(db, NewConsoleLogger(components)) ctx, err := NewContext(db, NewConsoleLogger(components))
fatalErr(t, err) fatalErr(t, err)
err = ctx.RegisterNodeType(SimpleListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType}) err = ctx.RegisterNodeType(SimpleListenerNodeType, []ExtType{ListenerExtType, LockableExtType})
fatalErr(t, err) fatalErr(t, err)
return ctx return ctx

@ -7,14 +7,14 @@ import (
// A Listener extension provides a channel that can receive signals on a different thread // A Listener extension provides a channel that can receive signals on a different thread
type ListenerExt struct { type ListenerExt struct {
Buffer int Buffer int
Chan chan Signal Chan chan Message
} }
// Create a new listener extension with a given buffer size // Create a new listener extension with a given buffer size
func NewListenerExt(buffer int) *ListenerExt { func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{ return &ListenerExt{
Buffer: buffer, Buffer: buffer,
Chan: make(chan Signal, buffer), Chan: make(chan Message, buffer),
} }
} }
@ -32,7 +32,7 @@ func (ext *ListenerExt) Field(name string) interface{} {
// Simple load function, unmarshal the buffer int from json // Simple load function, unmarshal the buffer int from json
func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error { func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error {
err := json.Unmarshal(data, &ext.Buffer) err := json.Unmarshal(data, &ext.Buffer)
ext.Chan = make(chan Signal, ext.Buffer) ext.Chan = make(chan Message, ext.Buffer)
return err return err
} }
@ -41,14 +41,14 @@ func (listener *ListenerExt) Type() ExtType {
} }
// Send the signal to the channel, logging an overflow if it occurs // Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { func (ext *ListenerExt) Process(ctx *Context, node *Node, msg Message) []Message {
ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal) ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, msg.Signal)
select { select {
case ext.Chan <- signal: case ext.Chan <- msg:
default: default:
ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID) ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID)
} }
return return nil
} }
// ReqState holds the multiple states of a requirement // ReqState holds the multiple states of a requirement
@ -138,43 +138,38 @@ func NewLockableExt() *LockableExt {
// Send the signal to unlock a node from itself // Send the signal to unlock a node from itself
func UnlockLockable(ctx *Context, node *Node) error { func UnlockLockable(ctx *Context, node *Node) error {
lock_signal := NewLockSignal("unlock") return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("unlock")}})
return ctx.Send(node.ID, node.ID, &lock_signal)
} }
// Send the signal to lock a node from itself // Send the signal to lock a node from itself
func LockLockable(ctx *Context, node *Node) error { func LockLockable(ctx *Context, node *Node) error {
lock_signal := NewLockSignal("lock") return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("lock")}})
return ctx.Send(node.ID, node.ID, &lock_signal)
} }
// Setup a node to send the initial requirement link signal, then send the signal // Setup a node to send the initial requirement link signal, then send the signal
func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error { func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error {
start_signal := NewLinkStartSignal("req", requirement) return ctx.Send(dependency, []Message{Message{dependency, NewLinkStartSignal("req", requirement)}})
return ctx.Send(dependency, dependency, &start_signal)
} }
// Handle a LockSignal and update the extensions owner/requirement states // Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) { func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message {
ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal)
state := signal.Str state := signal.Str
log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal)
messages := []Message{}
switch state { switch state {
case "unlock": case "unlock":
if ext.Owner == nil { if ext.Owner == nil {
resp := NewErrorSignal(signal.ID(), "already_unlocked") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocked")})
ctx.Send(node.ID, source, &resp)
} else if source != *ext.Owner { } else if source != *ext.Owner {
resp := NewErrorSignal(signal.ID(), "not_owner") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_owner")})
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner == nil { } else if ext.PendingOwner == nil {
resp := NewErrorSignal(signal.ID(), "already_unlocking") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocking")})
ctx.Send(node.ID, source, &resp)
} else { } else {
if len(ext.Requirements) == 0 { if len(ext.Requirements) == 0 {
ext.Owner = nil ext.Owner = nil
ext.PendingOwner = nil ext.PendingOwner = nil
resp := NewLockSignal("unlocked") messages = append(messages, Message{source, NewLockSignal("unlocked")})
ctx.Send(node.ID, source, &resp)
} else { } else {
ext.PendingOwner = nil ext.PendingOwner = nil
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
@ -184,44 +179,36 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
} }
state.Lock = "unlocking" state.Lock = "unlocking"
ext.Requirements[id] = state ext.Requirements[id] = state
resp := NewLockSignal("unlock") messages = append(messages, Message{id, NewLockSignal("unlock")})
ctx.Send(node.ID, id, &resp)
} }
} }
if source != node.ID { if source != node.ID {
resp := NewLockSignal("unlocking") messages = append(messages, Message{source, NewLockSignal("unlocking")})
ctx.Send(node.ID, source, &resp)
} }
} }
} }
case "unlocking": case "unlocking":
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
resp := NewErrorSignal(signal.ID(), "not_requirement") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" { } else if state.Link != "linked" {
resp := NewErrorSignal(signal.ID(), "not_linked") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")})
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "unlocking" { } else if state.Lock != "unlocking" {
resp := NewErrorSignal(signal.ID(), "not_unlocking") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")})
ctx.Send(node.ID, source, &resp)
} }
case "unlocked": case "unlocked":
if source == node.ID { if source == node.ID {
return return nil
} }
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
resp := NewErrorSignal(signal.ID(), "not_requirement") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" { } else if state.Link != "linked" {
resp := NewErrorSignal(signal.ID(), "not_linked") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")})
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "unlocking" { } else if state.Lock != "unlocking" {
resp := NewErrorSignal(signal.ID(), "not_unlocking") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")})
ctx.Send(node.ID, source, &resp)
} else { } else {
state.Lock = "unlocked" state.Lock = "unlocked"
ext.Requirements[source] = state ext.Requirements[source] = state
@ -241,26 +228,22 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
if linked == unlocked { if linked == unlocked {
previous_owner := *ext.Owner previous_owner := *ext.Owner
ext.Owner = nil ext.Owner = nil
resp := NewLockSignal("unlocked") messages = append(messages, Message{previous_owner, NewLockSignal("unlocked")})
ctx.Send(node.ID, previous_owner, &resp)
} }
} }
} }
case "locked": case "locked":
if source == node.ID { if source == node.ID {
return return nil
} }
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
resp := NewErrorSignal(signal.ID(), "not_requirement") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" { } else if state.Link != "linked" {
resp := NewErrorSignal(signal.ID(), "not_linked") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")})
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "locking" { } else if state.Lock != "locking" {
resp := NewErrorSignal(signal.ID(), "not_locking") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")})
ctx.Send(node.ID, source, &resp)
} else { } else {
state.Lock = "locked" state.Lock = "locked"
ext.Requirements[source] = state ext.Requirements[source] = state
@ -279,176 +262,142 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
if linked == locked { if linked == locked {
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
resp := NewLockSignal("locked") messages = append(messages, Message{*ext.Owner, NewLockSignal("locked")})
ctx.Send(node.ID, *ext.Owner, &resp)
} }
} }
} }
case "locking": case "locking":
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
resp := NewErrorSignal(signal.ID(), "not_requirement") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" { } else if state.Link != "linked" {
resp := NewErrorSignal(signal.ID(), "not_linked") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")})
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "locking" { } else if state.Lock != "locking" {
resp := NewErrorSignal(signal.ID(), "not_locking") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")})
ctx.Send(node.ID, source, &resp)
} }
case "lock": case "lock":
if ext.Owner != nil { if ext.Owner != nil {
resp := NewErrorSignal(signal.ID(), "already_locked") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locked")})
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != nil { } else if ext.PendingOwner != nil {
resp := NewErrorSignal(signal.ID(), "already_locking") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locking")})
ctx.Send(node.ID, source, &resp)
} else { } else {
owner := source owner := source
if len(ext.Requirements) == 0 { if len(ext.Requirements) == 0 {
ext.Owner = &owner ext.Owner = &owner
ext.PendingOwner = ext.Owner ext.PendingOwner = ext.Owner
resp := NewLockSignal("locked") messages = append(messages, Message{source, NewLockSignal("locked")})
ctx.Send(node.ID, source, &resp)
} else { } else {
ext.PendingOwner = &owner ext.PendingOwner = &owner
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
if state.Link == "linked" { if state.Link == "linked" {
log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id)
if state.Lock != "unlocked" { if state.Lock != "unlocked" {
panic("NOT_UNLOCKED") panic("NOT_UNLOCKED")
} }
state.Lock = "locking" state.Lock = "locking"
ext.Requirements[id] = state ext.Requirements[id] = state
sub := NewLockSignal("lock") messages = append(messages, Message{id, NewLockSignal("lock")})
ctx.Send(node.ID, id, &sub)
} }
} }
if source != node.ID { if source != node.ID {
resp := NewLockSignal("locking") messages = append(messages, Message{source, NewLockSignal("locking")})
ctx.Send(node.ID, source, &resp)
} }
} }
} }
default: default:
ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", state) log.Logf("lockable", "LOCK_ERR: unkown state %s", state)
} }
log.Logf("lockable", "LOCK_MESSAGES: %+v", messages)
return messages
} }
func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal *IDStringSignal) { func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) []Message {
ctx.Log.Logf("lockable", "LINK__START_SIGNAL: %s->%s %+v", source, node.ID, signal)
link_type := signal.Str link_type := signal.Str
target := signal.NodeID target := signal.NodeID
log.Logf("lockable", "LINK_START_SIGNAL: %s->%s %s %s", source, node.ID, link_type, target)
messages := []Message{}
switch link_type { switch link_type {
case "req": case "req":
state, exists := ext.Requirements[target] state, exists := ext.Requirements[target]
_, dep_exists := ext.Dependencies[target] _, dep_exists := ext.Dependencies[target]
if ext.Owner != nil { if ext.Owner != nil {
resp := NewLinkStartSignal("locked", target) messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already locked")})
ctx.Send(node.ID, source, &resp)
} else if ext.Owner != ext.PendingOwner { } else if ext.Owner != ext.PendingOwner {
if ext.PendingOwner == nil { if ext.PendingOwner == nil {
resp := NewLinkStartSignal("unlocking", target) messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")})
ctx.Send(node.ID, source, &resp)
} else { } else {
resp := NewLinkStartSignal("locking", target) messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")})
ctx.Send(node.ID, source, &resp)
} }
} else if exists == true { } else if exists == true {
if state.Link == "linking" { if state.Link == "linking" {
resp := NewErrorSignal(signal.ID(), "already_linking_req") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking_req")})
ctx.Send(node.ID, source, &resp)
} else if state.Link == "linked" { } else if state.Link == "linked" {
resp := NewErrorSignal(signal.ID(), "already_req") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_req")})
ctx.Send(node.ID, source, &resp)
} }
} else if dep_exists == true { } else if dep_exists == true {
resp := NewLinkStartSignal("already_dep", target) messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_dep")})
ctx.Send(node.ID, source, &resp)
} else { } else {
ext.Requirements[target] = LinkState{"linking", "unlocked", source} ext.Requirements[target] = LinkState{"linking", "unlocked", source}
resp := NewLinkSignal("linked_as_req") messages = append(messages, Message{target, NewLinkSignal("linked_as_req")})
ctx.Send(node.ID, target, &resp) messages = append(messages, Message{source, NewLinkStartSignal("linking_req", target)})
notify := NewLinkStartSignal("linking_req", target)
ctx.Send(node.ID, source, &notify)
} }
} }
return messages
} }
// Handle LinkSignal, updating the extensions requirements and dependencies as necessary // Handle LinkSignal, updating the extensions requirements and dependencies as necessary
// TODO: Add unlink // TODO: Add unlink
func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) { func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message {
ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal)
state := signal.Str state := signal.Str
messages := []Message{}
switch state { switch state {
case "linked_as_dep": case "dep_done":
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == true && state.Link == "linked" { if exists == false {
resp := NewLinkStartSignal("linked_as_req", source) messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linking")})
ctx.Send(node.ID, state.Initiator, &resp)
} else if state.Link == "linking" { } else if state.Link == "linking" {
state.Link = "linked" state.Link = "linked"
ext.Requirements[source] = state ext.Requirements[source] = state
resp := NewLinkSignal("linked_as_req") log.Logf("lockable", "FINISHED_LINKING_REQ: %s->%s", node.ID, source)
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != ext.Owner {
if ext.Owner == nil {
resp := NewLinkSignal("locking")
ctx.Send(node.ID, source, &resp)
} else {
resp := NewLinkSignal("unlocking")
ctx.Send(node.ID, source, &resp)
}
} else {
ext.Requirements[source] = LinkState{"linking", "unlocked", source}
resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, source, &resp)
} }
ctx.Log.Logf("lockable", "%s is a dependency of %s", node.ID, source)
case "linked_as_req": case "linked_as_req":
state, exists := ext.Dependencies[source] state, exists := ext.Dependencies[source]
if exists == true && state.Link == "linked" { if exists == false {
resp := NewLinkStartSignal("linked_as_dep", source) ext.Dependencies[source] = LinkState{"linked", "unlocked", source}
ctx.Send(node.ID, state.Initiator, &resp) messages = append(messages, Message{source, NewLinkSignal("dep_done")})
} else if state.Link == "linking" { } else if state.Link == "linking" {
state.Link = "linked" messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking")})
ext.Dependencies[source] = state } else if state.Link == "linked" {
resp := NewLinkSignal("linked_as_dep") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linked")})
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != ext.Owner { } else if ext.PendingOwner != ext.Owner {
if ext.Owner == nil { if ext.Owner == nil {
resp := NewLinkSignal("locking") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")})
ctx.Send(node.ID, source, &resp)
} else { } else {
resp := NewLinkSignal("unlocking") messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")})
ctx.Send(node.ID, source, &resp)
} }
} else {
ext.Dependencies[source] = LinkState{"linking", "unlocked", source}
resp := NewLinkSignal("linked_as_dep")
ctx.Send(node.ID, source, &resp)
} }
ctx.Log.Logf("lockable", "%s is a requirement of %s", node.ID, source)
default: default:
ctx.Log.Logf("lockable", "LINK_ERROR: unknown state %s", state) log.Logf("lockable", "LINK_ERROR: unknown state %s", state)
} }
return messages
} }
// LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes // LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes
// LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state
func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { func (ext *LockableExt) Process(ctx *Context, node *Node, msg Message) []Message {
switch signal.Direction() { messages := []Message{}
switch msg.Signal.Direction() {
case Up: case Up:
ctx.Log.Logf("lockable", "LOCKABLE_DEPENDENCIES: %+v", ext.Dependencies)
owner_sent := false owner_sent := false
for dependency, state := range(ext.Dependencies) { for dependency, state := range(ext.Dependencies) {
if state.Link == "linked" { if state.Link == "linked" {
err := ctx.Send(node.ID, dependency, signal) messages = append(messages, Message{dependency, msg.Signal})
if err != nil {
ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, dependency, err)
}
if ext.Owner != nil { if ext.Owner != nil {
if dependency == *ext.Owner { if dependency == *ext.Owner {
owner_sent = true owner_sent = true
@ -459,32 +408,27 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal
if ext.Owner != nil && owner_sent == false { if ext.Owner != nil && owner_sent == false {
if *ext.Owner != node.ID { if *ext.Owner != node.ID {
err := ctx.Send(node.ID, *ext.Owner, signal) messages = append(messages, Message{*ext.Owner, msg.Signal})
if err != nil {
ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, *ext.Owner, err)
}
} }
} }
case Down: case Down:
for requirement, state := range(ext.Requirements) { for requirement, state := range(ext.Requirements) {
if state.Link == "linked" { if state.Link == "linked" {
err := ctx.Send(node.ID, requirement, signal) messages = append(messages, Message{requirement, msg.Signal})
if err != nil {
ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, requirement, err)
}
} }
} }
case Direct: case Direct:
switch signal.Type() { switch msg.Signal.Type() {
case LinkSignalType: case LinkSignalType:
ext.HandleLinkSignal(ctx, source, node, signal.(*StringSignal)) messages = ext.HandleLinkSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal))
case LockSignalType: case LockSignalType:
ext.HandleLockSignal(ctx, source, node, signal.(*StringSignal)) messages = ext.HandleLockSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal))
case LinkStartSignalType: case LinkStartSignalType:
ext.HandleLinkStartSignal(ctx, source, node, signal.(*IDStringSignal)) messages = ext.HandleLinkStartSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*IDStringSignal))
default: default:
} }
default: default:
} }
return messages
} }

@ -9,7 +9,7 @@ const TestLockableType = NodeType("TEST_LOCKABLE")
func lockableTestContext(t *testing.T, logs []string) *Context { func lockableTestContext(t *testing.T, logs []string) *Context {
ctx := logTestContext(t, logs) ctx := logTestContext(t, logs)
err := ctx.RegisterNodeType(TestLockableType, []ExtType{ACLExtType, LockableExtType}) err := ctx.RegisterNodeType(TestLockableType, []ExtType{LockableExtType})
fatalErr(t, err) fatalErr(t, err)
return ctx return ctx
@ -26,13 +26,11 @@ func TestLink(t *testing.T) {
l1_listener := NewListenerExt(10) l1_listener := NewListenerExt(10)
l1 := NewNode(ctx, nil, TestLockableType, 10, nil, l1 := NewNode(ctx, nil, TestLockableType, 10, nil,
l1_listener, l1_listener,
NewACLExt(&link_policy),
NewLockableExt(), NewLockableExt(),
) )
l2_listener := NewListenerExt(10) l2_listener := NewListenerExt(10)
l2 := NewNode(ctx, nil, TestLockableType, 10, nil, l2 := NewNode(ctx, nil, TestLockableType, 10, nil,
l2_listener, l2_listener,
NewACLExt(&link_policy),
NewLockableExt(), NewLockableExt(),
) )
@ -40,13 +38,13 @@ func TestLink(t *testing.T) {
err := LinkRequirement(ctx, l1.ID, l2.ID) err := LinkRequirement(ctx, l1.ID, l2.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool {
return sig.Str == "linked_as_req" return sig.Str == "dep_done"
}) })
fatalErr(t, err) fatalErr(t, err)
sig1 := NewStatusSignal("TEST", l2.ID) sig1 := NewStatusSignal("TEST", l2.ID)
err = ctx.Send(l2.ID, l2.ID, &sig1) err = ctx.Send(l2.ID, []Message{{l2.ID, sig1}})
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool {
@ -65,7 +63,6 @@ func TestLink10K(t *testing.T) {
NewLockable := func()(*Node) { NewLockable := func()(*Node) {
l := NewNode(ctx, nil, TestLockableType, 10, nil, l := NewNode(ctx, nil, TestLockableType, 10, nil,
NewACLExt(&lock_policy, &link_policy),
NewLockableExt(), NewLockableExt(),
) )
return l return l
@ -75,7 +72,6 @@ func TestLink10K(t *testing.T) {
listener := NewListenerExt(100000) listener := NewListenerExt(100000)
l := NewNode(ctx, nil, TestLockableType, 256, nil, l := NewNode(ctx, nil, TestLockableType, 256, nil,
listener, listener,
NewACLExt(&lock_policy, &link_policy),
NewLockableExt(), NewLockableExt(),
) )
return l, listener return l, listener
@ -92,8 +88,8 @@ func TestLink10K(t *testing.T) {
for range(lockables) { for range(lockables) {
_, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool { _, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool {
return sig.Str == "linked_as_req" return sig.Str == "dep_done"
}) })
fatalErr(t, err) fatalErr(t, err)
} }
@ -102,13 +98,12 @@ func TestLink10K(t *testing.T) {
} }
func TestLock(t *testing.T) { func TestLock(t *testing.T) {
ctx := lockableTestContext(t, []string{}) ctx := lockableTestContext(t, []string{"lockable", "listener"})
NewLockable := func()(*Node, *ListenerExt) { NewLockable := func()(*Node, *ListenerExt) {
listener := NewListenerExt(100) listener := NewListenerExt(100)
l := NewNode(ctx, nil, TestLockableType, 10, nil, l := NewNode(ctx, nil, TestLockableType, 10, nil,
listener, listener,
NewACLExt(&lock_policy, &link_policy),
NewLockableExt(), NewLockableExt(),
) )
return l, listener return l, listener
@ -141,30 +136,30 @@ func TestLock(t *testing.T) {
err = LinkRequirement(ctx, l0.ID, l5.ID) err = LinkRequirement(ctx, l0.ID, l5.ID)
fatalErr(t, err) fatalErr(t, err)
linked_as_req := func(sig *IDStringSignal) bool { linked_as_req := func(sig *StringSignal) bool {
return sig.Str == "linked_as_req" return sig.Str == "dep_done"
} }
locked := func(sig *StringSignal) bool { locked := func(sig *StringSignal) bool {
return sig.Str == "locked" return sig.Str == "locked"
} }
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
err = LockLockable(ctx, l1) err = LockLockable(ctx, l1)

@ -20,9 +20,10 @@ const (
// Magic first four bytes of serialized DB content, stored big endian // Magic first four bytes of serialized DB content, stored big endian
NODE_DB_MAGIC = 0x2491df14 NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function // Total length of the node database header, has magic to verify and type_hash to map to load function
NODE_DB_HEADER_LEN = 28 NODE_DB_HEADER_LEN = 32
EXTENSION_DB_HEADER_LEN = 16 EXTENSION_DB_HEADER_LEN = 16
QSIGNAL_DB_HEADER_LEN = 40 QSIGNAL_DB_HEADER_LEN = 40
POLICY_DB_HEADER_LEN = 16
) )
var ( var (
@ -82,7 +83,7 @@ func RandID() NodeID {
return NodeID(uuid.New()) return NodeID(uuid.New())
} }
// A Serializable has a type that can be used to map to it, and a function to serialize the current state // A Serializable has a type that can be used to map to it, and a function to serialize` the current state
type Serializable[I comparable] interface { type Serializable[I comparable] interface {
Serialize()([]byte,error) Serialize()([]byte,error)
Deserialize(*Context,[]byte)error Deserialize(*Context,[]byte)error
@ -93,7 +94,7 @@ type Serializable[I comparable] interface {
type Extension interface { type Extension interface {
Serializable[ExtType] Serializable[ExtType]
Field(string)interface{} Field(string)interface{}
Process(context *Context, source NodeID, node *Node, signal Signal) Process(ctx *Context, node *Node, message Message)[]Message
} }
// A QueuedSignal is a Signal that has been Queued to trigger at a set time // A QueuedSignal is a Signal that has been Queued to trigger at a set time
@ -110,9 +111,10 @@ type Node struct {
ID NodeID ID NodeID
Type NodeType Type NodeType
Extensions map[ExtType]Extension Extensions map[ExtType]Extension
Policies map[PolicyType]Policy
// Channel for this node to receive messages from the Context // Channel for this node to receive messages from the Context
MsgChan chan Msg MsgChan chan Message
// Size of MsgChan // Size of MsgChan
BufferSize uint32 BufferSize uint32
// Channel for this node to process delayed signals // Channel for this node to process delayed signals
@ -124,6 +126,18 @@ type Node struct {
NextSignal *QueuedSignal NextSignal *QueuedSignal
} }
func (node *Node) Allows(principal_id NodeID, action Action) error {
errs := []error{}
for _, policy := range(node.Policies) {
err := policy.Allows(principal_id, action, node)
if err == nil {
return nil
}
errs = append(errs, err)
}
return fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs)
}
func (node *Node) QueueSignal(time time.Time, signal Signal) uuid.UUID { func (node *Node) QueueSignal(time time.Time, signal Signal) uuid.UUID {
id := uuid.New() id := uuid.New()
node.SignalQueue = append(node.SignalQueue, QueuedSignal{id, signal, time}) node.SignalQueue = append(node.SignalQueue, QueuedSignal{id, signal, time})
@ -163,17 +177,12 @@ func runNode(ctx *Context, node *Node) {
ctx.Log.Logf("node", "RUN_STOP: %s", node.ID) ctx.Log.Logf("node", "RUN_STOP: %s", node.ID)
} }
type Msg struct {
Source NodeID
Signal Signal
}
func ReadNodeFields(ctx *Context, self *Node, princ NodeID, reqs map[ExtType][]string)map[ExtType]map[string]interface{} { func ReadNodeFields(ctx *Context, self *Node, princ NodeID, reqs map[ExtType][]string)map[ExtType]map[string]interface{} {
exts := map[ExtType]map[string]interface{}{} exts := map[ExtType]map[string]interface{}{}
for ext_type, field_reqs := range(reqs) { for ext_type, field_reqs := range(reqs) {
fields := map[string]interface{}{} fields := map[string]interface{}{}
for _, req := range(field_reqs) { for _, req := range(field_reqs) {
err := Allowed(ctx, princ, MakeAction(ReadSignalType, ext_type, req), self) err := self.Allows(princ, MakeAction(ReadSignalType, ext_type, req))
if err != nil { if err != nil {
fields[req] = err fields[req] = err
} else { } else {
@ -198,27 +207,18 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
// Perform startup actions // Perform startup actions
node.Process(ctx, node.ID, &StartSignal) node.Process(ctx, Message{ZeroID, &StartSignal})
for true { for true {
var signal Signal var msg Message
var source NodeID
select { select {
case msg := <- node.MsgChan: case msg = <- node.MsgChan:
ctx.Log.Logf("signal", "NODE_MSG: %s - %+v", node.ID, msg) ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal.Type())
signal = msg.Signal
source = msg.Source
err := Allowed(ctx, msg.Source, signal.Permission(), node)
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err)
resp := NewErrorSignal(msg.Signal.ID(), err.Error())
ctx.Send(node.ID, msg.Source, &resp)
continue
}
case <-node.TimeoutChan: case <-node.TimeoutChan:
signal = node.NextSignal.Signal signal := node.NextSignal.Signal
msg = Message{node.ID, signal}
t := node.NextSignal.Time t := node.NextSignal.Time
source = node.ID
i := -1 i := -1
for j, queued := range(node.SignalQueue) { for j, queued := range(node.SignalQueue) {
if queued.UUID == node.NextSignal.UUID { if queued.UUID == node.NextSignal.UUID {
@ -235,17 +235,17 @@ func nodeLoop(ctx *Context, node *Node) error {
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
if node.NextSignal == nil { if node.NextSignal == nil {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil@%+v", node.ID, t, signal, node.TimeoutChan) ctx.Log.Logf("node_timeout", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL nil@%+v", node.ID, signal.Type(), t, node.TimeoutChan)
} else { } else {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, t, signal, node.NextSignal, node.NextSignal.Time) ctx.Log.Logf("node_timeout", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL: %s@%s", node.ID, signal.Type(), t, node.NextSignal, node.NextSignal.Time)
} }
} }
// Unwrap Authorized Signals // Unwrap Authorized Signals
if signal.Type() == AuthorizedSignalType { if msg.Signal.Type() == AuthorizedSignalType {
sig, ok := signal.(*AuthorizedSignal) sig, ok := msg.Signal.(*AuthorizedSignal)
if ok == false { if ok == false {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(signal)) ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(msg.Signal))
} else { } else {
// Validate // Validate
sig_data, err := sig.Signal.Serialize() sig_data, err := sig.Signal.Serialize()
@ -253,45 +253,40 @@ func nodeLoop(ctx *Context, node *Node) error {
} else { } else {
validated := ed25519.Verify(sig.Principal, sig_data, sig.Signature) validated := ed25519.Verify(sig.Principal, sig_data, sig.Signature)
if validated == true { if validated == true {
err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node) err := node.Allows(KeyID(sig.Principal), sig.Signal.Permission())
if err != nil { if err != nil {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err) ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err)
resp := NewErrorSignal(sig.ID(), err.Error()) ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.ID(), err.Error())}})
ctx.Send(node.ID, source, &resp)
} else { } else {
// Unwrap the signal without changing the source // Unwrap the signal without changing the source
signal = sig.Signal msg = Message{msg.NodeID, sig.Signal}
} }
} else { } else {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate") ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate")
resp := NewErrorSignal(sig.ID(), "signature validation failed") ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.ID(), "signature validation failed")}})
ctx.Send(node.ID, source, &resp)
} }
} }
} }
} }
ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) ctx.Log.Logf("node_signal_queue", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
// Handle special signal types // Handle special signal types
if signal.Type() == StopSignalType { if msg.Signal.Type() == StopSignalType {
resp := NewErrorSignal(signal.ID(), "stopped") ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(msg.Signal.ID(), "stopped")}})
ctx.Send(node.ID, source, &resp) node.Process(ctx, Message{node.ID, NewStatusSignal("stopped", node.ID)})
status := NewStatusSignal("stopped", node.ID)
node.Process(ctx, node.ID, &status)
break break
} else if signal.Type() == ReadSignalType { } else if msg.Signal.Type() == ReadSignalType {
read_signal, ok := signal.(*ReadSignal) read_signal, ok := msg.Signal.(*ReadSignal)
if ok == false { if ok == false {
ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) ctx.Log.Logf("signal_read", "READ_SIGNAL: bad cast %+v", msg.Signal)
} else { } else {
result := ReadNodeFields(ctx, node, source, read_signal.Extensions) result := ReadNodeFields(ctx, node, msg.NodeID, read_signal.Extensions)
resp := NewReadResultSignal(read_signal.ID(), node.Type, result) ctx.Send(node.ID, []Message{Message{msg.NodeID, NewReadResultSignal(read_signal.ID(), node.Type, result)}})
ctx.Send(node.ID, source, &resp)
} }
} }
node.Process(ctx, source, signal) node.Process(ctx, msg)
// assume that processing a signal means that this nodes state changed // assume that processing a signal means that this nodes state changed
// TODO: remove a lot of database writes by only writing when things change, // TODO: remove a lot of database writes by only writing when things change,
// so need to have Process return whether or not state changed // so need to have Process return whether or not state changed
@ -308,11 +303,24 @@ func nodeLoop(ctx *Context, node *Node) error {
return nil return nil
} }
func (node *Node) Process(ctx *Context, source NodeID, signal Signal) { type Message struct {
NodeID
Signal
}
func (node *Node) Process(ctx *Context, message Message) error {
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, message.Signal.Type())
messages := []Message{}
for ext_type, ext := range(node.Extensions) { for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("signal", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
ext.Process(ctx, source, node, signal) //TODO: add extension and node info to log
resp := ext.Process(ctx, node, message)
if resp != nil {
messages = append(messages, resp...)
}
} }
return ctx.Send(node.ID, messages)
} }
func GetCtx[T Extension, C any](ctx *Context) (C, error) { func GetCtx[T Extension, C any](ctx *Context) (C, error) {
@ -352,6 +360,7 @@ func GetExt[T Extension](node *Node) (T, error) {
func (node *Node) Serialize() ([]byte, error) { func (node *Node) Serialize() ([]byte, error) {
extensions := make([]ExtensionDB, len(node.Extensions)) extensions := make([]ExtensionDB, len(node.Extensions))
qsignals := make([]QSignalDB, len(node.SignalQueue)) qsignals := make([]QSignalDB, len(node.SignalQueue))
policies := make([]PolicyDB, len(node.Policies))
key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key) key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key)
if err != nil { if err != nil {
@ -365,6 +374,7 @@ func (node *Node) Serialize() ([]byte, error) {
KeyLength: uint32(len(key_bytes)), KeyLength: uint32(len(key_bytes)),
BufferSize: node.BufferSize, BufferSize: node.BufferSize,
NumExtensions: uint32(len(extensions)), NumExtensions: uint32(len(extensions)),
NumPolicies: uint32(len(policies)),
NumQueuedSignals: uint32(len(node.SignalQueue)), NumQueuedSignals: uint32(len(node.SignalQueue)),
}, },
Extensions: extensions, Extensions: extensions,
@ -405,6 +415,22 @@ func (node *Node) Serialize() ([]byte, error) {
} }
} }
i = 0
for _, policy := range(node.Policies) {
ser, err := policy.Serialize()
if err != nil {
return nil, err
}
node_db.Policies[i] = PolicyDB{
PolicyDBHeader{
Hash(policy.Type()),
uint64(len(ser)),
},
ser,
}
}
return node_db.Serialize(), nil return node_db.Serialize(), nil
} }
@ -415,7 +441,7 @@ func KeyID(pub ed25519.PublicKey) NodeID {
// Create a new node in memory and start it's event loop // Create a new node in memory and start it's event loop
// TODO: Change panics to errors // TODO: Change panics to errors
func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, queued_signals []QueuedSignal, extensions ...Extension) *Node { func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, policies map[PolicyType]Policy, extensions ...Extension) *Node {
var err error var err error
var public ed25519.PublicKey var public ed25519.PublicKey
if key == nil { if key == nil {
@ -453,22 +479,19 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
} }
} }
if queued_signals == nil { if policies == nil {
queued_signals = []QueuedSignal{} policies = map[PolicyType]Policy{}
} }
next_signal, timeout_chan := SoonestSignal(queued_signals)
node := &Node{ node := &Node{
Key: key, Key: key,
ID: id, ID: id,
Type: node_type, Type: node_type,
Extensions: ext_map, Extensions: ext_map,
MsgChan: make(chan Msg, buffer_size), Policies: policies,
MsgChan: make(chan Message, buffer_size),
BufferSize: buffer_size, BufferSize: buffer_size,
TimeoutChan: timeout_chan, SignalQueue: []QueuedSignal{},
SignalQueue: queued_signals,
NextSignal: next_signal,
} }
ctx.AddNode(id, node) ctx.AddNode(id, node)
err = WriteNode(ctx, node) err = WriteNode(ctx, node)
@ -476,41 +499,50 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
panic(err) panic(err)
} }
node.Process(ctx, node.ID, &NewSignal) node.Process(ctx, Message{node.ID, &NewSignal})
go runNode(ctx, node) go runNode(ctx, node)
return node return node
} }
func Allowed(ctx *Context, principal_id NodeID, action Action, node *Node) error { type PolicyDBHeader struct {
ctx.Log.Logf("policy", "POLICY_CHECK: %s -> %s.%s", principal_id, node.ID, action) TypeHash uint64
// Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension Length uint64
if principal_id == node.ID { }
ctx.Log.Logf("policy", "POLICY_CHECK_SAME_NODE: %s.%s", principal_id, action)
return nil
}
// Check if the node has a policy extension itself, and check against the policies in it type PolicyDB struct {
policy_ext, err := GetExt[*ACLExt](node) Header PolicyDBHeader
if err != nil { Data []byte
ctx.Log.Logf("policy", "POLICY_CHECK_NO_ACL_EXT: %s", node.ID) }
return err
}
err = policy_ext.Allows(ctx, principal_id, action, node) type QSignalDBHeader struct {
if err != nil { SignalID uuid.UUID
ctx.Log.Logf("policy", "POLICY_CHECK_FAIL: %s -> %s.%s : %s", principal_id, node.ID, action, err) Time time.Time
} else { TypeHash uint64
ctx.Log.Logf("policy", "POLICY_CHECK_PASS: %s -> %s.%s", principal_id, node.ID, action) Length uint64
} }
return err
type QSignalDB struct {
Header QSignalDBHeader
Data []byte
}
type ExtensionDBHeader struct {
TypeHash uint64
Length uint64
}
type ExtensionDB struct {
Header ExtensionDBHeader
Data []byte
} }
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node // A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
type NodeDBHeader struct { type NodeDBHeader struct {
Magic uint32 Magic uint32
NumExtensions uint32 NumExtensions uint32
NumPolicies uint32
NumQueuedSignals uint32 NumQueuedSignals uint32
BufferSize uint32 BufferSize uint32
KeyLength uint32 KeyLength uint32
@ -519,8 +551,9 @@ type NodeDBHeader struct {
type NodeDB struct { type NodeDB struct {
Header NodeDBHeader Header NodeDBHeader
QueuedSignals []QSignalDB
Extensions []ExtensionDB Extensions []ExtensionDB
Policies []PolicyDB
QueuedSignals []QSignalDB
KeyBytes []byte KeyBytes []byte
} }
@ -532,10 +565,11 @@ func NewNodeDB(data []byte) (NodeDB, error) {
magic := binary.BigEndian.Uint32(data[0:4]) magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8]) num_extensions := binary.BigEndian.Uint32(data[4:8])
num_queued_signals := binary.BigEndian.Uint32(data[8:12]) num_policies := binary.BigEndian.Uint32(data[8:12])
buffer_size := binary.BigEndian.Uint32(data[12:16]) num_queued_signals := binary.BigEndian.Uint32(data[12:16])
key_length := binary.BigEndian.Uint32(data[16:20]) buffer_size := binary.BigEndian.Uint32(data[16:20])
node_type_hash := binary.BigEndian.Uint64(data[20:28]) key_length := binary.BigEndian.Uint32(data[20:24])
node_type_hash := binary.BigEndian.Uint64(data[24:32])
ptr += NODE_DB_HEADER_LEN ptr += NODE_DB_HEADER_LEN
@ -573,6 +607,26 @@ func NewNodeDB(data []byte) (NodeDB, error) {
ptr += int(EXTENSION_DB_HEADER_LEN + length) ptr += int(EXTENSION_DB_HEADER_LEN + length)
} }
policies := make([]PolicyDB, num_policies)
for i, _ := range(policies) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(POLICY_DB_HEADER_LEN)
data_end := data_start + length
policy_data := cur[data_start:data_end]
policies[i] = PolicyDB{
PolicyDBHeader{
type_hash,
length,
},
policy_data,
}
ptr += int(POLICY_DB_HEADER_LEN + length)
}
queued_signals := make([]QSignalDB, num_queued_signals) queued_signals := make([]QSignalDB, num_queued_signals)
for i, _ := range(queued_signals) { for i, _ := range(queued_signals) {
cur := data[ptr:] cur := data[ptr:]
@ -626,10 +680,11 @@ func (header NodeDBHeader) Serialize() []byte {
ret := make([]byte, NODE_DB_HEADER_LEN) ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic) binary.BigEndian.PutUint32(ret[0:4], header.Magic)
binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions) binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
binary.BigEndian.PutUint32(ret[8:12], header.NumQueuedSignals) binary.BigEndian.PutUint32(ret[8:12], header.NumPolicies)
binary.BigEndian.PutUint32(ret[12:16], header.BufferSize) binary.BigEndian.PutUint32(ret[12:16], header.NumQueuedSignals)
binary.BigEndian.PutUint32(ret[16:20], header.KeyLength) binary.BigEndian.PutUint32(ret[16:20], header.BufferSize)
binary.BigEndian.PutUint64(ret[20:28], header.TypeHash) binary.BigEndian.PutUint32(ret[20:24], header.KeyLength)
binary.BigEndian.PutUint64(ret[24:32], header.TypeHash)
return ret return ret
} }
@ -673,28 +728,6 @@ func (extension ExtensionDB) Serialize() []byte {
return append(header_bytes, extension.Data...) return append(header_bytes, extension.Data...)
} }
type QSignalDBHeader struct {
SignalID uuid.UUID
Time time.Time
TypeHash uint64
Length uint64
}
type QSignalDB struct {
Header QSignalDBHeader
Data []byte
}
type ExtensionDBHeader struct {
TypeHash uint64
Length uint64
}
type ExtensionDB struct {
Header ExtensionDBHeader
Data []byte
}
// Write a node to the database // Write a node to the database
func WriteNode(ctx *Context, node *Node) error { func WriteNode(ctx *Context, node *Node) error {
ctx.Log.Logf("db", "DB_WRITE: %s", node.ID) ctx.Log.Logf("db", "DB_WRITE: %s", node.ID)
@ -740,6 +773,21 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, err return nil, err
} }
policies := make(map[PolicyType]Policy, node_db.Header.NumPolicies)
for _, policy_db := range(node_db.Policies) {
policy_info, exists := ctx.Policies[policy_db.Header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known policy type", policy_db.Header.TypeHash)
}
policy, err := policy_info.Load(ctx, policy_db.Data)
if err != nil {
return nil, err
}
policies[policy_info.Type] = policy
}
key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes) key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes)
if err != nil { if err != nil {
return nil, err return nil, err
@ -784,7 +832,8 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
ID: key_id, ID: key_id,
Type: node_type.Type, Type: node_type.Type,
Extensions: map[ExtType]Extension{}, Extensions: map[ExtType]Extension{},
MsgChan: make(chan Msg, node_db.Header.BufferSize), Policies: policies,
MsgChan: make(chan Message, node_db.Header.BufferSize),
BufferSize: node_db.Header.BufferSize, BufferSize: node_db.Header.BufferSize,
TimeoutChan: timeout_chan, TimeoutChan: timeout_chan,
SignalQueue: signal_queue, SignalQueue: signal_queue,

@ -23,7 +23,7 @@ func TestNodeDB(t *testing.T) {
func TestNodeRead(t *testing.T) { func TestNodeRead(t *testing.T) {
ctx := logTestContext(t, []string{}) ctx := logTestContext(t, []string{})
node_type := NodeType("TEST") node_type := NodeType("TEST")
err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType, ECDHExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType, ECDHExtType})
fatalErr(t, err) fatalErr(t, err)
n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader) n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader)
@ -37,21 +37,15 @@ func TestNodeRead(t *testing.T) {
ctx.Log.Logf("test", "N1: %s", n1_id) ctx.Log.Logf("test", "N1: %s", n1_id)
ctx.Log.Logf("test", "N2: %s", n2_id) ctx.Log.Logf("test", "N2: %s", n2_id)
n2_policy := NewPerNodePolicy(map[NodeID]Actions{
n1_id: Actions{MakeAction(ReadResultSignalType, "+")},
})
n2_listener := NewListenerExt(10) n2_listener := NewListenerExt(10)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(&n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener) n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt(), n2_listener)
n1_policy := NewPerNodePolicy(map[NodeID]Actions{ n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt())
n2_id: Actions{MakeAction(ReadSignalType, "+")},
})
n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(&n1_policy), NewGroupExt(nil), NewECDHExt())
read_sig := NewReadSignal(map[ExtType][]string{ read_sig := NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"}, GroupExtType: []string{"members"},
}) })
ctx.Send(n2.ID, n1.ID, &read_sig) ctx.Send(n2.ID, []Message{{n1.ID, &read_sig}})
res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool {
return true return true
@ -64,22 +58,20 @@ func TestECDH(t *testing.T) {
ctx := logTestContext(t, []string{"test", "ecdh", "policy"}) ctx := logTestContext(t, []string{"test", "ecdh", "policy"})
node_type := NodeType("TEST") node_type := NodeType("TEST")
err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, ECDHExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{ECDHExtType})
fatalErr(t, err) fatalErr(t, err)
n1_listener := NewListenerExt(10) n1_listener := NewListenerExt(10)
ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+"), MakeAction(ECDHProxySignalType, "+")}) n1 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt(), n1_listener)
n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt(), n1_listener) n2 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt())
n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt())
n3_listener := NewListenerExt(10) n3_listener := NewListenerExt(10)
n3_policy := NewPerNodePolicy(NodeActions{n1.ID: Actions{MakeAction(StopSignalType)}}) n3 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt(), n3_listener)
n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy, &n3_policy), NewECDHExt(), n3_listener)
ctx.Log.Logf("test", "N1: %s", n1.ID) ctx.Log.Logf("test", "N1: %s", n1.ID)
ctx.Log.Logf("test", "N2: %s", n2.ID) ctx.Log.Logf("test", "N2: %s", n2.ID)
ecdh_req, n1_ec, err := NewECDHReqSignal(ctx, n1) ecdh_req, n1_ec, err := NewECDHReqSignal(n1)
ecdh_ext, err := GetExt[*ECDHExt](n1) ecdh_ext, err := GetExt[*ECDHExt](n1)
fatalErr(t, err) fatalErr(t, err)
ecdh_ext.ECDHStates[n2.ID] = ECDHState{ ecdh_ext.ECDHStates[n2.ID] = ECDHState{
@ -88,7 +80,7 @@ func TestECDH(t *testing.T) {
} }
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "N1_EC: %+v", n1_ec) ctx.Log.Logf("test", "N1_EC: %+v", n1_ec)
err = ctx.Send(n1.ID, n2.ID, &ecdh_req) err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_req}})
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool { _, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool {
@ -100,6 +92,6 @@ func TestECDH(t *testing.T) {
ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret) ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret)
fatalErr(t, err) fatalErr(t, err)
err = ctx.Send(n1.ID, n2.ID, &ecdh_sig) err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_sig}})
fatalErr(t, err) fatalErr(t, err)
} }

@ -5,11 +5,8 @@ import (
"fmt" "fmt"
) )
type PolicyType string
func (policy PolicyType) Prefix() string { return "POLICY: " }
func (policy PolicyType) String() string { return string(policy) }
const ( const (
UserOfPolicyType = PolicyType("USER_OF")
RequirementOfPolicyType = PolicyType("REQUIREMENT_OF") RequirementOfPolicyType = PolicyType("REQUIREMENT_OF")
PerNodePolicyType = PolicyType("PER_NODE") PerNodePolicyType = PolicyType("PER_NODE")
AllNodesPolicyType = PolicyType("ALL_NODES") AllNodesPolicyType = PolicyType("ALL_NODES")
@ -38,7 +35,7 @@ func (policy PerNodePolicy) Allows(principal_id NodeID, action Action, node *Nod
return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID)
} }
func (policy RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error { func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error {
lockable_ext, err := GetExt[*LockableExt](node) lockable_ext, err := GetExt[*LockableExt](node)
if err != nil { if err != nil {
return err return err
@ -53,10 +50,45 @@ func (policy RequirementOfPolicy) Allows(principal_id NodeID, action Action, nod
return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID) return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID)
} }
type UserOfPolicy struct {
PerNodePolicy
}
func (policy *UserOfPolicy) Type() PolicyType {
return UserOfPolicyType
}
func NewUserOfPolicy(group_actions NodeActions) UserOfPolicy {
return UserOfPolicy{
PerNodePolicy: NewPerNodePolicy(group_actions),
}
}
// Send a read signal to Group to check if principal_id is a member of it
func (policy *UserOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error {
// Send a read signal to each of the groups in the map
// Check for principal_id in any of the returned member lists(skipping errors)
// Return an error in the default case
return fmt.Errorf("NOT_IMPLEMENTED")
}
func (policy *UserOfPolicy) Merge(p Policy) Policy {
other := p.(*UserOfPolicy)
policy.NodeActions = MergeNodeActions(policy.NodeActions, other.NodeActions)
return policy
}
func (policy *UserOfPolicy) Copy() Policy {
new_actions := CopyNodeActions(policy.NodeActions)
return &UserOfPolicy{
PerNodePolicy: NewPerNodePolicy(new_actions),
}
}
type RequirementOfPolicy struct { type RequirementOfPolicy struct {
AllNodesPolicy AllNodesPolicy
} }
func (policy RequirementOfPolicy) Type() PolicyType { func (policy *RequirementOfPolicy) Type() PolicyType {
return RequirementOfPolicyType return RequirementOfPolicyType
} }
@ -82,20 +114,25 @@ func CopyNodeActions(actions NodeActions) NodeActions {
return ret return ret
} }
func MergeNodeActions(modified NodeActions, read NodeActions) { func MergeNodeActions(first NodeActions, second NodeActions) NodeActions {
for id, actions := range(read) { merged := NodeActions{}
existing, exists := modified[id] for id, actions := range(first) {
merged[id] = actions
}
for id, actions := range(second) {
existing, exists := merged[id]
if exists { if exists {
modified[id] = MergeActions(existing, actions) merged[id] = MergeActions(existing, actions)
} else { } else {
modified[id] = actions merged[id] = actions
} }
} }
return merged
} }
func (policy *PerNodePolicy) Merge(p Policy) Policy { func (policy *PerNodePolicy) Merge(p Policy) Policy {
other := p.(*PerNodePolicy) other := p.(*PerNodePolicy)
MergeNodeActions(policy.NodeActions, other.NodeActions) policy.NodeActions = MergeNodeActions(policy.NodeActions, other.NodeActions)
return policy return policy
} }
@ -263,63 +300,6 @@ func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy) return json.Unmarshal(data, policy)
} }
// Extension to allow a node to hold ACL policies
type ACLExt struct {
Policies map[PolicyType]Policy
}
func NodeList(nodes ...*Node) NodeMap {
m := NodeMap{}
for _, node := range(nodes) {
m[node.ID] = node
}
return m
}
type PolicyLoadFunc func(*Context,[]byte) (Policy, error)
type ACLExtContext struct {
Loads map[PolicyType]PolicyLoadFunc
}
func NewACLExtContext() *ACLExtContext {
return &ACLExtContext{
Loads: map[PolicyType]PolicyLoadFunc{
AllNodesPolicyType: LoadPolicy[AllNodesPolicy,*AllNodesPolicy],
PerNodePolicyType: LoadPolicy[PerNodePolicy,*PerNodePolicy],
RequirementOfPolicyType: LoadPolicy[RequirementOfPolicy,*RequirementOfPolicy],
},
}
}
func (ext *ACLExt) Serialize() ([]byte, error) {
policies := map[string][]byte{}
for name, policy := range(ext.Policies) {
ser, err := policy.Serialize()
if err != nil {
return nil, err
}
policies[string(name)] = ser
}
return json.MarshalIndent(&struct{
Policies map[string][]byte `json:"policies"`
}{
Policies: policies,
}, "", " ")
}
func (ext *ACLExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
}
func (ext *ACLExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*ACLExt)interface{}{
"policies": func(ext *ACLExt) interface{} {
return ext.Policies
},
})
}
var ErrorSignalAction = Action{"ERROR_RESP"} var ErrorSignalAction = Action{"ERROR_RESP"}
var ReadResultSignalAction = Action{"READ_RESULT"} var ReadResultSignalAction = Action{"READ_RESULT"}
var AuthorizedSignalAction = Action{"AUTHORIZED_READ"} var AuthorizedSignalAction = Action{"AUTHORIZED_READ"}
@ -327,82 +307,3 @@ var defaultPolicy = NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSigna
var DefaultACLPolicies = []Policy{ var DefaultACLPolicies = []Policy{
&defaultPolicy, &defaultPolicy,
} }
func NewACLExt(policies ...Policy) *ACLExt {
policy_map := map[PolicyType]Policy{}
for _, policy_arg := range(append(policies, DefaultACLPolicies...)) {
policy := policy_arg.Copy()
existing, exists := policy_map[policy.Type()]
if exists == true {
policy = existing.Merge(policy)
}
policy_map[policy.Type()] = policy
}
return &ACLExt{
Policies: policy_map,
}
}
func LoadPolicy[T any, P interface {
*T
Policy
}](ctx *Context, data []byte) (Policy, error) {
p := P(new(T))
err := p.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return p, nil
}
func (ext *ACLExt) Deserialize(ctx *Context, data []byte) error {
var j struct {
Policies map[string][]byte `json:"policies"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return err
}
acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx)
if err != nil {
return err
}
ext.Policies = map[PolicyType]Policy{}
for name, ser := range(j.Policies) {
policy_load, exists := acl_ctx.Loads[PolicyType(name)]
if exists == false {
return fmt.Errorf("%s is not a known policy type", name)
}
policy, err := policy_load(ctx, ser)
if err != nil {
return err
}
ext.Policies[PolicyType(name)] = policy
}
return nil
}
func (ext *ACLExt) Type() ExtType {
return ACLExtType
}
// Check if the extension allows the principal to perform action on node
func (ext *ACLExt) Allows(ctx *Context, principal_id NodeID, action Action, node *Node) error {
errs := []error{}
for _, policy := range(ext.Policies) {
err := policy.Allows(principal_id, action, node)
if err == nil {
return nil
}
errs = append(errs, err)
}
return fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs)
}

@ -28,7 +28,6 @@ const (
ReadResultSignalType = "READ_RESULT" ReadResultSignalType = "READ_RESULT"
LinkStartSignalType = "LINK_START" LinkStartSignalType = "LINK_START"
ECDHSignalType = "ECDH" ECDHSignalType = "ECDH"
ECDHStateSignalType = "ECDH_STATE"
ECDHProxySignalType = "ECDH_PROXY" ECDHProxySignalType = "ECDH_PROXY"
GQLStateSignalType = "GQL_STATE" GQLStateSignalType = "GQL_STATE"
@ -43,6 +42,7 @@ func (signal_type SignalType) Prefix() string { return "SIGNAL: " }
type Signal interface { type Signal interface {
Serializable[SignalType] Serializable[SignalType]
String() string
Direction() SignalDirection Direction() SignalDirection
ID() uuid.UUID ID() uuid.UUID
Permission() Action Permission() Action
@ -70,12 +70,12 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.
} }
for true { for true {
select { select {
case signal := <- listener.Chan: case msg := <- listener.Chan:
if signal == nil { if msg.Signal == nil {
return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type) return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type)
} }
if signal.Type() == signal_type { if msg.Signal.Type() == signal_type {
sig, ok := signal.(S) sig, ok := msg.Signal.(S)
if ok == true { if ok == true {
if check(sig) == true { if check(sig) == true {
return sig, nil return sig, nil
@ -96,6 +96,11 @@ type BaseSignal struct {
UUID uuid.UUID `json:"id"` UUID uuid.UUID `json:"id"`
} }
func (signal *BaseSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
}
func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error { func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, signal) return json.Unmarshal(data, signal)
} }
@ -129,21 +134,9 @@ func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal
return signal return signal
} }
func NewDownSignal(signal_type SignalType) BaseSignal { var NewSignal = NewBaseSignal(NewSignalType, Direct)
return NewBaseSignal(signal_type, Down) var StartSignal = NewBaseSignal(StartSignalType, Direct)
} var StopSignal = NewBaseSignal(StopSignalType, Direct)
func NewUpSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Up)
}
func NewDirectSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Direct)
}
var NewSignal = NewDirectSignal(NewSignalType)
var StartSignal = NewDirectSignal(StartSignalType)
var StopSignal = NewDownSignal(StopSignalType)
type IDSignal struct { type IDSignal struct {
BaseSignal BaseSignal
@ -154,88 +147,91 @@ func (signal *IDSignal) Serialize() ([]byte, error) {
return json.Marshal(signal) return json.Marshal(signal)
} }
func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal {
return IDSignal{
BaseSignal: NewBaseSignal(signal_type, direction),
NodeID: id,
}
}
type StringSignal struct { type StringSignal struct {
BaseSignal BaseSignal
Str string `json:"state"` Str string `json:"state"`
} }
func (signal *StringSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
}
func (signal *StringSignal) Serialize() ([]byte, error) { func (signal *StringSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal) return json.Marshal(&signal)
} }
type RespSignal struct {
BaseSignal
ReqID uuid.UUID
}
type ErrorSignal struct { type ErrorSignal struct {
StringSignal RespSignal
Error string
}
func (signal *ErrorSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
} }
func (signal *ErrorSignal) Permission() Action { func (signal *ErrorSignal) Permission() Action {
return ErrorSignalAction return ErrorSignalAction
} }
func NewErrorSignal(req_id uuid.UUID, err string) ErrorSignal { func NewErrorSignal(req_id uuid.UUID, err string) Signal {
return ErrorSignal{ return &ErrorSignal{
StringSignal{ RespSignal{
NewDirectSignal(ErrorSignalType), NewBaseSignal(ErrorSignalType, Direct),
err, req_id,
}, },
err,
} }
} }
type IDStringSignal struct { type IDStringSignal struct {
BaseSignal BaseSignal
NodeID `json:"node_id"` NodeID NodeID `json:"node_id"`
Str string `json:"string"` Str string `json:"string"`
} }
func (signal *IDStringSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func (signal *IDStringSignal) String() string { func (signal *IDStringSignal) String() string {
ser, err := json.Marshal(signal) ser, _ := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
}
return string(ser) return string(ser)
} }
func NewStatusSignal(status string, source NodeID) IDStringSignal { func (signal *IDStringSignal) Serialize() ([]byte, error) {
return IDStringSignal{ return json.Marshal(signal)
BaseSignal: NewUpSignal(StatusSignalType), }
func NewStatusSignal(status string, source NodeID) Signal {
return &IDStringSignal{
BaseSignal: NewBaseSignal(StatusSignalType, Up),
NodeID: source, NodeID: source,
Str: status, Str: status,
} }
} }
func NewLinkSignal(state string) StringSignal { func NewLinkSignal(state string) Signal {
return StringSignal{ return &StringSignal{
BaseSignal: NewDirectSignal(LinkSignalType), BaseSignal: NewBaseSignal(LinkSignalType, Direct),
Str: state, Str: state,
} }
} }
func NewIDStringSignal(signal_type SignalType, direction SignalDirection, state string, id NodeID) IDStringSignal { func NewLinkStartSignal(link_type string, target NodeID) Signal {
return IDStringSignal{ return &IDStringSignal{
BaseSignal: NewBaseSignal(signal_type, direction), NewBaseSignal(LinkStartSignalType, Direct),
NodeID: id, target,
Str: state, link_type,
} }
} }
func NewLinkStartSignal(link_type string, target NodeID) IDStringSignal { func NewLockSignal(state string) Signal {
return NewIDStringSignal(LinkStartSignalType, Direct, link_type, target) return &StringSignal{
} NewBaseSignal(LockSignalType, Direct),
state,
func NewLockSignal(state string) StringSignal {
return StringSignal{
BaseSignal: NewDirectSignal(LockSignalType),
Str: state,
} }
} }
@ -259,22 +255,22 @@ func (signal *AuthorizedSignal) Permission() Action {
return AuthorizedSignalAction return AuthorizedSignalAction
} }
func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (AuthorizedSignal, error) { func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (Signal, error) {
sig_data, err := signal.Serialize() sig_data, err := signal.Serialize()
if err != nil { if err != nil {
return AuthorizedSignal{}, err return nil, err
} }
sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0)) sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil { if err != nil {
return AuthorizedSignal{}, err return nil, err
} }
return AuthorizedSignal{ return &AuthorizedSignal{
BaseSignal: NewDirectSignal(AuthorizedSignalType), NewBaseSignal(AuthorizedSignalType, Direct),
Principal: principal.Public().(ed25519.PublicKey), principal.Public().(ed25519.PublicKey),
Signal: signal, signal,
Signature: sig, sig,
}, nil }, nil
} }
@ -284,13 +280,13 @@ func (signal *ReadSignal) Serialize() ([]byte, error) {
func NewReadSignal(exts map[ExtType][]string) ReadSignal { func NewReadSignal(exts map[ExtType][]string) ReadSignal {
return ReadSignal{ return ReadSignal{
BaseSignal: NewDirectSignal(ReadSignalType), NewBaseSignal(ReadSignalType, Direct),
Extensions: exts, exts,
} }
} }
type ReadResultSignal struct { type ReadResultSignal struct {
BaseSignal RespSignal
NodeType NodeType
Extensions map[ExtType]map[string]interface{} `json:"extensions"` Extensions map[ExtType]map[string]interface{} `json:"extensions"`
} }
@ -299,15 +295,14 @@ func (signal *ReadResultSignal) Permission() Action {
return ReadResultSignalAction return ReadResultSignalAction
} }
func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]map[string]interface{}) ReadResultSignal { func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]map[string]interface{}) Signal {
return ReadResultSignal{ return &ReadResultSignal{
BaseSignal: BaseSignal{ RespSignal{
Direct, NewBaseSignal(ReadResultSignalType, Direct),
ReadResultSignalType,
req_id, req_id,
}, },
NodeType: node_type, node_type,
Extensions: exts, exts,
} }
} }
@ -341,28 +336,28 @@ func (signal *ECDHSignal) Serialize() ([]byte, error) {
return json.Marshal(signal) return json.Marshal(signal)
} }
func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, error) { func NewECDHReqSignal(node *Node) (Signal, *ecdh.PrivateKey, error) {
ec_key, err := ctx.ECDH.GenerateKey(rand.Reader) ec_key, err := ECDH.GenerateKey(rand.Reader)
if err != nil { if err != nil {
return ECDHSignal{}, nil, err return nil, nil, err
} }
now := time.Now() now := time.Now()
time_bytes, err := now.MarshalJSON() time_bytes, err := now.MarshalJSON()
if err != nil { if err != nil {
return ECDHSignal{}, nil, err return nil, nil, err
} }
sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...)
sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil { if err != nil {
return ECDHSignal{}, nil, err return nil, nil, err
} }
return ECDHSignal{ return &ECDHSignal{
StringSignal: StringSignal{ StringSignal: StringSignal{
BaseSignal: NewDirectSignal(ECDHSignalType), BaseSignal: NewBaseSignal(ECDHSignalType, Direct),
Str: "req", Str: "req",
}, },
Time: now, Time: now,
@ -374,7 +369,7 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e
const DEFAULT_ECDH_WINDOW = time.Second const DEFAULT_ECDH_WINDOW = time.Second
func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) { func NewECDHRespSignal(node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) {
now := time.Now() now := time.Now()
err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW) err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW)
@ -382,7 +377,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [
return ECDHSignal{}, nil, err return ECDHSignal{}, nil, err
} }
ec_key, err := ctx.ECDH.GenerateKey(rand.Reader) ec_key, err := ECDH.GenerateKey(rand.Reader)
if err != nil { if err != nil {
return ECDHSignal{}, nil, err return ECDHSignal{}, nil, err
} }
@ -406,7 +401,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [
return ECDHSignal{ return ECDHSignal{
StringSignal: StringSignal{ StringSignal: StringSignal{
BaseSignal: NewDirectSignal(ECDHSignalType), BaseSignal: NewBaseSignal(ECDHSignalType, Direct),
Str: "resp", Str: "resp",
}, },
Time: now, Time: now,
@ -449,34 +444,34 @@ type ECDHProxySignal struct {
Data []byte Data []byte
} }
func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (ECDHProxySignal, error) { func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (Signal, error) {
if shared_secret == nil { if shared_secret == nil {
return ECDHProxySignal{}, fmt.Errorf("need shared_secret") return nil, fmt.Errorf("need shared_secret")
} }
aes_key, err := aes.NewCipher(shared_secret[:32]) aes_key, err := aes.NewCipher(shared_secret[:32])
if err != nil { if err != nil {
return ECDHProxySignal{}, err return nil, err
} }
ser, err := SerializeSignal(signal, aes_key.BlockSize()) ser, err := SerializeSignal(signal, aes_key.BlockSize())
if err != nil { if err != nil {
return ECDHProxySignal{}, err return nil, err
} }
iv := make([]byte, aes_key.BlockSize()) iv := make([]byte, aes_key.BlockSize())
n, err := rand.Reader.Read(iv) n, err := rand.Reader.Read(iv)
if err != nil { if err != nil {
return ECDHProxySignal{}, err return nil, err
} else if n != len(iv) { } else if n != len(iv) {
return ECDHProxySignal{}, fmt.Errorf("Not enough bytes read for IV") return nil, fmt.Errorf("Not enough bytes read for IV")
} }
encrypter := cipher.NewCBCEncrypter(aes_key, iv) encrypter := cipher.NewCBCEncrypter(aes_key, iv)
encrypter.CryptBlocks(ser, ser) encrypter.CryptBlocks(ser, ser)
return ECDHProxySignal{ return &ECDHProxySignal{
BaseSignal: NewDirectSignal(ECDHProxySignalType), BaseSignal: NewBaseSignal(ECDHProxySignalType, Direct),
Source: source, Source: source,
Dest: dest, Dest: dest,
IV: iv, IV: iv,

@ -48,7 +48,7 @@ func (ext *GroupExt) Deserialize(ctx *Context, data []byte) error {
return err return err
} }
func (ext *GroupExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { func (ext *GroupExt) Process(ctx *Context, node *Node, msg Message) []Message {
return return nil
} }