Transfer Learning

Module 3: Advanced ML + Deep LearningFree Lesson

Advertisement

Transfer Learning

Info: πŸ’‘ Transfer learning lets you leverage knowledge from one task to improve performance on a different but related task. It dramatically reduces data and compute requirements for training deep learning models.


1. Why Transfer Learning Works

Deep neural networks learn hierarchical representations. Lower layers capture universal features (edges, textures, colors), while higher layers capture task-specific features (dog ears, car wheels). These universal features transfer across domains.

The Knowledge Transfer Spectrum

Architecture Diagram
Task A (Source)                          Task B (Target)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                        β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Conv Layer 1 β”‚ ──── edges, colors ──→ β”‚ Conv Layer 1 β”‚
β”‚ Conv Layer 2 β”‚ ──── textures   ─────→│ Conv Layer 2 β”‚
β”‚ Conv Layer 3 β”‚ ──── parts      ─────→│ Conv Layer 3 β”‚
β”‚ Conv Layer 4 β”‚ ──── object parts ───→│ Conv Layer 4 β”‚
β”‚ FC Layer     β”‚ ──── ImageNet cats ──→│ FC Layer     β”‚ ← REPLACE
β”‚ Output: 1000 β”‚                        β”‚ Output: 10   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                        β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Lower layers: GENERAL (transfer well)
Higher layers: SPECIFIC (need adaptation)

Formal Framework

Given:

  • Source task Ts\mathcal{T}_s with dataset Ds\mathcal{D}_s
  • Target task Tt\mathcal{T}_t with dataset Dt\mathcal{D}_t (typically ∣Dt∣β‰ͺ∣Ds∣|\mathcal{D}_t| \ll |\mathcal{D}_s|)

DfTransfer Learning Objective

Transfer learning aims to improve the target predictor ftf_t by leveraging knowledge from the source model fsf_s. The key insight is that features learned on large source datasets often transfer to related target tasks, even when the target dataset is small.

min⁑θtLt(ft(x;ΞΈt),y)subjectΒ toΞΈtβ‰ˆΞΈs\min_{\theta_t} \mathcal{L}_t(f_t(x; \theta_t), y) \quad \text{subject to} \quad \theta_t \approx \theta_s

ℹ️ Why Transfer Learning Works

Deep neural networks learn hierarchical representations: lower layers capture universal features (edges, textures) that transfer across domains, while higher layers capture task-specific features. By freezing lower layers and only training the classifier head, we preserve these universal features while adapting to the new task. This is why transfer learning works even when source and target tasks are different.


2. Transfer Learning Strategies

Strategy 1: Feature Extraction (Frozen Backbone)

Freeze all pre-trained layers and only train a new classifier head.

import torch
import torch.nn as nn
import torchvision.models as models

def create_feature_extractor(num_classes, model_name='resnet50'):
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=True)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, num_classes)
    elif model_name == 'vgg16':
        model = models.vgg16(pretrained=True)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, num_classes)

    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False
    # Unfreeze only the classifier
    for param in model.fc.parameters():
        param.requires_grad = True

    return model

model = create_feature_extractor(num_classes=5)
print(f"Total params: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

πŸ’‘ When to Use Feature Extraction

Use when you have a small dataset (less than 1K images per class) and the domain is similar to the pre-trained data.

Strategy 2: Fine-Tuning (Unfreeze Top Layers)

Unfreeze some or all pre-trained layers and train with a small learning rate.

def create_finetuner(num_classes, num_unfreeze=20):
    model = models.resnet50(pretrained=True)

    # Freeze early layers
    children = list(model.children())
    for child in children[:-num_unfreeze]:
        for param in child.parameters():
            param.requires_grad = False

    # Replace classifier
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)

    # Use differential learning rates
    optimizer = torch.optim.Adam([
        {'params': model.layer4.parameters(), 'lr': 1e-5},
        {'params': model.fc.parameters(), 'lr': 1e-3},
    ], weight_decay=1e-4)

    return model, optimizer

model, optimizer = create_finetuner(num_classes=5)

Strategy 3: Progressive Unfreezing

Gradually unfreeze layers during training β€” start with the classifier, then unfreeze deeper layers.

class ProgressiveUnfreezer:
    def __init__(self, model):
        self.model = model
        self.layers = list(model.children())
        self.current_unfreeze = 0

    def unfreeze_next(self):
        """Unfreeze one more layer group from top to bottom."""
        if self.current_unfreeze < len(self.layers):
            layer = self.layers[-(self.current_unfreeze + 1)]
            for param in layer.parameters():
                param.requires_grad = True
            self.current_unfreeze += 1
            trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            print(f"Unfroze layer: {layer.__class__.__name__} | "
                  f"Trainable: {trainable:,}")

# Usage
unfreezer = ProgressiveUnfreezer(model)
# After each validation plateau:
unfreezer.unfreeze_next()  # Layer4 β†’ Layer3 β†’ Layer2 β†’ ...

