From 641bd8febe946f7cfc23d42dbbe0e5fb755239e4 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sat, 29 Jul 2023 00:28:44 -0600 Subject: [PATCH] Moved ecdh to it's own extension --- context.go | 32 ++++++++ ecdh.go | 203 ++++++++++++++++++++++++++++++++++++++++++++++ gql.go | 48 +++-------- gql_query.go | 19 +++-- gql_resolvers.go | 25 ++---- gql_subscribe.go | 2 +- gql_test.go | 9 +-- node.go | 2 +- node_test.go | 12 ++- signal.go | 205 ++++++++++++++++++++++++++++++++++++++++++----- user.go | 179 ----------------------------------------- 11 files changed, 463 insertions(+), 273 deletions(-) create mode 100644 ecdh.go diff --git a/context.go b/context.go index b903844..ee61247 100644 --- a/context.go +++ b/context.go @@ -53,6 +53,13 @@ var ( NodeNotFoundError = errors.New("Node not found in DB") ) +type SignalLoadFunc func(*Context, []byte) (Signal, error) + +type SignalInfo struct { + Load SignalLoadFunc + Type SignalType +} + // Information about a registered extension type ExtensionInfo struct { // Function used to load extensions of this type from the database @@ -77,6 +84,8 @@ type Context struct { Log Logger // Map between database extension hashes and the registered info Extensions map[uint64]ExtensionInfo + // Map between serialized signal hashes and the registered info + Signals map[uint64]SignalInfo // Map between database type hashes and the registered info Types map[uint64]*NodeInfo // Curve used for signature operations @@ -118,6 +127,24 @@ func (ctx *Context) RegisterNodeType(node_type NodeType, extensions []ExtType) e return nil } +func (ctx *Context) RegisterSignal(signal_type SignalType, load_fn SignalLoadFunc) error { + if load_fn == nil { + return fmt.Errorf("def has no load function") + } + + type_hash := Hash(signal_type) + _, exists := ctx.Signals[type_hash] + if exists == true { + return fmt.Errorf("Cannot register signal of type %s, type already exists in context", signal_type) + } + + ctx.Signals[type_hash] = SignalInfo{ + Load: load_fn, + Type: signal_type, + } + return nil +} + // Add a node to a context, returns an error if the def is invalid or already exists in the context func (ctx *Context) RegisterExtension(ext_type ExtType, load_fn ExtensionLoadFunc, data interface{}) error { if load_fn == nil { @@ -199,6 +226,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { Log: log, Extensions: map[uint64]ExtensionInfo{}, Types: map[uint64]*NodeInfo{}, + Signals: map[uint64]SignalInfo{}, Nodes: map[NodeID]*Node{}, ECDH: ecdh.P256(), ECDSA: elliptic.P256(), @@ -236,6 +264,10 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = ctx.RegisterSignal(StopSignalType, func(ctx *Context, data []byte) (Signal, error) { + return StopSignal, nil + }) + err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType}) if err != nil { return nil, err diff --git a/ecdh.go b/ecdh.go new file mode 100644 index 0000000..0297f13 --- /dev/null +++ b/ecdh.go @@ -0,0 +1,203 @@ +package graphvent + +import ( + "fmt" + "time" + "encoding/json" + "crypto/ecdsa" + "crypto/x509" + "crypto/ecdh" +) + +type ECDHState struct { + ECKey *ecdh.PrivateKey + SharedSecret []byte +} + +type ECDHStateJSON struct { + ECKey []byte `json:"ec_key"` + SharedSecret []byte `json:"shared_secret"` +} + +func (state *ECDHState) MarshalJSON() ([]byte, error) { + var key_bytes []byte + var err error + if state.ECKey != nil { + key_bytes, err = x509.MarshalPKCS8PrivateKey(state.ECKey) + if err != nil { + return nil, err + } + } + + return json.Marshal(&ECDHStateJSON{ + ECKey: key_bytes, + SharedSecret: state.SharedSecret, + }) +} + +func (state *ECDHState) UnmarshalJSON(data []byte) error { + var j ECDHStateJSON + err := json.Unmarshal(data, &j) + if err != nil { + return err + } + + state.SharedSecret = j.SharedSecret + if len(j.ECKey) == 0 { + state.ECKey = nil + } else { + tmp_key, err := x509.ParsePKCS8PrivateKey(j.ECKey) + if err != nil { + return err + } + + ecdsa_key, ok := tmp_key.(*ecdsa.PrivateKey) + if ok == false { + return fmt.Errorf("Parsed wrong key type from DB for ECDHState") + } + + state.ECKey, err = ecdsa_key.ECDH() + if err != nil { + return err + } + } + + return nil +} + +type ECDHMap map[NodeID]ECDHState + +func (m ECDHMap) MarshalJSON() ([]byte, error) { + tmp := map[string]ECDHState{} + for id, state := range(m) { + tmp[id.String()] = state + } + + return json.Marshal(tmp) +} + +type ECDHExt struct { + ECDHStates ECDHMap +} + +func NewECDHExt() *ECDHExt { + return &ECDHExt{ + ECDHStates: ECDHMap{}, + } +} + +func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)interface{})interface{} { + var zero T + field_func, ok := field_funcs[name] + if ok == false { + return fmt.Errorf("%s is not a field of %s", name, zero.Type()) + } + return field_func(t) +} + +func (ext *ECDHExt) Field(name string) interface{} { + return ResolveFields(ext, name, map[string]func(*ECDHExt)interface{}{ + "ecdh_states": func(ext *ECDHExt) interface{} { + return ext.ECDHStates + }, + }) +} + +func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal ECDHSignal) { + ser, _ := signal.Serialize() + ctx.Log.Logf("ecdh", "ECDH_SIGNAL: %s->%s - %s", source, node.ID, ser) + switch signal.State { + case "req": + state, exists := ext.ECDHStates[source] + if exists == false { + state = ECDHState{nil, nil} + } + resp, shared_secret, err := NewECDHRespSignal(ctx, node, signal) + if err == nil { + state.SharedSecret = shared_secret + ext.ECDHStates[source] = state + ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) + ctx.Send(node.ID, source, resp) + } else { + ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) + // TODO: send error response + } + case "resp": + state, exists := ext.ECDHStates[source] + if exists == false || state.ECKey == nil { + ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHStateSignalType), "no_req"}) + } else { + err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) + if err == nil { + shared_secret, err := state.ECKey.ECDH(signal.ECDH) + if err == nil { + state.SharedSecret = shared_secret + state.ECKey = nil + ext.ECDHStates[source] = state + ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) + } + } + } + default: + ctx.Log.Logf("ecdh", "unknown echd state %s", signal.State) + } +} + +func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal StateSignal) { + ser, _ := signal.Serialize() + ctx.Log.Logf("ecdh", "ECHD_STATE: %s->%s - %s", source, node.ID, ser) +} + +func (ext *ECDHExt) HandleECDHProxySignal(ctx *Context, source NodeID, node *Node, signal ECDHProxySignal) { + state, exists := ext.ECDHStates[source] + if exists == false { + ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHStateSignalType), "no_req"}) + } else if state.SharedSecret == nil { + ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHStateSignalType), "no_shared"}) + } else { + unwrapped_signal, err := ParseECDHProxySignal(ctx, &signal, state.SharedSecret) + if err != nil { + ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHStateSignalType), err.Error()}) + } else { + //TODO: Figure out what I was trying to do here and fix it + ctx.Send(signal.Source, signal.Dest, unwrapped_signal) + } + } +} + +func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { + switch signal.Direction() { + case Direct: + switch signal.Type() { + case ECDHProxySignalType: + ecdh_signal := signal.(ECDHProxySignal) + ext.HandleECDHProxySignal(ctx, source, node, ecdh_signal) + case ECDHStateSignalType: + ecdh_signal := signal.(StateSignal) + ext.HandleStateSignal(ctx, source, node, ecdh_signal) + case ECDHSignalType: + ecdh_signal := signal.(ECDHSignal) + ext.HandleECDHSignal(ctx, source, node, ecdh_signal) + default: + } + default: + } +} + +func (ext *ECDHExt) Type() ExtType { + return ECDHExtType +} + +func (ext *ECDHExt) Serialize() ([]byte, error) { + return json.MarshalIndent(ext, "", " ") +} + +func LoadECDHExt(ctx *Context, data []byte) (Extension, error) { + var ext ECDHExt + err := json.Unmarshal(data, &ext) + if err != nil { + return nil, err + } + + return &ext, nil +} diff --git a/gql.go b/gql.go index 0b0d887..0983dac 100644 --- a/gql.go +++ b/gql.go @@ -167,7 +167,7 @@ type ResolveContext struct { GQLContext *GQLExtContext Server *Node Ext *GQLExt - User *Node + User NodeID } func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { @@ -181,16 +181,11 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) } - user, exists := gql_ext.Users[auth_id] - if exists == false { - return nil, fmt.Errorf("GQL_REQUEST_ERR: no existing authorization for client %s", auth_id) - } - return &ResolveContext{ Context: ctx, GQLContext: ctx.Extensions[Hash(GQLExtType)].Data.(*GQLExtContext), Server: server, - User: user, + User: auth_id, }, nil } @@ -481,6 +476,11 @@ func NewGQLInterface(if_name string, default_name string, interfaces []*graphql. return &gql } +type GQLNode struct { + ID NodeID + Type NodeType +} + // GQL Specific Context information type GQLExtContext struct { // Generated GQL schema @@ -545,7 +545,7 @@ func NewGQLExtContext() *GQLExtContext { }) query.AddFieldConfig("Self", GQLQuerySelf) - query.AddFieldConfig("User", GQLQueryUser) + query.AddFieldConfig("Node", GQLQueryNode) mutation := graphql.NewObject(graphql.ObjectConfig{ Name: "Mutation", @@ -600,9 +600,7 @@ type GQLExt struct { tls_key []byte tls_cert []byte Listen string - Users NodeMap - Key *ecdsa.PrivateKey - ECDH ecdh.Curve + SubscribeLock sync.Mutex SubscribeListeners []chan Signal } @@ -629,7 +627,6 @@ func (ext *GQLExt) Process(context *Context, princ_id NodeID, node *Node, signal if signal.Type() == ReadResultSignalType { } - ext.SubscribeLock.Lock() defer ext.SubscribeLock.Unlock() @@ -655,22 +652,13 @@ func (ext *GQLExt) Type() ExtType { type GQLExtJSON struct { Listen string `json:"listen"` - Key []byte `json:"key"` - ECDH uint8 `json:"ecdh_curve"` TLSKey []byte `json:"ssl_key"` TLSCert []byte `json:"ssl_cert"` } func (ext *GQLExt) Serialize() ([]byte, error) { - ser_key, err := x509.MarshalECPrivateKey(ext.Key) - if err != nil { - return nil, err - } - return json.MarshalIndent(&GQLExtJSON{ Listen: ext.Listen, - Key: ser_key, - ECDH: ecdh_curve_ids[ext.ECDH], TLSKey: ext.tls_key, TLSCert: ext.tls_cert, }, "", " ") @@ -699,22 +687,12 @@ func LoadGQLExt(ctx *Context, data []byte) (Extension, error) { return nil, err } - ecdh_curve, ok := ecdh_curves[j.ECDH] - if ok == false { - return nil, fmt.Errorf("%d is not a known ECDH curve ID", j.ECDH) - } - - key, err := x509.ParseECPrivateKey(j.Key) - if err != nil { - return nil, err - } - - return NewGQLExt(j.Listen, ecdh_curve, key, j.TLSCert, j.TLSKey), nil + return NewGQLExt(ctx, j.Listen, j.TLSCert, j.TLSKey), nil } -func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_cert []byte, tls_key []byte) *GQLExt { +func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte) *GQLExt { if tls_cert == nil || tls_key == nil { - ssl_key, err := ecdsa.GenerateKey(key.Curve, rand.Reader) + ssl_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) if err != nil { panic(err) } @@ -755,8 +733,6 @@ func NewGQLExt(listen string, ecdh_curve ecdh.Curve, key *ecdsa.PrivateKey, tls_ return &GQLExt{ Listen: listen, SubscribeListeners: []chan Signal{}, - Key: key, - ECDH: ecdh_curve, tls_cert: tls_cert, tls_key: tls_key, } diff --git a/gql_query.go b/gql_query.go index e9b7667..679c43d 100644 --- a/gql_query.go +++ b/gql_query.go @@ -3,26 +3,31 @@ import ( "github.com/graphql-go/graphql" ) -var GQLQuerySelf = &graphql.Field{ - Type: GQLInterfaceNode.Default, +var GQLQueryNode = &graphql.Field{ + Type: GQLInterfaceNode.Interface, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - _, ctx, err := PrepResolve(p) + ctx, err := PrepResolve(p) if err != nil { return nil, err } + ctx.Context.Log.Logf("gql", "FieldASTs: %+v", p.Info.FieldASTs) + // Get a list of fields that will be written + // Send the read signal + // Wait for the response, returning an error on timeout - return ctx.Server, nil + return nil, nil }, } -var GQLQueryUser = &graphql.Field{ +var GQLQuerySelf = &graphql.Field{ Type: GQLInterfaceNode.Default, Resolve: func(p graphql.ResolveParams) (interface{}, error) { - _, ctx, err := PrepResolve(p) + _, err := PrepResolve(p) if err != nil { return nil, err } - return ctx.User, nil + return nil, nil }, } + diff --git a/gql_resolvers.go b/gql_resolvers.go index aad4e43..4ba12e8 100644 --- a/gql_resolvers.go +++ b/gql_resolvers.go @@ -5,18 +5,13 @@ import ( "github.com/graphql-go/graphql" ) -func PrepResolve(p graphql.ResolveParams) (*Node, *ResolveContext, error) { +func PrepResolve(p graphql.ResolveParams) (*ResolveContext, error) { resolve_context, ok := p.Context.Value("resolve").(*ResolveContext) if ok == false { - return nil, nil, fmt.Errorf("Bad resolve in params context") + return nil, fmt.Errorf("Bad resolve in params context") } - node, ok := p.Source.(*Node) - if ok == false { - return nil, nil, fmt.Errorf("Source is not a *Node in PrepResolve") - } - - return node, resolve_context, nil + return resolve_context, nil } // TODO: Make composabe by checkinf if K is a slice, then recursing in the same way that ExtractList does @@ -71,21 +66,11 @@ func ExtractID(p graphql.ResolveParams, name string) (NodeID, error) { // TODO: think about what permissions should be needed to read ID, and if there's ever a case where they haven't already been granted func GQLNodeID(p graphql.ResolveParams) (interface{}, error) { - node, _, err := PrepResolve(p) - if err != nil { - return nil, err - } - - return node.ID, nil + return nil, nil } func GQLNodeTypeHash(p graphql.ResolveParams) (interface{}, error) { - node, _, err := PrepResolve(p) - if err != nil { - return nil, err - } - - return string(node.Type), nil + return nil, nil } func GQLNodeListen(p graphql.ResolveParams) (interface{}, error) { diff --git a/gql_subscribe.go b/gql_subscribe.go index 63440d6..a83c0a3 100644 --- a/gql_subscribe.go +++ b/gql_subscribe.go @@ -16,7 +16,7 @@ func GQLSubscribeSelf(p graphql.ResolveParams) (interface{}, error) { } func GQLSubscribeFn(p graphql.ResolveParams, send_nil bool, fn func(*Context, *Node, *GQLExt, Signal, graphql.ResolveParams)(interface{}, error))(interface{}, error) { - _, ctx, err := PrepResolve(p) + ctx, err := PrepResolve(p) if err != nil { return nil, err } diff --git a/gql_test.go b/gql_test.go index f6817dd..4f32562 100644 --- a/gql_test.go +++ b/gql_test.go @@ -3,10 +3,6 @@ package graphvent import ( "testing" "time" - "crypto/rand" - "crypto/ecdh" - "crypto/ecdsa" - "crypto/elliptic" ) func TestGQLDB(t * testing.T) { @@ -19,10 +15,7 @@ func TestGQLDB(t * testing.T) { ctx.Log.Logf("test", "U1_ID: %s", u1.ID) - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - fatalErr(t, err) - - gql_ext := NewGQLExt(":0", ecdh.P256(), key, nil, nil) + gql_ext := NewGQLExt(ctx, ":0", nil, nil) listener_ext := NewListenerExt(10) gql := NewNode(ctx, nil, GQLNodeType, 10, nil, gql_ext, diff --git a/node.go b/node.go index 10e2246..4b84751 100644 --- a/node.go +++ b/node.go @@ -218,7 +218,7 @@ func nodeLoop(ctx *Context, node *Node) error { ctx.Log.Logf("signal", "READ_SIGNAL: bad cast %+v", signal) } else { result := ReadNodeFields(ctx, node, source, read_signal.Extensions) - ctx.Send(node.ID, source, NewReadResultSignal(result)) + ctx.Send(node.ID, source, NewReadResultSignal(node.Type, result)) } } diff --git a/node_test.go b/node_test.go index 64f8317..bc79432 100644 --- a/node_test.go +++ b/node_test.go @@ -67,9 +67,12 @@ func TestECDH(t *testing.T) { fatalErr(t, err) n1_listener := NewListenerExt(10) - ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+")}) + ecdh_policy := NewAllNodesPolicy(Actions{MakeAction(ECDHSignalType, "+"), MakeAction(ECDHProxySignalType, "+")}) n1 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt(), n1_listener) n2 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy), NewECDHExt()) + n3_listener := NewListenerExt(10) + n3_policy := NewPerNodePolicy(NodeActions{n1.ID: Actions{MakeAction(StopSignalType)}}) + n3 := NewNode(ctx, nil, node_type, 10, nil, NewACLExt(ecdh_policy, n3_policy), NewECDHExt(), n3_listener) ctx.Log.Logf("test", "N1: %s", n1.ID) ctx.Log.Logf("test", "N2: %s", n2.ID) @@ -91,4 +94,11 @@ func TestECDH(t *testing.T) { return sig.State == "resp" }) fatalErr(t, err) + time.Sleep(10*time.Millisecond) + + ecdh_sig, err := NewECDHProxySignal(n1.ID, n3.ID, NewDirectSignal(StopSignalType), ecdh_ext.ECDHStates[n2.ID].SharedSecret) + fatalErr(t, err) + + err = ctx.Send(n1.ID, n2.ID, ecdh_sig) + fatalErr(t, err) } diff --git a/signal.go b/signal.go index c031603..f0a5fd1 100644 --- a/signal.go +++ b/signal.go @@ -4,10 +4,13 @@ import ( "time" "fmt" "encoding/json" + "encoding/binary" "crypto/sha512" "crypto/ecdsa" "crypto/ecdh" "crypto/rand" + "crypto/aes" + "crypto/cipher" ) type SignalDirection int @@ -20,6 +23,8 @@ const ( ReadResultSignalType = "READ_RESULT" LinkStartSignalType = "LINK_START" ECDHSignalType = "ECDH" + ECDHStateSignalType = "ECDH_STATE" + ECDHProxySignalType = "ECDH_PROXY" Up SignalDirection = iota Down @@ -27,9 +32,8 @@ const ( ) type SignalType string -func (signal_type SignalType) String() string { - return string(signal_type) -} +func (signal_type SignalType) String() string { return string(signal_type) } +func (signal_type SignalType) Prefix() string { return "SIGNAL: " } type Signal interface { Serializable[SignalType] @@ -49,7 +53,6 @@ func WaitForSignal[S Signal](ctx * Context, listener *ListenerExt, timeout time. if signal.Type() == signal_type { sig, ok := signal.(S) if ok == true { - ctx.Log.Logf("test", "received: %+v", sig) if check(sig) == true { return sig, nil } @@ -80,12 +83,8 @@ func (signal BaseSignal) Direction() SignalDirection { return signal.SignalDirection } -func (signal *BaseSignal) MarshalJSON() ([]byte, error) { - return json.Marshal(signal) -} - func (signal BaseSignal) Serialize() ([]byte, error) { - return signal.MarshalJSON() + return json.Marshal(signal) } func NewBaseSignal(signal_type SignalType, direction SignalDirection) BaseSignal { @@ -115,6 +114,10 @@ type IDSignal struct { ID NodeID `json:"id"` } +func (signal IDSignal) Serialize() ([]byte, error) { + return json.Marshal(&signal) +} + func (signal IDSignal) String() string { ser, err := json.Marshal(signal) if err != nil { @@ -135,10 +138,18 @@ type StateSignal struct { State string `json:"state"` } +func (signal StateSignal) Serialize() ([]byte, error) { + return json.Marshal(&signal) +} + type IDStateSignal struct { BaseSignal ID NodeID `json:"id"` - State string `json:"status"` + State string `json:"state"` +} + +func (signal IDStateSignal) Serialize() ([]byte, error) { + return json.Marshal(&signal) } func (signal IDStateSignal) String() string { @@ -157,15 +168,6 @@ func NewStatusSignal(status string, source NodeID) IDStateSignal { } } -func (signal StateSignal) Serialize() ([]byte, error) { - return json.MarshalIndent(signal, "", " ") -} - -func (signal StateSignal) String() string { - ser, _ := signal.Serialize() - return string(ser) -} - func NewLinkSignal(state string) StateSignal { return StateSignal{ BaseSignal: NewDirectSignal(LinkSignalType), @@ -201,6 +203,10 @@ type ReadSignal struct { Extensions map[ExtType][]string `json:"extensions"` } +func (signal ReadSignal) Serialize() ([]byte, error) { + return json.Marshal(&signal) +} + func NewReadSignal(exts map[ExtType][]string) ReadSignal { return ReadSignal{ BaseSignal: NewDirectSignal(ReadSignalType), @@ -210,12 +216,14 @@ func NewReadSignal(exts map[ExtType][]string) ReadSignal { type ReadResultSignal struct { BaseSignal + NodeType NodeType Extensions map[ExtType]map[string]interface{} `json:"extensions"` } -func NewReadResultSignal(exts map[ExtType]map[string]interface{}) ReadResultSignal { +func NewReadResultSignal(node_type NodeType, exts map[ExtType]map[string]interface{}) ReadResultSignal { return ReadResultSignal{ BaseSignal: NewDirectSignal(ReadResultSignalType), + NodeType: node_type, Extensions: exts, } } @@ -228,6 +236,28 @@ type ECDHSignal struct { Signature []byte } +type ECDHSignalJSON struct { + StateSignal + Time time.Time `json:"time"` + ECDSA []byte `json:"ecdsa_pubkey"` + ECDH []byte `json:"ecdh_pubkey"` + Signature []byte `json:"signature"` +} + +func (signal *ECDHSignal) MarshalJSON() ([]byte, error) { + return json.Marshal(&ECDHSignalJSON{ + StateSignal: signal.StateSignal, + Time: signal.Time, + ECDH: signal.ECDH.Bytes(), + ECDSA: signal.ECDH.Bytes(), + Signature: signal.Signature, + }) +} + +func (signal ECDHSignal) Serialize() ([]byte, error) { + return json.Marshal(&signal) +} + func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) { time_bytes, err := now.MarshalJSON() if err != nil { @@ -335,3 +365,138 @@ func VerifyECDHSignal(now time.Time, sig ECDHSignal, window time.Duration) error return nil } +type ECDHProxySignal struct { + BaseSignal + Source NodeID + Dest NodeID + IV []byte + Data []byte +} + +func NewECDHProxySignal(source, dest NodeID, signal Signal, shared_secret []byte) (ECDHProxySignal, error) { + if shared_secret == nil { + return ECDHProxySignal{}, fmt.Errorf("need shared_secret") + } + + aes_key, err := aes.NewCipher(shared_secret[:32]) + if err != nil { + return ECDHProxySignal{}, err + } + + ser, err := SerializeSignal(signal, aes_key.BlockSize()) + if err != nil { + return ECDHProxySignal{}, err + } + + iv := make([]byte, aes_key.BlockSize()) + n, err := rand.Reader.Read(iv) + if err != nil { + return ECDHProxySignal{}, err + } else if n != len(iv) { + return ECDHProxySignal{}, fmt.Errorf("Not enough bytes read for IV") + } + + encrypter := cipher.NewCBCEncrypter(aes_key, iv) + encrypter.CryptBlocks(ser, ser) + + return ECDHProxySignal{ + BaseSignal: NewDirectSignal(ECDHProxySignalType), + Source: source, + Dest: dest, + IV: iv, + Data: ser, + }, nil +} + +type SignalHeader struct { + Magic uint32 + TypeHash uint64 + Length uint64 +} + +const SIGNAL_SER_MAGIC uint32 = 0x753a64de +const SIGNAL_SER_HEADER_LENGTH = 20 +func SerializeSignal(signal Signal, block_size int) ([]byte, error) { + signal_ser, err := signal.Serialize() + if err != nil { + return nil, err + } + + pad_req := 0 + if block_size > 0 { + pad := block_size - ((SIGNAL_SER_HEADER_LENGTH + len(signal_ser)) % block_size) + if pad != block_size { + pad_req = pad + } + } + + header := SignalHeader{ + Magic: SIGNAL_SER_MAGIC, + TypeHash: Hash(signal.Type()), + Length: uint64(len(signal_ser) + pad_req), + } + + ser := make([]byte, SIGNAL_SER_HEADER_LENGTH + len(signal_ser) + pad_req) + binary.BigEndian.PutUint32(ser[0:4], header.Magic) + binary.BigEndian.PutUint64(ser[4:12], header.TypeHash) + binary.BigEndian.PutUint64(ser[12:20], header.Length) + + copy(ser[SIGNAL_SER_HEADER_LENGTH:], signal_ser) + + return ser, nil +} + +func ParseSignal(ctx *Context, data []byte) (Signal, error) { + if len(data) < SIGNAL_SER_HEADER_LENGTH { + return nil, fmt.Errorf("data shorter than header length") + } + + header := SignalHeader{ + Magic: binary.BigEndian.Uint32(data[0:4]), + TypeHash: binary.BigEndian.Uint64(data[4:12]), + Length: binary.BigEndian.Uint64(data[12:20]), + } + + if header.Magic != SIGNAL_SER_MAGIC { + return nil, fmt.Errorf("signal magic mismatch 0x%x", header.Magic) + } + + left := len(data) - SIGNAL_SER_HEADER_LENGTH + if int(header.Length) != left { + return nil, fmt.Errorf("signal length mismatch %d/%d", header.Length, left) + } + + signal_def, exists := ctx.Signals[header.TypeHash] + if exists == false { + return nil, fmt.Errorf("0x%x is not a known signal type", header.TypeHash) + } + + signal, err := signal_def.Load(ctx, data[SIGNAL_SER_HEADER_LENGTH:]) + if err != nil { + return nil, err + } + + return signal, nil +} + +func ParseECDHProxySignal(ctx *Context, signal *ECDHProxySignal, shared_secret []byte) (Signal, error) { + if shared_secret == nil { + return nil, fmt.Errorf("need shared_secret") + } + + aes_key, err := aes.NewCipher(shared_secret[:32]) + if err != nil { + return nil, err + } + + decrypter := cipher.NewCBCDecrypter(aes_key, signal.IV) + decrypted := make([]byte, len(signal.Data)) + decrypter.CryptBlocks(decrypted, signal.Data) + + wrapped_signal, err := ParseSignal(ctx, decrypted) + if err != nil { + return nil, err + } + + return wrapped_signal, nil +} diff --git a/user.go b/user.go index f8f87ba..b4908d9 100644 --- a/user.go +++ b/user.go @@ -1,188 +1,9 @@ package graphvent import ( - "fmt" - "time" "encoding/json" - "crypto/ecdsa" - "crypto/x509" - "crypto/ecdh" ) -type ECDHState struct { - ECKey *ecdh.PrivateKey - SharedSecret []byte -} - -type ECDHStateJSON struct { - ECKey []byte `json:"ec_key"` - SharedSecret []byte `json:"shared_secret"` -} - -func (state *ECDHState) MarshalJSON() ([]byte, error) { - var key_bytes []byte - var err error - if state.ECKey != nil { - key_bytes, err = x509.MarshalPKCS8PrivateKey(state.ECKey) - if err != nil { - return nil, err - } - } - - return json.Marshal(&ECDHStateJSON{ - ECKey: key_bytes, - SharedSecret: state.SharedSecret, - }) -} - -func (state *ECDHState) UnmarshalJSON(data []byte) error { - var j ECDHStateJSON - err := json.Unmarshal(data, &j) - if err != nil { - return err - } - - state.SharedSecret = j.SharedSecret - if len(j.ECKey) == 0 { - state.ECKey = nil - } else { - tmp_key, err := x509.ParsePKCS8PrivateKey(j.ECKey) - if err != nil { - return err - } - - ecdsa_key, ok := tmp_key.(*ecdsa.PrivateKey) - if ok == false { - return fmt.Errorf("Parsed wrong key type from DB for ECDHState") - } - - state.ECKey, err = ecdsa_key.ECDH() - if err != nil { - return err - } - } - - return nil -} - -type ECDHMap map[NodeID]ECDHState - -func (m ECDHMap) MarshalJSON() ([]byte, error) { - tmp := map[string]ECDHState{} - for id, state := range(m) { - tmp[id.String()] = state - } - - return json.Marshal(tmp) -} - -type ECDHExt struct { - ECDHStates ECDHMap -} - -func NewECDHExt() *ECDHExt { - return &ECDHExt{ - ECDHStates: ECDHMap{}, - } -} - -func ResolveFields[T Extension](t T, name string, field_funcs map[string]func(T)interface{})interface{} { - var zero T - field_func, ok := field_funcs[name] - if ok == false { - return fmt.Errorf("%s is not a field of %s", name, zero.Type()) - } - return field_func(t) -} - -func (ext *ECDHExt) Field(name string) interface{} { - return ResolveFields(ext, name, map[string]func(*ECDHExt)interface{}{ - "ecdh_states": func(ext *ECDHExt) interface{} { - return ext.ECDHStates - }, - }) -} - -func (ext *ECDHExt) HandleECDHSignal(ctx *Context, source NodeID, node *Node, signal ECDHSignal) { - ctx.Log.Logf("ecdh", "ECDH_SIGNAL: %s->%s - %+v", source, node, signal) - switch signal.State { - case "req": - state, exists := ext.ECDHStates[source] - if exists == false { - state = ECDHState{nil, nil} - } - resp, shared_secret, err := NewECDHRespSignal(ctx, node, signal) - if err == nil { - state.SharedSecret = shared_secret - ext.ECDHStates[source] = state - ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) - ctx.Send(node.ID, source, resp) - } else { - ctx.Log.Logf("ecdh", "ECDH_REQ_ERR: %s", err) - // TODO: send error response - } - case "resp": - state, exists := ext.ECDHStates[source] - if exists == false || state.ECKey == nil { - ctx.Send(node.ID, source, StateSignal{NewDirectSignal(ECDHSignalType), "no_req"}) - } else { - err := VerifyECDHSignal(time.Now(), signal, DEFAULT_ECDH_WINDOW) - if err == nil { - shared_secret, err := state.ECKey.ECDH(signal.ECDH) - if err == nil { - state.SharedSecret = shared_secret - state.ECKey = nil - ext.ECDHStates[source] = state - ctx.Log.Logf("ecdh", "New shared secret for %s<->%s - %+v", node.ID, source, ext.ECDHStates[source].SharedSecret) - } - } - } - default: - ctx.Log.Logf("ecdh", "unknown echd state %s", signal.State) - } -} - -func (ext *ECDHExt) HandleStateSignal(ctx *Context, source NodeID, node *Node, signal StateSignal) { - -} - -func (ext *ECDHExt) Process(ctx *Context, source NodeID, node *Node, signal Signal) { - switch signal.Direction() { - case Direct: - switch signal.Type() { - case ECDHSignalType: - switch ecdh_signal := signal.(type) { - case ECDHSignal: - ext.HandleECDHSignal(ctx, source, node, ecdh_signal) - case StateSignal: - ext.HandleStateSignal(ctx, source, node, ecdh_signal) - default: - ctx.Log.Logf("ecdh", "BAD_SIGNAL_CAST: %+v", signal) - } - default: - } - default: - } -} - -func (ext *ECDHExt) Type() ExtType { - return ECDHExtType -} - -func (ext *ECDHExt) Serialize() ([]byte, error) { - return json.MarshalIndent(ext, "", " ") -} - -func LoadECDHExt(ctx *Context, data []byte) (Extension, error) { - var ext ECDHExt - err := json.Unmarshal(data, &ext) - if err != nil { - return nil, err - } - - return &ext, nil -} - type GroupExt struct { Members map[NodeID]string }