RNN Deep Dive — Vanilla RNN, BPTT & Vanishing/Exploding Gradients
Recurrent Neural Networks process sequential data by maintaining a hidden state that carries information across time steps. This tutorial covers vanilla RNNs and their fundamental limitations.
See our RNN/LSTM tutorial for a general overview of recurrent architectures including GRU and LSTM.
Why Recurrence?
DfSequential Data
Many data types are inherently sequential:
- Text: Words follow a specific order
- Speech: Audio signals are time-series
- Video: Frames have temporal relationships
- Time series: Stock prices, sensor readings
Feedforward networks ignore temporal structure. RNNs explicitly model it by maintaining a hidden state that evolves over time.
Vanilla RNN
DfVanilla RNN
The RNN recurrence relation:
At each time step:
- Take previous hidden state
- Combine with current input
- Apply tanh activation
- Output prediction
The same parameters () are shared across all time steps — this is weight sharing, which provides parameter efficiency and allows generalizing to variable-length sequences.
Backpropagation Through Time (BPTT)
DfBPTT
BPTT unrolls the RNN through time and applies standard backpropagation:
- Forward pass: Compute sequentially
- Compute loss:
- Backward pass: Compute gradients by unrolling the recurrence
The gradient of the loss with respect to depends on gradients from all future time steps.
BPTT Gradient Computation
Here,
- =Hidden state at time t
- =Direct gradient from loss at time t
- =Jacobian of next hidden state w.r.t. current
Vanishing and Exploding Gradients
ThVanishing Gradient in RNNs
The gradient of with respect to is:
Since and may be less than 1, the product shrinks exponentially. For a sequence of length , the gradient is proportional to where is the largest singular value of .
Gradient Magnitude Through Time
Here,
- =Maximum singular value of the recurrent weight matrix (times max derivative)
- =Number of time steps between gradient source and target
ThGradient Explosion in RNNs
When , gradients grow exponentially. For and , the gradient is . This causes:
- Numerical overflow: Loss becomes NaN
- Unstable training: Weight updates are too large
- Solution: Gradient clipping (clip gradients by norm or value)
ℹ️ Practical Impact
- Vanilla RNNs can only learn dependencies spanning ~10-20 time steps
- Most real-world sequences are much longer (sentences, documents, videos)
- This fundamental limitation motivated the invention of LSTM and GRU
- Even with gradient clipping, vanilla RNNs struggle with long-range dependencies
Mathematical Analysis
DfSingular Value Analysis
The stability of RNN training depends on the spectral properties of :
- If (spectral norm): Gradients vanish
- If : Gradients explode
- If : Gradients are stable (orthogonal initialization)
The spectral norm is the largest singular value of .
💡 Orthogonal Initialization for RNNs
Initialize as an orthogonal matrix (). This ensures all singular values are 1, preventing both vanishing and exploding gradients at the start of training. PyTorch: nn.init.orthogonal_(weight).
PyTorch Implementation
📝Example: Vanilla RNN from Scratch
import torch
import torch.nn as nn
class VanillaRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
self.W_xh = nn.Linear(input_size, hidden_size)
self.W_hh = nn.Linear(hidden_size, hidden_size, bias=False)
self.W_hy = nn.Linear(hidden_size, output_size)
def forward(self, x, h_prev=None):
# x shape: (batch, seq_len, input_size)
batch_size, seq_len, _ = x.shape
if h_prev is None:
h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)
outputs = []
h = h_prev
for t in range(seq_len):
h = torch.tanh(self.W_xh(x[:, t, :]) + self.W_hh(h))
y = self.W_hy(h)
outputs.append(y)
outputs = torch.stack(outputs, dim=1)
return outputs, h
# Test
model = VanillaRNN(input_size=10, hidden_size=64, output_size=5)
x = torch.randn(32, 20, 10) # batch=32, seq_len=20, features=10
outputs, h_final = model(x)
print(f"Output shape: {outputs.shape}") # [32, 20, 5]
print(f"Final hidden state: {h_final.shape}") # [32, 64]
📝Example: Gradient Monitoring and Clipping
import torch
import torch.nn as nn
model = nn.RNN(input_size=10, hidden_size=64, num_layers=2, batch_first=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Monitor gradient norms across layers
def monitor_gradients(model):
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
print(f"{name}: grad_norm = {grad_norm:.6f}")
# Training with gradient clipping
x = torch.randn(32, 50, 10) # Long sequence to trigger vanishing/exploding
target = torch.randint(0, 5, (32, 50))
for epoch in range(10):
output, _ = model(x)
loss = nn.CrossEntropyLoss()(output.view(-1, 5), target.view(-1))
optimizer.zero_grad()
loss.backward()
# Clip gradients to prevent explosion
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
print(f"\nEpoch {epoch}: Loss={loss.item():.4f}")
monitor_gradients(model)
optimizer.step()
⚠️ Truncated BPTT
For very long sequences, full BPTT is computationally expensive. Truncated BPTT splits the sequence into chunks and runs BPTT within each chunk, detaching gradients between chunks. This limits the effective memory to the chunk size but reduces computation.
Practical Applications of Vanilla RNNs
DfWhen Vanilla RNNs Work
Despite their limitations, vanilla RNNs are suitable for:
- Short sequences: Tasks with dependencies spanning < 20 time steps
- Simple patterns: When the temporal structure is simple and predictable
- Character-level modeling: Predicting the next character in a short text
- Time series with short-term dependencies: Simple forecasting tasks
- Educational purposes: Understanding the foundations before moving to LSTM/GRU
For most real-world sequence tasks, LSTM or GRU networks are preferred because they handle long-range dependencies much better.
ℹ️ RNN vs. Transformer for Sequences
Modern sequence modeling has shifted from RNNs to Transformers (self-attention). Transformers process all positions in parallel, avoiding the sequential bottleneck of RNNs. However, RNNs still have advantages:
- Streaming data: RNNs process one token at a time; Transformers need the full sequence
- Variable-length sequences: RNNs naturally handle variable lengths; Transformers need padding/masking
- Edge devices: RNNs have lower memory footprint for inference
- Time complexity: RNNs are per step; Transformers are for the full sequence
Summary
📋Summary: RNN Deep Dive
- Vanilla RNN:
- BPTT: Unrolls RNN through time, applies standard backpropagation
- Vanishing gradients: Product of Jacobians shrinks exponentially — limits memory to ~10-20 steps
- Exploding gradients: Product grows exponentially — use gradient clipping
- Root cause: Recurrent weight matrix spectral norm ≠ 1
- Solutions: Orthogonal initialization, LSTM/GRU gates, skip connections
- Truncated BPTT: Limits computation by processing sequences in chunks
- Weight sharing: Same parameters across time steps enables variable-length processing
Practice Exercises
-
Mathematical: For a vanilla RNN with and , compute the gradient magnitude after 50 time steps. Does it vanish or explode?
-
Experiment: Train a vanilla RNN on a long-range dependency task (e.g., copy last 50 elements of a 100-element sequence). What happens as sequence length increases?
-
Coding: Implement gradient monitoring hooks in PyTorch to visualize gradient norms across time steps during BPTT. Plot the gradient magnitude as a function of time step.
-
Research: Look up the "gradient highway" in orthogonal RNNs. How does orthogonal initialization create stable gradient paths?
-
Application: Train a vanilla RNN for text generation on a small corpus. Does it capture long-range dependencies (e.g., subject-verb agreement across sentences)?