CW

Transfer Learning: Fine-tuning Pre-trained Models

Module 13: Computer VisionFree Lesson

Advertisement

Transfer Learning: Fine-tuning Pre-trained Models

Transfer learning is one of the most powerful paradigms in modern deep learning. Instead of training networks from scratch — a process demanding millions of labeled images and weeks of GPU time — we repurpose knowledge encoded in models trained on massive datasets. This lecture covers the theory, mechanics, and practical strategies for applying transfer learning to real-world problems.


1. What is Transfer Learning?

Transfer learning is the practice of taking a model trained on one task (the source task) and adapting it to a different but related task (the target task).

1.1 Why Transfer Learning Works

Deep networks learn hierarchical representations. Early layers capture universal visual primitives — edges, textures, color gradients. Middle layers compose these into motifs: corners, contours, repeated patterns. Deeper layers encode task-specific semantics — object parts, faces, scenes.

This hierarchy exhibits a key property: lower layers are more general, higher layers are more specific. The same Gabor-like edge detectors useful for ImageNet classification are equally useful for medical image segmentation or satellite imagery analysis.

Formally, consider a source model trainedondatasettrained on dataset\mathcal{D}_S = {(x_i^S, y_i^S)}andatargettaskwithdatasetand a target task with dataset\mathcal{D}_T = {(x_i^T, y_i^T)}.Transferlearningleveragesthefactthatthefeaturerepresentation. Transfer learning leverages the fact that the feature representation\phi_S(x)learnedbylearned by shares structure with an optimal representation ϕT(x)\phi_T^*(x) for the target task.

The key insight: natural images share statistical structure. A model that has learned to recognize 1,000 ImageNet classes has implicitly learned useful features for many other visual tasks.

1.2 The Taxonomy of Transfer

Transfer TypeSource → TargetExample
InductiveSame domain, different tasksImageNet → X-ray classification
UnsupervisedSame domain, no target labelsImageNet → feature extraction for clustering
Domain AdaptationDifferent domains, same taskSynthetic → real images for segmentation
Multi-taskShared features, multiple tasksSingle backbone for detection + depth

1.3 When Does Transfer Help?

The benefit of transfer depends on two factors:

\text{Transfer Gain} \propto \underbrace{\text{Task Similarity}(\mathcal{T}_S, \mathcal{T}T)}{\text{how related are the tasks?}} \times \underbrace{\frac{|\mathcal{D}_S|}{|\mathcal{D}T|}}{\text{data scarcity ratio}}

  • High similarity + small target dataset: Maximum benefit (e.g., ImageNet → CIFAR-10)
  • Low similarity + large target dataset: Transfer may hurt (negative transfer)
  • Any similarity + tiny dataset: Transfer is almost always beneficial

2. Pre-trained Models

2.1 ImageNet and the ILSVRC Benchmark

The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) provided 1.28 million training images across 1,000 categories. This dataset catalyzed the deep learning revolution and remains the standard source for pre-trained visual features.

2.2 Architecture Evolution

Pre-trained Architecture EvolutionAlexNet2012 - 8 layers61M paramsTop-5: 16.4%VGG-162014 - 16 layers138M paramsTop-5: 7.3%ResNet-502015 - 50 layers25M paramsTop-5: 3.6%EfficientNet-B72019 - Compound scaling66M paramsTop-5: 1.3%Key Architectural InnovationsResidual Connections (ResNet)Skip connections: y = F(x) + xEnables training of very deep networksInverted Bottlenecks (EfficientNet)Compound scaling: depth x width x resolutionNAS-optimized architectureBatch NormalizationNormalizes layer inputs for stable trainingReduces internal covariate shiftSE Attention (EfficientNet)Squeeze-and-Excitation channel attentionAdaptive feature recalibration

2.3 Choosing a Pre-trained Model

ModelParametersTop-1 (ImageNet)Best For
ResNet-1811.7M69.8%Quick prototyping, edge deployment
ResNet-5025.6M76.1%Good balance of speed/accuracy
EfficientNet-B05.3M77.1%Mobile/embedded
EfficientNet-B419.3M82.9%General-purpose
EfficientNet-B766.3M84.3%Maximum accuracy
ConvNeXt-B89M83.8%Transformer-like performance
ViT-B/1686M77.9%Vision Transformer baseline

3. Feature Extraction vs. Fine-tuning

Feature Extraction vs. Fine-tuningFeature ExtractionPre-trained Backbone (FROZEN)Conv Layers 1-3: Edges, TexturesConv Layers 4-6: Patterns, PartsConv Layers 7-8: High-level FeaturesFC: ImageNet Classes (REMOVED)All layers frozen (no gradient update)New Classifier Head (TRAINABLE)Fast training | Low data | SimpleFine-tuningPre-trained Backbone (PARTIAL)Conv Layers 1-3: FrozenConv Layers 4-6: UnfrozenConv Layers 7-8: UnfrozenFC: ImageNet Classes (REMOVED)Upper layers updated with target dataNew Classifier Head (TRAINABLE)Higher accuracy | Moderate data | More compute

3.1 Feature Extraction

In feature extraction mode, the pre-trained backbone is treated as a fixed feature extractor. Only the newly added classification head is trained.

import torch
import torch.nn as nn
import torchvision.models as models

# Load pre-trained ResNet-50
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Replace classifier head
num_classes = 10
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, num_classes)
)

# Only model.fc parameters are trainable
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
# Trainable: 529,162 / 23,528,522 (2.2%)

3.2 When to Use Each Approach

CriterionFeature ExtractionFine-tuning
Training dataLess than 1,000 imagesMore than 1,000 images
Domain similarityHigh (e.g., natural images)Low (e.g., medical, satellite)
Compute budgetLowHigh
Accuracy needModerateHigh
Training timeMinutes to hoursHours to days
Risk of overfittingLowModerate to high

4. Fine-tuning Strategies

Fine-tuning Strategies: Layer FreezingFull Fine-tuningAll layers: TrainableAll layers: TrainableAll layers: TrainableClassifier: TrainableMaximum flexibilityRisk of catastrophic forgettingRequires large datasetPartial Fine-tuningEarly layers: FrozenMiddle layers: FrozenLate layers: TrainableClassifier: TrainablePrevents overfittingGood for small-medium dataMust choose cutoff pointGradual UnfreezingStage 1: Classifier onlyStage 2: + Layer 4Stage 3: + Layer 3Stage 4: + Layer 2Stable trainingBest for very small dataSlower convergenceDecision GuideDataset size less than 500 images: Feature Extraction (frozen backbone)Dataset size 500-5000 images: Partial Fine-tuning or Gradual UnfreezingDataset size 5000+ images: Full Fine-tuning with differential learning ratesDomain mismatch between source and target: Gradual Unfreezing + Domain Adaptation

4.1 Full Fine-tuning

Unfreeze all layers and train the entire network. Use this when you have sufficient data and compute.

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# All parameters trainable with uniform learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)

4.2 Partial Fine-tuning

Freeze early layers, fine-tune later layers. This is the most common practical approach.

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Freeze everything except layer4 and fc
for name, param in model.named_parameters():
    if "layer4" not in name and "fc" not in name:
        param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, num_classes)

# Only unfrozen parameters
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(trainable_params, lr=1e-4)

4.3 Gradual Unfreezing

Start by training only the classifier, then progressively unfreeze deeper layers. Popularized by the ULMFiT paper for NLP.

class GradualUnfreezer:
    def __init__(self, model):
        self.model = model
        # Group layers bottom-to-top
        self.layer_groups = [
            model.layer1,
            model.layer2,
            model.layer3,
            model.layer4,
            model.fc,
        ]
        # Start with all frozen
        for group in self.layer_groups:
            for param in group.parameters():
                param.requires_grad = False

    def unfreeze_next_group(self):
        """Unfreeze the next layer group (top to bottom)."""
        for group in self.layer_groups:
            all_frozen = all(not p.requires_grad for p in group.parameters())
            if all_frozen:
                for param in group.parameters():
                    param.requires_grad = True
                name = group.__class__.__name__
                trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
                print(f"Unfroze: {name} | Trainable params: {trainable:,}")
                return
        print("All layers already unfrozen.")

