Back to Blog
4 min read

Model Distillation: Creating Efficient Student Models

Knowledge distillation trains smaller “student” models to mimic larger “teacher” models, achieving competitive performance with reduced compute requirements.

Distillation Overview

distillation_concept = {
    "teacher": "Large, accurate model",
    "student": "Smaller, faster model",
    "knowledge": "Soft labels (probabilities) from teacher",
    "benefit": "Student learns richer information than hard labels"
}

# Why soft labels help:
# Hard label: [0, 0, 1] (just the class)
# Soft label: [0.1, 0.2, 0.7] (relationships between classes)

Basic Distillation Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha  # Weight for distillation loss

    def forward(self, student_logits, teacher_logits, labels):
        # Soft loss (distillation)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        distillation_loss = F.kl_div(
            soft_student,
            soft_teacher,
            reduction='batchmean'
        ) * (self.temperature ** 2)

        # Hard loss (standard cross-entropy)
        hard_loss = F.cross_entropy(student_logits, labels)

        # Combined loss
        return self.alpha * distillation_loss + (1 - self.alpha) * hard_loss

# Training loop
def train_distillation(teacher, student, train_loader, optimizer, epochs):
    teacher.eval()
    criterion = DistillationLoss(temperature=4.0, alpha=0.7)

    for epoch in range(epochs):
        student.train()
        total_loss = 0

        for batch in train_loader:
            inputs, labels = batch

            # Get teacher predictions
            with torch.no_grad():
                teacher_logits = teacher(inputs)

            # Get student predictions
            student_logits = student(inputs)

            # Compute distillation loss
            loss = criterion(student_logits, teacher_logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}: Loss = {total_loss / len(train_loader):.4f}")

Distilling Transformers

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments
)

# Teacher: Large model
teacher = AutoModelForSequenceClassification.from_pretrained("bert-large-uncased")

# Student: Smaller model
student = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")

# Custom trainer for distillation
class DistillationTrainer(Trainer):
    def __init__(self, teacher_model, temperature=4.0, alpha=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha = alpha
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")

        # Student forward pass
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits

        # Teacher forward pass
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits

        # Distillation loss
        loss = distillation_loss(
            student_logits, teacher_logits, labels,
            self.temperature, self.alpha
        )

        return (loss, student_outputs) if return_outputs else loss

# Train with distillation
trainer = DistillationTrainer(
    teacher_model=teacher,
    model=student,
    args=TrainingArguments(output_dir="./distilled"),
    train_dataset=train_dataset,
    temperature=4.0,
    alpha=0.5
)
trainer.train()

Layer-wise Distillation

class LayerDistillationLoss(nn.Module):
    """Distill intermediate layer representations."""

    def __init__(self, teacher_dim, student_dim):
        super().__init__()
        # Projection to match dimensions
        self.projector = nn.Linear(student_dim, teacher_dim)

    def forward(self, teacher_hidden, student_hidden):
        # Project student to teacher dimension
        student_projected = self.projector(student_hidden)

        # MSE loss on hidden states
        return F.mse_loss(student_projected, teacher_hidden)

# Training with layer distillation
def layer_distillation_step(teacher, student, inputs, layer_mapping):
    """
    layer_mapping: dict mapping student layers to teacher layers
    e.g., {0: 0, 1: 2, 2: 4, 3: 6} for 4-layer student from 6-layer teacher
    """
    # Get all hidden states
    with torch.no_grad():
        teacher_outputs = teacher(**inputs, output_hidden_states=True)
    student_outputs = student(**inputs, output_hidden_states=True)

    # Layer-wise loss
    layer_loss = 0
    for student_layer, teacher_layer in layer_mapping.items():
        layer_loss += F.mse_loss(
            student_outputs.hidden_states[student_layer],
            teacher_outputs.hidden_states[teacher_layer]
        )

    return layer_loss

Attention Distillation

def attention_distillation_loss(teacher_attentions, student_attentions, mapping):
    """Distill attention patterns from teacher to student."""
    loss = 0

    for student_layer, teacher_layer in mapping.items():
        teacher_attn = teacher_attentions[teacher_layer]  # [B, H, S, S]
        student_attn = student_attentions[student_layer]

        # If head counts differ, average teacher heads
        if teacher_attn.shape[1] != student_attn.shape[1]:
            teacher_attn = teacher_attn.mean(dim=1, keepdim=True)
            teacher_attn = teacher_attn.expand_as(student_attn)

        loss += F.mse_loss(student_attn, teacher_attn)

    return loss / len(mapping)

TinyBERT-style Distillation

class TinyBERTDistillation:
    """
    TinyBERT distills:
    1. Embedding layer
    2. Hidden states
    3. Attention matrices
    4. Prediction layer
    """

    def __init__(self, teacher, student, config):
        self.teacher = teacher
        self.student = student
        self.config = config

    def compute_loss(self, inputs, labels):
        # Forward passes with all outputs
        with torch.no_grad():
            teacher_out = self.teacher(
                **inputs,
                output_hidden_states=True,
                output_attentions=True
            )
        student_out = self.student(
            **inputs,
            output_hidden_states=True,
            output_attentions=True
        )

        # 1. Embedding loss
        emb_loss = F.mse_loss(
            student_out.hidden_states[0],
            teacher_out.hidden_states[0]
        )

        # 2. Hidden state loss
        hidden_loss = sum(
            F.mse_loss(s, t)
            for s, t in zip(
                student_out.hidden_states[1:],
                teacher_out.hidden_states[1::2]  # Every other layer
            )
        )

        # 3. Attention loss
        attn_loss = sum(
            F.mse_loss(s, t)
            for s, t in zip(
                student_out.attentions,
                teacher_out.attentions[::2]
            )
        )

        # 4. Prediction loss
        pred_loss = F.kl_div(
            F.log_softmax(student_out.logits / 4, dim=-1),
            F.softmax(teacher_out.logits / 4, dim=-1),
            reduction='batchmean'
        ) * 16

        return emb_loss + hidden_loss + attn_loss + pred_loss

Distillation Best Practices

best_practices = {
    "temperature": {
        "typical": "2-10, higher for softer distributions",
        "start": "4 is a good default"
    },
    "alpha": {
        "typical": "0.5-0.9",
        "more_distillation": "Higher alpha, more weight on teacher"
    },
    "architecture": {
        "student": "Same architecture family, fewer layers/dims",
        "ratio": "3-10x parameter reduction typical"
    },
    "training": {
        "epochs": "May need more than fine-tuning",
        "learning_rate": "Can use higher LR for student",
        "data": "Same or more data than teacher training"
    }
}

Tomorrow we’ll dive deeper into knowledge distillation techniques.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.