diff --git a/Cargo.lock b/Cargo.lock index 7b0df6e..fcee756 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -540,6 +540,12 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "prettyplease" version = "0.2.16" @@ -632,6 +638,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -1016,6 +1052,8 @@ dependencies = [ "prost-build", "prost-types", "protoc", + "rand", + "rand_core", "rumqttc", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index d7da985..f69cebf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,8 @@ env_logger = "0.11.0" openssl = "0.10.63" prost = "0.12.3" prost-types = "0.12.3" +rand = "0.8.5" +rand_core = "0.6.4" rumqttc = "0.23.0" serde = { version = "1.0.195", features = ["derive"]} serde_json = "1.0.111" diff --git a/src/main.rs b/src/main.rs index 14226e3..305b597 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,20 +2,18 @@ use rumqttc::{MqttOptions, Client, QoS, LastWill}; use bytes::Bytes; use std::time::Duration; use std::thread; -use std::sync::mpsc; use std::collections::hash_map::HashMap; use prost::Message; use std::io::Cursor; use serde::{Serialize, Deserialize}; +use rand_core::RngCore; use std::io::prelude::*; -use sha2::Digest; -use sha2::Sha256; +use sha2::{Sha256, Digest}; use std::net::TcpStream; -use std::sync::Mutex; -use std::sync::Arc; +use std::sync::mpsc; // MQTT Topics: // - division/{division_id} @@ -160,10 +158,11 @@ struct Match { name: String, } +#[derive(Debug)] struct BackendMessage { status: u8, request_id: u32, - data: Vec, + data: tm::BackendMessageData, } impl BackendMessage { @@ -172,11 +171,20 @@ impl BackendMessage { return None; } - return Some(BackendMessage{ - status: bytes[0], - request_id: u32::from_le_bytes(bytes[0..4].try_into().unwrap()), - data: bytes[5..].to_vec(), - }); + let mut pb_data = Cursor::new(bytes[5..].to_vec()); + match tm::BackendMessageData::decode(&mut pb_data) { + Ok(data) => + Some(BackendMessage{ + status: bytes[0], + request_id: u32::from_le_bytes(bytes[0..4].try_into().unwrap()), + data, + }), + Err(error) => { + println!("Protobuf parse error: {:?}", error); + return None; + }, + } + } fn as_bytes(self: BackendMessage) -> Vec { @@ -184,10 +192,17 @@ impl BackendMessage { bytes.push(self.status); bytes.extend(self.request_id.to_le_bytes()); - bytes.extend(self.data); - + bytes.extend(self.data.encode_to_vec()); return bytes; } + + fn new(request_id: u32, data: tm::BackendMessageData) -> BackendMessage { + BackendMessage{ + status: 0, + request_id, + data, + } + } } const BACKEND_PACKET_HEADER_SIZE: usize = 28; @@ -201,6 +216,7 @@ struct BackendPacket { data: Vec, } +const TM_HEADER: u32 = 0x55D33DAA; impl BackendPacket { fn new(header: u32, timestamp: f64, msg_type: u32, seq_num: u64, data: Vec) -> BackendPacket { return BackendPacket{ @@ -266,14 +282,7 @@ impl ConnectMsg { return ConnectMsg{ version: welcome.version, uuid, - // The TM returns state_valid=0 if last_notice_id < pendingNotices[0].id - // It looks like this check is basically to say "you can't connect if I sent you - // notices you weren't aware of" since in effect it prevents a client from connecting - // if it sends a last_notice_id less than what the TM expects - // To get around it you can just send the (max_u64 - 1) so that the check of - // (last_notice_id + 1) < pendingNotices[0].id always fails. - // The downside is that the TM will not send you any queued notices. - last_notice_id: 0xFFFFFFFFFFFFFFFF - 1, + last_notice_id: 0, username, pass_hash: result.try_into().unwrap(), pw_valid: welcome.pw_valid, @@ -318,7 +327,6 @@ impl ConnectMsg { } } -const NOTICE_MSG_LEN: usize = 8; #[derive(Debug)] struct NoticeMsg { notice_id: u64, @@ -327,28 +335,32 @@ struct NoticeMsg { impl NoticeMsg { fn from_bytes(bytes: Vec) -> Option { - if bytes.len() < NOTICE_MSG_LEN { + if bytes.len() < 8 { return None; } - // TODO: figure out what protobuf is containing the notice so that I don't add a static - // offset - let mut pb_data = Cursor::new(bytes[8..].to_vec()); - match tm::Notice::decode(&mut pb_data) { - Ok(notice) => Some(NoticeMsg{ - notice_id: u64::from_le_bytes(bytes[0..8].try_into().unwrap()), - notice, + let notice_id = u64::from_le_bytes(bytes[0..8].try_into().unwrap()); + + match BackendMessage::from_bytes(bytes[8..].to_vec()) { + Some(message) => { + match message.data.notice { + Some(notice) => Some(NoticeMsg{ + notice_id, + notice, + }), + None => None, } - ), - Err(_) => None, + }, + None => None, } } } struct TMClient { - stream: Arc>>, + stream: openssl::ssl::SslStream, notices: mpsc::Sender, responses: mpsc::Sender, + requests: mpsc::Receiver, uuid: [u8; 16], client_name: [u8; 32], password: String, @@ -357,15 +369,16 @@ struct TMClient { } struct TMConnection { - stream: Arc>>, notices: mpsc::Receiver, responses: mpsc::Receiver, + requests: mpsc::Sender, } impl TMClient { fn new(uuid: [u8; 16], client_name: [u8; 32], password: String, username: [u8; 16]) -> (TMClient, TMConnection) { let (notice_tx, notice_rx) = mpsc::channel(); let (response_tx, response_rx) = mpsc::channel(); + let (request_tx, request_rx) = mpsc::channel(); let mut builder = openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls()).unwrap(); builder.set_ca_file("tm.crt").unwrap(); @@ -382,12 +395,13 @@ impl TMClient { stream_config.set_use_server_name_indication(false); let stream = stream_config.connect("127.0.0.1", stream).unwrap(); - let stream_arc = Arc::new(Mutex::new(stream)); + stream.get_ref().set_read_timeout(Some(Duration::from_millis(100))).expect("Failed to set read timeout on socket"); return (TMClient{ - stream: stream_arc.clone(), + stream, notices: notice_tx, responses: response_tx, + requests: request_rx, uuid, client_name, password, @@ -395,17 +409,22 @@ impl TMClient { username, }, TMConnection{ - stream: stream_arc, + requests: request_tx, notices: notice_rx, responses: response_rx, },); } fn process(self: &mut TMClient) { - + for request in self.requests.try_iter() { + let packet = BackendPacket::new(TM_HEADER, 0.00, 3, self.last_seq_num + 1, request.as_bytes()); + match self.stream.write(&packet.as_bytes()) { + Ok(_) => self.last_seq_num += 1, + Err(error) => println!("Request send error: {:?}", error), + } + } let mut incoming = [0; 2048]; - let mut stream = self.stream.lock().unwrap(); - match stream.read(&mut incoming) { + match self.stream.read(&mut incoming) { Ok(read) => { let data = incoming[0..read].to_vec(); match BackendPacket::from_bytes(data) { @@ -419,33 +438,40 @@ impl TMClient { println!("Received notice: {:?}", notice); let ack = BackendPacket::new(packet.header, packet.timestamp, 5, self.last_seq_num+1, notice.notice_id.to_le_bytes().to_vec()); self.last_seq_num += 1; - match stream.write(&ack.as_bytes()) { + match self.stream.write(&ack.as_bytes()) { Ok(_) => println!("Sent ACK for notice {}", notice.notice_id), Err(error) => println!("ACK error: {:?}", error), } + match self.notices.send(notice.notice) { + Ok(_) => println!("Forwarded notice to callback engine"), + Err(error) => println!("Notice forward error {:?}", error), + } }, - None => println!("Notice error: {:?}", packet), + None => println!("Notice parse error: {:?}", packet), } }, // Response message 3 => { - println!("Received response. TODO: handle these"); + match BackendMessage::from_bytes(packet.data.clone()) { + Some(message) => { + match self.responses.send(message) { + Ok(_) => println!("Forwarded response to callback engine"), + Err(error) => println!("Response forward error {:?}", error), + } + }, + None => println!("BackendMessage parse error: {:?}", packet), + } }, // Server Message 2 => { match ConnectMsg::from_bytes(packet.data) { Some(welcome_msg) => { - println!("Welcome msg: {:?}", welcome_msg); if welcome_msg.pw_valid == 0 { let connect_response = ConnectMsg::from_welcome(welcome_msg, &self.password, self.uuid, self.client_name, self.username); let response = BackendPacket::new(packet.header, packet.timestamp, packet.msg_type, self.last_seq_num+1, connect_response.as_bytes()); - println!("Sending {:X?}", connect_response); - match stream.write(&response.as_bytes()) { + match self.stream.write(&response.as_bytes()) { Err(error) => println!("Send error: {:?}", error), - Ok(sent) => { - println!("Sent {} bytes", sent); - self.last_seq_num += 1; - }, + Ok(_) => self.last_seq_num += 1, } } else if welcome_msg.state_valid == 0 { println!("pw_valid but not state_valid"); @@ -465,27 +491,13 @@ impl TMClient { } } }, - Err(error) => println!("Error: {}", error), + Err(error) => {}, } } } -impl TMConnection { - fn send(self: &mut TMConnection, request: BackendMessage) -> Result { - let mut stream = self.stream.lock().unwrap(); - match stream.write(&request.as_bytes()) { - Ok(_) => - match self.responses.recv() { - Ok(response) => Ok(response), - Err(_) => Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "channel closed")), - }, - Err(error) => Err(error), - } - } -} - -type NoticeCallback = fn(tm::Notice, Event, ) -> (Vec, Event); +type NoticeCallback = fn(tm::Notice, Event) -> (Vec, Event); fn get_game_score(scores: tm::MatchScore) -> Option { if scores.alliances.len() != 2 { @@ -593,8 +605,6 @@ fn main() { div.id = Some(1); test.division = Some(div); - println!("ENCODED_PROTOBUF: {:X?}", test.encode_to_vec()); - let mut event = Event{ name: String::from(""), divisions: Vec::new(), @@ -627,20 +637,24 @@ fn main() { client.publish("bridge/status", QoS::AtLeastOnce, true, "{\"online\": true}").unwrap(); let mqtt_recv_thread = thread::spawn(move || - for _ in connection.iter() { - } - ); + for _ in connection.iter() { + } + ); let running = true; - let uuid = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - let client_name: [u8;32] = [b'r', b'u', b's', b't', b'-', b'b', b'r', b'i', b'd', b'g', b'e', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let mut uuid = [0u8; 16]; + rand::thread_rng().fill_bytes(&mut uuid); + let mut client_name = [0u8;32]; + rand::thread_rng().fill_bytes(&mut client_name); let username: [u8;16] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let (mut tm_client, tm_connection) = TMClient::new(uuid, client_name, String::from(""), username); let tm_thread = thread::spawn(move || - while running { - tm_client.process(); - } - ); + while running { + tm_client.process(); + } + ); + + // TODO: send "get_schedule" message and wait for it's response to know we're connected while running { thread::sleep(Duration::from_millis(1000));