Reorganized client code

live
noah metz 2024-04-06 17:03:31 -06:00
parent a438837c81
commit 118975f490
4 changed files with 56 additions and 47 deletions

@ -3,7 +3,10 @@ package pnyx
import ( import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
mrand "math/rand"
"encoding/binary"
"crypto/sha512" "crypto/sha512"
"crypto/aes"
"net" "net"
"sync" "sync"
"fmt" "fmt"
@ -26,9 +29,9 @@ const (
type Client struct { type Client struct {
Key ed25519.PrivateKey Key ed25519.PrivateKey
Session Session
ConnectionLock sync.Mutex ConnectionLock sync.Mutex
Connection *net.UDPConn Connection *net.UDPConn
State ClientState
} }
func ID[T ~[16]byte, V ~[]byte](data V) T { func ID[T ~[16]byte, V ~[]byte](data V) T {
@ -36,79 +39,90 @@ func ID[T ~[16]byte, V ~[]byte](data V) T {
return (T)(hash[0:16]) return (T)(hash[0:16])
} }
func NewClient(key ed25519.PrivateKey) (Client, error) { func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
if key == nil { if key == nil {
var err error var err error
_, key, err = ed25519.GenerateKey(rand.Reader) _, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil { if err != nil {
return Client{}, err return nil, err
} }
} }
return Client{ seed_bytes := make([]byte, 8)
Key: key, read, err := rand.Read(seed_bytes)
State: CLIENT_SESSION_CREATE, if err != nil {
}, nil return nil, err
} else if read != 8 {
return nil, fmt.Errorf("Failed to create IV seed for client")
} }
func(client *Client) Connect(remote string) (ed25519.PublicKey, []byte, error) {
client.ConnectionLock.Lock()
defer client.ConnectionLock.Unlock()
address, err := net.ResolveUDPAddr("udp", remote) address, err := net.ResolveUDPAddr("udp", remote)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
client.Connection, err = net.DialUDP("udp", nil, address) connection, err := net.DialUDP("udp", nil, address)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
session_open, ecdh_private, err := NewSessionOpen(client.Key) session_open, ecdh_private, err := NewSessionOpen(key)
if err != nil { if err != nil {
client.Connection.Close() connection.Close()
client.Connection = nil return nil, err
return nil, nil, err
} }
_, err = client.Connection.Write(session_open) _, err = connection.Write(session_open)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
var response = [512]byte{} var response = [512]byte{}
read, _, err := client.Connection.ReadFromUDP(response[:]) read, _, err = connection.ReadFromUDP(response[:])
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
if response[0] != byte(SESSION_OPEN) { if response[0] != byte(SESSION_OPEN) {
return nil, nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0]) return nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0])
} }
server_public, ecdh_public, err := ParseSessionOpen(response[COMMAND_LENGTH:read]) server_pubkey, ecdh_public, err := ParseSessionOpen(response[COMMAND_LENGTH:read])
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
secret, err := ECDH(ecdh_public, ecdh_private) secret, err := ECDH(ecdh_public, ecdh_private)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
client.State = CLIENT_SESSION_CONNECT session_cipher, err := aes.NewCipher(secret)
if err != nil {
return nil, err
}
session_connect := NewSessionConnect(client.Connection.LocalAddr().(*net.UDPAddr), secret) session_connect := NewSessionConnect(connection.LocalAddr().(*net.UDPAddr), secret)
_, err = client.Connection.Write(session_connect) _, err = connection.Write(session_connect)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
read, _, err = client.Connection.ReadFromUDP(response[:]) read, _, err = connection.ReadFromUDP(response[:])
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
return server_public, secret, nil return &Client{
Key: key,
Session: Session{
ID: ID[SessionID](secret),
remote: address,
Peer: ID[ClientID](server_pubkey),
secret: secret,
cipher: session_cipher,
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
},
Connection: connection,
}, nil
} }

@ -8,15 +8,10 @@ import (
) )
func main() { func main() {
client, err := pnyx.NewClient(nil) client, err := pnyx.NewClient(nil, os.Args[1])
if err != nil { if err != nil {
panic(err) panic(err)
} }
server_public, secret, err := client.Connect(os.Args[1]) fmt.Printf("Started session %s with %s", client.Session.ID, client.Session.Peer)
if err != nil {
panic(err)
}
fmt.Printf("Started session %s with %s", pnyx.ID[pnyx.SessionID](secret), pnyx.ID[pnyx.ClientID](server_public))
} }

@ -174,7 +174,7 @@ func NewSessionData(session *Session, packet []byte) ([]byte, error) {
stream := cipher.NewOFB(session.cipher, iv[:]) stream := cipher.NewOFB(session.cipher, iv[:])
header := make([]byte, COMMAND_LENGTH + ID_LENGTH) header := make([]byte, COMMAND_LENGTH + ID_LENGTH)
header[0] = byte(SESSION_DATA) header[0] = byte(SESSION_DATA)
copy(header[1:], session.id[:]) copy(header[1:], session.ID[:])
packet_encrypted := bytes.NewBuffer(header) packet_encrypted := bytes.NewBuffer(header)
writer := &cipher.StreamWriter{S: stream, W: packet_encrypted} writer := &cipher.StreamWriter{S: stream, W: packet_encrypted}

@ -26,9 +26,9 @@ func(id SessionID) String() string {
} }
type Session struct { type Session struct {
id SessionID ID SessionID
remote *net.UDPAddr remote *net.UDPAddr
peer ClientID Peer ClientID
secret []byte secret []byte
cipher cipher.Block cipher cipher.Block
iv_generator mrand.Source64 iv_generator mrand.Source64
@ -174,9 +174,9 @@ func(server *Server) run() {
} }
session := &Session{ session := &Session{
id: session_id, ID: session_id,
remote: nil, remote: nil,
peer: client_id, Peer: client_id,
secret: session_secret, secret: session_secret,
cipher: session_cipher, cipher: session_cipher,
iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64), iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64),
@ -199,7 +199,7 @@ func(server *Server) run() {
} }
session.remote = client_addr session.remote = client_addr
server.Log("Got SESSION_CONNECT for client %s at address %s", session.peer, session.remote) server.Log("Got SESSION_CONNECT for client %s at address %s", session.Peer, session.remote)
// TODO: Send server hello back // TODO: Send server hello back
default: default: