gql_cataclysm
noah metz 2023-08-01 20:55:15 -06:00
parent 42cd8f4188
commit 147f44e5ff
16 changed files with 387 additions and 407 deletions

@ -36,6 +36,18 @@ func (ext ExtType) String() string { return string(ext) }
//Function to load an extension from bytes //Function to load an extension from bytes
type ExtensionLoadFunc func(*Context,[]byte) (Extension, error) type ExtensionLoadFunc func(*Context,[]byte) (Extension, error)
func LoadExtension[T any, E interface {
*T
Extension
}](ctx *Context, data []byte) (Extension, error) {
e := E(new(T))
err := e.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return e, nil
}
// ExtType and NodeType constants // ExtType and NodeType constants
const ( const (
@ -54,6 +66,18 @@ var (
) )
type SignalLoadFunc func(*Context,[]byte) (Signal, error) type SignalLoadFunc func(*Context,[]byte) (Signal, error)
func LoadSignal[T any, S interface{
*T
Signal
}](ctx *Context, data []byte) (Signal, error) {
s := S(new(T))
err := s.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return s, nil
}
type SignalInfo struct { type SignalInfo struct {
Load SignalLoadFunc Load SignalLoadFunc
@ -127,11 +151,10 @@ func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) e
return nil return nil
} }
func (ctx *Context) RegisterSignal(signal_type SignalType, load_fn SignalLoadFunc) error { func RegisterSignal[T any, S interface {
if load_fn == nil { *T
return fmt.Errorf("def has no load function") Signal
} }](ctx *Context, signal_type SignalType) error {
type_hash := Hash(signal_type) type_hash := Hash(signal_type)
_, exists := ctx.Signals[type_hash] _, exists := ctx.Signals[type_hash]
if exists == true { if exists == true {
@ -139,18 +162,19 @@ func (ctx *Context) RegisterSignal(signal_type SignalType, load_fn SignalLoadFun
} }
ctx.Signals[type_hash] = SignalInfo{ ctx.Signals[type_hash] = SignalInfo{
Load: load_fn, Load: LoadSignal[T, S],
Type: signal_type, Type: signal_type,
} }
return nil return nil
} }
// Add a node to a context, returns an error if the def is invalid or already exists in the context // Add a node to a context, returns an error if the def is invalid or already exists in the context
func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc, data interface{}) error { func RegisterExtension[T any, E interface{
if load_fn == nil { *T
return fmt.Errorf("def has no load function") Extension
} }](ctx *Context, data interface{}) error {
var zero E
ext_type := zero.Type()
type_hash := Hash(ext_type) type_hash := Hash(ext_type)
_, exists := ctx.Extensions[type_hash] _, exists := ctx.Extensions[type_hash]
if exists == true { if exists == true {
@ -158,7 +182,7 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun
} }
ctx.Extensions[type_hash] = ExtensionInfo{ ctx.Extensions[type_hash] = ExtensionInfo{
Load: load_fn, Load: LoadExtension[T,E],
Type: ext_type, Type: ext_type,
Data: data, Data: data,
} }
@ -195,7 +219,7 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) {
// Stop every running loop // Stop every running loop
func (ctx *Context) Stop() { func (ctx *Context) Stop() {
for _, node := range(ctx.Nodes) { for _, node := range(ctx.Nodes) {
node.MsgChan <- Msg{ZeroID, StopSignal} node.MsgChan <- Msg{ZeroID, &StopSignal}
} }
} }
@ -233,40 +257,41 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
} }
var err error var err error
err = ctx.RegisterExtension(ACLExtType, LoadACLExt, NewACLExtContext()) err = RegisterExtension[ACLExt,*ACLExt](ctx, NewACLExtContext())
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterExtension(LockableExtType, LoadLockableExt, nil) err = RegisterExtension[LockableExt,*LockableExt](ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterExtension(ListenerExtType, LoadListenerExt, nil) err = RegisterExtension[ListenerExt,*ListenerExt](ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterExtension(ECDHExtType, LoadECDHExt, nil) err = RegisterExtension[ECDHExt,*ECDHExt](ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterExtension(GroupExtType, LoadGroupExt, nil) err = RegisterExtension[GroupExt,*GroupExt](ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
gql_ctx := NewGQLExtContext() gql_ctx := NewGQLExtContext()
err = ctx.RegisterExtension(GQLExtType, LoadGQLExt, gql_ctx) err = RegisterExtension[GQLExt,*GQLExt](ctx, gql_ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = ctx.RegisterSignal(StopSignalType, func(ctx *Context, data []byte) (Signal, error) { err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StopSignalType)
return StopSignal, nil if err != nil {
}) return nil, err
}
err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType}) err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType})
if err != nil { if err != nil {

@ -103,9 +103,7 @@ func (ext *ECDHExt) Field(name string) interface{} {
}) })
} }
func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal ECDHSignal) { func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal *ECDHSignal) {
ser, _ := signal.Serialize()
ctx.Log.Logf("ecdh", "ECDH_SIGNAL: %s->%s - %s", source, node.ID, ser)
switch signal.Str { switch signal.Str {
case "req": case "req":
state, exists := ext.ECDHStates[source] state, exists := ext.ECDHStates[source]
@ -117,7 +115,7 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si
state.SharedSecret = shared_secret state.SharedSecret = shared_secret
ext.ECDHStates[source] = state ext.ECDHStates[source] = state
ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret)
ctx.Send(node.ID, source, resp) ctx.Send(node.ID, source, &resp)
} else { } else {
ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err)
// TODO: send error response // TODO: send error response
@ -125,7 +123,8 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si
case "resp": case "resp":
state, exists := ext.ECDHStates[source] state, exists := ext.ECDHStates[source]
if exists == false || state.ECKey == nil { if exists == false || state.ECKey == nil {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), "no_req"}) resp := NewErrorSignal(signal.ID(), "no_req")
ctx.Send(node.ID, source, &resp)
} else { } else {
err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW)
if err == nil { if err == nil {
@ -143,21 +142,22 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si
} }
} }
func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal StringSignal) { func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
ser, _ := signal.Serialize()
ctx.Log.Logf("ecdh", "ECHD_STATE: %s->%s - %s", source, node.ID, ser)
} }
func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal ECDHProxySignal) { func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal *ECDHProxySignal) {
state, exists := ext.ECDHStates[source] state, exists := ext.ECDHStates[source]
if exists == false { if exists == false {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), "no_req"}) resp := NewErrorSignal(signal.ID(), "no_req")
ctx.Send(node.ID, source, &resp)
} else if state.SharedSecret == nil { } else if state.SharedSecret == nil {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), "no_shared"}) resp := NewErrorSignal(signal.ID(), "no_shared")
ctx.Send(node.ID, source, &resp)
} else { } else {
unwrapped_signal, err := ParseECDHProxySignal(ctx, &signal, state.SharedSecret) unwrapped_signal, err := ParseECDHProxySignal(ctx, signal, state.SharedSecret)
if err != nil { if err != nil {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), err.Error()}) resp := NewErrorSignal(signal.ID(), err.Error())
ctx.Send(node.ID, source, &resp)
} else { } else {
//TODO: Figure out what I was trying to do here and fix it //TODO: Figure out what I was trying to do here and fix it
ctx.Send(signal.Source, signal.Dest, unwrapped_signal) ctx.Send(signal.Source, signal.Dest, unwrapped_signal)
@ -170,13 +170,13 @@ func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Sign
case Direct: case Direct:
switch signal.Type() { switch signal.Type() {
case ECDHProxySignalType: case ECDHProxySignalType:
ecdh_signal := signal.(ECDHProxySignal) ecdh_signal := signal.(*ECDHProxySignal)
ext.HandleECDHProxySignal(ctx, source, node, ecdh_signal) ext.HandleECDHProxySignal(ctx, source, node, ecdh_signal)
case ECDHStateSignalType: case ECDHStateSignalType:
ecdh_signal := signal.(StringSignal) ecdh_signal := signal.(*StringSignal)
ext.HandleStateSignal(ctx, source, node, ecdh_signal) ext.HandleStateSignal(ctx, source, node, ecdh_signal)
case ECDHSignalType: case ECDHSignalType:
ecdh_signal := signal.(ECDHSignal) ecdh_signal := signal.(*ECDHSignal)
ext.HandleECDHSignal(ctx, source, node, ecdh_signal) ext.HandleECDHSignal(ctx, source, node, ecdh_signal)
default: default:
} }
@ -189,15 +189,9 @@ func (ext *ECDHExt) Type() ExtType {
} }
func (ext *ECDHExt) Serialize() ([]byte, error) { func (ext *ECDHExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext, "", " ") return json.Marshal(ext)
} }
func LoadECDHExt(ctx *Context, data []byte) (Extension, error) { func (ext *ECDHExt) Deserialize(ctx *Context, data []byte) error {
var ext ECDHExt return json.Unmarshal(data, &ext)
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
}
return &ext, nil
} }

@ -610,7 +610,7 @@ func GQLFields(ctx *GQLExtContext, field_names []string) (graphql.Fields, []ExtT
type NodeResult struct { type NodeResult struct {
ID NodeID ID NodeID
Result ReadResultSignal Result *ReadResultSignal
} }
type ListField struct { type ListField struct {
@ -923,7 +923,7 @@ func NewGQLExtContext() *GQLExtContext {
} }
response_chan := ctx.Ext.GetResponseChannel(sig.ID()) response_chan := ctx.Ext.GetResponseChannel(sig.ID())
err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, sig) err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, &sig)
if err != nil { if err != nil {
ctx.Ext.FreeResponseChannel(sig.ID()) ctx.Ext.FreeResponseChannel(sig.ID())
return nil, err return nil, err
@ -1048,7 +1048,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
// Process ReadResultSignalType by forwarding it to the waiting resolver // Process ReadResultSignalType by forwarding it to the waiting resolver
if signal.Type() == ErrorSignalType { if signal.Type() == ErrorSignalType {
// TODO: Forward to resolver if waiting for it // TODO: Forward to resolver if waiting for it
sig := signal.(ErrorSignal) sig := signal.(*ErrorSignal)
response_chan := ext.FreeResponseChannel(sig.ID()) response_chan := ext.FreeResponseChannel(sig.ID())
if response_chan != nil { if response_chan != nil {
select { select {
@ -1062,7 +1062,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig) ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig)
} }
} else if signal.Type() == ReadResultSignalType { } else if signal.Type() == ReadResultSignalType {
sig := signal.(ReadResultSignal) sig := signal.(*ReadResultSignal)
response_chan := ext.FreeResponseChannel(sig.ID()) response_chan := ext.FreeResponseChannel(sig.ID())
if response_chan != nil { if response_chan != nil {
select { select {
@ -1075,14 +1075,15 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig) ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig)
} }
} else if signal.Type() == GQLStateSignalType { } else if signal.Type() == GQLStateSignalType {
sig := signal.(StringSignal) sig := signal.(*StringSignal)
switch sig.Str { switch sig.Str {
case "start_server": case "start_server":
if ext.State == "stopped" { if ext.State == "stopped" {
err := ext.StartGQLServer(ctx, node) err := ext.StartGQLServer(ctx, node)
if err == nil { if err == nil {
ext.State = "running" ext.State = "running"
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}) resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}
ctx.Send(node.ID, source, &resp)
} }
} }
case "stop_server": case "stop_server":
@ -1090,7 +1091,8 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
err := ext.StopGQLServer() err := ext.StopGQLServer()
if err == nil { if err == nil {
ext.State = "stopped" ext.State = "stopped"
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(GQLStateSignalType), "server_stopped"}) resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_stopped"}
ctx.Send(node.ID, source, &resp)
} }
} }
default: default:
@ -1102,7 +1104,8 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa
case "running": case "running":
err := ext.StartGQLServer(ctx, node) err := ext.StartGQLServer(ctx, node)
if err == nil { if err == nil {
ctx.Send(node.ID, source, StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}) resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}
ctx.Send(node.ID, source, &resp)
} }
case "stopped": case "stopped":
default: default:
@ -1116,7 +1119,7 @@ func (ext *GQLExt) Type() ExtType {
} }
func (ext *GQLExt) Serialize() ([]byte, error) { func (ext *GQLExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext, "", " ") return json.Marshal(ext)
} }
var ecdsa_curves = map[uint8]elliptic.Curve{ var ecdsa_curves = map[uint8]elliptic.Curve{
@ -1135,14 +1138,8 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{
ecdh.P256(): 0, ecdh.P256(): 0,
} }
func LoadGQLExt(ctx *Context, data []byte) (Extension, error) { func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error {
var ext GQLExt return json.Unmarshal(data, &ext)
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
}
return NewGQLExt(ctx, ext.Listen, ext.tls_cert, ext.tls_key, ext.State)
} }
func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) { func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) {

@ -1,21 +1,4 @@
package graphvent package graphvent
import ( import (
"github.com/graphql-go/graphql"
) )
var MutationStop = NewField(func()*graphql.Field {
mutation_stop := &graphql.Field{
Type: TypeSignal.Type,
Args: graphql.FieldConfigArgument{
"id": &graphql.ArgumentConfig{
Type: graphql.String,
},
},
Resolve: func(p graphql.ResolveParams) (interface{}, error) {
return StopSignal, nil
},
}
return mutation_stop
})

@ -51,7 +51,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
} }
// Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error
read_signal := NewReadSignal(ext_fields) read_signal := NewReadSignal(ext_fields)
auth_signal, err := NewAuthorizedSignal(ctx.Key, read_signal) auth_signal, err := NewAuthorizedSignal(ctx.Key, &read_signal)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,7 +61,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
resp_channels[read_signal.ID()] = response_chan resp_channels[read_signal.ID()] = response_chan
node_ids[read_signal.ID()] = id node_ids[read_signal.ID()] = id
err = ctx.Context.Send(ctx.Server.ID, id, auth_signal) err = ctx.Context.Send(ctx.Server.ID, id, &auth_signal)
if err != nil { if err != nil {
ctx.Ext.FreeResponseChannel(read_signal.ID()) ctx.Ext.FreeResponseChannel(read_signal.ID())
return nil, err return nil, err
@ -76,10 +76,10 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([
return nil, err return nil, err
} }
switch resp := response.(type) { switch resp := response.(type) {
case ReadResultSignal: case *ReadResultSignal:
responses = append(responses, NodeResult{node_ids[sig_id], resp}) responses = append(responses, NodeResult{node_ids[sig_id], resp})
case ErrorSignal: case *ErrorSignal:
return nil, resp.Error return nil, fmt.Errorf(resp.Str)
default: default:
return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp)) return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp))
} }

@ -84,37 +84,3 @@ func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) {
return Hash(node.Result.NodeType), nil return Hash(node.Result.NodeType), nil
}) })
} }
func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) {
if signal, ok := p.Source.(Signal); ok {
return fn(signal, p)
}
return nil, fmt.Errorf("Failed to cast source to event")
}
func GQLSignalType(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
return signal.Type(), nil
})
}
func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
direction := signal.Direction()
if direction == Up {
return "up", nil
} else if direction == Down {
return "down", nil
} else if direction == Direct {
return "direct", nil
}
return nil, fmt.Errorf("Invalid direction: %+v", direction)
})
}
func GQLSignalString(p graphql.ResolveParams) (interface{}, error) {
return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){
ser, err := signal.Serialize()
return string(ser), err
})
}

