CW

KV Cache Optimization

Inference OptimizationMemory ManagementFree Lesson

Advertisement

Inference Optimization

KV Cache Optimization — The Memory Bottleneck

The KV cache stores key-value tensors for all previous tokens during generation. For large models, it often exceeds the model's own memory footprint. Optimization is essential for serving.

  • PagedAttention — Non-contiguous memory allocation eliminates fragmentation
  • RadixAttention — Prefix sharing across requests reduces redundant computation
  • Cache Quantization — Store KV cache in lower precision for 2-4x memory reduction

The KV cache is the hidden cost of autoregressive generation.

KV Cache Optimization for LLMs

During autoregressive generation, each new token requires attention over all previous tokens. The KV cache stores pre-computed key and value tensors, avoiding redundant computation. However, for long sequences and large models, the KV cache becomes the primary memory bottleneck.

DfKV Cache

The KV cache stores the key (K) and value (V) tensors computed during attention for all previous tokens. Without caching, each new token would require re-computing attention over all previous tokens, making generation O(n^2) instead of O(n).

KV Cache Memory Requirements

KV Cache Size

MKV=2×L×nheads×dhead×S×B×bytesM_{\text{KV}} = 2 \times L \times n_{\text{heads}} \times d_{\text{head}} \times S \times B \times \text{bytes}

Here,

  • LL=Number of transformer layers
  • nheadsn_{\text{heads}}=Number of attention heads
  • dheadd_{\text{head}}=Dimension per head
  • SS=Sequence length
  • BB=Batch size
  • bytesbytes=Bytes per element (2 for FP16, 1 for INT8)

KV Cache Memory for Popular Models

For sequence length 4096, batch size 1, FP16:

ModelLayersHeadsHead DimKV Cache Size
LLaMA-2 7B32321281 GB
LLaMA-2 70B806412810 GB
GPT-4 (est.)1209612822.5 GB

The KV cache for a single request can exceed the model's parameter memory!

PagedAttention

The Problem: Memory Fragmentation

DfKV Cache Fragmentation

Traditional KV cache allocation requires contiguous memory blocks for each request. As requests of varying lengths are served, memory becomes fragmented — free memory exists in small chunks that cannot be allocated, leading to 50-80% memory waste.

The Solution: PagedAttention

DfPagedAttention

PagedAttention (Kwon et al., 2023) borrows from operating system virtual memory. It divides the KV cache into fixed-size "pages" that can be stored non-contiguously in GPU memory. A page table maps logical KV cache positions to physical memory blocks.

Architecture Diagram
Logical KV Cache:  [Page 0] [Page 1] [Page 2] [Page 3]
                         |         |         |         |
Physical Memory:  [Block 5] [Block 2] [Block 8] [Block 1]
                         |         |         |         |
                   GPU Memory: Non-contiguous allocation

PagedAttention reduces memory waste from 50-80% to less than 4%. This enables serving 2-4x more concurrent requests with the same GPU memory.

Implementation

class PagedKVCache:
    def __init__(self, page_size=16, num_pages=1024, num_layers=32, num_heads=32, head_dim=128):
        self.page_size = page_size
        self.num_pages = num_pages
        self.page_table = {}  # request_id -> list of physical pages
        self.free_pages = list(range(num_pages))
        
        # Pre-allocate all pages
        self.k_cache = torch.zeros(num_pages, page_size, num_heads, head_dim, device="cuda")
        self.v_cache = torch.zeros(num_pages, page_size, num_heads, head_dim, device="cuda")
    
    def allocate_page(self, request_id):
        if not self.free_pages:
            raise MemoryError("No free pages available")
        page = self.free_pages.pop()
        if request_id not in self.page_table:
            self.page_table[request_id] = []
        self.page_table[request_id].append(page)
        return page
    
    def free_request(self, request_id):
        if request_id in self.page_table:
            self.free_pages.extend(self.page_table[request_id])
            del self.page_table[request_id]
    
    def get_kv(self, request_id, position):
        page_idx = position // self.page_size
        offset = position % self.page_size
        physical_page = self.page_table[request_id][page_idx]
        return self.k_cache[physical_page, offset], self.v_cache[physical_page, offset]
    
    def store_kv(self, request_id, position, k, v):
        page_idx = position // self.page_size
        offset = position % self.page_size
        
        if page_idx >= len(self.page_table.get(request_id, [])):
            self.allocate_page(request_id)
        
        physical_page = self.page_table[request_id][page_idx]
        self.k_cache[physical_page, offset] = k
        self.v_cache[physical_page, offset] = v

RadixAttention

DfRadixAttention

RadixAttention (Zheng et al., 2024) uses a radix tree (trie) to automatically share KV cache prefixes across requests with common prefixes. This is particularly effective for system prompts shared across many requests.

Architecture Diagram
Radix Tree for KV Cache Sharing:
              [System Prompt]
                    |
          [User A Response Start]
         /                        \
[User A Response 1]    [User A Response 2]
         |                        |
[Generated A1]           [Generated A2]

All three leaf nodes share the System Prompt KV cache.

In a typical chat application, 60-80% of the KV cache is the system prompt. RadixAttention enables sharing this across all requests, reducing per-request memory from 1GB to 200MB.

Prefix Caching Benefits

ScenarioPrefix SizeRequestsMemory SavedLatency Reduction
System prompt sharing2000 tokens100099%40% (fewer prefill steps)
Document QA with context8000 tokens10095%60%
Few-shot prompting500 tokens50080%20%
Code completion1000 tokens200090%30%

KV Cache Quantization

DfKV Cache Quantization

KV cache quantization stores key and value tensors in lower precision (INT8 or INT4) instead of FP16. This reduces KV cache memory by 2-4x with minimal quality degradation.

def quantize_kv_cache(k_cache, v_cache, bits=8):
    """Quantize KV cache to lower precision."""
    if bits == 8:
        k_quantized = k_cache.to(torch.int8)
        v_quantized = v_cache.to(torch.int8)
        k_scale = k_cache.abs().max() / 127
        v_scale = v_cache.abs().max() / 127
        return k_quantized, v_quantized, k_scale, v_scale
    elif bits == 4:
        # Group quantization with group size 128
        group_size = 128
        k_groups = k_cache.reshape(-1, group_size)
        k_max = k_groups.abs().max(dim=1, keepdim=True).values
        k_scale = k_max / 7
        k_quantized = (k_groups / k_scale).round().to(torch.int8).reshape(k_cache.shape)
        return k_quantized, k_scale

Quantization Error Bound

K^KK2b1\| \hat{K} - K \|_{\infty} \leq \frac{\|K\|_{\infty}}{2^{b-1}}

Here,

  • K^\hat{K}=Quantized key tensor
  • KK=Original key tensor
  • bb=Number of bits for quantization

Grouped-Query Attention (GQA)

DfGrouped-Query Attention

GQA shares key-value heads across multiple query heads, reducing the KV cache size proportionally. LLaMA-2 70B uses GQA with 8 KV heads shared across 64 query heads, reducing KV cache by 8x.

GQA KV Cache Reduction

MGQA=MMHA×nkv_headsnquery_headsM_{\text{GQA}} = M_{\text{MHA}} \times \frac{n_{\text{kv\_heads}}}{n_{\text{query\_heads}}}

Here,

  • MGQAM_{\text{GQA}}=GQA KV cache size
  • MMHAM_{\text{MHA}}=Standard MHA KV cache size
  • nkv_headsn_{\text{kv\_heads}}=Number of KV heads
  • nquery_headsn_{\text{query\_heads}}=Number of query heads

GQA Memory Savings

LLaMA-2 70B with GQA (8 KV heads, 64 query heads):

  • Standard MHA: 80 layers x 64 heads x 128 dim x 4096 seq = 10 GB
  • With GQA (8 KV heads): 80 x 8 x 128 x 4096 = 1.25 GB
  • Memory reduction: 8x

Practice Exercises

  1. Memory Analysis: Calculate the KV cache memory for a 70B model serving 100 concurrent requests with 2048 token sequences. How much GPU memory is needed?

  2. PagedAttention Design: Design a page allocation policy that minimizes fragmentation for a workload with 30% short requests (<256 tokens) and 70% long requests (>1024 tokens).

  3. Prefix Caching: If you serve a chatbot with a 1000-token system prompt and 1000 concurrent users, how much KV cache memory is saved by prefix caching?

  4. Quantization Tradeoff: Compare INT8 vs INT4 KV cache quantization in terms of memory savings, latency overhead, and quality degradation on a long-context benchmark.

Key Takeaways

Summary: KV Cache Optimization

  • KV cache is the primary memory bottleneck in LLM serving
  • PagedAttention eliminates fragmentation with non-contiguous memory pages
  • RadixAttention shares KV cache prefixes across requests
  • KV cache quantization (INT8/INT4) reduces memory by 2-4x
  • GQA reduces KV cache by sharing heads across query groups
  • Prefix caching eliminates redundant computation for shared prefixes
  • Memory waste drops from 50-80% to <4% with PagedAttention
  • Combined techniques enable serving 10x more concurrent requests

What to Learn Next

-> Flash Attention and Memory Efficiency IO-aware attention algorithms that reduce memory.

-> Continuous Batching for LLMs Dynamic batching for maximum GPU utilization.

-> Speculative Decoding Generating multiple tokens per step.

-> LLM Inference Optimization Broader inference optimization strategies.

-> Long Context and Context Window Handling very long sequences efficiently.

-> Building Production LLM Applications End-to-end production systems.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement