From 3e1c744db1c0f564cc46beeceeda14116a28a199 Mon Sep 17 00:00:00 2001 From: Ninjdai Date: Fri, 1 Aug 2025 01:27:25 +0200 Subject: [PATCH] initial websocket implementation --- Cargo.lock | 57 ++++++++++++++++++ Cargo.toml | 3 +- src/main.rs | 33 ++++++++++- src/routes/mod.rs | 1 + src/routes/websocket.rs | 124 ++++++++++++++++++++++++++++++++++++++++ src/utils/events.rs | 28 +++++++++ src/utils/mod.rs | 2 + 7 files changed, 244 insertions(+), 4 deletions(-) create mode 100644 src/routes/websocket.rs create mode 100644 src/utils/events.rs diff --git a/Cargo.lock b/Cargo.lock index fa6cf27..05f2a35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,6 +43,7 @@ version = "0.1.0" dependencies = [ "axum", "dotenvy", + "futures-util", "reqwest", "sea-orm", "serde", @@ -158,6 +159,7 @@ checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" dependencies = [ "axum-core", "axum-macros", + "base64", "bytes", "form_urlencoded", "futures-util", @@ -177,8 +179,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -473,6 +477,12 @@ dependencies = [ "typenum", ] +[[package]] +name = "data-encoding" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" + [[package]] name = "der" version = "0.7.10" @@ -704,6 +714,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -724,6 +745,7 @@ checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -2875,6 +2897,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.15" @@ -2989,6 +3023,23 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.2", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typenum" version = "1.18.0" @@ -3045,6 +3096,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index 51c16d2..052dc7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -axum = { version = "0.8.4", features = [ "macros" ] } +axum = { version = "0.8.4", features = [ "macros", "ws", "tokio" ] } dotenvy = "0.15.7" reqwest = "0.12.22" sea-orm = { version = "1.1.13", features = [ "sqlx-sqlite", "runtime-tokio-rustls", "macros" ] } @@ -15,4 +15,5 @@ utoipa = "5.4.0" utoipa-axum = "0.2.0" utoipa-swagger-ui = { version = "9", features = ["axum", "reqwest"] } utoipa-redoc = { version = "6", features = ["axum"] } +futures-util = "0.3.31" diff --git a/src/main.rs b/src/main.rs index c3192e4..9ab5ac9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,16 @@ -use std::sync::Arc; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use axum::{extract::State, http::HeaderMap, routing::get}; use reqwest::{header::USER_AGENT}; use sea_orm::{ConnectionTrait, Database, DatabaseConnection, EntityTrait, PaginatorTrait, Schema}; +use tokio::{sync::broadcast::{self, Sender}, task, time}; use utoipa::openapi::{ContactBuilder, InfoBuilder, LicenseBuilder}; use utoipa_axum::router::OpenApiRouter; use utoipa_redoc::{Redoc, Servable}; use utoipa_swagger_ui::{Config, SwaggerUi}; use utoipa_axum::routes; -use crate::entities::prelude::BookInstance; +use crate::{entities::{owner, prelude::BookInstance}, utils::events::Event}; pub mod entities; pub mod utils; @@ -18,6 +19,7 @@ pub mod routes; pub struct AppState { app_name: String, db_conn: Arc, + event_bus: Sender, web_client: reqwest::Client } @@ -53,11 +55,31 @@ async fn main() { return; } + let (event_bus, _) = broadcast::channel(16); + + let ntx = event_bus.clone(); + let _forever = task::spawn(async move { + let mut interval = time::interval(Duration::from_secs(5)); + + let mut id = 1; + loop { + interval.tick().await; + let _ = ntx.send(Event::WebsocketBroadcast(utils::events::WebsocketMessage::NewOwner(Arc::new(owner::Model { + id, + first_name: "Avril".to_string(), + last_name: "Papillon".to_string(), + contact: "avril.papillon@proton.me".to_string() + })))); + id += 1; + } + }); + let mut default_headers = HeaderMap::new(); default_headers.append(USER_AGENT, "Alexandria/1.0 (unionetudianteauvergne@gmail.com)".parse().unwrap()); let shared_state = Arc::new(AppState { app_name: "Alexandria".to_string(), db_conn: db, + event_bus, web_client: reqwest::Client::builder().default_headers(default_headers).build().expect("creating the reqwest client failed") }); @@ -68,6 +90,7 @@ async fn main() { .routes(routes!(routes::book_instance::create_book_instance)) .routes(routes!(routes::owner::get_owner_by_id)) .routes(routes!(routes::owner::create_owner)) + .route("/ws", get(routes::websocket::ws_handler)) .route("/", get(index)) .with_state(shared_state) .split_for_parts(); @@ -95,6 +118,10 @@ async fn main() { let router = router.merge(swagger); let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); - axum::serve(listener, router).await.unwrap(); + println!("Running on http://{}", listener.local_addr().unwrap()); + axum::serve( + listener, + router.into_make_service_with_connect_info::() + ).await.unwrap(); } diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 326031a..d7e6167 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -1,4 +1,5 @@ pub mod book; pub mod book_instance; pub mod owner; +pub mod websocket; diff --git a/src/routes/websocket.rs b/src/routes/websocket.rs new file mode 100644 index 0000000..f9faa94 --- /dev/null +++ b/src/routes/websocket.rs @@ -0,0 +1,124 @@ +use std::{net::SocketAddr, ops::ControlFlow, sync::Arc}; + +use axum::{ + body::Bytes, + extract::{ + ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, + ConnectInfo, + State + }, + response::IntoResponse +}; +use serde_json::json; + +use crate::{utils::events, AppState}; + +use futures_util::{sink::SinkExt, stream::StreamExt}; + +#[axum::debug_handler] +pub async fn ws_handler( + ws: WebSocketUpgrade, + ConnectInfo(addr): ConnectInfo, + State(state): State> +) -> impl IntoResponse { + println!("`{addr} connected."); + // finalize the upgrade process by returning upgrade callback. + // we can customize the callback by sending additional info such as address. + ws.on_upgrade(move |socket| handle_socket(socket, addr, state)) +} + + +async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc) { + if socket + .send(Message::Ping(Bytes::from_static(&[4, 2]))) + .await + .is_ok() + { + println!("WS >>> Pinged {who}..."); + } else { + println!("WS >>> Could not send ping {who}!"); + return; + } + + if let Some(msg) = socket.recv().await { + if let Ok(msg) = msg { + if process_message(msg, who).is_break() { + return; + } + } else { + println!("WS >>> Client {who} abruptly disconnected"); + return; + } + } + + let (mut sender, mut receiver) = socket.split(); + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(msg)) = receiver.next().await { + if process_message(msg, who).is_break() { + break; + } + } + }); + let mut send_task = tokio::spawn(async move { + let mut event_listener = state.event_bus.subscribe(); + loop { + match event_listener.recv().await { + Err(_) => (), + Ok(event) => { + match event { + events::Event::WebsocketBroadcast(message) => { + let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await; + } + } + } + } + } + }); + + tokio::select! { + rv_a = (&mut send_task) => { + match rv_a { + Ok(()) => println!("WS >>> Sender connection with {who} gracefully shut down"), + Err(a) => println!("WS >>> Error sending messages {a:?}") + } + recv_task.abort(); + }, + rv_b = (&mut recv_task) => { + match rv_b { + Ok(()) => println!("WS >>> Receiver connection with {who} gracefully shut down"), + Err(b) => println!("WS >>> Error receiving messages {b:?}") + } + send_task.abort(); + } + } +} + +fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { + match msg { + Message::Text(t) => { + println!("WS >>> {who} sent str: {t:?}"); + } + Message::Binary(d) => { + println!("WS >>> {who} sent {} bytes: {d:?}", d.len()); + } + Message::Close(c) => { + if let Some(cf) = c { + println!( + "WS >>> {who} sent close with code {} and reason `{}`", + cf.code, cf.reason + ); + } else { + println!("WS >>> {who} somehow sent close message without CloseFrame"); + } + return ControlFlow::Break(()); + } + + Message::Pong(v) => { + println!("WS >>> {who} sent pong with {v:?}"); + } + Message::Ping(v) => { + println!("WS >>> {who} sent ping with {v:?}"); + } + } + ControlFlow::Continue(()) +} diff --git a/src/utils/events.rs b/src/utils/events.rs new file mode 100644 index 0000000..6ca8285 --- /dev/null +++ b/src/utils/events.rs @@ -0,0 +1,28 @@ +use std::sync::Arc; + +use serde_json::{json, Value}; + +use crate::entities::owner; + +#[derive(Clone)] +pub enum Event { + WebsocketBroadcast(WebsocketMessage) +} + +#[derive(Clone)] +pub enum WebsocketMessage { + NewOwner(Arc) +} + +impl WebsocketMessage { + pub fn to_json(self) -> Value { + json!({ + "type": match self { + Self::NewOwner(_) => "new_owner", + }, + "data": match self { + Self::NewOwner(owner) => json!(owner), + } + }) + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index a9f042c..82c29d2 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,2 +1,4 @@ pub mod open_library; pub mod serde; +pub mod events; +