Transfer Learning: Fine-tuning Pre-trained Models
Transfer learning is one of the most powerful paradigms in modern deep learning. Instead of training networks from scratch — a process demanding millions of labeled images and weeks of GPU time — we repurpose knowledge encoded in models trained on massive datasets. This lecture covers the theory, mechanics, and practical strategies for applying transfer learning to real-world problems.
1. What is Transfer Learning?
Transfer learning is the practice of taking a model trained on one task (the source task) and adapting it to a different but related task (the target task).
1.1 Why Transfer Learning Works
Deep networks learn hierarchical representations. Early layers capture universal visual primitives — edges, textures, color gradients. Middle layers compose these into motifs: corners, contours, repeated patterns. Deeper layers encode task-specific semantics — object parts, faces, scenes.
This hierarchy exhibits a key property: lower layers are more general, higher layers are more specific. The same Gabor-like edge detectors useful for ImageNet classification are equally useful for medical image segmentation or satellite imagery analysis.
Formally, consider a source model \mathcal{D}_S = {(x_i^S, y_i^S)}\mathcal{D}_T = {(x_i^T, y_i^T)}\phi_S(x) shares structure with an optimal representation for the target task.
The key insight: natural images share statistical structure. A model that has learned to recognize 1,000 ImageNet classes has implicitly learned useful features for many other visual tasks.
1.2 The Taxonomy of Transfer
| Transfer Type | Source → Target | Example |
|---|---|---|
| Inductive | Same domain, different tasks | ImageNet → X-ray classification |
| Unsupervised | Same domain, no target labels | ImageNet → feature extraction for clustering |
| Domain Adaptation | Different domains, same task | Synthetic → real images for segmentation |
| Multi-task | Shared features, multiple tasks | Single backbone for detection + depth |
1.3 When Does Transfer Help?
The benefit of transfer depends on two factors:
\text{Transfer Gain} \propto \underbrace{\text{Task Similarity}(\mathcal{T}_S, \mathcal{T}T)}{\text{how related are the tasks?}} \times \underbrace{\frac{|\mathcal{D}_S|}{|\mathcal{D}T|}}{\text{data scarcity ratio}}
- High similarity + small target dataset: Maximum benefit (e.g., ImageNet → CIFAR-10)
- Low similarity + large target dataset: Transfer may hurt (negative transfer)
- Any similarity + tiny dataset: Transfer is almost always beneficial
2. Pre-trained Models
2.1 ImageNet and the ILSVRC Benchmark
The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) provided 1.28 million training images across 1,000 categories. This dataset catalyzed the deep learning revolution and remains the standard source for pre-trained visual features.
2.2 Architecture Evolution
2.3 Choosing a Pre-trained Model
| Model | Parameters | Top-1 (ImageNet) | Best For |
|---|---|---|---|
| ResNet-18 | 11.7M | 69.8% | Quick prototyping, edge deployment |
| ResNet-50 | 25.6M | 76.1% | Good balance of speed/accuracy |
| EfficientNet-B0 | 5.3M | 77.1% | Mobile/embedded |
| EfficientNet-B4 | 19.3M | 82.9% | General-purpose |
| EfficientNet-B7 | 66.3M | 84.3% | Maximum accuracy |
| ConvNeXt-B | 89M | 83.8% | Transformer-like performance |
| ViT-B/16 | 86M | 77.9% | Vision Transformer baseline |
3. Feature Extraction vs. Fine-tuning
3.1 Feature Extraction
In feature extraction mode, the pre-trained backbone is treated as a fixed feature extractor. Only the newly added classification head is trained.
import torch
import torch.nn as nn
import torchvision.models as models
# Load pre-trained ResNet-50
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace classifier head
num_classes = 10
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
# Only model.fc parameters are trainable
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
# Trainable: 529,162 / 23,528,522 (2.2%)
3.2 When to Use Each Approach
| Criterion | Feature Extraction | Fine-tuning |
|---|---|---|
| Training data | Less than 1,000 images | More than 1,000 images |
| Domain similarity | High (e.g., natural images) | Low (e.g., medical, satellite) |
| Compute budget | Low | High |
| Accuracy need | Moderate | High |
| Training time | Minutes to hours | Hours to days |
| Risk of overfitting | Low | Moderate to high |
4. Fine-tuning Strategies
4.1 Full Fine-tuning
Unfreeze all layers and train the entire network. Use this when you have sufficient data and compute.
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Linear(model.fc.in_features, num_classes)
# All parameters trainable with uniform learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
4.2 Partial Fine-tuning
Freeze early layers, fine-tune later layers. This is the most common practical approach.
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# Freeze everything except layer4 and fc
for name, param in model.named_parameters():
if "layer4" not in name and "fc" not in name:
param.requires_grad = False
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Only unfrozen parameters
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(trainable_params, lr=1e-4)
4.3 Gradual Unfreezing
Start by training only the classifier, then progressively unfreeze deeper layers. Popularized by the ULMFiT paper for NLP.
class GradualUnfreezer:
def __init__(self, model):
self.model = model
# Group layers bottom-to-top
self.layer_groups = [
model.layer1,
model.layer2,
model.layer3,
model.layer4,
model.fc,
]
# Start with all frozen
for group in self.layer_groups:
for param in group.parameters():
param.requires_grad = False
def unfreeze_next_group(self):
"""Unfreeze the next layer group (top to bottom)."""
for group in self.layer_groups:
all_frozen = all(not p.requires_grad for p in group.parameters())
if all_frozen:
for param in group.parameters():
param.requires_grad = True
name = group.__class__.__name__
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print(f"Unfroze: {name} | Trainable params: {trainable:,}")
return
print("All layers already unfrozen.")
# Usage: unfreeze one group per epoch
freezer = GradualUnfreezer(model)
freezer.unfreeze_next_group() # Epoch 1: fc only
freezer.unfreeze_next_group() # Epoch 2: fc + layer4
freezer.unfreeze_next_group() # Epoch 3: + layer3
# ...
4.4 Layer Freezing Visualization
The diagram below shows how gradients flow through frozen and unfrozen layers:
5. Learning Rate Differentiation
5.1 Differential Learning Rates
Lower layers encode universal features that already generalize well — they need only slight adjustment. Higher layers encode more task-specific features that need more substantial adaptation. Using a single learning rate for all layers is suboptimal.
where is the learning rate for layer , is the total number of layers, and is a decay factor (typically 0.1 to 0.5).
5.2 Practical Implementation
# ResNet-50: 5 layer groups with decreasing learning rates
param_groups = [
{"params": model.conv1.parameters(), "lr": 1e-6}, # Lowest: universal features
{"params": model.layer1.parameters(), "lr": 1e-6},
{"params": model.layer2.parameters(), "lr": 1e-5},
{"params": model.layer3.parameters(), "lr": 1e-5},
{"params": model.layer4.parameters(), "lr": 1e-4}, # Higher: task-specific
{"params": model.fc.parameters(), "lr": 1e-3}, # Highest: new classifier
]
optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)
# Or use a simple ratio-based approach
def get_layer_lr(base_lr, layer_idx, total_layers, decay=0.1):
"""Decaying learning rate: earlier layers get smaller LR."""
return base_lr * (decay ** (total_layers - layer_idx - 1) / (total_layers - 1))
5.3 Learning Rate Scheduling
Combined with differential LRs, schedulers adapt rates during training:
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-7)
# Training loop
for epoch in range(num_epochs):
train_loss = train_one_epoch(model, train_loader, optimizer, scheduler)
val_acc = evaluate(model, val_loader)
# Optionally unfreeze more layers when val_acc plateaus
if epoch == warmup_epochs:
unfreezer.unfreeze_next_group()
6. Domain Adaptation
When source and target domains differ (e.g., synthetic vs. real images, daytime vs. nighttime), domain adaptation aligns feature distributions.
6.1 Transfer Learning Concept Diagram
6.2 Maximum Mean Discrepancy (MMD)
Minimize the distance between source and target feature distributions:
where is a kernel-induced feature mapping, and are the number of source and target samples.
Intuition: If MMD is zero, the distributions are identical in the feature space. By minimizing MMD during training, we force the feature extractor to learn domain-invariant representations.
6.3 Adversarial Domain Adaptation
Train a domain discriminator to distinguish source vs. target features, while the feature extractor learns to fool it:
where is the feature generator, is the domain discriminator, and controls the trade-off.
class DomainAdversarialNetwork(nn.Module):
def __init__(self, feature_extractor, num_classes, hidden_dim=256):
super().__init__()
self.feature_extractor = feature_extractor
self.task_classifier = nn.Linear(2048, num_classes)
self.domain_classifier = nn.Sequential(
nn.Linear(2048, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_dim, 2), # source vs target
)
def forward(self, x, alpha=1.0):
features = self.feature_extractor(x)
task_output = self.task_classifier(features)
# Gradient reversal layer
reversed_features = GradientReversalLayer.apply(features, alpha)
domain_output = self.domain_classifier(reversed_features)
return task_output, domain_output
class GradientReversalFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, lambda_):
ctx.lambda_ = lambda_
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return -ctx.lambda_ * grad_output, None
7. Data Augmentation
Data augmentation is critical for transfer learning — it artificially expands small datasets and improves generalization.
7.1 Data Augmentation Examples
7.2 Implementation with torchvision
from torchvision import transforms
# Training augmentation pipeline
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomGrayscale(p=0.1),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)),
])
# Validation: no augmentation, only resize and normalize
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
7.3 Advanced Augmentation: Mixup and CutMix
These techniques create virtual training examples by blending existing ones:
Mixup:
CutMix: Cut a patch from image and paste it onto image . Labels are mixed proportionally to the area ratio.
def mixup_data(x, y, alpha=0.2):
lam = np.random.beta(alpha, alpha)
batch_size = x.size(0)
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def cutmix_data(x, y, alpha=1.0):
lam = np.random.beta(alpha, alpha)
batch_size = x.size(0)
index = torch.randperm(batch_size)
_, _, h, w = x.shape
cut_ratio = np.sqrt(1.0 - lam)
ch = int(h * cut_ratio)
cw = int(w * cut_ratio)
cy, cx = np.random.randint(h), np.random.randint(w)
y1, y2 = np.clip([cy - ch // 2, cy + ch // 2], 0, h)
x1, x2 = np.clip([cx - cw // 2, cx + cw // 2], 0, w)
x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
lam = 1 - (y2 - y1) * (x2 - x1) / (h * w)
return x, y, y[index], lam
8. Implementation in PyTorch
8.1 Complete Fine-tuning Pipeline
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import CosineAnnealingLR
import time
def fine_tune_pipeline(data_dir, num_classes, num_epochs=20):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. Data with augmentation
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
train_dataset = datasets.ImageFolder(f'{data_dir}/train', transform=train_transform)
val_dataset = datasets.ImageFolder(f'{data_dir}/val', transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
# 2. Model with pre-trained weights
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# Phase 1: Freeze all but last block + classifier
for name, param in model.named_parameters():
if "layer4" not in name and "fc" not in name:
param.requires_grad = False
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes),
)
model = model.to(device)
# 3. Differential learning rates
optimizer = torch.optim.AdamW([
{"params": model.layer4.parameters(), "lr": 1e-5},
{"params": model.fc.parameters(), "lr": 1e-3},
], weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-7)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# 4. Training loop
best_val_acc = 0.0
history = {"train_loss": [], "val_acc": []}
for epoch in range(num_epochs):
# --- Train ---
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
running_loss += loss.item() * images.size(0)
scheduler.step()
train_loss = running_loss / len(train_dataset)
# --- Validate ---
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
val_acc = 100.0 * correct / total
history["train_loss"].append(train_loss)
history["val_acc"].append(val_acc)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), "best_model.pth")
print(f"Epoch {epoch+1:2d}/{num_epochs} | "
f"Loss: {train_loss:.4f} | Val Acc: {val_acc:.1f}% | "
f"LR: {optimizer.param_groups[0]['lr']:.2e}")
# 5. Phase 2: Unfreeze all layers for fine-tuning
print("\n--- Phase 2: Unfreezing all layers ---")
for param in model.parameters():
param.requires_grad = True
optimizer2 = torch.optim.AdamW([
{"params": model.parameters(), "lr": 1e-5},
], weight_decay=0.01)
scheduler2 = CosineAnnealingLR(optimizer2, T_max=num_epochs // 2, eta_min=1e-7)
for epoch in range(num_epochs // 2):
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer2.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer2.step()
running_loss += loss.item() * images.size(0)
scheduler2.step()
train_loss = running_loss / len(train_dataset)
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
val_acc = 100.0 * correct / total
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), "best_model.pth")
print(f"Epoch {epoch+1:2d} | Loss: {train_loss:.4f} | Val Acc: {val_acc:.1f}%")
print(f"\nBest Validation Accuracy: {best_val_acc:.1f}%")
return model, history
8.2 Gradual Unfreezing with ULMFiT-style Training
class GradualUnfreezer:
"""Unfreeze layer groups one at a time, from top to bottom."""
def __init__(self, model, layer_groups=None):
self.model = model
self.layer_groups = layer_groups or [
list(model.layer4.parameters()),
list(model.layer3.parameters()),
list(model.layer2.parameters()),
list(model.layer1.parameters()),
list(model.conv1.parameters()) + list(model.bn1.parameters()),
]
# Freeze everything
for group in self.layer_groups:
for p in group:
p.requires_grad = False
self.group_idx = 0
def unfreeze_next(self):
if self.group_idx >= len(self.layer_groups):
print("All groups unfrozen.")
return
for p in self.layer_groups[self.group_idx]:
p.requires_grad = True
self.group_idx += 1
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.model.parameters())
print(f"Unfroze group {self.group_idx}/{len(self.layer_groups)} "
f"| Trainable: {trainable:,}/{total:,}")
def train_with_gradual_unfreezing(model, train_loader, val_loader, num_groups, epochs_per_group=3):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
unfreezer = GradualUnfreezer(model)
criterion = nn.CrossEntropyLoss()
for g in range(num_groups):
unfreezer.unfreeze_next()
lr = 1e-3 * (0.1 ** g) # Decreasing LR as more layers unfreeze
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=lr, weight_decay=0.01
)
for epoch in range(epochs_per_group):
model.train()
total_loss = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
model.eval()
correct = sum((model(x.to(device)).argmax(1) == y.to(device)).sum().item()
for x, y in val_loader)
total = sum(y.size(0) for _, y in val_loader)
print(f" Group {g+1} Epoch {epoch+1} | "
f"Loss: {total_loss/len(train_loader):.4f} | "
f"Acc: {100*correct/total:.1f}%")
8.3 Model Export and Inference
def export_and_inference(model, input_tensor, onnx_path="model.onnx"):
model.eval()
# ONNX export
torch.onnx.export(
model, input_tensor, onnx_path,
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
opset_version=17,
)
print(f"Exported to {onnx_path}")
# Inference
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.softmax(output, dim=1)
confidence, predicted = probabilities.max(1)
print(f"Predicted class: {predicted.item()} | Confidence: {confidence.item():.4f}")
return predicted, confidence
9. Fine-tuning Strategies Comparison
10. Common Pitfalls and Solutions
10.1 Catastrophic Forgetting
When fine-tuning erases source knowledge. Solutions: lower learning rates, freeze early layers, Elastic Weight Consolidation (EWC):
where is the Fisher information (importance) of parameter , and are optimal source parameters.
10.2 Negative Transfer
When source knowledge hurts target performance. Solutions: validate transfer benefits empirically; use similarity metrics to select source tasks.
10.3 Overfitting on Small Datasets
Solutions: strong augmentation, dropout, weight decay, early stopping, label smoothing.
10.4 Batch Normalization Issues
Frozen BN layers use source statistics, which may mismatch target data. Solution: use LayerNorm or train BN in eval mode with running statistics.
11. Key Takeaways
- Feature extraction (frozen backbone) is the fastest and safest approach for small datasets (less than 1K images per class)
- Fine-tuning with differential learning rates gives best results when data is sufficient; use smaller LR for pre-trained layers, larger LR for new layers
- Progressive unfreezing prevents catastrophic forgetting and stabilizes training by gradually unfreezing layers from top to bottom
- Domain adaptation (MMD, adversarial training) helps when source and target domains differ
- EWC penalizes changes to important parameters using Fisher information, preserving source knowledge
- Data augmentation (Mixup, CutMix, RandAugment) is essential for small datasets
- Negative transfer occurs when source and target tasks are too dissimilar; always validate empirically
12. Practice Exercises
Exercise 1: Compare Strategies
# Train the same model three ways on a small dataset (e.g., 500 images):
# 1. Feature extraction (all frozen except FC)
# 2. Partial fine-tuning (unfreeze last 2 blocks)
# 3. Full fine-tuning
# Compare accuracy, training time, and overfitting behavior.
Exercise 2: Gradual Unfreezing Schedule
# Implement gradual unfreezing with different schedules:
# - Unfreeze one group every epoch
# - Unfreeze one group every 5 epochs
# - Unfreeze based on validation loss plateau
# Which schedule converges fastest? Which gives best accuracy?
Exercise 3: Domain Adaptation
# Use MNIST as source, MNIST-M (colored MNIST) as target.
# Implement adversarial domain adaptation with gradient reversal.
# Compare accuracy with and without domain adaptation.
Exercise 4: Ablation Study
# For a medical imaging task with 200 images per class:
# - Which pre-trained model? (ResNet vs EfficientNet vs ConvNeXt)
# - Which strategy? (Feature extraction vs partial fine-tuning)
# - Which augmentation? (Basic vs Mixup vs CutMix)
# - Which LR schedule? (Constant vs Cosine vs OneCycle)
# Report results in a table.