use std::{net::SocketAddr, ops::ControlFlow, sync::Arc}; use axum::{ body::Bytes, extract::{ ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, ConnectInfo, State }, response::IntoResponse }; use serde_json::json; use crate::{utils::events, AppState}; use futures_util::{sink::SinkExt, stream::StreamExt}; #[axum::debug_handler] pub async fn ws_handler( ws: WebSocketUpgrade, ConnectInfo(addr): ConnectInfo, State(state): State> ) -> impl IntoResponse { println!("`{addr} connected."); // finalize the upgrade process by returning upgrade callback. // we can customize the callback by sending additional info such as address. ws.on_upgrade(move |socket| handle_socket(socket, addr, state)) } async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc) { if socket .send(Message::Ping(Bytes::from_static(&[4, 2]))) .await .is_ok() { println!("WS >>> Pinged {who}..."); } else { println!("WS >>> Could not send ping {who}!"); return; } if let Some(msg) = socket.recv().await { if let Ok(msg) = msg { if process_message(msg, who).is_break() { return; } } else { println!("WS >>> Client {who} abruptly disconnected"); return; } } let (mut sender, mut receiver) = socket.split(); let mut recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = receiver.next().await { if process_message(msg, who).is_break() { break; } } }); let mut send_task = tokio::spawn(async move { let mut event_listener = state.event_bus.subscribe(); loop { match event_listener.recv().await { Err(_) => (), Ok(event) => { match event { events::Event::WebsocketBroadcast(message) => { let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await; } } } } } }); tokio::select! { rv_a = (&mut send_task) => { match rv_a { Ok(()) => println!("WS >>> Sender connection with {who} gracefully shut down"), Err(a) => println!("WS >>> Error sending messages {a:?}") } recv_task.abort(); }, rv_b = (&mut recv_task) => { match rv_b { Ok(()) => println!("WS >>> Receiver connection with {who} gracefully shut down"), Err(b) => println!("WS >>> Error receiving messages {b:?}") } send_task.abort(); } } } fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { match msg { Message::Text(t) => { println!("WS >>> {who} sent str: {t:?}"); } Message::Binary(d) => { println!("WS >>> {who} sent {} bytes: {d:?}", d.len()); } Message::Close(c) => { if let Some(cf) = c { println!( "WS >>> {who} sent close with code {} and reason `{}`", cf.code, cf.reason ); } else { println!("WS >>> {who} somehow sent close message without CloseFrame"); } return ControlFlow::Break(()); } Message::Pong(v) => { println!("WS >>> {who} sent pong with {v:?}"); } Message::Ping(v) => { println!("WS >>> {who} sent ping with {v:?}"); } } ControlFlow::Continue(()) }