GANs Deep Dive — Generative Adversarial Networks

Generative ModelsGANsFree Lesson

Advertisement

GANs Deep Dive — Generative Adversarial Networks

GANs learn to generate realistic data by pitting two neural networks against each other in a minimax game: a generator creates fake samples, and a discriminator tries to distinguish real from fake.


The GAN Framework

DfGAN Framework

A GAN consists of:

  • Generator GG: Maps random noise zpz(z)z \sim p_z(z) to fake data G(z)G(z)
  • Discriminator DD: Outputs probability that input is real data xx

The two networks compete: GG tries to fool DD, while DD tries to correctly classify real vs. fake. Training converges when GG produces data indistinguishable from real data.

minGmaxD  V(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]\min_G \max_D \; V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]

Discriminator Objective

maxD  V(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]\max_D \; V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]

Here,

  • D(x)D(x)=Discriminator's estimate that x is real
  • G(z)G(z)=Generator's fake sample from noise z
  • pdatap_{\text{data}}=Real data distribution
  • pzp_z=Prior noise distribution (e.g., Gaussian)

Generator Objective (Non-saturating)

maxG  Ezpz[logD(G(z)))]\max_G \; \mathbb{E}_{z \sim p_z}[\log D(G(z)))]

Here,

  • logD(G(z))\log D(G(z))=Generator wants discriminator to output 1 for fakes

Nash Equilibrium

ThGlobal Optimum of GAN

The global optimum of the minimax game is achieved when:

pG=pdatap_G = p_{\text{data}}

and the optimal discriminator is:

D(x)=pdata(x)pdata(x)+pG(x)D^*(x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_G(x)}

At this point, V(D,G)=log4V(D^*, G) = -\log 4 and the generator perfectly matches the data distribution.

ℹ️ Interpretation

When pG=pdatap_G = p_{\text{data}}, the discriminator cannot distinguish real from fake and outputs D(x)=0.5D(x) = 0.5 everywhere. The game reaches a Nash equilibrium where neither player can improve by changing strategy unilaterally.


Training Challenges

DfMode Collapse

The generator learns to produce only a few types of outputs that fool the discriminator, ignoring the diversity of the real data distribution. This is the most common failure mode of GANs.

Symptoms: Generator produces very similar outputs regardless of input noise.

DfTraining Instability

GAN training is inherently unstable because:

  1. Non-convergence: Alternating optimization may not converge
  2. Vanishing gradients: When DD is too good, log(1D(G(z)))\log(1-D(G(z))) saturates
  3. Oscillation: GG and DD may cycle without converging
  4. Mode collapse: GG maps all inputs to same output

DCGAN (Deep Convolutional GAN)

DfDCGAN Architecture

DCGAN (Radford et al., 2015) established stable GAN training with architectural guidelines:

  • Replace pooling with strided convolutions (discriminator) and transposed convolutions (generator)
  • Use batch normalization in both networks
  • Remove fully connected layers
  • Use ReLU activation in generator (Tanh for output)
  • Use LeakyReLU in discriminator

Transposed Convolution Output Size

