diff --git a/client.go b/client.go index 37c068d..a81355e 100644 --- a/client.go +++ b/client.go @@ -3,7 +3,10 @@ package pnyx import ( "crypto/ed25519" "crypto/rand" + mrand "math/rand" + "encoding/binary" "crypto/sha512" + "crypto/aes" "net" "sync" "fmt" @@ -26,9 +29,9 @@ const ( type Client struct { Key ed25519.PrivateKey + Session Session ConnectionLock sync.Mutex Connection *net.UDPConn - State ClientState } func ID[T ~[16]byte, V ~[]byte](data V) T { @@ -36,79 +39,90 @@ func ID[T ~[16]byte, V ~[]byte](data V) T { return (T)(hash[0:16]) } -func NewClient(key ed25519.PrivateKey) (Client, error) { +func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { if key == nil { var err error _, key, err = ed25519.GenerateKey(rand.Reader) if err != nil { - return Client{}, err + return nil, err } } - return Client{ - Key: key, - State: CLIENT_SESSION_CREATE, - }, nil -} + seed_bytes := make([]byte, 8) + read, err := rand.Read(seed_bytes) + if err != nil { + return nil, err + } else if read != 8 { + return nil, fmt.Errorf("Failed to create IV seed for client") + } -func(client *Client) Connect(remote string) (ed25519.PublicKey, []byte, error) { - client.ConnectionLock.Lock() - defer client.ConnectionLock.Unlock() - address, err := net.ResolveUDPAddr("udp", remote) if err != nil { - return nil, nil, err + return nil, err } - client.Connection, err = net.DialUDP("udp", nil, address) + connection, err := net.DialUDP("udp", nil, address) if err != nil { - return nil, nil, err + return nil, err } - session_open, ecdh_private, err := NewSessionOpen(client.Key) + session_open, ecdh_private, err := NewSessionOpen(key) if err != nil { - client.Connection.Close() - client.Connection = nil - return nil, nil, err + connection.Close() + return nil, err } - _, err = client.Connection.Write(session_open) + _, err = connection.Write(session_open) if err != nil { - return nil, nil, err + return nil, err } var response = [512]byte{} - read, _, err := client.Connection.ReadFromUDP(response[:]) + read, _, err = connection.ReadFromUDP(response[:]) if err != nil { - return nil, nil, err + return nil, err } if response[0] != byte(SESSION_OPEN) { - return nil, nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0]) + return nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0]) } - server_public, ecdh_public, err := ParseSessionOpen(response[COMMAND_LENGTH:read]) + server_pubkey, ecdh_public, err := ParseSessionOpen(response[COMMAND_LENGTH:read]) if err != nil { - return nil, nil, err + return nil, err } secret, err := ECDH(ecdh_public, ecdh_private) if err != nil { - return nil, nil, err + return nil, err } - client.State = CLIENT_SESSION_CONNECT + session_cipher, err := aes.NewCipher(secret) + if err != nil { + return nil, err + } - session_connect := NewSessionConnect(client.Connection.LocalAddr().(*net.UDPAddr), secret) - _, err = client.Connection.Write(session_connect) + session_connect := NewSessionConnect(connection.LocalAddr().(*net.UDPAddr), secret) + _, err = connection.Write(session_connect) if err != nil { - return nil, nil, err + return nil, err } - read, _, err = client.Connection.ReadFromUDP(response[:]) + read, _, err = connection.ReadFromUDP(response[:]) if err != nil { - return nil, nil, err + return nil, err } - return server_public, secret, nil + return &Client{ + Key: key, + Session: Session{ + ID: ID[SessionID](secret), + remote: address, + Peer: ID[ClientID](server_pubkey), + secret: secret, + cipher: session_cipher, + iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64), + }, + Connection: connection, + }, nil } diff --git a/cmd/client/main.go b/cmd/client/main.go index bfad157..8225285 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -8,15 +8,10 @@ import ( ) func main() { - client, err := pnyx.NewClient(nil) + client, err := pnyx.NewClient(nil, os.Args[1]) if err != nil { panic(err) } - server_public, secret, err := client.Connect(os.Args[1]) - if err != nil { - panic(err) - } - - fmt.Printf("Started session %s with %s", pnyx.ID[pnyx.SessionID](secret), pnyx.ID[pnyx.ClientID](server_public)) + fmt.Printf("Started session %s with %s", client.Session.ID, client.Session.Peer) } diff --git a/packet.go b/packet.go index a83a61c..4c86883 100644 --- a/packet.go +++ b/packet.go @@ -174,7 +174,7 @@ func NewSessionData(session *Session, packet []byte) ([]byte, error) { stream := cipher.NewOFB(session.cipher, iv[:]) header := make([]byte, COMMAND_LENGTH + ID_LENGTH) header[0] = byte(SESSION_DATA) - copy(header[1:], session.id[:]) + copy(header[1:], session.ID[:]) packet_encrypted := bytes.NewBuffer(header) writer := &cipher.StreamWriter{S: stream, W: packet_encrypted} diff --git a/server.go b/server.go index c15dd80..fe80de5 100644 --- a/server.go +++ b/server.go @@ -26,9 +26,9 @@ func(id SessionID) String() string { } type Session struct { - id SessionID + ID SessionID remote *net.UDPAddr - peer ClientID + Peer ClientID secret []byte cipher cipher.Block iv_generator mrand.Source64 @@ -174,9 +174,9 @@ func(server *Server) run() { } session := &Session{ - id: session_id, + ID: session_id, remote: nil, - peer: client_id, + Peer: client_id, secret: session_secret, cipher: session_cipher, iv_generator: mrand.NewSource(int64(binary.BigEndian.Uint64(seed_bytes))).(mrand.Source64), @@ -199,7 +199,7 @@ func(server *Server) run() { } session.remote = client_addr - server.Log("Got SESSION_CONNECT for client %s at address %s", session.peer, session.remote) + server.Log("Got SESSION_CONNECT for client %s at address %s", session.Peer, session.remote) // TODO: Send server hello back default: