Changed session packet protocol

live
noah metz 2024-04-11 21:05:50 -06:00
parent f32087f729
commit 21e9794747
7 changed files with 281 additions and 291 deletions

@ -24,6 +24,7 @@ type Permission string
type Channel struct {
id ChannelID
name string
modes map[ModeID]Mode
sessions []SessionID
}

@ -3,19 +3,12 @@ package pnyx
import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha512"
"net"
"sync"
"time"
"fmt"
"github.com/google/uuid"
)
type PeerID uuid.UUID
func(id PeerID) String() string {
return uuid.UUID(id).String()
}
type ClientState uint8
const (
@ -31,11 +24,6 @@ type Client struct {
Connection *net.UDPConn
}
func ID[T ~[16]byte, V ~[]byte](data V) T {
hash := sha512.Sum512(data)
return (T)(hash[0:16])
}
func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
if key == nil {
var err error
@ -80,16 +68,16 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
return nil, err
}
if response[0] != byte(SESSION_OPEN) {
if response[0] != byte(SESSION_OPENED) {
return nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0])
}
session, err := ParseSessionOpen(ecdh_private, response[COMMAND_LENGTH:read])
session, err := ParseSessionOpened(nil, ecdh_private, response[COMMAND_LENGTH:read])
if err != nil {
return nil, err
}
session_connect := NewSessionConnect(connection.LocalAddr().(*net.UDPAddr), session.secret)
session_connect := NewSessionTimed(SESSION_CONNECT, key, &session, time.Now())
_, err = connection.Write(session_connect)
if err != nil {
return nil, err
@ -99,7 +87,7 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
if err != nil {
return nil, err
}
return &Client{
Key: key,
Session: session,
@ -137,7 +125,7 @@ func(client *Client) Close() error {
return fmt.Errorf("No connection to close")
}
client.Connection.Write(NewSessionClose(&client.Session))
client.Connection.Write(NewSessionTimed(SESSION_CLOSE, client.Key, &client.Session, time.Now()))
err := client.Connection.Close()
client.Connection = nil

@ -169,7 +169,7 @@ func main() {
break
}
data, err := pnyx.ParseSessionData(&client.Session, buf[pnyx.COMMAND_LENGTH + pnyx.ID_LENGTH:read])
data, err := pnyx.ParseSessionData(&client.Session, buf[pnyx.COMMAND_LENGTH + pnyx.SESSION_ID_LENGTH:read])
if err != nil {
fmt.Printf("ParseSessionData Error %s\n", err)
continue

@ -15,8 +15,8 @@ const (
PACKET_CHANNEL_PEER
CHANNEL_HEADER_LEN int = 5
CHANNEL_COMMAND_LEN = CHANNEL_HEADER_LEN + COMMAND_LENGTH + ID_LENGTH
CHANNEL_PEER_LEN = CHANNEL_HEADER_LEN + ID_LENGTH
CHANNEL_COMMAND_LEN = CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH
CHANNEL_PEER_LEN = CHANNEL_HEADER_LEN + PEER_ID_LENGTH
)
type Payload interface {
@ -161,7 +161,7 @@ func(packet ChannelPeerPacket) MarshalBinary() ([]byte, error) {
func ParseChannelPeerPacket(data []byte) (ChannelPeerPacket, error) {
if len(data) < CHANNEL_PEER_LEN {
return ChannelPeerPacket{}, fmt.Errorf("Not enough bytes to parse ServerChannelPacket: %d/%d", len(data), ID_LENGTH)
return ChannelPeerPacket{}, fmt.Errorf("Not enough bytes to parse ServerChannelPacket: %d/%d", len(data), PEER_ID_LENGTH)
}
header, err := ParseChannelHeader(data)

@ -141,140 +141,138 @@ func(server *Server) Stop() error {
const SESSION_BUFFER_SIZE = 256
func(server *Server) handle_session_open(client_session_open []byte, from *net.UDPAddr) error {
session_open, ecdh_private, err := NewSessionOpen(server.key)
if err != nil {
return err
}
func handle_session_outgoing(session *ServerSession, server *Server) {
server.Log("Starting session outgoing goroutine %s", session.ID)
for true {
packet := <- session.OutgoingPackets
if packet == nil {
break
}
session, err := ParseSessionOpen(ecdh_private, client_session_open)
if err != nil {
return err
}
if session.remote == nil {
server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID)
continue
}
server.sessions_lock.Lock()
server.sessions[session.ID] = &ServerSession{
Session: session,
LastSeen: time.Now(),
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE),
}
server.sessions_lock.Unlock()
packet_data, err := packet.MarshalBinary()
if err != nil {
server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err)
continue
}
go func(session *ServerSession, server *Server){
server.Log("Starting session outgoing goroutine %s", session.ID)
for true {
packet := <- session.OutgoingPackets
if packet == nil {
break
}
encrypted, err := NewSessionData(&session.Session, packet_data)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err)
continue
}
if session.remote == nil {
server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID)
continue
}
_, err = server.connection.WriteToUDP(encrypted, session.remote)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err)
continue
}
}
server.Log("Stopping session outgoing goroutine %s", session.ID)
}
packet_data, err := packet.MarshalBinary()
if err != nil {
server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err)
continue
}
func handle_session_incoming(session *ServerSession, server *Server) {
server.Log("Starting session incoming goroutine %s", session.ID)
for true {
encrypted := <- session.IncomingPackets
if encrypted == nil {
break
}
encrypted, err := NewSessionData(&session.Session, packet_data)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err)
continue
}
data, err := ParseSessionData(&session.Session, encrypted)
if err != nil {
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
continue
}
_, err = server.connection.WriteToUDP(encrypted, session.remote)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err)
continue
}
packet, err := ParsePacket(data)
if err != nil {
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
}
server.Log("Stopping session outgoing goroutine %s", session.ID)
}(server.sessions[session.ID], server)
go func(session *ServerSession, server *Server){
server.Log("Starting session incoming goroutine %s", session.ID)
for true {
encrypted := <- session.IncomingPackets
if encrypted == nil {
break
}
data, err := ParseSessionData(&session.Session, encrypted)
if err != nil {
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
continue
switch packet := packet.(type) {
case ChannelDataPacket:
var result []SendPacket = nil
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
result = channel.Data(&session.Session, packet.Mode, packet.Data)
}
server.channels_lock.RUnlock()
packet, err := ParsePacket(data)
if err != nil {
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
if exists == false {
server.Log("Packet for unknown channel %d", packet.Channel)
} else if len(result) > 0 {
//TODO: handle overflow
server.send_packets<-result
}
switch packet := packet.(type) {
case ChannelDataPacket:
var result []SendPacket = nil
case ChannelCommandPacket:
var result []SendPacket = nil
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
result = channel.Data(&session.Session, packet.Mode, packet.Data)
}
server.channels_lock.RUnlock()
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data)
}
server.channels_lock.RUnlock()
if exists == false {
server.Log("Packet for unknown channel %d", packet.Channel)
} else if err != nil {
server.Log("Error processing %+v - %s", packet, err)
} else if len(result) > 0 {
//TODO: handle overflow
server.send_packets<-result
}
if exists == false {
server.Log("Packet for unknown channel %d", packet.Channel)
} else if len(result) > 0 {
//TODO: handle overflow
server.send_packets<-result
}
default:
server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
}
case ChannelCommandPacket:
var result []SendPacket = nil
}
server.Log("Stopping session incoming goroutine %s", session.ID)
}
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data)
}
server.channels_lock.RUnlock()
if exists == false {
server.Log("Packet for unknown channel %d", packet.Channel)
} else if err != nil {
server.Log("Error processing %+v - %s", packet, err)
} else if len(result) > 0 {
//TODO: handle overflow
server.send_packets<-result
}
func(server *Server) handle_session_open(client_session_open []byte, from *net.UDPAddr) error {
session, session_opened, err := ParseSessionOpen(server.key, client_session_open)
if err != nil {
return err
}
default:
server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
}
server.sessions_lock.Lock()
server.sessions[session.ID] = &ServerSession{
Session: session,
LastSeen: time.Now(),
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE),
}
server.sessions_lock.Unlock()
}
server.Log("Stopping session incoming goroutine %s", session.ID)
}(server.sessions[session.ID], server)
go handle_session_outgoing(server.sessions[session.ID], server)
go handle_session_incoming(server.sessions[session.ID], server)
_, err = server.connection.WriteToUDP(session_open, from)
_, err = server.connection.WriteToUDP(session_opened, from)
if err != nil {
return err
}
server.Log("Started session %s with %s", session.ID, session.Peer)
server.Log("Started session %s with %x", session.ID, session.Peer)
return nil
}
func(server *Server) handle_session_connect(client_session_connect []byte, from *net.UDPAddr) error {
session_id := SessionID(client_session_connect[:ID_LENGTH])
func(server *Server) handle_session_connect(session_connect []byte, from *net.UDPAddr) error {
session_id := SessionID(session_connect[:SESSION_ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
return fmt.Errorf("Session %s does not exist, can't connect", session_id)
}
_, err := ParseSessionConnect(client_session_connect[ID_LENGTH:], session.secret)
session_connected, err := ParseSessionTimed(SESSION_CONNECTED, server.key, session_connect, &session.Session, time.Now().Add(-1*time.Second).UnixMilli())
if err != nil {
return err
}
@ -283,13 +281,7 @@ func(server *Server) handle_session_connect(client_session_connect []byte, from
session.remote = from
session.LastSeen = time.Now()
// TODO: Make a better server hello
server_hello, err := NewSessionData(&session.Session, []byte("hello"))
if err != nil {
return err
}
_, err = server.connection.WriteToUDP(server_hello, session.remote)
_, err = server.connection.WriteToUDP(session_connected, session.remote)
if err != nil {
return err
}
@ -297,14 +289,14 @@ func(server *Server) handle_session_connect(client_session_connect []byte, from
return nil
}
func(server *Server) handle_session_close(client_session_close []byte, from *net.UDPAddr) error {
session_id := SessionID(client_session_close[:ID_LENGTH])
func(server *Server) handle_session_close(session_close []byte, from *net.UDPAddr) error {
session_id := SessionID(session_close[:SESSION_ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
return fmt.Errorf("Session %s does not exist, can't close", session_id)
}
err := ParseSessionClose(&session.Session, client_session_close[ID_LENGTH:])
session_closed, err := ParseSessionTimed(SESSION_CLOSED, server.key, session_close, &session.Session, time.Now().Add(-1*time.Second).UnixMilli())
if err != nil {
return err
}
@ -313,12 +305,17 @@ func(server *Server) handle_session_close(client_session_close []byte, from *net
server.close_session(session)
server.sessions_lock.Unlock()
_, err = server.connection.WriteToUDP(session_closed, session.remote)
if err != nil {
return err
}
server.Log("Session %s closed", session_id)
return nil
}
func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error {
session_id := SessionID(data[:ID_LENGTH])
session_id := SessionID(data[:SESSION_ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
return fmt.Errorf("Session %s does not exist, can't receive data", session_id)
@ -326,8 +323,8 @@ func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error {
session.LastSeen = time.Now()
buf_copy := make([]byte, len(data) - ID_LENGTH)
copy(buf_copy, data[ID_LENGTH:])
buf_copy := make([]byte, len(data) - SESSION_ID_LENGTH)
copy(buf_copy, data[SESSION_ID_LENGTH:])
select {
case session.IncomingPackets<-buf_copy:
@ -426,7 +423,6 @@ func(server *Server) cleanup_sessions() {
for server.active.Load() {
select {
case <-time.After(SESSION_TIMEOUT_CHECK):
server.Log("Running stale session check")
server.sessions_lock.Lock()
now := time.Now()
stale_sessions := make([]*ServerSession, 0, len(server.sessions))

@ -12,6 +12,7 @@ import (
"io"
mrand "math/rand"
"net"
"time"
"fmt"
"slices"
@ -25,35 +26,43 @@ func(id SessionID) String() string {
return uuid.UUID(id).String()
}
type PeerID [32]byte
type Session struct {
ID SessionID
remote *net.UDPAddr
Peer PeerID
secret []byte
cipher cipher.Block
iv_generator mrand.Source64
iv_generator *mrand.Rand
}
type SessionPacketType uint8
const (
ID_LENGTH = 16
REQ_ID_LENGTH = 16
PEER_ID_LENGTH = 32
SESSION_ID_LENGTH = 16
IV_LENGTH = aes.BlockSize
PUBKEY_LENGTH = 32
ECDH_LENGTH = 32
SIGNATURE_LENGTH = 64
HMAC_LENGTH = 64
COMMAND_LENGTH = 1
TIME_LENGTH = 8
SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH
SESSION_CONNECT_LENGTH = 2 + HMAC_LENGTH // + return addr string length
SESSION_CLOSE_LENGTH = HMAC_LENGTH
SESSION_OPEN_LENGTH = PUBKEY_LENGTH + PUBKEY_LENGTH + SIGNATURE_LENGTH
SESSION_OPENED_LENGTH = PUBKEY_LENGTH + PUBKEY_LENGTH + SESSION_ID_LENGTH + SIGNATURE_LENGTH
SESSION_TIMED_LENGTH = SESSION_ID_LENGTH + TIME_LENGTH + SIGNATURE_LENGTH
SESSION_TIMED_RESP_LENGTH = TIME_LENGTH + SIGNATURE_LENGTH
/*
pnyx session packets
*/
SESSION_OPEN SessionPacketType = iota
SESSION_OPENED
SESSION_CONNECT
SESSION_CONNECTED
SESSION_CLOSE
SESSION_CLOSED
SESSION_DATA
)
@ -104,9 +113,9 @@ func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error)
return packet, private, nil
}
func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, error) {
func ParseSessionOpen(key ed25519.PrivateKey, session_open []byte) (Session, []byte, error) {
if len(session_open) != SESSION_OPEN_LENGTH {
return Session{}, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH)
return Session{}, nil, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH)
}
cur := 0
@ -122,20 +131,92 @@ func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, er
cur += SIGNATURE_LENGTH
if ed25519.Verify(client_pubkey, digest, signature) == false {
return Session{}, fmt.Errorf("SESSION_OPEN signature verification failed")
return Session{}, nil, fmt.Errorf("SESSION_OPEN signature verification failed")
}
session_secret, err := ECDH(client_ecdh, ecdh)
seed_bytes := make([]byte, 8)
read, err := rand.Read(seed_bytes)
if err != nil {
return Session{}, err
return Session{}, nil, err
} else if read != 8 {
return Session{}, nil, fmt.Errorf("Not enough entropy to create session")
}
session_id := ID[SessionID](session_secret)
client_id := ID[PeerID](client_pubkey)
rand_gen := mrand.New(mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64))
session_uuid, err := uuid.NewRandomFromReader(rand_gen)
if err != nil {
return Session{}, nil, err
}
session_id := SessionID(session_uuid)
ecdh_public, ecdh_private, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return Session{}, nil, err
}
session_secret, err := ECDH(client_ecdh, ecdh_private)
if err != nil {
return Session{}, nil, err
}
session_cipher, err := aes.NewCipher(session_secret)
if err != nil {
return Session{}, err
return Session{}, nil, err
}
session_opened := make([]byte, COMMAND_LENGTH + SESSION_OPENED_LENGTH)
cur = 0
session_opened[cur] = byte(SESSION_OPENED)
cur += COMMAND_LENGTH
copy(session_opened[cur:], key.Public().(ed25519.PublicKey))
cur += PUBKEY_LENGTH
copy(session_opened[cur:], ecdh_public)
cur += PUBKEY_LENGTH
copy(session_opened[cur:], session_id[:])
cur += SESSION_ID_LENGTH
signature = ed25519.Sign(key, session_opened[COMMAND_LENGTH:cur])
copy(session_opened[cur:], signature)
cur += SIGNATURE_LENGTH
return Session{
ID: session_id,
remote: nil,
Peer: PeerID(client_pubkey),
iv_generator: rand_gen,
cipher: session_cipher,
secret: session_secret,
}, session_opened, nil
}
func ParseSessionOpened(expected_pubkey ed25519.PublicKey, ecdh_private ed25519.PrivateKey, session_opened []byte) (Session, error) {
if len(session_opened) != SESSION_OPENED_LENGTH {
return Session{}, fmt.Errorf("Wrong SESSION_OPENED length: %d/%d", len(session_opened), SESSION_OPEN_LENGTH)
}
cur := 0
server_pubkey := (ed25519.PublicKey)(session_opened[cur:cur+PUBKEY_LENGTH])
cur += PUBKEY_LENGTH
server_ecdh := (ed25519.PublicKey)(session_opened[cur:cur+PUBKEY_LENGTH])
cur += PUBKEY_LENGTH
session_id := SessionID(session_opened[cur:cur+SESSION_ID_LENGTH])
cur += SESSION_ID_LENGTH
if expected_pubkey != nil && slices.Compare(server_pubkey, expected_pubkey) != 0 {
return Session{}, fmt.Errorf("server public key %x does not match expected %x", server_pubkey, expected_pubkey)
}
if ed25519.Verify(server_pubkey, session_opened[:cur], session_opened[cur:cur+SIGNATURE_LENGTH]) == false {
return Session{}, fmt.Errorf("session opened signature verification failed")
}
seed_bytes := make([]byte, 8)
@ -146,73 +227,77 @@ func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, er
return Session{}, fmt.Errorf("Not enough entropy to create session")
}
session_secret, err := ECDH(server_ecdh, ecdh_private)
if err != nil {
return Session{}, err
}
cipher, err := aes.NewCipher(session_secret)
if err != nil {
return Session{}, nil
}
return Session{
ID: session_id,
remote: nil,
Peer: client_id,
Peer: PeerID(server_pubkey),
secret: session_secret,
cipher: session_cipher,
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
cipher: cipher,
iv_generator: mrand.New(mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64)),
}, nil
}
func NewSessionConnect(address *net.UDPAddr, session_secret []byte) []byte {
packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CONNECT_LENGTH + len(address.String()))
func NewSessionTimed(command SessionPacketType, key ed25519.PrivateKey, session *Session, t time.Time) []byte {
packet := make([]byte, COMMAND_LENGTH + SESSION_TIMED_LENGTH)
cur := 0
packet[cur] = byte(SESSION_CONNECT)
packet[cur] = byte(command)
cur += COMMAND_LENGTH
session_id := [16]byte(ID[SessionID](session_secret))
copy(packet[cur:], session_id[:])
cur += ID_LENGTH
copy(packet[cur:], session.ID[:])
cur += SESSION_ID_LENGTH
binary.BigEndian.PutUint16(packet[cur:], uint16(len(address.String())))
cur += 2
copy(packet[cur:], []byte(address.String()))
cur += len(address.String())
binary.BigEndian.PutUint64(packet[cur:], uint64(t.UnixMilli()))
cur += TIME_LENGTH
hmac := sha512.Sum512(append(packet[COMMAND_LENGTH+ID_LENGTH:cur], session_secret...))
copy(packet[cur:], hmac[:])
signature := ed25519.Sign(key, packet[COMMAND_LENGTH:cur])
copy(packet[cur:], signature)
cur += SIGNATURE_LENGTH
return packet
}
func ParseSessionConnect(session_connect []byte, session_secret []byte) (*net.UDPAddr, error) {
if len(session_connect) < SESSION_CONNECT_LENGTH {
return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH)
func ParseSessionTimed(resp_type SessionPacketType, key ed25519.PrivateKey, packet []byte, session *Session, last_t int64) ([]byte, error) {
if len(packet) != SESSION_TIMED_LENGTH {
return nil, fmt.Errorf("Bad timed packet length: %d/%d", len(packet), SESSION_TIMED_LENGTH)
}
cur := 0
address_length := int(binary.BigEndian.Uint16(session_connect[cur:cur+2]))
cur += 2
cur := SESSION_ID_LENGTH
t := int64(binary.BigEndian.Uint64(packet[cur:]))
cur += TIME_LENGTH
if len(session_connect) != (SESSION_CONNECT_LENGTH + address_length) {
return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH + address_length)
if t < last_t {
return nil, fmt.Errorf("Time in packet to old: %d < %d", t, last_t)
}
address := string(session_connect[cur:cur+address_length])
cur += address_length
hmac_digest := make([]byte, cur)
copy(hmac_digest, session_connect[:cur])
if ed25519.Verify(ed25519.PublicKey(session.Peer[:]), packet[:cur], packet[cur:cur+SIGNATURE_LENGTH]) == false {
return nil, fmt.Errorf("Failed to verify packet signature")
}
hmac := session_connect[cur:cur+HMAC_LENGTH]
cur += HMAC_LENGTH
resp := make([]byte, COMMAND_LENGTH + SESSION_TIMED_RESP_LENGTH)
cur = 0
calculated_hmac := sha512.Sum512(append(hmac_digest, session_secret...))
if slices.Compare(hmac, calculated_hmac[:]) != 0 {
return nil, fmt.Errorf("Session connect bad HMAC")
}
resp[cur] = byte(resp_type)
cur += COMMAND_LENGTH
binary.BigEndian.PutUint64(resp[cur:], uint64(t))
cur += TIME_LENGTH
addr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, fmt.Errorf("Error parsing return address: %w", err)
}
signature := ed25519.Sign(key, resp[COMMAND_LENGTH:cur])
copy(resp[cur:], signature)
cur += SIGNATURE_LENGTH
return addr, nil
return resp, nil
}
func NewSessionData(session *Session, packet []byte) ([]byte, error) {
@ -222,10 +307,10 @@ func NewSessionData(session *Session, packet []byte) ([]byte, error) {
}
stream := cipher.NewOFB(session.cipher, iv[:])
header := make([]byte, COMMAND_LENGTH + ID_LENGTH + IV_LENGTH)
header := make([]byte, COMMAND_LENGTH + SESSION_ID_LENGTH + IV_LENGTH)
header[0] = byte(SESSION_DATA)
copy(header[COMMAND_LENGTH:], session.ID[:])
copy(header[COMMAND_LENGTH+ID_LENGTH:], iv)
copy(header[COMMAND_LENGTH+SESSION_ID_LENGTH:], iv)
packet_encrypted := bytes.NewBuffer(header)
writer := &cipher.StreamWriter{S: stream, W: packet_encrypted}
@ -260,23 +345,3 @@ func ParseSessionData(session *Session, encrypted []byte) ([]byte, error) {
return data, nil
}
func NewSessionClose(session *Session) []byte {
packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CLOSE_LENGTH)
packet[0] = byte(SESSION_CLOSE)
copy(packet[1:], session.ID[:])
hmac := sha512.Sum512(append(session.ID[:], session.secret...))
copy(packet[COMMAND_LENGTH + ID_LENGTH:], hmac[:])
return packet
}
func ParseSessionClose(session *Session, hmac []byte) error {
calculated_hmac := sha512.Sum512(append(session.ID[:], session.secret...))
if slices.Compare(hmac, calculated_hmac[:]) != 0 {
return fmt.Errorf("Session Close HMAC validation failed")
} else {
return nil
}
}

@ -3,7 +3,6 @@ package pnyx
import (
"crypto/ed25519"
"crypto/rand"
"net"
"slices"
"testing"
)
@ -17,33 +16,28 @@ func fatalErr(t *testing.T, err error) {
func TestSessionOpen(t *testing.T) {
client_pubkey, client_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
client_id := ID[PeerID](client_pubkey)
server_pubkey, server_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
server_id := ID[PeerID](server_pubkey)
client_so, client_ecdh, err := NewSessionOpen(client_key)
session_open, client_ecdh, err := NewSessionOpen(client_key)
fatalErr(t, err)
if client_so[0] != byte(SESSION_OPEN) {
t.Fatalf("Session open command byte mismatch(%x != %x)", client_so[0], SESSION_OPEN)
if session_open[0] != byte(SESSION_OPEN) {
t.Fatalf("Session open command byte mismatch(%x != %x)", session_open[0], SESSION_OPEN)
}
server_so, server_ecdh, err := NewSessionOpen(server_key)
server_session, session_opened, err := ParseSessionOpen(server_key, session_open[COMMAND_LENGTH:])
fatalErr(t, err)
server_session, err := ParseSessionOpen(server_ecdh, client_so[COMMAND_LENGTH:])
client_session, err := ParseSessionOpened(server_pubkey, client_ecdh, session_opened[COMMAND_LENGTH:])
fatalErr(t, err)
client_session, err := ParseSessionOpen(client_ecdh, server_so[COMMAND_LENGTH:])
fatalErr(t, err)
if client_id != server_session.Peer {
t.Fatalf("Server session(%s) has wrong peer ID(%s)", server_session.Peer, client_id)
if PeerID(client_pubkey) != server_session.Peer {
t.Fatalf("Server session(%x) has wrong peer ID(%x)", server_session.Peer, client_pubkey)
}
if server_id != client_session.Peer {
t.Fatalf("Client session(%s) has wrong peer ID(%s)", server_session.Peer, client_id)
if PeerID(server_pubkey) != client_session.Peer {
t.Fatalf("Client session(%x) has wrong peer ID(%x)", client_session.Peer, server_pubkey)
}
if slices.Compare(server_session.secret, client_session.secret) != 0 {
@ -69,44 +63,20 @@ func TestECDH(t *testing.T) {
}
}
func TestSessionConnect(t *testing.T) {
secret := make([]byte, 32)
test_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:8080")
fatalErr(t, err)
session_id := ID[SessionID](secret)
session_connect := NewSessionConnect(test_addr, secret)
if session_connect[0] != byte(SESSION_CONNECT) {
t.Fatalf("Session open command byte mismatch(%x != %x)", session_connect[0], SESSION_CONNECT)
} else if slices.Compare(session_connect[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH], session_id[:]) != 0 {
t.Fatal("Session open ID mismatch")
}
parsed_addr, err := ParseSessionConnect(session_connect[COMMAND_LENGTH + ID_LENGTH:], secret)
fatalErr(t, err)
if parsed_addr.String() != test_addr.String() {
t.Fatalf("Parsed address(%s) does not match test address(%s)", parsed_addr, test_addr)
}
}
func TestSessionData(t *testing.T) {
_, client_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
_, server_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
client_so, client_ecdh, err := NewSessionOpen(client_key)
server_pubkey, server_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
server_so, server_ecdh, err := NewSessionOpen(server_key)
session_open, client_ecdh, err := NewSessionOpen(client_key)
fatalErr(t, err)
server_session, err := ParseSessionOpen(server_ecdh, client_so[COMMAND_LENGTH:])
server_session, session_opened, err := ParseSessionOpen(server_key, session_open[COMMAND_LENGTH:])
fatalErr(t, err)
client_session, err := ParseSessionOpen(client_ecdh, server_so[COMMAND_LENGTH:])
client_session, err := ParseSessionOpened(server_pubkey, client_ecdh, session_opened[COMMAND_LENGTH:])
fatalErr(t, err)
message := []byte("hello")
@ -115,44 +85,14 @@ func TestSessionData(t *testing.T) {
if session_data[0] != byte(SESSION_DATA) {
t.Fatalf("Session data command byte mismatch(%x != %x)", session_data[0], SESSION_DATA)
} else if slices.Compare(session_data[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH], server_session.ID[:]) != 0 {
} else if SessionID(session_data[COMMAND_LENGTH:COMMAND_LENGTH+SESSION_ID_LENGTH]) != server_session.ID {
t.Fatal("Session data ID mismatch")
}
parsed_message, err := ParseSessionData(&client_session, session_data[COMMAND_LENGTH+ID_LENGTH:])
parsed_message, err := ParseSessionData(&client_session, session_data[COMMAND_LENGTH+SESSION_ID_LENGTH:])
fatalErr(t, err)
if slices.Compare(message, parsed_message) != 0 {
t.Fatalf("Parsed message(%s) != test message(%s)", parsed_message, message)
}
}
func TestSessionClose(t *testing.T) {
_, client_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
_, server_key, err := ed25519.GenerateKey(rand.Reader)
fatalErr(t, err)
client_so, client_ecdh, err := NewSessionOpen(client_key)
fatalErr(t, err)
server_so, server_ecdh, err := NewSessionOpen(server_key)
fatalErr(t, err)
server_session, err := ParseSessionOpen(server_ecdh, client_so[COMMAND_LENGTH:])
fatalErr(t, err)
client_session, err := ParseSessionOpen(client_ecdh, server_so[COMMAND_LENGTH:])
fatalErr(t, err)
session_close := NewSessionClose(&client_session)
if session_close[0] != byte(SESSION_CLOSE) {
t.Fatalf("Session close command byte mismatch(%x != %x)", session_close[0], SESSION_CLOSE)
} else if slices.Compare(session_close[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH], server_session.ID[:]) != 0 {
t.Fatal("Session close ID mismatch")
}
fatalErr(t, ParseSessionClose(&server_session, session_close[COMMAND_LENGTH+ID_LENGTH:]))
}