From ae8d556e7b82e2bb9a2cd739ba78c8a3a4a2217f Mon Sep 17 00:00:00 2001 From: Joxtacy Date: Sun, 20 Jul 2025 22:45:29 +0200 Subject: [PATCH] feat: add ASN checking resolves: #8 --- Cargo.lock | 40 +++++++++++++++-- client/src/config.rs | 73 ++++++++++++++++++++++++++++++++ client/src/websocket_handler.rs | 13 ++++++ server/Cargo.toml | 1 + server/asn-test.mmdb | Bin 0 -> 12653 bytes server/src/access_control.rs | 56 +++++++++++++++++++++++- server/src/config.rs | 7 ++- server/src/forwarding.rs | 4 ++ server/src/main.rs | 7 +++ server/src/models.rs | 21 +++++++++ server/src/websocket.rs | 6 +++ 11 files changed, 222 insertions(+), 6 deletions(-) create mode 100644 server/asn-test.mmdb diff --git a/Cargo.lock b/Cargo.lock index 85a96d8..9fbfb7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -924,6 +924,19 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "maxminddb" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a197e44322788858682406c74b0b59bf8d9b4954fe1f224d9a25147f1880bba" +dependencies = [ + "ipnetwork", + "log", + "memchr", + "serde", + "thiserror 2.0.12", +] + [[package]] name = "memchr" version = "2.7.5" @@ -1597,7 +1610,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl 2.0.12", ] [[package]] @@ -1611,6 +1633,17 @@ dependencies = [ "syn", ] +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -1876,7 +1909,7 @@ dependencies = [ "log", "rand", "sha1", - "thiserror", + "thiserror 1.0.69", "utf-8", ] @@ -1894,7 +1927,7 @@ dependencies = [ "log", "rand", "sha1", - "thiserror", + "thiserror 1.0.69", "utf-8", ] @@ -2345,6 +2378,7 @@ dependencies = [ "dotenvy", "futures-util", "ipnetwork", + "maxminddb", "serde", "serde_json", "tokio", diff --git a/client/src/config.rs b/client/src/config.rs index c6bf7f5..f40bab4 100644 --- a/client/src/config.rs +++ b/client/src/config.rs @@ -11,6 +11,7 @@ pub struct AppConfig { pub target_http_service_url: String, pub allowed_paths: Vec, pub allowed_ips: Vec, + pub allowed_asns: Vec, } impl AppConfig { @@ -35,6 +36,8 @@ impl AppConfig { let allowed_ips = get_allowed_ips(); + let allowed_asns = get_allowed_asns(); + Self { server_ws_url, client_id, @@ -42,6 +45,7 @@ impl AppConfig { target_http_service_url, allowed_paths, allowed_ips, + allowed_asns, } } } @@ -207,3 +211,72 @@ fn get_target_local_url() -> String { } } } + +fn get_allowed_asns() -> Vec { + println!("\n▶ Enter allowed ASNs for the tunnel (e.g., AS15169)."); + println!(" - Press Enter on an empty line to finish. If no ASNs are provided, all ASNs will be allowed."); + + let mut asns = Vec::new(); + loop { + print!("> "); + io::Write::flush(&mut io::stdout()).expect("Failed to flush stdout"); + + let mut asn_input = String::new(); + match io::stdin().read_line(&mut asn_input) { + Ok(0) => break, // EOF + Ok(_) => { + let asn_input = asn_input.trim().to_string(); + if asn_input.is_empty() { + if asns.is_empty() { + print!(" ⚠️ Are you sure you want to allow all ASNs? This is a security risk. (y/N) "); + io::Write::flush(&mut io::stdout()).expect("Failed to flush stdout"); + let mut confirmation = String::new(); + io::stdin() + .read_line(&mut confirmation) + .expect("Failed to read line"); + if confirmation.trim().eq_ignore_ascii_case("y") { + println!(" ✅ All ASNs will be allowed."); + return Vec::new(); + } else { + println!(" Operation cancelled. Please enter at least one ASN."); + continue; + } + } else { + break; + } + } + + match validate_asn(&asn_input) { + Ok(ref asn) => { + if !asns.contains(asn) { + asns.push(*asn); + println!(" ✅ Added."); + } + } + Err(e) => eprintln!(" ❌ Error: {e}. Please try again."), + } + } + Err(_) => { + eprintln!("Error: Failed to read input."); + break; + } + } + } + asns +} + +fn validate_asn(asn_str: &str) -> Result { + let cleaned = asn_str.trim().to_uppercase(); + + let number_str = match cleaned.strip_prefix("AS") { + Some(s) => s, + None => &cleaned, + }; + + match number_str.parse::() { + // The size of an ASN can be up to 32 bits + Ok(asn) if asn >= 1 => Ok(asn), + Ok(_) => Err("ASN out of valid range (1-4294967295)".to_string()), + Err(_) => Err("Invalid ASN format".to_string()), + } +} diff --git a/client/src/websocket_handler.rs b/client/src/websocket_handler.rs index b882b61..b8fe9fb 100644 --- a/client/src/websocket_handler.rs +++ b/client/src/websocket_handler.rs @@ -32,6 +32,18 @@ pub async fn connect_to_websocket( .append_pair("allowed_ips", &config.allowed_ips.join(",")); } + if !config.allowed_asns.is_empty() { + ws_url.query_pairs_mut().append_pair( + "allowed_asns", + &config + .allowed_asns + .iter() + .map(u32::to_string) + .collect::>() + .join(","), + ); + } + let auth_header_value = format!("Bearer {}", config.secret_token); let host = ws_url.host_str().ok_or("Invalid WebSocket URL: no host")?; @@ -139,6 +151,7 @@ impl Clone for AppConfig { target_http_service_url: self.target_http_service_url.clone(), allowed_paths: self.allowed_paths.clone(), allowed_ips: self.allowed_ips.clone(), + allowed_asns: self.allowed_asns.clone(), } } } diff --git a/server/Cargo.toml b/server/Cargo.toml index b7c17da..e98adea 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -20,3 +20,4 @@ dashmap = "5.5.3" futures-util = "0.3.30" tower-http = { version = "0.5", features = ["trace"] } uuid = { version = "1.8.0", features = ["v4"] } +maxminddb = "0.26.0" diff --git a/server/asn-test.mmdb b/server/asn-test.mmdb new file mode 100644 index 0000000000000000000000000000000000000000..21ce6cc87c871a1f45e818196be13e2776137cf9 GIT binary patch literal 12653 zcmZvh2YeLO_Q%iMnGI9GqBeo^R6BCH-h>64`qMWE8CKFSLsl+s5djcD*cXvl(Ct_z}7h+dp zH)3~U4+5b&{2#=g#9l-tQAJb}HAF2jov0&b5cR}NqJd~6nuunig=i)A#^c9Gvna)g zHexn`>d@1jOUxtoA?6bch<%BL#C}9Ou|MIVHnH#y5&2N06Sz&J3piiI2bPO0qGz{( z&MZA5abT&)0kriJ38Ifk5&_Xq97v>yG?5{)#A0HJL1bXFmQgy3wu6xU5VH$%uz{Y| zp-`SDau{6>XMmpW5h6zdCo$kCkrlArC~`D$jDcC$u_DKX((xiEgwlx&U&(y8i<}G_ zhChWkl{k&~CviG)263i=&Ma8R*;JlGoSU~{Og(JmDUl1{d6me8z$GFV0T&y+cKMtE zdVQ-*fLtna8Kf&jF3;NvHl^)K!7KCfYPj4$>ViH7upblnu8Ue40JOjKe@+|NoJ)aYKK1}<9 z3>tWFFXfeku%E9W;9t!5Jn_BAYf!!;@;dN_Q5KHim~V=_#k6l56mCQ5UE)3BeZ+mh zxZOq88f76~=?(b^r*RXJk3}pj^%DS<{wc77$Y;Pvk%a_z&x?9MZTf( zTjIOW_C4qM1MwsAZ{j}&g|kI|MxxC`egTjR4*=tup??$kJ(T{S^1sBNB7cRpzm2kx zg<@GovXof1aj_f>T}EYuC?EzAg+viiELMp~p(j=;G1x#)Nh^nN%%Q|a#4uuG0$U1i zIEyu$*wg?@YXp>=OW`8A@Dy94pd2ICXrt7Ltf*L9gwmFrnM-UXg(xmP*jRW@6Km_x za~rY7h0?ZSjSr;>4Bw8JDApv7QXaZgh&36@WSv=4#G0CyptQEnd!kyc9g)3}gY87@ zOza}ou6a*uH!61*YmZQg$^Qf4)5SvnD|}C^y$sCSE5)j!vYMy~ZM9}F9Z+Wm!z;07 zh*cj-GsSAiOL~`@;IhA1&A>vjTEa=RinVtr&Eh;`M4MQsO&vZ5*%yd4mzZZ@=DCkp z^Fs+P`{pGbupjcZbGjc(5o)k#N3nEGbTUH+ZKsOW#Y8@V($JpW@H|kz603)Zi-rDI zgjx3*7+XTDJ}Q%;GGIWz0Rr$NR!XciTrhOTz$mk{Ef#ADZA*z|q04gPqT>$6C|8Je z2ym2Ghl=PbK>xFj;EE16z*5cj94Un#b44rQa+X*}1E-6143)u};R6_F)F>JWmc3Vx1-y`=6F{_!&^1Nl!k2qIR*)hVpW;&H>If`=UM1W5D?a zX4V&&IO{^OE((q?_KA_85;zP{kGqFArDYnEynF3}?pNREoUeY%7KNu|bKa2g(`pPI_Hnwjt z*dJnj3$Wu^-{p3~T4%aHiuG?3U=p$aSwD$|{--niLi|enM*MD2yp@58WB;@M%xCx; zJ$$JCz$a?3S%b1fK~!vY$`Q-#^zmgz&03jf~aw{ZA*;Ly9~THRqTIiqNqv6R@_U^3WV=0YBI3Bs3}xV zC8n+Cxr33+hVa~~og_EPE(qU4)UJ%%E%e;otQKx24sQ>3rwjU@xnm)h5iRO_f~q(1W4>X)J_`Ckp)!5`IMS7FPR9@hZA>gtksmU7_SN zV3DY9diI32xTpg{saI4Yl=>K#6csS8KeQbvDn%vwpPpcb$P$YUiq}wDis_!r!Ip`_ zEv#0EI>^|}Quz5#9YU8wiNlD)33fwuB&DMabXxR3b)u+anCMtYCx|+Z68oRdfd9p> zQn}K=q~(W6oDOvgZKopq3{j_1`e*2Qx)k$9QD?&C98qV{SEv{QI`PEi5hr*m(m6OPdzT`a^P-JS5SE+aTRejaSd@TfkuR6 z>U!b^;zr^o;%4F&41JfVTPd;sDfT~g2bFgk=*i>%02KS5mejpQsqPbXKLZ{h9wZ*h z$Ekl%_|3Hrk)q|0&Oo6FA*>U{=$5C*Ij^p>#FxZZ#McD+pB@GMPvbiSyb{H3sQCF${V0n4PyL6UKjl4j9QvO&+qo; zVvh`Mqr^u4D?y?tu?4ZEftes|TZI8*#Ag4qw`QVkh;hWW#CS}wLhK1xm?!pj26`}i zBIhzm?DD)4F8WieJz4B244-OXl+$S2p4hkA7*gv;W!klx7m_e>Oj$+D%ll|Jf~+ zT8X_0_CK5b&u*h~wt?&~Hv6AFj}zR7m`^Mq*#GQ>l=dUqiTw%eE~adE5S>I9;S-C9 zZi36S?BC?~xXy(0U1v2O^a8>zfW?3+U+zXr8$rR_GcZ#Opkj=YQB zm%A|719Z8YxQDn`?E4sSzd;EfxXJ#YlyE~zuA|FpVhsj+M(jt7WIP`g`!Oi_iO+sq z>?a^S#SBkM$!o^5baP@yj>7(DKPOV!EcWvTcqOGv#eR`4FJY9ICBkQ961^g&HyWGX zCHqyeIkf${Q4W0Wu;0Yv-&5?jILh0^JH)%hd&K+jEEf9%4EC4UYk{A{{t);|?2kl$ zQ*3_>=`*oE0X`M|Eh?1yveLHCi7$vR4fLBR*zhCv*J6J|+qcAb#P`Gx#E%5}pB@wa z4=8=bz>M-U2m3|rU+MB&X#1VEKg9koZGYx%n56x;ksKi`LJ>9rBeFOVqJS7g6cR-S zg9k<_!Jcg;PAM=_9B7@5#VMmF*X|6Zv=K3k;S(usLJTK1B{m~Q5GWK(X>hm5;A71C zoKfP8rY%ZrL2OC527__l^-{)QN)?P7E6&zb^1(UdC~ZrOCngZv5d-s_M5)|h@J*B^ zW9Xg5nL=qQF^$-s*nvQm>$&V?puf=9(k{fV#BKy@$=QPvj=lE$he2*ddzor=Dv2th zny4Xa88@A%!#sP$nL(+Zm4aT+K!LW*&hn#94u(`j3XR$^~rmI2=S7?(B!IEyn| zoHo0G;yju0mU+lk_2xvh(dCB#x<8G-(XSmz)F z93jrZlnx;dB@Qz%YdBoWW{1*|;v7Yn6~xiRF`>(`;v7fi@x%!RW!xTTB`0wbaWZiV zfk%bw(m~Z}{1Ye<=XA{TL2=FiE*Ixa+Rh@*hVmkD&Y^TJr1PW&SChjpPdeC6VM#PQ=CeZ^j zBxWS%A>v_-`HncNDXoF@5~uWtIFCYlUYy4O_CM!wah{;(lf+X7Lr}DO&z>QkC7v@V zWp?KUDql1hf(NbxUdGV;I@5WDY5A_ic~zX(=<+)82IB^v3j3e)HkDko^DaH#!Rl46s%THX<$D;r}PDFzft;%_}XCLt7_+4D!(%r zT1M#y;zvM<^KYd6OPv1zzlie_Z9f|fok;0d*!~b_Akpu%)zkK0;?MPze;X+xghktd zB17>x4*d^Hi8uiM)CV6e70~4dN`*ubQA`{2MM|k0OmIb!A(T+9k)aY{WkrUCN__vh z38mr0ro?6jLkEW5oXU~JC}K1bCAKio7fWPIbGJrZVk?P^fpVNg#+vIQvNf@d!O+JI zOx(5%7*9+fwj(AIlO$4(!KO;2f}>0(rf{$~!@;IWWc$1{^kY*!5q|#};otuv`~)4@ zg|=M@{QYltvlMK;ME0N(FW3L~;XE&!36eoykWIIzm!&g)Uwbm!*Wstu>|VLh`oG2? z)$JwY%e_oINH)|3L3hG;>yw=oAG+2W=P@_h(&i`9nUv?YW|p~)nXa|ak2(NWx66$- zRkn?5b7OvLalF${L%F)njn19nFG={BOnIx<+3Tgc-0GmO-%BpjX4{P}ENt@9X|J;< zn}$;w$-Bl@&=U94lkzhh_K6#vU)@|? zH>+iKt6SX@PkK{#+SRQ>=8hog>p`Zv$XVQy@{--Y+v=y%LDEa;IrU|e@lLbu^tww} zQ~IkLjn(+sOuDnjN8W_r8}zxAb$b5axY1cnekz`EoBhm^Ak~|n$=0_b%@8vKw=M3s z1<7s`+qr(C>LAr0q)b*l-i>Z_c6HG0Co>z&NYDJ%Q5e#0^?Kv!jF-%fmfKFZG2R!? z_+8p_-;=Nnt!)*#*#x_GV|t%=pQRqCK&7op-Bf|T#&_CjlG zOWo++v0zcA!%OzMRaKZ(s=q>)=@C!r-;*B3@6|8Cb+u0F?7hU-3)t|nm=c}jHa1q9 z`8@#FQBBqL8#?A{or`t5CU;Es+&QR-ZLMCJgJ^|zgtsPnNKk{ou%9vJ^j7I@fwdp*w8{Aob zH!c?(I>`FWXpgCdZnUnxwY(*nFh{-g6l{*k&Lf9o%`gr|qC z>+6`I^}~_6tt!Ya>(2V=Tv@pFu~uzNH@nfss+8Y_(&(DzR#s`{xK=E;1YxiwzK%N4wKm^R6cwr)6iZ5S5E9IfWIwhe1#o$EtC z-H97%@6oTq+lgh9b^5lgTrIu9oQcPTxRie2 z8=MGp->n&Y7asQfcz1W;R#nU~O(C+e&W39-?!{0R=nVDxv?k5b)i}Q)oo1$i{o8>%48dMy}pd&-SG^D#_fd*FJHQa zXg00-eoh7{w;7GhcUgYUZ^&!&j?x)&hsK>7q!L|AFu2>F=+s+Rv=*L2nu8@SdQ38d z{}{hx*}(la?*EC@Q8guQbl(|qRDfIGx_GDj%%RZrT`@^Jje7_~7c~1z%ypJcEyKJQ z1$vjU_Rbz%Vz_QlWZJp59L`QHu2$@%Ta^ktQ?vPk*3R(Nu5T5=b$D~UzaKY_UVV4U z>+dl;sU@D(b?U@zo$fYz9lA)@ywUDP=hQd1)y`_JZF8GuH@4MJudS_dt6R)0-q_!b zj+Zgov*=&wObs4SWmaWdeM_^kzK=QBv7j8M!L8MgUvFI>+t|7dZq4%3 z!WO=Mo#g6i=H;dcySn^jB90q;{W{dyMxhVRn~_CXb$YmmY7>5E2CwtCoxaywnp@`J z*wx3%YntYl&o;NvcpU#4+*)RBTcYw~C?Gv!t#eTX)dMA6dORA-oS@5Fr0==h)#7}C zv(i@ESUXKmV)zIw*{VY8kEeTb&t=`W5CawE-iZsTHAvt+L&i_IvHl8oxf@gjW)j;U zuJ=tpCS<~D+&UjUFS*R*bjRtMXqjQ|o8dpW(FNKOvym1&MU8XQr>GREn#6 zw)WANWjziE`gp9Ozrr)Qh98R?Y2a;7ZWBsB!~2BTl6X3;L&oD(W;yS$3Az;Zpd)e7 zHtIq3Ml}y?(A4YDV`qBHm**CqpTo#AQ8L!0I&%3UEUv7<>%g)H-ahHuQMWEUF2BRg zLKwJ^@|!&N9;Bo&df(_3&;Rk}5`Zf&f!ve{f=!*{`|(aX9$y@ph- zw^YVVWsbx5*k{YaVYhX(?t;^vc^ZR zWz1HXK?9yqvNsVwpx1Y`Yi>72#&GYL+a;R{`aRvJnp<>~wBUSB_Y zW`7oEJ&nE_po>qK8Sl?cX;ncz9({M#>-N)@<1U_=47&XG&LEr2tcmS|BN Jm;0+m|1UKy<>>, @@ -59,13 +60,21 @@ pub fn add_allowed_paths( Ok(()) } +pub fn add_allowed_asns( + app_state: &Arc, + client_id: &str, + asns: Vec, +) -> Result<(), Response> { + app_state.allowed_asns.insert(client_id.to_string(), asns); + Ok(()) +} + pub fn is_ip_allowed( app_state: &Arc, client_id: &str, remote_ip: IpAddr, ) -> Result<(), Response> { if let Some(allowed_ips_ref) = app_state.allowed_ips.get(client_id) { - if allowed_ips_ref.is_empty() { return Ok(()); } @@ -118,3 +127,46 @@ pub fn is_path_allowed( Err((StatusCode::NOT_FOUND).into_response()) } } + +pub async fn is_asn_allowed( + app_state: &Arc, + client_id: &str, + remote_ip: IpAddr, +) -> Result<(), impl IntoResponse> { + if let Some(allowed_asns_ref) = app_state.allowed_asns.get(client_id) { + if allowed_asns_ref.is_empty() { + return Ok(()); + } + + // Allow requests from loopback addresses without further checks. + if remote_ip.is_loopback() { + return Ok(()); + } + + let asn = app_state.db_reader.lookup::(remote_ip); + + let asn = asn + .inspect_err(|_e| error!("Error while doing ASN lookup. IP: {remote_ip}")) + .map_err(|_e| (StatusCode::NOT_FOUND, "No ASN found for the given IP").into_response())? + .ok_or_else(|| { + warn!("No ASN connected to IP: {remote_ip}"); + (StatusCode::NOT_FOUND, "No ASN found for the given IP").into_response() + })? + .autonomous_system_number + .ok_or_else(|| { + warn!("ASN number is None for IP: {remote_ip}"); + (StatusCode::NOT_FOUND, "No ASN found for the given IP").into_response() + })?; + + let is_allowed = allowed_asns_ref.iter().any(|p| p == &asn); + + if is_allowed { + Ok(()) + } else { + error!("Asn '{asn}' is not in the allowed list for client_id '{client_id}'"); + Err((StatusCode::FORBIDDEN, "ASN not allowed").into_response()) + } + } else { + Ok(()) + } +} diff --git a/server/src/config.rs b/server/src/config.rs index b05eff9..0342416 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -1,9 +1,10 @@ use dotenvy::dotenv; -use std::env; +use std::{env, path::PathBuf}; pub struct Config { pub secret_token: String, pub is_production: bool, + pub asn_db_path: PathBuf, } impl Config { @@ -13,9 +14,13 @@ impl Config { let is_production = env::var("IS_PRODUCTION") .map(|val| val == "true") .unwrap_or(false); + let asn_db_path: PathBuf = env::var("ASN_DB_PATH") + .map(PathBuf::from) + .unwrap_or(PathBuf::from("asn-test.mmdb")); Self { secret_token, is_production, + asn_db_path, } } } diff --git a/server/src/forwarding.rs b/server/src/forwarding.rs index b3ec8f1..301b371 100644 --- a/server/src/forwarding.rs +++ b/server/src/forwarding.rs @@ -40,6 +40,10 @@ async fn handle_forwarding_request( return response.into_response(); } + if let Err(response) = access_control::is_asn_allowed(&app_state, &client_id, remote_ip).await { + return response.into_response(); + } + if let Some(ws_sender) = app_state.active_websockets.get(&client_id) { let headers_map: HashMap = headers .iter() diff --git a/server/src/main.rs b/server/src/main.rs index b3e0348..8d3a428 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -27,6 +27,8 @@ pub struct AppState { pub pending_responses: Arc>>, pub allowed_paths: Arc>>, pub allowed_ips: Arc>>, + pub allowed_asns: Arc>>, + pub db_reader: Arc>>, } impl AppState { @@ -38,6 +40,11 @@ impl AppState { pending_responses: Arc::new(DashMap::new()), allowed_paths: Arc::new(DashMap::new()), allowed_ips: Arc::new(DashMap::new()), + allowed_asns: Arc::new(DashMap::new()), + db_reader: Arc::new( + maxminddb::Reader::open_readfile(config.asn_db_path) + .expect("Failed to open ASN database"), + ), } } } diff --git a/server/src/models.rs b/server/src/models.rs index 66fddbe..700ddf0 100644 --- a/server/src/models.rs +++ b/server/src/models.rs @@ -11,12 +11,33 @@ pub struct ClientParams { default = "default_vec" )] pub allowed_ips: Vec, + #[serde(deserialize_with = "deserialize_u32_vec", default = "default_u32_vec")] + pub allowed_asns: Vec, } fn default_vec() -> Vec { Vec::new() } +fn default_u32_vec() -> Vec { + Vec::new() +} + +fn deserialize_u32_vec<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let s: Option = Option::deserialize(deserializer)?; + s.map(|s| { + s.split(',') + .map(|s| s.trim().parse::()) + .collect::, _>>() + }) + .transpose() + .map(Option::unwrap_or_default) + .map_err(serde::de::Error::custom) +} + fn deserialize_comma_separated_optional<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, diff --git a/server/src/websocket.rs b/server/src/websocket.rs index a7a6ed0..c10cc3e 100644 --- a/server/src/websocket.rs +++ b/server/src/websocket.rs @@ -41,6 +41,12 @@ pub async fn ws_handler( return e.into_response(); } + let allowed_asns = params.allowed_asns.clone(); + if let Err(e) = access_control::add_allowed_asns(&app_state, &client_id, allowed_asns) { + error!("Failed to add allowed ASNs"); + return e.into_response(); + } + ws.on_upgrade(move |socket| handle_websocket(socket, app_state, client_id)) }