pnyx/server.go

226 lines
5.6 KiB
Go

package pnyx
import (
"crypto/cipher"
"crypto/ed25519"
"crypto/rand"
"errors"
"fmt"
mrand "math/rand"
"net"
"os"
"sync/atomic"
"github.com/google/uuid"
)
const (
SERVER_UDP_BUFFER_SIZE = 2048
)
type SessionID uuid.UUID
func(id SessionID) String() string {
return uuid.UUID(id).String()
}
type Session struct {
ID SessionID
remote *net.UDPAddr
Peer PeerID
secret []byte
cipher cipher.Block
iv_generator mrand.Source64
}
type Server struct {
key ed25519.PrivateKey
active atomic.Bool
connection *net.UDPConn
stopped chan error
sessions map[SessionID]*Session
channels map[ChannelID]*Channel
}
func NewServer(key ed25519.PrivateKey) (*Server, error) {
if key == nil {
var err error
_, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
}
server := &Server{
key: key,
connection: nil,
active: atomic.Bool{},
stopped: make(chan error, 0),
sessions: map[SessionID]*Session{},
channels: map[ChannelID]*Channel{},
}
server.active.Store(false)
return server, nil
}
// Check if the client has permission for the command on the channel
// If it's not specified, check the permission on the parent
func Allowed(server *Server, client PeerID, channel_id ChannelID, mode ModeID, command CommandID) bool {
channel, exists := server.channels[channel_id]
if exists == false {
return false
}
if channel.permissions != nil {
client_perms, exists := channel.permissions[client]
if exists {
if client_perms == nil {
return true
}
mode_perms, exists := client_perms[mode]
if exists {
if mode_perms == nil {
return true
}
allowed, exists := mode_perms[command]
if exists {
return allowed
}
}
}
}
// Prevent a cycle on the root channel
if channel_id == RootChannelID {
return false
} else {
return Allowed(server, client, channel.parent, mode, command)
}
}
func (server *Server) Log(format string, fields ...interface{}) {
fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n")
}
func(server *Server) Stop() error {
was_active := server.active.CompareAndSwap(true, false)
if was_active {
err := server.connection.Close()
if err != nil {
return err
}
return <-server.stopped
} else {
return fmt.Errorf("Called stop func on stopped server")
}
}
func(server *Server) run() {
server.Log("Started server on %s", server.connection.LocalAddr())
var buf [SERVER_UDP_BUFFER_SIZE]byte
for true {
read, from, err := server.connection.ReadFromUDP(buf[:])
if err == nil {
var packet_type PacketType = PacketType(buf[0])
switch packet_type {
case SESSION_OPEN:
session_open, ecdh_private, err := NewSessionOpen(server.key)
if err != nil {
server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
session, err := ParseSessionOpen(ecdh_private, buf[COMMAND_LENGTH:read])
if err != nil {
server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
server.sessions[session.ID] = &session
_, err = server.connection.WriteToUDP(session_open, from)
if err != nil {
server.Log("WriteToUDP error %s", err)
continue
}
server.Log("Started session %s with %s", session.ID, session.Peer)
case SESSION_CONNECT:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
server.Log("Session %s does not exist, can't connect", session_id)
continue
}
client_addr, err := ParseSessionConnect(buf[COMMAND_LENGTH+ID_LENGTH:read], session.secret)
if err != nil {
server.Log("Error parsing session connect: %s", err)
continue
}
session.remote = client_addr
server.Log("Got SESSION_CONNECT for client %s at address %s", session.Peer, session.remote)
// TODO: Send server hello back
case SESSION_CLOSE:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
server.Log("Session %s does not exist, can't close", session_id)
continue
}
err := ParseSessionClose(session, buf[COMMAND_LENGTH+ID_LENGTH:])
if err != nil {
server.Log("Session close error for %s - %s", session_id, err)
continue
}
delete(server.sessions, session_id)
server.Log("Session %s closed", session_id)
case SESSION_DATA:
default:
server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read])
}
} else if errors.Is(err, net.ErrClosed) {
server.Log("UDP_CLOSE: %s", server.connection.LocalAddr())
break
} else {
server.Log("UDP_READ_ERROR: %s", err)
}
}
server.Log("Shut down server on %s", server.connection.LocalAddr())
server.stopped <- nil
}
func(server *Server) Start(listen string) error {
was_inactive := server.active.CompareAndSwap(false, true)
if was_inactive == false {
return fmt.Errorf("Server already active")
}
address, err := net.ResolveUDPAddr("udp", listen)
if err != nil {
server.active.Store(false)
return err
}
server.connection, err = net.ListenUDP("udp", address)
if err != nil {
server.active.Store(false)
return fmt.Errorf("Failed to create udp server: %w", err)
}
go server.run()
return nil
}