CW

Pruning for LLMs

OptimizationModel CompressionFree Lesson

Advertisement

Optimization

Pruning for LLMs β€” Finding What You Can Remove

Not all weights contribute equally to model performance. Pruning removes redundant parameters while preserving quality, achieving 2-5Γ— compression with minimal accuracy loss. This guide covers unstructured, structured, and semi-structured pruning for LLMs.

  • Unstructured Pruning β€” Remove individual weights below a threshold
  • Structured Pruning β€” Remove entire neurons, heads, or layers
  • Lottery Ticket Hypothesis β€” Finding subnetworks that train to full accuracy
  • Magnitude vs Gradient Pruning β€” Choosing what to prune

The best architecture is the one that doesn't need to exist.

Pruning for LLMs

Neural networks are massively over-parameterized. Pruning exploits this redundancy by removing weights that contribute little to the model's output. For LLMs, pruning can reduce memory usage and inference cost while maintaining most of the model's capabilities.

DfPruning

Pruning is the process of removing parameters from a trained neural network. The goal is to reduce model size and computation while preserving performance. Pruning can be unstructured (removing individual weights) or structured (removing entire units).

Unstructured Pruning

Magnitude Pruning

DfMagnitude Pruning

Magnitude pruning removes weights with the smallest absolute values, based on the assumption that small weights contribute less to the model's output.

Magnitude Pruning Criterion

M_{ij} = \\begin{cases} 0 & \\text{if } |W_{ij}| < \\theta \\\\ 1 & \\text{otherwise} \\end{cases}

Here,

  • MijM_{ij}=Pruning mask
  • WijW_{ij}=Weight value
  • ΞΈ\theta=Pruning threshold
import torch
import torch.nn as nn

class MagnitudePruner:
    """Unstructured magnitude pruning."""
    
    def __init__(self, model, sparsity=0.5):
        self.model = model
        self.sparsity = sparsity
        self.masks = {}
        
    def compute_masks(self):
        """Compute pruning masks based on weight magnitudes."""
        all_weights = []
        
        # Collect all weights
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                all_weights.append(param.data.abs().flatten())
        
        # Compute global threshold
        all_weights = torch.cat(all_weights)
        threshold = torch.quantile(all_weights, self.sparsity)
        
        # Create masks
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                mask = (param.data.abs() >= threshold).float()
                self.masks[name] = mask
                
                # Apply mask
                param.data *= mask
    
    def prune_step(self, optimizer):
        """Prune after each training step."""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                # Remove gradients for pruned weights
                param.grad.data *= self.masks[name]
                
                # Reset pruned weights to zero
                param.data *= self.masks[name]

Iterative Pruning

Iterative pruning gradually increases sparsity over multiple training epochs. At each step, the least important weights are pruned, and the model is fine-tuned to recover accuracy. This produces better results than one-shot pruning.

class IterativePruner:
    """Iterative magnitude pruning with fine-tuning."""
    
    def __init__(self, model, target_sparsity=0.7, n_steps=10):
        self.model = model
        self.target_sparsity = target_sparsity
        self.n_steps = n_steps
        self.current_sparsity = 0
        
    def prune_and_finetune(self, train_loader, optimizer, n_epochs=5):
        """Iteratively prune and fine-tune."""
        sparsity_per_step = self.target_sparsity / self.n_steps
        
        for step in range(self.n_steps):
            # Increase sparsity
            self.current_sparsity += sparsity_per_step
            print(f"Step {step+1}: Sparsity = {self.current_sparsity:.2%}")
            
            # Compute new masks
            pruner = MagnitudePruner(self.model, self.current_sparsity)
            pruner.compute_masks()
            
            # Fine-tune
            for epoch in range(n_epochs):
                for batch in train_loader:
                    loss = self.model(batch)
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()

Structured Pruning

Neuron Pruning

DfNeuron Pruning

Neuron pruning removes entire neurons (rows/columns in weight matrices), reducing the model's width. This provides actual speedup on hardware because it reduces matrix dimensions.

Neuron Importance

Ij=sumi∣Wij∣cdottextVar(Xj)I_j = \\sum_{i} |W_{ij}| \\cdot \\text{Var}(X_j)

Here,

  • IjI_j=Importance of neuron j
  • WijW_{ij}=Weight connecting to neuron j
  • XjX_j=Activation of neuron j
  • Var(Xj)\text{Var}(X_j)=Variance of neuron j's activations
class NeuronPruner:
    """Structured neuron pruning."""
    
    def __init__(self, model, pruning_ratio=0.3):
        self.model = model
        self.pruning_ratio = pruning_ratio
        
    def compute_neuron_importance(self, calibration_data):
        """Compute importance scores for each neuron."""
        importance = {}
        
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                # Compute weight magnitude importance
                weight_importance = module.weight.data.abs().mean(dim=1)
                
                # Compute activation variance
                activations = self._get_activations(module, calibration_data)
                activation_var = activations.var(dim=0)
                
                # Combined importance
                importance[name] = weight_importance * activation_var
        
        return importance
    
    def prune_neurons(self, importance_scores, pruning_ratio):
        """Remove least important neurons."""
        for name, module in self.model.named_modules():
            if isinstance(name, importance_scores):
                scores = importance_scores[name]
                n_prune = int(len(scores) * pruning_ratio)
                
                # Find least important neurons
                _, indices_to_prune = scores.topk(n_prune, largest=False)
                
                # Create mask
                mask = torch.ones(len(scores), device=scores.device)
                mask[indices_to_prune] = 0
                
                # Apply mask (zero out rows)
                module.weight.data *= mask.unsqueeze(1)
                
                # Also remove bias if exists
                if module.bias is not None:
                    module.bias.data *= mask

Head Pruning

DfHead Pruning

Head pruning removes entire attention heads based on their importance, reducing the model's attention capacity while maintaining the overall architecture.

Head Importance Score

Ih=sumtleft∣sumialphah,t,icdotvh,t,iright∣I_h = \\sum_{t} \\left| \\sum_{i} \\alpha_{h,t,i} \\cdot v_{h,t,i} \\right|

Here,

  • IhI_h=Importance of head h
  • Ξ±h,t,i\alpha_{h,t,i}=Attention weight from token t to i in head h
  • vh,t,iv_{h,t,i}=Value vector in head h
class HeadPruner:
    """Prune entire attention heads."""
    
    def __init__(self, model):
        self.model = model
        
    def compute_head_importance(self, calibration_data):
        """Compute importance for each attention head."""
        head_importance = {}
        
        for name, module in self.model.named_modules():
            if hasattr(module, 'q_proj') and hasattr(module, 'k_proj'):
                # Get attention patterns
                attn_weights = self._compute_attention_patterns(
                    module, calibration_data
                )
                
                # Compute head importance
                n_heads = module.num_heads
                d_head = module.head_dim
                
                importance = torch.zeros(n_heads)
                for h in range(n_heads):
                    # Attention entropy (lower = more focused = more important)
                    attn_h = attn_weights[:, h, :, :]
                    entropy = -(attn_h * attn_h.log()).sum(dim=-1).mean()
                    
                    # Value norm (higher = more important)
                    v_h = module.v_proj.weight[h*d_head:(h+1)*d_head, :]
                    v_norm = v_h.norm()
                    
                    importance[h] = v_norm / (entropy + 1e-6)
                
                head_importance[name] = importance
        
        return head_importance
    
    def prune_heads(self, head_importance, pruning_ratio=0.2):
        """Remove least important heads."""
        for name, importance in head_importance.items():
            n_heads = len(importance)
            n_prune = int(n_heads * pruning_ratio)
            
            _, heads_to_prune = importance.topk(n_prune, largest=False)
            
            # Get the attention module
            module = self._get_module_by_name(name)
            
            # Zero out pruned heads
            d_head = module.head_dim
            for h in heads_to_prune:
                module.q_proj.weight[h*d_head:(h+1)*d_head, :] = 0
                module.k_proj.weight[h*d_head:(h+1)*d_head, :] = 0
                module.v_proj.weight[h*d_head:(h+1)*d_head, :] = 0
                module.o_proj.weight[:, h*d_head:(h+1)*d_head] = 0

Layer Pruning

DfLayer Pruning

Layer pruning removes entire transformer layers, reducing the model's depth. This is more aggressive than neuron pruning and typically requires careful layer importance assessment.

class LayerPruner:
    """Prune entire transformer layers."""
    
    def __init__(self, model):
        self.model = model
        
    def compute_layer_importance(self, calibration_data):
        """Compute importance for each layer."""
        layer_importance = []
        
        for i, layer in enumerate(self.model.layers):
            # Compute residual contribution
            residual_norm = self._compute_residual_contribution(
                layer, calibration_data
            )
            
            # Compute gradient magnitude
            grad_magnitude = self._compute_gradient_magnitude(
                layer, calibration_data
            )
            
            # Combined importance
            importance = residual_norm * grad_magnitude
            layer_importance.append(importance)
        
        return torch.tensor(layer_importance)
    
    def prune_layers(self, importance, n_prune=2):
        """Remove least important layers."""
        _, layers_to_prune = importance.topk(n_prune, largest=False)
        
        # Remove layers in reverse order
        for idx in sorted(layers_to_prune, reverse=True):
            del self.model.layers[idx]
        
        # Update layer indices
        for i, layer in enumerate(self.model.layers):
            layer.layer_idx = i

Lottery Ticket Hypothesis

Theory

DfLottery Ticket Hypothesis

The Lottery Ticket Hypothesis (Frankle & Carlin, 2019) states that a randomly-initialized neural network contains a subnetwork (the "winning ticket") that, when trained in isolation, achieves comparable accuracy to the full network.

Finding lottery tickets requires iterative pruning: train, prune, reset remaining weights to initialization, and retrain. The winning ticket is the subnetwork that can reach full accuracy when trained from scratch.

Finding Winning Tickets

class LotteryTicketFinder:
    """Find winning tickets through iterative pruning."""
    
    def __init__(self, model, train_fn, eval_fn):
        self.model = model
        self.train_fn = train_fn
        self.eval_fn = eval_fn
        
        # Save initial weights
        self.initial_weights = {
            name: param.clone() 
            for name, param in model.named_parameters()
        }
        
    def find_winning_ticket(self, pruning_ratio=0.2, n_iterations=5):
        """Iteratively prune to find winning ticket."""
        current_sparsity = 0
        
        for iteration in range(n_iterations):
            # Train current network
            self.train_fn(self.model)
            
            # Evaluate
            accuracy = self.eval_fn(self.model)
            print(f"Iteration {iteration+1}: Sparsity={current_sparsity:.0%}, Acc={accuracy:.2%}")
            
            # Prune
            current_sparsity += pruning_ratio
            self._prune_global(pruning_ratio)
            
            # Reset remaining weights to initialization
            self._reset_to_initial()
        
        return self.model
    
    def _prune_global(self, ratio):
        """Prune globally by magnitude."""
        all_weights = []
        
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                all_weights.append(param.data.abs().flatten())
        
        all_weights = torch.cat(all_weights)
        threshold = torch.quantile(all_weights, ratio)
        
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                mask = param.data.abs() >= threshold
                param.data *= mask.float()
    
    def _reset_to_initial(self):
        """Reset pruned weights to initial values."""
        for name, param in self.model.named_parameters():
            if name in self.initial_weights:
                mask = param.data != 0
                param.data = param.data * mask.float() + \
                           self.initial_weights[name] * (1 - mask.float())

Practical Implications

AspectFull ModelLottery TicketImplication
Parameters100%30-50%Significant compression
Training Cost1Γ—3-5Γ—Finding tickets is expensive
Inference SpeedBaseline2-3Γ— fasterStructured tickets only
QualityBaseline95-99%Minimal degradation

While lottery tickets provide theoretical insights, in practice we rarely find them for LLMs due to the computational cost. Instead, post-training pruning (magnitude, structured) is more practical for LLM compression.

Semi-Structured Pruning

N:M Sparsity

DfN:M Sparsity

N:M sparsity is a hardware-friendly pruning pattern where exactly N out of every M consecutive weights are zero. For example, 2:4 sparsity means 2 out of every 4 weights are pruned, achieving 50% sparsity with hardware acceleration support.

N:M Sparsity

textSparsity=1βˆ’fracNM\\text{Sparsity} = 1 - \\frac{N}{M}

Here,

  • NN=Number of non-zero weights in group
  • MM=Group size
class NMPruner:
    """N:M structured pruning."""
    
    def __init__(self, n=2, m=4):
        self.n = n
        self.m = m
        
    def prune(self, weight):
        """Apply N:M pruning to weight tensor."""
        original_shape = weight.shape
        
        # Flatten and reshape to groups
        weight_flat = weight.flatten()
        n_groups = len(weight_flat) // self.m
        weight_grouped = weight_flat[:n_groups * self.m].reshape(n_groups, self.m)
        
        # Keep top N weights per group
        _, topk_indices = weight_grouped.abs().topk(self.n, dim=1)
        
        mask = torch.zeros_like(weight_grouped)
        mask.scatter_(1, topk_indices, 1)
        
        # Apply mask
        weight_pruned = weight_grouped * mask
        
        return weight_pruned.reshape(original_shape)

N:M sparsity is supported by NVIDIA's Ampere and newer GPUs. The sparse tensor cores can accelerate 2:4 sparse matrix multiplication by up to 2Γ— compared to dense FP16 operations.

Choosing Pruning Strategy

StrategySpeedupCompressionQualityHardware Support
UnstructuredNoneHighHighLimited
Structured (neuron)HighModerateModerateFull
Structured (head)ModerateModerateGoodFull
Structured (layer)HighHighLowerFull
N:M2Γ—2Γ—GoodAmpere+

For production deployment, structured pruning (neurons or heads) provides actual speedup. Unstructured pruning reduces memory but doesn't improve inference speed on standard hardware.

Practice Exercises

  1. Conceptual: Explain why the lottery ticket hypothesis suggests that over-parameterization is beneficial for training, even though the final model can be much smaller.

  2. Mathematical: For a 7B parameter model with 2:4 N:M sparsity, calculate the actual memory savings including the index overhead for sparse storage.

  3. Practical: Implement iterative magnitude pruning on a small language model and plot the accuracy vs. sparsity curve. At what sparsity does accuracy begin to degrade significantly?

  4. Research: Compare magnitude pruning with gradient-based pruning (using |w Γ— βˆ‡w| as importance). Which provides better quality at high sparsity levels?

Key Takeaways:

  • Unstructured pruning removes individual weights but doesn't provide speedup
  • Structured pruning (neurons, heads, layers) provides actual inference acceleration
  • Lottery tickets show that subnetworks can match full network accuracy
  • N:M sparsity enables hardware-accelerated sparse computation
  • Iterative pruning with fine-tuning outperforms one-shot pruning
  • Larger models are more robust to pruning (70B pruned > 7B dense)
  • For production, structured pruning is preferred over unstructured

What to Learn Next

-> Low-Rank Factorization SVD decomposition and weight sharing techniques.

-> Quantization Techniques Deep Dive GPTQ, AWQ, GGUF, and INT4/INT8 methods.

-> Model Merging and Fusion Combining multiple fine-tuned models.

-> Knowledge Distillation for LLMs Training smaller models from larger teachers.

-> LoRA and PEFT Efficient fine-tuning using low-rank adaptation.

-> Hardware-Aware LLM Design Optimizing models for GPU memory hierarchy.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement