Self-Supervised Learning
Self-supervised learning learns useful representations from unlabeled data by solving pretext tasks. It bridges the gap between supervised and unsupervised learning, enabling models to leverage massive unlabeled datasets.
The Self-Supervised Paradigm
DfSelf-Supervised Learning
Instead of using manual labels, self-supervised learning creates supervisory signals from the data itself:
- Pretext task: Predict some part of the input from other parts
- Learned representations: Features transfer well to downstream tasks
- Fine-tuning: Adapt representations to specific tasks with few labels
Key insight: Good representations capture structure that generalizes across tasks.
Contrastive Learning
DfContrastive Learning
Learn representations by pulling positive pairs (augmented views of same image) together and pushing negative pairs (different images) apart:
The model learns to be invariant to augmentations while distinguishing different images.
Contrastive Loss
Here,
- =Embeddings of positive pair (two augmentations)
- =Cosine similarity
- =Temperature parameter (typically 0.07)
- =Batch size (2N samples = N pairs)
SimCLR
DfSimCLR Framework
SimCLR (Chen et al., 2020) consists of four components:
- Data augmentation: Random crop, color jitter, Gaussian blur, horizontal flip
- Encoder : ResNet-50, produces representation
- Projection head : MLP maps to contrastive space
- NT-Xent loss: Normalized temperature-scaled cross-entropy
Key insight: The projection head is crucial — downstream performance uses (not ), as discards task-relevant information.
ℹ️ Why Projection Head?
The projection head acts as an information bottleneck. During contrastive learning, it forces the encoder to capture features useful for distinguishing images, discarding augmentation-invariant but task-irrelevant information. For downstream tasks, use the encoder output (before projection head).
MoCo (Momentum Contrast)
DfMoCo
MoCo (He et al., 2020) decouples batch size from number of negatives using a queue:
- Queue: Maintain a FIFO queue of encoded representations (larger than batch)
- Momentum encoder: Update key encoder with momentum:
- InfoNCE loss: Same as NT-Xent but with queue as negative pool
This enables using 65K+ negatives per batch, improving contrastive learning.
Momentum Update
Here,
- =Momentum encoder parameters
- =Query encoder parameters
- =Momentum coefficient (typically 0.999)
Masked Image Modeling
DfMasked Autoencoder (MAE)
MAE (He et al., 2022) masks random patches and reconstructs them:
- Mask: Randomly mask 75% of image patches
- Encode: Process only visible patches (efficient!)
- Decode: Reconstruct masked patches from visible ones + mask tokens
- Loss: MSE between reconstructed and original patches
Key insight: High masking ratio (75%) forces the model to learn semantic understanding rather than pixel interpolation.
MAE Loss
Here,
- =Set of masked patch indices
- =Reconstructed patch i
- =Original patch i
BEiT (Bidirectional Encoder representation from Image Transformers)
DfBEiT
BEiT (Bao et al., 2021) uses discrete visual tokens as reconstruction targets:
- Tokenizer: dVAE discretizes image patches into visual tokens
- Mask: Mask 40-50% of patches
- Predict: Predict visual token IDs for masked patches
The visual tokens provide a learned discrete vocabulary for images, similar to words in NLP.
Comparison of Methods
| Method | Type | Masking | Negatives | Batch Size | Performance |
|---|---|---|---|---|---|
| SimCLR | Contrastive | None | In-batch | 4096 | Good |
| MoCo v2 | Contrastive | None | Queue (65K) | 256 | Better |
| SwAV | Contrastive | None | Prototypes | 4096 | Better |
| MAE | Masked | 75% | None | 1024 | Excellent |
| BEiT | Masked | 40% | None | 1024 | Excellent |
| DINO | Self-distillation | None | Self | 1024 | Excellent |
💡 When to Use Which?
- Contrastive (SimCLR, MoCo): When augmentations are well-defined
- Masked modeling (MAE, BEiT): When spatial understanding matters
- Self-distillation (DINO): When you want attention maps without labels
- MAE is often the default choice for vision — efficient and effective
PyTorch Implementation
📝Example: SimCLR
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimCLR(nn.Module):
def __init__(self, backbone='resnet50', projection_dim=128):
super().__init__()
# Encoder (ResNet without final FC)
resnet = torchvision.models.resnet50(pretrained=False)
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
encoder_dim = 2048
# Projection head (MLP)
self.projector = nn.Sequential(
nn.Linear(encoder_dim, encoder_dim),
nn.ReLU(),
nn.Linear(encoder_dim, projection_dim)
)
def forward(self, x):
# x: (batch, 2, C, H, W) — 2 augmented views
h = self.encoder(x.view(-1, *x.shape[2:])).squeeze(-1).squeeze(-1)
z = self.projector(h)
return z
def nt_xent_loss(z1, z2, temperature=0.07):
"""NT-Xent loss for SimCLR."""
batch_size = z1.shape[0]
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# Concatenate: (2*batch, projection_dim)
z = torch.cat([z1, z2], dim=0)
# Similarity matrix: (2*batch, 2*batch)
sim = torch.mm(z, z.T) / temperature
# Mask out self-similarity
mask = ~torch.eye(2 * batch_size, dtype=bool, device=z.device)
sim = sim.masked_select(mask).view(2 * batch_size, -1)
# Positive pairs: (i, i+batch) and (i+batch, i)
pos_sim = torch.cat([
torch.diag(torch.mm(z1, z2.T)),
torch.diag(torch.mm(z2, z1.T))
]) / temperature
# Labels: positive is at index 0 for each sample
labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z.device)
return F.cross_entropy(sim, labels)
# Training
model = SimCLR()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4 * 300 / 256)
for epoch in range(100):
for images, _ in dataloader:
# images: (batch, 2, C, H, W) — 2 augmented views
z1, z2 = model(images).chunk(2, dim=1)
loss = nt_xent_loss(z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
📝Example: MAE (Simplified)
class MAE(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=768, encoder_depth=12, decoder_depth=4,
mask_ratio=0.75):
super().__init__()
self.patch_size = patch_size
self.mask_ratio = mask_ratio
num_patches = (img_size // patch_size) ** 2
# Encoder
self.patch_embed = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size
)
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
self.encoder_blocks = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=12, dim_feedforward=embed_dim*4
) for _ in range(encoder_depth)
])
# Decoder
self.decoder_embed = nn.Linear(embed_dim, embed_dim // 4)
self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim // 4))
self.decoder_pos_embed = nn.Parameter(
torch.randn(1, num_patches + 1, embed_dim // 4)
)
self.decoder_blocks = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=embed_dim // 4, nhead=4, dim_feedforward=embed_dim
) for _ in range(decoder_depth)
])
self.decoder_pred = nn.Linear(
embed_dim // 4, patch_size ** 2 * in_chans
)
def random_masking(self, x, mask_ratio):
B, N, D = x.shape
num_keep = int(N * (1 - mask_ratio))
# Random permutation
noise = torch.rand(B, N, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# Keep first num_keep
ids_keep = ids_shuffle[:, :num_keep]
x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))
# Binary mask: 0 = keep, 1 = masked
mask = torch.ones(B, N, device=x.device)
mask[:, :num_keep] = 0
mask = torch.gather(mask, 1, ids_restore)
return x_masked, mask, ids_restore
def forward(self, x):
# Patchify
x = self.patch_embed(x).flatten(2).transpose(1, 2) # (B, N, D)
x = x + self.pos_embed[:, 1:, :]
# Random masking
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# Add CLS token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# Encode
for block in self.encoder_blocks:
x = block(x)
# Decode
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] - x.shape[1] + 1, 1
)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(x_, 1, ids_restore.unsqueeze(-1).expand(-1, -1, -1))
x = torch.cat([x[:, :1, :], x_], dim=1)
# Reconstruct
x = self.decoder_pred(x)[:, 1:] # Remove CLS
return x, mask
Practice Exercises
-
SimCLR ablation: Experiment with different augmentations. Which ones matter most?
-
Linear evaluation: Train SimCLR on CIFAR-10, then freeze encoder and train linear classifier.
-
MAE vs SimCLR: Compare representations using linear probing on CIFAR-10.
-
Visualization: Plot attention maps from DINO. Does the model attend to objects?
Key Takeaways
📋Summary: Self-Supervised Learning
- Contrastive learning: Pull positives together, push negatives apart
- SimCLR: NT-Xent loss with large batches and projection head
- MoCo: Momentum encoder + queue for more negatives
- MAE: Mask 75%, reconstruct — efficient and effective
- BEiT: Predict discrete visual tokens as targets
- DINO: Self-distillation with no labels — learns semantic segmentation
- Projection head: Used for pretraining, discard for downstream
- Augmentations are crucial — define what the model learns to ignore
- Linear evaluation: Standard protocol for comparing representations
- MAE is currently the most effective and efficient for vision
- See also: Self-Supervised in ML for fundamentals