242 lines
6.0 KiB
Go
242 lines
6.0 KiB
Go
package pnyx
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
mrand "math/rand"
|
|
"net"
|
|
"os"
|
|
"sync/atomic"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
const (
|
|
SERVER_UDP_BUFFER_SIZE = 2048
|
|
)
|
|
|
|
type SessionID uuid.UUID
|
|
func(id SessionID) String() string {
|
|
return uuid.UUID(id).String()
|
|
}
|
|
|
|
type Session struct {
|
|
ID SessionID
|
|
remote *net.UDPAddr
|
|
Peer ClientID
|
|
secret []byte
|
|
cipher cipher.Block
|
|
iv_generator mrand.Source64
|
|
}
|
|
|
|
type Server struct {
|
|
key ed25519.PrivateKey
|
|
active atomic.Bool
|
|
connection *net.UDPConn
|
|
stopped chan error
|
|
|
|
sessions map[SessionID]*Session
|
|
channels map[ChannelID]*Channel
|
|
}
|
|
|
|
func NewServer(key ed25519.PrivateKey) (*Server, error) {
|
|
if key == nil {
|
|
var err error
|
|
_, key, err = ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
server := &Server{
|
|
key: key,
|
|
connection: nil,
|
|
active: atomic.Bool{},
|
|
stopped: make(chan error, 0),
|
|
|
|
sessions: map[SessionID]*Session{},
|
|
channels: map[ChannelID]*Channel{},
|
|
}
|
|
server.active.Store(false)
|
|
return server, nil
|
|
}
|
|
|
|
// Check if the client has permission for the command on the channel
|
|
// If it's not specified, check the permission on the parent
|
|
func Allowed(server *Server, client ClientID, channel_id ChannelID, mode ModeID, command CommandID) bool {
|
|
channel, exists := server.channels[channel_id]
|
|
if exists == false {
|
|
return false
|
|
}
|
|
|
|
if channel.permissions != nil {
|
|
client_perms, exists := channel.permissions[client]
|
|
if exists {
|
|
if client_perms == nil {
|
|
return true
|
|
}
|
|
|
|
mode_perms, exists := client_perms[mode]
|
|
if exists {
|
|
if mode_perms == nil {
|
|
return true
|
|
}
|
|
|
|
allowed, exists := mode_perms[command]
|
|
if exists {
|
|
return allowed
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Prevent a cycle on the root channel
|
|
if channel_id == RootChannelID {
|
|
return false
|
|
} else {
|
|
return Allowed(server, client, channel.parent, mode, command)
|
|
}
|
|
}
|
|
|
|
func (server *Server) Log(format string, fields ...interface{}) {
|
|
fmt.Fprint(os.Stderr, fmt.Sprintf(format, fields...) + "\n")
|
|
}
|
|
|
|
func(server *Server) Stop() error {
|
|
was_active := server.active.CompareAndSwap(true, false)
|
|
if was_active {
|
|
err := server.connection.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return <-server.stopped
|
|
} else {
|
|
return fmt.Errorf("Called stop func on stopped server")
|
|
}
|
|
}
|
|
|
|
func(server *Server) run() {
|
|
server.Log("Started server on %s", server.connection.LocalAddr())
|
|
|
|
var buf [SERVER_UDP_BUFFER_SIZE]byte
|
|
for true {
|
|
read, from, err := server.connection.ReadFromUDP(buf[:])
|
|
if err == nil {
|
|
var packet_type PacketType = PacketType(buf[0])
|
|
switch packet_type {
|
|
case SESSION_OPEN:
|
|
client_pubkey, ecdh_public, err := ParseSessionOpen(buf[COMMAND_LENGTH:read])
|
|
if err != nil {
|
|
server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
|
|
continue
|
|
}
|
|
|
|
client_id := ID[ClientID](client_pubkey)
|
|
|
|
session_open, ecdh_private, err := NewSessionOpen(server.key)
|
|
if err != nil {
|
|
server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
|
|
continue
|
|
}
|
|
|
|
_, err = server.connection.WriteToUDP(session_open, from)
|
|
if err != nil {
|
|
server.Log("WriteToUDP error %s", err)
|
|
continue
|
|
}
|
|
|
|
session_secret, err := ECDH(ecdh_public, ecdh_private)
|
|
if err != nil {
|
|
server.Log("ECDH error %s", err)
|
|
continue
|
|
}
|
|
|
|
session_id := ID[SessionID](session_secret)
|
|
|
|
session_cipher, err := aes.NewCipher(session_secret)
|
|
if err != nil {
|
|
server.Log("AES error %s", err)
|
|
continue
|
|
}
|
|
|
|
seed_bytes := make([]byte, 8)
|
|
read, err := rand.Read(seed_bytes)
|
|
if err != nil {
|
|
server.Log("IV Seed error: %s", err)
|
|
continue
|
|
} else if read != 8 {
|
|
server.Log("IV Seed error: not enough bytes read %d/4", read)
|
|
continue
|
|
}
|
|
|
|
session := &Session{
|
|
ID: session_id,
|
|
remote: nil,
|
|
Peer: client_id,
|
|
secret: session_secret,
|
|
cipher: session_cipher,
|
|
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
|
|
}
|
|
server.sessions[session_id] = session
|
|
|
|
server.Log("Started session %s with %s", session_id, client_id)
|
|
case SESSION_CONNECT:
|
|
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
|
|
session, exists := server.sessions[session_id]
|
|
if exists == false {
|
|
server.Log("Session %s does not exist, can't connect", session_id)
|
|
continue
|
|
}
|
|
|
|
client_addr, err := ParseSessionConnect(buf[COMMAND_LENGTH+ID_LENGTH:read], session.secret)
|
|
if err != nil {
|
|
server.Log("Error parsing session connect: %s", err)
|
|
continue
|
|
}
|
|
|
|
session.remote = client_addr
|
|
server.Log("Got SESSION_CONNECT for client %s at address %s", session.Peer, session.remote)
|
|
|
|
// TODO: Send server hello back
|
|
default:
|
|
server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read])
|
|
}
|
|
} else if errors.Is(err, net.ErrClosed) {
|
|
server.Log("UDP_CLOSE: %s", server.connection.LocalAddr())
|
|
break
|
|
} else {
|
|
server.Log("UDP_READ_ERROR: %s", err)
|
|
}
|
|
}
|
|
|
|
server.Log("Shut down server on %s", server.connection.LocalAddr())
|
|
server.stopped <- nil
|
|
}
|
|
|
|
func(server *Server) Start(listen string) error {
|
|
was_inactive := server.active.CompareAndSwap(false, true)
|
|
if was_inactive == false {
|
|
return fmt.Errorf("Server already active")
|
|
}
|
|
|
|
address, err := net.ResolveUDPAddr("udp", listen)
|
|
if err != nil {
|
|
server.active.Store(false)
|
|
return err
|
|
}
|
|
|
|
server.connection, err = net.ListenUDP("udp", address)
|
|
if err != nil {
|
|
server.active.Store(false)
|
|
return fmt.Errorf("Failed to create udp server: %w", err)
|
|
}
|
|
|
|
go server.run()
|
|
|
|
return nil
|
|
}
|