diff --git a/src-tauri/.sqlx/query-e91278b90769f39e2cdf1677ffa1193580af693f9871a7162c47393daac8af11.json b/src-tauri/.sqlx/query-76c5c9b75df39afca9cd07530ab0569d3d6f9d8924458c8b357dd400966f4175.json similarity index 83% rename from src-tauri/.sqlx/query-e91278b90769f39e2cdf1677ffa1193580af693f9871a7162c47393daac8af11.json rename to src-tauri/.sqlx/query-76c5c9b75df39afca9cd07530ab0569d3d6f9d8924458c8b357dd400966f4175.json index eb35ee44..9c6f65ed 100644 --- a/src-tauri/.sqlx/query-e91278b90769f39e2cdf1677ffa1193580af693f9871a7162c47393daac8af11.json +++ b/src-tauri/.sqlx/query-76c5c9b75df39afca9cd07530ab0569d3d6f9d8924458c8b357dd400966f4175.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" FROM location WHERE id = $1", + "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM location WHERE id = $1", "describe": { "columns": [ { @@ -62,6 +62,11 @@ "name": "location_mfa_mode: LocationMfaMode", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "service_location_mode: ServiceLocationMode", + "ordinal": 12, + "type_info": "Integer" } ], "parameters": { @@ -79,8 +84,9 @@ false, false, false, + false, false ] }, - "hash": "e91278b90769f39e2cdf1677ffa1193580af693f9871a7162c47393daac8af11" + "hash": "76c5c9b75df39afca9cd07530ab0569d3d6f9d8924458c8b357dd400966f4175" } diff --git a/src-tauri/.sqlx/query-7bbc28ee5a141e5b531a6ac5a1cbf120828a0b9c19301c92a3f71531c08c698d.json b/src-tauri/.sqlx/query-85f8edf373d3bf1d405a8fed804d9d04839e69a6c2c5cb8ad5c2f8e19547a2f6.json similarity index 83% rename from src-tauri/.sqlx/query-7bbc28ee5a141e5b531a6ac5a1cbf120828a0b9c19301c92a3f71531c08c698d.json rename to src-tauri/.sqlx/query-85f8edf373d3bf1d405a8fed804d9d04839e69a6c2c5cb8ad5c2f8e19547a2f6.json index f5faadd8..1615e9b4 100644 --- a/src-tauri/.sqlx/query-7bbc28ee5a141e5b531a6ac5a1cbf120828a0b9c19301c92a3f71531c08c698d.json +++ b/src-tauri/.sqlx/query-85f8edf373d3bf1d405a8fed804d9d04839e69a6c2c5cb8ad5c2f8e19547a2f6.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" FROM location WHERE instance_id = $1 ORDER BY name ASC", + "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM location WHERE pubkey = $1;", "describe": { "columns": [ { @@ -62,6 +62,11 @@ "name": "location_mfa_mode: LocationMfaMode", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "service_location_mode: ServiceLocationMode", + "ordinal": 12, + "type_info": "Integer" } ], "parameters": { @@ -79,8 +84,9 @@ false, false, false, + false, false ] }, - "hash": "7bbc28ee5a141e5b531a6ac5a1cbf120828a0b9c19301c92a3f71531c08c698d" + "hash": "85f8edf373d3bf1d405a8fed804d9d04839e69a6c2c5cb8ad5c2f8e19547a2f6" } diff --git a/src-tauri/.sqlx/query-ac02b04f6490a768571290d7dc77444eb0ca55a3a7e159c3b2e529ebf75f224f.json b/src-tauri/.sqlx/query-9137d3329ed718f211b5654af41b297c31706f5a5ad9ac400be116db7113a056.json similarity index 80% rename from src-tauri/.sqlx/query-ac02b04f6490a768571290d7dc77444eb0ca55a3a7e159c3b2e529ebf75f224f.json rename to src-tauri/.sqlx/query-9137d3329ed718f211b5654af41b297c31706f5a5ad9ac400be116db7113a056.json index 6df78777..012a54b3 100644 --- a/src-tauri/.sqlx/query-ac02b04f6490a768571290d7dc77444eb0ca55a3a7e159c3b2e529ebf75f224f.json +++ b/src-tauri/.sqlx/query-9137d3329ed718f211b5654af41b297c31706f5a5ad9ac400be116db7113a056.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" FROM location WHERE pubkey = $1;", + "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM location WHERE instance_id = $1 AND service_location_mode <= $2 ORDER BY name ASC", "describe": { "columns": [ { @@ -62,10 +62,15 @@ "name": "location_mfa_mode: LocationMfaMode", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "service_location_mode: ServiceLocationMode", + "ordinal": 12, + "type_info": "Integer" } ], "parameters": { - "Right": 1 + "Right": 2 }, "nullable": [ false, @@ -79,8 +84,9 @@ false, false, false, + false, false ] }, - "hash": "ac02b04f6490a768571290d7dc77444eb0ca55a3a7e159c3b2e529ebf75f224f" + "hash": "9137d3329ed718f211b5654af41b297c31706f5a5ad9ac400be116db7113a056" } diff --git a/src-tauri/.sqlx/query-3421da72f01d726c2931071203d663b197cb518dd65ec73108f85b2cb7270741.json b/src-tauri/.sqlx/query-b882379427740576d70c89eaeb815dede3c312162dcc73cea9c883289ba9fa8e.json similarity index 64% rename from src-tauri/.sqlx/query-3421da72f01d726c2931071203d663b197cb518dd65ec73108f85b2cb7270741.json rename to src-tauri/.sqlx/query-b882379427740576d70c89eaeb815dede3c312162dcc73cea9c883289ba9fa8e.json index a994e60f..5163c8ca 100644 --- a/src-tauri/.sqlx/query-3421da72f01d726c2931071203d663b197cb518dd65ec73108f85b2cb7270741.json +++ b/src-tauri/.sqlx/query-b882379427740576d70c89eaeb815dede3c312162dcc73cea9c883289ba9fa8e.json @@ -1,12 +1,12 @@ { "db_name": "SQLite", - "query": "UPDATE location SET instance_id = $1, name = $2, address = $3, pubkey = $4, endpoint = $5, allowed_ips = $6, dns = $7, network_id = $8, route_all_traffic = $9, keepalive_interval = $10, location_mfa_mode = $11 WHERE id = $12", + "query": "UPDATE location SET instance_id = $1, name = $2, address = $3, pubkey = $4, endpoint = $5, allowed_ips = $6, dns = $7, network_id = $8, route_all_traffic = $9, keepalive_interval = $10, location_mfa_mode = $11, service_location_mode = $12 WHERE id = $13", "describe": { "columns": [], "parameters": { - "Right": 12 + "Right": 13 }, "nullable": [] }, - "hash": "3421da72f01d726c2931071203d663b197cb518dd65ec73108f85b2cb7270741" + "hash": "b882379427740576d70c89eaeb815dede3c312162dcc73cea9c883289ba9fa8e" } diff --git a/src-tauri/.sqlx/query-e02047df7deea862cceca537e49ae16a8237e91eff0ee684cacd2ec1c77adb58.json b/src-tauri/.sqlx/query-ea39145f2cdc783bc78b32363cce32a87bd603debccaec23b160150766bdcd9f.json similarity index 59% rename from src-tauri/.sqlx/query-e02047df7deea862cceca537e49ae16a8237e91eff0ee684cacd2ec1c77adb58.json rename to src-tauri/.sqlx/query-ea39145f2cdc783bc78b32363cce32a87bd603debccaec23b160150766bdcd9f.json index 3d77f025..a05f49a0 100644 --- a/src-tauri/.sqlx/query-e02047df7deea862cceca537e49ae16a8237e91eff0ee684cacd2ec1c77adb58.json +++ b/src-tauri/.sqlx/query-ea39145f2cdc783bc78b32363cce32a87bd603debccaec23b160150766bdcd9f.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "INSERT INTO location (instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id \"id!\"", + "query": "INSERT INTO location (instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode, service_location_mode) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id \"id!\"", "describe": { "columns": [ { @@ -10,11 +10,11 @@ } ], "parameters": { - "Right": 11 + "Right": 12 }, "nullable": [ true ] }, - "hash": "e02047df7deea862cceca537e49ae16a8237e91eff0ee684cacd2ec1c77adb58" + "hash": "ea39145f2cdc783bc78b32363cce32a87bd603debccaec23b160150766bdcd9f" } diff --git a/src-tauri/.sqlx/query-f660459ee3beed1e88815560c3f16259e63975a3ec89a3c9b95d833774e9dfef.json b/src-tauri/.sqlx/query-f660459ee3beed1e88815560c3f16259e63975a3ec89a3c9b95d833774e9dfef.json deleted file mode 100644 index e895e552..00000000 --- a/src-tauri/.sqlx/query-f660459ee3beed1e88815560c3f16259e63975a3ec89a3c9b95d833774e9dfef.json +++ /dev/null @@ -1,86 +0,0 @@ -{ - "db_name": "SQLite", - "query": "SELECT id, instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id,route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" FROM location ORDER BY name ASC;", - "describe": { - "columns": [ - { - "name": "id", - "ordinal": 0, - "type_info": "Integer" - }, - { - "name": "instance_id", - "ordinal": 1, - "type_info": "Integer" - }, - { - "name": "name", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "address", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "pubkey", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "endpoint", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "allowed_ips", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "dns", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "network_id", - "ordinal": 8, - "type_info": "Integer" - }, - { - "name": "route_all_traffic", - "ordinal": 9, - "type_info": "Bool" - }, - { - "name": "keepalive_interval", - "ordinal": 10, - "type_info": "Integer" - }, - { - "name": "location_mfa_mode: LocationMfaMode", - "ordinal": 11, - "type_info": "Integer" - } - ], - "parameters": { - "Right": 0 - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - true, - false, - false, - false, - false - ] - }, - "hash": "f660459ee3beed1e88815560c3f16259e63975a3ec89a3c9b95d833774e9dfef" -} diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 95bd80ca..8e2748a7 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1379,6 +1379,7 @@ dependencies = [ "defguard_wireguard_rs", "dirs-next", "hyper-util", + "known-folders", "log", "nix", "prost", @@ -1420,6 +1421,8 @@ dependencies = [ "vergen-git2", "webbrowser", "winapi", + "windows 0.62.2", + "windows-acl", "windows-service", "x25519-dalek", ] @@ -3148,6 +3151,15 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "known-folders" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c644f4623d1c55eb60a9dac35e0858a59f982fb87db6ce34c872372b0a5b728f" +dependencies = [ + "windows-sys 0.60.2", +] + [[package]] name = "kuchikiki" version = "0.8.8-speedreader" @@ -8179,6 +8191,18 @@ dependencies = [ "windows-numerics 0.3.1", ] +[[package]] +name = "windows-acl" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "177b1723986bcb4c606058e77f6e8614b51c7f9ad2face6f6fd63dd5c8b3cec3" +dependencies = [ + "field-offset", + "libc", + "widestring 0.4.3", + "winapi", +] + [[package]] name = "windows-collections" version = "0.2.0" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index e2d361c6..fadc3405 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -120,6 +120,13 @@ tower = "0.5" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["winsvc", "winerror"] } windows-service = "0.7" +known-folders = "1.3" +windows = { version = "0.62", features = [ + "Win32", + "Win32_System", + "Win32_System_RemoteDesktop", +] } +windows-acl = "0.3" [features] # this feature is used for production builds or when `devPath` points to the filesystem and the built-in dev server is disabled. diff --git a/src-tauri/migrations/20251009102408_service_locations.sql b/src-tauri/migrations/20251009102408_service_locations.sql new file mode 100644 index 00000000..aa1e67ba --- /dev/null +++ b/src-tauri/migrations/20251009102408_service_locations.sql @@ -0,0 +1,4 @@ +-- 1 - disabled +-- 2 - pre-logon +-- 3 - always-on +ALTER TABLE location ADD COLUMN service_location_mode INTEGER NOT NULL DEFAULT 1; diff --git a/src-tauri/proto b/src-tauri/proto index fa9c14ef..fee70601 160000 --- a/src-tauri/proto +++ b/src-tauri/proto @@ -1 +1 @@ -Subproject commit fa9c14efd121182ec39c8716370e1250c77fa652 +Subproject commit fee706013b3bb5452c3c4dbf35bd973d0637ff25 diff --git a/src-tauri/src/active_connections.rs b/src-tauri/src/active_connections.rs index ed0ef4b2..970a31bd 100644 --- a/src-tauri/src/active_connections.rs +++ b/src-tauri/src/active_connections.rs @@ -82,7 +82,7 @@ pub(crate) async fn find_connection( pub(crate) async fn active_connections( instance: &Instance, ) -> Result, Error> { - let locations: HashSet = Location::find_by_instance_id(&*DB_POOL, instance.id) + let locations: HashSet = Location::find_by_instance_id(&*DB_POOL, instance.id, false) .await? .iter() .map(|location| location.id) diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index 2c440913..2e43bd75 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -37,7 +37,13 @@ use crate::{ service_log_watcher::stop_log_watcher_task, }, proto::DeviceConfigResponse, - service::{proto::RemoveInterfaceRequest, utils::DAEMON_CLIENT}, + service::{ + proto::{ + DeleteServiceLocationsRequest, RemoveInterfaceRequest, SaveServiceLocationsRequest, + ServiceLocation, + }, + utils::DAEMON_CLIENT, + }, tray::{configure_tray_icon, reload_tray_menu}, utils::{ disconnect_interface, execute_command, get_location_interface_details, @@ -284,14 +290,57 @@ pub async fn save_device_config( transaction.commit().await?; info!("New instance {instance} created."); trace!("Created following instance: {instance:#?}"); - let locations = Location::find_by_instance_id(&*DB_POOL, instance.id).await?; + let locations = Location::find_by_instance_id(&*DB_POOL, instance.id, true).await?; trace!("Created following locations: {locations:#?}"); + + let mut service_locations = Vec::::new(); + + for saved_location in &locations { + if saved_location.is_service_location() { + debug!( + "Adding service location {}({}) for instance {}({}) to be saved to the daemon.", + saved_location.name, saved_location.id, instance.name, instance.id, + ); + service_locations.push(saved_location.to_service_location()?); + } + } + + if !service_locations.is_empty() { + let save_request = SaveServiceLocationsRequest { + service_locations: service_locations.clone(), + instance_id: instance.uuid.clone(), + private_key: keys.prvkey.clone(), + }; + debug!( + "Saving {} service locations to the daemon for instance {}({}).", + save_request.service_locations.len(), + instance.name, + instance.id, + ); + DAEMON_CLIENT + .clone() + .save_service_locations(save_request) + .await + .map_err(|err| { + error!( + "Error while saving service locations to the daemon for instance {}({}): {err}", + instance.name, instance.id, + ); + Error::InternalError(err.to_string()) + })?; + debug!( + "Saved service locations to the daemon for instance {}({}).", + instance.name, instance.id, + ); + } + handle.emit(EventKey::InstanceUpdate.into(), ())?; let res: SaveDeviceConfigResponse = SaveDeviceConfigResponse { locations, instance, }; reload_tray_menu(&handle).await; + Ok(res) } @@ -307,7 +356,7 @@ pub async fn all_instances() -> Result>, Error> { let mut instance_info = Vec::new(); let connection_ids = get_connection_id_by_type(ConnectionType::Location).await; for instance in instances { - let locations = Location::find_by_instance_id(&*DB_POOL, instance.id).await?; + let locations = Location::find_by_instance_id(&*DB_POOL, instance.id, false).await?; let location_ids: Vec = locations.iter().map(|location| location.id).collect(); let connected = connection_ids .iter() @@ -381,7 +430,7 @@ pub async fn all_locations(instance_id: Id) -> Result, Error> "Getting information about all locations for instance {}.", instance.name ); - let locations = Location::find_by_instance_id(&*DB_POOL, instance_id).await?; + let locations = Location::find_by_instance_id(&*DB_POOL, instance_id, false).await?; trace!( "Found {} locations for instance {instance} to return information about.", locations.len() @@ -471,7 +520,7 @@ pub(crate) async fn locations_changed( device_config: &DeviceConfigResponse, ) -> Result { let db_locations: HashSet> = - Location::find_by_instance_id(transaction.as_mut(), instance.id) + Location::find_by_instance_id(transaction.as_mut(), instance.id, true) .await? .into_iter() .map(|location| { @@ -533,6 +582,8 @@ pub(crate) async fn do_update_instance( "A new base configuration has been applied to instance {instance}, even if nothing changed" ); + let mut service_locations = Vec::::new(); + // check if locations have changed if locations_changed { // process locations received in response @@ -542,13 +593,13 @@ pub(crate) async fn do_update_instance( ); // fetch existing locations for given instance let mut current_locations = - Location::find_by_instance_id(transaction.as_mut(), instance.id).await?; + Location::find_by_instance_id(transaction.as_mut(), instance.id, true).await?; for dev_config in response.configs { // parse device config let new_location = dev_config.into_location(instance.id); // check if location is already present in current locations - if let Some(position) = current_locations + let saved_location = if let Some(position) = current_locations .iter() .position(|loc| loc.network_id == new_location.network_id) { @@ -567,13 +618,24 @@ pub(crate) async fn do_update_instance( current_location.keepalive_interval = new_location.keepalive_interval; current_location.dns = new_location.dns; current_location.location_mfa_mode = new_location.location_mfa_mode; + current_location.service_location_mode = new_location.service_location_mode; current_location.save(transaction.as_mut()).await?; info!("Location {current_location} configuration updated for instance {instance}"); + current_location } else { // create new location debug!("Creating new location {new_location} for instance instance {instance}"); let new_location = new_location.save(transaction.as_mut()).await?; info!("New location {new_location} created for instance {instance}"); + new_location + }; + + if saved_location.is_service_location() { + debug!( + "Adding service location {}({}) for instance {}({}) to be saved to the daemon.", + saved_location.name, saved_location.id, instance.name, instance.id, + ); + service_locations.push(saved_location.to_service_location()?); } } @@ -590,6 +652,63 @@ pub(crate) async fn do_update_instance( } else { info!("Locations for instance {instance} didn't change. Not updating them."); } + + let private_key = WireguardKeys::find_by_instance_id(transaction.as_mut(), instance.id) + .await? + .ok_or(Error::NotFound)? + .prvkey; + + if service_locations.is_empty() { + debug!( + "No service locations to process for instance {}({})", + instance.name, instance.id + ); + } else { + debug!( + "Processing {} service location(s) for instance {}({})", + service_locations.len(), + instance.name, + instance.id + ); + + let save_request = SaveServiceLocationsRequest { + service_locations: service_locations.clone(), + instance_id: instance.uuid.clone(), + private_key: private_key.clone(), + }; + + debug!( + "Sending request to daemon to save {} service location(s) for instance {}({})", + save_request.service_locations.len(), + instance.name, + instance.id + ); + + DAEMON_CLIENT + .clone() + .save_service_locations(save_request) + .await + .map_err(|err| { + error!( + "Error while saving service locations to the daemon for instance {}({}): {err}", + instance.name, instance.id, + ); + Error::InternalError(err.to_string()) + })?; + + info!( + "Successfully saved {} service location(s) to daemon for instance {}({})", + service_locations.len(), + instance.name, + instance.id + ); + + debug!( + "Completed processing all service locations for instance {}({})", + instance.name, instance.id + ); + } + Ok(()) } @@ -813,7 +932,8 @@ pub async fn delete_instance(instance_id: Id, handle: AppHandle) -> Result<(), E }; debug!("The instance that is being deleted has been identified as {instance}"); - let instance_locations = Location::find_by_instance_id(&mut *transaction, instance_id).await?; + let instance_locations = + Location::find_by_instance_id(&mut *transaction, instance_id, false).await?; if !instance_locations.is_empty() { debug!( "Found locations associated with the instance {instance}, closing their connections." @@ -851,6 +971,14 @@ pub async fn delete_instance(instance_id: Id, handle: AppHandle) -> Result<(), E transaction.commit().await?; + DAEMON_CLIENT + .clone() + .delete_service_locations(DeleteServiceLocationsRequest { + instance_id: instance.uuid.clone(), + }) + .await + .unwrap(); + reload_tray_menu(&handle).await; handle.emit(EventKey::InstanceUpdate.into(), ())?; diff --git a/src-tauri/src/database/models/location.rs b/src-tauri/src/database/models/location.rs index 97d20dc6..005b5fdb 100644 --- a/src-tauri/src/database/models/location.rs +++ b/src-tauri/src/database/models/location.rs @@ -4,7 +4,12 @@ use serde::{Deserialize, Serialize}; use sqlx::{prelude::Type, query, query_as, query_scalar, Error as SqlxError, SqliteExecutor}; use super::{Id, NoId}; -use crate::{error::Error, proto::LocationMfaMode as ProtoLocationMfaMode}; +use crate::{ + error::Error, + proto::{ + LocationMfaMode as ProtoLocationMfaMode, ServiceLocationMode as ProtoServiceLocationMode, + }, +}; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Type)] #[repr(u32)] @@ -27,6 +32,27 @@ impl From for LocationMfaMode { } } +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Type)] +#[repr(u32)] +#[serde(rename_all = "lowercase")] +pub enum ServiceLocationMode { + Disabled = 1, + PreLogon = 2, + AlwaysOn = 3, +} + +impl From for ServiceLocationMode { + fn from(value: ProtoServiceLocationMode) -> Self { + match value { + ProtoServiceLocationMode::Unspecified | ProtoServiceLocationMode::Disabled => { + ServiceLocationMode::Disabled + } + ProtoServiceLocationMode::Prelogon => ServiceLocationMode::PreLogon, + ProtoServiceLocationMode::Alwayson => ServiceLocationMode::AlwaysOn, + } + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct Location { pub id: I, @@ -42,6 +68,7 @@ pub struct Location { pub route_all_traffic: bool, pub keepalive_interval: i64, pub location_mfa_mode: LocationMfaMode, + pub service_location_mode: ServiceLocationMode, } impl fmt::Display for Location { @@ -57,17 +84,24 @@ impl fmt::Display for Location { } impl Location { + /// Ignores service locations #[cfg(windows)] - pub(crate) async fn all<'e, E>(executor: E) -> Result, SqlxError> + pub(crate) async fn all<'e, E>( + executor: E, + include_service_locations: bool, + ) -> Result, SqlxError> where E: SqliteExecutor<'e>, { + let max_mode = if include_service_locations { 2 } else { 0 }; // 0 to exclude service locations, 2 to include them query_as!( Self, - "SELECT id, instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id,\ - route_all_traffic, keepalive_interval, \ - location_mfa_mode \"location_mfa_mode: LocationMfaMode\" \ - FROM location ORDER BY name ASC;" + "SELECT id, instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id,\ + route_all_traffic, keepalive_interval, \ + location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ + FROM location WHERE service_location_mode <= $1 \ + ORDER BY name ASC;", + max_mode ) .fetch_all(executor) .await @@ -81,7 +115,7 @@ impl Location { query!( "UPDATE location SET instance_id = $1, name = $2, address = $3, pubkey = $4, \ endpoint = $5, allowed_ips = $6, dns = $7, network_id = $8, route_all_traffic = $9, \ - keepalive_interval = $10, location_mfa_mode = $11 WHERE id = $12", + keepalive_interval = $10, location_mfa_mode = $11, service_location_mode = $12 WHERE id = $13", self.instance_id, self.name, self.address, @@ -93,6 +127,7 @@ impl Location { self.route_all_traffic, self.keepalive_interval, self.location_mfa_mode, + self.service_location_mode, self.id, ) .execute(executor) @@ -112,7 +147,7 @@ impl Location { Self, "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, \ network_id, route_all_traffic, keepalive_interval, \ - location_mfa_mode \"location_mfa_mode: LocationMfaMode\" \ + location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ FROM location WHERE id = $1", location_id ) @@ -123,16 +158,20 @@ impl Location { pub(crate) async fn find_by_instance_id<'e, E>( executor: E, instance_id: Id, + include_service_locations: bool, ) -> Result, SqlxError> where E: SqliteExecutor<'e>, { + let max_mode = if include_service_locations { 2 } else { 0 }; // 0 to exclude service locations, 2 to include them query_as!( Self, "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, \ - network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" \ - FROM location WHERE instance_id = $1 ORDER BY name ASC", - instance_id + network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ + FROM location WHERE instance_id = $1 AND service_location_mode <= $2 \ + ORDER BY name ASC", + instance_id, + max_mode ) .fetch_all(executor) .await @@ -148,7 +187,7 @@ impl Location { query_as!( Self, "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, \ - network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" \ + network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ FROM location WHERE pubkey = $1;", pubkey ) @@ -199,8 +238,8 @@ impl Location { // Insert a new record when there is no ID let id = query_scalar!( "INSERT INTO location (instance_id, name, address, pubkey, endpoint, allowed_ips, \ - dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode) \ - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) \ + dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode, service_location_mode) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) \ RETURNING id \"id!\"", self.instance_id, self.name, @@ -212,7 +251,8 @@ impl Location { self.network_id, self.route_all_traffic, self.keepalive_interval, - self.location_mfa_mode + self.location_mfa_mode, + self.service_location_mode, ) .fetch_one(executor) .await?; @@ -230,10 +270,18 @@ impl Location { route_all_traffic: self.route_all_traffic, keepalive_interval: self.keepalive_interval, location_mfa_mode: self.location_mfa_mode, + service_location_mode: self.service_location_mode, }) } } +impl Location { + pub fn is_service_location(&self) -> bool { + self.service_location_mode != ServiceLocationMode::Disabled + && self.location_mfa_mode == LocationMfaMode::Disabled + } +} + impl From> for Location { fn from(location: Location) -> Self { Self { @@ -249,6 +297,7 @@ impl From> for Location { route_all_traffic: location.route_all_traffic, keepalive_interval: location.keepalive_interval, location_mfa_mode: location.location_mfa_mode, + service_location_mode: location.service_location_mode, } } } diff --git a/src-tauri/src/enterprise/mod.rs b/src-tauri/src/enterprise/mod.rs index 98f9e5ef..8e1f8e8a 100644 --- a/src-tauri/src/enterprise/mod.rs +++ b/src-tauri/src/enterprise/mod.rs @@ -1,3 +1,4 @@ pub mod models; pub mod periodic; pub mod provisioning; +pub mod service_locations; diff --git a/src-tauri/src/enterprise/service_locations/mod.rs b/src-tauri/src/enterprise/service_locations/mod.rs new file mode 100644 index 00000000..3a2b1098 --- /dev/null +++ b/src-tauri/src/enterprise/service_locations/mod.rs @@ -0,0 +1,121 @@ +use std::collections::HashMap; + +use defguard_wireguard_rs::{error::WireguardInterfaceError, WGApi}; +use serde::{Deserialize, Serialize}; + +use crate::{ + database::models::{ + location::{Location, ServiceLocationMode}, + Id, + }, + service::proto::ServiceLocation, +}; + +#[cfg(windows)] +pub mod windows; + +#[derive(Debug, thiserror::Error)] +pub enum ServiceLocationError { + #[error("Error occurred while initializing service location API: {0}")] + InitError(String), + #[error("Failed to load service location storage: {0}")] + LoadError(String), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + DecodeError(#[from] base64::DecodeError), + #[error(transparent)] + WireGuardError(#[from] WireguardInterfaceError), + #[error(transparent)] + AddrParseError(#[from] defguard_wireguard_rs::net::IpAddrParseError), + #[error("WireGuard interface error: {0}")] + InterfaceError(String), + #[error(transparent)] + JsonError(#[from] serde_json::Error), + #[error(transparent)] + ProtoEnumError(#[from] prost::UnknownEnumValue), + #[cfg(windows)] + #[error(transparent)] + WindowsServiceError(#[from] windows_service::Error), +} + +#[allow(dead_code)] +#[derive(Default)] +pub(crate) struct ServiceLocationManager { + // Interface name: WireGuard API instance + wgapis: HashMap, + // Instance ID: Service locations connected under that instance + connected_service_locations: HashMap>, +} + +#[allow(dead_code)] +#[derive(Serialize, Deserialize)] +pub(crate) struct ServiceLocationData { + pub service_locations: Vec, + pub instance_id: String, + pub private_key: String, +} + +#[allow(dead_code)] +pub(crate) struct SingleServiceLocationData { + pub service_location: ServiceLocation, + pub instance_id: String, + pub private_key: String, +} + +impl std::fmt::Debug for ServiceLocationData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ServiceLocationData") + .field("service_locations", &self.service_locations) + .field("instance_id", &self.instance_id) + .field("private_key", &"***") + .finish() + } +} + +impl std::fmt::Debug for SingleServiceLocationData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SingleServiceLocationData") + .field("service_locations", &self.service_location) + .field("instance_id", &self.instance_id) + .field("private_key", &"***") + .finish() + } +} + +impl Location { + pub fn to_service_location(&self) -> Result { + if !self.is_service_location() { + warn!("Location {self} is not a service location, so it can't be converted to one."); + return Err(crate::error::Error::ConversionError(format!( + "Failed to convert location {} to a service location as it's either not marked as one or has MFA enabled.", + self + ))); + } + + let mode = match self.service_location_mode { + ServiceLocationMode::Disabled => { + warn!( + "Location {} has an invalid service location mode, so it can't be converted to one.", + self + ); + return Err( + crate::error::Error::ConversionError(format!("Location {} has an invalid service location mode ({:?}), so it can't be converted to one.", self, self.service_location_mode)) + ); + } + ServiceLocationMode::PreLogon => 0, + ServiceLocationMode::AlwaysOn => 1, + }; + + Ok(ServiceLocation { + name: self.name.clone(), + address: self.address.clone(), + pubkey: self.pubkey.clone(), + endpoint: self.endpoint.clone(), + allowed_ips: self.allowed_ips.clone(), + dns: self.dns.clone().unwrap_or_default(), + keepalive_interval: self.keepalive_interval.try_into().unwrap_or(0), + mode, + }) + } +} diff --git a/src-tauri/src/enterprise/service_locations/windows.rs b/src-tauri/src/enterprise/service_locations/windows.rs new file mode 100644 index 00000000..b8eb8b8c --- /dev/null +++ b/src-tauri/src/enterprise/service_locations/windows.rs @@ -0,0 +1,890 @@ +use std::{ + collections::HashMap, + fs::{self, create_dir_all}, + net::IpAddr, + path::PathBuf, + result::Result, + str::FromStr, + sync::{Arc, RwLock}, + time::Duration, +}; + +use common::{find_free_tcp_port, get_interface_name}; +use defguard_wireguard_rs::{ + host::Peer, key::Key, net::IpAddrMask, InterfaceConfiguration, WireguardInterfaceApi, +}; +use known_folders::get_known_folder_path; +use log::{debug, error, warn}; +use windows::{ + core::PSTR, + Win32::System::RemoteDesktop::{ + self, WTSQuerySessionInformationA, WTSWaitSystemEvent, WTS_CURRENT_SERVER_HANDLE, + WTS_EVENT_LOGOFF, WTS_EVENT_LOGON, WTS_SESSION_INFOA, + }, +}; +use windows_acl::acl::ACL; + +use crate::{ + enterprise::service_locations::{ + ServiceLocationData, ServiceLocationError, ServiceLocationManager, + SingleServiceLocationData, + }, + service::{ + proto::{ServiceLocation, ServiceLocationMode}, + setup_wgapi, + }, +}; + +const LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS: u64 = 5; +const DEFAULT_WIREGUARD_PORT: u16 = 51820; +const DEFGUARD_DIR: &str = "Defguard"; +const SERVICE_LOCATIONS_SUBDIR: &str = "service_locations"; + +pub(crate) async fn watch_for_login_logoff( + service_location_manager: Arc>, +) -> Result<(), ServiceLocationError> { + loop { + let mut event_flags = 0; + let success = unsafe { + WTSWaitSystemEvent( + Some(WTS_CURRENT_SERVER_HANDLE), + WTS_EVENT_LOGON | WTS_EVENT_LOGOFF, + &mut event_flags, + ) + }; + + match success { + Ok(_) => { + debug!("Waiting for system event returned with event_flags: 0x{event_flags:x}"); + } + Err(err) => { + error!("Failed waiting for login/logoff event: {err:?}"); + tokio::time::sleep(Duration::from_secs(LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS)).await; + continue; + } + }; + + if event_flags & WTS_EVENT_LOGON != 0 { + debug!("Detected user logon, attempting to auto-disconnect from service locations."); + service_location_manager + .clone() + .write() + .unwrap() + .disconnect_service_locations(Some(ServiceLocationMode::PreLogon))?; + } + if event_flags & WTS_EVENT_LOGOFF != 0 { + debug!("Detected user logoff, attempting to auto-connect to service locations."); + service_location_manager + .clone() + .write() + .unwrap() + .connect_to_service_locations()?; + } + } +} + +fn get_shared_directory() -> Result { + match get_known_folder_path(known_folders::KnownFolder::ProgramData) { + Some(mut path) => { + path.push(DEFGUARD_DIR); + path.push(SERVICE_LOCATIONS_SUBDIR); + Ok(path) + } + None => Err(ServiceLocationError::LoadError( + "Could not find ProgramData known folder".to_string(), + )), + } +} + +fn set_protected_acls(path: &str) -> Result<(), ServiceLocationError> { + debug!("Setting secure ACLs on: {path}"); + + const SYSTEM_SID: &str = "S-1-5-18"; // NT AUTHORITY\SYSTEM + const ADMINISTRATORS_SID: &str = "S-1-5-32-544"; // BUILTIN\Administrators + + const FILE_ALL_ACCESS: u32 = 0x1F01FF; + + match ACL::from_file_path(path, false) { + Ok(mut acl) => { + // Remove everything else from access + debug!("Removing all existing ACL entries for {path}"); + let all_entries = acl.all().map_err(|e| { + ServiceLocationError::LoadError(format!("Failed to get ACL entries: {e}")) + })?; + + for entry in all_entries { + if let Some(sid) = entry.sid { + if let Err(e) = acl.remove(sid.as_ptr() as *mut _, None, None) { + debug!("Note: Could not remove ACL entry (might be expected): {e}"); + } + } + } + + debug!("Cleared existing ACL entries, now adding secure entries"); + + // Add SYSTEM with full control + debug!("Adding SYSTEM with full control"); + let system_sid_result = windows_acl::helper::string_to_sid(SYSTEM_SID); + match system_sid_result { + Ok(system_sid) => { + acl.allow(system_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) + .map_err(|e| { + ServiceLocationError::LoadError(format!( + "Failed to add SYSTEM ACL: {e}" + )) + })?; + } + Err(e) => { + return Err(ServiceLocationError::LoadError(format!( + "Failed to convert SYSTEM SID: {e}" + ))); + } + } + + // Add Administrators with full control + debug!("Adding Administrators with full control"); + let admin_sid_result = windows_acl::helper::string_to_sid(ADMINISTRATORS_SID); + match admin_sid_result { + Ok(admin_sid) => { + acl.allow(admin_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) + .map_err(|e| { + ServiceLocationError::LoadError(format!( + "Failed to add Administrators ACL: {e}" + )) + })?; + } + Err(e) => { + return Err(ServiceLocationError::LoadError(format!( + "Failed to convert Administrators SID: {e}" + ))); + } + } + + debug!("Successfully set secure ACLs on {path} for SYSTEM and Administrators"); + Ok(()) + } + Err(e) => { + error!("Failed to get ACL for {path}: {e}"); + Err(ServiceLocationError::LoadError(format!( + "Failed to get ACL for {path}: {e}" + ))) + } + } +} + +fn get_instance_file_path(instance_id: &str) -> Result { + let mut path = get_shared_directory()?; + path.push(format!("{instance_id}.json")); + Ok(path) +} + +pub(crate) fn is_user_logged_in() -> bool { + debug!("Starting checking if user is logged in..."); + + unsafe { + let mut pp_sessions: *mut WTS_SESSION_INFOA = std::ptr::null_mut(); + let mut count: u32 = 0; + + debug!("Calling WTSEnumerateSessionsA..."); + let ret = RemoteDesktop::WTSEnumerateSessionsA(None, 0, 1, &mut pp_sessions, &mut count); + + match ret { + Ok(_) => { + debug!("WTSEnumerateSessionsA succeeded, found {count} sessions"); + let sessions = std::slice::from_raw_parts(pp_sessions, count as usize); + + for (index, session) in sessions.iter().enumerate() { + debug!( + "Session {index}: SessionId={}, State={:?}, WinStationName={:?}", + session.SessionId, + session.State, + std::ffi::CStr::from_ptr(session.pWinStationName.0 as *const i8) + .to_string_lossy() + ); + + if session.State == windows::Win32::System::RemoteDesktop::WTSActive { + let mut buffer = PSTR::null(); + let mut bytes_returned: u32 = 0; + + let result = WTSQuerySessionInformationA( + None, + session.SessionId, + windows::Win32::System::RemoteDesktop::WTSUserName, + &mut buffer, + &mut bytes_returned, + ); + + match result { + Ok(_) => { + if !buffer.is_null() { + let username = std::ffi::CStr::from_ptr(buffer.0 as *const i8) + .to_string_lossy() + .into_owned(); + + debug!( + "Found session {} username: {username}", + session.SessionId + ); + + windows::Win32::System::RemoteDesktop::WTSFreeMemory( + buffer.0 as *mut _, + ); + + // We found an active session with a username + return true; + } + } + Err(err) => { + debug!( + "Failed to get username for session {}: {err:?}", + session.SessionId + ); + } + } + } + } + windows::Win32::System::RemoteDesktop::WTSFreeMemory(pp_sessions as _); + debug!("No active sessions found"); + } + Err(err) => { + error!("Failed to enumerate user sessions: {err:?}"); + debug!("WTSEnumerateSessionsA failed: {err:?}"); + } + } + } + + debug!("User is not logged in."); + false +} + +impl ServiceLocationManager { + pub fn init() -> Result { + debug!("Initializing ServiceLocationApi"); + let path = get_shared_directory()?; + + debug!("Creating directory: {path:?}"); + create_dir_all(&path)?; + + if let Some(path_str) = path.to_str() { + debug!("Setting ACLs on service locations directory"); + if let Err(e) = set_protected_acls(path_str) { + warn!("Failed to set ACLs on service locations directory: {e}. Continuing anyway."); + } + } else { + warn!("Failed to convert path to string for ACL setting"); + } + + let manager = Self { + wgapis: HashMap::new(), + connected_service_locations: HashMap::new(), + }; + + debug!("ServiceLocationApi initialized successfully"); + Ok(manager) + } + + /// Check if a specific service location is already connected + fn is_service_location_connected(&self, instance_id: &str, location_pubkey: &str) -> bool { + if let Some(locations) = self.connected_service_locations.get(instance_id) { + for location in locations { + if location.pubkey == location_pubkey { + return true; + } + } + } + false + } + + /// Add a connected service location + fn add_connected_service_location( + &mut self, + instance_id: &str, + location: &ServiceLocation, + ) -> Result<(), ServiceLocationError> { + self.connected_service_locations + .entry(instance_id.to_string()) + .or_insert_with(Vec::new) + .push(location.clone()); + + debug!( + "Added connected service location for instance '{instance_id}', location '{}'", + location.name + ); + Ok(()) + } + + /// Remove connected service locations by filter (write disk-first, then memory) + fn remove_connected_service_locations( + &mut self, + filter: F, + ) -> Result<(), ServiceLocationError> + where + F: Fn(&str, &ServiceLocation) -> bool, + { + // Iterate through connected_service_locations and remove matching locations + let mut instances_to_remove = Vec::new(); + + for (instance_id, locations) in self.connected_service_locations.iter_mut() { + locations.retain(|location| !filter(instance_id, location)); + + // Mark instance for removal if it has no more locations + if locations.is_empty() { + instances_to_remove.push(instance_id.clone()); + } + } + + // Remove instances with no locations + for instance_id in instances_to_remove { + self.connected_service_locations.remove(&instance_id); + } + + debug!("Removed connected service locations matching filter"); + Ok(()) + } + + // Resets the state of the service location: + // 1. If it's an always on location, disconnects and reconnects it. + // 2. Otherwise, just disconnects it if the user is not logged in. + pub(crate) fn reset_service_location_state( + &mut self, + instance_id: &str, + location_pubkey: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Reseting the state of service location for instance_id: {instance_id}, location_pubkey: {location_pubkey}" + ); + + let service_location_data = self + .load_service_location(instance_id, location_pubkey)? + .ok_or_else(|| { + ServiceLocationError::LoadError(format!( + "Service location with pubkey {} for instance {} not found", + location_pubkey, instance_id + )) + })?; + + debug!( + "Disconnecting service location for instance_id: {instance_id}, location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name + ); + + self.disconnect_service_location(instance_id, location_pubkey)?; + + debug!( + "Disconnected service location for instance_id: {instance_id}, location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name + ); + + debug!( + "Reconnecting service location if needed for instance_id: {instance_id}, location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name + ); + + // We should reconnect only if: + // 1. It's an always on location + // 2. It's a pre-logon location and the user is not logged in + if service_location_data.service_location.mode == ServiceLocationMode::AlwaysOn as i32 + || (service_location_data.service_location.mode == ServiceLocationMode::PreLogon as i32 + && !is_user_logged_in()) + { + debug!( + "Reconnecting service location for instance_id: {instance_id}, location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name + ); + self.connect_to_service_location(&service_location_data)?; + } + + debug!("Service location state reset completed."); + + Ok(()) + } + + pub(crate) fn disconnect_service_locations_by_instance( + &mut self, + instance_id: &str, + ) -> Result<(), ServiceLocationError> { + debug!("Disconnecting all service locations for instance_id: {instance_id}"); + + if let Some(locations) = self.connected_service_locations.get(instance_id) { + // Collect locations to disconnect to avoid borrowing issues + let locations_to_disconnect: Vec<_> = locations.iter().cloned().collect(); + + for location in locations_to_disconnect { + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {ifname}"); + if let Some(mut wgapi) = self.wgapis.remove(&ifname) { + if let Err(err) = wgapi.remove_interface() { + error!("Failed to remove interface {ifname}: {err}"); + } else { + debug!("Interface {ifname} removed successfully"); + } + debug!( + "Removing connected service location for instance_id: {instance_id}, location_pubkey: {}", + location.pubkey + ); + debug!( + "Disconnected service location for instance_id: {instance_id}, location_pubkey: {}", + location.pubkey + ); + } else { + error!("Failed to find WireGuard API for interface {ifname}"); + } + } + + self.connected_service_locations.remove(instance_id); + } else { + debug!( + "No connected service locations found for instance_id: {instance_id}. Skipping disconnect" + ); + return Ok(()); + } + + debug!("Disconnected all service locations for instance_id: {instance_id}"); + + Ok(()) + } + + pub(crate) fn disconnect_service_location( + &mut self, + instance_id: &str, + location_pubkey: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Disconnecting service location for instance_id: {instance_id}, location_pubkey: {location_pubkey}" + ); + + if let Some(locations) = self.connected_service_locations.get_mut(instance_id) { + if let Some(pos) = locations + .iter() + .position(|loc| &loc.pubkey == location_pubkey) + { + let location = locations.remove(pos); + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {ifname}"); + if let Some(mut wgapi) = self.wgapis.remove(&ifname) { + if let Err(err) = wgapi.remove_interface() { + error!("Failed to remove interface {ifname}: {err}"); + } else { + debug!("Interface {ifname} removed successfully."); + } + } else { + error!("Failed to find WireGuard API for interface {ifname}. "); + } + } else { + debug!( + "Service location with pubkey {location_pubkey} for instance {instance_id} is not connected, skipping disconnect" + ); + return Ok(()); + } + } else { + debug!( + "No connected service locations found for instance_id: {instance_id}, skipping disconnect" + ); + return Ok(()); + } + + debug!( + "Disconnected service location for instance_id: {instance_id}, location_pubkey: {location_pubkey}" + ); + + Ok(()) + } + + /// Helper function to setup a WireGuard interface for a service location + fn setup_service_location_interface( + &mut self, + location: &ServiceLocation, + private_key: &str, + ) -> Result<(), ServiceLocationError> { + let peer_key = Key::from_str(&location.pubkey)?; + + let mut peer = Peer::new(peer_key.clone()); + peer.set_endpoint(&location.endpoint)?; + + peer.persistent_keepalive_interval = location.keepalive_interval.try_into().ok(); + + let allowed_ips = location + .allowed_ips + .split(',') + .map(str::to_string) + .collect::>(); + + for allowed_ip in &allowed_ips { + match IpAddrMask::from_str(allowed_ip) { + Ok(addr) => { + peer.allowed_ips.push(addr); + } + Err(err) => { + error!( + "Error parsing IP address {allowed_ip} while setting up interface for \ + location {location:?}, error details: {err}" + ); + } + } + } + + let mut addresses = Vec::new(); + + for address in location.address.split(',') { + addresses.push(IpAddrMask::from_str(address.trim())?); + } + + let config = InterfaceConfiguration { + name: location.name.clone(), + prvkey: private_key.to_string(), + addresses, + port: find_free_tcp_port().unwrap_or(DEFAULT_WIREGUARD_PORT) as u32, + peers: vec![peer.clone()], + mtu: None, + }; + + let ifname = location.name.clone(); + let ifname = get_interface_name(&ifname); + let mut wgapi = match setup_wgapi(&ifname) { + Ok(api) => api, + Err(err) => { + let msg = format!("Failed to setup WireGuard API for interface {ifname}: {err:?}"); + debug!("{msg}"); + return Err(ServiceLocationError::InterfaceError(msg)); + } + }; + + wgapi.create_interface()?; + + // Extract DNS configuration if available + let dns_string = location.dns.clone(); + let dns_entries = dns_string.split(',').map(str::trim).collect::>(); + // We assume that every entry that can't be parsed as an IP address is a domain name. + let mut dns = Vec::new(); + let mut search_domains = Vec::new(); + for entry in dns_entries { + if let Ok(ip) = entry.parse::() { + dns.push(ip); + } else { + search_domains.push(entry); + } + } + + debug!( + "Configuring interface {ifname} with DNS: {:?} and search domains: {:?}", + dns, search_domains + ); + debug!("Interface Configuration: {:?}", config); + + wgapi.configure_interface(&config)?; + wgapi.configure_dns(&dns, &search_domains)?; + + self.wgapis.insert(ifname.clone(), wgapi); + + debug!("Interface {ifname} configured successfully."); + Ok(()) + } + + pub(crate) fn connect_to_service_location( + &mut self, + location_data: &SingleServiceLocationData, + ) -> Result<(), ServiceLocationError> { + let instance_id = &location_data.instance_id; + let location_pubkey = &location_data.service_location.pubkey; + debug!( + "Connecting to service location for instance_id: {instance_id}, location_pubkey: {location_pubkey}" + ); + + // Check if already connected to this service location + if self.is_service_location_connected(instance_id, location_pubkey) { + debug!( + "Service location with pubkey {location_pubkey} for instance {instance_id} is already connected, skipping" + ); + return Ok(()); + } + + let location_data = self + .load_service_location(instance_id, location_pubkey)? + .ok_or_else(|| { + ServiceLocationError::LoadError(format!( + "Service location with pubkey {} for instance {} not found", + location_pubkey, instance_id + )) + })?; + + self.setup_service_location_interface( + &location_data.service_location, + &location_data.private_key, + )?; + self.add_connected_service_location( + &location_data.instance_id, + &location_data.service_location, + )?; + let ifname = get_interface_name(&location_data.service_location.name); + debug!("Successfully connected to service location '{ifname}'"); + + Ok(()) + } + + pub(crate) fn disconnect_service_locations( + &mut self, + mode: Option, + ) -> Result<(), ServiceLocationError> { + debug!("Disconnecting service locations with mode: {mode:?}"); + + for (instance, locations) in self.connected_service_locations.iter() { + for location in locations { + debug!( + "Found connected service location for instance_id: {instance}, location_pubkey: {}", + location.pubkey + ); + if let Some(m) = mode { + let location_mode: ServiceLocationMode = location.mode.try_into()?; + if location_mode != m { + debug!( + "Skipping interface {} due to the service location mode doesn't match the requested mode (expected {m:?}, found {:?})", + location.name, location.mode + ); + continue; + } + } + + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {ifname}"); + if let Some(mut wgapi) = self.wgapis.remove(&ifname) { + if let Err(err) = wgapi.remove_interface() { + error!("Failed to remove interface {ifname}: {err}"); + } else { + debug!("Interface {ifname} removed successfully."); + } + } else { + error!("Failed to find WireGuard API for interface {ifname}"); + } + } + } + + self.remove_connected_service_locations(|_, location| { + if let Some(m) = mode { + let location_mode: ServiceLocationMode = location + .mode + .try_into() + .unwrap_or(ServiceLocationMode::AlwaysOn); + location_mode == m + } else { + true + } + })?; + + debug!("Service locations disconnected."); + + Ok(()) + } + + pub(crate) fn connect_to_service_locations(&mut self) -> Result<(), ServiceLocationError> { + debug!("Attempting to auto-connect to VPN..."); + + let data = self.load_service_locations()?; + debug!("Loaded {} instance(s) from ServiceLocationApi", data.len()); + + for instance_data in data { + debug!( + "Found service locations for instance ID: {}", + instance_data.instance_id + ); + debug!( + "Instance has {} service location(s)", + instance_data.service_locations.len() + ); + for location in instance_data.service_locations { + debug!("Service Location: {location:?}"); + + if location.mode == ServiceLocationMode::PreLogon as i32 { + if is_user_logged_in() { + debug!( + "Skipping pre-logon service location '{}' because user is logged in", + location.name + ); + continue; + } + debug!( + "Proceeding to connect pre-logon service location '{}' because no user is logged in", + location.name + ); + } + + if self.is_service_location_connected(&instance_data.instance_id, &location.pubkey) + { + debug!( + "Skipping service location '{}' because it's already connected", + location.name + ); + continue; + } + + if let Err(err) = + self.setup_service_location_interface(&location, &instance_data.private_key) + { + debug!( + "Failed to setup service location interface for '{}': {err:?}", + location.name + ); + continue; + } + + if let Err(err) = + self.add_connected_service_location(&instance_data.instance_id, &location) + { + debug!( + "Failed to persist connected service location after auto-connect: {err:?}" + ); + } + + debug!( + "Successfully connected to service location '{}'", + location.name + ); + } + } + + debug!("Auto-connect attempt completed"); + + Ok(()) + } + + pub fn save_service_locations( + &self, + service_locations: &[ServiceLocation], + instance_id: &str, + private_key: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Received a request to save {} service location(s) for instance {instance_id}", + service_locations.len(), + ); + + debug!("Service locations to save: {service_locations:?}"); + + create_dir_all(get_shared_directory()?)?; + + let instance_file_path = get_instance_file_path(instance_id)?; + + let service_location_data = ServiceLocationData { + service_locations: service_locations.to_vec(), + instance_id: instance_id.to_string(), + private_key: private_key.to_string(), + }; + + let json = serde_json::to_string_pretty(&service_location_data)?; + + debug!("Writing service location data to file: {instance_file_path:?}"); + + fs::write(&instance_file_path, &json)?; + + if let Some(file_path_str) = instance_file_path.to_str() { + debug!("Setting ACLs on service location file: {file_path_str}"); + if let Err(e) = set_protected_acls(file_path_str) { + warn!( + "Failed to set ACLs on service location file {file_path_str}: {e}. File saved but may have insecure permissions." + ); + } else { + debug!("Successfully set ACLs on service location file"); + } + } else { + warn!("Failed to convert file path to string for ACL setting"); + } + + debug!( + "Service locations saved successfully for instance {instance_id} to {:?}", + instance_file_path + ); + Ok(()) + } + + fn load_service_locations(&self) -> Result, ServiceLocationError> { + let base_dir = get_shared_directory()?; + let mut all_locations_data = Vec::new(); + + if base_dir.exists() { + for entry in fs::read_dir(base_dir)? { + let entry = entry?; + let file_path = entry.path(); + + if file_path.is_file() + && file_path.extension().and_then(|s| s.to_str()) == Some("json") + { + match fs::read_to_string(&file_path) { + Ok(data) => match serde_json::from_str::(&data) { + Ok(locations_data) => { + all_locations_data.push(locations_data); + } + Err(e) => { + error!( + "Failed to parse service locations from file {:?}: {e}", + file_path + ); + } + }, + Err(e) => { + error!("Failed to read service locations file {:?}: {e}", file_path); + } + } + } + } + } + + debug!( + "Loaded service locations data for {} instances", + all_locations_data.len() + ); + Ok(all_locations_data) + } + + fn load_service_location( + &self, + instance_id: &str, + location_pubkey: &str, + ) -> Result, ServiceLocationError> { + debug!("Loading service location for instance {instance_id} and pubkey {location_pubkey}"); + + let instance_file_path = get_instance_file_path(instance_id)?; + + if instance_file_path.exists() { + let data = fs::read_to_string(&instance_file_path)?; + let service_location_data = serde_json::from_str::(&data)?; + + for location in service_location_data.service_locations { + if location.pubkey == location_pubkey { + debug!( + "Successfully loaded service location for instance {instance_id} and pubkey {location_pubkey}" + ); + return Ok(Some(SingleServiceLocationData { + service_location: location, + instance_id: service_location_data.instance_id, + private_key: service_location_data.private_key, + })); + } + } + + debug!( + "No service location found for instance {instance_id} with pubkey {location_pubkey}" + ); + Ok(None) + } else { + debug!("No service location file found for instance {instance_id}"); + Ok(None) + } + } + + pub(crate) fn delete_all_service_locations_for_instance( + &self, + instance_id: &str, + ) -> Result<(), ServiceLocationError> { + debug!("Deleting all service locations for instance {instance_id}"); + + let instance_file_path = get_instance_file_path(instance_id)?; + + if instance_file_path.exists() { + fs::remove_file(&instance_file_path)?; + debug!("Successfully deleted all service locations for instance {instance_id}"); + } else { + debug!("No service location file found for instance {instance_id}"); + } + + Ok(()) + } +} diff --git a/src-tauri/src/error.rs b/src-tauri/src/error.rs index 4ba25ef5..a34c4330 100644 --- a/src-tauri/src/error.rs +++ b/src-tauri/src/error.rs @@ -44,6 +44,10 @@ pub enum Error { StateLockFail, #[error("Failed to acquire lock on mutex. {0}")] PoisonError(String), + #[error("Failed to convert value. {0}")] + ConversionError(String), + #[error("JSON error: {0}")] + JsonError(#[from] serde_json::Error), } // we must manually implement serde::Serialize diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index e22f4938..d02dbc1b 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -32,7 +32,7 @@ pub mod wg_config; pub mod proto { use crate::database::models::{ - location::{Location, LocationMfaMode as MfaMode}, + location::{Location, LocationMfaMode as MfaMode, ServiceLocationMode as SLocationMode}, Id, NoId, }; @@ -55,6 +55,11 @@ pub mod proto { } }; + let service_location_mode = match self.service_location_mode { + Some(_service_location_mode) => self.service_location_mode().into(), + None => SLocationMode::Disabled, // Default to disabled if not set + }; + Location { id: NoId, instance_id, @@ -68,6 +73,7 @@ pub mod proto { route_all_traffic: false, keepalive_interval: self.keepalive_interval.into(), location_mfa_mode, + service_location_mode, } } } diff --git a/src-tauri/src/periodic/mod.rs b/src-tauri/src/periodic/mod.rs index 0b4fab97..37daff37 100644 --- a/src-tauri/src/periodic/mod.rs +++ b/src-tauri/src/periodic/mod.rs @@ -1,9 +1,9 @@ -use self::{ - connection::verify_active_connections, purge_stats::purge_stats, version::poll_version, -}; use tauri::AppHandle; use tokio::select; +use self::{ + connection::verify_active_connections, purge_stats::purge_stats, version::poll_version, +}; use crate::enterprise::periodic::config::poll_config; pub mod connection; diff --git a/src-tauri/src/service/mod.rs b/src-tauri/src/service/mod.rs index 25c41f1d..c4de34ea 100644 --- a/src-tauri/src/service/mod.rs +++ b/src-tauri/src/service/mod.rs @@ -50,6 +50,10 @@ use tracing::{debug, error, info, info_span, Instrument}; use self::config::Config; use super::VERSION; +use crate::enterprise::service_locations::ServiceLocationError; +#[cfg(windows)] +use crate::enterprise::service_locations::ServiceLocationManager; +use crate::service::proto::{DeleteServiceLocationsRequest, SaveServiceLocationsRequest}; #[cfg(windows)] const DAEMON_HTTP_PORT: u16 = 54127; @@ -72,6 +76,11 @@ pub enum DaemonError { Unexpected(String), #[error(transparent)] TransportError(#[from] tonic::transport::Error), + #[error(transparent)] + ServiceLocationError(#[from] ServiceLocationError), + #[cfg(windows)] + #[error(transparent)] + WindowsServiceError(#[from] windows_service::Error), } type IfName = String; @@ -86,22 +95,29 @@ pub(crate) struct DaemonService { wgapis: Arc>>, stats_period: Duration, stat_tasks: Arc>>>, + #[cfg(windows)] + service_location_manager: Arc>, } impl DaemonService { #[must_use] - pub fn new(config: &Config) -> Self { + pub fn new( + config: &Config, + #[cfg(windows)] service_location_manager: Arc>, + ) -> Self { Self { wgapis: Arc::new(RwLock::new(HashMap::new())), stats_period: Duration::from_secs(config.stats_period), stat_tasks: Arc::new(Mutex::new(HashMap::new())), + #[cfg(windows)] + service_location_manager, } } } type InterfaceDataStream = Pin> + Send>>; -fn setup_wgapi(ifname: &str) -> Result { +pub(crate) fn setup_wgapi(ifname: &str) -> Result { let wgapi = WG::new(ifname).map_err(|err| { let msg = format!("Failed to setup WireGuard API for interface {ifname}: {err}"); error!("{msg}"); @@ -115,6 +131,118 @@ fn setup_wgapi(ifname: &str) -> Result { impl DesktopDaemonService for DaemonService { type ReadInterfaceDataStream = InterfaceDataStream; + #[cfg(not(windows))] + async fn save_service_locations( + &self, + _request: tonic::Request, + ) -> Result, Status> { + debug!("Saved service location request received, this is currently not supported on Unix systems"); + Ok(Response::new(())) + } + + #[cfg(not(windows))] + async fn delete_service_locations( + &self, + _request: tonic::Request, + ) -> Result, Status> { + debug!("Saved service location request received, this is currently not supported on Unix systems"); + Ok(Response::new(())) + } + + #[cfg(windows)] + async fn save_service_locations( + &self, + request: tonic::Request, + ) -> Result, Status> { + debug!("Received a request to save service location"); + let service_location = request.into_inner(); + + match self + .service_location_manager + .clone() + .read() + .unwrap() + .save_service_locations( + service_location.service_locations.as_slice(), + &service_location.instance_id, + &service_location.private_key, + ) { + Ok(()) => { + debug!("Service location saved successfully"); + } + Err(e) => { + let msg = format!("Failed to save service location: {e}"); + error!(msg); + return Err(Status::internal(msg)); + } + } + + for saved_location in service_location.service_locations { + match self + .service_location_manager + .clone() + .write() + .unwrap() + .reset_service_location_state(&service_location.instance_id, &saved_location.pubkey) + { + Ok(()) => { + debug!( + "Service location '{}' state reset successfully", + saved_location.name + ); + } + Err(e) => { + error!( + "Failed to reset state for service location '{}': {e}", + saved_location.name + ); + } + } + } + + Ok(Response::new(())) + } + + #[cfg(windows)] + async fn delete_service_locations( + &self, + request: tonic::Request, + ) -> Result, Status> { + debug!("Received a request to delete service location"); + let instance_id = request.into_inner().instance_id; + + self.service_location_manager + .clone() + .write() + .unwrap() + .disconnect_service_locations_by_instance(&instance_id) + .map_err(|e| { + let msg = format!("Failed to disconnect service location: {e}"); + error!(msg); + Status::internal(msg) + })?; + + match self + .service_location_manager + .clone() + .read() + .unwrap() + .delete_all_service_locations_for_instance(&instance_id) + { + Ok(()) => { + debug!("Service location deleted successfully"); + Ok(Response::new(())) + } + Err(e) => { + error!("Failed to delete service location: {}", e); + Err(Status::internal(format!( + "Failed to delete service location: {}", + e + ))) + } + } + } + async fn create_interface( &self, request: tonic::Request, @@ -413,11 +541,14 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { } #[cfg(windows)] -pub async fn run_server(config: Config) -> anyhow::Result<()> { +pub(crate) async fn run_server( + config: Config, + service_location_manager: Arc>, +) -> anyhow::Result<()> { debug!("Starting Defguard interface management daemon"); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), DAEMON_HTTP_PORT); - let daemon_service = DaemonService::new(&config); + let daemon_service = DaemonService::new(&config, service_location_manager); info!("Defguard daemon version {VERSION} started, listening on {addr}",); debug!("Defguard daemon configuration: {config:?}"); diff --git a/src-tauri/src/service/windows.rs b/src-tauri/src/service/windows.rs index aa79ab5a..5efd258f 100644 --- a/src-tauri/src/service/windows.rs +++ b/src-tauri/src/service/windows.rs @@ -1,8 +1,13 @@ -use std::{ffi::OsString, sync::mpsc, time::Duration}; +use std::{ + ffi::OsString, + result::Result, + sync::{mpsc, Arc, RwLock}, + time::Duration, +}; use clap::Parser; -use log::error; -use tokio::runtime::Runtime; +use error; +use tokio::{runtime::Runtime, select}; use windows_service::{ define_windows_service, service::{ @@ -10,15 +15,21 @@ use windows_service::{ ServiceType, }, service_control_handler::{register, ServiceControlHandlerResult}, - service_dispatcher, Result, + service_dispatcher, }; -use crate::service::{run_server, utils::logging_setup, Config}; +use crate::{ + enterprise::service_locations::{ + windows::watch_for_login_logoff, ServiceLocationError, ServiceLocationManager, + }, + service::{run_server, utils::logging_setup, Config, DaemonError}, +}; static SERVICE_NAME: &str = "DefguardService"; const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; +const LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS: Duration = Duration::from_secs(5); -pub fn run() -> Result<()> { +pub fn run() -> Result<(), windows_service::Error> { // Register generated `ffi_service_main` with the system and start the service, blocking // this thread until the service is stopped. service_dispatcher::start(SERVICE_NAME, ffi_service_main) @@ -33,7 +44,7 @@ pub fn service_main(_arguments: Vec) { } } -fn run_service() -> Result<()> { +fn run_service() -> Result<(), DaemonError> { // Create a channel to be able to poll a stop event from the service worker loop. let (shutdown_tx, shutdown_rx) = mpsc::channel::(); let shutdown_tx_server = shutdown_tx.clone(); @@ -81,12 +92,89 @@ fn run_service() -> Result<()> { std::process::exit(1); })); + let service_location_manager = match ServiceLocationManager::init() { + Ok(api) => { + info!("Service locations storage initialized successfully"); + Ok(api) + } + Err(e) => { + error!( + "Failed to initialize service locations storage: {}. Shutting down service location thread", + e + ); + Err(ServiceLocationError::InitError(e.to_string())) + } + }?; + + let service_location_manager = Arc::new(RwLock::new(service_location_manager)); + + let service_location_manager_clone = service_location_manager.clone(); runtime.spawn(async move { - let server_result = run_server(config).await; + let manager = service_location_manager_clone.clone(); + + let service_location_task = async move { + info!("Starting service location management task"); + + info!("Attempting to auto-connect to service locations"); + match manager.write().unwrap().connect_to_service_locations() { + Ok(_) => { + info!("Auto-connect to service locations completed successfully"); + } + Err(e) => { + warn!( + "Error while trying to auto-connect to service locations: {e}. \ + Will continue monitoring for login/logoff events.", + ); + } + } - if server_result.is_err() { - let _ = shutdown_tx_server.send(2); - } + info!("Starting login/logoff event monitoring"); + loop { + match watch_for_login_logoff( + manager.clone(), + ).await { + Ok(_) => { + warn!("Login/logoff event monitoring ended unexpectedly"); + break; + } + Err(e) => { + error!( + "Error in login/logoff event monitoring: {e}. Restarting in {LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS:?} seconds...", + ); + tokio::time::sleep(LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS).await; + info!("Restarting login/logoff event monitoring"); + } + } + } + + warn!("Service location management task terminated"); + Ok::<(), ServiceLocationError>(()) + }; + + let server_task = async move { + run_server(config, service_location_manager_clone).await + }; + + let result = select! { + result = service_location_task => { + warn!("Service location task completed"); + result.map_err(|e| format!("Service location error: {e}")) + } + result = server_task => { + warn!("Server task completed"); + result.map_err(|e| format!("Server error: {e}")) + } + }; + + let signal = if result.is_err() { + error!("Task ended with error: {:?}", result.err()); + 2 + } else { + info!("Task ended without an error."); + 1 + }; + + let _ = shutdown_tx_server.send(signal); }); loop { diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs index 987c1358..92230c75 100644 --- a/src-tauri/src/utils.rs +++ b/src-tauri/src/utils.rs @@ -6,14 +6,16 @@ use sqlx::query; use tauri::{AppHandle, Emitter, Manager}; use tonic::Code; use tracing::Level; -#[cfg(target_os = "windows")] +#[cfg(windows)] use winapi::shared::winerror::ERROR_SERVICE_DOES_NOT_EXIST; -#[cfg(target_os = "windows")] +#[cfg(windows)] use windows_service::{ service::{ServiceAccess, ServiceState}, service_manager::{ServiceManager, ServiceManagerAccess}, }; +#[cfg(windows)] +use crate::active_connections::find_connection; use crate::{ appstate::AppState, commands::LocationInterfaceDetails, @@ -38,9 +40,6 @@ use crate::{ ConnectionType, }; -#[cfg(target_os = "windows")] -use crate::active_connections::find_connection; - pub(crate) static DEFAULT_ROUTE_IPV4: &str = "0.0.0.0/0"; pub(crate) static DEFAULT_ROUTE_IPV6: &str = "::/0"; @@ -942,7 +941,7 @@ async fn check_connection( #[cfg(target_os = "windows")] pub async fn sync_connections(app_handle: &AppHandle) -> Result<(), Error> { debug!("Synchronizing active connections with the systems' state..."); - let all_locations = Location::all(&*DB_POOL).await?; + let all_locations = Location::all(&*DB_POOL, false).await?; let service_manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT).map_err( |err| {