From 64082144be70c0debde8599b1e1fbe0309043882 Mon Sep 17 00:00:00 2001 From: Speedy_Lex <78314533+speedy-lex@users.noreply.github.com> Date: Fri, 19 Sep 2025 00:05:35 +0200 Subject: [PATCH 1/2] move mutex from Authstate to tokenstore --- src/app.rs | 8 ++++---- src/app/auth.rs | 21 +++++++++++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/app.rs b/src/app.rs index f6ba671..cad9da4 100644 --- a/src/app.rs +++ b/src/app.rs @@ -6,7 +6,7 @@ use std::{ use argon2::{Algorithm, Argon2, Params, Version}; use axum::{Json, Router, extract::State, routing::post}; use serde::{Deserialize, Serialize}; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::RwLock; use tokio_postgres::{Config, config::SslMode}; use crate::app::auth::{AuthState, Token, router}; @@ -16,7 +16,7 @@ mod auth; #[derive(Debug, Clone)] struct AppState { value: Arc>, - auth: Arc>, + auth: Arc, hasher: Argon2<'static>, } impl AppState { @@ -40,12 +40,12 @@ impl AppState { Self { value: Arc::new(RwLock::new(0.0)), - auth: Arc::new(Mutex::new(AuthState::new(&cfg).await)), + auth: Arc::new(AuthState::new(&cfg).await), hasher: Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::DEFAULT), } } async fn validate_token(&self, token: &Token) -> bool { - self.auth.lock().await.tokens.contains_key(token) + self.auth.tokens.lock().await.contains_key(token) } } diff --git a/src/app/auth.rs b/src/app/auth.rs index 6541204..c4d0f0b 100644 --- a/src/app/auth.rs +++ b/src/app/auth.rs @@ -8,6 +8,7 @@ use axum::{Json, Router, extract::State, routing::post}; use rand::Rng; use serde::{Deserialize, Serialize}; use serde_with::{base64::Base64, serde_as}; +use tokio::sync::Mutex; use tokio_postgres::{Client, Config, NoTls}; use uuid::Uuid; @@ -20,7 +21,7 @@ pub struct Token(#[serde_as(as = "Base64")] pub [u8; 32]); #[derive(Debug)] pub struct AuthState { db: Client, - pub tokens: HashMap, + pub tokens: Mutex>, } impl AuthState { pub async fn new(cfg: &Config) -> Self { @@ -56,8 +57,8 @@ async fn register( let salt = SaltString::generate(&mut OsRng); let hash = state.hasher.hash_password(&request.secret, &salt).unwrap(); - let lock = state.auth.lock().await; - let res = lock + let res = state + .auth .db .execute( "WITH new_user AS ( @@ -94,8 +95,8 @@ async fn login( State(state): State, Json(request): Json, ) -> Json> { - let mut lock = state.auth.lock().await; - let res = lock + let res = state + .auth .db .query_one( "SELECT id, password_hash @@ -121,7 +122,15 @@ async fn login( .is_ok() { let token = Token(rand::rng().random::<[u8; 32]>()); - assert!(lock.tokens.insert(token, row.get("id")).is_none()); // we should never see a token collision + assert!( + state + .auth + .tokens + .lock() + .await + .insert(token, row.get("id")) + .is_none() + ); // we should never see a token collision Json(Ok(token)) } else { Json(Err("INVALID_CREDENTIALS".to_string())) From 640fc0decfe2de8aa0f00058dabc21103944b69a Mon Sep 17 00:00:00 2001 From: Speedy_Lex <78314533+speedy-lex@users.noreply.github.com> Date: Fri, 19 Sep 2025 00:23:23 +0200 Subject: [PATCH 2/2] prepare sql queries in advance --- src/app/auth.rs | 41 ++++++-------------------- src/app/auth/db.rs | 49 ++++++++++++++++++++++++++++++++ src/app/auth/sql/create_user.sql | 8 ++++++ src/app/auth/sql/get_user.sql | 3 ++ 4 files changed, 69 insertions(+), 32 deletions(-) create mode 100644 src/app/auth/db.rs create mode 100644 src/app/auth/sql/create_user.sql create mode 100644 src/app/auth/sql/get_user.sql diff --git a/src/app/auth.rs b/src/app/auth.rs index c4d0f0b..e247d5e 100644 --- a/src/app/auth.rs +++ b/src/app/auth.rs @@ -9,28 +9,28 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use serde_with::{base64::Base64, serde_as}; use tokio::sync::Mutex; -use tokio_postgres::{Client, Config, NoTls}; +use tokio_postgres::Config; use uuid::Uuid; +use crate::app::auth::db::Database; + use super::AppState; +mod db; + #[serde_as] #[derive(Serialize, Deserialize, Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] pub struct Token(#[serde_as(as = "Base64")] pub [u8; 32]); #[derive(Debug)] pub struct AuthState { - db: Client, + db: Database, pub tokens: Mutex>, } impl AuthState { pub async fn new(cfg: &Config) -> Self { - let (client, connection) = cfg.connect(NoTls).await.unwrap(); - tokio::spawn(async { - connection.await.unwrap(); // run the connection on a bg task - }); Self { - db: client, + db: Database::new(cfg).await, tokens: Default::default(), } } @@ -60,21 +60,7 @@ async fn register( let res = state .auth .db - .execute( - "WITH new_user AS ( - INSERT INTO public.users (password_hash, email) - VALUES ($2, $3) - RETURNING id - ) - INSERT INTO public.user_accounts (id, username) - SELECT id, $1 - FROM new_user;", - &[ - &request.username, - &hash.serialize().as_str(), - &request.email, - ], - ) + .create_user(&request.username, &request.email, hash.serialize()) .await; let res = res.map_err(|x| Json(Err(x.to_string()))); if let Err(e) = res { @@ -95,16 +81,7 @@ async fn login( State(state): State, Json(request): Json, ) -> Json> { - let res = state - .auth - .db - .query_one( - "SELECT id, password_hash - FROM public.users - WHERE email = $1;", - &[&request.email], - ) - .await; + let res = state.auth.db.get_user(&request.email).await; let res = res.map_err(|x| Json(Err(x.to_string()))); let row = match res { Ok(row) => row, diff --git a/src/app/auth/db.rs b/src/app/auth/db.rs new file mode 100644 index 0000000..ac6a28f --- /dev/null +++ b/src/app/auth/db.rs @@ -0,0 +1,49 @@ +use argon2::password_hash::PasswordHashString; +use tokio_postgres::{Client, Config, Error, NoTls, Row, Statement}; + +#[derive(Debug)] +pub struct Database { + client: Client, + create_user_statement: Statement, + get_user_statement: Statement, +} +impl Database { + pub async fn new(cfg: &Config) -> Self { + let (client, connection) = cfg.connect(NoTls).await.unwrap(); + tokio::spawn(async { + connection.await.unwrap(); // run the connection on a bg task + }); + let create_user_statement = client + .prepare(include_str!("sql/create_user.sql")) + .await + .unwrap(); + let get_user_statement = client + .prepare(include_str!("sql/get_user.sql")) + .await + .unwrap(); + Self { + client, + create_user_statement, + get_user_statement, + } + } + pub async fn create_user( + &self, + username: &str, + email: &str, + password_hash: PasswordHashString, + ) -> Result<(), Error> { + self.client + .execute( + &self.create_user_statement, + &[&username, &email, &password_hash.as_str()], + ) + .await + .map(|_| ()) + } + pub async fn get_user(&self, email: &str) -> Result { + self.client + .query_one(&self.get_user_statement, &[&email]) + .await + } +} diff --git a/src/app/auth/sql/create_user.sql b/src/app/auth/sql/create_user.sql new file mode 100644 index 0000000..9505dff --- /dev/null +++ b/src/app/auth/sql/create_user.sql @@ -0,0 +1,8 @@ +WITH new_user AS ( + INSERT INTO public.users (email, password_hash) + VALUES ($2, $3) + RETURNING id +) +INSERT INTO public.user_accounts (id, username) +SELECT id, $1 +FROM new_user; diff --git a/src/app/auth/sql/get_user.sql b/src/app/auth/sql/get_user.sql new file mode 100644 index 0000000..2823541 --- /dev/null +++ b/src/app/auth/sql/get_user.sql @@ -0,0 +1,3 @@ +SELECT id, password_hash +FROM public.users +WHERE email = $1;