diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 59fc639..e8196b1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,10 +52,10 @@ jobs: - run: cargo update -p native-tls - uses: taiki-e/install-action@cargo-hack - name: Tests - run: cargo test --features tempo,stripe,server,client,axum,middleware,tower,utils,integration-stripe + run: cargo test --features tempo,stripe,ws,server,client,axum,middleware,tower,utils,integration-stripe,integration-ws env: STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }} - - run: cargo hack check --each-feature --no-dev-deps --skip integration + - run: cargo hack check --each-feature --no-dev-deps --skip integration,integration-stripe,integration-ws - name: Check examples run: find examples -name Cargo.toml -exec cargo check --manifest-path {} \; diff --git a/Cargo.toml b/Cargo.toml index 9d0f8ad..413b250 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ default = ["reqwest-default-tls"] # Side selection client = ["dep:reqwest"] -server = ["tokio", "futures-core", "async-stream"] +server = ["tokio", "futures-core", "async-stream", "http-types"] # Method implementations evm = ["alloy", "hex", "rand"] @@ -35,12 +35,16 @@ tower = ["dep:tower-layer", "dep:tower-service", "http-types", "dep:http-body", # Axum middleware support (server-side convenience) axum = ["dep:axum-core", "server", "http-types"] +# WebSocket support +ws = ["server", "client", "dep:tokio-tungstenite", "dep:futures-util"] + reqwest-default-tls = ["reqwest?/default-tls"] reqwest-native-tls = ["reqwest?/native-tls"] reqwest-rustls-tls = ["reqwest?/rustls-tls"] # Integration tests (requires a running Tempo localnet) integration = ["tempo", "server", "client", "axum"] +integration-ws = ["ws", "tempo", "server", "client", "axum"] integration-stripe = ["stripe", "server", "client", "axum"] [dependencies] @@ -82,9 +86,15 @@ http-body = { version = "1", optional = true } # Axum dependencies (optional) axum-core = { version = "0.5", optional = true } +# WebSocket dependencies (optional) +tokio-tungstenite = { version = "0.26", optional = true } +futures-util = { version = "0.3", optional = true } + [dev-dependencies] tokio = { version = "1", features = ["rt-multi-thread", "macros", "net"] } -axum = { version = "0.8" } +axum = { version = "0.8", features = ["ws"] } reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } hex = "0.4" +tokio-tungstenite = "0.26" +futures-util = "0.3" diff --git a/README.md b/README.md index b9712ed..ddec0f6 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,39 @@ let resp = reqwest::Client::new() .await?; ``` +### WebSocket + +```rust +use mpp::server::ws::{WsMessage, WsResponse}; + +// Server: parse incoming WS message, send challenge/receipt +let msg: WsMessage = serde_json::from_str(&text)?; +if let WsMessage::Credential { credential } = msg { + let parsed = mpp::parse_authorization(&credential)?; + let receipt = mpp.verify_credential(&parsed).await?; + let resp = WsResponse::Receipt { + receipt: serde_json::to_value(&receipt)?, + }; + socket.send(resp.to_text()).await; +} + +// Client: detect challenge, send credential +let msg: mpp::client::ws::WsServerMessage = serde_json::from_str(&text)?; +if let WsServerMessage::Challenge { challenge, .. } = msg { + let cred_msg = serde_json::json!({ + "type": "credential", + "credential": auth_string, + }); + ws.send(cred_msg.to_string()).await; +} +``` + +WSS (WebSocket Secure) is handled at the connection layer — the transport itself is protocol-agnostic. On the server, terminate TLS via a reverse proxy (nginx, Cloudflare) or use `axum-server` with rustls. On the client, `tokio-tungstenite` supports `wss://` URLs via its `native-tls` or `rustls` features: + +```toml +tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] } +``` + ## Feature Flags | Feature | Description | @@ -114,6 +147,7 @@ let resp = reqwest::Client::new() | `middleware` | reqwest-middleware support with `PaymentMiddleware` (implies `client`) | | `tower` | Tower middleware for server-side integration | | `axum` | Axum extractor support for server-side convenience | +| `ws` | WebSocket transport for bidirectional session payments | | `utils` | Hex/random utilities for development and testing | ## Payment Methods diff --git a/examples/ws/Cargo.toml b/examples/ws/Cargo.toml new file mode 100644 index 0000000..c218f1f --- /dev/null +++ b/examples/ws/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "ws-example" +version = "0.1.0" +edition = "2021" +publish = false + +[[bin]] +name = "ws-server" +path = "src/server.rs" + +[[bin]] +name = "ws-client" +path = "src/client.rs" + +[dependencies] +mpp = { path = "../..", features = ["server", "client", "ws", "tempo"] } +axum = { version = "0.8", features = ["ws"] } +tokio = { version = "1", features = ["full"] } +tokio-tungstenite = "0.26" +futures-util = "0.3" +serde_json = "1" diff --git a/examples/ws/README.md b/examples/ws/README.md new file mode 100644 index 0000000..4faf972 --- /dev/null +++ b/examples/ws/README.md @@ -0,0 +1,25 @@ +# WebSocket Payment Example + +Demonstrates the MPP WebSocket payment flow with a server that streams +fortunes after payment verification. + +## Running + +```bash +# Start the server +cargo run --bin ws-server + +# In another terminal, start the client +cargo run --bin ws-client +``` + +## Protocol + +1. Client connects via WebSocket +2. Server sends `{ "type": "challenge", ... }` +3. Client responds with `{ "type": "credential", "credential": "Payment ..." }` +4. Server verifies payment and streams data as `{ "type": "message", "data": "..." }` +5. Server sends final `{ "type": "receipt", ... }` and closes + +**Note:** This example uses a mock credential. In production, use +`TempoProvider` to sign real transactions. diff --git a/examples/ws/src/client.rs b/examples/ws/src/client.rs new file mode 100644 index 0000000..403f4c7 --- /dev/null +++ b/examples/ws/src/client.rs @@ -0,0 +1,111 @@ +//! # WebSocket Payment Client +//! +//! Connects to the WS payment server, handles the challenge/credential +//! flow, and prints received fortunes. +//! +//! ## Running +//! +//! ```bash +//! # First start the server: +//! cargo run --bin ws-server +//! +//! # Then in another terminal: +//! cargo run --bin ws-client +//! ``` + +use futures_util::{SinkExt, StreamExt}; +use mpp::client::ws::WsServerMessage; +use mpp::protocol::core::{format_authorization, PaymentPayload}; +use tokio_tungstenite::tungstenite; + +#[tokio::main] +async fn main() { + let url = std::env::args() + .nth(1) + .unwrap_or_else(|| "ws://127.0.0.1:3000/ws".to_string()); + + println!("Connecting to {url} ..."); + + let (mut ws, _) = tokio_tungstenite::connect_async(&url) + .await + .expect("failed to connect"); + + println!("Connected!"); + + while let Some(msg) = ws.next().await { + let msg = match msg { + Ok(tungstenite::Message::Text(text)) => text, + Ok(tungstenite::Message::Close(_)) => { + println!("Server closed connection"); + break; + } + Err(e) => { + eprintln!("WS error: {e}"); + break; + } + _ => continue, + }; + + let server_msg: WsServerMessage = match serde_json::from_str(&msg) { + Ok(m) => m, + Err(e) => { + eprintln!("Failed to parse server message: {e}"); + continue; + } + }; + + match server_msg { + WsServerMessage::Challenge { challenge, .. } => { + println!("Received payment challenge"); + + // Parse the challenge + let parsed: mpp::PaymentChallenge = + serde_json::from_value(challenge).expect("parse challenge"); + + // Create a mock credential (in real use, sign a transaction) + let credential = mpp::PaymentCredential::new( + parsed.to_echo(), + PaymentPayload::hash( + "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef", + ), + ); + let auth_str = format_authorization(&credential).unwrap(); + + // Send credential + let cred_msg = serde_json::json!({ + "type": "credential", + "credential": auth_str, + }); + ws.send(tungstenite::Message::Text(cred_msg.to_string().into())) + .await + .unwrap(); + println!("Sent credential"); + } + WsServerMessage::Data { data } => { + println!(" {data}"); + } + WsServerMessage::NeedVoucher { + channel_id, + required_cumulative, + .. + } => { + println!( + "Server needs voucher for channel {channel_id} (required: {required_cumulative})" + ); + // In real use: sign and send a new voucher + } + WsServerMessage::Receipt { receipt } => { + println!("\nPayment receipt:"); + println!(" Status: {}", receipt["status"]); + println!(" Reference: {}", receipt["reference"]); + break; + } + WsServerMessage::Error { error } => { + eprintln!("Server error: {error}"); + // In this demo, the mock credential will fail verification. + // A real client would use TempoProvider to sign a transaction. + break; + } + } + } +} diff --git a/examples/ws/src/server.rs b/examples/ws/src/server.rs new file mode 100644 index 0000000..1932927 --- /dev/null +++ b/examples/ws/src/server.rs @@ -0,0 +1,155 @@ +//! # WebSocket Payment Server +//! +//! A payment-gated WebSocket server that streams fortunes after payment. +//! +//! ## Running +//! +//! ```bash +//! cargo run --bin ws-server +//! ``` +//! +//! The server listens on `ws://localhost:3000/ws`. + +use std::future::Future; +use std::sync::Arc; + +use axum::extract::ws::{Message, WebSocket}; +use axum::{extract::ws::WebSocketUpgrade, routing::get, Router}; +use mpp::protocol::core::Receipt; +use mpp::protocol::intents::ChargeRequest; +use mpp::protocol::traits::{ChargeMethod, VerificationError}; +use mpp::server::ws::{WsMessage, WsResponse}; +use mpp::server::Mpp; +use mpp::PaymentCredential; + +const FORTUNES: &[&str] = &[ + "A beautiful day awaits you.", + "Good things come to those who pay.", + "Your code will compile on the first try.", + "A WebSocket connection is worth a thousand HTTP requests.", + "Fortune favors the persistent.", +]; + +/// Mock charge method that accepts any credential — for demo purposes only. +#[derive(Clone)] +struct MockMethod; + +#[allow(clippy::manual_async_fn)] +impl ChargeMethod for MockMethod { + fn method(&self) -> &str { + "mock" + } + + fn verify( + &self, + _credential: &PaymentCredential, + _request: &ChargeRequest, + ) -> impl Future> + Send { + async { Ok(Receipt::success("mock", "mock-ws-receipt")) } + } +} + +type Payment = Mpp; + +#[tokio::main] +async fn main() { + let mpp = Mpp::new(MockMethod, "ws-example.local", "ws-example-secret"); + + let mpp = Arc::new(mpp); + + let app = Router::new() + .route("/ws", get(ws_handler)) + .with_state(mpp); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .expect("failed to bind"); + println!("WebSocket server listening on ws://127.0.0.1:3000/ws"); + + axum::serve(listener, app).await.expect("server error"); +} + +async fn ws_handler( + ws: WebSocketUpgrade, + axum::extract::State(mpp): axum::extract::State>, +) -> impl axum::response::IntoResponse { + ws.on_upgrade(move |mut socket| async move { + // 1. Send challenge + let Ok(challenge) = mpp.charge_challenge("10000", "0x0", "0x0") else { + let _ = send_error(&mut socket, "Failed to create challenge").await; + return; + }; + + println!("Sending challenge..."); + let challenge_resp = WsResponse::Challenge { + challenge: serde_json::to_value(&challenge).unwrap(), + error: None, + }; + if socket + .send(Message::Text(challenge_resp.to_text().into())) + .await + .is_err() + { + return; + } + + // 2. Wait for credential + let receipt = loop { + let Some(Ok(Message::Text(msg))) = socket.recv().await else { + return; + }; + + let Ok(WsMessage::Credential { credential }) = serde_json::from_str(&msg) else { + let _ = send_error(&mut socket, "Expected credential message").await; + continue; + }; + + let Ok(parsed) = mpp::parse_authorization(&credential) else { + let _ = send_error(&mut socket, "Malformed credential").await; + continue; + }; + + match mpp.verify_credential(&parsed).await { + Ok(receipt) => { + println!("Payment verified: {}", receipt.reference); + break receipt; + } + Err(e) => { + let _ = send_error(&mut socket, &e.message).await; + } + } + }; + + // 3. Stream fortunes + for i in 1..=3 { + let fortune = FORTUNES[i % FORTUNES.len()]; + let msg = WsResponse::Data { + data: format!("Fortune #{i}: {fortune}"), + }; + if socket + .send(Message::Text(msg.to_text().into())) + .await + .is_err() + { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + } + + // 4. Send receipt + let receipt_msg = WsResponse::Receipt { + receipt: serde_json::to_value(&receipt).unwrap(), + }; + let _ = socket + .send(Message::Text(receipt_msg.to_text().into())) + .await; + println!("Session complete"); + }) +} + +async fn send_error(socket: &mut WebSocket, error: &str) { + let msg = WsResponse::Error { + error: error.to_string(), + }; + let _ = socket.send(Message::Text(msg.to_text().into())).await; +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 3a0d888..8f7a6c1 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -19,6 +19,10 @@ mod error; mod provider; +pub mod transport; + +#[cfg(feature = "ws")] +pub mod ws; #[cfg(feature = "tempo")] pub mod tempo; diff --git a/src/client/transport.rs b/src/client/transport.rs new file mode 100644 index 0000000..cdc6e18 --- /dev/null +++ b/src/client/transport.rs @@ -0,0 +1,110 @@ +//! Client-side transport abstraction. +//! +//! Abstracts how challenges are received and credentials are sent +//! across different transport protocols (HTTP, WebSocket, MCP, etc.). +//! +//! This matches the mppx `Transport` interface from `mppx/client`. +//! +//! # Built-in transports +//! +//! - [`http()`]: HTTP transport (Authorization/WWW-Authenticate headers) +//! +//! # Custom transports +//! +//! Implement [`Transport`] for custom protocols: +//! +//! ```ignore +//! use mpp::client::transport::{Transport}; +//! use mpp::protocol::core::PaymentChallenge; +//! +//! struct MyTransport; +//! +//! impl Transport for MyTransport { +//! type Request = MyRequest; +//! type Response = MyResponse; +//! +//! fn name(&self) -> &str { "custom" } +//! // ... +//! } +//! ``` + +use crate::error::MppError; +use crate::protocol::core::PaymentChallenge; + +/// Client-side transport trait. +/// +/// Abstracts how the client detects payment-required responses, extracts +/// challenges, and attaches credentials to requests. +pub trait Transport: Send + Sync { + /// The outgoing request type. + type Request; + /// The incoming response type. + type Response; + + /// Transport name for identification (e.g., "http", "ws", "mcp"). + fn name(&self) -> &str; + + /// Check if a response indicates payment is required. + fn is_payment_required(&self, response: &Self::Response) -> bool; + + /// Extract the payment challenge from a payment-required response. + fn get_challenge(&self, response: &Self::Response) -> Result; + + /// Attach a credential string to a request. + fn set_credential(&self, request: Self::Request, credential: &str) -> Self::Request; +} + +/// Reqwest HTTP transport for client-side payment handling. +/// +/// - Detects payment required via 402 status +/// - Extracts challenges from `WWW-Authenticate` header +/// - Sends credentials via `Authorization` header +/// +/// This is the default transport, matching mppx's `Transport.http()`. +pub struct HttpTransport; + +/// Create an HTTP transport instance. +pub fn http() -> HttpTransport { + HttpTransport +} + +impl Transport for HttpTransport { + type Request = reqwest::RequestBuilder; + type Response = reqwest::Response; + + fn name(&self) -> &str { + "http" + } + + fn is_payment_required(&self, response: &Self::Response) -> bool { + response.status() == reqwest::StatusCode::PAYMENT_REQUIRED + } + + fn get_challenge(&self, response: &Self::Response) -> Result { + let header = response + .headers() + .get(reqwest::header::WWW_AUTHENTICATE) + .ok_or_else(|| MppError::MissingHeader("WWW-Authenticate".to_string()))?; + + let header_str = header.to_str().map_err(|e| { + MppError::MalformedCredential(Some(format!("invalid WWW-Authenticate header: {e}"))) + })?; + + crate::protocol::core::parse_www_authenticate(header_str) + } + + fn set_credential(&self, request: Self::Request, credential: &str) -> Self::Request { + request.header(reqwest::header::AUTHORIZATION, credential) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_http_transport_name() { + let transport = http(); + assert_eq!(transport.name(), "http"); + } +} diff --git a/src/client/ws.rs b/src/client/ws.rs new file mode 100644 index 0000000..7451acc --- /dev/null +++ b/src/client/ws.rs @@ -0,0 +1,227 @@ +//! WebSocket transport for client-side session payments. +//! +//! Provides a WebSocket transport that implements [`Transport`](super::transport::Transport) +//! for bidirectional payment flows. The client can send vouchers inline over the +//! same WebSocket connection (no separate HTTP request needed). +//! +//! # Message Protocol +//! +//! Uses the same JSON message format as [`server::ws`](crate::server::ws): +//! +//! **Client → Server:** +//! - `{ "type": "credential", "credential": "Payment ..." }` +//! +//! **Server → Client:** +//! - `{ "type": "challenge", "challenge": { ... } }` +//! - `{ "type": "message", "data": "..." }` +//! - `{ "type": "needVoucher", ... }` +//! - `{ "type": "receipt", ... }` +//! +//! # Example +//! +//! ```ignore +//! use mpp::client::ws::WsTransport; +//! use mpp::client::transport::Transport; +//! +//! let transport = WsTransport; +//! ``` + +use serde::{Deserialize, Serialize}; + +use crate::error::MppError; +use crate::protocol::core::PaymentChallenge; + +use super::transport::Transport; + +/// Outgoing WebSocket message from client. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum WsClientMessage { + /// Client sends a payment credential. + Credential { + /// The serialized credential string. + credential: String, + }, + /// Client sends application data. + #[serde(rename = "message")] + Data { + /// Application payload. + data: serde_json::Value, + }, +} + +impl WsClientMessage { + /// Serialize this message to a JSON string for sending over WebSocket. + pub fn to_text(&self) -> String { + serde_json::to_string(self).expect("WsClientMessage serialization cannot fail") + } +} + +/// Incoming WebSocket message from server. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum WsServerMessage { + /// Server issues a payment challenge. + Challenge { + /// The payment challenge. + challenge: serde_json::Value, + /// Optional error context. + #[serde(default)] + error: Option, + }, + /// Server sends application data. + #[serde(rename = "message")] + Data { + /// Application payload. + data: String, + }, + /// Server signals balance exhausted. + NeedVoucher { + /// Channel identifier. + #[serde(rename = "channelId")] + channel_id: String, + /// Minimum cumulative amount required. + #[serde(rename = "requiredCumulative")] + required_cumulative: String, + /// Current highest accepted cumulative amount. + #[serde(rename = "acceptedCumulative")] + accepted_cumulative: String, + /// Current on-chain deposit. + deposit: String, + }, + /// Server sends final payment receipt. + Receipt { + /// The payment receipt. + receipt: serde_json::Value, + }, + /// Server sends an error. + Error { + /// Error message. + error: String, + }, +} + +/// WebSocket transport for client-side payment handling. +/// +/// Detects payment challenges from JSON WebSocket messages and attaches +/// credentials as JSON messages (no HTTP headers involved). +pub struct WsTransport; + +/// Create a WebSocket transport instance. +pub fn ws() -> WsTransport { + WsTransport +} + +impl Transport for WsTransport { + type Request = WsClientMessage; + type Response = WsServerMessage; + + fn name(&self) -> &str { + "ws" + } + + fn is_payment_required(&self, response: &Self::Response) -> bool { + matches!(response, WsServerMessage::Challenge { .. }) + } + + fn get_challenge(&self, response: &Self::Response) -> Result { + let WsServerMessage::Challenge { challenge, .. } = response else { + return Err(MppError::MissingHeader( + "no challenge in WS message".to_string(), + )); + }; + + serde_json::from_value(challenge.clone()).map_err(|e| { + MppError::MalformedCredential(Some(format!("failed to parse WS challenge: {e}"))) + }) + } + + fn set_credential(&self, _request: Self::Request, credential: &str) -> Self::Request { + WsClientMessage::Credential { + credential: credential.to_string(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ws_transport_name() { + let transport = ws(); + assert_eq!(transport.name(), "ws"); + } + + #[test] + fn test_ws_client_message_credential_serde() { + let msg = WsClientMessage::Credential { + credential: "Payment id=\"abc\"".to_string(), + }; + let json = msg.to_text(); + assert!(json.contains("\"type\":\"credential\"")); + + let parsed: WsClientMessage = serde_json::from_str(&json).unwrap(); + assert!(matches!(parsed, WsClientMessage::Credential { .. })); + } + + #[test] + fn test_ws_client_message_data_serde() { + let msg = WsClientMessage::Data { + data: serde_json::json!({"prompt": "hello"}), + }; + let json = msg.to_text(); + assert!(json.contains("\"type\":\"message\"")); + } + + #[test] + fn test_ws_server_message_challenge() { + let json = r#"{"type":"challenge","challenge":{"id":"ch-1","realm":"test","method":"tempo","intent":"charge","request":"eyJ0ZXN0IjoidmFsdWUifQ"}}"#; + let parsed: WsServerMessage = serde_json::from_str(json).unwrap(); + assert!(matches!(parsed, WsServerMessage::Challenge { .. })); + } + + #[test] + fn test_ws_server_message_need_voucher() { + let json = r#"{"type":"needVoucher","channelId":"0xabc","requiredCumulative":"2000","acceptedCumulative":"1000","deposit":"5000"}"#; + let parsed: WsServerMessage = serde_json::from_str(json).unwrap(); + match parsed { + WsServerMessage::NeedVoucher { channel_id, .. } => { + assert_eq!(channel_id, "0xabc"); + } + _ => panic!("expected NeedVoucher"), + } + } + + #[test] + fn test_is_payment_required() { + let transport = ws(); + + let challenge = WsServerMessage::Challenge { + challenge: serde_json::json!({}), + error: None, + }; + assert!(transport.is_payment_required(&challenge)); + + let data = WsServerMessage::Data { + data: "hello".into(), + }; + assert!(!transport.is_payment_required(&data)); + } + + #[test] + fn test_set_credential() { + let transport = ws(); + let dummy = WsClientMessage::Data { + data: serde_json::json!({}), + }; + + let result = transport.set_credential(dummy, "Payment id=\"abc\""); + match result { + WsClientMessage::Credential { credential } => { + assert_eq!(credential, "Payment id=\"abc\""); + } + _ => panic!("expected Credential message"), + } + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 6266346..f2e2b78 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -29,6 +29,13 @@ mod amount; mod mpp; pub mod sse; +pub mod transport; + +#[cfg(feature = "ws")] +pub mod ws; + +#[cfg(all(feature = "ws", feature = "tempo"))] +pub mod ws_session; #[cfg(feature = "tower")] pub mod middleware; diff --git a/src/server/transport.rs b/src/server/transport.rs new file mode 100644 index 0000000..8ed3ebb --- /dev/null +++ b/src/server/transport.rs @@ -0,0 +1,314 @@ +//! Server-side transport abstraction. +//! +//! Abstracts how challenges are issued and credentials are received +//! across different transport protocols (HTTP, WebSocket, MCP, etc.). +//! +//! This matches the mppx `Transport` interface from `mppx/server`. +//! +//! # Built-in transports +//! +//! - [`http()`]: HTTP transport (Authorization/WWW-Authenticate headers) +//! +//! # Custom transports +//! +//! Implement [`Transport`] for custom protocols: +//! +//! ```ignore +//! use mpp::server::transport::{Transport, ChallengeContext}; +//! use mpp::protocol::core::{PaymentChallenge, PaymentCredential, Receipt}; +//! +//! struct MyTransport; +//! +//! impl Transport for MyTransport { +//! type Input = MyRequest; +//! type ChallengeOutput = MyResponse; +//! type ReceiptOutput = MyResponse; +//! +//! fn name(&self) -> &str { "custom" } +//! // ... +//! } +//! ``` + +use crate::error::MppError; +use crate::protocol::core::{PaymentChallenge, PaymentCredential, Receipt}; + +/// Context passed to [`Transport::respond_challenge`]. +pub struct ChallengeContext<'a, I> { + /// The payment challenge to send to the client. + pub challenge: &'a PaymentChallenge, + /// The original transport input (e.g., HTTP request). + pub input: &'a I, + /// Optional error message for the client. + pub error: Option<&'a str>, +} + +/// Context passed to [`Transport::respond_receipt`]. +pub struct ReceiptContext<'a, R> { + /// The challenge ID this receipt corresponds to. + pub challenge_id: &'a str, + /// The payment receipt. + pub receipt: &'a Receipt, + /// The application response to attach the receipt to. + pub response: R, +} + +/// Server-side transport trait. +/// +/// Abstracts how the server extracts credentials from incoming requests, +/// issues payment challenges, and attaches receipts to responses. +pub trait Transport: Send + Sync { + /// The incoming request/message type (e.g., `http::Request`). + type Input; + /// The response type for payment challenges (e.g., `http::Response`). + type ChallengeOutput; + /// The response type after attaching a receipt. + type ReceiptOutput; + + /// Transport name for identification (e.g., "http", "ws", "mcp"). + fn name(&self) -> &str; + + /// Extract a payment credential from the transport input. + /// + /// Returns `Ok(Some(credential))` if a valid credential is present, + /// `Ok(None)` if no credential was provided (trigger challenge), + /// or `Err` if the credential is malformed. + fn get_credential(&self, input: &Self::Input) -> Result, MppError>; + + /// Create a transport response for a payment challenge. + fn respond_challenge(&self, ctx: ChallengeContext<'_, Self::Input>) -> Self::ChallengeOutput; + + /// Attach a receipt to a successful response. + fn respond_receipt(&self, ctx: ReceiptContext<'_, Self::ReceiptOutput>) -> Self::ReceiptOutput; +} + +/// HTTP transport for server-side payment handling. +/// +/// - Reads credentials from the `Authorization` header +/// - Issues challenges via `WWW-Authenticate` header with 402 status +/// - Attaches receipts via `Payment-Receipt` header +/// +/// This is the default transport, matching mppx's `Transport.http()`. +pub struct HttpTransport; + +/// Create an HTTP transport instance. +pub fn http() -> HttpTransport { + HttpTransport +} + +impl Transport for HttpTransport { + type Input = http_types::Request<()>; + type ChallengeOutput = http_types::Response; + type ReceiptOutput = http_types::Response; + + fn name(&self) -> &str { + "http" + } + + fn get_credential(&self, input: &Self::Input) -> Result, MppError> { + let Some(header) = input.headers().get(http_types::header::AUTHORIZATION) else { + return Ok(None); + }; + + let header_str = header + .to_str() + .map_err(|e| MppError::MalformedCredential(Some(format!("invalid header: {e}"))))?; + + let Some(payment) = crate::protocol::core::extract_payment_scheme(header_str) else { + return Ok(None); + }; + + // extract_payment_scheme returns the full "Payment ..." fragment + let credential = crate::protocol::core::parse_authorization(payment).map_err(|e| { + MppError::MalformedCredential(Some(format!("failed to parse credential: {e}"))) + })?; + + Ok(Some(credential)) + } + + fn respond_challenge(&self, ctx: ChallengeContext<'_, Self::Input>) -> Self::ChallengeOutput { + let www_auth = crate::protocol::core::format_www_authenticate(ctx.challenge) + .unwrap_or_else(|_| "Payment".to_string()); + + let body = match ctx.error { + Some(msg) => serde_json::json!({ "error": msg }).to_string(), + None => serde_json::json!({ "error": "Payment Required" }).to_string(), + }; + + let mut resp = http_types::Response::builder() + .status(http_types::StatusCode::PAYMENT_REQUIRED) + .header(http_types::header::WWW_AUTHENTICATE, &www_auth) + .header(http_types::header::CONTENT_TYPE, "application/json") + .body(body) + .expect("response builder cannot fail"); + + // Add Cache-Control: no-store to prevent caching of challenges + resp.headers_mut().insert( + http_types::header::CACHE_CONTROL, + http_types::HeaderValue::from_static("no-store"), + ); + + resp + } + + fn respond_receipt(&self, ctx: ReceiptContext<'_, Self::ReceiptOutput>) -> Self::ReceiptOutput { + let receipt_header = + crate::protocol::core::format_receipt(ctx.receipt).unwrap_or_else(|_| String::new()); + + let mut resp = ctx.response; + if let Ok(value) = http_types::HeaderValue::from_str(&receipt_header) { + resp.headers_mut() + .insert(crate::protocol::core::PAYMENT_RECEIPT_HEADER, value); + } + resp + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_http_transport_name() { + let transport = http(); + assert_eq!(transport.name(), "http"); + } + + #[test] + fn test_http_get_credential_none() { + let transport = http(); + let req = http_types::Request::builder() + .uri("/test") + .body(()) + .unwrap(); + let result = transport.get_credential(&req).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_http_get_credential_non_payment_auth() { + let transport = http(); + let req = http_types::Request::builder() + .uri("/test") + .header("Authorization", "Bearer some-token") + .body(()) + .unwrap(); + let result = transport.get_credential(&req).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_http_get_credential_valid_payment() { + let transport = http(); + + // Build a valid Payment authorization header + let challenge = PaymentChallenge::new( + "test-id", + "test.example.com", + "tempo", + "charge", + crate::protocol::core::Base64UrlJson::from_value( + &serde_json::json!({"amount": "1000"}), + ) + .unwrap(), + ); + let credential = crate::protocol::core::PaymentCredential::new( + challenge.to_echo(), + crate::protocol::core::PaymentPayload::hash("0xdeadbeef"), + ); + let auth_header = crate::protocol::core::format_authorization(&credential).unwrap(); + + let req = http_types::Request::builder() + .uri("/test") + .header("Authorization", &auth_header) + .body(()) + .unwrap(); + + let result = transport.get_credential(&req).unwrap(); + assert!(result.is_some(), "should parse valid Payment credential"); + let parsed = result.unwrap(); + assert_eq!(parsed.challenge.id, "test-id"); + } + + #[test] + fn test_http_respond_challenge() { + let transport = http(); + let challenge = PaymentChallenge::new( + "test-id", + "test.example.com", + "tempo", + "charge", + crate::protocol::core::Base64UrlJson::from_value( + &serde_json::json!({"amount": "1000"}), + ) + .unwrap(), + ); + let req = http_types::Request::builder() + .uri("/test") + .body(()) + .unwrap(); + + let resp = transport.respond_challenge(ChallengeContext { + challenge: &challenge, + input: &req, + error: None, + }); + + assert_eq!(resp.status(), http_types::StatusCode::PAYMENT_REQUIRED); + assert!(resp + .headers() + .get(http_types::header::WWW_AUTHENTICATE) + .is_some()); + assert!(resp.body().contains("Payment Required")); + } + + #[test] + fn test_http_respond_challenge_with_error() { + let transport = http(); + let challenge = PaymentChallenge::new( + "test-id", + "test.example.com", + "tempo", + "charge", + crate::protocol::core::Base64UrlJson::from_value( + &serde_json::json!({"amount": "1000"}), + ) + .unwrap(), + ); + let req = http_types::Request::builder() + .uri("/test") + .body(()) + .unwrap(); + + let resp = transport.respond_challenge(ChallengeContext { + challenge: &challenge, + input: &req, + error: Some("Verification failed"), + }); + + assert_eq!(resp.status(), http_types::StatusCode::PAYMENT_REQUIRED); + assert!(resp.body().contains("Verification failed")); + } + + #[test] + fn test_http_respond_receipt() { + let transport = http(); + let receipt = Receipt::success("tempo", "0xabc123"); + + let resp = http_types::Response::builder() + .status(http_types::StatusCode::OK) + .body("ok".to_string()) + .unwrap(); + + let resp = transport.respond_receipt(ReceiptContext { + challenge_id: "ch-1", + receipt: &receipt, + response: resp, + }); + + assert_eq!(resp.status(), http_types::StatusCode::OK); + assert!(resp + .headers() + .get(crate::protocol::core::PAYMENT_RECEIPT_HEADER) + .is_some()); + } +} diff --git a/src/server/ws.rs b/src/server/ws.rs new file mode 100644 index 0000000..5317041 --- /dev/null +++ b/src/server/ws.rs @@ -0,0 +1,313 @@ +//! WebSocket transport for server-side session payments. +//! +//! Provides a WebSocket transport that implements [`Transport`](super::transport::Transport) +//! for bidirectional payment flows. Unlike SSE (server→client only), WebSocket +//! allows the client to send vouchers inline without a separate HTTP request. +//! +//! # Message Protocol +//! +//! All messages are JSON-encoded with a `type` discriminator: +//! +//! **Client → Server:** +//! - `{ "type": "credential", "credential": "Payment ..." }` — payment credential +//! +//! **Server → Client:** +//! - `{ "type": "challenge", "challenge": { ... } }` — payment challenge +//! - `{ "type": "message", "data": "..." }` — application data +//! - `{ "type": "needVoucher", ... }` — balance exhausted, send new voucher +//! - `{ "type": "receipt", "receipt": { ... } }` — final payment receipt +//! - `{ "type": "error", "error": "..." }` — error message +//! +//! # Example +//! +//! ```ignore +//! use mpp::server::ws::{WsTransport, WsMessage}; +//! +//! let transport = WsTransport; +//! +//! // Parse incoming WS message +//! let msg: WsMessage = serde_json::from_str(&text)?; +//! let credential = transport.get_credential(&msg)?; +//! ``` + +use serde::{Deserialize, Serialize}; + +use crate::error::MppError; +use crate::protocol::core::PaymentCredential; + +use super::transport::{ChallengeContext, ReceiptContext, Transport}; + +/// Incoming WebSocket message from client. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum WsMessage { + /// Client sends a payment credential. + Credential { + /// The serialized credential string (e.g., "Payment id=..., ..."). + credential: String, + }, + /// Client sends application data. + #[serde(rename = "message")] + Data { + /// Application payload. + data: serde_json::Value, + }, +} + +/// Outgoing WebSocket message from server. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum WsResponse { + /// Server issues a payment challenge. + Challenge { + /// The payment challenge. + challenge: serde_json::Value, + /// Optional error context. + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + }, + /// Server sends application data. + #[serde(rename = "message")] + Data { + /// Application payload. + data: String, + }, + /// Server signals balance exhausted — client should send a new voucher. + NeedVoucher { + /// Channel identifier. + #[serde(rename = "channelId")] + channel_id: String, + /// Minimum cumulative amount required for next voucher. + #[serde(rename = "requiredCumulative")] + required_cumulative: String, + /// Current highest accepted cumulative amount. + #[serde(rename = "acceptedCumulative")] + accepted_cumulative: String, + /// Current on-chain deposit. + deposit: String, + }, + /// Server sends final payment receipt. + Receipt { + /// The payment receipt. + receipt: serde_json::Value, + }, + /// Server sends an error. + Error { + /// Error message. + error: String, + }, +} + +impl WsResponse { + /// Serialize this response to a JSON string for sending over WebSocket. + pub fn to_text(&self) -> String { + serde_json::to_string(self).expect("WsResponse serialization cannot fail") + } +} + +/// WebSocket transport for server-side payment handling. +/// +/// Messages are JSON-encoded WebSocket text frames with a `type` discriminator. +/// The client sends credentials as `{ "type": "credential", "credential": "Payment ..." }`, +/// and the server responds with challenges, data, and receipts. +pub struct WsTransport; + +/// Create a WebSocket transport instance. +pub fn ws() -> WsTransport { + WsTransport +} + +impl Transport for WsTransport { + type Input = WsMessage; + type ChallengeOutput = WsResponse; + type ReceiptOutput = WsResponse; + + fn name(&self) -> &str { + "ws" + } + + fn get_credential(&self, input: &Self::Input) -> Result, MppError> { + match input { + WsMessage::Credential { credential } => { + let parsed = + crate::protocol::core::parse_authorization(credential).map_err(|e| { + MppError::MalformedCredential(Some(format!( + "failed to parse WS credential: {e}" + ))) + })?; + Ok(Some(parsed)) + } + WsMessage::Data { .. } => Ok(None), + } + } + + fn respond_challenge(&self, ctx: ChallengeContext<'_, Self::Input>) -> Self::ChallengeOutput { + let challenge_json = serde_json::to_value(ctx.challenge) + .unwrap_or_else(|_| serde_json::json!({"error": "serialization failed"})); + + WsResponse::Challenge { + challenge: challenge_json, + error: ctx.error.map(|s| s.to_string()), + } + } + + fn respond_receipt(&self, ctx: ReceiptContext<'_, Self::ReceiptOutput>) -> Self::ReceiptOutput { + let receipt_json = serde_json::to_value(ctx.receipt) + .unwrap_or_else(|_| serde_json::json!({"error": "serialization failed"})); + + WsResponse::Receipt { + receipt: receipt_json, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::core::{Base64UrlJson, PaymentChallenge, Receipt}; + + #[test] + fn test_ws_transport_name() { + let transport = ws(); + assert_eq!(transport.name(), "ws"); + } + + #[test] + fn test_ws_message_credential_serde() { + let msg = WsMessage::Credential { + credential: "Payment id=\"abc\"".to_string(), + }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("\"type\":\"credential\"")); + assert!(json.contains("\"credential\":\"Payment id=\\\"abc\\\"\"")); + + let parsed: WsMessage = serde_json::from_str(&json).unwrap(); + match parsed { + WsMessage::Credential { credential } => { + assert_eq!(credential, "Payment id=\"abc\"") + } + _ => panic!("expected Credential variant"), + } + } + + #[test] + fn test_ws_message_data_serde() { + let msg = WsMessage::Data { + data: serde_json::json!({"prompt": "hello"}), + }; + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("\"type\":\"message\"")); + + let parsed: WsMessage = serde_json::from_str(&json).unwrap(); + assert!(matches!(parsed, WsMessage::Data { .. })); + } + + #[test] + fn test_ws_response_challenge_serde() { + let resp = WsResponse::Challenge { + challenge: serde_json::json!({"id": "ch-1", "method": "tempo"}), + error: None, + }; + let json = resp.to_text(); + assert!(json.contains("\"type\":\"challenge\"")); + assert!(json.contains("\"ch-1\"")); + } + + #[test] + fn test_ws_response_need_voucher_serde() { + let resp = WsResponse::NeedVoucher { + channel_id: "0xabc".into(), + required_cumulative: "2000".into(), + accepted_cumulative: "1000".into(), + deposit: "5000".into(), + }; + let json = resp.to_text(); + assert!(json.contains("\"type\":\"needVoucher\"")); + assert!(json.contains("\"channelId\":\"0xabc\"")); + } + + #[test] + fn test_ws_response_receipt_serde() { + let resp = WsResponse::Receipt { + receipt: serde_json::json!({"status": "success", "reference": "0x123"}), + }; + let json = resp.to_text(); + assert!(json.contains("\"type\":\"receipt\"")); + assert!(json.contains("\"0x123\"")); + } + + #[test] + fn test_ws_response_error_serde() { + let resp = WsResponse::Error { + error: "payment failed".into(), + }; + let json = resp.to_text(); + assert!(json.contains("\"type\":\"error\"")); + assert!(json.contains("payment failed")); + } + + #[test] + fn test_ws_get_credential_none_for_data() { + let transport = ws(); + let msg = WsMessage::Data { + data: serde_json::json!({"prompt": "hello"}), + }; + let result = transport.get_credential(&msg).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_ws_respond_challenge() { + let transport = ws(); + let challenge = PaymentChallenge::new( + "test-id", + "test.example.com", + "tempo", + "charge", + Base64UrlJson::from_value(&serde_json::json!({"amount": "1000"})).unwrap(), + ); + let msg = WsMessage::Data { + data: serde_json::json!({}), + }; + + let resp = transport.respond_challenge(ChallengeContext { + challenge: &challenge, + input: &msg, + error: None, + }); + + match resp { + WsResponse::Challenge { + challenge: ch, + error, + } => { + assert!(ch.get("id").is_some()); + assert!(error.is_none()); + } + _ => panic!("expected Challenge response"), + } + } + + #[test] + fn test_ws_respond_receipt() { + let transport = ws(); + let receipt = Receipt::success("tempo", "0xabc123"); + + // The response input doesn't matter for receipts — it's replaced + let dummy = WsResponse::Data { data: "ok".into() }; + + let resp = transport.respond_receipt(ReceiptContext { + challenge_id: "ch-1", + receipt: &receipt, + response: dummy, + }); + + match resp { + WsResponse::Receipt { receipt } => { + assert_eq!(receipt["status"], "success"); + assert_eq!(receipt["reference"], "0xabc123"); + } + _ => panic!("expected Receipt response"), + } + } +} diff --git a/src/server/ws_session.rs b/src/server/ws_session.rs new file mode 100644 index 0000000..69a2ee1 --- /dev/null +++ b/src/server/ws_session.rs @@ -0,0 +1,183 @@ +//! WebSocket session handler with metered streaming. +//! +//! Implements the full session payment flow over WebSocket, equivalent to +//! the SSE metering loop in [`sse::serve`](super::sse::serve) but with +//! bidirectional communication — clients send vouchers inline as WS frames +//! instead of separate HTTP requests. +//! +//! # Flow +//! +//! 1. Server sends session challenge +//! 2. Client sends open credential (with deposit transaction) +//! 3. Server verifies, begins streaming data +//! 4. Per tick: deduct from channel balance +//! 5. When exhausted: send `needVoucher`, wait for voucher frame +//! 6. Client sends voucher credential → server verifies, resumes +//! 7. On completion: send session receipt, close +//! +//! # Example +//! +//! ```ignore +//! use mpp::server::ws_session::{WsSessionOptions, ws_session}; +//! +//! ws_session(socket, WsSessionOptions { +//! store, +//! mpp: &mpp, +//! channel_id: "0xabc", +//! challenge_id: "ch-1", +//! tick_cost: 1000, +//! generate: my_stream, +//! poll_interval_ms: 100, +//! }).await; +//! ``` + +use std::sync::Arc; + +use futures_util::{SinkExt, StreamExt}; +use time::format_description::well_known::Iso8601; +use time::OffsetDateTime; + +use super::ws::{WsMessage, WsResponse}; +use crate::protocol::core::parse_authorization; +use crate::protocol::methods::tempo::session_method::deduct_from_channel; +use crate::protocol::methods::tempo::session_receipt::SessionReceipt; +use crate::protocol::traits::{ChargeMethod, SessionMethod}; + +/// Options for [`ws_session`]. +pub struct WsSessionOptions { + /// Channel store for balance tracking. + pub store: Arc, + /// Channel ID (hex). + pub channel_id: String, + /// Challenge ID for the receipt. + pub challenge_id: String, + /// Cost per tick (emitted value) in base units. + pub tick_cost: u128, + /// The async generator producing application data. + pub generate: G, + /// Polling interval in ms when `wait_for_update` is not available. Default: 100. + pub poll_interval_ms: u64, +} + +/// Run a metered session over a split WebSocket connection. +/// +/// `sender` emits data frames and payment control messages (needVoucher, receipt). +/// `receiver` listens for incoming voucher credentials and updates the channel store. +/// +/// This is the WebSocket equivalent of [`sse::serve`](super::sse::serve), with the +/// key advantage that vouchers arrive on the same connection (no separate HTTP POST). +pub async fn ws_session(sender: &mut S, options: WsSessionOptions) +where + G: futures_core::Stream + Send + Unpin + 'static, + S: futures_util::Sink> + Send + Unpin, +{ + let WsSessionOptions { + store, + channel_id, + challenge_id, + tick_cost, + generate, + poll_interval_ms, + } = options; + + let mut stream = std::pin::pin!(generate); + + while let Some(value) = stream.next().await { + // Deduct, waiting for voucher top-up if insufficient + loop { + match deduct_from_channel(&*store, &channel_id, tick_cost).await { + Ok(_state) => break, + Err(_) => { + // Emit needVoucher frame + if let Ok(Some(ch)) = store.get_channel(&channel_id).await { + let msg = WsResponse::NeedVoucher { + channel_id: channel_id.clone(), + required_cumulative: (ch.spent + tick_cost).to_string(), + accepted_cumulative: ch.highest_voucher_amount.to_string(), + deposit: ch.deposit.to_string(), + }; + let _ = sender.send(msg.to_text()).await; + } + + // Wait for channel update (voucher from receiver) or poll + tokio::select! { + _ = store.wait_for_update(&channel_id) => {}, + _ = tokio::time::sleep(tokio::time::Duration::from_millis(poll_interval_ms)) => {}, + } + } + } + } + + // Send data frame + let msg = WsResponse::Data { data: value }; + if sender.send(msg.to_text()).await.is_err() { + break; + } + } + + // Emit final session receipt + if let Ok(Some(ch)) = store.get_channel(&channel_id).await { + let timestamp = OffsetDateTime::now_utc() + .format(&Iso8601::DEFAULT) + .expect("ISO 8601 formatting cannot fail"); + + let mut receipt = SessionReceipt::new( + timestamp, + &challenge_id, + &channel_id, + ch.highest_voucher_amount.to_string(), + ch.spent.to_string(), + ); + receipt.units = Some(ch.units); + + let msg = WsResponse::Receipt { + receipt: serde_json::to_value(&receipt) + .unwrap_or_else(|_| serde_json::json!({"error": "serialization failed"})), + }; + let _ = sender.send(msg.to_text()).await; + } +} + +/// Process incoming WebSocket messages for voucher credentials. +/// +/// Call this concurrently with [`ws_session`] on the receiver half of a +/// split WebSocket. When a voucher credential arrives, it's verified via +/// the session method, which updates the channel store and unblocks the +/// sender's `wait_for_update`. +pub async fn process_incoming_vouchers(receiver: &mut R, mpp: &crate::server::Mpp) +where + M: ChargeMethod, + S: SessionMethod, + R: futures_util::Stream>> + + Send + + Unpin, +{ + while let Some(Ok(text)) = receiver.next().await { + let Ok(WsMessage::Credential { credential }) = serde_json::from_str(&text) else { + continue; + }; + let Ok(parsed) = parse_authorization(&credential) else { + continue; + }; + let _ = mpp.verify_session(&parsed).await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::methods::tempo::session_method::InMemoryChannelStore; + + #[test] + fn test_ws_session_options_fields() { + let store = Arc::new(InMemoryChannelStore::new()); + let _opts = WsSessionOptions { + store, + channel_id: "0xabc".to_string(), + challenge_id: "ch-1".to_string(), + tick_cost: 1000, + generate: futures_util::stream::empty::(), + poll_interval_ms: 100, + }; + } +} diff --git a/tests/integration_ws.rs b/tests/integration_ws.rs new file mode 100644 index 0000000..481d0a8 --- /dev/null +++ b/tests/integration_ws.rs @@ -0,0 +1,342 @@ +//! Integration tests for the WebSocket transport. +//! +//! Spins up an axum server with a WS endpoint and tests the full +//! challenge → credential → data → receipt flow over WebSocket. +//! +//! # Running +//! +//! ```bash +//! cargo test --features ws,tempo,server,client,axum --test integration_ws +//! ``` + +#![cfg(all(feature = "ws", feature = "tempo", feature = "axum"))] + +use axum::{routing::get, Router}; +use futures_util::{SinkExt, StreamExt}; +use mpp::protocol::core::{format_authorization, PaymentPayload}; +use mpp::server::ws::{WsMessage, WsResponse}; +use mpp::server::{tempo, Mpp, TempoConfig}; +use tokio_tungstenite::tungstenite; + +/// Start an axum server with a WS payment endpoint. +async fn start_ws_server() -> (String, tokio::task::JoinHandle<()>) { + let mpp = Mpp::create( + tempo(TempoConfig { + recipient: "0x742d35Cc6634C0532925a3b844Bc9e7595f1B0F2", + }) + .secret_key("ws-test-secret"), + ) + .expect("failed to create Mpp"); + + let mpp = std::sync::Arc::new(mpp); + + let app = Router::new().route( + "/ws", + get({ + let mpp = mpp.clone(); + move |ws: axum::extract::ws::WebSocketUpgrade| { + let mpp = mpp.clone(); + async move { + ws.on_upgrade(move |mut socket| async move { + use axum::extract::ws::Message; + + // Send challenge + let challenge = mpp.charge("0.01").expect("challenge"); + let challenge_resp = WsResponse::Challenge { + challenge: serde_json::to_value(&challenge).unwrap(), + error: None, + }; + let _ = socket + .send(Message::Text(challenge_resp.to_text().into())) + .await; + + // Wait for credential + while let Some(Ok(Message::Text(text))) = socket.recv().await { + let Ok(WsMessage::Credential { credential }) = + serde_json::from_str(&text) + else { + continue; + }; + + let Ok(parsed) = mpp::parse_authorization(&credential) else { + continue; + }; + + match mpp.verify_credential(&parsed).await { + Ok(receipt) => { + let data = WsResponse::Data { + data: "hello from ws".into(), + }; + let _ = socket.send(Message::Text(data.to_text().into())).await; + + let receipt_msg = WsResponse::Receipt { + receipt: serde_json::to_value(&receipt).unwrap(), + }; + let _ = socket + .send(Message::Text(receipt_msg.to_text().into())) + .await; + break; + } + Err(e) => { + let err = WsResponse::Error { error: e.message }; + let _ = socket.send(Message::Text(err.to_text().into())).await; + } + } + } + }) + } + } + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("failed to bind"); + let addr = listener.local_addr().unwrap(); + let url = format!("ws://127.0.0.1:{}", addr.port()); + + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.expect("server error"); + }); + + (url, handle) +} + +/// Full e2e: connect WS → receive challenge → send credential → receive data + receipt. +#[tokio::test] +async fn test_ws_e2e_challenge_credential_flow() { + let (url, handle) = start_ws_server().await; + + let (mut ws, _) = tokio_tungstenite::connect_async(format!("{url}/ws")) + .await + .expect("ws connect failed"); + + // 1. Receive challenge + let msg = ws.next().await.unwrap().unwrap(); + let text = msg.into_text().unwrap(); + let server_msg: WsResponse = serde_json::from_str(&text).unwrap(); + + let WsResponse::Challenge { challenge, .. } = server_msg else { + panic!("expected Challenge, got: {server_msg:?}"); + }; + let challenge: mpp::PaymentChallenge = + serde_json::from_value(challenge).expect("parse challenge"); + + assert_eq!(challenge.method.as_str(), "tempo"); + assert_eq!(challenge.intent.as_str(), "charge"); + + // 2. Send credential (mock — use a hash payload) + let credential = + mpp::PaymentCredential::new(challenge.to_echo(), PaymentPayload::hash("0xdeadbeef")); + let auth_str = format_authorization(&credential).unwrap(); + let cred_msg = WsMessage::Credential { + credential: auth_str, + }; + ws.send(tungstenite::Message::Text( + serde_json::to_string(&cred_msg).unwrap().into(), + )) + .await + .unwrap(); + + // 3. Receive response (either data+receipt or error — depends on mock verify) + // With a mock hash, verify will likely fail. That's fine — we're testing the protocol. + let msg = ws.next().await.unwrap().unwrap(); + let text = msg.into_text().unwrap(); + let response: WsResponse = serde_json::from_str(&text).unwrap(); + + // We accept either an error (mock verify fails) or data (if mock verify somehow passes) + match response { + WsResponse::Error { error } => { + // Expected — mock credential won't pass real tempo verification + assert!(!error.is_empty()); + } + WsResponse::Data { data } => { + assert_eq!(data, "hello from ws"); + } + other => panic!("unexpected response: {other:?}"), + } + + handle.abort(); +} + +/// WS message serialization roundtrip. +#[tokio::test] +async fn test_ws_message_types_over_wire() { + let (url, handle) = start_ws_server().await; + + let (mut ws, _) = tokio_tungstenite::connect_async(format!("{url}/ws")) + .await + .expect("ws connect failed"); + + // Should receive a challenge as first message + let msg = ws.next().await.unwrap().unwrap(); + let text = msg.into_text().unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&text).unwrap(); + + assert_eq!(parsed["type"], "challenge"); + assert!(parsed["challenge"].is_object()); + assert!(parsed["challenge"]["id"].is_string()); + assert!(parsed["challenge"]["method"].is_string()); + + // Send garbage — should get error back + ws.send(tungstenite::Message::Text("not json".into())) + .await + .unwrap(); + + // Server may ignore or send error — just verify connection stays alive + // Send valid but non-credential message + let data_msg = serde_json::json!({"type": "message", "data": {"foo": "bar"}}); + ws.send(tungstenite::Message::Text(data_msg.to_string().into())) + .await + .unwrap(); + + handle.abort(); +} + +/// Credential with wrong challenge ID should be rejected. +#[tokio::test] +async fn test_ws_challenge_id_mismatch_rejected() { + let (url, handle) = start_ws_server().await; + + let (mut ws, _) = tokio_tungstenite::connect_async(format!("{url}/ws")) + .await + .expect("ws connect failed"); + + // Receive challenge + let msg = ws.next().await.unwrap().unwrap(); + let text = msg.into_text().unwrap(); + let server_msg: WsResponse = serde_json::from_str(&text).unwrap(); + let WsResponse::Challenge { .. } = server_msg else { + panic!("expected Challenge, got: {server_msg:?}"); + }; + + // Send credential with a DIFFERENT challenge ID (forged echo) + let fake_challenge = mpp::PaymentChallenge::new( + "wrong-challenge-id", + "test.example.com", + "tempo", + "charge", + mpp::Base64UrlJson::from_value(&serde_json::json!({"amount": "999"})).unwrap(), + ); + let credential = + mpp::PaymentCredential::new(fake_challenge.to_echo(), PaymentPayload::hash("0xdeadbeef")); + let auth_str = format_authorization(&credential).unwrap(); + let cred_msg = WsMessage::Credential { + credential: auth_str, + }; + ws.send(tungstenite::Message::Text( + serde_json::to_string(&cred_msg).unwrap().into(), + )) + .await + .unwrap(); + + // Should get error about challenge ID mismatch + let msg = ws.next().await.unwrap().unwrap(); + let text = msg.into_text().unwrap(); + let response: WsResponse = serde_json::from_str(&text).unwrap(); + + // Credential with wrong challenge ID should be rejected (HMAC mismatch + // or decode failure — either way it must not succeed) + match response { + WsResponse::Error { error } => { + assert!(!error.is_empty(), "error should not be empty"); + } + WsResponse::Challenge { error: Some(e), .. } => { + assert!(!e.is_empty()); + } + WsResponse::Data { .. } | WsResponse::Receipt { .. } => { + panic!("credential with wrong challenge ID should not succeed"); + } + other => panic!("unexpected response: {other:?}"), + } + + handle.abort(); +} + +/// Server/client wire types are cross-compatible. +#[test] +fn test_server_client_wire_type_compat() { + use mpp::client::ws::WsServerMessage; + + // Serialize with server types, deserialize with client types + let server_challenge = WsResponse::Challenge { + challenge: serde_json::json!({"id": "ch-1", "method": "tempo", "intent": "charge", "realm": "test", "request": "eyJ0ZXN0Ijp0cnVlfQ"}), + error: None, + }; + let json = server_challenge.to_text(); + let client_parsed: WsServerMessage = serde_json::from_str(&json).unwrap(); + assert!(matches!(client_parsed, WsServerMessage::Challenge { .. })); + + let server_data = WsResponse::Data { + data: "hello".into(), + }; + let json = server_data.to_text(); + let client_parsed: WsServerMessage = serde_json::from_str(&json).unwrap(); + assert!(matches!(client_parsed, WsServerMessage::Data { .. })); + + let server_nv = WsResponse::NeedVoucher { + channel_id: "0xabc".into(), + required_cumulative: "2000".into(), + accepted_cumulative: "1000".into(), + deposit: "5000".into(), + }; + let json = server_nv.to_text(); + let client_parsed: WsServerMessage = serde_json::from_str(&json).unwrap(); + assert!(matches!(client_parsed, WsServerMessage::NeedVoucher { .. })); + + let server_receipt = WsResponse::Receipt { + receipt: serde_json::json!({"status": "success"}), + }; + let json = server_receipt.to_text(); + let client_parsed: WsServerMessage = serde_json::from_str(&json).unwrap(); + assert!(matches!(client_parsed, WsServerMessage::Receipt { .. })); + + let server_err = WsResponse::Error { + error: "bad".into(), + }; + let json = server_err.to_text(); + let client_parsed: WsServerMessage = serde_json::from_str(&json).unwrap(); + assert!(matches!(client_parsed, WsServerMessage::Error { .. })); + + // Serialize with client types, deserialize with server types + use mpp::client::ws::WsClientMessage; + let client_cred = WsClientMessage::Credential { + credential: "Payment id=\"abc\"".into(), + }; + let json = client_cred.to_text(); + let server_parsed: WsMessage = serde_json::from_str(&json).unwrap(); + assert!(matches!(server_parsed, WsMessage::Credential { .. })); + + let client_data = WsClientMessage::Data { + data: serde_json::json!({"prompt": "hello"}), + }; + let json = client_data.to_text(); + let server_parsed: WsMessage = serde_json::from_str(&json).unwrap(); + assert!(matches!(server_parsed, WsMessage::Data { .. })); +} + +/// NeedVoucher message serde works over the wire. +#[test] +fn test_need_voucher_roundtrip() { + let resp = WsResponse::NeedVoucher { + channel_id: "0xabc123".into(), + required_cumulative: "2000000".into(), + accepted_cumulative: "1000000".into(), + deposit: "5000000".into(), + }; + + let json = resp.to_text(); + let parsed: WsResponse = serde_json::from_str(&json).unwrap(); + + match parsed { + WsResponse::NeedVoucher { + channel_id, + required_cumulative, + .. + } => { + assert_eq!(channel_id, "0xabc123"); + assert_eq!(required_cumulative, "2000000"); + } + _ => panic!("expected NeedVoucher"), + } +}