diff --git a/channel.go b/channel.go index 6767187..062c153 100644 --- a/channel.go +++ b/channel.go @@ -19,6 +19,7 @@ const ( AUDIO_GET_SAMPLE_RATE CHANNEL_COMMAND_BUFFER_SIZE int = 2048 + CHANNEL_CLOSE_BUFFER_SIZE int = 100 ) type ModeID uint8 @@ -27,49 +28,76 @@ type Permission string type Channel struct { Commands chan SessionChannelCommand + ClosedSessions chan *ServerSession Modes map[ModeID]*atomic.Value Members atomic.Value } func(channel *Channel) update_state() { for true { - incoming := <-channel.Commands - if incoming.Packet == nil { - break - } + select { + case session := <-channel.ClosedSessions: + members := channel.Members.Load().([]*ServerSession) + idx := slices.Index(members, session) + if idx != -1 { + new_members := make([]*ServerSession, len(members) - 1) + copy(new_members, members[:idx]) + copy(new_members[idx:], members[idx+1:]) + channel.Members.Store(new_members) + + for _, mode_val := range(channel.Modes) { + mode := mode_val.Load().(Mode) + mode_val.Store(mode.Leave(session)) + } + } + case incoming := <-channel.Commands: + if incoming.Packet == nil { + break + } else if incoming.Session.active.Load() == false { + continue + } - command := incoming.Packet - - if command.Mode == MODE_CHANNEL { - switch command.Command { - case CHANNEL_COMMAND_JOIN: - members := channel.Members.Load().([]*ServerSession) - if slices.Contains(members, incoming.Session) == false { - new_members := make([]*ServerSession, len(members) + 1) - copy(new_members, members) - new_members[len(members)] = incoming.Session - channel.Members.Store(new_members) - fmt.Printf("New Members: %+v\n", channel.Members.Load()) + command := incoming.Packet + + if command.Mode == MODE_CHANNEL { + switch command.Command { + case CHANNEL_COMMAND_JOIN: + members := channel.Members.Load().([]*ServerSession) + if slices.Contains(members, incoming.Session) == false { + new_members := make([]*ServerSession, len(members) + 1) + copy(new_members, members) + new_members[len(members)] = incoming.Session + channel.Members.Store(new_members) + + for _, mode_val := range(channel.Modes) { + mode := mode_val.Load().(Mode) + mode_val.Store(mode.Join(incoming.Session)) + } + } + case CHANNEL_COMMAND_LEAVE: + members := channel.Members.Load().([]*ServerSession) + idx := slices.Index(members, incoming.Session) + if idx != -1 { + new_members := make([]*ServerSession, len(members) - 1) + copy(new_members, members[:idx]) + copy(new_members[idx:], members[idx+1:]) + channel.Members.Store(new_members) + + for _, mode_val := range(channel.Modes) { + mode := mode_val.Load().(Mode) + mode_val.Store(mode.Leave(incoming.Session)) + } + } } - case CHANNEL_COMMAND_LEAVE: - members := channel.Members.Load().([]*ServerSession) - idx := slices.Index(members, incoming.Session) - if idx != -1 { - new_members := make([]*ServerSession, len(members) - 1) - copy(new_members, members[:idx]) - copy(new_members[idx:], members[idx+1:]) - channel.Members.Store(new_members) - fmt.Printf("New Members: %+v\n", channel.Members.Load()) + } else { + mode, has_mode := channel.Modes[command.Mode] + if has_mode { + members := channel.Members.Load().([]*ServerSession) + mode_val := mode.Load().(Mode) + new_mode := mode_val.Command(incoming.Session, command.Command, command.ReqID, command.Channel, members, command.Data) + mode.Store(new_mode) } } - } else { - mode, has_mode := channel.Modes[command.Mode] - if has_mode { - members := channel.Members.Load().([]*ServerSession) - mode_val := mode.Load().(Mode) - new_mode := mode_val.Command(incoming.Session, command.Command, command.ReqID, command.Channel, members, command.Data) - mode.CompareAndSwap(mode_val, new_mode) - } } } } @@ -86,6 +114,7 @@ func NewChannel(modes map[ModeID]Mode) (*Channel, error) { channel := &Channel{ Commands: make(chan SessionChannelCommand, CHANNEL_COMMAND_BUFFER_SIZE), + ClosedSessions: make(chan *ServerSession, CHANNEL_CLOSE_BUFFER_SIZE), Modes: initial_modes, } channel.Members.Store([]*ServerSession{}) @@ -96,6 +125,8 @@ func NewChannel(modes map[ModeID]Mode) (*Channel, error) { } type Mode interface { + Join(session *ServerSession) Mode + Leave(session *ServerSession) Mode Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode Data(session *ServerSession, channel_id ChannelID, members []*ServerSession, data []byte) } @@ -119,6 +150,14 @@ func multiplex(packet *Packet, sessions []*ServerSession) { type RawMode struct { } +func(mode RawMode) Join(session *ServerSession) Mode { + return mode +} + +func(mode RawMode) Leave(session *ServerSession) Mode { + return mode +} + func(mode RawMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode { return mode } @@ -139,6 +178,14 @@ type AudioMode struct { SampleRate SampleRate } +func(mode AudioMode) Join(session *ServerSession) Mode { + return mode +} + +func(mode AudioMode) Leave(session *ServerSession) Mode { + return mode +} + func(mode AudioMode) Command(session *ServerSession, command byte, request_id uuid.UUID, channel_id ChannelID, members []*ServerSession, data []byte) Mode { switch command { case AUDIO_SET_SAMPLE_RATE: diff --git a/client.go b/client.go index c69aa78..b520a74 100644 --- a/client.go +++ b/client.go @@ -26,11 +26,7 @@ type Client struct { func NewClient(key ed25519.PrivateKey, remote string) (*Client, error) { if key == nil { - var err error - _, key, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, err - } + return nil, fmt.Errorf("Need a key to create a client, passed nil") } seed_bytes := make([]byte, 8) diff --git a/cmd/client/main.go b/cmd/client/main.go index 56fbfe4..ba4ddf3 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -1,15 +1,25 @@ package main import ( + "crypto/ed25519" + "crypto/rand" + "crypto/x509" "encoding/binary" + "encoding/pem" + "flag" "fmt" "os" + "os/signal" + "slices" + "sync/atomic" + "syscall" "time" "git.metznet.ca/MetzNet/pnyx" "github.com/gen2brain/malgo" "github.com/google/uuid" "github.com/hraban/opus" + "seehuhn.de/go/ncurses" ) var decoders = map[pnyx.PeerID]chan[]byte{} @@ -20,7 +30,6 @@ var speaker = make(chan []int16, 0) var audio_data = make(chan []int16, 0) func set_sample_rate(new_sample_rate int) error { - fmt.Printf("Setting sample rate to %d\n", new_sample_rate) sample_rate = new_sample_rate var err error @@ -100,6 +109,10 @@ func mixer(data_chan chan []int16, speaker_chan chan []int16) { func main() { + key_file_arg := flag.String("key", "${HOME}/.pnyx.key", "Path to the private key file to load/save") + generate_key_arg := flag.Bool("genkey", false, "Set to generate a key if none exists") + flag.Parse() + ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, nil) if err != nil { panic(err) @@ -231,13 +244,53 @@ func main() { defer inDevice.Uninit() defer inDevice.Stop() + var key ed25519.PrivateKey = nil + + key_file_path := os.ExpandEnv(*key_file_arg) + key_file_bytes, err := os.ReadFile(key_file_path) + if err == nil { + key_pem, _ := pem.Decode(key_file_bytes) + if key_pem.Type != "PRIVATE KEY" { + panic("Key file has wrong PEM format") + } + + private_key, err := x509.ParsePKCS8PrivateKey(key_pem.Bytes) + if err != nil { + panic(err) + } + + var ok bool + key, ok = private_key.(ed25519.PrivateKey) + if ok == false { + panic("Private key is not ed25519.PrivateKey") + } + } else if *generate_key_arg { + _, key, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + + key_pkcs8, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + panic(err) + } + + key_pem := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: key_pkcs8, + }) + + err = os.WriteFile(key_file_path, key_pem, 0o600) + if err != nil { + panic(err) + } + } - client, err := pnyx.NewClient(nil, os.Args[1]) + client, err := pnyx.NewClient(key, flag.Arg(0)) if err != nil { panic(err) } - - + go func() { var buf [1024]byte for true { @@ -258,9 +311,7 @@ func main() { switch packet := packet.(type) { case pnyx.PingPacket: - fmt.Printf("Ping Packet From Server: %+v\n", packet) case pnyx.ChannelCommandPacket: - fmt.Printf("Channel Command packet: %+v\n", packet) if packet.Channel == pnyx.ChannelID(0) { if packet.Mode == pnyx.MODE_AUDIO { if packet.Command == pnyx.AUDIO_SET_SAMPLE_RATE { @@ -281,7 +332,6 @@ func main() { } } case pnyx.CommandPacket: - fmt.Printf("Command packet: %+v\n", packet) case pnyx.PeerPacket: if packet.Channel == pnyx.ChannelID(0) { decode_chan, exists := decoders[packet.Peer] @@ -315,34 +365,46 @@ func main() { panic(err) } - for true { - data := <- mic - err = client.Send(pnyx.NewDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data)) - if err != nil { - panic(err) + go func(){ + for true { + data := <- mic + err = client.Send(pnyx.NewDataPacket(pnyx.ChannelID(0), pnyx.MODE_AUDIO, data)) + if err != nil { + panic(err) + } } - } + }() - /*window := ncurses.Init() - defer ncurses.EndWin() + window := ncurses.Init() + active := atomic.Bool{} + active.Store(true) - ncurses.ColorPair(1).Init(ncurses.ColorBlue, ncurses.ColorRed) - window.AddStr("pnyx client") + go func() { + ncurses.ColorPair(1).Init(ncurses.ColorBlue, ncurses.ColorRed) + window.AddStr("pnyx client") + + for active.Load() { + window.Refresh() + time.Sleep(200*time.Millisecond) + peers := make([]pnyx.PeerID, 0, len(decoders)) + for peer_id := range(decoders) { + peers = append(peers, peer_id) + } - for true { - window.Refresh() - time.Sleep(200*time.Millisecond) - peers := make([]pnyx.PeerID, 0, len(decoders)) - for peer_id := range(decoders) { - peers = append(peers, peer_id) + slices.SortFunc(peers, func(a, b pnyx.PeerID) int { + return slices.Compare(a[:], b[:]) + }) + + for i, peer_id := range(peers) { + window.MvAddStr(i+1, 0, fmt.Sprintf("%x", peer_id)) + } } + }() - slices.SortFunc(peers, func(a, b pnyx.PeerID) int { - return slices.Compare(a[:], b[:]) - }) + os_sigs := make(chan os.Signal, 1) + signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGABRT) - for i, peer_id := range(peers) { - window.MvAddStr(i+1, 0, fmt.Sprintf("%x", peer_id)) - } - }*/ + <-os_sigs + active.Store(false) + ncurses.EndWin() } diff --git a/cmd/server/main.go b/cmd/server/main.go index fba9c95..97d06ea 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -11,7 +11,7 @@ import ( func main() { os_sigs := make(chan os.Signal, 1) - signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGINT) + signal.Notify(os_sigs, syscall.SIGINT, syscall.SIGABRT) channel_0, err := pnyx.NewChannel(map[pnyx.ModeID]pnyx.Mode{ pnyx.MODE_RAW: pnyx.RawMode{}, diff --git a/server.go b/server.go index a1e9698..efa0e0b 100644 --- a/server.go +++ b/server.go @@ -143,11 +143,13 @@ func handle_session_incoming(session *ServerSession, server *Server) { for session.active.Load() { select { case <- ping_timer: - if time.Now().Add(-1*SESSION_TIMEOUT).Compare(session.LastSeen) != 1 { + if time.Now().Add(-1*SESSION_TIMEOUT).Compare(session.LastSeen) != -1 { + server.Log("Closing %s after being inactive since %s", session.ID, session.LastSeen) server.sessions_lock.Lock() server.close_session(session) server.sessions_lock.Unlock() } else { + server.Log("%s passed keep-alive check, last seen %s", session.ID, session.LastSeen) session.OutgoingPackets <- NewPingPacket() ping_timer = time.After(SESSION_PING_TIME) }