GPU Training & Distributed ML
Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe
GPU Optimization
Efficient GPU utilization reduces training time and cost significantly.
βΉοΈ
Google's TPU pods can train GPT-3 in hours instead of weeks, reducing costs by 10x.
Multi-GPU Training
# distributed_training.py
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.cuda.amp import autocast, GradScaler
import os
from typing import Dict, Optional
from dataclasses import dataclass
@dataclass
class DistributedConfig:
backend: str = "nccl"
world_size: int = 1
rank: int = 0
local_rank: int = 0
master_addr: str = "localhost"
master_port: str = "12355"
class DistributedTrainer:
def __init__(self, config: DistributedConfig):
self.config = config
self.setup_distributed()
def setup_distributed(self):
os.environ["MASTER_ADDR"] = self.config.master_addr
os.environ["MASTER_PORT"] = self.config.master_port
dist.init_process_group(
backend=self.config.backend,
rank=self.config.rank,
world_size=self.config.world_size
)
torch.cuda.set_device(self.config.local_rank)
self.device = torch.device(f"cuda:{self.config.local_rank}")
def cleanup(self):
dist.destroy_process_group()
def train_epoch(
self,
model: nn.Module,
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
scaler: GradScaler,
epoch: int
) -> Dict[str, float]:
model.train()
total_loss = 0.0
correct = 0
total = 0
sampler = dataloader.sampler
if isinstance(sampler, DistributedSampler):
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(self.device), target.to(self.device)
optimizer.zero_grad()
with autocast():
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
metrics = {
"loss": total_loss / len(dataloader),
"accuracy": correct / total,
"lr": optimizer.param_groups[0]["lr"]
}
if self.config.rank == 0:
self._log_metrics(metrics, epoch)
return metrics
def _log_metrics(self, metrics: Dict, epoch: int):
print(f"Epoch {epoch}: {metrics}")
def setup_model(model: nn.Module, config: DistributedConfig) -> DDP:
model = model.to(config.local_rank)
model = DDP(model, device_ids=[config.local_rank])
return model
def setup_dataloaders(
train_dataset,
val_dataset,
batch_size: int,
config: DistributedConfig
):
train_sampler = DistributedSampler(
train_dataset,
num_replicas=config.world_size,
rank=config.rank
)
val_sampler = DistributedSampler(
val_dataset,
num_replicas=config.world_size,
rank=config.rank
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=4,
pin_memory=True,
persistent_workers=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
sampler=val_sampler,
num_workers=4,
pin_memory=True
)
return train_loader, val_loader
Mixed Precision Training
# mixed_precision.py
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from typing import Dict, Tuple
import time
class MixedPrecisionTrainer:
def __init__(self, model: nn.Module, device: torch.device):
self.model = model.to(device)
self.device = device
self.scaler = GradScaler()
self.metrics_history = []
def train_step(
self,
batch: Tuple[torch.Tensor, torch.Tensor],
criterion: nn.Module,
optimizer: torch.optim.Optimizer
) -> Dict[str, float]:
data, target = batch
data, target = data.to(self.device), target.to(self.device)
optimizer.zero_grad()
with autocast():
output = self.model(data)
loss = criterion(output, target)
self.scaler.scale(loss).backward()
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(optimizer)
self.scaler.update()
return {"loss": loss.item()}
def train_epoch(
self,
dataloader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer
) -> Dict[str, float]:
self.model.train()
total_loss = 0.0
num_batches = 0
for batch in dataloader:
metrics = self.train_step(batch, criterion, optimizer)
total_loss += metrics["loss"]
num_batches += 1
avg_loss = total_loss / num_batches
return {"avg_loss": avg_loss}
def benchmark(self, dataloader, criterion, optimizer, num_steps: int = 100) -> Dict:
self.model.train()
start_time = time.time()
for i, batch in enumerate(dataloader):
if i >= num_steps:
break
self.train_step(batch, criterion, optimizer)
elapsed = time.time() - start_time
throughput = num_steps / elapsed
return {
"elapsed_seconds": elapsed,
"throughput_steps_per_sec": throughput,
"avg_step_time_ms": (elapsed / num_steps) * 1000
}
Efficient Data Loading
# efficient_dataloader.py
import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.utils.data._utils.collate import default_collate
import numpy as np
from typing import Iterator, Callable, Optional
from functools import partial
class OptimizedDataLoader:
def __init__(
self,
dataset: Dataset,
batch_size: int = 32,
num_workers: int = 4,
pin_memory: bool = True,
persistent_workers: bool = True,
prefetch_factor: int = 2,
collate_fn: Optional[Callable] = None
):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.prefetch_factor = prefetch_factor
self.collate_fn = collate_fn or default_collate
def get_loader(self, distributed: bool = False, rank: int = 0, world_size: int = 1):
sampler = None
if distributed:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(
self.dataset,
num_replicas=world_size,
rank=rank
)
loader = DataLoader(
self.dataset,
batch_size=self.batch_size,
sampler=sampler,
shuffle=(sampler is None),
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers if self.num_workers > 0 else False,
prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
collate_fn=self.collate_fn,
drop_last=True
)
return loader
class StreamingDataset(IterableDataset):
def __init__(self, data_source: Callable, transform: Optional[Callable] = None):
self.data_source = data_source
self.transform = transform
def __iter__(self) -> Iterator:
for data in self.data_source():
if self.transform:
data = self.transform(data)
yield data
def worker_init_fn(worker_id: int):
np.random.seed(np.random.get_state()[1][0] + worker_id)
Follow-Up Questions
- How do you implement gradient accumulation for large batch training?
- What are the trade-offs between data parallelism and model parallelism?
- How would you handle fault tolerance in distributed training?
- What profiling tools work best for GPU training optimization?