140 lines
5.6 KiB
Rust
140 lines
5.6 KiB
Rust
use std::{net::SocketAddr, ops::ControlFlow, sync::Arc};
|
|
|
|
use axum::{
|
|
extract::{
|
|
ws::{Message, WebSocket, WebSocketUpgrade},
|
|
ConnectInfo,
|
|
State
|
|
}, response::IntoResponse
|
|
};
|
|
|
|
use crate::{routes::{auth::{claims_from_token, Claims, TokenPayload}}, utils::events::{self, WebsocketMessage}, AppState};
|
|
|
|
use futures_util::{sink::SinkExt, stream::StreamExt};
|
|
|
|
#[axum::debug_handler]
|
|
#[utoipa::path(
|
|
get,
|
|
path = "/ws",
|
|
security(("jwt" = [])),
|
|
responses(
|
|
(status = SWITCHING_PROTOCOLS, description = "Succesfully reached the websocket, now upgrade to establish the connection"),
|
|
),
|
|
summary = "Connect to the websocket",
|
|
description = "Initial connection to the websocket before upgrading protocols",
|
|
tag = "websocket",
|
|
)]
|
|
pub async fn ws_handler(
|
|
ws: WebSocketUpgrade,
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
State(state): State<Arc<AppState>>
|
|
) -> impl IntoResponse {
|
|
log::info!(target: "websocket", "{addr} connected to the websocket.");
|
|
ws.on_upgrade(move |socket| handle_socket(socket, addr, state))
|
|
}
|
|
|
|
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>) {
|
|
// Authenticate connected socket
|
|
let mut claims_t: Option<Claims> = None;
|
|
|
|
if let Some(Ok(msg)) = socket.recv().await {
|
|
if let Message::Text(txt) = &msg && let Ok(auth_payload) = serde_json::from_str::<TokenPayload>(&txt) {
|
|
if let Ok(claims) = claims_from_token(auth_payload) {
|
|
claims_t = Some(claims);
|
|
} else {
|
|
log::debug!(target: "websocket", "{who} tried to authenticate with wrong token");
|
|
}
|
|
} else {
|
|
log::debug!(target: "websocket", "{who} send an invalid payload before logging in: {}", &msg.clone().to_text().unwrap_or("<unable to get payload as text>"))
|
|
}
|
|
}
|
|
|
|
match claims_t {
|
|
Some(claims) => {
|
|
if let Err(_) = socket.send(WebsocketMessage::AuthSuccess.to_text_message()).await {
|
|
log::debug!(target: "websocket", "Could not send auth success message to {who}");
|
|
return;
|
|
};
|
|
log::debug!(target: "websocket", "{who} successfully authenticated on the websocket");
|
|
// Socket is authenticated, go on
|
|
let (mut sender, mut receiver) = socket.split();
|
|
let mut recv_task = tokio::spawn(async move {
|
|
while let Some(Ok(msg)) = receiver.next().await {
|
|
if process_message(msg, who).is_break() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
let mut send_task = tokio::spawn(async move {
|
|
let mut event_listener = state.event_bus.subscribe();
|
|
loop {
|
|
match event_listener.recv().await {
|
|
Err(_) => (),
|
|
Ok(event) => {
|
|
match event {
|
|
events::Event::WebsocketBroadcast(message) => {
|
|
if !message.should_user_receive(claims.user_id) {
|
|
continue;
|
|
};
|
|
log::debug!(target: "websocket", "Sent {message:?} to {who}");
|
|
let _ = sender.send(message.to_text_message()).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
tokio::select! {
|
|
rv_a = (&mut send_task) => {
|
|
match rv_a {
|
|
Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"),
|
|
Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}")
|
|
}
|
|
recv_task.abort();
|
|
},
|
|
rv_b = (&mut recv_task) => {
|
|
match rv_b {
|
|
Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"),
|
|
Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}")
|
|
}
|
|
send_task.abort();
|
|
}
|
|
}
|
|
},
|
|
None => {
|
|
// Socket was not authenticated, abort the mission
|
|
let _ = socket.send(WebsocketMessage::Error(r#"Invalid Authentication. When you connect to the websocket, please send a text message formatted in the following way: {"token": "valid_json_web_token"}"#.to_string()).to_text_message()).await;
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> {
|
|
match msg {
|
|
Message::Text(t) => {
|
|
log::debug!(target: "websocket", "{who} sent str: {t:?}");
|
|
}
|
|
Message::Binary(d) => {
|
|
log::debug!(target: "websocket", "{who} sent {} bytes: {d:?}", d.len());
|
|
}
|
|
Message::Close(c) => {
|
|
if let Some(cf) = c {
|
|
log::debug!(target: "websocket",
|
|
"{who} sent close with code {} and reason `{}`",
|
|
cf.code, cf.reason
|
|
);
|
|
} else {
|
|
log::debug!(target: "websocket", "{who} somehow sent close message without CloseFrame");
|
|
}
|
|
return ControlFlow::Break(());
|
|
}
|
|
Message::Pong(_v) => {
|
|
//log::debug!(target: "websocket", "{who} sent pong with {v:?}");
|
|
}
|
|
Message::Ping(_v) => {
|
|
//log::debug!(target: "websocket", "{who} sent ping with {v:?}");
|
|
}
|
|
}
|
|
ControlFlow::Continue(())
|
|
}
|