Inference Optimization
Speculative Decoding — Generating Multiple Tokens Per Step
Speculative decoding breaks the sequential bottleneck of autoregressive generation by using a small "draft" model to propose multiple tokens that are then verified by the large model in parallel.
- Draft Model — A small, fast model proposes token sequences
- Parallel Verification — The large model verifies all proposals simultaneously
- Lossless Speedup — Output distribution is identical to the large model alone
Why wait for one token when you can verify five?
Speculative Decoding
Autoregressive LLM inference is inherently sequential — each token depends on all previous tokens. Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) exploits the fact that verification is cheaper than generation to achieve 2-3x speedup with identical output quality.
DfSpeculative Decoding
Speculative decoding uses a small, fast "draft" model to generate K candidate tokens, then the large "target" model verifies all K tokens in parallel. Tokens that match the target model's distribution are accepted; mismatches trigger a rejection and the process restarts.
The Algorithm
Draft-then-Verify
1. Draft model generates K tokens: x_1, x_2, ..., x_K
2. Target model processes all K tokens in parallel
3. For each position i = 1 to K:
a. Compute acceptance probability: min(1, p_target(x_i) / p_draft(x_i))
b. Sample accept/reject from Bernoulli(acceptance_probability)
c. If rejected, sample correction token from adjusted distribution
4. Return accepted tokens + correction token (if rejected)
Acceptance Probability
Here,
- =Acceptance probability for token x_i
- =Target (large) model probability
- =Draft (small) model probability
Correction Distribution
When a token is rejected, the correction token is sampled from:
Correction Distribution
Here,
- =Distribution for correction token
The key insight: speculative decoding with rejection sampling produces output tokens that are exactly distributed according to the target model. There is no quality loss — only speed gain.
Expected Speedup
Expected Accepted Tokens
Here,
- =Number of draft tokens
Speedup Calculation
If the average acceptance rate is 0.7 and K=5:
- Expected accepted tokens = 1 + 0.7 + 0.49 + 0.343 + 0.24 = 2.77
- Speedup = 2.77x (with K=5 proposals, only 1 target forward pass)
- Actual speedup is slightly less due to overhead, typically 2-2.5x
Implementation
import torch
import torch.nn.functional as F
def speculative_decode(draft_model, target_model, prompt, K=5, max_tokens=100):
"""Speculative decoding with rejection sampling."""
input_ids = prompt.clone()
generated = []
for _ in range(max_tokens // K):
# Step 1: Draft model generates K tokens
draft_tokens = []
draft_probs = []
draft_input = input_ids.clone()
for _ in range(K):
with torch.no_grad():
draft_out = draft_model(draft_input)
probs = F.softmax(draft_out.logits[:, -1, :], dim=-1)
token = torch.multinomial(probs, 1)
draft_tokens.append(token)
draft_probs.append(probs)
draft_input = torch.cat([draft_input, token], dim=-1)
# Step 2: Target model verifies all K tokens
candidate = torch.cat([input_ids] + draft_tokens, dim=-1)
with torch.no_grad():
target_out = target_model(candidate)
# Step 3: Accept/reject each token
accepted = 0
for i in range(K):
target_probs = F.softmax(target_out.logits[:, input_ids.shape[1] + i - 1, :], dim=-1)
draft_prob = draft_probs[i].gather(1, draft_tokens[i].unsqueeze(-1)).squeeze()
target_prob = target_probs.gather(1, draft_tokens[i].unsqueeze(-1)).squeeze()
accept_prob = min(1.0, (target_prob / (draft_prob + 1e-10)).item())
if torch.rand(1).item() < accept_prob:
accepted += 1
generated.append(draft_tokens[i])
else:
# Sample correction from adjusted distribution
adjusted = F.relu(target_probs - draft_probs[i])
adjusted = adjusted / (adjusted.sum() + 1e-10)
correction = torch.multinomial(adjusted, 1)
generated.append(correction)
break
input_ids = torch.cat([input_ids] + draft_tokens[:accepted + 1], dim=-1)
return torch.cat(generated, dim=-1)[:max_tokens]
Draft Model Selection
| Draft Model Type | Size Ratio | Acceptance Rate | Speedup |
|---|---|---|---|
| Same architecture, fewer layers | 50% | 70-80% | 2.0-2.5x |
| Smaller vocabulary | 30% | 60-70% | 1.8-2.2x |
| n-gram model | <1% | 40-50% | 1.5-1.8x |
| Trained draft head | ~5% | 75-85% | 2.2-2.8x |
Advanced Variants
Medusa: Draft Head Approach
DfMedusa
Medusa (Cai et al., 2024) adds multiple "draft heads" to the target model itself. Each head predicts a different future position, eliminating the need for a separate draft model.
class MedusaModel(nn.Module):
def __init__(self, base_model, num_heads=3):
super().__init__()
self.base_model = base_model
hidden_size = base_model.config.hidden_size
# Draft heads for positions 1, 2, ..., num_heads
self.draft_heads = nn.ModuleList([
nn.Linear(hidden_size, base_model.config.vocab_size)
for _ in range(num_heads)
])
def forward(self, input_ids):
base_output = self.base_model(input_ids, output_hidden_states=True)
hidden = base_output.hidden_states[-1]
# Main logits (position 0)
main_logits = self.base_model.lm_head(hidden[:, -1, :])
# Draft logits (positions 1, 2, ..., K)
draft_logits = [head(hidden[:, -1, :]) for head in self.draft_heads]
return main_logits, draft_logits
EAGLE: Autoregressive Drafts
DfEAGLE
EAGLE (Li et al., 2024) uses autoregressive draft generation with feature-level speculation. It generates draft tokens by conditioning on the target model's hidden states, achieving higher acceptance rates than Medusa.
EAGLE achieves 2.5-3.5x speedup on code generation tasks, where token patterns are more predictable. The key advantage is that draft generation uses the target model's internal representations, leading to higher acceptance rates.
Practice Exercises
-
Acceptance Rate Analysis: For a given prompt, measure the acceptance rate of speculative decoding with K=5 using a 7B draft and 70B target model. How does acceptance rate vary across domains (code vs. prose)?
-
Draft Model Comparison: Compare the speedup of using a 3B draft model vs. a 1.5B draft model with a 70B target. What is the optimal draft size?
-
Medusa Implementation: Implement a 3-head Medusa model. How does the acceptance rate of head 3 compare to head 1?
-
Cost Analysis: Calculate the cost-per-token of speculative decoding vs. standard autoregressive decoding, accounting for draft model inference overhead.
Key Takeaways
Summary: Speculative Decoding
- Draft-then-verify enables generating multiple tokens per step
- Rejection sampling ensures output distribution matches target model exactly
- Typical speedup is 2-3x with 70-85% acceptance rate
- Medusa eliminates separate draft model by adding heads to target
- EAGLE uses autoregressive drafts for higher acceptance rates
- Best for scenarios where draft model cost << target model savings
- Not beneficial when draft model is too large or acceptance rate is too low
What to Learn Next
-> LLM Inference Optimization Broader strategies for making LLM inference faster.
-> Flash Attention and Memory Efficiency IO-aware attention algorithms that reduce memory.
-> KV Cache Optimization Reducing memory usage of the key-value cache.
-> Continuous Batching for LLMs Maximizing GPU utilization with dynamic batching.
-> Quantization Techniques Deep Dive Reducing model size through quantization.
-> Model Parallelism and Tensor Parallelism Splitting models across GPUs for inference.