package pnyx import ( "crypto/ed25519" "crypto/rand" "crypto/sha512" "encoding/binary" "fmt" "slices" "filippo.io/edwards25519" ) type PacketType uint16 const ( ID_LENGTH = 16 PUBKEY_LENGTH = 32 ECDH_LENGTH = 32 SIGNATURE_LENGTH = 64 HMAC_LENGTH = 64 SESSION_OPEN_LENGTH = PUBKEY_LENGTH + ECDH_LENGTH + SIGNATURE_LENGTH SESSION_CONNECT_LENGTH = ID_LENGTH + 2 + HMAC_LENGTH // + return addr string length SESSION_OPEN PacketType = iota SESSION_CONNECT SESSION_CLOSE SESSION_CLOSED SESSION_DATA ) func SessionKeyID(session_secret []byte) SessionID { hash := sha512.Sum512(session_secret) return (SessionID)(hash[0:16]) } func NewSessionConnect(address string, session_secret []byte) []byte { packet := make([]byte, SESSION_CONNECT_LENGTH + len(address)) cur := 0 session_id := [16]byte(SessionKeyID(session_secret)) copy(packet[cur:], session_id[:]) cur += ID_LENGTH binary.BigEndian.PutUint16(packet[cur:], uint16(len(address))) cur += 2 copy(packet[cur:], []byte(address)) cur += len(address) hmac := sha512.Sum512(append(packet[:cur], session_secret...)) copy(packet[cur:], hmac[:]) return packet } func ParseSessionConnect(session_connect []byte, session_secret []byte) (SessionID, string, error) { if len(session_connect) < SESSION_CONNECT_LENGTH { return SessionID{}, "", fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH) } cur := 0 session_id := SessionID(session_connect[cur:cur+ID_LENGTH]) cur += ID_LENGTH address_length := int(binary.BigEndian.Uint16(session_connect[cur:cur+2])) cur += 2 if len(session_connect) != (SESSION_CONNECT_LENGTH + address_length) { return SessionID{}, "", fmt.Errorf("Bad session connect length: %d/%d", len(session_connect), SESSION_CONNECT_LENGTH + address_length) } address := string(session_connect[cur:cur+address_length]) cur += address_length hmac_digest := make([]byte, cur) copy(hmac_digest, session_connect[:cur]) hmac := session_connect[cur:cur+HMAC_LENGTH] cur += HMAC_LENGTH calculated_hmac := sha512.Sum512(append(hmac_digest, session_secret...)) if slices.Compare(hmac, calculated_hmac[:]) != 0 { return SessionID{}, "", fmt.Errorf("Session connect bad HMAC") } return session_id, address, nil } func NewSessionOpen(key ed25519.PrivateKey) ([]byte, ed25519.PrivateKey, error) { if key == nil { return nil, nil, fmt.Errorf("Cannot create a SESSION_OPEN packet without a key") } public, private, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, nil, fmt.Errorf("Failed to generate ecdh key: %w", err) } packet := make([]byte, SESSION_OPEN_LENGTH) cur := 0 copy(packet[cur:], []byte(key.Public().(ed25519.PublicKey))) cur += PUBKEY_LENGTH copy(packet[cur:], []byte(public)) cur += PUBKEY_LENGTH signature := ed25519.Sign(key, packet[:cur]) copy(packet[cur:], signature) cur += SIGNATURE_LENGTH return packet, private, nil } func ParseSessionOpen(session_open []byte) (ed25519.PublicKey, ed25519.PublicKey, error) { if len(session_open) != SESSION_OPEN_LENGTH { return nil, nil, fmt.Errorf("Bad SESSION_OPEN length: %d/%d", len(session_open), SESSION_OPEN_LENGTH) } cur := 0 client_pubkey := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH]) cur += PUBKEY_LENGTH client_ecdh := (ed25519.PublicKey)(session_open[cur:cur+PUBKEY_LENGTH]) cur += PUBKEY_LENGTH digest := session_open[:cur] signature := session_open[cur:cur+SIGNATURE_LENGTH] cur += SIGNATURE_LENGTH if ed25519.Verify(client_pubkey, digest, signature) == false { return nil, nil, fmt.Errorf("SESSION_OPEN signature verification failed") } return client_pubkey, client_ecdh, nil } func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) { public_point, err := (&edwards25519.Point{}).SetBytes(public) if err != nil { return nil, err } h := sha512.Sum512(private.Seed()) private_scalar, err := (&edwards25519.Scalar{}).SetBytesWithClamping(h[:32]) if err != nil { return nil, err } shared_point := public_point.ScalarMult(private_scalar, public_point) return shared_point.BytesMontgomery(), nil }