Vision Transformers — ViT and Beyond

TransformersVisionFree Lesson

Advertisement

Vision Transformers — ViT and Beyond

Vision Transformers (ViT) apply the Transformer architecture to image recognition by treating image patches as tokens, achieving competitive or superior performance to CNNs.


From NLP to Vision

DfVision Transformer (ViT)

ViT (Dosovitskiy et al., 2020) processes images as sequences of patches:

  1. Split image into fixed-size patches (e.g., 16x16)
  2. Linearly embed each patch
  3. Add positional embeddings
  4. Process through standard Transformer encoder
  5. Use [CLS] token output for classification

Key insight: When trained on sufficient data, ViT outperforms CNNs because self-attention captures global relationships from the first layer.


Patch Embedding

Patch Embedding

z0=[xcls;  xp1E;  xp2E;  ;  xpNE]+Epos\mathbf{z}_0 = [\mathbf{x}_{\text{cls}}; \; \mathbf{x}_p^1 E; \; \mathbf{x}_p^2 E; \; \ldots; \; \mathbf{x}_p^N E] + \mathbf{E}_{\text{pos}}

Here,

  • xpi\mathbf{x}_p^i=Flattened patch i
  • EE=Patch embedding matrix (linear projection)
  • Epos\mathbf{E}_{\text{pos}}=Positional embedding
  • xcls\mathbf{x}_{\text{cls}}=Classification token embedding
  • NN=Number of patches (H*W / P^2)

Patch Projection (Convolutional Shortcut)

z0=Conv2d(image,kernel=P,stride=P)\mathbf{z}_0 = \text{Conv2d}(\text{image}, kernel=P, stride=P)

Here,

  • PP=Patch size (e.g., 16)
  • Conv2d\text{Conv2d}=Convolution with no padding, stride=P

ℹ️ Conv2d as Patch Embedding

A convolutional layer with kernel size = patch size and stride = patch size is equivalent to splitting the image into patches and applying a linear projection. This is computationally more efficient and is the standard implementation.


Positional Encoding

DfLearnable Positional Embeddings

Unlike the original Transformer's sinusoidal encoding, ViT uses learned positional embeddings:

z0=xpE+Epos\mathbf{z}_0 = \mathbf{x}_p E + E_{\text{pos}}

where EposR(N+1)×DE_{\text{pos}} \in \mathbb{R}^{(N+1) \times D} is a learned matrix. Since patch positions are 2D but embeddings are 1D, the model must learn spatial relationships.

Sinusoidal Positional Encoding (Alternative)

PE(pos,2i)=sin(pos100002i/dmodel),PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right), \quad PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)

Here,

  • pospos=Position index
  • ii=Dimension index
  • dmodeld_{model}=Model dimension

ViT Architecture

DfViT Components

  • Patch Embedding: Conv2d(P, P) with in_channels=3, out_channels=D
  • CLS Token: Learnable token prepended to sequence
  • Position Embedding: Learnable (N+1) x D matrix
  • Transformer Encoder: L layers of multi-head self-attention + FFN
  • Classification Head: Linear layer on CLS token output

Typical configurations: ViT-B/16 (L=12, D=768, P=16), ViT-L/16 (L=24, D=1024), ViT-H/14 (L=32, D=1280)

💡 ViT Hyperparameters

  • Patch size PP: Smaller patches = more tokens = more compute. P=14P=14 or P=16P=16 is standard
  • Model dimension DD: 768 (Base), 1024 (Large), 1280 (Huge)
  • Depth LL: 12 (Base), 24 (Large), 32 (Huge)
  • Heads HH: 12 (Base), 16 (Large), 16 (Huge)
  • MLP ratio: Typically 4x (hidden dim = 4D)

ViT vs CNN

AspectCNNViT
Inductive biasLocal connectivity, translation equivarianceNone (global attention)
Data requirementWorks with less dataNeeds large dataset (JFT-300M)
Computational costO(N)O(N) per layerO(N2)O(N^2) self-attention
Patch sizeN/ADetermines sequence length
Position infoBuilt-in (conv)Explicit positional encoding
Transfer learningExcellentExcellent with pretraining

ThViT Data Efficiency

ViT has no built-in inductive bias for locality or translation equivariance. Therefore, it requires significantly more data than CNNs to learn these properties from scratch. When pre-trained on large datasets (>100M images), ViT surpasses CNNs because self-attention captures global patterns that CNNs cannot.


DeiT (Data-efficient Image Transformers)

DfDeiT

DeiT (Touvron et al., 2021) addresses ViT's data hunger through:

  1. Knowledge distillation: Train with a CNN teacher (RegNetY-16GF)
  2. Distillation token: Learnable token that mimics CNN predictions
  3. Strong augmentation: RandAugment, Mixup, CutMix, Erasing
  4. Regularization: Stochastic depth, label smoothing

DeiT achieves comparable performance to ViT with only ImageNet-1K (1.2M images).

Knowledge Distillation Loss

L=αLCE(y,pstudent)+(1α)T2LCE(pteachersoft,pstudentsoft)\mathcal{L} = \alpha \cdot \mathcal{L}_{\text{CE}}(y, p_{\text{student}}) + (1 - \alpha) \cdot T^2 \cdot \mathcal{L}_{\text{CE}}(p_{\text{teacher}}^{\text{soft}}, p_{\text{student}}^{\text{soft}})

Here,

  • α\alpha=Balance between hard and soft loss
  • TT=Temperature for softening probabilities
  • pteachersoftp_{\text{teacher}}^{\text{soft}}=Teacher's soft predictions
  • pstudentsoftp_{\text{student}}^{\text{soft}}=Student's soft predictions

Swin Transformer

DfSwin Transformer

Swin (Liu et al., 2021) introduces hierarchical vision Transformer with:

  1. Shifted windows: Compute self-attention within local windows, shift between layers
  2. Hierarchical features: Patch merging creates multi-scale feature maps (like CNN pyramids)
  3. Linear complexity: O(N)O(N) instead of O(N2)O(N^2) due to local windows
  4. Versatile: Works for classification, detection, and segmentation

Window size MM (typically 7). Complexity: O(NM2)O(N \cdot M^2).

Window-based Self-Attention Complexity

Complexity=O(NM2)\text{Complexity} = O(N \cdot M^2)

Here,

  • NN=Total number of patches
  • MM=Window size (typically 7)

PyTorch Implementation

📝Example: Vision Transformer

import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, d_model=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, d_model,
            kernel_size=patch_size, stride=patch_size
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_embed = nn.Parameter(
            torch.randn(1, self.num_patches + 1, d_model)
        )

    def forward(self, x):
        # x: (batch, 3, 224, 224)
        B = x.shape[0]
        x = self.proj(x)                    # (B, d_model, 14, 14)
        x = x.flatten(2).transpose(1, 2)    # (B, 196, d_model)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, 197, d_model)
        x = x + self.pos_embed
        return x


class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(
            d_model, num_heads, dropout=dropout, batch_first=True
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, int(d_model * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(d_model * mlp_ratio), d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        h = self.norm1(x)
        x = x + self.attn(h, h, h)[0]
        x = x + self.mlp(self.norm2(x))
        return x


class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3,
                 num_classes=1000, d_model=768, num_heads=12,
                 num_layers=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, d_model
        )
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        cls_output = x[:, 0]
        return self.head(cls_output)


# ViT-B/16
model = VisionTransformer(
    img_size=224, patch_size=16, num_classes=1000,
    d_model=768, num_heads=12, num_layers=12
)
x = torch.randn(2, 3, 224, 224)
print(f"Output: {model(x).shape}")  # (2, 1000)
params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"Parameters: {params:.1f}M")

Practice Exercises

  1. Patch visualization: Split an image into patches and visualize the embedding space using t-SNE.

  2. Position embedding analysis: Train a ViT and visualize learned positional embeddings. Do they capture spatial structure?

  3. Swin Transformer: Implement shifted window attention. Verify linear complexity on large images.

  4. Data augmentation: Compare ViT performance with and without DeiT-style augmentation on ImageNet-1K.


Key Takeaways

📋Summary: Vision Transformers

  • ViT treats image patches as tokens for Transformer processing
  • Patch embedding: Conv2d with kernel=stride=patch_size
  • Positional encoding: Learned embeddings (not sinusoidal)
  • No inductive bias: Requires large datasets to learn locality
  • DeiT: Knowledge distillation enables training on ImageNet-1K alone
  • Swin Transformer: Hierarchical, window-based, linear complexity
  • ViT vs CNN: ViT wins with enough data; CNNs better with limited data
  • Transfer learning: Pre-trained ViTs achieve SOTA on many vision tasks
  • See also: Transformers for the original Transformer architecture

Advertisement

Need Expert Deep Learning Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement