Moved session packet handlers to functions for reasability

live
noah metz 2024-04-09 17:20:37 -06:00
parent 17990e35d9
commit f32087f729
1 changed files with 205 additions and 182 deletions

@ -141,212 +141,235 @@ func(server *Server) Stop() error {
const SESSION_BUFFER_SIZE = 256 const SESSION_BUFFER_SIZE = 256
func(server *Server) listen_udp() { func(server *Server) handle_session_open(client_session_open []byte, from *net.UDPAddr) error {
server.Log("Started server on %s", server.connection.LocalAddr()) session_open, ecdh_private, err := NewSessionOpen(server.key)
if err != nil {
return err
}
var buf [SERVER_UDP_BUFFER_SIZE]byte session, err := ParseSessionOpen(ecdh_private, client_session_open)
for true { if err != nil {
read, from, err := server.connection.ReadFromUDP(buf[:]) return err
if err == nil { }
var packet_type SessionPacketType = SessionPacketType(buf[0])
switch packet_type {
case SESSION_OPEN:
session_open, ecdh_private, err := NewSessionOpen(server.key)
if err != nil {
server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read])
continue
}
session, err := ParseSessionOpen(ecdh_private, buf[COMMAND_LENGTH:read]) server.sessions_lock.Lock()
if err != nil { server.sessions[session.ID] = &ServerSession{
server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) Session: session,
continue LastSeen: time.Now(),
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE),
}
server.sessions_lock.Unlock()
go func(session *ServerSession, server *Server){
server.Log("Starting session outgoing goroutine %s", session.ID)
for true {
packet := <- session.OutgoingPackets
if packet == nil {
break
}
if session.remote == nil {
server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID)
continue
}
packet_data, err := packet.MarshalBinary()
if err != nil {
server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err)
continue
}
encrypted, err := NewSessionData(&session.Session, packet_data)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err)
continue
}
_, err = server.connection.WriteToUDP(encrypted, session.remote)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err)
continue
}
}
server.Log("Stopping session outgoing goroutine %s", session.ID)
}(server.sessions[session.ID], server)
go func(session *ServerSession, server *Server){
server.Log("Starting session incoming goroutine %s", session.ID)
for true {
encrypted := <- session.IncomingPackets
if encrypted == nil {
break
}
data, err := ParseSessionData(&session.Session, encrypted)
if err != nil {
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
continue
}
packet, err := ParsePacket(data)
if err != nil {
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
}
switch packet := packet.(type) {
case ChannelDataPacket:
var result []SendPacket = nil
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
result = channel.Data(&session.Session, packet.Mode, packet.Data)
} }
server.channels_lock.RUnlock()
server.sessions_lock.Lock() if exists == false {
server.sessions[session.ID] = &ServerSession{ server.Log("Packet for unknown channel %d", packet.Channel)
Session: session, } else if len(result) > 0 {
LastSeen: time.Now(), //TODO: handle overflow
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE), server.send_packets<-result
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE),
} }
server.sessions_lock.Unlock()
case ChannelCommandPacket:
go func(session *ServerSession, server *Server){ var result []SendPacket = nil
server.Log("Starting session outgoing goroutine %s", session.ID)
for true { server.channels_lock.RLock()
packet := <- session.OutgoingPackets channel, exists := server.channels[packet.Channel]
if packet == nil { if exists == true {
break result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data)
}
if session.remote == nil {
server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID)
continue
}
packet_data, err := packet.MarshalBinary()
if err != nil {
server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err)
continue
}
encrypted, err := NewSessionData(&session.Session, packet_data)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err)
continue
}
_, err = server.connection.WriteToUDP(encrypted, session.remote)
if err != nil {
server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err)
continue
}
}
server.Log("Stopping session outgoing goroutine %s", session.ID)
}(server.sessions[session.ID], server)
go func(session *ServerSession, server *Server){
server.Log("Starting session incoming goroutine %s", session.ID)
for true {
encrypted := <- session.IncomingPackets
if encrypted == nil {
break
}
data, err := ParseSessionData(&session.Session, encrypted)
if err != nil {
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
continue
}
packet, err := ParsePacket(data)
if err != nil {
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
}
switch packet := packet.(type) {
case ChannelDataPacket:
var result []SendPacket = nil
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
result = channel.Data(&session.Session, packet.Mode, packet.Data)
}
server.channels_lock.RUnlock()
if exists == false {
server.Log("Packet for unknown channel %d", packet.Channel)
} else if len(result) > 0 {
//TODO: handle overflow
server.send_packets<-result
}
case ChannelCommandPacket:
var result []SendPacket = nil
server.channels_lock.RLock()
channel, exists := server.channels[packet.Channel]
if exists == true {
result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data)
}
server.channels_lock.RUnlock()
if exists == false {
server.Log("Packet for unknown channel %d", packet.Channel)
} else if err != nil {
server.Log("Error processing %+v - %s", packet, err)
} else if len(result) > 0 {
//TODO: handle overflow
server.send_packets<-result
}
default:
server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
}
}
server.Log("Stopping session incoming goroutine %s", session.ID)
}(server.sessions[session.ID], server)
_, err = server.connection.WriteToUDP(session_open, from)
if err != nil {
server.Log("WriteToUDP error %s", err)
continue
} }
server.Log("Started session %s with %s", session.ID, session.Peer) server.channels_lock.RUnlock()
case SESSION_CONNECT:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false { if exists == false {
server.Log("Session %s does not exist, can't connect", session_id) server.Log("Packet for unknown channel %d", packet.Channel)
continue } else if err != nil {
server.Log("Error processing %+v - %s", packet, err)
} else if len(result) > 0 {
//TODO: handle overflow
server.send_packets<-result
} }
_, err := ParseSessionConnect(buf[COMMAND_LENGTH+ID_LENGTH:read], session.secret) default:
if err != nil { server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
server.Log("Error parsing session connect: %s", err) }
continue
} }
server.Log("Stopping session incoming goroutine %s", session.ID)
}(server.sessions[session.ID], server)
_, err = server.connection.WriteToUDP(session_open, from)
if err != nil {
return err
}
server.Log("Started session %s with %s", session.ID, session.Peer)
return nil
}
func(server *Server) handle_session_connect(client_session_connect []byte, from *net.UDPAddr) error {
session_id := SessionID(client_session_connect[:ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
return fmt.Errorf("Session %s does not exist, can't connect", session_id)
}
_, err := ParseSessionConnect(client_session_connect[ID_LENGTH:], session.secret)
if err != nil {
return err
}
// TODO: fix, was client_addr but the client doesnt know it's nat assignment
session.remote = from
session.LastSeen = time.Now()
// TODO: Make a better server hello
server_hello, err := NewSessionData(&session.Session, []byte("hello"))
if err != nil {
return err
}
_, err = server.connection.WriteToUDP(server_hello, session.remote)
if err != nil {
return err
}
server.Log("Sent server_hello for %s to %s(from %s)", session.ID, session.remote, from)
return nil
}
func(server *Server) handle_session_close(client_session_close []byte, from *net.UDPAddr) error {
session_id := SessionID(client_session_close[:ID_LENGTH])
session, exists := server.sessions[session_id]
if exists == false {
return fmt.Errorf("Session %s does not exist, can't close", session_id)
}
err := ParseSessionClose(&session.Session, client_session_close[ID_LENGTH:])
if err != nil {
return err
}
server.sessions_lock.Lock()
server.close_session(session)
server.sessions_lock.Unlock()
server.Log("Session %s closed", session_id)
return nil
}
// TODO: fix, was client_addr but the client doesnt know it's nat assignment func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error {
session.remote = from session_id := SessionID(data[:ID_LENGTH])
session.LastSeen = time.Now() session, exists := server.sessions[session_id]
if exists == false {
return fmt.Errorf("Session %s does not exist, can't receive data", session_id)
}
session.LastSeen = time.Now()
buf_copy := make([]byte, len(data) - ID_LENGTH)
copy(buf_copy, data[ID_LENGTH:])
// TODO: Make a better server hello select {
server_hello, err := NewSessionData(&session.Session, []byte("hello")) case session.IncomingPackets<-buf_copy:
return nil
default:
return fmt.Errorf("Dropped packet to session %s", session_id)
}
}
func(server *Server) listen_udp() {
server.Log("Started server on %s", server.connection.LocalAddr())
var buf [SERVER_UDP_BUFFER_SIZE]byte
for true {
read, from, err := server.connection.ReadFromUDP(buf[:])
if err == nil {
var packet_type SessionPacketType = SessionPacketType(buf[0])
switch packet_type {
case SESSION_OPEN:
err := server.handle_session_open(buf[COMMAND_LENGTH:read], from)
if err != nil { if err != nil {
server.Log("Error generating server hello: %s", err) server.Log("handle_session_open erorr - %s", err)
continue
} }
_, err = server.connection.WriteToUDP(server_hello, session.remote) case SESSION_CONNECT:
err := server.handle_session_connect(buf[COMMAND_LENGTH:read], from)
if err != nil { if err != nil {
server.Log("Error sending server hello: %s", err) server.Log("handle_session_connect error - %s", err)
continue
} }
server.Log("Sent server_hello for %s to %s(from %s)", session.ID, session.remote, from)
case SESSION_CLOSE: case SESSION_CLOSE:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) err := server.handle_session_close(buf[COMMAND_LENGTH:read], from)
session, exists := server.sessions[session_id]
if exists == false {
server.Log("Session %s does not exist, can't close", session_id)
continue
}
err := ParseSessionClose(&session.Session, buf[COMMAND_LENGTH+ID_LENGTH:read])
if err != nil { if err != nil {
server.Log("Session close error for %s - %s", session_id, err) server.Log("handle_session_close error - %s", err)
continue
} }
server.sessions_lock.Lock()
server.close_session(session)
server.sessions_lock.Unlock()
server.Log("Session %s closed", session_id)
case SESSION_DATA: case SESSION_DATA:
session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) err := server.handle_session_data(buf[COMMAND_LENGTH:read], from)
session, exists := server.sessions[session_id] if err != nil {
if exists == false { server.Log("handle_session_data error - %s", err)
server.Log("Session %s does not exist, can't receive data", session_id)
continue
} }
session.LastSeen = time.Now()
buf_copy := make([]byte, read - COMMAND_LENGTH - ID_LENGTH)
copy(buf_copy, buf[COMMAND_LENGTH+ID_LENGTH:read])
select {
case session.IncomingPackets<-buf_copy:
default:
server.Log("Dropped packet to session %s", session_id)
}
default: default:
server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read]) server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read])
} }