Made SessionPackeType an fmt.Stringer, removed client state for atomic bool

live
noah metz 2024-04-23 16:57:04 -06:00
parent 73312a37cd
commit 6b7e166bb8
2 changed files with 51 additions and 45 deletions

@ -9,13 +9,6 @@ import (
"time" "time"
) )
type ClientState uint8
const (
CLIENT_SESSION_CONNECTING ClientState = iota
CLIENT_SESSION_CONNECTED
)
type Client struct { type Client struct {
Key ed25519.PrivateKey Key ed25519.PrivateKey
Session Session Session Session
@ -23,10 +16,10 @@ type Client struct {
connection *net.UDPConn connection *net.UDPConn
to_send chan []byte
data_fn func(Payload)error data_fn func(Payload)error
state ClientState
active atomic.Bool active atomic.Bool
connected atomic.Bool
} }
func(client *Client) Remote() string { func(client *Client) Remote() string {
@ -54,18 +47,7 @@ func(client *Client) listen_udp() {
client.Log("Client listen error - no data in packet") client.Log("Client listen error - no data in packet")
} else { } else {
packet_type := SessionPacketType(buffer[0]) packet_type := SessionPacketType(buffer[0])
switch client.state { if client.connected.Load() {
case CLIENT_SESSION_CONNECTING:
switch packet_type {
case SESSION_CONNECTED:
client.state = CLIENT_SESSION_CONNECTED
case SESSION_CLOSED:
client.Log("Server repsonded to session connect with SESSION_CLOSED")
client.active.Store(false)
default:
client.Log("Bad session packet type 0x%02x for client in state SESSION_CONNECTING", packet_type)
}
case CLIENT_SESSION_CONNECTED:
switch packet_type { switch packet_type {
case SESSION_DATA: case SESSION_DATA:
if len(buffer) < SESSION_ID_LENGTH + 1 { if len(buffer) < SESSION_ID_LENGTH + 1 {
@ -103,11 +85,22 @@ func(client *Client) listen_udp() {
if err != nil { if err != nil {
client.Log("Error running data_fn: %s", err) client.Log("Error running data_fn: %s", err)
} }
case SESSION_CLOSED:
client.Log("Server sent SESSION_CLOSED")
client.active.Store(false)
default: default:
client.Log("Bad session packet type 0x%02x for client in state SESSION_CONNECTED", packet_type) client.Log("Bad session packet type %s for connected session", packet_type)
} }
} else {
switch packet_type {
case SESSION_CONNECTED:
client.connected.Store(true)
case SESSION_CLOSED:
client.Log("Server repsonded to session connect with SESSION_CLOSED")
client.active.Store(false)
default: default:
client.Log("Received packet while in bad client state 0x%02x", client.state) client.Log("Bad session packet type %s for disconnected session", packet_type)
}
} }
} }
} }
@ -191,33 +184,22 @@ func NewClient(key ed25519.PrivateKey, remote string, data_fn func(Payload)error
remote: remote, remote: remote,
connection: connection, connection: connection,
to_send: make(chan []byte, 1000),
data_fn: data_fn, data_fn: data_fn,
state: CLIENT_SESSION_CONNECTING,
} }
client.active.Store(true) client.active.Store(true)
client.connected.Store(false)
go client.listen_udp() go client.listen_udp()
go client.send_queue()
return client, nil return client, nil
} }
func(client *Client) send_queue() { func(client *Client) Send(packet *Packet) error {
for client.active.Load() { if client.active.Load() == false {
packet := <- client.to_send } else if client.connected.Load() == false {
if packet == nil {
break
}
_, err := client.connection.Write(packet)
if err != nil {
client.Log("Send error: %s", err)
}
} }
}
func(client *Client) Send(packet *Packet) error {
data, err := packet.MarshalBinary() data, err := packet.MarshalBinary()
if err != nil { if err != nil {
return err return err
@ -228,18 +210,17 @@ func(client *Client) Send(packet *Packet) error {
return err return err
} }
select { _, err = client.connection.Write(wrapped)
case client.to_send <- wrapped: if err != nil {
return nil return err
default:
return fmt.Errorf("Channel overflow")
} }
return nil
} }
func(client *Client) Close() error { func(client *Client) Close() error {
var err error = fmt.Errorf("Client not active") var err error = fmt.Errorf("Client not active")
if client.active.CompareAndSwap(true, false) { if client.active.CompareAndSwap(true, false) {
close(client.to_send)
client.connection.Write(NewSessionTimed(SESSION_CLOSE, client.Key, &client.Session, time.Now())) client.connection.Write(NewSessionTimed(SESSION_CLOSE, client.Key, &client.Session, time.Now()))
err = client.connection.Close() err = client.connection.Close()
} }

@ -70,6 +70,31 @@ const (
SESSION_DATA SESSION_DATA
) )
func(packet_type SessionPacketType) String() string {
switch packet_type {
case SESSION_OPEN:
return "SESSION_OPEN"
case SESSION_OPENED:
return "SESSION_OPENED"
case SESSION_INVITE:
return "SESSION_INVITE"
case SESSION_INVITED:
return "SESSION_INVITED"
case SESSION_CONNECT:
return "SESSION_CONNECT"
case SESSION_CONNECTED:
return "SESSION_CONNECTED"
case SESSION_CLOSE:
return "SESSION_CLOSE"
case SESSION_CLOSED:
return "SESSION_CLOSED"
case SESSION_DATA:
return "SESSION_DATA"
default:
return fmt.Sprintf("UNKNOWN 0x%02x", uint8(packet_type))
}
}
func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) { func ECDH(public ed25519.PublicKey, private ed25519.PrivateKey) ([]byte, error) {
public_point, err := (&edwards25519.Point{}).SetBytes(public) public_point, err := (&edwards25519.Point{}).SetBytes(public)
if err != nil { if err != nil {