Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions src-tauri/src/active_connections.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use std::{collections::HashSet, sync::LazyLock};

use tokio::sync::Mutex;

use crate::{
database::{
models::{connection::ActiveConnection, instance::Instance, location::Location, Id},
DB_POOL,
},
error::Error,
utils::disconnect_interface,
ConnectionType,
};

pub(crate) static ACTIVE_CONNECTIONS: LazyLock<Mutex<Vec<ActiveConnection>>> =
LazyLock::new(|| Mutex::new(Vec::new()));

pub(crate) async fn get_connection_id_by_type(connection_type: ConnectionType) -> Vec<Id> {
let active_connections = ACTIVE_CONNECTIONS.lock().await;

let connection_ids = active_connections
.iter()
.filter_map(|con| {
if con.connection_type == connection_type {
Some(con.location_id)
} else {
None
}
})
.collect();

connection_ids
}

pub async fn close_all_connections() -> Result<(), Error> {
debug!("Closing all active connections");
let active_connections = ACTIVE_CONNECTIONS.lock().await;
let active_connections_count = active_connections.len();
debug!("Found {active_connections_count} active connections");
for connection in active_connections.iter() {
debug!(
"Found active connection with location {}",
connection.location_id
);
trace!("Connection: {connection:#?}");
debug!("Removing interface {}", connection.interface_name);
disconnect_interface(connection).await?;
}
if active_connections_count > 0 {
info!("All active connections ({active_connections_count}) have been closed.");
} else {
debug!("There were no active connections to close, nothing to do.");
}
Ok(())
}

pub(crate) async fn find_connection(
id: Id,
connection_type: ConnectionType,
) -> Option<ActiveConnection> {
let connections = ACTIVE_CONNECTIONS.lock().await;
trace!(
"Checking for active connection with ID {id}, type {connection_type} in active connections."
);

if let Some(connection) = connections
.iter()
.find(|conn| conn.location_id == id && conn.connection_type == connection_type)
{
// 'connection' now contains the first element with the specified id and connection_type
trace!("Found connection: {connection:?}");
Some(connection.to_owned())
} else {
debug!(
"Couldn't find connection with ID {id}, type: {connection_type} in active connections."
);
None
}
}

/// Returns active connections for a given instance.
pub(crate) async fn active_connections(
instance: &Instance<Id>,
) -> Result<Vec<ActiveConnection>, Error> {
let locations: HashSet<Id> = Location::find_by_instance_id(&*DB_POOL, instance.id)
.await?
.iter()
.map(|location| location.id)
.collect();
Ok(ACTIVE_CONNECTIONS
.lock()
.await
.iter()
.filter(|connection| locations.contains(&connection.location_id))
.cloned()
.collect())
}
113 changes: 13 additions & 100 deletions src-tauri/src/appstate.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
use std::collections::{HashMap, HashSet};
use std::{collections::HashMap, sync::Mutex};

use tauri::async_runtime::{spawn, JoinHandle, Mutex};
use tauri::async_runtime::{spawn, JoinHandle};
use tokio_util::sync::CancellationToken;

use crate::{
active_connections::ACTIVE_CONNECTIONS,
app_config::AppConfig,
database::{
models::{connection::ActiveConnection, instance::Instance, location::Location, Id},
models::{connection::ActiveConnection, Id},
DB_POOL,
},
error::Error,
service::utils::DAEMON_CLIENT,
utils::{disconnect_interface, stats_handler},
utils::stats_handler,
ConnectionType,
};

pub struct AppState {
pub active_connections: Mutex<Vec<ActiveConnection>>,
pub log_watchers: std::sync::Mutex<HashMap<String, CancellationToken>>,
pub app_config: std::sync::Mutex<AppConfig>,
stat_threads: std::sync::Mutex<HashMap<Id, JoinHandle<()>>>, // location ID is the key
pub log_watchers: Mutex<HashMap<String, CancellationToken>>,
pub app_config: Mutex<AppConfig>,
stat_threads: Mutex<HashMap<Id, JoinHandle<()>>>, // location ID is the key
}

impl AppState {
#[must_use]
pub fn new(config: AppConfig) -> Self {
AppState {
active_connections: Mutex::new(Vec::new()),
log_watchers: std::sync::Mutex::new(HashMap::new()),
app_config: std::sync::Mutex::new(config),
stat_threads: std::sync::Mutex::new(HashMap::new()),
log_watchers: Mutex::new(HashMap::new()),
app_config: Mutex::new(config),
stat_threads: Mutex::new(HashMap::new()),
}
}

Expand All @@ -42,7 +40,7 @@ impl AppState {
let ifname = interface_name.into();
let connection = ActiveConnection::new(location_id, ifname.clone(), connection_type);
debug!("Adding active connection for location ID: {location_id}");
let mut connections = self.active_connections.lock().await;
let mut connections = ACTIVE_CONNECTIONS.lock().await;
connections.push(connection);
trace!("Current active connections: {connections:?}");
drop(connections);
Expand Down Expand Up @@ -92,7 +90,7 @@ impl AppState {
}
}

let mut connections = self.active_connections.lock().await;
let mut connections = ACTIVE_CONNECTIONS.lock().await;
if let Some(index) = connections.iter().position(|conn| {
conn.location_id == location_id && conn.connection_type == connection_type
}) {
Expand All @@ -105,89 +103,4 @@ impl AppState {
None
}
}

pub(crate) async fn get_connection_id_by_type(
&self,
connection_type: ConnectionType,
) -> Vec<Id> {
let active_connections = self.active_connections.lock().await;

let connection_ids = active_connections
.iter()
.filter_map(|con| {
if con.connection_type == connection_type {
Some(con.location_id)
} else {
None
}
})
.collect();

connection_ids
}

pub async fn close_all_connections(&self) -> Result<(), crate::error::Error> {
debug!("Closing all active connections");
let active_connections = self.active_connections.lock().await;
let active_connections_count = active_connections.len();
debug!("Found {active_connections_count} active connections");
for connection in active_connections.iter() {
debug!(
"Found active connection with location {}",
connection.location_id
);
trace!("Connection: {connection:#?}");
debug!("Removing interface {}", connection.interface_name);
disconnect_interface(connection).await?;
}
if active_connections_count > 0 {
info!("All active connections ({active_connections_count}) have been closed.");
} else {
debug!("There were no active connections to close, nothing to do.");
}
Ok(())
}

pub(crate) async fn find_connection(
&self,
id: Id,
connection_type: ConnectionType,
) -> Option<ActiveConnection> {
let connections = self.active_connections.lock().await;
trace!(
"Checking for active connection with ID {id}, type {connection_type} in active connections."
);

if let Some(connection) = connections
.iter()
.find(|conn| conn.location_id == id && conn.connection_type == connection_type)
{
// 'connection' now contains the first element with the specified id and connection_type
trace!("Found connection: {connection:?}");
Some(connection.to_owned())
} else {
debug!("Couldn't find connection with ID {id}, type: {connection_type} in active connections.");
None
}
}

/// Returns active connections for a given instance.
pub(crate) async fn active_connections(
&self,
instance: &Instance<Id>,
) -> Result<Vec<ActiveConnection>, Error> {
let locations: HashSet<Id> = Location::find_by_instance_id(&*DB_POOL, instance.id)
.await?
.iter()
.map(|location| location.id)
.collect();
Ok(self
.active_connections
.lock()
.await
.iter()
.filter(|connection| locations.contains(&connection.location_id))
.cloned()
.collect())
}
}
28 changes: 12 additions & 16 deletions src-tauri/src/bin/defguard-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{env, str::FromStr, sync::LazyLock};
#[cfg(target_os = "windows")]
use defguard_client::utils::sync_connections;
use defguard_client::{
active_connections::close_all_connections,
app_config::AppConfig,
appstate::AppState,
commands::*,
Expand All @@ -25,7 +26,7 @@ use defguard_client::{
use log::{Level, LevelFilter};
#[cfg(target_os = "macos")]
use tauri::{process, Env};
use tauri::{AppHandle, Builder, Emitter, Manager, RunEvent, State, WindowEvent};
use tauri::{AppHandle, Builder, Emitter, Manager, RunEvent, WindowEvent};
use tauri_plugin_log::{Target, TargetKind};

#[derive(Clone, serde::Serialize)]
Expand Down Expand Up @@ -129,7 +130,7 @@ fn main() {
"PATH",
format!("{current_path}:{}", current_bin_dir.to_str().unwrap()),
);
debug!("Added binary dir {current_bin_dir:?} to PATH");
debug!("Added binary dir {} to PATH", current_bin_dir.display());
}

let app = Builder::default()
Expand Down Expand Up @@ -277,18 +278,18 @@ fn main() {
// Startup tasks
RunEvent::Ready => {
info!(
"Application data (database file) will be stored in: {:?} and application logs in: {:?}. \
"Application data (database file) will be stored in: {} and application logs in: {}. \
Logs of the background Defguard service responsible for managing VPN connections at the \
network level will be stored in: {}.",
// display the path to the app data directory, convert option<pathbuf> to option<&str>
app_handle
.path()
.app_data_dir()
.unwrap_or_else(|_| "UNDEFINED DATA DIRECTORY".into()),
.unwrap_or_else(|_| "UNDEFINED DATA DIRECTORY".into()).display(),
app_handle
.path()
.app_log_dir()
.unwrap_or_else(|_| "UNDEFINED LOG DIRECTORY".into()),
.unwrap_or_else(|_| "UNDEFINED LOG DIRECTORY".into()).display(),
service::config::DEFAULT_LOG_DIR
);
tauri::async_runtime::block_on(startup(app_handle));
Expand All @@ -305,22 +306,17 @@ fn main() {
});
debug!("Ctrl-C handler has been set up successfully");
}
// Prevent shutdown on window close.
RunEvent::ExitRequested { code, api, .. } => {
RunEvent::ExitRequested { api, .. } => {
debug!("Received exit request");
// `None` when the exit is requested by user interaction.
if code.is_none() {
api.prevent_exit();
} else {
let app_state = app_handle.state::<State<AppState>>();
tauri::async_runtime::block_on(async {
let _ = app_state.close_all_connections().await;
});
}
// Prevent shutdown on window close.
api.prevent_exit();
}
// Handle shutdown.
RunEvent::Exit => {
debug!("Exiting the application's main event loop.");
tauri::async_runtime::block_on(async {
let _ = close_all_connections().await;
});
}
_ => {
trace!("Received event: {event:?}");
Expand Down
Loading