Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repository = "https://github.com/DefGuard/proxy"

[dependencies]
# base `axum` deps
axum = { version = "0.7", features = ["macros", "tracing"] }
axum = { version = "0.7", features = ["macros", "tracing", "ws"] }
axum-client-ip = "0.6"
axum-extra = { version = "0.9", features = [
"cookie",
Expand Down Expand Up @@ -48,6 +48,8 @@ tower_governor = "0.4"
rust-embed = { version = "8.5", features = ["include-exclude"] }
mime_guess = "2.0"
base64 = "0.22.1"
futures = "0.3.31"
futures-util = "0.3.31"

[build-dependencies]
tonic-prost-build = "0.14"
Expand Down
2 changes: 1 addition & 1 deletion proto
Submodule proto updated 1 files
+13 −0 core/proxy.proto
170 changes: 166 additions & 4 deletions src/handlers/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
use axum::{extract::State, routing::post, Json, Router};
use axum::{
extract::{
ws::{Message, WebSocket},
Query, State, WebSocketUpgrade,
},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use futures_util::{sink::SinkExt, stream::StreamExt};
use serde::Deserialize;
use serde_json::json;
use std::collections::hash_map::Entry;
use tokio::{sync::oneshot, task::JoinSet};

use crate::{
error::ApiError,
Expand All @@ -14,9 +27,116 @@ pub(crate) fn router() -> Router<AppState> {
Router::new()
.route("/start", post(start_client_mfa))
.route("/finish", post(finish_client_mfa))
.route("/remote", get(await_remote_auth))
.route("/finish-remote", post(finish_remote_mfa))
}

#[instrument(level = "debug", skip(state))]
#[derive(Debug, Clone, Deserialize)]
pub(crate) struct RemoteMfaRequestQuery {
pub token: String,
}

// Allows desktop client to await for another device to complete MFA for it via mobile client
#[instrument(level = "debug", skip(state, req))]
async fn await_remote_auth(
ws: WebSocketUpgrade,
Query(req): Query<RemoteMfaRequestQuery>,
State(state): State<AppState>,
device_info: DeviceInfo,
) -> Result<Response, impl IntoResponse> {
let token = req.token;
// let core validate token first
let rx = state.grpc_server.send(
core_request::Payload::ClientMfaTokenValidation(
crate::proto::ClientMfaTokenValidationRequest {
token: token.clone(),
},
),
device_info,
)?;
let payload = get_core_response(rx).await?;
if let core_response::Payload::ClientMfaTokenValidation(response) = payload {
if !response.token_valid {
return Err(ApiError::Unauthorized(String::new()));
}
// check if its already in the map
let contains_key = {
let sessions = state.remote_mfa_sessions.lock().await;
sessions.contains_key(&token)
};
if contains_key {
return Err(ApiError::Unauthorized(String::new()));
};
Ok(ws.on_upgrade(move |socket| handle_remote_auth_socket(socket, state.clone(), token)))
} else {
Err(ApiError::InvalidResponseType)
}
}

// handle axum ws socket upgrade for await_remote_auth
async fn handle_remote_auth_socket(socket: WebSocket, state: AppState, token: String) {
let (tx, rx) = oneshot::channel::<String>();
let (mut ws_tx, mut ws_rx) = socket.split();

let occupied = {
let mut sessions = state.remote_mfa_sessions.lock().await;
match sessions.entry(token.clone()) {
Entry::Occupied(_) => true,
Entry::Vacant(v) => {
v.insert(tx);
false
}
}
};
if occupied {
let _ = ws_tx.close().await;
return;
}

let mut set = JoinSet::new();

set.spawn(async move {
if let Ok(msg) = rx.await {
let payload = json!({
"type": "mfa_success",
"preshared_key": &msg,
});
if let Ok(serialized) = serde_json::to_string(&payload) {
let message = Message::Text(serialized);
if ws_tx.send(message).await.is_err() {
error!("Failed to send preshared key via ws");
}
} else {
error!("Failed to serialize remote mfa ws client response message");
}
} else {
error!("Failed to receive preshared key from receiver")
}
let _ = ws_tx.close().await;
});
set.spawn(async move {
while let Some(msg_result) = ws_rx.next().await {
match msg_result {
Ok(msg) => {
if let Message::Close(_) = msg {
break;
}
}
Err(e) => {
error!("Remote desktop mfa WS client listen error {e}");
break;
}
}
}
});

let _ = set.join_next().await;
set.shutdown().await;
// will remove token if it's still there
state.remote_mfa_sessions.lock().await.remove(&token);
}

#[instrument(level = "debug", skip(state, req))]
async fn start_client_mfa(
State(state): State<AppState>,
device_info: DeviceInfo,
Expand All @@ -38,7 +158,7 @@ async fn start_client_mfa(
}
}

#[instrument(level = "debug", skip(state))]
#[instrument(level = "debug", skip(state, req))]
async fn finish_client_mfa(
State(state): State<AppState>,
device_info: DeviceInfo,
Expand All @@ -50,10 +170,52 @@ async fn finish_client_mfa(
.send(core_request::Payload::ClientMfaFinish(req), device_info)?;
let payload = get_core_response(rx).await?;
if let core_response::Payload::ClientMfaFinish(response) = payload {
info!("Finished desktop client authorization");
Ok(Json(response))
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}

#[instrument(level = "debug", skip(state, req))]
async fn finish_remote_mfa(
State(state): State<AppState>,
device_info: DeviceInfo,
Json(req): Json<ClientMfaFinishRequest>,
) -> Result<Json<serde_json::Value>, ApiError> {
info!("Finishing desktop client authorization");
let rx = state
.grpc_server
.send(core_request::Payload::ClientMfaFinish(req), device_info)?;
let payload = get_core_response(rx).await?;
if let core_response::Payload::ClientMfaFinish(response) = payload {
// check if this needs to be forwarded
match response.token {
Some(token) => {
let sender_option = {
let mut sessions = state.remote_mfa_sessions.lock().await;
sessions.remove(&token)
};
match sender_option {
Some(sender) => {
let _ = sender.send(response.preshared_key);
}
// if desktop stopped listening for the result there will be no palce to send the result
None => {
error!("Remote MFA approve finished but session was not found.");
return Err(ApiError::Unexpected(String::new()));
}
}
info!("Finished desktop client authorization via mobile device");
Ok(Json(json!({})))
}
None => {
error!("Remote MFA Unexpected core response, token was not returned");
Err(ApiError::Unexpected(String::new()))
}
}
} else {
error!("Received invalid gRPC response type: {payload:#?}");
Err(ApiError::InvalidResponseType)
}
}
8 changes: 6 additions & 2 deletions src/http.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::{
collections::HashMap,
fs::read_to_string,
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::atomic::Ordering,
sync::{atomic::Ordering, Arc},
time::Duration,
};

Expand All @@ -16,7 +17,7 @@ use axum::{
use axum_extra::extract::cookie::Key;
use clap::crate_version;
use serde::Serialize;
use tokio::{net::TcpListener, task::JoinSet};
use tokio::{net::TcpListener, sync::oneshot, task::JoinSet};
use tonic::transport::{Identity, Server, ServerTlsConfig};
use tower_governor::{
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
Expand All @@ -42,6 +43,8 @@ const RATE_LIMITER_CLEANUP_PERIOD: Duration = Duration::from_secs(60);
#[derive(Clone)]
pub(crate) struct AppState {
pub(crate) grpc_server: ProxyServer,
pub(crate) remote_mfa_sessions:
Arc<tokio::sync::Mutex<HashMap<String, oneshot::Sender<String>>>>,
key: Key,
url: Url,
}
Expand Down Expand Up @@ -129,6 +132,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
debug!("Setting up API server");
let shared_state = AppState {
grpc_server: grpc_server.clone(),
remote_mfa_sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
// Generate secret key for encrypting cookies.
key: Key::generate(),
url: config.url.clone(),
Expand Down
7 changes: 4 additions & 3 deletions web/biome.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"$schema": "https://biomejs.dev/schemas/2.1.2/schema.json",
"$schema": "https://biomejs.dev/schemas/2.2.0/schema.json",
"vcs": { "enabled": false, "clientKind": "git", "useIgnoreFile": false },
"files": {
"ignoreUnknown": false,
Expand All @@ -8,7 +8,7 @@
"!src/i18n/*.ts",
"!src/i18n/*.tsx",
"!src/i18n/i18n-util",
"!dist/**"
"!dist"
]
},
"formatter": {
Expand Down Expand Up @@ -40,7 +40,8 @@
"noUnusedVariables": "error",
"useExhaustiveDependencies": "error",
"useHookAtTopLevel": "error",
"useJsxKeyInIterable": "error"
"useJsxKeyInIterable": "error",
"useUniqueElementIds": "off"
},
"security": { "noDangerouslySetInnerHtmlWithChildren": "error" },
"style": {
Expand Down
Loading