diff --git a/src/routes/websocket.rs b/src/routes/websocket.rs index a13a96c..33aad9a 100644 --- a/src/routes/websocket.rs +++ b/src/routes/websocket.rs @@ -2,7 +2,7 @@ use std::{net::SocketAddr, ops::ControlFlow, sync::Arc}; use axum::{ extract::{ - ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, + ws::{Message, WebSocket, WebSocketUpgrade}, ConnectInfo, State }, response::IntoResponse @@ -38,19 +38,23 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc = None; if let Some(Ok(msg)) = socket.recv().await { - if let Message::Text(txt) = msg && let Ok(auth_payload) = serde_json::from_str::(&txt) { + if let Message::Text(txt) = &msg && let Ok(auth_payload) = serde_json::from_str::(&txt) { if let Ok(claims) = claims_from_token(auth_payload) { claims_t = Some(claims); } else { log::debug!(target: "websocket", "{who} tried to authenticate with wrong token"); } } else { - log::debug!(target: "websocket", "{who} send an invalid payload before logging in") + log::debug!(target: "websocket", "{who} send an invalid payload before logging in: {}", &msg.clone().to_text().unwrap_or("")) } } match claims_t { Some(claims) => { + if let Err(_) = socket.send(WebsocketMessage::AuthSuccess.to_text_message()).await { + log::debug!(target: "websocket", "Could not send auth success message to {who}"); + return; + }; log::debug!(target: "websocket", "{who} successfully authenticated on the websocket"); // Socket is authenticated, go on let (mut sender, mut receiver) = socket.split(); @@ -73,7 +77,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc { // Socket was not authenticated, abort the mission - let _ = socket.send(Message::Text(WebsocketMessage::Error(r#"Invalid Authentication. When you connect to the websocket, please send a text message formatted in the following way: {"token": "valid_json_web_token"}"#.to_string()).to_json().to_string().into())).await; + let _ = socket.send(WebsocketMessage::Error(r#"Invalid Authentication. When you connect to the websocket, please send a text message formatted in the following way: {"token": "valid_json_web_token"}"#.to_string()).to_text_message()).await; return; } } diff --git a/src/utils/events.rs b/src/utils/events.rs index 348a360..55c0ef0 100644 --- a/src/utils/events.rs +++ b/src/utils/events.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use axum::extract::ws::{Message, Utf8Bytes}; use serde_json::{json, Value}; use crate::entities::owner; @@ -11,6 +12,7 @@ pub enum Event { #[derive(Clone, Debug)] pub enum WebsocketMessage { + AuthSuccess, NewOwner(Arc), Error(String), Ping @@ -20,11 +22,13 @@ impl WebsocketMessage { pub fn to_json(&self) -> Value { json!({ "type": match self { + Self::AuthSuccess => "auth_success", Self::NewOwner(_) => "new_owner", Self::Error(_) => "error", Self::Ping => "ping", }, "data": match self { + Self::AuthSuccess => json!(null), Self::NewOwner(owner) => json!(owner), Self::Error(error) => json!(error), Self::Ping => json!(null), @@ -32,6 +36,10 @@ impl WebsocketMessage { }) } + pub fn to_text_message(&self) -> Message { + Message::Text(Utf8Bytes::from(self.to_json().to_string())) + } + pub fn should_user_receive(&self, user_id: u32) -> bool { match self { Self::NewOwner(owner) => owner.user_id == user_id,