Back to Blog
5 min read

Retry Strategies for AI Applications

Proper retry strategies are essential for reliable AI applications. Today, I will cover advanced retry patterns and implementation strategies.

Retry Strategy Types

from enum import Enum
from abc import ABC, abstractmethod
import random
import time

class RetryStrategy(ABC):
    @abstractmethod
    def get_delay(self, attempt: int) -> float:
        pass

class ConstantRetry(RetryStrategy):
    def __init__(self, delay: float = 1.0):
        self.delay = delay

    def get_delay(self, attempt: int) -> float:
        return self.delay

class LinearRetry(RetryStrategy):
    def __init__(self, initial_delay: float = 1.0, increment: float = 1.0):
        self.initial = initial_delay
        self.increment = increment

    def get_delay(self, attempt: int) -> float:
        return self.initial + (attempt * self.increment)

class ExponentialRetry(RetryStrategy):
    def __init__(self, base_delay: float = 1.0, multiplier: float = 2.0, max_delay: float = 60.0):
        self.base = base_delay
        self.multiplier = multiplier
        self.max_delay = max_delay

    def get_delay(self, attempt: int) -> float:
        delay = self.base * (self.multiplier ** attempt)
        return min(delay, self.max_delay)

class ExponentialWithJitter(ExponentialRetry):
    def __init__(self, base_delay: float = 1.0, multiplier: float = 2.0, max_delay: float = 60.0, jitter_range: float = 0.5):
        super().__init__(base_delay, multiplier, max_delay)
        self.jitter_range = jitter_range

    def get_delay(self, attempt: int) -> float:
        base_delay = super().get_delay(attempt)
        jitter = base_delay * self.jitter_range * random.random()
        return base_delay + jitter

Configurable Retry Handler

from dataclasses import dataclass
from typing import Callable, List, Type
import asyncio

@dataclass
class RetryConfig:
    max_attempts: int = 3
    strategy: RetryStrategy = None
    retryable_exceptions: List[Type[Exception]] = None
    on_retry: Callable = None
    on_failure: Callable = None

    def __post_init__(self):
        if self.strategy is None:
            self.strategy = ExponentialWithJitter()
        if self.retryable_exceptions is None:
            self.retryable_exceptions = [Exception]

class RetryHandler:
    def __init__(self, config: RetryConfig):
        self.config = config

    async def execute(self, func: Callable, *args, **kwargs):
        last_exception = None

        for attempt in range(self.config.max_attempts):
            try:
                return await func(*args, **kwargs)
            except tuple(self.config.retryable_exceptions) as e:
                last_exception = e

                if attempt < self.config.max_attempts - 1:
                    delay = self.config.strategy.get_delay(attempt)

                    if self.config.on_retry:
                        self.config.on_retry(attempt, e, delay)

                    await asyncio.sleep(delay)
                else:
                    if self.config.on_failure:
                        self.config.on_failure(e)

        raise last_exception

Context-Aware Retry

class ContextAwareRetry:
    """Retry with context about the operation"""

    def __init__(self):
        self.attempt_history = []

    async def execute_with_context(
        self,
        func: Callable,
        context: dict,
        config: RetryConfig
    ):
        operation_id = context.get("operation_id", str(uuid.uuid4()))

        for attempt in range(config.max_attempts):
            attempt_context = {
                "operation_id": operation_id,
                "attempt": attempt + 1,
                "max_attempts": config.max_attempts,
                "started_at": datetime.utcnow().isoformat()
            }

            try:
                result = await func()
                self._record_attempt(attempt_context, success=True)
                return result

            except Exception as e:
                attempt_context["error"] = str(e)
                attempt_context["error_type"] = type(e).__name__

                # Determine if retryable based on error type
                if not self._is_retryable(e, config):
                    self._record_attempt(attempt_context, success=False)
                    raise

                if attempt < config.max_attempts - 1:
                    delay = self._calculate_adaptive_delay(e, attempt, config)
                    attempt_context["retry_delay"] = delay
                    self._record_attempt(attempt_context, success=False)
                    await asyncio.sleep(delay)
                else:
                    self._record_attempt(attempt_context, success=False)
                    raise

    def _is_retryable(self, error: Exception, config: RetryConfig) -> bool:
        """Determine if error is retryable"""
        # Check explicit exception types
        if isinstance(error, tuple(config.retryable_exceptions)):
            return True

        # Check error message patterns
        error_msg = str(error).lower()
        retryable_patterns = ["rate limit", "timeout", "connection", "503", "429"]
        return any(pattern in error_msg for pattern in retryable_patterns)

    def _calculate_adaptive_delay(self, error: Exception, attempt: int, config: RetryConfig) -> float:
        """Calculate delay based on error type"""
        base_delay = config.strategy.get_delay(attempt)

        # Check for retry-after header
        if hasattr(error, 'retry_after'):
            return max(error.retry_after, base_delay)

        # Rate limit - longer delay
        if "rate limit" in str(error).lower():
            return base_delay * 2

        return base_delay

    def _record_attempt(self, context: dict, success: bool):
        context["success"] = success
        context["completed_at"] = datetime.utcnow().isoformat()
        self.attempt_history.append(context)

Circuit Breaker Integration

from datetime import datetime, timedelta
from enum import Enum

class CircuitState(Enum):
    CLOSED = "closed"      # Normal operation
    OPEN = "open"          # Failing, reject requests
    HALF_OPEN = "half_open"  # Testing if recovered

class CircuitBreaker:
    def __init__(
        self,
        failure_threshold: int = 5,
        recovery_timeout: int = 60,
        half_open_max_calls: int = 3
    ):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = timedelta(seconds=recovery_timeout)
        self.half_open_max_calls = half_open_max_calls

        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.last_failure_time = None
        self.half_open_calls = 0

    def can_execute(self) -> bool:
        if self.state == CircuitState.CLOSED:
            return True

        if self.state == CircuitState.OPEN:
            if datetime.utcnow() - self.last_failure_time > self.recovery_timeout:
                self.state = CircuitState.HALF_OPEN
                self.half_open_calls = 0
                return True
            return False

        if self.state == CircuitState.HALF_OPEN:
            return self.half_open_calls < self.half_open_max_calls

        return False

    def record_success(self):
        if self.state == CircuitState.HALF_OPEN:
            self.half_open_calls += 1
            if self.half_open_calls >= self.half_open_max_calls:
                self.state = CircuitState.CLOSED
                self.failure_count = 0
        elif self.state == CircuitState.CLOSED:
            self.failure_count = 0

    def record_failure(self):
        self.failure_count += 1
        self.last_failure_time = datetime.utcnow()

        if self.state == CircuitState.HALF_OPEN:
            self.state = CircuitState.OPEN
        elif self.failure_count >= self.failure_threshold:
            self.state = CircuitState.OPEN

class RetryWithCircuitBreaker:
    def __init__(self, circuit_breaker: CircuitBreaker, retry_config: RetryConfig):
        self.circuit = circuit_breaker
        self.retry_handler = RetryHandler(retry_config)

    async def execute(self, func: Callable, *args, **kwargs):
        if not self.circuit.can_execute():
            raise Exception("Circuit breaker is open")

        try:
            result = await self.retry_handler.execute(func, *args, **kwargs)
            self.circuit.record_success()
            return result
        except Exception as e:
            self.circuit.record_failure()
            raise

Usage Example

# Configure retry for OpenAI calls
retry_config = RetryConfig(
    max_attempts=3,
    strategy=ExponentialWithJitter(base_delay=1.0, multiplier=2.0, max_delay=30.0),
    retryable_exceptions=[RateLimitError, APITimeoutError, APIConnectionError],
    on_retry=lambda attempt, error, delay: logger.warning(
        f"Retry {attempt + 1}: {error}. Waiting {delay:.2f}s"
    )
)

circuit_breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=60)
resilient_executor = RetryWithCircuitBreaker(circuit_breaker, retry_config)

async def call_openai(messages: list):
    async def _call():
        return await client.chat.completions.create(
            model="gpt-4",
            messages=messages
        )

    return await resilient_executor.execute(_call)

Proper retry strategies make AI applications reliable under various failure conditions. Tomorrow, I will cover fallback patterns.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.