@ -15,7 +15,7 @@ import (
) )
func TestGQL(t *testing.T) { func TestGQL(t *testing.T) {
ctx := logTestContext(t, []string{"node", "test", "gql", "policy"}) ctx := logTestContext(t, []string{})
TestNodeType := NodeType("TEST") TestNodeType := NodeType("TEST")
err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType, ACLExtType}) err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType, ACLExtType})
@ -25,15 +25,16 @@ func TestGQL(t *testing.T) {
fatalErr(t, err) fatalErr(t, err)
listener_ext := NewListenerExt(10) listener_ext := NewListenerExt(10)
policy := NewAllNodesPolicy(Actions{MakeAction("+")}) policy := NewAllNodesPolicy(Actions{MakeAction("+")})
start_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "start_server"}
gql := NewNode(ctx, nil, GQLNodeType, 10, []QueuedSignal{ gql := NewNode(ctx, nil, GQLNodeType, 10, []QueuedSignal{
QueuedSignal{uuid.New(), StringSignal{NewDirectSignal(GQLStateSignalType), "start_server"}, time.Now()}, QueuedSignal{uuid.New(), &start_signal, time.Now()},
}, NewLockableExt(), NewACLExt(policy), gql_ext, NewGroupExt(nil), listener_ext) }, NewLockableExt(), NewACLExt(&policy), gql_ext, NewGroupExt(nil), listener_ext)
n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(policy)) n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(&policy))
err = LinkRequirement(ctx, gql.ID, n1.ID) err = LinkRequirement(ctx, gql.ID, n1.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig StringSignal) bool { _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool {
return sig.Str == "server_started" return sig.Str == "server_started"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -85,14 +86,15 @@ func TestGQL(t *testing.T) {
resp_2 := SendGQL(req_2) resp_2 := SendGQL(req_2)
ctx.Log.Logf("test", "RESP_2: %s", resp_2) ctx.Log.Logf("test", "RESP_2: %s", resp_2)
ctx.Send(n1.ID, gql.ID, StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"}) stop_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"}
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig StringSignal) bool { ctx.Send(n1.ID, gql.ID, &stop_signal)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool {
return sig.Str == "server_stopped" return sig.Str == "server_stopped"
}) })
} }
func TestGQLDB(t *testing.T) { func TestGQLDB(t *testing.T) {
ctx := logTestContext(t, []string{}) ctx := logTestContext(t, []string{"listener"})
TestUserNodeType := NodeType("TEST_USER") TestUserNodeType := NodeType("TEST_USER")
err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{})
@ -111,10 +113,10 @@ func TestGQLDB(t *testing.T) {
NewGroupExt(nil)) NewGroupExt(nil))
ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID)
err = ctx.Send(gql.ID, gql.ID, StopSignal) err = ctx.Send(gql.ID, gql.ID, &StopSignal)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStringSignal) bool { _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql.ID return sig.Str == "stopped" && sig.NodeID == gql.ID
}) })
fatalErr(t, err) fatalErr(t, err)
@ -130,9 +132,9 @@ func TestGQLDB(t *testing.T) {
fatalErr(t, err) fatalErr(t, err)
listener_ext, err = GetExt[*ListenerExt](gql_loaded) listener_ext, err = GetExt[*ListenerExt](gql_loaded)
fatalErr(t, err) fatalErr(t, err)
err = ctx.Send(gql_loaded.ID, gql_loaded.ID, StopSignal) err = ctx.Send(gql_loaded.ID, gql_loaded.ID, &StopSignal)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStringSignal) bool { _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID
}) })
fatalErr(t, err) fatalErr(t, err)

@ -35,29 +35,3 @@ func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init fu
init(&gql) init(&gql)
return &gql return &gql
} }
var TypeSignal = NewSingleton(func() *graphql.Object {
gql_type_signal := graphql.NewObject(graphql.ObjectConfig{
Name: "Signal",
IsTypeOf: func(p graphql.IsTypeOfParams) bool {
_, ok := p.Value.(Signal)
return ok
},
Fields: graphql.Fields{},
})
gql_type_signal.AddFieldConfig("Type", &graphql.Field{
Type: graphql.String,
Resolve: GQLSignalType,
})
gql_type_signal.AddFieldConfig("Direction", &graphql.Field{
Type: graphql.String,
Resolve: GQLSignalDirection,
})
gql_type_signal.AddFieldConfig("String", &graphql.Field{
Type: graphql.String,
Resolve: GQLSignalString,
})
return gql_type_signal
}, nil)

@ -17,7 +17,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) {
10, 10,
nil, nil,
listener_extension, listener_extension,
NewACLExt(policy), NewACLExt(&policy),
NewLockableExt()) NewLockableExt())
return listener, listener_extension return listener, listener_extension

