Mixture of Experts

ArchitecturesMoEFree Lesson

Advertisement

Mixture of Experts (MoE)

Mixture of Experts is an architecture that conditionally routes inputs to a subset of parameters, enabling models to scale total parameters while keeping computation fixed. This achieves better performance per FLOP than dense models.

The MoE Intuition

An architecture where multiple "expert" neural networks process different inputs, with a learned "gating" network determining which experts process each input. Only a subset of experts are activated per input, enabling sparse computation.

The key insight: not every input needs every parameter. By learning to route different inputs to different experts, MoE models achieve better performance with less computation per token.

Gating Function

The gating function determines expert selection:

Gating Network

G(x)=textsoftmax(Wgcdotx+bg)G(x) = \\text{softmax}(W_g \\cdot x + b_g)

Here,

  • =
  • =
  • =
  • =
y=sumi=1Ngi(x)cdotEi(x)y = \\sum_{i=1}^{N} g_i(x) \\cdot E_i(x)

In practice, only the top-k experts are activated:

Top-k Routing

g_i(x) = \\begin{cases} \\frac{\\exp((W_g \\cdot x)_i)}{\\sum_{j \\in \\text{Top-k}} \\exp((W_g \\cdot x)_j)} & \\text{if } i \\in \\text{Top-k} \\\\ 0 & \\text{otherwise} \\end{cases}

Here,

  • =
  • =

Load Balancing

Without explicit balancing, the gating network may collapse to using only a few experts.

Without load balancing loss, the gating network tends to converge to selecting the same experts for all inputs, resulting in underutilization of the full model capacity. This is known as the "rich get richer" phenomenon.

Load Balancing Loss

mathcalLtextbalance=alphacdotNcdotsumi=1NficdotPi\\mathcal{L}_{\\text{balance}} = \\alpha \\cdot N \\cdot \\sum_{i=1}^{N} f_i \\cdot P_i

Here,

  • =
  • =
  • =
  • =
import torch
import torch.nn as nn
import torch.nn.functional as F

class MoELayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_experts: int = 8,
        top_k: int = 2,
        balance_loss_coeff: float = 0.01
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.balance_loss_coeff = balance_loss_coeff
        
        # Expert networks
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, output_dim)
            ) for _ in range(num_experts)
        ])
        
        # Gating network
        self.gate = nn.Linear(input_dim, num_experts, bias=False)
    
    def compute_load_balance_loss(self, gate_probs: torch.Tensor) -> torch.Tensor:
        """Compute load balancing loss."""
        # f_i: fraction of tokens routed to each expert
        # In practice, this is computed per-batch
        f = gate_probs.mean(dim=0)  # Average probability per expert
        
        # P_i: average gating probability for each expert
        P = gate_probs.mean(dim=0)
        
        # Balance loss
        balance_loss = self.num_experts * (f * P).sum()
        return balance_loss
    
    def forward(self, x: torch.Tensor):
        batch_size, seq_len, input_dim = x.shape
        x_flat = x.view(-1, input_dim)
        
        # Compute gating scores
        gate_logits = self.gate(x_flat)  # (batch*seq, num_experts)
        gate_probs = F.softmax(gate_logits, dim=-1)
        
        # Select top-k experts
        top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        
        # Compute load balance loss
        balance_loss = self.compute_load_balance_loss(gate_probs)
        
        # Route tokens to experts
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_indices = top_k_indices[:, k]  # (batch*seq,)
            expert_weights = top_k_probs[:, k]  # (batch*seq,)
            
            for i in range(self.num_experts):
                mask = (expert_indices == i)
                if mask.any():
                    expert_input = x_flat[mask]
                    expert_output = self.experts[i](expert_input)
                    output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
        
        output = output.view(batch_size, seq_len, -1)
        return output, balance_loss

Mixtral Architecture

Mixtral (by Mistral AI) is a prominent MoE model with 8 experts and top-2 routing:

Mixtral 8x7B has 46.7B total parameters but only uses ~12.9B parameters per forward pass (2 out of 8 experts active). This achieves performance comparable to LLaMA-2 70B while being 6x faster at inference.

Mixtral Architecture Details

ComponentSpecification
Total Parameters46.7B
Active Parameters~12.9B
Experts8 per layer
Top-k2
Expert FFN Size14336
Hidden Size4096
Layers32
Attention Heads32
class MixtralBlock(nn.Module):
    def __init__(self, dim: int = 4096, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, 32, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        # MoE FFN
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim, dim * 4),
                nn.SiLU(),
                nn.Linear(dim * 4, dim)
            ) for _ in range(num_experts)
        ])
        
        self.gate = nn.Linear(dim, num_experts, bias=False)
        self.top_k = top_k
    
    def forward(self, x: torch.Tensor):
        # Self-attention with residual
        residual = x
        x = self.norm1(x)
        x, _ = self.attention(x, x, x)
        x = residual + x
        
        # MoE FFN with residual
        residual = x
        x = self.norm2(x)
        
        # Gating
        gate_scores = F.softmax(self.gate(x), dim=-1)
        top_k_probs, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        
        # Expert computation
        output = torch.zeros_like(x)
        for k in range(self.top_k):
            for i in range(len(self.experts)):
                mask = (top_k_indices[:, :, k] == i)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[i](expert_input)
                    output[mask] += top_k_probs[:, :, k].unsqueeze(-1)[mask] * expert_output
        
        x = residual + output
        return x

DeepSeek-MoE

DeepSeek-MoE introduces finer-grained expert specialization:

DeepSeek-MoE uses more experts (up to 160) with finer granularity and implements "shared experts" that are always activated, combined with "routed experts" that are selected by the gating network.

DeepSeek-MoE Design Principles

  1. Finer-grained experts: More, smaller experts for better specialization
  2. Shared experts: Always-active experts for common knowledge
  3. Routed experts: Conditionally activated for specialized knowledge
class DeepSeekMoELayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_shared_experts: int = 2,
        num_routed_experts: int = 64,
        top_k: int = 6
    ):
        super().__init__()
        self.num_shared_experts = num_shared_experts
        self.num_routed_experts = num_routed_experts
        self.top_k = top_k
        
        # Shared experts (always active)
        self.shared_experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, output_dim)
            ) for _ in range(num_shared_experts)
        ])
        
        # Routed experts (conditionally active)
        self.routed_experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, output_dim)
            ) for _ in range(num_routed_experts)
        ])
        
        # Gating network
        self.gate = nn.Linear(input_dim, num_routed_experts, bias=False)
    
    def forward(self, x: torch.Tensor):
        batch_size, seq_len, input_dim = x.shape
        x_flat = x.view(-1, input_dim)
        
        # Compute shared expert output
        shared_output = torch.zeros_like(x_flat)
        for expert in self.shared_experts:
            shared_output += expert(x_flat)
        shared_output /= self.num_shared_experts
        
        # Compute gated expert output
        gate_logits = self.gate(x_flat)
        gate_probs = F.softmax(gate_logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        
        routed_output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_indices = top_k_indices[:, k]
            expert_weights = top_k_probs[:, k]
            
            for i in range(self.num_routed_experts):
                mask = (expert_indices == i)
                if mask.any():
                    expert_input = x_flat[mask]
                    expert_output = self.routed_experts[i](expert_input)
                    routed_output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
        
        # Combine shared and routed outputs
        output = shared_output + routed_output
        return output.view(batch_size, seq_len, -1)

Load Balancing Analysis

Expert Utilization Variance

textVar(f)=frac1Nsumi=1N(fiโˆ’barf)2\\text{Var}(f) = \\frac{1}{N} \\sum_{i=1}^{N} (f_i - \\bar{f})^2

Here,

  • =
  • =
  • =

A well-balanced MoE model has low variance in expert utilization:

ModelExpertsTop-kUtilization Variance
Mixtral 8x7B820.002
Switch Transformer12810.015
GShard204820.008

Practical: Deploying MoE Models

class MoEInferenceEngine:
    def __init__(self, model_name: str):
        self.model = self.load_moe_model(model_name)
        self.expert_cache = {}
    
    def load_moe_model(self, model_name: str):
        """Load MoE model with expert offloading support."""
        from transformers import AutoModelForCausalLM
        
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16
        )
        return model
    
    def generate_with_expert_stats(self, prompt: str, max_tokens: int = 100):
        """Generate text while tracking expert utilization."""
        expert_counts = {i: 0 for i in range(8)}  # Assuming 8 experts
        
        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        
        for _ in range(max_tokens):
            with torch.no_grad():
                outputs = self.model(input_ids, output_router_logits=True)
                
                # Track expert usage
                if hasattr(outputs, 'router_logits'):
                    router_logits = outputs.router_logits[-1]
                    expert_indices = torch.argmax(router_logits, dim=-1)
                    for idx in expert_indices:
                        expert_counts[idx.item()] += 1
                
                next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1)
                input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=-1)
        
        return {
            "text": self.tokenizer.decode(input_ids[0]),
            "expert_utilization": expert_counts
        }

When deploying MoE models, ensure all experts fit in memory or use expert offloading. The total parameter count can be 4-8x larger than dense models of similar quality, but inference is faster because only a subset of experts are active per token.

Summary

  • MoE models route inputs to a subset of expert networks via a gating function
  • Top-k routing activates only k experts per input, enabling sparse computation
  • Load balancing loss prevents expert collapse: L_balance = ฮฑ ยท N ยท ฮฃ f_i ยท P_i
  • Mixtral 8x7B achieves LLaMA-2 70B quality with 6x faster inference
  • DeepSeek-MoE uses shared + routed experts for better specialization
  • MoE models require 4-8x more memory than dense models but are faster at inference
  • Expert utilization variance measures load balancing quality

Practice Exercises

  1. Gating Analysis: Visualize the gating network's expert selection patterns. Do different experts specialize on different types of inputs?

  2. Load Balancing: Train an MoE model with and without load balancing loss. Compare expert utilization.

  3. Expert Specialization: Analyze what each expert learns. Do experts specialize on different linguistic phenomena?

  4. MoE vs Dense: Compare an MoE model with a dense model of equal computational budget. Which achieves better performance?

  5. Deployment Optimization: Implement expert offloading to reduce memory usage. Measure the impact on inference speed.


Previous: 18 - Multimodal LLMs โ† | Next: 20 - LLM Agent Frameworks โ†’

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement