LLM Inference Optimization
Efficient inference is critical for deploying language models in production. The autoregressive nature of LLMsβgenerating one token at a timeβcreates unique optimization challenges. This tutorial covers the core techniques for maximizing throughput and minimizing latency.
The Inference Challenge
LLM inference has two distinct phases:
- Prefill phase: Process the entire input prompt in parallel (compute-bound)
- Decode phase: Generate tokens one at a time (memory-bandwidth-bound)
The number of tokens generated per second, measured as total output tokens divided by wall-clock time. Throughput is limited by memory bandwidth during the decode phase.
KV Cache
The KV cache is the most fundamental optimization for autoregressive generation. Instead of recomputing attention for all previous tokens at each step, we cache key-value pairs.
KV Cache Memory
Here,
- =
- =
- =
- =
- =
- =
For a 7B parameter model with 32 layers, 32 heads, 128 head dimension, and sequence length 2048, the KV cache requires approximately 2 Γ 32 Γ 32 Γ 128 Γ 2048 Γ 2 bytes β 1GB per sequence in float16.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class KVCacheModel:
def __init__(self, model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.kv_cache = None
def prefill(self, prompt: str):
"""Process entire prompt and cache KV pairs."""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs, use_cache=True)
self.kv_cache = outputs.past_key_values
return outputs.logits[:, -1, :]
def decode_step(self, input_ids: torch.Tensor):
"""Generate next token using cached KV pairs."""
with torch.no_grad():
outputs = self.model(
input_ids,
past_key_values=self.kv_cache,
use_cache=True
)
self.kv_cache = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
return next_token
def generate(self, prompt: str, max_new_tokens: int = 100):
"""Full generation with KV cache."""
self.kv_cache = None
logits = self.prefill(prompt)
next_token = torch.argmax(logits[:, -1:], dim=-1)
generated = [next_token.item()]
for _ in range(max_new_tokens - 1):
next_token = self.decode_step(next_token)
generated.append(next_token.item())
return self.tokenizer.decode(generated)
Quantization
Quantization reduces model size and accelerates inference by using lower-precision representations.
Quantization Methods Comparison
| Method | Type | Bits | Speedup | Quality Impact |
|---|---|---|---|---|
| FP16 | Post-training | 16 | 1x | None |
| INT8 | Post-training | 8 | 1.5-2x | Minimal |
| INT4 (GPTQ) | Post-training | 4 | 2-3x | Small |
| AWQ | Post-training | 4 | 2-3x | Small |
| GGUF | Runtime | 2-8 | Variable | Variable |
| QLoRA | Training | 4 | N/A | Minimal |
GPTQ Quantization
GPTQ (GPT Quantization) uses second-order information to quantize weights with minimal accuracy loss:
GPTQ Quantization Objective
Here,
- =
- =
- =
- =
- =
AWQ (Activation-Aware Weight Quantization)
AWQ identifies salient weight channels based on activation magnitudes and preserves them during quantization:
import torch
from awq import AutoAWQForCausalLM
def quantize_with_awq(model_path: str, output_path: str):
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
quant_config = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM"
}
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data="dataset"
)
model.save_quantized(output_path)
tokenizer.save_pretrained(output_path)
Speculative Decoding
Speculative decoding uses a smaller "draft" model to propose tokens that are verified in parallel by the target model.
A technique where a small, fast draft model generates candidate token sequences that are verified in parallel by the larger target model. Accepted tokens are accepted in bulk, achieving sub-linear token generation cost.
Speculative Decoding Acceptance
Here,
- =
- =
- =
class SpeculativeDecoder:
def __init__(self, target_model, draft_model, tokenizer, gamma: int = 5):
self.target = target_model
self.draft = draft_model
self.tokenizer = tokenizer
self.gamma = gamma # max draft tokens per step
def generate(self, prompt: str, max_new_tokens: int = 100):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
generated = input_ids.clone()
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# Draft phase: generate gamma tokens with draft model
draft_tokens = []
draft_probs = []
x = generated.clone()
for _ in range(self.gamma):
with torch.no_grad():
logits = self.draft(x).logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
token = torch.multinomial(probs, 1)
draft_tokens.append(token)
draft_probs.append(probs)
x = torch.cat([x, token], dim=-1)
# Verification phase: run target model on all draft tokens
with torch.no_grad():
target_logits = self.target(x).logits
# Accept or reject each draft token
n_accepted = 0
for i in range(self.gamma):
target_prob = torch.softmax(target_logits[:, -1-i, :], dim=-1)
draft_prob = draft_probs[i]
token = draft_tokens[i]
accept_prob = torch.min(
torch.ones(1),
target_prob[0, token[0, 0]] / draft_prob[0, token[0, 0]]
)
if torch.rand(1) < accept_prob:
n_accepted += 1
else:
break
# Add accepted tokens
accepted = torch.cat(draft_tokens[:n_accepted], dim=-1)
generated = torch.cat([generated, accepted], dim=-1)
# Sample one more token from target if not at max
if generated.shape[1] - input_ids.shape[1] < max_new_tokens:
with torch.no_grad():
target_sample = torch.softmax(target_logits[:, -1, :], dim=-1)
new_token = torch.multinomial(target_sample, 1)
generated = torch.cat([generated, new_token], dim=-1)
return self.tokenizer.decode(generated[0])
Continuous Batching
Continuous batching (also called in-flight batching) allows dynamic request scheduling instead of static batch processing.
Continuous Batching Throughput
Here,
- =
- =
- =
class ContinuousBatchScheduler:
def __init__(self, model, max_batch_size: int = 32, max_tokens: int = 4096):
self.model = model
self.max_batch_size = max_batch_size
self.max_tokens = max_tokens
self.pending_requests = []
self.active_requests = []
def add_request(self, request):
self.pending_requests.append(request)
def schedule_step(self):
# Fill batch from pending requests
while (len(self.active_requests) < self.max_batch_size and
self.pending_requests):
request = self.pending_requests.pop(0)
if self._can_add(request):
self.active_requests.append(request)
# Run one decode step for all active requests
if self.active_requests:
self._decode_step()
# Remove completed requests
completed = [r for r in self.active_requests if r.is_done]
for r in completed:
self.active_requests.remove(r)
def _can_add(self, request):
total_tokens = sum(r.current_length for r in self.active_requests)
return total_tokens + request.current_length <= self.max_tokens
def _decode_step(self):
# Batch all active requests together
batch = self._prepare_batch()
with torch.no_grad():
logits = self.model(batch.input_ids).logits
for i, request in enumerate(self.active_requests):
next_token = torch.argmax(logits[i, -1, :])
request.append_token(next_token)
Deployment Frameworks
vLLM
vLLM pioneered PagedAttention for efficient KV cache management:
PagedAttention stores KV cache in non-contiguous memory blocks (like virtual memory pages), reducing memory fragmentation from 60-80% to near 0%. This enables serving more concurrent requests.
from vllm import LLM, SamplingParams
# Initialize vLLM engine
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
max_model_len=4096
)
# Batch inference
prompts = [
"Explain quantum computing in simple terms.",
"Write a Python function to sort a list.",
"What are the benefits of exercise?"
]
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=256
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text)
TensorRT-LLM
TensorRT-LLM provides NVIDIA-optimized inference:
import tensorrt_llm
# Build engine
builder = tensorrt_llm.Builder()
network = builder.create_network()
# ... build network architecture ...
# Optimize for inference
config = builder.create_builder_config()
config.max_batch_size = 32
config.max_input_len = 2048
config.max_output_len = 512
engine = builder.build_serialized_network(network, config)
Latency vs Throughput Tradeoffs
Cost per Token
Here,
- =
- =
- =
| Optimization | Latency Impact | Throughput Impact | Use Case |
|---|---|---|---|
| KV Cache | βββ | β | Always use |
| Quantization (INT4) | β | βββ | Memory-limited |
| Speculative Decoding | ββ | β | High-latency apps |
| Continuous Batching | β | βββ | Multi-user serving |
| Tensor Parallelism | ββ | ββ | Large models |
For production deployments, start with quantization (INT4/GPTQ) for immediate memory and throughput gains, then add continuous batching for multi-user scenarios. Speculative decoding is most beneficial for latency-sensitive applications with single-user serving.
Practical Deployment Example
from vllm import LLM, SamplingParams
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
app = FastAPI()
# Initialize optimized model
llm = LLM(
model="TheBloke/Llama-2-7B-Chat-GPTQ",
quantization="gptq",
tensor_parallel_size=1,
gpu_memory_utilization=0.85,
max_model_len=2048,
enforce_eager=True # Disable CUDA graphs for lower latency
)
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 256
temperature: float = 0.7
@app.post("/generate")
async def generate(request: GenerationRequest):
params = SamplingParams(
temperature=request.temperature,
top_p=0.95,
max_tokens=request.max_tokens
)
outputs = llm.generate([request.prompt], params)
return {
"text": outputs[0].outputs[0].text,
"tokens": len(outputs[0].outputs[0].token_ids),
"usage": {
"prompt_tokens": len(outputs[0].prompt_token_ids),
"completion_tokens": len(outputs[0].outputs[0].token_ids)
}
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Summary
- KV cache eliminates redundant attention computation, requiring 2 Γ L Γ n_heads Γ d_head Γ s Γ b Γ bytes per sequence
- GPTQ and AWQ provide 2-3x speedup with INT4 quantization and minimal quality loss
- Speculative decoding achieves sub-linear token generation cost via draft-verify
- Continuous batching maximizes throughput by dynamically scheduling requests
- vLLM's PagedAttention reduces KV cache memory fragmentation to near 0%
- Latency vs throughput tradeoffs depend on use case; optimize accordingly
Practice Exercises
-
KV Cache Analysis: Calculate KV cache memory for LLaMA-2-7B, LLaMA-2-13B, and LLaMA-2-70B at sequence length 4096. What batch sizes are feasible on a 24GB GPU?
-
Quantization Benchmark: Compare inference speed and quality between FP16, INT8, and INT4 (GPTQ) for a 7B model. Measure perplexity on WikiText-2 and tokens/second.
-
Speculative Decoding: Implement speculative decoding with a 7B target model and 1.5B draft model. Measure acceptance rate and speedup.
-
Throughput Optimization: Deploy a model with vLLM and measure throughput at different batch sizes. Find the optimal configuration for your hardware.
-
Latency Optimization: Optimize a model for minimum first-token latency using quantization, KV cache, and CUDA graphs. Measure the impact of each optimization.
Previous: 15 - LLM Evaluation Benchmarks β | Next: 17 - Long Context Window β