|
|
|
@ -141,140 +141,138 @@ func(server *Server) Stop() error {
|
|
|
|
|
|
|
|
|
|
const SESSION_BUFFER_SIZE = 256
|
|
|
|
|
|
|
|
|
|
func(server *Server) handle_session_open(client_session_open []byte, from *net.UDPAddr) error {
|
|
|
|
|
session_open, ecdh_private, err := NewSessionOpen(server.key)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
func handle_session_outgoing(session *ServerSession, server *Server) {
|
|
|
|
|
server.Log("Starting session outgoing goroutine %s", session.ID)
|
|
|
|
|
for true {
|
|
|
|
|
packet := <- session.OutgoingPackets
|
|
|
|
|
if packet == nil {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
session, err := ParseSessionOpen(ecdh_private, client_session_open)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if session.remote == nil {
|
|
|
|
|
server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
server.sessions_lock.Lock()
|
|
|
|
|
server.sessions[session.ID] = &ServerSession{
|
|
|
|
|
Session: session,
|
|
|
|
|
LastSeen: time.Now(),
|
|
|
|
|
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
|
|
|
|
|
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE),
|
|
|
|
|
}
|
|
|
|
|
server.sessions_lock.Unlock()
|
|
|
|
|
packet_data, err := packet.MarshalBinary()
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
go func(session *ServerSession, server *Server){
|
|
|
|
|
server.Log("Starting session outgoing goroutine %s", session.ID)
|
|
|
|
|
for true {
|
|
|
|
|
packet := <- session.OutgoingPackets
|
|
|
|
|
if packet == nil {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
encrypted, err := NewSessionData(&session.Session, packet_data)
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if session.remote == nil {
|
|
|
|
|
server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
_, err = server.connection.WriteToUDP(encrypted, session.remote)
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
server.Log("Stopping session outgoing goroutine %s", session.ID)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
packet_data, err := packet.MarshalBinary()
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
func handle_session_incoming(session *ServerSession, server *Server) {
|
|
|
|
|
server.Log("Starting session incoming goroutine %s", session.ID)
|
|
|
|
|
for true {
|
|
|
|
|
encrypted := <- session.IncomingPackets
|
|
|
|
|
if encrypted == nil {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
encrypted, err := NewSessionData(&session.Session, packet_data)
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
data, err := ParseSessionData(&session.Session, encrypted)
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = server.connection.WriteToUDP(encrypted, session.remote)
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
packet, err := ParsePacket(data)
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
|
|
|
|
|
}
|
|
|
|
|
server.Log("Stopping session outgoing goroutine %s", session.ID)
|
|
|
|
|
}(server.sessions[session.ID], server)
|
|
|
|
|
|
|
|
|
|
go func(session *ServerSession, server *Server){
|
|
|
|
|
server.Log("Starting session incoming goroutine %s", session.ID)
|
|
|
|
|
for true {
|
|
|
|
|
encrypted := <- session.IncomingPackets
|
|
|
|
|
if encrypted == nil {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
data, err := ParseSessionData(&session.Session, encrypted)
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err)
|
|
|
|
|
continue
|
|
|
|
|
switch packet := packet.(type) {
|
|
|
|
|
case ChannelDataPacket:
|
|
|
|
|
var result []SendPacket = nil
|
|
|
|
|
|
|
|
|
|
server.channels_lock.RLock()
|
|
|
|
|
channel, exists := server.channels[packet.Channel]
|
|
|
|
|
if exists == true {
|
|
|
|
|
result = channel.Data(&session.Session, packet.Mode, packet.Data)
|
|
|
|
|
}
|
|
|
|
|
server.channels_lock.RUnlock()
|
|
|
|
|
|
|
|
|
|
packet, err := ParsePacket(data)
|
|
|
|
|
if err != nil {
|
|
|
|
|
server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err)
|
|
|
|
|
if exists == false {
|
|
|
|
|
server.Log("Packet for unknown channel %d", packet.Channel)
|
|
|
|
|
} else if len(result) > 0 {
|
|
|
|
|
//TODO: handle overflow
|
|
|
|
|
server.send_packets<-result
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch packet := packet.(type) {
|
|
|
|
|
case ChannelDataPacket:
|
|
|
|
|
var result []SendPacket = nil
|
|
|
|
|
case ChannelCommandPacket:
|
|
|
|
|
var result []SendPacket = nil
|
|
|
|
|
|
|
|
|
|
server.channels_lock.RLock()
|
|
|
|
|
channel, exists := server.channels[packet.Channel]
|
|
|
|
|
if exists == true {
|
|
|
|
|
result = channel.Data(&session.Session, packet.Mode, packet.Data)
|
|
|
|
|
}
|
|
|
|
|
server.channels_lock.RUnlock()
|
|
|
|
|
server.channels_lock.RLock()
|
|
|
|
|
channel, exists := server.channels[packet.Channel]
|
|
|
|
|
if exists == true {
|
|
|
|
|
result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data)
|
|
|
|
|
}
|
|
|
|
|
server.channels_lock.RUnlock()
|
|
|
|
|
|
|
|
|
|
if exists == false {
|
|
|
|
|
server.Log("Packet for unknown channel %d", packet.Channel)
|
|
|
|
|
} else if err != nil {
|
|
|
|
|
server.Log("Error processing %+v - %s", packet, err)
|
|
|
|
|
} else if len(result) > 0 {
|
|
|
|
|
//TODO: handle overflow
|
|
|
|
|
server.send_packets<-result
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if exists == false {
|
|
|
|
|
server.Log("Packet for unknown channel %d", packet.Channel)
|
|
|
|
|
} else if len(result) > 0 {
|
|
|
|
|
//TODO: handle overflow
|
|
|
|
|
server.send_packets<-result
|
|
|
|
|
}
|
|
|
|
|
default:
|
|
|
|
|
server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
case ChannelCommandPacket:
|
|
|
|
|
var result []SendPacket = nil
|
|
|
|
|
}
|
|
|
|
|
server.Log("Stopping session incoming goroutine %s", session.ID)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
server.channels_lock.RLock()
|
|
|
|
|
channel, exists := server.channels[packet.Channel]
|
|
|
|
|
if exists == true {
|
|
|
|
|
result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data)
|
|
|
|
|
}
|
|
|
|
|
server.channels_lock.RUnlock()
|
|
|
|
|
|
|
|
|
|
if exists == false {
|
|
|
|
|
server.Log("Packet for unknown channel %d", packet.Channel)
|
|
|
|
|
} else if err != nil {
|
|
|
|
|
server.Log("Error processing %+v - %s", packet, err)
|
|
|
|
|
} else if len(result) > 0 {
|
|
|
|
|
//TODO: handle overflow
|
|
|
|
|
server.send_packets<-result
|
|
|
|
|
}
|
|
|
|
|
func(server *Server) handle_session_open(client_session_open []byte, from *net.UDPAddr) error {
|
|
|
|
|
session, session_opened, err := ParseSessionOpen(server.key, client_session_open)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err)
|
|
|
|
|
}
|
|
|
|
|
server.sessions_lock.Lock()
|
|
|
|
|
server.sessions[session.ID] = &ServerSession{
|
|
|
|
|
Session: session,
|
|
|
|
|
LastSeen: time.Now(),
|
|
|
|
|
IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE),
|
|
|
|
|
OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE),
|
|
|
|
|
}
|
|
|
|
|
server.sessions_lock.Unlock()
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
server.Log("Stopping session incoming goroutine %s", session.ID)
|
|
|
|
|
}(server.sessions[session.ID], server)
|
|
|
|
|
go handle_session_outgoing(server.sessions[session.ID], server)
|
|
|
|
|
go handle_session_incoming(server.sessions[session.ID], server)
|
|
|
|
|
|
|
|
|
|
_, err = server.connection.WriteToUDP(session_open, from)
|
|
|
|
|
_, err = server.connection.WriteToUDP(session_opened, from)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
server.Log("Started session %s with %s", session.ID, session.Peer)
|
|
|
|
|
server.Log("Started session %s with %x", session.ID, session.Peer)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func(server *Server) handle_session_connect(client_session_connect []byte, from *net.UDPAddr) error {
|
|
|
|
|
session_id := SessionID(client_session_connect[:ID_LENGTH])
|
|
|
|
|
func(server *Server) handle_session_connect(session_connect []byte, from *net.UDPAddr) error {
|
|
|
|
|
session_id := SessionID(session_connect[:SESSION_ID_LENGTH])
|
|
|
|
|
session, exists := server.sessions[session_id]
|
|
|
|
|
if exists == false {
|
|
|
|
|
return fmt.Errorf("Session %s does not exist, can't connect", session_id)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err := ParseSessionConnect(client_session_connect[ID_LENGTH:], session.secret)
|
|
|
|
|
session_connected, err := ParseSessionTimed(SESSION_CONNECTED, server.key, session_connect, &session.Session, time.Now().Add(-1*time.Second).UnixMilli())
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
@ -283,13 +281,7 @@ func(server *Server) handle_session_connect(client_session_connect []byte, from
|
|
|
|
|
session.remote = from
|
|
|
|
|
session.LastSeen = time.Now()
|
|
|
|
|
|
|
|
|
|
// TODO: Make a better server hello
|
|
|
|
|
server_hello, err := NewSessionData(&session.Session, []byte("hello"))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, err = server.connection.WriteToUDP(server_hello, session.remote)
|
|
|
|
|
_, err = server.connection.WriteToUDP(session_connected, session.remote)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
@ -297,14 +289,14 @@ func(server *Server) handle_session_connect(client_session_connect []byte, from
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func(server *Server) handle_session_close(client_session_close []byte, from *net.UDPAddr) error {
|
|
|
|
|
session_id := SessionID(client_session_close[:ID_LENGTH])
|
|
|
|
|
func(server *Server) handle_session_close(session_close []byte, from *net.UDPAddr) error {
|
|
|
|
|
session_id := SessionID(session_close[:SESSION_ID_LENGTH])
|
|
|
|
|
session, exists := server.sessions[session_id]
|
|
|
|
|
if exists == false {
|
|
|
|
|
return fmt.Errorf("Session %s does not exist, can't close", session_id)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err := ParseSessionClose(&session.Session, client_session_close[ID_LENGTH:])
|
|
|
|
|
session_closed, err := ParseSessionTimed(SESSION_CLOSED, server.key, session_close, &session.Session, time.Now().Add(-1*time.Second).UnixMilli())
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
@ -313,12 +305,17 @@ func(server *Server) handle_session_close(client_session_close []byte, from *net
|
|
|
|
|
server.close_session(session)
|
|
|
|
|
server.sessions_lock.Unlock()
|
|
|
|
|
|
|
|
|
|
_, err = server.connection.WriteToUDP(session_closed, session.remote)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
server.Log("Session %s closed", session_id)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error {
|
|
|
|
|
session_id := SessionID(data[:ID_LENGTH])
|
|
|
|
|
session_id := SessionID(data[:SESSION_ID_LENGTH])
|
|
|
|
|
session, exists := server.sessions[session_id]
|
|
|
|
|
if exists == false {
|
|
|
|
|
return fmt.Errorf("Session %s does not exist, can't receive data", session_id)
|
|
|
|
@ -326,8 +323,8 @@ func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error {
|
|
|
|
|
|
|
|
|
|
session.LastSeen = time.Now()
|
|
|
|
|
|
|
|
|
|
buf_copy := make([]byte, len(data) - ID_LENGTH)
|
|
|
|
|
copy(buf_copy, data[ID_LENGTH:])
|
|
|
|
|
buf_copy := make([]byte, len(data) - SESSION_ID_LENGTH)
|
|
|
|
|
copy(buf_copy, data[SESSION_ID_LENGTH:])
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case session.IncomingPackets<-buf_copy:
|
|
|
|
@ -426,7 +423,6 @@ func(server *Server) cleanup_sessions() {
|
|
|
|
|
for server.active.Load() {
|
|
|
|
|
select {
|
|
|
|
|
case <-time.After(SESSION_TIMEOUT_CHECK):
|
|
|
|
|
server.Log("Running stale session check")
|
|
|
|
|
server.sessions_lock.Lock()
|
|
|
|
|
now := time.Now()
|
|
|
|
|
stale_sessions := make([]*ServerSession, 0, len(server.sessions))
|
|
|
|
|