Model Sizing
Model sizing involves choosing the right number of parameters, training data, and compute budget based on scaling laws and deployment constraints.
Scaling Law Calculator
import numpy as np
from dataclasses import dataclass
@dataclass
class ScalingConfig:
params: int # Number of parameters
tokens: int # Training tokens
compute: float # FLOPs
class ScalingLawCalculator:
def __init__(self):
self.alpha = 0.34 # Parameter scaling exponent
self.beta = 0.28 # Data scaling exponent
def compute_optimal_tokens(self, params: int) -> int:
"""Chinchilla optimal: ~20 tokens per parameter"""
return int(params * 20)
def compute_optimal_params(self, tokens: int) -> int:
"""Given tokens, find optimal params"""
return int(tokens / 20)
def estimate_loss(self, params: int, tokens: int) -> float:
"""Estimate loss based on scaling law"""
loss = 1 / (params ** self.alpha) + 1 / (tokens ** self.beta)
return loss
def compute_flops(self, params: int, tokens: int) -> float:
"""Estimate training FLOPs (6 * N * D)"""
return 6 * params * tokens
def find_optimal_config(self, budget_flops: float) -> ScalingConfig:
"""Find best N, D combination for given compute budget"""
best_loss = float('inf')
best_config = None
for params_exp in range(9, 13): # 1B to 10T
params = 10 ** params_exp
tokens = int(budget_flops / (6 * params))
if tokens > 0:
loss = self.estimate_loss(params, tokens)
if loss < best_loss:
best_loss = loss
best_config = ScalingConfig(params, tokens, budget_flops)
return best_config
# Usage
calculator = ScalingLawCalculator()
optimal = calculator.find_optimal_config(budget_flops=1e24)
# ScalingConfig(params=7000000000, tokens=140000000000, compute=1e24)
Model Profiler
import torch
from thop import profile
class ModelProfiler:
def __init__(self, model):
self.model = model
def count_parameters(self) -> dict:
total = sum(p.numel() for p in self.model.parameters())
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
return {"total": total, "trainable": trainable, "non_trainable": total - trainable}
def estimate_memory(self, batch_size: int = 1, seq_len: int = 512) -> dict:
param_memory = sum(p.numel() * p.element_size() for p in self.model.parameters())
batch_memory = batch_size * seq_len * 4 * 1024 # Rough estimate
return {
"parameters_gb": param_memory / (1024**3),
"batch_memory_gb": batch_memory / (1024**3),
"total_gb": (param_memory + batch_memory) / (1024**3)
}
def profile_inference(self, input_shape: tuple = (1, 512)) -> dict:
dummy_input = torch.randn(*input_shape).to(next(self.model.parameters()).device)
macs, params = profile(self.model, inputs=(dummy_input,), verbose=False)
return {"macs": macs, "params": params, "flops": macs * 2}
# Usage
profiler = ModelProfiler(model)
params = profiler.count_parameters()
memory = profiler.estimate_memory(batch_size=32, seq_len=1024)
Quantization
import torch
from transformers import BitsAndBytesConfig
class ModelQuantizer:
def __init__(self):
self.quantization_configs = {
"4bit": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
),
"8bit": BitsAndBytesConfig(load_in_8bit=True)
}
def quantize_model(self, model_name: str, precision: str = "4bit"):
from transformers import AutoModelForCausalLM
config = self.quantization_configs[precision]
return AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=config, device_map="auto"
)
def get_size_reduction(self, original_params: int, precision: str) -> dict:
bits_map = {"4bit": 4, "8bit": 8, "fp16": 16, "fp32": 32}
original_bits = 32
target_bits = bits_map.get(precision, 16)
return {
"original_size_gb": (original_params * original_bits) / (8 * 1024**3),
"quantized_size_gb": (original_params * target_bits) / (8 * 1024**3),
"reduction_factor": original_bits / target_bits
}
# Usage
quantizer = ModelQuantizer()
reduction = quantizer.get_size_reduction(7_000_000_000, "4bit")
# {"reduction_factor": 8.0}
Key Takeaways
- Chinchilla optimal suggests 20 tokens per parameter
- Over-training (Llama-style) trains small models on more data for efficiency
- Scaling laws predict performance from compute, data, and parameters
- Quantization reduces memory with minimal quality loss
- MoE architectures scale parameters without proportional compute increase