From 8770d6f433e9f04b83933c69c5798bb63b73342e Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Mon, 7 Aug 2023 20:26:02 -0600 Subject: [PATCH] Moved policies to node instead of an extension, need to fix gql tests --- context.go | 71 +++++++----- ecdh.go | 68 ++++-------- gql.go | 55 ++++++---- gql_query.go | 5 +- gql_test.go | 29 ++--- graph_test.go | 4 +- lockable.go | 232 +++++++++++++++------------------------ lockable_test.go | 39 +++---- node.go | 279 ++++++++++++++++++++++++++++------------------- node_test.go | 30 ++--- policy.go | 197 +++++++++------------------------ signal.go | 189 ++++++++++++++++---------------- user.go | 4 +- 13 files changed, 535 insertions(+), 667 deletions(-) diff --git a/context.go b/context.go index c36141c..6048e50 100644 --- a/context.go +++ b/context.go @@ -48,9 +48,30 @@ func LoadExtension[T any, E interface { return e, nil } +type PolicyType string +func (policy PolicyType) Prefix() string { return "POLICY: " } +func (policy PolicyType) String() string { return string(policy) } + +type PolicyLoadFunc func(*Context,[]byte) (Policy, error) +func LoadPolicy[T any, P interface { + *T + Policy +}](ctx *Context, data []byte) (Policy, error) { + p := P(new(T)) + err := p.Deserialize(ctx, data) + if err != nil { + return nil, err + } + return p, nil +} + +type PolicyInfo struct { + Load PolicyLoadFunc + Type PolicyType +} + // ExtType and NodeType constants const ( - ACLExtType = ExtType("ACL") ListenerExtType = ExtType("LISTENER") LockableExtType = ExtType("LOCKABLE") GQLExtType = ExtType("GQL") @@ -62,6 +83,7 @@ const ( var ( NodeNotFoundError = errors.New("Node not found in DB") + ECDH = ecdh.X25519() ) type SignalLoadFunc func(*Context,[]byte) (Signal, error) @@ -107,12 +129,12 @@ type Context struct { Log Logger // Map between database extension hashes and the registered info Extensions map[uint64]ExtensionInfo + // Map between databse policy hashes and the registered info + Policies map[uint64]PolicyInfo // Map between serialized signal hashes and the registered info Signals map[uint64]SignalInfo // Map between database type hashes and the registered info Types map[uint64]*NodeInfo - // Curve used for ecdh operations - ECDH ecdh.Curve // Routing map to all the nodes local to this context NodesLock sync.RWMutex Nodes map[NodeID]*Node @@ -216,28 +238,31 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) { // Stop every running loop func (ctx *Context) Stop() { for _, node := range(ctx.Nodes) { - node.MsgChan <- Msg{ZeroID, &StopSignal} + node.MsgChan <- Message{ZeroID, &StopSignal} } } // Route a Signal to dest. Currently only local context routing is supported -func (ctx *Context) Send(source NodeID, dest NodeID, signal Signal) error { - target, err := ctx.GetNode(dest) - if err == nil { - select { - case target.MsgChan <- Msg{source, signal}: - default: - buf := make([]byte, 4096) - n := runtime.Stack(buf, false) - stack_str := string(buf[:n]) - return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", dest, stack_str) +func (ctx *Context) Send(source NodeID, messages []Message) error { + for _, msg := range(messages) { + target, err := ctx.GetNode(msg.NodeID) + if err == nil { + select { + case target.MsgChan <- Message{source, msg.Signal}: + default: + buf := make([]byte, 4096) + n := runtime.Stack(buf, false) + stack_str := string(buf[:n]) + return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", msg.NodeID, stack_str) + } + } else if errors.Is(err, NodeNotFoundError) { + // TODO: Handle finding nodes in other contexts + return err + } else { + return err } - return nil - } else if errors.Is(err, NodeNotFoundError) { - // TODO: Handle finding nodes in other contexts - return err } - return err + return nil } // Create a new Context with the base library content added @@ -249,15 +274,9 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { Types: map[uint64]*NodeInfo{}, Signals: map[uint64]SignalInfo{}, Nodes: map[NodeID]*Node{}, - ECDH: ecdh.X25519(), } var err error - err = RegisterExtension[ACLExt,*ACLExt](ctx, NewACLExtContext()) - if err != nil { - return nil, err - } - err = RegisterExtension[LockableExt,*LockableExt](ctx, nil) if err != nil { return nil, err @@ -299,7 +318,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } - err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType}) + err = ctx.RegisterNodeType(GQLNodeType, []ExtType{GroupExtType, GQLExtType}) if err != nil { return nil, err } diff --git a/ecdh.go b/ecdh.go index e517917..401f813 100644 --- a/ecdh.go +++ b/ecdh.go @@ -103,28 +103,30 @@ func (ext *ECDHExt) Field(name string) interface{} { }) } -func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal *ECDHSignal) { +func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) []Message { + source := KeyID(signal.EDDSA) + + messages := []Message{} switch signal.Str { case "req": state, exists := ext.ECDHStates[source] if exists == false { state = ECDHState{nil, nil} } - resp, shared_secret, err := NewECDHRespSignal(ctx, node, signal) + resp, shared_secret, err := NewECDHRespSignal(node, signal) if err == nil { state.SharedSecret = shared_secret ext.ECDHStates[source] = state - ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) - ctx.Send(node.ID, source, &resp) + log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) + messages = append(messages, Message{source, &resp}) } else { - ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) - // TODO: send error response + log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), err.Error())}) } case "resp": state, exists := ext.ECDHStates[source] if exists == false || state.ECKey == nil { - resp := NewErrorSignal(signal.ID(), "no_req") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "no_req")}) } else { err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) if err == nil { @@ -133,55 +135,23 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si state.SharedSecret = shared_secret state.ECKey = nil ext.ECDHStates[source] = state - ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) + log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) } } } default: - ctx.Log.Logf("ecdh", "unknown echd state %s", signal.Str) + log.Logf("ecdh", "unknown echd state %s", signal.Str) } + return messages } -func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) { -} - -func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal *ECDHProxySignal) { - state, exists := ext.ECDHStates[source] - if exists == false { - resp := NewErrorSignal(signal.ID(), "no_req") - ctx.Send(node.ID, source, &resp) - } else if state.SharedSecret == nil { - resp := NewErrorSignal(signal.ID(), "no_shared") - ctx.Send(node.ID, source, &resp) - } else { - unwrapped_signal, err := ParseECDHProxySignal(ctx, signal, state.SharedSecret) - if err != nil { - resp := NewErrorSignal(signal.ID(), err.Error()) - ctx.Send(node.ID, source, &resp) - } else { - //TODO: Figure out what I was trying to do here and fix it - ctx.Send(signal.Source, signal.Dest, unwrapped_signal) - } - } -} - -func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { - switch signal.Direction() { - case Direct: - switch signal.Type() { - case ECDHProxySignalType: - ecdh_signal := signal.(*ECDHProxySignal) - ext.HandleECDHProxySignal(ctx, source, node, ecdh_signal) - case ECDHStateSignalType: - ecdh_signal := signal.(*StringSignal) - ext.HandleStateSignal(ctx, source, node, ecdh_signal) - case ECDHSignalType: - ecdh_signal := signal.(*ECDHSignal) - ext.HandleECDHSignal(ctx, source, node, ecdh_signal) - default: - } - default: +func (ext *ECDHExt) Process(ctx *Context, node *Node, msg Message) []Message { + switch msg.Signal.Type() { + case ECDHSignalType: + sig := msg.Signal.(*ECDHSignal) + return ext.HandleECDHSignal(ctx.Log, node, sig) } + return nil } func (ext *ECDHExt) Type() ExtType { diff --git a/gql.go b/gql.go index b6b79e7..d57fe01 100644 --- a/gql.go +++ b/gql.go @@ -17,9 +17,10 @@ import ( "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "strings" + "crypto/ecdsa" + "crypto/elliptic" "crypto/ecdh" "crypto/ed25519" - "crypto/elliptic" "crypto/rand" "crypto/x509" "crypto/tls" @@ -189,15 +190,16 @@ type ResolveContext struct { } func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { - username, key_bytes, ok := r.BasicAuth() + id_bytes, key_bytes, ok := r.BasicAuth() if ok == false { return nil, fmt.Errorf("GQL_REQUEST_ERR: no auth header included in request header") } - auth_id, err := ParseID(username) + auth_uuid, err := uuid.FromBytes([]byte(id_bytes)) if err != nil { - return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) + return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username") } + auth_id := NodeID(auth_uuid) key_raw, err := x509.ParsePKCS8PrivateKey([]byte(key_bytes)) if err != nil { @@ -916,7 +918,7 @@ func NewGQLExtContext() *GQLExtContext { panic(err) } - context.Mutation.AddFieldConfig("stopServer", &graphql.Field{ + context.Mutation.AddFieldConfig("stop", &graphql.Field{ Type: graphql.String, Resolve: func(p graphql.ResolveParams) (interface{}, error) { ctx, err := PrepResolve(p) @@ -924,14 +926,13 @@ func NewGQLExtContext() *GQLExtContext { return nil, err } - sig := StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"} - err = Allowed(ctx.Context, ctx.User, sig.Permission(), ctx.Server) + sig, err := NewAuthorizedSignal(ctx.Key, &StopSignal) if err != nil { - return err, nil + return nil, err } response_chan := ctx.Ext.GetResponseChannel(sig.ID()) - err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, &sig) + err = ctx.Context.Send(ctx.Server.ID, []Message{Message{ctx.Server.ID, sig}}) if err != nil { ctx.Ext.FreeResponseChannel(sig.ID()) return nil, err @@ -1016,8 +1017,8 @@ type GQLExt struct { resolver_response_lock sync.RWMutex `json:"-"` State string `json:"state"` - tls_key []byte `json:"tls_key"` - tls_cert []byte `json:"tls_cert"` + TLSKey []byte `json:"tls_key"` + TLSCert []byte `json:"tls_cert"` Listen string `json:"listen"` } @@ -1052,12 +1053,14 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal { } } -func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { +func (ext *GQLExt) Process(ctx *Context, node *Node, msg Message) []Message { // Process ReadResultSignalType by forwarding it to the waiting resolver + signal := msg.Signal + messages := []Message{} if signal.Type() == ErrorSignalType { // TODO: Forward to resolver if waiting for it sig := signal.(*ErrorSignal) - response_chan := ext.FreeResponseChannel(sig.ID()) + response_chan := ext.FreeResponseChannel(sig.UUID) if response_chan != nil { select { case response_chan <- sig: @@ -1084,14 +1087,16 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa } } else if signal.Type() == GQLStateSignalType { sig := signal.(*StringSignal) + ctx.Log.Logf("gql", "GQL_STATE_SIGNAL: %s - %+v", node.ID, sig.Str) switch sig.Str { case "start_server": if ext.State == "stopped" { err := ext.StartGQLServer(ctx, node) if err == nil { ext.State = "running" - resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"} - ctx.Send(node.ID, source, &resp) + node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID)) + } else { + ctx.Log.Logf("gql", "GQL_START_ERROR: %s", err) } } case "stop_server": @@ -1099,8 +1104,9 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa err := ext.StopGQLServer() if err == nil { ext.State = "stopped" - resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_stopped"} - ctx.Send(node.ID, source, &resp) + node.QueueSignal(time.Now(), NewStatusSignal("server_stopped", node.ID)) + } else { + ctx.Log.Logf("gql", "GQL_STOP_ERROR: %s", err) } } default: @@ -1112,14 +1118,16 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa case "running": err := ext.StartGQLServer(ctx, node) if err == nil { - resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"} - ctx.Send(node.ID, source, &resp) + node.QueueSignal(time.Now(), NewStatusSignal("server_started", node.ID)) + } else { + ctx.Log.Logf("gql", "GQL_RESTART_ERROR: %s", err) } case "stopped": default: ctx.Log.Logf("gql", "unknown state to restore from: %s", ext.State) } } + return messages } func (ext *GQLExt) Type() ExtType { @@ -1147,12 +1155,13 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{ } func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error { + ext.resolver_response = map[uuid.UUID]chan Signal{} return json.Unmarshal(data, &ext) } func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) { if tls_cert == nil || tls_key == nil { - _, ssl_key, err := ed25519.GenerateKey(rand.Reader) + ssl_key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, err } @@ -1194,8 +1203,8 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta State: state, Listen: listen, resolver_response: map[uuid.UUID]chan Signal{}, - tls_cert: tls_cert, - tls_key: tls_key, + TLSCert: tls_cert, + TLSKey: tls_key, }, nil } @@ -1224,7 +1233,7 @@ func (ext *GQLExt) StartGQLServer(ctx *Context, node *Node) error { return fmt.Errorf("Failed to start listener for server on %s", http_server.Addr) } - cert, err := tls.X509KeyPair(ext.tls_cert, ext.tls_key) + cert, err := tls.X509KeyPair(ext.TLSCert, ext.TLSKey) if err != nil { return err } diff --git a/gql_query.go b/gql_query.go index 84a9d10..407c85d 100644 --- a/gql_query.go +++ b/gql_query.go @@ -61,7 +61,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ resp_channels[read_signal.ID()] = response_chan node_ids[read_signal.ID()] = id - err = ctx.Context.Send(ctx.Server.ID, id, &auth_signal) + err = ctx.Context.Send(ctx.Server.ID, []Message{Message{id, auth_signal}}) if err != nil { ctx.Ext.FreeResponseChannel(read_signal.ID()) return nil, err @@ -79,11 +79,12 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ case *ReadResultSignal: responses = append(responses, NodeResult{node_ids[sig_id], resp}) case *ErrorSignal: - return nil, fmt.Errorf(resp.Str) + return nil, fmt.Errorf(resp.Error) default: return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp)) } } + ctx.Context.Log.Logf("gql", "RESOLVED_NODES") return responses, nil } diff --git a/gql_test.go b/gql_test.go index 6992ee9..41df6ac 100644 --- a/gql_test.go +++ b/gql_test.go @@ -11,30 +11,28 @@ import ( "crypto/tls" "crypto/x509" "bytes" - "github.com/google/uuid" ) func TestGQL(t *testing.T) { - ctx := logTestContext(t, []string{}) + ctx := logTestContext(t, []string{"gql", "lockable", "node_timeout", "listener"}) TestNodeType := NodeType("TEST") - err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType, ACLExtType}) + err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) fatalErr(t, err) gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped") fatalErr(t, err) listener_ext := NewListenerExt(10) - policy := NewAllNodesPolicy(Actions{MakeAction("+")}) - start_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "start_server"} - gql := NewNode(ctx, nil, GQLNodeType, 10, []QueuedSignal{ - QueuedSignal{uuid.New(), &start_signal, time.Now()}, - }, NewLockableExt(), NewACLExt(&policy), gql_ext, NewGroupExt(nil), listener_ext) - n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(&policy)) + gql := NewNode(ctx, nil, GQLNodeType, 10, nil, NewLockableExt(), gql_ext, NewGroupExt(nil), listener_ext) + n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt()) err = LinkRequirement(ctx, gql.ID, n1.ID) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool { + err = ctx.Send(gql.ID, []Message{{gql.ID, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}}}) + fatalErr(t, err) + + _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "server_started" }) fatalErr(t, err) @@ -86,8 +84,8 @@ func TestGQL(t *testing.T) { resp_2 := SendGQL(req_2) ctx.Log.Logf("test", "RESP_2: %s", resp_2) - stop_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"} - ctx.Send(n1.ID, gql.ID, &stop_signal) + stop_signal := StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"} + ctx.Send(n1.ID, []Message{{gql.ID, &stop_signal}}) _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool { return sig.Str == "server_stopped" }) @@ -109,13 +107,10 @@ func TestGQLDB(t *testing.T) { gql := NewNode(ctx, nil, GQLNodeType, 10, nil, gql_ext, listener_ext, - NewACLExt(), NewGroupExt(nil)) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) - err = ctx.Send(gql.ID, gql.ID, &StopSignal) - fatalErr(t, err) - + ctx.Stop() _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "stopped" && sig.NodeID == gql.ID }) @@ -134,7 +129,7 @@ func TestGQLDB(t *testing.T) { fatalErr(t, err) listener_ext, err = GetExt[*ListenerExt](gql_loaded) fatalErr(t, err) - err = ctx.Send(gql_loaded.ID, gql_loaded.ID, &StopSignal) + err = ctx.Send(gql_loaded.ID, []Message{{gql_loaded.ID, &StopSignal}}) fatalErr(t, err) _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID diff --git a/graph_test.go b/graph_test.go index 9ad43d2..f13e528 100644 --- a/graph_test.go +++ b/graph_test.go @@ -9,7 +9,6 @@ import ( const SimpleListenerNodeType = NodeType("SIMPLE_LISTENER") func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { - policy := NewAllNodesPolicy(Actions{MakeAction("status")}) listener_extension := NewListenerExt(buffer) listener := NewNode(ctx, nil, @@ -17,7 +16,6 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { 10, nil, listener_extension, - NewACLExt(&policy), NewLockableExt()) return listener, listener_extension @@ -32,7 +30,7 @@ func logTestContext(t * testing.T, components []string) *Context { ctx, err := NewContext(db, NewConsoleLogger(components)) fatalErr(t, err) - err = ctx.RegisterNodeType(SimpleListenerNodeType, []ExtType{ACLExtType, ListenerExtType, LockableExtType}) + err = ctx.RegisterNodeType(SimpleListenerNodeType, []ExtType{ListenerExtType, LockableExtType}) fatalErr(t, err) return ctx diff --git a/lockable.go b/lockable.go index a775f6e..359ab86 100644 --- a/lockable.go +++ b/lockable.go @@ -7,14 +7,14 @@ import ( // A Listener extension provides a channel that can receive signals on a different thread type ListenerExt struct { Buffer int - Chan chan Signal + Chan chan Message } // Create a new listener extension with a given buffer size func NewListenerExt(buffer int) *ListenerExt { return &ListenerExt{ Buffer: buffer, - Chan: make(chan Signal, buffer), + Chan: make(chan Message, buffer), } } @@ -32,7 +32,7 @@ func (ext *ListenerExt) Field(name string) interface{} { // Simple load function, unmarshal the buffer int from json func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error { err := json.Unmarshal(data, &ext.Buffer) - ext.Chan = make(chan Signal, ext.Buffer) + ext.Chan = make(chan Message, ext.Buffer) return err } @@ -41,14 +41,14 @@ func (listener *ListenerExt) Type() ExtType { } // Send the signal to the channel, logging an overflow if it occurs -func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { - ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal) +func (ext *ListenerExt) Process(ctx *Context, node *Node, msg Message) []Message { + ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, msg.Signal) select { - case ext.Chan <- signal: + case ext.Chan <- msg: default: ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID) } - return + return nil } // ReqState holds the multiple states of a requirement @@ -138,43 +138,38 @@ func NewLockableExt() *LockableExt { // Send the signal to unlock a node from itself func UnlockLockable(ctx *Context, node *Node) error { - lock_signal := NewLockSignal("unlock") - return ctx.Send(node.ID, node.ID, &lock_signal) + return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("unlock")}}) } // Send the signal to lock a node from itself func LockLockable(ctx *Context, node *Node) error { - lock_signal := NewLockSignal("lock") - return ctx.Send(node.ID, node.ID, &lock_signal) + return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("lock")}}) } // Setup a node to send the initial requirement link signal, then send the signal func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error { - start_signal := NewLinkStartSignal("req", requirement) - return ctx.Send(dependency, dependency, &start_signal) + return ctx.Send(dependency, []Message{Message{dependency, NewLinkStartSignal("req", requirement)}}) } // Handle a LockSignal and update the extensions owner/requirement states -func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) { - ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal) +func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message { state := signal.Str + log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal) + + messages := []Message{} switch state { case "unlock": if ext.Owner == nil { - resp := NewErrorSignal(signal.ID(), "already_unlocked") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocked")}) } else if source != *ext.Owner { - resp := NewErrorSignal(signal.ID(), "not_owner") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_owner")}) } else if ext.PendingOwner == nil { - resp := NewErrorSignal(signal.ID(), "already_unlocking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocking")}) } else { if len(ext.Requirements) == 0 { ext.Owner = nil ext.PendingOwner = nil - resp := NewLockSignal("unlocked") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewLockSignal("unlocked")}) } else { ext.PendingOwner = nil for id, state := range(ext.Requirements) { @@ -184,44 +179,36 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node } state.Lock = "unlocking" ext.Requirements[id] = state - resp := NewLockSignal("unlock") - ctx.Send(node.ID, id, &resp) + messages = append(messages, Message{id, NewLockSignal("unlock")}) } } if source != node.ID { - resp := NewLockSignal("unlocking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewLockSignal("unlocking")}) } } } case "unlocking": state, exists := ext.Requirements[source] if exists == false { - resp := NewErrorSignal(signal.ID(), "not_requirement") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) } else if state.Link != "linked" { - resp := NewErrorSignal(signal.ID(), "not_linked") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) } else if state.Lock != "unlocking" { - resp := NewErrorSignal(signal.ID(), "not_unlocking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")}) } case "unlocked": if source == node.ID { - return + return nil } state, exists := ext.Requirements[source] if exists == false { - resp := NewErrorSignal(signal.ID(), "not_requirement") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) } else if state.Link != "linked" { - resp := NewErrorSignal(signal.ID(), "not_linked") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) } else if state.Lock != "unlocking" { - resp := NewErrorSignal(signal.ID(), "not_unlocking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")}) } else { state.Lock = "unlocked" ext.Requirements[source] = state @@ -241,26 +228,22 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node if linked == unlocked { previous_owner := *ext.Owner ext.Owner = nil - resp := NewLockSignal("unlocked") - ctx.Send(node.ID, previous_owner, &resp) + messages = append(messages, Message{previous_owner, NewLockSignal("unlocked")}) } } } case "locked": if source == node.ID { - return + return nil } state, exists := ext.Requirements[source] if exists == false { - resp := NewErrorSignal(signal.ID(), "not_requirement") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) } else if state.Link != "linked" { - resp := NewErrorSignal(signal.ID(), "not_linked") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) } else if state.Lock != "locking" { - resp := NewErrorSignal(signal.ID(), "not_locking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")}) } else { state.Lock = "locked" ext.Requirements[source] = state @@ -279,176 +262,142 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node if linked == locked { ext.Owner = ext.PendingOwner - resp := NewLockSignal("locked") - ctx.Send(node.ID, *ext.Owner, &resp) + messages = append(messages, Message{*ext.Owner, NewLockSignal("locked")}) } } } case "locking": state, exists := ext.Requirements[source] if exists == false { - resp := NewErrorSignal(signal.ID(), "not_requirement") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) } else if state.Link != "linked" { - resp := NewErrorSignal(signal.ID(), "not_linked") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) } else if state.Lock != "locking" { - resp := NewErrorSignal(signal.ID(), "not_locking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")}) } case "lock": if ext.Owner != nil { - resp := NewErrorSignal(signal.ID(), "already_locked") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locked")}) } else if ext.PendingOwner != nil { - resp := NewErrorSignal(signal.ID(), "already_locking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locking")}) } else { owner := source if len(ext.Requirements) == 0 { ext.Owner = &owner ext.PendingOwner = ext.Owner - resp := NewLockSignal("locked") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewLockSignal("locked")}) } else { ext.PendingOwner = &owner for id, state := range(ext.Requirements) { if state.Link == "linked" { + log.Logf("lockable", "LOCK_REQ: %s sending 'lock' to %s", node.ID, id) if state.Lock != "unlocked" { panic("NOT_UNLOCKED") } state.Lock = "locking" ext.Requirements[id] = state - sub := NewLockSignal("lock") - ctx.Send(node.ID, id, &sub) + messages = append(messages, Message{id, NewLockSignal("lock")}) } } if source != node.ID { - resp := NewLockSignal("locking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewLockSignal("locking")}) } } } default: - ctx.Log.Logf("lockable", "LOCK_ERR: unkown state %s", state) + log.Logf("lockable", "LOCK_ERR: unkown state %s", state) } + log.Logf("lockable", "LOCK_MESSAGES: %+v", messages) + return messages } -func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal *IDStringSignal) { - ctx.Log.Logf("lockable", "LINK__START_SIGNAL: %s->%s %+v", source, node.ID, signal) +func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) []Message { link_type := signal.Str target := signal.NodeID + log.Logf("lockable", "LINK_START_SIGNAL: %s->%s %s %s", source, node.ID, link_type, target) + + messages := []Message{} switch link_type { case "req": state, exists := ext.Requirements[target] _, dep_exists := ext.Dependencies[target] if ext.Owner != nil { - resp := NewLinkStartSignal("locked", target) - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already locked")}) } else if ext.Owner != ext.PendingOwner { if ext.PendingOwner == nil { - resp := NewLinkStartSignal("unlocking", target) - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")}) } else { - resp := NewLinkStartSignal("locking", target) - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")}) } } else if exists == true { if state.Link == "linking" { - resp := NewErrorSignal(signal.ID(), "already_linking_req") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking_req")}) } else if state.Link == "linked" { - resp := NewErrorSignal(signal.ID(), "already_req") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_req")}) } } else if dep_exists == true { - resp := NewLinkStartSignal("already_dep", target) - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_dep")}) } else { ext.Requirements[target] = LinkState{"linking", "unlocked", source} - resp := NewLinkSignal("linked_as_req") - ctx.Send(node.ID, target, &resp) - notify := NewLinkStartSignal("linking_req", target) - ctx.Send(node.ID, source, ¬ify) + messages = append(messages, Message{target, NewLinkSignal("linked_as_req")}) + messages = append(messages, Message{source, NewLinkStartSignal("linking_req", target)}) } } + return messages } // Handle LinkSignal, updating the extensions requirements and dependencies as necessary // TODO: Add unlink -func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) { - ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) +func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message { + log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) state := signal.Str + + messages := []Message{} switch state { - case "linked_as_dep": + case "dep_done": state, exists := ext.Requirements[source] - if exists == true && state.Link == "linked" { - resp := NewLinkStartSignal("linked_as_req", source) - ctx.Send(node.ID, state.Initiator, &resp) + if exists == false { + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linking")}) } else if state.Link == "linking" { state.Link = "linked" ext.Requirements[source] = state - resp := NewLinkSignal("linked_as_req") - ctx.Send(node.ID, source, &resp) - } else if ext.PendingOwner != ext.Owner { - if ext.Owner == nil { - resp := NewLinkSignal("locking") - ctx.Send(node.ID, source, &resp) - } else { - resp := NewLinkSignal("unlocking") - ctx.Send(node.ID, source, &resp) - } - } else { - ext.Requirements[source] = LinkState{"linking", "unlocked", source} - resp := NewLinkSignal("linked_as_req") - ctx.Send(node.ID, source, &resp) + log.Logf("lockable", "FINISHED_LINKING_REQ: %s->%s", node.ID, source) } - ctx.Log.Logf("lockable", "%s is a dependency of %s", node.ID, source) - case "linked_as_req": state, exists := ext.Dependencies[source] - if exists == true && state.Link == "linked" { - resp := NewLinkStartSignal("linked_as_dep", source) - ctx.Send(node.ID, state.Initiator, &resp) + if exists == false { + ext.Dependencies[source] = LinkState{"linked", "unlocked", source} + messages = append(messages, Message{source, NewLinkSignal("dep_done")}) } else if state.Link == "linking" { - state.Link = "linked" - ext.Dependencies[source] = state - resp := NewLinkSignal("linked_as_dep") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking")}) + } else if state.Link == "linked" { + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linked")}) } else if ext.PendingOwner != ext.Owner { if ext.Owner == nil { - resp := NewLinkSignal("locking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")}) } else { - resp := NewLinkSignal("unlocking") - ctx.Send(node.ID, source, &resp) + messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")}) } - } else { - ext.Dependencies[source] = LinkState{"linking", "unlocked", source} - resp := NewLinkSignal("linked_as_dep") - ctx.Send(node.ID, source, &resp) } - ctx.Log.Logf("lockable", "%s is a requirement of %s", node.ID, source) default: - ctx.Log.Logf("lockable", "LINK_ERROR: unknown state %s", state) + log.Logf("lockable", "LINK_ERROR: unknown state %s", state) } + return messages } // LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state -func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { - switch signal.Direction() { +func (ext *LockableExt) Process(ctx *Context, node *Node, msg Message) []Message { + messages := []Message{} + switch msg.Signal.Direction() { case Up: + ctx.Log.Logf("lockable", "LOCKABLE_DEPENDENCIES: %+v", ext.Dependencies) owner_sent := false for dependency, state := range(ext.Dependencies) { if state.Link == "linked" { - err := ctx.Send(node.ID, dependency, signal) - if err != nil { - ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, dependency, err) - } - + messages = append(messages, Message{dependency, msg.Signal}) if ext.Owner != nil { if dependency == *ext.Owner { owner_sent = true @@ -459,32 +408,27 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal if ext.Owner != nil && owner_sent == false { if *ext.Owner != node.ID { - err := ctx.Send(node.ID, *ext.Owner, signal) - if err != nil { - ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, *ext.Owner, err) - } + messages = append(messages, Message{*ext.Owner, msg.Signal}) } } case Down: for requirement, state := range(ext.Requirements) { if state.Link == "linked" { - err := ctx.Send(node.ID, requirement, signal) - if err != nil { - ctx.Log.Logf("signal", "LOCKABLE_SIGNAL_ERR: %s->%s - %e", node.ID, requirement, err) - } + messages = append(messages, Message{requirement, msg.Signal}) } } case Direct: - switch signal.Type() { + switch msg.Signal.Type() { case LinkSignalType: - ext.HandleLinkSignal(ctx, source, node, signal.(*StringSignal)) + messages = ext.HandleLinkSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal)) case LockSignalType: - ext.HandleLockSignal(ctx, source, node, signal.(*StringSignal)) + messages = ext.HandleLockSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal)) case LinkStartSignalType: - ext.HandleLinkStartSignal(ctx, source, node, signal.(*IDStringSignal)) + messages = ext.HandleLinkStartSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*IDStringSignal)) default: } default: } + return messages } diff --git a/lockable_test.go b/lockable_test.go index bf77350..33d6bd9 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -9,7 +9,7 @@ const TestLockableType = NodeType("TEST_LOCKABLE") func lockableTestContext(t *testing.T, logs []string) *Context { ctx := logTestContext(t, logs) - err := ctx.RegisterNodeType(TestLockableType, []ExtType{ACLExtType, LockableExtType}) + err := ctx.RegisterNodeType(TestLockableType, []ExtType{LockableExtType}) fatalErr(t, err) return ctx @@ -26,13 +26,11 @@ func TestLink(t *testing.T) { l1_listener := NewListenerExt(10) l1 := NewNode(ctx, nil, TestLockableType, 10, nil, l1_listener, - NewACLExt(&link_policy), NewLockableExt(), ) l2_listener := NewListenerExt(10) l2 := NewNode(ctx, nil, TestLockableType, 10, nil, l2_listener, - NewACLExt(&link_policy), NewLockableExt(), ) @@ -40,13 +38,13 @@ func TestLink(t *testing.T) { err := LinkRequirement(ctx, l1.ID, l2.ID) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool { - return sig.Str == "linked_as_req" + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { + return sig.Str == "dep_done" }) fatalErr(t, err) sig1 := NewStatusSignal("TEST", l2.ID) - err = ctx.Send(l2.ID, l2.ID, &sig1) + err = ctx.Send(l2.ID, []Message{{l2.ID, sig1}}) fatalErr(t, err) _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { @@ -65,7 +63,6 @@ func TestLink10K(t *testing.T) { NewLockable := func()(*Node) { l := NewNode(ctx, nil, TestLockableType, 10, nil, - NewACLExt(&lock_policy, &link_policy), NewLockableExt(), ) return l @@ -75,7 +72,6 @@ func TestLink10K(t *testing.T) { listener := NewListenerExt(100000) l := NewNode(ctx, nil, TestLockableType, 256, nil, listener, - NewACLExt(&lock_policy, &link_policy), NewLockableExt(), ) return l, listener @@ -92,8 +88,8 @@ func TestLink10K(t *testing.T) { for range(lockables) { - _, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool { - return sig.Str == "linked_as_req" + _, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { + return sig.Str == "dep_done" }) fatalErr(t, err) } @@ -102,13 +98,12 @@ func TestLink10K(t *testing.T) { } func TestLock(t *testing.T) { - ctx := lockableTestContext(t, []string{}) + ctx := lockableTestContext(t, []string{"lockable", "listener"}) NewLockable := func()(*Node, *ListenerExt) { listener := NewListenerExt(100) l := NewNode(ctx, nil, TestLockableType, 10, nil, listener, - NewACLExt(&lock_policy, &link_policy), NewLockableExt(), ) return l, listener @@ -141,30 +136,30 @@ func TestLock(t *testing.T) { err = LinkRequirement(ctx, l0.ID, l5.ID) fatalErr(t, err) - linked_as_req := func(sig *IDStringSignal) bool { - return sig.Str == "linked_as_req" + linked_as_req := func(sig *StringSignal) bool { + return sig.Str == "dep_done" } locked := func(sig *StringSignal) bool { return sig.Str == "locked" } - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) err = LockLockable(ctx, l1) diff --git a/node.go b/node.go index db0e436..b86134b 100644 --- a/node.go +++ b/node.go @@ -20,9 +20,10 @@ const ( // Magic first four bytes of serialized DB content, stored big endian NODE_DB_MAGIC = 0x2491df14 // Total length of the node database header, has magic to verify and type_hash to map to load function - NODE_DB_HEADER_LEN = 28 + NODE_DB_HEADER_LEN = 32 EXTENSION_DB_HEADER_LEN = 16 QSIGNAL_DB_HEADER_LEN = 40 + POLICY_DB_HEADER_LEN = 16 ) var ( @@ -82,7 +83,7 @@ func RandID() NodeID { return NodeID(uuid.New()) } -// A Serializable has a type that can be used to map to it, and a function to serialize the current state +// A Serializable has a type that can be used to map to it, and a function to serialize` the current state type Serializable[I comparable] interface { Serialize()([]byte,error) Deserialize(*Context,[]byte)error @@ -93,7 +94,7 @@ type Serializable[I comparable] interface { type Extension interface { Serializable[ExtType] Field(string)interface{} - Process(context *Context, source NodeID, node *Node, signal Signal) + Process(ctx *Context, node *Node, message Message)[]Message } // A QueuedSignal is a Signal that has been Queued to trigger at a set time @@ -110,9 +111,10 @@ type Node struct { ID NodeID Type NodeType Extensions map[ExtType]Extension + Policies map[PolicyType]Policy // Channel for this node to receive messages from the Context - MsgChan chan Msg + MsgChan chan Message // Size of MsgChan BufferSize uint32 // Channel for this node to process delayed signals @@ -124,6 +126,18 @@ type Node struct { NextSignal *QueuedSignal } +func (node *Node) Allows(principal_id NodeID, action Action) error { + errs := []error{} + for _, policy := range(node.Policies) { + err := policy.Allows(principal_id, action, node) + if err == nil { + return nil + } + errs = append(errs, err) + } + return fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs) +} + func (node *Node) QueueSignal(time time.Time, signal Signal) uuid.UUID { id := uuid.New() node.SignalQueue = append(node.SignalQueue, QueuedSignal{id, signal, time}) @@ -163,17 +177,12 @@ func runNode(ctx *Context, node *Node) { ctx.Log.Logf("node", "RUN_STOP: %s", node.ID) } -type Msg struct { - Source NodeID - Signal Signal -} - func ReadNodeFields(ctx *Context, self *Node, princ NodeID, reqs map[ExtType][]string)map[ExtType]map[string]interface{} { exts := map[ExtType]map[string]interface{}{} for ext_type, field_reqs := range(reqs) { fields := map[string]interface{}{} for _, req := range(field_reqs) { - err := Allowed(ctx, princ, MakeAction(ReadSignalType, ext_type, req), self) + err := self.Allows(princ, MakeAction(ReadSignalType, ext_type, req)) if err != nil { fields[req] = err } else { @@ -198,27 +207,18 @@ func nodeLoop(ctx *Context, node *Node) error { } // Perform startup actions - node.Process(ctx, node.ID, &StartSignal) + node.Process(ctx, Message{ZeroID, &StartSignal}) for true { - var signal Signal - var source NodeID + var msg Message select { - case msg := <- node.MsgChan: - ctx.Log.Logf("signal", "NODE_MSG: %s - %+v", node.ID, msg) - signal = msg.Signal - source = msg.Source - err := Allowed(ctx, msg.Source, signal.Permission(), node) - if err != nil { - ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err) - resp := NewErrorSignal(msg.Signal.ID(), err.Error()) - ctx.Send(node.ID, msg.Source, &resp) - continue - } + case msg = <- node.MsgChan: + ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal.Type()) case <-node.TimeoutChan: - signal = node.NextSignal.Signal + signal := node.NextSignal.Signal + msg = Message{node.ID, signal} + t := node.NextSignal.Time - source = node.ID i := -1 for j, queued := range(node.SignalQueue) { if queued.UUID == node.NextSignal.UUID { @@ -235,17 +235,17 @@ func nodeLoop(ctx *Context, node *Node) error { node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) if node.NextSignal == nil { - ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil@%+v", node.ID, t, signal, node.TimeoutChan) + ctx.Log.Logf("node_timeout", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL nil@%+v", node.ID, signal.Type(), t, node.TimeoutChan) } else { - ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, t, signal, node.NextSignal, node.NextSignal.Time) + ctx.Log.Logf("node_timeout", "NODE_TIMEOUT(%s) - PROCESSING %s@%s - NEXT_SIGNAL: %s@%s", node.ID, signal.Type(), t, node.NextSignal, node.NextSignal.Time) } } // Unwrap Authorized Signals - if signal.Type() == AuthorizedSignalType { - sig, ok := signal.(*AuthorizedSignal) + if msg.Signal.Type() == AuthorizedSignalType { + sig, ok := msg.Signal.(*AuthorizedSignal) if ok == false { - ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(signal)) + ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(msg.Signal)) } else { // Validate sig_data, err := sig.Signal.Serialize() @@ -253,45 +253,40 @@ func nodeLoop(ctx *Context, node *Node) error { } else { validated := ed25519.Verify(sig.Principal, sig_data, sig.Signature) if validated == true { - err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node) + err := node.Allows(KeyID(sig.Principal), sig.Signal.Permission()) if err != nil { ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err) - resp := NewErrorSignal(sig.ID(), err.Error()) - ctx.Send(node.ID, source, &resp) + ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.ID(), err.Error())}}) } else { // Unwrap the signal without changing the source - signal = sig.Signal + msg = Message{msg.NodeID, sig.Signal} } } else { ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate") - resp := NewErrorSignal(sig.ID(), "signature validation failed") - ctx.Send(node.ID, source, &resp) + ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.ID(), "signature validation failed")}}) } } } } - ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) + ctx.Log.Logf("node_signal_queue", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) // Handle special signal types - if signal.Type() == StopSignalType { - resp := NewErrorSignal(signal.ID(), "stopped") - ctx.Send(node.ID, source, &resp) - status := NewStatusSignal("stopped", node.ID) - node.Process(ctx, node.ID, &status) + if msg.Signal.Type() == StopSignalType { + ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(msg.Signal.ID(), "stopped")}}) + node.Process(ctx, Message{node.ID, NewStatusSignal("stopped", node.ID)}) break - } else if signal.Type() == ReadSignalType { - read_signal, ok := signal.(*ReadSignal) + } else if msg.Signal.Type() == ReadSignalType { + read_signal, ok := msg.Signal.(*ReadSignal) if ok == false { - ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) + ctx.Log.Logf("signal_read", "READ_SIGNAL: bad cast %+v", msg.Signal) } else { - result := ReadNodeFields(ctx, node, source, read_signal.Extensions) - resp := NewReadResultSignal(read_signal.ID(), node.Type, result) - ctx.Send(node.ID, source, &resp) + result := ReadNodeFields(ctx, node, msg.NodeID, read_signal.Extensions) + ctx.Send(node.ID, []Message{Message{msg.NodeID, NewReadResultSignal(read_signal.ID(), node.Type, result)}}) } } - node.Process(ctx, source, signal) + node.Process(ctx, msg) // assume that processing a signal means that this nodes state changed // TODO: remove a lot of database writes by only writing when things change, // so need to have Process return whether or not state changed @@ -308,11 +303,24 @@ func nodeLoop(ctx *Context, node *Node) error { return nil } -func (node *Node) Process(ctx *Context, source NodeID, signal Signal) { +type Message struct { + NodeID + Signal +} + +func (node *Node) Process(ctx *Context, message Message) error { + ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, message.Signal.Type()) + messages := []Message{} for ext_type, ext := range(node.Extensions) { - ctx.Log.Logf("signal", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) - ext.Process(ctx, source, node, signal) + ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) + //TODO: add extension and node info to log + resp := ext.Process(ctx, node, message) + if resp != nil { + messages = append(messages, resp...) + } } + + return ctx.Send(node.ID, messages) } func GetCtx[T Extension, C any](ctx *Context) (C, error) { @@ -352,6 +360,7 @@ func GetExt[T Extension](node *Node) (T, error) { func (node *Node) Serialize() ([]byte, error) { extensions := make([]ExtensionDB, len(node.Extensions)) qsignals := make([]QSignalDB, len(node.SignalQueue)) + policies := make([]PolicyDB, len(node.Policies)) key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key) if err != nil { @@ -365,6 +374,7 @@ func (node *Node) Serialize() ([]byte, error) { KeyLength: uint32(len(key_bytes)), BufferSize: node.BufferSize, NumExtensions: uint32(len(extensions)), + NumPolicies: uint32(len(policies)), NumQueuedSignals: uint32(len(node.SignalQueue)), }, Extensions: extensions, @@ -405,6 +415,22 @@ func (node *Node) Serialize() ([]byte, error) { } } + i = 0 + for _, policy := range(node.Policies) { + ser, err := policy.Serialize() + if err != nil { + return nil, err + } + + node_db.Policies[i] = PolicyDB{ + PolicyDBHeader{ + Hash(policy.Type()), + uint64(len(ser)), + }, + ser, + } + } + return node_db.Serialize(), nil } @@ -415,7 +441,7 @@ func KeyID(pub ed25519.PublicKey) NodeID { // Create a new node in memory and start it's event loop // TODO: Change panics to errors -func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, queued_signals []QueuedSignal, extensions ...Extension) *Node { +func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, policies map[PolicyType]Policy, extensions ...Extension) *Node { var err error var public ed25519.PublicKey if key == nil { @@ -453,22 +479,19 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si } } - if queued_signals == nil { - queued_signals = []QueuedSignal{} + if policies == nil { + policies = map[PolicyType]Policy{} } - next_signal, timeout_chan := SoonestSignal(queued_signals) - node := &Node{ Key: key, ID: id, Type: node_type, Extensions: ext_map, - MsgChan: make(chan Msg, buffer_size), + Policies: policies, + MsgChan: make(chan Message, buffer_size), BufferSize: buffer_size, - TimeoutChan: timeout_chan, - SignalQueue: queued_signals, - NextSignal: next_signal, + SignalQueue: []QueuedSignal{}, } ctx.AddNode(id, node) err = WriteNode(ctx, node) @@ -476,41 +499,50 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si panic(err) } - node.Process(ctx, node.ID, &NewSignal) + node.Process(ctx, Message{node.ID, &NewSignal}) go runNode(ctx, node) return node } -func Allowed(ctx *Context, principal_id NodeID, action Action, node *Node) error { - ctx.Log.Logf("policy", "POLICY_CHECK: %s -> %s.%s", principal_id, node.ID, action) - // Nodes are allowed to perform all actions on themselves regardless of whether or not they have an ACL extension - if principal_id == node.ID { - ctx.Log.Logf("policy", "POLICY_CHECK_SAME_NODE: %s.%s", principal_id, action) - return nil - } +type PolicyDBHeader struct { + TypeHash uint64 + Length uint64 +} - // Check if the node has a policy extension itself, and check against the policies in it - policy_ext, err := GetExt[*ACLExt](node) - if err != nil { - ctx.Log.Logf("policy", "POLICY_CHECK_NO_ACL_EXT: %s", node.ID) - return err - } +type PolicyDB struct { + Header PolicyDBHeader + Data []byte +} - err = policy_ext.Allows(ctx, principal_id, action, node) - if err != nil { - ctx.Log.Logf("policy", "POLICY_CHECK_FAIL: %s -> %s.%s : %s", principal_id, node.ID, action, err) - } else { - ctx.Log.Logf("policy", "POLICY_CHECK_PASS: %s -> %s.%s", principal_id, node.ID, action) - } - return err +type QSignalDBHeader struct { + SignalID uuid.UUID + Time time.Time + TypeHash uint64 + Length uint64 +} + +type QSignalDB struct { + Header QSignalDBHeader + Data []byte +} + +type ExtensionDBHeader struct { + TypeHash uint64 + Length uint64 +} + +type ExtensionDB struct { + Header ExtensionDBHeader + Data []byte } // A DBHeader is parsed from the first NODE_DB_HEADER_LEN bytes of a serialized DB node type NodeDBHeader struct { Magic uint32 NumExtensions uint32 + NumPolicies uint32 NumQueuedSignals uint32 BufferSize uint32 KeyLength uint32 @@ -519,8 +551,9 @@ type NodeDBHeader struct { type NodeDB struct { Header NodeDBHeader - QueuedSignals []QSignalDB Extensions []ExtensionDB + Policies []PolicyDB + QueuedSignals []QSignalDB KeyBytes []byte } @@ -532,10 +565,11 @@ func NewNodeDB(data []byte) (NodeDB, error) { magic := binary.BigEndian.Uint32(data[0:4]) num_extensions := binary.BigEndian.Uint32(data[4:8]) - num_queued_signals := binary.BigEndian.Uint32(data[8:12]) - buffer_size := binary.BigEndian.Uint32(data[12:16]) - key_length := binary.BigEndian.Uint32(data[16:20]) - node_type_hash := binary.BigEndian.Uint64(data[20:28]) + num_policies := binary.BigEndian.Uint32(data[8:12]) + num_queued_signals := binary.BigEndian.Uint32(data[12:16]) + buffer_size := binary.BigEndian.Uint32(data[16:20]) + key_length := binary.BigEndian.Uint32(data[20:24]) + node_type_hash := binary.BigEndian.Uint64(data[24:32]) ptr += NODE_DB_HEADER_LEN @@ -573,6 +607,26 @@ func NewNodeDB(data []byte) (NodeDB, error) { ptr += int(EXTENSION_DB_HEADER_LEN + length) } + policies := make([]PolicyDB, num_policies) + for i, _ := range(policies) { + cur := data[ptr:] + type_hash := binary.BigEndian.Uint64(cur[0:8]) + length := binary.BigEndian.Uint64(cur[8:16]) + + data_start := uint64(POLICY_DB_HEADER_LEN) + data_end := data_start + length + policy_data := cur[data_start:data_end] + + policies[i] = PolicyDB{ + PolicyDBHeader{ + type_hash, + length, + }, + policy_data, + } + ptr += int(POLICY_DB_HEADER_LEN + length) + } + queued_signals := make([]QSignalDB, num_queued_signals) for i, _ := range(queued_signals) { cur := data[ptr:] @@ -626,10 +680,11 @@ func (header NodeDBHeader) Serialize() []byte { ret := make([]byte, NODE_DB_HEADER_LEN) binary.BigEndian.PutUint32(ret[0:4], header.Magic) binary.BigEndian.PutUint32(ret[4:8], header.NumExtensions) - binary.BigEndian.PutUint32(ret[8:12], header.NumQueuedSignals) - binary.BigEndian.PutUint32(ret[12:16], header.BufferSize) - binary.BigEndian.PutUint32(ret[16:20], header.KeyLength) - binary.BigEndian.PutUint64(ret[20:28], header.TypeHash) + binary.BigEndian.PutUint32(ret[8:12], header.NumPolicies) + binary.BigEndian.PutUint32(ret[12:16], header.NumQueuedSignals) + binary.BigEndian.PutUint32(ret[16:20], header.BufferSize) + binary.BigEndian.PutUint32(ret[20:24], header.KeyLength) + binary.BigEndian.PutUint64(ret[24:32], header.TypeHash) return ret } @@ -673,28 +728,6 @@ func (extension ExtensionDB) Serialize() []byte { return append(header_bytes, extension.Data...) } -type QSignalDBHeader struct { - SignalID uuid.UUID - Time time.Time - TypeHash uint64 - Length uint64 -} - -type QSignalDB struct { - Header QSignalDBHeader - Data []byte -} - -type ExtensionDBHeader struct { - TypeHash uint64 - Length uint64 -} - -type ExtensionDB struct { - Header ExtensionDBHeader - Data []byte -} - // Write a node to the database func WriteNode(ctx *Context, node *Node) error { ctx.Log.Logf("db", "DB_WRITE: %s", node.ID) @@ -740,6 +773,21 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, err } + policies := make(map[PolicyType]Policy, node_db.Header.NumPolicies) + for _, policy_db := range(node_db.Policies) { + policy_info, exists := ctx.Policies[policy_db.Header.TypeHash] + if exists == false { + return nil, fmt.Errorf("0x%x is not a known policy type", policy_db.Header.TypeHash) + } + + policy, err := policy_info.Load(ctx, policy_db.Data) + if err != nil { + return nil, err + } + + policies[policy_info.Type] = policy + } + key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes) if err != nil { return nil, err @@ -784,7 +832,8 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { ID: key_id, Type: node_type.Type, Extensions: map[ExtType]Extension{}, - MsgChan: make(chan Msg, node_db.Header.BufferSize), + Policies: policies, + MsgChan: make(chan Message, node_db.Header.BufferSize), BufferSize: node_db.Header.BufferSize, TimeoutChan: timeout_chan, SignalQueue: signal_queue, diff --git a/node_test.go b/node_test.go index cef5c28..44e524b 100644 --- a/node_test.go +++ b/node_test.go @@ -23,7 +23,7 @@ func TestNodeDB(t *testing.T) { func TestNodeRead(t *testing.T) { ctx := logTestContext(t, []string{}) node_type := NodeType("TEST") - err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType, ECDHExtType}) + err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType, ECDHExtType}) fatalErr(t, err) n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader) @@ -37,21 +37,15 @@ func TestNodeRead(t *testing.T) { ctx.Log.Logf("test", "N1: %s", n1_id) ctx.Log.Logf("test", "N2: %s", n2_id) - n2_policy := NewPerNodePolicy(map[NodeID]Actions{ - n1_id: Actions{MakeAction(ReadResultSignalType, "+")}, - }) n2_listener := NewListenerExt(10) - n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(&n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener) + n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt(), n2_listener) - n1_policy := NewPerNodePolicy(map[NodeID]Actions{ - n2_id: Actions{MakeAction(ReadSignalType, "+")}, - }) - n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(&n1_policy), NewGroupExt(nil), NewECDHExt()) + n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewGroupExt(nil), NewECDHExt()) read_sig := NewReadSignal(map[ExtType][]string{ GroupExtType: []string{"members"}, }) - ctx.Send(n2.ID, n1.ID, &read_sig) + ctx.Send(n2.ID, []Message{{n1.ID, &read_sig}}) res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { return true @@ -64,22 +58,20 @@ func TestECDH(t *testing.T) { ctx := logTestContext(t, []string{"test", "ecdh", "policy"}) node_type := NodeType("TEST") - err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, ECDHExtType}) + err := ctx.RegisterNodeType(node_type, []ExtType{ECDHExtType}) fatalErr(t, err) n1_listener := NewListenerExt(10) - ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+"), MakeAction(ECDHProxySignalType, "+")}) - n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt(), n1_listener) - n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt()) + n1 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt(), n1_listener) + n2 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt()) n3_listener := NewListenerExt(10) - n3_policy := NewPerNodePolicy(NodeActions{n1.ID: Actions{MakeAction(StopSignalType)}}) - n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy, &n3_policy), NewECDHExt(), n3_listener) + n3 := NewNode(ctx, nil, node_type, 10, nil, NewECDHExt(), n3_listener) ctx.Log.Logf("test", "N1: %s", n1.ID) ctx.Log.Logf("test", "N2: %s", n2.ID) - ecdh_req, n1_ec, err := NewECDHReqSignal(ctx, n1) + ecdh_req, n1_ec, err := NewECDHReqSignal(n1) ecdh_ext, err := GetExt[*ECDHExt](n1) fatalErr(t, err) ecdh_ext.ECDHStates[n2.ID] = ECDHState{ @@ -88,7 +80,7 @@ func TestECDH(t *testing.T) { } fatalErr(t, err) ctx.Log.Logf("test", "N1_EC: %+v", n1_ec) - err = ctx.Send(n1.ID, n2.ID, &ecdh_req) + err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_req}}) fatalErr(t, err) _, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool { @@ -100,6 +92,6 @@ func TestECDH(t *testing.T) { ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret) fatalErr(t, err) - err = ctx.Send(n1.ID, n2.ID, &ecdh_sig) + err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_sig}}) fatalErr(t, err) } diff --git a/policy.go b/policy.go index 7a5ca1f..2245176 100644 --- a/policy.go +++ b/policy.go @@ -5,11 +5,8 @@ import ( "fmt" ) -type PolicyType string -func (policy PolicyType) Prefix() string { return "POLICY: " } -func (policy PolicyType) String() string { return string(policy) } - const ( + UserOfPolicyType = PolicyType("USER_OF") RequirementOfPolicyType = PolicyType("REQUIREMENT_OF") PerNodePolicyType = PolicyType("PER_NODE") AllNodesPolicyType = PolicyType("ALL_NODES") @@ -38,7 +35,7 @@ func (policy PerNodePolicy) Allows(principal_id NodeID, action Action, node *Nod return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) } -func (policy RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error { +func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error { lockable_ext, err := GetExt[*LockableExt](node) if err != nil { return err @@ -53,10 +50,45 @@ func (policy RequirementOfPolicy) Allows(principal_id NodeID, action Action, nod return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID) } +type UserOfPolicy struct { + PerNodePolicy +} + +func (policy *UserOfPolicy) Type() PolicyType { + return UserOfPolicyType +} + +func NewUserOfPolicy(group_actions NodeActions) UserOfPolicy { + return UserOfPolicy{ + PerNodePolicy: NewPerNodePolicy(group_actions), + } +} + +// Send a read signal to Group to check if principal_id is a member of it +func (policy *UserOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error { + // Send a read signal to each of the groups in the map + // Check for principal_id in any of the returned member lists(skipping errors) + // Return an error in the default case + return fmt.Errorf("NOT_IMPLEMENTED") +} + +func (policy *UserOfPolicy) Merge(p Policy) Policy { + other := p.(*UserOfPolicy) + policy.NodeActions = MergeNodeActions(policy.NodeActions, other.NodeActions) + return policy +} + +func (policy *UserOfPolicy) Copy() Policy { + new_actions := CopyNodeActions(policy.NodeActions) + return &UserOfPolicy{ + PerNodePolicy: NewPerNodePolicy(new_actions), + } +} + type RequirementOfPolicy struct { AllNodesPolicy } -func (policy RequirementOfPolicy) Type() PolicyType { +func (policy *RequirementOfPolicy) Type() PolicyType { return RequirementOfPolicyType } @@ -82,20 +114,25 @@ func CopyNodeActions(actions NodeActions) NodeActions { return ret } -func MergeNodeActions(modified NodeActions, read NodeActions) { - for id, actions := range(read) { - existing, exists := modified[id] +func MergeNodeActions(first NodeActions, second NodeActions) NodeActions { + merged := NodeActions{} + for id, actions := range(first) { + merged[id] = actions + } + for id, actions := range(second) { + existing, exists := merged[id] if exists { - modified[id] = MergeActions(existing, actions) + merged[id] = MergeActions(existing, actions) } else { - modified[id] = actions + merged[id] = actions } } + return merged } func (policy *PerNodePolicy) Merge(p Policy) Policy { other := p.(*PerNodePolicy) - MergeNodeActions(policy.NodeActions, other.NodeActions) + policy.NodeActions = MergeNodeActions(policy.NodeActions, other.NodeActions) return policy } @@ -263,63 +300,6 @@ func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error { return json.Unmarshal(data, policy) } -// Extension to allow a node to hold ACL policies -type ACLExt struct { - Policies map[PolicyType]Policy -} - - -func NodeList(nodes ...*Node) NodeMap { - m := NodeMap{} - for _, node := range(nodes) { - m[node.ID] = node - } - return m -} - -type PolicyLoadFunc func(*Context,[]byte) (Policy, error) -type ACLExtContext struct { - Loads map[PolicyType]PolicyLoadFunc -} - -func NewACLExtContext() *ACLExtContext { - return &ACLExtContext{ - Loads: map[PolicyType]PolicyLoadFunc{ - AllNodesPolicyType: LoadPolicy[AllNodesPolicy,*AllNodesPolicy], - PerNodePolicyType: LoadPolicy[PerNodePolicy,*PerNodePolicy], - RequirementOfPolicyType: LoadPolicy[RequirementOfPolicy,*RequirementOfPolicy], - }, - } -} - -func (ext *ACLExt) Serialize() ([]byte, error) { - policies := map[string][]byte{} - for name, policy := range(ext.Policies) { - ser, err := policy.Serialize() - if err != nil { - return nil, err - } - policies[string(name)] = ser - } - - return json.MarshalIndent(&struct{ - Policies map[string][]byte `json:"policies"` - }{ - Policies: policies, - }, "", " ") -} - -func (ext *ACLExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { -} - -func (ext *ACLExt) Field(name string) interface{} { - return ResolveFields(ext, name, map[string]func(*ACLExt)interface{}{ - "policies": func(ext *ACLExt) interface{} { - return ext.Policies - }, - }) -} - var ErrorSignalAction = Action{"ERROR_RESP"} var ReadResultSignalAction = Action{"READ_RESULT"} var AuthorizedSignalAction = Action{"AUTHORIZED_READ"} @@ -327,82 +307,3 @@ var defaultPolicy = NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSigna var DefaultACLPolicies = []Policy{ &defaultPolicy, } - -func NewACLExt(policies ...Policy) *ACLExt { - policy_map := map[PolicyType]Policy{} - for _, policy_arg := range(append(policies, DefaultACLPolicies...)) { - policy := policy_arg.Copy() - existing, exists := policy_map[policy.Type()] - if exists == true { - policy = existing.Merge(policy) - } - - policy_map[policy.Type()] = policy - } - - return &ACLExt{ - Policies: policy_map, - } -} - -func LoadPolicy[T any, P interface { - *T - Policy -}](ctx *Context, data []byte) (Policy, error) { - p := P(new(T)) - err := p.Deserialize(ctx, data) - if err != nil { - return nil, err - } - - return p, nil -} - -func (ext *ACLExt) Deserialize(ctx *Context, data []byte) error { - var j struct { - Policies map[string][]byte `json:"policies"` - } - - err := json.Unmarshal(data, &j) - if err != nil { - return err - } - - acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx) - if err != nil { - return err - } - ext.Policies = map[PolicyType]Policy{} - - for name, ser := range(j.Policies) { - policy_load, exists := acl_ctx.Loads[PolicyType(name)] - if exists == false { - return fmt.Errorf("%s is not a known policy type", name) - } - policy, err := policy_load(ctx, ser) - if err != nil { - return err - } - - ext.Policies[PolicyType(name)] = policy - } - - return nil -} - -func (ext *ACLExt) Type() ExtType { - return ACLExtType -} - -// Check if the extension allows the principal to perform action on node -func (ext *ACLExt) Allows(ctx *Context, principal_id NodeID, action Action, node *Node) error { - errs := []error{} - for _, policy := range(ext.Policies) { - err := policy.Allows(principal_id, action, node) - if err == nil { - return nil - } - errs = append(errs, err) - } - return fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs) -} diff --git a/signal.go b/signal.go index 418171c..b950c79 100644 --- a/signal.go +++ b/signal.go @@ -28,7 +28,6 @@ const ( ReadResultSignalType = "READ_RESULT" LinkStartSignalType = "LINK_START" ECDHSignalType = "ECDH" - ECDHStateSignalType = "ECDH_STATE" ECDHProxySignalType = "ECDH_PROXY" GQLStateSignalType = "GQL_STATE" @@ -43,6 +42,7 @@ func (signal_type SignalType) Prefix() string { return "SIGNAL: " } type Signal interface { Serializable[SignalType] + String() string Direction() SignalDirection ID() uuid.UUID Permission() Action @@ -70,12 +70,12 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time. } for true { select { - case signal := <- listener.Chan: - if signal == nil { + case msg := <- listener.Chan: + if msg.Signal == nil { return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type) } - if signal.Type() == signal_type { - sig, ok := signal.(S) + if msg.Signal.Type() == signal_type { + sig, ok := msg.Signal.(S) if ok == true { if check(sig) == true { return sig, nil @@ -96,6 +96,11 @@ type BaseSignal struct { UUID uuid.UUID `json:"id"` } +func (signal *BaseSignal) String() string { + ser, _ := json.Marshal(signal) + return string(ser) +} + func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error { return json.Unmarshal(data, signal) } @@ -129,21 +134,9 @@ func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal return signal } -func NewDownSignal(signal_type SignalType) BaseSignal { - return NewBaseSignal(signal_type, Down) -} - -func NewUpSignal(signal_type SignalType) BaseSignal { - return NewBaseSignal(signal_type, Up) -} - -func NewDirectSignal(signal_type SignalType) BaseSignal { - return NewBaseSignal(signal_type, Direct) -} - -var NewSignal = NewDirectSignal(NewSignalType) -var StartSignal = NewDirectSignal(StartSignalType) -var StopSignal = NewDownSignal(StopSignalType) +var NewSignal = NewBaseSignal(NewSignalType, Direct) +var StartSignal = NewBaseSignal(StartSignalType, Direct) +var StopSignal = NewBaseSignal(StopSignalType, Direct) type IDSignal struct { BaseSignal @@ -154,88 +147,91 @@ func (signal *IDSignal) Serialize() ([]byte, error) { return json.Marshal(signal) } -func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal { - return IDSignal{ - BaseSignal: NewBaseSignal(signal_type, direction), - NodeID: id, - } -} - type StringSignal struct { BaseSignal Str string `json:"state"` } +func (signal *StringSignal) String() string { + ser, _ := json.Marshal(signal) + return string(ser) +} + func (signal *StringSignal) Serialize() ([]byte, error) { return json.Marshal(&signal) } +type RespSignal struct { + BaseSignal + ReqID uuid.UUID +} + type ErrorSignal struct { - StringSignal + RespSignal + Error string +} + +func (signal *ErrorSignal) String() string { + ser, _ := json.Marshal(signal) + return string(ser) } func (signal *ErrorSignal) Permission() Action { return ErrorSignalAction } -func NewErrorSignal(req_id uuid.UUID, err string) ErrorSignal { - return ErrorSignal{ - StringSignal{ - NewDirectSignal(ErrorSignalType), - err, +func NewErrorSignal(req_id uuid.UUID, err string) Signal { + return &ErrorSignal{ + RespSignal{ + NewBaseSignal(ErrorSignalType, Direct), + req_id, }, + err, } } type IDStringSignal struct { BaseSignal - NodeID `json:"node_id"` + NodeID NodeID `json:"node_id"` Str string `json:"string"` } -func (signal *IDStringSignal) Serialize() ([]byte, error) { - return json.Marshal(signal) -} - func (signal *IDStringSignal) String() string { - ser, err := json.Marshal(signal) - if err != nil { - return "STATE_SER_ERR" - } + ser, _ := json.Marshal(signal) return string(ser) } -func NewStatusSignal(status string, source NodeID) IDStringSignal { - return IDStringSignal{ - BaseSignal: NewUpSignal(StatusSignalType), +func (signal *IDStringSignal) Serialize() ([]byte, error) { + return json.Marshal(signal) +} + +func NewStatusSignal(status string, source NodeID) Signal { + return &IDStringSignal{ + BaseSignal: NewBaseSignal(StatusSignalType, Up), NodeID: source, Str: status, } } -func NewLinkSignal(state string) StringSignal { - return StringSignal{ - BaseSignal: NewDirectSignal(LinkSignalType), +func NewLinkSignal(state string) Signal { + return &StringSignal{ + BaseSignal: NewBaseSignal(LinkSignalType, Direct), Str: state, } } -func NewIDStringSignal(signal_type SignalType, direction SignalDirection, state string, id NodeID) IDStringSignal { - return IDStringSignal{ - BaseSignal: NewBaseSignal(signal_type, direction), - NodeID: id, - Str: state, +func NewLinkStartSignal(link_type string, target NodeID) Signal { + return &IDStringSignal{ + NewBaseSignal(LinkStartSignalType, Direct), + target, + link_type, } } -func NewLinkStartSignal(link_type string, target NodeID) IDStringSignal { - return NewIDStringSignal(LinkStartSignalType, Direct, link_type, target) -} - -func NewLockSignal(state string) StringSignal { - return StringSignal{ - BaseSignal: NewDirectSignal(LockSignalType), - Str: state, +func NewLockSignal(state string) Signal { + return &StringSignal{ + NewBaseSignal(LockSignalType, Direct), + state, } } @@ -259,22 +255,22 @@ func (signal *AuthorizedSignal) Permission() Action { return AuthorizedSignalAction } -func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (AuthorizedSignal, error) { +func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (Signal, error) { sig_data, err := signal.Serialize() if err != nil { - return AuthorizedSignal{}, err + return nil, err } sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0)) if err != nil { - return AuthorizedSignal{}, err + return nil, err } - return AuthorizedSignal{ - BaseSignal: NewDirectSignal(AuthorizedSignalType), - Principal: principal.Public().(ed25519.PublicKey), - Signal: signal, - Signature: sig, + return &AuthorizedSignal{ + NewBaseSignal(AuthorizedSignalType, Direct), + principal.Public().(ed25519.PublicKey), + signal, + sig, }, nil } @@ -284,13 +280,13 @@ func (signal *ReadSignal) Serialize() ([]byte, error) { func NewReadSignal(exts map[ExtType][]string) ReadSignal { return ReadSignal{ - BaseSignal: NewDirectSignal(ReadSignalType), - Extensions: exts, + NewBaseSignal(ReadSignalType, Direct), + exts, } } type ReadResultSignal struct { - BaseSignal + RespSignal NodeType Extensions map[ExtType]map[string]interface{} `json:"extensions"` } @@ -299,15 +295,14 @@ func (signal *ReadResultSignal) Permission() Action { return ReadResultSignalAction } -func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]map[string]interface{}) ReadResultSignal { - return ReadResultSignal{ - BaseSignal: BaseSignal{ - Direct, - ReadResultSignalType, +func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]map[string]interface{}) Signal { + return &ReadResultSignal{ + RespSignal{ + NewBaseSignal(ReadResultSignalType, Direct), req_id, }, - NodeType: node_type, - Extensions: exts, + node_type, + exts, } } @@ -341,28 +336,28 @@ func (signal *ECDHSignal) Serialize() ([]byte, error) { return json.Marshal(signal) } -func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, error) { - ec_key, err := ctx.ECDH.GenerateKey(rand.Reader) +func NewECDHReqSignal(node *Node) (Signal, *ecdh.PrivateKey, error) { + ec_key, err := ECDH.GenerateKey(rand.Reader) if err != nil { - return ECDHSignal{}, nil, err + return nil, nil, err } now := time.Now() time_bytes, err := now.MarshalJSON() if err != nil { - return ECDHSignal{}, nil, err + return nil, nil, err } sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) if err != nil { - return ECDHSignal{}, nil, err + return nil, nil, err } - return ECDHSignal{ + return &ECDHSignal{ StringSignal: StringSignal{ - BaseSignal: NewDirectSignal(ECDHSignalType), + BaseSignal: NewBaseSignal(ECDHSignalType, Direct), Str: "req", }, Time: now, @@ -374,7 +369,7 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e const DEFAULT_ECDH_WINDOW = time.Second -func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) { +func NewECDHRespSignal(node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) { now := time.Now() err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW) @@ -382,7 +377,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [ return ECDHSignal{}, nil, err } - ec_key, err := ctx.ECDH.GenerateKey(rand.Reader) + ec_key, err := ECDH.GenerateKey(rand.Reader) if err != nil { return ECDHSignal{}, nil, err } @@ -406,7 +401,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [ return ECDHSignal{ StringSignal: StringSignal{ - BaseSignal: NewDirectSignal(ECDHSignalType), + BaseSignal: NewBaseSignal(ECDHSignalType, Direct), Str: "resp", }, Time: now, @@ -449,34 +444,34 @@ type ECDHProxySignal struct { Data []byte } -func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (ECDHProxySignal, error) { +func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (Signal, error) { if shared_secret == nil { - return ECDHProxySignal{}, fmt.Errorf("need shared_secret") + return nil, fmt.Errorf("need shared_secret") } aes_key, err := aes.NewCipher(shared_secret[:32]) if err != nil { - return ECDHProxySignal{}, err + return nil, err } ser, err := SerializeSignal(signal, aes_key.BlockSize()) if err != nil { - return ECDHProxySignal{}, err + return nil, err } iv := make([]byte, aes_key.BlockSize()) n, err := rand.Reader.Read(iv) if err != nil { - return ECDHProxySignal{}, err + return nil, err } else if n != len(iv) { - return ECDHProxySignal{}, fmt.Errorf("Not enough bytes read for IV") + return nil, fmt.Errorf("Not enough bytes read for IV") } encrypter := cipher.NewCBCEncrypter(aes_key, iv) encrypter.CryptBlocks(ser, ser) - return ECDHProxySignal{ - BaseSignal: NewDirectSignal(ECDHProxySignalType), + return &ECDHProxySignal{ + BaseSignal: NewBaseSignal(ECDHProxySignalType, Direct), Source: source, Dest: dest, IV: iv, diff --git a/user.go b/user.go index 7371af6..94be32f 100644 --- a/user.go +++ b/user.go @@ -48,7 +48,7 @@ func (ext *GroupExt) Deserialize(ctx *Context, data []byte) error { return err } -func (ext *GroupExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { - return +func (ext *GroupExt) Process(ctx *Context, node *Node, msg Message) []Message { + return nil }