1 min read
Error Handling in LLM Applications: A Comprehensive Guide
I wrote “Error Handling in LLM Applications: A Comprehensive Guide” to share practical, production-minded guidance on this topic.
Error Taxonomy
from enum import Enum
from typing import Optional, Dict, Any
from dataclasses import dataclass
class ErrorCategory(Enum):
API_ERROR = "api_error" # OpenAI/Azure API errors
RATE_LIMIT = "rate_limit" # Rate limiting
CONTENT_FILTER = "content_filter" # Content moderation
CONTEXT_LENGTH = "context_length" # Token limits
VALIDATION = "validation" # Input/output validation
TIMEOUT = "timeout" # Request timeout
TOOL_ERROR = "tool_error" # Tool execution failure
PARSING = "parsing" # Response parsing failure
NETWORK = "network" # Network connectivity
class ErrorSeverity(Enum):
LOW = "low" # Log and continue
MEDIUM = "medium" # Retry with backoff
HIGH = "high" # Alert and retry
CRITICAL = "critical" # Stop and escalate
@dataclass
class LLMError(Exception):
"""Base class for LLM-related errors"""
category: ErrorCategory
severity: ErrorSeverity
message: str
details: Optional[Dict[str, Any]] = None
retryable: bool = True
retry_after: Optional[float] = None
def __str__(self):
return f"[{self.category.value}] {self.message}"
# Specific error classes
class RateLimitError(LLMError):
def __init__(self, retry_after: float = 60):
super().__init__(
category=ErrorCategory.RATE_LIMIT,
severity=ErrorSeverity.MEDIUM,
message="Rate limit exceeded",
retryable=True,
retry_after=retry_after
)
class ContentFilterError(LLMError):
def __init__(self, filter_type: str):
super().__init__(
category=ErrorCategory.CONTENT_FILTER,
severity=ErrorSeverity.HIGH,
message=f"Content filtered: {filter_type}",
retryable=False,
details={"filter_type": filter_type}
)
class ContextLengthError(LLMError):
def __init__(self, tokens_used: int, max_tokens: int):
super().__init__(
category=ErrorCategory.CONTEXT_LENGTH,
severity=ErrorSeverity.MEDIUM,
message=f"Context length exceeded: {tokens_used}/{max_tokens}",
retryable=True,
details={"tokens_used": tokens_used, "max_tokens": max_tokens}
)
Error Handler
from openai import OpenAI, APIError, RateLimitError as OpenAIRateLimitError
from openai import APIConnectionError, APITimeoutError
import time
import logging
logger = logging.getLogger(__name__)
class LLMErrorHandler:
"""Centralized error handling for LLM operations"""
def __init__(self, max_retries: int = 3, base_delay: float = 1.0):
self.max_retries = max_retries
self.base_delay = base_delay
self.error_counts: Dict[ErrorCategory, int] = {}
def handle(self, func, *args, **kwargs):
"""Execute function with comprehensive error handling"""
last_error = None
for attempt in range(self.max_retries + 1):
try:
return func(*args, **kwargs)
except OpenAIRateLimitError as e:
error = self._classify_openai_error(e)
last_error = error
if attempt < self.max_retries:
delay = error.retry_after or self._calculate_delay(attempt)
logger.warning(f"Rate limited. Retrying in {delay}s...")
time.sleep(delay)
continue
except APITimeoutError as e:
error = LLMError(
category=ErrorCategory.TIMEOUT,
severity=ErrorSeverity.MEDIUM,
message="Request timed out",
retryable=True
)
last_error = error
if attempt < self.max_retries:
delay = self._calculate_delay(attempt)
logger.warning(f"Timeout. Retrying in {delay}s...")
time.sleep(delay)
continue
except APIConnectionError as e:
error = LLMError(
category=ErrorCategory.NETWORK,
severity=ErrorSeverity.MEDIUM,
message="Connection failed",
retryable=True
)
last_error = error
if attempt < self.max_retries:
delay = self._calculate_delay(attempt)
logger.warning(f"Connection error. Retrying in {delay}s...")
time.sleep(delay)
continue
except APIError as e:
error = self._classify_openai_error(e)
self._record_error(error)
if error.retryable and attempt < self.max_retries:
delay = self._calculate_delay(attempt)
time.sleep(delay)
continue
raise error
except Exception as e:
# Unexpected error
logger.error(f"Unexpected error: {e}")
raise
# All retries exhausted
if last_error:
self._record_error(last_error)
raise last_error
def _classify_openai_error(self, error: APIError) -> LLMError:
"""Classify OpenAI API error"""
error_code = getattr(error, 'code', None)
status_code = getattr(error, 'status_code', None)
if status_code == 429:
retry_after = float(error.response.headers.get('retry-after', 60))
return RateLimitError(retry_after=retry_after)
if error_code == 'context_length_exceeded':
return ContextLengthError(tokens_used=0, max_tokens=0)
if error_code == 'content_filter':
return ContentFilterError(filter_type="unknown")
return LLMError(
category=ErrorCategory.API_ERROR,
severity=ErrorSeverity.HIGH,
message=str(error),
retryable=status_code in [500, 502, 503, 504]
)
def _calculate_delay(self, attempt: int) -> float:
"""Calculate exponential backoff delay"""
delay = self.base_delay * (2 ** attempt)
# Add jitter
import random
jitter = random.uniform(0, delay * 0.1)
return delay + jitter
def _record_error(self, error: LLMError):
"""Record error for monitoring"""
self.error_counts[error.category] = self.error_counts.get(error.category, 0) + 1
logger.error(f"LLM Error: {error}", extra={
"category": error.category.value,
"severity": error.severity.value,
"details": error.details
})
Graceful Degradation
from typing import Callable, TypeVar, Generic
T = TypeVar('T')
class GracefulDegradation(Generic[T]):
"""Handle failures with graceful degradation"""
def __init__(self, primary: Callable[..., T]):
self.primary = primary
self.fallbacks: List[Callable[..., T]] = []
self.default: Optional[T] = None
def add_fallback(self, fallback: Callable[..., T]) -> 'GracefulDegradation':
"""Add a fallback function"""
self.fallbacks.append(fallback)
return self
def set_default(self, default: T) -> 'GracefulDegradation':
"""Set default value if all else fails"""
self.default = default
return self
def execute(self, *args, **kwargs) -> T:
"""Execute with fallbacks"""
# Try primary
try:
return self.primary(*args, **kwargs)
except Exception as e:
logger.warning(f"Primary failed: {e}")
# Try fallbacks
for i, fallback in enumerate(self.fallbacks):
try:
logger.info(f"Trying fallback {i + 1}")
return fallback(*args, **kwargs)
except Exception as e:
logger.warning(f"Fallback {i + 1} failed: {e}")
# Return default
if self.default is not None:
logger.warning("All options failed, returning default")
return self.default
raise RuntimeError("All execution paths failed")
# Usage example
def call_gpt4(prompt: str) -> str:
# Primary model
response = client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}]
)
return response.choices[0].message.content
def call_gpt4_mini(prompt: str) -> str:
# Faster, cheaper fallback
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}]
)
return response.choices[0].message.content
def cached_response(prompt: str) -> str:
# Return cached response if available
return "I'm having trouble processing your request. Please try again later."
# Build degradation chain
degraded_call = (
GracefulDegradation(call_gpt4)
.add_fallback(call_gpt4_mini)
.add_fallback(cached_response)
.set_default("Service temporarily unavailable")
)
result = degraded_call.execute("Hello, world!")
Input Validation
from pydantic import BaseModel, validator, ValidationError
import tiktoken
class ValidatedPrompt(BaseModel):
"""Validate prompts before sending to LLM"""
content: str
max_tokens: int = 4096
model: str = "gpt-4o"
@validator('content')
def validate_content(cls, v):
if not v or not v.strip():
raise ValueError("Prompt cannot be empty")
if len(v) > 100000:
raise ValueError("Prompt too long")
# Check for prompt injection patterns
injection_patterns = [
"ignore previous instructions",
"disregard all prior",
"forget everything",
]
lower_content = v.lower()
for pattern in injection_patterns:
if pattern in lower_content:
raise ValueError(f"Potentially malicious prompt detected")
return v.strip()
def get_token_count(self) -> int:
"""Count tokens in the prompt"""
encoding = tiktoken.encoding_for_model(self.model)
return len(encoding.encode(self.content))
def validate_token_limit(self, context_window: int = 128000):
"""Ensure prompt fits in context window"""
tokens = self.get_token_count()
available = context_window - self.max_tokens
if tokens > available:
raise ContextLengthError(
tokens_used=tokens,
max_tokens=available
)
def validate_and_call(content: str) -> str:
"""Validate input before calling LLM"""
try:
prompt = ValidatedPrompt(content=content)
prompt.validate_token_limit()
return call_llm(prompt.content)
except ValidationError as e:
raise LLMError(
category=ErrorCategory.VALIDATION,
severity=ErrorSeverity.LOW,
message="Invalid input",
details={"errors": e.errors()},
retryable=False
)
Output Validation
import json
class OutputValidator:
"""Validate LLM outputs"""
def validate_json(self, output: str) -> dict:
"""Validate JSON output"""
try:
return json.loads(output)
except json.JSONDecodeError as e:
raise LLMError(
category=ErrorCategory.PARSING,
severity=ErrorSeverity.MEDIUM,
message="Invalid JSON in response",
details={"output": output[:500], "error": str(e)},
retryable=True
)
def validate_schema(self, output: dict, schema: type) -> BaseModel:
"""Validate output against Pydantic schema"""
try:
return schema.model_validate(output)
except ValidationError as e:
raise LLMError(
category=ErrorCategory.VALIDATION,
severity=ErrorSeverity.MEDIUM,
message="Output doesn't match schema",
details={"errors": e.errors()},
retryable=True
)
def validate_safety(self, output: str) -> str:
"""Basic safety validation of output"""
# Check for potentially harmful content
# This is in addition to OpenAI's content filtering
dangerous_patterns = [
"sudo rm",
"DROP TABLE",
"<script>",
]
for pattern in dangerous_patterns:
if pattern in output:
raise LLMError(
category=ErrorCategory.CONTENT_FILTER,
severity=ErrorSeverity.CRITICAL,
message="Potentially dangerous content in output",
retryable=False
)
return output
Robust error handling is what separates prototypes from production systems. Invest in comprehensive error handling early to avoid painful debugging later.\n\n## Takeaways\n\nAdd a concise, personal takeaway and recommended next steps here.\n