Attention Mechanism
The Attention Mechanism Explained
The attention mechanism allows models to dynamically focus on relevant parts of the input when producing each element of the output. It computes a weighted sum of values, where weights are determined by the compatibility between queries and keys.
Scaled Dot-Product Attention
Types of Attention
1. Self-Attention
Each position attends to all positions in the same sequence.
import torch
import torch.nn.functional as F
def self_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
# Example usage
batch_size, seq_len, d_model = 2, 10, 64
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output, weights = self_attention(Q, K, V)
print(f"Output shape: {output.shape}")
print(f"Weights shape: {weights.shape}")
2. Cross-Attention
Queries come from one sequence, keys and values from another.
class CrossAttention(torch.nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = torch.nn.Linear(d_model, d_model)
self.W_k = torch.nn.Linear(d_model, d_model)
self.W_v = torch.nn.Linear(d_model, d_model)
def forward(self, query, context, mask=None):
batch_size = query.size(0)
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(context).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(context).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output.transpose(1, 2).contiguous().view(batch_size, -1, Q.size(-1) * self.num_heads)
3. Causal (Masked) Attention
Prevents attending to future positions in autoregressive models.
def create_causal_mask(seq_len):
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = mask.masked_fill(mask == 1, -1e9)
return mask
# Example
mask = create_causal_mask(5)
print("Causal mask:")
print(mask)
Attention Patterns
Efficient Attention Variants
Flash Attention
Optimized attention implementation that reduces memory usage and increases speed.
# Using Flash Attention with PyTorch 2.0+
import torch
from torch.nn.functional import scaled_dot_product_attention
# Standard attention
def standard_attention(Q, K, V):
return scaled_dot_product_attention(Q, K, V, is_causal=True)
# Flash Attention (automatic with PyTorch 2.0+)
Q = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
K = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
V = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
# Automatically uses Flash Attention when available
output = scaled_dot_product_attention(Q, K, V, is_causal=True)
Summary
The attention mechanism is the core innovation enabling Transformers to capture relationships across sequences efficiently. Understanding its variants is crucial for working with modern generative AI models.
Next: We'll explore language model fundamentals and training objectives.