diff --git a/Cargo.lock b/Cargo.lock index ea9b58c..ed74032 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -121,6 +121,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45bf463831f5131b7d3c756525b305d40f1185b688565648a92e1392ca35713d" +dependencies = [ + "axum", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "serde", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-server" version = "0.7.2" @@ -148,6 +170,7 @@ name = "backend" version = "0.1.0" dependencies = [ "axum", + "axum-extra", "axum-server", "dotenvy", "rustls", diff --git a/Cargo.toml b/Cargo.toml index b1a4c6d..0c7c7b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,10 +4,11 @@ version = "0.1.0" edition = "2024" [features] -https = ["dep:rustls"] +https = ["dep:rustls", "dep:axum-extra"] [dependencies] axum = "0.8.4" +axum-extra = { version = "0.10.1", optional = true } axum-server = { version = "0.7.2", features = ["tls-rustls"] } dotenvy = "0.15.7" rustls = { version = "0.23.31", optional = true } diff --git a/src/main.rs b/src/main.rs index 65e4731..e4cae46 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::net::{Ipv4Addr, SocketAddrV4}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use crate::app::new_app; @@ -8,15 +8,27 @@ compile_error!("Feature `https` must be enabled on release."); mod app; +#[allow(dead_code)] +struct Ports { + http: u16, + https: u16, +} + #[cfg(not(debug_assertions))] -const PORT: u16 = 443; +const PORTS: Ports = Ports { + http: 80, + https: 443, +}; #[cfg(debug_assertions)] -const PORT: u16 = 8080; +const PORTS: Ports = Ports { + http: 8080, + https: 4430, +}; #[cfg(debug_assertions)] -const ADDR: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::LOCALHOST, PORT); +const ADDR: Ipv4Addr = Ipv4Addr::LOCALHOST; #[cfg(not(debug_assertions))] -const ADDR: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, PORT); +const ADDR: Ipv4Addr = Ipv4Addr::UNSPECIFIED; #[tokio::main] async fn main() { @@ -31,7 +43,7 @@ async fn main() { async fn http_main() { let app = new_app(); - axum_server::bind(std::net::SocketAddr::V4(ADDR)) + axum_server::bind(SocketAddr::V4(SocketAddrV4::new(ADDR, PORTS.http))) .serve(app.into_make_service()) .await .unwrap(); @@ -49,11 +61,67 @@ async fn https_main() { .await .unwrap(); + tokio::spawn(redirect_http_to_https(PORTS)); + let app = new_app(); // run https server - axum_server::bind_rustls(std::net::SocketAddr::V4(ADDR), config) + axum_server::bind_rustls(SocketAddr::V4(SocketAddrV4::new(ADDR, PORTS.https)), config) .serve(app.into_make_service()) .await .unwrap(); } + +#[cfg(feature = "https")] +async fn redirect_http_to_https(ports: Ports) { + use axum::{ + BoxError, + handler::HandlerWithoutStateExt, + http::{Uri, uri::Authority}, + }; + use axum_extra::extract::Host; + + fn make_https(host: &str, uri: Uri, https_port: u16) -> Result { + let mut parts = uri.into_parts(); + + parts.scheme = Some(axum::http::uri::Scheme::HTTPS); + + if parts.path_and_query.is_none() { + parts.path_and_query = Some("/".parse().unwrap()); + } + + let authority: Authority = host.parse()?; + let bare_host = match authority.port() { + Some(port_struct) => authority + .as_str() + .strip_suffix(port_struct.as_str()) + .unwrap() + .strip_suffix(':') + .unwrap(), // if authority.port() is Some(port) then we can be sure authority ends with :{port} + None => authority.as_str(), + }; + + parts.authority = Some(format!("{bare_host}:{https_port}").parse()?); + + Ok(Uri::from_parts(parts)?) + } + + let redirect = move |Host(host): Host, uri: Uri| async move { + use axum::response::Redirect; + + match make_https(&host, uri, ports.https) { + Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), + Err(_) => { + use axum::http::StatusCode; + + Err(StatusCode::BAD_REQUEST) + } + } + }; + + let addr = SocketAddr::V4(SocketAddrV4::new(ADDR, ports.http)); + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + axum::serve(listener, redirect.into_make_service()) + .await + .unwrap(); +}