Energy-Based Models
Energy Function
import torch
import torch.nn as nn
class EnergyFunction(nn.Module):
def __init__(self, input_dim, hidden_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.net(x).squeeze(-1)
def energy(self, x):
return self.forward(x)
class JointEnergyModel(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim=256):
super().__init__()
self.feature_extractor = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.energy_head = nn.Linear(hidden_dim, 1)
self.classifier = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
features = self.feature_extractor(x)
energy = self.energy_head(features).squeeze(-1)
logits = self.classifier(features)
return energy, logits
Score Matching
class ScoreMatchingTrainer:
def __init__(self, score_net, sigma=0.01):
self.score_net = score_net
self.sigma = sigma
def score_function(self, x):
x.requires_grad_(True)
energy = self.score_net.energy(x)
score = torch.autograd.grad(energy.sum(), x, create_graph=True)[0]
return score
def denoising_score_matching_loss(self, x):
noise = torch.randn_like(x) * self.sigma
x_noisy = x + noise
score = self.score_function(x_noisy)
target = -noise / (self.sigma ** 2)
loss = 0.5 * ((score - target) ** 2).mean()
return loss
def train_step(self, batch):
loss = self.denoising_score_matching_loss(batch)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
score_net = EnergyFunction(input_dim=784)
trainer = ScoreMatchingTrainer(score_net)
Langevin Dynamics Sampling
class LangevinDynamics:
def __init__(self, score_net, step_size=0.01, n_steps=100, sigma=0.01):
self.score_net = score_net
self.step_size = step_size
self.n_steps = n_steps
self.sigma = sigma
def sample(self, initial_samples):
x = initial_samples.clone()
for step in range(self.n_steps):
noise = torch.randn_like(x) * self.sigma
with torch.enable_grad():
x.requires_grad_(True)
energy = self.score_net.energy(x)
score = torch.autograd.grad(energy.sum(), x)[0]
x = x.detach() - self.step_size * score + noise
return x
sampler = LangevinDynamics(score_net, step_size=0.01, n_steps=100)
initial = torch.randn(32, 784)
samples = sampler.sample(initial)
Best Practices
- Use denoising score matching for stable training
- Apply annealed Langevin dynamics for better samples
- Monitor energy values during training
- Use multiple noise scales for robust score estimation
- Combine with diffusion models for state-of-the-art results