CW

Continuous Batching for LLMs

Inference OptimizationServing SystemsFree Lesson

Advertisement

Inference Optimization

Continuous Batching — Maximizing GPU Utilization

Static batching wastes GPU resources waiting for the longest request to complete. Continuous batching dynamically adds and removes requests during generation, keeping the GPU fully utilized.

  • Dynamic Scheduling — Add new requests as soon as slots free up
  • Iteration-Level Scheduling — Make scheduling decisions every generation step
  • Throughput Optimization — 3-5x throughput improvement over static batching

The GPU should never wait for a request, and no request should wait for the GPU.

Continuous Batching for LLMs

Traditional static batching requires all requests in a batch to complete before processing the next batch. This leads to significant waste — if one request finishes early, its GPU resources sit idle until the entire batch completes.

DfContinuous Batching

Continuous batching (also called iteration-level scheduling or in-flight batching) allows new requests to join a batch as soon as any request in the batch completes its generation. This maintains high GPU utilization throughout the serving process.

Static vs Continuous Batching

The Problem with Static Batching

Architecture Diagram
Static Batching:
Request 1: [====done====]......waiting......
Request 2: [========done========]...waiting..
Request 3: [================done================]
                                  ^^^^^^^^^^
                                  GPU idle time (wasted)

Continuous Batching Solution

Architecture Diagram
Continuous Batching:
Step 1: [Req1, Req2, Req3]
Step 2: [Req1, Req2, Req3] -> Req1 done, slot free
Step 3: [Req4, Req2, Req3] -> Req4 joins immediately
Step 4: [Req4, Req5, Req3] -> Req5 joins when Req2 done

Continuous batching requires iteration-level scheduling — making scheduling decisions at every generation step rather than at the batch level. This is implemented in vLLM, TensorRT-LLM, and TGI.

Throughput Analysis

Batch Throughput

Throughput=Nrequests×TˉoutputTtotal×tokens/step\text{Throughput} = \frac{N_{\text{requests}} \times \bar{T}_{\text{output}}}{T_{\text{total}}} \times \text{tokens/step}

Here,

  • NrequestsN_{\text{requests}}=Number of requests processed
  • Tˉoutput\bar{T}_{\text{output}}=Average output tokens per request
  • TtotalT_{\text{total}}=Total processing time

Throughput Comparison

With 100 requests (avg 50 output tokens each):

  • Static batching (batch=10): 10 batches x 50 steps = 500 GPU steps
  • Continuous batching: ~120 GPU steps (no idle time)
  • Speedup: 500/120 = 4.2x throughput improvement

Implementation Architecture

import asyncio
from dataclasses import dataclass
from typing import List, Optional
import heapq

@dataclass
class Request:
    request_id: str
    input_ids: torch.Tensor
    max_tokens: int
    priority: int = 0
    tokens_generated: int = 0
    done: bool = False

class ContinuousBatchScheduler:
    def __init__(self, model, max_batch_size=32):
        self.model = model
        self.max_batch_size = max_batch_size
        self.active_requests: List[Request] = []
        self.pending_requests: asyncio.Queue = asyncio.Queue()
    
    async def add_request(self, request: Request):
        await self.pending_requests.put(request)
    
    async def generate(self):
        while True:
            # Fill batch from pending requests
            while len(self.active_requests) < self.max_batch_size:
                if self.pending_requests.empty():
                    break
                request = await self.pending_requests.get()
                self.active_requests.append(request)
            
            if not self.active_requests:
                await asyncio.sleep(0.001)
                continue
            
            # Run one generation step
            input_batch = torch.stack([r.input_ids for r in self.active_requests])
            with torch.no_grad():
                outputs = self.model(input_batch)
            
            # Process each request
            completed = []
            for i, request in enumerate(self.active_requests):
                next_token = sample_token(outputs.logits[i])
                request.input_ids = torch.cat([request.input_ids, next_token.unsqueeze(0)])
                request.tokens_generated += 1
                
                if next_token.item() == EOS_TOKEN or request.tokens_generated >= request.max_tokens:
                    request.done = True
                    completed.append(request)
            
            # Remove completed requests
            for req in completed:
                self.active_requests.remove(req)

Scheduling Policies

First-In-First-Out (FIFO)

class FIFOScheduler:
    def select_next(self, pending, active, max_batch):
        if len(active) >= max_batch:
            return None
        if not pending:
            return None
        return pending.pop(0)

Shortest-Job-First (SJF)

DfShortest-Job-First Scheduling

SJF scheduling prioritizes requests expected to produce shorter outputs. This reduces average latency and increases throughput by completing more requests per unit time.

class SJFScheduler:
    def select_next(self, pending, active, max_batch):
        if len(active) >= max_batch:
            return None
        if not pending:
            return None
        # Sort by estimated output length
        pending.sort(key=lambda r: r.estimated_output_length)
        return pending.pop(0)

Preemptive Scheduling

DfPreemptive Scheduling

Preemptive scheduling can pause low-priority requests to make room for high-priority ones. Paused requests are saved to CPU memory and resumed when GPU resources become available.

class PreemptiveScheduler:
    def __init__(self, preempt_threshold=0.1):
        self.preempt_threshold = preempt_threshold
    
    def maybe_preempt(self, active_requests, new_request):
        if new_request.priority > min(r.priority for r in active_requests):
            # Preempt lowest priority request
            victim = min(active_requests, key=lambda r: r.priority)
            self.move_to_cpu(victim)
            active_requests.remove(victim)
            active_requests.append(new_request)

Memory Management

KV Cache Management

DfKV Cache Memory

The KV cache stores key and value tensors for all previous tokens. For a 70B model with sequence length 4096, the KV cache requires ~40GB — often the memory bottleneck in serving.

KV Cache Memory per Request

MKV=2×L×H×D×S×bytes_per_paramM_{\text{KV}} = 2 \times L \times H \times D \times S \times \text{bytes\_per\_param}

Here,

  • LL=Number of layers
  • HH=Number of attention heads
  • DD=Head dimension
  • SS=Sequence length

For continuous batching, KV cache must be managed dynamically as requests join and leave the batch. PagedAttention (vLLM) and RadixAttention (SGLang) are key innovations for efficient KV cache management.

Production Systems

vLLM Architecture

Architecture Diagram
User Request -> [API Server] -> [Scheduler] -> [Model Worker]
                       |              |              |
                       v              v              v
                  [Request Queue] [Batch Manager] [KV Cache Manager]
                                           |
                                           v
                                    [Token Sampler]
                                           |
                                           v
                                    [Response Stream]

Performance Comparison

SystemThroughput (tokens/s)Latency (p99)GPU Utilization
HuggingFace (naive)1,000500ms30-40%
vLLM (continuous)5,000200ms80-90%
TensorRT-LLM8,000150ms85-95%
SGLang6,000180ms82-92%

Practice Exercises

  1. Throughput Analysis: Compare the throughput of static vs continuous batching for a workload with 50% short (<10 tokens) and 50% long (>500 tokens) requests.

  2. Scheduler Design: Implement a priority-based scheduler that minimizes tail latency (p99) while maintaining high throughput.

  3. KV Cache Budget: If you have 80GB GPU memory and the model uses 40GB, how many concurrent requests can you serve with continuous batching?

  4. Preemption Policy: Design a preemption policy that minimizes the number of preempted requests while ensuring high-priority requests complete on time.

Key Takeaways

Summary: Continuous Batching

  • Static batching wastes GPU resources waiting for the longest request
  • Continuous batching adds/removes requests at every generation step
  • 3-5x throughput improvement over static batching
  • Scheduling policies (FIFO, SJF, preemptive) affect latency and fairness
  • KV cache management is the key memory bottleneck
  • PagedAttention and RadixAttention enable efficient memory sharing
  • Production systems (vLLM, TensorRT-LLM, SGLang) implement continuous batching
  • GPU utilization increases from 30-40% to 80-95% with continuous batching

What to Learn Next

-> KV Cache Optimization Reducing memory usage of the key-value cache.

-> Speculative Decoding Generating multiple tokens per step for faster inference.

-> LLM Inference Optimization Broader strategies for making LLM inference faster.

-> Flash Attention and Memory Efficiency IO-aware attention algorithms.

-> Building Production LLM Applications End-to-end production LLM systems.

-> Model Parallelism and Tensor Parallelism Splitting models across GPUs.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement