RNN Deep Dive — Vanilla RNN, BPTT & Vanishing/Exploding Gradients

Sequence ModelsRNNsFree Lesson

Advertisement

RNN Deep Dive — Vanilla RNN, BPTT & Vanishing/Exploding Gradients

Recurrent Neural Networks process sequential data by maintaining a hidden state that carries information across time steps. This tutorial covers vanilla RNNs and their fundamental limitations.

See our RNN/LSTM tutorial for a general overview of recurrent architectures including GRU and LSTM.


Why Recurrence?

DfSequential Data

Many data types are inherently sequential:

  • Text: Words follow a specific order
  • Speech: Audio signals are time-series
  • Video: Frames have temporal relationships
  • Time series: Stock prices, sensor readings

Feedforward networks ignore temporal structure. RNNs explicitly model it by maintaining a hidden state that evolves over time.


Vanilla RNN

DfVanilla RNN

The RNN recurrence relation:

ht=tanh(Whhht1+Wxhxt+b)\mathbf{h}_t = \tanh(\mathbf{W}_{hh} \mathbf{h}_{t-1} + \mathbf{W}_{xh} \mathbf{x}_t + \mathbf{b})
yt=Whyht\mathbf{y}_t = \mathbf{W}_{hy} \mathbf{h}_t

At each time step:

  1. Take previous hidden state ht1\mathbf{h}_{t-1}
  2. Combine with current input xt\mathbf{x}_t
  3. Apply tanh activation
  4. Output prediction yt\mathbf{y}_t

The same parameters (Whh,Wxh,b\mathbf{W}_{hh}, \mathbf{W}_{xh}, \mathbf{b}) are shared across all time steps — this is weight sharing, which provides parameter efficiency and allows generalizing to variable-length sequences.

ht=tanh(Whhht1+Wxhxt+b)\mathbf{h}_t = \tanh(\mathbf{W}_{hh} \mathbf{h}_{t-1} + \mathbf{W}_{xh} \mathbf{x}_t + \mathbf{b})

Backpropagation Through Time (BPTT)

DfBPTT

BPTT unrolls the RNN through time and applies standard backpropagation:

  1. Forward pass: Compute h1,h2,,hT\mathbf{h}_1, \mathbf{h}_2, \ldots, \mathbf{h}_T sequentially
  2. Compute loss: L=t=1TLt(yt,y^t)\mathcal{L} = \sum_{t=1}^{T} \mathcal{L}_t(\mathbf{y}_t, \hat{\mathbf{y}}_t)
  3. Backward pass: Compute gradients by unrolling the recurrence

The gradient of the loss with respect to ht\mathbf{h}_t depends on gradients from all future time steps.

BPTT Gradient Computation

Lht=Ltht+Lht+1ht+1ht\frac{\partial \mathcal{L}}{\partial \mathbf{h}_t} = \frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_t} + \frac{\partial \mathcal{L}}{\partial \mathbf{h}_{t+1}} \cdot \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t}

Here,

  • ht\mathbf{h}_t=Hidden state at time t
  • Ltht\frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_t}=Direct gradient from loss at time t
  • ht+1ht\frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t}=Jacobian of next hidden state w.r.t. current

Vanishing and Exploding Gradients

ThVanishing Gradient in RNNs

The gradient of hT\mathbf{h}_T with respect to ht\mathbf{h}_t is:

hTht=k=tT1hk+1hk=k=tT1diag(tanh(zk+1))Whh\frac{\partial \mathbf{h}_T}{\partial \mathbf{h}_t} = \prod_{k=t}^{T-1} \frac{\partial \mathbf{h}_{k+1}}{\partial \mathbf{h}_k} = \prod_{k=t}^{T-1} \text{diag}(\tanh'(\mathbf{z}_{k+1})) \cdot \mathbf{W}_{hh}

Since tanh(x)1|\tanh'(x)| \leq 1 and Whh\|\mathbf{W}_{hh}\| may be less than 1, the product shrinks exponentially. For a sequence of length TT, the gradient is proportional to (λmax)Tt(\lambda_{\max})^{T-t} where λmax\lambda_{\max} is the largest singular value of Whhdiag(tanh)\mathbf{W}_{hh} \cdot \text{diag}(\tanh').

Gradient Magnitude Through Time

hThtk=tT1Whhdiag(tanh)λmaxTt\left\|\frac{\partial \mathbf{h}_T}{\partial \mathbf{h}_t}\right\| \leq \prod_{k=t}^{T-1} \|\mathbf{W}_{hh}\| \cdot \|\text{diag}(\tanh')\| \leq \lambda_{\max}^{T-t}

Here,

  • λmax\lambda_{\max}=Maximum singular value of the recurrent weight matrix (times max derivative)
  • TtT - t=Number of time steps between gradient source and target

ThGradient Explosion in RNNs

When λmax>1\lambda_{\max} > 1, gradients grow exponentially. For λmax=1.1\lambda_{\max} = 1.1 and Tt=100T - t = 100, the gradient is 1.110013,7801.1^{100} \approx 13,780. This causes:

  • Numerical overflow: Loss becomes NaN
  • Unstable training: Weight updates are too large
  • Solution: Gradient clipping (clip gradients by norm or value)

ℹ️ Practical Impact

  • Vanilla RNNs can only learn dependencies spanning ~10-20 time steps
  • Most real-world sequences are much longer (sentences, documents, videos)
  • This fundamental limitation motivated the invention of LSTM and GRU
  • Even with gradient clipping, vanilla RNNs struggle with long-range dependencies

Mathematical Analysis

DfSingular Value Analysis

The stability of RNN training depends on the spectral properties of Whh\mathbf{W}_{hh}:

  • If Whh2<1\|\mathbf{W}_{hh}\|_2 < 1 (spectral norm): Gradients vanish
  • If Whh2>1\|\mathbf{W}_{hh}\|_2 > 1: Gradients explode
  • If Whh2=1\|\mathbf{W}_{hh}\|_2 = 1: Gradients are stable (orthogonal initialization)

The spectral norm W2\|\mathbf{W}\|_2 is the largest singular value of W\mathbf{W}.

💡 Orthogonal Initialization for RNNs

Initialize Whh\mathbf{W}_{hh} as an orthogonal matrix (WTW=I\mathbf{W}^T \mathbf{W} = \mathbf{I}). This ensures all singular values are 1, preventing both vanishing and exploding gradients at the start of training. PyTorch: nn.init.orthogonal_(weight).


PyTorch Implementation

📝Example: Vanilla RNN from Scratch

import torch
import torch.nn as nn

class VanillaRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.W_xh = nn.Linear(input_size, hidden_size)
        self.W_hh = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_hy = nn.Linear(hidden_size, output_size)

    def forward(self, x, h_prev=None):
        # x shape: (batch, seq_len, input_size)
        batch_size, seq_len, _ = x.shape

        if h_prev is None:
            h_prev = torch.zeros(batch_size, self.hidden_size, device=x.device)

        outputs = []
        h = h_prev

        for t in range(seq_len):
            h = torch.tanh(self.W_xh(x[:, t, :]) + self.W_hh(h))
            y = self.W_hy(h)
            outputs.append(y)

        outputs = torch.stack(outputs, dim=1)
        return outputs, h

# Test
model = VanillaRNN(input_size=10, hidden_size=64, output_size=5)
x = torch.randn(32, 20, 10)  # batch=32, seq_len=20, features=10
outputs, h_final = model(x)
print(f"Output shape: {outputs.shape}")  # [32, 20, 5]
print(f"Final hidden state: {h_final.shape}")  # [32, 64]

📝Example: Gradient Monitoring and Clipping

import torch
import torch.nn as nn

model = nn.RNN(input_size=10, hidden_size=64, num_layers=2, batch_first=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Monitor gradient norms across layers
def monitor_gradients(model):
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            print(f"{name}: grad_norm = {grad_norm:.6f}")

# Training with gradient clipping
x = torch.randn(32, 50, 10)  # Long sequence to trigger vanishing/exploding
target = torch.randint(0, 5, (32, 50))

for epoch in range(10):
    output, _ = model(x)
    loss = nn.CrossEntropyLoss()(output.view(-1, 5), target.view(-1))

    optimizer.zero_grad()
    loss.backward()

    # Clip gradients to prevent explosion
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    print(f"\nEpoch {epoch}: Loss={loss.item():.4f}")
    monitor_gradients(model)

    optimizer.step()

⚠️ Truncated BPTT

For very long sequences, full BPTT is computationally expensive. Truncated BPTT splits the sequence into chunks and runs BPTT within each chunk, detaching gradients between chunks. This limits the effective memory to the chunk size but reduces computation.


Practical Applications of Vanilla RNNs

DfWhen Vanilla RNNs Work

Despite their limitations, vanilla RNNs are suitable for:

  1. Short sequences: Tasks with dependencies spanning < 20 time steps
  2. Simple patterns: When the temporal structure is simple and predictable
  3. Character-level modeling: Predicting the next character in a short text
  4. Time series with short-term dependencies: Simple forecasting tasks
  5. Educational purposes: Understanding the foundations before moving to LSTM/GRU

For most real-world sequence tasks, LSTM or GRU networks are preferred because they handle long-range dependencies much better.

ℹ️ RNN vs. Transformer for Sequences

Modern sequence modeling has shifted from RNNs to Transformers (self-attention). Transformers process all positions in parallel, avoiding the sequential bottleneck of RNNs. However, RNNs still have advantages:

  • Streaming data: RNNs process one token at a time; Transformers need the full sequence
  • Variable-length sequences: RNNs naturally handle variable lengths; Transformers need padding/masking
  • Edge devices: RNNs have lower memory footprint for inference
  • Time complexity: RNNs are O(T)O(T) per step; Transformers are O(T2)O(T^2) for the full sequence

Summary

📋Summary: RNN Deep Dive

  • Vanilla RNN: ht=tanh(Whhht1+Wxhxt+b)\mathbf{h}_t = \tanh(\mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{W}_{xh}\mathbf{x}_t + \mathbf{b})
  • BPTT: Unrolls RNN through time, applies standard backpropagation
  • Vanishing gradients: Product of Jacobians shrinks exponentially — limits memory to ~10-20 steps
  • Exploding gradients: Product grows exponentially — use gradient clipping
  • Root cause: Recurrent weight matrix spectral norm ≠ 1
  • Solutions: Orthogonal initialization, LSTM/GRU gates, skip connections
  • Truncated BPTT: Limits computation by processing sequences in chunks
  • Weight sharing: Same parameters across time steps enables variable-length processing

Practice Exercises

  1. Mathematical: For a vanilla RNN with Whh=diag(0.9,1.1)\mathbf{W}_{hh} = \text{diag}(0.9, 1.1) and tanh(0)=1\tanh'(0) = 1, compute the gradient magnitude after 50 time steps. Does it vanish or explode?

  2. Experiment: Train a vanilla RNN on a long-range dependency task (e.g., copy last 50 elements of a 100-element sequence). What happens as sequence length increases?

  3. Coding: Implement gradient monitoring hooks in PyTorch to visualize gradient norms across time steps during BPTT. Plot the gradient magnitude as a function of time step.

  4. Research: Look up the "gradient highway" in orthogonal RNNs. How does orthogonal initialization create stable gradient paths?

  5. Application: Train a vanilla RNN for text generation on a small corpus. Does it capture long-range dependencies (e.g., subject-verb agreement across sentences)?

Advertisement

Need Expert Deep Learning Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement