Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion proto
Submodule proto updated 1 files
+13 −1 core/proxy.proto
101 changes: 101 additions & 0 deletions src/enterprise/handlers/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use axum::{extract::State, Json};
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
use tracing::{debug, error, info, warn};

use crate::{
enterprise::handlers::openid_login::{
AuthenticationResponse, FlowType, CSRF_COOKIE_NAME, NONCE_COOKIE_NAME,
},
error::ApiError,
handlers::get_core_response,
http::AppState,
proto::{core_request, core_response, ClientMfaOidcAuthenticateRequest, DeviceInfo},
};

#[instrument(level = "debug", skip(state))]
pub(super) async fn mfa_auth_callback(
State(state): State<AppState>,
device_info: DeviceInfo,
mut private_cookies: PrivateCookieJar,
Json(payload): Json<AuthenticationResponse>,
) -> Result<PrivateCookieJar, ApiError> {
info!("Processing MFA authentication callback");
debug!(
"Received payload: state={}, flow_type={}",
payload.state, payload.flow_type
);

let flow_type = payload.flow_type.parse::<FlowType>().map_err(|err| {
warn!("Failed to parse flow type '{}': {err:?}", payload.flow_type);
ApiError::BadRequest("Invalid flow type".into())
})?;

if flow_type != FlowType::Mfa {
warn!("Invalid flow type for MFA callback: {flow_type:?}");
return Err(ApiError::BadRequest(
"Invalid flow type for MFA callback".into(),
));
}

debug!("Flow type validation passed: {flow_type:?}");

let nonce = private_cookies
.get(NONCE_COOKIE_NAME)
.ok_or_else(|| {
warn!("Nonce cookie not found in request");
ApiError::Unauthorized("Nonce cookie not found".into())
})?
.value_trimmed()
.to_string();

let csrf = private_cookies
.get(CSRF_COOKIE_NAME)
.ok_or_else(|| {
warn!("CSRF cookie not found in request");
ApiError::Unauthorized("CSRF cookie not found".into())
})?
.value_trimmed()
.to_string();

debug!("Retrieved cookies successfully");

if payload.state != csrf {
warn!(
"CSRF token mismatch: expected={csrf}, received={}",
payload.state
);
return Err(ApiError::Unauthorized("CSRF token mismatch".into()));
}

debug!("CSRF token validation passed");

private_cookies = private_cookies
.remove(Cookie::from(NONCE_COOKIE_NAME))
.remove(Cookie::from(CSRF_COOKIE_NAME));

debug!("Removed security cookies");

let request = ClientMfaOidcAuthenticateRequest {
code: payload.code,
nonce,
callback_url: state.callback_url(flow_type).to_string(),
state: payload.state,
};

debug!("Sending MFA OIDC authenticate request to core service");

let rx = state.grpc_server.send(
core_request::Payload::ClientMfaOidcAuthenticate(request),
device_info,
)?;

let payload = get_core_response(rx).await?;

if let core_response::Payload::Empty(()) = payload {
info!("MFA authentication callback completed successfully");
Ok(private_cookies)
} else {
error!("Received invalid gRPC response type during handling the MFA OpenID authentication callback: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}
1 change: 1 addition & 0 deletions src/enterprise/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod desktop_client_mfa;
pub mod openid_login;
70 changes: 57 additions & 13 deletions src/enterprise/handlers/openid_login.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use axum::{
extract::State,
routing::{get, post},
Json, Router,
};
use axum::{extract::State, routing::post, Json, Router};
use axum_extra::extract::{
cookie::{Cookie, SameSite},
PrivateCookieJar,
Expand All @@ -11,6 +7,7 @@ use serde::{Deserialize, Serialize};
use time::Duration;

use crate::{
enterprise::handlers::desktop_client_mfa::mfa_auth_callback,
error::ApiError,
handlers::get_core_response,
http::AppState,
Expand All @@ -21,13 +18,14 @@ use crate::{
};

const COOKIE_MAX_AGE: Duration = Duration::days(1);
static CSRF_COOKIE_NAME: &str = "csrf_proxy";
static NONCE_COOKIE_NAME: &str = "nonce_proxy";
pub(super) static CSRF_COOKIE_NAME: &str = "csrf_proxy";
pub(super) static NONCE_COOKIE_NAME: &str = "nonce_proxy";

pub(crate) fn router() -> Router<AppState> {
Router::new()
.route("/auth_info", get(auth_info))
.route("/auth_info", post(auth_info))
.route("/callback", post(auth_callback))
.route("/callback/mfa", post(mfa_auth_callback))
}

#[derive(Serialize)]
Expand All @@ -46,17 +44,49 @@ impl AuthInfo {
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub(crate) enum FlowType {
Enrollment,
Mfa,
}

impl std::str::FromStr for FlowType {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"enrollment" => Ok(FlowType::Enrollment),
"mfa" => Ok(FlowType::Mfa),
_ => Err(()),
}
}
}

#[derive(Deserialize, Debug)]
struct RequestData {
state: Option<String>,
#[serde(rename = "type")]
flow_type: String,
}

/// Request external OAuth2/OpenID provider details from Defguard Core.
#[instrument(level = "debug", skip(state))]
async fn auth_info(
State(state): State<AppState>,
device_info: DeviceInfo,
private_cookies: PrivateCookieJar,
Json(request_data): Json<RequestData>,
) -> Result<(PrivateCookieJar, Json<AuthInfo>), ApiError> {
debug!("Getting auth info for OAuth2/OpenID login");

let flow_type = request_data
.flow_type
.parse::<FlowType>()
.map_err(|_| ApiError::BadRequest("Invalid flow type".into()))?;

let request = AuthInfoRequest {
redirect_url: state.callback_url().to_string(),
redirect_url: state.callback_url(flow_type).to_string(),
state: request_data.state,
};

let rx = state
Expand Down Expand Up @@ -93,9 +123,11 @@ async fn auth_info(
}

#[derive(Debug, Deserialize)]
pub struct AuthenticationResponse {
code: String,
state: String,
pub(super) struct AuthenticationResponse {
pub(super) code: String,
pub(super) state: String,
#[serde(rename = "type")]
pub(super) flow_type: String,
}

#[derive(Serialize)]
Expand All @@ -111,6 +143,17 @@ async fn auth_callback(
mut private_cookies: PrivateCookieJar,
Json(payload): Json<AuthenticationResponse>,
) -> Result<(PrivateCookieJar, Json<CallbackResponseData>), ApiError> {
let flow_type = payload
.flow_type
.parse::<FlowType>()
.map_err(|_| ApiError::BadRequest("Invalid flow type".into()))?;

if flow_type != FlowType::Enrollment {
return Err(ApiError::BadRequest(
"Invalid flow type for OpenID enrollment callback".into(),
));
}

let nonce = private_cookies
.get(NONCE_COOKIE_NAME)
.ok_or(ApiError::Unauthorized("Nonce cookie not found".into()))?
Expand All @@ -133,13 +176,14 @@ async fn auth_callback(
let request = AuthCallbackRequest {
code: payload.code,
nonce,
callback_url: state.callback_url().to_string(),
callback_url: state.callback_url(flow_type).to_string(),
};

let rx = state
.grpc_server
.send(core_request::Payload::AuthCallback(request), device_info)?;
let payload = get_core_response(rx).await?;

if let core_response::Payload::AuthCallback(AuthCallbackResponse { url, token }) = payload {
debug!("Received auth callback response {url:?} {token:?}");
Ok((private_cookies, Json(CallbackResponseData { url, token })))
Expand Down
6 changes: 5 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub enum ApiError {
PermissionDenied(String),
#[error("Enterprise not enabled")]
EnterpriseNotEnabled,
#[error("Precondition required: {0}")]
PreconditionRequired(String),
}

impl IntoResponse for ApiError {
Expand All @@ -39,6 +41,7 @@ impl IntoResponse for ApiError {
StatusCode::PAYMENT_REQUIRED,
"Enterprise features are not enabled".to_string(),
),
Self::PreconditionRequired(msg) => (StatusCode::PRECONDITION_REQUIRED, msg),
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error".to_string(),
Expand All @@ -64,8 +67,9 @@ impl From<CoreError> for ApiError {
Code::FailedPrecondition => match status.message().to_lowercase().as_str() {
// TODO: find a better way than matching on the error message
"no valid license" => ApiError::EnterpriseNotEnabled,
_ => ApiError::Unexpected(status.to_string()),
_ => ApiError::PreconditionRequired(status.message().to_string()),
},
Code::Unavailable => ApiError::CoreTimeout,
_ => ApiError::Unexpected(status.to_string()),
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use url::Url;
use crate::{
assets::{index, svg, web_asset},
config::Config,
enterprise::handlers::openid_login,
enterprise::handlers::openid_login::{self, FlowType},
error::ApiError,
grpc::ProxyServer,
handlers::{desktop_client_mfa, enrollment, password_reset, polling},
Expand All @@ -49,11 +49,14 @@ pub(crate) struct AppState {
impl AppState {
/// Returns configured URL with "auth/callback" appended to the path.
#[must_use]
pub(crate) fn callback_url(&self) -> Url {
pub(crate) fn callback_url(&self, flow_type: FlowType) -> Url {
let mut url = self.url.clone();
// Append "/openid/callback" to the URL.
if let Ok(mut path_segments) = url.path_segments_mut() {
path_segments.extend(&["openid", "callback"]);
match flow_type {
FlowType::Enrollment => path_segments.extend(&["openid", "callback"]),
FlowType::Mfa => path_segments.extend(&["openid", "mfa", "callback"]),
};
}
url
}
Expand Down
10 changes: 10 additions & 0 deletions web/src/components/App/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import { detectLocale } from '../../i18n/i18n-util';
import { loadLocaleAsync } from '../../i18n/i18n-util.async';
import { EnrollmentPage } from '../../pages/enrollment/EnrollmentPage';
import { MainPage } from '../../pages/main/MainPage';
import { OpenIdMfaCallbackPage } from '../../pages/mfa/OpenIDCallback';
import { OpenIdMfaPage } from '../../pages/mfa/OpenIDRedirect';
import { OpenIDCallbackPage } from '../../pages/openidCallback/OpenIDCallback';
import { PasswordResetPage } from '../../pages/passwordReset/PasswordResetPage';
import { SessionTimeoutPage } from '../../pages/sessionTimeout/SessionTimeoutPage';
Expand Down Expand Up @@ -57,6 +59,14 @@ const router = createBrowserRouter([
path: routes.openidCallback,
element: <OpenIDCallbackPage />,
},
{
path: routes.openidMfa,
element: <OpenIdMfaPage />,
},
{
path: routes.openidMfaCallback,
element: <OpenIdMfaCallbackPage />,
},
{
path: '/*',
element: <Navigate to="/" replace />,
Expand Down
20 changes: 20 additions & 0 deletions web/src/i18n/en/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,26 @@ If you want to disengage your VPN connection, simply press "deactivate".
},
},
},
openidMfaCallback: {
error: {
title: 'Authentication Error',
message:
'There was an error during authentication with the provider. Please go back to the **Defguard VPN Client** and repeat the process.',
detailsTitle: 'Error Details',
},
success: {
title: 'Authentication Completed',
message:
'You have been successfully authenticated. Please close this window and get back to the **Defguard VPN Client**.',
},
},
openidMfaRedirect: {
error: {
title: 'Authentication Error',
message:
'No token provided in the URL. Please ensure you have a valid token to proceed with OpenID authentication.',
},
},
},
} satisfies BaseTranslation;

Expand Down
Loading
Loading