Added session open and connect code, and started session data

live
noah metz 2024-04-06 16:38:14 -06:00
parent a52250bcf8
commit a438837c81
8 changed files with 556 additions and 137 deletions

@ -0,0 +1,22 @@
package pnyx
import (
)
type ChannelID uint32
const RootChannelID = 0
type ModeID uint8
type CommandID uint8
type PermissionMap map[ClientID]map[ModeID]map[CommandID]bool
type Channel struct {
modes map[ModeID]Mode
permissions PermissionMap
parent ChannelID
}
type Mode interface {
}

@ -4,15 +4,36 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/sha512" "crypto/sha512"
"net"
"sync"
"fmt"
"github.com/google/uuid"
)
type ClientID uuid.UUID
func(id ClientID) String() string {
return uuid.UUID(id).String()
}
type ClientState uint8
const (
CLIENT_SESSION_CREATE ClientState = iota
CLIENT_SESSION_CONNECT
CLIENT_SESSION_CONNECTED
) )
type Client struct { type Client struct {
Key ed25519.PrivateKey Key ed25519.PrivateKey
ConnectionLock sync.Mutex
Connection *net.UDPConn
State ClientState
} }
func(client Client) ID() ClientID { func ID[T ~[16]byte, V ~[]byte](data V) T {
hash := sha512.Sum512([]byte(client.Key.Public().(ed25519.PublicKey))) hash := sha512.Sum512(data)
return (ClientID)(hash[0:16]) return (T)(hash[0:16])
} }
func NewClient(key ed25519.PrivateKey) (Client, error) { func NewClient(key ed25519.PrivateKey) (Client, error) {
@ -26,5 +47,68 @@ func NewClient(key ed25519.PrivateKey) (Client, error) {
return Client{ return Client{
Key: key, Key: key,
State: CLIENT_SESSION_CREATE,
}, nil }, nil
} }
func(client *Client) Connect(remote string) (ed25519.PublicKey, []byte, error) {
client.ConnectionLock.Lock()
defer client.ConnectionLock.Unlock()
address, err := net.ResolveUDPAddr("udp", remote)
if err != nil {
return nil, nil, err
}
client.Connection, err = net.DialUDP("udp", nil, address)
if err != nil {
return nil, nil, err
}
session_open, ecdh_private, err := NewSessionOpen(client.Key)
if err != nil {
client.Connection.Close()
client.Connection = nil
return nil, nil, err
}
_, err = client.Connection.Write(session_open)
if err != nil {
return nil, nil, err
}
var response = [512]byte{}
read, _, err := client.Connection.ReadFromUDP(response[:])
if err != nil {
return nil, nil, err
}
if response[0] != byte(SESSION_OPEN) {
return nil, nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0])
}
server_public, ecdh_public, err := ParseSessionOpen(response[COMMAND_LENGTH:read])
if err != nil {
return nil, nil, err
}
secret, err := ECDH(ecdh_public, ecdh_private)
if err != nil {
return nil, nil, err
}
client.State = CLIENT_SESSION_CONNECT
session_connect := NewSessionConnect(client.Connection.LocalAddr().(*net.UDPAddr), secret)
_, err = client.Connection.Write(session_connect)
if err != nil {
return nil, nil, err
}
read, _, err = client.Connection.ReadFromUDP(response[:])
if err != nil {
return nil, nil, err
}
return server_public, secret, nil
}

@ -1,29 +1,22 @@
package main package main
import ( import (
"net" "fmt"
"os" "os"
"time"
"git.metznet.ca/MetzNet/pnyx"
) )
func main() { func main() {
address, err := net.ResolveUDPAddr("udp", os.Args[1]) client, err := pnyx.NewClient(nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
connection, err := net.DialUDP("udp", nil, address) server_public, secret, err := client.Connect(os.Args[1])
if err != nil { if err != nil {
panic(err) panic(err)
} }
for true { fmt.Printf("Started session %s with %s", pnyx.ID[pnyx.SessionID](secret), pnyx.ID[pnyx.ClientID](server_public))
written, err := connection.Write([]byte(os.Args[2]))
if written != len(os.Args[2]) {
panic(written)
} else if err != nil {
panic(err)
}
time.Sleep(time.Second)
}
} }

@ -13,9 +13,12 @@ func main() {
os_sigs := make(chan os.Signal, 1) os_sigs := make(chan os.Signal, 1)
signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGINT) signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGINT)
server := pnyx.NewServer() server, err := pnyx.NewServer(nil)
if err != nil {
panic(err)
}
err := server.Start(os.Args[1]) err = server.Start(os.Args[1])
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -0,0 +1,175 @@
So far I've been thinking of this as similar to IRC, but without passing messages between servers(if a user wants messages from a server they need to connect directly).
This complicates DMs since there isn't a message routing path between servers, but a potential workaround is to have a "mailbox" instead of DMs(who stores the mailbox?)
So far this is the architecture I've though of:
- At the core is a session manager on top of a UDP socket. This session manager allows clients to start and reconnect to sessions with the server.
- Once in a session with the server, a user can send commands, and the server streams data to the user depending on the commands sent.
- Similar to IRC a server is split into 'channels', but unlike IRC these channels are confined to a specific server(so can be referenced by the server to be unique)
- Channels are multimodal and when choosing to join a channel clients select which modes they are joining.
With this two problems I can think of are:
1. Text channels retaining data and sending that retained data
2. Sending direct messages to users would be confined to servers, or complicated
To solve the text channel problem I propose that text channels in the discord sense are not the same as the text mode of a pnyx channel since pnyx channels are live-streaming data while discord text channels are more of a forum.
Is there a way to split these in such a way that either can be operated independently, or the application can be expanded in similar ways easily(modular)?
Right now it sounds like the server would act as a multiplexer for data subscribed by channel, and the forum feature doesn't relate to that.
Other than forum, another good analogy is a group chat or slack channel where the messages are asyncrhonous, but not quite as isolated as forum posts.
The problem with 'forum mode' being another mode of a channel is that it requires user commands specific to the mode(which others would to I guess like not seeing everyones video at once).
So if 'forum mode' was the default text chat mode, these would be the channel modes:
- forum
- audio
- video
modes can be implemented modularly, which means that channels will have to store different state objects based on the supported modes
The mode module would have to be responsible for the multiplexing of the packets based off of it's state, so there should be some function that takes in a packet and the channel state, and returns packets to send with the updated channel state
- It could also be responsible for setting up routing tables, so the function would be function(command, state, routing_table) -> (new_state, routing_table)
- This way the state doesn't have to be updated every message, just on commands and the path for data is minimalized
Modes would have to be completely independent for simplicity
So basically:
- A server is a collection of channels
- A channel is a collection of modes
- A mode is a routing table and state information, which is updated by commands
What about channel permissions and creating channels?
Can/Should it be implemented such that channels are to a server as modes are to a channel?
I don't think there's a reason to nest past modes, so there's no reason to make modes and channels the same.
Wait but if channels are a mode on the server, then you can nest channels by making channels with the channel mode, so maybe it is useful?
OK so for permissions it's going to be completely server-based, and in the config you can specify a public key to be a server admin for easy configuration.
How would commands work? I'm thinking it would be a struct like this
struct {
uint8 mode
uint8 command
[]byte data
}
So for example, to join the raw mode of the servers root channel would be something like
{
mode: 0x00 (MODE_RAW)
command: 0x00 (MODE_RAW_COMMAND_JOIN)
data: 0 length byte array
}
To send data to the raw mode of the servers root channel would be
{
mode: 0x00 (MODE_RAW)
command: 0x01 (MODE_RAW_COMMAND_DATA)
data: n length byte array
}
If channels are a tree(with the server being a root channel with only the 'channels' mode), then permissions would similarily be tree-based and defined on a per-mode basis.
e.x. for someone to have all permissions on all modes on the server they would get the 'wildcard'(*) permission. If someone was granted all permissions within the 'test' channel on the server then it would be something like:
`c/test/*`, broken down this is `c` for the 'channels' mode(the typical start for a command), `test` to specify the channel, and then wildcard
The downside of this is how commands would be targetted. For routing purpouses it could make sense to have it be layered. E.x. assuming 0x0C is the 'channel' mode and '0x00' is the 'raw' mode, a data packet could look something like this
{
mode: 0x0C (MODE_CHANNELS)
command: 0x01 (MODE_CHANNELS_DATA)
data: {
channel: 0xDE (sub-channel ID)
command: {
mode: 0x0C (MODE_CHANNELS)
command: 0x01 (MODE_CHANNELS_DATA)
data: {
channel: 0xAD
command: {
mode: 0x00 (MODE_RAW)
command: 0x01 (MODE_RAW_COMMAND_DATA)
data: n length byte array
}
}
}
}
}
It would be better if channels were not nested and instead can be referenced by a global ID(instead of a hierarchical ID), but the tree can still be maintained in memory for orginazation/permissions.
So instead of nested 0x0C commands it would be like this:
{
channel: 0xDEAD (server-unique channel identifier)
mode: 0x00 (MODE_RAW)
command: 0x01
data: ...
}
Would the commands to modify/create/delete these channels then be server-wide, or parsed by the channel?
If the commands are processed by the server state and update the server state then I'd store a map in memory of:
Yea I like that, but how do permissions work?
If I'm client ID X and send a command for channel Y(which hierarchically is A/Y) and I have all permissions on channel A, how is this looked up?
The command would target channel Y, so first lookup would be server.Channels[Y]
First check if the ClientID has the permission on the channel directly with that Permissions that have been looked up
If that returns access denied, go to the channels parent(if it's not the zero ID to signify the root channel) and check if the user has wildcard permissions on the parent
Continue that until the root node, returning access granted if the user has the direct(or wildcard) permissions on the channel itself, or wildcard permissions on any of it's parents
The downside of this is that the first check is likely expensive, while the other checks are cheap(checking if a user is
So right now the server state would be something like
type ChannelID uuid.UUID
type Channel struct{
Children []ChannelID,
Parent ChannelID
}
type Permission string
type Permissions struct {
This map[ClientID]map[ModeID][]Permission
Children map[ClientID]map[ChannelID]
}
type Server struct {
Channels map[ChannelID]Channel
Permissions map[ClientID]map[ChannelID]map[ModeID][]ComandID
}
Permissions with 1 admin user:
{
ADMIN_ID: nil
}
What the difference between that and:
{
ADMIN_ID: {
ZERO_ID: nil
}
}
How is "no commands on the root channel, but all commands on the root channels children" expressed?
{
}
Also this allows for the permission map to look like:
{
ADMIN_ID: {
ZERO_ID: nil
0x0F: {
MODE_RAW: {
JOIN,
SEND,
LEAVE,
}
}
}
}

@ -1,10 +1,14 @@
package pnyx package pnyx
import ( import (
"bytes"
"crypto/cipher"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/sha512" "crypto/sha512"
"encoding/binary" "encoding/binary"
"io"
"net"
"fmt" "fmt"
"slices" "slices"
@ -12,16 +16,17 @@ import (
"filippo.io/edwards25519" "filippo.io/edwards25519"
) )
type PacketType uint16 type PacketType uint8
const ( const (
ID_LENGTH = 16 ID_LENGTH = 16
PUBKEY_LENGTH = 32 PUBKEY_LENGTH = 32
ECDH_LENGTH = 32 ECDH_LENGTH = 32
SIGNATURE_LENGTH = 64 SIGNATURE_LENGTH = 64
HMAC_LENGTH = 64 HMAC_LENGTH = 64
COMMAND_LENGTH = 1
SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH
SESSION_CONNECT_LENGTH = ID_LENGTH + 2 + HMAC_LENGTH // + return addr string length SESSION_CONNECT_LENGTH = 2 + HMAC_LENGTH // + return addr string length
SESSION_OPEN PacketType = iota SESSION_OPEN PacketType = iota
SESSION_CONNECT SESSION_CONNECT
@ -30,63 +35,22 @@ const (
SESSION_DATA SESSION_DATA
) )
func SessionKeyID(session_secret []byte) SessionID { func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) {
hash := sha512.Sum512(session_secret) public_point, err := (&edwards25519.Point{}).SetBytes(public)
return (SessionID)(hash[0:16]) if err != nil {
} return nil, err
func NewSessionConnect(address string, session_secret []byte) []byte {
packet := make([]byte, SESSION_CONNECT_LENGTH + len(address))
cur := 0
session_id := [16]byte(SessionKeyID(session_secret))
copy(packet[cur:], session_id[:])
cur += ID_LENGTH
binary.BigEndian.PutUint16(packet[cur:], uint16(len(address)))
cur += 2
copy(packet[cur:], []byte(address))
cur += len(address)
hmac := sha512.Sum512(append(packet[:cur], session_secret...))
copy(packet[cur:], hmac[:])
return packet
}
func ParseSessionConnect(session_connect []byte, session_secret []byte) (SessionID, string, error) {
if len(session_connect) < SESSION_CONNECT_LENGTH {
return SessionID{}, "", fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH)
} }
cur := 0
session_id := SessionID(session_connect[cur:cur+ID_LENGTH])
cur += ID_LENGTH
address_length := int(binary.BigEndian.Uint16(session_connect[cur:cur+2]))
cur += 2
if len(session_connect) != (SESSION_CONNECT_LENGTH + address_length) { h := sha512.Sum512(private.Seed())
return SessionID{}, "", fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH + address_length) private_scalar, err := (&edwards25519.Scalar{}).SetBytesWithClamping(h[:32])
if err != nil {
return nil, err
} }
address := string(session_connect[cur:cur+address_length]) shared_point := public_point.ScalarMult(private_scalar, public_point)
cur += address_length
hmac_digest := make([]byte, cur)
copy(hmac_digest, session_connect[:cur])
hmac := session_connect[cur:cur+HMAC_LENGTH]
cur += HMAC_LENGTH
calculated_hmac := sha512.Sum512(append(hmac_digest, session_secret...))
if slices.Compare(hmac, calculated_hmac[:]) != 0 {
return SessionID{}, "", fmt.Errorf("Session connect bad HMAC")
}
return session_id, address, nil return shared_point.BytesMontgomery(), nil
} }
func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error) { func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error) {
@ -99,16 +63,19 @@ func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error)
return nil, nil, fmt.Errorf("Failed to generate ecdh key: %w", err) return nil, nil, fmt.Errorf("Failed to generate ecdh key: %w", err)
} }
packet := make([]byte, SESSION_OPEN_LENGTH) packet := make([]byte, COMMAND_LENGTH + SESSION_OPEN_LENGTH)
cur := 0 cur := 0
packet[0] = byte(SESSION_OPEN)
cur += COMMAND_LENGTH
copy(packet[cur:], []byte(key.Public().(ed25519.PublicKey))) copy(packet[cur:], []byte(key.Public().(ed25519.PublicKey)))
cur += PUBKEY_LENGTH cur += PUBKEY_LENGTH
copy(packet[cur:], []byte(public)) copy(packet[cur:], []byte(public))
cur += PUBKEY_LENGTH cur += PUBKEY_LENGTH
signature := ed25519.Sign(key, packet[:cur]) signature := ed25519.Sign(key, packet[COMMAND_LENGTH:cur])
copy(packet[cur:], signature) copy(packet[cur:], signature)
cur += SIGNATURE_LENGTH cur += SIGNATURE_LENGTH
@ -139,20 +106,97 @@ func ParseSessionOpen(session_open []byte) (ed25519.PublicKey, ed25519.PublicKey
return client_pubkey, client_ecdh, nil return client_pubkey, client_ecdh, nil
} }
func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) { func NewSessionConnect(address *net.UDPAddr, session_secret []byte) []byte {
public_point, err := (&edwards25519.Point{}).SetBytes(public) packet := make([]byte, COMMAND_LENGTH + ID_LENGTH + SESSION_CONNECT_LENGTH + len(address.String()))
cur := 0
packet[cur] = byte(SESSION_CONNECT)
cur += COMMAND_LENGTH
session_id := [16]byte(ID[SessionID](session_secret))
copy(packet[cur:], session_id[:])
cur += ID_LENGTH
binary.BigEndian.PutUint16(packet[cur:], uint16(len(address.String())))
cur += 2
copy(packet[cur:], []byte(address.String()))
cur += len(address.String())
hmac := sha512.Sum512(append(packet[COMMAND_LENGTH+ID_LENGTH:cur], session_secret...))
copy(packet[cur:], hmac[:])
return packet
}
func ParseSessionConnect(session_connect []byte, session_secret []byte) (*net.UDPAddr, error) {
if len(session_connect) < SESSION_CONNECT_LENGTH {
return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH)
}
cur := 0
address_length := int(binary.BigEndian.Uint16(session_connect[cur:cur+2]))
cur += 2
if len(session_connect) != (SESSION_CONNECT_LENGTH + address_length) {
return nil, fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH + address_length)
}
address := string(session_connect[cur:cur+address_length])
cur += address_length
hmac_digest := make([]byte, cur)
copy(hmac_digest, session_connect[:cur])
hmac := session_connect[cur:cur+HMAC_LENGTH]
cur += HMAC_LENGTH
calculated_hmac := sha512.Sum512(append(hmac_digest, session_secret...))
if slices.Compare(hmac, calculated_hmac[:]) != 0 {
return nil, fmt.Errorf("Session connect bad HMAC")
}
addr, err := net.ResolveUDPAddr("udp", address)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("Error parsing return address: %w", err)
} }
return addr, nil
}
h := sha512.Sum512(private.Seed()) func NewSessionData(session *Session, packet []byte) ([]byte, error) {
private_scalar, err := (&edwards25519.Scalar{}).SetBytesWithClamping(h[:32]) iv := make([]byte, 32)
for i := 0; i < 4; i++ {
binary.BigEndian.PutUint64(iv[i*8:], session.iv_generator.Uint64())
}
stream := cipher.NewOFB(session.cipher, iv[:])
header := make([]byte, COMMAND_LENGTH + ID_LENGTH)
header[0] = byte(SESSION_DATA)
copy(header[1:], session.id[:])
packet_encrypted := bytes.NewBuffer(header)
writer := &cipher.StreamWriter{S: stream, W: packet_encrypted}
_, err := io.Copy(writer, bytes.NewBuffer(packet))
if err != nil { if err != nil {
return nil, err return nil, err
} }
shared_point := public_point.ScalarMult(private_scalar, public_point) return packet_encrypted.Bytes(), nil
}
return shared_point.BytesMontgomery(), nil func ParseSessionData(session *Session, data []byte) ([]byte, error) {
iv := data[0:32]
stream := cipher.NewOFB(session.cipher, iv)
var packet_clear bytes.Buffer
reader := &cipher.StreamReader{S: stream, R: bytes.NewBuffer(data)}
_, err := io.Copy(&packet_clear, reader)
if err != nil {
return nil, err
}
return packet_clear.Bytes(), nil
} }

@ -3,6 +3,7 @@ package pnyx
import ( import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"net"
"slices" "slices"
"testing" "testing"
) )
@ -20,7 +21,7 @@ func TestSessionOpen(t *testing.T) {
session_open, client_ecdh, err := NewSessionOpen(client_key) session_open, client_ecdh, err := NewSessionOpen(client_key)
fatalErr(t, err) fatalErr(t, err)
client_pubkey_parsed, client_ecdh_parsed, err := ParseSessionOpen(session_open) client_pubkey_parsed, client_ecdh_parsed, err := ParseSessionOpen(session_open[COMMAND_LENGTH:])
fatalErr(t, err) fatalErr(t, err)
if slices.Compare(client_pubkey, client_pubkey_parsed) != 0 { if slices.Compare(client_pubkey, client_pubkey_parsed) != 0 {
@ -52,17 +53,14 @@ func TestECDH(t *testing.T) {
func TestSessionConnect(t *testing.T) { func TestSessionConnect(t *testing.T) {
secret := make([]byte, 32) secret := make([]byte, 32)
test_addr := "test_addr" test_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:8080")
fatalErr(t, err)
session_connect := NewSessionConnect(test_addr, secret) session_connect := NewSessionConnect(test_addr, secret)
parsed_session_id, parsed_addr, err := ParseSessionConnect(session_connect, secret) parsed_addr, err := ParseSessionConnect(session_connect[COMMAND_LENGTH + ID_LENGTH:], secret)
fatalErr(t, err) fatalErr(t, err)
if parsed_addr != test_addr { if parsed_addr.String() != test_addr.String() {
t.Fatalf("Parsed address(%s) does not match test address(%s)", parsed_addr, test_addr) t.Fatalf("Parsed address(%s) does not match test address(%s)", parsed_addr, test_addr)
} }
if parsed_session_id != SessionKeyID(secret) {
t.Fatalf("Parsed session ID %s does not match original %s", parsed_session_id, SessionKeyID(secret))
}
} }

@ -1,9 +1,14 @@
package pnyx package pnyx
import ( import (
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
mrand "math/rand"
"net" "net"
"os" "os"
"sync/atomic" "sync/atomic"
@ -15,65 +20,86 @@ const (
SERVER_UDP_BUFFER_SIZE = 2048 SERVER_UDP_BUFFER_SIZE = 2048
) )
/*
Session Flow:
1. Client send a SESSION_OPEN packet to the server with first half for ECDH and it's public key
2. Server responds with an SESSION_AUTHENTICATE packet with the server's ECDH half and public key
aside - at this point if the client and server both hold the private parts of their public keys,
they both hold the same ECDH secret which they put in a KDF to generate a symmetric session key
aside - at this point the server creates the session in memory, but there is no return address associated
with it yet
3. Client sends a SESSION_CONNECT packet to the server with the session ID in cleartext and the
return address hashed with the key(to prove the return address has not been modified without the key)
4. Server adds the return address to the session info, and maps the address to the session for future packets
5. Server sends the HELLO packet to the client encrypted by the session key
If a client disconnects at any point and gets a new return address:
1. Client sends a SESSION_CONNECT packet to the server from the new socket
2. Server removes the old return address, and fills/maps the new return address
If a client wants to gracefully disconnect and notify the server to close the session:
1. Client sends a SESSION_CLOSE
2. Server responds with SESSION_CLOSED
Session Packets:
1. SESSION_OPEN
Payload is CLIENT_PUBKEY + ECDH_HALF + SIGNATURE
3. SESSION_CONNECT
4. SESSION_CLOSE
5. SESSION_CLOSED
*/
type SessionID uuid.UUID type SessionID uuid.UUID
type ClientID uuid.UUID func(id SessionID) String() string {
return uuid.UUID(id).String()
type Connection struct {
state string
session SessionID
} }
type Session struct { type Session struct {
state string id SessionID
client ClientID remote *net.UDPAddr
peer ClientID
secret []byte
cipher cipher.Block
iv_generator mrand.Source64
} }
type Server struct { type Server struct {
key ed25519.PrivateKey
active atomic.Bool active atomic.Bool
connection *net.UDPConn connection *net.UDPConn
stopped chan error stopped chan error
connections map[string]SessionID sessions map[SessionID]*Session
sessions map[SessionID]Session channels map[ChannelID]*Channel
} }
func NewServer() *Server { func NewServer(key ed25519.PrivateKey) (*Server, error) {
if key == nil {
var err error
_, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
}
server := &Server{ server := &Server{
key: key,
connection: nil, connection: nil,
active: atomic.Bool{}, active: atomic.Bool{},
stopped: make(chan error, 0), stopped: make(chan error, 0),
sessions: map[SessionID]*Session{},
channels: map[ChannelID]*Channel{},
} }
server.active.Store(false) server.active.Store(false)
return server return server, nil
}
// Check if the client has permission for the command on the channel
// If it's not specified, check the permission on the parent
func Allowed(server *Server, client ClientID, channel_id ChannelID, mode ModeID, command CommandID) bool {
channel, exists := server.channels[channel_id]
if exists == false {
return false
}
if channel.permissions != nil {
client_perms, exists := channel.permissions[client]
if exists {
if client_perms == nil {
return true
}
mode_perms, exists := client_perms[mode]
if exists {
if mode_perms == nil {
return true
}
allowed, exists := mode_perms[command]
if exists {
return allowed
}
}
}
}
// Prevent a cycle on the root channel
if channel_id == RootChannelID {
return false
} else {
return Allowed(server, client, channel.parent, mode, command)
}
} }
func (server *Server) Log(format string, fields ...interface{}) { func (server *Server) Log(format string, fields ...interface{}) {
@ -98,12 +124,86 @@ func(server *Server) run() {
var buf [SERVER_UDP_BUFFER_SIZE]byte var buf [SERVER_UDP_BUFFER_SIZE]byte
for true { for true {
read, addr, err := server.connection.ReadFromUDP(buf[:]) read, from, err := server.connection.ReadFromUDP(buf[:])
if err == nil { if err == nil {
var packet_type PacketType = PacketType(binary.BigEndian.Uint16(buf[0:2])) var packet_type PacketType = PacketType(buf[0])
switch packet_type { switch packet_type {
case SESSION_OPEN:
client_pubkey, ecdh_public, err := ParseSessionOpen(buf[COMMAND_LENGTH:read])
if err != nil {
server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
client_id := ID[ClientID](client_pubkey)
session_open, ecdh_private, err := NewSessionOpen(server.key)
if err != nil {
server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
_, err = server.connection.WriteToUDP(session_open, from)
if err != nil {
server.Log("WriteToUDP error %s", err)
continue
}
session_secret, err := ECDH(ecdh_public, ecdh_private)
if err != nil {
server.Log("ECDH error %s", err)
continue
}
session_id := ID[SessionID](session_secret)
session_cipher, err := aes.NewCipher(session_secret)
if err != nil {
server.Log("AES error %s", err)
continue
}
seed_bytes := make([]byte, 8)
read, err := rand.Read(seed_bytes)
if err != nil {
server.Log("IV Seed error: %s", err)
continue
} else if read != 8 {
server.Log("IV Seed error: not enough bytes read %d/4", read)
continue
}
session := &Session{
id: session_id,
remote: nil,
peer: client_id,
secret: session_secret,
cipher: session_cipher,
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
}
server.sessions[session_id] = session
server.Log("Started session %s with %s", session_id, client_id)
case SESSION_CONNECT:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
server.Log("Session %s does not exist, can't connect", session_id)
continue
}
client_addr, err := ParseSessionConnect(buf[COMMAND_LENGTH+ID_LENGTH:read], session.secret)
if err != nil {
server.Log("Error parsing session connect: %s", err)
continue
}
session.remote = client_addr
server.Log("Got SESSION_CONNECT for client %s at address %s", session.peer, session.remote)
// TODO: Send server hello back
default: default:
server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, addr, buf[:read]) server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read])
} }
} else if errors.Is(err, net.ErrClosed) { } else if errors.Is(err, net.ErrClosed) {
server.Log("UDP_CLOSE: %s", server.connection.LocalAddr()) server.Log("UDP_CLOSE: %s", server.connection.LocalAddr())