GAN Fundamentals
Generative Adversarial Networks learn to generate realistic data by pitting two neural networks against each other: a generator creates fake samples, and a discriminator tries to tell real from fake. Through this adversarial game, both improve until the generator produces indistinguishable data.
GAN Architecture
The GAN Framework
The generator G maps noise z to data space, while the discriminator D classifies real vs generated samples. They play a minimax game: min_G max_D V(D, G).
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
Simple GAN on 2D Data
def generate_real_data(n=1000):
"""Mixture of two Gaussians."""
centers = np.array([[-2, -2], [2, 2]])
labels = np.random.randint(0, 2, n)
data = centers[labels] + np.random.randn(n, 2) * 0.5
return data.astype(np.float32)
class Generator(nn.Module):
def __init__(self, latent_dim=2, hidden=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, hidden),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(hidden),
nn.Linear(hidden, hidden),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(hidden),
nn.Linear(hidden, 2)
)
def forward(self, z):
return self.net(z)
class Discriminator(nn.Module):
def __init__(self, hidden=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(2, hidden),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(hidden, hidden),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(hidden, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
# Training loop
def train_gan(n_epochs=2000, batch_size=64, latent_dim=2, lr=0.0002):
G = Generator(latent_dim).to(device)
D = Discriminator().to(device)
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()
real_data = generate_real_data(1000)
dataset = TensorDataset(torch.FloatTensor(real_data))
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
d_losses, g_losses = [], []
for epoch in range(n_epochs):
for (real_batch,) in loader:
real_batch = real_batch.to(device)
batch_size_actual = real_batch.size(0)
# Train Discriminator
z = torch.randn(batch_size_actual, latent_dim).to(device)
fake = G(z).detach()
d_real = D(real_batch)
d_fake = D(fake)
d_loss = criterion(d_real, torch.ones_like(d_real)) + \
criterion(d_fake, torch.zeros_like(d_fake))
opt_D.zero_grad()
d_loss.backward()
opt_D.step()
# Train Generator
z = torch.randn(batch_size_actual, latent_dim).to(device)
fake = G(z)
d_fake = D(fake)
g_loss = criterion(d_fake, torch.ones_like(d_fake))
opt_G.zero_grad()
g_loss.backward()
opt_G.step()
d_losses.append(d_loss.item())
g_losses.append(g_loss.item())
if (epoch + 1) % 500 == 0:
print(f"Epoch {epoch+1}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")
return G, D, d_losses, g_losses
G, D, d_losses, g_losses = train_gan()
Mode Collapse
Mode collapse occurs when the generator learns to produce only a subset of the real data distribution.
def detect_mode_collapse(generator, n_samples=5000, latent_dim=2, n_bins=50):
"""Check if generator covers all modes."""
z = torch.randn(n_samples, latent_dim).to(device)
generated = generator(z).detach().cpu().numpy()
# Check coverage using histogram
hist, xedges, yedges = np.histogram2d(
generated[:, 0], generated[:, 1], bins=n_bins, range=[[-5, 5], [-5, 5]]
)
# Modes are cells with significant density
threshold = hist.max() * 0.01
active_modes = (hist > threshold).sum()
print(f"Active modes (out of {n_bins**2} cells): {active_modes}")
print(f"Mode coverage: {active_modes / (n_bins**2):.3f}")
return generated
generated = detect_mode_collapse(G)
WGAN: Wasserstein GAN
WGAN uses Wasserstein distance instead of JS divergence, providing more stable training.
class WGAN_Generator(nn.Module):
def __init__(self, latent_dim=2, hidden=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, hidden),
nn.LeakyReLU(0.2),
nn.Linear(hidden, hidden),
nn.LeakyReLU(0.2),
nn.Linear(hidden, 2)
)
def forward(self, z):
return self.net(z)
class WGAN_Critic(nn.Module): # No sigmoid β outputs unbounded score
def __init__(self, hidden=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(2, hidden),
nn.LeakyReLU(0.2),
nn.Linear(hidden, hidden),
nn.LeakyReLU(0.2),
nn.Linear(hidden, 1)
)
def forward(self, x):
return self.net(x)
def train_wgan(n_epochs=3000, batch_size=64, latent_dim=2, n_critic=5, clip_value=0.01):
G = WGAN_Generator(latent_dim).to(device)
C = WGAN_Critic().to(device)
opt_G = optim.Adam(G.parameters(), lr=0.00005, betas=(0.5, 0.9))
opt_C = optim.Adam(C.parameters(), lr=0.00005, betas=(0.5, 0.9))
real_data = generate_real_data(1000)
dataset = TensorDataset(torch.FloatTensor(real_data))
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(n_epochs):
for i, (real_batch,) in enumerate(loader):
real_batch = real_batch.to(device)
# Train Critic (n_critic times per generator step)
for _ in range(n_critic):
z = torch.randn(real_batch.size(0), latent_dim).to(device)
fake = G(z).detach()
c_real = C(real_batch).mean()
c_fake = C(fake).mean()
c_loss = -(c_real - c_fake) # Wasserstein distance
opt_C.zero_grad()
c_loss.backward()
opt_C.step()
# Weight clipping for Lipschitz constraint
for p in C.parameters():
p.data.clamp_(-clip_value, clip_value)
# Train Generator
z = torch.randn(real_batch.size(0), latent_dim).to(device)
fake = G(z)
g_loss = -C(fake).mean()
opt_G.zero_grad()
g_loss.backward()
opt_G.step()
if (epoch + 1) % 500 == 0:
print(f"Epoch {epoch+1}: Critic loss={c_loss.item():.4f}")
return G, C
G_wgan, C_wgan = train_wgan()
Conditional GAN
Generate data conditioned on class labels.
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim=2, n_classes=2, hidden=64):
super().__init__()
self.label_embed = nn.Embedding(n_classes, 2)
self.net = nn.Sequential(
nn.Linear(latent_dim + 2, hidden),
nn.LeakyReLU(0.2),
nn.Linear(hidden, hidden),
nn.LeakyReLU(0.2),
nn.Linear(hidden, 2)
)
def forward(self, z, labels):
label_emb = self.label_embed(labels)
x = torch.cat([z, label_emb], dim=1)
return self.net(x)
class ConditionalDiscriminator(nn.Module):
def __init__(self, n_classes=2, hidden=64):
super().__init__()
self.label_embed = nn.Embedding(n_classes, 2)
self.net = nn.Sequential(
nn.Linear(4, hidden),
nn.LeakyReLU(0.2),
nn.Linear(hidden, hidden),
nn.LeakyReLU(0.2),
nn.Linear(hidden, 1),
nn.Sigmoid()
)
def forward(self, x, labels):
label_emb = self.label_embed(labels)
x = torch.cat([x, label_emb], dim=1)
return self.net(x)
# Conditional generation
latent_dim = 2
G_cond = ConditionalGenerator(latent_dim).to(device)
z = torch.randn(100, latent_dim).to(device)
labels = torch.zeros(100, dtype=torch.long).to(device) # Generate class 0
generated = G_cond(z, labels)
print(f"Conditional generation shape: {generated.shape}")
GAN Training Best Practices
# Spectral Normalization for stable training
class SN_Discriminator(nn.Module):
def __init__(self, hidden=64):
super().__init__()
self.net = nn.Sequential(
nn.utils.spectral_norm(nn.Linear(2, hidden)),
nn.LeakyReLU(0.2),
nn.utils.spectral_norm(nn.Linear(hidden, hidden)),
nn.LeakyReLU(0.2),
nn.utils.spectral_norm(nn.Linear(hidden, 1)),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
# Two Time-Scale Update Rule (TTUR)
G_ttur = Generator().to(device)
D_ttur = SN_Discriminator().to(device)
opt_G = optim.Adam(G_ttur.parameters(), lr=0.0001, betas=(0.5, 0.999)) # slower
opt_D = optim.Adam(D_ttur.parameters(), lr=0.0004, betas=(0.5, 0.999)) # faster
print("Spectral normalization and TTUR applied")
Evaluating GANs
# FID-like simplified metric
def compute_fid(real_data, generated_data):
"""Simplified FID using mean and covariance."""
real_mean = real_data.mean(axis=0)
fake_mean = generated_data.mean(axis=0)
real_cov = np.cov(real_data.T) + np.eye(real_data.shape[1]) * 1e-6
fake_cov = np.cov(generated_data.T) + np.eye(generated_data.shape[1]) * 1e-6
from scipy.linalg import sqrtm
diff = real_mean - fake_mean
covmean = sqrtm(real_cov @ fake_cov)
fid = diff @ diff + np.trace(real_cov + fake_cov - 2 * covmean)
return np.real(fid)
real_data = generate_real_data(1000)
z = torch.randn(1000, 2).to(device)
generated_data = G(z).detach().cpu().numpy()
fid = compute_fid(real_data, generated_data)
print(f"Simplified FID: {fid:.4f} (lower is better)")
Best Practices
- Use WGAN-GP β gradient penalty is more stable than weight clipping
- Spectral normalization β stabilizes discriminator training
- TTUR β different learning rates for G and D
- Batch normalization in generator only β discriminator uses layer norm
- Monitor mode collapse β check if generated samples cover all modes
- Evaluate with FID/IS β quantitative metrics, not just visual inspection
Summary
GANs learn to generate realistic data through adversarial training. Understand the minimax game, mode collapse, and stabilization techniques (WGAN, spectral norm, TTUR) to train reliable generators for data augmentation, style transfer, and creative applications.