Diffusion Models Advanced
DDPM Implementation
import torch
import torch.nn as nn
import numpy as np
class DDPM:
def __init__(self, model, T=1000, beta_start=1e-4, beta_end=0.02):
self.model = model
self.T = T
self.betas = torch.linspace(beta_start, beta_end, T)
self.alphas = 1 - self.betas
self.alpha_bar = torch.cumprod(self.alphas, dim=0)
self.sigma_squared = self.betas
def q_sample(self, x_0, t, noise=None):
if noise is None:
noise = torch.randn_like(x_0)
alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
return x_t, noise
def compute_loss(self, x_0):
batch_size = x_0.shape[0]
t = torch.randint(0, self.T, (batch_size,))
noise = torch.randn_like(x_0)
x_t, _ = self.q_sample(x_0, t, noise)
predicted_noise = self.model(x_t, t)
loss = nn.functional.mse_loss(predicted_noise, noise)
return loss
@torch.no_grad()
def p_sample(self, x_t, t):
t_batch = torch.full((x_t.shape[0],), t, device=x_t.device)
predicted_noise = self.model(x_t, t_batch)
alpha_t = self.alphas[t]
alpha_bar_t = self.alpha_bar[t]
beta_t = self.betas[t]
mean = (1 / torch.sqrt(alpha_t)) * (
x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * predicted_noise
)
if t > 0:
noise = torch.randn_like(x_t)
sigma = torch.sqrt(beta_t)
x_t_minus_1 = mean + sigma * noise
else:
x_t_minus_1 = mean
return x_t_minus_1
@torch.no_grad()
def sample(self, shape):
x = torch.randn(shape)
for t in reversed(range(self.T)):
x = self.p_sample(x, t)
return x
ddpm = DDPM(unet_model, T=1000)
loss = ddpm.compute_loss(images)
generated = ddpm.sample((batch_size, 3, 64, 64))
DDIM Sampling (Faster)
class DDIMSampler:
def __init__(self, ddpm_model, ddim_steps=50, eta=0.0):
self.model = ddpm_model
self.ddim_steps = ddim_steps
self.eta = eta
self.time_steps = np.linspace(0, ddpm_model.T - 1, ddim_steps).astype(int)
@torch.no_grad()
def sample(self, shape):
x = torch.randn(shape)
for i in reversed(range(len(self.time_steps))):
t = self.time_steps[i]
t_prev = self.time_steps[i - 1] if i > 0 else -1
predicted_noise = self.model.model(x, torch.full((x.shape[0],), t, device=x.device))
alpha_bar_t = self.model.alpha_bar[t]
alpha_bar_t_prev = self.model.alpha_bar[t_prev] if t_prev >= 0 else torch.tensor(1.0)
pred_x0 = (x - torch.sqrt(1 - alpha_bar_t) * predicted_noise) / torch.sqrt(alpha_bar_t)
pred_x0 = torch.clamp(pred_x0, -1, 1)
sigma_t = self.eta * torch.sqrt(
(1 - alpha_bar_t_prev) / (1 - alpha_bar_t) *
(1 - alpha_bar_t / alpha_bar_t_prev)
)
dir_xt = torch.sqrt(1 - alpha_bar_t_prev - sigma_t ** 2) * predicted_noise
noise = torch.randn_like(x) if t_prev >= 0 else 0
x = torch.sqrt(alpha_bar_t_prev) * pred_x0 + dir_xt + sigma_t * noise
return x
ddim = DDIMSampler(ddpm, ddim_steps=50, eta=0.0)
fast_samples = ddim.sample((8, 3, 64, 64))
U-Net Architecture
class UNetBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
if up:
self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
else:
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bnorm1 = nn.BatchNorm2d(out_ch)
self.bnorm2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU()
def forward(self, x, t):
h = self.bnorm1(self.relu(self.conv1(x)))
h += self.time_mlp(self.relu(t))[:, :, None, None]
h = self.bnorm2(self.relu(self.conv2(h)))
return self.transform(h)
class SimpleUNet(nn.Module):
def __init__(self, image_channels=3, down_channels=(64, 128, 256, 512), up_channels=(512, 256, 128, 64)):
super().__init__()
self.time_embed = nn.Sequential(
nn.Linear(1, 256),
nn.ReLU(),
nn.Linear(256, 256)
)
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
self.downs = nn.ModuleList([
UNetBlock(down_channels[i], down_channels[i+1], 256)
for i in range(len(down_channels) - 1)
])
self.ups = nn.ModuleList([
UNetBlock(up_channels[i], up_channels[i+1], 256, up=True)
for i in range(len(up_channels) - 1)
])
self.output = nn.Conv2d(up_channels[-1], image_channels, 1)
def forward(self, x, timestep):
t = self.time_embed(timestep.float().unsqueeze(-1))
x = self.conv0(x)
residuals = []
for down in self.downs:
x = down(x, t)
residuals.append(x)
for up in self.ups:
residual = residuals.pop()
x = torch.cat([x, residual], dim=1)
x = up(x, t)
return self.output(x)
unet = SimpleUNet()
Best Practices
- Use cosine noise schedule for better sample quality
- DDIM enables faster sampling with fewer steps
- Classifier-free guidance improves conditional generation
- Progressive growing helps train high-resolution models
- Use EMA for stable training
- Monitor FID and IS for quality evaluation