diff --git a/client.go b/client.go index b520a74..faa0e6a 100644 --- a/client.go +++ b/client.go @@ -3,30 +3,121 @@ package pnyx import ( "crypto/ed25519" "crypto/rand" + "fmt" "net" - "sync" - "time" - "fmt" + "sync/atomic" + "time" ) type ClientState uint8 const ( - CLIENT_SESSION_CREATE ClientState = iota - CLIENT_SESSION_CONNECT + CLIENT_SESSION_CONNECTING ClientState = iota CLIENT_SESSION_CONNECTED ) type Client struct { Key ed25519.PrivateKey Session Session - ConnectionLock sync.Mutex - Connection *net.UDPConn + remote string + + connection *net.UDPConn + + to_send chan []byte + data_fn func(Payload)error + state ClientState + active atomic.Bool +} + +func(client *Client) Remote() string { + return client.remote +} + +func(client *Client) Active() bool { + return client.active.Load() } -func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { +const CLIENT_UDP_BUFFER = 2048 +const CLIENT_MAX_CONNECT_ATTEMPTS = 5 + +func(client *Client) Log(format string, vals ...any) { + fmt.Printf("%s\n", fmt.Sprintf(format, vals...)) +} + +func(client *Client) listen_udp() { + buffer := [CLIENT_UDP_BUFFER]byte{} + for client.active.Load() { + read, err := client.connection.Read(buffer[:]) + if err != nil { + client.Log("Client listen error - %s", err) + } else if read == 0 { + client.Log("Client listen error - no data in packet") + } else { + packet_type := SessionPacketType(buffer[0]) + switch client.state { + case CLIENT_SESSION_CONNECTING: + switch packet_type { + case SESSION_CONNECTED: + client.state = CLIENT_SESSION_CONNECTED + case SESSION_CLOSED: + client.Log("Server repsonded to session connect with SESSION_CLOSED") + client.active.Store(false) + default: + client.Log("Bad session packet type 0x%02x for client in state SESSION_CONNECTING", packet_type) + } + case CLIENT_SESSION_CONNECTED: + switch packet_type { + case SESSION_DATA: + if len(buffer) < SESSION_ID_LENGTH + 1 { + client.Log("Not enough data to decode SESSION_DATA packet %d/%d", len(buffer), SESSION_ID_LENGTH + 1) + continue + } + + session_id := SessionID(buffer[1:1+SESSION_ID_LENGTH]) + if session_id != client.Session.ID { + client.Log("Session ID of data packet does not match client.Session %s =/= %s", session_id, client.Session.ID) + continue + } + + data, err := ParseSessionData(&client.Session, buffer[1+SESSION_ID_LENGTH:read]) + if err != nil { + client.Log("Error parsing session data: %s", err) + continue + } + + payload, err := ParsePacket(data) + if err != nil { + client.Log("Error parsing packet from session data: %s", err) + continue + } + + switch payload.(type) { + case PingPacket: + err = client.Send(NewPingPacket()) + if err != nil { + client.Log("Error sending ping packet: %s", err) + } + } + + err = client.data_fn(payload) + if err != nil { + client.Log("Error running data_fn: %s", err) + } + default: + client.Log("Bad session packet type 0x%02x for client in state SESSION_CONNECTED", packet_type) + } + default: + client.Log("Received packet while in bad client state 0x%02x", client.state) + } + } + } +} + +func NewClient(key ed25519.PrivateKey, remote string, data_fn func(Payload)error) (*Client, error) { if key == nil { return nil, fmt.Errorf("Need a key to create a client, passed nil") + } else if data_fn == nil { + return nil, fmt.Errorf("Need a function to run with session data") } seed_bytes := make([]byte, 8) @@ -53,24 +144,39 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { return nil, err } - _, err = connection.Write(session_open) - if err != nil { - return nil, err - } - var response = [512]byte{} - read, _, err = connection.ReadFromUDP(response[:]) - if err != nil { - return nil, err - } - - if response[0] != byte(SESSION_OPENED) { - return nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0]) - } - - session, err := ParseSessionOpened(nil, ecdh_private, response[COMMAND_LENGTH:read]) - if err != nil { - return nil, err + var session Session + for attempts := 0; attempts <= CLIENT_MAX_CONNECT_ATTEMPTS; attempts++ { + _, err = connection.Write(session_open) + if err != nil { + return nil, err + } + + // TODO: handle timeout + read, _, err = connection.ReadFromUDP(response[:]) + if err != nil { + return nil, err + } + + if response[0] != byte(SESSION_OPENED) { + return nil, fmt.Errorf("Invalid SESSION_OPEN response: %x", response[0]) + } + + session, err = ParseSessionOpened(nil, ecdh_private, response[COMMAND_LENGTH:read]) + if err == nil { + break + } + + net_error, ok := err.(net.Error) + if ok == false { + return nil, err + } else if net_error.Timeout() { + if attempts == CLIENT_MAX_CONNECT_ATTEMPTS { + return nil, fmt.Errorf("Failed to connect to server at %s", remote) + } + } else { + return nil, err + } } session_connect := NewSessionTimed(SESSION_CONNECT, key, &session, time.Now()) @@ -78,27 +184,40 @@ func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { if err != nil { return nil, err } - - read, _, err = connection.ReadFromUDP(response[:]) - if err != nil { - return nil, err - } - - return &Client{ + + client := &Client{ Key: key, Session: session, - Connection: connection, - }, nil -} + remote: remote, -func(client *Client) Send(packet *Packet) error { - client.ConnectionLock.Lock() - defer client.ConnectionLock.Unlock() + connection: connection, + to_send: make(chan []byte, 1000), + data_fn: data_fn, + state: CLIENT_SESSION_CONNECTING, + } + + client.active.Store(true) - if client.Connection == nil { - return fmt.Errorf("Client is not connected") + go client.listen_udp() + go client.send_queue() + + return client, nil +} + +func(client *Client) send_queue() { + for client.active.Load() { + packet := <- client.to_send + if packet == nil { + break + } + _, err := client.connection.Write(packet) + if err != nil { + client.Log("Send error: %s", err) + } } +} +func(client *Client) Send(packet *Packet) error { data, err := packet.MarshalBinary() if err != nil { return err @@ -109,21 +228,20 @@ func(client *Client) Send(packet *Packet) error { return err } - _, err = client.Connection.Write(wrapped) - return err + select { + case client.to_send <- wrapped: + return nil + default: + return fmt.Errorf("Channel overflow") + } } func(client *Client) Close() error { - client.ConnectionLock.Lock() - defer client.ConnectionLock.Unlock() - - if client.Connection == nil { - return fmt.Errorf("No connection to close") + var err error = fmt.Errorf("Client not active") + if client.active.CompareAndSwap(true, false) { + close(client.to_send) + client.connection.Write(NewSessionTimed(SESSION_CLOSE, client.Key, &client.Session, time.Now())) + err = client.connection.Close() } - - client.Connection.Write(NewSessionTimed(SESSION_CLOSE, client.Key, &client.Session, time.Now())) - err := client.Connection.Close() - client.Connection = nil - return err } diff --git a/cmd/client/main.go b/cmd/client/main.go index f265d7d..70d5ec4 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -10,25 +10,21 @@ import ( "fmt" "os" "os/signal" - "sync/atomic" "syscall" "time" + "git.metznet.ca/MetzNet/go-ncurses" "git.metznet.ca/MetzNet/pnyx" "github.com/gen2brain/malgo" "github.com/google/uuid" "github.com/hraban/opus" - "git.metznet.ca/MetzNet/go-ncurses" ) var decoders = map[pnyx.PeerID]chan[]byte{} var encoder *opus.Encoder var sample_rate int = 0 -var mic = make(chan []byte, 0) -var speaker = make(chan []int16, 0) -var audio_data = make(chan []int16, 0) -func set_sample_rate(new_sample_rate int) error { +func set_sample_rate(audio_data chan []int16, new_sample_rate int) error { sample_rate = new_sample_rate var err error @@ -43,12 +39,12 @@ func set_sample_rate(new_sample_rate int) error { } new_chan := make(chan[]byte, 1000) decoders[peer_id] = new_chan - go handle_peer_decode(peer_id, decoders[peer_id], sample_rate) + go handle_peer_decode(audio_data, peer_id, decoders[peer_id], sample_rate) } return nil } -func handle_peer_decode(peer_id pnyx.PeerID, decode_chan chan[]byte, sample_rate int){ +func handle_peer_decode(audio_data chan []int16, peer_id pnyx.PeerID, decode_chan chan[]byte, sample_rate int){ decoder, err := opus.NewDecoder(sample_rate, 1) if err != nil { panic(err) @@ -106,14 +102,12 @@ func mixer(data_chan chan []int16, speaker_chan chan []int16) { } } -func setup_audio() (*malgo.AllocatedContext, *malgo.Device, *malgo.Device) { +func setup_audio(mic chan []byte, speaker chan []int16) (*malgo.AllocatedContext, *malgo.Device, *malgo.Device) { ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil) if err != nil { panic(err) } - go mixer(audio_data, speaker) - infos, err := ctx.Devices(malgo.Playback) if err != nil { panic(err) @@ -277,25 +271,21 @@ func get_private_key(path string, generate bool) ed25519.PrivateKey { } } -func main_loop(client *pnyx.Client, window ncurses.Window, active *atomic.Bool, packet_chan chan pnyx.Payload, user_chan chan rune) { +func main_loop(client *pnyx.Client, audio_data chan []int16, window ncurses.Window, packet_chan chan pnyx.Payload, user_chan chan rune) { max_y := ncurses.GetMaxY(window) max_x := ncurses.GetMaxX(window) titlebar := ncurses.NewWin(1, max_x, 0, 0) channels := ncurses.NewWin(max_y - 1, max_x / 3, 1, 0) body := ncurses.NewWin(max_y - 1, max_x * 2 / 3, 1, max_x / 3) - server_name := client.Connection.RemoteAddr().String() - ncurses.MvWAddStr(titlebar, 0, 0, fmt.Sprintf("pnyx client %X:%X", client.Key.Public().(ed25519.PublicKey)[:2], client.Session.ID[:2])) - ncurses.MvWAddStr(body, 0, max_x-len(server_name), server_name) + ncurses.MvWAddStr(body, 0, max_x-len(client.Remote()), client.Remote()) ncurses.WRefresh(titlebar) - for active.Load() { + for client.Active() { select { case packet := <-packet_chan: switch packet := packet.(type) { - case pnyx.PingPacket: - _ = client.Send(pnyx.NewPingPacket()) case pnyx.ChannelCommandPacket: if packet.Channel == pnyx.ChannelID(0) { if packet.Mode == pnyx.MODE_AUDIO { @@ -309,14 +299,14 @@ func main_loop(client *pnyx.Client, window ncurses.Window, active *atomic.Bool, default: continue } - err := set_sample_rate(new_sample_rate) + + err := set_sample_rate(audio_data, new_sample_rate) if err != nil { panic(err) } } } } - case pnyx.CommandPacket: case pnyx.PeerPacket: if packet.Channel == pnyx.ChannelID(0) { decode_chan, exists := decoders[packet.Peer] @@ -324,7 +314,7 @@ func main_loop(client *pnyx.Client, window ncurses.Window, active *atomic.Bool, if sample_rate != 0 { decode_chan = make(chan[]byte, 1000) decoders[packet.Peer] = decode_chan - go handle_peer_decode(packet.Peer, decoders[packet.Peer], sample_rate) + go handle_peer_decode(audio_data, packet.Peer, decoders[packet.Peer], sample_rate) decode_chan <- packet.Data } else { decoders[packet.Peer] = nil @@ -333,7 +323,6 @@ func main_loop(client *pnyx.Client, window ncurses.Window, active *atomic.Bool, decode_chan <- packet.Data } } - default: } case char := <-user_chan: ncurses.MvWAddStr(body, 0, 0, string(char)) @@ -350,9 +339,9 @@ func bitmatch(b byte, pattern byte, length int) bool { return (b ^ pattern) & mask == 0 } -func ch_listen(active *atomic.Bool, user_chan chan rune) { +func ch_listen(client *pnyx.Client, user_chan chan rune) { b := [4]byte{} - for active.Load() { + for client.Active() { os.Stdin.Read(b[0:1]) if bitmatch(b[0], 0b00000000, 1) { user_chan <- rune(b[0]) @@ -360,29 +349,7 @@ func ch_listen(active *atomic.Bool, user_chan chan rune) { } } -func udp_listen(client *pnyx.Client, active *atomic.Bool, packet_chan chan pnyx.Payload) { - var buf [1024]byte - for active.Load() { - read, _, err := client.Connection.ReadFromUDP(buf[:]) - if err != nil { - break - } - - data, err := pnyx.ParseSessionData(&client.Session, buf[pnyx.COMMAND_LENGTH + pnyx.SESSION_ID_LENGTH:read]) - if err != nil { - continue - } - - packet, err := pnyx.ParsePacket(data) - if err != nil { - continue - } - - packet_chan <- packet - } -} - -func process_mic(client *pnyx.Client) { +func process_mic(client *pnyx.Client, mic chan []byte) { for true { data := <- mic err := client.Send(pnyx.NewDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data)) @@ -397,7 +364,13 @@ func main() { generate_key_arg := flag.Bool("genkey", false, "Set to generate a key if none exists") flag.Parse() - ctx, outDevice, inDevice := setup_audio() + var audio_data = make(chan []int16, 0) + var speaker = make(chan []int16, 0) + + go mixer(audio_data, speaker) + + var mic = make(chan []byte, 0) + ctx, outDevice, inDevice := setup_audio(mic, speaker) defer ctx.Free() defer ctx.Uninit() @@ -406,21 +379,23 @@ func main() { defer inDevice.Uninit() defer inDevice.Stop() - key := get_private_key(os.ExpandEnv(*key_file_arg), *generate_key_arg) - client, err := pnyx.NewClient(key, flag.Arg(0)) + user_chan := make(chan rune, 1024) + packet_chan := make(chan pnyx.Payload, 1024) + + key := get_private_key(os.ExpandEnv(*key_file_arg), *generate_key_arg) + client, err := pnyx.NewClient(key, flag.Arg(0), func(payload pnyx.Payload) error { + select { + case packet_chan <- payload: + return nil + default: + return fmt.Errorf("Channel overflow") + } + }) if err != nil { panic(err) } - packet_chan := make(chan pnyx.Payload, 1024) - user_chan := make(chan rune, 1024) - - active := atomic.Bool{} - active.Store(true) - - go udp_listen(client, &active, packet_chan) - join_packet := pnyx.NewChannelCommandPacket(uuid.New(), pnyx.ChannelID(0), pnyx.MODE_CHANNEL, pnyx.CHANNEL_COMMAND_JOIN, nil) err = client.Send(join_packet) if err != nil { @@ -433,17 +408,16 @@ func main() { panic(err) } - go process_mic(client) + go process_mic(client, mic) + go ch_listen(client, user_chan) window := ncurses.InitScr() - - go ch_listen(&active, user_chan) - go main_loop(client, window, &active, packet_chan, user_chan) + go main_loop(client, audio_data, window, packet_chan, user_chan) os_sigs := make(chan os.Signal, 1) signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGABRT) <-os_sigs - active.Store(false) + client.Close() ncurses.EndWin() } diff --git a/server.go b/server.go index 4c4eae8..8a0e037 100644 --- a/server.go +++ b/server.go @@ -166,7 +166,7 @@ func handle_session_incoming(session *ServerSession, server *Server) { server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) continue } - + packet, err := ParsePacket(data) if err != nil { server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)