πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Experiment Tracking (MLflow, W&B)

MLOpsExperiment Tracking⭐ Premium

Advertisement

Experiment Tracking (MLflow, W&B)

Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe

Why Experiment Tracking Matters

Without proper tracking, ML experiments become unreproducible. You lose hyperparameters, data versions, and model artifacts.

ℹ️

Google's ML test score taxonomy requires experiment tracking as a foundational practice for production ML systems.

MLflow Experiment Tracking

# mlflow_tracking.py
import mlflow
import mlflow.sklearn
import mlflow.pytorch
from mlflow.tracking import MlflowClient
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import numpy as np
import json
from typing import Dict, Any, Optional
from dataclasses import dataclass

@dataclass
class ExperimentConfig:
    experiment_name: str
    tracking_uri: str
    artifact_location: str

class MLflowTracker:
    def __init__(self, config: ExperimentConfig):
        self.config = config
        mlflow.set_tracking_uri(config.tracking_uri)
        mlflow.set_experiment(config.experiment_name)
        self.client = MlflowClient()

    def create_experiment(self) -> str:
        try:
            experiment_id = mlflow.create_experiment(
                self.config.experiment_name,
                artifact_location=self.config.artifact_location
            )
        except mlflow.exceptions.MlflowException:
            experiment = mlflow.get_experiment_by_name(self.config.experiment_name)
            experiment_id = experiment.experiment_id
        return experiment_id

    def log_training_run(
        self,
        model,
        X_train, X_test, y_train, y_test,
        params: Dict[str, Any],
        run_name: str
    ) -> str:
        with mlflow.start_run(run_name=run_name) as run:
            mlflow.log_params(params)

            y_pred = model.predict(X_test)
            metrics = {
                "accuracy": accuracy_score(y_test, y_pred),
                "f1_score": f1_score(y_test, y_pred, average="weighted"),
                "precision": precision_score(y_test, y_pred, average="weighted"),
                "recall": recall_score(y_test, y_pred, average="weighted"),
                "train_size": len(X_train),
                "test_size": len(X_test),
            }
            mlflow.log_metrics(metrics)

            mlflow.log_param("model_type", type(model).__name__)
            mlflow.log_param("features", X_train.shape[1])

            mlflow.sklearn.log_model(
                model,
                artifact_path="model",
                registered_model_name=f"{self.config.experiment_name}-model"
            )

            feature_importance = getattr(model, 'feature_importances_', None)
            if feature_importance is not None:
                importance_dict = {
                    f"feature_{i}": float(v)
                    for i, v in enumerate(feature_importance)
                }
                mlflow.log_dict(importance_dict, "feature_importance.json")

            return run.info.run_id

    def compare_runs(self, experiment_name: str, top_n: int = 5):
        experiment = mlflow.get_experiment_by_name(experiment_name)
        runs = mlflow.search_runs(
            experiment_ids=[experiment.experiment_id],
            order_by=["metrics.accuracy DESC"],
            max_results=top_n
        )
        return runs[["run_id", "metrics.accuracy", "metrics.f1_score", "params.model_type"]]


# Usage
config = ExperimentConfig(
    experiment_name="churn-prediction-v2",
    tracking_uri="http://localhost:5000",
    artifact_location="s3://ml-artifacts/churn"
)

tracker = MLflowTracker(config)
tracker.create_experiment()

from sklearn.datasets import make_classification
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

params = {"n_estimators": 100, "max_depth": 10, "random_state": 42}
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)

run_id = tracker.log_training_run(model, X_train, X_test, y_train, y_test, params, "rf-baseline")
print(f"Run ID: {run_id}")

Weights & Biases Integration

# wandb_tracking.py
import wandb
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import Dict, Any, Optional
from dataclasses import dataclass

@dataclass
class W&BConfig:
    project: str
    entity: str
    tags: list
    notes: str

class WandBTracker:
    def __init__(self, config: W&BConfig):
        self.config = config

    def init_run(self, run_name: Optional[str] = None, hyperparameters: Optional[Dict] = None):
        run = wandb.init(
            project=self.config.project,
            entity=self.config.entity,
            name=run_name,
            tags=self.config.tags,
            notes=self.config.notes,
            config=hyperparameters
        )
        return run

    def log_training_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
        wandb.log(metrics, step=step)

    def log_model(self, model: nn.Module, name: str):
        torch.save(model.state_dict(), f"{name}.pth")
        artifact = wandb.Artifact(name, type="model")
        artifact.add_file(f"{name}.pth")
        wandb.log_artifact(artifact)

    def log_dataset(self, data: np.ndarray, name: str):
        table = wandb.Table(data=data.tolist())
        wandb.log({name: table})

    def log_confusion_matrix(self, y_true, y_pred, labels):
        wandb.log({
            "confusion_matrix": wandb.plot.confusion_matrix(
                y_true=y_true,
                preds=y_pred,
                labels=labels
            )
        })

    def log_hyperparameter_sweep(self, sweep_config: Dict):
        sweep_id = wandb.sweep(sweep_config, project=self.config.project)
        return sweep_id


class SimpleNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))


# Training loop with W&B
config = W&BConfig(
    project="image-classification",
    entity="ml-team",
    tags=["resnet", "production"],
    notes="ResNet18 fine-tuning run"
)

tracker = WandBTracker(config)
run = tracker.init_run(run_name="resnet18-run-1", hyperparameters={
    "lr": 0.001,
    "batch_size": 32,
    "epochs": 10
})

model = SimpleNet(784, 256, 10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

X_train = torch.randn(1000, 784)
y_train = torch.randint(0, 10, (1000,))
dataset = TensorDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

for epoch in range(10):
    model.train()
    epoch_loss = 0.0
    correct = 0
    total = 0

    for batch_X, batch_y in dataloader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        _, predicted = outputs.max(1)
        total += batch_y.size(0)
        correct += predicted.eq(batch_y).sum().item()

    metrics = {
        "train/loss": epoch_loss / len(dataloader),
        "train/accuracy": correct / total,
        "epoch": epoch
    }
    tracker.log_training_metrics(metrics, step=epoch)

run.finish()

Custom Tracking Framework

# custom_tracker.py
import json
import hashlib
import time
from pathlib import Path
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import pickle

@dataclass
class ExperimentRun:
    run_id: str
    experiment_name: str
    params: Dict[str, Any]
    metrics: Dict[str, float]
    artifacts: List[str]
    git_commit: Optional[str]
    timestamp: str
    duration_seconds: float
    status: str

class ExperimentTracker:
    def __init__(self, storage_path: str = "./experiments"):
        self.storage_path = Path(storage_path)
        self.storage_path.mkdir(parents=True, exist_ok=True)
        self.current_run: Optional[ExperimentRun] = None
        self.start_time: Optional[float] = None

    def _generate_run_id(self) -> str:
        timestamp = str(time.time()).encode()
        return hashlib.md5(timestamp).hexdigest()[:12]

    def start_run(self, experiment_name: str, params: Dict[str, Any]) -> str:
        run_id = self._generate_run_id()
        self.start_time = time.time()
        self.current_run = ExperimentRun(
            run_id=run_id,
            experiment_name=experiment_name,
            params=params,
            metrics={},
            artifacts=[],
            git_commit=self._get_git_commit(),
            timestamp=datetime.now().isoformat(),
            duration_seconds=0,
            status="running"
        )
        return run_id

    def log_metric(self, name: str, value: float, step: Optional[int] = None):
        if self.current_run is None:
            raise ValueError("No active run. Call start_run first.")
        key = f"{name}" if step is None else f"{name}_step_{step}"
        self.current_run.metrics[key] = value

    def log_metrics(self, metrics: Dict[str, float]):
        for name, value in metrics.items():
            self.log_metric(name, value)

    def log_artifact(self, artifact_path: str):
        if self.current_run is None:
            raise ValueError("No active run.")
        self.current_run.artifacts.append(artifact_path)

    def save_model(self, model, name: str):
        artifact_dir = self.storage_path / self.current_run.run_id / "artifacts"
        artifact_dir.mkdir(parents=True, exist_ok=True)
        model_path = artifact_dir / f"{name}.pkl"
        with open(model_path, "wb") as f:
            pickle.dump(model, f)
        self.log_artifact(str(model_path))

    def end_run(self, status: str = "completed"):
        if self.current_run is None:
            return

        self.current_run.duration_seconds = time.time() - self.start_time
        self.current_run.status = status

        run_dir = self.storage_path / self.current_run.experiment_name
        run_dir.mkdir(parents=True, exist_ok=True)
        run_file = run_dir / f"{self.current_run.run_id}.json"

        with open(run_file, "w") as f:
            json.dump(asdict(self.current_run), f, indent=2)

        self.current_run = None

    def _get_git_commit(self) -> Optional[str]:
        import subprocess
        try:
            result = subprocess.run(
                ["git", "rev-parse", "HEAD"],
                capture_output=True, text=True, timeout=5
            )
            return result.stdout.strip()
        except Exception:
            return None

    def get_experiment_runs(self, experiment_name: str) -> List[ExperimentRun]:
        experiment_dir = self.storage_path / experiment_name
        if not experiment_dir.exists():
            return []

        runs = []
        for run_file in experiment_dir.glob("*.json"):
            with open(run_file) as f:
                data = json.load(f)
                runs.append(ExperimentRun(**data))
        return sorted(runs, key=lambda r: r.timestamp, reverse=True)

Follow-Up Questions

  1. How do you handle experiment tracking when training distributed across multiple nodes?
  2. What are the trade-offs between MLflow and W&B for enterprise use cases?
  3. How would you implement experiment tracking for AutoML systems?
  4. What governance requirements should experiment tracking satisfy?

Advertisement