Attention Mechanisms — Deep Dive
Attention allows models to dynamically focus on relevant parts of the input when producing each output element. It is the core innovation behind Transformers.
Why Attention?
Seq2seq models compress the entire input into a single fixed-length vector, creating an information bottleneck. Attention solves this by allowing the decoder to "look at" all encoder states at each decoding step.
DfAttention Mechanism
Attention computes a weighted sum of values (encoder states) where the weights are determined by the compatibility between a query (decoder state) and keys (encoder states):
where are attention weights that sum to 1.
Bahdanau Attention
DfBahdanau (Additive) Attention
Proposed by Bahdanau et al. (2015), this uses a learned feedforward network to compute alignment scores:
Bahdanau Attention Score
Here,
- =Alignment score between decoder state i and encoder state j
- =Learnable weight vector
- =Decoder state projection
- =Encoder state projection
- =Previous decoder hidden state
- =Encoder hidden state at position j
Luong Attention
DfLuong (Multiplicative) Attention
Proposed by Luong et al. (2015), this uses simpler dot-product or general scoring:
Dot:
General:
Concat:
Luong General Score
Here,
- =Decoder hidden state at step i
- =Learnable alignment matrix
- =Encoder hidden state at position j
| Type | Score Function | Complexity | Parameters |
|---|---|---|---|
| Dot | None | ||
| General | |||
| Concat | |||
| Additive |
Self-Attention
DfSelf-Attention
Self-attention computes attention within a single sequence, allowing each position to attend to all other positions:
Each token can directly attend to every other token, capturing long-range dependencies in sequential operations.
ℹ️ Why Scale by sqrt(d_k)?
Without scaling, dot products grow large with dimension , pushing softmax into regions with extremely small gradients. Dividing by keeps variance stable, enabling effective training. This was a critical insight in "Attention Is All You Need" (Vaswani et al., 2017).
Multi-Head Attention
DfMulti-Head Attention
Instead of a single attention function, project queries, keys, and values into different subspaces, compute attention in parallel, and concatenate:
where each head:
💡 Why Multiple Heads?
Different heads can learn to attend to different types of relationships: syntactic structure, semantic similarity, positional patterns, etc. Typically or with .
Theorem: Attention as Differentiable Lookup
ThAttention as Soft Lookup
Hard lookup (dictionary retrieval) returns for a specific key . Attention generalizes this to a differentiable soft lookup that returns a weighted average of all values, with weights determined by key-query similarity. As temperature , attention converges to hard lookup.
Temperature-Dependent Attention
Here,
- =Temperature parameter
- =Attention score for position i
- =Attention weight
Full PyTorch Implementation
📝Example: Multi-Head Attention from Scratch
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k, dropout=0.1):
super().__init__()
self.scale = math.sqrt(d_k)
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
# Q, K, V: (batch, num_heads, seq_len, d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, V)
return output, attn_weights
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(self.d_k, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections: (batch, seq_len, d_model) -> (batch, seq_len, d_model)
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# Reshape: (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Apply attention
output, attn_weights = self.attention(Q, K, V, mask)
# Reshape: (batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
output = output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
# Final linear projection
output = self.W_o(output)
return output, attn_weights
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, d_model)
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# Test multi-head attention
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512) # batch=2, seq_len=10, d_model=512
output, weights = mha(x, x, x) # self-attention
print(f"Output shape: {output.shape}") # (2, 10, 512)
print(f"Weights shape: {weights.shape}") # (2, 8, 10, 10)
Attention Patterns
ℹ️ Understanding Attention Patterns
- Local attention: Tokens attend mainly to nearby positions — common in early layers
- Global attention: Some tokens attend to all positions — e.g., [CLS] token in BERT
- Distributed attention: Weights spread across many positions — captures semantic similarity
- Block attention: Attention confined to windows — used in efficient Transformers (Longformer, BigBird)
Practice Exercises
-
Implement attention variants: Code dot, general, and concat attention. Compare on a fixed sequence.
-
Visualize attention: Train a seq2seq model and plot attention heatmaps. What patterns emerge?
-
Multi-head analysis: Train with different numbers of heads (1, 4, 8, 16). How does performance change?
-
Efficient attention: Implement linear attention (kernel-based) and compare with standard attention on long sequences.
Key Takeaways
📋Summary: Attention Mechanisms
- Bahdanau attention: Additive scoring, learned alignment
- Luong attention: Multiplicative scoring, simpler and faster
- Self-attention: Each position attends to all positions in the same sequence
- Scaled dot-product: — the standard attention mechanism
- Multi-head attention: Parallel attention heads capture different relationship types
- Attention as soft lookup: Generalizes hard dictionary lookup to differentiable weighted average
- Attention enables sequential operations — fully parallelizable
- Foundation of Transformers, BERT, GPT, and modern NLP
- See also: Transformers for the complete architecture