diff --git a/context.go b/context.go index 47b6688..c36141c 100644 --- a/context.go +++ b/context.go @@ -7,7 +7,6 @@ import ( "errors" "runtime" "crypto/sha512" - "crypto/elliptic" "crypto/ecdh" "encoding/binary" ) @@ -112,8 +111,6 @@ type Context struct { Signals map[uint64]SignalInfo // Map between database type hashes and the registered info Types map[uint64]*NodeInfo - // Curve used for signature operations - ECDSA elliptic.Curve // Curve used for ecdh operations ECDH ecdh.Curve // Routing map to all the nodes local to this context @@ -252,8 +249,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { Types: map[uint64]*NodeInfo{}, Signals: map[uint64]SignalInfo{}, Nodes: map[NodeID]*Node{}, - ECDH: ecdh.P256(), - ECDSA: elliptic.P256(), + ECDH: ecdh.X25519(), } var err error @@ -293,6 +289,16 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) { return nil, err } + err = RegisterSignal[BaseSignal, *BaseSignal](ctx, NewSignalType) + if err != nil { + return nil, err + } + + err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StartSignalType) + if err != nil { + return nil, err + } + err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType}) if err != nil { return nil, err diff --git a/gql.go b/gql.go index 9415da0..b6b79e7 100644 --- a/gql.go +++ b/gql.go @@ -18,7 +18,7 @@ import ( "github.com/gobwas/ws/wsutil" "strings" "crypto/ecdh" - "crypto/ecdsa" + "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/x509" @@ -185,7 +185,7 @@ type ResolveContext struct { // Key for the user that made this request, to sign resolver requests // TODO: figure out some way to use a generated key so that the server can't impersonate the user afterwards - Key *ecdsa.PrivateKey + Key ed25519.PrivateKey } func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { @@ -199,12 +199,20 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) } - key, err := x509.ParseECPrivateKey([]byte(key_bytes)) + key_raw, err := x509.ParsePKCS8PrivateKey([]byte(key_bytes)) if err != nil { return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ecdsa key from auth password: %s", key_bytes) } - key_id := KeyID(&key.PublicKey) + var key ed25519.PrivateKey + switch k := key_raw.(type) { + case ed25519.PrivateKey: + key = k + default: + return nil, fmt.Errorf("GQL_REQUEST_ERR: wrong type for key: %s", reflect.TypeOf(key_raw)) + } + + key_id := KeyID(key.Public().(ed25519.PublicKey)) if auth_id != key_id { return nil, fmt.Errorf("GQL_REQUEST_ERR: key_id(%s) != auth_id(%s)", auth_id, key_id) } @@ -1144,12 +1152,12 @@ func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error { func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) { if tls_cert == nil || tls_key == nil { - ssl_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) + _, ssl_key, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, err } - ssl_key_bytes, err := x509.MarshalECPrivateKey(ssl_key) + ssl_key_bytes, err := x509.MarshalPKCS8PrivateKey(ssl_key) if err != nil { return nil, err } @@ -1172,7 +1180,7 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta BasicConstraintsValid: true, } - ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &ssl_key.PublicKey, ssl_key) + ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, ssl_key.Public(), ssl_key) if err != nil { return nil, err } diff --git a/gql_test.go b/gql_test.go index 08a354c..6992ee9 100644 --- a/gql_test.go +++ b/gql_test.go @@ -68,7 +68,7 @@ func TestGQL(t *testing.T) { req, err := http.NewRequest("GET", url, req_data) fatalErr(t, err) - key_bytes, err := x509.MarshalECPrivateKey(n1.Key) + key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key) fatalErr(t, err) req.SetBasicAuth(n1.ID.String(), string(key_bytes)) resp, err := client.Do(req) diff --git a/node.go b/node.go index b8a60f6..db0e436 100644 --- a/node.go +++ b/node.go @@ -10,8 +10,7 @@ import ( "encoding/binary" "encoding/json" "sync/atomic" - "crypto/ecdsa" - "crypto/elliptic" + "crypto/ed25519" "crypto/sha512" "crypto/rand" "crypto/x509" @@ -23,6 +22,7 @@ const ( // Total length of the node database header, has magic to verify and type_hash to map to load function NODE_DB_HEADER_LEN = 28 EXTENSION_DB_HEADER_LEN = 16 + QSIGNAL_DB_HEADER_LEN = 40 ) var ( @@ -106,7 +106,7 @@ type QueuedSignal struct { // Default message channel size for nodes // Nodes represent a group of extensions that can be collectively addressed type Node struct { - Key *ecdsa.PrivateKey + Key ed25519.PrivateKey ID NodeID Type NodeType Extensions map[ExtType]Extension @@ -197,14 +197,15 @@ func nodeLoop(ctx *Context, node *Node) error { return fmt.Errorf("%s is already started, will not start again", node.ID) } - // Queue the signal for extensions to perform startup actions - node.QueueSignal(time.Now(), &StartSignal) + // Perform startup actions + node.Process(ctx, node.ID, &StartSignal) for true { var signal Signal var source NodeID select { case msg := <- node.MsgChan: + ctx.Log.Logf("signal", "NODE_MSG: %s - %+v", node.ID, msg) signal = msg.Signal source = msg.Source err := Allowed(ctx, msg.Source, signal.Permission(), node) @@ -234,7 +235,7 @@ func nodeLoop(ctx *Context, node *Node) error { node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) if node.NextSignal == nil { - ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil", node.ID, t, signal) + ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil@%+v", node.ID, t, signal, node.TimeoutChan) } else { ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, t, signal, node.NextSignal, node.NextSignal.Time) } @@ -250,8 +251,7 @@ func nodeLoop(ctx *Context, node *Node) error { sig_data, err := sig.Signal.Serialize() if err != nil { } else { - sig_hash := sha512.Sum512(sig_data) - validated := ecdsa.VerifyASN1(sig.Principal, sig_hash[:], sig.Signature) + validated := ed25519.Verify(sig.Principal, sig_data, sig.Signature) if validated == true { err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node) if err != nil { @@ -271,6 +271,8 @@ func nodeLoop(ctx *Context, node *Node) error { } } + ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue) + // Handle special signal types if signal.Type() == StopSignalType { resp := NewErrorSignal(signal.ID(), "stopped") @@ -349,8 +351,9 @@ func GetExt[T Extension](node *Node) (T, error) { func (node *Node) Serialize() ([]byte, error) { extensions := make([]ExtensionDB, len(node.Extensions)) + qsignals := make([]QSignalDB, len(node.SignalQueue)) - key_bytes, err := x509.MarshalECPrivateKey(node.Key) + key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key) if err != nil { return nil, err } @@ -365,7 +368,7 @@ func (node *Node) Serialize() ([]byte, error) { NumQueuedSignals: uint32(len(node.SignalQueue)), }, Extensions: extensions, - QueuedSignals: node.SignalQueue, + QueuedSignals: qsignals, KeyBytes: key_bytes, } @@ -385,26 +388,45 @@ func (node *Node) Serialize() ([]byte, error) { i += 1 } + for i, qsignal := range(node.SignalQueue) { + ser, err := qsignal.Signal.Serialize() + if err != nil { + return nil, err + } + + node_db.QueuedSignals[i] = QSignalDB{ + QSignalDBHeader{ + qsignal.Signal.ID(), + qsignal.Time, + Hash(qsignal.Signal.Type()), + uint64(len(ser)), + }, + ser, + } + } + return node_db.Serialize(), nil } -func KeyID(pub *ecdsa.PublicKey) NodeID { - ser := elliptic.Marshal(pub.Curve, pub.X, pub.Y) - str := uuid.NewHash(sha512.New(), ZeroUUID, ser, 3) +func KeyID(pub ed25519.PublicKey) NodeID { + str := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3) return NodeID(str) } // Create a new node in memory and start it's event loop // TODO: Change panics to errors -func NewNode(ctx *Context, key *ecdsa.PrivateKey, node_type NodeType, buffer_size uint32, queued_signals []QueuedSignal, extensions ...Extension) *Node { +func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, queued_signals []QueuedSignal, extensions ...Extension) *Node { var err error + var public ed25519.PublicKey if key == nil { - key, err = ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) + public, key, err = ed25519.GenerateKey(rand.Reader) if err != nil { panic(err) } + } else { + public = key.Public().(ed25519.PublicKey) } - id := KeyID(&key.PublicKey) + id := KeyID(public) _, exists := ctx.Node(id) if exists == true { panic("Attempted to create an existing node") @@ -432,9 +454,7 @@ func NewNode(ctx *Context, key *ecdsa.PrivateKey, node_type NodeType, buffer_siz } if queued_signals == nil { - queued_signals = []QueuedSignal{ - QueuedSignal{uuid.New(), &NewSignal, time.Now()}, - } + queued_signals = []QueuedSignal{} } next_signal, timeout_chan := SoonestSignal(queued_signals) @@ -456,6 +476,8 @@ func NewNode(ctx *Context, key *ecdsa.PrivateKey, node_type NodeType, buffer_siz panic(err) } + node.Process(ctx, node.ID, &NewSignal) + go runNode(ctx, node) return node @@ -497,7 +519,7 @@ type NodeDBHeader struct { type NodeDB struct { Header NodeDBHeader - QueuedSignals []QueuedSignal + QueuedSignals []QSignalDB Extensions []ExtensionDB KeyBytes []byte } @@ -551,9 +573,34 @@ func NewNodeDB(data []byte) (NodeDB, error) { ptr += int(EXTENSION_DB_HEADER_LEN + length) } - queued_signals := make([]QueuedSignal, num_queued_signals) + queued_signals := make([]QSignalDB, num_queued_signals) for i, _ := range(queued_signals) { - queued_signals[i] = QueuedSignal{} + cur := data[ptr:] + // TODO: load a header for each with the signal type and the signal length, so that it can be deserialized and incremented + // Right now causes segfault because any saved signal is loaded as nil + signal_id_bytes := cur[0:16] + unix_milli := binary.BigEndian.Uint64(cur[16:24]) + type_hash := binary.BigEndian.Uint64(cur[24:32]) + signal_size := binary.BigEndian.Uint64(cur[32:40]) + + signal_id, err := uuid.FromBytes(signal_id_bytes) + if err != nil { + return zero, err + } + + signal_data := cur[QSIGNAL_DB_HEADER_LEN:(QSIGNAL_DB_HEADER_LEN+signal_size)] + + queued_signals[i] = QSignalDB{ + QSignalDBHeader{ + signal_id, + time.UnixMilli(int64(unix_milli)), + type_hash, + signal_size, + }, + signal_data, + } + + ptr += QSIGNAL_DB_HEADER_LEN + int(signal_size) } return NodeDB{ @@ -592,10 +639,28 @@ func (node NodeDB) Serialize() []byte { for _, extension := range(node.Extensions) { ser = append(ser, extension.Serialize()...) } + for _, qsignal := range(node.QueuedSignals) { + ser = append(ser, qsignal.Serialize()...) + } return ser } +func (header QSignalDBHeader) Serialize() []byte { + ret := make([]byte, QSIGNAL_DB_HEADER_LEN) + id_ser, _ := header.SignalID.MarshalBinary() + copy(ret, id_ser) + binary.BigEndian.PutUint64(ret[16:24], uint64(header.Time.UnixMilli())) + binary.BigEndian.PutUint64(ret[24:32], header.TypeHash) + binary.BigEndian.PutUint64(ret[32:40], header.Length) + return ret +} + +func (qsignal QSignalDB) Serialize() []byte { + header_bytes := qsignal.Header.Serialize() + return append(header_bytes, qsignal.Data...) +} + func (header ExtensionDBHeader) Serialize() []byte { ret := make([]byte, EXTENSION_DB_HEADER_LEN) binary.BigEndian.PutUint64(ret[0:8], header.TypeHash) @@ -608,6 +673,18 @@ func (extension ExtensionDB) Serialize() []byte { return append(header_bytes, extension.Data...) } +type QSignalDBHeader struct { + SignalID uuid.UUID + Time time.Time + TypeHash uint64 + Length uint64 +} + +type QSignalDB struct { + Header QSignalDBHeader + Data []byte +} + type ExtensionDBHeader struct { TypeHash uint64 Length uint64 @@ -663,16 +740,20 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, err } - key, err := x509.ParseECPrivateKey(node_db.KeyBytes) + key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes) if err != nil { return nil, err } - if key.PublicKey.Curve != ctx.ECDSA { - return nil, fmt.Errorf("%s - wrong ec curve for private key: %+v, expected %+v", id, key.PublicKey.Curve, ctx.ECDSA) + var key ed25519.PrivateKey + switch k := key_raw.(type) { + case ed25519.PrivateKey: + key = k + default: + return nil, fmt.Errorf("Wrong type for private key loaded: %s - %s", id, reflect.TypeOf(k)) } - key_id := KeyID(&key.PublicKey) + key_id := KeyID(key.Public().(ed25519.PublicKey)) if key_id != id { return nil, fmt.Errorf("KeyID(%s) != %s", key_id, id) } @@ -682,7 +763,22 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash) } - next_signal, timeout_chan := SoonestSignal(node_db.QueuedSignals) + signal_queue := make([]QueuedSignal, node_db.Header.NumQueuedSignals) + for i, qsignal := range(node_db.QueuedSignals) { + sig_info, exists := ctx.Signals[qsignal.Header.TypeHash] + if exists == false { + return nil, fmt.Errorf("0x%x is not a known signal type", qsignal.Header.TypeHash) + } + + signal, err := sig_info.Load(ctx, qsignal.Data) + if err != nil { + return nil, err + } + + signal_queue[i] = QueuedSignal{qsignal.Header.SignalID, signal, qsignal.Header.Time} + } + + next_signal, timeout_chan := SoonestSignal(signal_queue) node := &Node{ Key: key, ID: key_id, @@ -691,7 +787,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) { MsgChan: make(chan Msg, node_db.Header.BufferSize), BufferSize: node_db.Header.BufferSize, TimeoutChan: timeout_chan, - SignalQueue: node_db.QueuedSignals, + SignalQueue: signal_queue, NextSignal: next_signal, } ctx.AddNode(id, node) diff --git a/node_test.go b/node_test.go index 7fb9c07..cef5c28 100644 --- a/node_test.go +++ b/node_test.go @@ -4,11 +4,11 @@ import ( "testing" "time" "crypto/rand" - "crypto/ecdsa" + "crypto/ed25519" ) func TestNodeDB(t *testing.T) { - ctx := logTestContext(t, []string{}) + ctx := logTestContext(t, []string{"signal", "node"}) node_type := NodeType("test") err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) fatalErr(t, err) @@ -26,13 +26,13 @@ func TestNodeRead(t *testing.T) { err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType, ECDHExtType}) fatalErr(t, err) - n1_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) + n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader) fatalErr(t, err) - n2_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) + n2_pub, n2_key, err := ed25519.GenerateKey(rand.Reader) fatalErr(t, err) - n1_id := KeyID(&n1_key.PublicKey) - n2_id := KeyID(&n2_key.PublicKey) + n1_id := KeyID(n1_pub) + n2_id := KeyID(n2_pub) ctx.Log.Logf("test", "N1: %s", n1_id) ctx.Log.Logf("test", "N2: %s", n2_id) diff --git a/signal.go b/signal.go index 7e1e7b4..418171c 100644 --- a/signal.go +++ b/signal.go @@ -5,8 +5,8 @@ import ( "fmt" "encoding/json" "encoding/binary" - "crypto/sha512" - "crypto/ecdsa" + "crypto" + "crypto/ed25519" "crypto/ecdh" "crypto/rand" "crypto/aes" @@ -250,7 +250,7 @@ type ReadSignal struct { type AuthorizedSignal struct { BaseSignal - Principal *ecdsa.PublicKey + Principal ed25519.PublicKey Signal Signal Signature []byte } @@ -259,21 +259,20 @@ func (signal *AuthorizedSignal) Permission() Action { return AuthorizedSignalAction } -func NewAuthorizedSignal(principal *ecdsa.PrivateKey, signal Signal) (AuthorizedSignal, error) { +func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (AuthorizedSignal, error) { sig_data, err := signal.Serialize() if err != nil { return AuthorizedSignal{}, err } - sig_hash := sha512.Sum512(sig_data) - sig, err := ecdsa.SignASN1(rand.Reader, principal, sig_hash[:]) + sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0)) if err != nil { return AuthorizedSignal{}, err } return AuthorizedSignal{ BaseSignal: NewDirectSignal(AuthorizedSignalType), - Principal: &principal.PublicKey, + Principal: principal.Public().(ed25519.PublicKey), Signal: signal, Signature: sig, }, nil @@ -315,7 +314,7 @@ func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType] type ECDHSignal struct { StringSignal Time time.Time - ECDSA *ecdsa.PublicKey + EDDSA ed25519.PublicKey ECDH *ecdh.PublicKey Signature []byte } @@ -323,7 +322,7 @@ type ECDHSignal struct { type ECDHSignalJSON struct { StringSignal Time time.Time `json:"time"` - ECDSA []byte `json:"ecdsa_pubkey"` + EDDSA []byte `json:"ecdsa_pubkey"` ECDH []byte `json:"ecdh_pubkey"` Signature []byte `json:"signature"` } @@ -333,7 +332,7 @@ func (signal *ECDHSignal) MarshalJSON() ([]byte, error) { StringSignal: signal.StringSignal, Time: signal.Time, ECDH: signal.ECDH.Bytes(), - ECDSA: signal.ECDH.Bytes(), + EDDSA: signal.ECDH.Bytes(), Signature: signal.Signature, }) } @@ -342,18 +341,6 @@ func (signal *ECDHSignal) Serialize() ([]byte, error) { return json.Marshal(signal) } -func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) { - time_bytes, err := now.MarshalJSON() - if err != nil { - return nil, err - } - - sig_data := append(ec_key.Bytes(), time_bytes...) - sig_hash := sha512.Sum512(sig_data) - - return sig_hash[:], nil -} - func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, error) { ec_key, err := ctx.ECDH.GenerateKey(rand.Reader) if err != nil { @@ -361,13 +348,14 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e } now := time.Now() - - sig_hash, err := keyHash(now, ec_key.PublicKey()) + time_bytes, err := now.MarshalJSON() if err != nil { return ECDHSignal{}, nil, err } - sig, err := ecdsa.SignASN1(rand.Reader, node.Key, sig_hash) + sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) + + sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) if err != nil { return ECDHSignal{}, nil, err } @@ -378,7 +366,7 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e Str: "req", }, Time: now, - ECDSA: &node.Key.PublicKey, + EDDSA: node.Key.Public().(ed25519.PublicKey), ECDH: ec_key.PublicKey(), Signature: sig, }, ec_key, nil @@ -404,12 +392,14 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [ return ECDHSignal{}, nil, err } - key_hash, err := keyHash(now, ec_key.PublicKey()) + time_bytes, err := now.MarshalJSON() if err != nil { return ECDHSignal{}, nil, err } - sig, err := ecdsa.SignASN1(rand.Reader, node.Key, key_hash) + sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...) + + sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0)) if err != nil { return ECDHSignal{}, nil, err } @@ -420,7 +410,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [ Str: "resp", }, Time: now, - ECDSA: &node.Key.PublicKey, + EDDSA: node.Key.Public().(ed25519.PublicKey), ECDH: ec_key.PublicKey(), Signature: sig, }, shared_secret, nil @@ -436,14 +426,16 @@ func VerifyECDHSignal(now time.Time, sig *ECDHSignal, window time.Duration) erro return fmt.Errorf("TIME_TOO_EARLY: %+v", sig.Time) } - sig_hash, err := keyHash(sig.Time, sig.ECDH) + time_bytes, err := sig.Time.MarshalJSON() if err != nil { return err } - verified := ecdsa.VerifyASN1(sig.ECDSA, sig_hash, sig.Signature) + sig_data := append(sig.ECDH.Bytes(), time_bytes...) + + verified := ed25519.Verify(sig.EDDSA, sig_data, sig.Signature) if verified == false { - return fmt.Errorf("VERIFY_FAIL") + return fmt.Errorf("Failed to verify signature") } return nil