Added test for SessionData

live
noah metz 2024-04-07 13:27:28 -06:00
parent 118975f490
commit 900b7502a9
5 changed files with 163 additions and 91 deletions

@ -10,7 +10,7 @@ const RootChannelID = 0
type ModeID uint8
type CommandID uint8
type PermissionMap map[ClientID]map[ModeID]map[CommandID]bool
type PermissionMap map[PeerID]map[ModeID]map[CommandID]bool
type Channel struct {
modes map[ModeID]Mode

@ -3,10 +3,7 @@ package pnyx
import (
"crypto/ed25519"
"crypto/rand"
mrand "math/rand"
"encoding/binary"
"crypto/sha512"
"crypto/aes"
"net"
"sync"
"fmt"
@ -14,8 +11,8 @@ import (
"github.com/google/uuid"
)
type ClientID uuid.UUID
func(id ClientID) String() string {
type PeerID uuid.UUID
func(id PeerID) String() string {
return uuid.UUID(id).String()
}
@ -87,22 +84,12 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
return nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0])
}
server_pubkey, ecdh_public, err := ParseSessionOpen(response[COMMAND_LENGTH:read])
session, err := ParseSessionOpen(ecdh_private, response[COMMAND_LENGTH:read])
if err != nil {
return nil, err
}
secret, err := ECDH(ecdh_public, ecdh_private)
if err != nil {
return nil, err
}
session_cipher, err := aes.NewCipher(secret)
if err != nil {
return nil, err
}
session_connect := NewSessionConnect(connection.LocalAddr().(*net.UDPAddr), secret)
session_connect := NewSessionConnect(connection.LocalAddr().(*net.UDPAddr), session.secret)
_, err = connection.Write(session_connect)
if err != nil {
return nil, err
@ -115,14 +102,7 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
return &Client{
Key: key,
Session: Session{
ID: ID[SessionID](secret),
remote: address,
Peer: ID[ClientID](server_pubkey),
secret: secret,
cipher: session_cipher,
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
},
Session: session,
Connection: connection,
}, nil
}

