feat: websocket authentication now takes place in the websocket instead of using headers, for broader API compatibility

eg Javascript WebSocket API
This commit is contained in:
Ninjdai 2025-08-07 17:25:55 +02:00
parent fa3017ce44
commit 6ca7916012
4 changed files with 75 additions and 62 deletions

View file

@ -182,14 +182,14 @@ async fn run_server(db: Arc<DatabaseConnection>) {
.routes(routes!(routes::bal::create_bal))
.routes(routes!(routes::bal::update_bal))
.routes(routes!(routes::bal::get_bals))
// Misc
.routes(routes!(routes::websocket::ws_handler))
// Authentication
.route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware))
.routes(routes!(routes::auth::auth))
.routes(routes!(routes::auth::check_token))
// Misc
.routes(routes!(routes::websocket::ws_handler))
.route("/", get(index))
.with_state(shared_state)
.split_for_parts();

View file

@ -84,6 +84,14 @@ pub async fn check_token(Json(payload): Json<TokenPayload>) -> Json<bool> {
}
}
pub fn claims_from_token(payload: TokenPayload) -> Result<Claims, jsonwebtoken::errors::Error> {
let token_data = decode::<Claims>(&payload.token, &KEYS.decoding, &Validation::default());
match token_data {
Ok(data) => Ok(data.claims),
Err(e) => Err(e)
}
}
impl AuthBody {
fn new(access_token: String) -> Self {
Self {

View file

@ -1,14 +1,14 @@
use std::{net::SocketAddr, ops::ControlFlow, sync::Arc};
use axum::{
body::Bytes, extract::{
extract::{
ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
ConnectInfo,
State
}, response::IntoResponse
};
use crate::{routes::auth::Claims, utils::events, AppState};
use crate::{routes::{auth::{claims_from_token, Claims, TokenPayload}}, utils::events::{self, WebsocketMessage}, AppState};
use futures_util::{sink::SinkExt, stream::StreamExt};
@ -27,79 +27,81 @@ use futures_util::{sink::SinkExt, stream::StreamExt};
pub async fn ws_handler(
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
claims: Claims,
State(state): State<Arc<AppState>>
) -> impl IntoResponse {
log::debug!(target: "websocket", "{addr} connected.");
ws.on_upgrade(move |socket| handle_socket(socket, addr, state, claims))
log::info!(target: "websocket", "{addr} connected to the websocket.");
ws.on_upgrade(move |socket| handle_socket(socket, addr, state))
}
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>) {
// Authenticate connected socket
let mut claims_t: Option<Claims> = None;
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>, claims: Claims) {
if socket
.send(Message::Ping(Bytes::from_static(&[4, 2])))
.await
.is_ok()
{
log::debug!(target: "websocket", "Pinged {who}...");
} else {
log::debug!(target: "websocket", "Could not send ping to {who}!");
return;
}
if let Some(msg) = socket.recv().await {
if let Ok(msg) = msg {
if process_message(msg, who).is_break() {
return;
if let Some(Ok(msg)) = socket.recv().await {
if let Message::Text(txt) = msg && let Ok(auth_payload) = serde_json::from_str::<TokenPayload>(&txt) {
if let Ok(claims) = claims_from_token(auth_payload) {
claims_t = Some(claims);
} else {
log::debug!(target: "websocket", "{who} tried to authenticate with wrong token");
}
} else {
log::debug!(target: "websocket", "Client {who} abruptly disconnected");
return;
log::debug!(target: "websocket", "{who} send an invalid payload before logging in")
}
}
let (mut sender, mut receiver) = socket.split();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
if process_message(msg, who).is_break() {
break;
}
}
});
let mut send_task = tokio::spawn(async move {
let mut event_listener = state.event_bus.subscribe();
loop {
match event_listener.recv().await {
Err(_) => (),
Ok(event) => {
match event {
events::Event::WebsocketBroadcast(message) => {
if !message.should_user_receive(claims.user_id) {
continue;
};
log::debug!(target: "websocket", "Sent {message:?} to {who}");
let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await;
match claims_t {
Some(claims) => {
log::debug!(target: "websocket", "{who} successfully authenticated on the websocket");
// Socket is authenticated, go on
let (mut sender, mut receiver) = socket.split();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
if process_message(msg, who).is_break() {
break;
}
}
});
let mut send_task = tokio::spawn(async move {
let mut event_listener = state.event_bus.subscribe();
loop {
match event_listener.recv().await {
Err(_) => (),
Ok(event) => {
match event {
events::Event::WebsocketBroadcast(message) => {
if !message.should_user_receive(claims.user_id) {
continue;
};
log::debug!(target: "websocket", "Sent {message:?} to {who}");
let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await;
}
}
}
}
}
}
}
});
});
tokio::select! {
rv_a = (&mut send_task) => {
match rv_a {
Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"),
Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}")
tokio::select! {
rv_a = (&mut send_task) => {
match rv_a {
Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"),
Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}")
}
recv_task.abort();
},
rv_b = (&mut recv_task) => {
match rv_b {
Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"),
Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}")
}
send_task.abort();
}
}
recv_task.abort();
},
rv_b = (&mut recv_task) => {
match rv_b {
Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"),
Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}")
}
send_task.abort();
None => {
// Socket was not authenticated, abort the mission
let _ = socket.send(Message::Text(WebsocketMessage::Error(r#"Invalid Authentication. When you connect to the websocket, please send a text message formatted in the following way: {"token": "valid_json_web_token"}"#.to_string()).to_json().to_string().into())).await;
return;
}
}
}

View file

@ -12,6 +12,7 @@ pub enum Event {
#[derive(Clone, Debug)]
pub enum WebsocketMessage {
NewOwner(Arc<owner::Model>),
Error(String),
Ping
}
@ -20,10 +21,12 @@ impl WebsocketMessage {
json!({
"type": match self {
Self::NewOwner(_) => "new_owner",
Self::Error(_) => "error",
Self::Ping => "ping",
},
"data": match self {
Self::NewOwner(owner) => json!(owner),
Self::Error(error) => json!(error),
Self::Ping => json!(null),
}
})