diff --git a/server.go b/server.go index 9e1f7e0..30cfd76 100644 --- a/server.go +++ b/server.go @@ -141,212 +141,235 @@ func(server *Server) Stop() error { const SESSION_BUFFER_SIZE = 256 -func(server *Server) listen_udp() { - server.Log("Started server on %s", server.connection.LocalAddr()) +func(server *Server) handle_session_open(client_session_open []byte, from *net.UDPAddr) error { + session_open, ecdh_private, err := NewSessionOpen(server.key) + if err != nil { + return err + } - var buf [SERVER_UDP_BUFFER_SIZE]byte - for true { - read, from, err := server.connection.ReadFromUDP(buf[:]) - if err == nil { - var packet_type SessionPacketType = SessionPacketType(buf[0]) - switch packet_type { - case SESSION_OPEN: - session_open, ecdh_private, err := NewSessionOpen(server.key) - if err != nil { - server.Log("NewSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) - continue - } + session, err := ParseSessionOpen(ecdh_private, client_session_open) + if err != nil { + return err + } - session, err := ParseSessionOpen(ecdh_private, buf[COMMAND_LENGTH:read]) - if err != nil { - server.Log("ParseSessionOpen error - %s: %x", err, buf[COMMAND_LENGTH:read]) - continue + server.sessions_lock.Lock() + server.sessions[session.ID] = &ServerSession{ + Session: session, + LastSeen: time.Now(), + IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE), + OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE), + } + server.sessions_lock.Unlock() + + go func(session *ServerSession, server *Server){ + server.Log("Starting session outgoing goroutine %s", session.ID) + for true { + packet := <- session.OutgoingPackets + if packet == nil { + break + } + + if session.remote == nil { + server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID) + continue + } + + packet_data, err := packet.MarshalBinary() + if err != nil { + server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err) + continue + } + + encrypted, err := NewSessionData(&session.Session, packet_data) + if err != nil { + server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err) + continue + } + + _, err = server.connection.WriteToUDP(encrypted, session.remote) + if err != nil { + server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err) + continue + } + } + server.Log("Stopping session outgoing goroutine %s", session.ID) + }(server.sessions[session.ID], server) + + go func(session *ServerSession, server *Server){ + server.Log("Starting session incoming goroutine %s", session.ID) + for true { + encrypted := <- session.IncomingPackets + if encrypted == nil { + break + } + + data, err := ParseSessionData(&session.Session, encrypted) + if err != nil { + server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) + continue + } + + packet, err := ParsePacket(data) + if err != nil { + server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err) + } + + switch packet := packet.(type) { + case ChannelDataPacket: + var result []SendPacket = nil + + server.channels_lock.RLock() + channel, exists := server.channels[packet.Channel] + if exists == true { + result = channel.Data(&session.Session, packet.Mode, packet.Data) } + server.channels_lock.RUnlock() - server.sessions_lock.Lock() - server.sessions[session.ID] = &ServerSession{ - Session: session, - LastSeen: time.Now(), - IncomingPackets: make(chan[]byte, SESSION_BUFFER_SIZE), - OutgoingPackets: make(chan *Packet, SESSION_BUFFER_SIZE), + if exists == false { + server.Log("Packet for unknown channel %d", packet.Channel) + } else if len(result) > 0 { + //TODO: handle overflow + server.send_packets<-result } - server.sessions_lock.Unlock() - - go func(session *ServerSession, server *Server){ - server.Log("Starting session outgoing goroutine %s", session.ID) - for true { - packet := <- session.OutgoingPackets - if packet == nil { - break - } - - if session.remote == nil { - server.Log("SESSION_DATA_OUT(%s) error - no remote to send to", session.ID) - continue - } - - packet_data, err := packet.MarshalBinary() - if err != nil { - server.Log("SESSION_DATA_OUT(%s) marshal error - %s", session.ID, err) - continue - } - - encrypted, err := NewSessionData(&session.Session, packet_data) - if err != nil { - server.Log("SESSION_DATA_OUT(%s) error - %s", session.ID, err) - continue - } - - _, err = server.connection.WriteToUDP(encrypted, session.remote) - if err != nil { - server.Log("SESSION_DATA_OUT(%s) write error - %s", session.ID, err) - continue - } - } - server.Log("Stopping session outgoing goroutine %s", session.ID) - }(server.sessions[session.ID], server) - - go func(session *ServerSession, server *Server){ - server.Log("Starting session incoming goroutine %s", session.ID) - for true { - encrypted := <- session.IncomingPackets - if encrypted == nil { - break - } - - data, err := ParseSessionData(&session.Session, encrypted) - if err != nil { - server.Log("SESSION_DATA_IN(%s) error - %s", session.ID, err) - continue - } - - packet, err := ParsePacket(data) - if err != nil { - server.Log("SESSION_DATA_IN(%s) parse error - %s", session.ID, err) - } - - switch packet := packet.(type) { - case ChannelDataPacket: - var result []SendPacket = nil - - server.channels_lock.RLock() - channel, exists := server.channels[packet.Channel] - if exists == true { - result = channel.Data(&session.Session, packet.Mode, packet.Data) - } - server.channels_lock.RUnlock() - - if exists == false { - server.Log("Packet for unknown channel %d", packet.Channel) - } else if len(result) > 0 { - //TODO: handle overflow - server.send_packets<-result - } - - case ChannelCommandPacket: - var result []SendPacket = nil - - server.channels_lock.RLock() - channel, exists := server.channels[packet.Channel] - if exists == true { - result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data) - } - server.channels_lock.RUnlock() - - if exists == false { - server.Log("Packet for unknown channel %d", packet.Channel) - } else if err != nil { - server.Log("Error processing %+v - %s", packet, err) - } else if len(result) > 0 { - //TODO: handle overflow - server.send_packets<-result - } - - default: - server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err) - } - - } - server.Log("Stopping session incoming goroutine %s", session.ID) - }(server.sessions[session.ID], server) - - _, err = server.connection.WriteToUDP(session_open, from) - if err != nil { - server.Log("WriteToUDP error %s", err) - continue + + case ChannelCommandPacket: + var result []SendPacket = nil + + server.channels_lock.RLock() + channel, exists := server.channels[packet.Channel] + if exists == true { + result, err = channel.Command(&session.Session, packet.Mode, packet.Command, packet.Data) } - server.Log("Started session %s with %s", session.ID, session.Peer) + server.channels_lock.RUnlock() - case SESSION_CONNECT: - session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) - session, exists := server.sessions[session_id] if exists == false { - server.Log("Session %s does not exist, can't connect", session_id) - continue + server.Log("Packet for unknown channel %d", packet.Channel) + } else if err != nil { + server.Log("Error processing %+v - %s", packet, err) + } else if len(result) > 0 { + //TODO: handle overflow + server.send_packets<-result } - _, err := ParseSessionConnect(buf[COMMAND_LENGTH+ID_LENGTH:read], session.secret) - if err != nil { - server.Log("Error parsing session connect: %s", err) - continue - } + default: + server.Log("Unhandled packet type from session %s - 0x%02x", session.ID, err) + } + + } + server.Log("Stopping session incoming goroutine %s", session.ID) + }(server.sessions[session.ID], server) + + _, err = server.connection.WriteToUDP(session_open, from) + if err != nil { + return err + } + server.Log("Started session %s with %s", session.ID, session.Peer) + return nil +} + +func(server *Server) handle_session_connect(client_session_connect []byte, from *net.UDPAddr) error { + session_id := SessionID(client_session_connect[:ID_LENGTH]) + session, exists := server.sessions[session_id] + if exists == false { + return fmt.Errorf("Session %s does not exist, can't connect", session_id) + } - // TODO: fix, was client_addr but the client doesnt know it's nat assignment - session.remote = from - session.LastSeen = time.Now() - - // TODO: Make a better server hello - server_hello, err := NewSessionData(&session.Session, []byte("hello")) + _, err := ParseSessionConnect(client_session_connect[ID_LENGTH:], session.secret) + if err != nil { + return err + } + + // TODO: fix, was client_addr but the client doesnt know it's nat assignment + session.remote = from + session.LastSeen = time.Now() + + // TODO: Make a better server hello + server_hello, err := NewSessionData(&session.Session, []byte("hello")) + if err != nil { + return err + } + + _, err = server.connection.WriteToUDP(server_hello, session.remote) + if err != nil { + return err + } + server.Log("Sent server_hello for %s to %s(from %s)", session.ID, session.remote, from) + return nil +} + +func(server *Server) handle_session_close(client_session_close []byte, from *net.UDPAddr) error { + session_id := SessionID(client_session_close[:ID_LENGTH]) + session, exists := server.sessions[session_id] + if exists == false { + return fmt.Errorf("Session %s does not exist, can't close", session_id) + } + + err := ParseSessionClose(&session.Session, client_session_close[ID_LENGTH:]) + if err != nil { + return err + } + + server.sessions_lock.Lock() + server.close_session(session) + server.sessions_lock.Unlock() + + server.Log("Session %s closed", session_id) + return nil +} + +func(server *Server) handle_session_data(data []byte, from *net.UDPAddr) error { + session_id := SessionID(data[:ID_LENGTH]) + session, exists := server.sessions[session_id] + if exists == false { + return fmt.Errorf("Session %s does not exist, can't receive data", session_id) + } + + session.LastSeen = time.Now() + + buf_copy := make([]byte, len(data) - ID_LENGTH) + copy(buf_copy, data[ID_LENGTH:]) + + select { + case session.IncomingPackets<-buf_copy: + return nil + default: + return fmt.Errorf("Dropped packet to session %s", session_id) + } +} + +func(server *Server) listen_udp() { + server.Log("Started server on %s", server.connection.LocalAddr()) + + var buf [SERVER_UDP_BUFFER_SIZE]byte + for true { + read, from, err := server.connection.ReadFromUDP(buf[:]) + if err == nil { + var packet_type SessionPacketType = SessionPacketType(buf[0]) + switch packet_type { + case SESSION_OPEN: + err := server.handle_session_open(buf[COMMAND_LENGTH:read], from) if err != nil { - server.Log("Error generating server hello: %s", err) - continue + server.Log("handle_session_open erorr - %s", err) } - _, err = server.connection.WriteToUDP(server_hello, session.remote) + case SESSION_CONNECT: + err := server.handle_session_connect(buf[COMMAND_LENGTH:read], from) if err != nil { - server.Log("Error sending server hello: %s", err) - continue + server.Log("handle_session_connect error - %s", err) } - server.Log("Sent server_hello for %s to %s(from %s)", session.ID, session.remote, from) case SESSION_CLOSE: - session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) - session, exists := server.sessions[session_id] - if exists == false { - server.Log("Session %s does not exist, can't close", session_id) - continue - } - - err := ParseSessionClose(&session.Session, buf[COMMAND_LENGTH+ID_LENGTH:read]) + err := server.handle_session_close(buf[COMMAND_LENGTH:read], from) if err != nil { - server.Log("Session close error for %s - %s", session_id, err) - continue + server.Log("handle_session_close error - %s", err) } - server.sessions_lock.Lock() - server.close_session(session) - server.sessions_lock.Unlock() - - server.Log("Session %s closed", session_id) - case SESSION_DATA: - session_id := SessionID(buf[COMMAND_LENGTH:COMMAND_LENGTH+ID_LENGTH]) - session, exists := server.sessions[session_id] - if exists == false { - server.Log("Session %s does not exist, can't receive data", session_id) - continue + err := server.handle_session_data(buf[COMMAND_LENGTH:read], from) + if err != nil { + server.Log("handle_session_data error - %s", err) } - session.LastSeen = time.Now() - - buf_copy := make([]byte, read - COMMAND_LENGTH - ID_LENGTH) - copy(buf_copy, buf[COMMAND_LENGTH+ID_LENGTH:read]) - - select { - case session.IncomingPackets<-buf_copy: - default: - server.Log("Dropped packet to session %s", session_id) - } default: server.Log("Unhandled packet type 0x%04x from %s: %+v", packet_type, from, buf[COMMAND_LENGTH:read]) }