Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use argon2::{Algorithm, Argon2, Params, Version};
use axum::{Json, Router, extract::State, routing::post};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, RwLock};
use tokio::sync::RwLock;
use tokio_postgres::{Config, config::SslMode};

use crate::app::auth::{AuthState, Token, router};
Expand All @@ -16,7 +16,7 @@ mod auth;
#[derive(Debug, Clone)]
struct AppState {
value: Arc<RwLock<f64>>,
auth: Arc<Mutex<AuthState>>,
auth: Arc<AuthState>,
hasher: Argon2<'static>,
}
impl AppState {
Expand All @@ -40,12 +40,12 @@ impl AppState {

Self {
value: Arc::new(RwLock::new(0.0)),
auth: Arc::new(Mutex::new(AuthState::new(&cfg).await)),
auth: Arc::new(AuthState::new(&cfg).await),
hasher: Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::DEFAULT),
}
}
async fn validate_token(&self, token: &Token) -> bool {
self.auth.lock().await.tokens.contains_key(token)
self.auth.tokens.lock().await.contains_key(token)
}
}

Expand Down
58 changes: 22 additions & 36 deletions src/app/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,29 @@ use axum::{Json, Router, extract::State, routing::post};
use rand::Rng;
use serde::{Deserialize, Serialize};
use serde_with::{base64::Base64, serde_as};
use tokio_postgres::{Client, Config, NoTls};
use tokio::sync::Mutex;
use tokio_postgres::Config;
use uuid::Uuid;

use crate::app::auth::db::Database;

use super::AppState;

mod db;

#[serde_as]
#[derive(Serialize, Deserialize, Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Token(#[serde_as(as = "Base64")] pub [u8; 32]);

#[derive(Debug)]
pub struct AuthState {
db: Client,
pub tokens: HashMap<Token, Uuid>,
db: Database,
pub tokens: Mutex<HashMap<Token, Uuid>>,
}
impl AuthState {
pub async fn new(cfg: &Config) -> Self {
let (client, connection) = cfg.connect(NoTls).await.unwrap();
tokio::spawn(async {
connection.await.unwrap(); // run the connection on a bg task
});
Self {
db: client,
db: Database::new(cfg).await,
tokens: Default::default(),
}
}
Expand All @@ -56,24 +57,10 @@ async fn register(
let salt = SaltString::generate(&mut OsRng);
let hash = state.hasher.hash_password(&request.secret, &salt).unwrap();

let lock = state.auth.lock().await;
let res = lock
let res = state
.auth
.db
.execute(
"WITH new_user AS (
INSERT INTO public.users (password_hash, email)
VALUES ($2, $3)
RETURNING id
)
INSERT INTO public.user_accounts (id, username)
SELECT id, $1
FROM new_user;",
&[
&request.username,
&hash.serialize().as_str(),
&request.email,
],
)
.create_user(&request.username, &request.email, hash.serialize())
.await;
let res = res.map_err(|x| Json(Err(x.to_string())));
if let Err(e) = res {
Expand All @@ -94,16 +81,7 @@ async fn login(
State(state): State<AppState>,
Json(request): Json<LoginRequest>,
) -> Json<Result<Token, String>> {
let mut lock = state.auth.lock().await;
let res = lock
.db
.query_one(
"SELECT id, password_hash
FROM public.users
WHERE email = $1;",
&[&request.email],
)
.await;
let res = state.auth.db.get_user(&request.email).await;
let res = res.map_err(|x| Json(Err(x.to_string())));
let row = match res {
Ok(row) => row,
Expand All @@ -121,7 +99,15 @@ async fn login(
.is_ok()
{
let token = Token(rand::rng().random::<[u8; 32]>());
assert!(lock.tokens.insert(token, row.get("id")).is_none()); // we should never see a token collision
assert!(
state
.auth
.tokens
.lock()
.await
.insert(token, row.get("id"))
.is_none()
); // we should never see a token collision
Json(Ok(token))
} else {
Json(Err("INVALID_CREDENTIALS".to_string()))
Expand Down
49 changes: 49 additions & 0 deletions src/app/auth/db.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use argon2::password_hash::PasswordHashString;
use tokio_postgres::{Client, Config, Error, NoTls, Row, Statement};

#[derive(Debug)]
pub struct Database {
client: Client,
create_user_statement: Statement,
get_user_statement: Statement,
}
impl Database {
pub async fn new(cfg: &Config) -> Self {
let (client, connection) = cfg.connect(NoTls).await.unwrap();
tokio::spawn(async {
connection.await.unwrap(); // run the connection on a bg task
});
let create_user_statement = client
.prepare(include_str!("sql/create_user.sql"))
.await
.unwrap();
let get_user_statement = client
.prepare(include_str!("sql/get_user.sql"))
.await
.unwrap();
Self {
client,
create_user_statement,
get_user_statement,
}
}
pub async fn create_user(
&self,
username: &str,
email: &str,
password_hash: PasswordHashString,
) -> Result<(), Error> {
self.client
.execute(
&self.create_user_statement,
&[&username, &email, &password_hash.as_str()],
)
.await
.map(|_| ())
}
pub async fn get_user(&self, email: &str) -> Result<Row, Error> {
self.client
.query_one(&self.get_user_statement, &[&email])
.await
}
}
8 changes: 8 additions & 0 deletions src/app/auth/sql/create_user.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
WITH new_user AS (
INSERT INTO public.users (email, password_hash)
VALUES ($2, $3)
RETURNING id
)
INSERT INTO public.user_accounts (id, username)
SELECT id, $1
FROM new_user;
3 changes: 3 additions & 0 deletions src/app/auth/sql/get_user.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT id, password_hash
FROM public.users
WHERE email = $1;
Loading