@ -2,7 +2,6 @@ package graphvent
import ( import (
"encoding/json" "encoding/json"
"fmt"
) )
// A Listener extension provides a channel that can receive signals on a different thread // A Listener extension provides a channel that can receive signals on a different thread
@ -31,14 +30,10 @@ func (ext *ListenerExt) Field(name string) interface{} {
} }
// Simple load function, unmarshal the buffer int from json // Simple load function, unmarshal the buffer int from json
func LoadListenerExt(ctx *Context, data []byte) (Extension, error) { func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error {
var j int err := json.Unmarshal(data, &ext.Buffer)
err := json.Unmarshal(data, &j) ext.Chan = make(chan Signal, ext.Buffer)
if err != nil { return err
return nil, err
}
return NewListenerExt(j), nil
} }
func (listener *ListenerExt) Type() ExtType { func (listener *ListenerExt) Type() ExtType {
@ -47,7 +42,7 @@ func (listener *ListenerExt) Type() ExtType {
// Send the signal to the channel, logging an overflow if it occurs // Send the signal to the channel, logging an overflow if it occurs
func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {
ctx.Log.Logf("signal", "LISTENER_PROCESS: %s - %+v", node.ID, signal) ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal)
select { select {
case ext.Chan <- signal: case ext.Chan <- signal:
default: default:
@ -116,18 +111,8 @@ func (ext *LockableExt) Field(name string) interface{} {
}) })
} }
func LoadLockableExt(ctx *Context, data []byte) (Extension, error) {
var ext LockableExt
err := json.Unmarshal(data, &ext)
if err != nil {
return nil, err
}
return &ext, nil
}
func (ext *ListenerExt) Serialize() ([]byte, error) { func (ext *ListenerExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext.Buffer, "", " ") return json.Marshal(ext.Buffer)
} }
func (ext *LockableExt) Type() ExtType { func (ext *LockableExt) Type() ExtType {
@ -135,7 +120,11 @@ func (ext *LockableExt) Type() ExtType {
} }
func (ext *LockableExt) Serialize() ([]byte, error) { func (ext *LockableExt) Serialize() ([]byte, error) {
return json.MarshalIndent(ext, "", " ") return json.Marshal(ext)
}
func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, ext)
} }
func NewLockableExt() *LockableExt { func NewLockableExt() *LockableExt {
@ -149,36 +138,43 @@ func NewLockableExt() *LockableExt {
// Send the signal to unlock a node from itself // Send the signal to unlock a node from itself
func UnlockLockable(ctx *Context, node *Node) error { func UnlockLockable(ctx *Context, node *Node) error {
return ctx.Send(node.ID, node.ID, NewLockSignal("unlock")) lock_signal := NewLockSignal("unlock")
return ctx.Send(node.ID, node.ID, &lock_signal)
} }
// Send the signal to lock a node from itself // Send the signal to lock a node from itself
func LockLockable(ctx *Context, node *Node) error { func LockLockable(ctx *Context, node *Node) error {
return ctx.Send(node.ID, node.ID, NewLockSignal("lock")) lock_signal := NewLockSignal("lock")
return ctx.Send(node.ID, node.ID, &lock_signal)
} }
// Setup a node to send the initial requirement link signal, then send the signal // Setup a node to send the initial requirement link signal, then send the signal
func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error { func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error {
return ctx.Send(dependency, dependency, NewLinkStartSignal("req", requirement)) start_signal := NewLinkStartSignal("req", requirement)
return ctx.Send(dependency, dependency, &start_signal)
} }
// Handle a LockSignal and update the extensions owner/requirement states // Handle a LockSignal and update the extensions owner/requirement states
func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal StringSignal) { func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal) ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal)
state := signal.Str state := signal.Str
switch state { switch state {
case "unlock": case "unlock":
if ext.Owner == nil { if ext.Owner == nil {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_unlocked"))) resp := NewErrorSignal(signal.ID(), "already_unlocked")
ctx.Send(node.ID, source, &resp)
} else if source != *ext.Owner { } else if source != *ext.Owner {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_owner"))) resp := NewErrorSignal(signal.ID(), "not_owner")
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner == nil { } else if ext.PendingOwner == nil {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_unlocking"))) resp := NewErrorSignal(signal.ID(), "already_unlocking")
ctx.Send(node.ID, source, &resp)
} else { } else {
if len(ext.Requirements) == 0 { if len(ext.Requirements) == 0 {
ext.Owner = nil ext.Owner = nil
ext.PendingOwner = nil ext.PendingOwner = nil
ctx.Send(node.ID, source, NewLockSignal("unlocked")) resp := NewLockSignal("unlocked")
ctx.Send(node.ID, source, &resp)
} else { } else {
ext.PendingOwner = nil ext.PendingOwner = nil
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
@ -188,22 +184,27 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
} }
state.Lock = "unlocking" state.Lock = "unlocking"
ext.Requirements[id] = state ext.Requirements[id] = state
ctx.Send(node.ID, id, NewLockSignal("unlock")) resp := NewLockSignal("unlock")
ctx.Send(node.ID, id, &resp)
} }
} }
if source != node.ID { if source != node.ID {
ctx.Send(node.ID, source, NewLockSignal("unlocking")) resp := NewLockSignal("unlocking")
ctx.Send(node.ID, source, &resp)
} }
} }
} }
case "unlocking": case "unlocking":
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" { } else if state.Link != "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("node_not_linked"))) resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "unlocking" { } else if state.Lock != "unlocking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_unlocking"))) resp := NewErrorSignal(signal.ID(), "not_unlocking")
ctx.Send(node.ID, source, &resp)
} }
case "unlocked": case "unlocked":
@ -213,11 +214,14 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" { } else if state.Link != "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_linked"))) resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "unlocking" { } else if state.Lock != "unlocking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_unlocking"))) resp := NewErrorSignal(signal.ID(), "not_unlocking")
ctx.Send(node.ID, source, &resp)
} else { } else {
state.Lock = "unlocked" state.Lock = "unlocked"
ext.Requirements[source] = state ext.Requirements[source] = state
@ -237,7 +241,8 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
if linked == unlocked { if linked == unlocked {
previous_owner := *ext.Owner previous_owner := *ext.Owner
ext.Owner = nil ext.Owner = nil
ctx.Send(node.ID, previous_owner, NewLockSignal("unlocked")) resp := NewLockSignal("unlocked")
ctx.Send(node.ID, previous_owner, &resp)
} }
} }
} }
@ -248,11 +253,14 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" { } else if state.Link != "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_linked"))) resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "locking" { } else if state.Lock != "locking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_locking"))) resp := NewErrorSignal(signal.ID(), "not_locking")
ctx.Send(node.ID, source, &resp)
} else { } else {
state.Lock = "locked" state.Lock = "locked"
ext.Requirements[source] = state ext.Requirements[source] = state
@ -271,31 +279,38 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
if linked == locked { if linked == locked {
ext.Owner = ext.PendingOwner ext.Owner = ext.PendingOwner
ctx.Send(node.ID, *ext.Owner, NewLockSignal("locked")) resp := NewLockSignal("locked")
ctx.Send(node.ID, *ext.Owner, &resp)
} }
} }
} }
case "locking": case "locking":
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == false { if exists == false {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) resp := NewErrorSignal(signal.ID(), "not_requirement")
ctx.Send(node.ID, source, &resp)
} else if state.Link != "linked" { } else if state.Link != "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("node_not_linked"))) resp := NewErrorSignal(signal.ID(), "not_linked")
ctx.Send(node.ID, source, &resp)
} else if state.Lock != "locking" { } else if state.Lock != "locking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_locking"))) resp := NewErrorSignal(signal.ID(), "not_locking")
ctx.Send(node.ID, source, &resp)
} }
case "lock": case "lock":
if ext.Owner != nil { if ext.Owner != nil {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_locked"))) resp := NewErrorSignal(signal.ID(), "already_locked")
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != nil { } else if ext.PendingOwner != nil {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_locking"))) resp := NewErrorSignal(signal.ID(), "already_locking")
ctx.Send(node.ID, source, &resp)
} else { } else {
owner := source owner := source
if len(ext.Requirements) == 0 { if len(ext.Requirements) == 0 {
ext.Owner = &owner ext.Owner = &owner
ext.PendingOwner = ext.Owner ext.PendingOwner = ext.Owner
ctx.Send(node.ID, source, NewLockSignal("locked")) resp := NewLockSignal("locked")
ctx.Send(node.ID, source, &resp)
} else { } else {
ext.PendingOwner = &owner ext.PendingOwner = &owner
for id, state := range(ext.Requirements) { for id, state := range(ext.Requirements) {
@ -305,11 +320,13 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
} }
state.Lock = "locking" state.Lock = "locking"
ext.Requirements[id] = state ext.Requirements[id] = state
ctx.Send(node.ID, id, NewLockSignal("lock")) sub := NewLockSignal("lock")
ctx.Send(node.ID, id, &sub)
} }
} }
if source != node.ID { if source != node.ID {
ctx.Send(node.ID, source, NewLockSignal("locking")) resp := NewLockSignal("locking")
ctx.Send(node.ID, source, &resp)
} }
} }
} }
@ -318,7 +335,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node
} }
} }
func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal IDStringSignal) { func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal *IDStringSignal) {
ctx.Log.Logf("lockable", "LINK__START_SIGNAL: %s->%s %+v", source, node.ID, signal) ctx.Log.Logf("lockable", "LINK__START_SIGNAL: %s->%s %+v", source, node.ID, signal)
link_type := signal.Str link_type := signal.Str
target := signal.NodeID target := signal.NodeID
@ -327,72 +344,90 @@ func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node
state, exists := ext.Requirements[target] state, exists := ext.Requirements[target]
_, dep_exists := ext.Dependencies[target] _, dep_exists := ext.Dependencies[target]
if ext.Owner != nil { if ext.Owner != nil {
ctx.Send(node.ID, source, NewLinkStartSignal("locked", target)) resp := NewLinkStartSignal("locked", target)
ctx.Send(node.ID, source, &resp)
} else if ext.Owner != ext.PendingOwner { } else if ext.Owner != ext.PendingOwner {
if ext.PendingOwner == nil { if ext.PendingOwner == nil {
ctx.Send(node.ID, source, NewLinkStartSignal("unlocking", target)) resp := NewLinkStartSignal("unlocking", target)
ctx.Send(node.ID, source, &resp)
} else { } else {
ctx.Send(node.ID, source, NewLinkStartSignal("locking", target)) resp := NewLinkStartSignal("locking", target)
ctx.Send(node.ID, source, &resp)
} }
} else if exists == true { } else if exists == true {
if state.Link == "linking" { if state.Link == "linking" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_linking_req"))) resp := NewErrorSignal(signal.ID(), "already_linking_req")
ctx.Send(node.ID, source, &resp)
} else if state.Link == "linked" { } else if state.Link == "linked" {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_req"))) resp := NewErrorSignal(signal.ID(), "already_req")
ctx.Send(node.ID, source, &resp)
} }
} else if dep_exists == true { } else if dep_exists == true {
ctx.Send(node.ID, source, NewLinkStartSignal("already_dep", target)) resp := NewLinkStartSignal("already_dep", target)
ctx.Send(node.ID, source, &resp)
} else { } else {
ext.Requirements[target] = LinkState{"linking", "unlocked", source} ext.Requirements[target] = LinkState{"linking", "unlocked", source}
ctx.Send(node.ID, target, NewLinkSignal("linked_as_req")) resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, source, NewLinkStartSignal("linking_req", target)) ctx.Send(node.ID, target, &resp)
notify := NewLinkStartSignal("linking_req", target)
ctx.Send(node.ID, source, &notify)
} }
} }
} }
// Handle LinkSignal, updating the extensions requirements and dependencies as necessary // Handle LinkSignal, updating the extensions requirements and dependencies as necessary
// TODO: Add unlink // TODO: Add unlink
func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal StringSignal) { func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) {
ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal)
state := signal.Str state := signal.Str
switch state { switch state {
case "linked_as_dep": case "linked_as_dep":
state, exists := ext.Requirements[source] state, exists := ext.Requirements[source]
if exists == true && state.Link == "linked" { if exists == true && state.Link == "linked" {
ctx.Send(node.ID, state.Initiator, NewLinkStartSignal("linked_as_req", source)) resp := NewLinkStartSignal("linked_as_req", source)
ctx.Send(node.ID, state.Initiator, &resp)
} else if state.Link == "linking" { } else if state.Link == "linking" {
state.Link = "linked" state.Link = "linked"
ext.Requirements[source] = state ext.Requirements[source] = state
ctx.Send(node.ID, source, NewLinkSignal("linked_as_req")) resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != ext.Owner { } else if ext.PendingOwner != ext.Owner {
if ext.Owner == nil { if ext.Owner == nil {
ctx.Send(node.ID, source, NewLinkSignal("locking")) resp := NewLinkSignal("locking")
ctx.Send(node.ID, source, &resp)
} else { } else {
ctx.Send(node.ID, source, NewLinkSignal("unlocking")) resp := NewLinkSignal("unlocking")
ctx.Send(node.ID, source, &resp)
} }
} else { } else {
ext.Requirements[source] = LinkState{"linking", "unlocked", source} ext.Requirements[source] = LinkState{"linking", "unlocked", source}
ctx.Send(node.ID, source, NewLinkSignal("linked_as_req")) resp := NewLinkSignal("linked_as_req")
ctx.Send(node.ID, source, &resp)
} }
ctx.Log.Logf("lockable", "%s is a dependency of %s", node.ID, source) ctx.Log.Logf("lockable", "%s is a dependency of %s", node.ID, source)
case "linked_as_req": case "linked_as_req":
state, exists := ext.Dependencies[source] state, exists := ext.Dependencies[source]
if exists == true && state.Link == "linked" { if exists == true && state.Link == "linked" {
ctx.Send(node.ID, state.Initiator, NewLinkStartSignal("linked_as_dep", source)) resp := NewLinkStartSignal("linked_as_dep", source)
ctx.Send(node.ID, state.Initiator, &resp)
} else if state.Link == "linking" { } else if state.Link == "linking" {
state.Link = "linked" state.Link = "linked"
ext.Dependencies[source] = state ext.Dependencies[source] = state
ctx.Send(node.ID, source, NewLinkSignal("linked_as_dep")) resp := NewLinkSignal("linked_as_dep")
ctx.Send(node.ID, source, &resp)
} else if ext.PendingOwner != ext.Owner { } else if ext.PendingOwner != ext.Owner {
if ext.Owner == nil { if ext.Owner == nil {
ctx.Send(node.ID, source, NewLinkSignal("locking")) resp := NewLinkSignal("locking")
ctx.Send(node.ID, source, &resp)
} else { } else {
ctx.Send(node.ID, source, NewLinkSignal("unlocking")) resp := NewLinkSignal("unlocking")
ctx.Send(node.ID, source, &resp)
} }
} else { } else {
ext.Dependencies[source] = LinkState{"linking", "unlocked", source} ext.Dependencies[source] = LinkState{"linking", "unlocked", source}
ctx.Send(node.ID, source, NewLinkSignal("linked_as_dep")) resp := NewLinkSignal("linked_as_dep")
ctx.Send(node.ID, source, &resp)
} }
ctx.Log.Logf("lockable", "%s is a requirement of %s", node.ID, source) ctx.Log.Logf("lockable", "%s is a requirement of %s", node.ID, source)
@ -442,11 +477,11 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal
case Direct: case Direct:
switch signal.Type() { switch signal.Type() {
case LinkSignalType: case LinkSignalType:
ext.HandleLinkSignal(ctx, source, node, signal.(StringSignal)) ext.HandleLinkSignal(ctx, source, node, signal.(*StringSignal))
case LockSignalType: case LockSignalType:
ext.HandleLockSignal(ctx, source, node, signal.(StringSignal)) ext.HandleLockSignal(ctx, source, node, signal.(*StringSignal))
case LinkStartSignalType: case LinkStartSignalType:
ext.HandleLinkStartSignal(ctx, source, node, signal.(IDStringSignal)) ext.HandleLinkStartSignal(ctx, source, node, signal.(*IDStringSignal))
default: default:
} }
default: default:

@ -26,13 +26,13 @@ func TestLink(t *testing.T) {
l1_listener := NewListenerExt(10) l1_listener := NewListenerExt(10)
l1 := NewNode(ctx, nil, TestLockableType, 10, nil, l1 := NewNode(ctx, nil, TestLockableType, 10, nil,
l1_listener, l1_listener,
NewACLExt(link_policy), NewACLExt(&link_policy),
NewLockableExt(), NewLockableExt(),
) )
l2_listener := NewListenerExt(10) l2_listener := NewListenerExt(10)
l2 := NewNode(ctx, nil, TestLockableType, 10, nil, l2 := NewNode(ctx, nil, TestLockableType, 10, nil,
l2_listener, l2_listener,
NewACLExt(link_policy), NewACLExt(&link_policy),
NewLockableExt(), NewLockableExt(),
) )
@ -40,20 +40,21 @@ func TestLink(t *testing.T) {
err := LinkRequirement(ctx, l1.ID, l2.ID) err := LinkRequirement(ctx, l1.ID, l2.ID)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, func(sig IDStringSignal) bool { _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "linked_as_req" return sig.Str == "linked_as_req"
}) })
fatalErr(t, err) fatalErr(t, err)
err = ctx.Send(l2.ID, l2.ID, NewStatusSignal("TEST", l2.ID)) sig1 := NewStatusSignal("TEST", l2.ID)
err = ctx.Send(l2.ID, l2.ID, &sig1)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig IDStringSignal) bool { _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "TEST" return sig.Str == "TEST"
}) })
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig IDStringSignal) bool { _, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "TEST" return sig.Str == "TEST"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -64,7 +65,7 @@ func TestLink10K(t *testing.T) {
NewLockable := func()(*Node) { NewLockable := func()(*Node) {
l := NewNode(ctx, nil, TestLockableType, 10, nil, l := NewNode(ctx, nil, TestLockableType, 10, nil,
NewACLExt(lock_policy, link_policy), NewACLExt(&lock_policy, &link_policy),
NewLockableExt(), NewLockableExt(),
) )
return l return l
@ -74,7 +75,7 @@ func TestLink10K(t *testing.T) {
listener := NewListenerExt(100000) listener := NewListenerExt(100000)
l := NewNode(ctx, nil, TestLockableType, 256, nil, l := NewNode(ctx, nil, TestLockableType, 256, nil,
listener, listener,
NewACLExt(lock_policy, link_policy), NewACLExt(&lock_policy, &link_policy),
NewLockableExt(), NewLockableExt(),
) )
return l, listener return l, listener
@ -91,7 +92,7 @@ func TestLink10K(t *testing.T) {
for range(lockables) { for range(lockables) {
_, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, func(sig IDStringSignal) bool { _, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool {
return sig.Str == "linked_as_req" return sig.Str == "linked_as_req"
}) })
fatalErr(t, err) fatalErr(t, err)
@ -107,7 +108,7 @@ func TestLock(t *testing.T) {
listener := NewListenerExt(100) listener := NewListenerExt(100)
l := NewNode(ctx, nil, TestLockableType, 10, nil, l := NewNode(ctx, nil, TestLockableType, 10, nil,
listener, listener,
NewACLExt(lock_policy, link_policy), NewACLExt(&lock_policy, &link_policy),
NewLockableExt(), NewLockableExt(),
) )
return l, listener return l, listener
@ -140,11 +141,11 @@ func TestLock(t *testing.T) {
err = LinkRequirement(ctx, l0.ID, l5.ID) err = LinkRequirement(ctx, l0.ID, l5.ID)
fatalErr(t, err) fatalErr(t, err)
linked_as_req := func(sig IDStringSignal) bool { linked_as_req := func(sig *IDStringSignal) bool {
return sig.Str == "linked_as_req" return sig.Str == "linked_as_req"
} }
locked := func(sig StringSignal) bool { locked := func(sig *StringSignal) bool {
return sig.Str == "locked" return sig.Str == "locked"
} }

@ -84,8 +84,9 @@ func RandID() NodeID {
// A Serializable has a type that can be used to map to it, and a function to serialize the current state // A Serializable has a type that can be used to map to it, and a function to serialize the current state
type Serializable[I comparable] interface { type Serializable[I comparable] interface {
Type() I
Serialize()([]byte,error) Serialize()([]byte,error)
Deserialize(*Context,[]byte)error
Type() I
} }
// Extensions are data attached to nodes that process signals // Extensions are data attached to nodes that process signals
@ -197,7 +198,7 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
// Queue the signal for extensions to perform startup actions // Queue the signal for extensions to perform startup actions
node.QueueSignal(time.Now(), NewDirectSignal(StartSignalType)) node.QueueSignal(time.Now(), &StartSignal)
for true { for true {
var signal Signal var signal Signal
@ -209,7 +210,8 @@ func nodeLoop(ctx *Context, node *Node) error {
err := Allowed(ctx, msg.Source, signal.Permission(), node) err := Allowed(ctx, msg.Source, signal.Permission(), node)
if err != nil { if err != nil {
ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err) ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err)
ctx.Send(node.ID, msg.Source, NewErrorSignal(msg.Signal.ID(), err)) resp := NewErrorSignal(msg.Signal.ID(), err.Error())
ctx.Send(node.ID, msg.Source, &resp)
continue continue
} }
case <-node.TimeoutChan: case <-node.TimeoutChan:
@ -240,7 +242,7 @@ func nodeLoop(ctx *Context, node *Node) error {
// Unwrap Authorized Signals // Unwrap Authorized Signals
if signal.Type() == AuthorizedSignalType { if signal.Type() == AuthorizedSignalType {
sig, ok := signal.(AuthorizedSignal) sig, ok := signal.(*AuthorizedSignal)
if ok == false { if ok == false {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(signal)) ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(signal))
} else { } else {
@ -254,14 +256,16 @@ func nodeLoop(ctx *Context, node *Node) error {
err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node) err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node)
if err != nil { if err != nil {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err) ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err)
ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), err)) resp := NewErrorSignal(sig.ID(), err.Error())
ctx.Send(node.ID, source, &resp)
} else { } else {
// Unwrap the signal without changing the source // Unwrap the signal without changing the source
signal = sig.Signal signal = sig.Signal
} }
} else { } else {
ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate") ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate")
ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), fmt.Errorf("failed to validate signature"))) resp := NewErrorSignal(sig.ID(), "signature validation failed")
ctx.Send(node.ID, source, &resp)
} }
} }
} }
@ -269,16 +273,19 @@ func nodeLoop(ctx *Context, node *Node) error {
// Handle special signal types // Handle special signal types
if signal.Type() == StopSignalType { if signal.Type() == StopSignalType {
ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), nil)) resp := NewErrorSignal(signal.ID(), "stopped")
node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID)) ctx.Send(node.ID, source, &resp)
status := NewStatusSignal("stopped", node.ID)
node.Process(ctx, node.ID, &status)
break break
} else if signal.Type() == ReadSignalType { } else if signal.Type() == ReadSignalType {
read_signal, ok := signal.(ReadSignal) read_signal, ok := signal.(*ReadSignal)
if ok == false { if ok == false {
ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal)
} else { } else {
result := ReadNodeFields(ctx, node, source, read_signal.Extensions) result := ReadNodeFields(ctx, node, source, read_signal.Extensions)
ctx.Send(node.ID, source, NewReadResultSignal(read_signal.ID(), node.Type, result)) resp := NewReadResultSignal(read_signal.ID(), node.Type, result)
ctx.Send(node.ID, source, &resp)
} }
} }

