Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ homepage = "https://github.com/DefGuard/proxy"
repository = "https://github.com/DefGuard/proxy"

[dependencies]
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "db678a95398e38b72bbb4ecef36a27caa427e48c" }
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "be3f96ced072ede3ebde72f2f6c6063d2e7f7403" }
# base `axum` deps
axum = { version = "0.8", features = ["macros", "tracing", "ws"] }
axum-client-ip = "0.7"
Expand Down
13 changes: 11 additions & 2 deletions src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ use std::{
Arc, Mutex,
},
};

use defguard_version::{get_tracing_variables, parse_metadata, DefguardComponent, Version};
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{Request, Response, Status, Streaming};
use tracing::Instrument;

use defguard_version::{get_tracing_variables, parse_metadata, DefguardComponent};

use crate::{
error::ApiError,
proto::{core_request, core_response, proxy_server, CoreRequest, CoreResponse, DeviceInfo},
Expand All @@ -27,6 +27,7 @@ pub(crate) struct ProxyServer {
clients: Arc<Mutex<ClientMap>>,
results: Arc<Mutex<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
pub(crate) connected: Arc<AtomicBool>,
pub(crate) core_version: Arc<Mutex<Option<Version>>>,
}

impl ProxyServer {
Expand All @@ -38,6 +39,7 @@ impl ProxyServer {
clients: Arc::new(Mutex::new(HashMap::new())),
results: Arc::new(Mutex::new(HashMap::new())),
connected: Arc::new(AtomicBool::new(false)),
core_version: Arc::new(Mutex::new(None)),
}
}

Expand Down Expand Up @@ -82,6 +84,7 @@ impl Clone for ProxyServer {
clients: Arc::clone(&self.clients),
results: Arc::clone(&self.results),
connected: Arc::clone(&self.connected),
core_version: Arc::clone(&self.core_version),
}
}
}
Expand All @@ -102,6 +105,12 @@ impl proxy_server::Proxy for ProxyServer {
};
let maybe_info = parse_metadata(request.metadata());
let (version, info) = get_tracing_variables(&maybe_info);

if let Ok(ver) = Version::parse(&version) {
let mut core_version = self.core_version.lock().unwrap();
*core_version = Some(ver);
}

let span = tracing::info_span!("core_bidi_stream", component = %DefguardComponent::Core, version, info);
let _guard = span.enter();

Expand Down
29 changes: 27 additions & 2 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ use anyhow::Context;
use axum::{
body::Body,
extract::{ConnectInfo, FromRef, State},
http::{Request, StatusCode},
http::{header::HeaderValue, Request, Response, StatusCode},
middleware::{self, Next},
routing::{get, post},
serve, Json, Router,
};
use axum_extra::extract::cookie::Key;
use clap::crate_version;
use defguard_version::{
server::{DefguardVersionInterceptor, DefguardVersionLayer},
server::{grpc::DefguardVersionInterceptor, DefguardVersionLayer},
DefguardComponent, Version,
};
use serde::Serialize;
Expand All @@ -44,6 +45,7 @@ use crate::{

pub(crate) static ENROLLMENT_COOKIE_NAME: &str = "defguard_proxy";
pub(crate) static PASSWORD_RESET_COOKIE_NAME: &str = "defguard_proxy_password_reset";
const DEFGUARD_CORE_VERSION_HEADER: &str = "defguard-core-version";
const RATE_LIMITER_CLEANUP_PERIOD: Duration = Duration::from_secs(60);

#[derive(Clone)]
Expand Down Expand Up @@ -125,6 +127,24 @@ fn get_client_addr(request: &Request<Body>) -> String {
)
}

async fn core_version_middleware(
State(app_state): State<AppState>,
request: Request<Body>,
next: Next,
) -> Response<Body> {
let mut response = next.run(request).await;

if let Some(core_version) = app_state.grpc_server.core_version.lock().unwrap().as_ref() {
if let Ok(core_version_header) = HeaderValue::from_str(&core_version.to_string()) {
response
.headers_mut()
.insert(DEFGUARD_CORE_VERSION_HEADER, core_version_header);
}
}

response
}

pub async fn run_server(config: Config) -> anyhow::Result<()> {
info!("Starting Defguard Proxy server");
debug!("Using config: {config:?}");
Expand Down Expand Up @@ -246,6 +266,11 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
.route("/info", get(app_info)),
)
.fallback_service(get(handle_404))
.layer(middleware::from_fn_with_state(
shared_state.clone(),
core_version_middleware,
))
.layer(DefguardVersionLayer::new(Version::parse(VERSION)?))
.with_state(shared_state)
.layer(
TraceLayer::new_for_http()
Expand Down