Model Compression Techniques
Weight Pruning
Remove unnecessary connections (weights) from neural networks to reduce model size and computation.
import torch
import torch.nn as nn
import numpy as np
from typing import Dict
class WeightPruner:
def __init__(self, model: nn.Module):
self.model = model
self.masks: Dict[str, torch.Tensor] = {}
def compute_importance(self, method="magnitude"):
importance = {}
for name, param in self.model.named_parameters():
if method == "magnitude":
importance[name] = torch.abs(param.data)
elif method == "gradient":
importance[name] = torch.abs(param.data * param.grad.data)
return importance
def create_mask(self, sparsity: float = 0.5):
importance = self.compute_importance()
all_values = torch.cat([v.flatten() for v in importance.values()])
threshold = torch.quantile(all_values, sparsity)
for name, imp in importance.items():
self.masks[name] = (imp >= threshold).float()
def apply_mask(self):
for name, param in self.model.named_parameters():
if name in self.masks:
param.data *= self.masks[name]
def get_compression_stats(self):
total_params = 0
zero_params = 0
for name, param in self.model.named_parameters():
total_params += param.numel()
if name in self.masks:
zero_params += (self.masks[name] == 0).sum().item()
return {
"total_params": total_params,
"zero_params": zero_params,
"sparsity": zero_params / total_params,
"compression_ratio": total_params / (total_params - zero_params)
}
pruner = WeightPruner(model)
pruner.create_mask(sparsity=0.7)
pruner.apply_mask()
stats = pruner.get_compression_stats()
print(f"Compression: {stats['compression_ratio']:.2f}x")
Quantization
import torch
import torch.nn as nn
class DynamicQuantizer:
def __init__(self, model: nn.Module):
self.model = model
def quantize_linear_layers(self):
quantized_model = torch.quantization.quantize_dynamic(
self.model,
{nn.Linear},
dtype=torch.qint8
)
return quantized_model
def get_model_size(self, model):
torch.save(model.state_dict(), "temp.pq")
import os
size_mb = os.path.getsize("temp.pq") / 1024 / 1024
os.remove("temp.pq")
return size_mb
class PostTrainingQuantizer:
@staticmethod
def quantize(model, calibration_data):
model.eval()
qconfig = torch.quantization.get_default_qconfig("fbgemm")
model.qconfig = qconfig
torch.quantization.prepare(model, inplace=True)
with torch.no_grad():
for batch in calibration_data:
model(batch)
torch.quantization.convert(model, inplace=True)
return model
quantizer = DynamicQuantizer(model)
quantized = quantizer.quantize_linear_layers()
original_size = quantizer.get_model_size(model)
quantized_size = quantizer.get_model_size(quantized)
print(f"Original: {original_size:.2f}MB, Quantized: {quantized_size:.2f}MB")
Knowledge Distillation
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__()
self.temperature = temperature
self.alpha = alpha
def forward(self, student_logits, teacher_logits, labels):
soft_loss = F.kl_div(
F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1),
reduction="batchmean"
) * (self.temperature ** 2)
hard_loss = F.cross_entropy(student_logits, labels)
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
class DistillationTrainer:
def __init__(self, teacher, student, temperature=4.0):
self.teacher = teacher
self.student = student
self.criterion = DistillationLoss(temperature)
self.optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)
def train_epoch(self, dataloader):
self.teacher.eval()
self.student.train()
total_loss = 0
for batch, labels in dataloader:
with torch.no_grad():
teacher_logits = self.teacher(batch)
student_logits = self.student(batch)
loss = self.criterion(student_logits, teacher_logits, labels)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
trainer = DistillationTrainer(teacher_model, student_model)
for epoch in range(10):
loss = trainer.train_epoch(train_loader)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
Combined Compression
class CompressionPipeline:
def __init__(self, model):
self.model = model
def apply_pruning(self, sparsity=0.5):
pruner = WeightPruner(self.model)
pruner.create_mask(sparsity)
pruner.apply_mask()
return self
def apply_quantization(self):
quantizer = DynamicQuantizer(self.model)
self.model = quantizer.quantize_linear_layers()
return self
def apply_distillation(self, teacher, train_loader, epochs=5):
trainer = DistillationTrainer(teacher, self.model)
for _ in range(epochs):
trainer.train_epoch(train_loader)
return self
def compress(self, config):
if config.get("prune"):
self.apply_pruning(config["prune_sparsity"])
if config.get("quantize"):
self.apply_quantization()
if config.get("distill"):
self.apply_distillation(
config["teacher"],
config["train_loader"],
config.get("distill_epochs", 5)
)
return self.model
pipeline = CompressionPipeline(model)
compressed_model = pipeline.compress({
"prune": True,
"prune_sparsity": 0.5,
"quantize": True
})
Best Practices
- Start with quantization for quickest wins
- Use structured pruning for hardware efficiency
- Combine techniques for maximum compression
- Validate accuracy after each compression step
- Benchmark on target hardware
- Monitor for accuracy degradation