feat: websocket authentication now takes place in the websocket instead of using headers, for broader API compatibility
eg Javascript WebSocket API
This commit is contained in:
parent
fa3017ce44
commit
6ca7916012
4 changed files with 75 additions and 62 deletions
|
|
@ -182,14 +182,14 @@ async fn run_server(db: Arc<DatabaseConnection>) {
|
|||
.routes(routes!(routes::bal::create_bal))
|
||||
.routes(routes!(routes::bal::update_bal))
|
||||
.routes(routes!(routes::bal::get_bals))
|
||||
// Misc
|
||||
.routes(routes!(routes::websocket::ws_handler))
|
||||
// Authentication
|
||||
.route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware))
|
||||
.routes(routes!(routes::auth::auth))
|
||||
.routes(routes!(routes::auth::check_token))
|
||||
|
||||
// Misc
|
||||
.routes(routes!(routes::websocket::ws_handler))
|
||||
.route("/", get(index))
|
||||
|
||||
.with_state(shared_state)
|
||||
.split_for_parts();
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,14 @@ pub async fn check_token(Json(payload): Json<TokenPayload>) -> Json<bool> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn claims_from_token(payload: TokenPayload) -> Result<Claims, jsonwebtoken::errors::Error> {
|
||||
let token_data = decode::<Claims>(&payload.token, &KEYS.decoding, &Validation::default());
|
||||
match token_data {
|
||||
Ok(data) => Ok(data.claims),
|
||||
Err(e) => Err(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthBody {
|
||||
fn new(access_token: String) -> Self {
|
||||
Self {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
use std::{net::SocketAddr, ops::ControlFlow, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
body::Bytes, extract::{
|
||||
extract::{
|
||||
ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
|
||||
ConnectInfo,
|
||||
State
|
||||
}, response::IntoResponse
|
||||
};
|
||||
|
||||
use crate::{routes::auth::Claims, utils::events, AppState};
|
||||
use crate::{routes::{auth::{claims_from_token, Claims, TokenPayload}}, utils::events::{self, WebsocketMessage}, AppState};
|
||||
|
||||
use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||
|
||||
|
|
@ -27,79 +27,81 @@ use futures_util::{sink::SinkExt, stream::StreamExt};
|
|||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
claims: Claims,
|
||||
State(state): State<Arc<AppState>>
|
||||
) -> impl IntoResponse {
|
||||
log::debug!(target: "websocket", "{addr} connected.");
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, addr, state, claims))
|
||||
log::info!(target: "websocket", "{addr} connected to the websocket.");
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, addr, state))
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>) {
|
||||
// Authenticate connected socket
|
||||
let mut claims_t: Option<Claims> = None;
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>, claims: Claims) {
|
||||
if socket
|
||||
.send(Message::Ping(Bytes::from_static(&[4, 2])))
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
log::debug!(target: "websocket", "Pinged {who}...");
|
||||
} else {
|
||||
log::debug!(target: "websocket", "Could not send ping to {who}!");
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(msg) = socket.recv().await {
|
||||
if let Ok(msg) = msg {
|
||||
if process_message(msg, who).is_break() {
|
||||
return;
|
||||
if let Some(Ok(msg)) = socket.recv().await {
|
||||
if let Message::Text(txt) = msg && let Ok(auth_payload) = serde_json::from_str::<TokenPayload>(&txt) {
|
||||
if let Ok(claims) = claims_from_token(auth_payload) {
|
||||
claims_t = Some(claims);
|
||||
} else {
|
||||
log::debug!(target: "websocket", "{who} tried to authenticate with wrong token");
|
||||
}
|
||||
} else {
|
||||
log::debug!(target: "websocket", "Client {who} abruptly disconnected");
|
||||
return;
|
||||
log::debug!(target: "websocket", "{who} send an invalid payload before logging in")
|
||||
}
|
||||
}
|
||||
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = receiver.next().await {
|
||||
if process_message(msg, who).is_break() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
let mut event_listener = state.event_bus.subscribe();
|
||||
loop {
|
||||
match event_listener.recv().await {
|
||||
Err(_) => (),
|
||||
Ok(event) => {
|
||||
match event {
|
||||
events::Event::WebsocketBroadcast(message) => {
|
||||
if !message.should_user_receive(claims.user_id) {
|
||||
continue;
|
||||
};
|
||||
log::debug!(target: "websocket", "Sent {message:?} to {who}");
|
||||
let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await;
|
||||
match claims_t {
|
||||
Some(claims) => {
|
||||
log::debug!(target: "websocket", "{who} successfully authenticated on the websocket");
|
||||
// Socket is authenticated, go on
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = receiver.next().await {
|
||||
if process_message(msg, who).is_break() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
let mut event_listener = state.event_bus.subscribe();
|
||||
loop {
|
||||
match event_listener.recv().await {
|
||||
Err(_) => (),
|
||||
Ok(event) => {
|
||||
match event {
|
||||
events::Event::WebsocketBroadcast(message) => {
|
||||
if !message.should_user_receive(claims.user_id) {
|
||||
continue;
|
||||
};
|
||||
log::debug!(target: "websocket", "Sent {message:?} to {who}");
|
||||
let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
rv_a = (&mut send_task) => {
|
||||
match rv_a {
|
||||
Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"),
|
||||
Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}")
|
||||
tokio::select! {
|
||||
rv_a = (&mut send_task) => {
|
||||
match rv_a {
|
||||
Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"),
|
||||
Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}")
|
||||
}
|
||||
recv_task.abort();
|
||||
},
|
||||
rv_b = (&mut recv_task) => {
|
||||
match rv_b {
|
||||
Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"),
|
||||
Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}")
|
||||
}
|
||||
send_task.abort();
|
||||
}
|
||||
}
|
||||
recv_task.abort();
|
||||
},
|
||||
rv_b = (&mut recv_task) => {
|
||||
match rv_b {
|
||||
Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"),
|
||||
Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}")
|
||||
}
|
||||
send_task.abort();
|
||||
None => {
|
||||
// Socket was not authenticated, abort the mission
|
||||
let _ = socket.send(Message::Text(WebsocketMessage::Error(r#"Invalid Authentication. When you connect to the websocket, please send a text message formatted in the following way: {"token": "valid_json_web_token"}"#.to_string()).to_json().to_string().into())).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ pub enum Event {
|
|||
#[derive(Clone, Debug)]
|
||||
pub enum WebsocketMessage {
|
||||
NewOwner(Arc<owner::Model>),
|
||||
Error(String),
|
||||
Ping
|
||||
}
|
||||
|
||||
|
|
@ -20,10 +21,12 @@ impl WebsocketMessage {
|
|||
json!({
|
||||
"type": match self {
|
||||
Self::NewOwner(_) => "new_owner",
|
||||
Self::Error(_) => "error",
|
||||
Self::Ping => "ping",
|
||||
},
|
||||
"data": match self {
|
||||
Self::NewOwner(owner) => json!(owner),
|
||||
Self::Error(error) => json!(error),
|
||||
Self::Ping => json!(null),
|
||||
}
|
||||
})
|
||||
|
|
|
|||
Reference in a new issue