diff --git a/.sqlx/query-7ddef79c85c3e85b979d5a8a5e50660bcae531c2b8342ae2feffea7454450f10.json b/.sqlx/query-2eeee174a2a68ff5bd35bab32d35e01e900639bc113b7feee2ee52546f8f16b4.json similarity index 96% rename from .sqlx/query-7ddef79c85c3e85b979d5a8a5e50660bcae531c2b8342ae2feffea7454450f10.json rename to .sqlx/query-2eeee174a2a68ff5bd35bab32d35e01e900639bc113b7feee2ee52546f8f16b4.json index 9f60163ea..53fdfdd55 100644 --- a/.sqlx/query-7ddef79c85c3e85b979d5a8a5e50660bcae531c2b8342ae2feffea7454450f10.json +++ b/.sqlx/query-2eeee174a2a68ff5bd35bab32d35e01e900639bc113b7feee2ee52546f8f16b4.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT openid_enabled, wireguard_enabled, webhooks_enabled, worker_enabled, challenge_template, instance_name, main_logo_url, nav_logo_url, smtp_server, smtp_port, smtp_encryption \"smtp_encryption: _\", smtp_user, smtp_password \"smtp_password?: SecretStringWrapper\", smtp_sender, enrollment_vpn_step_optional, enrollment_welcome_message, enrollment_welcome_email, enrollment_welcome_email_subject, enrollment_use_welcome_message_as_email, uuid, ldap_url, ldap_bind_username, ldap_bind_password \"ldap_bind_password?: SecretStringWrapper\", ldap_group_search_base, ldap_user_search_base, ldap_user_obj_class, ldap_group_obj_class, ldap_username_attr, ldap_groupname_attr, ldap_group_member_attr, ldap_member_attr, openid_create_account, license, gateway_disconnect_notifications_enabled, ldap_use_starttls, ldap_tls_verify_cert, gateway_disconnect_notifications_inactivity_threshold, gateway_disconnect_notifications_reconnect_notification_enabled, ldap_sync_status \"ldap_sync_status: SyncStatus\", ldap_enabled, ldap_sync_enabled, ldap_is_authoritative, ldap_sync_interval, ldap_user_auxiliary_obj_classes, ldap_uses_ad, ldap_user_rdn_attr, ldap_sync_groups, openid_username_handling \"openid_username_handling: OpenidUsernameHandling\" FROM \"settings\" WHERE id = 1", + "query": "SELECT openid_enabled, wireguard_enabled, webhooks_enabled, worker_enabled, challenge_template, instance_name, main_logo_url, nav_logo_url, smtp_server, smtp_port, smtp_encryption \"smtp_encryption: _\", smtp_user, smtp_password \"smtp_password?: SecretStringWrapper\", smtp_sender, enrollment_vpn_step_optional, enrollment_welcome_message, enrollment_welcome_email, enrollment_welcome_email_subject, enrollment_use_welcome_message_as_email, uuid, ldap_url, ldap_bind_username, ldap_bind_password \"ldap_bind_password?: SecretStringWrapper\", ldap_group_search_base, ldap_user_search_base, ldap_user_obj_class, ldap_group_obj_class, ldap_username_attr, ldap_groupname_attr, ldap_group_member_attr, ldap_member_attr, openid_create_account, license, gateway_disconnect_notifications_enabled, ldap_use_starttls, ldap_tls_verify_cert, gateway_disconnect_notifications_inactivity_threshold, gateway_disconnect_notifications_reconnect_notification_enabled, ldap_sync_status \"ldap_sync_status: SyncStatus\", ldap_enabled, ldap_sync_enabled, ldap_is_authoritative, ldap_sync_interval, ldap_user_auxiliary_obj_classes, ldap_uses_ad, ldap_user_rdn_attr, ldap_sync_groups, openid_username_handling \"openid_username_handling: OpenidUsernameHandling\", use_openid_for_mfa FROM \"settings\" WHERE id = 1", "describe": { "columns": [ { @@ -274,6 +274,11 @@ } } } + }, + { + "ordinal": 48, + "name": "use_openid_for_mfa", + "type_info": "Bool" } ], "parameters": { @@ -327,8 +332,9 @@ false, true, false, + false, false ] }, - "hash": "7ddef79c85c3e85b979d5a8a5e50660bcae531c2b8342ae2feffea7454450f10" + "hash": "2eeee174a2a68ff5bd35bab32d35e01e900639bc113b7feee2ee52546f8f16b4" } diff --git a/.sqlx/query-3491725f35609e9b219c4d613cffd28a14cf37e546dfcabdfd78889dc1ef247f.json b/.sqlx/query-f3c5a612ced180d9b2014e027d34a20e3de28df8100f7c0d476d4182328daeeb.json similarity index 95% rename from .sqlx/query-3491725f35609e9b219c4d613cffd28a14cf37e546dfcabdfd78889dc1ef247f.json rename to .sqlx/query-f3c5a612ced180d9b2014e027d34a20e3de28df8100f7c0d476d4182328daeeb.json index beabc1823..afb392ccf 100644 --- a/.sqlx/query-3491725f35609e9b219c4d613cffd28a14cf37e546dfcabdfd78889dc1ef247f.json +++ b/.sqlx/query-f3c5a612ced180d9b2014e027d34a20e3de28df8100f7c0d476d4182328daeeb.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE \"settings\" SET openid_enabled = $1, wireguard_enabled = $2, webhooks_enabled = $3, worker_enabled = $4, challenge_template = $5, instance_name = $6, main_logo_url = $7, nav_logo_url = $8, smtp_server = $9, smtp_port = $10, smtp_encryption = $11, smtp_user = $12, smtp_password = $13, smtp_sender = $14, enrollment_vpn_step_optional = $15, enrollment_welcome_message = $16, enrollment_welcome_email = $17, enrollment_welcome_email_subject = $18, enrollment_use_welcome_message_as_email = $19, uuid = $20, ldap_url = $21, ldap_bind_username = $22, ldap_bind_password = $23, ldap_group_search_base = $24, ldap_user_search_base = $25, ldap_user_obj_class = $26, ldap_group_obj_class = $27, ldap_username_attr = $28, ldap_groupname_attr = $29, ldap_group_member_attr = $30, ldap_member_attr = $31, ldap_use_starttls = $32, ldap_tls_verify_cert = $33, openid_create_account = $34, license = $35, gateway_disconnect_notifications_enabled = $36, gateway_disconnect_notifications_inactivity_threshold = $37, gateway_disconnect_notifications_reconnect_notification_enabled = $38, ldap_sync_status = $39, ldap_enabled = $40, ldap_sync_enabled = $41, ldap_is_authoritative = $42, ldap_sync_interval = $43, ldap_user_auxiliary_obj_classes = $44, ldap_uses_ad = $45, ldap_user_rdn_attr = $46, ldap_sync_groups = $47, openid_username_handling = $48 WHERE id = 1", + "query": "UPDATE \"settings\" SET openid_enabled = $1, wireguard_enabled = $2, webhooks_enabled = $3, worker_enabled = $4, challenge_template = $5, instance_name = $6, main_logo_url = $7, nav_logo_url = $8, smtp_server = $9, smtp_port = $10, smtp_encryption = $11, smtp_user = $12, smtp_password = $13, smtp_sender = $14, enrollment_vpn_step_optional = $15, enrollment_welcome_message = $16, enrollment_welcome_email = $17, enrollment_welcome_email_subject = $18, enrollment_use_welcome_message_as_email = $19, uuid = $20, ldap_url = $21, ldap_bind_username = $22, ldap_bind_password = $23, ldap_group_search_base = $24, ldap_user_search_base = $25, ldap_user_obj_class = $26, ldap_group_obj_class = $27, ldap_username_attr = $28, ldap_groupname_attr = $29, ldap_group_member_attr = $30, ldap_member_attr = $31, ldap_use_starttls = $32, ldap_tls_verify_cert = $33, openid_create_account = $34, license = $35, gateway_disconnect_notifications_enabled = $36, gateway_disconnect_notifications_inactivity_threshold = $37, gateway_disconnect_notifications_reconnect_notification_enabled = $38, ldap_sync_status = $39, ldap_enabled = $40, ldap_sync_enabled = $41, ldap_is_authoritative = $42, ldap_sync_interval = $43, ldap_user_auxiliary_obj_classes = $44, ldap_uses_ad = $45, ldap_user_rdn_attr = $46, ldap_sync_groups = $47, openid_username_handling = $48, use_openid_for_mfa = $49 WHERE id = 1", "describe": { "columns": [], "parameters": { @@ -84,10 +84,11 @@ ] } } - } + }, + "Bool" ] }, "nullable": [] }, - "hash": "3491725f35609e9b219c4d613cffd28a14cf37e546dfcabdfd78889dc1ef247f" + "hash": "f3c5a612ced180d9b2014e027d34a20e3de28df8100f7c0d476d4182328daeeb" } diff --git a/crates/defguard_core/migrations/20250612111316_client_oidc_2fa.down.sql b/crates/defguard_core/migrations/20250612111316_client_oidc_2fa.down.sql new file mode 100644 index 000000000..12c99cb1d --- /dev/null +++ b/crates/defguard_core/migrations/20250612111316_client_oidc_2fa.down.sql @@ -0,0 +1 @@ +ALTER TABLE settings DROP COLUMN use_openid_for_mfa; diff --git a/crates/defguard_core/migrations/20250612111316_client_oidc_2fa.up.sql b/crates/defguard_core/migrations/20250612111316_client_oidc_2fa.up.sql new file mode 100644 index 000000000..e937b090f --- /dev/null +++ b/crates/defguard_core/migrations/20250612111316_client_oidc_2fa.up.sql @@ -0,0 +1 @@ +ALTER TABLE settings ADD COLUMN use_openid_for_mfa BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/crates/defguard_core/src/db/models/audit_log/metadata.rs b/crates/defguard_core/src/db/models/audit_log/metadata.rs index 141b7cc48..c8635ed9d 100644 --- a/crates/defguard_core/src/db/models/audit_log/metadata.rs +++ b/crates/defguard_core/src/db/models/audit_log/metadata.rs @@ -13,6 +13,7 @@ use crate::{ audit_stream::{AuditStream, AuditStreamType}, openid_provider::{DirectorySyncTarget, DirectorySyncUserBehavior, OpenIdProvider}, }, + events::ClientMFAMethod, }; #[derive(Serialize)] @@ -159,7 +160,7 @@ pub struct VpnClientMetadata { pub struct VpnClientMfaMetadata { pub location: WireguardNetwork, pub device: Device, - pub method: MFAMethod, + pub method: ClientMFAMethod, } #[derive(Serialize)] diff --git a/crates/defguard_core/src/db/models/audit_log/mod.rs b/crates/defguard_core/src/db/models/audit_log/mod.rs index 0cf327d92..906a9c11a 100644 --- a/crates/defguard_core/src/db/models/audit_log/mod.rs +++ b/crates/defguard_core/src/db/models/audit_log/mod.rs @@ -1,9 +1,10 @@ -use crate::db::{Id, NoId}; use chrono::NaiveDateTime; use ipnetwork::IpNetwork; use model_derive::Model; use sqlx::{FromRow, Type}; +use crate::db::{Id, NoId}; + pub mod metadata; #[derive(Clone, Debug, Deserialize, Serialize, Type)] diff --git a/crates/defguard_core/src/db/models/settings.rs b/crates/defguard_core/src/db/models/settings.rs index c2622557d..5747eb76b 100644 --- a/crates/defguard_core/src/db/models/settings.rs +++ b/crates/defguard_core/src/db/models/settings.rs @@ -120,6 +120,7 @@ pub struct Settings { // Whether to create a new account when users try to log in with external OpenID pub openid_create_account: bool, pub openid_username_handling: OpenidUsernameHandling, + pub use_openid_for_mfa: bool, pub license: Option, // Gateway disconnect notifications pub gateway_disconnect_notifications_enabled: bool, @@ -152,7 +153,7 @@ impl Settings { ldap_enabled, ldap_sync_enabled, ldap_is_authoritative, \ ldap_sync_interval, ldap_user_auxiliary_obj_classes, ldap_uses_ad, \ ldap_user_rdn_attr, ldap_sync_groups, \ - openid_username_handling \"openid_username_handling: OpenidUsernameHandling\" \ + openid_username_handling \"openid_username_handling: OpenidUsernameHandling\", use_openid_for_mfa \ FROM \"settings\" WHERE id = 1", ) .fetch_optional(executor) @@ -224,7 +225,8 @@ impl Settings { ldap_uses_ad = $45, \ ldap_user_rdn_attr = $46, \ ldap_sync_groups = $47, \ - openid_username_handling = $48 \ + openid_username_handling = $48, \ + use_openid_for_mfa = $49 \ WHERE id = 1", self.openid_enabled, self.wireguard_enabled, @@ -274,6 +276,7 @@ impl Settings { self.ldap_user_rdn_attr, &self.ldap_sync_groups as &Vec, &self.openid_username_handling as &OpenidUsernameHandling, + self.use_openid_for_mfa, ) .execute(executor) .await?; diff --git a/crates/defguard_core/src/db/models/user.rs b/crates/defguard_core/src/db/models/user.rs index e6db06197..097e2a29b 100644 --- a/crates/defguard_core/src/db/models/user.rs +++ b/crates/defguard_core/src/db/models/user.rs @@ -15,6 +15,7 @@ use rand::{ prelude::Distribution, Rng, }; +use serde::Serialize; use sqlx::{ query, query_as, query_scalar, Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, @@ -53,15 +54,7 @@ pub enum MFAMethod { Email, } -impl From for MFAMethod { - fn from(method: MfaMethod) -> Self { - match method { - MfaMethod::Totp => Self::OneTimePassword, - MfaMethod::Email => Self::Email, - } - } -} - +// Web MFA methods impl fmt::Display for MFAMethod { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( @@ -77,6 +70,21 @@ impl fmt::Display for MFAMethod { } } +// Client MFA methods +impl fmt::Display for MfaMethod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + MfaMethod::Totp => "TOTP", + MfaMethod::Email => "Email", + MfaMethod::Oidc => "OIDC", + } + ) + } +} + // User information ready to be sent as part of diagnostic data. #[derive(Serialize)] pub struct UserDiagnostic { diff --git a/crates/defguard_core/src/enterprise/audit_stream/audit_stream_manager.rs b/crates/defguard_core/src/enterprise/audit_stream/audit_stream_manager.rs index 0d7d3cea0..1550a0643 100644 --- a/crates/defguard_core/src/enterprise/audit_stream/audit_stream_manager.rs +++ b/crates/defguard_core/src/enterprise/audit_stream/audit_stream_manager.rs @@ -4,17 +4,15 @@ use bytes::Bytes; use sqlx::PgPool; use tokio::{sync::broadcast::Receiver, task::JoinSet, time::sleep}; use tokio_util::sync::CancellationToken; - use tracing::debug; +use super::AuditStreamReconfigurationNotification; use crate::enterprise::{ audit_stream::http_stream::{run_http_stream_task, HttpAuditStreamConfig}, db::models::audit_stream::{AuditStream, AuditStreamConfig}, is_enterprise_enabled, }; -use super::AuditStreamReconfigurationNotification; - pub async fn run_audit_stream_manager( pool: PgPool, notification: AuditStreamReconfigurationNotification, diff --git a/crates/defguard_core/src/enterprise/audit_stream/http_stream.rs b/crates/defguard_core/src/enterprise/audit_stream/http_stream.rs index 330af98c5..a2dd03a36 100644 --- a/crates/defguard_core/src/enterprise/audit_stream/http_stream.rs +++ b/crates/defguard_core/src/enterprise/audit_stream/http_stream.rs @@ -5,7 +5,6 @@ use bytes::Bytes; use reqwest::tls; use tokio::sync::broadcast::Receiver; use tokio_util::sync::CancellationToken; - use tracing::{debug, error}; use crate::{ diff --git a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs b/crates/defguard_core/src/enterprise/db/models/openid_provider.rs index f3cb40c7e..44f909ac0 100644 --- a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs +++ b/crates/defguard_core/src/enterprise/db/models/openid_provider.rs @@ -1,7 +1,7 @@ use std::fmt; use model_derive::Model; -use sqlx::{query, query_as, Error as SqlxError, PgPool, Type}; +use sqlx::{query, query_as, Error as SqlxError, PgExecutor, PgPool, Type}; use crate::db::{Id, NoId}; @@ -195,7 +195,10 @@ impl OpenIdProvider { } impl OpenIdProvider { - pub async fn find_by_name(pool: &PgPool, name: &str) -> Result, SqlxError> { + pub async fn find_by_name<'e, E>(executor: E, name: &str) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { query_as!( OpenIdProvider, "SELECT id, name, base_url, client_id, client_secret, display_name, \ @@ -207,11 +210,14 @@ impl OpenIdProvider { FROM openidprovider WHERE name = $1", name ) - .fetch_optional(pool) + .fetch_optional(executor) .await } - pub async fn get_current(pool: &PgPool) -> Result, SqlxError> { + pub async fn get_current<'e, E>(executor: E) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { query_as!( OpenIdProvider, "SELECT id, name, base_url, client_id, client_secret, display_name, \ @@ -222,7 +228,7 @@ impl OpenIdProvider { okta_private_jwk, okta_dirsync_client_id, directory_sync_group_match \ FROM openidprovider LIMIT 1" ) - .fetch_optional(pool) + .fetch_optional(executor) .await } } diff --git a/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs new file mode 100644 index 000000000..83d4ecedc --- /dev/null +++ b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs @@ -0,0 +1,146 @@ +use openidconnect::{AuthorizationCode, Nonce}; +use reqwest::Url; +use tonic::Status; + +use crate::{ + enterprise::{ + handlers::openid_login::{extract_state_data, user_from_claims}, + is_enterprise_enabled, + }, + events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, + grpc::{ + desktop_client_mfa::{ClientLoginSession, ClientMfaServer}, + proto::proxy::{ClientMfaOidcAuthenticateRequest, DeviceInfo, MfaMethod}, + utils::parse_client_info, + }, +}; + +impl ClientMfaServer { + #[instrument(skip_all)] + pub async fn auth_mfa_session_with_oidc( + &mut self, + request: ClientMfaOidcAuthenticateRequest, + info: Option, + ) -> Result<(), Status> { + debug!("Received OIDC MFA authentication request: {request:?}"); + if !is_enterprise_enabled() { + error!("OIDC MFA method requires enterprise feature to be enabled"); + return Err(Status::invalid_argument("OIDC MFA method is not supported")); + } + + let token = extract_state_data(&request.state).ok_or_else(|| { + error!( + "Failed to extract state data from state: {:?}", + request.state + ); + Status::invalid_argument("invalid state data") + })?; + if token.is_empty() { + debug!("Empty token provided in request"); + return Err(Status::invalid_argument("empty token provided")); + } + let pubkey = Self::parse_token(&token)?; + + // fetch login session + let Some(session) = self.sessions.get(&pubkey).cloned() else { + debug!("Client login session not found"); + return Err(Status::invalid_argument("login session not found")); + }; + let ClientLoginSession { + method, + device, + location, + user, + openid_auth_completed, + } = session; + + if openid_auth_completed { + debug!("Client login session already completed"); + return Err(Status::invalid_argument("login session already completed")); + } + + if method != MfaMethod::Oidc { + debug!("Invalid MFA method for OIDC authentication: {method:?}"); + self.sessions.remove(&pubkey); + return Err(Status::invalid_argument("invalid MFA method")); + } + + let (ip, user_agent) = parse_client_info(&info).map_err(Status::internal)?; + let context = BidiRequestContext::new(user.id, user.username.clone(), ip, user_agent); + + let code = AuthorizationCode::new(request.code.clone()); + let url = match Url::parse(&request.callback_url).map_err(|err| { + error!("Invalid redirect URL provided: {err:?}"); + Status::invalid_argument("invalid redirect URL") + }) { + Ok(url) => url, + Err(status) => { + self.sessions.remove(&pubkey); + self.emit_event(BidiStreamEvent { + context, + event: BidiStreamEventType::DesktopClientMfa(Box::new( + DesktopClientMfaEvent::Failed { + location: location.clone(), + device: device.clone(), + method, + }, + )), + })?; + return Err(status); + } + }; + + match user_from_claims(&self.pool, Nonce::new(request.nonce.clone()), code, url).await { + Ok(claims_user) => { + // if thats not our user, prevent login + if claims_user.id != user.id { + info!("User {claims_user} tried to use OIDC MFA for another user: {user}"); + self.sessions.remove(&pubkey); + self.emit_event(BidiStreamEvent { + context, + event: BidiStreamEventType::DesktopClientMfa(Box::new( + DesktopClientMfaEvent::Failed { + location: location.clone(), + device: device.clone(), + method, + }, + )), + })?; + return Err(Status::unauthenticated("unauthorized")); + } + info!( + "OIDC MFA authentication completed successfully for user: {}", + user.username + ); + } + Err(err) => { + info!("Failed to verify OIDC code: {err:?}"); + self.sessions.remove(&pubkey); + self.emit_event(BidiStreamEvent { + context, + event: BidiStreamEventType::DesktopClientMfa(Box::new( + DesktopClientMfaEvent::Failed { + location: location.clone(), + device: device.clone(), + method, + }, + )), + })?; + return Err(Status::unauthenticated("unauthorized")); + } + }; + + self.sessions.insert( + pubkey.clone(), + ClientLoginSession { + method, + device: device.clone(), + location: location.clone(), + user: user.clone(), + openid_auth_completed: true, + }, + ); + + Ok(()) + } +} diff --git a/crates/defguard_core/src/enterprise/grpc/mod.rs b/crates/defguard_core/src/enterprise/grpc/mod.rs index 505916a0a..cc68fd70d 100644 --- a/crates/defguard_core/src/enterprise/grpc/mod.rs +++ b/crates/defguard_core/src/enterprise/grpc/mod.rs @@ -1 +1,2 @@ +pub mod desktop_client_mfa; pub mod polling; diff --git a/crates/defguard_core/src/enterprise/handlers/audit_stream.rs b/crates/defguard_core/src/enterprise/handlers/audit_stream.rs index e920801ee..350925886 100644 --- a/crates/defguard_core/src/enterprise/handlers/audit_stream.rs +++ b/crates/defguard_core/src/enterprise/handlers/audit_stream.rs @@ -5,6 +5,7 @@ use axum::{ use reqwest::StatusCode; use serde_json::json; +use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, @@ -14,8 +15,6 @@ use crate::{ handlers::{ApiResponse, ApiResult}, }; -use super::LicenseInfo; - pub async fn get_audit_stream( _admin: AdminRole, State(appstate): State, diff --git a/crates/defguard_core/src/enterprise/handlers/openid_login.rs b/crates/defguard_core/src/enterprise/handlers/openid_login.rs index aa63aff4f..f37f7b3aa 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_login.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_login.rs @@ -8,6 +8,7 @@ use axum_extra::{ headers::UserAgent, TypedHeader, }; +use base64::{prelude::BASE64_STANDARD, Engine}; use openidconnect::{ core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata, CoreUserInfoClaims}, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointMaybeSet, EndpointNotSet, @@ -115,6 +116,34 @@ async fn get_provider_metadata(url: &str) -> Result) -> CsrfToken { + let csrf_token = CsrfToken::new_random(); + if let Some(data) = state_data { + let combined = format!("{}.{data}", csrf_token.secret()); + let encoded = BASE64_STANDARD.encode(combined); + CsrfToken::new(encoded) + } else { + csrf_token + } +} + +/// Extract the state data from the provided state. +pub(crate) fn extract_state_data(state: &str) -> Option { + let decoded = BASE64_STANDARD.decode(state).ok()?; + let decoded_str = String::from_utf8(decoded).ok()?; + let result = decoded_str.split_once('.'); + if let Some((part1, part2)) = result { + if part1.is_empty() { + None + } else { + Some(part2.to_string()) + } + } else { + None + } +} + /// Build OpenID Connect client. /// `url`: redirect/callback URL pub(crate) async fn make_oidc_client( @@ -673,4 +702,54 @@ mod test { "averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee" ); } + + #[test] + fn test_state_build_and_extract() { + // without data + let token = build_state(None); + let decoded = BASE64_STANDARD.decode(token.secret()); + // not base64 encoded + assert!(decoded.is_err()); + assert!(!token.secret().is_empty()); + + // with data + let data = "somedata".to_string(); + let token = build_state(Some(data.clone())); + let decoded = BASE64_STANDARD.decode(token.secret()); + assert!(decoded.is_ok()); + let decoded_str = String::from_utf8(decoded.unwrap()).unwrap(); + let (csrf, state_data) = decoded_str.split_once('.').unwrap(); + assert!(!csrf.is_empty()); + assert_eq!(state_data, data); + + // valid + let data = "my_state_data".to_string(); + let token = build_state(Some(data.clone())); + let extracted = extract_state_data(token.secret()); + assert_eq!(extracted, Some(data)); + + // invalid base64 + let extracted = extract_state_data("not_base64!!"); + assert_eq!(extracted, None); + + // no dot + let encoded = BASE64_STANDARD.encode("no_dot_here"); + let extracted = extract_state_data(&encoded); + assert_eq!(extracted, None); + + // empty first part + let encoded = BASE64_STANDARD.encode(".somedata"); + let extracted = extract_state_data(&encoded); + assert_eq!(extracted, None); + + // empty second part + let encoded = BASE64_STANDARD.encode("csrf."); + let extracted = extract_state_data(&encoded); + assert_eq!(extracted, Some("".to_string())); + + // multiple dots + let encoded = BASE64_STANDARD.encode("csrf.data.with.dots"); + let extracted = extract_state_data(&encoded); + assert_eq!(extracted, Some("data.with.dots".to_string())); + } } diff --git a/crates/defguard_core/src/enterprise/handlers/openid_providers.rs b/crates/defguard_core/src/enterprise/handlers/openid_providers.rs index 92e63f292..4ceb9f124 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_providers.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_providers.rs @@ -41,6 +41,7 @@ pub struct AddProviderData { pub okta_dirsync_client_id: Option, pub directory_sync_group_match: Option, pub username_handling: OpenidUsernameHandling, + pub use_openid_for_mfa: bool, } #[derive(Debug, Deserialize, Serialize)] @@ -117,6 +118,7 @@ pub async fn add_openid_provider( let mut settings = Settings::get_current_settings(); settings.openid_create_account = provider_data.create_account; + settings.use_openid_for_mfa = provider_data.use_openid_for_mfa; settings.openid_username_handling = provider_data.username_handling; update_current_settings(&appstate.pool, settings).await?; @@ -177,7 +179,6 @@ pub async fn get_current_openid_provider( State(appstate): State, ) -> ApiResult { let settings = Settings::get_current_settings(); - let create_account = settings.openid_create_account; match OpenIdProvider::get_current(&appstate.pool).await? { Some(mut provider) => { // Get rid of it, it should stay on the backend only. @@ -186,7 +187,7 @@ pub async fn get_current_openid_provider( Ok(ApiResponse { json: json!({ "provider": json!(provider), - "settings": json!({ "create_account": create_account, "username_handling": settings.openid_username_handling}), + "settings": json!({ "create_account": settings.openid_create_account, "username_handling": settings.openid_username_handling, "use_openid_for_mfa": settings.use_openid_for_mfa }), }), status: StatusCode::OK, }) @@ -194,7 +195,7 @@ pub async fn get_current_openid_provider( None => Ok(ApiResponse { json: json!({ "provider": null, - "settings": json!({ "create_account": create_account }), + "settings": json!({ "create_account": settings.openid_create_account, "username_handling": settings.openid_username_handling, "use_openid_for_mfa": settings.use_openid_for_mfa }), }), status: StatusCode::NO_CONTENT, }), @@ -213,9 +214,14 @@ pub async fn delete_openid_provider( "User {} deleting OpenID provider {}", session.user.username, provider_data.name ); - let provider = OpenIdProvider::find_by_name(&appstate.pool, &provider_data.name).await?; + let mut trasnaction = appstate.pool.begin().await?; + let provider = OpenIdProvider::find_by_name(&mut *trasnaction, &provider_data.name).await?; if let Some(provider) = provider { - provider.clone().delete(&appstate.pool).await?; + let mut settings = Settings::get_current_settings(); + provider.clone().delete(&mut *trasnaction).await?; + settings.use_openid_for_mfa = false; + update_current_settings(&mut *trasnaction, settings).await?; + trasnaction.commit().await?; info!( "User {} deleted OpenID provider {}", session.user.username, provider.name @@ -240,6 +246,44 @@ pub async fn delete_openid_provider( } } +pub async fn modify_openid_provider( + _license: LicenseInfo, + _admin: AdminRole, + session: SessionInfo, + State(appstate): State, + Json(provider_data): Json, +) -> ApiResult { + debug!( + "User {} modifying OpenID provider {}", + session.user.username, provider_data.name + ); + let mut transaction = appstate.pool.begin().await?; + let provider = OpenIdProvider::find_by_name(&mut *transaction, &provider_data.name).await?; + if let Some(mut provider) = provider { + provider.base_url = provider_data.base_url; + provider.client_id = provider_data.client_id; + provider.client_secret = provider_data.client_secret; + provider.save(&mut *transaction).await?; + info!( + "User {} modified OpenID client {}", + session.user.username, provider.name + ); + Ok(ApiResponse { + json: json!({}), + status: StatusCode::OK, + }) + } else { + warn!( + "User {} failed to modify OpenID client {}. Such client does not exist.", + session.user.username, provider_data.name + ); + Ok(ApiResponse { + json: json!({}), + status: StatusCode::NOT_FOUND, + }) + } +} + pub async fn list_openid_providers( _license: LicenseInfo, _admin: AdminRole, diff --git a/crates/defguard_core/src/events.rs b/crates/defguard_core/src/events.rs index 8d557f7f0..9550e1857 100644 --- a/crates/defguard_core/src/events.rs +++ b/crates/defguard_core/src/events.rs @@ -1,5 +1,8 @@ use std::net::IpAddr; +use chrono::{NaiveDateTime, Utc}; +use serde::Serialize; + use crate::{ db::{ models::{authentication_key::AuthenticationKey, oauth2client::OAuth2Client}, @@ -8,8 +11,8 @@ use crate::{ enterprise::db::models::{ api_tokens::ApiToken, audit_stream::AuditStream, openid_provider::OpenIdProvider, }, + grpc::proto::proxy::MfaMethod, }; -use chrono::{NaiveDateTime, Utc}; /// Shared context that needs to be added to every API event /// @@ -327,17 +330,32 @@ pub enum PasswordResetEvent { PasswordResetCompleted, } +pub type ClientMFAMethod = MfaMethod; + +impl Serialize for ClientMFAMethod { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match *self { + MfaMethod::Totp => serializer.serialize_unit_variant("MfaMethod", 0, "Totp"), + MfaMethod::Email => serializer.serialize_unit_variant("MfaMethod", 1, "Email"), + MfaMethod::Oidc => serializer.serialize_unit_variant("MfaMethod", 2, "Oidc"), + } + } +} + #[derive(Debug)] pub enum DesktopClientMfaEvent { Connected { device: Device, location: WireguardNetwork, - method: MFAMethod, + method: ClientMFAMethod, }, Failed { device: Device, location: WireguardNetwork, - method: MFAMethod, + method: ClientMFAMethod, }, } diff --git a/crates/defguard_core/src/grpc/desktop_client_mfa.rs b/crates/defguard_core/src/grpc/desktop_client_mfa.rs index f4d228903..5ca6e876e 100644 --- a/crates/defguard_core/src/grpc/desktop_client_mfa.rs +++ b/crates/defguard_core/src/grpc/desktop_client_mfa.rs @@ -17,8 +17,9 @@ use crate::{ auth::{Claims, ClaimsType}, db::{ models::device::{DeviceInfo, DeviceNetworkInfo, WireguardNetworkDevice}, - Device, GatewayEvent, Id, User, UserInfo, WireguardNetwork, + Device, GatewayEvent, Id, Settings, User, UserInfo, WireguardNetwork, }, + enterprise::{db::models::openid_provider::OpenIdProvider, is_enterprise_enabled}, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, grpc::utils::parse_client_info, handlers::mail::send_email_mfa_code_email, @@ -39,18 +40,20 @@ impl From for Status { } } -struct ClientLoginSession { - method: MfaMethod, - location: WireguardNetwork, - device: Device, - user: User, +#[derive(Clone)] +pub(crate) struct ClientLoginSession { + pub(crate) method: MfaMethod, + pub(crate) location: WireguardNetwork, + pub(crate) device: Device, + pub(crate) user: User, + pub(crate) openid_auth_completed: bool, } -pub(super) struct ClientMfaServer { - pool: PgPool, +pub(crate) struct ClientMfaServer { + pub(crate) pool: PgPool, mail_tx: UnboundedSender, wireguard_tx: Sender, - sessions: HashMap, + pub(crate) sessions: HashMap, bidi_event_tx: UnboundedSender, } @@ -86,7 +89,7 @@ impl ClientMfaServer { } /// Validate JWT and extract client pubkey - fn parse_token(token: &str) -> Result { + pub(crate) fn parse_token(token: &str) -> Result { let claims = Claims::from_jwt(ClaimsType::DesktopClient, token).map_err(|err| { error!("Failed to parse JWT token: {err:?}"); Status::invalid_argument("invalid token") @@ -94,7 +97,7 @@ impl ClientMfaServer { Ok(claims.client_id) } - fn emit_event(&self, event: BidiStreamEvent) -> Result<(), ClientMfaServerError> { + pub(crate) fn emit_event(&self, event: BidiStreamEvent) -> Result<(), ClientMfaServerError> { Ok(self.bidi_event_tx.send(event)?) } @@ -193,7 +196,37 @@ impl ClientMfaServer { Status::internal("unexpected error") })?; } - } + MfaMethod::Oidc => { + if !is_enterprise_enabled() { + error!("OIDC MFA method requires enterprise feature to be enabled"); + return Err(Status::invalid_argument( + "selected MFA method not available", + )); + } + + let settings = Settings::get_current_settings(); + if !settings.use_openid_for_mfa { + error!("OIDC MFA method is not enabled in settings"); + return Err(Status::invalid_argument( + "selected MFA method not available", + )); + } + + if OpenIdProvider::get_current(&self.pool) + .await + .map_err(|err| { + error!("Failed to get current OpenID provider: {err:?}",); + Status::internal("unexpected error") + })? + .is_none() + { + error!("OIDC provider is not configured"); + return Err(Status::invalid_argument( + "selected MFA method not available", + )); + } + } + }; // generate auth token let token = Self::generate_token(&request.pubkey)?; @@ -211,6 +244,7 @@ impl ClientMfaServer { location, device, user, + openid_auth_completed: false, }, ); @@ -237,6 +271,7 @@ impl ClientMfaServer { device, location, user, + openid_auth_completed, } = session; // Prepare event context @@ -246,7 +281,23 @@ impl ClientMfaServer { // validate code match method { MfaMethod::Totp => { - if !user.verify_totp_code(&request.code.to_string()) { + let code = if let Some(code) = request.code { + code.to_string() + } else { + error!("TOTP code not provided in request"); + self.emit_event(BidiStreamEvent { + context, + event: BidiStreamEventType::DesktopClientMfa(Box::new( + DesktopClientMfaEvent::Failed { + location: location.clone(), + device: device.clone(), + method: *method, + }, + )), + })?; + return Err(Status::invalid_argument("TOTP code not provided")); + }; + if !user.verify_totp_code(&code) { error!("Provided TOTP code is not valid"); self.emit_event(BidiStreamEvent { context, @@ -254,7 +305,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Failed { location: location.clone(), device: device.clone(), - method: (*method).into(), + method: *method, }, )), })?; @@ -262,7 +313,23 @@ impl ClientMfaServer { } } MfaMethod::Email => { - if !user.verify_email_mfa_code(&request.code.to_string()) { + let code = if let Some(code) = request.code { + code.to_string() + } else { + error!("Email MFA code not provided in request"); + self.emit_event(BidiStreamEvent { + context, + event: BidiStreamEventType::DesktopClientMfa(Box::new( + DesktopClientMfaEvent::Failed { + location: location.clone(), + device: device.clone(), + method: *method, + }, + )), + })?; + return Err(Status::invalid_argument("email MFA code not provided")); + }; + if !user.verify_email_mfa_code(&code) { error!("Provided email code is not valid"); self.emit_event(BidiStreamEvent { context, @@ -270,13 +337,35 @@ impl ClientMfaServer { DesktopClientMfaEvent::Failed { location: location.clone(), device: device.clone(), - method: (*method).into(), + method: *method, }, )), })?; return Err(Status::unauthenticated("unauthorized")); } } + MfaMethod::Oidc => { + if !*openid_auth_completed { + debug!( + "User {user} tried to finish OIDC MFA login but they haven't completed the OIDC authentication yet." + ); + self.emit_event(BidiStreamEvent { + context, + event: BidiStreamEventType::DesktopClientMfa(Box::new( + DesktopClientMfaEvent::Failed { + location: location.clone(), + device: device.clone(), + method: *method, + }, + )), + })?; + return Err(Status::failed_precondition( + "OIDC authentication not completed yet", + )); + } else { + debug!("User {user} is trying to finish OIDC MFA login and the OIDC authentication has already been completed; proceeding."); + } + } } // begin transaction @@ -337,7 +426,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Connected { location: location.clone(), device: device.clone(), - method: (*method).into(), + method: *method, }, )), })?; diff --git a/crates/defguard_core/src/grpc/enrollment.rs b/crates/defguard_core/src/grpc/enrollment.rs index 4b0205e14..f5aaee110 100644 --- a/crates/defguard_core/src/grpc/enrollment.rs +++ b/crates/defguard_core/src/grpc/enrollment.rs @@ -13,7 +13,6 @@ use super::{ }, InstanceInfo, }; -use crate::grpc::utils::parse_client_info; use crate::{ db::{ models::{ @@ -24,11 +23,12 @@ use crate::{ Device, GatewayEvent, Id, Settings, User, }, enterprise::{ - db::models::enterprise_settings::EnterpriseSettings, ldap::utils::ldap_add_user, + db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, + ldap::utils::ldap_add_user, limits::update_counts, }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, EnrollmentEvent}, - grpc::utils::{build_device_config_response, new_polling_token}, + grpc::utils::{build_device_config_response, new_polling_token, parse_client_info}, handlers::{mail::send_new_device_added_email, user::check_password_strength}, headers::get_device_info, mail::Mail, @@ -193,7 +193,20 @@ impl EnrollmentServer { "Retrieving instance info for user {}({:?}).", user.username, user.id ); - let instance_info = InstanceInfo::new(settings, &user.username, &enterprise_settings); + + let openid_provider = OpenIdProvider::get_current(&self.pool) + .await + .map_err(|err| { + error!("Failed to get OpenID provider: {err}"); + Status::internal(format!("unexpected error: {err}")) + })?; + + let instance_info = InstanceInfo::new( + settings, + &user.username, + &enterprise_settings, + openid_provider, + ); debug!("Instance info {instance_info:?}"); debug!( @@ -709,11 +722,24 @@ impl EnrollmentServer { info!("Device {} remote configuration done.", device.name); + let openid_provider = OpenIdProvider::get_current(&self.pool) + .await + .map_err(|err| { + error!("Failed to get OpenID provider: {err}"); + Status::internal(format!("unexpected error: {err}")) + })?; + let response = DeviceConfigResponse { device: Some(device.clone().into()), configs: configs.into_iter().map(Into::into).collect(), instance: Some( - InstanceInfo::new(settings, &user.username, &enterprise_settings).into(), + InstanceInfo::new( + settings, + &user.username, + &enterprise_settings, + openid_provider, + ) + .into(), ), token: Some(token.token), }; diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index ee7d5c7cf..290f6517a 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -186,8 +186,8 @@ impl GatewayServer { } pub fn get_client_state_guard( - &self, - ) -> Result, GatewayServerError> { + &'_ self, + ) -> Result, GatewayServerError> { let client_state = self .client_state .lock() diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 63dd1784e..a17968603 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -10,7 +10,7 @@ use std::{ }; use chrono::{NaiveDateTime, Utc}; -use openidconnect::{core::CoreAuthenticationFlow, AuthorizationCode, CsrfToken, Nonce, Scope}; +use openidconnect::{core::CoreAuthenticationFlow, AuthorizationCode, Nonce, Scope}; use reqwest::Url; use serde::Serialize; #[cfg(feature = "worker")] @@ -56,7 +56,7 @@ use crate::{ db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, directory_sync::sync_user_groups_if_configured, grpc::polling::PollingServer, - handlers::openid_login::{make_oidc_client, user_from_claims}, + handlers::openid_login::{build_state, make_oidc_client, user_from_claims}, is_enterprise_enabled, ldap::utils::ldap_update_user_state, }, @@ -69,7 +69,7 @@ use crate::{ use crate::{auth::ClaimsType, db::GatewayEvent}; mod auth; -mod desktop_client_mfa; +pub(crate) mod desktop_client_mfa; pub mod enrollment; #[cfg(feature = "wireguard")] pub(crate) mod gateway; @@ -646,7 +646,28 @@ pub async fn run_grpc_bidi_stream( Some(core_response::Payload::ClientMfaFinish(response_payload)) } Err(err) => { - error!("client MFA finish error {err}"); + match err.code() { + Code::FailedPrecondition => { + // User not yet done with OIDC authentication. Don't log it as an error. + debug!("Client MFA finish error: {err}"); + } + _ => { + // Log other errors as errors. + error!("Client MFA finish error: {err}"); + } + } + Some(core_response::Payload::CoreError(err.into())) + } + } + } + Some(core_request::Payload::ClientMfaOidcAuthenticate(request)) => { + match client_mfa_server + .auth_mfa_session_with_oidc(request, received.device_info) + .await + { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("client MFA OIDC authenticate error {err}"); Some(core_response::Payload::CoreError(err.into())) } } @@ -686,7 +707,7 @@ pub async fn run_grpc_bidi_stream( let (url, csrf_token, nonce) = client .authorize_url( CoreAuthenticationFlow::AuthorizationCode, - CsrfToken::new_random, + || build_state(request.state), Nonce::new_random, ) .add_scope(Scope::new("email".to_string())) @@ -915,6 +936,8 @@ pub struct InstanceInfo { username: String, disable_all_traffic: bool, enterprise_enabled: bool, + use_openid_for_mfa: bool, + openid_display_name: Option, } impl InstanceInfo { @@ -922,8 +945,13 @@ impl InstanceInfo { settings: Settings, username: S, enterprise_settings: &EnterpriseSettings, + openid_provider: Option>, ) -> Self { let config = server_config(); + let openid_display_name = openid_provider + .as_ref() + .map(|provider| provider.display_name.clone()) + .unwrap_or_default(); InstanceInfo { id: settings.uuid, name: settings.instance_name, @@ -932,6 +960,12 @@ impl InstanceInfo { username: username.into(), disable_all_traffic: enterprise_settings.disable_all_traffic, enterprise_enabled: is_enterprise_enabled(), + use_openid_for_mfa: if is_enterprise_enabled() { + settings.use_openid_for_mfa + } else { + false + }, + openid_display_name, } } } @@ -946,6 +980,8 @@ impl From for proto::proxy::InstanceInfo { username: instance.username, disable_all_traffic: instance.disable_all_traffic, enterprise_enabled: instance.enterprise_enabled, + use_openid_for_mfa: instance.use_openid_for_mfa, + openid_display_name: instance.openid_display_name, } } } diff --git a/crates/defguard_core/src/grpc/utils.rs b/crates/defguard_core/src/grpc/utils.rs index 9df56e499..4ced8db38 100644 --- a/crates/defguard_core/src/grpc/utils.rs +++ b/crates/defguard_core/src/grpc/utils.rs @@ -1,5 +1,6 @@ -use sqlx::PgPool; use std::{net::IpAddr, str::FromStr}; + +use sqlx::PgPool; use tonic::Status; use super::{ @@ -15,7 +16,9 @@ use crate::{ }, Device, Id, Settings, User, }, - enterprise::db::models::enterprise_settings::EnterpriseSettings, + enterprise::db::models::{ + enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider, + }, AsCsv, }; @@ -68,6 +71,11 @@ pub(crate) async fn build_device_config_response( ) -> Result { let settings = Settings::get_current_settings(); + let openid_provider = OpenIdProvider::get_current(pool).await.map_err(|err| { + error!("Failed to get OpenID provider: {err}"); + Status::internal(format!("unexpected error: {err}")) + })?; + let networks = WireguardNetwork::all(pool).await.map_err(|err| { error!("Failed to fetch all networks: {err}"); Status::internal(format!("unexpected error: {err}")) @@ -162,7 +170,15 @@ pub(crate) async fn build_device_config_response( Ok(DeviceConfigResponse { device: Some(device.into()), configs, - instance: Some(InstanceInfo::new(settings, &user.username, &enterprise_settings).into()), + instance: Some( + InstanceInfo::new( + settings, + &user.username, + &enterprise_settings, + openid_provider, + ) + .into(), + ), token, }) } diff --git a/crates/defguard_core/src/handlers/audit_log.rs b/crates/defguard_core/src/handlers/audit_log.rs index fada147ee..76cf9ef26 100644 --- a/crates/defguard_core/src/handlers/audit_log.rs +++ b/crates/defguard_core/src/handlers/audit_log.rs @@ -7,17 +7,16 @@ use ipnetwork::IpNetwork; use sqlx::{FromRow, Postgres, QueryBuilder, Type}; use tracing::Instrument; +use super::{ + pagination::{PaginatedApiResponse, PaginatedApiResult, PaginationMeta, PaginationParams}, + DEFAULT_API_PAGE_SIZE, +}; use crate::{ appstate::AppState, auth::SessionInfo, db::{models::audit_log::AuditModule, Id}, }; -use super::{ - pagination::{PaginatedApiResponse, PaginatedApiResult, PaginationMeta, PaginationParams}, - DEFAULT_API_PAGE_SIZE, -}; - #[derive(Debug, Deserialize, Default)] pub struct FilterParams { pub from: Option>, diff --git a/crates/defguard_core/tests/integration/openid_login.rs b/crates/defguard_core/tests/integration/openid_login.rs index 7c4bfab65..68832f5a9 100644 --- a/crates/defguard_core/tests/integration/openid_login.rs +++ b/crates/defguard_core/tests/integration/openid_login.rs @@ -56,6 +56,7 @@ async fn test_openid_providers(_: PgPoolOptions, options: PgConnectOptions) { okta_private_jwk: None, directory_sync_group_match: None, username_handling: OpenidUsernameHandling::PruneEmailDomain, + use_openid_for_mfa: false, }; let response = client diff --git a/crates/defguard_event_logger/src/lib.rs b/crates/defguard_event_logger/src/lib.rs index 500e975c0..30e8085e8 100644 --- a/crates/defguard_event_logger/src/lib.rs +++ b/crates/defguard_event_logger/src/lib.rs @@ -1,12 +1,4 @@ use bytes::Bytes; -use error::EventLoggerError; -use message::{ - DefguardEvent, EnrollmentEvent, EventContext, EventLoggerMessage, LoggerEvent, VpnEvent, -}; -use sqlx::PgPool; -use tokio::sync::mpsc::UnboundedReceiver; -use tracing::{debug, error, info, trace}; - use defguard_core::db::{ models::audit_log::{ metadata::{ @@ -27,6 +19,13 @@ use defguard_core::db::{ }, NoId, }; +use error::EventLoggerError; +use message::{ + DefguardEvent, EnrollmentEvent, EventContext, EventLoggerMessage, LoggerEvent, VpnEvent, +}; +use sqlx::PgPool; +use tokio::sync::mpsc::UnboundedReceiver; +use tracing::{debug, error, info, trace}; pub mod error; pub mod message; diff --git a/crates/defguard_event_logger/src/message.rs b/crates/defguard_event_logger/src/message.rs index 476e5c0ae..1ee5d8087 100644 --- a/crates/defguard_event_logger/src/message.rs +++ b/crates/defguard_event_logger/src/message.rs @@ -1,6 +1,6 @@ -use chrono::NaiveDateTime; use std::net::IpAddr; +use chrono::NaiveDateTime; use defguard_core::{ db::{ models::{authentication_key::AuthenticationKey, oauth2client::OAuth2Client}, @@ -9,7 +9,10 @@ use defguard_core::{ enterprise::db::models::{ api_tokens::ApiToken, audit_stream::AuditStream, openid_provider::OpenIdProvider, }, - events::{ApiRequestContext, BidiRequestContext, GrpcRequestContext, InternalEventContext}, + events::{ + ApiRequestContext, BidiRequestContext, ClientMFAMethod, GrpcRequestContext, + InternalEventContext, + }, }; /// Messages that can be sent to the event logger @@ -271,7 +274,7 @@ pub enum VpnEvent { ConnectedToMfaLocation { location: WireguardNetwork, device: Device, - method: MFAMethod, + method: ClientMFAMethod, }, DisconnectedFromMfaLocation { location: WireguardNetwork, @@ -280,7 +283,7 @@ pub enum VpnEvent { MfaFailed { location: WireguardNetwork, device: Device, - method: MFAMethod, + method: ClientMFAMethod, }, ConnectedToLocation { location: WireguardNetwork, diff --git a/crates/defguard_event_router/src/lib.rs b/crates/defguard_event_router/src/lib.rs index c2c1b73de..a423ef3f9 100644 --- a/crates/defguard_event_router/src/lib.rs +++ b/crates/defguard_event_router/src/lib.rs @@ -28,10 +28,16 @@ //! event_tx.send(event).await.unwrap(); //! ``` -use defguard_core::events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}; +use std::sync::Arc; + +use defguard_core::{ + db::GatewayEvent, + events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}, + mail::Mail, +}; +use defguard_event_logger::message::{EventContext, EventLoggerMessage, LoggerEvent}; use error::EventRouterError; use events::Event; -use std::sync::Arc; use tokio::sync::{ broadcast::Sender, mpsc::{UnboundedReceiver, UnboundedSender}, @@ -39,9 +45,6 @@ use tokio::sync::{ }; use tracing::{debug, error, info}; -use defguard_core::{db::GatewayEvent, mail::Mail}; -use defguard_event_logger::message::{EventContext, EventLoggerMessage, LoggerEvent}; - mod error; mod events; mod handlers; diff --git a/proto b/proto index 20fe30dfa..eb4ac0620 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 20fe30dfa1c2985bb7a6afe1c74dd9a709e034c6 +Subproject commit eb4ac0620f54bfa58669f2ac61ea5fce5c55b521 diff --git a/web/src/i18n/en/index.ts b/web/src/i18n/en/index.ts index 9de72cd3f..fa09a5949 100644 --- a/web/src/i18n/en/index.ts +++ b/web/src/i18n/en/index.ts @@ -1238,6 +1238,11 @@ Licensing information: [https://docs.defguard.net/enterprise/license](https://do helper: 'If this option is enabled, Defguard automatically creates new accounts for users who log in for the first time using an external OpenID provider. Otherwise, the user account must first be created by an administrator.', }, + useOpenIdForMfa: { + label: 'Use external OpenID for client MFA', + helper: + 'When the external OpenID SSO Multi-Factor (MFA) process is enabled, users connecting to VPN locations that require MFA will need to authenticate via their browser using the configured provider for each connection. If this setting is disabled, MFA for those VPN locations will be handled through the internal Defguard SSO system. In that case, users must have TOTP or email-based MFA configured in their profile.', + }, usernameHandling: { label: 'Username handling', helper: diff --git a/web/src/i18n/i18n-types.ts b/web/src/i18n/i18n-types.ts index fc1b5afde..d28593106 100644 --- a/web/src/i18n/i18n-types.ts +++ b/web/src/i18n/i18n-types.ts @@ -3050,6 +3050,16 @@ type RootTranslation = { */ helper: string } + useOpenIdForMfa: { + /** + * U​s​e​ ​e​x​t​e​r​n​a​l​ ​O​p​e​n​I​D​ ​f​o​r​ ​c​l​i​e​n​t​ ​M​F​A + */ + label: string + /** + * W​h​e​n​ ​t​h​e​ ​e​x​t​e​r​n​a​l​ ​O​p​e​n​I​D​ ​S​S​O​ ​M​u​l​t​i​-​F​a​c​t​o​r​ ​(​M​F​A​)​ ​p​r​o​c​e​s​s​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​u​s​e​r​s​ ​c​o​n​n​e​c​t​i​n​g​ ​t​o​ ​V​P​N​ ​l​o​c​a​t​i​o​n​s​ ​t​h​a​t​ ​r​e​q​u​i​r​e​ ​M​F​A​ ​w​i​l​l​ ​n​e​e​d​ ​t​o​ ​a​u​t​h​e​n​t​i​c​a​t​e​ ​v​i​a​ ​t​h​e​i​r​ ​b​r​o​w​s​e​r​ ​u​s​i​n​g​ ​t​h​e​ ​c​o​n​f​i​g​u​r​e​d​ ​p​r​o​v​i​d​e​r​ ​f​o​r​ ​e​a​c​h​ ​c​o​n​n​e​c​t​i​o​n​.​ ​I​f​ ​t​h​i​s​ ​s​e​t​t​i​n​g​ ​i​s​ ​d​i​s​a​b​l​e​d​,​ ​M​F​A​ ​f​o​r​ ​t​h​o​s​e​ ​V​P​N​ ​l​o​c​a​t​i​o​n​s​ ​w​i​l​l​ ​b​e​ ​h​a​n​d​l​e​d​ ​t​h​r​o​u​g​h​ ​t​h​e​ ​i​n​t​e​r​n​a​l​ ​D​e​f​g​u​a​r​d​ ​S​S​O​ ​s​y​s​t​e​m​.​ ​I​n​ ​t​h​a​t​ ​c​a​s​e​,​ ​u​s​e​r​s​ ​m​u​s​t​ ​h​a​v​e​ ​T​O​T​P​ ​o​r​ ​e​m​a​i​l​-​b​a​s​e​d​ ​M​F​A​ ​c​o​n​f​i​g​u​r​e​d​ ​i​n​ ​t​h​e​i​r​ ​p​r​o​f​i​l​e​. + */ + helper: string + } usernameHandling: { /** * U​s​e​r​n​a​m​e​ ​h​a​n​d​l​i​n​g @@ -9391,6 +9401,16 @@ export type TranslationFunctions = { */ helper: () => LocalizedString } + useOpenIdForMfa: { + /** + * Use external OpenID for client MFA + */ + label: () => LocalizedString + /** + * When the external OpenID SSO Multi-Factor (MFA) process is enabled, users connecting to VPN locations that require MFA will need to authenticate via their browser using the configured provider for each connection. If this setting is disabled, MFA for those VPN locations will be handled through the internal Defguard SSO system. In that case, users must have TOTP or email-based MFA configured in their profile. + */ + helper: () => LocalizedString + } usernameHandling: { /** * Username handling diff --git a/web/src/i18n/pl/index.ts b/web/src/i18n/pl/index.ts index 8b2f55424..384cb9935 100644 --- a/web/src/i18n/pl/index.ts +++ b/web/src/i18n/pl/index.ts @@ -1119,6 +1119,11 @@ Uwaga, podane tutaj konfiguracje nie posiadają klucza prywatnego. Musisz uzupe helper: 'Jeśli ta opcja jest włączona, Defguard automatycznie tworzy nowe konta dla użytkowników, którzy logują się po raz pierwszy za pomocą zewnętrznego dostawcy OpenID. W innym przypadku konto użytkownika musi zostać najpierw utworzone przez administratora.', }, + useOpenIdForMfa: { + label: 'Używaj zewnętrznego OpenID dla MFA klienta', + helper: + 'Gdy zewnętrzny proces Multi-Factor Authentication (MFA) OpenID SSO jest włączony, użytkownicy łączący się z lokalizacjami VPN wymagającymi MFA będą musieli uwierzytelniać się przez swoją przeglądarkę używając skonfigurowanego dostawcy dla każdego połączenia. Jeśli to ustawienie jest wyłączone, MFA dla tych lokalizacji VPN będzie obsługiwane przez wewnętrzny system SSO Defguard. W takim przypadku użytkownicy muszą mieć skonfigurowane TOTP lub MFA oparte na e-mailu.', + }, usernameHandling: { label: 'Obsługa nazw użytkowników', helper: diff --git a/web/src/pages/settings/components/OpenIdSettings/components/OpenIdGeneralSettings.tsx b/web/src/pages/settings/components/OpenIdSettings/components/OpenIdGeneralSettings.tsx index 815b6bcce..369e2b3b4 100644 --- a/web/src/pages/settings/components/OpenIdSettings/components/OpenIdGeneralSettings.tsx +++ b/web/src/pages/settings/components/OpenIdSettings/components/OpenIdGeneralSettings.tsx @@ -22,6 +22,14 @@ export const OpenIdGeneralSettings = ({ isLoading }: { isLoading: boolean }) => control, name: 'create_account', }) as boolean; + const use_openid_for_mfa = useWatch({ + control, + name: 'use_openid_for_mfa', + }) as boolean; + const providerName = useWatch({ + control, + name: 'name', + }) as string; const options: SelectOption[] = useMemo( () => [ @@ -44,13 +52,17 @@ export const OpenIdGeneralSettings = ({ isLoading }: { isLoading: boolean }) => [localLL.general.usernameHandling.options], ); + const providerConfigured = useMemo(() => { + return providerName !== ''; + }, [providerName]); + return (