@ -41,18 +41,19 @@ func TestNodeRead(t *testing.T) {
n1_id: Actions{MakeAction(ReadResultSignalType, "+")}, n1_id: Actions{MakeAction(ReadResultSignalType, "+")},
}) })
n2_listener := NewListenerExt(10) n2_listener := NewListenerExt(10)
n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener) n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(&n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener)
n1_policy := NewPerNodePolicy(map[NodeID]Actions{ n1_policy := NewPerNodePolicy(map[NodeID]Actions{
n2_id: Actions{MakeAction(ReadSignalType, "+")}, n2_id: Actions{MakeAction(ReadSignalType, "+")},
}) })
n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(n1_policy), NewGroupExt(nil), NewECDHExt()) n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(&n1_policy), NewGroupExt(nil), NewECDHExt())
ctx.Send(n2.ID, n1.ID, NewReadSignal(map[ExtType][]string{ read_sig := NewReadSignal(map[ExtType][]string{
GroupExtType: []string{"members"}, GroupExtType: []string{"members"},
})) })
ctx.Send(n2.ID, n1.ID, &read_sig)
res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig ReadResultSignal) bool { res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool {
return true return true
}) })
fatalErr(t, err) fatalErr(t, err)
@ -68,11 +69,11 @@ func TestECDH(t *testing.T) {
n1_listener := NewListenerExt(10) n1_listener := NewListenerExt(10)
ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+"), MakeAction(ECDHProxySignalType, "+")}) ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+"), MakeAction(ECDHProxySignalType, "+")})
n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt(), n1_listener) n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt(), n1_listener)
n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt()) n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt())
n3_listener := NewListenerExt(10) n3_listener := NewListenerExt(10)
n3_policy := NewPerNodePolicy(NodeActions{n1.ID: Actions{MakeAction(StopSignalType)}}) n3_policy := NewPerNodePolicy(NodeActions{n1.ID: Actions{MakeAction(StopSignalType)}})
n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy, n3_policy), NewECDHExt(), n3_listener) n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy, &n3_policy), NewECDHExt(), n3_listener)
ctx.Log.Logf("test", "N1: %s", n1.ID) ctx.Log.Logf("test", "N1: %s", n1.ID)
ctx.Log.Logf("test", "N2: %s", n2.ID) ctx.Log.Logf("test", "N2: %s", n2.ID)
@ -87,18 +88,18 @@ func TestECDH(t *testing.T) {
} }
fatalErr(t, err) fatalErr(t, err)
ctx.Log.Logf("test", "N1_EC: %+v", n1_ec) ctx.Log.Logf("test", "N1_EC: %+v", n1_ec)
err = ctx.Send(n1.ID, n2.ID, ecdh_req) err = ctx.Send(n1.ID, n2.ID, &ecdh_req)
fatalErr(t, err) fatalErr(t, err)
_, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig ECDHSignal) bool { _, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool {
return sig.Str == "resp" return sig.Str == "resp"
}) })
fatalErr(t, err) fatalErr(t, err)
time.Sleep(10*time.Millisecond) time.Sleep(10*time.Millisecond)
ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, NewDirectSignal(StopSignalType), ecdh_ext.ECDHStates[n2.ID].SharedSecret) ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret)
fatalErr(t, err) fatalErr(t, err)
err = ctx.Send(n1.ID, n2.ID, ecdh_sig) err = ctx.Send(n1.ID, n2.ID, &ecdh_sig)
fatalErr(t, err) fatalErr(t, err)
} }

