CW

Speculative Decoding

Inference OptimizationDecoding StrategiesFree Lesson

Advertisement

Inference Optimization

Speculative Decoding — Generating Multiple Tokens Per Step

Speculative decoding breaks the sequential bottleneck of autoregressive generation by using a small "draft" model to propose multiple tokens that are then verified by the large model in parallel.

  • Draft Model — A small, fast model proposes token sequences
  • Parallel Verification — The large model verifies all proposals simultaneously
  • Lossless Speedup — Output distribution is identical to the large model alone

Why wait for one token when you can verify five?

Speculative Decoding

Autoregressive LLM inference is inherently sequential — each token depends on all previous tokens. Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) exploits the fact that verification is cheaper than generation to achieve 2-3x speedup with identical output quality.

DfSpeculative Decoding

Speculative decoding uses a small, fast "draft" model to generate K candidate tokens, then the large "target" model verifies all K tokens in parallel. Tokens that match the target model's distribution are accepted; mismatches trigger a rejection and the process restarts.

The Algorithm

Draft-then-Verify

Architecture Diagram
1. Draft model generates K tokens: x_1, x_2, ..., x_K
2. Target model processes all K tokens in parallel
3. For each position i = 1 to K:
   a. Compute acceptance probability: min(1, p_target(x_i) / p_draft(x_i))
   b. Sample accept/reject from Bernoulli(acceptance_probability)
   c. If rejected, sample correction token from adjusted distribution
4. Return accepted tokens + correction token (if rejected)

Acceptance Probability

A(xi)=min(1,ptarget(xix<i)pdraft(xix<i))A(x_i) = \min\left(1, \frac{p_{\text{target}}(x_i | x_{<i})}{p_{\text{draft}}(x_i | x_{<i})}\right)

Here,

  • A(xi)A(x_i)=Acceptance probability for token x_i
  • ptargetp_{\text{target}}=Target (large) model probability
  • pdraftp_{\text{draft}}=Draft (small) model probability

Correction Distribution

When a token is rejected, the correction token is sampled from:

Correction Distribution

