Skip to content
Back to Blog
1 min read

Efficient Inference: Optimization Techniques for Production AI

I wrote “Efficient Inference: Optimization Techniques for Production AI” to share practical, production-minded guidance on this topic.

Optimization Techniques Overview

Technique              Latency Impact    Cost Impact    Quality Impact
─────────────────────────────────────────────────────────────────────────
Quantization           -40%              -50%           -1-3%
Batching               +50%*             -70%           None
Caching                -90%**            -90%**         None
Model Distillation     -60%              -80%           -5-10%
Speculative Decoding   -50%              -30%           None
KV Cache Optimization  -30%              -20%           None

* Per-request latency may increase, but throughput improves
** For cached requests only

Quantization

from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType
import onnx

class ModelQuantizer:
    def __init__(self, model_path: str):
        self.model_path = model_path

    def quantize_dynamic_int8(self, output_path: str):
        """Dynamic quantization - no calibration data needed."""
        quantize_dynamic(
            model_input=self.model_path,
            model_output=output_path,
            weight_type=QuantType.QInt8,
            optimize_model=True
        )
        self._print_size_comparison(output_path)

    def quantize_static_int8(self, output_path: str, calibration_data):
        """Static quantization - better quality with calibration."""
        from onnxruntime.quantization import CalibrationDataReader

        class DataReader(CalibrationDataReader):
            def __init__(self, data):
                self.data = data
                self.idx = 0

            def get_next(self):
                if self.idx >= len(self.data):
                    return None
                item = {"input": self.data[self.idx]}
                self.idx += 1
                return item

        quantize_static(
            model_input=self.model_path,
            model_output=output_path,
            calibration_data_reader=DataReader(calibration_data),
            quant_format=QuantType.QInt8
        )

    def quantize_float16(self, output_path: str):
        """FP16 quantization - good balance of speed and quality."""
        from onnxruntime.transformers import float16
        model = onnx.load(self.model_path)
        model_fp16 = float16.convert_float_to_float16(model)
        onnx.save(model_fp16, output_path)

    def _print_size_comparison(self, output_path: str):
        import os
        original = os.path.getsize(self.model_path) / 1024 / 1024
        quantized = os.path.getsize(output_path) / 1024 / 1024
        print(f"Original: {original:.2f}MB -> Quantized: {quantized:.2f}MB ({quantized/original*100:.1f}%)")

Intelligent Batching

import asyncio
from collections import defaultdict
import time

class DynamicBatcher:
    """Batch requests for efficient GPU utilization."""

    def __init__(
        self,
        model,
        max_batch_size: int = 32,
        max_wait_ms: int = 50
    ):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.pending = []
        self.lock = asyncio.Lock()
        self._batch_task = None

    async def infer(self, input_data) -> dict:
        """Submit request and wait for result."""
        future = asyncio.Future()

        async with self.lock:
            self.pending.append((input_data, future))

            if len(self.pending) >= self.max_batch_size:
                # Batch is full, process immediately
                await self._process_batch()
            elif self._batch_task is None:
                # Start timer for partial batch
                self._batch_task = asyncio.create_task(self._wait_and_process())

        return await future

    async def _wait_and_process(self):
        """Wait for more requests or timeout."""
        await asyncio.sleep(self.max_wait_ms / 1000)
        async with self.lock:
            if self.pending:
                await self._process_batch()
            self._batch_task = None

    async def _process_batch(self):
        """Process current batch."""
        batch = self.pending[:self.max_batch_size]
        self.pending = self.pending[self.max_batch_size:]

        inputs = [item[0] for item in batch]
        futures = [item[1] for item in batch]

        try:
            # Batch inference
            results = await self.model.batch_infer(inputs)

            # Resolve futures
            for future, result in zip(futures, results):
                future.set_result(result)
        except Exception as e:
            for future in futures:
                future.set_exception(e)

Response Caching

import hashlib
import json
from azure.cosmos import CosmosClient
from datetime import datetime, timedelta

class SemanticCache:
    """Cache responses with semantic similarity."""

    def __init__(self, cosmos_client: CosmosClient, embedding_model):
        self.container = cosmos_client.get_database_client("cache").get_container_client("responses")
        self.embedding_model = embedding_model
        self.similarity_threshold = 0.95

    async def get(self, prompt: str, model: str) -> dict | None:
        """Get cached response if similar prompt exists."""

        # Exact match first (fast)
        exact_key = self._hash_key(prompt, model)
        try:
            item = self.container.read_item(exact_key, partition_key=model)
            if self._is_valid(item):
                return item["response"]
        except:
            pass

        # Semantic similarity search (slower but catches paraphrases)
        prompt_embedding = await self.embedding_model.embed(prompt)

        query = """
            SELECT TOP 5 c.id, c.prompt, c.response, c.embedding
            FROM c
            WHERE c.model = @model AND c.expires_at > @now
        """

        items = self.container.query_items(
            query=query,
            parameters=[
                {"name": "@model", "value": model},
                {"name": "@now", "value": datetime.utcnow().isoformat()}
            ]
        )

        for item in items:
            similarity = self._cosine_similarity(prompt_embedding, item["embedding"])
            if similarity >= self.similarity_threshold:
                return item["response"]

        return None

    async def set(
        self,
        prompt: str,
        model: str,
        response: dict,
        ttl_hours: int = 24
    ):
        """Cache a response."""

        prompt_embedding = await self.embedding_model.embed(prompt)

        item = {
            "id": self._hash_key(prompt, model),
            "model": model,
            "prompt": prompt,
            "response": response,
            "embedding": prompt_embedding,
            "created_at": datetime.utcnow().isoformat(),
            "expires_at": (datetime.utcnow() + timedelta(hours=ttl_hours)).isoformat()
        }

        self.container.upsert_item(item)

    def _hash_key(self, prompt: str, model: str) -> str:
        content = f"{model}:{prompt}"
        return hashlib.sha256(content.encode()).hexdigest()[:32]

