Moved state around

live
noah metz 2024-04-16 15:06:53 -06:00
parent 3ce9c08dff
commit c3f38ef089
5 changed files with 249 additions and 263 deletions

@ -2,6 +2,8 @@ package pnyx
import ( import (
"fmt" "fmt"
"slices"
"sync/atomic"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -9,11 +11,14 @@ import (
type ChannelID byte type ChannelID byte
const ( const (
MODE_RAW ModeID = iota MODE_CHANNEL ModeID = iota
MODE_RAW
MODE_AUDIO MODE_AUDIO
AUDIO_SET_SAMPLE_RATE = iota AUDIO_SET_SAMPLE_RATE = iota
AUDIO_GET_SAMPLE_RATE AUDIO_GET_SAMPLE_RATE
CHANNEL_COMMAND_BUFFER_SIZE int = 2048
) )
type ModeID uint8 type ModeID uint8
@ -21,40 +26,78 @@ type CommandID uint8
type Permission string type Permission string
type Channel struct { type Channel struct {
id ChannelID Commands chan SessionChannelCommand
name string Modes map[ModeID]*atomic.Value
modes map[ModeID]Mode Members atomic.Value
members []*ServerSession
} }
func(channel *Channel) Data(session *ServerSession, mode ModeID, data []byte) { func(channel *Channel) update_state() {
m, has_mode := channel.modes[mode] for true {
if has_mode { incoming := <-channel.Commands
m.Data(session, channel.id, channel.members, data) if incoming.Packet == nil {
break
}
command := incoming.Packet
if command.Mode == MODE_CHANNEL {
switch command.Command {
case CHANNEL_COMMAND_JOIN:
members := channel.Members.Load().([]*ServerSession)
if slices.Contains(members, incoming.Session) == false {
new_members := make([]*ServerSession, len(members) + 1)
copy(new_members, members)
new_members[len(members)] = incoming.Session
channel.Members.Store(new_members)
fmt.Printf("New Members: %+v\n", channel.Members.Load())
}
case CHANNEL_COMMAND_LEAVE:
members := channel.Members.Load().([]*ServerSession)
idx := slices.Index(members, incoming.Session)
if idx != -1 {
new_members := make([]*ServerSession, len(members) - 1)
copy(new_members, members[:idx])
copy(new_members[idx:], members[idx+1:])
channel.Members.Store(new_members)
fmt.Printf("New Members: %+v\n", channel.Members.Load())
}
}
} else {
mode, has_mode := channel.Modes[command.Mode]
if has_mode {
members := channel.Members.Load().([]*ServerSession)
mode_val := mode.Load().(Mode)
new_mode := mode_val.Command(incoming.Session, command.Command, command.ReqID, command.Channel, members, command.Data)
mode.CompareAndSwap(mode_val, new_mode)
}
}
} }
} }
func(channel *Channel) Command(session *ServerSession, command byte, request_id uuid.UUID, mode_id ModeID, data []byte) error { func NewChannel(modes map[ModeID]Mode) (*Channel, error) {
mode, has_mode := channel.modes[mode_id] initial_modes := map[ModeID]*atomic.Value{}
if has_mode == false { for mode_id, mode := range(modes) {
return fmt.Errorf("Channel has no mode 0x%02x", mode) if mode_id == MODE_CHANNEL {
} else { return nil, fmt.Errorf("Cannot create a channel with MODE_CHANNEL(0x%02x) mode", MODE_CHANNEL)
return mode.Command(session, command, request_id, channel.id, channel.members, data) }
initial_modes[mode_id] = new(atomic.Value)
initial_modes[mode_id].Store(mode)
} }
}
channel := &Channel{
Commands: make(chan SessionChannelCommand, CHANNEL_COMMAND_BUFFER_SIZE),
Modes: initial_modes,
}
channel.Members.Store([]*ServerSession{})
func(channel *Channel) Join(client PeerID, session SessionID) { go channel.update_state()
}
func(channel *Channel) Leave(client PeerID, session SessionID) { return channel, nil
} }
type Mode interface { type Mode interface {
Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) error Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode
Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte)
Join(client PeerID, session SessionID)
Leave(client PeerID, session SessionID)
} }
func multiplex_without_sender(origin SessionID, packet *Packet, sessions []*ServerSession) { func multiplex_without_sender(origin SessionID, packet *Packet, sessions []*ServerSession) {
@ -76,20 +119,15 @@ func multiplex(packet *Packet, sessions []*ServerSession) {
type RawMode struct { type RawMode struct {
} }
func(mode *RawMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) error { func(mode RawMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode {
return fmt.Errorf("unknown raw mode command 0x%02x", command) return mode
} }
func(mode *RawMode) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) { func(mode RawMode) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) {
new_packet := NewChannelPeerPacket(session.Peer, channel_id, MODE_RAW, data) new_packet := NewPeerPacket(session.Peer, channel_id, MODE_RAW, data)
multiplex_without_sender(session.ID, new_packet, members) multiplex_without_sender(session.ID, new_packet, members)
} }
func(mode *RawMode) Join(client PeerID, session SessionID) {
}
func(mode *RawMode) Leave(client PeerID, session SessionID) {
}
type SampleRate byte type SampleRate byte
const ( const (
SAMPLE_RATE_UNSET SampleRate = 0xFF SAMPLE_RATE_UNSET SampleRate = 0xFF
@ -101,7 +139,7 @@ type AudioMode struct {
SampleRate SampleRate SampleRate SampleRate
} }
func(mode *AudioMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) error { func(mode AudioMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode {
switch command { switch command {
case AUDIO_SET_SAMPLE_RATE: case AUDIO_SET_SAMPLE_RATE:
if len(data) == 1 { if len(data) == 1 {
@ -112,27 +150,16 @@ func(mode *AudioMode) Command(session *ServerSession, command byte, request_id u
mode.SampleRate = SampleRate(data[0]) mode.SampleRate = SampleRate(data[0])
update_packet := NewChannelCommandPacket(request_id, channel_id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, data) update_packet := NewChannelCommandPacket(request_id, channel_id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, data)
multiplex(update_packet, members) multiplex(update_packet, members)
return nil
default:
return fmt.Errorf("Invalid sample rate: %x", data[0])
} }
} else {
return fmt.Errorf("Invalid AUDIO_SET_SAMPLE_RATE payload: %x", data)
} }
case AUDIO_GET_SAMPLE_RATE: case AUDIO_GET_SAMPLE_RATE:
session.OutgoingPackets <- NewChannelCommandPacket(request_id, channel_id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, []byte{byte(mode.SampleRate)}) session.OutgoingPackets <- NewChannelCommandPacket(request_id, channel_id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, []byte{byte(mode.SampleRate)})
return nil
default:
return fmt.Errorf("unknown audio mode command 0x%02x", command)
} }
return mode
} }
func(mode *AudioMode) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) { func(mode AudioMode) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) {
new_packet := NewChannelPeerPacket(session.Peer, channel_id, MODE_AUDIO, data) new_packet := NewPeerPacket(session.Peer, channel_id, MODE_AUDIO, data)
multiplex_without_sender(session.ID, new_packet, members) multiplex_without_sender(session.ID, new_packet, members)
} }
func(mode *AudioMode) Join(client PeerID, session SessionID) {
}
func(mode *AudioMode) Leave(client PeerID, session SessionID) {
}

@ -4,14 +4,12 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"os" "os"
"slices"
"time" "time"
"git.metznet.ca/MetzNet/pnyx" "git.metznet.ca/MetzNet/pnyx"
"github.com/gen2brain/malgo" "github.com/gen2brain/malgo"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hraban/opus" "github.com/hraban/opus"
"seehuhn.de/go/ncurses"
) )
var decoders = map[pnyx.PeerID]chan[]byte{} var decoders = map[pnyx.PeerID]chan[]byte{}
@ -22,6 +20,7 @@ var speaker = make(chan []int16, 0)
var audio_data = make(chan []int16, 0) var audio_data = make(chan []int16, 0)
func set_sample_rate(new_sample_rate int) error { func set_sample_rate(new_sample_rate int) error {
fmt.Printf("Setting sample rate to %d\n", new_sample_rate)
sample_rate = new_sample_rate sample_rate = new_sample_rate
var err error var err error
@ -258,15 +257,18 @@ func main() {
} }
switch packet := packet.(type) { switch packet := packet.(type) {
case pnyx.PingPacket:
fmt.Printf("Ping Packet From Server: %+v\n", packet)
case pnyx.ChannelCommandPacket: case pnyx.ChannelCommandPacket:
fmt.Printf("Channel Command packet: %+v\n", packet)
if packet.Channel == pnyx.ChannelID(0) { if packet.Channel == pnyx.ChannelID(0) {
if packet.Mode == pnyx.MODE_AUDIO { if packet.Mode == pnyx.MODE_AUDIO {
if packet.Command == pnyx.AUDIO_SET_SAMPLE_RATE { if packet.Command == pnyx.AUDIO_SET_SAMPLE_RATE {
var new_sample_rate int var new_sample_rate int
switch packet.Data[0] { switch pnyx.SampleRate(packet.Data[0]) {
case byte(pnyx.SAMPLE_RATE_24KHZ): case pnyx.SAMPLE_RATE_24KHZ:
new_sample_rate = 24000 new_sample_rate = 24000
case byte(pnyx.SAMPLE_RATE_48KHZ): case pnyx.SAMPLE_RATE_48KHZ:
new_sample_rate = 48000 new_sample_rate = 48000
default: default:
continue continue
@ -278,7 +280,9 @@ func main() {
} }
} }
} }
case pnyx.ChannelPeerPacket: case pnyx.CommandPacket:
fmt.Printf("Command packet: %+v\n", packet)
case pnyx.PeerPacket:
if packet.Channel == pnyx.ChannelID(0) { if packet.Channel == pnyx.ChannelID(0) {
decode_chan, exists := decoders[packet.Peer] decode_chan, exists := decoders[packet.Peer]
if exists == false { if exists == false {
@ -299,7 +303,7 @@ func main() {
} }
}() }()
join_packet := pnyx.NewServerCommandPacket(uuid.New(), pnyx.SERVER_COMMAND_JOIN_CHANNEL, []byte{0x00}) join_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_COMMAND_JOIN, nil)
err = client.Send(join_packet) err = client.Send(join_packet)
if err != nil { if err != nil {
panic(err) panic(err)
@ -311,17 +315,15 @@ func main() {
panic(err) panic(err)
} }
go func () { for true {
for true { data := <- mic
data := <- mic err = client.Send(pnyx.NewDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data))
err = client.Send(pnyx.NewChannelDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data)) if err != nil {
if err != nil { panic(err)
panic(err)
}
} }
}() }
window := ncurses.Init() /*window := ncurses.Init()
defer ncurses.EndWin() defer ncurses.EndWin()
ncurses.ColorPair(1).Init(ncurses.ColorBlue, ncurses.ColorRed) ncurses.ColorPair(1).Init(ncurses.ColorBlue, ncurses.ColorRed)
@ -342,5 +344,5 @@ func main() {
for i, peer_id := range(peers) { for i, peer_id := range(peers) {
window.MvAddStr(i+1, 0, fmt.Sprintf("%x", peer_id)) window.MvAddStr(i+1, 0, fmt.Sprintf("%x", peer_id))
} }
} }*/
} }

@ -13,19 +13,24 @@ func main() {
os_sigs := make(chan os.Signal, 1) os_sigs := make(chan os.Signal, 1)
signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGINT) signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGINT)
server, err := pnyx.NewServer(nil) channel_0, err := pnyx.NewChannel(map[pnyx.ModeID]pnyx.Mode{
pnyx.MODE_RAW: pnyx.RawMode{},
pnyx.MODE_AUDIO: pnyx.AudioMode{
SampleRate: pnyx.SAMPLE_RATE_48KHZ,
},
})
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = server.Start(os.Args[1]) server, err := pnyx.NewServer(nil, map[pnyx.ChannelID]*pnyx.Channel{
0: channel_0,
})
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = server.AddChannel(pnyx.ChannelID(0), &pnyx.RawMode{}, &pnyx.AudioMode{ err = server.Start(os.Args[1])
SampleRate: pnyx.SAMPLE_RATE_24KHZ,
})
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -11,16 +11,18 @@ type PacketType uint8
const ( const (
PACKET_SERVER_COMMAND PacketType = iota PACKET_SERVER_COMMAND PacketType = iota
PACKET_CHANNEL_COMMAND PACKET_CHANNEL_COMMAND
PACKET_CHANNEL_DATA PACKET_DATA
PACKET_CHANNEL_PEER PACKET_PEER
PACKET_PING
CHANNEL_HEADER_LEN int = 2 CHANNEL_HEADER_LEN int = 2
CHANNEL_COMMAND_LEN = CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH CHANNEL_COMMAND_LEN = CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH
CHANNEL_PEER_LEN = CHANNEL_HEADER_LEN + PEER_ID_LENGTH CHANNEL_PEER_LEN = CHANNEL_HEADER_LEN + PEER_ID_LENGTH
SERVER_COMMAND_JOIN_CHANNEL byte = iota CHANNEL_COMMAND_JOIN byte = iota
SERVER_COMMAND_LEAVE_CHANNEL CHANNEL_COMMAND_LEAVE
SERVER_COMMAND_ADD_CHANNEL
SERVER_COMMAND_ADD_CHANNEL byte = iota
SERVER_COMMAND_DEL_CHANNEL SERVER_COMMAND_DEL_CHANNEL
) )
@ -49,25 +51,49 @@ func ParsePacket(data []byte) (Payload, error) {
switch PacketType(data[0]) { switch PacketType(data[0]) {
case PACKET_SERVER_COMMAND: case PACKET_SERVER_COMMAND:
return ParseServerCommandPacket(data[1:]) return ParseCommandPacket(data[1:])
case PACKET_CHANNEL_DATA:
return ParseChannelDataPacket(data[1:])
case PACKET_CHANNEL_COMMAND: case PACKET_CHANNEL_COMMAND:
return ParseChannelCommandPacket(data[1:]) return ParseChannelCommandPacket(data[1:])
case PACKET_CHANNEL_PEER: case PACKET_DATA:
return ParseChannelPeerPacket(data[1:]) return ParseDataPacket(data[1:])
case PACKET_PEER:
return ParsePeerPacket(data[1:])
case PACKET_PING:
return ParsePingPacket(data[1:])
default: default:
return nil, fmt.Errorf("Don't know how to parse packet type 0x%02x", data[0]) return nil, fmt.Errorf("Don't know how to parse packet type 0x%02x", data[0])
} }
} }
type ServerCommandPacket struct { type PingPacket struct {
}
func(packet PingPacket) MarshalBinary() ([]byte, error) {
return []byte{}, nil
}
func NewPingPacket() *Packet {
return &Packet{
Type: PACKET_PING,
Payload: PingPacket{},
}
}
func ParsePingPacket(data []byte) (PingPacket, error) {
if len(data) != 0 {
return PingPacket{}, fmt.Errorf("Wrong number of bytes to parse PingPacket %d/0", len(data))
}
return PingPacket{}, nil
}
type CommandPacket struct {
ReqID uuid.UUID ReqID uuid.UUID
Command byte Command byte
Data []byte Data []byte
} }
func (packet ServerCommandPacket) MarshalBinary() ([]byte, error) { func (packet CommandPacket) MarshalBinary() ([]byte, error) {
p := make([]byte, 17 + len(packet.Data)) p := make([]byte, 17 + len(packet.Data))
copy(p, packet.ReqID[:]) copy(p, packet.ReqID[:])
p[16] = packet.Command p[16] = packet.Command
@ -76,10 +102,10 @@ func (packet ServerCommandPacket) MarshalBinary() ([]byte, error) {
return p, nil return p, nil
} }
func NewServerCommandPacket(request_id uuid.UUID, command byte, data []byte) *Packet { func NewCommandPacket(request_id uuid.UUID, command byte, data []byte) *Packet {
return &Packet{ return &Packet{
Type: PACKET_SERVER_COMMAND, Type: PACKET_SERVER_COMMAND,
Payload: ServerCommandPacket{ Payload: CommandPacket{
ReqID: request_id, ReqID: request_id,
Command: command, Command: command,
Data: data, Data: data,
@ -87,12 +113,12 @@ func NewServerCommandPacket(request_id uuid.UUID, command byte, data []byte) *Pa
} }
} }
func ParseServerCommandPacket(data []byte) (ServerCommandPacket, error) { func ParseCommandPacket(data []byte) (CommandPacket, error) {
if len(data) < 17 { if len(data) < 17 {
return ServerCommandPacket{}, fmt.Errorf("Not enough data to decode ServerCommandPacket: %d/%d", len(data), 17) return CommandPacket{}, fmt.Errorf("Not enough data to decode CommandPacket: %d/%d", len(data), 17)
} }
return ServerCommandPacket{ return CommandPacket{
ReqID: uuid.UUID(data[0:16]), ReqID: uuid.UUID(data[0:16]),
Command: data[16], Command: data[16],
Data: data[17:], Data: data[17:],
@ -175,16 +201,16 @@ func ParseChannelCommandPacket(data []byte) (ChannelCommandPacket, error) {
}, nil }, nil
} }
type ChannelPeerPacket struct { type PeerPacket struct {
ChannelHeader ChannelHeader
Peer PeerID Peer PeerID
Data []byte Data []byte
} }
func NewChannelPeerPacket(peer PeerID, channel ChannelID, mode ModeID, data []byte) *Packet { func NewPeerPacket(peer PeerID, channel ChannelID, mode ModeID, data []byte) *Packet {
return &Packet{ return &Packet{
Type: PACKET_CHANNEL_PEER, Type: PACKET_PEER,
Payload: ChannelPeerPacket{ Payload: PeerPacket{
ChannelHeader: ChannelHeader{ ChannelHeader: ChannelHeader{
Channel: channel, Channel: channel,
Mode: mode, Mode: mode,
@ -195,7 +221,7 @@ func NewChannelPeerPacket(peer PeerID, channel ChannelID, mode ModeID, data []by
} }
} }
func(packet ChannelPeerPacket) MarshalBinary() ([]byte, error) { func(packet PeerPacket) MarshalBinary() ([]byte, error) {
header, err := packet.ChannelHeader.MarshalBinary() header, err := packet.ChannelHeader.MarshalBinary()
if err != nil { if err != nil {
return nil, err return nil, err
@ -205,17 +231,17 @@ func(packet ChannelPeerPacket) MarshalBinary() ([]byte, error) {
return append(data, packet.Data...), nil return append(data, packet.Data...), nil
} }
func ParseChannelPeerPacket(data []byte) (ChannelPeerPacket, error) { func ParsePeerPacket(data []byte) (PeerPacket, error) {
if len(data) < CHANNEL_PEER_LEN { if len(data) < CHANNEL_PEER_LEN {
return ChannelPeerPacket{}, fmt.Errorf("Not enough bytes to parse ServerChannelPacket: %d/%d", len(data), PEER_ID_LENGTH) return PeerPacket{}, fmt.Errorf("Not enough bytes to parse ServerChannelPacket: %d/%d", len(data), PEER_ID_LENGTH)
} }
header, err := ParseChannelHeader(data) header, err := ParseChannelHeader(data)
if err != nil { if err != nil {
return ChannelPeerPacket{}, err return PeerPacket{}, err
} }
return ChannelPeerPacket{ return PeerPacket{
ChannelHeader: header, ChannelHeader: header,
Peer: PeerID(data[CHANNEL_HEADER_LEN:]), Peer: PeerID(data[CHANNEL_HEADER_LEN:]),
Data: data[CHANNEL_PEER_LEN:], Data: data[CHANNEL_PEER_LEN:],
@ -223,15 +249,15 @@ func ParseChannelPeerPacket(data []byte) (ChannelPeerPacket, error) {
} }
type ChannelDataPacket struct { type DataPacket struct {
ChannelHeader ChannelHeader
Data []byte Data []byte
} }
func NewChannelDataPacket(channel ChannelID, mode ModeID, data []byte) *Packet { func NewDataPacket(channel ChannelID, mode ModeID, data []byte) *Packet {
return &Packet{ return &Packet{
Type: PACKET_CHANNEL_DATA, Type: PACKET_DATA,
Payload: ChannelDataPacket{ Payload: DataPacket{
ChannelHeader: ChannelHeader{ ChannelHeader: ChannelHeader{
Channel: channel, Channel: channel,
Mode: mode, Mode: mode,
@ -241,7 +267,7 @@ func NewChannelDataPacket(channel ChannelID, mode ModeID, data []byte) *Packet {
} }
} }
func(packet ChannelDataPacket) MarshalBinary() ([]byte, error) { func(packet DataPacket) MarshalBinary() ([]byte, error) {
header, err := packet.ChannelHeader.MarshalBinary() header, err := packet.ChannelHeader.MarshalBinary()
if err != nil { if err != nil {
return nil, err return nil, err
@ -250,17 +276,17 @@ func(packet ChannelDataPacket) MarshalBinary() ([]byte, error) {
return append(header, packet.Data...), nil return append(header, packet.Data...), nil
} }
func ParseChannelDataPacket(data []byte) (ChannelDataPacket, error) { func ParseDataPacket(data []byte) (DataPacket, error) {
if len(data) < CHANNEL_HEADER_LEN { if len(data) < CHANNEL_HEADER_LEN {
return ChannelDataPacket{}, fmt.Errorf("Not enough data to parse ChannelDataPacket") return DataPacket{}, fmt.Errorf("Not enough data to parse DataPacket")
} }
header, err := ParseChannelHeader(data) header, err := ParseChannelHeader(data)
if err != nil { if err != nil {
return ChannelDataPacket{}, nil return DataPacket{}, nil
} }
return ChannelDataPacket{ return DataPacket{
ChannelHeader: header, ChannelHeader: header,
Data: data[CHANNEL_HEADER_LEN:], Data: data[CHANNEL_HEADER_LEN:],
}, nil }, nil

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"reflect"
"slices" "slices"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -17,6 +16,7 @@ import (
const ( const (
SERVER_UDP_BUFFER_SIZE = 2048 SERVER_UDP_BUFFER_SIZE = 2048
SERVER_SEND_BUFFER_SIZE = 2048 SERVER_SEND_BUFFER_SIZE = 2048
SERVER_COMMAND_BUFFER_SIZE = 2048
) )
type RoleID uint32 type RoleID uint32
@ -25,7 +25,7 @@ type ServerSession struct {
Session Session
LastSeen time.Time LastSeen time.Time
IncomingPackets chan[]byte IncomingPackets chan[]byte
OutgoingPackets chan *Packet OutgoingPackets chan Payload
Channels []ChannelID Channels []ChannelID
} }
@ -34,19 +34,17 @@ type Server struct {
active atomic.Bool active atomic.Bool
connection *net.UDPConn connection *net.UDPConn
stopped chan error stopped chan error
commands chan Payload
modes map[reflect.Type]ModeID
sessions_lock sync.Mutex sessions_lock sync.Mutex
sessions map[SessionID]*ServerSession sessions map[SessionID]*ServerSession
channels_lock sync.RWMutex channels atomic.Value
channels map[ChannelID]*Channel
peers map[PeerID][]RoleID peers map[PeerID][]RoleID
} }
func NewServer(key ed25519.PrivateKey) (*Server, error) { func NewServer(key ed25519.PrivateKey, channels map[ChannelID]*Channel) (*Server, error) {
if key == nil { if key == nil {
var err error var err error
_, key, err = ed25519.GenerateKey(rand.Reader) _, key, err = ed25519.GenerateKey(rand.Reader)
@ -59,67 +57,18 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) {
connection: nil, connection: nil,
active: atomic.Bool{}, active: atomic.Bool{},
stopped: make(chan error, 0), stopped: make(chan error, 0),
commands: make(chan Payload, SERVER_COMMAND_BUFFER_SIZE),
modes: map[reflect.Type]ModeID{
reflect.TypeFor[*RawMode](): MODE_RAW,
reflect.TypeFor[*AudioMode](): MODE_AUDIO,
},
sessions: map[SessionID]*ServerSession{}, sessions: map[SessionID]*ServerSession{},
channels: map[ChannelID]*Channel{}, channels: atomic.Value{},
peers: map[PeerID][]RoleID{}, peers: map[PeerID][]RoleID{},
} }
server.channels.Store(channels)
server.active.Store(false) server.active.Store(false)
return server, nil return server, nil
} }
func(server *Server) RemoveChannel(id ChannelID) error {
server.channels_lock.Lock()
defer server.channels_lock.Unlock()
_, exists := server.channels[id]
if exists == false {
return fmt.Errorf("Channel %x does not exist", id)
}
delete(server.channels, id)
return nil
}
func(server *Server) AddChannel(id ChannelID, modes ...Mode) error {
server.channels_lock.Lock()
defer server.channels_lock.Unlock()
_, exists := server.channels[id]
if exists {
return fmt.Errorf("Channel with ID %x already exists", id)
}
mode_map := map[ModeID]Mode{}
for _, mode := range(modes) {
reflect_type := reflect.TypeOf(mode)
mode_id, known := server.modes[reflect_type]
if known == false {
return fmt.Errorf("Can't create channel with unknown mode: %s", reflect_type)
}
_, exists := mode_map[mode_id]
if exists {
return fmt.Errorf("Can't create channel with duplicate ModeID %x", mode_id)
}
mode_map[mode_id] = mode
}
server.channels[id] = &Channel{
id: id,
modes: mode_map,
members: []*ServerSession{},
}
return nil
}
func(server *Server) Log(format string, fields ...interface{}) { func(server *Server) Log(format string, fields ...interface{}) {
fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n") fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n")
} }
@ -173,93 +122,79 @@ func handle_session_outgoing(session *ServerSession, server *Server) {
server.Log("Stopping session outgoing goroutine %s", session.ID) server.Log("Stopping session outgoing goroutine %s", session.ID)
} }
const SESSION_PING_TIME = time.Minute
const SESSION_TIMEOUT = 2 * time.Minute
type SessionChannelCommand struct {
Session *ServerSession
Packet *ChannelCommandPacket
}
func handle_session_incoming(session *ServerSession, server *Server) { func handle_session_incoming(session *ServerSession, server *Server) {
server.Log("Starting session incoming goroutine %s", session.ID) server.Log("Starting session incoming goroutine %s", session.ID)
for true { ping_timer := time.After(SESSION_PING_TIME)
encrypted := <- session.IncomingPackets running := true
if encrypted == nil { for running {
break select {
} case <- ping_timer:
if time.Now().Add(-1*SESSION_TIMEOUT).Compare(session.LastSeen) != 1 {
server.sessions_lock.Lock()
server.close_session(session)
server.sessions_lock.Unlock()
running = false
} else {
session.OutgoingPackets <- NewPingPacket()
ping_timer = time.After(SESSION_PING_TIME)
}
case encrypted := <- session.IncomingPackets:
if encrypted == nil {
running = false
continue
}
data, err := ParseSessionData(&session.Session, encrypted) data, err := ParseSessionData(&session.Session, encrypted)
if err != nil { if err != nil {
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
continue continue
} }
packet, err := ParsePacket(data) packet, err := ParsePacket(data)
if err != nil { if err != nil {
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err) server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
} }
switch packet := packet.(type) { switch packet := packet.(type) {
case ServerCommandPacket: case CommandPacket:
switch packet.Command { server.commands<-packet
case SERVER_COMMAND_JOIN_CHANNEL: case ChannelCommandPacket:
server.Log("Got join_channel for %x with %x", session.ID, packet.Data) channels := server.channels.Load().(map[ChannelID]*Channel)
if len(packet.Data) == 1 { channel, exists := channels[packet.Channel]
server.channels_lock.Lock() if exists == true {
channel, exists := server.channels[ChannelID(packet.Data[0])] channel.Commands<-SessionChannelCommand{
if exists == true { Session: session,
if slices.Contains(channel.members, session) == false { Packet: &packet,
channel.members = append(channel.members, session)
channel.Join(session.Peer, session.ID)
// TODO: Send message to clients to notify of join
}
} }
server.channels_lock.Unlock() } else {
server.Log("Command for unknown channel %d", packet.Channel)
} }
case SERVER_COMMAND_LEAVE_CHANNEL: case DataPacket:
server.Log("Got leave_channel for %x with %x", session.ID, packet.Data) channels := server.channels.Load().(map[ChannelID]*Channel)
if len(packet.Data) == 1 { channel, exists := channels[packet.Channel]
server.channels_lock.Lock() if exists == true {
channel, exists := server.channels[ChannelID(packet.Data[0])] members := channel.Members.Load().([]*ServerSession)
if exists == true { if slices.Contains(members, session) {
idx := slices.Index(channel.members, session) mode, has_mode := channel.Modes[packet.Mode]
if idx != -1 { if has_mode {
channel.members = slices.Delete(channel.members, idx, idx+1) mode.Load().(Mode).Data(session, packet.Channel, members, data)
channel.Leave(session.Peer, session.ID)
// TODO: Send message to clients to notify of join
} }
} }
server.channels_lock.Unlock() } else {
} server.Log("Data for unknown channel %d", packet.Channel)
case SERVER_COMMAND_ADD_CHANNEL:
server.Log("Got add_channel with %x", packet.Data)
case SERVER_COMMAND_DEL_CHANNEL:
server.Log("Got del_channel with %x", packet.Data)
default:
server.Log("Unknown server command %x", packet.Command)
}
case ChannelDataPacket:
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
if slices.Contains(channel.members, session) {
channel.Data(session, packet.Mode, packet.Data)
} }
} else {
server.Log("Packet for unknown channel %d", packet.Channel)
}
server.channels_lock.RUnlock()
default:
case ChannelCommandPacket: server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
err = channel.Command(session, packet.Command, packet.ReqID, packet.Mode, packet.Data)
if err != nil {
server.Log("Error processing %+v - %s", packet, err)
}
} else {
server.Log("Packet for unknown channel %d", packet.Channel)
} }
server.channels_lock.RUnlock()
default:
server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
} }
} }
@ -277,7 +212,7 @@ func(server *Server) handle_session_open(client_session_open []byte, from *net.U
Session: session, Session: session,
LastSeen: time.Now(), LastSeen: time.Now(),
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE), IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE), OutgoingPackets: make(chan Payload, SESSION_BUFFER_SIZE),
} }
server.sessions_lock.Unlock() server.sessions_lock.Unlock()
@ -329,7 +264,7 @@ func(server *Server) handle_session_close(session_close []byte, from *net.UDPAdd
} }
server.sessions_lock.Lock() server.sessions_lock.Lock()
server.close_session(session) server.close_session(session)
server.sessions_lock.Unlock() server.sessions_lock.Unlock()
_, err = server.connection.WriteToUDP(session_closed, session.remote) _, err = server.connection.WriteToUDP(session_closed, session.remote)
@ -406,11 +341,17 @@ func(server *Server) listen_udp() {
} }
} }
channels := server.channels.Load().(map[ChannelID]*Channel)
for _, channel := range(channels) {
close(channel.Commands)
}
server.sessions_lock.Lock() server.sessions_lock.Lock()
sessions := make([]*ServerSession, 0, len(server.sessions)) sessions := make([]*ServerSession, 0, len(server.sessions))
for _, session := range(server.sessions) { for _, session := range(server.sessions) {
sessions = append(sessions, session) sessions = append(sessions, session)
} }
for _, session := range(sessions) { for _, session := range(sessions) {
server.close_session(session) server.close_session(session)
} }
@ -424,31 +365,16 @@ func(server *Server) close_session(session *ServerSession) {
close(session.IncomingPackets) close(session.IncomingPackets)
close(session.OutgoingPackets) close(session.OutgoingPackets)
delete(server.sessions, session.ID) delete(server.sessions, session.ID)
}
const SESSION_TIMEOUT = time.Minute * 5 session_closed := NewSessionTimed(SESSION_CLOSED, server.key, &session.Session, time.Now())
const SESSION_TIMEOUT_CHECK = time.Minute server.connection.WriteToUDP(session_closed, session.remote)
}
func(server *Server) cleanup_sessions() { func(server *Server) update_state() {
for server.active.Load() { for server.active.Load() {
select { select {
case <-time.After(SESSION_TIMEOUT_CHECK): case command := <-server.commands:
server.sessions_lock.Lock() server.Log("Incoming server command %+v", command)
now := time.Now()
stale_sessions := make([]*ServerSession, 0, len(server.sessions))
for _, session := range(server.sessions) {
if now.Sub(session.LastSeen) >= SESSION_TIMEOUT {
server.Log("Closing stale session %s for %s", session.ID, session.Peer)
stale_sessions = append(stale_sessions, session)
}
}
for _, session := range(stale_sessions) {
server.close_session(session)
}
server.sessions_lock.Unlock()
// TODO: add a way for this to be shutdown instantly on server shutdown
} }
} }
} }
@ -472,7 +398,7 @@ func(server *Server) Start(listen string) error {
} }
go server.listen_udp() go server.listen_udp()
go server.cleanup_sessions() go server.update_state()
return nil return nil
} }