# Usage: unfreeze one group per epoch
freezer = GradualUnfreezer(model)
freezer.unfreeze_next_group()  # Epoch 1: fc only
freezer.unfreeze_next_group()  # Epoch 2: fc + layer4
freezer.unfreeze_next_group()  # Epoch 3: + layer3
# ...

4.4 Layer Freezing Visualization

The diagram below shows how gradients flow through frozen and unfrozen layers:

Frozen layers: Lθi=0(no gradient update)\text{Frozen layers: } \frac{\partial \mathcal{L}}{\partial \theta_i} = 0 \quad \text{(no gradient update)}
Unfrozen layers: θiθiηLθi(standard update)\text{Unfrozen layers: } \theta_i \leftarrow \theta_i - \eta \frac{\partial \mathcal{L}}{\partial \theta_i} \quad \text{(standard update)}

5. Learning Rate Differentiation

5.1 Differential Learning Rates

Lower layers encode universal features that already generalize well — they need only slight adjustment. Higher layers encode more task-specific features that need more substantial adaptation. Using a single learning rate for all layers is suboptimal.

ηl=ηbaseαLl\eta_l = \eta_{\text{base}} \cdot \alpha^{L - l}

where ηl\eta_l is the learning rate for layer ll, LL is the total number of layers, and α(0,1)\alpha \in (0, 1) is a decay factor (typically 0.1 to 0.5).

5.2 Practical Implementation

# ResNet-50: 5 layer groups with decreasing learning rates
param_groups = [
    {"params": model.conv1.parameters(),   "lr": 1e-6},   # Lowest: universal features
    {"params": model.layer1.parameters(),  "lr": 1e-6},
    {"params": model.layer2.parameters(),  "lr": 1e-5},
    {"params": model.layer3.parameters(),  "lr": 1e-5},
    {"params": model.layer4.parameters(),  "lr": 1e-4},   # Higher: task-specific
    {"params": model.fc.parameters(),      "lr": 1e-3},   # Highest: new classifier
]

optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)

# Or use a simple ratio-based approach
def get_layer_lr(base_lr, layer_idx, total_layers, decay=0.1):
    """Decaying learning rate: earlier layers get smaller LR."""
    return base_lr * (decay ** (total_layers - layer_idx - 1) / (total_layers - 1))

5.3 Learning Rate Scheduling

Combined with differential LRs, schedulers adapt rates during training:

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-7)

# Training loop
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, scheduler)
    val_acc = evaluate(model, val_loader)

    # Optionally unfreeze more layers when val_acc plateaus
    if epoch == warmup_epochs:
        unfreezer.unfreeze_next_group()

6. Domain Adaptation

When source and target domains differ (e.g., synthetic vs. real images, daytime vs. nighttime), domain adaptation aligns feature distributions.

6.1 Transfer Learning Concept Diagram

Transfer Learning: Source to Target DomainSource DomainLarge labeled datasetImageNet (1.28M images)1,000 classesPre-trained modelRich feature representationsShared FeatureSpaceEdges, TexturesPatterns, ShapesSemantic PartsObject LevelDomain-invariant featuresTarget DomainSmall labeled datasetMedical ImagesSatellite ImageryIndustrial InspectionTask-specific adaptation

6.2 Maximum Mean Discrepancy (MMD)

Minimize the distance between source and target feature distributions:

MMD2(Ds,Dt)=1nsi=1nsϕ(xis)1ntj=1ntϕ(xjt)2\text{MMD}^2(\mathcal{D}_s, \mathcal{D}_t) = \left\| \frac{1}{n_s}\sum_{i=1}^{n_s} \phi(x_i^s) - \frac{1}{n_t}\sum_{j=1}^{n_t} \phi(x_j^t) \right\|^2

where ϕ()\phi(\cdot) is a kernel-induced feature mapping, and ns,ntn_s, n_t are the number of source and target samples.

Intuition: If MMD is zero, the distributions are identical in the feature space. By minimizing MMD during training, we force the feature extractor to learn domain-invariant representations.

6.3 Adversarial Domain Adaptation

Train a domain discriminator to distinguish source vs. target features, while the feature extractor learns to fool it:

minGmaxD  Ltask(G)λLdomain(D,G)\min_G \max_D \; \mathcal{L}_{\text{task}}(G) - \lambda \mathcal{L}_{\text{domain}}(D, G)

where GG is the feature generator, DD is the domain discriminator, and λ\lambda controls the trade-off.

class DomainAdversarialNetwork(nn.Module):
    def __init__(self, feature_extractor, num_classes, hidden_dim=256):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.task_classifier = nn.Linear(2048, num_classes)
        self.domain_classifier = nn.Sequential(
            nn.Linear(2048, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 2),  # source vs target
        )

    def forward(self, x, alpha=1.0):
        features = self.feature_extractor(x)
        task_output = self.task_classifier(features)
        # Gradient reversal layer
        reversed_features = GradientReversalLayer.apply(features, alpha)
        domain_output = self.domain_classifier(reversed_features)
        return task_output, domain_output

class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

7. Data Augmentation

Data augmentation is critical for transfer learning — it artificially expands small datasets and improves generalization.

7.1 Data Augmentation Examples

Data Augmentation Techniques[ ]Original224 x 224RGB[|]H-Flip (p=0.5)Mirror imageSymmetry-invariant[/]Rotate 15 degRotation invariance-15 to +15 deg[star]Color JitterBrightness, ContrastSaturation, Hue[#]Rand CropScale variationSize 0.8-1.0[_]CutoutRandom erasinghole size 16x16[~]Gaussian Blursigma 0.1-2.0Scale invariance[R]RandAugmentN ops, magnitude MAuto-selected ops[M]MixupBlend two imageslambda ~ Beta(0.2)[C]CutMixCut and paste regionBlend labels by area

7.2 Implementation with torchvision

from torchvision import transforms

# Training augmentation pipeline
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)),
])

# Validation: no augmentation, only resize and normalize
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

7.3 Advanced Augmentation: Mixup and CutMix

These techniques create virtual training examples by blending existing ones:

Mixup:

x~=λxi+(1λ)xj,y~=λyi+(1λ)yj,λBeta(α,α)\tilde{x} = \lambda x_i + (1 - \lambda) x_j, \quad \tilde{y} = \lambda y_i + (1 - \lambda) y_j, \quad \lambda \sim \text{Beta}(\alpha, \alpha)

CutMix: Cut a patch from image jj and paste it onto image ii. Labels are mixed proportionally to the area ratio.

def mixup_data(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size)

    _, _, h, w = x.shape
    cut_ratio = np.sqrt(1.0 - lam)
    ch = int(h * cut_ratio)
    cw = int(w * cut_ratio)

    cy, cx = np.random.randint(h), np.random.randint(w)
    y1, y2 = np.clip([cy - ch // 2, cy + ch // 2], 0, h)
    x1, x2 = np.clip([cx - cw // 2, cx + cw // 2], 0, w)

    x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
    lam = 1 - (y2 - y1) * (x2 - x1) / (h * w)
    return x, y, y[index], lam

8. Implementation in PyTorch

8.1 Complete Fine-tuning Pipeline

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import CosineAnnealingLR
import time

def fine_tune_pipeline(data_dir, num_classes, num_epochs=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Data with augmentation
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    train_dataset = datasets.ImageFolder(f'{data_dir}/train', transform=train_transform)
    val_dataset = datasets.ImageFolder(f'{data_dir}/val', transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

    # 2. Model with pre-trained weights
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

    # Phase 1: Freeze all but last block + classifier
    for name, param in model.named_parameters():
        if "layer4" not in name and "fc" not in name:
            param.requires_grad = False

    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes),
    )
    model = model.to(device)

    # 3. Differential learning rates
    optimizer = torch.optim.AdamW([
        {"params": model.layer4.parameters(), "lr": 1e-5},
        {"params": model.fc.parameters(), "lr": 1e-3},
    ], weight_decay=0.01)

    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-7)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    # 4. Training loop
    best_val_acc = 0.0
    history = {"train_loss": [], "val_acc": []}

    for epoch in range(num_epochs):
        # --- Train ---
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        scheduler.step()
        train_loss = running_loss / len(train_dataset)

        # --- Validate ---
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        val_acc = 100.0 * correct / total
        history["train_loss"].append(train_loss)
        history["val_acc"].append(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")

        print(f"Epoch {epoch+1:2d}/{num_epochs} | "
              f"Loss: {train_loss:.4f} | Val Acc: {val_acc:.1f}% | "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")

    # 5. Phase 2: Unfreeze all layers for fine-tuning
    print("\n--- Phase 2: Unfreezing all layers ---")
    for param in model.parameters():
        param.requires_grad = True

    optimizer2 = torch.optim.AdamW([
        {"params": model.parameters(), "lr": 1e-5},
    ], weight_decay=0.01)

    scheduler2 = CosineAnnealingLR(optimizer2, T_max=num_epochs // 2, eta_min=1e-7)

    for epoch in range(num_epochs // 2):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer2.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer2.step()
            running_loss += loss.item() * images.size(0)

        scheduler2.step()
        train_loss = running_loss / len(train_dataset)

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        val_acc = 100.0 * correct / total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")

        print(f"Epoch {epoch+1:2d} | Loss: {train_loss:.4f} | Val Acc: {val_acc:.1f}%")

    print(f"\nBest Validation Accuracy: {best_val_acc:.1f}%")
    return model, history

8.2 Gradual Unfreezing with ULMFiT-style Training

class GradualUnfreezer:
    """Unfreeze layer groups one at a time, from top to bottom."""

    def __init__(self, model, layer_groups=None):
        self.model = model
        self.layer_groups = layer_groups or [
            list(model.layer4.parameters()),
            list(model.layer3.parameters()),
            list(model.layer2.parameters()),
            list(model.layer1.parameters()),
            list(model.conv1.parameters()) + list(model.bn1.parameters()),
        ]
        # Freeze everything
        for group in self.layer_groups:
            for p in group:
                p.requires_grad = False
        self.group_idx = 0

    def unfreeze_next(self):
        if self.group_idx >= len(self.layer_groups):
            print("All groups unfrozen.")
            return
        for p in self.layer_groups[self.group_idx]:
            p.requires_grad = True
        self.group_idx += 1
        trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.model.parameters())
        print(f"Unfroze group {self.group_idx}/{len(self.layer_groups)} "
              f"| Trainable: {trainable:,}/{total:,}")

def train_with_gradual_unfreezing(model, train_loader, val_loader, num_groups, epochs_per_group=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    unfreezer = GradualUnfreezer(model)
    criterion = nn.CrossEntropyLoss()

    for g in range(num_groups):
        unfreezer.unfreeze_next()
        lr = 1e-3 * (0.1 ** g)  # Decreasing LR as more layers unfreeze
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr, weight_decay=0.01
        )

        for epoch in range(epochs_per_group):
            model.train()
            total_loss = 0
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                loss = criterion(model(x), y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                total_loss += loss.item()

            model.eval()
            correct = sum((model(x.to(device)).argmax(1) == y.to(device)).sum().item()
                         for x, y in val_loader)
            total = sum(y.size(0) for _, y in val_loader)
            print(f"  Group {g+1} Epoch {epoch+1} | "
                  f"Loss: {total_loss/len(train_loader):.4f} | "
                  f"Acc: {100*correct/total:.1f}%")

8.3 Model Export and Inference

def export_and_inference(model, input_tensor, onnx_path="model.onnx"):
    model.eval()

    # ONNX export
    torch.onnx.export(
        model, input_tensor, onnx_path,
        input_names=["input"], output_names=["output"],
        dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
        opset_version=17,
    )
    print(f"Exported to {onnx_path}")

    # Inference
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.softmax(output, dim=1)
        confidence, predicted = probabilities.max(1)
        print(f"Predicted class: {predicted.item()} | Confidence: {confidence.item():.4f}")

    return predicted, confidence

9. Fine-tuning Strategies Comparison

Strategy ComparisonStrategyParams ChangedLR RangeBest Data SizeRiskFeature Extraction2-5% (head only)1e-3 to 1e-2100 - 1KLowPartial Fine-tuning15-40% (top layers)1e-5 to 1e-31K - 10KMediumGradual UnfreezingProgressive 5-100%1e-4 to 1e-3500 - 5KLowFull Fine-tuning100% (all layers)1e-5 to 1e-410K+HighFrom Scratch100% (random init)1e-3 to 1e-1100K+Very High

10. Common Pitfalls and Solutions

10.1 Catastrophic Forgetting

When fine-tuning erases source knowledge. Solutions: lower learning rates, freeze early layers, Elastic Weight Consolidation (EWC):

LEWC=Ltask(θ)+λ2iFi(θiθi)2\mathcal{L}_{\text{EWC}} = \mathcal{L}_{\text{task}}(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_i^*)^2

where FiF_i is the Fisher information (importance) of parameter θi\theta_i, and θi\theta_i^* are optimal source parameters.

10.2 Negative Transfer

When source knowledge hurts target performance. Solutions: validate transfer benefits empirically; use similarity metrics to select source tasks.

10.3 Overfitting on Small Datasets

Solutions: strong augmentation, dropout, weight decay, early stopping, label smoothing.

10.4 Batch Normalization Issues

Frozen BN layers use source statistics, which may mismatch target data. Solution: use LayerNorm or train BN in eval mode with running statistics.


11. Key Takeaways

  • Feature extraction (frozen backbone) is the fastest and safest approach for small datasets (less than 1K images per class)
  • Fine-tuning with differential learning rates gives best results when data is sufficient; use smaller LR for pre-trained layers, larger LR for new layers
  • Progressive unfreezing prevents catastrophic forgetting and stabilizes training by gradually unfreezing layers from top to bottom
  • Domain adaptation (MMD, adversarial training) helps when source and target domains differ
  • EWC penalizes changes to important parameters using Fisher information, preserving source knowledge
  • Data augmentation (Mixup, CutMix, RandAugment) is essential for small datasets
  • Negative transfer occurs when source and target tasks are too dissimilar; always validate empirically

12. Practice Exercises

Exercise 1: Compare Strategies

# Train the same model three ways on a small dataset (e.g., 500 images):
# 1. Feature extraction (all frozen except FC)
# 2. Partial fine-tuning (unfreeze last 2 blocks)
# 3. Full fine-tuning
# Compare accuracy, training time, and overfitting behavior.

Exercise 2: Gradual Unfreezing Schedule

# Implement gradual unfreezing with different schedules:
# - Unfreeze one group every epoch
# - Unfreeze one group every 5 epochs
# - Unfreeze based on validation loss plateau
# Which schedule converges fastest? Which gives best accuracy?

Exercise 3: Domain Adaptation

# Use MNIST as source, MNIST-M (colored MNIST) as target.
# Implement adversarial domain adaptation with gradient reversal.
# Compare accuracy with and without domain adaptation.

Exercise 4: Ablation Study

# For a medical imaging task with 200 images per class:
# - Which pre-trained model? (ResNet vs EfficientNet vs ConvNeXt)
# - Which strategy? (Feature extraction vs partial fine-tuning)
# - Which augmentation? (Basic vs Mixup vs CutMix)
# - Which LR schedule? (Constant vs Cosine vs OneCycle)
# Report results in a table.

Advertisement

Need Expert Data Science Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement