From 82f2bb4a6180022d980c3ac3dde55bd7d2273407 Mon Sep 17 00:00:00 2001 From: Ninjdai Date: Sat, 9 Aug 2025 17:43:31 +0200 Subject: [PATCH] feat: tests and big refactor --- Cargo.toml | 2 +- src/lib.rs | 225 ++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 223 ++----------------------------------------- src/routes/auth.rs | 27 +++--- src/utils/auth.rs | 7 ++ src/utils/cli.rs | 9 +- tests/auth.rs | 61 ++++++++++++ tests/common/mod.rs | 86 +++++++++++++++++ 8 files changed, 402 insertions(+), 238 deletions(-) create mode 100644 src/lib.rs create mode 100644 tests/auth.rs create mode 100644 tests/common/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 57d09c9..c335d2f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ strip = true # Strip symbols from binary* axum = { version = "0.8.4", features = [ "macros", "ws", "tokio" ] } axum-extra = { version = "0.10.1", features = ["typed-header"] } dotenvy = "0.15.7" -reqwest = "0.12.22" +reqwest = { version = "0.12.22", features = ["json"] } sea-orm = { version = "1.1.13", features = [ "sqlx-sqlite", "runtime-tokio-native-tls", "macros" ] } serde = "1.0.219" serde_json = "1.0.140" diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..4c824de --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,225 @@ +use std::{net::{Ipv4Addr, SocketAddr}, path::PathBuf, sync::{Arc, LazyLock}}; + +use axum::{extract::State, http::HeaderMap, middleware, routing::get}; +use clap::{Parser, Subcommand}; +use reqwest::{header::USER_AGENT}; +use sea_orm::{ConnectionTrait, DatabaseConnection, DbErr, EntityTrait, PaginatorTrait, Schema}; +use tokio::{sync::broadcast::{self, Sender}}; +use utoipa::{openapi::{security::{HttpAuthScheme, HttpBuilder, SecurityScheme}, ContactBuilder, InfoBuilder, LicenseBuilder}, Modify, OpenApi}; +use utoipa_axum::router::OpenApiRouter; +use utoipa_swagger_ui::{Config, SwaggerUi}; +use utoipa_axum::routes; + +use crate::{entities::prelude::BookInstance, routes::auth::{Keys, DEFAULT_TOKEN_EXPIRY_TIME}, utils::events::Event}; + +pub mod entities; +pub mod utils; +pub mod routes; + +#[derive(Parser)] +#[command(name = "Alexandria")] +#[command(version = "1.0")] +#[command(about = "BAL management server", long_about = None)] +pub struct Cli { + /// Path to the sqlite database [default: ./alexandria.db] + #[arg(long, short, global = true, value_name = "FILE")] + pub database: Option, + + #[command(subcommand)] + command: Option, +} + +#[derive(Subcommand, Clone)] +pub enum Command { + /// Serves the web server + Run { + /// Port on which to serve the web server + #[arg(short, long, default_value_t = 3000)] + port: u16, + /// How many seconds generated JWTs are valid for. Default equates to 6 months + #[arg(long, default_value_t = DEFAULT_TOKEN_EXPIRY_TIME)] + token_expiration_time: u64, + }, + /// Open a TUI to manage user accounts + User +} + +impl Cli { + pub fn command(&self) -> Command { + self.command.clone().unwrap_or(Command::Run { port: 3000, token_expiration_time: DEFAULT_TOKEN_EXPIRY_TIME }) + } +} + +pub struct AppState { + app_name: String, + db_conn: Arc, + event_bus: Sender, + web_client: reqwest::Client +} + +async fn index( + State(state): State> +) ->String { + let app_name = &state.app_name; + let db_conn = &state.db_conn; + let status: &str = match db_conn.ping().await { + Ok(_) => "working", + Err(_) => "erroring" + }; + let book_count = BookInstance::find().count(db_conn.as_ref()).await.unwrap(); + format!("Hello from {app_name}! Database is {status}. We currently have {book_count} books in stock !") +} + +static KEYS: LazyLock = LazyLock::new(|| { + let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); + Keys::new(secret.as_bytes()) +}); + +pub static CLI: LazyLock = LazyLock::new(|| { + Cli::parse() +}); + +pub async fn create_tables(db_conn: &C) -> Result<(), DbErr> + where C: ConnectionTrait,{ + let builder = db_conn.get_database_backend(); + let schema = Schema::new(builder); + if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Book).if_not_exists())).await { + log::error!(target: "database", "Error while creating book table: {err:?}"); + return Err(err); + } + if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::BookInstance).if_not_exists())).await { + log::error!(target: "database", "Error while creating book_instance table: {err:?}"); + return Err(err); + } + if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Owner).if_not_exists())).await { + log::error!(target: "database", "Error while creating owner table: {err:?}"); + return Err(err); + } + if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::User).if_not_exists())).await { + log::error!(target: "database", "Error while creating user table: {err:?}"); + return Err(err); + } + if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Bal).if_not_exists())).await { + log::error!(target: "database", "Error while creating bal table: {err:?}"); + return Err(err); + } + Ok(()) +} + +pub async fn run_server(db: Arc, port: u16, serve_docs: bool) { + let (event_bus, _) = broadcast::channel(16); + + if std::env::var("JWT_SECRET").is_err() { + log::error!("JWT_SECRET is not set"); + return; + } + + let mut default_headers = HeaderMap::new(); + default_headers.append(USER_AGENT, "Alexandria/1.0 (unionetudianteauvergne@gmail.com)".parse().unwrap()); + let shared_state = Arc::new(AppState { + app_name: "Alexandria".to_string(), + db_conn: db, + event_bus, + web_client: reqwest::Client::builder().default_headers(default_headers).build().expect("creating the reqwest client failed") + }); + + + #[derive(OpenApi)] + #[openapi( + tags( + (name = "book-api", description = "Book management endpoints.") + ), + modifiers(&SecurityAddon) + )] + struct ApiDoc; + + struct SecurityAddon; + + impl Modify for SecurityAddon { + fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { + let components = openapi.components.as_mut().unwrap(); + components.add_security_scheme( + "jwt", + SecurityScheme::Http( + HttpBuilder::new().scheme(HttpAuthScheme::Bearer).bearer_format("JWT").build() + ) + ) + } + } + + let open_api_router = OpenApiRouter::new() + // Book API + .routes(routes!(routes::book::get_book_by_ean)) + .routes(routes!(routes::book::get_book_by_id)) + .routes(routes!(routes::book::create_book)) + // Book Instance API + .routes(routes!(routes::book_instance::get_book_instance_by_id)) + .routes(routes!(routes::book_instance::create_book_instance)) + .routes(routes!(routes::book_instance::update_book_instance)) + .routes(routes!(routes::book_instance::sell_book_instance)) + .routes(routes!(routes::book_instance::bulk_create_book_instance)) + .routes(routes!(routes::book_instance::get_bal_owner_book_instances)) + .routes(routes!(routes::book_instance::get_bal_book_instances_by_ean)) + .routes(routes!(routes::book_instance::search_bal_book_instances)) + // Owner API + .routes(routes!(routes::owner::get_owner_by_id)) + .routes(routes!(routes::owner::create_owner)) + .routes(routes!(routes::owner::update_owner)) + .routes(routes!(routes::owner::get_owners)) + // Bal API + .routes(routes!(routes::bal::get_bal_by_id)) + .routes(routes!(routes::bal::create_bal)) + .routes(routes!(routes::bal::update_bal)) + .routes(routes!(routes::bal::get_bals)) + .routes(routes!(routes::bal::get_current_bal)) + .routes(routes!(routes::bal::set_current_bal)) + // Authentication + .route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware)) + .routes(routes!(routes::auth::auth)) + .routes(routes!(routes::auth::check_token)) + // Misc + .routes(routes!(routes::misc::current_api_version)) + .routes(routes!(routes::websocket::ws_handler)) + + .with_state(shared_state.clone()); + + let (mut router, mut api) = OpenApiRouter::new() + .nest("/api", open_api_router) + .route("/", get(index)) // temporary index page, will redirect/proxy to flutter app + .with_state(shared_state) + .split_for_parts(); + + if serve_docs { + api.info = InfoBuilder::new() + .title("Alexandria") + .description(Some("Alexandria is a server that manages books and users for Union Étudiante's book exchange")) + .contact(Some(ContactBuilder::new() + .url(Some("https://ueauvergne.fr")) + .name(Some("Union Étudiante Auvergne")) + .email(Some("unionetudianteauvergne@gmail.com")) + .build())) + .license(Some(LicenseBuilder::new().name("MIT").url(Some("https://spdx.org/licenses/MIT.html")).build())) + .version("1.0.0") + .build(); + + api.merge(ApiDoc::openapi()); + + let swagger = SwaggerUi::new("/docs/") + .url("/docs/openapi.json", api) + .config(Config::default() + .try_it_out_enabled(true) + .filter(true) + .display_request_duration(true) + .persist_authorization(true) + ); + + router = router.merge(swagger); + } + + let listener = tokio::net::TcpListener::bind(SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)).await.unwrap(); + log::info!("Running on http://{}", listener.local_addr().unwrap()); + axum::serve( + listener, + router.into_make_service_with_connect_info::() + ).await.unwrap() +} diff --git a/src/main.rs b/src/main.rs index d5a8b86..af9dc3d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,77 +1,8 @@ -use std::{net::{Ipv4Addr, SocketAddr}, path::PathBuf, sync::{Arc, LazyLock}}; +use std::sync::Arc; -use axum::{extract::State, http::HeaderMap, middleware, routing::get}; -use clap::{Parser, Subcommand}; -use reqwest::{header::USER_AGENT}; -use sea_orm::{ConnectionTrait, Database, DatabaseConnection, DbErr, EntityTrait, PaginatorTrait, Schema}; -use tokio::{sync::broadcast::{self, Sender}}; -use utoipa::{openapi::{security::{HttpAuthScheme, HttpBuilder, SecurityScheme}, ContactBuilder, InfoBuilder, LicenseBuilder}, Modify, OpenApi}; -use utoipa_axum::router::OpenApiRouter; -use utoipa_swagger_ui::{Config, SwaggerUi}; -use utoipa_axum::routes; +use alexandria::{create_tables, run_server, utils, Command, CLI}; +use sea_orm::{Database, DatabaseConnection}; -use crate::{entities::prelude::BookInstance, routes::auth::Keys, utils::events::Event}; - -pub mod entities; -pub mod utils; -pub mod routes; - -#[derive(Parser)] -#[command(name = "Alexandria")] -#[command(version = "1.0")] -#[command(about = "BAL management server", long_about = None)] -struct Cli { - /// Path to the sqlite database [default: ./alexandria.db] - #[arg(long, short, global = true, value_name = "FILE")] - database: Option, - - #[command(subcommand)] - command: Commands, -} - -#[derive(Subcommand)] -enum Commands { - /// Serves the web server - Run { - /// Port on which to serve the web server - #[arg(short, long, default_value_t = 3000)] - port: u16, - /// How many seconds generated JWTs are valid for. Default equates to 6 months - #[arg(long, default_value_t = 15_778_476)] - token_expiration_time: u64, - }, - /// Open a TUI to manage user accounts - User -} - -pub struct AppState { - app_name: String, - db_conn: Arc, - event_bus: Sender, - web_client: reqwest::Client -} - -async fn index( - State(state): State> -) ->String { - let app_name = &state.app_name; - let db_conn = &state.db_conn; - let status: &str = match db_conn.ping().await { - Ok(_) => "working", - Err(_) => "erroring" - }; - let book_count = BookInstance::find().count(db_conn.as_ref()).await.unwrap(); - format!("Hello from {app_name}! Database is {status}. We currently have {book_count} books in stock !") -} - -static KEYS: LazyLock = LazyLock::new(|| { - let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); - Keys::new(secret.as_bytes()) -}); - -static CLI: LazyLock = LazyLock::new(|| { - Cli::parse() -}); #[tokio::main] async fn main() { @@ -105,151 +36,9 @@ async fn main() { return; }; - match &CLI.command { - Commands::Run {port,..} => run_server(db, *port).await, - Commands::User => utils::cli::manage_users(db).await + match &CLI.command() { + Command::Run {port,..} => run_server(db, *port, true).await, + Command::User => utils::cli::manage_users(db).await } } -async fn create_tables(db_conn: &C) -> Result<(), DbErr> - where C: ConnectionTrait,{ - let builder = db_conn.get_database_backend(); - let schema = Schema::new(builder); - if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Book).if_not_exists())).await { - log::error!(target: "database", "Error while creating book table: {err:?}"); - return Err(err); - } - if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::BookInstance).if_not_exists())).await { - log::error!(target: "database", "Error while creating book_instance table: {err:?}"); - return Err(err); - } - if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Owner).if_not_exists())).await { - log::error!(target: "database", "Error while creating owner table: {err:?}"); - return Err(err); - } - if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::User).if_not_exists())).await { - log::error!(target: "database", "Error while creating user table: {err:?}"); - return Err(err); - } - if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Bal).if_not_exists())).await { - log::error!(target: "database", "Error while creating bal table: {err:?}"); - return Err(err); - } - Ok(()) -} - -async fn run_server(db: Arc, port: u16) { - let (event_bus, _) = broadcast::channel(16); - - if std::env::var("JWT_SECRET").is_err() { - log::error!("JWT_SECRET is not set"); - return; - } - - let mut default_headers = HeaderMap::new(); - default_headers.append(USER_AGENT, "Alexandria/1.0 (unionetudianteauvergne@gmail.com)".parse().unwrap()); - let shared_state = Arc::new(AppState { - app_name: "Alexandria".to_string(), - db_conn: db, - event_bus, - web_client: reqwest::Client::builder().default_headers(default_headers).build().expect("creating the reqwest client failed") - }); - - - #[derive(OpenApi)] - #[openapi( - tags( - (name = "book-api", description = "Book management endpoints.") - ), - modifiers(&SecurityAddon) - )] - struct ApiDoc; - - struct SecurityAddon; - - impl Modify for SecurityAddon { - fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { - let components = openapi.components.as_mut().unwrap(); - components.add_security_scheme( - "jwt", - SecurityScheme::Http( - HttpBuilder::new().scheme(HttpAuthScheme::Bearer).bearer_format("JWT").build() - ) - ) - } - } - - let open_api_router = OpenApiRouter::new() - // Book API - .routes(routes!(routes::book::get_book_by_ean)) - .routes(routes!(routes::book::get_book_by_id)) - .routes(routes!(routes::book::create_book)) - // Book Instance API - .routes(routes!(routes::book_instance::get_book_instance_by_id)) - .routes(routes!(routes::book_instance::create_book_instance)) - .routes(routes!(routes::book_instance::update_book_instance)) - .routes(routes!(routes::book_instance::sell_book_instance)) - .routes(routes!(routes::book_instance::bulk_create_book_instance)) - .routes(routes!(routes::book_instance::get_bal_owner_book_instances)) - .routes(routes!(routes::book_instance::get_bal_book_instances_by_ean)) - .routes(routes!(routes::book_instance::search_bal_book_instances)) - // Owner API - .routes(routes!(routes::owner::get_owner_by_id)) - .routes(routes!(routes::owner::create_owner)) - .routes(routes!(routes::owner::update_owner)) - .routes(routes!(routes::owner::get_owners)) - // Bal API - .routes(routes!(routes::bal::get_bal_by_id)) - .routes(routes!(routes::bal::create_bal)) - .routes(routes!(routes::bal::update_bal)) - .routes(routes!(routes::bal::get_bals)) - .routes(routes!(routes::bal::get_current_bal)) - .routes(routes!(routes::bal::set_current_bal)) - // Authentication - .route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware)) - .routes(routes!(routes::auth::auth)) - .routes(routes!(routes::auth::check_token)) - // Misc - .routes(routes!(routes::misc::current_api_version)) - .routes(routes!(routes::websocket::ws_handler)) - - .with_state(shared_state.clone()); - - let (router, mut api) = OpenApiRouter::new() - .nest("/api", open_api_router) - .route("/", get(index)) // temporary index page, will redirect/proxy to flutter app - .with_state(shared_state) - .split_for_parts(); - - api.info = InfoBuilder::new() - .title("Alexandria") - .description(Some("Alexandria is a server that manages books and users for Union Étudiante's book exchange")) - .contact(Some(ContactBuilder::new() - .url(Some("https://ueauvergne.fr")) - .name(Some("Union Étudiante Auvergne")) - .email(Some("unionetudianteauvergne@gmail.com")) - .build())) - .license(Some(LicenseBuilder::new().name("MIT").url(Some("https://spdx.org/licenses/MIT.html")).build())) - .version("1.0.0") - .build(); - - api.merge(ApiDoc::openapi()); - - let swagger = SwaggerUi::new("/docs/") - .url("/docs/openapi.json", api) - .config(Config::default() - .try_it_out_enabled(true) - .filter(true) - .display_request_duration(true) - .persist_authorization(true) - ); - - let router = router.merge(swagger); - - let listener = tokio::net::TcpListener::bind(SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)).await.unwrap(); - log::info!("Running on http://{}", listener.local_addr().unwrap()); - axum::serve( - listener, - router.into_make_service_with_connect_info::() - ).await.unwrap(); -} diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 8cbe842..af671aa 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -7,9 +7,9 @@ use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::{entities::user, AppState, Commands, CLI, KEYS}; +use crate::{entities::user, AppState, KEYS}; -//const TOKEN_EXPIRY_TIME: u64 = 15_778_476; // 6 Months +pub const DEFAULT_TOKEN_EXPIRY_TIME: u64 = 15_778_476; // 6 Months pub async fn auth_middleware( _claims: Claims, @@ -44,15 +44,14 @@ pub async fn auth(State(state): State>, Json(payload): Json return Err(AuthError::WrongCredentials), Ok(Some(user)) => { - user.verify_password(payload.password); + if !user.verify_password(payload.password) { + return Err(AuthError::WrongCredentials); + }; let unix_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time should go forward").as_secs(); let claims = Claims { sub: user.username, - exp: unix_timestamp + match CLI.command { - Commands::Run { token_expiration_time, .. } => token_expiration_time, - _ => panic!("The auth endpoint cannot be used outside of a Run command") - }, + exp: unix_timestamp + DEFAULT_TOKEN_EXPIRY_TIME, user_id: user.id }; let token = encode(&Header::default(), &claims, &KEYS.encoding) @@ -63,6 +62,10 @@ pub async fn auth(State(state): State>, Json(payload): Json String { + encode(&Header::default(), &claims, &KEYS.encoding).unwrap() +} + #[derive(Deserialize, utoipa::ToSchema)] pub struct TokenPayload { token: String @@ -152,15 +155,15 @@ impl Keys { #[derive(Debug, Serialize, Deserialize)] pub struct Claims { - sub: String, - exp: u64, + pub sub: String, + pub exp: u64, pub user_id: u32, } -#[derive(Debug, Serialize, utoipa::ToSchema)] +#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)] pub struct AuthBody { - access_token: String, - token_type: String, + pub access_token: String, + pub token_type: String, } #[derive(Debug, Deserialize, utoipa::ToSchema)] diff --git a/src/utils/auth.rs b/src/utils/auth.rs index 1bed6e5..5e44b3c 100644 --- a/src/utils/auth.rs +++ b/src/utils/auth.rs @@ -1,3 +1,5 @@ +use argon2::{Argon2, PasswordHasher}; +use password_hash::{rand_core::OsRng, SaltString}; use sea_orm::ConnectionTrait; use crate::entities::prelude::*; @@ -40,3 +42,8 @@ where C: ConnectionTrait, None => false } } + +pub fn hash_password(password: String) -> String { + let salt = SaltString::generate(&mut OsRng); + Argon2::default().hash_password(&password.clone().into_bytes(), &salt).unwrap().to_string() +} diff --git a/src/utils/cli.rs b/src/utils/cli.rs index 8956350..dbc0097 100644 --- a/src/utils/cli.rs +++ b/src/utils/cli.rs @@ -1,11 +1,9 @@ use std::{fmt::Display, sync::Arc}; -use argon2::{password_hash::{SaltString}, Argon2, PasswordHasher}; use inquire::{min_length, prompt_text, Confirm, Password, Select, Text}; -use password_hash::rand_core::OsRng; use sea_orm::{ActiveModelTrait, ActiveValue::{NotSet, Set}, ColumnTrait, DatabaseConnection, EntityTrait, ModelTrait, QueryFilter}; -use crate::entities::{owner, prelude::User, user::{self, ActiveModel}}; +use crate::{entities::{owner, prelude::User, user::{self, ActiveModel}}, utils::auth::hash_password}; #[derive(Debug, Copy, Clone)] enum Action { @@ -146,11 +144,6 @@ pub async fn manage_users(db: Arc) { } } -fn hash_password(password: String) -> String { - let salt = SaltString::generate(&mut OsRng); - Argon2::default().hash_password(&password.clone().into_bytes(), &salt).unwrap().to_string() -} - async fn select_user(db: Arc) -> Option { let users = User::find().all(db.as_ref()).await.unwrap(); if users.is_empty() { diff --git a/tests/auth.rs b/tests/auth.rs new file mode 100644 index 0000000..2222738 --- /dev/null +++ b/tests/auth.rs @@ -0,0 +1,61 @@ +use std::collections::HashMap; + +use alexandria::routes::auth::AuthBody; +use reqwest::StatusCode; + +mod common; + +#[tokio::test] +async fn auth_wrong_password() { + let data = common::setup().await; + let client = reqwest::Client::new(); + + let mut wrong_pwd_auth_map = HashMap::new(); + wrong_pwd_auth_map.insert("username", "test_username"); + wrong_pwd_auth_map.insert("password", "pwd"); + let wrong_pwd_auth_res = client.execute(client.post(format!("{}/auth", data.api_path)).json(&wrong_pwd_auth_map).build().unwrap()).await.unwrap(); + assert_eq!(wrong_pwd_auth_res.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn auth_wrong_username() { + let data = common::setup().await; + let client = reqwest::Client::new(); + + let mut wrong_username_auth_map = HashMap::new(); + wrong_username_auth_map.insert("username", "wrong_username"); + wrong_username_auth_map.insert("password", "test_password"); + let wrong_username_auth_res = client.execute(client.post(format!("{}/auth", data.api_path)).json(&wrong_username_auth_map).build().unwrap()).await.unwrap(); + assert_eq!(wrong_username_auth_res.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn auth_correct_credentials() { + let data = common::setup().await; + let client = reqwest::Client::new(); + + let mut auth_map = HashMap::new(); + auth_map.insert("username", "test_username"); + auth_map.insert("password", "test_password"); + let auth_res = client.execute(client.post(format!("{}/auth", data.api_path)).json(&auth_map).build().unwrap()).await.unwrap(); + assert_eq!(auth_res.status(), StatusCode::OK); + let auth_body = auth_res.json::().await.unwrap(); + + let mut check_token_map = HashMap::new(); + check_token_map.insert("token", auth_body.access_token); + let check_token_res = client.execute(client.post(format!("{}/token-check", data.api_path)).json(&check_token_map).build().unwrap()).await.unwrap(); + let valid_token = check_token_res.json::().await.unwrap(); + assert_eq!(valid_token, true); +} + +#[tokio::test] +async fn auth_wrong_token_check() { + let data = common::setup().await; + let client = reqwest::Client::new(); + + let mut check_wrong_token_map = HashMap::new(); + check_wrong_token_map.insert("token", "this-is-definitely-not-a-wrong-token"); + let check_wrong_token_res = client.execute(client.post(format!("{}/token-check", data.api_path)).json(&check_wrong_token_map).build().unwrap()).await.unwrap(); + let invalid_token = check_wrong_token_res.json::().await.unwrap(); + assert_eq!(invalid_token, false); +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 0000000..1831471 --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,86 @@ +use std::{net::{Ipv4Addr, SocketAddrV4, TcpListener}, sync::Arc, time::{SystemTime, UNIX_EPOCH}}; + +use sea_orm::{ActiveModelTrait, ActiveValue::{NotSet, Set}, ColumnTrait, ConnectionTrait, Database, DatabaseConnection, EntityTrait, QueryFilter}; + +use alexandria::{create_tables, entities::{owner, prelude::*, user}, routes::auth::{generate_token_from_claims, Claims, DEFAULT_TOKEN_EXPIRY_TIME}, run_server, utils::auth::hash_password}; + +pub struct SetupData { + /// A valid JWT for testing features that need authentication + pub jwt: String, + pub api_path: String +} + +/// Common setup function for tests that require a database and server setup +pub async fn setup() -> SetupData { + let _ = pretty_env_logger::try_init(); + + let db: Arc = Arc::new( + match Database::connect(format!("sqlite::memory:?mode=rwc")).await { + Ok(c) => c, + Err(e) => { + panic!("Error while opening fatabase: {}", e.to_string()) + } + }); + + create_tables(db.as_ref()).await.expect("Create tables should not fail"); + + let port = free_local_ipv4_port().expect("Could not get a free port"); + let db_c = db.clone(); + tokio::spawn(async move { + run_server(db_c, port, false).await; + }); + + create_user(db.as_ref(), "test_username", "test_password").await; + + + let unix_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time should go forward").as_secs(); + + let claims = Claims { + sub: "test_username".to_string(), + exp: unix_timestamp + DEFAULT_TOKEN_EXPIRY_TIME, + user_id: 1 + }; + + SetupData { + jwt: generate_token_from_claims(claims), + api_path: format!("http://0.0.0.0:{port}/api") + } +} + + +fn free_local_ipv4_port() -> Option { + let socket = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0); + TcpListener::bind(socket) + .and_then(|listener| listener.local_addr()) + .map(|addr| addr.port()) + .ok() +} + +async fn create_user(db_conn: &C, username_t: impl ToString, password_t: impl ToString) +where C: ConnectionTrait { + let username = username_t.to_string(); + if User::find().filter(user::Column::Username.eq(username.clone())).one(db_conn).await.is_ok_and(|r| r.is_some()) { + panic!("Username {username} already in use"); + } else { + let password = password_t.to_string(); + let mut new_user = user::ActiveModel { + id: NotSet, + username: Set(username.clone()), + hashed_password: Set(hash_password(password)), + current_bal_id: Set(None), + owner_id: Set(None) + }; + let res = new_user.clone().insert(db_conn).await.unwrap(); + + let new_owner = owner::ActiveModel { + id: NotSet, + user_id: Set(res.id), + first_name: Set(format!("{username} first name")), + last_name: Set(format!("{username} last name")), + contact: Set(format!("{username}@mail.com")) + }; + let owner_res = new_owner.insert(db_conn).await.unwrap(); + new_user.owner_id = Set(Some(owner_res.id)); + let _ = new_user.update(db_conn); + } +}