initial websocket implementation

This commit is contained in:
Ninjdai 2025-08-01 01:27:25 +02:00
parent 4aa5cf463f
commit 3e1c744db1
7 changed files with 244 additions and 4 deletions

57
Cargo.lock generated
View file

@ -43,6 +43,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"axum", "axum",
"dotenvy", "dotenvy",
"futures-util",
"reqwest", "reqwest",
"sea-orm", "sea-orm",
"serde", "serde",
@ -158,6 +159,7 @@ checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
dependencies = [ dependencies = [
"axum-core", "axum-core",
"axum-macros", "axum-macros",
"base64",
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
"futures-util", "futures-util",
@ -177,8 +179,10 @@ dependencies = [
"serde_json", "serde_json",
"serde_path_to_error", "serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sha1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite",
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
@ -473,6 +477,12 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "data-encoding"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]] [[package]]
name = "der" name = "der"
version = "0.7.10" version = "0.7.10"
@ -704,6 +714,17 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
]
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.31" version = "0.3.31"
@ -724,6 +745,7 @@ checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-io", "futures-io",
"futures-macro",
"futures-sink", "futures-sink",
"futures-task", "futures-task",
"memchr", "memchr",
@ -2875,6 +2897,18 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.15" version = "0.7.15"
@ -2989,6 +3023,23 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand 0.9.2",
"sha1",
"thiserror",
"utf-8",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.18.0" version = "1.18.0"
@ -3045,6 +3096,12 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "utf8_iter" name = "utf8_iter"
version = "1.0.4" version = "1.0.4"

View file

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
axum = { version = "0.8.4", features = [ "macros" ] } axum = { version = "0.8.4", features = [ "macros", "ws", "tokio" ] }
dotenvy = "0.15.7" dotenvy = "0.15.7"
reqwest = "0.12.22" reqwest = "0.12.22"
sea-orm = { version = "1.1.13", features = [ "sqlx-sqlite", "runtime-tokio-rustls", "macros" ] } sea-orm = { version = "1.1.13", features = [ "sqlx-sqlite", "runtime-tokio-rustls", "macros" ] }
@ -15,4 +15,5 @@ utoipa = "5.4.0"
utoipa-axum = "0.2.0" utoipa-axum = "0.2.0"
utoipa-swagger-ui = { version = "9", features = ["axum", "reqwest"] } utoipa-swagger-ui = { version = "9", features = ["axum", "reqwest"] }
utoipa-redoc = { version = "6", features = ["axum"] } utoipa-redoc = { version = "6", features = ["axum"] }
futures-util = "0.3.31"

View file

@ -1,15 +1,16 @@
use std::sync::Arc; use std::{net::SocketAddr, sync::Arc, time::Duration};
use axum::{extract::State, http::HeaderMap, routing::get}; use axum::{extract::State, http::HeaderMap, routing::get};
use reqwest::{header::USER_AGENT}; use reqwest::{header::USER_AGENT};
use sea_orm::{ConnectionTrait, Database, DatabaseConnection, EntityTrait, PaginatorTrait, Schema}; use sea_orm::{ConnectionTrait, Database, DatabaseConnection, EntityTrait, PaginatorTrait, Schema};
use tokio::{sync::broadcast::{self, Sender}, task, time};
use utoipa::openapi::{ContactBuilder, InfoBuilder, LicenseBuilder}; use utoipa::openapi::{ContactBuilder, InfoBuilder, LicenseBuilder};
use utoipa_axum::router::OpenApiRouter; use utoipa_axum::router::OpenApiRouter;
use utoipa_redoc::{Redoc, Servable}; use utoipa_redoc::{Redoc, Servable};
use utoipa_swagger_ui::{Config, SwaggerUi}; use utoipa_swagger_ui::{Config, SwaggerUi};
use utoipa_axum::routes; use utoipa_axum::routes;
use crate::entities::prelude::BookInstance; use crate::{entities::{owner, prelude::BookInstance}, utils::events::Event};
pub mod entities; pub mod entities;
pub mod utils; pub mod utils;
@ -18,6 +19,7 @@ pub mod routes;
pub struct AppState { pub struct AppState {
app_name: String, app_name: String,
db_conn: Arc<DatabaseConnection>, db_conn: Arc<DatabaseConnection>,
event_bus: Sender<Event>,
web_client: reqwest::Client web_client: reqwest::Client
} }
@ -53,11 +55,31 @@ async fn main() {
return; return;
} }
let (event_bus, _) = broadcast::channel(16);
let ntx = event_bus.clone();
let _forever = task::spawn(async move {
let mut interval = time::interval(Duration::from_secs(5));
let mut id = 1;
loop {
interval.tick().await;
let _ = ntx.send(Event::WebsocketBroadcast(utils::events::WebsocketMessage::NewOwner(Arc::new(owner::Model {
id,
first_name: "Avril".to_string(),
last_name: "Papillon".to_string(),
contact: "avril.papillon@proton.me".to_string()
}))));
id += 1;
}
});
let mut default_headers = HeaderMap::new(); let mut default_headers = HeaderMap::new();
default_headers.append(USER_AGENT, "Alexandria/1.0 (unionetudianteauvergne@gmail.com)".parse().unwrap()); default_headers.append(USER_AGENT, "Alexandria/1.0 (unionetudianteauvergne@gmail.com)".parse().unwrap());
let shared_state = Arc::new(AppState { let shared_state = Arc::new(AppState {
app_name: "Alexandria".to_string(), app_name: "Alexandria".to_string(),
db_conn: db, db_conn: db,
event_bus,
web_client: reqwest::Client::builder().default_headers(default_headers).build().expect("creating the reqwest client failed") web_client: reqwest::Client::builder().default_headers(default_headers).build().expect("creating the reqwest client failed")
}); });
@ -68,6 +90,7 @@ async fn main() {
.routes(routes!(routes::book_instance::create_book_instance)) .routes(routes!(routes::book_instance::create_book_instance))
.routes(routes!(routes::owner::get_owner_by_id)) .routes(routes!(routes::owner::get_owner_by_id))
.routes(routes!(routes::owner::create_owner)) .routes(routes!(routes::owner::create_owner))
.route("/ws", get(routes::websocket::ws_handler))
.route("/", get(index)) .route("/", get(index))
.with_state(shared_state) .with_state(shared_state)
.split_for_parts(); .split_for_parts();
@ -95,6 +118,10 @@ async fn main() {
let router = router.merge(swagger); let router = router.merge(swagger);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap(); println!("Running on http://{}", listener.local_addr().unwrap());
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>()
).await.unwrap();
} }

