diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/crates/defguard_core/src/enterprise/directory_sync/mod.rs index 1bb93a409..79c3a8783 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/mod.rs @@ -28,6 +28,8 @@ use crate::{ model::ldap_sync_allowed_for_user, utils::{ldap_add_users_to_groups, ldap_delete_users, ldap_remove_users_from_groups}, }, + license::get_cached_license, + limits::{get_counts, update_counts}, }, grpc::GatewayEvent, handlers::user::check_username, @@ -644,6 +646,11 @@ async fn sync_all_users_state( let mut modified_users = Vec::new(); let mut deleted_users = Vec::new(); let mut created_users = Vec::new(); + let mut user_count = get_counts().user(); + let user_limit = get_cached_license() + .as_ref() + .and_then(|license| license.limits.as_ref().map(|limits| limits.users)); + let mut blocked_import_notification_sent = false; sync_inactive_directory_users( &mut transaction, @@ -725,7 +732,20 @@ async fn sync_all_users_state( details.phone_number.clone(), ); user.openid_sub.clone_from(&directory_user.id); + if let Some(limit) = user_limit.filter(|limit| user_count >= *limit) { + error!( + "Skipping directory sync import of user {} (email: {}) because \ + license user limit has been reached ({}/{})", + user.username, user.email, user_count, limit + ); + if !blocked_import_notification_sent { + blocked_import_notification_sent = true; + // TODO: send emails + } + continue; + } let new_user = user.save(&mut *transaction).await?; + user_count += 1; created_users.push(new_user); } } @@ -860,6 +880,7 @@ async fn sync_all_users_state( debug!("Done processing missing users"); transaction.commit().await?; + update_counts(pool).await?; // trigger LDAP sync ldap_delete_users(deleted_users.iter().collect::>(), pool).await; diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/crates/defguard_core/src/enterprise/directory_sync/tests.rs index 671a58e3c..321218894 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/tests.rs @@ -19,7 +19,14 @@ mod test { use tokio::sync::broadcast; use super::super::*; - use crate::enterprise::db::models::openid_provider::{DirectorySyncTarget, OpenIdProviderKind}; + use crate::{ + enterprise::{ + db::models::openid_provider::{DirectorySyncTarget, OpenIdProviderKind}, + license::{License, LicenseTier, set_cached_license}, + limits::{get_counts, update_counts}, + }, + grpc::proto::enterprise::license::LicenseLimits, + }; async fn get_test_network(pool: &PgPool) -> WireguardNetwork { WireguardNetwork::find_by_name(pool, "test") @@ -855,4 +862,55 @@ mod test { // No events assert!(wg_rx.try_recv().is_err()); } + + #[sqlx::test] + async fn test_users_prefetch_respects_license_user_limit( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + + let config = DefGuardConfig::new_test_config(); + let _ = SERVER_CONFIG.set(config.clone()); + let (wg_tx, mut wg_rx) = broadcast::channel::(16); + + // enable prefetching users + make_test_provider( + &pool, + DirectorySyncUserBehavior::Keep, + DirectorySyncUserBehavior::Keep, + DirectorySyncTarget::All, + true, + ) + .await; + + let user_limit = 1; + let license = License::new( + "test".to_string(), + false, + None, + Some(LicenseLimits { + users: user_limit, + devices: 100, + locations: 100, + network_devices: Some(100), + }), + None, + LicenseTier::Business, + ); + set_cached_license(Some(license)); + update_counts(&pool).await.unwrap(); + + do_directory_sync(&pool, &wg_tx).await.unwrap(); + update_counts(&pool).await.unwrap(); + + let user_count = get_counts().user(); + assert!(user_count <= user_limit); + + let defguard_users = User::all(&pool).await.unwrap(); + assert_eq!(defguard_users.len(), user_limit as usize); + + // No events + assert!(wg_rx.try_recv().is_err()); + } } diff --git a/crates/defguard_core/src/enterprise/handlers/openid_login.rs b/crates/defguard_core/src/enterprise/handlers/openid_login.rs index 83cc56cea..4d8aea9c7 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_login.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_login.rs @@ -38,8 +38,10 @@ use crate::{ appstate::AppState, enterprise::{ db::models::openid_provider::OpenIdProvider, - directory_sync::sync_user_groups_if_configured, ldap::utils::ldap_update_user_state, - limits::update_counts, + directory_sync::sync_user_groups_if_configured, + ldap::utils::ldap_update_user_state, + license::get_cached_license, + limits::{get_counts, update_counts}, }, error::WebError, handlers::{ @@ -94,6 +96,16 @@ pub fn prune_username(username: &str, handling: OpenIdUsernameHandling) -> Strin result } +fn reached_user_license_limit() -> Option<(u32, u32)> { + let user_count = get_counts().user(); + let user_limit = get_cached_license() + .as_ref() + .and_then(|license| license.limits.as_ref().map(|limits| limits.users)); + user_limit + .filter(|limit| user_count >= *limit) + .map(|limit| (user_count, limit)) +} + /// Create HTTP client and prevent following redirects fn get_async_http_client() -> Result { reqwest::Client::builder() @@ -365,6 +377,19 @@ pub async fn user_from_claims( ))); } + if let Some((user_count, limit)) = reached_user_license_limit() { + error!( + "Skipping OpenID account creation for user {} (email: {}) because \ + license user limit has been reached ({}/{})", + username, + email.as_str(), + user_count, + limit + ); + // TODO: send emails + return Err(WebError::Forbidden("License limit reached.".into())); + } + // Extract all necessary information from the token or call the userinfo endpoint. let given_name = token_claims .given_name() @@ -644,6 +669,14 @@ pub(crate) async fn auth_callback( #[cfg(test)] mod test { + use crate::{ + enterprise::{ + license::{License, LicenseTier, set_cached_license}, + limits::{Counts, set_counts}, + }, + grpc::proto::enterprise::license::LicenseLimits, + }; + use super::*; #[test] @@ -762,4 +795,62 @@ mod test { let extracted = extract_state_data(&encoded); assert_eq!(extracted, Some("data.with.dots".to_string())); } + + #[test] + fn test_reached_user_license_limit_reached() { + set_counts(Counts::new(2, 0, 0, 0)); + let license = License::new( + "test".to_string(), + false, + None, + Some(LicenseLimits { + users: 2, + devices: 100, + locations: 100, + network_devices: Some(100), + }), + None, + LicenseTier::Business, + ); + set_cached_license(Some(license)); + + assert_eq!(reached_user_license_limit(), Some((2, 2))); + } + + #[test] + fn test_reached_user_license_limit_not_reached() { + set_counts(Counts::new(1, 0, 0, 0)); + let license = License::new( + "test".to_string(), + false, + None, + Some(LicenseLimits { + users: 2, + devices: 100, + locations: 100, + network_devices: Some(100), + }), + None, + LicenseTier::Business, + ); + set_cached_license(Some(license)); + + assert_eq!(reached_user_license_limit(), None); + } + + #[test] + fn test_reached_user_license_limit_unlimited() { + set_counts(Counts::new(100, 0, 0, 0)); + let license = License::new( + "test".to_string(), + false, + None, + None, + None, + LicenseTier::Business, + ); + set_cached_license(Some(license)); + + assert_eq!(reached_user_license_limit(), None); + } } diff --git a/crates/defguard_core/src/enterprise/ldap/error.rs b/crates/defguard_core/src/enterprise/ldap/error.rs index bec3932cf..dfcf9a52c 100644 --- a/crates/defguard_core/src/enterprise/ldap/error.rs +++ b/crates/defguard_core/src/enterprise/ldap/error.rs @@ -28,6 +28,8 @@ pub enum LdapError { InvalidUsername(String), #[error("LDAP object already exists: {0}")] ObjectAlreadyExists(String), + #[error("License user limit reached: {0}/{1}")] + LicenseUserLimitReached(u32, u32), #[error("User {0} does not belong to the defined synchronization groups in {1}")] UserNotInLDAPSyncGroups(String, &'static str), } diff --git a/crates/defguard_core/src/enterprise/ldap/sync.rs b/crates/defguard_core/src/enterprise/ldap/sync.rs index 45416ec14..64086376b 100644 --- a/crates/defguard_core/src/enterprise/ldap/sync.rs +++ b/crates/defguard_core/src/enterprise/ldap/sync.rs @@ -66,9 +66,13 @@ use sqlx::{PgConnection, PgPool}; use super::{LDAPConfig, error::LdapError}; use crate::{ - enterprise::ldap::model::{ - get_users_without_ldap_path, ldap_sync_allowed_for_user, update_from_ldap_user, - user_from_searchentry, + enterprise::{ + ldap::model::{ + get_users_without_ldap_path, ldap_sync_allowed_for_user, update_from_ldap_user, + user_from_searchentry, + }, + license::get_cached_license, + limits::{get_counts, update_counts}, }, hashset, }; @@ -806,6 +810,13 @@ impl super::LDAPConnection { ) -> Result<(), LdapError> { let mut transaction = pool.begin().await?; let mut admin_count = User::find_admins(&mut *transaction).await?.len(); + let mut user_count = get_counts().user(); + + let user_limit = get_cached_license() + .as_ref() + .and_then(|license| license.limits.as_ref().map(|limits| limits.users)); + let mut blocked_import_notification_sent = false; + for user in changes.delete_defguard { if user.is_admin(&mut *transaction).await? { if admin_count == 1 { @@ -849,12 +860,27 @@ impl super::LDAPConnection { "LDAP user {} does not exist in Defguard yet, adding...", user.username ); + if let Some(limit) = user_limit.filter(|limit| user_count >= *limit) { + error!( + "Skipping LDAP import of user {} (email: {}) because license user limit \ + has been reached ({}/{})", + user.username, user.email, user_count, limit + ); + if !blocked_import_notification_sent { + blocked_import_notification_sent = true; + // TODO: send emails + } + continue; + } user.save(&mut *transaction).await?; + user_count += 1; } } transaction.commit().await?; + update_counts(pool).await?; + for user in changes.delete_ldap { debug!("Deleting user {} from LDAP", user.username); self.delete_user(&user).await?; diff --git a/crates/defguard_core/src/enterprise/ldap/tests.rs b/crates/defguard_core/src/enterprise/ldap/tests.rs index d8b01b4c2..ed342fcb1 100644 --- a/crates/defguard_core/src/enterprise/ldap/tests.rs +++ b/crates/defguard_core/src/enterprise/ldap/tests.rs @@ -5,14 +5,19 @@ use ldap3::SearchEntry; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::*; -use crate::enterprise::ldap::{ - model::{extract_rdn_value, get_users_without_ldap_path, user_from_searchentry}, - sync::{ - Authority, compute_group_sync_changes, compute_user_sync_changes, - extract_intersecting_users, +use crate::enterprise::license::{License, LicenseTier, set_cached_license}; +use crate::enterprise::{ + ldap::{ + model::{extract_rdn_value, get_users_without_ldap_path, user_from_searchentry}, + sync::{ + Authority, compute_group_sync_changes, compute_user_sync_changes, + extract_intersecting_users, + }, + test_client::{LdapEvent, group_to_test_attrs, user_to_test_attrs}, }, - test_client::{LdapEvent, group_to_test_attrs, user_to_test_attrs}, + limits::get_counts, }; +use crate::grpc::proto::enterprise::license::LicenseLimits; const PASSWORD: &str = "test_password"; @@ -2433,6 +2438,128 @@ async fn test_sync_group_membership_with_intersecting_users( assert!(ldap_conn.test_client.get_events().is_empty()); } +#[sqlx::test] +async fn test_sync_ldap_to_defguard_does_not_exceed_user_license_limit( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let _ = initialize_current_settings(&pool).await; + + let user_limit = 1; + // set license with some limits + let license = License::new( + "test".to_string(), + false, + None, + Some(LicenseLimits { + users: user_limit, + devices: 100, + locations: 100, + network_devices: Some(100), + }), + None, + LicenseTier::Business, + ); + set_cached_license(Some(license)); + + let existing_user = make_test_user("existing_user", None, None); + existing_user.save(&pool).await.unwrap(); + + crate::enterprise::limits::update_counts(&pool) + .await + .unwrap(); + + let mut ldap_conn = super::LDAPConnection::create().await.unwrap(); + let config = ldap_conn.config.clone(); + + let mut ldap_only_user = make_test_user("ldap_only_user_limit", None, None); + ldap_only_user.ldap_rdn = Some("ldap_only_user_limit".to_string()); + ldap_only_user.ldap_user_path = Some("ou=users,dc=example,dc=com".to_string()); + ldap_conn + .test_client_mut() + .add_test_user(&ldap_only_user, &config); + + ldap_conn.sync(&pool, false).await.unwrap(); + + let user_count_after_sync = get_counts().user(); + + assert!( + user_count_after_sync <= user_limit, + "LDAP sync exceeded user license limit: users={}, limit={}", + user_count_after_sync, + user_limit + ); + + let skipped_user = User::find_by_username(&pool, "ldap_only_user_limit") + .await + .unwrap(); + assert!(skipped_user.is_none()); +} + +#[sqlx::test] +async fn test_ldap_login_does_not_create_user_when_user_license_limit_is_reached( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let _ = initialize_current_settings(&pool).await; + + let user_limit = 1; + let license = License::new( + "test".to_string(), + false, + None, + Some(LicenseLimits { + users: user_limit, + devices: 100, + locations: 100, + network_devices: Some(100), + }), + None, + LicenseTier::Business, + ); + set_cached_license(Some(license)); + + let existing_user = make_test_user("existing_user_ldap_login", None, None); + existing_user.save(&pool).await.unwrap(); + crate::enterprise::limits::update_counts(&pool) + .await + .unwrap(); + + let mut ldap_conn = super::LDAPConnection::create().await.unwrap(); + let config = ldap_conn.config.clone(); + + let mut ldap_only_user = make_test_user("ldap_login_only_user_limit", None, None); + ldap_only_user.ldap_rdn = Some("ldap_login_only_user_limit".to_string()); + ldap_only_user.ldap_user_path = Some("ou=users,dc=example,dc=com".to_string()); + ldap_conn + .test_client_mut() + .add_test_user(&ldap_only_user, &config); + + let login_result = super::utils::login_through_ldap_with_connection( + &pool, + &mut ldap_conn, + "ldap_login_only_user_limit", + PASSWORD, + ) + .await; + + assert!( + matches!( + login_result, + Err(LdapError::LicenseUserLimitReached(count, limit)) + if count == user_limit && limit == user_limit + ), + "Expected LDAP login to fail with license limit reached, got: {login_result:?}" + ); + + let skipped_user = User::find_by_username(&pool, "ldap_login_only_user_limit") + .await + .unwrap(); + assert!(skipped_user.is_none()); +} + #[sqlx::test] async fn test_get_empty_user_path(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; diff --git a/crates/defguard_core/src/enterprise/ldap/utils.rs b/crates/defguard_core/src/enterprise/ldap/utils.rs index 2ae32db26..56a0476d3 100644 --- a/crates/defguard_core/src/enterprise/ldap/utils.rs +++ b/crates/defguard_core/src/enterprise/ldap/utils.rs @@ -10,7 +10,21 @@ use defguard_common::db::{ use sqlx::PgPool; use super::{LDAPConnection, error::LdapError}; -use crate::enterprise::ldap::{model::ldap_sync_allowed_for_user, with_ldap_status}; +use crate::enterprise::{ + ldap::{model::ldap_sync_allowed_for_user, with_ldap_status}, + license::get_cached_license, + limits::get_counts, +}; + +fn reached_user_license_limit() -> Option<(u32, u32)> { + let user_count = get_counts().user(); + let user_limit = get_cached_license() + .as_ref() + .and_then(|license| license.limits.as_ref().map(|limits| limits.users)); + user_limit + .filter(|limit| user_count >= *limit) + .map(|limit| (user_count, limit)) +} /// Retrieves a user from LDAP if they are in the configured LDAP sync groups. /// @@ -22,6 +36,16 @@ pub(crate) async fn login_through_ldap( ) -> Result, LdapError> { debug!("Logging in user {username} through LDAP"); let mut ldap_connection = LDAPConnection::create().await?; + login_through_ldap_with_connection(pool, &mut ldap_connection, username, password).await +} + +pub(crate) async fn login_through_ldap_with_connection( + pool: &PgPool, + ldap_connection: &mut LDAPConnection, + username: &str, + password: &str, +) -> Result, LdapError> { + debug!("Logging in user {username} through LDAP"); let mut ldap_user = ldap_connection .get_user_by_credentials(username, password) .await?; @@ -49,6 +73,16 @@ pub(crate) async fn login_through_ldap( debug!( "User {ldap_user} doesn't exist in Defguard, creating them first based on LDAP data" ); + if let Some((user_count, limit)) = reached_user_license_limit() { + error!( + "Skipping LDAP account creation for user {} (email: {}) because license user \ + limit has been reached ({}/{})", + ldap_user.username, ldap_user.email, user_count, limit + ); + // TODO: send email to admins + return Err(LdapError::LicenseUserLimitReached(user_count, limit)); + } + ldap_user.from_ldap = true; ldap_user.save(pool).await? }; diff --git a/crates/defguard_core/src/handlers/auth.rs b/crates/defguard_core/src/handlers/auth.rs index c9839bb1b..dc6cb8494 100644 --- a/crates/defguard_core/src/handlers/auth.rs +++ b/crates/defguard_core/src/handlers/auth.rs @@ -37,7 +37,7 @@ use crate::{ SessionExtractor, SessionInfo, failed_login::{check_failed_logins, log_failed_login_attempt}, }, - enterprise::ldap::utils::login_through_ldap, + enterprise::ldap::{error::LdapError, utils::login_through_ldap}, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, handlers::{ @@ -167,6 +167,9 @@ pub(crate) async fn authenticate( .await { Ok(user) => user, + Err(LdapError::LicenseUserLimitReached(_, _)) => { + return Err(WebError::Forbidden("License limit reached.".into())); + } Err(ldap_err) => { warn!( "Failed to authenticate user {username_or_email} internally and through LDAP. Internal error: {err}, LDAP error: {ldap_err}" @@ -214,6 +217,9 @@ pub(crate) async fn authenticate( debug!("User not found in DB, authenticating user {username_or_email} with LDAP"); match login_through_ldap(&appstate.pool, &username_or_email, &data.password).await { Ok(user) => user, + Err(LdapError::LicenseUserLimitReached(_, _)) => { + return Err(WebError::Forbidden("License limit reached.".into())); + } Err(err) => { info!("Failed to authenticate user {username_or_email} with LDAP: {err}"); log_failed_login_attempt(&appstate.failed_logins, &username_or_email); diff --git a/crates/defguard_core/src/handlers/mail.rs b/crates/defguard_core/src/handlers/mail.rs index 40491ad0f..7865bbf16 100644 --- a/crates/defguard_core/src/handlers/mail.rs +++ b/crates/defguard_core/src/handlers/mail.rs @@ -15,6 +15,8 @@ use defguard_mail::{ }; use reqwest::Url; use serde_json::json; +use sqlx::query_scalar; +use tera::Context; use tokio::fs::read_to_string; use super::{ApiResponse, ApiResult}; @@ -210,6 +212,37 @@ pub async fn send_gateway_reconnected_email( Ok(()) } +pub async fn get_admins_emails(pool: &PgPool) -> Result, sqlx::Error> { + debug!("Getting emails of active admins"); + query_scalar::<_, String>( + " + SELECT u.email + FROM \"user\" u + WHERE u.is_active = true + AND EXISTS ( + SELECT 1 + FROM group_user gu + JOIN \"group\" g ON gu.group_id = g.id + WHERE g.is_admin = true AND gu.user_id = u.id + )", + ) + .fetch_all(pool) + .await +} + +pub async fn send_user_import_blocked_email(pool: &PgPool) -> Result<(), WebError> { + debug!("Sending blocked user import mail to all admin users"); + let admin_emails = get_admins_emails(pool).await?; + let mut conn = pool.acquire().await?; + + for email in admin_emails { + templates::user_import_blocked_mail(&email, &mut conn, Context::new()).await?; + debug!("Scheduled blocked user import mail to admin {}", email); + } + + Ok(()) +} + pub fn send_new_device_login_email( user_email: &str, session: &SessionContext, diff --git a/crates/defguard_mail/src/mail.rs b/crates/defguard_mail/src/mail.rs index 6977e3fb5..62ae0e8a9 100644 --- a/crates/defguard_mail/src/mail.rs +++ b/crates/defguard_mail/src/mail.rs @@ -293,6 +293,7 @@ pub enum MailMessage { MFACode, PasswordReset, PasswordResetDone, + UserImportBlocked, } impl MailMessage { @@ -314,6 +315,7 @@ impl MailMessage { Self::MFACode => "Defguard: Multi-Factor Authentication code for login", Self::PasswordReset => "Password reset", Self::PasswordResetDone => "Password reset success", + Self::UserImportBlocked => "User import blocked", } } @@ -334,6 +336,7 @@ impl MailMessage { Self::MFACode => "mfa-code", Self::PasswordReset => "password-reset", Self::PasswordResetDone => "password-reset-done", + Self::UserImportBlocked => "user-import-blocked", } } @@ -354,6 +357,7 @@ impl MailMessage { Self::MFACode => include_str!("../templates/mfa-code.mjml"), // Self::PasswordReset => "", // Self::PasswordResetDone => "", + Self::UserImportBlocked => include_str!("../templates/plain-notification.mjml"), _ => "", } } diff --git a/crates/defguard_mail/src/templates.rs b/crates/defguard_mail/src/templates.rs index 9b779376f..135f2a191 100644 --- a/crates/defguard_mail/src/templates.rs +++ b/crates/defguard_mail/src/templates.rs @@ -190,6 +190,21 @@ pub fn test_mail(session: Option<&SessionContext>) -> Result Result<(), TemplateError> { + debug!("Render a plain notification mail template for blocked user import."); + let (mut tera, mut context) = get_base_tera_mjml(context, None, None, None)?; + + let message = MailMessage::UserImportBlocked; + message.fill_context(conn, &mut context).await?; + message.mail(&mut tera, &context, to)?.send_and_forget(); + + Ok(()) +} + // Mail with link to enrollment service. pub async fn new_account_mail( to: &str, diff --git a/crates/defguard_mail/templates/plain-notification.mjml b/crates/defguard_mail/templates/plain-notification.mjml new file mode 100644 index 000000000..630ea5b77 --- /dev/null +++ b/crates/defguard_mail/templates/plain-notification.mjml @@ -0,0 +1,17 @@ +{% import "macros.mjml" as macros %} +{% extends "base.mjml" %} +{% block content %} + +{{ macros::email_header() }} + + + + + {{notification_text}} + + + + +{{ macros::footer_divider() }} + +{% endblock content %} diff --git a/migrations/20260223095641_mjml_user_import.down.sql b/migrations/20260223095641_mjml_user_import.down.sql new file mode 100644 index 000000000..017421bcb --- /dev/null +++ b/migrations/20260223095641_mjml_user_import.down.sql @@ -0,0 +1 @@ +DELETE from mail_context WHERE template = 'user-import-blocked'; diff --git a/migrations/20260223095641_mjml_user_import.up.sql b/migrations/20260223095641_mjml_user_import.up.sql new file mode 100644 index 000000000..d976ef13c --- /dev/null +++ b/migrations/20260223095641_mjml_user_import.up.sql @@ -0,0 +1,3 @@ +INSERT INTO mail_context (template, section, language_tag, text) VALUES + ('user-import-blocked', 'title', 'en_US', 'User import blocked'), + ('user-import-blocked', 'notification_text', 'en_US', 'Import of an external user was blocked because it would exceed your current license capacity.'); diff --git a/web/.nvmrc b/web/.nvmrc new file mode 100644 index 000000000..a682cfb97 --- /dev/null +++ b/web/.nvmrc @@ -0,0 +1 @@ +v25