gql_cataclysm
noah metz 2023-08-01 20:55:15 -06:00
parent 42cd8f4188
commit 147f44e5ff
16 changed files with 387 additions and 407 deletions

@ -35,7 +35,19 @@ func (ext ExtType) Prefix() string { return "EXTENSION: " }
func (ext ExtType) String() string { return string(ext) }
//Function to load an extension from bytes
type ExtensionLoadFunc func(*Context, []byte) (Extension, error)
type ExtensionLoadFunc func(*Context,[]byte) (Extension, error)
func LoadExtension[T any, E interface {
*T
Extension
}](ctx *Context, data []byte) (Extension, error) {
e := E(new(T))
err := e.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return e, nil
}
// ExtType and NodeType constants
const (
@ -53,7 +65,19 @@ var (
NodeNotFoundError = errors.New("Node not found in DB")
)
type SignalLoadFunc func(*Context, []byte) (Signal, error)
type SignalLoadFunc func(*Context,[]byte) (Signal, error)
func LoadSignal[T any, S interface{
*T
Signal
}](ctx *Context, data []byte) (Signal, error) {
s := S(new(T))
err := s.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return s, nil
}
type SignalInfo struct {
Load SignalLoadFunc
@ -127,11 +151,10 @@ func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) e
return nil
}
func (ctx *Context) RegisterSignal(signal_type SignalType, load_fn SignalLoadFunc) error {
if load_fn == nil {
return fmt.Errorf("def has no load function")
}
func RegisterSignal[T any, S interface {
*T
Signal
}](ctx *Context, signal_type SignalType) error {
type_hash := Hash(signal_type)
_, exists := ctx.Signals[type_hash]
if exists == true {
@ -139,18 +162,19 @@ func (ctx *Context) RegisterSignal(signal_type SignalType, load_fn SignalLoadFun
}
ctx.Signals[type_hash] = SignalInfo{
Load: load_fn,
Load: LoadSignal[T, S],
Type: signal_type,
}
return nil
}
// Add a node to a context, returns an error if the def is invalid or already exists in the context
func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc, data interface{}) error {
if load_fn == nil {
return fmt.Errorf("def has no load function")
}
func RegisterExtension[T any, E interface{
*T
Extension
}](ctx *Context, data interface{}) error {
var zero E
ext_type := zero.Type()
type_hash := Hash(ext_type)
_, exists := ctx.Extensions[type_hash]
if exists == true {
@ -158,7 +182,7 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun
}
ctx.Extensions[type_hash] = ExtensionInfo{
Load: load_fn,
Load: LoadExtension[T,E],
Type: ext_type,
Data: data,
}
@ -195,7 +219,7 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) {
// Stop every running loop
func (ctx *Context) Stop() {
for _, node := range(ctx.Nodes) {
node.MsgChan <- Msg{ZeroID, StopSignal}
node.MsgChan <- Msg{ZeroID, &StopSignal}
}
}
@ -233,40 +257,41 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
}
var err error
err = ctx.RegisterExtension(ACLExtType, LoadACLExt, NewACLExtContext())
err = RegisterExtension[ACLExt,*ACLExt](ctx, NewACLExtContext())
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(LockableExtType, LoadLockableExt, nil)
err = RegisterExtension[LockableExt,*LockableExt](ctx, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ListenerExtType, LoadListenerExt, nil)
err = RegisterExtension[ListenerExt,*ListenerExt](ctx, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(ECDHExtType, LoadECDHExt, nil)
err = RegisterExtension[ECDHExt,*ECDHExt](ctx, nil)
if err != nil {
return nil, err
}
err = ctx.RegisterExtension(GroupExtType, LoadGroupExt, nil)
err = RegisterExtension[GroupExt,*GroupExt](ctx, nil)
if err != nil {
return nil, err
}
gql_ctx := NewGQLExtContext()
err = ctx.RegisterExtension(GQLExtType, LoadGQLExt, gql_ctx)
err = RegisterExtension[GQLExt,*GQLExt](ctx, gql_ctx)
if err != nil {
return nil, err
}
err = ctx.RegisterSignal(StopSignalType, func(ctx *Context, data []byte) (Signal, error) {
return StopSignal, nil
})
err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StopSignalType)
if err != nil {
return nil, err
}
err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType})
if err != nil {

@ -103,9 +103,7 @@ func (ext *ECDHExt) Field(name string) interface{} {
})
}
func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal ECDHSignal) {
ser, _ := signal.Serialize()
ctx.Log.Logf("ecdh", "ECDH_SIGNAL: %s->%s - %s", source, node.ID, ser)
func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal *ECDHSignal) {
switch signal.Str {
case "req":
state, exists := ext.ECDHStates[source]
@ -117,7 +115,7 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si
state.SharedSecret = shared_secret
ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
ctx.Send(node.ID, source, resp)
ctx.Send(node.ID, source, &resp)
} else {
ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
// TODO: send error response
@ -125,7 +123,8 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si
case "resp":
state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), "no_req"})
resp := NewErrorSignal(signal.ID(), "no_req")
ctx.Send(node.ID, source, &resp)
} else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil {
@ -143,21 +142,22 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si
}
}
func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal StringSignal) {
ser, _ := signal.Serialize()
ctx.Log.Logf("ecdh", "ECHD_STATE: %s->%s - %s", source, node.ID, ser)
func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
}
func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal ECDHProxySignal) {
func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal *ECDHProxySignal) {
state, exists := ext.ECDHStates[source]
if exists == false {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), "no_req"})
resp := NewErrorSignal(signal.ID(), "no_req")
ctx.Send(node.ID, source, &resp)
} else if state.SharedSecret == nil {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), "no_shared"})
resp := NewErrorSignal(signal.ID(), "no_shared")
ctx.Send(node.ID, source, &resp)
} else {
unwrapped_signal, err := ParseECDHProxySignal(ctx, &signal, state.SharedSecret)
unwrapped_signal, err := ParseECDHProxySignal(ctx, signal, state.SharedSecret)
if err != nil {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), err.Error()})
resp := NewErrorSignal(signal.ID(), err.Error())
ctx.Send(node.ID, source, &resp)
} else {
//TODO: Figure out what I was trying to do here and fix it
ctx.Send(signal.Source, signal.Dest, unwrapped_signal)
@ -170,13 +170,13 @@ func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Sign
case Direct:
switch signal.Type() {
case ECDHProxySignalType:
ecdh_signal := signal.(ECDHProxySignal)
ecdh_signal := signal.(*ECDHProxySignal)
ext.HandleECDHProxySignal(ctx, source, node, ecdh_signal)
case ECDHStateSignalType:
ecdh_signal := signal.(StringSignal)
ecdh_signal := signal.(*StringSignal)
ext.HandleStateSignal(ctx, source, node, ecdh_signal)
case ECDHSignalType:
ecdh_signal := signal.(ECDHSignal)
ecdh_signal := signal.(*ECDHSignal)
ext.HandleECDHSignal(ctx, source, node, ecdh_signal)
default:
}
@ -189,15 +189,9 @@ func (ext *ECDHExt) Type() ExtType {
}
func (ext *ECDHExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext, "", " ")
return json.Marshal(ext)
}
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) {
var ext ECDHExt
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
}
return &ext, nil
func (ext *ECDHExt) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, &ext)
}

