Made SessionPackeType an fmt.Stringer, removed client state for atomic bool

live
noah metz 2024-04-23 16:57:04 -06:00
parent 73312a37cd
commit 6b7e166bb8
2 changed files with 51 additions and 45 deletions

@ -9,13 +9,6 @@ import (
"time"
)
type ClientState uint8
const (
CLIENT_SESSION_CONNECTING ClientState = iota
CLIENT_SESSION_CONNECTED
)
type Client struct {
Key ed25519.PrivateKey
Session Session
@ -23,10 +16,10 @@ type Client struct {
connection *net.UDPConn
to_send chan []byte
data_fn func(Payload)error
state ClientState
active atomic.Bool
connected atomic.Bool
}
func(client *Client) Remote() string {
@ -54,18 +47,7 @@ func(client *Client) listen_udp() {
client.Log("Client listen error - no data in packet")
} else {
packet_type := SessionPacketType(buffer[0])
switch client.state {
case CLIENT_SESSION_CONNECTING:
switch packet_type {
case SESSION_CONNECTED:
client.state = CLIENT_SESSION_CONNECTED
case SESSION_CLOSED:
client.Log("Server repsonded to session connect with SESSION_CLOSED")
client.active.Store(false)
default:
client.Log("Bad session packet type 0x%02x for client in state SESSION_CONNECTING", packet_type)
}
case CLIENT_SESSION_CONNECTED:
if client.connected.Load() {
switch packet_type {
case SESSION_DATA:
if len(buffer) < SESSION_ID_LENGTH + 1 {
@ -103,11 +85,22 @@ func(client *Client) listen_udp() {
if err != nil {
client.Log("Error running data_fn: %s", err)
}
case SESSION_CLOSED:
client.Log("Server sent SESSION_CLOSED")
client.active.Store(false)
default:
client.Log("Bad session packet type 0x%02x for client in state SESSION_CONNECTED", packet_type)
client.Log("Bad session packet type %s for connected session", packet_type)
}
} else {
switch packet_type {
case SESSION_CONNECTED:
client.connected.Store(true)
case SESSION_CLOSED:
client.Log("Server repsonded to session connect with SESSION_CLOSED")
client.active.Store(false)
default:
client.Log("Received packet while in bad client state 0x%02x", client.state)
client.Log("Bad session packet type %s for disconnected session", packet_type)
}
}
}
}
@ -191,33 +184,22 @@ func NewClient(key ed25519.PrivateKey, remote string, data_fn func(Payload)error
remote: remote,
connection: connection,
to_send: make(chan []byte, 1000),
data_fn: data_fn,
state: CLIENT_SESSION_CONNECTING,
}
client.active.Store(true)
client.connected.Store(false)
go client.listen_udp()
go client.send_queue()
return client, nil
}
func(client *Client) send_queue() {
for client.active.Load() {
packet := <- client.to_send
if packet == nil {
break
}
_, err := client.connection.Write(packet)
if err != nil {
client.Log("Send error: %s", err)
}
}
func(client *Client) Send(packet *Packet) error {
if client.active.Load() == false {
} else if client.connected.Load() == false {
}
func(client *Client) Send(packet *Packet) error {
data, err := packet.MarshalBinary()
if err != nil {
return err
@ -228,18 +210,17 @@ func(client *Client) Send(packet *Packet) error {
return err
}
select {
case client.to_send <- wrapped:
return nil
default:
return fmt.Errorf("Channel overflow")
_, err = client.connection.Write(wrapped)
if err != nil {
return err
}
return nil
}
func(client *Client) Close() error {
var err error = fmt.Errorf("Client not active")
if client.active.CompareAndSwap(true, false) {
close(client.to_send)
client.connection.Write(NewSessionTimed(SESSION_CLOSE, client.Key, &client.Session, time.Now()))
err = client.connection.Close()
}

@ -70,6 +70,31 @@ const (
SESSION_DATA
)
func(packet_type SessionPacketType) String() string {
switch packet_type {
case SESSION_OPEN:
return "SESSION_OPEN"
case SESSION_OPENED:
return "SESSION_OPENED"
case SESSION_INVITE:
return "SESSION_INVITE"
case SESSION_INVITED:
return "SESSION_INVITED"
case SESSION_CONNECT:
return "SESSION_CONNECT"
case SESSION_CONNECTED:
return "SESSION_CONNECTED"
case SESSION_CLOSE:
return "SESSION_CLOSE"
case SESSION_CLOSED:
return "SESSION_CLOSED"
case SESSION_DATA:
return "SESSION_DATA"
default:
return fmt.Sprintf("UNKNOWN 0x%02x", uint8(packet_type))
}
}
func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) {
public_point, err := (&edwards25519.Point{}).SetBytes(public)
if err != nil {