416 lines
12 KiB
Go
416 lines
12 KiB
Go
package pnyx
|
|
|
|
import (
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
SERVER_UDP_BUFFER_SIZE = 2048
|
|
SERVER_SEND_BUFFER_SIZE = 2048
|
|
SERVER_COMMAND_BUFFER_SIZE = 2048
|
|
)
|
|
|
|
type RoleID uint32
|
|
|
|
type ServerSession struct {
|
|
Session
|
|
active atomic.Bool
|
|
LastSeen time.Time
|
|
IncomingPackets chan[]byte
|
|
OutgoingPackets chan Payload
|
|
Channels []ChannelID
|
|
}
|
|
|
|
func(session *ServerSession) Send(payload Payload) {
|
|
if session.active.Load() {
|
|
session.OutgoingPackets <- payload
|
|
}
|
|
}
|
|
|
|
type Server struct {
|
|
key ed25519.PrivateKey
|
|
active atomic.Bool
|
|
connection *net.UDPConn
|
|
stopped chan error
|
|
commands chan Payload
|
|
|
|
sessions_lock sync.Mutex
|
|
sessions map[SessionID]*ServerSession
|
|
|
|
channels atomic.Value
|
|
|
|
peers map[PeerID][]RoleID
|
|
|
|
BasePermissions atomic.Value
|
|
RolePermissions atomic.Value
|
|
UserPermissions atomic.Value
|
|
}
|
|
|
|
func NewServer(listen string, key ed25519.PrivateKey, channels map[ChannelID]*Channel) (*Server, error) {
|
|
if key == nil {
|
|
var err error
|
|
_, key, err = ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
address, err := net.ResolveUDPAddr("udp", listen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
connection, err := net.ListenUDP("udp", address)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Failed to create udp server: %w", err)
|
|
|
|
}
|
|
|
|
server := &Server{
|
|
key: key,
|
|
connection: connection,
|
|
active: atomic.Bool{},
|
|
stopped: make(chan error, 0),
|
|
commands: make(chan Payload, SERVER_COMMAND_BUFFER_SIZE),
|
|
|
|
sessions: map[SessionID]*ServerSession{},
|
|
channels: atomic.Value{},
|
|
|
|
peers: map[PeerID][]RoleID{},
|
|
}
|
|
server.channels.Store(channels)
|
|
|
|
server.BasePermissions.Store(Permissions{})
|
|
server.RolePermissions.Store(map[Role]Permissions{})
|
|
server.UserPermissions.Store(map[PeerID]Permissions{})
|
|
|
|
server.active.Store(true)
|
|
|
|
go server.listen_udp()
|
|
go server.update_state()
|
|
|
|
return server, nil
|
|
}
|
|
|
|
func(server *Server) Log(format string, fields ...interface{}) {
|
|
fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n")
|
|
}
|
|
|
|
func(server *Server) Stop() error {
|
|
if server.active.CompareAndSwap(true, false) {
|
|
err := server.connection.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return <-server.stopped
|
|
} else {
|
|
return fmt.Errorf("Called stop on stopped server")
|
|
}
|
|
}
|
|
|
|
const SESSION_BUFFER_SIZE = 256
|
|
|
|
func handle_session_outgoing(session *ServerSession, server *Server) {
|
|
server.Log("Starting session outgoing goroutine %s", session.ID)
|
|
for session.active.Load() {
|
|
packet := <- session.OutgoingPackets
|
|
if packet == nil {
|
|
break
|
|
}
|
|
|
|
if session.remote == nil {
|
|
server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID)
|
|
continue
|
|
}
|
|
|
|
packet_data, err := packet.MarshalBinary()
|
|
if err != nil {
|
|
server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err)
|
|
continue
|
|
}
|
|
|
|
encrypted, err := NewSessionData(&session.Session, packet_data)
|
|
if err != nil {
|
|
server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err)
|
|
continue
|
|
}
|
|
|
|
_, err = server.connection.WriteToUDP(encrypted, session.remote)
|
|
if err != nil {
|
|
server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err)
|
|
continue
|
|
}
|
|
}
|
|
server.Log("Stopping session outgoing goroutine %s", session.ID)
|
|
}
|
|
|
|
const SESSION_PING_TIME = time.Minute
|
|
const SESSION_TIMEOUT = 2 * time.Minute
|
|
|
|
type SessionChannelCommand struct {
|
|
Session *ServerSession
|
|
Packet *ChannelCommandPacket
|
|
}
|
|
|
|
func handle_session_incoming(session *ServerSession, server *Server) {
|
|
server.Log("Starting session incoming goroutine %s", session.ID)
|
|
ping_timer := time.After(SESSION_PING_TIME)
|
|
for session.active.Load() {
|
|
select {
|
|
case <- ping_timer:
|
|
if time.Now().Add(-1*SESSION_TIMEOUT).Compare(session.LastSeen) != -1 {
|
|
server.Log("Closing %s after being inactive since %s", session.ID, session.LastSeen)
|
|
server.sessions_lock.Lock()
|
|
server.close_session(session)
|
|
server.sessions_lock.Unlock()
|
|
} else if time.Now().Add(-1*SESSION_PING_TIME).Compare(session.LastSeen) != -1 {
|
|
server.Log("Pinging %s after being inactive since %s", session.ID, session.LastSeen)
|
|
session.OutgoingPackets <- NewPingPacket()
|
|
ping_timer = time.After(SESSION_PING_TIME)
|
|
} else {
|
|
server.Log("%s passed keep-alive check, last seen %s", session.ID, session.LastSeen)
|
|
ping_timer = time.After(SESSION_PING_TIME)
|
|
}
|
|
case encrypted := <- session.IncomingPackets:
|
|
if encrypted == nil {
|
|
continue
|
|
}
|
|
|
|
data, err := ParseSessionData(&session.Session, encrypted)
|
|
if err != nil {
|
|
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
|
|
continue
|
|
}
|
|
|
|
packet, err := ParsePacket(data)
|
|
if err != nil {
|
|
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
|
|
}
|
|
|
|
switch packet := packet.(type) {
|
|
case CommandPacket:
|
|
server.commands<-packet
|
|
case ChannelCommandPacket:
|
|
channels := server.channels.Load().(map[ChannelID]*Channel)
|
|
channel, exists := channels[packet.Channel]
|
|
if exists == true {
|
|
channel.Commands<-SessionChannelCommand{
|
|
Session: session,
|
|
Packet: &packet,
|
|
}
|
|
} else {
|
|
server.Log("Command for unknown channel %d", packet.Channel)
|
|
}
|
|
case DataPacket:
|
|
channels := server.channels.Load().(map[ChannelID]*Channel)
|
|
channel, exists := channels[packet.Channel]
|
|
if exists == true {
|
|
members := channel.Members.Load().([]*ServerSession)
|
|
if slices.Contains(members, session) {
|
|
mode, has_mode := channel.Modes[packet.Mode]
|
|
if has_mode {
|
|
mode.Load().(Mode).Data(session, packet.Channel, members, packet.Data)
|
|
}
|
|
}
|
|
} else {
|
|
server.Log("Data for unknown channel %d", packet.Channel)
|
|
}
|
|
|
|
default:
|
|
server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
|
|
}
|
|
}
|
|
|
|
}
|
|
server.Log("Stopping session incoming goroutine %s", session.ID)
|
|
}
|
|
|
|
func(server *Server) handle_session_open(client_session_open []byte, from *net.UDPAddr) error {
|
|
session, session_opened, err := ParseSessionOpen(server.key, client_session_open)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
server.sessions_lock.Lock()
|
|
server.sessions[session.ID] = &ServerSession{
|
|
Session: session,
|
|
LastSeen: time.Now(),
|
|
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
|
|
OutgoingPackets: make(chan Payload, SESSION_BUFFER_SIZE),
|
|
}
|
|
server.sessions[session.ID].active.Store(true)
|
|
server.sessions_lock.Unlock()
|
|
|
|
go handle_session_outgoing(server.sessions[session.ID], server)
|
|
go handle_session_incoming(server.sessions[session.ID], server)
|
|
|
|
_, err = server.connection.WriteToUDP(session_opened, from)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
server.Log("Started session %s with %x", session.ID, session.Peer)
|
|
return nil
|
|
}
|
|
|
|
func(server *Server) handle_session_connect(session_connect []byte, from *net.UDPAddr) error {
|
|
session_id := SessionID(session_connect[:SESSION_ID_LENGTH])
|
|
session, exists := server.sessions[session_id]
|
|
if exists == false {
|
|
return fmt.Errorf("Session %s does not exist, can't connect", session_id)
|
|
}
|
|
|
|
session_connected, err := ParseSessionTimed(SESSION_CONNECTED, server.key, session_connect, &session.Session, time.Now().Add(-1*time.Second).UnixMilli())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO: fix, was client_addr but the client doesnt know it's nat assignment
|
|
session.remote = from
|
|
session.LastSeen = time.Now()
|
|
|
|
_, err = server.connection.WriteToUDP(session_connected, session.remote)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
server.Log("Sent server_hello for %s to %s(from %s)", session.ID, session.remote, from)
|
|
return nil
|
|
}
|
|
|
|
func(server *Server) handle_session_close(session_close []byte, from *net.UDPAddr) error {
|
|
session_id := SessionID(session_close[:SESSION_ID_LENGTH])
|
|
session, exists := server.sessions[session_id]
|
|
if exists == false {
|
|
return fmt.Errorf("Session %s does not exist, can't close", session_id)
|
|
}
|
|
|
|
session_closed, err := ParseSessionTimed(SESSION_CLOSED, server.key, session_close, &session.Session, time.Now().Add(-1*time.Second).UnixMilli())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
server.sessions_lock.Lock()
|
|
server.close_session(session)
|
|
server.sessions_lock.Unlock()
|
|
|
|
_, err = server.connection.WriteToUDP(session_closed, session.remote)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
server.Log("Session %s closed", session_id)
|
|
return nil
|
|
}
|
|
|
|
// TODO: handle packets without creating so many objects(and preventing an extra copy)
|
|
func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error {
|
|
session_id := SessionID(data[:SESSION_ID_LENGTH])
|
|
session, exists := server.sessions[session_id]
|
|
if exists == false {
|
|
return fmt.Errorf("Session %s does not exist, can't receive data", session_id)
|
|
}
|
|
|
|
session.LastSeen = time.Now()
|
|
|
|
buf_copy := make([]byte, len(data) - SESSION_ID_LENGTH)
|
|
copy(buf_copy, data[SESSION_ID_LENGTH:])
|
|
|
|
select {
|
|
case session.IncomingPackets<-buf_copy:
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("Dropped packet to session %s", session_id)
|
|
}
|
|
}
|
|
|
|
func(server *Server) listen_udp() {
|
|
server.Log("Started server on %s", server.connection.LocalAddr())
|
|
|
|
var buf [SERVER_UDP_BUFFER_SIZE]byte
|
|
for true {
|
|
read, from, err := server.connection.ReadFromUDP(buf[:])
|
|
if err == nil {
|
|
var packet_type SessionPacketType = SessionPacketType(buf[0])
|
|
switch packet_type {
|
|
case SESSION_OPEN:
|
|
err := server.handle_session_open(buf[COMMAND_LENGTH:read], from)
|
|
if err != nil {
|
|
server.Log("handle_session_open erorr - %s", err)
|
|
}
|
|
|
|
case SESSION_CONNECT:
|
|
err := server.handle_session_connect(buf[COMMAND_LENGTH:read], from)
|
|
if err != nil {
|
|
server.Log("handle_session_connect error - %s", err)
|
|
}
|
|
|
|
case SESSION_CLOSE:
|
|
err := server.handle_session_close(buf[COMMAND_LENGTH:read], from)
|
|
if err != nil {
|
|
server.Log("handle_session_close error - %s", err)
|
|
}
|
|
|
|
case SESSION_DATA:
|
|
err := server.handle_session_data(buf[COMMAND_LENGTH:read], from)
|
|
if err != nil {
|
|
server.Log("handle_session_data error - %s", err)
|
|
}
|
|
|
|
default:
|
|
server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read])
|
|
}
|
|
} else if errors.Is(err, net.ErrClosed) {
|
|
server.Log("UDP_CLOSE: %s", server.connection.LocalAddr())
|
|
break
|
|
} else {
|
|
server.Log("UDP_READ_ERROR: %s", err)
|
|
}
|
|
}
|
|
|
|
channels := server.channels.Load().(map[ChannelID]*Channel)
|
|
for _, channel := range(channels) {
|
|
close(channel.Commands)
|
|
}
|
|
|
|
server.sessions_lock.Lock()
|
|
sessions := make([]*ServerSession, 0, len(server.sessions))
|
|
for _, session := range(server.sessions) {
|
|
sessions = append(sessions, session)
|
|
}
|
|
|
|
for _, session := range(sessions) {
|
|
server.close_session(session)
|
|
}
|
|
server.sessions_lock.Unlock()
|
|
|
|
server.Log("Shut down server on %s", server.connection.LocalAddr())
|
|
server.stopped <- nil
|
|
}
|
|
|
|
func(server *Server) close_session(session *ServerSession) {
|
|
session.active.Store(false)
|
|
close(session.IncomingPackets)
|
|
close(session.OutgoingPackets)
|
|
delete(server.sessions, session.ID)
|
|
|
|
session_closed := NewSessionTimed(SESSION_CLOSED, server.key, &session.Session, time.Now())
|
|
server.connection.WriteToUDP(session_closed, session.remote)
|
|
}
|
|
|
|
func(server *Server) update_state() {
|
|
for server.active.Load() {
|
|
select {
|
|
case command := <-server.commands:
|
|
server.Log("Incoming server command %+v", command)
|
|
}
|
|
}
|
|
}
|