3. Domain Adaptation

When source and target domains differ, domain adaptation aligns their feature distributions.

Maximum Mean Discrepancy (MMD)

Minimize the distance between source and target feature distributions:

Maximum Mean Discrepancy (MMD)

MMD2(Ds,Dt)=βˆ₯1nsβˆ‘i=1nsΟ•(xis)βˆ’1ntβˆ‘j=1ntΟ•(xjt)βˆ₯2\text{MMD}^2(\mathcal{D}_s, \mathcal{D}_t) = \left\| \frac{1}{n_s}\sum_{i=1}^{n_s} \phi(x_i^s) - \frac{1}{n_t}\sum_{j=1}^{n_t} \phi(x_j^t) \right\|^2

Here,

  • Ds\mathcal{D}_s=Source domain dataset
  • Dt\mathcal{D}_t=Target domain dataset
  • Ο•(β‹…)\phi(\cdot)=Feature mapping (kernel-induced)
  • ns,ntn_s, n_t=Number of source and target samples

ℹ️ MMD Intuition

MMD measures the distance between two distributions in a reproducing kernel Hilbert space (RKHS). If MMD is zero, the distributions are identical in the feature space. By minimizing MMD during training, we force the feature extractor to learn domain-invariant representations β€” features that look the same regardless of whether the input came from the source or target domain.

Adversarial Domain Adaptation

Train a domain discriminator to distinguish source vs. target features, while the feature extractor learns to fool it:

Adversarial Domain Adaptation Objective

min⁑Gmax⁑Dβ€…β€ŠLtask(G)βˆ’Ξ»Ldomain(D,G)\min_G \max_D \; \mathcal{L}_{\text{task}}(G) - \lambda \mathcal{L}_{\text{domain}}(D, G)

Here,

  • GG=Feature generator (shared backbone)
  • DD=Domain discriminator
  • Ltask\mathcal{L}_{\text{task}}=Task loss (classification, etc.)
  • Ldomain\mathcal{L}_{\text{domain}}=Domain classification loss
  • Ξ»\lambda=Trade-off parameter
class DomainAdversarialNetwork(nn.Module):
    def __init__(self, feature_extractor, num_classes, hidden_dim=256):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.task_classifier = nn.Linear(2048, num_classes)
        self.domain_classifier = nn.Sequential(
            nn.Linear(2048, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 2),  # source vs target
        )
        self.grl_lambda = 1.0

    def forward(self, x, alpha=1.0):
        features = self.feature_extractor(x)
        task_output = self.task_classifier(features)
        # Gradient reversal layer (simplified)
        reversed_features = GradientReversalLayer.apply(features, self.grl_lambda)
        domain_output = self.domain_classifier(reversed_features)
        return task_output, domain_output

class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

4. Pre-trained Model Zoo

Image Models

ModelParametersTop-1 AccBest For
ResNet-5025.6M76.1%General purpose, fast
EfficientNet-B05.3M77.1%Efficiency-focused
EfficientNet-B766M84.3%Maximum accuracy
ViT-B/1686M77.9%Vision Transformer
ConvNeXt-B89M83.8%Modern CNN

NLP Models

ModelParametersTaskBest For
BERT-base110MText classificationGeneral NLP
RoBERTa-large355MVariousHigh accuracy
GPT-21.5BText generationLanguage modeling
DistilBERT66MDistillationEfficiency
sentence-BERT110MSentence embeddingsSemantic similarity

Using Pre-trained Models

from transformers import AutoModel, AutoTokenizer

# Load pre-trained BERT
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')

# Tokenize input
text = "Transfer learning is powerful!"
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)

# Get embeddings
with torch.no_grad():
    outputs = model(**inputs)
    embeddings = outputs.last_hidden_state[:, 0, :]  # [CLS] token
    print(f"Embedding shape: {embeddings.shape}")  # [1, 768]

5. Practical Workflow

When to Use Which Strategy

Architecture Diagram
Dataset Size     Domain Similarity    Strategy
─────────────────────────────────────────────────
Small (<1K)      Very Similar         Feature Extraction
Small (<1K)      Different            Feature Extraction + DA
Medium (1K-10K)  Similar              Fine-tune Top Layers
Medium (1K-10K)  Different            Progressive Unfreezing + DA
Large (>10K)     Any                  Full Fine-tuning
Very Large (>100K) Any               Train from Scratch (maybe)

Complete Fine-tuning Pipeline

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import OneCycleLR

