Inference Optimization
KV Cache Optimization — The Memory Bottleneck
The KV cache stores key-value tensors for all previous tokens during generation. For large models, it often exceeds the model's own memory footprint. Optimization is essential for serving.
- PagedAttention — Non-contiguous memory allocation eliminates fragmentation
- RadixAttention — Prefix sharing across requests reduces redundant computation
- Cache Quantization — Store KV cache in lower precision for 2-4x memory reduction
The KV cache is the hidden cost of autoregressive generation.
KV Cache Optimization for LLMs
During autoregressive generation, each new token requires attention over all previous tokens. The KV cache stores pre-computed key and value tensors, avoiding redundant computation. However, for long sequences and large models, the KV cache becomes the primary memory bottleneck.
DfKV Cache
The KV cache stores the key (K) and value (V) tensors computed during attention for all previous tokens. Without caching, each new token would require re-computing attention over all previous tokens, making generation O(n^2) instead of O(n).
KV Cache Memory Requirements
KV Cache Size
Here,
- =Number of transformer layers
- =Number of attention heads
- =Dimension per head
- =Sequence length
- =Batch size
- =Bytes per element (2 for FP16, 1 for INT8)
KV Cache Memory for Popular Models
For sequence length 4096, batch size 1, FP16:
| Model | Layers | Heads | Head Dim | KV Cache Size |
|---|---|---|---|---|
| LLaMA-2 7B | 32 | 32 | 128 | 1 GB |
| LLaMA-2 70B | 80 | 64 | 128 | 10 GB |
| GPT-4 (est.) | 120 | 96 | 128 | 22.5 GB |
The KV cache for a single request can exceed the model's parameter memory!
PagedAttention
The Problem: Memory Fragmentation
DfKV Cache Fragmentation
Traditional KV cache allocation requires contiguous memory blocks for each request. As requests of varying lengths are served, memory becomes fragmented — free memory exists in small chunks that cannot be allocated, leading to 50-80% memory waste.
The Solution: PagedAttention
DfPagedAttention
PagedAttention (Kwon et al., 2023) borrows from operating system virtual memory. It divides the KV cache into fixed-size "pages" that can be stored non-contiguously in GPU memory. A page table maps logical KV cache positions to physical memory blocks.
Logical KV Cache: [Page 0] [Page 1] [Page 2] [Page 3]
| | | |
Physical Memory: [Block 5] [Block 2] [Block 8] [Block 1]
| | | |
GPU Memory: Non-contiguous allocation
PagedAttention reduces memory waste from 50-80% to less than 4%. This enables serving 2-4x more concurrent requests with the same GPU memory.
Implementation
class PagedKVCache:
def __init__(self, page_size=16, num_pages=1024, num_layers=32, num_heads=32, head_dim=128):
self.page_size = page_size
self.num_pages = num_pages
self.page_table = {} # request_id -> list of physical pages
self.free_pages = list(range(num_pages))
# Pre-allocate all pages
self.k_cache = torch.zeros(num_pages, page_size, num_heads, head_dim, device="cuda")
self.v_cache = torch.zeros(num_pages, page_size, num_heads, head_dim, device="cuda")
def allocate_page(self, request_id):
if not self.free_pages:
raise MemoryError("No free pages available")
page = self.free_pages.pop()
if request_id not in self.page_table:
self.page_table[request_id] = []
self.page_table[request_id].append(page)
return page
def free_request(self, request_id):
if request_id in self.page_table:
self.free_pages.extend(self.page_table[request_id])
del self.page_table[request_id]
def get_kv(self, request_id, position):
page_idx = position // self.page_size
offset = position % self.page_size
physical_page = self.page_table[request_id][page_idx]
return self.k_cache[physical_page, offset], self.v_cache[physical_page, offset]
def store_kv(self, request_id, position, k, v):
page_idx = position // self.page_size
offset = position % self.page_size
if page_idx >= len(self.page_table.get(request_id, [])):
self.allocate_page(request_id)
physical_page = self.page_table[request_id][page_idx]
self.k_cache[physical_page, offset] = k
self.v_cache[physical_page, offset] = v
RadixAttention
DfRadixAttention
RadixAttention (Zheng et al., 2024) uses a radix tree (trie) to automatically share KV cache prefixes across requests with common prefixes. This is particularly effective for system prompts shared across many requests.
Radix Tree for KV Cache Sharing:
[System Prompt]
|
[User A Response Start]
/ \
[User A Response 1] [User A Response 2]
| |
[Generated A1] [Generated A2]
All three leaf nodes share the System Prompt KV cache.
In a typical chat application, 60-80% of the KV cache is the system prompt. RadixAttention enables sharing this across all requests, reducing per-request memory from 1GB to 200MB.
Prefix Caching Benefits
| Scenario | Prefix Size | Requests | Memory Saved | Latency Reduction |
|---|---|---|---|---|
| System prompt sharing | 2000 tokens | 1000 | 99% | 40% (fewer prefill steps) |
| Document QA with context | 8000 tokens | 100 | 95% | 60% |
| Few-shot prompting | 500 tokens | 500 | 80% | 20% |
| Code completion | 1000 tokens | 2000 | 90% | 30% |
KV Cache Quantization
DfKV Cache Quantization
KV cache quantization stores key and value tensors in lower precision (INT8 or INT4) instead of FP16. This reduces KV cache memory by 2-4x with minimal quality degradation.
def quantize_kv_cache(k_cache, v_cache, bits=8):
"""Quantize KV cache to lower precision."""
if bits == 8:
k_quantized = k_cache.to(torch.int8)
v_quantized = v_cache.to(torch.int8)
k_scale = k_cache.abs().max() / 127
v_scale = v_cache.abs().max() / 127
return k_quantized, v_quantized, k_scale, v_scale
elif bits == 4:
# Group quantization with group size 128
group_size = 128
k_groups = k_cache.reshape(-1, group_size)
k_max = k_groups.abs().max(dim=1, keepdim=True).values
k_scale = k_max / 7
k_quantized = (k_groups / k_scale).round().to(torch.int8).reshape(k_cache.shape)
return k_quantized, k_scale
Quantization Error Bound
Here,
- =Quantized key tensor
- =Original key tensor
- =Number of bits for quantization
Grouped-Query Attention (GQA)
DfGrouped-Query Attention
GQA shares key-value heads across multiple query heads, reducing the KV cache size proportionally. LLaMA-2 70B uses GQA with 8 KV heads shared across 64 query heads, reducing KV cache by 8x.
GQA KV Cache Reduction
Here,
- =GQA KV cache size
- =Standard MHA KV cache size
- =Number of KV heads
- =Number of query heads
GQA Memory Savings
LLaMA-2 70B with GQA (8 KV heads, 64 query heads):
- Standard MHA: 80 layers x 64 heads x 128 dim x 4096 seq = 10 GB
- With GQA (8 KV heads): 80 x 8 x 128 x 4096 = 1.25 GB
- Memory reduction: 8x
Practice Exercises
-
Memory Analysis: Calculate the KV cache memory for a 70B model serving 100 concurrent requests with 2048 token sequences. How much GPU memory is needed?
-
PagedAttention Design: Design a page allocation policy that minimizes fragmentation for a workload with 30% short requests (<256 tokens) and 70% long requests (>1024 tokens).
-
Prefix Caching: If you serve a chatbot with a 1000-token system prompt and 1000 concurrent users, how much KV cache memory is saved by prefix caching?
-
Quantization Tradeoff: Compare INT8 vs INT4 KV cache quantization in terms of memory savings, latency overhead, and quality degradation on a long-context benchmark.
Key Takeaways
Summary: KV Cache Optimization
- KV cache is the primary memory bottleneck in LLM serving
- PagedAttention eliminates fragmentation with non-contiguous memory pages
- RadixAttention shares KV cache prefixes across requests
- KV cache quantization (INT8/INT4) reduces memory by 2-4x
- GQA reduces KV cache by sharing heads across query groups
- Prefix caching eliminates redundant computation for shared prefixes
- Memory waste drops from 50-80% to <4% with PagedAttention
- Combined techniques enable serving 10x more concurrent requests
What to Learn Next
-> Flash Attention and Memory Efficiency IO-aware attention algorithms that reduce memory.
-> Continuous Batching for LLMs Dynamic batching for maximum GPU utilization.
-> Speculative Decoding Generating multiple tokens per step.
-> LLM Inference Optimization Broader inference optimization strategies.
-> Long Context and Context Window Handling very long sequences efficiently.
-> Building Production LLM Applications End-to-end production systems.