Vision Transformers — ViT and Beyond
Vision Transformers (ViT) apply the Transformer architecture to image recognition by treating image patches as tokens, achieving competitive or superior performance to CNNs.
From NLP to Vision
DfVision Transformer (ViT)
ViT (Dosovitskiy et al., 2020) processes images as sequences of patches:
- Split image into fixed-size patches (e.g., 16x16)
- Linearly embed each patch
- Add positional embeddings
- Process through standard Transformer encoder
- Use [CLS] token output for classification
Key insight: When trained on sufficient data, ViT outperforms CNNs because self-attention captures global relationships from the first layer.
Patch Embedding
Patch Embedding
Here,
- =Flattened patch i
- =Patch embedding matrix (linear projection)
- =Positional embedding
- =Classification token embedding
- =Number of patches (H*W / P^2)
Patch Projection (Convolutional Shortcut)
Here,
- =Patch size (e.g., 16)
- =Convolution with no padding, stride=P
ℹ️ Conv2d as Patch Embedding
A convolutional layer with kernel size = patch size and stride = patch size is equivalent to splitting the image into patches and applying a linear projection. This is computationally more efficient and is the standard implementation.
Positional Encoding
DfLearnable Positional Embeddings
Unlike the original Transformer's sinusoidal encoding, ViT uses learned positional embeddings:
where is a learned matrix. Since patch positions are 2D but embeddings are 1D, the model must learn spatial relationships.
Sinusoidal Positional Encoding (Alternative)
Here,
- =Position index
- =Dimension index
- =Model dimension
ViT Architecture
DfViT Components
- Patch Embedding: Conv2d(P, P) with in_channels=3, out_channels=D
- CLS Token: Learnable token prepended to sequence
- Position Embedding: Learnable (N+1) x D matrix
- Transformer Encoder: L layers of multi-head self-attention + FFN
- Classification Head: Linear layer on CLS token output
Typical configurations: ViT-B/16 (L=12, D=768, P=16), ViT-L/16 (L=24, D=1024), ViT-H/14 (L=32, D=1280)
💡 ViT Hyperparameters
- Patch size : Smaller patches = more tokens = more compute. or is standard
- Model dimension : 768 (Base), 1024 (Large), 1280 (Huge)
- Depth : 12 (Base), 24 (Large), 32 (Huge)
- Heads : 12 (Base), 16 (Large), 16 (Huge)
- MLP ratio: Typically 4x (hidden dim = 4D)
ViT vs CNN
| Aspect | CNN | ViT |
|---|---|---|
| Inductive bias | Local connectivity, translation equivariance | None (global attention) |
| Data requirement | Works with less data | Needs large dataset (JFT-300M) |
| Computational cost | per layer | self-attention |
| Patch size | N/A | Determines sequence length |
| Position info | Built-in (conv) | Explicit positional encoding |
| Transfer learning | Excellent | Excellent with pretraining |
ThViT Data Efficiency
ViT has no built-in inductive bias for locality or translation equivariance. Therefore, it requires significantly more data than CNNs to learn these properties from scratch. When pre-trained on large datasets (>100M images), ViT surpasses CNNs because self-attention captures global patterns that CNNs cannot.
DeiT (Data-efficient Image Transformers)
DfDeiT
DeiT (Touvron et al., 2021) addresses ViT's data hunger through:
- Knowledge distillation: Train with a CNN teacher (RegNetY-16GF)
- Distillation token: Learnable token that mimics CNN predictions
- Strong augmentation: RandAugment, Mixup, CutMix, Erasing
- Regularization: Stochastic depth, label smoothing
DeiT achieves comparable performance to ViT with only ImageNet-1K (1.2M images).
Knowledge Distillation Loss
Here,
- =Balance between hard and soft loss
- =Temperature for softening probabilities
- =Teacher's soft predictions
- =Student's soft predictions
Swin Transformer
DfSwin Transformer
Swin (Liu et al., 2021) introduces hierarchical vision Transformer with:
- Shifted windows: Compute self-attention within local windows, shift between layers
- Hierarchical features: Patch merging creates multi-scale feature maps (like CNN pyramids)
- Linear complexity: instead of due to local windows
- Versatile: Works for classification, detection, and segmentation
Window size (typically 7). Complexity: .
Window-based Self-Attention Complexity
Here,
- =Total number of patches
- =Window size (typically 7)
PyTorch Implementation
📝Example: Vision Transformer
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, d_model=768):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_channels, d_model,
kernel_size=patch_size, stride=patch_size
)
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
self.pos_embed = nn.Parameter(
torch.randn(1, self.num_patches + 1, d_model)
)
def forward(self, x):
# x: (batch, 3, 224, 224)
B = x.shape[0]
x = self.proj(x) # (B, d_model, 14, 14)
x = x.flatten(2).transpose(1, 2) # (B, 196, d_model)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (B, 197, d_model)
x = x + self.pos_embed
return x
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.norm2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, int(d_model * mlp_ratio)),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(int(d_model * mlp_ratio), d_model),
nn.Dropout(dropout),
)
def forward(self, x):
h = self.norm1(x)
x = x + self.attn(h, h, h)[0]
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3,
num_classes=1000, d_model=768, num_heads=12,
num_layers=12, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, d_model
)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, mlp_ratio, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, num_classes)
def forward(self, x):
x = self.patch_embed(x)
for block in self.blocks:
x = block(x)
x = self.norm(x)
cls_output = x[:, 0]
return self.head(cls_output)
# ViT-B/16
model = VisionTransformer(
img_size=224, patch_size=16, num_classes=1000,
d_model=768, num_heads=12, num_layers=12
)
x = torch.randn(2, 3, 224, 224)
print(f"Output: {model(x).shape}") # (2, 1000)
params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"Parameters: {params:.1f}M")
Practice Exercises
-
Patch visualization: Split an image into patches and visualize the embedding space using t-SNE.
-
Position embedding analysis: Train a ViT and visualize learned positional embeddings. Do they capture spatial structure?
-
Swin Transformer: Implement shifted window attention. Verify linear complexity on large images.
-
Data augmentation: Compare ViT performance with and without DeiT-style augmentation on ImageNet-1K.
Key Takeaways
📋Summary: Vision Transformers
- ViT treats image patches as tokens for Transformer processing
- Patch embedding: Conv2d with kernel=stride=patch_size
- Positional encoding: Learned embeddings (not sinusoidal)
- No inductive bias: Requires large datasets to learn locality
- DeiT: Knowledge distillation enables training on ImageNet-1K alone
- Swin Transformer: Hierarchical, window-based, linear complexity
- ViT vs CNN: ViT wins with enough data; CNNs better with limited data
- Transfer learning: Pre-trained ViTs achieve SOTA on many vision tasks
- See also: Transformers for the original Transformer architecture