RNN, LSTM & GRU — Complete Guide
Recurrent networks process sequential data by maintaining a hidden state that carries information across time steps.
Vanilla RNN
h_t = tanh(W_hh × h_{t-1} + W_xh × x_t + b)
At each time step:
1. Take previous hidden state h_{t-1}
2. Combine with current input x_t
3. Apply activation
4. Output prediction
Problem: VANISHING GRADIENTS
├─ Gradients shrink exponentially through time
├─ Cannot learn long-range dependencies
└─ Limited to ~10-20 time steps
LSTM (Long Short-Term Memory)
LSTM Cell:
┌───────────────────────────────────────┐
│ Cell State (c_t) — information highway │
│ ───────────────────────────────────── │
│ ↑ forget × c_{t-1} + input × content │
│ │ │
│ ┌──────────┐ │
│ │ Forget │ → What to discard │
│ │ Gate │ σ(W·[h_{t-1}, x_t]) │
│ └──────────┘ │
│ ┌──────────┐ │
│ │ Input │ → What to store │
│ │ Gate │ σ(W·[h_{t-1}, x_t]) │
│ └──────────┘ │
│ ┌──────────┐ │
│ │ Cell │ → New candidate values │
│ │ Candidate│ tanh(W·[h_{t-1}, x_t])│
│ └──────────┘ │
│ ┌──────────┐ │
│ │ Output │ → What to output │
│ │ Gate │ σ(W·[h_{t-1}, x_t]) │
│ └──────────┘ │
└───────────────────────────────────────┘
Three gates control information flow:
Forget: What to throw away
Input: What new information to store
Output: What to output based on cell state
GRU (Gated Recurrent Unit)
GRU simplifies LSTM:
├─ Two gates (reset, update) instead of three
├─ No separate cell state
├─ Fewer parameters
└─ Often performs as well as LSTM
Update gate: z = σ(W_z·[h_{t-1}, x_t])
Reset gate: r = σ(W_r·[h_{t-1}, x_t])
Candidate: h̃ = tanh(W·[r ⊙ h_{t-1}, x_t])
Output: h_t = (1-z) ⊙ h_{t-1} + z ⊙ h̃
PyTorch Implementation
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=2,
batch_first=True, dropout=0.3)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
embedded = self.embedding(x)
output, (hidden, cell) = self.lstm(embedded)
return self.fc(hidden[-1])
Key Takeaways
- RNNs process sequential data with hidden state
- LSTM solves vanishing gradients with gates
- GRU is simpler, often performs equally well
- Bidirectional RNNs process sequences in both directions
- Seq2Seq models (encoder-decoder) for translation
- LSTMs are being replaced by Transformers for most tasks
- Teacher forcing accelerates training for seq2seq
- LSTMs still useful for time series and edge devices