From b353128de0a6e13d3cfba3e702865313d1197603 Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Tue, 21 Oct 2025 16:56:04 +0200 Subject: [PATCH 01/13] service locations 1 --- src-tauri/Cargo.lock | 157 +- src-tauri/Cargo.toml | 7 + .../20251009102408_service_locations.sql | 4 + src-tauri/src/commands.rs | 212 ++- src-tauri/src/database/models/location.rs | 54 +- src-tauri/src/enterprise/mod.rs | 1 + .../src/enterprise/service_locations/mod.rs | 95 ++ .../enterprise/service_locations/windows.rs | 1261 +++++++++++++++++ src-tauri/src/error.rs | 6 + src-tauri/src/lib.rs | 8 +- src-tauri/src/service/mod.rs | 119 +- src-tauri/src/service/windows.rs | 99 +- src-tauri/src/utils.rs | 5 + 13 files changed, 1970 insertions(+), 58 deletions(-) create mode 100644 src-tauri/migrations/20251009102408_service_locations.sql create mode 100644 src-tauri/src/enterprise/service_locations/mod.rs create mode 100644 src-tauri/src/enterprise/service_locations/windows.rs diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index adbf6915..7abd6fed 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -562,7 +562,7 @@ dependencies = [ "miniz_oxide", "object", "rustc-demangle", - "windows-link 0.2.0", + "windows-link 0.2.1", ] [[package]] @@ -956,7 +956,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link 0.2.0", + "windows-link 0.2.1", ] [[package]] @@ -1403,6 +1403,7 @@ dependencies = [ "defguard_wireguard_rs", "dirs-next", "hyper-util", + "known-folders", "log", "nix", "prost", @@ -1444,6 +1445,8 @@ dependencies = [ "vergen-git2", "webbrowser", "winapi", + "windows 0.62.2", + "windows-acl", "windows-service", "x25519-dalek", ] @@ -2802,7 +2805,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core 0.62.1", + "windows-core 0.62.2", ] [[package]] @@ -3179,6 +3182,15 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "known-folders" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c644f4623d1c55eb60a9dac35e0858a59f982fb87db6ce34c872372b0a5b728f" +dependencies = [ + "windows-sys 0.60.2", +] + [[package]] name = "kuchikiki" version = "0.8.8-speedreader" @@ -3268,7 +3280,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" dependencies = [ "cfg-if", - "windows-link 0.2.0", + "windows-link 0.2.1", ] [[package]] @@ -6282,7 +6294,7 @@ dependencies = [ "tao-macros", "unicode-segmentation", "url", - "windows", + "windows 0.61.3", "windows-core 0.61.2", "windows-version", "x11-dl", @@ -6361,7 +6373,7 @@ dependencies = [ "webkit2gtk", "webview2-com", "window-vibrancy", - "windows", + "windows 0.61.3", ] [[package]] @@ -6603,7 +6615,7 @@ dependencies = [ "tauri-plugin", "thiserror 2.0.17", "url", - "windows", + "windows 0.61.3", "zbus", ] @@ -6678,7 +6690,7 @@ dependencies = [ "url", "webkit2gtk", "webview2-com", - "windows", + "windows 0.61.3", ] [[package]] @@ -6704,7 +6716,7 @@ dependencies = [ "url", "webkit2gtk", "webview2-com", - "windows", + "windows 0.61.3", "wry", ] @@ -6764,7 +6776,7 @@ checksum = "0b1e66e07de489fe43a46678dd0b8df65e0c973909df1b60ba33874e297ba9b9" dependencies = [ "quick-xml 0.37.5", "thiserror 2.0.17", - "windows", + "windows 0.61.3", "windows-version", ] @@ -8076,7 +8088,7 @@ checksum = "d4ba622a989277ef3886dd5afb3e280e3dd6d974b766118950a08f8f678ad6a4" dependencies = [ "webview2-com-macros", "webview2-com-sys", - "windows", + "windows 0.61.3", "windows-core 0.61.2", "windows-implement", "windows-interface", @@ -8100,7 +8112,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "36695906a1b53a3bf5c4289621efedac12b73eeb0b89e7e1a89b517302d5d75c" dependencies = [ "thiserror 2.0.17", - "windows", + "windows 0.61.3", "windows-core 0.61.2", ] @@ -8129,6 +8141,12 @@ dependencies = [ "wasite", ] +[[package]] +name = "widestring" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c168940144dd21fd8046987c16a46a33d5fc84eec29ef9dcddc2ac9e31526b7c" + [[package]] name = "widestring" version = "1.2.0" @@ -8187,11 +8205,35 @@ version = "0.61.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" dependencies = [ - "windows-collections", + "windows-collections 0.2.0", "windows-core 0.61.2", - "windows-future", + "windows-future 0.2.1", "windows-link 0.1.3", - "windows-numerics", + "windows-numerics 0.2.0", +] + +[[package]] +name = "windows" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" +dependencies = [ + "windows-collections 0.3.2", + "windows-core 0.62.2", + "windows-future 0.3.2", + "windows-numerics 0.3.1", +] + +[[package]] +name = "windows-acl" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "177b1723986bcb4c606058e77f6e8614b51c7f9ad2face6f6fd63dd5c8b3cec3" +dependencies = [ + "field-offset", + "libc", + "widestring 0.4.3", + "winapi", ] [[package]] @@ -8203,6 +8245,15 @@ dependencies = [ "windows-core 0.61.2", ] +[[package]] +name = "windows-collections" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" +dependencies = [ + "windows-core 0.62.2", +] + [[package]] name = "windows-core" version = "0.61.2" @@ -8218,15 +8269,15 @@ dependencies = [ [[package]] name = "windows-core" -version = "0.62.1" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6844ee5416b285084d3d3fffd743b925a6c9385455f64f6d4fa3031c4c2749a9" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", - "windows-link 0.2.0", - "windows-result 0.4.0", - "windows-strings 0.5.0", + "windows-link 0.2.1", + "windows-result 0.4.1", + "windows-strings 0.5.1", ] [[package]] @@ -8237,14 +8288,25 @@ checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" dependencies = [ "windows-core 0.61.2", "windows-link 0.1.3", - "windows-threading", + "windows-threading 0.1.0", +] + +[[package]] +name = "windows-future" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" +dependencies = [ + "windows-core 0.62.2", + "windows-link 0.2.1", + "windows-threading 0.2.1", ] [[package]] name = "windows-implement" -version = "0.60.1" +version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edb307e42a74fb6de9bf3a02d9712678b22399c87e6fa869d6dfcd8c1b7754e0" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", @@ -8253,9 +8315,9 @@ dependencies = [ [[package]] name = "windows-interface" -version = "0.59.2" +version = "0.59.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0abd1ddbc6964ac14db11c7213d6532ef34bd9aa042c2e5935f59d7908b46a5" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", @@ -8270,9 +8332,9 @@ checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" [[package]] name = "windows-link" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-numerics" @@ -8284,6 +8346,16 @@ dependencies = [ "windows-link 0.1.3", ] +[[package]] +name = "windows-numerics" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" +dependencies = [ + "windows-core 0.62.2", + "windows-link 0.2.1", +] + [[package]] name = "windows-registry" version = "0.5.3" @@ -8306,11 +8378,11 @@ dependencies = [ [[package]] name = "windows-result" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7084dcc306f89883455a206237404d3eaf961e5bd7e0f312f7c91f57eb44167f" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-link 0.2.0", + "windows-link 0.2.1", ] [[package]] @@ -8320,7 +8392,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d24d6bcc7f734a4091ecf8d7a64c5f7d7066f45585c1861eba06449909609c8a" dependencies = [ "bitflags 2.9.4", - "widestring", + "widestring 1.2.0", "windows-sys 0.52.0", ] @@ -8335,11 +8407,11 @@ dependencies = [ [[package]] name = "windows-strings" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7218c655a553b0bed4426cf54b20d7ba363ef543b52d515b3e48d7fd55318dda" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ - "windows-link 0.2.0", + "windows-link 0.2.1", ] [[package]] @@ -8393,7 +8465,7 @@ version = "0.61.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f109e41dd4a3c848907eb83d5a42ea98b3769495597450cf6d153507b166f0f" dependencies = [ - "windows-link 0.2.0", + "windows-link 0.2.1", ] [[package]] @@ -8448,7 +8520,7 @@ version = "0.53.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d42b7b7f66d2a06854650af09cfdf8713e427a439c97ad65a6375318033ac4b" dependencies = [ - "windows-link 0.2.0", + "windows-link 0.2.1", "windows_aarch64_gnullvm 0.53.0", "windows_aarch64_msvc 0.53.0", "windows_i686_gnu 0.53.0", @@ -8468,13 +8540,22 @@ dependencies = [ "windows-link 0.1.3", ] +[[package]] +name = "windows-threading" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" +dependencies = [ + "windows-link 0.2.1", +] + [[package]] name = "windows-version" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "700dad7c058606087f6fdc1f88da5841e06da40334413c6cd4367b25ef26d24e" dependencies = [ - "windows-link 0.2.0", + "windows-link 0.2.1", ] [[package]] @@ -8765,7 +8846,7 @@ dependencies = [ "webkit2gtk", "webkit2gtk-sys", "webview2-com", - "windows", + "windows 0.61.3", "windows-core 0.61.2", "windows-version", "x11-dl", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 8ca1a3b3..2637e0ea 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -120,6 +120,13 @@ tower = "0.5" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["winsvc", "winerror"] } windows-service = "0.7" +known-folders = "1.3" +windows = { version = "0.62", features = [ + "Win32", + "Win32_System", + "Win32_System_RemoteDesktop", +] } +windows-acl = "0.3" [features] # this feature is used for production builds or when `devPath` points to the filesystem and the built-in dev server is disabled. diff --git a/src-tauri/migrations/20251009102408_service_locations.sql b/src-tauri/migrations/20251009102408_service_locations.sql new file mode 100644 index 00000000..aa1e67ba --- /dev/null +++ b/src-tauri/migrations/20251009102408_service_locations.sql @@ -0,0 +1,4 @@ +-- 1 - disabled +-- 2 - pre-logon +-- 3 - always-on +ALTER TABLE location ADD COLUMN service_location_mode INTEGER NOT NULL DEFAULT 1; diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index c638e105..a055bff4 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -21,7 +21,7 @@ use crate::{ models::{ connection::{ActiveConnection, Connection, ConnectionInfo}, instance::{Instance, InstanceInfo}, - location::{Location, LocationMfaMode}, + location::{Location, LocationMfaMode, ServiceLocationMode}, location_stats::LocationStats, tunnel::{Tunnel, TunnelConnection, TunnelConnectionInfo, TunnelStats}, wireguard_keys::WireguardKeys, @@ -29,7 +29,7 @@ use crate::{ }, DB_POOL, }, - enterprise::periodic::config::poll_instance, + enterprise::{periodic::config::poll_instance, service_locations}, error::Error, events::EventKey, log_watcher::{ @@ -37,7 +37,13 @@ use crate::{ service_log_watcher::stop_log_watcher_task, }, proto::DeviceConfigResponse, - service::{proto::RemoveInterfaceRequest, utils::DAEMON_CLIENT}, + service::{ + proto::{ + DeleteServiceLocationsRequest, RemoveInterfaceRequest, RestartServiceLocationRequest, + SaveServiceLocationsRequest, ServiceLocation, + }, + utils::DAEMON_CLIENT, + }, tray::{configure_tray_icon, reload_tray_menu}, utils::{ disconnect_interface, execute_command, get_location_interface_details, @@ -286,12 +292,86 @@ pub async fn save_device_config( trace!("Created following instance: {instance:#?}"); let locations = Location::find_by_instance_id(&*DB_POOL, instance.id).await?; trace!("Created following locations: {locations:#?}"); + + let mut service_locations = Vec::::new(); + let mut service_locations_to_restart = Vec::<(String, String)>::new(); + + for saved_location in &locations { + if saved_location.is_service_location() { + debug!( + "Adding service location {}({}) for instance {}({}) to be saved to the daemon.", + saved_location.name, saved_location.id, instance.name, instance.id, + ); + service_locations.push(saved_location.to_service_location()?); + } + } + + if !service_locations.is_empty() { + let save_request = SaveServiceLocationsRequest { + service_locations: service_locations.clone(), + instance_id: instance.uuid.clone(), + private_key: keys.prvkey.clone(), + }; + debug!( + "Saving {} service locations to the daemon for instance {}({}).", + save_request.service_locations.len(), + instance.name, + instance.id, + ); + DAEMON_CLIENT + .clone() + .save_service_locations(save_request) + .await + .map_err(|err| { + error!( + "Error while saving service locations to the daemon for instance {}({}): {err}", + instance.name, instance.id, + ); + Error::InternalError(err.to_string()) + })?; + debug!( + "Saved service locations to the daemon for instance {}({}).", + instance.name, instance.id, + ); + + let locations_pubkeys = service_locations + .iter() + .map(|loc| loc.pubkey.clone()) + .collect::>(); + + for location_pubkey in locations_pubkeys { + let restart_request = RestartServiceLocationRequest { + instance_id: instance.uuid.clone(), + pubkey: location_pubkey.clone(), + }; + debug!( + "Restarting service location with pubkey {} on instance {}.", + restart_request.pubkey, restart_request.instance_id, + ); + DAEMON_CLIENT.clone() + .reset_service_location(restart_request) + .await + .map_err(|err| { + error!( + "Error while restarting service location with pubkey {} on instance {}: {err}", + location_pubkey, instance.uuid, + ); + Error::InternalError(err.to_string()) + })?; + debug!( + "Restarted service location with pubkey {} on instance {}.", + location_pubkey, instance.uuid + ); + } + } + handle.emit(EventKey::InstanceUpdate.into(), ())?; let res: SaveDeviceConfigResponse = SaveDeviceConfigResponse { locations, instance, }; reload_tray_menu(&handle).await; + Ok(res) } @@ -389,6 +469,16 @@ pub async fn all_locations(instance_id: Id) -> Result, Error> let active_locations_ids = get_connection_id_by_type(ConnectionType::Location).await; let mut location_info = Vec::new(); for location in locations { + // Skip service locations, those shouldn't be shown in the UI. + if location.is_service_location() { + debug!( + "Skipping service location {}({}) for instance {}({}) when returning \ + locations to the frontend.", + location.name, location.id, instance.name, instance.id, + ); + continue; + } + let info = LocationInfo { id: location.id, instance_id: location.instance_id, @@ -533,6 +623,9 @@ pub(crate) async fn do_update_instance( "A new base configuration has been applied to instance {instance}, even if nothing changed" ); + let mut service_locations = Vec::::new(); + let mut service_locations_to_reset = Vec::<(String, String)>::new(); + // check if locations have changed if locations_changed { // process locations received in response @@ -546,6 +639,7 @@ pub(crate) async fn do_update_instance( for dev_config in response.configs { // parse device config let new_location = dev_config.into_location(instance.id); + let saved_location: Location; // check if location is already present in current locations if let Some(position) = current_locations @@ -567,13 +661,24 @@ pub(crate) async fn do_update_instance( current_location.keepalive_interval = new_location.keepalive_interval; current_location.dns = new_location.dns; current_location.location_mfa_mode = new_location.location_mfa_mode; + current_location.service_location_mode = new_location.service_location_mode; current_location.save(transaction.as_mut()).await?; info!("Location {current_location} configuration updated for instance {instance}"); + saved_location = current_location; } else { // create new location debug!("Creating new location {new_location} for instance instance {instance}"); let new_location = new_location.save(transaction.as_mut()).await?; info!("New location {new_location} created for instance {instance}"); + saved_location = new_location; + } + + if saved_location.is_service_location() { + debug!( + "Adding service location {}({}) for instance {}({}) to be saved to the daemon.", + saved_location.name, saved_location.id, instance.name, instance.id, + ); + service_locations.push(saved_location.to_service_location()?); } } @@ -590,6 +695,99 @@ pub(crate) async fn do_update_instance( } else { info!("Locations for instance {instance} didn't change. Not updating them."); } + + let private_key = WireguardKeys::find_by_instance_id(transaction.as_mut(), instance.id) + .await? + .ok_or(Error::NotFound)? + .prvkey; + + if !service_locations.is_empty() { + debug!( + "Processing {} service location(s) for instance {}({})", + service_locations.len(), + instance.name, + instance.id + ); + + let save_request = SaveServiceLocationsRequest { + service_locations: service_locations.clone(), + instance_id: instance.uuid.clone(), + private_key: private_key.clone(), + }; + + debug!("Prepared save request: {save_request:#?}"); + + debug!( + "Sending request to daemon to save {} service location(s) for instance {}({})", + save_request.service_locations.len(), + instance.name, + instance.id + ); + + DAEMON_CLIENT + .clone() + .save_service_locations(save_request) + .await + .map_err(|err| { + error!( + "Error while saving service locations to the daemon for instance {}({}): {err}", + instance.name, instance.id, + ); + Error::InternalError(err.to_string()) + })?; + + info!( + "Successfully saved {} service location(s) to daemon for instance {}({})", + service_locations.len(), + instance.name, + instance.id + ); + + let service_locations_pubkeys = service_locations + .iter() + .map(|loc| loc.pubkey.clone()) + .collect::>(); + + let instance_id = instance.uuid.clone(); + + for pubkey in service_locations_pubkeys { + debug!( + "Sending state reset request for service location with pubkey {} on instance {}", + pubkey, instance_id + ); + + DAEMON_CLIENT + .clone() + .reset_service_location(RestartServiceLocationRequest { + instance_id: instance_id.clone(), + pubkey: pubkey.clone(), + }) + .await + .map_err(|err| { + error!( + "Error while restarting service location with pubkey {} on instance {}: {err}", + pubkey, instance_id, + ); + Error::InternalError(err.to_string()) + })?; + + info!( + "Successfully reset the state of service location with pubkey {} on instance {}", + pubkey, instance_id + ); + } + + debug!( + "Completed processing all service locations for instance {}({})", + instance.name, instance.id + ); + } else { + debug!( + "No service locations to process for instance {}({})", + instance.name, instance.id + ); + } + Ok(()) } @@ -851,6 +1049,14 @@ pub async fn delete_instance(instance_id: Id, handle: AppHandle) -> Result<(), E transaction.commit().await?; + DAEMON_CLIENT + .clone() + .delete_service_locations(DeleteServiceLocationsRequest { + instance_id: instance.uuid.clone(), + }) + .await + .unwrap(); + reload_tray_menu(&handle).await; handle.emit(EventKey::InstanceUpdate.into(), ())?; diff --git a/src-tauri/src/database/models/location.rs b/src-tauri/src/database/models/location.rs index 97d20dc6..e7449059 100644 --- a/src-tauri/src/database/models/location.rs +++ b/src-tauri/src/database/models/location.rs @@ -4,7 +4,10 @@ use serde::{Deserialize, Serialize}; use sqlx::{prelude::Type, query, query_as, query_scalar, Error as SqlxError, SqliteExecutor}; use super::{Id, NoId}; -use crate::{error::Error, proto::LocationMfaMode as ProtoLocationMfaMode}; +use crate::{ + error::Error, proto::LocationMfaMode as ProtoLocationMfaMode, + proto::ServiceLocationMode as ProtoServiceLocationMode, +}; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Type)] #[repr(u32)] @@ -27,6 +30,27 @@ impl From for LocationMfaMode { } } +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Type)] +#[repr(u32)] +#[serde(rename_all = "lowercase")] +pub enum ServiceLocationMode { + Disabled = 1, + PreLogon = 2, + AlwaysOn = 3, +} + +impl From for ServiceLocationMode { + fn from(value: ProtoServiceLocationMode) -> Self { + match value { + ProtoServiceLocationMode::Unspecified | ProtoServiceLocationMode::Disabled => { + ServiceLocationMode::Disabled + } + ProtoServiceLocationMode::Prelogon => ServiceLocationMode::PreLogon, + ProtoServiceLocationMode::Alwayson => ServiceLocationMode::AlwaysOn, + } + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct Location { pub id: I, @@ -42,6 +66,7 @@ pub struct Location { pub route_all_traffic: bool, pub keepalive_interval: i64, pub location_mfa_mode: LocationMfaMode, + pub service_location_mode: ServiceLocationMode, } impl fmt::Display for Location { @@ -66,7 +91,7 @@ impl Location { Self, "SELECT id, instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id,\ route_all_traffic, keepalive_interval, \ - location_mfa_mode \"location_mfa_mode: LocationMfaMode\" \ + location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ FROM location ORDER BY name ASC;" ) .fetch_all(executor) @@ -81,7 +106,7 @@ impl Location { query!( "UPDATE location SET instance_id = $1, name = $2, address = $3, pubkey = $4, \ endpoint = $5, allowed_ips = $6, dns = $7, network_id = $8, route_all_traffic = $9, \ - keepalive_interval = $10, location_mfa_mode = $11 WHERE id = $12", + keepalive_interval = $10, location_mfa_mode = $11, service_location_mode = $12 WHERE id = $13", self.instance_id, self.name, self.address, @@ -93,6 +118,7 @@ impl Location { self.route_all_traffic, self.keepalive_interval, self.location_mfa_mode, + self.service_location_mode, self.id, ) .execute(executor) @@ -112,7 +138,7 @@ impl Location { Self, "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, \ network_id, route_all_traffic, keepalive_interval, \ - location_mfa_mode \"location_mfa_mode: LocationMfaMode\" \ + location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ FROM location WHERE id = $1", location_id ) @@ -130,7 +156,7 @@ impl Location { query_as!( Self, "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, \ - network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" \ + network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ FROM location WHERE instance_id = $1 ORDER BY name ASC", instance_id ) @@ -148,7 +174,7 @@ impl Location { query_as!( Self, "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, \ - network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" \ + network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ FROM location WHERE pubkey = $1;", pubkey ) @@ -199,8 +225,8 @@ impl Location { // Insert a new record when there is no ID let id = query_scalar!( "INSERT INTO location (instance_id, name, address, pubkey, endpoint, allowed_ips, \ - dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode) \ - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) \ + dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode, service_location_mode) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) \ RETURNING id \"id!\"", self.instance_id, self.name, @@ -212,7 +238,8 @@ impl Location { self.network_id, self.route_all_traffic, self.keepalive_interval, - self.location_mfa_mode + self.location_mfa_mode, + self.service_location_mode, ) .fetch_one(executor) .await?; @@ -230,10 +257,18 @@ impl Location { route_all_traffic: self.route_all_traffic, keepalive_interval: self.keepalive_interval, location_mfa_mode: self.location_mfa_mode, + service_location_mode: self.service_location_mode, }) } } +impl Location { + pub fn is_service_location(&self) -> bool { + self.service_location_mode != ServiceLocationMode::Disabled + && self.location_mfa_mode == LocationMfaMode::Disabled + } +} + impl From> for Location { fn from(location: Location) -> Self { Self { @@ -249,6 +284,7 @@ impl From> for Location { route_all_traffic: location.route_all_traffic, keepalive_interval: location.keepalive_interval, location_mfa_mode: location.location_mfa_mode, + service_location_mode: location.service_location_mode, } } } diff --git a/src-tauri/src/enterprise/mod.rs b/src-tauri/src/enterprise/mod.rs index f9a20825..f8e5f35a 100644 --- a/src-tauri/src/enterprise/mod.rs +++ b/src-tauri/src/enterprise/mod.rs @@ -1,2 +1,3 @@ pub mod models; pub mod periodic; +pub mod service_locations; diff --git a/src-tauri/src/enterprise/service_locations/mod.rs b/src-tauri/src/enterprise/service_locations/mod.rs new file mode 100644 index 00000000..d19324a8 --- /dev/null +++ b/src-tauri/src/enterprise/service_locations/mod.rs @@ -0,0 +1,95 @@ +use defguard_wireguard_rs::error::WireguardInterfaceError; +use serde::{Deserialize, Serialize}; + +use crate::{ + database::models::{ + location::{Location, LocationMfaMode, ServiceLocationMode}, + Id, + }, + service::proto::{ServiceLocation, ServiceLocationMode as ProtoServiceLocationMode}, +}; + +#[cfg(target_os = "windows")] +pub mod windows; + +#[derive(Debug, thiserror::Error)] +pub enum ServiceLocationError { + #[error("Failed to load service location storage: {0}")] + LoadError(String), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + DecodeError(#[from] base64::DecodeError), + #[error(transparent)] + WireGuardError(#[from] WireguardInterfaceError), + #[error(transparent)] + AddrParseError(#[from] defguard_wireguard_rs::net::IpAddrParseError), + #[error("WireGuard interface error: {0}")] + InterfaceError(String), + #[error(transparent)] + JsonError(#[from] serde_json::Error), + #[error(transparent)] + ProtoEnumError(#[from] prost::UnknownEnumValue), + #[cfg(target_os = "windows")] + #[error(transparent)] + WindowsServiceError(#[from] windows_service::Error), +} + +pub(crate) struct ServiceLocationApi; + +#[derive(Serialize, Deserialize)] +pub(crate) struct ServiceLocationData { + pub service_locations: Vec, + pub instance_id: String, + pub private_key: String, +} + +impl std::fmt::Debug for ServiceLocationData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ServiceLocationData") + .field("service_locations", &self.service_locations) + .field("instance_id", &self.instance_id) + .field("private_key", &"***") + .finish() + } +} + +impl Location { + pub fn to_service_location(&self) -> Result { + if !self.is_service_location() { + warn!( + "Location {} is not a service location, so it can't be converted to one.", + self + ); + return Err(crate::error::Error::ConversionError(format!( + "Failed to convert location {} to a service location as it's either not marked as one or has MFA enabled.", + self + ))); + } + + let mode = match self.service_location_mode { + ServiceLocationMode::Disabled => { + warn!( + "Location {} has an invalid service location mode, so it can't be converted to one.", + self + ); + return Err( + crate::error::Error::ConversionError(format!("Location {} has an invalid service location mode ({:?}), so it can't be converted to one.", self, self.service_location_mode)) + ); + } + ServiceLocationMode::PreLogon => 0, + ServiceLocationMode::AlwaysOn => 1, + }; + + Ok(ServiceLocation { + name: self.name.clone(), + address: self.address.clone(), + pubkey: self.pubkey.clone(), + endpoint: self.endpoint.clone(), + allowed_ips: self.allowed_ips.clone(), + dns: self.dns.clone().unwrap_or_default(), + keepalive_interval: self.keepalive_interval.try_into().unwrap_or(0), + mode: mode, + }) + } +} diff --git a/src-tauri/src/enterprise/service_locations/windows.rs b/src-tauri/src/enterprise/service_locations/windows.rs new file mode 100644 index 00000000..e4c9c2b6 --- /dev/null +++ b/src-tauri/src/enterprise/service_locations/windows.rs @@ -0,0 +1,1261 @@ +use std::{ + fs::create_dir_all, net::IpAddr, path::PathBuf, result::Result, str::FromStr, sync::RwLock, + time::Duration, +}; + +use common::{find_free_tcp_port, get_interface_name}; +use defguard_wireguard_rs::{ + host::Peer, key::Key, net::IpAddrMask, InterfaceConfiguration, WireguardInterfaceApi, +}; +use known_folders::get_known_folder_path; +use log::{debug, error, warn}; +use windows::{ + core::PSTR, + Win32::System::RemoteDesktop::{ + self, WTSQuerySessionInformationA, WTSWaitSystemEvent, WTS_CURRENT_SERVER_HANDLE, + WTS_EVENT_LOGOFF, WTS_EVENT_LOGON, WTS_SESSION_INFOA, + }, +}; +use windows_acl::acl::ACL; +use windows_service::{ + service::{ServiceAccess, ServiceState}, + service_manager::{ServiceManager, ServiceManagerAccess}, +}; + +use crate::{ + enterprise::service_locations::{ + ServiceLocationApi, ServiceLocationData, ServiceLocationError, + }, + service::{ + proto::{ServiceLocation, ServiceLocationMode}, + setup_wgapi, + }, +}; + +const LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS: u64 = 5; +const DEFAULT_WIREGUARD_PORT: u16 = 51820; +const CONNECTED_LOCATIONS_FILENAME: &str = "connected_service_locations.json"; +const WIREGUARD_SERVICE_PREFIX: &str = "WireGuardTunnel$"; +const DEFGUARD_DIR: &str = "Defguard"; +const SERVICE_LOCATIONS_SUBDIR: &str = "service_locations"; +const INTERFACE_DOWN_CHECK_INTERVAL_MS: u64 = 100; +const INTERFACE_DOWN_TIMEOUT_MS: u64 = 5000; + +// Tuples of (instance_id, ServiceLocation) - serves as in-memory cache, should be commited to disk first +static CONNECTED_SERVICE_LOCATIONS: RwLock> = + RwLock::new(Vec::new()); + +#[derive(serde::Serialize, serde::Deserialize)] +struct PersistedConnectedLocation { + instance_id: String, + location: ServiceLocation, +} + +fn get_connected_locations_path() -> Result { + let mut path = get_shared_directory()?; + path.push(CONNECTED_LOCATIONS_FILENAME); + Ok(path) +} + +pub(crate) async fn watch_for_login_logoff() -> Result<(), ServiceLocationError> { + unsafe { + loop { + let mut event_mask: u32 = 0; + let success = WTSWaitSystemEvent( + Some(WTS_CURRENT_SERVER_HANDLE), + WTS_EVENT_LOGON | WTS_EVENT_LOGOFF, + &mut event_mask, + ); + + match success { + Ok(_) => { + debug!( + "Waiting for system event returned with event_mask: 0x{:x}", + event_mask, + ); + } + Err(err) => { + error!("Failed waiting for login/logoff event: {:?}", err); + tokio::time::sleep(Duration::from_secs(LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS)) + .await; + continue; + } + }; + + if event_mask & WTS_EVENT_LOGON != 0 { + debug!( + "Detected user logon, attempting to auto-disconnect from service locations." + ); + ServiceLocationApi::disconnect_service_locations(Some( + ServiceLocationMode::PreLogon, + ))?; + } + if event_mask & WTS_EVENT_LOGOFF != 0 { + debug!("Detected user logoff, attempting to auto-connect to service locations."); + ServiceLocationApi::connect_to_service_locations()?; + } + } + } +} + +fn get_shared_directory() -> Result { + match get_known_folder_path(known_folders::KnownFolder::ProgramData) { + Some(mut path) => { + path.push(DEFGUARD_DIR); + path.push(SERVICE_LOCATIONS_SUBDIR); + Ok(path) + } + None => Err(ServiceLocationError::LoadError( + "Could not find ProgramData known folder".to_string(), + )), + } +} + +fn set_protected_acls(path: &str) -> Result<(), ServiceLocationError> { + debug!("Setting secure ACLs on: {}", path); + + const SYSTEM_SID: &str = "S-1-5-18"; // NT AUTHORITY\SYSTEM + const ADMINISTRATORS_SID: &str = "S-1-5-32-544"; // BUILTIN\Administrators + + const FILE_ALL_ACCESS: u32 = 0x1F01FF; + + match ACL::from_file_path(path, false) { + Ok(mut acl) => { + // Remove everything else from access + debug!("Removing all existing ACL entries for {}", path); + let all_entries = acl.all().map_err(|e| { + ServiceLocationError::LoadError(format!("Failed to get ACL entries: {}", e)) + })?; + + for entry in all_entries { + if let Some(sid) = entry.sid { + if let Err(e) = acl.remove(sid.as_ptr() as *mut _, None, None) { + debug!( + "Note: Could not remove ACL entry (might be expected): {}", + e + ); + } + } + } + + debug!("Cleared existing ACL entries, now adding secure entries"); + + // Add SYSTEM with full control + debug!("Adding SYSTEM with full control"); + let system_sid_result = windows_acl::helper::string_to_sid(SYSTEM_SID); + match system_sid_result { + Ok(system_sid) => { + acl.allow(system_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) + .map_err(|e| { + ServiceLocationError::LoadError(format!( + "Failed to add SYSTEM ACL: {}", + e + )) + })?; + } + Err(e) => { + return Err(ServiceLocationError::LoadError(format!( + "Failed to convert SYSTEM SID: {}", + e + ))); + } + } + + // Add Administrators with full control + debug!("Adding Administrators with full control"); + let admin_sid_result = windows_acl::helper::string_to_sid(ADMINISTRATORS_SID); + match admin_sid_result { + Ok(admin_sid) => { + acl.allow(admin_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) + .map_err(|e| { + ServiceLocationError::LoadError(format!( + "Failed to add Administrators ACL: {}", + e + )) + })?; + } + Err(e) => { + return Err(ServiceLocationError::LoadError(format!( + "Failed to convert Administrators SID: {}", + e + ))); + } + } + + debug!( + "Successfully set secure ACLs on {} for SYSTEM and Administrators", + path + ); + Ok(()) + } + Err(e) => { + error!("Failed to get ACL for {}: {}", path, e); + Err(ServiceLocationError::LoadError(format!( + "Failed to get ACL for {}: {}", + path, e + ))) + } + } +} + +fn get_instance_file_path(instance_id: &str) -> Result { + let mut path = get_shared_directory()?; + path.push(format!("{}.json", instance_id)); + Ok(path) +} + +pub fn query_connection_status(interface_name: &str) -> Result { + let service_manager = + ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT)?; + let service_name = format!("{}{}", WIREGUARD_SERVICE_PREFIX, interface_name); + let service = service_manager.open_service(&service_name, ServiceAccess::QUERY_STATUS)?; + let status = service.query_status()?; + Ok(status.current_state == ServiceState::Running) +} + +/// Wait for an interface to go down (not running) +/// Returns Ok(()) if interface is down within timeout, Err otherwise +async fn wait_for_interface_down( + interface_name: &str, + timeout_ms: u64, +) -> Result<(), ServiceLocationError> { + let start = std::time::Instant::now(); + let timeout = Duration::from_millis(timeout_ms); + let check_interval = Duration::from_millis(INTERFACE_DOWN_CHECK_INTERVAL_MS); + + debug!( + "Waiting for interface '{}' to go down (timeout: {}ms)", + interface_name, timeout_ms + ); + + loop { + match query_connection_status(interface_name) { + Ok(is_running) => { + if !is_running { + debug!("Interface '{}' is now down", interface_name); + return Ok(()); + } + } + Err(_) => { + // If we can't query the status (e.g., service not found), assume it's down + debug!( + "Interface '{}' status query failed, assuming it's down", + interface_name + ); + return Ok(()); + } + } + + if start.elapsed() >= timeout { + let msg = format!( + "Timeout waiting for interface '{}' to go down after {}ms", + interface_name, timeout_ms + ); + error!("{}", msg); + return Err(ServiceLocationError::InterfaceError(msg)); + } + + tokio::time::sleep(check_interval).await; + } +} + +pub(crate) fn is_user_logged_in() -> bool { + debug!("Starting checking if user is logged in..."); + + unsafe { + let mut pp_sessions: *mut WTS_SESSION_INFOA = std::ptr::null_mut(); + let mut count: u32 = 0; + + debug!("Calling WTSEnumerateSessionsA..."); + let ret = RemoteDesktop::WTSEnumerateSessionsA(None, 0, 1, &mut pp_sessions, &mut count); + + match ret { + Ok(_) => { + debug!("WTSEnumerateSessionsA succeeded, found {} sessions", count); + let sessions = std::slice::from_raw_parts(pp_sessions, count as usize); + + for (index, session) in sessions.iter().enumerate() { + debug!( + "Session {}: SessionId={}, State={:?}, WinStationName={:?}", + index, + session.SessionId, + session.State, + std::ffi::CStr::from_ptr(session.pWinStationName.0 as *const i8) + .to_string_lossy() + ); + + if session.State == windows::Win32::System::RemoteDesktop::WTSActive { + let mut buffer = PSTR::null(); + let mut bytes_returned: u32 = 0; + + let result = WTSQuerySessionInformationA( + None, + session.SessionId, + windows::Win32::System::RemoteDesktop::WTSUserName, + &mut buffer, + &mut bytes_returned, + ); + + match result { + Ok(_) => { + if !buffer.is_null() { + let username = std::ffi::CStr::from_ptr(buffer.0 as *const i8) + .to_string_lossy() + .into_owned(); + + debug!( + "Found session {} username: {}", + session.SessionId, username + ); + + windows::Win32::System::RemoteDesktop::WTSFreeMemory( + buffer.0 as *mut _, + ); + + // We found an active session with a username + return true; + } + } + Err(err) => { + debug!( + "Failed to get username for session {}: {:?}", + session.SessionId, err + ); + } + } + } + } + windows::Win32::System::RemoteDesktop::WTSFreeMemory(pp_sessions as _); + debug!("No active sessions found"); + } + Err(err) => { + error!("Failed to enumerate user sessions: {:?}", err); + debug!("WTSEnumerateSessionsA failed: {:?}", err); + } + } + } + + debug!("User is not logged in."); + false +} + +impl ServiceLocationApi { + pub fn init() -> Result<(), ServiceLocationError> { + debug!("Initializing ServiceLocationApi"); + let path = get_shared_directory()?; + + debug!("Creating directory: {:?}", path); + create_dir_all(&path)?; + + if let Some(path_str) = path.to_str() { + debug!("Setting ACLs on service locations directory"); + if let Err(e) = set_protected_acls(path_str) { + warn!( + "Failed to set ACLs on service locations directory: {}. Continuing anyway.", + e + ); + } + } else { + warn!("Failed to convert path to string for ACL setting"); + } + + debug!("Loading and validating connected service locations"); + if let Err(err) = Self::load_and_validate_connected_service_locations() { + debug!( + "Failed to load and validate persisted connected service locations: {:?}", + err + ); + } + + Self::cleanup_invalid_locations()?; + + debug!("ServiceLocationApi initialized successfully"); + Ok(()) + } + + /// Load and validate connected service locations from file into memory cache + /// Verifies that each location is actually running and removes stale entries + fn load_and_validate_connected_service_locations() -> Result<(), ServiceLocationError> { + let path = get_connected_locations_path()?; + if !path.exists() { + CONNECTED_SERVICE_LOCATIONS.write().unwrap().clear(); + return Ok(()); + } + + let data = std::fs::read_to_string(&path)?; + + let persisted: Vec = serde_json::from_str(&data)?; + + let mut validated_locations = Vec::new(); + let mut removed_count = 0; + + for p in persisted { + let interface_name = get_interface_name(&p.location.name); + match query_connection_status(&interface_name) { + Ok(is_running) => { + if is_running { + validated_locations.push((p.instance_id, p.location)); + } else { + debug!( + "Removing stale service location '{}' from connected list - interface not up", + p.location.name + ); + removed_count += 1; + } + } + Err(err) => { + debug!( + "Removing service location '{}' from connected list - failed to query status: {:?}", + p.location.name, err + ); + removed_count += 1; + } + } + } + + // Update memory with validated locations + let mut guard = CONNECTED_SERVICE_LOCATIONS.write().unwrap(); + guard.clear(); + for (instance_id, location) in &validated_locations { + guard.push((instance_id.clone(), location.clone())); + } + drop(guard); + + debug!( + "Loaded {} connected service locations from file into memory ({} stale entries removed)", + validated_locations.len(), + removed_count + ); + + // Save the corrected state back to disk if we removed any stale entries + if removed_count > 0 { + let persisted: Vec = validated_locations + .into_iter() + .map(|(instance_id, location)| PersistedConnectedLocation { + instance_id, + location, + }) + .collect(); + + let json = serde_json::to_string_pretty(&persisted)?; + std::fs::write(&path, json)?; + + // Update ACLs + if let Some(path_str) = path.to_str() { + if let Err(e) = set_protected_acls(path_str) { + warn!( + "Failed to set ACLs on connected service locations file: {}", + e + ); + } + } + + debug!("Saved corrected connected service locations to file"); + } + + Ok(()) + } + + /// Get connected service locations from memory (fast read) + fn get_connected_service_locations() -> Vec<(String, ServiceLocation)> { + CONNECTED_SERVICE_LOCATIONS.read().unwrap().clone() + } + + /// Check if a specific service location is already connected + /// This is a cache lookup, not a live status check + fn is_service_location_connected(instance_id: &str, location_pubkey: &str) -> bool { + CONNECTED_SERVICE_LOCATIONS + .read() + .unwrap() + .iter() + .any(|(inst_id, loc)| inst_id == instance_id && loc.pubkey == location_pubkey) + } + + /// Add a connected service location (writes to disk-first, then memory cache) + fn add_connected_service_location( + instance_id: &str, + location: &ServiceLocation, + ) -> Result<(), ServiceLocationError> { + let mut locations = CONNECTED_SERVICE_LOCATIONS.read().unwrap().clone(); + locations.push((instance_id.to_string(), location.clone())); + + let persisted: Vec = locations + .iter() + .map(|(instance_id, location)| PersistedConnectedLocation { + instance_id: instance_id.clone(), + location: location.clone(), + }) + .collect(); + + let json = serde_json::to_string_pretty(&persisted)?; + let path = get_connected_locations_path()?; + std::fs::write(&path, json)?; + + // Update ACLs + if let Some(path_str) = path.to_str() { + if let Err(e) = set_protected_acls(path_str) { + warn!( + "Failed to set ACLs on connected service locations file after adding: {}", + e + ); + } + } + + // Update memory cache + CONNECTED_SERVICE_LOCATIONS + .write() + .unwrap() + .push((instance_id.to_string(), location.clone())); + + debug!( + "Added connected service location for instance '{}', location '{}'", + instance_id, location.name + ); + Ok(()) + } + + /// Remove connected service locations by filter (write disk-first, then memory) + fn remove_connected_service_locations(filter: F) -> Result<(), ServiceLocationError> + where + F: Fn(&str, &ServiceLocation) -> bool, + { + let mut locations = CONNECTED_SERVICE_LOCATIONS.read().unwrap().clone(); + locations.retain(|(instance_id, location)| !filter(instance_id, location)); + + // Save to disk first + let persisted: Vec = locations + .iter() + .map(|(instance_id, location)| PersistedConnectedLocation { + instance_id: instance_id.clone(), + location: location.clone(), + }) + .collect(); + + let json = serde_json::to_string_pretty(&persisted)?; + let path = get_connected_locations_path()?; + std::fs::write(&path, json)?; + + if let Some(path_str) = path.to_str() { + if let Err(e) = set_protected_acls(path_str) { + warn!( + "Failed to set ACLs on connected service locations file after removing: {}", + e + ); + } + } + + // Then update memory + CONNECTED_SERVICE_LOCATIONS + .write() + .unwrap() + .retain(|(instance_id, location)| !filter(instance_id, location)); + + debug!("Removed connected service locations matching filter"); + Ok(()) + } + + // Resets the state of the service location: + // 1. If it's an always on location, disconnects and reconnects it. + // 2. Otherwise, just disconnects it if the user is not logged in. + pub(crate) async fn reset_service_location_state( + instance_id: &str, + location_pubkey: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Reseting the state of service location for instance_id: {}, location_pubkey: {}", + instance_id, location_pubkey + ); + + let service_location = ServiceLocationApi::load_service_location_by_instance_and_pubkey( + instance_id, + location_pubkey, + )? + .ok_or_else(|| { + ServiceLocationError::LoadError(format!( + "Service location with pubkey {} for instance {} not found", + location_pubkey, instance_id + )) + })?; + + let interface_name = get_interface_name(&service_location.name); + + debug!( + "Disconnecting service location for instance_id: {}, location_pubkey: {}", + instance_id, location_pubkey + ); + + ServiceLocationApi::disconnect_service_location(instance_id, location_pubkey)?; + + debug!( + "Waiting for interface '{}' to go down before reconnecting...", + interface_name + ); + + // Wait for the interface to actually go down before reconnecting + wait_for_interface_down(&interface_name, INTERFACE_DOWN_TIMEOUT_MS).await?; + + debug!( + "Reconnecting service location if needed for instance_id: {}, location_pubkey: {}", + instance_id, location_pubkey + ); + + // We should reconnect only if: + // 1. It's an always on location + // 2. It's a pre-logon location and the user is not logged in + if service_location.mode == ServiceLocationMode::AlwaysOn as i32 + || (service_location.mode == ServiceLocationMode::PreLogon as i32 + && !is_user_logged_in()) + { + debug!( + "Reconnecting service location for instance_id: {}, location_pubkey: {}", + instance_id, location_pubkey + ); + ServiceLocationApi::connect_to_service_location(instance_id, location_pubkey)?; + } + + debug!("Service location state reset completed."); + + Ok(()) + } + + pub(crate) fn disconnect_service_locations_by_instance( + instance_id: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Disconnecting all service locations for instance_id: {}", + instance_id + ); + + let locations = Self::get_connected_service_locations(); + for (connected_instance_id, location) in &locations { + if instance_id == connected_instance_id { + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {}", ifname); + if let Ok(wgapi) = setup_wgapi(&ifname) { + if let Err(err) = wgapi.remove_interface() { + let msg = format!("Failed to remove interface {}: {}", ifname, err); + error!("{}", msg); + debug!("{}", msg); + } else { + debug!("Interface {} removed successfully", ifname); + } + debug!( + "Removing connected service location for instance_id: {}, location_pubkey: {}", + connected_instance_id, location.pubkey + ); + Self::remove_connected_service_locations(|inst_id, loc| { + inst_id == connected_instance_id && loc.pubkey == location.pubkey + })?; + debug!( + "Disconnected service location for instance_id: {}, location_pubkey: {}", + connected_instance_id, location.pubkey + ); + } else { + let msg = format!("Failed to setup WireGuard API for interface {}", ifname); + error!("{}", msg); + debug!("{}", msg); + } + } + } + + debug!( + "Disconnected all service locations for instance_id: {}", + instance_id + ); + + Ok(()) + } + + pub(crate) fn disconnect_service_location( + instance_id: &str, + location_pubkey: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Disconnecting service location for instance_id: {}, location_pubkey: {}", + instance_id, location_pubkey + ); + + let locations = Self::get_connected_service_locations(); + for (connected_instance_id, location) in locations { + if instance_id == connected_instance_id && location.pubkey == location_pubkey { + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {}", ifname); + if let Ok(wgapi) = setup_wgapi(&ifname) { + if let Err(err) = wgapi.remove_interface() { + let msg = format!("Failed to remove interface {}: {}", ifname, err); + error!("{}", msg); + debug!("{}", msg); + } else { + debug!("Interface {} removed successfully.", ifname); + } + debug!( + "Removing connected service location for instance_id: {}, location_pubkey: {}", + connected_instance_id, location.pubkey + ); + Self::remove_connected_service_locations(|inst_id, loc| { + inst_id == connected_instance_id && loc.pubkey == location_pubkey + })?; + debug!( + "Disconnected service location for instance_id: {}, location_pubkey: {}", + connected_instance_id, location.pubkey + ); + } else { + let msg = format!("Failed to setup WireGuard API for interface {}", ifname); + error!("{}", msg); + debug!("{}", msg); + } + + break; + } + } + + debug!( + "Disconnected service location for instance_id: {}, location_pubkey: {}", + instance_id, location_pubkey + ); + + Ok(()) + } + + /// Helper function to setup a WireGuard interface for a service location + fn setup_service_location_interface( + location: &ServiceLocation, + private_key: &str, + ) -> Result<(), ServiceLocationError> { + let peer_key = Key::from_str(&location.pubkey)?; + + let mut peer = Peer::new(peer_key.clone()); + peer.set_endpoint(&location.endpoint)?; + + peer.persistent_keepalive_interval = location.keepalive_interval.try_into().ok(); + + let allowed_ips = location + .allowed_ips + .split(',') + .map(str::to_string) + .collect::>(); + + for allowed_ip in &allowed_ips { + match IpAddrMask::from_str(allowed_ip) { + Ok(addr) => { + peer.allowed_ips.push(addr); + } + Err(err) => { + error!( + "Error parsing IP address {allowed_ip} while setting up interface for \ + location {location:?}, error details: {err}" + ); + } + } + } + + let mut addresses = Vec::new(); + + for address in location.address.split(',') { + addresses.push(IpAddrMask::from_str(address.trim())?); + } + + let config = InterfaceConfiguration { + name: location.name.clone(), + prvkey: private_key.to_string(), + addresses, + port: find_free_tcp_port().unwrap_or(DEFAULT_WIREGUARD_PORT) as u32, + peers: vec![peer.clone()], + mtu: None, + }; + + let ifname = location.name.clone(); + let ifname = get_interface_name(&ifname); + let wgapi = match setup_wgapi(&ifname) { + Ok(api) => api, + Err(err) => { + let msg = format!( + "Failed to setup WireGuard API for interface {}: {:?}", + ifname, err + ); + debug!("{}", msg); + return Err(ServiceLocationError::InterfaceError(msg)); + } + }; + + // Extract DNS configuration if available + let dns_string = location.dns.clone(); + let dns_entries = dns_string.split(',').map(str::trim).collect::>(); + // We assume that every entry that can't be parsed as an IP address is a domain name. + let mut dns = Vec::new(); + let mut search_domains = Vec::new(); + for entry in dns_entries { + if let Ok(ip) = entry.parse::() { + dns.push(ip); + } else { + search_domains.push(entry); + } + } + + debug!( + "Configuring interface {} with DNS: {:?} and search domains: {:?}", + ifname, dns, search_domains + ); + debug!("Interface Configuration: {:?}", config); + + wgapi.configure_interface(&config, &dns, &search_domains)?; + + debug!("Interface {} configured successfully.", ifname); + Ok(()) + } + + pub(crate) fn connect_to_service_location( + instance_id: &str, + location_pubkey: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Connecting to service location for instance_id: {}, location_pubkey: {}", + instance_id, location_pubkey + ); + + // Check if already connected to this service location + if Self::is_service_location_connected(instance_id, location_pubkey) { + debug!( + "Service location with pubkey {} for instance {} is already connected, skipping", + location_pubkey, instance_id + ); + return Ok(()); + } + + let locations = ServiceLocationApi::load_service_location_by_instance_id(instance_id)?; + let data = ServiceLocationApi::load_service_locations()?; + let instance_data = data + .into_iter() + .find(|d| d.instance_id == instance_id) + .ok_or_else(|| { + ServiceLocationError::LoadError(format!("Instance ID {} not found", instance_id)) + })?; + + for location in locations { + if location.pubkey == location_pubkey { + Self::setup_service_location_interface(&location, &instance_data.private_key)?; + Self::add_connected_service_location(&instance_data.instance_id, &location)?; + let ifname = get_interface_name(&location.name); + debug!("Successfully connected to service location '{}'", ifname); + break; + } + } + + Ok(()) + } + + pub(crate) fn disconnect_service_locations( + mode: Option, + ) -> Result<(), ServiceLocationError> { + debug!("Disconnecting service locations..."); + + let locations = Self::get_connected_service_locations(); + + debug!("Tearing down {} interfaces", locations.len()); + + for (_, location) in &locations { + if let Some(m) = mode { + let location_mode: ServiceLocationMode = location.mode.try_into()?; + if location_mode != m { + debug!( + "Skipping interface {} due to the service location mode doesn't match the requested mode (expected {:?}, found {:?})", + location.name, m, location.mode + ); + continue; + } + } + + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {}", ifname); + if let Ok(wgapi) = setup_wgapi(&ifname) { + if let Err(err) = wgapi.remove_interface() { + let msg = format!("Failed to remove interface {}: {}", ifname, err); + error!("{}", msg); + debug!("{}", msg); + } else { + debug!("Interface {} removed successfully.", ifname); + } + } else { + let msg = format!("Failed to setup WireGuard API for interface {}", ifname); + error!("{}", msg); + debug!("{}", msg); + } + } + + Self::remove_connected_service_locations(|_, location| { + if let Some(m) = mode { + let location_mode: ServiceLocationMode = location + .mode + .try_into() + .unwrap_or(ServiceLocationMode::AlwaysOn); + location_mode == m + } else { + true + } + })?; + + debug!("Service locations disconnected."); + + Ok(()) + } + + pub(crate) fn connect_to_service_locations() -> Result<(), ServiceLocationError> { + debug!("Attempting to auto-connect to VPN..."); + + let data = Self::load_service_locations()?; + debug!("Loaded {} instance(s) from ServiceLocationApi", data.len()); + + for instance_data in data { + debug!( + "Found service locations for instance ID: {}", + instance_data.instance_id + ); + debug!( + "Instance has {} service location(s)", + instance_data.service_locations.len() + ); + for location in instance_data.service_locations { + debug!("Service Location: {:?}", location); + + if location.mode == ServiceLocationMode::PreLogon as i32 { + if is_user_logged_in() { + debug!( + "Skipping pre-logon service location '{}' because user is logged in", + location.name + ); + continue; + } else { + debug!( + "Proceeding to connect pre-logon service location '{}' because no user is logged in", + location.name + ); + } + } + + if Self::is_service_location_connected(&instance_data.instance_id, &location.pubkey) + { + debug!( + "Skipping service location '{}' because it's already connected", + location.name + ); + continue; + } + + if let Err(err) = + Self::setup_service_location_interface(&location, &instance_data.private_key) + { + debug!( + "Failed to setup service location interface for '{}': {:?}", + location.name, err + ); + continue; + } + + if let Err(err) = + Self::add_connected_service_location(&instance_data.instance_id, &location) + { + debug!( + "Failed to persist connected service location after auto-connect: {:?}", + err + ); + } + + let ifname = get_interface_name(&location.name); + debug!("Successfully connected to service location '{}'", ifname); + } + } + + debug!("Auto-connect attempt completed"); + + let current_locations = Self::get_connected_service_locations(); + debug!( + "Currently connected service locations: {:?}", + current_locations + ); + + Ok(()) + } + + pub fn save_service_locations( + service_locations: &[ServiceLocation], + instance_id: &str, + private_key: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Received a request to save {} service location(s) for instance {}", + service_locations.len(), + instance_id + ); + + debug!("Service locations to save: {:?}", service_locations); + + create_dir_all(get_shared_directory()?)?; + + let instance_file_path = get_instance_file_path(instance_id)?; + + let service_location_data = ServiceLocationData { + service_locations: service_locations.to_vec(), + instance_id: instance_id.to_string(), + private_key: private_key.to_string(), + }; + + let json = serde_json::to_string_pretty(&service_location_data)?; + + debug!( + "Writing service location data to file: {:?}", + instance_file_path + ); + + std::fs::write(&instance_file_path, &json)?; + + if let Some(file_path_str) = instance_file_path.to_str() { + debug!("Setting ACLs on service location file: {}", file_path_str); + if let Err(e) = set_protected_acls(file_path_str) { + warn!( + "Failed to set ACLs on service location file {}: {}. File saved but may have insecure permissions.", + file_path_str, e + ); + } else { + debug!("Successfully set ACLs on service location file"); + } + } else { + warn!("Failed to convert file path to string for ACL setting"); + } + + debug!( + "Service locations saved successfully for instance {} to {:?}", + instance_id, instance_file_path + ); + Ok(()) + } + + fn load_service_locations() -> Result, ServiceLocationError> { + let base_dir = get_shared_directory()?; + let mut all_locations_data = Vec::new(); + + if base_dir.exists() { + for entry in std::fs::read_dir(base_dir)? { + let entry = entry?; + let file_path = entry.path(); + + if file_path.is_file() + && file_path.extension().and_then(|s| s.to_str()) == Some("json") + && file_path.file_name().and_then(|s| s.to_str()) + != Some(CONNECTED_LOCATIONS_FILENAME) + { + match std::fs::read_to_string(&file_path) { + Ok(data) => match serde_json::from_str::(&data) { + Ok(locations_data) => { + all_locations_data.push(locations_data); + } + Err(e) => { + error!( + "Failed to parse service locations from file {:?}: {}", + file_path, e + ); + } + }, + Err(e) => { + error!( + "Failed to read service locations file {:?}: {}", + file_path, e + ); + } + } + } + } + } + + debug!( + "Loaded service locations data for {} instances", + all_locations_data.len() + ); + Ok(all_locations_data) + } + + fn load_service_location_by_instance_id( + instance_id: &str, + ) -> Result, ServiceLocationError> { + debug!("Loading service locations for instance {}", instance_id); + + let instance_file_path = get_instance_file_path(instance_id)?; + + if instance_file_path.exists() { + let data = std::fs::read_to_string(&instance_file_path)?; + let service_location_data = serde_json::from_str::(&data)?; + Ok(service_location_data.service_locations) + } else { + debug!( + "No service location file found for instance {}", + instance_id + ); + Ok(Vec::new()) + } + } + + fn load_service_location_by_instance_and_pubkey( + instance_id: &str, + location_pubkey: &str, + ) -> Result, ServiceLocationError> { + debug!( + "Loading service location for instance {} and pubkey {}", + instance_id, location_pubkey + ); + + let instance_file_path = get_instance_file_path(instance_id)?; + + if instance_file_path.exists() { + let data = std::fs::read_to_string(&instance_file_path)?; + let service_location_data = serde_json::from_str::(&data)?; + + for location in service_location_data.service_locations { + if location.pubkey == location_pubkey { + debug!( + "Successfully loaded service location for instance {} and pubkey {}", + instance_id, location_pubkey + ); + return Ok(Some(location)); + } + } + + debug!( + "No service location found for instance {} with pubkey {}", + instance_id, location_pubkey + ); + Ok(None) + } else { + debug!( + "No service location file found for instance {}", + instance_id + ); + Ok(None) + } + } + + pub(crate) fn delete_all_service_locations_for_instance( + instance_id: &str, + ) -> Result<(), ServiceLocationError> { + debug!( + "Deleting all service locations for instance {}", + instance_id + ); + + let instance_file_path = get_instance_file_path(instance_id)?; + + if instance_file_path.exists() { + std::fs::remove_file(&instance_file_path)?; + debug!( + "Successfully deleted all service locations for instance {}", + instance_id + ); + } else { + debug!( + "No service location file found for instance {}", + instance_id + ); + } + + Ok(()) + } + + /// Validates that every running location still exists in the instance files + /// Returns a vector of tuples (instance_id, location_pubkey) for running locations that no longer exist + fn find_invalid_locations() -> Result, ServiceLocationError> { + debug!("Validating that all running locations still exist in instance files"); + + let connected_locations = Self::get_connected_service_locations(); + let mut invalid_locations = Vec::new(); + + for (instance_id, connected_location) in &connected_locations { + debug!( + "Checking if location '{}' (pubkey: {}) for instance '{}' still exists", + connected_location.name, connected_location.pubkey, instance_id + ); + + match ServiceLocationApi::load_service_location_by_instance_and_pubkey( + instance_id, + &connected_location.pubkey, + ) { + Ok(Some(_)) => { + debug!( + "Location '{}' (pubkey: {}) for instance '{}' exists in instance file", + connected_location.name, connected_location.pubkey, instance_id + ); + } + Ok(None) => { + warn!( + "Running location '{}' (pubkey: {}) for instance '{}' no longer exists in instance file", + connected_location.name, connected_location.pubkey, instance_id + ); + invalid_locations + .push((instance_id.clone(), connected_location.pubkey.clone())); + } + Err(err) => { + warn!( + "Failed to load location '{}' (pubkey: {}) for instance '{}': {:?}. Marking as invalid.", + connected_location.name, connected_location.pubkey, instance_id, err + ); + invalid_locations + .push((instance_id.clone(), connected_location.pubkey.clone())); + } + } + } + + if invalid_locations.is_empty() { + debug!("All running locations are valid and exist in instance files"); + } else { + warn!( + "Found {} running location(s) that no longer exist in instance files", + invalid_locations.len() + ); + } + + Ok(invalid_locations) + } + + /// Cleans up invalid running locations that no longer exist in instance files + /// This function will disconnect and remove any running locations that are not found in the instance files + /// Returns the number of locations that were cleaned up + fn cleanup_invalid_locations() -> Result { + debug!("Starting cleanup of invalid running locations"); + + let invalid_locations = Self::find_invalid_locations()?; + + if invalid_locations.is_empty() { + debug!("No invalid locations to clean up"); + return Ok(0); + } + + let cleanup_count = invalid_locations.len(); + debug!("Found {} invalid location(s) to clean up", cleanup_count); + + for (instance_id, location_pubkey) in invalid_locations { + debug!( + "Cleaning up invalid location with pubkey '{}' for instance '{}'", + location_pubkey, instance_id + ); + + match Self::disconnect_service_location(&instance_id, &location_pubkey) { + Ok(_) => { + debug!( + "Successfully cleaned up invalid location '{}' for instance '{}'", + location_pubkey, instance_id + ); + } + Err(err) => { + error!( + "Failed to disconnect invalid location '{}' for instance '{}': {:?}", + location_pubkey, instance_id, err + ); + } + } + } + + debug!( + "Cleanup complete. Disconnected and removed {} invalid location(s)", + cleanup_count + ); + + Ok(cleanup_count) + } +} diff --git a/src-tauri/src/error.rs b/src-tauri/src/error.rs index 4ba25ef5..ac8bbde4 100644 --- a/src-tauri/src/error.rs +++ b/src-tauri/src/error.rs @@ -2,6 +2,8 @@ use std::net::AddrParseError; use defguard_wireguard_rs::{error::WireguardInterfaceError, net::IpAddrParseError}; +use crate::enterprise::service_locations::ServiceLocationError; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] @@ -44,6 +46,10 @@ pub enum Error { StateLockFail, #[error("Failed to acquire lock on mutex. {0}")] PoisonError(String), + #[error("Failed to convert value. {0}")] + ConversionError(String), + #[error("JSON error: {0}")] + JsonError(#[from] serde_json::Error), } // we must manually implement serde::Serialize diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index e22f4938..d02dbc1b 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -32,7 +32,7 @@ pub mod wg_config; pub mod proto { use crate::database::models::{ - location::{Location, LocationMfaMode as MfaMode}, + location::{Location, LocationMfaMode as MfaMode, ServiceLocationMode as SLocationMode}, Id, NoId, }; @@ -55,6 +55,11 @@ pub mod proto { } }; + let service_location_mode = match self.service_location_mode { + Some(_service_location_mode) => self.service_location_mode().into(), + None => SLocationMode::Disabled, // Default to disabled if not set + }; + Location { id: NoId, instance_id, @@ -68,6 +73,7 @@ pub mod proto { route_all_traffic: false, keepalive_interval: self.keepalive_interval.into(), location_mfa_mode, + service_location_mode, } } } diff --git a/src-tauri/src/service/mod.rs b/src-tauri/src/service/mod.rs index 431a79f4..ac306b33 100644 --- a/src-tauri/src/service/mod.rs +++ b/src-tauri/src/service/mod.rs @@ -48,6 +48,14 @@ use tonic::{ }; use tracing::{debug, error, info, info_span, Instrument}; +use crate::enterprise::service_locations::ServiceLocationError; +#[cfg(not(windows))] +use crate::service::proto::DeleteServiceLocationsRequest; +#[cfg(windows)] +use crate::service::proto::{ + DeleteServiceLocationsRequest, RestartServiceLocationRequest, SaveServiceLocationsRequest, +}; + use self::config::Config; use super::VERSION; @@ -72,6 +80,10 @@ pub enum DaemonError { Unexpected(String), #[error(transparent)] TransportError(#[from] tonic::transport::Error), + #[error(transparent)] + ServiceLocationError(#[from] ServiceLocationError), + #[error(transparent)] + WindowsServiceError(#[from] windows_service::Error), } type IfName = String; @@ -101,7 +113,7 @@ impl DaemonService { type InterfaceDataStream = Pin> + Send>>; -fn setup_wgapi(ifname: &str) -> Result { +pub(crate) fn setup_wgapi(ifname: &str) -> Result { let wgapi = WG::new(ifname).map_err(|err| { let msg = format!("Failed to setup WireGuard API for interface {ifname}: {err}"); error!("{msg}"); @@ -115,6 +127,111 @@ fn setup_wgapi(ifname: &str) -> Result { impl DesktopDaemonService for DaemonService { type ReadInterfaceDataStream = InterfaceDataStream; + #[cfg(not(windows))] + async fn save_service_locations( + &self, + request: tonic::Request, + ) -> Result, Status> { + debug!("Saved service location request received, this is currently not supported on Unix systems"); + Ok(Response::new(())) + } + + #[cfg(not(windows))] + async fn delete_service_locations( + &self, + request: tonic::Request, + ) -> Result, Status> { + debug!("Saved service location request received, this is currently not supported on Unix systems"); + Ok(Response::new(())) + } + + #[cfg(not(windows))] + async fn reset_service_location( + &self, + request: tonic::Request, + ) -> Result, Status> { + debug!("Restart service location request received, this is currently not supported on Unix systems"); + Ok(Response::new(())) + } + + #[cfg(windows)] + async fn save_service_locations( + &self, + request: tonic::Request, + ) -> Result, Status> { + use crate::enterprise::service_locations::ServiceLocationApi; + + debug!("Received a request to save service location"); + let service_location = request.into_inner(); + + match ServiceLocationApi::save_service_locations( + service_location.service_locations.as_slice(), + &service_location.instance_id, + &service_location.private_key, + ) { + Ok(()) => { + debug!("Service location saved successfully"); + Ok(Response::new(())) + } + Err(e) => { + error!("Failed to save service location: {}", e); + Err(Status::internal(format!( + "Failed to save service location: {}", + e + ))) + } + } + } + + #[cfg(windows)] + async fn delete_service_locations( + &self, + request: tonic::Request, + ) -> Result, Status> { + use crate::enterprise::service_locations::ServiceLocationApi; + + debug!("Received a request to delete service location"); + let instance_id = request.into_inner().instance_id; + + ServiceLocationApi::disconnect_service_locations_by_instance(&instance_id).map_err( + |e| { + error!("Failed to disconnect service location: {}", e); + Status::internal(format!("Failed to disconnect service location: {}", e)) + }, + )?; + + match ServiceLocationApi::delete_all_service_locations_for_instance(&instance_id) { + Ok(()) => { + debug!("Service location deleted successfully"); + Ok(Response::new(())) + } + Err(e) => { + error!("Failed to delete service location: {}", e); + Err(Status::internal(format!( + "Failed to delete service location: {}", + e + ))) + } + } + } + + #[cfg(windows)] + async fn reset_service_location( + &self, + request: tonic::Request, + ) -> Result, Status> { + use crate::enterprise::service_locations::ServiceLocationApi; + + let request = request.into_inner(); + ServiceLocationApi::reset_service_location_state(&request.instance_id, &request.pubkey) + .await + .map_err(|e| { + error!("Failed to restart service location: {}", e); + Status::internal(format!("Failed to restart service location: {}", e)) + })?; + Ok(Response::new(())) + } + async fn create_interface( &self, request: tonic::Request, diff --git a/src-tauri/src/service/windows.rs b/src-tauri/src/service/windows.rs index aa79ab5a..7d400a86 100644 --- a/src-tauri/src/service/windows.rs +++ b/src-tauri/src/service/windows.rs @@ -1,7 +1,21 @@ -use std::{ffi::OsString, sync::mpsc, time::Duration}; +use std::{ + ffi::OsString, + fs::OpenOptions, + net::IpAddr, + result::Result, + str::FromStr, + sync::{mpsc, LazyLock, RwLock}, + time::Duration, +}; +use chrono::Utc; use clap::Parser; -use log::error; +use common::{find_free_tcp_port, get_interface_name}; +use defguard_wireguard_rs::{ + host::Peer, key::Key, net::IpAddrMask, InterfaceConfiguration, WireguardInterfaceApi, +}; +use error; +use std::io::Write; use tokio::runtime::Runtime; use windows_service::{ define_windows_service, @@ -10,15 +24,33 @@ use windows_service::{ ServiceType, }, service_control_handler::{register, ServiceControlHandlerResult}, - service_dispatcher, Result, + service_dispatcher, }; -use crate::service::{run_server, utils::logging_setup, Config}; +use crate::{ + enterprise::service_locations::{windows::watch_for_login_logoff, ServiceLocationApi}, + error::Error, + service::{ + proto::{ServiceLocation, ServiceLocationMode}, + run_server, setup_wgapi, + utils::logging_setup, + Config, DaemonError, + }, + utils::{DEFAULT_ROUTE_IPV4, DEFAULT_ROUTE_IPV6}, +}; +use windows::{ + core::PSTR, + Win32::System::RemoteDesktop::{ + WTSQuerySessionInformationA, WTSWaitSystemEvent, WTS_CURRENT_SERVER_HANDLE, + WTS_EVENT_LOGOFF, WTS_EVENT_LOGON, WTS_SESSION_INFOA, + }, +}; static SERVICE_NAME: &str = "DefguardService"; const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; +const LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS: u64 = 10; -pub fn run() -> Result<()> { +pub fn run() -> Result<(), windows_service::Error> { // Register generated `ffi_service_main` with the system and start the service, blocking // this thread until the service is stopped. service_dispatcher::start(SERVICE_NAME, ffi_service_main) @@ -33,7 +65,7 @@ pub fn service_main(_arguments: Vec) { } } -fn run_service() -> Result<()> { +fn run_service() -> Result<(), DaemonError> { // Create a channel to be able to poll a stop event from the service worker loop. let (shutdown_tx, shutdown_rx) = mpsc::channel::(); let shutdown_tx_server = shutdown_tx.clone(); @@ -81,6 +113,61 @@ fn run_service() -> Result<()> { std::process::exit(1); })); + + runtime.spawn(async move { + info!("Starting service location management task"); + + match ServiceLocationApi::init() { + Ok(_) => { + info!("Service locations storage initialized successfully"); + } + Err(e) => { + error!( + "Failed to initialize service locations storage: {}. Shutting down service location thread", + e + ); + return; + } + } + + // Attempt to connect to service locations + info!("Attempting to auto-connect to service locations"); + match ServiceLocationApi::connect_to_service_locations() { + Ok(_) => { + info!("Auto-connect to service locations completed successfully"); + } + Err(e) => { + warn!( + "Error while trying to auto-connect to service locations: {}. \ + Will continue monitoring for login/logoff events.", + e + ); + } + } + + // Start watching for login/logoff events with error recovery + info!("Starting login/logoff event monitoring"); + loop { + match watch_for_login_logoff().await { + Ok(_) => { + warn!("Login/logoff event monitoring ended unexpectedly"); + break; + } + Err(e) => { + error!( + "Error in login/logoff event monitoring: {}. Restarting in {} seconds...", + e, LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS + ); + tokio::time::sleep(Duration::from_secs(LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS)).await; + info!("Restarting login/logoff event monitoring"); + } + } + } + + warn!("Service location management task terminated"); + }); + + runtime.spawn(async move { let server_result = run_server(config).await; diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs index 987c1358..7498fa3b 100644 --- a/src-tauri/src/utils.rs +++ b/src-tauri/src/utils.rs @@ -943,6 +943,11 @@ async fn check_connection( pub async fn sync_connections(app_handle: &AppHandle) -> Result<(), Error> { debug!("Synchronizing active connections with the systems' state..."); let all_locations = Location::all(&*DB_POOL).await?; + // filter out service locations as they are managed through the Windows Service + let all_locations: Vec> = all_locations + .into_iter() + .filter(|loc| !loc.is_service_location()) + .collect(); let service_manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT).map_err( |err| { From f0e9f6c07adeecd79adf4a23158e94f874c161bf Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Thu, 23 Oct 2025 11:02:49 +0200 Subject: [PATCH 02/13] cleanup, merge with dev --- src-tauri/Cargo.lock | 12 + src-tauri/proto | 2 +- src-tauri/src/commands.rs | 14 +- .../src/enterprise/service_locations/mod.rs | 34 +- .../enterprise/service_locations/windows.rs | 883 +++++------------- src-tauri/src/error.rs | 2 - src-tauri/src/service/mod.rs | 66 +- src-tauri/src/service/windows.rs | 73 +- 8 files changed, 383 insertions(+), 703 deletions(-) diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index fbd84c72..8e2748a7 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -8191,6 +8191,18 @@ dependencies = [ "windows-numerics 0.3.1", ] +[[package]] +name = "windows-acl" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "177b1723986bcb4c606058e77f6e8614b51c7f9ad2face6f6fd63dd5c8b3cec3" +dependencies = [ + "field-offset", + "libc", + "widestring 0.4.3", + "winapi", +] + [[package]] name = "windows-collections" version = "0.2.0" diff --git a/src-tauri/proto b/src-tauri/proto index fa9c14ef..3fd150c0 160000 --- a/src-tauri/proto +++ b/src-tauri/proto @@ -1 +1 @@ -Subproject commit fa9c14efd121182ec39c8716370e1250c77fa652 +Subproject commit 3fd150c0245f5ed088ed57ad780a9376e3377ce3 diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index 2b602bc2..9ab57f9c 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -21,7 +21,7 @@ use crate::{ models::{ connection::{ActiveConnection, Connection, ConnectionInfo}, instance::{Instance, InstanceInfo}, - location::{Location, LocationMfaMode, ServiceLocationMode}, + location::{Location, LocationMfaMode}, location_stats::LocationStats, tunnel::{Tunnel, TunnelConnection, TunnelConnectionInfo, TunnelStats}, wireguard_keys::WireguardKeys, @@ -29,9 +29,7 @@ use crate::{ }, DB_POOL, }, - enterprise::{ - periodic::config::poll_instance, provisioning::ProvisioningConfig, service_locations, - }, + enterprise::{periodic::config::poll_instance, provisioning::ProvisioningConfig}, error::Error, events::EventKey, log_watcher::{ @@ -41,7 +39,7 @@ use crate::{ proto::DeviceConfigResponse, service::{ proto::{ - DeleteServiceLocationsRequest, RemoveInterfaceRequest, RestartServiceLocationRequest, + DeleteServiceLocationsRequest, RemoveInterfaceRequest, ResetServiceLocationRequest, SaveServiceLocationsRequest, ServiceLocation, }, utils::DAEMON_CLIENT, @@ -296,7 +294,6 @@ pub async fn save_device_config( trace!("Created following locations: {locations:#?}"); let mut service_locations = Vec::::new(); - let mut service_locations_to_restart = Vec::<(String, String)>::new(); for saved_location in &locations { if saved_location.is_service_location() { @@ -342,7 +339,7 @@ pub async fn save_device_config( .collect::>(); for location_pubkey in locations_pubkeys { - let restart_request = RestartServiceLocationRequest { + let restart_request = ResetServiceLocationRequest { instance_id: instance.uuid.clone(), pubkey: location_pubkey.clone(), }; @@ -626,7 +623,6 @@ pub(crate) async fn do_update_instance( ); let mut service_locations = Vec::::new(); - let mut service_locations_to_reset = Vec::<(String, String)>::new(); // check if locations have changed if locations_changed { @@ -760,7 +756,7 @@ pub(crate) async fn do_update_instance( DAEMON_CLIENT .clone() - .reset_service_location(RestartServiceLocationRequest { + .reset_service_location(ResetServiceLocationRequest { instance_id: instance_id.clone(), pubkey: pubkey.clone(), }) diff --git a/src-tauri/src/enterprise/service_locations/mod.rs b/src-tauri/src/enterprise/service_locations/mod.rs index d19324a8..f5008c02 100644 --- a/src-tauri/src/enterprise/service_locations/mod.rs +++ b/src-tauri/src/enterprise/service_locations/mod.rs @@ -1,12 +1,14 @@ -use defguard_wireguard_rs::error::WireguardInterfaceError; +use std::collections::HashMap; + +use defguard_wireguard_rs::{error::WireguardInterfaceError, WGApi}; use serde::{Deserialize, Serialize}; use crate::{ database::models::{ - location::{Location, LocationMfaMode, ServiceLocationMode}, + location::{Location, ServiceLocationMode}, Id, }, - service::proto::{ServiceLocation, ServiceLocationMode as ProtoServiceLocationMode}, + service::proto::ServiceLocation, }; #[cfg(target_os = "windows")] @@ -14,6 +16,8 @@ pub mod windows; #[derive(Debug, thiserror::Error)] pub enum ServiceLocationError { + #[error("Error occurred while initializing service location API: {0}")] + InitError(String), #[error("Failed to load service location storage: {0}")] LoadError(String), #[error(transparent)] @@ -35,7 +39,13 @@ pub enum ServiceLocationError { WindowsServiceError(#[from] windows_service::Error), } -pub(crate) struct ServiceLocationApi; +#[derive(Default)] +pub(crate) struct ServiceLocationManager { + // Interface name: WireGuard API instance + wgapis: HashMap, + // Instance ID: Service locations connected under that instance + connected_service_locations: HashMap>, +} #[derive(Serialize, Deserialize)] pub(crate) struct ServiceLocationData { @@ -44,6 +54,12 @@ pub(crate) struct ServiceLocationData { pub private_key: String, } +pub(crate) struct SingleServiceLocationData { + pub service_location: ServiceLocation, + pub instance_id: String, + pub private_key: String, +} + impl std::fmt::Debug for ServiceLocationData { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ServiceLocationData") @@ -54,6 +70,16 @@ impl std::fmt::Debug for ServiceLocationData { } } +impl std::fmt::Debug for SingleServiceLocationData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SingleServiceLocationData") + .field("service_locations", &self.service_location) + .field("instance_id", &self.instance_id) + .field("private_key", &"***") + .finish() + } +} + impl Location { pub fn to_service_location(&self) -> Result { if !self.is_service_location() { diff --git a/src-tauri/src/enterprise/service_locations/windows.rs b/src-tauri/src/enterprise/service_locations/windows.rs index e4c9c2b6..87183260 100644 --- a/src-tauri/src/enterprise/service_locations/windows.rs +++ b/src-tauri/src/enterprise/service_locations/windows.rs @@ -1,5 +1,11 @@ use std::{ - fs::create_dir_all, net::IpAddr, path::PathBuf, result::Result, str::FromStr, sync::RwLock, + collections::HashMap, + fs::{self, create_dir_all}, + net::IpAddr, + path::PathBuf, + result::Result, + str::FromStr, + sync::{Arc, RwLock}, time::Duration, }; @@ -17,14 +23,11 @@ use windows::{ }, }; use windows_acl::acl::ACL; -use windows_service::{ - service::{ServiceAccess, ServiceState}, - service_manager::{ServiceManager, ServiceManagerAccess}, -}; use crate::{ enterprise::service_locations::{ - ServiceLocationApi, ServiceLocationData, ServiceLocationError, + ServiceLocationData, ServiceLocationError, ServiceLocationManager, + SingleServiceLocationData, }, service::{ proto::{ServiceLocation, ServiceLocationMode}, @@ -34,30 +37,12 @@ use crate::{ const LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS: u64 = 5; const DEFAULT_WIREGUARD_PORT: u16 = 51820; -const CONNECTED_LOCATIONS_FILENAME: &str = "connected_service_locations.json"; -const WIREGUARD_SERVICE_PREFIX: &str = "WireGuardTunnel$"; const DEFGUARD_DIR: &str = "Defguard"; const SERVICE_LOCATIONS_SUBDIR: &str = "service_locations"; -const INTERFACE_DOWN_CHECK_INTERVAL_MS: u64 = 100; -const INTERFACE_DOWN_TIMEOUT_MS: u64 = 5000; - -// Tuples of (instance_id, ServiceLocation) - serves as in-memory cache, should be commited to disk first -static CONNECTED_SERVICE_LOCATIONS: RwLock> = - RwLock::new(Vec::new()); - -#[derive(serde::Serialize, serde::Deserialize)] -struct PersistedConnectedLocation { - instance_id: String, - location: ServiceLocation, -} - -fn get_connected_locations_path() -> Result { - let mut path = get_shared_directory()?; - path.push(CONNECTED_LOCATIONS_FILENAME); - Ok(path) -} -pub(crate) async fn watch_for_login_logoff() -> Result<(), ServiceLocationError> { +pub(crate) async fn watch_for_login_logoff( + service_location_manager: Arc>, +) -> Result<(), ServiceLocationError> { unsafe { loop { let mut event_mask: u32 = 0; @@ -69,13 +54,10 @@ pub(crate) async fn watch_for_login_logoff() -> Result<(), ServiceLocationError> match success { Ok(_) => { - debug!( - "Waiting for system event returned with event_mask: 0x{:x}", - event_mask, - ); + debug!("Waiting for system event returned with event_mask: 0x{event_mask:x}"); } Err(err) => { - error!("Failed waiting for login/logoff event: {:?}", err); + error!("Failed waiting for login/logoff event: {err:?}"); tokio::time::sleep(Duration::from_secs(LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS)) .await; continue; @@ -86,13 +68,19 @@ pub(crate) async fn watch_for_login_logoff() -> Result<(), ServiceLocationError> debug!( "Detected user logon, attempting to auto-disconnect from service locations." ); - ServiceLocationApi::disconnect_service_locations(Some( - ServiceLocationMode::PreLogon, - ))?; + service_location_manager + .clone() + .write() + .unwrap() + .disconnect_service_locations(Some(ServiceLocationMode::PreLogon))?; } if event_mask & WTS_EVENT_LOGOFF != 0 { debug!("Detected user logoff, attempting to auto-connect to service locations."); - ServiceLocationApi::connect_to_service_locations()?; + service_location_manager + .clone() + .write() + .unwrap() + .connect_to_service_locations()?; } } } @@ -112,7 +100,7 @@ fn get_shared_directory() -> Result { } fn set_protected_acls(path: &str) -> Result<(), ServiceLocationError> { - debug!("Setting secure ACLs on: {}", path); + debug!("Setting secure ACLs on: {path}"); const SYSTEM_SID: &str = "S-1-5-18"; // NT AUTHORITY\SYSTEM const ADMINISTRATORS_SID: &str = "S-1-5-32-544"; // BUILTIN\Administrators @@ -122,18 +110,15 @@ fn set_protected_acls(path: &str) -> Result<(), ServiceLocationError> { match ACL::from_file_path(path, false) { Ok(mut acl) => { // Remove everything else from access - debug!("Removing all existing ACL entries for {}", path); + debug!("Removing all existing ACL entries for {path}"); let all_entries = acl.all().map_err(|e| { - ServiceLocationError::LoadError(format!("Failed to get ACL entries: {}", e)) + ServiceLocationError::LoadError(format!("Failed to get ACL entries: {e}")) })?; for entry in all_entries { if let Some(sid) = entry.sid { if let Err(e) = acl.remove(sid.as_ptr() as *mut _, None, None) { - debug!( - "Note: Could not remove ACL entry (might be expected): {}", - e - ); + debug!("Note: Could not remove ACL entry (might be expected): {e}"); } } } @@ -148,15 +133,13 @@ fn set_protected_acls(path: &str) -> Result<(), ServiceLocationError> { acl.allow(system_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) .map_err(|e| { ServiceLocationError::LoadError(format!( - "Failed to add SYSTEM ACL: {}", - e + "Failed to add SYSTEM ACL: {e}" )) })?; } Err(e) => { return Err(ServiceLocationError::LoadError(format!( - "Failed to convert SYSTEM SID: {}", - e + "Failed to convert SYSTEM SID: {e}" ))); } } @@ -169,30 +152,24 @@ fn set_protected_acls(path: &str) -> Result<(), ServiceLocationError> { acl.allow(admin_sid.as_ptr() as *mut _, true, FILE_ALL_ACCESS) .map_err(|e| { ServiceLocationError::LoadError(format!( - "Failed to add Administrators ACL: {}", - e + "Failed to add Administrators ACL: {e}" )) })?; } Err(e) => { return Err(ServiceLocationError::LoadError(format!( - "Failed to convert Administrators SID: {}", - e + "Failed to convert Administrators SID: {e}" ))); } } - debug!( - "Successfully set secure ACLs on {} for SYSTEM and Administrators", - path - ); + debug!("Successfully set secure ACLs on {path} for SYSTEM and Administrators"); Ok(()) } Err(e) => { - error!("Failed to get ACL for {}: {}", path, e); + error!("Failed to get ACL for {path}: {e}"); Err(ServiceLocationError::LoadError(format!( - "Failed to get ACL for {}: {}", - path, e + "Failed to get ACL for {path}: {e}" ))) } } @@ -200,65 +177,10 @@ fn set_protected_acls(path: &str) -> Result<(), ServiceLocationError> { fn get_instance_file_path(instance_id: &str) -> Result { let mut path = get_shared_directory()?; - path.push(format!("{}.json", instance_id)); + path.push(format!("{instance_id}.json")); Ok(path) } -pub fn query_connection_status(interface_name: &str) -> Result { - let service_manager = - ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT)?; - let service_name = format!("{}{}", WIREGUARD_SERVICE_PREFIX, interface_name); - let service = service_manager.open_service(&service_name, ServiceAccess::QUERY_STATUS)?; - let status = service.query_status()?; - Ok(status.current_state == ServiceState::Running) -} - -/// Wait for an interface to go down (not running) -/// Returns Ok(()) if interface is down within timeout, Err otherwise -async fn wait_for_interface_down( - interface_name: &str, - timeout_ms: u64, -) -> Result<(), ServiceLocationError> { - let start = std::time::Instant::now(); - let timeout = Duration::from_millis(timeout_ms); - let check_interval = Duration::from_millis(INTERFACE_DOWN_CHECK_INTERVAL_MS); - - debug!( - "Waiting for interface '{}' to go down (timeout: {}ms)", - interface_name, timeout_ms - ); - - loop { - match query_connection_status(interface_name) { - Ok(is_running) => { - if !is_running { - debug!("Interface '{}' is now down", interface_name); - return Ok(()); - } - } - Err(_) => { - // If we can't query the status (e.g., service not found), assume it's down - debug!( - "Interface '{}' status query failed, assuming it's down", - interface_name - ); - return Ok(()); - } - } - - if start.elapsed() >= timeout { - let msg = format!( - "Timeout waiting for interface '{}' to go down after {}ms", - interface_name, timeout_ms - ); - error!("{}", msg); - return Err(ServiceLocationError::InterfaceError(msg)); - } - - tokio::time::sleep(check_interval).await; - } -} - pub(crate) fn is_user_logged_in() -> bool { debug!("Starting checking if user is logged in..."); @@ -271,13 +193,12 @@ pub(crate) fn is_user_logged_in() -> bool { match ret { Ok(_) => { - debug!("WTSEnumerateSessionsA succeeded, found {} sessions", count); + debug!("WTSEnumerateSessionsA succeeded, found {count} sessions"); let sessions = std::slice::from_raw_parts(pp_sessions, count as usize); for (index, session) in sessions.iter().enumerate() { debug!( - "Session {}: SessionId={}, State={:?}, WinStationName={:?}", - index, + "Session {index}: SessionId={}, State={:?}, WinStationName={:?}", session.SessionId, session.State, std::ffi::CStr::from_ptr(session.pWinStationName.0 as *const i8) @@ -304,8 +225,8 @@ pub(crate) fn is_user_logged_in() -> bool { .into_owned(); debug!( - "Found session {} username: {}", - session.SessionId, username + "Found session {} username: {username}", + session.SessionId ); windows::Win32::System::RemoteDesktop::WTSFreeMemory( @@ -318,8 +239,8 @@ pub(crate) fn is_user_logged_in() -> bool { } Err(err) => { debug!( - "Failed to get username for session {}: {:?}", - session.SessionId, err + "Failed to get username for session {}: {err:?}", + session.SessionId ); } } @@ -329,8 +250,8 @@ pub(crate) fn is_user_logged_in() -> bool { debug!("No active sessions found"); } Err(err) => { - error!("Failed to enumerate user sessions: {:?}", err); - debug!("WTSEnumerateSessionsA failed: {:?}", err); + error!("Failed to enumerate user sessions: {err:?}"); + debug!("WTSEnumerateSessionsA failed: {err:?}"); } } } @@ -339,8 +260,8 @@ pub(crate) fn is_user_logged_in() -> bool { false } -impl ServiceLocationApi { - pub fn init() -> Result<(), ServiceLocationError> { +impl ServiceLocationManager { + pub fn init() -> Result { debug!("Initializing ServiceLocationApi"); let path = get_shared_directory()?; @@ -350,205 +271,75 @@ impl ServiceLocationApi { if let Some(path_str) = path.to_str() { debug!("Setting ACLs on service locations directory"); if let Err(e) = set_protected_acls(path_str) { - warn!( - "Failed to set ACLs on service locations directory: {}. Continuing anyway.", - e - ); + warn!("Failed to set ACLs on service locations directory: {e}. Continuing anyway."); } } else { warn!("Failed to convert path to string for ACL setting"); } - debug!("Loading and validating connected service locations"); - if let Err(err) = Self::load_and_validate_connected_service_locations() { - debug!( - "Failed to load and validate persisted connected service locations: {:?}", - err - ); - } - - Self::cleanup_invalid_locations()?; + let manager = Self { + wgapis: HashMap::new(), + connected_service_locations: HashMap::new(), + }; debug!("ServiceLocationApi initialized successfully"); - Ok(()) + Ok(manager) } - /// Load and validate connected service locations from file into memory cache - /// Verifies that each location is actually running and removes stale entries - fn load_and_validate_connected_service_locations() -> Result<(), ServiceLocationError> { - let path = get_connected_locations_path()?; - if !path.exists() { - CONNECTED_SERVICE_LOCATIONS.write().unwrap().clear(); - return Ok(()); - } - - let data = std::fs::read_to_string(&path)?; - - let persisted: Vec = serde_json::from_str(&data)?; - - let mut validated_locations = Vec::new(); - let mut removed_count = 0; - - for p in persisted { - let interface_name = get_interface_name(&p.location.name); - match query_connection_status(&interface_name) { - Ok(is_running) => { - if is_running { - validated_locations.push((p.instance_id, p.location)); - } else { - debug!( - "Removing stale service location '{}' from connected list - interface not up", - p.location.name - ); - removed_count += 1; - } - } - Err(err) => { - debug!( - "Removing service location '{}' from connected list - failed to query status: {:?}", - p.location.name, err - ); - removed_count += 1; - } - } - } - - // Update memory with validated locations - let mut guard = CONNECTED_SERVICE_LOCATIONS.write().unwrap(); - guard.clear(); - for (instance_id, location) in &validated_locations { - guard.push((instance_id.clone(), location.clone())); - } - drop(guard); - - debug!( - "Loaded {} connected service locations from file into memory ({} stale entries removed)", - validated_locations.len(), - removed_count - ); - - // Save the corrected state back to disk if we removed any stale entries - if removed_count > 0 { - let persisted: Vec = validated_locations - .into_iter() - .map(|(instance_id, location)| PersistedConnectedLocation { - instance_id, - location, - }) - .collect(); - - let json = serde_json::to_string_pretty(&persisted)?; - std::fs::write(&path, json)?; - - // Update ACLs - if let Some(path_str) = path.to_str() { - if let Err(e) = set_protected_acls(path_str) { - warn!( - "Failed to set ACLs on connected service locations file: {}", - e - ); + /// Check if a specific service location is already connected + fn is_service_location_connected(&self, instance_id: &str, location_pubkey: &str) -> bool { + if let Some(locations) = self.connected_service_locations.get(instance_id) { + for location in locations { + if location.pubkey == location_pubkey { + return true; } } - - debug!("Saved corrected connected service locations to file"); } - - Ok(()) + false } - /// Get connected service locations from memory (fast read) - fn get_connected_service_locations() -> Vec<(String, ServiceLocation)> { - CONNECTED_SERVICE_LOCATIONS.read().unwrap().clone() - } - - /// Check if a specific service location is already connected - /// This is a cache lookup, not a live status check - fn is_service_location_connected(instance_id: &str, location_pubkey: &str) -> bool { - CONNECTED_SERVICE_LOCATIONS - .read() - .unwrap() - .iter() - .any(|(inst_id, loc)| inst_id == instance_id && loc.pubkey == location_pubkey) - } - - /// Add a connected service location (writes to disk-first, then memory cache) + /// Add a connected service location fn add_connected_service_location( + &mut self, instance_id: &str, location: &ServiceLocation, ) -> Result<(), ServiceLocationError> { - let mut locations = CONNECTED_SERVICE_LOCATIONS.read().unwrap().clone(); - locations.push((instance_id.to_string(), location.clone())); - - let persisted: Vec = locations - .iter() - .map(|(instance_id, location)| PersistedConnectedLocation { - instance_id: instance_id.clone(), - location: location.clone(), - }) - .collect(); - - let json = serde_json::to_string_pretty(&persisted)?; - let path = get_connected_locations_path()?; - std::fs::write(&path, json)?; - - // Update ACLs - if let Some(path_str) = path.to_str() { - if let Err(e) = set_protected_acls(path_str) { - warn!( - "Failed to set ACLs on connected service locations file after adding: {}", - e - ); - } - } - - // Update memory cache - CONNECTED_SERVICE_LOCATIONS - .write() - .unwrap() - .push((instance_id.to_string(), location.clone())); + self.connected_service_locations + .entry(instance_id.to_string()) + .or_insert_with(Vec::new) + .push(location.clone()); debug!( - "Added connected service location for instance '{}', location '{}'", - instance_id, location.name + "Added connected service location for instance '{instance_id}', location '{}'", + location.name ); Ok(()) } /// Remove connected service locations by filter (write disk-first, then memory) - fn remove_connected_service_locations(filter: F) -> Result<(), ServiceLocationError> + fn remove_connected_service_locations( + &mut self, + filter: F, + ) -> Result<(), ServiceLocationError> where F: Fn(&str, &ServiceLocation) -> bool, { - let mut locations = CONNECTED_SERVICE_LOCATIONS.read().unwrap().clone(); - locations.retain(|(instance_id, location)| !filter(instance_id, location)); - - // Save to disk first - let persisted: Vec = locations - .iter() - .map(|(instance_id, location)| PersistedConnectedLocation { - instance_id: instance_id.clone(), - location: location.clone(), - }) - .collect(); - - let json = serde_json::to_string_pretty(&persisted)?; - let path = get_connected_locations_path()?; - std::fs::write(&path, json)?; + // Iterate through connected_service_locations and remove matching locations + let mut instances_to_remove = Vec::new(); - if let Some(path_str) = path.to_str() { - if let Err(e) = set_protected_acls(path_str) { - warn!( - "Failed to set ACLs on connected service locations file after removing: {}", - e - ); + for (instance_id, locations) in self.connected_service_locations.iter_mut() { + locations.retain(|location| !filter(instance_id, location)); + + // Mark instance for removal if it has no more locations + if locations.is_empty() { + instances_to_remove.push(instance_id.clone()); } } - // Then update memory - CONNECTED_SERVICE_LOCATIONS - .write() - .unwrap() - .retain(|(instance_id, location)| !filter(instance_id, location)); + // Remove instances with no locations + for instance_id in instances_to_remove { + self.connected_service_locations.remove(&instance_id); + } debug!("Removed connected service locations matching filter"); Ok(()) @@ -557,60 +348,53 @@ impl ServiceLocationApi { // Resets the state of the service location: // 1. If it's an always on location, disconnects and reconnects it. // 2. Otherwise, just disconnects it if the user is not logged in. - pub(crate) async fn reset_service_location_state( + pub(crate) fn reset_service_location_state( + &mut self, instance_id: &str, location_pubkey: &str, ) -> Result<(), ServiceLocationError> { debug!( - "Reseting the state of service location for instance_id: {}, location_pubkey: {}", - instance_id, location_pubkey + "Reseting the state of service location for instance_id: {instance_id}, location_pubkey: {location_pubkey}" ); - let service_location = ServiceLocationApi::load_service_location_by_instance_and_pubkey( - instance_id, - location_pubkey, - )? - .ok_or_else(|| { - ServiceLocationError::LoadError(format!( - "Service location with pubkey {} for instance {} not found", - location_pubkey, instance_id - )) - })?; - - let interface_name = get_interface_name(&service_location.name); + let service_location_data = self + .load_service_location(instance_id, location_pubkey)? + .ok_or_else(|| { + ServiceLocationError::LoadError(format!( + "Service location with pubkey {} for instance {} not found", + location_pubkey, instance_id + )) + })?; debug!( - "Disconnecting service location for instance_id: {}, location_pubkey: {}", - instance_id, location_pubkey + "Disconnecting service location for instance_id: {instance_id}, location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name ); - ServiceLocationApi::disconnect_service_location(instance_id, location_pubkey)?; + self.disconnect_service_location(instance_id, location_pubkey)?; debug!( - "Waiting for interface '{}' to go down before reconnecting...", - interface_name + "Disconnected service location for instance_id: {instance_id}, location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name ); - // Wait for the interface to actually go down before reconnecting - wait_for_interface_down(&interface_name, INTERFACE_DOWN_TIMEOUT_MS).await?; - debug!( - "Reconnecting service location if needed for instance_id: {}, location_pubkey: {}", - instance_id, location_pubkey + "Reconnecting service location if needed for instance_id: {instance_id}, location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name ); // We should reconnect only if: // 1. It's an always on location // 2. It's a pre-logon location and the user is not logged in - if service_location.mode == ServiceLocationMode::AlwaysOn as i32 - || (service_location.mode == ServiceLocationMode::PreLogon as i32 + if service_location_data.service_location.mode == ServiceLocationMode::AlwaysOn as i32 + || (service_location_data.service_location.mode == ServiceLocationMode::PreLogon as i32 && !is_user_logged_in()) { debug!( - "Reconnecting service location for instance_id: {}, location_pubkey: {}", - instance_id, location_pubkey + "Reconnecting service location for instance_id: {instance_id}, location_pubkey: {location_pubkey} ({})", + service_location_data.service_location.name ); - ServiceLocationApi::connect_to_service_location(instance_id, location_pubkey)?; + self.connect_to_service_location(&service_location_data)?; } debug!("Service location state reset completed."); @@ -619,99 +403,91 @@ impl ServiceLocationApi { } pub(crate) fn disconnect_service_locations_by_instance( + &mut self, instance_id: &str, ) -> Result<(), ServiceLocationError> { - debug!( - "Disconnecting all service locations for instance_id: {}", - instance_id - ); + debug!("Disconnecting all service locations for instance_id: {instance_id}"); + + if let Some(locations) = self.connected_service_locations.get(instance_id) { + // Collect locations to disconnect to avoid borrowing issues + let locations_to_disconnect: Vec<_> = locations.iter().cloned().collect(); - let locations = Self::get_connected_service_locations(); - for (connected_instance_id, location) in &locations { - if instance_id == connected_instance_id { + for location in locations_to_disconnect { let ifname = get_interface_name(&location.name); - debug!("Tearing down interface: {}", ifname); - if let Ok(wgapi) = setup_wgapi(&ifname) { + debug!("Tearing down interface: {ifname}"); + if let Some(mut wgapi) = self.wgapis.remove(&ifname) { if let Err(err) = wgapi.remove_interface() { - let msg = format!("Failed to remove interface {}: {}", ifname, err); - error!("{}", msg); - debug!("{}", msg); + error!("Failed to remove interface {ifname}: {err}"); } else { - debug!("Interface {} removed successfully", ifname); + debug!("Interface {ifname} removed successfully"); } debug!( - "Removing connected service location for instance_id: {}, location_pubkey: {}", - connected_instance_id, location.pubkey - ); - Self::remove_connected_service_locations(|inst_id, loc| { - inst_id == connected_instance_id && loc.pubkey == location.pubkey - })?; + "Removing connected service location for instance_id: {instance_id}, location_pubkey: {}", + location.pubkey + ); debug!( - "Disconnected service location for instance_id: {}, location_pubkey: {}", - connected_instance_id, location.pubkey + "Disconnected service location for instance_id: {instance_id}, location_pubkey: {}", + location.pubkey ); } else { - let msg = format!("Failed to setup WireGuard API for interface {}", ifname); - error!("{}", msg); - debug!("{}", msg); + error!("Failed to find WireGuard API for interface {ifname}"); } } + + self.connected_service_locations.remove(instance_id); + } else { + debug!( + "No connected service locations found for instance_id: {instance_id}. Skipping disconnect" + ); + return Ok(()); } - debug!( - "Disconnected all service locations for instance_id: {}", - instance_id - ); + debug!("Disconnected all service locations for instance_id: {instance_id}"); Ok(()) } pub(crate) fn disconnect_service_location( + &mut self, instance_id: &str, location_pubkey: &str, ) -> Result<(), ServiceLocationError> { debug!( - "Disconnecting service location for instance_id: {}, location_pubkey: {}", - instance_id, location_pubkey + "Disconnecting service location for instance_id: {instance_id}, location_pubkey: {location_pubkey}" ); - let locations = Self::get_connected_service_locations(); - for (connected_instance_id, location) in locations { - if instance_id == connected_instance_id && location.pubkey == location_pubkey { + if let Some(locations) = self.connected_service_locations.get_mut(instance_id) { + if let Some(pos) = locations + .iter() + .position(|loc| &loc.pubkey == location_pubkey) + { + let location = locations.remove(pos); let ifname = get_interface_name(&location.name); - debug!("Tearing down interface: {}", ifname); - if let Ok(wgapi) = setup_wgapi(&ifname) { + debug!("Tearing down interface: {ifname}"); + if let Some(mut wgapi) = self.wgapis.remove(&ifname) { if let Err(err) = wgapi.remove_interface() { - let msg = format!("Failed to remove interface {}: {}", ifname, err); - error!("{}", msg); - debug!("{}", msg); + error!("Failed to remove interface {ifname}: {err}"); } else { - debug!("Interface {} removed successfully.", ifname); + debug!("Interface {ifname} removed successfully."); } - debug!( - "Removing connected service location for instance_id: {}, location_pubkey: {}", - connected_instance_id, location.pubkey - ); - Self::remove_connected_service_locations(|inst_id, loc| { - inst_id == connected_instance_id && loc.pubkey == location_pubkey - })?; - debug!( - "Disconnected service location for instance_id: {}, location_pubkey: {}", - connected_instance_id, location.pubkey - ); } else { - let msg = format!("Failed to setup WireGuard API for interface {}", ifname); - error!("{}", msg); - debug!("{}", msg); + error!("Failed to find WireGuard API for interface {ifname}. "); } - - break; + } else { + debug!( + "Service location with pubkey {location_pubkey} for instance {instance_id} is not connected, skipping disconnect" + ); + return Ok(()); } + } else { + debug!( + "No connected service locations found for instance_id: {instance_id}, skipping disconnect" + ); + return Ok(()); } debug!( - "Disconnected service location for instance_id: {}, location_pubkey: {}", - instance_id, location_pubkey + "Disconnected service location for instance_id: {instance_id}, location_pubkey: {location_pubkey}" ); Ok(()) @@ -719,6 +495,7 @@ impl ServiceLocationApi { /// Helper function to setup a WireGuard interface for a service location fn setup_service_location_interface( + &mut self, location: &ServiceLocation, private_key: &str, ) -> Result<(), ServiceLocationError> { @@ -766,18 +543,17 @@ impl ServiceLocationApi { let ifname = location.name.clone(); let ifname = get_interface_name(&ifname); - let wgapi = match setup_wgapi(&ifname) { + let mut wgapi = match setup_wgapi(&ifname) { Ok(api) => api, Err(err) => { - let msg = format!( - "Failed to setup WireGuard API for interface {}: {:?}", - ifname, err - ); - debug!("{}", msg); + let msg = format!("Failed to setup WireGuard API for interface {ifname}: {err:?}"); + debug!("{msg}"); return Err(ServiceLocationError::InterfaceError(msg)); } }; + wgapi.create_interface()?; + // Extract DNS configuration if available let dns_string = location.dns.clone(); let dns_entries = dns_string.split(',').map(str::trim).collect::>(); @@ -793,96 +569,99 @@ impl ServiceLocationApi { } debug!( - "Configuring interface {} with DNS: {:?} and search domains: {:?}", - ifname, dns, search_domains + "Configuring interface {ifname} with DNS: {:?} and search domains: {:?}", + dns, search_domains ); debug!("Interface Configuration: {:?}", config); - wgapi.configure_interface(&config, &dns, &search_domains)?; + wgapi.configure_interface(&config)?; + wgapi.configure_dns(&dns, &search_domains)?; + + self.wgapis.insert(ifname.clone(), wgapi); - debug!("Interface {} configured successfully.", ifname); + debug!("Interface {ifname} configured successfully."); Ok(()) } pub(crate) fn connect_to_service_location( - instance_id: &str, - location_pubkey: &str, + &mut self, + location_data: &SingleServiceLocationData, ) -> Result<(), ServiceLocationError> { + let instance_id = &location_data.instance_id; + let location_pubkey = &location_data.service_location.pubkey; debug!( - "Connecting to service location for instance_id: {}, location_pubkey: {}", - instance_id, location_pubkey + "Connecting to service location for instance_id: {instance_id}, location_pubkey: {location_pubkey}" ); // Check if already connected to this service location - if Self::is_service_location_connected(instance_id, location_pubkey) { + if self.is_service_location_connected(instance_id, location_pubkey) { debug!( - "Service location with pubkey {} for instance {} is already connected, skipping", - location_pubkey, instance_id + "Service location with pubkey {location_pubkey} for instance {instance_id} is already connected, skipping" ); return Ok(()); } - let locations = ServiceLocationApi::load_service_location_by_instance_id(instance_id)?; - let data = ServiceLocationApi::load_service_locations()?; - let instance_data = data - .into_iter() - .find(|d| d.instance_id == instance_id) + let location_data = self + .load_service_location(instance_id, location_pubkey)? .ok_or_else(|| { - ServiceLocationError::LoadError(format!("Instance ID {} not found", instance_id)) + ServiceLocationError::LoadError(format!( + "Service location with pubkey {} for instance {} not found", + location_pubkey, instance_id + )) })?; - for location in locations { - if location.pubkey == location_pubkey { - Self::setup_service_location_interface(&location, &instance_data.private_key)?; - Self::add_connected_service_location(&instance_data.instance_id, &location)?; - let ifname = get_interface_name(&location.name); - debug!("Successfully connected to service location '{}'", ifname); - break; - } - } + self.setup_service_location_interface( + &location_data.service_location, + &location_data.private_key, + )?; + self.add_connected_service_location( + &location_data.instance_id, + &location_data.service_location, + )?; + let ifname = get_interface_name(&location_data.service_location.name); + debug!("Successfully connected to service location '{ifname}'"); Ok(()) } pub(crate) fn disconnect_service_locations( + &mut self, mode: Option, ) -> Result<(), ServiceLocationError> { - debug!("Disconnecting service locations..."); - - let locations = Self::get_connected_service_locations(); - - debug!("Tearing down {} interfaces", locations.len()); + debug!("Disconnecting service locations with mode: {mode:?}"); - for (_, location) in &locations { - if let Some(m) = mode { - let location_mode: ServiceLocationMode = location.mode.try_into()?; - if location_mode != m { - debug!( - "Skipping interface {} due to the service location mode doesn't match the requested mode (expected {:?}, found {:?})", - location.name, m, location.mode + for (instance, locations) in self.connected_service_locations.iter() { + for location in locations { + debug!( + "Found connected service location for instance_id: {instance}, location_pubkey: {}", + location.pubkey ); - continue; + if let Some(m) = mode { + let location_mode: ServiceLocationMode = location.mode.try_into()?; + if location_mode != m { + debug!( + "Skipping interface {} due to the service location mode doesn't match the requested mode (expected {m:?}, found {:?})", + location.name, location.mode + ); + continue; + } } - } - let ifname = get_interface_name(&location.name); - debug!("Tearing down interface: {}", ifname); - if let Ok(wgapi) = setup_wgapi(&ifname) { - if let Err(err) = wgapi.remove_interface() { - let msg = format!("Failed to remove interface {}: {}", ifname, err); - error!("{}", msg); - debug!("{}", msg); + let ifname = get_interface_name(&location.name); + debug!("Tearing down interface: {ifname}"); + if let Some(mut wgapi) = self.wgapis.remove(&ifname) { + if let Err(err) = wgapi.remove_interface() { + error!("Failed to remove interface {ifname}: {err}"); + } else { + debug!("Interface {ifname} removed successfully."); + } } else { - debug!("Interface {} removed successfully.", ifname); + error!("Failed to find WireGuard API for interface {ifname}"); } - } else { - let msg = format!("Failed to setup WireGuard API for interface {}", ifname); - error!("{}", msg); - debug!("{}", msg); } } - Self::remove_connected_service_locations(|_, location| { + self.remove_connected_service_locations(|_, location| { if let Some(m) = mode { let location_mode: ServiceLocationMode = location .mode @@ -899,10 +678,10 @@ impl ServiceLocationApi { Ok(()) } - pub(crate) fn connect_to_service_locations() -> Result<(), ServiceLocationError> { + pub(crate) fn connect_to_service_locations(&mut self) -> Result<(), ServiceLocationError> { debug!("Attempting to auto-connect to VPN..."); - let data = Self::load_service_locations()?; + let data = self.load_service_locations()?; debug!("Loaded {} instance(s) from ServiceLocationApi", data.len()); for instance_data in data { @@ -932,7 +711,7 @@ impl ServiceLocationApi { } } - if Self::is_service_location_connected(&instance_data.instance_id, &location.pubkey) + if self.is_service_location_connected(&instance_data.instance_id, &location.pubkey) { debug!( "Skipping service location '{}' because it's already connected", @@ -942,37 +721,32 @@ impl ServiceLocationApi { } if let Err(err) = - Self::setup_service_location_interface(&location, &instance_data.private_key) + self.setup_service_location_interface(&location, &instance_data.private_key) { debug!( - "Failed to setup service location interface for '{}': {:?}", - location.name, err + "Failed to setup service location interface for '{}': {err:?}", + location.name ); continue; } if let Err(err) = - Self::add_connected_service_location(&instance_data.instance_id, &location) + self.add_connected_service_location(&instance_data.instance_id, &location) { debug!( - "Failed to persist connected service location after auto-connect: {:?}", - err + "Failed to persist connected service location after auto-connect: {err:?}" ); } - let ifname = get_interface_name(&location.name); - debug!("Successfully connected to service location '{}'", ifname); + debug!( + "Successfully connected to service location '{}'", + location.name + ); } } debug!("Auto-connect attempt completed"); - let current_locations = Self::get_connected_service_locations(); - debug!( - "Currently connected service locations: {:?}", - current_locations - ); - Ok(()) } @@ -982,9 +756,8 @@ impl ServiceLocationApi { private_key: &str, ) -> Result<(), ServiceLocationError> { debug!( - "Received a request to save {} service location(s) for instance {}", + "Received a request to save {} service location(s) for instance {instance_id}", service_locations.len(), - instance_id ); debug!("Service locations to save: {:?}", service_locations); @@ -1006,14 +779,13 @@ impl ServiceLocationApi { instance_file_path ); - std::fs::write(&instance_file_path, &json)?; + fs::write(&instance_file_path, &json)?; if let Some(file_path_str) = instance_file_path.to_str() { - debug!("Setting ACLs on service location file: {}", file_path_str); + debug!("Setting ACLs on service location file: {file_path_str}"); if let Err(e) = set_protected_acls(file_path_str) { warn!( - "Failed to set ACLs on service location file {}: {}. File saved but may have insecure permissions.", - file_path_str, e + "Failed to set ACLs on service location file {file_path_str}: {e}. File saved but may have insecure permissions." ); } else { debug!("Successfully set ACLs on service location file"); @@ -1023,43 +795,38 @@ impl ServiceLocationApi { } debug!( - "Service locations saved successfully for instance {} to {:?}", - instance_id, instance_file_path + "Service locations saved successfully for instance {instance_id} to {:?}", + instance_file_path ); Ok(()) } - fn load_service_locations() -> Result, ServiceLocationError> { + fn load_service_locations(&self) -> Result, ServiceLocationError> { let base_dir = get_shared_directory()?; let mut all_locations_data = Vec::new(); if base_dir.exists() { - for entry in std::fs::read_dir(base_dir)? { + for entry in fs::read_dir(base_dir)? { let entry = entry?; let file_path = entry.path(); if file_path.is_file() && file_path.extension().and_then(|s| s.to_str()) == Some("json") - && file_path.file_name().and_then(|s| s.to_str()) - != Some(CONNECTED_LOCATIONS_FILENAME) { - match std::fs::read_to_string(&file_path) { + match fs::read_to_string(&file_path) { Ok(data) => match serde_json::from_str::(&data) { Ok(locations_data) => { all_locations_data.push(locations_data); } Err(e) => { error!( - "Failed to parse service locations from file {:?}: {}", - file_path, e + "Failed to parse service locations from file {:?}: {e}", + file_path ); } }, Err(e) => { - error!( - "Failed to read service locations file {:?}: {}", - file_path, e - ); + error!("Failed to read service locations file {:?}: {e}", file_path); } } } @@ -1073,189 +840,57 @@ impl ServiceLocationApi { Ok(all_locations_data) } - fn load_service_location_by_instance_id( - instance_id: &str, - ) -> Result, ServiceLocationError> { - debug!("Loading service locations for instance {}", instance_id); - - let instance_file_path = get_instance_file_path(instance_id)?; - - if instance_file_path.exists() { - let data = std::fs::read_to_string(&instance_file_path)?; - let service_location_data = serde_json::from_str::(&data)?; - Ok(service_location_data.service_locations) - } else { - debug!( - "No service location file found for instance {}", - instance_id - ); - Ok(Vec::new()) - } - } - - fn load_service_location_by_instance_and_pubkey( + fn load_service_location( + &self, instance_id: &str, location_pubkey: &str, - ) -> Result, ServiceLocationError> { - debug!( - "Loading service location for instance {} and pubkey {}", - instance_id, location_pubkey - ); + ) -> Result, ServiceLocationError> { + debug!("Loading service location for instance {instance_id} and pubkey {location_pubkey}"); let instance_file_path = get_instance_file_path(instance_id)?; if instance_file_path.exists() { - let data = std::fs::read_to_string(&instance_file_path)?; + let data = fs::read_to_string(&instance_file_path)?; let service_location_data = serde_json::from_str::(&data)?; for location in service_location_data.service_locations { if location.pubkey == location_pubkey { debug!( - "Successfully loaded service location for instance {} and pubkey {}", - instance_id, location_pubkey + "Successfully loaded service location for instance {instance_id} and pubkey {location_pubkey}" ); - return Ok(Some(location)); + return Ok(Some(SingleServiceLocationData { + service_location: location, + instance_id: service_location_data.instance_id, + private_key: service_location_data.private_key, + })); } } debug!( - "No service location found for instance {} with pubkey {}", - instance_id, location_pubkey + "No service location found for instance {instance_id} with pubkey {location_pubkey}" ); Ok(None) } else { - debug!( - "No service location file found for instance {}", - instance_id - ); + debug!("No service location file found for instance {instance_id}"); Ok(None) } } pub(crate) fn delete_all_service_locations_for_instance( + &self, instance_id: &str, ) -> Result<(), ServiceLocationError> { - debug!( - "Deleting all service locations for instance {}", - instance_id - ); + debug!("Deleting all service locations for instance {instance_id}"); let instance_file_path = get_instance_file_path(instance_id)?; if instance_file_path.exists() { - std::fs::remove_file(&instance_file_path)?; - debug!( - "Successfully deleted all service locations for instance {}", - instance_id - ); + fs::remove_file(&instance_file_path)?; + debug!("Successfully deleted all service locations for instance {instance_id}"); } else { - debug!( - "No service location file found for instance {}", - instance_id - ); + debug!("No service location file found for instance {instance_id}"); } Ok(()) } - - /// Validates that every running location still exists in the instance files - /// Returns a vector of tuples (instance_id, location_pubkey) for running locations that no longer exist - fn find_invalid_locations() -> Result, ServiceLocationError> { - debug!("Validating that all running locations still exist in instance files"); - - let connected_locations = Self::get_connected_service_locations(); - let mut invalid_locations = Vec::new(); - - for (instance_id, connected_location) in &connected_locations { - debug!( - "Checking if location '{}' (pubkey: {}) for instance '{}' still exists", - connected_location.name, connected_location.pubkey, instance_id - ); - - match ServiceLocationApi::load_service_location_by_instance_and_pubkey( - instance_id, - &connected_location.pubkey, - ) { - Ok(Some(_)) => { - debug!( - "Location '{}' (pubkey: {}) for instance '{}' exists in instance file", - connected_location.name, connected_location.pubkey, instance_id - ); - } - Ok(None) => { - warn!( - "Running location '{}' (pubkey: {}) for instance '{}' no longer exists in instance file", - connected_location.name, connected_location.pubkey, instance_id - ); - invalid_locations - .push((instance_id.clone(), connected_location.pubkey.clone())); - } - Err(err) => { - warn!( - "Failed to load location '{}' (pubkey: {}) for instance '{}': {:?}. Marking as invalid.", - connected_location.name, connected_location.pubkey, instance_id, err - ); - invalid_locations - .push((instance_id.clone(), connected_location.pubkey.clone())); - } - } - } - - if invalid_locations.is_empty() { - debug!("All running locations are valid and exist in instance files"); - } else { - warn!( - "Found {} running location(s) that no longer exist in instance files", - invalid_locations.len() - ); - } - - Ok(invalid_locations) - } - - /// Cleans up invalid running locations that no longer exist in instance files - /// This function will disconnect and remove any running locations that are not found in the instance files - /// Returns the number of locations that were cleaned up - fn cleanup_invalid_locations() -> Result { - debug!("Starting cleanup of invalid running locations"); - - let invalid_locations = Self::find_invalid_locations()?; - - if invalid_locations.is_empty() { - debug!("No invalid locations to clean up"); - return Ok(0); - } - - let cleanup_count = invalid_locations.len(); - debug!("Found {} invalid location(s) to clean up", cleanup_count); - - for (instance_id, location_pubkey) in invalid_locations { - debug!( - "Cleaning up invalid location with pubkey '{}' for instance '{}'", - location_pubkey, instance_id - ); - - match Self::disconnect_service_location(&instance_id, &location_pubkey) { - Ok(_) => { - debug!( - "Successfully cleaned up invalid location '{}' for instance '{}'", - location_pubkey, instance_id - ); - } - Err(err) => { - error!( - "Failed to disconnect invalid location '{}' for instance '{}': {:?}", - location_pubkey, instance_id, err - ); - } - } - } - - debug!( - "Cleanup complete. Disconnected and removed {} invalid location(s)", - cleanup_count - ); - - Ok(cleanup_count) - } } diff --git a/src-tauri/src/error.rs b/src-tauri/src/error.rs index ac8bbde4..a34c4330 100644 --- a/src-tauri/src/error.rs +++ b/src-tauri/src/error.rs @@ -2,8 +2,6 @@ use std::net::AddrParseError; use defguard_wireguard_rs::{error::WireguardInterfaceError, net::IpAddrParseError}; -use crate::enterprise::service_locations::ServiceLocationError; - #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] diff --git a/src-tauri/src/service/mod.rs b/src-tauri/src/service/mod.rs index 5e07d5ea..9e234de8 100644 --- a/src-tauri/src/service/mod.rs +++ b/src-tauri/src/service/mod.rs @@ -48,17 +48,19 @@ use tonic::{ }; use tracing::{debug, error, info, info_span, Instrument}; +use self::config::Config; +use super::VERSION; use crate::enterprise::service_locations::ServiceLocationError; #[cfg(not(windows))] use crate::service::proto::DeleteServiceLocationsRequest; #[cfg(windows)] -use crate::service::proto::{ - DeleteServiceLocationsRequest, RestartServiceLocationRequest, SaveServiceLocationsRequest, +use crate::{ + enterprise::service_locations::ServiceLocationManager, + service::proto::{ + DeleteServiceLocationsRequest, ResetServiceLocationRequest, SaveServiceLocationsRequest, + }, }; -use self::config::Config; -use super::VERSION; - #[cfg(windows)] const DAEMON_HTTP_PORT: u16 = 54127; pub(super) const DAEMON_BASE_URL: &str = "http://localhost:54127"; @@ -98,15 +100,22 @@ pub(crate) struct DaemonService { wgapis: Arc>>, stats_period: Duration, stat_tasks: Arc>>>, + #[cfg(windows)] + service_location_manager: Arc>, } impl DaemonService { #[must_use] - pub fn new(config: &Config) -> Self { + pub fn new( + config: &Config, + #[cfg(windows)] service_location_manager: Arc>, + ) -> Self { Self { wgapis: Arc::new(RwLock::new(HashMap::new())), stats_period: Duration::from_secs(config.stats_period), stat_tasks: Arc::new(Mutex::new(HashMap::new())), + #[cfg(windows)] + service_location_manager, } } } @@ -148,7 +157,7 @@ impl DesktopDaemonService for DaemonService { #[cfg(not(windows))] async fn reset_service_location( &self, - request: tonic::Request, + request: tonic::Request, ) -> Result, Status> { debug!("Restart service location request received, this is currently not supported on Unix systems"); Ok(Response::new(())) @@ -159,12 +168,12 @@ impl DesktopDaemonService for DaemonService { &self, request: tonic::Request, ) -> Result, Status> { - use crate::enterprise::service_locations::ServiceLocationApi; + use crate::enterprise::service_locations::ServiceLocationManager; debug!("Received a request to save service location"); let service_location = request.into_inner(); - match ServiceLocationApi::save_service_locations( + match ServiceLocationManager::save_service_locations( service_location.service_locations.as_slice(), &service_location.instance_id, &service_location.private_key, @@ -188,19 +197,26 @@ impl DesktopDaemonService for DaemonService { &self, request: tonic::Request, ) -> Result, Status> { - use crate::enterprise::service_locations::ServiceLocationApi; - debug!("Received a request to delete service location"); let instance_id = request.into_inner().instance_id; - ServiceLocationApi::disconnect_service_locations_by_instance(&instance_id).map_err( - |e| { + self.service_location_manager + .clone() + .write() + .unwrap() + .disconnect_service_locations_by_instance(&instance_id) + .map_err(|e| { error!("Failed to disconnect service location: {}", e); Status::internal(format!("Failed to disconnect service location: {}", e)) - }, - )?; + })?; - match ServiceLocationApi::delete_all_service_locations_for_instance(&instance_id) { + match self + .service_location_manager + .clone() + .read() + .unwrap() + .delete_all_service_locations_for_instance(&instance_id) + { Ok(()) => { debug!("Service location deleted successfully"); Ok(Response::new(())) @@ -218,13 +234,14 @@ impl DesktopDaemonService for DaemonService { #[cfg(windows)] async fn reset_service_location( &self, - request: tonic::Request, + request: tonic::Request, ) -> Result, Status> { - use crate::enterprise::service_locations::ServiceLocationApi; - let request = request.into_inner(); - ServiceLocationApi::reset_service_location_state(&request.instance_id, &request.pubkey) - .await + self.service_location_manager + .clone() + .write() + .unwrap() + .reset_service_location_state(&request.instance_id, &request.pubkey) .map_err(|e| { error!("Failed to restart service location: {}", e); Status::internal(format!("Failed to restart service location: {}", e)) @@ -530,11 +547,14 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { } #[cfg(windows)] -pub async fn run_server(config: Config) -> anyhow::Result<()> { +pub(crate) async fn run_server( + config: Config, + service_location_manager: Arc>, +) -> anyhow::Result<()> { debug!("Starting Defguard interface management daemon"); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), DAEMON_HTTP_PORT); - let daemon_service = DaemonService::new(&config); + let daemon_service = DaemonService::new(&config, service_location_manager); info!("Defguard daemon version {VERSION} started, listening on {addr}",); debug!("Defguard daemon configuration: {config:?}"); diff --git a/src-tauri/src/service/windows.rs b/src-tauri/src/service/windows.rs index 7d400a86..f65d48e5 100644 --- a/src-tauri/src/service/windows.rs +++ b/src-tauri/src/service/windows.rs @@ -1,21 +1,12 @@ use std::{ ffi::OsString, - fs::OpenOptions, - net::IpAddr, result::Result, - str::FromStr, - sync::{mpsc, LazyLock, RwLock}, + sync::{mpsc, Arc, RwLock}, time::Duration, }; -use chrono::Utc; use clap::Parser; -use common::{find_free_tcp_port, get_interface_name}; -use defguard_wireguard_rs::{ - host::Peer, key::Key, net::IpAddrMask, InterfaceConfiguration, WireguardInterfaceApi, -}; use error; -use std::io::Write; use tokio::runtime::Runtime; use windows_service::{ define_windows_service, @@ -28,24 +19,15 @@ use windows_service::{ }; use crate::{ - enterprise::service_locations::{windows::watch_for_login_logoff, ServiceLocationApi}, - error::Error, + enterprise::service_locations::{windows::watch_for_login_logoff, ServiceLocationManager, ServiceLocationError}, service::{ - proto::{ServiceLocation, ServiceLocationMode}, - run_server, setup_wgapi, + run_server, utils::logging_setup, Config, DaemonError, }, - utils::{DEFAULT_ROUTE_IPV4, DEFAULT_ROUTE_IPV6}, -}; -use windows::{ - core::PSTR, - Win32::System::RemoteDesktop::{ - WTSQuerySessionInformationA, WTSWaitSystemEvent, WTS_CURRENT_SERVER_HANDLE, - WTS_EVENT_LOGOFF, WTS_EVENT_LOGON, WTS_SESSION_INFOA, - }, }; + static SERVICE_NAME: &str = "DefguardService"; const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; const LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS: u64 = 10; @@ -113,26 +95,32 @@ fn run_service() -> Result<(), DaemonError> { std::process::exit(1); })); - + + let service_location_manager = match ServiceLocationManager::init() { + Ok(api) => { + info!("Service locations storage initialized successfully"); + Ok(api) + } + Err(e) => { + error!( + "Failed to initialize service locations storage: {}. Shutting down service location thread", + e + ); + Err(ServiceLocationError::InitError(e.to_string())) + } + }?; + + let service_location_manager = Arc::new(RwLock::new(service_location_manager)); + + let service_location_manager_clone = service_location_manager.clone(); runtime.spawn(async move { info!("Starting service location management task"); - - match ServiceLocationApi::init() { - Ok(_) => { - info!("Service locations storage initialized successfully"); - } - Err(e) => { - error!( - "Failed to initialize service locations storage: {}. Shutting down service location thread", - e - ); - return; - } - } - + + let manager = service_location_manager_clone; + // Attempt to connect to service locations info!("Attempting to auto-connect to service locations"); - match ServiceLocationApi::connect_to_service_locations() { + match manager.write().unwrap().connect_to_service_locations() { Ok(_) => { info!("Auto-connect to service locations completed successfully"); } @@ -148,7 +136,9 @@ fn run_service() -> Result<(), DaemonError> { // Start watching for login/logoff events with error recovery info!("Starting login/logoff event monitoring"); loop { - match watch_for_login_logoff().await { + match watch_for_login_logoff( + manager.clone(), + ).await { Ok(_) => { warn!("Login/logoff event monitoring ended unexpectedly"); break; @@ -168,8 +158,11 @@ fn run_service() -> Result<(), DaemonError> { }); + let service_location_manager_clone = service_location_manager.clone(); runtime.spawn(async move { - let server_result = run_server(config).await; + let server_result = run_server(config, + service_location_manager_clone + ).await; if server_result.is_err() { let _ = shutdown_tx_server.send(2); From 8f754d91f5f0240aff9a4d3db432747d23fca3d6 Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Thu, 23 Oct 2025 11:08:09 +0200 Subject: [PATCH 03/13] some cleanup --- src-tauri/src/database/models/location.rs | 6 ++++-- src-tauri/src/periodic/mod.rs | 6 +++--- src-tauri/src/service/windows.rs | 21 ++++++--------------- src-tauri/src/utils.rs | 5 ++--- 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/src-tauri/src/database/models/location.rs b/src-tauri/src/database/models/location.rs index e7449059..24177bab 100644 --- a/src-tauri/src/database/models/location.rs +++ b/src-tauri/src/database/models/location.rs @@ -5,8 +5,10 @@ use sqlx::{prelude::Type, query, query_as, query_scalar, Error as SqlxError, Sql use super::{Id, NoId}; use crate::{ - error::Error, proto::LocationMfaMode as ProtoLocationMfaMode, - proto::ServiceLocationMode as ProtoServiceLocationMode, + error::Error, + proto::{ + LocationMfaMode as ProtoLocationMfaMode, ServiceLocationMode as ProtoServiceLocationMode, + }, }; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Type)] diff --git a/src-tauri/src/periodic/mod.rs b/src-tauri/src/periodic/mod.rs index 0b4fab97..37daff37 100644 --- a/src-tauri/src/periodic/mod.rs +++ b/src-tauri/src/periodic/mod.rs @@ -1,9 +1,9 @@ -use self::{ - connection::verify_active_connections, purge_stats::purge_stats, version::poll_version, -}; use tauri::AppHandle; use tokio::select; +use self::{ + connection::verify_active_connections, purge_stats::purge_stats, version::poll_version, +}; use crate::enterprise::periodic::config::poll_config; pub mod connection; diff --git a/src-tauri/src/service/windows.rs b/src-tauri/src/service/windows.rs index f65d48e5..aa3cd514 100644 --- a/src-tauri/src/service/windows.rs +++ b/src-tauri/src/service/windows.rs @@ -19,15 +19,12 @@ use windows_service::{ }; use crate::{ - enterprise::service_locations::{windows::watch_for_login_logoff, ServiceLocationManager, ServiceLocationError}, - service::{ - run_server, - utils::logging_setup, - Config, DaemonError, + enterprise::service_locations::{ + windows::watch_for_login_logoff, ServiceLocationError, ServiceLocationManager, }, + service::{run_server, utils::logging_setup, Config, DaemonError}, }; - static SERVICE_NAME: &str = "DefguardService"; const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; const LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS: u64 = 10; @@ -95,7 +92,6 @@ fn run_service() -> Result<(), DaemonError> { std::process::exit(1); })); - let service_location_manager = match ServiceLocationManager::init() { Ok(api) => { info!("Service locations storage initialized successfully"); @@ -118,7 +114,6 @@ fn run_service() -> Result<(), DaemonError> { let manager = service_location_manager_clone; - // Attempt to connect to service locations info!("Attempting to auto-connect to service locations"); match manager.write().unwrap().connect_to_service_locations() { Ok(_) => { @@ -132,8 +127,7 @@ fn run_service() -> Result<(), DaemonError> { ); } } - - // Start watching for login/logoff events with error recovery + info!("Starting login/logoff event monitoring"); loop { match watch_for_login_logoff( @@ -153,16 +147,13 @@ fn run_service() -> Result<(), DaemonError> { } } } - + warn!("Service location management task terminated"); }); - let service_location_manager_clone = service_location_manager.clone(); runtime.spawn(async move { - let server_result = run_server(config, - service_location_manager_clone - ).await; + let server_result = run_server(config, service_location_manager_clone).await; if server_result.is_err() { let _ = shutdown_tx_server.send(2); diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs index 7498fa3b..0b979dbf 100644 --- a/src-tauri/src/utils.rs +++ b/src-tauri/src/utils.rs @@ -14,6 +14,8 @@ use windows_service::{ service_manager::{ServiceManager, ServiceManagerAccess}, }; +#[cfg(target_os = "windows")] +use crate::active_connections::find_connection; use crate::{ appstate::AppState, commands::LocationInterfaceDetails, @@ -38,9 +40,6 @@ use crate::{ ConnectionType, }; -#[cfg(target_os = "windows")] -use crate::active_connections::find_connection; - pub(crate) static DEFAULT_ROUTE_IPV4: &str = "0.0.0.0/0"; pub(crate) static DEFAULT_ROUTE_IPV6: &str = "::/0"; From 12e8c9bfded238e05e5b2c1408b932f9b5c1cb2b Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Thu, 23 Oct 2025 12:17:15 +0200 Subject: [PATCH 04/13] remove pretty printing --- src-tauri/src/commands.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index 9ab57f9c..c1dfc3fe 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -713,8 +713,6 @@ pub(crate) async fn do_update_instance( private_key: private_key.clone(), }; - debug!("Prepared save request: {save_request:#?}"); - debug!( "Sending request to daemon to save {} service location(s) for instance {}({})", save_request.service_locations.len(), From 2d108a99c2db7f2d2f427a53bc43a06ac235cda3 Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Thu, 23 Oct 2025 12:34:06 +0200 Subject: [PATCH 05/13] clippy, sqlx --- ...569d3d6f9d8924458c8b357dd400966f4175.json} | 10 ++- ...9d04839e69a6c2c5cb8ad5c2f8e19547a2f6.json} | 10 ++- ...5dede3c312162dcc73cea9c883289ba9fa8e.json} | 6 +- ...8c3384709a2bdd3cd769e29b0caf5014d624.json} | 10 ++- ...32a87bd603debccaec23b160150766bdcd9f.json} | 6 +- ...16259e63975a3ec89a3c9b95d833774e9dfef.json | 86 ------------------- .../src/enterprise/service_locations/mod.rs | 6 +- src-tauri/src/service/mod.rs | 18 ++-- 8 files changed, 42 insertions(+), 110 deletions(-) rename src-tauri/.sqlx/{query-e91278b90769f39e2cdf1677ffa1193580af693f9871a7162c47393daac8af11.json => query-76c5c9b75df39afca9cd07530ab0569d3d6f9d8924458c8b357dd400966f4175.json} (83%) rename src-tauri/.sqlx/{query-7bbc28ee5a141e5b531a6ac5a1cbf120828a0b9c19301c92a3f71531c08c698d.json => query-85f8edf373d3bf1d405a8fed804d9d04839e69a6c2c5cb8ad5c2f8e19547a2f6.json} (83%) rename src-tauri/.sqlx/{query-3421da72f01d726c2931071203d663b197cb518dd65ec73108f85b2cb7270741.json => query-b882379427740576d70c89eaeb815dede3c312162dcc73cea9c883289ba9fa8e.json} (64%) rename src-tauri/.sqlx/{query-ac02b04f6490a768571290d7dc77444eb0ca55a3a7e159c3b2e529ebf75f224f.json => query-d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624.json} (82%) rename src-tauri/.sqlx/{query-e02047df7deea862cceca537e49ae16a8237e91eff0ee684cacd2ec1c77adb58.json => query-ea39145f2cdc783bc78b32363cce32a87bd603debccaec23b160150766bdcd9f.json} (59%) delete mode 100644 src-tauri/.sqlx/query-f660459ee3beed1e88815560c3f16259e63975a3ec89a3c9b95d833774e9dfef.json diff --git a/src-tauri/.sqlx/query-e91278b90769f39e2cdf1677ffa1193580af693f9871a7162c47393daac8af11.json b/src-tauri/.sqlx/query-76c5c9b75df39afca9cd07530ab0569d3d6f9d8924458c8b357dd400966f4175.json similarity index 83% rename from src-tauri/.sqlx/query-e91278b90769f39e2cdf1677ffa1193580af693f9871a7162c47393daac8af11.json rename to src-tauri/.sqlx/query-76c5c9b75df39afca9cd07530ab0569d3d6f9d8924458c8b357dd400966f4175.json index eb35ee44..9c6f65ed 100644 --- a/src-tauri/.sqlx/query-e91278b90769f39e2cdf1677ffa1193580af693f9871a7162c47393daac8af11.json +++ b/src-tauri/.sqlx/query-76c5c9b75df39afca9cd07530ab0569d3d6f9d8924458c8b357dd400966f4175.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" FROM location WHERE id = $1", + "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM location WHERE id = $1", "describe": { "columns": [ { @@ -62,6 +62,11 @@ "name": "location_mfa_mode: LocationMfaMode", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "service_location_mode: ServiceLocationMode", + "ordinal": 12, + "type_info": "Integer" } ], "parameters": { @@ -79,8 +84,9 @@ false, false, false, + false, false ] }, - "hash": "e91278b90769f39e2cdf1677ffa1193580af693f9871a7162c47393daac8af11" + "hash": "76c5c9b75df39afca9cd07530ab0569d3d6f9d8924458c8b357dd400966f4175" } diff --git a/src-tauri/.sqlx/query-7bbc28ee5a141e5b531a6ac5a1cbf120828a0b9c19301c92a3f71531c08c698d.json b/src-tauri/.sqlx/query-85f8edf373d3bf1d405a8fed804d9d04839e69a6c2c5cb8ad5c2f8e19547a2f6.json similarity index 83% rename from src-tauri/.sqlx/query-7bbc28ee5a141e5b531a6ac5a1cbf120828a0b9c19301c92a3f71531c08c698d.json rename to src-tauri/.sqlx/query-85f8edf373d3bf1d405a8fed804d9d04839e69a6c2c5cb8ad5c2f8e19547a2f6.json index f5faadd8..1615e9b4 100644 --- a/src-tauri/.sqlx/query-7bbc28ee5a141e5b531a6ac5a1cbf120828a0b9c19301c92a3f71531c08c698d.json +++ b/src-tauri/.sqlx/query-85f8edf373d3bf1d405a8fed804d9d04839e69a6c2c5cb8ad5c2f8e19547a2f6.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" FROM location WHERE instance_id = $1 ORDER BY name ASC", + "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM location WHERE pubkey = $1;", "describe": { "columns": [ { @@ -62,6 +62,11 @@ "name": "location_mfa_mode: LocationMfaMode", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "service_location_mode: ServiceLocationMode", + "ordinal": 12, + "type_info": "Integer" } ], "parameters": { @@ -79,8 +84,9 @@ false, false, false, + false, false ] }, - "hash": "7bbc28ee5a141e5b531a6ac5a1cbf120828a0b9c19301c92a3f71531c08c698d" + "hash": "85f8edf373d3bf1d405a8fed804d9d04839e69a6c2c5cb8ad5c2f8e19547a2f6" } diff --git a/src-tauri/.sqlx/query-3421da72f01d726c2931071203d663b197cb518dd65ec73108f85b2cb7270741.json b/src-tauri/.sqlx/query-b882379427740576d70c89eaeb815dede3c312162dcc73cea9c883289ba9fa8e.json similarity index 64% rename from src-tauri/.sqlx/query-3421da72f01d726c2931071203d663b197cb518dd65ec73108f85b2cb7270741.json rename to src-tauri/.sqlx/query-b882379427740576d70c89eaeb815dede3c312162dcc73cea9c883289ba9fa8e.json index a994e60f..5163c8ca 100644 --- a/src-tauri/.sqlx/query-3421da72f01d726c2931071203d663b197cb518dd65ec73108f85b2cb7270741.json +++ b/src-tauri/.sqlx/query-b882379427740576d70c89eaeb815dede3c312162dcc73cea9c883289ba9fa8e.json @@ -1,12 +1,12 @@ { "db_name": "SQLite", - "query": "UPDATE location SET instance_id = $1, name = $2, address = $3, pubkey = $4, endpoint = $5, allowed_ips = $6, dns = $7, network_id = $8, route_all_traffic = $9, keepalive_interval = $10, location_mfa_mode = $11 WHERE id = $12", + "query": "UPDATE location SET instance_id = $1, name = $2, address = $3, pubkey = $4, endpoint = $5, allowed_ips = $6, dns = $7, network_id = $8, route_all_traffic = $9, keepalive_interval = $10, location_mfa_mode = $11, service_location_mode = $12 WHERE id = $13", "describe": { "columns": [], "parameters": { - "Right": 12 + "Right": 13 }, "nullable": [] }, - "hash": "3421da72f01d726c2931071203d663b197cb518dd65ec73108f85b2cb7270741" + "hash": "b882379427740576d70c89eaeb815dede3c312162dcc73cea9c883289ba9fa8e" } diff --git a/src-tauri/.sqlx/query-ac02b04f6490a768571290d7dc77444eb0ca55a3a7e159c3b2e529ebf75f224f.json b/src-tauri/.sqlx/query-d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624.json similarity index 82% rename from src-tauri/.sqlx/query-ac02b04f6490a768571290d7dc77444eb0ca55a3a7e159c3b2e529ebf75f224f.json rename to src-tauri/.sqlx/query-d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624.json index 6df78777..d15e39ce 100644 --- a/src-tauri/.sqlx/query-ac02b04f6490a768571290d7dc77444eb0ca55a3a7e159c3b2e529ebf75f224f.json +++ b/src-tauri/.sqlx/query-d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" FROM location WHERE pubkey = $1;", + "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM location WHERE instance_id = $1 ORDER BY name ASC", "describe": { "columns": [ { @@ -62,6 +62,11 @@ "name": "location_mfa_mode: LocationMfaMode", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "service_location_mode: ServiceLocationMode", + "ordinal": 12, + "type_info": "Integer" } ], "parameters": { @@ -79,8 +84,9 @@ false, false, false, + false, false ] }, - "hash": "ac02b04f6490a768571290d7dc77444eb0ca55a3a7e159c3b2e529ebf75f224f" + "hash": "d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624" } diff --git a/src-tauri/.sqlx/query-e02047df7deea862cceca537e49ae16a8237e91eff0ee684cacd2ec1c77adb58.json b/src-tauri/.sqlx/query-ea39145f2cdc783bc78b32363cce32a87bd603debccaec23b160150766bdcd9f.json similarity index 59% rename from src-tauri/.sqlx/query-e02047df7deea862cceca537e49ae16a8237e91eff0ee684cacd2ec1c77adb58.json rename to src-tauri/.sqlx/query-ea39145f2cdc783bc78b32363cce32a87bd603debccaec23b160150766bdcd9f.json index 3d77f025..a05f49a0 100644 --- a/src-tauri/.sqlx/query-e02047df7deea862cceca537e49ae16a8237e91eff0ee684cacd2ec1c77adb58.json +++ b/src-tauri/.sqlx/query-ea39145f2cdc783bc78b32363cce32a87bd603debccaec23b160150766bdcd9f.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "INSERT INTO location (instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id \"id!\"", + "query": "INSERT INTO location (instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode, service_location_mode) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id \"id!\"", "describe": { "columns": [ { @@ -10,11 +10,11 @@ } ], "parameters": { - "Right": 11 + "Right": 12 }, "nullable": [ true ] }, - "hash": "e02047df7deea862cceca537e49ae16a8237e91eff0ee684cacd2ec1c77adb58" + "hash": "ea39145f2cdc783bc78b32363cce32a87bd603debccaec23b160150766bdcd9f" } diff --git a/src-tauri/.sqlx/query-f660459ee3beed1e88815560c3f16259e63975a3ec89a3c9b95d833774e9dfef.json b/src-tauri/.sqlx/query-f660459ee3beed1e88815560c3f16259e63975a3ec89a3c9b95d833774e9dfef.json deleted file mode 100644 index e895e552..00000000 --- a/src-tauri/.sqlx/query-f660459ee3beed1e88815560c3f16259e63975a3ec89a3c9b95d833774e9dfef.json +++ /dev/null @@ -1,86 +0,0 @@ -{ - "db_name": "SQLite", - "query": "SELECT id, instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id,route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\" FROM location ORDER BY name ASC;", - "describe": { - "columns": [ - { - "name": "id", - "ordinal": 0, - "type_info": "Integer" - }, - { - "name": "instance_id", - "ordinal": 1, - "type_info": "Integer" - }, - { - "name": "name", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "address", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "pubkey", - "ordinal": 4, - "type_info": "Text" - }, - { - "name": "endpoint", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "allowed_ips", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "dns", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "network_id", - "ordinal": 8, - "type_info": "Integer" - }, - { - "name": "route_all_traffic", - "ordinal": 9, - "type_info": "Bool" - }, - { - "name": "keepalive_interval", - "ordinal": 10, - "type_info": "Integer" - }, - { - "name": "location_mfa_mode: LocationMfaMode", - "ordinal": 11, - "type_info": "Integer" - } - ], - "parameters": { - "Right": 0 - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - true, - false, - false, - false, - false - ] - }, - "hash": "f660459ee3beed1e88815560c3f16259e63975a3ec89a3c9b95d833774e9dfef" -} diff --git a/src-tauri/src/enterprise/service_locations/mod.rs b/src-tauri/src/enterprise/service_locations/mod.rs index f5008c02..6cd4bba9 100644 --- a/src-tauri/src/enterprise/service_locations/mod.rs +++ b/src-tauri/src/enterprise/service_locations/mod.rs @@ -34,11 +34,12 @@ pub enum ServiceLocationError { JsonError(#[from] serde_json::Error), #[error(transparent)] ProtoEnumError(#[from] prost::UnknownEnumValue), - #[cfg(target_os = "windows")] + #[cfg(windows)] #[error(transparent)] WindowsServiceError(#[from] windows_service::Error), } +#[allow(dead_code)] #[derive(Default)] pub(crate) struct ServiceLocationManager { // Interface name: WireGuard API instance @@ -54,6 +55,7 @@ pub(crate) struct ServiceLocationData { pub private_key: String, } +#[allow(dead_code)] pub(crate) struct SingleServiceLocationData { pub service_location: ServiceLocation, pub instance_id: String, @@ -115,7 +117,7 @@ impl Location { allowed_ips: self.allowed_ips.clone(), dns: self.dns.clone().unwrap_or_default(), keepalive_interval: self.keepalive_interval.try_into().unwrap_or(0), - mode: mode, + mode, }) } } diff --git a/src-tauri/src/service/mod.rs b/src-tauri/src/service/mod.rs index 9e234de8..1b0e1ae0 100644 --- a/src-tauri/src/service/mod.rs +++ b/src-tauri/src/service/mod.rs @@ -51,14 +51,11 @@ use tracing::{debug, error, info, info_span, Instrument}; use self::config::Config; use super::VERSION; use crate::enterprise::service_locations::ServiceLocationError; -#[cfg(not(windows))] -use crate::service::proto::DeleteServiceLocationsRequest; #[cfg(windows)] -use crate::{ - enterprise::service_locations::ServiceLocationManager, - service::proto::{ - DeleteServiceLocationsRequest, ResetServiceLocationRequest, SaveServiceLocationsRequest, - }, +use crate::enterprise::service_locations::ServiceLocationManager; +#[cfg(not(windows))] +use crate::service::proto::{ + DeleteServiceLocationsRequest, ResetServiceLocationRequest, SaveServiceLocationsRequest, }; #[cfg(windows)] @@ -84,6 +81,7 @@ pub enum DaemonError { TransportError(#[from] tonic::transport::Error), #[error(transparent)] ServiceLocationError(#[from] ServiceLocationError), + #[cfg(windows)] #[error(transparent)] WindowsServiceError(#[from] windows_service::Error), } @@ -139,7 +137,7 @@ impl DesktopDaemonService for DaemonService { #[cfg(not(windows))] async fn save_service_locations( &self, - request: tonic::Request, + _request: tonic::Request, ) -> Result, Status> { debug!("Saved service location request received, this is currently not supported on Unix systems"); Ok(Response::new(())) @@ -148,7 +146,7 @@ impl DesktopDaemonService for DaemonService { #[cfg(not(windows))] async fn delete_service_locations( &self, - request: tonic::Request, + _request: tonic::Request, ) -> Result, Status> { debug!("Saved service location request received, this is currently not supported on Unix systems"); Ok(Response::new(())) @@ -157,7 +155,7 @@ impl DesktopDaemonService for DaemonService { #[cfg(not(windows))] async fn reset_service_location( &self, - request: tonic::Request, + _request: tonic::Request, ) -> Result, Status> { debug!("Restart service location request received, this is currently not supported on Unix systems"); Ok(Response::new(())) From 887cf1b89ce4e446473bdee4cd9c87dab349ae0c Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:38:20 +0200 Subject: [PATCH 06/13] review changes --- .../src/enterprise/service_locations/mod.rs | 5 +- .../enterprise/service_locations/windows.rs | 76 +++++++++---------- 2 files changed, 36 insertions(+), 45 deletions(-) diff --git a/src-tauri/src/enterprise/service_locations/mod.rs b/src-tauri/src/enterprise/service_locations/mod.rs index 6cd4bba9..7cdb4577 100644 --- a/src-tauri/src/enterprise/service_locations/mod.rs +++ b/src-tauri/src/enterprise/service_locations/mod.rs @@ -85,10 +85,7 @@ impl std::fmt::Debug for SingleServiceLocationData { impl Location { pub fn to_service_location(&self) -> Result { if !self.is_service_location() { - warn!( - "Location {} is not a service location, so it can't be converted to one.", - self - ); + warn!("Location {self} is not a service location, so it can't be converted to one."); return Err(crate::error::Error::ConversionError(format!( "Failed to convert location {} to a service location as it's either not marked as one or has MFA enabled.", self diff --git a/src-tauri/src/enterprise/service_locations/windows.rs b/src-tauri/src/enterprise/service_locations/windows.rs index 87183260..9cca0ae4 100644 --- a/src-tauri/src/enterprise/service_locations/windows.rs +++ b/src-tauri/src/enterprise/service_locations/windows.rs @@ -43,45 +43,42 @@ const SERVICE_LOCATIONS_SUBDIR: &str = "service_locations"; pub(crate) async fn watch_for_login_logoff( service_location_manager: Arc>, ) -> Result<(), ServiceLocationError> { - unsafe { - loop { - let mut event_mask: u32 = 0; - let success = WTSWaitSystemEvent( + loop { + let mut event_flags = 0; + let success = unsafe { + WTSWaitSystemEvent( Some(WTS_CURRENT_SERVER_HANDLE), WTS_EVENT_LOGON | WTS_EVENT_LOGOFF, - &mut event_mask, - ); - - match success { - Ok(_) => { - debug!("Waiting for system event returned with event_mask: 0x{event_mask:x}"); - } - Err(err) => { - error!("Failed waiting for login/logoff event: {err:?}"); - tokio::time::sleep(Duration::from_secs(LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS)) - .await; - continue; - } - }; + &mut event_flags, + ) + }; - if event_mask & WTS_EVENT_LOGON != 0 { - debug!( - "Detected user logon, attempting to auto-disconnect from service locations." - ); - service_location_manager - .clone() - .write() - .unwrap() - .disconnect_service_locations(Some(ServiceLocationMode::PreLogon))?; + match success { + Ok(_) => { + debug!("Waiting for system event returned with event_flags: 0x{event_flags:x}"); } - if event_mask & WTS_EVENT_LOGOFF != 0 { - debug!("Detected user logoff, attempting to auto-connect to service locations."); - service_location_manager - .clone() - .write() - .unwrap() - .connect_to_service_locations()?; + Err(err) => { + error!("Failed waiting for login/logoff event: {err:?}"); + tokio::time::sleep(Duration::from_secs(LOGIN_LOGOFF_EVENT_RETRY_DELAY_SECS)).await; + continue; } + }; + + if event_flags & WTS_EVENT_LOGON != 0 { + debug!("Detected user logon, attempting to auto-disconnect from service locations."); + service_location_manager + .clone() + .write() + .unwrap() + .disconnect_service_locations(Some(ServiceLocationMode::PreLogon))?; + } + if event_flags & WTS_EVENT_LOGOFF != 0 { + debug!("Detected user logoff, attempting to auto-connect to service locations."); + service_location_manager + .clone() + .write() + .unwrap() + .connect_to_service_locations()?; } } } @@ -265,7 +262,7 @@ impl ServiceLocationManager { debug!("Initializing ServiceLocationApi"); let path = get_shared_directory()?; - debug!("Creating directory: {:?}", path); + debug!("Creating directory: {path:?}"); create_dir_all(&path)?; if let Some(path_str) = path.to_str() { @@ -694,7 +691,7 @@ impl ServiceLocationManager { instance_data.service_locations.len() ); for location in instance_data.service_locations { - debug!("Service Location: {:?}", location); + debug!("Service Location: {location:?}"); if location.mode == ServiceLocationMode::PreLogon as i32 { if is_user_logged_in() { @@ -760,7 +757,7 @@ impl ServiceLocationManager { service_locations.len(), ); - debug!("Service locations to save: {:?}", service_locations); + debug!("Service locations to save: {service_locations:?}"); create_dir_all(get_shared_directory()?)?; @@ -774,10 +771,7 @@ impl ServiceLocationManager { let json = serde_json::to_string_pretty(&service_location_data)?; - debug!( - "Writing service location data to file: {:?}", - instance_file_path - ); + debug!("Writing service location data to file: {instance_file_path:?}"); fs::write(&instance_file_path, &json)?; From 8257f063f5e2d22d26835a4331cbeb785275c33d Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:39:28 +0200 Subject: [PATCH 07/13] more cleanup --- src-tauri/src/enterprise/service_locations/mod.rs | 2 +- src-tauri/src/utils.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src-tauri/src/enterprise/service_locations/mod.rs b/src-tauri/src/enterprise/service_locations/mod.rs index 7cdb4577..5fba9784 100644 --- a/src-tauri/src/enterprise/service_locations/mod.rs +++ b/src-tauri/src/enterprise/service_locations/mod.rs @@ -11,7 +11,7 @@ use crate::{ service::proto::ServiceLocation, }; -#[cfg(target_os = "windows")] +#[cfg(windows)] pub mod windows; #[derive(Debug, thiserror::Error)] diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs index 0b979dbf..dbb63a88 100644 --- a/src-tauri/src/utils.rs +++ b/src-tauri/src/utils.rs @@ -6,15 +6,15 @@ use sqlx::query; use tauri::{AppHandle, Emitter, Manager}; use tonic::Code; use tracing::Level; -#[cfg(target_os = "windows")] +#[cfg(windows)] use winapi::shared::winerror::ERROR_SERVICE_DOES_NOT_EXIST; -#[cfg(target_os = "windows")] +#[cfg(windows)] use windows_service::{ service::{ServiceAccess, ServiceState}, service_manager::{ServiceManager, ServiceManagerAccess}, }; -#[cfg(target_os = "windows")] +#[cfg(windows)] use crate::active_connections::find_connection; use crate::{ appstate::AppState, From b5c60d2d401d6dc160f44792460135117121529c Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:43:29 +0200 Subject: [PATCH 08/13] further review changes --- src-tauri/src/service/mod.rs | 13 ++++++------- src-tauri/src/service/windows.rs | 10 ++++------ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src-tauri/src/service/mod.rs b/src-tauri/src/service/mod.rs index 1b0e1ae0..9aa1a13c 100644 --- a/src-tauri/src/service/mod.rs +++ b/src-tauri/src/service/mod.rs @@ -181,11 +181,9 @@ impl DesktopDaemonService for DaemonService { Ok(Response::new(())) } Err(e) => { - error!("Failed to save service location: {}", e); - Err(Status::internal(format!( - "Failed to save service location: {}", - e - ))) + let msg = format!("Failed to save service location: {e}"); + error!(msg); + Err(Status::internal(msg)) } } } @@ -204,8 +202,9 @@ impl DesktopDaemonService for DaemonService { .unwrap() .disconnect_service_locations_by_instance(&instance_id) .map_err(|e| { - error!("Failed to disconnect service location: {}", e); - Status::internal(format!("Failed to disconnect service location: {}", e)) + let msg = format!("Failed to disconnect service location: {e}"); + error!(msg); + Status::internal(msg) })?; match self diff --git a/src-tauri/src/service/windows.rs b/src-tauri/src/service/windows.rs index aa3cd514..2f006d38 100644 --- a/src-tauri/src/service/windows.rs +++ b/src-tauri/src/service/windows.rs @@ -27,7 +27,7 @@ use crate::{ static SERVICE_NAME: &str = "DefguardService"; const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; -const LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS: u64 = 10; +const LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS: Duration = Duration::from_secs(5); pub fn run() -> Result<(), windows_service::Error> { // Register generated `ffi_service_main` with the system and start the service, blocking @@ -121,9 +121,8 @@ fn run_service() -> Result<(), DaemonError> { } Err(e) => { warn!( - "Error while trying to auto-connect to service locations: {}. \ + "Error while trying to auto-connect to service locations: {e}. \ Will continue monitoring for login/logoff events.", - e ); } } @@ -139,10 +138,9 @@ fn run_service() -> Result<(), DaemonError> { } Err(e) => { error!( - "Error in login/logoff event monitoring: {}. Restarting in {} seconds...", - e, LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS + "Error in login/logoff event monitoring: {e}. Restarting in {LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS} seconds...", ); - tokio::time::sleep(Duration::from_secs(LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS)).await; + tokio::time::sleep(LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS).await; info!("Restarting login/logoff event monitoring"); } } From ef97822a7359f5dd0802d9556aef986cb8310d67 Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:44:15 +0200 Subject: [PATCH 09/13] clippy --- src-tauri/src/enterprise/service_locations/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src-tauri/src/enterprise/service_locations/mod.rs b/src-tauri/src/enterprise/service_locations/mod.rs index 5fba9784..3a2b1098 100644 --- a/src-tauri/src/enterprise/service_locations/mod.rs +++ b/src-tauri/src/enterprise/service_locations/mod.rs @@ -48,6 +48,7 @@ pub(crate) struct ServiceLocationManager { connected_service_locations: HashMap>, } +#[allow(dead_code)] #[derive(Serialize, Deserialize)] pub(crate) struct ServiceLocationData { pub service_locations: Vec, From 2f9ac4b7b5250eb0146bafa467cc93a925a05ca7 Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Thu, 23 Oct 2025 23:12:14 +0200 Subject: [PATCH 10/13] review changes --- src-tauri/src/active_connections.rs | 2 +- src-tauri/src/commands.rs | 112 +++--------------- src-tauri/src/database/models/location.rs | 25 ++-- .../enterprise/service_locations/windows.rs | 6 +- src-tauri/src/service/mod.rs | 77 ++++++------ src-tauri/src/service/windows.rs | 91 ++++++++------ src-tauri/src/utils.rs | 7 +- 7 files changed, 134 insertions(+), 186 deletions(-) diff --git a/src-tauri/src/active_connections.rs b/src-tauri/src/active_connections.rs index ed0ef4b2..970a31bd 100644 --- a/src-tauri/src/active_connections.rs +++ b/src-tauri/src/active_connections.rs @@ -82,7 +82,7 @@ pub(crate) async fn find_connection( pub(crate) async fn active_connections( instance: &Instance, ) -> Result, Error> { - let locations: HashSet = Location::find_by_instance_id(&*DB_POOL, instance.id) + let locations: HashSet = Location::find_by_instance_id(&*DB_POOL, instance.id, false) .await? .iter() .map(|location| location.id) diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index c1dfc3fe..2e43bd75 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -39,8 +39,8 @@ use crate::{ proto::DeviceConfigResponse, service::{ proto::{ - DeleteServiceLocationsRequest, RemoveInterfaceRequest, ResetServiceLocationRequest, - SaveServiceLocationsRequest, ServiceLocation, + DeleteServiceLocationsRequest, RemoveInterfaceRequest, SaveServiceLocationsRequest, + ServiceLocation, }, utils::DAEMON_CLIENT, }, @@ -290,7 +290,7 @@ pub async fn save_device_config( transaction.commit().await?; info!("New instance {instance} created."); trace!("Created following instance: {instance:#?}"); - let locations = Location::find_by_instance_id(&*DB_POOL, instance.id).await?; + let locations = Location::find_by_instance_id(&*DB_POOL, instance.id, true).await?; trace!("Created following locations: {locations:#?}"); let mut service_locations = Vec::::new(); @@ -332,36 +332,6 @@ pub async fn save_device_config( "Saved service locations to the daemon for instance {}({}).", instance.name, instance.id, ); - - let locations_pubkeys = service_locations - .iter() - .map(|loc| loc.pubkey.clone()) - .collect::>(); - - for location_pubkey in locations_pubkeys { - let restart_request = ResetServiceLocationRequest { - instance_id: instance.uuid.clone(), - pubkey: location_pubkey.clone(), - }; - debug!( - "Restarting service location with pubkey {} on instance {}.", - restart_request.pubkey, restart_request.instance_id, - ); - DAEMON_CLIENT.clone() - .reset_service_location(restart_request) - .await - .map_err(|err| { - error!( - "Error while restarting service location with pubkey {} on instance {}: {err}", - location_pubkey, instance.uuid, - ); - Error::InternalError(err.to_string()) - })?; - debug!( - "Restarted service location with pubkey {} on instance {}.", - location_pubkey, instance.uuid - ); - } } handle.emit(EventKey::InstanceUpdate.into(), ())?; @@ -386,7 +356,7 @@ pub async fn all_instances() -> Result>, Error> { let mut instance_info = Vec::new(); let connection_ids = get_connection_id_by_type(ConnectionType::Location).await; for instance in instances { - let locations = Location::find_by_instance_id(&*DB_POOL, instance.id).await?; + let locations = Location::find_by_instance_id(&*DB_POOL, instance.id, false).await?; let location_ids: Vec = locations.iter().map(|location| location.id).collect(); let connected = connection_ids .iter() @@ -460,7 +430,7 @@ pub async fn all_locations(instance_id: Id) -> Result, Error> "Getting information about all locations for instance {}.", instance.name ); - let locations = Location::find_by_instance_id(&*DB_POOL, instance_id).await?; + let locations = Location::find_by_instance_id(&*DB_POOL, instance_id, false).await?; trace!( "Found {} locations for instance {instance} to return information about.", locations.len() @@ -468,16 +438,6 @@ pub async fn all_locations(instance_id: Id) -> Result, Error> let active_locations_ids = get_connection_id_by_type(ConnectionType::Location).await; let mut location_info = Vec::new(); for location in locations { - // Skip service locations, those shouldn't be shown in the UI. - if location.is_service_location() { - debug!( - "Skipping service location {}({}) for instance {}({}) when returning \ - locations to the frontend.", - location.name, location.id, instance.name, instance.id, - ); - continue; - } - let info = LocationInfo { id: location.id, instance_id: location.instance_id, @@ -560,7 +520,7 @@ pub(crate) async fn locations_changed( device_config: &DeviceConfigResponse, ) -> Result { let db_locations: HashSet> = - Location::find_by_instance_id(transaction.as_mut(), instance.id) + Location::find_by_instance_id(transaction.as_mut(), instance.id, true) .await? .into_iter() .map(|location| { @@ -633,14 +593,13 @@ pub(crate) async fn do_update_instance( ); // fetch existing locations for given instance let mut current_locations = - Location::find_by_instance_id(transaction.as_mut(), instance.id).await?; + Location::find_by_instance_id(transaction.as_mut(), instance.id, true).await?; for dev_config in response.configs { // parse device config let new_location = dev_config.into_location(instance.id); - let saved_location: Location; // check if location is already present in current locations - if let Some(position) = current_locations + let saved_location = if let Some(position) = current_locations .iter() .position(|loc| loc.network_id == new_location.network_id) { @@ -662,14 +621,14 @@ pub(crate) async fn do_update_instance( current_location.service_location_mode = new_location.service_location_mode; current_location.save(transaction.as_mut()).await?; info!("Location {current_location} configuration updated for instance {instance}"); - saved_location = current_location; + current_location } else { // create new location debug!("Creating new location {new_location} for instance instance {instance}"); let new_location = new_location.save(transaction.as_mut()).await?; info!("New location {new_location} created for instance {instance}"); - saved_location = new_location; - } + new_location + }; if saved_location.is_service_location() { debug!( @@ -699,7 +658,12 @@ pub(crate) async fn do_update_instance( .ok_or(Error::NotFound)? .prvkey; - if !service_locations.is_empty() { + if service_locations.is_empty() { + debug!( + "No service locations to process for instance {}({})", + instance.name, instance.id + ); + } else { debug!( "Processing {} service location(s) for instance {}({})", service_locations.len(), @@ -739,49 +703,10 @@ pub(crate) async fn do_update_instance( instance.id ); - let service_locations_pubkeys = service_locations - .iter() - .map(|loc| loc.pubkey.clone()) - .collect::>(); - - let instance_id = instance.uuid.clone(); - - for pubkey in service_locations_pubkeys { - debug!( - "Sending state reset request for service location with pubkey {} on instance {}", - pubkey, instance_id - ); - - DAEMON_CLIENT - .clone() - .reset_service_location(ResetServiceLocationRequest { - instance_id: instance_id.clone(), - pubkey: pubkey.clone(), - }) - .await - .map_err(|err| { - error!( - "Error while restarting service location with pubkey {} on instance {}: {err}", - pubkey, instance_id, - ); - Error::InternalError(err.to_string()) - })?; - - info!( - "Successfully reset the state of service location with pubkey {} on instance {}", - pubkey, instance_id - ); - } - debug!( "Completed processing all service locations for instance {}({})", instance.name, instance.id ); - } else { - debug!( - "No service locations to process for instance {}({})", - instance.name, instance.id - ); } Ok(()) @@ -1007,7 +932,8 @@ pub async fn delete_instance(instance_id: Id, handle: AppHandle) -> Result<(), E }; debug!("The instance that is being deleted has been identified as {instance}"); - let instance_locations = Location::find_by_instance_id(&mut *transaction, instance_id).await?; + let instance_locations = + Location::find_by_instance_id(&mut *transaction, instance_id, false).await?; if !instance_locations.is_empty() { debug!( "Found locations associated with the instance {instance}, closing their connections." diff --git a/src-tauri/src/database/models/location.rs b/src-tauri/src/database/models/location.rs index 24177bab..005b5fdb 100644 --- a/src-tauri/src/database/models/location.rs +++ b/src-tauri/src/database/models/location.rs @@ -84,17 +84,24 @@ impl fmt::Display for Location { } impl Location { + /// Ignores service locations #[cfg(windows)] - pub(crate) async fn all<'e, E>(executor: E) -> Result, SqlxError> + pub(crate) async fn all<'e, E>( + executor: E, + include_service_locations: bool, + ) -> Result, SqlxError> where E: SqliteExecutor<'e>, { + let max_mode = if include_service_locations { 2 } else { 0 }; // 0 to exclude service locations, 2 to include them query_as!( Self, - "SELECT id, instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id,\ - route_all_traffic, keepalive_interval, \ - location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ - FROM location ORDER BY name ASC;" + "SELECT id, instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id,\ + route_all_traffic, keepalive_interval, \ + location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ + FROM location WHERE service_location_mode <= $1 \ + ORDER BY name ASC;", + max_mode ) .fetch_all(executor) .await @@ -151,16 +158,20 @@ impl Location { pub(crate) async fn find_by_instance_id<'e, E>( executor: E, instance_id: Id, + include_service_locations: bool, ) -> Result, SqlxError> where E: SqliteExecutor<'e>, { + let max_mode = if include_service_locations { 2 } else { 0 }; // 0 to exclude service locations, 2 to include them query_as!( Self, "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, \ network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ - FROM location WHERE instance_id = $1 ORDER BY name ASC", - instance_id + FROM location WHERE instance_id = $1 AND service_location_mode <= $2 \ + ORDER BY name ASC", + instance_id, + max_mode ) .fetch_all(executor) .await diff --git a/src-tauri/src/enterprise/service_locations/windows.rs b/src-tauri/src/enterprise/service_locations/windows.rs index 9cca0ae4..b8eb8b8c 100644 --- a/src-tauri/src/enterprise/service_locations/windows.rs +++ b/src-tauri/src/enterprise/service_locations/windows.rs @@ -700,12 +700,11 @@ impl ServiceLocationManager { location.name ); continue; - } else { - debug!( + } + debug!( "Proceeding to connect pre-logon service location '{}' because no user is logged in", location.name ); - } } if self.is_service_location_connected(&instance_data.instance_id, &location.pubkey) @@ -748,6 +747,7 @@ impl ServiceLocationManager { } pub fn save_service_locations( + &self, service_locations: &[ServiceLocation], instance_id: &str, private_key: &str, diff --git a/src-tauri/src/service/mod.rs b/src-tauri/src/service/mod.rs index 9aa1a13c..c4de34ea 100644 --- a/src-tauri/src/service/mod.rs +++ b/src-tauri/src/service/mod.rs @@ -53,10 +53,7 @@ use super::VERSION; use crate::enterprise::service_locations::ServiceLocationError; #[cfg(windows)] use crate::enterprise::service_locations::ServiceLocationManager; -#[cfg(not(windows))] -use crate::service::proto::{ - DeleteServiceLocationsRequest, ResetServiceLocationRequest, SaveServiceLocationsRequest, -}; +use crate::service::proto::{DeleteServiceLocationsRequest, SaveServiceLocationsRequest}; #[cfg(windows)] const DAEMON_HTTP_PORT: u16 = 54127; @@ -152,40 +149,58 @@ impl DesktopDaemonService for DaemonService { Ok(Response::new(())) } - #[cfg(not(windows))] - async fn reset_service_location( - &self, - _request: tonic::Request, - ) -> Result, Status> { - debug!("Restart service location request received, this is currently not supported on Unix systems"); - Ok(Response::new(())) - } - #[cfg(windows)] async fn save_service_locations( &self, request: tonic::Request, ) -> Result, Status> { - use crate::enterprise::service_locations::ServiceLocationManager; - debug!("Received a request to save service location"); let service_location = request.into_inner(); - match ServiceLocationManager::save_service_locations( - service_location.service_locations.as_slice(), - &service_location.instance_id, - &service_location.private_key, - ) { + match self + .service_location_manager + .clone() + .read() + .unwrap() + .save_service_locations( + service_location.service_locations.as_slice(), + &service_location.instance_id, + &service_location.private_key, + ) { Ok(()) => { debug!("Service location saved successfully"); - Ok(Response::new(())) } Err(e) => { let msg = format!("Failed to save service location: {e}"); error!(msg); - Err(Status::internal(msg)) + return Err(Status::internal(msg)); } } + + for saved_location in service_location.service_locations { + match self + .service_location_manager + .clone() + .write() + .unwrap() + .reset_service_location_state(&service_location.instance_id, &saved_location.pubkey) + { + Ok(()) => { + debug!( + "Service location '{}' state reset successfully", + saved_location.name + ); + } + Err(e) => { + error!( + "Failed to reset state for service location '{}': {e}", + saved_location.name + ); + } + } + } + + Ok(Response::new(())) } #[cfg(windows)] @@ -228,24 +243,6 @@ impl DesktopDaemonService for DaemonService { } } - #[cfg(windows)] - async fn reset_service_location( - &self, - request: tonic::Request, - ) -> Result, Status> { - let request = request.into_inner(); - self.service_location_manager - .clone() - .write() - .unwrap() - .reset_service_location_state(&request.instance_id, &request.pubkey) - .map_err(|e| { - error!("Failed to restart service location: {}", e); - Status::internal(format!("Failed to restart service location: {}", e)) - })?; - Ok(Response::new(())) - } - async fn create_interface( &self, request: tonic::Request, diff --git a/src-tauri/src/service/windows.rs b/src-tauri/src/service/windows.rs index 2f006d38..d0a003be 100644 --- a/src-tauri/src/service/windows.rs +++ b/src-tauri/src/service/windows.rs @@ -7,7 +7,7 @@ use std::{ use clap::Parser; use error; -use tokio::runtime::Runtime; +use tokio::{runtime::Runtime, select}; use windows_service::{ define_windows_service, service::{ @@ -110,52 +110,71 @@ fn run_service() -> Result<(), DaemonError> { let service_location_manager_clone = service_location_manager.clone(); runtime.spawn(async move { - info!("Starting service location management task"); + let manager = service_location_manager_clone.clone(); - let manager = service_location_manager_clone; + let service_location_task = async move { + info!("Starting service location management task"); - info!("Attempting to auto-connect to service locations"); - match manager.write().unwrap().connect_to_service_locations() { - Ok(_) => { - info!("Auto-connect to service locations completed successfully"); - } - Err(e) => { - warn!( - "Error while trying to auto-connect to service locations: {e}. \ - Will continue monitoring for login/logoff events.", - ); - } - } - - info!("Starting login/logoff event monitoring"); - loop { - match watch_for_login_logoff( - manager.clone(), - ).await { + info!("Attempting to auto-connect to service locations"); + match manager.write().unwrap().connect_to_service_locations() { Ok(_) => { - warn!("Login/logoff event monitoring ended unexpectedly"); - break; + info!("Auto-connect to service locations completed successfully"); } Err(e) => { - error!( - "Error in login/logoff event monitoring: {e}. Restarting in {LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS} seconds...", + warn!( + "Error while trying to auto-connect to service locations: {e}. \ + Will continue monitoring for login/logoff events.", ); - tokio::time::sleep(LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS).await; - info!("Restarting login/logoff event monitoring"); } } - } - warn!("Service location management task terminated"); - }); + info!("Starting login/logoff event monitoring"); + loop { + match watch_for_login_logoff( + manager.clone(), + ).await { + Ok(_) => { + warn!("Login/logoff event monitoring ended unexpectedly"); + break; + } + Err(e) => { + error!( + "Error in login/logoff event monitoring: {e}. Restarting in {LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS:?} seconds...", + ); + tokio::time::sleep(LOGIN_LOGOFF_MONITORING_RESTART_DELAY_SECS).await; + info!("Restarting login/logoff event monitoring"); + } + } + } - let service_location_manager_clone = service_location_manager.clone(); - runtime.spawn(async move { - let server_result = run_server(config, service_location_manager_clone).await; + warn!("Service location management task terminated"); + Ok::<(), ServiceLocationError>(()) + }; - if server_result.is_err() { - let _ = shutdown_tx_server.send(2); - } + let server_task = async move { + run_server(config, service_location_manager_clone).await + }; + + let result = select! { + result = service_location_task => { + warn!("Service location task completed"); + result.map_err(|e| format!("Service location error: {e}")) + } + result = server_task => { + warn!("Server task completed"); + result.map_err(|e| format!("Server error: {e}")) + } + }; + + let signal = if result.is_err() { + error!("Task ended with error: {:?}", result.err()); + 2 + } else { + info!("Task ended without an error."); + 1 + }; + + let _ = shutdown_tx_server.send(signal); }); loop { diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs index dbb63a88..92230c75 100644 --- a/src-tauri/src/utils.rs +++ b/src-tauri/src/utils.rs @@ -941,12 +941,7 @@ async fn check_connection( #[cfg(target_os = "windows")] pub async fn sync_connections(app_handle: &AppHandle) -> Result<(), Error> { debug!("Synchronizing active connections with the systems' state..."); - let all_locations = Location::all(&*DB_POOL).await?; - // filter out service locations as they are managed through the Windows Service - let all_locations: Vec> = all_locations - .into_iter() - .filter(|loc| !loc.is_service_location()) - .collect(); + let all_locations = Location::all(&*DB_POOL, false).await?; let service_manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT).map_err( |err| { From ce572457faea486a2d73014e2d2f742eaa465b8b Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Fri, 24 Oct 2025 09:40:58 +0200 Subject: [PATCH 11/13] remove spaces --- src-tauri/src/service/windows.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src-tauri/src/service/windows.rs b/src-tauri/src/service/windows.rs index d0a003be..5efd258f 100644 --- a/src-tauri/src/service/windows.rs +++ b/src-tauri/src/service/windows.rs @@ -165,7 +165,7 @@ fn run_service() -> Result<(), DaemonError> { result.map_err(|e| format!("Server error: {e}")) } }; - + let signal = if result.is_err() { error!("Task ended with error: {:?}", result.err()); 2 @@ -173,7 +173,7 @@ fn run_service() -> Result<(), DaemonError> { info!("Task ended without an error."); 1 }; - + let _ = shutdown_tx_server.send(signal); }); From 5071905737afc64b4e02b57de76eeeb2dbb9b687 Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Fri, 24 Oct 2025 10:00:53 +0200 Subject: [PATCH 12/13] sqlx, proto --- ...f211b5654af41b297c31706f5a5ad9ac400be116db7113a056.json} | 6 +++--- src-tauri/proto | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) rename src-tauri/.sqlx/{query-d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624.json => query-9137d3329ed718f211b5654af41b297c31706f5a5ad9ac400be116db7113a056.json} (90%) diff --git a/src-tauri/.sqlx/query-d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624.json b/src-tauri/.sqlx/query-9137d3329ed718f211b5654af41b297c31706f5a5ad9ac400be116db7113a056.json similarity index 90% rename from src-tauri/.sqlx/query-d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624.json rename to src-tauri/.sqlx/query-9137d3329ed718f211b5654af41b297c31706f5a5ad9ac400be116db7113a056.json index d15e39ce..012a54b3 100644 --- a/src-tauri/.sqlx/query-d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624.json +++ b/src-tauri/.sqlx/query-9137d3329ed718f211b5654af41b297c31706f5a5ad9ac400be116db7113a056.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM location WHERE instance_id = $1 ORDER BY name ASC", + "query": "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM location WHERE instance_id = $1 AND service_location_mode <= $2 ORDER BY name ASC", "describe": { "columns": [ { @@ -70,7 +70,7 @@ } ], "parameters": { - "Right": 1 + "Right": 2 }, "nullable": [ false, @@ -88,5 +88,5 @@ false ] }, - "hash": "d5433e3f04be190a009ac805ba5f8c3384709a2bdd3cd769e29b0caf5014d624" + "hash": "9137d3329ed718f211b5654af41b297c31706f5a5ad9ac400be116db7113a056" } diff --git a/src-tauri/proto b/src-tauri/proto index 3fd150c0..302b9858 160000 --- a/src-tauri/proto +++ b/src-tauri/proto @@ -1 +1 @@ -Subproject commit 3fd150c0245f5ed088ed57ad780a9376e3377ce3 +Subproject commit 302b985859915b1715fefad76ca727e8348453f4 From e395f6317f6a5c074f53678aab32d7d42495aade Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Sun, 26 Oct 2025 15:05:41 +0100 Subject: [PATCH 13/13] Update proto --- src-tauri/proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src-tauri/proto b/src-tauri/proto index 302b9858..fee70601 160000 --- a/src-tauri/proto +++ b/src-tauri/proto @@ -1 +1 @@ -Subproject commit 302b985859915b1715fefad76ca727e8348453f4 +Subproject commit fee706013b3bb5452c3c4dbf35bd973d0637ff25