@ -20,6 +20,8 @@ type Policy interface {
Allows(principal_id NodeID, action Action, node *Node) error Allows(principal_id NodeID, action Action, node *Node) error
// Merge with another policy of the same underlying type // Merge with another policy of the same underlying type
Merge(Policy) Policy Merge(Policy) Policy
// Make a copy of this policy
Copy() Policy
} }
func (policy AllNodesPolicy) Allows(principal_id NodeID, action Action, node *Node) error { func (policy AllNodesPolicy) Allows(principal_id NodeID, action Action, node *Node) error {
@ -72,6 +74,14 @@ func MergeActions(first Actions, second Actions) Actions {
return ret return ret
} }
func CopyNodeActions(actions NodeActions) NodeActions {
ret := NodeActions{}
for id, a := range(actions) {
ret[id] = a
}
return ret
}
func MergeNodeActions(modified NodeActions, read NodeActions) { func MergeNodeActions(modified NodeActions, read NodeActions) {
for id, actions := range(read) { for id, actions := range(read) {
existing, exists := modified[id] existing, exists := modified[id]
@ -83,24 +93,47 @@ func MergeNodeActions(modified NodeActions, read NodeActions) {
} }
} }
func (policy PerNodePolicy) Merge(p Policy) Policy { func (policy *PerNodePolicy) Merge(p Policy) Policy {
other := p.(PerNodePolicy) other := p.(*PerNodePolicy)
MergeNodeActions(policy.NodeActions, other.NodeActions) MergeNodeActions(policy.NodeActions, other.NodeActions)
return policy return policy
} }
func (policy AllNodesPolicy) Merge(p Policy) Policy { func (policy *PerNodePolicy) Copy() Policy {
other := p.(AllNodesPolicy) new_actions := CopyNodeActions(policy.NodeActions)
return &PerNodePolicy{
NodeActions: new_actions,
}
}
func (policy *AllNodesPolicy) Merge(p Policy) Policy {
other := p.(*AllNodesPolicy)
policy.Actions = MergeActions(policy.Actions, other.Actions) policy.Actions = MergeActions(policy.Actions, other.Actions)
return policy return policy
} }
func (policy RequirementOfPolicy) Merge(p Policy) Policy { func (policy *AllNodesPolicy) Copy() Policy {
other := p.(RequirementOfPolicy) new_actions := policy.Actions
return &AllNodesPolicy {
Actions: new_actions,
}
}
func (policy *RequirementOfPolicy) Merge(p Policy) Policy {
other := p.(*RequirementOfPolicy)
policy.Actions = MergeActions(policy.Actions, other.Actions) policy.Actions = MergeActions(policy.Actions, other.Actions)
return policy return policy
} }
func (policy *RequirementOfPolicy) Copy() Policy {
new_actions := policy.Actions
return &RequirementOfPolicy{
AllNodesPolicy {
Actions: new_actions,
},
}
}
type Action []string type Action []string
func MakeAction(parts ...interface{}) Action { func MakeAction(parts ...interface{}) Action {
@ -178,36 +211,6 @@ func (actions *NodeActions) UnmarshalJSON(data []byte) error {
return nil return nil
} }
type AllNodesPolicyJSON struct {
Actions Actions `json:"actions"`
}
func AllNodesPolicyLoad(init_fn func(Actions)(Policy, error)) func(*Context, []byte)(Policy, error) {
return func(ctx *Context, data []byte)(Policy, error){
var j AllNodesPolicyJSON
err := json.Unmarshal(data, &j)
if err != nil {
return nil, err
}
return init_fn(j.Actions)
}
}
func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) {
return func(ctx *Context, data []byte)(Policy, error){
policy := PerNodePolicy{
NodeActions: NodeActions{},
}
err := json.Unmarshal(data, &policy)
if err != nil {
return nil, err
}
return init_fn(policy.NodeActions)
}
}
func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy { func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy {
if node_actions == nil { if node_actions == nil {
node_actions = NodeActions{} node_actions = NodeActions{}
@ -222,14 +225,18 @@ type PerNodePolicy struct {
NodeActions NodeActions `json:"node_actions"` NodeActions NodeActions `json:"node_actions"`
} }
func (policy PerNodePolicy) Type() PolicyType { func (policy *PerNodePolicy) Type() PolicyType {
return PerNodePolicyType return PerNodePolicyType
} }
func (policy PerNodePolicy) Serialize() ([]byte, error) { func (policy *PerNodePolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(policy, "", " ") return json.MarshalIndent(policy, "", " ")
} }
func (policy *PerNodePolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy)
}
func NewAllNodesPolicy(actions Actions) AllNodesPolicy { func NewAllNodesPolicy(actions Actions) AllNodesPolicy {
if actions == nil { if actions == nil {
actions = Actions{} actions = Actions{}
@ -244,14 +251,18 @@ type AllNodesPolicy struct {
Actions Actions Actions Actions
} }
func (policy AllNodesPolicy) Type() PolicyType { func (policy *AllNodesPolicy) Type() PolicyType {
return AllNodesPolicyType return AllNodesPolicyType
} }
func (policy AllNodesPolicy) Serialize() ([]byte, error) { func (policy *AllNodesPolicy) Serialize() ([]byte, error) {
return json.MarshalIndent(policy, "", " ") return json.MarshalIndent(policy, "", " ")
} }
func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, policy)
}
// Extension to allow a node to hold ACL policies // Extension to allow a node to hold ACL policies
type ACLExt struct { type ACLExt struct {
Policies map[PolicyType]Policy Policies map[PolicyType]Policy
@ -267,35 +278,16 @@ func NodeList(nodes ...*Node) NodeMap {
} }
type PolicyLoadFunc func(*Context,[]byte) (Policy, error) type PolicyLoadFunc func(*Context,[]byte) (Policy, error)
type PolicyInfo struct {
Load PolicyLoadFunc
}
type ACLExtContext struct { type ACLExtContext struct {
Types map[PolicyType]PolicyInfo Loads map[PolicyType]PolicyLoadFunc
} }
func NewACLExtContext() *ACLExtContext { func NewACLExtContext() *ACLExtContext {
return &ACLExtContext{ return &ACLExtContext{
Types: map[PolicyType]PolicyInfo{ Loads: map[PolicyType]PolicyLoadFunc{
AllNodesPolicyType: PolicyInfo{ AllNodesPolicyType: LoadPolicy[AllNodesPolicy,*AllNodesPolicy],
Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ PerNodePolicyType: LoadPolicy[PerNodePolicy,*PerNodePolicy],
policy := NewAllNodesPolicy(actions) RequirementOfPolicyType: LoadPolicy[RequirementOfPolicy,*RequirementOfPolicy],
return &policy, nil
}),
},
PerNodePolicyType: PolicyInfo{
Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){
policy := NewPerNodePolicy(nodes)
return &policy, nil
}),
},
RequirementOfPolicyType: PolicyInfo{
Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){
policy := NewRequirementOfPolicy(actions)
return &policy, nil
}),
},
}, },
} }
} }
@ -331,13 +323,15 @@ func (ext *ACLExt) Field(name string) interface{} {
var ErrorSignalAction = Action{"ERROR_RESP"} var ErrorSignalAction = Action{"ERROR_RESP"}
var ReadResultSignalAction = Action{"READ_RESULT"} var ReadResultSignalAction = Action{"READ_RESULT"}
var AuthorizedSignalAction = Action{"AUTHORIZED_READ"} var AuthorizedSignalAction = Action{"AUTHORIZED_READ"}
var defaultPolicy = NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction, AuthorizedSignalAction})
var DefaultACLPolicies = []Policy{ var DefaultACLPolicies = []Policy{
NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction, AuthorizedSignalAction}), &defaultPolicy,
} }
func NewACLExt(policies ...Policy) *ACLExt { func NewACLExt(policies ...Policy) *ACLExt {
policy_map := map[PolicyType]Policy{} policy_map := map[PolicyType]Policy{}
for _, policy := range(append(policies, DefaultACLPolicies...)) { for _, policy_arg := range(append(policies, DefaultACLPolicies...)) {
policy := policy_arg.Copy()
existing, exists := policy_map[policy.Type()] existing, exists := policy_map[policy.Type()]
if exists == true { if exists == true {
policy = existing.Merge(policy) policy = existing.Merge(policy)
@ -351,36 +345,49 @@ func NewACLExt(policies ...Policy) *ACLExt {
} }
} }
func LoadACLExt(ctx *Context, data []byte) (Extension, error) { func LoadPolicy[T any, P interface {
*T
Policy
}](ctx *Context, data []byte) (Policy, error) {
p := P(new(T))
err := p.Deserialize(ctx, data)
if err != nil {
return nil, err
}
return p, nil
}
func (ext *ACLExt) Deserialize(ctx *Context, data []byte) error {
var j struct { var j struct {
Policies map[string][]byte `json:"policies"` Policies map[string][]byte `json:"policies"`
} }
err := json.Unmarshal(data, &j) err := json.Unmarshal(data, &j)
if err != nil { if err != nil {
return nil, err return err
} }
policies := make([]Policy, len(j.Policies))
i := 0
acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx) acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx)
if err != nil { if err != nil {
return nil, err return err
} }
ext.Policies = map[PolicyType]Policy{}
for name, ser := range(j.Policies) { for name, ser := range(j.Policies) {
policy_def, exists := acl_ctx.Types[PolicyType(name)] policy_load, exists := acl_ctx.Loads[PolicyType(name)]
if exists == false { if exists == false {
return nil, fmt.Errorf("%s is not a known policy type", name) return fmt.Errorf("%s is not a known policy type", name)
} }
policy, err := policy_def.Load(ctx, ser) policy, err := policy_load(ctx, ser)
if err != nil { if err != nil {
return nil, err return err
} }
policies[i] = policy ext.Policies[PolicyType(name)] = policy
i++
} }
return NewACLExt(policies...), nil return nil
} }
func (ext *ACLExt) Type() ExtType { func (ext *ACLExt) Type() ExtType {

@ -95,23 +95,27 @@ type BaseSignal struct {
UUID uuid.UUID `json:"id"` UUID uuid.UUID `json:"id"`
} }
func (signal BaseSignal) ID() uuid.UUID { func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error {
return json.Unmarshal(data, signal)
}
func (signal *BaseSignal) ID() uuid.UUID {
return signal.UUID return signal.UUID
} }
func (signal BaseSignal) Type() SignalType { func (signal *BaseSignal) Type() SignalType {
return signal.SignalType return signal.SignalType
} }
func (signal BaseSignal) Permission() Action { func (signal *BaseSignal) Permission() Action {
return MakeAction(signal.Type()) return MakeAction(signal.Type())
} }
func (signal BaseSignal) Direction() SignalDirection { func (signal *BaseSignal) Direction() SignalDirection {
return signal.SignalDirection return signal.SignalDirection
} }
func (signal BaseSignal) Serialize() ([]byte, error) { func (signal *BaseSignal) Serialize() ([]byte, error) {
return json.Marshal(signal) return json.Marshal(signal)
} }
@ -136,43 +140,16 @@ func NewDirectSignal(signal_type SignalType) BaseSignal {
return NewBaseSignal(signal_type, Direct) return NewBaseSignal(signal_type, Direct)
} }
var StartSignal = NewDirectSignal(StartSignalType)
var StopSignal = NewDownSignal(StopSignalType) var StopSignal = NewDownSignal(StopSignalType)
type ErrorSignal struct {
BaseSignal
Error error `json:"error"`
}
func (signal ErrorSignal) Permission() Action {
return ErrorSignalAction
}
func NewErrorSignal(req_id uuid.UUID, err error) ErrorSignal {
return ErrorSignal{
BaseSignal: BaseSignal{
Direct,
ErrorSignalType,
req_id,
},
Error: err,
}
}
type IDSignal struct { type IDSignal struct {
BaseSignal BaseSignal
NodeID `json:"id"` NodeID `json:"id"`
} }
func (signal IDSignal) Serialize() ([]byte, error) { func (signal *IDSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal) return json.Marshal(signal)
}
func (signal IDSignal) String() string {
ser, err := json.Marshal(signal)
if err != nil {
return "STATE_SER_ERR"
}
return string(ser)
} }
func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal { func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal {
@ -187,21 +164,38 @@ type StringSignal struct {
Str string `json:"state"` Str string `json:"state"`
} }
func (signal StringSignal) Serialize() ([]byte, error) { func (signal *StringSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal) return json.Marshal(&signal)
} }
type ErrorSignal struct {
StringSignal
}
func (signal *ErrorSignal) Permission() Action {
return ErrorSignalAction
}
func NewErrorSignal(req_id uuid.UUID, err string) ErrorSignal {
return ErrorSignal{
StringSignal{
NewDirectSignal(ErrorSignalType),
err,
},
}
}
type IDStringSignal struct { type IDStringSignal struct {
BaseSignal BaseSignal
NodeID `json:"node_id"` NodeID `json:"node_id"`
Str string `json:"string"` Str string `json:"string"`
} }
func (signal IDStringSignal) Serialize() ([]byte, error) { func (signal *IDStringSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal) return json.Marshal(signal)
} }
func (signal IDStringSignal) String() string { func (signal *IDStringSignal) String() string {
ser, err := json.Marshal(signal) ser, err := json.Marshal(signal)
if err != nil { if err != nil {
return "STATE_SER_ERR" return "STATE_SER_ERR"
@ -243,7 +237,7 @@ func NewLockSignal(state string) StringSignal {
} }
} }
func (signal StringSignal) Permission() Action { func (signal *StringSignal) Permission() Action {
return MakeAction(signal.Type(), signal.Str) return MakeAction(signal.Type(), signal.Str)
} }
@ -259,7 +253,7 @@ type AuthorizedSignal struct {
Signature []byte Signature []byte
} }
func (signal AuthorizedSignal) Permission() Action { func (signal *AuthorizedSignal) Permission() Action {
return AuthorizedSignalAction return AuthorizedSignalAction
} }
@ -283,8 +277,8 @@ func NewAuthorizedSignal(principal *ecdsa.PrivateKey, signal Signal) (Authorized
}, nil }, nil
} }
func (signal ReadSignal) Serialize() ([]byte, error) { func (signal *ReadSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal) return json.Marshal(signal)
} }
func NewReadSignal(exts map[ExtType][]string) ReadSignal { func NewReadSignal(exts map[ExtType][]string) ReadSignal {
@ -300,7 +294,7 @@ type ReadResultSignal struct {
Extensions map[ExtType]map[string]interface{} `json:"extensions"` Extensions map[ExtType]map[string]interface{} `json:"extensions"`
} }
func (signal ReadResultSignal) Permission() Action { func (signal *ReadResultSignal) Permission() Action {
return ReadResultSignalAction return ReadResultSignalAction
} }
@ -342,8 +336,8 @@ func (signal *ECDHSignal) MarshalJSON() ([]byte, error) {
}) })
} }
func (signal ECDHSignal) Serialize() ([]byte, error) { func (signal *ECDHSignal) Serialize() ([]byte, error) {
return json.Marshal(&signal) return json.Marshal(signal)
} }
func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) { func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) {
@ -390,7 +384,7 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e
const DEFAULT_ECDH_WINDOW = time.Second const DEFAULT_ECDH_WINDOW = time.Second
func NewECDHRespSignal(ctx *Context, node *Node, req ECDHSignal) (ECDHSignal, []byte, error) { func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) {
now := time.Now() now := time.Now()
err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW) err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW)
@ -430,7 +424,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req ECDHSignal) (ECDHSignal, []
}, shared_secret, nil }, shared_secret, nil
} }
func VerifyECDHSignal(now time.Time, sig ECDHSignal, window time.Duration) error { func VerifyECDHSignal(now time.Time, sig *ECDHSignal, window time.Duration) error {
earliest := now.Add(-window) earliest := now.Add(-window)
latest := now.Add(window) latest := now.Add(window)

@ -17,9 +17,9 @@ func (ext *GroupExt) Type() ExtType {
} }
func (ext *GroupExt) Serialize() ([]byte, error) { func (ext *GroupExt) Serialize() ([]byte, error) {
return json.MarshalIndent(&GroupExtJSON{ return json.Marshal(&GroupExtJSON{
Members: IDMap(ext.Members), Members: IDMap(ext.Members),
}, "", " ") })
} }
func (ext *GroupExt) Field(name string) interface{} { func (ext *GroupExt) Field(name string) interface{} {
@ -40,18 +40,12 @@ func NewGroupExt(members map[NodeID]string) *GroupExt {
} }
} }
func LoadGroupExt(ctx *Context, data []byte) (Extension, error) { func (ext *GroupExt) Deserialize(ctx *Context, data []byte) error {
var j GroupExtJSON var j GroupExtJSON
err := json.Unmarshal(data, &j) err := json.Unmarshal(data, &j)
members, err := LoadIDMap(j.Members) ext.Members, err = LoadIDMap(j.Members)
if err != nil { return err
return nil, err
}
return &GroupExt{
Members: members,
}, nil
} }
func (ext *GroupExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { func (ext *GroupExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {