diff --git a/src/client/error.rs b/src/client/error.rs index 04638bb..14c6a4e 100644 --- a/src/client/error.rs +++ b/src/client/error.rs @@ -18,6 +18,9 @@ pub enum HttpError { /// Request could not be cloned (required for retry) CloneFailed, + /// Payment was declined by the on_challenge callback + PaymentDeclined, + /// Payment provider error Payment(MppError), @@ -35,6 +38,7 @@ impl fmt::Display for HttpError { Self::InvalidChallenge(msg) => write!(f, "invalid challenge: {}", msg), Self::InvalidCredential(msg) => write!(f, "invalid credential: {}", msg), Self::CloneFailed => write!(f, "request could not be cloned for retry"), + Self::PaymentDeclined => write!(f, "payment declined by on_challenge callback"), Self::Payment(e) => write!(f, "payment failed: {}", e), #[cfg(feature = "client")] Self::Request(e) => write!(f, "HTTP request failed: {}", e), diff --git a/src/client/fetch.rs b/src/client/fetch.rs index 854737d..d05a015 100644 --- a/src/client/fetch.rs +++ b/src/client/fetch.rs @@ -4,10 +4,70 @@ use reqwest::header::WWW_AUTHENTICATE; use reqwest::{RequestBuilder, Response, StatusCode}; +use std::future::Future; +use std::pin::Pin; use super::error::HttpError; use super::provider::PaymentProvider; -use crate::protocol::core::{format_authorization, parse_www_authenticate, AUTHORIZATION_HEADER}; +use crate::protocol::core::{ + format_authorization, parse_www_authenticate, PaymentChallenge, PaymentCredential, + AUTHORIZATION_HEADER, +}; + +/// Result of an [`OnChallenge`] callback, mirroring mppx's 3-way return type. +/// +/// - `Approve` — proceed with automatic payment via the provider (mppx: `undefined`) +/// - `Credential(c)` — use this pre-built credential, skip the provider (mppx: `string`) +/// - `Decline` — abort with [`HttpError::PaymentDeclined`] (mppx: `throw`) +#[derive(Debug)] +pub enum ChallengeAction { + /// Proceed with automatic payment via the provider. + Approve, + /// Use this credential directly, skipping the provider. + Credential(Box), + /// Decline the payment. + Decline, +} + +/// Callback invoked when a 402 challenge is received, before executing payment. +/// +/// The callback receives the parsed [`PaymentChallenge`] and returns a +/// [`ChallengeAction`] controlling how to proceed. +/// +/// # Examples +/// +/// ```ignore +/// use mpp::client::{Fetch, OnChallenge, ChallengeAction}; +/// +/// // Simple approval gate +/// let on_challenge: OnChallenge = Box::new(|challenge| { +/// Box::pin(async move { +/// if approve_payment(challenge).await { +/// ChallengeAction::Approve +/// } else { +/// ChallengeAction::Decline +/// } +/// }) +/// }); +/// +/// // Or supply a credential directly (e.g. after gathering user input) +/// let on_challenge: OnChallenge = Box::new(|challenge| { +/// Box::pin(async move { +/// let credential = create_credential_with_extra_context(challenge).await; +/// ChallengeAction::Credential(Box::new(credential)) +/// }) +/// }); +/// +/// let resp = client +/// .get("https://api.example.com/paid") +/// .send_with_payment_opts(&provider, Some(&on_challenge)) +/// .await?; +/// ``` +pub type OnChallenge = Box< + dyn Fn(&PaymentChallenge) -> Pin + Send + '_>> + + Send + + Sync, +>; /// Extension trait for `reqwest::RequestBuilder` with payment support. /// @@ -48,12 +108,33 @@ pub trait PaymentExt { self, provider: &P, ) -> impl std::future::Future> + Send; + + /// Like [`send_with_payment`](PaymentExt::send_with_payment), but with an + /// optional [`OnChallenge`] callback invoked before payment execution. + /// + /// The callback returns a [`ChallengeAction`] controlling how to proceed: + /// - [`ChallengeAction::Approve`] — auto-pay via the provider + /// - [`ChallengeAction::Credential`] — use the provided credential directly + /// - [`ChallengeAction::Decline`] — abort with [`HttpError::PaymentDeclined`] + fn send_with_payment_opts( + self, + provider: &P, + on_challenge: Option<&OnChallenge>, + ) -> impl std::future::Future> + Send; } impl PaymentExt for RequestBuilder { async fn send_with_payment( self, provider: &P, + ) -> Result { + self.send_with_payment_opts(provider, None).await + } + + async fn send_with_payment_opts( + self, + provider: &P, + on_challenge: Option<&OnChallenge>, ) -> Result { let retry_builder = self.try_clone().ok_or(HttpError::CloneFailed)?; @@ -73,7 +154,15 @@ impl PaymentExt for RequestBuilder { let challenge = parse_www_authenticate(www_auth) .map_err(|e| HttpError::InvalidChallenge(e.to_string()))?; - let credential = provider.pay(&challenge).await?; + let credential = if let Some(cb) = on_challenge { + match cb(&challenge).await { + ChallengeAction::Approve => provider.pay(&challenge).await?, + ChallengeAction::Credential(c) => *c, + ChallengeAction::Decline => return Err(HttpError::PaymentDeclined), + } + } else { + provider.pay(&challenge).await? + }; let auth_header = format_authorization(&credential) .map_err(|e| HttpError::InvalidCredential(e.to_string()))?; @@ -322,5 +411,93 @@ mod tests { assert!(matches!(err, HttpError::Payment(_))); } + + /// Helper: build a 402 server that accepts payment on retry. + fn paid_app(www_auth: String) -> Router { + Router::new().route( + "/paid", + get(move |req: axum::http::Request| { + let www_auth = www_auth.clone(); + async move { + if req.headers().get("authorization").is_some() { + (AxumStatusCode::OK, "ok").into_response() + } else { + ( + AxumStatusCode::PAYMENT_REQUIRED, + [(WWW_AUTH_NAME, www_auth)], + "pay up", + ) + .into_response() + } + } + }), + ) + } + + #[tokio::test] + async fn test_on_challenge_approve() { + let (_, www_auth) = test_challenge(); + let app = paid_app(www_auth); + let base_url = spawn_server(app).await; + let provider = MockProvider::new(); + + let on_challenge: super::OnChallenge = + Box::new(|_challenge| Box::pin(async { super::ChallengeAction::Approve })); + + let resp = reqwest::Client::new() + .get(format!("{}/paid", base_url)) + .send_with_payment_opts(&provider, Some(&on_challenge)) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(provider.call_count(), 1); + } + + #[tokio::test] + async fn test_on_challenge_credential() { + let (challenge, www_auth) = test_challenge(); + let app = paid_app(www_auth); + let base_url = spawn_server(app).await; + let provider = MockProvider::new(); + + let echo = challenge.to_echo(); + let on_challenge: super::OnChallenge = Box::new(move |_challenge| { + let echo = echo.clone(); + Box::pin(async move { + let cred = PaymentCredential::new(echo, PaymentPayload::hash("0xcustom")); + super::ChallengeAction::Credential(Box::new(cred)) + }) + }); + + let resp = reqwest::Client::new() + .get(format!("{}/paid", base_url)) + .send_with_payment_opts(&provider, Some(&on_challenge)) + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(provider.call_count(), 0); // provider was NOT called + } + + #[tokio::test] + async fn test_on_challenge_decline() { + let (_, www_auth) = test_challenge(); + let app = paid_app(www_auth); + let base_url = spawn_server(app).await; + let provider = MockProvider::new(); + + let on_challenge: super::OnChallenge = + Box::new(|_challenge| Box::pin(async { super::ChallengeAction::Decline })); + + let err = reqwest::Client::new() + .get(format!("{}/paid", base_url)) + .send_with_payment_opts(&provider, Some(&on_challenge)) + .await + .unwrap_err(); + + assert!(matches!(err, HttpError::PaymentDeclined)); + assert_eq!(provider.call_count(), 0); + } } } diff --git a/src/client/middleware.rs b/src/client/middleware.rs index 04b496e..3e6148e 100644 --- a/src/client/middleware.rs +++ b/src/client/middleware.rs @@ -8,6 +8,7 @@ use reqwest::header::WWW_AUTHENTICATE; use reqwest::{Request, Response, StatusCode}; use reqwest_middleware::{Middleware, Next}; +use crate::client::fetch::{ChallengeAction, OnChallenge}; use crate::client::provider::PaymentProvider; use crate::protocol::core::{format_authorization, parse_www_authenticate, AUTHORIZATION_HEADER}; @@ -35,12 +36,27 @@ use crate::protocol::core::{format_authorization, parse_www_authenticate, AUTHOR /// ``` pub struct PaymentMiddleware

{ provider: P, + on_challenge: Option, } impl

PaymentMiddleware

{ /// Create a new payment middleware with the given provider. pub fn new(provider: P) -> Self { - Self { provider } + Self { + provider, + on_challenge: None, + } + } + + /// Set an [`OnChallenge`] callback invoked before payment execution. + /// + /// The callback returns a [`ChallengeAction`] controlling how to proceed: + /// - [`ChallengeAction::Approve`] — auto-pay via the provider + /// - [`ChallengeAction::Credential`] — use the provided credential directly + /// - [`ChallengeAction::Decline`] — abort with a middleware error + pub fn with_on_challenge(mut self, on_challenge: OnChallenge) -> Self { + self.on_challenge = Some(on_challenge); + self } } @@ -79,12 +95,28 @@ where .context("invalid challenge") .map_err(reqwest_middleware::Error::Middleware)?; - let credential = self - .provider - .pay(&challenge) - .await - .context("payment failed") - .map_err(reqwest_middleware::Error::Middleware)?; + let credential = if let Some(ref cb) = self.on_challenge { + match cb(&challenge).await { + ChallengeAction::Approve => self + .provider + .pay(&challenge) + .await + .context("payment failed") + .map_err(reqwest_middleware::Error::Middleware)?, + ChallengeAction::Credential(c) => *c, + ChallengeAction::Decline => { + return Err(reqwest_middleware::Error::Middleware(anyhow::anyhow!( + "payment declined by on_challenge callback" + ))) + } + } + } else { + self.provider + .pay(&challenge) + .await + .context("payment failed") + .map_err(reqwest_middleware::Error::Middleware)? + }; let auth_header = format_authorization(&credential) .context("failed to format credential") @@ -341,5 +373,109 @@ mod tests { err ); } + + /// Helper: build a 402 server that accepts payment on retry. + fn paid_app(www_auth: String) -> Router { + Router::new().route( + "/paid", + get(move |req: axum::http::Request| { + let www_auth = www_auth.clone(); + async move { + if req.headers().get("authorization").is_some() { + (AxumStatusCode::OK, "ok").into_response() + } else { + ( + AxumStatusCode::PAYMENT_REQUIRED, + [(WWW_AUTH_NAME, www_auth)], + "pay up", + ) + .into_response() + } + } + }), + ) + } + + #[tokio::test] + async fn test_middleware_on_challenge_approve() { + let (_, www_auth) = test_challenge(); + let app = paid_app(www_auth); + let base_url = spawn_server(app).await; + let provider = TestProvider::new(); + + let on_challenge: OnChallenge = + Box::new(|_challenge| Box::pin(async { ChallengeAction::Approve })); + + let client = ClientBuilder::new(reqwest::Client::new()) + .with(PaymentMiddleware::new(provider.clone()).with_on_challenge(on_challenge)) + .build(); + + let resp = client + .get(format!("{}/paid", base_url)) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), reqwest::StatusCode::OK); + assert_eq!(provider.call_count(), 1); + } + + #[tokio::test] + async fn test_middleware_on_challenge_credential() { + let (challenge, www_auth) = test_challenge(); + let app = paid_app(www_auth); + let base_url = spawn_server(app).await; + let provider = TestProvider::new(); + + let echo = challenge.to_echo(); + let on_challenge: OnChallenge = Box::new(move |_challenge| { + let echo = echo.clone(); + Box::pin(async move { + let cred = PaymentCredential::new(echo, PaymentPayload::hash("0xcustom")); + ChallengeAction::Credential(Box::new(cred)) + }) + }); + + let client = ClientBuilder::new(reqwest::Client::new()) + .with(PaymentMiddleware::new(provider.clone()).with_on_challenge(on_challenge)) + .build(); + + let resp = client + .get(format!("{}/paid", base_url)) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), reqwest::StatusCode::OK); + assert_eq!(provider.call_count(), 0); // provider was NOT called + } + + #[tokio::test] + async fn test_middleware_on_challenge_decline() { + let (_, www_auth) = test_challenge(); + let app = paid_app(www_auth); + let base_url = spawn_server(app).await; + let provider = TestProvider::new(); + + let on_challenge: OnChallenge = + Box::new(|_challenge| Box::pin(async { ChallengeAction::Decline })); + + let client = ClientBuilder::new(reqwest::Client::new()) + .with(PaymentMiddleware::new(provider.clone()).with_on_challenge(on_challenge)) + .build(); + + let err = client + .get(format!("{}/paid", base_url)) + .send() + .await + .unwrap_err(); + + assert!( + err.to_string().contains("payment declined"), + "expected payment declined error, got: {}", + err + ); + assert_eq!(provider.call_count(), 0); + } } } diff --git a/src/client/mod.rs b/src/client/mod.rs index fd55432..181cc61 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -33,7 +33,7 @@ pub use error::HttpError; pub use provider::{MultiProvider, PaymentProvider}; #[cfg(feature = "client")] -pub use fetch::PaymentExt as Fetch; +pub use fetch::{ChallengeAction, OnChallenge, PaymentExt as Fetch}; #[cfg(feature = "middleware")] pub use middleware::PaymentMiddleware;