feat: tests and big refactor
This commit is contained in:
parent
05e8366611
commit
82f2bb4a61
8 changed files with 402 additions and 238 deletions
|
|
@ -13,7 +13,7 @@ strip = true # Strip symbols from binary*
|
|||
axum = { version = "0.8.4", features = [ "macros", "ws", "tokio" ] }
|
||||
axum-extra = { version = "0.10.1", features = ["typed-header"] }
|
||||
dotenvy = "0.15.7"
|
||||
reqwest = "0.12.22"
|
||||
reqwest = { version = "0.12.22", features = ["json"] }
|
||||
sea-orm = { version = "1.1.13", features = [ "sqlx-sqlite", "runtime-tokio-native-tls", "macros" ] }
|
||||
serde = "1.0.219"
|
||||
serde_json = "1.0.140"
|
||||
|
|
|
|||
225
src/lib.rs
Normal file
225
src/lib.rs
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
use std::{net::{Ipv4Addr, SocketAddr}, path::PathBuf, sync::{Arc, LazyLock}};
|
||||
|
||||
use axum::{extract::State, http::HeaderMap, middleware, routing::get};
|
||||
use clap::{Parser, Subcommand};
|
||||
use reqwest::{header::USER_AGENT};
|
||||
use sea_orm::{ConnectionTrait, DatabaseConnection, DbErr, EntityTrait, PaginatorTrait, Schema};
|
||||
use tokio::{sync::broadcast::{self, Sender}};
|
||||
use utoipa::{openapi::{security::{HttpAuthScheme, HttpBuilder, SecurityScheme}, ContactBuilder, InfoBuilder, LicenseBuilder}, Modify, OpenApi};
|
||||
use utoipa_axum::router::OpenApiRouter;
|
||||
use utoipa_swagger_ui::{Config, SwaggerUi};
|
||||
use utoipa_axum::routes;
|
||||
|
||||
use crate::{entities::prelude::BookInstance, routes::auth::{Keys, DEFAULT_TOKEN_EXPIRY_TIME}, utils::events::Event};
|
||||
|
||||
pub mod entities;
|
||||
pub mod utils;
|
||||
pub mod routes;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "Alexandria")]
|
||||
#[command(version = "1.0")]
|
||||
#[command(about = "BAL management server", long_about = None)]
|
||||
pub struct Cli {
|
||||
/// Path to the sqlite database [default: ./alexandria.db]
|
||||
#[arg(long, short, global = true, value_name = "FILE")]
|
||||
pub database: Option<PathBuf>,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: Option<Command>,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Clone)]
|
||||
pub enum Command {
|
||||
/// Serves the web server
|
||||
Run {
|
||||
/// Port on which to serve the web server
|
||||
#[arg(short, long, default_value_t = 3000)]
|
||||
port: u16,
|
||||
/// How many seconds generated JWTs are valid for. Default equates to 6 months
|
||||
#[arg(long, default_value_t = DEFAULT_TOKEN_EXPIRY_TIME)]
|
||||
token_expiration_time: u64,
|
||||
},
|
||||
/// Open a TUI to manage user accounts
|
||||
User
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
pub fn command(&self) -> Command {
|
||||
self.command.clone().unwrap_or(Command::Run { port: 3000, token_expiration_time: DEFAULT_TOKEN_EXPIRY_TIME })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
app_name: String,
|
||||
db_conn: Arc<DatabaseConnection>,
|
||||
event_bus: Sender<Event>,
|
||||
web_client: reqwest::Client
|
||||
}
|
||||
|
||||
async fn index(
|
||||
State(state): State<Arc<AppState>>
|
||||
) ->String {
|
||||
let app_name = &state.app_name;
|
||||
let db_conn = &state.db_conn;
|
||||
let status: &str = match db_conn.ping().await {
|
||||
Ok(_) => "working",
|
||||
Err(_) => "erroring"
|
||||
};
|
||||
let book_count = BookInstance::find().count(db_conn.as_ref()).await.unwrap();
|
||||
format!("Hello from {app_name}! Database is {status}. We currently have {book_count} books in stock !")
|
||||
}
|
||||
|
||||
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
|
||||
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
Keys::new(secret.as_bytes())
|
||||
});
|
||||
|
||||
pub static CLI: LazyLock<Cli> = LazyLock::new(|| {
|
||||
Cli::parse()
|
||||
});
|
||||
|
||||
pub async fn create_tables<C>(db_conn: &C) -> Result<(), DbErr>
|
||||
where C: ConnectionTrait,{
|
||||
let builder = db_conn.get_database_backend();
|
||||
let schema = Schema::new(builder);
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Book).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating book table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::BookInstance).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating book_instance table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Owner).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating owner table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::User).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating user table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Bal).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating bal table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_server(db: Arc<DatabaseConnection>, port: u16, serve_docs: bool) {
|
||||
let (event_bus, _) = broadcast::channel(16);
|
||||
|
||||
if std::env::var("JWT_SECRET").is_err() {
|
||||
log::error!("JWT_SECRET is not set");
|
||||
return;
|
||||
}
|
||||
|
||||
let mut default_headers = HeaderMap::new();
|
||||
default_headers.append(USER_AGENT, "Alexandria/1.0 (unionetudianteauvergne@gmail.com)".parse().unwrap());
|
||||
let shared_state = Arc::new(AppState {
|
||||
app_name: "Alexandria".to_string(),
|
||||
db_conn: db,
|
||||
event_bus,
|
||||
web_client: reqwest::Client::builder().default_headers(default_headers).build().expect("creating the reqwest client failed")
|
||||
});
|
||||
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
tags(
|
||||
(name = "book-api", description = "Book management endpoints.")
|
||||
),
|
||||
modifiers(&SecurityAddon)
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
struct SecurityAddon;
|
||||
|
||||
impl Modify for SecurityAddon {
|
||||
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
|
||||
let components = openapi.components.as_mut().unwrap();
|
||||
components.add_security_scheme(
|
||||
"jwt",
|
||||
SecurityScheme::Http(
|
||||
HttpBuilder::new().scheme(HttpAuthScheme::Bearer).bearer_format("JWT").build()
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let open_api_router = OpenApiRouter::new()
|
||||
// Book API
|
||||
.routes(routes!(routes::book::get_book_by_ean))
|
||||
.routes(routes!(routes::book::get_book_by_id))
|
||||
.routes(routes!(routes::book::create_book))
|
||||
// Book Instance API
|
||||
.routes(routes!(routes::book_instance::get_book_instance_by_id))
|
||||
.routes(routes!(routes::book_instance::create_book_instance))
|
||||
.routes(routes!(routes::book_instance::update_book_instance))
|
||||
.routes(routes!(routes::book_instance::sell_book_instance))
|
||||
.routes(routes!(routes::book_instance::bulk_create_book_instance))
|
||||
.routes(routes!(routes::book_instance::get_bal_owner_book_instances))
|
||||
.routes(routes!(routes::book_instance::get_bal_book_instances_by_ean))
|
||||
.routes(routes!(routes::book_instance::search_bal_book_instances))
|
||||
// Owner API
|
||||
.routes(routes!(routes::owner::get_owner_by_id))
|
||||
.routes(routes!(routes::owner::create_owner))
|
||||
.routes(routes!(routes::owner::update_owner))
|
||||
.routes(routes!(routes::owner::get_owners))
|
||||
// Bal API
|
||||
.routes(routes!(routes::bal::get_bal_by_id))
|
||||
.routes(routes!(routes::bal::create_bal))
|
||||
.routes(routes!(routes::bal::update_bal))
|
||||
.routes(routes!(routes::bal::get_bals))
|
||||
.routes(routes!(routes::bal::get_current_bal))
|
||||
.routes(routes!(routes::bal::set_current_bal))
|
||||
// Authentication
|
||||
.route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware))
|
||||
.routes(routes!(routes::auth::auth))
|
||||
.routes(routes!(routes::auth::check_token))
|
||||
// Misc
|
||||
.routes(routes!(routes::misc::current_api_version))
|
||||
.routes(routes!(routes::websocket::ws_handler))
|
||||
|
||||
.with_state(shared_state.clone());
|
||||
|
||||
let (mut router, mut api) = OpenApiRouter::new()
|
||||
.nest("/api", open_api_router)
|
||||
.route("/", get(index)) // temporary index page, will redirect/proxy to flutter app
|
||||
.with_state(shared_state)
|
||||
.split_for_parts();
|
||||
|
||||
if serve_docs {
|
||||
api.info = InfoBuilder::new()
|
||||
.title("Alexandria")
|
||||
.description(Some("Alexandria is a server that manages books and users for Union Étudiante's book exchange"))
|
||||
.contact(Some(ContactBuilder::new()
|
||||
.url(Some("https://ueauvergne.fr"))
|
||||
.name(Some("Union Étudiante Auvergne"))
|
||||
.email(Some("unionetudianteauvergne@gmail.com"))
|
||||
.build()))
|
||||
.license(Some(LicenseBuilder::new().name("MIT").url(Some("https://spdx.org/licenses/MIT.html")).build()))
|
||||
.version("1.0.0")
|
||||
.build();
|
||||
|
||||
api.merge(ApiDoc::openapi());
|
||||
|
||||
let swagger = SwaggerUi::new("/docs/")
|
||||
.url("/docs/openapi.json", api)
|
||||
.config(Config::default()
|
||||
.try_it_out_enabled(true)
|
||||
.filter(true)
|
||||
.display_request_duration(true)
|
||||
.persist_authorization(true)
|
||||
);
|
||||
|
||||
router = router.merge(swagger);
|
||||
}
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)).await.unwrap();
|
||||
log::info!("Running on http://{}", listener.local_addr().unwrap());
|
||||
axum::serve(
|
||||
listener,
|
||||
router.into_make_service_with_connect_info::<SocketAddr>()
|
||||
).await.unwrap()
|
||||
}
|
||||
223
src/main.rs
223
src/main.rs
|
|
@ -1,77 +1,8 @@
|
|||
use std::{net::{Ipv4Addr, SocketAddr}, path::PathBuf, sync::{Arc, LazyLock}};
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{extract::State, http::HeaderMap, middleware, routing::get};
|
||||
use clap::{Parser, Subcommand};
|
||||
use reqwest::{header::USER_AGENT};
|
||||
use sea_orm::{ConnectionTrait, Database, DatabaseConnection, DbErr, EntityTrait, PaginatorTrait, Schema};
|
||||
use tokio::{sync::broadcast::{self, Sender}};
|
||||
use utoipa::{openapi::{security::{HttpAuthScheme, HttpBuilder, SecurityScheme}, ContactBuilder, InfoBuilder, LicenseBuilder}, Modify, OpenApi};
|
||||
use utoipa_axum::router::OpenApiRouter;
|
||||
use utoipa_swagger_ui::{Config, SwaggerUi};
|
||||
use utoipa_axum::routes;
|
||||
use alexandria::{create_tables, run_server, utils, Command, CLI};
|
||||
use sea_orm::{Database, DatabaseConnection};
|
||||
|
||||
use crate::{entities::prelude::BookInstance, routes::auth::Keys, utils::events::Event};
|
||||
|
||||
pub mod entities;
|
||||
pub mod utils;
|
||||
pub mod routes;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "Alexandria")]
|
||||
#[command(version = "1.0")]
|
||||
#[command(about = "BAL management server", long_about = None)]
|
||||
struct Cli {
|
||||
/// Path to the sqlite database [default: ./alexandria.db]
|
||||
#[arg(long, short, global = true, value_name = "FILE")]
|
||||
database: Option<PathBuf>,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Serves the web server
|
||||
Run {
|
||||
/// Port on which to serve the web server
|
||||
#[arg(short, long, default_value_t = 3000)]
|
||||
port: u16,
|
||||
/// How many seconds generated JWTs are valid for. Default equates to 6 months
|
||||
#[arg(long, default_value_t = 15_778_476)]
|
||||
token_expiration_time: u64,
|
||||
},
|
||||
/// Open a TUI to manage user accounts
|
||||
User
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
app_name: String,
|
||||
db_conn: Arc<DatabaseConnection>,
|
||||
event_bus: Sender<Event>,
|
||||
web_client: reqwest::Client
|
||||
}
|
||||
|
||||
async fn index(
|
||||
State(state): State<Arc<AppState>>
|
||||
) ->String {
|
||||
let app_name = &state.app_name;
|
||||
let db_conn = &state.db_conn;
|
||||
let status: &str = match db_conn.ping().await {
|
||||
Ok(_) => "working",
|
||||
Err(_) => "erroring"
|
||||
};
|
||||
let book_count = BookInstance::find().count(db_conn.as_ref()).await.unwrap();
|
||||
format!("Hello from {app_name}! Database is {status}. We currently have {book_count} books in stock !")
|
||||
}
|
||||
|
||||
static KEYS: LazyLock<Keys> = LazyLock::new(|| {
|
||||
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
|
||||
Keys::new(secret.as_bytes())
|
||||
});
|
||||
|
||||
static CLI: LazyLock<Cli> = LazyLock::new(|| {
|
||||
Cli::parse()
|
||||
});
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
|
|
@ -105,151 +36,9 @@ async fn main() {
|
|||
return;
|
||||
};
|
||||
|
||||
match &CLI.command {
|
||||
Commands::Run {port,..} => run_server(db, *port).await,
|
||||
Commands::User => utils::cli::manage_users(db).await
|
||||
match &CLI.command() {
|
||||
Command::Run {port,..} => run_server(db, *port, true).await,
|
||||
Command::User => utils::cli::manage_users(db).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_tables<C>(db_conn: &C) -> Result<(), DbErr>
|
||||
where C: ConnectionTrait,{
|
||||
let builder = db_conn.get_database_backend();
|
||||
let schema = Schema::new(builder);
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Book).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating book table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::BookInstance).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating book_instance table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Owner).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating owner table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::User).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating user table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
if let Err(err) = db_conn.execute(builder.build(schema.create_table_from_entity(crate::entities::prelude::Bal).if_not_exists())).await {
|
||||
log::error!(target: "database", "Error while creating bal table: {err:?}");
|
||||
return Err(err);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_server(db: Arc<DatabaseConnection>, port: u16) {
|
||||
let (event_bus, _) = broadcast::channel(16);
|
||||
|
||||
if std::env::var("JWT_SECRET").is_err() {
|
||||
log::error!("JWT_SECRET is not set");
|
||||
return;
|
||||
}
|
||||
|
||||
let mut default_headers = HeaderMap::new();
|
||||
default_headers.append(USER_AGENT, "Alexandria/1.0 (unionetudianteauvergne@gmail.com)".parse().unwrap());
|
||||
let shared_state = Arc::new(AppState {
|
||||
app_name: "Alexandria".to_string(),
|
||||
db_conn: db,
|
||||
event_bus,
|
||||
web_client: reqwest::Client::builder().default_headers(default_headers).build().expect("creating the reqwest client failed")
|
||||
});
|
||||
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
tags(
|
||||
(name = "book-api", description = "Book management endpoints.")
|
||||
),
|
||||
modifiers(&SecurityAddon)
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
struct SecurityAddon;
|
||||
|
||||
impl Modify for SecurityAddon {
|
||||
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
|
||||
let components = openapi.components.as_mut().unwrap();
|
||||
components.add_security_scheme(
|
||||
"jwt",
|
||||
SecurityScheme::Http(
|
||||
HttpBuilder::new().scheme(HttpAuthScheme::Bearer).bearer_format("JWT").build()
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let open_api_router = OpenApiRouter::new()
|
||||
// Book API
|
||||
.routes(routes!(routes::book::get_book_by_ean))
|
||||
.routes(routes!(routes::book::get_book_by_id))
|
||||
.routes(routes!(routes::book::create_book))
|
||||
// Book Instance API
|
||||
.routes(routes!(routes::book_instance::get_book_instance_by_id))
|
||||
.routes(routes!(routes::book_instance::create_book_instance))
|
||||
.routes(routes!(routes::book_instance::update_book_instance))
|
||||
.routes(routes!(routes::book_instance::sell_book_instance))
|
||||
.routes(routes!(routes::book_instance::bulk_create_book_instance))
|
||||
.routes(routes!(routes::book_instance::get_bal_owner_book_instances))
|
||||
.routes(routes!(routes::book_instance::get_bal_book_instances_by_ean))
|
||||
.routes(routes!(routes::book_instance::search_bal_book_instances))
|
||||
// Owner API
|
||||
.routes(routes!(routes::owner::get_owner_by_id))
|
||||
.routes(routes!(routes::owner::create_owner))
|
||||
.routes(routes!(routes::owner::update_owner))
|
||||
.routes(routes!(routes::owner::get_owners))
|
||||
// Bal API
|
||||
.routes(routes!(routes::bal::get_bal_by_id))
|
||||
.routes(routes!(routes::bal::create_bal))
|
||||
.routes(routes!(routes::bal::update_bal))
|
||||
.routes(routes!(routes::bal::get_bals))
|
||||
.routes(routes!(routes::bal::get_current_bal))
|
||||
.routes(routes!(routes::bal::set_current_bal))
|
||||
// Authentication
|
||||
.route_layer(middleware::from_fn_with_state(shared_state.clone(), routes::auth::auth_middleware))
|
||||
.routes(routes!(routes::auth::auth))
|
||||
.routes(routes!(routes::auth::check_token))
|
||||
// Misc
|
||||
.routes(routes!(routes::misc::current_api_version))
|
||||
.routes(routes!(routes::websocket::ws_handler))
|
||||
|
||||
.with_state(shared_state.clone());
|
||||
|
||||
let (router, mut api) = OpenApiRouter::new()
|
||||
.nest("/api", open_api_router)
|
||||
.route("/", get(index)) // temporary index page, will redirect/proxy to flutter app
|
||||
.with_state(shared_state)
|
||||
.split_for_parts();
|
||||
|
||||
api.info = InfoBuilder::new()
|
||||
.title("Alexandria")
|
||||
.description(Some("Alexandria is a server that manages books and users for Union Étudiante's book exchange"))
|
||||
.contact(Some(ContactBuilder::new()
|
||||
.url(Some("https://ueauvergne.fr"))
|
||||
.name(Some("Union Étudiante Auvergne"))
|
||||
.email(Some("unionetudianteauvergne@gmail.com"))
|
||||
.build()))
|
||||
.license(Some(LicenseBuilder::new().name("MIT").url(Some("https://spdx.org/licenses/MIT.html")).build()))
|
||||
.version("1.0.0")
|
||||
.build();
|
||||
|
||||
api.merge(ApiDoc::openapi());
|
||||
|
||||
let swagger = SwaggerUi::new("/docs/")
|
||||
.url("/docs/openapi.json", api)
|
||||
.config(Config::default()
|
||||
.try_it_out_enabled(true)
|
||||
.filter(true)
|
||||
.display_request_duration(true)
|
||||
.persist_authorization(true)
|
||||
);
|
||||
|
||||
let router = router.merge(swagger);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)).await.unwrap();
|
||||
log::info!("Running on http://{}", listener.local_addr().unwrap());
|
||||
axum::serve(
|
||||
listener,
|
||||
router.into_make_service_with_connect_info::<SocketAddr>()
|
||||
).await.unwrap();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{entities::user, AppState, Commands, CLI, KEYS};
|
||||
use crate::{entities::user, AppState, KEYS};
|
||||
|
||||
//const TOKEN_EXPIRY_TIME: u64 = 15_778_476; // 6 Months
|
||||
pub const DEFAULT_TOKEN_EXPIRY_TIME: u64 = 15_778_476; // 6 Months
|
||||
|
||||
pub async fn auth_middleware(
|
||||
_claims: Claims,
|
||||
|
|
@ -44,15 +44,14 @@ pub async fn auth(State(state): State<Arc<AppState>>, Json(payload): Json<AuthPa
|
|||
match user::Entity::find().filter(user::Column::Username.eq(payload.username)).one(state.db_conn.as_ref()).await {
|
||||
Err(_) | Ok(None) => return Err(AuthError::WrongCredentials),
|
||||
Ok(Some(user)) => {
|
||||
user.verify_password(payload.password);
|
||||
if !user.verify_password(payload.password) {
|
||||
return Err(AuthError::WrongCredentials);
|
||||
};
|
||||
let unix_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time should go forward").as_secs();
|
||||
|
||||
let claims = Claims {
|
||||
sub: user.username,
|
||||
exp: unix_timestamp + match CLI.command {
|
||||
Commands::Run { token_expiration_time, .. } => token_expiration_time,
|
||||
_ => panic!("The auth endpoint cannot be used outside of a Run command")
|
||||
},
|
||||
exp: unix_timestamp + DEFAULT_TOKEN_EXPIRY_TIME,
|
||||
user_id: user.id
|
||||
};
|
||||
let token = encode(&Header::default(), &claims, &KEYS.encoding)
|
||||
|
|
@ -63,6 +62,10 @@ pub async fn auth(State(state): State<Arc<AppState>>, Json(payload): Json<AuthPa
|
|||
}
|
||||
}
|
||||
|
||||
pub fn generate_token_from_claims(claims: Claims) -> String {
|
||||
encode(&Header::default(), &claims, &KEYS.encoding).unwrap()
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::ToSchema)]
|
||||
pub struct TokenPayload {
|
||||
token: String
|
||||
|
|
@ -152,15 +155,15 @@ impl Keys {
|
|||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
sub: String,
|
||||
exp: u64,
|
||||
pub sub: String,
|
||||
pub exp: u64,
|
||||
pub user_id: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)]
|
||||
pub struct AuthBody {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
pub access_token: String,
|
||||
pub token_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
use argon2::{Argon2, PasswordHasher};
|
||||
use password_hash::{rand_core::OsRng, SaltString};
|
||||
use sea_orm::ConnectionTrait;
|
||||
|
||||
use crate::entities::prelude::*;
|
||||
|
|
@ -40,3 +42,8 @@ where C: ConnectionTrait,
|
|||
None => false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hash_password(password: String) -> String {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
Argon2::default().hash_password(&password.clone().into_bytes(), &salt).unwrap().to_string()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
use std::{fmt::Display, sync::Arc};
|
||||
|
||||
use argon2::{password_hash::{SaltString}, Argon2, PasswordHasher};
|
||||
use inquire::{min_length, prompt_text, Confirm, Password, Select, Text};
|
||||
use password_hash::rand_core::OsRng;
|
||||
use sea_orm::{ActiveModelTrait, ActiveValue::{NotSet, Set}, ColumnTrait, DatabaseConnection, EntityTrait, ModelTrait, QueryFilter};
|
||||
|
||||
use crate::entities::{owner, prelude::User, user::{self, ActiveModel}};
|
||||
use crate::{entities::{owner, prelude::User, user::{self, ActiveModel}}, utils::auth::hash_password};
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
enum Action {
|
||||
|
|
@ -146,11 +144,6 @@ pub async fn manage_users(db: Arc<DatabaseConnection>) {
|
|||
}
|
||||
}
|
||||
|
||||
fn hash_password(password: String) -> String {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
Argon2::default().hash_password(&password.clone().into_bytes(), &salt).unwrap().to_string()
|
||||
}
|
||||
|
||||
async fn select_user(db: Arc<DatabaseConnection>) -> Option<user::Model> {
|
||||
let users = User::find().all(db.as_ref()).await.unwrap();
|
||||
if users.is_empty() {
|
||||
|
|
|
|||
61
tests/auth.rs
Normal file
61
tests/auth.rs
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use alexandria::routes::auth::AuthBody;
|
||||
use reqwest::StatusCode;
|
||||
|
||||
mod common;
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_wrong_password() {
|
||||
let data = common::setup().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let mut wrong_pwd_auth_map = HashMap::new();
|
||||
wrong_pwd_auth_map.insert("username", "test_username");
|
||||
wrong_pwd_auth_map.insert("password", "pwd");
|
||||
let wrong_pwd_auth_res = client.execute(client.post(format!("{}/auth", data.api_path)).json(&wrong_pwd_auth_map).build().unwrap()).await.unwrap();
|
||||
assert_eq!(wrong_pwd_auth_res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_wrong_username() {
|
||||
let data = common::setup().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let mut wrong_username_auth_map = HashMap::new();
|
||||
wrong_username_auth_map.insert("username", "wrong_username");
|
||||
wrong_username_auth_map.insert("password", "test_password");
|
||||
let wrong_username_auth_res = client.execute(client.post(format!("{}/auth", data.api_path)).json(&wrong_username_auth_map).build().unwrap()).await.unwrap();
|
||||
assert_eq!(wrong_username_auth_res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_correct_credentials() {
|
||||
let data = common::setup().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let mut auth_map = HashMap::new();
|
||||
auth_map.insert("username", "test_username");
|
||||
auth_map.insert("password", "test_password");
|
||||
let auth_res = client.execute(client.post(format!("{}/auth", data.api_path)).json(&auth_map).build().unwrap()).await.unwrap();
|
||||
assert_eq!(auth_res.status(), StatusCode::OK);
|
||||
let auth_body = auth_res.json::<AuthBody>().await.unwrap();
|
||||
|
||||
let mut check_token_map = HashMap::new();
|
||||
check_token_map.insert("token", auth_body.access_token);
|
||||
let check_token_res = client.execute(client.post(format!("{}/token-check", data.api_path)).json(&check_token_map).build().unwrap()).await.unwrap();
|
||||
let valid_token = check_token_res.json::<bool>().await.unwrap();
|
||||
assert_eq!(valid_token, true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_wrong_token_check() {
|
||||
let data = common::setup().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let mut check_wrong_token_map = HashMap::new();
|
||||
check_wrong_token_map.insert("token", "this-is-definitely-not-a-wrong-token");
|
||||
let check_wrong_token_res = client.execute(client.post(format!("{}/token-check", data.api_path)).json(&check_wrong_token_map).build().unwrap()).await.unwrap();
|
||||
let invalid_token = check_wrong_token_res.json::<bool>().await.unwrap();
|
||||
assert_eq!(invalid_token, false);
|
||||
}
|
||||
86
tests/common/mod.rs
Normal file
86
tests/common/mod.rs
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
use std::{net::{Ipv4Addr, SocketAddrV4, TcpListener}, sync::Arc, time::{SystemTime, UNIX_EPOCH}};
|
||||
|
||||
use sea_orm::{ActiveModelTrait, ActiveValue::{NotSet, Set}, ColumnTrait, ConnectionTrait, Database, DatabaseConnection, EntityTrait, QueryFilter};
|
||||
|
||||
use alexandria::{create_tables, entities::{owner, prelude::*, user}, routes::auth::{generate_token_from_claims, Claims, DEFAULT_TOKEN_EXPIRY_TIME}, run_server, utils::auth::hash_password};
|
||||
|
||||
pub struct SetupData {
|
||||
/// A valid JWT for testing features that need authentication
|
||||
pub jwt: String,
|
||||
pub api_path: String
|
||||
}
|
||||
|
||||
/// Common setup function for tests that require a database and server setup
|
||||
pub async fn setup() -> SetupData {
|
||||
let _ = pretty_env_logger::try_init();
|
||||
|
||||
let db: Arc<DatabaseConnection> = Arc::new(
|
||||
match Database::connect(format!("sqlite::memory:?mode=rwc")).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
panic!("Error while opening fatabase: {}", e.to_string())
|
||||
}
|
||||
});
|
||||
|
||||
create_tables(db.as_ref()).await.expect("Create tables should not fail");
|
||||
|
||||
let port = free_local_ipv4_port().expect("Could not get a free port");
|
||||
let db_c = db.clone();
|
||||
tokio::spawn(async move {
|
||||
run_server(db_c, port, false).await;
|
||||
});
|
||||
|
||||
create_user(db.as_ref(), "test_username", "test_password").await;
|
||||
|
||||
|
||||
let unix_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time should go forward").as_secs();
|
||||
|
||||
let claims = Claims {
|
||||
sub: "test_username".to_string(),
|
||||
exp: unix_timestamp + DEFAULT_TOKEN_EXPIRY_TIME,
|
||||
user_id: 1
|
||||
};
|
||||
|
||||
SetupData {
|
||||
jwt: generate_token_from_claims(claims),
|
||||
api_path: format!("http://0.0.0.0:{port}/api")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn free_local_ipv4_port() -> Option<u16> {
|
||||
let socket = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
|
||||
TcpListener::bind(socket)
|
||||
.and_then(|listener| listener.local_addr())
|
||||
.map(|addr| addr.port())
|
||||
.ok()
|
||||
}
|
||||
|
||||
async fn create_user<C>(db_conn: &C, username_t: impl ToString, password_t: impl ToString)
|
||||
where C: ConnectionTrait {
|
||||
let username = username_t.to_string();
|
||||
if User::find().filter(user::Column::Username.eq(username.clone())).one(db_conn).await.is_ok_and(|r| r.is_some()) {
|
||||
panic!("Username {username} already in use");
|
||||
} else {
|
||||
let password = password_t.to_string();
|
||||
let mut new_user = user::ActiveModel {
|
||||
id: NotSet,
|
||||
username: Set(username.clone()),
|
||||
hashed_password: Set(hash_password(password)),
|
||||
current_bal_id: Set(None),
|
||||
owner_id: Set(None)
|
||||
};
|
||||
let res = new_user.clone().insert(db_conn).await.unwrap();
|
||||
|
||||
let new_owner = owner::ActiveModel {
|
||||
id: NotSet,
|
||||
user_id: Set(res.id),
|
||||
first_name: Set(format!("{username} first name")),
|
||||
last_name: Set(format!("{username} last name")),
|
||||
contact: Set(format!("{username}@mail.com"))
|
||||
};
|
||||
let owner_res = new_owner.insert(db_conn).await.unwrap();
|
||||
new_user.owner_id = Set(Some(owner_res.id));
|
||||
let _ = new_user.update(db_conn);
|
||||
}
|
||||
}
|
||||
Reference in a new issue