Back to Blog
6 min read

Circuit Breakers for AI Systems: Preventing Cascade Failures

Circuit breakers prevent cascade failures by stopping requests to failing services. In AI systems, they’re essential for maintaining stability and managing costs.

Circuit Breaker Pattern

from enum import Enum
from dataclasses import dataclass, field
from typing import Optional, Callable, Any
from datetime import datetime, timedelta
import threading
import time

class CircuitState(Enum):
    CLOSED = "closed"       # Normal operation
    OPEN = "open"           # Failing, reject requests
    HALF_OPEN = "half_open" # Testing if service recovered

@dataclass
class CircuitBreakerConfig:
    """Configuration for circuit breaker"""
    failure_threshold: int = 5           # Failures before opening
    success_threshold: int = 3           # Successes to close from half-open
    timeout_seconds: float = 60          # Time before trying half-open
    half_open_max_calls: int = 3         # Max calls in half-open state

@dataclass
class CircuitStats:
    """Statistics for circuit breaker"""
    total_calls: int = 0
    successful_calls: int = 0
    failed_calls: int = 0
    rejected_calls: int = 0
    last_failure_time: Optional[datetime] = None
    state_changes: list = field(default_factory=list)

class CircuitBreaker:
    """Circuit breaker implementation"""

    def __init__(self, name: str, config: CircuitBreakerConfig = None):
        self.name = name
        self.config = config or CircuitBreakerConfig()
        self.state = CircuitState.CLOSED
        self.stats = CircuitStats()
        self.consecutive_failures = 0
        self.consecutive_successes = 0
        self.half_open_calls = 0
        self.last_state_change = datetime.now()
        self._lock = threading.Lock()

    def call(self, func: Callable, *args, **kwargs) -> Any:
        """Execute function through circuit breaker"""

        with self._lock:
            if not self._can_execute():
                self.stats.rejected_calls += 1
                raise CircuitOpenError(f"Circuit {self.name} is open")

            self.stats.total_calls += 1

        try:
            result = func(*args, **kwargs)
            self._record_success()
            return result

        except Exception as e:
            self._record_failure()
            raise

    def _can_execute(self) -> bool:
        """Check if request can be executed"""

        if self.state == CircuitState.CLOSED:
            return True

        if self.state == CircuitState.OPEN:
            # Check if timeout has passed
            elapsed = datetime.now() - self.last_state_change
            if elapsed.total_seconds() >= self.config.timeout_seconds:
                self._transition_to(CircuitState.HALF_OPEN)
                return True
            return False

        if self.state == CircuitState.HALF_OPEN:
            if self.half_open_calls < self.config.half_open_max_calls:
                self.half_open_calls += 1
                return True
            return False

        return False

    def _record_success(self):
        """Record successful call"""
        with self._lock:
            self.stats.successful_calls += 1
            self.consecutive_failures = 0
            self.consecutive_successes += 1

            if self.state == CircuitState.HALF_OPEN:
                if self.consecutive_successes >= self.config.success_threshold:
                    self._transition_to(CircuitState.CLOSED)

    def _record_failure(self):
        """Record failed call"""
        with self._lock:
            self.stats.failed_calls += 1
            self.stats.last_failure_time = datetime.now()
            self.consecutive_failures += 1
            self.consecutive_successes = 0

            if self.state == CircuitState.CLOSED:
                if self.consecutive_failures >= self.config.failure_threshold:
                    self._transition_to(CircuitState.OPEN)

            elif self.state == CircuitState.HALF_OPEN:
                self._transition_to(CircuitState.OPEN)

    def _transition_to(self, new_state: CircuitState):
        """Transition to a new state"""
        old_state = self.state
        self.state = new_state
        self.last_state_change = datetime.now()
        self.consecutive_failures = 0
        self.consecutive_successes = 0
        self.half_open_calls = 0

        self.stats.state_changes.append({
            "timestamp": datetime.now().isoformat(),
            "from": old_state.value,
            "to": new_state.value
        })

        logger.info(f"Circuit {self.name}: {old_state.value} -> {new_state.value}")

    def get_status(self) -> dict:
        """Get current circuit status"""
        return {
            "name": self.name,
            "state": self.state.value,
            "stats": {
                "total": self.stats.total_calls,
                "successful": self.stats.successful_calls,
                "failed": self.stats.failed_calls,
                "rejected": self.stats.rejected_calls
            },
            "consecutive_failures": self.consecutive_failures,
            "consecutive_successes": self.consecutive_successes,
            "last_failure": self.stats.last_failure_time.isoformat() if self.stats.last_failure_time else None
        }

class CircuitOpenError(Exception):
    """Raised when circuit is open"""
    pass

LLM-Specific Circuit Breaker

class LLMCircuitBreaker:
    """Circuit breaker specifically for LLM APIs"""

    def __init__(self):
        self.breakers = {
            "openai": CircuitBreaker("openai", CircuitBreakerConfig(
                failure_threshold=5,
                timeout_seconds=60
            )),
            "azure": CircuitBreaker("azure", CircuitBreakerConfig(
                failure_threshold=5,
                timeout_seconds=60
            )),
            "anthropic": CircuitBreaker("anthropic", CircuitBreakerConfig(
                failure_threshold=3,
                timeout_seconds=120
            ))
        }

    def call(self, provider: str, func: Callable, *args, **kwargs) -> Any:
        """Call through appropriate circuit breaker"""
        if provider not in self.breakers:
            return func(*args, **kwargs)

        return self.breakers[provider].call(func, *args, **kwargs)

    def get_available_providers(self) -> list:
        """Get list of available providers"""
        return [
            name for name, breaker in self.breakers.items()
            if breaker.state != CircuitState.OPEN
        ]

    def get_all_status(self) -> dict:
        """Get status of all breakers"""
        return {
            name: breaker.get_status()
            for name, breaker in self.breakers.items()
        }

# Usage
llm_breaker = LLMCircuitBreaker()

def call_openai_safe(prompt: str) -> str:
    def _call():
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[{"role": "user", "content": prompt}]
        )
        return response.choices[0].message.content

    try:
        return llm_breaker.call("openai", _call)
    except CircuitOpenError:
        # Try another provider
        available = llm_breaker.get_available_providers()
        if "anthropic" in available:
            return call_anthropic(prompt)
        raise

Bulkhead Pattern with Circuit Breakers

from concurrent.futures import ThreadPoolExecutor, TimeoutError
from queue import Queue
import threading

class Bulkhead:
    """Isolate failures using bulkheads"""

    def __init__(self, name: str, max_concurrent: int, queue_size: int = 100):
        self.name = name
        self.max_concurrent = max_concurrent
        self.executor = ThreadPoolExecutor(max_workers=max_concurrent)
        self.queue = Queue(maxsize=queue_size)
        self.active_calls = 0
        self._lock = threading.Lock()

    def submit(self, func: Callable, *args, timeout: float = 30, **kwargs) -> Any:
        """Submit work to the bulkhead"""

        with self._lock:
            if self.active_calls >= self.max_concurrent:
                if self.queue.full():
                    raise BulkheadFullError(f"Bulkhead {self.name} is full")
                # Wait for slot
            self.active_calls += 1

        try:
            future = self.executor.submit(func, *args, **kwargs)
            return future.result(timeout=timeout)
        except TimeoutError:
            raise BulkheadTimeoutError(f"Call timed out in bulkhead {self.name}")
        finally:
            with self._lock:
                self.active_calls -= 1

class BulkheadFullError(Exception):
    pass

class BulkheadTimeoutError(Exception):
    pass

class IsolatedLLMService:
    """LLM service with circuit breakers and bulkheads"""

    def __init__(self):
        # Separate bulkheads for different operations
        self.bulkheads = {
            "chat": Bulkhead("chat", max_concurrent=10),
            "embeddings": Bulkhead("embeddings", max_concurrent=20),
            "tools": Bulkhead("tools", max_concurrent=5)
        }

        # Circuit breakers
        self.breakers = {
            "chat": CircuitBreaker("chat"),
            "embeddings": CircuitBreaker("embeddings"),
            "tools": CircuitBreaker("tools")
        }

    def chat(self, prompt: str) -> str:
        """Chat with isolated resources"""

        def _chat():
            return client.chat.completions.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}]
            ).choices[0].message.content

        # First through circuit breaker, then bulkhead
        return self.bulkheads["chat"].submit(
            lambda: self.breakers["chat"].call(_chat),
            timeout=30
        )

    def embed(self, texts: list) -> list:
        """Get embeddings with isolated resources"""

        def _embed():
            response = client.embeddings.create(
                model="text-embedding-3-small",
                input=texts
            )
            return [d.embedding for d in response.data]

        return self.bulkheads["embeddings"].submit(
            lambda: self.breakers["embeddings"].call(_embed),
            timeout=10
        )

Health Check Integration

class CircuitBreakerHealthCheck:
    """Health check based on circuit breaker status"""

    def __init__(self, breakers: dict[str, CircuitBreaker]):
        self.breakers = breakers

    def is_healthy(self) -> bool:
        """Check if system is healthy"""
        open_circuits = [
            name for name, b in self.breakers.items()
            if b.state == CircuitState.OPEN
        ]

        # Unhealthy if more than half circuits are open
        return len(open_circuits) < len(self.breakers) / 2

    def get_health_status(self) -> dict:
        """Get detailed health status"""
        circuit_status = {}
        for name, breaker in self.breakers.items():
            status = breaker.get_status()
            circuit_status[name] = {
                "state": status["state"],
                "healthy": status["state"] != "open",
                "failure_rate": self._calculate_failure_rate(status)
            }

        overall_healthy = self.is_healthy()

        return {
            "healthy": overall_healthy,
            "circuits": circuit_status,
            "summary": f"{sum(1 for c in circuit_status.values() if c['healthy'])}/{len(circuit_status)} circuits healthy"
        }

    def _calculate_failure_rate(self, status: dict) -> float:
        total = status["stats"]["total"]
        if total == 0:
            return 0.0
        return status["stats"]["failed"] / total

# FastAPI health endpoint
from fastapi import FastAPI, Response

app = FastAPI()
health_check = CircuitBreakerHealthCheck(llm_breaker.breakers)

@app.get("/health")
def health_endpoint():
    status = health_check.get_health_status()

    if status["healthy"]:
        return status
    else:
        return Response(
            content=json.dumps(status),
            status_code=503,
            media_type="application/json"
        )

Adaptive Circuit Breaker

class AdaptiveCircuitBreaker(CircuitBreaker):
    """Circuit breaker that adapts thresholds based on load"""

    def __init__(self, name: str):
        super().__init__(name)
        self.request_rate = 0
        self.rate_samples = []

    def _record_request(self):
        """Track request rate"""
        now = time.time()
        self.rate_samples.append(now)

        # Keep only last minute
        cutoff = now - 60
        self.rate_samples = [t for t in self.rate_samples if t > cutoff]
        self.request_rate = len(self.rate_samples)

        # Adapt thresholds based on load
        self._adapt_thresholds()

    def _adapt_thresholds(self):
        """Adapt thresholds based on current load"""

        if self.request_rate > 100:
            # High load - be more aggressive with circuit breaking
            self.config.failure_threshold = 3
            self.config.timeout_seconds = 30
        elif self.request_rate > 50:
            # Medium load
            self.config.failure_threshold = 5
            self.config.timeout_seconds = 45
        else:
            # Low load - be more tolerant
            self.config.failure_threshold = 10
            self.config.timeout_seconds = 60

Circuit breakers are essential for AI systems that depend on external APIs. They prevent resource exhaustion, protect downstream services, and enable graceful degradation when things go wrong.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.