@ -610,7 +610,7 @@ func GQLFields(ctx *GQLExtContext, field_names []string) (graphql.Fields, []ExtT
type NodeResult struct {
ID NodeID
Result ReadResultSignal
Result *ReadResultSignal
}
type ListField struct {
@ -923,7 +923,7 @@ func NewGQLExtContext() *GQLExtContext {
}
response_chan := ctx.Ext.GetResponseChannel(sig.ID())
err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, sig)
err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, &sig)
if err != nil {
ctx.Ext.FreeResponseChannel(sig.ID())
return nil, err
@ -1048,7 +1048,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
// Process ReadResultSignalType by forwarding it to the waiting resolver
if signal.Type() == ErrorSignalType {
// TODO: Forward to resolver if waiting for it
sig := signal.(ErrorSignal)
sig := signal.(*ErrorSignal)
response_chan := ext.FreeResponseChannel(sig.ID())
if response_chan != nil {
select {
@ -1062,7 +1062,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig)
}
} else if signal.Type() == ReadResultSignalType {
sig := signal.(ReadResultSignal)
sig := signal.(*ReadResultSignal)
response_chan := ext.FreeResponseChannel(sig.ID())
if response_chan != nil {
select {
@ -1075,14 +1075,15 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
}
} else if signal.Type() == GQLStateSignalType {
sig := signal.(StringSignal)
sig := signal.(*StringSignal)
switch sig.Str {
case "start_server":
if ext.State == "stopped" {
err := ext.StartGQLServer(ctx, node)
if err == nil {
ext.State = "running"
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"})
resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}
ctx.Send(node.ID, source, &resp)
}
}
case "stop_server":
@ -1090,7 +1091,8 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
err := ext.StopGQLServer()
if err == nil {
ext.State = "stopped"
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(GQLStateSignalType), "server_stopped"})
resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_stopped"}
ctx.Send(node.ID, source, &resp)
}
}
default:
@ -1102,7 +1104,8 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
case "running":
err := ext.StartGQLServer(ctx, node)
if err == nil {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"})
resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}
ctx.Send(node.ID, source, &resp)
}
case "stopped":
default:
@ -1116,7 +1119,7 @@ func (ext *GQLExt) Type() ExtType {
}
func (ext *GQLExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext, "", " ")
return json.Marshal(ext)
}
var ecdsa_curves = map[uint8]elliptic.Curve{
@ -1135,14 +1138,8 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{
ecdh.P256(): 0,
}
func LoadGQLExt(ctx *Context, data []byte) (Extension, error) {
var ext GQLExt
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
}
return NewGQLExt(ctx, ext.Listen, ext.tls_cert, ext.tls_key, ext.State)
func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, &ext)
}
func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) {

@ -1,21 +1,4 @@
package graphvent
import (
"github.com/graphql-go/graphql"
)
var MutationStop = NewField(func()*graphql.Field {
mutation_stop := &graphql.Field{
Type: TypeSignal.Type,
Args: graphql.FieldConfigArgument{
"id": &graphql.ArgumentConfig{
Type: graphql.String,
},
},
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return StopSignal, nil
},
}
return mutation_stop
})

@ -51,7 +51,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
}
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
read_signal := NewReadSignal(ext_fields)
auth_signal, err := NewAuthorizedSignal(ctx.Key, read_signal)
auth_signal, err := NewAuthorizedSignal(ctx.Key, &read_signal)
if err != nil {
return nil, err
}
@ -61,7 +61,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
resp_channels[read_signal.ID()] = response_chan
node_ids[read_signal.ID()] = id
err = ctx.Context.Send(ctx.Server.ID, id, auth_signal)
err = ctx.Context.Send(ctx.Server.ID, id, &auth_signal)
if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID())
return nil, err
@ -76,10 +76,10 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
return nil, err
}
switch resp := response.(type) {
case ReadResultSignal:
case *ReadResultSignal:
responses = append(responses, NodeResult{node_ids[sig_id], resp})
case ErrorSignal:
return nil, resp.Error
case *ErrorSignal:
return nil, fmt.Errorf(resp.Str)
default:
return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp))
}

@ -84,37 +84,3 @@ func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
return Hash(node.Result.NodeType), nil
})
}
func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
if signal, ok := p.Source.(Signal); ok {
return fn(signal, p)
}
return nil, fmt.Errorf("Failed to cast source to event")
}
func GQLSignalType(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
return signal.Type(), nil
})
}
func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
direction := signal.Direction()
if direction == Up {
return "up", nil
} else if direction == Down {
return "down", nil
} else if direction == Direct {
return "direct", nil
}
return nil, fmt.Errorf("Invalid direction: %+v", direction)
})
}
func GQLSignalString(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
ser, err := signal.Serialize()
return string(ser), err
})
}

@ -15,7 +15,7 @@ import (
)
func TestGQL(t *testing.T) {
ctx := logTestContext(t, []string{"node", "test", "gql", "policy"})
ctx := logTestContext(t, []string{})
TestNodeType := NodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType, ACLExtType})
@ -25,15 +25,16 @@ func TestGQL(t *testing.T) {
fatalErr(t, err)
listener_ext := NewListenerExt(10)
policy := NewAllNodesPolicy(Actions{MakeAction("+")})
start_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "start_server"}
gql := NewNode(ctx, nil, GQLNodeType, 10, []QueuedSignal{
QueuedSignal{uuid.New(), StringSignal{NewDirectSignal(GQLStateSignalType), "start_server"}, time.Now()},
}, NewLockableExt(), NewACLExt(policy), gql_ext, NewGroupExt(nil), listener_ext)
n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(policy))
QueuedSignal{uuid.New(), &start_signal, time.Now()},
}, NewLockableExt(), NewACLExt(&policy), gql_ext, NewGroupExt(nil), listener_ext)
n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(&policy))
err = LinkRequirement(ctx, gql.ID, n1.ID)
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig StringSignal) bool {
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool {
return sig.Str == "server_started"
})
fatalErr(t, err)
@ -85,14 +86,15 @@ func TestGQL(t *testing.T) {
resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2)
ctx.Send(n1.ID, gql.ID, StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"})
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig StringSignal) bool {
stop_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"}
ctx.Send(n1.ID, gql.ID, &stop_signal)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool {
return sig.Str == "server_stopped"
})
}
func TestGQLDB(t *testing.T) {
ctx := logTestContext(t, []string{})
ctx := logTestContext(t, []string{"listener"})
TestUserNodeType := NodeType("TEST_USER")
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{})
@ -111,10 +113,10 @@ func TestGQLDB(t *testing.T) {
NewGroupExt(nil))
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
err = ctx.Send(gql.ID, gql.ID, StopSignal)
err = ctx.Send(gql.ID, gql.ID, &StopSignal)
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStringSignal) bool {
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql.ID
})
fatalErr(t, err)
@ -130,9 +132,9 @@ func TestGQLDB(t *testing.T) {
fatalErr(t, err)
listener_ext, err = GetExt[*ListenerExt](gql_loaded)
fatalErr(t, err)
err = ctx.Send(gql_loaded.ID, gql_loaded.ID, StopSignal)
err = ctx.Send(gql_loaded.ID, gql_loaded.ID, &StopSignal)
fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStringSignal) bool {
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID
})
fatalErr(t, err)

@ -35,29 +35,3 @@ func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init fu
init(&gql)
return &gql
}
var TypeSignal = NewSingleton(func() *graphql.Object {
gql_type_signal := graphql.NewObject(graphql.ObjectConfig{
Name: "Signal",
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(Signal)
return ok
},
Fields: graphql.Fields{},
})
gql_type_signal.AddFieldConfig("Type", &graphql.Field{
Type: graphql.String,
Resolve: GQLSignalType,
})
gql_type_signal.AddFieldConfig("Direction", &graphql.Field{
Type: graphql.String,
Resolve: GQLSignalDirection,
})
gql_type_signal.AddFieldConfig("String", &graphql.Field{
Type: graphql.String,
Resolve: GQLSignalString,
})
return gql_type_signal
}, nil)

@ -17,7 +17,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
10,
nil,
listener_extension,
NewACLExt(policy),
NewACLExt(&policy),
NewLockableExt())
return listener, listener_extension

@ -2,7 +2,6 @@ package graphvent
import (
"encoding/json"
"fmt"
)
// A Listener extension provides a channel that can receive signals on a different thread
@ -31,14 +30,10 @@ func (ext *ListenerExt) Field(name string) interface{} {
}
// Simple load function, unmarshal the buffer int from json
func LoadListenerExt(ctx *Context, data []byte) (Extension, error) {
var j int
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
return NewListenerExt(j), nil
func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error {
err := json.Unmarshal(data, &ext.Buffer)
ext.Chan = make(chan Signal, ext.Buffer)
return err
}
func (listener *ListenerExt) Type() ExtType {
@ -47,7 +42,7 @@ func (listener *ListenerExt) Type() ExtType {
// Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
ctx.Log.Logf("signal", "LISTENER_PROCESS: %s - %+v", node.ID, signal)
ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal)
select {
case ext.Chan <- signal:
default:
@ -116,18 +111,8 @@ func (ext *LockableExt) Field(name string) interface{} {
})
}
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
var ext LockableExt
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
}
return &ext, nil
}
func (ext *ListenerExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext.Buffer, "", " ")
return json.Marshal(ext.Buffer)
}
func (ext *LockableExt) Type() ExtType {
@ -135,7 +120,11 @@ func (ext *LockableExt) Type() ExtType {
}
func (ext *LockableExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext, "", " ")
return json.Marshal(ext)
}
func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, ext)
}
func NewLockableExt() *LockableExt {
@ -149,36 +138,43 @@ func NewLockableExt() *LockableExt {
// Send the signal to unlock a node from itself
func UnlockLockable(ctx *Context, node *Node) error {
return ctx.Send(node.ID, node.ID, NewLockSignal("unlock"))
lock_signal := NewLockSignal("unlock")
return ctx.Send(node.ID, node.ID, &lock_signal)
}
// Send the signal to lock a node from itself
func LockLockable(ctx *Context, node *Node) error {
return ctx.Send(node.ID, node.ID, NewLockSignal("lock"))
lock_signal := NewLockSignal("lock")
return ctx.Send(node.ID, node.ID, &lock_signal)
}
// Setup a node to send the initial requirement link signal, then send the signal
func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error {
return ctx.Send(dependency, dependency, NewLinkStartSignal("req", requirement))
start_signal := NewLinkStartSignal("req", requirement)
return ctx.Send(dependency, dependency, &start_signal)
}
// Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal StringSignal) {
func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal)
state := signal.Str
switch state {
case "unlock":
if ext.Owner == nil {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_unlocked")))
resp := NewErrorSignal(signal.ID(), "already_unlocked")
ctx.Send(node.ID, source, &resp)
} else if source != *ext.Owner {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_owner")))
resp := NewErrorSignal(signal.ID(), "not_owner")
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner == nil {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_unlocking")))
resp := NewErrorSignal(signal.ID(), "already_unlocking")
ctx.Send(node.ID, source, &resp)
} else {
if len(ext.Requirements) == 0 {
ext.Owner = nil
ext.PendingOwner = nil
ctx.Send(node.ID, source, NewLockSignal("unlocked"))
resp := NewLockSignal("unlocked")
ctx.Send(node.ID, source, &resp)
} else {
ext.PendingOwner = nil
for id, state := range(ext.Requirements) {
@ -188,22 +184,27 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
}
state.Lock = "unlocking"
ext.Requirements[id] = state
ctx.Send(node.ID, id, NewLockSignal("unlock"))
resp := NewLockSignal("unlock")
ctx.Send(node.ID, id, &resp)
}
}
if source != node.ID {
ctx.Send(node.ID, source, NewLockSignal("unlocking"))
resp := NewLockSignal("unlocking")
ctx.Send(node.ID, source, &resp)
}
}
}
case "unlocking":
state, exists := ext.Requirements[source]
if exists == false {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement")))
resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("node_not_linked")))
resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "unlocking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_unlocking")))
resp := NewErrorSignal(signal.ID(), "not_unlocking")
ctx.Send(node.ID, source, &resp)
}
case "unlocked":
@ -213,11 +214,14 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
state, exists := ext.Requirements[source]
if exists == false {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement")))
resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_linked")))
resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "unlocking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_unlocking")))
resp := NewErrorSignal(signal.ID(), "not_unlocking")
ctx.Send(node.ID, source, &resp)
} else {
state.Lock = "unlocked"
ext.Requirements[source] = state
@ -237,7 +241,8 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
if linked == unlocked {
previous_owner := *ext.Owner
ext.Owner = nil
ctx.Send(node.ID, previous_owner, NewLockSignal("unlocked"))
resp := NewLockSignal("unlocked")
ctx.Send(node.ID, previous_owner, &resp)
}
}
}
@ -248,11 +253,14 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
state, exists := ext.Requirements[source]
if exists == false {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement")))
resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_linked")))
resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "locking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_locking")))
resp := NewErrorSignal(signal.ID(), "not_locking")
ctx.Send(node.ID, source, &resp)
} else {
state.Lock = "locked"
ext.Requirements[source] = state
@ -271,31 +279,38 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
if linked == locked {
ext.Owner = ext.PendingOwner
ctx.Send(node.ID, *ext.Owner, NewLockSignal("locked"))
resp := NewLockSignal("locked")
ctx.Send(node.ID, *ext.Owner, &resp)
}
}
}
case "locking":
state, exists := ext.Requirements[source]
if exists == false {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement")))
resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("node_not_linked")))
resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "locking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_locking")))
resp := NewErrorSignal(signal.ID(), "not_locking")
ctx.Send(node.ID, source, &resp)
}
case "lock":
if ext.Owner != nil {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_locked")))
resp := NewErrorSignal(signal.ID(), "already_locked")
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != nil {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_locking")))
resp := NewErrorSignal(signal.ID(), "already_locking")
ctx.Send(node.ID, source, &resp)
} else {
owner := source
if len(ext.Requirements) == 0 {
ext.Owner = &owner
ext.PendingOwner = ext.Owner
ctx.Send(node.ID, source, NewLockSignal("locked"))
resp := NewLockSignal("locked")
ctx.Send(node.ID, source, &resp)
} else {
ext.PendingOwner = &owner
for id, state := range(ext.Requirements) {
@ -305,11 +320,13 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
}
state.Lock = "locking"
ext.Requirements[id] = state
ctx.Send(node.ID, id, NewLockSignal("lock"))
sub := NewLockSignal("lock")
ctx.Send(node.ID, id, &sub)
}
}
if source != node.ID {
ctx.Send(node.ID, source, NewLockSignal("locking"))
resp := NewLockSignal("locking")
ctx.Send(node.ID, source, &resp)
}
}
}
@ -318,7 +335,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
}
}
func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal IDStringSignal) {
func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal *IDStringSignal) {
ctx.Log.Logf("lockable", "LINK__START_SIGNAL: %s->%s %+v", source, node.ID, signal)
link_type := signal.Str
target := signal.NodeID
@ -327,72 +344,90 @@ func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node
state, exists := ext.Requirements[target]
_, dep_exists := ext.Dependencies[target]
if ext.Owner != nil {
ctx.Send(node.ID, source, NewLinkStartSignal("locked", target))
resp := NewLinkStartSignal("locked", target)
ctx.Send(node.ID, source, &resp)
} else if ext.Owner != ext.PendingOwner {
if ext.PendingOwner == nil {
ctx.Send(node.ID, source, NewLinkStartSignal("unlocking", target))
resp := NewLinkStartSignal("unlocking", target)
ctx.Send(node.ID, source, &resp)
} else {
ctx.Send(node.ID, source, NewLinkStartSignal("locking", target))
resp := NewLinkStartSignal("locking", target)
ctx.Send(node.ID, source, &resp)
}
} else if exists == true {
if state.Link == "linking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_linking_req")))
resp := NewErrorSignal(signal.ID(), "already_linking_req")
ctx.Send(node.ID, source, &resp)
} else if state.Link == "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_req")))
resp := NewErrorSignal(signal.ID(), "already_req")
ctx.Send(node.ID, source, &resp)
}
} else if dep_exists == true {
ctx.Send(node.ID, source, NewLinkStartSignal("already_dep", target))
resp := NewLinkStartSignal("already_dep", target)
ctx.Send(node.ID, source, &resp)
} else {
ext.Requirements[target] = LinkState{"linking", "unlocked", source}
ctx.Send(node.ID, target, NewLinkSignal("linked_as_req"))
ctx.Send(node.ID, source, NewLinkStartSignal("linking_req", target))
resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, target, &resp)
notify := NewLinkStartSignal("linking_req", target)
ctx.Send(node.ID, source, &notify)
}
}
}
// Handle LinkSignal, updating the extensions requirements and dependencies as necessary
// TODO: Add unlink
func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal StringSignal) {
func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal)
state := signal.Str
switch state {
case "linked_as_dep":
state, exists := ext.Requirements[source]
if exists == true && state.Link == "linked" {
ctx.Send(node.ID, state.Initiator, NewLinkStartSignal("linked_as_req", source))
resp := NewLinkStartSignal("linked_as_req", source)
ctx.Send(node.ID, state.Initiator, &resp)
} else if state.Link == "linking" {
state.Link = "linked"
ext.Requirements[source] = state
ctx.Send(node.ID, source, NewLinkSignal("linked_as_req"))
resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != ext.Owner {
if ext.Owner == nil {
ctx.Send(node.ID, source, NewLinkSignal("locking"))
resp := NewLinkSignal("locking")
ctx.Send(node.ID, source, &resp)
} else {
ctx.Send(node.ID, source, NewLinkSignal("unlocking"))
resp := NewLinkSignal("unlocking")
ctx.Send(node.ID, source, &resp)
}
} else {
ext.Requirements[source] = LinkState{"linking", "unlocked", source}
ctx.Send(node.ID, source, NewLinkSignal("linked_as_req"))
resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, source, &resp)
}
ctx.Log.Logf("lockable", "%s is a dependency of %s", node.ID, source)
case "linked_as_req":
state, exists := ext.Dependencies[source]
if exists == true && state.Link == "linked" {
ctx.Send(node.ID, state.Initiator, NewLinkStartSignal("linked_as_dep", source))
resp := NewLinkStartSignal("linked_as_dep", source)
ctx.Send(node.ID, state.Initiator, &resp)
} else if state.Link == "linking" {
state.Link = "linked"
ext.Dependencies[source] = state
ctx.Send(node.ID, source, NewLinkSignal("linked_as_dep"))
resp := NewLinkSignal("linked_as_dep")
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != ext.Owner {
if ext.Owner == nil {
ctx.Send(node.ID, source, NewLinkSignal("locking"))
resp := NewLinkSignal("locking")
ctx.Send(node.ID, source, &resp)
} else {
ctx.Send(node.ID, source, NewLinkSignal("unlocking"))
resp := NewLinkSignal("unlocking")
ctx.Send(node.ID, source, &resp)
}
} else {
ext.Dependencies[source] = LinkState{"linking", "unlocked", source}
ctx.Send(node.ID, source, NewLinkSignal("linked_as_dep"))
resp := NewLinkSignal("linked_as_dep")
ctx.Send(node.ID, source, &resp)
}
ctx.Log.Logf("lockable", "%s is a requirement of %s", node.ID, source)
@ -442,11 +477,11 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal
case Direct:
switch signal.Type() {
case LinkSignalType:
ext.HandleLinkSignal(ctx, source, node, signal.(StringSignal))
ext.HandleLinkSignal(ctx, source, node, signal.(*StringSignal))
case LockSignalType:
ext.HandleLockSignal(ctx, source, node, signal.(StringSignal))
ext.HandleLockSignal(ctx, source, node, signal.(*StringSignal))
case LinkStartSignalType:
ext.HandleLinkStartSignal(ctx, source, node, signal.(IDStringSignal))
ext.HandleLinkStartSignal(ctx, source, node, signal.(*IDStringSignal))
default:
}
default:

@ -26,13 +26,13 @@ func TestLink(t *testing.T) {
l1_listener := NewListenerExt(10)
l1 := NewNode(ctx, nil, TestLockableType, 10, nil,
l1_listener,
NewACLExt(link_policy),
NewACLExt(&link_policy),
NewLockableExt(),
)
l2_listener := NewListenerExt(10)
l2 := NewNode(ctx, nil, TestLockableType, 10, nil,
l2_listener,
NewACLExt(link_policy),
NewACLExt(&link_policy),
NewLockableExt(),
)
@ -40,20 +40,21 @@ func TestLink(t *testing.T) {
err := LinkRequirement(ctx, l1.ID, l2.ID)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, func(sig IDStringSignal) bool {
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "linked_as_req"
})
fatalErr(t, err)
err = ctx.Send(l2.ID, l2.ID, NewStatusSignal("TEST", l2.ID))
sig1 := NewStatusSignal("TEST", l2.ID)
err = ctx.Send(l2.ID, l2.ID, &sig1)
fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig IDStringSignal) bool {
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "TEST"
})
fatalErr(t, err)
_, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig IDStringSignal) bool {
_, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "TEST"
})
fatalErr(t, err)
@ -64,7 +65,7 @@ func TestLink10K(t *testing.T) {
NewLockable := func()(*Node) {
l := NewNode(ctx, nil, TestLockableType, 10, nil,
NewACLExt(lock_policy, link_policy),
NewACLExt(&lock_policy, &link_policy),
NewLockableExt(),
)
return l
@ -74,7 +75,7 @@ func TestLink10K(t *testing.T) {
listener := NewListenerExt(100000)
l := NewNode(ctx, nil, TestLockableType, 256, nil,
listener,
NewACLExt(lock_policy, link_policy),
NewACLExt(&lock_policy, &link_policy),
NewLockableExt(),
)
return l, listener
@ -91,7 +92,7 @@ func TestLink10K(t *testing.T) {
for range(lockables) {
_, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, func(sig IDStringSignal) bool {
_, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "linked_as_req"
})
fatalErr(t, err)
@ -107,7 +108,7 @@ func TestLock(t *testing.T) {
listener := NewListenerExt(100)
l := NewNode(ctx, nil, TestLockableType, 10, nil,
listener,
NewACLExt(lock_policy, link_policy),
NewACLExt(&lock_policy, &link_policy),
NewLockableExt(),
)
return l, listener
@ -140,11 +141,11 @@ func TestLock(t *testing.T) {
err = LinkRequirement(ctx, l0.ID, l5.ID)
fatalErr(t, err)
linked_as_req := func(sig IDStringSignal) bool {
linked_as_req := func(sig *IDStringSignal) bool {
return sig.Str == "linked_as_req"
}
locked := func(sig StringSignal) bool {
locked := func(sig *StringSignal) bool {
return sig.Str == "locked"
}

@ -84,8 +84,9 @@ func RandID() NodeID {
// A Serializable has a type that can be used to map to it, and a function to serialize the current state
type Serializable[I comparable] interface {
Serialize()([]byte,error)
Deserialize(*Context,[]byte)error
Type() I
Serialize() ([]byte, error)
}
// Extensions are data attached to nodes that process signals
@ -197,7 +198,7 @@ func nodeLoop(ctx *Context, node *Node) error {
}
// Queue the signal for extensions to perform startup actions
node.QueueSignal(time.Now(), NewDirectSignal(StartSignalType))
node.QueueSignal(time.Now(), &StartSignal)
for true {
var signal Signal
@ -209,7 +210,8 @@ func nodeLoop(ctx *Context, node *Node) error {
err := Allowed(ctx, msg.Source, signal.Permission(), node)
if err != nil {
ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err)
ctx.Send(node.ID, msg.Source, NewErrorSignal(msg.Signal.ID(), err))
resp := NewErrorSignal(msg.Signal.ID(), err.Error())
ctx.Send(node.ID, msg.Source, &resp)
continue
}
case <-node.TimeoutChan:
@ -240,7 +242,7 @@ func nodeLoop(ctx *Context, node *Node) error {
// Unwrap Authorized Signals
if signal.Type() == AuthorizedSignalType {
sig, ok := signal.(AuthorizedSignal)
sig, ok := signal.(*AuthorizedSignal)
if ok == false {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(signal))
} else {
@ -254,14 +256,16 @@ func nodeLoop(ctx *Context, node *Node) error {
err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node)
if err != nil {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err)
ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), err))
resp := NewErrorSignal(sig.ID(), err.Error())
ctx.Send(node.ID, source, &resp)
} else {
// Unwrap the signal without changing the source
signal = sig.Signal
}
} else {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate")
ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), fmt.Errorf("failed to validate signature")))
resp := NewErrorSignal(sig.ID(), "signature validation failed")
ctx.Send(node.ID, source, &resp)
}
}
}
@ -269,16 +273,19 @@ func nodeLoop(ctx *Context, node *Node) error {
// Handle special signal types
if signal.Type() == StopSignalType {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), nil))
node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID))
resp := NewErrorSignal(signal.ID(), "stopped")
ctx.Send(node.ID, source, &resp)
status := NewStatusSignal("stopped", node.ID)
node.Process(ctx, node.ID, &status)
break
} else if signal.Type() == ReadSignalType {
read_signal, ok := signal.(ReadSignal)
read_signal, ok := signal.(*ReadSignal)
if ok == false {
ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal)
} else {
result := ReadNodeFields(ctx, node, source, read_signal.Extensions)
ctx.Send(node.ID, source, NewReadResultSignal(read_signal.ID(), node.Type, result))
resp := NewReadResultSignal(read_signal.ID(), node.Type, result)
ctx.Send(node.ID, source, &resp)
}
}

@ -41,18 +41,19 @@ func TestNodeRead(t *testing.T) {
n1_id: Actions{MakeAction(ReadResultSignalType, "+")},
})
n2_listener := NewListenerExt(10)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(&n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener)
n1_policy := NewPerNodePolicy(map[NodeID]Actions{
n2_id: Actions{MakeAction(ReadSignalType, "+")},
})
n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(n1_policy), NewGroupExt(nil), NewECDHExt())
n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(&n1_policy), NewGroupExt(nil), NewECDHExt())
ctx.Send(n2.ID, n1.ID, NewReadSignal(map[ExtType][]string{
read_sig := NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"},
}))
})
ctx.Send(n2.ID, n1.ID, &read_sig)
res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig ReadResultSignal) bool {
res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool {
return true
})
fatalErr(t, err)
@ -68,11 +69,11 @@ func TestECDH(t *testing.T) {
n1_listener := NewListenerExt(10)
ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+"), MakeAction(ECDHProxySignalType, "+")})
n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt(), n1_listener)
n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt())
n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt(), n1_listener)
n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt())
n3_listener := NewListenerExt(10)
n3_policy := NewPerNodePolicy(NodeActions{n1.ID: Actions{MakeAction(StopSignalType)}})
n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy, n3_policy), NewECDHExt(), n3_listener)
n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy, &n3_policy), NewECDHExt(), n3_listener)
ctx.Log.Logf("test", "N1: %s", n1.ID)
ctx.Log.Logf("test", "N2: %s", n2.ID)
@ -87,18 +88,18 @@ func TestECDH(t *testing.T) {
}
fatalErr(t, err)
ctx.Log.Logf("test", "N1_EC: %+v", n1_ec)
err = ctx.Send(n1.ID, n2.ID, ecdh_req)
err = ctx.Send(n1.ID, n2.ID, &ecdh_req)
fatalErr(t, err)
_, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig ECDHSignal) bool {
_, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool {
return sig.Str == "resp"
})
fatalErr(t, err)
time.Sleep(10*time.Millisecond)
ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, NewDirectSignal(StopSignalType), ecdh_ext.ECDHStates[n2.ID].SharedSecret)
ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret)
fatalErr(t, err)
err = ctx.Send(n1.ID, n2.ID, ecdh_sig)
err = ctx.Send(n1.ID, n2.ID, &ecdh_sig)
fatalErr(t, err)
}

@ -20,6 +20,8 @@ type Policy interface {
Allows(principal_id NodeID, action Action, node *Node) error
// Merge with another policy of the same underlying type
Merge(Policy) Policy
// Make a copy of this policy
Copy() Policy
}
func (policy AllNodesPolicy) Allows(principal_id NodeID, action Action, node *Node) error {
@ -72,6 +74,14 @@ func MergeActions(first Actions, second Actions) Actions {
return ret
}
func CopyNodeActions(actions NodeActions) NodeActions {
ret := NodeActions{}
for id, a := range(actions) {
ret[id] = a
}
return ret
}
func MergeNodeActions(modified NodeActions, read NodeActions) {
for id, actions := range(read) {
existing, exists := modified[id]
@ -83,24 +93,47 @@ func MergeNodeActions(modified NodeActions, read NodeActions) {
}
}
func (policy PerNodePolicy) Merge(p Policy) Policy {
other := p.(PerNodePolicy)
func (policy *PerNodePolicy) Merge(p Policy) Policy {
other := p.(*PerNodePolicy)
MergeNodeActions(policy.NodeActions, other.NodeActions)
return policy
}
func (policy AllNodesPolicy) Merge(p Policy) Policy {
other := p.(AllNodesPolicy)
func (policy *PerNodePolicy) Copy() Policy {
new_actions := CopyNodeActions(policy.NodeActions)
return &PerNodePolicy{
NodeActions: new_actions,
}
}
func (policy *AllNodesPolicy) Merge(p Policy) Policy {
other := p.(*AllNodesPolicy)
policy.Actions = MergeActions(policy.Actions, other.Actions)
return policy
}
func (policy RequirementOfPolicy) Merge(p Policy) Policy {
other := p.(RequirementOfPolicy)
func (policy *AllNodesPolicy) Copy() Policy {
new_actions := policy.Actions
return &AllNodesPolicy {
Actions: new_actions,
}
}
func (policy *RequirementOfPolicy) Merge(p Policy) Policy {
other := p.(*RequirementOfPolicy)
policy.Actions = MergeActions(policy.Actions, other.Actions)
return policy
}
func (policy *RequirementOfPolicy) Copy() Policy {
new_actions := policy.Actions
return &RequirementOfPolicy{
AllNodesPolicy {
Actions: new_actions,
},
}
}
type Action []string
func MakeAction(parts ...interface{}) Action {
@ -178,36 +211,6 @@ func (actions *NodeActions) UnmarshalJSON(data []byte) error {
return nil
}
type AllNodesPolicyJSON struct {
Actions Actions `json:"actions"`
}
func AllNodesPolicyLoad(init_fn func(Actions)(Policy, error)) func(*Context, []byte)(Policy, error) {
return func(ctx *Context, data []byte)(Policy, error){
var j AllNodesPolicyJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
return init_fn(j.Actions)
}
}
func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) {
return func(ctx *Context, data []byte)(Policy, error){
policy := PerNodePolicy{
NodeActions: NodeActions{},
}
err := json.Unmarshal(data, &policy)
if err != nil {
return nil, err
}
return init_fn(policy.NodeActions)
}
}
func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy {
if node_actions == nil {
node_actions = NodeActions{}
@ -222,14 +225,18 @@ type PerNodePolicy struct {
NodeActions NodeActions `json:"node_actions"`
}
func (policy PerNodePolicy) Type() PolicyType {
func (policy *PerNodePolicy) Type() PolicyType {
return PerNodePolicyType
}
func (policy PerNodePolicy) Serialize() ([]byte, error) {
func (policy *PerNodePolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(policy, "", " ")
}
func (policy *PerNodePolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy)
}
func NewAllNodesPolicy(actions Actions) AllNodesPolicy {
if actions == nil {
actions = Actions{}
@ -244,14 +251,18 @@ type AllNodesPolicy struct {
Actions Actions
}
func (policy AllNodesPolicy) Type() PolicyType {
func (policy *AllNodesPolicy) Type() PolicyType {
return AllNodesPolicyType
}
func (policy AllNodesPolicy) Serialize() ([]byte, error) {
func (policy *AllNodesPolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(policy, "", " ")
}
func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy)
}
// Extension to allow a node to hold ACL policies
type ACLExt struct {
Policies map[PolicyType]Policy
@ -266,36 +277,17 @@ func NodeList(nodes ...*Node) NodeMap {
return m
}
type PolicyLoadFunc func(*Context, []byte) (Policy, error)
type PolicyInfo struct {
Load PolicyLoadFunc
}
type PolicyLoadFunc func(*Context,[]byte) (Policy, error)
type ACLExtContext struct {
Types map[PolicyType]PolicyInfo
Loads map[PolicyType]PolicyLoadFunc
}
func NewACLExtContext() *ACLExtContext {
return &ACLExtContext{
Types: map[PolicyType]PolicyInfo{
AllNodesPolicyType: PolicyInfo{
Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){
policy := NewAllNodesPolicy(actions)
return &policy, nil
}),
},
PerNodePolicyType: PolicyInfo{
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewPerNodePolicy(nodes)
return &policy, nil
}),
},
RequirementOfPolicyType: PolicyInfo{
Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){
policy := NewRequirementOfPolicy(actions)
return &policy, nil
}),
},
Loads: map[PolicyType]PolicyLoadFunc{
AllNodesPolicyType: LoadPolicy[AllNodesPolicy,*AllNodesPolicy],
PerNodePolicyType: LoadPolicy[PerNodePolicy,*PerNodePolicy],
RequirementOfPolicyType: LoadPolicy[RequirementOfPolicy,*RequirementOfPolicy],
},
}
}
@ -331,13 +323,15 @@ func (ext *ACLExt) Field(name string) interface{} {
var ErrorSignalAction = Action{"ERROR_RESP"}
var ReadResultSignalAction = Action{"READ_RESULT"}
var AuthorizedSignalAction = Action{"AUTHORIZED_READ"}
var defaultPolicy = NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction, AuthorizedSignalAction})
var DefaultACLPolicies = []Policy{
NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction, AuthorizedSignalAction}),
&defaultPolicy,
}
func NewACLExt(policies ...Policy) *ACLExt {
policy_map := map[PolicyType]Policy{}
for _, policy := range(append(policies, DefaultACLPolicies...)) {
for _, policy_arg := range(append(policies, DefaultACLPolicies...)) {
policy := policy_arg.Copy()
existing, exists := policy_map[policy.Type()]
if exists == true {
policy = existing.Merge(policy)
@ -351,36 +345,49 @@ func NewACLExt(policies ...Policy) *ACLExt {
}
}
func LoadACLExt(ctx *Context, data []byte) (Extension, error) {
func LoadPolicy[T any, P interface {
*T
Policy
}](ctx *Context, data []byte) (Policy, error) {
p := P(new(T))
err := p.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return p, nil
}
func (ext *ACLExt) Deserialize(ctx *Context, data []byte) error {
var j struct {
Policies map[string][]byte `json:"policies"`
}
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
return err
}
policies := make([]Policy, len(j.Policies))
i := 0
acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx)
if err != nil {
return nil, err
return err
}
ext.Policies = map[PolicyType]Policy{}
for name, ser := range(j.Policies) {
policy_def, exists := acl_ctx.Types[PolicyType(name)]
policy_load, exists := acl_ctx.Loads[PolicyType(name)]
if exists == false {
return nil, fmt.Errorf("%s is not a known policy type", name)
return fmt.Errorf("%s is not a known policy type", name)
}
policy, err := policy_def.Load(ctx, ser)
policy, err := policy_load(ctx, ser)
if err != nil {
return nil, err
return err
}
policies[i] = policy
i++
ext.Policies[PolicyType(name)] = policy
}
return NewACLExt(policies...), nil
return nil
}
func (ext *ACLExt) Type() ExtType {

@ -95,23 +95,27 @@ type BaseSignal struct {
UUID uuid.UUID `json:"id"`
}
func (signal BaseSignal) ID() uuid.UUID {
func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, signal)
}
func (signal *BaseSignal) ID() uuid.UUID {
return signal.UUID
}
func (signal BaseSignal) Type() SignalType {
func (signal *BaseSignal) Type() SignalType {
return signal.SignalType
}
func (signal BaseSignal) Permission() Action {
func (signal *BaseSignal) Permission() Action {
return MakeAction(signal.Type())
}
func (signal BaseSignal) Direction() SignalDirection {
func (signal *BaseSignal) Direction() SignalDirection {
return signal.SignalDirection
}
func (signal BaseSignal) Serialize() ([]byte, error) {
func (signal *BaseSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
@ -136,43 +140,16 @@ func NewDirectSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Direct)
}
var StartSignal = NewDirectSignal(StartSignalType)
var StopSignal = NewDownSignal(StopSignalType)
type ErrorSignal struct {
BaseSignal
Error error `json:"error"`
}
func (signal ErrorSignal) Permission() Action {
return ErrorSignalAction
}
func NewErrorSignal(req_id uuid.UUID, err error) ErrorSignal {
return ErrorSignal{
BaseSignal: BaseSignal{
Direct,
ErrorSignalType,
req_id,
},
Error: err,
}
}
type IDSignal struct {
BaseSignal
NodeID `json:"id"`
}
func (signal IDSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
}
func (signal IDSignal) String() string {
ser, err := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
}
return string(ser)
func (signal *IDSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal {
@ -187,21 +164,38 @@ type StringSignal struct {
Str string `json:"state"`
}
func (signal StringSignal) Serialize() ([]byte, error) {
func (signal *StringSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
}
type ErrorSignal struct {
StringSignal
}
func (signal *ErrorSignal) Permission() Action {
return ErrorSignalAction
}
func NewErrorSignal(req_id uuid.UUID, err string) ErrorSignal {
return ErrorSignal{
StringSignal{
NewDirectSignal(ErrorSignalType),
err,
},
}
}
type IDStringSignal struct {
BaseSignal
NodeID `json:"node_id"`
Str string `json:"string"`
}
func (signal IDStringSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
func (signal *IDStringSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func (signal IDStringSignal) String() string {
func (signal *IDStringSignal) String() string {
ser, err := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
@ -243,7 +237,7 @@ func NewLockSignal(state string) StringSignal {
}
}
func (signal StringSignal) Permission() Action {
func (signal *StringSignal) Permission() Action {
return MakeAction(signal.Type(), signal.Str)
}
@ -259,7 +253,7 @@ type AuthorizedSignal struct {
Signature []byte
}
func (signal AuthorizedSignal) Permission() Action {
func (signal *AuthorizedSignal) Permission() Action {
return AuthorizedSignalAction
}
@ -283,8 +277,8 @@ func NewAuthorizedSignal(principal *ecdsa.PrivateKey, signal Signal) (Authorized
}, nil
}
func (signal ReadSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
func (signal *ReadSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func NewReadSignal(exts map[ExtType][]string) ReadSignal {
@ -300,7 +294,7 @@ type ReadResultSignal struct {
Extensions map[ExtType]map[string]interface{} `json:"extensions"`
}
func (signal ReadResultSignal) Permission() Action {
func (signal *ReadResultSignal) Permission() Action {
return ReadResultSignalAction
}
@ -342,8 +336,8 @@ func (signal *ECDHSignal) MarshalJSON() ([]byte, error) {
})
}
func (signal ECDHSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal)
func (signal *ECDHSignal) Serialize() ([]byte, error) {
return json.Marshal(signal)
}
func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) {
@ -390,7 +384,7 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e
const DEFAULT_ECDH_WINDOW = time.Second
func NewECDHRespSignal(ctx *Context, node *Node, req ECDHSignal) (ECDHSignal, []byte, error) {
func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) {
now := time.Now()
err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW)
@ -430,7 +424,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req ECDHSignal) (ECDHSignal, []
}, shared_secret, nil
}
func VerifyECDHSignal(now time.Time, sig ECDHSignal, window time.Duration) error {
func VerifyECDHSignal(now time.Time, sig *ECDHSignal, window time.Duration) error {
earliest := now.Add(-window)
latest := now.Add(window)

@ -17,9 +17,9 @@ func (ext *GroupExt) Type() ExtType {
}
func (ext *GroupExt) Serialize() ([]byte, error) {
return json.MarshalIndent(&GroupExtJSON{
return json.Marshal(&GroupExtJSON{
Members: IDMap(ext.Members),
}, "", " ")
})
}
func (ext *GroupExt) Field(name string) interface{} {
@ -40,18 +40,12 @@ func NewGroupExt(members map[NodeID]string) *GroupExt {
}
}
func LoadGroupExt(ctx *Context, data []byte) (Extension, error) {
func (ext *GroupExt) Deserialize(ctx *Context, data []byte) error {
var j GroupExtJSON
err := json.Unmarshal(data, &j)
members, err := LoadIDMap(j.Members)
if err != nil {
return nil, err
}
return &GroupExt{
Members: members,
}, nil
ext.Members, err = LoadIDMap(j.Members)
return err
}
func (ext *GroupExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {