feat: websocket authentication now takes place in the websocket instead of using headers, for broader API compatibility
eg Javascript WebSocket API
This commit is contained in:
parent
fa3017ce44
commit
6ca7916012
4 changed files with 75 additions and 62 deletions
|
|
@ -182,14 +182,14 @@ async fn run_server(db: Arc<DatabaseConnection>) {
|
||||||
.routes(routes!(routes::bal::create_bal))
|
.routes(routes!(routes::bal::create_bal))
|
||||||
.routes(routes!(routes::bal::update_bal))
|
.routes(routes!(routes::bal::update_bal))
|
||||||
.routes(routes!(routes::bal::get_bals))
|
.routes(routes!(routes::bal::get_bals))
|
||||||
// Misc
|
|
||||||
.routes(routes!(routes::websocket::ws_handler))
|
|
||||||
// Authentication
|
// Authentication
|
||||||
.route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware))
|
.route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware))
|
||||||
.routes(routes!(routes::auth::auth))
|
.routes(routes!(routes::auth::auth))
|
||||||
.routes(routes!(routes::auth::check_token))
|
.routes(routes!(routes::auth::check_token))
|
||||||
|
// Misc
|
||||||
|
.routes(routes!(routes::websocket::ws_handler))
|
||||||
.route("/", get(index))
|
.route("/", get(index))
|
||||||
|
|
||||||
.with_state(shared_state)
|
.with_state(shared_state)
|
||||||
.split_for_parts();
|
.split_for_parts();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,14 @@ pub async fn check_token(Json(payload): Json<TokenPayload>) -> Json<bool> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn claims_from_token(payload: TokenPayload) -> Result<Claims, jsonwebtoken::errors::Error> {
|
||||||
|
let token_data = decode::<Claims>(&payload.token, &KEYS.decoding, &Validation::default());
|
||||||
|
match token_data {
|
||||||
|
Ok(data) => Ok(data.claims),
|
||||||
|
Err(e) => Err(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl AuthBody {
|
impl AuthBody {
|
||||||
fn new(access_token: String) -> Self {
|
fn new(access_token: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
use std::{net::SocketAddr, ops::ControlFlow, sync::Arc};
|
use std::{net::SocketAddr, ops::ControlFlow, sync::Arc};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Bytes, extract::{
|
extract::{
|
||||||
ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
|
ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
|
||||||
ConnectInfo,
|
ConnectInfo,
|
||||||
State
|
State
|
||||||
}, response::IntoResponse
|
}, response::IntoResponse
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{routes::auth::Claims, utils::events, AppState};
|
use crate::{routes::{auth::{claims_from_token, Claims, TokenPayload}}, utils::events::{self, WebsocketMessage}, AppState};
|
||||||
|
|
||||||
use futures_util::{sink::SinkExt, stream::StreamExt};
|
use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||||
|
|
||||||
|
|
@ -27,79 +27,81 @@ use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||||
pub async fn ws_handler(
|
pub async fn ws_handler(
|
||||||
ws: WebSocketUpgrade,
|
ws: WebSocketUpgrade,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
claims: Claims,
|
|
||||||
State(state): State<Arc<AppState>>
|
State(state): State<Arc<AppState>>
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
log::debug!(target: "websocket", "{addr} connected.");
|
log::info!(target: "websocket", "{addr} connected to the websocket.");
|
||||||
ws.on_upgrade(move |socket| handle_socket(socket, addr, state, claims))
|
ws.on_upgrade(move |socket| handle_socket(socket, addr, state))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>) {
|
||||||
|
// Authenticate connected socket
|
||||||
|
let mut claims_t: Option<Claims> = None;
|
||||||
|
|
||||||
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>, claims: Claims) {
|
if let Some(Ok(msg)) = socket.recv().await {
|
||||||
if socket
|
if let Message::Text(txt) = msg && let Ok(auth_payload) = serde_json::from_str::<TokenPayload>(&txt) {
|
||||||
.send(Message::Ping(Bytes::from_static(&[4, 2])))
|
if let Ok(claims) = claims_from_token(auth_payload) {
|
||||||
.await
|
claims_t = Some(claims);
|
||||||
.is_ok()
|
} else {
|
||||||
{
|
log::debug!(target: "websocket", "{who} tried to authenticate with wrong token");
|
||||||
log::debug!(target: "websocket", "Pinged {who}...");
|
|
||||||
} else {
|
|
||||||
log::debug!(target: "websocket", "Could not send ping to {who}!");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(msg) = socket.recv().await {
|
|
||||||
if let Ok(msg) = msg {
|
|
||||||
if process_message(msg, who).is_break() {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log::debug!(target: "websocket", "Client {who} abruptly disconnected");
|
log::debug!(target: "websocket", "{who} send an invalid payload before logging in")
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (mut sender, mut receiver) = socket.split();
|
match claims_t {
|
||||||
let mut recv_task = tokio::spawn(async move {
|
Some(claims) => {
|
||||||
while let Some(Ok(msg)) = receiver.next().await {
|
log::debug!(target: "websocket", "{who} successfully authenticated on the websocket");
|
||||||
if process_message(msg, who).is_break() {
|
// Socket is authenticated, go on
|
||||||
break;
|
let (mut sender, mut receiver) = socket.split();
|
||||||
}
|
let mut recv_task = tokio::spawn(async move {
|
||||||
}
|
while let Some(Ok(msg)) = receiver.next().await {
|
||||||
});
|
if process_message(msg, who).is_break() {
|
||||||
let mut send_task = tokio::spawn(async move {
|
break;
|
||||||
let mut event_listener = state.event_bus.subscribe();
|
}
|
||||||
loop {
|
}
|
||||||
match event_listener.recv().await {
|
});
|
||||||
Err(_) => (),
|
let mut send_task = tokio::spawn(async move {
|
||||||
Ok(event) => {
|
let mut event_listener = state.event_bus.subscribe();
|
||||||
match event {
|
loop {
|
||||||
events::Event::WebsocketBroadcast(message) => {
|
match event_listener.recv().await {
|
||||||
if !message.should_user_receive(claims.user_id) {
|
Err(_) => (),
|
||||||
continue;
|
Ok(event) => {
|
||||||
};
|
match event {
|
||||||
log::debug!(target: "websocket", "Sent {message:?} to {who}");
|
events::Event::WebsocketBroadcast(message) => {
|
||||||
let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await;
|
if !message.should_user_receive(claims.user_id) {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
log::debug!(target: "websocket", "Sent {message:?} to {who}");
|
||||||
|
let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
rv_a = (&mut send_task) => {
|
rv_a = (&mut send_task) => {
|
||||||
match rv_a {
|
match rv_a {
|
||||||
Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"),
|
Ok(()) => log::debug!(target: "websocket", "Sender connection with {who} gracefully shut down"),
|
||||||
Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}")
|
Err(a) => log::debug!(target: "websocket", "Error sending messages {a:?}")
|
||||||
|
}
|
||||||
|
recv_task.abort();
|
||||||
|
},
|
||||||
|
rv_b = (&mut recv_task) => {
|
||||||
|
match rv_b {
|
||||||
|
Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"),
|
||||||
|
Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}")
|
||||||
|
}
|
||||||
|
send_task.abort();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
recv_task.abort();
|
|
||||||
},
|
},
|
||||||
rv_b = (&mut recv_task) => {
|
None => {
|
||||||
match rv_b {
|
// Socket was not authenticated, abort the mission
|
||||||
Ok(()) => log::debug!(target: "websocket", "Receiver connection with {who} gracefully shut down"),
|
let _ = socket.send(Message::Text(WebsocketMessage::Error(r#"Invalid Authentication. When you connect to the websocket, please send a text message formatted in the following way: {"token": "valid_json_web_token"}"#.to_string()).to_json().to_string().into())).await;
|
||||||
Err(b) => log::debug!(target: "websocket", "Error receiving messages {b:?}")
|
return;
|
||||||
}
|
|
||||||
send_task.abort();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ pub enum Event {
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub enum WebsocketMessage {
|
pub enum WebsocketMessage {
|
||||||
NewOwner(Arc<owner::Model>),
|
NewOwner(Arc<owner::Model>),
|
||||||
|
Error(String),
|
||||||
Ping
|
Ping
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -20,10 +21,12 @@ impl WebsocketMessage {
|
||||||
json!({
|
json!({
|
||||||
"type": match self {
|
"type": match self {
|
||||||
Self::NewOwner(_) => "new_owner",
|
Self::NewOwner(_) => "new_owner",
|
||||||
|
Self::Error(_) => "error",
|
||||||
Self::Ping => "ping",
|
Self::Ping => "ping",
|
||||||
},
|
},
|
||||||
"data": match self {
|
"data": match self {
|
||||||
Self::NewOwner(owner) => json!(owner),
|
Self::NewOwner(owner) => json!(owner),
|
||||||
|
Self::Error(error) => json!(error),
|
||||||
Self::Ping => json!(null),
|
Self::Ping => json!(null),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Reference in a new issue