Advanced Training
Distributed Training for LLMs — Scaling Beyond Single GPUs
Training modern LLMs requires distributing computation across hundreds or thousands of GPUs. This guide covers the fundamental parallelism strategies, memory optimization techniques, and frameworks that make large-scale training possible.
- Data Parallelism — Replicate model across GPUs, split data batches
- Tensor Parallelism — Split individual layers across multiple GPUs
- Pipeline Parallelism — Split model layers across GPU stages
No single GPU can hold a trillion-parameter model. Distribution is not optional — it is the foundation of modern AI.
Distributed Training for LLMs
Training modern LLMs requires distributing computation across hundreds or thousands of GPUs. A single NVIDIA A100 has 80GB of GPU memory — far insufficient for models with hundreds of billions of parameters. This tutorial covers the parallelism strategies, memory optimization techniques, and frameworks that enable large-scale LLM training.
DfDistributed Training
Distributed training is the process of splitting model training across multiple compute devices (GPUs/TPUs) to enable training of models that exceed single-device memory capacity and to reduce training time through parallel computation.
Why Distributed Training is Necessary
Memory Requirements
The memory required to train a model depends on its parameter count, batch size, and optimizer state:
Training Memory Formula
Here,
- =Number of model parameters
- =Bytes per parameter (FP32)
- =Bytes for gradients (FP32)
- =Bytes for Adam optimizer state (m)
- =Bytes for Adam optimizer state (v)
- =Batch size multiplier
Memory Calculation for a 7B Model
For a 7B parameter model with mixed-precision training:
- Parameters (FP16): 7B x 2 bytes = 14 GB
- Gradients (FP16): 7B x 2 bytes = 14 GB
- Optimizer states (FP32): 7B x 12 bytes = 84 GB
- Activations: ~16 GB (varies with batch size)
- Total: ~128 GB — Requires at least 2x A100 80GB GPUs
Scaling Laws and Training Time
Training Throughput Scaling
Here,
- =Training time in seconds
- =Number of parameters
- =Number of training tokens
- =Number of GPUs
Scaling from 1 to 64 GPUs does not give 64x speedup due to communication overhead. Real-world scaling efficiency is typically 60-80% for well-optimized training runs.
Parallelism Strategies
Data Parallelism (DP)
The simplest form of distributed training — replicate the model on every GPU and split the data:
DfData Parallelism
In data parallelism, each GPU holds a complete copy of the model. The training batch is split across GPUs, each computes gradients independently, and gradients are averaged via all-reduce before updating weights.
GPU 0: Model Copy + Batch 0 -> Gradients_0 --+
GPU 1: Model Copy + Batch 1 -> Gradients_1 --+-> All-Reduce -> Average -> Update
GPU 2: Model Copy + Batch 2 -> Gradients_2 --+
GPU 3: Model Copy + Batch 3 -> Gradients_3 --+
Data Parallelism Gradient Aggregation
Here,
- =Averaged gradient across all GPUs
- =Number of GPUs (data parallel groups)
- =Gradient computed on GPU i
Limitations:
- Each GPU must hold the entire model (parameters + optimizer states + gradients)
- Communication cost scales with model size (all-reduce of gradients)
- Cannot train models larger than single GPU memory
ZeRO (Zero Redundancy Optimizer)
DeepSpeed ZeRO eliminates memory redundancy in data parallelism:
DfZeRO Optimization
ZeRO (Zero Redundancy Optimizer) partitions optimizer states, gradients, and parameters across data-parallel GPUs, eliminating memory redundancy while maintaining computational equivalence to data parallelism.
| ZeRO Stage | What is Partitioned | Memory Savings |
|---|---|---|
| Stage 1 | Optimizer states | 4x reduction |
| Stage 2 | Optimizer states + Gradients | 8x reduction |
| Stage 3 | Optimizer states + Gradients + Parameters | Nx reduction (N = GPUs) |
ZeRO Stage 3 Memory Layout:
GPU 0: Optimizer shard 0 + Gradient shard 0 + Parameter shard 0
GPU 1: Optimizer shard 1 + Gradient shard 1 + Parameter shard 1
GPU 2: Optimizer shard 2 + Gradient shard 2 + Parameter shard 2
GPU 3: Optimizer shard 3 + Gradient shard 3 + Parameter shard 3
ZeRO Stage 3 enables training a 175B parameter model on 1024 GPUs with only ~1.6GB of parameter memory per GPU, compared to 350GB with standard data parallelism.
Fully Sharded Data Parallel (FSDP)
PyTorch's native implementation of ZeRO-3:
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
model = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=4096, nhead=32),
num_layers=80
)
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
),
auto_wrap_policy=transformer_auto_wrap_policy,
device_id=torch.cuda.current_device(),
)
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
Tensor Parallelism (TP)
Split individual layers across multiple GPUs:
DfTensor Parallelism
In tensor parallelism, individual operations (matrix multiplications) are split across GPUs. Each GPU computes a portion of the output, and results are combined via all-reduce.
For a linear layer Y = XW where W is split column-wise:
W = [W_0 | W_1] (split across 2 GPUs)
GPU 0: Y_0 = X @ W_0
GPU 1: Y_1 = X @ W_1
Y = [Y_0 | Y_1] (concatenated)
Tensor Parallelism Communication
Here,
- =Batch size
- =Sequence length
- =Hidden dimension
- =Tensor parallelism degree
Pipeline Parallelism (PP)
Split model layers across GPU stages:
DfPipeline Parallelism
In pipeline parallelism, different layers of the model reside on different GPUs. Data flows sequentially through the pipeline stages. Micro-batching is used to keep all stages busy simultaneously.
Stage 0 (GPU 0): Layers 0-19 -> Output_0
Stage 1 (GPU 1): Layers 20-39 -> Output_1
Stage 2 (GPU 2): Layers 40-59 -> Output_2
Stage 3 (GPU 3): Layers 60-79 -> Final Output
Pipeline Bubble Efficiency
Here,
- =Pipeline efficiency (fraction of useful compute)
- =Number of micro-batches
- =Number of pipeline stages
Pipeline Bubble Calculation
With 4 pipeline stages and 16 micro-batches: eta = (16 - 1) / (16 + 4 - 1) = 15/19 = 78.9% The pipeline bubble accounts for ~21% of total training time. To achieve >95% efficiency, need m >> p (at least 3-4x more micro-batches than stages).
3D Parallelism
Modern LLM training combines all three strategies:
3D Parallelism GPU Allocation
Here,
- =Total number of GPUs
- =Data parallel degree
- =Tensor parallel degree
- =Pipeline parallel degree
3D Parallelism Configuration
For a 175B parameter model on 1024 GPUs:
- TP = 8 (split each layer across 8 GPUs)
- PP = 16 (16 pipeline stages)
- DP = 8 (8 data-parallel replicas)
- Total: 8 x 16 x 8 = 1024 GPUs
Mixed-Precision Training
FP16 and BF16
DfMixed-Precision Training
Mixed-precision training uses lower-precision formats (FP16/BF16) for most computations while maintaining a master copy of weights in FP32 for numerical stability. This reduces memory usage by ~50% and increases throughput on tensor cores.
| Format | Bits | Range | Precision | Use Case |
|---|---|---|---|---|
| FP32 | 32 | +/-3.4x10^38 | High | Master weights |
| FP16 | 16 | +/-65504 | Lower | Forward/backward pass |
| BF16 | 16 | +/-3.4x10^38 | Lower | Forward/backward pass |
| INT8 | 8 | +/-128 | Very low | Inference |
BF16 is preferred over FP16 for LLM training because it has the same exponent range as FP32, avoiding overflow/underflow issues. However, BF16 has lower mantissa precision (7 bits vs 10 bits for FP16).
Loss Scaling
scaler = torch.cuda.amp.GradScaler()
for batch in dataloader:
optimizer.zero_grad()
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = model(batch["input_ids"])
loss = criterion(outputs, batch["labels"])
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
Communication Optimization
Gradient Accumulation
DfGradient Accumulation
Gradient accumulation simulates larger batch sizes by accumulating gradients over multiple micro-batches before performing a parameter update.
accumulation_steps = 8
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
Effective Batch Size
Here,
- =Effective batch size
- =Batch size per GPU
- =Data parallel degree
- =Gradient accumulation steps
Communication Overhead
All-Reduce Communication Cost
Here,
- =Latency per communication hop
- =Inverse bandwidth (seconds per byte)
- =Number of GPUs
- =Message size (bytes)
Training Frameworks
DeepSpeed
import deepspeed
ds_config = {
"train_batch_size": 1024,
"gradient_accumulation_steps": 8,
"fp16": {"enabled": True, "loss_scale": 0, "initial_scale_power": 16},
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": 5e7,
"stage3_prefetch_bucket_size": 5e7,
"stage3_param_persistence_threshold": 1e6
},
"activation_checkpointing": {
"partition_activations": True,
"cpu_checkpointing": False,
"contiguous_memory_optimization": False
}
}
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
config=ds_config,
model_parameters=model.parameters()
)
Activation Checkpointing
DfActivation Checkpointing
Activation checkpointing (gradient checkpointing) saves memory by discarding intermediate activations during the forward pass and recomputing them during the backward pass. This trades compute for memory — typically a 60-70% memory reduction at a 30-40% compute overhead.
Memory vs Compute Tradeoff
Here,
- =Number of layers
For a 70B parameter model, activation checkpointing reduces activation memory from ~60GB to ~20GB per GPU, enabling training with larger batch sizes or on fewer GPUs.
Memory Optimization Techniques
CPU Offloading
DfCPU Offloading
CPU offloading moves optimizer states or parameters to CPU memory during training, freeing GPU memory for activations. Data is transferred between CPU and GPU as needed, introducing I/O overhead.
ds_config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu", "pin_memory": True},
"offload_param": {"device": "cpu", "pin_memory": True}
}
}
CPU offloading can reduce GPU memory by 4-8x but introduces 20-50% training overhead due to CPU-GPU data transfers. Use only when GPU memory is insufficient even with ZeRO-3.
Training Stability
Learning Rate Scheduling
from transformers import get_cosine_schedule_with_warmup
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=2000,
num_training_steps=100000,
min_lr_ratio=0.1
)
Cosine Learning Rate Schedule
Here,
- =Learning rate at step t
- =Minimum learning rate
- =Maximum learning rate
- =Total training steps
Gradient Clipping
max_grad_norm = 1.0
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
Gradient clipping is essential for LLM training stability. Without it, a single large gradient can cause loss spikes or divergence, especially in early training stages.
Practice Exercises
-
ZeRO Stage Comparison: Calculate the memory savings for training a 7B model using ZeRO Stage 1, 2, and 3 on 4 GPUs. What is the memory per GPU for each stage?
-
Pipeline Bubble Analysis: Compute the pipeline efficiency for a 32-stage pipeline with 8, 16, 32, and 64 micro-batches. At what point does efficiency exceed 90%?
-
3D Parallelism Design: Design a 3D parallelism configuration for a 405B parameter model on 2048 GPUs. Justify your choice of TP, PP, and DP degrees.
-
Communication Analysis: For a 70B model with gradient all-reduce on 256 GPUs, calculate the communication volume and estimated latency assuming 100 Gbps interconnect.
Key Takeaways
Summary: Distributed Training for LLMs
- Data parallelism replicates models across GPUs and splits data — simple but memory-inefficient
- ZeRO/FSDP partitions optimizer states, gradients, and parameters to eliminate redundancy
- Tensor parallelism splits individual layers across GPUs for very large models
- Pipeline parallelism splits model layers across GPU stages with micro-batching
- 3D parallelism combines all three strategies for massive-scale training
- Mixed precision (BF16) halves memory and increases throughput on tensor cores
- Activation checkpointing trades 30% more compute for 60% less memory
- Communication overhead is the key bottleneck — ring all-reduce is optimal
- Gradient accumulation simulates larger batch sizes without extra memory
What to Learn Next
-> Data Quality and Curation for LLMs How data quality, deduplication, and filtering impact model performance.
-> Knowledge Distillation for LLMs Compressing large models into smaller, faster ones without losing capability.
-> Flash Attention and Memory Efficiency IO-aware attention algorithms that reduce memory and increase speed.
-> Model Parallelism and Tensor Parallelism Deep dive into splitting models across GPUs for training and inference.
-> Training Deep Networks Foundational techniques for training neural networks effectively.
-> Scaling Laws and Chinchilla How model performance scales with compute, data, and parameters.