package pnyx import ( "crypto/ed25519" "crypto/rand" "errors" "fmt" "net" "os" "reflect" "sync" "sync/atomic" "time" ) const ( SERVER_UDP_BUFFER_SIZE = 2048 SERVER_SEND_BUFFER_SIZE = 2048 ) type RoleID uint32 type ServerSession struct { Session LastSeen time.Time IncomingPackets chan[]byte OutgoingPackets chan *Packet } type Server struct { key ed25519.PrivateKey active atomic.Bool connection *net.UDPConn stopped chan error modes map[reflect.Type]ModeID send_packets chan[]SendPacket sessions_lock sync.Mutex sessions map[SessionID]*ServerSession channels_lock sync.RWMutex channels map[ChannelID]*Channel peers map[PeerID][]RoleID } func NewServer(key ed25519.PrivateKey) (*Server, error) { if key == nil { var err error _, key, err = ed25519.GenerateKey(rand.Reader) if err != nil { return nil, err } } server := &Server{ key: key, connection: nil, active: atomic.Bool{}, stopped: make(chan error, 0), send_packets: make(chan []SendPacket, SERVER_SEND_BUFFER_SIZE), modes: map[reflect.Type]ModeID{ reflect.TypeFor[*RawMode](): MODE_RAW, reflect.TypeFor[*AudioMode](): MODE_AUDIO, }, sessions: map[SessionID]*ServerSession{}, channels: map[ChannelID]*Channel{}, peers: map[PeerID][]RoleID{}, } server.active.Store(false) return server, nil } func(server *Server) RemoveChannel(id ChannelID) error { server.channels_lock.Lock() defer server.channels_lock.Unlock() _, exists := server.channels[id] if exists == false { return fmt.Errorf("Channel %x does not exist", id) } delete(server.channels, id) return nil } func(server *Server) AddChannel(id ChannelID, modes ...Mode) error { server.channels_lock.Lock() defer server.channels_lock.Unlock() _, exists := server.channels[id] if exists { return fmt.Errorf("Channel with ID %x already exists", id) } mode_map := map[ModeID]Mode{} for _, mode := range(modes) { reflect_type := reflect.TypeOf(mode) mode_id, known := server.modes[reflect_type] if known == false { return fmt.Errorf("Can't create channel with unknown mode: %s", reflect_type) } _, exists := mode_map[mode_id] if exists { return fmt.Errorf("Can't create channel with duplicate ModeID %x", mode_id) } mode_map[mode_id] = mode } server.channels[id] = &Channel{ id: id, modes: mode_map, sessions: []SessionID{}, } return nil } func(server *Server) Log(format string, fields ...interface{}) { fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n") } func(server *Server) Stop() error { was_active := server.active.CompareAndSwap(true, false) if was_active { close(server.send_packets) err := server.connection.Close() if err != nil { return err } return <-server.stopped } else { return fmt.Errorf("Called stop func on stopped server") } } const SESSION_BUFFER_SIZE = 256 func handle_session_outgoing(session *ServerSession, server *Server) { server.Log("Starting session outgoing goroutine %s", session.ID) for true { packet := <- session.OutgoingPackets if packet == nil { break } if session.remote == nil { server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID) continue } packet_data, err := packet.MarshalBinary() if err != nil { server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err) continue } encrypted, err := NewSessionData(&session.Session, packet_data) if err != nil { server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err) continue } _, err = server.connection.WriteToUDP(encrypted, session.remote) if err != nil { server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err) continue } } server.Log("Stopping session outgoing goroutine %s", session.ID) } func handle_session_incoming(session *ServerSession, server *Server) { server.Log("Starting session incoming goroutine %s", session.ID) for true { encrypted := <- session.IncomingPackets if encrypted == nil { break } data, err := ParseSessionData(&session.Session, encrypted) if err != nil { server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) continue } packet, err := ParsePacket(data) if err != nil { server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err) } switch packet := packet.(type) { case ServerCommandPacket: switch packet.Command { case SERVER_COMMAND_ADD_CHANNEL: server.Log("Got add_channel with %x", packet.Data) case SERVER_COMMAND_DEL_CHANNEL: server.Log("Got del_channel with %x", packet.Data) default: server.Log("Unknown server command %x", packet.Command) } case ChannelDataPacket: var result []SendPacket = nil server.channels_lock.RLock() channel, exists := server.channels[packet.Channel] if exists == true { result = channel.Data(&session.Session, packet.Mode, packet.Data) } server.channels_lock.RUnlock() if exists == false { server.Log("Packet for unknown channel %d", packet.Channel) } else if len(result) > 0 { //TODO: handle overflow server.send_packets<-result } case ChannelCommandPacket: var result []SendPacket = nil server.channels_lock.RLock() channel, exists := server.channels[packet.Channel] if exists == true { result, err = channel.Command(&session.Session, packet) } server.channels_lock.RUnlock() if exists == false { server.Log("Packet for unknown channel %d", packet.Channel) } else if err != nil { server.Log("Error processing %+v - %s", packet, err) } else if len(result) > 0 { //TODO: handle overflow server.send_packets<-result } default: server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err) } } server.Log("Stopping session incoming goroutine %s", session.ID) } func(server *Server) handle_session_open(client_session_open []byte, from *net.UDPAddr) error { session, session_opened, err := ParseSessionOpen(server.key, client_session_open) if err != nil { return err } server.sessions_lock.Lock() server.sessions[session.ID] = &ServerSession{ Session: session, LastSeen: time.Now(), IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE), OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE), } server.sessions_lock.Unlock() go handle_session_outgoing(server.sessions[session.ID], server) go handle_session_incoming(server.sessions[session.ID], server) _, err = server.connection.WriteToUDP(session_opened, from) if err != nil { return err } server.Log("Started session %s with %x", session.ID, session.Peer) return nil } func(server *Server) handle_session_connect(session_connect []byte, from *net.UDPAddr) error { session_id := SessionID(session_connect[:SESSION_ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { return fmt.Errorf("Session %s does not exist, can't connect", session_id) } session_connected, err := ParseSessionTimed(SESSION_CONNECTED, server.key, session_connect, &session.Session, time.Now().Add(-1*time.Second).UnixMilli()) if err != nil { return err } // TODO: fix, was client_addr but the client doesnt know it's nat assignment session.remote = from session.LastSeen = time.Now() _, err = server.connection.WriteToUDP(session_connected, session.remote) if err != nil { return err } server.Log("Sent server_hello for %s to %s(from %s)", session.ID, session.remote, from) return nil } func(server *Server) handle_session_close(session_close []byte, from *net.UDPAddr) error { session_id := SessionID(session_close[:SESSION_ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { return fmt.Errorf("Session %s does not exist, can't close", session_id) } session_closed, err := ParseSessionTimed(SESSION_CLOSED, server.key, session_close, &session.Session, time.Now().Add(-1*time.Second).UnixMilli()) if err != nil { return err } server.sessions_lock.Lock() server.close_session(session) server.sessions_lock.Unlock() _, err = server.connection.WriteToUDP(session_closed, session.remote) if err != nil { return err } server.Log("Session %s closed", session_id) return nil } // TODO: handle packets without creating so many objects(and preventing an extra copy) func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error { session_id := SessionID(data[:SESSION_ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { return fmt.Errorf("Session %s does not exist, can't receive data", session_id) } session.LastSeen = time.Now() buf_copy := make([]byte, len(data) - SESSION_ID_LENGTH) copy(buf_copy, data[SESSION_ID_LENGTH:]) select { case session.IncomingPackets<-buf_copy: return nil default: return fmt.Errorf("Dropped packet to session %s", session_id) } } func(server *Server) listen_udp() { server.Log("Started server on %s", server.connection.LocalAddr()) var buf [SERVER_UDP_BUFFER_SIZE]byte for true { read, from, err := server.connection.ReadFromUDP(buf[:]) if err == nil { var packet_type SessionPacketType = SessionPacketType(buf[0]) switch packet_type { case SESSION_OPEN: err := server.handle_session_open(buf[COMMAND_LENGTH:read], from) if err != nil { server.Log("handle_session_open erorr - %s", err) } case SESSION_CONNECT: err := server.handle_session_connect(buf[COMMAND_LENGTH:read], from) if err != nil { server.Log("handle_session_connect error - %s", err) } case SESSION_CLOSE: err := server.handle_session_close(buf[COMMAND_LENGTH:read], from) if err != nil { server.Log("handle_session_close error - %s", err) } case SESSION_DATA: err := server.handle_session_data(buf[COMMAND_LENGTH:read], from) if err != nil { server.Log("handle_session_data error - %s", err) } default: server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read]) } } else if errors.Is(err, net.ErrClosed) { server.Log("UDP_CLOSE: %s", server.connection.LocalAddr()) break } else { server.Log("UDP_READ_ERROR: %s", err) } } server.sessions_lock.Lock() sessions := make([]*ServerSession, 0, len(server.sessions)) for _, session := range(server.sessions) { sessions = append(sessions, session) } for _, session := range(sessions) { server.close_session(session) } server.sessions_lock.Unlock() server.Log("Shut down server on %s", server.connection.LocalAddr()) server.stopped <- nil } func(server *Server) send_sessions() { for server.active.Load() { packets := <- server.send_packets if packets == nil { break } server.sessions_lock.Lock() for _, packet := range(packets) { session, exists := server.sessions[packet.Session] if exists { session.OutgoingPackets <- packet.Packet } } server.sessions_lock.Unlock() } } func(server *Server) close_session(session *ServerSession) { close(session.IncomingPackets) close(session.OutgoingPackets) delete(server.sessions, session.ID) } const SESSION_TIMEOUT = time.Minute * 5 const SESSION_TIMEOUT_CHECK = time.Minute func(server *Server) cleanup_sessions() { for server.active.Load() { select { case <-time.After(SESSION_TIMEOUT_CHECK): server.sessions_lock.Lock() now := time.Now() stale_sessions := make([]*ServerSession, 0, len(server.sessions)) for _, session := range(server.sessions) { if now.Sub(session.LastSeen) >= SESSION_TIMEOUT { server.Log("Closing stale session %s for %s", session.ID, session.Peer) stale_sessions = append(stale_sessions, session) } } for _, session := range(stale_sessions) { server.close_session(session) } server.sessions_lock.Unlock() // TODO: add a way for this to be shutdown instantly on server shutdown } } } func(server *Server) Start(listen string) error { was_inactive := server.active.CompareAndSwap(false, true) if was_inactive == false { return fmt.Errorf("Server already active") } address, err := net.ResolveUDPAddr("udp", listen) if err != nil { server.active.Store(false) return err } server.connection, err = net.ListenUDP("udp", address) if err != nil { server.active.Store(false) return fmt.Errorf("Failed to create udp server: %w", err) } go server.listen_udp() go server.send_sessions() go server.cleanup_sessions() return nil }