6 min read
Model Optimization Techniques: From Training to Deployment
Optimizing AI models is essential for production deployment. This guide covers techniques from model architecture to deployment, helping you achieve the best balance of speed, size, and quality.
The Optimization Journey
Training Post-Training Deployment
│ │ │
▼ ▼ ▼
┌─────────────┐ ┌──────────────┐ ┌─────────────────┐
│ Architecture│ │ Quantization │ │ Runtime Optim │
│ Distillation│ │ Pruning │ │ Hardware Accel │
│ Efficient │ │ Graph Optim │ │ Serving Optim │
│ Attention │ │ Knowledge │ │ Caching │
└─────────────┘ │ Distillation │ └─────────────────┘
└──────────────┘
Architecture Optimization
Efficient Attention Mechanisms
import torch
import torch.nn as nn
import math
class MultiQueryAttention(nn.Module):
"""Multi-Query Attention - Shared KV heads for efficiency."""
def __init__(self, d_model: int, n_heads: int, n_kv_heads: int = 1):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim)
self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim)
self.o_proj = nn.Linear(d_model, d_model)
def forward(self, x):
batch, seq_len, _ = x.shape
q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
# Expand KV to match query heads
k = k.repeat_interleave(self.n_heads // self.n_kv_heads, dim=2)
v = v.repeat_interleave(self.n_heads // self.n_kv_heads, dim=2)
# Attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
return self.o_proj(out.reshape(batch, seq_len, -1))
class FlashAttention(nn.Module):
"""Flash Attention - IO-aware attention computation."""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
from flash_attn import flash_attn_func
batch, seq_len, _ = x.shape
qkv = self.qkv(x).reshape(batch, seq_len, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
# Flash attention - memory efficient
out = flash_attn_func(q, k, v, causal=True)
return self.out(out.reshape(batch, seq_len, -1))
Model Pruning
import torch.nn.utils.prune as prune
class ModelPruner:
"""Prune model weights for smaller, faster models."""
def __init__(self, model):
self.model = model
def prune_unstructured(self, amount: float = 0.3):
"""Remove individual weights (unstructured pruning)."""
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
prune.l1_unstructured(module, name='weight', amount=amount)
return self._get_sparsity()
def prune_structured(self, amount: float = 0.2):
"""Remove entire neurons/filters (structured pruning)."""
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
prune.ln_structured(module, name='weight', amount=amount, n=2, dim=0)
return self._get_sparsity()
def remove_pruning(self):
"""Make pruning permanent."""
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
try:
prune.remove(module, 'weight')
except:
pass
def _get_sparsity(self) -> float:
"""Calculate model sparsity."""
zeros = 0
total = 0
for param in self.model.parameters():
zeros += (param == 0).sum().item()
total += param.numel()
return zeros / total
def iterative_pruning(self, target_sparsity: float, steps: int = 5, fine_tune_fn=None):
"""Gradually prune and fine-tune."""
current_sparsity = 0
step_amount = target_sparsity / steps
for step in range(steps):
# Prune
self.prune_unstructured(step_amount / (1 - current_sparsity))
current_sparsity = self._get_sparsity()
print(f"Step {step+1}: Sparsity = {current_sparsity:.2%}")
# Fine-tune
if fine_tune_fn:
fine_tune_fn(self.model)
self.remove_pruning()
return current_sparsity
Knowledge Distillation
import torch.nn.functional as F
class DistillationTrainer:
"""Train smaller model to mimic larger model."""
def __init__(
self,
teacher_model,
student_model,
temperature: float = 2.0,
alpha: float = 0.5
):
self.teacher = teacher_model
self.student = student_model
self.temperature = temperature
self.alpha = alpha # Balance between hard and soft labels
def distillation_loss(self, student_logits, teacher_logits, labels):
"""Compute combined distillation loss."""
# Soft targets (from teacher)
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
soft_loss *= self.temperature ** 2
# Hard targets (ground truth)
hard_loss = F.cross_entropy(student_logits, labels)
# Combined loss
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
def train_step(self, batch, optimizer):
"""Single training step."""
inputs, labels = batch
# Teacher inference (no grad)
with torch.no_grad():
teacher_logits = self.teacher(inputs)
# Student inference
student_logits = self.student(inputs)
# Loss
loss = self.distillation_loss(student_logits, teacher_logits, labels)
# Update
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def train(self, train_loader, optimizer, epochs: int = 10):
"""Full training loop."""
self.teacher.eval()
self.student.train()
for epoch in range(epochs):
total_loss = 0
for batch in train_loader:
loss = self.train_step(batch, optimizer)
total_loss += loss
print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")
return self.student
Graph Optimization
import onnx
from onnxruntime.transformers import optimizer
class GraphOptimizer:
"""Optimize ONNX computation graph."""
def __init__(self, model_path: str):
self.model_path = model_path
def optimize(self, output_path: str):
"""Apply all graph optimizations."""
# Load model
model = onnx.load(self.model_path)
# Basic optimizations
from onnxruntime.transformers.onnx_model import OnnxModel
onnx_model = OnnxModel(model)
# Fuse operations
self._fuse_operations(onnx_model)
# Constant folding
self._fold_constants(onnx_model)
# Remove redundant operations
self._remove_redundant(onnx_model)
# Save optimized model
onnx_model.save_model_to_file(output_path)
return self._compare_models(self.model_path, output_path)
def optimize_for_inference(self, output_path: str, model_type: str = "bert"):
"""Use ONNX Runtime's built-in optimizer."""
optimized = optimizer.optimize_model(
self.model_path,
model_type=model_type,
num_heads=12, # Adjust based on model
hidden_size=768,
optimization_options=optimizer.FusionOptions(model_type),
opt_level=99
)
optimized.save_model_to_file(output_path)
return output_path
def _fuse_operations(self, model):
"""Fuse consecutive operations."""
# MatMul + Add -> GEMM
# BatchNorm + ReLU -> FusedBatchNorm
# etc.
pass
def _fold_constants(self, model):
"""Pre-compute constant expressions."""
pass
def _remove_redundant(self, model):
"""Remove identity operations and dead code."""
pass
Combining Techniques
class OptimizationPipeline:
"""Apply multiple optimization techniques."""
def __init__(self, model, calibration_data):
self.model = model
self.calibration_data = calibration_data
def optimize(self, config: dict) -> dict:
"""Run optimization pipeline."""
results = {"original": self._benchmark()}
# 1. Pruning (if enabled)
if config.get("pruning", {}).get("enabled", False):
pruner = ModelPruner(self.model)
sparsity = pruner.iterative_pruning(
target_sparsity=config["pruning"]["target_sparsity"],
steps=config["pruning"]["steps"]
)
results["after_pruning"] = self._benchmark()
results["sparsity"] = sparsity
# 2. Export to ONNX
onnx_path = "model_optimized.onnx"
self._export_to_onnx(onnx_path)
# 3. Graph optimization
if config.get("graph_optimization", True):
opt = GraphOptimizer(onnx_path)
opt.optimize(onnx_path)
results["after_graph_opt"] = self._benchmark_onnx(onnx_path)
# 4. Quantization
if config.get("quantization", {}).get("enabled", False):
quant = ModelQuantizer(onnx_path)
quant_type = config["quantization"]["type"]
if quant_type == "dynamic_int8":
quant.quantize_dynamic_int8(onnx_path)
elif quant_type == "static_int8":
quant.quantize_static_int8(onnx_path, self.calibration_data)
elif quant_type == "fp16":
quant.quantize_float16(onnx_path)
results["after_quantization"] = self._benchmark_onnx(onnx_path)
# Report
self._print_report(results)
return results
def _benchmark(self) -> dict:
"""Benchmark PyTorch model."""
pass
def _benchmark_onnx(self, path: str) -> dict:
"""Benchmark ONNX model."""
pass
# Usage
config = {
"pruning": {
"enabled": True,
"target_sparsity": 0.3,
"steps": 5
},
"graph_optimization": True,
"quantization": {
"enabled": True,
"type": "dynamic_int8"
}
}
pipeline = OptimizationPipeline(model, calibration_data)
results = pipeline.optimize(config)
Best Practices
- Benchmark baseline: Know your starting point
- One technique at a time: Isolate impact of each optimization
- Monitor quality: Track accuracy throughout
- Test on target hardware: Results vary by platform
- Profile memory: Size matters for deployment
- Automate: Build optimization into your pipeline
Model optimization is both science and art. Start with the techniques that offer the best impact for your specific constraints.