@ -2,12 +2,15 @@ package pnyx
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"crypto/rand"
"crypto/sha512"
"encoding/binary"
"hash/crc32"
"io"
mrand "math/rand"
"net"
"fmt"
@ -19,6 +22,7 @@ import (
type PacketType uint8
const (
ID_LENGTH = 16
IV_LENGTH = aes.BlockSize
PUBKEY_LENGTH = 32
ECDH_LENGTH = 32
SIGNATURE_LENGTH = 64
@ -27,11 +31,11 @@ const (
SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH
SESSION_CONNECT_LENGTH = 2 + HMAC_LENGTH // + return addr string length
SESSION_CLOSE_LENGTH = HMAC_LENGTH
SESSION_OPEN PacketType = iota
SESSION_CONNECT
SESSION_CLOSE
SESSION_CLOSED
SESSION_DATA
)
@ -82,9 +86,9 @@ func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error)
return packet, private, nil
}
func ParseSessionOpen(session_open []byte) (ed25519.PublicKey, ed25519.PublicKey, error) {
func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, error) {
if len(session_open) != SESSION_OPEN_LENGTH {
return nil, nil, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH)
return Session{}, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH)
}
cur := 0
@ -100,10 +104,38 @@ func ParseSessionOpen(session_open []byte) (ed25519.PublicKey, ed25519.PublicKey
cur += SIGNATURE_LENGTH
if ed25519.Verify(client_pubkey, digest, signature) == false {
return nil, nil, fmt.Errorf("SESSION_OPEN signature verification failed")
return Session{}, fmt.Errorf("SESSION_OPEN signature verification failed")
}
return client_pubkey, client_ecdh, nil
session_secret, err := ECDH(client_ecdh, ecdh)
if err != nil {
return Session{}, err
}
session_id := ID[SessionID](session_secret)
client_id := ID[PeerID](client_pubkey)
session_cipher, err := aes.NewCipher(session_secret)
if err != nil {
return Session{}, err
}
seed_bytes := make([]byte, 8)
read, err := rand.Read(seed_bytes)
if err != nil {
return Session{}, err
} else if read != 8 {
return Session{}, fmt.Errorf("Not enough entropy to create session")
}
return Session{
ID: session_id,
remote: nil,
Peer: client_id,
secret: session_secret,
cipher: session_cipher,
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
}, nil
}
func NewSessionConnect(address *net.UDPAddr, session_secret []byte) []byte {
@ -166,20 +198,22 @@ func ParseSessionConnect(session_connect []byte, session_secret []byte) (*net.UD
}
func NewSessionData(session *Session, packet []byte) ([]byte, error) {
iv := make([]byte, 32)
for i := 0; i < 4; i++ {
iv := make([]byte, IV_LENGTH)
for i := 0; i < IV_LENGTH/8; i++ {
binary.BigEndian.PutUint64(iv[i*8:], session.iv_generator.Uint64())
}
stream := cipher.NewOFB(session.cipher, iv[:])
header := make([]byte, COMMAND_LENGTH + ID_LENGTH)
header := make([]byte, COMMAND_LENGTH + ID_LENGTH + IV_LENGTH)
header[0] = byte(SESSION_DATA)
copy(header[1:], session.ID[:])
copy(header[COMMAND_LENGTH:], session.ID[:])
copy(header[COMMAND_LENGTH+ID_LENGTH:], iv)
packet_encrypted := bytes.NewBuffer(header)
writer := &cipher.StreamWriter{S: stream, W: packet_encrypted}
_, err := io.Copy(writer, bytes.NewBuffer(packet))
// Encrypt the packet with a crc32 checksum appended to the end
_, err := io.Copy(writer, bytes.NewBuffer(binary.BigEndian.AppendUint32(packet, crc32.ChecksumIEEE(packet))))
if err != nil {
return nil, err
}
@ -187,16 +221,44 @@ func NewSessionData(session *Session, packet []byte) ([]byte, error) {
return packet_encrypted.Bytes(), nil
}
func ParseSessionData(session *Session, data []byte) ([]byte, error) {
iv := data[0:32]
func ParseSessionData(session *Session, encrypted []byte) ([]byte, error) {
iv := encrypted[0:IV_LENGTH]
stream := cipher.NewOFB(session.cipher, iv)
var packet_clear bytes.Buffer
reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(data)}
_, err := io.Copy(&packet_clear, reader)
var packet bytes.Buffer
reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(encrypted[IV_LENGTH:])}
_, err := io.Copy(&packet, reader)
if err != nil {
return nil, err
}
return packet_clear.Bytes(), nil
packet_clear := packet.Bytes()
checksum := binary.BigEndian.Uint32(packet_clear[len(packet_clear)-4:])
data := packet_clear[:len(packet_clear)-4]
calculated := crc32.ChecksumIEEE(data)
if checksum != calculated {
return nil, fmt.Errorf("SESSION_DATA checksum mismatch: %x != %x", checksum, calculated)
}
return data, nil
}
func NewSessionClose(session *Session) []byte {
packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CLOSE_LENGTH)
packet[0] = COMMAND_LENGTH
copy(packet[1:], session.ID[:])
hmac := sha512.Sum512(append(session.ID[:], session.secret...))
copy(packet[COMMAND_LENGTH + ID_LENGTH:], hmac[:])
return packet
}
func ParseSessionClose(session *Session, hmac []byte) error {
calculated_hmac := sha512.Sum512(append(session.ID[:], session.secret...))
if slices.Compare(hmac, calculated_hmac[:]) != 0 {
return fmt.Errorf("Session Close HMAC validation failed")
} else {
return nil
}
}

@ -17,19 +17,34 @@ func fatalErr(t *testing.T, err error) {
func TestSessionOpen(t *testing.T) {
client_pubkey, client_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
client_id := ID[PeerID](client_pubkey)
session_open, client_ecdh, err := NewSessionOpen(client_key)
server_pubkey, server_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
server_id := ID[PeerID](server_pubkey)
client_pubkey_parsed, client_ecdh_parsed, err := ParseSessionOpen(session_open[COMMAND_LENGTH:])
client_so, client_ecdh, err := NewSessionOpen(client_key)
fatalErr(t, err)
if slices.Compare(client_pubkey, client_pubkey_parsed) != 0 {
t.Fatalf("Client Pubkey %x does not match parsed %x", client_pubkey, client_pubkey_parsed)
server_so, server_ecdh, err := NewSessionOpen(server_key)
fatalErr(t, err)
server_session, err := ParseSessionOpen(server_ecdh, client_so[COMMAND_LENGTH:])
fatalErr(t, err)
client_session, err := ParseSessionOpen(client_ecdh, server_so[COMMAND_LENGTH:])
fatalErr(t, err)
if client_id != server_session.Peer {
t.Fatalf("Server session(%s) has wrong peer ID(%s)", server_session.Peer, client_id)
}
if server_id != client_session.Peer {
t.Fatalf("Client session(%s) has wrong peer ID(%s)", server_session.Peer, client_id)
}
if slices.Compare(client_ecdh.Public().(ed25519.PublicKey), client_ecdh_parsed) != 0 {
t.Fatalf("Client Pubkey %x does not match parsed %x", client_pubkey, client_pubkey_parsed)
if slices.Compare(server_session.secret, client_session.secret) != 0 {
t.Fatalf("Client secret(%x) and server secret(%x) do not match", client_session.secret, server_session.secret)
}
}
@ -64,3 +79,34 @@ func TestSessionConnect(t *testing.T) {
t.Fatalf("Parsed address(%s) does not match test address(%s)", parsed_addr, test_addr)
}
}
func TestSessionData(t *testing.T) {
_, client_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
_, server_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
client_so, client_ecdh, err := NewSessionOpen(client_key)
fatalErr(t, err)
server_so, server_ecdh, err := NewSessionOpen(server_key)
fatalErr(t, err)
server_session, err := ParseSessionOpen(server_ecdh, client_so[COMMAND_LENGTH:])
fatalErr(t, err)
client_session, err := ParseSessionOpen(client_ecdh, server_so[COMMAND_LENGTH:])
fatalErr(t, err)
message := []byte("hello")
server_hello, err := NewSessionData(&server_session, message)
fatalErr(t, err)
parsed_message, err := ParseSessionData(&client_session, server_hello[COMMAND_LENGTH+ID_LENGTH:])
fatalErr(t, err)
if slices.Compare(message, parsed_message) != 0 {
t.Fatalf("Parsed message(%s) != test message(%s)", parsed_message, message)
}
}

@ -1,11 +1,9 @@
package pnyx
import (
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
mrand "math/rand"
@ -28,7 +26,7 @@ func(id SessionID) String() string {
type Session struct {
ID SessionID
remote *net.UDPAddr
Peer ClientID
Peer PeerID
secret []byte
cipher cipher.Block
iv_generator mrand.Source64
@ -67,7 +65,7 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) {
// Check if the client has permission for the command on the channel
// If it's not specified, check the permission on the parent
func Allowed(server *Server, client ClientID, channel_id ChannelID, mode ModeID, command CommandID) bool {
func Allowed(server *Server, client PeerID, channel_id ChannelID, mode ModeID, command CommandID) bool {
channel, exists := server.channels[channel_id]
if exists == false {
return false
@ -129,61 +127,28 @@ func(server *Server) run() {
var packet_type PacketType = PacketType(buf[0])
switch packet_type {
case SESSION_OPEN:
client_pubkey, ecdh_public, err := ParseSessionOpen(buf[COMMAND_LENGTH:read])
if err != nil {
server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
client_id := ID[ClientID](client_pubkey)
session_open, ecdh_private, err := NewSessionOpen(server.key)
if err != nil {
server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
_, err = server.connection.WriteToUDP(session_open, from)
session, err := ParseSessionOpen(ecdh_private, buf[COMMAND_LENGTH:read])
if err != nil {
server.Log("WriteToUDP error %s", err)
continue
}
session_secret, err := ECDH(ecdh_public, ecdh_private)
if err != nil {
server.Log("ECDH error %s", err)
server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
session_id := ID[SessionID](session_secret)
session_cipher, err := aes.NewCipher(session_secret)
if err != nil {
server.Log("AES error %s", err)
continue
}
server.sessions[session.ID] = &session
seed_bytes := make([]byte, 8)
read, err := rand.Read(seed_bytes)
_, err = server.connection.WriteToUDP(session_open, from)
if err != nil {
server.Log("IV Seed error: %s", err)
continue
} else if read != 8 {
server.Log("IV Seed error: not enough bytes read %d/4", read)
server.Log("WriteToUDP error %s", err)
continue
}
session := &Session{
ID: session_id,
remote: nil,
Peer: client_id,
secret: session_secret,
cipher: session_cipher,
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
}
server.sessions[session_id] = session
server.Log("Started session %s with %s", session.ID, session.Peer)
server.Log("Started session %s with %s", session_id, client_id)
case SESSION_CONNECT:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
@ -201,7 +166,26 @@ func(server *Server) run() {
session.remote = client_addr
server.Log("Got SESSION_CONNECT for client %s at address %s", session.Peer, session.remote)
// TODO: Send server hello back
case SESSION_CLOSE:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
server.Log("Session %s does not exist, can't close", session_id)
continue
}
err := ParseSessionClose(session, buf[COMMAND_LENGTH+ID_LENGTH:])
if err != nil {
server.Log("Session close error for %s - %s", session_id, err)
continue
}
delete(server.sessions, session_id)
server.Log("Session %s closed", session_id)
case SESSION_DATA:
default:
server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read])
}