Back to Blog
2 min read

Efficient Inference Patterns: Maximizing AI Throughput

Efficient inference is critical for production AI. Here are patterns to maximize throughput and minimize cost.

Inference Optimization Patterns

# efficient_inference.py - Patterns for high-throughput inference

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import List, TypeVar
import numpy as np

T = TypeVar('T')

class BatchingInference:
    """Dynamic batching for efficient inference."""

    def __init__(self, model, max_batch_size: int = 32, max_wait_ms: int = 50):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.queue = asyncio.Queue()
        self.running = False

    async def start(self):
        """Start the batching loop."""
        self.running = True
        asyncio.create_task(self._batch_loop())

    async def predict(self, input_data) -> any:
        """Submit request and wait for result."""
        future = asyncio.Future()
        await self.queue.put((input_data, future))
        return await future

    async def _batch_loop(self):
        """Process requests in batches."""
        while self.running:
            batch = []
            futures = []

            # Collect requests up to batch size or timeout
            try:
                item = await asyncio.wait_for(
                    self.queue.get(),
                    timeout=self.max_wait_ms / 1000
                )
                batch.append(item[0])
                futures.append(item[1])
            except asyncio.TimeoutError:
                continue

            # Collect more items without waiting
            while len(batch) < self.max_batch_size:
                try:
                    item = self.queue.get_nowait()
                    batch.append(item[0])
                    futures.append(item[1])
                except asyncio.QueueEmpty:
                    break

            # Process batch
            if batch:
                results = await self._process_batch(batch)
                for future, result in zip(futures, results):
                    future.set_result(result)

    async def _process_batch(self, batch: List) -> List:
        """Process a batch of inputs."""
        return self.model.predict_batch(batch)


class ContinuousBatching:
    """Continuous batching for LLM inference."""

    def __init__(self, model, max_concurrent: int = 8):
        self.model = model
        self.max_concurrent = max_concurrent
        self.active_requests = {}

    async def generate(self, prompt: str, max_tokens: int) -> str:
        """Generate with continuous batching."""
        request_id = self._create_request(prompt, max_tokens)

        while not self._is_complete(request_id):
            # Wait for next token batch
            await self._step()

        return self._get_result(request_id)

    async def _step(self):
        """Process one step for all active requests."""
        if not self.active_requests:
            return

        # Batch all active sequences
        inputs = self._prepare_batch()
        outputs = self.model.forward(inputs)

        # Update each request
        for request_id, output in zip(self.active_requests.keys(), outputs):
            self._update_request(request_id, output)


class KVCacheManager:
    """Manage KV cache for efficient inference."""

    def __init__(self, max_cache_size_gb: float = 8.0):
        self.max_size = max_cache_size_gb * 1e9
        self.cache = {}
        self.access_times = {}

    def get_or_create(self, sequence_id: str, layer: int) -> np.ndarray:
        """Get cached KV or create new."""
        key = (sequence_id, layer)
        self.access_times[key] = asyncio.get_event_loop().time()

        if key in self.cache:
            return self.cache[key]

        # Evict if necessary
        self._evict_if_needed()

        # Create new cache entry
        self.cache[key] = self._allocate_cache()
        return self.cache[key]

    def _evict_if_needed(self):
        """Evict least recently used entries."""
        while self._current_size() > self.max_size * 0.9:
            oldest = min(self.access_times, key=self.access_times.get)
            del self.cache[oldest]
            del self.access_times[oldest]

Efficient inference patterns can improve throughput by 5-10x while reducing costs.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.