Weight Initialization — Xavier, He, LSUV & Variance Preservation
Weight initialization determines how neural network parameters are set before training begins. Poor initialization leads to vanishing or exploding gradients, making deep networks untrainable.
See our Backpropagation tutorial for why initialization affects gradient flow.
Why Initialization Matters
DfThe Initialization Problem
Consider a network with layers. The output variance after layers depends on the variance of each layer's weights. If the variance shrinks at each layer, activations vanish. If it grows, activations explode. Initialization must preserve the variance of activations and gradients across layers.
ThVariance Preservation Requirement
For a layer , the variance of the output is:
For variance to be preserved across layers:
This is the foundation for both Xavier and He initialization.
Random Initialization
DfNaive Random Initialization
Initialize weights from a standard normal distribution:
Problem: For a layer with 1000 inputs, the output variance is 1000 times the input variance. Activations explode after a few layers, causing numerical overflow and gradient explosion.
⚠️ Why Not Zero Initialization?
Initializing all weights to zero causes all neurons in a layer to compute the same output and receive the same gradient. They remain identical throughout training — the network cannot learn different features. Weights must be asymmetric at initialization.
Xavier/Glorot Initialization
DfXavier (Glorot) Initialization
Xavier initialization preserves variance for linear and sigmoid/tanh activations. It sets weights from:
This balances the variance of inputs and outputs. Derivation assumes linear activations and small weights.
Xavier (Glorot) Initialization
Here,
- =Number of input neurons
- =Number of output neurons
- =Normal distribution
- =Uniform distribution
ℹ️ Xavier Intuition
The factor is a compromise between (preserves input variance) and (preserves output variance). It works well for sigmoid and tanh activations where the function is approximately linear near zero.
He Initialization
DfHe (Kaiming) Initialization
He initialization accounts for ReLU's non-linearity, which zeros out half the outputs:
The factor of 2 compensates for ReLU's variance-halving property. This is the default initialization for ReLU networks.
💡 When to Use Which
- Xavier: Sigmoid, tanh activations (linear regime)
- He: ReLU, Leaky ReLU, PReLU activations
- Default in PyTorch: Linear layers use Kaiming uniform; Conv2d uses Kaiming uniform with fan_out
LSUV (Layer-Sequential Unit-Variance)
DfLSUV Initialization
LSUV (Mishkin & Matas, 2015) iteratively adjusts initialization by:
- Initialize all weights from
- For each layer (in order):
- Pass a batch of data through the layer
- Compute the variance of the output
- Scale weights so that the output variance equals 1
This is data-dependent and works for any activation function.
📝Example: LSUV Initialization
import torch
import torch.nn as nn
def lsuv_init(model, data, std_tol=0.01, max_iter=100):
"""Layer-Sequential Unit-Variance initialization."""
modules = [m for m in model.modules()
if isinstance(m, (nn.Conv2d, nn.Linear))]
hooks = []
for layer in modules:
def hook_fn(module, input, output):
# Store output for this layer
module.output = output
hooks.append(layer.register_forward_hook(hook_fn))
# Forward pass to get all layer outputs
with torch.no_grad():
model(data)
# Adjust weights for each layer
for layer in modules:
output = layer.output
# Iterate until variance is close to 1
for _ in range(max_iter):
current_std = output.std().item()
if abs(current_std - 1.0) < std_tol:
break
# Scale weights
with torch.no_grad():
scale = 1.0 / current_std
if isinstance(layer, nn.Conv2d):
layer.weight *= scale
elif isinstance(layer, nn.Linear):
layer.weight *= scale
# Forward pass again for next layer
with torch.no_grad():
model(data)
# Remove hooks
for h in hooks:
h.remove()
return model
Biases Initialization
DfBias Initialization
- Hidden layers: Initialize biases to zero (standard)
- Output layer: Initialize biases to the inverse sigmoid of the target distribution
- LSTM forget gate: Initialize forget gate bias to 1.0 (or larger) to encourage information flow
For binary classification with sigmoid output, setting the bias to helps early training.
Initialization in Different Architectures
DfArchitecture-Specific Initialization
- ResNet: He initialization for all conv layers; zero-init for the last BN layer in each residual block
- Transformer: Xavier/Glorot for attention projections; small init for output projection of residual blocks
- LSTM/GRU: Orthogonal initialization for recurrent weights; Xavier for input-to-hidden weights
- BatchNorm: Initialize , (standard normalization)
ℹ️ ResNet Zero Initialization Trick
ResNet initializes the last batch normalization layer in each residual block with . This ensures the residual block initially computes zero, making the network behave like a shallow network at the start of training. This is critical for training very deep networks (100+ layers).
PyTorch Implementation
📝Example: All Initialization Methods
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# ═══════════════════════════════════════════════════
# 1. Xavier/Glorot Initialization
# ═══════════════════════════════════════════════════
for name, param in model.named_parameters():
if 'weight' in name:
nn.init.xavier_uniform_(param)
print(f"{name}: Xavier uniform, std={param.std():.4f}")
# ═══════════════════════════════════════════════════
# 2. He (Kaiming) Initialization
# ═══════════════════════════════════════════════════
for name, param in model.named_parameters():
if 'weight' in name:
nn.init.kaiming_normal_(param, mode='fan_in', nonlinearity='relu')
print(f"{name}: He normal, std={param.std():.4f}")
# ═══════════════════════════════════════════════════
# 3. Orthogonal Initialization (for RNNs)
# ═══════════════════════════════════════════════════
rnn = nn.LSTM(128, 256, num_layers=2)
for name, param in rnn.named_parameters():
if 'weight_ih' in name:
nn.init.xavier_uniform_(param.data)
elif 'weight_hh' in name:
nn.init.orthogonal_(param.data)
elif 'bias' in name:
param.data.fill_(0)
# Set forget gate bias to 1
n = param.size(0)
param.data[n//4:n//2].fill_(1.0)
print(f"{name}: shape={param.shape}")
# ═══════════════════════════════════════════════════
# 4. Inspect Activation Statistics
# ═══════════════════════════════════════════════════
def init_and_check(init_fn, nonlinearity='relu'):
"""Check activation statistics for different initializations."""
model = nn.Sequential(
nn.Linear(256, 256),
nn.ReLU() if nonlinearity == 'relu' else nn.Tanh(),
nn.Linear(256, 256),
nn.ReLU() if nonlinearity == 'relu' else nn.Tanh(),
nn.Linear(256, 256),
nn.ReLU() if nonlinearity == 'relu' else nn.Tanh(),
)
# Apply initialization
for m in model:
if isinstance(m, nn.Linear):
init_fn(m.weight)
# Forward pass
x = torch.randn(1000, 256)
activations = []
h = x
for layer in model:
h = layer(h)
if isinstance(layer, (nn.ReLU, nn.Tanh)):
activations.append(h.detach())
# Print statistics
for i, act in enumerate(activations):
print(f"Layer {i}: mean={act.mean():.4f}, std={act.std():.4f}")
print("\nXavier + ReLU:")
init_and_check(lambda w: nn.init.xavier_uniform_(w), 'relu')
print("\nHe + ReLU:")
init_and_check(lambda w: nn.init.kaiming_normal_(w, nonlinearity='relu'), 'relu')
print("\nXavier + Tanh:")
init_and_check(lambda w: nn.init.xavier_uniform_(w), 'tanh')
Summary
📋Summary: Weight Initialization
- Zero initialization: Breaks symmetry — all neurons learn the same features
- Random (N(0,1)): Causes exploding/vanishing activations in deep networks
- Xavier/Glorot: Preserves variance for sigmoid/tanh activations
- He/Kaiming: Preserves variance for ReLU activations (divide by 2)
- LSUV: Data-dependent initialization that works for any architecture
- ResNet trick: Zero-init last BN layer in residual blocks
- Always use: He init for ReLU networks, Xavier for sigmoid/tanh
- Inspect activation statistics to verify initialization quality
Practice Exercises
-
Conceptual: Derive why He initialization uses instead of for ReLU. What happens to the variance when ReLU zeros out half the activations?
-
Experiment: Train a 20-layer MLP with different initializations (zero, random, Xavier, He). Plot the activation variance across layers for each initialization. Which one preserves variance?
-
Coding: Implement LSUV initialization from scratch. Apply it to a ResNet-18 and compare with the default PyTorch initialization on CIFAR-10.
-
Research: Look up the orthogonal initialization for RNNs (Saxe et al., 2013). Why does orthogonal initialization help with vanishing gradients in recurrent networks?
-
Debugging: Create a network with poor initialization that exhibits exploding gradients. Use gradient clipping to fix it. Compare this with fixing the initialization itself.