diff --git a/Cargo.lock b/Cargo.lock index 85a96d8..9fbfb7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -924,6 +924,19 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "maxminddb" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a197e44322788858682406c74b0b59bf8d9b4954fe1f224d9a25147f1880bba" +dependencies = [ + "ipnetwork", + "log", + "memchr", + "serde", + "thiserror 2.0.12", +] + [[package]] name = "memchr" version = "2.7.5" @@ -1597,7 +1610,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl 2.0.12", ] [[package]] @@ -1611,6 +1633,17 @@ dependencies = [ "syn", ] +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -1876,7 +1909,7 @@ dependencies = [ "log", "rand", "sha1", - "thiserror", + "thiserror 1.0.69", "utf-8", ] @@ -1894,7 +1927,7 @@ dependencies = [ "log", "rand", "sha1", - "thiserror", + "thiserror 1.0.69", "utf-8", ] @@ -2345,6 +2378,7 @@ dependencies = [ "dotenvy", "futures-util", "ipnetwork", + "maxminddb", "serde", "serde_json", "tokio", diff --git a/client/src/config.rs b/client/src/config.rs index c6bf7f5..f40bab4 100644 --- a/client/src/config.rs +++ b/client/src/config.rs @@ -11,6 +11,7 @@ pub struct AppConfig { pub target_http_service_url: String, pub allowed_paths: Vec, pub allowed_ips: Vec, + pub allowed_asns: Vec, } impl AppConfig { @@ -35,6 +36,8 @@ impl AppConfig { let allowed_ips = get_allowed_ips(); + let allowed_asns = get_allowed_asns(); + Self { server_ws_url, client_id, @@ -42,6 +45,7 @@ impl AppConfig { target_http_service_url, allowed_paths, allowed_ips, + allowed_asns, } } } @@ -207,3 +211,72 @@ fn get_target_local_url() -> String { } } } + +fn get_allowed_asns() -> Vec { + println!("\n▶ Enter allowed ASNs for the tunnel (e.g., AS15169)."); + println!(" - Press Enter on an empty line to finish. If no ASNs are provided, all ASNs will be allowed."); + + let mut asns = Vec::new(); + loop { + print!("> "); + io::Write::flush(&mut io::stdout()).expect("Failed to flush stdout"); + + let mut asn_input = String::new(); + match io::stdin().read_line(&mut asn_input) { + Ok(0) => break, // EOF + Ok(_) => { + let asn_input = asn_input.trim().to_string(); + if asn_input.is_empty() { + if asns.is_empty() { + print!(" ⚠️ Are you sure you want to allow all ASNs? This is a security risk. (y/N) "); + io::Write::flush(&mut io::stdout()).expect("Failed to flush stdout"); + let mut confirmation = String::new(); + io::stdin() + .read_line(&mut confirmation) + .expect("Failed to read line"); + if confirmation.trim().eq_ignore_ascii_case("y") { + println!(" ✅ All ASNs will be allowed."); + return Vec::new(); + } else { + println!(" Operation cancelled. Please enter at least one ASN."); + continue; + } + } else { + break; + } + } + + match validate_asn(&asn_input) { + Ok(ref asn) => { + if !asns.contains(asn) { + asns.push(*asn); + println!(" ✅ Added."); + } + } + Err(e) => eprintln!(" ❌ Error: {e}. Please try again."), + } + } + Err(_) => { + eprintln!("Error: Failed to read input."); + break; + } + } + } + asns +} + +fn validate_asn(asn_str: &str) -> Result { + let cleaned = asn_str.trim().to_uppercase(); + + let number_str = match cleaned.strip_prefix("AS") { + Some(s) => s, + None => &cleaned, + }; + + match number_str.parse::() { + // The size of an ASN can be up to 32 bits + Ok(asn) if asn >= 1 => Ok(asn), + Ok(_) => Err("ASN out of valid range (1-4294967295)".to_string()), + Err(_) => Err("Invalid ASN format".to_string()), + } +} diff --git a/client/src/websocket_handler.rs b/client/src/websocket_handler.rs index b882b61..b8fe9fb 100644 --- a/client/src/websocket_handler.rs +++ b/client/src/websocket_handler.rs @@ -32,6 +32,18 @@ pub async fn connect_to_websocket( .append_pair("allowed_ips", &config.allowed_ips.join(",")); } + if !config.allowed_asns.is_empty() { + ws_url.query_pairs_mut().append_pair( + "allowed_asns", + &config + .allowed_asns + .iter() + .map(u32::to_string) + .collect::>() + .join(","), + ); + } + let auth_header_value = format!("Bearer {}", config.secret_token); let host = ws_url.host_str().ok_or("Invalid WebSocket URL: no host")?; @@ -139,6 +151,7 @@ impl Clone for AppConfig { target_http_service_url: self.target_http_service_url.clone(), allowed_paths: self.allowed_paths.clone(), allowed_ips: self.allowed_ips.clone(), + allowed_asns: self.allowed_asns.clone(), } } } diff --git a/server/Cargo.toml b/server/Cargo.toml index b7c17da..e98adea 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -20,3 +20,4 @@ dashmap = "5.5.3" futures-util = "0.3.30" tower-http = { version = "0.5", features = ["trace"] } uuid = { version = "1.8.0", features = ["v4"] } +maxminddb = "0.26.0" diff --git a/server/asn-test.mmdb b/server/asn-test.mmdb new file mode 100644 index 0000000..21ce6cc Binary files /dev/null and b/server/asn-test.mmdb differ diff --git a/server/src/access_control.rs b/server/src/access_control.rs index eee4bdd..381437f 100644 --- a/server/src/access_control.rs +++ b/server/src/access_control.rs @@ -5,9 +5,10 @@ use axum::{ }; use axum_extra::{headers::Authorization, TypedHeader}; use ipnetwork::IpNetwork; +use maxminddb::geoip2; use std::net::IpAddr; use std::sync::Arc; -use tracing::error; +use tracing::{error, warn}; pub fn authenticate_client( auth_header: Option>>, @@ -59,13 +60,21 @@ pub fn add_allowed_paths( Ok(()) } +pub fn add_allowed_asns( + app_state: &Arc, + client_id: &str, + asns: Vec, +) -> Result<(), Response> { + app_state.allowed_asns.insert(client_id.to_string(), asns); + Ok(()) +} + pub fn is_ip_allowed( app_state: &Arc, client_id: &str, remote_ip: IpAddr, ) -> Result<(), Response> { if let Some(allowed_ips_ref) = app_state.allowed_ips.get(client_id) { - if allowed_ips_ref.is_empty() { return Ok(()); } @@ -118,3 +127,46 @@ pub fn is_path_allowed( Err((StatusCode::NOT_FOUND).into_response()) } } + +pub async fn is_asn_allowed( + app_state: &Arc, + client_id: &str, + remote_ip: IpAddr, +) -> Result<(), impl IntoResponse> { + if let Some(allowed_asns_ref) = app_state.allowed_asns.get(client_id) { + if allowed_asns_ref.is_empty() { + return Ok(()); + } + + // Allow requests from loopback addresses without further checks. + if remote_ip.is_loopback() { + return Ok(()); + } + + let asn = app_state.db_reader.lookup::(remote_ip); + + let asn = asn + .inspect_err(|_e| error!("Error while doing ASN lookup. IP: {remote_ip}")) + .map_err(|_e| (StatusCode::NOT_FOUND, "No ASN found for the given IP").into_response())? + .ok_or_else(|| { + warn!("No ASN connected to IP: {remote_ip}"); + (StatusCode::NOT_FOUND, "No ASN found for the given IP").into_response() + })? + .autonomous_system_number + .ok_or_else(|| { + warn!("ASN number is None for IP: {remote_ip}"); + (StatusCode::NOT_FOUND, "No ASN found for the given IP").into_response() + })?; + + let is_allowed = allowed_asns_ref.iter().any(|p| p == &asn); + + if is_allowed { + Ok(()) + } else { + error!("Asn '{asn}' is not in the allowed list for client_id '{client_id}'"); + Err((StatusCode::FORBIDDEN, "ASN not allowed").into_response()) + } + } else { + Ok(()) + } +} diff --git a/server/src/config.rs b/server/src/config.rs index b05eff9..0342416 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -1,9 +1,10 @@ use dotenvy::dotenv; -use std::env; +use std::{env, path::PathBuf}; pub struct Config { pub secret_token: String, pub is_production: bool, + pub asn_db_path: PathBuf, } impl Config { @@ -13,9 +14,13 @@ impl Config { let is_production = env::var("IS_PRODUCTION") .map(|val| val == "true") .unwrap_or(false); + let asn_db_path: PathBuf = env::var("ASN_DB_PATH") + .map(PathBuf::from) + .unwrap_or(PathBuf::from("asn-test.mmdb")); Self { secret_token, is_production, + asn_db_path, } } } diff --git a/server/src/forwarding.rs b/server/src/forwarding.rs index b3ec8f1..301b371 100644 --- a/server/src/forwarding.rs +++ b/server/src/forwarding.rs @@ -40,6 +40,10 @@ async fn handle_forwarding_request( return response.into_response(); } + if let Err(response) = access_control::is_asn_allowed(&app_state, &client_id, remote_ip).await { + return response.into_response(); + } + if let Some(ws_sender) = app_state.active_websockets.get(&client_id) { let headers_map: HashMap = headers .iter() diff --git a/server/src/main.rs b/server/src/main.rs index b3e0348..8d3a428 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -27,6 +27,8 @@ pub struct AppState { pub pending_responses: Arc>>, pub allowed_paths: Arc>>, pub allowed_ips: Arc>>, + pub allowed_asns: Arc>>, + pub db_reader: Arc>>, } impl AppState { @@ -38,6 +40,11 @@ impl AppState { pending_responses: Arc::new(DashMap::new()), allowed_paths: Arc::new(DashMap::new()), allowed_ips: Arc::new(DashMap::new()), + allowed_asns: Arc::new(DashMap::new()), + db_reader: Arc::new( + maxminddb::Reader::open_readfile(config.asn_db_path) + .expect("Failed to open ASN database"), + ), } } } diff --git a/server/src/models.rs b/server/src/models.rs index 66fddbe..700ddf0 100644 --- a/server/src/models.rs +++ b/server/src/models.rs @@ -11,12 +11,33 @@ pub struct ClientParams { default = "default_vec" )] pub allowed_ips: Vec, + #[serde(deserialize_with = "deserialize_u32_vec", default = "default_u32_vec")] + pub allowed_asns: Vec, } fn default_vec() -> Vec { Vec::new() } +fn default_u32_vec() -> Vec { + Vec::new() +} + +fn deserialize_u32_vec<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let s: Option = Option::deserialize(deserializer)?; + s.map(|s| { + s.split(',') + .map(|s| s.trim().parse::()) + .collect::, _>>() + }) + .transpose() + .map(Option::unwrap_or_default) + .map_err(serde::de::Error::custom) +} + fn deserialize_comma_separated_optional<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, diff --git a/server/src/websocket.rs b/server/src/websocket.rs index a7a6ed0..c10cc3e 100644 --- a/server/src/websocket.rs +++ b/server/src/websocket.rs @@ -41,6 +41,12 @@ pub async fn ws_handler( return e.into_response(); } + let allowed_asns = params.allowed_asns.clone(); + if let Err(e) = access_control::add_allowed_asns(&app_state, &client_id, allowed_asns) { + error!("Failed to add allowed ASNs"); + return e.into_response(); + } + ws.on_upgrade(move |socket| handle_websocket(socket, app_state, client_id)) }