Moved global variables in client example to local

live
noah metz 2024-04-23 16:38:08 -06:00
parent f7ef36ba55
commit 73312a37cd
3 changed files with 208 additions and 116 deletions

@ -3,30 +3,121 @@ package pnyx
import ( import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"fmt"
"net" "net"
"sync" "sync/atomic"
"time" "time"
"fmt"
) )
type ClientState uint8 type ClientState uint8
const ( const (
CLIENT_SESSION_CREATE ClientState = iota CLIENT_SESSION_CONNECTING ClientState = iota
CLIENT_SESSION_CONNECT
CLIENT_SESSION_CONNECTED CLIENT_SESSION_CONNECTED
) )
type Client struct { type Client struct {
Key ed25519.PrivateKey Key ed25519.PrivateKey
Session Session Session Session
ConnectionLock sync.Mutex remote string
Connection *net.UDPConn
connection *net.UDPConn
to_send chan []byte
data_fn func(Payload)error
state ClientState
active atomic.Bool
} }
func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { func(client *Client) Remote() string {
return client.remote
}
func(client *Client) Active() bool {
return client.active.Load()
}
const CLIENT_UDP_BUFFER = 2048
const CLIENT_MAX_CONNECT_ATTEMPTS = 5
func(client *Client) Log(format string, vals ...any) {
fmt.Printf("%s\n", fmt.Sprintf(format, vals...))
}
func(client *Client) listen_udp() {
buffer := [CLIENT_UDP_BUFFER]byte{}
for client.active.Load() {
read, err := client.connection.Read(buffer[:])
if err != nil {
client.Log("Client listen error - %s", err)
} else if read == 0 {
client.Log("Client listen error - no data in packet")
} else {
packet_type := SessionPacketType(buffer[0])
switch client.state {
case CLIENT_SESSION_CONNECTING:
switch packet_type {
case SESSION_CONNECTED:
client.state = CLIENT_SESSION_CONNECTED
case SESSION_CLOSED:
client.Log("Server repsonded to session connect with SESSION_CLOSED")
client.active.Store(false)
default:
client.Log("Bad session packet type 0x%02x for client in state SESSION_CONNECTING", packet_type)
}
case CLIENT_SESSION_CONNECTED:
switch packet_type {
case SESSION_DATA:
if len(buffer) < SESSION_ID_LENGTH + 1 {
client.Log("Not enough data to decode SESSION_DATA packet %d/%d", len(buffer), SESSION_ID_LENGTH + 1)
continue
}
session_id := SessionID(buffer[1:1+SESSION_ID_LENGTH])
if session_id != client.Session.ID {
client.Log("Session ID of data packet does not match client.Session %s =/= %s", session_id, client.Session.ID)
continue
}
data, err := ParseSessionData(&client.Session, buffer[1+SESSION_ID_LENGTH:read])
if err != nil {
client.Log("Error parsing session data: %s", err)
continue
}
payload, err := ParsePacket(data)
if err != nil {
client.Log("Error parsing packet from session data: %s", err)
continue
}
switch payload.(type) {
case PingPacket:
err = client.Send(NewPingPacket())
if err != nil {
client.Log("Error sending ping packet: %s", err)
}
}
err = client.data_fn(payload)
if err != nil {
client.Log("Error running data_fn: %s", err)
}
default:
client.Log("Bad session packet type 0x%02x for client in state SESSION_CONNECTED", packet_type)
}
default:
client.Log("Received packet while in bad client state 0x%02x", client.state)
}
}
}
}
func NewClient(key ed25519.PrivateKey, remote string, data_fn func(Payload)error) (*Client, error) {
if key == nil { if key == nil {
return nil, fmt.Errorf("Need a key to create a client, passed nil") return nil, fmt.Errorf("Need a key to create a client, passed nil")
} else if data_fn == nil {
return nil, fmt.Errorf("Need a function to run with session data")
} }
seed_bytes := make([]byte, 8) seed_bytes := make([]byte, 8)
@ -53,12 +144,15 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
return nil, err return nil, err
} }
var response = [512]byte{}
var session Session
for attempts := 0; attempts <= CLIENT_MAX_CONNECT_ATTEMPTS; attempts++ {
_, err = connection.Write(session_open) _, err = connection.Write(session_open)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var response = [512]byte{} // TODO: handle timeout
read, _, err = connection.ReadFromUDP(response[:]) read, _, err = connection.ReadFromUDP(response[:])
if err != nil { if err != nil {
return nil, err return nil, err
@ -68,37 +162,62 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
return nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0]) return nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0])
} }
session, err := ParseSessionOpened(nil, ecdh_private, response[COMMAND_LENGTH:read]) session, err = ParseSessionOpened(nil, ecdh_private, response[COMMAND_LENGTH:read])
if err != nil { if err == nil {
return nil, err break
} }
session_connect := NewSessionTimed(SESSION_CONNECT, key, &session, time.Now()) net_error, ok := err.(net.Error)
_, err = connection.Write(session_connect) if ok == false {
if err != nil { return nil, err
} else if net_error.Timeout() {
if attempts == CLIENT_MAX_CONNECT_ATTEMPTS {
return nil, fmt.Errorf("Failed to connect to server at %s", remote)
}
} else {
return nil, err return nil, err
} }
}
read, _, err = connection.ReadFromUDP(response[:]) session_connect := NewSessionTimed(SESSION_CONNECT, key, &session, time.Now())
_, err = connection.Write(session_connect)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Client{ client := &Client{
Key: key, Key: key,
Session: session, Session: session,
Connection: connection, remote: remote,
}, nil
}
func(client *Client) Send(packet *Packet) error { connection: connection,
client.ConnectionLock.Lock() to_send: make(chan []byte, 1000),
defer client.ConnectionLock.Unlock() data_fn: data_fn,
state: CLIENT_SESSION_CONNECTING,
}
client.active.Store(true)
go client.listen_udp()
go client.send_queue()
if client.Connection == nil { return client, nil
return fmt.Errorf("Client is not connected") }
func(client *Client) send_queue() {
for client.active.Load() {
packet := <- client.to_send
if packet == nil {
break
}
_, err := client.connection.Write(packet)
if err != nil {
client.Log("Send error: %s", err)
} }
}
}
func(client *Client) Send(packet *Packet) error {
data, err := packet.MarshalBinary() data, err := packet.MarshalBinary()
if err != nil { if err != nil {
return err return err
@ -109,21 +228,20 @@ func(client *Client) Send(packet *Packet) error {
return err return err
} }
_, err = client.Connection.Write(wrapped) select {
return err case client.to_send <- wrapped:
return nil
default:
return fmt.Errorf("Channel overflow")
}
} }
func(client *Client) Close() error { func(client *Client) Close() error {
client.ConnectionLock.Lock() var err error = fmt.Errorf("Client not active")
defer client.ConnectionLock.Unlock() if client.active.CompareAndSwap(true, false) {
close(client.to_send)
if client.Connection == nil { client.connection.Write(NewSessionTimed(SESSION_CLOSE, client.Key, &client.Session, time.Now()))
return fmt.Errorf("No connection to close") err = client.connection.Close()
} }
client.Connection.Write(NewSessionTimed(SESSION_CLOSE, client.Key, &client.Session, time.Now()))
err := client.Connection.Close()
client.Connection = nil
return err return err
} }

@ -10,25 +10,21 @@ import (
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
"sync/atomic"
"syscall" "syscall"
"time" "time"
"git.metznet.ca/MetzNet/go-ncurses"
"git.metznet.ca/MetzNet/pnyx" "git.metznet.ca/MetzNet/pnyx"
"github.com/gen2brain/malgo" "github.com/gen2brain/malgo"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hraban/opus" "github.com/hraban/opus"
"git.metznet.ca/MetzNet/go-ncurses"
) )
var decoders = map[pnyx.PeerID]chan[]byte{} var decoders = map[pnyx.PeerID]chan[]byte{}
var encoder *opus.Encoder var encoder *opus.Encoder
var sample_rate int = 0 var sample_rate int = 0
var mic = make(chan []byte, 0)
var speaker = make(chan []int16, 0)
var audio_data = make(chan []int16, 0)
func set_sample_rate(new_sample_rate int) error { func set_sample_rate(audio_data chan []int16, new_sample_rate int) error {
sample_rate = new_sample_rate sample_rate = new_sample_rate
var err error var err error
@ -43,12 +39,12 @@ func set_sample_rate(new_sample_rate int) error {
} }
new_chan := make(chan[]byte, 1000) new_chan := make(chan[]byte, 1000)
decoders[peer_id] = new_chan decoders[peer_id] = new_chan
go handle_peer_decode(peer_id, decoders[peer_id], sample_rate) go handle_peer_decode(audio_data, peer_id, decoders[peer_id], sample_rate)
} }
return nil return nil
} }
func handle_peer_decode(peer_id pnyx.PeerID, decode_chan chan[]byte, sample_rate int){ func handle_peer_decode(audio_data chan []int16, peer_id pnyx.PeerID, decode_chan chan[]byte, sample_rate int){
decoder, err := opus.NewDecoder(sample_rate, 1) decoder, err := opus.NewDecoder(sample_rate, 1)
if err != nil { if err != nil {
panic(err) panic(err)
@ -106,14 +102,12 @@ func mixer(data_chan chan []int16, speaker_chan chan []int16) {
} }
} }
func setup_audio() (*malgo.AllocatedContext, *malgo.Device, *malgo.Device) { func setup_audio(mic chan []byte, speaker chan []int16) (*malgo.AllocatedContext, *malgo.Device, *malgo.Device) {
ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil) ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
go mixer(audio_data, speaker)
infos, err := ctx.Devices(malgo.Playback) infos, err := ctx.Devices(malgo.Playback)
if err != nil { if err != nil {
panic(err) panic(err)
@ -277,25 +271,21 @@ func get_private_key(path string, generate bool) ed25519.PrivateKey {
} }
} }
func main_loop(client *pnyx.Client, window ncurses.Window, active *atomic.Bool, packet_chan chan pnyx.Payload, user_chan chan rune) { func main_loop(client *pnyx.Client, audio_data chan []int16, window ncurses.Window, packet_chan chan pnyx.Payload, user_chan chan rune) {
max_y := ncurses.GetMaxY(window) max_y := ncurses.GetMaxY(window)
max_x := ncurses.GetMaxX(window) max_x := ncurses.GetMaxX(window)
titlebar := ncurses.NewWin(1, max_x, 0, 0) titlebar := ncurses.NewWin(1, max_x, 0, 0)
channels := ncurses.NewWin(max_y - 1, max_x / 3, 1, 0) channels := ncurses.NewWin(max_y - 1, max_x / 3, 1, 0)
body := ncurses.NewWin(max_y - 1, max_x * 2 / 3, 1, max_x / 3) body := ncurses.NewWin(max_y - 1, max_x * 2 / 3, 1, max_x / 3)
server_name := client.Connection.RemoteAddr().String()
ncurses.MvWAddStr(titlebar, 0, 0, fmt.Sprintf("pnyx client %X:%X", client.Key.Public().(ed25519.PublicKey)[:2], client.Session.ID[:2])) ncurses.MvWAddStr(titlebar, 0, 0, fmt.Sprintf("pnyx client %X:%X", client.Key.Public().(ed25519.PublicKey)[:2], client.Session.ID[:2]))
ncurses.MvWAddStr(body, 0, max_x-len(server_name), server_name) ncurses.MvWAddStr(body, 0, max_x-len(client.Remote()), client.Remote())
ncurses.WRefresh(titlebar) ncurses.WRefresh(titlebar)
for active.Load() { for client.Active() {
select { select {
case packet := <-packet_chan: case packet := <-packet_chan:
switch packet := packet.(type) { switch packet := packet.(type) {
case pnyx.PingPacket:
_ = client.Send(pnyx.NewPingPacket())
case pnyx.ChannelCommandPacket: case pnyx.ChannelCommandPacket:
if packet.Channel == pnyx.ChannelID(0) { if packet.Channel == pnyx.ChannelID(0) {
if packet.Mode == pnyx.MODE_AUDIO { if packet.Mode == pnyx.MODE_AUDIO {
@ -309,14 +299,14 @@ func main_loop(client *pnyx.Client, window ncurses.Window, active *atomic.Bool,
default: default:
continue continue
} }
err := set_sample_rate(new_sample_rate)
err := set_sample_rate(audio_data, new_sample_rate)
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
} }
} }
case pnyx.CommandPacket:
case pnyx.PeerPacket: case pnyx.PeerPacket:
if packet.Channel == pnyx.ChannelID(0) { if packet.Channel == pnyx.ChannelID(0) {
decode_chan, exists := decoders[packet.Peer] decode_chan, exists := decoders[packet.Peer]
@ -324,7 +314,7 @@ func main_loop(client *pnyx.Client, window ncurses.Window, active *atomic.Bool,
if sample_rate != 0 { if sample_rate != 0 {
decode_chan = make(chan[]byte, 1000) decode_chan = make(chan[]byte, 1000)
decoders[packet.Peer] = decode_chan decoders[packet.Peer] = decode_chan
go handle_peer_decode(packet.Peer, decoders[packet.Peer], sample_rate) go handle_peer_decode(audio_data, packet.Peer, decoders[packet.Peer], sample_rate)
decode_chan <- packet.Data decode_chan <- packet.Data
} else { } else {
decoders[packet.Peer] = nil decoders[packet.Peer] = nil
@ -333,7 +323,6 @@ func main_loop(client *pnyx.Client, window ncurses.Window, active *atomic.Bool,
decode_chan <- packet.Data decode_chan <- packet.Data
} }
} }
default:
} }
case char := <-user_chan: case char := <-user_chan:
ncurses.MvWAddStr(body, 0, 0, string(char)) ncurses.MvWAddStr(body, 0, 0, string(char))
@ -350,9 +339,9 @@ func bitmatch(b byte, pattern byte, length int) bool {
return (b ^ pattern) & mask == 0 return (b ^ pattern) & mask == 0
} }
func ch_listen(active *atomic.Bool, user_chan chan rune) { func ch_listen(client *pnyx.Client, user_chan chan rune) {
b := [4]byte{} b := [4]byte{}
for active.Load() { for client.Active() {
os.Stdin.Read(b[0:1]) os.Stdin.Read(b[0:1])
if bitmatch(b[0], 0b00000000, 1) { if bitmatch(b[0], 0b00000000, 1) {
user_chan <- rune(b[0]) user_chan <- rune(b[0])
@ -360,29 +349,7 @@ func ch_listen(active *atomic.Bool, user_chan chan rune) {
} }
} }
func udp_listen(client *pnyx.Client, active *atomic.Bool, packet_chan chan pnyx.Payload) { func process_mic(client *pnyx.Client, mic chan []byte) {
var buf [1024]byte
for active.Load() {
read, _, err := client.Connection.ReadFromUDP(buf[:])
if err != nil {
break
}
data, err := pnyx.ParseSessionData(&client.Session, buf[pnyx.COMMAND_LENGTH + pnyx.SESSION_ID_LENGTH:read])
if err != nil {
continue
}
packet, err := pnyx.ParsePacket(data)
if err != nil {
continue
}
packet_chan <- packet
}
}
func process_mic(client *pnyx.Client) {
for true { for true {
data := <- mic data := <- mic
err := client.Send(pnyx.NewDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data)) err := client.Send(pnyx.NewDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data))
@ -397,7 +364,13 @@ func main() {
generate_key_arg := flag.Bool("genkey", false, "Set to generate a key if none exists") generate_key_arg := flag.Bool("genkey", false, "Set to generate a key if none exists")
flag.Parse() flag.Parse()
ctx, outDevice, inDevice := setup_audio() var audio_data = make(chan []int16, 0)
var speaker = make(chan []int16, 0)
go mixer(audio_data, speaker)
var mic = make(chan []byte, 0)
ctx, outDevice, inDevice := setup_audio(mic, speaker)
defer ctx.Free() defer ctx.Free()
defer ctx.Uninit() defer ctx.Uninit()
@ -406,21 +379,23 @@ func main() {
defer inDevice.Uninit() defer inDevice.Uninit()
defer inDevice.Stop() defer inDevice.Stop()
key := get_private_key(os.ExpandEnv(*key_file_arg), *generate_key_arg)
client, err := pnyx.NewClient(key, flag.Arg(0)) user_chan := make(chan rune, 1024)
packet_chan := make(chan pnyx.Payload, 1024)
key := get_private_key(os.ExpandEnv(*key_file_arg), *generate_key_arg)
client, err := pnyx.NewClient(key, flag.Arg(0), func(payload pnyx.Payload) error {
select {
case packet_chan <- payload:
return nil
default:
return fmt.Errorf("Channel overflow")
}
})
if err != nil { if err != nil {
panic(err) panic(err)
} }
packet_chan := make(chan pnyx.Payload, 1024)
user_chan := make(chan rune, 1024)
active := atomic.Bool{}
active.Store(true)
go udp_listen(client, &active, packet_chan)
join_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_COMMAND_JOIN, nil) join_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_COMMAND_JOIN, nil)
err = client.Send(join_packet) err = client.Send(join_packet)
if err != nil { if err != nil {
@ -433,17 +408,16 @@ func main() {
panic(err) panic(err)
} }
go process_mic(client) go process_mic(client, mic)
go ch_listen(client, user_chan)
window := ncurses.InitScr() window := ncurses.InitScr()
go main_loop(client, audio_data, window, packet_chan, user_chan)
go ch_listen(&active, user_chan)
go main_loop(client, window, &active, packet_chan, user_chan)
os_sigs := make(chan os.Signal, 1) os_sigs := make(chan os.Signal, 1)
signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGABRT) signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGABRT)
<-os_sigs <-os_sigs
active.Store(false) client.Close()
ncurses.EndWin() ncurses.EndWin()
} }