🎉 75% of content is free forever — Unlock Premium from $10/mo →
CW
Search courses…
💼 Servicesℹ️ About✉️ ContactView Pricing Plansfrom $10

Distributed Training: Data Parallelism, Model Parallelism, FSDP — Asked at NVIDIA & Google

Deep Learning Premium InterviewsDistributed Training⭐ Premium

Advertisement

NVIDIA & Google

Distributed Training: Data Parallelism, Model Parallelism & FSDP

Premium Interview Preparation — Distributed Training Mastery

🎯 The Interview Question

"Explain the difference between data parallelism and model parallelism. How does PyTorch's DDP (Distributed Data Parallel) work? What is FSDP (Fully Sharded Data Parallelism) and how does it reduce memory usage? What are the communication patterns in distributed training, and how do you optimize for them?"

This question is critical for training large models at NVIDIA (hardware) and Google (TPU infrastructure).


📚 Detailed Answer

Why Distributed Training?

Large models don't fit on single GPUs:

ModelParametersMemory (FP32)Memory (FP16)
BERT-Large340M1.3 GB0.7 GB
GPT-3175B700 GB350 GB
LLaMA-70B70B280 GB140 GB

Even with A100 80GB, we need multiple GPUs.

Data Parallelism

Basic idea: Replicate model on each GPU, split data:

Architecture Diagram
GPU 0: Model replica + Batch 0 → Gradients 0
GPU 1: Model replica + Batch 1 → Gradients 1
GPU 2: Model replica + Batch 2 → Gradients 2
GPU 3: Model replica + Batch 3 → Gradients 3
                        ↓
              AllReduce gradients
                        ↓
              Update all replicas

AllReduce Operation

Computes sum/average of gradients across all GPUs:

gˉ=1Ni=1Ngi\bar{g} = \frac{1}{N}\sum_{i=1}^{N} g_i

Ring AllReduce: GPUs form a ring, each sends/receives once:

  • Communication: 2(N1)/N×D2(N-1)/N \times D per GPU
  • Latency: O(N)O(N)

💡

DDP (Distributed Data Parallel) in PyTorch uses Ring AllReduce by default. It's bandwidth-optimal for homogeneous networks.

Model Parallelism

Split model across GPUs:

Pipeline Parallelism

Split model by layers:

Architecture Diagram
GPU 0: Layers 0-9
GPU 1: Layers 10-19
GPU 2: Layers 20-29
GPU 3: Layers 30-39

Problem: Pipeline bubbles — GPUs idle while waiting for activations.

Solution: Micro-batching (GPipe, PipeDream):

Architecture Diagram
Time →
GPU 0: [Micro-batch 1][Micro-batch 2][Micro-batch 3][Micro-batch 4]
GPU 1:                 [Micro-batch 1][Micro-batch 2][Micro-batch 3]
GPU 2:                                 [Micro-batch 1][Micro-batch 2]
GPU 3:                                                 [Micro-batch 1]

Tensor Parallelism

Split individual operations across GPUs:

For matrix multiplication Y=XW\mathbf{Y} = \mathbf{X}\mathbf{W}:

W=[W1W2]\mathbf{W} = [\mathbf{W}_1 | \mathbf{W}_2]
Y=X[W1W2]=[XW1XW2]\mathbf{Y} = \mathbf{X}[\mathbf{W}_1 | \mathbf{W}_2] = [\mathbf{X}\mathbf{W}_1 | \mathbf{X}\mathbf{W}_2]

Each GPU computes part of the output, then AllGather.

FSDP (Fully Sharded Data Parallel)

Combines data parallelism with model parallelism:

Key insight: Shard model parameters, gradients, and optimizer states:

Architecture Diagram
GPU 0: [Params shard 0][Gradients shard 0][Optimizer shard 0]
GPU 1: [Params shard 1][Gradients shard 1][Optimizer shard 1]
GPU 2: [Params shard 2][Gradients shard 2][Optimizer shard 2]
GPU 3: [Params shard 3][Gradients shard 3][Optimizer shard 3]

Memory savings:

  • Standard DDP: 12×P12 \times P bytes (FP32 params + grads + optimizer states)
  • FSDP: 12×P/N12 \times P / N bytes per GPU (where NN is number of GPUs)

Communication Patterns

OperationPatternBandwidthUse Case
AllReduceAll-to-allHighGradient sync
AllGatherAll-to-allHighParameter gathering
ReduceScatterAll-to-allHighGradient sharding
BroadcastOne-to-allLowParameter initialization

ZeRO Optimization

Stage 1: Shard optimizer states only Stage 2: Shard optimizer states + gradients Stage 3: Shard optimizer states + gradients + parameters

Memory comparison (GPT-3, 175B):

StageMemory/GPU
DDP700 GB
ZeRO-1175 GB
ZeRO-288 GB
ZeRO-3 (FSDP)22 GB

Communication Optimization

Overlap computation and communication:

# Overlap gradient all-reduce with backward pass
# PyTorch DDP does this automatically with find_unused_parameters=False

# Overlap parameter gathering with forward pass
# FSDP uses backward_prefetch=BACKWARD_PRE for this

Compression techniques:

  • Gradient quantization (FP32 → FP16/INT8)
  • Gradient sparsification (only send top-k gradients)
  • Error feedback (accumulate quantization errors)

Follow-Up Questions

Q: When should you use pipeline parallelism vs tensor parallelism? A: Pipeline parallelism for very deep models (split by layers). Tensor parallelism for very wide models (split within layers). Usually combined with data parallelism.

Q: How does FSDP compare to DeepSpeed ZeRO? A: FSDP is PyTorch native; DeepSpeed has more features (offloading, CPUAdam). FSDP is simpler; DeepSpeed is more flexible. Performance is similar.

Q: What is the communication bottleneck in distributed training? A: Gradient synchronization (AllReduce) during backward pass. Solutions: overlap with computation, gradient compression, reduce frequency of synchronization.

Related Topics

Advertisement