Made the default client connect to channel 0 and stream audio data over the raw channel

live
noah metz 2024-04-08 11:28:52 -06:00
parent 37d42b5a9b
commit 4a2ce02617
11 changed files with 770 additions and 311 deletions

@ -1,22 +1,79 @@
package pnyx
import (
"slices"
)
type ChannelID uint32
const RootChannelID = 0
const (
RootChannelID ChannelID = 0
MODE_RAW ModeID = iota
MODE_COMMAND_DATA byte = 0x00
MODE_COMMAND_JOIN = 0x01
MODE_COMMAND_LEAVE = 0x02
)
type ModeID uint8
type CommandID uint8
type PermissionMap map[PeerID]map[ModeID]map[CommandID]bool
type Channel struct {
modes map[ModeID]Mode
permissions PermissionMap
parent ChannelID
}
type SendPacket struct {
Packet *Packet
Session SessionID
}
type Mode interface {
// Process takes incoming packets from a session and returns a list of packets to send
Process(*Session, *Packet) []SendPacket
}
func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPacket {
send_packets := make([]SendPacket, len(sessions))
for i, session_id := range(sessions) {
if session_id == session.ID {
continue
}
send_packets[i] = SendPacket{
Packet: packet,
Session: session_id,
}
}
return send_packets
}
type RawMode struct {
Sessions []SessionID
}
func(mode *RawMode) Process(session *Session, packet *Packet) []SendPacket {
switch packet.Command {
case MODE_COMMAND_JOIN:
if slices.Contains(mode.Sessions, session.ID) == false {
mode.Sessions = append(mode.Sessions, session.ID)
}
case MODE_COMMAND_LEAVE:
idx := slices.Index(mode.Sessions, session.ID)
if idx != -1 {
mode.Sessions = slices.Delete(mode.Sessions, idx, idx+1)
}
case MODE_COMMAND_DATA:
if slices.Contains(mode.Sessions, session.ID) {
new_packet := &Packet{
Channel: packet.Channel,
Mode: packet.Mode,
Command: MODE_COMMAND_DATA,
Data: append(session.Peer[:], packet.Data...),
}
return multiplex(session, new_packet, mode.Sessions)
}
}
return nil
}

@ -106,3 +106,40 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
Connection: connection,
}, nil
}
func(client *Client) Send(packet Packet) error {
client.ConnectionLock.Lock()
defer client.ConnectionLock.Unlock()
if client.Connection == nil {
return fmt.Errorf("Client is not connected")
}
data, err := packet.MarshalBinary()
if err != nil {
return err
}
wrapped, err := NewSessionData(&client.Session, data)
if err != nil {
return err
}
_, err = client.Connection.Write(wrapped)
return err
}
func(client *Client) Close() error {
client.ConnectionLock.Lock()
defer client.ConnectionLock.Unlock()
if client.Connection == nil {
return fmt.Errorf("No connection to close")
}
client.Connection.Write(NewSessionClose(&client.Session))
err := client.Connection.Close()
client.Connection = nil
return err
}

@ -5,13 +5,156 @@ import (
"os"
"git.metznet.ca/MetzNet/pnyx"
"github.com/gen2brain/malgo"
)
func main() {
ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil)
if err != nil {
panic(err)
}
defer ctx.Free()
defer ctx.Uninit()
inDeviceConfig := malgo.DefaultDeviceConfig(malgo.Capture)
inDeviceConfig.Capture.Format = malgo.FormatF32
inDeviceConfig.Capture.Channels = 1
inDeviceConfig.SampleRate = 44100
inDeviceConfig.PeriodSizeInFrames = 100
inDeviceConfig.Alsa.NoMMap = 1
inDeviceConfig.Capture.ShareMode = malgo.Shared
outDeviceConfig := malgo.DefaultDeviceConfig(malgo.Playback)
outDeviceConfig.Playback.Format = malgo.FormatF32
outDeviceConfig.Playback.Channels = 1
outDeviceConfig.SampleRate = 44100
outDeviceConfig.PeriodSizeInFrames = 100
outDeviceConfig.Alsa.NoMMap = 1
outDeviceConfig.Playback.ShareMode = malgo.Shared
mic := make(chan []byte, 0)
speaker := make(chan []byte, 0)
onSendFrames := func(output_samples []byte, input_samples []byte, framecount uint32) {
select {
case data := <- speaker:
copy(output_samples, data)
}
}
playbackCallbacks := malgo.DeviceCallbacks{
Data: onSendFrames,
}
fmt.Printf("Creating playback device\n")
outDevice, err := malgo.InitDevice(ctx.Context, outDeviceConfig, playbackCallbacks)
if err != nil {
panic(err)
}
err = outDevice.Start()
if err != nil {
panic(err)
}
defer outDevice.Uninit()
defer outDevice.Stop()
onRecvFrames := func(output_samples []byte, input_samples []byte, framecount uint32) {
data := make([]byte, len(input_samples))
copy(data, input_samples)
select {
case mic <- data:
default:
}
}
captureCallbacks := malgo.DeviceCallbacks{
Data: onRecvFrames,
}
fmt.Printf("Creating capture device\n")
inDevice, err := malgo.InitDevice(ctx.Context, inDeviceConfig, captureCallbacks)
if err != nil {
panic(err)
}
err = inDevice.Start()
if err != nil {
panic(err)
}
defer inDevice.Uninit()
defer inDevice.Stop()
fmt.Printf("Starting pnyx client\n")
client, err := pnyx.NewClient(nil, os.Args[1])
if err != nil {
panic(err)
}
fmt.Printf("Started session %s with %s", client.Session.ID, client.Session.Peer)
go func() {
var buf [1024]byte
for true {
read, _, err := client.Connection.ReadFromUDP(buf[:])
if err != nil {
fmt.Printf("Read Error %s\n", err)
break
}
data, err := pnyx.ParseSessionData(&client.Session, buf[pnyx.COMMAND_LENGTH + pnyx.ID_LENGTH:read])
if err != nil {
fmt.Printf("ParseSessionData Error %s\n", err)
continue
}
packet, err := pnyx.ParsePacket(data)
if err != nil {
fmt.Printf("ParsePacket Error %s - %x\n", err, data)
continue
}
_ = pnyx.PeerID(packet.Data[0:16])
speaker <- packet.Data[16:]
}
}()
err = client.Send(pnyx.Packet{
Channel: pnyx.ChannelID(1),
Mode: pnyx.MODE_RAW,
Command: pnyx.MODE_COMMAND_JOIN,
Data: nil,
})
if err != nil {
panic(err)
}
for true {
data := <- mic
err = client.Send(pnyx.Packet{
Channel: pnyx.ChannelID(1),
Mode: pnyx.MODE_RAW,
Command: pnyx.MODE_COMMAND_DATA,
Data: data,
})
if err != nil {
panic(err)
}
}
err = client.Send(pnyx.Packet{
Channel: pnyx.ChannelID(1),
Mode: pnyx.MODE_RAW,
Command: pnyx.MODE_COMMAND_LEAVE,
Data: nil,
})
if err != nil {
panic(err)
}
err = client.Close()
if err != nil {
panic(err)
}
}

@ -23,6 +23,11 @@ func main() {
panic(err)
}
err = server.AddChannel(pnyx.ChannelID(1), &pnyx.RawMode{})
if err != nil {
panic(err)
}
<-os_sigs
err = server.Stop()
if err != nil {

@ -1,8 +1,13 @@
module git.metznet.ca/MetzNet/pnyx
go 1.21.5
go 1.22.0
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/ebitengine/purego v0.7.1 // indirect
github.com/gen2brain/malgo v0.11.21 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 // indirect
github.com/pion/opus v0.0.0-20240403022900-1c7b6eddc7c9 // indirect
golang.org/x/sys v0.7.0 // indirect
)

@ -1,4 +1,14 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/ebitengine/purego v0.7.1 h1:6/55d26lG3o9VCZX8lping+bZcmShseiqlh2bnUDiPA=
github.com/ebitengine/purego v0.7.1/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ=
github.com/gen2brain/malgo v0.11.21 h1:qsS4Dh6zhZgmvAW5CtKRxDjQzHbc2NJlBG9eE0tgS8w=
github.com/gen2brain/malgo v0.11.21/go.mod h1:f9TtuN7DVrXMiV/yIceMeWpvanyVzJQMlBecJFVMxww=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
github.com/pion/opus v0.0.0-20240403022900-1c7b6eddc7c9 h1:/aqYkFcwlpZVXSt1cLDXppeDQlABu9zZq/mBVX3v/5w=
github.com/pion/opus v0.0.0-20240403022900-1c7b6eddc7c9/go.mod h1:APGXJHkH8qlbefy7R7/i6a2w/nvXC85hnHm8FjaGgMo=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

