diff --git a/Cargo.lock b/Cargo.lock index 2cf63480..30646972 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -550,8 +550,9 @@ dependencies = [ [[package]] name = "defguard_version" version = "0.0.0" -source = "git+https://github.com/DefGuard/defguard.git?rev=db678a95398e38b72bbb4ecef36a27caa427e48c#db678a95398e38b72bbb4ecef36a27caa427e48c" +source = "git+https://github.com/DefGuard/defguard.git?rev=be3f96ced072ede3ebde72f2f6c6063d2e7f7403#be3f96ced072ede3ebde72f2f6c6063d2e7f7403" dependencies = [ + "axum", "http", "os_info", "semver", diff --git a/Cargo.toml b/Cargo.toml index 3bef7e64..cf7b9611 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ homepage = "https://github.com/DefGuard/proxy" repository = "https://github.com/DefGuard/proxy" [dependencies] -defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "db678a95398e38b72bbb4ecef36a27caa427e48c" } +defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "be3f96ced072ede3ebde72f2f6c6063d2e7f7403" } # base `axum` deps axum = { version = "0.8", features = ["macros", "tracing", "ws"] } axum-client-ip = "0.7" diff --git a/src/grpc.rs b/src/grpc.rs index f69a724a..dbd27d7a 100644 --- a/src/grpc.rs +++ b/src/grpc.rs @@ -6,13 +6,13 @@ use std::{ Arc, Mutex, }, }; + +use defguard_version::{get_tracing_variables, parse_metadata, DefguardComponent, Version}; use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{Request, Response, Status, Streaming}; use tracing::Instrument; -use defguard_version::{get_tracing_variables, parse_metadata, DefguardComponent}; - use crate::{ error::ApiError, proto::{core_request, core_response, proxy_server, CoreRequest, CoreResponse, DeviceInfo}, @@ -27,6 +27,7 @@ pub(crate) struct ProxyServer { clients: Arc>, results: Arc>>>, pub(crate) connected: Arc, + pub(crate) core_version: Arc>>, } impl ProxyServer { @@ -38,6 +39,7 @@ impl ProxyServer { clients: Arc::new(Mutex::new(HashMap::new())), results: Arc::new(Mutex::new(HashMap::new())), connected: Arc::new(AtomicBool::new(false)), + core_version: Arc::new(Mutex::new(None)), } } @@ -82,6 +84,7 @@ impl Clone for ProxyServer { clients: Arc::clone(&self.clients), results: Arc::clone(&self.results), connected: Arc::clone(&self.connected), + core_version: Arc::clone(&self.core_version), } } } @@ -102,6 +105,12 @@ impl proxy_server::Proxy for ProxyServer { }; let maybe_info = parse_metadata(request.metadata()); let (version, info) = get_tracing_variables(&maybe_info); + + if let Ok(ver) = Version::parse(&version) { + let mut core_version = self.core_version.lock().unwrap(); + *core_version = Some(ver); + } + let span = tracing::info_span!("core_bidi_stream", component = %DefguardComponent::Core, version, info); let _guard = span.enter(); diff --git a/src/http.rs b/src/http.rs index 382ba71b..db0cd948 100644 --- a/src/http.rs +++ b/src/http.rs @@ -10,14 +10,15 @@ use anyhow::Context; use axum::{ body::Body, extract::{ConnectInfo, FromRef, State}, - http::{Request, StatusCode}, + http::{header::HeaderValue, Request, Response, StatusCode}, + middleware::{self, Next}, routing::{get, post}, serve, Json, Router, }; use axum_extra::extract::cookie::Key; use clap::crate_version; use defguard_version::{ - server::{DefguardVersionInterceptor, DefguardVersionLayer}, + server::{grpc::DefguardVersionInterceptor, DefguardVersionLayer}, DefguardComponent, Version, }; use serde::Serialize; @@ -44,6 +45,7 @@ use crate::{ pub(crate) static ENROLLMENT_COOKIE_NAME: &str = "defguard_proxy"; pub(crate) static PASSWORD_RESET_COOKIE_NAME: &str = "defguard_proxy_password_reset"; +const DEFGUARD_CORE_VERSION_HEADER: &str = "defguard-core-version"; const RATE_LIMITER_CLEANUP_PERIOD: Duration = Duration::from_secs(60); #[derive(Clone)] @@ -125,6 +127,24 @@ fn get_client_addr(request: &Request) -> String { ) } +async fn core_version_middleware( + State(app_state): State, + request: Request, + next: Next, +) -> Response { + let mut response = next.run(request).await; + + if let Some(core_version) = app_state.grpc_server.core_version.lock().unwrap().as_ref() { + if let Ok(core_version_header) = HeaderValue::from_str(&core_version.to_string()) { + response + .headers_mut() + .insert(DEFGUARD_CORE_VERSION_HEADER, core_version_header); + } + } + + response +} + pub async fn run_server(config: Config) -> anyhow::Result<()> { info!("Starting Defguard Proxy server"); debug!("Using config: {config:?}"); @@ -246,6 +266,11 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { .route("/info", get(app_info)), ) .fallback_service(get(handle_404)) + .layer(middleware::from_fn_with_state( + shared_state.clone(), + core_version_middleware, + )) + .layer(DefguardVersionLayer::new(Version::parse(VERSION)?)) .with_state(shared_state) .layer( TraceLayer::new_for_http()