diff --git a/.spelling b/.spelling index e9ceb2a8..0c96de03 100644 --- a/.spelling +++ b/.spelling @@ -137,6 +137,7 @@ interop Interop interoperability interoperate +jitter IOCP IP Kubernetes @@ -181,6 +182,7 @@ non-mockable ns nuget NUMA +observability ohno ok Ok @@ -212,6 +214,7 @@ repo repos Reqwest Reusability +RPC runtime runtimes rustc diff --git a/CHANGELOG.md b/CHANGELOG.md index 9862cfa6..9a9ad24d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Please see each crate's change log below: - [`ohno`](./crates/ohno/CHANGELOG.md) - [`ohno_macros`](./crates/ohno_macros/CHANGELOG.md) - [`recoverable`](./crates/recoverable/CHANGELOG.md) +- [`seatbelt`](./crates/seatbelt/CHANGELOG.md) - [`thread_aware`](./crates/thread_aware/CHANGELOG.md) - [`thread_aware_macros`](./crates/thread_aware_macros/CHANGELOG.md) - [`thread_aware_macros_impl`](./crates/thread_aware_macros_impl/CHANGELOG.md) diff --git a/Cargo.lock b/Cargo.lock index ad68e1cc..60a7650a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -598,6 +598,16 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + [[package]] name = "iana-time-zone" version = "0.1.64" @@ -944,6 +954,47 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "opentelemetry" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b84bcd6ae87133e903af7ef497404dda70c60d0ea14895fc8a5e6722754fc2a0" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "pin-project-lite", + "thiserror 2.0.17", +] + +[[package]] +name = "opentelemetry-stdout" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc8887887e169414f637b18751487cce4e095be787d23fad13c454e2fb1b3811" +dependencies = [ + "chrono", + "opentelemetry", + "opentelemetry_sdk", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14ae4f5991976fd48df6d843de219ca6d31b01daaab2dad5af2badeded372bd" +dependencies = [ + "futures-channel", + "futures-executor", + "futures-util", + "opentelemetry", + "percent-encoding", + "rand 0.9.2", + "thiserror 2.0.17", + "tokio", + "tokio-stream", +] + [[package]] name = "os_pipe" version = "1.2.3" @@ -992,6 +1043,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + [[package]] name = "phf" version = "0.11.3" @@ -1298,6 +1355,29 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "seatbelt" +version = "0.2.0" +dependencies = [ + "alloc_tracker", + "criterion", + "fastrand", + "futures", + "http", + "layered", + "mutants", + "ohno", + "opentelemetry", + "opentelemetry-stdout", + "opentelemetry_sdk", + "recoverable", + "static_assertions", + "tick", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "semver" version = "1.0.27" @@ -1682,6 +1762,17 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.17" diff --git a/Cargo.toml b/Cargo.toml index fb58951b..975995d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ layered = { path = "crates/layered", default-features = false, version = "0.3.0" ohno = { path = "crates/ohno", default-features = false, version = "0.2.1" } ohno_macros = { path = "crates/ohno_macros", default-features = false, version = "0.2.0" } recoverable = { path = "crates/recoverable", default-features = false, version = "0.1.0" } +seatbelt = { path = "crates/seatbelt", default-features = false, version = "0.2.0" } testing_aids = { path = "crates/testing_aids", default-features = false } thread_aware = { path = "crates/thread_aware", default-features = false, version = "0.6.1" } thread_aware_macros = { path = "crates/thread_aware_macros", default-features = false, version = "0.6.1" } @@ -51,9 +52,11 @@ criterion = { version = "0.7.0", default-features = false } derive_more = { version = "2.0.1", default-features = false } duct = { version = "1.1.1", default-features = false } dynosaur = { version = "0.3.0", default-features = false } +fastrand = { version = "2.3.0", default-features = false, features = ["std"] } futures = { version = "0.3.31", default-features = false } futures-core = { version = "0.3.31", default-features = false } futures-util = { version = "0.3.31", default-features = false } +http = { version = "1.2.0", default-features = false, features = ["std"] } infinity_pool = { version = "0.8.1", default-features = false } insta = { version = "1.44.1", default-features = false } jiff = { version = "0.2.16", default-features = false } @@ -65,6 +68,9 @@ new_zealand = { version = "1.0.1", default-features = false } nm = { version = "0.1.21", default-features = false } num-traits = { version = "0.2.19", default-features = false } once_cell = { version = "1.21.3", default-features = false } +opentelemetry = { version = "0.31.0", default-features = false } +opentelemetry-stdout = { version = "0.31.0", default-features = false } +opentelemetry_sdk = { version = "0.31.0", default-features = false } pin-project-lite = { version = "0.2.13", default-features = false } pretty_assertions = { version = "1.4.1", default-features = false } prettyplease = { version = "0.2.37", default-features = false } diff --git a/README.md b/README.md index d7f6ea59..31896942 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ These are the primary crates built out of this repo: - [`layered`](./crates/layered/README.md) - A foundational service abstraction for building composable, middleware-driven systems. - [`ohno`](./crates/ohno/README.md) - High-quality Rust error handling. - [`recoverable`](./crates/recoverable/README.md) - Recovery information and classification for resilience patterns. +- [`seatbelt`](./crates/seatbelt/README.md) - Resilience and recovery mechanisms for fallible operations. - [`thread_aware`](./crates/thread_aware/README.md) - Facilities to support thread-isolated state. - [`tick`](./crates/tick/README.md) - Provides primitives to interact with and manipulate machine time. diff --git a/crates/seatbelt/CHANGELOG.md b/crates/seatbelt/CHANGELOG.md new file mode 100644 index 00000000..cc03dcc8 --- /dev/null +++ b/crates/seatbelt/CHANGELOG.md @@ -0,0 +1,14 @@ +# Changelog + +## [0.2.0] - 2026-01-20 + +Initial release. + +- ✨ Features + + - Timeout middleware for canceling long-running operations + - Retry middleware with constant, linear, and exponential backoff strategies + - Circuit breaker middleware with health-based failure detection and gradual recovery + - OpenTelemetry metrics integration (`metrics` feature) + - Structured logging via tracing (`logs` feature) + - Shared `Context` for clock and telemetry configuration diff --git a/crates/seatbelt/Cargo.toml b/crates/seatbelt/Cargo.toml new file mode 100644 index 00000000..60add28c --- /dev/null +++ b/crates/seatbelt/Cargo.toml @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +[package] +name = "seatbelt" +description = "Resilience and recovery mechanisms for fallible operations." +version = "0.2.0" +readme = "README.md" +keywords = ["oxidizer", "resilience", "layered", "recovery", "retry", "circuit-breaker"] +categories = ["data-structures"] + +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[package.metadata.cargo_check_external_types] +allowed_external_types = [ + "layered::layer::stack::Stack", + "layered::service::Service", + "opentelemetry::metrics::meter::MeterProvider", + "recoverable::Recovery", + "recoverable::RecoveryInfo", + "recoverable::RecoveryKind", + "tick::clock::Clock", + "tower_layer::Layer", +] + +[package.metadata.docs.rs] +all-features = true + +[features] +default = [] +timeout = [] +retry = ["dep:fastrand"] +circuit-breaker = ["dep:fastrand"] +metrics = ["dep:opentelemetry", "opentelemetry/metrics"] +logs = ["dep:tracing"] + +[dependencies] +fastrand = { workspace = true, optional = true } +futures = { workspace = true } +layered = { workspace = true } +opentelemetry = { workspace = true, optional = true } +recoverable = { workspace = true } +tick = { workspace = true } +tracing = { workspace = true, optional = true } + +[dev-dependencies] +alloc_tracker.workspace = true +criterion.workspace = true +fastrand.workspace = true +futures = { workspace = true, features = ["executor"] } +http.workspace = true +layered = { workspace = true } +mutants.workspace = true +ohno = { workspace = true, features = ["app-err"] } +opentelemetry = { workspace = true, default-features = false, features = ["metrics"] } +opentelemetry-stdout = { workspace = true, default-features = false, features = ["metrics", "logs"] } +opentelemetry_sdk = { workspace = true, default-features = false, features = ["metrics", "testing", "experimental_metrics_custom_reader"] } +static_assertions.workspace = true +tick = { workspace = true, features = ["test-util", "tokio"] } +tokio = { workspace = true, features = ["rt", "macros"] } +tracing.workspace = true +tracing-subscriber = { workspace = true, features = ["fmt", "std"] } + +[[example]] +name = "timeout" +required-features = ["timeout"] + +[[example]] +name = "timeout_advanced" +required-features = ["timeout"] + +[[example]] +name = "retry" +required-features = ["retry"] + +[[example]] +name = "retry_advanced" +required-features = ["retry"] + +[[example]] +name = "retry_outage" +required-features = ["retry"] + +[[example]] +name = "resilience_pipeline" +required-features = ["retry", "timeout"] + +[[example]] +name = "circuit_breaker" +required-features = ["circuit-breaker", "metrics"] + +[[bench]] +name = "observability" +harness = false +required-features = ["retry", "logs", "metrics"] + +[[bench]] +name = "timeout" +harness = false +required-features = ["timeout"] + +[[bench]] +name = "retry" +harness = false +required-features = ["retry"] + +[[bench]] +name = "circuit_breaker" +harness = false +required-features = ["circuit-breaker"] + +[lints] +workspace = true diff --git a/crates/seatbelt/README.md b/crates/seatbelt/README.md new file mode 100644 index 00000000..3e97fe98 --- /dev/null +++ b/crates/seatbelt/README.md @@ -0,0 +1,124 @@ +
+ Seatbelt Logo + +# Seatbelt + +[![crate.io](https://img.shields.io/crates/v/seatbelt.svg)](https://crates.io/crates/seatbelt) +[![docs.rs](https://docs.rs/seatbelt/badge.svg)](https://docs.rs/seatbelt) +[![MSRV](https://img.shields.io/crates/msrv/seatbelt)](https://crates.io/crates/seatbelt) +[![CI](https://github.com/microsoft/oxidizer/actions/workflows/main.yml/badge.svg?event=push)](https://github.com/microsoft/oxidizer/actions/workflows/main.yml) +[![Coverage](https://codecov.io/gh/microsoft/oxidizer/graph/badge.svg?token=FCUG0EL5TI)](https://codecov.io/gh/microsoft/oxidizer) +[![License](https://img.shields.io/badge/license-MIT-blue.svg)](../../LICENSE) +This crate was developed as part of the Oxidizer project + +
+ +Resilience and recovery mechanisms for fallible operations. + +## Quick Start + +Add resilience to fallible operations, such as RPC calls over the network, with just a few lines of code. +**Retry** handles transient failures and **Timeout** prevents operations from hanging indefinitely: + +```rust +use seatbelt::retry::Retry; +use seatbelt::timeout::Timeout; +use seatbelt::{RecoveryInfo, ResilienceContext}; + +let context = ResilienceContext::new(&clock); +let service = ( + // Retry middleware: Automatically retries failed operations + Retry::layer("retry", &context) + .clone_input() + .recovery_with(|output: &String, _| match output.as_str() { + "temporary_error" => RecoveryInfo::retry(), + "operation timed out" => RecoveryInfo::retry(), + _ => RecoveryInfo::never(), + }), + // Timeout middleware: Cancels operations that take too long + Timeout::layer("timeout", &context) + .timeout_output(|_| "operation timed out".to_string()) + .timeout(Duration::from_secs(30)), + // Your core business logic + Execute::new(my_string_operation), +) + .into_service(); + +let result = service.execute("input data".to_string()).await; +``` + +## Why? + +Communicating over a network is inherently fraught with problems. The network can go down at any time, +sometimes for a millisecond or two. The endpoint you’re connecting to may crash or be rebooted, +network configuration may change from under you, etc. To deliver a robust experience to users, and to +achieve `5` or more `9s` of availability, it is imperative to implement robust resilience patterns to +mask these transient failures. + +This crate provides production-ready resilience middleware with excellent telemetry for building +robust distributed systems that can automatically handle timeouts, retries, and other failure +scenarios. + +* **Production-ready** - Battle-tested middleware with sensible defaults and comprehensive + configuration options. +* **Excellent telemetry** - Built-in support for metrics and structured logging to monitor + resilience behavior in production. +* **Runtime agnostic** - Works seamlessly across any async runtime. Use the same resilience + patterns across different projects and migrate between runtimes without changes. + +## Overview + +This crate uses the [`layered`][__link0] crate for composing middleware. The middleware layers +can be stacked together using tuples and built into a service using the [`Stack`][__link1] trait. + +Resilience middleware also requires [`Clock`][__link2] from the [`tick`][__link3] crate for timing +operations like delays, timeouts, and backoff calculations. The clock is passed through +[`ResilienceContext`][__link4] when creating middleware layers. + +### Core Types + +* [`ResilienceContext`][__link5] - Holds shared state for resilience middleware, including the clock. +* [`RecoveryInfo`][__link6] - Classifies errors as recoverable (transient) or non-recoverable (permanent). +* [`Recovery`][__link7] - A trait for types that can determine their recoverability. + +### Built-in Middleware + +This crate provides built-in resilience middleware that you can use out of the box. See the documentation +for each module for details on how to use them. + +* [`timeout`][__link8] - Middleware that cancels long-running operations. +* [`retry`][__link9] - Middleware that automatically retries failed operations. +* [`circuit_breaker`][__link10] - Middleware that prevents cascading failures. + +## Features + +This crate provides several optional features that can be enabled in your `Cargo.toml`: + +* **`timeout`** - Enables the [`timeout`][__link11] middleware for canceling long-running operations. +* **`retry`** - Enables the [`retry`][__link12] middleware for automatically retrying failed operations with + configurable backoff strategies, jitter, and recovery classification. +* **`circuit-breaker`** - Enables the [`circuit_breaker`][__link13] middleware for preventing cascading failures. +* **`metrics`** - Exposes the OpenTelemetry metrics API for collecting and reporting metrics. +* **`logs`** - Enables structured logging for resilience middleware using the `tracing` crate. + + +
+ +This crate was developed as part of The Oxidizer Project. Browse this crate's source code. + + + [__cargo_doc2readme_dependencies_info]: ggGkYW0CYXSEGy4k8ldDFPOhG2VNeXtD5nnKG6EPY6OfW5wBG8g18NOFNdxpYXKEG5bw60hngnQRG76QZSWWI79pG7-oEvcoPz3VGzVGNfifez53YWSEgmdsYXllcmVkZTAuMy4wgmtyZWNvdmVyYWJsZWUwLjEuMIJoc2VhdGJlbHRlMC4yLjCCZHRpY2tlMC4xLjI + [__link0]: https://crates.io/crates/layered/0.3.0 + [__link1]: https://docs.rs/layered/0.3.0/layered/?search=Stack + [__link10]: https://docs.rs/seatbelt/0.2.0/seatbelt/circuit_breaker/index.html + [__link11]: https://docs.rs/seatbelt/0.2.0/seatbelt/timeout/index.html + [__link12]: https://docs.rs/seatbelt/0.2.0/seatbelt/retry/index.html + [__link13]: https://docs.rs/seatbelt/0.2.0/seatbelt/circuit_breaker/index.html + [__link2]: https://docs.rs/tick/0.1.2/tick/?search=Clock + [__link3]: https://crates.io/crates/tick/0.1.2 + [__link4]: https://docs.rs/seatbelt/0.2.0/seatbelt/?search=ResilienceContext + [__link5]: https://docs.rs/seatbelt/0.2.0/seatbelt/?search=ResilienceContext + [__link6]: https://docs.rs/recoverable/0.1.0/recoverable/?search=RecoveryInfo + [__link7]: https://docs.rs/recoverable/0.1.0/recoverable/?search=Recovery + [__link8]: https://docs.rs/seatbelt/0.2.0/seatbelt/timeout/index.html + [__link9]: https://docs.rs/seatbelt/0.2.0/seatbelt/retry/index.html diff --git a/crates/seatbelt/benches/circuit_breaker.rs b/crates/seatbelt/benches/circuit_breaker.rs new file mode 100644 index 00000000..6cbe96a4 --- /dev/null +++ b/crates/seatbelt/benches/circuit_breaker.rs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#![expect(missing_docs, reason = "benchmark code")] +use alloc_tracker::{Allocator, Session}; +use criterion::{Criterion, criterion_group, criterion_main}; +use futures::executor::block_on; +use layered::{Execute, Service, Stack}; +use seatbelt::circuit_breaker::Circuit; +use seatbelt::{RecoveryInfo, ResilienceContext}; +use tick::Clock; + +#[global_allocator] +static ALLOCATOR: Allocator = Allocator::system(); + +fn entry(c: &mut Criterion) { + let mut group = c.benchmark_group("circuit_breaker"); + let session = Session::new(); + + // No circuit breaker + let service = Execute::new(|_input: Input| async move { Output }); + let operation = session.operation("no-circuit-breaker"); + group.bench_function("no-circuit-breaker", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + // With circuit breaker (closed state) + let context = ResilienceContext::new(Clock::new_frozen()); + + let service = ( + Circuit::layer("bench", &context) + .recovery_with(|_, _| RecoveryInfo::never()) + .rejected_input_error(|_input, _args| Output) + .min_throughput(1000), // High threshold to keep circuit closed + Execute::new(|_input: Input| async move { Ok(Output) }), + ) + .into_service(); + + let operation = session.operation("with-circuit-breaker"); + group.bench_function("with-circuit-breaker", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + group.finish(); + session.print_to_stdout(); +} + +criterion_group!(benches, entry); +criterion_main!(benches); + +#[derive(Debug, Clone)] +struct Input; + +#[derive(Debug, Clone)] +struct Output; diff --git a/crates/seatbelt/benches/observability.rs b/crates/seatbelt/benches/observability.rs new file mode 100644 index 00000000..a618d54f --- /dev/null +++ b/crates/seatbelt/benches/observability.rs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#![expect(missing_docs, reason = "benchmark code")] +use std::time::Duration; + +use alloc_tracker::{Allocator, Session}; +use criterion::{Criterion, criterion_group, criterion_main}; +use futures::executor::block_on; +use layered::{Execute, Service, Stack}; +use opentelemetry_sdk::error::OTelSdkResult; +use opentelemetry_sdk::metrics::data::ResourceMetrics; +use opentelemetry_sdk::metrics::exporter::PushMetricExporter; +use opentelemetry_sdk::metrics::{SdkMeterProvider, Temporality}; +use seatbelt::retry::Retry; +use seatbelt::{RecoveryInfo, ResilienceContext}; +use tick::Clock; + +#[global_allocator] +static ALLOCATOR: Allocator = Allocator::system(); + +fn entry(c: &mut Criterion) { + let mut group = c.benchmark_group("observability"); + let session = Session::new(); + + // No telemetry + let context = ResilienceContext::new(Clock::new_frozen()); + let service = ( + Retry::layer("bench", &context) + .clone_input() + .base_delay(Duration::ZERO) + .recovery_with(|_, _| RecoveryInfo::retry()), + Execute::new(|v: Input| async move { Output::from(v) }), + ) + .into_service(); + let operation = session.operation("retry-no-telemetry"); + group.bench_function("retry-no-telemetry", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + // Metrics + let meter_provider = SdkMeterProvider::builder().with_periodic_exporter(EmptyExporter).build(); + let context = ResilienceContext::new(Clock::new_frozen()).enable_metrics(&meter_provider); + let service = ( + Retry::layer("bench", &context) + .clone_input() + .base_delay(Duration::ZERO) + .recovery_with(|_, _| RecoveryInfo::retry()), + Execute::new(|v: Input| async move { Output::from(v) }), + ) + .into_service(); + let operation = session.operation("retry-metrics"); + group.bench_function("retry-metrics", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + // Logs + let context = ResilienceContext::new(Clock::new_frozen()).enable_logs(); + let service = ( + Retry::layer("bench", &context) + .clone_input() + .base_delay(Duration::ZERO) + .recovery_with(|_, _| RecoveryInfo::retry()), + Execute::new(|v: Input| async move { Output::from(v) }), + ) + .into_service(); + let operation = session.operation("retry-logs"); + group.bench_function("retry-logs", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + group.finish(); + session.print_to_stdout(); +} + +criterion_group!(benches, entry); +criterion_main!(benches); + +#[derive(Debug, Clone)] +struct Input; + +#[derive(Debug, Clone)] +struct Output; + +impl From for Output { + fn from(_input: Input) -> Self { + Self + } +} + +struct EmptyExporter; + +impl PushMetricExporter for EmptyExporter { + async fn export(&self, _metrics: &ResourceMetrics) -> OTelSdkResult { + Ok(()) + } + + fn force_flush(&self) -> OTelSdkResult { + Ok(()) + } + + fn shutdown_with_timeout(&self, _timeout: Duration) -> OTelSdkResult { + Ok(()) + } + + fn temporality(&self) -> Temporality { + Temporality::Cumulative + } +} diff --git a/crates/seatbelt/benches/retry.rs b/crates/seatbelt/benches/retry.rs new file mode 100644 index 00000000..e7a12bdb --- /dev/null +++ b/crates/seatbelt/benches/retry.rs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#![expect(missing_docs, reason = "benchmark code")] + +use std::time::Duration; + +use alloc_tracker::{Allocator, Session}; +use criterion::{Criterion, criterion_group, criterion_main}; +use futures::executor::block_on; +use layered::{Execute, Service, Stack}; +use seatbelt::retry::Retry; +use seatbelt::{RecoveryInfo, ResilienceContext}; +use tick::Clock; + +#[global_allocator] +static ALLOCATOR: Allocator = Allocator::system(); + +fn entry(c: &mut Criterion) { + let mut group = c.benchmark_group("retry"); + let session = Session::new(); + + // No retries + let service = Execute::new(|v: Input| async move { Output::from(v) }); + let operation = session.operation("no-retry"); + group.bench_function("no-retry", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + // With retry + let context = ResilienceContext::new(Clock::new_frozen()); + + let service = ( + Retry::layer("bench", &context) + .clone_input() + .recovery_with(|_, _| RecoveryInfo::never()), + Execute::new(|v: Input| async move { Output::from(v) }), + ) + .into_service(); + + let operation = session.operation("with-retry"); + group.bench_function("with-retry", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + // With retry and recovery + let context = ResilienceContext::new(Clock::new_frozen()); + + let service = ( + Retry::layer("bench", &context) + .clone_input() + .max_retry_attempts(1) + .base_delay(Duration::ZERO) + .recovery_with(|_, _| RecoveryInfo::retry()), + Execute::new(|v: Input| async move { Output::from(v) }), + ) + .into_service(); + + let operation = session.operation("with-retry-and-recovery"); + group.bench_function("with-retry-and-recovery", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + group.finish(); + session.print_to_stdout(); +} + +criterion_group!(benches, entry); +criterion_main!(benches); + +#[derive(Debug, Clone)] +struct Input; + +#[derive(Debug, Clone)] +struct Output; + +impl From for Output { + fn from(_input: Input) -> Self { + Self + } +} diff --git a/crates/seatbelt/benches/timeout.rs b/crates/seatbelt/benches/timeout.rs new file mode 100644 index 00000000..25ae351c --- /dev/null +++ b/crates/seatbelt/benches/timeout.rs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#![expect(missing_docs, reason = "benchmark code")] + +use std::time::Duration; + +use alloc_tracker::{Allocator, Session}; +use criterion::{Criterion, criterion_group, criterion_main}; +use futures::executor::block_on; +use layered::{Execute, Service, Stack}; +use seatbelt::ResilienceContext; +use seatbelt::timeout::Timeout; +use tick::Clock; + +#[global_allocator] +static ALLOCATOR: Allocator = Allocator::system(); + +fn entry(c: &mut Criterion) { + let mut group = c.benchmark_group("timeout"); + let session = Session::new(); + + // No timeout + let service = Execute::new(|v: Input| async move { Output::from(v) }); + let operation = session.operation("no-timeout"); + group.bench_function("no-timeout", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + // With timeout + let context = ResilienceContext::new(Clock::new_frozen()); + + let service = ( + Timeout::layer("bench", &context) + .timeout_output(|_args| Output) + .timeout(Duration::from_secs(10)), + Execute::new(|v: Input| async move { Output::from(v) }), + ) + .into_service(); + + let operation = session.operation("with-timeout"); + group.bench_function("with-timeout", |b| { + b.iter(|| { + let _span = operation.measure_thread(); + _ = block_on(service.execute(Input)); + }); + }); + + group.finish(); + session.print_to_stdout(); +} + +criterion_group!(benches, entry); +criterion_main!(benches); + +struct Input; + +struct Output; + +impl From for Output { + fn from(_input: Input) -> Self { + Self + } +} diff --git a/crates/seatbelt/examples/circuit_breaker.rs b/crates/seatbelt/examples/circuit_breaker.rs new file mode 100644 index 00000000..a8f65007 --- /dev/null +++ b/crates/seatbelt/examples/circuit_breaker.rs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Circuit breaker example that simulates a major service outage and tripping of the +//! circuit breaker by: +//! +//! 1. Monitoring failure rates in real-time +//! 2. Opening the circuit when failure thresholds are exceeded +//! 3. Allowing probe requests to test service recovery +//! 4. Automatically closing the circuit when the service recovers + +use std::time::Duration; + +use layered::{Execute, Service, Stack}; +use ohno::AppError; +use opentelemetry_sdk::metrics::SdkMeterProvider; +use opentelemetry_stdout::MetricExporter; +use seatbelt::circuit_breaker::Circuit; +use seatbelt::{RecoveryInfo, ResilienceContext}; +use tick::Clock; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +#[tokio::main] +async fn main() -> Result<(), AppError> { + let meter_provider = configure_telemetry(); + + let clock = Clock::new_tokio(); + let context = ResilienceContext::new(&clock).enable_metrics(&meter_provider); + + // Define stack with circuit breaker layer + let stack = ( + Circuit::layer("my_circuit_breaker", &context) + // Required: classify the recoverability of outputs + .recovery_with(|output, _args| match output { + Ok(_) => RecoveryInfo::never(), + Err(_) => RecoveryInfo::retry(), + }) + // Required: provide output when circuit is open + .rejected_input_error(|input, _args| format!("rejecting execution of '{input}' because circuit is open")) + // Decrease the following values to see the circuit breaker trip faster + // and speed-up the example + .sampling_duration(Duration::from_secs(2)) + .min_throughput(5) + .break_duration(Duration::from_secs(2)) + .on_probing(|_, _| println!("probing input let in to see if the service has recovered")) + .on_opened(|_, _| println!("circuit opened due to exceeding failure threshold")) + .on_closed(|_, args| { + println!( + "circuit closed because probing succeeded, opened for: {}s", + args.open_duration().as_secs() + ); + }), + Execute::new(execute_operation), + ); + + // Create the service from the stack + let service = stack.into_service(); + + // Execute multiple attempts, the circuit breaker will eventually open because the + // failure rate exceeds the threshold. You can play with this value an increase it to 300 + // to see how the circuit breaker eventually closes when the service recovers. + for attempt in 0..30 { + clock.delay(Duration::from_millis(50)).await; + + match service.execute(attempt).await { + Ok(output) => println!("{attempt}: {output}"), + Err(e) => println!("{attempt}: {e}"), + } + } + + // Flush metrics to stdout before exiting + meter_provider.force_flush()?; + + Ok(()) +} + +// Simulate major service outage, 50% chance of failing +async fn execute_operation(input: u32) -> Result { + // After input 100, the service recovers and always succeeds + if input > 100 { + return Ok(format!("output-{input}")); + } + + if fastrand::i16(0..10) > 5 { + Err(format!("transient error for '{input}'")) + } else { + // Produce some output + Ok(format!("output-{input}")) + } +} + +fn configure_telemetry() -> SdkMeterProvider { + // Set up tracing subscriber for logs to console + tracing_subscriber::registry().with(tracing_subscriber::fmt::layer()).init(); + + SdkMeterProvider::builder() + .with_periodic_exporter(MetricExporter::default()) + .build() +} diff --git a/crates/seatbelt/examples/resilience_pipeline.rs b/crates/seatbelt/examples/resilience_pipeline.rs new file mode 100644 index 00000000..6cfb311c --- /dev/null +++ b/crates/seatbelt/examples/resilience_pipeline.rs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! This example demonstrates how to combine multiple resilience middlewares +//! using the `seatbelt` crate to create a robust execution pipeline with basic +//! resilience capabilities. + +use std::time::Duration; + +use layered::{Execute, Service, Stack}; +use ohno::AppError; +use opentelemetry_sdk::metrics::SdkMeterProvider; +use opentelemetry_stdout::MetricExporter; +use seatbelt::retry::Retry; +use seatbelt::timeout::Timeout; +use seatbelt::{RecoveryInfo, ResilienceContext}; +use tick::Clock; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +#[tokio::main] +async fn main() -> Result<(), AppError> { + let meter_provider = configure_telemetry(); + + let clock = Clock::new_tokio(); + + // Shared options for resilience middleware + let context = ResilienceContext::new(&clock).enable_metrics(&meter_provider).name("my_pipeline"); + + // Define stack with retry and timeout middlewares + let stack = ( + Retry::layer("my_retry", &context) + // automatically clones the input for retries + .clone_input() + // classify the output + .recovery_with(|output: &String, _args| match output.as_str() { + "error" | "timeout" => RecoveryInfo::retry(), + _ => RecoveryInfo::never(), + }), + Timeout::layer("my_timeout", &context) + .timeout(Duration::from_secs(1)) + .timeout_output(|_args| "timeout".to_string()), + Execute::new(execute_operation), + ); + + // Build the service + let service = stack.into_service(); + + // Execute the service with an input + let output = service.execute("value".to_string()).await; + + println!("execution finished, output: {output}"); + + // Flush metrics to stdout before exiting + meter_provider.force_flush()?; + + Ok(()) +} + +async fn execute_operation(input: String) -> String { + if fastrand::i16(0..10) > 4 { "error".to_string() } else { input } +} + +fn configure_telemetry() -> SdkMeterProvider { + // Set up tracing subscriber for logs to console + tracing_subscriber::registry().with(tracing_subscriber::fmt::layer()).init(); + + SdkMeterProvider::builder() + .with_periodic_exporter(MetricExporter::default()) + .build() +} diff --git a/crates/seatbelt/examples/retry.rs b/crates/seatbelt/examples/retry.rs new file mode 100644 index 00000000..fba73eb9 --- /dev/null +++ b/crates/seatbelt/examples/retry.rs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Basic retry middleware example with automatic input cloning and simple recovery logic. + +use std::io::Error; + +use layered::{Execute, Service, Stack}; +use ohno::AppError; +use seatbelt::retry::Retry; +use seatbelt::{RecoveryInfo, ResilienceContext}; +use tick::Clock; + +#[tokio::main] +async fn main() -> Result<(), AppError> { + let clock = Clock::new_tokio(); + let context = ResilienceContext::new(&clock); + + // Define stack with retry layer + let stack = ( + Retry::layer("my_retry", &context) + .clone_input() // Automatically clone input for retries + .recovery_with(|output, _args| match output { + Ok(_) => RecoveryInfo::never(), + Err(_) => RecoveryInfo::retry(), + }), + Execute::new(execute_operation), + ); + + // Create the service from the stack + let service = stack.into_service(); + + match service.execute("value".to_string()).await { + Ok(output) => println!("execution succeeded, result: {output}"), + Err(e) => println!("execution failed, error: {e}"), + } + + Ok(()) +} + +// 20% chance of failing with a transient error +async fn execute_operation(input: String) -> Result { + if fastrand::i16(0..10) > 8 { + Err(Error::other("transient execution error")) + } else { + Ok(input) + } +} diff --git a/crates/seatbelt/examples/retry_advanced.rs b/crates/seatbelt/examples/retry_advanced.rs new file mode 100644 index 00000000..1fd6cad1 --- /dev/null +++ b/crates/seatbelt/examples/retry_advanced.rs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Advanced retry middleware demonstrating custom input cloning and attempt info forwarding. +//! +//! Shows how to inject attempt metadata into requests via `.clone_input()`, access it +//! in the service function, and forward it through to the final output. + +use std::io::Error; +use std::time::Duration; + +use http::{Request, Response}; +use layered::{Execute, Service, Stack}; +use ohno::AppError; +use opentelemetry_sdk::metrics::SdkMeterProvider; +use opentelemetry_stdout::MetricExporter; +use seatbelt::retry::{Attempt, Retry}; +use seatbelt::{RecoveryInfo, ResilienceContext}; +use tick::Clock; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +#[tokio::main] +async fn main() -> Result<(), AppError> { + let meter_provider = configure_telemetry(); + + let clock = Clock::new_tokio(); + let context = ResilienceContext::new(&clock) + .name("retry_advanced") + .enable_metrics(&meter_provider); + + // Define stack with retry layer + let stack = ( + Retry::layer("my_retry", &context) + // Custom input cloning - inject attempt info into request extensions + .clone_input_with(|input: &mut Request, args| { + let mut cloned = input.clone(); + cloned.extensions_mut().insert(args.attempt()); + Some(cloned) + }) + .max_retry_attempts(10) + .use_jitter(true) + .base_delay(Duration::from_millis(100)) + .recovery_with(|output, _args| match output { + Ok(_) => RecoveryInfo::never(), + Err(_) => RecoveryInfo::retry(), + }) + // Register a callback called just before the next retry + .on_retry(|_output, args| { + println!( + "retrying, attempt {}, delay: {}s", + args.attempt().index(), + args.retry_delay().as_secs_f32(), + ); + }), + Execute::new(send_request), + ); + + // Create the service from the stack + let service = stack.into_service(); + + let request = Request::builder().uri("https://example.com").body("value".to_string())?; + + match service.execute(request).await { + Ok(output) => { + // Extract attempt info that was forwarded through the pipeline + let attempts = output.extensions().get::().map_or(0, |a| a.index()); + println!("execution succeeded, result: {}, attempts: {}", output.body(), attempts); + } + Err(e) => println!("execution failed, error: {e}"), + } + + // Flush metrics to stdout before exiting + meter_provider.force_flush()?; + + Ok(()) +} + +// Only 20% chance of success, retries will be attempted with a high probability +async fn send_request(input: Request) -> Result, Error> { + if fastrand::i16(0..10) > 2 { + Err(Error::other("transient execution error")) + } else { + // Extract attempt info that was injected during custom cloning + let attempt = input.extensions().get::().copied().unwrap_or_default(); + + // Forward attempt info to output via response extensions + Response::builder() + .extension(attempt) + .body("success".to_string()) + .map_err(Error::other) + } +} + +fn configure_telemetry() -> SdkMeterProvider { + // Set up tracing subscriber for logs to console + tracing_subscriber::registry().with(tracing_subscriber::fmt::layer()).init(); + + SdkMeterProvider::builder() + .with_periodic_exporter(MetricExporter::default()) + .build() +} diff --git a/crates/seatbelt/examples/retry_outage.rs b/crates/seatbelt/examples/retry_outage.rs new file mode 100644 index 00000000..4e2b7aab --- /dev/null +++ b/crates/seatbelt/examples/retry_outage.rs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#![expect(clippy::unwrap_used, reason = "sample code")] + +//! Demonstrates advanced retry patterns with input restoration from errors. +//! +//! This example showcases how to handle outage scenarios where: +//! - The original input cannot be cloned (expensive request bodies) +//! - Input must be restored from error information using `restore_input_on_error()` +//! - Failed requests are automatically retried with a fallback endpoint +//! - Outage detection and recovery are handled seamlessly + +use std::time::Duration; + +use http::{Request, Response}; +use layered::{Execute, Service, Stack}; +use ohno::AppError; +use opentelemetry_sdk::metrics::SdkMeterProvider; +use opentelemetry_stdout::MetricExporter; +use seatbelt::retry::Retry; +use seatbelt::{Recovery, RecoveryInfo, ResilienceContext}; +use tick::Clock; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +const ENDPOINT_WITH_OUTAGE: &str = "https://example.com"; +const ENDPOINT_ALIVE: &str = "https://fallback.example.com"; + +#[tokio::main] +async fn main() -> Result<(), AppError> { + let meter_provider = configure_telemetry(); + + let clock = Clock::new_tokio(); + let context = ResilienceContext::new(&clock).enable_metrics(&meter_provider); + + // Configure retry layer for outage handling with input restoration + let stack = ( + Retry::layer("outage_retry", &context) + // Disable input cloning - we'll restore from error instead + .clone_input_with(|_, _| None) + // Configure recovery based on an error type + .recovery_with(|output: &Result<_, HttpError>, _| match output { + Ok(_) => RecoveryInfo::never(), // Don't retry successful responses + Err(error) => error.recovery(), // Use error's recovery strategy + }) + // Enable unavailable detection and handling + .handle_unavailable(true) + // Restore input from error when retrying (key feature!) + .restore_input_from_error(|error: &mut HttpError, _| { + // Extract the original request and modify it for fallback endpoint + error.try_restore_request() + }), + Execute::new(send_request), + ); + + // Create the service from the stack + let service = stack.into_service(); + + // Create a request that will initially fail but can be restored + let request = Request::builder() + .uri(ENDPOINT_WITH_OUTAGE) + .body("important request data".to_string())?; + + println!("Sending request to: {}", request.uri()); + + // The service will: + // 1. Try the original endpoint (fails with outage) + // 2. Restore input from the error with fallback endpoint + // 3. Retry with the modified request (succeeds) + let response = service.execute(request).await?; + + println!("Final response: {}", response.body()); + + // Flush metrics to stdout before exiting + meter_provider.force_flush()?; + + Ok(()) +} + +/// Simulates a service that has outages on the primary endpoint but works on fallback. +/// +/// This demonstrates the input restoration pattern where the original request is preserved +/// in the error so it can be modified and retried against a different endpoint. +async fn send_request(input: Request) -> Result, HttpError> { + if input.uri() == ENDPOINT_WITH_OUTAGE { + println!("Request to {} failed - simulating outage", input.uri()); + // Store the original request in the error for later restoration + Err(HttpError::outage(input)) + } else { + println!("Request to {} succeeded", input.uri()); + Ok(Response::new(format!("Success! Data from {}", input.uri()))) + } +} + +/// Custom error type that preserves the original request for restoration. +/// +/// This pattern allows failed requests to be modified and retried against different +/// endpoints without requiring the original input to be cloneable. +#[ohno::error] +struct HttpError { + /// The original request that failed, preserved for input restoration + rejected_request: Option>>, + /// Recovery strategy (retry vs. never) for this error type + recovery: RecoveryInfo, +} + +impl HttpError { + /// Creates an outage error that preserves the original request for retry. + fn outage(rejected_request: Request) -> Self { + Self::caused_by( + Some(Box::new(rejected_request)), + RecoveryInfo::unavailable().delay(Duration::from_millis(100)), + "simulated outage", + ) + } + + /// Restores the original request with a modified endpoint for retry. + /// + /// This is called by `restore_input_on_error()` to extract and modify the + /// original request. It changes the URI to the fallback endpoint and returns + /// the modified request for the next retry attempt. + fn try_restore_request(&mut self) -> Option> { + self.rejected_request + .take() // Extract the stored request + .map(|boxed_request| *boxed_request) // Unbox it + .map(|mut request| { + // Modify the request to use the fallback endpoint + *request.uri_mut() = ENDPOINT_ALIVE.parse().unwrap(); + println!("Restored request with fallback endpoint: {}", request.uri()); + request + }) + } +} + +impl Recovery for HttpError { + fn recovery(&self) -> RecoveryInfo { + self.recovery.clone() + } +} + +fn configure_telemetry() -> SdkMeterProvider { + // Set up tracing subscriber for logs to console + tracing_subscriber::registry().with(tracing_subscriber::fmt::layer()).init(); + + SdkMeterProvider::builder() + .with_periodic_exporter(MetricExporter::default()) + .build() +} diff --git a/crates/seatbelt/examples/timeout.rs b/crates/seatbelt/examples/timeout.rs new file mode 100644 index 00000000..03c8294a --- /dev/null +++ b/crates/seatbelt/examples/timeout.rs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#![expect(clippy::unwrap_used, reason = "sample code")] + +//! Simple timeout resilience middleware example. +//! +//! This example demonstrates the basic usage of the timeout middleware to cancel +//! long-running operations. + +use std::time::Duration; + +use layered::{Execute, Service, Stack}; +use ohno::{AppError, app_err}; +use opentelemetry_sdk::metrics::SdkMeterProvider; +use opentelemetry_stdout::MetricExporter; +use seatbelt::ResilienceContext; +use seatbelt::timeout::Timeout; +use tick::Clock; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +const TIMEOUT_DURATION: Duration = Duration::from_millis(100); +const PROCESSING_DELAY: Duration = Duration::from_millis(500); + +#[tokio::main] +async fn main() -> Result<(), AppError> { + let meter_provider = configure_telemetry(); + + let clock = Clock::new_tokio(); + + // Create common options + let context = ResilienceContext::new(&clock).enable_metrics(&meter_provider); + + // Define stack with timeout layer + let stack = ( + Timeout::layer("my_timeout", &context) + // Required: specify the timeout duration + .timeout(TIMEOUT_DURATION) + // Required: create error output for timeouts + .timeout_error(|args| app_err!("timeout occurred, timeout: {}ms", args.timeout().as_millis())), + Execute::new({ + let clock = clock.clone(); + move |_input| { + let clock = clock.clone(); + async move { + clock.delay(PROCESSING_DELAY).await; // Simulate some processing delay so the timeout can trigger + Ok(()) + } + } + }), + ); + + // Create the service from the stack + let service = stack.into_service(); + + for i in 0..10 { + // Execute the service, results in a timeout error + let timeout_error = service.execute(i.to_string()).await.unwrap_err(); + println!("{i} attempt, error: {timeout_error}"); + } + + // Flush metrics to stdout before exiting + meter_provider.force_flush()?; + + Ok(()) +} + +fn configure_telemetry() -> SdkMeterProvider { + // Set up tracing subscriber for logs to console + tracing_subscriber::registry().with(tracing_subscriber::fmt::layer()).init(); + + SdkMeterProvider::builder() + .with_periodic_exporter(MetricExporter::default()) + .build() +} diff --git a/crates/seatbelt/examples/timeout_advanced.rs b/crates/seatbelt/examples/timeout_advanced.rs new file mode 100644 index 00000000..29a42114 --- /dev/null +++ b/crates/seatbelt/examples/timeout_advanced.rs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Advanced timeout resilience middleware example. +//! +//! This example demonstrates advanced usage of the timeout middleware, including working with +//! Result-based outputs, timeout callbacks, and dynamic timeout durations based on input. + +use std::time::Duration; + +use layered::{Execute, Service, Stack}; +use ohno::{AppError, app_err}; +use opentelemetry_sdk::metrics::SdkMeterProvider; +use opentelemetry_stdout::MetricExporter; +use seatbelt::ResilienceContext; +use seatbelt::timeout::Timeout; +use tick::Clock; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +const TIMEOUT_DURATION: Duration = Duration::from_millis(20); +const PROCESSING_DELAY: Duration = Duration::from_secs(1); + +#[tokio::main] +async fn main() -> Result<(), AppError> { + // Configure telemetry to see the timeout metrics and logs + let meter_provider = configure_telemetry(); + + let clock = Clock::new_tokio(); + + // Create service options + let context: ResilienceContext> = + ResilienceContext::new(&clock).name("my_pipeline").enable_metrics(&meter_provider); + + // Define stack with timeout layer + let stack = ( + Timeout::layer("my_timeout", &context) + // Required: specify the timeout duration + .timeout(TIMEOUT_DURATION) + // Required: create error output for timeouts + .timeout_error(|args| app_err!("timeout occurred, timeout: {}ms", args.timeout().as_millis())) + // Optional: callback to be invoked when a timeout occurs + .on_timeout(|_out, args| { + println!("timeout occurred, timeout: {}ms", args.timeout().as_millis()); + }) + // Optional: override the default timeout duration by inspecting the input + .timeout_override(|input, _args| match input.as_str() { + "2" => Some(Duration::from_millis(300)), + _ => None, + }), + Execute::new({ + let clock = clock.clone(); + move |_input| { + let clock = clock.clone(); + async move { + // Simulate some processing delay so the timeout can trigger + clock.delay(PROCESSING_DELAY).await; + Ok(()) + } + } + }), + ); + + // Create the service from the stack + let service = stack.into_service(); + + for i in 0..10 { + // Execute the service, results in a timeout error + match service.execute(i.to_string()).await { + Ok(()) => println!("execute, input: {i}, result: success"), + Err(e) => println!("execute, input: {i}, error: {e}"), + } + } + + // Flush metrics to stdout before exiting + meter_provider.force_flush()?; + + Ok(()) +} + +fn configure_telemetry() -> SdkMeterProvider { + // Set up tracing subscriber for logs to console + tracing_subscriber::registry().with(tracing_subscriber::fmt::layer()).init(); + + SdkMeterProvider::builder() + .with_periodic_exporter(MetricExporter::default()) + .build() +} diff --git a/crates/seatbelt/favicon.ico b/crates/seatbelt/favicon.ico new file mode 100644 index 00000000..e7752941 --- /dev/null +++ b/crates/seatbelt/favicon.ico @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3de762b64903267b6d416cfb2d420b956fa4dfdfcb47ee8b232845da495be8e5 +size 28527 diff --git a/crates/seatbelt/logo.png b/crates/seatbelt/logo.png new file mode 100644 index 00000000..1794a428 --- /dev/null +++ b/crates/seatbelt/logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7992716781c4ace4099acf8d4104ac6ebc0e1cdeb437f43a0a884ead21bfd29e +size 84743 diff --git a/crates/seatbelt/src/circuit_breaker/args.rs b/crates/seatbelt/src/circuit_breaker/args.rs new file mode 100644 index 00000000..f16ced2e --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/args.rs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Duration; + +use tick::Clock; + +use crate::circuit_breaker::PartitionKey; + +/// Arguments for the [`recovery_with`][super::CircuitLayer::recovery_with] callback function. +/// +/// Provides context for recovery classification in the circuit breaker. +#[derive(Debug)] +#[non_exhaustive] +pub struct RecoveryArgs<'a> { + pub(crate) partition_key: &'a PartitionKey, + pub(crate) clock: &'a Clock, +} + +impl RecoveryArgs<'_> { + /// Returns the partition key associated with the recovery evaluation. + #[must_use] + pub fn partition_key(&self) -> &PartitionKey { + self.partition_key + } + + /// Returns a reference to the clock use by the circuit breaker. + #[must_use] + pub fn clock(&self) -> &Clock { + self.clock + } +} + +/// Arguments for the [`rejected_input`][super::CircuitLayer::rejected_input] callback function. +/// +/// Provides context for generating outputs when the inputs are rejected by the circuit breaker. +#[derive(Debug)] +pub struct RejectedInputArgs<'a> { + pub(crate) partition_key: &'a PartitionKey, +} + +impl RejectedInputArgs<'_> { + /// Returns the partition key associated with the rejected input. + #[must_use] + pub fn partition_key(&self) -> &PartitionKey { + self.partition_key + } +} + +/// Arguments for the [`on_probing`][super::CircuitLayer::on_probing] callback function. +/// +/// Provides context when the circuit breaker enters the probing state to test if the service has recovered. +#[derive(Debug)] +#[non_exhaustive] +pub struct OnProbingArgs<'a> { + pub(crate) partition_key: &'a PartitionKey, +} + +impl OnProbingArgs<'_> { + /// Returns the partition key associated with the probing execution. + #[must_use] + pub fn partition_key(&self) -> &PartitionKey { + self.partition_key + } +} + +/// Arguments for the [`on_closed`][super::CircuitLayer::on_closed] callback function. +/// +/// Provides context when the circuit breaker transitions to the closed state, allowing normal operation. +#[derive(Debug)] +#[non_exhaustive] +pub struct OnClosedArgs<'a> { + pub(crate) partition_key: &'a PartitionKey, + pub(crate) open_duration: std::time::Duration, +} + +impl OnClosedArgs<'_> { + /// Returns the partition key associated with this event. + #[must_use] + pub fn partition_key(&self) -> &PartitionKey { + self.partition_key + } + + /// Returns the duration the circuit was open before closing. + #[must_use] + pub fn open_duration(&self) -> Duration { + self.open_duration + } +} + +/// Arguments for the [`on_opened`][super::CircuitLayer::on_opened] callback function. +/// +/// Provides context when the circuit breaker transitions to the open state, blocking requests due to failures. +#[derive(Debug)] +#[non_exhaustive] +pub struct OnOpenedArgs<'a> { + pub(crate) partition_key: &'a PartitionKey, +} + +impl OnOpenedArgs<'_> { + /// Returns the partition key associated with this event. + #[must_use] + pub fn partition_key(&self) -> &PartitionKey { + self.partition_key + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn recovery_args_accessors() { + let key = PartitionKey::from("test"); + let clock = Clock::new_frozen(); + let args = RecoveryArgs { + partition_key: &key, + clock: &clock, + }; + assert_eq!(args.partition_key(), &key); + let _ = args.clock(); + assert!(format!("{args:?}").contains("RecoveryArgs")); + } + + #[test] + fn rejected_input_args_accessors() { + let key = PartitionKey::from("rejected"); + let args = RejectedInputArgs { partition_key: &key }; + assert_eq!(args.partition_key(), &key); + assert!(format!("{args:?}").contains("RejectedInputArgs")); + } + + #[test] + fn on_probing_args_accessors() { + let key = PartitionKey::from("probing"); + let args = OnProbingArgs { partition_key: &key }; + assert_eq!(args.partition_key(), &key); + assert!(format!("{args:?}").contains("OnProbingArgs")); + } + + #[test] + fn on_closed_args_accessors() { + let key = PartitionKey::from("closed"); + let duration = Duration::from_secs(5); + let args = OnClosedArgs { + partition_key: &key, + open_duration: duration, + }; + assert_eq!(args.partition_key(), &key); + assert_eq!(args.open_duration(), duration); + assert!(format!("{args:?}").contains("OnClosedArgs")); + } + + #[test] + fn on_opened_args_accessors() { + let key = PartitionKey::from("opened"); + let args = OnOpenedArgs { partition_key: &key }; + assert_eq!(args.partition_key(), &key); + assert!(format!("{args:?}").contains("OnOpenedArgs")); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/callbacks.rs b/crates/seatbelt/src/circuit_breaker/callbacks.rs new file mode 100644 index 00000000..ebb579f7 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/callbacks.rs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::{OnClosedArgs, OnOpenedArgs, OnProbingArgs, PartitionKey, RecoveryArgs, RejectedInputArgs}; +use crate::RecoveryInfo; + +crate::utils::define_fn_wrapper!(PartionKeyProvider(Fn(&In) -> PartitionKey)); +crate::utils::define_fn_wrapper!(ShouldRecover(Fn(&Out, RecoveryArgs) -> RecoveryInfo)); +crate::utils::define_fn_wrapper!(RejectedInput(Fn(In, RejectedInputArgs) -> Out)); +crate::utils::define_fn_wrapper!(OnProbing(Fn(&mut In, OnProbingArgs))); +crate::utils::define_fn_wrapper!(OnOpened(Fn(&Out, OnOpenedArgs))); +crate::utils::define_fn_wrapper!(OnClosed(Fn(&Out, OnClosedArgs))); diff --git a/crates/seatbelt/src/circuit_breaker/constants.rs b/crates/seatbelt/src/circuit_breaker/constants.rs new file mode 100644 index 00000000..1030ed01 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/constants.rs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Duration; + +/// Minimum allowed duration for the circuit breaker's sampling window. +pub(crate) const MIN_SAMPLING_DURATION: Duration = Duration::from_secs(1); + +/// Default minimum throughput (number of requests) in the sampling window before +/// the circuit breaker can evaluate the failure rate and potentially trip the circuit. +/// +/// The defaults taken from `Polly V8`: +/// +pub(crate) const DEFAULT_MIN_THROUGHPUT: u32 = 100; + +/// Default duration of the circuit breaker's sampling window. +/// +/// The defaults taken from `Polly V8`: +/// +pub(crate) const DEFAULT_SAMPLING_DURATION: Duration = Duration::from_secs(30); + +/// Default failure threshold (percentage of failed requests) in the sampling window +/// that will trip the circuit breaker. +/// +/// The defaults taken from `Polly V8`: +/// +pub(crate) const DEFAULT_FAILURE_THRESHOLD: f32 = 0.1; + +/// Default duration that the circuit breaker remains open (broken) before +/// transitioning to half-open to test if the service has recovered. +/// +/// The defaults taken from `Polly V8`: +/// +pub(crate) const DEFAULT_BREAK_DURATION: Duration = Duration::from_secs(5); + +pub(crate) const ERR_POISONED_LOCK: &str = + "poisoned lock - cannot continue execution because security and privacy guarantees can no longer be upheld"; diff --git a/crates/seatbelt/src/circuit_breaker/engine/engine_core.rs b/crates/seatbelt/src/circuit_breaker/engine/engine_core.rs new file mode 100644 index 00000000..4edc58c6 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/engine_core.rs @@ -0,0 +1,690 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::sync::Mutex; +use std::time::{Duration, Instant}; + +use tick::Clock; + +use super::{EngineOptions, EnterCircuitResult, ExitCircuitResult}; +use crate::circuit_breaker::constants::ERR_POISONED_LOCK; +use crate::circuit_breaker::engine::probing::{AllowProbeResult, Probes, ProbingResult}; +use crate::circuit_breaker::{CircuitEngine, ExecutionMode, ExecutionResult, HealthMetrics, HealthStatus}; + +/// Engine that manages the state of the circuit breaker. +#[derive(Debug)] +pub(crate) struct EngineCore { + state: Mutex, + options: EngineOptions, + clock: Clock, +} + +impl EngineCore { + pub fn new(options: EngineOptions, clock: Clock) -> Self { + Self { + state: Mutex::new(State::Closed { + health: options.health_metrics_builder.build(), + }), + options, + clock, + } + } +} + +impl CircuitEngine for EngineCore { + fn enter(&self) -> EnterCircuitResult { + let now = self.clock.instant(); + + // NOTE: Remember to execute all expensive operations (like time checks) outside the lock. + self.state.lock().expect(ERR_POISONED_LOCK).enter(now, &self.options) + } + + fn exit(&self, result: ExecutionResult, _mode: ExecutionMode) -> ExitCircuitResult { + let now = self.clock.instant(); + + // NOTE: Remember to execute all expensive operations (like time checks) outside the lock. + self.state.lock().expect(ERR_POISONED_LOCK).exit(result, now, &self.options) + } +} + +#[derive(Debug)] +enum State { + Closed { health: HealthMetrics }, + Open { open_until: Instant, stats: Stats }, + HalfOpen { probes: Probes, stats: Stats }, +} + +impl State { + fn enter(&mut self, now: Instant, settings: &EngineOptions) -> EnterCircuitResult { + match self { + Self::Closed { .. } => EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal, + }, + Self::Open { open_until, stats } => { + if now >= *open_until { + let mut probes = Probes::new(&settings.probes); + let allow = probes.allow_probe(now); + stats.record_allow_result(allow); + + *self = Self::HalfOpen { + probes, + stats: stats.clone(), + }; + EnterCircuitResult::from(allow) + } else { + stats.rejected = stats.rejected.saturating_add(1); + EnterCircuitResult::Rejected + } + } + Self::HalfOpen { probes, stats: info } => { + let allow = probes.allow_probe(now); + info.record_allow_result(allow); + EnterCircuitResult::from(allow) + } + } + } + + fn exit(&mut self, result: ExecutionResult, now: Instant, settings: &EngineOptions) -> ExitCircuitResult { + match self { + Self::Closed { health } => { + // first, record the result and evaluate the health metrics + health.record(result, now); + let health = health.health_info(); + + // decide the next state based on health status + match health.status() { + // Health is good, remain in a closed state + HealthStatus::Healthy => ExitCircuitResult::Unchanged, + // Health is poor, transition to Open state + HealthStatus::Unhealthy => { + *self = Self::Open { + open_until: now + settings.break_duration, + stats: Stats::new(now), + }; + ExitCircuitResult::Opened(health) + } + } + } + Self::Open { stats, .. } => { + // Record lost results for statistics purposes + stats.probes_lost = stats.probes_lost.saturating_add(1); + + // In open state, we don't process results. This can happen when multiple threads are involved and + // the state of circuit breaker changes between enter and exit calls since these are separate + // method calls that could be interleaved with other threads. Ignore the result. + ExitCircuitResult::Unchanged + } + Self::HalfOpen { probes, stats } => { + // record the result of the probe + stats.record_probe_execution_result(result); + + match probes.record(result, now) { + ProbingResult::Success => { + let stats = stats.clone(); + + *self = Self::Closed { + health: settings.health_metrics_builder.build(), + }; + + ExitCircuitResult::Closed(stats) + } + ProbingResult::Failure => { + stats.re_opened = stats.re_opened.saturating_add(1); + + *self = Self::Open { + open_until: now + settings.break_duration, + stats: stats.clone(), + }; + + ExitCircuitResult::Reopened + } + ProbingResult::Pending => ExitCircuitResult::Unchanged, + } + } + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct Stats { + pub opened_at: Instant, + pub re_opened: usize, + pub probes_total: usize, + pub probes_lost: usize, + pub probes_successes: usize, + pub probes_failures: usize, + pub rejected: usize, +} + +impl Stats { + pub fn new(opened_at: Instant) -> Self { + Self { + opened_at, + probes_total: 0, + probes_lost: 0, + probes_successes: 0, + probes_failures: 0, + rejected: 0, + re_opened: 0, + } + } + + pub fn opened_duration(&self, now: Instant) -> Duration { + now.saturating_duration_since(self.opened_at) + } + + fn record_allow_result(&mut self, allow: AllowProbeResult) { + if allow == AllowProbeResult::Accepted { + self.probes_total = self.probes_total.saturating_add(1); + } else { + self.rejected = self.rejected.saturating_add(1); + } + } + + fn record_probe_execution_result(&mut self, result: ExecutionResult) { + match result { + ExecutionResult::Success => { + self.probes_successes = self.probes_successes.saturating_add(1); + } + ExecutionResult::Failure => { + self.probes_failures = self.probes_failures.saturating_add(1); + } + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::ops::Deref; + + use tick::ClockControl; + + use super::*; + use crate::circuit_breaker::HealthMetricsBuilder; + use crate::circuit_breaker::engine::probing::ProbesOptions; + + fn create_test_settings() -> EngineOptions { + EngineOptions { + break_duration: Duration::from_secs(5), + health_metrics_builder: HealthMetricsBuilder::new( + Duration::from_secs(30), + 0.1, // 10% failure threshold + 10, // minimum 10 requests + ), + probes: ProbesOptions::quick(Duration::from_secs(2)), + } + } + + fn create_test_engine() -> EngineCore { + let settings = create_test_settings(); + let clock = Clock::new_frozen(); + EngineCore::new(settings, clock) + } + + fn open_engine(engine: &EngineCore) { + const MAX_ATTEMPTS: usize = 1000; + + for _attempt in 0..MAX_ATTEMPTS { + engine.enter(); + let result = engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + if matches!(result, ExitCircuitResult::Opened(_)) { + return; + } + } + + panic!("failed to open the circuit after {MAX_ATTEMPTS} attempts"); + } + + #[test] + fn new_with_valid_settings_creates_closed_state() { + let engine = create_test_engine(); + + // Verify engine was created (we can't directly inspect the state due to encapsulation) + // but we can verify it starts in closed state by checking enter() behavior + let result = engine.enter(); + assert!(matches!( + result, + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal + } + )); + } + + #[test] + fn enter_when_closed_accepts_request() { + let engine = create_test_engine(); + + let result = engine.enter(); + + assert!(matches!( + result, + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal + } + )); + } + + #[test] + fn enter_when_open_before_timeout_rejects_request() { + let engine = create_test_engine(); + open_engine(&engine); + + // Verify circuit is now open + let result = engine.enter(); + assert!(matches!(result, EnterCircuitResult::Rejected)); + } + + #[test] + fn enter_when_open_after_timeout_transitions_to_half_open() { + let settings = create_test_settings(); + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Force circuit to open + open_engine(&engine); + + // Advance time beyond break duration + control.advance(Duration::from_secs(6)); + + let result = engine.enter(); + assert!(matches!( + result, + EnterCircuitResult::Accepted { + mode: ExecutionMode::Probe + } + )); + } + + #[test] + fn enter_when_half_open_within_break_duration_rejects_request() { + let settings = create_test_settings(); + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Force to open then half-open + open_engine(&engine); + control.advance(Duration::from_secs(6)); + engine.enter(); // Transitions to half-open + + // Try entering again immediately (within break duration) + let result = engine.enter(); + assert!(matches!(result, EnterCircuitResult::Rejected)); + } + + #[test] + fn enter_when_half_open_after_break_duration_resets_half_open_timer() { + let settings = create_test_settings(); + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Force to open then half-open + open_engine(&engine); + control.advance(Duration::from_secs(6)); + engine.enter(); // Transitions to half-open + + // Advance time beyond break duration while in half-open + control.advance(Duration::from_secs(6)); + + let result = engine.enter(); + assert!(matches!( + result, + EnterCircuitResult::Accepted { + mode: ExecutionMode::Probe + } + )); + } + + #[test] + fn exit_when_closed_with_success_remains_unchanged() { + let engine = create_test_engine(); + engine.enter(); + + let result = engine.exit(ExecutionResult::Success, ExecutionMode::Normal); + + assert!(matches!(result, ExitCircuitResult::Unchanged)); + } + + #[test] + fn exit_when_closed_with_enough_failures_opens_circuit() { + let settings = EngineOptions { + break_duration: Duration::from_secs(5), + health_metrics_builder: HealthMetricsBuilder::new( + Duration::from_secs(30), + 0.1, // 10% failure threshold + 20, // minimum 20 requests (higher than default 10 for this test) + ), + probes: ProbesOptions::quick(Duration::from_secs(2)), + }; + let clock = Clock::new_frozen(); + let engine = EngineCore::new(settings, clock); + + // Record 19 successes and 3 failures = 22 total requests with ~13.6% failure rate + for _ in 0..19 { + engine.enter(); + engine.exit(ExecutionResult::Success, ExecutionMode::Normal); + } + for _ in 0..2 { + engine.enter(); + engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + } + + // One more failure to trigger opening: 3 failures out of 22 total = ~13.6% > 10% + engine.enter(); + let result = engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + + assert!(matches!(result, ExitCircuitResult::Opened(_))); + } + + #[test] + fn exit_when_closed_with_insufficient_failures_remains_unchanged() { + let engine = create_test_engine(); + + // Record some failures but not enough to exceed a threshold (need at least 10 requests) + for _ in 0..5 { + engine.enter(); + engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + } + + engine.enter(); + let result = engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + + assert!(matches!(result, ExitCircuitResult::Unchanged)); + } + + #[test] + fn exit_when_open_ignores_result() { + let engine = create_test_engine(); + open_engine(&engine); + + // Try to record success in open state + let result = engine.exit(ExecutionResult::Success, ExecutionMode::Normal); + assert!(matches!(result, ExitCircuitResult::Unchanged)); + + if let State::Open { stats, .. } = engine.state.lock().unwrap().deref() { + assert_eq!(stats.probes_lost, 1); + } else { + panic!("expected engine to be in Open state"); + } + } + + #[test] + fn exit_when_half_open_with_success_closes_circuit() { + let settings = create_test_settings(); + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Force to open then half-open + open_engine(&engine); + control.advance(Duration::from_secs(6)); + engine.enter(); // Transitions to half-open + + let result = engine.exit(ExecutionResult::Success, ExecutionMode::Normal); + + assert!(matches!(result, ExitCircuitResult::Closed(stats) if stats.probes_successes == 1 && stats.probes_total == 1)); + } + + #[test] + fn exit_when_half_open_with_failure_reopens_circuit() { + let settings = create_test_settings(); + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Force to open then half-open + open_engine(&engine); + control.advance(Duration::from_secs(6)); + engine.enter(); // Transitions to half-open + + let result = engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + + assert!(matches!(result, ExitCircuitResult::Reopened)); + } + + #[test] + fn circuit_breaker_full_cycle() { + let settings = create_test_settings(); + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Start in closed state + let result = engine.enter(); + assert!(matches!( + result, + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal + } + )); + + // Force to open state + open_engine(&engine); + + // Verify open state rejects requests + let result = engine.enter(); + assert!(matches!(result, EnterCircuitResult::Rejected)); + + // Advance time to allow transition to half-open + control.advance(Duration::from_secs(6)); + let result = engine.enter(); + assert!(matches!( + result, + EnterCircuitResult::Accepted { + mode: ExecutionMode::Probe + } + )); + + // Successful probe closes the circuit + let result = engine.exit(ExecutionResult::Success, ExecutionMode::Normal); + + if let ExitCircuitResult::Closed(stats) = &result { + assert_eq!(stats.probes_successes, 1); + assert_eq!(stats.probes_total, 1); + assert_eq!(stats.rejected, 1); + assert_eq!(stats.probes_failures, 0); + assert_eq!(stats.probes_lost, 0); + assert_eq!(stats.re_opened, 0); + } else { + panic!("expected circuit to close after successful probe"); + } + + // Verify back to normal operation + let result = engine.enter(); + assert!(matches!( + result, + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal + } + )); + } + + #[test] + fn circuit_breaker_half_open_failure_cycle() { + let settings = create_test_settings(); + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Force to open state + open_engine(&engine); + + // Transition to half-open + control.advance(Duration::from_secs(6)); + engine.enter(); + + // Failed probe reopens circuit + let result = engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + assert!(matches!(result, ExitCircuitResult::Reopened)); + + // Verify circuit is open again + let result = engine.enter(); + assert!(matches!(result, EnterCircuitResult::Rejected)); + + // Transition to half-open + control.advance(Duration::from_secs(6)); + engine.enter(); + + let result = engine.exit(ExecutionResult::Success, ExecutionMode::Normal); + + if let ExitCircuitResult::Closed(stats) = &result { + assert_eq!(stats.probes_successes, 1); + assert_eq!(stats.probes_total, 2); + assert_eq!(stats.rejected, 1); + assert_eq!(stats.probes_failures, 1); + assert_eq!(stats.probes_lost, 0); + assert_eq!(stats.re_opened, 1); + } else { + panic!("expected circuit to close after successful probe"); + } + } + + #[test] + fn concurrent_enter_exit_operations() { + let engine = create_test_engine(); + + // Simulate operations where enter and exit are called separately + // (though each method call is atomic due to the internal mutex) + engine.enter(); + let result1 = engine.exit(ExecutionResult::Success, ExecutionMode::Normal); + + engine.enter(); + let result2 = engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + + // Both should complete without panicking + assert!(matches!(result1, ExitCircuitResult::Unchanged)); + assert!(matches!(result2, ExitCircuitResult::Unchanged)); + } + + #[test] + fn engine_with_custom_break_duration() { + let settings = EngineOptions { + break_duration: Duration::from_millis(100), + health_metrics_builder: HealthMetricsBuilder::new(Duration::from_secs(30), 0.1, 50), + probes: ProbesOptions::quick(Duration::from_secs(2)), + }; + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Force to open state + open_engine(&engine); + + // Verify still rejected just before timeout + control.advance(Duration::from_millis(99)); + let result = engine.enter(); + assert!(matches!(result, EnterCircuitResult::Rejected)); + + // Verify accepted just after timeout + control.advance(Duration::from_millis(2)); + let result = engine.enter(); + assert!(matches!( + result, + EnterCircuitResult::Accepted { + mode: ExecutionMode::Probe + } + )); + } + + #[test] + fn engine_with_custom_failure_threshold() { + let settings = EngineOptions { + break_duration: Duration::from_secs(5), + health_metrics_builder: HealthMetricsBuilder::new( + Duration::from_secs(30), + 0.5, // 50% failure threshold + 10, // minimum 10 requests + ), + probes: ProbesOptions::quick(Duration::from_secs(2)), + }; + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Record 6 failures and 4 successes (60% failure rate, 10 total requests) + for _ in 0..6 { + engine.enter(); + engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + } + for _ in 0..3 { + engine.enter(); + engine.exit(ExecutionResult::Success, ExecutionMode::Normal); + } + + // Add one more failure to make it 7 failures out of 10 (70% > 50% threshold) + engine.enter(); + let result = engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + + assert!(matches!(result, ExitCircuitResult::Opened(_))); + } + + #[test] + fn stats_record_probe_execution_result_increments_correctly() { + let mut stats = Stats::new(Instant::now()); + + stats.record_probe_execution_result(ExecutionResult::Success); + assert_eq!(stats.probes_successes, 1); + assert_eq!(stats.probes_failures, 0); + + stats.record_probe_execution_result(ExecutionResult::Failure); + assert_eq!(stats.probes_successes, 1); + assert_eq!(stats.probes_failures, 1); + } + + #[test] + fn stats_record_allow_result_increments_correctly() { + let mut stats = Stats::new(Instant::now()); + + stats.record_allow_result(AllowProbeResult::Accepted); + assert_eq!(stats.probes_total, 1); + assert_eq!(stats.rejected, 0); + + stats.record_allow_result(AllowProbeResult::Rejected); + assert_eq!(stats.probes_total, 1); + assert_eq!(stats.rejected, 1); + } + + #[test] + fn stats_opened_for_calculates_duration_correctly() { + let opened_at = Instant::now(); + let stats = Stats::new(opened_at); + + // Simulate some time passing + let later = opened_at + Duration::from_secs(10); + + assert_eq!(stats.opened_duration(later), Duration::from_secs(10)); + } + + #[test] + fn exit_when_half_open_with_pending_probe_returns_unchanged() { + use crate::circuit_breaker::engine::probing::{HealthProbeOptions, ProbeOptions}; + + let settings = EngineOptions { + break_duration: Duration::from_secs(5), + health_metrics_builder: HealthMetricsBuilder::new(Duration::from_secs(30), 0.1, 10), + // Use a HealthProbe with long sampling duration so it returns Pending + probes: ProbesOptions::new([ProbeOptions::HealthProbe(HealthProbeOptions::new( + Duration::from_secs(60), + 0.2, + 1.0, + ))]), + }; + let control = ClockControl::new(); + let clock = control.to_clock(); + let engine = EngineCore::new(settings, clock); + + // Force to open state + open_engine(&engine); + + // Advance time to transition to half-open + control.advance(Duration::from_secs(6)); + engine.enter(); + + // Record success - should return Unchanged because HealthProbe is still sampling + let result = engine.exit(ExecutionResult::Success, ExecutionMode::Probe); + assert!(matches!(result, ExitCircuitResult::Unchanged)); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/engine/engine_fake.rs b/crates/seatbelt/src/circuit_breaker/engine/engine_fake.rs new file mode 100644 index 00000000..e35a9cfa --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/engine_fake.rs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::circuit_breaker::{CircuitEngine, EnterCircuitResult, ExecutionMode, ExecutionResult, ExitCircuitResult}; + +/// Fake engine to be used in tests. +#[derive(Debug)] +pub(crate) struct EngineFake { + enter_result: EnterCircuitResult, + exit_result: ExitCircuitResult, +} + +impl EngineFake { + pub fn new(enter_result: EnterCircuitResult, exit_result: ExitCircuitResult) -> Self { + Self { enter_result, exit_result } + } +} + +impl CircuitEngine for EngineFake { + fn enter(&self) -> EnterCircuitResult { + self.enter_result.clone() + } + + fn exit(&self, _result: ExecutionResult, _mode: ExecutionMode) -> ExitCircuitResult { + self.exit_result.clone() + } +} diff --git a/crates/seatbelt/src/circuit_breaker/engine/engine_telemetry.rs b/crates/seatbelt/src/circuit_breaker/engine/engine_telemetry.rs new file mode 100644 index 00000000..1cf92831 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/engine_telemetry.rs @@ -0,0 +1,297 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::borrow::Cow; + +use tick::Clock; + +#[cfg(any(feature = "metrics", feature = "logs", test))] +use crate::circuit_breaker::CircuitState; +#[cfg(any(feature = "metrics", test))] +use crate::circuit_breaker::telemetry::*; +use crate::circuit_breaker::{CircuitEngine, EnterCircuitResult, ExecutionMode, ExecutionResult, ExitCircuitResult}; + +use crate::utils::TelemetryHelper; +#[cfg(any(feature = "metrics", test))] +use crate::utils::{EVENT_NAME, PIPELINE_NAME, STRATEGY_NAME}; + +/// Wrapper around a circuit engine to add telemetry capabilities. +#[derive(Debug)] +pub(crate) struct EngineTelemetry { + inner: T, + #[cfg(any(feature = "metrics", feature = "logs", test))] + pub(super) telemetry: TelemetryHelper, + #[cfg(any(feature = "metrics", feature = "logs", test))] + pub(super) partition_key: Cow<'static, str>, + #[cfg(any(feature = "metrics", feature = "logs", test))] + pub(super) clock: Clock, +} + +impl EngineTelemetry { + #[cfg(any(feature = "metrics", feature = "logs", test))] + pub fn new(inner: T, telemetry: TelemetryHelper, partition_key: Cow<'static, str>, clock: Clock) -> Self { + Self { + inner, + telemetry, + partition_key, + clock, + } + } + + #[cfg(not(any(feature = "metrics", feature = "logs", test)))] + pub fn new(inner: T, _telemetry: TelemetryHelper, _partition_key: Cow<'static, str>, _clock: Clock) -> Self { + Self { inner } + } +} + +impl CircuitEngine for EngineTelemetry { + fn enter(&self) -> EnterCircuitResult { + let enter_result = self.inner.enter(); + + if matches!(enter_result, EnterCircuitResult::Rejected) { + #[cfg(any(feature = "metrics", test))] + if self.telemetry.metrics_enabled() { + self.telemetry.report_metrics(&[ + opentelemetry::KeyValue::new(PIPELINE_NAME, self.telemetry.pipeline_name.clone()), + opentelemetry::KeyValue::new(STRATEGY_NAME, self.telemetry.strategy_name.clone()), + opentelemetry::KeyValue::new(EVENT_NAME, CIRCUIT_REJECTED_EVENT_NAME), + opentelemetry::KeyValue::new(CIRCUIT_STATE, CircuitState::Open.as_str()), + opentelemetry::KeyValue::new(CIRCUIT_PARTITION, self.partition_key.clone()), + ]); + } + + #[cfg(any(feature = "logs", test))] + if self.telemetry.logs_enabled { + tracing::event!( + name: "seatbelt.circuit_breaker.rejected", + tracing::Level::WARN, + pipeline.name = %self.telemetry.pipeline_name, + strategy.name = %self.telemetry.strategy_name, + circuit_breaker.state = CircuitState::Open.as_str(), + circuit_breaker.partition = %self.partition_key, + ); + } + } + + enter_result + } + + fn exit(&self, result: ExecutionResult, mode: ExecutionMode) -> ExitCircuitResult { + if mode == ExecutionMode::Probe { + #[cfg(any(feature = "metrics", test))] + if self.telemetry.metrics_enabled() { + self.telemetry.report_metrics(&[ + opentelemetry::KeyValue::new(PIPELINE_NAME, self.telemetry.pipeline_name.clone()), + opentelemetry::KeyValue::new(STRATEGY_NAME, self.telemetry.strategy_name.clone()), + opentelemetry::KeyValue::new(EVENT_NAME, CIRCUIT_PROBE_EVENT_NAME), + opentelemetry::KeyValue::new(CIRCUIT_STATE, CircuitState::HalfOpen.as_str()), + opentelemetry::KeyValue::new(CIRCUIT_PARTITION, self.partition_key.clone()), + opentelemetry::KeyValue::new(CIRCUIT_PROBE_RESULT, result.as_str()), + ]); + } + + #[cfg(any(feature = "logs", test))] + if self.telemetry.logs_enabled { + tracing::event!( + name: "seatbelt.circuit_breaker.probe", + tracing::Level::INFO, + pipeline.name = %self.telemetry.pipeline_name, + strategy.name = %self.telemetry.strategy_name, + circuit_breaker.state = CircuitState::HalfOpen.as_str(), + circuit_breaker.partition = %self.partition_key, + circuit_breaker.probe.result = result.as_str(), + ); + } + } + + let exit_result = self.inner.exit(result, mode); + + // Emit telemetry events for circuit state changes + match exit_result { + ExitCircuitResult::Opened(health) => { + #[cfg(any(feature = "metrics", test))] + if self.telemetry.metrics_enabled() { + self.telemetry.report_metrics(&[ + opentelemetry::KeyValue::new(PIPELINE_NAME, self.telemetry.pipeline_name.clone()), + opentelemetry::KeyValue::new(STRATEGY_NAME, self.telemetry.strategy_name.clone()), + opentelemetry::KeyValue::new(EVENT_NAME, CIRCUIT_OPENED_EVENT_NAME), + opentelemetry::KeyValue::new(CIRCUIT_STATE, CircuitState::Open.as_str()), + opentelemetry::KeyValue::new(CIRCUIT_PARTITION, self.partition_key.clone()), + ]); + } + + #[cfg(any(feature = "logs", test))] + if self.telemetry.logs_enabled { + tracing::event!( + name: "seatbelt.circuit_breaker.opened", + tracing::Level::WARN, + pipeline.name = %self.telemetry.pipeline_name, + strategy.name = %self.telemetry.strategy_name, + circuit_breaker.state = CircuitState::Open.as_str(), + circuit_breaker.partition = %self.partition_key, + circuit_breaker.health.failure_rate = health.failure_rate(), + circuit_breaker.health.throughput = health.throughput(), + ); + } + + _ = health; + } + ExitCircuitResult::Closed(ref stats) => { + #[cfg(any(feature = "metrics", test))] + if self.telemetry.metrics_enabled() { + self.telemetry.report_metrics(&[ + opentelemetry::KeyValue::new(PIPELINE_NAME, self.telemetry.pipeline_name.clone()), + opentelemetry::KeyValue::new(STRATEGY_NAME, self.telemetry.strategy_name.clone()), + opentelemetry::KeyValue::new(EVENT_NAME, CIRCUIT_CLOSED_EVENT_NAME), + opentelemetry::KeyValue::new(CIRCUIT_STATE, CircuitState::Closed.as_str()), + opentelemetry::KeyValue::new(CIRCUIT_PARTITION, self.partition_key.clone()), + ]); + } + + #[cfg(any(feature = "logs", test))] + if self.telemetry.logs_enabled { + tracing::event!( + name: "seatbelt.circuit_breaker.closed", + tracing::Level::INFO, + pipeline.name = %self.telemetry.pipeline_name, + strategy.name = %self.telemetry.strategy_name, + circuit_breaker.state = CircuitState::Closed.as_str(), + circuit_breaker.open.duration = stats.opened_duration(self.clock.instant()).as_secs(), + circuit_breaker.partition = %self.partition_key, + circuit_breaker.probes.total = stats.probes_total, + circuit_breaker.probes.successfull = stats.probes_successes, + circuit_breaker.probes.failed = stats.probes_failures, + circuit_breaker.probes.lost = stats.probes_lost, + circuit_breaker.rejections = stats.rejected, + circuit_breaker.re_opened = stats.re_opened, + ); + } + + _ = stats; + } + ExitCircuitResult::Reopened | ExitCircuitResult::Unchanged => { + // We do not report a telemetry event for reopening the circuit + // as it is redundant because it is always preceded by an "opened" + // event, or when there is no state change. + } + } + + exit_result + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +#[cfg(not(miri))] +mod tests { + use std::time::Instant; + + use opentelemetry::KeyValue; + + use super::*; + use crate::circuit_breaker::{EngineFake, HealthInfo, Stats}; + use crate::metrics::{create_meter, create_resilience_event_counter}; + use crate::testing::MetricTester; + + #[test] + fn enter_rejected_ensure_telemetry() { + let (tester, telemetry_engine) = create_engine(EngineFake::new( + EnterCircuitResult::Rejected, + ExitCircuitResult::Closed(Stats::new(Instant::now())), + )); + + let _ = telemetry_engine.enter(); + + tester.assert_attributes( + &[ + KeyValue::new(PIPELINE_NAME, "test_pipeline"), + KeyValue::new(STRATEGY_NAME, "test_strategy"), + KeyValue::new(EVENT_NAME, CIRCUIT_REJECTED_EVENT_NAME), + KeyValue::new(CIRCUIT_PARTITION, "test_partition"), + KeyValue::new(CIRCUIT_STATE, CircuitState::Open.as_str()), + ], + Some(5), + ); + } + + #[test] + fn exit_probe_ensure_telemetry() { + let (tester, telemetry_engine) = create_engine(EngineFake::new( + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal, + }, + ExitCircuitResult::Unchanged, + )); + + let _ = telemetry_engine.exit(ExecutionResult::Success, ExecutionMode::Probe); + + tester.assert_attributes( + &[ + KeyValue::new(PIPELINE_NAME, "test_pipeline"), + KeyValue::new(STRATEGY_NAME, "test_strategy"), + KeyValue::new(EVENT_NAME, CIRCUIT_PROBE_EVENT_NAME), + KeyValue::new(CIRCUIT_PARTITION, "test_partition"), + KeyValue::new(CIRCUIT_STATE, CircuitState::HalfOpen.as_str()), + KeyValue::new(CIRCUIT_PROBE_RESULT, ExecutionResult::Success.as_str()), + ], + Some(6), + ); + } + + #[test] + fn circuit_closed_ensure_telemetry() { + let (tester, telemetry_engine) = create_engine(EngineFake::new( + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal, + }, + ExitCircuitResult::Closed(Stats::new(Instant::now())), + )); + + let _ = telemetry_engine.exit(ExecutionResult::Success, ExecutionMode::Normal); + + tester.assert_attributes( + &[ + KeyValue::new(PIPELINE_NAME, "test_pipeline"), + KeyValue::new(STRATEGY_NAME, "test_strategy"), + KeyValue::new(EVENT_NAME, CIRCUIT_CLOSED_EVENT_NAME), + KeyValue::new(CIRCUIT_PARTITION, "test_partition"), + KeyValue::new(CIRCUIT_STATE, CircuitState::Closed.as_str()), + ], + Some(5), + ); + } + + #[test] + fn circuit_opened_ensure_telemetry() { + let (tester, telemetry_engine) = create_engine(EngineFake::new( + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal, + }, + ExitCircuitResult::Opened(HealthInfo::new(1, 0, 0.75, 100)), + )); + + let _ = telemetry_engine.exit(ExecutionResult::Failure, ExecutionMode::Normal); + tester.assert_attributes( + &[ + KeyValue::new(PIPELINE_NAME, "test_pipeline"), + KeyValue::new(STRATEGY_NAME, "test_strategy"), + KeyValue::new(EVENT_NAME, CIRCUIT_OPENED_EVENT_NAME), + KeyValue::new(CIRCUIT_PARTITION, "test_partition"), + KeyValue::new(CIRCUIT_STATE, CircuitState::Open.as_str()), + ], + Some(5), + ); + } + + fn create_engine(engine: EngineFake) -> (MetricTester, EngineTelemetry) { + let tester = MetricTester::new(); + let telemetry = TelemetryHelper { + pipeline_name: "test_pipeline".into(), + strategy_name: "test_strategy".into(), + event_reporter: Some(create_resilience_event_counter(&create_meter(tester.meter_provider()))), + logs_enabled: true, + }; + let telemetry_engine = EngineTelemetry::new(engine, telemetry, "test_partition".into(), Clock::new_frozen()); + (tester, telemetry_engine) + } +} diff --git a/crates/seatbelt/src/circuit_breaker/engine/engines.rs b/crates/seatbelt/src/circuit_breaker/engine/engines.rs new file mode 100644 index 00000000..b4ee1729 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/engines.rs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use tick::Clock; + +use crate::circuit_breaker::constants::ERR_POISONED_LOCK; +use crate::circuit_breaker::{Engine, EngineCore, EngineOptions, EngineTelemetry, PartitionKey}; +use crate::utils::TelemetryHelper; + +/// Manages circuit breaker engines for different partition keys. +#[derive(Debug)] +pub(crate) struct Engines { + map: Mutex>>, + engine_options: EngineOptions, + clock: Clock, + telemetry: TelemetryHelper, +} + +impl Engines { + pub fn new(engine_options: EngineOptions, clock: Clock, telemetry: TelemetryHelper) -> Self { + Self { + map: Mutex::new(HashMap::new()), + engine_options, + clock, + telemetry, + } + } + + pub fn get_engine(&self, key: &PartitionKey) -> Arc { + let mut map = self.map.lock().expect(ERR_POISONED_LOCK); + + if let Some(engine) = map.get(key) { + return Arc::clone(engine); + } + + let engine = Arc::new(self.create_engine(key)); + map.insert(key.clone(), Arc::clone(&engine)); + engine + } + + #[cfg(test)] + fn len(&self) -> usize { + let map = self.map.lock().expect(ERR_POISONED_LOCK); + map.len() + } + + fn create_engine(&self, key: &PartitionKey) -> Engine { + EngineTelemetry::new( + EngineCore::new(self.engine_options.clone(), self.clock.clone()), + self.telemetry.clone(), + key.clone().into(), + self.clock.clone(), + ) + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + use crate::circuit_breaker::HealthMetricsBuilder; + use crate::circuit_breaker::engine::probing::ProbesOptions; + use crate::metrics::create_resilience_event_counter; + + #[test] + fn get_engine_ok() { + let telemetry = TelemetryHelper { + pipeline_name: "pipeline".into(), + strategy_name: "strategy".into(), + event_reporter: Some(create_resilience_event_counter(&opentelemetry::global::meter("test"))), + logs_enabled: false, + }; + let engines = Engines::new( + EngineOptions { + break_duration: Duration::from_secs(60), + health_metrics_builder: HealthMetricsBuilder::new(Duration::from_millis(100), 0.5, 5), + probes: ProbesOptions::quick(Duration::from_secs(1)), + }, + Clock::new_frozen(), + telemetry, + ); + + assert!(Arc::ptr_eq( + &engines.get_engine(&PartitionKey::from("test")), + &engines.get_engine(&PartitionKey::from("test")) + )); + assert_eq!(engines.len(), 1); + + _ = engines.get_engine(&PartitionKey::from("test2")); + assert_eq!(engines.len(), 2); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/engine/mod.rs b/crates/seatbelt/src/circuit_breaker/engine/mod.rs new file mode 100644 index 00000000..940cfdb8 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/mod.rs @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::fmt::Debug; +use std::time::Duration; + +use crate::circuit_breaker::{ExecutionResult, HealthInfo, HealthMetricsBuilder}; + +pub(super) mod probing; + +#[cfg(any(feature = "metrics", feature = "logs", test))] +#[derive(Debug, Copy, Clone)] +pub(crate) enum CircuitState { + Closed, + Open, + HalfOpen, +} + +#[cfg(any(feature = "metrics", feature = "logs", test))] +impl CircuitState { + pub fn as_str(self) -> &'static str { + match self { + Self::Closed => "closed", + Self::Open => "open", + Self::HalfOpen => "half_open", + } + } +} + +/// Result of attempting to enter the circuit. +#[derive(Debug, Clone)] +pub(crate) enum EnterCircuitResult { + /// The operation is allowed to proceed. + /// + /// The `probe` indicates that this is a test operation used to evaluate whether + /// the circuit can be closed again. + Accepted { mode: ExecutionMode }, + + /// Operation is rejected due to open circuit. + Rejected, +} + +#[derive(Debug, Clone)] +pub(crate) enum ExitCircuitResult { + /// The state remains unchanged. + Unchanged, + + /// Circuit transitioned to Open state. + Opened(HealthInfo), + + /// Circuit re-transitioned to Open state due to a failure in Half-Open state. + Reopened, + + /// Circuit transitioned back to Closed state. + Closed(Stats), +} + +/// Configuration options for the circuit breaker engine. +#[derive(Debug, Clone)] +pub(crate) struct EngineOptions { + pub break_duration: Duration, + pub health_metrics_builder: HealthMetricsBuilder, + pub probes: probing::ProbesOptions, +} + +/// Determines the mode of execution for an operation. +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) enum ExecutionMode { + /// Regular operation. + Normal, + + /// A probe operation to test the health of the underlying service. + Probe, +} + +// Type alias for the default engine with telemetry. +pub type Engine = EngineTelemetry; + +/// Trait defining the behavior of a circuit breaker engine. +pub(crate) trait CircuitEngine: Debug + Send + Sync + 'static { + fn enter(&self) -> EnterCircuitResult; + + fn exit(&self, result: ExecutionResult, mode: ExecutionMode) -> ExitCircuitResult; +} + +mod engine_core; +pub(crate) use engine_core::*; + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(all(test, not(miri)))] +mod engine_fake; +#[cfg(all(test, not(miri)))] +pub(crate) use engine_fake::*; + +mod engine_telemetry; +pub(crate) use engine_telemetry::*; + +mod engines; +pub(crate) use engines::*; + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_circuit_state_as_str() { + assert_eq!(CircuitState::Closed.as_str(), "closed"); + assert_eq!(CircuitState::Open.as_str(), "open"); + assert_eq!(CircuitState::HalfOpen.as_str(), "half_open"); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/engine/probing/health_probe.rs b/crates/seatbelt/src/circuit_breaker/engine/probing/health_probe.rs new file mode 100644 index 00000000..b1496ec6 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/probing/health_probe.rs @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Instant; + +use crate::circuit_breaker::engine::probing::{AllowProbeResult, HealthProbeOptions, ProbeOperation, ProbingResult}; +use crate::circuit_breaker::{ExecutionResult, HealthMetrics, HealthStatus}; +use crate::rnd::Rnd; + +#[derive(Debug)] +pub(crate) struct HealthProbe { + options: HealthProbeOptions, + metrics: HealthMetrics, + fallback_after: Option, + sample_until: Option, + rnd: Rnd, +} + +impl ProbeOperation for HealthProbe { + fn allow_probe(&mut self, now: Instant) -> AllowProbeResult { + // Sampling starts with the first probe attempt. Make sure relevant timestamps are set. + let sample_until = *self.sample_until.get_or_insert_with(|| now + self.options.stage_duration()); + + // Fallback probe is allowed only after the sampling duration has elapsed. + let fallback_after = *self.fallback_after.get_or_insert(sample_until); + + // Allow probe based on the probing ratio. + if self.rnd.next_f64() < self.options.probing_ratio { + return AllowProbeResult::Accepted; + } + + // Allow fallback probe to get through if we are past the sampling duration. + // This can happen if the traffic is very low and no probes were allowed + // by the rate sampling. This allows making progress in low-traffic scenarios + // as a last resort. + if now > fallback_after { + // Allow additional fallback probes only after another sampling duration + // in case allowed probe did not result in a recorded execution (e.g., due to timeout). + self.fallback_after = Some(now + self.options.stage_duration()); + return AllowProbeResult::Accepted; + } + + AllowProbeResult::Rejected + } + + fn record(&mut self, result: ExecutionResult, now: Instant) -> ProbingResult { + // Always record the result + self.metrics.record(result, now); + + // If we are still sampling, we cannot make a decision yet + if self.keep_sampling(now) { + return ProbingResult::Pending; + } + + // Sampling duration elapsed, use the health metrics to determine the result + match self.metrics.health_info().status() { + HealthStatus::Healthy => ProbingResult::Success, + HealthStatus::Unhealthy => ProbingResult::Failure, + } + } +} + +impl HealthProbe { + pub fn new(options: HealthProbeOptions) -> Self { + Self { + metrics: options.builder.build(), + options, + fallback_after: None, + sample_until: None, + rnd: Rnd::Real, + } + } + + fn keep_sampling(&self, now: Instant) -> bool { + match self.sample_until { + None => true, + Some(until) if now < until => true, + _ => false, + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[test] + fn allow_probe_fallback() { + let options = HealthProbeOptions::new(Duration::from_secs(5), 0.5, 0.1); + let mut probe = HealthProbe::new(options); + probe.rnd = Rnd::new_fixed(0.5); + let now = Instant::now(); + + assert_eq!(probe.allow_probe(now), AllowProbeResult::Rejected); + + let later = now + Duration::from_secs(5); + assert_eq!(probe.allow_probe(later), AllowProbeResult::Rejected); + + // Allowed, because we are past the sampling duration and no probes were allowed yet + let later = now + Duration::from_secs(5) + Duration::from_micros(1); + assert_eq!(probe.allow_probe(later), AllowProbeResult::Accepted); + + // Not allowed, because we already let the fallback probe through + assert_eq!(probe.allow_probe(later), AllowProbeResult::Rejected); + + // Allowed again + let later = now + Duration::from_secs(10) + Duration::from_micros(2); + assert_eq!(probe.allow_probe(later), AllowProbeResult::Accepted); + } + + #[test] + fn allow_probe_rejected_when_at_ratio() { + let options = HealthProbeOptions::new(Duration::from_secs(5), 0.5, 0.1); + let mut probe = HealthProbe::new(options); + probe.rnd = Rnd::new_fixed(0.1); + + assert_eq!(probe.allow_probe(Instant::now()), AllowProbeResult::Rejected); + } + + #[test] + fn record_not_allowed_before() { + let options = HealthProbeOptions::new(Duration::from_secs(5), 0.99, 0.1); + let mut probe = HealthProbe::new(options); + let now = Instant::now(); + + assert_eq!(probe.record(ExecutionResult::Success, now), ProbingResult::Pending,); + + assert_eq!(probe.record(ExecutionResult::Success, now), ProbingResult::Pending,); + + let status = probe.metrics.health_info(); + assert_eq!(status.status(), HealthStatus::Healthy); + assert_eq!(status.throughput(), 2); + } + + #[test] + fn allow_then_record_after_sampling_period_healthy() { + let options = HealthProbeOptions::new(Duration::from_secs(5), 0.1, 1.0); + let mut probe = HealthProbe::new(options); + let now = Instant::now(); + + assert_eq!(probe.allow_probe(now), AllowProbeResult::Accepted); + + // At the edge of a sampling period, success + assert_eq!( + probe.record(ExecutionResult::Success, now + Duration::from_secs(5)), + ProbingResult::Success, + ); + + assert_eq!( + probe.record(ExecutionResult::Success, now + Duration::from_secs(10)), + ProbingResult::Success, + ); + } + + #[test] + fn allow_then_record_after_sampling_period_unhealthy() { + let options = HealthProbeOptions::new(Duration::from_secs(5), 0.1, 1.0); + let mut probe = HealthProbe::new(options); + let now = Instant::now(); + + assert_eq!(probe.allow_probe(now), AllowProbeResult::Accepted); + assert_eq!( + probe.record(ExecutionResult::Failure, now + Duration::from_secs(10)), + ProbingResult::Failure, + ); + } + + #[test] + fn record_multiple_ensure_health_evaluated() { + let options = HealthProbeOptions::new(Duration::from_secs(5), 0.6, 1.0); + let mut probe = HealthProbe::new(options); + let now = Instant::now(); + + assert_eq!(probe.allow_probe(now), AllowProbeResult::Accepted); + assert_eq!( + probe.record(ExecutionResult::Success, now + Duration::from_secs(1)), + ProbingResult::Pending, + ); + assert_eq!( + probe.record(ExecutionResult::Failure, now + Duration::from_secs(2)), + ProbingResult::Pending, + ); + assert_eq!( + probe.record(ExecutionResult::Success, now + Duration::from_secs(6)), + ProbingResult::Success, + ); + + assert_eq!( + probe.record(ExecutionResult::Failure, now + Duration::from_secs(6)), + ProbingResult::Success, + ); + + assert_eq!( + probe.record(ExecutionResult::Failure, now + Duration::from_secs(6)), + ProbingResult::Failure, + ); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/engine/probing/mod.rs b/crates/seatbelt/src/circuit_breaker/engine/probing/mod.rs new file mode 100644 index 00000000..58a90dc4 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/probing/mod.rs @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Probing mechanisms for circuit breakers. +//! +//! Probing is used to test if a service has recovered after a failure. +//! Different probing strategies can be implemented by implementing the `ProbeOperation` trait. +//! +//! - Various probes can be combined in sequence using the [`Probes`] struct. +//! - Unified view over various probe types is provided by the [`Probe`] enum. + +use std::fmt::Debug; +use std::time::Instant; + +use crate::circuit_breaker::{EnterCircuitResult, ExecutionMode, ExecutionResult}; + +mod health_probe; +mod options; +mod probes; +mod single_probe; + +pub(crate) use health_probe::*; +pub(crate) use options::*; +pub(crate) use probes::*; +pub(crate) use single_probe::*; + +/// Result of a probing attempt. +#[derive(Debug, Copy, Clone, PartialEq)] +pub(crate) enum ProbingResult { + /// Probing succeeded, no more probing needed. + Success, + + /// Probing failed, circuit should remain open. + Failure, + + /// Probing is still in progress, more probes are needed. + Pending, +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub(crate) enum AllowProbeResult { + Accepted, + Rejected, +} + +impl From for EnterCircuitResult { + fn from(value: AllowProbeResult) -> Self { + match value { + AllowProbeResult::Accepted => Self::Accepted { + mode: ExecutionMode::Probe, + }, + AllowProbeResult::Rejected => Self::Rejected, + } + } +} + +/// Trait defining the behavior of a probing mechanism in a circuit breaker. +pub(crate) trait ProbeOperation: Send + Sync + Debug + 'static { + fn allow_probe(&mut self, now: Instant) -> AllowProbeResult; + + fn record(&mut self, result: ExecutionResult, now: Instant) -> ProbingResult; +} + +/// View over multiple probe types. +#[derive(Debug)] +pub(crate) enum Probe { + Single(SingleProbe), + Health(HealthProbe), +} + +impl Probe { + pub fn new(options: ProbeOptions) -> Self { + match options { + ProbeOptions::SingleProbe { cooldown } => Self::Single(SingleProbe::new(cooldown)), + ProbeOptions::HealthProbe(options) => Self::Health(HealthProbe::new(options)), + } + } +} + +impl ProbeOperation for Probe { + fn allow_probe(&mut self, now: Instant) -> AllowProbeResult { + match self { + Self::Single(probe) => probe.allow_probe(now), + Self::Health(health) => health.allow_probe(now), + } + } + + /// Record the result of a probing attempt. + /// + /// Once the probe reports success or failure, it is considered complete and + /// should never be used again. + fn record(&mut self, result: ExecutionResult, now: Instant) -> ProbingResult { + match self { + Self::Single(probe) => probe.record(result, now), + Self::Health(health) => health.record(result, now), + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[test] + fn probe_new_creates_single_probe() { + let cooldown = Duration::from_secs(5); + let probe = Probe::new(ProbeOptions::SingleProbe { cooldown }); + assert!(matches!(probe, Probe::Single(duration) if duration.probe_cooldown() == cooldown)); + } + + #[test] + fn probe_allow_probe_delegates_to_inner() { + let mut probe = Probe::new(ProbeOptions::SingleProbe { + cooldown: Duration::from_secs(5), + }); + let now = Instant::now(); + + assert_eq!(probe.allow_probe(now), AllowProbeResult::Accepted); + assert_eq!(probe.allow_probe(now), AllowProbeResult::Rejected); + } + + #[test] + fn probe_record_delegates_to_inner() { + let mut probe = Probe::new(ProbeOptions::SingleProbe { + cooldown: Duration::from_secs(5), + }); + let now = Instant::now(); + + assert_eq!(probe.record(ExecutionResult::Success, now), ProbingResult::Success); + assert_eq!(probe.record(ExecutionResult::Failure, now), ProbingResult::Failure); + } + + #[test] + fn allow_probe_result_to_enter_circuit_result_ok() { + assert!(matches!( + EnterCircuitResult::from(AllowProbeResult::Accepted), + EnterCircuitResult::Accepted { + mode: ExecutionMode::Probe + } + )); + + assert!(matches!( + EnterCircuitResult::from(AllowProbeResult::Rejected), + EnterCircuitResult::Rejected + )); + } + + #[test] + fn probe_new_creates_health_probe() { + let options = HealthProbeOptions::new(Duration::from_secs(10), 0.2, 0.5); + let probe = Probe::new(ProbeOptions::HealthProbe(options)); + assert!(matches!(probe, Probe::Health(_))); + } + + #[test] + fn probe_health_allow_probe_delegates_to_inner() { + let options = HealthProbeOptions::new(Duration::from_secs(5), 0.2, 1.0); + let mut probe = Probe::new(ProbeOptions::HealthProbe(options)); + let now = Instant::now(); + + // With probing_ratio=1.0, all probes should be accepted + assert_eq!(probe.allow_probe(now), AllowProbeResult::Accepted); + } + + #[test] + fn probe_health_record_delegates_to_inner() { + let options = HealthProbeOptions::new(Duration::from_secs(5), 0.2, 1.0); + let mut probe = Probe::new(ProbeOptions::HealthProbe(options)); + let now = Instant::now(); + + // allow_probe initializes the sampling period + assert_eq!(probe.allow_probe(now), AllowProbeResult::Accepted); + + // Record before sampling period ends returns Pending + assert_eq!(probe.record(ExecutionResult::Success, now), ProbingResult::Pending); + + // Record after sampling period with success returns Success + assert_eq!( + probe.record(ExecutionResult::Success, now + Duration::from_secs(5)), + ProbingResult::Success + ); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/engine/probing/options.rs b/crates/seatbelt/src/circuit_breaker/engine/probing/options.rs new file mode 100644 index 00000000..a06ba905 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/probing/options.rs @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Duration; +use std::vec::IntoIter; + +use crate::circuit_breaker::HealthMetricsBuilder; + +/// The minimum throughput during probing stage is set to 1, so at least one request must come +/// through in each probing stage to evaluate the health. +const MIN_THROUGHPUT: u32 = 1; + +/// Options for a single probe type. +#[derive(Debug, Clone)] +pub(crate) enum ProbeOptions { + /// A single probe that allows one probe. + /// + /// After the initial probe is allowed, it enters a cool-down period during which + /// no further probes are allowed. + SingleProbe { cooldown: Duration }, + + /// A health-based probe that uses health metrics to determine the health of the system. + HealthProbe(HealthProbeOptions), +} + +/// Configuration options for the probing mechanism. +#[derive(Debug, Clone)] +pub(crate) struct ProbesOptions { + probes: Vec, +} + +impl ProbesOptions { + pub fn quick(cooldown: Duration) -> Self { + Self::new([ProbeOptions::SingleProbe { cooldown }]) + } + + pub fn reliable(stage_duration: Duration, failure_threshold: f32) -> Self { + Self::gradual(&[0.001, 0.01, 0.05, 0.1, 0.25, 0.5], stage_duration, failure_threshold) + } + + pub fn gradual(probing_ratio: &[f64], stage_duration: Duration, failure_threshold: f32) -> Self { + // Start with a single probe + let initial = std::iter::once(ProbeOptions::SingleProbe { cooldown: stage_duration }); + + // Then continue with health-based probes + let health = probing_ratio + .iter() + .map(|probing_ratio| ProbeOptions::HealthProbe(HealthProbeOptions::new(stage_duration, failure_threshold, *probing_ratio))); + + Self::new(initial.chain(health)) + } + + pub fn new(probes: impl IntoIterator) -> Self { + let probes: Vec = probes.into_iter().collect(); + assert!(!probes.is_empty(), "the probes list cannot be empty"); + Self { probes } + } + + pub fn probes(&self) -> IntoIter { + self.probes.clone().into_iter() + } +} + +#[derive(Debug, Clone)] +pub struct HealthProbeOptions { + pub(super) builder: HealthMetricsBuilder, + pub(super) probing_ratio: f64, +} + +impl HealthProbeOptions { + pub fn new(stage_duration: Duration, failure_threshold: f32, probing_ratio: f64) -> Self { + assert!(probing_ratio > 0.0 && probing_ratio <= 1.0, "probing_ratio must be in (0.0, 1.0]"); + assert!((0.0..1.0).contains(&failure_threshold), "failure_threshold must be in [0.0, 1.0)"); + assert!(stage_duration > Duration::ZERO, "stage_duration must be greater than zero"); + + Self { + // The min throughput is set to 0, so if no requests come in during the probing stage, + // the health will be considered healthy by default. + builder: HealthMetricsBuilder::new(stage_duration, failure_threshold, MIN_THROUGHPUT), + probing_ratio, + } + } + + pub fn stage_duration(&self) -> Duration { + self.builder.sampling_duration + } + + #[cfg(test)] + pub fn failure_threshold(&self) -> f32 { + self.builder.failure_threshold + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use static_assertions::assert_impl_all; + + use super::*; + + assert_impl_all!(ProbeOptions: Clone, std::fmt::Debug); + assert_impl_all!(ProbesOptions: Clone, std::fmt::Debug); + #[test] + fn single_probe_constructor_creates_correct_options() { + let cooldown = Duration::from_secs(15); + let options = ProbesOptions::quick(cooldown); + let probes: Vec<_> = options.probes().collect(); + + assert_eq!(probes.len(), 1); + assert!(matches!( + &probes[0], + ProbeOptions::SingleProbe { cooldown: c } if *c == Duration::from_secs(15) + )); + } + + #[test] + fn new_accepts_multiple_probes() { + let options = ProbesOptions::new([ + ProbeOptions::SingleProbe { + cooldown: Duration::from_secs(10), + }, + ProbeOptions::SingleProbe { + cooldown: Duration::from_secs(20), + }, + ProbeOptions::SingleProbe { + cooldown: Duration::from_secs(30), + }, + ]); + + let probes: Vec<_> = options.probes().collect(); + assert_eq!(probes.len(), 3); + assert!(matches!(&probes[0], ProbeOptions::SingleProbe { cooldown } if *cooldown == Duration::from_secs(10))); + assert!(matches!(&probes[1], ProbeOptions::SingleProbe { cooldown } if *cooldown == Duration::from_secs(20))); + assert!(matches!(&probes[2], ProbeOptions::SingleProbe { cooldown } if *cooldown == Duration::from_secs(30))); + } + + #[test] + fn clone_preserves_probe_count() { + let options = ProbesOptions::quick(Duration::from_secs(25)); + let cloned = options.clone(); + + assert_eq!(options.probes().count(), cloned.probes().count()); + } + + #[test] + fn probes_iterator_is_reusable() { + let options = ProbesOptions::quick(Duration::from_secs(30)); + + assert_eq!(options.probes().count(), 1); + assert_eq!(options.probes().count(), 1); + } + + #[test] + #[should_panic(expected = "the probes list cannot be empty")] + fn new_panics_with_empty_iterator() { + let _ = ProbesOptions::new(Vec::::new()); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn health_probe_options_ctor_ok() { + let sampling_duration = Duration::from_secs(60); + let failure_threshold = 0.2; + let probing_ratio = 0.1; + + let options = HealthProbeOptions::new(sampling_duration, failure_threshold, probing_ratio); + + assert_eq!(options.stage_duration(), sampling_duration); + assert_eq!(options.probing_ratio, probing_ratio); + assert_eq!(options.builder.failure_threshold, failure_threshold); + assert_eq!(options.builder.min_throughput, 1); + } + + #[should_panic(expected = "stage_duration must be greater than zero")] + #[test] + fn health_probe_options_ctor_sampling_duration() { + let _ = HealthProbeOptions::new(Duration::ZERO, 0.1, 0.5); + } + + #[should_panic(expected = "failure_threshold must be in [0.0, 1.0)")] + #[test] + fn health_probe_options_ctor_failure_threshold() { + let _ = HealthProbeOptions::new(Duration::from_secs(10), 1.0, 0.5); + } + + #[should_panic(expected = "probing_ratio must be in (0.0, 1.0]")] + #[test] + fn health_probe_options_ctor_probing_ratio() { + let _ = HealthProbeOptions::new(Duration::from_secs(10), 0.1, 0.0); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn probes_options_reliable_ok() { + let options = ProbesOptions::reliable(Duration::from_secs(30), 0.2); + let probes: Vec<_> = options.probes().collect(); + + assert_eq!(probes.len(), 7); + assert!(matches!( + &probes[0], + ProbeOptions::SingleProbe { cooldown } if *cooldown == Duration::from_secs(30) + )); + + let expected_ratios = [0.001, 0.01, 0.05, 0.1, 0.25, 0.5]; + for (i, ratio) in expected_ratios.iter().enumerate() { + let probe = &probes[i + 1]; + + match probe { + ProbeOptions::HealthProbe(options) => { + assert_eq!(options.builder.sampling_duration, Duration::from_secs(30)); + assert_eq!(options.builder.failure_threshold, 0.2); + assert_eq!(options.probing_ratio, *ratio); + } + ProbeOptions::SingleProbe { .. } => panic!("expected HealthProbe"), + } + } + } +} diff --git a/crates/seatbelt/src/circuit_breaker/engine/probing/probes.rs b/crates/seatbelt/src/circuit_breaker/engine/probing/probes.rs new file mode 100644 index 00000000..2b9529fd --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/probing/probes.rs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Instant; +use std::vec; + +use super::{AllowProbeResult, Probe, ProbeOperation, ProbeOptions, ProbesOptions, ProbingResult}; +use crate::circuit_breaker::ExecutionResult; + +/// Manages a sequence of probes. +#[derive(Debug)] +pub(crate) struct Probes { + probes: vec::IntoIter, + current: Probe, +} + +impl Probes { + pub fn new(options: &ProbesOptions) -> Self { + let mut probes = options.probes(); + let probe = probes.next().expect("probes are never empty because ProbesOptions enforces that"); + + Self { + probes, + current: Probe::new(probe), + } + } + + pub fn allow_probe(&mut self, now: Instant) -> AllowProbeResult { + self.current.allow_probe(now) + } + + pub fn record(&mut self, result: ExecutionResult, now: Instant) -> ProbingResult { + match self.current.record(result, now) { + ProbingResult::Success => { + // check if there are more probes to try + match self.probes.next() { + Some(probe) => { + self.current = Probe::new(probe); + ProbingResult::Pending + } + None => ProbingResult::Success, + } + } + ProbingResult::Pending => ProbingResult::Pending, + ProbingResult::Failure => ProbingResult::Failure, + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + + use std::time::Duration; + + use tick::Clock; + + use super::*; + use crate::circuit_breaker::engine::probing::HealthProbeOptions; + + #[test] + fn multiple_probes_ok() { + let options = ProbesOptions::new([ + ProbeOptions::SingleProbe { + cooldown: Duration::from_secs(1), + }, + ProbeOptions::SingleProbe { + cooldown: Duration::from_secs(2), + }, + ]); + let mut probes = Probes::new(&options); + let now = Instant::now(); + + assert_eq!(probes.allow_probe(now), AllowProbeResult::Accepted); + assert_eq!(probes.allow_probe(now), AllowProbeResult::Rejected); + assert_eq!(probes.record(ExecutionResult::Success, now), ProbingResult::Pending); + + assert_eq!(probes.allow_probe(now), AllowProbeResult::Accepted); + assert_eq!(probes.record(ExecutionResult::Success, now), ProbingResult::Success); + + assert!(probes.probes.next().is_none()); + } + + #[test] + fn record_returns_pending_when_probe_returns_pending() { + let now = Clock::new_frozen().instant(); + + let options = ProbesOptions::new([ProbeOptions::HealthProbe(HealthProbeOptions::new(Duration::from_secs(5), 0.2, 1.0))]); + let mut probes = Probes::new(&options); + + // Initialize sampling period + assert_eq!(probes.allow_probe(now), AllowProbeResult::Accepted); + + // Record during sampling period returns Pending + assert_eq!(probes.record(ExecutionResult::Success, now), ProbingResult::Pending); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/engine/probing/single_probe.rs b/crates/seatbelt/src/circuit_breaker/engine/probing/single_probe.rs new file mode 100644 index 00000000..523b5bd0 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/engine/probing/single_probe.rs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::{Duration, Instant}; + +use super::{AllowProbeResult, ProbeOperation, ProbingResult}; +use crate::circuit_breaker::ExecutionResult; + +/// Allows a single probe to get in and based on the result either closes the circuit +/// or goes back to open state. +#[derive(Debug, Clone)] +pub(crate) struct SingleProbe { + probe_cooldown: Duration, + entered_at: Option, +} + +impl SingleProbe { + pub fn new(probe_cooldown: Duration) -> Self { + Self { + probe_cooldown, + entered_at: None, + } + } + + #[cfg(test)] + pub fn probe_cooldown(&self) -> Duration { + self.probe_cooldown + } +} + +impl ProbeOperation for SingleProbe { + fn allow_probe(&mut self, now: Instant) -> AllowProbeResult { + match self.entered_at { + // First probe attempt - record the timestamp to start the cool-down period + None => { + self.entered_at = Some(now); + AllowProbeResult::Accepted + } + // Cool-down has elapsed, allow the probe and reset the cool-down timer. + // We allow additional probe after the cool-down period to handle the case + // where the probe result is not recorded due to future being dropped. + Some(entered_at) if now.saturating_duration_since(entered_at) > self.probe_cooldown => { + self.entered_at = Some(now); + AllowProbeResult::Accepted + } + Some(_) => AllowProbeResult::Rejected, + } + } + + fn record(&mut self, result: ExecutionResult, _now: Instant) -> ProbingResult { + match result { + ExecutionResult::Success => ProbingResult::Success, + ExecutionResult::Failure => ProbingResult::Failure, + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn allow_probe_accepts_single_probe() { + let mut probe = SingleProbe::new(Duration::from_secs(5)); + let now = Instant::now(); + + // The first probe should be accepted + assert_eq!(probe.allow_probe(now), AllowProbeResult::Accepted); + + // The second probe immediately should be rejected + assert_eq!(probe.allow_probe(now), AllowProbeResult::Rejected); + + // After 3 seconds, still should be rejected + let later = now + Duration::from_secs(3); + assert_eq!(probe.allow_probe(later), AllowProbeResult::Rejected); + + // After cooldown, the probe should be accepted again + let later = now + Duration::from_secs(6); + assert_eq!(probe.allow_probe(later), AllowProbeResult::Accepted); + } + + #[test] + fn allow_probe_check_bounds() { + let mut probe = SingleProbe::new(Duration::from_secs(5)); + let now = Instant::now(); + + // The first probe should be accepted + assert_eq!(probe.allow_probe(now), AllowProbeResult::Accepted); + + // After exactly cool-down duration, the probe should still be rejected + let later = now + Duration::from_secs(5); + assert_eq!(probe.allow_probe(later), AllowProbeResult::Rejected); + + // After cool-down + 1 microsecond, the probe should be accepted + let later = now + Duration::from_secs(5) + Duration::from_micros(1); + assert_eq!(probe.allow_probe(later), AllowProbeResult::Accepted); + } + + #[test] + fn record_ensure_correct_result() { + let mut probe = SingleProbe::new(Duration::from_secs(5)); + let now = Instant::now(); + + // Record a success + assert_eq!(probe.record(ExecutionResult::Success, now), ProbingResult::Success); + + // Record a failure + assert_eq!(probe.record(ExecutionResult::Failure, now), ProbingResult::Failure); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/execution_result.rs b/crates/seatbelt/src/circuit_breaker/execution_result.rs new file mode 100644 index 00000000..2d38f8f7 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/execution_result.rs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::RecoveryInfo; + +/// An evaluated execution result. +/// +/// From the perspective of a circuit breaker, an execution can either +/// succeed or fail. This enum captures that binary outcome. +#[derive(Debug, PartialEq, Copy, Clone)] +pub(crate) enum ExecutionResult { + Success, + Failure, +} + +#[cfg(any(feature = "logs", feature = "metrics", test))] +impl ExecutionResult { + pub fn as_str(self) -> &'static str { + match self { + Self::Success => "success", + Self::Failure => "failure", + } + } +} + +impl ExecutionResult { + pub fn from_recovery(recovery: &RecoveryInfo) -> Self { + match recovery.kind() { + crate::RecoveryKind::Retry | crate::RecoveryKind::Unavailable => Self::Failure, + _ => Self::Success, + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_execution_result_from_recovery() { + assert_eq!(ExecutionResult::from_recovery(&RecoveryInfo::retry()), ExecutionResult::Failure); + assert_eq!( + ExecutionResult::from_recovery(&RecoveryInfo::unavailable()), + ExecutionResult::Failure + ); + assert_eq!(ExecutionResult::from_recovery(&RecoveryInfo::never()), ExecutionResult::Success); + assert_eq!(ExecutionResult::from_recovery(&RecoveryInfo::unknown()), ExecutionResult::Success); + } + + #[test] + fn test_execution_result_as_str() { + assert_eq!(ExecutionResult::Success.as_str(), "success"); + assert_eq!(ExecutionResult::Failure.as_str(), "failure"); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/half_open_mode.rs b/crates/seatbelt/src/circuit_breaker/half_open_mode.rs new file mode 100644 index 00000000..13de619b --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/half_open_mode.rs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Duration; + +use crate::circuit_breaker::constants::MIN_SAMPLING_DURATION; +use crate::circuit_breaker::engine::probing::ProbesOptions; + +/// Defines the behavior of the circuit breaker when transitioning from half-open to closed state. +/// +/// The half-open state is a transitional phase where the circuit breaker allows a limited number of +/// requests to pass through to test if the underlying service has recovered. The chosen mode +/// determines how aggressively the circuit breaker probes the service during this phase. +/// +/// Currently, two modes are supported: +/// +/// - [`HalfOpenMode::quick`]: Allows a single probe request to determine if the service has recovered. +/// - [`HalfOpenMode::reliable`]: Gradually increases the percentage of probing requests over multiple stages (default). +#[derive(Debug, Clone, PartialEq)] +pub struct HalfOpenMode { + inner: Mode, +} + +impl HalfOpenMode { + /// Allow quick recovery from half-open state with a single probe. + /// + /// This approach is less reliable compared to the [`HalfOpenMode::reliable`] mode, but + /// can close the circuit faster. + /// + /// The downside of this approach is that it relies on a single execution to determine + /// the health of the service. If that execution happens to succeed by chance, the circuit + /// closes and later requests may fail again, leading to instability and re-opening the circuit + /// again. + #[must_use] + pub fn quick() -> Self { + Self { inner: Mode::Quick } + } + + /// Gradually increase the percentage of probing requests over multiple stages. + /// + /// This approach allows more requests to pass through in a controlled manner, + /// increasing the probing rate over time. This can help more reliably evaluate the + /// health of the underlying service over time rather than relying on a single execution. + /// + /// The pre-configured ratios for each probing stage are: + /// `0.1%, 1%, 5%, 10%, 25%, 50%` + /// + /// Each probing stage advances after the stage duration has elapsed, and the health + /// metrics indicate that the failure rate is below the configured threshold. If any probing stage + /// fails, the circuit reopens immediately and the cycle starts over. + /// + /// # Arguments + /// + /// - `stage_duration` - Optional custom stage duration for each probing stage. If not provided, + /// the value of [`break_duration`][crate::circuit_breaker::CircuitLayer::break_duration] is used. The provided stage + /// duration is clamped to a minimum of 1 second. + #[must_use] + pub fn reliable(stage_duration: impl Into>) -> Self { + Self { + inner: Mode::Reliable(stage_duration.into().map(|d| d.max(MIN_SAMPLING_DURATION))), + } + } + + pub(super) fn to_options(&self, default_stage_duration: Duration, failure_threshold: f32) -> ProbesOptions { + match self.inner { + Mode::Quick => ProbesOptions::quick(default_stage_duration), + Mode::Reliable(duration) => ProbesOptions::reliable(duration.unwrap_or(default_stage_duration), failure_threshold), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +enum Mode { + Quick, + Reliable(Option), +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit_breaker::engine::probing::ProbeOptions; + + #[test] + fn quick_mode_creates_single_probe() { + let mode = HalfOpenMode::quick(); + let options = mode.to_options(Duration::from_secs(30), 0.1); + let probes: Vec<_> = options.probes().collect(); + + assert_eq!(probes.len(), 1); + assert!(matches!(probes[0], ProbeOptions::SingleProbe { .. })); + } + + #[test] + fn quick_mode_uses_default_duration() { + let mode = HalfOpenMode::quick(); + let default = Duration::from_secs(30); + let options = mode.to_options(default, 0.1); + let probes: Vec<_> = options.probes().collect(); + + assert!(matches!( + &probes[0], + ProbeOptions::SingleProbe { cooldown } if *cooldown == default + )); + } + + #[test] + fn reliable_mode_creates_seven_probes() { + let mode = HalfOpenMode::reliable(None); + let options = mode.to_options(Duration::from_secs(30), 0.1); + + assert_eq!(options.probes().len(), 7); + } + + #[test] + fn reliable_mode_with_custom_duration() { + let custom = Duration::from_secs(45); + let mode = HalfOpenMode::reliable(custom); + let options = mode.to_options(Duration::from_secs(30), 0.1); + let probes: Vec<_> = options.probes().collect(); + + assert!(matches!( + &probes[0], + ProbeOptions::SingleProbe { cooldown } if *cooldown == custom + )); + + #[expect(clippy::float_cmp, reason = "Test")] + for probe in &probes[1..] { + if let ProbeOptions::HealthProbe(h) = probe { + assert_eq!(h.stage_duration(), custom); + + assert_eq!(h.failure_threshold(), 0.1); + } + } + } + + #[test] + fn reliable_mode_with_default_duration() { + let mode = HalfOpenMode::reliable(None); + let default = Duration::from_secs(60); + let options = mode.to_options(default, 0.1); + let probes: Vec<_> = options.probes().collect(); + + assert!(matches!( + &probes[0], + ProbeOptions::SingleProbe { cooldown } if *cooldown == default + )); + + for probe in &probes[1..] { + if let ProbeOptions::HealthProbe(h) = probe { + assert_eq!(h.stage_duration(), default); + } + } + } + + #[test] + fn reliable_mode_accepts_various_inputs() { + let mode1 = HalfOpenMode::reliable(Duration::from_secs(10)); + let mode2 = HalfOpenMode::reliable(Some(Duration::from_secs(10))); + let mode3 = HalfOpenMode::reliable(None); + + assert!(matches!(mode1.inner, Mode::Reliable(Some(_)))); + assert!(matches!(mode2.inner, Mode::Reliable(Some(_)))); + assert!(matches!(mode3.inner, Mode::Reliable(None))); + } + + #[test] + fn reliable_mode_clamps_min_duration() { + let mode = HalfOpenMode::reliable(Duration::from_millis(500)); + + assert!(matches!(mode.inner, Mode::Reliable(duration) if duration == Some(Duration::from_secs(1)))); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/health.rs b/crates/seatbelt/src/circuit_breaker/health.rs new file mode 100644 index 00000000..bd0d4984 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/health.rs @@ -0,0 +1,351 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::collections::VecDeque; +use std::time::{Duration, Instant}; + +use super::ExecutionResult; +use crate::circuit_breaker::constants::MIN_SAMPLING_DURATION; + +const WINDOW_COUNT: u32 = 10; + +#[derive(Debug, Copy, Clone, PartialEq)] +pub(crate) enum HealthStatus { + Healthy, + Unhealthy, +} + +/// Aggregated health information that can be used to determine the failure rate and throughput +/// of a service over a recent sampling period. +#[must_use] +#[derive(Debug, Copy, Clone)] +pub(crate) struct HealthInfo { + throughput: u32, + failure_rate: f32, + health_status: HealthStatus, +} + +impl HealthInfo { + pub fn new(successes: u32, failures: u32, failure_threshold: f32, min_throughput: u32) -> Self { + let throughput = successes.saturating_add(failures); + + if throughput == 0 { + return Self { + throughput: 0, + failure_rate: 0.0, + health_status: HealthStatus::Healthy, + }; + } + + #[expect(clippy::cast_possible_truncation, reason = "Acceptable")] + let failure_rate = (f64::from(failures) / f64::from(throughput)) as f32; + + Self { + throughput, + failure_rate, + health_status: if failure_rate >= failure_threshold && throughput >= min_throughput { + HealthStatus::Unhealthy + } else { + HealthStatus::Healthy + }, + } + } + + #[cfg_attr( + not(any(feature = "logs", test)), + expect(dead_code, reason = "trying to avoid dead code here leads to too much conditionals") + )] + pub fn throughput(&self) -> u32 { + self.throughput + } + + #[cfg_attr( + not(any(feature = "logs", test)), + expect(dead_code, reason = "trying to avoid dead code here leads to too much conditionals") + )] + pub fn failure_rate(&self) -> f32 { + self.failure_rate + } + + pub fn status(&self) -> HealthStatus { + self.health_status + } +} + +/// Pre-configured builder that creates `HealthMetrics` instances with consistent settings. +#[derive(Debug, Clone)] +pub(crate) struct HealthMetricsBuilder { + pub(crate) sampling_duration: Duration, + pub(crate) failure_threshold: f32, + pub(crate) min_throughput: u32, +} + +impl HealthMetricsBuilder { + pub fn new(sampling_duration: Duration, failure_threshold: f32, min_throughput: u32) -> Self { + Self { + sampling_duration: sampling_duration.max(MIN_SAMPLING_DURATION), + failure_threshold, + min_throughput, + } + } + + pub fn build(&self) -> HealthMetrics { + HealthMetrics::new(self.sampling_duration, self.failure_threshold, self.min_throughput) + } +} + +/// Tracks execution results over a sliding time window to provide health metrics. +#[derive(Debug)] +pub(crate) struct HealthMetrics { + sampling_duration: Duration, + window_duration: Duration, + windows: VecDeque, + failure_threshold: f32, + min_throughput: u32, +} + +impl HealthMetrics { + fn new(sampling_duration: Duration, failure_threshold: f32, min_throughput: u32) -> Self { + Self { + sampling_duration, + window_duration: sampling_duration / WINDOW_COUNT, + windows: VecDeque::with_capacity(WINDOW_COUNT as usize), + failure_threshold, + min_throughput, + } + } + + pub fn record(&mut self, result: ExecutionResult, now: Instant) { + // Remove old windows + while let Some(front) = self.windows.front() + && now.duration_since(front.started_at) > self.sampling_duration + { + self.windows.pop_front(); + } + + // Get or create the current window + if let Some(back) = self.windows.back_mut() + && now.duration_since(back.started_at) < self.window_duration + { + // Update the existing window + back.update(result); + } else { + // Create a new window + let mut new_window = Window::new(now); + new_window.update(result); + self.windows.push_back(new_window); + } + } + + pub fn health_info(&self) -> HealthInfo { + let mut successes = 0_u32; + let mut failures = 0_u32; + + for w in &self.windows { + successes = successes.saturating_add(w.successes); + failures = failures.saturating_add(w.failures); + } + + HealthInfo::new(successes, failures, self.failure_threshold, self.min_throughput) + } +} + +#[derive(Debug)] +struct Window { + successes: u32, + failures: u32, + started_at: Instant, +} + +impl Window { + fn new(started_at: Instant) -> Self { + Self { + successes: 0, + failures: 0, + started_at, + } + } + + fn update(&mut self, result: ExecutionResult) { + match result { + ExecutionResult::Success => self.successes += 1, + ExecutionResult::Failure => self.failures += 1, + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn factory_ok() { + let builder = HealthMetricsBuilder::new(Duration::from_secs(10), 0.5, 5); + let metrics = builder.build(); + + assert_eq!(metrics.sampling_duration, Duration::from_secs(10)); + assert_eq!(metrics.window_duration, Duration::from_secs(1)); + assert_eq!(metrics.failure_threshold, 0.5); + assert_eq!(metrics.min_throughput, 5); + + // small sampling duration is clamped + let builder = HealthMetricsBuilder::new(Duration::from_millis(500), 0.5, 5); + let metrics = builder.build(); + assert_eq!(metrics.sampling_duration, MIN_SAMPLING_DURATION); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn record_when_empty() { + let mut metrics = HealthMetrics::new(Duration::from_secs(10), 0.5, 5); + let start = Instant::now(); + metrics.record(ExecutionResult::Success, start); + let info = metrics.health_info(); + + assert_eq!(info.throughput(), 1); + assert_eq!(info.failure_rate(), 0.0); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn create_health_info_healthy_when_not_throughput() { + let metrics = HealthMetrics::new(Duration::from_secs(10), 0.5, 5); + + let info = metrics.health_info(); + + assert_eq!(info.throughput(), 0); + assert_eq!(info.failure_rate(), 0.0); + assert_eq!(info.status(), HealthStatus::Healthy); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn record_twice() { + let mut metrics = HealthMetrics::new(Duration::from_secs(10), 0.5, 2); + let start = Instant::now(); + metrics.record(ExecutionResult::Success, start); + metrics.record(ExecutionResult::Failure, start); + let info = metrics.health_info(); + + assert_eq!(info.throughput(), 2); + assert_eq!(info.failure_rate(), 0.5); + assert_eq!(info.status(), HealthStatus::Unhealthy); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn record_ensure_old_window_discarded() { + let mut metrics = HealthMetrics::new(Duration::from_secs(10), 0.5, 5); + let start = Instant::now(); + metrics.record(ExecutionResult::Success, start); + + // Advance time beyond the sampling duration + let later = start + Duration::from_secs(11); + metrics.record(ExecutionResult::Success, later); + let info = metrics.health_info(); + + assert_eq!(info.throughput(), 1); + assert_eq!(info.failure_rate(), 0.0); + } + + #[test] + fn new_ensure_initialized_properly() { + let metrics = HealthMetrics::new(Duration::from_secs(10), 0.5, 5); + assert_eq!(metrics.sampling_duration, Duration::from_secs(10)); + assert_eq!(metrics.window_duration, Duration::from_secs(1)); + assert!(metrics.windows.is_empty()); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn ensure_multiple_windows_created() { + let mut metrics = HealthMetrics::new(Duration::from_secs(10), 0.5, 5); + let start = Instant::now(); + for i in 0..30 { + let now = start + Duration::from_millis(i * 100); + metrics.record(ExecutionResult::Success, now); + } + + assert_eq!(metrics.windows.len(), 3); + + let first_window = &metrics.windows[0]; + assert_eq!(first_window.successes, 10); + assert_eq!(first_window.failures, 0); + assert_eq!(first_window.started_at, start); + + // discard the first window + let later = start + Duration::from_secs(12); + metrics.record(ExecutionResult::Success, later); + let info = metrics.health_info(); + + assert_eq!(metrics.windows.len(), 2); + assert_eq!(info.throughput(), 11); + assert_eq!(info.failure_rate(), 0.0); + } + + mod health_info_create_tests { + use super::*; + + #[test] + fn zero_throughput_is_healthy() { + let info = HealthInfo::new(0, 0, 0.5, 10); + assert_eq!( + (info.throughput(), info.failure_rate(), info.status()), + (0, 0.0, HealthStatus::Healthy) + ); + } + + #[test] + fn only_successes_is_healthy() { + let info = HealthInfo::new(10, 0, 0.5, 5); + assert_eq!( + (info.throughput(), info.failure_rate(), info.status()), + (10, 0.0, HealthStatus::Healthy) + ); + } + + #[test] + fn only_failures_above_threshold_is_unhealthy() { + let info = HealthInfo::new(0, 10, 0.5, 5); + assert_eq!( + (info.throughput(), info.failure_rate(), info.status()), + (10, 1.0, HealthStatus::Unhealthy) + ); + } + + #[test] + fn failure_threshold_boundaries() { + // At threshold + let info = HealthInfo::new(5, 5, 0.5, 5); + assert_eq!(info.status(), HealthStatus::Unhealthy); + + // Below threshold + let info = HealthInfo::new(6, 4, 0.5, 5); + assert_eq!(info.status(), HealthStatus::Healthy); + } + + #[test] + fn min_throughput_boundaries() { + // Below min throughput - healthy despite high failure rate + let info = HealthInfo::new(0, 3, 0.5, 5); + assert_eq!(info.status(), HealthStatus::Healthy); + + // At min throughput - unhealthy with high failure rate + let info = HealthInfo::new(1, 4, 0.5, 5); + assert_eq!(info.status(), HealthStatus::Unhealthy); + } + + #[test] + fn edge_cases() { + // Saturating add + let info = HealthInfo::new(u32::MAX, 1, 0.5, 5); + assert_eq!(info.throughput(), u32::MAX); + + // Zero threshold + let info = HealthInfo::new(1, 1, 0.0, 0); + assert_eq!(info.status(), HealthStatus::Unhealthy); + } + } +} diff --git a/crates/seatbelt/src/circuit_breaker/layer.rs b/crates/seatbelt/src/circuit_breaker/layer.rs new file mode 100644 index 00000000..e82a96b0 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/layer.rs @@ -0,0 +1,720 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::borrow::Cow; +use std::marker::PhantomData; +use std::time::Duration; + +use super::constants::{DEFAULT_BREAK_DURATION, DEFAULT_FAILURE_THRESHOLD, DEFAULT_MIN_THROUGHPUT, DEFAULT_SAMPLING_DURATION}; +use super::{ + Circuit, Engines, HalfOpenMode, HealthMetricsBuilder, OnClosed, OnClosedArgs, OnOpened, OnOpenedArgs, OnProbing, OnProbingArgs, + PartionKeyProvider, PartitionKey, RejectedInput, RejectedInputArgs, ShouldRecover, +}; +use crate::circuit_breaker::engine::probing::ProbesOptions; +use crate::utils::{EnableIf, TelemetryHelper}; +use crate::{NotSet, Recovery, RecoveryInfo, ResilienceContext, Set}; +use layered::Layer; + +/// Builder for configuring circuit breaker resilience middleware. +/// +/// This type is created by calling [`Circuit::layer`] and uses the +/// type-state pattern to enforce that required properties are configured before the circuit breaker +/// middleware can be built: +/// +/// - [`recovery`][CircuitLayer::recovery]: Required to determine if an output represents a failure +/// - [`rejected_input`][CircuitLayer::rejected_input]: Required to specify the output when the circuit is open and inputs are rejected +/// +/// For comprehensive documentation and examples, see the [`circuit_breaker` module][crate::circuit_breaker] documentation. +/// +/// # Type State +/// +/// - `S1`: Tracks whether [`recovery`][CircuitLayer::recovery] has been set +/// - `S2`: Tracks whether [`rejected_input`][CircuitLayer::rejected_input] has been set +#[derive(Debug)] +pub struct CircuitLayer { + context: ResilienceContext, + recovery: Option>, + rejected_input: Option>, + on_opened: Option>, + on_closed: Option>, + on_probing: Option>, + partition_key: Option>, + enable_if: EnableIf, + telemetry: TelemetryHelper, + failure_threshold: f32, + min_throughput: u32, + sampling_duration: Duration, + break_duration: Duration, + half_open_mode: HalfOpenMode, + _state: PhantomData Out>, +} + +impl CircuitLayer { + #[must_use] + pub(crate) fn new(name: Cow<'static, str>, context: &ResilienceContext) -> Self { + Self { + context: context.clone(), + recovery: None, + rejected_input: None, + on_opened: None, + on_closed: None, + on_probing: None, + partition_key: None, + enable_if: EnableIf::always(), + telemetry: context.create_telemetry(name), + failure_threshold: DEFAULT_FAILURE_THRESHOLD, + min_throughput: DEFAULT_MIN_THROUGHPUT, + sampling_duration: DEFAULT_SAMPLING_DURATION, + break_duration: DEFAULT_BREAK_DURATION, + half_open_mode: HalfOpenMode::reliable(None), + _state: PhantomData, + } + } +} + +impl CircuitLayer, S1, S2> { + /// Sets the error to return when the circuit breaker is open for Result-returning services. + /// + /// When the circuit is open, requests are immediately rejected and this function + /// is called to generate the error that should be returned to the caller. + /// The error is automatically wrapped in a `Result::Err`. + /// + /// This is a convenience method for Result-returning services that allows you to + /// provide a meaningful error when the circuit breaker prevents a request from + /// reaching the underlying service. + /// + /// # Arguments + /// + /// * `error_producer` - Function that generates the error to return when the circuit is open + #[must_use] + pub fn rejected_input_error( + self, + error_producer: impl Fn(In, RejectedInputArgs) -> E + Send + Sync + 'static, + ) -> CircuitLayer, S1, Set> { + self.into_state::() + .rejected_input(move |input, args| Err(error_producer(input, args))) + .into_state() + } +} + +impl CircuitLayer { + /// Sets the recovery classification function. + /// + /// This function determines whether a specific output represents a failure + /// by examining the output and returning a [`RecoveryInfo`] classification. + /// + /// The function receives the output and [`RecoveryArgs`][crate::circuit_breaker::RecoveryArgs] + /// with context about the circuit breaker state. + /// + /// # Arguments + /// + /// * `recover_fn` - Function that takes a reference to the output and + /// [`RecoveryArgs`][crate::circuit_breaker::RecoveryArgs] containing circuit breaker context, + /// and returns a [`RecoveryInfo`] decision + #[must_use] + pub fn recovery_with( + mut self, + recover_fn: impl Fn(&Out, crate::circuit_breaker::RecoveryArgs) -> RecoveryInfo + Send + Sync + 'static, + ) -> CircuitLayer { + self.recovery = Some(ShouldRecover::new(recover_fn)); + self.into_state::() + } + + /// Automatically sets the recovery classification function for types that implement [`Recovery`]. + /// + /// This is a convenience method that uses the [`Recovery`] trait to determine + /// whether an output represents a failure. For types that implement [`Recovery`], + /// this provides a simple way to enable circuit breaker behavior without manually + /// implementing a recovery classification function. + /// + /// This is equivalent to calling [`recovery_with`][CircuitLayer::recovery_with] with + /// `|output, _args| output.recovery()`. + /// + /// # Type Requirements + /// + /// This method is only available when the output type `Out` implements [`Recovery`]. + #[must_use] + pub fn recovery(self) -> CircuitLayer + where + Out: Recovery, + { + self.recovery_with(|out, _args| out.recovery()) + } + + /// Sets the output to return when the circuit breaker is open. + /// + /// When the circuit is open, requests are immediately rejected and this function + /// is called to generate the output that should be returned to the caller. + /// + /// This allows you to provide a meaningful error message or fallback value + /// when the circuit breaker prevents a request from reaching the underlying service. + /// + /// # Arguments + /// + /// * `rejected_fn` - Function that generates the output to return when the circuit is open + #[must_use] + pub fn rejected_input( + mut self, + rejected_fn: impl Fn(In, RejectedInputArgs) -> Out + Send + Sync + 'static, + ) -> CircuitLayer { + self.rejected_input = Some(RejectedInput::new(rejected_fn)); + self.into_state::() + } + + /// Sets the failure threshold for the circuit breaker. + /// + /// The circuit breaker will open when the failure rate exceeds this threshold + /// over the sampling duration. The value should be between 0.0 and 1.0, where + /// 0.1 represents a `10%` failure threshold. Values greater than 1.0 will be clamped to 1.0. + /// + /// **Default**: 0.1 (`10%` failure rate) + /// + /// # Arguments + /// + /// * `threshold` - The failure threshold (0.0 to 1.0, values `>` 1.0 are clamped) + #[must_use] + pub fn failure_threshold(mut self, threshold: f32) -> Self { + self.failure_threshold = threshold.min(1.0); + self + } + + /// Sets the minimum throughput required before the circuit breaker can open. + /// + /// The circuit breaker will only consider opening if at least this many requests + /// have been processed during the sampling duration. This prevents the circuit + /// from opening due to a small number of failures when overall traffic is low. + /// + /// **Default**: 100 requests + /// + /// # Arguments + /// + /// * `throughput` - The minimum number of requests required + #[must_use] + pub fn min_throughput(mut self, throughput: u32) -> Self { + self.min_throughput = throughput; + self + } + + /// Sets the sampling duration for calculating failure rates. + /// + /// The circuit breaker calculates failure rates over this time window. + /// Only requests within this duration are considered when determining + /// whether the failure rate exceeds the threshold. + /// + /// **Default**: 30 seconds + /// + /// > **Note**: The sampling duration cannot be lower than 1 second. If value is less + /// > than 1 second, it will be clamped to 1 second. + /// + /// # Arguments + /// + /// * `duration` - The time window for sampling failures + #[must_use] + pub fn sampling_duration(mut self, duration: Duration) -> Self { + self.sampling_duration = duration; + self + } + + /// Sets the break duration for how long the circuit stays open. + /// + /// When the circuit breaker opens due to failures, it will remain open + /// for this duration before transitioning to half-open state to test + /// if the underlying service has recovered. + /// + /// **Default**: 5 seconds + /// + /// # Arguments + /// + /// * `duration` - How long the circuit stays open after breaking + #[must_use] + pub fn break_duration(mut self, duration: Duration) -> Self { + self.break_duration = duration; + self + } + + /// Sets the callback to be invoked when the circuit breaker opens. + /// + /// This callback is called whenever the circuit breaker transitions from + /// closed to open state due to exceeding the failure threshold. + /// + /// **Default**: No callback + /// + /// # Arguments + /// + /// * `callback` - Function that takes a reference to the output and + /// [`OnOpenedArgs`] containing circuit breaker context + #[must_use] + pub fn on_opened(mut self, callback: impl Fn(&Out, OnOpenedArgs) + Send + Sync + 'static) -> Self { + self.on_opened = Some(OnOpened::new(callback)); + self + } + + /// Sets the callback to be invoked when the circuit breaker closes. + /// + /// This callback is called whenever the circuit breaker transitions from + /// half-open state to closed state after successful recovery. + /// + /// **Default**: No callback + /// + /// # Arguments + /// + /// * `callback` - Function that takes a reference to the output and + /// [`OnClosedArgs`] containing circuit breaker context + #[must_use] + pub fn on_closed(mut self, callback: impl Fn(&Out, OnClosedArgs) + Send + Sync + 'static) -> Self { + self.on_closed = Some(OnClosed::new(callback)); + self + } + + /// Sets the callback to be invoked when the circuit breaker is probing. + /// + /// This callback is called when the circuit breaker is in half-open state + /// and is testing whether the underlying service has recovered. + /// + /// **Default**: No callback + /// + /// # Arguments + /// + /// * `callback` - Function that takes a mutable reference to the input and + /// [`OnProbingArgs`] containing circuit breaker context + #[must_use] + pub fn on_probing(mut self, callback: impl Fn(&mut In, OnProbingArgs) + Send + Sync + 'static) -> Self { + self.on_probing = Some(OnProbing::new(callback)); + self + } + + /// Sets the partition key provider function. + /// + /// The partitioning key is used to maintain separate circuit breaker states + /// for different inputs. The idea is to extract the partition key from the input + /// so that requests with the same key share the same circuit breaker state. + /// + /// **Default**: Single global circuit (no partitioning) - all requests share the same circuit breaker state + /// + /// If no partition key provider is set, a default key is used, meaning all requests + /// share the same circuit breaker state. + /// + /// The typical scenario is HTTP request where the partition key could be derived from + /// the combination of scheme and authority (host and port). This allows separate + /// circuit breaker states for different backend services. + /// + /// # Arguments + /// + /// * `key_provider` - Function that takes a reference to the input and returns + /// a [`PartitionKey`] identifying the partition for circuit breaker state + /// + /// # Example + /// + /// ```rust + /// # use seatbelt::circuit_breaker::{CircuitLayer, PartitionKey}; + /// // Example HTTP request structure + /// struct HttpRequest { + /// scheme: String, + /// host: String, + /// port: u16, + /// path: String, + /// } + /// # fn example(circuit_breaker_layer: CircuitLayer) { + /// // Configure circuit breaker with a partition key based on a scheme, host and port. + /// let layer = circuit_breaker_layer.partition_key(|request: &HttpRequest| { + /// let partition = format!("{}://{}:{}", request.scheme, request.host, request.port); + /// PartitionKey::from(partition) + /// }); + /// + /// // This ensures that: + /// // - Requests to https://api.service1.com share one circuit breaker state + /// // - Requests to https://api.service2.com:8080 share another circuit breaker state + /// // - Requests to http://localhost:3000 share yet another circuit breaker state + /// # } + /// ``` + /// + /// # Telemetry + /// + /// The values used to create partition keys are included in telemetry data (logs and metrics) + /// for observability purposes. **Important**: Ensure that the values from which partition keys + /// are created do not contain any sensitive data such as authentication tokens, personal + /// identifiable information (PII), or other confidential data. + #[must_use] + pub fn partition_key(mut self, key_provider: impl Fn(&In) -> PartitionKey + Send + Sync + 'static) -> Self { + self.partition_key = Some(PartionKeyProvider::new(key_provider)); + self + } + + /// Sets the half-open mode for the circuit breaker. + /// + /// This determines how the circuit breaker behaves when transitioning from half-open + /// to a closed state. + /// + /// **Default**: [`HalfOpenMode::reliable`] + #[must_use] + pub fn half_open_mode(mut self, mode: HalfOpenMode) -> Self { + self.half_open_mode = mode; + self + } + + /// Optionally enables the circuit breaker middleware based on a condition. + /// + /// When disabled, requests pass through without circuit breaker protection. + /// This call replaces any previous condition. + /// + /// **Default**: Always enabled + /// + /// # Arguments + /// + /// * `is_enabled` - Function that takes a reference to the input and returns + /// `true` if circuit breaker protection should be enabled for this request + #[must_use] + pub fn enable_if(mut self, is_enabled: impl Fn(&In) -> bool + Send + Sync + 'static) -> Self { + self.enable_if = EnableIf::new(is_enabled); + self + } + + /// Enables the circuit breaker middleware unconditionally. + /// + /// All requests will have circuit breaker protection applied. + /// This call replaces any previous condition. + /// + /// **Note**: This is the default behavior - circuit breaker is enabled by default. + #[must_use] + pub fn enable_always(mut self) -> Self { + self.enable_if = EnableIf::always(); + self + } + + /// Disables the circuit breaker middleware completely. + /// + /// All requests will pass through without circuit breaker protection. + /// This call replaces any previous condition. + /// + /// **Note**: This overrides the default enabled behavior. + #[must_use] + pub fn disable(mut self) -> Self { + self.enable_if = EnableIf::never(); + self + } +} + +impl Layer for CircuitLayer { + type Service = Circuit; + + fn layer(&self, inner: S) -> Self::Service { + Circuit { + inner, + clock: self.context.get_clock().clone(), + recovery: self.recovery.clone().expect("recovery must be set in Ready state"), + rejected_input: self.rejected_input.clone().expect("rejected_input must be set in Ready state"), + enable_if: self.enable_if.clone(), + engines: self.engines(), + on_opened: self.on_opened.clone(), + on_closed: self.on_closed.clone(), + on_probing: self.on_probing.clone(), + partition_key: self.partition_key.clone(), + } + } +} + +impl CircuitLayer { + fn probes_options(&self) -> ProbesOptions { + self.half_open_mode + // we will use break duration as the sampling duration for probes + .to_options(self.break_duration, self.failure_threshold) + } + + fn engines(&self) -> Engines { + Engines::new( + super::engine::EngineOptions { + break_duration: self.break_duration, + health_metrics_builder: HealthMetricsBuilder::new(self.sampling_duration, self.failure_threshold, self.min_throughput), + probes: self.probes_options(), + }, + self.context.get_clock().clone(), + self.telemetry.clone(), + ) + } + + fn into_state(self) -> CircuitLayer { + CircuitLayer { + context: self.context, + recovery: self.recovery, + rejected_input: self.rejected_input, + on_opened: self.on_opened, + on_closed: self.on_closed, + on_probing: self.on_probing, + partition_key: self.partition_key, + enable_if: self.enable_if, + telemetry: self.telemetry.clone(), + failure_threshold: self.failure_threshold, + min_throughput: self.min_throughput, + sampling_duration: self.sampling_duration, + break_duration: self.break_duration, + half_open_mode: self.half_open_mode, + _state: PhantomData, + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::fmt::Debug; + + use layered::Execute; + use tick::Clock; + + use super::*; + use crate::circuit_breaker::RecoveryArgs; + use crate::circuit_breaker::engine::probing::ProbeOptions; + use crate::testing::RecoverableType; + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn new_creates_correct_initial_state() { + let context = create_test_context(); + let layer: CircuitLayer<_, _, NotSet, NotSet> = CircuitLayer::new("test_breaker".into(), &context); + + assert!(layer.recovery.is_none()); + assert!(layer.rejected_input.is_none()); + assert_eq!(layer.telemetry.strategy_name.as_ref(), "test_breaker"); + assert!(layer.enable_if.call(&"test_input".to_string())); + assert_eq!(layer.failure_threshold, 0.1); + assert_eq!(layer.min_throughput, 100); + assert_eq!(layer.sampling_duration, Duration::from_secs(30)); + } + + #[test] + fn recovery_sets_correctly() { + let context = create_test_context(); + let layer = CircuitLayer::new("test".into(), &context); + + let layer: CircuitLayer<_, _, Set, NotSet> = layer.recovery_with(|output, _args| { + if output.contains("error") { + RecoveryInfo::retry() + } else { + RecoveryInfo::never() + } + }); + + let result = layer.recovery.as_ref().unwrap().call( + &"error message".to_string(), + RecoveryArgs { + partition_key: &PartitionKey::default(), + clock: &Clock::new_frozen(), + }, + ); + assert_eq!(result, RecoveryInfo::retry()); + + let result = layer.recovery.as_ref().unwrap().call( + &"success".to_string(), + RecoveryArgs { + partition_key: &PartitionKey::default(), + clock: &Clock::new_frozen(), + }, + ); + assert_eq!(result, RecoveryInfo::never()); + } + + #[test] + fn recovery_auto_sets_correctly() { + let context = ResilienceContext::::new(Clock::new_frozen()); + let layer = CircuitLayer::new("test".into(), &context); + + let layer: CircuitLayer<_, _, Set, NotSet> = layer.recovery(); + + let result = layer.recovery.as_ref().unwrap().call( + &RecoverableType::from(RecoveryInfo::retry()), + RecoveryArgs { + partition_key: &PartitionKey::default(), + clock: &Clock::new_frozen(), + }, + ); + assert_eq!(result, RecoveryInfo::retry()); + + let result = layer.recovery.as_ref().unwrap().call( + &RecoverableType::from(RecoveryInfo::never()), + RecoveryArgs { + partition_key: &PartitionKey::default(), + clock: &Clock::new_frozen(), + }, + ); + assert_eq!(result, RecoveryInfo::never()); + } + + #[test] + fn rejected_input_sets_correctly() { + let context = create_test_context(); + let layer = CircuitLayer::new("test".into(), &context); + + let layer: CircuitLayer<_, _, NotSet, Set> = layer.rejected_input(|_, _| "rejected".to_string()); + + let result = layer.rejected_input.as_ref().unwrap().call( + "test".to_string(), + RejectedInputArgs { + partition_key: &PartitionKey::default(), + }, + ); + assert_eq!(result, "rejected"); + } + + #[test] + fn rejected_input_error_wraps_in_err() { + let context: ResilienceContext> = ResilienceContext::new(Clock::new_frozen()); + let layer = CircuitLayer::new("test".into(), &context); + + let layer: CircuitLayer<_, _, NotSet, Set> = layer.rejected_input_error(|input, _| format!("rejected: {input}")); + + let result = layer.rejected_input.as_ref().unwrap().call( + "test_input".to_string(), + RejectedInputArgs { + partition_key: &PartitionKey::default(), + }, + ); + assert_eq!(result, Err("rejected: test_input".to_string())); + } + + #[test] + fn enable_disable_conditions_work() { + let layer = create_ready_layer().enable_if(|input| input.contains("enable")); + + assert!(layer.enable_if.call(&"enable_test".to_string())); + assert!(!layer.enable_if.call(&"disable_test".to_string())); + + let layer = layer.disable(); + assert!(!layer.enable_if.call(&"anything".to_string())); + + let layer = layer.enable_always(); + assert!(layer.enable_if.call(&"anything".to_string())); + } + + #[test] + fn layer_builds_service_when_ready() { + let layer = create_ready_layer(); + let _service = layer.layer(Execute::new(|input: String| async move { input })); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn failure_threshold_sets_correctly() { + let layer = create_ready_layer(); + + // Test setting a valid threshold + let layer = layer.failure_threshold(0.25); + assert_eq!(layer.failure_threshold, 0.25); + + // Test clamping values greater than 1.0 + let layer = layer.failure_threshold(1.5); + assert_eq!(layer.failure_threshold, 1.0); + + // Test edge cases + let layer = layer.failure_threshold(0.0); + assert_eq!(layer.failure_threshold, 0.0); + + let layer = layer.failure_threshold(1.0); + assert_eq!(layer.failure_threshold, 1.0); + } + + #[test] + fn min_throughput_sets_correctly() { + let layer = create_ready_layer(); + + // Test setting different throughput values + let layer = layer.min_throughput(50); + assert_eq!(layer.min_throughput, 50); + + let layer = layer.min_throughput(1000); + assert_eq!(layer.min_throughput, 1000); + + let layer = layer.min_throughput(0); + assert_eq!(layer.min_throughput, 0); + } + + #[test] + fn sampling_duration_sets_correctly() { + let layer = create_ready_layer(); + + // Test setting different durations + let layer = layer.sampling_duration(Duration::from_secs(10)); + assert_eq!(layer.sampling_duration, Duration::from_secs(10)); + + let layer = layer.sampling_duration(Duration::from_secs(60)); + assert_eq!(layer.sampling_duration, Duration::from_secs(60)); + + let layer = layer.sampling_duration(Duration::from_millis(500)); + assert_eq!(layer.sampling_duration, Duration::from_millis(500)); + } + + #[test] + fn break_duration_sets_correctly() { + let layer = create_ready_layer(); + + // Test setting different break durations + let layer = layer.break_duration(Duration::from_secs(5)); + assert_eq!(layer.break_duration, Duration::from_secs(5)); + + let layer = layer.break_duration(Duration::from_secs(120)); + assert_eq!(layer.break_duration, Duration::from_secs(120)); + + let layer = layer.break_duration(Duration::from_millis(2000)); + assert_eq!(layer.break_duration, Duration::from_millis(2000)); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + fn default_values_are_correct() { + let context = create_test_context(); + let layer = CircuitLayer::new("test".into(), &context); + + assert_eq!(layer.failure_threshold, DEFAULT_FAILURE_THRESHOLD); + assert_eq!(layer.min_throughput, DEFAULT_MIN_THROUGHPUT); + assert_eq!(layer.sampling_duration, DEFAULT_SAMPLING_DURATION); + assert_eq!(layer.break_duration, DEFAULT_BREAK_DURATION); + assert_eq!(layer.half_open_mode, HalfOpenMode::reliable(None)); + } + + #[test] + #[expect(clippy::float_cmp, reason = "Test")] + pub fn half_open_mode_ok() { + let layer = create_ready_layer().half_open_mode(HalfOpenMode::quick()); + assert_eq!(layer.half_open_mode, HalfOpenMode::quick()); + + let probes = layer + .break_duration(Duration::from_secs(234)) + .failure_threshold(0.52) + .half_open_mode(HalfOpenMode::reliable(None)) + .probes_options(); + + // access the last probe which should be the health probe + let probe = probes.probes().last().unwrap(); + + match probe { + ProbeOptions::HealthProbe(health_probe) => { + assert_eq!(health_probe.stage_duration(), Duration::from_secs(234)); + assert_eq!(health_probe.failure_threshold(), 0.52); + } + ProbeOptions::SingleProbe { .. } => panic!("Expected HealthProbe"), + } + } + + #[test] + fn static_assertions() { + static_assertions::assert_impl_all!(CircuitLayer: Layer); + static_assertions::assert_not_impl_all!(CircuitLayer: Layer); + static_assertions::assert_not_impl_all!(CircuitLayer: Layer); + static_assertions::assert_impl_all!(CircuitLayer: Debug); + } + + fn create_test_context() -> ResilienceContext { + ResilienceContext::new(Clock::new_frozen()).name("test_pipeline") + } + + fn create_ready_layer() -> CircuitLayer { + CircuitLayer::new("test".into(), &create_test_context()) + .recovery_with(|output, _args| { + if output.contains("error") { + RecoveryInfo::retry() + } else { + RecoveryInfo::never() + } + }) + .rejected_input(|_, _| "circuit is open".to_string()) + } +} diff --git a/crates/seatbelt/src/circuit_breaker/mod.rs b/crates/seatbelt/src/circuit_breaker/mod.rs new file mode 100644 index 00000000..9677b116 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/mod.rs @@ -0,0 +1,315 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Circuit breaker resilience middleware for preventing cascading failures. +//! +//! This module provides automatic circuit breaking capabilities with configurable failure +//! thresholds, break duration, and comprehensive telemetry. The primary types are: +//! +//! - [`Circuit`] is the middleware that wraps an inner service and monitors failure rates +//! - [`CircuitLayer`] is used to configure and construct the circuit breaker middleware +//! +//! A circuit breaker monitors the success and failure rates of operations and can temporarily +//! block requests when the failure rate exceeds a configured threshold. This prevents cascading failures +//! and gives downstream services time to recover. +//! +//! # Quick Start +//! +//! ```rust +//! # use layered::{Execute, Service, Stack}; +//! # use tick::Clock; +//! # use seatbelt::circuit_breaker::Circuit; +//! # use seatbelt::{RecoveryInfo, ResilienceContext}; +//! # async fn example(clock: Clock) -> Result<(), Box> { +//! let context = ResilienceContext::new(&clock); +//! +//! let stack = ( +//! Circuit::layer("circuit_breaker", &context) +//! // Required: determine if output indicates a failure by using recovery metadata +//! .recovery_with(|result: &Result, _| match result { +//! Ok(_) => RecoveryInfo::never(), +//! Err(_) => RecoveryInfo::retry(), +//! }) +//! // Required: provide output when the input is rejected on an open circuit +//! .rejected_input_error(|input, args| "service unavailable".to_string()), +//! Execute::new(my_operation), +//! ); +//! +//! let service = stack.into_service(); +//! let result = service.execute("input".to_string()).await; +//! # let _result = result; +//! # Ok(()) +//! # } +//! # async fn my_operation(input: String) -> Result { Ok(input) } +//! ``` +//! +//! # Configuration +//! +//! The [`CircuitLayer`] uses a type state pattern to enforce that all required properties +//! are configured before the layer can be built. This compile-time safety ensures that you cannot +//! accidentally create a circuit breaker layer without properly specifying recovery logic and +//! rejected input handling. The properties that must be configured are: +//! +//! - [`recovery`][CircuitLayer::recovery]: Detects the recovery classification for output. +//! This is used to determine if an operation succeeded or failed. +//! - [`rejected_input`][CircuitLayer::rejected_input]: Provide the output to return when the +//! circuit is open and execution is being rejected. +//! +//! Each circuit breaker layer requires an identifier for telemetry purposes. This identifier +//! should use `snake_case` naming convention to maintain consistency across the telemetry. +//! +//! # Thread Safety +//! +//! The [`Circuit`] type is thread-safe and implements both `Send` and `Sync` as enforced +//! by the `Service` trait it implements. This allows circuit breaker middleware to be safely +//! shared across multiple threads and used in concurrent environments. +//! +//! # Circuit Breaker States and Transitions +//! +//! The circuit breaker operates in three states: +//! +//! - **Closed**: Normal operation. Requests pass through and failures are tracked. +//! - **Open**: The circuit is broken. Requests are immediately rejected without calling +//! the underlying service. +//! - **Half-Open**: Testing if the service has recovered. A limited number of probing requests are +//! allowed through to assess the health of the underlying service. +//! +//! ```text +//! ┌────────┐ Failure threshold exceeded ┌──────────┐ +//! │ Closed │ ────────────────────────────────────▶│ Open │ +//! └────────┘ └──────────┘ +//! ▲ │ +//! │ │ +//! │ ┌────────────────┐ │ +//! └────────────│ Half-Open │◀──────────────────┘ +//! Probing └────────────────┘ Break duration +//! successful elapsed +//! ``` +//! +//! ## Closed State +//! +//! The circuit starts in the closed state and operates normally: +//! +//! - All requests pass through to the underlying service +//! - Failures are tracked and evaluated against the failure threshold +//! - When the failure threshold is exceeded, transitions to **Open** +//! - You can observe transitions into the closed state by providing +//! the [`on_closed`][CircuitLayer::on_closed] callback. +//! +//! ## Open State +//! +//! When the circuit is open: +//! +//! - Requests are immediately rejected with the output provided by [`rejected_input`][CircuitLayer::rejected_input] +//! - No calls are made to the underlying service +//! - After the break duration elapsed, transitions to **Half-Open** +//! - You can observe transitions into the open state by providing +//! the [`on_opened`][CircuitLayer::on_opened] callback. +//! +//! ## Half-Open State +//! +//! The circuit enters a testing phase: +//! +//! - A limited number of probing requests are allowed through +//! - Success rate is carefully monitored +//! - If sufficient successful probing requests occur, transitions back to **Closed** +//! - If failures continue, the circuit stays in the Half-Open state until the underlying service recovers. +//! Half-open state respects the break duration before allowing more probing requests. +//! - You can observe when the circuit is probing in half-open state by providing +//! the [`on_probing`][CircuitLayer::on_probing] callback. +//! - You can configure the probing behavior and the sensitivity of how quickly the circuit +//! closes again after successful probes by using [`half_open_mode`][CircuitLayer::half_open_mode] +//! +//! # Recovery Classification +//! +//! The circuit breaker uses [`RecoveryInfo`][crate::RecoveryInfo] to classify operation results. The following +//! recovery kinds are classified as failures that contribute to tripping the circuit: +//! +//! - [`RecoveryKind::Retry`][crate::RecoveryKind::Retry] +//! - [`RecoveryKind::Unavailable`][crate::RecoveryKind::Unavailable] +//! +//! # Partitioning +//! +//! Circuit breakers can maintain separate circuit states for different logical groups of requests +//! by providing a [`partition_key`][CircuitLayer::partition_key] function. This allows +//! the creation of multiple independent circuits based on the input properties. +//! +//! For example, a typical scenario where partitioning is useful is HTTP request where the partition key +//! is extracted from the request scheme, host, and port. This allows isolation of circuit states +//! for different downstream endpoints. +//! +//! > **Note**: Each unique partition key creates a separate circuit state. Be mindful of memory usage +//! > with high-cardinality partition keys. +//! +//! # Defaults +//! +//! The circuit breaker middleware uses the following default values when optional configuration +//! is not provided: +//! +//! | Parameter | Default Value | Description | Configured By | +//! |-----------|---------------|-------------|---------------| +//! | Failure threshold | `0.1` (10%) | Circuit opens when failure rate exceeds this percentage | [`failure_threshold`][CircuitLayer::failure_threshold] | +//! | Minimum throughput | `100` requests | Minimum request volume required before circuit can open | [`min_throughput`][CircuitLayer::min_throughput] | +//! | Sampling duration | `30` seconds | Time window for calculating failure rates | [`sampling_duration`][CircuitLayer::sampling_duration] | +//! | Break duration | `5` seconds | Duration circuit remains open before testing recovery | [`break_duration`][CircuitLayer::break_duration] | +//! | Partitioning | Single global circuit | All requests share the same circuit breaker state | [`partition_key`][CircuitLayer::partition_key] | +//! | Half-open mode | `Reliable` | Gradual recovery with increasing probe percentages | [`half_open_mode`][CircuitLayer::half_open_mode] | +//! | Enable condition | Always enabled | Circuit breaker protection is applied to all requests | [`enable_if`][CircuitLayer::enable_if], [`enable_always`][CircuitLayer::enable_always], [`disable`][CircuitLayer::disable] | +//! +//! These defaults provide a reasonable starting point for most use cases, offering a balance +//! between resilience and responsiveness to service recovery. +//! +//! # Telemetry +//! +//! ## Metrics +//! +//! - **Metric**: `resilience.event` (counter) +//! - **When**: Emitted when circuit state transitions occur and when requests are rejected +//! - **Attributes**: +//! - `resilience.pipeline.name`: Pipeline identifier from [`ResilienceContext::name`][crate::ResilienceContext::name] +//! - `resilience.strategy.name`: Circuit breaker identifier from [`Circuit::layer`] +//! - `resilience.event.name`: One of: +//! - `circuit_opened`: When the circuit transitions to open state due to failure threshold being exceeded +//! - `circuit_closed`: When the circuit transitions to closed state after successful probing +//! - `circuit_rejected`: When a request is rejected due to the circuit being in open state +//! - `circuit_probe`: When a probing request is executed in half-open state +//! - `resilience.circuit_breaker.state`: Current circuit state (`closed`, `open`, or `half_open`) +//! - `resilience.circuit_breaker.probe.result`: Result of probe execution (`success` or `failure`, only present for probe events) +//! +//! Additional structured logging events are emitted with detailed health metrics (failure rate, throughput) for circuit state transitions. +//! +//! # Examples +//! +//! ## Basic Usage +//! +//! This example demonstrates the basic usage of configuring and using circuit breaker middleware. +//! +//! ```rust +//! # use layered::{Execute, Service, Stack}; +//! # use tick::Clock; +//! # use seatbelt::circuit_breaker::Circuit; +//! # use seatbelt::{RecoveryInfo, ResilienceContext}; +//! # async fn example(clock: Clock) -> Result<(), Box> { +//! // Define common options for resilience middleware. The clock is runtime-specific and +//! // must be provided. See its documentation for details. +//! let context = ResilienceContext::new(&clock).name("example"); +//! +//! let stack = ( +//! Circuit::layer("my_breaker", &context) +//! // Required: determine if output indicates failure +//! .recovery_with(|result: &Result, _args| match result { +//! // These are demonstrative, real code will have more meaningful recovery detection +//! Ok(_) => RecoveryInfo::never(), +//! Err(msg) if msg.contains("transient") => RecoveryInfo::retry(), +//! Err(_) => RecoveryInfo::never(), +//! }) +//! // Required: provide output when circuit is open +//! .rejected_input_error(|_, _| "service unavailable".to_string()), +//! Execute::new(execute_unreliable_operation), +//! ); +//! +//! // Build the service +//! let service = stack.into_service(); +//! +//! // Execute the service +//! let result = service.execute("test input".to_string()).await; +//! # let _result = result; +//! # Ok(()) +//! # } +//! # async fn execute_unreliable_operation(input: String) -> Result { Ok(input) } +//! ``` +//! +//! ## Advanced Usage +//! +//! This example demonstrates advanced usage of the circuit breaker middleware, including custom +//! failure thresholds, sampling duration, break duration, and state change callbacks. +//! +//! ```rust +//! # use std::time::Duration; +//! # use layered::{Execute, Service, Stack}; +//! # use tick::Clock; +//! # use seatbelt::circuit_breaker::{Circuit, PartitionKey, HalfOpenMode}; +//! # use seatbelt::{RecoveryInfo, ResilienceContext}; +//! # async fn example(clock: Clock) -> Result<(), Box> { +//! // Define common options for resilience middleware. +//! let context = ResilienceContext::new(&clock).name("advanced_example"); +//! +//! let stack = ( +//! Circuit::layer("advanced_breaker", &context) +//! // Required: determine if output indicates failure +//! .recovery_with(|result: &Result, _args| match result { +//! Err(msg) if msg.contains("rate_limit") => RecoveryInfo::unavailable(), +//! Err(msg) if msg.contains("timeout") => RecoveryInfo::retry(), +//! Err(msg) if msg.contains("server_error") => RecoveryInfo::retry(), +//! Err(_) => RecoveryInfo::never(), // Client errors don't count as failures +//! Ok(_) => RecoveryInfo::never(), +//! }) +//! // Required: provide output when circuit is open +//! .rejected_input_error(|_input, _args| { +//! "service temporarily unavailable due to exceeding failure threshold".to_string() +//! }) +//! // Optional configuration +//! .half_open_mode(HalfOpenMode::reliable(None)) // close the circuit gradually (default) +//! .failure_threshold(0.05) // Trip at 5% failure threshold (less sensitive than default 10%) +//! .min_throughput(50) // Require minimum 50 requests before considering circuit open +//! .sampling_duration(Duration::from_secs(60)) // Evaluate failures over 60-second window +//! .break_duration(Duration::from_secs(30)) // Stay open for 30 seconds before testing +//! // You can provide your own partitioning logic if needed. The default is a single global +//! // circuit. By partitioning, you can have separate circuits for different inputs. +//! .partition_key(|input| PartitionKey::from(detect_partition(input))) +//! // State change callbacks for monitoring and alerting +//! .on_opened(|output, _args| { +//! println!("circuit breaker OPENED due to failure: {:?}", output); +//! // In real code, this would trigger alerts, metrics, logging, etc. +//! }) +//! .on_closed(|output, _args| { +//! println!("circuit breaker CLOSED, service recovered: {:?}", output); +//! // In real code, this would log recovery, update dashboards, etc. +//! }) +//! .on_probing(|input, _args| { +//! println!("circuit breaker PROBING with input: {:?}", input); +//! // Optionally modify input for probing requests +//! }), +//! Execute::new(execute_unreliable_operation), +//! ); +//! +//! // Build and execute the service +//! let service = stack.into_service(); +//! let result = service.execute("test_timeout".to_string()).await; +//! # let _result = result; +//! # Ok(()) +//! # } +//! # fn detect_partition(input: &String) -> String { input.to_string() } +//! # async fn execute_unreliable_operation(input: String) -> Result { Ok(input) } +//! ``` + +mod args; +mod callbacks; +mod layer; +mod service; +#[doc(inline)] +pub use args::{OnClosedArgs, OnOpenedArgs, OnProbingArgs, RecoveryArgs, RejectedInputArgs}; +pub(super) use callbacks::*; +#[doc(inline)] +pub use layer::CircuitLayer; +#[doc(inline)] +pub use service::Circuit; + +mod execution_result; +pub(super) use execution_result::ExecutionResult; + +mod health; +pub(super) use health::*; + +mod constants; +mod engine; + +#[cfg(any(feature = "metrics", test))] +mod telemetry; +pub(super) use engine::*; + +mod partition_key; +pub use partition_key::PartitionKey; + +mod half_open_mode; +pub use half_open_mode::HalfOpenMode; diff --git a/crates/seatbelt/src/circuit_breaker/partition_key.rs b/crates/seatbelt/src/circuit_breaker/partition_key.rs new file mode 100644 index 00000000..3199d161 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/partition_key.rs @@ -0,0 +1,206 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::borrow::Cow; +use std::collections::hash_map::DefaultHasher; +use std::fmt::Display; +use std::hash::Hasher; + +/// Key that identifies a partition for which a separate circuit breaker instance is maintained. +/// +/// Currently, it supports either integer or string keys. For maximum performance, prefer using integer keys +/// or static string keys (i.e. `&'static str`). +/// +/// # Examples +/// +/// ## Creation from a number +/// +/// ```rust +/// use seatbelt::circuit_breaker::PartitionKey; +/// +/// let key = PartitionKey::from(42_u64); +/// assert_eq!(key.to_string(), "42"); +/// ``` +/// +/// ## Creation from HTTP request authority and scheme +/// +/// ```rust +/// use seatbelt::circuit_breaker::PartitionKey; +/// +/// // Simulate extracting authority and scheme from an HTTP request +/// let scheme = "https"; +/// let authority = "api.example.com"; +/// let partition_value = format!("{}://{}", scheme, authority); +/// +/// let key = PartitionKey::from(partition_value); +/// assert_eq!(key.to_string(), "https://api.example.com"); +/// +/// // For better performance, use hashing. Note that you must provide a display label +/// // for the hashed key. +/// let hashed_key = PartitionKey::hashed(&(scheme, authority), "scheme_and_authority"); +/// assert_eq!(hashed_key.to_string(), "scheme_and_authority"); +/// ``` +/// +/// # Telemetry +/// +/// The values used to create partition keys are included in telemetry data (logs and metrics) +/// for observability purposes. **Important**: Ensure that the values from which partition keys +/// are created do not contain any sensitive data such as authentication tokens, personal +/// identifiable information (PII), or other confidential data. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct PartitionKey(PartitionKeyValue); + +impl PartitionKey { + pub(crate) const fn default() -> Self { + Self(PartitionKeyValue::String(Cow::Borrowed("default"))) + } + + /// Create a partition key by hashing the given value. + /// + /// The value must implement the `Hash` trait. This is useful for creating partition + /// keys from complex types. The resulting partition key will be based on the hash of + /// the value. You must provide a `label` that will be used for display and telemetry. + pub fn hashed(value: &T, label: &'static str) -> Self { + let mut hasher = DefaultHasher::new(); + value.hash(&mut hasher); + Self(PartitionKeyValue::Hashed(hasher.finish(), label)) + } +} + +impl Display for PartitionKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.0 { + PartitionKeyValue::Number(n) => write!(f, "{n}"), + PartitionKeyValue::String(s) => f.write_str(s), + PartitionKeyValue::Hashed(_, label) => f.write_str(label), + } + } +} + +impl From for PartitionKey { + fn from(value: u64) -> Self { + Self(PartitionKeyValue::Number(value)) + } +} + +impl From<&'static str> for PartitionKey { + fn from(value: &'static str) -> Self { + Self(PartitionKeyValue::String(Cow::Borrowed(value))) + } +} + +impl From for PartitionKey { + fn from(value: String) -> Self { + Self(PartitionKeyValue::String(Cow::Owned(value))) + } +} + +impl From for Cow<'static, str> { + fn from(value: PartitionKey) -> Self { + match value.0 { + PartitionKeyValue::Number(n) => Cow::Owned(n.to_string()), + PartitionKeyValue::String(s) => s, + PartitionKeyValue::Hashed(_, label) => Cow::Borrowed(label), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum PartitionKeyValue { + Number(u64), + Hashed(u64, &'static str), + String(Cow<'static, str>), +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::fmt::Debug; + use std::hash::Hash; + + use static_assertions::assert_impl_all; + + use super::*; + + assert_impl_all!(PartitionKey: Send, Sync, Unpin, Clone, Hash, Display, Debug, PartialEq, Eq); + + #[test] + fn from_u64_and_display() { + let k = PartitionKey::from(42u64); + assert_eq!(k.to_string(), "42"); + assert_eq!(k, PartitionKey::from(42u64)); + } + + #[test] + fn from_static_str_and_string() { + let a: PartitionKey = "hello".into(); + let b: PartitionKey = String::from("hello").into(); + assert_eq!(a.to_string(), "hello"); + assert_eq!(b.to_string(), "hello"); + assert_eq!(a, b); + } + + #[test] + fn hashed_matches_manual_hasher() { + let value = "some value"; + let pk = PartitionKey::hashed(&value, "my_label"); + + let mut hasher = DefaultHasher::new(); + value.hash(&mut hasher); + let expected = hasher.finish(); + + match &pk.0 { + PartitionKeyValue::Hashed(n, _) => { + assert_eq!(*n, expected); + } + _ => panic!("Expected Inner::Hashed variant"), + } + + assert_eq!(pk.to_string(), "my_label"); + } + + #[test] + fn partitionkey_hash_consistent() { + let k1 = PartitionKey::from(123u64); + let k2 = PartitionKey::from(123u64); + + let mut h1 = DefaultHasher::new(); + k1.hash(&mut h1); + let mut h2 = DefaultHasher::new(); + k2.hash(&mut h2); + + assert_eq!(h1.finish(), h2.finish()); + } + + #[test] + fn into_cow_from_number() { + let key = PartitionKey::from(42u64); + let cow: Cow<'static, str> = key.into(); + assert!(matches!(cow, Cow::Owned(_))); + assert_eq!(cow, "42"); + } + + #[test] + fn into_cow_from_string_owned() { + let key = PartitionKey::from(String::from("owned_string")); + let cow: Cow<'static, str> = key.into(); + assert!(matches!(cow, Cow::Owned(_))); + assert_eq!(cow, "owned_string"); + } + + #[test] + fn into_cow_from_string_borrowed() { + let key = PartitionKey::from("static_str"); + let cow: Cow<'static, str> = key.into(); + assert!(matches!(cow, Cow::Borrowed(_))); + assert_eq!(cow, "static_str"); + } + + #[test] + fn into_cow_from_hashed() { + let key = PartitionKey::hashed(&"value", "label"); + let cow: Cow<'static, str> = key.into(); + assert!(matches!(cow, Cow::Borrowed(_))); + assert_eq!(cow, "label"); + } +} diff --git a/crates/seatbelt/src/circuit_breaker/service.rs b/crates/seatbelt/src/circuit_breaker/service.rs new file mode 100644 index 00000000..dd3c4eb8 --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/service.rs @@ -0,0 +1,520 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::ops::ControlFlow; + +use layered::Service; +use tick::Clock; + +use super::{ + CircuitEngine, CircuitLayer, Engines, EnterCircuitResult, ExecutionMode, ExecutionResult, ExitCircuitResult, OnClosed, OnClosedArgs, + OnOpened, OnOpenedArgs, OnProbing, OnProbingArgs, PartionKeyProvider, PartitionKey, RecoveryArgs, RejectedInput, RejectedInputArgs, + ShouldRecover, +}; +use crate::{NotSet, utils::EnableIf}; + +/// Applies circuit breaker logic to prevent cascading failures. +/// +/// `Circuit` wraps an inner [`Service`] and monitors the success and failure rates +/// of operations. When the failure rate exceeds a configured threshold, the circuit breaker opens +/// and temporarily blocks requests to give the downstream service time to recover. +/// +/// This middleware is designed to be used across services, applications, and libraries +/// to prevent cascading failures in distributed systems. +/// +/// `Circuit` is configured by calling [`Circuit::layer`] and using the +/// builder methods on the returned [`CircuitLayer`] instance. +/// +/// For comprehensive examples and usage patterns, see the [`circuit_breaker` module][crate::circuit_breaker] documentation. +#[derive(Debug)] +pub struct Circuit { + pub(super) inner: S, + pub(super) clock: Clock, + pub(super) recovery: ShouldRecover, + pub(super) rejected_input: RejectedInput, + pub(super) enable_if: EnableIf, + pub(super) engines: Engines, + pub(super) partition_key: Option>, + pub(super) on_opened: Option>, + pub(super) on_closed: Option>, + pub(super) on_probing: Option>, +} + +impl Circuit { + /// Creates a new circuit breaker layer with the specified name and options. + /// + /// Returns a [`CircuitLayer`] that must be configured with required parameters + /// before it can be used to build a circuit breaker service. + pub fn layer( + name: impl Into>, + context: &crate::ResilienceContext, + ) -> CircuitLayer { + CircuitLayer::new(name.into(), context) + } +} + +impl Service for Circuit +where + In: Send, + S: Service, +{ + type Out = Out; + + async fn execute(&self, input: In) -> Self::Out { + // Check if a circuit breaker is enabled for this input + if !self.enable_if.call(&input) { + return self.inner.execute(input).await; + } + + // Determine the partition key for this input + let partition_key = self + .partition_key + .as_ref() + .map_or_else(PartitionKey::default, |partition_key| partition_key.call(&input)); + + // Retrieve the engine for this partition + let engine = self.engines.get_engine(&partition_key); + + // Before + let (input, mode) = match self.before_execute(engine.as_ref(), input, &partition_key) { + ControlFlow::Continue(input) => input, + ControlFlow::Break(output) => return output, + }; + + // Execute the inner service + let output = self.inner.execute(input).await; + + // After + self.after_execute(engine.as_ref(), &output, mode, &partition_key); + + output + } +} + +impl Circuit { + #[inline] + fn before_execute( + &self, + engine: &impl CircuitEngine, + mut input: In, + partition_key: &PartitionKey, + ) -> ControlFlow { + // Try to enter the circuit + match engine.enter() { + EnterCircuitResult::Accepted { mode } => { + match mode { + // regular execution, do nothing special + ExecutionMode::Normal => ControlFlow::Continue((input, ExecutionMode::Normal)), + // This is a probing execution that happens when the circuit is half-open. + // Invoke the on_probing callback if configured. + ExecutionMode::Probe => { + if let Some(on_probing) = &self.on_probing { + on_probing.call(&mut input, OnProbingArgs { partition_key }); + } + + ControlFlow::Continue((input, ExecutionMode::Probe)) + } + } + } + // Circuit is open, return rejected input output + EnterCircuitResult::Rejected => ControlFlow::Break(self.rejected_input.call(input, RejectedInputArgs { partition_key })), + } + } + + fn after_execute(&self, engine: &impl CircuitEngine, output: &Out, mode: ExecutionMode, partition_key: &PartitionKey) { + let recovery = self.recovery.call( + output, + RecoveryArgs { + partition_key, + clock: &self.clock, + }, + ); + + // Evaluate the execution result based on recovery decision + let execution_result = ExecutionResult::from_recovery(&recovery); + + // Exit the circuit and handle state transitions + match engine.exit(execution_result, mode) { + ExitCircuitResult::Unchanged | ExitCircuitResult::Reopened => { + // we explicitly do nothing here + } + ExitCircuitResult::Opened(_health) => { + if let Some(on_opened) = &self.on_opened { + on_opened.call(output, OnOpenedArgs { partition_key }); + } + } + ExitCircuitResult::Closed(stats) => { + if let Some(on_closed) = &self.on_closed { + on_closed.call( + output, + OnClosedArgs { + partition_key, + open_duration: stats.opened_duration(self.clock.instant()), + }, + ); + } + } + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +#[cfg(not(miri))] +mod tests { + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::time::{Duration, Instant}; + + use layered::Execute; + use tick::ClockControl; + + use super::*; + use crate::circuit_breaker::constants::DEFAULT_BREAK_DURATION; + use crate::circuit_breaker::{EngineFake, HalfOpenMode, HealthInfo, Stats}; + use crate::{RecoveryInfo, ResilienceContext, Set}; + use layered::Layer; + + #[test] + fn layer_ensure_defaults() { + let context = ResilienceContext::::new(Clock::new_frozen()).name("test_pipeline"); + let layer: CircuitLayer = Circuit::layer("test_breaker", &context); + let layer = layer + .recovery_with(|_, _| RecoveryInfo::never()) + .rejected_input(|_, _| "rejected".to_string()); + + let breaker = layer.layer(Execute::new(|v: String| async move { v })); + + assert!(breaker.enable_if.call(&"str".to_string())); + } + + #[tokio::test] + async fn circuit_breaker_disabled_no_inner_calls() { + let clock = Clock::new_frozen(); + let service = create_ready_circuit_breaker_layer(&clock) + .disable() + .layer(Execute::new(move |v: String| async move { v })); + + let result = service.execute("test".to_string()).await; + + assert_eq!(result, "test"); + } + + #[tokio::test] + async fn passthrough_behavior() { + let clock = Clock::new_frozen(); + let service = create_ready_circuit_breaker_layer(&clock).layer(Execute::new(move |v: String| async move { v })); + + let result = service.execute("test".to_string()).await; + + assert_eq!(result, "test"); + } + + #[test] + fn before_execute_accepted() { + let service = create_ready_circuit_breaker_layer(&Clock::new_frozen()) + .on_probing(|_, _| panic!("should not be called")) + .layer(Execute::new(move |v: String| async move { v })); + + let engine = EngineFake::new( + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal, + }, + ExitCircuitResult::Unchanged, + ); + + let result = service + .before_execute(&engine, "test".to_string(), &PartitionKey::default()) + .continue_value() + .unwrap(); + assert_eq!(result, ("test".to_string(), ExecutionMode::Normal)); + } + + #[test] + fn before_execute_accepted_with_probing() { + let probing_called = Arc::new(AtomicBool::new(false)); + let probing_called_clone = Arc::clone(&probing_called); + + let service = create_ready_circuit_breaker_layer(&Clock::new_frozen()) + .on_probing(move |value, _| { + assert_eq!(value, "test"); + probing_called.store(true, std::sync::atomic::Ordering::SeqCst); + }) + .layer(Execute::new(move |v: String| async move { v })); + + let engine = EngineFake::new( + EnterCircuitResult::Accepted { + mode: ExecutionMode::Probe, + }, + ExitCircuitResult::Unchanged, + ); + + let result = service + .before_execute(&engine, "test".to_string(), &PartitionKey::default()) + .continue_value() + .unwrap(); + assert_eq!(result, ("test".to_string(), ExecutionMode::Probe)); + assert!(probing_called_clone.load(std::sync::atomic::Ordering::SeqCst)); + } + + #[test] + fn before_execute_rejected() { + let service = create_ready_circuit_breaker_layer(&Clock::new_frozen()) + .rejected_input(|_, _| "rejected".to_string()) + .layer(Execute::new(move |v: String| async move { v })); + + let engine = EngineFake::new(EnterCircuitResult::Rejected, ExitCircuitResult::Unchanged); + + let result = service + .before_execute(&engine, "test".to_string(), &PartitionKey::default()) + .break_value() + .unwrap(); + assert_eq!(result, "rejected"); + } + + #[test] + fn after_execute_unchanged() { + let service = create_ready_circuit_breaker_layer(&Clock::new_frozen()) + .on_opened(|_, _| panic!("should not be called")) + .on_closed(|_, _| panic!("should not be called")) + .layer(Execute::new(move |v: String| async move { v })); + + let engine = EngineFake::new( + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal, + }, + ExitCircuitResult::Unchanged, + ); + + // This should not panic, indicating no callbacks were invoked + service.after_execute(&engine, &"success".to_string(), ExecutionMode::Normal, &PartitionKey::default()); + } + + #[test] + fn after_execute_reopened() { + let service = create_ready_circuit_breaker_layer(&Clock::new_frozen()) + .on_opened(|_, _| panic!("should not be called")) + .on_closed(|_, _| panic!("should not be called")) + .layer(Execute::new(move |v: String| async move { v })); + + let engine = EngineFake::new( + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal, + }, + ExitCircuitResult::Reopened, + ); + + // This should not panic, indicating no callbacks were invoked + service.after_execute(&engine, &"success".to_string(), ExecutionMode::Normal, &PartitionKey::default()); + } + + #[test] + fn after_execute_opened() { + let opened_called = Arc::new(AtomicBool::new(false)); + let opened_called_clone = Arc::clone(&opened_called); + + let service = create_ready_circuit_breaker_layer(&Clock::new_frozen()) + .on_opened(move |output, _| { + assert_eq!(output, "error_response"); + opened_called.store(true, Ordering::SeqCst); + }) + .on_closed(|_, _| panic!("on_closed should not be called")) + .layer(Execute::new(move |v: String| async move { v })); + + let engine = EngineFake::new( + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal, + }, + ExitCircuitResult::Opened(HealthInfo::new(1, 1, 1.0, 1)), + ); + + service.after_execute( + &engine, + &"error_response".to_string(), + ExecutionMode::Normal, + &PartitionKey::default(), + ); + assert!(opened_called_clone.load(Ordering::SeqCst)); + } + + #[test] + fn after_execute_closed() { + let closed_called = Arc::new(AtomicBool::new(false)); + let closed_called_clone = Arc::clone(&closed_called); + + let service = create_ready_circuit_breaker_layer(&Clock::new_frozen()) + .on_opened(|_, _| panic!("on_opened should not be called")) + .on_closed(move |output, _| { + assert_eq!(output, "success_response"); + closed_called.store(true, Ordering::SeqCst); + }) + .layer(Execute::new(move |v: String| async move { v })); + + let engine = EngineFake::new( + EnterCircuitResult::Accepted { + mode: ExecutionMode::Normal, + }, + ExitCircuitResult::Closed(Stats::new(Instant::now())), + ); + + service.after_execute( + &engine, + &"success_response".to_string(), + ExecutionMode::Normal, + &PartitionKey::default(), + ); + assert!(closed_called_clone.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn execute_end_to_end_with_callbacks() { + let probing_called = Arc::new(AtomicBool::new(false)); + let opened_called = Arc::new(AtomicBool::new(false)); + let closed_called = Arc::new(AtomicBool::new(false)); + + let probing_called_clone = Arc::clone(&probing_called); + let opened_called_clone = Arc::clone(&opened_called); + let closed_called_clone = Arc::clone(&closed_called); + + let clock_control = ClockControl::new(); + + // Create a service that transforms input and can trigger different circuit states + let service = create_ready_circuit_breaker_layer(&clock_control.to_clock()) + .min_throughput(5) + .half_open_mode(HalfOpenMode::quick()) + .on_probing(move |input, _| { + assert_eq!(input, "probe_input"); + probing_called.store(true, Ordering::SeqCst); + }) + .on_opened(move |output, _| { + assert_eq!(output, "error_output"); + opened_called.store(true, Ordering::SeqCst); + }) + .on_closed(move |output, args| { + assert_eq!(output, "probe_output"); + assert!(args.open_duration() > Duration::ZERO); + closed_called.store(true, Ordering::SeqCst); + }) + .layer(Execute::new(move |input: String| async move { + // Transform input to simulate different scenarios + match input.as_str() { + "probe_input" => "probe_output".to_string(), + "success_input" => "success_output".to_string(), + "error_input" => "error_output".to_string(), + _ => input, + } + })); + + // break the circuit first by simulating failures + for _ in 0..5 { + let result = service.execute("error_input".to_string()).await; + assert_eq!(result, "error_output"); + } + + // rejected input + let result = service.execute("success_input".to_string()).await; + assert_eq!(result, "circuit is open"); + assert!(opened_called_clone.load(Ordering::SeqCst)); + assert!(!closed_called_clone.load(Ordering::SeqCst)); + + // send probe and close the circuit + clock_control.advance(DEFAULT_BREAK_DURATION); + let result = service.execute("probe_input".to_string()).await; + assert_eq!(result, "probe_output"); + assert!(probing_called_clone.load(Ordering::SeqCst)); + assert!(closed_called_clone.load(Ordering::SeqCst)); + + // normal execution should pass through + let result = service.execute("success_input".to_string()).await; + assert_eq!(result, "success_output"); + } + + #[tokio::test] + async fn different_partitions_ensure_isolated() { + let clock = Clock::new_frozen(); + let service = create_ready_circuit_breaker_layer(&clock) + .partition_key(|input| PartitionKey::from(input.clone())) + .min_throughput(3) + .recovery_with(|_, _| RecoveryInfo::retry()) + .rejected_input(|_, args| format!("circuit is open, partition: {}", args.partition_key)) + .layer(Execute::new(|input: String| async move { input })); + + // break the circuit for partition "A" + for _ in 0..3 { + let result = service.execute("A".to_string()).await; + assert_eq!(result, "A"); + } + + let result = service.execute("A".to_string()).await; + assert_eq!(result, "circuit is open, partition: A"); + + // Execute on partition "B" should pass through + let result = service.execute("B".to_string()).await; + assert_eq!(result, "B"); + } + + #[tokio::test] + async fn circuit_breaker_emits_logs() { + use tracing_subscriber::util::SubscriberInitExt; + + use crate::testing::LogCapture; + + let log_capture = LogCapture::new(); + let _guard = log_capture.subscriber().set_default(); + + let clock_control = ClockControl::new(); + let context = ResilienceContext::::new(clock_control.to_clock()) + .name("log_test_pipeline") + .enable_logs(); + + let service = Circuit::layer("log_test_circuit", &context) + .min_throughput(3) + .half_open_mode(HalfOpenMode::quick()) + .recovery_with(|output, _| { + if output.contains("success") { + RecoveryInfo::never() + } else { + RecoveryInfo::retry() + } + }) + .rejected_input(|_, _| "rejected".to_string()) + .layer(Execute::new(|input: String| async move { input })); + + // Trip the circuit by generating failures + for _ in 0..3 { + let _ = service.execute("fail".to_string()).await; + } + + // Verify circuit opened log + log_capture.assert_contains("seatbelt::circuit"); + log_capture.assert_contains("log_test_pipeline"); + log_capture.assert_contains("log_test_circuit"); + log_capture.assert_contains("circuit_breaker.state=\"open\""); + log_capture.assert_contains("circuit_breaker.health.failure_rate"); + + // Request should be rejected (emits another open state log) + let _ = service.execute("test".to_string()).await; + + // Advance time past break duration to allow probing + clock_control.advance(DEFAULT_BREAK_DURATION); + + // Send a successful probe to close circuit + let _ = service.execute("success".to_string()).await; + log_capture.assert_contains("circuit_breaker.probe.result"); + log_capture.assert_contains("circuit_breaker.state=\"closed\""); + log_capture.assert_contains("circuit_breaker.open.duration"); + } + + fn create_ready_circuit_breaker_layer(clock: &Clock) -> CircuitLayer { + let context = ResilienceContext::::new(clock.clone()).name("test_pipeline"); + Circuit::layer("test_breaker", &context) + .recovery_with(|output, _| { + if output.contains("error") { + RecoveryInfo::retry() + } else { + RecoveryInfo::never() + } + }) + .rejected_input(|_, _| "circuit is open".to_string()) + } +} diff --git a/crates/seatbelt/src/circuit_breaker/telemetry.rs b/crates/seatbelt/src/circuit_breaker/telemetry.rs new file mode 100644 index 00000000..a6ded27a --- /dev/null +++ b/crates/seatbelt/src/circuit_breaker/telemetry.rs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +pub(super) const CIRCUIT_OPENED_EVENT_NAME: &str = "circuit_opened"; +pub(super) const CIRCUIT_CLOSED_EVENT_NAME: &str = "circuit_closed"; +pub(super) const CIRCUIT_REJECTED_EVENT_NAME: &str = "circuit_rejected"; +pub(super) const CIRCUIT_PROBE_EVENT_NAME: &str = "circuit_probe"; +pub(super) const CIRCUIT_STATE: &str = "resilience.circuit_breaker.state"; +pub(super) const CIRCUIT_PROBE_RESULT: &str = "resilience.circuit_breaker.probe.result"; +pub(super) const CIRCUIT_PARTITION: &str = "resilience.circuit_breaker.partition"; diff --git a/crates/seatbelt/src/context.rs b/crates/seatbelt/src/context.rs new file mode 100644 index 00000000..746916fb --- /dev/null +++ b/crates/seatbelt/src/context.rs @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::borrow::Cow; + +use tick::Clock; + +pub(crate) const DEFAULT_PIPELINE_NAME: &str = "default"; + +/// Shared configuration and dependencies for a pipeline of resilience middleware. +/// +/// Pass a single `ResilienceContext` to all middleware in a pipeline (retry, timeout, +/// circuit breaker, etc.) to share a clock and telemetry configuration. +#[derive(Debug)] +#[non_exhaustive] +pub struct ResilienceContext { + clock: Clock, + name: Cow<'static, str>, + #[cfg(any(feature = "metrics", test))] + meter: Option, + logs_enabled: bool, + _in: std::marker::PhantomData In>, + _out: std::marker::PhantomData Out>, +} + +impl ResilienceContext { + /// Create a context with a clock. Initializes with `name = "default"`. + pub fn new(clock: impl AsRef) -> Self { + Self { + clock: clock.as_ref().clone(), + name: Cow::Borrowed(DEFAULT_PIPELINE_NAME), + #[cfg(any(feature = "metrics", test))] + meter: None, + logs_enabled: false, + _in: std::marker::PhantomData, + _out: std::marker::PhantomData, + } + } + + /// Get the configured clock for timing operations. + #[must_use] + #[cfg(any(feature = "retry", feature = "circuit-breaker", feature = "timeout", test))] + pub(crate) fn get_clock(&self) -> &Clock { + &self.clock + } + + /// Set the pipeline name for telemetry correlation. Prefer `snake_case`. + #[must_use] + pub fn name(mut self, name: impl Into>) -> Self { + self.name = name.into(); + self + } + + /// Enable metrics reporting with the given OpenTelemetry meter provider. + #[must_use] + #[cfg(any(feature = "metrics", test))] + pub fn enable_metrics(self, provider: &dyn opentelemetry::metrics::MeterProvider) -> Self { + Self { + meter: Some(crate::metrics::create_meter(provider)), + ..self + } + } + + /// Enable structured logging for resilience events. + #[must_use] + #[cfg(any(feature = "logs", test))] + pub fn enable_logs(self) -> Self { + Self { + logs_enabled: true, + ..self + } + } + + #[cfg_attr( + not(any(feature = "metrics", feature = "logs", test)), + expect(unused_variables, reason = "unused when logs nor metrics are used") + )] + #[cfg(any(feature = "retry", feature = "circuit-breaker", feature = "timeout", test))] + pub(crate) fn create_telemetry(&self, strategy_name: Cow<'static, str>) -> crate::utils::TelemetryHelper { + crate::utils::TelemetryHelper { + #[cfg(any(feature = "metrics", test))] + event_reporter: self.meter.as_ref().map(crate::metrics::create_resilience_event_counter), + #[cfg(any(feature = "metrics", feature = "logs", test))] + pipeline_name: self.name.clone(), + #[cfg(any(feature = "metrics", feature = "logs", test))] + strategy_name, + #[cfg(any(feature = "logs", test))] + logs_enabled: self.logs_enabled, + } + } +} + +impl Clone for ResilienceContext { + fn clone(&self) -> Self { + Self { + clock: self.clock.clone(), + name: self.name.clone(), + #[cfg(any(feature = "metrics", test))] + meter: self.meter.clone(), + _in: std::marker::PhantomData, + _out: std::marker::PhantomData, + logs_enabled: self.logs_enabled, + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_new_with_clock_sets_default_pipeline_name() { + let clock = tick::Clock::new_frozen(); + let ctx = ResilienceContext::<(), ()>::new(clock); + let telemetry = ctx.create_telemetry("test".into()); + assert_eq!(telemetry.pipeline_name.as_ref(), DEFAULT_PIPELINE_NAME); + // Ensure clock reference behaves (timestamp monotonic relative behaviour not required, just accessible) + let _ = ctx.get_clock().system_time(); + } + + #[test] + fn test_name_with_custom_value_sets_name_and_is_owned() { + let clock = tick::Clock::new_frozen(); + let ctx = ResilienceContext::<(), ()>::new(clock).name(String::from("custom_pipeline")); + let telemetry = ctx.create_telemetry("test".into()); + assert_eq!(telemetry.pipeline_name.as_ref(), "custom_pipeline"); + assert!(matches!(telemetry.pipeline_name, Cow::Owned(_))); + } + + #[cfg(not(miri))] + #[test] + fn test_create_event_reporter_with_multiple_clones_accumulates_events() { + let clock = tick::Clock::new_frozen(); + let (provider, exporter) = test_meter_provider(); + + let ctx = ResilienceContext::<(), ()>::new(clock).enable_metrics(&provider); + let telemetry1 = ctx.create_telemetry("test1".into()); + let telemetry2 = ctx.create_telemetry("test2".into()); + let c1 = telemetry1.event_reporter.unwrap(); + let c2 = telemetry2.event_reporter.unwrap(); + c1.add(1, &[]); + c2.add(2, &[]); + + provider.force_flush().unwrap(); + let metrics = exporter.get_finished_metrics().unwrap(); + let dump = format!("{metrics:?}"); + assert!(dump.contains("resilience.event")); + // Basic sanity that total of 3 was recorded somewhere in debug output. + assert!(dump.contains('3')); + } + + #[cfg(not(miri))] + fn test_meter_provider() -> ( + opentelemetry_sdk::metrics::SdkMeterProvider, + opentelemetry_sdk::metrics::InMemoryMetricExporter, + ) { + let exporter = opentelemetry_sdk::metrics::InMemoryMetricExporter::default(); + let provider = opentelemetry_sdk::metrics::SdkMeterProvider::builder() + .with_periodic_exporter(exporter.clone()) + .build(); + (provider, exporter) + } +} diff --git a/crates/seatbelt/src/lib.rs b/crates/seatbelt/src/lib.rs new file mode 100644 index 00000000..a1a97bf7 --- /dev/null +++ b/crates/seatbelt/src/lib.rs @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#![cfg_attr(coverage_nightly, feature(coverage_attribute))] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![doc(html_logo_url = "https://media.githubusercontent.com/media/microsoft/oxidizer/refs/heads/main/crates/seatbelt/logo.png")] +#![doc(html_favicon_url = "https://media.githubusercontent.com/media/microsoft/oxidizer/refs/heads/main/crates/seatbelt/favicon.ico")] +#![cfg_attr( + not(all( + feature = "retry", + feature = "timeout", + feature = "circuit-breaker", + feature = "metrics", + feature = "logs" + )), + expect( + rustdoc::broken_intra_doc_links, + reason = "too ugly to make 'live links' possible with the combination of features" + ) +)] + +//! Resilience and recovery mechanisms for fallible operations. +//! +//! # Quick Start +//! +//! Add resilience to fallible operations, such as RPC calls over the network, with just a few lines of code. +//! **Retry** handles transient failures and **Timeout** prevents operations from hanging indefinitely: +//! +//! ```rust +//! # #[cfg(all(feature = "retry", feature = "timeout"))] +//! # { +//! # use std::time::Duration; +//! # use tick::Clock; +//! # use layered::{Execute, Service, Stack}; +//! use seatbelt::retry::Retry; +//! use seatbelt::timeout::Timeout; +//! use seatbelt::{RecoveryInfo, ResilienceContext}; +//! +//! # async fn main(clock: Clock) { +//! let context = ResilienceContext::new(&clock); +//! let service = ( +//! // Retry middleware: Automatically retries failed operations +//! Retry::layer("retry", &context) +//! .clone_input() +//! .recovery_with(|output: &String, _| match output.as_str() { +//! "temporary_error" => RecoveryInfo::retry(), +//! "operation timed out" => RecoveryInfo::retry(), +//! _ => RecoveryInfo::never(), +//! }), +//! // Timeout middleware: Cancels operations that take too long +//! Timeout::layer("timeout", &context) +//! .timeout_output(|_| "operation timed out".to_string()) +//! .timeout(Duration::from_secs(30)), +//! // Your core business logic +//! Execute::new(my_string_operation), +//! ) +//! .into_service(); +//! +//! let result = service.execute("input data".to_string()).await; +//! # } +//! # async fn my_string_operation(input: String) -> String { +//! # // Simulate processing that transforms the input string +//! # format!("processed: {}", input) +//! # } +//! # } +//! ``` +//! +//! # Why? +//! +//! Communicating over a network is inherently fraught with problems. The network can go down at any time, +//! sometimes for a millisecond or two. The endpoint you're connecting to may crash or be rebooted, +//! network configuration may change from under you, etc. To deliver a robust experience to users, and to +//! achieve `5` or more `9s` of availability, it is imperative to implement robust resilience patterns to +//! mask these transient failures. +//! +//! This crate provides production-ready resilience middleware with excellent telemetry for building +//! robust distributed systems that can automatically handle timeouts, retries, and other failure +//! scenarios. +//! +//! - **Production-ready** - Battle-tested middleware with sensible defaults and comprehensive +//! configuration options. +//! - **Excellent telemetry** - Built-in support for metrics and structured logging to monitor +//! resilience behavior in production. +//! - **Runtime agnostic** - Works seamlessly across any async runtime. Use the same resilience +//! patterns across different projects and migrate between runtimes without changes. +//! +//! # Overview +//! +//! This crate uses the [`layered`] crate for composing middleware. The middleware layers +//! can be stacked together using tuples and built into a service using the [`Stack`][layered::Stack] trait. +//! +//! Resilience middleware also requires [`Clock`][tick::Clock] from the [`tick`] crate for timing +//! operations like delays, timeouts, and backoff calculations. The clock is passed through +//! [`ResilienceContext`] when creating middleware layers. +//! +//! ## Core Types +//! +//! - [`ResilienceContext`] - Holds shared state for resilience middleware, including the clock. +//! - [`RecoveryInfo`] - Classifies errors as recoverable (transient) or non-recoverable (permanent). +//! - [`Recovery`] - A trait for types that can determine their recoverability. +//! +//! ## Built-in Middleware +//! +//! This crate provides built-in resilience middleware that you can use out of the box. See the documentation +//! for each module for details on how to use them. +//! +//! - [`timeout`] - Middleware that cancels long-running operations. +//! - [`retry`] - Middleware that automatically retries failed operations. +//! - [`circuit_breaker`] - Middleware that prevents cascading failures. +//! +//! # Features +//! +//! This crate provides several optional features that can be enabled in your `Cargo.toml`: +//! +//! - **`timeout`** - Enables the [`timeout`] middleware for canceling long-running operations. +//! - **`retry`** - Enables the [`retry`] middleware for automatically retrying failed operations with +//! configurable backoff strategies, jitter, and recovery classification. +//! - **`circuit-breaker`** - Enables the [`circuit_breaker`] middleware for preventing cascading failures. +//! - **`metrics`** - Exposes the OpenTelemetry metrics API for collecting and reporting metrics. +//! - **`logs`** - Enables structured logging for resilience middleware using the `tracing` crate. + +#[doc(inline)] +pub use recoverable::{Recovery, RecoveryInfo, RecoveryKind}; + +mod context; +pub use context::ResilienceContext; + +mod shared; +pub use crate::shared::{NotSet, Set}; + +#[cfg(any(feature = "timeout", test))] +pub mod timeout; + +#[cfg(any(feature = "retry", test))] +pub mod retry; + +#[cfg(any(feature = "circuit-breaker", test))] +pub mod circuit_breaker; + +#[cfg(any(feature = "retry", feature = "circuit-breaker", test))] +mod rnd; + +#[cfg(any(feature = "retry", feature = "circuit-breaker", feature = "timeout", test))] +pub(crate) mod utils; + +#[cfg(any(feature = "metrics", test))] +mod metrics; + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +pub(crate) mod testing; diff --git a/crates/seatbelt/src/metrics.rs b/crates/seatbelt/src/metrics.rs new file mode 100644 index 00000000..347c7a20 --- /dev/null +++ b/crates/seatbelt/src/metrics.rs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use opentelemetry::InstrumentationScope; +use opentelemetry::metrics::{Meter, MeterProvider}; + +const METER_NAME: &str = "seatbelt"; +const VERSION: &str = "v0.1.0"; +const SCHEMA_URL: &str = "https://opentelemetry.io/schemas/1.47.0"; + +pub(crate) fn create_meter(meter_provider: &dyn MeterProvider) -> Meter { + meter_provider.meter_with_scope( + InstrumentationScope::builder(METER_NAME) + .with_version(VERSION) + .with_schema_url(SCHEMA_URL) + .build(), + ) +} + +#[cfg(any(feature = "retry", feature = "circuit-breaker", feature = "timeout", test))] +pub(crate) fn create_resilience_event_counter(meter: &Meter) -> opentelemetry::metrics::Counter { + meter + .u64_counter("resilience.event") + .with_description("Emitted upon the occurrence of a resilience event.") + .with_unit("u64") + .build() +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +#[cfg(not(miri))] +mod tests { + use opentelemetry_sdk::metrics::{InMemoryMetricExporter, SdkMeterProvider}; + + use super::*; + + #[test] + fn assert_definitions() { + let exporter = InMemoryMetricExporter::default(); + let meter_provider = SdkMeterProvider::builder().with_periodic_exporter(exporter.clone()).build(); + + let meter = create_meter(&meter_provider); + let resilience_events = create_resilience_event_counter(&meter); + resilience_events.add(1, &[]); + + meter_provider.force_flush().unwrap(); + + let metrics = exporter.get_finished_metrics().unwrap(); + let str = format!("{metrics:?}"); + + assert!(str.contains("resilience.event")); + assert!(str.contains("u64")); + assert!(str.contains("seatbelt")); + assert!(str.contains("v0.1.0")); + assert!(str.contains("https://opentelemetry.io/schemas/1.47")); + } +} diff --git a/crates/seatbelt/src/retry/args.rs b/crates/seatbelt/src/retry/args.rs new file mode 100644 index 00000000..bcd52878 --- /dev/null +++ b/crates/seatbelt/src/retry/args.rs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Duration; + +use tick::Clock; + +use crate::{RecoveryInfo, retry::Attempt}; + +/// Arguments for the [`clone_input_with`][super::RetryLayer::clone_input_with] callback function. +/// +/// Provides context for input cloning operations. +#[derive(Debug)] +pub struct CloneArgs { + pub(super) attempt: Attempt, + pub(super) previous_recovery: Option, +} + +impl CloneArgs { + /// Returns the current attempt information. + #[must_use] + pub fn attempt(&self) -> Attempt { + self.attempt + } + + /// Returns the recovery information from the previous attempt, if any. + #[must_use] + pub fn previous_recovery(&self) -> Option<&RecoveryInfo> { + self.previous_recovery.as_ref() + } +} + +/// Arguments for the [`recovery_with`][super::RetryLayer::recovery_with] callback function. +/// +/// Provides context for recovery classification. +#[derive(Debug)] +pub struct RecoveryArgs<'a> { + pub(super) attempt: Attempt, + pub(super) clock: &'a Clock, +} + +impl RecoveryArgs<'_> { + /// Returns the current attempt information. + #[must_use] + pub fn attempt(&self) -> Attempt { + self.attempt + } + + /// Returns the clock used for time-related operations. + #[must_use] + pub fn clock(&self) -> &Clock { + self.clock + } +} + +/// Arguments for the [`on_retry`][super::RetryLayer::on_retry] callback function. +/// +/// Provides context for retry notifications. +#[derive(Debug)] +pub struct OnRetryArgs { + pub(super) attempt: Attempt, + pub(super) retry_delay: Duration, + pub(super) recovery: RecoveryInfo, +} + +impl OnRetryArgs { + /// Returns the current attempt information. + #[must_use] + pub fn attempt(&self) -> Attempt { + self.attempt + } + + /// Returns the delay before the next retry attempt. + #[must_use] + pub fn retry_delay(&self) -> Duration { + self.retry_delay + } + + /// Returns the recovery information that triggered this retry. + #[must_use] + pub fn recovery(&self) -> &RecoveryInfo { + &self.recovery + } +} + +/// Arguments for the [`restore_input`][super::RetryLayer::restore_input] callback function. +/// +/// Provides context for input restoration when cloning is unavailable. +#[derive(Debug)] +pub struct RestoreInputArgs { + pub(super) attempt: Attempt, + pub(super) recovery: RecoveryInfo, +} + +impl RestoreInputArgs { + /// Returns the current attempt information. + #[must_use] + pub fn attempt(&self) -> Attempt { + self.attempt + } + + /// Returns the recovery information that triggered this restoration attempt. + #[must_use] + pub fn recovery(&self) -> &RecoveryInfo { + &self.recovery + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_recover_args() { + let clock = Clock::new_frozen(); + + let args = RecoveryArgs { + attempt: Attempt::new(3, true), + clock: &clock, + }; + + assert_eq!(args.attempt(), Attempt::new(3, true)); + let _clock = args.clock(); + } + + #[test] + fn on_retry_args() { + let args = OnRetryArgs { + attempt: Attempt::new(2, false), + retry_delay: Duration::from_secs(5), + recovery: RecoveryInfo::retry(), + }; + + assert_eq!(args.attempt(), Attempt::new(2, false)); + assert_eq!(args.retry_delay(), Duration::from_secs(5)); + assert_eq!(*args.recovery(), RecoveryInfo::retry()); + } + + #[test] + fn clone_args() { + let args = CloneArgs { + attempt: Attempt::new(1, false), + previous_recovery: Some(RecoveryInfo::retry()), + }; + + assert_eq!(args.attempt(), Attempt::new(1, false)); + assert_eq!(args.previous_recovery(), Some(&RecoveryInfo::retry())); + } + + #[test] + fn restore_input_args() { + let args = RestoreInputArgs { + attempt: Attempt::new(2, true), + recovery: RecoveryInfo::retry(), + }; + + assert_eq!(args.attempt(), Attempt::new(2, true)); + assert_eq!(*args.recovery(), RecoveryInfo::retry()); + } +} diff --git a/crates/seatbelt/src/retry/attempt.rs b/crates/seatbelt/src/retry/attempt.rs new file mode 100644 index 00000000..8c6eb3c7 --- /dev/null +++ b/crates/seatbelt/src/retry/attempt.rs @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::fmt::Display; + +/// Represents a single attempt in a retry operation. +/// +/// This struct tracks the current attempt index, and it provides methods to check if this is the +/// first or last attempt. +/// +/// The default attempt has: +/// +/// - `attempt`: 0 (first attempt, 0-based indexing) +/// - `is_first`: true +/// - `is_last`: true +/// +/// This represents a single-shot operation with no retries, where the first +/// attempt is also the final attempt. +/// +/// # Examples +/// +/// ``` +/// use seatbelt::retry::Attempt; +/// +/// // Create the first attempt (attempt 0) +/// let attempt = Attempt::new(0, false); +/// assert!(attempt.is_first()); +/// assert!(!attempt.is_last()); +/// assert_eq!(attempt.index(), 0); +/// +/// // Create the last attempt (attempt 2) +/// let last_attempt = Attempt::new(2, true); +/// assert!(!last_attempt.is_first()); +/// assert!(last_attempt.is_last()); +/// assert_eq!(last_attempt.index(), 2); +/// +/// // Use the default attempt (single-shot operation) +/// let default_attempt = Attempt::default(); +/// assert_eq!(default_attempt.index(), 0); +/// assert!(default_attempt.is_first()); +/// assert!(default_attempt.is_last()); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Attempt { + index: u32, + is_last: bool, +} + +impl Default for Attempt { + fn default() -> Self { + Self::new(0, true) + } +} + +impl Attempt { + /// Creates a new attempt with the given attempt index and maximum attempts. + /// + /// # Examples + /// + /// ``` + /// use seatbelt::retry::Attempt; + /// + /// let attempt = Attempt::new(0, false); + /// assert_eq!(attempt.index(), 0); + /// ``` + #[must_use] + pub fn new(index: u32, is_last: bool) -> Self { + Self { index, is_last } + } + + /// Returns true if this is the first attempt (attempt 0). + /// + /// # Examples + /// + /// ``` + /// use seatbelt::retry::Attempt; + /// + /// let first_attempt = Attempt::new(0, false); + /// assert!(first_attempt.is_first()); + /// + /// let second_attempt = Attempt::new(1, false); + /// assert!(!second_attempt.is_first()); + /// ``` + #[must_use] + pub fn is_first(self) -> bool { + self.index == 0 + } + + /// Returns true if this is the last allowed attempt. + /// + /// # Examples + /// + /// ``` + /// use seatbelt::retry::Attempt; + /// + /// let not_last = Attempt::new(1, false); + /// assert!(!not_last.is_last()); + /// + /// let last = Attempt::new(1, true); + /// assert!(last.is_last()); + /// ``` + #[must_use] + pub fn is_last(self) -> bool { + self.is_last + } + + /// Returns the current attempt index (0-based). + /// + /// # Examples + /// + /// ``` + /// use seatbelt::retry::Attempt; + /// + /// let attempt = Attempt::new(3, false); + /// assert_eq!(attempt.index(), 3); + /// ``` + #[must_use] + pub fn index(self) -> u32 { + self.index + } + + #[cfg_attr(test, mutants::skip)] // causes test timeouts + #[cfg(any(feature = "retry", test))] + pub(crate) fn increment(self, max_attempts: MaxAttempts) -> Option { + let next = self.index.saturating_add(1); + + match max_attempts { + MaxAttempts::Finite(index) => { + // If we've reached or exceeded the maximum number of attempts, return None. + if next >= index { + return None; + } + + let is_last = next == index.saturating_sub(1); + Some(Self::new(next, is_last)) + } + MaxAttempts::Infinite => Some(Self::new(next, false)), + } + } +} + +impl Display for Attempt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.index.fmt(f) + } +} + +/// Represents the maximum number of retry attempts allowed. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg(any(feature = "retry", test))] +#[non_exhaustive] +pub(crate) enum MaxAttempts { + /// A finite number of retry attempts. + Finite(u32), + + /// An infinite number of retry attempts. + #[cfg(any(feature = "retry", test))] // currently, only used with retry feature + Infinite, +} + +#[cfg(any(feature = "retry", test))] +impl MaxAttempts { + #[cfg(any(feature = "retry", test))] + pub fn first_attempt(self) -> Attempt { + Attempt::new(0, matches!(self, Self::Finite(1))) + } +} + +#[cfg(any(feature = "retry", test))] +impl From for MaxAttempts { + fn from(value: u32) -> Self { + Self::Finite(value) + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_with_zero_is_first_and_not_last() { + let a = Attempt::new(0, false); + assert_eq!(a.index(), 0); + assert!(a.is_first()); + assert!(!a.is_last()); + } + + #[test] + fn new_when_equal_to_max_is_last() { + let a = Attempt::new(5, true); + assert!(a.is_last()); + assert!(!a.is_first()); + } + + #[test] + fn new_with_zero_max_is_both_first_and_last() { + let a = Attempt::new(0, true); + assert!(a.is_first()); + assert!(a.is_last()); + } + + #[test] + fn increment_correct_behavior() { + let max_attempts = MaxAttempts::Finite(2); + let a = Attempt::new(0, false); + assert_eq!(a.index(), 0); + assert!(!a.is_last()); + + let a = a.increment(max_attempts).unwrap(); + assert_eq!(a.index(), 1); + assert!(a.is_last()); + + let a = a.increment(max_attempts); + assert!(a.is_none()); + } + + #[test] + fn increment_with_infinite_preserves_number() { + let a = Attempt::new(u32::MAX, false); + let next = a.increment(MaxAttempts::Infinite).unwrap(); + assert!(!next.is_last()); + assert_eq!(next.index(), u32::MAX); + } + + #[test] + fn display_shows_index() { + let a = Attempt::new(42, false); + assert_eq!(format!("{a}"), "42"); + } + + #[test] + fn from_u32_to_max_attempts() { + let finite: MaxAttempts = 5u32.into(); + assert_eq!(finite, MaxAttempts::Finite(5)); + + let infinite: MaxAttempts = u32::MAX.into(); + assert_eq!(infinite, MaxAttempts::Finite(u32::MAX)); + } + + #[test] + fn first_attempt_returns_correct_attempt() { + let max_finite = MaxAttempts::Finite(3); + let first = max_finite.first_attempt(); + assert_eq!(first.index(), 0); + assert!(!first.is_last()); + + let max_one = MaxAttempts::Finite(1); + let first_one = max_one.first_attempt(); + assert_eq!(first_one.index(), 0); + assert!(first_one.is_last()); + + let max_infinite = MaxAttempts::Infinite; + let first_infinite = max_infinite.first_attempt(); + assert_eq!(first_infinite.index(), 0); + assert!(!first_infinite.is_last()); + } + + #[test] + fn default_ok() { + let default = Attempt::default(); + assert_eq!(default.index(), 0); + assert!(default.is_first()); + assert!(default.is_last()); + } +} diff --git a/crates/seatbelt/src/retry/backoff copy.rs b/crates/seatbelt/src/retry/backoff copy.rs new file mode 100644 index 00000000..bd4947e5 --- /dev/null +++ b/crates/seatbelt/src/retry/backoff copy.rs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + + diff --git a/crates/seatbelt/src/retry/backoff.rs b/crates/seatbelt/src/retry/backoff.rs new file mode 100644 index 00000000..3d963a39 --- /dev/null +++ b/crates/seatbelt/src/retry/backoff.rs @@ -0,0 +1,696 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::cmp::min; +use std::time::Duration; + +use crate::retry::constants::{DEFAULT_BACKOFF, DEFAULT_BASE_DELAY, DEFAULT_USE_JITTER}; +use crate::rnd::Rnd; + +/// The factor used to determine the range of jitter applied to delays. +const JITTER_FACTOR: f64 = 0.5; + +/// The default factor used for exponential backoff calculations for cases where jitter is not applied. +const EXPONENTIAL_FACTOR: f64 = 2.0; + +/// Defines the backoff strategy used by resilience middleware for retry operations. +/// +/// Backoff strategies control how delays between retry attempts are calculated, providing +/// different approaches to spacing out retries to avoid overwhelming failing systems while +/// balancing responsiveness and resource utilization. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Backoff { + /// Constant backoff strategy that maintains consistent delays between attempts. + /// + /// **Example with `2s` base delay:** `2s, 2s, 2s, 2s, ...` + Constant, + + /// Linear backoff strategy that increases delays proportionally with attempt count. + /// + /// **Example with `2s` base delay:** `2s, 4s, 6s, 8s, 10s, ...` + Linear, + + /// Exponential backoff strategy that doubles delays with each attempt. + /// + /// **Example with `2s` base delay:** `2s, 4s, 8s, 16s, 32s, ...` + Exponential, +} + +// The delay generation follows the Polly V8 implementation: +// +// https://github.com/App-vNext/Polly/blob/452b34ee1e3a45ccce156a6980f60901a623ee67/src/Polly.Core/Retry/RetryHelper.cs#L3 +#[derive(Debug)] +pub(crate) struct DelayBackoff(pub(super) BackoffOptions); + +impl From for DelayBackoff { + fn from(props: BackoffOptions) -> Self { + Self(props) + } +} + +impl DelayBackoff { + pub fn delays(&self) -> impl Iterator { + DelaysIter { + props: self.0.clone(), + attempt: 0, + prev: 0.0, + } + } +} + +#[derive(Debug)] +struct DelaysIter { + props: BackoffOptions, + attempt: u32, + // The state that is required to compute the next delay when using + // decorrelated jitter backoff. + prev: f64, +} + +impl Iterator for DelaysIter { + type Item = Duration; + + fn next(&mut self) -> Option { + // zero base delay => always zero + if self.props.base_delay.is_zero() { + return Some(Duration::ZERO); + } + + let next_attempt = self.attempt.saturating_add(1); + let delay = match (self.props.backoff_type, self.props.use_jitter) { + (Backoff::Constant, false) => self.props.base_delay, + (Backoff::Constant, true) => apply_jitter(self.props.base_delay, &self.props.rnd), + (Backoff::Linear, _) => { + let delay = self.props.base_delay.saturating_mul(next_attempt); + if self.props.use_jitter { + apply_jitter(delay, &self.props.rnd) + } else { + delay + } + } + (Backoff::Exponential, false) => duration_mul_pow2(self.props.base_delay, self.attempt), + (Backoff::Exponential, true) => { + decorrelated_jitter_backoff_v2(self.attempt, self.props.base_delay, &mut self.prev, &self.props.rnd) + } + }; + + self.attempt = next_attempt; + Some(clamp_to_max(delay, self.props.max_delay)) + } +} + +fn clamp_to_max(d: Duration, max: Option) -> Duration { + max.map_or(d, |m| min(d, m)) +} + +fn duration_mul_pow2(base: Duration, attempt: u32) -> Duration { + let factor = EXPONENTIAL_FACTOR.powi(i32::try_from(attempt).unwrap_or(i32::MAX)); + secs_to_duration_saturating(base.as_secs_f64() * factor) +} + +/// Adds a symmetric, uniform jitter around the given delay. +/// +/// - Jitter is in both directions and relative to `delay` (centered on it). +/// - With `JITTER_FACTOR = 0.5`, the result lies in `[0.75*delay, 1.25*delay]`. +/// - Randomness comes from [`Rnd`]; conversion saturates on overflow and clamps at zero. +#[inline] +fn apply_jitter(delay: Duration, rnd: &Rnd) -> Duration { + let ms = delay.as_secs_f64() * 1000.0; + let offset = (ms * JITTER_FACTOR) / 2.0; + let random_delay = (ms * JITTER_FACTOR).mul_add(rnd.next_f64(), -offset); + let new_ms = ms + random_delay; + + secs_to_duration_saturating(new_ms / 1000.0) +} + +/// De-correlated jitter backoff (`v2`): smooth exponential growth with bounded randomization. +/// +/// De-correlated jitter `V2` spreads retries evenly while preserving exponential backoff +/// (with a configurable first-retry median), reducing synchronized spikes and tail-latency +/// compared to naive random jitter. +/// +/// What does "de-correlated" mean? +/// +/// - Successive delays are not a direct function of the immediately previous +/// delay. Instead, each step samples a random phase (`t = attempt + U[0,1)`) +/// and advances a smooth curve; we only take the delta from the previous +/// position on that curve. This weakens correlation between consecutive +/// samples and reduces synchronization across many callers. +/// +/// What does `v2` mean? +/// +/// - It refers to the second-generation formulation from +/// `Polly.Contrib.WaitAndRetry` (linked below). Compared to the earlier +/// (`v1`) "de-correlated jitter" popularized in the `AWS` blog post, `v2` uses a +/// closed-form function combining exponential growth with a `tanh(sqrt(p*t))` +/// taper to achieve monotonic expected growth, reduced tail latency, and a +/// tighter distribution while still remaining de-correlated. +/// +/// References +/// - [`Polly V8` implementation](https://github.com/App-vNext/Polly/blob/8ba1e3ba295542cbc937d0555fadfa0d23b5c568/src/Polly.Core/Retry/RetryHelper.cs#L96) +/// - [`Polly V7` implementation](https://github.com/Polly-Contrib/Polly.Contrib.WaitAndRetry/blob/7596d2dacf22d88bbd814bc49c28424fb6e921e9/src/Polly.Contrib.WaitAndRetry/Backoff.DecorrelatedJitterV2.cs#L22) +/// - [`Polly.Contrib.WaitAndRetry` repo](https://github.com/Polly-Contrib/Polly.Contrib.WaitAndRetry) +#[inline] +#[cfg_attr(test, mutants::skip)] // Mutating arithmetic causes infinite loops in retry tests +fn decorrelated_jitter_backoff_v2(attempt: u32, base_delay: Duration, prev: &mut f64, rnd: &Rnd) -> Duration { + // The original author/credit for this jitter formula is @george-polevoy . + // Jitter formula used with permission as described at https://github.com/App-vNext/Polly/issues/530#issuecomment-526555979 + // Minor adaptations (pFactor = 4.0 and rpScalingFactor = 1 / 1.4d) by @reisenberger, to scale the formula output for easier parameterization to users. + + // A factor used within the formula to help smooth the first calculated delay. + const P_FACTOR: f64 = 4.0; + + // A factor used to scale the median values of the retry times generated by the formula to be _near_ whole seconds, to aid Polly user comprehension. + // This factor allows the median values to fall approximately at 1, 2, 4 etc seconds, instead of 1.4, 2.8, 5.6, 11.2. + const RP_SCALING: f64 = 1.0 / 1.4; + + let target_secs_first_delay = base_delay.as_secs_f64(); + + let t = f64::from(attempt) + rnd.next_f64(); + let next = t.exp2() * (P_FACTOR * t).sqrt().tanh(); + + if !next.is_finite() { + *prev = next; + return Duration::MAX; + } + + let formula_intrinsic_value = next - *prev; + *prev = next; + + secs_to_duration_saturating(formula_intrinsic_value * RP_SCALING * target_secs_first_delay) +} + +fn secs_to_duration_saturating(secs: f64) -> Duration { + if secs <= 0.0 { + return Duration::ZERO; + } + + Duration::try_from_secs_f64(secs).unwrap_or(Duration::MAX) +} + +#[derive(Debug, Clone)] +pub(super) struct BackoffOptions { + pub backoff_type: Backoff, + pub base_delay: Duration, + pub max_delay: Option, + pub use_jitter: bool, + pub rnd: Rnd, +} + +impl Default for BackoffOptions { + fn default() -> Self { + Self { + backoff_type: DEFAULT_BACKOFF, + base_delay: DEFAULT_BASE_DELAY, + max_delay: None, + use_jitter: DEFAULT_USE_JITTER, + rnd: Rnd::default(), + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use super::*; + + #[test] + fn default_props() { + let props = BackoffOptions::default(); + + assert_eq!(props.backoff_type, Backoff::Exponential); + assert_eq!(props.base_delay, Duration::from_secs(2)); + assert_eq!(props.max_delay, None); + assert!(props.use_jitter); + } + + #[test] + fn smoke_constant_no_jitter() { + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Constant, + base_delay: Duration::from_millis(200), + max_delay: None, + use_jitter: false, + rnd: Rnd::default(), + }); + let v: Vec<_> = backoff.delays().take(3).collect(); + assert_eq!(v, vec![Duration::from_millis(200); 3]); + } + + #[test] + fn smoke_linear_no_jitter() { + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Linear, + base_delay: Duration::from_millis(100), + max_delay: None, + use_jitter: false, + rnd: Rnd::default(), + }); + + let v: Vec<_> = backoff.delays().take(4).collect(); + assert_eq!( + v, + vec![ + Duration::from_millis(100), + Duration::from_millis(200), + Duration::from_millis(300), + Duration::from_millis(400), + ] + ); + } + + #[test] + fn smoke_exponential_cap() { + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_millis(100), + max_delay: Some(Duration::from_secs(1)), + use_jitter: false, + rnd: Rnd::default(), + }); + + // 100ms, 200ms, 400ms, 800ms, then clamped at 1s + let v: Vec<_> = backoff.delays().take(6).collect(); + assert_eq!(v[0], Duration::from_millis(100)); + assert_eq!(v[1], Duration::from_millis(200)); + assert_eq!(v[2], Duration::from_millis(400)); + assert_eq!(v[3], Duration::from_millis(800)); + assert_eq!(v[4], Duration::from_secs(1)); + assert_eq!(v[5], Duration::from_secs(1)); + } + + #[test] + fn zero_base_delay_always_zero() { + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::ZERO, + max_delay: None, + use_jitter: true, + rnd: Rnd::default(), + }); + let v: Vec<_> = backoff.delays().take(5).collect(); + assert!(v.iter().all(|d| *d == Duration::ZERO)); + } + + #[test] + fn constant_with_jitter() { + // Test with fixed random values to verify jitter behavior + let rnd = Rnd::new_fixed(0.0); + + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Constant, + base_delay: Duration::from_secs(1), + max_delay: None, + use_jitter: true, + rnd, + }); + + let v: Vec<_> = backoff.delays().take(3).collect(); + // With random value 0.0, jitter should give us 0.75 seconds + assert_eq!(v[0], Duration::from_millis(750)); + assert_eq!(v[1], Duration::from_millis(750)); + assert_eq!(v[2], Duration::from_millis(750)); + } + + #[test] + fn constant_with_different_jitter_values() { + // Test with random value 0.4 + let rnd = Rnd::new_fixed(0.4); + + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Constant, + base_delay: Duration::from_secs(1), + max_delay: None, + use_jitter: true, + rnd, + }); + + let delay = backoff.delays().next().unwrap(); + // With random value 0.4, jitter should give us 0.95 seconds + assert_eq!(delay, Duration::from_millis(950)); + + // Test with random value 1.0 + let rnd = Rnd::new_fixed(1.0); + + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Constant, + base_delay: Duration::from_secs(1), + max_delay: None, + use_jitter: true, + rnd, + }); + + let delay = backoff.delays().next().unwrap(); + // With random value 1.0, jitter should give us 1.25 seconds + assert_eq!(delay, Duration::from_millis(1250)); + } + + #[test] + fn linear_with_jitter() { + // Test with fixed random value 0.5 + let rnd = Rnd::new_fixed(0.5); + + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Linear, + base_delay: Duration::from_secs(1), + max_delay: None, + use_jitter: true, + rnd, + }); + + let v: Vec<_> = backoff.delays().take(3).collect(); + // attempt 0: base_delay * 1 = 1s, with jitter 0.5 should be exactly 1s + // attempt 1: base_delay * 2 = 2s, with jitter 0.5 should be exactly 2s + // attempt 2: base_delay * 3 = 3s, with jitter 0.5 should be exactly 3s + assert_eq!(v[0], Duration::from_secs(1)); + assert_eq!(v[1], Duration::from_secs(2)); + assert_eq!(v[2], Duration::from_secs(3)); + } + + #[test] + fn linear_with_different_jitter_values() { + // Test linear with various jitter values for attempt 2 (3rd delay) + let test_cases = [ + (0.0, 2250), // 3s * (1 + 0.5 * (0.0 - 0.5)) = 2.25s + (0.4, 2850), // 3s * (1 + 0.5 * (0.4 - 0.5)) = 2.85s + (0.6, 3150), // 3s * (1 + 0.5 * (0.6 - 0.5)) = 3.15s + (1.0, 3750), // 3s * (1 + 0.5 * (1.0 - 0.5)) = 3.75s + ]; + + for (random_val, expected_ms) in test_cases { + let rnd = Rnd::new_fixed(random_val); + + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Linear, + base_delay: Duration::from_secs(1), + max_delay: None, + use_jitter: true, + rnd, + }); + + let v: Vec<_> = backoff.delays().take(3).collect(); + assert_eq!(v[2], Duration::from_millis(expected_ms)); + } + } + + #[test] + fn exponential_no_jitter() { + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_secs(1), + max_delay: None, + use_jitter: false, + rnd: Rnd::default(), + }); + + let v: Vec<_> = backoff.delays().take(3).collect(); + assert_eq!( + v, + vec![ + Duration::from_secs(1), // 2^0 = 1 + Duration::from_secs(2), // 2^1 = 2 + Duration::from_secs(4), // 2^2 = 4 + ] + ); + } + + #[test] + fn max_delay_respected_all_types() { + let max_delay = Duration::from_secs(1); + + // Test constant with large base delay + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Constant, + base_delay: Duration::from_secs(10), + max_delay: Some(max_delay), + use_jitter: false, + rnd: Rnd::default(), + }); + let v: Vec<_> = backoff.delays().take(3).collect(); + assert!(v.iter().all(|d| *d == max_delay)); + + // Test linear with large multiplier + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Linear, + base_delay: Duration::from_secs(10), + max_delay: Some(max_delay), + use_jitter: false, + rnd: Rnd::default(), + }); + let v: Vec<_> = backoff.delays().take(3).collect(); + assert!(v.iter().all(|d| *d == max_delay)); + + // Test exponential with large base delay + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_secs(10), + max_delay: Some(max_delay), + use_jitter: false, + rnd: Rnd::default(), + }); + let v: Vec<_> = backoff.delays().take(3).collect(); + assert!(v.iter().all(|d| *d == max_delay)); + } + + #[test] + fn max_delay_with_jitter() { + let rnd = Rnd::new_fixed(0.5); + + let max_delay = Duration::from_secs(1); + + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Linear, + base_delay: Duration::from_secs(10), + max_delay: Some(max_delay), + use_jitter: true, + rnd, + }); + + let v: Vec<_> = backoff.delays().take(3).collect(); + assert!(v.iter().all(|d| *d == max_delay)); + } + + #[test] + fn delay_less_than_max_delay_respected() { + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Constant, + base_delay: Duration::from_secs(1), + max_delay: Some(Duration::from_secs(2)), + use_jitter: false, + rnd: Rnd::default(), + }); + + let v: Vec<_> = backoff.delays().take(3).collect(); + assert!(v.iter().all(|d| *d == Duration::from_secs(1))); + } + + #[test] + fn exponential_overflow_returns_max_duration() { + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_secs(86400), // 1 day + max_delay: None, + use_jitter: false, + rnd: Rnd::default(), + }); + + // Large attempt should cause overflow and return Duration::MAX + let v: Vec<_> = backoff.delays().skip(1000).take(1).collect(); + assert_eq!(v[0], Duration::MAX); + } + + #[test] + fn exponential_overflow_with_max_delay() { + let max_delay = Duration::from_secs(172_800); // 2 days + + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_secs(86400), // 1 day + max_delay: Some(max_delay), + use_jitter: false, + rnd: Rnd::default(), + }); + + // Large attempt should cause overflow but be clamped to max_delay + let v: Vec<_> = backoff.delays().skip(1000).take(1).collect(); + assert_eq!(v[0], max_delay); + } + + #[test] + fn exponential_with_jitter_is_positive() { + let test_attempts = [1, 2, 3, 4, 10, 100, 1000, 1024, 1025]; + + for attempt in test_attempts { + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_secs(2), + max_delay: None, + use_jitter: true, + rnd: Rnd::default(), + }); + + let delays: Vec<_> = backoff.delays().skip(attempt).take(2).collect(); + assert!(delays[0] > Duration::ZERO, "Attempt {attempt}: first delay should be positive"); + assert!(delays[1] > Duration::ZERO, "Attempt {attempt}: second delay should be positive"); + } + } + + #[test] + fn exponential_with_jitter_respects_max_delay() { + let test_attempts = [1, 2, 3, 4, 10, 100, 1000, 1024, 1025]; + let max_delay = Duration::from_secs(30); + + for attempt in test_attempts { + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_secs(2), + max_delay: Some(max_delay), + use_jitter: true, + rnd: Rnd::default(), + }); + + let delays: Vec<_> = backoff.delays().skip(attempt).take(2).collect(); + assert!(delays[0] > Duration::ZERO, "Attempt {attempt}: first delay should be positive"); + assert!(delays[0] <= max_delay, "Attempt {attempt}: first delay should not exceed max"); + assert!(delays[1] > Duration::ZERO, "Attempt {attempt}: second delay should be positive"); + assert!(delays[1] <= max_delay, "Attempt {attempt}: second delay should not exceed max"); + } + } + + #[test] + fn exponential_with_jitter_reproducible_with_fixed_values() { + let rnd1 = Rnd::new_fixed(0.5); + let backoff1 = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_millis(7800), // 7.8 seconds + max_delay: None, + use_jitter: true, + rnd: rnd1, + }); + + let rnd2 = Rnd::new_fixed(0.5); + let backoff2 = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_millis(7800), // 7.8 seconds + max_delay: None, + use_jitter: true, + rnd: rnd2, + }); + + let delays1: Vec<_> = backoff1.delays().take(10).collect(); + let delays2: Vec<_> = backoff2.delays().take(10).collect(); + + assert_eq!(delays1, delays2); + assert!(delays1.iter().all(|d| *d > Duration::ZERO)); + } + + #[test] + fn exponential_with_jitter_different_values_different_results() { + let rnd1 = Rnd::new_fixed(0.2); + let backoff1 = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_millis(7800), // 7.8 seconds + max_delay: None, + use_jitter: true, + rnd: rnd1, + }); + + let rnd2 = Rnd::new_fixed(0.8); + let backoff2 = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_millis(7800), // 7.8 seconds + max_delay: None, + use_jitter: true, + rnd: rnd2, + }); + + let delays1: Vec<_> = backoff1.delays().take(10).collect(); + let delays2: Vec<_> = backoff2.delays().take(10).collect(); + + assert_ne!(delays1, delays2); + assert!(delays1.iter().all(|d| *d > Duration::ZERO)); + assert!(delays2.iter().all(|d| *d > Duration::ZERO)); + } + + // This test checks that the exponential backoff with jitter produces the same sequence of delays + // as Polly v8: + // + // https://github.com/App-vNext/Polly/blob/452b34ee1e3a45ccce156a6980f60901a623ee67/test/Polly.Core.Tests/Retry/RetryHelperTests.cs#L254 + #[test] + fn exponential_with_jitter_compatibility_with_polly_v8() { + let random_values = Mutex::new( + [ + 0.726_243_269_967_959_8, + 0.817_325_359_590_968_7, + 0.768_022_689_394_663_4, + 0.558_161_191_436_537_2, + 0.206_033_154_021_032_7, + 0.558_884_794_618_415_1, + 0.906_027_066_011_925_7, + 0.442_177_873_310_715_84, + 0.977_549_753_141_379_8, + 0.273_704_457_689_870_34, + ] + .into_iter(), + ); + + let delays_ms = [8_626, 10_830, 18_396, 27_703, 37_213, 159_824, 405_539, 300_743, 1_839_611, 639_970]; + + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_millis(7800), // 7.8 seconds + max_delay: None, + use_jitter: true, + rnd: Rnd::new_function(move || random_values.lock().unwrap().next().unwrap()), + }); + + let computed: Vec<_> = backoff.delays().take(10).map(|v| v.as_millis()).collect(); + assert_eq!(computed, delays_ms); + } + + #[test] + fn exponential_without_jitter_ensure_expected_delays() { + let random_values = Mutex::new( + [ + 0.726_243_269_967_959_8, + 0.817_325_359_590_968_7, + 0.768_022_689_394_663_4, + 0.558_161_191_436_537_2, + 0.206_033_154_021_032_7, + 0.558_884_794_618_415_1, + 0.906_027_066_011_925_7, + 0.442_177_873_310_715_84, + 0.977_549_753_141_379_8, + 0.273_704_457_689_870_34, + ] + .into_iter(), + ); + + let delays_ms = [7800, 15600, 31200, 62400, 124_800, 249_600, 499_200, 998_400, 1_996_800, 3_993_600]; + + let backoff = DelayBackoff(BackoffOptions { + backoff_type: Backoff::Exponential, + base_delay: Duration::from_millis(7800), // 7.8 seconds + max_delay: None, + use_jitter: false, + rnd: Rnd::new_function(move || random_values.lock().unwrap().next().unwrap()), + }); + + let computed: Vec<_> = backoff.delays().take(10).map(|v| v.as_millis()).collect(); + assert_eq!(computed, delays_ms); + } + + #[test] + fn secs_to_duration_saturating_zero() { + assert_eq!(secs_to_duration_saturating(0.0), Duration::ZERO); + } + + #[test] + fn secs_to_duration_saturating_negative() { + assert_eq!(secs_to_duration_saturating(-1.0), Duration::ZERO); + assert_eq!(secs_to_duration_saturating(-0.001), Duration::ZERO); + assert_eq!(secs_to_duration_saturating(f64::NEG_INFINITY), Duration::ZERO); + } +} diff --git a/crates/seatbelt/src/retry/callbacks.rs b/crates/seatbelt/src/retry/callbacks.rs new file mode 100644 index 00000000..bedcd076 --- /dev/null +++ b/crates/seatbelt/src/retry/callbacks.rs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::{CloneArgs, OnRetryArgs, RecoveryArgs, RestoreInputArgs}; +use crate::RecoveryInfo; + +crate::utils::define_fn_wrapper!(CloneInput(Fn(&mut In, CloneArgs) -> Option)); +crate::utils::define_fn_wrapper!(ShouldRecover(Fn(&Out, RecoveryArgs) -> RecoveryInfo)); +crate::utils::define_fn_wrapper!(OnRetry(Fn(&Out, OnRetryArgs))); +crate::utils::define_fn_wrapper!(RestoreInput(Fn(&mut Out, RestoreInputArgs) -> Option)); diff --git a/crates/seatbelt/src/retry/constants.rs b/crates/seatbelt/src/retry/constants.rs new file mode 100644 index 00000000..c0777883 --- /dev/null +++ b/crates/seatbelt/src/retry/constants.rs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Duration; + +use crate::retry::Backoff; + +/// Default backoff strategy: exponential backoff. +/// +/// Exponential backoff quickly reduces request pressure after failures and +/// naturally spaces out subsequent attempts. This is the commonly recommended +/// choice for transient faults in distributed systems and pairs well with jitter +/// to avoid thundering herds. +pub(super) const DEFAULT_BACKOFF: Backoff = Backoff::Exponential; + +/// Base delay for the backoff schedule; conservative 2 seconds by default. +/// +/// A `2s` starting delay prevents aggressive retry storms during partial outages +/// while still enabling fast recovery for short-lived failures. Workloads with +/// different needs can override this via configuration. +pub(super) const DEFAULT_BASE_DELAY: Duration = Duration::from_secs(2); + +/// Enable jitter by default to de-synchronize clients and reduce contention. +/// +/// Randomizing retry delays mitigates correlated bursts and improves tail +/// latency under contention. See [Exponential Backoff and Jitter](https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter) for details. +pub(super) const DEFAULT_USE_JITTER: bool = true; + +/// Default maximum retry attempts: 3. +/// +/// The default is inherited from `Polly v8` which is a widely used resilience library +/// and also uses 3 retry attempts. +pub(super) const DEFAULT_RETRY_ATTEMPTS: u32 = 3; diff --git a/crates/seatbelt/src/retry/layer.rs b/crates/seatbelt/src/retry/layer.rs new file mode 100644 index 00000000..51f62939 --- /dev/null +++ b/crates/seatbelt/src/retry/layer.rs @@ -0,0 +1,786 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::borrow::Cow; +use std::marker::PhantomData; +use std::time::Duration; + +use super::MaxAttempts; +use crate::retry::backoff::BackoffOptions; +use crate::retry::constants::DEFAULT_RETRY_ATTEMPTS; +use crate::retry::{CloneArgs, CloneInput, OnRetry, OnRetryArgs, RecoveryArgs, RestoreInput, RestoreInputArgs, Retry, ShouldRecover}; +use crate::utils::EnableIf; +use crate::utils::TelemetryHelper; +use crate::{NotSet, Recovery, RecoveryInfo, ResilienceContext, Set, retry::Backoff}; +use layered::Layer; + +/// Builder for configuring retry resilience middleware. +/// +/// This type is created by calling [`Retry::layer`](crate::retry::Retry::layer) and uses the +/// type-state pattern to enforce that required properties are configured before the retry middleware can be built: +/// +/// - [`clone_input`][RetryLayer::clone_input]: Required to specify how to clone inputs for retry attempts +/// - [`recovery`][RetryLayer::recovery]: Required to determine if an output should trigger a retry +/// +/// For comprehensive examples, see the [retry module][crate::retry] documentation. +/// +/// # Type State +/// +/// - `S1`: Tracks whether [`clone_input`][RetryLayer::clone_input] has been set +/// - `S2`: Tracks whether [`recovery`][RetryLayer::recovery] has been set +#[derive(Debug)] +pub struct RetryLayer { + context: ResilienceContext, + max_attempts: MaxAttempts, + backoff: BackoffOptions, + clone_input: Option>, + should_recover: Option>, + on_retry: Option>, + enable_if: EnableIf, + telemetry: TelemetryHelper, + restore_input: Option>, + handle_unavailable: bool, + _state: PhantomData Out>, +} + +impl RetryLayer { + #[must_use] + pub(crate) fn new(name: Cow<'static, str>, context: &ResilienceContext) -> Self { + Self { + context: context.clone(), + max_attempts: MaxAttempts::Finite(DEFAULT_RETRY_ATTEMPTS.saturating_add(1)), + backoff: BackoffOptions::default(), + clone_input: None, + should_recover: None, + on_retry: None, + enable_if: EnableIf::always(), + telemetry: context.create_telemetry(name), + restore_input: None, + handle_unavailable: false, + _state: PhantomData, + } + } +} + +impl RetryLayer { + /// Sets the maximum number of retry attempts. + /// + /// This specifies the maximum number of retry attempts in addition to the original call. + /// For example, if `max_retry_attempts` is 3, the operation will be attempted up to + /// 4 times total (1 original `+` 3 retries). + /// + /// **Default**: 3 retry attempts + #[must_use] + pub fn max_retry_attempts(mut self, max_attempts: u32) -> Self { + self.max_attempts = MaxAttempts::Finite(max_attempts.saturating_add(1)); + self + } + + /// Configures infinite retry attempts. + /// + /// This setting will cause the operation to be retried indefinitely until it succeeds + /// or the retry is aborted by other means (e.g., cancellation, timeout). + /// + /// **Warning**: Use with caution as this can cause infinite loops if the operation + /// consistently fails. + #[must_use] + pub fn infinite_retry_attempts(mut self) -> Self { + self.max_attempts = MaxAttempts::Infinite; + self + } + + /// Sets the backoff strategy for delay calculation. + /// + /// - [`Backoff::Constant`]: Same delay between all retries + /// - [`Backoff::Linear`]: Linearly increasing delay (`base_delay` `×` attempt) + /// - [`Backoff::Exponential`]: Exponentially increasing delay (`base_delay × 2^attempt`) + /// + /// **Default**: [`Backoff::Exponential`] + #[must_use] + pub fn backoff(mut self, backoff_type: Backoff) -> Self { + self.backoff.backoff_type = backoff_type; + self + } + + /// Sets the base delay used for backoff calculations. + /// + /// The meaning depends on the backoff strategy: + /// - **Constant**: The actual delay between retries + /// - **Linear**: Initial delay; subsequent delays are `base_delay × attempt_number` + /// - **Exponential**: Initial delay; subsequent delays grow exponentially + /// + /// **Default**: 2 seconds + #[must_use] + pub fn base_delay(mut self, delay: Duration) -> Self { + self.backoff.base_delay = delay; + self + } + + /// Sets the maximum allowed delay between retries. + /// + /// This caps the backoff calculation to prevent excessively long delays. + /// If not set, delays can grow indefinitely based on the backoff algorithm. + /// + /// **Default**: None (no limit) + #[must_use] + pub fn max_delay(mut self, max_delay: Duration) -> Self { + self.backoff.max_delay = Some(max_delay); + self + } + + /// Enables or disables jitter to reduce correlation between retries. + /// + /// Jitter adds randomization to delay calculations to prevent thundering herd + /// problems when multiple clients retry simultaneously. This is especially + /// important in distributed systems to avoid synchronized retry storms. + /// + /// **Default**: true (jitter enabled) + #[must_use] + pub fn use_jitter(mut self, use_jitter: bool) -> Self { + self.backoff.use_jitter = use_jitter; + self + } + + /// Sets the input cloning function. + /// + /// This function is called before retry attempts to clone the input. + /// Return `Some(cloned_input)` to proceed with retry, or `None` to abort + /// retry and return the last failed result. + /// + /// This is required because Rust's ownership model doesn't allow reusing + /// the same input value across multiple attempts. + /// + /// # Arguments + /// + /// * `clone_fn` - Function that takes a reference to the input and [`CloneArgs`] + /// containing context about the retry attempt, and returns an optional cloned input + #[must_use] + pub fn clone_input_with( + mut self, + clone_fn: impl Fn(&mut In, CloneArgs) -> Option + Send + Sync + 'static, + ) -> RetryLayer { + self.clone_input = Some(CloneInput::new(clone_fn)); + self.into_state::() + } + + /// Automatically sets the input cloning function for types that implement [`Clone`]. + /// + /// This is a convenience method that uses the standard [`Clone`] trait to clone + /// inputs for retry attempts. For types that implement [`Clone`], this provides + /// a simple way to enable retries without manually implementing a cloning function. + /// + /// This is equivalent to calling [`clone_input_with`][RetryLayer::clone_input_with] with + /// `|input, _args| Some(input.clone())`. + /// + /// # Type Requirements + /// + /// This method is only available when the input type `In` implements [`Clone`]. + #[must_use] + pub fn clone_input(self) -> RetryLayer + where + In: Clone, + { + self.clone_input_with(|input, _args| Some(input.clone())) + } + + /// Sets the recovery classification function. + /// + /// This function determines whether a specific output should trigger a retry + /// by examining the output and returning a [`RecoveryInfo`] classification. + /// + /// The function receives the output and [`RecoveryArgs`] with context + /// about the current attempt. + /// + /// # Arguments + /// + /// * `recover_fn` - Function that takes a reference to the output and + /// [`RecoveryArgs`] containing retry attempt context, and returns + /// a [`RecoveryInfo`] decision + #[must_use] + pub fn recovery_with( + mut self, + recover_fn: impl Fn(&Out, RecoveryArgs) -> RecoveryInfo + Send + Sync + 'static, + ) -> RetryLayer { + self.should_recover = Some(ShouldRecover::new(recover_fn)); + self.into_state::() + } + + /// Automatically sets the recovery classification function for types that implement [`Recovery`]. + /// + /// This is a convenience method that uses the [`Recovery`] trait to determine + /// whether an output should trigger a retry. For types that implement [`Recovery`], + /// this provides a simple way to enable intelligent retry behavior without manually + /// implementing a recovery classification function. + /// + /// This is equivalent to calling [`recovery`][RetryLayer::recovery] with + /// `|output, _args| output.recovery()`. + /// + /// # Type Requirements + /// + /// This method is only available when the output type `Out` implements [`Recovery`]. + #[must_use] + pub fn recovery(self) -> RetryLayer + where + Out: Recovery, + { + self.recovery_with(|out, _args| out.recovery()) + } + + /// Configures a callback invoked before each retry attempt. + /// + /// This callback is useful for logging, metrics, or other observability + /// purposes. It receives the output that triggered the retry and + /// [`OnRetryArgs`] with detailed retry information. + /// + /// The callback does not affect retry behavior - it's purely for observation. + /// + /// **Default**: None (no observability by default) + /// + /// # Arguments + /// + /// * `retry_fn` - Function that takes a reference to the output and + /// [`OnRetryArgs`] containing retry context information + #[must_use] + pub fn on_retry(mut self, retry_fn: impl Fn(&Out, OnRetryArgs) + Send + Sync + 'static) -> Self { + self.on_retry = Some(OnRetry::new(retry_fn)); + self + } + + /// Optionally enables the retry middleware based on a condition. + /// + /// When disabled, requests pass through without retry protection. + /// This call replaces any previous condition. + /// + /// **Default**: Always enabled + /// + /// # Arguments + /// + /// * `is_enabled` - Function that takes a reference to the input and returns + /// `true` if retry protection should be enabled for this request + #[must_use] + pub fn enable_if(mut self, is_enabled: impl Fn(&In) -> bool + Send + Sync + 'static) -> Self { + self.enable_if = EnableIf::new(is_enabled); + self + } + + /// Enables the retry middleware unconditionally. + /// + /// All requests will have retry protection applied. + /// This call replaces any previous condition. + /// + /// **Note**: This is the default behavior - retry is enabled by default. + #[must_use] + pub fn enable_always(mut self) -> Self { + self.enable_if = EnableIf::always(); + self + } + + /// Disables the retry middleware completely. + /// + /// All requests will pass through without retry protection. + /// This call replaces any previous condition. + /// + /// **Note**: This overrides the default enabled behavior. + #[must_use] + pub fn disable(mut self) -> Self { + self.enable_if = EnableIf::never(); + self + } + + /// Configures whether the retry middleware should attempt to recover from unavailable services. + /// + /// When enabled, the retry middleware will treat [`RecoveryInfo::unavailable`] classifications + /// as recoverable conditions and attempt retries. When disabled (default), unavailable services + /// are treated as non-recoverable and cause immediate failure without retry attempts. + /// + /// This is particularly useful when you have access to multiple resources + /// or service endpoints. When one resource is unavailable, the retry + /// mechanism can attempt the operation against a different resource in subsequent + /// attempts, potentially allowing the operation to succeed despite the unavailability. + /// + /// **Default**: false (unavailable responses are not retried) + /// + /// # Arguments + /// + /// * `enable` - `true` to enable unavailable recovery, `false` to treat unavailable responses as permanent failures + /// + /// # Example + /// + /// ```rust + /// # use seatbelt::retry::{Retry, Attempt}; + /// # use seatbelt::{RecoveryInfo, ResilienceContext}; + /// # use tick::Clock; + /// # fn example() { + /// # let context = ResilienceContext::>::new(Clock::new_frozen()); + /// // Service with multiple endpoints that can route around unavailable services + /// let layer = Retry::layer("multi_endpoint_retry", &context) + /// .clone_input_with(|input: &mut String, args| { + /// let mut input = input.clone(); + /// update_endpoint(&mut input, args.attempt()); // Modify input to use a different endpoint + /// Some(input) + /// }) + /// .recovery_with(|result, _args| match result { + /// Ok(_) => RecoveryInfo::never(), + /// Err(msg) if msg.contains("service unavailable") => RecoveryInfo::unavailable(), + /// Err(_) => RecoveryInfo::retry(), + /// }) + /// .handle_unavailable(true); // Try different endpoints on unavailable + /// # } + /// # fn update_endpoint(_input : &mut String, _attempt: Attempt) {} + /// ``` + #[must_use] + pub fn handle_unavailable(mut self, enable: bool) -> Self { + self.handle_unavailable = enable; + self + } + + /// Sets the input restoration function. + /// + /// This function is called when the original input could not be cloned for a retry + /// attempt (i.e., when [`clone_input_with`][RetryLayer::clone_input_with] returns `None`). + /// The restore function receives the output from the failed attempt and can attempt + /// to extract and reconstruct the input for the next retry. + /// + /// This is particularly useful when a service is unavailable and the input was not actually + /// consumed by the operation. A common pattern is that error responses contain or reference + /// the original input that can be extracted for retry. For example, an HTTP request that + /// is rejected even before sending, because the remote service is known to be down. + /// + /// The restore function should return: + /// - `Some(Input)` to proceed with retry using the restored input + /// - `None` to abort retry and return the provided output + /// + /// This enables retry scenarios where input cloning is expensive or impossible, but + /// the input can be extracted from error responses or failure contexts. + /// + /// # Arguments + /// + /// * `restore_fn` - Function that takes the output and [`RestoreInputArgs`] containing + /// context about the retry attempt, and returns either a restored input and modified + /// output, or just the output to abort retry + /// + /// # Example + /// + /// ```rust + /// # use std::ops::ControlFlow; + /// # use seatbelt::retry::{Retry, RestoreInputArgs}; + /// # use seatbelt::{RecoveryInfo, ResilienceContext}; + /// # use tick::Clock; + /// # fn example() { + /// # let clock = Clock::new_frozen(); + /// # let context = ResilienceContext::new(&clock); + /// #[derive(Clone)] + /// struct HttpRequest { + /// url: String, + /// body: Vec, + /// } + /// + /// enum HttpResult { + /// Success(String), + /// ConnectionError { original_request: HttpRequest }, + /// ServerError(u16), + /// } + /// + /// let layer = Retry::layer("http_retry", &context) + /// .clone_input_with(|_request, _args| None) // Don't clone expensive request bodies + /// .restore_input(|result: &mut HttpResult, _args| { + /// match result { + /// // Extract the original request from the error for retry + /// HttpResult::ConnectionError { original_request } => { + /// let request = original_request.clone(); + /// *result = HttpResult::ServerError(0); + /// Some(request) + /// } + /// _ => None, + /// } + /// }) + /// .recovery_with(|result, _args| match result { + /// HttpResult::ConnectionError { .. } => RecoveryInfo::retry(), + /// _ => RecoveryInfo::never(), + /// }); + /// # } + /// ``` + #[must_use] + pub fn restore_input(mut self, restore_fn: impl Fn(&mut Out, RestoreInputArgs) -> Option + Send + Sync + 'static) -> Self { + self.restore_input = Some(RestoreInput::new(restore_fn)); + self + } +} + +impl Layer for RetryLayer { + type Service = Retry; + + fn layer(&self, inner: S) -> Self::Service { + Retry { + inner, + clock: self.context.get_clock().clone(), + max_attempts: self.max_attempts, + backoff: self.backoff.clone().into(), + clone_input: self.clone_input.clone().expect("clone_input must be set in Ready state"), + should_recover: self.should_recover.clone().expect("should_recover must be set in Ready state"), + on_retry: self.on_retry.clone(), + enable_if: self.enable_if.clone(), + #[cfg(any(feature = "logs", feature = "metrics", test))] + telemetry: self.telemetry.clone(), + restore_input: self.restore_input.clone(), + handle_unavailable: self.handle_unavailable, + } + } +} + +impl RetryLayer, S1, S2> { + /// Sets a specialized input restoration callback that operates only on error cases. + /// + /// This is a convenience method for working with `Result` outputs, where you + /// only want to restore input when an error occurs. The callback receives a mutable reference + /// to the error and can extract the original input from it, while potentially modifying the + /// error for the next attempt. + /// + /// This method is particularly useful when: + /// - Your service returns `Result` where the error type contains recoverable request data + /// - You want to extract and restore input only from error cases, not successful responses + /// - You need to modify the error (e.g., to remove sensitive data) before the next retry + /// + /// # Parameters + /// + /// * `restore_fn` - A function that takes a mutable reference to the error and restoration + /// arguments, returning `Some(input)` if the input can be restored from the error, or + /// `None` if restoration is not possible or desired. + /// + /// # Example + /// + /// ```rust + /// # use tick::Clock; + /// # use seatbelt::retry::*; + /// # use seatbelt::{RecoveryInfo, ResilienceContext}; + /// # #[derive(Clone)] + /// # struct HttpRequest { url: String, body: Vec } + /// # struct HttpResponse { status: u16 } + /// # enum HttpError { + /// # ConnectionError { original_request: HttpRequest }, + /// # ServerError(u16), + /// # AuthError, + /// # } + /// # impl HttpError { + /// # fn try_restore_request(&mut self) -> Option { + /// # match self { + /// # HttpError::ConnectionError { original_request } => { + /// # Some(original_request.clone()) + /// # }, + /// # _ => None, + /// # } + /// # } + /// # } + /// # fn example(clock: Clock) { + /// # let context = ResilienceContext::>::new(&clock); + /// type HttpResult = Result; + /// + /// let layer = Retry::layer("http_retry", &context).restore_input_from_error( + /// |error: &mut HttpError, _args| { + /// // Only restore input from connection errors that contain the original request + /// error.try_restore_request() + /// }, + /// ); + /// # } + /// ``` + #[must_use] + pub fn restore_input_from_error(self, restore_fn: impl Fn(&mut Error, RestoreInputArgs) -> Option + Send + Sync + 'static) -> Self { + self.restore_input(move |input, args| match input { + Ok(_) => None, + Err(e) => restore_fn(e, args), + }) + } +} + +impl RetryLayer { + fn into_state(self) -> RetryLayer { + RetryLayer { + context: self.context, + max_attempts: self.max_attempts, + backoff: self.backoff, + clone_input: self.clone_input, + should_recover: self.should_recover, + on_retry: self.on_retry, + enable_if: self.enable_if, + telemetry: self.telemetry, + restore_input: self.restore_input, + handle_unavailable: self.handle_unavailable, + _state: PhantomData, + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::fmt::Debug; + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + + use layered::Execute; + use tick::Clock; + + use super::*; + use crate::retry::Attempt; + use crate::testing::RecoverableType; + + #[test] + fn new_creates_correct_initial_state() { + let context = create_test_context(); + let layer: RetryLayer<_, _, NotSet, NotSet> = RetryLayer::new("test_retry".into(), &context); + + assert_eq!(layer.max_attempts, MaxAttempts::Finite(4)); // 3 retries + 1 original = 4 total + assert!(matches!(layer.backoff.backoff_type, Backoff::Exponential)); + assert_eq!(layer.backoff.base_delay, Duration::from_secs(2)); + assert!(layer.backoff.max_delay.is_none()); + assert!(layer.backoff.use_jitter); // Default is true + assert!(layer.clone_input.is_none()); + assert!(layer.should_recover.is_none()); + assert!(layer.on_retry.is_none()); + assert_eq!(layer.telemetry.strategy_name.as_ref(), "test_retry"); + assert!(layer.enable_if.call(&"test_input".to_string())); + } + + #[test] + fn clone_input_sets_correctly() { + let context = create_test_context(); + let layer = RetryLayer::new("test".into(), &context); + + let layer: RetryLayer<_, _, Set, NotSet> = layer.clone_input_with(|input, _args| Some(input.clone())); + + let result = layer.clone_input.unwrap().call( + &mut "test".to_string(), + CloneArgs { + attempt: Attempt::new(0, false), + previous_recovery: None, + }, + ); + assert_eq!(result, Some("test".to_string())); + } + + #[test] + fn recovery_sets_correctly() { + let context = create_test_context(); + let layer = RetryLayer::new("test".into(), &context); + + let layer: RetryLayer<_, _, NotSet, Set> = layer.recovery_with(|output, _args| { + if output.contains("error") { + RecoveryInfo::retry() + } else { + RecoveryInfo::never() + } + }); + + let result = layer.should_recover.as_ref().unwrap().call( + &"error message".to_string(), + RecoveryArgs { + attempt: Attempt::new(1, false), + clock: context.get_clock(), + }, + ); + assert_eq!(result, RecoveryInfo::retry()); + + let result = layer.should_recover.as_ref().unwrap().call( + &"success".to_string(), + RecoveryArgs { + attempt: Attempt::new(1, false), + clock: context.get_clock(), + }, + ); + assert_eq!(result, RecoveryInfo::never()); + } + + #[test] + fn recovery_auto_sets_correctly() { + let context = ResilienceContext::::new(Clock::new_frozen()); + let layer = RetryLayer::new("test".into(), &context); + + let layer: RetryLayer<_, _, NotSet, Set> = layer.recovery(); + + let result = layer.should_recover.as_ref().unwrap().call( + &RecoverableType::from(RecoveryInfo::retry()), + RecoveryArgs { + attempt: Attempt::new(1, false), + clock: context.get_clock(), + }, + ); + assert_eq!(result, RecoveryInfo::retry()); + + let result = layer.should_recover.as_ref().unwrap().call( + &RecoverableType::from(RecoveryInfo::never()), + RecoveryArgs { + attempt: Attempt::new(1, false), + clock: context.get_clock(), + }, + ); + assert_eq!(result, RecoveryInfo::never()); + } + + #[test] + fn configuration_methods_work() { + let layer = create_ready_layer() + .max_retry_attempts(5) + .backoff(Backoff::Exponential) + .base_delay(Duration::from_millis(500)) + .max_delay(Duration::from_secs(30)) + .use_jitter(true); + + assert_eq!(layer.max_attempts, MaxAttempts::Finite(6)); + assert!(matches!(layer.backoff.backoff_type, Backoff::Exponential)); + assert_eq!(layer.backoff.base_delay, Duration::from_millis(500)); + assert_eq!(layer.backoff.max_delay, Some(Duration::from_secs(30))); + assert!(layer.backoff.use_jitter); + } + + #[test] + fn on_retry_works() { + let called = Arc::new(AtomicU32::new(0)); + let called_clone = Arc::clone(&called); + + let layer = create_ready_layer().on_retry(move |_output, _args| { + called_clone.fetch_add(1, Ordering::SeqCst); + }); + + layer.on_retry.unwrap().call( + &"output".to_string(), + OnRetryArgs { + retry_delay: Duration::ZERO, + attempt: Attempt::new(1, false), + recovery: RecoveryInfo::retry(), + }, + ); + + assert_eq!(called.load(Ordering::SeqCst), 1); + } + + #[test] + fn enable_disable_conditions_work() { + let layer = create_ready_layer().enable_if(|input| input.contains("enable")); + + assert!(layer.enable_if.call(&"enable_test".to_string())); + assert!(!layer.enable_if.call(&"disable_test".to_string())); + + let layer = layer.disable(); + assert!(!layer.enable_if.call(&"anything".to_string())); + + let layer = layer.enable_always(); + assert!(layer.enable_if.call(&"anything".to_string())); + } + + #[test] + fn layer_builds_service_when_ready() { + let layer = create_ready_layer(); + let _service = layer.layer(Execute::new(|input: String| async move { input })); + } + + #[test] + fn handle_unavailable_sets_correctly() { + let context = create_test_context(); + let layer = RetryLayer::new("test".into(), &context); + + // Test default value + assert!(!layer.handle_unavailable); + + // Test enabling outage handling + let layer = layer.handle_unavailable(true); + assert!(layer.handle_unavailable); + + // Test disabling outage handling + let layer = layer.handle_unavailable(false); + assert!(!layer.handle_unavailable); + } + + #[test] + fn restore_input_sets_correctly() { + let context = create_test_context(); + let layer = RetryLayer::new("test".into(), &context); + + let layer = layer.restore_input(|output: &mut String, _args| { + (output == "restore_me").then(|| { + *output = "modified_output".to_string(); + "restored_input".to_string() + }) + }); + + let mut test_output = "restore_me".to_string(); + let result = layer.restore_input.as_ref().unwrap().call( + &mut test_output, + RestoreInputArgs { + attempt: Attempt::new(1, false), + recovery: RecoveryInfo::retry(), + }, + ); + + match result { + Some(input) => { + assert_eq!(input, "restored_input"); + assert_eq!(test_output, "modified_output"); + } + None => panic!("Expected Some, got None"), + } + + let mut test_output2 = "no_restore".to_string(); + let result = layer.restore_input.as_ref().unwrap().call( + &mut test_output2, + RestoreInputArgs { + attempt: Attempt::new(1, false), + recovery: RecoveryInfo::retry(), + }, + ); + + match result { + None => { + assert_eq!(test_output2, "no_restore"); + } + Some(_) => panic!("Expected None, got Some"), + } + } + + #[test] + fn infinite_retry_attempts_sets_correctly() { + let context = create_test_context(); + let layer = RetryLayer::new("test".into(), &context).infinite_retry_attempts(); + assert_eq!(layer.max_attempts, MaxAttempts::Infinite); + } + + #[test] + fn restore_input_from_error_sets_correctly() { + let context: ResilienceContext> = ResilienceContext::new(Clock::new_frozen()).name("test"); + let layer = RetryLayer::new("test".into(), &context) + .restore_input_from_error(|e: &mut String, _| (e == "restore").then(|| std::mem::take(e))); + + let restore = layer.restore_input.as_ref().unwrap(); + let args = || RestoreInputArgs { + attempt: Attempt::new(1, false), + recovery: RecoveryInfo::retry(), + }; + + assert_eq!(restore.call(&mut Err("restore".into()), args()), Some("restore".to_string())); + assert_eq!(restore.call(&mut Err("other".into()), args()), None); + assert_eq!(restore.call(&mut Ok("success".into()), args()), None); + } + + #[test] + fn static_assertions() { + static_assertions::assert_impl_all!(RetryLayer: Layer); + static_assertions::assert_not_impl_all!(RetryLayer: Layer); + static_assertions::assert_not_impl_all!(RetryLayer: Layer); + static_assertions::assert_impl_all!(RetryLayer: Debug); + } + + fn create_test_context() -> ResilienceContext { + ResilienceContext::new(Clock::new_frozen()).name("test_pipeline") + } + + fn create_ready_layer() -> RetryLayer { + RetryLayer::new("test".into(), &create_test_context()) + .clone_input_with(|input, _args| Some(input.clone())) + .recovery_with(|output, _args| { + if output.contains("error") { + RecoveryInfo::retry() + } else { + RecoveryInfo::never() + } + }) + } +} diff --git a/crates/seatbelt/src/retry/mod.rs b/crates/seatbelt/src/retry/mod.rs new file mode 100644 index 00000000..e7497ebc --- /dev/null +++ b/crates/seatbelt/src/retry/mod.rs @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Retry resilience middleware for services, applications, and libraries. +//! +//! This module provides automatic retry capabilities with configurable backoff strategies, +//! jitter, recovery classification, and comprehensive telemetry. The primary types are +//! [`Retry`] and [`RetryLayer`]: +//! +//! - [`Retry`] is the middleware that wraps an inner service and automatically retries failed operations +//! - [`RetryLayer`] is used to configure and construct the retry middleware +//! +//! # Quick Start +//! +//! ```rust +//! # use tick::Clock; +//! # use layered::{Execute, Service, Stack}; +//! # use seatbelt::retry::{Retry, Backoff}; +//! # use seatbelt::{RecoveryInfo, ResilienceContext}; +//! # async fn example(clock: Clock) -> Result<(), Box> { +//! let context = ResilienceContext::new(&clock).name("my_service"); +//! +//! let stack = ( +//! Retry::layer("retry", &context) +//! .clone_input() +//! .recovery_with(|result, _| match result { +//! Ok(_) => RecoveryInfo::never(), +//! Err(_) => RecoveryInfo::retry(), +//! }), +//! Execute::new(my_operation), +//! ); +//! +//! let service = stack.into_service(); +//! let result = service.execute("input".to_string()).await; +//! # let _result = result; +//! # Ok(()) +//! # } +//! # async fn my_operation(input: String) -> Result { Ok(input) } +//! ``` +//! +//! # Configuration +//! +//! The [`RetryLayer`] uses a type state pattern to enforce that all required properties are +//! configured before the layer can be built. This compile-time safety ensures that you cannot +//! accidentally create a retry layer without properly specifying input cloning and recovery logic: +//! +//! - [`clone_input_with`][RetryLayer::clone_input_with]: Required function to clone inputs for retries (Rust ownership requirement) +//! - [`recovery`][RetryLayer::recovery]: Required function to determine if an output should trigger a retry +//! +//! Each retry layer requires an identifier for telemetry purposes. This identifier should use +//! `snake_case` naming convention to maintain consistency across the codebase. +//! +//! # Thread Safety +//! +//! The [`Retry`] type is thread-safe and implements both `Send` and `Sync` as enforced by +//! the `Service` trait it implements. This allows retry middleware to be safely shared +//! across multiple threads and used in concurrent environments. +//! +//! # Retry Delay +//! +//! Retry delays are determined by the following priority order: +//! +//! 1. **Recovery Delay Override**: If [`RetryLayer::recovery_with`] returns +//! [`RecoveryInfo::retry().delay()`][crate::RecoveryInfo::delay] or [`RecoveryInfo::unavailable()`][crate::RecoveryInfo::unavailable] with a specific duration, this delay +//! is used directly. +//! +//! 2. **Backoff Strategy**: When no recovery delay is specified, delays are calculated using +//! the configured backoff strategy (Constant, Linear, or Exponential with default `2s` base delay). +//! +//! # Defaults +//! +//! The retry middleware uses the following default values when optional configuration is not provided: +//! +//! | Parameter | Default Value | Description | Configured By | +//! |-----------|---------------|-------------|---------------| +//! | Max retry attempts | `3` (4 total) | Maximum retry attempts plus original call | [`max_retry_attempts`][RetryLayer::max_retry_attempts], [`infinite_retry_attempts`][RetryLayer::infinite_retry_attempts] | +//! | Base delay | `2` seconds | Base delay used for backoff calculations | [`base_delay`][RetryLayer::base_delay] | +//! | Backoff strategy | `Exponential` | Exponential backoff with base multiplier of 2 | [`backoff`][RetryLayer::backoff] | +//! | Jitter | `Enabled` | Adds randomness to delays to prevent thundering herds | [`use_jitter`][RetryLayer::use_jitter] | +//! | Max delay | `None` | No limit on maximum delay between retries | [`max_delay`][RetryLayer::max_delay] | +//! | Enable condition | Always enabled | Retry protection is applied to all requests | [`enable_if`][RetryLayer::enable_if], [`enable_always`][RetryLayer::enable_always], [`disable`][RetryLayer::disable] | +//! +//! These defaults provide a reasonable starting point for most use cases, offering a balance +//! between resilience and avoiding an excessive load on downstream services. +//! +//! # Telemetry +//! +//! ## Metrics +//! +//! - **Metric**: `resilience.event` (counter) +//! - **When**: Emitted for each attempt that should be retried (including the final retry attempt) +//! - **Attributes**: +//! - `resilience.pipeline.name`: Pipeline identifier from [`ResilienceContext::name`][crate::ResilienceContext::name] +//! - `resilience.strategy.name`: Timeout identifier from [`Retry::layer`] +//! - `resilience.event.name`: Always `retry` +//! - `resilience.attempt.index`: Attempt index (0-based) +//! - `resilience.attempt.is_last`: Boolean indicating if this is the last retry attempt +//! +//! # Examples +//! +//! ## Basic Usage +//! +//! This example demonstrates the basic usage of configuring and using retry middleware. +//! +//! ```rust +//! # use std::time::Duration; +//! # use tick::Clock; +//! # use layered::{Execute, Service, Stack}; +//! # use seatbelt::retry::{Retry, Backoff}; +//! # use seatbelt::{RecoveryInfo, ResilienceContext}; +//! # async fn example(clock: Clock) -> Result<(), String> { +//! // Define common options for resilience middleware. The clock is runtime-specific and +//! // must be provided. See its documentation for details. +//! let context = ResilienceContext::new(&clock).name("example"); +//! +//! let stack = ( +//! Retry::layer("my_retry", &context) +//! // Required: how to clone inputs for retries +//! .clone_input() +//! // Required: determine if we should retry based on output +//! .recovery_with(|output: &Result, _args| match output { +//! // These are demonstrative, real code will have more meaningful recovery detection +//! Ok(_) => RecoveryInfo::never(), +//! Err(msg) if msg.contains("transient") => RecoveryInfo::retry(), +//! Err(_) => RecoveryInfo::never(), +//! }), +//! Execute::new(execute_unreliable_operation), +//! ); +//! +//! // Build the service +//! let service = stack.into_service(); +//! +//! // Execute the service +//! let result = service.execute("test input".to_string()).await; +//! # let _result = result; +//! # Ok(()) +//! # } +//! # async fn execute_unreliable_operation(input: String) -> Result { Ok(input) } +//! ``` +//! +//! ## Advanced Usage +//! +//! This example demonstrates advanced usage of the retry middleware, including custom backoff +//! strategies, delay generators, and retry callbacks. +//! +//! ```rust +//! # use std::time::Duration; +//! # use tick::Clock; +//! # use std::io; +//! # use layered::{Execute, Stack, Service}; +//! # use seatbelt::retry::{Retry, Backoff}; +//! # use seatbelt::{RecoveryInfo, ResilienceContext}; +//! # async fn example(clock: Clock) -> Result<(), String> { +//! // Define common options for resilience middleware. +//! let context = ResilienceContext::new(&clock); +//! +//! let stack = ( +//! Retry::layer("advanced_retry", &context) +//! .clone_input() +//! .recovery_with(|output: &Result, _args| match output { +//! Err(err) if err.kind() == io::ErrorKind::TimedOut => RecoveryInfo::retry().delay(Duration::from_secs(60)), +//! Err(_) => RecoveryInfo::never(), +//! Ok(_) => RecoveryInfo::never(), +//! }) +//! // Optional configuration +//! .max_retry_attempts(5) +//! .base_delay(Duration::from_millis(200)) +//! .backoff(Backoff::Exponential) +//! .use_jitter(true) +//! // Callback called just before the next retry +//! .on_retry(|output, args| { +//! println!( +//! "retrying, attempt: {}, delay: {}ms", +//! args.attempt(), +//! args.retry_delay().as_millis(), +//! ); +//! }), +//! Execute::new(execute_unreliable_operation), +//! ); +//! +//! // Build and execute the service +//! let service = stack.into_service(); +//! let result = service.execute("test_timeout".to_string()).await; +//! # let _result = result; +//! # Ok(()) +//! # } +//! # async fn execute_unreliable_operation(input: String) -> Result { Ok(input) } +//! ``` + +mod args; +mod attempt; +mod backoff; +mod callbacks; +mod constants; +mod layer; +mod service; +#[cfg(any(feature = "metrics", test))] +mod telemetry; + +pub use args::{CloneArgs, OnRetryArgs, RecoveryArgs, RestoreInputArgs}; +pub use attempt::Attempt; +pub(crate) use attempt::MaxAttempts; +pub use backoff::Backoff; +pub(crate) use backoff::DelayBackoff; +pub(crate) use callbacks::{CloneInput, OnRetry, RestoreInput, ShouldRecover}; +pub use layer::RetryLayer; +pub use service::Retry; diff --git a/crates/seatbelt/src/retry/service.rs b/crates/seatbelt/src/retry/service.rs new file mode 100644 index 00000000..2cdf6b6a --- /dev/null +++ b/crates/seatbelt/src/retry/service.rs @@ -0,0 +1,652 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::ops::ControlFlow; +use std::time::Duration; + +use layered::Service; +use tick::Clock; + +use super::MaxAttempts; +use crate::retry::{ + CloneArgs, CloneInput, DelayBackoff, OnRetry, OnRetryArgs, RecoveryArgs, RestoreInput, RestoreInputArgs, ShouldRecover, +}; +use crate::utils::EnableIf; +use crate::{NotSet, RecoveryInfo, RecoveryKind, retry::Attempt}; + +/// Applies retry logic to service execution for transient error handling. +/// +/// `Retry` wraps an inner [`Service`] and automatically retries failed operations +/// based on configurable recovery classification, backoff strategies, and delay generation. +/// This middleware is designed to be used across services, applications, and libraries +/// to handle transient failures gracefully. +/// +/// This middleware requires input cloning capabilities and recovery classification to determine +/// retry eligibility. +/// +/// Retry is configured by calling [`Retry::layer`] and using the +/// builder methods on the returned [`RetryLayer`][crate::retry::RetryLayer] instance. +/// +/// For comprehensive examples and usage patterns, see the [retry module][crate::retry] documentation. +#[derive(Debug)] +#[expect(clippy::struct_field_names, reason = "Fields are named for clarity")] +pub struct Retry { + pub(super) inner: S, + pub(super) clock: Clock, + pub(super) max_attempts: MaxAttempts, + pub(super) backoff: DelayBackoff, + pub(super) clone_input: CloneInput, + pub(super) should_recover: ShouldRecover, + pub(super) on_retry: Option>, + pub(super) enable_if: EnableIf, + #[cfg(any(feature = "logs", feature = "metrics", test))] + pub(super) telemetry: crate::utils::TelemetryHelper, + pub(super) restore_input: Option>, + pub(super) handle_unavailable: bool, +} + +impl Retry { + /// Creates a new retry layer with the specified name and options. + /// + /// Returns a [`RetryLayer`][crate::retry::RetryLayer] that must be configured with required parameters + /// before it can be used to build a retry service. + pub fn layer( + name: impl Into>, + context: &crate::ResilienceContext, + ) -> crate::retry::RetryLayer { + crate::retry::RetryLayer::new(name.into(), context) + } +} + +impl Service for Retry +where + In: Send, + S: Service, +{ + type Out = Out; + + #[cfg_attr(test, mutants::skip)] // Mutating enable_if check causes infinite loops + async fn execute(&self, mut input: In) -> Self::Out { + // Check if retry is enabled for this input + if !self.enable_if.call(&input) { + return self.inner.execute(input).await; + } + + let mut attempt = self.max_attempts.first_attempt(); + let mut delays = self.backoff.delays(); + let mut previous_recovery = None; + + loop { + match self.execute_attempt(input, attempt, &mut delays, previous_recovery).await { + ControlFlow::Continue((next_input, next_attempt, recovery)) => { + input = next_input; + attempt = next_attempt; + previous_recovery = Some(recovery); + } + ControlFlow::Break(out) => return out, + } + } + } +} + +impl Retry +where + In: Send, + S: Service, +{ + async fn execute_attempt( + &self, + mut input: In, + attempt: Attempt, + delays: &mut impl Iterator, + previous_recovery: Option, + ) -> ControlFlow { + let (original_input, attempt_input) = match self.clone_input.call( + &mut input, + CloneArgs { + attempt, + previous_recovery, + }, + ) { + Some(cloned) => (Some(input), cloned), + None => (None, input), + }; + + // Execute the operation + let out = self.inner.execute(attempt_input).await; + + // Check if we should recover from this output + let recovery = self.should_recover.call( + &out, + RecoveryArgs { + attempt, + clock: &self.clock, + }, + ); + + // Detect if we can recover from output + let recovery_kind = match recovery.kind() { + RecoveryKind::Unavailable => { + if self.handle_unavailable { + RecoverableKind::Retry + } else { + return ControlFlow::Break(out); + } + } + RecoveryKind::Retry => RecoverableKind::Retry, + // Handle future variants - treat unknown variants as non-recoverable + RecoveryKind::Never | RecoveryKind::Unknown | _ => return ControlFlow::Break(out), + }; + + // If no more attempts left, report telemetry, and return the last output + let Some(next_attempt) = attempt.increment(self.max_attempts) else { + self.emit_attempt_telemetry(attempt, Duration::ZERO); + return ControlFlow::Break(out); + }; + + // Always get the next delay, even if we won't use it. This is because we want to + // advance the backoff strategy (e.g., exponential backoff), so the next retry uses the + // correct delay when it's not explicitly overridden by the recovery. + let retry_delay = delays.next().unwrap_or(Duration::ZERO); + + // Use the recovery delay if provided, otherwise use the backoff delay + let retry_delay = recovery.get_delay().unwrap_or(retry_delay); + + // At this point, we know that the output is recoverable and that we have more attempts left. + // Determine the delay before the next attempt based on the recovery kind. + let flow_control = match recovery_kind { + RecoverableKind::Retry => self.finalize_retryable_attempt(original_input, out, attempt, next_attempt, retry_delay, recovery), + }; + + // Only ever delay if we have a next attempt + if matches!(flow_control, ControlFlow::Continue(_)) { + self.clock.delay(retry_delay).await; + } + + flow_control + } +} + +enum RecoverableKind { + Retry, +} + +impl Retry { + #[cfg_attr( + not(any(feature = "logs", test)), + expect(unused_variables, reason = "unused when logs feature not used") + )] + fn emit_attempt_telemetry(&self, attempt: Attempt, retry_delay: Duration) { + #[cfg(any(feature = "logs", test))] + if self.telemetry.logs_enabled { + tracing::event!( + name: "seatbelt.retry", + tracing::Level::WARN, + pipeline.name = %self.telemetry.pipeline_name, + strategy.name = %self.telemetry.strategy_name, + resilience.attempt.index = attempt.index(), + resilience.attempt.is_last = attempt.is_last(), + resilience.retry.delay = retry_delay.as_secs_f32(), + ); + } + + #[cfg(any(feature = "metrics", test))] + if self.telemetry.metrics_enabled() { + use super::telemetry::{ATTEMPT_INDEX, ATTEMPT_NUMBER_IS_LAST, RETRY_EVENT}; + use crate::utils::{EVENT_NAME, PIPELINE_NAME, STRATEGY_NAME}; + + self.telemetry.report_metrics(&[ + opentelemetry::KeyValue::new(PIPELINE_NAME, self.telemetry.pipeline_name.clone()), + opentelemetry::KeyValue::new(STRATEGY_NAME, self.telemetry.strategy_name.clone()), + opentelemetry::KeyValue::new(EVENT_NAME, RETRY_EVENT), + opentelemetry::KeyValue::new(ATTEMPT_INDEX, i64::from(attempt.index())), + opentelemetry::KeyValue::new(ATTEMPT_NUMBER_IS_LAST, attempt.is_last()), + ]); + } + } + + #[inline] + fn finalize_retryable_attempt( + &self, + mut original_input: Option, + mut out: Out, + attempt: Attempt, + next_attempt: Attempt, + retry_delay: Duration, + recovery: RecoveryInfo, + ) -> ControlFlow { + // we emit attempt telemetry even if the next attempt does not happen + self.emit_attempt_telemetry(attempt, retry_delay); + + // If we have a restore input callback, we can use it to restore the input for the next attempt if + // the original input was not clonable. + if original_input.is_none() + && let Some(restore) = &self.restore_input + && let Some(input) = restore.call( + &mut out, + RestoreInputArgs { + attempt, + recovery: recovery.clone(), + }, + ) + { + original_input = Some(input); + } + + match original_input { + Some(input) => { + // Only invoke on-retry if there will be next attempt + if let Some(ref on_retry) = self.on_retry { + on_retry.call( + &out, + OnRetryArgs { + attempt, + retry_delay, + recovery: recovery.clone(), + }, + ); + } + ControlFlow::Continue((input, next_attempt, recovery)) + } + None => ControlFlow::Break(out), + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(not(miri))] // Oxidizer runtime does not support Miri. +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicU32, Ordering}; + use std::sync::{Arc, Mutex}; + + use layered::Execute; + use opentelemetry::KeyValue; + use tick::ClockControl; + + use super::*; + use crate::retry::{Backoff, RetryLayer}; + use crate::testing::MetricTester; + use crate::{ResilienceContext, Set}; + use layered::Layer; + + #[test] + fn layer_ensure_defaults() { + let context = ResilienceContext::::new(Clock::new_frozen()).name("test_pipeline"); + let layer: RetryLayer = Retry::layer("test_retry", &context); + let layer = layer.recovery_with(|_, _| RecoveryInfo::never()).clone_input(); + + let retry = layer.layer(Execute::new(|v: String| async move { v })); + + assert_eq!(retry.telemetry.pipeline_name.to_string(), "test_pipeline"); + assert_eq!(retry.telemetry.strategy_name.to_string(), "test_retry"); + assert_eq!(retry.max_attempts, MaxAttempts::Finite(4)); + assert_eq!(retry.backoff.0.base_delay, Duration::from_secs(2)); + assert_eq!(retry.backoff.0.backoff_type, Backoff::Exponential); + assert!(retry.backoff.0.use_jitter); + assert!(retry.on_retry.is_none()); + assert!(retry.enable_if.call(&"str".to_string())); + } + + #[tokio::test] + async fn retry_disabled_no_inner_calls() { + let clock = Clock::new_frozen(); + let counter = Arc::new(AtomicU32::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(move |input, _args| { + counter_clone.fetch_add(1, Ordering::SeqCst); + Some(input.clone()) + }) + .disable() + .layer(Execute::new(move |v: String| async move { v })); + + let result = service.execute("test".to_string()).await; + + assert_eq!(result, "test"); + assert_eq!(counter.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn uncloneable_recovery_called() { + let clock = Clock::new_frozen(); + let counter = Arc::new(AtomicU32::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(move |_input, _args| None) + .recovery_with(move |_input, _args| { + counter_clone.fetch_add(1, Ordering::SeqCst); + RecoveryInfo::retry() + }) + .layer(Execute::new(move |v: String| async move { v })); + + let result = service.execute("test".to_string()).await; + + assert_eq!(result, "test"); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn no_recovery_ensure_no_additional_retries() { + let clock = Clock::new_frozen(); + let counter = Arc::new(AtomicU32::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(move |input, _args| { + counter_clone.fetch_add(1, Ordering::SeqCst); + Some(input.clone()) + }) + .recovery_with(move |_input, _args| RecoveryInfo::never()) + .layer(Execute::new(move |v: String| async move { v })); + + let result = service.execute("test".to_string()).await; + + assert_eq!(result, "test"); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn retry_recovery_ensure_retries_exhausted() { + let clock = ClockControl::default().auto_advance_timers(true).to_clock(); + let counter = Arc::new(AtomicU32::new(0)); + let counter_clone = std::sync::Arc::clone(&counter); + + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(move |input, _args| { + counter_clone.fetch_add(1, Ordering::SeqCst); + Some(input.clone()) + }) + .recovery_with(move |_input, _args| RecoveryInfo::retry()) + .max_retry_attempts(4) + .layer(Execute::new(move |v: String| async move { v })); + + let result = service.execute("test".to_string()).await; + + assert_eq!(result, "test"); + assert_eq!(counter.load(Ordering::SeqCst), 5); + } + + #[tokio::test] + async fn retry_recovery_ensure_correct_delays() { + let clock = ClockControl::default().auto_advance_timers(true).to_clock(); + let delays = Arc::new(Mutex::new(vec![])); + let delays_clone = Arc::clone(&delays); + + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(move |input, _args| Some(input.clone())) + .use_jitter(false) + .backoff(Backoff::Linear) + .recovery_with(move |_input, _args| RecoveryInfo::retry()) + .max_retry_attempts(4) + .on_retry(move |_output, args| { + delays_clone.lock().unwrap().push(args.retry_delay()); + }) + .layer(Execute::new(move |v: String| async move { v })); + + let _result = service.execute("test".to_string()).await; + + assert_eq!( + delays.lock().unwrap().to_vec(), + vec![ + Duration::from_secs(2), + Duration::from_secs(4), + Duration::from_secs(6), + Duration::from_secs(8), + ] + ); + } + + #[tokio::test] + async fn retry_recovery_ensure_correct_attempts() { + let clock = ClockControl::default().auto_advance_timers(true).to_clock(); + let attempts = Arc::new(Mutex::new(vec![])); + let attempts_clone = Arc::clone(&attempts); + + let attempts_for_clone = Arc::new(Mutex::new(vec![])); + let attempts_for_clone_clone = Arc::clone(&attempts_for_clone); + + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(move |input, args| { + attempts_for_clone_clone.lock().unwrap().push(args.attempt()); + Some(input.clone()) + }) + .recovery_with(move |_input, _args| RecoveryInfo::retry()) + .max_retry_attempts(4) + .on_retry(move |_output, args| { + attempts_clone.lock().unwrap().push(args.attempt()); + }) + .layer(Execute::new(move |v: String| async move { v })); + + let _result = service.execute("test".to_string()).await; + + assert_eq!( + attempts_for_clone.lock().unwrap().to_vec(), + vec![ + Attempt::new(0, false), + Attempt::new(1, false), + Attempt::new(2, false), + Attempt::new(3, false), + Attempt::new(4, true), + ] + ); + + assert_eq!( + attempts.lock().unwrap().to_vec(), + vec![ + Attempt::new(0, false), + Attempt::new(1, false), + Attempt::new(2, false), + Attempt::new(3, false), + ] + ); + } + + #[tokio::test] + async fn restore_input_integration_test() { + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + + let clock = ClockControl::default().auto_advance_timers(true).to_clock(); + let call_count = Arc::new(AtomicU32::new(0)); + let call_count_clone = std::sync::Arc::clone(&call_count); + let restore_count = Arc::new(AtomicU32::new(0)); + let restore_count_clone = std::sync::Arc::clone(&restore_count); + + // Create a service that fails on first attempt but can restore input + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(|_input, _args| None) // Don't clone - force restore path + .restore_input(move |output: &mut String, _args| { + restore_count_clone.fetch_add(1, Ordering::SeqCst); + output.contains("error:").then(|| { + let input = output.replace("error:", ""); + *output = "restored".to_string(); + input + }) + }) + .recovery_with(|output, _args| { + if output.contains("error:") { + RecoveryInfo::retry() + } else { + RecoveryInfo::never() + } + }) + .max_retry_attempts(2) + .layer(Execute::new(move |input: String| { + let count = call_count_clone.fetch_add(1, Ordering::SeqCst); + async move { + if count == 0 { + // First call fails with input stored in error + format!("error:{input}") + } else { + // Subsequent calls succeed + format!("success:{input}") + } + } + })); + + let result = service.execute("test_input".to_string()).await; + + // Verify the restore path was used and retry succeeded + assert_eq!(result, "success:test_input"); + assert_eq!(call_count.load(Ordering::SeqCst), 2); // Original + 1 retry + assert_eq!(restore_count.load(Ordering::SeqCst), 1); // Restore called once + } + + #[tokio::test] + async fn outage_handling_disabled_no_retries() { + let clock = ClockControl::default().auto_advance_timers(true).to_clock(); + let call_count = Arc::new(AtomicU32::new(0)); + let call_count_clone = Arc::clone(&call_count); + + // Create a service that returns outage (handle_unavailable disabled by default) + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(move |input, _args| { + call_count_clone.fetch_add(1, Ordering::SeqCst); + Some(input.clone()) + }) + .recovery_with(|_output, _args| RecoveryInfo::unavailable()) + .layer(Execute::new(move |v: String| async move { v })); + + let result = service.execute("test".to_string()).await; + + // Should not retry when outage handling is disabled + assert_eq!(result, "test"); + assert_eq!(call_count.load(Ordering::SeqCst), 1); // Only original call, no retries + } + + #[tokio::test] + async fn outage_handling_enabled_with_retries() { + let clock = ClockControl::default().auto_advance_timers(true).to_clock(); + let call_count = Arc::new(AtomicU32::new(0)); + let call_count_clone = std::sync::Arc::clone(&call_count); + + // Create a service that returns outage initially, then succeeds + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(move |input, _args| { + call_count_clone.fetch_add(1, Ordering::SeqCst); + Some(input.clone()) + }) + .recovery_with(|_output, args| { + // First attempt returns outage, subsequent attempts succeed + if args.attempt().index() == 0 { + RecoveryInfo::unavailable() + } else { + RecoveryInfo::never() + } + }) + .handle_unavailable(true) // Enable outage handling + .max_retry_attempts(2) + .layer(Execute::new(move |input: String| async move { format!("processed_{input}") })); + + let result = service.execute("test".to_string()).await; + + // Should retry when outage handling is enabled + assert_eq!(result, "processed_test"); + assert_eq!(call_count.load(Ordering::SeqCst), 2); // Original + 1 retry + } + + #[tokio::test] + async fn outage_handling_with_recovery_hint() { + let clock = ClockControl::default().auto_advance_timers(true).to_clock(); + let delays = Arc::new(Mutex::new(vec![])); + let delays_clone = Arc::clone(&delays); + + // Create a service that returns outage with recovery hint + let service = create_ready_retry_layer(&clock, RecoveryInfo::retry()) + .clone_input_with(move |input, _args| Some(input.clone())) + .recovery_with(|_output, args| { + if args.attempt().index() == 0 { + RecoveryInfo::unavailable().delay(Duration::from_secs(10)) // 10 second recovery hint + } else { + RecoveryInfo::never() + } + }) + .handle_unavailable(true) + .max_retry_attempts(1) + .on_retry(move |_output, args| { + delays_clone.lock().unwrap().push(args.retry_delay()); + }) + .layer(Execute::new(move |v: String| async move { v })); + + let _result = service.execute("test".to_string()).await; + + // Should use the recovery hint as the delay + assert_eq!(delays.lock().unwrap().to_vec(), vec![Duration::from_secs(10)]); + } + + #[tokio::test] + async fn retries_exhausted_ensure_telemetry_reported() { + let tester = MetricTester::new(); + let context = ResilienceContext::::new(ClockControl::default().auto_advance_timers(true).to_clock()) + .name("test_pipeline") + .enable_metrics(tester.meter_provider()); + + let service = create_ready_retry_layer_core(RecoveryInfo::retry(), &context) + .clone_input_with(move |input, _args| Some(input.clone())) + .max_retry_attempts(2) + .recovery_with(move |_input, _args| RecoveryInfo::retry()) + .layer(Execute::new(move |v: String| async move { v })); + + let _result = service.execute("test".to_string()).await; + + tester.assert_attributes( + &[ + KeyValue::new("resilience.attempt.index", 0), + KeyValue::new("resilience.attempt.index", 1), + KeyValue::new("resilience.attempt.is_last", false), + KeyValue::new("resilience.attempt.is_last", true), + KeyValue::new("resilience.pipeline.name", "test_pipeline"), + KeyValue::new("resilience.strategy.name", "test_retry"), + KeyValue::new("resilience.event.name", "retry"), + ], + Some(15), + ); + } + + #[tokio::test] + async fn retry_emits_log() { + use tracing_subscriber::util::SubscriberInitExt; + + use crate::testing::LogCapture; + + let log_capture = LogCapture::new(); + let _guard = log_capture.subscriber().set_default(); + + let clock = ClockControl::default().auto_advance_timers(true).to_clock(); + let context = ResilienceContext::::new(clock) + .name("log_test_pipeline") + .enable_logs(); + + let service = Retry::layer("log_test_retry", &context) + .clone_input() + .recovery_with(|_, _| RecoveryInfo::retry()) + .max_retry_attempts(2) + .layer(Execute::new(|v: String| async move { v })); + + let _ = service.execute("test".to_string()).await; + + log_capture.assert_contains("seatbelt::retry"); + log_capture.assert_contains("log_test_pipeline"); + log_capture.assert_contains("log_test_retry"); + log_capture.assert_contains("resilience.attempt.index"); + log_capture.assert_contains("resilience.retry.delay"); + } + + fn create_ready_retry_layer(clock: &Clock, recover: RecoveryInfo) -> RetryLayer { + let context = ResilienceContext::new(clock.clone()).name("test_pipeline"); + create_ready_retry_layer_core(recover, &context) + } + + fn create_ready_retry_layer_core( + recover: RecoveryInfo, + context: &ResilienceContext, + ) -> RetryLayer { + Retry::layer("test_retry", context) + .recovery_with(move |_, _| recover.clone()) + .clone_input() + .max_delay(Duration::from_secs(9999)) // protect against infinite backoff + } +} diff --git a/crates/seatbelt/src/retry/telemetry.rs b/crates/seatbelt/src/retry/telemetry.rs new file mode 100644 index 00000000..05fc4a38 --- /dev/null +++ b/crates/seatbelt/src/retry/telemetry.rs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/// The name of the retry event for telemetry reporting. +pub(super) const RETRY_EVENT: &str = "retry"; + +/// Attribute key for the retry attempt index. +pub(super) const ATTEMPT_INDEX: &str = "resilience.attempt.index"; + +/// Attribute key for whether this is the last retry attempt. +pub(super) const ATTEMPT_NUMBER_IS_LAST: &str = "resilience.attempt.is_last"; diff --git a/crates/seatbelt/src/rnd.rs b/crates/seatbelt/src/rnd.rs new file mode 100644 index 00000000..326f22ad --- /dev/null +++ b/crates/seatbelt/src/rnd.rs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::fmt::Debug; + +/// Non-cryptographic random number generator used in this crate. +/// +/// This random generator is **NOT cryptography secure** and should only be used for +/// non-security-critical purposes such as load balancing, jitter, sampling, +/// and other scenarios where cryptography guarantees are not required. +/// +/// The `seatbelt` crate does not require cryptography security for its +/// random number generation needs, so this type is provided as a lightweight +/// alternative to more complex `RNG` implementations. +#[derive(Clone, Default)] +pub(crate) enum Rnd { + #[default] + Real, + + #[cfg(test)] + Test(std::sync::Arc f64 + Send + Sync>), +} + +impl Debug for Rnd { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Real => write!(f, "Real"), + #[cfg(test)] + Self::Test(_) => write!(f, "Test"), + } + } +} + +impl Rnd { + #[cfg(test)] + pub fn new_fixed(value: f64) -> Self { + Self::Test(std::sync::Arc::new(move || value)) + } + + #[cfg(test)] + pub fn new_function(f: F) -> Self + where + F: Fn() -> f64 + Send + Sync + 'static, + { + Self::Test(std::sync::Arc::new(f)) + } + + #[cfg_attr(test, mutants::skip)] // Mutating return value causes infinite loops in backoff calculations + pub fn next_f64(&self) -> f64 { + match self { + Self::Real => fastrand::f64(), + #[cfg(test)] + Self::Test(generator) => generator(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn debug_real() { + assert_eq!(format!("{:?}", Rnd::Real), "Real"); + } + + #[test] + fn debug_test() { + assert_eq!(format!("{:?}", Rnd::new_fixed(0.5)), "Test"); + } +} diff --git a/crates/seatbelt/src/shared.rs b/crates/seatbelt/src/shared.rs new file mode 100644 index 00000000..839465a3 --- /dev/null +++ b/crates/seatbelt/src/shared.rs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/// A flag indicating that the required property is set. +#[non_exhaustive] +#[derive(Debug)] +#[doc(hidden)] +pub struct Set; + +/// A flag indicating that the required property has not been set. +#[non_exhaustive] +#[derive(Debug)] +#[doc(hidden)] +pub struct NotSet; diff --git a/crates/seatbelt/src/testing.rs b/crates/seatbelt/src/testing.rs new file mode 100644 index 00000000..42f0f306 --- /dev/null +++ b/crates/seatbelt/src/testing.rs @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#![cfg_attr(miri, expect(dead_code, reason = "too much noise to satisfy Miri's expectations"))] + +use std::io::Write; +use std::sync::{Arc, Mutex}; + +use opentelemetry::KeyValue; +use opentelemetry_sdk::metrics::data::{AggregatedMetrics, Metric, MetricData}; +use opentelemetry_sdk::metrics::{InMemoryMetricExporter, SdkMeterProvider}; +use tracing_subscriber::fmt::MakeWriter; + +use crate::{Recovery, RecoveryInfo}; + +#[derive(Debug)] +pub(crate) struct MetricTester { + exporter: InMemoryMetricExporter, + provider: SdkMeterProvider, +} + +impl Default for MetricTester { + fn default() -> Self { + Self::new() + } +} + +impl MetricTester { + #[must_use] + pub fn new() -> Self { + let in_memory = InMemoryMetricExporter::default(); + + Self { + exporter: in_memory.clone(), + provider: SdkMeterProvider::builder().with_periodic_exporter(in_memory).build(), + } + } + + #[must_use] + pub fn meter_provider(&self) -> &SdkMeterProvider { + &self.provider + } + + #[must_use] + pub fn collect_attributes(&self) -> Vec { + self.provider.force_flush().unwrap(); + collect_attributes(&self.exporter) + } + + pub fn assert_attributes(&self, key_values: &[KeyValue], expected_length: Option) { + let attributes = self.collect_attributes(); + + if let Some(expected_length) = expected_length { + assert_eq!( + attributes.len(), + expected_length, + "expected {} attributes, got {}", + expected_length, + attributes.len() + ); + } + + for attr in key_values { + assert!( + attributes.contains(attr), + "attribute {attr:?} not found in collected attributes: {attributes:?}" + ); + } + } +} + +fn collect_attributes(exporter: &InMemoryMetricExporter) -> Vec { + exporter + .get_finished_metrics() + .unwrap() + .iter() + .flat_map(opentelemetry_sdk::metrics::data::ResourceMetrics::scope_metrics) + .flat_map(opentelemetry_sdk::metrics::data::ScopeMetrics::metrics) + .flat_map(collect_attributes_for_metric) + .collect() +} + +// bleh +fn collect_attributes_for_metric(metric: &Metric) -> impl Iterator { + match metric.data() { + AggregatedMetrics::F64(data) => match data { + MetricData::Gauge(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + MetricData::Sum(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + MetricData::Histogram(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + MetricData::ExponentialHistogram(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + }, + AggregatedMetrics::U64(data) => match data { + MetricData::Gauge(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + MetricData::Sum(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + MetricData::Histogram(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + MetricData::ExponentialHistogram(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + }, + AggregatedMetrics::I64(data) => match data { + MetricData::Gauge(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + MetricData::Sum(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + MetricData::Histogram(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + MetricData::ExponentialHistogram(data) => data.data_points().flat_map(|v| v.attributes().cloned()).collect::>(), + }, + } + .into_iter() +} + +#[derive(Debug)] +pub(crate) struct RecoverableType(RecoveryInfo); + +impl Recovery for RecoverableType { + fn recovery(&self) -> RecoveryInfo { + self.0.clone() + } +} + +impl From for RecoverableType { + fn from(recovery: RecoveryInfo) -> Self { + Self(recovery) + } +} + +/// Thread-local log capture buffer for testing. +/// +/// Uses `tracing_subscriber::fmt::MakeWriter` to capture formatted log output +/// into a thread-local buffer that can be inspected in tests. +#[derive(Debug, Clone, Default)] +pub(crate) struct LogCapture { + buffer: Arc>>, +} + +impl LogCapture { + #[must_use] + pub fn new() -> Self { + Self { + buffer: Arc::new(Mutex::new(Vec::new())), + } + } + + /// Returns the captured log output as a string. + #[must_use] + pub fn output(&self) -> String { + String::from_utf8_lossy(&self.buffer.lock().unwrap()).to_string() + } + + /// Asserts that the captured log output contains the given string. + pub fn assert_contains(&self, expected: &str) { + let output = self.output(); + assert!( + output.contains(expected), + "log output does not contain '{expected}', got:\n{output}" + ); + } + + /// Creates a `tracing_subscriber` that writes to this capture buffer. + /// Use with `set_default()` for thread-local capture. + #[must_use] + pub fn subscriber(&self) -> impl tracing::Subscriber { + use tracing_subscriber::layer::SubscriberExt; + tracing_subscriber::registry().with(tracing_subscriber::fmt::layer().with_writer(self.clone()).with_ansi(false)) + } +} + +impl<'a> MakeWriter<'a> for LogCapture { + type Writer = LogCaptureWriter; + + fn make_writer(&'a self) -> Self::Writer { + LogCaptureWriter { + buffer: Arc::clone(&self.buffer), + } + } +} + +/// Writer that appends to a shared buffer. +pub(crate) struct LogCaptureWriter { + buffer: Arc>>, +} + +impl Write for LogCaptureWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.buffer.lock().unwrap().extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} diff --git a/crates/seatbelt/src/timeout/args.rs b/crates/seatbelt/src/timeout/args.rs new file mode 100644 index 00000000..a88c381b --- /dev/null +++ b/crates/seatbelt/src/timeout/args.rs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Duration; + +/// Arguments passed to timeout callback functions. +/// +/// Contains information about the timeout event that can be used for logging, +/// metrics, or other side effects when a timeout occurs. +#[derive(Debug)] +#[non_exhaustive] +pub struct OnTimeoutArgs { + pub(super) timeout: Duration, +} + +impl OnTimeoutArgs { + /// Returns the timeout duration that was exceeded. + /// + /// This is the duration after which the operation was canceled. + #[must_use] + pub fn timeout(&self) -> Duration { + self.timeout + } +} + +/// Arguments passed to timeout override functions. +#[derive(Debug)] +#[non_exhaustive] +pub struct TimeoutOverrideArgs { + pub(super) default_timeout: Duration, +} + +impl TimeoutOverrideArgs { + /// Returns the default timeout duration configured for the middleware. + /// + /// This can be used as a base value when calculating dynamic timeouts, or to + /// explicitly reuse the default by returning `None` from the override closure. + #[must_use] + pub fn default_timeout(&self) -> Duration { + self.default_timeout + } +} + +/// Arguments passed to timeout output functions. +/// +/// Contains information about the timeout event that occurred, +/// which can be used to create appropriate timeout responses. +#[derive(Debug)] +pub struct TimeoutOutputArgs { + pub(super) timeout: Duration, +} + +impl TimeoutOutputArgs { + /// Returns the timeout duration that was exceeded. + /// + /// This is the duration after which the operation was canceled. + /// It can be used to provide detailed timeout information in the response. + #[must_use] + pub fn timeout(&self) -> Duration { + self.timeout + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_timeout_ok() { + let args = TimeoutOverrideArgs { + default_timeout: Duration::from_secs(5), + }; + + assert_eq!(args.default_timeout(), Duration::from_secs(5)); + } +} diff --git a/crates/seatbelt/src/timeout/callbacks.rs b/crates/seatbelt/src/timeout/callbacks.rs new file mode 100644 index 00000000..9c25c69f --- /dev/null +++ b/crates/seatbelt/src/timeout/callbacks.rs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::time::Duration; + +use super::{OnTimeoutArgs, TimeoutOutputArgs, TimeoutOverrideArgs}; + +crate::utils::define_fn_wrapper!(TimeoutOutput(Fn(TimeoutOutputArgs) -> Out)); +crate::utils::define_fn_wrapper!(OnTimeout(Fn(&Out, OnTimeoutArgs))); +crate::utils::define_fn_wrapper!(TimeoutOverride(Fn(&In, TimeoutOverrideArgs) -> Option)); diff --git a/crates/seatbelt/src/timeout/layer.rs b/crates/seatbelt/src/timeout/layer.rs new file mode 100644 index 00000000..2aecc1a1 --- /dev/null +++ b/crates/seatbelt/src/timeout/layer.rs @@ -0,0 +1,415 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::borrow::Cow; +use std::marker::PhantomData; +use std::time::Duration; + +use crate::timeout::{ + OnTimeout, OnTimeoutArgs, Timeout, TimeoutOutput as TimeoutOutputCallback, TimeoutOutputArgs, TimeoutOverride, TimeoutOverrideArgs, +}; +use crate::utils::EnableIf; +use crate::utils::TelemetryHelper; +use crate::{NotSet, ResilienceContext, Set}; +use layered::Layer; + +/// Builder for configuring timeout resilience middleware. +/// +/// This type is created by calling [`Timeout::layer`](crate::timeout::Timeout::layer) and uses the +/// type-state pattern to enforce that required properties are configured before the timeout middleware can be built: +/// +/// - [`timeout_output`][TimeoutLayer::timeout_output]: Required to specify how to represent output values when a timeout occurs +/// - [`timeout`][TimeoutLayer::timeout]: Required to set the timeout duration for operations +/// +/// For comprehensive examples, see the [timeout module][crate::timeout] documentation. +/// +/// # Type State +/// +/// - `S1`: Tracks whether [`timeout`][TimeoutLayer::timeout] has been set +/// - `S2`: Tracks whether [`timeout_output`][TimeoutLayer::timeout_output] has been set +#[derive(Debug)] +pub struct TimeoutLayer { + context: ResilienceContext, + timeout: Option, + timeout_output: Option>, + on_timeout: Option>, + enable_if: EnableIf, + telemetry: TelemetryHelper, + timeout_override: Option>, + _state: PhantomData Out>, +} + +impl TimeoutLayer { + #[must_use] + pub(crate) fn new(name: Cow<'static, str>, context: &ResilienceContext) -> Self { + Self { + timeout: None, + timeout_output: None, + on_timeout: None, + enable_if: EnableIf::always(), + telemetry: context.create_telemetry(name), + context: context.clone(), + timeout_override: None, + _state: PhantomData, + } + } +} + +impl TimeoutLayer, S1, S2> { + /// Configures the error value to return when a timeout occurs for Result types. + /// + /// This is a convenience method for Result types that creates an error value + /// when a timeout occurs instead of requiring you to specify the full Result. + /// The error function receives [`TimeoutOutputArgs`] containing timeout context. + /// + /// # Arguments + /// + /// * `timeout_error` - Function that takes [`TimeoutOutputArgs`] and returns + /// the error value to use when a timeout occurs + pub fn timeout_error( + self, + timeout_error: impl Fn(TimeoutOutputArgs) -> E + Send + Sync + 'static, + ) -> TimeoutLayer, S1, Set> { + self.into_state::() + .timeout_output(move |args| Err(timeout_error(args))) + .into_state() + } +} + +impl TimeoutLayer { + /// Sets the timeout duration. + /// + /// This specifies how long to wait before timing out an operation. + /// This call replaces any previous timeout value. + /// + /// # Arguments + /// + /// * `timeout` - The maximum duration to wait for the operation to complete + #[must_use] + pub fn timeout(mut self, timeout: Duration) -> TimeoutLayer { + self.timeout = Some(timeout); + self.into_state::() + } + + /// Sets the timeout result factory function. + /// + /// This function is called when a timeout occurs to create the output value + /// that will be returned instead of the original operation's result. + /// This call replaces any previous timeout output handler. + /// + /// # Arguments + /// + /// * `output` - Function that takes [`TimeoutOutputArgs`] containing timeout + /// context and returns the output value to use when a timeout occurs + #[must_use] + pub fn timeout_output(mut self, output: impl Fn(TimeoutOutputArgs) -> Out + Send + Sync + 'static) -> TimeoutLayer { + self.timeout_output = Some(TimeoutOutputCallback::new(output)); + self.into_state::() + } + + /// Configures a callback invoked when a timeout occurs. + /// + /// This callback is useful for logging, metrics, or other observability + /// purposes. It receives the timeout output and [`OnTimeoutArgs`] with + /// detailed timeout information. + /// + /// The callback does not affect timeout behavior - it's purely for observation. + /// This call replaces any previous callback. + /// + /// **Default**: None (no observability by default) + /// + /// # Arguments + /// + /// * `on_timeout` - Function that takes a reference to the timeout output and + /// [`OnTimeoutArgs`] containing timeout context information + #[must_use] + pub fn on_timeout(mut self, on_timeout: impl Fn(&Out, OnTimeoutArgs) + Send + Sync + 'static) -> Self { + self.on_timeout = Some(OnTimeout::new(on_timeout)); + self + } + + /// Overrides the default timeout on a per-request basis. + /// + /// Use this to compute a timeout dynamically from the input. Return `Some(Duration)` + /// to apply an override, or `None` to fall back to the default timeout configured via + /// [`timeout`][TimeoutLayer::timeout]. The function receives [`TimeoutOverrideArgs`], + /// which exposes the default via [`TimeoutOverrideArgs::default_timeout`]. + /// + /// This call replaces any previous timeout override. + /// + /// **Default**: None (uses default timeout for all requests) + /// + /// # Arguments + /// + /// * `timeout_override` - Function that takes a reference to the input and + /// [`TimeoutOverrideArgs`] containing the default timeout, and returns + /// an optional override duration + #[must_use] + pub fn timeout_override( + mut self, + timeout_override: impl Fn(&In, TimeoutOverrideArgs) -> Option + Send + Sync + 'static, + ) -> Self { + self.timeout_override = Some(TimeoutOverride::new(timeout_override)); + self + } + + /// Optionally enables the timeout middleware based on a condition. + /// + /// When disabled, requests pass through without timeout protection. + /// This call replaces any previous condition. + /// + /// **Default**: Always enabled + /// + /// # Arguments + /// + /// * `is_enabled` - Function that takes a reference to the input and returns + /// `true` if timeout protection should be enabled for this request + #[must_use] + pub fn enable_if(mut self, is_enabled: impl Fn(&In) -> bool + Send + Sync + 'static) -> Self { + self.enable_if = EnableIf::new(is_enabled); + self + } + + /// Enables the timeout middleware unconditionally. + /// + /// All requests will have timeout protection applied. + /// This call replaces any previous condition. + /// + /// **Note**: This is the default behavior - timeout is enabled by default. + #[must_use] + pub fn enable_always(mut self) -> Self { + self.enable_if = EnableIf::always(); + self + } + + /// Disables the timeout middleware completely. + /// + /// All requests will pass through without timeout protection. + /// This call replaces any previous condition. + /// + /// **Note**: This overrides the default enabled behavior. + #[must_use] + pub fn disable(mut self) -> Self { + self.enable_if = EnableIf::never(); + self + } +} + +impl Layer for TimeoutLayer { + type Service = Timeout; + + fn layer(&self, inner: S) -> Self::Service { + Timeout { + inner, + clock: self.context.get_clock().clone(), + timeout: self.timeout.expect("timeout must be set in Ready state"), + enable_if: self.enable_if.clone(), + on_timeout: self.on_timeout.clone(), + timeout_output: self.timeout_output.clone().expect("timeout_result must be set in Ready state"), + timeout_override: self.timeout_override.clone(), + #[cfg(any(feature = "logs", feature = "metrics", test))] + telemetry: self.telemetry.clone(), + } + } +} + +impl TimeoutLayer { + fn into_state(self) -> TimeoutLayer { + TimeoutLayer { + timeout: self.timeout, + enable_if: self.enable_if, + timeout_output: self.timeout_output, + on_timeout: self.on_timeout, + telemetry: self.telemetry, + context: self.context, + timeout_override: self.timeout_override, + _state: PhantomData, + } + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::fmt::Debug; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + + use layered::Execute; + use tick::Clock; + + use super::*; + + #[test] + fn new_needs_timeout_output() { + let context = create_test_context(); + let layer: TimeoutLayer<_, _, NotSet, NotSet> = TimeoutLayer::new("test_timeout".into(), &context); + + assert!(layer.timeout.is_none()); + assert!(layer.timeout_output.is_none()); + assert!(layer.on_timeout.is_none()); + assert!(layer.timeout_override.is_none()); + assert_eq!(layer.telemetry.strategy_name.as_ref(), "test_timeout"); + assert!(layer.enable_if.call(&"test_input".to_string())); + } + + #[test] + fn timeout_output_ensure_set_correctly() { + let context = create_test_context(); + let layer = TimeoutLayer::new("test".into(), &context); + + let layer: TimeoutLayer<_, _, NotSet, Set> = layer.timeout_output(|args| format!("timeout: {}", args.timeout().as_millis())); + let result = layer.timeout_output.unwrap().call(TimeoutOutputArgs { + timeout: Duration::from_millis(3), + }); + + assert_eq!(result, "timeout: 3"); + } + + #[test] + fn timeout_error_ensure_set_correctly() { + let context = create_test_context_result(); + let layer = TimeoutLayer::new("test".into(), &context); + + let layer: TimeoutLayer<_, _, NotSet, Set> = layer.timeout_error(|args| format!("timeout: {}", args.timeout().as_millis())); + let result = layer + .timeout_output + .unwrap() + .call(TimeoutOutputArgs { + timeout: Duration::from_millis(3), + }) + .unwrap_err(); + + assert_eq!(result, "timeout: 3"); + } + + #[test] + fn timeout_ensure_set_correctly() { + let layer: TimeoutLayer<_, _, Set, Set> = TimeoutLayer::new("test".into(), &create_test_context()) + .timeout_output(|_args| "timeout: ".to_string()) + .timeout(Duration::from_millis(3)); + + assert_eq!(layer.timeout.unwrap(), Duration::from_millis(3)); + } + + #[test] + fn on_timeout_ok() { + let called = Arc::new(AtomicBool::new(false)); + let called_clone = Arc::clone(&called); + + let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().on_timeout(move |_output, _args| { + called_clone.store(true, Ordering::SeqCst); + }); + + layer.on_timeout.unwrap().call( + &"output".to_string(), + OnTimeoutArgs { + timeout: Duration::from_millis(3), + }, + ); + + assert!(called.load(Ordering::SeqCst)); + } + + #[test] + fn timeout_override_ok() { + let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().timeout_override(|_input, _args| Some(Duration::from_secs(3))); + + let result = layer.timeout_override.unwrap().call( + &"a".to_string(), + TimeoutOverrideArgs { + default_timeout: Duration::from_millis(3), + }, + ); + + assert_eq!(result, Some(Duration::from_secs(3))); + } + + #[test] + fn enable_if_ok() { + let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().enable_if(|input| matches!(input.as_ref(), "enable")); + + assert!(layer.enable_if.call(&"enable".to_string())); + assert!(!layer.enable_if.call(&"disable".to_string())); + } + + #[test] + fn disable_ok() { + let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().disable(); + + assert!(!layer.enable_if.call(&"whatever".to_string())); + } + + #[test] + fn enable_ok() { + let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().disable().enable_always(); + + assert!(layer.enable_if.call(&"whatever".to_string())); + } + + #[test] + fn timeout_when_ready_ok() { + let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().timeout(Duration::from_secs(123)); + + assert_eq!(layer.timeout.unwrap(), Duration::from_secs(123)); + } + + #[test] + fn timeout_output_when_ready_ok() { + let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().timeout_output(|_args| "some new value".to_string()); + assert!(layer.timeout_output.is_some()); + let result = layer.timeout_output.unwrap().call(TimeoutOutputArgs { + timeout: Duration::from_secs(123), + }); + + assert_eq!(result, "some new value"); + } + + #[test] + fn timeout_error_when_ready_ok() { + let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer_with_result().timeout_error(|_args| "some error value".to_string()); + assert!(layer.timeout_output.is_some()); + let result = layer + .timeout_output + .unwrap() + .call(TimeoutOutputArgs { + timeout: Duration::from_secs(123), + }) + .unwrap_err(); + + assert_eq!(result, "some error value"); + } + + #[test] + fn layer_ok() { + let _layered = create_ready_layer().layer(Execute::new(|input: String| async move { input })); + } + + #[test] + fn static_assertions() { + static_assertions::assert_impl_all!(TimeoutLayer: Layer); + static_assertions::assert_not_impl_all!(TimeoutLayer: Layer); + static_assertions::assert_not_impl_all!(TimeoutLayer: Layer); + static_assertions::assert_impl_all!(TimeoutLayer: Debug); + } + + fn create_test_context() -> ResilienceContext { + ResilienceContext::new(Clock::new_frozen()).name("test_pipeline") + } + + fn create_test_context_result() -> ResilienceContext> { + ResilienceContext::new(Clock::new_frozen()).name("test_pipeline") + } + + fn create_ready_layer() -> TimeoutLayer { + TimeoutLayer::new("test".into(), &create_test_context()) + .timeout_output(|_args| "timeout: ".to_string()) + .timeout(Duration::from_millis(3)) + } + + fn create_ready_layer_with_result() -> TimeoutLayer, Set, Set> { + TimeoutLayer::new("test".into(), &create_test_context_result()) + .timeout_error(|_args| "timeout: ".to_string()) + .timeout(Duration::from_millis(3)) + } +} diff --git a/crates/seatbelt/src/timeout/mod.rs b/crates/seatbelt/src/timeout/mod.rs new file mode 100644 index 00000000..8fc4e238 --- /dev/null +++ b/crates/seatbelt/src/timeout/mod.rs @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Timeout resilience middleware for services, applications, and libraries. +//! +//! This module provides timeout functionality to cancel long-running operations and prevent +//! services from hanging indefinitely when processing requests. The primary types are +//! [`Timeout`] and [`TimeoutLayer`]: +//! +//! - [`Timeout`] is the middleware that wraps an inner service and enforces timeout behavior +//! - [`TimeoutLayer`] is used to configure and construct the timeout middleware +//! +//! # Quick Start +//! +//! ```rust +//! # use std::time::Duration; +//! # use std::io; +//! # use tick::Clock; +//! # use layered::{Execute, Service, Stack}; +//! # use seatbelt::timeout::Timeout; +//! # use seatbelt::ResilienceContext; +//! # async fn example(clock: Clock) -> Result<(), io::Error> { +//! let context = ResilienceContext::new(&clock).name("my_service"); +//! +//! let stack = ( +//! Timeout::layer("timeout", &context) +//! .timeout_error(|_| io::Error::new(io::ErrorKind::TimedOut, "operation timed out")) +//! .timeout(Duration::from_secs(30)), +//! Execute::new(my_operation), +//! ); +//! +//! let service = stack.into_service(); +//! let result = service.execute("input".to_string()).await; +//! # Ok(()) +//! # } +//! # async fn my_operation(input: String) -> Result { Ok(input) } +//! ``` +//! +//! # Configuration +//! +//! The [`TimeoutLayer`] uses a type state pattern to enforce that all required properties are +//! configured before the layer can be built. This compile-time safety ensures that you cannot +//! accidentally create a timeout layer without properly specifying the timeout duration +//! and the timeout output generator: +//! +//! - [`timeout_output`][TimeoutLayer::timeout_output] or [`timeout_error`][TimeoutLayer::timeout_error]: Required function to generate output when timeout occurs +//! - [`timeout`][TimeoutLayer::timeout]: Required timeout duration for operations +//! +//! Each timeout layer requires an identifier for telemetry purposes. This identifier should use +//! `snake_case` naming convention to maintain consistency across the codebase. +//! +//! The default timeout is configured via [`TimeoutLayer::timeout`]. You can override that +//! per request with [`TimeoutLayer::timeout_override`]. +//! +//! # Defaults +//! +//! The timeout middleware uses the following default values when optional configuration is not provided: +//! +//! | Parameter | Default Value | Description | Configured By | +//! |-----------|---------------|-------------|---------------| +//! | Timeout duration | `None` (required) | Maximum duration to wait for operation completion | [`timeout`][TimeoutLayer::timeout] | +//! | Timeout output | `None` (required) | Output value to return when timeout occurs | [`timeout_output`][TimeoutLayer::timeout_output], [`timeout_error`][TimeoutLayer::timeout_error] | +//! | Timeout override | `None` | Uses default timeout for all requests | [`timeout_override`][TimeoutLayer::timeout_override] | +//! | On timeout callback | `None` | No observability by default | [`on_timeout`][TimeoutLayer::on_timeout] | +//! | Enable condition | Always enabled | Timeout protection is applied to all requests | [`enable_if`][TimeoutLayer::enable_if], [`enable_always`][TimeoutLayer::enable_always], [`disable`][TimeoutLayer::disable] | +//! +//! Unlike other middleware, timeout requires explicit configuration of both the timeout duration +//! and the output generator function, as there are no reasonable universal defaults for these values. +//! +//! # Thread Safety +//! +//! The [`Timeout`] type is thread-safe and implements both `Send` and `Sync` as enforced by +//! the `Service` trait it implements. This allows timeout middleware to be safely shared +//! across multiple threads and used in concurrent environments. +//! +//! # Telemetry +//! +//! ## Metrics +//! +//! - **Metric**: `resilience.event` (counter) +//! - **When**: Emitted when a timeout occurs +//! - **Attributes**: +//! - `resilience.pipeline.name`: Pipeline identifier from [`ResilienceContext::name`][crate::ResilienceContext::name] +//! - `resilience.strategy.name`: Timeout identifier from [`Timeout::layer`] +//! - `resilience.event.name`: Always `timeout` +//! +//! # Examples +//! +//! ## Basic Usage +//! +//! This example demonstrates the basic usage of configuring and using timeout middleware. +//! +//! ```rust +//! # use std::time::Duration; +//! # use tick::Clock; +//! # use layered::{Execute, Service, Stack}; +//! # use seatbelt::ResilienceContext; +//! # use seatbelt::timeout::Timeout; +//! # async fn example(clock: Clock) -> Result<(), Box> { +//! // Define common options for resilience middleware. The clock is runtime-specific and +//! // must be provided. See its documentation for details. +//! let context = ResilienceContext::new(&clock); +//! +//! let stack = ( +//! Timeout::layer("my_timeout", &context) +//! // Required: timeout middleware needs to know what output to return when timeout occurs +//! .timeout_output(|args| { +//! format!("timeout error, duration: {}ms", args.timeout().as_millis()) +//! }) +//! // Required: timeout duration must be set +//! .timeout(Duration::from_secs(30)), +//! Execute::new(execute_unreliable_operation), +//! ); +//! +//! // Build the service +//! let service = stack.into_service(); +//! +//! // Execute the service +//! let result = service.execute("quick".to_string()).await; +//! # let _result = result; +//! # Ok(()) +//! # } +//! +//! # async fn execute_unreliable_operation(input: String) -> String { input } +//! ``` +//! +//! ## Advanced Usage +//! +//! This example demonstrates advanced usage of the timeout middleware, including working with +//! Result-based outputs, custom configurations, and timeout overrides. +//! +//! ```rust +//! # use std::time::Duration; +//! # use std::io; +//! # use tick::Clock; +//! # use layered::{Execute, Service, Stack}; +//! # use seatbelt::ResilienceContext; +//! # use seatbelt::timeout::Timeout; +//! # async fn example(clock: Clock) -> Result<(), io::Error> { +//! // Define common options for resilience middleware. +//! let context = ResilienceContext::new(&clock); +//! +//! let stack = ( +//! Timeout::layer("my_timeout", &context) +//! // Return an error for Result outputs on timeout +//! .timeout_error(|args| io::Error::new(io::ErrorKind::TimedOut, "request timed out")) +//! // Default timeout +//! .timeout(Duration::from_secs(30)) +//! // Callback for when a timeout occurs +//! .on_timeout(|_output: &Result, args| { +//! println!("timeout occurred after {}ms", args.timeout().as_millis()); +//! }) +//! // Provide per-input timeout overrides (fallback to default on None) +//! .timeout_override(|input: &String, args| { +//! match input.as_str() { +//! "quick" => Some(Duration::from_secs(5)), // override +//! "slow" => Some(Duration::from_secs(60)), +//! _ => None, // use default (args.default_timeout())! +//! } +//! }) +//! // Optionally disable timeouts for some inputs +//! .enable_if(|input: &String| !input.starts_with("bypass_")), +//! Execute::new(execute_unreliable_operation), +//! ); +//! +//! // Build and execute the service +//! let service = stack.into_service(); +//! let result = service.execute("quick".to_string()).await?; +//! # let _result = result; +//! # Ok(()) +//! # } +//! # async fn execute_unreliable_operation(input: String) -> Result { Ok(input) } +//! ``` +mod args; +mod callbacks; +mod layer; +mod service; + +#[cfg(any(feature = "metrics", test))] +mod telemetry; + +pub use args::{OnTimeoutArgs, TimeoutOutputArgs, TimeoutOverrideArgs}; +pub(crate) use callbacks::{OnTimeout, TimeoutOutput, TimeoutOverride}; +pub use layer::TimeoutLayer; +pub use service::Timeout; diff --git a/crates/seatbelt/src/timeout/service.rs b/crates/seatbelt/src/timeout/service.rs new file mode 100644 index 00000000..86fa095f --- /dev/null +++ b/crates/seatbelt/src/timeout/service.rs @@ -0,0 +1,340 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::borrow::Cow; +use std::time::Duration; + +use futures::future::Either; +use layered::Service; +use tick::{Clock, FutureExt}; + +use crate::timeout::{OnTimeout, OnTimeoutArgs, TimeoutLayer, TimeoutOutput, TimeoutOutputArgs, TimeoutOverride, TimeoutOverrideArgs}; +use crate::utils::EnableIf; +use crate::{NotSet, ResilienceContext}; + +/// Applies a timeout to service execution for canceling long-running operations. +/// +/// `Timeout` wraps an inner [`Service`] and enforces a maximum duration for +/// each call. If the operation doesn't finish in time, an output that represents +/// the timeout is returned. This middleware is designed to be used across +/// services, applications, and libraries to prevent operations from hanging indefinitely. +/// +/// Timeouts are configured by calling [`Timeout::layer`](crate::timeout::Timeout::layer) +/// and using the builder methods on the returned [`TimeoutLayer`] instance. +/// +/// For comprehensive examples and usage patterns, see the [timeout module] documentation. +/// +/// [timeout module]: crate::timeout +#[derive(Debug)] +#[expect(clippy::struct_field_names, reason = "fields are named for clarity")] +pub struct Timeout { + pub(super) inner: S, + pub(super) clock: Clock, + pub(super) timeout: Duration, + pub(super) enable_if: EnableIf, + pub(super) on_timeout: Option>, + pub(super) timeout_override: Option>, + pub(super) timeout_output: TimeoutOutput, + #[cfg(any(feature = "logs", feature = "metrics", test))] + pub(super) telemetry: crate::utils::TelemetryHelper, +} + +impl Timeout { + /// Creates a [`TimeoutLayer`] used to configure the timeout resilience middleware. + /// + /// The instance returned by this call is a builder and cannot be used to build a + /// service until the required properties are set: `timeout_output` and `timeout`. + /// The `name` identifies the timeout strategy in telemetry, while `options` + /// provides configuration shared across multiple resilience middleware. + /// + /// # Example + /// + /// ```rust + /// # use std::time::Duration; + /// # use layered::{Execute, Stack}; + /// # use tick::Clock; + /// # use seatbelt::ResilienceContext; + /// use seatbelt::timeout::Timeout; + /// + /// # fn example(context: ResilienceContext) { + /// let timeout_layer = Timeout::layer("my_timeout", &context) + /// .timeout_output(|args| format!("timed out after {}ms", args.timeout().as_millis())) + /// .timeout(Duration::from_secs(30)); + /// # } + /// ``` + /// + /// For comprehensive examples, see the [timeout module] documentation. + /// + /// [timeout module]: crate::timeout + pub fn layer(name: impl Into>, context: &ResilienceContext) -> TimeoutLayer { + TimeoutLayer::new(name.into(), context) + } +} + +impl Service for Timeout +where + In: Send, + S: Service, +{ + type Out = Out; + + #[cfg_attr(test, mutants::skip)] // causes test timeouts + fn execute(&self, input: In) -> impl Future + Send { + if !self.enable_if.call(&input) { + return Either::Left(self.inner.execute(input)); + } + + let timeout = self + .timeout_override + .as_ref() + .and_then(|provider| { + provider.call( + &input, + TimeoutOverrideArgs { + default_timeout: self.timeout, + }, + ) + }) + .unwrap_or(self.timeout); + + Either::Right(async move { + match self.inner.execute(input).timeout(&self.clock, timeout).await { + Ok(output) => output, + Err(_error) => { + #[cfg(any(feature = "metrics", test))] + if self.telemetry.metrics_enabled() { + use crate::utils::{EVENT_NAME, PIPELINE_NAME, STRATEGY_NAME}; + + self.telemetry.report_metrics(&[ + opentelemetry::KeyValue::new(PIPELINE_NAME, self.telemetry.pipeline_name.clone()), + opentelemetry::KeyValue::new(STRATEGY_NAME, self.telemetry.strategy_name.clone()), + opentelemetry::KeyValue::new(EVENT_NAME, super::telemetry::TIMEOUT_EVENT_NAME), + ]); + } + + let output = self.timeout_output.call(TimeoutOutputArgs { timeout }); + + #[cfg(any(feature = "logs", test))] + if self.telemetry.logs_enabled { + tracing::event!( + name: "seatbelt.timeout", + tracing::Level::WARN, + pipeline.name = %self.telemetry.pipeline_name, + strategy.name = %self.telemetry.strategy_name, + timeout.ms = timeout.as_millis(), + ); + } + + if let Some(on_timeout) = &self.on_timeout { + on_timeout.call(&output, OnTimeoutArgs { timeout }); + } + + output + } + } + }) + } +} + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(not(miri))] // tokio runtime does not support Miri. +#[cfg(test)] +mod tests { + use layered::{Execute, Stack}; + use tick::ClockControl; + + use super::*; + + #[tokio::test] + async fn no_timeout() { + let clock = Clock::new_frozen(); + let context = ResilienceContext::new(clock); + + let stack = ( + Timeout::layer("test_timeout", &context) + .timeout_output(|args| format!("timed out after {}ms", args.timeout().as_millis())) + .timeout(Duration::from_secs(5)), + Execute::new(|input: String| async move { input }), + ); + + let service = stack.into_service(); + + let output = service.execute("test input".to_string()).await; + + assert_eq!(output, "test input".to_string()); + } + + #[tokio::test] + async fn timeout() { + let clock = ClockControl::default() + .auto_advance(Duration::from_millis(200)) + .auto_advance_limit(Duration::from_millis(500)) + .to_clock(); + let context = ResilienceContext::new(clock.clone()); + let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let called_clone = std::sync::Arc::clone(&called); + + let stack = ( + Timeout::layer("test_timeout", &context) + .timeout_output(|args| format!("timed out after {}ms", args.timeout().as_millis())) + .timeout(Duration::from_millis(200)) + .on_timeout(move |out, args| { + assert_eq!("timed out after 200ms", out.as_str()); + assert_eq!(200, args.timeout().as_millis()); + called.store(true, std::sync::atomic::Ordering::SeqCst); + }), + Execute::new(move |input| { + let clock = clock.clone(); + async move { + clock.delay(Duration::from_secs(1)).await; + input + } + }), + ); + + let service = stack.into_service(); + + let output = service.execute("test input".to_string()).await; + + assert_eq!(output, "timed out after 200ms"); + assert!(called_clone.load(std::sync::atomic::Ordering::SeqCst)); + } + + #[tokio::test] + async fn timeout_override_ensure_respected() { + let clock = ClockControl::default() + .auto_advance(Duration::from_millis(200)) + .auto_advance_limit(Duration::from_millis(5000)) + .to_clock(); + + let stack = ( + Timeout::layer("test_timeout", &ResilienceContext::new(clock.clone())) + .timeout_output(|args| format!("timed out after {}ms", args.timeout().as_millis())) + .timeout(Duration::from_millis(200)) + .timeout_override(|input, _args| { + if input == "ignore" { + return None; + } + + Some(Duration::from_millis(150)) + }), + Execute::new(move |input| { + let clock = clock.clone(); + async move { + clock.delay(Duration::from_secs(10)).await; + input + } + }), + ); + + let service = stack.into_service(); + + assert_eq!(service.execute("test input".to_string()).await, "timed out after 150ms"); + assert_eq!(service.execute("ignore".to_string()).await, "timed out after 200ms"); + } + + #[tokio::test] + async fn no_timeout_if_disabled() { + let clock = ClockControl::default().auto_advance_timers(true).to_clock(); + let stack = ( + Timeout::layer("test_timeout", &ResilienceContext::new(&clock)) + .timeout_output(|_args| "timed out".to_string()) + .timeout(Duration::from_millis(200)) + .disable(), + Execute::new({ + let clock = clock.clone(); + move |input| { + let clock = clock.clone(); + async move { + clock.delay(Duration::from_secs(1)).await; + input + } + } + }), + ); + + let service = stack.into_service(); + let output = service.execute("test input".to_string()).await; + + assert_eq!(output, "test input"); + } + + #[tokio::test] + async fn timeout_emits_log() { + use tracing_subscriber::util::SubscriberInitExt; + + use crate::testing::LogCapture; + + let log_capture = LogCapture::new(); + let _guard = log_capture.subscriber().set_default(); + + let clock = ClockControl::default() + .auto_advance(Duration::from_millis(200)) + .auto_advance_limit(Duration::from_millis(500)) + .to_clock(); + let context = ResilienceContext::new(clock.clone()).enable_logs().name("log_test_pipeline"); + + let stack = ( + Timeout::layer("log_test_timeout", &context) + .timeout_output(|_| "timed out".to_string()) + .timeout(Duration::from_millis(100)), + Execute::new(move |input| { + let clock = clock.clone(); + async move { + clock.delay(Duration::from_secs(1)).await; + input + } + }), + ); + + let service = stack.into_service(); + let _ = service.execute("test".to_string()).await; + + log_capture.assert_contains("seatbelt::timeout"); + log_capture.assert_contains("log_test_pipeline"); + log_capture.assert_contains("log_test_timeout"); + log_capture.assert_contains("timeout.ms=100"); + } + + #[tokio::test] + async fn timeout_emits_metrics() { + use opentelemetry::KeyValue; + + use crate::testing::MetricTester; + use crate::utils::{EVENT_NAME, PIPELINE_NAME, STRATEGY_NAME}; + + let metrics = MetricTester::new(); + let clock = ClockControl::default() + .auto_advance(Duration::from_millis(200)) + .auto_advance_limit(Duration::from_millis(500)) + .to_clock(); + let context = ResilienceContext::new(clock.clone()) + .enable_metrics(metrics.meter_provider()) + .name("metrics_pipeline"); + + let stack = ( + Timeout::layer("metrics_timeout", &context) + .timeout_output(|_| "timed out".to_string()) + .timeout(Duration::from_millis(100)), + Execute::new(move |input| { + let clock = clock.clone(); + async move { + clock.delay(Duration::from_secs(1)).await; + input + } + }), + ); + + let service = stack.into_service(); + let _ = service.execute("test".to_string()).await; + + metrics.assert_attributes( + &[ + KeyValue::new(PIPELINE_NAME, "metrics_pipeline"), + KeyValue::new(STRATEGY_NAME, "metrics_timeout"), + KeyValue::new(EVENT_NAME, "timeout"), + ], + Some(3), + ); + } +} diff --git a/crates/seatbelt/src/timeout/telemetry.rs b/crates/seatbelt/src/timeout/telemetry.rs new file mode 100644 index 00000000..989a3ec4 --- /dev/null +++ b/crates/seatbelt/src/timeout/telemetry.rs @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/// The name of the timeout event for telemetry reporting. +pub(super) const TIMEOUT_EVENT_NAME: &str = "timeout"; diff --git a/crates/seatbelt/src/utils/attributes.rs b/crates/seatbelt/src/utils/attributes.rs new file mode 100644 index 00000000..51c1e46e --- /dev/null +++ b/crates/seatbelt/src/utils/attributes.rs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/// Key used to annotate the name of a resilience pipeline. +/// +/// Values reported under this dimension should be short and concise, preferably in `snake_case`. +/// Examples: `user_auth`, `data_processing`, `payment_flow`. +#[cfg(any(feature = "metrics", test))] +pub(crate) const PIPELINE_NAME: &str = "resilience.pipeline.name"; + +/// Key used to annotate the name of a resilience strategy. +/// +/// Values reported under this dimension should be short and concise, preferably in `snake_case`. +/// Examples: `retry`, `circuit_breaker`, `timeout`, `bulkhead`. +#[cfg(any(feature = "metrics", test))] +pub(crate) const STRATEGY_NAME: &str = "resilience.strategy.name"; + +/// Key used to annotate the specific resilience event being emitted. +/// +/// Values reported under this dimension should be short and concise, preferably in `snake_case`. +/// Examples: `retry`, `timeout`, `circuit_opened`. +#[cfg(any(feature = "metrics", test))] +pub(crate) const EVENT_NAME: &str = "resilience.event.name"; + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pipeline_name_is_expected() { + assert_eq!(PIPELINE_NAME, "resilience.pipeline.name"); + } + + #[test] + fn test_strategy_name_is_expected() { + assert_eq!(STRATEGY_NAME, "resilience.strategy.name"); + } + + #[test] + fn test_event_name_is_expected() { + assert_eq!(EVENT_NAME, "resilience.event.name"); + } +} diff --git a/crates/seatbelt/src/utils/define_fn_wrapper.rs b/crates/seatbelt/src/utils/define_fn_wrapper.rs new file mode 100644 index 00000000..01a4f576 --- /dev/null +++ b/crates/seatbelt/src/utils/define_fn_wrapper.rs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/// A macro to generate `Fn` like wrapper types with consistent patterns. +/// +/// This macro generates a type that wraps a function in an `Arc`, +/// providing `Clone`, `Debug`, and convenient constructor methods. We need this to allow storing +/// user-provided functions (e.g., predicates) in a thread-safe, clonable way. +/// +/// # Syntax +/// +/// ```rust,ignore +/// define_fn_wrapper!(TypeName(Fn(args) -> ReturnType)); +/// ``` +/// +/// # Example +/// +/// ```rust,ignore +/// define_fn_wrapper!(ShouldRetry(Fn(&Res, ShouldRetryArgs) -> Recovery)); +/// ``` +/// +/// This generates a `ShouldRetry` struct with methods: +/// - `new(predicate: F) -> Self` where `F: Fn(...) + Send + Sync + 'static` +/// - `call(&self, args...) -> ReturnType` to invoke the wrapped function +/// - `Clone` and `Debug` implementations +macro_rules! define_fn_wrapper { + // Match pattern: Name(Fn(param_name: param_type, ...) -> return_type) + ($name:ident<$($generics:ident),*>(Fn($($param_name:ident: $param_ty:ty),*) -> $return_ty:ty)) => { + pub(crate) struct $name<$($generics),*>(std::sync::Arc $return_ty + Send + Sync>); + + impl<$($generics),*> $name<$($generics),*> { + pub(crate) fn new(predicate: F) -> Self + where + F: Fn($($param_ty),*) -> $return_ty + Send + Sync + 'static, + { + Self(std::sync::Arc::new(predicate)) + } + + pub(crate) fn call(&self, $($param_name: $param_ty),*) -> $return_ty { + (self.0)($($param_name),*) + } + } + + impl<$($generics),*> Clone for $name<$($generics),*> { + fn clone(&self) -> Self { + Self(self.0.clone()) + } + } + + impl<$($generics),*> std::fmt::Debug for $name<$($generics),*> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(stringify!($name)).finish() + } + } + }; + + // Match pattern without return type (defaults to unit) + ($name:ident<$($generics:ident),*>(Fn($($param_name:ident: $param_ty:ty),*))) => { + crate::utils::define_fn_wrapper!($name<$($generics),*>(Fn($($param_name: $param_ty),*) -> ())); + }; + + // Alternative match for simple cases without explicit parameter names + // For two parameters + ($name:ident<$($generics:ident),*>(Fn($param1:ty, $param2:ty) -> $return_ty:ty)) => { + crate::utils::define_fn_wrapper!($name<$($generics),*>(Fn(arg1: $param1, arg2: $param2) -> $return_ty)); + }; + + // For two parameters without return type + ($name:ident<$($generics:ident),*>(Fn($param1:ty, $param2:ty))) => { + crate::utils::define_fn_wrapper!($name<$($generics),*>(Fn(arg1: $param1, arg2: $param2) -> ())); + }; + + // For one parameter + ($name:ident<$($generics:ident),*>(Fn($param1:ty) -> $return_ty:ty)) => { + crate::utils::define_fn_wrapper!($name<$($generics),*>(Fn(arg1: $param1) -> $return_ty)); + }; + + // For one parameter without return type + ($name:ident<$($generics:ident),*>(Fn($param1:ty))) => { + crate::utils::define_fn_wrapper!($name<$($generics),*>(Fn(arg1: $param1) -> ())); + }; + + // For zero parameters + ($name:ident<$($generics:ident),*>(Fn() -> $return_ty:ty)) => { + $crate::utils::define_fn_wrapper!($name<$($generics),*>(Fn() -> $return_ty)); + }; + + // For zero parameters without return type + ($name:ident<$($generics:ident),*>(Fn())) => { + $crate::utils::define_fn_wrapper!($name<$($generics),*>(Fn() -> ())); + }; + + // Match pattern without return type (defaults to unit) + ($name:ident(Fn($($param_name:ident: $param_ty:ty),*))) => { + $crate::utils::define_fn_wrapper!($name(Fn($($param_name: $param_ty),*) -> ())); + }; + + // Alternative match for simple cases without explicit parameter names + // For two parameters + ($name:ident(Fn($param1:ty, $param2:ty) -> $return_ty:ty)) => { + $crate::utils::define_fn_wrapper!($name(Fn(arg1: $param1, arg2: $param2) -> $return_ty)); + }; + + // For two parameters without return type + ($name:ident(Fn($param1:ty, $param2:ty))) => { + $crate::define_fn_wrapper!($name(Fn(arg1: $param1, arg2: $param2) -> ())); + }; + + // For one parameter + ($name:ident(Fn($param1:ty) -> $return_ty:ty)) => { + $crate::utils::define_fn_wrapper!($name(Fn(arg1: $param1) -> $return_ty)); + }; + + // For one parameter without return type + ($name:ident(Fn($param1:ty))) => { + $crate::utils::define_fn_wrapper!($name(Fn(arg1: $param1) -> ())); + }; + + // For zero parameters + ($name:ident(Fn() -> $return_ty:ty)) => { + $crate::utils::define_fn_wrapper!($name(Fn() -> $return_ty)); + }; + + // For zero parameters without return type + ($name:ident(Fn())) => { + $crate::utils::define_fn_wrapper!($name(Fn() -> ())); + }; +} + +pub(crate) use define_fn_wrapper; + +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg(test)] +mod tests { + use std::fmt::Debug; + + define_fn_wrapper!(InOut(Fn(&In) -> Out)); + + #[test] + fn static_assertions() { + static_assertions::assert_impl_all!(InOut: Send, Sync, Debug, Clone); + } + + #[test] + fn call_ok() { + let wrapper = InOut::new(|input: &String| input.clone()); + + let result = wrapper.call(&"Hello, World!".to_string()); + assert_eq!(result, "Hello, World!".to_string()); + + let wrapper = wrapper; + let result = wrapper.call(&"Hello, World!".to_string()); + assert_eq!(result, "Hello, World!".to_string()); + } + + #[test] + fn debug_ok() { + let wrapper = InOut::new(|input: &String| input.clone()); + + let debug_str = format!("{wrapper:?}"); + + assert_eq!(debug_str, "InOut"); + } +} diff --git a/crates/seatbelt/src/utils/mod.rs b/crates/seatbelt/src/utils/mod.rs new file mode 100644 index 00000000..c9459081 --- /dev/null +++ b/crates/seatbelt/src/utils/mod.rs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +mod define_fn_wrapper; +pub(crate) use define_fn_wrapper::define_fn_wrapper; + +#[cfg(any(feature = "metrics", test))] +mod attributes; +#[cfg(any(feature = "metrics", test))] +pub(crate) use attributes::*; + +mod telemetry_helper; +pub(crate) use telemetry_helper::TelemetryHelper; + +define_fn_wrapper!(EnableIf(Fn(&In) -> bool)); + +impl EnableIf { + /// Creates a new `EnableIf` instance that always returns `true`. + pub fn always() -> Self { + Self::new(|_| true) + } + + /// Creates a new `EnableIf` instance that always returns `false`. + pub fn never() -> Self { + Self::new(|_| false) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn enable_if_debug() { + let enable_if: EnableIf = EnableIf::always(); + assert_eq!(format!("{enable_if:?}"), "EnableIf"); + } +} diff --git a/crates/seatbelt/src/utils/telemetry_helper.rs b/crates/seatbelt/src/utils/telemetry_helper.rs new file mode 100644 index 00000000..ab56633f --- /dev/null +++ b/crates/seatbelt/src/utils/telemetry_helper.rs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#[derive(Debug, Clone)] +pub(crate) struct TelemetryHelper { + #[cfg(any(feature = "metrics", feature = "logs", test))] + pub(crate) pipeline_name: std::borrow::Cow<'static, str>, + #[cfg(any(feature = "metrics", feature = "logs", test))] + pub(crate) strategy_name: std::borrow::Cow<'static, str>, + #[cfg(any(feature = "metrics", test))] + pub(crate) event_reporter: Option>, + #[cfg(any(feature = "logs", test))] + pub(crate) logs_enabled: bool, +} + +impl TelemetryHelper { + #[cfg(any(feature = "metrics", test))] + pub fn metrics_enabled(&self) -> bool { + self.event_reporter.is_some() + } + + #[cfg(any(feature = "metrics", test))] + pub fn report_metrics(&self, attributes: &[opentelemetry::KeyValue]) { + if let Some(reporter) = &self.event_reporter { + reporter.add(1, attributes); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn metrics_enabled_returns_false_when_no_reporter() { + let helper = TelemetryHelper { + pipeline_name: "test".into(), + strategy_name: "test".into(), + event_reporter: None, + logs_enabled: false, + }; + assert!(!helper.metrics_enabled()); + } +}