From 8f5a5b244fc0c9a363fc0a3326df86640036d1b6 Mon Sep 17 00:00:00 2001 From: Noah Metz Date: Fri, 12 Apr 2024 18:06:57 -0600 Subject: [PATCH] Protocol improvements --- channel.go | 94 ++++++++++++++++++---- cmd/client/main.go | 190 ++++++++++++++++++++++++++++++--------------- cmd/server/main.go | 4 +- packet.go | 73 +++++++++++++---- server.go | 13 +++- 5 files changed, 282 insertions(+), 92 deletions(-) diff --git a/channel.go b/channel.go index e2399cd..73e893d 100644 --- a/channel.go +++ b/channel.go @@ -5,16 +5,20 @@ import ( "fmt" ) -type ChannelID uint32 +type ChannelID byte const ( MODE_CHANNEL ModeID = iota MODE_RAW + MODE_AUDIO CHANNEL_JOIN byte = iota CHANNEL_LEAVE CHANNEL_MEMBERS + AUDIO_SET_SAMPLE_RATE = iota + AUDIO_GET_SAMPLE_RATE + RAW_DATA = iota ) @@ -38,9 +42,9 @@ func(channel *Channel) Data(session *Session, mode ModeID, data []byte) []SendPa } } -func(channel *Channel) Command(session *Session, mode ModeID, command byte, data []byte) ([]SendPacket, error) { - if mode == MODE_CHANNEL { - switch command { +func(channel *Channel) Command(session *Session, packet ChannelCommandPacket) ([]SendPacket, error) { + if packet.Mode == MODE_CHANNEL { + switch packet.Command { case CHANNEL_JOIN: if slices.Contains(channel.sessions, session.ID) { return nil, fmt.Errorf("Session %s already in channel %d, can't join", session.ID, channel.id) @@ -57,14 +61,14 @@ func(channel *Channel) Command(session *Session, mode ModeID, command byte, data return nil, nil } default: - return nil, fmt.Errorf("Unknown MODE_CHANNEL command: 0x%02x", command) + return nil, fmt.Errorf("Unknown MODE_CHANNEL command: 0x%02x", packet.Command) } } else { - mode, has_mode := channel.modes[mode] + mode, has_mode := channel.modes[packet.Mode] if has_mode == false { return nil, fmt.Errorf("Channel has no mode 0x%02x", mode) } else { - return mode.Command(session, channel, command, data) + return mode.Command(session, channel, packet) } } } @@ -76,14 +80,15 @@ type SendPacket struct { type Mode interface { // Process takes incoming packets from a session and returns a list of packets to send - Command(session *Session, channel *Channel, command byte, data []byte) ([]SendPacket, error) + Command(session *Session, channel *Channel, packet ChannelCommandPacket) ([]SendPacket, error) Data(session *Session, channel *Channel, data []byte) []SendPacket } -func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPacket { - send_packets := make([]SendPacket, len(sessions)) - for i, session_id := range(sessions) { - if session_id == session.ID { +func multiplex_without_sender(origin SessionID, packet *Packet, sessions []SessionID) []SendPacket { + send_packets := make([]SendPacket, len(sessions) - 1) + i := 0 + for _, session_id := range(sessions) { + if session_id == origin { continue } @@ -91,6 +96,19 @@ func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPac Packet: packet, Session: session_id, } + i += 1 + } + + return send_packets +} + +func multiplex(packet *Packet, sessions []SessionID) []SendPacket { + send_packets := make([]SendPacket, len(sessions)) + for i, session_id := range(sessions) { + send_packets[i] = SendPacket{ + Packet: packet, + Session: session_id, + } } return send_packets @@ -99,14 +117,60 @@ func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPac type RawMode struct { } -func(mode *RawMode) Command(session *Session, channel *Channel, command byte, data []byte) ([]SendPacket, error) { - return nil, fmt.Errorf("unknown raw mode command 0x%02x", command) +func(mode *RawMode) Command(session *Session, channel *Channel, packet ChannelCommandPacket) ([]SendPacket, error) { + return nil, fmt.Errorf("unknown raw mode command 0x%02x", packet.Command) } func(mode *RawMode) Data(session *Session, channel *Channel, data []byte) []SendPacket { if slices.Contains(channel.sessions, session.ID) { new_packet := NewChannelPeerPacket(session.Peer, channel.id, MODE_RAW, data) - return multiplex(session, new_packet, channel.sessions) + return multiplex_without_sender(session.ID, new_packet, channel.sessions) + } + return nil +} + +type SampleRate byte +const ( + SAMPLE_RATE_UNSET SampleRate = 0xFF + SAMPLE_RATE_24KHZ = 0x01 + SAMPLE_RATE_48KHZ = 0x02 +) + +type AudioMode struct { + SampleRate SampleRate +} + +func(mode *AudioMode) Command(session *Session, channel *Channel, packet ChannelCommandPacket) ([]SendPacket, error) { + switch packet.Command { + case AUDIO_SET_SAMPLE_RATE: + if len(packet.Data) == 1 { + switch SampleRate(packet.Data[0]) { + case SAMPLE_RATE_24KHZ: + fallthrough + case SAMPLE_RATE_48KHZ: + mode.SampleRate = SampleRate(packet.Data[0]) + update_packet := NewChannelCommandPacket(packet.ReqID, channel.id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, packet.Data) + return multiplex(update_packet, channel.sessions), nil + default: + return nil, fmt.Errorf("Invalid sample rate: %x", packet.Data[0]) + } + } else { + return nil, fmt.Errorf("Invalid AUDIO_SET_SAMPLE_RATE payload: %x", packet.Data) + } + case AUDIO_GET_SAMPLE_RATE: + return []SendPacket{{ + Packet: NewChannelCommandPacket(packet.ReqID, channel.id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, []byte{byte(mode.SampleRate)}), + Session: session.ID, + }}, nil + default: + return nil, fmt.Errorf("unknown audio mode command 0x%02x", packet.Command) + } +} + +func(mode *AudioMode) Data(session *Session, channel *Channel, data []byte) []SendPacket { + if slices.Contains(channel.sessions, session.ID) { + new_packet := NewChannelPeerPacket(session.Peer, channel.id, MODE_AUDIO, data) + return multiplex_without_sender(session.ID, new_packet, channel.sessions) } return nil } diff --git a/cmd/client/main.go b/cmd/client/main.go index 5d55de1..c49b650 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -8,16 +8,83 @@ import ( "git.metznet.ca/MetzNet/pnyx" "github.com/gen2brain/malgo" + "github.com/google/uuid" "github.com/hraban/opus" ) -func main() { - decoders := map[pnyx.PeerID]chan[]byte{} - encoder, err := opus.NewEncoder(48000, 1, opus.AppVoIP) +var decoders = map[pnyx.PeerID]chan[]byte{} +var encoder *opus.Encoder +var sample_rate int = 0 +var mic = make(chan []byte, 0) +var speaker = make(chan []byte, 0) + +func set_sample_rate(new_sample_rate int) error { + sample_rate = new_sample_rate + + var err error + fmt.Printf("Creating encoder with sample_rate %d\n", new_sample_rate) + encoder, err = opus.NewEncoder(new_sample_rate, 1, opus.AppVoIP) + if err != nil { + return err + } + + for peer_id, decoder_chan := range(decoders) { + if decoder_chan != nil { + decoder_chan <- nil + } + new_chan := make(chan[]byte, 1000) + decoders[peer_id] = new_chan + go handle_peer_decode(peer_id, decoders[peer_id], sample_rate) + } + return nil +} + +func handle_peer_decode(peer_id pnyx.PeerID, decode_chan chan[]byte, sample_rate int){ + fmt.Printf("Starting decoder routine for %x with sample_rate %d\n", peer_id, sample_rate) + decoder, err := opus.NewDecoder(sample_rate, 1) if err != nil { panic(err) } + running := true + for running { + select { + case <-time.After(20*time.Millisecond): + pcm := make([]int16, sample_rate/50) + err := decoder.DecodePLC(pcm) + if err != nil { + panic(err) + } + + pcm_bytes := make([]byte, sample_rate/50*2) + for i := 0; i < sample_rate/50; i++ { + binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i])) + } + speaker <- pcm_bytes + + case data := <-decode_chan: + if data == nil { + running = false + } else { + pcm := make([]int16, sample_rate/50*2) + written, err := decoder.Decode(data, pcm) + if err != nil { + panic(err) + } + + pcm_bytes := make([]byte, written*2) + for i := 0; i < written; i++ { + binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i])) + } + + speaker <- pcm_bytes + } + } + } + fmt.Printf("Stopping decoder routine for %x with sample_rate %d\n", peer_id, sample_rate) +} + +func main() { ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil) if err != nil { panic(err) @@ -84,9 +151,6 @@ func main() { outDeviceConfig.Alsa.NoMMap = 1 outDeviceConfig.Playback.ShareMode = malgo.Shared - mic := make(chan []byte, 0) - speaker := make(chan []byte, 0) - onSendFrames := func(output_samples []byte, input_samples []byte, framecount uint32) { select { case data := <- speaker: @@ -115,20 +179,21 @@ func main() { defer outDevice.Stop() onRecvFrames := func(output_samples []byte, input_samples []byte, framecount uint32) { - pcm := make([]int16, len(input_samples)/2) - for i := 0; i < len(input_samples)/2; i++ { - pcm[i] = int16(binary.LittleEndian.Uint16(input_samples[2*i:])) - } - - data := make([]byte, len(input_samples)) - written, err := encoder.Encode(pcm, data) - if err != nil { - panic(err) - } + if encoder != nil { + pcm := make([]int16, len(input_samples)/2) + for i := 0; i < len(input_samples)/2; i++ { + pcm[i] = int16(binary.LittleEndian.Uint16(input_samples[2*i:])) + } - select { - case mic <- data[:written]: - default: + data := make([]byte, len(input_samples)) + written, err := encoder.Encode(pcm, data) + if err != nil { + panic(err) + } + select { + case mic <- data[:written]: + default: + } } } @@ -182,51 +247,42 @@ func main() { } switch packet := packet.(type) { - case pnyx.ChannelPeerPacket: + case pnyx.ChannelCommandPacket: if packet.Channel == pnyx.ChannelID(0) { - decode_chan, exists := decoders[packet.Peer] - if exists == false { - decode_chan = make(chan[]byte, 1000) - decoders[packet.Peer] = decode_chan - - go func(decode_chan chan[]byte){ - decoder, err := opus.NewDecoder(48000, 1) + if packet.Mode == pnyx.MODE_AUDIO { + if packet.Command == pnyx.AUDIO_SET_SAMPLE_RATE { + fmt.Printf("GOT NEW SAMPLE RATE 0x%02x\n", packet.Data) + var new_sample_rate int + switch packet.Data[0] { + case byte(pnyx.SAMPLE_RATE_24KHZ): + new_sample_rate = 24000 + case byte(pnyx.SAMPLE_RATE_48KHZ): + new_sample_rate = 48000 + default: + continue + } + err := set_sample_rate(new_sample_rate) if err != nil { panic(err) } - - for true { - select { - case <-time.After(20*time.Millisecond): - pcm := make([]int16, 960) - err := decoder.DecodePLC(pcm) - if err != nil { - panic(err) - } - - pcm_bytes := make([]byte, 960*2) - for i := 0; i < 960; i++ { - binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i])) - } - speaker <- pcm_bytes - case data := <-decode_chan: - pcm := make([]int16, 960) - written, err := decoder.Decode(data, pcm) - if err != nil { - panic(err) - } - - pcm_bytes := make([]byte, written*2) - for i := 0; i < written; i++ { - binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i])) - } - speaker <- pcm_bytes - } - } - - }(decoders[packet.Peer]) + } + } + } + case pnyx.ChannelPeerPacket: + if packet.Channel == pnyx.ChannelID(0) { + decode_chan, exists := decoders[packet.Peer] + if exists == false { + if sample_rate != 0 { + decode_chan = make(chan[]byte, 1000) + decoders[packet.Peer] = decode_chan + go handle_peer_decode(packet.Peer, decoders[packet.Peer], sample_rate) + decode_chan <- packet.Data + } else { + decoders[packet.Peer] = nil + } + } else if decode_chan != nil { + decode_chan <- packet.Data } - decode_chan <- packet.Data } default: fmt.Printf("Unhandled packet type: %s\n", packet) @@ -234,15 +290,27 @@ func main() { } }() - join_packet, _ := pnyx.NewChannelCommandPacket(pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_JOIN, nil) + add_packet, _ := pnyx.NewServerCommandPacket(pnyx.SERVER_COMMAND_ADD_CHANNEL, []byte{0xFF}) + err = client.Send(add_packet) + if err != nil { + panic(err) + } + + join_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_JOIN, nil) err = client.Send(join_packet) if err != nil { panic(err) } + get_sample_rate_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_AUDIO, pnyx.AUDIO_SET_SAMPLE_RATE, []byte{byte(pnyx.SAMPLE_RATE_48KHZ)}) + err = client.Send(get_sample_rate_packet) + if err != nil { + panic(err) + } + for true { data := <- mic - err = client.Send(pnyx.NewChannelDataPacket(pnyx.ChannelID(0), pnyx.MODE_RAW, data)) + err = client.Send(pnyx.NewChannelDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data)) if err != nil { panic(err) } diff --git a/cmd/server/main.go b/cmd/server/main.go index 42b3882..77d63a8 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -23,7 +23,9 @@ func main() { panic(err) } - err = server.AddChannel(pnyx.ChannelID(0), &pnyx.RawMode{}) + err = server.AddChannel(pnyx.ChannelID(0), &pnyx.RawMode{}, &pnyx.AudioMode{ + SampleRate: pnyx.SAMPLE_RATE_24KHZ, + }) if err != nil { panic(err) } diff --git a/packet.go b/packet.go index ecb9557..0850460 100644 --- a/packet.go +++ b/packet.go @@ -2,7 +2,6 @@ package pnyx import ( "encoding" - "encoding/binary" "fmt" "github.com/google/uuid" @@ -10,13 +9,17 @@ import ( type PacketType uint8 const ( - PACKET_CHANNEL_COMMAND PacketType = iota + PACKET_SERVER_COMMAND PacketType = iota + PACKET_CHANNEL_COMMAND PACKET_CHANNEL_DATA PACKET_CHANNEL_PEER - CHANNEL_HEADER_LEN int = 5 + CHANNEL_HEADER_LEN int = 2 CHANNEL_COMMAND_LEN = CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH CHANNEL_PEER_LEN = CHANNEL_HEADER_LEN + PEER_ID_LENGTH + + SERVER_COMMAND_ADD_CHANNEL byte = iota + SERVER_COMMAND_DEL_CHANNEL ) type Payload interface { @@ -43,6 +46,8 @@ func ParsePacket(data []byte) (Payload, error) { } switch PacketType(data[0]) { + case PACKET_SERVER_COMMAND: + return ParseServerCommandPacket(data[1:]) case PACKET_CHANNEL_DATA: return ParseChannelDataPacket(data[1:]) case PACKET_CHANNEL_COMMAND: @@ -54,24 +59,62 @@ func ParsePacket(data []byte) (Payload, error) { } } +type ServerCommandPacket struct { + ReqID uuid.UUID + Command byte + Data []byte +} + +func (packet ServerCommandPacket) MarshalBinary() ([]byte, error) { + p := make([]byte, 17 + len(packet.Data)) + copy(p, packet.ReqID[:]) + p[16] = packet.Command + copy(p[17:], packet.Data) + + return p, nil +} + +func NewServerCommandPacket(command byte, data []byte) (*Packet, uuid.UUID) { + req_id := uuid.New() + return &Packet{ + Type: PACKET_SERVER_COMMAND, + Payload: ServerCommandPacket{ + ReqID: req_id, + Command: command, + Data: data, + }, + }, req_id +} + +func ParseServerCommandPacket(data []byte) (ServerCommandPacket, error) { + if len(data) < 17 { + return ServerCommandPacket{}, fmt.Errorf("Not enough data to decode ServerCommandPacket: %d/%d", len(data), 17) + } + + return ServerCommandPacket{ + ReqID: uuid.UUID(data[0:16]), + Command: data[16], + Data: data[17:], + }, nil +} + type ChannelHeader struct { Channel ChannelID Mode ModeID } func(packet ChannelHeader) MarshalBinary() ([]byte, error) { - p := binary.BigEndian.AppendUint32(nil, uint32(packet.Channel)) - return append(p, byte(packet.Mode)), nil + return []byte{byte(packet.Channel), byte(packet.Mode)}, nil } func ParseChannelHeader(data []byte) (ChannelHeader, error) { - if len(data) < 5 { + if len(data) < 2 { return ChannelHeader{}, fmt.Errorf("Not enough bytes to parse ChannelPacket(%d/%d)", len(data), 6) } return ChannelHeader{ - Channel: ChannelID(binary.BigEndian.Uint32(data)), - Mode: ModeID(data[4]), + Channel: ChannelID(data[0]), + Mode: ModeID(data[1]), }, nil } @@ -82,8 +125,7 @@ type ChannelCommandPacket struct { Data []byte } -func NewChannelCommandPacket(channel ChannelID, mode ModeID, command byte, data []byte) (*Packet, uuid.UUID) { - request_id := uuid.New() +func NewChannelCommandPacket(request_id uuid.UUID, channel ChannelID, mode ModeID, command byte, data []byte) *Packet { return &Packet{ Type: PACKET_CHANNEL_COMMAND, Payload: ChannelCommandPacket{ @@ -95,7 +137,7 @@ func NewChannelCommandPacket(channel ChannelID, mode ModeID, command byte, data ReqID: request_id, Data: data, }, - }, request_id + } } func(packet ChannelCommandPacket) MarshalBinary() ([]byte, error) { @@ -104,9 +146,12 @@ func(packet ChannelCommandPacket) MarshalBinary() ([]byte, error) { return nil, err } - data := append(header, packet.Command) - data = append(data, packet.ReqID[:]...) - return append(data, packet.Data...), nil + data := make([]byte, len(header) + len(packet.Data) + REQ_ID_LENGTH + COMMAND_LENGTH) + copy(data, header) + data[CHANNEL_HEADER_LEN] = packet.Command + copy(data[CHANNEL_HEADER_LEN + COMMAND_LENGTH:], packet.ReqID[:]) + copy(data[CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH:], packet.Data) + return data, nil } func ParseChannelCommandPacket(data []byte) (ChannelCommandPacket, error) { diff --git a/server.go b/server.go index 346eb00..9c5f62b 100644 --- a/server.go +++ b/server.go @@ -64,6 +64,7 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) { modes: map[reflect.Type]ModeID{ reflect.TypeFor[*RawMode](): MODE_RAW, + reflect.TypeFor[*AudioMode](): MODE_AUDIO, }, sessions: map[SessionID]*ServerSession{}, @@ -195,6 +196,15 @@ func handle_session_incoming(session *ServerSession, server *Server) { } switch packet := packet.(type) { + case ServerCommandPacket: + switch packet.Command { + case SERVER_COMMAND_ADD_CHANNEL: + server.Log("Got add_channel with %x", packet.Data) + case SERVER_COMMAND_DEL_CHANNEL: + server.Log("Got del_channel with %x", packet.Data) + default: + server.Log("Unknown server command %x", packet.Command) + } case ChannelDataPacket: var result []SendPacket = nil @@ -218,7 +228,7 @@ func handle_session_incoming(session *ServerSession, server *Server) { server.channels_lock.RLock() channel, exists := server.channels[packet.Channel] if exists == true { - result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data) + result, err = channel.Command(&session.Session, packet) } server.channels_lock.RUnlock() @@ -314,6 +324,7 @@ func(server *Server) handle_session_close(session_close []byte, from *net.UDPAdd return nil } +// TODO: handle packets without creating so many objects(and preventing an extra copy) func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error { session_id := SessionID(data[:SESSION_ID_LENGTH]) session, exists := server.sessions[session_id]