Inference Optimization
Flash Attention — IO-Aware Exact Attention
Flash Attention reformulates the standard attention algorithm to minimize GPU memory transfers, achieving 2-4x speedup while using less memory. It is the foundation of all modern LLM serving systems.
- IO-Aware Tiling — Process attention in blocks that fit in SRAM
- Exact Computation — Not an approximation — produces identical results
- Memory Efficiency — O(N) memory instead of O(N^2) for attention
The bottleneck in attention is not computation — it is memory movement.
Flash Attention and Memory Efficiency
Standard attention computes the full N x N attention matrix, requiring O(N^2) memory and multiple round-trips to GPU HBM (High Bandwidth Memory). Flash Attention (Dao et al., 2022) restructures the computation to minimize memory I/O, achieving the same result with 2-4x speedup.
DfFlash Attention
Flash Attention is an IO-aware exact attention algorithm that computes attention by tiling blocks in SRAM (on-chip memory), avoiding materialization of the full N x N attention matrix in HBM. It uses online softmax to compute attention incrementally.
The Memory Bottleneck
Standard Attention Computation
Standard Attention
Here,
- =Query matrix (N x d)
- =Key matrix (N x d)
- =Value matrix (N x d)
- =Key dimension
The standard implementation:
- Compute S = QK^T (N x N matrix) — O(N^2) compute and memory
- Compute P = softmax(S) — O(N^2) compute and memory
- Compute O = PV — O(N^2) compute, O(Nd) memory
Problem: Step 1 and 2 require writing/reading the N x N matrix to/from HBM, which is the bottleneck.
Flash Attention Tiling
DfIO-Aware Tiling
Flash Attention processes attention in blocks that fit in GPU SRAM (typically 192KB-256KB on A100). Each block computes a portion of the output, accumulating results using online softmax.
Standard Attention (HBM access): Flash Attention (SRAM):
Q, K, V in HBM Q, K, V in HBM
| |
v v
Compute S = QK^T (N x N) Load block of Q, K, V
| into SRAM
v |
Write S to HBM v
| Compute block attention
v in SRAM
Load S from HBM |
| Accumulate with online
v softmax statistics
Compute P = softmax(S) |
| v
v Write output block
Write P to HBM to HBM
| |
v v
Compute O = PV Done!
|
v
Done!
HBM Accesses: O(N^2 + Nd) HBM Accesses: O(N^2 d / M)
Flash Attention reduces HBM accesses from O(N^2 + Nd) to O(N^2 d / M), where M is the SRAM size. For typical values (N=4096, d=128, M=192KB), this gives ~3x speedup.
Online Softmax
DfOnline Softmax
Online softmax computes the softmax incrementally, processing one block at a time. It maintains running statistics (max and sum) to correctly normalize the final result without materializing the full softmax matrix.
Online Softmax Statistics
Here,
- =Updated running maximum
- =Previous running maximum
- =Maximum of current block
def flash_attention_block(Q_block, K_block, V_block, l_prev, m_prev):
"""Compute one block of flash attention."""
# Compute attention scores for this block
S_block = Q_block @ K_block.T / math.sqrt(Q_block.shape[-1])
# Update running max
m_block = S_block.max(dim=-1, keepdim=True).values
m_new = torch.maximum(m_prev, m_block)
# Compute exp with corrected max
P_block = torch.exp(S_block - m_new)
# Update running sum
l_block = P_block.sum(dim=-1, keepdim=True)
l_new = l_prev * torch.exp(m_prev - m_new) + l_block
# Update output
O_new = (l_prev * torch.exp(m_prev - m_new) * O_prev + P_block @ V_block) / l_new
return O_new, l_new, m_new
Flash Attention Variants
Flash Attention 2
DfFlash Attention 2
Flash Attention 2 (Dao, 2023) improves on the original by: (1) reducing non-matmul FLOPs, (2) better work partitioning across warps, (3) supporting head dimensions up to 256. It achieves ~2x speedup over Flash Attention 1.
| Feature | Flash Attention 1 | Flash Attention 2 |
|---|---|---|
| Max head dim | 128 | 256 |
| Warp partitioning | Fixed | Dynamic |
| Non-matmul FLOPs | 50% of total | 25% of total |
| Speedup | Baseline | ~2x faster |
Flash Attention 3
DfFlash Attention 3
Flash Attention 3 (Shah et al., 2024) leverages hardware-specific optimizations for Hopper GPUs (H100): asynchronous execution, FP8 support, and warp specialization. It achieves 1.5-2x speedup over Flash Attention 2 on H100.
Flash Attention 3 uses H100's TMA (Tensor Memory Accelerator) for asynchronous data loading and warp specialization to overlap computation with memory operations.
Implementation
import torch
import math
def flash_attention_forward(Q, K, V, block_size=256):
"""Simplified Flash Attention implementation."""
batch, seq_len, num_heads, head_dim = Q.shape
O = torch.zeros_like(Q)
l = torch.zeros(batch, num_heads, seq_len, 1, device=Q.device)
m = torch.full((batch, num_heads, seq_len, 1), float('-inf'), device=Q.device)
# Process K, V in blocks
for j in range(0, seq_len, block_size):
K_block = K[:, j:j+block_size, :, :]
V_block = V[:, j:j+block_size, :, :]
# Process Q in blocks
for i in range(0, seq_len, block_size):
Q_block = Q[:, i:i+block_size, :, :]
# Compute attention scores
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / math.sqrt(head_dim)
# Online softmax update
m_block = S_block.max(dim=-1, keepdim=True).values
m_new = torch.maximum(m[:, i:i+block_size], m_block)
P_block = torch.exp(S_block - m_new)
l_block = P_block.sum(dim=-1, keepdim=True)
# Update output
O[:, i:i+block_size] = (
torch.exp(m[:, i:i+block_size] - m_new) * O[:, i:i+block_size] +
torch.matmul(P_block, V_block)
)
l[:, i:i+block_size] = (
torch.exp(m[:, i:i+block_size] - m_new) * l[:, i:i+block_size] +
l_block
)
m[:, i:i+block_size] = m_new
# Normalize output
O = O / l
return O
Memory Comparison
| Method | Memory | Compute | Exact? |
|---|---|---|---|
| Standard Attention | O(N^2) | O(N^2 d) | Yes |
| Flash Attention | O(N) | O(N^2 d) | Yes |
| Sparse Attention | O(N sqrt(N)) | O(N sqrt(N) d) | Approximate |
| Linear Attention | O(N) | O(N d^2) | Approximate |
Flash Attention is the only method that achieves both O(N) memory AND exact computation. All other O(N) methods are approximations.
Practice Exercises
-
Memory Calculation: For a 70B model with 4096 sequence length, calculate the memory savings of Flash Attention vs standard attention.
-
Block Size Analysis: How does block size affect the speed of Flash Attention? What is the optimal block size for A100 vs H100 GPUs?
-
Implementation: Implement Flash Attention for a simplified 1-layer, 1-head attention mechanism and verify it produces identical results to standard attention.
-
Profiling: Profile the HBM access patterns of standard vs Flash Attention. Where does the speedup come from?
Key Takeaways
Summary: Flash Attention
- Standard attention is IO-bound, not compute-bound
- Flash Attention tiles attention blocks in SRAM, minimizing HBM access
- Online softmax enables incremental computation without materializing N x N matrix
- Exact computation — not an approximation, produces identical results
- O(N) memory instead of O(N^2) for attention
- 2-4x speedup over standard attention implementation
- Flash Attention 2 reduces non-matmul FLOPs and improves warp partitioning
- Flash Attention 3 leverages H100 hardware for additional speedup
- Foundation of modern serving — used by vLLM, TensorRT-LLM, TGI
What to Learn Next
-> KV Cache Optimization Reducing memory usage of the key-value cache.
-> Model Parallelism and Tensor Parallelism Splitting models across GPUs.
-> Attention Mechanisms Deep Dive Understanding attention in neural networks.
-> LLM Inference Optimization Broader inference optimization strategies.
-> Long Context and Context Window Handling very long sequences.
-> Transformers The architecture that enables attention.