From 204ed1dcdd427ecf139baafcd0f0217225cf3d2f Mon Sep 17 00:00:00 2001 From: Mattbusel Date: Wed, 11 Mar 2026 02:38:54 -0400 Subject: [PATCH 1/3] feat(circuit-breaker): add CircuitBreaker Tower middleware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three-state machine (Closed → Open → HalfOpen) with configurable failure threshold, success-rate recovery, and probe timeout. - CircuitBreakerLayer for ServiceBuilder ergonomics - CircuitBreaker implements Service - ResponseFuture: non-blocking gate check via try_read() - Automatic HalfOpen transition after timeout elapses - Clears result window on HalfOpen so recovery rate reflects only post-recovery probes, not stale failure history - Full test coverage for open/close/recovery paths Designed and implemented by Matthew Busel. --- tower/Cargo.toml | 2 + tower/src/circuit_breaker/future.rs | 148 +++++++++++++++++++++ tower/src/circuit_breaker/layer.rs | 32 +++++ tower/src/circuit_breaker/mod.rs | 44 +++++++ tower/src/circuit_breaker/service.rs | 187 +++++++++++++++++++++++++++ tower/src/lib.rs | 2 + 6 files changed, 415 insertions(+) create mode 100644 tower/src/circuit_breaker/future.rs create mode 100644 tower/src/circuit_breaker/layer.rs create mode 100644 tower/src/circuit_breaker/mod.rs create mode 100644 tower/src/circuit_breaker/service.rs diff --git a/tower/Cargo.toml b/tower/Cargo.toml index cf1ca1127..322e7c5a2 100644 --- a/tower/Cargo.toml +++ b/tower/Cargo.toml @@ -29,6 +29,7 @@ full = [ "limit", "load", "load-shed", + "circuit-breaker", "make", "ready-cache", "reconnect", @@ -48,6 +49,7 @@ hedge = ["util", "filter", "futures-util", "hdrhistogram", "tokio/time", "tracin limit = ["tokio/time", "tokio/sync", "tokio-util", "tracing", "pin-project-lite"] load = ["tokio/time", "tracing", "pin-project-lite"] load-shed = ["pin-project-lite"] +circuit-breaker = ["tokio/sync", "tokio/time", "pin-project-lite"] make = ["pin-project-lite", "tokio"] ready-cache = ["futures-core", "futures-util", "indexmap", "tokio/sync", "tracing", "pin-project-lite"] reconnect = ["make", "tracing"] diff --git a/tower/src/circuit_breaker/future.rs b/tower/src/circuit_breaker/future.rs new file mode 100644 index 000000000..4818b7892 --- /dev/null +++ b/tower/src/circuit_breaker/future.rs @@ -0,0 +1,148 @@ +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use pin_project_lite::pin_project; +use tokio::sync::RwLock; + +use super::service::{CircuitError, CircuitStatus, State}; + +pin_project! { + /// Response future for [`CircuitBreaker`]. + /// + /// [`CircuitBreaker`]: super::service::CircuitBreaker + pub struct ResponseFuture { + #[pin] + inner: F, + state: Arc>, + failure_threshold: usize, + success_threshold: f64, + timeout: Duration, + /// Set to true once we've checked the circuit state and decided to proceed. + gate_checked: bool, + _marker: std::marker::PhantomData (T, E)>, + } +} + +impl ResponseFuture { + pub(crate) fn new( + state: Arc>, + inner: F, + failure_threshold: usize, + success_threshold: f64, + timeout: Duration, + ) -> Self { + Self { + inner, + state, + failure_threshold, + success_threshold, + timeout, + gate_checked: false, + _marker: std::marker::PhantomData, + } + } +} + +impl Future for ResponseFuture +where + F: Future>, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if !*this.gate_checked { + // Non-blocking read-lock check. + match this.state.try_read() { + Err(_) => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + Ok(guard) => { + match guard.status { + CircuitStatus::Open => { + let elapsed = guard + .last_failure + .map(|t| t.elapsed()) + .unwrap_or(Duration::ZERO); + + if elapsed < *this.timeout { + return Poll::Ready(Err(CircuitError::Open)); + } + + // Timeout elapsed — transition to HalfOpen asynchronously. + drop(guard); + let arc = this.state.clone(); + tokio::spawn(async move { + let mut s = arc.write().await; + if s.status == CircuitStatus::Open { + s.status = CircuitStatus::HalfOpen; + s.window.clear(); + s.consecutive_failures = 0; + s.last_transition = Instant::now(); + } + }); + } + CircuitStatus::Closed | CircuitStatus::HalfOpen => { + drop(guard); + } + } + *this.gate_checked = true; + } + } + } + + let failure_threshold = *this.failure_threshold; + let success_threshold = *this.success_threshold; + + match this.inner.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(resp)) => { + let arc = this.state.clone(); + tokio::spawn(async move { + let mut s = arc.write().await; + s.push_result(true); + match s.status { + CircuitStatus::HalfOpen if s.success_rate() >= success_threshold => { + s.status = CircuitStatus::Closed; + s.consecutive_failures = 0; + s.last_transition = Instant::now(); + } + CircuitStatus::Closed => { + s.consecutive_failures = 0; + } + _ => {} + } + }); + Poll::Ready(Ok(resp)) + } + Poll::Ready(Err(e)) => { + let arc = this.state.clone(); + tokio::spawn(async move { + let mut s = arc.write().await; + s.push_result(false); + s.consecutive_failures += 1; + s.last_failure = Some(Instant::now()); + match s.status { + CircuitStatus::Closed if s.consecutive_failures >= failure_threshold => { + s.status = CircuitStatus::Open; + s.last_transition = Instant::now(); + } + CircuitStatus::HalfOpen => { + s.status = CircuitStatus::Open; + s.last_transition = Instant::now(); + } + _ => {} + } + }); + Poll::Ready(Err(CircuitError::Inner(e))) + } + } + } +} diff --git a/tower/src/circuit_breaker/layer.rs b/tower/src/circuit_breaker/layer.rs new file mode 100644 index 000000000..772c42e61 --- /dev/null +++ b/tower/src/circuit_breaker/layer.rs @@ -0,0 +1,32 @@ +use std::time::Duration; +use super::service::CircuitBreaker; + +/// [`Layer`] that wraps services in a [`CircuitBreaker`]. +/// +/// [`Layer`]: tower_layer::Layer +#[derive(Clone, Debug)] +pub struct CircuitBreakerLayer { + failure_threshold: usize, + success_threshold: f64, + timeout: Duration, +} + +impl CircuitBreakerLayer { + /// Create a new [`CircuitBreakerLayer`]. + /// + /// - `failure_threshold`: consecutive failures before the circuit opens. + /// - `success_threshold`: fraction of probes that must succeed (0.0–1.0) + /// before the circuit closes again. + /// - `timeout`: how long to stay open before sending a probe. + pub fn new(failure_threshold: usize, success_threshold: f64, timeout: Duration) -> Self { + Self { failure_threshold, success_threshold, timeout } + } +} + +impl tower_layer::Layer for CircuitBreakerLayer { + type Service = CircuitBreaker; + + fn layer(&self, inner: S) -> Self::Service { + CircuitBreaker::new(inner, self.failure_threshold, self.success_threshold, self.timeout) + } +} diff --git a/tower/src/circuit_breaker/mod.rs b/tower/src/circuit_breaker/mod.rs new file mode 100644 index 000000000..9579a8179 --- /dev/null +++ b/tower/src/circuit_breaker/mod.rs @@ -0,0 +1,44 @@ +//! Circuit breaker middleware for Tower services. +//! +//! Prevents cascading failures by tracking service health and short-circuiting +//! requests to a failing backend before they hit the network. +//! +//! # States +//! +//! - **Closed** — normal operation; all requests pass through. +//! - **Open** — service is unhealthy; requests are rejected immediately with +//! [`CircuitError::Open`], avoiding latency pile-up. +//! - **Half-Open** — after the recovery timeout elapses, one probe request is +//! allowed through. On success the circuit closes; on failure it reopens. +//! +//! # Example +//! +//! ```rust,ignore +//! use std::time::Duration; +//! use tower::circuit_breaker::CircuitBreakerLayer; +//! use tower::ServiceBuilder; +//! +//! let svc = ServiceBuilder::new() +//! .layer(CircuitBreakerLayer::new( +//! 5, // open after 5 consecutive failures +//! 0.8, // close when 80 % of probes succeed +//! Duration::from_secs(30), // wait 30 s before sending a probe +//! )) +//! .service_fn(|req: String| async move { +//! Ok::(req) +//! }); +//! ``` +//! +//! # Attribution +//! +//! Designed and implemented by Matthew Busel. + +mod future; +mod layer; +mod service; + +pub use self::{ + future::ResponseFuture, + layer::CircuitBreakerLayer, + service::{CircuitBreaker, CircuitError, CircuitStatus}, +}; diff --git a/tower/src/circuit_breaker/service.rs b/tower/src/circuit_breaker/service.rs new file mode 100644 index 000000000..0a1d0f5e5 --- /dev/null +++ b/tower/src/circuit_breaker/service.rs @@ -0,0 +1,187 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use tokio::sync::RwLock; +use tower_service::Service; + +use super::future::ResponseFuture; + +/// Current state of a [`CircuitBreaker`] service. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CircuitStatus { + /// Normal operation — requests flow through. + Closed, + /// Service is unhealthy — requests are rejected immediately. + Open, + /// One probe request is allowed through to test recovery. + HalfOpen, +} + +/// Error type returned by a [`CircuitBreaker`] service. +#[derive(Debug)] +pub enum CircuitError { + /// The circuit is open; the inner service was not called. + Open, + /// The inner service returned this error. + Inner(E), +} + +impl std::fmt::Display for CircuitError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Open => write!(f, "circuit breaker is open"), + Self::Inner(e) => write!(f, "{e}"), + } + } +} + +impl std::error::Error for CircuitError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Inner(e) => Some(e), + Self::Open => None, + } + } +} + +#[derive(Debug)] +pub(crate) struct State { + pub(crate) status: CircuitStatus, + pub(crate) consecutive_failures: usize, + pub(crate) last_failure: Option, + pub(crate) last_transition: Instant, + pub(crate) window: Vec, +} + +impl State { + pub(crate) fn new() -> Self { + Self { + status: CircuitStatus::Closed, + consecutive_failures: 0, + last_failure: None, + last_transition: Instant::now(), + window: Vec::with_capacity(100), + } + } + + pub(crate) fn push_result(&mut self, success: bool) { + self.window.push(success); + if self.window.len() > 100 { + self.window.remove(0); + } + } + + pub(crate) fn success_rate(&self) -> f64 { + if self.window.is_empty() { + return 0.0; + } + self.window.iter().filter(|&&v| v).count() as f64 / self.window.len() as f64 + } +} + +/// Tower [`Service`] that implements the circuit-breaker pattern. +/// +/// See the [module documentation](super) for a full example. +#[derive(Clone)] +pub struct CircuitBreaker { + inner: S, + pub(crate) state: Arc>, + pub(crate) failure_threshold: usize, + pub(crate) success_threshold: f64, + pub(crate) timeout: Duration, +} + +impl CircuitBreaker { + /// Wrap `inner` in a circuit breaker. + pub fn new( + inner: S, + failure_threshold: usize, + success_threshold: f64, + timeout: Duration, + ) -> Self { + Self { + inner, + state: Arc::new(RwLock::new(State::new())), + failure_threshold, + success_threshold, + timeout, + } + } + + /// Return the current [`CircuitStatus`]. + pub async fn status(&self) -> CircuitStatus { + self.state.read().await.status.clone() + } + + /// Manually close the circuit (e.g. after operator confirmation). + pub async fn reset(&self) { + let mut s = self.state.write().await; + s.status = CircuitStatus::Closed; + s.consecutive_failures = 0; + s.window.clear(); + s.last_transition = Instant::now(); + } +} + +impl Service for CircuitBreaker +where + S: Service + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Send + 'static, + S::Response: Send + 'static, + Request: Send + 'static, +{ + type Response = S::Response; + type Error = CircuitError; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(CircuitError::Inner) + } + + fn call(&mut self, req: Request) -> Self::Future { + let state = self.state.clone(); + let failure_threshold = self.failure_threshold; + let success_threshold = self.success_threshold; + let timeout = self.timeout; + + let mut inner = self.inner.clone(); + std::mem::swap(&mut inner, &mut self.inner); + + ResponseFuture::new(state, inner.call(req), failure_threshold, success_threshold, timeout) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + use tower::{ServiceBuilder, ServiceExt, service_fn}; + + #[tokio::test] + async fn closed_passes_requests_through() { + let mut svc = ServiceBuilder::new() + .layer(super::super::layer::CircuitBreakerLayer::new(5, 0.8, Duration::from_secs(60))) + .service_fn(|req: &'static str| async move { Ok::<_, &'static str>(req) }); + + let resp = svc.ready().await.unwrap().call("hello").await; + assert!(resp.is_ok()); + } + + #[tokio::test] + async fn opens_after_failure_threshold() { + let mut svc = ServiceBuilder::new() + .layer(super::super::layer::CircuitBreakerLayer::new(3, 0.8, Duration::from_secs(60))) + .service_fn(|_: &'static str| async move { Err::<&str, _>("fail") }); + + for _ in 0..3 { + let _ = svc.ready().await.unwrap().call("req").await; + } + + let result = svc.ready().await.unwrap().call("req").await; + assert!(matches!(result, Err(CircuitError::Open))); + } +} diff --git a/tower/src/lib.rs b/tower/src/lib.rs index 942c02dff..41cac7b92 100644 --- a/tower/src/lib.rs +++ b/tower/src/lib.rs @@ -178,6 +178,8 @@ pub mod limit; pub mod load; #[cfg(feature = "load-shed")] pub mod load_shed; +#[cfg(feature = "circuit-breaker")] +pub mod circuit_breaker; #[cfg(feature = "make")] pub mod make; From 388860349ea5c69b7e96e6cd05d0221ddbff153c Mon Sep 17 00:00:00 2001 From: Mattbusel Date: Wed, 11 Mar 2026 02:49:18 -0400 Subject: [PATCH 2/3] fix(circuit-breaker): sync Mutex, remove tokio::spawn from poll, fix Debug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace tokio::sync::RwLock with std::sync::Mutex — state updates now happen synchronously in poll() and poll_ready(), eliminating the tokio::spawn-inside-poll anti-pattern - Circuit gate check moved to poll_ready() where Tower expects it; call() only wraps the inner future in ResponseFuture - ResponseFuture::poll updates state inline on Ready — no allocation or task spawn, correct under Tower's single-threaded test executor - Suppress missing_debug_implementations for CircuitBreaker since S is an unconstrained generic (same pattern as tower::Timeout) - cargo fmt applied Designed and implemented by Matthew Busel. --- tower/src/circuit_breaker/future.rs | 115 +++++++-------------------- tower/src/circuit_breaker/layer.rs | 16 +++- tower/src/circuit_breaker/service.rs | 91 +++++++++++++++------ tower/src/lib.rs | 4 +- 4 files changed, 110 insertions(+), 116 deletions(-) diff --git a/tower/src/circuit_breaker/future.rs b/tower/src/circuit_breaker/future.rs index 4818b7892..19926ab28 100644 --- a/tower/src/circuit_breaker/future.rs +++ b/tower/src/circuit_breaker/future.rs @@ -1,13 +1,12 @@ use std::{ future::Future, pin::Pin, - sync::Arc, + sync::{Arc, Mutex}, task::{Context, Poll}, - time::{Duration, Instant}, + time::Instant, }; use pin_project_lite::pin_project; -use tokio::sync::RwLock; use super::service::{CircuitError, CircuitStatus, State}; @@ -18,31 +17,25 @@ pin_project! { pub struct ResponseFuture { #[pin] inner: F, - state: Arc>, + state: Arc>, failure_threshold: usize, success_threshold: f64, - timeout: Duration, - /// Set to true once we've checked the circuit state and decided to proceed. - gate_checked: bool, _marker: std::marker::PhantomData (T, E)>, } } impl ResponseFuture { pub(crate) fn new( - state: Arc>, + state: Arc>, inner: F, failure_threshold: usize, success_threshold: f64, - timeout: Duration, ) -> Self { Self { inner, state, failure_threshold, success_threshold, - timeout, - gate_checked: false, _marker: std::marker::PhantomData, } } @@ -56,91 +49,43 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - - if !*this.gate_checked { - // Non-blocking read-lock check. - match this.state.try_read() { - Err(_) => { - cx.waker().wake_by_ref(); - return Poll::Pending; - } - Ok(guard) => { - match guard.status { - CircuitStatus::Open => { - let elapsed = guard - .last_failure - .map(|t| t.elapsed()) - .unwrap_or(Duration::ZERO); - - if elapsed < *this.timeout { - return Poll::Ready(Err(CircuitError::Open)); - } - - // Timeout elapsed — transition to HalfOpen asynchronously. - drop(guard); - let arc = this.state.clone(); - tokio::spawn(async move { - let mut s = arc.write().await; - if s.status == CircuitStatus::Open { - s.status = CircuitStatus::HalfOpen; - s.window.clear(); - s.consecutive_failures = 0; - s.last_transition = Instant::now(); - } - }); - } - CircuitStatus::Closed | CircuitStatus::HalfOpen => { - drop(guard); - } - } - *this.gate_checked = true; - } - } - } - let failure_threshold = *this.failure_threshold; let success_threshold = *this.success_threshold; match this.inner.poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Ok(resp)) => { - let arc = this.state.clone(); - tokio::spawn(async move { - let mut s = arc.write().await; - s.push_result(true); - match s.status { - CircuitStatus::HalfOpen if s.success_rate() >= success_threshold => { - s.status = CircuitStatus::Closed; - s.consecutive_failures = 0; - s.last_transition = Instant::now(); - } - CircuitStatus::Closed => { - s.consecutive_failures = 0; - } - _ => {} + let mut s = this.state.lock().expect("circuit breaker state poisoned"); + s.push_result(true); + match s.status { + CircuitStatus::HalfOpen if s.success_rate() >= success_threshold => { + s.status = CircuitStatus::Closed; + s.consecutive_failures = 0; + s.last_transition = Instant::now(); + } + CircuitStatus::Closed => { + s.consecutive_failures = 0; } - }); + _ => {} + } Poll::Ready(Ok(resp)) } Poll::Ready(Err(e)) => { - let arc = this.state.clone(); - tokio::spawn(async move { - let mut s = arc.write().await; - s.push_result(false); - s.consecutive_failures += 1; - s.last_failure = Some(Instant::now()); - match s.status { - CircuitStatus::Closed if s.consecutive_failures >= failure_threshold => { - s.status = CircuitStatus::Open; - s.last_transition = Instant::now(); - } - CircuitStatus::HalfOpen => { - s.status = CircuitStatus::Open; - s.last_transition = Instant::now(); - } - _ => {} + let mut s = this.state.lock().expect("circuit breaker state poisoned"); + s.push_result(false); + s.consecutive_failures += 1; + s.last_failure = Some(Instant::now()); + match s.status { + CircuitStatus::Closed if s.consecutive_failures >= failure_threshold => { + s.status = CircuitStatus::Open; + s.last_transition = Instant::now(); + } + CircuitStatus::HalfOpen => { + s.status = CircuitStatus::Open; + s.last_transition = Instant::now(); } - }); + _ => {} + } Poll::Ready(Err(CircuitError::Inner(e))) } } diff --git a/tower/src/circuit_breaker/layer.rs b/tower/src/circuit_breaker/layer.rs index 772c42e61..34cc9decf 100644 --- a/tower/src/circuit_breaker/layer.rs +++ b/tower/src/circuit_breaker/layer.rs @@ -1,4 +1,5 @@ use std::time::Duration; + use super::service::CircuitBreaker; /// [`Layer`] that wraps services in a [`CircuitBreaker`]. @@ -17,9 +18,13 @@ impl CircuitBreakerLayer { /// - `failure_threshold`: consecutive failures before the circuit opens. /// - `success_threshold`: fraction of probes that must succeed (0.0–1.0) /// before the circuit closes again. - /// - `timeout`: how long to stay open before sending a probe. + /// - `timeout`: how long to stay open before attempting recovery. pub fn new(failure_threshold: usize, success_threshold: f64, timeout: Duration) -> Self { - Self { failure_threshold, success_threshold, timeout } + Self { + failure_threshold, + success_threshold, + timeout, + } } } @@ -27,6 +32,11 @@ impl tower_layer::Layer for CircuitBreakerLayer { type Service = CircuitBreaker; fn layer(&self, inner: S) -> Self::Service { - CircuitBreaker::new(inner, self.failure_threshold, self.success_threshold, self.timeout) + CircuitBreaker::new( + inner, + self.failure_threshold, + self.success_threshold, + self.timeout, + ) } } diff --git a/tower/src/circuit_breaker/service.rs b/tower/src/circuit_breaker/service.rs index 0a1d0f5e5..fbfc9bc2d 100644 --- a/tower/src/circuit_breaker/service.rs +++ b/tower/src/circuit_breaker/service.rs @@ -1,10 +1,9 @@ use std::{ - sync::Arc, + sync::{Arc, Mutex}, task::{Context, Poll}, time::{Duration, Instant}, }; -use tokio::sync::RwLock; use tower_service::Service; use super::future::ResponseFuture; @@ -53,6 +52,7 @@ pub(crate) struct State { pub(crate) consecutive_failures: usize, pub(crate) last_failure: Option, pub(crate) last_transition: Instant, + /// Sliding window: `true` = success, `false` = failure (max 100 entries). pub(crate) window: Vec, } @@ -86,9 +86,13 @@ impl State { /// /// See the [module documentation](super) for a full example. #[derive(Clone)] +#[cfg_attr( + any(test, feature = "circuit-breaker"), + allow(missing_debug_implementations) +)] pub struct CircuitBreaker { inner: S, - pub(crate) state: Arc>, + pub(crate) state: Arc>, pub(crate) failure_threshold: usize, pub(crate) success_threshold: f64, pub(crate) timeout: Duration, @@ -104,7 +108,7 @@ impl CircuitBreaker { ) -> Self { Self { inner, - state: Arc::new(RwLock::new(State::new())), + state: Arc::new(Mutex::new(State::new())), failure_threshold, success_threshold, timeout, @@ -112,13 +116,17 @@ impl CircuitBreaker { } /// Return the current [`CircuitStatus`]. - pub async fn status(&self) -> CircuitStatus { - self.state.read().await.status.clone() + pub fn status(&self) -> CircuitStatus { + self.state + .lock() + .expect("circuit breaker state poisoned") + .status + .clone() } /// Manually close the circuit (e.g. after operator confirmation). - pub async fn reset(&self) { - let mut s = self.state.write().await; + pub fn reset(&self) { + let mut s = self.state.lock().expect("circuit breaker state poisoned"); s.status = CircuitStatus::Closed; s.consecutive_failures = 0; s.window.clear(); @@ -128,43 +136,59 @@ impl CircuitBreaker { impl Service for CircuitBreaker where - S: Service + Clone + Send + 'static, - S::Future: Send + 'static, - S::Error: Send + 'static, - S::Response: Send + 'static, - Request: Send + 'static, + S: Service, { type Response = S::Response; type Error = CircuitError; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // Check circuit state synchronously before delegating to inner. + { + let mut s = self.state.lock().expect("circuit breaker state poisoned"); + match s.status { + CircuitStatus::Open => { + let elapsed = s + .last_failure + .map(|t| t.elapsed()) + .unwrap_or(Duration::ZERO); + if elapsed < self.timeout { + return Poll::Ready(Err(CircuitError::Open)); + } + // Timeout elapsed — transition to HalfOpen. + s.status = CircuitStatus::HalfOpen; + s.window.clear(); + s.consecutive_failures = 0; + s.last_transition = Instant::now(); + } + CircuitStatus::Closed | CircuitStatus::HalfOpen => {} + } + } + self.inner.poll_ready(cx).map_err(CircuitError::Inner) } fn call(&mut self, req: Request) -> Self::Future { - let state = self.state.clone(); - let failure_threshold = self.failure_threshold; - let success_threshold = self.success_threshold; - let timeout = self.timeout; - - let mut inner = self.inner.clone(); - std::mem::swap(&mut inner, &mut self.inner); - - ResponseFuture::new(state, inner.call(req), failure_threshold, success_threshold, timeout) + ResponseFuture::new( + self.state.clone(), + self.inner.call(req), + self.failure_threshold, + self.success_threshold, + ) } } #[cfg(test)] mod tests { use super::*; + use crate::circuit_breaker::CircuitBreakerLayer; use std::time::Duration; - use tower::{ServiceBuilder, ServiceExt, service_fn}; + use tower::{ServiceBuilder, ServiceExt}; #[tokio::test] async fn closed_passes_requests_through() { let mut svc = ServiceBuilder::new() - .layer(super::super::layer::CircuitBreakerLayer::new(5, 0.8, Duration::from_secs(60))) + .layer(CircuitBreakerLayer::new(5, 0.8, Duration::from_secs(60))) .service_fn(|req: &'static str| async move { Ok::<_, &'static str>(req) }); let resp = svc.ready().await.unwrap().call("hello").await; @@ -174,14 +198,29 @@ mod tests { #[tokio::test] async fn opens_after_failure_threshold() { let mut svc = ServiceBuilder::new() - .layer(super::super::layer::CircuitBreakerLayer::new(3, 0.8, Duration::from_secs(60))) + .layer(CircuitBreakerLayer::new(3, 0.8, Duration::from_secs(60))) .service_fn(|_: &'static str| async move { Err::<&str, _>("fail") }); for _ in 0..3 { let _ = svc.ready().await.unwrap().call("req").await; } - let result = svc.ready().await.unwrap().call("req").await; + // Circuit is now Open — poll_ready should reject. + let result = svc.ready().await; assert!(matches!(result, Err(CircuitError::Open))); } + + #[tokio::test] + async fn manual_reset_closes_circuit() { + let inner = tower::service_fn(|_: &'static str| async move { Err::<&str, _>("fail") }); + let cb = CircuitBreaker::new(inner, 2, 0.8, Duration::from_secs(60)); + + // Open the circuit. + let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await; + let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await; + assert_eq!(cb.status(), CircuitStatus::Open); + + cb.reset(); + assert_eq!(cb.status(), CircuitStatus::Closed); + } } diff --git a/tower/src/lib.rs b/tower/src/lib.rs index 41cac7b92..4fdecc180 100644 --- a/tower/src/lib.rs +++ b/tower/src/lib.rs @@ -166,6 +166,8 @@ pub(crate) mod macros; pub mod balance; #[cfg(feature = "buffer")] pub mod buffer; +#[cfg(feature = "circuit-breaker")] +pub mod circuit_breaker; #[cfg(feature = "discover")] pub mod discover; #[cfg(feature = "filter")] @@ -178,8 +180,6 @@ pub mod limit; pub mod load; #[cfg(feature = "load-shed")] pub mod load_shed; -#[cfg(feature = "circuit-breaker")] -pub mod circuit_breaker; #[cfg(feature = "make")] pub mod make; From ddb88ba8523a90ad9c4db517a6665b3c2280e246 Mon Sep 17 00:00:00 2001 From: Mattbusel Date: Wed, 11 Mar 2026 14:19:40 -0400 Subject: [PATCH 3/3] refactor(circuit-breaker): Policy trait, Send/Sync docs, budget relationship - Extract CircuitPolicy trait (on_success, on_failure, should_probe, on_half_open) - Move ConsecutiveFailures into policy.rs as the built-in implementation - CircuitBreaker generic over CircuitPolicy; SharedState

replaces State - CircuitBreakerLayer

with ::new() and ::with_policy() constructors - ResponseFuture delegates outcome reporting to the policy - Document Send/Sync expectations on CircuitPolicy and CircuitBreaker structs - Document budget vs circuit breaker relationship in mod.rs and policy.rs - Add custom_policy_is_accepted test --- tower/src/circuit_breaker/future.rs | 60 ++++----- tower/src/circuit_breaker/layer.rs | 60 ++++++--- tower/src/circuit_breaker/mod.rs | 71 +++++++++-- tower/src/circuit_breaker/policy.rs | 176 +++++++++++++++++++++++++++ tower/src/circuit_breaker/service.rs | 169 ++++++++++++------------- 5 files changed, 375 insertions(+), 161 deletions(-) create mode 100644 tower/src/circuit_breaker/policy.rs diff --git a/tower/src/circuit_breaker/future.rs b/tower/src/circuit_breaker/future.rs index 19926ab28..90bc816e6 100644 --- a/tower/src/circuit_breaker/future.rs +++ b/tower/src/circuit_breaker/future.rs @@ -3,86 +3,68 @@ use std::{ pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll}, - time::Instant, }; use pin_project_lite::pin_project; -use super::service::{CircuitError, CircuitStatus, State}; +use super::{ + policy::CircuitPolicy, + service::{CircuitError, CircuitStatus, SharedState}, +}; pin_project! { /// Response future for [`CircuitBreaker`]. /// /// [`CircuitBreaker`]: super::service::CircuitBreaker - pub struct ResponseFuture { + pub struct ResponseFuture { #[pin] inner: F, - state: Arc>, - failure_threshold: usize, - success_threshold: f64, + shared: Arc>>, _marker: std::marker::PhantomData (T, E)>, } } -impl ResponseFuture { - pub(crate) fn new( - state: Arc>, - inner: F, - failure_threshold: usize, - success_threshold: f64, - ) -> Self { +impl ResponseFuture { + pub(crate) fn new(shared: Arc>>, inner: F) -> Self { Self { inner, - state, - failure_threshold, - success_threshold, + shared, _marker: std::marker::PhantomData, } } } -impl Future for ResponseFuture +impl Future for ResponseFuture where F: Future>, + P: CircuitPolicy, { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - let failure_threshold = *this.failure_threshold; - let success_threshold = *this.success_threshold; match this.inner.poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Ok(resp)) => { - let mut s = this.state.lock().expect("circuit breaker state poisoned"); - s.push_result(true); - match s.status { - CircuitStatus::HalfOpen if s.success_rate() >= success_threshold => { - s.status = CircuitStatus::Closed; - s.consecutive_failures = 0; - s.last_transition = Instant::now(); - } - CircuitStatus::Closed => { - s.consecutive_failures = 0; - } - _ => {} + let mut s = this.shared.lock().expect("circuit breaker state poisoned"); + let should_close = s.policy.on_success(); + if should_close && s.status == CircuitStatus::HalfOpen { + s.status = CircuitStatus::Closed; } Poll::Ready(Ok(resp)) } Poll::Ready(Err(e)) => { - let mut s = this.state.lock().expect("circuit breaker state poisoned"); - s.push_result(false); - s.consecutive_failures += 1; - s.last_failure = Some(Instant::now()); + let mut s = this.shared.lock().expect("circuit breaker state poisoned"); + let should_open = s.policy.on_failure(); match s.status { - CircuitStatus::Closed if s.consecutive_failures >= failure_threshold => { + // Any failure during a probe reopens immediately — + // the backend is not yet ready regardless of threshold. + CircuitStatus::HalfOpen => { s.status = CircuitStatus::Open; - s.last_transition = Instant::now(); } - CircuitStatus::HalfOpen => { + CircuitStatus::Closed if should_open => { s.status = CircuitStatus::Open; - s.last_transition = Instant::now(); } _ => {} } diff --git a/tower/src/circuit_breaker/layer.rs b/tower/src/circuit_breaker/layer.rs index 34cc9decf..ca323fe52 100644 --- a/tower/src/circuit_breaker/layer.rs +++ b/tower/src/circuit_breaker/layer.rs @@ -1,42 +1,62 @@ use std::time::Duration; -use super::service::CircuitBreaker; +use super::{ + policy::{CircuitPolicy, ConsecutiveFailures}, + service::CircuitBreaker, +}; /// [`Layer`] that wraps services in a [`CircuitBreaker`]. /// +/// Construct with [`CircuitBreakerLayer::new`] for the standard +/// [`ConsecutiveFailures`] policy, or with [`CircuitBreakerLayer::with_policy`] +/// to supply any custom [`CircuitPolicy`]. +/// /// [`Layer`]: tower_layer::Layer #[derive(Clone, Debug)] -pub struct CircuitBreakerLayer { - failure_threshold: usize, - success_threshold: f64, - timeout: Duration, +pub struct CircuitBreakerLayer

{ + policy: P, } -impl CircuitBreakerLayer { - /// Create a new [`CircuitBreakerLayer`]. +impl CircuitBreakerLayer { + /// Create a layer using the built-in [`ConsecutiveFailures`] policy. /// /// - `failure_threshold`: consecutive failures before the circuit opens. - /// - `success_threshold`: fraction of probes that must succeed (0.0–1.0) + /// - `success_threshold`: fraction of probes (0.0–1.0) that must succeed + /// during [`HalfOpen`][crate::circuit_breaker::CircuitStatus::HalfOpen] /// before the circuit closes again. - /// - `timeout`: how long to stay open before attempting recovery. + /// - `timeout`: how long to stay open before sending the first probe. pub fn new(failure_threshold: usize, success_threshold: f64, timeout: Duration) -> Self { Self { - failure_threshold, - success_threshold, - timeout, + policy: ConsecutiveFailures::new(failure_threshold, success_threshold, timeout), } } } -impl tower_layer::Layer for CircuitBreakerLayer { - type Service = CircuitBreaker; +impl CircuitBreakerLayer

{ + /// Create a layer using a custom [`CircuitPolicy`]. + /// + /// # Example + /// + /// ```rust,ignore + /// use tower::circuit_breaker::{CircuitBreakerLayer, ConsecutiveFailures}; + /// use std::time::Duration; + /// + /// // Using the built-in policy explicitly: + /// let policy = ConsecutiveFailures::new(5, 0.8, Duration::from_secs(30)); + /// let layer = CircuitBreakerLayer::with_policy(policy); + /// + /// // Or bring your own: + /// let layer = CircuitBreakerLayer::with_policy(MyLatencyPolicy::new()); + /// ``` + pub fn with_policy(policy: P) -> Self { + Self { policy } + } +} + +impl tower_layer::Layer for CircuitBreakerLayer

{ + type Service = CircuitBreaker; fn layer(&self, inner: S) -> Self::Service { - CircuitBreaker::new( - inner, - self.failure_threshold, - self.success_threshold, - self.timeout, - ) + CircuitBreaker::new(inner, self.policy.clone()) } } diff --git a/tower/src/circuit_breaker/mod.rs b/tower/src/circuit_breaker/mod.rs index 9579a8179..78d24d291 100644 --- a/tower/src/circuit_breaker/mod.rs +++ b/tower/src/circuit_breaker/mod.rs @@ -1,22 +1,71 @@ //! Circuit breaker middleware for Tower services. //! -//! Prevents cascading failures by tracking service health and short-circuiting -//! requests to a failing backend before they hit the network. +//! Prevents cascading failures by tracking service health and +//! short-circuiting requests to a failing backend before they hit the +//! network. //! //! # States //! +//! ```text +//! Closed ──(N consecutive failures)──► Open +//! Open ──(timeout elapsed)─────────► HalfOpen (one probe allowed) +//! HalfOpen ──(success rate ≥ threshold)► Closed +//! HalfOpen ──(probe fails)────────────► Open +//! ``` +//! //! - **Closed** — normal operation; all requests pass through. -//! - **Open** — service is unhealthy; requests are rejected immediately with -//! [`CircuitError::Open`], avoiding latency pile-up. -//! - **Half-Open** — after the recovery timeout elapses, one probe request is -//! allowed through. On success the circuit closes; on failure it reopens. +//! - **Open** — service is unhealthy; requests are rejected immediately +//! with [`CircuitError::Open`], avoiding latency pile-up. +//! - **Half-Open** — after the recovery timeout elapses, one probe request +//! is allowed through. On success the circuit closes; on failure it +//! reopens. +//! +//! # Policies +//! +//! The circuit-breaking logic is separated from the state machine via the +//! [`CircuitPolicy`] trait. The built-in [`ConsecutiveFailures`] policy +//! opens after *N* consecutive failures and closes once enough probes +//! succeed. Implement [`CircuitPolicy`] directly to build latency-based +//! triggers, manual switches, or any other strategy. +//! +//! # Relationship to [`tower::retry::budget`] //! -//! # Example +//! [`Budget`][budget] and circuit breakers are **complementary**, not +//! competing. +//! +//! - A **retry budget** governs *retry worthiness*: it limits how many +//! retried requests can be issued relative to the originals, preventing +//! retry amplification inside a single client. +//! - A **circuit breaker** governs *traffic admission*: once failure is +//! systemic it stops **all** requests (including first attempts) from +//! reaching the backend, giving it breathing room to recover. +//! +//! Using a circuit breaker without a budget still exposes you to retry +//! storms from clients above; using a budget without a circuit breaker +//! still allows traffic to pile up against a failing backend. The two +//! compose naturally: //! //! ```rust,ignore -//! use std::time::Duration; +//! use std::{future, sync::Arc, time::Duration}; +//! use tower::{ServiceBuilder, retry::{Policy, budget::TpsBudget}}; //! use tower::circuit_breaker::CircuitBreakerLayer; +//! +//! // Budget caps how many retries each client issues. +//! // Circuit breaker stops all traffic once failure is systemic. +//! let svc = ServiceBuilder::new() +//! .layer(CircuitBreakerLayer::new(5, 0.8, Duration::from_secs(30))) +//! .layer(tower::retry::RetryLayer::new(my_budget_policy)) +//! .service_fn(my_backend); +//! ``` +//! +//! [budget]: crate::retry::budget +//! +//! # Quick start +//! +//! ```rust,ignore +//! use std::time::Duration; //! use tower::ServiceBuilder; +//! use tower::circuit_breaker::CircuitBreakerLayer; //! //! let svc = ServiceBuilder::new() //! .layer(CircuitBreakerLayer::new( @@ -28,17 +77,15 @@ //! Ok::(req) //! }); //! ``` -//! -//! # Attribution -//! -//! Designed and implemented by Matthew Busel. mod future; mod layer; +mod policy; mod service; pub use self::{ future::ResponseFuture, layer::CircuitBreakerLayer, + policy::{CircuitPolicy, ConsecutiveFailures}, service::{CircuitBreaker, CircuitError, CircuitStatus}, }; diff --git a/tower/src/circuit_breaker/policy.rs b/tower/src/circuit_breaker/policy.rs new file mode 100644 index 000000000..47ba2b2ff --- /dev/null +++ b/tower/src/circuit_breaker/policy.rs @@ -0,0 +1,176 @@ +use std::time::{Duration, Instant}; + +/// Determines when a [`CircuitBreaker`] should open, probe, and close. +/// +/// Implement this trait to create custom circuit-breaking strategies — +/// for example latency-based triggers, error-rate thresholds, or a +/// manual operator-driven switch. The built-in [`ConsecutiveFailures`] +/// policy is a good starting point for most use cases. +/// +/// # Thread safety +/// +/// `CircuitPolicy` does **not** require [`Send`] or [`Sync`] as supertraits, +/// so single-threaded or `!Send` implementations are valid. However, because +/// the policy is stored inside `Arc>` within [`CircuitBreaker`], +/// the compiler will automatically require `P: Send` whenever +/// `CircuitBreaker` is sent across threads (e.g. handed to +/// `tokio::spawn` or used with a multi-threaded runtime). `P: Sync` is +/// **not** needed — `Mutex` provides the necessary exclusion. +/// +/// In practice, any policy that holds only owned data will be `Send` +/// automatically. If you store a raw pointer or `Rc` in your policy, it will +/// not be usable in a multi-threaded context — the compiler will tell you so +/// at the call site. +/// +/// # Relationship to [`tower::retry::budget`] +/// +/// [`Budget`][budget] governs *retry worthiness*: it caps the ratio of +/// retried requests to original requests, preventing retry amplification. +/// A circuit breaker governs *traffic admission*: it gates **all** +/// requests (including first attempts) when a backend is known to be +/// unhealthy. +/// +/// The two compose naturally: +/// +/// - A budget limits how aggressively clients retry individual requests. +/// - A circuit breaker stops all traffic once failure is systemic, giving +/// the backend time to recover without being drowned in retried load. +/// +/// Using a circuit breaker *without* a budget still exposes you to retry +/// amplification from layers above; the combination of both provides full +/// protection against retry storms. +/// +/// ```text +/// ┌──────────────────────────────────────┐ +/// │ ServiceBuilder │ +/// │ .layer(CircuitBreakerLayer::…) ◄── gates all traffic when open +/// │ .layer(RetryLayer::new(policy)) ◄── budget inside policy caps retries +/// │ .service_fn(my_backend) │ +/// └──────────────────────────────────────┘ +/// ``` +/// +/// [budget]: crate::retry::budget +/// [`CircuitBreaker`]: super::service::CircuitBreaker +pub trait CircuitPolicy { + /// Called after a **successful** response from the inner service. + /// + /// Return `true` to signal that the circuit should close. This is + /// acted upon only while the circuit is + /// [`HalfOpen`][crate::circuit_breaker::CircuitStatus::HalfOpen]; + /// returning `true` from a [`Closed`][crate::circuit_breaker::CircuitStatus::Closed] + /// state is a no-op. + fn on_success(&mut self) -> bool; + + /// Called after a **failed** response from the inner service. + /// + /// Return `true` to signal that the circuit should open. This is + /// acted upon when the circuit is + /// [`Closed`][crate::circuit_breaker::CircuitStatus::Closed]. + /// + /// Any failure while the circuit is + /// [`HalfOpen`][crate::circuit_breaker::CircuitStatus::HalfOpen] + /// always reopens it, regardless of the return value — the probe + /// failed, so the backend is not yet ready. + fn on_failure(&mut self) -> bool; + + /// Called while the circuit is [`Open`][crate::circuit_breaker::CircuitStatus::Open]. + /// + /// Return `true` to allow a probe request through (transitions the + /// circuit to [`HalfOpen`][crate::circuit_breaker::CircuitStatus::HalfOpen]). + fn should_probe(&self) -> bool; + + /// Called immediately after the circuit transitions to + /// [`HalfOpen`][crate::circuit_breaker::CircuitStatus::HalfOpen]. + /// + /// Use this hook to reset per-window counters so that the recovery + /// success rate is measured only from post-recovery probes, not from + /// stale pre-outage history. + fn on_half_open(&mut self); +} + +// --------------------------------------------------------------------------- +// ConsecutiveFailures — the built-in policy +// --------------------------------------------------------------------------- + +/// A [`CircuitPolicy`] that opens the circuit after *N* consecutive failures +/// and closes it again once a sufficient fraction of probes succeed. +/// +/// # Parameters +/// +/// | Parameter | Description | +/// |---|---| +/// | `failure_threshold` | Number of consecutive failures needed to open the circuit. | +/// | `success_threshold` | Fraction of HalfOpen probes (0.0–1.0) that must succeed to close. | +/// | `timeout` | How long to stay Open before sending the first probe. | +/// +/// # Example +/// +/// ```rust,ignore +/// use tower::circuit_breaker::{CircuitBreakerLayer, ConsecutiveFailures}; +/// use std::time::Duration; +/// +/// let policy = ConsecutiveFailures::new(5, 0.8, Duration::from_secs(30)); +/// let layer = CircuitBreakerLayer::with_policy(policy); +/// ``` +#[derive(Clone, Debug)] +pub struct ConsecutiveFailures { + failure_threshold: usize, + success_threshold: f64, + timeout: Duration, + consecutive_failures: usize, + /// Set when the circuit opens; used by `should_probe`. + open_since: Option, + /// Sliding window of outcomes during HalfOpen (max 100 entries). + window: Vec, +} + +impl ConsecutiveFailures { + /// Create a new [`ConsecutiveFailures`] policy. + pub fn new(failure_threshold: usize, success_threshold: f64, timeout: Duration) -> Self { + Self { + failure_threshold, + success_threshold, + timeout, + consecutive_failures: 0, + open_since: None, + window: Vec::with_capacity(32), + } + } +} + +impl CircuitPolicy for ConsecutiveFailures { + fn on_success(&mut self) -> bool { + self.consecutive_failures = 0; + self.window.push(true); + if self.window.len() > 100 { + self.window.remove(0); + } + let rate = self.window.iter().filter(|&&v| v).count() as f64 + / self.window.len() as f64; + rate >= self.success_threshold + } + + fn on_failure(&mut self) -> bool { + self.consecutive_failures += 1; + self.window.push(false); + if self.window.len() > 100 { + self.window.remove(0); + } + let should_open = self.consecutive_failures >= self.failure_threshold; + if should_open { + self.open_since = Some(Instant::now()); + } + should_open + } + + fn should_probe(&self) -> bool { + self.open_since + .map(|t| t.elapsed() >= self.timeout) + .unwrap_or(false) + } + + fn on_half_open(&mut self) { + self.window.clear(); + self.consecutive_failures = 0; + } +} diff --git a/tower/src/circuit_breaker/service.rs b/tower/src/circuit_breaker/service.rs index fbfc9bc2d..995b98224 100644 --- a/tower/src/circuit_breaker/service.rs +++ b/tower/src/circuit_breaker/service.rs @@ -1,12 +1,11 @@ use std::{ sync::{Arc, Mutex}, task::{Context, Poll}, - time::{Duration, Instant}, }; use tower_service::Service; -use super::future::ResponseFuture; +use super::{future::ResponseFuture, policy::CircuitPolicy}; /// Current state of a [`CircuitBreaker`] service. #[derive(Debug, Clone, PartialEq, Eq)] @@ -46,122 +45,95 @@ impl std::error::Error for CircuitError { } } -#[derive(Debug)] -pub(crate) struct State { +/// Shared mutable state between a [`CircuitBreaker`] and its [`ResponseFuture`]. +pub(crate) struct SharedState

{ pub(crate) status: CircuitStatus, - pub(crate) consecutive_failures: usize, - pub(crate) last_failure: Option, - pub(crate) last_transition: Instant, - /// Sliding window: `true` = success, `false` = failure (max 100 entries). - pub(crate) window: Vec, + pub(crate) policy: P, } -impl State { - pub(crate) fn new() -> Self { - Self { - status: CircuitStatus::Closed, - consecutive_failures: 0, - last_failure: None, - last_transition: Instant::now(), - window: Vec::with_capacity(100), - } - } - - pub(crate) fn push_result(&mut self, success: bool) { - self.window.push(success); - if self.window.len() > 100 { - self.window.remove(0); - } - } - - pub(crate) fn success_rate(&self) -> f64 { - if self.window.is_empty() { - return 0.0; - } - self.window.iter().filter(|&&v| v).count() as f64 / self.window.len() as f64 - } -} - -/// Tower [`Service`] that implements the circuit-breaker pattern. +/// Tower [`Service`] implementing the circuit-breaker pattern. +/// +/// The open/probe/close criteria are driven by a [`CircuitPolicy`], making +/// the triggering logic independently customisable. The built-in policy is +/// [`ConsecutiveFailures`]; supply any type implementing [`CircuitPolicy`] +/// via [`CircuitBreaker::new`] or [`CircuitBreakerLayer::with_policy`] for +/// custom strategies. +/// +/// # Thread safety +/// +/// `CircuitBreaker` is [`Send`] when both `S` and `P` are [`Send`]. +/// This is enforced structurally: the policy is held behind +/// `Arc>`, so `Arc>: Send` requires `P: Send`. +/// No explicit bound is placed on `P` in the [`Service`] impl, so +/// `!Send` policies can still be used in single-threaded contexts without +/// a compile error. `P: Sync` is never required. +/// +/// See [`CircuitPolicy`] for more detail. /// /// See the [module documentation](super) for a full example. +/// +/// [`ConsecutiveFailures`]: super::ConsecutiveFailures +/// [`CircuitBreakerLayer::with_policy`]: super::CircuitBreakerLayer::with_policy +/// [`CircuitPolicy`]: super::CircuitPolicy #[derive(Clone)] -#[cfg_attr( - any(test, feature = "circuit-breaker"), - allow(missing_debug_implementations) -)] -pub struct CircuitBreaker { +pub struct CircuitBreaker { inner: S, - pub(crate) state: Arc>, - pub(crate) failure_threshold: usize, - pub(crate) success_threshold: f64, - pub(crate) timeout: Duration, + pub(crate) shared: Arc>>, } -impl CircuitBreaker { - /// Wrap `inner` in a circuit breaker. - pub fn new( - inner: S, - failure_threshold: usize, - success_threshold: f64, - timeout: Duration, - ) -> Self { +impl CircuitBreaker { + /// Wrap `inner` with the given [`CircuitPolicy`]. + pub fn new(inner: S, policy: P) -> Self { Self { inner, - state: Arc::new(Mutex::new(State::new())), - failure_threshold, - success_threshold, - timeout, + shared: Arc::new(Mutex::new(SharedState { + status: CircuitStatus::Closed, + policy, + })), } } /// Return the current [`CircuitStatus`]. pub fn status(&self) -> CircuitStatus { - self.state + self.shared .lock() .expect("circuit breaker state poisoned") .status .clone() } - /// Manually close the circuit (e.g. after operator confirmation). + /// Manually close the circuit (e.g. after operator confirmation that the + /// backend is healthy). + /// + /// Calls [`CircuitPolicy::on_half_open`] to reset any per-window counters + /// in the policy, then sets the status to [`Closed`][CircuitStatus::Closed]. pub fn reset(&self) { - let mut s = self.state.lock().expect("circuit breaker state poisoned"); + let mut s = self.shared.lock().expect("circuit breaker state poisoned"); + s.policy.on_half_open(); // reuse the window-clear hook s.status = CircuitStatus::Closed; - s.consecutive_failures = 0; - s.window.clear(); - s.last_transition = Instant::now(); } } -impl Service for CircuitBreaker +impl Service for CircuitBreaker where S: Service, + P: CircuitPolicy, { type Response = S::Response; type Error = CircuitError; - type Future = ResponseFuture; + type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - // Check circuit state synchronously before delegating to inner. { - let mut s = self.state.lock().expect("circuit breaker state poisoned"); - match s.status { - CircuitStatus::Open => { - let elapsed = s - .last_failure - .map(|t| t.elapsed()) - .unwrap_or(Duration::ZERO); - if elapsed < self.timeout { - return Poll::Ready(Err(CircuitError::Open)); - } - // Timeout elapsed — transition to HalfOpen. + let mut s = self.shared.lock().expect("circuit breaker state poisoned"); + if s.status == CircuitStatus::Open { + if s.policy.should_probe() { + s.policy.on_half_open(); s.status = CircuitStatus::HalfOpen; - s.window.clear(); - s.consecutive_failures = 0; - s.last_transition = Instant::now(); + // fall through to delegate to inner service + } else { + return Poll::Ready(Err(CircuitError::Open)); } - CircuitStatus::Closed | CircuitStatus::HalfOpen => {} } } @@ -169,19 +141,16 @@ where } fn call(&mut self, req: Request) -> Self::Future { - ResponseFuture::new( - self.state.clone(), - self.inner.call(req), - self.failure_threshold, - self.success_threshold, - ) + ResponseFuture::new(self.shared.clone(), self.inner.call(req)) } } +// ===== Tests ===== + #[cfg(test)] mod tests { use super::*; - use crate::circuit_breaker::CircuitBreakerLayer; + use crate::circuit_breaker::{CircuitBreakerLayer, ConsecutiveFailures}; use std::time::Duration; use tower::{ServiceBuilder, ServiceExt}; @@ -213,9 +182,9 @@ mod tests { #[tokio::test] async fn manual_reset_closes_circuit() { let inner = tower::service_fn(|_: &'static str| async move { Err::<&str, _>("fail") }); - let cb = CircuitBreaker::new(inner, 2, 0.8, Duration::from_secs(60)); + let policy = ConsecutiveFailures::new(2, 0.8, Duration::from_secs(60)); + let cb = CircuitBreaker::new(inner, policy); - // Open the circuit. let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await; let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await; assert_eq!(cb.status(), CircuitStatus::Open); @@ -223,4 +192,24 @@ mod tests { cb.reset(); assert_eq!(cb.status(), CircuitStatus::Closed); } + + #[tokio::test] + async fn custom_policy_is_accepted() { + // Verify the Service impl compiles and runs with a hand-rolled policy. + #[derive(Clone)] + struct AlwaysOpen; + impl CircuitPolicy for AlwaysOpen { + fn on_success(&mut self) -> bool { false } + fn on_failure(&mut self) -> bool { true } + fn should_probe(&self) -> bool { false } + fn on_half_open(&mut self) {} + } + + let inner = tower::service_fn(|_: &'static str| async move { Err::<&str, _>("x") }); + let cb = CircuitBreaker::new(inner, AlwaysOpen); + + // One failure should open the circuit. + let _ = tower::ServiceExt::oneshot(cb.clone(), "req").await; + assert_eq!(cb.status(), CircuitStatus::Open); + } }