GRU Networks — Gated Recurrent Units
GRUs are gated recurrent units that solve the vanishing gradient problem with a simpler architecture than LSTMs. They use only two gates and no separate cell state.
Why GRU?
Vanilla RNNs suffer from vanishing gradients, making it impossible to learn long-range dependencies. GRUs introduce gating mechanisms to control information flow across time steps.
DfGated Recurrent Unit
A GRU maintains a hidden state that is updated at each time step using two gates:
- Update gate : Controls how much of the previous hidden state to keep
- Reset gate : Controls how much of the previous hidden state to forget when computing the candidate
The hidden state is a convex interpolation between the old state and a new candidate, governed by the update gate.
Mathematical Formulation
Update Gate
Here,
- =Update gate output (0 = keep old, 1 = use candidate)
- =Weight matrix for update gate
- =Previous hidden state
- =Current input
- =Update gate bias
Reset Gate
Here,
- =Reset gate output (0 = forget, 1 = remember)
- =Weight matrix for reset gate
- =Reset gate bias
Candidate Hidden State
Here,
- =Candidate hidden state
- =Weight matrix for candidate
- =Element-wise product of reset gate and previous state
- =Candidate bias
ℹ️ Interpretation of Gates
When , the GRU retains the old hidden state (long-term memory). When , it replaces it with the candidate (short-term update). The reset gate determines how much past information influences the candidate — when , the candidate ignores previous state entirely.
GRU vs LSTM
DfLSTM (for Comparison)
LSTM maintains a separate cell state and uses three gates:
- Forget gate:
- Input gate:
- Output gate:
Cell update:
Output:
| Feature | GRU | LSTM |
|---|---|---|
| Gates | 2 (update, reset) | 3 (forget, input, output) |
| States | Hidden only | Hidden + Cell |
| Parameters | Fewer | More |
| Training speed | Faster | Slower |
| Performance | Comparable | Comparable |
| Short sequences | Often better | Slightly worse |
💡 When to Choose GRU
- When you have limited training data (fewer parameters = less overfitting)
- When computational efficiency matters
- When sequences are relatively short (< 200 time steps)
- When you want a simpler model that is easier to debug
PyTorch Implementation
📝Example: GRU from Scratch
import torch
import torch.nn as nn
class GRUCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# Combined weights for efficiency: W_z, W_r, W_h stacked
self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
def forward(self, x_t, h_prev):
# Concatenate input and previous hidden state
combined = torch.cat([x_t, h_prev], dim=1)
# Update gate: how much to keep vs replace
z_t = torch.sigmoid(self.W_z(combined))
# Reset gate: how much past to forget
r_t = torch.sigmoid(self.W_r(combined))
# Candidate: new information with reset gating
combined_r = torch.cat([x_t, r_t * h_prev], dim=1)
h_candidate = torch.tanh(self.W_h(combined_r))
# Final hidden state: interpolation
h_t = (1 - z_t) * h_prev + z_t * h_candidate
return h_t
class GRUModel(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim,
num_layers=2, dropout=0.3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.gru = nn.GRU(
embed_dim, hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0
)
self.fc = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x: (batch_size, seq_len)
embedded = self.dropout(self.embedding(x))
# gru_out: (batch, seq_len, hidden_dim)
# hidden: (num_layers, batch, hidden_dim)
gru_out, hidden = self.gru(embedded)
# Use last hidden state of final layer
last_hidden = hidden[-1]
return self.fc(self.dropout(last_hidden))
# Initialize model
model = GRUModel(
vocab_size=10000,
embed_dim=256,
hidden_dim=512,
output_dim=2,
num_layers=2,
dropout=0.3
)
# Example forward pass
x = torch.randint(0, 10000, (32, 50)) # batch=32, seq_len=50
output = model(x)
print(f"Output shape: {output.shape}") # (32, 2)
📝Example: Bidirectional GRU for Text Classification
class BiGRUClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim,
num_layers=2, dropout=0.3):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.bigru = nn.GRU(
embed_dim, hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0
)
self.attention = nn.Linear(hidden_dim * 2, 1)
self.fc = nn.Linear(hidden_dim * 2, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
embedded = self.dropout(self.embedding(x))
gru_out, hidden = self.bigru(embedded)
# gru_out: (batch, seq_len, hidden_dim * 2)
# Self-attention pooling
attn_weights = torch.softmax(
self.attention(gru_out).squeeze(-1), dim=1
)
context = torch.bmm(
attn_weights.unsqueeze(1), gru_out
).squeeze(1)
return self.fc(context)
Training Tips
💡 GRU Training Best Practices
- Gradient clipping: Clip gradients at norm 1.0 to prevent exploding gradients
- Learning rate scheduling: Use cosine annealing or reduce-on-plateau
- Layer normalization: Apply LayerNorm to GRU outputs for stable training
- Teacher forcing ratio: Decay from 1.0 to 0.0 during training for seq2seq
- Hidden size: Start with 256-512, increase for complex tasks
- Dropout: Apply dropout between layers (not within recurrent connections)
Practice Exercises
-
Implement a GRU from scratch: Write the forward pass of a GRU cell without using
nn.GRU. Verify your implementation matches PyTorch's output. -
GRU vs LSTM comparison: Train both models on a time series prediction task (e.g., predicting sine wave). Compare training time, parameter count, and final loss.
-
Character-level language model: Build a GRU-based language model that predicts the next character. Use temperature sampling for text generation.
-
Bidirectional GRU: Implement a bidirectional GRU for sentiment analysis. Compare performance with unidirectional GRU.
Key Takeaways
📋Summary: GRU Networks
- GRU uses two gates (update, reset) vs LSTM's three gates
- Update gate controls information retention — similar to LSTM's forget + input gates
- Reset gate controls past information influence on candidate state
- Fewer parameters than LSTM, often comparable performance
- Convex interpolation:
- Bidirectional GRUs capture both forward and backward context
- GRU is preferred when data is limited or sequences are short
- LSTM is preferred for very long sequences or when cell state is important
- Both are increasingly replaced by Transformers for NLP tasks
- GRUs remain valuable for time series, edge deployment, and low-resource settings