LLM Inference Optimization

ProductionInferenceFree Lesson

Advertisement

LLM Inference Optimization

Efficient inference is critical for deploying language models in production. The autoregressive nature of LLMsβ€”generating one token at a timeβ€”creates unique optimization challenges. This tutorial covers the core techniques for maximizing throughput and minimizing latency.

The Inference Challenge

LLM inference has two distinct phases:

  1. Prefill phase: Process the entire input prompt in parallel (compute-bound)
  2. Decode phase: Generate tokens one at a time (memory-bandwidth-bound)

The number of tokens generated per second, measured as total output tokens divided by wall-clock time. Throughput is limited by memory bandwidth during the decode phase.

KV Cache

The KV cache is the most fundamental optimization for autoregressive generation. Instead of recomputing attention for all previous tokens at each step, we cache key-value pairs.

KV Cache Memory

textKVCacheSize=2timesLtimesntextheadstimesdtextheadtimesstimesbtimestextbytes_per_param\\text{KV Cache Size} = 2 \\times L \\times n_{\\text{heads}} \\times d_{\\text{head}} \\times s \\times b \\times \\text{bytes\_per\_param}

Here,

  • =
  • =
  • =
  • =
  • =
  • =

For a 7B parameter model with 32 layers, 32 heads, 128 head dimension, and sequence length 2048, the KV cache requires approximately 2 Γ— 32 Γ— 32 Γ— 128 Γ— 2048 Γ— 2 bytes β‰ˆ 1GB per sequence in float16.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class KVCacheModel:
    def __init__(self, model_name: str):
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.kv_cache = None
    
    def prefill(self, prompt: str):
        """Process entire prompt and cache KV pairs."""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs, use_cache=True)
            self.kv_cache = outputs.past_key_values
        
        return outputs.logits[:, -1, :]
    
    def decode_step(self, input_ids: torch.Tensor):
        """Generate next token using cached KV pairs."""
        with torch.no_grad():
            outputs = self.model(
                input_ids,
                past_key_values=self.kv_cache,
                use_cache=True
            )
            self.kv_cache = outputs.past_key_values
        
        next_token_logits = outputs.logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        return next_token
    
    def generate(self, prompt: str, max_new_tokens: int = 100):
        """Full generation with KV cache."""
        self.kv_cache = None
        logits = self.prefill(prompt)
        
        next_token = torch.argmax(logits[:, -1:], dim=-1)
        generated = [next_token.item()]
        
        for _ in range(max_new_tokens - 1):
            next_token = self.decode_step(next_token)
            generated.append(next_token.item())
        
        return self.tokenizer.decode(generated)

Quantization

Quantization reduces model size and accelerates inference by using lower-precision representations.

Quantization Methods Comparison

MethodTypeBitsSpeedupQuality Impact
FP16Post-training161xNone
INT8Post-training81.5-2xMinimal
INT4 (GPTQ)Post-training42-3xSmall
AWQPost-training42-3xSmall
GGUFRuntime2-8VariableVariable
QLoRATraining4N/AMinimal

GPTQ Quantization

GPTQ (GPT Quantization) uses second-order information to quantize weights with minimal accuracy loss:

GPTQ Quantization Objective

minhatW∣WXβˆ’hatWX∣22quadtexts.t.quadhatWinβˆ’2bβˆ’1,...,2bβˆ’1βˆ’1mtimesn\\min_{\\hat{W}} \\| WX - \\hat{W}X \\|_2^2 \\quad \\text{s.t.} \\quad \\hat{W} \\in \\{-2^{b-1}, ..., 2^{b-1}-1\\}^{m \\times n}

Here,

  • =
  • =
  • =
  • =
  • =

AWQ (Activation-Aware Weight Quantization)

AWQ identifies salient weight channels based on activation magnitudes and preserves them during quantization:

import torch
from awq import AutoAWQForCausalLM

def quantize_with_awq(model_path: str, output_path: str):
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    
    quant_config = {
        "zero_point": True,
        "q_group_size": 128,
        "w_bit": 4,
        "version": "GEMM"
    }
    
    model.quantize(
        tokenizer,
        quant_config=quant_config,
        calib_data="dataset"
    )
    
    model.save_quantized(output_path)
    tokenizer.save_pretrained(output_path)

Speculative Decoding

Speculative decoding uses a smaller "draft" model to propose tokens that are verified in parallel by the target model.

A technique where a small, fast draft model generates candidate token sequences that are verified in parallel by the larger target model. Accepted tokens are accepted in bulk, achieving sub-linear token generation cost.

Speculative Decoding Acceptance

Ptextaccept(x)=minleft(1,fracPtexttarget(x)Ptextdraft(x)right)P_{\\text{accept}}(x) = \\min\\left(1, \\frac{P_{\\text{target}}(x)}{P_{\\text{draft}}(x)}\\right)

Here,

  • =
  • =
  • =
class SpeculativeDecoder:
    def __init__(self, target_model, draft_model, tokenizer, gamma: int = 5):
        self.target = target_model
        self.draft = draft_model
        self.tokenizer = tokenizer
        self.gamma = gamma  # max draft tokens per step
    
    def generate(self, prompt: str, max_new_tokens: int = 100):
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        generated = input_ids.clone()
        
        while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
            # Draft phase: generate gamma tokens with draft model
            draft_tokens = []
            draft_probs = []
            x = generated.clone()
            
            for _ in range(self.gamma):
                with torch.no_grad():
                    logits = self.draft(x).logits[:, -1, :]
                    probs = torch.softmax(logits, dim=-1)
                    token = torch.multinomial(probs, 1)
                    draft_tokens.append(token)
                    draft_probs.append(probs)
                    x = torch.cat([x, token], dim=-1)
            
            # Verification phase: run target model on all draft tokens
            with torch.no_grad():
                target_logits = self.target(x).logits
            
            # Accept or reject each draft token
            n_accepted = 0
            for i in range(self.gamma):
                target_prob = torch.softmax(target_logits[:, -1-i, :], dim=-1)
                draft_prob = draft_probs[i]
                token = draft_tokens[i]
                
                accept_prob = torch.min(
                    torch.ones(1),
                    target_prob[0, token[0, 0]] / draft_prob[0, token[0, 0]]
                )
                
                if torch.rand(1) < accept_prob:
                    n_accepted += 1
                else:
                    break
            
            # Add accepted tokens
            accepted = torch.cat(draft_tokens[:n_accepted], dim=-1)
            generated = torch.cat([generated, accepted], dim=-1)
            
            # Sample one more token from target if not at max
            if generated.shape[1] - input_ids.shape[1] < max_new_tokens:
                with torch.no_grad():
                    target_sample = torch.softmax(target_logits[:, -1, :], dim=-1)
                    new_token = torch.multinomial(target_sample, 1)
                    generated = torch.cat([generated, new_token], dim=-1)
        
        return self.tokenizer.decode(generated[0])

Continuous Batching

Continuous batching (also called in-flight batching) allows dynamic request scheduling instead of static batch processing.

Continuous Batching Throughput

textThroughput=fracsumi=1BLiTtexttotal\\text{Throughput} = \\frac{\\sum_{i=1}^{B} L_i}{T_{\\text{total}}}

Here,

  • =
  • =
  • =
class ContinuousBatchScheduler:
    def __init__(self, model, max_batch_size: int = 32, max_tokens: int = 4096):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_tokens = max_tokens
        self.pending_requests = []
        self.active_requests = []
    
    def add_request(self, request):
        self.pending_requests.append(request)
    
    def schedule_step(self):
        # Fill batch from pending requests
        while (len(self.active_requests) < self.max_batch_size and 
               self.pending_requests):
            request = self.pending_requests.pop(0)
            if self._can_add(request):
                self.active_requests.append(request)
        
        # Run one decode step for all active requests
        if self.active_requests:
            self._decode_step()
            
            # Remove completed requests
            completed = [r for r in self.active_requests if r.is_done]
            for r in completed:
                self.active_requests.remove(r)
    
    def _can_add(self, request):
        total_tokens = sum(r.current_length for r in self.active_requests)
        return total_tokens + request.current_length <= self.max_tokens
    
    def _decode_step(self):
        # Batch all active requests together
        batch = self._prepare_batch()
        with torch.no_grad():
            logits = self.model(batch.input_ids).logits
        
        for i, request in enumerate(self.active_requests):
            next_token = torch.argmax(logits[i, -1, :])
            request.append_token(next_token)

Deployment Frameworks

vLLM

vLLM pioneered PagedAttention for efficient KV cache management:

PagedAttention stores KV cache in non-contiguous memory blocks (like virtual memory pages), reducing memory fragmentation from 60-80% to near 0%. This enables serving more concurrent requests.

from vllm import LLM, SamplingParams

# Initialize vLLM engine
llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    tensor_parallel_size=1,
    gpu_memory_utilization=0.9,
    max_model_len=4096
)

# Batch inference
prompts = [
    "Explain quantum computing in simple terms.",
    "Write a Python function to sort a list.",
    "What are the benefits of exercise?"
]

sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=256
)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    print(output.outputs[0].text)

TensorRT-LLM

TensorRT-LLM provides NVIDIA-optimized inference:

import tensorrt_llm

# Build engine
builder = tensorrt_llm.Builder()
network = builder.create_network()
# ... build network architecture ...

# Optimize for inference
config = builder.create_builder_config()
config.max_batch_size = 32
config.max_input_len = 2048
config.max_output_len = 512

engine = builder.build_serialized_network(network, config)

Latency vs Throughput Tradeoffs

Cost per Token

textCosttexttoken=fractextGPU_cost_per_secondtimestextlatencytexttokens_generated\\text{Cost}_{\\text{token}} = \\frac{\\text{GPU\_cost\_per\_second} \\times \\text{latency}}{\\text{tokens\_generated}}

Here,

  • =
  • =
  • =
OptimizationLatency ImpactThroughput ImpactUse Case
KV Cache↓↓↓↓Always use
Quantization (INT4)↓↑↑↑Memory-limited
Speculative Decoding↓↓↑High-latency apps
Continuous Batching↓↑↑↑Multi-user serving
Tensor Parallelism↓↓↑↑Large models

For production deployments, start with quantization (INT4/GPTQ) for immediate memory and throughput gains, then add continuous batching for multi-user scenarios. Speculative decoding is most beneficial for latency-sensitive applications with single-user serving.

Practical Deployment Example

from vllm import LLM, SamplingParams
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn

app = FastAPI()

# Initialize optimized model
llm = LLM(
    model="TheBloke/Llama-2-7B-Chat-GPTQ",
    quantization="gptq",
    tensor_parallel_size=1,
    gpu_memory_utilization=0.85,
    max_model_len=2048,
    enforce_eager=True  # Disable CUDA graphs for lower latency
)

class GenerationRequest(BaseModel):
    prompt: str
    max_tokens: int = 256
    temperature: float = 0.7

@app.post("/generate")
async def generate(request: GenerationRequest):
    params = SamplingParams(
        temperature=request.temperature,
        top_p=0.95,
        max_tokens=request.max_tokens
    )
    
    outputs = llm.generate([request.prompt], params)
    return {
        "text": outputs[0].outputs[0].text,
        "tokens": len(outputs[0].outputs[0].token_ids),
        "usage": {
            "prompt_tokens": len(outputs[0].prompt_token_ids),
            "completion_tokens": len(outputs[0].outputs[0].token_ids)
        }
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Summary

  • KV cache eliminates redundant attention computation, requiring 2 Γ— L Γ— n_heads Γ— d_head Γ— s Γ— b Γ— bytes per sequence
  • GPTQ and AWQ provide 2-3x speedup with INT4 quantization and minimal quality loss
  • Speculative decoding achieves sub-linear token generation cost via draft-verify
  • Continuous batching maximizes throughput by dynamically scheduling requests
  • vLLM's PagedAttention reduces KV cache memory fragmentation to near 0%
  • Latency vs throughput tradeoffs depend on use case; optimize accordingly

Practice Exercises

  1. KV Cache Analysis: Calculate KV cache memory for LLaMA-2-7B, LLaMA-2-13B, and LLaMA-2-70B at sequence length 4096. What batch sizes are feasible on a 24GB GPU?

  2. Quantization Benchmark: Compare inference speed and quality between FP16, INT8, and INT4 (GPTQ) for a 7B model. Measure perplexity on WikiText-2 and tokens/second.

  3. Speculative Decoding: Implement speculative decoding with a 7B target model and 1.5B draft model. Measure acceptance rate and speedup.

  4. Throughput Optimization: Deploy a model with vLLM and measure throughput at different batch sizes. Find the optimal configuration for your hardware.

  5. Latency Optimization: Optimize a model for minimum first-token latency using quantization, KV cache, and CUDA graphs. Measure the impact of each optimization.


Previous: 15 - LLM Evaluation Benchmarks ← | Next: 17 - Long Context Window β†’

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement