Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 72 additions & 58 deletions crates/defguard_core/src/db/models/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ use super::wireguard::{
};
use crate::{
KEY_LENGTH,
db::{User, models::wireguard::ServiceLocationMode},
db::{
User,
models::wireguard::{ServiceLocationMode, get_allowed_ips_for_device},
},
enterprise::db::models::enterprise_settings::EnterpriseSettings,
};

#[derive(Serialize, ToSchema)]
#[derive(Deserialize, Serialize, ToSchema)]
pub struct DeviceConfig {
pub(crate) network_id: Id,
pub(crate) network_name: String,
Expand All @@ -40,7 +44,7 @@ pub struct DeviceConfig {
pub(crate) address: Vec<IpAddr>,
pub(crate) endpoint: String,
#[schema(value_type = String)]
pub(crate) allowed_ips: Vec<IpNetwork>,
pub allowed_ips: Vec<IpNetwork>,
pub(crate) pubkey: String,
pub(crate) dns: Option<String>,
pub(crate) keepalive_interval: i32,
Expand Down Expand Up @@ -568,10 +572,11 @@ impl Device<Id> {
/// Create WireGuard config for device.
#[must_use]
pub(crate) fn create_config(
network: &WireguardNetwork<Id>,
location: &WireguardNetwork<Id>,
wireguard_network_device: &WireguardNetworkDevice,
enterprise_settings: &EnterpriseSettings,
) -> String {
let dns = match &network.dns {
let dns = match &location.dns {
Some(dns) => {
if dns.is_empty() {
String::new()
Expand All @@ -582,10 +587,11 @@ impl Device<Id> {
None => String::new(),
};

let allowed_ips = if network.allowed_ips.is_empty() {
let location_allowed_ips = get_allowed_ips_for_device(enterprise_settings, location);
let allowed_ips = if location_allowed_ips.is_empty() {
String::new()
} else {
format!("AllowedIPs = {}\n", network.allowed_ips.as_csv())
format!("AllowedIPs = {}\n", location_allowed_ips.as_csv())
};

format!(
Expand All @@ -600,9 +606,9 @@ impl Device<Id> {
Endpoint = {}:{}\n\
PersistentKeepalive = 300",
wireguard_network_device.wireguard_ips.as_csv(),
network.pubkey,
network.endpoint,
network.port,
location.pubkey,
location.endpoint,
location.port,
)
}

Expand Down Expand Up @@ -682,67 +688,71 @@ impl Device<Id> {

pub(crate) async fn get_network_configs(
&self,
network: &WireguardNetwork<Id>,
transaction: &mut PgConnection,
location: &WireguardNetwork<Id>,
enterprise_settings: &EnterpriseSettings,
) -> Result<(DeviceNetworkInfo, DeviceConfig), DeviceError> {
let wireguard_network_device =
WireguardNetworkDevice::find(&mut *transaction, self.id, network.id)
WireguardNetworkDevice::find(&mut *transaction, self.id, location.id)
.await?
.ok_or_else(|| DeviceError::Unexpected("Device not found in network".into()))?;
let device_network_info = DeviceNetworkInfo {
network_id: network.id,
network_id: location.id,
device_wireguard_ips: wireguard_network_device.wireguard_ips.clone(),
preshared_key: wireguard_network_device.preshared_key.clone(),
is_authorized: wireguard_network_device.is_authorized,
};

let config = Self::create_config(network, &wireguard_network_device);
let config = Self::create_config(location, &wireguard_network_device, enterprise_settings);
let allowed_ips = get_allowed_ips_for_device(enterprise_settings, location);
let device_config = DeviceConfig {
network_id: network.id,
network_name: network.name.clone(),
network_id: location.id,
network_name: location.name.clone(),
config,
endpoint: format!("{}:{}", network.endpoint, network.port),
endpoint: format!("{}:{}", location.endpoint, location.port),
address: wireguard_network_device.wireguard_ips,
allowed_ips: network.allowed_ips.clone(),
pubkey: network.pubkey.clone(),
dns: network.dns.clone(),
keepalive_interval: network.keepalive_interval,
location_mfa_mode: network.location_mfa_mode.clone(),
service_location_mode: network.service_location_mode.clone(),
allowed_ips,
pubkey: location.pubkey.clone(),
dns: location.dns.clone(),
keepalive_interval: location.keepalive_interval,
location_mfa_mode: location.location_mfa_mode.clone(),
service_location_mode: location.service_location_mode.clone(),
};

Ok((device_network_info, device_config))
}

pub(crate) async fn add_to_network(
&self,
network: &WireguardNetwork<Id>,
ip: &[IpAddr],
transaction: &mut PgConnection,
location: &WireguardNetwork<Id>,
ip: &[IpAddr],
enterprise_settings: &EnterpriseSettings,
) -> Result<(DeviceNetworkInfo, DeviceConfig), DeviceError> {
let wireguard_network_device = self
.assign_network_ips(&mut *transaction, network, ip)
.assign_network_ips(&mut *transaction, location, ip)
.await?;
let device_network_info = DeviceNetworkInfo {
network_id: network.id,
network_id: location.id,
device_wireguard_ips: wireguard_network_device.wireguard_ips.clone(),
preshared_key: wireguard_network_device.preshared_key.clone(),
is_authorized: wireguard_network_device.is_authorized,
};

let config = Self::create_config(network, &wireguard_network_device);
let config = Self::create_config(location, &wireguard_network_device, enterprise_settings);
let allowed_ips = get_allowed_ips_for_device(enterprise_settings, location);
let device_config = DeviceConfig {
network_id: network.id,
network_name: network.name.clone(),
network_id: location.id,
network_name: location.name.clone(),
config,
endpoint: format!("{}:{}", network.endpoint, network.port),
endpoint: format!("{}:{}", location.endpoint, location.port),
address: wireguard_network_device.wireguard_ips,
allowed_ips: network.allowed_ips.clone(),
pubkey: network.pubkey.clone(),
dns: network.dns.clone(),
keepalive_interval: network.keepalive_interval,
location_mfa_mode: network.location_mfa_mode.clone(),
service_location_mode: network.service_location_mode.clone(),
allowed_ips,
pubkey: location.pubkey.clone(),
dns: location.dns.clone(),
keepalive_interval: location.keepalive_interval,
location_mfa_mode: location.location_mfa_mode.clone(),
service_location_mode: location.service_location_mode.clone(),
};

Ok((device_network_info, device_config))
Expand All @@ -754,58 +764,62 @@ impl Device<Id> {
transaction: &mut PgConnection,
) -> Result<(Vec<DeviceNetworkInfo>, Vec<DeviceConfig>), DeviceError> {
info!("Adding device {} to all existing networks", self.name);
let networks = WireguardNetwork::all(&mut *transaction).await?;
let locations = WireguardNetwork::all(&mut *transaction).await?;

let enterprise_settings = EnterpriseSettings::get(&mut *transaction).await?;

let mut configs = Vec::new();
let mut network_info = Vec::new();
for network in networks {
for location in locations {
debug!(
"Assigning IP for device {} (user {}) in network {network}",
"Assigning IP for device {} (user {}) in location {location}",
self.name, self.user_id
);
// check for pubkey conflicts with networks
if network.pubkey == self.wireguard_pubkey {
return Err(DeviceError::PubkeyConflict(self.clone(), network.name));
if location.pubkey == self.wireguard_pubkey {
return Err(DeviceError::PubkeyConflict(self.clone(), location.name));
}
if WireguardNetworkDevice::find(&mut *transaction, self.id, network.id)
if WireguardNetworkDevice::find(&mut *transaction, self.id, location.id)
.await?
.is_some()
{
debug!("Device {self} already has an IP within network {network}. Skipping...",);
debug!("Device {self} already has an IP within location {location}. Skipping...",);
continue;
}

if let Ok(wireguard_network_device) = network
if let Ok(wireguard_network_device) = location
.add_device_to_network(&mut *transaction, self, None)
.await
{
debug!(
"Assigned IPs {} for device {} (user {}) in network {network}",
"Assigned IPs {} for device {} (user {}) in location {location}",
wireguard_network_device.wireguard_ips.as_csv(),
self.name,
self.user_id
);
let device_network_info = DeviceNetworkInfo {
network_id: network.id,
network_id: location.id,
device_wireguard_ips: wireguard_network_device.wireguard_ips.clone(),
preshared_key: wireguard_network_device.preshared_key.clone(),
is_authorized: wireguard_network_device.is_authorized,
};
network_info.push(device_network_info);

let config = Self::create_config(&network, &wireguard_network_device);
let config =
Self::create_config(&location, &wireguard_network_device, &enterprise_settings);
let allowed_ips = get_allowed_ips_for_device(&enterprise_settings, &location);
configs.push(DeviceConfig {
network_id: network.id,
network_name: network.name,
network_id: location.id,
network_name: location.name,
config,
endpoint: format!("{}:{}", network.endpoint, network.port),
endpoint: format!("{}:{}", location.endpoint, location.port),
address: wireguard_network_device.wireguard_ips,
allowed_ips: network.allowed_ips,
pubkey: network.pubkey,
dns: network.dns,
keepalive_interval: network.keepalive_interval,
location_mfa_mode: network.location_mfa_mode.clone(),
service_location_mode: network.service_location_mode.clone(),
allowed_ips,
pubkey: location.pubkey,
dns: location.dns,
keepalive_interval: location.keepalive_interval,
location_mfa_mode: location.location_mfa_mode.clone(),
service_location_mode: location.service_location_mode.clone(),
});
}
}
Expand Down Expand Up @@ -940,7 +954,7 @@ impl Device<Id> {
}

/// Gets the first network of the network device
/// FIXME: Return only one network, not a Vec
// FIXME: Return only one network, not a Vec
pub async fn find_network_device_networks<'e, E>(
&self,
executor: E,
Expand Down
26 changes: 24 additions & 2 deletions crates/defguard_core/src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
collections::HashMap,
fmt::{self, Display},
iter::zip,
net::{IpAddr, Ipv4Addr},
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};

use base64::prelude::{BASE64_STANDARD, Engine};
Expand Down Expand Up @@ -40,7 +40,11 @@ use super::{
wireguard_peer_stats::WireguardPeerStats,
};
use crate::{
enterprise::{firewall::FirewallError, is_enterprise_license_active},
enterprise::{
db::models::enterprise_settings::{ClientTrafficPolicy, EnterpriseSettings},
firewall::FirewallError,
is_enterprise_license_active,
},
grpc::gateway::{send_multiple_wireguard_events, state::GatewayState},
wg_config::ImportedDevice,
};
Expand Down Expand Up @@ -1492,6 +1496,24 @@ pub(crate) async fn networks_stats(
})
}

// If `force_all_traffic` setting is enabled we override the allowed_ips
// to also enforce this on legacy clients.
pub fn get_allowed_ips_for_device(
enterprise_settings: &EnterpriseSettings,
location: &WireguardNetwork<Id>,
) -> Vec<IpNetwork> {
if enterprise_settings.client_traffic_policy == ClientTrafficPolicy::ForceAllTraffic {
vec![
IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
.expect("Failed to parse UNSPECIFIED IPv4 constant"),
IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
.expect("Failed to parse UNSPECIFIED IPv6 constant"),
]
} else {
location.allowed_ips.clone()
}
}

#[cfg(test)]
mod test {
use std::str::FromStr;
Expand Down
2 changes: 1 addition & 1 deletion crates/defguard_core/src/grpc/enrollment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ impl EnrollmentServer {
}

let (network_info, configs) = device
.get_network_configs(&network, &mut transaction)
.get_network_configs(&mut transaction, &network, &enterprise_settings)
.await
.map_err(|err| {
error!(
Expand Down
Loading