package pnyx import ( "bytes" "crypto/aes" "crypto/cipher" "crypto/ed25519" "crypto/rand" "crypto/sha512" "encoding/binary" "hash/crc32" "io" mrand "math/rand" "net" "fmt" "slices" "filippo.io/edwards25519" ) type PacketType uint8 const ( ID_LENGTH = 16 IV_LENGTH = aes.BlockSize PUBKEY_LENGTH = 32 ECDH_LENGTH = 32 SIGNATURE_LENGTH = 64 HMAC_LENGTH = 64 COMMAND_LENGTH = 1 SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH SESSION_CONNECT_LENGTH = 2 + HMAC_LENGTH // + return addr string length SESSION_CLOSE_LENGTH = HMAC_LENGTH SESSION_OPEN PacketType = iota SESSION_CONNECT SESSION_CLOSE SESSION_DATA ) func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) { public_point, err := (&edwards25519.Point{}).SetBytes(public) if err != nil { return nil, err } h := sha512.Sum512(private.Seed()) private_scalar, err := (&edwards25519.Scalar{}).SetBytesWithClamping(h[:32]) if err != nil { return nil, err } shared_point := public_point.ScalarMult(private_scalar, public_point) return shared_point.BytesMontgomery(), nil } func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error) { if key == nil { return nil, nil, fmt.Errorf("Cannot create a SESSION_OPEN packet without a key") } public, private, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, nil, fmt.Errorf("Failed to generate ecdh key: %w", err) } packet := make([]byte, COMMAND_LENGTH + SESSION_OPEN_LENGTH) cur := 0 packet[0] = byte(SESSION_OPEN) cur += COMMAND_LENGTH copy(packet[cur:], []byte(key.Public().(ed25519.PublicKey))) cur += PUBKEY_LENGTH copy(packet[cur:], []byte(public)) cur += PUBKEY_LENGTH signature := ed25519.Sign(key, packet[COMMAND_LENGTH:cur]) copy(packet[cur:], signature) cur += SIGNATURE_LENGTH return packet, private, nil } func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, error) { if len(session_open) != SESSION_OPEN_LENGTH { return Session{}, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH) } cur := 0 client_pubkey := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH]) cur += PUBKEY_LENGTH client_ecdh := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH]) cur += PUBKEY_LENGTH digest := session_open[:cur] signature := session_open[cur:cur+SIGNATURE_LENGTH] cur += SIGNATURE_LENGTH if ed25519.Verify(client_pubkey, digest, signature) == false { return Session{}, fmt.Errorf("SESSION_OPEN signature verification failed") } session_secret, err := ECDH(client_ecdh, ecdh) if err != nil { return Session{}, err } session_id := ID[SessionID](session_secret) client_id := ID[PeerID](client_pubkey) session_cipher, err := aes.NewCipher(session_secret) if err != nil { return Session{}, err } seed_bytes := make([]byte, 8) read, err := rand.Read(seed_bytes) if err != nil { return Session{}, err } else if read != 8 { return Session{}, fmt.Errorf("Not enough entropy to create session") } return Session{ ID: session_id, remote: nil, Peer: client_id, secret: session_secret, cipher: session_cipher, iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64), }, nil } func NewSessionConnect(address *net.UDPAddr, session_secret []byte) []byte { packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CONNECT_LENGTH + len(address.String())) cur := 0 packet[cur] = byte(SESSION_CONNECT) cur += COMMAND_LENGTH session_id := [16]byte(ID[SessionID](session_secret)) copy(packet[cur:], session_id[:]) cur += ID_LENGTH binary.BigEndian.PutUint16(packet[cur:], uint16(len(address.String()))) cur += 2 copy(packet[cur:], []byte(address.String())) cur += len(address.String()) hmac := sha512.Sum512(append(packet[COMMAND_LENGTH+ID_LENGTH:cur], session_secret...)) copy(packet[cur:], hmac[:]) return packet } func ParseSessionConnect(session_connect []byte, session_secret []byte) (*net.UDPAddr, error) { if len(session_connect) < SESSION_CONNECT_LENGTH { return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH) } cur := 0 address_length := int(binary.BigEndian.Uint16(session_connect[cur:cur+2])) cur += 2 if len(session_connect) != (SESSION_CONNECT_LENGTH + address_length) { return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH + address_length) } address := string(session_connect[cur:cur+address_length]) cur += address_length hmac_digest := make([]byte, cur) copy(hmac_digest, session_connect[:cur]) hmac := session_connect[cur:cur+HMAC_LENGTH] cur += HMAC_LENGTH calculated_hmac := sha512.Sum512(append(hmac_digest, session_secret...)) if slices.Compare(hmac, calculated_hmac[:]) != 0 { return nil, fmt.Errorf("Session connect bad HMAC") } addr, err := net.ResolveUDPAddr("udp", address) if err != nil { return nil, fmt.Errorf("Error parsing return address: %w", err) } return addr, nil } func NewSessionData(session *Session, packet []byte) ([]byte, error) { iv := make([]byte, IV_LENGTH) for i := 0; i < IV_LENGTH/8; i++ { binary.BigEndian.PutUint64(iv[i*8:], session.iv_generator.Uint64()) } stream := cipher.NewOFB(session.cipher, iv[:]) header := make([]byte, COMMAND_LENGTH + ID_LENGTH + IV_LENGTH) header[0] = byte(SESSION_DATA) copy(header[COMMAND_LENGTH:], session.ID[:]) copy(header[COMMAND_LENGTH+ID_LENGTH:], iv) packet_encrypted := bytes.NewBuffer(header) writer := &cipher.StreamWriter{S: stream, W: packet_encrypted} // Encrypt the packet with a crc32 checksum appended to the end _, err := io.Copy(writer, bytes.NewBuffer(binary.BigEndian.AppendUint32(packet, crc32.ChecksumIEEE(packet)))) if err != nil { return nil, err } return packet_encrypted.Bytes(), nil } func ParseSessionData(session *Session, encrypted []byte) ([]byte, error) { iv := encrypted[0:IV_LENGTH] stream := cipher.NewOFB(session.cipher, iv) var packet bytes.Buffer reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(encrypted[IV_LENGTH:])} _, err := io.Copy(&packet, reader) if err != nil { return nil, err } packet_clear := packet.Bytes() checksum := binary.BigEndian.Uint32(packet_clear[len(packet_clear)-4:]) data := packet_clear[:len(packet_clear)-4] calculated := crc32.ChecksumIEEE(data) if checksum != calculated { return nil, fmt.Errorf("SESSION_DATA checksum mismatch: %x != %x", checksum, calculated) } return data, nil } func NewSessionClose(session *Session) []byte { packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CLOSE_LENGTH) packet[0] = COMMAND_LENGTH copy(packet[1:], session.ID[:]) hmac := sha512.Sum512(append(session.ID[:], session.secret...)) copy(packet[COMMAND_LENGTH + ID_LENGTH:], hmac[:]) return packet } func ParseSessionClose(session *Session, hmac []byte) error { calculated_hmac := sha512.Sum512(append(session.ID[:], session.secret...)) if slices.Compare(hmac, calculated_hmac[:]) != 0 { return fmt.Errorf("Session Close HMAC validation failed") } else { return nil } }