Compare commits

..

2 commits

2 changed files with 34 additions and 15 deletions

View file

@ -1,4 +1,4 @@
use std::{net::SocketAddr, path::PathBuf, sync::{Arc, LazyLock}}; use std::{net::{Ipv4Addr, SocketAddr}, path::PathBuf, sync::{Arc, LazyLock}};
use axum::{extract::State, http::HeaderMap, middleware, routing::get}; use axum::{extract::State, http::HeaderMap, middleware, routing::get};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
@ -21,7 +21,8 @@ pub mod routes;
#[command(version = "1.0")] #[command(version = "1.0")]
#[command(about = "BAL management server", long_about = None)] #[command(about = "BAL management server", long_about = None)]
struct Cli { struct Cli {
#[arg(long, short, value_name = "FILE")] /// Path to the sqlite database [default: ./alexandria.db]
#[arg(long, short, global = true, value_name = "FILE")]
database: Option<PathBuf>, database: Option<PathBuf>,
#[command(subcommand)] #[command(subcommand)]
@ -30,7 +31,16 @@ struct Cli {
#[derive(Subcommand)] #[derive(Subcommand)]
enum Commands { enum Commands {
Run, /// Serves the web server
Run {
/// Port on which to serve the web server
#[arg(short, long, default_value_t = 3000)]
port: u16,
/// How many seconds generated JWTs are valid for. Default equates to 6 months
#[arg(long, default_value_t = 15_778_476)]
token_expiration_time: u64,
},
/// Open a TUI to manage user accounts
User User
} }
@ -59,13 +69,15 @@ static KEYS: LazyLock<Keys> = LazyLock::new(|| {
Keys::new(secret.as_bytes()) Keys::new(secret.as_bytes())
}); });
static CLI: LazyLock<Cli> = LazyLock::new(|| {
Cli::parse()
});
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
pretty_env_logger::init(); pretty_env_logger::init();
let cli = Cli::parse(); let db_path = match &CLI.database {
let db_path = match cli.database {
Some(path) => { Some(path) => {
if path.is_dir() { if path.is_dir() {
log::error!("{path:?} is a directory"); log::error!("{path:?} is a directory");
@ -112,13 +124,13 @@ async fn main() {
return; return;
} }
match cli.command { match &CLI.command {
Commands::Run => run_server(db).await, Commands::Run {port,..} => run_server(db, *port).await,
Commands::User => utils::cli::manage_users(db).await Commands::User => utils::cli::manage_users(db).await
} }
} }
async fn run_server(db: Arc<DatabaseConnection>) { async fn run_server(db: Arc<DatabaseConnection>, port: u16) {
let (event_bus, _) = broadcast::channel(16); let (event_bus, _) = broadcast::channel(16);
if std::env::var("JWT_SECRET").is_err() { if std::env::var("JWT_SECRET").is_err() {
@ -159,7 +171,7 @@ async fn run_server(db: Arc<DatabaseConnection>) {
} }
} }
let (router, mut api) = OpenApiRouter::new() let open_api_router = OpenApiRouter::new()
// Book API // Book API
.routes(routes!(routes::book::get_book_by_ean)) .routes(routes!(routes::book::get_book_by_ean))
.routes(routes!(routes::book::get_book_by_id)) .routes(routes!(routes::book::get_book_by_id))
@ -188,8 +200,12 @@ async fn run_server(db: Arc<DatabaseConnection>) {
.routes(routes!(routes::auth::check_token)) .routes(routes!(routes::auth::check_token))
// Misc // Misc
.routes(routes!(routes::websocket::ws_handler)) .routes(routes!(routes::websocket::ws_handler))
.route("/", get(index))
.with_state(shared_state.clone());
let (router, mut api) = OpenApiRouter::new()
.nest("/api", open_api_router)
.route("/", get(index)) // temporary index page, will redirect/proxy to flutter app
.with_state(shared_state) .with_state(shared_state)
.split_for_parts(); .split_for_parts();
@ -217,7 +233,7 @@ async fn run_server(db: Arc<DatabaseConnection>) {
let router = router.merge(swagger); let router = router.merge(swagger);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); let listener = tokio::net::TcpListener::bind(SocketAddr::new(std::net::IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)).await.unwrap();
log::info!("Running on http://{}", listener.local_addr().unwrap()); log::info!("Running on http://{}", listener.local_addr().unwrap());
axum::serve( axum::serve(
listener, listener,

View file

@ -7,9 +7,9 @@ use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use crate::{entities::user, AppState, KEYS}; use crate::{entities::user, AppState, Commands, CLI, KEYS};
const TOKEN_EXPIRY_TIME: u64 = 15_778_476; // 6 Months //const TOKEN_EXPIRY_TIME: u64 = 15_778_476; // 6 Months
pub async fn auth_middleware( pub async fn auth_middleware(
_claims: Claims, _claims: Claims,
@ -49,7 +49,10 @@ pub async fn auth(State(state): State<Arc<AppState>>, Json(payload): Json<AuthPa
let claims = Claims { let claims = Claims {
sub: user.username, sub: user.username,
exp: unix_timestamp + TOKEN_EXPIRY_TIME, exp: unix_timestamp + match CLI.command {
Commands::Run { token_expiration_time, .. } => token_expiration_time,
_ => panic!("The auth endpoint cannot be used outside of a Run command")
},
user_id: user.id user_id: user.id
}; };
let token = encode(&Header::default(), &claims, &KEYS.encoding) let token = encode(&Header::default(), &claims, &KEYS.encoding)