Mixture of Experts
What is Mixture of Experts?
MoE architectures use multiple expert subnetworks with a gating mechanism that routes inputs to only a subset of experts, enabling efficient scaling.
Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, d_model, num_experts, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model)
)
for _ in range(num_experts)
])
self.gate = nn.Linear(d_model, num_experts)
def forward(self, x):
batch_size, seq_len, d_model = x.shape
x_flat = x.view(-1, d_model)
# Compute gate scores
gate_scores = F.softmax(self.gate(x_flat), dim=-1)
# Select top-k experts
top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1)
# Process through experts
output = torch.zeros_like(x_flat)
for i, expert in enumerate(self.experts):
mask = (top_k_indices == i).any(dim=-1)
if mask.any():
expert_output = expert(x_flat[mask])
weights = top_k_scores[top_k_indices == i].unsqueeze(-1)
output[mask] += expert_output * weights
return output.view(batch_size, seq_len, d_model)
Load Balancing
def load_balancing_loss(gate_scores, num_experts):
"""Encourage balanced expert utilization."""
# Average probability for each expert
expert_probs = gate_scores.mean(dim=0)
# Target uniform distribution
target = torch.ones_like(expert_probs) / num_experts
# KL divergence
loss = F.kl_div(expert_probs.log(), target, reduction='batchmean')
return loss
MoE Models
| Model | Experts | Active | Parameters |
|---|---|---|---|
| Mixtral 8x7B | 8 | 2 | 47B |
| Switch Transformer | 128 | 1-2 | Various |
| GShard | 2048 | 2 | 600B |
Summary
MoE enables efficient model scaling by activating only a subset of parameters per input. This allows larger models with manageable compute costs.
Next: We'll explore state space models.