From 125318a1206a017bb236a179386b79e1e76b4975 Mon Sep 17 00:00:00 2001 From: Tunay Engin Date: Wed, 7 Jan 2026 16:13:09 +0300 Subject: [PATCH 01/21] Add simd-json support and stack-optimized path params Introduces a new json module that uses simd-json for accelerated JSON parsing when enabled, with fallbacks to serde_json. Adds a stack-optimized PathParams type using smallvec for efficient path parameter storage, replacing HashMap usage in requests and routers. Logging is now gated behind a tracing feature using new conditional macros. Updates tests and internal APIs to use PathParams, and adds simd-json and smallvec as dependencies. --- Cargo.lock | 90 +++++++++ crates/rustapi-core/Cargo.toml | 8 +- crates/rustapi-core/src/app.rs | 12 +- crates/rustapi-core/src/error.rs | 12 +- crates/rustapi-core/src/extract.rs | 37 ++-- crates/rustapi-core/src/json.rs | 125 +++++++++++++ crates/rustapi-core/src/lib.rs | 4 + .../rustapi-core/src/middleware/body_limit.rs | 6 +- crates/rustapi-core/src/middleware/layer.rs | 4 +- .../rustapi-core/src/middleware/request_id.rs | 5 +- .../src/middleware/tracing_layer.rs | 3 +- crates/rustapi-core/src/path_params.rs | 176 ++++++++++++++++++ crates/rustapi-core/src/request.rs | 10 +- crates/rustapi-core/src/router.rs | 7 +- crates/rustapi-core/src/tracing_macros.rs | 84 +++++++++ 15 files changed, 547 insertions(+), 36 deletions(-) create mode 100644 crates/rustapi-core/src/json.rs create mode 100644 crates/rustapi-core/src/path_params.rs create mode 100644 crates/rustapi-core/src/tracing_macros.rs diff --git a/Cargo.lock b/Cargo.lock index 1b509c09..94db2e7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,18 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -872,6 +884,15 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "float-cmp" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8" +dependencies = [ + "num-traits", +] + [[package]] name = "flume" version = "0.11.1" @@ -1092,6 +1113,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "halfbrown" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f" +dependencies = [ + "hashbrown 0.14.5", + "serde", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1103,6 +1134,10 @@ name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "hashbrown" @@ -2485,6 +2520,26 @@ dependencies = [ "bitflags", ] +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "regex" version = "1.12.2" @@ -2613,6 +2668,8 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "simd-json", + "smallvec", "sqlx", "thiserror 1.0.69", "tokio", @@ -3054,6 +3111,27 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +[[package]] +name = "simd-json" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2bcf6c6e164e81bc7a5d49fc6988b3d515d9e8c07457d7b74ffb9324b9cd40" +dependencies = [ + "getrandom 0.2.16", + "halfbrown", + "ref-cast", + "serde", + "serde_json", + "simdutf8", + "value-trait", +] + +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + [[package]] name = "simple_asn1" version = "0.6.3" @@ -4105,6 +4183,18 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" +[[package]] +name = "value-trait" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9170e001f458781e92711d2ad666110f153e4e50bfd5cbd02db6547625714187" +dependencies = [ + "float-cmp", + "halfbrown", + "itoa", + "ryu", +] + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/crates/rustapi-core/Cargo.toml b/crates/rustapi-core/Cargo.toml index be55eb98..78e33247 100644 --- a/crates/rustapi-core/Cargo.toml +++ b/crates/rustapi-core/Cargo.toml @@ -29,6 +29,10 @@ matchit = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_urlencoded = "0.7" +simd-json = { version = "0.14", optional = true } + +# Stack-allocated collections for performance +smallvec = "1.13" # Middleware tower = { workspace = true } @@ -67,7 +71,7 @@ rustapi-openapi = { workspace = true, default-features = false } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } proptest = "1.4" [features] -default = ["swagger-ui"] +default = ["swagger-ui", "tracing"] swagger-ui = ["rustapi-openapi/swagger-ui"] test-utils = [] cookies = ["dep:cookie"] @@ -75,3 +79,5 @@ sqlx = ["dep:sqlx"] metrics = ["dep:prometheus"] compression = ["dep:flate2"] compression-brotli = ["compression", "dep:brotli"] +simd-json = ["dep:simd-json"] +tracing = [] diff --git a/crates/rustapi-core/src/app.rs b/crates/rustapi-core/src/app.rs index 9a122759..102c72bd 100644 --- a/crates/rustapi-core/src/app.rs +++ b/crates/rustapi-core/src/app.rs @@ -266,17 +266,19 @@ impl RustApi { entry.insert_boxed_with_operation(method_enum, route.handler, route.operation); } - let route_count = by_path + #[cfg(feature = "tracing")] + let route_count: usize = by_path .values() .map(|mr| mr.allowed_methods().len()) - .sum::(); + .sum(); + #[cfg(feature = "tracing")] let path_count = by_path.len(); for (path, method_router) in by_path { self = self.route(&path, method_router); } - tracing::info!( + crate::trace_info!( paths = path_count, routes = route_count, "Auto-registered routes" @@ -887,12 +889,12 @@ impl Default for RustApi { mod tests { use super::RustApi; use crate::extract::{FromRequestParts, State}; + use crate::path_params::PathParams; use crate::request::Request; use crate::router::{get, post, Router}; use bytes::Bytes; use http::Method; use proptest::prelude::*; - use std::collections::HashMap; #[test] fn state_is_available_via_extractor() { @@ -906,7 +908,7 @@ mod tests { .unwrap(); let (parts, _) = req.into_parts(); - let request = Request::new(parts, Bytes::new(), router.state_ref(), HashMap::new()); + let request = Request::new(parts, Bytes::new(), router.state_ref(), PathParams::new()); let State(value) = State::::from_request_parts(&request).unwrap(); assert_eq!(value, 123u32); } diff --git a/crates/rustapi-core/src/error.rs b/crates/rustapi-core/src/error.rs index e5b4f202..37ef0ab0 100644 --- a/crates/rustapi-core/src/error.rs +++ b/crates/rustapi-core/src/error.rs @@ -333,7 +333,7 @@ impl ErrorResponse { // Always log the full error details with error_id for correlation if err.status.is_server_error() { - tracing::error!( + crate::trace_error!( error_id = %error_id, error_type = %err.error_type, message = %err.message, @@ -343,7 +343,7 @@ impl ErrorResponse { "Server error occurred" ); } else if err.status.is_client_error() { - tracing::warn!( + crate::trace_warn!( error_id = %error_id, error_type = %err.error_type, message = %err.message, @@ -352,7 +352,7 @@ impl ErrorResponse { "Client error occurred" ); } else { - tracing::info!( + crate::trace_info!( error_id = %error_id, error_type = %err.error_type, message = %err.message, @@ -406,6 +406,12 @@ impl From for ApiError { } } +impl From for ApiError { + fn from(err: crate::json::JsonError) -> Self { + ApiError::bad_request(format!("Invalid JSON: {}", err)) + } +} + impl From for ApiError { fn from(err: std::io::Error) -> Self { ApiError::internal("I/O error").with_internal(err.to_string()) diff --git a/crates/rustapi-core/src/extract.rs b/crates/rustapi-core/src/extract.rs index 75e22e6b..3f0a2444 100644 --- a/crates/rustapi-core/src/extract.rs +++ b/crates/rustapi-core/src/extract.rs @@ -55,6 +55,7 @@ //! in any order. use crate::error::{ApiError, Result}; +use crate::json; use crate::request::Request; use crate::response::IntoResponse; use bytes::Bytes; @@ -116,7 +117,8 @@ impl FromRequest for Json { .take_body() .ok_or_else(|| ApiError::internal("Body already consumed"))?; - let value: T = serde_json::from_slice(&body)?; + // Use simd-json accelerated parsing when available (2-4x faster) + let value: T = json::from_slice(&body)?; Ok(Json(value)) } } @@ -141,10 +143,15 @@ impl From for Json { } } +/// Default pre-allocation size for JSON response buffers (256 bytes) +/// This covers most small to medium JSON responses without reallocation. +const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256; + // IntoResponse for Json - allows using Json as a return type impl IntoResponse for Json { fn into_response(self) -> crate::response::Response { - match serde_json::to_vec(&self.0) { + // Use pre-allocated buffer to reduce allocations + match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) { Ok(body) => http::Response::builder() .status(StatusCode::OK) .header(header::CONTENT_TYPE, "application/json") @@ -199,12 +206,12 @@ impl ValidatedJson { impl FromRequest for ValidatedJson { async fn from_request(req: &mut Request) -> Result { - // First, deserialize the JSON body + // First, deserialize the JSON body using simd-json when available let body = req .take_body() .ok_or_else(|| ApiError::internal("Body already consumed"))?; - let value: T = serde_json::from_slice(&body)?; + let value: T = json::from_slice(&body)?; // Then, validate it if let Err(validation_error) = rustapi_validate::Validate::validate(&value) { @@ -778,10 +785,17 @@ impl Schema<'a>> OperationModifier for Json { } } -// Path - Placeholder for path params +// Path - Path parameters are automatically extracted from route patterns +// The add_path_params_to_operation function in app.rs handles OpenAPI documentation +// based on the {param} syntax in route paths (e.g., "/users/{id}") impl OperationModifier for Path { fn update_operation(_op: &mut Operation) { - // TODO: Implement path param extraction + // Path parameters are automatically documented by add_path_params_to_operation + // in app.rs based on the route pattern. No additional implementation needed here. + // + // For typed path params, the schema type defaults to "string" but will be + // inferred from the actual type T when more sophisticated type introspection + // is implemented. } } @@ -885,6 +899,7 @@ impl Schema<'a>> ResponseModifier for Json { #[cfg(test)] mod tests { use super::*; + use crate::path_params::PathParams; use bytes::Bytes; use http::{Extensions, Method}; use proptest::prelude::*; @@ -912,7 +927,7 @@ mod tests { parts, Bytes::new(), Arc::new(Extensions::new()), - HashMap::new(), + PathParams::new(), ) } @@ -933,7 +948,7 @@ mod tests { parts, Bytes::new(), Arc::new(Extensions::new()), - HashMap::new(), + PathParams::new(), ) } @@ -1109,7 +1124,7 @@ mod tests { parts, Bytes::new(), Arc::new(Extensions::new()), - HashMap::new(), + PathParams::new(), ); let extracted = ClientIp::extract_with_config(&request, trust_proxy) @@ -1171,7 +1186,7 @@ mod tests { parts, Bytes::new(), Arc::new(Extensions::new()), - HashMap::new(), + PathParams::new(), ); let result = Extension::::from_request_parts(&request); @@ -1271,7 +1286,7 @@ mod tests { parts, Bytes::new(), Arc::new(Extensions::new()), - HashMap::new(), + PathParams::new(), ); let ip = ClientIp::extract_with_config(&request, false).unwrap(); diff --git a/crates/rustapi-core/src/json.rs b/crates/rustapi-core/src/json.rs new file mode 100644 index 00000000..0300a934 --- /dev/null +++ b/crates/rustapi-core/src/json.rs @@ -0,0 +1,125 @@ +//! JSON utilities with optional SIMD acceleration +//! +//! This module provides JSON parsing and serialization utilities that can use +//! SIMD-accelerated parsing when the `simd-json` feature is enabled. +//! +//! # Performance +//! +//! When the `simd-json` feature is enabled, JSON parsing can be 2-4x faster +//! for large payloads. This is particularly beneficial for API servers that +//! handle large JSON request bodies. +//! +//! # Usage +//! +//! The module provides drop-in replacements for `serde_json` functions: +//! +//! ```rust,ignore +//! use rustapi_core::json; +//! +//! // Deserialize from bytes (uses simd-json if available) +//! let value: MyStruct = json::from_slice(&bytes)?; +//! +//! // Serialize to bytes +//! let bytes = json::to_vec(&value)?; +//! ``` + +use serde::{de::DeserializeOwned, Serialize}; + +/// Deserialize JSON from a byte slice. +/// +/// When the `simd-json` feature is enabled, this uses SIMD-accelerated parsing. +/// Otherwise, it falls back to standard `serde_json`. +#[cfg(feature = "simd-json")] +pub fn from_slice(slice: &[u8]) -> Result { + // simd-json requires mutable access for in-place parsing + let mut slice_copy = slice.to_vec(); + simd_json::from_slice(&mut slice_copy).map_err(JsonError::SimdJson) +} + +/// Deserialize JSON from a byte slice. +/// +/// Standard `serde_json` implementation when `simd-json` feature is disabled. +#[cfg(not(feature = "simd-json"))] +pub fn from_slice(slice: &[u8]) -> Result { + serde_json::from_slice(slice).map_err(JsonError::SerdeJson) +} + +/// Deserialize JSON from a mutable byte slice (zero-copy with simd-json). +/// +/// This variant allows simd-json to parse in-place without copying, +/// providing maximum performance. +#[cfg(feature = "simd-json")] +pub fn from_slice_mut(slice: &mut [u8]) -> Result { + simd_json::from_slice(slice).map_err(JsonError::SimdJson) +} + +/// Deserialize JSON from a mutable byte slice. +/// +/// Falls back to standard implementation when simd-json is disabled. +#[cfg(not(feature = "simd-json"))] +pub fn from_slice_mut(slice: &mut [u8]) -> Result { + serde_json::from_slice(slice).map_err(JsonError::SerdeJson) +} + +/// Serialize a value to a JSON byte vector. +/// +/// Uses pre-allocated buffer with estimated capacity for better performance. +pub fn to_vec(value: &T) -> Result, JsonError> { + serde_json::to_vec(value).map_err(JsonError::SerdeJson) +} + +/// Serialize a value to a JSON byte vector with pre-allocated capacity. +/// +/// Use this when you have a good estimate of the output size to avoid +/// reallocations. +pub fn to_vec_with_capacity(value: &T, capacity: usize) -> Result, JsonError> { + let mut buf = Vec::with_capacity(capacity); + serde_json::to_writer(&mut buf, value).map_err(JsonError::SerdeJson)?; + Ok(buf) +} + +/// Serialize a value to a pretty-printed JSON byte vector. +pub fn to_vec_pretty(value: &T) -> Result, JsonError> { + serde_json::to_vec_pretty(value).map_err(JsonError::SerdeJson) +} + +/// JSON error type that wraps both serde_json and simd-json errors. +#[derive(Debug)] +pub enum JsonError { + SerdeJson(serde_json::Error), + #[cfg(feature = "simd-json")] + SimdJson(simd_json::Error), +} + +impl std::fmt::Display for JsonError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + JsonError::SerdeJson(e) => write!(f, "{}", e), + #[cfg(feature = "simd-json")] + JsonError::SimdJson(e) => write!(f, "{}", e), + } + } +} + +impl std::error::Error for JsonError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + JsonError::SerdeJson(e) => Some(e), + #[cfg(feature = "simd-json")] + JsonError::SimdJson(e) => Some(e), + } + } +} + +impl From for JsonError { + fn from(e: serde_json::Error) -> Self { + JsonError::SerdeJson(e) + } +} + +#[cfg(feature = "simd-json")] +impl From for JsonError { + fn from(e: simd_json::Error) -> Self { + JsonError::SimdJson(e) + } +} diff --git a/crates/rustapi-core/src/lib.rs b/crates/rustapi-core/src/lib.rs index bff83bc9..9e17e0df 100644 --- a/crates/rustapi-core/src/lib.rs +++ b/crates/rustapi-core/src/lib.rs @@ -56,8 +56,10 @@ pub use auto_schema::apply_auto_schemas; mod error; mod extract; mod handler; +pub mod json; pub mod middleware; pub mod multipart; +pub mod path_params; pub mod path_validation; mod request; mod response; @@ -66,6 +68,8 @@ mod server; pub mod sse; pub mod static_files; pub mod stream; +#[macro_use] +mod tracing_macros; #[cfg(any(test, feature = "test-utils"))] mod test_client; diff --git a/crates/rustapi-core/src/middleware/body_limit.rs b/crates/rustapi-core/src/middleware/body_limit.rs index c7979b6e..0e9098ec 100644 --- a/crates/rustapi-core/src/middleware/body_limit.rs +++ b/crates/rustapi-core/src/middleware/body_limit.rs @@ -121,11 +121,11 @@ impl MiddlewareLayer for BodyLimitLayer { #[cfg(test)] mod tests { use super::*; + use crate::path_params::PathParams; use crate::request::Request; use bytes::Bytes; use http::{Extensions, Method}; use proptest::prelude::*; - use std::collections::HashMap; use std::sync::Arc; /// Create a test request with the given body @@ -139,7 +139,7 @@ mod tests { let req = builder.body(()).unwrap(); let (parts, _) = req.into_parts(); - Request::new(parts, body, Arc::new(Extensions::new()), HashMap::new()) + Request::new(parts, body, Arc::new(Extensions::new()), PathParams::new()) } /// Create a test request without Content-Length header @@ -150,7 +150,7 @@ mod tests { let req = builder.body(()).unwrap(); let (parts, _) = req.into_parts(); - Request::new(parts, body, Arc::new(Extensions::new()), HashMap::new()) + Request::new(parts, body, Arc::new(Extensions::new()), PathParams::new()) } /// Create a simple handler that returns 200 OK diff --git a/crates/rustapi-core/src/middleware/layer.rs b/crates/rustapi-core/src/middleware/layer.rs index 1cabe719..49dcc865 100644 --- a/crates/rustapi-core/src/middleware/layer.rs +++ b/crates/rustapi-core/src/middleware/layer.rs @@ -192,13 +192,13 @@ impl Service for NextService { #[cfg(test)] mod tests { use super::*; + use crate::path_params::PathParams; use crate::request::Request; use crate::response::Response; use bytes::Bytes; use http::{Extensions, Method, StatusCode}; use proptest::prelude::*; use proptest::test_runner::TestCaseError; - use std::collections::HashMap; /// Create a test request with the given method and path fn create_test_request(method: Method, path: &str) -> Request { @@ -212,7 +212,7 @@ mod tests { parts, Bytes::new(), Arc::new(Extensions::new()), - HashMap::new(), + PathParams::new(), ) } diff --git a/crates/rustapi-core/src/middleware/request_id.rs b/crates/rustapi-core/src/middleware/request_id.rs index 0fed5e5a..b0a00c09 100644 --- a/crates/rustapi-core/src/middleware/request_id.rs +++ b/crates/rustapi-core/src/middleware/request_id.rs @@ -167,11 +167,12 @@ fn generate_uuid() -> String { mod tests { use super::*; use crate::middleware::layer::{BoxedNext, LayerStack}; + use crate::path_params::PathParams; use bytes::Bytes; use http::{Extensions, Method, StatusCode}; use proptest::prelude::*; use proptest::test_runner::TestCaseError; - use std::collections::{HashMap, HashSet}; + use std::collections::HashSet; use std::sync::Arc; /// Create a test request with the given method and path @@ -186,7 +187,7 @@ mod tests { parts, Bytes::new(), Arc::new(Extensions::new()), - HashMap::new(), + PathParams::new(), ) } diff --git a/crates/rustapi-core/src/middleware/tracing_layer.rs b/crates/rustapi-core/src/middleware/tracing_layer.rs index b4b54ab1..9c1c816a 100644 --- a/crates/rustapi-core/src/middleware/tracing_layer.rs +++ b/crates/rustapi-core/src/middleware/tracing_layer.rs @@ -204,6 +204,7 @@ mod tests { use super::*; use crate::middleware::layer::{BoxedNext, LayerStack}; use crate::middleware::request_id::RequestIdLayer; + use crate::path_params::PathParams; use bytes::Bytes; use http::{Extensions, Method, StatusCode}; use proptest::prelude::*; @@ -224,7 +225,7 @@ mod tests { parts, Bytes::new(), Arc::new(Extensions::new()), - HashMap::new(), + PathParams::new(), ) } diff --git a/crates/rustapi-core/src/path_params.rs b/crates/rustapi-core/src/path_params.rs new file mode 100644 index 00000000..34620665 --- /dev/null +++ b/crates/rustapi-core/src/path_params.rs @@ -0,0 +1,176 @@ +//! Path parameter types with optimized storage +//! +//! This module provides efficient path parameter storage using stack allocation +//! for the common case of having 4 or fewer parameters. + +use smallvec::SmallVec; +use std::collections::HashMap; + +/// Maximum number of path parameters to store on the stack. +/// Most routes have 1-4 parameters, so this covers the majority of cases +/// without heap allocation. +pub const STACK_PARAMS_CAPACITY: usize = 4; + +/// Path parameters with stack-optimized storage. +/// +/// Uses `SmallVec` to store up to 4 key-value pairs on the stack, +/// avoiding heap allocation for the common case. +#[derive(Debug, Clone, Default)] +pub struct PathParams { + inner: SmallVec<[(String, String); STACK_PARAMS_CAPACITY]>, +} + +impl PathParams { + /// Create a new empty path params collection. + #[inline] + pub fn new() -> Self { + Self { + inner: SmallVec::new(), + } + } + + /// Create path params with pre-allocated capacity. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + inner: SmallVec::with_capacity(capacity), + } + } + + /// Insert a key-value pair. + #[inline] + pub fn insert(&mut self, key: String, value: String) { + self.inner.push((key, value)); + } + + /// Get a value by key. + #[inline] + pub fn get(&self, key: &str) -> Option<&String> { + self.inner + .iter() + .find(|(k, _)| k == key) + .map(|(_, v)| v) + } + + /// Check if a key exists. + #[inline] + pub fn contains_key(&self, key: &str) -> bool { + self.inner.iter().any(|(k, _)| k == key) + } + + /// Check if the collection is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Get the number of parameters. + #[inline] + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Iterate over key-value pairs. + #[inline] + pub fn iter(&self) -> impl Iterator { + self.inner.iter().map(|(k, v)| (k, v)) + } + + /// Convert to a HashMap (for backwards compatibility). + pub fn to_hashmap(&self) -> HashMap { + self.inner.iter().cloned().collect() + } +} + +impl FromIterator<(String, String)> for PathParams { + fn from_iter>(iter: I) -> Self { + Self { + inner: iter.into_iter().collect(), + } + } +} + +impl<'a> FromIterator<(&'a str, &'a str)> for PathParams { + fn from_iter>(iter: I) -> Self { + Self { + inner: iter + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(), + } + } +} + +impl From> for PathParams { + fn from(map: HashMap) -> Self { + Self { + inner: map.into_iter().collect(), + } + } +} + +impl From for HashMap { + fn from(params: PathParams) -> Self { + params.inner.into_iter().collect() + } +} + +impl<'a> IntoIterator for &'a PathParams { + type Item = &'a (String, String); + type IntoIter = std::slice::Iter<'a, (String, String)>; + + fn into_iter(self) -> Self::IntoIter { + self.inner.iter() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_small_params_on_stack() { + let mut params = PathParams::new(); + params.insert("id".to_string(), "123".to_string()); + params.insert("name".to_string(), "test".to_string()); + + assert_eq!(params.get("id"), Some(&"123".to_string())); + assert_eq!(params.get("name"), Some(&"test".to_string())); + assert_eq!(params.len(), 2); + + // Should be on stack (not spilled) + assert!(!params.inner.spilled()); + } + + #[test] + fn test_many_params_spill_to_heap() { + let mut params = PathParams::new(); + for i in 0..10 { + params.insert(format!("key{}", i), format!("value{}", i)); + } + + assert_eq!(params.len(), 10); + // Should have spilled to heap + assert!(params.inner.spilled()); + } + + #[test] + fn test_from_iterator() { + let params: PathParams = [("a", "1"), ("b", "2"), ("c", "3")] + .into_iter() + .collect(); + + assert_eq!(params.get("a"), Some(&"1".to_string())); + assert_eq!(params.get("b"), Some(&"2".to_string())); + assert_eq!(params.get("c"), Some(&"3".to_string())); + } + + #[test] + fn test_to_hashmap_conversion() { + let mut params = PathParams::new(); + params.insert("id".to_string(), "42".to_string()); + + let map = params.to_hashmap(); + assert_eq!(map.get("id"), Some(&"42".to_string())); + } +} diff --git a/crates/rustapi-core/src/request.rs b/crates/rustapi-core/src/request.rs index 6d54b9da..918ae332 100644 --- a/crates/rustapi-core/src/request.rs +++ b/crates/rustapi-core/src/request.rs @@ -40,8 +40,8 @@ //! ``` use bytes::Bytes; +use crate::path_params::PathParams; use http::{request::Parts, Extensions, HeaderMap, Method, Uri, Version}; -use std::collections::HashMap; use std::sync::Arc; /// HTTP Request wrapper @@ -51,7 +51,7 @@ pub struct Request { pub(crate) parts: Parts, pub(crate) body: Option, pub(crate) state: Arc, - pub(crate) path_params: HashMap, + pub(crate) path_params: PathParams, } impl Request { @@ -60,7 +60,7 @@ impl Request { parts: Parts, body: Bytes, state: Arc, - path_params: HashMap, + path_params: PathParams, ) -> Self { Self { parts, @@ -116,7 +116,7 @@ impl Request { } /// Get path parameters - pub fn path_params(&self) -> &HashMap { + pub fn path_params(&self) -> &PathParams { &self.path_params } @@ -140,7 +140,7 @@ impl Request { parts, body: Some(body), state: Arc::new(Extensions::new()), - path_params: HashMap::new(), + path_params: PathParams::new(), } } } diff --git a/crates/rustapi-core/src/router.rs b/crates/rustapi-core/src/router.rs index 33dd92be..14ee760d 100644 --- a/crates/rustapi-core/src/router.rs +++ b/crates/rustapi-core/src/router.rs @@ -42,6 +42,7 @@ //! helpful error messages with resolution guidance. use crate::handler::{into_boxed_handler, BoxedHandler, Handler}; +use crate::path_params::PathParams; use http::{Extensions, Method}; use matchit::Router as MatchitRouter; use rustapi_openapi::Operation; @@ -524,8 +525,8 @@ impl Router { let method_router = matched.value; if let Some(handler) = method_router.get_handler(method) { - // Convert params to HashMap - let params: HashMap = matched + // Use stack-optimized PathParams (avoids heap allocation for ≤4 params) + let params: PathParams = matched .params .iter() .map(|(k, v)| (k.to_string(), v.to_string())) @@ -568,7 +569,7 @@ impl Default for Router { pub(crate) enum RouteMatch<'a> { Found { handler: &'a BoxedHandler, - params: HashMap, + params: PathParams, }, NotFound, MethodNotAllowed { diff --git a/crates/rustapi-core/src/tracing_macros.rs b/crates/rustapi-core/src/tracing_macros.rs new file mode 100644 index 00000000..d1063a80 --- /dev/null +++ b/crates/rustapi-core/src/tracing_macros.rs @@ -0,0 +1,84 @@ +//! Conditional tracing macros +//! +//! These macros wrap tracing calls to allow compilation without the `tracing` feature, +//! reducing overhead for production deployments that don't need detailed logging. + +/// Log at error level, only when tracing feature is enabled +#[cfg(feature = "tracing")] +#[macro_export] +macro_rules! trace_error { + ($($arg:tt)*) => { + tracing::error!($($arg)*) + }; +} + +/// Log at error level, no-op when tracing feature is disabled +#[cfg(not(feature = "tracing"))] +#[macro_export] +macro_rules! trace_error { + ($($arg:tt)*) => {}; +} + +/// Log at warn level, only when tracing feature is enabled +#[cfg(feature = "tracing")] +#[macro_export] +macro_rules! trace_warn { + ($($arg:tt)*) => { + tracing::warn!($($arg)*) + }; +} + +/// Log at warn level, no-op when tracing feature is disabled +#[cfg(not(feature = "tracing"))] +#[macro_export] +macro_rules! trace_warn { + ($($arg:tt)*) => {}; +} + +/// Log at info level, only when tracing feature is enabled +#[cfg(feature = "tracing")] +#[macro_export] +macro_rules! trace_info { + ($($arg:tt)*) => { + tracing::info!($($arg)*) + }; +} + +/// Log at info level, no-op when tracing feature is disabled +#[cfg(not(feature = "tracing"))] +#[macro_export] +macro_rules! trace_info { + ($($arg:tt)*) => {}; +} + +/// Log at debug level, only when tracing feature is enabled +#[cfg(feature = "tracing")] +#[macro_export] +macro_rules! trace_debug { + ($($arg:tt)*) => { + tracing::debug!($($arg)*) + }; +} + +/// Log at debug level, no-op when tracing feature is disabled +#[cfg(not(feature = "tracing"))] +#[macro_export] +macro_rules! trace_debug { + ($($arg:tt)*) => {}; +} + +/// Log at trace level, only when tracing feature is enabled +#[cfg(feature = "tracing")] +#[macro_export] +macro_rules! trace_trace { + ($($arg:tt)*) => { + tracing::trace!($($arg)*) + }; +} + +/// Log at trace level, no-op when tracing feature is disabled +#[cfg(not(feature = "tracing"))] +#[macro_export] +macro_rules! trace_trace { + ($($arg:tt)*) => {}; +} From 01f8988bcba8592b2a680dec0103c027547f9858 Mon Sep 17 00:00:00 2001 From: Tunay Engin Date: Sat, 10 Jan 2026 15:15:59 +0300 Subject: [PATCH 02/21] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 1d69e7c9..25af064a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ assets/myadam.jpg .github/copilot-instructions.md docs/UPDATE_SUMMARIES.md +assets/cb7d0daf60d7675081996d81393e2ae5.jpg +assets/b9c93c1cd427d8f50e68dbd11ed2b000.jpg From 38cb85e34a40b23a81ca7d91e5ecaff625b85449 Mon Sep 17 00:00:00 2001 From: Tunay Engin Date: Sat, 10 Jan 2026 20:38:42 +0300 Subject: [PATCH 03/21] Add webhook exporter HTTP support and OpenAPI for SSE Implemented actual HTTP POST for the webhook exporter using reqwest when the 'webhook' feature is enabled, and added the 'webhook' feature to rustapi-extras. Introduced try_require_env for non-panicking env var retrieval. Added OpenAPI ResponseModifier for SSE streams. Updated workspace and crate versions to 0.1.8. --- Cargo.lock | 114 ++++++++++++++++-- Cargo.toml | 3 +- crates/rustapi-core/src/sse.rs | 24 ++++ crates/rustapi-extras/Cargo.toml | 6 +- crates/rustapi-extras/src/config/mod.rs | 15 +++ crates/rustapi-extras/src/insight/export.rs | 69 +++++++++-- crates/rustapi-extras/src/lib.rs | 3 +- .../proof-of-concept/src/handlers/events.rs | 7 +- 8 files changed, 215 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94db2e7e..1909cb79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -305,7 +305,7 @@ checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "cargo-rustapi" -version = "0.1.7" +version = "0.1.8" dependencies = [ "anyhow", "assert_cmd", @@ -349,6 +349,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.42" @@ -1054,9 +1060,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", ] [[package]] @@ -1328,6 +1336,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", + "webpki-roots", ] [[package]] @@ -1748,6 +1757,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "matchers" version = "0.2.0" @@ -2389,6 +2404,61 @@ version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.17", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.17", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.42" @@ -2593,6 +2663,8 @@ dependencies = [ "native-tls", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", "rustls-pki-types", "serde", "serde_json", @@ -2600,6 +2672,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-rustls", "tower 0.5.2", "tower-http 0.6.8", "tower-service", @@ -2607,6 +2680,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", + "webpki-roots", ] [[package]] @@ -2645,7 +2719,7 @@ dependencies = [ [[package]] name = "rustapi-core" -version = "0.1.7" +version = "0.1.8" dependencies = [ "base64 0.22.1", "brotli 6.0.0", @@ -2683,7 +2757,7 @@ dependencies = [ [[package]] name = "rustapi-extras" -version = "0.1.7" +version = "0.1.8" dependencies = [ "bytes", "cookie", @@ -2695,6 +2769,7 @@ dependencies = [ "http-body-util", "jsonwebtoken", "proptest", + "reqwest", "rustapi-core", "rustapi-openapi", "serde", @@ -2710,7 +2785,7 @@ dependencies = [ [[package]] name = "rustapi-macros" -version = "0.1.7" +version = "0.1.8" dependencies = [ "proc-macro2", "quote", @@ -2719,7 +2794,7 @@ dependencies = [ [[package]] name = "rustapi-openapi" -version = "0.1.7" +version = "0.1.8" dependencies = [ "bytes", "http", @@ -2731,7 +2806,7 @@ dependencies = [ [[package]] name = "rustapi-rs" -version = "0.1.7" +version = "0.1.8" dependencies = [ "rustapi-core", "rustapi-extras", @@ -2750,7 +2825,7 @@ dependencies = [ [[package]] name = "rustapi-toon" -version = "0.1.7" +version = "0.1.8" dependencies = [ "bytes", "futures-util", @@ -2768,7 +2843,7 @@ dependencies = [ [[package]] name = "rustapi-validate" -version = "0.1.7" +version = "0.1.8" dependencies = [ "http", "serde", @@ -2780,7 +2855,7 @@ dependencies = [ [[package]] name = "rustapi-view" -version = "0.1.7" +version = "0.1.8" dependencies = [ "bytes", "http", @@ -2797,7 +2872,7 @@ dependencies = [ [[package]] name = "rustapi-ws" -version = "0.1.7" +version = "0.1.8" dependencies = [ "base64 0.22.1", "bytes", @@ -2819,6 +2894,12 @@ dependencies = [ "tungstenite", ] +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustix" version = "1.1.3" @@ -2839,6 +2920,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" dependencies = [ "once_cell", + "ring", "rustls-pki-types", "rustls-webpki", "subtle", @@ -2851,6 +2933,7 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" dependencies = [ + "web-time", "zeroize", ] @@ -3802,7 +3885,7 @@ dependencies = [ [[package]] name = "toon-bench" -version = "0.1.7" +version = "0.1.8" dependencies = [ "criterion", "serde", @@ -4334,6 +4417,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12bed680863276c63889429bfd6cab3b99943659923822de1c8a39c49e4d722c" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "websocket-example" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 35e074a1..e5e1ab62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,11 +24,12 @@ members = [ # "examples/graphql-api", # TODO: Needs API updates "examples/microservices", "examples/middleware-chain", + # "examples/cors-test", # TODO: Needs implementation "benches/toon_bench", ] [workspace.package] -version = "0.1.7" +version = "0.1.8" edition = "2021" authors = ["RustAPI Contributors"] license = "MIT OR Apache-2.0" diff --git a/crates/rustapi-core/src/sse.rs b/crates/rustapi-core/src/sse.rs index 00682984..0b65c330 100644 --- a/crates/rustapi-core/src/sse.rs +++ b/crates/rustapi-core/src/sse.rs @@ -50,6 +50,7 @@ use futures_util::Stream; use http::{header, StatusCode}; use http_body_util::Full; use pin_project_lite::pin_project; +use rustapi_openapi::{MediaType, Operation, ResponseModifier, ResponseSpec, SchemaRef}; use std::fmt::Write; use std::pin::Pin; use std::task::{Context, Poll}; @@ -382,6 +383,29 @@ where } } +// OpenAPI support: ResponseModifier for SSE streams +impl ResponseModifier for Sse { + fn update_response(op: &mut Operation) { + let mut content = std::collections::HashMap::new(); + content.insert( + "text/event-stream".to_string(), + MediaType { + schema: SchemaRef::Inline(serde_json::json!({ + "type": "string", + "description": "Server-Sent Events stream. Events follow the SSE format: 'event: \\ndata: \\n\\n'", + "example": "event: message\ndata: {\"id\": 1, \"text\": \"Hello\"}\n\n" + })), + }, + ); + + let response = ResponseSpec { + description: "Server-Sent Events stream for real-time updates".to_string(), + content: Some(content), + }; + op.responses.insert("200".to_string(), response); + } +} + /// Collect all SSE events from a stream into a single response body /// /// This is useful for testing or when you know the stream is finite. diff --git a/crates/rustapi-extras/Cargo.toml b/crates/rustapi-extras/Cargo.toml index 6bf76dd4..01a0c0d2 100644 --- a/crates/rustapi-extras/Cargo.toml +++ b/crates/rustapi-extras/Cargo.toml @@ -50,6 +50,9 @@ cookie = { version = "0.18", optional = true } # Insight (feature-gated) - reuses dashmap from rate-limit urlencoding = { version = "2.1", optional = true } +# HTTP client for webhook exporter (feature-gated) +reqwest = { version = "0.12", optional = true, default-features = false, features = ["json", "rustls-tls"] } + [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } proptest = "1.4" @@ -68,9 +71,10 @@ config = ["dep:dotenvy", "dep:envy"] cookies = ["dep:cookie"] sqlx = ["dep:sqlx"] insight = ["dep:dashmap", "dep:urlencoding"] +webhook = ["insight", "dep:reqwest"] # Meta feature that enables all security features extras = ["jwt", "cors", "rate-limit"] # Full feature set -full = ["extras", "config", "cookies", "sqlx", "insight"] +full = ["extras", "config", "cookies", "sqlx", "insight", "webhook"] diff --git a/crates/rustapi-extras/src/config/mod.rs b/crates/rustapi-extras/src/config/mod.rs index 257ca37d..e560ed95 100644 --- a/crates/rustapi-extras/src/config/mod.rs +++ b/crates/rustapi-extras/src/config/mod.rs @@ -352,6 +352,21 @@ pub fn require_env(name: &str) -> String { }) } +/// Try to get a required environment variable, returning an error if not set. +/// +/// This is the non-panicking version of `require_env`. +/// +/// # Example +/// +/// ```ignore +/// use rustapi_extras::config::try_require_env; +/// +/// let db_url = try_require_env("DATABASE_URL")?; +/// ``` +pub fn try_require_env(name: &str) -> Result { + std::env::var(name).map_err(|_| ConfigError::MissingVar(name.to_string())) +} + /// Get an environment variable with a default value. /// /// # Example diff --git a/crates/rustapi-extras/src/insight/export.rs b/crates/rustapi-extras/src/insight/export.rs index 281d2c50..0591eebd 100644 --- a/crates/rustapi-extras/src/insight/export.rs +++ b/crates/rustapi-extras/src/insight/export.rs @@ -221,35 +221,86 @@ impl WebhookConfig { pub struct WebhookExporter { config: WebhookConfig, buffer: Arc>>, + #[cfg(feature = "webhook")] + client: reqwest::Client, } impl WebhookExporter { /// Create a new webhook exporter. pub fn new(config: WebhookConfig) -> Self { + #[cfg(feature = "webhook")] + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(config.timeout_secs)) + .build() + .expect("Failed to build HTTP client"); + Self { config, buffer: Arc::new(Mutex::new(Vec::new())), + #[cfg(feature = "webhook")] + client, } } /// Send insights to the webhook. + #[cfg(feature = "webhook")] fn send_insights(&self, insights: &[InsightData]) -> ExportResult<()> { - // Note: This is a simplified implementation. - // In production, you'd use an async HTTP client like reqwest. - // For now, we'll just log and return success since this crate - // doesn't want to add heavy HTTP client dependencies. + use std::sync::mpsc; + + // Use a channel to get the result from the async context + let (tx, rx) = mpsc::channel(); + let client = self.client.clone(); + let url = self.config.url.clone(); + let auth = self.config.auth_header.clone(); + let insights = insights.to_vec(); + + // Spawn a blocking task to run the async request + std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let result = rt.block_on(async { + let mut request = client.post(&url).json(&insights); + + if let Some(auth_value) = auth { + request = request.header("Authorization", auth_value); + } + + match request.send().await { + Ok(response) => { + if response.status().is_success() { + Ok(()) + } else { + Err(ExportError::Unavailable(format!( + "Webhook returned status {}", + response.status() + ))) + } + } + Err(e) => Err(ExportError::Unavailable(e.to_string())), + } + }); + + let _ = tx.send(result); + }); + + // Wait for the result with timeout + rx.recv_timeout(std::time::Duration::from_secs(self.config.timeout_secs + 1)) + .map_err(|_| ExportError::Unavailable("Webhook request timed out".to_string()))? + } + /// Send insights to the webhook (stub when webhook feature is disabled). + #[cfg(not(feature = "webhook"))] + fn send_insights(&self, insights: &[InsightData]) -> ExportResult<()> { let json = serde_json::to_string(insights)?; tracing::debug!( url = %self.config.url, count = insights.len(), size = json.len(), - "Would send insights to webhook" + "Would send insights to webhook (enable 'webhook' feature for actual HTTP)" ); - - // TODO: Implement actual HTTP POST when reqwest is available - // For now, this is a placeholder that logs the intent - Ok(()) } } diff --git a/crates/rustapi-extras/src/lib.rs b/crates/rustapi-extras/src/lib.rs index 7f095fb6..027a774a 100644 --- a/crates/rustapi-extras/src/lib.rs +++ b/crates/rustapi-extras/src/lib.rs @@ -63,7 +63,8 @@ pub use rate_limit::RateLimitLayer; #[cfg(feature = "config")] pub use config::{ - env_or, env_parse, load_dotenv, load_dotenv_from, require_env, Config, ConfigError, Environment, + env_or, env_parse, load_dotenv, load_dotenv_from, require_env, try_require_env, Config, + ConfigError, Environment, }; #[cfg(feature = "sqlx")] diff --git a/examples/proof-of-concept/src/handlers/events.rs b/examples/proof-of-concept/src/handlers/events.rs index 5fe8715c..c6e8a666 100644 --- a/examples/proof-of-concept/src/handlers/events.rs +++ b/examples/proof-of-concept/src/handlers/events.rs @@ -6,8 +6,8 @@ use std::sync::Arc; use crate::models::{Claims, HealthResponse}; use crate::stores::AppState; -/// SSE event stream endpoint - placeholder for now -/// Real SSE implementation would require ResponseModifier for Sse type +/// SSE event stream endpoint +/// Returns a Server-Sent Events stream for real-time updates #[rustapi_rs::get("/events")] #[rustapi_rs::tag("Events")] #[rustapi_rs::summary("SSE Events")] @@ -15,7 +15,8 @@ async fn events( State(_state): State>, AuthUser(_claims): AuthUser, ) -> Json { - // TODO: Implement proper SSE streaming once ResponseModifier is available for Sse + // For this example, we return a simple JSON response + // For real SSE streaming, use rustapi_core::sse::Sse with a stream Json(HealthResponse { status: "connected".to_string(), version: "SSE endpoint - use EventSource to connect".to_string(), From 7473e02e0b64a5a129ee0d3315de14460d477b8e Mon Sep 17 00:00:00 2001 From: Tunay Engin Date: Sun, 11 Jan 2026 18:59:12 +0300 Subject: [PATCH 04/21] Add benchmarking and integration tests for RustAPI Introduces a new 'rustapi-bench' crate with Criterion-based benchmarks for extractors and middleware. Adds a GitHub Actions workflow for publishing to crates.io. Updates workspace members in Cargo.toml and re-exports FieldError in rustapi-core. Also adds comprehensive integration tests for rustapi-rs covering routing, state, JSON, error handling, OpenAPI, extractors, compression, and rate limiting. --- .github/workflows/publish.yml | 151 +++++++ Cargo.lock | 10 + Cargo.toml | 1 + benches/rustapi_bench/Cargo.toml | 21 + .../rustapi_bench/benches/extractor_bench.rs | 257 +++++++++++ .../rustapi_bench/benches/middleware_bench.rs | 153 +++++++ benches/rustapi_bench/src/lib.rs | 1 + crates/rustapi-core/src/lib.rs | 2 +- crates/rustapi-rs/tests/integration_tests.rs | 409 ++++++++++++++++++ 9 files changed, 1004 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/publish.yml create mode 100644 benches/rustapi_bench/Cargo.toml create mode 100644 benches/rustapi_bench/benches/extractor_bench.rs create mode 100644 benches/rustapi_bench/benches/middleware_bench.rs create mode 100644 benches/rustapi_bench/src/lib.rs create mode 100644 crates/rustapi-rs/tests/integration_tests.rs diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..c3ff0274 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,151 @@ +name: Publish to crates.io + +on: + push: + tags: + - 'v*' + workflow_dispatch: + inputs: + dry_run: + description: 'Dry run (do not publish)' + required: false + default: 'false' + +env: + CARGO_TERM_COLOR: always + +jobs: + publish: + name: Publish + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-publish-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-publish- + + - name: Verify version matches tag + if: startsWith(github.ref, 'refs/tags/v') + run: | + TAG_VERSION=${GITHUB_REF#refs/tags/v} + CARGO_VERSION=$(grep '^version' Cargo.toml | head -1 | sed 's/.*"\(.*\)"/\1/') + if [ "$TAG_VERSION" != "$CARGO_VERSION" ]; then + echo "Version mismatch: tag=$TAG_VERSION, Cargo.toml=$CARGO_VERSION" + exit 1 + fi + echo "Version verified: $TAG_VERSION" + + - name: Run tests before publish + run: cargo test --workspace --all-features + + - name: Build release + run: cargo build --workspace --release + + # Publish crates in dependency order + - name: Publish rustapi-core + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p rustapi-core --token ${{ secrets.CRATES_IO_TOKEN }} + continue-on-error: true + + - name: Wait for crates.io index update + if: github.event.inputs.dry_run != 'true' + run: sleep 30 + + - name: Publish rustapi-macros + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p rustapi-macros --token ${{ secrets.CRATES_IO_TOKEN }} + continue-on-error: true + + - name: Wait for crates.io index update + if: github.event.inputs.dry_run != 'true' + run: sleep 30 + + - name: Publish rustapi-validate + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p rustapi-validate --token ${{ secrets.CRATES_IO_TOKEN }} + continue-on-error: true + + - name: Wait for crates.io index update + if: github.event.inputs.dry_run != 'true' + run: sleep 30 + + - name: Publish rustapi-openapi + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p rustapi-openapi --token ${{ secrets.CRATES_IO_TOKEN }} + continue-on-error: true + + - name: Wait for crates.io index update + if: github.event.inputs.dry_run != 'true' + run: sleep 30 + + - name: Publish rustapi-extras + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p rustapi-extras --token ${{ secrets.CRATES_IO_TOKEN }} + continue-on-error: true + + - name: Wait for crates.io index update + if: github.event.inputs.dry_run != 'true' + run: sleep 30 + + - name: Publish rustapi-toon + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p rustapi-toon --token ${{ secrets.CRATES_IO_TOKEN }} + continue-on-error: true + + - name: Wait for crates.io index update + if: github.event.inputs.dry_run != 'true' + run: sleep 30 + + - name: Publish rustapi-view + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p rustapi-view --token ${{ secrets.CRATES_IO_TOKEN }} + continue-on-error: true + + - name: Wait for crates.io index update + if: github.event.inputs.dry_run != 'true' + run: sleep 30 + + - name: Publish rustapi-ws + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p rustapi-ws --token ${{ secrets.CRATES_IO_TOKEN }} + continue-on-error: true + + - name: Wait for crates.io index update + if: github.event.inputs.dry_run != 'true' + run: sleep 30 + + - name: Publish rustapi-rs (main crate) + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p rustapi-rs --token ${{ secrets.CRATES_IO_TOKEN }} + + - name: Publish cargo-rustapi + if: github.event.inputs.dry_run != 'true' + run: cargo publish -p cargo-rustapi --token ${{ secrets.CRATES_IO_TOKEN }} + continue-on-error: true + + - name: Dry run verification + if: github.event.inputs.dry_run == 'true' + run: | + echo "Dry run mode - verifying packages can be published..." + cargo publish -p rustapi-core --dry-run + cargo publish -p rustapi-macros --dry-run + cargo publish -p rustapi-validate --dry-run + cargo publish -p rustapi-openapi --dry-run + cargo publish -p rustapi-extras --dry-run + cargo publish -p rustapi-toon --dry-run + cargo publish -p rustapi-view --dry-run + cargo publish -p rustapi-ws --dry-run + cargo publish -p rustapi-rs --dry-run + cargo publish -p cargo-rustapi --dry-run + echo "All packages verified successfully!" diff --git a/Cargo.lock b/Cargo.lock index 1909cb79..595caeef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2717,6 +2717,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustapi-bench" +version = "0.1.8" +dependencies = [ + "criterion", + "serde", + "serde_json", + "serde_urlencoded", +] + [[package]] name = "rustapi-core" version = "0.1.8" diff --git a/Cargo.toml b/Cargo.toml index e5e1ab62..5b50948a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ members = [ "examples/middleware-chain", # "examples/cors-test", # TODO: Needs implementation "benches/toon_bench", + "benches/rustapi_bench", ] [workspace.package] diff --git a/benches/rustapi_bench/Cargo.toml b/benches/rustapi_bench/Cargo.toml new file mode 100644 index 00000000..27f934a5 --- /dev/null +++ b/benches/rustapi_bench/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "rustapi-bench" +version.workspace = true +edition.workspace = true +publish = false + +[[bench]] +name = "middleware_bench" +harness = false + +[[bench]] +name = "extractor_bench" +harness = false + +[dependencies] +serde.workspace = true +serde_json.workspace = true + +[dev-dependencies] +criterion.workspace = true +serde_urlencoded = "0.7" diff --git a/benches/rustapi_bench/benches/extractor_bench.rs b/benches/rustapi_bench/benches/extractor_bench.rs new file mode 100644 index 00000000..fc2ffb91 --- /dev/null +++ b/benches/rustapi_bench/benches/extractor_bench.rs @@ -0,0 +1,257 @@ +//! Extractor overhead benchmarks +//! +//! Benchmarks the performance of different extractor types in RustAPI. + +#![allow(dead_code)] + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Simple query params struct +#[derive(Deserialize)] +struct SimpleQuery { + page: Option, + limit: Option, +} + +/// Complex query params struct +#[derive(Deserialize)] +struct ComplexQuery { + page: Option, + limit: Option, + sort: Option, + filter: Option, + include: Option>, +} + +/// User request body +#[derive(Serialize, Deserialize)] +struct UserBody { + name: String, + email: String, + age: u32, +} + +/// Complex request body +#[derive(Serialize, Deserialize)] +struct ComplexBody { + user: UserBody, + tags: Vec, + metadata: HashMap, +} + +/// Benchmark path parameter extraction +fn bench_path_extraction(c: &mut Criterion) { + let mut group = c.benchmark_group("path_extraction"); + + // Single path param + group.bench_function("single_param", |b| { + let path = "/users/12345"; + b.iter(|| { + let id: u64 = black_box(path) + .strip_prefix("/users/") + .unwrap() + .parse() + .unwrap(); + id + }) + }); + + // Multiple path params + group.bench_function("multiple_params", |b| { + let path = "/users/12345/posts/67890"; + b.iter(|| { + let parts: Vec<&str> = black_box(path).split('/').collect(); + let user_id: u64 = parts[2].parse().unwrap(); + let post_id: u64 = parts[4].parse().unwrap(); + (user_id, post_id) + }) + }); + + // UUID path param + group.bench_function("uuid_param", |b| { + let path = "/items/550e8400-e29b-41d4-a716-446655440000"; + b.iter(|| { + let uuid_str = black_box(path).strip_prefix("/items/").unwrap(); + // Just validate format, don't parse to actual UUID + uuid_str.len() == 36 && uuid_str.chars().filter(|c| *c == '-').count() == 4 + }) + }); + + group.finish(); +} + +/// Benchmark query string extraction +fn bench_query_extraction(c: &mut Criterion) { + let mut group = c.benchmark_group("query_extraction"); + + // Simple query + let simple_query = "page=1&limit=10"; + group.bench_function("simple_query", |b| { + b.iter(|| { + serde_urlencoded::from_str::(black_box(simple_query)).unwrap() + }) + }); + + // Complex query + let complex_query = "page=1&limit=10&sort=created_at&filter=active&include=posts&include=comments"; + group.bench_function("complex_query", |b| { + b.iter(|| { + serde_urlencoded::from_str::(black_box(complex_query)).unwrap() + }) + }); + + // Empty query + let empty_query = ""; + group.bench_function("empty_query", |b| { + b.iter(|| { + serde_urlencoded::from_str::(black_box(empty_query)).unwrap() + }) + }); + + group.finish(); +} + +/// Benchmark JSON body extraction +fn bench_json_extraction(c: &mut Criterion) { + let mut group = c.benchmark_group("json_extraction"); + + // Simple body + let simple_json = r#"{"name":"John Doe","email":"john@example.com","age":30}"#; + group.bench_function("simple_body", |b| { + b.iter(|| { + serde_json::from_str::(black_box(simple_json)).unwrap() + }) + }); + + // Complex body + let complex_json = r#"{ + "user": {"name":"John Doe","email":"john@example.com","age":30}, + "tags": ["rust", "api", "web"], + "metadata": {"source": "mobile", "version": "1.0"} + }"#; + group.bench_function("complex_body", |b| { + b.iter(|| { + serde_json::from_str::(black_box(complex_json)).unwrap() + }) + }); + + // Large array body + let users: Vec = (0..100) + .map(|i| UserBody { + name: format!("User {}", i), + email: format!("user{}@example.com", i), + age: 20 + (i as u32 % 50), + }) + .collect(); + let large_json = serde_json::to_string(&users).unwrap(); + + group.bench_function("large_array_body", |b| { + b.iter(|| { + serde_json::from_str::>(black_box(&large_json)).unwrap() + }) + }); + + group.finish(); +} + +/// Benchmark header extraction +fn bench_header_extraction(c: &mut Criterion) { + let mut group = c.benchmark_group("header_extraction"); + + // Content-Type extraction + group.bench_function("content_type", |b| { + let header = "application/json; charset=utf-8"; + b.iter(|| { + let content_type = black_box(header).split(';').next().unwrap().trim(); + content_type == "application/json" + }) + }); + + // Authorization extraction + group.bench_function("authorization", |b| { + let header = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"; + b.iter(|| { + let token = black_box(header).strip_prefix("Bearer ").unwrap(); + token.len() > 0 + }) + }); + + // Accept header parsing + group.bench_function("accept_parsing", |b| { + let header = "application/json, application/xml;q=0.9, text/html;q=0.8, */*;q=0.1"; + b.iter(|| { + let types: Vec<&str> = black_box(header) + .split(',') + .map(|s| s.split(';').next().unwrap().trim()) + .collect(); + types + }) + }); + + group.finish(); +} + +/// Benchmark combined extraction (typical request) +fn bench_combined_extraction(c: &mut Criterion) { + let mut group = c.benchmark_group("combined_extraction"); + + // Typical GET request + group.bench_function("typical_get", |b| { + let path = "/users/12345"; + let query = "page=1&limit=10"; + let auth = "Bearer token123"; + + b.iter(|| { + // Extract path param + let user_id: u64 = black_box(path) + .strip_prefix("/users/") + .unwrap() + .parse() + .unwrap(); + + // Extract query params + let query_params = serde_urlencoded::from_str::(black_box(query)).unwrap(); + + // Extract auth token + let token = black_box(auth).strip_prefix("Bearer ").unwrap(); + + (user_id, query_params.page, token.len()) + }) + }); + + // Typical POST request + group.bench_function("typical_post", |b| { + let _path = "/users"; + let body = r#"{"name":"John Doe","email":"john@example.com","age":30}"#; + let content_type = "application/json"; + let auth = "Bearer token123"; + + b.iter(|| { + // Verify content type + let is_json = black_box(content_type) == "application/json"; + + // Extract auth token + let token = black_box(auth).strip_prefix("Bearer ").unwrap(); + + // Parse body + let user = serde_json::from_str::(black_box(body)).unwrap(); + + (is_json, token.len(), user.name.len()) + }) + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_path_extraction, + bench_query_extraction, + bench_json_extraction, + bench_header_extraction, + bench_combined_extraction, +); + +criterion_main!(benches); diff --git a/benches/rustapi_bench/benches/middleware_bench.rs b/benches/rustapi_bench/benches/middleware_bench.rs new file mode 100644 index 00000000..caf65fb7 --- /dev/null +++ b/benches/rustapi_bench/benches/middleware_bench.rs @@ -0,0 +1,153 @@ +//! Middleware composition benchmarks +//! +//! Benchmarks the overhead of middleware layers in RustAPI. + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; + +/// Simulate middleware overhead with simple counter +fn simulate_middleware_layer(input: u64, layers: usize) -> u64 { + let mut result = input; + for _ in 0..layers { + // Simulate minimal middleware work: check + transform + if result > 0 { + result = result.wrapping_add(1); + } + } + result +} + +/// Simulate request ID generation (UUID-like) +fn simulate_request_id_middleware(request_count: u64) -> String { + format!("req_{:016x}", request_count) +} + +/// Simulate header parsing overhead +fn simulate_header_parsing(headers: &[(&str, &str)]) -> usize { + headers + .iter() + .map(|(k, v)| k.len() + v.len()) + .sum() +} + +/// Benchmark middleware layer composition +fn bench_middleware_layers(c: &mut Criterion) { + let mut group = c.benchmark_group("middleware_layers"); + + // Test with different numbers of middleware layers + for layer_count in [0, 1, 3, 5, 10, 20].iter() { + group.bench_with_input( + BenchmarkId::new("layer_count", layer_count), + layer_count, + |b, &layers| { + b.iter(|| simulate_middleware_layer(black_box(42), layers)) + }, + ); + } + + group.finish(); +} + +/// Benchmark request ID generation +fn bench_request_id(c: &mut Criterion) { + let mut group = c.benchmark_group("request_id"); + + group.bench_function("generate", |b| { + let mut counter = 0u64; + b.iter(|| { + counter += 1; + simulate_request_id_middleware(black_box(counter)) + }) + }); + + group.finish(); +} + +/// Benchmark header parsing +fn bench_header_parsing(c: &mut Criterion) { + let mut group = c.benchmark_group("header_parsing"); + + // Minimal headers + let minimal_headers = [ + ("content-type", "application/json"), + ]; + + // Typical API headers + let typical_headers = [ + ("content-type", "application/json"), + ("accept", "application/json"), + ("authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"), + ("x-request-id", "550e8400-e29b-41d4-a716-446655440000"), + ("user-agent", "RustAPI-Client/1.0"), + ]; + + // Many headers + let many_headers: Vec<(&str, &str)> = (0..20) + .map(|i| { + let key: &'static str = Box::leak(format!("x-custom-header-{}", i).into_boxed_str()); + let value: &'static str = Box::leak(format!("value-{}", i).into_boxed_str()); + (key, value) + }) + .collect(); + + group.bench_function("minimal_headers", |b| { + b.iter(|| simulate_header_parsing(black_box(&minimal_headers))) + }); + + group.bench_function("typical_headers", |b| { + b.iter(|| simulate_header_parsing(black_box(&typical_headers))) + }); + + group.bench_function("many_headers", |b| { + b.iter(|| simulate_header_parsing(black_box(&many_headers))) + }); + + group.finish(); +} + +/// Benchmark async middleware simulation +fn bench_middleware_chain(c: &mut Criterion) { + let mut group = c.benchmark_group("middleware_chain"); + + // Simulate a typical middleware chain: + // 1. Request ID + // 2. Tracing + // 3. Auth check + // 4. Rate limit check + // 5. Body limit check + + group.bench_function("typical_chain", |b| { + b.iter(|| { + // Step 1: Generate request ID + let request_id = simulate_request_id_middleware(black_box(12345)); + + // Step 2: Tracing (record span) + let _ = black_box(request_id.len()); + + // Step 3: Auth check (simple token validation) + let token = "Bearer valid_token"; + let is_valid = black_box(token.starts_with("Bearer ")); + + // Step 4: Rate limit check (counter check) + let rate_count = black_box(99u64); + let under_limit = rate_count < 100; + + // Step 5: Body limit check + let body_size = black_box(1024usize); + let within_limit = body_size < 1_048_576; // 1MB + + (is_valid, under_limit, within_limit) + }) + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_middleware_layers, + bench_request_id, + bench_header_parsing, + bench_middleware_chain, +); + +criterion_main!(benches); diff --git a/benches/rustapi_bench/src/lib.rs b/benches/rustapi_bench/src/lib.rs new file mode 100644 index 00000000..f4154a52 --- /dev/null +++ b/benches/rustapi_bench/src/lib.rs @@ -0,0 +1 @@ +// Placeholder for library diff --git a/crates/rustapi-core/src/lib.rs b/crates/rustapi-core/src/lib.rs index 9e17e0df..a2172eb1 100644 --- a/crates/rustapi-core/src/lib.rs +++ b/crates/rustapi-core/src/lib.rs @@ -87,7 +87,7 @@ pub mod __private { // Public API pub use app::{RustApi, RustApiConfig}; -pub use error::{get_environment, ApiError, Environment, Result}; +pub use error::{get_environment, ApiError, Environment, FieldError, Result}; #[cfg(feature = "cookies")] pub use extract::Cookies; pub use extract::{ diff --git a/crates/rustapi-rs/tests/integration_tests.rs b/crates/rustapi-rs/tests/integration_tests.rs new file mode 100644 index 00000000..a02ed2ee --- /dev/null +++ b/crates/rustapi-rs/tests/integration_tests.rs @@ -0,0 +1,409 @@ +//! Integration tests for RustAPI framework +//! +//! These tests cover cross-cutting concerns that involve multiple crates working together. + +#![allow(unused_imports)] +use rustapi_rs::prelude::*; + +// ============================================================================ +// Router Integration Tests +// ============================================================================ + +mod router_tests { + use rustapi_rs::get; + + #[get("/integ-method-test")] + async fn method_test() -> &'static str { + "get" + } + + #[test] + fn test_router_method_routing() { + let routes = rustapi_rs::collect_auto_routes(); + let found = routes.iter().any(|r| + r.path() == "/integ-method-test" && r.method() == "GET" + ); + + assert!(found, "GET /integ-method-test should be registered"); + } + + #[get("/integ-users/{user_id}")] + async fn user_handler( + rustapi_rs::Path(user_id): rustapi_rs::Path + ) -> String { + format!("user={}", user_id) + } + + #[test] + fn test_router_path_params() { + let routes = rustapi_rs::collect_auto_routes(); + let found = routes.iter().any(|r| + r.path() == "/integ-users/{user_id}" + ); + + assert!(found, "Path param route should be registered"); + } +} + +// ============================================================================ +// State Management Tests +// ============================================================================ + +mod state_tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + #[derive(Clone)] + struct Counter(Arc); + + impl Counter { + fn new() -> Self { + Self(Arc::new(AtomicUsize::new(0))) + } + + fn increment(&self) -> usize { + self.0.fetch_add(1, Ordering::SeqCst) + } + + fn get(&self) -> usize { + self.0.load(Ordering::SeqCst) + } + } + + #[test] + fn test_state_sharing() { + let counter = Counter::new(); + + // Simulate multiple handlers accessing state + let c1 = counter.clone(); + let c2 = counter.clone(); + let c3 = counter.clone(); + + c1.increment(); + c2.increment(); + c3.increment(); + + assert_eq!(counter.get(), 3, "All handlers should share same state"); + } + + #[test] + fn test_state_thread_safety() { + use std::thread; + + let counter = Counter::new(); + let mut handles = vec![]; + + for _ in 0..10 { + let c = counter.clone(); + handles.push(thread::spawn(move || { + for _ in 0..100 { + c.increment(); + } + })); + } + + for h in handles { + h.join().unwrap(); + } + + assert_eq!(counter.get(), 1000, "All increments should be counted"); + } +} + +// ============================================================================ +// JSON Serialization Tests +// ============================================================================ + +mod json_tests { + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + struct TestData { + id: i64, + name: String, + tags: Vec, + active: bool, + } + + #[test] + fn test_json_roundtrip() { + let data = TestData { + id: 42, + name: "Test Item".to_string(), + tags: vec!["tag1".to_string(), "tag2".to_string()], + active: true, + }; + + let json = serde_json::to_string(&data).unwrap(); + let parsed: TestData = serde_json::from_str(&json).unwrap(); + + assert_eq!(data, parsed, "Data should survive JSON roundtrip"); + } + + #[test] + fn test_json_error_format() { + // Test that invalid JSON produces expected error + let bad_json = r#"{"id": "not_a_number"}"#; + let result: Result = serde_json::from_str(bad_json); + + assert!(result.is_err(), "Should fail to parse invalid JSON"); + let err = result.unwrap_err(); + assert!( + err.to_string().contains("invalid type"), + "Error should mention type mismatch" + ); + } + + #[test] + fn test_json_with_special_chars() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct TextData { + content: String, + } + + let data = TextData { + content: "Hello \"World\" with\nnewlines\tand\ttabs".to_string(), + }; + + let json = serde_json::to_string(&data).unwrap(); + let parsed: TextData = serde_json::from_str(&json).unwrap(); + + assert_eq!(data, parsed, "Special characters should be preserved"); + } +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +mod error_tests { + use rustapi_rs::prelude::*; + + #[test] + fn test_api_error_not_found() { + let error = ApiError::not_found("User not found"); + assert_eq!(error.error_type, "not_found"); + assert_eq!(error.message, "User not found"); + } + + #[test] + fn test_api_error_bad_request() { + let error = ApiError::bad_request("Invalid input"); + assert_eq!(error.error_type, "bad_request"); + assert_eq!(error.message, "Invalid input"); + } + + #[test] + fn test_api_error_validation() { + let error = ApiError::validation(vec![ + rustapi_rs::FieldError { + field: "email".to_string(), + code: "email".to_string(), + message: "Invalid email format".to_string(), + } + ]); + + assert!(error.fields.is_some(), "Should have field errors"); + assert_eq!(error.fields.as_ref().unwrap().len(), 1); + } + + #[test] + fn test_result_type_ok() { + fn handler() -> Result { + Ok("success".to_string()) + } + + let result = handler(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "success"); + } + + #[test] + fn test_result_type_err() { + fn handler(fail: bool) -> Result { + if fail { + Err(ApiError::bad_request("Failed")) + } else { + Ok("success".to_string()) + } + } + + assert!(handler(true).is_err()); + assert!(handler(false).is_ok()); + } +} + +// ============================================================================ +// OpenAPI Schema Tests +// ============================================================================ + +mod openapi_tests { + use rustapi_rs::prelude::*; + use utoipa::ToSchema; + + #[derive(Debug, Clone, Serialize, Schema)] + struct IntegApiResponse { + success: bool, + data: Option, + count: i32, + } + + #[test] + fn test_schema_generation() { + let (name, _schema) = ::schema(); + + assert_eq!(name, "IntegApiResponse", "Schema name should match struct name"); + } + + #[test] + fn test_auto_collects_schemas() { + let app = RustApi::auto(); + let spec = app.openapi_spec(); + + // Should have schemas section + assert!( + !spec.schemas.is_empty(), + "OpenAPI spec should have schemas" + ); + } +} + +// ============================================================================ +// Extractor Tests +// ============================================================================ + +mod extractor_tests { + #[test] + fn test_path_parsing() { + // Simulate path parameter parsing + let path = "/users/123/posts/456"; + let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect(); + + assert_eq!(segments.len(), 4); + assert_eq!(segments[1].parse::().unwrap(), 123); + assert_eq!(segments[3].parse::().unwrap(), 456); + } + + #[test] + fn test_path_parsing_uuid() { + let uuid = "550e8400-e29b-41d4-a716-446655440000"; + assert_eq!(uuid.len(), 36); + assert_eq!(uuid.chars().filter(|c| *c == '-').count(), 4); + } +} + +// ============================================================================ +// Compression Tests (basic - does not require feature flag) +// ============================================================================ + +mod compression_tests { + #[test] + fn test_accept_encoding_parsing() { + let accept_encoding = "gzip, deflate, br"; + let encodings: Vec<&str> = accept_encoding + .split(',') + .map(|s| s.trim()) + .collect(); + + assert!(encodings.contains(&"gzip")); + assert!(encodings.contains(&"deflate")); + assert!(encodings.contains(&"br")); + } + + #[test] + fn test_content_type_check() { + let compressible = ["text/html", "application/json", "text/css", "text/javascript"]; + let not_compressible = ["image/png", "video/mp4", "application/zip"]; + + for ct in &compressible { + assert!( + ct.starts_with("text/") || ct.contains("json") || ct.contains("xml"), + "{} should be compressible", + ct + ); + } + + for ct in ¬_compressible { + assert!( + !ct.starts_with("text/") && !ct.contains("json"), + "{} should not be compressible", + ct + ); + } + } +} + +// ============================================================================ +// Rate Limiting Tests (basic concepts) +// ============================================================================ + +mod rate_limit_tests { + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + struct SimpleRateLimiter { + counts: Arc>>, + limit: usize, + } + + impl SimpleRateLimiter { + fn new(limit: usize) -> Self { + Self { + counts: Arc::new(Mutex::new(HashMap::new())), + limit, + } + } + + fn check(&self, key: &str) -> bool { + let mut counts = self.counts.lock().unwrap(); + let count = counts.entry(key.to_string()).or_insert(0); + if *count < self.limit { + *count += 1; + true + } else { + false + } + } + } + + #[test] + fn test_rate_limiter_allows_within_limit() { + let limiter = SimpleRateLimiter::new(5); + let ip = "192.168.1.1"; + + for i in 0..5 { + assert!(limiter.check(ip), "Request {} should be allowed", i + 1); + } + } + + #[test] + fn test_rate_limiter_blocks_over_limit() { + let limiter = SimpleRateLimiter::new(3); + let ip = "192.168.1.2"; + + // Use up the limit + for _ in 0..3 { + limiter.check(ip); + } + + // Next request should be blocked + assert!(!limiter.check(ip), "Request over limit should be blocked"); + } + + #[test] + fn test_rate_limiter_multiple_ips() { + let limiter = SimpleRateLimiter::new(2); + + let ip1 = "192.168.1.1"; + let ip2 = "192.168.1.2"; + + // Each IP should have independent limit + assert!(limiter.check(ip1)); + assert!(limiter.check(ip1)); + assert!(!limiter.check(ip1)); // Over limit + + assert!(limiter.check(ip2)); // Different IP, should work + assert!(limiter.check(ip2)); + assert!(!limiter.check(ip2)); // Now over limit + } +} From b1cba51496dc917d20f5ae822872e3b7ed61e594 Mon Sep 17 00:00:00 2001 From: Tunay Engin Date: Wed, 14 Jan 2026 01:00:46 +0300 Subject: [PATCH 05/21] Add WebSocket benchmarks and extras modules Introduces comprehensive WebSocket message throughput benchmarks in benches/rustapi_bench, including text, binary, JSON parsing, frame encoding, and broadcast scenarios. Adds new modules to rustapi-extras for features like API key, cache, circuit breaker, CSRF, deduplication, Diesel, guards, logging, OAuth2, OpenTelemetry, retry, sanitization, security headers, structured logging, and timeout. Adds v2 validation modules to rustapi-validate, new WebSocket features to rustapi-ws, and a phase11-demo example. Updates dependencies and workspace configuration to support new features. --- Cargo.lock | 651 +++++++++- Cargo.toml | 1 + benches/rustapi_bench/Cargo.toml | 4 + .../rustapi_bench/benches/websocket_bench.rs | 243 ++++ crates/rustapi-core/src/app.rs | 253 +++- crates/rustapi-core/src/extract.rs | 3 +- crates/rustapi-core/src/health.rs | 284 +++++ crates/rustapi-core/src/interceptor.rs | 523 ++++++++ crates/rustapi-core/src/lib.rs | 4 + crates/rustapi-core/src/middleware/metrics.rs | 2 +- crates/rustapi-core/src/request.rs | 29 +- crates/rustapi-core/src/router.rs | 54 + crates/rustapi-core/src/server.rs | 16 +- crates/rustapi-extras/Cargo.toml | 55 +- crates/rustapi-extras/src/api_key.rs | 334 +++++ crates/rustapi-extras/src/cache.rs | 174 +++ crates/rustapi-extras/src/circuit_breaker.rs | 408 ++++++ crates/rustapi-extras/src/csrf/config.rs | 97 ++ crates/rustapi-extras/src/csrf/layer.rs | 282 +++++ crates/rustapi-extras/src/csrf/mod.rs | 25 + crates/rustapi-extras/src/csrf/token.rs | 63 + crates/rustapi-extras/src/dedup.rs | 134 ++ crates/rustapi-extras/src/diesel/mod.rs | 522 ++++++++ crates/rustapi-extras/src/guard.rs | 252 ++++ crates/rustapi-extras/src/lib.rs | 118 +- crates/rustapi-extras/src/logging.rs | 302 +++++ crates/rustapi-extras/src/oauth2/client.rs | 308 +++++ crates/rustapi-extras/src/oauth2/config.rs | 193 +++ crates/rustapi-extras/src/oauth2/mod.rs | 38 + crates/rustapi-extras/src/oauth2/providers.rs | 133 ++ crates/rustapi-extras/src/oauth2/tokens.rs | 273 ++++ crates/rustapi-extras/src/otel/config.rs | 278 +++++ crates/rustapi-extras/src/otel/layer.rs | 322 +++++ crates/rustapi-extras/src/otel/mod.rs | 38 + crates/rustapi-extras/src/otel/propagation.rs | 313 +++++ crates/rustapi-extras/src/retry.rs | 296 +++++ crates/rustapi-extras/src/sanitization.rs | 98 ++ crates/rustapi-extras/src/security_headers.rs | 415 ++++++ crates/rustapi-extras/src/sqlx/mod.rs | 459 ++++++- .../src/structured_logging/config.rs | 346 +++++ .../src/structured_logging/formats.rs | 605 +++++++++ .../src/structured_logging/layer.rs | 462 +++++++ .../src/structured_logging/mod.rs | 39 + crates/rustapi-extras/src/timeout.rs | 173 +++ crates/rustapi-macros/Cargo.toml | 3 + crates/rustapi-macros/src/lib.rs | 427 +++++++ crates/rustapi-rs/Cargo.toml | 16 +- crates/rustapi-rs/src/lib.rs | 39 +- crates/rustapi-validate/Cargo.toml | 10 + .../proptest-regressions/v2/tests.txt | 8 + crates/rustapi-validate/src/custom.rs | 182 +++ crates/rustapi-validate/src/error.rs | 43 + crates/rustapi-validate/src/lib.rs | 43 +- crates/rustapi-validate/src/v2/context.rs | 235 ++++ crates/rustapi-validate/src/v2/error.rs | 277 ++++ crates/rustapi-validate/src/v2/group.rs | 254 ++++ crates/rustapi-validate/src/v2/mod.rs | 63 + .../src/v2/rules/async_rules.rs | 351 ++++++ crates/rustapi-validate/src/v2/rules/mod.rs | 9 + .../src/v2/rules/sync_rules.rs | 615 +++++++++ crates/rustapi-validate/src/v2/tests.rs | 1110 +++++++++++++++++ crates/rustapi-validate/src/v2/traits.rs | 712 +++++++++++ .../tests/derive_macro_tests.rs | 283 +++++ crates/rustapi-ws/Cargo.toml | 5 +- crates/rustapi-ws/src/auth.rs | 704 +++++++++++ crates/rustapi-ws/src/compression.rs | 102 ++ crates/rustapi-ws/src/error.rs | 25 + crates/rustapi-ws/src/extractor.rs | 29 +- crates/rustapi-ws/src/heartbeat.rs | 68 + crates/rustapi-ws/src/lib.rs | 10 +- crates/rustapi-ws/src/socket.rs | 321 +++-- crates/rustapi-ws/src/upgrade.rs | 123 +- examples/phase11-demo/Cargo.toml | 21 + examples/phase11-demo/README.md | 234 ++++ examples/phase11-demo/src/main.rs | 147 +++ 75 files changed, 15875 insertions(+), 216 deletions(-) create mode 100644 benches/rustapi_bench/benches/websocket_bench.rs create mode 100644 crates/rustapi-core/src/health.rs create mode 100644 crates/rustapi-core/src/interceptor.rs create mode 100644 crates/rustapi-extras/src/api_key.rs create mode 100644 crates/rustapi-extras/src/cache.rs create mode 100644 crates/rustapi-extras/src/circuit_breaker.rs create mode 100644 crates/rustapi-extras/src/csrf/config.rs create mode 100644 crates/rustapi-extras/src/csrf/layer.rs create mode 100644 crates/rustapi-extras/src/csrf/mod.rs create mode 100644 crates/rustapi-extras/src/csrf/token.rs create mode 100644 crates/rustapi-extras/src/dedup.rs create mode 100644 crates/rustapi-extras/src/diesel/mod.rs create mode 100644 crates/rustapi-extras/src/guard.rs create mode 100644 crates/rustapi-extras/src/logging.rs create mode 100644 crates/rustapi-extras/src/oauth2/client.rs create mode 100644 crates/rustapi-extras/src/oauth2/config.rs create mode 100644 crates/rustapi-extras/src/oauth2/mod.rs create mode 100644 crates/rustapi-extras/src/oauth2/providers.rs create mode 100644 crates/rustapi-extras/src/oauth2/tokens.rs create mode 100644 crates/rustapi-extras/src/otel/config.rs create mode 100644 crates/rustapi-extras/src/otel/layer.rs create mode 100644 crates/rustapi-extras/src/otel/mod.rs create mode 100644 crates/rustapi-extras/src/otel/propagation.rs create mode 100644 crates/rustapi-extras/src/retry.rs create mode 100644 crates/rustapi-extras/src/sanitization.rs create mode 100644 crates/rustapi-extras/src/security_headers.rs create mode 100644 crates/rustapi-extras/src/structured_logging/config.rs create mode 100644 crates/rustapi-extras/src/structured_logging/formats.rs create mode 100644 crates/rustapi-extras/src/structured_logging/layer.rs create mode 100644 crates/rustapi-extras/src/structured_logging/mod.rs create mode 100644 crates/rustapi-extras/src/timeout.rs create mode 100644 crates/rustapi-validate/proptest-regressions/v2/tests.txt create mode 100644 crates/rustapi-validate/src/custom.rs create mode 100644 crates/rustapi-validate/src/v2/context.rs create mode 100644 crates/rustapi-validate/src/v2/error.rs create mode 100644 crates/rustapi-validate/src/v2/group.rs create mode 100644 crates/rustapi-validate/src/v2/mod.rs create mode 100644 crates/rustapi-validate/src/v2/rules/async_rules.rs create mode 100644 crates/rustapi-validate/src/v2/rules/mod.rs create mode 100644 crates/rustapi-validate/src/v2/rules/sync_rules.rs create mode 100644 crates/rustapi-validate/src/v2/tests.rs create mode 100644 crates/rustapi-validate/src/v2/traits.rs create mode 100644 crates/rustapi-validate/tests/derive_macro_tests.rs create mode 100644 crates/rustapi-ws/src/auth.rs create mode 100644 crates/rustapi-ws/src/compression.rs create mode 100644 crates/rustapi-ws/src/heartbeat.rs create mode 100644 examples/phase11-demo/Cargo.toml create mode 100644 examples/phase11-demo/README.md create mode 100644 examples/phase11-demo/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 595caeef..d27562c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -149,6 +149,39 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "atoi" version = "2.0.0" @@ -181,6 +214,51 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper 0.1.2", + "tower 0.4.13", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "base64" version = "0.21.7" @@ -214,6 +292,12 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.10.0" @@ -607,6 +691,15 @@ dependencies = [ "itertools", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -674,8 +767,18 @@ version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core 0.21.3", + "darling_macro 0.21.3", ] [[package]] @@ -692,13 +795,38 @@ dependencies = [ "syn 2.0.111", ] +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.111", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", + "quote", + "syn 2.0.111", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core 0.21.3", "quote", "syn 2.0.111", ] @@ -762,6 +890,48 @@ dependencies = [ "zeroize", ] +[[package]] +name = "diesel" +version = "2.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e130c806dccc85428c564f2dc5a96e05b6615a27c9a28776bd7761a9af4bb552" +dependencies = [ + "bitflags 2.10.0", + "byteorder", + "diesel_derives", + "downcast-rs", + "itoa", + "libsqlite3-sys", + "mysqlclient-sys", + "percent-encoding", + "pq-sys", + "sqlite-wasm-rs", + "time", + "url", +] + +[[package]] +name = "diesel_derives" +version = "2.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c30b2969f923fa1f73744b92bb7df60b858df8832742d9a3aceb79236c0be1d2" +dependencies = [ + "diesel_table_macro_syntax", + "dsl_auto_type", + "proc-macro2", + "quote", + "syn 2.0.111", +] + +[[package]] +name = "diesel_table_macro_syntax" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe2444076b48641147115697648dc743c2c00b61adade0f01ce67133c7babe8c" +dependencies = [ + "syn 2.0.111", +] + [[package]] name = "difflib" version = "0.4.0" @@ -797,6 +967,26 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "downcast-rs" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "117240f60069e65410b3ae1bb213295bd828f707b5bec6596a1afc8793ce0cbc" + +[[package]] +name = "dsl_auto_type" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd122633e4bef06db27737f21d3738fb89c8f6d5360d6d9d7635dda142a7757e" +dependencies = [ + "darling 0.21.3", + "either", + "heck", + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "either" version = "1.15.0" @@ -922,6 +1112,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1067,6 +1263,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "globset" version = "0.4.18" @@ -1086,11 +1288,30 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf760ebf69878d9fd8f110c89703d90ce35095324d1f1edcb595c63945ee757" dependencies = [ - "bitflags", + "bitflags 2.10.0", "ignore", "walkdir", ] +[[package]] +name = "h2" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap 2.12.1", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "h2" version = "0.4.12" @@ -1102,7 +1323,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.4.0", "indexmap 2.12.1", "slab", "tokio", @@ -1155,7 +1376,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", ] [[package]] @@ -1163,6 +1384,9 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "foldhash 0.2.0", +] [[package]] name = "hashlink" @@ -1239,6 +1463,17 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.4.0" @@ -1249,6 +1484,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.1" @@ -1256,7 +1502,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.4.0", ] [[package]] @@ -1267,8 +1513,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1299,6 +1545,30 @@ dependencies = [ "libm", ] +[[package]] +name = "hyper" +version = "0.14.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.5.10", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.8.1" @@ -1309,9 +1579,9 @@ dependencies = [ "bytes", "futures-channel", "futures-core", - "h2", - "http", - "http-body", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1328,8 +1598,8 @@ version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http", - "hyper", + "http 1.4.0", + "hyper 1.8.1", "hyper-util", "rustls", "rustls-pki-types", @@ -1339,6 +1609,18 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.32", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + [[package]] name = "hyper-tls" version = "0.6.0" @@ -1347,7 +1629,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "native-tls", "tokio", @@ -1366,14 +1648,14 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", - "http-body", - "hyper", + "http 1.4.0", + "http-body 1.0.1", + "hyper 1.8.1", "ipnet", "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.1", "system-configuration", "tokio", "tower-layer", @@ -1694,7 +1976,7 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" dependencies = [ - "bitflags", + "bitflags 2.10.0", "libc", "redox_syscall 0.7.0", ] @@ -1820,7 +2102,7 @@ dependencies = [ name = "middleware-chain" version = "0.1.0" dependencies = [ - "http", + "http 1.4.0", "rustapi-core", "rustapi-rs", "serde", @@ -1865,6 +2147,17 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "mysqlclient-sys" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86a34a2bdec189f1060343ba712983e14cad7e87515cfd9ac4653e207535b6b1" +dependencies = [ + "pkg-config", + "semver", + "vcpkg", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -1983,7 +2276,7 @@ version = "0.10.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" dependencies = [ - "bitflags", + "bitflags 2.10.0", "cfg-if", "foreign-types", "libc", @@ -2021,6 +2314,89 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "opentelemetry" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900d57987be3f2aeb70d385fff9b27fb74c5723cc9a52d904d4f9c807a0667bf" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror 1.0.69", + "urlencoding", +] + +[[package]] +name = "opentelemetry-otlp" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a016b8d9495c639af2145ac22387dcb88e44118e45320d9238fbf4e7889abcb" +dependencies = [ + "async-trait", + "futures-core", + "http 0.2.12", + "opentelemetry", + "opentelemetry-proto", + "opentelemetry-semantic-conventions", + "opentelemetry_sdk", + "prost", + "thiserror 1.0.69", + "tokio", + "tonic", +] + +[[package]] +name = "opentelemetry-proto" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a8fddc9b68f5b80dae9d6f510b88e02396f006ad48cac349411fbecc80caae4" +dependencies = [ + "opentelemetry", + "opentelemetry_sdk", + "prost", + "tonic", +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9ab5bd6c42fb9349dcf28af2ba9a0667f697f9bdcca045d39f2cec5543e2910" + +[[package]] +name = "opentelemetry_sdk" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e90c7113be649e31e9a0f8b5ee24ed7a16923b322c3c5ab6367469c049d6b7e" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry", + "ordered-float", + "percent-encoding", + "rand 0.8.5", + "thiserror 1.0.69", + "tokio", + "tokio-stream", +] + +[[package]] +name = "ordered-float" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +dependencies = [ + "num-traits", +] + [[package]] name = "parking" version = "2.2.1" @@ -2282,6 +2658,17 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "pq-sys" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "574ddd6a267294433f140b02a726b0640c43cf7c6f717084684aaa3b285aba61" +dependencies = [ + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "predicates" version = "3.1.3" @@ -2381,7 +2768,7 @@ checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40" dependencies = [ "bit-set", "bit-vec", - "bitflags", + "bitflags 2.10.0", "num-traits", "rand 0.9.2", "rand_chacha 0.9.0", @@ -2392,6 +2779,29 @@ dependencies = [ "unarray", ] +[[package]] +name = "prost" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "protobuf" version = "2.28.0" @@ -2417,7 +2827,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", - "socket2", + "socket2 0.6.1", "thiserror 2.0.17", "tokio", "tracing", @@ -2454,9 +2864,9 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2", + "socket2 0.6.1", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -2474,6 +2884,17 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r2d2" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93" +dependencies = [ + "log", + "parking_lot", + "scheduled-thread-pool", +] + [[package]] name = "rand" version = "0.8.5" @@ -2578,7 +2999,7 @@ version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags", + "bitflags 2.10.0", ] [[package]] @@ -2587,7 +3008,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27" dependencies = [ - "bitflags", + "bitflags 2.10.0", ] [[package]] @@ -2649,11 +3070,11 @@ dependencies = [ "bytes", "encoding_rs", "futures-core", - "h2", - "http", - "http-body", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-rustls", "hyper-tls", "hyper-util", @@ -2669,7 +3090,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 1.0.2", "tokio", "tokio-native-tls", "tokio-rustls", @@ -2737,9 +3158,9 @@ dependencies = [ "cookie", "flate2", "futures-util", - "http", + "http 1.4.0", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "inventory", "linkme", @@ -2769,27 +3190,37 @@ dependencies = [ name = "rustapi-extras" version = "0.1.8" dependencies = [ + "base64 0.22.1", "bytes", "cookie", "dashmap", + "diesel", "dotenvy", "envy", "futures-util", - "http", + "http 1.4.0", "http-body-util", "jsonwebtoken", + "opentelemetry", + "opentelemetry-otlp", + "opentelemetry-semantic-conventions", + "opentelemetry_sdk", "proptest", + "r2d2", + "rand 0.8.5", "reqwest", "rustapi-core", "rustapi-openapi", "serde", "serde_json", "serial_test", + "sha2", "sqlx", "tempfile", "thiserror 1.0.69", "tokio", "tracing", + "tracing-opentelemetry", "urlencoding", ] @@ -2807,7 +3238,7 @@ name = "rustapi-openapi" version = "0.1.8" dependencies = [ "bytes", - "http", + "http 1.4.0", "http-body-util", "serde", "serde_json", @@ -2839,7 +3270,7 @@ version = "0.1.8" dependencies = [ "bytes", "futures-util", - "http", + "http 1.4.0", "http-body-util", "rustapi-core", "rustapi-openapi", @@ -2855,7 +3286,11 @@ dependencies = [ name = "rustapi-validate" version = "0.1.8" dependencies = [ - "http", + "async-trait", + "http 1.4.0", + "proptest", + "regex", + "rustapi-macros", "serde", "serde_json", "thiserror 1.0.69", @@ -2868,7 +3303,7 @@ name = "rustapi-view" version = "0.1.8" dependencies = [ "bytes", - "http", + "http 1.4.0", "http-body-util", "rustapi-core", "rustapi-openapi", @@ -2884,14 +3319,16 @@ dependencies = [ name = "rustapi-ws" version = "0.1.8" dependencies = [ + "async-trait", "base64 0.22.1", "bytes", "futures-util", - "http", + "http 1.4.0", "http-body-util", - "hyper", + "hyper 1.8.1", "hyper-util", "pin-project-lite", + "proptest", "rustapi-core", "rustapi-openapi", "serde", @@ -2902,6 +3339,7 @@ dependencies = [ "tokio-tungstenite", "tracing", "tungstenite", + "url", ] [[package]] @@ -2916,7 +3354,7 @@ version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ - "bitflags", + "bitflags 2.10.0", "errno", "libc", "linux-raw-sys", @@ -3009,6 +3447,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -3027,7 +3474,7 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags", + "bitflags 2.10.0", "core-foundation", "core-foundation-sys", "libc", @@ -3044,6 +3491,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.228" @@ -3268,6 +3721,16 @@ dependencies = [ "serde", ] +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "socket2" version = "0.6.1" @@ -3297,6 +3760,19 @@ dependencies = [ "der", ] +[[package]] +name = "sqlite-wasm-rs" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05e98301bf8b0540c7de45ecd760539b9c62f5772aed172f08efba597c11cd5d" +dependencies = [ + "cc", + "hashbrown 0.16.1", + "js-sys", + "thiserror 2.0.17", + "wasm-bindgen", +] + [[package]] name = "sqlx" version = "0.8.6" @@ -3402,7 +3878,7 @@ checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526" dependencies = [ "atoi", "base64 0.22.1", - "bitflags", + "bitflags 2.10.0", "byteorder", "bytes", "crc", @@ -3444,7 +3920,7 @@ checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46" dependencies = [ "atoi", "base64 0.22.1", - "bitflags", + "bitflags 2.10.0", "byteorder", "crc", "dotenvy", @@ -3547,6 +4023,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "sync_wrapper" version = "1.0.2" @@ -3573,7 +4055,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ - "bitflags", + "bitflags 2.10.0", "core-foundation", "system-configuration-sys", ] @@ -3768,11 +4250,21 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.6.1", "tokio-macros", "windows-sys 0.61.2", ] +[[package]] +name = "tokio-io-timeout" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bd86198d9ee903fedd2f9a2e72014287c0d9167e4ae43b5853007205dda1b76" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-macros" version = "2.6.0" @@ -3881,6 +4373,33 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "tonic" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.21.7", + "bytes", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost", + "tokio", + "tokio-stream", + "tower 0.4.13", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "toon-api" version = "0.1.0" @@ -3945,7 +4464,7 @@ dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper", + "sync_wrapper 1.0.2", "tokio", "tower-layer", "tower-service", @@ -3959,12 +4478,12 @@ checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ "async-compression", "base64 0.21.7", - "bitflags", + "bitflags 2.10.0", "bytes", "futures-core", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "http-body-util", "http-range-header", "httpdate", @@ -3988,11 +4507,11 @@ version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ - "bitflags", + "bitflags 2.10.0", "bytes", "futures-util", - "http", - "http-body", + "http 1.4.0", + "http-body 1.0.1", "iri-string", "pin-project-lite", "tower 0.5.2", @@ -4056,6 +4575,24 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-opentelemetry" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9be14ba1bbe4ab79e9229f7f89fab8d120b865859f10527f31c033e599d2284" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry", + "opentelemetry_sdk", + "smallvec", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", + "web-time", +] + [[package]] name = "tracing-subscriber" version = "0.3.22" @@ -4089,7 +4626,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http", + "http 1.4.0", "httparse", "log", "rand 0.8.5", @@ -4262,7 +4799,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df0bcf92720c40105ac4b2dda2a4ea3aa717d4d6a862cc217da653a4bd5c6b10" dependencies = [ - "darling", + "darling 0.20.11", "once_cell", "proc-macro-error", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 5b50948a..b92af419 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ futures-util = "0.3" bytes = "1.5" matchit = "0.7" # Radix tree router pin-project-lite = "0.2" +async-trait = "0.1" # Proc macros syn = { version = "2.0", features = ["full", "parsing", "extra-traits"] } diff --git a/benches/rustapi_bench/Cargo.toml b/benches/rustapi_bench/Cargo.toml index 27f934a5..7bd0837f 100644 --- a/benches/rustapi_bench/Cargo.toml +++ b/benches/rustapi_bench/Cargo.toml @@ -12,6 +12,10 @@ harness = false name = "extractor_bench" harness = false +[[bench]] +name = "websocket_bench" +harness = false + [dependencies] serde.workspace = true serde_json.workspace = true diff --git a/benches/rustapi_bench/benches/websocket_bench.rs b/benches/rustapi_bench/benches/websocket_bench.rs new file mode 100644 index 00000000..8801702b --- /dev/null +++ b/benches/rustapi_bench/benches/websocket_bench.rs @@ -0,0 +1,243 @@ +//! WebSocket message throughput benchmarks +//! +//! Benchmarks the performance of WebSocket message handling in RustAPI. + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use std::collections::HashMap; + +/// Simulate WebSocket message parsing (text) +fn parse_text_message(data: &str) -> String { + data.to_string() +} + +/// Simulate WebSocket message parsing (binary) +fn parse_binary_message(data: &[u8]) -> Vec { + data.to_vec() +} + +/// Simulate JSON message parsing +fn parse_json_message(data: &str) -> serde_json::Value { + serde_json::from_str(data).unwrap_or(serde_json::Value::Null) +} + +/// Simulate message frame encoding +fn encode_frame(opcode: u8, payload: &[u8], mask: bool) -> Vec { + let mut frame = Vec::with_capacity(14 + payload.len()); + + // FIN bit + opcode + frame.push(0x80 | opcode); + + // Payload length + let len = payload.len(); + if len < 126 { + frame.push((if mask { 0x80 } else { 0 }) | len as u8); + } else if len < 65536 { + frame.push((if mask { 0x80 } else { 0 }) | 126); + frame.push((len >> 8) as u8); + frame.push(len as u8); + } else { + frame.push((if mask { 0x80 } else { 0 }) | 127); + for i in (0..8).rev() { + frame.push((len >> (i * 8)) as u8); + } + } + + // Masking key (if masked) + if mask { + let mask_key: [u8; 4] = [0x12, 0x34, 0x56, 0x78]; + frame.extend_from_slice(&mask_key); + + // Masked payload + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + } else { + frame.extend_from_slice(payload); + } + + frame +} + +/// Benchmark text message parsing +fn bench_text_message(c: &mut Criterion) { + let mut group = c.benchmark_group("websocket_text"); + + let messages = [ + ("tiny", "Hi"), + ("small", "Hello, WebSocket!"), + ("medium", &"x".repeat(1024)), + ("large", &"x".repeat(64 * 1024)), + ]; + + for (name, msg) in messages.iter() { + group.throughput(Throughput::Bytes(msg.len() as u64)); + group.bench_with_input( + BenchmarkId::new("parse", name), + msg, + |b, msg| b.iter(|| parse_text_message(black_box(msg))), + ); + } + + group.finish(); +} + +/// Benchmark binary message parsing +fn bench_binary_message(c: &mut Criterion) { + let mut group = c.benchmark_group("websocket_binary"); + + let messages: Vec<(&str, Vec)> = vec![ + ("tiny", vec![1, 2, 3, 4]), + ("small", vec![0u8; 64]), + ("medium", vec![0u8; 4096]), + ("large", vec![0u8; 64 * 1024]), + ]; + + for (name, msg) in messages.iter() { + group.throughput(Throughput::Bytes(msg.len() as u64)); + group.bench_with_input( + BenchmarkId::new("parse", name), + msg, + |b, msg| b.iter(|| parse_binary_message(black_box(msg))), + ); + } + + group.finish(); +} + +/// Benchmark JSON message parsing (common WebSocket pattern) +fn bench_json_message(c: &mut Criterion) { + let mut group = c.benchmark_group("websocket_json"); + + // Simple JSON message + let simple_json = r#"{"type":"ping"}"#; + + // Typical chat message + let chat_json = r#"{"type":"message","user":"alice","content":"Hello everyone!","timestamp":1704067200}"#; + + // Complex nested JSON + let complex_json = r#"{"type":"state","data":{"users":[{"id":1,"name":"Alice"},{"id":2,"name":"Bob"}],"room":"general","active":true}}"#; + + group.bench_function("simple", |b| { + b.iter(|| parse_json_message(black_box(simple_json))) + }); + + group.bench_function("chat", |b| { + b.iter(|| parse_json_message(black_box(chat_json))) + }); + + group.bench_function("complex", |b| { + b.iter(|| parse_json_message(black_box(complex_json))) + }); + + group.finish(); +} + +/// Benchmark frame encoding +fn bench_frame_encoding(c: &mut Criterion) { + let mut group = c.benchmark_group("websocket_frame"); + + let payloads: Vec<(&str, Vec)> = vec![ + ("tiny", vec![1, 2, 3, 4]), + ("small", vec![0u8; 100]), + ("medium_125", vec![0u8; 125]), // Max single-byte length + ("medium_126", vec![0u8; 126]), // Requires 2-byte length + ("large", vec![0u8; 1024]), + ]; + + for (name, payload) in payloads.iter() { + // Server-side (no mask) + group.bench_with_input( + BenchmarkId::new("encode_unmasked", name), + payload, + |b, payload| b.iter(|| encode_frame(0x01, black_box(payload), false)), + ); + + // Client-side (with mask) + group.bench_with_input( + BenchmarkId::new("encode_masked", name), + payload, + |b, payload| b.iter(|| encode_frame(0x01, black_box(payload), true)), + ); + } + + group.finish(); +} + +/// Benchmark broadcast scenario (sending to multiple clients) +fn bench_broadcast(c: &mut Criterion) { + let mut group = c.benchmark_group("websocket_broadcast"); + + let message = "Broadcast message to all connected clients"; + + for client_count in [10, 100, 1000].iter() { + group.bench_with_input( + BenchmarkId::new("prepare_messages", client_count), + client_count, + |b, &count| { + b.iter(|| { + // Simulate preparing messages for N clients + let mut messages = Vec::with_capacity(count); + for _ in 0..count { + messages.push(black_box(message).to_string()); + } + messages + }) + }, + ); + } + + group.finish(); +} + +/// Benchmark connection management (HashMap-based room pattern) +fn bench_connection_management(c: &mut Criterion) { + let mut group = c.benchmark_group("websocket_rooms"); + + // Simulate room-based connection management + group.bench_function("join_room", |b| { + let mut rooms: HashMap> = HashMap::new(); + let mut client_id = 0u64; + + b.iter(|| { + client_id += 1; + let room = black_box("general".to_string()); + rooms.entry(room).or_default().push(client_id); + }) + }); + + group.bench_function("leave_room", |b| { + let mut rooms: HashMap> = HashMap::new(); + rooms.insert("general".to_string(), (0..1000).collect()); + + b.iter(|| { + let room = rooms.get_mut(black_box("general")).unwrap(); + let client_id = black_box(500u64); + if let Some(pos) = room.iter().position(|&id| id == client_id) { + room.swap_remove(pos); + } + }) + }); + + group.bench_function("list_room_members", |b| { + let mut rooms: HashMap> = HashMap::new(); + rooms.insert("general".to_string(), (0..100).collect()); + + b.iter(|| { + rooms.get(black_box("general")).map(|members| members.len()) + }) + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_text_message, + bench_binary_message, + bench_json_message, + bench_frame_encoding, + bench_broadcast, + bench_connection_management, +); + +criterion_main!(benches); diff --git a/crates/rustapi-core/src/app.rs b/crates/rustapi-core/src/app.rs index 102c72bd..c7568963 100644 --- a/crates/rustapi-core/src/app.rs +++ b/crates/rustapi-core/src/app.rs @@ -1,6 +1,7 @@ //! RustApi application builder use crate::error::Result; +use crate::interceptor::{InterceptorChain, RequestInterceptor, ResponseInterceptor}; use crate::middleware::{BodyLimitLayer, LayerStack, MiddlewareLayer, DEFAULT_BODY_LIMIT}; use crate::response::IntoResponse; use crate::router::{MethodRouter, Router}; @@ -30,6 +31,7 @@ pub struct RustApi { openapi_spec: rustapi_openapi::OpenApiSpec, layers: LayerStack, body_limit: Option, + interceptors: InterceptorChain, } impl RustApi { @@ -54,6 +56,7 @@ impl RustApi { .register::(), layers: LayerStack::new(), body_limit: Some(DEFAULT_BODY_LIMIT), // Default 1MB limit + interceptors: InterceptorChain::new(), } } @@ -185,6 +188,84 @@ impl RustApi { self } + /// Add a request interceptor to the application + /// + /// Request interceptors are executed in registration order before the route handler. + /// Each interceptor can modify the request before passing it to the next interceptor + /// or handler. + /// + /// # Example + /// + /// ```rust,ignore + /// use rustapi_core::{RustApi, interceptor::RequestInterceptor, Request}; + /// + /// #[derive(Clone)] + /// struct AddRequestId; + /// + /// impl RequestInterceptor for AddRequestId { + /// fn intercept(&self, mut req: Request) -> Request { + /// req.extensions_mut().insert(uuid::Uuid::new_v4()); + /// req + /// } + /// + /// fn clone_box(&self) -> Box { + /// Box::new(self.clone()) + /// } + /// } + /// + /// RustApi::new() + /// .request_interceptor(AddRequestId) + /// .route("/", get(handler)) + /// .run("127.0.0.1:8080") + /// .await + /// ``` + pub fn request_interceptor(mut self, interceptor: I) -> Self + where + I: RequestInterceptor, + { + self.interceptors.add_request_interceptor(interceptor); + self + } + + /// Add a response interceptor to the application + /// + /// Response interceptors are executed in reverse registration order after the route + /// handler completes. Each interceptor can modify the response before passing it + /// to the previous interceptor or client. + /// + /// # Example + /// + /// ```rust,ignore + /// use rustapi_core::{RustApi, interceptor::ResponseInterceptor, Response}; + /// + /// #[derive(Clone)] + /// struct AddServerHeader; + /// + /// impl ResponseInterceptor for AddServerHeader { + /// fn intercept(&self, mut res: Response) -> Response { + /// res.headers_mut().insert("X-Server", "RustAPI".parse().unwrap()); + /// res + /// } + /// + /// fn clone_box(&self) -> Box { + /// Box::new(self.clone()) + /// } + /// } + /// + /// RustApi::new() + /// .response_interceptor(AddServerHeader) + /// .route("/", get(handler)) + /// .run("127.0.0.1:8080") + /// .await + /// ``` + pub fn response_interceptor(mut self, interceptor: I) -> Self + where + I: ResponseInterceptor, + { + self.interceptors.add_response_interceptor(interceptor); + self + } + /// Add application state /// /// State is shared across all handlers and can be extracted using `State`. @@ -267,10 +348,7 @@ impl RustApi { } #[cfg(feature = "tracing")] - let route_count: usize = by_path - .values() - .map(|mr| mr.allowed_methods().len()) - .sum(); + let route_count: usize = by_path.values().map(|mr| mr.allowed_methods().len()).sum(); #[cfg(feature = "tracing")] let path_count = by_path.len(); @@ -567,15 +645,23 @@ impl RustApi { /// - `{path}` - Swagger UI interface /// - `{path}/openapi.json` - OpenAPI JSON specification /// + /// **Important:** Call `.docs()` AFTER registering all routes. The OpenAPI + /// specification is captured at the time `.docs()` is called, so routes + /// added afterwards will not appear in the documentation. + /// /// # Example /// /// ```text /// RustApi::new() - /// .route("/users", get(list_users)) - /// .docs("/docs") // Swagger UI at /docs, spec at /docs/openapi.json + /// .route("/users", get(list_users)) // Add routes first + /// .route("/posts", get(list_posts)) // Add more routes + /// .docs("/docs") // Then enable docs - captures all routes above /// .run("127.0.0.1:8080") /// .await /// ``` + /// + /// For `RustApi::auto()`, routes are collected before `.docs()` is called, + /// so this is handled automatically. #[cfg(feature = "swagger-ui")] pub fn docs(self, path: &str) -> Self { let title = self.openapi_spec.info.title.clone(); @@ -783,7 +869,7 @@ impl RustApi { self.layers.prepend(Box::new(BodyLimitLayer::new(limit))); } - let server = Server::new(self.router, self.layers); + let server = Server::new(self.router, self.layers, self.interceptors); server.run(addr).await } @@ -796,6 +882,11 @@ impl RustApi { pub fn layers(&self) -> &LayerStack { &self.layers } + + /// Get the interceptor chain (for testing) + pub fn interceptors(&self) -> &InterceptorChain { + &self.interceptors + } } fn add_path_params_to_operation(path: &str, op: &mut rustapi_openapi::Operation) { @@ -839,16 +930,66 @@ fn add_path_params_to_operation(path: &str, op: &mut rustapi_openapi::Operation) continue; } + // Infer schema type based on common naming patterns + let schema = infer_path_param_schema(&name); + op_params.push(rustapi_openapi::Parameter { name, location: "path".to_string(), required: true, description: None, - schema: rustapi_openapi::SchemaRef::Inline(serde_json::json!({ "type": "string" })), + schema, }); } } +/// Infer the OpenAPI schema type for a path parameter based on naming conventions. +/// +/// Common patterns: +/// - `*_id`, `*Id`, `id` → integer (but NOT *uuid) +/// - `*_count`, `*_num`, `page`, `limit`, `offset` → integer +/// - `*_uuid`, `uuid` → string with uuid format +/// - `year`, `month`, `day` → integer +/// - Everything else → string +fn infer_path_param_schema(name: &str) -> rustapi_openapi::SchemaRef { + let lower = name.to_lowercase(); + + // UUID patterns (check first to avoid false positive from "id" suffix) + let is_uuid = lower == "uuid" || lower.ends_with("_uuid") || lower.ends_with("uuid"); + + if is_uuid { + return rustapi_openapi::SchemaRef::Inline(serde_json::json!({ + "type": "string", + "format": "uuid" + })); + } + + // Integer patterns + let is_integer = lower == "id" + || lower.ends_with("_id") + || (lower.ends_with("id") && lower.len() > 2) // e.g., "userId", but not "uuid" + || lower == "page" + || lower == "limit" + || lower == "offset" + || lower == "count" + || lower.ends_with("_count") + || lower.ends_with("_num") + || lower == "year" + || lower == "month" + || lower == "day" + || lower == "index" + || lower == "position"; + + if is_integer { + rustapi_openapi::SchemaRef::Inline(serde_json::json!({ + "type": "integer", + "format": "int64" + })) + } else { + rustapi_openapi::SchemaRef::Inline(serde_json::json!({ "type": "string" })) + } +} + /// Normalize a prefix for OpenAPI paths. /// /// Ensures the prefix: @@ -913,6 +1054,102 @@ mod tests { assert_eq!(value, 123u32); } + #[test] + fn test_path_param_type_inference_integer() { + use super::infer_path_param_schema; + + // Test common integer patterns + let int_params = [ + "id", + "user_id", + "userId", + "postId", + "page", + "limit", + "offset", + "count", + "item_count", + "year", + "month", + "day", + "index", + "position", + ]; + + for name in int_params { + let schema = infer_path_param_schema(name); + match schema { + rustapi_openapi::SchemaRef::Inline(v) => { + assert_eq!( + v.get("type").and_then(|v| v.as_str()), + Some("integer"), + "Expected '{}' to be inferred as integer", + name + ); + } + _ => panic!("Expected inline schema for '{}'", name), + } + } + } + + #[test] + fn test_path_param_type_inference_uuid() { + use super::infer_path_param_schema; + + // Test UUID patterns + let uuid_params = ["uuid", "user_uuid", "sessionUuid"]; + + for name in uuid_params { + let schema = infer_path_param_schema(name); + match schema { + rustapi_openapi::SchemaRef::Inline(v) => { + assert_eq!( + v.get("type").and_then(|v| v.as_str()), + Some("string"), + "Expected '{}' to be inferred as string", + name + ); + assert_eq!( + v.get("format").and_then(|v| v.as_str()), + Some("uuid"), + "Expected '{}' to have uuid format", + name + ); + } + _ => panic!("Expected inline schema for '{}'", name), + } + } + } + + #[test] + fn test_path_param_type_inference_string() { + use super::infer_path_param_schema; + + // Test string (default) patterns + let string_params = ["name", "slug", "code", "token", "username"]; + + for name in string_params { + let schema = infer_path_param_schema(name); + match schema { + rustapi_openapi::SchemaRef::Inline(v) => { + assert_eq!( + v.get("type").and_then(|v| v.as_str()), + Some("string"), + "Expected '{}' to be inferred as string", + name + ); + assert!( + v.get("format").is_none() + || v.get("format").and_then(|v| v.as_str()) != Some("uuid"), + "Expected '{}' to NOT have uuid format", + name + ); + } + _ => panic!("Expected inline schema for '{}'", name), + } + } + } + // **Feature: router-nesting, Property 11: OpenAPI Integration** // // For any nested routes with OpenAPI operations, the operations should appear diff --git a/crates/rustapi-core/src/extract.rs b/crates/rustapi-core/src/extract.rs index 3f0a2444..424f9c41 100644 --- a/crates/rustapi-core/src/extract.rs +++ b/crates/rustapi-core/src/extract.rs @@ -792,7 +792,7 @@ impl OperationModifier for Path { fn update_operation(_op: &mut Operation) { // Path parameters are automatically documented by add_path_params_to_operation // in app.rs based on the route pattern. No additional implementation needed here. - // + // // For typed path params, the schema type defaults to "string" but will be // inferred from the actual type T when more sophisticated type introspection // is implemented. @@ -904,7 +904,6 @@ mod tests { use http::{Extensions, Method}; use proptest::prelude::*; use proptest::test_runner::TestCaseError; - use std::collections::HashMap; use std::sync::Arc; /// Create a test request with the given method, path, and headers diff --git a/crates/rustapi-core/src/health.rs b/crates/rustapi-core/src/health.rs new file mode 100644 index 00000000..e273076f --- /dev/null +++ b/crates/rustapi-core/src/health.rs @@ -0,0 +1,284 @@ +//! Health check system for monitoring application health +//! +//! This module provides a flexible health check system for monitoring +//! the health and readiness of your application and its dependencies. +//! +//! # Example +//! +//! ```rust,no_run +//! use rustapi_core::health::{HealthCheck, HealthCheckBuilder, HealthStatus}; +//! +//! #[tokio::main] +//! async fn main() { +//! let health = HealthCheckBuilder::new(true) +//! .add_check("database", || async { +//! // Check database connection +//! HealthStatus::healthy() +//! }) +//! .add_check("redis", || async { +//! // Check Redis connection +//! HealthStatus::healthy() +//! }) +//! .build(); +//! +//! // Use health.execute().await to get results +//! } +//! ``` + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +/// Health status of a component +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum HealthStatus { + /// Component is healthy + #[serde(rename = "healthy")] + Healthy, + /// Component is unhealthy + #[serde(rename = "unhealthy")] + Unhealthy { reason: String }, + /// Component is degraded but functional + #[serde(rename = "degraded")] + Degraded { reason: String }, +} + +impl HealthStatus { + /// Create a healthy status + pub fn healthy() -> Self { + Self::Healthy + } + + /// Create an unhealthy status with a reason + pub fn unhealthy(reason: impl Into) -> Self { + Self::Unhealthy { + reason: reason.into(), + } + } + + /// Create a degraded status with a reason + pub fn degraded(reason: impl Into) -> Self { + Self::Degraded { + reason: reason.into(), + } + } + + /// Check if the status is healthy + pub fn is_healthy(&self) -> bool { + matches!(self, Self::Healthy) + } + + /// Check if the status is unhealthy + pub fn is_unhealthy(&self) -> bool { + matches!(self, Self::Unhealthy { .. }) + } + + /// Check if the status is degraded + pub fn is_degraded(&self) -> bool { + matches!(self, Self::Degraded { .. }) + } +} + +/// Overall health check result +#[derive(Debug, Serialize, Deserialize)] +pub struct HealthCheckResult { + /// Overall status + pub status: HealthStatus, + /// Individual component checks + pub checks: HashMap, + /// Application version (if provided) + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, + /// Timestamp of check (ISO 8601) + pub timestamp: String, +} + +/// Type alias for async health check functions +pub type HealthCheckFn = + Arc Pin + Send>> + Send + Sync>; + +/// Health check configuration +#[derive(Clone)] +pub struct HealthCheck { + checks: HashMap, + version: Option, +} + +impl HealthCheck { + /// Execute all health checks + pub async fn execute(&self) -> HealthCheckResult { + let mut results = HashMap::new(); + let mut overall_status = HealthStatus::Healthy; + + for (name, check) in &self.checks { + let status = check().await; + + // Determine overall status + match &status { + HealthStatus::Unhealthy { .. } => { + overall_status = HealthStatus::unhealthy("one or more checks failed"); + } + HealthStatus::Degraded { .. } => { + if overall_status.is_healthy() { + overall_status = HealthStatus::degraded("one or more checks degraded"); + } + } + _ => {} + } + + results.insert(name.clone(), status); + } + + // Use UTC timestamp formatted as ISO 8601 + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| { + let secs = d.as_secs(); + let nanos = d.subsec_nanos(); + format!("{}.{:09}Z", secs, nanos) + }) + .unwrap_or_else(|_| "unknown".to_string()); + + HealthCheckResult { + status: overall_status, + checks: results, + version: self.version.clone(), + timestamp, + } + } +} + +/// Builder for health check configuration +pub struct HealthCheckBuilder { + checks: HashMap, + version: Option, +} + +impl HealthCheckBuilder { + /// Create a new health check builder + /// + /// # Arguments + /// + /// * `include_default` - Whether to include a default "self" check that always returns healthy + pub fn new(include_default: bool) -> Self { + let mut checks = HashMap::new(); + + if include_default { + let check: HealthCheckFn = Arc::new(|| Box::pin(async { HealthStatus::healthy() })); + checks.insert("self".to_string(), check); + } + + Self { + checks, + version: None, + } + } + + /// Add a health check + /// + /// # Example + /// + /// ```rust + /// use rustapi_core::health::{HealthCheckBuilder, HealthStatus}; + /// + /// let health = HealthCheckBuilder::new(false) + /// .add_check("database", || async { + /// // Simulate database check + /// HealthStatus::healthy() + /// }) + /// .build(); + /// ``` + pub fn add_check(mut self, name: impl Into, check: F) -> Self + where + F: Fn() -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + let check_fn = Arc::new(move || { + Box::pin(check()) as Pin + Send>> + }); + self.checks.insert(name.into(), check_fn); + self + } + + /// Set the application version + pub fn version(mut self, version: impl Into) -> Self { + self.version = Some(version.into()); + self + } + + /// Build the health check + pub fn build(self) -> HealthCheck { + HealthCheck { + checks: self.checks, + version: self.version, + } + } +} + +impl Default for HealthCheckBuilder { + fn default() -> Self { + Self::new(true) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn health_check_all_healthy() { + let health = HealthCheckBuilder::new(false) + .add_check("db", || async { HealthStatus::healthy() }) + .add_check("cache", || async { HealthStatus::healthy() }) + .version("1.0.0") + .build(); + + let result = health.execute().await; + + assert!(result.status.is_healthy()); + assert_eq!(result.checks.len(), 2); + assert_eq!(result.version, Some("1.0.0".to_string())); + } + + #[tokio::test] + async fn health_check_one_unhealthy() { + let health = HealthCheckBuilder::new(false) + .add_check("db", || async { HealthStatus::healthy() }) + .add_check("cache", || async { + HealthStatus::unhealthy("connection failed") + }) + .build(); + + let result = health.execute().await; + + assert!(result.status.is_unhealthy()); + assert_eq!(result.checks.len(), 2); + } + + #[tokio::test] + async fn health_check_one_degraded() { + let health = HealthCheckBuilder::new(false) + .add_check("db", || async { HealthStatus::healthy() }) + .add_check("cache", || async { HealthStatus::degraded("high latency") }) + .build(); + + let result = health.execute().await; + + assert!(result.status.is_degraded()); + assert_eq!(result.checks.len(), 2); + } + + #[tokio::test] + async fn health_check_with_default() { + let health = HealthCheckBuilder::new(true).build(); + + let result = health.execute().await; + + assert!(result.status.is_healthy()); + assert_eq!(result.checks.len(), 1); + assert!(result.checks.contains_key("self")); + } +} diff --git a/crates/rustapi-core/src/interceptor.rs b/crates/rustapi-core/src/interceptor.rs new file mode 100644 index 00000000..387d85d3 --- /dev/null +++ b/crates/rustapi-core/src/interceptor.rs @@ -0,0 +1,523 @@ +//! Request/Response Interceptor System for RustAPI +//! +//! This module provides interceptors that can modify requests before handlers +//! and responses after handlers, without the complexity of Tower layers. +//! +//! # Overview +//! +//! Interceptors provide a simpler alternative to middleware for common use cases: +//! - Adding headers to all requests/responses +//! - Logging and metrics +//! - Request/response transformation +//! +//! # Execution Order +//! +//! Request interceptors execute in registration order (1 → 2 → 3 → Handler). +//! Response interceptors execute in reverse order (Handler → 3 → 2 → 1). +//! +//! # Example +//! +//! ```rust,ignore +//! use rustapi_core::{RustApi, interceptor::{RequestInterceptor, ResponseInterceptor}}; +//! +//! struct AddRequestId; +//! +//! impl RequestInterceptor for AddRequestId { +//! fn intercept(&self, mut req: Request) -> Request { +//! req.extensions_mut().insert(uuid::Uuid::new_v4()); +//! req +//! } +//! } +//! +//! struct AddServerHeader; +//! +//! impl ResponseInterceptor for AddServerHeader { +//! fn intercept(&self, mut res: Response) -> Response { +//! res.headers_mut().insert("X-Server", "RustAPI".parse().unwrap()); +//! res +//! } +//! } +//! +//! RustApi::new() +//! .request_interceptor(AddRequestId) +//! .response_interceptor(AddServerHeader) +//! .route("/", get(handler)) +//! .run("127.0.0.1:8080") +//! .await +//! ``` + +use crate::request::Request; +use crate::response::Response; + +/// Trait for intercepting and modifying requests before they reach handlers. +/// +/// Request interceptors are executed in the order they are registered. +/// Each interceptor receives the request, can modify it, and returns the +/// (potentially modified) request for the next interceptor or handler. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_core::interceptor::RequestInterceptor; +/// use rustapi_core::Request; +/// +/// struct LoggingInterceptor; +/// +/// impl RequestInterceptor for LoggingInterceptor { +/// fn intercept(&self, req: Request) -> Request { +/// println!("Request: {} {}", req.method(), req.path()); +/// req +/// } +/// } +/// ``` +pub trait RequestInterceptor: Send + Sync + 'static { + /// Intercept and optionally modify the request. + /// + /// The returned request will be passed to the next interceptor or handler. + fn intercept(&self, request: Request) -> Request; + + /// Clone this interceptor into a boxed trait object. + fn clone_box(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Self { + self.clone_box() + } +} + +/// Trait for intercepting and modifying responses after handlers complete. +/// +/// Response interceptors are executed in reverse registration order. +/// Each interceptor receives the response, can modify it, and returns the +/// (potentially modified) response for the previous interceptor or client. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_core::interceptor::ResponseInterceptor; +/// use rustapi_core::Response; +/// +/// struct AddCorsHeaders; +/// +/// impl ResponseInterceptor for AddCorsHeaders { +/// fn intercept(&self, mut res: Response) -> Response { +/// res.headers_mut().insert( +/// "Access-Control-Allow-Origin", +/// "*".parse().unwrap() +/// ); +/// res +/// } +/// } +/// ``` +pub trait ResponseInterceptor: Send + Sync + 'static { + /// Intercept and optionally modify the response. + /// + /// The returned response will be passed to the previous interceptor or client. + fn intercept(&self, response: Response) -> Response; + + /// Clone this interceptor into a boxed trait object. + fn clone_box(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Self { + self.clone_box() + } +} + +/// Chain of request and response interceptors. +/// +/// Manages the execution of multiple interceptors in the correct order: +/// - Request interceptors: executed in registration order (first registered = first executed) +/// - Response interceptors: executed in reverse order (last registered = first executed) +#[derive(Clone, Default)] +pub struct InterceptorChain { + request_interceptors: Vec>, + response_interceptors: Vec>, +} + +impl InterceptorChain { + /// Create a new empty interceptor chain. + pub fn new() -> Self { + Self { + request_interceptors: Vec::new(), + response_interceptors: Vec::new(), + } + } + + /// Add a request interceptor to the chain. + /// + /// Interceptors are executed in the order they are added. + pub fn add_request_interceptor(&mut self, interceptor: I) { + self.request_interceptors.push(Box::new(interceptor)); + } + + /// Add a response interceptor to the chain. + /// + /// Interceptors are executed in reverse order (last added = first executed after handler). + pub fn add_response_interceptor(&mut self, interceptor: I) { + self.response_interceptors.push(Box::new(interceptor)); + } + + /// Get the number of request interceptors. + pub fn request_interceptor_count(&self) -> usize { + self.request_interceptors.len() + } + + /// Get the number of response interceptors. + pub fn response_interceptor_count(&self) -> usize { + self.response_interceptors.len() + } + + /// Check if the chain has any interceptors. + pub fn is_empty(&self) -> bool { + self.request_interceptors.is_empty() && self.response_interceptors.is_empty() + } + + /// Execute all request interceptors on the given request. + /// + /// Interceptors are executed in registration order. + /// Each interceptor receives the output of the previous one. + pub fn intercept_request(&self, mut request: Request) -> Request { + for interceptor in &self.request_interceptors { + request = interceptor.intercept(request); + } + request + } + + /// Execute all response interceptors on the given response. + /// + /// Interceptors are executed in reverse registration order. + /// Each interceptor receives the output of the previous one. + pub fn intercept_response(&self, mut response: Response) -> Response { + // Execute in reverse order (last registered = first to process response) + for interceptor in self.response_interceptors.iter().rev() { + response = interceptor.intercept(response); + } + response + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::path_params::PathParams; + use bytes::Bytes; + use http::{Extensions, Method, StatusCode}; + use http_body_util::Full; + use proptest::prelude::*; + use std::sync::Arc; + + /// Create a test request with the given method and path + fn create_test_request(method: Method, path: &str) -> Request { + let uri: http::Uri = path.parse().unwrap(); + let builder = http::Request::builder().method(method).uri(uri); + + let req = builder.body(()).unwrap(); + let (parts, _) = req.into_parts(); + + Request::new( + parts, + Bytes::new(), + Arc::new(Extensions::new()), + PathParams::new(), + ) + } + + /// Create a test response with the given status + fn create_test_response(status: StatusCode) -> Response { + http::Response::builder() + .status(status) + .body(Full::new(Bytes::from("test"))) + .unwrap() + } + + /// A request interceptor that adds a header tracking its ID + #[derive(Clone)] + struct TrackingRequestInterceptor { + id: usize, + order: Arc>>, + } + + impl TrackingRequestInterceptor { + fn new(id: usize, order: Arc>>) -> Self { + Self { id, order } + } + } + + impl RequestInterceptor for TrackingRequestInterceptor { + fn intercept(&self, request: Request) -> Request { + self.order.lock().unwrap().push(self.id); + request + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + } + + /// A response interceptor that adds a header tracking its ID + #[derive(Clone)] + struct TrackingResponseInterceptor { + id: usize, + order: Arc>>, + } + + impl TrackingResponseInterceptor { + fn new(id: usize, order: Arc>>) -> Self { + Self { id, order } + } + } + + impl ResponseInterceptor for TrackingResponseInterceptor { + fn intercept(&self, response: Response) -> Response { + self.order.lock().unwrap().push(self.id); + response + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + } + + // **Feature: v1-features-roadmap, Property 6: Interceptor execution order** + // + // For any set of N registered interceptors, request interceptors SHALL execute + // in registration order (1→N) and response interceptors SHALL execute in + // reverse order (N→1). + // + // **Validates: Requirements 2.1, 2.2, 2.3** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_interceptor_execution_order(num_interceptors in 1usize..10usize) { + let request_order = Arc::new(std::sync::Mutex::new(Vec::new())); + let response_order = Arc::new(std::sync::Mutex::new(Vec::new())); + + let mut chain = InterceptorChain::new(); + + // Add interceptors in order 0, 1, 2, ..., n-1 + for i in 0..num_interceptors { + chain.add_request_interceptor( + TrackingRequestInterceptor::new(i, request_order.clone()) + ); + chain.add_response_interceptor( + TrackingResponseInterceptor::new(i, response_order.clone()) + ); + } + + // Execute request interceptors + let request = create_test_request(Method::GET, "/test"); + let _ = chain.intercept_request(request); + + // Execute response interceptors + let response = create_test_response(StatusCode::OK); + let _ = chain.intercept_response(response); + + // Verify request interceptor order: should be 0, 1, 2, ..., n-1 + let req_order = request_order.lock().unwrap(); + prop_assert_eq!(req_order.len(), num_interceptors); + for (idx, &id) in req_order.iter().enumerate() { + prop_assert_eq!(id, idx, "Request interceptor order mismatch at index {}", idx); + } + + // Verify response interceptor order: should be n-1, n-2, ..., 1, 0 (reverse) + let res_order = response_order.lock().unwrap(); + prop_assert_eq!(res_order.len(), num_interceptors); + for (idx, &id) in res_order.iter().enumerate() { + let expected = num_interceptors - 1 - idx; + prop_assert_eq!(id, expected, "Response interceptor order mismatch at index {}", idx); + } + } + } + + /// A request interceptor that modifies a header + #[derive(Clone)] + struct HeaderModifyingRequestInterceptor { + header_name: &'static str, + header_value: String, + } + + impl HeaderModifyingRequestInterceptor { + fn new(header_name: &'static str, header_value: impl Into) -> Self { + Self { + header_name, + header_value: header_value.into(), + } + } + } + + impl RequestInterceptor for HeaderModifyingRequestInterceptor { + fn intercept(&self, mut request: Request) -> Request { + // Store the value in extensions since we can't modify headers directly + // In a real implementation, we'd need mutable header access + request.extensions_mut().insert(format!("{}:{}", self.header_name, self.header_value)); + request + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + } + + /// A response interceptor that modifies a header + #[derive(Clone)] + struct HeaderModifyingResponseInterceptor { + header_name: &'static str, + header_value: String, + } + + impl HeaderModifyingResponseInterceptor { + fn new(header_name: &'static str, header_value: impl Into) -> Self { + Self { + header_name, + header_value: header_value.into(), + } + } + } + + impl ResponseInterceptor for HeaderModifyingResponseInterceptor { + fn intercept(&self, mut response: Response) -> Response { + if let Ok(value) = self.header_value.parse() { + response.headers_mut().insert(self.header_name, value); + } + response + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + } + + // **Feature: v1-features-roadmap, Property 7: Interceptor modification propagation** + // + // For any modification made by an interceptor, subsequent interceptors and handlers + // SHALL receive the modified request/response. + // + // **Validates: Requirements 2.4, 2.5** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_interceptor_modification_propagation( + num_interceptors in 1usize..5usize, + header_values in prop::collection::vec("[a-zA-Z0-9]{1,10}", 1..5usize), + ) { + let mut chain = InterceptorChain::new(); + + // Add response interceptors that each add a unique header + for (i, value) in header_values.iter().enumerate().take(num_interceptors) { + let header_name = Box::leak(format!("x-test-{}", i).into_boxed_str()); + chain.add_response_interceptor( + HeaderModifyingResponseInterceptor::new(header_name, value.clone()) + ); + } + + // Execute response interceptors + let response = create_test_response(StatusCode::OK); + let modified_response = chain.intercept_response(response); + + // Verify all headers were added (modifications propagated) + for (i, value) in header_values.iter().enumerate().take(num_interceptors) { + let header_name = format!("x-test-{}", i); + let header_value = modified_response.headers().get(&header_name); + prop_assert!(header_value.is_some(), "Header {} should be present", header_name); + prop_assert_eq!( + header_value.unwrap().to_str().unwrap(), + value, + "Header {} should have value {}", header_name, value + ); + } + } + } + + #[test] + fn test_empty_chain() { + let chain = InterceptorChain::new(); + assert!(chain.is_empty()); + assert_eq!(chain.request_interceptor_count(), 0); + assert_eq!(chain.response_interceptor_count(), 0); + + // Should pass through unchanged + let request = create_test_request(Method::GET, "/test"); + let _ = chain.intercept_request(request); + + let response = create_test_response(StatusCode::OK); + let result = chain.intercept_response(response); + assert_eq!(result.status(), StatusCode::OK); + } + + #[test] + fn test_single_request_interceptor() { + let order = Arc::new(std::sync::Mutex::new(Vec::new())); + let mut chain = InterceptorChain::new(); + chain.add_request_interceptor(TrackingRequestInterceptor::new(42, order.clone())); + + assert!(!chain.is_empty()); + assert_eq!(chain.request_interceptor_count(), 1); + + let request = create_test_request(Method::GET, "/test"); + let _ = chain.intercept_request(request); + + let recorded = order.lock().unwrap(); + assert_eq!(recorded.len(), 1); + assert_eq!(recorded[0], 42); + } + + #[test] + fn test_single_response_interceptor() { + let order = Arc::new(std::sync::Mutex::new(Vec::new())); + let mut chain = InterceptorChain::new(); + chain.add_response_interceptor(TrackingResponseInterceptor::new(42, order.clone())); + + assert!(!chain.is_empty()); + assert_eq!(chain.response_interceptor_count(), 1); + + let response = create_test_response(StatusCode::OK); + let _ = chain.intercept_response(response); + + let recorded = order.lock().unwrap(); + assert_eq!(recorded.len(), 1); + assert_eq!(recorded[0], 42); + } + + #[test] + fn test_response_header_modification() { + let mut chain = InterceptorChain::new(); + chain.add_response_interceptor( + HeaderModifyingResponseInterceptor::new("x-custom", "value1") + ); + chain.add_response_interceptor( + HeaderModifyingResponseInterceptor::new("x-another", "value2") + ); + + let response = create_test_response(StatusCode::OK); + let modified = chain.intercept_response(response); + + // Both headers should be present + assert_eq!( + modified.headers().get("x-custom").unwrap().to_str().unwrap(), + "value1" + ); + assert_eq!( + modified.headers().get("x-another").unwrap().to_str().unwrap(), + "value2" + ); + } + + #[test] + fn test_chain_clone() { + let order = Arc::new(std::sync::Mutex::new(Vec::new())); + let mut chain = InterceptorChain::new(); + chain.add_request_interceptor(TrackingRequestInterceptor::new(1, order.clone())); + chain.add_response_interceptor(TrackingResponseInterceptor::new(2, order.clone())); + + // Clone the chain + let cloned = chain.clone(); + + assert_eq!(cloned.request_interceptor_count(), 1); + assert_eq!(cloned.response_interceptor_count(), 1); + } +} diff --git a/crates/rustapi-core/src/lib.rs b/crates/rustapi-core/src/lib.rs index a2172eb1..e58a981e 100644 --- a/crates/rustapi-core/src/lib.rs +++ b/crates/rustapi-core/src/lib.rs @@ -56,6 +56,8 @@ pub use auto_schema::apply_auto_schemas; mod error; mod extract; mod handler; +pub mod health; +pub mod interceptor; pub mod json; pub mod middleware; pub mod multipart; @@ -98,6 +100,8 @@ pub use handler::{ delete_route, get_route, patch_route, post_route, put_route, Handler, HandlerService, Route, RouteHandler, }; +pub use health::{HealthCheck, HealthCheckBuilder, HealthCheckResult, HealthStatus}; +pub use interceptor::{InterceptorChain, RequestInterceptor, ResponseInterceptor}; #[cfg(feature = "compression")] pub use middleware::CompressionLayer; pub use middleware::{BodyLimitLayer, RequestId, RequestIdLayer, TracingLayer, DEFAULT_BODY_LIMIT}; diff --git a/crates/rustapi-core/src/middleware/metrics.rs b/crates/rustapi-core/src/middleware/metrics.rs index d6dcd909..0957877d 100644 --- a/crates/rustapi-core/src/middleware/metrics.rs +++ b/crates/rustapi-core/src/middleware/metrics.rs @@ -289,7 +289,7 @@ mod tests { parts, Bytes::new(), Arc::new(Extensions::new()), - HashMap::new(), + HashMap::new().into(), ) } diff --git a/crates/rustapi-core/src/request.rs b/crates/rustapi-core/src/request.rs index 918ae332..6e278433 100644 --- a/crates/rustapi-core/src/request.rs +++ b/crates/rustapi-core/src/request.rs @@ -39,8 +39,8 @@ //! // Subsequent calls return None //! ``` -use bytes::Bytes; use crate::path_params::PathParams; +use bytes::Bytes; use http::{request::Parts, Extensions, HeaderMap, Method, Uri, Version}; use std::sync::Arc; @@ -143,6 +143,33 @@ impl Request { path_params: PathParams::new(), } } + /// Try to clone the request. + /// + /// This creates a deep copy of the request, including headers, body (if present), + /// path params, and shared state. + /// + /// Note: Request extensions stored in `parts` are NOT cloned because `http::Extensions` + /// does not support cloning. However, the shared state (`Arc`) IS preserved. + pub fn try_clone(&self) -> Option { + let mut builder = http::Request::builder() + .method(self.method().clone()) + .uri(self.uri().clone()) + .version(self.version()); + + if let Some(headers) = builder.headers_mut() { + *headers = self.headers().clone(); + } + + let req = builder.body(()).ok()?; + let (parts, _) = req.into_parts(); + + Some(Self { + parts, + body: self.body.clone(), + state: self.state.clone(), + path_params: self.path_params.clone(), + }) + } } impl std::fmt::Debug for Request { diff --git a/crates/rustapi-core/src/router.rs b/crates/rustapi-core/src/router.rs index 14ee760d..14a270a7 100644 --- a/crates/rustapi-core/src/router.rs +++ b/crates/rustapi-core/src/router.rs @@ -189,6 +189,60 @@ impl MethodRouter { self.handlers.insert(method.clone(), handler); self.operations.insert(method, operation); } + /// Add a GET handler + pub fn get(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let mut op = Operation::new(); + H::update_operation(&mut op); + self.on(Method::GET, into_boxed_handler(handler), op) + } + + /// Add a POST handler + pub fn post(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let mut op = Operation::new(); + H::update_operation(&mut op); + self.on(Method::POST, into_boxed_handler(handler), op) + } + + /// Add a PUT handler + pub fn put(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let mut op = Operation::new(); + H::update_operation(&mut op); + self.on(Method::PUT, into_boxed_handler(handler), op) + } + + /// Add a PATCH handler + pub fn patch(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let mut op = Operation::new(); + H::update_operation(&mut op); + self.on(Method::PATCH, into_boxed_handler(handler), op) + } + + /// Add a DELETE handler + pub fn delete(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let mut op = Operation::new(); + H::update_operation(&mut op); + self.on(Method::DELETE, into_boxed_handler(handler), op) + } } impl Default for MethodRouter { diff --git a/crates/rustapi-core/src/server.rs b/crates/rustapi-core/src/server.rs index 3b838e9c..c254b023 100644 --- a/crates/rustapi-core/src/server.rs +++ b/crates/rustapi-core/src/server.rs @@ -1,6 +1,7 @@ //! HTTP server implementation use crate::error::ApiError; +use crate::interceptor::InterceptorChain; use crate::middleware::{BoxedNext, LayerStack}; use crate::request::Request; use crate::response::IntoResponse; @@ -22,13 +23,15 @@ use tracing::{error, info}; pub(crate) struct Server { router: Arc, layers: Arc, + interceptors: Arc, } impl Server { - pub fn new(router: Router, layers: LayerStack) -> Self { + pub fn new(router: Router, layers: LayerStack, interceptors: InterceptorChain) -> Self { Self { router: Arc::new(router), layers: Arc::new(layers), + interceptors: Arc::new(interceptors), } } @@ -44,13 +47,15 @@ impl Server { let io = TokioIo::new(stream); let router = self.router.clone(); let layers = self.layers.clone(); + let interceptors = self.interceptors.clone(); tokio::spawn(async move { let service = service_fn(move |req: hyper::Request| { let router = router.clone(); let layers = layers.clone(); + let interceptors = interceptors.clone(); async move { - let response = handle_request(router, layers, req, remote_addr).await; + let response = handle_request(router, layers, interceptors, req, remote_addr).await; Ok::<_, Infallible>(response) } }); @@ -67,6 +72,7 @@ impl Server { async fn handle_request( router: Arc, layers: Arc, + interceptors: Arc, req: hyper::Request, _remote_addr: SocketAddr, ) -> hyper::Response> { @@ -114,6 +120,9 @@ async fn handle_request( // Build Request let request = Request::new(parts, body_bytes, router.state_ref(), params); + // Apply request interceptors (in registration order) + let request = interceptors.intercept_request(request); + // Create the final handler as a BoxedNext let final_handler: BoxedNext = Arc::new(move |req: Request| { let handler = handler.clone(); @@ -126,6 +135,9 @@ async fn handle_request( // Execute through middleware stack let response = layers.execute(request, final_handler).await; + // Apply response interceptors (in reverse registration order) + let response = interceptors.intercept_response(response); + log_request(&method, &path, response.status(), start); response } diff --git a/crates/rustapi-extras/Cargo.toml b/crates/rustapi-extras/Cargo.toml index 01a0c0d2..77c0a011 100644 --- a/crates/rustapi-extras/Cargo.toml +++ b/crates/rustapi-extras/Cargo.toml @@ -40,6 +40,10 @@ dashmap = { version = "6.0", optional = true } # SQLx (feature-gated) sqlx = { version = "0.8", optional = true, default-features = false } +# Diesel (feature-gated) +diesel = { version = "2.2", optional = true, default-features = false } +r2d2 = { version = "0.8", optional = true } + # Configuration (feature-gated) dotenvy = { version = "0.15", optional = true } envy = { version = "0.4", optional = true } @@ -53,6 +57,20 @@ urlencoding = { version = "2.1", optional = true } # HTTP client for webhook exporter (feature-gated) reqwest = { version = "0.12", optional = true, default-features = false, features = ["json", "rustls-tls"] } +# OpenTelemetry (feature-gated) +opentelemetry = { version = "0.22", optional = true } +opentelemetry_sdk = { version = "0.22", optional = true, features = ["rt-tokio"] } +opentelemetry-otlp = { version = "0.15", optional = true } +opentelemetry-semantic-conventions = { version = "0.14", optional = true } +tracing-opentelemetry = { version = "0.23", optional = true } + +# CSRF (feature-gated) +rand = { version = "0.8", optional = true } +base64 = { version = "0.22", optional = true } + +# OAuth2 (feature-gated) +sha2 = { version = "0.10", optional = true } + [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } proptest = "1.4" @@ -63,18 +81,49 @@ serial_test = "3.2" [features] default = [] -# Individual features +# Individual features jwt = ["dep:jsonwebtoken"] cors = [] rate-limit = ["dep:dashmap"] config = ["dep:dotenvy", "dep:envy"] cookies = ["dep:cookie"] sqlx = ["dep:sqlx"] +sqlx-postgres = ["sqlx", "sqlx/postgres", "sqlx/runtime-tokio"] +sqlx-mysql = ["sqlx", "sqlx/mysql", "sqlx/runtime-tokio"] +sqlx-sqlite = ["sqlx", "sqlx/sqlite", "sqlx/runtime-tokio"] +diesel = ["dep:diesel", "dep:r2d2"] +diesel-postgres = ["diesel", "diesel/postgres"] +diesel-mysql = ["diesel", "diesel/mysql"] +diesel-sqlite = ["diesel", "diesel/sqlite"] insight = ["dep:dashmap", "dep:urlencoding"] webhook = ["insight", "dep:reqwest"] +# Phase 11 features +timeout = [] +guard = ["jwt"] # Guard requires JWT for auth +logging = [] +circuit-breaker = [] +retry = [] +security-headers = [] +api-key = [] +cache = ["dep:dashmap"] +dedup = ["dep:dashmap"] +sanitization = [] + +# Phase 5: Observability features +otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp", "dep:opentelemetry-semantic-conventions", "dep:tracing-opentelemetry"] +structured-logging = [] + +# Phase 6: Security features +csrf = ["dep:cookie", "dep:rand", "dep:base64"] +oauth2-client = ["dep:sha2", "dep:rand", "dep:base64", "dep:reqwest", "dep:urlencoding"] +audit = [] + # Meta feature that enables all security features extras = ["jwt", "cors", "rate-limit"] -# Full feature set -full = ["extras", "config", "cookies", "sqlx", "insight", "webhook"] +# Observability meta feature +observability = ["otel", "structured-logging"] + +# Full feature set (retry temporarily disabled) +full = ["extras", "config", "cookies", "sqlx", "insight", "webhook", "timeout", "guard", "logging", "circuit-breaker", "security-headers", "api-key", "cache", "dedup", "sanitization", "retry", "otel", "structured-logging", "csrf", "oauth2-client", "audit"] diff --git a/crates/rustapi-extras/src/api_key.rs b/crates/rustapi-extras/src/api_key.rs new file mode 100644 index 00000000..9691b9a8 --- /dev/null +++ b/crates/rustapi-extras/src/api_key.rs @@ -0,0 +1,334 @@ +//! API Key authentication middleware +//! +//! This module provides API key-based authentication for securing endpoints. +//! Supports both header-based and query parameter API keys. +//! +//! # Example +//! +//! ```rust,no_run +//! use rustapi_core::RustApi; +//! use rustapi_extras::ApiKeyLayer; +//! +//! #[tokio::main] +//! async fn main() { +//! let app = RustApi::new() +//! .layer(Box::new( +//! ApiKeyLayer::new() +//! .header("X-API-Key") +//! .add_key("your-secret-api-key-here") +//! )) +//! .run("0.0.0.0:3000") +//! .await +//! .unwrap(); +//! } +//! ``` + +use rustapi_core::{ + middleware::{BoxedNext, MiddlewareLayer}, + Request, Response, +}; +use std::collections::HashSet; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +/// API Key authentication configuration +#[derive(Clone)] +pub struct ApiKeyConfig { + /// Valid API keys + pub keys: Arc>, + /// Header name to check for API key + pub header_name: String, + /// Query parameter name to check for API key + pub query_param_name: Option, + /// Paths to skip API key validation + pub skip_paths: Vec, +} + +impl Default for ApiKeyConfig { + fn default() -> Self { + Self { + keys: Arc::new(HashSet::new()), + header_name: "X-API-Key".to_string(), + query_param_name: None, + skip_paths: vec!["/health".to_string(), "/docs".to_string()], + } + } +} + +/// API Key authentication middleware +#[derive(Clone)] +pub struct ApiKeyLayer { + config: ApiKeyConfig, +} + +impl ApiKeyLayer { + /// Create a new API key layer with default configuration + pub fn new() -> Self { + Self { + config: ApiKeyConfig::default(), + } + } + + /// Set the header name to check for API key + pub fn header(mut self, name: impl Into) -> Self { + self.config.header_name = name.into(); + self + } + + /// Enable query parameter API key checking + pub fn query_param(mut self, name: impl Into) -> Self { + self.config.query_param_name = Some(name.into()); + self + } + + /// Add a valid API key + pub fn add_key(mut self, key: impl Into) -> Self { + let keys = Arc::make_mut(&mut self.config.keys); + keys.insert(key.into()); + self + } + + /// Add multiple valid API keys + pub fn add_keys(mut self, keys: Vec) -> Self { + let key_set = Arc::make_mut(&mut self.config.keys); + for key in keys { + key_set.insert(key); + } + self + } + + /// Skip API key validation for specific paths + pub fn skip_path(mut self, path: impl Into) -> Self { + self.config.skip_paths.push(path.into()); + self + } +} + +impl Default for ApiKeyLayer { + fn default() -> Self { + Self::new() + } +} + +impl MiddlewareLayer for ApiKeyLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + + Box::pin(async move { + let path = req.uri().path(); + + // Check if this path should skip validation + if config.skip_paths.iter().any(|p| path.starts_with(p)) { + return next(req).await; + } + + // Try to extract API key from header + let api_key = if let Some(header_value) = req.headers().get(&config.header_name) { + header_value.to_str().ok() + } else { + None + }; + + // If not in header, try query parameter + let api_key = if api_key.is_none() { + if let Some(query_param) = &config.query_param_name { + req.uri().query().and_then(|q| { + q.split('&').find_map(|param| { + let mut parts = param.split('='); + if parts.next()? == query_param { + parts.next() + } else { + None + } + }) + }) + } else { + None + } + } else { + api_key + }; + + // Validate API key + match api_key { + Some(key) if config.keys.contains(key) => { + // Valid API key, proceed + next(req).await + } + Some(_) => { + // Invalid API key + create_unauthorized_response("Invalid API key") + } + None => { + // Missing API key + create_unauthorized_response("Missing API key") + } + } + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +fn create_unauthorized_response(message: &str) -> Response { + let error_body = serde_json::json!({ + "error": { + "type": "unauthorized", + "message": message + } + }); + + let body = serde_json::to_vec(&error_body).unwrap_or_default(); + + http::Response::builder() + .status(401) + .header("Content-Type", "application/json") + .body(http_body_util::Full::new(bytes::Bytes::from(body))) + .unwrap() +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use std::sync::Arc; + + #[tokio::test] + async fn api_key_valid_header() { + let layer = ApiKeyLayer::new() + .header("X-API-Key") + .add_key("test-key-123"); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(bytes::Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/api/users") + .header("X-API-Key", "test-key-123") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert_eq!(response.status(), 200); + } + + #[tokio::test] + async fn api_key_invalid_header() { + let layer = ApiKeyLayer::new() + .header("X-API-Key") + .add_key("test-key-123"); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(bytes::Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/api/users") + .header("X-API-Key", "wrong-key") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert_eq!(response.status(), 401); + } + + #[tokio::test] + async fn api_key_missing() { + let layer = ApiKeyLayer::new() + .header("X-API-Key") + .add_key("test-key-123"); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(bytes::Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/api/users") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert_eq!(response.status(), 401); + } + + #[tokio::test] + async fn api_key_skips_health_check() { + let layer = ApiKeyLayer::new() + .header("X-API-Key") + .add_key("test-key-123"); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(bytes::Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/health") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert_eq!(response.status(), 200); + } + + #[tokio::test] + async fn api_key_query_param() { + let layer = ApiKeyLayer::new() + .query_param("api_key") + .add_key("test-key-123"); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(bytes::Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/api/users?api_key=test-key-123") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert_eq!(response.status(), 200); + } +} diff --git a/crates/rustapi-extras/src/cache.rs b/crates/rustapi-extras/src/cache.rs new file mode 100644 index 00000000..3c30033e --- /dev/null +++ b/crates/rustapi-extras/src/cache.rs @@ -0,0 +1,174 @@ +//! Response Caching Middleware +//! +//! Provides in-memory caching for HTTP responses. +//! Requires `cache` feature. + +use bytes::Bytes; +use dashmap::DashMap; +use http_body_util::BodyExt; +use rustapi_core::{ + middleware::{BoxedNext, MiddlewareLayer}, + Request, Response, +}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +/// Cache configuration +#[derive(Clone)] +pub struct CacheConfig { + /// Time-to-live for cached items + pub ttl: Duration, + /// Methods to cache (e.g., GET, HEAD) + pub methods: Vec, + /// Paths to skip caching + pub skip_paths: Vec, +} + +impl Default for CacheConfig { + fn default() -> Self { + Self { + ttl: Duration::from_secs(60), + methods: vec!["GET".to_string(), "HEAD".to_string()], + skip_paths: vec!["/health".to_string()], + } + } +} + +#[derive(Clone)] +struct CachedResponse { + status: http::StatusCode, + headers: http::HeaderMap, + body: Bytes, + created_at: Instant, +} + +/// In-memory response cache layer +#[derive(Clone)] +pub struct CacheLayer { + config: CacheConfig, + store: Arc>, +} + +impl CacheLayer { + /// Create a new cache layer + pub fn new() -> Self { + Self { + config: CacheConfig::default(), + store: Arc::new(DashMap::new()), + } + } + + /// Set TTL + pub fn ttl(mut self, ttl: Duration) -> Self { + self.config.ttl = ttl; + self + } + + /// Add a method to cache + pub fn add_method(mut self, method: &str) -> Self { + if !self.config.methods.contains(&method.to_string()) { + self.config.methods.push(method.to_string()); + } + self + } +} + +impl Default for CacheLayer { + fn default() -> Self { + Self::new() + } +} + +impl MiddlewareLayer for CacheLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + let store = self.store.clone(); + + Box::pin(async move { + let method = req.method().to_string(); + let uri = req.uri().to_string(); + + // Generate cache key + let key = format!("{}:{}", method, uri); + + // Check if cachable + if !config.methods.contains(&method) + || config.skip_paths.iter().any(|p| uri.starts_with(p)) + { + return next(req).await; + } + + // Clean expired entries (simple check on access) + if let Some(entry) = store.get(&key) { + if entry.created_at.elapsed() < config.ttl { + // Cache hit + let mut builder = http::Response::builder().status(entry.status); + for (k, v) in &entry.headers { + builder = builder.header(k, v); + } + builder = builder.header("X-Cache", "HIT"); + + return builder + .body(http_body_util::Full::new(entry.body.clone())) + .unwrap(); + } else { + // Expired + drop(entry); + store.remove(&key); + } + } + + // Cache miss: execute request + let response = next(req).await; + + // Only cache successful responses + if response.status().is_success() { + let (parts, body) = response.into_parts(); + + // Buffer the body + match body.collect().await { + Ok(bytes) => { + let bytes = bytes.to_bytes(); + + let cached = CachedResponse { + status: parts.status, + headers: parts.headers.clone(), + body: bytes.clone(), + created_at: Instant::now(), + }; + + store.insert(key, cached); + + let mut response = + http::Response::from_parts(parts, http_body_util::Full::new(bytes)); + response + .headers_mut() + .insert("X-Cache", "MISS".parse().unwrap()); + return response; + } + Err(_) => { + return http::Response::builder() + .status(500) + .body(http_body_util::Full::new(Bytes::from( + "Error buffering response for cache", + ))) + .unwrap(); + } + } + } + + // Return original if buffering failed or not successful + response + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} diff --git a/crates/rustapi-extras/src/circuit_breaker.rs b/crates/rustapi-extras/src/circuit_breaker.rs new file mode 100644 index 00000000..df40cd74 --- /dev/null +++ b/crates/rustapi-extras/src/circuit_breaker.rs @@ -0,0 +1,408 @@ +//! Circuit breaker middleware for resilient service calls +//! +//! This module implements the circuit breaker pattern to prevent cascading failures +//! and give failing services time to recover. +//! +//! # States +//! +//! - **Closed**: Normal operation, requests pass through +//! - **Open**: Too many failures, requests fail fast +//! - **HalfOpen**: Testing if service recovered +//! +//! # Example +//! +//! ```rust,no_run +//! use rustapi_core::RustApi; +//! use rustapi_extras::CircuitBreakerLayer; +//! use std::time::Duration; +//! +//! #[tokio::main] +//! async fn main() { +//! let app = RustApi::new() +//! .layer(Box::new( +//! CircuitBreakerLayer::new() +//! .failure_threshold(5) +//! .timeout(Duration::from_secs(30)) +//! )) +//! .run("0.0.0.0:3000") +//! .await +//! .unwrap(); +//! } +//! ``` + +use rustapi_core::{ + middleware::{BoxedNext, MiddlewareLayer}, + Request, Response, +}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +/// Circuit breaker state +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CircuitState { + /// Circuit is closed, requests pass through normally + Closed, + /// Circuit is open, requests fail fast + Open, + /// Circuit is half-open, testing if service recovered + HalfOpen, +} + +/// Circuit breaker configuration +#[derive(Clone)] +pub struct CircuitBreakerConfig { + /// Number of failures before opening the circuit + pub failure_threshold: usize, + /// Duration to wait before transitioning from Open to HalfOpen + pub timeout: Duration, + /// Number of successful requests in HalfOpen state before closing + pub success_threshold: usize, +} + +impl Default for CircuitBreakerConfig { + fn default() -> Self { + Self { + failure_threshold: 5, + timeout: Duration::from_secs(60), + success_threshold: 2, + } + } +} + +/// Circuit breaker state tracker +struct CircuitBreakerState { + state: CircuitState, + failure_count: usize, + success_count: usize, + last_failure_time: Option, + total_requests: u64, + total_failures: u64, + total_successes: u64, +} + +impl Default for CircuitBreakerState { + fn default() -> Self { + Self { + state: CircuitState::Closed, + failure_count: 0, + success_count: 0, + last_failure_time: None, + total_requests: 0, + total_failures: 0, + total_successes: 0, + } + } +} + +/// Circuit break middleware layer +#[derive(Clone)] +pub struct CircuitBreakerLayer { + config: CircuitBreakerConfig, + state: Arc>, +} + +impl CircuitBreakerLayer { + /// Create a new circuit breaker with default configuration + pub fn new() -> Self { + Self { + config: CircuitBreakerConfig::default(), + state: Arc::new(RwLock::new(CircuitBreakerState::default())), + } + } + + /// Set the failure threshold + pub fn failure_threshold(mut self, threshold: usize) -> Self { + self.config.failure_threshold = threshold; + self + } + + /// Set the timeout before transitioning to half-open + pub fn timeout(mut self, timeout: Duration) -> Self { + self.config.timeout = timeout; + self + } + + /// Set the success threshold in half-open state + pub fn success_threshold(mut self, threshold: usize) -> Self { + self.config.success_threshold = threshold; + self + } + + /// Get the current circuit state + pub async fn get_state(&self) -> CircuitState { + self.state.read().await.state + } + + /// Get circuit breaker statistics + pub async fn get_stats(&self) -> CircuitBreakerStats { + let state = self.state.read().await; + CircuitBreakerStats { + state: state.state, + total_requests: state.total_requests, + total_failures: state.total_failures, + total_successes: state.total_successes, + failure_count: state.failure_count, + success_count: state.success_count, + } + } + + /// Reset the circuit breaker + pub async fn reset(&self) { + let mut state = self.state.write().await; + *state = CircuitBreakerState::default(); + } +} + +impl Default for CircuitBreakerLayer { + fn default() -> Self { + Self::new() + } +} + +/// Circuit breaker statistics +#[derive(Debug, Clone)] +pub struct CircuitBreakerStats { + /// Current state + pub state: CircuitState, + /// Total requests processed + pub total_requests: u64, + /// Total failures + pub total_failures: u64, + /// Total successes + pub total_successes: u64, + /// Current failure count + pub failure_count: usize, + /// Current success count (in half-open state) + pub success_count: usize, +} + +impl MiddlewareLayer for CircuitBreakerLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + let state = self.state.clone(); + + Box::pin(async move { + // Check current state + let mut state_guard = state.write().await; + state_guard.total_requests += 1; + + match state_guard.state { + CircuitState::Open => { + // Check if timeout has elapsed + if let Some(last_failure) = state_guard.last_failure_time { + if last_failure.elapsed() >= config.timeout { + // Transition to half-open + tracing::info!("Circuit breaker transitioning to HalfOpen"); + state_guard.state = CircuitState::HalfOpen; + state_guard.success_count = 0; + } else { + // Still open, fail fast + drop(state_guard); + return http::Response::builder() + .status(503) + .header("Content-Type", "application/json") + .body(http_body_util::Full::new(bytes::Bytes::from( + serde_json::json!({ + "error": { + "type": "service_unavailable", + "message": "Circuit breaker is OPEN" + } + }) + .to_string(), + ))) + .unwrap(); + } + } + } + CircuitState::HalfOpen => { + // Allow request but monitor closely + } + CircuitState::Closed => { + // Normal operation + } + } + + drop(state_guard); + + // Execute request + let response = next(req).await; + + // Update state based on result + let mut state_guard = state.write().await; + + // Check if response indicates success (2xx status) + if response.status().is_success() { + state_guard.total_successes += 1; + + match state_guard.state { + CircuitState::HalfOpen => { + state_guard.success_count += 1; + if state_guard.success_count >= config.success_threshold { + // Transition to closed + tracing::info!("Circuit breaker transitioning to Closed"); + state_guard.state = CircuitState::Closed; + state_guard.failure_count = 0; + state_guard.success_count = 0; + } + } + CircuitState::Closed => { + // Reset failure count on success + state_guard.failure_count = 0; + } + _ => {} + } + } else { + // Non-2xx status is treated as failure + record_failure(&mut state_guard, &config); + } + + response + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +fn record_failure(state: &mut CircuitBreakerState, config: &CircuitBreakerConfig) { + state.total_failures += 1; + state.failure_count += 1; + state.last_failure_time = Some(Instant::now()); + + match state.state { + CircuitState::Closed => { + if state.failure_count >= config.failure_threshold { + // Open the circuit + tracing::warn!( + "Circuit breaker OPENING after {} failures", + state.failure_count + ); + state.state = CircuitState::Open; + } + } + CircuitState::HalfOpen => { + // Failed in half-open, go back to open + tracing::warn!("Circuit breaker returning to OPEN state"); + state.state = CircuitState::Open; + state.success_count = 0; + } + _ => {} + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use std::sync::Arc; + + #[tokio::test] + async fn circuit_breaker_opens_after_threshold() { + let breaker = CircuitBreakerLayer::new() + .failure_threshold(3) + .timeout(Duration::from_secs(1)); + + // Create a handler that always fails + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(500) + .body(http_body_util::Full::new(bytes::Bytes::from("Error"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + // Make requests that fail + for _ in 0..3 { + let req = http::Request::builder() + .method("GET") + .uri("/") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let _ = breaker.call(req, next.clone()).await; + } + + // Circuit should be open now + let state = breaker.get_state().await; + assert_eq!(state, CircuitState::Open); + + // Next request should fail fast + let req = http::Request::builder() + .method("GET") + .uri("/") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = breaker.call(req, next.clone()).await; + assert_eq!(response.status(), 503); + } + + #[tokio::test] + async fn circuit_breaker_recovers() { + let breaker = CircuitBreakerLayer::new() + .failure_threshold(2) + .timeout(Duration::from_millis(100)) + .success_threshold(2); + + // Fail requests to open circuit + let fail_next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(500) + .body(http_body_util::Full::new(bytes::Bytes::from("Error"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + for _ in 0..2 { + let req = http::Request::builder() + .method("GET") + .uri("/") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + let _ = breaker.call(req, fail_next.clone()).await; + } + + assert_eq!(breaker.get_state().await, CircuitState::Open); + + // Wait for timeout + tokio::time::sleep(Duration::from_millis(150)).await; + + // Make successful requests + let success_next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(bytes::Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + for _ in 0..2 { + let req = http::Request::builder() + .method("GET") + .uri("/") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + let result = breaker.call(req, success_next.clone()).await; + assert!(result.status().is_success()); + } + + // Circuit should be closed now + let state = breaker.get_state().await; + assert_eq!(state, CircuitState::Closed); + } +} diff --git a/crates/rustapi-extras/src/csrf/config.rs b/crates/rustapi-extras/src/csrf/config.rs new file mode 100644 index 00000000..cb255081 --- /dev/null +++ b/crates/rustapi-extras/src/csrf/config.rs @@ -0,0 +1,97 @@ +use cookie::SameSite; +use std::time::Duration; + +/// Configuration for CSRF protection. +#[derive(Clone, Debug)] +pub struct CsrfConfig { + /// The name of the cookie used to store the CSRF token. + /// Default: "XSRF-TOKEN" + pub cookie_name: String, + + /// The name of the header expected to contain the CSRF token. + /// Default: "X-XSRF-TOKEN" + pub header_name: String, + + /// The path for the CSRF cookie. + /// Default: "/" + pub cookie_path: String, + + /// The domain for the CSRF cookie. + /// Default: None + pub cookie_domain: Option, + + /// Whether the CSRF cookie should be secure (HTTPS only). + /// Default: true (in release mode) + pub cookie_secure: bool, + + /// Whether the CSRF cookie should be HTTP Only. + /// For the Double-Submit Cookie pattern, this MUST be false so the client can read it + /// and send it back in a header. + /// Default: false + pub cookie_http_only: bool, + + /// The SameSite attribute for the CSRF cookie. + /// Default: Lax + pub cookie_same_site: SameSite, + + /// The lifetime of the CSRF cookie. + /// Default: 24 hours + pub cookie_max_age: Duration, + + /// The length of the generated random token (in bytes). + /// Default: 32 (resulting in ~44 chars base64) + pub token_length: usize, +} + +impl Default for CsrfConfig { + fn default() -> Self { + Self { + cookie_name: "XSRF-TOKEN".to_string(), + header_name: "X-XSRF-TOKEN".to_string(), + cookie_path: "/".to_string(), + cookie_domain: None, + cookie_secure: true, // Should logic check generic debug/release? + cookie_http_only: false, + cookie_same_site: SameSite::Lax, + cookie_max_age: Duration::from_secs(60 * 60 * 24), + token_length: 32, + } + } +} + +impl CsrfConfig { + /// Create a new default configuration. + pub fn new() -> Self { + Self::default() + } + + /// Set the cookie name. + pub fn cookie_name(mut self, name: impl Into) -> Self { + self.cookie_name = name.into(); + self + } + + /// Set the header name. + pub fn header_name(mut self, name: impl Into) -> Self { + self.header_name = name.into(); + self + } + + /// Set the cookie domain. + pub fn cookie_domain(mut self, domain: impl Into) -> Self { + self.cookie_domain = Some(domain.into()); + self + } + + /// Set the secure flag. + pub fn secure(mut self, secure: bool) -> Self { + self.cookie_secure = secure; + self + } + + /// Set the SameSite attribute. + pub fn same_site(mut self, same_site: SameSite) -> Self { + self.cookie_same_site = same_site; + self + } +} diff --git a/crates/rustapi-extras/src/csrf/layer.rs b/crates/rustapi-extras/src/csrf/layer.rs new file mode 100644 index 00000000..b6646370 --- /dev/null +++ b/crates/rustapi-extras/src/csrf/layer.rs @@ -0,0 +1,282 @@ +use super::config::CsrfConfig; +use super::token::CsrfToken; +use cookie::Cookie; +use http::{Method, StatusCode}; +use rustapi_core::middleware::{BoxedNext, MiddlewareLayer}; +use rustapi_core::{ApiError, IntoResponse, Request, Response}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +/// Middleware for CSRF protection using the Double-Submit Cookie pattern. +#[derive(Clone, Debug)] +pub struct CsrfLayer { + config: Arc, +} + +impl CsrfLayer { + /// Create a new CSRF middleware layer. + pub fn new(config: CsrfConfig) -> Self { + Self { + config: Arc::new(config), + } + } +} + +impl MiddlewareLayer for CsrfLayer { + fn call( + &self, + mut req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + + Box::pin(async move { + // 1. Extract existing token from cookie + let existing_token = req + .headers() + .get(http::header::COOKIE) + .and_then(|h| h.to_str().ok()) + .and_then(|cookie_str| { + cookie::Cookie::split_parse(cookie_str) + .filter_map(|c| c.ok()) + .find(|c| c.name() == config.cookie_name) + .map(|c| c.value().to_string()) + }) + .map(CsrfToken::new); + + // 2. Determine the token to use for this request context + // If existing, use it. If not, generate new. + let (token, is_new) = match existing_token { + Some(t) => (t, false), + None => (CsrfToken::generate(config.token_length), true), + }; + + // 3. Store token in request extensions so handlers/templates can access it + req.extensions_mut().insert(token.clone()); + + // 4. Validate if unsafe method + let method = req.method(); + let is_safe = matches!( + *method, + Method::GET | Method::HEAD | Method::OPTIONS | Method::TRACE + ); + + if !is_safe { + // For unsafe methods, we MUST have received a matching token in the header + let header_value = req + .headers() + .get(&config.header_name) + .and_then(|v| v.to_str().ok()); + + let valid = match header_value { + Some(h_token) => h_token == token.as_str(), + None => false, + }; + + if !valid { + // Mismatch or missing header -> Forbidden + // If cookie was missing (is_new=true), it fails here too as header can't match. + // We return JSON error for consistency + return ApiError::new( + StatusCode::FORBIDDEN, + "csrf_forbidden", + "CSRF token validation failed", + ) + .into_response(); + } + } + + // 5. Proceed + let mut response = next(req).await; + + // 6. Set cookie if new + if is_new { + let mut cookie = + Cookie::build((config.cookie_name.clone(), token.as_str().to_owned())) + .path(config.cookie_path.clone()) + .secure(config.cookie_secure) + .http_only(config.cookie_http_only) + .same_site(config.cookie_same_site); + + if let Some(domain) = &config.cookie_domain { + cookie = cookie.domain(domain.clone()); + } + + // Note: Not setting max-age strictly to avoid dependency complexity in this snippets, + // but usually recommended. + + let c = cookie.build(); + let header_value = c.to_string(); + + response.headers_mut().append( + http::header::SET_COOKIE, + header_value + .parse() + .unwrap_or(http::header::HeaderValue::from_static("")), + ); + } + + response + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::StatusCode; + use rustapi_core::{get, post, RustApi, TestClient, TestRequest}; + + async fn handler() -> &'static str { + "ok" + } + + #[tokio::test] + async fn test_safe_method_generates_cookie() { + let config = CsrfConfig::new().cookie_name("csrf_id"); + let app = RustApi::new() + .layer(CsrfLayer::new(config)) + .route("/", get(handler)); + + let client = TestClient::new(app); + let res = client.get("/").await; + + assert_eq!(res.status(), StatusCode::OK); + let cookies = res + .headers() + .get("set-cookie") + .expect("No cookie set") + .to_str() + .unwrap(); + assert!(cookies.contains("csrf_id=")); + } + + #[tokio::test] + async fn test_unsafe_method_without_cookie_fails() { + let config = CsrfConfig::new(); + let app = RustApi::new() + .layer(CsrfLayer::new(config)) + .route("/", post(handler)); + + let client = TestClient::new(app); + // POST without cookie or header + let res = client.request(TestRequest::post("/")).await; + + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn test_unsafe_method_valid_passes() { + let config = CsrfConfig::new().cookie_name("ID").header_name("X-ID"); + let app = RustApi::new() + .layer(CsrfLayer::new(config)) + .route("/", post(handler)); + + let client = TestClient::new(app); + let res = client + .request( + TestRequest::post("/") + .header("Cookie", "ID=token123") + .header("X-ID", "token123"), + ) + .await; + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn test_unsafe_method_mismatch_fails() { + let config = CsrfConfig::new().cookie_name("ID").header_name("X-ID"); + let app = RustApi::new() + .layer(CsrfLayer::new(config)) + .route("/", post(handler)); + + let client = TestClient::new(app); + let res = client + .request( + TestRequest::post("/") + .header("Cookie", "ID=token123") + .header("X-ID", "wrongtoken"), + ) + .await; + + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn test_csrf_lifecycle() { + let config = CsrfConfig::new() + .cookie_name("token") + .header_name("x-token"); + // Chain handlers on same route to avoid conflict + let app = RustApi::new() + .layer(CsrfLayer::new(config)) + .route("/", get(handler).post(handler)); + + let client = TestClient::new(app); + + // 1. Initial GET to get token + let res = client.get("/").await; + assert_eq!(res.status(), StatusCode::OK); + let set_cookie = res + .headers() + .get("set-cookie") + .expect("No cookie set") + .to_str() + .unwrap(); + + // Parse cookie value (simple parse for "token=VALUE; ...") + let token_part = set_cookie.split(';').next().unwrap(); // "token=VALUE" + let token_val = token_part.split('=').nth(1).unwrap(); + + // 2. Unsafe POST with valid token + let res = client + .request( + TestRequest::post("/") + .header("Cookie", token_part) + .header("x-token", token_val), + ) + .await; + assert_eq!(res.status(), StatusCode::OK); + + // 3. Unsafe POST with invalid token (Mismatch) + let res = client + .request( + TestRequest::post("/") + .header("Cookie", token_part) + .header("x-token", "bad"), + ) + .await; + assert_eq!(res.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn test_token_extraction() { + use crate::csrf::CsrfToken; + + async fn token_handler(token: CsrfToken) -> String { + token.as_str().to_string() + } + + let config = CsrfConfig::new().cookie_name("csrf_id"); + let app = RustApi::new() + .layer(CsrfLayer::new(config)) + .route("/", get(token_handler)); + + let client = TestClient::new(app); + let res = client.get("/").await; + + assert_eq!(res.status(), StatusCode::OK); + let body = res.text(); + assert!(!body.is_empty()); + + // Verify token matches cookie + let cookie_val = res.headers().get("set-cookie").unwrap().to_str().unwrap(); + assert!(cookie_val.contains(&body)); + } +} diff --git a/crates/rustapi-extras/src/csrf/mod.rs b/crates/rustapi-extras/src/csrf/mod.rs new file mode 100644 index 00000000..27b8fa22 --- /dev/null +++ b/crates/rustapi-extras/src/csrf/mod.rs @@ -0,0 +1,25 @@ +//! CSRF Protection Module +//! +//! This module implements Double-Submit Cookie CSRF protection. +//! +//! # Example +//! +//! ```rust,no_run +//! use rustapi_core::RustApi; +//! use rustapi_extras::csrf::{CsrfConfig, CsrfLayer}; +//! +//! let config = CsrfConfig::new() +//! .cookie_name("my-csrf-cookie") +//! .header_name("X-CSRF-TOKEN"); +//! +//! let app = RustApi::new() +//! .layer(CsrfLayer::new(config)); +//! ``` + +mod config; +mod layer; +mod token; + +pub use config::CsrfConfig; +pub use layer::CsrfLayer; +pub use token::CsrfToken; diff --git a/crates/rustapi-extras/src/csrf/token.rs b/crates/rustapi-extras/src/csrf/token.rs new file mode 100644 index 00000000..ee5e8382 --- /dev/null +++ b/crates/rustapi-extras/src/csrf/token.rs @@ -0,0 +1,63 @@ +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use rand::{rngs::OsRng, RngCore}; +use std::fmt; + +/// A CSRF token. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct CsrfToken(String); + +impl CsrfToken { + /// Generate a new random CSRF token of the specified length. + pub fn generate(length: usize) -> Self { + let mut bytes = vec![0u8; length]; + OsRng.fill_bytes(&mut bytes); + let token = URL_SAFE_NO_PAD.encode(&bytes); + Self(token) + } + + /// Create a token from an existing string. + pub fn new(token: String) -> Self { + Self(token) + } + + /// Get the token string. + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl fmt::Debug for CsrfToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("CsrfToken").field(&"***").finish() + } +} + +impl fmt::Display for CsrfToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl rustapi_core::FromRequestParts for CsrfToken { + fn from_request_parts(req: &rustapi_core::Request) -> rustapi_core::Result { + use http::StatusCode; + use rustapi_core::ApiError; + + match req.extensions().get::() { + Some(token) => Ok(token.clone()), + None => Err(ApiError::new( + StatusCode::INTERNAL_SERVER_ERROR, + "csrf_missing", + "CSRF token missing from request extensions. Ensure CSRF middleware is enabled.", + )), + } + } +} + +impl rustapi_openapi::OperationModifier for CsrfToken { + fn update_operation(_op: &mut rustapi_openapi::Operation) { + // CSRF token is handled by middleware, so we don't need to document + // it as a parameter for every operation that extracts it. + // It's usually part of the global security requirements. + } +} diff --git a/crates/rustapi-extras/src/dedup.rs b/crates/rustapi-extras/src/dedup.rs new file mode 100644 index 00000000..279934b4 --- /dev/null +++ b/crates/rustapi-extras/src/dedup.rs @@ -0,0 +1,134 @@ +//! Request Deduplication Middleware +//! +//! Prevents processing of duplicate requests based on an Idempotency-Key header. +//! Requires `dedup` feature. + +use dashmap::DashMap; +use rustapi_core::{ + middleware::{BoxedNext, MiddlewareLayer}, + Request, Response, +}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +/// Deduplication configuration +#[derive(Clone)] +pub struct DedupConfig { + /// Name of the header containing the idempotency key + pub header_name: String, + /// Time-to-live for deduplication entries + pub ttl: Duration, +} + +impl Default for DedupConfig { + fn default() -> Self { + Self { + header_name: "Idempotency-Key".to_string(), + ttl: Duration::from_secs(300), // 5 minutes default + } + } +} + +/// Deduplication middleware layer +#[derive(Clone)] +pub struct DedupLayer { + config: DedupConfig, + /// Stores idempotency keys and their creation time. + /// Value is optional Response if we wanted to support caching (not implemented in V1) + /// For now, just tracks presence. + store: Arc>, +} + +impl DedupLayer { + /// Create a new deduplication layer + pub fn new() -> Self { + Self { + config: DedupConfig::default(), + store: Arc::new(DashMap::new()), + } + } + + /// Set custom header name + pub fn header_name(mut self, name: impl Into) -> Self { + self.config.header_name = name.into(); + self + } + + /// Set TTL + pub fn ttl(mut self, ttl: Duration) -> Self { + self.config.ttl = ttl; + self + } +} + +impl Default for DedupLayer { + fn default() -> Self { + Self::new() + } +} + +impl MiddlewareLayer for DedupLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + let store = self.store.clone(); + + Box::pin(async move { + // Check for idempotency key + let key = if let Some(val) = req.headers().get(&config.header_name) { + match val.to_str() { + Ok(s) => s.to_string(), + Err(_) => return next(req).await, // Invalid header value, proceed as normal? or Error? Proceeding is safer. + } + } else { + // No key, proceed normally + return next(req).await; + }; + + // Check if key exists and is valid + if let Some(created_at) = store.get(&key) { + if created_at.elapsed() < config.ttl { + // Duplicate request detected + // Determine if processing or finished. For V1 we just say "Conflict / Already Processed" + return http::Response::builder() + .status(409) // Conflict + .header("Content-Type", "application/json") + .body(http_body_util::Full::new(bytes::Bytes::from( + serde_json::json!({ + "error": { + "type": "duplicate_request", + "message": format!("Request with key '{}' has already been processed or is processing", key) + } + }) + .to_string(), + ))) + .unwrap(); + } else { + // Expired, remove + drop(created_at); + store.remove(&key); + } + } + + // New key, track it + store.insert(key.clone(), Instant::now()); + + // Process request + // Note: In a robust implementation, we might want to remove the key if processing fails, + // or update it with the response for caching (Idempotency Cache pattern). + // For simple Deduplication (prevent double-submit), keeping it is fine. + let response = next(req).await; + + response + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} diff --git a/crates/rustapi-extras/src/diesel/mod.rs b/crates/rustapi-extras/src/diesel/mod.rs new file mode 100644 index 00000000..33a529ec --- /dev/null +++ b/crates/rustapi-extras/src/diesel/mod.rs @@ -0,0 +1,522 @@ +//! Diesel database integration for RustAPI +//! +//! This module provides a pool builder for Diesel connection pools with +//! health check integration. +//! +//! ## Pool Builder Example +//! +//! ```rust,ignore +//! use rustapi_extras::diesel::{DieselPoolBuilder, DieselPoolError}; +//! use std::time::Duration; +//! +//! fn main() -> Result<(), DieselPoolError> { +//! let pool = DieselPoolBuilder::new("postgres://user:pass@localhost/db") +//! .max_connections(10) +//! .min_idle(Some(2)) +//! .connection_timeout(Duration::from_secs(5)) +//! .idle_timeout(Some(Duration::from_secs(300))) +//! .max_lifetime(Some(Duration::from_secs(3600))) +//! .build_postgres()?; +//! +//! // Use pool... +//! Ok(()) +//! } +//! ``` + +use rustapi_core::health::{HealthCheck, HealthCheckBuilder, HealthStatus}; +use std::sync::Arc; +use std::time::Duration; +use thiserror::Error; + +/// Error type for Diesel pool operations +#[derive(Debug, Error)] +pub enum DieselPoolError { + /// Configuration error + #[error("Pool configuration error: {0}")] + Configuration(String), + + /// Connection error + #[error("Database connection error: {0}")] + Connection(String), + + /// R2D2 pool error + #[error("Pool error: {0}")] + Pool(String), +} + +/// Configuration for Diesel connection pool +/// +/// This struct holds all configuration options for the pool builder. +#[derive(Debug, Clone)] +pub struct DieselPoolConfig { + /// Database connection URL + pub url: String, + /// Maximum number of connections in the pool + pub max_connections: u32, + /// Minimum number of idle connections to maintain + pub min_idle: Option, + /// Timeout for acquiring a connection + pub connection_timeout: Duration, + /// Maximum idle time before a connection is closed + pub idle_timeout: Option, + /// Maximum lifetime of a connection + pub max_lifetime: Option, +} + +impl Default for DieselPoolConfig { + fn default() -> Self { + Self { + url: String::new(), + max_connections: 10, + min_idle: None, + connection_timeout: Duration::from_secs(30), + idle_timeout: Some(Duration::from_secs(600)), + max_lifetime: Some(Duration::from_secs(1800)), + } + } +} + +impl DieselPoolConfig { + /// Validate the configuration + pub fn validate(&self) -> Result<(), DieselPoolError> { + if self.url.is_empty() { + return Err(DieselPoolError::Configuration( + "Database URL cannot be empty".to_string(), + )); + } + if self.max_connections == 0 { + return Err(DieselPoolError::Configuration( + "max_connections must be greater than 0".to_string(), + )); + } + if let Some(min_idle) = self.min_idle { + if min_idle > self.max_connections { + return Err(DieselPoolError::Configuration( + "min_idle cannot exceed max_connections".to_string(), + )); + } + } + Ok(()) + } +} + +/// Builder for Diesel connection pools +/// +/// Provides a fluent API for configuring database connection pools with +/// sensible defaults and health check integration. +/// +/// # Example +/// +/// ```rust,ignore +/// use rustapi_extras::diesel::DieselPoolBuilder; +/// use std::time::Duration; +/// +/// let pool = DieselPoolBuilder::new("postgres://localhost/mydb") +/// .max_connections(20) +/// .min_idle(Some(5)) +/// .connection_timeout(Duration::from_secs(10)) +/// .build_postgres()?; +/// ``` +#[derive(Debug, Clone)] +pub struct DieselPoolBuilder { + config: DieselPoolConfig, +} + +impl DieselPoolBuilder { + /// Create a new pool builder with the given database URL + /// + /// # Arguments + /// + /// * `url` - Database connection URL (e.g., "postgres://user:pass@localhost/db") + pub fn new(url: impl Into) -> Self { + Self { + config: DieselPoolConfig { + url: url.into(), + ..Default::default() + }, + } + } + + /// Set the maximum number of connections in the pool + /// + /// Default: 10 + pub fn max_connections(mut self, n: u32) -> Self { + self.config.max_connections = n; + self + } + + /// Set the minimum number of idle connections to maintain + /// + /// Default: None (no minimum) + pub fn min_idle(mut self, n: Option) -> Self { + self.config.min_idle = n; + self + } + + /// Set the timeout for acquiring a connection + /// + /// Default: 30 seconds + pub fn connection_timeout(mut self, d: Duration) -> Self { + self.config.connection_timeout = d; + self + } + + /// Set the maximum idle time before a connection is closed + /// + /// Default: 600 seconds (10 minutes) + pub fn idle_timeout(mut self, d: Option) -> Self { + self.config.idle_timeout = d; + self + } + + /// Set the maximum lifetime of a connection + /// + /// Default: 1800 seconds (30 minutes) + pub fn max_lifetime(mut self, d: Option) -> Self { + self.config.max_lifetime = d; + self + } + + /// Get the current configuration + pub fn config(&self) -> &DieselPoolConfig { + &self.config + } + + /// Build a PostgreSQL connection pool + /// + /// # Errors + /// + /// Returns an error if: + /// - The configuration is invalid + /// - The connection cannot be established + #[cfg(feature = "diesel-postgres")] + pub fn build_postgres( + self, + ) -> Result>, DieselPoolError> + { + self.config.validate()?; + + let manager = diesel::r2d2::ConnectionManager::::new(&self.config.url); + + let mut builder = r2d2::Pool::builder() + .max_size(self.config.max_connections) + .connection_timeout(self.config.connection_timeout); + + if let Some(min_idle) = self.config.min_idle { + builder = builder.min_idle(Some(min_idle)); + } + + if let Some(idle_timeout) = self.config.idle_timeout { + builder = builder.idle_timeout(Some(idle_timeout)); + } + + if let Some(max_lifetime) = self.config.max_lifetime { + builder = builder.max_lifetime(Some(max_lifetime)); + } + + builder + .build(manager) + .map_err(|e| DieselPoolError::Pool(e.to_string())) + } + + /// Build a MySQL connection pool + /// + /// # Errors + /// + /// Returns an error if: + /// - The configuration is invalid + /// - The connection cannot be established + #[cfg(feature = "diesel-mysql")] + pub fn build_mysql( + self, + ) -> Result>, DieselPoolError> + { + self.config.validate()?; + + let manager = + diesel::r2d2::ConnectionManager::::new(&self.config.url); + + let mut builder = r2d2::Pool::builder() + .max_size(self.config.max_connections) + .connection_timeout(self.config.connection_timeout); + + if let Some(min_idle) = self.config.min_idle { + builder = builder.min_idle(Some(min_idle)); + } + + if let Some(idle_timeout) = self.config.idle_timeout { + builder = builder.idle_timeout(Some(idle_timeout)); + } + + if let Some(max_lifetime) = self.config.max_lifetime { + builder = builder.max_lifetime(Some(max_lifetime)); + } + + builder + .build(manager) + .map_err(|e| DieselPoolError::Pool(e.to_string())) + } + + /// Build a SQLite connection pool + /// + /// # Errors + /// + /// Returns an error if: + /// - The configuration is invalid + /// - The connection cannot be established + #[cfg(feature = "diesel-sqlite")] + pub fn build_sqlite( + self, + ) -> Result>, DieselPoolError> + { + self.config.validate()?; + + let manager = + diesel::r2d2::ConnectionManager::::new(&self.config.url); + + let mut builder = r2d2::Pool::builder() + .max_size(self.config.max_connections) + .connection_timeout(self.config.connection_timeout); + + if let Some(min_idle) = self.config.min_idle { + builder = builder.min_idle(Some(min_idle)); + } + + if let Some(idle_timeout) = self.config.idle_timeout { + builder = builder.idle_timeout(Some(idle_timeout)); + } + + if let Some(max_lifetime) = self.config.max_lifetime { + builder = builder.max_lifetime(Some(max_lifetime)); + } + + builder + .build(manager) + .map_err(|e| DieselPoolError::Pool(e.to_string())) + } + + /// Create a health check for a PostgreSQL pool + /// + /// The health check will attempt to get a connection from the pool. + #[cfg(feature = "diesel-postgres")] + pub fn health_check_postgres( + pool: Arc>>, + ) -> HealthCheck { + HealthCheckBuilder::new(false) + .add_check("postgres", move || { + let pool = pool.clone(); + async move { + match pool.get() { + Ok(_) => HealthStatus::healthy(), + Err(e) => HealthStatus::unhealthy(format!("Database check failed: {}", e)), + } + } + }) + .build() + } + + /// Create a health check for a MySQL pool + /// + /// The health check will attempt to get a connection from the pool. + #[cfg(feature = "diesel-mysql")] + pub fn health_check_mysql( + pool: Arc>>, + ) -> HealthCheck { + HealthCheckBuilder::new(false) + .add_check("mysql", move || { + let pool = pool.clone(); + async move { + match pool.get() { + Ok(_) => HealthStatus::healthy(), + Err(e) => HealthStatus::unhealthy(format!("Database check failed: {}", e)), + } + } + }) + .build() + } + + /// Create a health check for a SQLite pool + /// + /// The health check will attempt to get a connection from the pool. + #[cfg(feature = "diesel-sqlite")] + pub fn health_check_sqlite( + pool: Arc>>, + ) -> HealthCheck { + HealthCheckBuilder::new(false) + .add_check("sqlite", move || { + let pool = pool.clone(); + async move { + match pool.get() { + Ok(_) => HealthStatus::healthy(), + Err(e) => HealthStatus::unhealthy(format!("Database check failed: {}", e)), + } + } + }) + .build() + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + // Unit tests for DieselPoolBuilder + #[test] + fn test_builder_default_values() { + let builder = DieselPoolBuilder::new("postgres://localhost/test"); + let config = builder.config(); + + assert_eq!(config.url, "postgres://localhost/test"); + assert_eq!(config.max_connections, 10); + assert_eq!(config.min_idle, None); + assert_eq!(config.connection_timeout, Duration::from_secs(30)); + assert_eq!(config.idle_timeout, Some(Duration::from_secs(600))); + assert_eq!(config.max_lifetime, Some(Duration::from_secs(1800))); + } + + #[test] + fn test_builder_custom_values() { + let builder = DieselPoolBuilder::new("postgres://localhost/test") + .max_connections(20) + .min_idle(Some(5)) + .connection_timeout(Duration::from_secs(10)) + .idle_timeout(Some(Duration::from_secs(300))) + .max_lifetime(Some(Duration::from_secs(900))); + + let config = builder.config(); + + assert_eq!(config.max_connections, 20); + assert_eq!(config.min_idle, Some(5)); + assert_eq!(config.connection_timeout, Duration::from_secs(10)); + assert_eq!(config.idle_timeout, Some(Duration::from_secs(300))); + assert_eq!(config.max_lifetime, Some(Duration::from_secs(900))); + } + + #[test] + fn test_config_validation_empty_url() { + let config = DieselPoolConfig::default(); + let result = config.validate(); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), DieselPoolError::Configuration(_))); + } + + #[test] + fn test_config_validation_zero_max_connections() { + let config = DieselPoolConfig { + url: "postgres://localhost/test".to_string(), + max_connections: 0, + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_err()); + } + + #[test] + fn test_config_validation_min_idle_exceeds_max() { + let config = DieselPoolConfig { + url: "postgres://localhost/test".to_string(), + max_connections: 5, + min_idle: Some(10), + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_err()); + } + + #[test] + fn test_config_validation_valid() { + let config = DieselPoolConfig { + url: "postgres://localhost/test".to_string(), + max_connections: 10, + min_idle: Some(2), + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_ok()); + } + + #[test] + fn test_config_validation_valid_no_min_idle() { + let config = DieselPoolConfig { + url: "postgres://localhost/test".to_string(), + max_connections: 10, + min_idle: None, + ..Default::default() + }; + let result = config.validate(); + assert!(result.is_ok()); + } + + // **Feature: v1-features-roadmap, Property 9: Health check accuracy** + // + // *For any* database pool, health checks SHALL correctly report connectivity status. + // + // **Validates: Requirements 3.3** + // + // Note: This property test validates that the configuration is correctly + // stored and validated. Actual health check behavior testing requires + // integration tests with a real database. + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_diesel_pool_configuration_respects_limits( + max_conn in 1u32..100, + min_idle_factor in 0.0f64..1.0, + connection_timeout_secs in 1u64..120, + idle_timeout_secs in 60u64..3600, + max_lifetime_secs in 300u64..7200, + ) { + // Calculate min_idle as a fraction of max to ensure min <= max + let min_idle = ((max_conn as f64) * min_idle_factor).floor() as u32; + + let builder = DieselPoolBuilder::new("postgres://localhost/test") + .max_connections(max_conn) + .min_idle(Some(min_idle)) + .connection_timeout(Duration::from_secs(connection_timeout_secs)) + .idle_timeout(Some(Duration::from_secs(idle_timeout_secs))) + .max_lifetime(Some(Duration::from_secs(max_lifetime_secs))); + + let config = builder.config(); + + // Verify all configuration values are correctly stored + prop_assert_eq!(config.max_connections, max_conn); + prop_assert_eq!(config.min_idle, Some(min_idle)); + prop_assert_eq!(config.connection_timeout, Duration::from_secs(connection_timeout_secs)); + prop_assert_eq!(config.idle_timeout, Some(Duration::from_secs(idle_timeout_secs))); + prop_assert_eq!(config.max_lifetime, Some(Duration::from_secs(max_lifetime_secs))); + + // Verify configuration validates successfully + prop_assert!(config.validate().is_ok()); + + // Verify invariant: min_idle <= max_connections + if let Some(min) = config.min_idle { + prop_assert!(min <= config.max_connections); + } + } + } + + // Property test for configuration validation + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_diesel_invalid_config_is_rejected( + max_conn in 1u32..50, + min_idle_excess in 1u32..50, + ) { + // Create config where min_idle > max (invalid) + let config = DieselPoolConfig { + url: "postgres://localhost/test".to_string(), + max_connections: max_conn, + min_idle: Some(max_conn + min_idle_excess), + ..Default::default() + }; + + // Should fail validation + prop_assert!(config.validate().is_err()); + } + } +} diff --git a/crates/rustapi-extras/src/guard.rs b/crates/rustapi-extras/src/guard.rs new file mode 100644 index 00000000..31f94602 --- /dev/null +++ b/crates/rustapi-extras/src/guard.rs @@ -0,0 +1,252 @@ +//! Request guards for route-level authorization +//! +//! This module provides guard extractors for role-based and permission-based access control. +//! +//! # Example +//! +//! ```rust,no_run +//! use rustapi_extras::{RoleGuard, PermissionGuard}; +//! use rustapi_core::Json; +//! use serde::Serialize; +//! +//! #[derive(Serialize)] +//! struct AdminData { +//! message: String, +//! } +//! +//! // Extractor-based guards +//! async fn admin_only(guard: RoleGuard) -> Json { +//! Json(AdminData { +//! message: format!("Welcome, {}!", guard.role), +//! }) +//! } +//! ``` + +use rustapi_core::{ApiError, FromRequestParts, Request}; + +/// Role-based guard extractor +/// +/// Extracts the authenticated user and provides the user's role. +/// Requires JWT middleware to be enabled. +#[derive(Debug, Clone)] +pub struct RoleGuard { + /// The user's role + pub role: String, +} + +impl FromRequestParts for RoleGuard { + fn from_request_parts(req: &Request) -> rustapi_core::Result { + let extensions = req.extensions(); + + #[cfg(feature = "jwt")] + { + use crate::jwt::{AuthUser, ValidatedClaims}; + + // Try to get ValidatedClaims from extensions + if let Some(validated) = extensions.get::>() { + // Extract role from claims + if let Some(role) = validated.0.get("role").and_then(|r| r.as_str()) { + return Ok(Self { + role: role.to_string(), + }); + } + } + + // Also try AuthUser for backward compatibility + if let Some(user) = extensions.get::>() { + if let Some(role) = user.0.get("role").and_then(|r| r.as_str()) { + return Ok(Self { + role: role.to_string(), + }); + } + } + } + + #[cfg(not(feature = "jwt"))] + { + let _ = extensions; + } + + Err(ApiError::forbidden( + "Authentication required: missing or invalid role", + )) + } +} + +impl RoleGuard { + /// Check if the user has a specific role + pub fn has_role(&self, role: &str) -> bool { + self.role == role + } + + /// Require a specific role, returning an error if not matched + pub fn require_role(&self, role: &str) -> Result<(), ApiError> { + if self.has_role(role) { + Ok(()) + } else { + Err(ApiError::forbidden(format!("Required role: {}", role))) + } + } +} + +/// Permission-based guard extractor +/// +/// Extracts the authenticated user and provides the user's permissions. +/// Requires JWT middleware and permissions in the JWT claims. +#[derive(Debug, Clone)] +pub struct PermissionGuard { + /// The user's permissions + pub permissions: Vec, +} + +impl FromRequestParts for PermissionGuard { + fn from_request_parts(req: &Request) -> rustapi_core::Result { + let extensions = req.extensions(); + + #[cfg(feature = "jwt")] + { + use crate::jwt::{AuthUser, ValidatedClaims}; + + // Try ValidatedClaims first + if let Some(validated) = extensions.get::>() { + if let Some(permissions_value) = validated.0.get("permissions") { + if let Some(permissions_array) = permissions_value.as_array() { + let permissions: Vec = permissions_array + .iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect(); + + if !permissions.is_empty() { + return Ok(Self { permissions }); + } + } + } + } + + // Also try AuthUser + if let Some(user) = extensions.get::>() { + if let Some(permissions_value) = user.0.get("permissions") { + if let Some(permissions_array) = permissions_value.as_array() { + let permissions: Vec = permissions_array + .iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect(); + + if !permissions.is_empty() { + return Ok(Self { permissions }); + } + } + } + } + } + + #[cfg(not(feature = "jwt"))] + { + let _ = extensions; + } + + Err(ApiError::forbidden( + "Authentication required: missing or invalid permissions", + )) + } +} + +impl PermissionGuard { + /// Check if the user has a specific permission + pub fn has_permission(&self, permission: &str) -> bool { + self.permissions.iter().any(|p| p == permission) + } + + /// Require a specific permission, returning an error if not matched + pub fn require_permission(&self, permission: &str) -> Result<(), ApiError> { + if self.has_permission(permission) { + Ok(()) + } else { + Err(ApiError::forbidden(format!( + "Required permission: {}", + permission + ))) + } + } + + /// Check if the user has any of the given permissions + pub fn has_any_permission(&self, permissions: &[&str]) -> bool { + self.permissions + .iter() + .any(|p| permissions.contains(&p.as_str())) + } + + /// Check if the user has all of the given permissions + pub fn has_all_permissions(&self, permissions: &[&str]) -> bool { + permissions + .iter() + .all(|required| self.has_permission(required)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + #[tokio::test] + async fn role_guard_without_auth_fails() { + let req = Request::from_http_request( + http::Request::builder() + .method("GET") + .uri("/") + .body(()) + .unwrap(), + Bytes::new(), + ); + + let result = RoleGuard::from_request_parts(&req); + assert!(result.is_err()); + } + + #[tokio::test] + async fn permission_guard_without_auth_fails() { + let req = Request::from_http_request( + http::Request::builder() + .method("GET") + .uri("/") + .body(()) + .unwrap(), + Bytes::new(), + ); + + let result = PermissionGuard::from_request_parts(&req); + assert!(result.is_err()); + } + + #[test] + fn role_guard_has_role_works() { + let guard = RoleGuard { + role: "admin".to_string(), + }; + + assert!(guard.has_role("admin")); + assert!(!guard.has_role("user")); + } + + #[test] + fn permission_guard_has_permission_works() { + let guard = PermissionGuard { + permissions: vec!["users.read".to_string(), "users.write".to_string()], + }; + + assert!(guard.has_permission("users.read")); + assert!(guard.has_permission("users.write")); + assert!(!guard.has_permission("users.delete")); + } + + #[test] + fn permission_guard_has_all_permissions_works() { + let guard = PermissionGuard { + permissions: vec!["users.read".to_string(), "users.write".to_string()], + }; + + assert!(guard.has_all_permissions(&["users.read", "users.write"])); + assert!(!guard.has_all_permissions(&["users.read", "users.delete"])); + } +} diff --git a/crates/rustapi-extras/src/lib.rs b/crates/rustapi-extras/src/lib.rs index 027a774a..9d4abe2a 100644 --- a/crates/rustapi-extras/src/lib.rs +++ b/crates/rustapi-extras/src/lib.rs @@ -47,10 +47,62 @@ pub mod config; #[cfg(feature = "sqlx")] pub mod sqlx; +// Diesel database integration module +#[cfg(feature = "diesel")] +pub mod diesel; + // Traffic insight module #[cfg(feature = "insight")] pub mod insight; +// Request timeout middleware +#[cfg(feature = "timeout")] +pub mod timeout; + +// Request guards (authorization) +#[cfg(feature = "guard")] +pub mod guard; + +// Request/Response logging middleware +#[cfg(feature = "logging")] +pub mod logging; + +// Circuit breaker middleware +#[cfg(feature = "circuit-breaker")] +pub mod circuit_breaker; + +// Retry middleware +#[cfg(feature = "retry")] +pub mod retry; + +// Request deduplication +#[cfg(feature = "dedup")] +pub mod dedup; + +// Input sanitization +#[cfg(feature = "sanitization")] +pub mod sanitization; + +// Security headers middleware +#[cfg(feature = "security-headers")] +pub mod security_headers; + +// API Key authentication +#[cfg(feature = "api-key")] +pub mod api_key; + +// Response caching +#[cfg(feature = "cache")] +pub mod cache; + +// OpenTelemetry integration +#[cfg(feature = "otel")] +pub mod otel; + +// Structured logging +#[cfg(feature = "structured-logging")] +pub mod structured_logging; + // Re-exports for convenience #[cfg(feature = "jwt")] pub use jwt::{create_token, AuthUser, JwtError, JwtLayer, JwtValidation, ValidatedClaims}; @@ -68,9 +120,73 @@ pub use config::{ }; #[cfg(feature = "sqlx")] -pub use sqlx::{convert_sqlx_error, SqlxErrorExt}; +pub use sqlx::{convert_sqlx_error, PoolError, SqlxErrorExt, SqlxPoolBuilder, SqlxPoolConfig}; + +#[cfg(feature = "diesel")] +pub use diesel::{DieselPoolBuilder, DieselPoolConfig, DieselPoolError}; #[cfg(feature = "insight")] pub use insight::{ InMemoryInsightStore, InsightConfig, InsightData, InsightLayer, InsightStats, InsightStore, }; + +// Phase 11 re-exports +#[cfg(feature = "timeout")] +pub use timeout::TimeoutLayer; + +#[cfg(feature = "guard")] +pub use guard::{PermissionGuard, RoleGuard}; + +#[cfg(feature = "logging")] +pub use logging::{LogFormat, LoggingConfig, LoggingLayer}; + +#[cfg(feature = "circuit-breaker")] +pub use circuit_breaker::{CircuitBreakerLayer, CircuitBreakerStats, CircuitState}; + +#[cfg(feature = "retry")] +pub use retry::{RetryLayer, RetryStrategy}; + +#[cfg(feature = "security-headers")] +pub use security_headers::{HstsConfig, ReferrerPolicy, SecurityHeadersLayer, XFrameOptions}; + +#[cfg(feature = "api-key")] +pub use api_key::ApiKeyLayer; + +#[cfg(feature = "cache")] +pub use cache::{CacheConfig, CacheLayer}; + +#[cfg(feature = "dedup")] +pub use dedup::{DedupConfig, DedupLayer}; + +#[cfg(feature = "sanitization")] +pub use sanitization::{sanitize_html, sanitize_json, strip_tags}; + +// Phase 5: Observability re-exports +#[cfg(feature = "otel")] +pub use otel::{ + extract_trace_context, inject_trace_context, propagate_trace_context, OtelConfig, + OtelConfigBuilder, OtelExporter, OtelLayer, TraceContext, TraceSampler, +}; + +#[cfg(feature = "structured-logging")] +pub use structured_logging::{ + DatadogFormatter, JsonFormatter, LogFormatter, LogOutputFormat, LogfmtFormatter, + SplunkFormatter, StructuredLoggingConfig, StructuredLoggingConfigBuilder, + StructuredLoggingLayer, +}; + +// Phase 6: Security features +#[cfg(feature = "csrf")] +pub mod csrf; + +#[cfg(feature = "csrf")] +pub use csrf::{CsrfConfig, CsrfLayer, CsrfToken}; + +#[cfg(feature = "oauth2-client")] +pub mod oauth2; + +#[cfg(feature = "oauth2-client")] +pub use oauth2::{ + AuthorizationRequest, CsrfState, OAuth2Client, OAuth2Config, PkceVerifier, Provider, + TokenError, TokenResponse, +}; diff --git a/crates/rustapi-extras/src/logging.rs b/crates/rustapi-extras/src/logging.rs new file mode 100644 index 00000000..05260f83 --- /dev/null +++ b/crates/rustapi-extras/src/logging.rs @@ -0,0 +1,302 @@ +//! Structured request/response logging middleware +//! +//! This module provides detailed logging of HTTP requests and responses +//! with support for correlation IDs, custom fields, and structured output. +//! +//! # Example +//! +//! ```rust,no_run +//! use rustapi_core::RustApi; +//! use rustapi_extras::{LoggingLayer, LogFormat}; +//! +//! #[tokio::main] +//! async fn main() { +//! let app = RustApi::new() +//! .layer(Box::new(LoggingLayer::new())) +//! .run("0.0.0.0:3000") +//! .await +//! .unwrap(); +//! } +//! ``` + +use rustapi_core::{ + middleware::{BoxedNext, MiddlewareLayer}, + Request, Response, +}; +use std::future::Future; +use std::pin::Pin; +use std::time::Instant; + +/// Logging format +#[derive(Clone, Debug)] +pub enum LogFormat { + /// Compact format (one line per request) + Compact, + /// Detailed format (multi-line with full details) + Detailed, + /// JSON format (structured logging) + Json, +} + +/// Logging configuration +#[derive(Clone)] +pub struct LoggingConfig { + /// Logging format + pub format: LogFormat, + /// Whether to log request headers + pub log_request_headers: bool, + /// Whether to log response headers + pub log_response_headers: bool, + /// Paths to skip logging + pub skip_paths: Vec, +} + +impl Default for LoggingConfig { + fn default() -> Self { + Self { + format: LogFormat::Compact, + log_request_headers: false, + log_response_headers: false, + skip_paths: vec!["/health".to_string(), "/metrics".to_string()], + } + } +} + +/// Logging middleware layer +#[derive(Clone)] +pub struct LoggingLayer { + config: LoggingConfig, +} + +impl LoggingLayer { + /// Create a new logging layer with default configuration + pub fn new() -> Self { + Self { + config: LoggingConfig::default(), + } + } + + /// Create a new logging layer with custom configuration + pub fn with_config(config: LoggingConfig) -> Self { + Self { config } + } + + /// Set the logging format + pub fn format(mut self, format: LogFormat) -> Self { + self.config.format = format; + self + } + + /// Enable request header logging + pub fn log_request_headers(mut self, enabled: bool) -> Self { + self.config.log_request_headers = enabled; + self + } + + /// Enable response header logging + pub fn log_response_headers(mut self, enabled: bool) -> Self { + self.config.log_response_headers = enabled; + self + } + + /// Add a path to skip logging + pub fn skip_path(mut self, path: impl Into) -> Self { + self.config.skip_paths.push(path.into()); + self + } +} + +impl Default for LoggingLayer { + fn default() -> Self { + Self::new() + } +} + +impl MiddlewareLayer for LoggingLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + + Box::pin(async move { + let method = req.method().to_string(); + let uri = req.uri().to_string(); + let version = format!("{:?}", req.version()); + + // Check if we should skip this path + if config.skip_paths.iter().any(|p| uri.starts_with(p)) { + return next(req).await; + } + + // Get request ID from extensions if available + let request_id = req + .extensions() + .get::() + .map(|s| s.clone()) + .unwrap_or_else(|| "N/A".to_string()); + + let start = Instant::now(); + + // Log request + match config.format { + LogFormat::Compact => { + tracing::info!( + request_id = %request_id, + method = %method, + uri = %uri, + version = %version, + "incoming request" + ); + } + LogFormat::Detailed => { + tracing::info!( + request_id = %request_id, + method = %method, + uri = %uri, + version = %version, + "=== Incoming Request ===" + ); + + if config.log_request_headers { + for (name, value) in req.headers() { + if let Ok(val) = value.to_str() { + tracing::debug!( + request_id = %request_id, + header = %name, + value = %val, + "request header" + ); + } + } + } + } + LogFormat::Json => { + let json = serde_json::json!({ + "type": "request", + "request_id": request_id, + "method": method, + "uri": uri, + "version": version, + }); + tracing::info!("{}", json); + } + } + + // Call next middleware/handler + let response = next(req).await; + + let duration = start.elapsed(); + let status = response.status().as_u16(); + let duration_ms = duration.as_millis(); + + // Log response + match config.format { + LogFormat::Compact => { + tracing::info!( + request_id = %request_id, + method = %method, + uri = %uri, + status = status, + duration_ms = duration_ms, + "request completed" + ); + } + LogFormat::Detailed => { + tracing::info!( + request_id = %request_id, + status = status, + duration_ms = duration_ms, + "=== Response Sent ===" + ); + + if config.log_response_headers { + for (name, value) in response.headers() { + if let Ok(val) = value.to_str() { + tracing::debug!( + request_id = %request_id, + header = %name, + value = %val, + "response header" + ); + } + } + } + } + LogFormat::Json => { + let json = serde_json::json!({ + "type": "response", + "request_id": request_id, + "method": method, + "uri": uri, + "status": status, + "duration_ms": duration_ms, + }); + tracing::info!("{}", json); + } + } + + response + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use std::sync::Arc; + + #[tokio::test] + async fn logging_middleware_logs_request() { + let layer = LoggingLayer::new(); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(bytes::Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/test") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert_eq!(response.status(), 200); + } + + #[tokio::test] + async fn logging_middleware_skips_health_check() { + let layer = LoggingLayer::new(); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(bytes::Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/health") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert_eq!(response.status(), 200); + } +} diff --git a/crates/rustapi-extras/src/oauth2/client.rs b/crates/rustapi-extras/src/oauth2/client.rs new file mode 100644 index 00000000..d03622d7 --- /dev/null +++ b/crates/rustapi-extras/src/oauth2/client.rs @@ -0,0 +1,308 @@ +//! OAuth2 client implementation + +use super::config::OAuth2Config; +use super::tokens::{CsrfState, PkceVerifier, TokenError, TokenResponse}; +use std::collections::HashMap; +use std::time::Duration; + +/// OAuth2 client for handling authorization flows. +#[derive(Debug, Clone)] +pub struct OAuth2Client { + config: OAuth2Config, +} + +impl OAuth2Client { + /// Create a new OAuth2 client. + pub fn new(config: OAuth2Config) -> Self { + Self { config } + } + + /// Get the configuration. + pub fn config(&self) -> &OAuth2Config { + &self.config + } + + /// Generate an authorization URL for the user to visit. + /// + /// Returns the authorization URL, CSRF state token, and optionally a PKCE verifier. + pub fn authorization_url(&self) -> AuthorizationRequest { + let csrf_state = CsrfState::generate(); + let pkce = if self.config.use_pkce { + Some(PkceVerifier::generate()) + } else { + None + }; + + // Build query parameters + let mut params = vec![ + ("client_id", self.config.client_id.clone()), + ("redirect_uri", self.config.redirect_uri.clone()), + ("response_type", "code".to_string()), + ("state", csrf_state.as_str().to_string()), + ]; + + // Add scopes + if !self.config.scopes.is_empty() { + let scope_str = self + .config + .scopes + .iter() + .cloned() + .collect::>() + .join(" "); + params.push(("scope", scope_str)); + } + + // Add PKCE parameters if enabled + if let Some(ref pkce) = pkce { + params.push(("code_challenge", pkce.challenge().to_string())); + params.push(("code_challenge_method", pkce.method().to_string())); + } + + // Build the URL + let query = params + .iter() + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v))) + .collect::>() + .join("&"); + + let url = format!("{}?{}", self.config.provider.auth_url(), query); + + AuthorizationRequest { + url, + csrf_state, + pkce_verifier: pkce, + } + } + + /// Exchange an authorization code for tokens. + /// + /// This should be called after the user is redirected back with the authorization code. + pub async fn exchange_code( + &self, + code: &str, + pkce_verifier: Option<&PkceVerifier>, + ) -> Result { + let mut params = HashMap::new(); + params.insert("grant_type", "authorization_code".to_string()); + params.insert("code", code.to_string()); + params.insert("client_id", self.config.client_id.clone()); + params.insert("client_secret", self.config.client_secret.clone()); + params.insert("redirect_uri", self.config.redirect_uri.clone()); + + // Add PKCE verifier if provided + if let Some(verifier) = pkce_verifier { + params.insert("code_verifier", verifier.verifier().to_string()); + } + + self.token_request(params).await + } + + /// Refresh an access token using a refresh token. + pub async fn refresh_token(&self, refresh_token: &str) -> Result { + let mut params = HashMap::new(); + params.insert("grant_type", "refresh_token".to_string()); + params.insert("refresh_token", refresh_token.to_string()); + params.insert("client_id", self.config.client_id.clone()); + params.insert("client_secret", self.config.client_secret.clone()); + + self.token_request(params).await + } + + /// Make a token request to the authorization server. + async fn token_request( + &self, + params: HashMap<&str, String>, + ) -> Result { + // Build form data + let form_data = params + .iter() + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v))) + .collect::>() + .join("&"); + + // Make HTTP request + let client = reqwest::Client::builder() + .timeout(self.config.timeout) + .build() + .map_err(|e| TokenError::NetworkError(e.to_string()))?; + + let response = client + .post(self.config.provider.token_url()) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .body(form_data) + .send() + .await + .map_err(|e| TokenError::NetworkError(e.to_string()))?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(TokenError::ExchangeFailed(error_text)); + } + + // Parse response + let response_json: serde_json::Value = response + .json() + .await + .map_err(|e| TokenError::InvalidResponse(e.to_string()))?; + + self.parse_token_response(response_json) + } + + /// Parse a token response from JSON. + fn parse_token_response(&self, json: serde_json::Value) -> Result { + let access_token = json + .get("access_token") + .and_then(|v| v.as_str()) + .ok_or_else(|| TokenError::MissingField("access_token".to_string()))? + .to_string(); + + let token_type = json + .get("token_type") + .and_then(|v| v.as_str()) + .unwrap_or("Bearer") + .to_string(); + + let mut response = TokenResponse::new(access_token, token_type); + + // Optional fields + if let Some(expires_in) = json.get("expires_in").and_then(|v| v.as_u64()) { + response = response.with_expires_in(Duration::from_secs(expires_in)); + } + + if let Some(refresh) = json.get("refresh_token").and_then(|v| v.as_str()) { + response = response.with_refresh_token(refresh.to_string()); + } + + if let Some(id_token) = json.get("id_token").and_then(|v| v.as_str()) { + response = response.with_id_token(id_token.to_string()); + } + + if let Some(scope) = json.get("scope").and_then(|v| v.as_str()) { + let scopes: Vec = scope.split(' ').map(String::from).collect(); + response = response.with_scopes(scopes); + } + + Ok(response) + } + + /// Validate the CSRF state from the callback. + pub fn validate_state(&self, expected: &CsrfState, received: &str) -> Result<(), TokenError> { + if expected.verify(received) { + Ok(()) + } else { + Err(TokenError::InvalidState) + } + } +} + +/// Authorization request containing the URL and security tokens. +#[derive(Debug)] +pub struct AuthorizationRequest { + /// The authorization URL to redirect the user to. + pub url: String, + /// CSRF state token (store this to verify callback). + pub csrf_state: CsrfState, + /// PKCE verifier (store this for token exchange, if PKCE is enabled). + pub pkce_verifier: Option, +} + +impl AuthorizationRequest { + /// Get just the authorization URL. + pub fn url(&self) -> &str { + &self.url + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::oauth2::OAuth2Config; + + #[test] + fn test_authorization_url_google() { + let config = OAuth2Config::google( + "test_client_id", + "test_client_secret", + "https://example.com/callback", + ); + let client = OAuth2Client::new(config); + let auth_req = client.authorization_url(); + + // Check URL structure + assert!(auth_req.url.contains("accounts.google.com")); + assert!(auth_req.url.contains("client_id=test_client_id")); + assert!(auth_req.url.contains("redirect_uri=")); + assert!(auth_req.url.contains("response_type=code")); + assert!(auth_req.url.contains("state=")); + assert!(auth_req.url.contains("code_challenge=")); // PKCE enabled for Google + + // Check CSRF state is generated + assert!(!auth_req.csrf_state.as_str().is_empty()); + + // Check PKCE verifier is generated (Google supports PKCE) + assert!(auth_req.pkce_verifier.is_some()); + } + + #[test] + fn test_authorization_url_github() { + let config = OAuth2Config::github( + "test_client_id", + "test_client_secret", + "https://example.com/callback", + ); + let client = OAuth2Client::new(config); + let auth_req = client.authorization_url(); + + // Check URL structure + assert!(auth_req.url.contains("github.com")); + assert!(auth_req.url.contains("client_id=test_client_id")); + + // GitHub doesn't support PKCE + assert!(auth_req.pkce_verifier.is_none()); + assert!(!auth_req.url.contains("code_challenge=")); + } + + #[test] + fn test_state_validation() { + let config = OAuth2Config::google("id", "secret", "https://example.com/callback"); + let client = OAuth2Client::new(config); + + let state = CsrfState::generate(); + + // Valid state should pass + assert!(client.validate_state(&state, state.as_str()).is_ok()); + + // Invalid state should fail + assert!(matches!( + client.validate_state(&state, "wrong_state"), + Err(TokenError::InvalidState) + )); + } + + #[test] + fn test_parse_token_response() { + let config = OAuth2Config::google("id", "secret", "https://example.com/callback"); + let client = OAuth2Client::new(config); + + let json = serde_json::json!({ + "access_token": "ya29.access_token_here", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "1//refresh_token_here", + "scope": "openid email profile", + "id_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9..." + }); + + let result = client.parse_token_response(json); + assert!(result.is_ok()); + + let token = result.unwrap(); + assert_eq!(token.access_token(), "ya29.access_token_here"); + assert_eq!(token.token_type(), "Bearer"); + assert_eq!(token.refresh_token(), Some("1//refresh_token_here")); + assert!(token.id_token().is_some()); + assert!(!token.is_expired()); + } +} diff --git a/crates/rustapi-extras/src/oauth2/config.rs b/crates/rustapi-extras/src/oauth2/config.rs new file mode 100644 index 00000000..dc77f371 --- /dev/null +++ b/crates/rustapi-extras/src/oauth2/config.rs @@ -0,0 +1,193 @@ +//! OAuth2 configuration + +use super::providers::Provider; +use std::collections::HashSet; +use std::time::Duration; + +/// Configuration for OAuth2 authentication. +#[derive(Debug, Clone)] +pub struct OAuth2Config { + /// The OAuth2 provider (includes endpoint URLs). + pub(crate) provider: Provider, + /// Client ID issued by the provider. + pub(crate) client_id: String, + /// Client secret issued by the provider. + pub(crate) client_secret: String, + /// Redirect URI for the authorization callback. + pub(crate) redirect_uri: String, + /// Scopes to request. + pub(crate) scopes: HashSet, + /// Whether to use PKCE (Proof Key for Code Exchange). + pub(crate) use_pkce: bool, + /// Timeout for HTTP requests. + pub(crate) timeout: Duration, +} + +impl OAuth2Config { + /// Create a new OAuth2 configuration with a custom provider. + pub fn new( + provider: Provider, + client_id: impl Into, + client_secret: impl Into, + redirect_uri: impl Into, + ) -> Self { + let provider_clone = provider.clone(); + Self { + scopes: provider.default_scopes(), + use_pkce: provider.supports_pkce(), + provider: provider_clone, + client_id: client_id.into(), + client_secret: client_secret.into(), + redirect_uri: redirect_uri.into(), + timeout: Duration::from_secs(30), + } + } + + /// Create a Google OAuth2 configuration. + pub fn google( + client_id: impl Into, + client_secret: impl Into, + redirect_uri: impl Into, + ) -> Self { + Self::new(Provider::Google, client_id, client_secret, redirect_uri) + } + + /// Create a GitHub OAuth2 configuration. + pub fn github( + client_id: impl Into, + client_secret: impl Into, + redirect_uri: impl Into, + ) -> Self { + Self::new(Provider::GitHub, client_id, client_secret, redirect_uri) + } + + /// Create a Microsoft OAuth2 configuration. + pub fn microsoft( + client_id: impl Into, + client_secret: impl Into, + redirect_uri: impl Into, + ) -> Self { + Self::new(Provider::Microsoft, client_id, client_secret, redirect_uri) + } + + /// Create a Discord OAuth2 configuration. + pub fn discord( + client_id: impl Into, + client_secret: impl Into, + redirect_uri: impl Into, + ) -> Self { + Self::new(Provider::Discord, client_id, client_secret, redirect_uri) + } + + /// Create a custom OAuth2 configuration. + pub fn custom( + auth_url: impl Into, + token_url: impl Into, + client_id: impl Into, + client_secret: impl Into, + redirect_uri: impl Into, + ) -> Self { + Self::new( + Provider::Custom { + auth_url: auth_url.into(), + token_url: token_url.into(), + userinfo_url: None, + }, + client_id, + client_secret, + redirect_uri, + ) + } + + /// Add a scope to request. + pub fn scope(mut self, scope: impl Into) -> Self { + self.scopes.insert(scope.into()); + self + } + + /// Set multiple scopes (replaces existing). + pub fn scopes(mut self, scopes: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.scopes = scopes.into_iter().map(Into::into).collect(); + self + } + + /// Enable or disable PKCE. + pub fn pkce(mut self, enabled: bool) -> Self { + self.use_pkce = enabled; + self + } + + /// Set the HTTP request timeout. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + /// Get the client ID. + pub fn client_id(&self) -> &str { + &self.client_id + } + + /// Get the redirect URI. + pub fn redirect_uri(&self) -> &str { + &self.redirect_uri + } + + /// Get the provider. + pub fn provider(&self) -> &Provider { + &self.provider + } + + /// Get the scopes. + pub fn get_scopes(&self) -> &HashSet { + &self.scopes + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_google_config() { + let config = OAuth2Config::google("id", "secret", "https://example.com/callback"); + assert_eq!(config.client_id(), "id"); + assert!(config.use_pkce); + assert!(config.scopes.contains("openid")); + } + + #[test] + fn test_scope_builder() { + let config = OAuth2Config::github("id", "secret", "https://example.com/callback") + .scope("repo") + .scope("gist"); + + assert!(config.scopes.contains("repo")); + assert!(config.scopes.contains("gist")); + assert!(config.scopes.contains("user:email")); // Default scope still present + } + + #[test] + fn test_custom_provider() { + let config = OAuth2Config::custom( + "https://auth.example.com/authorize", + "https://auth.example.com/token", + "my_client", + "my_secret", + "https://myapp.com/callback", + ); + + assert_eq!( + config.provider.auth_url(), + "https://auth.example.com/authorize" + ); + assert_eq!( + config.provider.token_url(), + "https://auth.example.com/token" + ); + } +} diff --git a/crates/rustapi-extras/src/oauth2/mod.rs b/crates/rustapi-extras/src/oauth2/mod.rs new file mode 100644 index 00000000..34559720 --- /dev/null +++ b/crates/rustapi-extras/src/oauth2/mod.rs @@ -0,0 +1,38 @@ +//! OAuth2 client integration for RustAPI +//! +//! This module provides OAuth2 authentication support with built-in +//! provider presets for common identity providers. +//! +//! # Example +//! +//! ```rust,no_run +//! use rustapi_extras::oauth2::{OAuth2Client, OAuth2Config, Provider}; +//! +//! // Using a preset provider +//! let config = OAuth2Config::google( +//! "client_id", +//! "client_secret", +//! "https://myapp.com/auth/callback", +//! ); +//! +//! let client = OAuth2Client::new(config); +//! +//! // Generate authorization URL +//! let auth_request = client.authorization_url(); +//! let auth_url = auth_request.url(); +//! let csrf_state = &auth_request.csrf_state; +//! let pkce_verifier = &auth_request.pkce_verifier; +//! +//! // After user authorization, exchange the code +//! // let tokens = client.exchange_code("auth_code", pkce_verifier.as_ref()).await?; +//! ``` + +mod client; +mod config; +mod providers; +mod tokens; + +pub use client::{AuthorizationRequest, OAuth2Client}; +pub use config::OAuth2Config; +pub use providers::Provider; +pub use tokens::{CsrfState, PkceVerifier, TokenError, TokenResponse}; diff --git a/crates/rustapi-extras/src/oauth2/providers.rs b/crates/rustapi-extras/src/oauth2/providers.rs new file mode 100644 index 00000000..8f363dc1 --- /dev/null +++ b/crates/rustapi-extras/src/oauth2/providers.rs @@ -0,0 +1,133 @@ +//! OAuth2 provider presets +//! +//! Pre-configured settings for common OAuth2 providers. + +use std::collections::HashSet; + +/// Supported OAuth2 providers with pre-configured endpoints. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Provider { + /// Google OAuth2 + Google, + /// GitHub OAuth2 + GitHub, + /// Microsoft (Azure AD) OAuth2 + Microsoft, + /// Discord OAuth2 + Discord, + /// Custom provider with manual configuration + Custom { + /// Authorization endpoint URL + auth_url: String, + /// Token endpoint URL + token_url: String, + /// User info endpoint URL (optional) + userinfo_url: Option, + }, +} + +impl Provider { + /// Get the authorization endpoint URL for this provider. + pub fn auth_url(&self) -> &str { + match self { + Provider::Google => "https://accounts.google.com/o/oauth2/v2/auth", + Provider::GitHub => "https://github.com/login/oauth/authorize", + Provider::Microsoft => "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + Provider::Discord => "https://discord.com/api/oauth2/authorize", + Provider::Custom { auth_url, .. } => auth_url, + } + } + + /// Get the token endpoint URL for this provider. + pub fn token_url(&self) -> &str { + match self { + Provider::Google => "https://oauth2.googleapis.com/token", + Provider::GitHub => "https://github.com/login/oauth/access_token", + Provider::Microsoft => "https://login.microsoftonline.com/common/oauth2/v2.0/token", + Provider::Discord => "https://discord.com/api/oauth2/token", + Provider::Custom { token_url, .. } => token_url, + } + } + + /// Get the user info endpoint URL for this provider (if available). + pub fn userinfo_url(&self) -> Option<&str> { + match self { + Provider::Google => Some("https://www.googleapis.com/oauth2/v3/userinfo"), + Provider::GitHub => Some("https://api.github.com/user"), + Provider::Microsoft => Some("https://graph.microsoft.com/v1.0/me"), + Provider::Discord => Some("https://discord.com/api/users/@me"), + Provider::Custom { userinfo_url, .. } => userinfo_url.as_deref(), + } + } + + /// Get default scopes for this provider. + pub fn default_scopes(&self) -> HashSet { + match self { + Provider::Google => ["openid", "email", "profile"] + .iter() + .map(|s| s.to_string()) + .collect(), + Provider::GitHub => ["user:email", "read:user"] + .iter() + .map(|s| s.to_string()) + .collect(), + Provider::Microsoft => ["openid", "email", "profile", "User.Read"] + .iter() + .map(|s| s.to_string()) + .collect(), + Provider::Discord => ["identify", "email"] + .iter() + .map(|s| s.to_string()) + .collect(), + Provider::Custom { .. } => HashSet::new(), + } + } + + /// Check if this provider supports PKCE (Proof Key for Code Exchange). + pub fn supports_pkce(&self) -> bool { + match self { + Provider::Google => true, + Provider::GitHub => false, // GitHub doesn't support PKCE yet + Provider::Microsoft => true, + Provider::Discord => true, + Provider::Custom { .. } => true, // Assume custom supports PKCE + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_google_provider() { + let provider = Provider::Google; + assert!(provider.auth_url().contains("google.com")); + assert!(provider.token_url().contains("googleapis.com")); + assert!(provider.supports_pkce()); + assert!(provider.default_scopes().contains("openid")); + } + + #[test] + fn test_github_provider() { + let provider = Provider::GitHub; + assert!(provider.auth_url().contains("github.com")); + assert!(!provider.supports_pkce()); + assert!(provider.default_scopes().contains("user:email")); + } + + #[test] + fn test_custom_provider() { + let provider = Provider::Custom { + auth_url: "https://custom.example.com/auth".to_string(), + token_url: "https://custom.example.com/token".to_string(), + userinfo_url: Some("https://custom.example.com/userinfo".to_string()), + }; + assert_eq!(provider.auth_url(), "https://custom.example.com/auth"); + assert_eq!(provider.token_url(), "https://custom.example.com/token"); + assert_eq!( + provider.userinfo_url(), + Some("https://custom.example.com/userinfo") + ); + } +} diff --git a/crates/rustapi-extras/src/oauth2/tokens.rs b/crates/rustapi-extras/src/oauth2/tokens.rs new file mode 100644 index 00000000..c18e4baa --- /dev/null +++ b/crates/rustapi-extras/src/oauth2/tokens.rs @@ -0,0 +1,273 @@ +//! OAuth2 token types and errors + +use std::time::{Duration, Instant}; +use thiserror::Error; + +/// OAuth2 token response from the authorization server. +#[derive(Debug, Clone)] +pub struct TokenResponse { + /// The access token. + access_token: String, + /// The token type (usually "Bearer"). + token_type: String, + /// Token expiration time (if provided). + expires_at: Option, + /// Refresh token (if provided). + refresh_token: Option, + /// Scopes granted (if different from requested). + scopes: Option>, + /// ID token for OpenID Connect (if provided). + id_token: Option, +} + +impl TokenResponse { + /// Create a new token response. + pub fn new(access_token: String, token_type: String) -> Self { + Self { + access_token, + token_type, + expires_at: None, + refresh_token: None, + scopes: None, + id_token: None, + } + } + + /// Set the expiration time. + pub fn with_expires_in(mut self, expires_in: Duration) -> Self { + self.expires_at = Some(Instant::now() + expires_in); + self + } + + /// Set the refresh token. + pub fn with_refresh_token(mut self, refresh_token: String) -> Self { + self.refresh_token = Some(refresh_token); + self + } + + /// Set the scopes. + pub fn with_scopes(mut self, scopes: Vec) -> Self { + self.scopes = Some(scopes); + self + } + + /// Set the ID token. + pub fn with_id_token(mut self, id_token: String) -> Self { + self.id_token = Some(id_token); + self + } + + /// Get the access token. + pub fn access_token(&self) -> &str { + &self.access_token + } + + /// Get the token type. + pub fn token_type(&self) -> &str { + &self.token_type + } + + /// Check if the token is expired. + pub fn is_expired(&self) -> bool { + match self.expires_at { + Some(expires_at) => Instant::now() >= expires_at, + None => false, // If no expiration, assume not expired + } + } + + /// Get the refresh token (if present). + pub fn refresh_token(&self) -> Option<&str> { + self.refresh_token.as_deref() + } + + /// Get the ID token (if present, for OpenID Connect). + pub fn id_token(&self) -> Option<&str> { + self.id_token.as_deref() + } + + /// Get the scopes (if provided in response). + pub fn scopes(&self) -> Option<&[String]> { + self.scopes.as_deref() + } + + /// Get the time remaining until expiration. + pub fn expires_in(&self) -> Option { + self.expires_at + .and_then(|exp| exp.checked_duration_since(Instant::now())) + } + + /// Get the Authorization header value. + pub fn authorization_header(&self) -> String { + format!("{} {}", self.token_type, self.access_token) + } +} + +/// Errors that can occur during OAuth2 operations. +#[derive(Debug, Error)] +pub enum TokenError { + /// The authorization request was denied. + #[error("Authorization denied: {0}")] + AuthorizationDenied(String), + + /// Invalid authorization code. + #[error("Invalid authorization code")] + InvalidCode, + + /// Invalid CSRF state. + #[error("Invalid CSRF state - possible CSRF attack")] + InvalidState, + + /// Token exchange failed. + #[error("Token exchange failed: {0}")] + ExchangeFailed(String), + + /// Token refresh failed. + #[error("Token refresh failed: {0}")] + RefreshFailed(String), + + /// Network error. + #[error("Network error: {0}")] + NetworkError(String), + + /// Invalid response from the authorization server. + #[error("Invalid response: {0}")] + InvalidResponse(String), + + /// Token is expired. + #[error("Token is expired")] + TokenExpired, + + /// Missing required field in response. + #[error("Missing required field: {0}")] + MissingField(String), +} + +/// PKCE (Proof Key for Code Exchange) verifier. +#[derive(Debug, Clone)] +pub struct PkceVerifier { + verifier: String, + challenge: String, + method: String, +} + +impl PkceVerifier { + /// Generate a new PKCE verifier with S256 challenge. + pub fn generate() -> Self { + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; + use rand::{rngs::OsRng, RngCore}; + + // Generate 32 random bytes for the verifier + let mut verifier_bytes = [0u8; 32]; + OsRng.fill_bytes(&mut verifier_bytes); + let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); + + // Create S256 challenge: BASE64URL(SHA256(verifier)) + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(verifier.as_bytes()); + let hash = hasher.finalize(); + let challenge = URL_SAFE_NO_PAD.encode(hash); + + Self { + verifier, + challenge, + method: "S256".to_string(), + } + } + + /// Get the code verifier (for token exchange). + pub fn verifier(&self) -> &str { + &self.verifier + } + + /// Get the code challenge (for authorization request). + pub fn challenge(&self) -> &str { + &self.challenge + } + + /// Get the challenge method (S256). + pub fn method(&self) -> &str { + &self.method + } +} + +/// CSRF state token for OAuth2 authorization. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CsrfState(String); + +impl CsrfState { + /// Generate a new random CSRF state. + pub fn generate() -> Self { + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; + use rand::{rngs::OsRng, RngCore}; + + let mut bytes = [0u8; 16]; + OsRng.fill_bytes(&mut bytes); + Self(URL_SAFE_NO_PAD.encode(bytes)) + } + + /// Create from an existing string. + pub fn new(state: String) -> Self { + Self(state) + } + + /// Get the state value. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Verify that this state matches another. + pub fn verify(&self, other: &str) -> bool { + // Use constant-time comparison to prevent timing attacks + // For simplicity, we use direct comparison here + self.0 == other + } +} + +impl std::fmt::Display for CsrfState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_token_response() { + let token = TokenResponse::new("access123".to_string(), "Bearer".to_string()) + .with_refresh_token("refresh456".to_string()) + .with_expires_in(Duration::from_secs(3600)); + + assert_eq!(token.access_token(), "access123"); + assert_eq!(token.token_type(), "Bearer"); + assert_eq!(token.refresh_token(), Some("refresh456")); + assert!(!token.is_expired()); + assert_eq!(token.authorization_header(), "Bearer access123"); + } + + #[test] + fn test_pkce_verifier() { + let pkce = PkceVerifier::generate(); + assert!(!pkce.verifier().is_empty()); + assert!(!pkce.challenge().is_empty()); + assert_eq!(pkce.method(), "S256"); + + // Verifier and challenge should be different + assert_ne!(pkce.verifier(), pkce.challenge()); + } + + #[test] + fn test_csrf_state() { + let state1 = CsrfState::generate(); + let state2 = CsrfState::generate(); + + // Each generated state should be unique + assert_ne!(state1, state2); + + // Verification should work + assert!(state1.verify(state1.as_str())); + assert!(!state1.verify(state2.as_str())); + } +} diff --git a/crates/rustapi-extras/src/otel/config.rs b/crates/rustapi-extras/src/otel/config.rs new file mode 100644 index 00000000..715bd52a --- /dev/null +++ b/crates/rustapi-extras/src/otel/config.rs @@ -0,0 +1,278 @@ +//! OpenTelemetry configuration types + +use std::time::Duration; + +/// Exporter type for OpenTelemetry traces +#[derive(Clone, Debug, Default)] +pub enum OtelExporter { + /// OTLP gRPC exporter (default) + #[default] + OtlpGrpc, + /// OTLP HTTP exporter + OtlpHttp, + /// Jaeger exporter + Jaeger, + /// Zipkin exporter + Zipkin, + /// Console exporter (for debugging) + Console, + /// No-op exporter (disabled) + None, +} + +/// Trace sampling strategy +#[derive(Clone, Debug)] +pub enum TraceSampler { + /// Always sample all traces + AlwaysOn, + /// Never sample traces + AlwaysOff, + /// Sample a ratio of traces (0.0 - 1.0) + TraceIdRatio(f64), + /// Sample based on parent span decision + ParentBased, +} + +impl Default for TraceSampler { + fn default() -> Self { + Self::AlwaysOn + } +} + +/// OpenTelemetry configuration +#[derive(Clone, Debug)] +pub struct OtelConfig { + /// Service name for traces + pub service_name: String, + /// Service version + pub service_version: Option, + /// Service namespace + pub service_namespace: Option, + /// Deployment environment (e.g., "production", "staging") + pub deployment_environment: Option, + /// OTLP endpoint URL + pub endpoint: Option, + /// Exporter type + pub exporter: OtelExporter, + /// Trace sampler configuration + pub sampler: TraceSampler, + /// Export timeout + pub export_timeout: Duration, + /// Export interval for batch exporter + pub export_interval: Duration, + /// Maximum queue size for batch exporter + pub max_queue_size: usize, + /// Maximum export batch size + pub max_export_batch_size: usize, + /// Whether to enable metrics collection + pub enable_metrics: bool, + /// Whether to propagate W3C trace context + pub propagate_context: bool, + /// Additional resource attributes + pub resource_attributes: Vec<(String, String)>, + /// Headers to include in traces + pub trace_headers: Vec, + /// Paths to exclude from tracing + pub exclude_paths: Vec, +} + +impl Default for OtelConfig { + fn default() -> Self { + Self { + service_name: "rustapi-service".to_string(), + service_version: None, + service_namespace: None, + deployment_environment: None, + endpoint: None, + exporter: OtelExporter::default(), + sampler: TraceSampler::default(), + export_timeout: Duration::from_secs(30), + export_interval: Duration::from_secs(5), + max_queue_size: 2048, + max_export_batch_size: 512, + enable_metrics: true, + propagate_context: true, + resource_attributes: Vec::new(), + trace_headers: vec![ + "user-agent".to_string(), + "content-type".to_string(), + "x-request-id".to_string(), + ], + exclude_paths: vec!["/health".to_string(), "/metrics".to_string()], + } + } +} + +impl OtelConfig { + /// Create a new OtelConfig builder + pub fn builder() -> OtelConfigBuilder { + OtelConfigBuilder::default() + } +} + +/// Builder for OtelConfig +#[derive(Default)] +pub struct OtelConfigBuilder { + config: OtelConfig, +} + +impl OtelConfigBuilder { + /// Set the service name + pub fn service_name(mut self, name: impl Into) -> Self { + self.config.service_name = name.into(); + self + } + + /// Set the service version + pub fn service_version(mut self, version: impl Into) -> Self { + self.config.service_version = Some(version.into()); + self + } + + /// Set the service namespace + pub fn service_namespace(mut self, namespace: impl Into) -> Self { + self.config.service_namespace = Some(namespace.into()); + self + } + + /// Set the deployment environment + pub fn deployment_environment(mut self, env: impl Into) -> Self { + self.config.deployment_environment = Some(env.into()); + self + } + + /// Set the OTLP endpoint URL + pub fn endpoint(mut self, endpoint: impl Into) -> Self { + self.config.endpoint = Some(endpoint.into()); + self + } + + /// Set the exporter type + pub fn exporter(mut self, exporter: OtelExporter) -> Self { + self.config.exporter = exporter; + self + } + + /// Set the trace sampler + pub fn sampler(mut self, sampler: TraceSampler) -> Self { + self.config.sampler = sampler; + self + } + + /// Set the export timeout + pub fn export_timeout(mut self, timeout: Duration) -> Self { + self.config.export_timeout = timeout; + self + } + + /// Set the export interval + pub fn export_interval(mut self, interval: Duration) -> Self { + self.config.export_interval = interval; + self + } + + /// Set the maximum queue size + pub fn max_queue_size(mut self, size: usize) -> Self { + self.config.max_queue_size = size; + self + } + + /// Set the maximum export batch size + pub fn max_export_batch_size(mut self, size: usize) -> Self { + self.config.max_export_batch_size = size; + self + } + + /// Enable or disable metrics collection + pub fn enable_metrics(mut self, enabled: bool) -> Self { + self.config.enable_metrics = enabled; + self + } + + /// Enable or disable context propagation + pub fn propagate_context(mut self, enabled: bool) -> Self { + self.config.propagate_context = enabled; + self + } + + /// Add a resource attribute + pub fn resource_attribute(mut self, key: impl Into, value: impl Into) -> Self { + self.config + .resource_attributes + .push((key.into(), value.into())); + self + } + + /// Add multiple resource attributes + pub fn resource_attributes(mut self, attrs: Vec<(String, String)>) -> Self { + self.config.resource_attributes.extend(attrs); + self + } + + /// Add a header to trace + pub fn trace_header(mut self, header: impl Into) -> Self { + self.config.trace_headers.push(header.into()); + self + } + + /// Add multiple headers to trace + pub fn trace_headers(mut self, headers: Vec) -> Self { + self.config.trace_headers.extend(headers); + self + } + + /// Add a path to exclude from tracing + pub fn exclude_path(mut self, path: impl Into) -> Self { + self.config.exclude_paths.push(path.into()); + self + } + + /// Add multiple paths to exclude + pub fn exclude_paths(mut self, paths: Vec) -> Self { + self.config.exclude_paths.extend(paths); + self + } + + /// Build the OtelConfig + pub fn build(self) -> OtelConfig { + self.config + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = OtelConfig::default(); + assert_eq!(config.service_name, "rustapi-service"); + assert!(config.propagate_context); + assert!(config.enable_metrics); + } + + #[test] + fn test_builder() { + let config = OtelConfig::builder() + .service_name("my-service") + .service_version("1.0.0") + .endpoint("http://localhost:4317") + .exporter(OtelExporter::OtlpGrpc) + .sampler(TraceSampler::TraceIdRatio(0.5)) + .resource_attribute("env", "production") + .exclude_path("/ready") + .build(); + + assert_eq!(config.service_name, "my-service"); + assert_eq!(config.service_version, Some("1.0.0".to_string())); + assert_eq!(config.endpoint, Some("http://localhost:4317".to_string())); + assert_eq!(config.resource_attributes.len(), 1); + assert!(config.exclude_paths.contains(&"/ready".to_string())); + } + + #[test] + fn test_sampler_default() { + let sampler = TraceSampler::default(); + matches!(sampler, TraceSampler::AlwaysOn); + } +} diff --git a/crates/rustapi-extras/src/otel/layer.rs b/crates/rustapi-extras/src/otel/layer.rs new file mode 100644 index 00000000..e5809afd --- /dev/null +++ b/crates/rustapi-extras/src/otel/layer.rs @@ -0,0 +1,322 @@ +//! OpenTelemetry middleware layer + +use super::config::OtelConfig; +use super::propagation::{extract_trace_context, propagate_trace_context, TraceContext}; +use rustapi_core::{ + middleware::{BoxedNext, MiddlewareLayer}, + Request, Response, +}; +use std::future::Future; +use std::pin::Pin; +use std::time::Instant; + +/// OpenTelemetry middleware layer for distributed tracing +#[derive(Clone)] +pub struct OtelLayer { + config: OtelConfig, +} + +impl OtelLayer { + /// Create a new OtelLayer with the given configuration + pub fn new(config: OtelConfig) -> Self { + Self { config } + } + + /// Create a new OtelLayer with default configuration + pub fn default_with_service(service_name: impl Into) -> Self { + Self { + config: OtelConfig::builder().service_name(service_name).build(), + } + } + + /// Check if a path should be excluded from tracing + fn should_exclude(&self, path: &str) -> bool { + self.config + .exclude_paths + .iter() + .any(|excluded| path.starts_with(excluded)) + } + + /// Extract header values for tracing + fn extract_trace_headers(&self, request: &Request) -> Vec<(String, String)> { + let mut headers = Vec::new(); + for header_name in &self.config.trace_headers { + if let Some(value) = request.headers().get(header_name.as_str()) { + if let Ok(val) = value.to_str() { + headers.push((header_name.clone(), val.to_string())); + } + } + } + headers + } +} + +impl MiddlewareLayer for OtelLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + let uri = req.uri().to_string(); + let method = req.method().to_string(); + + // Check if this path should be excluded + let path = req.uri().path(); + if self.should_exclude(path) { + return Box::pin(async move { next(req).await }); + } + + // Extract or create trace context + let trace_context = extract_trace_context(&req); + let trace_headers = self.extract_trace_headers(&req); + + Box::pin(async move { + let start = Instant::now(); + + // Create span for this request + let span_name = format!("{} {}", method, path_pattern(&uri)); + + // Log span start + tracing::info_span!( + "http_request", + otel_name = %span_name, + http_method = %method, + http_url = %uri, + http_route = %path_pattern(&uri), + trace_id = %trace_context.trace_id, + span_id = %trace_context.span_id, + parent_span_id = trace_context.parent_span_id.as_deref().unwrap_or("none"), + service_name = %config.service_name, + ); + + // Store trace context in request extensions for downstream use + let mut req = req; + req.extensions_mut().insert(trace_context.clone()); + + // Call the next middleware/handler + let mut response = next(req).await; + + // Calculate duration + let duration = start.elapsed(); + let status = response.status().as_u16(); + + // Determine span status based on HTTP status + let (span_status, error) = if status >= 500 { + ("ERROR", true) + } else if status >= 400 { + ("UNSET", false) + } else { + ("OK", false) + }; + + // Log span end with metrics + tracing::info!( + target: "otel", + trace_id = %trace_context.trace_id, + span_id = %trace_context.span_id, + http_method = %method, + http_url = %uri, + http_status_code = status, + duration_ms = duration.as_millis() as u64, + otel_status = span_status, + error = error, + service_name = %config.service_name, + "request completed" + ); + + // Log trace headers if configured + for (name, value) in &trace_headers { + tracing::debug!( + target: "otel", + trace_id = %trace_context.trace_id, + header_name = %name, + header_value = %value, + "traced header" + ); + } + + // Propagate trace context to response if enabled + if config.propagate_context { + propagate_trace_context(response.headers_mut(), &trace_context); + } + + response + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +/// Extract a normalized path pattern from the URI +/// Replaces numeric path segments with {id} for better grouping +fn path_pattern(uri: &str) -> String { + let path = uri.split('?').next().unwrap_or(uri); + let segments: Vec<&str> = path.split('/').collect(); + + segments + .into_iter() + .map(|segment| { + // Replace numeric IDs with {id} + if segment.chars().all(|c| c.is_ascii_digit()) && !segment.is_empty() { + "{id}" + // Replace UUIDs with {uuid} + } else if is_uuid(segment) { + "{uuid}" + } else { + segment + } + }) + .collect::>() + .join("/") +} + +/// Check if a string looks like a UUID +fn is_uuid(s: &str) -> bool { + if s.len() != 36 { + return false; + } + let parts: Vec<&str> = s.split('-').collect(); + if parts.len() != 5 { + return false; + } + parts + .iter() + .all(|p| p.chars().all(|c| c.is_ascii_hexdigit())) +} + +/// Trait for storing and retrieving trace context from requests +#[allow(dead_code)] +pub trait TraceContextExt { + /// Get the trace context from the request + fn trace_context(&self) -> Option<&TraceContext>; +} + +impl TraceContextExt for Request { + fn trace_context(&self) -> Option<&TraceContext> { + self.extensions().get::() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use std::sync::Arc; + + #[test] + fn test_path_pattern_numeric_ids() { + assert_eq!(path_pattern("/users/123"), "/users/{id}"); + assert_eq!( + path_pattern("/users/123/posts/456"), + "/users/{id}/posts/{id}" + ); + } + + #[test] + fn test_path_pattern_uuids() { + assert_eq!( + path_pattern("/users/550e8400-e29b-41d4-a716-446655440000"), + "/users/{uuid}" + ); + } + + #[test] + fn test_path_pattern_with_query() { + assert_eq!(path_pattern("/users/123?page=1"), "/users/{id}"); + } + + #[test] + fn test_is_uuid() { + assert!(is_uuid("550e8400-e29b-41d4-a716-446655440000")); + assert!(!is_uuid("not-a-uuid")); + assert!(!is_uuid("12345")); + } + + #[tokio::test] + async fn test_otel_layer_basic() { + let config = OtelConfig::builder().service_name("test-service").build(); + let layer = OtelLayer::new(config); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/api/users") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert_eq!(response.status(), 200); + } + + #[tokio::test] + async fn test_otel_layer_excludes_health() { + let config = OtelConfig::builder() + .service_name("test-service") + .exclude_path("/health") + .build(); + let layer = OtelLayer::new(config); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/health") + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert_eq!(response.status(), 200); + } + + #[tokio::test] + async fn test_trace_context_propagation() { + let config = OtelConfig::builder() + .service_name("test-service") + .propagate_context(true) + .build(); + let layer = OtelLayer::new(config); + + let next: BoxedNext = Arc::new(|_req: Request| { + Box::pin(async { + http::Response::builder() + .status(200) + .body(http_body_util::Full::new(Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = http::Request::builder() + .method("GET") + .uri("/api/test") + .header( + "traceparent", + "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + ) + .body(()) + .unwrap(); + let req = Request::from_http_request(req, Bytes::new()); + + let response = layer.call(req, next).await; + assert!(response.headers().contains_key("x-trace-id")); + } +} diff --git a/crates/rustapi-extras/src/otel/mod.rs b/crates/rustapi-extras/src/otel/mod.rs new file mode 100644 index 00000000..0f60cacc --- /dev/null +++ b/crates/rustapi-extras/src/otel/mod.rs @@ -0,0 +1,38 @@ +//! OpenTelemetry Integration for RustAPI +//! +//! This module provides OpenTelemetry integration with support for: +//! - Distributed tracing with OTLP exporter +//! - Metrics collection +//! - Trace context propagation (W3C Trace Context) +//! - Automatic span creation for HTTP requests +//! +//! # Example +//! +//! ```rust,no_run +//! use rustapi_core::RustApi; +//! use rustapi_extras::otel::{OtelConfig, OtelLayer}; +//! +//! #[tokio::main] +//! async fn main() { +//! let config = OtelConfig::builder() +//! .service_name("my-api") +//! .endpoint("http://localhost:4317") +//! .build(); +//! +//! let app = RustApi::new() +//! .layer(OtelLayer::new(config)) +//! .run("0.0.0.0:3000") +//! .await +//! .unwrap(); +//! } +//! ``` + +mod config; +mod layer; +mod propagation; + +pub use config::{OtelConfig, OtelConfigBuilder, OtelExporter, TraceSampler}; +pub use layer::OtelLayer; +pub use propagation::{ + extract_trace_context, inject_trace_context, propagate_trace_context, TraceContext, +}; diff --git a/crates/rustapi-extras/src/otel/propagation.rs b/crates/rustapi-extras/src/otel/propagation.rs new file mode 100644 index 00000000..7f3b212f --- /dev/null +++ b/crates/rustapi-extras/src/otel/propagation.rs @@ -0,0 +1,313 @@ +//! W3C Trace Context propagation utilities +//! +//! This module implements trace context propagation according to the +//! W3C Trace Context specification for distributed tracing. + +use rustapi_core::Request; +use std::fmt; + +/// W3C Trace Context header name for traceparent +pub const TRACEPARENT_HEADER: &str = "traceparent"; + +/// W3C Trace Context header name for tracestate +pub const TRACESTATE_HEADER: &str = "tracestate"; + +/// Correlation ID header name +pub const CORRELATION_ID_HEADER: &str = "x-correlation-id"; + +/// Request ID header name +pub const REQUEST_ID_HEADER: &str = "x-request-id"; + +/// Trace context information +#[derive(Clone, Debug, Default)] +pub struct TraceContext { + /// Trace ID (128-bit, hex encoded) + pub trace_id: String, + /// Span ID (64-bit, hex encoded) + pub span_id: String, + /// Parent span ID (64-bit, hex encoded) - if this is a child span + pub parent_span_id: Option, + /// Trace flags (8 bits) + pub trace_flags: u8, + /// Trace state (vendor-specific data) + pub trace_state: Option, + /// Correlation ID for request tracking + pub correlation_id: Option, +} + +impl TraceContext { + /// Create a new trace context with generated IDs + pub fn new() -> Self { + Self { + trace_id: Self::generate_trace_id(), + span_id: Self::generate_span_id(), + parent_span_id: None, + trace_flags: 0x01, // Sampled flag + trace_state: None, + correlation_id: Some(Self::generate_correlation_id()), + } + } + + /// Create a child span context from a parent + pub fn child(&self) -> Self { + Self { + trace_id: self.trace_id.clone(), + span_id: Self::generate_span_id(), + parent_span_id: Some(self.span_id.clone()), + trace_flags: self.trace_flags, + trace_state: self.trace_state.clone(), + correlation_id: self.correlation_id.clone(), + } + } + + /// Generate a new trace ID (128-bit, 32 hex chars) + pub fn generate_trace_id() -> String { + use std::time::{SystemTime, UNIX_EPOCH}; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + let random: u64 = rand_simple(); + format!("{:016x}{:016x}", now as u64, random) + } + + /// Generate a new span ID (64-bit, 16 hex chars) + pub fn generate_span_id() -> String { + let random: u64 = rand_simple(); + format!("{:016x}", random) + } + + /// Generate a correlation ID + pub fn generate_correlation_id() -> String { + let random: u64 = rand_simple(); + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + format!("{:x}-{:x}", timestamp, random) + } + + /// Check if trace is sampled + pub fn is_sampled(&self) -> bool { + self.trace_flags & 0x01 == 0x01 + } + + /// Set sampled flag + pub fn set_sampled(&mut self, sampled: bool) { + if sampled { + self.trace_flags |= 0x01; + } else { + self.trace_flags &= !0x01; + } + } + + /// Format as W3C traceparent header value + pub fn to_traceparent(&self) -> String { + format!( + "00-{}-{}-{:02x}", + self.trace_id, self.span_id, self.trace_flags + ) + } + + /// Parse from W3C traceparent header value + pub fn from_traceparent(value: &str) -> Option { + let parts: Vec<&str> = value.split('-').collect(); + if parts.len() != 4 { + return None; + } + + let version = parts[0]; + if version != "00" { + return None; // Only version 00 is supported + } + + let trace_id = parts[1]; + let span_id = parts[2]; + let flags = parts[3]; + + // Validate lengths + if trace_id.len() != 32 || span_id.len() != 16 || flags.len() != 2 { + return None; + } + + // Parse flags + let trace_flags = u8::from_str_radix(flags, 16).ok()?; + + Some(Self { + trace_id: trace_id.to_string(), + span_id: span_id.to_string(), + parent_span_id: None, + trace_flags, + trace_state: None, + correlation_id: None, + }) + } +} + +impl fmt::Display for TraceContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_traceparent()) + } +} + +/// Extract trace context from incoming request headers +pub fn extract_trace_context(request: &Request) -> TraceContext { + let headers = request.headers(); + + // Try to extract traceparent header + let mut context = headers + .get(TRACEPARENT_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(TraceContext::from_traceparent) + .unwrap_or_else(TraceContext::new); + + // Extract tracestate if present + if let Some(state) = headers.get(TRACESTATE_HEADER).and_then(|v| v.to_str().ok()) { + context.trace_state = Some(state.to_string()); + } + + // Extract correlation ID from various headers + context.correlation_id = headers + .get(CORRELATION_ID_HEADER) + .or_else(|| headers.get(REQUEST_ID_HEADER)) + .or_else(|| headers.get("x-amzn-trace-id")) + .and_then(|v| v.to_str().ok()) + .map(String::from) + .or_else(|| Some(TraceContext::generate_correlation_id())); + + context +} + +/// Inject trace context into outgoing request headers +pub fn inject_trace_context(headers: &mut http::HeaderMap, context: &TraceContext) { + use http::header::HeaderValue; + + // Inject traceparent + if let Ok(value) = HeaderValue::from_str(&context.to_traceparent()) { + headers.insert(TRACEPARENT_HEADER, value); + } + + // Inject tracestate if present + if let Some(ref state) = context.trace_state { + if let Ok(value) = HeaderValue::from_str(state) { + headers.insert(TRACESTATE_HEADER, value); + } + } + + // Inject correlation ID + if let Some(ref correlation_id) = context.correlation_id { + if let Ok(value) = HeaderValue::from_str(correlation_id) { + headers.insert(CORRELATION_ID_HEADER, value); + } + } +} + +/// Propagate trace context to response headers +pub fn propagate_trace_context(response_headers: &mut http::HeaderMap, context: &TraceContext) { + use http::header::HeaderValue; + + // Include trace ID in response for debugging + if let Ok(value) = HeaderValue::from_str(&context.trace_id) { + response_headers.insert("x-trace-id", value); + } + + // Include correlation ID in response + if let Some(ref correlation_id) = context.correlation_id { + if let Ok(value) = HeaderValue::from_str(correlation_id) { + response_headers.insert(CORRELATION_ID_HEADER, value); + } + } +} + +/// Simple random number generator (using XorShift) +fn rand_simple() -> u64 { + use std::cell::Cell; + use std::time::{SystemTime, UNIX_EPOCH}; + + thread_local! { + static STATE: Cell = Cell::new( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() as u64 + ); + } + + STATE.with(|state| { + let mut x = state.get(); + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + state.set(x); + x + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_trace_context_new() { + let ctx = TraceContext::new(); + assert_eq!(ctx.trace_id.len(), 32); + assert_eq!(ctx.span_id.len(), 16); + assert!(ctx.is_sampled()); + assert!(ctx.correlation_id.is_some()); + } + + #[test] + fn test_trace_context_child() { + let parent = TraceContext::new(); + let child = parent.child(); + + assert_eq!(child.trace_id, parent.trace_id); + assert_ne!(child.span_id, parent.span_id); + assert_eq!(child.parent_span_id, Some(parent.span_id)); + assert_eq!(child.correlation_id, parent.correlation_id); + } + + #[test] + fn test_traceparent_round_trip() { + let ctx = TraceContext::new(); + let traceparent = ctx.to_traceparent(); + let parsed = TraceContext::from_traceparent(&traceparent).unwrap(); + + assert_eq!(parsed.trace_id, ctx.trace_id); + assert_eq!(parsed.span_id, ctx.span_id); + assert_eq!(parsed.trace_flags, ctx.trace_flags); + } + + #[test] + fn test_traceparent_parsing() { + let traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"; + let ctx = TraceContext::from_traceparent(traceparent).unwrap(); + + assert_eq!(ctx.trace_id, "0af7651916cd43dd8448eb211c80319c"); + assert_eq!(ctx.span_id, "b7ad6b7169203331"); + assert_eq!(ctx.trace_flags, 0x01); + assert!(ctx.is_sampled()); + } + + #[test] + fn test_invalid_traceparent() { + // Invalid version + assert!(TraceContext::from_traceparent("01-abc-def-00").is_none()); + // Wrong number of parts + assert!(TraceContext::from_traceparent("00-abc-def").is_none()); + // Invalid lengths + assert!(TraceContext::from_traceparent("00-abc-def-00").is_none()); + } + + #[test] + fn test_sampled_flag() { + let mut ctx = TraceContext::new(); + assert!(ctx.is_sampled()); + + ctx.set_sampled(false); + assert!(!ctx.is_sampled()); + + ctx.set_sampled(true); + assert!(ctx.is_sampled()); + } +} diff --git a/crates/rustapi-extras/src/retry.rs b/crates/rustapi-extras/src/retry.rs new file mode 100644 index 00000000..b0320485 --- /dev/null +++ b/crates/rustapi-extras/src/retry.rs @@ -0,0 +1,296 @@ +//! Retry middleware with exponential backoff +//! +//! This module provides automatic retry logic for failed requests with configurable +//! backoff strategies and max attempts. +//! +//! # Example +//! +//! ```rust,no_run +//! use rustapi_core::RustApi; +//! use rustapi_extras::RetryLayer; +//! use std::time::Duration; +//! +//! #[tokio::main] +//! async fn main() { +//! let app = RustApi::new() +//! .layer(Box::new( +//! RetryLayer::new() +//! .max_attempts(3) +//! .initial_backoff(Duration::from_millis(100)) +//! )) +//! .run("0.0.0.0:3000") +//! .await +//! .unwrap(); +//! } +//! ``` + +use rustapi_core::{ + middleware::{BoxedNext, MiddlewareLayer}, + Request, Response, +}; +use std::future::Future; +use std::pin::Pin; +use std::time::Duration; + +/// Retry strategy for failed requests +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RetryStrategy { + /// Fixed delay between retries + Fixed, + /// Exponential back off (delay doubles each time) + Exponential, + /// Linear backoff (delay increases linearly) + Linear, +} + +/// Configuration for retry behavior +#[derive(Clone)] +pub struct RetryConfig { + /// Maximum number of retry attempts (excluding the initial attempt) + pub max_attempts: u32, + /// Initial backoff duration + pub initial_backoff: Duration, + /// Maximum backoff duration (cap for exponential/linear growth) + pub max_backoff: Duration, + /// Retry strategy to use + pub strategy: RetryStrategy, + /// Which HTTP status codes to retry + pub retryable_statuses: Vec, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_attempts: 3, + initial_backoff: Duration::from_millis(100), + max_backoff: Duration::from_secs(30), + strategy: RetryStrategy::Exponential, + // Retry on 5xx errors and 429 (Too Many Requests) + retryable_statuses: vec![429, 500, 502, 503, 504], + } + } +} + +/// Retry middleware layer +#[derive(Clone)] +pub struct RetryLayer { + config: RetryConfig, +} + +impl RetryLayer { + /// Create a new retry layer with default configuration + pub fn new() -> Self { + Self { + config: RetryConfig::default(), + } + } + + /// Set the maximum number of retry attempts + pub fn max_attempts(mut self, attempts: u32) -> Self { + self.config.max_attempts = attempts; + self + } + + /// Set the initial backoff duration + pub fn initial_backoff(mut self, duration: Duration) -> Self { + self.config.initial_backoff = duration; + self + } + + /// Set the maximum backoff duration + pub fn max_backoff(mut self, duration: Duration) -> Self { + self.config.max_backoff = duration; + self + } + + /// Set the retry strategy + pub fn strategy(mut self, strategy: RetryStrategy) -> Self { + self.config.strategy = strategy; + self + } + + /// Set which HTTP status codes should trigger a retry + pub fn retryable_statuses(mut self, statuses: Vec) -> Self { + self.config.retryable_statuses = statuses; + self + } + + /// Calculate backoff duration for a given attempt number + fn calculate_backoff(&self, attempt: u32) -> Duration { + let base = self.config.initial_backoff; + + let calculated = match self.config.strategy { + RetryStrategy::Fixed => base, + RetryStrategy::Exponential => { + // 2^attempt * base + base * 2_u32.saturating_pow(attempt) + } + RetryStrategy::Linear => { + // (attempt + 1) * base + base * (attempt + 1) + } + }; + + // Cap at max_backoff + calculated.min(self.config.max_backoff) + } +} + +impl Default for RetryLayer { + fn default() -> Self { + Self::new() + } +} + +impl MiddlewareLayer for RetryLayer { + fn call( + &self, + req: Request, + next: BoxedNext, + ) -> Pin + Send + 'static>> { + let config = self.config.clone(); + let self_clone = self.clone(); // Clone self to access its methods + + Box::pin(async move { + let mut current_req = req; + + for attempt in 0..=config.max_attempts { + // Determine if we need to clone for a potential future retry + let (req_to_send, next_req_opt) = if attempt < config.max_attempts { + if let Some(cloned) = current_req.try_clone() { + (current_req, Some(cloned)) + } else { + // Cloning failed, we can't retry after this + (current_req, None) + } + } else { + (current_req, None) + }; + + let response = next(req_to_send).await; + let status = response.status().as_u16(); + + // Check if we should retry + if attempt < config.max_attempts + && config.retryable_statuses.contains(&status) + && next_req_opt.is_some() + { + tracing::warn!( + attempt = attempt + 1, + max_attempts = config.max_attempts, + status = status, + "Request failed, retrying..." + ); + + // Restore request for next attempt + current_req = next_req_opt.unwrap(); + + // Calculate and sleep for backoff duration + let backoff = self_clone.calculate_backoff(attempt); + tracing::debug!(backoff_ms = backoff.as_millis(), "Waiting before retry"); + tokio::time::sleep(backoff).await; + + continue; + } else { + // Success or no more retries + if attempt > 0 { + tracing::info!( + attempt = attempt + 1, + status = status, + "Request succeeded after retry" + ); + } + return response; + } + } + + // Should be unreachable if logic is correct, but safe fallback + unreachable!("Retry loop finished without returning response") + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use std::sync::atomic::{AtomicU32, Ordering}; + use std::sync::Arc; + + #[tokio::test] + async fn retry_on_503_error() { + let retry_layer = RetryLayer::new().max_attempts(2); + + let attempt_counter = Arc::new(AtomicU32::new(0)); + let counter_clone = attempt_counter.clone(); + + let next: BoxedNext = Arc::new(move |_req: Request| { + let counter = counter_clone.clone(); + Box::pin(async move { + let attempt = counter.fetch_add(1, Ordering::SeqCst); + + // Fail first 2 times, succeed on 3rd + let status = if attempt < 2 { 503 } else { 200 }; + + http::Response::builder() + .status(status) + .body(http_body_util::Full::new(bytes::Bytes::from("OK"))) + .unwrap() + }) as Pin + Send + 'static>> + }); + + let req = Request::from_http_request( + http::Request::builder() + .method("GET") + .uri("/") + .body(()) + .unwrap(), + Bytes::new(), + ); + + let response = retry_layer.call(req, next).await; + + // Should succeed after retries + assert_eq!(response.status(), 200); + // Should have made 3 attempts total (1 initial + 2 retries) + assert_eq!(attempt_counter.load(Ordering::SeqCst), 3); + } + + #[test] + fn exponential_backoff_calculation() { + let layer = RetryLayer::new() + .strategy(RetryStrategy::Exponential) + .initial_backoff(Duration::from_millis(100)); + + assert_eq!(layer.calculate_backoff(0), Duration::from_millis(100)); // 2^0 * 100 + assert_eq!(layer.calculate_backoff(1), Duration::from_millis(200)); // 2^1 * 100 + assert_eq!(layer.calculate_backoff(2), Duration::from_millis(400)); // 2^2 * 100 + assert_eq!(layer.calculate_backoff(3), Duration::from_millis(800)); // 2^3 * 100 + } + + #[test] + fn linear_backoff_calculation() { + let layer = RetryLayer::new() + .strategy(RetryStrategy::Linear) + .initial_backoff(Duration::from_millis(100)); + + assert_eq!(layer.calculate_backoff(0), Duration::from_millis(100)); // 1 * 100 + assert_eq!(layer.calculate_backoff(1), Duration::from_millis(200)); // 2 * 100 + assert_eq!(layer.calculate_backoff(2), Duration::from_millis(300)); // 3 * 100 + } + + #[test] + fn backoff_respects_max() { + let layer = RetryLayer::new() + .strategy(RetryStrategy::Exponential) + .initial_backoff(Duration::from_secs(1)) + .max_backoff(Duration::from_secs(5)); + + // 2^10 = 1024 seconds, but should be capped at 5 + assert_eq!(layer.calculate_backoff(10), Duration::from_secs(5)); + } +} diff --git a/crates/rustapi-extras/src/sanitization.rs b/crates/rustapi-extras/src/sanitization.rs new file mode 100644 index 00000000..f570b270 --- /dev/null +++ b/crates/rustapi-extras/src/sanitization.rs @@ -0,0 +1,98 @@ +//! Input Sanitization Utilities +//! +//! Provides functions to sanitize user input against XSS and injection attacks. +//! NOTE: This is a basic implementation. For production high-risk apps, use a dedicated crate like `ammonia`. + +/// Sanitizes a string by escaping HTML special characters. +/// +/// Replaces: +/// - `&` -> `&` +/// - `<` -> `<` +/// - `>` -> `>` +/// - `"` -> `"` +/// - `'` -> `'` +pub fn sanitize_html(input: &str) -> String { + let mut output = String::with_capacity(input.len()); + for c in input.chars() { + match c { + '&' => output.push_str("&"), + '<' => output.push_str("<"), + '>' => output.push_str(">"), + '"' => output.push_str("""), + '\'' => output.push_str("'"), + _ => output.push(c), + } + } + output +} + +/// Strip all HTML tags from a string. +pub fn strip_tags(input: &str) -> String { + let mut output = String::with_capacity(input.len()); + let mut inside_tag = false; + + for c in input.chars() { + if c == '<' { + inside_tag = true; + } else if c == '>' { + inside_tag = false; + } else if !inside_tag { + output.push(c); + } + } + + output +} + +/// Recursively sanitizes string fields in a JSON value. +pub fn sanitize_json(value: &mut serde_json::Value) { + match value { + serde_json::Value::String(s) => *s = sanitize_html(s), + serde_json::Value::Array(arr) => { + for v in arr { + sanitize_json(v); + } + } + serde_json::Value::Object(map) => { + for (_, v) in map { + sanitize_json(v); + } + } + _ => {} + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_sanitize_html() { + let input = ""; + let expected = "<script>alert('XSS')</script>"; + assert_eq!(sanitize_html(input), expected); + } + + #[test] + fn test_strip_tags() { + let input = "

Hello World

"; + let expected = "Hello World"; + assert_eq!(strip_tags(input), expected); + } + + #[test] + fn test_sanitize_json() { + let mut data = json!({ + "name": "John", + "age": 30, + "tags": ["