Optimization
Pruning for LLMs β Finding What You Can Remove
Not all weights contribute equally to model performance. Pruning removes redundant parameters while preserving quality, achieving 2-5Γ compression with minimal accuracy loss. This guide covers unstructured, structured, and semi-structured pruning for LLMs.
- Unstructured Pruning β Remove individual weights below a threshold
- Structured Pruning β Remove entire neurons, heads, or layers
- Lottery Ticket Hypothesis β Finding subnetworks that train to full accuracy
- Magnitude vs Gradient Pruning β Choosing what to prune
The best architecture is the one that doesn't need to exist.
Pruning for LLMs
Neural networks are massively over-parameterized. Pruning exploits this redundancy by removing weights that contribute little to the model's output. For LLMs, pruning can reduce memory usage and inference cost while maintaining most of the model's capabilities.
DfPruning
Pruning is the process of removing parameters from a trained neural network. The goal is to reduce model size and computation while preserving performance. Pruning can be unstructured (removing individual weights) or structured (removing entire units).
Unstructured Pruning
Magnitude Pruning
DfMagnitude Pruning
Magnitude pruning removes weights with the smallest absolute values, based on the assumption that small weights contribute less to the model's output.
Magnitude Pruning Criterion
Here,
- =Pruning mask
- =Weight value
- =Pruning threshold
import torch
import torch.nn as nn
class MagnitudePruner:
"""Unstructured magnitude pruning."""
def __init__(self, model, sparsity=0.5):
self.model = model
self.sparsity = sparsity
self.masks = {}
def compute_masks(self):
"""Compute pruning masks based on weight magnitudes."""
all_weights = []
# Collect all weights
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() >= 2:
all_weights.append(param.data.abs().flatten())
# Compute global threshold
all_weights = torch.cat(all_weights)
threshold = torch.quantile(all_weights, self.sparsity)
# Create masks
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() >= 2:
mask = (param.data.abs() >= threshold).float()
self.masks[name] = mask
# Apply mask
param.data *= mask
def prune_step(self, optimizer):
"""Prune after each training step."""
for name, param in self.model.named_parameters():
if name in self.masks:
# Remove gradients for pruned weights
param.grad.data *= self.masks[name]
# Reset pruned weights to zero
param.data *= self.masks[name]
Iterative Pruning
Iterative pruning gradually increases sparsity over multiple training epochs. At each step, the least important weights are pruned, and the model is fine-tuned to recover accuracy. This produces better results than one-shot pruning.
class IterativePruner:
"""Iterative magnitude pruning with fine-tuning."""
def __init__(self, model, target_sparsity=0.7, n_steps=10):
self.model = model
self.target_sparsity = target_sparsity
self.n_steps = n_steps
self.current_sparsity = 0
def prune_and_finetune(self, train_loader, optimizer, n_epochs=5):
"""Iteratively prune and fine-tune."""
sparsity_per_step = self.target_sparsity / self.n_steps
for step in range(self.n_steps):
# Increase sparsity
self.current_sparsity += sparsity_per_step
print(f"Step {step+1}: Sparsity = {self.current_sparsity:.2%}")
# Compute new masks
pruner = MagnitudePruner(self.model, self.current_sparsity)
pruner.compute_masks()
# Fine-tune
for epoch in range(n_epochs):
for batch in train_loader:
loss = self.model(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Structured Pruning
Neuron Pruning
DfNeuron Pruning
Neuron pruning removes entire neurons (rows/columns in weight matrices), reducing the model's width. This provides actual speedup on hardware because it reduces matrix dimensions.
Neuron Importance
Here,
- =Importance of neuron j
- =Weight connecting to neuron j
- =Activation of neuron j
- =Variance of neuron j's activations
class NeuronPruner:
"""Structured neuron pruning."""
def __init__(self, model, pruning_ratio=0.3):
self.model = model
self.pruning_ratio = pruning_ratio
def compute_neuron_importance(self, calibration_data):
"""Compute importance scores for each neuron."""
importance = {}
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
# Compute weight magnitude importance
weight_importance = module.weight.data.abs().mean(dim=1)
# Compute activation variance
activations = self._get_activations(module, calibration_data)
activation_var = activations.var(dim=0)
# Combined importance
importance[name] = weight_importance * activation_var
return importance
def prune_neurons(self, importance_scores, pruning_ratio):
"""Remove least important neurons."""
for name, module in self.model.named_modules():
if isinstance(name, importance_scores):
scores = importance_scores[name]
n_prune = int(len(scores) * pruning_ratio)
# Find least important neurons
_, indices_to_prune = scores.topk(n_prune, largest=False)
# Create mask
mask = torch.ones(len(scores), device=scores.device)
mask[indices_to_prune] = 0
# Apply mask (zero out rows)
module.weight.data *= mask.unsqueeze(1)
# Also remove bias if exists
if module.bias is not None:
module.bias.data *= mask
Head Pruning
DfHead Pruning
Head pruning removes entire attention heads based on their importance, reducing the model's attention capacity while maintaining the overall architecture.
Head Importance Score
Here,
- =Importance of head h
- =Attention weight from token t to i in head h
- =Value vector in head h
class HeadPruner:
"""Prune entire attention heads."""
def __init__(self, model):
self.model = model
def compute_head_importance(self, calibration_data):
"""Compute importance for each attention head."""
head_importance = {}
for name, module in self.model.named_modules():
if hasattr(module, 'q_proj') and hasattr(module, 'k_proj'):
# Get attention patterns
attn_weights = self._compute_attention_patterns(
module, calibration_data
)
# Compute head importance
n_heads = module.num_heads
d_head = module.head_dim
importance = torch.zeros(n_heads)
for h in range(n_heads):
# Attention entropy (lower = more focused = more important)
attn_h = attn_weights[:, h, :, :]
entropy = -(attn_h * attn_h.log()).sum(dim=-1).mean()
# Value norm (higher = more important)
v_h = module.v_proj.weight[h*d_head:(h+1)*d_head, :]
v_norm = v_h.norm()
importance[h] = v_norm / (entropy + 1e-6)
head_importance[name] = importance
return head_importance
def prune_heads(self, head_importance, pruning_ratio=0.2):
"""Remove least important heads."""
for name, importance in head_importance.items():
n_heads = len(importance)
n_prune = int(n_heads * pruning_ratio)
_, heads_to_prune = importance.topk(n_prune, largest=False)
# Get the attention module
module = self._get_module_by_name(name)
# Zero out pruned heads
d_head = module.head_dim
for h in heads_to_prune:
module.q_proj.weight[h*d_head:(h+1)*d_head, :] = 0
module.k_proj.weight[h*d_head:(h+1)*d_head, :] = 0
module.v_proj.weight[h*d_head:(h+1)*d_head, :] = 0
module.o_proj.weight[:, h*d_head:(h+1)*d_head] = 0
Layer Pruning
DfLayer Pruning
Layer pruning removes entire transformer layers, reducing the model's depth. This is more aggressive than neuron pruning and typically requires careful layer importance assessment.
class LayerPruner:
"""Prune entire transformer layers."""
def __init__(self, model):
self.model = model
def compute_layer_importance(self, calibration_data):
"""Compute importance for each layer."""
layer_importance = []
for i, layer in enumerate(self.model.layers):
# Compute residual contribution
residual_norm = self._compute_residual_contribution(
layer, calibration_data
)
# Compute gradient magnitude
grad_magnitude = self._compute_gradient_magnitude(
layer, calibration_data
)
# Combined importance
importance = residual_norm * grad_magnitude
layer_importance.append(importance)
return torch.tensor(layer_importance)
def prune_layers(self, importance, n_prune=2):
"""Remove least important layers."""
_, layers_to_prune = importance.topk(n_prune, largest=False)
# Remove layers in reverse order
for idx in sorted(layers_to_prune, reverse=True):
del self.model.layers[idx]
# Update layer indices
for i, layer in enumerate(self.model.layers):
layer.layer_idx = i
Lottery Ticket Hypothesis
Theory
DfLottery Ticket Hypothesis
The Lottery Ticket Hypothesis (Frankle & Carlin, 2019) states that a randomly-initialized neural network contains a subnetwork (the "winning ticket") that, when trained in isolation, achieves comparable accuracy to the full network.
Finding lottery tickets requires iterative pruning: train, prune, reset remaining weights to initialization, and retrain. The winning ticket is the subnetwork that can reach full accuracy when trained from scratch.
Finding Winning Tickets
class LotteryTicketFinder:
"""Find winning tickets through iterative pruning."""
def __init__(self, model, train_fn, eval_fn):
self.model = model
self.train_fn = train_fn
self.eval_fn = eval_fn
# Save initial weights
self.initial_weights = {
name: param.clone()
for name, param in model.named_parameters()
}
def find_winning_ticket(self, pruning_ratio=0.2, n_iterations=5):
"""Iteratively prune to find winning ticket."""
current_sparsity = 0
for iteration in range(n_iterations):
# Train current network
self.train_fn(self.model)
# Evaluate
accuracy = self.eval_fn(self.model)
print(f"Iteration {iteration+1}: Sparsity={current_sparsity:.0%}, Acc={accuracy:.2%}")
# Prune
current_sparsity += pruning_ratio
self._prune_global(pruning_ratio)
# Reset remaining weights to initialization
self._reset_to_initial()
return self.model
def _prune_global(self, ratio):
"""Prune globally by magnitude."""
all_weights = []
for name, param in self.model.named_parameters():
if 'weight' in name:
all_weights.append(param.data.abs().flatten())
all_weights = torch.cat(all_weights)
threshold = torch.quantile(all_weights, ratio)
for name, param in self.model.named_parameters():
if 'weight' in name:
mask = param.data.abs() >= threshold
param.data *= mask.float()
def _reset_to_initial(self):
"""Reset pruned weights to initial values."""
for name, param in self.model.named_parameters():
if name in self.initial_weights:
mask = param.data != 0
param.data = param.data * mask.float() + \
self.initial_weights[name] * (1 - mask.float())
Practical Implications
| Aspect | Full Model | Lottery Ticket | Implication |
|---|---|---|---|
| Parameters | 100% | 30-50% | Significant compression |
| Training Cost | 1Γ | 3-5Γ | Finding tickets is expensive |
| Inference Speed | Baseline | 2-3Γ faster | Structured tickets only |
| Quality | Baseline | 95-99% | Minimal degradation |
While lottery tickets provide theoretical insights, in practice we rarely find them for LLMs due to the computational cost. Instead, post-training pruning (magnitude, structured) is more practical for LLM compression.
Semi-Structured Pruning
N:M Sparsity
DfN:M Sparsity
N:M sparsity is a hardware-friendly pruning pattern where exactly N out of every M consecutive weights are zero. For example, 2:4 sparsity means 2 out of every 4 weights are pruned, achieving 50% sparsity with hardware acceleration support.
N:M Sparsity
Here,
- =Number of non-zero weights in group
- =Group size
class NMPruner:
"""N:M structured pruning."""
def __init__(self, n=2, m=4):
self.n = n
self.m = m
def prune(self, weight):
"""Apply N:M pruning to weight tensor."""
original_shape = weight.shape
# Flatten and reshape to groups
weight_flat = weight.flatten()
n_groups = len(weight_flat) // self.m
weight_grouped = weight_flat[:n_groups * self.m].reshape(n_groups, self.m)
# Keep top N weights per group
_, topk_indices = weight_grouped.abs().topk(self.n, dim=1)
mask = torch.zeros_like(weight_grouped)
mask.scatter_(1, topk_indices, 1)
# Apply mask
weight_pruned = weight_grouped * mask
return weight_pruned.reshape(original_shape)
N:M sparsity is supported by NVIDIA's Ampere and newer GPUs. The sparse tensor cores can accelerate 2:4 sparse matrix multiplication by up to 2Γ compared to dense FP16 operations.
Choosing Pruning Strategy
| Strategy | Speedup | Compression | Quality | Hardware Support |
|---|---|---|---|---|
| Unstructured | None | High | High | Limited |
| Structured (neuron) | High | Moderate | Moderate | Full |
| Structured (head) | Moderate | Moderate | Good | Full |
| Structured (layer) | High | High | Lower | Full |
| N:M | 2Γ | 2Γ | Good | Ampere+ |
For production deployment, structured pruning (neurons or heads) provides actual speedup. Unstructured pruning reduces memory but doesn't improve inference speed on standard hardware.
Practice Exercises
-
Conceptual: Explain why the lottery ticket hypothesis suggests that over-parameterization is beneficial for training, even though the final model can be much smaller.
-
Mathematical: For a 7B parameter model with 2:4 N:M sparsity, calculate the actual memory savings including the index overhead for sparse storage.
-
Practical: Implement iterative magnitude pruning on a small language model and plot the accuracy vs. sparsity curve. At what sparsity does accuracy begin to degrade significantly?
-
Research: Compare magnitude pruning with gradient-based pruning (using |w Γ βw| as importance). Which provides better quality at high sparsity levels?
Key Takeaways:
- Unstructured pruning removes individual weights but doesn't provide speedup
- Structured pruning (neurons, heads, layers) provides actual inference acceleration
- Lottery tickets show that subnetworks can match full network accuracy
- N:M sparsity enables hardware-accelerated sparse computation
- Iterative pruning with fine-tuning outperforms one-shot pruning
- Larger models are more robust to pruning (70B pruned > 7B dense)
- For production, structured pruning is preferred over unstructured
What to Learn Next
-> Low-Rank Factorization SVD decomposition and weight sharing techniques.
-> Quantization Techniques Deep Dive GPTQ, AWQ, GGUF, and INT4/INT8 methods.
-> Model Merging and Fusion Combining multiple fine-tuned models.
-> Knowledge Distillation for LLMs Training smaller models from larger teachers.
-> LoRA and PEFT Efficient fine-tuning using low-rank adaptation.
-> Hardware-Aware LLM Design Optimizing models for GPU memory hierarchy.