KV Cache Optimization

class KVCacheManager:
    """Manage key-value cache for transformer inference."""

    def __init__(self, max_cache_size_gb: float = 8.0):
        self.max_size = max_cache_size_gb * 1024 * 1024 * 1024  # bytes
        self.cache = {}
        self.access_times = {}

    def get_or_create(self, session_id: str, sequence_length: int):
        """Get cached KV state or create new."""

        if session_id in self.cache:
            self.access_times[session_id] = time.time()
            return self.cache[session_id]

        # Evict if necessary
        self._evict_if_needed()

        # Create new cache
        kv_cache = self._allocate_cache(sequence_length)
        self.cache[session_id] = kv_cache
        self.access_times[session_id] = time.time()

        return kv_cache

    def _evict_if_needed(self):
        """Evict oldest entries if cache is full."""
        current_size = sum(self._cache_size(c) for c in self.cache.values())

        while current_size > self.max_size and self.cache:
            # LRU eviction
            oldest = min(self.access_times, key=self.access_times.get)
            current_size -= self._cache_size(self.cache[oldest])
            del self.cache[oldest]
            del self.access_times[oldest]

    def update(self, session_id: str, new_kv: dict):
        """Update cache with new key-value pairs."""
        if session_id in self.cache:
            # Append new KV pairs
            self.cache[session_id] = self._merge_kv(
                self.cache[session_id],
                new_kv
            )

Speculative Decoding

class SpeculativeDecoder:
    """Use small model to draft, large model to verify."""

    def __init__(self, draft_model, target_model, speculation_length: int = 5):
        self.draft_model = draft_model
        self.target_model = target_model
        self.speculation_length = speculation_length

    async def generate(self, prompt: str, max_tokens: int) -> str:
        """Generate with speculative decoding."""

        tokens = self._tokenize(prompt)
        generated = []

        while len(generated) < max_tokens:
            # Draft: Generate speculation_length tokens with small model
            draft_tokens = await self._draft(tokens + generated)

            # Verify: Check draft tokens with large model
            verified, next_token = await self._verify(tokens + generated, draft_tokens)

            # Accept verified tokens
            generated.extend(verified)

            # Add correction token if draft was wrong
            if next_token is not None:
                generated.append(next_token)

            # Check for end of sequence
            if self._is_eos(generated[-1]):
                break

        return self._detokenize(generated)

    async def _draft(self, context: list) -> list:
        """Generate draft tokens with small model."""
        return await self.draft_model.generate(
            context,
            max_new_tokens=self.speculation_length,
            do_sample=False  # Greedy for speculation
        )

    async def _verify(self, context: list, draft: list) -> tuple:
        """Verify draft tokens with large model."""
        # Get target model's distribution for all positions
        logits = await self.target_model.forward(context + draft)

        verified = []
        for i, draft_token in enumerate(draft):
            target_token = logits[i].argmax()
            if target_token == draft_token:
                verified.append(draft_token)
            else:
                # Draft was wrong, return correction
                return verified, target_token

        return verified, None

Optimization Pipeline

class InferenceOptimizer:
    """Complete optimization pipeline."""

    def __init__(self, model_path: str):
        self.model_path = model_path

    def optimize(self, target_latency_ms: int, target_quality: float) -> dict:
        """Apply optimizations to meet targets."""

        optimizations_applied = []
        current_latency = self._benchmark_latency()
        current_quality = self._benchmark_quality()

        # Apply optimizations progressively
        if current_latency > target_latency_ms:
            # Try quantization first
            if current_quality - 0.02 > target_quality:
                self._apply_quantization("int8")
                optimizations_applied.append("int8_quantization")
            else:
                self._apply_quantization("fp16")
                optimizations_applied.append("fp16_quantization")

        current_latency = self._benchmark_latency()

        if current_latency > target_latency_ms:
            # Try batching
            self._configure_batching(batch_size=16)
            optimizations_applied.append("batching")

        return {
            "optimizations": optimizations_applied,
            "final_latency_ms": self._benchmark_latency(),
            "final_quality": self._benchmark_quality()
        }

Best Practices

  1. Measure first: Benchmark before optimizing
  2. Start with caching: Highest impact, lowest risk
  3. Batch when possible: Better GPU utilization
  4. Quantize appropriately: FP16 for quality, INT8 for speed
  5. Monitor quality: Track accuracy with each optimization
  6. Use the right hardware: Match model to accelerator

Efficient inference is crucial for production AI. Apply these techniques systematically and measure the impact at each step.\n\n## Takeaways\n\nAdd a concise, personal takeaway and recommended next steps here.\n

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.