Changes to client and channel

live
noah metz 2024-04-18 20:10:01 -06:00
parent caeed001c3
commit 6b3bf3eda2
5 changed files with 177 additions and 70 deletions

@ -19,6 +19,7 @@ const (
AUDIO_GET_SAMPLE_RATE AUDIO_GET_SAMPLE_RATE
CHANNEL_COMMAND_BUFFER_SIZE int = 2048 CHANNEL_COMMAND_BUFFER_SIZE int = 2048
CHANNEL_CLOSE_BUFFER_SIZE int = 100
) )
type ModeID uint8 type ModeID uint8
@ -27,49 +28,76 @@ type Permission string
type Channel struct { type Channel struct {
Commands chan SessionChannelCommand Commands chan SessionChannelCommand
ClosedSessions chan *ServerSession
Modes map[ModeID]*atomic.Value Modes map[ModeID]*atomic.Value
Members atomic.Value Members atomic.Value
} }
func(channel *Channel) update_state() { func(channel *Channel) update_state() {
for true { for true {
incoming := <-channel.Commands select {
if incoming.Packet == nil { case session := <-channel.ClosedSessions:
break members := channel.Members.Load().([]*ServerSession)
} idx := slices.Index(members, session)
if idx != -1 {
new_members := make([]*ServerSession, len(members) - 1)
copy(new_members, members[:idx])
copy(new_members[idx:], members[idx+1:])
channel.Members.Store(new_members)
for _, mode_val := range(channel.Modes) {
mode := mode_val.Load().(Mode)
mode_val.Store(mode.Leave(session))
}
}
case incoming := <-channel.Commands:
if incoming.Packet == nil {
break
} else if incoming.Session.active.Load() == false {
continue
}
command := incoming.Packet command := incoming.Packet
if command.Mode == MODE_CHANNEL { if command.Mode == MODE_CHANNEL {
switch command.Command { switch command.Command {
case CHANNEL_COMMAND_JOIN: case CHANNEL_COMMAND_JOIN:
members := channel.Members.Load().([]*ServerSession) members := channel.Members.Load().([]*ServerSession)
if slices.Contains(members, incoming.Session) == false { if slices.Contains(members, incoming.Session) == false {
new_members := make([]*ServerSession, len(members) + 1) new_members := make([]*ServerSession, len(members) + 1)
copy(new_members, members) copy(new_members, members)
new_members[len(members)] = incoming.Session new_members[len(members)] = incoming.Session
channel.Members.Store(new_members) channel.Members.Store(new_members)
fmt.Printf("New Members: %+v\n", channel.Members.Load())
for _, mode_val := range(channel.Modes) {
mode := mode_val.Load().(Mode)
mode_val.Store(mode.Join(incoming.Session))
}
}
case CHANNEL_COMMAND_LEAVE:
members := channel.Members.Load().([]*ServerSession)
idx := slices.Index(members, incoming.Session)
if idx != -1 {
new_members := make([]*ServerSession, len(members) - 1)
copy(new_members, members[:idx])
copy(new_members[idx:], members[idx+1:])
channel.Members.Store(new_members)
for _, mode_val := range(channel.Modes) {
mode := mode_val.Load().(Mode)
mode_val.Store(mode.Leave(incoming.Session))
}
}
} }
case CHANNEL_COMMAND_LEAVE: } else {
members := channel.Members.Load().([]*ServerSession) mode, has_mode := channel.Modes[command.Mode]
idx := slices.Index(members, incoming.Session) if has_mode {
if idx != -1 { members := channel.Members.Load().([]*ServerSession)
new_members := make([]*ServerSession, len(members) - 1) mode_val := mode.Load().(Mode)
copy(new_members, members[:idx]) new_mode := mode_val.Command(incoming.Session, command.Command, command.ReqID, command.Channel, members, command.Data)
copy(new_members[idx:], members[idx+1:]) mode.Store(new_mode)
channel.Members.Store(new_members)
fmt.Printf("New Members: %+v\n", channel.Members.Load())
} }
} }
} else {
mode, has_mode := channel.Modes[command.Mode]
if has_mode {
members := channel.Members.Load().([]*ServerSession)
mode_val := mode.Load().(Mode)
new_mode := mode_val.Command(incoming.Session, command.Command, command.ReqID, command.Channel, members, command.Data)
mode.CompareAndSwap(mode_val, new_mode)
}
} }
} }
} }
@ -86,6 +114,7 @@ func NewChannel(modes map[ModeID]Mode) (*Channel, error) {
channel := &Channel{ channel := &Channel{
Commands: make(chan SessionChannelCommand, CHANNEL_COMMAND_BUFFER_SIZE), Commands: make(chan SessionChannelCommand, CHANNEL_COMMAND_BUFFER_SIZE),
ClosedSessions: make(chan *ServerSession, CHANNEL_CLOSE_BUFFER_SIZE),
Modes: initial_modes, Modes: initial_modes,
} }
channel.Members.Store([]*ServerSession{}) channel.Members.Store([]*ServerSession{})
@ -96,6 +125,8 @@ func NewChannel(modes map[ModeID]Mode) (*Channel, error) {
} }
type Mode interface { type Mode interface {
Join(session *ServerSession) Mode
Leave(session *ServerSession) Mode
Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode
Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte)
} }
@ -119,6 +150,14 @@ func multiplex(packet *Packet, sessions []*ServerSession) {
type RawMode struct { type RawMode struct {
} }
func(mode RawMode) Join(session *ServerSession) Mode {
return mode
}
func(mode RawMode) Leave(session *ServerSession) Mode {
return mode
}
func(mode RawMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode { func(mode RawMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode {
return mode return mode
} }
@ -139,6 +178,14 @@ type AudioMode struct {
SampleRate SampleRate SampleRate SampleRate
} }
func(mode AudioMode) Join(session *ServerSession) Mode {
return mode
}
func(mode AudioMode) Leave(session *ServerSession) Mode {
return mode
}
func(mode AudioMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode { func(mode AudioMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode {
switch command { switch command {
case AUDIO_SET_SAMPLE_RATE: case AUDIO_SET_SAMPLE_RATE:

@ -26,11 +26,7 @@ type Client struct {
func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) {
if key == nil { if key == nil {
var err error return nil, fmt.Errorf("Need a key to create a client, passed nil")
_, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
} }
seed_bytes := make([]byte, 8) seed_bytes := make([]byte, 8)

@ -1,15 +1,25 @@
package main package main
import ( import (
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"encoding/binary" "encoding/binary"
"encoding/pem"
"flag"
"fmt" "fmt"
"os" "os"
"os/signal"
"slices"
"sync/atomic"
"syscall"
"time" "time"
"git.metznet.ca/MetzNet/pnyx" "git.metznet.ca/MetzNet/pnyx"
"github.com/gen2brain/malgo" "github.com/gen2brain/malgo"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hraban/opus" "github.com/hraban/opus"
"seehuhn.de/go/ncurses"
) )
var decoders = map[pnyx.PeerID]chan[]byte{} var decoders = map[pnyx.PeerID]chan[]byte{}
@ -20,7 +30,6 @@ var speaker = make(chan []int16, 0)
var audio_data = make(chan []int16, 0) var audio_data = make(chan []int16, 0)
func set_sample_rate(new_sample_rate int) error { func set_sample_rate(new_sample_rate int) error {
fmt.Printf("Setting sample rate to %d\n", new_sample_rate)
sample_rate = new_sample_rate sample_rate = new_sample_rate
var err error var err error
@ -100,6 +109,10 @@ func mixer(data_chan chan []int16, speaker_chan chan []int16) {
func main() { func main() {
key_file_arg := flag.String("key", "${HOME}/.pnyx.key", "Path to the private key file to load/save")
generate_key_arg := flag.Bool("genkey", false, "Set to generate a key if none exists")
flag.Parse()
ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil) ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil)
if err != nil { if err != nil {
panic(err) panic(err)
@ -231,13 +244,53 @@ func main() {
defer inDevice.Uninit() defer inDevice.Uninit()
defer inDevice.Stop() defer inDevice.Stop()
var key ed25519.PrivateKey = nil
key_file_path := os.ExpandEnv(*key_file_arg)
key_file_bytes, err := os.ReadFile(key_file_path)
if err == nil {
key_pem, _ := pem.Decode(key_file_bytes)
if key_pem.Type != "PRIVATE KEY" {
panic("Key file has wrong PEM format")
}
private_key, err := x509.ParsePKCS8PrivateKey(key_pem.Bytes)
if err != nil {
panic(err)
}
var ok bool
key, ok = private_key.(ed25519.PrivateKey)
if ok == false {
panic("Private key is not ed25519.PrivateKey")
}
} else if *generate_key_arg {
_, key, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
key_pkcs8, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
panic(err)
}
key_pem := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: key_pkcs8,
})
err = os.WriteFile(key_file_path, key_pem, 0o600)
if err != nil {
panic(err)
}
}
client, err := pnyx.NewClient(nil, os.Args[1]) client, err := pnyx.NewClient(key, flag.Arg(0))
if err != nil { if err != nil {
panic(err) panic(err)
} }
go func() { go func() {
var buf [1024]byte var buf [1024]byte
for true { for true {
@ -258,9 +311,7 @@ func main() {
switch packet := packet.(type) { switch packet := packet.(type) {
case pnyx.PingPacket: case pnyx.PingPacket:
fmt.Printf("Ping Packet From Server: %+v\n", packet)
case pnyx.ChannelCommandPacket: case pnyx.ChannelCommandPacket:
fmt.Printf("Channel Command packet: %+v\n", packet)
if packet.Channel == pnyx.ChannelID(0) { if packet.Channel == pnyx.ChannelID(0) {
if packet.Mode == pnyx.MODE_AUDIO { if packet.Mode == pnyx.MODE_AUDIO {
if packet.Command == pnyx.AUDIO_SET_SAMPLE_RATE { if packet.Command == pnyx.AUDIO_SET_SAMPLE_RATE {
@ -281,7 +332,6 @@ func main() {
} }
} }
case pnyx.CommandPacket: case pnyx.CommandPacket:
fmt.Printf("Command packet: %+v\n", packet)
case pnyx.PeerPacket: case pnyx.PeerPacket:
if packet.Channel == pnyx.ChannelID(0) { if packet.Channel == pnyx.ChannelID(0) {
decode_chan, exists := decoders[packet.Peer] decode_chan, exists := decoders[packet.Peer]
@ -315,34 +365,46 @@ func main() {
panic(err) panic(err)
} }
for true { go func(){
data := <- mic for true {
err = client.Send(pnyx.NewDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data)) data := <- mic
if err != nil { err = client.Send(pnyx.NewDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data))
panic(err) if err != nil {
panic(err)
}
} }
} }()
/*window := ncurses.Init() window := ncurses.Init()
defer ncurses.EndWin() active := atomic.Bool{}
active.Store(true)
ncurses.ColorPair(1).Init(ncurses.ColorBlue, ncurses.ColorRed) go func() {
window.AddStr("pnyx client") ncurses.ColorPair(1).Init(ncurses.ColorBlue, ncurses.ColorRed)
window.AddStr("pnyx client")
for active.Load() {
window.Refresh()
time.Sleep(200*time.Millisecond)
peers := make([]pnyx.PeerID, 0, len(decoders))
for peer_id := range(decoders) {
peers = append(peers, peer_id)
}
for true { slices.SortFunc(peers, func(a, b pnyx.PeerID) int {
window.Refresh() return slices.Compare(a[:], b[:])
time.Sleep(200*time.Millisecond) })
peers := make([]pnyx.PeerID, 0, len(decoders))
for peer_id := range(decoders) { for i, peer_id := range(peers) {
peers = append(peers, peer_id) window.MvAddStr(i+1, 0, fmt.Sprintf("%x", peer_id))
}
} }
}()
slices.SortFunc(peers, func(a, b pnyx.PeerID) int { os_sigs := make(chan os.Signal, 1)
return slices.Compare(a[:], b[:]) signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGABRT)
})
for i, peer_id := range(peers) { <-os_sigs
window.MvAddStr(i+1, 0, fmt.Sprintf("%x", peer_id)) active.Store(false)
} ncurses.EndWin()
}*/
} }

@ -11,7 +11,7 @@ import (
func main() { func main() {
os_sigs := make(chan os.Signal, 1) os_sigs := make(chan os.Signal, 1)
signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGINT) signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGABRT)
channel_0, err := pnyx.NewChannel(map[pnyx.ModeID]pnyx.Mode{ channel_0, err := pnyx.NewChannel(map[pnyx.ModeID]pnyx.Mode{
pnyx.MODE_RAW: pnyx.RawMode{}, pnyx.MODE_RAW: pnyx.RawMode{},

@ -143,11 +143,13 @@ func handle_session_incoming(session *ServerSession, server *Server) {
for session.active.Load() { for session.active.Load() {
select { select {
case <- ping_timer: case <- ping_timer:
if time.Now().Add(-1*SESSION_TIMEOUT).Compare(session.LastSeen) != 1 { if time.Now().Add(-1*SESSION_TIMEOUT).Compare(session.LastSeen) != -1 {
server.Log("Closing %s after being inactive since %s", session.ID, session.LastSeen)
server.sessions_lock.Lock() server.sessions_lock.Lock()
server.close_session(session) server.close_session(session)
server.sessions_lock.Unlock() server.sessions_lock.Unlock()
} else { } else {
server.Log("%s passed keep-alive check, last seen %s", session.ID, session.LastSeen)
session.OutgoingPackets <- NewPingPacket() session.OutgoingPackets <- NewPingPacket()
ping_timer = time.After(SESSION_PING_TIME) ping_timer = time.After(SESSION_PING_TIME)
} }