Back to Blog
6 min read

Retry Strategies for AI Applications: Beyond Simple Backoff

Retry strategies for AI applications need to be smarter than simple exponential backoff. Let’s explore advanced patterns that improve reliability without wasting resources.

Intelligent Retry Framework

from dataclasses import dataclass, field
from typing import Callable, Optional, List, Any
from enum import Enum
import time
import random
import logging

logger = logging.getLogger(__name__)

class RetryDecision(Enum):
    RETRY = "retry"
    FAIL = "fail"
    MODIFY_AND_RETRY = "modify_and_retry"

@dataclass
class RetryConfig:
    """Configuration for retry behavior"""
    max_attempts: int = 3
    base_delay: float = 1.0
    max_delay: float = 60.0
    exponential_base: float = 2.0
    jitter: bool = True
    retry_on: List[type] = field(default_factory=list)

@dataclass
class RetryContext:
    """Context for making retry decisions"""
    attempt: int
    error: Exception
    elapsed_time: float
    previous_errors: List[Exception]

class SmartRetry:
    """Intelligent retry with adaptive behavior"""

    def __init__(self, config: RetryConfig = None):
        self.config = config or RetryConfig()
        self.decision_handlers: List[Callable[[RetryContext], RetryDecision]] = []

    def add_decision_handler(self, handler: Callable[[RetryContext], RetryDecision]):
        """Add custom logic for retry decisions"""
        self.decision_handlers.append(handler)
        return self

    def execute(self, func: Callable, *args, **kwargs) -> Any:
        """Execute with smart retry logic"""

        errors: List[Exception] = []
        start_time = time.time()

        for attempt in range(1, self.config.max_attempts + 1):
            try:
                return func(*args, **kwargs)

            except Exception as e:
                errors.append(e)
                elapsed = time.time() - start_time

                context = RetryContext(
                    attempt=attempt,
                    error=e,
                    elapsed_time=elapsed,
                    previous_errors=errors[:-1]
                )

                decision = self._make_decision(context)

                if decision == RetryDecision.FAIL:
                    logger.error(f"Failing after {attempt} attempts")
                    raise

                if decision == RetryDecision.MODIFY_AND_RETRY:
                    # Modification happens via kwargs update
                    pass

                if attempt < self.config.max_attempts:
                    delay = self._calculate_delay(attempt, context)
                    logger.info(f"Attempt {attempt} failed. Retrying in {delay:.2f}s")
                    time.sleep(delay)

        # Should not reach here, but just in case
        raise errors[-1] if errors else RuntimeError("Retry failed")

    def _make_decision(self, context: RetryContext) -> RetryDecision:
        """Determine whether to retry"""

        # Check custom handlers first
        for handler in self.decision_handlers:
            decision = handler(context)
            if decision != RetryDecision.RETRY:
                return decision

        # Default logic
        error = context.error

        # Don't retry certain errors
        non_retryable = [
            ValueError,
            TypeError,
            KeyError,
        ]

        if any(isinstance(error, t) for t in non_retryable):
            return RetryDecision.FAIL

        # Check if error type is in retry list
        if self.config.retry_on:
            if not any(isinstance(error, t) for t in self.config.retry_on):
                return RetryDecision.FAIL

        return RetryDecision.RETRY

    def _calculate_delay(self, attempt: int, context: RetryContext) -> float:
        """Calculate delay with adaptive behavior"""

        # Base exponential backoff
        delay = self.config.base_delay * (self.config.exponential_base ** (attempt - 1))

        # Cap at max delay
        delay = min(delay, self.config.max_delay)

        # Add jitter
        if self.config.jitter:
            delay = delay * (0.5 + random.random())

        # Adaptive: if same error repeating, increase delay
        if len(context.previous_errors) > 0:
            same_error_count = sum(
                1 for e in context.previous_errors
                if type(e) == type(context.error)
            )
            if same_error_count > 1:
                delay *= 1.5  # Longer delay for repeated errors

        return delay

LLM-Specific Retry Strategies

from openai import APIError, RateLimitError, APIConnectionError

class LLMRetryStrategy:
    """Retry strategies specific to LLM APIs"""

    @staticmethod
    def rate_limit_handler(context: RetryContext) -> RetryDecision:
        """Handle rate limit errors specially"""
        if isinstance(context.error, RateLimitError):
            # Check if we have retry-after header
            retry_after = getattr(context.error, 'retry_after', None)
            if retry_after and retry_after > 120:
                # Too long to wait
                return RetryDecision.FAIL
            return RetryDecision.RETRY
        return RetryDecision.RETRY

    @staticmethod
    def context_length_handler(context: RetryContext) -> RetryDecision:
        """Handle context length errors"""
        error_code = getattr(context.error, 'code', None)
        if error_code == 'context_length_exceeded':
            # Can't retry with same input
            return RetryDecision.MODIFY_AND_RETRY
        return RetryDecision.RETRY

    @staticmethod
    def content_filter_handler(context: RetryContext) -> RetryDecision:
        """Handle content filter errors"""
        error_code = getattr(context.error, 'code', None)
        if error_code == 'content_filter':
            # Retrying won't help
            return RetryDecision.FAIL
        return RetryDecision.RETRY

# Build LLM-optimized retry
def create_llm_retry() -> SmartRetry:
    config = RetryConfig(
        max_attempts=5,
        base_delay=1.0,
        max_delay=60.0,
        retry_on=[APIError, APIConnectionError, TimeoutError]
    )

    retry = SmartRetry(config)
    retry.add_decision_handler(LLMRetryStrategy.rate_limit_handler)
    retry.add_decision_handler(LLMRetryStrategy.context_length_handler)
    retry.add_decision_handler(LLMRetryStrategy.content_filter_handler)

    return retry

# Usage
llm_retry = create_llm_retry()

def call_with_retry(prompt: str) -> str:
    def _call():
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[{"role": "user", "content": prompt}]
        )
        return response.choices[0].message.content

    return llm_retry.execute(_call)

Request Modification on Retry

class AdaptiveRetry:
    """Retry with automatic request modification"""

    def __init__(self):
        self.modifiers: List[Callable[[dict, Exception], dict]] = []

    def add_modifier(self, modifier: Callable[[dict, Exception], dict]):
        """Add a request modifier for retries"""
        self.modifiers.append(modifier)
        return self

    def execute(self, func: Callable, request: dict, max_attempts: int = 3) -> Any:
        """Execute with adaptive modifications"""

        current_request = request.copy()

        for attempt in range(max_attempts):
            try:
                return func(**current_request)

            except Exception as e:
                if attempt == max_attempts - 1:
                    raise

                # Apply modifiers
                for modifier in self.modifiers:
                    current_request = modifier(current_request, e)

                logger.info(f"Modified request for retry {attempt + 2}")
                time.sleep(2 ** attempt)

def truncate_on_context_error(request: dict, error: Exception) -> dict:
    """Truncate content if context length exceeded"""
    if getattr(error, 'code', None) == 'context_length_exceeded':
        messages = request.get('messages', [])
        if messages:
            # Truncate user message
            for msg in messages:
                if msg['role'] == 'user':
                    content = msg['content']
                    msg['content'] = content[:len(content)//2] + "\n[Content truncated]"
    return request

def reduce_max_tokens(request: dict, error: Exception) -> dict:
    """Reduce max_tokens on certain errors"""
    if 'max_tokens' in request:
        request['max_tokens'] = request['max_tokens'] // 2
    return request

def switch_model_on_failure(request: dict, error: Exception) -> dict:
    """Switch to a different model on repeated failures"""
    fallback_models = {
        'gpt-4o': 'gpt-4o-mini',
        'gpt-4-turbo': 'gpt-4o',
    }
    current_model = request.get('model')
    if current_model in fallback_models:
        request['model'] = fallback_models[current_model]
        logger.info(f"Switching model from {current_model} to {request['model']}")
    return request

# Build adaptive retry
adaptive = AdaptiveRetry()
adaptive.add_modifier(truncate_on_context_error)
adaptive.add_modifier(reduce_max_tokens)
adaptive.add_modifier(switch_model_on_failure)

# Usage
result = adaptive.execute(
    client.chat.completions.create,
    {
        "model": "gpt-4o",
        "messages": [{"role": "user", "content": long_prompt}],
        "max_tokens": 4096
    }
)

Retry Budget

from datetime import datetime, timedelta
from collections import defaultdict

class RetryBudget:
    """Limit retries across the application"""

    def __init__(self, max_retries_per_minute: int = 100):
        self.max_per_minute = max_retries_per_minute
        self.retry_counts: defaultdict = defaultdict(list)

    def can_retry(self, key: str = "default") -> bool:
        """Check if retry budget allows another retry"""
        now = datetime.now()
        cutoff = now - timedelta(minutes=1)

        # Clean old entries
        self.retry_counts[key] = [
            t for t in self.retry_counts[key] if t > cutoff
        ]

        return len(self.retry_counts[key]) < self.max_per_minute

    def record_retry(self, key: str = "default"):
        """Record a retry attempt"""
        self.retry_counts[key].append(datetime.now())

    def get_usage(self, key: str = "default") -> dict:
        """Get current usage stats"""
        now = datetime.now()
        cutoff = now - timedelta(minutes=1)

        recent = [t for t in self.retry_counts[key] if t > cutoff]

        return {
            "used": len(recent),
            "limit": self.max_per_minute,
            "percentage": (len(recent) / self.max_per_minute) * 100
        }

# Global retry budget
retry_budget = RetryBudget(max_retries_per_minute=50)

def retry_with_budget(func: Callable, *args, **kwargs) -> Any:
    """Execute with retry budget enforcement"""

    if not retry_budget.can_retry():
        raise RuntimeError("Retry budget exhausted")

    try:
        return func(*args, **kwargs)
    except Exception as e:
        retry_budget.record_retry()
        raise

Smart retry strategies balance reliability with resource efficiency. The key is adapting behavior based on error type, system state, and business requirements.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.