vex_mqtt_rust/src/main.rs

784 lines
26 KiB
Rust

use rumqttc::{MqttOptions, Client, QoS, LastWill};
use bytes::Bytes;
use tm::BackendMessageData;
use std::time::Duration;
use std::thread;
use std::collections::hash_map::HashMap;
use prost::Message;
use std::io::Cursor;
use serde::{Serialize, Deserialize};
use rand_core::RngCore;
use std::time::{SystemTime, UNIX_EPOCH};
use std::io::prelude::*;
use sha2::{Sha256, Digest};
use std::net::TcpStream;
use std::sync::mpsc;
// MQTT Topics:
// - division/{division_id}
// - division/{division_id}/ranking
// - arena/{arena_id}/score
// - arena/{arena_id}/state
// - arena/{arena_id}
// - game/{division_id}/{game_id}/score
// - team/{team_string}
pub mod tm {
include!(concat!(env!("OUT_DIR"), "/tm.rs"));
}
#[derive(Serialize, Deserialize, Debug)]
struct DivisionInfo {
arena: String,
game_id: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct DivisionRankingInfo {
rankings: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug)]
enum GameSide {
Red,
Blue,
}
#[derive(Serialize, Deserialize, Debug)]
enum ElevationTier {
A = 0,
B = 1,
C = 2,
D = 3,
E = 4,
F = 5,
G = 6,
H = 7,
I = 8,
J = 9,
}
#[derive(Serialize, Deserialize, Debug)]
struct AllianceScore {
team_goal: usize,
team_zone: usize,
green_goal: usize,
green_zone: usize,
elevation_tiers: [Option<ElevationTier>; 2],
}
#[derive(Serialize, Deserialize, Debug)]
struct GameScore {
autonomous_winner: Option<GameSide>,
red_score: AllianceScore,
red_total: usize,
blue_score: AllianceScore,
blue_total: usize,
}
#[derive(Serialize, Deserialize, Debug)]
enum GameState {
Scheduled,
Timeout,
Driver,
Driverdone,
Autonomous,
AutonomousDone,
Abandoned,
}
#[derive(Serialize, Deserialize, Debug)]
struct ArenaStateInfo {
state: Option<GameState>,
start_s: usize,
start_ns: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
enum Round {
None = 0,
Practice = 1,
Qualification = 2,
QuarterFinals = 3,
SemiFinals = 4,
Finals = 5,
RoundOf16 = 6,
RoundOf32 = 7,
RoundOf64 = 8,
RoundOf128 = 9,
TopN = 15,
RoundRobin = 16,
PreEliminations = 20,
Eliminations = 21,
}
fn int_to_round(round: i32) -> Round {
match round {
1 => Round::Practice,
2 => Round::Qualification,
3 => Round::QuarterFinals,
4 => Round::SemiFinals,
5 => Round::Finals,
6 => Round::RoundOf16,
7 => Round::RoundOf32,
8 => Round::RoundOf64,
9 => Round::RoundOf128,
15 => Round::TopN,
16 => Round::RoundRobin,
20 => Round::PreEliminations,
21 => Round::Eliminations,
_ => Round::None,
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
struct MatchTuple {
division: i32,
round: Round,
instance: i32,
match_num: i32,
session: i32,
}
#[derive(Serialize, Deserialize, Debug)]
struct ArenaInfo {
red_teams: [String; 2],
blue_teams: [String; 2],
match_tuple: MatchTuple,
}
struct MQTTMessage {
topic: String,
payload: String,
}
#[derive(Debug)]
struct Event {
name: String,
divisions: HashMap<i32, Division>,
}
impl Event {
fn new(name: String) -> Event {
Event{
name,
divisions: HashMap::new(),
}
}
fn parse_match_list(self: &mut Event, msg: BackendMessage) {
match msg.data.match_list {
Some(matches) => {
for m in matches.matches.iter() {
let match_tuple = MatchTuple{
division: m.division.unwrap(),
round: int_to_round(m.round.unwrap()),
instance: m.instance.unwrap(),
match_num: m.r#match.unwrap(),
session: m.session.unwrap(),
};
match self.divisions.get_mut(&match_tuple.division) {
Some(division) => {
division.matches.push(Match{
match_tuple: match_tuple.clone(),
})
},
None => {
let mut new_division = Division{
name: String::from(""),
matches: Vec::new(),
field_set: None,
};
new_division.matches.push(Match{
match_tuple: match_tuple.clone(),
});
self.divisions.insert(match_tuple.division, new_division);
},
}
}
},
None => log::warn!("Parsed match list without match_list"),
}
}
}
#[derive(Debug)]
struct Division {
name: String,
matches: Vec<Match>,
field_set: Option<FieldSet>,
}
#[derive(Debug)]
struct FieldSet {
fields: Vec<Field>,
}
#[derive(Debug)]
struct Field {
name: String,
current_match: u32,
}
#[derive(Debug)]
struct Match {
match_tuple: MatchTuple,
}
#[derive(Debug)]
struct BackendMessage {
status: u8,
request_id: u32,
data: tm::BackendMessageData,
}
impl BackendMessage {
fn from_bytes(bytes: Vec<u8>) -> Option<BackendMessage> {
if bytes.len() < 5 {
return None;
}
let mut pb_data = Cursor::new(bytes[5..].to_vec());
match tm::BackendMessageData::decode(&mut pb_data) {
Ok(data) =>
Some(BackendMessage{
status: bytes[0],
request_id: u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
data,
}),
Err(_) => None,
}
}
fn as_bytes(self: &BackendMessage) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.push(self.status);
bytes.extend(self.request_id.to_le_bytes());
bytes.extend(self.data.encode_to_vec());
return bytes;
}
fn new(request_id: u32, data: tm::BackendMessageData) -> BackendMessage {
BackendMessage{
status: 0,
request_id,
data,
}
}
}
const BACKEND_PACKET_HEADER_SIZE: usize = 28;
#[derive(Debug)]
struct BackendPacket {
header: u32,
timestamp: f64,
msg_type: u32,
seq_num: u64,
size: u32,
data: Vec<u8>,
}
const TM_HEADER: u32 = 0x55D33DAA;
impl BackendPacket {
fn new(header: u32, timestamp: f64, msg_type: u32, seq_num: u64, data: Vec<u8>) -> BackendPacket {
return BackendPacket{
header,
timestamp,
msg_type,
seq_num,
size: data.len().try_into().unwrap(),
data,
};
}
fn from_bytes(bytes: Vec<u8>) -> Option<BackendPacket> {
if bytes.len() < BACKEND_PACKET_HEADER_SIZE {
return None;
}
return Some(BackendPacket{
header: u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
timestamp: f64::from_le_bytes(bytes[4..12].try_into().unwrap()),
msg_type: u32::from_le_bytes(bytes[12..16].try_into().unwrap()),
seq_num: u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
size: u32::from_le_bytes(bytes[24..28].try_into().unwrap()),
data: bytes[28..].to_vec(),
});
}
fn as_bytes(self: &BackendPacket) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend(self.header.to_le_bytes());
bytes.extend(self.timestamp.to_le_bytes());
bytes.extend(self.msg_type.to_le_bytes());
bytes.extend(self.seq_num.to_le_bytes());
bytes.extend(self.size.to_le_bytes());
bytes.extend(self.data.clone());
return bytes;
}
}
const CONNECT_MSG_LEN: usize = 114;
#[derive(Debug)]
struct ConnectMsg {
version: u32,
uuid: [u8; 16],
last_notice_id: u64,
username: [u8; 16],
pass_hash: [u8; 32],
pw_valid: u8,
state_valid: u8,
client_name: [u8; 32],
server_time_zone: i32,
}
impl ConnectMsg {
fn from_welcome(welcome: ConnectMsg, password: &str, uuid: [u8; 16], client_name: [u8; 32], username: [u8; 16]) -> ConnectMsg {
let mut hasher = Sha256::new();
hasher.update(welcome.pass_hash);
hasher.update(password);
let result = hasher.finalize();
return ConnectMsg{
version: welcome.version,
uuid,
last_notice_id: 0,
username,
pass_hash: result.try_into().unwrap(),
pw_valid: welcome.pw_valid,
state_valid: welcome.state_valid,
client_name,
server_time_zone: welcome.server_time_zone,
};
}
fn from_bytes(bytes: Vec<u8>) -> Option<ConnectMsg> {
if bytes.len() < CONNECT_MSG_LEN {
return None;
}
return Some(ConnectMsg{
version: u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
uuid: bytes[4..20].to_owned().try_into().unwrap(),
last_notice_id: u64::from_le_bytes(bytes[20..28].try_into().unwrap()),
username: bytes[28..44].to_owned().try_into().unwrap(),
pass_hash: bytes[44..76].to_owned().try_into().unwrap(),
pw_valid: bytes[76].to_owned(),
state_valid: bytes[77].to_owned(),
client_name: bytes[78..110].to_owned().try_into().unwrap(),
server_time_zone: i32::from_le_bytes(bytes[110..114].try_into().unwrap()),
});
}
fn as_bytes(self: &ConnectMsg) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.extend(self.version.to_le_bytes());
bytes.extend(self.uuid);
bytes.extend(self.last_notice_id.to_le_bytes());
bytes.extend(self.username);
bytes.extend(self.pass_hash);
bytes.extend(self.pw_valid.to_le_bytes());
bytes.extend(self.state_valid.to_le_bytes());
bytes.extend(self.client_name);
bytes.extend(self.server_time_zone.to_le_bytes());
return bytes;
}
}
#[derive(Debug)]
struct NoticeMsg {
notice_id: u64,
notice: tm::Notice,
}
impl NoticeMsg {
fn from_bytes(bytes: Vec<u8>) -> Option<NoticeMsg> {
if bytes.len() < 8 {
return None;
}
let notice_id = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
match BackendMessage::from_bytes(bytes[8..].to_vec()) {
Some(message) => {
match message.data.notice {
Some(notice) => Some(NoticeMsg{
notice_id,
notice,
}),
None => None,
}
},
None => None,
}
}
}
struct TMClient {
stream: openssl::ssl::SslStream<TcpStream>,
notices: mpsc::Sender<Box<tm::Notice>>,
responses: mpsc::Sender<Box<BackendMessage>>,
requests: mpsc::Receiver<Box<BackendMessage>>,
uuid: [u8; 16],
client_name: [u8; 32],
password: String,
last_seq_num: u64,
username: [u8; 16],
connected: bool,
}
struct TMConnection {
notices: mpsc::Receiver<Box<tm::Notice>>,
responses: mpsc::Receiver<Box<BackendMessage>>,
requests: mpsc::Sender<Box<BackendMessage>>,
}
impl TMClient {
fn new(uuid: [u8; 16], client_name: [u8; 32], password: String, username: [u8; 16]) -> (TMClient, TMConnection) {
let (notice_tx, notice_rx) = mpsc::channel();
let (response_tx, response_rx) = mpsc::channel();
let (request_tx, request_rx) = mpsc::channel();
let mut builder = openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls()).unwrap();
builder.set_ca_file("tm.crt").unwrap();
builder.set_verify(openssl::ssl::SslVerifyMode::PEER);
let connector = builder.build();
let stream = TcpStream::connect("127.0.0.1:5000").unwrap();
let mut stream_config = connector.configure().unwrap();
stream_config.set_verify_hostname(false);
stream_config.set_certificate_chain_file("tm.crt").unwrap();
stream_config.set_private_key_file("tm.crt", openssl::ssl::SslFiletype::PEM).unwrap();
stream_config.set_use_server_name_indication(false);
let stream = stream_config.connect("127.0.0.1", stream).unwrap();
stream.get_ref().set_read_timeout(Some(Duration::from_millis(100))).expect("Failed to set read timeout on socket");
return (TMClient{
stream,
notices: notice_tx,
responses: response_tx,
requests: request_rx,
uuid,
client_name,
password,
last_seq_num: 0xFFFFFFFFFFFFFFFF,
username,
connected: false,
},
TMConnection{
requests: request_tx,
notices: notice_rx,
responses: response_rx,
},);
}
fn process(self: &mut TMClient) {
if self.connected == true {
// TODO: right now it's halfway to processing multiple requests at once, but currently
// it only processes a single requests/response at a time. This is fine since there's
// only a single callback thread though.
for request in self.requests.try_iter() {
let time = SystemTime::now();
let millis = time.duration_since(UNIX_EPOCH).unwrap();
let packet = BackendPacket::new(TM_HEADER, (millis.as_millis() as f64)/1000.0, 2, self.last_seq_num + 1, request.as_bytes());
match self.stream.write(&packet.as_bytes()) {
Ok(_) => {
log::debug!("Sent: {:?}", packet);
self.last_seq_num += 1;
},
Err(error) => log::error!("Request send error: {:?}", error),
}
}
}
let mut incoming = [0; 2048];
match self.stream.read(&mut incoming) {
Ok(read) => {
let data = incoming[0..read].to_vec();
match BackendPacket::from_bytes(data) {
Some(packet) => {
log::debug!("Recevied: {:?}", packet);
self.last_seq_num = packet.seq_num;
match packet.msg_type {
// Notice Message
4 => {
match NoticeMsg::from_bytes(packet.data.clone()) {
Some(notice) => {
log::debug!("Received notice: {:?}", notice);
let ack = BackendPacket::new(packet.header, packet.timestamp, 5, self.last_seq_num+1, notice.notice_id.to_le_bytes().to_vec());
self.last_seq_num += 1;
match self.stream.write(&ack.as_bytes()) {
Ok(_) => log::debug!("Sent ACK for notice {}", notice.notice_id),
Err(error) => log::error!("ACK error: {:?}", error),
}
match self.notices.send(Box::new(notice.notice)) {
Ok(_) => log::debug!("Forwarded notice to callback engine"),
Err(error) => log::error!("Notice forward error {:?}", error),
}
},
None => log::error!("Notice parse error: {:?}", packet),
}
},
// Response message
3 => {
match BackendMessage::from_bytes(packet.data.clone()) {
Some(message) => {
match self.responses.send(Box::new(message)) {
Ok(_) => log::debug!("Forwarded response to callback engine"),
Err(error) => log::error!("Response forward error {:?}", error),
}
},
None => log::error!("BackendMessage parse error: {:?}", packet),
}
},
// Server Message
2 => {
match ConnectMsg::from_bytes(packet.data) {
Some(welcome_msg) => {
if welcome_msg.pw_valid == 0 {
let connect_response = ConnectMsg::from_welcome(welcome_msg, &self.password, self.uuid, self.client_name, self.username);
let response = BackendPacket::new(packet.header, packet.timestamp, packet.msg_type, self.last_seq_num+1, connect_response.as_bytes());
match self.stream.write(&response.as_bytes()) {
Err(error) => log::error!("Send error: {:?}", error),
Ok(_) => self.last_seq_num += 1,
}
} else if welcome_msg.state_valid == 0 {
log::error!("pw_valid but not state_valid");
} else {
self.connected = true;
log::info!("Connected to TM backend!");
}
},
None => log::error!("Failed to parse welcome msg"),
}
},
_ => log::warn!("Unhandled message type: {}", packet.msg_type),
}
},
None => {
log::error!("Failed to parse BackendPacket({}): {}", read, String::from_utf8_lossy(&incoming));
// Sleep to prevent busy loop when TM is spamming 0 length packets
thread::sleep(Duration::from_millis(100));
}
}
},
Err(_) => {},
}
}
}
type NoticeCallback = fn(tm::Notice, Event) -> (Vec<MQTTMessage>, Event);
fn get_game_score(scores: tm::MatchScore) -> Option<GameScore> {
if scores.alliances.len() != 2 {
return None;
}
let ref red_score = scores.alliances[0];
let ref blue_score = scores.alliances[1];
// 1) Get the autonomous winner
// 2) Get score object and fill AllianceScore struct
// 3) Compute total scores
let out = GameScore{
autonomous_winner: None,
red_total: 0,
blue_total: 0,
blue_score: AllianceScore{
team_goal: 0,
team_zone: 0,
green_goal: 0,
green_zone: 0,
elevation_tiers: [None, None],
},
red_score : AllianceScore{
team_goal: 0,
team_zone: 0,
green_goal: 0,
green_zone: 0,
elevation_tiers: [None, None],
},
};
return Some(out);
}
fn on_score_change(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
match notice.match_score {
None => return (Vec::new(), event),
Some(game_scores) => {
match get_game_score(game_scores) {
Some(score_json) => {
let serialized = serde_json::to_string(&score_json).unwrap();
let arena_topic = String::from("arena/TEST/score");
let mut out = Vec::new();
out.push(MQTTMessage{
topic: arena_topic,
payload: serialized,
});
return (out, event);
},
None => return (Vec::new(), event),
}
},
}
}
fn on_match_start(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn on_match_cancel(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn on_match_reset(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn on_match_assigned(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn on_active_field_changed(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn on_rankings_updated(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn on_event_status_updated(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn on_elim_alliance_update(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn on_elim_unavail_teams_update(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn on_match_list_update(notice: tm::Notice, event: Event) -> (Vec<MQTTMessage>, Event) {
return (Vec::new(), event);
}
fn main() {
env_logger::init();
let mut callbacks: HashMap<tm::NoticeId, NoticeCallback> = HashMap::new();
callbacks.insert(tm::NoticeId::NoticeRealtimeScoreChanged, on_score_change);
callbacks.insert(tm::NoticeId::NoticeFieldTimerStarted, on_match_start);
callbacks.insert(tm::NoticeId::NoticeFieldTimerStopped, on_match_cancel);
callbacks.insert(tm::NoticeId::NoticeFieldResetTimer, on_match_reset);
callbacks.insert(tm::NoticeId::NoticeFieldMatchAssigned, on_match_assigned);
callbacks.insert(tm::NoticeId::NoticeActiveFieldChanged, on_active_field_changed);
callbacks.insert(tm::NoticeId::NoticeRankingsUpdated, on_rankings_updated);
callbacks.insert(tm::NoticeId::NoticeEventStatusUpdated, on_event_status_updated);
callbacks.insert(tm::NoticeId::NoticeElimAllianceUpdated, on_elim_alliance_update);
callbacks.insert(tm::NoticeId::NoticeElimUnavailTeamsUpdated, on_elim_unavail_teams_update);
callbacks.insert(tm::NoticeId::NoticeMatchListUpdated, on_match_list_update);
let mut mqttoptions = MqttOptions::new("vex-bridge", "localhost", 1883);
mqttoptions.set_keep_alive(Duration::from_secs(5));
mqttoptions.set_last_will(LastWill{
topic: String::from("bridge/status"),
message: Bytes::from("{\"online\": false}"),
qos: QoS::AtLeastOnce,
retain: true,
});
let (mut client, mut connection) = Client::new(mqttoptions, 10);
client.subscribe("bridge", QoS::AtLeastOnce).unwrap();
client.publish("bridge/status", QoS::AtLeastOnce, true, "{\"online\": true}").unwrap();
let mqtt_recv_thread = thread::spawn(move ||
for _ in connection.iter() {
}
);
let running = true;
let mut uuid = [0u8; 16];
rand::thread_rng().fill_bytes(&mut uuid);
let mut client_name = [0u8;32];
rand::thread_rng().fill_bytes(&mut client_name);
let username: [u8;16] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let (mut tm_client, tm_connection) = TMClient::new(uuid, client_name, String::from(""), username);
let tm_thread = thread::spawn(move ||
while running {
tm_client.process();
}
);
let get_event_name_req = BackendMessage::new(101, tm::BackendMessageData::default());
tm_connection.requests.send(Box::new(get_event_name_req)).unwrap();
let get_event_name_resp = tm_connection.responses.recv().unwrap();
let mut event = Event::new(String::from(get_event_name_resp.data.event_config.unwrap().event_name.unwrap()));
let mut get_match_list_tuple = tm::MatchTuple::default();
get_match_list_tuple.division = Some(0);
get_match_list_tuple.round = None;
get_match_list_tuple.instance = Some(0);
get_match_list_tuple.r#match = Some(0);
get_match_list_tuple.session = Some(0);
let mut get_match_list_data = tm::BackendMessageData::default();
get_match_list_data.match_tuple = Some(get_match_list_tuple);
let get_match_list_req = BackendMessage::new(1002, get_match_list_data.clone());
tm_connection.requests.send(Box::new(get_match_list_req)).unwrap();
let get_match_list_resp = tm_connection.responses.recv().unwrap();
event.parse_match_list(*get_match_list_resp);
println!("Event after parse: {:?}", event);
while running {
thread::sleep(Duration::from_millis(1000));
match tm_connection.notices.recv() {
Ok(notice) => {
let callback = callbacks.get(&notice.id());
match callback {
None => {
match notice.id {
None => log::error!("Notice without NoticeId received"),
Some(notice_id) => log::warn!("Unhandled NoticeId: {}", notice_id),
}
},
Some(callback) => {
let (messages, next_event) = callback(*notice, event);
event = next_event;
for message in messages {
let result = client.publish(message.topic, QoS::AtMostOnce, true, message.payload);
match result {
Ok(_) => {},
Err(error) => log::error!("Publish error: {}", error),
}
}
},
}
},
Err(error) => log::error!("Notice recv error: {}", error),
}
}
mqtt_recv_thread.join().expect("Failed to join mqtt thread");
tm_thread.join().expect("Failed to join tm connection thread");
}