Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 73 additions & 0 deletions client/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct AppConfig {
pub target_http_service_url: String,
pub allowed_paths: Vec<String>,
pub allowed_ips: Vec<String>,
pub allowed_asns: Vec<u32>,
}

impl AppConfig {
Expand All @@ -35,13 +36,16 @@ impl AppConfig {

let allowed_ips = get_allowed_ips();

let allowed_asns = get_allowed_asns();

Self {
server_ws_url,
client_id,
secret_token,
target_http_service_url,
allowed_paths,
allowed_ips,
allowed_asns,
}
}
}
Expand Down Expand Up @@ -207,3 +211,72 @@ fn get_target_local_url() -> String {
}
}
}

fn get_allowed_asns() -> Vec<u32> {
println!("\n▶ Enter allowed ASNs for the tunnel (e.g., AS15169).");
println!(" - Press Enter on an empty line to finish. If no ASNs are provided, all ASNs will be allowed.");

let mut asns = Vec::new();
loop {
print!("> ");
io::Write::flush(&mut io::stdout()).expect("Failed to flush stdout");

let mut asn_input = String::new();
match io::stdin().read_line(&mut asn_input) {
Ok(0) => break, // EOF
Ok(_) => {
let asn_input = asn_input.trim().to_string();
if asn_input.is_empty() {
if asns.is_empty() {
print!(" ⚠️ Are you sure you want to allow all ASNs? This is a security risk. (y/N) ");
io::Write::flush(&mut io::stdout()).expect("Failed to flush stdout");
let mut confirmation = String::new();
io::stdin()
.read_line(&mut confirmation)
.expect("Failed to read line");
if confirmation.trim().eq_ignore_ascii_case("y") {
println!(" ✅ All ASNs will be allowed.");
return Vec::new();
} else {
println!(" Operation cancelled. Please enter at least one ASN.");
continue;
}
} else {
break;
}
}

match validate_asn(&asn_input) {
Ok(ref asn) => {
if !asns.contains(asn) {
asns.push(*asn);
println!(" ✅ Added.");
}
}
Err(e) => eprintln!(" ❌ Error: {e}. Please try again."),
}
}
Err(_) => {
eprintln!("Error: Failed to read input.");
break;
}
}
}
asns
}

fn validate_asn(asn_str: &str) -> Result<u32, String> {
let cleaned = asn_str.trim().to_uppercase();

let number_str = match cleaned.strip_prefix("AS") {
Some(s) => s,
None => &cleaned,
};

match number_str.parse::<u32>() {
// The size of an ASN can be up to 32 bits
Ok(asn) if asn >= 1 => Ok(asn),
Ok(_) => Err("ASN out of valid range (1-4294967295)".to_string()),
Err(_) => Err("Invalid ASN format".to_string()),
}
}
13 changes: 13 additions & 0 deletions client/src/websocket_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ pub async fn connect_to_websocket(
.append_pair("allowed_ips", &config.allowed_ips.join(","));
}

if !config.allowed_asns.is_empty() {
ws_url.query_pairs_mut().append_pair(
"allowed_asns",
&config
.allowed_asns
.iter()
.map(u32::to_string)
.collect::<Vec<String>>()
.join(","),
);
}

let auth_header_value = format!("Bearer {}", config.secret_token);
let host = ws_url.host_str().ok_or("Invalid WebSocket URL: no host")?;

Expand Down Expand Up @@ -139,6 +151,7 @@ impl Clone for AppConfig {
target_http_service_url: self.target_http_service_url.clone(),
allowed_paths: self.allowed_paths.clone(),
allowed_ips: self.allowed_ips.clone(),
allowed_asns: self.allowed_asns.clone(),
}
}
}
1 change: 1 addition & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ dashmap = "5.5.3"
futures-util = "0.3.30"
tower-http = { version = "0.5", features = ["trace"] }
uuid = { version = "1.8.0", features = ["v4"] }
maxminddb = "0.26.0"
Binary file added server/asn-test.mmdb
Binary file not shown.
56 changes: 54 additions & 2 deletions server/src/access_control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ use axum::{
};
use axum_extra::{headers::Authorization, TypedHeader};
use ipnetwork::IpNetwork;
use maxminddb::geoip2;
use std::net::IpAddr;
use std::sync::Arc;
use tracing::error;
use tracing::{error, warn};

pub fn authenticate_client(
auth_header: Option<TypedHeader<Authorization<axum_extra::headers::authorization::Bearer>>>,
Expand Down Expand Up @@ -59,13 +60,21 @@ pub fn add_allowed_paths(
Ok(())
}

pub fn add_allowed_asns(
app_state: &Arc<AppState>,
client_id: &str,
asns: Vec<u32>,
) -> Result<(), Response> {
app_state.allowed_asns.insert(client_id.to_string(), asns);
Ok(())
}

pub fn is_ip_allowed(
app_state: &Arc<AppState>,
client_id: &str,
remote_ip: IpAddr,
) -> Result<(), Response> {
if let Some(allowed_ips_ref) = app_state.allowed_ips.get(client_id) {

if allowed_ips_ref.is_empty() {
return Ok(());
}
Expand Down Expand Up @@ -118,3 +127,46 @@ pub fn is_path_allowed(
Err((StatusCode::NOT_FOUND).into_response())
}
}

pub async fn is_asn_allowed(
app_state: &Arc<AppState>,
client_id: &str,
remote_ip: IpAddr,
) -> Result<(), impl IntoResponse> {
if let Some(allowed_asns_ref) = app_state.allowed_asns.get(client_id) {
if allowed_asns_ref.is_empty() {
return Ok(());
}

// Allow requests from loopback addresses without further checks.
if remote_ip.is_loopback() {
return Ok(());
}

let asn = app_state.db_reader.lookup::<geoip2::Asn>(remote_ip);

let asn = asn
.inspect_err(|_e| error!("Error while doing ASN lookup. IP: {remote_ip}"))
.map_err(|_e| (StatusCode::NOT_FOUND, "No ASN found for the given IP").into_response())?
.ok_or_else(|| {
warn!("No ASN connected to IP: {remote_ip}");
(StatusCode::NOT_FOUND, "No ASN found for the given IP").into_response()
})?
.autonomous_system_number
.ok_or_else(|| {
warn!("ASN number is None for IP: {remote_ip}");
(StatusCode::NOT_FOUND, "No ASN found for the given IP").into_response()
})?;

let is_allowed = allowed_asns_ref.iter().any(|p| p == &asn);

if is_allowed {
Ok(())
} else {
error!("Asn '{asn}' is not in the allowed list for client_id '{client_id}'");
Err((StatusCode::FORBIDDEN, "ASN not allowed").into_response())
}
} else {
Ok(())
}
}
7 changes: 6 additions & 1 deletion server/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use dotenvy::dotenv;
use std::env;
use std::{env, path::PathBuf};

pub struct Config {
pub secret_token: String,
pub is_production: bool,
pub asn_db_path: PathBuf,
}

impl Config {
Expand All @@ -13,9 +14,13 @@ impl Config {
let is_production = env::var("IS_PRODUCTION")
.map(|val| val == "true")
.unwrap_or(false);
let asn_db_path: PathBuf = env::var("ASN_DB_PATH")
.map(PathBuf::from)
.unwrap_or(PathBuf::from("asn-test.mmdb"));
Self {
secret_token,
is_production,
asn_db_path,
}
}
}
4 changes: 4 additions & 0 deletions server/src/forwarding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ async fn handle_forwarding_request(
return response.into_response();
}

if let Err(response) = access_control::is_asn_allowed(&app_state, &client_id, remote_ip).await {
return response.into_response();
}

if let Some(ws_sender) = app_state.active_websockets.get(&client_id) {
let headers_map: HashMap<String, String> = headers
.iter()
Expand Down
7 changes: 7 additions & 0 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub struct AppState {
pub pending_responses: Arc<DashMap<String, oneshot::Sender<TunneledHttpResponse>>>,
pub allowed_paths: Arc<DashMap<String, Vec<String>>>,
pub allowed_ips: Arc<DashMap<String, Vec<String>>>,
pub allowed_asns: Arc<DashMap<String, Vec<u32>>>,
pub db_reader: Arc<maxminddb::Reader<Vec<u8>>>,
}

impl AppState {
Expand All @@ -38,6 +40,11 @@ impl AppState {
pending_responses: Arc::new(DashMap::new()),
allowed_paths: Arc::new(DashMap::new()),
allowed_ips: Arc::new(DashMap::new()),
allowed_asns: Arc::new(DashMap::new()),
db_reader: Arc::new(
maxminddb::Reader::open_readfile(config.asn_db_path)
.expect("Failed to open ASN database"),
),
}
}
}
Expand Down
21 changes: 21 additions & 0 deletions server/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,33 @@ pub struct ClientParams {
default = "default_vec"
)]
pub allowed_ips: Vec<String>,
#[serde(deserialize_with = "deserialize_u32_vec", default = "default_u32_vec")]
pub allowed_asns: Vec<u32>,
}

fn default_vec() -> Vec<String> {
Vec::new()
}

fn default_u32_vec() -> Vec<u32> {
Vec::new()
}

fn deserialize_u32_vec<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: Option<String> = Option::deserialize(deserializer)?;
s.map(|s| {
s.split(',')
.map(|s| s.trim().parse::<u32>())
.collect::<Result<Vec<u32>, _>>()
})
.transpose()
.map(Option::unwrap_or_default)
.map_err(serde::de::Error::custom)
}

fn deserialize_comma_separated_optional<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
Expand Down
6 changes: 6 additions & 0 deletions server/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ pub async fn ws_handler(
return e.into_response();
}

let allowed_asns = params.allowed_asns.clone();
if let Err(e) = access_control::add_allowed_asns(&app_state, &client_id, allowed_asns) {
error!("Failed to add allowed ASNs");
return e.into_response();
}

ws.on_upgrade(move |socket| handle_websocket(socket, app_state, client_id))
}

Expand Down