CW

Flash Attention and Memory Efficiency

Inference OptimizationAttention MechanismsFree Lesson

Advertisement

Inference Optimization

Flash Attention — IO-Aware Exact Attention

Flash Attention reformulates the standard attention algorithm to minimize GPU memory transfers, achieving 2-4x speedup while using less memory. It is the foundation of all modern LLM serving systems.

  • IO-Aware Tiling — Process attention in blocks that fit in SRAM
  • Exact Computation — Not an approximation — produces identical results
  • Memory Efficiency — O(N) memory instead of O(N^2) for attention

The bottleneck in attention is not computation — it is memory movement.

Flash Attention and Memory Efficiency

Standard attention computes the full N x N attention matrix, requiring O(N^2) memory and multiple round-trips to GPU HBM (High Bandwidth Memory). Flash Attention (Dao et al., 2022) restructures the computation to minimize memory I/O, achieving the same result with 2-4x speedup.

DfFlash Attention

Flash Attention is an IO-aware exact attention algorithm that computes attention by tiling blocks in SRAM (on-chip memory), avoiding materialization of the full N x N attention matrix in HBM. It uses online softmax to compute attention incrementally.

The Memory Bottleneck

Standard Attention Computation

Standard Attention

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Here,

  • QQ=Query matrix (N x d)
  • KK=Key matrix (N x d)
  • VV=Value matrix (N x d)
  • dkd_k=Key dimension

The standard implementation:

  1. Compute S = QK^T (N x N matrix) — O(N^2) compute and memory
  2. Compute P = softmax(S) — O(N^2) compute and memory
  3. Compute O = PV — O(N^2) compute, O(Nd) memory

Problem: Step 1 and 2 require writing/reading the N x N matrix to/from HBM, which is the bottleneck.

Flash Attention Tiling

DfIO-Aware Tiling

Flash Attention processes attention in blocks that fit in GPU SRAM (typically 192KB-256KB on A100). Each block computes a portion of the output, accumulating results using online softmax.

Architecture Diagram
Standard Attention (HBM access):              Flash Attention (SRAM):
                                              
Q, K, V in HBM                                Q, K, V in HBM
     |                                              |
     v                                              v
Compute S = QK^T (N x N)                     Load block of Q, K, V
     |                                        into SRAM
     v                                              |
Write S to HBM                                    v
     |                                        Compute block attention
     v                                        in SRAM
Load S from HBM                                   |
     |                                        Accumulate with online
     v                                        softmax statistics
Compute P = softmax(S)                             |
     |                                              v
     v                                        Write output block
Write P to HBM                                to HBM
     |                                              |
     v                                              v
Compute O = PV                                 Done!
     |
     v
Done!

HBM Accesses: O(N^2 + Nd)                   HBM Accesses: O(N^2 d / M)

Flash Attention reduces HBM accesses from O(N^2 + Nd) to O(N^2 d / M), where M is the SRAM size. For typical values (N=4096, d=128, M=192KB), this gives ~3x speedup.

Online Softmax

DfOnline Softmax

Online softmax computes the softmax incrementally, processing one block at a time. It maintains running statistics (max and sum) to correctly normalize the final result without materializing the full softmax matrix.

Online Softmax Statistics

mnew=max(mold,mblock)m_{\text{new}} = \max(m_{\text{old}}, m_{\text{block}})

Here,

  • mnewm_{\text{new}}=Updated running maximum
  • moldm_{\text{old}}=Previous running maximum
  • mblockm_{\text{block}}=Maximum of current block
def flash_attention_block(Q_block, K_block, V_block, l_prev, m_prev):
    """Compute one block of flash attention."""
    # Compute attention scores for this block
    S_block = Q_block @ K_block.T / math.sqrt(Q_block.shape[-1])
    
    # Update running max
    m_block = S_block.max(dim=-1, keepdim=True).values
    m_new = torch.maximum(m_prev, m_block)
    
    # Compute exp with corrected max
    P_block = torch.exp(S_block - m_new)
    
    # Update running sum
    l_block = P_block.sum(dim=-1, keepdim=True)
    l_new = l_prev * torch.exp(m_prev - m_new) + l_block
    
    # Update output
    O_new = (l_prev * torch.exp(m_prev - m_new) * O_prev + P_block @ V_block) / l_new
    
    return O_new, l_new, m_new

Flash Attention Variants

Flash Attention 2

DfFlash Attention 2

Flash Attention 2 (Dao, 2023) improves on the original by: (1) reducing non-matmul FLOPs, (2) better work partitioning across warps, (3) supporting head dimensions up to 256. It achieves ~2x speedup over Flash Attention 1.

FeatureFlash Attention 1Flash Attention 2
Max head dim128256
Warp partitioningFixedDynamic
Non-matmul FLOPs50% of total25% of total
SpeedupBaseline~2x faster

Flash Attention 3

DfFlash Attention 3

Flash Attention 3 (Shah et al., 2024) leverages hardware-specific optimizations for Hopper GPUs (H100): asynchronous execution, FP8 support, and warp specialization. It achieves 1.5-2x speedup over Flash Attention 2 on H100.

Flash Attention 3 uses H100's TMA (Tensor Memory Accelerator) for asynchronous data loading and warp specialization to overlap computation with memory operations.

Implementation

import torch
import math

def flash_attention_forward(Q, K, V, block_size=256):
    """Simplified Flash Attention implementation."""
    batch, seq_len, num_heads, head_dim = Q.shape
    
    O = torch.zeros_like(Q)
    l = torch.zeros(batch, num_heads, seq_len, 1, device=Q.device)
    m = torch.full((batch, num_heads, seq_len, 1), float('-inf'), device=Q.device)
    
    # Process K, V in blocks
    for j in range(0, seq_len, block_size):
        K_block = K[:, j:j+block_size, :, :]
        V_block = V[:, j:j+block_size, :, :]
        
        # Process Q in blocks
        for i in range(0, seq_len, block_size):
            Q_block = Q[:, i:i+block_size, :, :]
            
            # Compute attention scores
            S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / math.sqrt(head_dim)
            
            # Online softmax update
            m_block = S_block.max(dim=-1, keepdim=True).values
            m_new = torch.maximum(m[:, i:i+block_size], m_block)
            
            P_block = torch.exp(S_block - m_new)
            l_block = P_block.sum(dim=-1, keepdim=True)
            
            # Update output
            O[:, i:i+block_size] = (
                torch.exp(m[:, i:i+block_size] - m_new) * O[:, i:i+block_size] +
                torch.matmul(P_block, V_block)
            )
            
            l[:, i:i+block_size] = (
                torch.exp(m[:, i:i+block_size] - m_new) * l[:, i:i+block_size] +
                l_block
            )
            m[:, i:i+block_size] = m_new
    
    # Normalize output
    O = O / l
    return O

Memory Comparison

MethodMemoryComputeExact?
Standard AttentionO(N^2)O(N^2 d)Yes
Flash AttentionO(N)O(N^2 d)Yes
Sparse AttentionO(N sqrt(N))O(N sqrt(N) d)Approximate
Linear AttentionO(N)O(N d^2)Approximate

Flash Attention is the only method that achieves both O(N) memory AND exact computation. All other O(N) methods are approximations.

Practice Exercises

  1. Memory Calculation: For a 70B model with 4096 sequence length, calculate the memory savings of Flash Attention vs standard attention.

  2. Block Size Analysis: How does block size affect the speed of Flash Attention? What is the optimal block size for A100 vs H100 GPUs?

  3. Implementation: Implement Flash Attention for a simplified 1-layer, 1-head attention mechanism and verify it produces identical results to standard attention.

  4. Profiling: Profile the HBM access patterns of standard vs Flash Attention. Where does the speedup come from?

Key Takeaways

Summary: Flash Attention

  • Standard attention is IO-bound, not compute-bound
  • Flash Attention tiles attention blocks in SRAM, minimizing HBM access
  • Online softmax enables incremental computation without materializing N x N matrix
  • Exact computation — not an approximation, produces identical results
  • O(N) memory instead of O(N^2) for attention
  • 2-4x speedup over standard attention implementation
  • Flash Attention 2 reduces non-matmul FLOPs and improves warp partitioning
  • Flash Attention 3 leverages H100 hardware for additional speedup
  • Foundation of modern serving — used by vLLM, TensorRT-LLM, TGI

What to Learn Next

-> KV Cache Optimization Reducing memory usage of the key-value cache.

-> Model Parallelism and Tensor Parallelism Splitting models across GPUs.

-> Attention Mechanisms Deep Dive Understanding attention in neural networks.

-> LLM Inference Optimization Broader inference optimization strategies.

-> Long Context and Context Window Handling very long sequences.

-> Transformers The architecture that enables attention.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement