1. Training Loop Anatomy
Every deep learning model, regardless of architecture, follows the same fundamental loop: forward pass → compute loss → backward pass → update parameters.
For each epoch:
For each batch:
1. Forward pass: ŷ = f(x; θ)
2. Compute loss: L = Loss(y, Å·)
3. Zero gradients: ∇θ ↠0
4. Backward pass: ∇θ = ∂L/∂θ (autograd)
5. Update: θ ↠θ − η · ∇θ
function TrainingLoopSVG() {
return (
<svg viewBox="0 0 720 520" xmlns="http://www.w3.org/2000/svg" fontFamily="monospace">
<defs>
<marker id="arrow" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="#6366f1" />
</marker>
<linearGradient id="bg" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#1e1b4b" />
<stop offset="100%" stopColor="#312e81" />
</linearGradient>
<filter id="glow">
<feGaussianBlur stdDeviation="3" result="blur" />
<feMerge><feMergeNode in="blur" /><feMergeNode in="SourceGraphic" /></feMerge>
</filter>
</defs>
<rect width="720" height="520" rx="12" fill="url(#bg)" />
<text x="360" y="35" textAnchor="middle" fill="#a5b4fc" fontSize="16" fontWeight="bold">Training Loop Flowchart</text>
{/* Start */}
<rect x="280" y="55" width="160" height="40" rx="20" fill="#4338ca" stroke="#818cf8" strokeWidth="2" />
<text x="360" y="80" textAnchor="middle" fill="white" fontSize="13">Start Epoch</text>
{/* Data Loading */}
<rect x="280" y="115" width="160" height="40" rx="8" fill="#3730a3" stroke="#818cf8" strokeWidth="1.5" />
<text x="360" y="140" textAnchor="middle" fill="#c7d2fe" fontSize="12">Load Batch (x, y)</text>
{/* Forward */}
<rect x="280" y="175" width="160" height="40" rx="8" fill="#1d4ed8" stroke="#60a5fa" strokeWidth="2" filter="url(#glow)" />
<text x="360" y="200" textAnchor="middle" fill="white" fontSize="13" fontWeight="bold">Forward Pass</text>
<text x="360" y="215" textAnchor="middle" fill="#93c5fd" fontSize="10">y_hat = f(x; theta)</text>
{/* Loss */}
<rect x="280" y="235" width="160" height="40" rx="8" fill="#b91c1c" stroke="#f87171" strokeWidth="2" />
<text x="360" y="260" textAnchor="middle" fill="white" fontSize="13" fontWeight="bold">Compute Loss</text>
<text x="360" y="275" textAnchor="middle" fill="#fca5a5" fontSize="10">L = Loss(y, y_hat)</text>
{/* Zero Grad */}
<rect x="280" y="295" width="160" height="40" rx="8" fill="#3730a3" stroke="#818cf8" strokeWidth="1.5" />
<text x="360" y="320" textAnchor="middle" fill="#c7d2fe" fontSize="12">Zero Gradients</text>
{/* Backward */}
<rect x="280" y="355" width="160" height="40" rx="8" fill="#7c3aed" stroke="#a78bfa" strokeWidth="2" filter="url(#glow)" />
<text x="360" y="380" textAnchor="middle" fill="white" fontSize="13" fontWeight="bold">Backward Pass</text>
<text x="360" y="395" textAnchor="middle" fill="#ddd6fe" fontSize="10">grad_theta = dL/dtheta</text>
{/* Update */}
<rect x="280" y="415" width="160" height="40" rx="8" fill="#047857" stroke="#34d399" strokeWidth="2" />
<text x="360" y="440" textAnchor="middle" fill="white" fontSize="13" fontWeight="bold">Update Weights</text>
<text x="360" y="455" textAnchor="middle" fill="#a7f3d0" fontSize="10">theta = theta - lr * grad</text>
{/* Check epoch */}
<polygon points="360,470 440,505 360,505 280,505" fill="#92400e" stroke="#fbbf24" strokeWidth="1.5" />
<text x="360" y="500" textAnchor="middle" fill="#fef3c7" fontSize="10">More batches?</text>
{/* Arrows */}
<line x1="360" y1="95" x2="360" y2="115" stroke="#818cf8" strokeWidth="1.5" markerEnd="url(#arrow)" />
<line x1="360" y1="155" x2="360" y2="175" stroke="#60a5fa" strokeWidth="1.5" markerEnd="url(#arrow)" />
<line x1="360" y1="215" x2="360" y2="235" stroke="#f87171" strokeWidth="1.5" markerEnd="url(#arrow)" />
<line x1="360" y1="275" x2="360" y2="295" stroke="#818cf8" strokeWidth="1.5" markerEnd="url(#arrow)" />
<line x1="360" y1="335" x2="360" y2="355" stroke="#a78bfa" strokeWidth="1.5" markerEnd="url(#arrow)" />
<line x1="360" y1="395" x2="360" y2="415" stroke="#34d399" strokeWidth="1.5" markerEnd="url(#arrow)" />
<line x1="360" y1="455" x2="360" y2="470" stroke="#fbbf24" strokeWidth="1.5" markerEnd="url(#arrow)" />
{/* Loop back arrow */}
<path d="M 280,505 Q 180,505 180,440 Q 180,145 280,135" fill="none" stroke="#fbbf24" strokeWidth="1.5" strokeDasharray="5,3" markerEnd="url(#arrow)" />
<text x="150" y="320" fill="#fbbf24" fontSize="10" transform="rotate(-90, 150, 320)">next batch</text>
{/* Converge */}
<line x1="440" y1="505" x2="560" y2="505" stroke="#34d399" strokeWidth="1.5" markerEnd="url(#arrow)" />
<rect x="560" y="485" width="120" height="40" rx="20" fill="#047857" stroke="#34d399" strokeWidth="2" />
<text x="620" y="510" textAnchor="middle" fill="white" fontSize="13">Done</text>
</svg>
);
}
The PyTorch Implementation
import torch
import torch.nn as nn
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0.0
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
# 1. Forward pass
predictions = model(batch_x)
# 2. Compute loss
loss = criterion(predictions, batch_y)
# 3. Zero gradients
optimizer.zero_grad()
# 4. Backward pass
loss.backward()
# 5. Update parameters
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
Key subtleties:
model.train()enables dropout and batch normalization training modeoptimizer.zero_grad()must be called before.backward()— gradients accumulate by default.item()extracts the scalar loss value (detaches from the computation graph)
2. Loss Functions
Loss functions quantify the mismatch between predictions and targets. The choice of loss function encodes what "good" means for your task.
2.1 Mean Squared Error (MSE)
For regression tasks:
Gradient (per sample):
Properties:
- Convex for linear models → guarantees a single global minimum
- Penalizes large errors quadratically → sensitive to outliers
- Equivalent to maximizing Gaussian log-likelihood with fixed variance
2.2 Cross-Entropy Loss
For multi-class classification with classes:
where is one-hot encoded and :
The combined CrossEntropyLoss in PyTorch applies log-softmax + NLL loss numerically:
Numerical stability: The log-sum-exp trick computes where .
2.3 Focal Loss
Addressing class imbalance (e.g., object detection where 99% of anchors are background):
where:
- = model's estimated probability for the correct class
- = focusing parameter (typically 2)
- = class balancing weight
When , focal loss reduces to standard cross-entropy. When , well-classified examples () have their loss reduced by .
2.4 Contrastive Loss
For learning embeddings where similar items are close and dissimilar items are far apart:
where is the Euclidean distance between embeddings, for similar pairs, and is the margin.
Triplet Loss (used in FaceNet):
where = anchor, = positive (same class), = negative (different class), = margin.
Loss Function Selection Guide
| Task | Loss Function | Why |
|---|---|---|
| Regression | MSE, MAE, Huber | MSE for clean data, Huber for outliers |
| Binary classification | BCE, Focal | Focal for imbalanced data |
| Multi-class classification | CrossEntropy, Focal | Focal for long-tailed distributions |
| Metric learning | Contrastive, Triplet | Learn embedding space structure |
| Segmentation | Dice loss, CE+Dice | Handle severe foreground/background imbalance |
| GANs | Adversarial loss | Minimax game between generator and discriminator |
3. Optimizers
3.1 SGD (Stochastic Gradient Descent)
The simplest update rule:
Problem: Oscillates along high-curvature directions, converges slowly along flat directions.
3.2 SGD with Momentum
Adds a velocity term that accumulates past gradients:
Commonly . Momentum accelerates convergence in consistent gradient directions and dampens oscillations.
Physical analogy: A ball rolling downhill accumulates velocity. controls friction — lower means more friction.
3.3 RMSProp
Adapts the learning rate per parameter based on the magnitude of recent gradients:
Parameters with large gradients get a smaller effective learning rate; parameters with small gradients get a larger one. Default: , .
3.4 Adam (Adaptive Moment Estimation)
Combines momentum (first moment) and RMSProp (second moment):
Bias correction (critical in early steps):
Update:
Defaults: , , .
3.5 AdamW (Adam with Decoupled Weight Decay)
In Adam, L2 regularization () is absorbed into the adaptive learning rate, making the effective weight decay per parameter different. AdamW decouples weight decay:
This makes weight decay consistent across parameters regardless of gradient magnitude. AdamW is the default optimizer for training transformers.
Optimizer Comparison
function OptimizerComparisonSVG() {
return (
<svg viewBox="0 0 720 400" xmlns="http://www.w3.org/2000/svg" fontFamily="monospace">
<defs>
<linearGradient id="bg2" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#0f172a" />
<stop offset="100%" stopColor="#1e293b" />
</linearGradient>
</defs>
<rect width="720" height="400" rx="12" fill="url(#bg2)" />
<text x="360" y="30" textAnchor="middle" fill="#94a3b8" fontSize="14" fontWeight="bold">Optimizer Convergence Comparison</text>
{/* Axes */}
<line x1="80" y1="350" x2="680" y2="350" stroke="#475569" strokeWidth="1.5" />
<line x1="80" y1="350" x2="80" y2="50" stroke="#475569" strokeWidth="1.5" />
<text x="380" y="390" textAnchor="middle" fill="#94a3b8" fontSize="12">Training Steps</text>
<text x="30" y="200" fill="#94a3b8" fontSize="12" transform="rotate(-90,30,200)">Loss</text>
{/* Grid lines */}
{[100,150,200,250,300].map(y => (
<line key={y} x1="80" y1={y} x2="680" y2={y} stroke="#334155" strokeWidth="0.5" strokeDasharray="4,4" />
))}
{/* SGD - slow, oscillating */}
<polyline points="80,340 120,300 160,320 200,280 240,290 280,260 320,265 360,240 400,245 440,225 480,228 520,215 560,218 600,210 640,205 680,200"
fill="none" stroke="#f87171" strokeWidth="2" strokeDasharray="6,3" />
<text x="685" y="200" fill="#f87171" fontSize="11">SGD</text>
{/* SGD+Momentum - faster, smoother */}
<polyline points="80,340 120,280 160,250 200,210 240,185 280,165 320,150 360,138 400,130 440,124 480,120 520,117 560,115 600,113 640,112 680,111"
fill="none" stroke="#fbbf24" strokeWidth="2" />
<text x="685" y="111" fill="#fbbf24" fontSize="11">SGD+M</text>
{/* Adam - fast convergence */}
<polyline points="80,340 120,220 160,160 200,130 240,112 280,100 320,92 360,86 400,82 440,79 480,77 520,76 560,75 600,74 640,74 680,74"
fill="none" stroke="#34d399" strokeWidth="2.5" />
<text x="685" y="74" fill="#34d399" fontSize="11">Adam</text>
{/* AdamW - similar to Adam, slightly better generalization */}
<polyline points="80,340 120,215 160,155 200,125 240,108 280,96 320,88 360,83 400,79 440,76 480,74 520,73 560,72 600,71 640,71 680,70"
fill="none" stroke="#818cf8" strokeWidth="2.5" strokeDasharray="8,4" />
<text x="685" y="70" fill="#818cf8" fontSize="11">AdamW</text>
{/* Legend */}
<rect x="100" y="60" width="180" height="90" rx="6" fill="#1e293b" stroke="#475569" strokeWidth="1" />
<line x1="115" y1="80" x2="145" y2="80" stroke="#f87171" strokeWidth="2" strokeDasharray="6,3" />
<text x="155" y="84" fill="#cbd5e1" fontSize="10">SGD (constant lr)</text>
<line x1="115" y1="100" x2="145" y2="100" stroke="#fbbf24" strokeWidth="2" />
<text x="155" y="104" fill="#cbd5e1" fontSize="10">SGD + Momentum</text>
<line x1="115" y1="120" x2="145" y2="120" stroke="#34d399" strokeWidth="2.5" />
<text x="155" y="124" fill="#cbd5e1" fontSize="10">Adam / AdamW</text>
</svg>
);
}
Optimizer Selection Decision Tree
Is your model a transformer or uses batch norm?
├─ Yes → AdamW (lr=3e-4, weight_decay=0.01)
└─ No
├─ Computer Vision (CNN)?
│ ├─ Yes → SGD+Momentum (lr=0.1, momentum=0.9) with cosine schedule
│ └─ No
│ ├─ Reinforcement Learning? → Adam (lr=3e-4)
│ └─ General deep learning? → Start with Adam, try SGD if generalization gap
4. Learning Rate Schedules
The learning rate is the most important hyperparameter. A fixed learning rate is rarely optimal — you want large steps early (fast convergence) and small steps later (fine-tuning).
4.1 Step Decay
Reduce the learning rate by a factor every epochs:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# lr drops to 10% at epoch 30 and 60
4.2 Cosine Annealing
Smoothly anneal from to following a cosine curve:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
Cosine annealing is the default schedule for most modern training pipelines. It provides a gentle decay at both ends and faster decay in the middle.
4.3 Warmup + Cosine
Linearly increase the learning rate from 0 to over warmup steps, then cosine anneal:
Warmup is essential for transformers — training is unstable in early steps when parameters are random and adaptive optimizers have unreliable second-moment estimates.
4.4 OneCycle Policy
Cycles the learning rate from low → high → low within a single cycle, with momentum going in reverse (high → low → high):
where pct goes from 0 to 1 over the total training steps. Smith (2018) showed this can converge in fewer epochs than standard schedules.
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=0.01, total_steps=total_training_steps
)
Learning Rate Schedule Visualization
function LRScheduleSVG() {
return (
<svg viewBox="0 0 720 380" xmlns="http://www.w3.org/2000/svg" fontFamily="monospace">
<defs>
<linearGradient id="bg3" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#0c0a09" />
<stop offset="100%" stopColor="#1c1917" />
</linearGradient>
</defs>
<rect width="720" height="380" rx="12" fill="url(#bg3)" />
<text x="360" y="25" textAnchor="middle" fill="#a8a29e" fontSize="14" fontWeight="bold">Learning Rate Schedules</text>
{/* Axes */}
<line x1="80" y1="330" x2="680" y2="330" stroke="#57534e" strokeWidth="1.5" />
<line x1="80" y1="330" x2="80" y2="50" stroke="#57534e" strokeWidth="1.5" />
<text x="380" y="370" textAnchor="middle" fill="#a8a29e" fontSize="11">Epoch</text>
<text x="25" y="190" fill="#a8a29e" fontSize="11" transform="rotate(-90,25,190)">Learning Rate</text>
{/* Step Decay */}
<polyline points="80,80 200,80 200,160 320,160 320,240 440,240 440,290 560,290 560,310 680,310"
fill="none" stroke="#f87171" strokeWidth="2" />
<text x="690" y="310" fill="#f87171" fontSize="10">Step</text>
{/* Cosine */}
<path d="M80,80 Q200,85 280,140 Q360,220 440,280 Q520,315 600,325 Q640,328 680,329"
fill="none" stroke="#34d399" strokeWidth="2.5" />
<text x="690" y="325" fill="#34d399" fontSize="10">Cosine</text>
{/* Warmup + Cosine */}
<path d="M80,329 Q110,329 130,280 Q150,220 170,120 Q190,80 200,80 Q280,90 360,160 Q440,240 520,300 Q600,325 680,329"
fill="none" stroke="#818cf8" strokeWidth="2.5" strokeDasharray="8,3" />
<text x="690" y="329" fill="#818cf8" fontSize="10">Warmup+Cos</text>
{/* OneCycle */}
<path d="M80,300 Q160,250 240,100 Q280,60 320,60 Q360,60 400,100 Q480,250 560,310 Q620,325 680,328"
fill="none" stroke="#fbbf24" strokeWidth="2" />
<text x="690" y="318" fill="#fbbf24" fontSize="10">OneCycle</text>
{/* Warmup region */}
<rect x="80" y="50" width="100" height="280" fill="#818cf8" opacity="0.05" />
<text x="130" y="345" textAnchor="middle" fill="#818cf8" fontSize="9">warmup</text>
</svg>
);
}
5. Regularization
Overfitting occurs when the model memorizes training data rather than learning generalizable patterns. Regularization techniques combat this.
5.1 Dropout
During training, each neuron is independently set to zero with probability :
The scaling (inverted dropout) ensures the expected activation remains unchanged at test time, where no dropout is applied.
Intuition: Dropout forces the network to learn redundant representations — no single neuron can be relied upon. It can be interpreted as training an ensemble of sub-networks (where is the number of neurons).
function DropoutSVG() {
return (
<svg viewBox="0 0 720 300" xmlns="http://www.w3.org/2000/svg" fontFamily="monospace">
<defs>
<linearGradient id="bg4" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#0f172a" />
<stop offset="100%" stopColor="#1e293b" />
</linearGradient>
</defs>
<rect width="720" height="300" rx="12" fill="url(#bg4)" />
<text x="360" y="25" textAnchor="middle" fill="#94a3b8" fontSize="14" fontWeight="bold">Dropout Visualization (p = 0.5)</text>
{/* Training side */}
<text x="180" y="55" textAnchor="middle" fill="#60a5fa" fontSize="12" fontWeight="bold">Training</text>
{/* Input layer */}
{[80, 120, 160, 200, 240].map((y, i) => (
<g key={`in-${i}`}>
<circle cx="100" cy={y} r="15" fill="#334155" stroke="#60a5fa" strokeWidth="1.5" />
<text x="100" y={y+4} textAnchor="middle" fill="#e2e8f0" fontSize="9">x{i+1}</text>
</g>
))}
{/* Hidden layer - with dropout */}
{[80, 120, 160, 200, 240].map((y, i) => {
const dropped = [1, 3].includes(i);
return (
<g key={`h-${i}`}>
<circle cx="280" cy={y} r="15"
fill={dropped ? "#1e293b" : "#1d4ed8"}
stroke={dropped ? "#475569" : "#60a5fa"}
strokeWidth={dropped ? "1" : "1.5"}
opacity={dropped ? "0.4" : "1"} />
{!dropped && <text x="280" y={y+4} textAnchor="middle" fill="white" fontSize="9">h{i+1}</text>}
{dropped && <line x1="268" y1={y-8} x2="292" y2={y+8} stroke="#f87171" strokeWidth="2" />}
{dropped && <line x1="292" y1={y-8} x2="268" y2={y+8} stroke="#f87171" strokeWidth="2" />}
</g>
);
})}
{/* Output */}
<circle cx="400" cy="160" r="15" fill="#047857" stroke="#34d399" strokeWidth="1.5" />
<text x="400" y="164" textAnchor="middle" fill="white" fontSize="9">y</text>
{/* Connections - active only */}
{[80, 120, 200, 240].map((y, i) => (
<line key={`c1-${i}`} x1="115" y1={y} x2="265" y2={y} stroke="#475569" strokeWidth="0.8" opacity="0.5" />
))}
<line x1="295" y1="80" x2="385" y2="160" stroke="#34d399" strokeWidth="1" />
<line x1="295" y1="120" x2="385" y2="160" stroke="#34d399" strokeWidth="1" />
<line x1="295" y1="200" x2="385" y2="160" stroke="#34d399" strokeWidth="1" />
<line x1="295" y1="240" x2="385" y2="160" stroke="#34d399" strokeWidth="1" />
{/* Test side */}
<text x="560" y="55" textAnchor="middle" fill="#fbbf24" fontSize="12" fontWeight="bold">Inference</text>
{[80, 120, 160, 200, 240].map((y, i) => (
<g key={`in2-${i}`}>
<circle cx="480" cy={y} r="15" fill="#334155" stroke="#fbbf24" strokeWidth="1.5" />
<text x="480" y={y+4} textAnchor="middle" fill="#e2e8f0" fontSize="9">x{i+1}</text>
</g>
))}
{/* Hidden - no dropout, scaled */}
{[80, 120, 160, 200, 240].map((y, i) => (
<g key={`h2-${i}`}>
<circle cx="620" cy={y} r="15" fill="#3730a3" stroke="#fbbf24" strokeWidth="1.5" />
<text x="620" y={y+4} textAnchor="middle" fill="white" fontSize="8">h{i+1}×0.5</text>
</g>
))}
{[80, 120, 160, 200, 240].map((y, i) => (
<line key={`c2-${i}`} x1="495" y1={y} x2="605" y2={y} stroke="#fbbf24" strokeWidth="0.8" opacity="0.5" />
))}
</svg>
);
}
5.2 Batch Normalization
Normalizes activations across the batch dimension for each feature:
where and over the mini-batch.
and are learnable parameters that allow the network to undo the normalization if needed.
During inference: Use running averages of and accumulated during training (via exponential moving average).
Benefits:
- Allows higher learning rates
- Reduces sensitivity to initialization
- Provides mild regularization (batch statistics add noise)
Limitation: Requires batch dimension → problematic for small batches, sequence models, or distributed training.
5.3 Layer Normalization
Normalizes across the feature dimension for each sample (independent of batch size):
function NormComparisonSVG() {
return (
<svg viewBox="0 0 720 350" xmlns="http://www.w3.org/2000/svg" fontFamily="monospace">
<defs>
<linearGradient id="bg5" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#0c0a09" />
<stop offset="100%" stopColor="#1c1917" />
</linearGradient>
</defs>
<rect width="720" height="350" rx="12" fill="url(#bg5)" />
<text x="360" y="25" textAnchor="middle" fill="#a8a29e" fontSize="14" fontWeight="bold">BatchNorm vs LayerNorm</text>
{/* BatchNorm side */}
<text x="180" y="55" textAnchor="middle" fill="#60a5fa" fontSize="13" fontWeight="bold">BatchNorm</text>
<text x="180" y="72" textAnchor="middle" fill="#94a3b8" fontSize="10">Normalizes across batch (feature-wise)</text>
{/* Batch as grid */}
{[0,1,2,3].map(b => (
<g key={`batch-${b}`}>
<rect x={100 + b*50} y={95} width="40" height="120" rx="4" fill={b === 2 ? "#1d4ed8" : "#334155"} stroke="#475569" strokeWidth="1" opacity={b === 2 ? 1 : 0.6} />
<text x={120 + b*50} y={90} textAnchor="middle" fill="#94a3b8" fontSize="9">Sample {b+1}</text>
{[0,1,2].map(f => (
<rect key={f} x={105 + b*50} y={100 + f*38} width="30" height="30" rx="3"
fill={b === 2 ? "#2563eb" : "#1e293b"} stroke={b === 2 ? "#60a5fa" : "#334155"} strokeWidth="1" />
))}
</g>
))}
{/* Arrow showing normalization direction */}
<line x1="120" y1="225" x2="260" y2="225" stroke="#60a5fa" strokeWidth="2" markerEnd="url(#arrow)" />
<text x="190" y="245" textAnchor="middle" fill="#60a5fa" fontSize="10">normalize along this axis</text>
{/* Feature labels */}
<text x="85" y="118" fill="#94a3b8" fontSize="8" textAnchor="end">f1</text>
<text x="85" y="156" fill="#94a3b8" fontSize="8" textAnchor="end">f2</text>
<text x="85" y="194" fill="#94a3b8" fontSize="8" textAnchor="end">f3</text>
{/* BN result */}
<text x="190" y="275" textAnchor="middle" fill="#94a3b8" fontSize="10">μ, σ² per feature across batch</text>
<text x="190" y="295" textAnchor="middle" fill="#f87171" fontSize="10">Needs batch_size > 1</text>
{/* LayerNorm side */}
<text x="540" y="55" textAnchor="middle" fill="#34d399" fontSize="13" fontWeight="bold">LayerNorm</text>
<text x="540" y="72" textAnchor="middle" fill="#94a3b8" fontSize="10">Normalizes across features (sample-wise)</text>
{[0,1,2,3].map(b => (
<g key={`ln-${b}`}>
<rect x={460 + b*50} y={95} width="40" height="120" rx="4"
fill={b === 2 ? "#047857" : "#334155"} stroke="#475569" strokeWidth="1" opacity={b === 2 ? 1 : 0.6} />
<text x={480 + b*50} y={90} textAnchor="middle" fill="#94a3b8" fontSize="9">Sample {b+1}</text>
{[0,1,2].map(f => (
<rect key={f} x={465 + b*50} y={100 + f*38} width="30" height="30" rx="3"
fill={b === 2 ? "#059669" : "#1e293b"} stroke={b === 2 ? "#34d399" : "#334155"} strokeWidth="1" />
))}
</g>
))}
{/* Arrow showing normalization direction */}
<line x1="480" y1="225" x2="480" y2="250" stroke="#34d399" strokeWidth="2" markerEnd="url(#arrow)" />
<text x="540" y="245" textAnchor="middle" fill="#34d399" fontSize="10">normalize along this axis</text>
<text x="540" y="275" textAnchor="middle" fill="#94a3b8" fontSize="10">μ, σ² per sample across features</text>
<text x="540" y="295" textAnchor="middle" fill="#34d399" fontSize="10">Batch-size independent</text>
{/* Bottom comparison */}
<rect x="100" y="310" width="520" height="30" rx="6" fill="#1e293b" stroke="#475569" strokeWidth="1" />
<text x="150" y="330" fill="#60a5fa" fontSize="10">CNNs, ResNets</text>
<text x="300" y="330" fill="#a8a29e" fontSize="10">|</text>
<text x="400" y="330" fill="#34d399" fontSize="10">Transformers, RNNs, Small batches</text>
</svg>
);
}
5.4 Weight Decay
Adds an L2 penalty to the loss:
This pushes weights toward zero, preventing any single weight from growing too large. Typical values: .
With AdamW, weight decay is applied directly to parameters without going through the adaptive learning rate, making it more effective than L2 regularization with Adam.
6. Gradient Clipping
Exploding gradients cause numerical instability. Gradient clipping bounds the gradient norm.
Norm Clipping (recommended)
This preserves the gradient direction while limiting magnitude.
Value Clipping
Clips each gradient component independently. This changes the gradient direction and is less preferred.
# Norm clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Value clipping
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
function GradClippingSVG() {
return (
<svg viewBox="0 0 720 320" xmlns="http://www.w3.org/2000/svg" fontFamily="monospace">
<defs>
<linearGradient id="bg6" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#0f172a" />
<stop offset="100%" stopColor="#1e293b" />
</linearGradient>
</defs>
<rect width="720" height="320" rx="12" fill="url(#bg6)" />
<text x="360" y="25" textAnchor="middle" fill="#94a3b8" fontSize="14" fontWeight="bold">Gradient Clipping (max_norm = Ï„)</text>
{/* Circle representing gradient norm threshold */}
<circle cx="200" cy="180" r="100" fill="none" stroke="#475569" strokeWidth="1" strokeDasharray="5,5" />
<text x="305" y="180" fill="#475569" fontSize="11">||g|| = Ï„</text>
{/* Unclipped gradient */}
<line x1="200" y1="180" x2="280" y2="90" stroke="#f87171" strokeWidth="2" />
<circle cx="280" cy="90" r="4" fill="#f87171" />
<text x="290" y="88" fill="#f87171" fontSize="10">g (||g|| > Ï„)</text>
{/* Clipped gradient */}
<line x1="200" y1="180" x2="250" y2="118" stroke="#34d399" strokeWidth="2.5" />
<circle cx="250" cy="118" r="5" fill="#34d399" />
<text x="260" y="115" fill="#34d399" fontSize="10">ĝ = (τ/||g||)·g</text>
{/* Small gradient (no clipping) */}
<line x1="200" y1="180" x2="160" y2="130" stroke="#fbbf24" strokeWidth="2" />
<circle cx="160" cy="130" r="4" fill="#fbbf24" />
<text x="115" y="125" fill="#fbbf24" fontSize="10">g (||g|| < Ï„)</text>
{/* Right side: Before/After */}
<text x="520" y="60" textAnchor="middle" fill="#94a3b8" fontSize="12" fontWeight="bold">Before Clipping</text>
<text x="520" y="80" textAnchor="middle" fill="#f87171" fontSize="11">gradients = [2.5, -4.0, 8.0, 1.0]</text>
<text x="520" y="100" textAnchor="middle" fill="#f87171" fontSize="11">||g|| = 9.37</text>
<text x="520" y="150" textAnchor="middle" fill="#94a3b8" fontSize="12" fontWeight="bold">After Clipping (Ï„=5.0)</text>
<text x="520" y="170" textAnchor="middle" fill="#34d399" fontSize="11">gradients = [1.33, -2.13, 4.27, 0.53]</text>
<text x="520" y="190" textAnchor="middle" fill="#34d399" fontSize="11">||ĝ|| = 5.0 (direction preserved)</text>
{/* Note */}
<rect x="370" y="220" width="300" height="60" rx="6" fill="#1e293b" stroke="#475569" strokeWidth="1" />
<text x="520" y="242" textAnchor="middle" fill="#94a3b8" fontSize="10">Value clipping: clip each component independently</text>
<text x="520" y="258" textAnchor="middle" fill="#f87171" fontSize="10">clip([-3, 8], -5, 5) = [-3, 5] (changes direction!)</text>
<text x="520" y="274" textAnchor="middle" fill="#34d399" fontSize="10">Norm clipping preserves direction ✓</text>
</svg>
);
}
When to use gradient clipping:
- Training RNNs/LSTMs (almost always needed)
- Training transformers (especially with large learning rates)
- Large batch training where gradient norms can spike
- Any situation with loss divergence or NaN losses
7. Mixed Precision Training
Uses 16-bit floating point (FP16) for most computations while keeping a 32-bit (FP32) master copy of weights.
Why Mixed Precision?
| FP32 | FP16 | Speedup | |
|---|---|---|---|
| Memory | 4 bytes | 2 bytes | 2× less memory |
| Compute (A100) | 19.5 TFLOPS | 312 TFLOPS | ~16× (with Tensor Cores) |
| Bandwidth | 2 TB/s | 2 TB/s | Same (but less data) |
The Problem: Loss Scaling
FP16 has a much smaller range () and precision ( decimal digits). Small gradients can underflow to zero. Solution: loss scaling — multiply the loss by a large factor (e.g., 1024), compute gradients in this scaled space, then unscale before the update.
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
predictions = model(batch_x)
loss = criterion(predictions, batch_y)
scaler.scale(loss).backward() # backward in scaled FP16
scaler.unscale_(optimizer) # unscale gradients back to FP32
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer) # skip if gradients contain inf/nan
scaler.update() # adjust scale factor
Dynamic loss scaling: The GradScaler starts with a large scale factor and halves it whenever inf or nan gradients are detected, then increases it slowly when training is stable.
BFloat16 Alternative
BFloat16 uses 8 exponent bits (same range as FP32) and 7 mantissa bits. No loss scaling needed, but slightly less precise than FP16. Preferred on Ampere+ GPUs.
8. Distributed Training
Data Parallelism
The most common strategy: replicate the model on GPUs, split the batch across them, and average gradients:
# PyTorch DDP
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
Gradient synchronization: All-reduce communicates gradients across GPUs. NCCL (NVIDIA) or Gloo (CPU) backends. Overlap computation and communication — start reducing gradients for layer while computing backward for layer .
Model Parallelism
When the model is too large to fit on one GPU:
- Pipeline parallelism: Split model layers across GPUs, micro-batch the pipeline
- Tensor parallelism: Split individual operations (e.g., attention heads) across GPUs
- ZeRO (DeepSpeed): Shard optimizer states, gradients, and parameters across GPUs
Training at Scale
Total batch size = num_GPUs × per_gpu_batch_size × gradient_accumulation_steps
Example: 8 GPUs × 32 samples × 4 accumulations = 1024 effective batch size
Large batch training requires adjusting the learning rate (linear scaling rule) and using warmup:
Putting It All Together: A Modern Training Recipe
# 1. Model
model = YourModel().to(device)
# 2. Optimizer (AdamW for transformers, SGD for CNNs)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
# 3. Schedule (Warmup + Cosine)
warmup_steps = 1000
total_steps = 100000
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# 4. Loss
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
# 5. Mixed precision
scaler = torch.cuda.amp.GradScaler()
# 6. Training loop
for step in range(total_steps):
batch_x, batch_y = next(train_loader)
with torch.cuda.amp.autocast():
loss = criterion(model(batch_x), batch_y)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scheduler.step()
Hyperparameter Defaults (IIT/MIT Research Standards)
| Hyperparameter | Transformer | CNN (ResNet) |
|---|---|---|
| Optimizer | AdamW | SGD + Momentum |
| Learning rate | 3e-4 | 0.1 |
| Weight decay | 0.01 | 1e-4 |
| Batch size | 256–2048 | 256 |
| Warmup steps | 2000–4000 | — |
| Schedule | Cosine | Cosine |
| Gradient clip | 1.0 | None |
| Dropout | 0.1 | 0.2 |
| Label smoothing | 0.1 | — |
Summary
| Concept | Key Takeaway |
|---|---|
| Training loop | Forward → loss → zero_grad → backward → step |
| MSE | Regression; penalizes large errors quadratically |
| Cross-entropy | Classification; combined with log-softmax |
| Focal loss | Handles class imbalance via weighting |
| Contrastive/triplet loss | Learn embedding spaces |
| SGD + Momentum | Best for CNNs; fast convergence with proper schedule |
| Adam/AdamW | Best for transformers; adaptive per-parameter lr |
| Cosine annealing | Smooth decay; default schedule in modern training |
| Warmup | Essential for transformers; stabilizes early training |
| Dropout | Ensembles sub-networks; scale by at test time |
| BatchNorm | Normalize across batch; use in CNNs |
| LayerNorm | Normalize across features; use in transformers |
| Gradient clipping | Clip norm to prevent exploding gradients |
| Mixed precision | FP16/BF16 + loss scaling for 2-4× speedup |
| Distributed training | DDP for data parallelism; ZeRO for model parallelism |