diff --git a/channel.go b/channel.go index 3c82e0b..9a522d1 100644 --- a/channel.go +++ b/channel.go @@ -2,6 +2,8 @@ package pnyx import ( "fmt" + "slices" + "sync/atomic" "github.com/google/uuid" ) @@ -9,11 +11,14 @@ import ( type ChannelID byte const ( - MODE_RAW ModeID = iota + MODE_CHANNEL ModeID = iota + MODE_RAW MODE_AUDIO AUDIO_SET_SAMPLE_RATE = iota AUDIO_GET_SAMPLE_RATE + + CHANNEL_COMMAND_BUFFER_SIZE int = 2048 ) type ModeID uint8 @@ -21,40 +26,78 @@ type CommandID uint8 type Permission string type Channel struct { - id ChannelID - name string - modes map[ModeID]Mode - members []*ServerSession + Commands chan SessionChannelCommand + Modes map[ModeID]*atomic.Value + Members atomic.Value } -func(channel *Channel) Data(session *ServerSession, mode ModeID, data []byte) { - m, has_mode := channel.modes[mode] - if has_mode { - m.Data(session, channel.id, channel.members, data) +func(channel *Channel) update_state() { + for true { + incoming := <-channel.Commands + if incoming.Packet == nil { + break + } + + command := incoming.Packet + + if command.Mode == MODE_CHANNEL { + switch command.Command { + case CHANNEL_COMMAND_JOIN: + members := channel.Members.Load().([]*ServerSession) + if slices.Contains(members, incoming.Session) == false { + new_members := make([]*ServerSession, len(members) + 1) + copy(new_members, members) + new_members[len(members)] = incoming.Session + channel.Members.Store(new_members) + fmt.Printf("New Members: %+v\n", channel.Members.Load()) + } + case CHANNEL_COMMAND_LEAVE: + members := channel.Members.Load().([]*ServerSession) + idx := slices.Index(members, incoming.Session) + if idx != -1 { + new_members := make([]*ServerSession, len(members) - 1) + copy(new_members, members[:idx]) + copy(new_members[idx:], members[idx+1:]) + channel.Members.Store(new_members) + fmt.Printf("New Members: %+v\n", channel.Members.Load()) + } + } + } else { + mode, has_mode := channel.Modes[command.Mode] + if has_mode { + members := channel.Members.Load().([]*ServerSession) + mode_val := mode.Load().(Mode) + new_mode := mode_val.Command(incoming.Session, command.Command, command.ReqID, command.Channel, members, command.Data) + mode.CompareAndSwap(mode_val, new_mode) + } + } } } -func(channel *Channel) Command(session *ServerSession, command byte, request_id uuid.UUID, mode_id ModeID, data []byte) error { - mode, has_mode := channel.modes[mode_id] - if has_mode == false { - return fmt.Errorf("Channel has no mode 0x%02x", mode) - } else { - return mode.Command(session, command, request_id, channel.id, channel.members, data) +func NewChannel(modes map[ModeID]Mode) (*Channel, error) { + initial_modes := map[ModeID]*atomic.Value{} + for mode_id, mode := range(modes) { + if mode_id == MODE_CHANNEL { + return nil, fmt.Errorf("Cannot create a channel with MODE_CHANNEL(0x%02x) mode", MODE_CHANNEL) + } + initial_modes[mode_id] = new(atomic.Value) + initial_modes[mode_id].Store(mode) } -} + + channel := &Channel{ + Commands: make(chan SessionChannelCommand, CHANNEL_COMMAND_BUFFER_SIZE), + Modes: initial_modes, + } + channel.Members.Store([]*ServerSession{}) -func(channel *Channel) Join(client PeerID, session SessionID) { -} + go channel.update_state() -func(channel *Channel) Leave(client PeerID, session SessionID) { + return channel, nil } type Mode interface { - Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) error + Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) - - Join(client PeerID, session SessionID) - Leave(client PeerID, session SessionID) } func multiplex_without_sender(origin SessionID, packet *Packet, sessions []*ServerSession) { @@ -76,20 +119,15 @@ func multiplex(packet *Packet, sessions []*ServerSession) { type RawMode struct { } -func(mode *RawMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) error { - return fmt.Errorf("unknown raw mode command 0x%02x", command) +func(mode RawMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode { + return mode } -func(mode *RawMode) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) { - new_packet := NewChannelPeerPacket(session.Peer, channel_id, MODE_RAW, data) +func(mode RawMode) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) { + new_packet := NewPeerPacket(session.Peer, channel_id, MODE_RAW, data) multiplex_without_sender(session.ID, new_packet, members) } -func(mode *RawMode) Join(client PeerID, session SessionID) { -} -func(mode *RawMode) Leave(client PeerID, session SessionID) { -} - type SampleRate byte const ( SAMPLE_RATE_UNSET SampleRate = 0xFF @@ -101,7 +139,7 @@ type AudioMode struct { SampleRate SampleRate } -func(mode *AudioMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) error { +func(mode AudioMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode { switch command { case AUDIO_SET_SAMPLE_RATE: if len(data) == 1 { @@ -112,27 +150,16 @@ func(mode *AudioMode) Command(session *ServerSession, command byte, request_id u mode.SampleRate = SampleRate(data[0]) update_packet := NewChannelCommandPacket(request_id, channel_id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, data) multiplex(update_packet, members) - return nil - default: - return fmt.Errorf("Invalid sample rate: %x", data[0]) } - } else { - return fmt.Errorf("Invalid AUDIO_SET_SAMPLE_RATE payload: %x", data) } case AUDIO_GET_SAMPLE_RATE: session.OutgoingPackets <- NewChannelCommandPacket(request_id, channel_id, MODE_AUDIO, AUDIO_SET_SAMPLE_RATE, []byte{byte(mode.SampleRate)}) - return nil - default: - return fmt.Errorf("unknown audio mode command 0x%02x", command) } + + return mode } -func(mode *AudioMode) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) { - new_packet := NewChannelPeerPacket(session.Peer, channel_id, MODE_AUDIO, data) +func(mode AudioMode) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) { + new_packet := NewPeerPacket(session.Peer, channel_id, MODE_AUDIO, data) multiplex_without_sender(session.ID, new_packet, members) } - -func(mode *AudioMode) Join(client PeerID, session SessionID) { -} -func(mode *AudioMode) Leave(client PeerID, session SessionID) { -} diff --git a/cmd/client/main.go b/cmd/client/main.go index f94a19e..56fbfe4 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -4,14 +4,12 @@ import ( "encoding/binary" "fmt" "os" - "slices" "time" "git.metznet.ca/MetzNet/pnyx" "github.com/gen2brain/malgo" "github.com/google/uuid" "github.com/hraban/opus" - "seehuhn.de/go/ncurses" ) var decoders = map[pnyx.PeerID]chan[]byte{} @@ -22,6 +20,7 @@ var speaker = make(chan []int16, 0) var audio_data = make(chan []int16, 0) func set_sample_rate(new_sample_rate int) error { + fmt.Printf("Setting sample rate to %d\n", new_sample_rate) sample_rate = new_sample_rate var err error @@ -258,15 +257,18 @@ func main() { } switch packet := packet.(type) { + case pnyx.PingPacket: + fmt.Printf("Ping Packet From Server: %+v\n", packet) case pnyx.ChannelCommandPacket: + fmt.Printf("Channel Command packet: %+v\n", packet) if packet.Channel == pnyx.ChannelID(0) { if packet.Mode == pnyx.MODE_AUDIO { if packet.Command == pnyx.AUDIO_SET_SAMPLE_RATE { var new_sample_rate int - switch packet.Data[0] { - case byte(pnyx.SAMPLE_RATE_24KHZ): + switch pnyx.SampleRate(packet.Data[0]) { + case pnyx.SAMPLE_RATE_24KHZ: new_sample_rate = 24000 - case byte(pnyx.SAMPLE_RATE_48KHZ): + case pnyx.SAMPLE_RATE_48KHZ: new_sample_rate = 48000 default: continue @@ -278,7 +280,9 @@ func main() { } } } - case pnyx.ChannelPeerPacket: + case pnyx.CommandPacket: + fmt.Printf("Command packet: %+v\n", packet) + case pnyx.PeerPacket: if packet.Channel == pnyx.ChannelID(0) { decode_chan, exists := decoders[packet.Peer] if exists == false { @@ -299,7 +303,7 @@ func main() { } }() - join_packet := pnyx.NewServerCommandPacket(uuid.New(), pnyx.SERVER_COMMAND_JOIN_CHANNEL, []byte{0x00}) + join_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_COMMAND_JOIN, nil) err = client.Send(join_packet) if err != nil { panic(err) @@ -311,17 +315,15 @@ func main() { panic(err) } - go func () { - for true { - data := <- mic - err = client.Send(pnyx.NewChannelDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data)) - if err != nil { - panic(err) - } + for true { + data := <- mic + err = client.Send(pnyx.NewDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data)) + if err != nil { + panic(err) } - }() + } - window := ncurses.Init() + /*window := ncurses.Init() defer ncurses.EndWin() ncurses.ColorPair(1).Init(ncurses.ColorBlue, ncurses.ColorRed) @@ -342,5 +344,5 @@ func main() { for i, peer_id := range(peers) { window.MvAddStr(i+1, 0, fmt.Sprintf("%x", peer_id)) } - } + }*/ } diff --git a/cmd/server/main.go b/cmd/server/main.go index 77d63a8..fba9c95 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -13,19 +13,24 @@ func main() { os_sigs := make(chan os.Signal, 1) signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGINT) - server, err := pnyx.NewServer(nil) + channel_0, err := pnyx.NewChannel(map[pnyx.ModeID]pnyx.Mode{ + pnyx.MODE_RAW: pnyx.RawMode{}, + pnyx.MODE_AUDIO: pnyx.AudioMode{ + SampleRate: pnyx.SAMPLE_RATE_48KHZ, + }, + }) if err != nil { panic(err) } - err = server.Start(os.Args[1]) + server, err := pnyx.NewServer(nil, map[pnyx.ChannelID]*pnyx.Channel{ + 0: channel_0, + }) if err != nil { panic(err) } - err = server.AddChannel(pnyx.ChannelID(0), &pnyx.RawMode{}, &pnyx.AudioMode{ - SampleRate: pnyx.SAMPLE_RATE_24KHZ, - }) + err = server.Start(os.Args[1]) if err != nil { panic(err) } diff --git a/packet.go b/packet.go index 5679a92..5297c91 100644 --- a/packet.go +++ b/packet.go @@ -11,16 +11,18 @@ type PacketType uint8 const ( PACKET_SERVER_COMMAND PacketType = iota PACKET_CHANNEL_COMMAND - PACKET_CHANNEL_DATA - PACKET_CHANNEL_PEER + PACKET_DATA + PACKET_PEER + PACKET_PING CHANNEL_HEADER_LEN int = 2 CHANNEL_COMMAND_LEN = CHANNEL_HEADER_LEN + COMMAND_LENGTH + REQ_ID_LENGTH CHANNEL_PEER_LEN = CHANNEL_HEADER_LEN + PEER_ID_LENGTH - SERVER_COMMAND_JOIN_CHANNEL byte = iota - SERVER_COMMAND_LEAVE_CHANNEL - SERVER_COMMAND_ADD_CHANNEL + CHANNEL_COMMAND_JOIN byte = iota + CHANNEL_COMMAND_LEAVE + + SERVER_COMMAND_ADD_CHANNEL byte = iota SERVER_COMMAND_DEL_CHANNEL ) @@ -49,25 +51,49 @@ func ParsePacket(data []byte) (Payload, error) { switch PacketType(data[0]) { case PACKET_SERVER_COMMAND: - return ParseServerCommandPacket(data[1:]) - case PACKET_CHANNEL_DATA: - return ParseChannelDataPacket(data[1:]) + return ParseCommandPacket(data[1:]) case PACKET_CHANNEL_COMMAND: return ParseChannelCommandPacket(data[1:]) - case PACKET_CHANNEL_PEER: - return ParseChannelPeerPacket(data[1:]) + case PACKET_DATA: + return ParseDataPacket(data[1:]) + case PACKET_PEER: + return ParsePeerPacket(data[1:]) + case PACKET_PING: + return ParsePingPacket(data[1:]) default: return nil, fmt.Errorf("Don't know how to parse packet type 0x%02x", data[0]) } } -type ServerCommandPacket struct { +type PingPacket struct { +} + +func(packet PingPacket) MarshalBinary() ([]byte, error) { + return []byte{}, nil +} + +func NewPingPacket() *Packet { + return &Packet{ + Type: PACKET_PING, + Payload: PingPacket{}, + } +} + +func ParsePingPacket(data []byte) (PingPacket, error) { + if len(data) != 0 { + return PingPacket{}, fmt.Errorf("Wrong number of bytes to parse PingPacket %d/0", len(data)) + } + + return PingPacket{}, nil +} + +type CommandPacket struct { ReqID uuid.UUID Command byte Data []byte } -func (packet ServerCommandPacket) MarshalBinary() ([]byte, error) { +func (packet CommandPacket) MarshalBinary() ([]byte, error) { p := make([]byte, 17 + len(packet.Data)) copy(p, packet.ReqID[:]) p[16] = packet.Command @@ -76,10 +102,10 @@ func (packet ServerCommandPacket) MarshalBinary() ([]byte, error) { return p, nil } -func NewServerCommandPacket(request_id uuid.UUID, command byte, data []byte) *Packet { +func NewCommandPacket(request_id uuid.UUID, command byte, data []byte) *Packet { return &Packet{ Type: PACKET_SERVER_COMMAND, - Payload: ServerCommandPacket{ + Payload: CommandPacket{ ReqID: request_id, Command: command, Data: data, @@ -87,12 +113,12 @@ func NewServerCommandPacket(request_id uuid.UUID, command byte, data []byte) *Pa } } -func ParseServerCommandPacket(data []byte) (ServerCommandPacket, error) { +func ParseCommandPacket(data []byte) (CommandPacket, error) { if len(data) < 17 { - return ServerCommandPacket{}, fmt.Errorf("Not enough data to decode ServerCommandPacket: %d/%d", len(data), 17) + return CommandPacket{}, fmt.Errorf("Not enough data to decode CommandPacket: %d/%d", len(data), 17) } - return ServerCommandPacket{ + return CommandPacket{ ReqID: uuid.UUID(data[0:16]), Command: data[16], Data: data[17:], @@ -175,16 +201,16 @@ func ParseChannelCommandPacket(data []byte) (ChannelCommandPacket, error) { }, nil } -type ChannelPeerPacket struct { +type PeerPacket struct { ChannelHeader Peer PeerID Data []byte } -func NewChannelPeerPacket(peer PeerID, channel ChannelID, mode ModeID, data []byte) *Packet { +func NewPeerPacket(peer PeerID, channel ChannelID, mode ModeID, data []byte) *Packet { return &Packet{ - Type: PACKET_CHANNEL_PEER, - Payload: ChannelPeerPacket{ + Type: PACKET_PEER, + Payload: PeerPacket{ ChannelHeader: ChannelHeader{ Channel: channel, Mode: mode, @@ -195,7 +221,7 @@ func NewChannelPeerPacket(peer PeerID, channel ChannelID, mode ModeID, data []by } } -func(packet ChannelPeerPacket) MarshalBinary() ([]byte, error) { +func(packet PeerPacket) MarshalBinary() ([]byte, error) { header, err := packet.ChannelHeader.MarshalBinary() if err != nil { return nil, err @@ -205,17 +231,17 @@ func(packet ChannelPeerPacket) MarshalBinary() ([]byte, error) { return append(data, packet.Data...), nil } -func ParseChannelPeerPacket(data []byte) (ChannelPeerPacket, error) { +func ParsePeerPacket(data []byte) (PeerPacket, error) { if len(data) < CHANNEL_PEER_LEN { - return ChannelPeerPacket{}, fmt.Errorf("Not enough bytes to parse ServerChannelPacket: %d/%d", len(data), PEER_ID_LENGTH) + return PeerPacket{}, fmt.Errorf("Not enough bytes to parse ServerChannelPacket: %d/%d", len(data), PEER_ID_LENGTH) } header, err := ParseChannelHeader(data) if err != nil { - return ChannelPeerPacket{}, err + return PeerPacket{}, err } - return ChannelPeerPacket{ + return PeerPacket{ ChannelHeader: header, Peer: PeerID(data[CHANNEL_HEADER_LEN:]), Data: data[CHANNEL_PEER_LEN:], @@ -223,15 +249,15 @@ func ParseChannelPeerPacket(data []byte) (ChannelPeerPacket, error) { } -type ChannelDataPacket struct { +type DataPacket struct { ChannelHeader Data []byte } -func NewChannelDataPacket(channel ChannelID, mode ModeID, data []byte) *Packet { +func NewDataPacket(channel ChannelID, mode ModeID, data []byte) *Packet { return &Packet{ - Type: PACKET_CHANNEL_DATA, - Payload: ChannelDataPacket{ + Type: PACKET_DATA, + Payload: DataPacket{ ChannelHeader: ChannelHeader{ Channel: channel, Mode: mode, @@ -241,7 +267,7 @@ func NewChannelDataPacket(channel ChannelID, mode ModeID, data []byte) *Packet { } } -func(packet ChannelDataPacket) MarshalBinary() ([]byte, error) { +func(packet DataPacket) MarshalBinary() ([]byte, error) { header, err := packet.ChannelHeader.MarshalBinary() if err != nil { return nil, err @@ -250,17 +276,17 @@ func(packet ChannelDataPacket) MarshalBinary() ([]byte, error) { return append(header, packet.Data...), nil } -func ParseChannelDataPacket(data []byte) (ChannelDataPacket, error) { +func ParseDataPacket(data []byte) (DataPacket, error) { if len(data) < CHANNEL_HEADER_LEN { - return ChannelDataPacket{}, fmt.Errorf("Not enough data to parse ChannelDataPacket") + return DataPacket{}, fmt.Errorf("Not enough data to parse DataPacket") } header, err := ParseChannelHeader(data) if err != nil { - return ChannelDataPacket{}, nil + return DataPacket{}, nil } - return ChannelDataPacket{ + return DataPacket{ ChannelHeader: header, Data: data[CHANNEL_HEADER_LEN:], }, nil diff --git a/server.go b/server.go index bc29fde..3f35907 100644 --- a/server.go +++ b/server.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "os" - "reflect" "slices" "sync" "sync/atomic" @@ -17,6 +16,7 @@ import ( const ( SERVER_UDP_BUFFER_SIZE = 2048 SERVER_SEND_BUFFER_SIZE = 2048 + SERVER_COMMAND_BUFFER_SIZE = 2048 ) type RoleID uint32 @@ -25,7 +25,7 @@ type ServerSession struct { Session LastSeen time.Time IncomingPackets chan[]byte - OutgoingPackets chan *Packet + OutgoingPackets chan Payload Channels []ChannelID } @@ -34,19 +34,17 @@ type Server struct { active atomic.Bool connection *net.UDPConn stopped chan error - - modes map[reflect.Type]ModeID + commands chan Payload sessions_lock sync.Mutex sessions map[SessionID]*ServerSession - channels_lock sync.RWMutex - channels map[ChannelID]*Channel + channels atomic.Value peers map[PeerID][]RoleID } -func NewServer(key ed25519.PrivateKey) (*Server, error) { +func NewServer(key ed25519.PrivateKey, channels map[ChannelID]*Channel) (*Server, error) { if key == nil { var err error _, key, err = ed25519.GenerateKey(rand.Reader) @@ -59,67 +57,18 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) { connection: nil, active: atomic.Bool{}, stopped: make(chan error, 0), - - modes: map[reflect.Type]ModeID{ - reflect.TypeFor[*RawMode](): MODE_RAW, - reflect.TypeFor[*AudioMode](): MODE_AUDIO, - }, + commands: make(chan Payload, SERVER_COMMAND_BUFFER_SIZE), sessions: map[SessionID]*ServerSession{}, - channels: map[ChannelID]*Channel{}, + channels: atomic.Value{}, peers: map[PeerID][]RoleID{}, } + server.channels.Store(channels) server.active.Store(false) return server, nil } -func(server *Server) RemoveChannel(id ChannelID) error { - server.channels_lock.Lock() - defer server.channels_lock.Unlock() - - _, exists := server.channels[id] - if exists == false { - return fmt.Errorf("Channel %x does not exist", id) - } - - delete(server.channels, id) - return nil -} - -func(server *Server) AddChannel(id ChannelID, modes ...Mode) error { - server.channels_lock.Lock() - defer server.channels_lock.Unlock() - - _, exists := server.channels[id] - if exists { - return fmt.Errorf("Channel with ID %x already exists", id) - } - - mode_map := map[ModeID]Mode{} - for _, mode := range(modes) { - reflect_type := reflect.TypeOf(mode) - mode_id, known := server.modes[reflect_type] - if known == false { - return fmt.Errorf("Can't create channel with unknown mode: %s", reflect_type) - } - - _, exists := mode_map[mode_id] - if exists { - return fmt.Errorf("Can't create channel with duplicate ModeID %x", mode_id) - } - mode_map[mode_id] = mode - } - - server.channels[id] = &Channel{ - id: id, - modes: mode_map, - members: []*ServerSession{}, - } - - return nil -} - func(server *Server) Log(format string, fields ...interface{}) { fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n") } @@ -173,93 +122,79 @@ func handle_session_outgoing(session *ServerSession, server *Server) { server.Log("Stopping session outgoing goroutine %s", session.ID) } +const SESSION_PING_TIME = time.Minute +const SESSION_TIMEOUT = 2 * time.Minute + +type SessionChannelCommand struct { + Session *ServerSession + Packet *ChannelCommandPacket +} + func handle_session_incoming(session *ServerSession, server *Server) { server.Log("Starting session incoming goroutine %s", session.ID) - for true { - encrypted := <- session.IncomingPackets - if encrypted == nil { - break - } + ping_timer := time.After(SESSION_PING_TIME) + running := true + for running { + select { + case <- ping_timer: + if time.Now().Add(-1*SESSION_TIMEOUT).Compare(session.LastSeen) != 1 { + server.sessions_lock.Lock() + server.close_session(session) + server.sessions_lock.Unlock() + running = false + } else { + session.OutgoingPackets <- NewPingPacket() + ping_timer = time.After(SESSION_PING_TIME) + } + case encrypted := <- session.IncomingPackets: + if encrypted == nil { + running = false + continue + } - data, err := ParseSessionData(&session.Session, encrypted) - if err != nil { - server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) - continue - } + data, err := ParseSessionData(&session.Session, encrypted) + if err != nil { + server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) + continue + } - packet, err := ParsePacket(data) - if err != nil { - server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err) - } + packet, err := ParsePacket(data) + if err != nil { + server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err) + } - switch packet := packet.(type) { - case ServerCommandPacket: - switch packet.Command { - case SERVER_COMMAND_JOIN_CHANNEL: - server.Log("Got join_channel for %x with %x", session.ID, packet.Data) - if len(packet.Data) == 1 { - server.channels_lock.Lock() - channel, exists := server.channels[ChannelID(packet.Data[0])] - if exists == true { - if slices.Contains(channel.members, session) == false { - channel.members = append(channel.members, session) - channel.Join(session.Peer, session.ID) - // TODO: Send message to clients to notify of join - } + switch packet := packet.(type) { + case CommandPacket: + server.commands<-packet + case ChannelCommandPacket: + channels := server.channels.Load().(map[ChannelID]*Channel) + channel, exists := channels[packet.Channel] + if exists == true { + channel.Commands<-SessionChannelCommand{ + Session: session, + Packet: &packet, } - server.channels_lock.Unlock() + } else { + server.Log("Command for unknown channel %d", packet.Channel) } - case SERVER_COMMAND_LEAVE_CHANNEL: - server.Log("Got leave_channel for %x with %x", session.ID, packet.Data) - if len(packet.Data) == 1 { - server.channels_lock.Lock() - channel, exists := server.channels[ChannelID(packet.Data[0])] - if exists == true { - idx := slices.Index(channel.members, session) - if idx != -1 { - channel.members = slices.Delete(channel.members, idx, idx+1) - channel.Leave(session.Peer, session.ID) - // TODO: Send message to clients to notify of join + case DataPacket: + channels := server.channels.Load().(map[ChannelID]*Channel) + channel, exists := channels[packet.Channel] + if exists == true { + members := channel.Members.Load().([]*ServerSession) + if slices.Contains(members, session) { + mode, has_mode := channel.Modes[packet.Mode] + if has_mode { + mode.Load().(Mode).Data(session, packet.Channel, members, data) } } - server.channels_lock.Unlock() - } - case SERVER_COMMAND_ADD_CHANNEL: - server.Log("Got add_channel with %x", packet.Data) - case SERVER_COMMAND_DEL_CHANNEL: - server.Log("Got del_channel with %x", packet.Data) - default: - server.Log("Unknown server command %x", packet.Command) - } - case ChannelDataPacket: - server.channels_lock.RLock() - channel, exists := server.channels[packet.Channel] - if exists == true { - if slices.Contains(channel.members, session) { - channel.Data(session, packet.Mode, packet.Data) + } else { + server.Log("Data for unknown channel %d", packet.Channel) } - } else { - server.Log("Packet for unknown channel %d", packet.Channel) - } - server.channels_lock.RUnlock() - - case ChannelCommandPacket: - server.channels_lock.RLock() - channel, exists := server.channels[packet.Channel] - if exists == true { - err = channel.Command(session, packet.Command, packet.ReqID, packet.Mode, packet.Data) - if err != nil { - server.Log("Error processing %+v - %s", packet, err) - } - } else { - server.Log("Packet for unknown channel %d", packet.Channel) + default: + server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err) } - server.channels_lock.RUnlock() - - - default: - server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err) } } @@ -277,7 +212,7 @@ func(server *Server) handle_session_open(client_session_open []byte, from *net.U Session: session, LastSeen: time.Now(), IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE), - OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE), + OutgoingPackets: make(chan Payload, SESSION_BUFFER_SIZE), } server.sessions_lock.Unlock() @@ -329,7 +264,7 @@ func(server *Server) handle_session_close(session_close []byte, from *net.UDPAdd } server.sessions_lock.Lock() - server.close_session(session) + server.close_session(session) server.sessions_lock.Unlock() _, err = server.connection.WriteToUDP(session_closed, session.remote) @@ -406,11 +341,17 @@ func(server *Server) listen_udp() { } } + channels := server.channels.Load().(map[ChannelID]*Channel) + for _, channel := range(channels) { + close(channel.Commands) + } + server.sessions_lock.Lock() sessions := make([]*ServerSession, 0, len(server.sessions)) for _, session := range(server.sessions) { sessions = append(sessions, session) } + for _, session := range(sessions) { server.close_session(session) } @@ -424,31 +365,16 @@ func(server *Server) close_session(session *ServerSession) { close(session.IncomingPackets) close(session.OutgoingPackets) delete(server.sessions, session.ID) -} -const SESSION_TIMEOUT = time.Minute * 5 -const SESSION_TIMEOUT_CHECK = time.Minute + session_closed := NewSessionTimed(SESSION_CLOSED, server.key, &session.Session, time.Now()) + server.connection.WriteToUDP(session_closed, session.remote) +} -func(server *Server) cleanup_sessions() { +func(server *Server) update_state() { for server.active.Load() { select { - case <-time.After(SESSION_TIMEOUT_CHECK): - server.sessions_lock.Lock() - now := time.Now() - stale_sessions := make([]*ServerSession, 0, len(server.sessions)) - for _, session := range(server.sessions) { - if now.Sub(session.LastSeen) >= SESSION_TIMEOUT { - server.Log("Closing stale session %s for %s", session.ID, session.Peer) - stale_sessions = append(stale_sessions, session) - } - } - - for _, session := range(stale_sessions) { - server.close_session(session) - } - - server.sessions_lock.Unlock() - // TODO: add a way for this to be shutdown instantly on server shutdown + case command := <-server.commands: + server.Log("Incoming server command %+v", command) } } } @@ -472,7 +398,7 @@ func(server *Server) Start(listen string) error { } go server.listen_udp() - go server.cleanup_sessions() + go server.update_state() return nil }