diff --git a/context.go b/context.go index 6048e50..835e2ef 100644 --- a/context.go +++ b/context.go @@ -235,25 +235,19 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) { return target, nil } -// Stop every running loop -func (ctx *Context) Stop() { - for _, node := range(ctx.Nodes) { - node.MsgChan <- Message{ZeroID, &StopSignal} - } -} - // Route a Signal to dest. Currently only local context routing is supported -func (ctx *Context) Send(source NodeID, messages []Message) error { +func (ctx *Context) Send(messages Messages) error { for _, msg := range(messages) { - target, err := ctx.GetNode(msg.NodeID) + target, err := ctx.GetNode(msg.Dest) if err == nil { select { - case target.MsgChan <- Message{source, msg.Signal}: + case target.MsgChan <- msg: + ctx.Log.Logf("signal", "Sent %s -> %+v", target.ID, msg) default: buf := make([]byte, 4096) n := runtime.Stack(buf, false) stack_str := string(buf[:n]) - return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", msg.NodeID, stack_str) + return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", msg.Dest, stack_str) } } else if errors.Is(err, NodeNotFoundError) { // TODO: Handle finding nodes in other contexts diff --git a/ecdh.go b/ecdh.go index 401f813..837a856 100644 --- a/ecdh.go +++ b/ecdh.go @@ -103,10 +103,10 @@ func (ext *ECDHExt) Field(name string) interface{} { }) } -func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) []Message { +func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) Messages { source := KeyID(signal.EDDSA) - messages := []Message{} + messages := Messages{} switch signal.Str { case "req": state, exists := ext.ECDHStates[source] @@ -118,15 +118,15 @@ func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) state.SharedSecret = shared_secret ext.ECDHStates[source] = state log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) - messages = append(messages, Message{source, &resp}) + messages = messages.Add(log, node.ID, node.Key, &resp, source) } else { log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), err.Error())}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), err.Error()), source) } case "resp": state, exists := ext.ECDHStates[source] if exists == false || state.ECKey == nil { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "no_req")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "no_req"), source) } else { err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) if err == nil { @@ -145,10 +145,10 @@ func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) return messages } -func (ext *ECDHExt) Process(ctx *Context, node *Node, msg Message) []Message { - switch msg.Signal.Type() { +func (ext *ECDHExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { + switch signal.Type() { case ECDHSignalType: - sig := msg.Signal.(*ECDHSignal) + sig := signal.(*ECDHSignal) return ext.HandleECDHSignal(ctx.Log, node, sig) } return nil diff --git a/gql.go b/gql.go index df9dde6..3617baa 100644 --- a/gql.go +++ b/gql.go @@ -921,31 +921,7 @@ func NewGQLExtContext() *GQLExtContext { context.Mutation.AddFieldConfig("stop", &graphql.Field{ Type: graphql.String, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - ctx, err := PrepResolve(p) - if err != nil { - return nil, err - } - - sig, err := NewAuthorizedSignal(ctx.Key, &StopSignal) - if err != nil { - return nil, err - } - - response_chan := ctx.Ext.GetResponseChannel(sig.ID()) - err = ctx.Context.Send(ctx.Server.ID, []Message{Message{ctx.Server.ID, sig}}) - if err != nil { - ctx.Ext.FreeResponseChannel(sig.ID()) - return nil, err - } - - resp, err := WaitForResult(response_chan, 100*time.Millisecond, sig.ID()) - if err != nil { - return nil, err - } - - ser, err := resp.Serialize() - - return string(ser), err + return nil, fmt.Errorf("NOT_IMPLEMENTED") }, }) @@ -1053,10 +1029,9 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal { } } -func (ext *GQLExt) Process(ctx *Context, node *Node, msg Message) []Message { +func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { // Process ReadResultSignalType by forwarding it to the waiting resolver - signal := msg.Signal - messages := []Message{} + messages := Messages{} if signal.Type() == ErrorSignalType { // TODO: Forward to resolver if waiting for it sig := signal.(*ErrorSignal) diff --git a/gql_query.go b/gql_query.go index 407c85d..736537e 100644 --- a/gql_query.go +++ b/gql_query.go @@ -2,7 +2,6 @@ package graphvent import ( "time" "reflect" - "fmt" "github.com/graphql-go/graphql" "github.com/graphql-go/graphql/language/ast" "github.com/google/uuid" @@ -51,17 +50,15 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ } // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error read_signal := NewReadSignal(ext_fields) - auth_signal, err := NewAuthorizedSignal(ctx.Key, &read_signal) - if err != nil { - return nil, err - } - + msgs := Messages{} + msgs = msgs.Add(ctx.Context.Log, ctx.Server.ID, ctx.Key, read_signal, id) response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) resp_channels[read_signal.ID()] = response_chan node_ids[read_signal.ID()] = id - err = ctx.Context.Send(ctx.Server.ID, []Message{Message{id, auth_signal}}) + // TODO: Send all at once instead of createing Messages for each + err = ctx.Context.Send(msgs) if err != nil { ctx.Ext.FreeResponseChannel(read_signal.ID()) return nil, err @@ -71,18 +68,13 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ responses := []NodeResult{} for sig_id, response_chan := range(resp_channels) { // Wait for the response, returning an error on timeout - response, err := WaitForResult(response_chan, time.Millisecond*100, sig_id) + response, err := WaitForSignal(ctx.Context, response_chan, time.Millisecond*100, ReadResultSignalType, func(sig *ReadResultSignal)bool{ + return sig.ReqID == sig_id + }) if err != nil { return nil, err } - switch resp := response.(type) { - case *ReadResultSignal: - responses = append(responses, NodeResult{node_ids[sig_id], resp}) - case *ErrorSignal: - return nil, fmt.Errorf(resp.Error) - default: - return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp)) - } + responses = append(responses, NodeResult{node_ids[sig_id], response}) } ctx.Context.Log.Logf("gql", "RESOLVED_NODES") diff --git a/gql_test.go b/gql_test.go index 4cfae45..7c161b0 100644 --- a/gql_test.go +++ b/gql_test.go @@ -14,25 +14,42 @@ import ( ) func TestGQLServer(t *testing.T) { - ctx := logTestContext(t, []string{"gql", "lockable", "signal"}) + ctx := logTestContext(t, []string{"test"}) TestNodeType := NodeType("TEST") err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) fatalErr(t, err) + policy := NewAllNodesPolicy(Actions{ + MakeAction(LinkSignalType, "+"), + MakeAction(LinkStartSignalType, "+"), + MakeAction(LockSignalType, "+"), + MakeAction(StatusSignalType, "+"), + MakeAction(ErrorSignalType, "+"), + MakeAction(ReadSignalType, "+"), + MakeAction(ReadResultSignalType, "+"), + MakeAction(GQLStateSignalType, "+"), + }) + gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped") fatalErr(t, err) listener_ext := NewListenerExt(10) - gql := NewNode(ctx, nil, GQLNodeType, 10, nil, NewLockableExt(), gql_ext, NewGroupExt(nil), listener_ext) - n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt()) - - err = LinkRequirement(ctx, gql.ID, n1.ID) + gql := NewNode(ctx, nil, GQLNodeType, 10, map[PolicyType]Policy{ + AllNodesPolicyType: &policy, + }, NewLockableExt(), gql_ext, NewGroupExt(nil), listener_ext) + n1 := NewNode(ctx, nil, TestNodeType, 10, map[PolicyType]Policy{ + AllNodesPolicyType: &policy, + }, NewLockableExt()) + + err = LinkRequirement(ctx, gql, n1.ID) fatalErr(t, err) - err = ctx.Send(gql.ID, []Message{{gql.ID, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}}}) + msgs := Messages{} + msgs = msgs.Add(ctx.Log, gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}, gql.ID) + err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { + _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "server_started" }) fatalErr(t, err) @@ -84,9 +101,11 @@ func TestGQLServer(t *testing.T) { resp_2 := SendGQL(req_2) ctx.Log.Logf("test", "RESP_2: %s", resp_2) - stop_signal := StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"} - ctx.Send(gql.ID, []Message{{gql.ID, &stop_signal}}) - _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { + msgs = Messages{} + msgs = msgs.Add(ctx.Log, gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}, gql.ID) + err = ctx.Send(msgs) + fatalErr(t, err) + _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "server_stopped" }) fatalErr(t, err) @@ -111,8 +130,11 @@ func TestGQLDB(t *testing.T) { NewGroupExt(nil)) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) - ctx.Stop() - _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { + msgs := Messages{} + msgs = msgs.Add(ctx.Log, gql.ID, gql.Key, &StopSignal, gql.ID) + err = ctx.Send(msgs) + fatalErr(t, err) + _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "stopped" && sig.NodeID == gql.ID }) fatalErr(t, err) @@ -130,9 +152,11 @@ func TestGQLDB(t *testing.T) { fatalErr(t, err) listener_ext, err = GetExt[*ListenerExt](gql_loaded) fatalErr(t, err) - err = ctx.Send(gql_loaded.ID, []Message{{gql_loaded.ID, &StopSignal}}) + msgs = Messages{} + msgs = msgs.Add(ctx.Log, gql_loaded.ID, gql_loaded.Key, &StopSignal, gql_loaded.ID) + err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { + _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID }) fatalErr(t, err) diff --git a/lockable.go b/lockable.go index 359ab86..2b0c223 100644 --- a/lockable.go +++ b/lockable.go @@ -7,14 +7,14 @@ import ( // A Listener extension provides a channel that can receive signals on a different thread type ListenerExt struct { Buffer int - Chan chan Message + Chan chan Signal } // Create a new listener extension with a given buffer size func NewListenerExt(buffer int) *ListenerExt { return &ListenerExt{ Buffer: buffer, - Chan: make(chan Message, buffer), + Chan: make(chan Signal, buffer), } } @@ -32,7 +32,7 @@ func (ext *ListenerExt) Field(name string) interface{} { // Simple load function, unmarshal the buffer int from json func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error { err := json.Unmarshal(data, &ext.Buffer) - ext.Chan = make(chan Message, ext.Buffer) + ext.Chan = make(chan Signal, ext.Buffer) return err } @@ -41,10 +41,10 @@ func (listener *ListenerExt) Type() ExtType { } // Send the signal to the channel, logging an overflow if it occurs -func (ext *ListenerExt) Process(ctx *Context, node *Node, msg Message) []Message { - ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, msg.Signal) +func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { + ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal) select { - case ext.Chan <- msg: + case ext.Chan <- signal: default: ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID) } @@ -138,38 +138,44 @@ func NewLockableExt() *LockableExt { // Send the signal to unlock a node from itself func UnlockLockable(ctx *Context, node *Node) error { - return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("unlock")}}) + msgs := Messages{} + msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewLockSignal("unlock"), node.ID) + return ctx.Send(msgs) } // Send the signal to lock a node from itself func LockLockable(ctx *Context, node *Node) error { - return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("lock")}}) + msgs := Messages{} + msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewLockSignal("lock"), node.ID) + return ctx.Send(msgs) } // Setup a node to send the initial requirement link signal, then send the signal -func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error { - return ctx.Send(dependency, []Message{Message{dependency, NewLinkStartSignal("req", requirement)}}) +func LinkRequirement(ctx *Context, dependency *Node, requirement NodeID) error { + msgs := Messages{} + msgs = msgs.Add(ctx.Log, dependency.ID, dependency.Key, NewLinkStartSignal("req", requirement), dependency.ID) + return ctx.Send(msgs) } // Handle a LockSignal and update the extensions owner/requirement states -func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message { +func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages { state := signal.Str log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal) - messages := []Message{} + messages := Messages{} switch state { case "unlock": if ext.Owner == nil { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocked")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocked"), source) } else if source != *ext.Owner { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_owner")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_owner"), source) } else if ext.PendingOwner == nil { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocking"), source) } else { if len(ext.Requirements) == 0 { ext.Owner = nil ext.PendingOwner = nil - messages = append(messages, Message{source, NewLockSignal("unlocked")}) + messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlocked"), source) } else { ext.PendingOwner = nil for id, state := range(ext.Requirements) { @@ -179,22 +185,22 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, } state.Lock = "unlocking" ext.Requirements[id] = state - messages = append(messages, Message{id, NewLockSignal("unlock")}) + messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlock"), id) } } if source != node.ID { - messages = append(messages, Message{source, NewLockSignal("unlocking")}) + messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlocking"), source) } } } case "unlocking": state, exists := ext.Requirements[source] if exists == false { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state.Link != "linked" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) } else if state.Lock != "unlocking" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) } case "unlocked": @@ -204,11 +210,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, state, exists := ext.Requirements[source] if exists == false { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state.Link != "linked" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) } else if state.Lock != "unlocking" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source) } else { state.Lock = "unlocked" ext.Requirements[source] = state @@ -228,7 +234,7 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, if linked == unlocked { previous_owner := *ext.Owner ext.Owner = nil - messages = append(messages, Message{previous_owner, NewLockSignal("unlocked")}) + messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlocked"), previous_owner) } } } @@ -239,11 +245,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, state, exists := ext.Requirements[source] if exists == false { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state.Link != "linked" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) } else if state.Lock != "locking" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) } else { state.Lock = "locked" ext.Requirements[source] = state @@ -262,31 +268,31 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, if linked == locked { ext.Owner = ext.PendingOwner - messages = append(messages, Message{*ext.Owner, NewLockSignal("locked")}) + messages = messages.Add(log, node.ID, node.Key, NewLockSignal("locked"), *ext.Owner) } } } case "locking": state, exists := ext.Requirements[source] if exists == false { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source) } else if state.Link != "linked" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source) } else if state.Lock != "locking" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source) } case "lock": if ext.Owner != nil { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locked")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source) } else if ext.PendingOwner != nil { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locking"), source) } else { owner := source if len(ext.Requirements) == 0 { ext.Owner = &owner ext.PendingOwner = ext.Owner - messages = append(messages, Message{source, NewLockSignal("locked")}) + messages = messages.Add(log, node.ID, node.Key, NewLockSignal("locked"), source) } else { ext.PendingOwner = &owner for id, state := range(ext.Requirements) { @@ -297,11 +303,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, } state.Lock = "locking" ext.Requirements[id] = state - messages = append(messages, Message{id, NewLockSignal("lock")}) + messages = messages.Add(log, node.ID, node.Key, NewLockSignal("lock"), id) } } if source != node.ID { - messages = append(messages, Message{source, NewLockSignal("locking")}) + messages = messages.Add(log, node.ID, node.Key, NewLockSignal("locking"), source) } } } @@ -312,36 +318,36 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, return messages } -func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) []Message { +func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) Messages { link_type := signal.Str target := signal.NodeID log.Logf("lockable", "LINK_START_SIGNAL: %s->%s %s %s", source, node.ID, link_type, target) - messages := []Message{} + messages := Messages{} switch link_type { case "req": state, exists := ext.Requirements[target] _, dep_exists := ext.Dependencies[target] if ext.Owner != nil { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already locked")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source) } else if ext.Owner != ext.PendingOwner { if ext.PendingOwner == nil { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source) } else { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source) } } else if exists == true { if state.Link == "linking" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking_req")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking_req"), source) } else if state.Link == "linked" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_req")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_req"), source) } } else if dep_exists == true { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_dep")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_dep"), source) } else { ext.Requirements[target] = LinkState{"linking", "unlocked", source} - messages = append(messages, Message{target, NewLinkSignal("linked_as_req")}) - messages = append(messages, Message{source, NewLinkStartSignal("linking_req", target)}) + messages = messages.Add(log, node.ID, node.Key, NewLinkSignal("linked_as_req"), target) + messages = messages.Add(log, node.ID, node.Key, NewLinkStartSignal("linking_req", target), source) } } return messages @@ -349,16 +355,16 @@ func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source Nod // Handle LinkSignal, updating the extensions requirements and dependencies as necessary // TODO: Add unlink -func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message { +func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages { log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) state := signal.Str - messages := []Message{} + messages := Messages{} switch state { case "dep_done": state, exists := ext.Requirements[source] if exists == false { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linking"), source) } else if state.Link == "linking" { state.Link = "linked" ext.Requirements[source] = state @@ -368,16 +374,16 @@ func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, state, exists := ext.Dependencies[source] if exists == false { ext.Dependencies[source] = LinkState{"linked", "unlocked", source} - messages = append(messages, Message{source, NewLinkSignal("dep_done")}) + messages = messages.Add(log, node.ID, node.Key, NewLinkSignal("dep_done"), source) } else if state.Link == "linking" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking"), source) } else if state.Link == "linked" { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linked")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linked"), source) } else if ext.PendingOwner != ext.Owner { if ext.Owner == nil { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source) } else { - messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")}) + messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source) } } @@ -389,15 +395,15 @@ func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, // LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state -func (ext *LockableExt) Process(ctx *Context, node *Node, msg Message) []Message { - messages := []Message{} - switch msg.Signal.Direction() { +func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { + messages := Messages{} + switch signal.Direction() { case Up: ctx.Log.Logf("lockable", "LOCKABLE_DEPENDENCIES: %+v", ext.Dependencies) owner_sent := false for dependency, state := range(ext.Dependencies) { if state.Link == "linked" { - messages = append(messages, Message{dependency, msg.Signal}) + messages = messages.Add(ctx.Log, node.ID, node.Key, signal, dependency) if ext.Owner != nil { if dependency == *ext.Owner { owner_sent = true @@ -408,23 +414,23 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, msg Message) []Message if ext.Owner != nil && owner_sent == false { if *ext.Owner != node.ID { - messages = append(messages, Message{*ext.Owner, msg.Signal}) + messages = messages.Add(ctx.Log, node.ID, node.Key, signal, *ext.Owner) } } case Down: for requirement, state := range(ext.Requirements) { if state.Link == "linked" { - messages = append(messages, Message{requirement, msg.Signal}) + messages = messages.Add(ctx.Log, node.ID, node.Key, signal, requirement) } } case Direct: - switch msg.Signal.Type() { + switch signal.Type() { case LinkSignalType: - messages = ext.HandleLinkSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal)) + messages = ext.HandleLinkSignal(ctx.Log, node, source, signal.(*StringSignal)) case LockSignalType: - messages = ext.HandleLockSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal)) + messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal)) case LinkStartSignalType: - messages = ext.HandleLinkStartSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*IDStringSignal)) + messages = ext.HandleLinkStartSignal(ctx.Log, node, source, signal.(*IDStringSignal)) default: } default: diff --git a/lockable_test.go b/lockable_test.go index c436176..584fb96 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -35,24 +35,25 @@ func TestLink(t *testing.T) { ) // Link l2 as a requirement of l1 - err := LinkRequirement(ctx, l1.ID, l2.ID) + err := LinkRequirement(ctx, l1, l2.ID) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { return sig.Str == "dep_done" }) fatalErr(t, err) - sig1 := NewStatusSignal("TEST", l2.ID) - err = ctx.Send(l2.ID, []Message{{l2.ID, sig1}}) + msgs := Messages{} + msgs = msgs.Add(ctx.Log, l2.ID, l2.Key, NewStatusSignal("TEST", l2.ID), l2.ID) + err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "TEST" }) fatalErr(t, err) - _, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { + _, err = WaitForSignal(ctx, l2_listener.Chan, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "TEST" }) fatalErr(t, err) @@ -81,14 +82,14 @@ func TestLink10K(t *testing.T) { lockables := make([]*Node, 10) for i, _ := range(lockables) { lockables[i] = NewLockable() - LinkRequirement(ctx, l0.ID, lockables[i].ID) + LinkRequirement(ctx, l0, lockables[i].ID) } ctx.Log.Logf("test", "CREATED_10K") for range(lockables) { - _, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { + _, err := WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { return sig.Str == "dep_done" }) fatalErr(t, err) @@ -118,22 +119,22 @@ func TestLock(t *testing.T) { var err error - err = LinkRequirement(ctx, l1.ID, l2.ID) + err = LinkRequirement(ctx, l1, l2.ID) fatalErr(t, err) - err = LinkRequirement(ctx, l1.ID, l3.ID) + err = LinkRequirement(ctx, l1, l3.ID) fatalErr(t, err) - err = LinkRequirement(ctx, l1.ID, l4.ID) + err = LinkRequirement(ctx, l1, l4.ID) fatalErr(t, err) - err = LinkRequirement(ctx, l1.ID, l5.ID) + err = LinkRequirement(ctx, l1, l5.ID) fatalErr(t, err) - err = LinkRequirement(ctx, l0.ID, l2.ID) + err = LinkRequirement(ctx, l0, l2.ID) fatalErr(t, err) - err = LinkRequirement(ctx, l0.ID, l3.ID) + err = LinkRequirement(ctx, l0, l3.ID) fatalErr(t, err) - err = LinkRequirement(ctx, l0.ID, l4.ID) + err = LinkRequirement(ctx, l0, l4.ID) fatalErr(t, err) - err = LinkRequirement(ctx, l0.ID, l5.ID) + err = LinkRequirement(ctx, l0, l5.ID) fatalErr(t, err) linked_as_req := func(sig *StringSignal) bool { @@ -144,35 +145,35 @@ func TestLock(t *testing.T) { return sig.Str == "locked" } - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) + _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req) fatalErr(t, err) err = LockLockable(ctx, l1) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) + _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked) fatalErr(t, err) err = UnlockLockable(ctx, l1) diff --git a/node.go b/node.go index a6e4b34..6a66c51 100644 --- a/node.go +++ b/node.go @@ -10,6 +10,7 @@ import ( "encoding/binary" "encoding/json" "sync/atomic" + "crypto" "crypto/ed25519" "crypto/sha512" "crypto/rand" @@ -94,7 +95,7 @@ type Serializable[I comparable] interface { type Extension interface { Serializable[ExtType] Field(string)interface{} - Process(ctx *Context, node *Node, message Message)[]Message + Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages } // A QueuedSignal is a Signal that has been Queued to trigger at a set time @@ -114,7 +115,7 @@ type Node struct { Policies map[PolicyType]Policy // Channel for this node to receive messages from the Context - MsgChan chan Message + MsgChan chan *Message // Size of MsgChan BufferSize uint32 // Channel for this node to process delayed signals @@ -126,16 +127,20 @@ type Node struct { NextSignal *QueuedSignal } -func (node *Node) Allows(principal_id NodeID, action Action) error { +func (node *Node) Allows(principal_id NodeID, action Action)(Messages, error) { errs := []error{} + var pends Messages = nil for _, policy := range(node.Policies) { - err := policy.Allows(principal_id, action, node) + msgs, err := policy.Allows(principal_id, action, node) if err == nil { - return nil + return nil, nil } errs = append(errs, err) + if msgs != nil { + pends = append(pends, msgs...) + } } - return fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs) + return pends, fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs) } func (node *Node) QueueSignal(time time.Time, signal Signal) uuid.UUID { @@ -177,21 +182,16 @@ func runNode(ctx *Context, node *Node) { ctx.Log.Logf("node", "RUN_STOP: %s", node.ID) } -func ReadNodeFields(ctx *Context, self *Node, princ NodeID, reqs map[ExtType][]string)map[ExtType]map[string]interface{} { +func (node *Node) ReadFields(reqs map[ExtType][]string)map[ExtType]map[string]interface{} { exts := map[ExtType]map[string]interface{}{} for ext_type, field_reqs := range(reqs) { fields := map[string]interface{}{} for _, req := range(field_reqs) { - err := self.Allows(princ, MakeAction(ReadSignalType, ext_type, req)) - if err != nil { - fields[req] = err + ext, exists := node.Extensions[ext_type] + if exists == false { + fields[req] = fmt.Errorf("%s does not have %s extension", node.ID, ext_type) } else { - ext, exists := self.Extensions[ext_type] - if exists == false { - fields[req] = fmt.Errorf("%s does not have %s extension", self.ID, ext_type) - } else { - fields[req] = ext.Field(req) - } + fields[req] = ext.Field(req) } } exts[ext_type] = fields @@ -207,16 +207,41 @@ func nodeLoop(ctx *Context, node *Node) error { } // Perform startup actions - node.Process(ctx, Message{ZeroID, &StartSignal}) + node.Process(ctx, ZeroID, &StartSignal) for true { - var msg Message + var signal Signal + var source NodeID select { - case msg = <- node.MsgChan: - ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal.Type()) + case msg := <- node.MsgChan: + ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal) + ser, err := msg.Signal.Serialize() + if err != nil { + ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", node.ID, msg.Signal) + continue + } + + sig_data := append(msg.Dest.Serialize(), msg.Source.Serialize()...) + sig_data = append(sig_data, ser...) + validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature) + if validated == false { + ctx.Log.Logf("signal", "SIGNAL_VERIFY_ERR: %s - %+v", node.ID, msg) + continue + } + + _, err = node.Allows(KeyID(msg.Principal), msg.Signal.Permission()) + if err != nil { + ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s - %s - %e", node.ID, msg.Signal, err) + // TODO: send the msgs and set the state so that getting a response triggers a potential processing of the original signal + continue + } + + signal = msg.Signal + source = msg.Source + case <-node.TimeoutChan: - signal := node.NextSignal.Signal - msg = Message{node.ID, signal} + signal = node.NextSignal.Signal + source = node.ID t := node.NextSignal.Time i := -1 @@ -241,52 +266,28 @@ func nodeLoop(ctx *Context, node *Node) error { } } - // Unwrap Authorized Signals - if msg.Signal.Type() == AuthorizedSignalType { - sig, ok := msg.Signal.(*AuthorizedSignal) - if ok == false { - ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(msg.Signal)) - } else { - // Validate - sig_data, err := sig.Signal.Serialize() - if err != nil { - } else { - validated := ed25519.Verify(sig.Principal, sig_data, sig.Signature) - if validated == true { - err := node.Allows(KeyID(sig.Principal), sig.Signal.Permission()) - if err != nil { - ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err) - ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.Signal.ID(), err.Error())}}) - } else { - // Unwrap the signal without changing the source - msg = Message{msg.NodeID, sig.Signal} - } - } else { - ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate") - ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.ID(), "signature validation failed")}}) - } - } - } - } - ctx.Log.Logf("node_signal_queue", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) // Handle special signal types - if msg.Signal.Type() == StopSignalType { - ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(msg.Signal.ID(), "stopped")}}) - node.Process(ctx, Message{node.ID, NewStatusSignal("stopped", node.ID)}) + if signal.Type() == StopSignalType { + msgs := Messages{} + msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewErrorSignal(signal.ID(), "stopped"), source) + ctx.Send(msgs) + node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID)) break - } else if msg.Signal.Type() == ReadSignalType { - read_signal, ok := msg.Signal.(*ReadSignal) + } else if signal.Type() == ReadSignalType { + read_signal, ok := signal.(*ReadSignal) if ok == false { - ctx.Log.Logf("signal_read", "READ_SIGNAL: bad cast %+v", msg.Signal) + ctx.Log.Logf("signal_read", "READ_SIGNAL: bad cast %+v", signal) } else { - result := ReadNodeFields(ctx, node, msg.NodeID, read_signal.Extensions) - ctx.Send(node.ID, []Message{Message{msg.NodeID, NewReadResultSignal(read_signal.ID(), node.Type, result)}}) + result := node.ReadFields(read_signal.Extensions) + msgs := Messages{} + msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.Type, result), source) + ctx.Send(msgs) } } - node.Process(ctx, msg) + node.Process(ctx, source, signal) // assume that processing a signal means that this nodes state changed // TODO: remove a lot of database writes by only writing when things change, // so need to have Process return whether or not state changed @@ -304,23 +305,60 @@ func nodeLoop(ctx *Context, node *Node) error { } type Message struct { - NodeID - Signal + Source NodeID + Dest NodeID + Principal ed25519.PublicKey + Signal Signal + Signature []byte +} + +type Messages []*Message +func (msgs Messages) Add(log Logger, source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages { + msg, err := NewMessage(dest, source, principal, signal) + if err != nil { + log.Logf("signal", "MESSAGE_CREATE_ERR: %s", err) + } else { + msgs = append(msgs, msg) + } + return msgs } -func (node *Node) Process(ctx *Context, message Message) error { - ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, message.Signal.Type()) - messages := []Message{} +func NewMessage(dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) { + ser, err := signal.Serialize() + if err != nil { + return nil, err + } + + sig_data := append(dest.Serialize(), source.Serialize()...) + sig_data = append(sig_data, ser...) + + sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0)) + if err != nil { + return nil, err + } + + return &Message{ + Dest: dest, + Source: source, + Principal: principal.Public().(ed25519.PublicKey), + Signal: signal, + Signature: sig, + }, nil +} + +func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error { + ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal.Type()) + messages := Messages{} for ext_type, ext := range(node.Extensions) { ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) //TODO: add extension and node info to log - resp := ext.Process(ctx, node, message) + resp := ext.Process(ctx, node, source, signal) if resp != nil { messages = append(messages, resp...) } } - return ctx.Send(node.ID, messages) + return ctx.Send(messages) } func GetCtx[T Extension, C any](ctx *Context) (C, error) { @@ -378,6 +416,7 @@ func (node *Node) Serialize() ([]byte, error) { NumQueuedSignals: uint32(len(node.SignalQueue)), }, Extensions: extensions, + Policies: policies, QueuedSignals: qsignals, KeyBytes: key_bytes, } @@ -489,7 +528,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si Type: node_type, Extensions: ext_map, Policies: policies, - MsgChan: make(chan Message, buffer_size), + MsgChan: make(chan *Message, buffer_size), BufferSize: buffer_size, SignalQueue: []QueuedSignal{}, } @@ -499,7 +538,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si panic(err) } - node.Process(ctx, Message{node.ID, &NewSignal}) + node.Process(ctx, ZeroID, &NewSignal) go runNode(ctx, node) @@ -833,7 +872,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { Type: node_type.Type, Extensions: map[ExtType]Extension{}, Policies: policies, - MsgChan: make(chan Message, node_db.Header.BufferSize), + MsgChan: make(chan *Message, node_db.Header.BufferSize), BufferSize: node_db.Header.BufferSize, TimeoutChan: timeout_chan, SignalQueue: signal_queue, diff --git a/node_test.go b/node_test.go index 44e524b..d19558b 100644 --- a/node_test.go +++ b/node_test.go @@ -45,9 +45,12 @@ func TestNodeRead(t *testing.T) { read_sig := NewReadSignal(map[ExtType][]string{ GroupExtType: []string{"members"}, }) - ctx.Send(n2.ID, []Message{{n1.ID, &read_sig}}) + msgs := Messages{} + msgs = msgs.Add(ctx.Log, n2.ID, n2.Key, read_sig, n1.ID) + err = ctx.Send(msgs) + fatalErr(t, err) - res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { + res, err := WaitForSignal(ctx, n2_listener.Chan, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { return true }) fatalErr(t, err) @@ -80,10 +83,12 @@ func TestECDH(t *testing.T) { } fatalErr(t, err) ctx.Log.Logf("test", "N1_EC: %+v", n1_ec) - err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_req}}) + msgs := Messages{} + msgs = msgs.Add(ctx.Log, n1.ID, n1.Key, ecdh_req, n2.ID) + err = ctx.Send(msgs) fatalErr(t, err) - _, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool { + _, err = WaitForSignal(ctx, n1_listener.Chan, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool { return sig.Str == "resp" }) fatalErr(t, err) @@ -92,6 +97,8 @@ func TestECDH(t *testing.T) { ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret) fatalErr(t, err) - err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_sig}}) + msgs = Messages{} + msgs = msgs.Add(ctx.Log, n1.ID, n1.Key, ecdh_sig, n2.ID) + err = ctx.Send(msgs) fatalErr(t, err) } diff --git a/policy.go b/policy.go index 2245176..045acd3 100644 --- a/policy.go +++ b/policy.go @@ -14,40 +14,40 @@ const ( type Policy interface { Serializable[PolicyType] - Allows(principal_id NodeID, action Action, node *Node) error + Allows(principal_id NodeID, action Action, node *Node)(Messages, error) // Merge with another policy of the same underlying type Merge(Policy) Policy // Make a copy of this policy Copy() Policy } -func (policy AllNodesPolicy) Allows(principal_id NodeID, action Action, node *Node) error { - return policy.Actions.Allows(action) +func (policy AllNodesPolicy) Allows(principal_id NodeID, action Action, node *Node)(Messages, error) { + return nil, policy.Actions.Allows(action) } -func (policy PerNodePolicy) Allows(principal_id NodeID, action Action, node *Node) error { +func (policy PerNodePolicy) Allows(principal_id NodeID, action Action, node *Node)(Messages, error) { for id, actions := range(policy.NodeActions) { if id != principal_id { continue } - return actions.Allows(action) + return nil, actions.Allows(action) } - return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) + return nil, fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) } -func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error { +func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node)(Messages, error) { lockable_ext, err := GetExt[*LockableExt](node) if err != nil { - return err + return nil, err } for id, _ := range(lockable_ext.Requirements) { if id == principal_id { - return policy.Actions.Allows(action) + return nil, policy.Actions.Allows(action) } } - return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID) + return nil, fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID) } type UserOfPolicy struct { @@ -65,11 +65,11 @@ func NewUserOfPolicy(group_actions NodeActions) UserOfPolicy { } // Send a read signal to Group to check if principal_id is a member of it -func (policy *UserOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error { +func (policy *UserOfPolicy) Allows(principal_id NodeID, action Action, node *Node) (Messages, error) { // Send a read signal to each of the groups in the map // Check for principal_id in any of the returned member lists(skipping errors) // Return an error in the default case - return fmt.Errorf("NOT_IMPLEMENTED") + return nil, fmt.Errorf("NOT_IMPLEMENTED") } func (policy *UserOfPolicy) Merge(p Policy) Policy { diff --git a/signal.go b/signal.go index b950c79..71d28ec 100644 --- a/signal.go +++ b/signal.go @@ -24,7 +24,6 @@ const ( LinkSignalType = "LINK" LockSignalType = "LOCK" ReadSignalType = "READ" - AuthorizedSignalType = "AUTHORIZED" ReadResultSignalType = "READ_RESULT" LinkStartSignalType = "LINK_START" ECDHSignalType = "ECDH" @@ -48,21 +47,7 @@ type Signal interface { Permission() Action } -func WaitForResult(listener chan Signal, timeout time.Duration, id uuid.UUID) (Signal, error) { - timeout_channel := time.After(timeout) - select { - case result:=<-listener: - if result.ID() == id { - return result, nil - } else { - return result, fmt.Errorf("WRONG_ID: %s", result.ID()) - } - case <-timeout_channel: - return nil, fmt.Errorf("timeout waiting for read response to %s", id) - } -} - -func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { +func WaitForSignal[S Signal](ctx * Context, listener chan Signal, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) { var zero S var timeout_channel <- chan time.Time if timeout > 0 { @@ -70,12 +55,12 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time. } for true { select { - case msg := <- listener.Chan: - if msg.Signal == nil { + case signal := <- listener: + if signal == nil { return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type) } - if msg.Signal.Type() == signal_type { - sig, ok := msg.Signal.(S) + if signal.Type() == signal_type { + sig, ok := signal.(S) if ok == true { if check(sig) == true { return sig, nil @@ -89,7 +74,6 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time. return zero, fmt.Errorf("LOOP_ENDED") } - type BaseSignal struct { SignalDirection SignalDirection `json:"direction"` SignalType SignalType `json:"type"` @@ -244,42 +228,12 @@ type ReadSignal struct { Extensions map[ExtType][]string `json:"extensions"` } -type AuthorizedSignal struct { - BaseSignal - Principal ed25519.PublicKey - Signal Signal - Signature []byte -} - -func (signal *AuthorizedSignal) Permission() Action { - return AuthorizedSignalAction -} - -func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (Signal, error) { - sig_data, err := signal.Serialize() - if err != nil { - return nil, err - } - - sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0)) - if err != nil { - return nil, err - } - - return &AuthorizedSignal{ - NewBaseSignal(AuthorizedSignalType, Direct), - principal.Public().(ed25519.PublicKey), - signal, - sig, - }, nil -} - func (signal *ReadSignal) Serialize() ([]byte, error) { return json.Marshal(signal) } -func NewReadSignal(exts map[ExtType][]string) ReadSignal { - return ReadSignal{ +func NewReadSignal(exts map[ExtType][]string) *ReadSignal { + return &ReadSignal{ NewBaseSignal(ReadSignalType, Direct), exts, } diff --git a/user.go b/user.go index 94be32f..d14eb8c 100644 --- a/user.go +++ b/user.go @@ -48,7 +48,7 @@ func (ext *GroupExt) Deserialize(ctx *Context, data []byte) error { return err } -func (ext *GroupExt) Process(ctx *Context, node *Node, msg Message) []Message { +func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages { return nil }