πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Model Registry & Versioning

MLOpsModel Management⭐ Premium

Advertisement

Model Registry & Versioning

Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe

Model Registry Fundamentals

A model registry provides centralized model management with versioning, lineage tracking, and lifecycle management.

ℹ️

Netflix's model registry tracks over 1000 production models with full lineage from data to deployment.

MLflow Model Registry

# mlflow_registry.py
from mlflow.tracking import MlflowClient
from mlflow.entities import ModelVersion, Run
from mlflow.store.model_registry import models_uri_to_latest
import mlflow.pytorch
import mlflow.sklearn
from typing import Optional, Dict, List
from dataclasses import dataclass
from enum import Enum

class ModelStage(Enum):
    NONE = "None"
    STAGING = "Staging"
    PRODUCTION = "Production"
    ARCHIVED = "Archived"

@dataclass
class ModelMetadata:
    name: str
    version: str
    stage: str
    description: str
    tags: Dict[str, str]
    run_id: str
    created_at: str
    updated_at: str

class ModelRegistryManager:
    def __init__(self, tracking_uri: str = "http://localhost:5000"):
        self.client = MlflowClient(tracking_uri=tracking_uri)

    def register_model(self, run_id: str, model_name: str, description: str = "") -> ModelVersion:
        model_uri = f"runs:/{run_id}/model"
        model_version = mlflow.register_model(model_uri, model_name)

        self.client.update_model_version(
            name=model_name,
            version=model_version.version,
            description=description
        )

        return model_version

    def transition_model_stage(
        self,
        model_name: str,
        version: str,
        stage: ModelStage,
        archive_existing: bool = True
    ):
        self.client.transition_model_version_stage(
            name=model_name,
            version=version,
            stage=stage.value,
            archive_existing_versions=archive_existing
        )

    def add_model_tag(self, model_name: str, version: str, key: str, value: str):
        self.client.set_model_version_tag(
            name=model_name,
            version=version,
            key=key,
            value=value
        )

    def get_model_version(self, model_name: str, version: str) -> ModelVersion:
        return self.client.get_model_version(model_name, version)

    def get_latest_versions(self, model_name: str, stages: Optional[List[str]] = None) -> List[ModelVersion]:
        return self.client.get_latest_versions(model_name, stages)

    def get_model_version_by_stage(self, model_name: str, stage: ModelStage) -> Optional[ModelVersion]:
        versions = self.get_latest_versions(model_name, [stage.value])
        return versions[0] if versions else None

    def compare_model_versions(self, model_name: str, version_a: str, version_b: str) -> Dict:
        mv_a = self.get_model_version(model_name, version_a)
        mv_b = self.get_model_version(model_name, version_b)

        run_a = self.client.get_run(mv_a.run_id)
        run_b = self.client.get_run(mv_b.run_id)

        return {
            "version_a": {
                "stage": mv_a.current_stage,
                "metrics": run_a.data.metrics,
                "params": run_a.data.params,
            },
            "version_b": {
                "stage": mv_b.current_stage,
                "metrics": run_b.data.metrics,
                "params": run_b.data.params,
            }
        }

    def delete_model_version(self, model_name: str, version: str):
        self.client.delete_model_version(model_name, version)


# Usage
registry = ModelRegistryManager()

run_id = "abc123def456"
model_version = registry.register_model(
    run_id=run_id,
    model_name="churn-predictor",
    description="Random Forest model for customer churn prediction v2.1"
)

registry.add_model_tag("churn-predictor", model_version.version, "team", "ml-ops")
registry.add_model_tag("churn-predictor", model_version.version, "dataset_version", "2024-01-15")

registry.transition_model_stage("churn-predictor", model_version.version, ModelStage.STAGING)

Custom Model Registry

# custom_registry.py
import json
import pickle
import hashlib
from pathlib import Path
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from enum import Enum
import shutil

class ModelStage(Enum):
    DEVELOPMENT = "development"
    STAGING = "staging"
    PRODUCTION = "production"
    ARCHIVED = "archived"

@dataclass
class ModelVersionInfo:
    version: str
    model_name: str
    stage: ModelStage
    description: str
    tags: Dict[str, str]
    metrics: Dict[str, float]
    params: Dict[str, Any]
    run_id: str
    artifact_path: str
    created_at: str
    updated_at: str
    created_by: str

