Back to Blog
4 min read

Knowledge Distillation: Advanced Techniques and Applications

Building on yesterday’s introduction, today we explore advanced knowledge distillation techniques and real-world applications.

Advanced Distillation Strategies

# Different types of knowledge to transfer
knowledge_types = {
    "response_based": "Output logits/probabilities",
    "feature_based": "Intermediate representations",
    "relation_based": "Relationships between samples/features"
}

Self-Distillation

class SelfDistillation(nn.Module):
    """Model distills knowledge from its deeper layers to shallower ones."""

    def __init__(self, base_model, num_classifiers=4):
        super().__init__()
        self.base_model = base_model
        self.hidden_size = base_model.config.hidden_size
        self.num_layers = base_model.config.num_hidden_layers

        # Add classifiers at intermediate layers
        self.classifiers = nn.ModuleList([
            nn.Linear(self.hidden_size, base_model.config.num_labels)
            for _ in range(num_classifiers)
        ])

        # Layers to attach classifiers
        self.classifier_layers = [
            self.num_layers // 4,
            self.num_layers // 2,
            3 * self.num_layers // 4,
            self.num_layers
        ]

    def forward(self, inputs, labels=None):
        outputs = self.base_model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states

        # Get predictions from each classifier
        all_logits = []
        for i, (layer_idx, classifier) in enumerate(
            zip(self.classifier_layers, self.classifiers)
        ):
            logits = classifier(hidden_states[layer_idx])
            all_logits.append(logits)

        if labels is not None:
            # Self-distillation loss
            loss = self.compute_self_distillation_loss(all_logits, labels)
            return loss, all_logits[-1]

        return all_logits[-1]

    def compute_self_distillation_loss(self, all_logits, labels, temperature=3.0):
        # The deepest classifier is the "teacher"
        teacher_logits = all_logits[-1].detach()

        total_loss = 0
        for logits in all_logits:
            # Hard loss
            hard_loss = F.cross_entropy(logits, labels)

            # Soft loss (distill from deepest)
            if logits is not all_logits[-1]:
                soft_loss = F.kl_div(
                    F.log_softmax(logits / temperature, dim=-1),
                    F.softmax(teacher_logits / temperature, dim=-1),
                    reduction='batchmean'
                ) * (temperature ** 2)
                total_loss += hard_loss + soft_loss
            else:
                total_loss += hard_loss

        return total_loss / len(all_logits)

Online Distillation

class MutualLearning:
    """Two models learn from each other simultaneously."""

    def __init__(self, model1, model2, temperature=4.0):
        self.model1 = model1
        self.model2 = model2
        self.temperature = temperature

    def train_step(self, inputs, labels, optimizer1, optimizer2):
        # Forward pass for both models
        logits1 = self.model1(**inputs).logits
        logits2 = self.model2(**inputs).logits

        # Cross-entropy loss
        ce_loss1 = F.cross_entropy(logits1, labels)
        ce_loss2 = F.cross_entropy(logits2, labels)

        # Mutual distillation loss
        kl_loss1 = F.kl_div(
            F.log_softmax(logits1 / self.temperature, dim=-1),
            F.softmax(logits2.detach() / self.temperature, dim=-1),
            reduction='batchmean'
        )
        kl_loss2 = F.kl_div(
            F.log_softmax(logits2 / self.temperature, dim=-1),
            F.softmax(logits1.detach() / self.temperature, dim=-1),
            reduction='batchmean'
        )

        # Total losses
        loss1 = ce_loss1 + kl_loss1
        loss2 = ce_loss2 + kl_loss2

        # Update both models
        optimizer1.zero_grad()
        loss1.backward()
        optimizer1.step()

        optimizer2.zero_grad()
        loss2.backward()
        optimizer2.step()

        return loss1.item(), loss2.item()

Feature Map Distillation

class FeatureDistillation(nn.Module):
    """Distill intermediate feature maps with learned projections."""

    def __init__(self, teacher_dims, student_dims):
        super().__init__()
        # Projectors to align dimensions
        self.projectors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(s_dim, t_dim),
                nn.ReLU(),
                nn.Linear(t_dim, t_dim)
            )
            for s_dim, t_dim in zip(student_dims, teacher_dims)
        ])

    def forward(self, teacher_features, student_features):
        """
        teacher_features: list of tensors from teacher layers
        student_features: list of tensors from student layers
        """
        loss = 0
        for projector, t_feat, s_feat in zip(
            self.projectors, teacher_features, student_features
        ):
            # Project student features
            s_projected = projector(s_feat)

            # Normalize and compute loss
            t_norm = F.normalize(t_feat, p=2, dim=-1)
            s_norm = F.normalize(s_projected, p=2, dim=-1)

            loss += F.mse_loss(s_norm, t_norm)

        return loss / len(self.projectors)

Distillation for Generation Tasks

class SequenceDistillation:
    """Distillation for autoregressive generation models."""

    def __init__(self, teacher, student, temperature=2.0):
        self.teacher = teacher
        self.student = student
        self.temperature = temperature

    def compute_loss(self, input_ids, attention_mask, labels):
        # Teacher forward (no gradient)
        with torch.no_grad():
            teacher_outputs = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            teacher_logits = teacher_outputs.logits

        # Student forward
        student_outputs = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        student_logits = student_outputs.logits

        # Shift for causal LM
        shift_logits = student_logits[..., :-1, :].contiguous()
        shift_teacher = teacher_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # KL divergence loss per position
        kl_loss = F.kl_div(
            F.log_softmax(shift_logits / self.temperature, dim=-1),
            F.softmax(shift_teacher / self.temperature, dim=-1),
            reduction='none'
        ).sum(dim=-1)

        # Mask padding
        mask = (shift_labels != -100).float()
        kl_loss = (kl_loss * mask).sum() / mask.sum()

        # Hard loss
        ce_loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )

        return 0.5 * kl_loss * (self.temperature ** 2) + 0.5 * ce_loss

Practical Distillation Pipeline

def distill_model(
    teacher_model_name,
    student_model_name,
    train_dataset,
    eval_dataset,
    output_dir,
    temperature=4.0,
    alpha=0.7,
    epochs=5
):
    """Complete distillation pipeline."""

    # Load models
    teacher = AutoModelForSequenceClassification.from_pretrained(teacher_model_name)
    student = AutoModelForSequenceClassification.from_pretrained(student_model_name)
    tokenizer = AutoTokenizer.from_pretrained(student_model_name)

    teacher.eval()
    for param in teacher.parameters():
        param.requires_grad = False

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=32,
        learning_rate=5e-5,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True
    )

    # Custom trainer
    trainer = DistillationTrainer(
        teacher_model=teacher,
        model=student,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        temperature=temperature,
        alpha=alpha
    )

    # Train
    trainer.train()

    # Evaluate
    results = trainer.evaluate()
    print(f"Final accuracy: {results['eval_accuracy']:.4f}")

    # Save
    student.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    return student, results

When to Use Distillation

distillation_use_cases = {
    "good_for": [
        "Deployment on edge devices",
        "Reducing inference latency",
        "When labeled data is limited",
        "Creating task-specific small models"
    ],
    "consider_alternatives": [
        "Quantization (simpler, often sufficient)",
        "Pruning (keeps architecture)",
        "Direct small model training (if data abundant)"
    ],
    "typical_compression": {
        "bert_large_to_distilbert": "60% size, ~97% accuracy",
        "gpt2_xl_to_gpt2": "75% size, ~95% quality"
    }
}

Tomorrow we’ll explore small language models.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.