More parsing fixes

master
noah metz 2024-01-20 23:28:58 -07:00
parent b25fa4e36d
commit 9ff15e3864
3 changed files with 131 additions and 77 deletions

38
Cargo.lock generated

@ -540,6 +540,12 @@ version = "0.3.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.16" version = "0.2.16"
@ -632,6 +638,36 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.4.1" version = "0.4.1"
@ -1016,6 +1052,8 @@ dependencies = [
"prost-build", "prost-build",
"prost-types", "prost-types",
"protoc", "protoc",
"rand",
"rand_core",
"rumqttc", "rumqttc",
"serde", "serde",
"serde_json", "serde_json",

@ -15,6 +15,8 @@ env_logger = "0.11.0"
openssl = "0.10.63" openssl = "0.10.63"
prost = "0.12.3" prost = "0.12.3"
prost-types = "0.12.3" prost-types = "0.12.3"
rand = "0.8.5"
rand_core = "0.6.4"
rumqttc = "0.23.0" rumqttc = "0.23.0"
serde = { version = "1.0.195", features = ["derive"]} serde = { version = "1.0.195", features = ["derive"]}
serde_json = "1.0.111" serde_json = "1.0.111"

@ -2,20 +2,18 @@ use rumqttc::{MqttOptions, Client, QoS, LastWill};
use bytes::Bytes; use bytes::Bytes;
use std::time::Duration; use std::time::Duration;
use std::thread; use std::thread;
use std::sync::mpsc;
use std::collections::hash_map::HashMap; use std::collections::hash_map::HashMap;
use prost::Message; use prost::Message;
use std::io::Cursor; use std::io::Cursor;
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use rand_core::RngCore;
use std::io::prelude::*; use std::io::prelude::*;
use sha2::Digest; use sha2::{Sha256, Digest};
use sha2::Sha256;
use std::net::TcpStream; use std::net::TcpStream;
use std::sync::Mutex; use std::sync::mpsc;
use std::sync::Arc;
// MQTT Topics: // MQTT Topics:
// - division/{division_id} // - division/{division_id}
@ -160,10 +158,11 @@ struct Match {
name: String, name: String,
} }
#[derive(Debug)]
struct BackendMessage { struct BackendMessage {
status: u8, status: u8,
request_id: u32, request_id: u32,
data: Vec<u8>, data: tm::BackendMessageData,
} }
impl BackendMessage { impl BackendMessage {
@ -172,11 +171,20 @@ impl BackendMessage {
return None; return None;
} }
return Some(BackendMessage{ let mut pb_data = Cursor::new(bytes[5..].to_vec());
status: bytes[0], match tm::BackendMessageData::decode(&mut pb_data) {
request_id: u32::from_le_bytes(bytes[0..4].try_into().unwrap()), Ok(data) =>
data: bytes[5..].to_vec(), Some(BackendMessage{
}); status: bytes[0],
request_id: u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
data,
}),
Err(error) => {
println!("Protobuf parse error: {:?}", error);
return None;
},
}
} }
fn as_bytes(self: BackendMessage) -> Vec<u8> { fn as_bytes(self: BackendMessage) -> Vec<u8> {
@ -184,10 +192,17 @@ impl BackendMessage {
bytes.push(self.status); bytes.push(self.status);
bytes.extend(self.request_id.to_le_bytes()); bytes.extend(self.request_id.to_le_bytes());
bytes.extend(self.data); bytes.extend(self.data.encode_to_vec());
return bytes; return bytes;
} }
fn new(request_id: u32, data: tm::BackendMessageData) -> BackendMessage {
BackendMessage{
status: 0,
request_id,
data,
}
}
} }
const BACKEND_PACKET_HEADER_SIZE: usize = 28; const BACKEND_PACKET_HEADER_SIZE: usize = 28;
@ -201,6 +216,7 @@ struct BackendPacket {
data: Vec<u8>, data: Vec<u8>,
} }
const TM_HEADER: u32 = 0x55D33DAA;
impl BackendPacket { impl BackendPacket {
fn new(header: u32, timestamp: f64, msg_type: u32, seq_num: u64, data: Vec<u8>) -> BackendPacket { fn new(header: u32, timestamp: f64, msg_type: u32, seq_num: u64, data: Vec<u8>) -> BackendPacket {
return BackendPacket{ return BackendPacket{
@ -266,14 +282,7 @@ impl ConnectMsg {
return ConnectMsg{ return ConnectMsg{
version: welcome.version, version: welcome.version,
uuid, uuid,
// The TM returns state_valid=0 if last_notice_id < pendingNotices[0].id last_notice_id: 0,
// It looks like this check is basically to say "you can't connect if I sent you
// notices you weren't aware of" since in effect it prevents a client from connecting
// if it sends a last_notice_id less than what the TM expects
// To get around it you can just send the (max_u64 - 1) so that the check of
// (last_notice_id + 1) < pendingNotices[0].id always fails.
// The downside is that the TM will not send you any queued notices.
last_notice_id: 0xFFFFFFFFFFFFFFFF - 1,
username, username,
pass_hash: result.try_into().unwrap(), pass_hash: result.try_into().unwrap(),
pw_valid: welcome.pw_valid, pw_valid: welcome.pw_valid,
@ -318,7 +327,6 @@ impl ConnectMsg {
} }
} }
const NOTICE_MSG_LEN: usize = 8;
#[derive(Debug)] #[derive(Debug)]
struct NoticeMsg { struct NoticeMsg {
notice_id: u64, notice_id: u64,
@ -327,28 +335,32 @@ struct NoticeMsg {
impl NoticeMsg { impl NoticeMsg {
fn from_bytes(bytes: Vec<u8>) -> Option<NoticeMsg> { fn from_bytes(bytes: Vec<u8>) -> Option<NoticeMsg> {
if bytes.len() < NOTICE_MSG_LEN { if bytes.len() < 8 {
return None; return None;
} }
// TODO: figure out what protobuf is containing the notice so that I don't add a static let notice_id = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
// offset
let mut pb_data = Cursor::new(bytes[8..].to_vec()); match BackendMessage::from_bytes(bytes[8..].to_vec()) {
match tm::Notice::decode(&mut pb_data) { Some(message) => {
Ok(notice) => Some(NoticeMsg{ match message.data.notice {
notice_id: u64::from_le_bytes(bytes[0..8].try_into().unwrap()), Some(notice) => Some(NoticeMsg{
notice, notice_id,
notice,
}),
None => None,
} }
), },
Err(_) => None, None => None,
} }
} }
} }
struct TMClient { struct TMClient {
stream: Arc<Mutex<openssl::ssl::SslStream<TcpStream>>>, stream: openssl::ssl::SslStream<TcpStream>,
notices: mpsc::Sender<tm::Notice>, notices: mpsc::Sender<tm::Notice>,
responses: mpsc::Sender<BackendMessage>, responses: mpsc::Sender<BackendMessage>,
requests: mpsc::Receiver<BackendMessage>,
uuid: [u8; 16], uuid: [u8; 16],
client_name: [u8; 32], client_name: [u8; 32],
password: String, password: String,
@ -357,15 +369,16 @@ struct TMClient {
} }
struct TMConnection { struct TMConnection {
stream: Arc<Mutex<openssl::ssl::SslStream<TcpStream>>>,
notices: mpsc::Receiver<tm::Notice>, notices: mpsc::Receiver<tm::Notice>,
responses: mpsc::Receiver<BackendMessage>, responses: mpsc::Receiver<BackendMessage>,
requests: mpsc::Sender<BackendMessage>,
} }
impl TMClient { impl TMClient {
fn new(uuid: [u8; 16], client_name: [u8; 32], password: String, username: [u8; 16]) -> (TMClient, TMConnection) { fn new(uuid: [u8; 16], client_name: [u8; 32], password: String, username: [u8; 16]) -> (TMClient, TMConnection) {
let (notice_tx, notice_rx) = mpsc::channel(); let (notice_tx, notice_rx) = mpsc::channel();
let (response_tx, response_rx) = mpsc::channel(); let (response_tx, response_rx) = mpsc::channel();
let (request_tx, request_rx) = mpsc::channel();
let mut builder = openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls()).unwrap(); let mut builder = openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls()).unwrap();
builder.set_ca_file("tm.crt").unwrap(); builder.set_ca_file("tm.crt").unwrap();
@ -382,12 +395,13 @@ impl TMClient {
stream_config.set_use_server_name_indication(false); stream_config.set_use_server_name_indication(false);
let stream = stream_config.connect("127.0.0.1", stream).unwrap(); let stream = stream_config.connect("127.0.0.1", stream).unwrap();
let stream_arc = Arc::new(Mutex::new(stream)); stream.get_ref().set_read_timeout(Some(Duration::from_millis(100))).expect("Failed to set read timeout on socket");
return (TMClient{ return (TMClient{
stream: stream_arc.clone(), stream,
notices: notice_tx, notices: notice_tx,
responses: response_tx, responses: response_tx,
requests: request_rx,
uuid, uuid,
client_name, client_name,
password, password,
@ -395,17 +409,22 @@ impl TMClient {
username, username,
}, },
TMConnection{ TMConnection{
stream: stream_arc, requests: request_tx,
notices: notice_rx, notices: notice_rx,
responses: response_rx, responses: response_rx,
},); },);
} }
fn process(self: &mut TMClient) { fn process(self: &mut TMClient) {
for request in self.requests.try_iter() {
let packet = BackendPacket::new(TM_HEADER, 0.00, 3, self.last_seq_num + 1, request.as_bytes());
match self.stream.write(&packet.as_bytes()) {
Ok(_) => self.last_seq_num += 1,
Err(error) => println!("Request send error: {:?}", error),
}
}
let mut incoming = [0; 2048]; let mut incoming = [0; 2048];
let mut stream = self.stream.lock().unwrap(); match self.stream.read(&mut incoming) {
match stream.read(&mut incoming) {
Ok(read) => { Ok(read) => {
let data = incoming[0..read].to_vec(); let data = incoming[0..read].to_vec();
match BackendPacket::from_bytes(data) { match BackendPacket::from_bytes(data) {
@ -419,33 +438,40 @@ impl TMClient {
println!("Received notice: {:?}", notice); println!("Received notice: {:?}", notice);
let ack = BackendPacket::new(packet.header, packet.timestamp, 5, self.last_seq_num+1, notice.notice_id.to_le_bytes().to_vec()); let ack = BackendPacket::new(packet.header, packet.timestamp, 5, self.last_seq_num+1, notice.notice_id.to_le_bytes().to_vec());
self.last_seq_num += 1; self.last_seq_num += 1;
match stream.write(&ack.as_bytes()) { match self.stream.write(&ack.as_bytes()) {
Ok(_) => println!("Sent ACK for notice {}", notice.notice_id), Ok(_) => println!("Sent ACK for notice {}", notice.notice_id),
Err(error) => println!("ACK error: {:?}", error), Err(error) => println!("ACK error: {:?}", error),
} }
match self.notices.send(notice.notice) {
Ok(_) => println!("Forwarded notice to callback engine"),
Err(error) => println!("Notice forward error {:?}", error),
}
}, },
None => println!("Notice error: {:?}", packet), None => println!("Notice parse error: {:?}", packet),
} }
}, },
// Response message // Response message
3 => { 3 => {
println!("Received response. TODO: handle these"); match BackendMessage::from_bytes(packet.data.clone()) {
Some(message) => {
match self.responses.send(message) {
Ok(_) => println!("Forwarded response to callback engine"),
Err(error) => println!("Response forward error {:?}", error),
}
},
None => println!("BackendMessage parse error: {:?}", packet),
}
}, },
// Server Message // Server Message
2 => { 2 => {
match ConnectMsg::from_bytes(packet.data) { match ConnectMsg::from_bytes(packet.data) {
Some(welcome_msg) => { Some(welcome_msg) => {
println!("Welcome msg: {:?}", welcome_msg);
if welcome_msg.pw_valid == 0 { if welcome_msg.pw_valid == 0 {
let connect_response = ConnectMsg::from_welcome(welcome_msg, &self.password, self.uuid, self.client_name, self.username); let connect_response = ConnectMsg::from_welcome(welcome_msg, &self.password, self.uuid, self.client_name, self.username);
let response = BackendPacket::new(packet.header, packet.timestamp, packet.msg_type, self.last_seq_num+1, connect_response.as_bytes()); let response = BackendPacket::new(packet.header, packet.timestamp, packet.msg_type, self.last_seq_num+1, connect_response.as_bytes());
println!("Sending {:X?}", connect_response); match self.stream.write(&response.as_bytes()) {
match stream.write(&response.as_bytes()) {
Err(error) => println!("Send error: {:?}", error), Err(error) => println!("Send error: {:?}", error),
Ok(sent) => { Ok(_) => self.last_seq_num += 1,
println!("Sent {} bytes", sent);
self.last_seq_num += 1;
},
} }
} else if welcome_msg.state_valid == 0 { } else if welcome_msg.state_valid == 0 {
println!("pw_valid but not state_valid"); println!("pw_valid but not state_valid");
@ -465,27 +491,13 @@ impl TMClient {
} }
} }
}, },
Err(error) => println!("Error: {}", error), Err(error) => {},
} }
} }
} }
impl TMConnection { type NoticeCallback = fn(tm::Notice, Event) -> (Vec<MQTTMessage>, Event);
fn send(self: &mut TMConnection, request: BackendMessage) -> Result<BackendMessage, std::io::Error> {
let mut stream = self.stream.lock().unwrap();
match stream.write(&request.as_bytes()) {
Ok(_) =>
match self.responses.recv() {
Ok(response) => Ok(response),
Err(_) => Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "channel closed")),
},
Err(error) => Err(error),
}
}
}
type NoticeCallback = fn(tm::Notice, Event, ) -> (Vec<MQTTMessage>, Event);
fn get_game_score(scores: tm::MatchScore) -> Option<GameScore> { fn get_game_score(scores: tm::MatchScore) -> Option<GameScore> {
if scores.alliances.len() != 2 { if scores.alliances.len() != 2 {
@ -593,8 +605,6 @@ fn main() {
div.id = Some(1); div.id = Some(1);
test.division = Some(div); test.division = Some(div);
println!("ENCODED_PROTOBUF: {:X?}", test.encode_to_vec());
let mut event = Event{ let mut event = Event{
name: String::from(""), name: String::from(""),
divisions: Vec::new(), divisions: Vec::new(),
@ -627,20 +637,24 @@ fn main() {
client.publish("bridge/status", QoS::AtLeastOnce, true, "{\"online\": true}").unwrap(); client.publish("bridge/status", QoS::AtLeastOnce, true, "{\"online\": true}").unwrap();
let mqtt_recv_thread = thread::spawn(move || let mqtt_recv_thread = thread::spawn(move ||
for _ in connection.iter() { for _ in connection.iter() {
} }
); );
let running = true; let running = true;
let uuid = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; let mut uuid = [0u8; 16];
let client_name: [u8;32] = [b'r', b'u', b's', b't', b'-', b'b', b'r', b'i', b'd', b'g', b'e', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; rand::thread_rng().fill_bytes(&mut uuid);
let mut client_name = [0u8;32];
rand::thread_rng().fill_bytes(&mut client_name);
let username: [u8;16] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; let username: [u8;16] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let (mut tm_client, tm_connection) = TMClient::new(uuid, client_name, String::from(""), username); let (mut tm_client, tm_connection) = TMClient::new(uuid, client_name, String::from(""), username);
let tm_thread = thread::spawn(move || let tm_thread = thread::spawn(move ||
while running { while running {
tm_client.process(); tm_client.process();
} }
); );
// TODO: send "get_schedule" message and wait for it's response to know we're connected
while running { while running {
thread::sleep(Duration::from_millis(1000)); thread::sleep(Duration::from_millis(1000));