class ModelRegistry:
    def __init__(self, registry_path: str = "./model_registry"):
        self.registry_path = Path(registry_path)
        self.registry_path.mkdir(parents=True, exist_ok=True)
        self._load_index()

    def _load_index(self):
        index_file = self.registry_path / "index.json"
        if index_file.exists():
            with open(index_file) as f:
                self.index = json.load(f)
        else:
            self.index = {"models": {}}

    def _save_index(self):
        with open(self.registry_path / "index.json", "w") as f:
            json.dump(self.index, f, indent=2)

    def register_model(
        self,
        model: Any,
        model_name: str,
        description: str = "",
        metrics: Optional[Dict[str, float]] = None,
        params: Optional[Dict[str, Any]] = None,
        tags: Optional[Dict[str, str]] = None,
        run_id: Optional[str] = None,
        created_by: str = "system"
    ) -> ModelVersionInfo:
        version = self._get_next_version(model_name)

        model_dir = self.registry_path / model_name / f"v{version}"
        model_dir.mkdir(parents=True, exist_ok=True)

        artifact_path = model_dir / "model.pkl"
        with open(artifact_path, "wb") as f:
            pickle.dump(model, f)

        model_hash = self._compute_hash(artifact_path)

        version_info = ModelVersionInfo(
            version=str(version),
            model_name=model_name,
            stage=ModelStage.DEVELOPMENT,
            description=description,
            tags=tags or {},
            metrics=metrics or {},
            params=params or {},
            run_id=run_id or "",
            artifact_path=str(artifact_path),
            created_at=datetime.now().isoformat(),
            updated_at=datetime.now().isoformat(),
            created_by=created_by,
        )

        if model_name not in self.index["models"]:
            self.index["models"][model_name] = {"versions": {}}

        self.index["models"][model_name]["versions"][str(version)] = {
            "stage": version_info.stage.value,
            "hash": model_hash,
            "created_at": version_info.created_at,
        }

        metadata_file = model_dir / "metadata.json"
        with open(metadata_file, "w") as f:
            json.dump(asdict(version_info), f, indent=2)

        self._save_index()
        return version_info

    def promote_model(self, model_name: str, version: str, target_stage: ModelStage):
        model_dir = self.registry_path / model_name / f"v{version}"
        metadata_file = model_dir / "metadata.json"

        with open(metadata_file) as f:
            metadata = json.load(f)

        metadata["stage"] = target_stage.value
        metadata["updated_at"] = datetime.now().isoformat()

        with open(metadata_file, "w") as f:
            json.dump(metadata, f, indent=2)

        self.index["models"][model_name]["versions"][version]["stage"] = target_stage.value
        self._save_index()

    def load_model(self, model_name: str, version: Optional[str] = None, stage: Optional[ModelStage] = None):
        if stage:
            version = self._get_version_by_stage(model_name, stage)
        if version is None:
            raise ValueError(f"No model found for {model_name}")

        model_path = self.registry_path / model_name / f"v{version}" / "model.pkl"
        with open(model_path, "rb") as f:
            return pickle.load(f)

    def list_models(self) -> List[str]:
        return list(self.index["models"].keys())

    def list_versions(self, model_name: str) -> List[Dict]:
        if model_name not in self.index["models"]:
            return []
        versions = self.index["models"][model_name]["versions"]
        return [{"version": v, **info} for v, info in versions.items()]

    def get_latest_production_model(self, model_name: str) -> Optional[Dict]:
        versions = self.list_versions(model_name)
        prod_versions = [v for v in versions if v["stage"] == "production"]
        if not prod_versions:
            return None
        return max(prod_versions, key=lambda x: int(x["version"]))

    def _get_next_version(self, model_name: str) -> int:
        if model_name not in self.index["models"]:
            return 1
        versions = self.index["models"][model_name]["versions"]
        if not versions:
            return 1
        return max(int(v) for v in versions.keys()) + 1

    def _get_version_by_stage(self, model_name: str, stage: ModelStage) -> Optional[str]:
        versions = self.list_versions(model_name)
        for v in versions:
            if v["stage"] == stage.value:
                return v["version"]
        return None

    def _compute_hash(self, file_path: Path) -> str:
        hasher = hashlib.sha256()
        with open(file_path, "rb") as f:
            for chunk in iter(lambda: f.read(8192), b""):
                hasher.update(chunk)
        return hasher.hexdigest()[:16]


# Usage
registry = ModelRegistry("./my_model_registry")

from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100)
model.fit([[1, 2], [3, 4]], [0, 1])

version_info = registry.register_model(
    model=model,
    model_name="fraud-detector",
    description="Random Forest fraud detection model",
    metrics={"accuracy": 0.95, "f1": 0.93},
    params={"n_estimators": 100},
    tags={"team": "security", "priority": "high"}
)

registry.promote_model("fraud-detector", "1", ModelStage.PRODUCTION)
loaded_model = registry.load_model("fraud-detector", stage=ModelStage.PRODUCTION)

Follow-Up Questions

  1. How do you handle model rollback in production?
  2. What metadata should be tracked for regulatory compliance?
  3. How would you implement A/B testing between model versions?
  4. What are the implications of model registry design on team collaboration?

Advertisement