Protocol improvements

live
noah metz 2024-04-12 18:06:57 -06:00
parent 21e9794747
commit 8f5a5b244f
5 changed files with 282 additions and 92 deletions

@ -5,16 +5,20 @@ import (
"fmt"
)
type ChannelID uint32
type ChannelID byte
const (
MODE_CHANNEL ModeID = iota
MODE_RAW
MODE_AUDIO
CHANNEL_JOIN byte = iota
CHANNEL_LEAVE
CHANNEL_MEMBERS
AUDIO_SET_SAMPLE_RATE = iota
AUDIO_GET_SAMPLE_RATE
RAW_DATA = iota
)
@ -38,9 +42,9 @@ func(channel *Channel) Data(session *Session, mode ModeID, data []byte) []SendPa
}
}
func(channel *Channel) Command(session *Session, mode ModeID, command byte, data []byte) ([]SendPacket, error) {
if mode == MODE_CHANNEL {
switch command {
func(channel *Channel) Command(session *Session, packet ChannelCommandPacket) ([]SendPacket, error) {
if packet.Mode == MODE_CHANNEL {
switch packet.Command {
case CHANNEL_JOIN:
if slices.Contains(channel.sessions, session.ID) {
return nil, fmt.Errorf("Session %s already in channel %d, can't join", session.ID, channel.id)
@ -57,14 +61,14 @@ func(channel *Channel) Command(session *Session, mode ModeID, command byte, data
return nil, nil
}
default:
return nil, fmt.Errorf("Unknown MODE_CHANNEL command: 0x%02x", command)
return nil, fmt.Errorf("Unknown MODE_CHANNEL command: 0x%02x", packet.Command)
}
} else {
mode, has_mode := channel.modes[mode]
mode, has_mode := channel.modes[packet.Mode]
if has_mode == false {
return nil, fmt.Errorf("Channel has no mode 0x%02x", mode)
} else {
return mode.Command(session, channel, command, data)
return mode.Command(session, channel, packet)
}
}
}
@ -76,14 +80,15 @@ type SendPacket struct {
type Mode interface {
// Process takes incoming packets from a session and returns a list of packets to send
Command(session *Session, channel *Channel, command byte, data []byte) ([]SendPacket, error)
Command(session *Session, channel *Channel, packet ChannelCommandPacket) ([]SendPacket, error)
Data(session *Session, channel *Channel, data []byte) []SendPacket
}
func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPacket {
send_packets := make([]SendPacket, len(sessions))
for i, session_id := range(sessions) {
if session_id == session.ID {
func multiplex_without_sender(origin SessionID, packet *Packet, sessions []SessionID) []SendPacket {
send_packets := make([]SendPacket, len(sessions) - 1)
i := 0
for _, session_id := range(sessions) {
if session_id == origin {
continue
}
@ -91,6 +96,19 @@ func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPac
Packet: packet,
Session: session_id,
}
i += 1
}
return send_packets
}
func multiplex(packet *Packet, sessions []SessionID) []SendPacket {
send_packets := make([]SendPacket, len(sessions))
for i, session_id := range(sessions) {
send_packets[i] = SendPacket{
Packet: packet,
Session: session_id,
}
}
return send_packets
@ -99,14 +117,60 @@ func multiplex(session *Session, packet *Packet, sessions []SessionID) []SendPac
type RawMode struct {
}
func(mode *RawMode) Command(session *Session, channel *Channel, command byte, data []byte) ([]SendPacket, error) {
return nil, fmt.Errorf("unknown raw mode command 0x%02x", command)
func(mode *RawMode) Command(session *Session, channel *Channel, packet ChannelCommandPacket) ([]SendPacket, error) {
return nil, fmt.Errorf("unknown raw mode command 0x%02x", packet.Command)
}
func(mode *RawMode) Data(session *Session, channel *Channel, data []byte) []SendPacket {
if slices.Contains(channel.sessions, session.ID) {
new_packet := NewChannelPeerPacket(session.Peer, channel.id, MODE_RAW, data)
return multiplex(session, new_packet, channel.sessions)
return multiplex_without_sender(session.ID, new_packet, channel.sessions)
}
return nil
}
type SampleRate byte
const (
SAMPLE_RATE_UNSET SampleRate = 0xFF
SAMPLE_RATE_24KHZ = 0x01
SAMPLE_RATE_48KHZ = 0x02
)
type AudioMode struct {
SampleRate SampleRate
}
func(mode *AudioMode) Command(session *Session, channel *Channel, packet ChannelCommandPacket) ([]SendPacket, error) {
switch packet.Command {
case AUDIO_SET_SAMPLE_RATE:
if len(packet.Data) == 1 {
switch SampleRate(packet.Data[0]) {
case SAMPLE_RATE_24KHZ:
fallthrough
case SAMPLE_RATE_48KHZ:
mode.SampleRate = SampleRate(packet.Data[0])
update_packet := NewChannelCommandPacket(packet.ReqID, channel.id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, packet.Data)
return multiplex(update_packet, channel.sessions), nil
default:
return nil, fmt.Errorf("Invalid sample rate: %x", packet.Data[0])
}
} else {
return nil, fmt.Errorf("Invalid AUDIO_SET_SAMPLE_RATE payload: %x", packet.Data)
}
case AUDIO_GET_SAMPLE_RATE:
return []SendPacket{{
Packet: NewChannelCommandPacket(packet.ReqID, channel.id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, []byte{byte(mode.SampleRate)}),
Session: session.ID,
}}, nil
default:
return nil, fmt.Errorf("unknown audio mode command 0x%02x", packet.Command)
}
}
func(mode *AudioMode) Data(session *Session, channel *Channel, data []byte) []SendPacket {
if slices.Contains(channel.sessions, session.ID) {
new_packet := NewChannelPeerPacket(session.Peer, channel.id, MODE_AUDIO, data)
return multiplex_without_sender(session.ID, new_packet, channel.sessions)
}
return nil
}

@ -8,16 +8,83 @@ import (
"git.metznet.ca/MetzNet/pnyx"
"github.com/gen2brain/malgo"
"github.com/google/uuid"
"github.com/hraban/opus"
)
func main() {
decoders := map[pnyx.PeerID]chan[]byte{}
encoder, err := opus.NewEncoder(48000, 1, opus.AppVoIP)
var decoders = map[pnyx.PeerID]chan[]byte{}
var encoder *opus.Encoder
var sample_rate int = 0
var mic = make(chan []byte, 0)
var speaker = make(chan []byte, 0)
func set_sample_rate(new_sample_rate int) error {
sample_rate = new_sample_rate
var err error
fmt.Printf("Creating encoder with sample_rate %d\n", new_sample_rate)
encoder, err = opus.NewEncoder(new_sample_rate, 1, opus.AppVoIP)
if err != nil {
return err
}
for peer_id, decoder_chan := range(decoders) {
if decoder_chan != nil {
decoder_chan <- nil
}
new_chan := make(chan[]byte, 1000)
decoders[peer_id] = new_chan
go handle_peer_decode(peer_id, decoders[peer_id], sample_rate)
}
return nil
}
func handle_peer_decode(peer_id pnyx.PeerID, decode_chan chan[]byte, sample_rate int){
fmt.Printf("Starting decoder routine for %x with sample_rate %d\n", peer_id, sample_rate)
decoder, err := opus.NewDecoder(sample_rate, 1)
if err != nil {
panic(err)
}
running := true
for running {
select {
case <-time.After(20*time.Millisecond):
pcm := make([]int16, sample_rate/50)
err := decoder.DecodePLC(pcm)
if err != nil {
panic(err)
}
pcm_bytes := make([]byte, sample_rate/50*2)
for i := 0; i < sample_rate/50; i++ {
binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i]))
}
speaker <- pcm_bytes
case data := <-decode_chan:
if data == nil {
running = false
} else {
pcm := make([]int16, sample_rate/50*2)
written, err := decoder.Decode(data, pcm)
if err != nil {
panic(err)
}
pcm_bytes := make([]byte, written*2)
for i := 0; i < written; i++ {
binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i]))
}
speaker <- pcm_bytes
}
}
}
fmt.Printf("Stopping decoder routine for %x with sample_rate %d\n", peer_id, sample_rate)
}
func main() {
ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil)
if err != nil {
panic(err)
@ -84,9 +151,6 @@ func main() {
outDeviceConfig.Alsa.NoMMap = 1
outDeviceConfig.Playback.ShareMode = malgo.Shared
mic := make(chan []byte, 0)
speaker := make(chan []byte, 0)
onSendFrames := func(output_samples []byte, input_samples []byte, framecount uint32) {
select {
case data := <- speaker:
@ -115,20 +179,21 @@ func main() {
defer outDevice.Stop()
onRecvFrames := func(output_samples []byte, input_samples []byte, framecount uint32) {
pcm := make([]int16, len(input_samples)/2)
for i := 0; i < len(input_samples)/2; i++ {
pcm[i] = int16(binary.LittleEndian.Uint16(input_samples[2*i:]))
}
data := make([]byte, len(input_samples))
written, err := encoder.Encode(pcm, data)
if err != nil {
panic(err)
}
if encoder != nil {
pcm := make([]int16, len(input_samples)/2)
for i := 0; i < len(input_samples)/2; i++ {
pcm[i] = int16(binary.LittleEndian.Uint16(input_samples[2*i:]))
}
select {
case mic <- data[:written]:
default:
data := make([]byte, len(input_samples))
written, err := encoder.Encode(pcm, data)
if err != nil {
panic(err)
}
select {
case mic <- data[:written]:
default:
}
}
}
@ -182,51 +247,42 @@ func main() {
}
switch packet := packet.(type) {
case pnyx.ChannelPeerPacket:
case pnyx.ChannelCommandPacket:
if packet.Channel == pnyx.ChannelID(0) {
decode_chan, exists := decoders[packet.Peer]
if exists == false {
decode_chan = make(chan[]byte, 1000)
decoders[packet.Peer] = decode_chan
go func(decode_chan chan[]byte){
decoder, err := opus.NewDecoder(48000, 1)
if packet.Mode == pnyx.MODE_AUDIO {
if packet.Command == pnyx.AUDIO_SET_SAMPLE_RATE {
fmt.Printf("GOT NEW SAMPLE RATE 0x%02x\n", packet.Data)
var new_sample_rate int
switch packet.Data[0] {
case byte(pnyx.SAMPLE_RATE_24KHZ):
new_sample_rate = 24000
case byte(pnyx.SAMPLE_RATE_48KHZ):
new_sample_rate = 48000
default:
continue
}
err := set_sample_rate(new_sample_rate)
if err != nil {
panic(err)
}
for true {
select {
case <-time.After(20*time.Millisecond):
pcm := make([]int16, 960)
err := decoder.DecodePLC(pcm)
if err != nil {
panic(err)
}
pcm_bytes := make([]byte, 960*2)
for i := 0; i < 960; i++ {
binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i]))
}
speaker <- pcm_bytes
case data := <-decode_chan:
pcm := make([]int16, 960)
written, err := decoder.Decode(data, pcm)
if err != nil {
panic(err)
}
pcm_bytes := make([]byte, written*2)
for i := 0; i < written; i++ {
binary.LittleEndian.PutUint16(pcm_bytes[i*2:], uint16(pcm[i]))
}
speaker <- pcm_bytes
}
}
}(decoders[packet.Peer])
}
}
}
case pnyx.ChannelPeerPacket:
if packet.Channel == pnyx.ChannelID(0) {
decode_chan, exists := decoders[packet.Peer]
if exists == false {
if sample_rate != 0 {
decode_chan = make(chan[]byte, 1000)
decoders[packet.Peer] = decode_chan
go handle_peer_decode(packet.Peer, decoders[packet.Peer], sample_rate)
decode_chan <- packet.Data
} else {
decoders[packet.Peer] = nil
}
} else if decode_chan != nil {
decode_chan <- packet.Data
}
decode_chan <- packet.Data
}
default:
fmt.Printf("Unhandled packet type: %s\n", packet)
@ -234,15 +290,27 @@ func main() {
}
}()
join_packet, _ := pnyx.NewChannelCommandPacket(pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_JOIN, nil)
add_packet, _ := pnyx.NewServerCommandPacket(pnyx.SERVER_COMMAND_ADD_CHANNEL, []byte{0xFF})
err = client.Send(add_packet)
if err != nil {
panic(err)
}
join_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_JOIN, nil)
err = client.Send(join_packet)
if err != nil {
panic(err)
}
get_sample_rate_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_AUDIO, pnyx.AUDIO_SET_SAMPLE_RATE, []byte{byte(pnyx.SAMPLE_RATE_48KHZ)})
err = client.Send(get_sample_rate_packet)
if err != nil {
panic(err)
}
for true {
data := <- mic
err = client.Send(pnyx.NewChannelDataPacket(pnyx.ChannelID(0), pnyx.MODE_RAW, data))
err = client.Send(pnyx.NewChannelDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data))
if err != nil {
panic(err)
}

@ -23,7 +23,9 @@ func main() {
panic(err)
}
err = server.AddChannel(pnyx.ChannelID(0), &pnyx.RawMode{})
err = server.AddChannel(pnyx.ChannelID(0), &pnyx.RawMode{}, &pnyx.AudioMode{
SampleRate: pnyx.SAMPLE_RATE_24KHZ,
})
if err != nil {
panic(err)
}

@ -2,7 +2,6 @@ package pnyx
import (
"encoding"
"encoding/binary"
"fmt"
"github.com/google/uuid"
@ -10,13 +9,17 @@ import (
type PacketType uint8
const (
PACKET_CHANNEL_COMMAND PacketType = iota
PACKET_SERVER_COMMAND PacketType = iota
PACKET_CHANNEL_COMMAND
PACKET_CHANNEL_DATA
PACKET_CHANNEL_PEER
CHANNEL_HEADER_LEN int = 5
CHANNEL_HEADER_LEN int = 2
CHANNEL_COMMAND_LEN = CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH
CHANNEL_PEER_LEN = CHANNEL_HEADER_LEN + PEER_ID_LENGTH
SERVER_COMMAND_ADD_CHANNEL byte = iota
SERVER_COMMAND_DEL_CHANNEL
)
type Payload interface {
@ -43,6 +46,8 @@ func ParsePacket(data []byte) (Payload, error) {
}
switch PacketType(data[0]) {
case PACKET_SERVER_COMMAND:
return ParseServerCommandPacket(data[1:])
case PACKET_CHANNEL_DATA:
return ParseChannelDataPacket(data[1:])
case PACKET_CHANNEL_COMMAND:
@ -54,24 +59,62 @@ func ParsePacket(data []byte) (Payload, error) {
}
}
type ServerCommandPacket struct {
ReqID uuid.UUID
Command byte
Data []byte
}
func (packet ServerCommandPacket) MarshalBinary() ([]byte, error) {
p := make([]byte, 17 + len(packet.Data))
copy(p, packet.ReqID[:])
p[16] = packet.Command
copy(p[17:], packet.Data)
return p, nil
}
func NewServerCommandPacket(command byte, data []byte) (*Packet, uuid.UUID) {
req_id := uuid.New()
return &Packet{
Type: PACKET_SERVER_COMMAND,
Payload: ServerCommandPacket{
ReqID: req_id,
Command: command,
Data: data,
},
}, req_id
}
func ParseServerCommandPacket(data []byte) (ServerCommandPacket, error) {
if len(data) < 17 {
return ServerCommandPacket{}, fmt.Errorf("Not enough data to decode ServerCommandPacket: %d/%d", len(data), 17)
}
return ServerCommandPacket{
ReqID: uuid.UUID(data[0:16]),
Command: data[16],
Data: data[17:],
}, nil
}
type ChannelHeader struct {
Channel ChannelID
Mode ModeID
}
func(packet ChannelHeader) MarshalBinary() ([]byte, error) {
p := binary.BigEndian.AppendUint32(nil, uint32(packet.Channel))
return append(p, byte(packet.Mode)), nil
return []byte{byte(packet.Channel), byte(packet.Mode)}, nil
}
func ParseChannelHeader(data []byte) (ChannelHeader, error) {
if len(data) < 5 {
if len(data) < 2 {
return ChannelHeader{}, fmt.Errorf("Not enough bytes to parse ChannelPacket(%d/%d)", len(data), 6)
}
return ChannelHeader{
Channel: ChannelID(binary.BigEndian.Uint32(data)),
Mode: ModeID(data[4]),
Channel: ChannelID(data[0]),
Mode: ModeID(data[1]),
}, nil
}
@ -82,8 +125,7 @@ type ChannelCommandPacket struct {
Data []byte
}
func NewChannelCommandPacket(channel ChannelID, mode ModeID, command byte, data []byte) (*Packet, uuid.UUID) {
request_id := uuid.New()
func NewChannelCommandPacket(request_id uuid.UUID, channel ChannelID, mode ModeID, command byte, data []byte) *Packet {
return &Packet{
Type: PACKET_CHANNEL_COMMAND,
Payload: ChannelCommandPacket{
@ -95,7 +137,7 @@ func NewChannelCommandPacket(channel ChannelID, mode ModeID, command byte, data
ReqID: request_id,
Data: data,
},
}, request_id
}
}
func(packet ChannelCommandPacket) MarshalBinary() ([]byte, error) {
@ -104,9 +146,12 @@ func(packet ChannelCommandPacket) MarshalBinary() ([]byte, error) {
return nil, err
}
data := append(header, packet.Command)
data = append(data, packet.ReqID[:]...)
return append(data, packet.Data...), nil
data := make([]byte, len(header) + len(packet.Data) + REQ_ID_LENGTH + COMMAND_LENGTH)
copy(data, header)
data[CHANNEL_HEADER_LEN] = packet.Command
copy(data[CHANNEL_HEADER_LEN + COMMAND_LENGTH:], packet.ReqID[:])
copy(data[CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH:], packet.Data)
return data, nil
}
func ParseChannelCommandPacket(data []byte) (ChannelCommandPacket, error) {

@ -64,6 +64,7 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) {
modes: map[reflect.Type]ModeID{
reflect.TypeFor[*RawMode](): MODE_RAW,
reflect.TypeFor[*AudioMode](): MODE_AUDIO,
},
sessions: map[SessionID]*ServerSession{},
@ -195,6 +196,15 @@ func handle_session_incoming(session *ServerSession, server *Server) {
}
switch packet := packet.(type) {
case ServerCommandPacket:
switch packet.Command {
case SERVER_COMMAND_ADD_CHANNEL:
server.Log("Got add_channel with %x", packet.Data)
case SERVER_COMMAND_DEL_CHANNEL:
server.Log("Got del_channel with %x", packet.Data)
default:
server.Log("Unknown server command %x", packet.Command)
}
case ChannelDataPacket:
var result []SendPacket = nil
@ -218,7 +228,7 @@ func handle_session_incoming(session *ServerSession, server *Server) {
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data)
result, err = channel.Command(&session.Session, packet)
}
server.channels_lock.RUnlock()
@ -314,6 +324,7 @@ func(server *Server) handle_session_close(session_close []byte, from *net.UDPAdd
return nil
}
// TODO: handle packets without creating so many objects(and preventing an extra copy)
func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error {
session_id := SessionID(data[:SESSION_ID_LENGTH])
session, exists := server.sessions[session_id]