Protocol improvements

live
noah metz 2024-04-12 18:06:57 -06:00
parent 21e9794747
commit 8f5a5b244f
5 changed files with 282 additions and 92 deletions

@ -5,16 +5,20 @@ import (
"fmt" "fmt"
) )
type ChannelID uint32 type ChannelID byte
const ( const (
MODE_CHANNEL ModeID = iota MODE_CHANNEL ModeID = iota
MODE_RAW MODE_RAW
MODE_AUDIO
CHANNEL_JOIN byte = iota CHANNEL_JOIN byte = iota
CHANNEL_LEAVE CHANNEL_LEAVE
CHANNEL_MEMBERS CHANNEL_MEMBERS
AUDIO_SET_SAMPLE_RATE = iota
AUDIO_GET_SAMPLE_RATE
RAW_DATA = iota RAW_DATA = iota
) )
@ -38,9 +42,9 @@ func(channel *Channel) Data(session *Session, mode ModeID, data []byte) []SendPa
} }
} }
func(channel *Channel) Command(session *Session, mode ModeID, command byte, data []byte) ([]SendPacket, error) { func(channel *Channel) Command(session *Session, packet ChannelCommandPacket) ([]SendPacket, error) {
if mode == MODE_CHANNEL { if packet.Mode == MODE_CHANNEL {
switch command { switch packet.Command {
case CHANNEL_JOIN: case CHANNEL_JOIN:
if slices.Contains(channel.sessions, session.ID) { if slices.Contains(channel.sessions, session.ID) {
return nil, fmt.Errorf("Session %s already in channel %d, can't join", session.ID, channel.id) return nil, fmt.Errorf("Session %s already in channel %d, can't join", session.ID, channel.id)
@ -57,14 +61,14 @@ func(channel *Channel) Command(session *Session, mode ModeID, command byte, data
return nil, nil return nil, nil
} }
default: default:
return nil, fmt.Errorf("Unknown MODE_CHANNEL command: 0x%02x", command) return nil, fmt.Errorf("Unknown MODE_CHANNEL command: 0x%02x", packet.Command)
} }
} else { } else {
mode, has_mode := channel.modes[mode] mode, has_mode := channel.modes[packet.Mode]
if has_mode == false { if has_mode == false {
return nil, fmt.Errorf("Channel has no mode 0x%02x", mode) return nil, fmt.Errorf("Channel has no mode 0x%02x", mode)
} else { } else {
return mode.Command(session, channel, command, data) return mode.Command(session, channel, packet)
} }
} }
} }
@ -76,14 +80,15 @@ type SendPacket struct {
type Mode interface { type Mode interface {
// Process takes incoming packets from a session and returns a list of packets to send // Process takes incoming packets from a session and returns a list of packets to send
Command(session *Session, channel *Channel, command byte, data []byte) ([]SendPacket, error) Command(session *Session, channel *Channel, packet ChannelCommandPacket) ([]SendPacket, error)
Data(session *Session, channel *Channel, data []byte) []SendPacket Data(session *Session, channel *Channel, data []byte) []SendPacket
} }
func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPacket { func multiplex_without_sender(origin SessionID, packet *Packet, sessions []SessionID) []SendPacket {
send_packets := make([]SendPacket, len(sessions)) send_packets := make([]SendPacket, len(sessions) - 1)
for i, session_id := range(sessions) { i := 0
if session_id == session.ID { for _, session_id := range(sessions) {
if session_id == origin {
continue continue
} }
@ -91,6 +96,19 @@ func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPac
Packet: packet, Packet: packet,
Session: session_id, Session: session_id,
} }
i += 1
}
return send_packets
}
func multiplex(packet *Packet, sessions []SessionID) []SendPacket {
send_packets := make([]SendPacket, len(sessions))
for i, session_id := range(sessions) {
send_packets[i] = SendPacket{
Packet: packet,
Session: session_id,
}
} }
return send_packets return send_packets
@ -99,14 +117,60 @@ func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPac
type RawMode struct { type RawMode struct {
} }
func(mode *RawMode) Command(session *Session, channel *Channel, command byte, data []byte) ([]SendPacket, error) { func(mode *RawMode) Command(session *Session, channel *Channel, packet ChannelCommandPacket) ([]SendPacket, error) {
return nil, fmt.Errorf("unknown raw mode command 0x%02x", command) return nil, fmt.Errorf("unknown raw mode command 0x%02x", packet.Command)
} }
func(mode *RawMode) Data(session *Session, channel *Channel, data []byte) []SendPacket { func(mode *RawMode) Data(session *Session, channel *Channel, data []byte) []SendPacket {
if slices.Contains(channel.sessions, session.ID) { if slices.Contains(channel.sessions, session.ID) {
new_packet := NewChannelPeerPacket(session.Peer, channel.id, MODE_RAW, data) new_packet := NewChannelPeerPacket(session.Peer, channel.id, MODE_RAW, data)
return multiplex(session, new_packet, channel.sessions) return multiplex_without_sender(session.ID, new_packet, channel.sessions)
}
return nil
}
type SampleRate byte
const (
SAMPLE_RATE_UNSET SampleRate = 0xFF
SAMPLE_RATE_24KHZ = 0x01
SAMPLE_RATE_48KHZ = 0x02
)
type AudioMode struct {
SampleRate SampleRate
}
func(mode *AudioMode) Command(session *Session, channel *Channel, packet ChannelCommandPacket) ([]SendPacket, error) {
switch packet.Command {
case AUDIO_SET_SAMPLE_RATE:
if len(packet.Data) == 1 {
switch SampleRate(packet.Data[0]) {
case SAMPLE_RATE_24KHZ:
fallthrough
case SAMPLE_RATE_48KHZ:
mode.SampleRate = SampleRate(packet.Data[0])
update_packet := NewChannelCommandPacket(packet.ReqID, channel.id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, packet.Data)
return multiplex(update_packet, channel.sessions), nil
default:
return nil, fmt.Errorf("Invalid sample rate: %x", packet.Data[0])
}
} else {
return nil, fmt.Errorf("Invalid AUDIO_SET_SAMPLE_RATE payload: %x", packet.Data)
}
case AUDIO_GET_SAMPLE_RATE:
return []SendPacket{{
Packet: NewChannelCommandPacket(packet.ReqID, channel.id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, []byte{byte(mode.SampleRate)}),
Session: session.ID,
}}, nil
default:
return nil, fmt.Errorf("unknown audio mode command 0x%02x", packet.Command)
}
}
func(mode *AudioMode) Data(session *Session, channel *Channel, data []byte) []SendPacket {
if slices.Contains(channel.sessions, session.ID) {
new_packet := NewChannelPeerPacket(session.Peer, channel.id, MODE_AUDIO, data)
return multiplex_without_sender(session.ID, new_packet, channel.sessions)
} }
return nil return nil
} }

@ -8,16 +8,83 @@ import (
"git.metznet.ca/MetzNet/pnyx" "git.metznet.ca/MetzNet/pnyx"
"github.com/gen2brain/malgo" "github.com/gen2brain/malgo"
"github.com/google/uuid"
"github.com/hraban/opus" "github.com/hraban/opus"
) )
func main() { var decoders = map[pnyx.PeerID]chan[]byte{}
decoders := map[pnyx.PeerID]chan[]byte{} var encoder *opus.Encoder
encoder, err := opus.NewEncoder(48000, 1, opus.AppVoIP) var sample_rate int = 0
var mic = make(chan []byte, 0)
var speaker = make(chan []byte, 0)
func set_sample_rate(new_sample_rate int) error {
sample_rate = new_sample_rate
var err error
fmt.Printf("Creating encoder with sample_rate %d\n", new_sample_rate)
encoder, err = opus.NewEncoder(new_sample_rate, 1, opus.AppVoIP)
if err != nil {
return err
}
for peer_id, decoder_chan := range(decoders) {
if decoder_chan != nil {
decoder_chan <- nil
}
new_chan := make(chan[]byte, 1000)
decoders[peer_id] = new_chan
go handle_peer_decode(peer_id, decoders[peer_id], sample_rate)
}
return nil
}
func handle_peer_decode(peer_id pnyx.PeerID, decode_chan chan[]byte, sample_rate int){
fmt.Printf("Starting decoder routine for %x with sample_rate %d\n", peer_id, sample_rate)
decoder, err := opus.NewDecoder(sample_rate, 1)
if err != nil { if err != nil {
panic(err) panic(err)
} }
running := true
for running {
select {
case <-time.After(20*time.Millisecond):
pcm := make([]int16, sample_rate/50)
err := decoder.DecodePLC(pcm)
if err != nil {
panic(err)
}
pcm_bytes := make([]byte, sample_rate/50*2)
for i := 0; i < sample_rate/50; i++ {
binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i]))
}
speaker <- pcm_bytes
case data := <-decode_chan:
if data == nil {
running = false
} else {
pcm := make([]int16, sample_rate/50*2)
written, err := decoder.Decode(data, pcm)
if err != nil {
panic(err)
}
pcm_bytes := make([]byte, written*2)
for i := 0; i < written; i++ {
binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i]))
}
speaker <- pcm_bytes
}
}
}
fmt.Printf("Stopping decoder routine for %x with sample_rate %d\n", peer_id, sample_rate)
}
func main() {
ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil) ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil)
if err != nil { if err != nil {
panic(err) panic(err)
@ -84,9 +151,6 @@ func main() {
outDeviceConfig.Alsa.NoMMap = 1 outDeviceConfig.Alsa.NoMMap = 1
outDeviceConfig.Playback.ShareMode = malgo.Shared outDeviceConfig.Playback.ShareMode = malgo.Shared
mic := make(chan []byte, 0)
speaker := make(chan []byte, 0)
onSendFrames := func(output_samples []byte, input_samples []byte, framecount uint32) { onSendFrames := func(output_samples []byte, input_samples []byte, framecount uint32) {
select { select {
case data := <- speaker: case data := <- speaker:
@ -115,20 +179,21 @@ func main() {
defer outDevice.Stop() defer outDevice.Stop()
onRecvFrames := func(output_samples []byte, input_samples []byte, framecount uint32) { onRecvFrames := func(output_samples []byte, input_samples []byte, framecount uint32) {
pcm := make([]int16, len(input_samples)/2) if encoder != nil {
for i := 0; i < len(input_samples)/2; i++ { pcm := make([]int16, len(input_samples)/2)
pcm[i] = int16(binary.LittleEndian.Uint16(input_samples[2*i:])) for i := 0; i < len(input_samples)/2; i++ {
} pcm[i] = int16(binary.LittleEndian.Uint16(input_samples[2*i:]))
}
data := make([]byte, len(input_samples))
written, err := encoder.Encode(pcm, data)
if err != nil {
panic(err)
}
select { data := make([]byte, len(input_samples))
case mic <- data[:written]: written, err := encoder.Encode(pcm, data)
default: if err != nil {
panic(err)
}
select {
case mic <- data[:written]:
default:
}
} }
} }
@ -182,51 +247,42 @@ func main() {
} }
switch packet := packet.(type) { switch packet := packet.(type) {
case pnyx.ChannelPeerPacket: case pnyx.ChannelCommandPacket:
if packet.Channel == pnyx.ChannelID(0) { if packet.Channel == pnyx.ChannelID(0) {
decode_chan, exists := decoders[packet.Peer] if packet.Mode == pnyx.MODE_AUDIO {
if exists == false { if packet.Command == pnyx.AUDIO_SET_SAMPLE_RATE {
decode_chan = make(chan[]byte, 1000) fmt.Printf("GOT NEW SAMPLE RATE 0x%02x\n", packet.Data)
decoders[packet.Peer] = decode_chan var new_sample_rate int
switch packet.Data[0] {
go func(decode_chan chan[]byte){ case byte(pnyx.SAMPLE_RATE_24KHZ):
decoder, err := opus.NewDecoder(48000, 1) new_sample_rate = 24000
case byte(pnyx.SAMPLE_RATE_48KHZ):
new_sample_rate = 48000
default:
continue
}
err := set_sample_rate(new_sample_rate)
if err != nil { if err != nil {
panic(err) panic(err)
} }
}
for true { }
select { }
case <-time.After(20*time.Millisecond): case pnyx.ChannelPeerPacket:
pcm := make([]int16, 960) if packet.Channel == pnyx.ChannelID(0) {
err := decoder.DecodePLC(pcm) decode_chan, exists := decoders[packet.Peer]
if err != nil { if exists == false {
panic(err) if sample_rate != 0 {
} decode_chan = make(chan[]byte, 1000)
decoders[packet.Peer] = decode_chan
pcm_bytes := make([]byte, 960*2) go handle_peer_decode(packet.Peer, decoders[packet.Peer], sample_rate)
for i := 0; i < 960; i++ { decode_chan <- packet.Data
binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i])) } else {
} decoders[packet.Peer] = nil
speaker <- pcm_bytes }
case data := <-decode_chan: } else if decode_chan != nil {
pcm := make([]int16, 960) decode_chan <- packet.Data
written, err := decoder.Decode(data, pcm)
if err != nil {
panic(err)
}
pcm_bytes := make([]byte, written*2)
for i := 0; i < written; i++ {
binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i]))
}
speaker <- pcm_bytes
}
}
}(decoders[packet.Peer])
} }
decode_chan <- packet.Data
} }
default: default:
fmt.Printf("Unhandled packet type: %s\n", packet) fmt.Printf("Unhandled packet type: %s\n", packet)
@ -234,15 +290,27 @@ func main() {
} }
}() }()
join_packet, _ := pnyx.NewChannelCommandPacket(pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_JOIN, nil) add_packet, _ := pnyx.NewServerCommandPacket(pnyx.SERVER_COMMAND_ADD_CHANNEL, []byte{0xFF})
err = client.Send(add_packet)
if err != nil {
panic(err)
}
join_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_JOIN, nil)
err = client.Send(join_packet) err = client.Send(join_packet)
if err != nil { if err != nil {
panic(err) panic(err)
} }
get_sample_rate_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_AUDIO, pnyx.AUDIO_SET_SAMPLE_RATE, []byte{byte(pnyx.SAMPLE_RATE_48KHZ)})
err = client.Send(get_sample_rate_packet)
if err != nil {
panic(err)
}
for true { for true {
data := <- mic data := <- mic
err = client.Send(pnyx.NewChannelDataPacket(pnyx.ChannelID(0), pnyx.MODE_RAW, data)) err = client.Send(pnyx.NewChannelDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data))
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -23,7 +23,9 @@ func main() {
panic(err) panic(err)
} }
err = server.AddChannel(pnyx.ChannelID(0), &pnyx.RawMode{}) err = server.AddChannel(pnyx.ChannelID(0), &pnyx.RawMode{}, &pnyx.AudioMode{
SampleRate: pnyx.SAMPLE_RATE_24KHZ,
})
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -2,7 +2,6 @@ package pnyx
import ( import (
"encoding" "encoding"
"encoding/binary"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
@ -10,13 +9,17 @@ import (
type PacketType uint8 type PacketType uint8
const ( const (
PACKET_CHANNEL_COMMAND PacketType = iota PACKET_SERVER_COMMAND PacketType = iota
PACKET_CHANNEL_COMMAND
PACKET_CHANNEL_DATA PACKET_CHANNEL_DATA
PACKET_CHANNEL_PEER PACKET_CHANNEL_PEER
CHANNEL_HEADER_LEN int = 5 CHANNEL_HEADER_LEN int = 2
CHANNEL_COMMAND_LEN = CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH CHANNEL_COMMAND_LEN = CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH
CHANNEL_PEER_LEN = CHANNEL_HEADER_LEN + PEER_ID_LENGTH CHANNEL_PEER_LEN = CHANNEL_HEADER_LEN + PEER_ID_LENGTH
SERVER_COMMAND_ADD_CHANNEL byte = iota
SERVER_COMMAND_DEL_CHANNEL
) )
type Payload interface { type Payload interface {
@ -43,6 +46,8 @@ func ParsePacket(data []byte) (Payload, error) {
} }
switch PacketType(data[0]) { switch PacketType(data[0]) {
case PACKET_SERVER_COMMAND:
return ParseServerCommandPacket(data[1:])
case PACKET_CHANNEL_DATA: case PACKET_CHANNEL_DATA:
return ParseChannelDataPacket(data[1:]) return ParseChannelDataPacket(data[1:])
case PACKET_CHANNEL_COMMAND: case PACKET_CHANNEL_COMMAND:
@ -54,24 +59,62 @@ func ParsePacket(data []byte) (Payload, error) {
} }
} }
type ServerCommandPacket struct {
ReqID uuid.UUID
Command byte
Data []byte
}
func (packet ServerCommandPacket) MarshalBinary() ([]byte, error) {
p := make([]byte, 17 + len(packet.Data))
copy(p, packet.ReqID[:])
p[16] = packet.Command
copy(p[17:], packet.Data)
return p, nil
}
func NewServerCommandPacket(command byte, data []byte) (*Packet, uuid.UUID) {
req_id := uuid.New()
return &Packet{
Type: PACKET_SERVER_COMMAND,
Payload: ServerCommandPacket{
ReqID: req_id,
Command: command,
Data: data,
},
}, req_id
}
func ParseServerCommandPacket(data []byte) (ServerCommandPacket, error) {
if len(data) < 17 {
return ServerCommandPacket{}, fmt.Errorf("Not enough data to decode ServerCommandPacket: %d/%d", len(data), 17)
}
return ServerCommandPacket{
ReqID: uuid.UUID(data[0:16]),
Command: data[16],
Data: data[17:],
}, nil
}
type ChannelHeader struct { type ChannelHeader struct {
Channel ChannelID Channel ChannelID
Mode ModeID Mode ModeID
} }
func(packet ChannelHeader) MarshalBinary() ([]byte, error) { func(packet ChannelHeader) MarshalBinary() ([]byte, error) {
p := binary.BigEndian.AppendUint32(nil, uint32(packet.Channel)) return []byte{byte(packet.Channel), byte(packet.Mode)}, nil
return append(p, byte(packet.Mode)), nil
} }
func ParseChannelHeader(data []byte) (ChannelHeader, error) { func ParseChannelHeader(data []byte) (ChannelHeader, error) {
if len(data) < 5 { if len(data) < 2 {
return ChannelHeader{}, fmt.Errorf("Not enough bytes to parse ChannelPacket(%d/%d)", len(data), 6) return ChannelHeader{}, fmt.Errorf("Not enough bytes to parse ChannelPacket(%d/%d)", len(data), 6)
} }
return ChannelHeader{ return ChannelHeader{
Channel: ChannelID(binary.BigEndian.Uint32(data)), Channel: ChannelID(data[0]),
Mode: ModeID(data[4]), Mode: ModeID(data[1]),
}, nil }, nil
} }
@ -82,8 +125,7 @@ type ChannelCommandPacket struct {
Data []byte Data []byte
} }
func NewChannelCommandPacket(channel ChannelID, mode ModeID, command byte, data []byte) (*Packet, uuid.UUID) { func NewChannelCommandPacket(request_id uuid.UUID, channel ChannelID, mode ModeID, command byte, data []byte) *Packet {
request_id := uuid.New()
return &Packet{ return &Packet{
Type: PACKET_CHANNEL_COMMAND, Type: PACKET_CHANNEL_COMMAND,
Payload: ChannelCommandPacket{ Payload: ChannelCommandPacket{
@ -95,7 +137,7 @@ func NewChannelCommandPacket(channel ChannelID, mode ModeID, command byte, data
ReqID: request_id, ReqID: request_id,
Data: data, Data: data,
}, },
}, request_id }
} }
func(packet ChannelCommandPacket) MarshalBinary() ([]byte, error) { func(packet ChannelCommandPacket) MarshalBinary() ([]byte, error) {
@ -104,9 +146,12 @@ func(packet ChannelCommandPacket) MarshalBinary() ([]byte, error) {
return nil, err return nil, err
} }
data := append(header, packet.Command) data := make([]byte, len(header) + len(packet.Data) + REQ_ID_LENGTH + COMMAND_LENGTH)
data = append(data, packet.ReqID[:]...) copy(data, header)
return append(data, packet.Data...), nil data[CHANNEL_HEADER_LEN] = packet.Command
copy(data[CHANNEL_HEADER_LEN + COMMAND_LENGTH:], packet.ReqID[:])
copy(data[CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH:], packet.Data)
return data, nil
} }
func ParseChannelCommandPacket(data []byte) (ChannelCommandPacket, error) { func ParseChannelCommandPacket(data []byte) (ChannelCommandPacket, error) {

@ -64,6 +64,7 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) {
modes: map[reflect.Type]ModeID{ modes: map[reflect.Type]ModeID{
reflect.TypeFor[*RawMode](): MODE_RAW, reflect.TypeFor[*RawMode](): MODE_RAW,
reflect.TypeFor[*AudioMode](): MODE_AUDIO,
}, },
sessions: map[SessionID]*ServerSession{}, sessions: map[SessionID]*ServerSession{},
@ -195,6 +196,15 @@ func handle_session_incoming(session *ServerSession, server *Server) {
} }
switch packet := packet.(type) { switch packet := packet.(type) {
case ServerCommandPacket:
switch packet.Command {
case SERVER_COMMAND_ADD_CHANNEL:
server.Log("Got add_channel with %x", packet.Data)
case SERVER_COMMAND_DEL_CHANNEL:
server.Log("Got del_channel with %x", packet.Data)
default:
server.Log("Unknown server command %x", packet.Command)
}
case ChannelDataPacket: case ChannelDataPacket:
var result []SendPacket = nil var result []SendPacket = nil
@ -218,7 +228,7 @@ func handle_session_incoming(session *ServerSession, server *Server) {
server.channels_lock.RLock() server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel] channel, exists := server.channels[packet.Channel]
if exists == true { if exists == true {
result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data) result, err = channel.Command(&session.Session, packet)
} }
server.channels_lock.RUnlock() server.channels_lock.RUnlock()
@ -314,6 +324,7 @@ func(server *Server) handle_session_close(session_close []byte, from *net.UDPAdd
return nil return nil
} }
// TODO: handle packets without creating so many objects(and preventing an extra copy)
func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error { func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error {
session_id := SessionID(data[:SESSION_ID_LENGTH]) session_id := SessionID(data[:SESSION_ID_LENGTH])
session, exists := server.sessions[session_id] session, exists := server.sessions[session_id]