Moved to x25519 for EC operations

gql_cataclysm
noah metz 2023-08-06 12:47:47 -06:00
parent 7d0af0eb5b
commit 1d91854f6f
6 changed files with 181 additions and 79 deletions

@ -7,7 +7,6 @@ import (
"errors" "errors"
"runtime" "runtime"
"crypto/sha512" "crypto/sha512"
"crypto/elliptic"
"crypto/ecdh" "crypto/ecdh"
"encoding/binary" "encoding/binary"
) )
@ -112,8 +111,6 @@ type Context struct {
Signals map[uint64]SignalInfo Signals map[uint64]SignalInfo
// Map between database type hashes and the registered info // Map between database type hashes and the registered info
Types map[uint64]*NodeInfo Types map[uint64]*NodeInfo
// Curve used for signature operations
ECDSA elliptic.Curve
// Curve used for ecdh operations // Curve used for ecdh operations
ECDH ecdh.Curve ECDH ecdh.Curve
// Routing map to all the nodes local to this context // Routing map to all the nodes local to this context
@ -252,8 +249,7 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
Types: map[uint64]*NodeInfo{}, Types: map[uint64]*NodeInfo{},
Signals: map[uint64]SignalInfo{}, Signals: map[uint64]SignalInfo{},
Nodes: map[NodeID]*Node{}, Nodes: map[NodeID]*Node{},
ECDH: ecdh.P256(), ECDH: ecdh.X25519(),
ECDSA: elliptic.P256(),
} }
var err error var err error
@ -293,6 +289,16 @@ func NewContext(db * badger.DB, log Logger) (*Context, error) {
return nil, err return nil, err
} }
err = RegisterSignal[BaseSignal, *BaseSignal](ctx, NewSignalType)
if err != nil {
return nil, err
}
err = RegisterSignal[BaseSignal, *BaseSignal](ctx, StartSignalType)
if err != nil {
return nil, err
}
err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType}) err = ctx.RegisterNodeType(GQLNodeType, []ExtType{ACLExtType, GroupExtType, GQLExtType})
if err != nil { if err != nil {
return nil, err return nil, err

@ -18,7 +18,7 @@ import (
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"strings" "strings"
"crypto/ecdh" "crypto/ecdh"
"crypto/ecdsa" "crypto/ed25519"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
@ -185,7 +185,7 @@ type ResolveContext struct {
// Key for the user that made this request, to sign resolver requests // Key for the user that made this request, to sign resolver requests
// TODO: figure out some way to use a generated key so that the server can't impersonate the user afterwards // TODO: figure out some way to use a generated key so that the server can't impersonate the user afterwards
Key *ecdsa.PrivateKey Key ed25519.PrivateKey
} }
func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) { func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Request) (*ResolveContext, error) {
@ -199,12 +199,20 @@ func NewResolveContext(ctx *Context, server *Node, gql_ext *GQLExt, r *http.Requ
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username) return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ID from auth username: %s", username)
} }
key, err := x509.ParseECPrivateKey([]byte(key_bytes)) key_raw, err := x509.ParsePKCS8PrivateKey([]byte(key_bytes))
if err != nil { if err != nil {
return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ecdsa key from auth password: %s", key_bytes) return nil, fmt.Errorf("GQL_REQUEST_ERR: failed to parse ecdsa key from auth password: %s", key_bytes)
} }
key_id := KeyID(&key.PublicKey) var key ed25519.PrivateKey
switch k := key_raw.(type) {
case ed25519.PrivateKey:
key = k
default:
return nil, fmt.Errorf("GQL_REQUEST_ERR: wrong type for key: %s", reflect.TypeOf(key_raw))
}
key_id := KeyID(key.Public().(ed25519.PublicKey))
if auth_id != key_id { if auth_id != key_id {
return nil, fmt.Errorf("GQL_REQUEST_ERR: key_id(%s) != auth_id(%s)", auth_id, key_id) return nil, fmt.Errorf("GQL_REQUEST_ERR: key_id(%s) != auth_id(%s)", auth_id, key_id)
} }
@ -1144,12 +1152,12 @@ func (ext *GQLExt) Deserialize(ctx *Context, data []byte) error {
func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) { func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, state string) (*GQLExt, error) {
if tls_cert == nil || tls_key == nil { if tls_cert == nil || tls_key == nil {
ssl_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) _, ssl_key, err := ed25519.GenerateKey(rand.Reader)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ssl_key_bytes, err := x509.MarshalECPrivateKey(ssl_key) ssl_key_bytes, err := x509.MarshalPKCS8PrivateKey(ssl_key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1172,7 +1180,7 @@ func NewGQLExt(ctx *Context, listen string, tls_cert []byte, tls_key []byte, sta
BasicConstraintsValid: true, BasicConstraintsValid: true,
} }
ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, &ssl_key.PublicKey, ssl_key) ssl_cert, err := x509.CreateCertificate(rand.Reader, &template, &template, ssl_key.Public(), ssl_key)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -68,7 +68,7 @@ func TestGQL(t *testing.T) {
req, err := http.NewRequest("GET", url, req_data) req, err := http.NewRequest("GET", url, req_data)
fatalErr(t, err) fatalErr(t, err)
key_bytes, err := x509.MarshalECPrivateKey(n1.Key) key_bytes, err := x509.MarshalPKCS8PrivateKey(n1.Key)
fatalErr(t, err) fatalErr(t, err)
req.SetBasicAuth(n1.ID.String(), string(key_bytes)) req.SetBasicAuth(n1.ID.String(), string(key_bytes))
resp, err := client.Do(req) resp, err := client.Do(req)

@ -10,8 +10,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"sync/atomic" "sync/atomic"
"crypto/ecdsa" "crypto/ed25519"
"crypto/elliptic"
"crypto/sha512" "crypto/sha512"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
@ -23,6 +22,7 @@ const (
// Total length of the node database header, has magic to verify and type_hash to map to load function // Total length of the node database header, has magic to verify and type_hash to map to load function
NODE_DB_HEADER_LEN = 28 NODE_DB_HEADER_LEN = 28
EXTENSION_DB_HEADER_LEN = 16 EXTENSION_DB_HEADER_LEN = 16
QSIGNAL_DB_HEADER_LEN = 40
) )
var ( var (
@ -106,7 +106,7 @@ type QueuedSignal struct {
// Default message channel size for nodes // Default message channel size for nodes
// Nodes represent a group of extensions that can be collectively addressed // Nodes represent a group of extensions that can be collectively addressed
type Node struct { type Node struct {
Key *ecdsa.PrivateKey Key ed25519.PrivateKey
ID NodeID ID NodeID
Type NodeType Type NodeType
Extensions map[ExtType]Extension Extensions map[ExtType]Extension
@ -197,14 +197,15 @@ func nodeLoop(ctx *Context, node *Node) error {
return fmt.Errorf("%s is already started, will not start again", node.ID) return fmt.Errorf("%s is already started, will not start again", node.ID)
} }
// Queue the signal for extensions to perform startup actions // Perform startup actions
node.QueueSignal(time.Now(), &StartSignal) node.Process(ctx, node.ID, &StartSignal)
for true { for true {
var signal Signal var signal Signal
var source NodeID var source NodeID
select { select {
case msg := <- node.MsgChan: case msg := <- node.MsgChan:
ctx.Log.Logf("signal", "NODE_MSG: %s - %+v", node.ID, msg)
signal = msg.Signal signal = msg.Signal
source = msg.Source source = msg.Source
err := Allowed(ctx, msg.Source, signal.Permission(), node) err := Allowed(ctx, msg.Source, signal.Permission(), node)
@ -234,7 +235,7 @@ func nodeLoop(ctx *Context, node *Node) error {
node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue) node.NextSignal, node.TimeoutChan = SoonestSignal(node.SignalQueue)
if node.NextSignal == nil { if node.NextSignal == nil {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil", node.ID, t, signal) ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL nil@%+v", node.ID, t, signal, node.TimeoutChan)
} else { } else {
ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, t, signal, node.NextSignal, node.NextSignal.Time) ctx.Log.Logf("node", "NODE_TIMEOUT(%s) - PROCESSING %+v@%s - NEXT_SIGNAL: %s@%s", node.ID, t, signal, node.NextSignal, node.NextSignal.Time)
} }
@ -250,8 +251,7 @@ func nodeLoop(ctx *Context, node *Node) error {
sig_data, err := sig.Signal.Serialize() sig_data, err := sig.Signal.Serialize()
if err != nil { if err != nil {
} else { } else {
sig_hash := sha512.Sum512(sig_data) validated := ed25519.Verify(sig.Principal, sig_data, sig.Signature)
validated := ecdsa.VerifyASN1(sig.Principal, sig_hash[:], sig.Signature)
if validated == true { if validated == true {
err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node) err := Allowed(ctx, KeyID(sig.Principal), sig.Signal.Permission(), node)
if err != nil { if err != nil {
@ -271,6 +271,8 @@ func nodeLoop(ctx *Context, node *Node) error {
} }
} }
ctx.Log.Logf("node", "NODE_SIGNAL_QUEUE[%s]: %+v", node.ID, node.SignalQueue)
// Handle special signal types // Handle special signal types
if signal.Type() == StopSignalType { if signal.Type() == StopSignalType {
resp := NewErrorSignal(signal.ID(), "stopped") resp := NewErrorSignal(signal.ID(), "stopped")
@ -349,8 +351,9 @@ func GetExt[T Extension](node *Node) (T, error) {
func (node *Node) Serialize() ([]byte, error) { func (node *Node) Serialize() ([]byte, error) {
extensions := make([]ExtensionDB, len(node.Extensions)) extensions := make([]ExtensionDB, len(node.Extensions))
qsignals := make([]QSignalDB, len(node.SignalQueue))
key_bytes, err := x509.MarshalECPrivateKey(node.Key) key_bytes, err := x509.MarshalPKCS8PrivateKey(node.Key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -365,7 +368,7 @@ func (node *Node) Serialize() ([]byte, error) {
NumQueuedSignals: uint32(len(node.SignalQueue)), NumQueuedSignals: uint32(len(node.SignalQueue)),
}, },
Extensions: extensions, Extensions: extensions,
QueuedSignals: node.SignalQueue, QueuedSignals: qsignals,
KeyBytes: key_bytes, KeyBytes: key_bytes,
} }
@ -385,26 +388,45 @@ func (node *Node) Serialize() ([]byte, error) {
i += 1 i += 1
} }
for i, qsignal := range(node.SignalQueue) {
ser, err := qsignal.Signal.Serialize()
if err != nil {
return nil, err
}
node_db.QueuedSignals[i] = QSignalDB{
QSignalDBHeader{
qsignal.Signal.ID(),
qsignal.Time,
Hash(qsignal.Signal.Type()),
uint64(len(ser)),
},
ser,
}
}
return node_db.Serialize(), nil return node_db.Serialize(), nil
} }
func KeyID(pub *ecdsa.PublicKey) NodeID { func KeyID(pub ed25519.PublicKey) NodeID {
ser := elliptic.Marshal(pub.Curve, pub.X, pub.Y) str := uuid.NewHash(sha512.New(), ZeroUUID, pub, 3)
str := uuid.NewHash(sha512.New(), ZeroUUID, ser, 3)
return NodeID(str) return NodeID(str)
} }
// Create a new node in memory and start it's event loop // Create a new node in memory and start it's event loop
// TODO: Change panics to errors // TODO: Change panics to errors
func NewNode(ctx *Context, key *ecdsa.PrivateKey, node_type NodeType, buffer_size uint32, queued_signals []QueuedSignal, extensions ...Extension) *Node { func NewNode(ctx *Context, key ed25519.PrivateKey, node_type NodeType, buffer_size uint32, queued_signals []QueuedSignal, extensions ...Extension) *Node {
var err error var err error
var public ed25519.PublicKey
if key == nil { if key == nil {
key, err = ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) public, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil { if err != nil {
panic(err) panic(err)
} }
} else {
public = key.Public().(ed25519.PublicKey)
} }
id := KeyID(&key.PublicKey) id := KeyID(public)
_, exists := ctx.Node(id) _, exists := ctx.Node(id)
if exists == true { if exists == true {
panic("Attempted to create an existing node") panic("Attempted to create an existing node")
@ -432,9 +454,7 @@ func NewNode(ctx *Context, key *ecdsa.PrivateKey, node_type NodeType, buffer_siz
} }
if queued_signals == nil { if queued_signals == nil {
queued_signals = []QueuedSignal{ queued_signals = []QueuedSignal{}
QueuedSignal{uuid.New(), &NewSignal, time.Now()},
}
} }
next_signal, timeout_chan := SoonestSignal(queued_signals) next_signal, timeout_chan := SoonestSignal(queued_signals)
@ -456,6 +476,8 @@ func NewNode(ctx *Context, key *ecdsa.PrivateKey, node_type NodeType, buffer_siz
panic(err) panic(err)
} }
node.Process(ctx, node.ID, &NewSignal)
go runNode(ctx, node) go runNode(ctx, node)
return node return node
@ -497,7 +519,7 @@ type NodeDBHeader struct {
type NodeDB struct { type NodeDB struct {
Header NodeDBHeader Header NodeDBHeader
QueuedSignals []QueuedSignal QueuedSignals []QSignalDB
Extensions []ExtensionDB Extensions []ExtensionDB
KeyBytes []byte KeyBytes []byte
} }
@ -551,9 +573,34 @@ func NewNodeDB(data []byte) (NodeDB, error) {
ptr += int(EXTENSION_DB_HEADER_LEN + length) ptr += int(EXTENSION_DB_HEADER_LEN + length)
} }
queued_signals := make([]QueuedSignal, num_queued_signals) queued_signals := make([]QSignalDB, num_queued_signals)
for i, _ := range(queued_signals) { for i, _ := range(queued_signals) {
queued_signals[i] = QueuedSignal{} cur := data[ptr:]
// TODO: load a header for each with the signal type and the signal length, so that it can be deserialized and incremented
// Right now causes segfault because any saved signal is loaded as nil
signal_id_bytes := cur[0:16]
unix_milli := binary.BigEndian.Uint64(cur[16:24])
type_hash := binary.BigEndian.Uint64(cur[24:32])
signal_size := binary.BigEndian.Uint64(cur[32:40])
signal_id, err := uuid.FromBytes(signal_id_bytes)
if err != nil {
return zero, err
}
signal_data := cur[QSIGNAL_DB_HEADER_LEN:(QSIGNAL_DB_HEADER_LEN+signal_size)]
queued_signals[i] = QSignalDB{
QSignalDBHeader{
signal_id,
time.UnixMilli(int64(unix_milli)),
type_hash,
signal_size,
},
signal_data,
}
ptr += QSIGNAL_DB_HEADER_LEN + int(signal_size)
} }
return NodeDB{ return NodeDB{
@ -592,10 +639,28 @@ func (node NodeDB) Serialize() []byte {
for _, extension := range(node.Extensions) { for _, extension := range(node.Extensions) {
ser = append(ser, extension.Serialize()...) ser = append(ser, extension.Serialize()...)
} }
for _, qsignal := range(node.QueuedSignals) {
ser = append(ser, qsignal.Serialize()...)
}
return ser return ser
} }
func (header QSignalDBHeader) Serialize() []byte {
ret := make([]byte, QSIGNAL_DB_HEADER_LEN)
id_ser, _ := header.SignalID.MarshalBinary()
copy(ret, id_ser)
binary.BigEndian.PutUint64(ret[16:24], uint64(header.Time.UnixMilli()))
binary.BigEndian.PutUint64(ret[24:32], header.TypeHash)
binary.BigEndian.PutUint64(ret[32:40], header.Length)
return ret
}
func (qsignal QSignalDB) Serialize() []byte {
header_bytes := qsignal.Header.Serialize()
return append(header_bytes, qsignal.Data...)
}
func (header ExtensionDBHeader) Serialize() []byte { func (header ExtensionDBHeader) Serialize() []byte {
ret := make([]byte, EXTENSION_DB_HEADER_LEN) ret := make([]byte, EXTENSION_DB_HEADER_LEN)
binary.BigEndian.PutUint64(ret[0:8], header.TypeHash) binary.BigEndian.PutUint64(ret[0:8], header.TypeHash)
@ -608,6 +673,18 @@ func (extension ExtensionDB) Serialize() []byte {
return append(header_bytes, extension.Data...) return append(header_bytes, extension.Data...)
} }
type QSignalDBHeader struct {
SignalID uuid.UUID
Time time.Time
TypeHash uint64
Length uint64
}
type QSignalDB struct {
Header QSignalDBHeader
Data []byte
}
type ExtensionDBHeader struct { type ExtensionDBHeader struct {
TypeHash uint64 TypeHash uint64
Length uint64 Length uint64
@ -663,16 +740,20 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, err return nil, err
} }
key, err := x509.ParseECPrivateKey(node_db.KeyBytes) key_raw, err := x509.ParsePKCS8PrivateKey(node_db.KeyBytes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if key.PublicKey.Curve != ctx.ECDSA { var key ed25519.PrivateKey
return nil, fmt.Errorf("%s - wrong ec curve for private key: %+v, expected %+v", id, key.PublicKey.Curve, ctx.ECDSA) switch k := key_raw.(type) {
case ed25519.PrivateKey:
key = k
default:
return nil, fmt.Errorf("Wrong type for private key loaded: %s - %s", id, reflect.TypeOf(k))
} }
key_id := KeyID(&key.PublicKey) key_id := KeyID(key.Public().(ed25519.PublicKey))
if key_id != id { if key_id != id {
return nil, fmt.Errorf("KeyID(%s) != %s", key_id, id) return nil, fmt.Errorf("KeyID(%s) != %s", key_id, id)
} }
@ -682,7 +763,22 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash) return nil, fmt.Errorf("Tried to load node %s of type 0x%x, which is not a known node type", id, node_db.Header.TypeHash)
} }
next_signal, timeout_chan := SoonestSignal(node_db.QueuedSignals) signal_queue := make([]QueuedSignal, node_db.Header.NumQueuedSignals)
for i, qsignal := range(node_db.QueuedSignals) {
sig_info, exists := ctx.Signals[qsignal.Header.TypeHash]
if exists == false {
return nil, fmt.Errorf("0x%x is not a known signal type", qsignal.Header.TypeHash)
}
signal, err := sig_info.Load(ctx, qsignal.Data)
if err != nil {
return nil, err
}
signal_queue[i] = QueuedSignal{qsignal.Header.SignalID, signal, qsignal.Header.Time}
}
next_signal, timeout_chan := SoonestSignal(signal_queue)
node := &Node{ node := &Node{
Key: key, Key: key,
ID: key_id, ID: key_id,
@ -691,7 +787,7 @@ func LoadNode(ctx * Context, id NodeID) (*Node, error) {
MsgChan: make(chan Msg, node_db.Header.BufferSize), MsgChan: make(chan Msg, node_db.Header.BufferSize),
BufferSize: node_db.Header.BufferSize, BufferSize: node_db.Header.BufferSize,
TimeoutChan: timeout_chan, TimeoutChan: timeout_chan,
SignalQueue: node_db.QueuedSignals, SignalQueue: signal_queue,
NextSignal: next_signal, NextSignal: next_signal,
} }
ctx.AddNode(id, node) ctx.AddNode(id, node)

@ -4,11 +4,11 @@ import (
"testing" "testing"
"time" "time"
"crypto/rand" "crypto/rand"
"crypto/ecdsa" "crypto/ed25519"
) )
func TestNodeDB(t *testing.T) { func TestNodeDB(t *testing.T) {
ctx := logTestContext(t, []string{}) ctx := logTestContext(t, []string{"signal", "node"})
node_type := NodeType("test") node_type := NodeType("test")
err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{GroupExtType})
fatalErr(t, err) fatalErr(t, err)
@ -26,13 +26,13 @@ func TestNodeRead(t *testing.T) {
err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType, ECDHExtType}) err := ctx.RegisterNodeType(node_type, []ExtType{ACLExtType, GroupExtType, ECDHExtType})
fatalErr(t, err) fatalErr(t, err)
n1_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) n1_pub, n1_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err) fatalErr(t, err)
n2_key, err := ecdsa.GenerateKey(ctx.ECDSA, rand.Reader) n2_pub, n2_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err) fatalErr(t, err)
n1_id := KeyID(&n1_key.PublicKey) n1_id := KeyID(n1_pub)
n2_id := KeyID(&n2_key.PublicKey) n2_id := KeyID(n2_pub)
ctx.Log.Logf("test", "N1: %s", n1_id) ctx.Log.Logf("test", "N1: %s", n1_id)
ctx.Log.Logf("test", "N2: %s", n2_id) ctx.Log.Logf("test", "N2: %s", n2_id)

@ -5,8 +5,8 @@ import (
"fmt" "fmt"
"encoding/json" "encoding/json"
"encoding/binary" "encoding/binary"
"crypto/sha512" "crypto"
"crypto/ecdsa" "crypto/ed25519"
"crypto/ecdh" "crypto/ecdh"
"crypto/rand" "crypto/rand"
"crypto/aes" "crypto/aes"
@ -250,7 +250,7 @@ type ReadSignal struct {
type AuthorizedSignal struct { type AuthorizedSignal struct {
BaseSignal BaseSignal
Principal *ecdsa.PublicKey Principal ed25519.PublicKey
Signal Signal Signal Signal
Signature []byte Signature []byte
} }
@ -259,21 +259,20 @@ func (signal *AuthorizedSignal) Permission() Action {
return AuthorizedSignalAction return AuthorizedSignalAction
} }
func NewAuthorizedSignal(principal *ecdsa.PrivateKey, signal Signal) (AuthorizedSignal, error) { func NewAuthorizedSignal(principal ed25519.PrivateKey, signal Signal) (AuthorizedSignal, error) {
sig_data, err := signal.Serialize() sig_data, err := signal.Serialize()
if err != nil { if err != nil {
return AuthorizedSignal{}, err return AuthorizedSignal{}, err
} }
sig_hash := sha512.Sum512(sig_data) sig, err := principal.Sign(rand.Reader, sig_data, crypto.Hash(0))
sig, err := ecdsa.SignASN1(rand.Reader, principal, sig_hash[:])
if err != nil { if err != nil {
return AuthorizedSignal{}, err return AuthorizedSignal{}, err
} }
return AuthorizedSignal{ return AuthorizedSignal{
BaseSignal: NewDirectSignal(AuthorizedSignalType), BaseSignal: NewDirectSignal(AuthorizedSignalType),
Principal: &principal.PublicKey, Principal: principal.Public().(ed25519.PublicKey),
Signal: signal, Signal: signal,
Signature: sig, Signature: sig,
}, nil }, nil
@ -315,7 +314,7 @@ func NewReadResultSignal(req_id uuid.UUID, node_type NodeType, exts map[ExtType]
type ECDHSignal struct { type ECDHSignal struct {
StringSignal StringSignal
Time time.Time Time time.Time
ECDSA *ecdsa.PublicKey EDDSA ed25519.PublicKey
ECDH *ecdh.PublicKey ECDH *ecdh.PublicKey
Signature []byte Signature []byte
} }
@ -323,7 +322,7 @@ type ECDHSignal struct {
type ECDHSignalJSON struct { type ECDHSignalJSON struct {
StringSignal StringSignal
Time time.Time `json:"time"` Time time.Time `json:"time"`
ECDSA []byte `json:"ecdsa_pubkey"` EDDSA []byte `json:"ecdsa_pubkey"`
ECDH []byte `json:"ecdh_pubkey"` ECDH []byte `json:"ecdh_pubkey"`
Signature []byte `json:"signature"` Signature []byte `json:"signature"`
} }
@ -333,7 +332,7 @@ func (signal *ECDHSignal) MarshalJSON() ([]byte, error) {
StringSignal: signal.StringSignal, StringSignal: signal.StringSignal,
Time: signal.Time, Time: signal.Time,
ECDH: signal.ECDH.Bytes(), ECDH: signal.ECDH.Bytes(),
ECDSA: signal.ECDH.Bytes(), EDDSA: signal.ECDH.Bytes(),
Signature: signal.Signature, Signature: signal.Signature,
}) })
} }
@ -342,18 +341,6 @@ func (signal *ECDHSignal) Serialize() ([]byte, error) {
return json.Marshal(signal) return json.Marshal(signal)
} }
func keyHash(now time.Time, ec_key *ecdh.PublicKey) ([]byte, error) {
time_bytes, err := now.MarshalJSON()
if err != nil {
return nil, err
}
sig_data := append(ec_key.Bytes(), time_bytes...)
sig_hash := sha512.Sum512(sig_data)
return sig_hash[:], nil
}
func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, error) { func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, error) {
ec_key, err := ctx.ECDH.GenerateKey(rand.Reader) ec_key, err := ctx.ECDH.GenerateKey(rand.Reader)
if err != nil { if err != nil {
@ -361,13 +348,14 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e
} }
now := time.Now() now := time.Now()
time_bytes, err := now.MarshalJSON()
sig_hash, err := keyHash(now, ec_key.PublicKey())
if err != nil { if err != nil {
return ECDHSignal{}, nil, err return ECDHSignal{}, nil, err
} }
sig, err := ecdsa.SignASN1(rand.Reader, node.Key, sig_hash) sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...)
sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil { if err != nil {
return ECDHSignal{}, nil, err return ECDHSignal{}, nil, err
} }
@ -378,7 +366,7 @@ func NewECDHReqSignal(ctx *Context, node *Node) (ECDHSignal, *ecdh.PrivateKey, e
Str: "req", Str: "req",
}, },
Time: now, Time: now,
ECDSA: &node.Key.PublicKey, EDDSA: node.Key.Public().(ed25519.PublicKey),
ECDH: ec_key.PublicKey(), ECDH: ec_key.PublicKey(),
Signature: sig, Signature: sig,
}, ec_key, nil }, ec_key, nil
@ -404,12 +392,14 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [
return ECDHSignal{}, nil, err return ECDHSignal{}, nil, err
} }
key_hash, err := keyHash(now, ec_key.PublicKey()) time_bytes, err := now.MarshalJSON()
if err != nil { if err != nil {
return ECDHSignal{}, nil, err return ECDHSignal{}, nil, err
} }
sig, err := ecdsa.SignASN1(rand.Reader, node.Key, key_hash) sig_data := append(ec_key.PublicKey().Bytes(), time_bytes...)
sig, err := node.Key.Sign(rand.Reader, sig_data, crypto.Hash(0))
if err != nil { if err != nil {
return ECDHSignal{}, nil, err return ECDHSignal{}, nil, err
} }
@ -420,7 +410,7 @@ func NewECDHRespSignal(ctx *Context, node *Node, req *ECDHSignal) (ECDHSignal, [
Str: "resp", Str: "resp",
}, },
Time: now, Time: now,
ECDSA: &node.Key.PublicKey, EDDSA: node.Key.Public().(ed25519.PublicKey),
ECDH: ec_key.PublicKey(), ECDH: ec_key.PublicKey(),
Signature: sig, Signature: sig,
}, shared_secret, nil }, shared_secret, nil
@ -436,14 +426,16 @@ func VerifyECDHSignal(now time.Time, sig *ECDHSignal, window time.Duration) erro
return fmt.Errorf("TIME_TOO_EARLY: %+v", sig.Time) return fmt.Errorf("TIME_TOO_EARLY: %+v", sig.Time)
} }
sig_hash, err := keyHash(sig.Time, sig.ECDH) time_bytes, err := sig.Time.MarshalJSON()
if err != nil { if err != nil {
return err return err
} }
verified := ecdsa.VerifyASN1(sig.ECDSA, sig_hash, sig.Signature) sig_data := append(sig.ECDH.Bytes(), time_bytes...)
verified := ed25519.Verify(sig.EDDSA, sig_data, sig.Signature)
if verified == false { if verified == false {
return fmt.Errorf("VERIFY_FAIL") return fmt.Errorf("Failed to verify signature")
} }
return nil return nil