Mixture of Experts (MoE)
Mixture of Experts is an architecture that conditionally routes inputs to a subset of parameters, enabling models to scale total parameters while keeping computation fixed. This achieves better performance per FLOP than dense models.
The MoE Intuition
An architecture where multiple "expert" neural networks process different inputs, with a learned "gating" network determining which experts process each input. Only a subset of experts are activated per input, enabling sparse computation.
The key insight: not every input needs every parameter. By learning to route different inputs to different experts, MoE models achieve better performance with less computation per token.
Gating Function
The gating function determines expert selection:
Gating Network
Here,
- =
- =
- =
- =
In practice, only the top-k experts are activated:
Top-k Routing
Here,
- =
- =
Load Balancing
Without explicit balancing, the gating network may collapse to using only a few experts.
Without load balancing loss, the gating network tends to converge to selecting the same experts for all inputs, resulting in underutilization of the full model capacity. This is known as the "rich get richer" phenomenon.
Load Balancing Loss
Here,
- =
- =
- =
- =
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_experts: int = 8,
top_k: int = 2,
balance_loss_coeff: float = 0.01
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.balance_loss_coeff = balance_loss_coeff
# Expert networks
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
) for _ in range(num_experts)
])
# Gating network
self.gate = nn.Linear(input_dim, num_experts, bias=False)
def compute_load_balance_loss(self, gate_probs: torch.Tensor) -> torch.Tensor:
"""Compute load balancing loss."""
# f_i: fraction of tokens routed to each expert
# In practice, this is computed per-batch
f = gate_probs.mean(dim=0) # Average probability per expert
# P_i: average gating probability for each expert
P = gate_probs.mean(dim=0)
# Balance loss
balance_loss = self.num_experts * (f * P).sum()
return balance_loss
def forward(self, x: torch.Tensor):
batch_size, seq_len, input_dim = x.shape
x_flat = x.view(-1, input_dim)
# Compute gating scores
gate_logits = self.gate(x_flat) # (batch*seq, num_experts)
gate_probs = F.softmax(gate_logits, dim=-1)
# Select top-k experts
top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# Compute load balance loss
balance_loss = self.compute_load_balance_loss(gate_probs)
# Route tokens to experts
output = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_indices = top_k_indices[:, k] # (batch*seq,)
expert_weights = top_k_probs[:, k] # (batch*seq,)
for i in range(self.num_experts):
mask = (expert_indices == i)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.experts[i](expert_input)
output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
output = output.view(batch_size, seq_len, -1)
return output, balance_loss
Mixtral Architecture
Mixtral (by Mistral AI) is a prominent MoE model with 8 experts and top-2 routing:
Mixtral 8x7B has 46.7B total parameters but only uses ~12.9B parameters per forward pass (2 out of 8 experts active). This achieves performance comparable to LLaMA-2 70B while being 6x faster at inference.
Mixtral Architecture Details
| Component | Specification |
|---|---|
| Total Parameters | 46.7B |
| Active Parameters | ~12.9B |
| Experts | 8 per layer |
| Top-k | 2 |
| Expert FFN Size | 14336 |
| Hidden Size | 4096 |
| Layers | 32 |
| Attention Heads | 32 |
class MixtralBlock(nn.Module):
def __init__(self, dim: int = 4096, num_experts: int = 8, top_k: int = 2):
super().__init__()
self.attention = nn.MultiheadAttention(dim, 32, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# MoE FFN
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(dim, dim * 4),
nn.SiLU(),
nn.Linear(dim * 4, dim)
) for _ in range(num_experts)
])
self.gate = nn.Linear(dim, num_experts, bias=False)
self.top_k = top_k
def forward(self, x: torch.Tensor):
# Self-attention with residual
residual = x
x = self.norm1(x)
x, _ = self.attention(x, x, x)
x = residual + x
# MoE FFN with residual
residual = x
x = self.norm2(x)
# Gating
gate_scores = F.softmax(self.gate(x), dim=-1)
top_k_probs, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# Expert computation
output = torch.zeros_like(x)
for k in range(self.top_k):
for i in range(len(self.experts)):
mask = (top_k_indices[:, :, k] == i)
if mask.any():
expert_input = x[mask]
expert_output = self.experts[i](expert_input)
output[mask] += top_k_probs[:, :, k].unsqueeze(-1)[mask] * expert_output
x = residual + output
return x
DeepSeek-MoE
DeepSeek-MoE introduces finer-grained expert specialization:
DeepSeek-MoE uses more experts (up to 160) with finer granularity and implements "shared experts" that are always activated, combined with "routed experts" that are selected by the gating network.
DeepSeek-MoE Design Principles
- Finer-grained experts: More, smaller experts for better specialization
- Shared experts: Always-active experts for common knowledge
- Routed experts: Conditionally activated for specialized knowledge
class DeepSeekMoELayer(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_shared_experts: int = 2,
num_routed_experts: int = 64,
top_k: int = 6
):
super().__init__()
self.num_shared_experts = num_shared_experts
self.num_routed_experts = num_routed_experts
self.top_k = top_k
# Shared experts (always active)
self.shared_experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
) for _ in range(num_shared_experts)
])
# Routed experts (conditionally active)
self.routed_experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
) for _ in range(num_routed_experts)
])
# Gating network
self.gate = nn.Linear(input_dim, num_routed_experts, bias=False)
def forward(self, x: torch.Tensor):
batch_size, seq_len, input_dim = x.shape
x_flat = x.view(-1, input_dim)
# Compute shared expert output
shared_output = torch.zeros_like(x_flat)
for expert in self.shared_experts:
shared_output += expert(x_flat)
shared_output /= self.num_shared_experts
# Compute gated expert output
gate_logits = self.gate(x_flat)
gate_probs = F.softmax(gate_logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
routed_output = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_indices = top_k_indices[:, k]
expert_weights = top_k_probs[:, k]
for i in range(self.num_routed_experts):
mask = (expert_indices == i)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.routed_experts[i](expert_input)
routed_output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
# Combine shared and routed outputs
output = shared_output + routed_output
return output.view(batch_size, seq_len, -1)
Load Balancing Analysis
Expert Utilization Variance
Here,
- =
- =
- =
A well-balanced MoE model has low variance in expert utilization:
| Model | Experts | Top-k | Utilization Variance |
|---|---|---|---|
| Mixtral 8x7B | 8 | 2 | 0.002 |
| Switch Transformer | 128 | 1 | 0.015 |
| GShard | 2048 | 2 | 0.008 |
Practical: Deploying MoE Models
class MoEInferenceEngine:
def __init__(self, model_name: str):
self.model = self.load_moe_model(model_name)
self.expert_cache = {}
def load_moe_model(self, model_name: str):
"""Load MoE model with expert offloading support."""
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16
)
return model
def generate_with_expert_stats(self, prompt: str, max_tokens: int = 100):
"""Generate text while tracking expert utilization."""
expert_counts = {i: 0 for i in range(8)} # Assuming 8 experts
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.model.device)
for _ in range(max_tokens):
with torch.no_grad():
outputs = self.model(input_ids, output_router_logits=True)
# Track expert usage
if hasattr(outputs, 'router_logits'):
router_logits = outputs.router_logits[-1]
expert_indices = torch.argmax(router_logits, dim=-1)
for idx in expert_indices:
expert_counts[idx.item()] += 1
next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=-1)
return {
"text": self.tokenizer.decode(input_ids[0]),
"expert_utilization": expert_counts
}
When deploying MoE models, ensure all experts fit in memory or use expert offloading. The total parameter count can be 4-8x larger than dense models of similar quality, but inference is faster because only a subset of experts are active per token.
Summary
- MoE models route inputs to a subset of expert networks via a gating function
- Top-k routing activates only k experts per input, enabling sparse computation
- Load balancing loss prevents expert collapse: L_balance = ฮฑ ยท N ยท ฮฃ f_i ยท P_i
- Mixtral 8x7B achieves LLaMA-2 70B quality with 6x faster inference
- DeepSeek-MoE uses shared + routed experts for better specialization
- MoE models require 4-8x more memory than dense models but are faster at inference
- Expert utilization variance measures load balancing quality
Practice Exercises
-
Gating Analysis: Visualize the gating network's expert selection patterns. Do different experts specialize on different types of inputs?
-
Load Balancing: Train an MoE model with and without load balancing loss. Compare expert utilization.
-
Expert Specialization: Analyze what each expert learns. Do experts specialize on different linguistic phenomena?
-
MoE vs Dense: Compare an MoE model with a dense model of equal computational budget. Which achieves better performance?
-
Deployment Optimization: Implement expert offloading to reduce memory usage. Measure the impact on inference speed.
Previous: 18 - Multimodal LLMs โ | Next: 20 - LLM Agent Frameworks โ