feat: websocket authentication now takes place in the websocket instead of using headers, for broader API compatibility

eg Javascript WebSocket API
This commit is contained in:
Ninjdai 2025-08-07 17:25:55 +02:00
parent fa3017ce44
commit 6ca7916012
4 changed files with 75 additions and 62 deletions

View file

@ -182,14 +182,14 @@ async fn run_server(db: Arc<DatabaseConnection>) {
.routes(routes!(routes::bal::create_bal)) .routes(routes!(routes::bal::create_bal))
.routes(routes!(routes::bal::update_bal)) .routes(routes!(routes::bal::update_bal))
.routes(routes!(routes::bal::get_bals)) .routes(routes!(routes::bal::get_bals))
// Misc
.routes(routes!(routes::websocket::ws_handler))
// Authentication // Authentication
.route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware)) .route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware))
.routes(routes!(routes::auth::auth)) .routes(routes!(routes::auth::auth))
.routes(routes!(routes::auth::check_token)) .routes(routes!(routes::auth::check_token))
// Misc
.routes(routes!(routes::websocket::ws_handler))
.route("/", get(index)) .route("/", get(index))
.with_state(shared_state) .with_state(shared_state)
.split_for_parts(); .split_for_parts();

View file

@ -84,6 +84,14 @@ pub async fn check_token(Json(payload): Json<TokenPayload>) -> Json<bool> {
} }
} }
pub fn claims_from_token(payload: TokenPayload) -> Result<Claims, jsonwebtoken::errors::Error> {
let token_data = decode::<Claims>(&payload.token, &KEYS.decoding, &Validation::default());
match token_data {
Ok(data) => Ok(data.claims),
Err(e) => Err(e)
}
}
impl AuthBody { impl AuthBody {
fn new(access_token: String) -> Self { fn new(access_token: String) -> Self {
Self { Self {

View file

@ -1,14 +1,14 @@
use std::{net::SocketAddr, ops::ControlFlow, sync::Arc}; use std::{net::SocketAddr, ops::ControlFlow, sync::Arc};
use axum::{ use axum::{
body::Bytes, extract::{ extract::{
ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
ConnectInfo, ConnectInfo,
State State
}, response::IntoResponse }, response::IntoResponse
}; };
use crate::{routes::auth::Claims, utils::events, AppState}; use crate::{routes::{auth::{claims_from_token, Claims, TokenPayload}}, utils::events::{self, WebsocketMessage}, AppState};
use futures_util::{sink::SinkExt, stream::StreamExt}; use futures_util::{sink::SinkExt, stream::StreamExt};
@ -27,79 +27,81 @@ use futures_util::{sink::SinkExt, stream::StreamExt};
pub async fn ws_handler( pub async fn ws_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
claims: Claims,
State(state): State<Arc<AppState>> State(state): State<Arc<AppState>>
) -> impl IntoResponse { ) -> impl IntoResponse {
log::debug!(target: "websocket", "{addr} connected."); log::info!(target: "websocket", "{addr} connected to the websocket.");
ws.on_upgrade(move |socket| handle_socket(socket, addr, state, claims)) ws.on_upgrade(move |socket| handle_socket(socket, addr, state))
} }
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>) {
// Authenticate connected socket
let mut claims_t: Option<Claims> = None;
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>, claims: Claims) { if let Some(Ok(msg)) = socket.recv().await {
if socket if let Message::Text(txt) = msg && let Ok(auth_payload) = serde_json::from_str::<TokenPayload>(&txt) {
.send(Message::Ping(Bytes::from_static(&[4, 2]))) if let Ok(claims) = claims_from_token(auth_payload) {
.await claims_t = Some(claims);
.is_ok() } else {
{ log::debug!(target: "websocket", "{who} tried to authenticate with wrong token");
log::debug!(target: "websocket", "Pinged {who}...");
} else {
log::debug!(target: "websocket", "Could not send ping to {who}!");
return;
}
if let Some(msg) = socket.recv().await {
if let Ok(msg) = msg {
if process_message(msg, who).is_break() {
return;
} }
} else { } else {
log::debug!(target: "websocket", "Client {who} abruptly disconnected"); log::debug!(target: "websocket", "{who} send an invalid payload before logging in")
return;
} }
} }
let (mut sender, mut receiver) = socket.split(); match claims_t {
let mut recv_task = tokio::spawn(async move { Some(claims) => {
while let Some(Ok(msg)) = receiver.next().await { log::debug!(target: "websocket", "{who} successfully authenticated on the websocket");
if process_message(msg, who).is_break() { // Socket is authenticated, go on
break; let (mut sender, mut receiver) = socket.split();
} let mut recv_task = tokio::spawn(async move {
} while let Some(Ok(msg)) = receiver.next().await {
}); if process_message(msg, who).is_break() {
let mut send_task = tokio::spawn(async move { break;
let mut event_listener = state.event_bus.subscribe(); }
loop { }
match event_listener.recv().await { });
Err(_) => (), let mut send_task = tokio::spawn(async move {
Ok(event) => { let mut event_listener = state.event_bus.subscribe();
match event { loop {
events::Event::WebsocketBroadcast(message) => { match event_listener.recv().await {
if !message.should_user_receive(claims.user_id) { Err(_) => (),
continue; Ok(event) => {
}; match event {
log::debug!(target: "websocket", "Sent {message:?} to {who}"); events::Event::WebsocketBroadcast(message) => {
let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await; if !message.should_user_receive(claims.user_id) {
continue;
};
log::debug!(target: "websocket", "Sent {message:?} to {who}");
let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await;
}
}
} }
} }
} }
} });
}
});
tokio::select! { tokio::select! {
rv_a = (&mut send_task) => { rv_a = (&mut send_task) => {
match rv_a { match rv_a {
Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"), Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"),
Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}") Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}")
}
recv_task.abort();
},
rv_b = (&mut recv_task) => {
match rv_b {
Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"),
Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}")
}
send_task.abort();
}
} }
recv_task.abort();
}, },
rv_b = (&mut recv_task) => { None => {
match rv_b { // Socket was not authenticated, abort the mission
Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"), let _ = socket.send(Message::Text(WebsocketMessage::Error(r#"Invalid Authentication. When you connect to the websocket, please send a text message formatted in the following way: {"token": "valid_json_web_token"}"#.to_string()).to_json().to_string().into())).await;
Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}") return;
}
send_task.abort();
} }
} }
} }

View file

@ -12,6 +12,7 @@ pub enum Event {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum WebsocketMessage { pub enum WebsocketMessage {
NewOwner(Arc<owner::Model>), NewOwner(Arc<owner::Model>),
Error(String),
Ping Ping
} }
@ -20,10 +21,12 @@ impl WebsocketMessage {
json!({ json!({
"type": match self { "type": match self {
Self::NewOwner(_) => "new_owner", Self::NewOwner(_) => "new_owner",
Self::Error(_) => "error",
Self::Ping => "ping", Self::Ping => "ping",
}, },
"data": match self { "data": match self {
Self::NewOwner(owner) => json!(owner), Self::NewOwner(owner) => json!(owner),
Self::Error(error) => json!(error),
Self::Ping => json!(null), Self::Ping => json!(null),
} }
}) })