|
|
@ -7,7 +7,6 @@ import (
|
|
|
|
"fmt"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
|
|
|
"net"
|
|
|
|
"os"
|
|
|
|
"os"
|
|
|
|
"reflect"
|
|
|
|
|
|
|
|
"slices"
|
|
|
|
"slices"
|
|
|
|
"sync"
|
|
|
|
"sync"
|
|
|
|
"sync/atomic"
|
|
|
|
"sync/atomic"
|
|
|
@ -17,6 +16,7 @@ import (
|
|
|
|
const (
|
|
|
|
const (
|
|
|
|
SERVER_UDP_BUFFER_SIZE = 2048
|
|
|
|
SERVER_UDP_BUFFER_SIZE = 2048
|
|
|
|
SERVER_SEND_BUFFER_SIZE = 2048
|
|
|
|
SERVER_SEND_BUFFER_SIZE = 2048
|
|
|
|
|
|
|
|
SERVER_COMMAND_BUFFER_SIZE = 2048
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
type RoleID uint32
|
|
|
|
type RoleID uint32
|
|
|
@ -25,7 +25,7 @@ type ServerSession struct {
|
|
|
|
Session
|
|
|
|
Session
|
|
|
|
LastSeen time.Time
|
|
|
|
LastSeen time.Time
|
|
|
|
IncomingPackets chan[]byte
|
|
|
|
IncomingPackets chan[]byte
|
|
|
|
OutgoingPackets chan *Packet
|
|
|
|
OutgoingPackets chan Payload
|
|
|
|
Channels []ChannelID
|
|
|
|
Channels []ChannelID
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -34,19 +34,17 @@ type Server struct {
|
|
|
|
active atomic.Bool
|
|
|
|
active atomic.Bool
|
|
|
|
connection *net.UDPConn
|
|
|
|
connection *net.UDPConn
|
|
|
|
stopped chan error
|
|
|
|
stopped chan error
|
|
|
|
|
|
|
|
commands chan Payload
|
|
|
|
modes map[reflect.Type]ModeID
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sessions_lock sync.Mutex
|
|
|
|
sessions_lock sync.Mutex
|
|
|
|
sessions map[SessionID]*ServerSession
|
|
|
|
sessions map[SessionID]*ServerSession
|
|
|
|
|
|
|
|
|
|
|
|
channels_lock sync.RWMutex
|
|
|
|
channels atomic.Value
|
|
|
|
channels map[ChannelID]*Channel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
peers map[PeerID][]RoleID
|
|
|
|
peers map[PeerID][]RoleID
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func NewServer(key ed25519.PrivateKey) (*Server, error) {
|
|
|
|
func NewServer(key ed25519.PrivateKey, channels map[ChannelID]*Channel) (*Server, error) {
|
|
|
|
if key == nil {
|
|
|
|
if key == nil {
|
|
|
|
var err error
|
|
|
|
var err error
|
|
|
|
_, key, err = ed25519.GenerateKey(rand.Reader)
|
|
|
|
_, key, err = ed25519.GenerateKey(rand.Reader)
|
|
|
@ -59,67 +57,18 @@ func NewServer(key ed25519.PrivateKey) (*Server, error) {
|
|
|
|
connection: nil,
|
|
|
|
connection: nil,
|
|
|
|
active: atomic.Bool{},
|
|
|
|
active: atomic.Bool{},
|
|
|
|
stopped: make(chan error, 0),
|
|
|
|
stopped: make(chan error, 0),
|
|
|
|
|
|
|
|
commands: make(chan Payload, SERVER_COMMAND_BUFFER_SIZE),
|
|
|
|
modes: map[reflect.Type]ModeID{
|
|
|
|
|
|
|
|
reflect.TypeFor[*RawMode](): MODE_RAW,
|
|
|
|
|
|
|
|
reflect.TypeFor[*AudioMode](): MODE_AUDIO,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sessions: map[SessionID]*ServerSession{},
|
|
|
|
sessions: map[SessionID]*ServerSession{},
|
|
|
|
channels: map[ChannelID]*Channel{},
|
|
|
|
channels: atomic.Value{},
|
|
|
|
|
|
|
|
|
|
|
|
peers: map[PeerID][]RoleID{},
|
|
|
|
peers: map[PeerID][]RoleID{},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
server.channels.Store(channels)
|
|
|
|
server.active.Store(false)
|
|
|
|
server.active.Store(false)
|
|
|
|
return server, nil
|
|
|
|
return server, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func(server *Server) RemoveChannel(id ChannelID) error {
|
|
|
|
|
|
|
|
server.channels_lock.Lock()
|
|
|
|
|
|
|
|
defer server.channels_lock.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, exists := server.channels[id]
|
|
|
|
|
|
|
|
if exists == false {
|
|
|
|
|
|
|
|
return fmt.Errorf("Channel %x does not exist", id)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
delete(server.channels, id)
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func(server *Server) AddChannel(id ChannelID, modes ...Mode) error {
|
|
|
|
|
|
|
|
server.channels_lock.Lock()
|
|
|
|
|
|
|
|
defer server.channels_lock.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, exists := server.channels[id]
|
|
|
|
|
|
|
|
if exists {
|
|
|
|
|
|
|
|
return fmt.Errorf("Channel with ID %x already exists", id)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mode_map := map[ModeID]Mode{}
|
|
|
|
|
|
|
|
for _, mode := range(modes) {
|
|
|
|
|
|
|
|
reflect_type := reflect.TypeOf(mode)
|
|
|
|
|
|
|
|
mode_id, known := server.modes[reflect_type]
|
|
|
|
|
|
|
|
if known == false {
|
|
|
|
|
|
|
|
return fmt.Errorf("Can't create channel with unknown mode: %s", reflect_type)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, exists := mode_map[mode_id]
|
|
|
|
|
|
|
|
if exists {
|
|
|
|
|
|
|
|
return fmt.Errorf("Can't create channel with duplicate ModeID %x", mode_id)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
mode_map[mode_id] = mode
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server.channels[id] = &Channel{
|
|
|
|
|
|
|
|
id: id,
|
|
|
|
|
|
|
|
modes: mode_map,
|
|
|
|
|
|
|
|
members: []*ServerSession{},
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func(server *Server) Log(format string, fields ...interface{}) {
|
|
|
|
func(server *Server) Log(format string, fields ...interface{}) {
|
|
|
|
fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n")
|
|
|
|
fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n")
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -173,93 +122,79 @@ func handle_session_outgoing(session *ServerSession, server *Server) {
|
|
|
|
server.Log("Stopping session outgoing goroutine %s", session.ID)
|
|
|
|
server.Log("Stopping session outgoing goroutine %s", session.ID)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const SESSION_PING_TIME = time.Minute
|
|
|
|
|
|
|
|
const SESSION_TIMEOUT = 2 * time.Minute
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type SessionChannelCommand struct {
|
|
|
|
|
|
|
|
Session *ServerSession
|
|
|
|
|
|
|
|
Packet *ChannelCommandPacket
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func handle_session_incoming(session *ServerSession, server *Server) {
|
|
|
|
func handle_session_incoming(session *ServerSession, server *Server) {
|
|
|
|
server.Log("Starting session incoming goroutine %s", session.ID)
|
|
|
|
server.Log("Starting session incoming goroutine %s", session.ID)
|
|
|
|
for true {
|
|
|
|
ping_timer := time.After(SESSION_PING_TIME)
|
|
|
|
encrypted := <- session.IncomingPackets
|
|
|
|
running := true
|
|
|
|
if encrypted == nil {
|
|
|
|
for running {
|
|
|
|
break
|
|
|
|
select {
|
|
|
|
}
|
|
|
|
case <- ping_timer:
|
|
|
|
|
|
|
|
if time.Now().Add(-1*SESSION_TIMEOUT).Compare(session.LastSeen) != 1 {
|
|
|
|
|
|
|
|
server.sessions_lock.Lock()
|
|
|
|
|
|
|
|
server.close_session(session)
|
|
|
|
|
|
|
|
server.sessions_lock.Unlock()
|
|
|
|
|
|
|
|
running = false
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
session.OutgoingPackets <- NewPingPacket()
|
|
|
|
|
|
|
|
ping_timer = time.After(SESSION_PING_TIME)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
case encrypted := <- session.IncomingPackets:
|
|
|
|
|
|
|
|
if encrypted == nil {
|
|
|
|
|
|
|
|
running = false
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
data, err := ParseSessionData(&session.Session, encrypted)
|
|
|
|
data, err := ParseSessionData(&session.Session, encrypted)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
|
|
|
|
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
packet, err := ParsePacket(data)
|
|
|
|
packet, err := ParsePacket(data)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
|
|
|
|
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
switch packet := packet.(type) {
|
|
|
|
switch packet := packet.(type) {
|
|
|
|
case ServerCommandPacket:
|
|
|
|
case CommandPacket:
|
|
|
|
switch packet.Command {
|
|
|
|
server.commands<-packet
|
|
|
|
case SERVER_COMMAND_JOIN_CHANNEL:
|
|
|
|
case ChannelCommandPacket:
|
|
|
|
server.Log("Got join_channel for %x with %x", session.ID, packet.Data)
|
|
|
|
channels := server.channels.Load().(map[ChannelID]*Channel)
|
|
|
|
if len(packet.Data) == 1 {
|
|
|
|
channel, exists := channels[packet.Channel]
|
|
|
|
server.channels_lock.Lock()
|
|
|
|
if exists == true {
|
|
|
|
channel, exists := server.channels[ChannelID(packet.Data[0])]
|
|
|
|
channel.Commands<-SessionChannelCommand{
|
|
|
|
if exists == true {
|
|
|
|
Session: session,
|
|
|
|
if slices.Contains(channel.members, session) == false {
|
|
|
|
Packet: &packet,
|
|
|
|
channel.members = append(channel.members, session)
|
|
|
|
|
|
|
|
channel.Join(session.Peer, session.ID)
|
|
|
|
|
|
|
|
// TODO: Send message to clients to notify of join
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
server.channels_lock.Unlock()
|
|
|
|
} else {
|
|
|
|
|
|
|
|
server.Log("Command for unknown channel %d", packet.Channel)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
case SERVER_COMMAND_LEAVE_CHANNEL:
|
|
|
|
case DataPacket:
|
|
|
|
server.Log("Got leave_channel for %x with %x", session.ID, packet.Data)
|
|
|
|
channels := server.channels.Load().(map[ChannelID]*Channel)
|
|
|
|
if len(packet.Data) == 1 {
|
|
|
|
channel, exists := channels[packet.Channel]
|
|
|
|
server.channels_lock.Lock()
|
|
|
|
if exists == true {
|
|
|
|
channel, exists := server.channels[ChannelID(packet.Data[0])]
|
|
|
|
members := channel.Members.Load().([]*ServerSession)
|
|
|
|
if exists == true {
|
|
|
|
if slices.Contains(members, session) {
|
|
|
|
idx := slices.Index(channel.members, session)
|
|
|
|
mode, has_mode := channel.Modes[packet.Mode]
|
|
|
|
if idx != -1 {
|
|
|
|
if has_mode {
|
|
|
|
channel.members = slices.Delete(channel.members, idx, idx+1)
|
|
|
|
mode.Load().(Mode).Data(session, packet.Channel, members, data)
|
|
|
|
channel.Leave(session.Peer, session.ID)
|
|
|
|
|
|
|
|
// TODO: Send message to clients to notify of join
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
server.channels_lock.Unlock()
|
|
|
|
} else {
|
|
|
|
}
|
|
|
|
server.Log("Data for unknown channel %d", packet.Channel)
|
|
|
|
case SERVER_COMMAND_ADD_CHANNEL:
|
|
|
|
|
|
|
|
server.Log("Got add_channel with %x", packet.Data)
|
|
|
|
|
|
|
|
case SERVER_COMMAND_DEL_CHANNEL:
|
|
|
|
|
|
|
|
server.Log("Got del_channel with %x", packet.Data)
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
server.Log("Unknown server command %x", packet.Command)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
case ChannelDataPacket:
|
|
|
|
|
|
|
|
server.channels_lock.RLock()
|
|
|
|
|
|
|
|
channel, exists := server.channels[packet.Channel]
|
|
|
|
|
|
|
|
if exists == true {
|
|
|
|
|
|
|
|
if slices.Contains(channel.members, session) {
|
|
|
|
|
|
|
|
channel.Data(session, packet.Mode, packet.Data)
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
|
|
|
|
server.Log("Packet for unknown channel %d", packet.Channel)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
server.channels_lock.RUnlock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
case ChannelCommandPacket:
|
|
|
|
server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
|
|
|
|
server.channels_lock.RLock()
|
|
|
|
|
|
|
|
channel, exists := server.channels[packet.Channel]
|
|
|
|
|
|
|
|
if exists == true {
|
|
|
|
|
|
|
|
err = channel.Command(session, packet.Command, packet.ReqID, packet.Mode, packet.Data)
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
|
|
server.Log("Error processing %+v - %s", packet, err)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
server.Log("Packet for unknown channel %d", packet.Channel)
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
server.channels_lock.RUnlock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -277,7 +212,7 @@ func(server *Server) handle_session_open(client_session_open []byte, from *net.U
|
|
|
|
Session: session,
|
|
|
|
Session: session,
|
|
|
|
LastSeen: time.Now(),
|
|
|
|
LastSeen: time.Now(),
|
|
|
|
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
|
|
|
|
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
|
|
|
|
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE),
|
|
|
|
OutgoingPackets: make(chan Payload, SESSION_BUFFER_SIZE),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
server.sessions_lock.Unlock()
|
|
|
|
server.sessions_lock.Unlock()
|
|
|
|
|
|
|
|
|
|
|
@ -329,7 +264,7 @@ func(server *Server) handle_session_close(session_close []byte, from *net.UDPAdd
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
server.sessions_lock.Lock()
|
|
|
|
server.sessions_lock.Lock()
|
|
|
|
server.close_session(session)
|
|
|
|
server.close_session(session)
|
|
|
|
server.sessions_lock.Unlock()
|
|
|
|
server.sessions_lock.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
_, err = server.connection.WriteToUDP(session_closed, session.remote)
|
|
|
|
_, err = server.connection.WriteToUDP(session_closed, session.remote)
|
|
|
@ -406,11 +341,17 @@ func(server *Server) listen_udp() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
channels := server.channels.Load().(map[ChannelID]*Channel)
|
|
|
|
|
|
|
|
for _, channel := range(channels) {
|
|
|
|
|
|
|
|
close(channel.Commands)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
server.sessions_lock.Lock()
|
|
|
|
server.sessions_lock.Lock()
|
|
|
|
sessions := make([]*ServerSession, 0, len(server.sessions))
|
|
|
|
sessions := make([]*ServerSession, 0, len(server.sessions))
|
|
|
|
for _, session := range(server.sessions) {
|
|
|
|
for _, session := range(server.sessions) {
|
|
|
|
sessions = append(sessions, session)
|
|
|
|
sessions = append(sessions, session)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for _, session := range(sessions) {
|
|
|
|
for _, session := range(sessions) {
|
|
|
|
server.close_session(session)
|
|
|
|
server.close_session(session)
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -424,31 +365,16 @@ func(server *Server) close_session(session *ServerSession) {
|
|
|
|
close(session.IncomingPackets)
|
|
|
|
close(session.IncomingPackets)
|
|
|
|
close(session.OutgoingPackets)
|
|
|
|
close(session.OutgoingPackets)
|
|
|
|
delete(server.sessions, session.ID)
|
|
|
|
delete(server.sessions, session.ID)
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const SESSION_TIMEOUT = time.Minute * 5
|
|
|
|
session_closed := NewSessionTimed(SESSION_CLOSED, server.key, &session.Session, time.Now())
|
|
|
|
const SESSION_TIMEOUT_CHECK = time.Minute
|
|
|
|
server.connection.WriteToUDP(session_closed, session.remote)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func(server *Server) cleanup_sessions() {
|
|
|
|
func(server *Server) update_state() {
|
|
|
|
for server.active.Load() {
|
|
|
|
for server.active.Load() {
|
|
|
|
select {
|
|
|
|
select {
|
|
|
|
case <-time.After(SESSION_TIMEOUT_CHECK):
|
|
|
|
case command := <-server.commands:
|
|
|
|
server.sessions_lock.Lock()
|
|
|
|
server.Log("Incoming server command %+v", command)
|
|
|
|
now := time.Now()
|
|
|
|
|
|
|
|
stale_sessions := make([]*ServerSession, 0, len(server.sessions))
|
|
|
|
|
|
|
|
for _, session := range(server.sessions) {
|
|
|
|
|
|
|
|
if now.Sub(session.LastSeen) >= SESSION_TIMEOUT {
|
|
|
|
|
|
|
|
server.Log("Closing stale session %s for %s", session.ID, session.Peer)
|
|
|
|
|
|
|
|
stale_sessions = append(stale_sessions, session)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for _, session := range(stale_sessions) {
|
|
|
|
|
|
|
|
server.close_session(session)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server.sessions_lock.Unlock()
|
|
|
|
|
|
|
|
// TODO: add a way for this to be shutdown instantly on server shutdown
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -472,7 +398,7 @@ func(server *Server) Start(listen string) error {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
go server.listen_udp()
|
|
|
|
go server.listen_udp()
|
|
|
|
go server.cleanup_sessions()
|
|
|
|
go server.update_state()
|
|
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|