From 900b7502a9d0a3b4c17654fbb52e0aeaa9edf583 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Sun, 7 Apr 2024 13:27:28 -0600 Subject: [PATCH] Added test for SessionData --- channel.go | 2 +- client.go | 30 +++------------- packet.go | 94 +++++++++++++++++++++++++++++++++++++++++--------- packet_test.go | 58 +++++++++++++++++++++++++++---- server.go | 70 +++++++++++++++---------------------- 5 files changed, 163 insertions(+), 91 deletions(-) diff --git a/channel.go b/channel.go index 3c9821d..e1ce091 100644 --- a/channel.go +++ b/channel.go @@ -10,7 +10,7 @@ const RootChannelID = 0 type ModeID uint8 type CommandID uint8 -type PermissionMap map[ClientID]map[ModeID]map[CommandID]bool +type PermissionMap map[PeerID]map[ModeID]map[CommandID]bool type Channel struct { modes map[ModeID]Mode diff --git a/client.go b/client.go index a81355e..19d0619 100644 --- a/client.go +++ b/client.go @@ -3,10 +3,7 @@ package pnyx import ( "crypto/ed25519" "crypto/rand" - mrand "math/rand" - "encoding/binary" "crypto/sha512" - "crypto/aes" "net" "sync" "fmt" @@ -14,8 +11,8 @@ import ( "github.com/google/uuid" ) -type ClientID uuid.UUID -func(id ClientID) String() string { +type PeerID uuid.UUID +func(id PeerID) String() string { return uuid.UUID(id).String() } @@ -87,22 +84,12 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { return nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0]) } - server_pubkey, ecdh_public, err := ParseSessionOpen(response[COMMAND_LENGTH:read]) + session, err := ParseSessionOpen(ecdh_private, response[COMMAND_LENGTH:read]) if err != nil { return nil, err } - secret, err := ECDH(ecdh_public, ecdh_private) - if err != nil { - return nil, err - } - - session_cipher, err := aes.NewCipher(secret) - if err != nil { - return nil, err - } - - session_connect := NewSessionConnect(connection.LocalAddr().(*net.UDPAddr), secret) + session_connect := NewSessionConnect(connection.LocalAddr().(*net.UDPAddr), session.secret) _, err = connection.Write(session_connect) if err != nil { return nil, err @@ -115,14 +102,7 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { return &Client{ Key: key, - Session: Session{ - ID: ID[SessionID](secret), - remote: address, - Peer: ID[ClientID](server_pubkey), - secret: secret, - cipher: session_cipher, - iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64), - }, + Session: session, Connection: connection, }, nil } diff --git a/packet.go b/packet.go index 4c86883..d3d830f 100644 --- a/packet.go +++ b/packet.go @@ -2,12 +2,15 @@ package pnyx import ( "bytes" + "crypto/aes" "crypto/cipher" "crypto/ed25519" "crypto/rand" "crypto/sha512" "encoding/binary" + "hash/crc32" "io" + mrand "math/rand" "net" "fmt" @@ -19,6 +22,7 @@ import ( type PacketType uint8 const ( ID_LENGTH = 16 + IV_LENGTH = aes.BlockSize PUBKEY_LENGTH = 32 ECDH_LENGTH = 32 SIGNATURE_LENGTH = 64 @@ -27,11 +31,11 @@ const ( SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH SESSION_CONNECT_LENGTH = 2 + HMAC_LENGTH // + return addr string length + SESSION_CLOSE_LENGTH = HMAC_LENGTH SESSION_OPEN PacketType = iota SESSION_CONNECT SESSION_CLOSE - SESSION_CLOSED SESSION_DATA ) @@ -82,9 +86,9 @@ func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error) return packet, private, nil } -func ParseSessionOpen(session_open []byte) (ed25519.PublicKey, ed25519.PublicKey, error) { +func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, error) { if len(session_open) != SESSION_OPEN_LENGTH { - return nil, nil, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH) + return Session{}, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH) } cur := 0 @@ -100,10 +104,38 @@ func ParseSessionOpen(session_open []byte) (ed25519.PublicKey, ed25519.PublicKey cur += SIGNATURE_LENGTH if ed25519.Verify(client_pubkey, digest, signature) == false { - return nil, nil, fmt.Errorf("SESSION_OPEN signature verification failed") + return Session{}, fmt.Errorf("SESSION_OPEN signature verification failed") } - return client_pubkey, client_ecdh, nil + session_secret, err := ECDH(client_ecdh, ecdh) + if err != nil { + return Session{}, err + } + + session_id := ID[SessionID](session_secret) + client_id := ID[PeerID](client_pubkey) + + session_cipher, err := aes.NewCipher(session_secret) + if err != nil { + return Session{}, err + } + + seed_bytes := make([]byte, 8) + read, err := rand.Read(seed_bytes) + if err != nil { + return Session{}, err + } else if read != 8 { + return Session{}, fmt.Errorf("Not enough entropy to create session") + } + + return Session{ + ID: session_id, + remote: nil, + Peer: client_id, + secret: session_secret, + cipher: session_cipher, + iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64), + }, nil } func NewSessionConnect(address *net.UDPAddr, session_secret []byte) []byte { @@ -166,20 +198,22 @@ func ParseSessionConnect(session_connect []byte, session_secret []byte) (*net.UD } func NewSessionData(session *Session, packet []byte) ([]byte, error) { - iv := make([]byte, 32) - for i := 0; i < 4; i++ { + iv := make([]byte, IV_LENGTH) + for i := 0; i < IV_LENGTH/8; i++ { binary.BigEndian.PutUint64(iv[i*8:], session.iv_generator.Uint64()) } stream := cipher.NewOFB(session.cipher, iv[:]) - header := make([]byte, COMMAND_LENGTH + ID_LENGTH) + header := make([]byte, COMMAND_LENGTH + ID_LENGTH + IV_LENGTH) header[0] = byte(SESSION_DATA) - copy(header[1:], session.ID[:]) + copy(header[COMMAND_LENGTH:], session.ID[:]) + copy(header[COMMAND_LENGTH+ID_LENGTH:], iv) packet_encrypted := bytes.NewBuffer(header) writer := &cipher.StreamWriter{S: stream, W: packet_encrypted} - _, err := io.Copy(writer, bytes.NewBuffer(packet)) + // Encrypt the packet with a crc32 checksum appended to the end + _, err := io.Copy(writer, bytes.NewBuffer(binary.BigEndian.AppendUint32(packet, crc32.ChecksumIEEE(packet)))) if err != nil { return nil, err } @@ -187,16 +221,44 @@ func NewSessionData(session *Session, packet []byte) ([]byte, error) { return packet_encrypted.Bytes(), nil } -func ParseSessionData(session *Session, data []byte) ([]byte, error) { - iv := data[0:32] +func ParseSessionData(session *Session, encrypted []byte) ([]byte, error) { + iv := encrypted[0:IV_LENGTH] stream := cipher.NewOFB(session.cipher, iv) - var packet_clear bytes.Buffer - reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(data)} - _, err := io.Copy(&packet_clear, reader) + var packet bytes.Buffer + reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(encrypted[IV_LENGTH:])} + _, err := io.Copy(&packet, reader) if err != nil { return nil, err } - return packet_clear.Bytes(), nil + packet_clear := packet.Bytes() + checksum := binary.BigEndian.Uint32(packet_clear[len(packet_clear)-4:]) + data := packet_clear[:len(packet_clear)-4] + calculated := crc32.ChecksumIEEE(data) + if checksum != calculated { + return nil, fmt.Errorf("SESSION_DATA checksum mismatch: %x != %x", checksum, calculated) + } + + return data, nil +} + +func NewSessionClose(session *Session) []byte { + packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CLOSE_LENGTH) + packet[0] = COMMAND_LENGTH + copy(packet[1:], session.ID[:]) + + hmac := sha512.Sum512(append(session.ID[:], session.secret...)) + copy(packet[COMMAND_LENGTH + ID_LENGTH:], hmac[:]) + + return packet +} + +func ParseSessionClose(session *Session, hmac []byte) error { + calculated_hmac := sha512.Sum512(append(session.ID[:], session.secret...)) + if slices.Compare(hmac, calculated_hmac[:]) != 0 { + return fmt.Errorf("Session Close HMAC validation failed") + } else { + return nil + } } diff --git a/packet_test.go b/packet_test.go index ed40909..3972f78 100644 --- a/packet_test.go +++ b/packet_test.go @@ -17,19 +17,34 @@ func fatalErr(t *testing.T, err error) { func TestSessionOpen(t *testing.T) { client_pubkey, client_key, err := ed25519.GenerateKey(rand.Reader) fatalErr(t, err) + client_id := ID[PeerID](client_pubkey) - session_open, client_ecdh, err := NewSessionOpen(client_key) + server_pubkey, server_key, err := ed25519.GenerateKey(rand.Reader) fatalErr(t, err) + server_id := ID[PeerID](server_pubkey) - client_pubkey_parsed, client_ecdh_parsed, err := ParseSessionOpen(session_open[COMMAND_LENGTH:]) + client_so, client_ecdh, err := NewSessionOpen(client_key) fatalErr(t, err) - if slices.Compare(client_pubkey, client_pubkey_parsed) != 0 { - t.Fatalf("Client Pubkey %x does not match parsed %x", client_pubkey, client_pubkey_parsed) + server_so, server_ecdh, err := NewSessionOpen(server_key) + fatalErr(t, err) + + server_session, err := ParseSessionOpen(server_ecdh, client_so[COMMAND_LENGTH:]) + fatalErr(t, err) + + client_session, err := ParseSessionOpen(client_ecdh, server_so[COMMAND_LENGTH:]) + fatalErr(t, err) + + if client_id != server_session.Peer { + t.Fatalf("Server session(%s) has wrong peer ID(%s)", server_session.Peer, client_id) + } + + if server_id != client_session.Peer { + t.Fatalf("Client session(%s) has wrong peer ID(%s)", server_session.Peer, client_id) } - if slices.Compare(client_ecdh.Public().(ed25519.PublicKey), client_ecdh_parsed) != 0 { - t.Fatalf("Client Pubkey %x does not match parsed %x", client_pubkey, client_pubkey_parsed) + if slices.Compare(server_session.secret, client_session.secret) != 0 { + t.Fatalf("Client secret(%x) and server secret(%x) do not match", client_session.secret, server_session.secret) } } @@ -64,3 +79,34 @@ func TestSessionConnect(t *testing.T) { t.Fatalf("Parsed address(%s) does not match test address(%s)", parsed_addr, test_addr) } } + +func TestSessionData(t *testing.T) { + _, client_key, err := ed25519.GenerateKey(rand.Reader) + fatalErr(t, err) + + _, server_key, err := ed25519.GenerateKey(rand.Reader) + fatalErr(t, err) + + client_so, client_ecdh, err := NewSessionOpen(client_key) + fatalErr(t, err) + + server_so, server_ecdh, err := NewSessionOpen(server_key) + fatalErr(t, err) + + server_session, err := ParseSessionOpen(server_ecdh, client_so[COMMAND_LENGTH:]) + fatalErr(t, err) + + client_session, err := ParseSessionOpen(client_ecdh, server_so[COMMAND_LENGTH:]) + fatalErr(t, err) + + message := []byte("hello") + server_hello, err := NewSessionData(&server_session, message) + fatalErr(t, err) + + parsed_message, err := ParseSessionData(&client_session, server_hello[COMMAND_LENGTH+ID_LENGTH:]) + fatalErr(t, err) + + if slices.Compare(message, parsed_message) != 0 { + t.Fatalf("Parsed message(%s) != test message(%s)", parsed_message, message) + } +} diff --git a/server.go b/server.go index fe80de5..1473918 100644 --- a/server.go +++ b/server.go @@ -1,11 +1,9 @@ package pnyx import ( - "crypto/aes" "crypto/cipher" "crypto/ed25519" "crypto/rand" - "encoding/binary" "errors" "fmt" mrand "math/rand" @@ -28,7 +26,7 @@ func(id SessionID) String() string { type Session struct { ID SessionID remote *net.UDPAddr - Peer ClientID + Peer PeerID secret []byte cipher cipher.Block iv_generator mrand.Source64 @@ -67,7 +65,7 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) { // Check if the client has permission for the command on the channel // If it's not specified, check the permission on the parent -func Allowed(server *Server, client ClientID, channel_id ChannelID, mode ModeID, command CommandID) bool { +func Allowed(server *Server, client PeerID, channel_id ChannelID, mode ModeID, command CommandID) bool { channel, exists := server.channels[channel_id] if exists == false { return false @@ -129,61 +127,28 @@ func(server *Server) run() { var packet_type PacketType = PacketType(buf[0]) switch packet_type { case SESSION_OPEN: - client_pubkey, ecdh_public, err := ParseSessionOpen(buf[COMMAND_LENGTH:read]) - if err != nil { - server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) - continue - } - - client_id := ID[ClientID](client_pubkey) - session_open, ecdh_private, err := NewSessionOpen(server.key) if err != nil { server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) continue } - _, err = server.connection.WriteToUDP(session_open, from) - if err != nil { - server.Log("WriteToUDP error %s", err) - continue - } - - session_secret, err := ECDH(ecdh_public, ecdh_private) + session, err := ParseSessionOpen(ecdh_private, buf[COMMAND_LENGTH:read]) if err != nil { - server.Log("ECDH error %s", err) + server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) continue } - session_id := ID[SessionID](session_secret) - - session_cipher, err := aes.NewCipher(session_secret) - if err != nil { - server.Log("AES error %s", err) - continue - } + server.sessions[session.ID] = &session - seed_bytes := make([]byte, 8) - read, err := rand.Read(seed_bytes) + _, err = server.connection.WriteToUDP(session_open, from) if err != nil { - server.Log("IV Seed error: %s", err) - continue - } else if read != 8 { - server.Log("IV Seed error: not enough bytes read %d/4", read) + server.Log("WriteToUDP error %s", err) continue } - session := &Session{ - ID: session_id, - remote: nil, - Peer: client_id, - secret: session_secret, - cipher: session_cipher, - iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64), - } - server.sessions[session_id] = session + server.Log("Started session %s with %s", session.ID, session.Peer) - server.Log("Started session %s with %s", session_id, client_id) case SESSION_CONNECT: session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) session, exists := server.sessions[session_id] @@ -201,7 +166,26 @@ func(server *Server) run() { session.remote = client_addr server.Log("Got SESSION_CONNECT for client %s at address %s", session.Peer, session.remote) + // TODO: Send server hello back + case SESSION_CLOSE: + session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) + session, exists := server.sessions[session_id] + if exists == false { + server.Log("Session %s does not exist, can't close", session_id) + continue + } + + err := ParseSessionClose(session, buf[COMMAND_LENGTH+ID_LENGTH:]) + if err != nil { + server.Log("Session close error for %s - %s", session_id, err) + continue + } + + delete(server.sessions, session_id) + server.Log("Session %s closed", session_id) + + case SESSION_DATA: default: server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read]) }