pcorrection(x)=max(0,ptarget(x)pdraft(x))xmax(0,ptarget(x)pdraft(x))p_{\text{correction}}(x) = \frac{\max(0, p_{\text{target}}(x) - p_{\text{draft}}(x))}{\sum_{x'} \max(0, p_{\text{target}}(x') - p_{\text{draft}}(x'))}

Here,

  • pcorrectionp_{\text{correction}}=Distribution for correction token

The key insight: speculative decoding with rejection sampling produces output tokens that are exactly distributed according to the target model. There is no quality loss — only speed gain.

Expected Speedup

Expected Accepted Tokens

E[accepted]=i=1Kj=1iP(accept at j)E[\text{accepted}] = \sum_{i=1}^{K} \prod_{j=1}^{i} P(\text{accept at } j)

Here,

  • KK=Number of draft tokens

Speedup Calculation

If the average acceptance rate is 0.7 and K=5:

  • Expected accepted tokens = 1 + 0.7 + 0.49 + 0.343 + 0.24 = 2.77
  • Speedup = 2.77x (with K=5 proposals, only 1 target forward pass)
  • Actual speedup is slightly less due to overhead, typically 2-2.5x

Implementation

import torch
import torch.nn.functional as F

def speculative_decode(draft_model, target_model, prompt, K=5, max_tokens=100):
    """Speculative decoding with rejection sampling."""
    input_ids = prompt.clone()
    generated = []
    
    for _ in range(max_tokens // K):
        # Step 1: Draft model generates K tokens
        draft_tokens = []
        draft_probs = []
        draft_input = input_ids.clone()
        
        for _ in range(K):
            with torch.no_grad():
                draft_out = draft_model(draft_input)
            probs = F.softmax(draft_out.logits[:, -1, :], dim=-1)
            token = torch.multinomial(probs, 1)
            draft_tokens.append(token)
            draft_probs.append(probs)
            draft_input = torch.cat([draft_input, token], dim=-1)
        
        # Step 2: Target model verifies all K tokens
        candidate = torch.cat([input_ids] + draft_tokens, dim=-1)
        with torch.no_grad():
            target_out = target_model(candidate)
        
        # Step 3: Accept/reject each token
        accepted = 0
        for i in range(K):
            target_probs = F.softmax(target_out.logits[:, input_ids.shape[1] + i - 1, :], dim=-1)
            draft_prob = draft_probs[i].gather(1, draft_tokens[i].unsqueeze(-1)).squeeze()
            target_prob = target_probs.gather(1, draft_tokens[i].unsqueeze(-1)).squeeze()
            
            accept_prob = min(1.0, (target_prob / (draft_prob + 1e-10)).item())
            
            if torch.rand(1).item() < accept_prob:
                accepted += 1
                generated.append(draft_tokens[i])
            else:
                # Sample correction from adjusted distribution
                adjusted = F.relu(target_probs - draft_probs[i])
                adjusted = adjusted / (adjusted.sum() + 1e-10)
                correction = torch.multinomial(adjusted, 1)
                generated.append(correction)
                break
        
        input_ids = torch.cat([input_ids] + draft_tokens[:accepted + 1], dim=-1)
    
    return torch.cat(generated, dim=-1)[:max_tokens]

Draft Model Selection

Draft Model TypeSize RatioAcceptance RateSpeedup
Same architecture, fewer layers50%70-80%2.0-2.5x
Smaller vocabulary30%60-70%1.8-2.2x
n-gram model<1%40-50%1.5-1.8x
Trained draft head~5%75-85%2.2-2.8x

Advanced Variants

Medusa: Draft Head Approach

DfMedusa

Medusa (Cai et al., 2024) adds multiple "draft heads" to the target model itself. Each head predicts a different future position, eliminating the need for a separate draft model.

class MedusaModel(nn.Module):
    def __init__(self, base_model, num_heads=3):
        super().__init__()
        self.base_model = base_model
        hidden_size = base_model.config.hidden_size
        
        # Draft heads for positions 1, 2, ..., num_heads
        self.draft_heads = nn.ModuleList([
            nn.Linear(hidden_size, base_model.config.vocab_size)
            for _ in range(num_heads)
        ])
    
    def forward(self, input_ids):
        base_output = self.base_model(input_ids, output_hidden_states=True)
        hidden = base_output.hidden_states[-1]
        
        # Main logits (position 0)
        main_logits = self.base_model.lm_head(hidden[:, -1, :])
        
        # Draft logits (positions 1, 2, ..., K)
        draft_logits = [head(hidden[:, -1, :]) for head in self.draft_heads]
        
        return main_logits, draft_logits

EAGLE: Autoregressive Drafts

DfEAGLE

EAGLE (Li et al., 2024) uses autoregressive draft generation with feature-level speculation. It generates draft tokens by conditioning on the target model's hidden states, achieving higher acceptance rates than Medusa.

EAGLE achieves 2.5-3.5x speedup on code generation tasks, where token patterns are more predictable. The key advantage is that draft generation uses the target model's internal representations, leading to higher acceptance rates.

Practice Exercises

  1. Acceptance Rate Analysis: For a given prompt, measure the acceptance rate of speculative decoding with K=5 using a 7B draft and 70B target model. How does acceptance rate vary across domains (code vs. prose)?

  2. Draft Model Comparison: Compare the speedup of using a 3B draft model vs. a 1.5B draft model with a 70B target. What is the optimal draft size?

  3. Medusa Implementation: Implement a 3-head Medusa model. How does the acceptance rate of head 3 compare to head 1?

  4. Cost Analysis: Calculate the cost-per-token of speculative decoding vs. standard autoregressive decoding, accounting for draft model inference overhead.

Key Takeaways

Summary: Speculative Decoding

  • Draft-then-verify enables generating multiple tokens per step
  • Rejection sampling ensures output distribution matches target model exactly
  • Typical speedup is 2-3x with 70-85% acceptance rate
  • Medusa eliminates separate draft model by adding heads to target
  • EAGLE uses autoregressive drafts for higher acceptance rates
  • Best for scenarios where draft model cost << target model savings
  • Not beneficial when draft model is too large or acceptance rate is too low

What to Learn Next

-> LLM Inference Optimization Broader strategies for making LLM inference faster.

-> Flash Attention and Memory Efficiency IO-aware attention algorithms that reduce memory.

-> KV Cache Optimization Reducing memory usage of the key-value cache.

-> Continuous Batching for LLMs Maximizing GPU utilization with dynamic batching.

-> Quantization Techniques Deep Dive Reducing model size through quantization.

-> Model Parallelism and Tensor Parallelism Splitting models across GPUs for inference.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement