Advanced Training
Knowledge Distillation for LLMs — Compressing Intelligence
Knowledge distillation transfers the capabilities of large, expensive models into smaller, faster ones. This guide covers the theory and practice of distilling LLM knowledge for deployment.
- Response-Based Distillation — Train student to match teacher outputs
- Feature-Based Distillation — Transfer internal representations
- Chain-of-Thought Distillation — Distill reasoning capabilities specifically
The goal is not a smaller model — it is a smaller model that thinks like a larger one.
Knowledge Distillation for LLMs
Knowledge distillation enables deploying capable models at a fraction of the cost. A 70B parameter model distilled into a 7B model can retain 80-95% of the teacher's performance on many tasks while being 10x faster and cheaper to run.
DfKnowledge Distillation
Knowledge distillation is a model compression technique where a smaller "student" model is trained to reproduce the behavior of a larger "teacher" model. The student learns from the teacher's soft probability distributions, which contain richer information than hard labels.
The Distillation Framework
Temperature-Scaled Softmax
DfSoft Labels
When a teacher model outputs probabilities, the soft labels (probability distributions over vocabulary) contain more information than hard labels (one-hot). The "dark knowledge" in soft labels captures relationships between classes that hard labels miss.
Temperature-Scaled Softmax
Here,
- =Soft probability for class i
- =Logit for class i
- =Temperature parameter (higher = softer)
Temperature T controls the "softness" of the probability distribution. T=1 gives the standard softmax. T=2-5 is typical for distillation, making the distribution softer and revealing more about the teacher's knowledge of class relationships.
Distillation Loss
Distillation Loss
Here,
- =Total distillation loss
- =Cross-entropy loss with hard labels
- =KL divergence between soft distributions
- =Teacher soft probabilities at temperature T
- =Student soft probabilities at temperature T
- =Weight balancing hard and soft losses
- =Temperature parameter
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, temperature=4.0, alpha=0.5):
super().__init__()
self.temperature = temperature
self.alpha = alpha
def forward(self, student_logits, teacher_logits, labels):
# Soft loss (KL divergence between softened distributions)
student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
soft_loss = F.kl_div(
student_soft, teacher_soft,
reduction="batchmean"
) * (self.temperature ** 2)
# Hard loss (standard cross-entropy with true labels)
hard_loss = F.cross_entropy(student_logits, labels)
return self.alpha * hard_loss + (1 - self.alpha) * soft_loss
Types of Knowledge Distillation
Response-Based Distillation
DfResponse-Based Distillation
Response-based distillation trains the student to match the teacher's final output probabilities. This is the most common form — the student learns to produce the same token probabilities as the teacher for each input.
def response_distillation_step(teacher, student, batch, loss_fn):
with torch.no_grad():
teacher_outputs = teacher(batch["input_ids"])
student_outputs = student(batch["input_ids"])
loss = loss_fn(student_outputs.logits, teacher_outputs.logits, batch["labels"])
return loss
Feature-Based Distillation
DfFeature-Based Distillation
Feature-based distillation transfers knowledge by matching intermediate representations (hidden states) between teacher and student. The student learns not just what the teacher outputs, but how it processes information internally.
Feature Matching Loss
Here,
- =Student hidden state at layer l
- =Teacher hidden state at mapped layer
- =Learned projection matrix
- =Set of layers to match
class FeatureDistillationLoss(nn.Module):
def __init__(self, teacher_dim, student_dim):
super().__init__()
self.projection = nn.Linear(student_dim, teacher_dim)
def forward(self, student_hidden, teacher_hidden):
projected = self.projection(student_hidden)
return F.mse_loss(projected, teacher_hidden)
Attention-Based Distillation
DfAttention Transfer
Attention transfer matches the attention maps of teacher and student. The student learns to attend to the same parts of the input as the teacher, capturing its "focus" patterns.
def attention_distillation_loss(teacher_attn, student_attn):
"""Match attention distributions between teacher and student."""
teacher_attn = teacher_attn.mean(dim=1) # Average over heads
student_attn = student_attn.mean(dim=1)
teacher_attn = F.softmax(teacher_attn, dim=-1)
student_attn = F.softmax(student_attn, dim=-1)
return F.kl_div(
student_attn.log(), teacher_attn,
reduction="batchmean"
)
Chain-of-Thought Distillation
Distilling Reasoning Capabilities
DfChain-of-Thought Distillation
CoT distillation specifically transfers reasoning capabilities by having the teacher generate step-by-step reasoning traces, which the student then learns to reproduce. This is crucial for maintaining reasoning performance in smaller models.
def cot_distillation(teacher, student, problems, loss_fn):
"""Distill chain-of-thought reasoning from teacher to student."""
# Teacher generates reasoning traces
teacher_traces = []
for problem in problems:
trace = teacher.generate(
problem["question"],
max_new_tokens=1024,
temperature=0.3
)
teacher_traces.append(trace)
# Student learns to produce same reasoning
total_loss = 0
for problem, trace in zip(problems, teacher_traces):
student_output = student(problem["question"], labels=trace)
total_loss += student_output.loss
return total_loss / len(problems)
CoT distillation is particularly effective for mathematical reasoning. Models distilled with CoT traces retain significantly more mathematical capability than those distilled only on final answers.
Selective Distillation
DfSelective Distillation
Not all teacher outputs are equally valuable. Selective distillation focuses on examples where the teacher is most confident and accurate, avoiding distillation from noisy or incorrect teacher outputs.
def selective_distillation(teacher, student, dataset, confidence_threshold=0.9):
selective_loss = 0
count = 0
for batch in dataset:
with torch.no_grad():
teacher_probs = F.softmax(teacher(batch["input_ids"]).logits, dim=-1)
max_probs, _ = teacher_probs.max(dim=-1)
confidence = max_probs.mean().item()
if confidence > confidence_threshold:
student_out = student(batch["input_ids"])
loss = distillation_loss(student_out.logits, teacher(batch["input_ids"]).logits, batch["labels"])
selective_loss += loss
count += 1
return selective_loss / max(count, 1)
Multi-Teacher Distillation
DfMulti-Teacher Distillation
Use multiple teacher models (e.g., a general model and a specialized model) to provide complementary knowledge to the student. The student learns from the ensemble of teachers.
Multi-Teacher Loss
Here,
- =Number of teacher models
- =Weight for teacher k
- =Soft probabilities from teacher k
- =Student soft probabilities
def multi_teacher_distillation(teachers, student, batch, weights):
combined_teacher_loss = 0
for teacher, weight in zip(teachers, weights):
with torch.no_grad():
teacher_logits = teacher(batch["input_ids"])
student_logits = student(batch["input_ids"])
loss = F.kl_div(
F.log_softmax(student_logits / 4.0, dim=-1),
F.softmax(teacher_logits / 4.0, dim=-1),
reduction="batchmean"
)
combined_teacher_loss += weight * loss
return combined_teacher_loss
Practical Distillation Pipeline
Full Training Loop
def distill(teacher, student, train_dataloader, val_dataloader, epochs=3, lr=5e-5):
optimizer = torch.optim.AdamW(student.parameters(), lr=lr)
loss_fn = DistillationLoss(temperature=4.0, alpha=0.3)
teacher.eval()
student.train()
for epoch in range(epochs):
total_loss = 0
for batch in train_dataloader:
with torch.no_grad():
teacher_outputs = teacher(batch["input_ids"])
student_outputs = student(batch["input_ids"])
loss = loss_fn(
student_outputs.logits,
teacher_outputs.logits,
batch["labels"]
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch + 1}: Loss = {avg_loss:.4f}")
# Evaluate
evaluate(student, val_dataloader)
Distillation vs Other Compression Methods
| Method | Accuracy Retention | Speedup | Memory Reduction | Training Cost |
|---|---|---|---|---|
| Distillation | 80-95% | 3-10x | 3-10x | High (need teacher) |
| Quantization | 95-99% | 2-4x | 2-4x | Low |
| Pruning | 90-98% | 2-5x | 2-5x | Medium |
| Low-rank | 85-95% | 2-3x | 2-3x | Medium |
The best approach often combines multiple techniques: distill to a smaller architecture, then quantize for further compression. A 7B model distilled from 70B and quantized to INT4 can run on consumer GPUs while retaining significant capability.
Practice Exercises
-
Distillation Design: Design a distillation pipeline to compress a 70B teacher into a 7B student. What temperature, alpha, and training data would you use?
-
CoT Distillation: Implement chain-of-thought distillation for mathematical reasoning. How would you measure whether reasoning capabilities transfer effectively?
-
Multi-Teacher Ensemble: If you have a general-purpose teacher and a code-specialized teacher, how would you combine their knowledge for a student that needs both capabilities?
-
Distillation Analysis: Compare distillation from a 70B model vs. training a 7B model from scratch on the same data. What are the tradeoffs?
Key Takeaways
Summary: Knowledge Distillation for LLMs
- Soft labels contain richer information than hard labels (dark knowledge)
- Temperature scaling controls how much information is transferred
- Response-based distillation matches output probabilities (most common)
- Feature-based distillation transfers internal representations
- Attention transfer matches where the model "looks"
- CoT distillation specifically transfers reasoning capabilities
- Selective distillation focuses on high-confidence teacher outputs
- Multi-teacher distillation combines complementary knowledge sources
- Combined approaches (distillation + quantization) give best compression
What to Learn Next
-> QLoRA and Quantization Reducing model size through quantization techniques.
-> LoRA and PEFT Parameter-efficient fine-tuning for large models.
-> Distributed Training for LLMs Scaling training across hundreds of GPUs.
-> Curriculum Learning for LLMs Strategic ordering of training data.
-> LLM Inference Optimization Making LLM inference faster and cheaper.
-> Open-Source LLM Ecosystem Pre-trained and distilled models available today.