Loss Functions for Deep Learning — MSE, Cross-Entropy, Focal Loss & Beyond
Loss functions quantify how wrong a model's predictions are. Choosing the right loss function is critical for effective training.
See our Training Loops tutorial for a broader overview of loss functions and optimizers.
Loss Function Taxonomy
DfTypes of Loss Functions
| Category | Loss Function | Use Case |
|---|---|---|
| Regression | MSE, MAE, Huber | Continuous output prediction |
| Classification | Cross-Entropy, Focal | Discrete class prediction |
| Ranking | Triplet, Contrastive | Similarity learning |
| Generative | Reconstruction, Adversarial | Data generation |
| Segmentation | Dice, IoU | Pixel-level prediction |
Mean Squared Error (MSE)
DfMSE Loss
MSE measures the average squared difference between predictions and targets:
- Penalizes large errors quadratically (robust to small errors, sensitive to outliers)
- Differentiable everywhere with smooth gradients
- Assumes Gaussian errors with constant variance
Mean Squared Error
Here,
- =True value for instance i
- =Predicted value for instance i
- =Number of instances
Cross-Entropy Loss
DfCross-Entropy Loss
Cross-entropy measures the difference between the true label distribution and the predicted distribution:
For hard labels (), this simplifies to:
where is the true class. Minimizing cross-entropy is equivalent to maximizing the likelihood of the correct class under the model's predicted distribution.
ℹ️ Numerical Stability of Cross-Entropy
Computing directly causes numerical issues when . PyTorch's nn.CrossEntropyLoss combines log-softmax and NLL loss in a single numerically stable operation. Always use nn.CrossEntropyLoss(logits, targets) instead of manually computing softmax + log + NLL.
Binary Cross-Entropy (BCE)
DfBinary Cross-Entropy
BCE is the cross-entropy loss for binary classification:
Use with sigmoid output for binary classification. PyTorch provides nn.BCEWithLogitsLoss which combines sigmoid + BCE for numerical stability.
Binary Cross-Entropy
Here,
- =True label (0 or 1)
- =Predicted probability
- =Number of instances
Focal Loss
DfFocal Loss
Focal loss addresses class imbalance by down-weighting easy examples:
where is the model's probability for the correct class, is the focusing parameter, and is a class weight. When , focal loss reduces to standard cross-entropy.
Focal Loss
Here,
- =Model's predicted probability for the correct class
- =Focusing parameter (typically 2.0)
- =Class balancing weight
ℹ️ Focal Loss Intuition
When is large (easy example), is small, reducing the loss contribution. When is small (hard example), is large, keeping the loss significant. This forces the model to focus on hard examples — critical for object detection with extreme class imbalance.
Huber Loss (Smooth L1)
DfHuber Loss
Huber loss combines MSE and MAE, being quadratic for small errors and linear for large errors:
- Robust to outliers (linear region)
- Differentiable everywhere (smooth transition at )
- Parameter controls the transition point (typically 1.0)
Huber Loss
Here,
- =Threshold parameter (typically 1.0)
- =True value
- =Predicted value
💡 Huber Loss Interpretation
Huber loss behaves like MSE for small errors (quadratic, smooth gradients) and like MAE for large errors (linear, bounded gradients). This makes it robust to outliers while still being differentiable everywhere. Use it for regression tasks where outliers are present.
Contrastive Loss
DfContrastive Loss
Contrastive loss learns embeddings by pulling similar pairs together and pushing dissimilar pairs apart:
where is the distance between embeddings, for similar pairs, and is the margin for negative pairs. Used in Siamese networks and metric learning.
Contrastive Loss
Here,
- =Distance between embeddings
- =1 if similar, 0 if dissimilar
- =Margin for negative pairs
Dice Loss
DfDice Loss
Dice loss measures overlap between predicted and target masks, commonly used in segmentation:
where is the predicted probability and is the ground truth (0 or 1). It is invariant to class imbalance and directly optimizes the IoU metric.
Dice Loss
Here,
- =Predicted probability at pixel i
- =Ground truth at pixel i (0 or 1)
- =Smoothing term for numerical stability
PyTorch Implementation
📝Example: Loss Functions in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
# ═══════════════════════════════════════════════════
# Classification Losses
# ═══════════════════════════════════════════════════
# Multi-class classification
logits = torch.randn(32, 10) # Batch of 32, 10 classes
targets = torch.randint(0, 10, (32,))
ce_loss = nn.CrossEntropyLoss()(logits, targets)
print(f"Cross-Entropy Loss: {ce_loss.item():.4f}")
# Binary classification
binary_logits = torch.randn(32, 1)
binary_targets = torch.randint(0, 2, (32, 1)).float()
bce_loss = nn.BCEWithLogitsLoss()(binary_logits, binary_targets)
print(f"BCE Loss: {bce_loss.item():.4f}")
# Focal loss (custom implementation)
def focal_loss(logits, targets, gamma=2.0, alpha=0.25):
ce = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce)
return (alpha * (1 - pt) ** gamma * ce).mean()
focal = focal_loss(logits, targets)
print(f"Focal Loss: {focal.item():.4f}")
# ═══════════════════════════════════════════════════
# Regression Losses
# ═══════════════════════════════════════════════════
y_pred = torch.randn(32, 1)
y_true = torch.randn(32, 1)
mse_loss = nn.MSELoss()(y_pred, y_true)
mae_loss = nn.L1Loss()(y_pred, y_true)
huber_loss = nn.HuberLoss(delta=1.0)(y_pred, y_true)
print(f"\nMSE Loss: {mse_loss.item():.4f}")
print(f"MAE Loss: {mae_loss.item():.4f}")
print(f"Huber Loss: {huber_loss.item():.4f}")
# ═══════════════════════════════════════════════════
# Dice Loss (Segmentation)
# ═══════════════════════════════════════════════════
def dice_loss(pred, target, epsilon=1e-6):
pred = torch.sigmoid(pred)
intersection = (pred * target).sum()
union = pred.sum() + target.sum()
return 1 - (2.0 * intersection + epsilon) / (union + epsilon)
pred_mask = torch.randn(1, 1, 64, 64)
target_mask = torch.randint(0, 2, (1, 1, 64, 64)).float()
d_loss = dice_loss(pred_mask, target_mask)
print(f"\nDice Loss: {d_loss.item():.4f}")
When to Use Which
💡 Loss Function Selection Guide
- Binary classification: BCEWithLogitsLoss (not sigmoid + BCE)
- Multi-class classification: CrossEntropyLoss (not softmax + NLL)
- Multi-label classification: BCEWithLogitsLoss with each output independent
- Regression (clean data): MSELoss
- Regression (outliers): HuberLoss
- Segmentation: DiceLoss + BCE combination
- Object detection: Focal Loss for class imbalance
- Metric learning: Contrastive Loss, Triplet Loss
- Imbalanced data: Focal Loss or class-weighted cross-entropy
Summary
📋Summary: Loss Functions for Deep Learning
- MSE: Regression, assumes Gaussian errors, sensitive to outliers
- Cross-Entropy: Classification, equivalent to maximum likelihood
- Binary CE: Binary classification, always with sigmoid output
- Focal Loss: Down-weights easy examples, critical for imbalanced data
- Huber Loss: Robust regression, combines MSE and MAE
- Contrastive Loss: Metric learning, pull similar pairs together
- Dice Loss: Segmentation, optimizes overlap directly
- Always use numerically stable variants:
CrossEntropyLoss,BCEWithLogitsLoss - Loss choice shapes what the model learns — choose based on your objective
Practice Exercises
-
Conceptual: Why is
nn.CrossEntropyLossnumerically stable whilenn.NLLLoss(torch.log(F.softmax(logits)))is not? What happens whenlogitscontains very large values? -
Coding: Implement focal loss from scratch with class-specific weights. Test it on CIFAR-10 with a class imbalance (reduce training samples of one class by 10x).
-
Experiment: Compare MSE vs. Huber loss for regression on a dataset with 10% outlier labels. Which converges faster? Which gives better final performance?
-
Visualization: Plot the loss surface of cross-entropy vs. focal loss for different values of . How does the loss landscape change?
-
Application: Build a binary classifier for imbalanced data (1:100 ratio). Compare standard BCE, weighted BCE, and focal loss. Which performs best on minority class recall?