{localLL.general.title()}

{parse(localLL.general.helper())}
-
+
{/* FIXME: Really buggy when using the controller, investigate why */} /> {localLL.general.createAccount.helper()}
+
+ {/* FIXME: Really buggy when using the controller, investigate why */} + { + setValue('use_openid_for_mfa', e); + }} + disabled={isLoading || !providerConfigured} + /> + {localLL.general.useOpenIdForMfa.helper()} +
{ okta_private_jwk: z.string(), okta_dirsync_client_id: z.string(), directory_sync_group_match: z.string(), + use_openid_for_mfa: z.boolean(), }) .superRefine((val, ctx) => { if (val.name === '') { @@ -175,6 +178,7 @@ export const OpenIdSettingsForm = () => { okta_dirsync_client_id: '', directory_sync_group_match: '', username_handling: 'RemoveForbidden', + use_openid_for_mfa: false, }; if (openidData) { diff --git a/web/src/pages/settings/components/OpenIdSettings/components/style.scss b/web/src/pages/settings/components/OpenIdSettings/components/style.scss index ef6f2e6a8..017b6c9d4 100644 --- a/web/src/pages/settings/components/OpenIdSettings/components/style.scss +++ b/web/src/pages/settings/components/OpenIdSettings/components/style.scss @@ -7,6 +7,10 @@ padding-bottom: var(--spacing-s); } + .checkbox-padding { + padding-bottom: var(--spacing-s); + } + #sync-not-supported { text-align: center; margin: var(--spacing-s) 0; diff --git a/web/src/pages/settings/style.scss b/web/src/pages/settings/style.scss index 8a7133b6d..81d10988a 100644 --- a/web/src/pages/settings/style.scss +++ b/web/src/pages/settings/style.scss @@ -45,6 +45,9 @@ display: flex; align-items: center; gap: var(--spacing-xs); + .labeled-checkbox { + padding-bottom: 0; + } } section {