diff --git a/src/main.rs b/src/main.rs index 10980e4..8af7fbb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -182,14 +182,14 @@ async fn run_server(db: Arc) { .routes(routes!(routes::bal::create_bal)) .routes(routes!(routes::bal::update_bal)) .routes(routes!(routes::bal::get_bals)) - // Misc - .routes(routes!(routes::websocket::ws_handler)) // Authentication .route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware)) .routes(routes!(routes::auth::auth)) .routes(routes!(routes::auth::check_token)) - + // Misc + .routes(routes!(routes::websocket::ws_handler)) .route("/", get(index)) + .with_state(shared_state) .split_for_parts(); diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 736d797..f972682 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -84,6 +84,14 @@ pub async fn check_token(Json(payload): Json) -> Json { } } +pub fn claims_from_token(payload: TokenPayload) -> Result { + let token_data = decode::(&payload.token, &KEYS.decoding, &Validation::default()); + match token_data { + Ok(data) => Ok(data.claims), + Err(e) => Err(e) + } +} + impl AuthBody { fn new(access_token: String) -> Self { Self { diff --git a/src/routes/websocket.rs b/src/routes/websocket.rs index 8cf8d98..a13a96c 100644 --- a/src/routes/websocket.rs +++ b/src/routes/websocket.rs @@ -1,14 +1,14 @@ use std::{net::SocketAddr, ops::ControlFlow, sync::Arc}; use axum::{ - body::Bytes, extract::{ + extract::{ ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, ConnectInfo, State }, response::IntoResponse }; -use crate::{routes::auth::Claims, utils::events, AppState}; +use crate::{routes::{auth::{claims_from_token, Claims, TokenPayload}}, utils::events::{self, WebsocketMessage}, AppState}; use futures_util::{sink::SinkExt, stream::StreamExt}; @@ -27,79 +27,81 @@ use futures_util::{sink::SinkExt, stream::StreamExt}; pub async fn ws_handler( ws: WebSocketUpgrade, ConnectInfo(addr): ConnectInfo, - claims: Claims, State(state): State> ) -> impl IntoResponse { - log::debug!(target: "websocket", "{addr} connected."); - ws.on_upgrade(move |socket| handle_socket(socket, addr, state, claims)) + log::info!(target: "websocket", "{addr} connected to the websocket."); + ws.on_upgrade(move |socket| handle_socket(socket, addr, state)) } +async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc) { + // Authenticate connected socket + let mut claims_t: Option = None; -async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc, claims: Claims) { - if socket - .send(Message::Ping(Bytes::from_static(&[4, 2]))) - .await - .is_ok() - { - log::debug!(target: "websocket", "Pinged {who}..."); - } else { - log::debug!(target: "websocket", "Could not send ping to {who}!"); - return; - } - - if let Some(msg) = socket.recv().await { - if let Ok(msg) = msg { - if process_message(msg, who).is_break() { - return; + if let Some(Ok(msg)) = socket.recv().await { + if let Message::Text(txt) = msg && let Ok(auth_payload) = serde_json::from_str::(&txt) { + if let Ok(claims) = claims_from_token(auth_payload) { + claims_t = Some(claims); + } else { + log::debug!(target: "websocket", "{who} tried to authenticate with wrong token"); } } else { - log::debug!(target: "websocket", "Client {who} abruptly disconnected"); - return; + log::debug!(target: "websocket", "{who} send an invalid payload before logging in") } } - let (mut sender, mut receiver) = socket.split(); - let mut recv_task = tokio::spawn(async move { - while let Some(Ok(msg)) = receiver.next().await { - if process_message(msg, who).is_break() { - break; - } - } - }); - let mut send_task = tokio::spawn(async move { - let mut event_listener = state.event_bus.subscribe(); - loop { - match event_listener.recv().await { - Err(_) => (), - Ok(event) => { - match event { - events::Event::WebsocketBroadcast(message) => { - if !message.should_user_receive(claims.user_id) { - continue; - }; - log::debug!(target: "websocket", "Sent {message:?} to {who}"); - let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await; + match claims_t { + Some(claims) => { + log::debug!(target: "websocket", "{who} successfully authenticated on the websocket"); + // Socket is authenticated, go on + let (mut sender, mut receiver) = socket.split(); + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(msg)) = receiver.next().await { + if process_message(msg, who).is_break() { + break; + } + } + }); + let mut send_task = tokio::spawn(async move { + let mut event_listener = state.event_bus.subscribe(); + loop { + match event_listener.recv().await { + Err(_) => (), + Ok(event) => { + match event { + events::Event::WebsocketBroadcast(message) => { + if !message.should_user_receive(claims.user_id) { + continue; + }; + log::debug!(target: "websocket", "Sent {message:?} to {who}"); + let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await; + } + } } } } - } - } - }); + }); - tokio::select! { - rv_a = (&mut send_task) => { - match rv_a { - Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"), - Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}") + tokio::select! { + rv_a = (&mut send_task) => { + match rv_a { + Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"), + Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}") + } + recv_task.abort(); + }, + rv_b = (&mut recv_task) => { + match rv_b { + Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"), + Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}") + } + send_task.abort(); + } } - recv_task.abort(); }, - rv_b = (&mut recv_task) => { - match rv_b { - Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"), - Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}") - } - send_task.abort(); + None => { + // Socket was not authenticated, abort the mission + let _ = socket.send(Message::Text(WebsocketMessage::Error(r#"Invalid Authentication. When you connect to the websocket, please send a text message formatted in the following way: {"token": "valid_json_web_token"}"#.to_string()).to_json().to_string().into())).await; + return; } } } diff --git a/src/utils/events.rs b/src/utils/events.rs index 56d7381..348a360 100644 --- a/src/utils/events.rs +++ b/src/utils/events.rs @@ -12,6 +12,7 @@ pub enum Event { #[derive(Clone, Debug)] pub enum WebsocketMessage { NewOwner(Arc), + Error(String), Ping } @@ -20,10 +21,12 @@ impl WebsocketMessage { json!({ "type": match self { Self::NewOwner(_) => "new_owner", + Self::Error(_) => "error", Self::Ping => "ping", }, "data": match self { Self::NewOwner(owner) => json!(owner), + Self::Error(error) => json!(error), Self::Ping => json!(null), } })