View file

@ -1,4 +1,5 @@
pub mod book; pub mod book;
pub mod book_instance; pub mod book_instance;
pub mod owner; pub mod owner;
pub mod websocket;

124
src/routes/websocket.rs Normal file
View file

@ -0,0 +1,124 @@
use std::{net::SocketAddr, ops::ControlFlow, sync::Arc};
use axum::{
body::Bytes,
extract::{
ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
ConnectInfo,
State
},
response::IntoResponse
};
use serde_json::json;
use crate::{utils::events, AppState};
use futures_util::{sink::SinkExt, stream::StreamExt};
#[axum::debug_handler]
pub async fn ws_handler(
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<Arc<AppState>>
) -> impl IntoResponse {
println!("`{addr} connected.");
// finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| handle_socket(socket, addr, state))
}
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, state: Arc<AppState>) {
if socket
.send(Message::Ping(Bytes::from_static(&[4, 2])))
.await
.is_ok()
{
println!("WS >>> Pinged {who}...");
} else {
println!("WS >>> Could not send ping {who}!");
return;
}
if let Some(msg) = socket.recv().await {
if let Ok(msg) = msg {
if process_message(msg, who).is_break() {
return;
}
} else {
println!("WS >>> Client {who} abruptly disconnected");
return;
}
}
let (mut sender, mut receiver) = socket.split();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
if process_message(msg, who).is_break() {
break;
}
}
});
let mut send_task = tokio::spawn(async move {
let mut event_listener = state.event_bus.subscribe();
loop {
match event_listener.recv().await {
Err(_) => (),
Ok(event) => {
match event {
events::Event::WebsocketBroadcast(message) => {
let _ = sender.send(Message::Text(Utf8Bytes::from(message.to_json().to_string()))).await;
}
}
}
}
}
});
tokio::select! {
rv_a = (&mut send_task) => {
match rv_a {
Ok(()) => println!("WS >>> Sender connection with {who} gracefully shut down"),
Err(a) => println!("WS >>> Error sending messages {a:?}")
}
recv_task.abort();
},
rv_b = (&mut recv_task) => {
match rv_b {
Ok(()) => println!("WS >>> Receiver connection with {who} gracefully shut down"),
Err(b) => println!("WS >>> Error receiving messages {b:?}")
}
send_task.abort();
}
}
}
fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> {
match msg {
Message::Text(t) => {
println!("WS >>> {who} sent str: {t:?}");
}
Message::Binary(d) => {
println!("WS >>> {who} sent {} bytes: {d:?}", d.len());
}
Message::Close(c) => {
if let Some(cf) = c {
println!(
"WS >>> {who} sent close with code {} and reason `{}`",
cf.code, cf.reason
);
} else {
println!("WS >>> {who} somehow sent close message without CloseFrame");
}
return ControlFlow::Break(());
}
Message::Pong(v) => {
println!("WS >>> {who} sent pong with {v:?}");
}
Message::Ping(v) => {
println!("WS >>> {who} sent ping with {v:?}");
}
}
ControlFlow::Continue(())
}

28
src/utils/events.rs Normal file
View file

@ -0,0 +1,28 @@
use std::sync::Arc;
use serde_json::{json, Value};
use crate::entities::owner;
#[derive(Clone)]
pub enum Event {
WebsocketBroadcast(WebsocketMessage)
}
#[derive(Clone)]
pub enum WebsocketMessage {
NewOwner(Arc<owner::Model>)
}
impl WebsocketMessage {
pub fn to_json(self) -> Value {
json!({
"type": match self {
Self::NewOwner(_) => "new_owner",
},
"data": match self {
Self::NewOwner(owner) => json!(owner),
}
})
}
}

View file

@ -1,2 +1,4 @@
pub mod open_library; pub mod open_library;
pub mod serde; pub mod serde;
pub mod events;