diff --git a/channel.go b/channel.go index e1ce091..a0b806b 100644 --- a/channel.go +++ b/channel.go @@ -1,22 +1,79 @@ package pnyx import ( + "slices" ) type ChannelID uint32 -const RootChannelID = 0 +const ( + RootChannelID ChannelID = 0 + + MODE_RAW ModeID = iota + + MODE_COMMAND_DATA byte = 0x00 + MODE_COMMAND_JOIN = 0x01 + MODE_COMMAND_LEAVE = 0x02 +) type ModeID uint8 type CommandID uint8 -type PermissionMap map[PeerID]map[ModeID]map[CommandID]bool - type Channel struct { modes map[ModeID]Mode - permissions PermissionMap - parent ChannelID +} + +type SendPacket struct { + Packet *Packet + Session SessionID } type Mode interface { + // Process takes incoming packets from a session and returns a list of packets to send + Process(*Session, *Packet) []SendPacket +} + +func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPacket { + send_packets := make([]SendPacket, len(sessions)) + for i, session_id := range(sessions) { + if session_id == session.ID { + continue + } + + send_packets[i] = SendPacket{ + Packet: packet, + Session: session_id, + } + } + + return send_packets +} + +type RawMode struct { + Sessions []SessionID +} + +func(mode *RawMode) Process(session *Session, packet *Packet) []SendPacket { + switch packet.Command { + case MODE_COMMAND_JOIN: + if slices.Contains(mode.Sessions, session.ID) == false { + mode.Sessions = append(mode.Sessions, session.ID) + } + case MODE_COMMAND_LEAVE: + idx := slices.Index(mode.Sessions, session.ID) + if idx != -1 { + mode.Sessions = slices.Delete(mode.Sessions, idx, idx+1) + } + case MODE_COMMAND_DATA: + if slices.Contains(mode.Sessions, session.ID) { + new_packet := &Packet{ + Channel: packet.Channel, + Mode: packet.Mode, + Command: MODE_COMMAND_DATA, + Data: append(session.Peer[:], packet.Data...), + } + return multiplex(session, new_packet, mode.Sessions) + } + } + return nil } diff --git a/client.go b/client.go index 19d0619..77ac7ce 100644 --- a/client.go +++ b/client.go @@ -106,3 +106,40 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { Connection: connection, }, nil } + +func(client *Client) Send(packet Packet) error { + client.ConnectionLock.Lock() + defer client.ConnectionLock.Unlock() + + if client.Connection == nil { + return fmt.Errorf("Client is not connected") + } + + data, err := packet.MarshalBinary() + if err != nil { + return err + } + + wrapped, err := NewSessionData(&client.Session, data) + if err != nil { + return err + } + + _, err = client.Connection.Write(wrapped) + return err +} + +func(client *Client) Close() error { + client.ConnectionLock.Lock() + defer client.ConnectionLock.Unlock() + + if client.Connection == nil { + return fmt.Errorf("No connection to close") + } + + client.Connection.Write(NewSessionClose(&client.Session)) + err := client.Connection.Close() + client.Connection = nil + + return err +} diff --git a/cmd/client/main.go b/cmd/client/main.go index 8225285..7dd4f94 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -5,13 +5,156 @@ import ( "os" "git.metznet.ca/MetzNet/pnyx" + "github.com/gen2brain/malgo" ) func main() { + ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil) + if err != nil { + panic(err) + } + + defer ctx.Free() + defer ctx.Uninit() + + inDeviceConfig := malgo.DefaultDeviceConfig(malgo.Capture) + inDeviceConfig.Capture.Format = malgo.FormatF32 + inDeviceConfig.Capture.Channels = 1 + inDeviceConfig.SampleRate = 44100 + inDeviceConfig.PeriodSizeInFrames = 100 + inDeviceConfig.Alsa.NoMMap = 1 + inDeviceConfig.Capture.ShareMode = malgo.Shared + + outDeviceConfig := malgo.DefaultDeviceConfig(malgo.Playback) + outDeviceConfig.Playback.Format = malgo.FormatF32 + outDeviceConfig.Playback.Channels = 1 + outDeviceConfig.SampleRate = 44100 + outDeviceConfig.PeriodSizeInFrames = 100 + outDeviceConfig.Alsa.NoMMap = 1 + outDeviceConfig.Playback.ShareMode = malgo.Shared + + mic := make(chan []byte, 0) + speaker := make(chan []byte, 0) + + onSendFrames := func(output_samples []byte, input_samples []byte, framecount uint32) { + select { + case data := <- speaker: + copy(output_samples, data) + } + } + + playbackCallbacks := malgo.DeviceCallbacks{ + Data: onSendFrames, + } + + fmt.Printf("Creating playback device\n") + outDevice, err := malgo.InitDevice(ctx.Context, outDeviceConfig, playbackCallbacks) + if err != nil { + panic(err) + } + + err = outDevice.Start() + if err != nil { + panic(err) + } + + defer outDevice.Uninit() + defer outDevice.Stop() + + onRecvFrames := func(output_samples []byte, input_samples []byte, framecount uint32) { + data := make([]byte, len(input_samples)) + copy(data, input_samples) + + select { + case mic <- data: + default: + } + } + + captureCallbacks := malgo.DeviceCallbacks{ + Data: onRecvFrames, + } + + fmt.Printf("Creating capture device\n") + inDevice, err := malgo.InitDevice(ctx.Context, inDeviceConfig, captureCallbacks) + if err != nil { + panic(err) + } + + err = inDevice.Start() + if err != nil { + panic(err) + } + + defer inDevice.Uninit() + defer inDevice.Stop() + + fmt.Printf("Starting pnyx client\n") + client, err := pnyx.NewClient(nil, os.Args[1]) if err != nil { panic(err) } - - fmt.Printf("Started session %s with %s", client.Session.ID, client.Session.Peer) + + go func() { + var buf [1024]byte + for true { + read, _, err := client.Connection.ReadFromUDP(buf[:]) + if err != nil { + fmt.Printf("Read Error %s\n", err) + break + } + + data, err := pnyx.ParseSessionData(&client.Session, buf[pnyx.COMMAND_LENGTH + pnyx.ID_LENGTH:read]) + if err != nil { + fmt.Printf("ParseSessionData Error %s\n", err) + continue + } + + packet, err := pnyx.ParsePacket(data) + if err != nil { + fmt.Printf("ParsePacket Error %s - %x\n", err, data) + continue + } + _ = pnyx.PeerID(packet.Data[0:16]) + speaker <- packet.Data[16:] + } + }() + + err = client.Send(pnyx.Packet{ + Channel: pnyx.ChannelID(1), + Mode: pnyx.MODE_RAW, + Command: pnyx.MODE_COMMAND_JOIN, + Data: nil, + }) + if err != nil { + panic(err) + } + + for true { + data := <- mic + err = client.Send(pnyx.Packet{ + Channel: pnyx.ChannelID(1), + Mode: pnyx.MODE_RAW, + Command: pnyx.MODE_COMMAND_DATA, + Data: data, + }) + if err != nil { + panic(err) + } + } + + err = client.Send(pnyx.Packet{ + Channel: pnyx.ChannelID(1), + Mode: pnyx.MODE_RAW, + Command: pnyx.MODE_COMMAND_LEAVE, + Data: nil, + }) + if err != nil { + panic(err) + } + err = client.Close() + if err != nil { + panic(err) + } } diff --git a/cmd/server/main.go b/cmd/server/main.go index cf7815b..b106551 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -23,6 +23,11 @@ func main() { panic(err) } + err = server.AddChannel(pnyx.ChannelID(1), &pnyx.RawMode{}) + if err != nil { + panic(err) + } + <-os_sigs err = server.Stop() if err != nil { diff --git a/go.mod b/go.mod index fdf2b6a..00dd52f 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,13 @@ module git.metznet.ca/MetzNet/pnyx -go 1.21.5 +go 1.22.0 require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/ebitengine/purego v0.7.1 // indirect + github.com/gen2brain/malgo v0.11.21 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 // indirect + github.com/pion/opus v0.0.0-20240403022900-1c7b6eddc7c9 // indirect + golang.org/x/sys v0.7.0 // indirect ) diff --git a/go.sum b/go.sum index 4aae644..c518d29 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,14 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/ebitengine/purego v0.7.1 h1:6/55d26lG3o9VCZX8lping+bZcmShseiqlh2bnUDiPA= +github.com/ebitengine/purego v0.7.1/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= +github.com/gen2brain/malgo v0.11.21 h1:qsS4Dh6zhZgmvAW5CtKRxDjQzHbc2NJlBG9eE0tgS8w= +github.com/gen2brain/malgo v0.11.21/go.mod h1:f9TtuN7DVrXMiV/yIceMeWpvanyVzJQMlBecJFVMxww= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc= +github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es= +github.com/pion/opus v0.0.0-20240403022900-1c7b6eddc7c9 h1:/aqYkFcwlpZVXSt1cLDXppeDQlABu9zZq/mBVX3v/5w= +github.com/pion/opus v0.0.0-20240403022900-1c7b6eddc7c9/go.mod h1:APGXJHkH8qlbefy7R7/i6a2w/nvXC85hnHm8FjaGgMo= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/main b/main new file mode 100755 index 0000000..e2f5630 Binary files /dev/null and b/main differ diff --git a/packet.go b/packet.go index d58edec..193541e 100644 --- a/packet.go +++ b/packet.go @@ -1,264 +1,37 @@ package pnyx import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/ed25519" - "crypto/rand" - "crypto/sha512" - "encoding/binary" - "hash/crc32" - "io" - mrand "math/rand" - "net" - - "fmt" - "slices" - - "filippo.io/edwards25519" -) - -type PacketType uint8 -const ( - ID_LENGTH = 16 - IV_LENGTH = aes.BlockSize - PUBKEY_LENGTH = 32 - ECDH_LENGTH = 32 - SIGNATURE_LENGTH = 64 - HMAC_LENGTH = 64 - COMMAND_LENGTH = 1 - - SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH - SESSION_CONNECT_LENGTH = 2 + HMAC_LENGTH // + return addr string length - SESSION_CLOSE_LENGTH = HMAC_LENGTH - - SESSION_OPEN PacketType = iota - SESSION_CONNECT - SESSION_CLOSE - SESSION_DATA + "encoding/binary" + "fmt" ) -func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) { - public_point, err := (&edwards25519.Point{}).SetBytes(public) - if err != nil { - return nil, err - } - - - h := sha512.Sum512(private.Seed()) - private_scalar, err := (&edwards25519.Scalar{}).SetBytesWithClamping(h[:32]) - if err != nil { - return nil, err - } - - shared_point := public_point.ScalarMult(private_scalar, public_point) - - return shared_point.BytesMontgomery(), nil -} - -func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error) { - if key == nil { - return nil, nil, fmt.Errorf("Cannot create a SESSION_OPEN packet without a key") - } - - public, private, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, fmt.Errorf("Failed to generate ecdh key: %w", err) - } - - packet := make([]byte, COMMAND_LENGTH + SESSION_OPEN_LENGTH) - cur := 0 - - packet[0] = byte(SESSION_OPEN) - cur += COMMAND_LENGTH - - copy(packet[cur:], []byte(key.Public().(ed25519.PublicKey))) - cur += PUBKEY_LENGTH - - copy(packet[cur:], []byte(public)) - cur += PUBKEY_LENGTH - - signature := ed25519.Sign(key, packet[COMMAND_LENGTH:cur]) - copy(packet[cur:], signature) - cur += SIGNATURE_LENGTH - - return packet, private, nil -} - -func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, error) { - if len(session_open) != SESSION_OPEN_LENGTH { - return Session{}, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH) - } - - cur := 0 - - client_pubkey := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH]) - cur += PUBKEY_LENGTH - - client_ecdh := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH]) - cur += PUBKEY_LENGTH - - digest := session_open[:cur] - signature := session_open[cur:cur+SIGNATURE_LENGTH] - cur += SIGNATURE_LENGTH - - if ed25519.Verify(client_pubkey, digest, signature) == false { - return Session{}, fmt.Errorf("SESSION_OPEN signature verification failed") - } - - session_secret, err := ECDH(client_ecdh, ecdh) - if err != nil { - return Session{}, err - } - - session_id := ID[SessionID](session_secret) - client_id := ID[PeerID](client_pubkey) - - session_cipher, err := aes.NewCipher(session_secret) - if err != nil { - return Session{}, err - } - - seed_bytes := make([]byte, 8) - read, err := rand.Read(seed_bytes) - if err != nil { - return Session{}, err - } else if read != 8 { - return Session{}, fmt.Errorf("Not enough entropy to create session") - } - - return Session{ - ID: session_id, - remote: nil, - Peer: client_id, - secret: session_secret, - cipher: session_cipher, - iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64), - }, nil -} - -func NewSessionConnect(address *net.UDPAddr, session_secret []byte) []byte { - packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CONNECT_LENGTH + len(address.String())) - cur := 0 - - packet[cur] = byte(SESSION_CONNECT) - cur += COMMAND_LENGTH - - session_id := [16]byte(ID[SessionID](session_secret)) - copy(packet[cur:], session_id[:]) - cur += ID_LENGTH - - binary.BigEndian.PutUint16(packet[cur:], uint16(len(address.String()))) - cur += 2 - - copy(packet[cur:], []byte(address.String())) - cur += len(address.String()) - - hmac := sha512.Sum512(append(packet[COMMAND_LENGTH+ID_LENGTH:cur], session_secret...)) - copy(packet[cur:], hmac[:]) - - return packet +type Packet struct { + Channel ChannelID + Mode ModeID + Command byte + Data []byte } -func ParseSessionConnect(session_connect []byte, session_secret []byte) (*net.UDPAddr, error) { - if len(session_connect) < SESSION_CONNECT_LENGTH { - return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH) - } - - cur := 0 - - address_length := int(binary.BigEndian.Uint16(session_connect[cur:cur+2])) - cur += 2 - - if len(session_connect) != (SESSION_CONNECT_LENGTH + address_length) { - return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH + address_length) - } - - address := string(session_connect[cur:cur+address_length]) - cur += address_length - - hmac_digest := make([]byte, cur) - copy(hmac_digest, session_connect[:cur]) - - hmac := session_connect[cur:cur+HMAC_LENGTH] - cur += HMAC_LENGTH - - calculated_hmac := sha512.Sum512(append(hmac_digest, session_secret...)) - if slices.Compare(hmac, calculated_hmac[:]) != 0 { - return nil, fmt.Errorf("Session connect bad HMAC") - } - - addr, err := net.ResolveUDPAddr("udp", address) - if err != nil { - return nil, fmt.Errorf("Error parsing return address: %w", err) - } - - return addr, nil +func(packet Packet) String() string { + return fmt.Sprintf("{Channel: %x, Mode: %x, Data: %x}", packet.Channel, packet.Mode, packet.Data) } -func NewSessionData(session *Session, packet []byte) ([]byte, error) { - iv := make([]byte, IV_LENGTH) - for i := 0; i < IV_LENGTH/8; i++ { - binary.BigEndian.PutUint64(iv[i*8:], session.iv_generator.Uint64()) - } - - stream := cipher.NewOFB(session.cipher, iv[:]) - header := make([]byte, COMMAND_LENGTH + ID_LENGTH + IV_LENGTH) - header[0] = byte(SESSION_DATA) - copy(header[COMMAND_LENGTH:], session.ID[:]) - copy(header[COMMAND_LENGTH+ID_LENGTH:], iv) - - packet_encrypted := bytes.NewBuffer(header) - writer := &cipher.StreamWriter{S: stream, W: packet_encrypted} - - // Encrypt the packet with a crc32 checksum appended to the end - _, err := io.Copy(writer, bytes.NewBuffer(binary.BigEndian.AppendUint32(packet, crc32.ChecksumIEEE(packet)))) - if err != nil { - return nil, err - } - - return packet_encrypted.Bytes(), nil +func(packet Packet) MarshalBinary() ([]byte, error) { + p := binary.BigEndian.AppendUint32(nil, uint32(packet.Channel)) + p = append(p, byte(packet.Mode)) + p = append(p, byte(packet.Command)) + return append(p, packet.Data...), nil } -func ParseSessionData(session *Session, encrypted []byte) ([]byte, error) { - iv := encrypted[0:IV_LENGTH] - - stream := cipher.NewOFB(session.cipher, iv) - var packet bytes.Buffer - reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(encrypted[IV_LENGTH:])} - _, err := io.Copy(&packet, reader) - if err != nil { - return nil, err +func ParsePacket(data []byte) (*Packet, error) { + if len(data) < 6 { + return nil, fmt.Errorf("Not enough bytes to parse Packet(%d/%d)", len(data), 6) } - packet_clear := packet.Bytes() - checksum := binary.BigEndian.Uint32(packet_clear[len(packet_clear)-4:]) - data := packet_clear[:len(packet_clear)-4] - calculated := crc32.ChecksumIEEE(data) - if checksum != calculated { - return nil, fmt.Errorf("SESSION_DATA checksum mismatch: %x != %x", checksum, calculated) - } - - return data, nil -} - -func NewSessionClose(session *Session) []byte { - packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CLOSE_LENGTH) - packet[0] = byte(SESSION_CLOSE) - copy(packet[1:], session.ID[:]) - - hmac := sha512.Sum512(append(session.ID[:], session.secret...)) - copy(packet[COMMAND_LENGTH + ID_LENGTH:], hmac[:]) - - return packet -} - -func ParseSessionClose(session *Session, hmac []byte) error { - calculated_hmac := sha512.Sum512(append(session.ID[:], session.secret...)) - if slices.Compare(hmac, calculated_hmac[:]) != 0 { - return fmt.Errorf("Session Close HMAC validation failed") - } else { - return nil - } + return &Packet{ + Channel: ChannelID(binary.BigEndian.Uint32(data)), + Mode: ModeID(data[4]), + Command: data[5], + Data: data[6:], + }, nil } diff --git a/server.go b/server.go index 5fb394e..abab5cc 100644 --- a/server.go +++ b/server.go @@ -1,35 +1,26 @@ package pnyx import ( - "crypto/cipher" "crypto/ed25519" "crypto/rand" "errors" "fmt" - mrand "math/rand" "net" "os" + "reflect" + "sync" "sync/atomic" - - "github.com/google/uuid" ) const ( SERVER_UDP_BUFFER_SIZE = 2048 + SERVER_SEND_BUFFER_SIZE = 2048 ) -type SessionID uuid.UUID -func(id SessionID) String() string { - return uuid.UUID(id).String() -} - -type Session struct { - ID SessionID - remote *net.UDPAddr - Peer PeerID - secret []byte - cipher cipher.Block - iv_generator mrand.Source64 +type ServerSession struct { + Session + IncomingPackets chan[]byte + OutgoingPackets chan *Packet } type Server struct { @@ -38,7 +29,14 @@ type Server struct { connection *net.UDPConn stopped chan error - sessions map[SessionID]*Session + modes map[reflect.Type]ModeID + + send_packets chan[]SendPacket + + sessions_lock sync.Mutex + sessions map[SessionID]*ServerSession + + channels_lock sync.RWMutex channels map[ChannelID]*Channel } @@ -56,57 +54,75 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) { active: atomic.Bool{}, stopped: make(chan error, 0), - sessions: map[SessionID]*Session{}, + send_packets: make(chan []SendPacket, SERVER_SEND_BUFFER_SIZE), + + modes: map[reflect.Type]ModeID{ + reflect.TypeFor[*RawMode](): MODE_RAW, + }, + + sessions: map[SessionID]*ServerSession{}, channels: map[ChannelID]*Channel{}, } server.active.Store(false) return server, nil } -// Check if the client has permission for the command on the channel -// If it's not specified, check the permission on the parent -func Allowed(server *Server, client PeerID, channel_id ChannelID, mode ModeID, command CommandID) bool { - channel, exists := server.channels[channel_id] +func(server *Server) RemoveChannel(id ChannelID) error { + server.channels_lock.Lock() + defer server.channels_lock.Unlock() + + _, exists := server.channels[id] if exists == false { - return false + return fmt.Errorf("Channel %x does not exist", id) } - if channel.permissions != nil { - client_perms, exists := channel.permissions[client] - if exists { - if client_perms == nil { - return true - } + delete(server.channels, id) + return nil +} - mode_perms, exists := client_perms[mode] - if exists { - if mode_perms == nil { - return true - } +func(server *Server) AddChannel(id ChannelID, modes ...Mode) error { + if id == RootChannelID { + return fmt.Errorf("Cannot use root channel ID as real channel") + } - allowed, exists := mode_perms[command] - if exists { - return allowed - } - } + server.channels_lock.Lock() + defer server.channels_lock.Unlock() + + _, exists := server.channels[id] + if exists { + return fmt.Errorf("Channel with ID %x already exists", id) + } + + mode_map := map[ModeID]Mode{} + for _, mode := range(modes) { + reflect_type := reflect.TypeOf(mode) + mode_id, known := server.modes[reflect_type] + if known == false { + return fmt.Errorf("Can't create channel with unknown mode: %s", reflect_type) + } + + _, exists := mode_map[mode_id] + if exists { + return fmt.Errorf("Can't create channel with duplicate ModeID %x", mode_id) } + mode_map[mode_id] = mode } - // Prevent a cycle on the root channel - if channel_id == RootChannelID { - return false - } else { - return Allowed(server, client, channel.parent, mode, command) + server.channels[id] = &Channel{ + modes: mode_map, } + + return nil } -func (server *Server) Log(format string, fields ...interface{}) { +func(server *Server) Log(format string, fields ...interface{}) { fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n") } func(server *Server) Stop() error { was_active := server.active.CompareAndSwap(true, false) if was_active { + close(server.send_packets) err := server.connection.Close() if err != nil { return err @@ -117,7 +133,9 @@ func(server *Server) Stop() error { } } -func(server *Server) run() { +const SESSION_BUFFER_SIZE = 256 + +func(server *Server) listen_udp() { server.Log("Started server on %s", server.connection.LocalAddr()) var buf [SERVER_UDP_BUFFER_SIZE]byte @@ -139,14 +157,96 @@ func(server *Server) run() { continue } - server.sessions[session.ID] = &session + server.sessions_lock.Lock() + server.sessions[session.ID] = &ServerSession{ + Session: session, + IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE), + OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE), + } + server.sessions_lock.Unlock() + + go func(session *ServerSession, server *Server){ + server.Log("Starting session outgoing goroutine %s", session.ID) + for true { + packet := <- session.OutgoingPackets + if packet == nil { + break + } + + if session.remote == nil { + server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID) + continue + } + + packet_data, err := packet.MarshalBinary() + if err != nil { + server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err) + continue + } + + encrypted, err := NewSessionData(&session.Session, packet_data) + if err != nil { + server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err) + continue + } + + _, err = server.connection.WriteToUDP(encrypted, session.remote) + if err != nil { + server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err) + continue + } + } + server.Log("Stopping session outgoing goroutine %s", session.ID) + }(server.sessions[session.ID], server) + + go func(session *ServerSession, server *Server){ + server.Log("Starting session incoming goroutine %s", session.ID) + for true { + encrypted := <- session.IncomingPackets + if encrypted == nil { + break + } + + data, err := ParseSessionData(&session.Session, encrypted) + if err != nil { + server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) + continue + } + + packet, err := ParsePacket(data) + if err != nil { + server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err) + } + + if packet.Channel == RootChannelID { + // TODO process commands on the root channel + } else { + var result []SendPacket = nil + server.channels_lock.RLock() + channel, exists := server.channels[packet.Channel] + if exists == true { + mode, exists := channel.modes[packet.Mode] + if exists == true { + result = mode.Process(&session.Session, packet) + } + } + server.channels_lock.RUnlock() + + if result != nil { + //TODO: handle overflow + server.send_packets<-result + } + } + + } + server.Log("Stopping session incoming goroutine %s", session.ID) + }(server.sessions[session.ID], server) _, err = server.connection.WriteToUDP(session_open, from) if err != nil { server.Log("WriteToUDP error %s", err) continue } - server.Log("Started session %s with %s", session.ID, session.Peer) case SESSION_CONNECT: @@ -164,10 +264,9 @@ func(server *Server) run() { } session.remote = client_addr - server.Log("Got SESSION_CONNECT for session %s at address %s", session.ID, session.remote) - // TODO: Send server hello back - server_hello, err := NewSessionData(session, []byte("hello")) + // TODO: Make a better server hello + server_hello, err := NewSessionData(&session.Session, []byte("hello")) if err != nil { server.Log("Error generating server hello: %s", err) continue @@ -188,16 +287,37 @@ func(server *Server) run() { continue } - err := ParseSessionClose(session, buf[COMMAND_LENGTH+ID_LENGTH:]) + err := ParseSessionClose(&session.Session, buf[COMMAND_LENGTH+ID_LENGTH:read]) if err != nil { server.Log("Session close error for %s - %s", session_id, err) continue } + close(session.IncomingPackets) + close(session.OutgoingPackets) + + server.sessions_lock.Lock() delete(server.sessions, session_id) + server.sessions_lock.Unlock() + server.Log("Session %s closed", session_id) case SESSION_DATA: + session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) + session, exists := server.sessions[session_id] + if exists == false { + server.Log("Session %s does not exist, can't receive data", session_id) + continue + } + + buf_copy := make([]byte, read - COMMAND_LENGTH - ID_LENGTH) + copy(buf_copy, buf[COMMAND_LENGTH+ID_LENGTH:read]) + + select { + case session.IncomingPackets<-buf_copy: + default: + server.Log("Dropped packet to session %s", session_id) + } default: server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read]) } @@ -209,10 +329,36 @@ func(server *Server) run() { } } + server.sessions_lock.Lock() + for session_id, session := range(server.sessions) { + close(session.IncomingPackets) + close(session.OutgoingPackets) + delete(server.sessions, session_id) + } + server.sessions_lock.Unlock() + server.Log("Shut down server on %s", server.connection.LocalAddr()) server.stopped <- nil } +func(server *Server) send_sessions() { + for true { + packets := <- server.send_packets + if packets == nil { + break + } + + server.sessions_lock.Lock() + for _, packet := range(packets) { + session, exists := server.sessions[packet.Session] + if exists { + session.OutgoingPackets <- packet.Packet + } + } + server.sessions_lock.Unlock() + } +} + func(server *Server) Start(listen string) error { was_inactive := server.active.CompareAndSwap(false, true) if was_inactive == false { @@ -231,7 +377,8 @@ func(server *Server) Start(listen string) error { return fmt.Errorf("Failed to create udp server: %w", err) } - go server.run() + go server.listen_udp() + go server.send_sessions() return nil } diff --git a/session.go b/session.go new file mode 100644 index 0000000..ead97a3 --- /dev/null +++ b/session.go @@ -0,0 +1,282 @@ +package pnyx + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/ed25519" + "crypto/rand" + "crypto/sha512" + "encoding/binary" + "hash/crc32" + "io" + mrand "math/rand" + "net" + + "fmt" + "slices" + + "filippo.io/edwards25519" + "github.com/google/uuid" +) + +type SessionID uuid.UUID +func(id SessionID) String() string { + return uuid.UUID(id).String() +} + +type Session struct { + ID SessionID + remote *net.UDPAddr + Peer PeerID + secret []byte + cipher cipher.Block + iv_generator mrand.Source64 +} + +type PacketType uint8 +const ( + ID_LENGTH = 16 + IV_LENGTH = aes.BlockSize + PUBKEY_LENGTH = 32 + ECDH_LENGTH = 32 + SIGNATURE_LENGTH = 64 + HMAC_LENGTH = 64 + COMMAND_LENGTH = 1 + + SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH + SESSION_CONNECT_LENGTH = 2 + HMAC_LENGTH // + return addr string length + SESSION_CLOSE_LENGTH = HMAC_LENGTH + + /* + pnyx session packets + */ + SESSION_OPEN PacketType = iota + SESSION_CONNECT + SESSION_CLOSE + SESSION_DATA +) + +func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) { + public_point, err := (&edwards25519.Point{}).SetBytes(public) + if err != nil { + return nil, err + } + + + h := sha512.Sum512(private.Seed()) + private_scalar, err := (&edwards25519.Scalar{}).SetBytesWithClamping(h[:32]) + if err != nil { + return nil, err + } + + shared_point := public_point.ScalarMult(private_scalar, public_point) + + return shared_point.BytesMontgomery(), nil +} + +func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error) { + if key == nil { + return nil, nil, fmt.Errorf("Cannot create a SESSION_OPEN packet without a key") + } + + public, private, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("Failed to generate ecdh key: %w", err) + } + + packet := make([]byte, COMMAND_LENGTH + SESSION_OPEN_LENGTH) + cur := 0 + + packet[0] = byte(SESSION_OPEN) + cur += COMMAND_LENGTH + + copy(packet[cur:], []byte(key.Public().(ed25519.PublicKey))) + cur += PUBKEY_LENGTH + + copy(packet[cur:], []byte(public)) + cur += PUBKEY_LENGTH + + signature := ed25519.Sign(key, packet[COMMAND_LENGTH:cur]) + copy(packet[cur:], signature) + cur += SIGNATURE_LENGTH + + return packet, private, nil +} + +func ParseSessionOpen(ecdh ed25519.PrivateKey, session_open []byte) (Session, error) { + if len(session_open) != SESSION_OPEN_LENGTH { + return Session{}, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH) + } + + cur := 0 + + client_pubkey := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH]) + cur += PUBKEY_LENGTH + + client_ecdh := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH]) + cur += PUBKEY_LENGTH + + digest := session_open[:cur] + signature := session_open[cur:cur+SIGNATURE_LENGTH] + cur += SIGNATURE_LENGTH + + if ed25519.Verify(client_pubkey, digest, signature) == false { + return Session{}, fmt.Errorf("SESSION_OPEN signature verification failed") + } + + session_secret, err := ECDH(client_ecdh, ecdh) + if err != nil { + return Session{}, err + } + + session_id := ID[SessionID](session_secret) + client_id := ID[PeerID](client_pubkey) + + session_cipher, err := aes.NewCipher(session_secret) + if err != nil { + return Session{}, err + } + + seed_bytes := make([]byte, 8) + read, err := rand.Read(seed_bytes) + if err != nil { + return Session{}, err + } else if read != 8 { + return Session{}, fmt.Errorf("Not enough entropy to create session") + } + + return Session{ + ID: session_id, + remote: nil, + Peer: client_id, + secret: session_secret, + cipher: session_cipher, + iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64), + }, nil +} + +func NewSessionConnect(address *net.UDPAddr, session_secret []byte) []byte { + packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CONNECT_LENGTH + len(address.String())) + cur := 0 + + packet[cur] = byte(SESSION_CONNECT) + cur += COMMAND_LENGTH + + session_id := [16]byte(ID[SessionID](session_secret)) + copy(packet[cur:], session_id[:]) + cur += ID_LENGTH + + binary.BigEndian.PutUint16(packet[cur:], uint16(len(address.String()))) + cur += 2 + + copy(packet[cur:], []byte(address.String())) + cur += len(address.String()) + + hmac := sha512.Sum512(append(packet[COMMAND_LENGTH+ID_LENGTH:cur], session_secret...)) + copy(packet[cur:], hmac[:]) + + return packet +} + +func ParseSessionConnect(session_connect []byte, session_secret []byte) (*net.UDPAddr, error) { + if len(session_connect) < SESSION_CONNECT_LENGTH { + return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH) + } + + cur := 0 + + address_length := int(binary.BigEndian.Uint16(session_connect[cur:cur+2])) + cur += 2 + + if len(session_connect) != (SESSION_CONNECT_LENGTH + address_length) { + return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH + address_length) + } + + address := string(session_connect[cur:cur+address_length]) + cur += address_length + + hmac_digest := make([]byte, cur) + copy(hmac_digest, session_connect[:cur]) + + hmac := session_connect[cur:cur+HMAC_LENGTH] + cur += HMAC_LENGTH + + calculated_hmac := sha512.Sum512(append(hmac_digest, session_secret...)) + if slices.Compare(hmac, calculated_hmac[:]) != 0 { + return nil, fmt.Errorf("Session connect bad HMAC") + } + + addr, err := net.ResolveUDPAddr("udp", address) + if err != nil { + return nil, fmt.Errorf("Error parsing return address: %w", err) + } + + return addr, nil +} + +func NewSessionData(session *Session, packet []byte) ([]byte, error) { + iv := make([]byte, IV_LENGTH) + for i := 0; i < IV_LENGTH/8; i++ { + binary.BigEndian.PutUint64(iv[i*8:], session.iv_generator.Uint64()) + } + + stream := cipher.NewOFB(session.cipher, iv[:]) + header := make([]byte, COMMAND_LENGTH + ID_LENGTH + IV_LENGTH) + header[0] = byte(SESSION_DATA) + copy(header[COMMAND_LENGTH:], session.ID[:]) + copy(header[COMMAND_LENGTH+ID_LENGTH:], iv) + + packet_encrypted := bytes.NewBuffer(header) + writer := &cipher.StreamWriter{S: stream, W: packet_encrypted} + + // Encrypt the packet with a crc32 checksum appended to the end + _, err := io.Copy(writer, bytes.NewBuffer(binary.BigEndian.AppendUint32(packet, crc32.ChecksumIEEE(packet)))) + if err != nil { + return nil, err + } + + return packet_encrypted.Bytes(), nil +} + +func ParseSessionData(session *Session, encrypted []byte) ([]byte, error) { + iv := encrypted[0:IV_LENGTH] + + stream := cipher.NewOFB(session.cipher, iv) + var packet bytes.Buffer + reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(encrypted[IV_LENGTH:])} + _, err := io.Copy(&packet, reader) + if err != nil { + return nil, err + } + + packet_clear := packet.Bytes() + checksum := binary.BigEndian.Uint32(packet_clear[len(packet_clear)-4:]) + data := packet_clear[:len(packet_clear)-4] + calculated := crc32.ChecksumIEEE(data) + if checksum != calculated { + return nil, fmt.Errorf("SESSION_DATA checksum mismatch: %x != %x", checksum, calculated) + } + + return data, nil +} + +func NewSessionClose(session *Session) []byte { + packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CLOSE_LENGTH) + packet[0] = byte(SESSION_CLOSE) + copy(packet[1:], session.ID[:]) + + hmac := sha512.Sum512(append(session.ID[:], session.secret...)) + copy(packet[COMMAND_LENGTH + ID_LENGTH:], hmac[:]) + + return packet +} + +func ParseSessionClose(session *Session, hmac []byte) error { + calculated_hmac := sha512.Sum512(append(session.ID[:], session.secret...)) + if slices.Compare(hmac, calculated_hmac[:]) != 0 { + return fmt.Errorf("Session Close HMAC validation failed") + } else { + return nil + } +} diff --git a/packet_test.go b/session_test.go similarity index 100% rename from packet_test.go rename to session_test.go