LSTM Networks — Gates, Cell State & Bidirectional Architectures

Sequence ModelsLSTMFree Lesson

Advertisement

LSTM Networks — Gates, Cell State & Bidirectional Architectures

LSTM (Long Short-Term Memory) networks solve the vanishing gradient problem of vanilla RNNs by using gating mechanisms to control information flow across time steps.

See our RNN Deep Dive tutorial for the fundamentals of vanilla RNNs and why LSTMs were invented.


Why LSTMs?

DfThe Problem LSTMs Solve

Vanilla RNNs suffer from vanishing gradients, limiting their ability to learn long-range dependencies. For sequences longer than ~20 steps, the gradient becomes too small to update early weights effectively.

LSTMs solve this by introducing a cell state — an information highway that allows gradients to flow through time with minimal degradation. The gates control what information is added to, kept in, or removed from this highway.


LSTM Cell Architecture

DfLSTM Cell

An LSTM cell has four components:

  1. Forget Gate: Decides what to throw away from the cell state
  2. Input Gate: Decides what new information to store
  3. Cell Candidate: Creates candidate values to add
  4. Output Gate: Decides what to output based on cell state

The cell state ct\mathbf{c}_t acts as a conveyor belt, carrying information across time steps with only linear interactions (addition), preventing vanishing gradients.


The Three Gates

Forget Gate

DfForget Gate

The forget gate decides what to discard from the cell state:

ft=σ(Wf[ht1,xt]+bf)\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)

Output: Vector in (0,1)d(0, 1)^d where dd is the hidden dimension. Values near 0 mean "forget this"; values near 1 mean "keep this."

Forget Gate

ft=σ(Wf[ht1,xt]+bf)\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)

Here,

  • ft\mathbf{f}_t=Forget gate activation (0 to 1 per dimension)
  • σ\sigma=Sigmoid function (outputs 0-1)
  • [ht1,xt][\mathbf{h}_{t-1}, \mathbf{x}_t]=Concatenation of previous hidden state and current input
  • Wf,bf\mathbf{W}_f, \mathbf{b}_f=Forget gate weights and bias

Input Gate

DfInput Gate

The input gate decides what new information to store:

it=σ(Wi[ht1,xt]+bi)\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i)
c~t=tanh(Wc[ht1,xt]+bc)\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)

it\mathbf{i}_t controls which values to update; c~t\tilde{\mathbf{c}}_t creates candidate values to add.

Input Gate

it=σ(Wi[ht1,xt]+bi),c~t=tanh(Wc[ht1,xt]+bc)\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i), \quad \tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)

Here,

  • it\mathbf{i}_t=Input gate activation (what to update)
  • c~t\tilde{\mathbf{c}}_t=Candidate cell state values
  • tanh\tanh=Tanh function (outputs -1 to 1)

Cell State Update

DfCell State Update

The cell state is updated by forgetting old information and adding new information:

ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

The key insight: this is a linear operation (no matrix multiplication). Gradients flow through cell state updates without vanishing because addition preserves gradient magnitude.

ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

Output Gate

DfOutput Gate

The output gate decides what to output from the cell state:

ot=σ(Wo[ht1,xt]+bo)\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o)
ht=ottanh(ct)\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)

The hidden state is a filtered version of the cell state.

Output Gate

ot=σ(Wo[ht1,xt]+bo),ht=ottanh(ct)\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o), \quad \mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)

Here,

  • ot\mathbf{o}_t=Output gate activation (what to output)
  • ht\mathbf{h}_t=Final hidden state
  • ct\mathbf{c}_t=Updated cell state
  • \odot=Element-wise multiplication

Complete LSTM Equations

DfComplete LSTM Forward Pass

At each time step tt:

  1. ft=σ(Wf[ht1,xt]+bf)\mathbf{f}_t = \sigma(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) — Forget gate
  2. it=σ(Wi[ht1,xt]+bi)\mathbf{i}_t = \sigma(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) — Input gate
  3. c~t=tanh(Wc[ht1,xt]+bc)\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c) — Cell candidate
  4. ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t — Cell state update
  5. ot=σ(Wo[ht1,xt]+bo)\mathbf{o}_t = \sigma(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o) — Output gate
  6. ht=ottanh(ct)\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t) — Hidden state

