Moved policies to node instead of an extension, need to fix gql tests

gql_cataclysm
noah metz 2023-08-07 20:26:02 -06:00
parent 1d91854f6f
commit 8770d6f433
13 changed files with 535 additions and 667 deletions

@ -48,9 +48,30 @@ func LoadExtension[T any, E interface {
return e, nil
}
type PolicyType string
func (policy PolicyType) Prefix() string { return "POLICY: " }
func (policy PolicyType) String() string { return string(policy) }
type PolicyLoadFunc func(*Context,[]byte) (Policy, error)
func LoadPolicy[T any, P interface {
*T
Policy
}](ctx *Context, data []byte) (Policy, error) {
p := P(new(T))
err := p.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return p, nil
}
type PolicyInfo struct {
Load PolicyLoadFunc
Type PolicyType
}
// ExtType and NodeType constants
const (
ACLExtType = ExtType("ACL")
ListenerExtType = ExtType("LISTENER")
LockableExtType = ExtType("LOCKABLE")
GQLExtType = ExtType("GQL")
@ -62,6 +83,7 @@ const (
var (
NodeNotFoundError = errors.New("Node not found in DB")
ECDH = ecdh.X25519()
)
type SignalLoadFunc func(*Context,[]byte) (Signal, error)
@ -107,12 +129,12 @@ type Context struct {
Log Logger
// Map between database extension hashes and the registered info
Extensions map[uint64]ExtensionInfo
// Map between databse policy hashes and the registered info
Policies map[uint64]PolicyInfo
// Map between serialized signal hashes and the registered info
Signals map[uint64]SignalInfo
// Map between database type hashes and the registered info
Types map[uint64]*NodeInfo
// Curve used for ecdh operations
ECDH ecdh.Curve
// Routing map to all the nodes local to this context
NodesLock sync.RWMutex
Nodes map[NodeID]*Node
@ -216,29 +238,32 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) {
// Stop every running loop
func (ctx *Context) Stop() {
for _, node := range(ctx.Nodes) {
node.MsgChan <- Msg{ZeroID, &StopSignal}
node.MsgChan <- Message{ZeroID, &StopSignal}
}
}
// Route a Signal to dest. Currently only local context routing is supported
func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error {
target, err := ctx.GetNode(dest)
func (ctx *Context) Send(source NodeID, messages []Message) error {
for _, msg := range(messages) {
target, err := ctx.GetNode(msg.NodeID)
if err == nil {
select {
case target.MsgChan <- Msg{source, signal}:
case target.MsgChan <- Message{source, msg.Signal}:
default:
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
stack_str := string(buf[:n])
return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", dest, stack_str)
return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", msg.NodeID, stack_str)
}
return nil
} else if errors.Is(err, NodeNotFoundError) {
// TODO: Handle finding nodes in other contexts
return err
}
} else {
return err
}
}
return nil
}
// Create a new Context with the base library content added
func NewContext(db * badger.DB, log Logger) (*Context, error) {
@ -249,15 +274,9 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
Types: map[uint64]*NodeInfo{},
Signals: map[uint64]SignalInfo{},
Nodes: map[NodeID]*Node{},
ECDH: ecdh.X25519(),
}
var err error
err = RegisterExtension[ACLExt,*ACLExt](ctx, NewACLExtContext())
if err != nil {
return nil, err
}
err = RegisterExtension[LockableExt,*LockableExt](ctx, nil)
if err != nil {
return nil, err
@ -299,7 +318,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err
}
err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType})
err = ctx.RegisterNodeType(GQLNodeType, []ExtType{GroupExtType, GQLExtType})
if err != nil {
return nil, err
}

@ -103,28 +103,30 @@ func (ext *ECDHExt) Field(name string) interface{} {
})
}
func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal *ECDHSignal) {
func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) []Message {
source := KeyID(signal.EDDSA)
messages := []Message{}
switch signal.Str {
case "req":
state, exists := ext.ECDHStates[source]
if exists == false {
state = ECDHState{nil, nil}
}
resp, shared_secret, err := NewECDHRespSignal(ctx, node, signal)
resp, shared_secret, err := NewECDHRespSignal(node, signal)
if err == nil {
state.SharedSecret = shared_secret
ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
ctx.Send(node.ID, source, &resp)
log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
messages = append(messages, Message{source, &resp})
} else {
ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
// TODO: send error response
log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), err.Error())})
}
case "resp":
state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil {
resp := NewErrorSignal(signal.ID(), "no_req")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "no_req")})
} else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil {
@ -133,55 +135,23 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si
state.SharedSecret = shared_secret
state.ECKey = nil
ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
}
}
}
default:
ctx.Log.Logf("ecdh", "unknown echd state %s", signal.Str)
}
}
func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
}
func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal *ECDHProxySignal) {
state, exists := ext.ECDHStates[source]
if exists == false {
resp := NewErrorSignal(signal.ID(), "no_req")
ctx.Send(node.ID, source, &resp)
} else if state.SharedSecret == nil {
resp := NewErrorSignal(signal.ID(), "no_shared")
ctx.Send(node.ID, source, &resp)
} else {
unwrapped_signal, err := ParseECDHProxySignal(ctx, signal, state.SharedSecret)
if err != nil {
resp := NewErrorSignal(signal.ID(), err.Error())
ctx.Send(node.ID, source, &resp)
} else {
//TODO: Figure out what I was trying to do here and fix it
ctx.Send(signal.Source, signal.Dest, unwrapped_signal)
}
log.Logf("ecdh", "unknown echd state %s", signal.Str)
}
return messages
}
func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) {
switch signal.Direction() {
case Direct:
switch signal.Type() {
case ECDHProxySignalType:
ecdh_signal := signal.(*ECDHProxySignal)
ext.HandleECDHProxySignal(ctx, source, node, ecdh_signal)
case ECDHStateSignalType:
ecdh_signal := signal.(*StringSignal)
ext.HandleStateSignal(ctx, source, node, ecdh_signal)
func (ext *ECDHExt) Process(ctx *Context, node *Node, msg Message) []Message {
switch msg.Signal.Type() {
case ECDHSignalType:
ecdh_signal := signal.(*ECDHSignal)
ext.HandleECDHSignal(ctx, source, node, ecdh_signal)
default:
}
default:
sig := msg.Signal.(*ECDHSignal)
return ext.HandleECDHSignal(ctx.Log, node, sig)
}
return nil
}
func (ext *ECDHExt) Type() ExtType {

@ -17,9 +17,10 @@ import (
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"strings"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/ecdh"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/tls"
@ -189,15 +190,16 @@ type ResolveContext struct {
}
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) {
username, key_bytes, ok := r.BasicAuth()
id_bytes, key_bytes, ok := r.BasicAuth()
if ok == false {
return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header")
}
auth_id, err := ParseID(username)
auth_uuid, err := uuid.FromBytes([]byte(id_bytes))
if err != nil {
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username)
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username")
}
auth_id := NodeID(auth_uuid)
key_raw, err := x509.ParsePKCS8PrivateKey([]byte(key_bytes))
if err != nil {
@ -916,7 +918,7 @@ func NewGQLExtContext() *GQLExtContext {
panic(err)
}
context.Mutation.AddFieldConfig("stopServer", &graphql.Field{
context.Mutation.AddFieldConfig("stop", &graphql.Field{
Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p)
@ -924,14 +926,13 @@ func NewGQLExtContext() *GQLExtContext {
return nil, err
}
sig := StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"}
err = Allowed(ctx.Context, ctx.User, sig.Permission(), ctx.Server)
sig, err := NewAuthorizedSignal(ctx.Key, &StopSignal)
if err != nil {
return err, nil
return nil, err
}
response_chan := ctx.Ext.GetResponseChannel(sig.ID())
err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, &sig)
err = ctx.Context.Send(ctx.Server.ID, []Message{Message{ctx.Server.ID, sig}})
if err != nil {
ctx.Ext.FreeResponseChannel(sig.ID())
return nil, err
@ -1016,8 +1017,8 @@ type GQLExt struct {
resolver_response_lock sync.RWMutex `json:"-"`
State string `json:"state"`
tls_key []byte `json:"tls_key"`
tls_cert []byte `json:"tls_cert"`
TLSKey []byte `json:"tls_key"`
TLSCert []byte `json:"tls_cert"`
Listen string `json:"listen"`
}
@ -1052,12 +1053,14 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
}
}
func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) {
func (ext *GQLExt) Process(ctx *Context, node *Node, msg Message) []Message {
// Process ReadResultSignalType by forwarding it to the waiting resolver
signal := msg.Signal
messages := []Message{}
if signal.Type() == ErrorSignalType {
// TODO: Forward to resolver if waiting for it
sig := signal.(*ErrorSignal)
response_chan := ext.FreeResponseChannel(sig.ID())
response_chan := ext.FreeResponseChannel(sig.UUID)
if response_chan != nil {
select {
case response_chan <- sig:
@ -1084,14 +1087,16 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
}
} else if signal.Type() == GQLStateSignalType {
sig := signal.(*StringSignal)
ctx.Log.Logf("gql", "GQL_STATE_SIGNAL: %s - %+v", node.ID, sig.Str)
switch sig.Str {
case "start_server":
if ext.State == "stopped" {
err := ext.StartGQLServer(ctx, node)
if err == nil {
ext.State = "running"
resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}
ctx.Send(node.ID, source, &resp)
node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
} else {
ctx.Log.Logf("gql", "GQL_START_ERROR: %s", err)
}
}
case "stop_server":
@ -1099,8 +1104,9 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
err := ext.StopGQLServer()
if err == nil {
ext.State = "stopped"
resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_stopped"}
ctx.Send(node.ID, source, &resp)
node.QueueSignal(time.Now(), NewStatusSignal("server_stopped", node.ID))
} else {
ctx.Log.Logf("gql", "GQL_STOP_ERROR: %s", err)
}
}
default:
@ -1112,14 +1118,16 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
case "running":
err := ext.StartGQLServer(ctx, node)
if err == nil {
resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}
ctx.Send(node.ID, source, &resp)
node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID))
} else {
ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err)
}
case "stopped":
default:
ctx.Log.Logf("gql", "unknown state to restore from: %s", ext.State)
}
}
return messages
}
func (ext *GQLExt) Type() ExtType {
@ -1147,12 +1155,13 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{
}
func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error {
ext.resolver_response = map[uuid.UUID]chan Signal{}
return json.Unmarshal(data, &ext)
}
func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) {
if tls_cert == nil || tls_key == nil {
_, ssl_key, err := ed25519.GenerateKey(rand.Reader)
ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
@ -1194,8 +1203,8 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta
State: state,
Listen: listen,
resolver_response: map[uuid.UUID]chan Signal{},
tls_cert: tls_cert,
tls_key: tls_key,
TLSCert: tls_cert,
TLSKey: tls_key,
}, nil
}
@ -1224,7 +1233,7 @@ func (ext *GQLExt) StartGQLServer(ctx *Context, node *Node) error {
return fmt.Errorf("Failed to start listener for server on %s", http_server.Addr)
}
cert, err := tls.X509KeyPair(ext.tls_cert, ext.tls_key)
cert, err := tls.X509KeyPair(ext.TLSCert, ext.TLSKey)
if err != nil {
return err
}

@ -61,7 +61,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
resp_channels[read_signal.ID()] = response_chan
node_ids[read_signal.ID()] = id
err = ctx.Context.Send(ctx.Server.ID, id, &auth_signal)
err = ctx.Context.Send(ctx.Server.ID, []Message{Message{id, auth_signal}})
if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID())
return nil, err
@ -79,11 +79,12 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
case *ReadResultSignal:
responses = append(responses, NodeResult{node_ids[sig_id], resp})
case *ErrorSignal:
return nil, fmt.Errorf(resp.Str)
return nil, fmt.Errorf(resp.Error)
default:
return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp))
}
}
ctx.Context.Log.Logf("gql", "RESOLVED_NODES")
return responses, nil
}

@ -11,30 +11,28 @@ import (
"crypto/tls"
"crypto/x509"
"bytes"
"github.com/google/uuid"
)
func TestGQL(t *testing.T) {
ctx := logTestContext(t, []string{})
ctx := logTestContext(t, []string{"gql", "lockable", "node_timeout", "listener"})
TestNodeType := NodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType, ACLExtType})
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType})
fatalErr(t, err)
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped")
fatalErr(t, err)
listener_ext := NewListenerExt(10)
policy := NewAllNodesPolicy(Actions{MakeAction("+")})
start_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "start_server"}
gql := NewNode(ctx, nil, GQLNodeType, 10, []QueuedSignal{
QueuedSignal{uuid.New(), &start_signal, time.Now()},
}, NewLockableExt(), NewACLExt(&policy), gql_ext, NewGroupExt(nil), listener_ext)
n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(&policy))
gql := NewNode(ctx, nil, GQLNodeType, 10, nil, NewLockableExt(), gql_ext, NewGroupExt(nil), listener_ext)
n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt())
err = LinkRequirement(ctx, gql.ID, n1.ID)
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool {
err = ctx.Send(gql.ID, []Message{{gql.ID, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}}})
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "server_started"
})
fatalErr(t, err)
@ -86,8 +84,8 @@ func TestGQL(t *testing.T) {
resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2)
stop_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"}
ctx.Send(n1.ID, gql.ID, &stop_signal)
stop_signal := StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}
ctx.Send(n1.ID, []Message{{gql.ID, &stop_signal}})
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool {
return sig.Str == "server_stopped"
})
@ -109,13 +107,10 @@ func TestGQLDB(t *testing.T) {
gql := NewNode(ctx, nil, GQLNodeType, 10, nil,
gql_ext,
listener_ext,
NewACLExt(),
NewGroupExt(nil))
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
err = ctx.Send(gql.ID, gql.ID, &StopSignal)
fatalErr(t, err)
ctx.Stop()
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql.ID
})
@ -134,7 +129,7 @@ func TestGQLDB(t *testing.T) {
fatalErr(t, err)
listener_ext, err = GetExt[*ListenerExt](gql_loaded)
fatalErr(t, err)
err = ctx.Send(gql_loaded.ID, gql_loaded.ID, &StopSignal)
err = ctx.Send(gql_loaded.ID, []Message{{gql_loaded.ID, &StopSignal}})
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID

@ -9,7 +9,6 @@ import (
const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER")
func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
policy := NewAllNodesPolicy(Actions{MakeAction("status")})
listener_extension := NewListenerExt(buffer)
listener := NewNode(ctx,
nil,
@ -17,7 +16,6 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
10,
nil,
listener_extension,
NewACLExt(&policy),
NewLockableExt())
return listener, listener_extension
@ -32,7 +30,7 @@ func logTestContext(t * testing.T, components []string) *Context {
ctx, err := NewContext(db, NewConsoleLogger(components))
fatalErr(t, err)
err = ctx.RegisterNodeType(SimpleListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType})
err = ctx.RegisterNodeType(SimpleListenerNodeType, []ExtType{ListenerExtType, LockableExtType})
fatalErr(t, err)
return ctx

@ -7,14 +7,14 @@ import (
// A Listener extension provides a channel that can receive signals on a different thread
type ListenerExt struct {
Buffer int
Chan chan Signal
Chan chan Message
}
// Create a new listener extension with a given buffer size
func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{
Buffer: buffer,
Chan: make(chan Signal, buffer),
Chan: make(chan Message, buffer),
}
}
@ -32,7 +32,7 @@ func (ext *ListenerExt) Field(name string) interface{} {
// Simple load function, unmarshal the buffer int from json
func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error {
err := json.Unmarshal(data, &ext.Buffer)
ext.Chan = make(chan Signal, ext.Buffer)
ext.Chan = make(chan Message, ext.Buffer)
return err
}
@ -41,14 +41,14 @@ func (listener *ListenerExt) Type() ExtType {
}
// Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal)
func (ext *ListenerExt) Process(ctx *Context, node *Node, msg Message) []Message {
ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, msg.Signal)
select {
case ext.Chan <- signal:
case ext.Chan <- msg:
default:
ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID)
}
return
return nil
}
// ReqState holds the multiple states of a requirement
@ -138,43 +138,38 @@ func NewLockableExt() *LockableExt {
// Send the signal to unlock a node from itself
func UnlockLockable(ctx *Context, node *Node) error {
lock_signal := NewLockSignal("unlock")
return ctx.Send(node.ID, node.ID, &lock_signal)
return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("unlock")}})
}
// Send the signal to lock a node from itself
func LockLockable(ctx *Context, node *Node) error {
lock_signal := NewLockSignal("lock")
return ctx.Send(node.ID, node.ID, &lock_signal)
return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("lock")}})
}
// Setup a node to send the initial requirement link signal, then send the signal
func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error {
start_signal := NewLinkStartSignal("req", requirement)
return ctx.Send(dependency, dependency, &start_signal)
return ctx.Send(dependency, []Message{Message{dependency, NewLinkStartSignal("req", requirement)}})
}
// Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal)
func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message {
state := signal.Str
log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal)
messages := []Message{}
switch state {
case "unlock":
if ext.Owner == nil {
resp := NewErrorSignal(signal.ID(), "already_unlocked")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocked")})
} else if source != *ext.Owner {
resp := NewErrorSignal(signal.ID(), "not_owner")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_owner")})
} else if ext.PendingOwner == nil {
resp := NewErrorSignal(signal.ID(), "already_unlocking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocking")})
} else {
if len(ext.Requirements) == 0 {
ext.Owner = nil
ext.PendingOwner = nil
resp := NewLockSignal("unlocked")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewLockSignal("unlocked")})
} else {
ext.PendingOwner = nil
for id, state := range(ext.Requirements) {
@ -184,44 +179,36 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
}
state.Lock = "unlocking"
ext.Requirements[id] = state
resp := NewLockSignal("unlock")
ctx.Send(node.ID, id, &resp)
messages = append(messages, Message{id, NewLockSignal("unlock")})
}
}
if source != node.ID {
resp := NewLockSignal("unlocking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewLockSignal("unlocking")})
}
}
}
case "unlocking":
state, exists := ext.Requirements[source]
if exists == false {
resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
} else if state.Link != "linked" {
resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")})
} else if state.Lock != "unlocking" {
resp := NewErrorSignal(signal.ID(), "not_unlocking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")})
}
case "unlocked":
if source == node.ID {
return
return nil
}
state, exists := ext.Requirements[source]
if exists == false {
resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
} else if state.Link != "linked" {
resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")})
} else if state.Lock != "unlocking" {
resp := NewErrorSignal(signal.ID(), "not_unlocking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")})
} else {
state.Lock = "unlocked"
ext.Requirements[source] = state
@ -241,26 +228,22 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
if linked == unlocked {
previous_owner := *ext.Owner
ext.Owner = nil
resp := NewLockSignal("unlocked")
ctx.Send(node.ID, previous_owner, &resp)
messages = append(messages, Message{previous_owner, NewLockSignal("unlocked")})
}
}
}
case "locked":
if source == node.ID {
return
return nil
}
state, exists := ext.Requirements[source]
if exists == false {
resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
} else if state.Link != "linked" {
resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")})
} else if state.Lock != "locking" {
resp := NewErrorSignal(signal.ID(), "not_locking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")})
} else {
state.Lock = "locked"
ext.Requirements[source] = state
@ -279,176 +262,142 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
if linked == locked {
ext.Owner = ext.PendingOwner
resp := NewLockSignal("locked")
ctx.Send(node.ID, *ext.Owner, &resp)
messages = append(messages, Message{*ext.Owner, NewLockSignal("locked")})
}
}
}
case "locking":
state, exists := ext.Requirements[source]
if exists == false {
resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")})
} else if state.Link != "linked" {
resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")})
} else if state.Lock != "locking" {
resp := NewErrorSignal(signal.ID(), "not_locking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")})
}
case "lock":
if ext.Owner != nil {
resp := NewErrorSignal(signal.ID(), "already_locked")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locked")})
} else if ext.PendingOwner != nil {
resp := NewErrorSignal(signal.ID(), "already_locking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locking")})
} else {
owner := source
if len(ext.Requirements) == 0 {
ext.Owner = &owner
ext.PendingOwner = ext.Owner
resp := NewLockSignal("locked")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewLockSignal("locked")})
} else {
ext.PendingOwner = &owner
for id, state := range(ext.Requirements) {
if state.Link == "linked" {
log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id)
if state.Lock != "unlocked" {
panic("NOT_UNLOCKED")
}
state.Lock = "locking"
ext.Requirements[id] = state
sub := NewLockSignal("lock")
ctx.Send(node.ID, id, &sub)
messages = append(messages, Message{id, NewLockSignal("lock")})
}
}
if source != node.ID {
resp := NewLockSignal("locking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewLockSignal("locking")})
}
}
}
default:
ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", state)
log.Logf("lockable", "LOCK_ERR: unkown state %s", state)
}
log.Logf("lockable", "LOCK_MESSAGES: %+v", messages)
return messages
}
func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal *IDStringSignal) {
ctx.Log.Logf("lockable", "LINK__START_SIGNAL: %s->%s %+v", source, node.ID, signal)
func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) []Message {
link_type := signal.Str
target := signal.NodeID
log.Logf("lockable", "LINK_START_SIGNAL: %s->%s %s %s", source, node.ID, link_type, target)
messages := []Message{}
switch link_type {
case "req":
state, exists := ext.Requirements[target]
_, dep_exists := ext.Dependencies[target]
if ext.Owner != nil {
resp := NewLinkStartSignal("locked", target)
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already locked")})
} else if ext.Owner != ext.PendingOwner {
if ext.PendingOwner == nil {
resp := NewLinkStartSignal("unlocking", target)
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")})
} else {
resp := NewLinkStartSignal("locking", target)
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")})
}
} else if exists == true {
if state.Link == "linking" {
resp := NewErrorSignal(signal.ID(), "already_linking_req")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking_req")})
} else if state.Link == "linked" {
resp := NewErrorSignal(signal.ID(), "already_req")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_req")})
}
} else if dep_exists == true {
resp := NewLinkStartSignal("already_dep", target)
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_dep")})
} else {
ext.Requirements[target] = LinkState{"linking", "unlocked", source}
resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, target, &resp)
notify := NewLinkStartSignal("linking_req", target)
ctx.Send(node.ID, source, &notify)
messages = append(messages, Message{target, NewLinkSignal("linked_as_req")})
messages = append(messages, Message{source, NewLinkStartSignal("linking_req", target)})
}
}
return messages
}
// Handle LinkSignal, updating the extensions requirements and dependencies as necessary
// TODO: Add unlink
func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal)
func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message {
log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal)
state := signal.Str
messages := []Message{}
switch state {
case "linked_as_dep":
case "dep_done":
state, exists := ext.Requirements[source]
if exists == true && state.Link == "linked" {
resp := NewLinkStartSignal("linked_as_req", source)
ctx.Send(node.ID, state.Initiator, &resp)
if exists == false {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linking")})
} else if state.Link == "linking" {
state.Link = "linked"
ext.Requirements[source] = state
resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != ext.Owner {
if ext.Owner == nil {
resp := NewLinkSignal("locking")
ctx.Send(node.ID, source, &resp)
} else {
resp := NewLinkSignal("unlocking")
ctx.Send(node.ID, source, &resp)
}
} else {
ext.Requirements[source] = LinkState{"linking", "unlocked", source}
resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, source, &resp)
log.Logf("lockable", "FINISHED_LINKING_REQ: %s->%s", node.ID, source)
}
ctx.Log.Logf("lockable", "%s is a dependency of %s", node.ID, source)
case "linked_as_req":
state, exists := ext.Dependencies[source]
if exists == true && state.Link == "linked" {
resp := NewLinkStartSignal("linked_as_dep", source)
ctx.Send(node.ID, state.Initiator, &resp)
if exists == false {
ext.Dependencies[source] = LinkState{"linked", "unlocked", source}
messages = append(messages, Message{source, NewLinkSignal("dep_done")})
} else if state.Link == "linking" {
state.Link = "linked"
ext.Dependencies[source] = state
resp := NewLinkSignal("linked_as_dep")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking")})
} else if state.Link == "linked" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linked")})
} else if ext.PendingOwner != ext.Owner {
if ext.Owner == nil {
resp := NewLinkSignal("locking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")})
} else {
resp := NewLinkSignal("unlocking")
ctx.Send(node.ID, source, &resp)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")})
}
} else {
ext.Dependencies[source] = LinkState{"linking", "unlocked", source}
resp := NewLinkSignal("linked_as_dep")
ctx.Send(node.ID, source, &resp)
}
ctx.Log.Logf("lockable", "%s is a requirement of %s", node.ID, source)
default:
ctx.Log.Logf("lockable", "LINK_ERROR: unknown state %s", state)
log.Logf("lockable", "LINK_ERROR: unknown state %s", state)
}
return messages
}
// LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes
// LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state
func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) {
switch signal.Direction() {
func (ext *LockableExt) Process(ctx *Context, node *Node, msg Message) []Message {
messages := []Message{}
switch msg.Signal.Direction() {
case Up:
ctx.Log.Logf("lockable", "LOCKABLE_DEPENDENCIES: %+v", ext.Dependencies)
owner_sent := false
for dependency, state := range(ext.Dependencies) {
if state.Link == "linked" {
err := ctx.Send(node.ID, dependency, signal)
if err != nil {
ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, dependency, err)
}
messages = append(messages, Message{dependency, msg.Signal})
if ext.Owner != nil {
if dependency == *ext.Owner {
owner_sent = true
@ -459,32 +408,27 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal
if ext.Owner != nil && owner_sent == false {
if *ext.Owner != node.ID {
err := ctx.Send(node.ID, *ext.Owner, signal)
if err != nil {
ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, *ext.Owner, err)
}
messages = append(messages, Message{*ext.Owner, msg.Signal})
}
}
case Down:
for requirement, state := range(ext.Requirements) {
if state.Link == "linked" {
err := ctx.Send(node.ID, requirement, signal)
if err != nil {
ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, requirement, err)
}
messages = append(messages, Message{requirement, msg.Signal})
}
}
case Direct:
switch signal.Type() {
switch msg.Signal.Type() {
case LinkSignalType:
ext.HandleLinkSignal(ctx, source, node, signal.(*StringSignal))
messages = ext.HandleLinkSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal))
case LockSignalType:
ext.HandleLockSignal(ctx, source, node, signal.(*StringSignal))
messages = ext.HandleLockSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal))
case LinkStartSignalType:
ext.HandleLinkStartSignal(ctx, source, node, signal.(*IDStringSignal))
messages = ext.HandleLinkStartSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*IDStringSignal))
default:
}
default:
}
return messages
}

@ -9,7 +9,7 @@ const TestLockableType = NodeType("TEST_LOCKABLE")
func lockableTestContext(t *testing.T, logs []string) *Context {
ctx := logTestContext(t, logs)
err := ctx.RegisterNodeType(TestLockableType, []ExtType{ACLExtType, LockableExtType})
err := ctx.RegisterNodeType(TestLockableType, []ExtType{LockableExtType})
fatalErr(t, err)
return ctx
@ -26,13 +26,11 @@ func TestLink(t *testing.T) {
l1_listener := NewListenerExt(10)
l1 := NewNode(ctx, nil, TestLockableType, 10, nil,
l1_listener,
NewACLExt(&link_policy),
NewLockableExt(),
)
l2_listener := NewListenerExt(10)
l2 := NewNode(ctx, nil, TestLockableType, 10, nil,
l2_listener,
NewACLExt(&link_policy),
NewLockableExt(),
)
@ -40,13 +38,13 @@ func TestLink(t *testing.T) {
err := LinkRequirement(ctx, l1.ID, l2.ID)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "linked_as_req"
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool {
return sig.Str == "dep_done"
})
fatalErr(t, err)
sig1 := NewStatusSignal("TEST", l2.ID)
err = ctx.Send(l2.ID, l2.ID, &sig1)
err = ctx.Send(l2.ID, []Message{{l2.ID, sig1}})
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool {
@ -65,7 +63,6 @@ func TestLink10K(t *testing.T) {
NewLockable := func()(*Node) {
l := NewNode(ctx, nil, TestLockableType, 10, nil,
NewACLExt(&lock_policy, &link_policy),
NewLockableExt(),
)
return l
@ -75,7 +72,6 @@ func TestLink10K(t *testing.T) {
listener := NewListenerExt(100000)
l := NewNode(ctx, nil, TestLockableType, 256, nil,
listener,
NewACLExt(&lock_policy, &link_policy),
NewLockableExt(),
)
return l, listener
@ -92,8 +88,8 @@ func TestLink10K(t *testing.T) {
for range(lockables) {
_, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "linked_as_req"
_, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool {
return sig.Str == "dep_done"
})
fatalErr(t, err)
}
@ -102,13 +98,12 @@ func TestLink10K(t *testing.T) {
}
func TestLock(t *testing.T) {
ctx := lockableTestContext(t, []string{})
ctx := lockableTestContext(t, []string{"lockable", "listener"})
NewLockable := func()(*Node, *ListenerExt) {
listener := NewListenerExt(100)
l := NewNode(ctx, nil, TestLockableType, 10, nil,
listener,
NewACLExt(&lock_policy, &link_policy),
NewLockableExt(),
)
return l, listener
@ -141,30 +136,30 @@ func TestLock(t *testing.T) {
err = LinkRequirement(ctx, l0.ID, l5.ID)
fatalErr(t, err)
linked_as_req := func(sig *IDStringSignal) bool {
return sig.Str == "linked_as_req"
linked_as_req := func(sig *StringSignal) bool {
return sig.Str == "dep_done"
}
locked := func(sig *StringSignal) bool {
return sig.Str == "locked"
}
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err)
err = LockLockable(ctx, l1)

@ -20,9 +20,10 @@ const (
// Magic first four bytes of serialized DB content, stored big endian
NODE_DB_MAGIC = 0x2491df14
// Total length of the node database header, has magic to verify and type_hash to map to load function
NODE_DB_HEADER_LEN = 28
NODE_DB_HEADER_LEN = 32
EXTENSION_DB_HEADER_LEN = 16
QSIGNAL_DB_HEADER_LEN = 40
POLICY_DB_HEADER_LEN = 16
)
var (
@ -82,7 +83,7 @@ func RandID() NodeID {
return NodeID(uuid.New())
}
// A Serializable has a type that can be used to map to it, and a function to serialize the current state
// A Serializable has a type that can be used to map to it, and a function to serialize` the current state
type Serializable[I comparable] interface {
Serialize()([]byte,error)
Deserialize(*Context,[]byte)error
@ -93,7 +94,7 @@ type Serializable[I comparable] interface {
type Extension interface {
Serializable[ExtType]
Field(string)interface{}
Process(context *Context, source NodeID, node *Node, signal Signal)
Process(ctx *Context, node *Node, message Message)[]Message
}
// A QueuedSignal is a Signal that has been Queued to trigger at a set time
@ -110,9 +111,10 @@ type Node struct {
ID NodeID
Type NodeType
Extensions map[ExtType]Extension
Policies map[PolicyType]Policy
// Channel for this node to receive messages from the Context
MsgChan chan Msg
MsgChan chan Message
// Size of MsgChan
BufferSize uint32
// Channel for this node to process delayed signals
@ -124,6 +126,18 @@ type Node struct {
NextSignal *QueuedSignal
}
func (node *Node) Allows(principal_id NodeID, action Action) error {
errs := []error{}
for _, policy := range(node.Policies) {
err := policy.Allows(principal_id, action, node)
if err == nil {
return nil
}
errs = append(errs, err)
}
return fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs)
}
func (node *Node) QueueSignal(time time.Time, signal Signal) uuid.UUID {
id := uuid.New()
node.SignalQueue = append(node.SignalQueue, QueuedSignal{id, signal, time})
@ -163,17 +177,12 @@ func runNode(ctx *Context, node *Node) {
ctx.Log.Logf("node", "RUN_STOP: %s", node.ID)
}
type Msg struct {
Source NodeID
Signal Signal
}
func ReadNodeFields(ctx *Context, self *Node, princ NodeID, reqs map[ExtType][]string)map[ExtType]map[string]interface{} {
exts := map[ExtType]map[string]interface{}{}
for ext_type, field_reqs := range(reqs) {
fields := map[string]interface{}{}
for _, req := range(field_reqs) {
err := Allowed(ctx, princ, MakeAction(ReadSignalType, ext_type, req), self)
err := self.Allows(princ, MakeAction(ReadSignalType, ext_type, req))
if err != nil {
fields[req] = err
} else {
@ -198,27 +207,18 @@ func nodeLoop(ctx *Context, node *Node) error {
}
// Perform startup actions
node.Process(ctx, node.ID, &StartSignal)
node.Process(ctx, Message{ZeroID, &StartSignal})
for true {
var signal Signal
var source NodeID
var msg Message
select {
case msg := <- node.MsgChan:
ctx.Log.Logf("signal", "NODE_MSG: %s - %+v", node.ID, msg)
signal = msg.Signal
source = msg.Source
err := Allowed(ctx, msg.Source, signal.Permission(), node)
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err)
resp := NewErrorSignal(msg.Signal.ID(), err.Error())
ctx.Send(node.ID, msg.Source, &resp)
continue
}
case msg = <- node.MsgChan:
ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal.Type())
case <-node.TimeoutChan:
signal = node.NextSignal.Signal
signal := node.NextSignal.Signal
msg = Message{node.ID, signal}
t := node.NextSignal.Time
source = node.ID
i := -1
for j, queued := range(node.SignalQueue) {
if queued.UUID == node.NextSignal.UUID {
@ -235,17 +235,17 @@ func nodeLoop(ctx *Context, node *Node) error {
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
if node.NextSignal == nil {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil@%+v", node.ID, t, signal, node.TimeoutChan)
ctx.Log.Logf("node_timeout", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL nil@%+v", node.ID, signal.Type(), t, node.TimeoutChan)
} else {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, t, signal, node.NextSignal, node.NextSignal.Time)
ctx.Log.Logf("node_timeout", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL: %s@%s", node.ID, signal.Type(), t, node.NextSignal, node.NextSignal.Time)
}
}
// Unwrap Authorized Signals
if signal.Type() == AuthorizedSignalType {
sig, ok := signal.(*AuthorizedSignal)
if msg.Signal.Type() == AuthorizedSignalType {
sig, ok := msg.Signal.(*AuthorizedSignal)
if ok == false {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(signal))
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(msg.Signal))
} else {
// Validate
sig_data, err := sig.Signal.Serialize()
@ -253,45 +253,40 @@ func nodeLoop(ctx *Context, node *Node) error {
} else {
validated := ed25519.Verify(sig.Principal, sig_data, sig.Signature)
if validated == true {
err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node)
err := node.Allows(KeyID(sig.Principal), sig.Signal.Permission())
if err != nil {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err)
resp := NewErrorSignal(sig.ID(), err.Error())
ctx.Send(node.ID, source, &resp)
ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.ID(), err.Error())}})
} else {
// Unwrap the signal without changing the source
signal = sig.Signal
msg = Message{msg.NodeID, sig.Signal}
}
} else {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate")
resp := NewErrorSignal(sig.ID(), "signature validation failed")
ctx.Send(node.ID, source, &resp)
ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.ID(), "signature validation failed")}})
}
}
}
}
ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
ctx.Log.Logf("node_signal_queue", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
// Handle special signal types
if signal.Type() == StopSignalType {
resp := NewErrorSignal(signal.ID(), "stopped")
ctx.Send(node.ID, source, &resp)
status := NewStatusSignal("stopped", node.ID)
node.Process(ctx, node.ID, &status)
if msg.Signal.Type() == StopSignalType {
ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(msg.Signal.ID(), "stopped")}})
node.Process(ctx, Message{node.ID, NewStatusSignal("stopped", node.ID)})
break
} else if signal.Type() == ReadSignalType {
read_signal, ok := signal.(*ReadSignal)
} else if msg.Signal.Type() == ReadSignalType {
read_signal, ok := msg.Signal.(*ReadSignal)
if ok == false {
ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal)
ctx.Log.Logf("signal_read", "READ_SIGNAL: bad cast %+v", msg.Signal)
} else {
result := ReadNodeFields(ctx, node, source, read_signal.Extensions)
resp := NewReadResultSignal(read_signal.ID(), node.Type, result)
ctx.Send(node.ID, source, &resp)
result := ReadNodeFields(ctx, node, msg.NodeID, read_signal.Extensions)
ctx.Send(node.ID, []Message{Message{msg.NodeID, NewReadResultSignal(read_signal.ID(), node.Type, result)}})
}
}
node.Process(ctx, source, signal)
node.Process(ctx, msg)
// assume that processing a signal means that this nodes state changed
// TODO: remove a lot of database writes by only writing when things change,
// so need to have Process return whether or not state changed
@ -308,13 +303,26 @@ func nodeLoop(ctx *Context, node *Node) error {
return nil
}
func (node *Node) Process(ctx *Context, source NodeID, signal Signal) {
type Message struct {
NodeID
Signal
}
func (node *Node) Process(ctx *Context, message Message) error {
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, message.Signal.Type())
messages := []Message{}
for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("signal", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
ext.Process(ctx, source, node, signal)
ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
//TODO: add extension and node info to log
resp := ext.Process(ctx, node, message)
if resp != nil {
messages = append(messages, resp...)
}
}
return ctx.Send(node.ID, messages)
}
func GetCtx[T Extension, C any](ctx *Context) (C, error) {
var zero T
var zero_ctx C
@ -352,6 +360,7 @@ func GetExt[T Extension](node *Node) (T, error) {
func (node *Node) Serialize() ([]byte, error) {
extensions := make([]ExtensionDB, len(node.Extensions))
qsignals := make([]QSignalDB, len(node.SignalQueue))
policies := make([]PolicyDB, len(node.Policies))
key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key)
if err != nil {
@ -365,6 +374,7 @@ func (node *Node) Serialize() ([]byte, error) {
KeyLength: uint32(len(key_bytes)),
BufferSize: node.BufferSize,
NumExtensions: uint32(len(extensions)),
NumPolicies: uint32(len(policies)),
NumQueuedSignals: uint32(len(node.SignalQueue)),
},
Extensions: extensions,
@ -405,6 +415,22 @@ func (node *Node) Serialize() ([]byte, error) {
}
}
i = 0
for _, policy := range(node.Policies) {
ser, err := policy.Serialize()
if err != nil {
return nil, err
}
node_db.Policies[i] = PolicyDB{
PolicyDBHeader{
Hash(policy.Type()),
uint64(len(ser)),
},
ser,
}
}
return node_db.Serialize(), nil
}
@ -415,7 +441,7 @@ func KeyID(pub ed25519.PublicKey) NodeID {
// Create a new node in memory and start it's event loop
// TODO: Change panics to errors
func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, queued_signals []QueuedSignal, extensions ...Extension) *Node {
func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, policies map[PolicyType]Policy, extensions ...Extension) *Node {
var err error
var public ed25519.PublicKey
if key == nil {
@ -453,22 +479,19 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
}
}
if queued_signals == nil {
queued_signals = []QueuedSignal{}
if policies == nil {
policies = map[PolicyType]Policy{}
}
next_signal, timeout_chan := SoonestSignal(queued_signals)
node := &Node{
Key: key,
ID: id,
Type: node_type,
Extensions: ext_map,
MsgChan: make(chan Msg, buffer_size),
Policies: policies,
MsgChan: make(chan Message, buffer_size),
BufferSize: buffer_size,
TimeoutChan: timeout_chan,
SignalQueue: queued_signals,
NextSignal: next_signal,
SignalQueue: []QueuedSignal{},
}
ctx.AddNode(id, node)
err = WriteNode(ctx, node)
@ -476,41 +499,50 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
panic(err)
}
node.Process(ctx, node.ID, &NewSignal)
node.Process(ctx, Message{node.ID, &NewSignal})
go runNode(ctx, node)
return node
}
func Allowed(ctx *Context, principal_id NodeID, action Action, node *Node) error {
ctx.Log.Logf("policy", "POLICY_CHECK: %s -> %s.%s", principal_id, node.ID, action)
// Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension
if principal_id == node.ID {
ctx.Log.Logf("policy", "POLICY_CHECK_SAME_NODE: %s.%s", principal_id, action)
return nil
type PolicyDBHeader struct {
TypeHash uint64
Length uint64
}
// Check if the node has a policy extension itself, and check against the policies in it
policy_ext, err := GetExt[*ACLExt](node)
if err != nil {
ctx.Log.Logf("policy", "POLICY_CHECK_NO_ACL_EXT: %s", node.ID)
return err
type PolicyDB struct {
Header PolicyDBHeader
Data []byte
}
err = policy_ext.Allows(ctx, principal_id, action, node)
if err != nil {
ctx.Log.Logf("policy", "POLICY_CHECK_FAIL: %s -> %s.%s : %s", principal_id, node.ID, action, err)
} else {
ctx.Log.Logf("policy", "POLICY_CHECK_PASS: %s -> %s.%s", principal_id, node.ID, action)
type QSignalDBHeader struct {
SignalID uuid.UUID
Time time.Time
TypeHash uint64
Length uint64
}
return err
type QSignalDB struct {
Header QSignalDBHeader
Data []byte
}
type ExtensionDBHeader struct {
TypeHash uint64
Length uint64
}
type ExtensionDB struct {
Header ExtensionDBHeader
Data []byte
}
// A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node
type NodeDBHeader struct {
Magic uint32
NumExtensions uint32
NumPolicies uint32
NumQueuedSignals uint32
BufferSize uint32
KeyLength uint32
@ -519,8 +551,9 @@ type NodeDBHeader struct {
type NodeDB struct {
Header NodeDBHeader
QueuedSignals []QSignalDB
Extensions []ExtensionDB
Policies []PolicyDB
QueuedSignals []QSignalDB
KeyBytes []byte
}
@ -532,10 +565,11 @@ func NewNodeDB(data []byte) (NodeDB, error) {
magic := binary.BigEndian.Uint32(data[0:4])
num_extensions := binary.BigEndian.Uint32(data[4:8])
num_queued_signals := binary.BigEndian.Uint32(data[8:12])
buffer_size := binary.BigEndian.Uint32(data[12:16])
key_length := binary.BigEndian.Uint32(data[16:20])
node_type_hash := binary.BigEndian.Uint64(data[20:28])
num_policies := binary.BigEndian.Uint32(data[8:12])
num_queued_signals := binary.BigEndian.Uint32(data[12:16])
buffer_size := binary.BigEndian.Uint32(data[16:20])
key_length := binary.BigEndian.Uint32(data[20:24])
node_type_hash := binary.BigEndian.Uint64(data[24:32])
ptr += NODE_DB_HEADER_LEN
@ -573,6 +607,26 @@ func NewNodeDB(data []byte) (NodeDB, error) {
ptr += int(EXTENSION_DB_HEADER_LEN + length)
}
policies := make([]PolicyDB, num_policies)
for i, _ := range(policies) {
cur := data[ptr:]
type_hash := binary.BigEndian.Uint64(cur[0:8])
length := binary.BigEndian.Uint64(cur[8:16])
data_start := uint64(POLICY_DB_HEADER_LEN)
data_end := data_start + length
policy_data := cur[data_start:data_end]
policies[i] = PolicyDB{
PolicyDBHeader{
type_hash,
length,
},
policy_data,
}
ptr += int(POLICY_DB_HEADER_LEN + length)
}
queued_signals := make([]QSignalDB, num_queued_signals)
for i, _ := range(queued_signals) {
cur := data[ptr:]
@ -626,10 +680,11 @@ func (header NodeDBHeader) Serialize() []byte {
ret := make([]byte, NODE_DB_HEADER_LEN)
binary.BigEndian.PutUint32(ret[0:4], header.Magic)
binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions)
binary.BigEndian.PutUint32(ret[8:12], header.NumQueuedSignals)
binary.BigEndian.PutUint32(ret[12:16], header.BufferSize)
binary.BigEndian.PutUint32(ret[16:20], header.KeyLength)
binary.BigEndian.PutUint64(ret[20:28], header.TypeHash)
binary.BigEndian.PutUint32(ret[8:12], header.NumPolicies)
binary.BigEndian.PutUint32(ret[12:16], header.NumQueuedSignals)
binary.BigEndian.PutUint32(ret[16:20], header.BufferSize)
binary.BigEndian.PutUint32(ret[20:24], header.KeyLength)
binary.BigEndian.PutUint64(ret[24:32], header.TypeHash)
return ret
}
@ -673,28 +728,6 @@ func (extension ExtensionDB) Serialize() []byte {
return append(header_bytes, extension.Data...)
}
type QSignalDBHeader struct {
SignalID uuid.UUID
Time time.Time
TypeHash uint64
Length uint64
}
type QSignalDB struct {
Header QSignalDBHeader
Data []byte
}
type ExtensionDBHeader struct {
TypeHash uint64
Length uint64
}
type ExtensionDB struct {
Header ExtensionDBHeader
Data []byte
}
// Write a node to the database
func WriteNode(ctx *Context, node *Node) error {
ctx.Log.Logf("db", "DB_WRITE: %s", node.ID)
@ -740,6 +773,21 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, err
}
policies := make(map[PolicyType]Policy, node_db.Header.NumPolicies)
for _, policy_db := range(node_db.Policies) {
policy_info, exists := ctx.Policies[policy_db.Header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known policy type", policy_db.Header.TypeHash)
}
policy, err := policy_info.Load(ctx, policy_db.Data)
if err != nil {
return nil, err
}
policies[policy_info.Type] = policy
}
key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes)
if err != nil {
return nil, err
@ -784,7 +832,8 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
ID: key_id,
Type: node_type.Type,
Extensions: map[ExtType]Extension{},
MsgChan: make(chan Msg, node_db.Header.BufferSize),
Policies: policies,
MsgChan: make(chan Message, node_db.Header.BufferSize),
BufferSize: node_db.Header.BufferSize,
TimeoutChan: timeout_chan,
SignalQueue: signal_queue,

@ -23,7 +23,7 @@ func TestNodeDB(t *testing.T) {
func TestNodeRead(t *testing.T) {
ctx := logTestContext(t, []string{})
node_type := NodeType("TEST")
err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType, ECDHExtType})
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType, ECDHExtType})
fatalErr(t, err)
n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader)
@ -37,21 +37,15 @@ func TestNodeRead(t *testing.T) {
ctx.Log.Logf("test", "N1: %s", n1_id)
ctx.Log.Logf("test", "N2: %s", n2_id)
n2_policy := NewPerNodePolicy(map[NodeID]Actions{
n1_id: Actions{MakeAction(ReadResultSignalType, "+")},
})
n2_listener := NewListenerExt(10)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(&n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt(), n2_listener)
n1_policy := NewPerNodePolicy(map[NodeID]Actions{
n2_id: Actions{MakeAction(ReadSignalType, "+")},
})
n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(&n1_policy), NewGroupExt(nil), NewECDHExt())
n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt())
read_sig := NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"},
})
ctx.Send(n2.ID, n1.ID, &read_sig)
ctx.Send(n2.ID, []Message{{n1.ID, &read_sig}})
res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool {
return true
@ -64,22 +58,20 @@ func TestECDH(t *testing.T) {
ctx := logTestContext(t, []string{"test", "ecdh", "policy"})
node_type := NodeType("TEST")
err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, ECDHExtType})
err := ctx.RegisterNodeType(node_type, []ExtType{ECDHExtType})
fatalErr(t, err)
n1_listener := NewListenerExt(10)
ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+"), MakeAction(ECDHProxySignalType, "+")})
n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt(), n1_listener)
n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt())
n1 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt(), n1_listener)
n2 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt())
n3_listener := NewListenerExt(10)
n3_policy := NewPerNodePolicy(NodeActions{n1.ID: Actions{MakeAction(StopSignalType)}})
n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy, &n3_policy), NewECDHExt(), n3_listener)
n3 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt(), n3_listener)
ctx.Log.Logf("test", "N1: %s", n1.ID)
ctx.Log.Logf("test", "N2: %s", n2.ID)
ecdh_req, n1_ec, err := NewECDHReqSignal(ctx, n1)
ecdh_req, n1_ec, err := NewECDHReqSignal(n1)
ecdh_ext, err := GetExt[*ECDHExt](n1)
fatalErr(t, err)
ecdh_ext.ECDHStates[n2.ID] = ECDHState{
@ -88,7 +80,7 @@ func TestECDH(t *testing.T) {
}
fatalErr(t, err)
ctx.Log.Logf("test", "N1_EC: %+v", n1_ec)
err = ctx.Send(n1.ID, n2.ID, &ecdh_req)
err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_req}})
fatalErr(t, err)
_, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool {
@ -100,6 +92,6 @@ func TestECDH(t *testing.T) {
ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret)
fatalErr(t, err)
err = ctx.Send(n1.ID, n2.ID, &ecdh_sig)
err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_sig}})
fatalErr(t, err)
}

@ -5,11 +5,8 @@ import (
"fmt"
)
type PolicyType string
func (policy PolicyType) Prefix() string { return "POLICY: " }
func (policy PolicyType) String() string { return string(policy) }
const (
UserOfPolicyType = PolicyType("USER_OF")
RequirementOfPolicyType = PolicyType("REQUIREMENT_OF")
PerNodePolicyType = PolicyType("PER_NODE")
AllNodesPolicyType = PolicyType("ALL_NODES")
@ -38,7 +35,7 @@ func (policy PerNodePolicy) Allows(principal_id NodeID, action Action, node *Nod
return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID)
}
func (policy RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error {
func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error {
lockable_ext, err := GetExt[*LockableExt](node)
if err != nil {
return err
@ -53,10 +50,45 @@ func (policy RequirementOfPolicy) Allows(principal_id NodeID, action Action, nod
return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID)
}
type UserOfPolicy struct {
PerNodePolicy
}
func (policy *UserOfPolicy) Type() PolicyType {
return UserOfPolicyType
}
func NewUserOfPolicy(group_actions NodeActions) UserOfPolicy {
return UserOfPolicy{
PerNodePolicy: NewPerNodePolicy(group_actions),
}
}
// Send a read signal to Group to check if principal_id is a member of it
func (policy *UserOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error {
// Send a read signal to each of the groups in the map
// Check for principal_id in any of the returned member lists(skipping errors)
// Return an error in the default case
return fmt.Errorf("NOT_IMPLEMENTED")
}
func (policy *UserOfPolicy) Merge(p Policy) Policy {
other := p.(*UserOfPolicy)
policy.NodeActions = MergeNodeActions(policy.NodeActions, other.NodeActions)
return policy
}
func (policy *UserOfPolicy) Copy() Policy {
new_actions := CopyNodeActions(policy.NodeActions)
return &UserOfPolicy{
PerNodePolicy: NewPerNodePolicy(new_actions),
}
}
type RequirementOfPolicy struct {
AllNodesPolicy
}
func (policy RequirementOfPolicy) Type() PolicyType {
func (policy *RequirementOfPolicy) Type() PolicyType {
return RequirementOfPolicyType
}
@ -82,20 +114,25 @@ func CopyNodeActions(actions NodeActions) NodeActions {
return ret
}
func MergeNodeActions(modified NodeActions, read NodeActions) {
for id, actions := range(read) {
existing, exists := modified[id]
func MergeNodeActions(first NodeActions, second NodeActions) NodeActions {
merged := NodeActions{}
for id, actions := range(first) {
merged[id] = actions
}
for id, actions := range(second) {
existing, exists := merged[id]
if exists {
modified[id] = MergeActions(existing, actions)
merged[id] = MergeActions(existing, actions)
} else {
modified[id] = actions
merged[id] = actions
}
}
return merged
}
func (policy *PerNodePolicy) Merge(p Policy) Policy {
other := p.(*PerNodePolicy)
MergeNodeActions(policy.NodeActions, other.NodeActions)
policy.NodeActions = MergeNodeActions(policy.NodeActions, other.NodeActions)
return policy
}
@ -263,63 +300,6 @@ func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy)
}
// Extension to allow a node to hold ACL policies
type ACLExt struct {
Policies map[PolicyType]Policy
}
func NodeList(nodes ...*Node) NodeMap {
m := NodeMap{}
for _, node := range(nodes) {
m[node.ID] = node
}
return m
}
type PolicyLoadFunc func(*Context,[]byte) (Policy, error)
type ACLExtContext struct {
Loads map[PolicyType]PolicyLoadFunc
}
func NewACLExtContext() *ACLExtContext {
return &ACLExtContext{
Loads: map[PolicyType]PolicyLoadFunc{
AllNodesPolicyType: LoadPolicy[AllNodesPolicy,*AllNodesPolicy],
PerNodePolicyType: LoadPolicy[PerNodePolicy,*PerNodePolicy],
RequirementOfPolicyType: LoadPolicy[RequirementOfPolicy,*RequirementOfPolicy],
},
}
}
func (ext *ACLExt) Serialize() ([]byte, error) {
policies := map[string][]byte{}
for name, policy := range(ext.Policies) {
ser, err := policy.Serialize()
if err != nil {
return nil, err
}
policies[string(name)] = ser
}
return json.MarshalIndent(&struct{
Policies map[string][]byte `json:"policies"`
}{
Policies: policies,
}, "", " ")
}
func (ext *ACLExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
}
func (ext *ACLExt) Field(name string) interface{} {
return ResolveFields(ext, name, map[string]func(*ACLExt)interface{}{
"policies": func(ext *ACLExt) interface{} {
return ext.Policies
},
})
}
var ErrorSignalAction = Action{"ERROR_RESP"}
var ReadResultSignalAction = Action{"READ_RESULT"}
var AuthorizedSignalAction = Action{"AUTHORIZED_READ"}
@ -327,82 +307,3 @@ var defaultPolicy = NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSigna
var DefaultACLPolicies = []Policy{
&defaultPolicy,
}
func NewACLExt(policies ...Policy) *ACLExt {
policy_map := map[PolicyType]Policy{}
for _, policy_arg := range(append(policies, DefaultACLPolicies...)) {
policy := policy_arg.Copy()
existing, exists := policy_map[policy.Type()]
if exists == true {
policy = existing.Merge(policy)
}
policy_map[policy.Type()] = policy
}
return &ACLExt{
Policies: policy_map,
}
}
func LoadPolicy[T any, P interface {
*T
Policy
}](ctx *Context, data []byte) (Policy, error) {
p := P(new(T))
err := p.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return p, nil
}
func (ext *ACLExt) Deserialize(ctx *Context, data []byte) error {
var j struct {
Policies map[string][]byte `json:"policies"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return err
}
acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx)
if err != nil {
return err
}
ext.Policies = map[PolicyType]Policy{}
for name, ser := range(j.Policies) {
policy_load, exists := acl_ctx.Loads[PolicyType(name)]
if exists == false {
return fmt.Errorf("%s is not a known policy type", name)
}
policy, err := policy_load(ctx, ser)
if err != nil {
return err
}
ext.Policies[PolicyType(name)] = policy
}
return nil
}
func (ext *ACLExt) Type() ExtType {
return ACLExtType
}
// Check if the extension allows the principal to perform action on node
func (ext *ACLExt) Allows(ctx *Context, principal_id NodeID, action Action, node *Node) error {
errs := []error{}
for _, policy := range(ext.Policies) {
err := policy.Allows(principal_id, action, node)
if err == nil {
return nil
}
errs = append(errs, err)
}
return fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs)
}

@ -28,7 +28,6 @@ const (
ReadResultSignalType = "READ_RESULT"
LinkStartSignalType = "LINK_START"
ECDHSignalType = "ECDH"
ECDHStateSignalType = "ECDH_STATE"
ECDHProxySignalType = "ECDH_PROXY"
GQLStateSignalType = "GQL_STATE"
@ -43,6 +42,7 @@ func (signal_type SignalType) Prefix() string { return "SIGNAL: " }
type Signal interface {
Serializable[SignalType]
String() string
Direction() SignalDirection
ID() uuid.UUID
Permission() Action
@ -70,12 +70,12 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.
}
for true {
select {
case signal := <- listener.Chan:
if signal == nil {
case msg := <- listener.Chan:
if msg.Signal == nil {
return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type)
}
if signal.Type() == signal_type {
sig, ok := signal.(S)
if msg.Signal.Type() == signal_type {
sig, ok := msg.Signal.(S)
if ok == true {
if check(sig) == true {
return sig, nil
@ -96,6 +96,11 @@ type BaseSignal struct {
UUID uuid.UUID `json:"id"`
}
func (signal *BaseSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
}
func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, signal)
}
@ -129,21 +134,9 @@ func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal
return signal
}
func NewDownSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Down)
}
func NewUpSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Up)
}
func NewDirectSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Direct)
}
var NewSignal = NewDirectSignal(NewSignalType)
var StartSignal = NewDirectSignal(StartSignalType)
var StopSignal = NewDownSignal(StopSignalType)
var NewSignal = NewBaseSignal(NewSignalType, Direct)
var StartSignal = NewBaseSignal(StartSignalType, Direct)
var StopSignal = NewBaseSignal(StopSignalType, Direct)
type IDSignal struct {
BaseSignal
@ -154,88 +147,91 @@ func (signal *IDSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal {
return IDSignal{
BaseSignal: NewBaseSignal(signal_type, direction),
NodeID: id,
}
}
type StringSignal struct {
BaseSignal
Str string `json:"state"`
}
func (signal *StringSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
}
func (signal *StringSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
}
type RespSignal struct {
BaseSignal
ReqID uuid.UUID
}
type ErrorSignal struct {
StringSignal
RespSignal
Error string
}
func (signal *ErrorSignal) String() string {
ser, _ := json.Marshal(signal)
return string(ser)
}
func (signal *ErrorSignal) Permission() Action {
return ErrorSignalAction
}
func NewErrorSignal(req_id uuid.UUID, err string) ErrorSignal {
return ErrorSignal{
StringSignal{
NewDirectSignal(ErrorSignalType),
err,
func NewErrorSignal(req_id uuid.UUID, err string) Signal {
return &ErrorSignal{
RespSignal{
NewBaseSignal(ErrorSignalType, Direct),
req_id,
},
err,
}
}
type IDStringSignal struct {
BaseSignal
NodeID `json:"node_id"`
NodeID NodeID `json:"node_id"`
Str string `json:"string"`
}
func (signal *IDStringSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func (signal *IDStringSignal) String() string {
ser, err := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
}
ser, _ := json.Marshal(signal)
return string(ser)
}
func NewStatusSignal(status string, source NodeID) IDStringSignal {
return IDStringSignal{
BaseSignal: NewUpSignal(StatusSignalType),
func (signal *IDStringSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func NewStatusSignal(status string, source NodeID) Signal {
return &IDStringSignal{
BaseSignal: NewBaseSignal(StatusSignalType, Up),
NodeID: source,
Str: status,
}
}
func NewLinkSignal(state string) StringSignal {
return StringSignal{
BaseSignal: NewDirectSignal(LinkSignalType),
func NewLinkSignal(state string) Signal {
return &StringSignal{
BaseSignal: NewBaseSignal(LinkSignalType, Direct),
Str: state,
}
}
func NewIDStringSignal(signal_type SignalType, direction SignalDirection, state string, id NodeID) IDStringSignal {
return IDStringSignal{
BaseSignal: NewBaseSignal(signal_type, direction),
NodeID: id,
Str: state,
func NewLinkStartSignal(link_type string, target NodeID) Signal {
return &IDStringSignal{
NewBaseSignal(LinkStartSignalType, Direct),
target,
link_type,
}
}
func NewLinkStartSignal(link_type string, target NodeID) IDStringSignal {
return NewIDStringSignal(LinkStartSignalType, Direct, link_type, target)
}
func NewLockSignal(state string) StringSignal {
return StringSignal{
BaseSignal: NewDirectSignal(LockSignalType),
Str: state,
func NewLockSignal(state string) Signal {
return &StringSignal{
NewBaseSignal(LockSignalType, Direct),
state,
}
}
@ -259,22 +255,22 @@ func (signal *AuthorizedSignal) Permission() Action {
return AuthorizedSignalAction
}
func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (AuthorizedSignal, error) {
func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (Signal, error) {
sig_data, err := signal.Serialize()
if err != nil {
return AuthorizedSignal{}, err
return nil, err
}
sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil {
return AuthorizedSignal{}, err
return nil, err
}
return AuthorizedSignal{
BaseSignal: NewDirectSignal(AuthorizedSignalType),
Principal: principal.Public().(ed25519.PublicKey),
Signal: signal,
Signature: sig,
return &AuthorizedSignal{
NewBaseSignal(AuthorizedSignalType, Direct),
principal.Public().(ed25519.PublicKey),
signal,
sig,
}, nil
}
@ -284,13 +280,13 @@ func (signal *ReadSignal) Serialize() ([]byte, error) {
func NewReadSignal(exts map[ExtType][]string) ReadSignal {
return ReadSignal{
BaseSignal: NewDirectSignal(ReadSignalType),
Extensions: exts,
NewBaseSignal(ReadSignalType, Direct),
exts,
}
}
type ReadResultSignal struct {
BaseSignal
RespSignal
NodeType
Extensions map[ExtType]map[string]interface{} `json:"extensions"`
}
@ -299,15 +295,14 @@ func (signal *ReadResultSignal) Permission() Action {
return ReadResultSignalAction
}
func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]map[string]interface{}) ReadResultSignal {
return ReadResultSignal{
BaseSignal: BaseSignal{
Direct,
ReadResultSignalType,
func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]map[string]interface{}) Signal {
return &ReadResultSignal{
RespSignal{
NewBaseSignal(ReadResultSignalType, Direct),
req_id,
},
NodeType: node_type,
Extensions: exts,
node_type,
exts,
}
}
@ -341,28 +336,28 @@ func (signal *ECDHSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, error) {
ec_key, err := ctx.ECDH.GenerateKey(rand.Reader)
func NewECDHReqSignal(node *Node) (Signal, *ecdh.PrivateKey, error) {
ec_key, err := ECDH.GenerateKey(rand.Reader)
if err != nil {
return ECDHSignal{}, nil, err
return nil, nil, err
}
now := time.Now()
time_bytes, err := now.MarshalJSON()
if err != nil {
return ECDHSignal{}, nil, err
return nil, nil, err
}
sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...)
sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil {
return ECDHSignal{}, nil, err
return nil, nil, err
}
return ECDHSignal{
return &ECDHSignal{
StringSignal: StringSignal{
BaseSignal: NewDirectSignal(ECDHSignalType),
BaseSignal: NewBaseSignal(ECDHSignalType, Direct),
Str: "req",
},
Time: now,
@ -374,7 +369,7 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e
const DEFAULT_ECDH_WINDOW = time.Second
func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) {
func NewECDHRespSignal(node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) {
now := time.Now()
err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW)
@ -382,7 +377,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [
return ECDHSignal{}, nil, err
}
ec_key, err := ctx.ECDH.GenerateKey(rand.Reader)
ec_key, err := ECDH.GenerateKey(rand.Reader)
if err != nil {
return ECDHSignal{}, nil, err
}
@ -406,7 +401,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [
return ECDHSignal{
StringSignal: StringSignal{
BaseSignal: NewDirectSignal(ECDHSignalType),
BaseSignal: NewBaseSignal(ECDHSignalType, Direct),
Str: "resp",
},
Time: now,
@ -449,34 +444,34 @@ type ECDHProxySignal struct {
Data []byte
}
func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (ECDHProxySignal, error) {
func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (Signal, error) {
if shared_secret == nil {
return ECDHProxySignal{}, fmt.Errorf("need shared_secret")
return nil, fmt.Errorf("need shared_secret")
}
aes_key, err := aes.NewCipher(shared_secret[:32])
if err != nil {
return ECDHProxySignal{}, err
return nil, err
}
ser, err := SerializeSignal(signal, aes_key.BlockSize())
if err != nil {
return ECDHProxySignal{}, err
return nil, err
}
iv := make([]byte, aes_key.BlockSize())
n, err := rand.Reader.Read(iv)
if err != nil {
return ECDHProxySignal{}, err
return nil, err
} else if n != len(iv) {
return ECDHProxySignal{}, fmt.Errorf("Not enough bytes read for IV")
return nil, fmt.Errorf("Not enough bytes read for IV")
}
encrypter := cipher.NewCBCEncrypter(aes_key, iv)
encrypter.CryptBlocks(ser, ser)
return ECDHProxySignal{
BaseSignal: NewDirectSignal(ECDHProxySignalType),
return &ECDHProxySignal{
BaseSignal: NewBaseSignal(ECDHProxySignalType, Direct),
Source: source,
Dest: dest,
IV: iv,

@ -48,7 +48,7 @@ func (ext *GroupExt) Deserialize(ctx *Context, data []byte) error {
return err
}
func (ext *GroupExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
return
func (ext *GroupExt) Process(ctx *Context, node *Node, msg Message) []Message {
return nil
}