Training Loops, Loss, Optimizers
Introduction
Training neural networks requires careful selection of loss functions, optimizers, and training strategies. This module covers the essential components that determine how well and how fast your model learns.
Training Loop Overview:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ TRAINING LOOP โ
โ โ
โ โโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโ โ
โ โ Forward โโโโโบโ Compute โโโโโบโ Backward โโโโโบโ Update โ โ
โ โ Pass โ โ Loss โ โ Pass โ โWeights โ โ
โ โโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโ โ
โ โ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ โ
โ โโโโโโโโโโโ Repeat for N epochs โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Optimizer Updates:
โโโโโโโโโโโโโโโโโโ
w_new = w_old - learning_rate ร gradient
Different optimizers compute "gradient" differently:
โข SGD: Just the gradient
โข Adam: Gradient + momentum + adaptive learning rates
โข RMSprop: Gradient with running average of squared gradients
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
DfLoss Function
A loss function measures the discrepancy between the true labels and the model's predictions . Training minimizes the expected loss over the data distribution. The choice of loss function determines what the model learns to optimize.
Loss Functions
Regression Losses
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Regression Loss Functions
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Create sample data
y_true = torch.tensor([3.0, 5.0, 2.5, 7.0])
y_pred = torch.tensor([2.5, 5.2, 2.0, 8.0])
# 1. Mean Squared Error (MSE)
mse_loss = nn.MSELoss()
mse = mse_loss(y_pred, y_true)
print(f"MSE: {mse:.4f}")
# Manual: ((2.5-3)ยฒ + (5.2-5)ยฒ + (2-2.5)ยฒ + (8-7)ยฒ) / 4
# 2. Mean Absolute Error (MAE/L1)
mae_loss = nn.L1Loss()
mae = mae_loss(y_pred, y_true)
print(f"MAE: {mae:.4f}")
# 3. Smooth L1 Loss (Huber Loss) - combines MSE and MAE
huber_loss = nn.SmoothL1Loss()
huber = huber_loss(y_pred, y_true)
print(f"Huber: {huber:.4f}")
# 4. Root Mean Squared Error (RMSE)
rmse = torch.sqrt(mse)
print(f"RMSE: {rmse:.4f}")
# 5. Mean Squared Logarithmic Error (MSLE)
msle_loss = nn.MSELoss()
y_true_pos = y_true.clone()
y_pred_pos = y_pred.clone()
msle = msle_loss(torch.log1p(y_pred_pos), torch.log1p(y_true_pos))
print(f"MSLE: {msle:.4f}")
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Loss Function Comparison
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
errors = np.linspace(-5, 5, 100)
errors_tensor = torch.tensor(errors, dtype=torch.float32)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# MSE
mse_vals = errors ** 2
axes[0].plot(errors, mse_vals, 'b-', linewidth=2)
axes[0].set_title('MSE Loss')
axes[0].set_xlabel('Error')
axes[0].set_ylabel('Loss')
axes[0].grid(True, alpha=0.3)
# MAE
mae_vals = np.abs(errors)
axes[1].plot(errors, mae_vals, 'r-', linewidth=2)
axes[1].set_title('MAE Loss')
axes[1].set_xlabel('Error')
axes[1].grid(True, alpha=0.3)
# Huber
huber_vals = np.where(np.abs(errors) <= 1,
0.5 * errors**2,
np.abs(errors) - 0.5)
axes[2].plot(errors, huber_vals, 'g-', linewidth=2)
axes[2].set_title('Huber Loss')
axes[2].set_xlabel('Error')
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('regression_losses.png', dpi=150)
plt.show()
Classification Losses
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Classification Loss Functions
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Binary Cross-Entropy
# L = -[yยทlog(p) + (1-y)ยทlog(1-p)]
y_true_binary = torch.tensor([1.0, 0.0, 1.0, 1.0])
y_pred_binary = torch.tensor([0.9, 0.1, 0.8, 0.7])
bce_loss = nn.BCELoss()
bce = bce_loss(y_pred_binary, y_true_binary)
print(f"BCE: {bce:.4f}")
# BCE with logits (numerically stable)
y_logits = torch.tensor([2.0, -2.0, 1.5, 1.0])
bce_logits_loss = nn.BCEWithLogitsLoss()
bce_logits = bce_logits_loss(y_logits, y_true_binary)
print(f"BCE with Logits: {bce_logits:.4f}")
# Categorical Cross-Entropy
y_true_cat = torch.tensor([0, 1, 2, 1]) # Class indices
y_pred_cat = torch.tensor([
[0.9, 0.05, 0.05],
[0.1, 0.8, 0.1],
[0.2, 0.2, 0.6],
[0.05, 0.9, 0.05]
])
ce_loss = nn.CrossEntropyLoss()
ce = ce_loss(y_pred_cat, y_true_cat)
print(f"Cross Entropy: {ce:.4f}")
# Focal Loss (for imbalanced datasets)
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()
focal_loss = FocalLoss()
focal = focal_loss(y_pred_cat, y_true_cat)
print(f"Focal Loss: {focal:.4f}")
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Loss Selection Guide
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
print("\n" + "=" * 60)
print("LOSS FUNCTION SELECTION GUIDE")
print("=" * 60)
guide = """
Task โ Loss Function โ When to Use
โโโโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโโโโโโโโชโโโโโโโโโโโโโโโโโโโโโโโ
Regression โ MSE / RMSE โ Normal distributed errors
Regression (outliers) โ Huber / Smooth L1 โ Robust to outliers
Regression (positive) โ MSLE โ Large range values
Binary Classification โ BCE / BCEWithLogits โ Two classes
Multi-class (balanced) โ CrossEntropy โ Single label
Multi-class (imbalanced)โ Focal Loss โ Rare classes
Multi-label โ BCE (per output) โ Multiple labels
Segmentation โ Dice / Jaccard โ Overlap metrics
GAN โ Adversarial Loss โ Generator training
"""
print(guide)
โน๏ธ Loss Function Selection
- Use MSE/RMSE for regression with normally distributed errors
- Use Huber Loss when data has outliers
- Use CrossEntropy for multi-class classification
- Use Focal Loss for imbalanced datasets
- Use BCE with Logits for binary classification (numerically stable)
Mean Squared Error (MSE)
Here,
- =True value for instance i
- =Predicted value for instance i
- =Number of instances
Cross-Entropy Loss (Classification)
Here,
- =Number of classes
- =True label (one-hot) for instance i, class c
- =Predicted probability for instance i, class c
โน๏ธ Numerical Stability of Cross-Entropy
Computing directly can cause numerical issues when . PyTorch's combines log-softmax and NLL loss in a single numerically stable operation, avoiding explicit logarithm computation of small probabilities.
Huber Loss (Robust Regression)
Here,
- =Threshold parameter (typically 1.0)
- =True value
- =Predicted value
๐ก Huber Loss Interpretation
Huber loss behaves like MSE for small errors (quadratic) and like MAE for large errors (linear). This makes it robust to outliers while still being differentiable everywhere. The parameter controls the transition point between quadratic and linear behavior.
Optimizers
Gradient Descent Variants
Optimizer Comparison:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
SGD (Stochastic Gradient Descent):
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
w = w - lr ร โL
Pros: Simple, generalizes well
Cons: Slow convergence, sensitive to lr
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ SGD Path (erratic): โ
โ โ
โ Start โ โ
โ โฒ โ
โ โโโโโ โ
โ โฒ โ
โ โโโโ โ
โ โฒ โ
โ โโโโโโโโโ โ
โ โฒ โ
โ โโโโโโโโโโโโโ Finish โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
SGD + Momentum:
โโโโโโโโโโโโโโโโโโ
v = ฮฒยทv + โL (velocity update)
w = w - lr ร v (weight update)
Pros: Accelerates convergence, dampens oscillations
Cons: Adds hyperparameter ฮฒ (typically 0.9)
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Momentum Path (smoother): โ
โ โ
โ Start โ โ
โ โฒ โ
โ โฒ โ
โ โ โ
โ โ โ
โ โ โ
โ โ โ
โ โฒ โ
โ โฒ โ
โ โโโโโโโโโโโโโโ Finish โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Adam (Adaptive Moment Estimation):
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
m = ฮฒโยทm + (1-ฮฒโ)ยทโL (first moment)
v = ฮฒโยทv + (1-ฮฒโ)ยท(โL)ยฒ (second moment)
mฬ = m / (1 - ฮฒโแต) (bias correction)
vฬ = v / (1 - ฮฒโแต) (bias correction)
w = w - lr ร mฬ / (โvฬ + ฮต)
Pros: Adaptive learning rates, fast convergence
Cons: May not generalize as well as SGD
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
SGD with Momentum
Here,
- =Velocity (running average of gradients)
- =Momentum coefficient (typically 0.9)
- =Learning rate
- =Gradient of the loss function
Adam Optimizer
Here,
- =First moment estimate (mean of gradients)
- =Second moment estimate (mean of squared gradients)
- =Exponential decay rate for first moment (0.9)
- =Exponential decay rate for second moment (0.999)
- =Small constant for numerical stability (1e-8)
โน๏ธ Why Adam Adapts Learning Rates
Adam maintains per-parameter adaptive learning rates by tracking both the mean (first moment) and variance (second moment) of gradients. Parameters with consistently large gradients get smaller effective learning rates, while parameters with small gradients get larger ones. This makes Adam robust to hyperparameter choices and effective across a wide range of problems.
import torch.optim as optim
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Optimizer Implementations
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
import torch.nn as nn
# Simple model for demonstration
model = nn.Sequential(
nn.Linear(100, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10)
)
# 1. SGD
optimizer_sgd = optim.SGD(
model.parameters(),
lr=0.01,
momentum=0.9,
weight_decay=1e-4
)
# 2. Adam
optimizer_adam = optim.Adam(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-4
)
# 3. AdamW (Adam with decoupled weight decay)
optimizer_adamw = optim.AdamW(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
weight_decay=0.01 # Stronger regularization
)
# 4. RMSprop
optimizer_rmsprop = optim.RMSprop(
model.parameters(),
lr=0.001,
alpha=0.99,
momentum=0,
weight_decay=1e-4
)
# 5. RAdam (Rectified Adam)
optimizer_radam = optim.RAdam(
model.parameters(),
lr=0.001
)
# 6. LAMB (Large Batch)
optimizer_lamb = optim.LAMB(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
weight_decay=0.01
)
print("Optimizers configured:")
for name, opt in [('SGD', optimizer_sgd), ('Adam', optimizer_adam),
('AdamW', optimizer_adamw), ('RMSprop', optimizer_rmsprop)]:
print(f" {name}: {len(opt.param_groups)} param groups")
Optimizer Comparison
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Compare Optimizers on Same Problem
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def run_optimizer_comparison():
"""Compare different optimizers on a simple regression task."""
torch.manual_seed(42)
# Generate synthetic data
X = torch.randn(1000, 10)
true_w = torch.randn(10, 1)
y = X @ true_w + 0.1 * torch.randn(1000, 1)
# Split data
X_train, X_val = X[:800], X[800:]
y_train, y_val = y[:800], y[800:]
optimizers_config = {
'SGD': {'lr': 0.01, 'momentum': 0.9},
'SGD+Momentum': {'lr': 0.01, 'momentum': 0.9, 'weight_decay': 1e-4},
'Adam': {'lr': 0.001},
'AdamW': {'lr': 0.001, 'weight_decay': 0.01},
'RMSprop': {'lr': 0.001},
}
results = {}
for name, config in optimizers_config.items():
# Fresh model for each optimizer
model = nn.Linear(10, 1)
criterion = nn.MSELoss()
if name.startswith('SGD'):
optimizer = optim.SGD(model.parameters(), **config)
elif name.startswith('Adam'):
if name == 'AdamW':
optimizer = optim.AdamW(model.parameters(), **config)
else:
optimizer = optim.Adam(model.parameters(), **config)
else:
optimizer = optim.RMSprop(model.parameters(), **config)
# Train
train_losses = []
val_losses = []
for epoch in range(200):
# Training
optimizer.zero_grad()
output = model(X_train)
loss = criterion(output, y_train)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
# Validation
with torch.no_grad():
val_output = model(X_val)
val_loss = criterion(val_output, y_val)
val_losses.append(val_loss.item())
results[name] = {
'train': train_losses,
'val': val_losses
}
print(f"{name}: Final val loss = {val_losses[-1]:.4f}")
# Plot results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for name, losses in results.items():
axes[0].plot(losses['train'], label=name, alpha=0.8)
axes[1].plot(losses['val'], label=name, alpha=0.8)
axes[0].set_title('Training Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_yscale('log')
axes[1].set_title('Validation Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_yscale('log')
plt.tight_layout()
plt.savefig('optimizer_comparison.png', dpi=150)
plt.show()
return results
results = run_optimizer_comparison()
ThConvergence of Gradient Descent
For a convex function with Lipschitz constant (bounded gradient), gradient descent with learning rate converges at rate after iterations. For strongly convex functions with parameter , the convergence rate improves to .
๐Learning Rate Selection
Scenario: Training a ResNet-50 on ImageNet with SGD + momentum.
Rule of thumb: Start with for SGD, or for Adam.
Why this works: For a typical CNN with ReLU activations, the loss surface has Lipschitz constant to . The learning rate would be to .
In practice: Use learning rate warmup (start small, increase linearly for first 5-10 epochs) followed by cosine annealing or step decay. This helps the model escape saddle points early in training.
Learning Rate Scheduling
from torch.optim.lr_scheduler import (
StepLR, CosineAnnealingLR, OneCycleLR,
ReduceLROnPlateau, CyclicLR, ExponentialLR
)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Learning Rate Schedulers
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
model = nn.Linear(10, 1)
num_epochs = 100
schedulers_config = {
'StepLR': StepLR(model.parameters() if False else optim.Adam(model.parameters(), lr=0.01),
step_size=30, gamma=0.1),
'ExponentialLR': ExponentialLR(optim.Adam(model.parameters(), lr=0.01), gamma=0.95),
'CosineAnnealing': CosineAnnealingLR(optim.Adam(model.parameters(), lr=0.01),
T_max=100),
'OneCycleLR': OneCycleLR(optim.Adam(model.parameters(), lr=0.01),
max_lr=0.01, steps_per_epoch=10, epochs=100),
}
# Visualize scheduler behavior
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
for ax, (name, scheduler) in zip(axes.flat, schedulers_config.items()):
lrs = []
for epoch in range(num_epochs):
lrs.append(scheduler.get_last_lr()[0])
scheduler.step()
ax.plot(lrs, linewidth=2)
ax.set_title(f'{name} Learning Rate Schedule')
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('lr_schedulers.png', dpi=150)
plt.show()
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# ReduceLROnPlateau (Validation-based)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5, # Reduce by half
patience=10, # Wait 10 epochs
verbose=True,
min_lr=1e-7
)
# Simulated training
val_losses = [1.0, 0.9, 0.85, 0.8, 0.78, 0.77, 0.76,
0.76, 0.76, 0.76, 0.76, 0.76, 0.75]
for epoch, val_loss in enumerate(val_losses):
scheduler.step(val_loss)
print(f"Epoch {epoch}: val_loss={val_loss:.4f}, lr={optimizer.param_groups[0]['lr']:.6f}")
Gradient Clipping
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Gradient Clipping Methods
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Method 1: Clip by norm (most common)
# Scales gradients so their L2 norm โค max_norm
max_norm = 1.0
# Method 2: Clip by value
# Clips each gradient element to [-clip_value, clip_value]
clip_value = 0.5
# Training with gradient clipping
def train_with_clipping(model, dataloader, criterion, optimizer,
max_norm=None, clip_value=None):
model.train()
total_loss = 0
for batch_x, batch_y in dataloader:
optimizer.zero_grad()
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
# Apply gradient clipping
if max_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
elif clip_value is not None:
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# Monitor gradient norms
def monitor_gradients(model):
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
return total_norm
# Example training loop with monitoring
print("Training with gradient clipping:")
for epoch in range(10):
# Forward pass
x = torch.randn(32, 100)
y = torch.randint(0, 10, (32,))
optimizer.zero_grad()
output = model(x)
loss = nn.CrossEntropyLoss()(output, y)
loss.backward()
# Monitor before clipping
grad_norm_before = monitor_gradients(model)
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Monitor after clipping
grad_norm_after = monitor_gradients(model)
optimizer.step()
print(f"Epoch {epoch+1}: loss={loss.item():.4f}, "
f"grad_norm: {grad_norm_before:.4f} -> {grad_norm_after:.4f}")
Complete Training Best Practices
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Production-Ready Training Loop
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
import time
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
class Trainer:
def __init__(self, model, train_loader, val_loader, config):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
# Loss function
self.criterion = self._get_criterion()
# Optimizer
self.optimizer = self._get_optimizer()
# Scheduler
self.scheduler = self._get_scheduler()
# Mixed precision training
self.scaler = torch.cuda.amp.GradScaler() if self.device.type == 'cuda' else None
# Logging
self.writer = SummaryWriter(log_dir=config.get('log_dir', 'runs'))
self.history = defaultdict(list)
self.best_val_loss = float('inf')
def _get_criterion(self):
loss_name = self.config.get('loss', 'cross_entropy')
if loss_name == 'cross_entropy':
return nn.CrossEntropyLoss()
elif loss_name == 'mse':
return nn.MSELoss()
elif loss_name == 'bce':
return nn.BCEWithLogitsLoss()
else:
return nn.CrossEntropyLoss()
def _get_optimizer(self):
opt_name = self.config.get('optimizer', 'adam')
lr = self.config.get('lr', 0.001)
weight_decay = self.config.get('weight_decay', 1e-4)
if opt_name == 'sgd':
return optim.SGD(self.model.parameters(), lr=lr,
momentum=0.9, weight_decay=weight_decay)
elif opt_name == 'adam':
return optim.Adam(self.model.parameters(), lr=lr,
weight_decay=weight_decay)
elif opt_name == 'adamw':
return optim.AdamW(self.model.parameters(), lr=lr,
weight_decay=weight_decay)
else:
return optim.Adam(self.model.parameters(), lr=lr)
def _get_scheduler(self):
sched_name = self.config.get('scheduler', 'cosine')
if sched_name == 'cosine':
return CosineAnnealingLR(self.optimizer,
T_max=self.config.get('epochs', 100))
elif sched_name == 'plateau':
return ReduceLROnPlateau(self.optimizer, mode='min',
factor=0.5, patience=10)
return None
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(self.train_loader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
# Mixed precision training
if self.scaler:
with torch.cuda.amp.autocast():
output = self.model(data)
loss = self.criterion(output, target)
self.scaler.scale(loss).backward()
# Gradient clipping
if self.config.get('grad_clip'):
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['grad_clip']
)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
if self.config.get('grad_clip'):
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['grad_clip']
)
self.optimizer.step()
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
if batch_idx % 100 == 0:
print(f" Batch {batch_idx}/{len(self.train_loader)}: "
f"Loss: {loss.item():.4f}")
accuracy = 100. * correct / total
avg_loss = total_loss / len(self.train_loader)
return avg_loss, accuracy
@torch.no_grad()
def validate(self):
self.model.eval()
total_loss = 0
correct = 0
total = 0
for data, target in self.val_loader:
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
loss = self.criterion(output, target)
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
accuracy = 100. * correct / total
avg_loss = total_loss / len(self.val_loader)
return avg_loss, accuracy
def train(self, num_epochs):
print(f"Training on {self.device}")
print(f"Config: {self.config}")
print("=" * 60)
for epoch in range(num_epochs):
start_time = time.time()
# Train
train_loss, train_acc = self.train_epoch(epoch)
# Validate
val_loss, val_acc = self.validate()
# Update scheduler
if self.scheduler:
if isinstance(self.scheduler, ReduceLROnPlateau):
self.scheduler.step(val_loss)
else:
self.scheduler.step()
# Log metrics
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['train_acc'].append(train_acc)
self.history['val_acc'].append(val_acc)
self.history['lr'].append(self.optimizer.param_groups[0]['lr'])
# TensorBoard logging
self.writer.add_scalars('Loss', {
'train': train_loss,
'val': val_loss
}, epoch)
self.writer.add_scalars('Accuracy', {
'train': train_acc,
'val': val_acc
}, epoch)
self.writer.add_scalar('LR', self.optimizer.param_groups[0]['lr'], epoch)
# Save best model
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
'val_acc': val_acc,
}, 'best_model.pth')
print(f" โ New best model saved!")
elapsed = time.time() - start_time
print(f"Epoch {epoch+1}/{num_epochs} ({elapsed:.1f}s):")
print(f" Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
print(f" Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
print(f" LR: {self.optimizer.param_groups[0]['lr']:.6f}")
self.writer.close()
return self.history
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Usage Example
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
config = {
'optimizer': 'adamw',
'lr': 0.001,
'weight_decay': 0.01,
'loss': 'cross_entropy',
'scheduler': 'cosine',
'grad_clip': 1.0,
'epochs': 50
}
model = CNNClassifier(num_classes=10)
trainer = Trainer(model, train_loader, val_loader, config)
history = trainer.train(num_epochs=50)
๐Key Takeaways
- MSE for regression with normal errors, Huber for robust regression, CrossEntropy for classification, Focal Loss for imbalanced data
- Adam/AdamW are good defaults with adaptive learning rates; SGD+momentum often generalizes better but requires careful LR tuning
- Learning rate is the most important hyperparameter โ use scheduling (cosine annealing, warmup + decay)
- Gradient clipping () prevents exploding gradients in deep networks
- Mixed precision training (FP16) provides 2-3x speedup on modern GPUs with minimal accuracy loss
- Weight decay in AdamW decouples regularization from the gradient update, providing better generalization than L2 regularization in Adam
- Always save checkpoints and log metrics for debugging and model selection
Practice Exercises
- Loss Function Comparison: Train the same model with MSE, Huber, and MAE - when does each work best?
- Optimizer Benchmark: Compare optimizers on a convnet with CIFAR-10
- LR Finder: Implement a learning rate finder that plots loss vs learning rate
- Training Diagnostics: Plot gradient norms, weight distributions, and activation statistics during training