out=(in1)×stride2×padding+kernel\text{out} = (\text{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{kernel}

Here,

  • in\text{in}=Input spatial dimension
  • stride\text{stride}=Stride of transposed convolution
  • padding\text{padding}=Padding
  • kernel\text{kernel}=Kernel size

WGAN (Wasserstein GAN)

DfWGAN

WGAN (Arjovsky et al., 2017) replaces the JS divergence with Wasserstein distance (Earth-Mover distance) for more stable training:

  • Uses Wasserstein distance: W(pdata,pG)=infγE(x,y)γ[xy]W(p_{\text{data}}, p_G) = \inf_{\gamma} \mathbb{E}_{(x,y) \sim \gamma}[\|x - y\|]
  • Discriminator becomes "critic" — outputs scalar, not probability
  • Weight clipping or gradient penalty instead of batch norm in critic
  • Meaningful loss correlate with sample quality
L=Expdata[D(x)]Ezpz[D(G(z))]\mathcal{L} = \mathbb{E}_{x \sim p_{\text{data}}}[D(x)] - \mathbb{E}_{z \sim p_z}[D(G(z))]

💡 WGAN-GP Gradient Penalty

Instead of weight clipping (WGAN), use gradient penalty (WGAN-GP):

LGP=λEx^[(x^D(x^)21)2]\mathcal{L}_{\text{GP}} = \lambda \mathbb{E}_{\hat{x}} \left[ \left( \| \nabla_{\hat{x}} D(\hat{x}) \|_2 - 1 \right)^2 \right]

where x^\hat{x} is interpolated between real and fake samples. This enforces the Lipschitz constraint smoothly.


StyleGAN

DfStyleGAN Architecture

StyleGAN (Karras et al., 2019) introduces style-based generator architecture:

  1. Mapping network: zwz \to w (8 FC layers) maps latent to style space
  2. Adaptive instance normalization (AdaIN): Injects style at each layer
  3. Noise injection: Per-pixel noise for stochastic variation
  4. Progressive growing: Train with increasing resolution

This enables disentangled control over high-level attributes (pose, identity) and stochastic variation (hair, freckles).

AdaIN (Adaptive Instance Normalization)

AdaIN(x,y)=ys,ixiμ(xi)σ(xi)+yb,i\text{AdaIN}(x, y) = y_{s,i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i}

Here,

  • xix_i=Feature map at layer i
  • ys,iy_{s,i}=Style scale (from w)
  • yb,iy_{b,i}=Style bias (from w)
  • μ(xi),σ(xi)\mu(x_i), \sigma(x_i)=Mean and std of feature map

PyTorch Implementation

📝Example: DCGAN

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim=100, channels=3):
        super().__init__()
        self.main = nn.Sequential(
            # Latent -> 512 x 4 x 4
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 512 x 4 x 4 -> 256 x 8 x 8
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 256 x 8 x 8 -> 128 x 16 x 16
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 128 x 16 x 16 -> 64 x 32 x 32
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 64 x 32 x 32 -> 3 x 64 x 64
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.main(z.view(z.size(0), -1, 1, 1))


class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super().__init__()
        self.main = nn.Sequential(
            # 3 x 64 x 64 -> 64 x 32 x 32
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 64 x 32 x 32 -> 128 x 16 x 16
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # 128 x 16 x 16 -> 256 x 8 x 8
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 256 x 8 x 8 -> 512 x 4 x 4
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 512 x 4 x 4 -> 1
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.main(x).view(-1)


# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = Generator(100, 3).to(device)
D = Discriminator(3).to(device)

opt_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCEWithLogitsLoss()

for epoch in range(100):
    for real, _ in dataloader:
        real = real.to(device)
        batch_size = real.size(0)

        # Train Discriminator
        z = torch.randn(batch_size, 100, device=device)
        fake = G(z).detach()
        loss_D = criterion(D(real), torch.ones(batch_size, device=device)) + \
                 criterion(D(fake), torch.zeros(batch_size, device=device))

        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # Train Generator
        z = torch.randn(batch_size, 100, device=device)
        fake = G(z)
        loss_G = criterion(D(fake), torch.ones(batch_size, device=device))

        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

Training Tips

💡 GAN Training Best Practices

  1. Use label smoothing: Real labels = 0.9, fake = 0.1 (reduces overconfidence)
  2. Two-time-scale update rule: Train D more than G (e.g., 5:1 ratio)
  3. Spectral normalization: Apply to D weights for stable training
  4. Progressive growing: Start with low resolution, increase gradually
  5. Track FID/IS: Fréchet Inception Distance is the standard evaluation metric
  6. Avoid batch norm in D: Use layer norm or instance norm instead
  7. Adam optimizer: β1=0.5\beta_1 = 0.5, β2=0.999\beta_2 = 0.999, learning rate 2×1042 \times 10^{-4}

Practice Exercises

  1. Train DCGAN on CIFAR-10: Generate realistic images. Monitor FID over training.

  2. WGAN-GP implementation: Replace BCE loss with Wasserstein loss + gradient penalty. Compare training stability.

  3. Mode collapse experiment: Train a GAN on MNIST and observe mode collapse. Fix it with minibatch discrimination.

  4. Style mixing: Implement StyleGAN and experiment with style mixing at different layers.


Key Takeaways

📋Summary: GANs

  • GANs consist of generator GG and discriminator DD in minimax game
  • Nash equilibrium: pG=pdatap_G = p_{\text{data}}, D(x)=0.5D(x) = 0.5
  • Non-saturating loss: logD(G(z))-\log D(G(z)) instead of log(1D(G(z)))\log(1-D(G(z)))
  • DCGAN: Architectural guidelines for stable training
  • WGAN: Wasserstein distance for better training dynamics
  • StyleGAN: Style-based generation with disentangled controls
  • Mode collapse and training instability are main challenges
  • FID score is the standard evaluation metric
  • GANs excel at image synthesis, style transfer, super-resolution
  • See also: GANs in ML for fundamentals

Advertisement

Need Expert Deep Learning Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement