Added signature to all signals(signature of serialized signal + source + dest so technically vulnerable to replay) to use for ACL

gql_cataclysm
noah metz 2023-08-08 14:00:17 -06:00
parent 96e842decf
commit f31beade29
12 changed files with 303 additions and 311 deletions

@ -235,25 +235,19 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) {
return target, nil return target, nil
} }
// Stop every running loop
func (ctx *Context) Stop() {
for _, node := range(ctx.Nodes) {
node.MsgChan <- Message{ZeroID, &StopSignal}
}
}
// Route a Signal to dest. Currently only local context routing is supported // Route a Signal to dest. Currently only local context routing is supported
func (ctx *Context) Send(source NodeID, messages []Message) error { func (ctx *Context) Send(messages Messages) error {
for _, msg := range(messages) { for _, msg := range(messages) {
target, err := ctx.GetNode(msg.NodeID) target, err := ctx.GetNode(msg.Dest)
if err == nil { if err == nil {
select { select {
case target.MsgChan <- Message{source, msg.Signal}: case target.MsgChan <- msg:
ctx.Log.Logf("signal", "Sent %s -> %+v", target.ID, msg)
default: default:
buf := make([]byte, 4096) buf := make([]byte, 4096)
n := runtime.Stack(buf, false) n := runtime.Stack(buf, false)
stack_str := string(buf[:n]) stack_str := string(buf[:n])
return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", msg.NodeID, stack_str) return fmt.Errorf("SIGNAL_OVERFLOW: %s - %s", msg.Dest, stack_str)
} }
} else if errors.Is(err, NodeNotFoundError) { } else if errors.Is(err, NodeNotFoundError) {
// TODO: Handle finding nodes in other contexts // TODO: Handle finding nodes in other contexts

@ -103,10 +103,10 @@ func (ext *ECDHExt) Field(name string) interface{} {
}) })
} }
func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) []Message { func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal) Messages {
source := KeyID(signal.EDDSA) source := KeyID(signal.EDDSA)
messages := []Message{} messages := Messages{}
switch signal.Str { switch signal.Str {
case "req": case "req":
state, exists := ext.ECDHStates[source] state, exists := ext.ECDHStates[source]
@ -118,15 +118,15 @@ func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal)
state.SharedSecret = shared_secret state.SharedSecret = shared_secret
ext.ECDHStates[source] = state ext.ECDHStates[source] = state
log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
messages = append(messages, Message{source, &resp}) messages = messages.Add(log, node.ID, node.Key, &resp, source)
} else { } else {
log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), err.Error())}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), err.Error()), source)
} }
case "resp": case "resp":
state, exists := ext.ECDHStates[source] state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil { if exists == false || state.ECKey == nil {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "no_req")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "no_req"), source)
} else { } else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil { if err == nil {
@ -145,10 +145,10 @@ func (ext *ECDHExt) HandleECDHSignal(log Logger, node *Node, signal *ECDHSignal)
return messages return messages
} }
func (ext *ECDHExt) Process(ctx *Context, node *Node, msg Message) []Message { func (ext *ECDHExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
switch msg.Signal.Type() { switch signal.Type() {
case ECDHSignalType: case ECDHSignalType:
sig := msg.Signal.(*ECDHSignal) sig := signal.(*ECDHSignal)
return ext.HandleECDHSignal(ctx.Log, node, sig) return ext.HandleECDHSignal(ctx.Log, node, sig)
} }
return nil return nil

@ -921,31 +921,7 @@ func NewGQLExtContext() *GQLExtContext {
context.Mutation.AddFieldConfig("stop", &graphql.Field{ context.Mutation.AddFieldConfig("stop", &graphql.Field{
Type: graphql.String, Type: graphql.String,
Resolve: func(p graphql.ResolveParams) (interface{}, error) { Resolve: func(p graphql.ResolveParams) (interface{}, error) {
ctx, err := PrepResolve(p) return nil, fmt.Errorf("NOT_IMPLEMENTED")
if err != nil {
return nil, err
}
sig, err := NewAuthorizedSignal(ctx.Key, &StopSignal)
if err != nil {
return nil, err
}
response_chan := ctx.Ext.GetResponseChannel(sig.ID())
err = ctx.Context.Send(ctx.Server.ID, []Message{Message{ctx.Server.ID, sig}})
if err != nil {
ctx.Ext.FreeResponseChannel(sig.ID())
return nil, err
}
resp, err := WaitForResult(response_chan, 100*time.Millisecond, sig.ID())
if err != nil {
return nil, err
}
ser, err := resp.Serialize()
return string(ser), err
}, },
}) })
@ -1053,10 +1029,9 @@ func (ext *GQLExt) FreeResponseChannel(req_id uuid.UUID) chan Signal {
} }
} }
func (ext *GQLExt) Process(ctx *Context, node *Node, msg Message) []Message { func (ext *GQLExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
// Process ReadResultSignalType by forwarding it to the waiting resolver // Process ReadResultSignalType by forwarding it to the waiting resolver
signal := msg.Signal messages := Messages{}
messages := []Message{}
if signal.Type() == ErrorSignalType { if signal.Type() == ErrorSignalType {
// TODO: Forward to resolver if waiting for it // TODO: Forward to resolver if waiting for it
sig := signal.(*ErrorSignal) sig := signal.(*ErrorSignal)

@ -2,7 +2,6 @@ package graphvent
import ( import (
"time" "time"
"reflect" "reflect"
"fmt"
"github.com/graphql-go/graphql" "github.com/graphql-go/graphql"
"github.com/graphql-go/graphql/language/ast" "github.com/graphql-go/graphql/language/ast"
"github.com/google/uuid" "github.com/google/uuid"
@ -51,17 +50,15 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
} }
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
read_signal := NewReadSignal(ext_fields) read_signal := NewReadSignal(ext_fields)
auth_signal, err := NewAuthorizedSignal(ctx.Key, &read_signal) msgs := Messages{}
if err != nil { msgs = msgs.Add(ctx.Context.Log, ctx.Server.ID, ctx.Key, read_signal, id)
return nil, err
}
response_chan := ctx.Ext.GetResponseChannel(read_signal.ID()) response_chan := ctx.Ext.GetResponseChannel(read_signal.ID())
resp_channels[read_signal.ID()] = response_chan resp_channels[read_signal.ID()] = response_chan
node_ids[read_signal.ID()] = id node_ids[read_signal.ID()] = id
err = ctx.Context.Send(ctx.Server.ID, []Message{Message{id, auth_signal}}) // TODO: Send all at once instead of createing Messages for each
err = ctx.Context.Send(msgs)
if err != nil { if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID()) ctx.Ext.FreeResponseChannel(read_signal.ID())
return nil, err return nil, err
@ -71,18 +68,13 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
responses := []NodeResult{} responses := []NodeResult{}
for sig_id, response_chan := range(resp_channels) { for sig_id, response_chan := range(resp_channels) {
// Wait for the response, returning an error on timeout // Wait for the response, returning an error on timeout
response, err := WaitForResult(response_chan, time.Millisecond*100, sig_id) response, err := WaitForSignal(ctx.Context, response_chan, time.Millisecond*100, ReadResultSignalType, func(sig *ReadResultSignal)bool{
return sig.ReqID == sig_id
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch resp := response.(type) { responses = append(responses, NodeResult{node_ids[sig_id], response})
case *ReadResultSignal:
responses = append(responses, NodeResult{node_ids[sig_id], resp})
case *ErrorSignal:
return nil, fmt.Errorf(resp.Error)
default:
return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp))
}
} }
ctx.Context.Log.Logf("gql", "RESOLVED_NODES") ctx.Context.Log.Logf("gql", "RESOLVED_NODES")

@ -14,25 +14,42 @@ import (
) )
func TestGQLServer(t *testing.T) { func TestGQLServer(t *testing.T) {
ctx := logTestContext(t, []string{"gql", "lockable", "signal"}) ctx := logTestContext(t, []string{"test"})
TestNodeType := NodeType("TEST") TestNodeType := NodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType}) err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType})
fatalErr(t, err) fatalErr(t, err)
policy := NewAllNodesPolicy(Actions{
MakeAction(LinkSignalType, "+"),
MakeAction(LinkStartSignalType, "+"),
MakeAction(LockSignalType, "+"),
MakeAction(StatusSignalType, "+"),
MakeAction(ErrorSignalType, "+"),
MakeAction(ReadSignalType, "+"),
MakeAction(ReadResultSignalType, "+"),
MakeAction(GQLStateSignalType, "+"),
})
gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped") gql_ext, err := NewGQLExt(ctx, ":0", nil, nil, "stopped")
fatalErr(t, err) fatalErr(t, err)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
gql := NewNode(ctx, nil, GQLNodeType, 10, nil, NewLockableExt(), gql_ext, NewGroupExt(nil), listener_ext) gql := NewNode(ctx, nil, GQLNodeType, 10, map[PolicyType]Policy{
n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt()) AllNodesPolicyType: &policy,
}, NewLockableExt(), gql_ext, NewGroupExt(nil), listener_ext)
n1 := NewNode(ctx, nil, TestNodeType, 10, map[PolicyType]Policy{
AllNodesPolicyType: &policy,
}, NewLockableExt())
err = LinkRequirement(ctx, gql.ID, n1.ID) err = LinkRequirement(ctx, gql, n1.ID)
fatalErr(t, err) fatalErr(t, err)
err = ctx.Send(gql.ID, []Message{{gql.ID, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}}}) msgs := Messages{}
msgs = msgs.Add(ctx.Log, gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "start_server"}, gql.ID)
err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "server_started" return sig.Str == "server_started"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -84,9 +101,11 @@ func TestGQLServer(t *testing.T) {
resp_2 := SendGQL(req_2) resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2) ctx.Log.Logf("test", "RESP_2: %s", resp_2)
stop_signal := StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"} msgs = Messages{}
ctx.Send(gql.ID, []Message{{gql.ID, &stop_signal}}) msgs = msgs.Add(ctx.Log, gql.ID, gql.Key, &StringSignal{NewBaseSignal(GQLStateSignalType, Direct), "stop_server"}, gql.ID)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "server_stopped" return sig.Str == "server_stopped"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -111,8 +130,11 @@ func TestGQLDB(t *testing.T) {
NewGroupExt(nil)) NewGroupExt(nil))
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
ctx.Stop() msgs := Messages{}
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { msgs = msgs.Add(ctx.Log, gql.ID, gql.Key, &StopSignal, gql.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql.ID return sig.Str == "stopped" && sig.NodeID == gql.ID
}) })
fatalErr(t, err) fatalErr(t, err)
@ -130,9 +152,11 @@ func TestGQLDB(t *testing.T) {
fatalErr(t, err) fatalErr(t, err)
listener_ext, err = GetExt[*ListenerExt](gql_loaded) listener_ext, err = GetExt[*ListenerExt](gql_loaded)
fatalErr(t, err) fatalErr(t, err)
err = ctx.Send(gql_loaded.ID, []Message{{gql_loaded.ID, &StopSignal}}) msgs = Messages{}
msgs = msgs.Add(ctx.Log, gql_loaded.ID, gql_loaded.Key, &StopSignal, gql_loaded.ID)
err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(ctx, listener_ext.Chan, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID
}) })
fatalErr(t, err) fatalErr(t, err)

@ -7,14 +7,14 @@ import (
// A Listener extension provides a channel that can receive signals on a different thread // A Listener extension provides a channel that can receive signals on a different thread
type ListenerExt struct { type ListenerExt struct {
Buffer int Buffer int
Chan chan Message Chan chan Signal
} }
// Create a new listener extension with a given buffer size // Create a new listener extension with a given buffer size
func NewListenerExt(buffer int) *ListenerExt { func NewListenerExt(buffer int) *ListenerExt {
return &ListenerExt{ return &ListenerExt{
Buffer: buffer, Buffer: buffer,
Chan: make(chan Message, buffer), Chan: make(chan Signal, buffer),
} }
} }
@ -32,7 +32,7 @@ func (ext *ListenerExt) Field(name string) interface{} {
// Simple load function, unmarshal the buffer int from json // Simple load function, unmarshal the buffer int from json
func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error { func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error {
err := json.Unmarshal(data, &ext.Buffer) err := json.Unmarshal(data, &ext.Buffer)
ext.Chan = make(chan Message, ext.Buffer) ext.Chan = make(chan Signal, ext.Buffer)
return err return err
} }
@ -41,10 +41,10 @@ func (listener *ListenerExt) Type() ExtType {
} }
// Send the signal to the channel, logging an overflow if it occurs // Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, node *Node, msg Message) []Message { func (ext *ListenerExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, msg.Signal) ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal)
select { select {
case ext.Chan <- msg: case ext.Chan <- signal:
default: default:
ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID) ctx.Log.Logf("listener", "LISTENER_OVERFLOW: %s", node.ID)
} }
@ -138,38 +138,44 @@ func NewLockableExt() *LockableExt {
// Send the signal to unlock a node from itself // Send the signal to unlock a node from itself
func UnlockLockable(ctx *Context, node *Node) error { func UnlockLockable(ctx *Context, node *Node) error {
return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("unlock")}}) msgs := Messages{}
msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewLockSignal("unlock"), node.ID)
return ctx.Send(msgs)
} }
// Send the signal to lock a node from itself // Send the signal to lock a node from itself
func LockLockable(ctx *Context, node *Node) error { func LockLockable(ctx *Context, node *Node) error {
return ctx.Send(node.ID, []Message{Message{node.ID, NewLockSignal("lock")}}) msgs := Messages{}
msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewLockSignal("lock"), node.ID)
return ctx.Send(msgs)
} }
// Setup a node to send the initial requirement link signal, then send the signal // Setup a node to send the initial requirement link signal, then send the signal
func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error { func LinkRequirement(ctx *Context, dependency *Node, requirement NodeID) error {
return ctx.Send(dependency, []Message{Message{dependency, NewLinkStartSignal("req", requirement)}}) msgs := Messages{}
msgs = msgs.Add(ctx.Log, dependency.ID, dependency.Key, NewLinkStartSignal("req", requirement), dependency.ID)
return ctx.Send(msgs)
} }
// Handle a LockSignal and update the extensions owner/requirement states // Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message { func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages {
state := signal.Str state := signal.Str
log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal) log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal)
messages := []Message{} messages := Messages{}
switch state { switch state {
case "unlock": case "unlock":
if ext.Owner == nil { if ext.Owner == nil {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocked")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocked"), source)
} else if source != *ext.Owner { } else if source != *ext.Owner {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_owner")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_owner"), source)
} else if ext.PendingOwner == nil { } else if ext.PendingOwner == nil {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_unlocking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_unlocking"), source)
} else { } else {
if len(ext.Requirements) == 0 { if len(ext.Requirements) == 0 {
ext.Owner = nil ext.Owner = nil
ext.PendingOwner = nil ext.PendingOwner = nil
messages = append(messages, Message{source, NewLockSignal("unlocked")}) messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlocked"), source)
} else { } else {
ext.PendingOwner = nil ext.PendingOwner = nil
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
@ -179,22 +185,22 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
} }
state.Lock = "unlocking" state.Lock = "unlocking"
ext.Requirements[id] = state ext.Requirements[id] = state
messages = append(messages, Message{id, NewLockSignal("unlock")}) messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlock"), id)
} }
} }
if source != node.ID { if source != node.ID {
messages = append(messages, Message{source, NewLockSignal("unlocking")}) messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlocking"), source)
} }
} }
} }
case "unlocking": case "unlocking":
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state.Link != "linked" { } else if state.Link != "linked" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source)
} else if state.Lock != "unlocking" { } else if state.Lock != "unlocking" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source)
} }
case "unlocked": case "unlocked":
@ -204,11 +210,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state.Link != "linked" { } else if state.Link != "linked" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source)
} else if state.Lock != "unlocking" { } else if state.Lock != "unlocking" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_unlocking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_unlocking"), source)
} else { } else {
state.Lock = "unlocked" state.Lock = "unlocked"
ext.Requirements[source] = state ext.Requirements[source] = state
@ -228,7 +234,7 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
if linked == unlocked { if linked == unlocked {
previous_owner := *ext.Owner previous_owner := *ext.Owner
ext.Owner = nil ext.Owner = nil
messages = append(messages, Message{previous_owner, NewLockSignal("unlocked")}) messages = messages.Add(log, node.ID, node.Key, NewLockSignal("unlocked"), previous_owner)
} }
} }
} }
@ -239,11 +245,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state.Link != "linked" { } else if state.Link != "linked" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source)
} else if state.Lock != "locking" { } else if state.Lock != "locking" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source)
} else { } else {
state.Lock = "locked" state.Lock = "locked"
ext.Requirements[source] = state ext.Requirements[source] = state
@ -262,31 +268,31 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
if linked == locked { if linked == locked {
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
messages = append(messages, Message{*ext.Owner, NewLockSignal("locked")}) messages = messages.Add(log, node.ID, node.Key, NewLockSignal("locked"), *ext.Owner)
} }
} }
} }
case "locking": case "locking":
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_requirement")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_requirement"), source)
} else if state.Link != "linked" { } else if state.Link != "linked" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linked")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linked"), source)
} else if state.Lock != "locking" { } else if state.Lock != "locking" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_locking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_locking"), source)
} }
case "lock": case "lock":
if ext.Owner != nil { if ext.Owner != nil {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locked")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source)
} else if ext.PendingOwner != nil { } else if ext.PendingOwner != nil {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_locking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locking"), source)
} else { } else {
owner := source owner := source
if len(ext.Requirements) == 0 { if len(ext.Requirements) == 0 {
ext.Owner = &owner ext.Owner = &owner
ext.PendingOwner = ext.Owner ext.PendingOwner = ext.Owner
messages = append(messages, Message{source, NewLockSignal("locked")}) messages = messages.Add(log, node.ID, node.Key, NewLockSignal("locked"), source)
} else { } else {
ext.PendingOwner = &owner ext.PendingOwner = &owner
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
@ -297,11 +303,11 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
} }
state.Lock = "locking" state.Lock = "locking"
ext.Requirements[id] = state ext.Requirements[id] = state
messages = append(messages, Message{id, NewLockSignal("lock")}) messages = messages.Add(log, node.ID, node.Key, NewLockSignal("lock"), id)
} }
} }
if source != node.ID { if source != node.ID {
messages = append(messages, Message{source, NewLockSignal("locking")}) messages = messages.Add(log, node.ID, node.Key, NewLockSignal("locking"), source)
} }
} }
} }
@ -312,36 +318,36 @@ func (ext *LockableExt) HandleLockSignal(log Logger, node *Node, source NodeID,
return messages return messages
} }
func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) []Message { func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source NodeID, signal *IDStringSignal) Messages {
link_type := signal.Str link_type := signal.Str
target := signal.NodeID target := signal.NodeID
log.Logf("lockable", "LINK_START_SIGNAL: %s->%s %s %s", source, node.ID, link_type, target) log.Logf("lockable", "LINK_START_SIGNAL: %s->%s %s %s", source, node.ID, link_type, target)
messages := []Message{} messages := Messages{}
switch link_type { switch link_type {
case "req": case "req":
state, exists := ext.Requirements[target] state, exists := ext.Requirements[target]
_, dep_exists := ext.Dependencies[target] _, dep_exists := ext.Dependencies[target]
if ext.Owner != nil { if ext.Owner != nil {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already locked")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_locked"), source)
} else if ext.Owner != ext.PendingOwner { } else if ext.Owner != ext.PendingOwner {
if ext.PendingOwner == nil { if ext.PendingOwner == nil {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source)
} else { } else {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source)
} }
} else if exists == true { } else if exists == true {
if state.Link == "linking" { if state.Link == "linking" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking_req")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking_req"), source)
} else if state.Link == "linked" { } else if state.Link == "linked" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_req")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_req"), source)
} }
} else if dep_exists == true { } else if dep_exists == true {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_dep")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_dep"), source)
} else { } else {
ext.Requirements[target] = LinkState{"linking", "unlocked", source} ext.Requirements[target] = LinkState{"linking", "unlocked", source}
messages = append(messages, Message{target, NewLinkSignal("linked_as_req")}) messages = messages.Add(log, node.ID, node.Key, NewLinkSignal("linked_as_req"), target)
messages = append(messages, Message{source, NewLinkStartSignal("linking_req", target)}) messages = messages.Add(log, node.ID, node.Key, NewLinkStartSignal("linking_req", target), source)
} }
} }
return messages return messages
@ -349,16 +355,16 @@ func (ext *LockableExt) HandleLinkStartSignal(log Logger, node *Node, source Nod
// Handle LinkSignal, updating the extensions requirements and dependencies as necessary // Handle LinkSignal, updating the extensions requirements and dependencies as necessary
// TODO: Add unlink // TODO: Add unlink
func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *StringSignal) []Message { func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID, signal *StringSignal) Messages {
log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal)
state := signal.Str state := signal.Str
messages := []Message{} messages := Messages{}
switch state { switch state {
case "dep_done": case "dep_done":
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "not_linking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "not_linking"), source)
} else if state.Link == "linking" { } else if state.Link == "linking" {
state.Link = "linked" state.Link = "linked"
ext.Requirements[source] = state ext.Requirements[source] = state
@ -368,16 +374,16 @@ func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID,
state, exists := ext.Dependencies[source] state, exists := ext.Dependencies[source]
if exists == false { if exists == false {
ext.Dependencies[source] = LinkState{"linked", "unlocked", source} ext.Dependencies[source] = LinkState{"linked", "unlocked", source}
messages = append(messages, Message{source, NewLinkSignal("dep_done")}) messages = messages.Add(log, node.ID, node.Key, NewLinkSignal("dep_done"), source)
} else if state.Link == "linking" { } else if state.Link == "linking" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linking"), source)
} else if state.Link == "linked" { } else if state.Link == "linked" {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "already_linked")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "already_linked"), source)
} else if ext.PendingOwner != ext.Owner { } else if ext.PendingOwner != ext.Owner {
if ext.Owner == nil { if ext.Owner == nil {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "locking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "locking"), source)
} else { } else {
messages = append(messages, Message{source, NewErrorSignal(signal.ID(), "unlocking")}) messages = messages.Add(log, node.ID, node.Key, NewErrorSignal(signal.ID(), "unlocking"), source)
} }
} }
@ -389,15 +395,15 @@ func (ext *LockableExt) HandleLinkSignal(log Logger, node *Node, source NodeID,
// LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes // LockableExts process Up/Down signals by forwarding them to owner, dependency, and requirement nodes
// LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state // LockSignal and LinkSignal Direct signals are processed to update the requirement/dependency/lock state
func (ext *LockableExt) Process(ctx *Context, node *Node, msg Message) []Message { func (ext *LockableExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
messages := []Message{} messages := Messages{}
switch msg.Signal.Direction() { switch signal.Direction() {
case Up: case Up:
ctx.Log.Logf("lockable", "LOCKABLE_DEPENDENCIES: %+v", ext.Dependencies) ctx.Log.Logf("lockable", "LOCKABLE_DEPENDENCIES: %+v", ext.Dependencies)
owner_sent := false owner_sent := false
for dependency, state := range(ext.Dependencies) { for dependency, state := range(ext.Dependencies) {
if state.Link == "linked" { if state.Link == "linked" {
messages = append(messages, Message{dependency, msg.Signal}) messages = messages.Add(ctx.Log, node.ID, node.Key, signal, dependency)
if ext.Owner != nil { if ext.Owner != nil {
if dependency == *ext.Owner { if dependency == *ext.Owner {
owner_sent = true owner_sent = true
@ -408,23 +414,23 @@ func (ext *LockableExt) Process(ctx *Context, node *Node, msg Message) []Message
if ext.Owner != nil && owner_sent == false { if ext.Owner != nil && owner_sent == false {
if *ext.Owner != node.ID { if *ext.Owner != node.ID {
messages = append(messages, Message{*ext.Owner, msg.Signal}) messages = messages.Add(ctx.Log, node.ID, node.Key, signal, *ext.Owner)
} }
} }
case Down: case Down:
for requirement, state := range(ext.Requirements) { for requirement, state := range(ext.Requirements) {
if state.Link == "linked" { if state.Link == "linked" {
messages = append(messages, Message{requirement, msg.Signal}) messages = messages.Add(ctx.Log, node.ID, node.Key, signal, requirement)
} }
} }
case Direct: case Direct:
switch msg.Signal.Type() { switch signal.Type() {
case LinkSignalType: case LinkSignalType:
messages = ext.HandleLinkSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal)) messages = ext.HandleLinkSignal(ctx.Log, node, source, signal.(*StringSignal))
case LockSignalType: case LockSignalType:
messages = ext.HandleLockSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*StringSignal)) messages = ext.HandleLockSignal(ctx.Log, node, source, signal.(*StringSignal))
case LinkStartSignalType: case LinkStartSignalType:
messages = ext.HandleLinkStartSignal(ctx.Log, node, msg.NodeID, msg.Signal.(*IDStringSignal)) messages = ext.HandleLinkStartSignal(ctx.Log, node, source, signal.(*IDStringSignal))
default: default:
} }
default: default:

@ -35,24 +35,25 @@ func TestLink(t *testing.T) {
) )
// Link l2 as a requirement of l1 // Link l2 as a requirement of l1
err := LinkRequirement(ctx, l1.ID, l2.ID) err := LinkRequirement(ctx, l1, l2.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool {
return sig.Str == "dep_done" return sig.Str == "dep_done"
}) })
fatalErr(t, err) fatalErr(t, err)
sig1 := NewStatusSignal("TEST", l2.ID) msgs := Messages{}
err = ctx.Send(l2.ID, []Message{{l2.ID, sig1}}) msgs = msgs.Add(ctx.Log, l2.ID, l2.Key, NewStatusSignal("TEST", l2.ID), l2.ID)
err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "TEST" return sig.Str == "TEST"
}) })
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { _, err = WaitForSignal(ctx, l2_listener.Chan, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "TEST" return sig.Str == "TEST"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -81,14 +82,14 @@ func TestLink10K(t *testing.T) {
lockables := make([]*Node, 10) lockables := make([]*Node, 10)
for i, _ := range(lockables) { for i, _ := range(lockables) {
lockables[i] = NewLockable() lockables[i] = NewLockable()
LinkRequirement(ctx, l0.ID, lockables[i].ID) LinkRequirement(ctx, l0, lockables[i].ID)
} }
ctx.Log.Logf("test", "CREATED_10K") ctx.Log.Logf("test", "CREATED_10K")
for range(lockables) { for range(lockables) {
_, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool { _, err := WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, func(sig *StringSignal) bool {
return sig.Str == "dep_done" return sig.Str == "dep_done"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -118,22 +119,22 @@ func TestLock(t *testing.T) {
var err error var err error
err = LinkRequirement(ctx, l1.ID, l2.ID) err = LinkRequirement(ctx, l1, l2.ID)
fatalErr(t, err) fatalErr(t, err)
err = LinkRequirement(ctx, l1.ID, l3.ID) err = LinkRequirement(ctx, l1, l3.ID)
fatalErr(t, err) fatalErr(t, err)
err = LinkRequirement(ctx, l1.ID, l4.ID) err = LinkRequirement(ctx, l1, l4.ID)
fatalErr(t, err) fatalErr(t, err)
err = LinkRequirement(ctx, l1.ID, l5.ID) err = LinkRequirement(ctx, l1, l5.ID)
fatalErr(t, err) fatalErr(t, err)
err = LinkRequirement(ctx, l0.ID, l2.ID) err = LinkRequirement(ctx, l0, l2.ID)
fatalErr(t, err) fatalErr(t, err)
err = LinkRequirement(ctx, l0.ID, l3.ID) err = LinkRequirement(ctx, l0, l3.ID)
fatalErr(t, err) fatalErr(t, err)
err = LinkRequirement(ctx, l0.ID, l4.ID) err = LinkRequirement(ctx, l0, l4.ID)
fatalErr(t, err) fatalErr(t, err)
err = LinkRequirement(ctx, l0.ID, l5.ID) err = LinkRequirement(ctx, l0, l5.ID)
fatalErr(t, err) fatalErr(t, err)
linked_as_req := func(sig *StringSignal) bool { linked_as_req := func(sig *StringSignal) bool {
@ -144,35 +145,35 @@ func TestLock(t *testing.T) {
return sig.Str == "locked" return sig.Str == "locked"
} }
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkSignalType, linked_as_req) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkSignalType, linked_as_req) _, err = WaitForSignal(ctx, l0_listener.Chan, time.Millisecond*10, LinkSignalType, linked_as_req)
fatalErr(t, err) fatalErr(t, err)
err = LockLockable(ctx, l1) err = LockLockable(ctx, l1)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LockSignalType, locked) _, err = WaitForSignal(ctx, l1_listener.Chan, time.Millisecond*10, LockSignalType, locked)
fatalErr(t, err) fatalErr(t, err)
err = UnlockLockable(ctx, l1) err = UnlockLockable(ctx, l1)

@ -10,6 +10,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"sync/atomic" "sync/atomic"
"crypto"
"crypto/ed25519" "crypto/ed25519"
"crypto/sha512" "crypto/sha512"
"crypto/rand" "crypto/rand"
@ -94,7 +95,7 @@ type Serializable[I comparable] interface {
type Extension interface { type Extension interface {
Serializable[ExtType] Serializable[ExtType]
Field(string)interface{} Field(string)interface{}
Process(ctx *Context, node *Node, message Message)[]Message Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages
} }
// A QueuedSignal is a Signal that has been Queued to trigger at a set time // A QueuedSignal is a Signal that has been Queued to trigger at a set time
@ -114,7 +115,7 @@ type Node struct {
Policies map[PolicyType]Policy Policies map[PolicyType]Policy
// Channel for this node to receive messages from the Context // Channel for this node to receive messages from the Context
MsgChan chan Message MsgChan chan *Message
// Size of MsgChan // Size of MsgChan
BufferSize uint32 BufferSize uint32
// Channel for this node to process delayed signals // Channel for this node to process delayed signals
@ -126,16 +127,20 @@ type Node struct {
NextSignal *QueuedSignal NextSignal *QueuedSignal
} }
func (node *Node) Allows(principal_id NodeID, action Action) error { func (node *Node) Allows(principal_id NodeID, action Action)(Messages, error) {
errs := []error{} errs := []error{}
var pends Messages = nil
for _, policy := range(node.Policies) { for _, policy := range(node.Policies) {
err := policy.Allows(principal_id, action, node) msgs, err := policy.Allows(principal_id, action, node)
if err == nil { if err == nil {
return nil return nil, nil
} }
errs = append(errs, err) errs = append(errs, err)
if msgs != nil {
pends = append(pends, msgs...)
}
} }
return fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs) return pends, fmt.Errorf("POLICY_CHECK_ERRORS: %s %s.%s - %+v", principal_id, node.ID, action, errs)
} }
func (node *Node) QueueSignal(time time.Time, signal Signal) uuid.UUID { func (node *Node) QueueSignal(time time.Time, signal Signal) uuid.UUID {
@ -177,23 +182,18 @@ func runNode(ctx *Context, node *Node) {
ctx.Log.Logf("node", "RUN_STOP: %s", node.ID) ctx.Log.Logf("node", "RUN_STOP: %s", node.ID)
} }
func ReadNodeFields(ctx *Context, self *Node, princ NodeID, reqs map[ExtType][]string)map[ExtType]map[string]interface{} { func (node *Node) ReadFields(reqs map[ExtType][]string)map[ExtType]map[string]interface{} {
exts := map[ExtType]map[string]interface{}{} exts := map[ExtType]map[string]interface{}{}
for ext_type, field_reqs := range(reqs) { for ext_type, field_reqs := range(reqs) {
fields := map[string]interface{}{} fields := map[string]interface{}{}
for _, req := range(field_reqs) { for _, req := range(field_reqs) {
err := self.Allows(princ, MakeAction(ReadSignalType, ext_type, req)) ext, exists := node.Extensions[ext_type]
if err != nil {
fields[req] = err
} else {
ext, exists := self.Extensions[ext_type]
if exists == false { if exists == false {
fields[req] = fmt.Errorf("%s does not have %s extension", self.ID, ext_type) fields[req] = fmt.Errorf("%s does not have %s extension", node.ID, ext_type)
} else { } else {
fields[req] = ext.Field(req) fields[req] = ext.Field(req)
} }
} }
}
exts[ext_type] = fields exts[ext_type] = fields
} }
return exts return exts
@ -207,16 +207,41 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
// Perform startup actions // Perform startup actions
node.Process(ctx, Message{ZeroID, &StartSignal}) node.Process(ctx, ZeroID, &StartSignal)
for true { for true {
var msg Message var signal Signal
var source NodeID
select { select {
case msg = <- node.MsgChan: case msg := <- node.MsgChan:
ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal.Type()) ctx.Log.Logf("node_msg", "NODE_MSG: %s - %+v", node.ID, msg.Signal)
ser, err := msg.Signal.Serialize()
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_SERIALIZE_ERR: %s - %+v", node.ID, msg.Signal)
continue
}
sig_data := append(msg.Dest.Serialize(), msg.Source.Serialize()...)
sig_data = append(sig_data, ser...)
validated := ed25519.Verify(msg.Principal, sig_data, msg.Signature)
if validated == false {
ctx.Log.Logf("signal", "SIGNAL_VERIFY_ERR: %s - %+v", node.ID, msg)
continue
}
_, err = node.Allows(KeyID(msg.Principal), msg.Signal.Permission())
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s - %s - %e", node.ID, msg.Signal, err)
// TODO: send the msgs and set the state so that getting a response triggers a potential processing of the original signal
continue
}
signal = msg.Signal
source = msg.Source
case <-node.TimeoutChan: case <-node.TimeoutChan:
signal := node.NextSignal.Signal signal = node.NextSignal.Signal
msg = Message{node.ID, signal} source = node.ID
t := node.NextSignal.Time t := node.NextSignal.Time
i := -1 i := -1
@ -241,52 +266,28 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
} }
// Unwrap Authorized Signals
if msg.Signal.Type() == AuthorizedSignalType {
sig, ok := msg.Signal.(*AuthorizedSignal)
if ok == false {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(msg.Signal))
} else {
// Validate
sig_data, err := sig.Signal.Serialize()
if err != nil {
} else {
validated := ed25519.Verify(sig.Principal, sig_data, sig.Signature)
if validated == true {
err := node.Allows(KeyID(sig.Principal), sig.Signal.Permission())
if err != nil {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err)
ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.Signal.ID(), err.Error())}})
} else {
// Unwrap the signal without changing the source
msg = Message{msg.NodeID, sig.Signal}
}
} else {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate")
ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(sig.ID(), "signature validation failed")}})
}
}
}
}
ctx.Log.Logf("node_signal_queue", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) ctx.Log.Logf("node_signal_queue", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
// Handle special signal types // Handle special signal types
if msg.Signal.Type() == StopSignalType { if signal.Type() == StopSignalType {
ctx.Send(node.ID, []Message{Message{msg.NodeID, NewErrorSignal(msg.Signal.ID(), "stopped")}}) msgs := Messages{}
node.Process(ctx, Message{node.ID, NewStatusSignal("stopped", node.ID)}) msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewErrorSignal(signal.ID(), "stopped"), source)
ctx.Send(msgs)
node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID))
break break
} else if msg.Signal.Type() == ReadSignalType { } else if signal.Type() == ReadSignalType {
read_signal, ok := msg.Signal.(*ReadSignal) read_signal, ok := signal.(*ReadSignal)
if ok == false { if ok == false {
ctx.Log.Logf("signal_read", "READ_SIGNAL: bad cast %+v", msg.Signal) ctx.Log.Logf("signal_read", "READ_SIGNAL: bad cast %+v", signal)
} else { } else {
result := ReadNodeFields(ctx, node, msg.NodeID, read_signal.Extensions) result := node.ReadFields(read_signal.Extensions)
ctx.Send(node.ID, []Message{Message{msg.NodeID, NewReadResultSignal(read_signal.ID(), node.Type, result)}}) msgs := Messages{}
msgs = msgs.Add(ctx.Log, node.ID, node.Key, NewReadResultSignal(read_signal.ID(), node.Type, result), source)
ctx.Send(msgs)
} }
} }
node.Process(ctx, msg) node.Process(ctx, source, signal)
// assume that processing a signal means that this nodes state changed // assume that processing a signal means that this nodes state changed
// TODO: remove a lot of database writes by only writing when things change, // TODO: remove a lot of database writes by only writing when things change,
// so need to have Process return whether or not state changed // so need to have Process return whether or not state changed
@ -304,23 +305,60 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
type Message struct { type Message struct {
NodeID Source NodeID
Signal Dest NodeID
Principal ed25519.PublicKey
Signal Signal
Signature []byte
}
type Messages []*Message
func (msgs Messages) Add(log Logger, source NodeID, principal ed25519.PrivateKey, signal Signal, dest NodeID) Messages {
msg, err := NewMessage(dest, source, principal, signal)
if err != nil {
log.Logf("signal", "MESSAGE_CREATE_ERR: %s", err)
} else {
msgs = append(msgs, msg)
}
return msgs
}
func NewMessage(dest NodeID, source NodeID, principal ed25519.PrivateKey, signal Signal) (*Message, error) {
ser, err := signal.Serialize()
if err != nil {
return nil, err
}
sig_data := append(dest.Serialize(), source.Serialize()...)
sig_data = append(sig_data, ser...)
sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil {
return nil, err
}
return &Message{
Dest: dest,
Source: source,
Principal: principal.Public().(ed25519.PublicKey),
Signal: signal,
Signature: sig,
}, nil
} }
func (node *Node) Process(ctx *Context, message Message) error { func (node *Node) Process(ctx *Context, source NodeID, signal Signal) error {
ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, message.Signal.Type()) ctx.Log.Logf("node_process", "PROCESSING MESSAGE: %s - %+v", node.ID, signal.Type())
messages := []Message{} messages := Messages{}
for ext_type, ext := range(node.Extensions) { for ext_type, ext := range(node.Extensions) {
ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type) ctx.Log.Logf("node_process", "PROCESSING_EXTENSION: %s/%s", node.ID, ext_type)
//TODO: add extension and node info to log //TODO: add extension and node info to log
resp := ext.Process(ctx, node, message) resp := ext.Process(ctx, node, source, signal)
if resp != nil { if resp != nil {
messages = append(messages, resp...) messages = append(messages, resp...)
} }
} }
return ctx.Send(node.ID, messages) return ctx.Send(messages)
} }
func GetCtx[T Extension, C any](ctx *Context) (C, error) { func GetCtx[T Extension, C any](ctx *Context) (C, error) {
@ -378,6 +416,7 @@ func (node *Node) Serialize() ([]byte, error) {
NumQueuedSignals: uint32(len(node.SignalQueue)), NumQueuedSignals: uint32(len(node.SignalQueue)),
}, },
Extensions: extensions, Extensions: extensions,
Policies: policies,
QueuedSignals: qsignals, QueuedSignals: qsignals,
KeyBytes: key_bytes, KeyBytes: key_bytes,
} }
@ -489,7 +528,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
Type: node_type, Type: node_type,
Extensions: ext_map, Extensions: ext_map,
Policies: policies, Policies: policies,
MsgChan: make(chan Message, buffer_size), MsgChan: make(chan *Message, buffer_size),
BufferSize: buffer_size, BufferSize: buffer_size,
SignalQueue: []QueuedSignal{}, SignalQueue: []QueuedSignal{},
} }
@ -499,7 +538,7 @@ func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_si
panic(err) panic(err)
} }
node.Process(ctx, Message{node.ID, &NewSignal}) node.Process(ctx, ZeroID, &NewSignal)
go runNode(ctx, node) go runNode(ctx, node)
@ -833,7 +872,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
Type: node_type.Type, Type: node_type.Type,
Extensions: map[ExtType]Extension{}, Extensions: map[ExtType]Extension{},
Policies: policies, Policies: policies,
MsgChan: make(chan Message, node_db.Header.BufferSize), MsgChan: make(chan *Message, node_db.Header.BufferSize),
BufferSize: node_db.Header.BufferSize, BufferSize: node_db.Header.BufferSize,
TimeoutChan: timeout_chan, TimeoutChan: timeout_chan,
SignalQueue: signal_queue, SignalQueue: signal_queue,

@ -45,9 +45,12 @@ func TestNodeRead(t *testing.T) {
read_sig := NewReadSignal(map[ExtType][]string{ read_sig := NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"}, GroupExtType: []string{"members"},
}) })
ctx.Send(n2.ID, []Message{{n1.ID, &read_sig}}) msgs := Messages{}
msgs = msgs.Add(ctx.Log, n2.ID, n2.Key, read_sig, n1.ID)
err = ctx.Send(msgs)
fatalErr(t, err)
res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { res, err := WaitForSignal(ctx, n2_listener.Chan, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool {
return true return true
}) })
fatalErr(t, err) fatalErr(t, err)
@ -80,10 +83,12 @@ func TestECDH(t *testing.T) {
} }
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "N1_EC: %+v", n1_ec) ctx.Log.Logf("test", "N1_EC: %+v", n1_ec)
err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_req}}) msgs := Messages{}
msgs = msgs.Add(ctx.Log, n1.ID, n1.Key, ecdh_req, n2.ID)
err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool { _, err = WaitForSignal(ctx, n1_listener.Chan, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool {
return sig.Str == "resp" return sig.Str == "resp"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -92,6 +97,8 @@ func TestECDH(t *testing.T) {
ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret) ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret)
fatalErr(t, err) fatalErr(t, err)
err = ctx.Send(n1.ID, []Message{{n2.ID, ecdh_sig}}) msgs = Messages{}
msgs = msgs.Add(ctx.Log, n1.ID, n1.Key, ecdh_sig, n2.ID)
err = ctx.Send(msgs)
fatalErr(t, err) fatalErr(t, err)
} }

@ -14,40 +14,40 @@ const (
type Policy interface { type Policy interface {
Serializable[PolicyType] Serializable[PolicyType]
Allows(principal_id NodeID, action Action, node *Node) error Allows(principal_id NodeID, action Action, node *Node)(Messages, error)
// Merge with another policy of the same underlying type // Merge with another policy of the same underlying type
Merge(Policy) Policy Merge(Policy) Policy
// Make a copy of this policy // Make a copy of this policy
Copy() Policy Copy() Policy
} }
func (policy AllNodesPolicy) Allows(principal_id NodeID, action Action, node *Node) error { func (policy AllNodesPolicy) Allows(principal_id NodeID, action Action, node *Node)(Messages, error) {
return policy.Actions.Allows(action) return nil, policy.Actions.Allows(action)
} }
func (policy PerNodePolicy) Allows(principal_id NodeID, action Action, node *Node) error { func (policy PerNodePolicy) Allows(principal_id NodeID, action Action, node *Node)(Messages, error) {
for id, actions := range(policy.NodeActions) { for id, actions := range(policy.NodeActions) {
if id != principal_id { if id != principal_id {
continue continue
} }
return actions.Allows(action) return nil, actions.Allows(action)
} }
return fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID) return nil, fmt.Errorf("%s is not in per node policy of %s", principal_id, node.ID)
} }
func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error { func (policy *RequirementOfPolicy) Allows(principal_id NodeID, action Action, node *Node)(Messages, error) {
lockable_ext, err := GetExt[*LockableExt](node) lockable_ext, err := GetExt[*LockableExt](node)
if err != nil { if err != nil {
return err return nil, err
} }
for id, _ := range(lockable_ext.Requirements) { for id, _ := range(lockable_ext.Requirements) {
if id == principal_id { if id == principal_id {
return policy.Actions.Allows(action) return nil, policy.Actions.Allows(action)
} }
} }
return fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID) return nil, fmt.Errorf("%s is not a requirement of %s", principal_id, node.ID)
} }
type UserOfPolicy struct { type UserOfPolicy struct {
@ -65,11 +65,11 @@ func NewUserOfPolicy(group_actions NodeActions) UserOfPolicy {
} }
// Send a read signal to Group to check if principal_id is a member of it // Send a read signal to Group to check if principal_id is a member of it
func (policy *UserOfPolicy) Allows(principal_id NodeID, action Action, node *Node) error { func (policy *UserOfPolicy) Allows(principal_id NodeID, action Action, node *Node) (Messages, error) {
// Send a read signal to each of the groups in the map // Send a read signal to each of the groups in the map
// Check for principal_id in any of the returned member lists(skipping errors) // Check for principal_id in any of the returned member lists(skipping errors)
// Return an error in the default case // Return an error in the default case
return fmt.Errorf("NOT_IMPLEMENTED") return nil, fmt.Errorf("NOT_IMPLEMENTED")
} }
func (policy *UserOfPolicy) Merge(p Policy) Policy { func (policy *UserOfPolicy) Merge(p Policy) Policy {

@ -24,7 +24,6 @@ const (
LinkSignalType = "LINK" LinkSignalType = "LINK"
LockSignalType = "LOCK" LockSignalType = "LOCK"
ReadSignalType = "READ" ReadSignalType = "READ"
AuthorizedSignalType = "AUTHORIZED"
ReadResultSignalType = "READ_RESULT" ReadResultSignalType = "READ_RESULT"
LinkStartSignalType = "LINK_START" LinkStartSignalType = "LINK_START"
ECDHSignalType = "ECDH" ECDHSignalType = "ECDH"
@ -48,21 +47,7 @@ type Signal interface {
Permission() Action Permission() Action
} }
func WaitForResult(listener chan Signal, timeout time.Duration, id uuid.UUID) (Signal, error) { func WaitForSignal[S Signal](ctx * Context, listener chan Signal, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) {
timeout_channel := time.After(timeout)
select {
case result:=<-listener:
if result.ID() == id {
return result, nil
} else {
return result, fmt.Errorf("WRONG_ID: %s", result.ID())
}
case <-timeout_channel:
return nil, fmt.Errorf("timeout waiting for read response to %s", id)
}
}
func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.Duration, signal_type SignalType, check func(S)bool) (S, error) {
var zero S var zero S
var timeout_channel <- chan time.Time var timeout_channel <- chan time.Time
if timeout > 0 { if timeout > 0 {
@ -70,12 +55,12 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.
} }
for true { for true {
select { select {
case msg := <- listener.Chan: case signal := <- listener:
if msg.Signal == nil { if signal == nil {
return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type) return zero, fmt.Errorf("LISTENER_CLOSED: %s", signal_type)
} }
if msg.Signal.Type() == signal_type { if signal.Type() == signal_type {
sig, ok := msg.Signal.(S) sig, ok := signal.(S)
if ok == true { if ok == true {
if check(sig) == true { if check(sig) == true {
return sig, nil return sig, nil
@ -89,7 +74,6 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time.
return zero, fmt.Errorf("LOOP_ENDED") return zero, fmt.Errorf("LOOP_ENDED")
} }
type BaseSignal struct { type BaseSignal struct {
SignalDirection SignalDirection `json:"direction"` SignalDirection SignalDirection `json:"direction"`
SignalType SignalType `json:"type"` SignalType SignalType `json:"type"`
@ -244,42 +228,12 @@ type ReadSignal struct {
Extensions map[ExtType][]string `json:"extensions"` Extensions map[ExtType][]string `json:"extensions"`
} }
type AuthorizedSignal struct {
BaseSignal
Principal ed25519.PublicKey
Signal Signal
Signature []byte
}
func (signal *AuthorizedSignal) Permission() Action {
return AuthorizedSignalAction
}
func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (Signal, error) {
sig_data, err := signal.Serialize()
if err != nil {
return nil, err
}
sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil {
return nil, err
}
return &AuthorizedSignal{
NewBaseSignal(AuthorizedSignalType, Direct),
principal.Public().(ed25519.PublicKey),
signal,
sig,
}, nil
}
func (signal *ReadSignal) Serialize() ([]byte, error) { func (signal *ReadSignal) Serialize() ([]byte, error) {
return json.Marshal(signal) return json.Marshal(signal)
} }
func NewReadSignal(exts map[ExtType][]string) ReadSignal { func NewReadSignal(exts map[ExtType][]string) *ReadSignal {
return ReadSignal{ return &ReadSignal{
NewBaseSignal(ReadSignalType, Direct), NewBaseSignal(ReadSignalType, Direct),
exts, exts,
} }

@ -48,7 +48,7 @@ func (ext *GroupExt) Deserialize(ctx *Context, data []byte) error {
return err return err
} }
func (ext *GroupExt) Process(ctx *Context, node *Node, msg Message) []Message { func (ext *GroupExt) Process(ctx *Context, node *Node, source NodeID, signal Signal) Messages {
return nil return nil
} }