Transfer Learning
Info: π‘ Transfer learning lets you leverage knowledge from one task to improve performance on a different but related task. It dramatically reduces data and compute requirements for training deep learning models.
1. Why Transfer Learning Works
Deep neural networks learn hierarchical representations. Lower layers capture universal features (edges, textures, colors), while higher layers capture task-specific features (dog ears, car wheels). These universal features transfer across domains.
The Knowledge Transfer Spectrum
Task A (Source) Task B (Target)
ββββββββββββββββ ββββββββββββββββ
β Conv Layer 1 β ββββ edges, colors βββ β Conv Layer 1 β
β Conv Layer 2 β ββββ textures βββββββ Conv Layer 2 β
β Conv Layer 3 β ββββ parts βββββββ Conv Layer 3 β
β Conv Layer 4 β ββββ object parts βββββ Conv Layer 4 β
β FC Layer β ββββ ImageNet cats ββββ FC Layer β β REPLACE
β Output: 1000 β β Output: 10 β
ββββββββββββββββ ββββββββββββββββ
Lower layers: GENERAL (transfer well)
Higher layers: SPECIFIC (need adaptation)
Formal Framework
Given:
- Source task with dataset
- Target task with dataset (typically )
DfTransfer Learning Objective
Transfer learning aims to improve the target predictor by leveraging knowledge from the source model . The key insight is that features learned on large source datasets often transfer to related target tasks, even when the target dataset is small.
βΉοΈ Why Transfer Learning Works
Deep neural networks learn hierarchical representations: lower layers capture universal features (edges, textures) that transfer across domains, while higher layers capture task-specific features. By freezing lower layers and only training the classifier head, we preserve these universal features while adapting to the new task. This is why transfer learning works even when source and target tasks are different.
2. Transfer Learning Strategies
Strategy 1: Feature Extraction (Frozen Backbone)
Freeze all pre-trained layers and only train a new classifier head.
import torch
import torch.nn as nn
import torchvision.models as models
def create_feature_extractor(num_classes, model_name='resnet50'):
if model_name == 'resnet50':
model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
elif model_name == 'vgg16':
model = models.vgg16(pretrained=True)
num_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_features, num_classes)
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Unfreeze only the classifier
for param in model.fc.parameters():
param.requires_grad = True
return model
model = create_feature_extractor(num_classes=5)
print(f"Total params: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
π‘ When to Use Feature Extraction
Use when you have a small dataset (less than 1K images per class) and the domain is similar to the pre-trained data.
Strategy 2: Fine-Tuning (Unfreeze Top Layers)
Unfreeze some or all pre-trained layers and train with a small learning rate.
def create_finetuner(num_classes, num_unfreeze=20):
model = models.resnet50(pretrained=True)
# Freeze early layers
children = list(model.children())
for child in children[:-num_unfreeze]:
for param in child.parameters():
param.requires_grad = False
# Replace classifier
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
# Use differential learning rates
optimizer = torch.optim.Adam([
{'params': model.layer4.parameters(), 'lr': 1e-5},
{'params': model.fc.parameters(), 'lr': 1e-3},
], weight_decay=1e-4)
return model, optimizer
model, optimizer = create_finetuner(num_classes=5)
Strategy 3: Progressive Unfreezing
Gradually unfreeze layers during training β start with the classifier, then unfreeze deeper layers.
class ProgressiveUnfreezer:
def __init__(self, model):
self.model = model
self.layers = list(model.children())
self.current_unfreeze = 0
def unfreeze_next(self):
"""Unfreeze one more layer group from top to bottom."""
if self.current_unfreeze < len(self.layers):
layer = self.layers[-(self.current_unfreeze + 1)]
for param in layer.parameters():
param.requires_grad = True
self.current_unfreeze += 1
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print(f"Unfroze layer: {layer.__class__.__name__} | "
f"Trainable: {trainable:,}")
# Usage
unfreezer = ProgressiveUnfreezer(model)
# After each validation plateau:
unfreezer.unfreeze_next() # Layer4 β Layer3 β Layer2 β ...
3. Domain Adaptation
When source and target domains differ, domain adaptation aligns their feature distributions.
Maximum Mean Discrepancy (MMD)
Minimize the distance between source and target feature distributions:
Maximum Mean Discrepancy (MMD)
Here,
- =Source domain dataset
- =Target domain dataset
- =Feature mapping (kernel-induced)
- =Number of source and target samples
βΉοΈ MMD Intuition
MMD measures the distance between two distributions in a reproducing kernel Hilbert space (RKHS). If MMD is zero, the distributions are identical in the feature space. By minimizing MMD during training, we force the feature extractor to learn domain-invariant representations β features that look the same regardless of whether the input came from the source or target domain.
Adversarial Domain Adaptation
Train a domain discriminator to distinguish source vs. target features, while the feature extractor learns to fool it:
Adversarial Domain Adaptation Objective
Here,
- =Feature generator (shared backbone)
- =Domain discriminator
- =Task loss (classification, etc.)
- =Domain classification loss
- =Trade-off parameter
class DomainAdversarialNetwork(nn.Module):
def __init__(self, feature_extractor, num_classes, hidden_dim=256):
super().__init__()
self.feature_extractor = feature_extractor
self.task_classifier = nn.Linear(2048, num_classes)
self.domain_classifier = nn.Sequential(
nn.Linear(2048, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_dim, 2), # source vs target
)
self.grl_lambda = 1.0
def forward(self, x, alpha=1.0):
features = self.feature_extractor(x)
task_output = self.task_classifier(features)
# Gradient reversal layer (simplified)
reversed_features = GradientReversalLayer.apply(features, self.grl_lambda)
domain_output = self.domain_classifier(reversed_features)
return task_output, domain_output
class GradientReversalFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, lambda_):
ctx.lambda_ = lambda_
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return -ctx.lambda_ * grad_output, None
4. Pre-trained Model Zoo
Image Models
| Model | Parameters | Top-1 Acc | Best For |
|---|---|---|---|
| ResNet-50 | 25.6M | 76.1% | General purpose, fast |
| EfficientNet-B0 | 5.3M | 77.1% | Efficiency-focused |
| EfficientNet-B7 | 66M | 84.3% | Maximum accuracy |
| ViT-B/16 | 86M | 77.9% | Vision Transformer |
| ConvNeXt-B | 89M | 83.8% | Modern CNN |
NLP Models
| Model | Parameters | Task | Best For |
|---|---|---|---|
| BERT-base | 110M | Text classification | General NLP |
| RoBERTa-large | 355M | Various | High accuracy |
| GPT-2 | 1.5B | Text generation | Language modeling |
| DistilBERT | 66M | Distillation | Efficiency |
| sentence-BERT | 110M | Sentence embeddings | Semantic similarity |
Using Pre-trained Models
from transformers import AutoModel, AutoTokenizer
# Load pre-trained BERT
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')
# Tokenize input
text = "Transfer learning is powerful!"
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
# Get embeddings
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state[:, 0, :] # [CLS] token
print(f"Embedding shape: {embeddings.shape}") # [1, 768]
5. Practical Workflow
When to Use Which Strategy
Dataset Size Domain Similarity Strategy
βββββββββββββββββββββββββββββββββββββββββββββββββ
Small (<1K) Very Similar Feature Extraction
Small (<1K) Different Feature Extraction + DA
Medium (1K-10K) Similar Fine-tune Top Layers
Medium (1K-10K) Different Progressive Unfreezing + DA
Large (>10K) Any Full Fine-tuning
Very Large (>100K) Any Train from Scratch (maybe)
Complete Fine-tuning Pipeline
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import OneCycleLR
def fine_tune_pipeline():
# 1. Data
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
train_dataset = datasets.ImageFolder('data/train', transform=transform)
val_dataset = datasets.ImageFolder('data/val', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 2. Model
model = models.resnet50(pretrained=True)
# Freeze early layers
for name, param in model.named_parameters():
if 'layer4' not in name and 'fc' not in name:
param.requires_grad = False
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, len(train_dataset.classes)),
)
# 3. Optimizer with differential learning rates
optimizer = torch.optim.AdamW([
{'params': model.layer4.parameters(), 'lr': 1e-5},
{'params': model.fc.parameters(), 'lr': 1e-3},
], weight_decay=0.01)
scheduler = OneCycleLR(optimizer, max_lr=[1e-4, 1e-2],
steps_per_epoch=len(train_loader), epochs=20)
criterion = nn.CrossEntropyLoss()
# 4. Training
best_val_acc = 0.0
for epoch in range(20):
model.train()
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
# Validate
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
val_acc = 100. * correct / total
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print(f"Epoch {epoch+1:2d} | Val Acc: {val_acc:.1f}%")
print(f"Best Val Acc: {best_val_acc:.1f}%")
# fine_tune_pipeline()
ThNegative Transfer
Negative transfer occurs when transfer learning hurts target performance instead of helping it. This happens when:
- Source and target domains are too dissimilar
- The source model has memorized source-specific features that don't generalize
- The fine-tuning learning rate is too high, destroying useful source knowledge
The theoretical condition for positive transfer is that the source and target tasks share a common low-dimensional subspace in the feature space. If this subspace doesn't exist, transfer learning can be harmful.
πDifferential Learning Rates in Practice
Scenario: Fine-tuning ResNet-50 (pre-trained on ImageNet) for medical imaging.
Strategy: Use different learning rates for different layers:
- Conv layers 1-3 (universal features): (nearly frozen)
- Conv layer 4 (higher-level features): (slow adaptation)
- FC layer (new classifier): (fast learning)
Why this works: Lower layers detect edges and textures that transfer well to medical images. Higher layers detect object parts that need mild adaptation. The new classifier learns task-specific decision boundaries. This hierarchical approach prevents catastrophic forgetting while allowing adaptation.
6. Common Pitfalls
Catastrophic Forgetting
When fine-tuning erases source knowledge, the model loses its ability to perform well on the original task.
Solution: Use lower learning rates, freeze early layers, employ EWC (Elastic Weight Consolidation).
Elastic Weight Consolidation (EWC)
Here,
- =Current model parameters
- =Optimal parameters from source task
- =Fisher information (importance) of parameter \theta_i
- =Regularization strength
βΉοΈ EWC Intuition
EWC prevents catastrophic forgetting by penalizing changes to parameters that are important for the source task. The Fisher information measures how sensitive the source task loss is to changes in parameter β parameters with high Fisher information are "important" and should not change much during fine-tuning. This is a form of parameter-space regularization that preserves source knowledge while allowing adaptation.
Negative Transfer
When source knowledge hurts target performance.
Solution: Use similarity metrics to select source tasks; train with and without transfer and compare.
7. Key Takeaways
πSummary: Transfer Learning
- Feature extraction (frozen backbone) is the fastest and safest approach for small datasets ( images per class)
- Fine-tuning with differential learning rates gives best results when data is sufficient; use smaller LR for pre-trained layers, larger LR for new layers
- Progressive unfreezing prevents catastrophic forgetting and stabilizes training by gradually unfreezing layers from top to bottom
- Domain adaptation (MMD, adversarial training) helps when source and target domains differ; MMD minimizes distributional distance in feature space
- EWC penalizes changes to important parameters using Fisher information, preserving source knowledge during fine-tuning
- Negative transfer occurs when source and target tasks are too dissimilar; always validate transfer benefits empirically
- Catastrophic forgetting is mitigated by lower learning rates, early layer freezing, and regularization techniques
8. Practice Exercises
Exercise 1: Compare Strategies
# Train the same model three ways on a small dataset (e.g., 500 images):
# 1. Feature extraction (all frozen except FC)
# 2. Fine-tune last 2 blocks
# 3. Full fine-tuning
# Compare accuracy and training time
Exercise 2: Domain Adaptation
# Use MNIST as source, MNIST-M (colored MNIST) as target
# Implement a simple adversarial domain adaptation network
# Compare with and without domain adaptation
Exercise 3: Model Selection
# For a medical imaging task with 200 images per class:
# - Which pre-trained model would you choose? Why?
# - Which fine-tuning strategy? Justify.
# - Implement and evaluate