Self-Supervised Learning — Contrastive and Masked Methods

Advanced TopicsSelf-SupervisedFree Lesson

Advertisement

Self-Supervised Learning

Self-supervised learning learns useful representations from unlabeled data by solving pretext tasks. It bridges the gap between supervised and unsupervised learning, enabling models to leverage massive unlabeled datasets.


The Self-Supervised Paradigm

DfSelf-Supervised Learning

Instead of using manual labels, self-supervised learning creates supervisory signals from the data itself:

  1. Pretext task: Predict some part of the input from other parts
  2. Learned representations: Features transfer well to downstream tasks
  3. Fine-tuning: Adapt representations to specific tasks with few labels

Key insight: Good representations capture structure that generalizes across tasks.


Contrastive Learning

DfContrastive Learning

Learn representations by pulling positive pairs (augmented views of same image) together and pushing negative pairs (different images) apart:

Similarity(zi,zj+)>Similarity(zi,zj)\text{Similarity}(z_i, z_j^+) > \text{Similarity}(z_i, z_j^-)

The model learns to be invariant to augmentations while distinguishing different images.

Li=logexp(sim(zi,zj)/τ)k=12N1kiexp(sim(zi,zk)/τ)\mathcal{L}_i = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{k \neq i} \exp(\text{sim}(z_i, z_k) / \tau)}

Contrastive Loss

Lcontrastive=logexp(sim(zi,zj)/τ)exp(sim(zi,zj)/τ)+ki,jexp(sim(zi,zk)/τ)\mathcal{L}_{\text{contrastive}} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\exp(\text{sim}(z_i, z_j) / \tau) + \sum_{k \neq i,j} \exp(\text{sim}(z_i, z_k) / \tau)}

Here,

  • zi,zjz_i, z_j=Embeddings of positive pair (two augmentations)
  • sim(,)\text{sim}(\cdot, \cdot)=Cosine similarity
  • τ\tau=Temperature parameter (typically 0.07)
  • NN=Batch size (2N samples = N pairs)

SimCLR

DfSimCLR Framework

SimCLR (Chen et al., 2020) consists of four components:

  1. Data augmentation: Random crop, color jitter, Gaussian blur, horizontal flip
  2. Encoder f()f(\cdot): ResNet-50, produces representation hi=f(xi)h_i = f(x_i)
  3. Projection head g()g(\cdot): MLP maps to contrastive space zi=g(hi)z_i = g(h_i)
  4. NT-Xent loss: Normalized temperature-scaled cross-entropy

Key insight: The projection head is crucial — downstream performance uses hih_i (not ziz_i), as g()g(\cdot) discards task-relevant information.

ℹ️ Why Projection Head?

The projection head g()g(\cdot) acts as an information bottleneck. During contrastive learning, it forces the encoder to capture features useful for distinguishing images, discarding augmentation-invariant but task-irrelevant information. For downstream tasks, use the encoder output hih_i (before projection head).


MoCo (Momentum Contrast)

DfMoCo

MoCo (He et al., 2020) decouples batch size from number of negatives using a queue:

  1. Queue: Maintain a FIFO queue of encoded representations (larger than batch)
  2. Momentum encoder: Update key encoder with momentum: θkmθk+(1m)θq\theta_k \leftarrow m \theta_k + (1-m) \theta_q
  3. InfoNCE loss: Same as NT-Xent but with queue as negative pool

This enables using 65K+ negatives per batch, improving contrastive learning.

Momentum Update

θkmθk+(1m)θq\theta_k \leftarrow m \cdot \theta_k + (1 - m) \cdot \theta_q

Here,

  • θk\theta_k=Momentum encoder parameters
  • θq\theta_q=Query encoder parameters
  • mm=Momentum coefficient (typically 0.999)

Masked Image Modeling

DfMasked Autoencoder (MAE)

MAE (He et al., 2022) masks random patches and reconstructs them:

  1. Mask: Randomly mask 75% of image patches
  2. Encode: Process only visible patches (efficient!)
  3. Decode: Reconstruct masked patches from visible ones + mask tokens
  4. Loss: MSE between reconstructed and original patches

Key insight: High masking ratio (75%) forces the model to learn semantic understanding rather than pixel interpolation.

MAE Loss

LMAE=1MiMx^ixi2\mathcal{L}_{\text{MAE}} = \frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \|\hat{x}_i - x_i\|^2

Here,

  • M\mathcal{M}=Set of masked patch indices
  • x^i\hat{x}_i=Reconstructed patch i
  • xix_i=Original patch i

BEiT (Bidirectional Encoder representation from Image Transformers)

DfBEiT

BEiT (Bao et al., 2021) uses discrete visual tokens as reconstruction targets:

  1. Tokenizer: dVAE discretizes image patches into visual tokens
  2. Mask: Mask 40-50% of patches
  3. Predict: Predict visual token IDs for masked patches

The visual tokens provide a learned discrete vocabulary for images, similar to words in NLP.


Comparison of Methods

MethodTypeMaskingNegativesBatch SizePerformance
SimCLRContrastiveNoneIn-batch4096Good
MoCo v2ContrastiveNoneQueue (65K)256Better
SwAVContrastiveNonePrototypes4096Better
MAEMasked75%None1024Excellent
BEiTMasked40%None1024Excellent
DINOSelf-distillationNoneSelf1024Excellent

💡 When to Use Which?

  • Contrastive (SimCLR, MoCo): When augmentations are well-defined
  • Masked modeling (MAE, BEiT): When spatial understanding matters
  • Self-distillation (DINO): When you want attention maps without labels
  • MAE is often the default choice for vision — efficient and effective

PyTorch Implementation

📝Example: SimCLR

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimCLR(nn.Module):
    def __init__(self, backbone='resnet50', projection_dim=128):
        super().__init__()
        # Encoder (ResNet without final FC)
        resnet = torchvision.models.resnet50(pretrained=False)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        encoder_dim = 2048

        # Projection head (MLP)
        self.projector = nn.Sequential(
            nn.Linear(encoder_dim, encoder_dim),
            nn.ReLU(),
            nn.Linear(encoder_dim, projection_dim)
        )

    def forward(self, x):
        # x: (batch, 2, C, H, W) — 2 augmented views
        h = self.encoder(x.view(-1, *x.shape[2:])).squeeze(-1).squeeze(-1)
        z = self.projector(h)
        return z


def nt_xent_loss(z1, z2, temperature=0.07):
    """NT-Xent loss for SimCLR."""
    batch_size = z1.shape[0]
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    # Concatenate: (2*batch, projection_dim)
    z = torch.cat([z1, z2], dim=0)

    # Similarity matrix: (2*batch, 2*batch)
    sim = torch.mm(z, z.T) / temperature

    # Mask out self-similarity
    mask = ~torch.eye(2 * batch_size, dtype=bool, device=z.device)
    sim = sim.masked_select(mask).view(2 * batch_size, -1)

    # Positive pairs: (i, i+batch) and (i+batch, i)
    pos_sim = torch.cat([
        torch.diag(torch.mm(z1, z2.T)),
        torch.diag(torch.mm(z2, z1.T))
    ]) / temperature

    # Labels: positive is at index 0 for each sample
    labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z.device)

    return F.cross_entropy(sim, labels)


# Training
model = SimCLR()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4 * 300 / 256)

for epoch in range(100):
    for images, _ in dataloader:
        # images: (batch, 2, C, H, W) — 2 augmented views
        z1, z2 = model(images).chunk(2, dim=1)
        loss = nt_xent_loss(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

📝Example: MAE (Simplified)

class MAE(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=768, encoder_depth=12, decoder_depth=4,
                 mask_ratio=0.75):
        super().__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        num_patches = (img_size // patch_size) ** 2

        # Encoder
        self.patch_embed = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        self.encoder_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=12, dim_feedforward=embed_dim*4
            ) for _ in range(encoder_depth)
        ])

        # Decoder
        self.decoder_embed = nn.Linear(embed_dim, embed_dim // 4)
        self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim // 4))
        self.decoder_pos_embed = nn.Parameter(
            torch.randn(1, num_patches + 1, embed_dim // 4)
        )
        self.decoder_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim // 4, nhead=4, dim_feedforward=embed_dim
            ) for _ in range(decoder_depth)
        ])
        self.decoder_pred = nn.Linear(
            embed_dim // 4, patch_size ** 2 * in_chans
        )

    def random_masking(self, x, mask_ratio):
        B, N, D = x.shape
        num_keep = int(N * (1 - mask_ratio))

        # Random permutation
        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # Keep first num_keep
        ids_keep = ids_shuffle[:, :num_keep]
        x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))

        # Binary mask: 0 = keep, 1 = masked
        mask = torch.ones(B, N, device=x.device)
        mask[:, :num_keep] = 0
        mask = torch.gather(mask, 1, ids_restore)

        return x_masked, mask, ids_restore

    def forward(self, x):
        # Patchify
        x = self.patch_embed(x).flatten(2).transpose(1, 2)  # (B, N, D)
        x = x + self.pos_embed[:, 1:, :]

        # Random masking
        x, mask, ids_restore = self.random_masking(x, self.mask_ratio)

        # Add CLS token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Encode
        for block in self.encoder_blocks:
            x = block(x)

        # Decode
        x = self.decoder_embed(x)
        mask_tokens = self.mask_token.repeat(
            x.shape[0], ids_restore.shape[1] - x.shape[1] + 1, 1
        )
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_ = torch.gather(x_, 1, ids_restore.unsqueeze(-1).expand(-1, -1, -1))
        x = torch.cat([x[:, :1, :], x_], dim=1)

        # Reconstruct
        x = self.decoder_pred(x)[:, 1:]  # Remove CLS
        return x, mask

Practice Exercises

  1. SimCLR ablation: Experiment with different augmentations. Which ones matter most?

  2. Linear evaluation: Train SimCLR on CIFAR-10, then freeze encoder and train linear classifier.

  3. MAE vs SimCLR: Compare representations using linear probing on CIFAR-10.

  4. Visualization: Plot attention maps from DINO. Does the model attend to objects?


Key Takeaways

📋Summary: Self-Supervised Learning

  • Contrastive learning: Pull positives together, push negatives apart
  • SimCLR: NT-Xent loss with large batches and projection head
  • MoCo: Momentum encoder + queue for more negatives
  • MAE: Mask 75%, reconstruct — efficient and effective
  • BEiT: Predict discrete visual tokens as targets
  • DINO: Self-distillation with no labels — learns semantic segmentation
  • Projection head: Used for pretraining, discard for downstream
  • Augmentations are crucial — define what the model learns to ignore
  • Linear evaluation: Standard protocol for comparing representations
  • MAE is currently the most effective and efficient for vision
  • See also: Self-Supervised in ML for fundamentals

Advertisement

Need Expert Deep Learning Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement