package pnyx import ( "crypto/ed25519" "crypto/rand" "errors" "fmt" "net" "os" "slices" "sync" "sync/atomic" "time" ) const ( SERVER_UDP_BUFFER_SIZE = 2048 SERVER_SEND_BUFFER_SIZE = 2048 SERVER_COMMAND_BUFFER_SIZE = 2048 ) type RoleID uint32 type ServerSession struct { Session active atomic.Bool LastSeen time.Time IncomingPackets chan[]byte OutgoingPackets chan Payload Channels []ChannelID } func(session *ServerSession) Send(payload Payload) { if session.active.Load() { session.OutgoingPackets <- payload } } type Server struct { key ed25519.PrivateKey active atomic.Bool connection *net.UDPConn stopped chan error commands chan Payload sessions_lock sync.Mutex sessions map[SessionID]*ServerSession channels atomic.Value peers map[PeerID][]RoleID BasePermissions atomic.Value RolePermissions atomic.Value UserPermissions atomic.Value } func NewServer(listen string, key ed25519.PrivateKey, channels map[ChannelID]*Channel) (*Server, error) { if key == nil { var err error _, key, err = ed25519.GenerateKey(rand.Reader) if err != nil { return nil, err } } address, err := net.ResolveUDPAddr("udp", listen) if err != nil { return nil, err } connection, err := net.ListenUDP("udp", address) if err != nil { return nil, fmt.Errorf("Failed to create udp server: %w", err) } server := &Server{ key: key, connection: connection, active: atomic.Bool{}, stopped: make(chan error, 0), commands: make(chan Payload, SERVER_COMMAND_BUFFER_SIZE), sessions: map[SessionID]*ServerSession{}, channels: atomic.Value{}, peers: map[PeerID][]RoleID{}, } server.channels.Store(channels) server.BasePermissions.Store(Permissions{}) server.RolePermissions.Store(map[Role]Permissions{}) server.UserPermissions.Store(map[PeerID]Permissions{}) server.active.Store(true) go server.listen_udp() go server.update_state() return server, nil } func(server *Server) Log(format string, fields ...interface{}) { fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n") } func(server *Server) Stop() error { if server.active.CompareAndSwap(true, false) { err := server.connection.Close() if err != nil { return err } return <-server.stopped } else { return fmt.Errorf("Called stop on stopped server") } } const SESSION_BUFFER_SIZE = 256 func handle_session_outgoing(session *ServerSession, server *Server) { server.Log("Starting session outgoing goroutine %s", session.ID) for session.active.Load() { packet := <- session.OutgoingPackets if packet == nil { break } if session.remote == nil { server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID) continue } packet_data, err := packet.MarshalBinary() if err != nil { server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err) continue } encrypted, err := NewSessionData(&session.Session, packet_data) if err != nil { server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err) continue } _, err = server.connection.WriteToUDP(encrypted, session.remote) if err != nil { server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err) continue } } server.Log("Stopping session outgoing goroutine %s", session.ID) } const SESSION_PING_TIME = time.Minute const SESSION_TIMEOUT = 2 * time.Minute type SessionChannelCommand struct { Session *ServerSession Packet *ChannelCommandPacket } func handle_session_incoming(session *ServerSession, server *Server) { server.Log("Starting session incoming goroutine %s", session.ID) ping_timer := time.After(SESSION_PING_TIME) for session.active.Load() { select { case <- ping_timer: if time.Now().Add(-1*SESSION_TIMEOUT).Compare(session.LastSeen) != -1 { server.Log("Closing %s after being inactive since %s", session.ID, session.LastSeen) server.sessions_lock.Lock() server.close_session(session) server.sessions_lock.Unlock() } else if time.Now().Add(-1*SESSION_PING_TIME).Compare(session.LastSeen) != -1 { server.Log("Pinging %s after being inactive since %s", session.ID, session.LastSeen) session.OutgoingPackets <- NewPingPacket() ping_timer = time.After(SESSION_PING_TIME) } else { server.Log("%s passed keep-alive check, last seen %s", session.ID, session.LastSeen) ping_timer = time.After(SESSION_PING_TIME) } case encrypted := <- session.IncomingPackets: if encrypted == nil { continue } data, err := ParseSessionData(&session.Session, encrypted) if err != nil { server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) continue } packet, err := ParsePacket(data) if err != nil { server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err) } switch packet := packet.(type) { case CommandPacket: server.commands<-packet case ChannelCommandPacket: channels := server.channels.Load().(map[ChannelID]*Channel) channel, exists := channels[packet.Channel] if exists == true { channel.Commands<-SessionChannelCommand{ Session: session, Packet: &packet, } } else { server.Log("Command for unknown channel %d", packet.Channel) } case DataPacket: channels := server.channels.Load().(map[ChannelID]*Channel) channel, exists := channels[packet.Channel] if exists == true { members := channel.Members.Load().([]*ServerSession) if slices.Contains(members, session) { mode, has_mode := channel.Modes[packet.Mode] if has_mode { mode.Load().(Mode).Data(session, packet.Channel, members, packet.Data) } } } else { server.Log("Data for unknown channel %d", packet.Channel) } default: server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err) } } } server.Log("Stopping session incoming goroutine %s", session.ID) } func(server *Server) handle_session_open(client_session_open []byte, from *net.UDPAddr) error { session, session_opened, err := ParseSessionOpen(server.key, client_session_open) if err != nil { return err } server.sessions_lock.Lock() server.sessions[session.ID] = &ServerSession{ Session: session, LastSeen: time.Now(), IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE), OutgoingPackets: make(chan Payload, SESSION_BUFFER_SIZE), } server.sessions[session.ID].active.Store(true) server.sessions_lock.Unlock() go handle_session_outgoing(server.sessions[session.ID], server) go handle_session_incoming(server.sessions[session.ID], server) _, err = server.connection.WriteToUDP(session_opened, from) if err != nil { return err } server.Log("Started session %s with %x", session.ID, session.Peer) return nil } func(server *Server) handle_session_connect(session_connect []byte, from *net.UDPAddr) error { session_id := SessionID(session_connect[:SESSION_ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { return fmt.Errorf("Session %s does not exist, can't connect", session_id) } session_connected, err := ParseSessionTimed(SESSION_CONNECTED, server.key, session_connect, &session.Session, time.Now().Add(-1*time.Second).UnixMilli()) if err != nil { return err } // TODO: fix, was client_addr but the client doesnt know it's nat assignment session.remote = from session.LastSeen = time.Now() _, err = server.connection.WriteToUDP(session_connected, session.remote) if err != nil { return err } server.Log("Sent server_hello for %s to %s(from %s)", session.ID, session.remote, from) return nil } func(server *Server) handle_session_close(session_close []byte, from *net.UDPAddr) error { session_id := SessionID(session_close[:SESSION_ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { return fmt.Errorf("Session %s does not exist, can't close", session_id) } session_closed, err := ParseSessionTimed(SESSION_CLOSED, server.key, session_close, &session.Session, time.Now().Add(-1*time.Second).UnixMilli()) if err != nil { return err } server.sessions_lock.Lock() server.close_session(session) server.sessions_lock.Unlock() _, err = server.connection.WriteToUDP(session_closed, session.remote) if err != nil { return err } server.Log("Session %s closed", session_id) return nil } // TODO: handle packets without creating so many objects(and preventing an extra copy) func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error { session_id := SessionID(data[:SESSION_ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { return fmt.Errorf("Session %s does not exist, can't receive data", session_id) } session.LastSeen = time.Now() buf_copy := make([]byte, len(data) - SESSION_ID_LENGTH) copy(buf_copy, data[SESSION_ID_LENGTH:]) select { case session.IncomingPackets<-buf_copy: return nil default: return fmt.Errorf("Dropped packet to session %s", session_id) } } func(server *Server) listen_udp() { server.Log("Started server on %s", server.connection.LocalAddr()) var buf [SERVER_UDP_BUFFER_SIZE]byte for true { read, from, err := server.connection.ReadFromUDP(buf[:]) if err == nil { var packet_type SessionPacketType = SessionPacketType(buf[0]) switch packet_type { case SESSION_OPEN: err := server.handle_session_open(buf[COMMAND_LENGTH:read], from) if err != nil { server.Log("handle_session_open erorr - %s", err) } case SESSION_CONNECT: err := server.handle_session_connect(buf[COMMAND_LENGTH:read], from) if err != nil { server.Log("handle_session_connect error - %s", err) } case SESSION_CLOSE: err := server.handle_session_close(buf[COMMAND_LENGTH:read], from) if err != nil { server.Log("handle_session_close error - %s", err) } case SESSION_DATA: err := server.handle_session_data(buf[COMMAND_LENGTH:read], from) if err != nil { server.Log("handle_session_data error - %s", err) } default: server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read]) } } else if errors.Is(err, net.ErrClosed) { server.Log("UDP_CLOSE: %s", server.connection.LocalAddr()) break } else { server.Log("UDP_READ_ERROR: %s", err) } } channels := server.channels.Load().(map[ChannelID]*Channel) for _, channel := range(channels) { close(channel.Commands) } server.sessions_lock.Lock() sessions := make([]*ServerSession, 0, len(server.sessions)) for _, session := range(server.sessions) { sessions = append(sessions, session) } for _, session := range(sessions) { server.close_session(session) } server.sessions_lock.Unlock() server.Log("Shut down server on %s", server.connection.LocalAddr()) server.stopped <- nil } func(server *Server) close_session(session *ServerSession) { session.active.Store(false) close(session.IncomingPackets) close(session.OutgoingPackets) delete(server.sessions, session.ID) session_closed := NewSessionTimed(SESSION_CLOSED, server.key, &session.Session, time.Now()) server.connection.WriteToUDP(session_closed, session.remote) } func(server *Server) update_state() { for server.active.Load() { select { case command := <-server.commands: server.Log("Incoming server command %+v", command) } } }