Back to Blog
4 min read

Rate Limit Management for AI APIs

Effective rate limit management ensures reliable AI applications. Today we explore strategies for handling API rate limits gracefully.

Rate Limiting Patterns

from functools import wraps
import time

# Token bucket algorithm
class TokenBucket:
    def __init__(self, capacity, refill_rate):
        self.capacity = capacity
        self.tokens = capacity
        self.refill_rate = refill_rate
        self.last_refill = time.time()

    def consume(self, tokens=1):
        self._refill()
        if self.tokens >= tokens:
            self.tokens -= tokens
            return True
        return False

    def _refill(self):
        now = time.time()
        elapsed = now - self.last_refill
        self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
        self.last_refill = now

# Decorator for rate limiting
def rate_limited(bucket):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            while not bucket.consume():
                time.sleep(0.1)
            return func(*args, **kwargs)
        return wrapper
    return decorator

Exponential Backoff with Jitter

import random

class ExponentialBackoff:
    def __init__(self, base_delay=1, max_delay=60, max_retries=5):
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.max_retries = max_retries

    def calculate_delay(self, attempt):
        # Exponential backoff
        delay = self.base_delay * (2 ** attempt)
        # Cap at max delay
        delay = min(delay, self.max_delay)
        # Add jitter (0.5 to 1.5 multiplier)
        jitter = random.uniform(0.5, 1.5)
        return delay * jitter

    def execute_with_retry(self, func, *args, **kwargs):
        last_error = None
        for attempt in range(self.max_retries):
            try:
                return func(*args, **kwargs)
            except RateLimitError as e:
                last_error = e
                if attempt < self.max_retries - 1:
                    delay = self.calculate_delay(attempt)
                    print(f"Rate limited. Retrying in {delay:.2f}s")
                    time.sleep(delay)
        raise last_error

Multi-Deployment Load Balancing

import random

class LoadBalancer:
    def __init__(self, deployments):
        """
        deployments: list of dicts with 'endpoint', 'key', 'weight'
        """
        self.deployments = deployments
        self.health_status = {d['endpoint']: True for d in deployments}

    def get_deployment(self):
        healthy = [d for d in self.deployments if self.health_status[d['endpoint']]]
        if not healthy:
            # Reset health if all unhealthy
            self.health_status = {d['endpoint']: True for d in self.deployments}
            healthy = self.deployments

        # Weighted random selection
        total_weight = sum(d['weight'] for d in healthy)
        r = random.uniform(0, total_weight)
        cumulative = 0
        for d in healthy:
            cumulative += d['weight']
            if r <= cumulative:
                return d
        return healthy[-1]

    def mark_unhealthy(self, endpoint, duration=60):
        self.health_status[endpoint] = False
        # Auto-recover after duration
        threading.Timer(duration, lambda: self.mark_healthy(endpoint)).start()

    def mark_healthy(self, endpoint):
        self.health_status[endpoint] = True

# Usage
balancer = LoadBalancer([
    {"endpoint": "https://east-us.openai.azure.com", "key": "key1", "weight": 2},
    {"endpoint": "https://west-us.openai.azure.com", "key": "key2", "weight": 1},
])

deployment = balancer.get_deployment()
client = AzureOpenAI(azure_endpoint=deployment['endpoint'], api_key=deployment['key'])

Request Queue with Priority

import heapq
from dataclasses import dataclass, field
from typing import Any

@dataclass(order=True)
class PrioritizedRequest:
    priority: int
    timestamp: float = field(compare=False)
    request: Any = field(compare=False)
    future: Any = field(compare=False)

class PriorityQueue:
    def __init__(self, rate_limiter):
        self.heap = []
        self.rate_limiter = rate_limiter

    async def submit(self, request, priority=5):
        future = asyncio.Future()
        item = PrioritizedRequest(priority, time.time(), request, future)
        heapq.heappush(self.heap, item)
        return await future

    async def process(self):
        while True:
            if self.heap and self.rate_limiter.consume():
                item = heapq.heappop(self.heap)
                try:
                    result = await self._execute(item.request)
                    item.future.set_result(result)
                except Exception as e:
                    item.future.set_exception(e)
            await asyncio.sleep(0.01)

Circuit Breaker Pattern

from enum import Enum

class CircuitState(Enum):
    CLOSED = "closed"
    OPEN = "open"
    HALF_OPEN = "half_open"

class CircuitBreaker:
    def __init__(self, failure_threshold=5, recovery_timeout=30):
        self.state = CircuitState.CLOSED
        self.failures = 0
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.last_failure_time = None

    def call(self, func, *args, **kwargs):
        if self.state == CircuitState.OPEN:
            if time.time() - self.last_failure_time > self.recovery_timeout:
                self.state = CircuitState.HALF_OPEN
            else:
                raise Exception("Circuit is open")

        try:
            result = func(*args, **kwargs)
            self._on_success()
            return result
        except RateLimitError:
            self._on_failure()
            raise

    def _on_success(self):
        self.failures = 0
        self.state = CircuitState.CLOSED

    def _on_failure(self):
        self.failures += 1
        self.last_failure_time = time.time()
        if self.failures >= self.failure_threshold:
            self.state = CircuitState.OPEN

Monitoring Rate Limits

class RateLimitMonitor:
    def __init__(self):
        self.metrics = {
            "requests": 0,
            "rate_limited": 0,
            "retries": 0,
            "failures": 0
        }

    def record_request(self):
        self.metrics["requests"] += 1

    def record_rate_limit(self):
        self.metrics["rate_limited"] += 1

    def record_retry(self):
        self.metrics["retries"] += 1

    def record_failure(self):
        self.metrics["failures"] += 1

    def get_stats(self):
        total = self.metrics["requests"]
        return {
            **self.metrics,
            "rate_limit_pct": self.metrics["rate_limited"] / total * 100 if total else 0,
            "retry_pct": self.metrics["retries"] / total * 100 if total else 0
        }

Tomorrow we’ll explore token estimation techniques.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.