Back to Blog
4 min read

Quantization Basics: Reducing Model Size and Improving Speed

Quantization reduces the numerical precision of model weights, dramatically decreasing memory usage and improving inference speed. Today we’ll explore quantization fundamentals.

Understanding Quantization

# Precision levels
precision_levels = {
    "FP32": {"bits": 32, "range": "3.4e38", "typical_use": "Training"},
    "FP16": {"bits": 16, "range": "6.5e4", "typical_use": "Mixed precision training"},
    "BF16": {"bits": 16, "range": "3.4e38", "typical_use": "Training (A100+)"},
    "INT8": {"bits": 8, "range": "-128 to 127", "typical_use": "Inference"},
    "INT4": {"bits": 4, "range": "-8 to 7", "typical_use": "LLM inference"}
}

# Memory savings example (7B parameter model)
memory_comparison = {
    "FP32": "28 GB",
    "FP16": "14 GB",
    "INT8": "7 GB",
    "INT4": "3.5 GB"
}

Quantization Types

# Post-Training Quantization (PTQ)
# - Applied after training
# - No additional training required
# - May have accuracy loss

# Quantization-Aware Training (QAT)
# - Model trained with quantization simulation
# - Better accuracy preservation
# - Requires retraining

# Dynamic Quantization
# - Weights quantized, activations computed in FP32
# - Simple to apply
# - Good for CPU inference

# Static Quantization
# - Both weights and activations quantized
# - Requires calibration data
# - Best performance

PyTorch Dynamic Quantization

import torch
from transformers import AutoModelForSequenceClassification

# Load model
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
)
model.eval()

# Apply dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear, torch.nn.Embedding},  # Layers to quantize
    dtype=torch.qint8
)

# Compare sizes
def model_size_mb(model):
    torch.save(model.state_dict(), "temp.p")
    import os
    size = os.path.getsize("temp.p") / 1e6
    os.remove("temp.p")
    return size

print(f"Original: {model_size_mb(model):.1f} MB")
print(f"Quantized: {model_size_mb(quantized_model):.1f} MB")

Static Quantization with Calibration

import torch
from torch.quantization import prepare, convert, get_default_qconfig

# Prepare model for static quantization
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

# Set quantization config
model.qconfig = get_default_qconfig('fbgemm')  # or 'qnnpack' for ARM

# Prepare model (inserts observers)
prepared_model = prepare(model)

# Calibrate with representative data
def calibrate(model, data_loader, num_batches=100):
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            if i >= num_batches:
                break
            model(**batch)

calibrate(prepared_model, calibration_loader)

# Convert to quantized model
quantized_model = convert(prepared_model)

Using bitsandbytes for LLMs

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

# 8-bit quantization
config_8bit = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_skip_modules=["lm_head"]
)

model_8bit = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=config_8bit,
    device_map="auto"
)

# 4-bit quantization (NF4)
config_4bit = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",           # NormalFloat4
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True       # Nested quantization
)

model_4bit = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=config_4bit,
    device_map="auto"
)

GPTQ Quantization

# GPTQ: Accurate Post-Training Quantization
# pip install auto-gptq

from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Quantization config
quantize_config = BaseQuantizeConfig(
    bits=4,
    group_size=128,
    desc_act=False
)

# Load and quantize model
model = AutoGPTQForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantize_config
)

# Prepare examples for calibration
examples = [tokenizer(text, return_tensors="pt") for text in calibration_texts]

# Quantize
model.quantize(examples)

# Save
model.save_quantized("./llama-7b-gptq")

AWQ Quantization

# AWQ: Activation-aware Weight Quantization
# pip install autoawq

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

# Load model and tokenizer
model_path = "meta-llama/Llama-2-7b-hf"
quant_path = "./llama-7b-awq"

model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Quantize
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
model.quantize(tokenizer, quant_config=quant_config)

# Save
model.save_quantized(quant_path)

Accuracy vs Speed Trade-offs

# Benchmark function
def evaluate_model(model, tokenizer, test_data):
    correct = 0
    total_time = 0

    for text, label in test_data:
        inputs = tokenizer(text, return_tensors="pt")

        start = time.time()
        outputs = model(**inputs)
        total_time += time.time() - start

        predicted = outputs.logits.argmax(-1).item()
        if predicted == label:
            correct += 1

    accuracy = correct / len(test_data)
    avg_latency = total_time / len(test_data)

    return {"accuracy": accuracy, "latency_ms": avg_latency * 1000}

# Compare quantization methods
results = {
    "FP32": evaluate_model(model_fp32, tokenizer, test_data),
    "FP16": evaluate_model(model_fp16, tokenizer, test_data),
    "INT8": evaluate_model(model_int8, tokenizer, test_data),
    "INT4": evaluate_model(model_int4, tokenizer, test_data)
}

for method, metrics in results.items():
    print(f"{method}: Accuracy={metrics['accuracy']:.2%}, Latency={metrics['latency_ms']:.1f}ms")

Best Practices

quantization_best_practices = {
    "choice": {
        "cpu_inference": "INT8 dynamic quantization",
        "gpu_inference": "FP16 or INT8",
        "memory_constrained": "INT4 (GPTQ/AWQ)",
        "accuracy_critical": "FP16 or calibrated INT8"
    },
    "calibration": {
        "data": "Use representative samples",
        "size": "100-1000 samples usually sufficient",
        "diversity": "Cover all input types"
    },
    "validation": {
        "always_test": "Compare accuracy before/after",
        "edge_cases": "Test on difficult examples",
        "latency": "Measure actual speedup"
    }
}

Tomorrow we’ll explore INT8 quantization in more detail.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.