Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ tracing = "0.1"
tracing-subscriber = "0.3"
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "postgres", "uuid", "time"] }
uuid = { version = "1", features = ["serde", "v4"] }
oauth2 = "4.4"
oauth2 = "5.0"
tower-sessions = "0.15"
axum-extra = { version = "0.12", features = ["cookie"] }
dotenvy = "0.15" # To load client secrets
reqwest = { version = "0.11", features = ["json", "rustls-tls"] }
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
time = "0.3.46"
openssl = { version = "0.10", features = ["vendored"] }
openssl-sys = { version = "0.9", features = ["vendored"] }
Expand Down
28 changes: 17 additions & 11 deletions server/src/handlers/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use axum::{
};
use oauth2::{
basic::BasicClient, AuthUrl, ClientId, ClientSecret, CsrfToken,
RedirectUrl, Scope, TokenUrl, TokenResponse,
RedirectUrl, Scope, TokenUrl, TokenResponse, EndpointSet, EndpointNotSet,
};
use serde::Deserialize;
use tower_sessions::Session;
Expand All @@ -17,6 +17,9 @@ use crate::models::User;
pub const AUTH_URL: &str = "https://github.com/login/oauth/authorize";
pub const TOKEN_URL: &str = "https://github.com/login/oauth/access_token";

// Type alias for a fully configured OAuth client with both auth and token endpoints set
type ConfiguredClient = BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;

// Routes for Auth
pub fn routes() -> Router<AppState> {
Router::new()
Expand Down Expand Up @@ -50,17 +53,22 @@ async fn github_callback(
// FIX: Handle config errors gracefully
let client = make_client(&state).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;

// Create a stateful HTTP client with no redirects (for SSRF protection)
let http_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("HTTP Client Error: {}", e)))?;

// 1. Exchange Code
let token = client
.exchange_code(oauth2::AuthorizationCode::new(query.code))
.request_async(oauth2::reqwest::async_http_client)
.request_async(&http_client)
.await
.map_err(|e| {
(StatusCode::INTERNAL_SERVER_ERROR, format!("Token Error: {}", e))
})?;

// 2. Fetch Profile
let http_client = reqwest::Client::new();
let user_data: User = http_client
.get("https://api.github.com/user")
.header("User-Agent", "TryCli Studio")
Expand Down Expand Up @@ -99,7 +107,7 @@ async fn get_me(session: Session) -> Result<impl IntoResponse, (StatusCode, Stri
}

// 4. Helper to create OAuth client (Now returns Result)
fn make_client(state: &AppState) -> Result<BasicClient, String> {
fn make_client(state: &AppState) -> Result<ConfiguredClient, String> {
let auth_url = AuthUrl::new(AUTH_URL.to_string())
.map_err(|e| format!("Invalid Auth URL: {}", e))?;

Expand All @@ -112,13 +120,11 @@ fn make_client(state: &AppState) -> Result<BasicClient, String> {
let redirect_url = RedirectUrl::new(format!("{}/auth/callback", api_url))
.map_err(|e| format!("Invalid Redirect URL: {}", e))?;

Ok(BasicClient::new(
ClientId::new(state.github_id.clone()),
Some(ClientSecret::new(state.github_secret.clone())),
auth_url,
Some(token_url),
)
.set_redirect_uri(redirect_url))
Ok(BasicClient::new(ClientId::new(state.github_id.clone()))
.set_client_secret(ClientSecret::new(state.github_secret.clone()))
.set_auth_uri(auth_url)
.set_token_uri(token_url)
.set_redirect_uri(redirect_url))
}

async fn logout(session: Session) -> Result<impl IntoResponse, (StatusCode, String)> {
Expand Down