use std::{net::SocketAddr, ops::ControlFlow, sync::Arc}; use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, ConnectInfo, State }, response::IntoResponse }; use crate::{routes::{auth::{claims_from_token, Claims, TokenPayload}}, utils::events::{self, WebsocketMessage}, AppState}; use futures_util::{sink::SinkExt, stream::StreamExt}; #[axum::debug_handler] #[utoipa::path( get, path = "/ws", security(("jwt" = [])), responses( (status = SWITCHING_PROTOCOLS, description = "Succesfully reached the websocket, now upgrade to establish the connection"), ), summary = "Connect to the websocket", description = "Initial connection to the websocket before upgrading protocols", tag = "websocket", )] pub async fn ws_handler( ws: WebSocketUpgrade, ConnectInfo(addr): ConnectInfo, State(state): State> ) -> impl IntoResponse { log::info!(target: "websocket", "{addr} connected to the websocket."); ws.on_upgrade(move |socket| handle_socket(socket, addr, state)) } async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc) { // Authenticate connected socket let mut claims_t: Option = None; if let Some(Ok(msg)) = socket.recv().await { if let Message::Text(txt) = &msg && let Ok(auth_payload) = serde_json::from_str::(&txt) { if let Ok(claims) = claims_from_token(auth_payload) { claims_t = Some(claims); } else { log::debug!(target: "websocket", "{who} tried to authenticate with wrong token"); } } else { log::debug!(target: "websocket", "{who} send an invalid payload before logging in: {}", &msg.clone().to_text().unwrap_or("")) } } match claims_t { Some(claims) => { if let Err(_) = socket.send(WebsocketMessage::AuthSuccess.to_text_message()).await { log::debug!(target: "websocket", "Could not send auth success message to {who}"); return; }; log::debug!(target: "websocket", "{who} successfully authenticated on the websocket"); // Socket is authenticated, go on let (mut sender, mut receiver) = socket.split(); let mut recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = receiver.next().await { if process_message(msg, who).is_break() { break; } } }); let mut send_task = tokio::spawn(async move { let mut event_listener = state.event_bus.subscribe(); loop { match event_listener.recv().await { Err(_) => (), Ok(event) => { match event { events::Event::WebsocketBroadcast(message) => { if !message.should_user_receive(claims.user_id) { continue; }; log::debug!(target: "websocket", "Sent {message:?} to {who}"); let _ = sender.send(message.to_text_message()).await; } } } } } }); tokio::select! { rv_a = (&mut send_task) => { match rv_a { Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"), Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}") } recv_task.abort(); }, rv_b = (&mut recv_task) => { match rv_b { Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"), Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}") } send_task.abort(); } } }, None => { // Socket was not authenticated, abort the mission let _ = socket.send(WebsocketMessage::Error(r#"Invalid Authentication. When you connect to the websocket, please send a text message formatted in the following way: {"token": "valid_json_web_token"}"#.to_string()).to_text_message()).await; return; } } } fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { match msg { Message::Text(t) => { log::debug!(target: "websocket", "{who} sent str: {t:?}"); } Message::Binary(d) => { log::debug!(target: "websocket", "{who} sent {} bytes: {d:?}", d.len()); } Message::Close(c) => { if let Some(cf) = c { log::debug!(target: "websocket", "{who} sent close with code {} and reason `{}`", cf.code, cf.reason ); } else { log::debug!(target: "websocket", "{who} somehow sent close message without CloseFrame"); } return ControlFlow::Break(()); } Message::Pong(_v) => { //log::debug!(target: "websocket", "{who} sent pong with {v:?}"); } Message::Ping(_v) => { //log::debug!(target: "websocket", "{who} sent ping with {v:?}"); } } ControlFlow::Continue(()) }