package pnyx import ( "crypto/ed25519" "crypto/rand" "errors" "fmt" "net" "os" "reflect" "sync" "sync/atomic" "time" ) const ( SERVER_UDP_BUFFER_SIZE = 2048 SERVER_SEND_BUFFER_SIZE = 2048 ) type ServerSession struct { Session LastSeen time.Time IncomingPackets chan[]byte OutgoingPackets chan *Packet } type Server struct { key ed25519.PrivateKey active atomic.Bool connection *net.UDPConn stopped chan error modes map[reflect.Type]ModeID send_packets chan[]SendPacket sessions_lock sync.Mutex sessions map[SessionID]*ServerSession channels_lock sync.RWMutex channels map[ChannelID]*Channel } func NewServer(key ed25519.PrivateKey) (*Server, error) { if key == nil { var err error _, key, err = ed25519.GenerateKey(rand.Reader) if err != nil { return nil, err } } server := &Server{ key: key, connection: nil, active: atomic.Bool{}, stopped: make(chan error, 0), send_packets: make(chan []SendPacket, SERVER_SEND_BUFFER_SIZE), modes: map[reflect.Type]ModeID{ reflect.TypeFor[*RawMode](): MODE_RAW, }, sessions: map[SessionID]*ServerSession{}, channels: map[ChannelID]*Channel{}, } server.active.Store(false) return server, nil } func(server *Server) RemoveChannel(id ChannelID) error { server.channels_lock.Lock() defer server.channels_lock.Unlock() _, exists := server.channels[id] if exists == false { return fmt.Errorf("Channel %x does not exist", id) } delete(server.channels, id) return nil } func(server *Server) AddChannel(id ChannelID, modes ...Mode) error { if id == RootChannelID { return fmt.Errorf("Cannot use root channel ID as real channel") } server.channels_lock.Lock() defer server.channels_lock.Unlock() _, exists := server.channels[id] if exists { return fmt.Errorf("Channel with ID %x already exists", id) } mode_map := map[ModeID]Mode{} for _, mode := range(modes) { reflect_type := reflect.TypeOf(mode) mode_id, known := server.modes[reflect_type] if known == false { return fmt.Errorf("Can't create channel with unknown mode: %s", reflect_type) } _, exists := mode_map[mode_id] if exists { return fmt.Errorf("Can't create channel with duplicate ModeID %x", mode_id) } mode_map[mode_id] = mode } server.channels[id] = &Channel{ modes: mode_map, } return nil } func(server *Server) Log(format string, fields ...interface{}) { fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n") } func(server *Server) Stop() error { was_active := server.active.CompareAndSwap(true, false) if was_active { close(server.send_packets) err := server.connection.Close() if err != nil { return err } return <-server.stopped } else { return fmt.Errorf("Called stop func on stopped server") } } const SESSION_BUFFER_SIZE = 256 func(server *Server) listen_udp() { server.Log("Started server on %s", server.connection.LocalAddr()) var buf [SERVER_UDP_BUFFER_SIZE]byte for true { read, from, err := server.connection.ReadFromUDP(buf[:]) if err == nil { var packet_type PacketType = PacketType(buf[0]) switch packet_type { case SESSION_OPEN: session_open, ecdh_private, err := NewSessionOpen(server.key) if err != nil { server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) continue } session, err := ParseSessionOpen(ecdh_private, buf[COMMAND_LENGTH:read]) if err != nil { server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) continue } server.sessions_lock.Lock() server.sessions[session.ID] = &ServerSession{ Session: session, LastSeen: time.Now(), IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE), OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE), } server.sessions_lock.Unlock() go func(session *ServerSession, server *Server){ server.Log("Starting session outgoing goroutine %s", session.ID) for true { packet := <- session.OutgoingPackets if packet == nil { break } if session.remote == nil { server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID) continue } packet_data, err := packet.MarshalBinary() if err != nil { server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err) continue } encrypted, err := NewSessionData(&session.Session, packet_data) if err != nil { server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err) continue } _, err = server.connection.WriteToUDP(encrypted, session.remote) if err != nil { server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err) continue } } server.Log("Stopping session outgoing goroutine %s", session.ID) }(server.sessions[session.ID], server) go func(session *ServerSession, server *Server){ server.Log("Starting session incoming goroutine %s", session.ID) for true { encrypted := <- session.IncomingPackets if encrypted == nil { break } data, err := ParseSessionData(&session.Session, encrypted) if err != nil { server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) continue } packet, err := ParsePacket(data) if err != nil { server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err) } if packet.Channel == RootChannelID { // TODO process commands on the root channel } else { var result []SendPacket = nil server.channels_lock.RLock() channel, exists := server.channels[packet.Channel] if exists == true { mode, exists := channel.modes[packet.Mode] if exists == true { result = mode.Process(&session.Session, packet) } } server.channels_lock.RUnlock() if result != nil { //TODO: handle overflow server.send_packets<-result } } } server.Log("Stopping session incoming goroutine %s", session.ID) }(server.sessions[session.ID], server) _, err = server.connection.WriteToUDP(session_open, from) if err != nil { server.Log("WriteToUDP error %s", err) continue } server.Log("Started session %s with %s", session.ID, session.Peer) case SESSION_CONNECT: session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { server.Log("Session %s does not exist, can't connect", session_id) continue } client_addr, err := ParseSessionConnect(buf[COMMAND_LENGTH+ID_LENGTH:read], session.secret) if err != nil { server.Log("Error parsing session connect: %s", err) continue } session.remote = client_addr session.LastSeen = time.Now() // TODO: Make a better server hello server_hello, err := NewSessionData(&session.Session, []byte("hello")) if err != nil { server.Log("Error generating server hello: %s", err) continue } _, err = server.connection.WriteToUDP(server_hello, session.remote) if err != nil { server.Log("Error sending server hello: %s", err) continue } server.Log("Sent server_hello for %s to %s(from %s)", session.ID, session.remote, from) case SESSION_CLOSE: session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { server.Log("Session %s does not exist, can't close", session_id) continue } err := ParseSessionClose(&session.Session, buf[COMMAND_LENGTH+ID_LENGTH:read]) if err != nil { server.Log("Session close error for %s - %s", session_id, err) continue } server.sessions_lock.Lock() server.close_session(session) server.sessions_lock.Unlock() server.Log("Session %s closed", session_id) case SESSION_DATA: session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { server.Log("Session %s does not exist, can't receive data", session_id) continue } session.LastSeen = time.Now() buf_copy := make([]byte, read - COMMAND_LENGTH - ID_LENGTH) copy(buf_copy, buf[COMMAND_LENGTH+ID_LENGTH:read]) select { case session.IncomingPackets<-buf_copy: default: server.Log("Dropped packet to session %s", session_id) } default: server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read]) } } else if errors.Is(err, net.ErrClosed) { server.Log("UDP_CLOSE: %s", server.connection.LocalAddr()) break } else { server.Log("UDP_READ_ERROR: %s", err) } } server.sessions_lock.Lock() sessions := make([]*ServerSession, 0, len(server.sessions)) for _, session := range(server.sessions) { sessions = append(sessions, session) } for _, session := range(sessions) { server.close_session(session) } server.sessions_lock.Unlock() server.Log("Shut down server on %s", server.connection.LocalAddr()) server.stopped <- nil } func(server *Server) send_sessions() { for server.active.Load() { packets := <- server.send_packets if packets == nil { break } server.sessions_lock.Lock() for _, packet := range(packets) { session, exists := server.sessions[packet.Session] if exists { session.OutgoingPackets <- packet.Packet } } server.sessions_lock.Unlock() } } func(server *Server) close_session(session *ServerSession) { close(session.IncomingPackets) close(session.OutgoingPackets) delete(server.sessions, session.ID) } const SESSION_TIMEOUT = time.Minute * 5 const SESSION_TIMEOUT_CHECK = time.Minute func(server *Server) cleanup_sessions() { for server.active.Load() { select { case <-time.After(SESSION_TIMEOUT_CHECK): server.Log("Running stale session check") server.sessions_lock.Lock() now := time.Now() stale_sessions := make([]*ServerSession, 0, len(server.sessions)) for _, session := range(server.sessions) { if now.Sub(session.LastSeen) >= SESSION_TIMEOUT { server.Log("Closing stale session %s for %s", session.ID, session.Peer) stale_sessions = append(stale_sessions, session) } } for _, session := range(stale_sessions) { server.close_session(session) } server.sessions_lock.Unlock() // TODO: add a way for this to be shutdown instantly on server shutdown } } } func(server *Server) Start(listen string) error { was_inactive := server.active.CompareAndSwap(false, true) if was_inactive == false { return fmt.Errorf("Server already active") } address, err := net.ResolveUDPAddr("udp", listen) if err != nil { server.active.Store(false) return err } server.connection, err = net.ListenUDP("udp", address) if err != nil { server.active.Store(false) return fmt.Errorf("Failed to create udp server: %w", err) } go server.listen_udp() go server.send_sessions() go server.cleanup_sessions() return nil }