Why LSTMs Solve Vanishing Gradients

ThLSTM Gradient Flow

The cell state gradient through time satisfies:

ctct1=diag(ft)+higher order terms\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \text{diag}(\mathbf{f}_t) + \text{higher order terms}

When the forget gate ft1\mathbf{f}_t \approx 1, the gradient is approximately the identity matrix, preserving gradient magnitude across time steps. The forget gate acts as a gradient highway — when it learns to keep values (output near 1), gradients flow unattenuated.

ℹ️ Forget Gate Bias Initialization

Initialize the forget gate bias to 1.0 (or larger, e.g., 2.0) instead of 0. This makes σ(1)0.73\sigma(1) \approx 0.73, encouraging the network to initially keep most information. With bias=0, σ(0)=0.5\sigma(0) = 0.5 causes aggressive forgetting early in training.


Bidirectional LSTM

DfBidirectional LSTM

A bidirectional LSTM processes the sequence in both directions:

ht=LSTM(xt,ht1)\overrightarrow{\mathbf{h}_t} = \overrightarrow{\text{LSTM}}(\mathbf{x}_t, \overrightarrow{\mathbf{h}_{t-1}})
ht=LSTM(xt,ht+1)\overleftarrow{\mathbf{h}_t} = \overleftarrow{\text{LSTM}}(\mathbf{x}_t, \overleftarrow{\mathbf{h}_{t+1}})
ht=[ht;ht]\mathbf{h}_t = [\overrightarrow{\mathbf{h}_t}; \overleftarrow{\mathbf{h}_t}]

Each output is the concatenation of forward and backward hidden states. This allows the model to use both past and future context, critical for tasks like NER, POS tagging, and machine translation.

Bidirectional LSTM Output

ht=[ht;ht]R2d\mathbf{h}_t = [\overrightarrow{\mathbf{h}_t}; \overleftarrow{\mathbf{h}_t}] \in \mathbb{R}^{2d}

Here,

  • ht\overrightarrow{\mathbf{h}_t}=Forward hidden state
  • ht\overleftarrow{\mathbf{h}_t}=Backward hidden state
  • dd=Hidden dimension (output is 2d)

💡 When to Use Bidirectional

  • Use bidirectional when you have the full sequence available (NER, sentiment analysis, machine translation encoder)
  • Use unidirectional when you must predict online (time series forecasting, streaming data, decoder)
  • Bidirectional doubles parameters and computation

Stacked LSTM

DfStacked LSTM

Stack multiple LSTM layers by feeding the output of one layer as input to the next:

ht(l)=LSTM(l)(ht(l1),ht1(l))\mathbf{h}_t^{(l)} = \text{LSTM}^{(l)}(\mathbf{h}_t^{(l-1)}, \mathbf{h}_{t-1}^{(l)})

where ht(0)=xt\mathbf{h}_t^{(0)} = \mathbf{x}_t. Deeper LSTMs can learn more abstract representations, but are harder to train and prone to overfitting.


PyTorch Implementation

📝Example: Complete LSTM in PyTorch

import torch
import torch.nn as nn

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim,
                 output_dim, num_layers=2, bidirectional=True, dropout=0.5):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0
        )
        lstm_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.fc = nn.Linear(lstm_output_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

        # Initialize forget gate bias to 1
        for name, param in self.lstm.named_parameters():
            if 'bias' in name:
                n = param.size(0)
                param.data[n//4:n//2].fill_(1.0)  # forget gate bias

    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.dropout(self.embedding(x))
        # embedded: (batch, seq_len, embed_dim)

        output, (hidden, cell) = self.lstm(embedded)
        # output: (batch, seq_len, hidden_dim * num_directions)
        # hidden: (num_layers * num_directions, batch, hidden_dim)

        # Use last hidden state from the last layer
        if self.lstm.bidirectional:
            hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        else:
            hidden = hidden[-1]

        # hidden: (batch, hidden_dim * num_directions)
        return self.fc(self.dropout(hidden))

# Test
model = LSTMClassifier(
    vocab_size=10000,
    embed_dim=128,
    hidden_dim=256,
    output_dim=5,
    num_layers=2,
    bidirectional=True
)

x = torch.randint(0, 10000, (32, 100))  # batch=32, seq_len=100
out = model(x)
print(f"Output shape: {out.shape}")  # [32, 5]
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

📝Example: LSTM for Time Series Forecasting

import torch
import torch.nn as nn

class LSTMForecaster(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=2, forecast_horizon=24):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.1
        )
        self.fc = nn.Linear(hidden_size, forecast_horizon)

    def forward(self, x):
        # x: (batch, lookback, features)
        lstm_out, (h_n, c_n) = self.lstm(x)
        # Use last time step's output
        last_hidden = lstm_out[:, -1, :]
        return self.fc(last_hidden)

# Test with synthetic time series
model = LSTMForecaster(input_size=3, hidden_size=64, forecast_horizon=24)
x = torch.randn(32, 168, 3)  # 168 hours of history, 3 features
pred = model(x)
print(f"Prediction shape: {pred.shape}")  # [32, 24] (next 24 hours)

📝Example: Bidirectional LSTM for NER

import torch
import torch.nn as nn

class BiLSTM_NER(nn.Module):
    def __init__(self, vocab_size, tagset_size, embed_dim=128,
                 hidden_dim=256, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim // 2,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.3
        )
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, x):
        embeds = self.embedding(x)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.hidden2tag(lstm_out)
        return tag_space  # (batch, seq_len, num_tags)

# Test for NER (token-level classification)
model = BiLSTM_NER(vocab_size=5000, tagset_size=9)  # 9 NER tags
tokens = torch.randint(0, 5000, (16, 50))  # batch=16, seq_len=50
logits = model(tokens)
print(f"NER output shape: {logits.shape}")  # [16, 50, 9]

# Cross-entropy loss for token classification
tags = torch.randint(0, 9, (16, 50))
loss = nn.CrossEntropyLoss()(logits.view(-1, 9), tags.view(-1))
print(f"NER loss: {loss.item():.4f}")

LSTM vs. GRU Comparison

FeatureLSTMGRU
Gates3 (forget, input, output)2 (update, reset)
Cell stateYesNo
ParametersMore (4×d24 \times d^2)Fewer (3×d23 \times d^2)
SpeedSlowerFaster
PerformanceSimilarSimilar
MemoryBetter for very long sequencesCompetitive

ℹ️ When to Use Which

  • LSTM: Long sequences, when you need fine-grained control over memory
  • GRU: Shorter sequences, faster training, fewer parameters
  • Transformers: Now dominant for most sequence tasks, but LSTMs still useful for streaming/online settings

Summary

📋Summary: LSTM Networks

  • LSTM solves vanishing gradients via gating mechanisms and cell state
  • Forget gate: Controls what to discard from memory (ft=σ()\mathbf{f}_t = \sigma(\cdot))
  • Input gate: Controls what to add to memory (it=σ()\mathbf{i}_t = \sigma(\cdot))
  • Cell state update: Linear operation preserves gradient flow (ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t)
  • Output gate: Controls what to output (ot=σ()\mathbf{o}_t = \sigma(\cdot))
  • Bidirectional: Processes sequences in both directions, uses past and future context
  • Stacked: Multiple layers for abstract representations
  • Forget gate bias: Initialize to 1.0 for better gradient flow
  • LSTMs still useful: Streaming data, edge devices, when Transformers are too large

Practice Exercises

  1. Mathematical: Derive the gradient ctct1\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} for the LSTM cell state update. Why does this prevent vanishing gradients?

  2. Coding: Implement an LSTM cell from scratch using only torch.tensor operations (no nn.LSTM). Verify that your implementation matches PyTorch's output.

  3. Experiment: Compare LSTM vs. GRU vs. Transformer on a long-range dependency task (e.g., copying the first element of a 500-element sequence). Which handles long sequences best?

  4. Application: Build a bidirectional LSTM for sentiment analysis on IMDB reviews. Compare with a Transformer-based model. What are the tradeoffs?

  5. Research: Read the original LSTM paper (Hochreiter & Schmidhuber, 1997). What was the original motivation? How has the architecture evolved?

Advertisement

Need Expert Deep Learning Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement