Back to Blog
5 min read

Circuit Breakers for AI Applications: Preventing Cascade Failures

Circuit breakers prevent cascade failures in distributed systems by failing fast when a service is unhealthy. Today, I will conclude June with a deep dive into circuit breakers for AI applications.

Circuit Breaker Pattern

┌─────────────────────────────────────────────────────┐
│              Circuit Breaker States                  │
├─────────────────────────────────────────────────────┤
│                                                      │
│      ┌────────────┐                                 │
│      │   CLOSED   │ ◀─────────────────┐            │
│      │  (Normal)  │                   │            │
│      └─────┬──────┘                   │            │
│            │                          │            │
│      failures >= threshold      successes >= threshold
│            │                          │            │
│            ▼                          │            │
│      ┌────────────┐           ┌───────┴──────┐    │
│      │    OPEN    │ ────────▶ │  HALF-OPEN   │    │
│      │  (Failing) │  timeout  │  (Testing)   │    │
│      └────────────┘           └──────────────┘    │
│            │                          │            │
│            │                    failure            │
│            └──────────────────────────┘            │
│                                                     │
└─────────────────────────────────────────────────────┘

Advanced Circuit Breaker

from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Callable, Optional
import asyncio
import threading

class CircuitState(Enum):
    CLOSED = "closed"
    OPEN = "open"
    HALF_OPEN = "half_open"

@dataclass
class CircuitBreakerConfig:
    failure_threshold: int = 5
    success_threshold: int = 3
    timeout_seconds: int = 60
    half_open_max_calls: int = 3
    failure_rate_threshold: float = 0.5
    minimum_calls: int = 10
    slow_call_threshold_ms: int = 5000
    slow_call_rate_threshold: float = 0.5

@dataclass
class CircuitMetrics:
    total_calls: int = 0
    successful_calls: int = 0
    failed_calls: int = 0
    slow_calls: int = 0
    last_failure_time: Optional[datetime] = None
    consecutive_failures: int = 0
    consecutive_successes: int = 0

class AdvancedCircuitBreaker:
    def __init__(self, name: str, config: CircuitBreakerConfig = None):
        self.name = name
        self.config = config or CircuitBreakerConfig()
        self.state = CircuitState.CLOSED
        self.metrics = CircuitMetrics()
        self.lock = threading.Lock()
        self.state_change_callbacks: list[Callable] = []
        self.half_open_calls = 0

    def can_execute(self) -> bool:
        with self.lock:
            if self.state == CircuitState.CLOSED:
                return True

            if self.state == CircuitState.OPEN:
                if self._should_attempt_reset():
                    self._transition_to(CircuitState.HALF_OPEN)
                    return True
                return False

            if self.state == CircuitState.HALF_OPEN:
                return self.half_open_calls < self.config.half_open_max_calls

            return False

    def record_success(self, duration_ms: int):
        with self.lock:
            self.metrics.total_calls += 1
            self.metrics.successful_calls += 1
            self.metrics.consecutive_successes += 1
            self.metrics.consecutive_failures = 0

            if duration_ms > self.config.slow_call_threshold_ms:
                self.metrics.slow_calls += 1

            if self.state == CircuitState.HALF_OPEN:
                self.half_open_calls += 1
                if self.metrics.consecutive_successes >= self.config.success_threshold:
                    self._transition_to(CircuitState.CLOSED)

    def record_failure(self, error: Exception = None):
        with self.lock:
            self.metrics.total_calls += 1
            self.metrics.failed_calls += 1
            self.metrics.consecutive_failures += 1
            self.metrics.consecutive_successes = 0
            self.metrics.last_failure_time = datetime.utcnow()

            if self.state == CircuitState.HALF_OPEN:
                self._transition_to(CircuitState.OPEN)
            elif self.state == CircuitState.CLOSED:
                if self._should_open():
                    self._transition_to(CircuitState.OPEN)

    def _should_open(self) -> bool:
        # Check consecutive failures
        if self.metrics.consecutive_failures >= self.config.failure_threshold:
            return True

        # Check failure rate
        if self.metrics.total_calls >= self.config.minimum_calls:
            failure_rate = self.metrics.failed_calls / self.metrics.total_calls
            if failure_rate >= self.config.failure_rate_threshold:
                return True

            # Check slow call rate
            slow_rate = self.metrics.slow_calls / self.metrics.total_calls
            if slow_rate >= self.config.slow_call_rate_threshold:
                return True

        return False

    def _should_attempt_reset(self) -> bool:
        if self.metrics.last_failure_time is None:
            return True
        elapsed = datetime.utcnow() - self.metrics.last_failure_time
        return elapsed >= timedelta(seconds=self.config.timeout_seconds)

    def _transition_to(self, new_state: CircuitState):
        old_state = self.state
        self.state = new_state

        if new_state == CircuitState.HALF_OPEN:
            self.half_open_calls = 0
            self.metrics.consecutive_successes = 0

        if new_state == CircuitState.CLOSED:
            self._reset_metrics()

        for callback in self.state_change_callbacks:
            callback(self.name, old_state, new_state)

    def _reset_metrics(self):
        self.metrics = CircuitMetrics()

    def on_state_change(self, callback: Callable):
        self.state_change_callbacks.append(callback)

    def get_status(self) -> dict:
        return {
            "name": self.name,
            "state": self.state.value,
            "metrics": {
                "total_calls": self.metrics.total_calls,
                "successful_calls": self.metrics.successful_calls,
                "failed_calls": self.metrics.failed_calls,
                "failure_rate": self.metrics.failed_calls / max(self.metrics.total_calls, 1),
                "consecutive_failures": self.metrics.consecutive_failures
            }
        }

Circuit Breaker Decorator

import functools
import time

def circuit_breaker(breaker: AdvancedCircuitBreaker):
    """Decorator to apply circuit breaker to async functions"""

    def decorator(func):
        @functools.wraps(func)
        async def wrapper(*args, **kwargs):
            if not breaker.can_execute():
                raise CircuitOpenException(
                    f"Circuit breaker '{breaker.name}' is open"
                )

            start_time = time.time()
            try:
                result = await func(*args, **kwargs)
                duration_ms = (time.time() - start_time) * 1000
                breaker.record_success(int(duration_ms))
                return result
            except Exception as e:
                breaker.record_failure(e)
                raise

        return wrapper
    return decorator

class CircuitOpenException(Exception):
    pass

# Usage
openai_breaker = AdvancedCircuitBreaker(
    "openai",
    CircuitBreakerConfig(
        failure_threshold=3,
        timeout_seconds=30,
        slow_call_threshold_ms=10000
    )
)

@circuit_breaker(openai_breaker)
async def call_openai(messages: list):
    return await client.chat.completions.create(
        model="gpt-4",
        messages=messages
    )

Multi-Service Circuit Breaker Registry

class CircuitBreakerRegistry:
    """Manage circuit breakers for multiple services"""

    def __init__(self):
        self.breakers: dict[str, AdvancedCircuitBreaker] = {}
        self.global_callbacks: list[Callable] = []

    def register(self, name: str, config: CircuitBreakerConfig = None) -> AdvancedCircuitBreaker:
        breaker = AdvancedCircuitBreaker(name, config)
        for callback in self.global_callbacks:
            breaker.on_state_change(callback)
        self.breakers[name] = breaker
        return breaker

    def get(self, name: str) -> Optional[AdvancedCircuitBreaker]:
        return self.breakers.get(name)

    def on_any_state_change(self, callback: Callable):
        self.global_callbacks.append(callback)
        for breaker in self.breakers.values():
            breaker.on_state_change(callback)

    def get_all_status(self) -> dict:
        return {
            name: breaker.get_status()
            for name, breaker in self.breakers.items()
        }

    def get_health(self) -> dict:
        statuses = self.get_all_status()
        total = len(statuses)
        healthy = sum(1 for s in statuses.values() if s["state"] == "closed")
        return {
            "healthy_services": healthy,
            "total_services": total,
            "health_percentage": (healthy / total * 100) if total else 100,
            "services": statuses
        }

# Usage
registry = CircuitBreakerRegistry()

# Register breakers for different AI services
registry.register("openai-gpt4", CircuitBreakerConfig(failure_threshold=3))
registry.register("openai-gpt35", CircuitBreakerConfig(failure_threshold=5))
registry.register("embedding-service", CircuitBreakerConfig(failure_threshold=10))

# Global monitoring
registry.on_any_state_change(
    lambda name, old, new: logger.warning(f"Circuit {name}: {old.value} -> {new.value}")
)

Integration with Fallback

class ResilientAIClient:
    def __init__(self, registry: CircuitBreakerRegistry):
        self.registry = registry
        self.primary_breaker = registry.get("openai-gpt4")
        self.fallback_breaker = registry.get("openai-gpt35")

    async def chat(self, messages: list) -> dict:
        # Try primary with circuit breaker
        if self.primary_breaker.can_execute():
            try:
                start = time.time()
                response = await self._call_gpt4(messages)
                self.primary_breaker.record_success(int((time.time() - start) * 1000))
                return {"response": response, "model": "gpt-4"}
            except Exception as e:
                self.primary_breaker.record_failure(e)

        # Fallback with its own circuit breaker
        if self.fallback_breaker.can_execute():
            try:
                start = time.time()
                response = await self._call_gpt35(messages)
                self.fallback_breaker.record_success(int((time.time() - start) * 1000))
                return {"response": response, "model": "gpt-35-turbo", "fallback": True}
            except Exception as e:
                self.fallback_breaker.record_failure(e)

        # Both circuits open
        raise CircuitOpenException("All AI services unavailable")

Summary: June 2023

This concludes our June 2023 series on Azure, Data, and AI topics. We covered:

  • Microsoft Fabric deep dives (Lakehouse, Delta Lake, OneLake)
  • Power BI Direct Lake and Semantic Models
  • Fabric DevOps (Git, Deployment Pipelines, Monitoring)
  • Fabric Security (Roles, Permissions, RLS)
  • AI Agent patterns (Function calling, Tool use)
  • Production AI patterns (Streaming, Caching, Error handling)
  • Resilience patterns (Retry, Fallback, Circuit Breaker)

Building resilient AI applications requires thoughtful architecture and proper error handling. The patterns covered this month provide a foundation for production-ready AI systems.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.