def fine_tune_pipeline():
    # 1. Data
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    train_dataset = datasets.ImageFolder('data/train', transform=transform)
    val_dataset = datasets.ImageFolder('data/val', transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    # 2. Model
    model = models.resnet50(pretrained=True)

    # Freeze early layers
    for name, param in model.named_parameters():
        if 'layer4' not in name and 'fc' not in name:
            param.requires_grad = False

    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, len(train_dataset.classes)),
    )

    # 3. Optimizer with differential learning rates
    optimizer = torch.optim.AdamW([
        {'params': model.layer4.parameters(), 'lr': 1e-5},
        {'params': model.fc.parameters(), 'lr': 1e-3},
    ], weight_decay=0.01)

    scheduler = OneCycleLR(optimizer, max_lr=[1e-4, 1e-2],
                           steps_per_epoch=len(train_loader), epochs=20)
    criterion = nn.CrossEntropyLoss()

    # 4. Training
    best_val_acc = 0.0
    for epoch in range(20):
        model.train()
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

        # Validate
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                outputs = model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        val_acc = 100. * correct / total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

        print(f"Epoch {epoch+1:2d} | Val Acc: {val_acc:.1f}%")

    print(f"Best Val Acc: {best_val_acc:.1f}%")

# fine_tune_pipeline()

ThNegative Transfer

Negative transfer occurs when transfer learning hurts target performance instead of helping it. This happens when:

  1. Source and target domains are too dissimilar
  2. The source model has memorized source-specific features that don't generalize
  3. The fine-tuning learning rate is too high, destroying useful source knowledge

The theoretical condition for positive transfer is that the source and target tasks share a common low-dimensional subspace in the feature space. If this subspace doesn't exist, transfer learning can be harmful.

πŸ“Differential Learning Rates in Practice

Scenario: Fine-tuning ResNet-50 (pre-trained on ImageNet) for medical imaging.

Strategy: Use different learning rates for different layers:

  • Conv layers 1-3 (universal features): Ξ·=10βˆ’5\eta = 10^{-5} (nearly frozen)
  • Conv layer 4 (higher-level features): Ξ·=10βˆ’4\eta = 10^{-4} (slow adaptation)
  • FC layer (new classifier): Ξ·=10βˆ’3\eta = 10^{-3} (fast learning)

Why this works: Lower layers detect edges and textures that transfer well to medical images. Higher layers detect object parts that need mild adaptation. The new classifier learns task-specific decision boundaries. This hierarchical approach prevents catastrophic forgetting while allowing adaptation.

6. Common Pitfalls

Catastrophic Forgetting

When fine-tuning erases source knowledge, the model loses its ability to perform well on the original task.

Solution: Use lower learning rates, freeze early layers, employ EWC (Elastic Weight Consolidation).

Elastic Weight Consolidation (EWC)

LEWC=Ltask(ΞΈ)+Ξ»2βˆ‘iFi(ΞΈiβˆ’ΞΈiβˆ—)2\mathcal{L}_{\text{EWC}} = \mathcal{L}_{\text{task}}(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_i^*)^2

Here,

  • ΞΈ\theta=Current model parameters
  • ΞΈβˆ—\theta^*=Optimal parameters from source task
  • FiF_i=Fisher information (importance) of parameter \theta_i
  • Ξ»\lambda=Regularization strength

ℹ️ EWC Intuition

EWC prevents catastrophic forgetting by penalizing changes to parameters that are important for the source task. The Fisher information FiF_i measures how sensitive the source task loss is to changes in parameter ΞΈi\theta_i β€” parameters with high Fisher information are "important" and should not change much during fine-tuning. This is a form of parameter-space regularization that preserves source knowledge while allowing adaptation.

Negative Transfer

When source knowledge hurts target performance.

Solution: Use similarity metrics to select source tasks; train with and without transfer and compare.


7. Key Takeaways

πŸ“‹Summary: Transfer Learning

  • Feature extraction (frozen backbone) is the fastest and safest approach for small datasets (<1K< 1K images per class)
  • Fine-tuning with differential learning rates gives best results when data is sufficient; use smaller LR for pre-trained layers, larger LR for new layers
  • Progressive unfreezing prevents catastrophic forgetting and stabilizes training by gradually unfreezing layers from top to bottom
  • Domain adaptation (MMD, adversarial training) helps when source and target domains differ; MMD minimizes distributional distance in feature space
  • EWC penalizes changes to important parameters using Fisher information, preserving source knowledge during fine-tuning
  • Negative transfer occurs when source and target tasks are too dissimilar; always validate transfer benefits empirically
  • Catastrophic forgetting is mitigated by lower learning rates, early layer freezing, and regularization techniques

8. Practice Exercises

Exercise 1: Compare Strategies

# Train the same model three ways on a small dataset (e.g., 500 images):
# 1. Feature extraction (all frozen except FC)
# 2. Fine-tune last 2 blocks
# 3. Full fine-tuning
# Compare accuracy and training time

Exercise 2: Domain Adaptation

# Use MNIST as source, MNIST-M (colored MNIST) as target
# Implement a simple adversarial domain adaptation network
# Compare with and without domain adaptation

Exercise 3: Model Selection

# For a medical imaging task with 200 images per class:
# - Which pre-trained model would you choose? Why?
# - Which fine-tuning strategy? Justify.
# - Implement and evaluate

Advertisement

Need Expert Data Science Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement