package pnyx import ( "crypto/cipher" "crypto/ed25519" "crypto/rand" "errors" "fmt" mrand "math/rand" "net" "os" "sync/atomic" "github.com/google/uuid" ) const ( SERVER_UDP_BUFFER_SIZE = 2048 ) type SessionID uuid.UUID func(id SessionID) String() string { return uuid.UUID(id).String() } type Session struct { ID SessionID remote *net.UDPAddr Peer PeerID secret []byte cipher cipher.Block iv_generator mrand.Source64 } type Server struct { key ed25519.PrivateKey active atomic.Bool connection *net.UDPConn stopped chan error sessions map[SessionID]*Session channels map[ChannelID]*Channel } func NewServer(key ed25519.PrivateKey) (*Server, error) { if key == nil { var err error _, key, err = ed25519.GenerateKey(rand.Reader) if err != nil { return nil, err } } server := &Server{ key: key, connection: nil, active: atomic.Bool{}, stopped: make(chan error, 0), sessions: map[SessionID]*Session{}, channels: map[ChannelID]*Channel{}, } server.active.Store(false) return server, nil } // Check if the client has permission for the command on the channel // If it's not specified, check the permission on the parent func Allowed(server *Server, client PeerID, channel_id ChannelID, mode ModeID, command CommandID) bool { channel, exists := server.channels[channel_id] if exists == false { return false } if channel.permissions != nil { client_perms, exists := channel.permissions[client] if exists { if client_perms == nil { return true } mode_perms, exists := client_perms[mode] if exists { if mode_perms == nil { return true } allowed, exists := mode_perms[command] if exists { return allowed } } } } // Prevent a cycle on the root channel if channel_id == RootChannelID { return false } else { return Allowed(server, client, channel.parent, mode, command) } } func (server *Server) Log(format string, fields ...interface{}) { fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n") } func(server *Server) Stop() error { was_active := server.active.CompareAndSwap(true, false) if was_active { err := server.connection.Close() if err != nil { return err } return <-server.stopped } else { return fmt.Errorf("Called stop func on stopped server") } } func(server *Server) run() { server.Log("Started server on %s", server.connection.LocalAddr()) var buf [SERVER_UDP_BUFFER_SIZE]byte for true { read, from, err := server.connection.ReadFromUDP(buf[:]) if err == nil { var packet_type PacketType = PacketType(buf[0]) switch packet_type { case SESSION_OPEN: session_open, ecdh_private, err := NewSessionOpen(server.key) if err != nil { server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) continue } session, err := ParseSessionOpen(ecdh_private, buf[COMMAND_LENGTH:read]) if err != nil { server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) continue } server.sessions[session.ID] = &session _, err = server.connection.WriteToUDP(session_open, from) if err != nil { server.Log("WriteToUDP error %s", err) continue } server.Log("Started session %s with %s", session.ID, session.Peer) case SESSION_CONNECT: session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { server.Log("Session %s does not exist, can't connect", session_id) continue } client_addr, err := ParseSessionConnect(buf[COMMAND_LENGTH+ID_LENGTH:read], session.secret) if err != nil { server.Log("Error parsing session connect: %s", err) continue } session.remote = client_addr server.Log("Got SESSION_CONNECT for client %s at address %s", session.Peer, session.remote) // TODO: Send server hello back case SESSION_CLOSE: session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) session, exists := server.sessions[session_id] if exists == false { server.Log("Session %s does not exist, can't close", session_id) continue } err := ParseSessionClose(session, buf[COMMAND_LENGTH+ID_LENGTH:]) if err != nil { server.Log("Session close error for %s - %s", session_id, err) continue } delete(server.sessions, session_id) server.Log("Session %s closed", session_id) case SESSION_DATA: default: server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read]) } } else if errors.Is(err, net.ErrClosed) { server.Log("UDP_CLOSE: %s", server.connection.LocalAddr()) break } else { server.Log("UDP_READ_ERROR: %s", err) } } server.Log("Shut down server on %s", server.connection.LocalAddr()) server.stopped <- nil } func(server *Server) Start(listen string) error { was_inactive := server.active.CompareAndSwap(false, true) if was_inactive == false { return fmt.Errorf("Server already active") } address, err := net.ResolveUDPAddr("udp", listen) if err != nil { server.active.Store(false) return err } server.connection, err = net.ListenUDP("udp", address) if err != nil { server.active.Store(false) return fmt.Errorf("Failed to create udp server: %w", err) } go server.run() return nil }