🎯 The Interview Question
"Explain the difference between data parallelism and model parallelism. How does PyTorch's DDP (Distributed Data Parallel) work? What is FSDP (Fully Sharded Data Parallelism) and how does it reduce memory usage? What are the communication patterns in distributed training, and how do you optimize for them?"
This question is critical for training large models at NVIDIA (hardware) and Google (TPU infrastructure).
📚 Detailed Answer
Why Distributed Training?
Large models don't fit on single GPUs:
| Model | Parameters | Memory (FP32) | Memory (FP16) |
|---|---|---|---|
| BERT-Large | 340M | 1.3 GB | 0.7 GB |
| GPT-3 | 175B | 700 GB | 350 GB |
| LLaMA-70B | 70B | 280 GB | 140 GB |
Even with A100 80GB, we need multiple GPUs.
Data Parallelism
Basic idea: Replicate model on each GPU, split data:
GPU 0: Model replica + Batch 0 → Gradients 0
GPU 1: Model replica + Batch 1 → Gradients 1
GPU 2: Model replica + Batch 2 → Gradients 2
GPU 3: Model replica + Batch 3 → Gradients 3
↓
AllReduce gradients
↓
Update all replicas
AllReduce Operation
Computes sum/average of gradients across all GPUs:
Ring AllReduce: GPUs form a ring, each sends/receives once:
- Communication: per GPU
- Latency:
💡
DDP (Distributed Data Parallel) in PyTorch uses Ring AllReduce by default. It's bandwidth-optimal for homogeneous networks.
Model Parallelism
Split model across GPUs:
Pipeline Parallelism
Split model by layers:
GPU 0: Layers 0-9
GPU 1: Layers 10-19
GPU 2: Layers 20-29
GPU 3: Layers 30-39
Problem: Pipeline bubbles — GPUs idle while waiting for activations.
Solution: Micro-batching (GPipe, PipeDream):
Time →
GPU 0: [Micro-batch 1][Micro-batch 2][Micro-batch 3][Micro-batch 4]
GPU 1: [Micro-batch 1][Micro-batch 2][Micro-batch 3]
GPU 2: [Micro-batch 1][Micro-batch 2]
GPU 3: [Micro-batch 1]
Tensor Parallelism
Split individual operations across GPUs:
For matrix multiplication :
Each GPU computes part of the output, then AllGather.
FSDP (Fully Sharded Data Parallel)
Combines data parallelism with model parallelism:
Key insight: Shard model parameters, gradients, and optimizer states:
GPU 0: [Params shard 0][Gradients shard 0][Optimizer shard 0]
GPU 1: [Params shard 1][Gradients shard 1][Optimizer shard 1]
GPU 2: [Params shard 2][Gradients shard 2][Optimizer shard 2]
GPU 3: [Params shard 3][Gradients shard 3][Optimizer shard 3]
Memory savings:
- Standard DDP: bytes (FP32 params + grads + optimizer states)
- FSDP: bytes per GPU (where is number of GPUs)
Communication Patterns
| Operation | Pattern | Bandwidth | Use Case |
|---|---|---|---|
| AllReduce | All-to-all | High | Gradient sync |
| AllGather | All-to-all | High | Parameter gathering |
| ReduceScatter | All-to-all | High | Gradient sharding |
| Broadcast | One-to-all | Low | Parameter initialization |
ZeRO Optimization
Stage 1: Shard optimizer states only Stage 2: Shard optimizer states + gradients Stage 3: Shard optimizer states + gradients + parameters
Memory comparison (GPT-3, 175B):
| Stage | Memory/GPU |
|---|---|
| DDP | 700 GB |
| ZeRO-1 | 175 GB |
| ZeRO-2 | 88 GB |
| ZeRO-3 (FSDP) | 22 GB |
Communication Optimization
Overlap computation and communication:
# Overlap gradient all-reduce with backward pass
# PyTorch DDP does this automatically with find_unused_parameters=False
# Overlap parameter gathering with forward pass
# FSDP uses backward_prefetch=BACKWARD_PRE for this
Compression techniques:
- Gradient quantization (FP32 → FP16/INT8)
- Gradient sparsification (only send top-k gradients)
- Error feedback (accumulate quantization errors)
Follow-Up Questions
Q: When should you use pipeline parallelism vs tensor parallelism? A: Pipeline parallelism for very deep models (split by layers). Tensor parallelism for very wide models (split within layers). Usually combined with data parallelism.
Q: How does FSDP compare to DeepSpeed ZeRO? A: FSDP is PyTorch native; DeepSpeed has more features (offloading, CPUAdam). FSDP is simpler; DeepSpeed is more flexible. Performance is similar.
Q: What is the communication bottleneck in distributed training? A: Gradient synchronization (AllReduce) during backward pass. Solutions: overlap with computation, gradient compression, reduce frequency of synchronization.