More parsing fixes

master
noah metz 2024-01-20 23:28:58 -07:00
parent b25fa4e36d
commit 9ff15e3864
3 changed files with 131 additions and 77 deletions

38
Cargo.lock generated

@ -540,6 +540,12 @@ version = "0.3.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "prettyplease"
version = "0.2.16"
@ -632,6 +638,36 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "redox_syscall"
version = "0.4.1"
@ -1016,6 +1052,8 @@ dependencies = [
"prost-build",
"prost-types",
"protoc",
"rand",
"rand_core",
"rumqttc",
"serde",
"serde_json",

@ -15,6 +15,8 @@ env_logger = "0.11.0"
openssl = "0.10.63"
prost = "0.12.3"
prost-types = "0.12.3"
rand = "0.8.5"
rand_core = "0.6.4"
rumqttc = "0.23.0"
serde = { version = "1.0.195", features = ["derive"]}
serde_json = "1.0.111"

@ -2,20 +2,18 @@ use rumqttc::{MqttOptions, Client, QoS, LastWill};
use bytes::Bytes;
use std::time::Duration;
use std::thread;
use std::sync::mpsc;
use std::collections::hash_map::HashMap;
use prost::Message;
use std::io::Cursor;
use serde::{Serialize, Deserialize};
use rand_core::RngCore;
use std::io::prelude::*;
use sha2::Digest;
use sha2::Sha256;
use sha2::{Sha256, Digest};
use std::net::TcpStream;
use std::sync::Mutex;
use std::sync::Arc;
use std::sync::mpsc;
// MQTT Topics:
// - division/{division_id}
@ -160,10 +158,11 @@ struct Match {
name: String,
}
#[derive(Debug)]
struct BackendMessage {
status: u8,
request_id: u32,
data: Vec<u8>,
data: tm::BackendMessageData,
}
impl BackendMessage {
@ -172,11 +171,20 @@ impl BackendMessage {
return None;
}
return Some(BackendMessage{
status: bytes[0],
request_id: u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
data: bytes[5..].to_vec(),
});
let mut pb_data = Cursor::new(bytes[5..].to_vec());
match tm::BackendMessageData::decode(&mut pb_data) {
Ok(data) =>
Some(BackendMessage{
status: bytes[0],
request_id: u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
data,
}),
Err(error) => {
println!("Protobuf parse error: {:?}", error);
return None;
},
}
}
fn as_bytes(self: BackendMessage) -> Vec<u8> {
@ -184,10 +192,17 @@ impl BackendMessage {
bytes.push(self.status);
bytes.extend(self.request_id.to_le_bytes());
bytes.extend(self.data);
bytes.extend(self.data.encode_to_vec());
return bytes;
}
fn new(request_id: u32, data: tm::BackendMessageData) -> BackendMessage {
BackendMessage{
status: 0,
request_id,
data,
}
}
}
const BACKEND_PACKET_HEADER_SIZE: usize = 28;
@ -201,6 +216,7 @@ struct BackendPacket {
data: Vec<u8>,
}
const TM_HEADER: u32 = 0x55D33DAA;
impl BackendPacket {
fn new(header: u32, timestamp: f64, msg_type: u32, seq_num: u64, data: Vec<u8>) -> BackendPacket {
return BackendPacket{
@ -266,14 +282,7 @@ impl ConnectMsg {
return ConnectMsg{
version: welcome.version,
uuid,
// The TM returns state_valid=0 if last_notice_id < pendingNotices[0].id
// It looks like this check is basically to say "you can't connect if I sent you
// notices you weren't aware of" since in effect it prevents a client from connecting
// if it sends a last_notice_id less than what the TM expects
// To get around it you can just send the (max_u64 - 1) so that the check of
// (last_notice_id + 1) < pendingNotices[0].id always fails.
// The downside is that the TM will not send you any queued notices.
last_notice_id: 0xFFFFFFFFFFFFFFFF - 1,
last_notice_id: 0,
username,
pass_hash: result.try_into().unwrap(),
pw_valid: welcome.pw_valid,
@ -318,7 +327,6 @@ impl ConnectMsg {
}
}
const NOTICE_MSG_LEN: usize = 8;
#[derive(Debug)]
struct NoticeMsg {
notice_id: u64,
@ -327,28 +335,32 @@ struct NoticeMsg {
impl NoticeMsg {
fn from_bytes(bytes: Vec<u8>) -> Option<NoticeMsg> {
if bytes.len() < NOTICE_MSG_LEN {
if bytes.len() < 8 {
return None;
}
// TODO: figure out what protobuf is containing the notice so that I don't add a static
// offset
let mut pb_data = Cursor::new(bytes[8..].to_vec());
match tm::Notice::decode(&mut pb_data) {
Ok(notice) => Some(NoticeMsg{
notice_id: u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
notice,
let notice_id = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
match BackendMessage::from_bytes(bytes[8..].to_vec()) {
Some(message) => {
match message.data.notice {
Some(notice) => Some(NoticeMsg{
notice_id,
notice,
}),
None => None,
}
),
Err(_) => None,
},
None => None,
}
}
}
struct TMClient {
stream: Arc<Mutex<openssl::ssl::SslStream<TcpStream>>>,
stream: openssl::ssl::SslStream<TcpStream>,
notices: mpsc::Sender<tm::Notice>,
responses: mpsc::Sender<BackendMessage>,
requests: mpsc::Receiver<BackendMessage>,
uuid: [u8; 16],
client_name: [u8; 32],
password: String,
@ -357,15 +369,16 @@ struct TMClient {
}
struct TMConnection {
stream: Arc<Mutex<openssl::ssl::SslStream<TcpStream>>>,
notices: mpsc::Receiver<tm::Notice>,
responses: mpsc::Receiver<BackendMessage>,
requests: mpsc::Sender<BackendMessage>,
}
impl TMClient {
fn new(uuid: [u8; 16], client_name: [u8; 32], password: String, username: [u8; 16]) -> (TMClient, TMConnection) {
let (notice_tx, notice_rx) = mpsc::channel();
let (response_tx, response_rx) = mpsc::channel();
let (request_tx, request_rx) = mpsc::channel();
let mut builder = openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls()).unwrap();
builder.set_ca_file("tm.crt").unwrap();
@ -382,12 +395,13 @@ impl TMClient {
stream_config.set_use_server_name_indication(false);
let stream = stream_config.connect("127.0.0.1", stream).unwrap();
let stream_arc = Arc::new(Mutex::new(stream));
stream.get_ref().set_read_timeout(Some(Duration::from_millis(100))).expect("Failed to set read timeout on socket");
return (TMClient{
stream: stream_arc.clone(),
stream,
notices: notice_tx,
responses: response_tx,
requests: request_rx,
uuid,
client_name,
password,
@ -395,17 +409,22 @@ impl TMClient {
username,
},
TMConnection{
stream: stream_arc,
requests: request_tx,
notices: notice_rx,
responses: response_rx,
},);
}
fn process(self: &mut TMClient) {
for request in self.requests.try_iter() {
let packet = BackendPacket::new(TM_HEADER, 0.00, 3, self.last_seq_num + 1, request.as_bytes());
match self.stream.write(&packet.as_bytes()) {
Ok(_) => self.last_seq_num += 1,
Err(error) => println!("Request send error: {:?}", error),
}
}
let mut incoming = [0; 2048];
let mut stream = self.stream.lock().unwrap();
match stream.read(&mut incoming) {
match self.stream.read(&mut incoming) {
Ok(read) => {
let data = incoming[0..read].to_vec();
match BackendPacket::from_bytes(data) {
@ -419,33 +438,40 @@ impl TMClient {
println!("Received notice: {:?}", notice);
let ack = BackendPacket::new(packet.header, packet.timestamp, 5, self.last_seq_num+1, notice.notice_id.to_le_bytes().to_vec());
self.last_seq_num += 1;
match stream.write(&ack.as_bytes()) {
match self.stream.write(&ack.as_bytes()) {
Ok(_) => println!("Sent ACK for notice {}", notice.notice_id),
Err(error) => println!("ACK error: {:?}", error),
}
match self.notices.send(notice.notice) {
Ok(_) => println!("Forwarded notice to callback engine"),
Err(error) => println!("Notice forward error {:?}", error),
}
},
None => println!("Notice error: {:?}", packet),
None => println!("Notice parse error: {:?}", packet),
}
},
// Response message
3 => {
println!("Received response. TODO: handle these");
match BackendMessage::from_bytes(packet.data.clone()) {
Some(message) => {
match self.responses.send(message) {
Ok(_) => println!("Forwarded response to callback engine"),
Err(error) => println!("Response forward error {:?}", error),
}
},
None => println!("BackendMessage parse error: {:?}", packet),
}
},
// Server Message
2 => {
match ConnectMsg::from_bytes(packet.data) {
Some(welcome_msg) => {
println!("Welcome msg: {:?}", welcome_msg);
if welcome_msg.pw_valid == 0 {
let connect_response = ConnectMsg::from_welcome(welcome_msg, &self.password, self.uuid, self.client_name, self.username);
let response = BackendPacket::new(packet.header, packet.timestamp, packet.msg_type, self.last_seq_num+1, connect_response.as_bytes());
println!("Sending {:X?}", connect_response);
match stream.write(&response.as_bytes()) {
match self.stream.write(&response.as_bytes()) {
Err(error) => println!("Send error: {:?}", error),
Ok(sent) => {
println!("Sent {} bytes", sent);
self.last_seq_num += 1;
},
Ok(_) => self.last_seq_num += 1,
}
} else if welcome_msg.state_valid == 0 {
println!("pw_valid but not state_valid");
@ -465,27 +491,13 @@ impl TMClient {
}
}
},
Err(error) => println!("Error: {}", error),
Err(error) => {},
}
}
}
impl TMConnection {
fn send(self: &mut TMConnection, request: BackendMessage) -> Result<BackendMessage, std::io::Error> {
let mut stream = self.stream.lock().unwrap();
match stream.write(&request.as_bytes()) {
Ok(_) =>
match self.responses.recv() {
Ok(response) => Ok(response),
Err(_) => Err(std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "channel closed")),
},
Err(error) => Err(error),
}
}
}
type NoticeCallback = fn(tm::Notice, Event, ) -> (Vec<MQTTMessage>, Event);
type NoticeCallback = fn(tm::Notice, Event) -> (Vec<MQTTMessage>, Event);
fn get_game_score(scores: tm::MatchScore) -> Option<GameScore> {
if scores.alliances.len() != 2 {
@ -593,8 +605,6 @@ fn main() {
div.id = Some(1);
test.division = Some(div);
println!("ENCODED_PROTOBUF: {:X?}", test.encode_to_vec());
let mut event = Event{
name: String::from(""),
divisions: Vec::new(),
@ -627,20 +637,24 @@ fn main() {
client.publish("bridge/status", QoS::AtLeastOnce, true, "{\"online\": true}").unwrap();
let mqtt_recv_thread = thread::spawn(move ||
for _ in connection.iter() {
}
);
for _ in connection.iter() {
}
);
let running = true;
let uuid = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let client_name: [u8;32] = [b'r', b'u', b's', b't', b'-', b'b', b'r', b'i', b'd', b'g', b'e', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let mut uuid = [0u8; 16];
rand::thread_rng().fill_bytes(&mut uuid);
let mut client_name = [0u8;32];
rand::thread_rng().fill_bytes(&mut client_name);
let username: [u8;16] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let (mut tm_client, tm_connection) = TMClient::new(uuid, client_name, String::from(""), username);
let tm_thread = thread::spawn(move ||
while running {
tm_client.process();
}
);
while running {
tm_client.process();
}
);
// TODO: send "get_schedule" message and wait for it's response to know we're connected
while running {
thread::sleep(Duration::from_millis(1000));