BIN
main

Binary file not shown.

@ -1,264 +1,37 @@
package pnyx
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"crypto/rand"
"crypto/sha512"
"encoding/binary"
"hash/crc32"
"io"
mrand "math/rand"
"net"
"fmt"
"slices"
"filippo.io/edwards25519"
)
type PacketType uint8
const (
ID_LENGTH = 16
IV_LENGTH = aes.BlockSize
PUBKEY_LENGTH = 32
ECDH_LENGTH = 32
SIGNATURE_LENGTH = 64
HMAC_LENGTH = 64
COMMAND_LENGTH = 1
SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH
SESSION_CONNECT_LENGTH = 2 + HMAC_LENGTH // + return addr string length
SESSION_CLOSE_LENGTH = HMAC_LENGTH
SESSION_OPEN PacketType = iota
SESSION_CONNECT
SESSION_CLOSE
SESSION_DATA
"encoding/binary"
"fmt"
)
func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) {
public_point, err := (&edwards25519.Point{}).SetBytes(public)
if err != nil {
return nil, err
}
h := sha512.Sum512(private.Seed())
private_scalar, err := (&edwards25519.Scalar{}).SetBytesWithClamping(h[:32])
if err != nil {
return nil, err
}
shared_point := public_point.ScalarMult(private_scalar, public_point)
return shared_point.BytesMontgomery(), nil
type Packet struct {
Channel ChannelID
Mode ModeID
Command byte
Data []byte
}
func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error) {
if key == nil {
return nil, nil, fmt.Errorf("Cannot create a SESSION_OPEN packet without a key")
}
public, private, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("Failed to generate ecdh key: %w", err)
}
packet := make([]byte, COMMAND_LENGTH + SESSION_OPEN_LENGTH)
cur := 0
packet[0] = byte(SESSION_OPEN)
cur += COMMAND_LENGTH
copy(packet[cur:], []byte(key.Public().(ed25519.PublicKey)))
cur += PUBKEY_LENGTH
copy(packet[cur:], []byte(public))
cur += PUBKEY_LENGTH
signature := ed25519.Sign(key, packet[COMMAND_LENGTH:cur])
copy(packet[cur:], signature)
cur += SIGNATURE_LENGTH
return packet, private, nil
func(packet Packet) String() string {
return fmt.Sprintf("{Channel: %x, Mode: %x, Data: %x}", packet.Channel, packet.Mode, packet.Data)
}
func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, error) {
if len(session_open) != SESSION_OPEN_LENGTH {
return Session{}, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH)
}
cur := 0
client_pubkey := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH])
cur += PUBKEY_LENGTH
client_ecdh := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH])
cur += PUBKEY_LENGTH
digest := session_open[:cur]
signature := session_open[cur:cur+SIGNATURE_LENGTH]
cur += SIGNATURE_LENGTH
if ed25519.Verify(client_pubkey, digest, signature) == false {
return Session{}, fmt.Errorf("SESSION_OPEN signature verification failed")
}
session_secret, err := ECDH(client_ecdh, ecdh)
if err != nil {
return Session{}, err
}
session_id := ID[SessionID](session_secret)
client_id := ID[PeerID](client_pubkey)
session_cipher, err := aes.NewCipher(session_secret)
if err != nil {
return Session{}, err
}
seed_bytes := make([]byte, 8)
read, err := rand.Read(seed_bytes)
if err != nil {
return Session{}, err
} else if read != 8 {
return Session{}, fmt.Errorf("Not enough entropy to create session")
}
return Session{
ID: session_id,
remote: nil,
Peer: client_id,
secret: session_secret,
cipher: session_cipher,
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
}, nil
func(packet Packet) MarshalBinary() ([]byte, error) {
p := binary.BigEndian.AppendUint32(nil, uint32(packet.Channel))
p = append(p, byte(packet.Mode))
p = append(p, byte(packet.Command))
return append(p, packet.Data...), nil
}
func NewSessionConnect(address *net.UDPAddr, session_secret []byte) []byte {
packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CONNECT_LENGTH + len(address.String()))
cur := 0
packet[cur] = byte(SESSION_CONNECT)
cur += COMMAND_LENGTH
session_id := [16]byte(ID[SessionID](session_secret))
copy(packet[cur:], session_id[:])
cur += ID_LENGTH
binary.BigEndian.PutUint16(packet[cur:], uint16(len(address.String())))
cur += 2
copy(packet[cur:], []byte(address.String()))
cur += len(address.String())
hmac := sha512.Sum512(append(packet[COMMAND_LENGTH+ID_LENGTH:cur], session_secret...))
copy(packet[cur:], hmac[:])
return packet
}
func ParseSessionConnect(session_connect []byte, session_secret []byte) (*net.UDPAddr, error) {
if len(session_connect) < SESSION_CONNECT_LENGTH {
return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH)
}
cur := 0
address_length := int(binary.BigEndian.Uint16(session_connect[cur:cur+2]))
cur += 2
if len(session_connect) != (SESSION_CONNECT_LENGTH + address_length) {
return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH + address_length)
}
address := string(session_connect[cur:cur+address_length])
cur += address_length
hmac_digest := make([]byte, cur)
copy(hmac_digest, session_connect[:cur])
hmac := session_connect[cur:cur+HMAC_LENGTH]
cur += HMAC_LENGTH
calculated_hmac := sha512.Sum512(append(hmac_digest, session_secret...))
if slices.Compare(hmac, calculated_hmac[:]) != 0 {
return nil, fmt.Errorf("Session connect bad HMAC")
}
addr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, fmt.Errorf("Error parsing return address: %w", err)
}
return addr, nil
}
func NewSessionData(session *Session, packet []byte) ([]byte, error) {
iv := make([]byte, IV_LENGTH)
for i := 0; i < IV_LENGTH/8; i++ {
binary.BigEndian.PutUint64(iv[i*8:], session.iv_generator.Uint64())
func ParsePacket(data []byte) (*Packet, error) {
if len(data) < 6 {
return nil, fmt.Errorf("Not enough bytes to parse Packet(%d/%d)", len(data), 6)
}
stream := cipher.NewOFB(session.cipher, iv[:])
header := make([]byte, COMMAND_LENGTH + ID_LENGTH + IV_LENGTH)
header[0] = byte(SESSION_DATA)
copy(header[COMMAND_LENGTH:], session.ID[:])
copy(header[COMMAND_LENGTH+ID_LENGTH:], iv)
packet_encrypted := bytes.NewBuffer(header)
writer := &cipher.StreamWriter{S: stream, W: packet_encrypted}
// Encrypt the packet with a crc32 checksum appended to the end
_, err := io.Copy(writer, bytes.NewBuffer(binary.BigEndian.AppendUint32(packet, crc32.ChecksumIEEE(packet))))
if err != nil {
return nil, err
}
return packet_encrypted.Bytes(), nil
}
func ParseSessionData(session *Session, encrypted []byte) ([]byte, error) {
iv := encrypted[0:IV_LENGTH]
stream := cipher.NewOFB(session.cipher, iv)
var packet bytes.Buffer
reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(encrypted[IV_LENGTH:])}
_, err := io.Copy(&packet, reader)
if err != nil {
return nil, err
}
packet_clear := packet.Bytes()
checksum := binary.BigEndian.Uint32(packet_clear[len(packet_clear)-4:])
data := packet_clear[:len(packet_clear)-4]
calculated := crc32.ChecksumIEEE(data)
if checksum != calculated {
return nil, fmt.Errorf("SESSION_DATA checksum mismatch: %x != %x", checksum, calculated)
}
return data, nil
}
func NewSessionClose(session *Session) []byte {
packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CLOSE_LENGTH)
packet[0] = byte(SESSION_CLOSE)
copy(packet[1:], session.ID[:])
hmac := sha512.Sum512(append(session.ID[:], session.secret...))
copy(packet[COMMAND_LENGTH + ID_LENGTH:], hmac[:])
return packet
}
func ParseSessionClose(session *Session, hmac []byte) error {
calculated_hmac := sha512.Sum512(append(session.ID[:], session.secret...))
if slices.Compare(hmac, calculated_hmac[:]) != 0 {
return fmt.Errorf("Session Close HMAC validation failed")
} else {
return nil
}
return &Packet{
Channel: ChannelID(binary.BigEndian.Uint32(data)),
Mode: ModeID(data[4]),
Command: data[5],
Data: data[6:],
}, nil
}

@ -1,35 +1,26 @@
package pnyx
import (
"crypto/cipher"
"crypto/ed25519"
"crypto/rand"
"errors"
"fmt"
mrand "math/rand"
"net"
"os"
"reflect"
"sync"
"sync/atomic"
"github.com/google/uuid"
)
const (
SERVER_UDP_BUFFER_SIZE = 2048
SERVER_SEND_BUFFER_SIZE = 2048
)
type SessionID uuid.UUID
func(id SessionID) String() string {
return uuid.UUID(id).String()
}
type Session struct {
ID SessionID
remote *net.UDPAddr
Peer PeerID
secret []byte
cipher cipher.Block
iv_generator mrand.Source64
type ServerSession struct {
Session
IncomingPackets chan[]byte
OutgoingPackets chan *Packet
}
type Server struct {
@ -38,7 +29,14 @@ type Server struct {
connection *net.UDPConn
stopped chan error
sessions map[SessionID]*Session
modes map[reflect.Type]ModeID
send_packets chan[]SendPacket
sessions_lock sync.Mutex
sessions map[SessionID]*ServerSession
channels_lock sync.RWMutex
channels map[ChannelID]*Channel
}
@ -56,57 +54,75 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) {
active: atomic.Bool{},
stopped: make(chan error, 0),
sessions: map[SessionID]*Session{},
send_packets: make(chan []SendPacket, SERVER_SEND_BUFFER_SIZE),
modes: map[reflect.Type]ModeID{
reflect.TypeFor[*RawMode](): MODE_RAW,
},
sessions: map[SessionID]*ServerSession{},
channels: map[ChannelID]*Channel{},
}
server.active.Store(false)
return server, nil
}
// Check if the client has permission for the command on the channel
// If it's not specified, check the permission on the parent
func Allowed(server *Server, client PeerID, channel_id ChannelID, mode ModeID, command CommandID) bool {
channel, exists := server.channels[channel_id]
func(server *Server) RemoveChannel(id ChannelID) error {
server.channels_lock.Lock()
defer server.channels_lock.Unlock()
_, exists := server.channels[id]
if exists == false {
return false
return fmt.Errorf("Channel %x does not exist", id)
}
if channel.permissions != nil {
client_perms, exists := channel.permissions[client]
if exists {
if client_perms == nil {
return true
}
delete(server.channels, id)
return nil
}
mode_perms, exists := client_perms[mode]
if exists {
if mode_perms == nil {
return true
}
func(server *Server) AddChannel(id ChannelID, modes ...Mode) error {
if id == RootChannelID {
return fmt.Errorf("Cannot use root channel ID as real channel")
}
allowed, exists := mode_perms[command]
if exists {
return allowed
}
}
server.channels_lock.Lock()
defer server.channels_lock.Unlock()
_, exists := server.channels[id]
if exists {
return fmt.Errorf("Channel with ID %x already exists", id)
}
mode_map := map[ModeID]Mode{}
for _, mode := range(modes) {
reflect_type := reflect.TypeOf(mode)
mode_id, known := server.modes[reflect_type]
if known == false {
return fmt.Errorf("Can't create channel with unknown mode: %s", reflect_type)
}
_, exists := mode_map[mode_id]
if exists {
return fmt.Errorf("Can't create channel with duplicate ModeID %x", mode_id)
}
mode_map[mode_id] = mode
}
// Prevent a cycle on the root channel
if channel_id == RootChannelID {
return false
} else {
return Allowed(server, client, channel.parent, mode, command)
server.channels[id] = &Channel{
modes: mode_map,
}
return nil
}
func (server *Server) Log(format string, fields ...interface{}) {
func(server *Server) Log(format string, fields ...interface{}) {
fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n")
}
func(server *Server) Stop() error {
was_active := server.active.CompareAndSwap(true, false)
if was_active {
close(server.send_packets)
err := server.connection.Close()
if err != nil {
return err
@ -117,7 +133,9 @@ func(server *Server) Stop() error {
}
}
func(server *Server) run() {
const SESSION_BUFFER_SIZE = 256
func(server *Server) listen_udp() {
server.Log("Started server on %s", server.connection.LocalAddr())
var buf [SERVER_UDP_BUFFER_SIZE]byte
@ -139,14 +157,96 @@ func(server *Server) run() {
continue
}
server.sessions[session.ID] = &session
server.sessions_lock.Lock()
server.sessions[session.ID] = &ServerSession{
Session: session,
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE),
}
server.sessions_lock.Unlock()
go func(session *ServerSession, server *Server){
server.Log("Starting session outgoing goroutine %s", session.ID)
for true {
packet := <- session.OutgoingPackets
if packet == nil {
break
}
if session.remote == nil {
server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID)
continue
}
packet_data, err := packet.MarshalBinary()
if err != nil {
server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err)
continue
}
encrypted, err := NewSessionData(&session.Session, packet_data)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err)
continue
}
_, err = server.connection.WriteToUDP(encrypted, session.remote)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err)
continue
}
}
server.Log("Stopping session outgoing goroutine %s", session.ID)
}(server.sessions[session.ID], server)
go func(session *ServerSession, server *Server){
server.Log("Starting session incoming goroutine %s", session.ID)
for true {
encrypted := <- session.IncomingPackets
if encrypted == nil {
break
}
data, err := ParseSessionData(&session.Session, encrypted)
if err != nil {
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
continue
}
packet, err := ParsePacket(data)
if err != nil {
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
}
if packet.Channel == RootChannelID {
// TODO process commands on the root channel
} else {
var result []SendPacket = nil
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
mode, exists := channel.modes[packet.Mode]
if exists == true {
result = mode.Process(&session.Session, packet)
}
}
server.channels_lock.RUnlock()
if result != nil {
//TODO: handle overflow
server.send_packets<-result
}
}
}
server.Log("Stopping session incoming goroutine %s", session.ID)
}(server.sessions[session.ID], server)
_, err = server.connection.WriteToUDP(session_open, from)
if err != nil {
server.Log("WriteToUDP error %s", err)
continue
}
server.Log("Started session %s with %s", session.ID, session.Peer)
case SESSION_CONNECT:
@ -164,10 +264,9 @@ func(server *Server) run() {
}
session.remote = client_addr
server.Log("Got SESSION_CONNECT for session %s at address %s", session.ID, session.remote)
// TODO: Send server hello back
server_hello, err := NewSessionData(session, []byte("hello"))
// TODO: Make a better server hello
server_hello, err := NewSessionData(&session.Session, []byte("hello"))
if err != nil {
server.Log("Error generating server hello: %s", err)
continue
@ -188,16 +287,37 @@ func(server *Server) run() {
continue
}
err := ParseSessionClose(session, buf[COMMAND_LENGTH+ID_LENGTH:])
err := ParseSessionClose(&session.Session, buf[COMMAND_LENGTH+ID_LENGTH:read])
if err != nil {
server.Log("Session close error for %s - %s", session_id, err)
continue
}
close(session.IncomingPackets)
close(session.OutgoingPackets)
server.sessions_lock.Lock()
delete(server.sessions, session_id)
server.sessions_lock.Unlock()
server.Log("Session %s closed", session_id)
case SESSION_DATA:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
server.Log("Session %s does not exist, can't receive data", session_id)
continue
}
buf_copy := make([]byte, read - COMMAND_LENGTH - ID_LENGTH)
copy(buf_copy, buf[COMMAND_LENGTH+ID_LENGTH:read])
select {
case session.IncomingPackets<-buf_copy:
default:
server.Log("Dropped packet to session %s", session_id)
}
default:
server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read])
}
@ -209,10 +329,36 @@ func(server *Server) run() {
}
}
server.sessions_lock.Lock()
for session_id, session := range(server.sessions) {
close(session.IncomingPackets)
close(session.OutgoingPackets)
delete(server.sessions, session_id)
}
server.sessions_lock.Unlock()
server.Log("Shut down server on %s", server.connection.LocalAddr())
server.stopped <- nil
}
func(server *Server) send_sessions() {
for true {
packets := <- server.send_packets
if packets == nil {
break
}
server.sessions_lock.Lock()
for _, packet := range(packets) {
session, exists := server.sessions[packet.Session]
if exists {
session.OutgoingPackets <- packet.Packet
}
}
server.sessions_lock.Unlock()
}
}
func(server *Server) Start(listen string) error {
was_inactive := server.active.CompareAndSwap(false, true)
if was_inactive == false {
@ -231,7 +377,8 @@ func(server *Server) Start(listen string) error {
return fmt.Errorf("Failed to create udp server: %w", err)
}
go server.run()
go server.listen_udp()
go server.send_sessions()
return nil
}

@ -0,0 +1,282 @@
package pnyx
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"crypto/rand"
"crypto/sha512"
"encoding/binary"
"hash/crc32"
"io"
mrand "math/rand"
"net"
"fmt"
"slices"
"filippo.io/edwards25519"
"github.com/google/uuid"
)
type SessionID uuid.UUID
func(id SessionID) String() string {
return uuid.UUID(id).String()
}
type Session struct {
ID SessionID
remote *net.UDPAddr
Peer PeerID
secret []byte
cipher cipher.Block
iv_generator mrand.Source64
}
type PacketType uint8
const (
ID_LENGTH = 16
IV_LENGTH = aes.BlockSize
PUBKEY_LENGTH = 32
ECDH_LENGTH = 32
SIGNATURE_LENGTH = 64
HMAC_LENGTH = 64
COMMAND_LENGTH = 1
SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH
SESSION_CONNECT_LENGTH = 2 + HMAC_LENGTH // + return addr string length
SESSION_CLOSE_LENGTH = HMAC_LENGTH
/*
pnyx session packets
*/
SESSION_OPEN PacketType = iota
SESSION_CONNECT
SESSION_CLOSE
SESSION_DATA
)
func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) {
public_point, err := (&edwards25519.Point{}).SetBytes(public)
if err != nil {
return nil, err
}
h := sha512.Sum512(private.Seed())
private_scalar, err := (&edwards25519.Scalar{}).SetBytesWithClamping(h[:32])
if err != nil {
return nil, err
}
shared_point := public_point.ScalarMult(private_scalar, public_point)
return shared_point.BytesMontgomery(), nil
}
func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error) {
if key == nil {
return nil, nil, fmt.Errorf("Cannot create a SESSION_OPEN packet without a key")
}
public, private, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("Failed to generate ecdh key: %w", err)
}
packet := make([]byte, COMMAND_LENGTH + SESSION_OPEN_LENGTH)
cur := 0
packet[0] = byte(SESSION_OPEN)
cur += COMMAND_LENGTH
copy(packet[cur:], []byte(key.Public().(ed25519.PublicKey)))
cur += PUBKEY_LENGTH
copy(packet[cur:], []byte(public))
cur += PUBKEY_LENGTH
signature := ed25519.Sign(key, packet[COMMAND_LENGTH:cur])
copy(packet[cur:], signature)
cur += SIGNATURE_LENGTH
return packet, private, nil
}
func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, error) {
if len(session_open) != SESSION_OPEN_LENGTH {
return Session{}, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH)
}
cur := 0
client_pubkey := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH])
cur += PUBKEY_LENGTH
client_ecdh := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH])
cur += PUBKEY_LENGTH
digest := session_open[:cur]
signature := session_open[cur:cur+SIGNATURE_LENGTH]
cur += SIGNATURE_LENGTH
if ed25519.Verify(client_pubkey, digest, signature) == false {
return Session{}, fmt.Errorf("SESSION_OPEN signature verification failed")
}
session_secret, err := ECDH(client_ecdh, ecdh)
if err != nil {
return Session{}, err
}
session_id := ID[SessionID](session_secret)
client_id := ID[PeerID](client_pubkey)
session_cipher, err := aes.NewCipher(session_secret)
if err != nil {
return Session{}, err
}
seed_bytes := make([]byte, 8)
read, err := rand.Read(seed_bytes)
if err != nil {
return Session{}, err
} else if read != 8 {
return Session{}, fmt.Errorf("Not enough entropy to create session")
}
return Session{
ID: session_id,
remote: nil,
Peer: client_id,
secret: session_secret,
cipher: session_cipher,
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
}, nil
}
func NewSessionConnect(address *net.UDPAddr, session_secret []byte) []byte {
packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CONNECT_LENGTH + len(address.String()))
cur := 0
packet[cur] = byte(SESSION_CONNECT)
cur += COMMAND_LENGTH
session_id := [16]byte(ID[SessionID](session_secret))
copy(packet[cur:], session_id[:])
cur += ID_LENGTH
binary.BigEndian.PutUint16(packet[cur:], uint16(len(address.String())))
cur += 2
copy(packet[cur:], []byte(address.String()))
cur += len(address.String())
hmac := sha512.Sum512(append(packet[COMMAND_LENGTH+ID_LENGTH:cur], session_secret...))
copy(packet[cur:], hmac[:])
return packet
}
func ParseSessionConnect(session_connect []byte, session_secret []byte) (*net.UDPAddr, error) {
if len(session_connect) < SESSION_CONNECT_LENGTH {
return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH)
}
cur := 0
address_length := int(binary.BigEndian.Uint16(session_connect[cur:cur+2]))
cur += 2
if len(session_connect) != (SESSION_CONNECT_LENGTH + address_length) {
return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH + address_length)
}
address := string(session_connect[cur:cur+address_length])
cur += address_length
hmac_digest := make([]byte, cur)
copy(hmac_digest, session_connect[:cur])
hmac := session_connect[cur:cur+HMAC_LENGTH]
cur += HMAC_LENGTH
calculated_hmac := sha512.Sum512(append(hmac_digest, session_secret...))
if slices.Compare(hmac, calculated_hmac[:]) != 0 {
return nil, fmt.Errorf("Session connect bad HMAC")
}
addr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, fmt.Errorf("Error parsing return address: %w", err)
}
return addr, nil
}
func NewSessionData(session *Session, packet []byte) ([]byte, error) {
iv := make([]byte, IV_LENGTH)
for i := 0; i < IV_LENGTH/8; i++ {
binary.BigEndian.PutUint64(iv[i*8:], session.iv_generator.Uint64())
}
stream := cipher.NewOFB(session.cipher, iv[:])
header := make([]byte, COMMAND_LENGTH + ID_LENGTH + IV_LENGTH)
header[0] = byte(SESSION_DATA)
copy(header[COMMAND_LENGTH:], session.ID[:])
copy(header[COMMAND_LENGTH+ID_LENGTH:], iv)
packet_encrypted := bytes.NewBuffer(header)
writer := &cipher.StreamWriter{S: stream, W: packet_encrypted}
// Encrypt the packet with a crc32 checksum appended to the end
_, err := io.Copy(writer, bytes.NewBuffer(binary.BigEndian.AppendUint32(packet, crc32.ChecksumIEEE(packet))))
if err != nil {
return nil, err
}
return packet_encrypted.Bytes(), nil
}
func ParseSessionData(session *Session, encrypted []byte) ([]byte, error) {
iv := encrypted[0:IV_LENGTH]
stream := cipher.NewOFB(session.cipher, iv)
var packet bytes.Buffer
reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(encrypted[IV_LENGTH:])}
_, err := io.Copy(&packet, reader)
if err != nil {
return nil, err
}
packet_clear := packet.Bytes()
checksum := binary.BigEndian.Uint32(packet_clear[len(packet_clear)-4:])
data := packet_clear[:len(packet_clear)-4]
calculated := crc32.ChecksumIEEE(data)
if checksum != calculated {
return nil, fmt.Errorf("SESSION_DATA checksum mismatch: %x != %x", checksum, calculated)
}
return data, nil
}
func NewSessionClose(session *Session) []byte {
packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CLOSE_LENGTH)
packet[0] = byte(SESSION_CLOSE)
copy(packet[1:], session.ID[:])
hmac := sha512.Sum512(append(session.ID[:], session.secret...))
copy(packet[COMMAND_LENGTH + ID_LENGTH:], hmac[:])
return packet
}
func ParseSessionClose(session *Session, hmac []byte) error {
calculated_hmac := sha512.Sum512(append(session.ID[:], session.secret...))
if slices.Compare(hmac, calculated_hmac[:]) != 0 {
return fmt.Errorf("Session Close HMAC validation failed")
} else {
return nil
}
}