From 147f44e5ff0b0b56756ea5b64fc8278ba93710e5 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Tue, 1 Aug 2023 20:55:15 -0600 Subject: [PATCH] hehe --- context.go | 73 ++++++++++++------- ecdh.go | 44 +++++------- gql.go | 31 ++++---- gql_mutation.go | 17 ----- gql_query.go | 10 +-- gql_resolvers.go | 34 --------- gql_test.go | 26 +++---- gql_types.go | 26 ------- graph_test.go | 2 +- lockable.go | 183 ++++++++++++++++++++++++++++------------------- lockable_test.go | 25 +++---- node.go | 27 ++++--- node_test.go | 25 +++---- policy.go | 165 ++++++++++++++++++++++-------------------- signal.go | 90 +++++++++++------------ user.go | 16 ++--- 16 files changed, 387 insertions(+), 407 deletions(-) diff --git a/context.go b/context.go index 9590ded..47b6688 100644 --- a/context.go +++ b/context.go @@ -35,7 +35,19 @@ func (ext ExtType) Prefix() string { return "EXTENSION: " } func (ext ExtType) String() string { return string(ext) } //Function to load an extension from bytes -type ExtensionLoadFunc func(*Context, []byte) (Extension, error) +type ExtensionLoadFunc func(*Context,[]byte) (Extension, error) +func LoadExtension[T any, E interface { + *T + Extension +}](ctx *Context, data []byte) (Extension, error) { + e := E(new(T)) + err := e.Deserialize(ctx, data) + if err != nil { + return nil, err + } + + return e, nil +} // ExtType and NodeType constants const ( @@ -53,7 +65,19 @@ var ( NodeNotFoundError = errors.New("Node not found in DB") ) -type SignalLoadFunc func(*Context, []byte) (Signal, error) +type SignalLoadFunc func(*Context,[]byte) (Signal, error) +func LoadSignal[T any, S interface{ + *T + Signal +}](ctx *Context, data []byte) (Signal, error) { + s := S(new(T)) + err := s.Deserialize(ctx, data) + if err != nil { + return nil, err + } + + return s, nil +} type SignalInfo struct { Load SignalLoadFunc @@ -127,11 +151,10 @@ func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) e return nil } -func (ctx *Context) RegisterSignal(signal_type SignalType, load_fn SignalLoadFunc) error { - if load_fn == nil { - return fmt.Errorf("def has no load function") - } - +func RegisterSignal[T any, S interface { + *T + Signal +}](ctx *Context, signal_type SignalType) error { type_hash := Hash(signal_type) _, exists := ctx.Signals[type_hash] if exists == true { @@ -139,18 +162,19 @@ func (ctx *Context) RegisterSignal(signal_type SignalType, load_fn SignalLoadFun } ctx.Signals[type_hash] = SignalInfo{ - Load: load_fn, + Load: LoadSignal[T, S], Type: signal_type, } return nil } // Add a node to a context, returns an error if the def is invalid or already exists in the context -func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc, data interface{}) error { - if load_fn == nil { - return fmt.Errorf("def has no load function") - } - +func RegisterExtension[T any, E interface{ + *T + Extension +}](ctx *Context, data interface{}) error { + var zero E + ext_type := zero.Type() type_hash := Hash(ext_type) _, exists := ctx.Extensions[type_hash] if exists == true { @@ -158,7 +182,7 @@ func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFun } ctx.Extensions[type_hash] = ExtensionInfo{ - Load: load_fn, + Load: LoadExtension[T,E], Type: ext_type, Data: data, } @@ -195,7 +219,7 @@ func (ctx *Context) GetNode(id NodeID) (*Node, error) { // Stop every running loop func (ctx *Context) Stop() { for _, node := range(ctx.Nodes) { - node.MsgChan <- Msg{ZeroID, StopSignal} + node.MsgChan <- Msg{ZeroID, &StopSignal} } } @@ -233,40 +257,41 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { } var err error - err = ctx.RegisterExtension(ACLExtType, LoadACLExt, NewACLExtContext()) + err = RegisterExtension[ACLExt,*ACLExt](ctx, NewACLExtContext()) if err != nil { return nil, err } - err = ctx.RegisterExtension(LockableExtType, LoadLockableExt, nil) + err = RegisterExtension[LockableExt,*LockableExt](ctx, nil) if err != nil { return nil, err } - err = ctx.RegisterExtension(ListenerExtType, LoadListenerExt, nil) + err = RegisterExtension[ListenerExt,*ListenerExt](ctx, nil) if err != nil { return nil, err } - err = ctx.RegisterExtension(ECDHExtType, LoadECDHExt, nil) + err = RegisterExtension[ECDHExt,*ECDHExt](ctx, nil) if err != nil { return nil, err } - err = ctx.RegisterExtension(GroupExtType, LoadGroupExt, nil) + err = RegisterExtension[GroupExt,*GroupExt](ctx, nil) if err != nil { return nil, err } gql_ctx := NewGQLExtContext() - err = ctx.RegisterExtension(GQLExtType, LoadGQLExt, gql_ctx) + err = RegisterExtension[GQLExt,*GQLExt](ctx, gql_ctx) if err != nil { return nil, err } - err = ctx.RegisterSignal(StopSignalType, func(ctx *Context, data []byte) (Signal, error) { - return StopSignal, nil - }) + err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StopSignalType) + if err != nil { + return nil, err + } err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType}) if err != nil { diff --git a/ecdh.go b/ecdh.go index 9fda69d..e517917 100644 --- a/ecdh.go +++ b/ecdh.go @@ -103,9 +103,7 @@ func (ext *ECDHExt) Field(name string) interface{} { }) } -func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal ECDHSignal) { - ser, _ := signal.Serialize() - ctx.Log.Logf("ecdh", "ECDH_SIGNAL: %s->%s - %s", source, node.ID, ser) +func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal *ECDHSignal) { switch signal.Str { case "req": state, exists := ext.ECDHStates[source] @@ -117,7 +115,7 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si state.SharedSecret = shared_secret ext.ECDHStates[source] = state ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) - ctx.Send(node.ID, source, resp) + ctx.Send(node.ID, source, &resp) } else { ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) // TODO: send error response @@ -125,7 +123,8 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si case "resp": state, exists := ext.ECDHStates[source] if exists == false || state.ECKey == nil { - ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), "no_req"}) + resp := NewErrorSignal(signal.ID(), "no_req") + ctx.Send(node.ID, source, &resp) } else { err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) if err == nil { @@ -143,21 +142,22 @@ func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, si } } -func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal StringSignal) { - ser, _ := signal.Serialize() - ctx.Log.Logf("ecdh", "ECHD_STATE: %s->%s - %s", source, node.ID, ser) +func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) { } -func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal ECDHProxySignal) { +func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal *ECDHProxySignal) { state, exists := ext.ECDHStates[source] if exists == false { - ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), "no_req"}) + resp := NewErrorSignal(signal.ID(), "no_req") + ctx.Send(node.ID, source, &resp) } else if state.SharedSecret == nil { - ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), "no_shared"}) + resp := NewErrorSignal(signal.ID(), "no_shared") + ctx.Send(node.ID, source, &resp) } else { - unwrapped_signal, err := ParseECDHProxySignal(ctx, &signal, state.SharedSecret) + unwrapped_signal, err := ParseECDHProxySignal(ctx, signal, state.SharedSecret) if err != nil { - ctx.Send(node.ID, source, StringSignal{NewDirectSignal(ECDHStateSignalType), err.Error()}) + resp := NewErrorSignal(signal.ID(), err.Error()) + ctx.Send(node.ID, source, &resp) } else { //TODO: Figure out what I was trying to do here and fix it ctx.Send(signal.Source, signal.Dest, unwrapped_signal) @@ -170,13 +170,13 @@ func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Sign case Direct: switch signal.Type() { case ECDHProxySignalType: - ecdh_signal := signal.(ECDHProxySignal) + ecdh_signal := signal.(*ECDHProxySignal) ext.HandleECDHProxySignal(ctx, source, node, ecdh_signal) case ECDHStateSignalType: - ecdh_signal := signal.(StringSignal) + ecdh_signal := signal.(*StringSignal) ext.HandleStateSignal(ctx, source, node, ecdh_signal) case ECDHSignalType: - ecdh_signal := signal.(ECDHSignal) + ecdh_signal := signal.(*ECDHSignal) ext.HandleECDHSignal(ctx, source, node, ecdh_signal) default: } @@ -189,15 +189,9 @@ func (ext *ECDHExt) Type() ExtType { } func (ext *ECDHExt) Serialize() ([]byte, error) { - return json.MarshalIndent(ext, "", " ") + return json.Marshal(ext) } -func LoadECDHExt(ctx *Context, data []byte) (Extension, error) { - var ext ECDHExt - err := json.Unmarshal(data, &ext) - if err != nil { - return nil, err - } - - return &ext, nil +func (ext *ECDHExt) Deserialize(ctx *Context, data []byte) error { + return json.Unmarshal(data, &ext) } diff --git a/gql.go b/gql.go index 85d6498..9415da0 100644 --- a/gql.go +++ b/gql.go @@ -610,7 +610,7 @@ func GQLFields(ctx *GQLExtContext, field_names []string) (graphql.Fields, []ExtT type NodeResult struct { ID NodeID - Result ReadResultSignal + Result *ReadResultSignal } type ListField struct { @@ -923,7 +923,7 @@ func NewGQLExtContext() *GQLExtContext { } response_chan := ctx.Ext.GetResponseChannel(sig.ID()) - err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, sig) + err = ctx.Context.Send(ctx.Server.ID, ctx.Server.ID, &sig) if err != nil { ctx.Ext.FreeResponseChannel(sig.ID()) return nil, err @@ -1048,7 +1048,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa // Process ReadResultSignalType by forwarding it to the waiting resolver if signal.Type() == ErrorSignalType { // TODO: Forward to resolver if waiting for it - sig := signal.(ErrorSignal) + sig := signal.(*ErrorSignal) response_chan := ext.FreeResponseChannel(sig.ID()) if response_chan != nil { select { @@ -1062,7 +1062,7 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa ctx.Log.Logf("gql", "received error signal response %+v with no mapped resolver", sig) } } else if signal.Type() == ReadResultSignalType { - sig := signal.(ReadResultSignal) + sig := signal.(*ReadResultSignal) response_chan := ext.FreeResponseChannel(sig.ID()) if response_chan != nil { select { @@ -1075,14 +1075,15 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa ctx.Log.Logf("gql", "Received read result that wasn't expected - %+v", sig) } } else if signal.Type() == GQLStateSignalType { - sig := signal.(StringSignal) + sig := signal.(*StringSignal) switch sig.Str { case "start_server": if ext.State == "stopped" { err := ext.StartGQLServer(ctx, node) if err == nil { ext.State = "running" - ctx.Send(node.ID, source, StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}) + resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"} + ctx.Send(node.ID, source, &resp) } } case "stop_server": @@ -1090,7 +1091,8 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa err := ext.StopGQLServer() if err == nil { ext.State = "stopped" - ctx.Send(node.ID, source, StringSignal{NewDirectSignal(GQLStateSignalType), "server_stopped"}) + resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_stopped"} + ctx.Send(node.ID, source, &resp) } } default: @@ -1102,7 +1104,8 @@ func (ext *GQLExt) Process(ctx *Context, source NodeID, node *Node, signal Signa case "running": err := ext.StartGQLServer(ctx, node) if err == nil { - ctx.Send(node.ID, source, StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"}) + resp := StringSignal{NewDirectSignal(GQLStateSignalType), "server_started"} + ctx.Send(node.ID, source, &resp) } case "stopped": default: @@ -1116,7 +1119,7 @@ func (ext *GQLExt) Type() ExtType { } func (ext *GQLExt) Serialize() ([]byte, error) { - return json.MarshalIndent(ext, "", " ") + return json.Marshal(ext) } var ecdsa_curves = map[uint8]elliptic.Curve{ @@ -1135,14 +1138,8 @@ var ecdh_curve_ids = map[ecdh.Curve]uint8{ ecdh.P256(): 0, } -func LoadGQLExt(ctx *Context, data []byte) (Extension, error) { - var ext GQLExt - err := json.Unmarshal(data, &ext) - if err != nil { - return nil, err - } - - return NewGQLExt(ctx, ext.Listen, ext.tls_cert, ext.tls_key, ext.State) +func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error { + return json.Unmarshal(data, &ext) } func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) { diff --git a/gql_mutation.go b/gql_mutation.go index 8d24754..948021c 100644 --- a/gql_mutation.go +++ b/gql_mutation.go @@ -1,21 +1,4 @@ package graphvent import ( - "github.com/graphql-go/graphql" ) -var MutationStop = NewField(func()*graphql.Field { - mutation_stop := &graphql.Field{ - Type: TypeSignal.Type, - Args: graphql.FieldConfigArgument{ - "id": &graphql.ArgumentConfig{ - Type: graphql.String, - }, - }, - Resolve: func(p graphql.ResolveParams) (interface{}, error) { - return StopSignal, nil - }, - } - - return mutation_stop -}) - diff --git a/gql_query.go b/gql_query.go index b68d462..84a9d10 100644 --- a/gql_query.go +++ b/gql_query.go @@ -51,7 +51,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ } // Create a read signal, send it to the specified node, and add the wait to the response map if the send returns no error read_signal := NewReadSignal(ext_fields) - auth_signal, err := NewAuthorizedSignal(ctx.Key, read_signal) + auth_signal, err := NewAuthorizedSignal(ctx.Key, &read_signal) if err != nil { return nil, err } @@ -61,7 +61,7 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ resp_channels[read_signal.ID()] = response_chan node_ids[read_signal.ID()] = id - err = ctx.Context.Send(ctx.Server.ID, id, auth_signal) + err = ctx.Context.Send(ctx.Server.ID, id, &auth_signal) if err != nil { ctx.Ext.FreeResponseChannel(read_signal.ID()) return nil, err @@ -76,10 +76,10 @@ func ResolveNodes(ctx *ResolveContext, p graphql.ResolveParams, ids []NodeID) ([ return nil, err } switch resp := response.(type) { - case ReadResultSignal: + case *ReadResultSignal: responses = append(responses, NodeResult{node_ids[sig_id], resp}) - case ErrorSignal: - return nil, resp.Error + case *ErrorSignal: + return nil, fmt.Errorf(resp.Str) default: return nil, fmt.Errorf("BAD_TYPE: %s", reflect.TypeOf(resp)) } diff --git a/gql_resolvers.go b/gql_resolvers.go index 01bda71..658af9e 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -84,37 +84,3 @@ func ResolveNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { return Hash(node.Result.NodeType), nil }) } - -func GQLSignalFn(p graphql.ResolveParams, fn func(Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { - if signal, ok := p.Source.(Signal); ok { - return fn(signal, p) - } - return nil, fmt.Errorf("Failed to cast source to event") -} - -func GQLSignalType(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){ - return signal.Type(), nil - }) -} - -func GQLSignalDirection(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){ - direction := signal.Direction() - if direction == Up { - return "up", nil - } else if direction == Down { - return "down", nil - } else if direction == Direct { - return "direct", nil - } - return nil, fmt.Errorf("Invalid direction: %+v", direction) - }) -} - -func GQLSignalString(p graphql.ResolveParams) (interface{}, error) { - return GQLSignalFn(p, func(signal Signal, p graphql.ResolveParams)(interface{}, error){ - ser, err := signal.Serialize() - return string(ser), err - }) -} diff --git a/gql_test.go b/gql_test.go index e01d189..9fd9767 100644 --- a/gql_test.go +++ b/gql_test.go @@ -15,7 +15,7 @@ import ( ) func TestGQL(t *testing.T) { - ctx := logTestContext(t, []string{"node", "test", "gql", "policy"}) + ctx := logTestContext(t, []string{}) TestNodeType := NodeType("TEST") err := ctx.RegisterNodeType(TestNodeType, []ExtType{LockableExtType, ACLExtType}) @@ -25,15 +25,16 @@ func TestGQL(t *testing.T) { fatalErr(t, err) listener_ext := NewListenerExt(10) policy := NewAllNodesPolicy(Actions{MakeAction("+")}) + start_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "start_server"} gql := NewNode(ctx, nil, GQLNodeType, 10, []QueuedSignal{ - QueuedSignal{uuid.New(), StringSignal{NewDirectSignal(GQLStateSignalType), "start_server"}, time.Now()}, - }, NewLockableExt(), NewACLExt(policy), gql_ext, NewGroupExt(nil), listener_ext) - n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(policy)) + QueuedSignal{uuid.New(), &start_signal, time.Now()}, + }, NewLockableExt(), NewACLExt(&policy), gql_ext, NewGroupExt(nil), listener_ext) + n1 := NewNode(ctx, nil, TestNodeType, 10, nil, NewLockableExt(), NewACLExt(&policy)) err = LinkRequirement(ctx, gql.ID, n1.ID) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig StringSignal) bool { + _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool { return sig.Str == "server_started" }) fatalErr(t, err) @@ -85,14 +86,15 @@ func TestGQL(t *testing.T) { resp_2 := SendGQL(req_2) ctx.Log.Logf("test", "RESP_2: %s", resp_2) - ctx.Send(n1.ID, gql.ID, StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"}) - _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig StringSignal) bool { + stop_signal := StringSignal{NewDirectSignal(GQLStateSignalType), "stop_server"} + ctx.Send(n1.ID, gql.ID, &stop_signal) + _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, GQLStateSignalType, func(sig *StringSignal) bool { return sig.Str == "server_stopped" }) } func TestGQLDB(t *testing.T) { - ctx := logTestContext(t, []string{}) + ctx := logTestContext(t, []string{"listener"}) TestUserNodeType := NodeType("TEST_USER") err := ctx.RegisterNodeType(TestUserNodeType, []ExtType{}) @@ -111,10 +113,10 @@ func TestGQLDB(t *testing.T) { NewGroupExt(nil)) ctx.Log.Logf("test", "GQL_ID: %s", gql.ID) - err = ctx.Send(gql.ID, gql.ID, StopSignal) + err = ctx.Send(gql.ID, gql.ID, &StopSignal) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStringSignal) bool { + _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "stopped" && sig.NodeID == gql.ID }) fatalErr(t, err) @@ -130,9 +132,9 @@ func TestGQLDB(t *testing.T) { fatalErr(t, err) listener_ext, err = GetExt[*ListenerExt](gql_loaded) fatalErr(t, err) - err = ctx.Send(gql_loaded.ID, gql_loaded.ID, StopSignal) + err = ctx.Send(gql_loaded.ID, gql_loaded.ID, &StopSignal) fatalErr(t, err) - _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig IDStringSignal) bool { + _, err = WaitForSignal(ctx, listener_ext, 100*time.Millisecond, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "stopped" && sig.NodeID == gql_loaded.ID }) fatalErr(t, err) diff --git a/gql_types.go b/gql_types.go index f38e9ab..f63678a 100644 --- a/gql_types.go +++ b/gql_types.go @@ -35,29 +35,3 @@ func NewGQLNodeType(node_type NodeType, interfaces []*graphql.Interface, init fu init(&gql) return &gql } - -var TypeSignal = NewSingleton(func() *graphql.Object { - gql_type_signal := graphql.NewObject(graphql.ObjectConfig{ - Name: "Signal", - IsTypeOf: func(p graphql.IsTypeOfParams) bool { - _, ok := p.Value.(Signal) - return ok - }, - Fields: graphql.Fields{}, - }) - - gql_type_signal.AddFieldConfig("Type", &graphql.Field{ - Type: graphql.String, - Resolve: GQLSignalType, - }) - gql_type_signal.AddFieldConfig("Direction", &graphql.Field{ - Type: graphql.String, - Resolve: GQLSignalDirection, - }) - gql_type_signal.AddFieldConfig("String", &graphql.Field{ - Type: graphql.String, - Resolve: GQLSignalString, - }) - return gql_type_signal -}, nil) - diff --git a/graph_test.go b/graph_test.go index d1c7168..9ad43d2 100644 --- a/graph_test.go +++ b/graph_test.go @@ -17,7 +17,7 @@ func NewSimpleListener(ctx *Context, buffer int) (*Node, *ListenerExt) { 10, nil, listener_extension, - NewACLExt(policy), + NewACLExt(&policy), NewLockableExt()) return listener, listener_extension diff --git a/lockable.go b/lockable.go index bbd420b..a775f6e 100644 --- a/lockable.go +++ b/lockable.go @@ -2,7 +2,6 @@ package graphvent import ( "encoding/json" - "fmt" ) // A Listener extension provides a channel that can receive signals on a different thread @@ -31,14 +30,10 @@ func (ext *ListenerExt) Field(name string) interface{} { } // Simple load function, unmarshal the buffer int from json -func LoadListenerExt(ctx *Context, data []byte) (Extension, error) { - var j int - err := json.Unmarshal(data, &j) - if err != nil { - return nil, err - } - - return NewListenerExt(j), nil +func (ext *ListenerExt) Deserialize(ctx *Context, data []byte) error { + err := json.Unmarshal(data, &ext.Buffer) + ext.Chan = make(chan Signal, ext.Buffer) + return err } func (listener *ListenerExt) Type() ExtType { @@ -47,7 +42,7 @@ func (listener *ListenerExt) Type() ExtType { // Send the signal to the channel, logging an overflow if it occurs func (ext *ListenerExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) { - ctx.Log.Logf("signal", "LISTENER_PROCESS: %s - %+v", node.ID, signal) + ctx.Log.Logf("listener", "LISTENER_PROCESS: %s - %+v", node.ID, signal) select { case ext.Chan <- signal: default: @@ -116,18 +111,8 @@ func (ext *LockableExt) Field(name string) interface{} { }) } -func LoadLockableExt(ctx *Context, data []byte) (Extension, error) { - var ext LockableExt - err := json.Unmarshal(data, &ext) - if err != nil { - return nil, err - } - - return &ext, nil -} - func (ext *ListenerExt) Serialize() ([]byte, error) { - return json.MarshalIndent(ext.Buffer, "", " ") + return json.Marshal(ext.Buffer) } func (ext *LockableExt) Type() ExtType { @@ -135,7 +120,11 @@ func (ext *LockableExt) Type() ExtType { } func (ext *LockableExt) Serialize() ([]byte, error) { - return json.MarshalIndent(ext, "", " ") + return json.Marshal(ext) +} + +func (ext *LockableExt) Deserialize(ctx *Context, data []byte) error { + return json.Unmarshal(data, ext) } func NewLockableExt() *LockableExt { @@ -149,36 +138,43 @@ func NewLockableExt() *LockableExt { // Send the signal to unlock a node from itself func UnlockLockable(ctx *Context, node *Node) error { - return ctx.Send(node.ID, node.ID, NewLockSignal("unlock")) + lock_signal := NewLockSignal("unlock") + return ctx.Send(node.ID, node.ID, &lock_signal) } // Send the signal to lock a node from itself func LockLockable(ctx *Context, node *Node) error { - return ctx.Send(node.ID, node.ID, NewLockSignal("lock")) + lock_signal := NewLockSignal("lock") + return ctx.Send(node.ID, node.ID, &lock_signal) } // Setup a node to send the initial requirement link signal, then send the signal func LinkRequirement(ctx *Context, dependency NodeID, requirement NodeID) error { - return ctx.Send(dependency, dependency, NewLinkStartSignal("req", requirement)) + start_signal := NewLinkStartSignal("req", requirement) + return ctx.Send(dependency, dependency, &start_signal) } // Handle a LockSignal and update the extensions owner/requirement states -func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal StringSignal) { +func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) { ctx.Log.Logf("lockable", "LOCK_SIGNAL: %s->%s %+v", source, node.ID, signal) state := signal.Str switch state { case "unlock": if ext.Owner == nil { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_unlocked"))) + resp := NewErrorSignal(signal.ID(), "already_unlocked") + ctx.Send(node.ID, source, &resp) } else if source != *ext.Owner { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_owner"))) + resp := NewErrorSignal(signal.ID(), "not_owner") + ctx.Send(node.ID, source, &resp) } else if ext.PendingOwner == nil { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_unlocking"))) + resp := NewErrorSignal(signal.ID(), "already_unlocking") + ctx.Send(node.ID, source, &resp) } else { if len(ext.Requirements) == 0 { ext.Owner = nil ext.PendingOwner = nil - ctx.Send(node.ID, source, NewLockSignal("unlocked")) + resp := NewLockSignal("unlocked") + ctx.Send(node.ID, source, &resp) } else { ext.PendingOwner = nil for id, state := range(ext.Requirements) { @@ -188,22 +184,27 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node } state.Lock = "unlocking" ext.Requirements[id] = state - ctx.Send(node.ID, id, NewLockSignal("unlock")) + resp := NewLockSignal("unlock") + ctx.Send(node.ID, id, &resp) } } if source != node.ID { - ctx.Send(node.ID, source, NewLockSignal("unlocking")) + resp := NewLockSignal("unlocking") + ctx.Send(node.ID, source, &resp) } } } case "unlocking": state, exists := ext.Requirements[source] if exists == false { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) + resp := NewErrorSignal(signal.ID(), "not_requirement") + ctx.Send(node.ID, source, &resp) } else if state.Link != "linked" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("node_not_linked"))) + resp := NewErrorSignal(signal.ID(), "not_linked") + ctx.Send(node.ID, source, &resp) } else if state.Lock != "unlocking" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_unlocking"))) + resp := NewErrorSignal(signal.ID(), "not_unlocking") + ctx.Send(node.ID, source, &resp) } case "unlocked": @@ -213,11 +214,14 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node state, exists := ext.Requirements[source] if exists == false { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) + resp := NewErrorSignal(signal.ID(), "not_requirement") + ctx.Send(node.ID, source, &resp) } else if state.Link != "linked" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_linked"))) + resp := NewErrorSignal(signal.ID(), "not_linked") + ctx.Send(node.ID, source, &resp) } else if state.Lock != "unlocking" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_unlocking"))) + resp := NewErrorSignal(signal.ID(), "not_unlocking") + ctx.Send(node.ID, source, &resp) } else { state.Lock = "unlocked" ext.Requirements[source] = state @@ -237,7 +241,8 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node if linked == unlocked { previous_owner := *ext.Owner ext.Owner = nil - ctx.Send(node.ID, previous_owner, NewLockSignal("unlocked")) + resp := NewLockSignal("unlocked") + ctx.Send(node.ID, previous_owner, &resp) } } } @@ -248,11 +253,14 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node state, exists := ext.Requirements[source] if exists == false { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) + resp := NewErrorSignal(signal.ID(), "not_requirement") + ctx.Send(node.ID, source, &resp) } else if state.Link != "linked" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_linked"))) + resp := NewErrorSignal(signal.ID(), "not_linked") + ctx.Send(node.ID, source, &resp) } else if state.Lock != "locking" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_locking"))) + resp := NewErrorSignal(signal.ID(), "not_locking") + ctx.Send(node.ID, source, &resp) } else { state.Lock = "locked" ext.Requirements[source] = state @@ -271,31 +279,38 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node if linked == locked { ext.Owner = ext.PendingOwner - ctx.Send(node.ID, *ext.Owner, NewLockSignal("locked")) + resp := NewLockSignal("locked") + ctx.Send(node.ID, *ext.Owner, &resp) } } } case "locking": state, exists := ext.Requirements[source] if exists == false { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_requirement"))) + resp := NewErrorSignal(signal.ID(), "not_requirement") + ctx.Send(node.ID, source, &resp) } else if state.Link != "linked" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("node_not_linked"))) + resp := NewErrorSignal(signal.ID(), "not_linked") + ctx.Send(node.ID, source, &resp) } else if state.Lock != "locking" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("not_locking"))) + resp := NewErrorSignal(signal.ID(), "not_locking") + ctx.Send(node.ID, source, &resp) } case "lock": if ext.Owner != nil { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_locked"))) + resp := NewErrorSignal(signal.ID(), "already_locked") + ctx.Send(node.ID, source, &resp) } else if ext.PendingOwner != nil { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_locking"))) + resp := NewErrorSignal(signal.ID(), "already_locking") + ctx.Send(node.ID, source, &resp) } else { owner := source if len(ext.Requirements) == 0 { ext.Owner = &owner ext.PendingOwner = ext.Owner - ctx.Send(node.ID, source, NewLockSignal("locked")) + resp := NewLockSignal("locked") + ctx.Send(node.ID, source, &resp) } else { ext.PendingOwner = &owner for id, state := range(ext.Requirements) { @@ -305,11 +320,13 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node } state.Lock = "locking" ext.Requirements[id] = state - ctx.Send(node.ID, id, NewLockSignal("lock")) + sub := NewLockSignal("lock") + ctx.Send(node.ID, id, &sub) } } if source != node.ID { - ctx.Send(node.ID, source, NewLockSignal("locking")) + resp := NewLockSignal("locking") + ctx.Send(node.ID, source, &resp) } } } @@ -318,7 +335,7 @@ func (ext *LockableExt) HandleLockSignal(ctx *Context, source NodeID, node *Node } } -func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal IDStringSignal) { +func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node *Node, signal *IDStringSignal) { ctx.Log.Logf("lockable", "LINK__START_SIGNAL: %s->%s %+v", source, node.ID, signal) link_type := signal.Str target := signal.NodeID @@ -327,72 +344,90 @@ func (ext *LockableExt) HandleLinkStartSignal(ctx *Context, source NodeID, node state, exists := ext.Requirements[target] _, dep_exists := ext.Dependencies[target] if ext.Owner != nil { - ctx.Send(node.ID, source, NewLinkStartSignal("locked", target)) + resp := NewLinkStartSignal("locked", target) + ctx.Send(node.ID, source, &resp) } else if ext.Owner != ext.PendingOwner { if ext.PendingOwner == nil { - ctx.Send(node.ID, source, NewLinkStartSignal("unlocking", target)) + resp := NewLinkStartSignal("unlocking", target) + ctx.Send(node.ID, source, &resp) } else { - ctx.Send(node.ID, source, NewLinkStartSignal("locking", target)) + resp := NewLinkStartSignal("locking", target) + ctx.Send(node.ID, source, &resp) } } else if exists == true { if state.Link == "linking" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_linking_req"))) + resp := NewErrorSignal(signal.ID(), "already_linking_req") + ctx.Send(node.ID, source, &resp) } else if state.Link == "linked" { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), fmt.Errorf("already_req"))) + resp := NewErrorSignal(signal.ID(), "already_req") + ctx.Send(node.ID, source, &resp) } } else if dep_exists == true { - ctx.Send(node.ID, source, NewLinkStartSignal("already_dep", target)) + resp := NewLinkStartSignal("already_dep", target) + ctx.Send(node.ID, source, &resp) } else { ext.Requirements[target] = LinkState{"linking", "unlocked", source} - ctx.Send(node.ID, target, NewLinkSignal("linked_as_req")) - ctx.Send(node.ID, source, NewLinkStartSignal("linking_req", target)) + resp := NewLinkSignal("linked_as_req") + ctx.Send(node.ID, target, &resp) + notify := NewLinkStartSignal("linking_req", target) + ctx.Send(node.ID, source, ¬ify) } } } // Handle LinkSignal, updating the extensions requirements and dependencies as necessary // TODO: Add unlink -func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal StringSignal) { +func (ext *LockableExt) HandleLinkSignal(ctx *Context, source NodeID, node *Node, signal *StringSignal) { ctx.Log.Logf("lockable", "LINK_SIGNAL: %s->%s %+v", source, node.ID, signal) state := signal.Str switch state { case "linked_as_dep": state, exists := ext.Requirements[source] if exists == true && state.Link == "linked" { - ctx.Send(node.ID, state.Initiator, NewLinkStartSignal("linked_as_req", source)) + resp := NewLinkStartSignal("linked_as_req", source) + ctx.Send(node.ID, state.Initiator, &resp) } else if state.Link == "linking" { state.Link = "linked" ext.Requirements[source] = state - ctx.Send(node.ID, source, NewLinkSignal("linked_as_req")) + resp := NewLinkSignal("linked_as_req") + ctx.Send(node.ID, source, &resp) } else if ext.PendingOwner != ext.Owner { if ext.Owner == nil { - ctx.Send(node.ID, source, NewLinkSignal("locking")) + resp := NewLinkSignal("locking") + ctx.Send(node.ID, source, &resp) } else { - ctx.Send(node.ID, source, NewLinkSignal("unlocking")) + resp := NewLinkSignal("unlocking") + ctx.Send(node.ID, source, &resp) } } else { ext.Requirements[source] = LinkState{"linking", "unlocked", source} - ctx.Send(node.ID, source, NewLinkSignal("linked_as_req")) + resp := NewLinkSignal("linked_as_req") + ctx.Send(node.ID, source, &resp) } ctx.Log.Logf("lockable", "%s is a dependency of %s", node.ID, source) case "linked_as_req": state, exists := ext.Dependencies[source] if exists == true && state.Link == "linked" { - ctx.Send(node.ID, state.Initiator, NewLinkStartSignal("linked_as_dep", source)) + resp := NewLinkStartSignal("linked_as_dep", source) + ctx.Send(node.ID, state.Initiator, &resp) } else if state.Link == "linking" { state.Link = "linked" ext.Dependencies[source] = state - ctx.Send(node.ID, source, NewLinkSignal("linked_as_dep")) + resp := NewLinkSignal("linked_as_dep") + ctx.Send(node.ID, source, &resp) } else if ext.PendingOwner != ext.Owner { if ext.Owner == nil { - ctx.Send(node.ID, source, NewLinkSignal("locking")) + resp := NewLinkSignal("locking") + ctx.Send(node.ID, source, &resp) } else { - ctx.Send(node.ID, source, NewLinkSignal("unlocking")) + resp := NewLinkSignal("unlocking") + ctx.Send(node.ID, source, &resp) } } else { ext.Dependencies[source] = LinkState{"linking", "unlocked", source} - ctx.Send(node.ID, source, NewLinkSignal("linked_as_dep")) + resp := NewLinkSignal("linked_as_dep") + ctx.Send(node.ID, source, &resp) } ctx.Log.Logf("lockable", "%s is a requirement of %s", node.ID, source) @@ -442,11 +477,11 @@ func (ext *LockableExt) Process(ctx *Context, source NodeID, node *Node, signal case Direct: switch signal.Type() { case LinkSignalType: - ext.HandleLinkSignal(ctx, source, node, signal.(StringSignal)) + ext.HandleLinkSignal(ctx, source, node, signal.(*StringSignal)) case LockSignalType: - ext.HandleLockSignal(ctx, source, node, signal.(StringSignal)) + ext.HandleLockSignal(ctx, source, node, signal.(*StringSignal)) case LinkStartSignalType: - ext.HandleLinkStartSignal(ctx, source, node, signal.(IDStringSignal)) + ext.HandleLinkStartSignal(ctx, source, node, signal.(*IDStringSignal)) default: } default: diff --git a/lockable_test.go b/lockable_test.go index bc46a9d..bf77350 100644 --- a/lockable_test.go +++ b/lockable_test.go @@ -26,13 +26,13 @@ func TestLink(t *testing.T) { l1_listener := NewListenerExt(10) l1 := NewNode(ctx, nil, TestLockableType, 10, nil, l1_listener, - NewACLExt(link_policy), + NewACLExt(&link_policy), NewLockableExt(), ) l2_listener := NewListenerExt(10) l2 := NewNode(ctx, nil, TestLockableType, 10, nil, l2_listener, - NewACLExt(link_policy), + NewACLExt(&link_policy), NewLockableExt(), ) @@ -40,20 +40,21 @@ func TestLink(t *testing.T) { err := LinkRequirement(ctx, l1.ID, l2.ID) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, func(sig IDStringSignal) bool { + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool { return sig.Str == "linked_as_req" }) fatalErr(t, err) - err = ctx.Send(l2.ID, l2.ID, NewStatusSignal("TEST", l2.ID)) + sig1 := NewStatusSignal("TEST", l2.ID) + err = ctx.Send(l2.ID, l2.ID, &sig1) fatalErr(t, err) - _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig IDStringSignal) bool { + _, err = WaitForSignal(ctx, l1_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "TEST" }) fatalErr(t, err) - _, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig IDStringSignal) bool { + _, err = WaitForSignal(ctx, l2_listener, time.Millisecond*10, StatusSignalType, func(sig *IDStringSignal) bool { return sig.Str == "TEST" }) fatalErr(t, err) @@ -64,7 +65,7 @@ func TestLink10K(t *testing.T) { NewLockable := func()(*Node) { l := NewNode(ctx, nil, TestLockableType, 10, nil, - NewACLExt(lock_policy, link_policy), + NewACLExt(&lock_policy, &link_policy), NewLockableExt(), ) return l @@ -74,7 +75,7 @@ func TestLink10K(t *testing.T) { listener := NewListenerExt(100000) l := NewNode(ctx, nil, TestLockableType, 256, nil, listener, - NewACLExt(lock_policy, link_policy), + NewACLExt(&lock_policy, &link_policy), NewLockableExt(), ) return l, listener @@ -91,7 +92,7 @@ func TestLink10K(t *testing.T) { for range(lockables) { - _, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, func(sig IDStringSignal) bool { + _, err := WaitForSignal(ctx, l0_listener, time.Millisecond*10, LinkStartSignalType, func(sig *IDStringSignal) bool { return sig.Str == "linked_as_req" }) fatalErr(t, err) @@ -107,7 +108,7 @@ func TestLock(t *testing.T) { listener := NewListenerExt(100) l := NewNode(ctx, nil, TestLockableType, 10, nil, listener, - NewACLExt(lock_policy, link_policy), + NewACLExt(&lock_policy, &link_policy), NewLockableExt(), ) return l, listener @@ -140,11 +141,11 @@ func TestLock(t *testing.T) { err = LinkRequirement(ctx, l0.ID, l5.ID) fatalErr(t, err) - linked_as_req := func(sig IDStringSignal) bool { + linked_as_req := func(sig *IDStringSignal) bool { return sig.Str == "linked_as_req" } - locked := func(sig StringSignal) bool { + locked := func(sig *StringSignal) bool { return sig.Str == "locked" } diff --git a/node.go b/node.go index 4000ea4..c8dd96f 100644 --- a/node.go +++ b/node.go @@ -84,8 +84,9 @@ func RandID() NodeID { // A Serializable has a type that can be used to map to it, and a function to serialize the current state type Serializable[I comparable] interface { + Serialize()([]byte,error) + Deserialize(*Context,[]byte)error Type() I - Serialize() ([]byte, error) } // Extensions are data attached to nodes that process signals @@ -197,7 +198,7 @@ func nodeLoop(ctx *Context, node *Node) error { } // Queue the signal for extensions to perform startup actions - node.QueueSignal(time.Now(), NewDirectSignal(StartSignalType)) + node.QueueSignal(time.Now(), &StartSignal) for true { var signal Signal @@ -209,7 +210,8 @@ func nodeLoop(ctx *Context, node *Node) error { err := Allowed(ctx, msg.Source, signal.Permission(), node) if err != nil { ctx.Log.Logf("signal", "SIGNAL_POLICY_ERR: %s", err) - ctx.Send(node.ID, msg.Source, NewErrorSignal(msg.Signal.ID(), err)) + resp := NewErrorSignal(msg.Signal.ID(), err.Error()) + ctx.Send(node.ID, msg.Source, &resp) continue } case <-node.TimeoutChan: @@ -240,7 +242,7 @@ func nodeLoop(ctx *Context, node *Node) error { // Unwrap Authorized Signals if signal.Type() == AuthorizedSignalType { - sig, ok := signal.(AuthorizedSignal) + sig, ok := signal.(*AuthorizedSignal) if ok == false { ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: bad cast %+v", reflect.TypeOf(signal)) } else { @@ -254,14 +256,16 @@ func nodeLoop(ctx *Context, node *Node) error { err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node) if err != nil { ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL_POLICY_ERR: %s", err) - ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), err)) + resp := NewErrorSignal(sig.ID(), err.Error()) + ctx.Send(node.ID, source, &resp) } else { // Unwrap the signal without changing the source signal = sig.Signal } } else { ctx.Log.Logf("signal", "AUTHORIZED_SIGNAL: failed to validate") - ctx.Send(node.ID, source, NewErrorSignal(sig.ID(), fmt.Errorf("failed to validate signature"))) + resp := NewErrorSignal(sig.ID(), "signature validation failed") + ctx.Send(node.ID, source, &resp) } } } @@ -269,16 +273,19 @@ func nodeLoop(ctx *Context, node *Node) error { // Handle special signal types if signal.Type() == StopSignalType { - ctx.Send(node.ID, source, NewErrorSignal(signal.ID(), nil)) - node.Process(ctx, node.ID, NewStatusSignal("stopped", node.ID)) + resp := NewErrorSignal(signal.ID(), "stopped") + ctx.Send(node.ID, source, &resp) + status := NewStatusSignal("stopped", node.ID) + node.Process(ctx, node.ID, &status) break } else if signal.Type() == ReadSignalType { - read_signal, ok := signal.(ReadSignal) + read_signal, ok := signal.(*ReadSignal) if ok == false { ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) } else { result := ReadNodeFields(ctx, node, source, read_signal.Extensions) - ctx.Send(node.ID, source, NewReadResultSignal(read_signal.ID(), node.Type, result)) + resp := NewReadResultSignal(read_signal.ID(), node.Type, result) + ctx.Send(node.ID, source, &resp) } } diff --git a/node_test.go b/node_test.go index 3b60d32..7fb9c07 100644 --- a/node_test.go +++ b/node_test.go @@ -41,18 +41,19 @@ func TestNodeRead(t *testing.T) { n1_id: Actions{MakeAction(ReadResultSignalType, "+")}, }) n2_listener := NewListenerExt(10) - n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener) + n2 := NewNode(ctx, n2_key, node_type, 10, nil, NewACLExt(&n2_policy), NewGroupExt(nil), NewECDHExt(), n2_listener) n1_policy := NewPerNodePolicy(map[NodeID]Actions{ n2_id: Actions{MakeAction(ReadSignalType, "+")}, }) - n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(n1_policy), NewGroupExt(nil), NewECDHExt()) + n1 := NewNode(ctx, n1_key, node_type, 10, nil, NewACLExt(&n1_policy), NewGroupExt(nil), NewECDHExt()) - ctx.Send(n2.ID, n1.ID, NewReadSignal(map[ExtType][]string{ + read_sig := NewReadSignal(map[ExtType][]string{ GroupExtType: []string{"members"}, - })) + }) + ctx.Send(n2.ID, n1.ID, &read_sig) - res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig ReadResultSignal) bool { + res, err := WaitForSignal(ctx, n2_listener, 10*time.Millisecond, ReadResultSignalType, func(sig *ReadResultSignal) bool { return true }) fatalErr(t, err) @@ -68,11 +69,11 @@ func TestECDH(t *testing.T) { n1_listener := NewListenerExt(10) ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+"), MakeAction(ECDHProxySignalType, "+")}) - n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt(), n1_listener) - n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt()) + n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt(), n1_listener) + n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy), NewECDHExt()) n3_listener := NewListenerExt(10) n3_policy := NewPerNodePolicy(NodeActions{n1.ID: Actions{MakeAction(StopSignalType)}}) - n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy, n3_policy), NewECDHExt(), n3_listener) + n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(&ecdh_policy, &n3_policy), NewECDHExt(), n3_listener) ctx.Log.Logf("test", "N1: %s", n1.ID) ctx.Log.Logf("test", "N2: %s", n2.ID) @@ -87,18 +88,18 @@ func TestECDH(t *testing.T) { } fatalErr(t, err) ctx.Log.Logf("test", "N1_EC: %+v", n1_ec) - err = ctx.Send(n1.ID, n2.ID, ecdh_req) + err = ctx.Send(n1.ID, n2.ID, &ecdh_req) fatalErr(t, err) - _, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig ECDHSignal) bool { + _, err = WaitForSignal(ctx, n1_listener, 100*time.Millisecond, ECDHSignalType, func(sig *ECDHSignal) bool { return sig.Str == "resp" }) fatalErr(t, err) time.Sleep(10*time.Millisecond) - ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, NewDirectSignal(StopSignalType), ecdh_ext.ECDHStates[n2.ID].SharedSecret) + ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, &StopSignal, ecdh_ext.ECDHStates[n2.ID].SharedSecret) fatalErr(t, err) - err = ctx.Send(n1.ID, n2.ID, ecdh_sig) + err = ctx.Send(n1.ID, n2.ID, &ecdh_sig) fatalErr(t, err) } diff --git a/policy.go b/policy.go index 93bac51..7a5ca1f 100644 --- a/policy.go +++ b/policy.go @@ -20,6 +20,8 @@ type Policy interface { Allows(principal_id NodeID, action Action, node *Node) error // Merge with another policy of the same underlying type Merge(Policy) Policy + // Make a copy of this policy + Copy() Policy } func (policy AllNodesPolicy) Allows(principal_id NodeID, action Action, node *Node) error { @@ -72,6 +74,14 @@ func MergeActions(first Actions, second Actions) Actions { return ret } +func CopyNodeActions(actions NodeActions) NodeActions { + ret := NodeActions{} + for id, a := range(actions) { + ret[id] = a + } + return ret +} + func MergeNodeActions(modified NodeActions, read NodeActions) { for id, actions := range(read) { existing, exists := modified[id] @@ -83,24 +93,47 @@ func MergeNodeActions(modified NodeActions, read NodeActions) { } } -func (policy PerNodePolicy) Merge(p Policy) Policy { - other := p.(PerNodePolicy) +func (policy *PerNodePolicy) Merge(p Policy) Policy { + other := p.(*PerNodePolicy) MergeNodeActions(policy.NodeActions, other.NodeActions) return policy } -func (policy AllNodesPolicy) Merge(p Policy) Policy { - other := p.(AllNodesPolicy) +func (policy *PerNodePolicy) Copy() Policy { + new_actions := CopyNodeActions(policy.NodeActions) + return &PerNodePolicy{ + NodeActions: new_actions, + } +} + +func (policy *AllNodesPolicy) Merge(p Policy) Policy { + other := p.(*AllNodesPolicy) policy.Actions = MergeActions(policy.Actions, other.Actions) return policy } -func (policy RequirementOfPolicy) Merge(p Policy) Policy { - other := p.(RequirementOfPolicy) +func (policy *AllNodesPolicy) Copy() Policy { + new_actions := policy.Actions + return &AllNodesPolicy { + Actions: new_actions, + } +} + +func (policy *RequirementOfPolicy) Merge(p Policy) Policy { + other := p.(*RequirementOfPolicy) policy.Actions = MergeActions(policy.Actions, other.Actions) return policy } +func (policy *RequirementOfPolicy) Copy() Policy { + new_actions := policy.Actions + return &RequirementOfPolicy{ + AllNodesPolicy { + Actions: new_actions, + }, + } +} + type Action []string func MakeAction(parts ...interface{}) Action { @@ -178,36 +211,6 @@ func (actions *NodeActions) UnmarshalJSON(data []byte) error { return nil } -type AllNodesPolicyJSON struct { - Actions Actions `json:"actions"` -} - -func AllNodesPolicyLoad(init_fn func(Actions)(Policy, error)) func(*Context, []byte)(Policy, error) { - return func(ctx *Context, data []byte)(Policy, error){ - var j AllNodesPolicyJSON - err := json.Unmarshal(data, &j) - - if err != nil { - return nil, err - } - - return init_fn(j.Actions) - } -} - -func PerNodePolicyLoad(init_fn func(NodeActions)(Policy, error)) func(*Context, []byte)(Policy, error) { - return func(ctx *Context, data []byte)(Policy, error){ - policy := PerNodePolicy{ - NodeActions: NodeActions{}, - } - err := json.Unmarshal(data, &policy) - if err != nil { - return nil, err - } - return init_fn(policy.NodeActions) - } -} - func NewPerNodePolicy(node_actions NodeActions) PerNodePolicy { if node_actions == nil { node_actions = NodeActions{} @@ -222,14 +225,18 @@ type PerNodePolicy struct { NodeActions NodeActions `json:"node_actions"` } -func (policy PerNodePolicy) Type() PolicyType { +func (policy *PerNodePolicy) Type() PolicyType { return PerNodePolicyType } -func (policy PerNodePolicy) Serialize() ([]byte, error) { +func (policy *PerNodePolicy) Serialize() ([]byte, error) { return json.MarshalIndent(policy, "", " ") } +func (policy *PerNodePolicy) Deserialize(ctx *Context, data []byte) error { + return json.Unmarshal(data, policy) +} + func NewAllNodesPolicy(actions Actions) AllNodesPolicy { if actions == nil { actions = Actions{} @@ -244,14 +251,18 @@ type AllNodesPolicy struct { Actions Actions } -func (policy AllNodesPolicy) Type() PolicyType { +func (policy *AllNodesPolicy) Type() PolicyType { return AllNodesPolicyType } -func (policy AllNodesPolicy) Serialize() ([]byte, error) { +func (policy *AllNodesPolicy) Serialize() ([]byte, error) { return json.MarshalIndent(policy, "", " ") } +func (policy *AllNodesPolicy) Deserialize(ctx *Context, data []byte) error { + return json.Unmarshal(data, policy) +} + // Extension to allow a node to hold ACL policies type ACLExt struct { Policies map[PolicyType]Policy @@ -266,36 +277,17 @@ func NodeList(nodes ...*Node) NodeMap { return m } -type PolicyLoadFunc func(*Context, []byte) (Policy, error) -type PolicyInfo struct { - Load PolicyLoadFunc -} - +type PolicyLoadFunc func(*Context,[]byte) (Policy, error) type ACLExtContext struct { - Types map[PolicyType]PolicyInfo + Loads map[PolicyType]PolicyLoadFunc } func NewACLExtContext() *ACLExtContext { return &ACLExtContext{ - Types: map[PolicyType]PolicyInfo{ - AllNodesPolicyType: PolicyInfo{ - Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ - policy := NewAllNodesPolicy(actions) - return &policy, nil - }), - }, - PerNodePolicyType: PolicyInfo{ - Load: PerNodePolicyLoad(func(nodes NodeActions)(Policy,error){ - policy := NewPerNodePolicy(nodes) - return &policy, nil - }), - }, - RequirementOfPolicyType: PolicyInfo{ - Load: AllNodesPolicyLoad(func(actions Actions)(Policy, error){ - policy := NewRequirementOfPolicy(actions) - return &policy, nil - }), - }, + Loads: map[PolicyType]PolicyLoadFunc{ + AllNodesPolicyType: LoadPolicy[AllNodesPolicy,*AllNodesPolicy], + PerNodePolicyType: LoadPolicy[PerNodePolicy,*PerNodePolicy], + RequirementOfPolicyType: LoadPolicy[RequirementOfPolicy,*RequirementOfPolicy], }, } } @@ -331,13 +323,15 @@ func (ext *ACLExt) Field(name string) interface{} { var ErrorSignalAction = Action{"ERROR_RESP"} var ReadResultSignalAction = Action{"READ_RESULT"} var AuthorizedSignalAction = Action{"AUTHORIZED_READ"} +var defaultPolicy = NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction, AuthorizedSignalAction}) var DefaultACLPolicies = []Policy{ - NewAllNodesPolicy(Actions{ErrorSignalAction, ReadResultSignalAction, AuthorizedSignalAction}), + &defaultPolicy, } func NewACLExt(policies ...Policy) *ACLExt { policy_map := map[PolicyType]Policy{} - for _, policy := range(append(policies, DefaultACLPolicies...)) { + for _, policy_arg := range(append(policies, DefaultACLPolicies...)) { + policy := policy_arg.Copy() existing, exists := policy_map[policy.Type()] if exists == true { policy = existing.Merge(policy) @@ -351,36 +345,49 @@ func NewACLExt(policies ...Policy) *ACLExt { } } -func LoadACLExt(ctx *Context, data []byte) (Extension, error) { +func LoadPolicy[T any, P interface { + *T + Policy +}](ctx *Context, data []byte) (Policy, error) { + p := P(new(T)) + err := p.Deserialize(ctx, data) + if err != nil { + return nil, err + } + + return p, nil +} + +func (ext *ACLExt) Deserialize(ctx *Context, data []byte) error { var j struct { Policies map[string][]byte `json:"policies"` } + err := json.Unmarshal(data, &j) if err != nil { - return nil, err + return err } - policies := make([]Policy, len(j.Policies)) - i := 0 acl_ctx, err := GetCtx[*ACLExt, *ACLExtContext](ctx) if err != nil { - return nil, err + return err } + ext.Policies = map[PolicyType]Policy{} + for name, ser := range(j.Policies) { - policy_def, exists := acl_ctx.Types[PolicyType(name)] + policy_load, exists := acl_ctx.Loads[PolicyType(name)] if exists == false { - return nil, fmt.Errorf("%s is not a known policy type", name) + return fmt.Errorf("%s is not a known policy type", name) } - policy, err := policy_def.Load(ctx, ser) + policy, err := policy_load(ctx, ser) if err != nil { - return nil, err + return err } - policies[i] = policy - i++ + ext.Policies[PolicyType(name)] = policy } - return NewACLExt(policies...), nil + return nil } func (ext *ACLExt) Type() ExtType { diff --git a/signal.go b/signal.go index 8c0d733..902466e 100644 --- a/signal.go +++ b/signal.go @@ -95,23 +95,27 @@ type BaseSignal struct { UUID uuid.UUID `json:"id"` } -func (signal BaseSignal) ID() uuid.UUID { +func (signal *BaseSignal) Deserialize(ctx *Context, data []byte) error { + return json.Unmarshal(data, signal) +} + +func (signal *BaseSignal) ID() uuid.UUID { return signal.UUID } -func (signal BaseSignal) Type() SignalType { +func (signal *BaseSignal) Type() SignalType { return signal.SignalType } -func (signal BaseSignal) Permission() Action { +func (signal *BaseSignal) Permission() Action { return MakeAction(signal.Type()) } -func (signal BaseSignal) Direction() SignalDirection { +func (signal *BaseSignal) Direction() SignalDirection { return signal.SignalDirection } -func (signal BaseSignal) Serialize() ([]byte, error) { +func (signal *BaseSignal) Serialize() ([]byte, error) { return json.Marshal(signal) } @@ -136,43 +140,16 @@ func NewDirectSignal(signal_type SignalType) BaseSignal { return NewBaseSignal(signal_type, Direct) } +var StartSignal = NewDirectSignal(StartSignalType) var StopSignal = NewDownSignal(StopSignalType) -type ErrorSignal struct { - BaseSignal - Error error `json:"error"` -} - -func (signal ErrorSignal) Permission() Action { - return ErrorSignalAction -} - -func NewErrorSignal(req_id uuid.UUID, err error) ErrorSignal { - return ErrorSignal{ - BaseSignal: BaseSignal{ - Direct, - ErrorSignalType, - req_id, - }, - Error: err, - } -} - type IDSignal struct { BaseSignal NodeID `json:"id"` } -func (signal IDSignal) Serialize() ([]byte, error) { - return json.Marshal(&signal) -} - -func (signal IDSignal) String() string { - ser, err := json.Marshal(signal) - if err != nil { - return "STATE_SER_ERR" - } - return string(ser) +func (signal *IDSignal) Serialize() ([]byte, error) { + return json.Marshal(signal) } func NewIDSignal(signal_type SignalType, direction SignalDirection, id NodeID) IDSignal { @@ -187,21 +164,38 @@ type StringSignal struct { Str string `json:"state"` } -func (signal StringSignal) Serialize() ([]byte, error) { +func (signal *StringSignal) Serialize() ([]byte, error) { return json.Marshal(&signal) } +type ErrorSignal struct { + StringSignal +} + +func (signal *ErrorSignal) Permission() Action { + return ErrorSignalAction +} + +func NewErrorSignal(req_id uuid.UUID, err string) ErrorSignal { + return ErrorSignal{ + StringSignal{ + NewDirectSignal(ErrorSignalType), + err, + }, + } +} + type IDStringSignal struct { BaseSignal NodeID `json:"node_id"` Str string `json:"string"` } -func (signal IDStringSignal) Serialize() ([]byte, error) { - return json.Marshal(&signal) +func (signal *IDStringSignal) Serialize() ([]byte, error) { + return json.Marshal(signal) } -func (signal IDStringSignal) String() string { +func (signal *IDStringSignal) String() string { ser, err := json.Marshal(signal) if err != nil { return "STATE_SER_ERR" @@ -243,7 +237,7 @@ func NewLockSignal(state string) StringSignal { } } -func (signal StringSignal) Permission() Action { +func (signal *StringSignal) Permission() Action { return MakeAction(signal.Type(), signal.Str) } @@ -259,7 +253,7 @@ type AuthorizedSignal struct { Signature []byte } -func (signal AuthorizedSignal) Permission() Action { +func (signal *AuthorizedSignal) Permission() Action { return AuthorizedSignalAction } @@ -283,8 +277,8 @@ func NewAuthorizedSignal(principal *ecdsa.PrivateKey, signal Signal) (Authorized }, nil } -func (signal ReadSignal) Serialize() ([]byte, error) { - return json.Marshal(&signal) +func (signal *ReadSignal) Serialize() ([]byte, error) { + return json.Marshal(signal) } func NewReadSignal(exts map[ExtType][]string) ReadSignal { @@ -300,7 +294,7 @@ type ReadResultSignal struct { Extensions map[ExtType]map[string]interface{} `json:"extensions"` } -func (signal ReadResultSignal) Permission() Action { +func (signal *ReadResultSignal) Permission() Action { return ReadResultSignalAction } @@ -342,8 +336,8 @@ func (signal *ECDHSignal) MarshalJSON() ([]byte, error) { }) } -func (signal ECDHSignal) Serialize() ([]byte, error) { - return json.Marshal(&signal) +func (signal *ECDHSignal) Serialize() ([]byte, error) { + return json.Marshal(signal) } func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) { @@ -390,7 +384,7 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e const DEFAULT_ECDH_WINDOW = time.Second -func NewECDHRespSignal(ctx *Context, node *Node, req ECDHSignal) (ECDHSignal, []byte, error) { +func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, []byte, error) { now := time.Now() err := VerifyECDHSignal(now, req, DEFAULT_ECDH_WINDOW) @@ -430,7 +424,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req ECDHSignal) (ECDHSignal, [] }, shared_secret, nil } -func VerifyECDHSignal(now time.Time, sig ECDHSignal, window time.Duration) error { +func VerifyECDHSignal(now time.Time, sig *ECDHSignal, window time.Duration) error { earliest := now.Add(-window) latest := now.Add(window) diff --git a/user.go b/user.go index b4908d9..7371af6 100644 --- a/user.go +++ b/user.go @@ -17,9 +17,9 @@ func (ext *GroupExt) Type() ExtType { } func (ext *GroupExt) Serialize() ([]byte, error) { - return json.MarshalIndent(&GroupExtJSON{ + return json.Marshal(&GroupExtJSON{ Members: IDMap(ext.Members), - }, "", " ") + }) } func (ext *GroupExt) Field(name string) interface{} { @@ -40,18 +40,12 @@ func NewGroupExt(members map[NodeID]string) *GroupExt { } } -func LoadGroupExt(ctx *Context, data []byte) (Extension, error) { +func (ext *GroupExt) Deserialize(ctx *Context, data []byte) error { var j GroupExtJSON err := json.Unmarshal(data, &j) - members, err := LoadIDMap(j.Members) - if err != nil { - return nil, err - } - - return &GroupExt{ - Members: members, - }, nil + ext.Members, err = LoadIDMap(j.Members) + return err } func (ext *GroupExt) Process(ctx *Context, princ_id NodeID, node *Node, signal Signal) {