Experiment Tracking (MLflow, W&B)
Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe
Why Experiment Tracking Matters
Without proper tracking, ML experiments become unreproducible. You lose hyperparameters, data versions, and model artifacts.
βΉοΈ
Google's ML test score taxonomy requires experiment tracking as a foundational practice for production ML systems.
MLflow Experiment Tracking
# mlflow_tracking.py
import mlflow
import mlflow.sklearn
import mlflow.pytorch
from mlflow.tracking import MlflowClient
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import numpy as np
import json
from typing import Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class ExperimentConfig:
experiment_name: str
tracking_uri: str
artifact_location: str
class MLflowTracker:
def __init__(self, config: ExperimentConfig):
self.config = config
mlflow.set_tracking_uri(config.tracking_uri)
mlflow.set_experiment(config.experiment_name)
self.client = MlflowClient()
def create_experiment(self) -> str:
try:
experiment_id = mlflow.create_experiment(
self.config.experiment_name,
artifact_location=self.config.artifact_location
)
except mlflow.exceptions.MlflowException:
experiment = mlflow.get_experiment_by_name(self.config.experiment_name)
experiment_id = experiment.experiment_id
return experiment_id
def log_training_run(
self,
model,
X_train, X_test, y_train, y_test,
params: Dict[str, Any],
run_name: str
) -> str:
with mlflow.start_run(run_name=run_name) as run:
mlflow.log_params(params)
y_pred = model.predict(X_test)
metrics = {
"accuracy": accuracy_score(y_test, y_pred),
"f1_score": f1_score(y_test, y_pred, average="weighted"),
"precision": precision_score(y_test, y_pred, average="weighted"),
"recall": recall_score(y_test, y_pred, average="weighted"),
"train_size": len(X_train),
"test_size": len(X_test),
}
mlflow.log_metrics(metrics)
mlflow.log_param("model_type", type(model).__name__)
mlflow.log_param("features", X_train.shape[1])
mlflow.sklearn.log_model(
model,
artifact_path="model",
registered_model_name=f"{self.config.experiment_name}-model"
)
feature_importance = getattr(model, 'feature_importances_', None)
if feature_importance is not None:
importance_dict = {
f"feature_{i}": float(v)
for i, v in enumerate(feature_importance)
}
mlflow.log_dict(importance_dict, "feature_importance.json")
return run.info.run_id
def compare_runs(self, experiment_name: str, top_n: int = 5):
experiment = mlflow.get_experiment_by_name(experiment_name)
runs = mlflow.search_runs(
experiment_ids=[experiment.experiment_id],
order_by=["metrics.accuracy DESC"],
max_results=top_n
)
return runs[["run_id", "metrics.accuracy", "metrics.f1_score", "params.model_type"]]
# Usage
config = ExperimentConfig(
experiment_name="churn-prediction-v2",
tracking_uri="http://localhost:5000",
artifact_location="s3://ml-artifacts/churn"
)
tracker = MLflowTracker(config)
tracker.create_experiment()
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
params = {"n_estimators": 100, "max_depth": 10, "random_state": 42}
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
run_id = tracker.log_training_run(model, X_train, X_test, y_train, y_test, params, "rf-baseline")
print(f"Run ID: {run_id}")
Weights & Biases Integration
# wandb_tracking.py
import wandb
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class W&BConfig:
project: str
entity: str
tags: list
notes: str
class WandBTracker:
def __init__(self, config: W&BConfig):
self.config = config
def init_run(self, run_name: Optional[str] = None, hyperparameters: Optional[Dict] = None):
run = wandb.init(
project=self.config.project,
entity=self.config.entity,
name=run_name,
tags=self.config.tags,
notes=self.config.notes,
config=hyperparameters
)
return run
def log_training_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
wandb.log(metrics, step=step)
def log_model(self, model: nn.Module, name: str):
torch.save(model.state_dict(), f"{name}.pth")
artifact = wandb.Artifact(name, type="model")
artifact.add_file(f"{name}.pth")
wandb.log_artifact(artifact)
def log_dataset(self, data: np.ndarray, name: str):
table = wandb.Table(data=data.tolist())
wandb.log({name: table})
def log_confusion_matrix(self, y_true, y_pred, labels):
wandb.log({
"confusion_matrix": wandb.plot.confusion_matrix(
y_true=y_true,
preds=y_pred,
labels=labels
)
})
def log_hyperparameter_sweep(self, sweep_config: Dict):
sweep_id = wandb.sweep(sweep_config, project=self.config.project)
return sweep_id
class SimpleNet(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
# Training loop with W&B
config = W&BConfig(
project="image-classification",
entity="ml-team",
tags=["resnet", "production"],
notes="ResNet18 fine-tuning run"
)
tracker = WandBTracker(config)
run = tracker.init_run(run_name="resnet18-run-1", hyperparameters={
"lr": 0.001,
"batch_size": 32,
"epochs": 10
})
model = SimpleNet(784, 256, 10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
X_train = torch.randn(1000, 784)
y_train = torch.randint(0, 10, (1000,))
dataset = TensorDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
for epoch in range(10):
model.train()
epoch_loss = 0.0
correct = 0
total = 0
for batch_X, batch_y in dataloader:
optimizer.zero_grad()
outputs = model(batch_X)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
_, predicted = outputs.max(1)
total += batch_y.size(0)
correct += predicted.eq(batch_y).sum().item()
metrics = {
"train/loss": epoch_loss / len(dataloader),
"train/accuracy": correct / total,
"epoch": epoch
}
tracker.log_training_metrics(metrics, step=epoch)
run.finish()
Custom Tracking Framework
# custom_tracker.py
import json
import hashlib
import time
from pathlib import Path
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import pickle
@dataclass
class ExperimentRun:
run_id: str
experiment_name: str
params: Dict[str, Any]
metrics: Dict[str, float]
artifacts: List[str]
git_commit: Optional[str]
timestamp: str
duration_seconds: float
status: str
class ExperimentTracker:
def __init__(self, storage_path: str = "./experiments"):
self.storage_path = Path(storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
self.current_run: Optional[ExperimentRun] = None
self.start_time: Optional[float] = None
def _generate_run_id(self) -> str:
timestamp = str(time.time()).encode()
return hashlib.md5(timestamp).hexdigest()[:12]
def start_run(self, experiment_name: str, params: Dict[str, Any]) -> str:
run_id = self._generate_run_id()
self.start_time = time.time()
self.current_run = ExperimentRun(
run_id=run_id,
experiment_name=experiment_name,
params=params,
metrics={},
artifacts=[],
git_commit=self._get_git_commit(),
timestamp=datetime.now().isoformat(),
duration_seconds=0,
status="running"
)
return run_id
def log_metric(self, name: str, value: float, step: Optional[int] = None):
if self.current_run is None:
raise ValueError("No active run. Call start_run first.")
key = f"{name}" if step is None else f"{name}_step_{step}"
self.current_run.metrics[key] = value
def log_metrics(self, metrics: Dict[str, float]):
for name, value in metrics.items():
self.log_metric(name, value)
def log_artifact(self, artifact_path: str):
if self.current_run is None:
raise ValueError("No active run.")
self.current_run.artifacts.append(artifact_path)
def save_model(self, model, name: str):
artifact_dir = self.storage_path / self.current_run.run_id / "artifacts"
artifact_dir.mkdir(parents=True, exist_ok=True)
model_path = artifact_dir / f"{name}.pkl"
with open(model_path, "wb") as f:
pickle.dump(model, f)
self.log_artifact(str(model_path))
def end_run(self, status: str = "completed"):
if self.current_run is None:
return
self.current_run.duration_seconds = time.time() - self.start_time
self.current_run.status = status
run_dir = self.storage_path / self.current_run.experiment_name
run_dir.mkdir(parents=True, exist_ok=True)
run_file = run_dir / f"{self.current_run.run_id}.json"
with open(run_file, "w") as f:
json.dump(asdict(self.current_run), f, indent=2)
self.current_run = None
def _get_git_commit(self) -> Optional[str]:
import subprocess
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"],
capture_output=True, text=True, timeout=5
)
return result.stdout.strip()
except Exception:
return None
def get_experiment_runs(self, experiment_name: str) -> List[ExperimentRun]:
experiment_dir = self.storage_path / experiment_name
if not experiment_dir.exists():
return []
runs = []
for run_file in experiment_dir.glob("*.json"):
with open(run_file) as f:
data = json.load(f)
runs.append(ExperimentRun(**data))
return sorted(runs, key=lambda r: r.timestamp, reverse=True)
Follow-Up Questions
- How do you handle experiment tracking when training distributed across multiple nodes?
- What are the trade-offs between MLflow and W&B for enterprise use cases?
- How would you implement experiment tracking for AutoML systems?
- What governance requirements should experiment tracking satisfy?