pnyx/server.go

425 lines
12 KiB
Go

2024-04-03 18:52:04 -06:00
package pnyx
import (
"crypto/ed25519"
"crypto/rand"
2024-04-03 18:52:04 -06:00
"errors"
"fmt"
"net"
"os"
"reflect"
"sync"
2024-04-03 18:52:04 -06:00
"sync/atomic"
2024-04-08 17:23:55 -06:00
"time"
2024-04-03 18:52:04 -06:00
)
const (
SERVER_UDP_BUFFER_SIZE = 2048
SERVER_SEND_BUFFER_SIZE = 2048
2024-04-03 18:52:04 -06:00
)
type ServerSession struct {
Session
2024-04-08 17:23:55 -06:00
LastSeen time.Time
IncomingPackets chan[]byte
OutgoingPackets chan *Packet
2024-04-03 18:52:04 -06:00
}
type Server struct {
key ed25519.PrivateKey
2024-04-03 18:52:04 -06:00
active atomic.Bool
connection *net.UDPConn
stopped chan error
modes map[reflect.Type]ModeID
send_packets chan[]SendPacket
sessions_lock sync.Mutex
sessions map[SessionID]*ServerSession
channels_lock sync.RWMutex
channels map[ChannelID]*Channel
2024-04-03 18:52:04 -06:00
}
func NewServer(key ed25519.PrivateKey) (*Server, error) {
if key == nil {
var err error
_, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
}
2024-04-03 18:52:04 -06:00
server := &Server{
key: key,
2024-04-03 18:52:04 -06:00
connection: nil,
active: atomic.Bool{},
stopped: make(chan error, 0),
send_packets: make(chan []SendPacket, SERVER_SEND_BUFFER_SIZE),
modes: map[reflect.Type]ModeID{
reflect.TypeFor[*RawMode](): MODE_RAW,
},
sessions: map[SessionID]*ServerSession{},
channels: map[ChannelID]*Channel{},
2024-04-03 18:52:04 -06:00
}
server.active.Store(false)
return server, nil
}
func(server *Server) RemoveChannel(id ChannelID) error {
server.channels_lock.Lock()
defer server.channels_lock.Unlock()
_, exists := server.channels[id]
if exists == false {
return fmt.Errorf("Channel %x does not exist", id)
}
delete(server.channels, id)
return nil
}
func(server *Server) AddChannel(id ChannelID, modes ...Mode) error {
if id == RootChannelID {
return fmt.Errorf("Cannot use root channel ID as real channel")
}
server.channels_lock.Lock()
defer server.channels_lock.Unlock()
_, exists := server.channels[id]
if exists {
return fmt.Errorf("Channel with ID %x already exists", id)
}
mode_map := map[ModeID]Mode{}
for _, mode := range(modes) {
reflect_type := reflect.TypeOf(mode)
mode_id, known := server.modes[reflect_type]
if known == false {
return fmt.Errorf("Can't create channel with unknown mode: %s", reflect_type)
}
_, exists := mode_map[mode_id]
if exists {
return fmt.Errorf("Can't create channel with duplicate ModeID %x", mode_id)
}
mode_map[mode_id] = mode
}
server.channels[id] = &Channel{
modes: mode_map,
}
return nil
2024-04-03 18:52:04 -06:00
}
func(server *Server) Log(format string, fields ...interface{}) {
2024-04-03 18:52:04 -06:00
fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n")
}
func(server *Server) Stop() error {
was_active := server.active.CompareAndSwap(true, false)
if was_active {
close(server.send_packets)
2024-04-03 18:52:04 -06:00
err := server.connection.Close()
if err != nil {
return err
}
return <-server.stopped
} else {
return fmt.Errorf("Called stop func on stopped server")
}
}
const SESSION_BUFFER_SIZE = 256
func(server *Server) listen_udp() {
2024-04-03 18:52:04 -06:00
server.Log("Started server on %s", server.connection.LocalAddr())
var buf [SERVER_UDP_BUFFER_SIZE]byte
for true {
read, from, err := server.connection.ReadFromUDP(buf[:])
2024-04-03 18:52:04 -06:00
if err == nil {
var packet_type PacketType = PacketType(buf[0])
2024-04-03 18:52:04 -06:00
switch packet_type {
case SESSION_OPEN:
session_open, ecdh_private, err := NewSessionOpen(server.key)
if err != nil {
server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
2024-04-07 13:27:28 -06:00
session, err := ParseSessionOpen(ecdh_private, buf[COMMAND_LENGTH:read])
if err != nil {
2024-04-07 13:27:28 -06:00
server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
server.sessions_lock.Lock()
server.sessions[session.ID] = &ServerSession{
Session: session,
2024-04-08 17:23:55 -06:00
LastSeen: time.Now(),
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE),
}
server.sessions_lock.Unlock()
go func(session *ServerSession, server *Server){
server.Log("Starting session outgoing goroutine %s", session.ID)
for true {
packet := <- session.OutgoingPackets
if packet == nil {
break
}
if session.remote == nil {
server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID)
continue
}
packet_data, err := packet.MarshalBinary()
if err != nil {
server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err)
continue
}
encrypted, err := NewSessionData(&session.Session, packet_data)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err)
continue
}
_, err = server.connection.WriteToUDP(encrypted, session.remote)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err)
continue
}
}
server.Log("Stopping session outgoing goroutine %s", session.ID)
}(server.sessions[session.ID], server)
go func(session *ServerSession, server *Server){
server.Log("Starting session incoming goroutine %s", session.ID)
for true {
encrypted := <- session.IncomingPackets
if encrypted == nil {
break
}
data, err := ParseSessionData(&session.Session, encrypted)
if err != nil {
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
continue
}
packet, err := ParsePacket(data)
if err != nil {
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
}
if packet.Channel == RootChannelID {
// TODO process commands on the root channel
} else {
var result []SendPacket = nil
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
mode, exists := channel.modes[packet.Mode]
if exists == true {
result = mode.Process(&session.Session, packet)
}
}
server.channels_lock.RUnlock()
if result != nil {
//TODO: handle overflow
server.send_packets<-result
}
}
}
server.Log("Stopping session incoming goroutine %s", session.ID)
}(server.sessions[session.ID], server)
2024-04-07 13:27:28 -06:00
_, err = server.connection.WriteToUDP(session_open, from)
if err != nil {
2024-04-07 13:27:28 -06:00
server.Log("WriteToUDP error %s", err)
continue
}
2024-04-07 13:27:28 -06:00
server.Log("Started session %s with %s", session.ID, session.Peer)
case SESSION_CONNECT:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
server.Log("Session %s does not exist, can't connect", session_id)
continue
}
client_addr, err := ParseSessionConnect(buf[COMMAND_LENGTH+ID_LENGTH:read], session.secret)
if err != nil {
server.Log("Error parsing session connect: %s", err)
continue
}
session.remote = client_addr
2024-04-08 17:23:55 -06:00
session.LastSeen = time.Now()
2024-04-07 13:27:28 -06:00
// TODO: Make a better server hello
server_hello, err := NewSessionData(&session.Session, []byte("hello"))
2024-04-07 13:38:15 -06:00
if err != nil {
server.Log("Error generating server hello: %s", err)
continue
}
_, err = server.connection.WriteToUDP(server_hello, session.remote)
if err != nil {
server.Log("Error sending server hello: %s", err)
continue
}
2024-04-08 18:21:17 -06:00
server.Log("Sent server_hello for %s to %s", session.ID, session.remote)
2024-04-07 13:38:15 -06:00
2024-04-07 13:27:28 -06:00
case SESSION_CLOSE:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
server.Log("Session %s does not exist, can't close", session_id)
continue
}
err := ParseSessionClose(&session.Session, buf[COMMAND_LENGTH+ID_LENGTH:read])
2024-04-07 13:27:28 -06:00
if err != nil {
server.Log("Session close error for %s - %s", session_id, err)
continue
}
server.sessions_lock.Lock()
2024-04-08 17:23:55 -06:00
server.close_session(session)
server.sessions_lock.Unlock()
2024-04-07 13:27:28 -06:00
server.Log("Session %s closed", session_id)
case SESSION_DATA:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
server.Log("Session %s does not exist, can't receive data", session_id)
continue
}
2024-04-08 17:23:55 -06:00
session.LastSeen = time.Now()
buf_copy := make([]byte, read - COMMAND_LENGTH - ID_LENGTH)
copy(buf_copy, buf[COMMAND_LENGTH+ID_LENGTH:read])
select {
case session.IncomingPackets<-buf_copy:
default:
server.Log("Dropped packet to session %s", session_id)
}
2024-04-03 18:52:04 -06:00
default:
server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read])
2024-04-03 18:52:04 -06:00
}
} else if errors.Is(err, net.ErrClosed) {
server.Log("UDP_CLOSE: %s", server.connection.LocalAddr())
break
} else {
server.Log("UDP_READ_ERROR: %s", err)
}
}
server.sessions_lock.Lock()
2024-04-08 17:23:55 -06:00
sessions := make([]*ServerSession, 0, len(server.sessions))
for _, session := range(server.sessions) {
sessions = append(sessions, session)
}
for _, session := range(sessions) {
server.close_session(session)
}
server.sessions_lock.Unlock()
2024-04-03 18:52:04 -06:00
server.Log("Shut down server on %s", server.connection.LocalAddr())
server.stopped <- nil
}
func(server *Server) send_sessions() {
2024-04-08 17:23:55 -06:00
for server.active.Load() {
packets := <- server.send_packets
if packets == nil {
break
}
server.sessions_lock.Lock()
for _, packet := range(packets) {
session, exists := server.sessions[packet.Session]
if exists {
session.OutgoingPackets <- packet.Packet
}
}
server.sessions_lock.Unlock()
}
}
2024-04-08 17:23:55 -06:00
func(server *Server) close_session(session *ServerSession) {
close(session.IncomingPackets)
close(session.OutgoingPackets)
delete(server.sessions, session.ID)
}
const SESSION_TIMEOUT = time.Minute * 5
const SESSION_TIMEOUT_CHECK = time.Minute
func(server *Server) cleanup_sessions() {
for server.active.Load() {
select {
case <-time.After(SESSION_TIMEOUT_CHECK):
server.Log("Running stale session check")
server.sessions_lock.Lock()
now := time.Now()
stale_sessions := make([]*ServerSession, 0, len(server.sessions))
for _, session := range(server.sessions) {
if now.Sub(session.LastSeen) >= SESSION_TIMEOUT {
server.Log("Closing stale session %s for %s", session.ID, session.Peer)
stale_sessions = append(stale_sessions, session)
}
}
for _, session := range(stale_sessions) {
server.close_session(session)
}
server.sessions_lock.Unlock()
// TODO: add a way for this to be shutdown instantly on server shutdown
}
}
}
2024-04-03 18:52:04 -06:00
func(server *Server) Start(listen string) error {
was_inactive := server.active.CompareAndSwap(false, true)
if was_inactive == false {
return fmt.Errorf("Server already active")
}
address, err := net.ResolveUDPAddr("udp", listen)
if err != nil {
server.active.Store(false)
return err
}
server.connection, err = net.ListenUDP("udp", address)
if err != nil {
server.active.Store(false)
return fmt.Errorf("Failed to create udp server: %w", err)
}
go server.listen_udp()
go server.send_sessions()
2024-04-08 17:23:55 -06:00
go server.cleanup_sessions()
2024-04-03 18:52:04 -06:00
return nil
}