πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Automated Retraining Pipelines

MLOpsRetraining⭐ Premium

Advertisement

Automated Retraining Pipelines

Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe

Retraining Triggers

Automated retraining ensures models stay current with changing data distributions and business requirements.

ℹ️

Netflix retrains 80% of their models automatically based on performance degradation signals.

Retraining Orchestrator

# retraining_orchestrator.py
import time
import json
from typing import Dict, List, Optional, Callable
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from enum import Enum
import schedule

class RetrainingTrigger(Enum):
    SCHEDULE = "schedule"
    DRIFT = "drift"
    PERFORMANCE = "performance"
    DATA_UPDATE = "data_update"
    MANUAL = "manual"

@dataclass
class RetrainingJob:
    job_id: str
    model_name: str
    trigger: RetrainingTrigger
    config: Dict
    created_at: str
    status: str
    priority: int = 1

@dataclass
class RetrainingConfig:
    model_name: str
    training_function: str
    data_source: str
    hyperparameters: Dict
    min_performance_threshold: float
    max_retraining_interval_hours: int
    drift_threshold: float = 0.3
    auto_deploy: bool = False
    notification_email: Optional[str] = None

class RetrainingOrchestrator:
    def __init__(self):
        self.jobs: List[RetrainingJob] = []
        self.completed_jobs: List[RetrainingJob] = []
        self.configs: Dict[str, RetrainingConfig] = {}
        self.retraining_history: List[Dict] = []

    def register_model(self, config: RetrainingConfig):
        self.configs[config.model_name] = config

    def trigger_retraining(self, model_name: str, trigger: RetrainingTrigger, priority: int = 1) -> str:
        config = self.configs[model_name]
        job_id = f"retrain-{model_name}-{int(time.time())}"

        job = RetrainingJob(
            job_id=job_id,
            model_name=model_name,
            trigger=trigger,
            config={
                "training_function": config.training_function,
                "data_source": config.data_source,
                "hyperparameters": config.hyperparameters,
                "auto_deploy": config.auto_deploy,
            },
            created_at=datetime.now().isoformat(),
            status="queued",
            priority=priority
        )

        self.jobs.append(job)
        return job_id

    def check_retraining_needed(self, model_name: str, metrics: Dict) -> bool:
        config = self.configs[model_name]
        if metrics.get("accuracy", 0) < config.min_performance_threshold:
            return True
        if metrics.get("drift_score", 0) > config.drift_threshold:
            return True
        return False

    def process_queue(self) -> List[RetrainingJob]:
        self.jobs.sort(key=lambda j: j.priority)
        processed = []

        for job in self.jobs[:5]:
            job.status = "running"
            self._execute_retraining(job)
            job.status = "completed"
            self.completed_jobs.append(job)
            processed.append(job)

        self.jobs = [j for j in self.jobs if j.status != "running"]
        return processed

    def _execute_retraining(self, job: RetrainingJob):
        print(f"Executing retraining job: {job.job_id}")
        print(f"Model: {job.model_name}")
        print(f"Trigger: {job.trigger.value}")

        self.retraining_history.append({
            "job_id": job.job_id,
            "model_name": job.model_name,
            "trigger": job.trigger.value,
            "timestamp": datetime.now().isoformat(),
            "status": "completed"
        })

    def get_retraining_stats(self) -> Dict:
        return {
            "total_jobs": len(self.completed_jobs),
            "triggers": {
                trigger.value: len([j for j in self.completed_jobs if j.trigger == trigger])
                for trigger in RetrainingTrigger
            },
            "avg_priority": sum(j.priority for j in self.completed_jobs) / max(1, len(self.completed_jobs))
        }


# Usage
orchestrator = RetrainingOrchestrator()

config = RetrainingConfig(
    model_name="churn-predictor",
    training_function="train_churn_model",
    data_source="s3://data-lake/churn",
    hyperparameters={"n_estimators": 100, "max_depth": 10},
    min_performance_threshold=0.85,
    max_retraining_interval_hours=24,
    auto_deploy=True
)

orchestrator.register_model(config)
job_id = orchestrator.trigger_retraining("churn-predictor", RetrainingTrigger.DRIFT)
print(f"Job ID: {job_id}")

orchestrator.process_queue()

Retraining Pipeline

# retraining_pipeline.py
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import mlflow
import json
from typing import Dict, Tuple
from datetime import datetime

class RetrainingPipeline:
    def __init__(self, model_name: str, mlflow_tracking_uri: str = "http://localhost:5000"):
        self.model_name = model_name
        mlflow.set_tracking_uri(mlflow_tracking_uri)

    def execute(self, data_path: str, hyperparameters: Dict, previous_model_path: Optional[str] = None) -> Dict:
        df = pd.read_parquet(data_path)
        X = df.drop(columns=["target"])
        y = df["target"]

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

        previous_accuracy = None
        if previous_model_path:
            previous_accuracy = self._evaluate_previous_model(previous_model_path, X_test, y_test)

        mlflow.set_experiment(f"{self.model_name}_retraining")
        with mlflow.start_run(run_name=f"retrain_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
            mlflow.log_params(hyperparameters)
            mlflow.log_param("data_path", data_path)
            mlflow.log_param("train_samples", len(X_train))
            mlflow.log_param("test_samples", len(X_test))

            model = RandomForestClassifier(**hyperparameters)
            model.fit(X_train, y_train)

            y_pred = model.predict(X_test)
            metrics = {
                "accuracy": accuracy_score(y_test, y_pred),
                "f1_score": f1_score(y_test, y_pred, average="weighted"),
            }

            mlflow.log_metrics(metrics)
            mlflow.sklearn.log_model(model, "model")

        improvement = None
        if previous_accuracy is not None:
            improvement = metrics["accuracy"] - previous_accuracy

        return {
            "model_version": mlflow.active_run().info.run_id,
            "metrics": metrics,
            "previous_accuracy": previous_accuracy,
            "improvement": improvement,
            "deployed": improvement is not None and improvement > 0
        }

    def _evaluate_previous_model(self, model_path: str, X_test: pd.DataFrame, y_test: pd.Series) -> float:
        import pickle
        with open(model_path, "rb") as f:
            previous_model = pickle.load(f)
        y_pred = previous_model.predict(X_test)
        return accuracy_score(y_test, y_pred)


# Usage
pipeline = RetrainingPipeline("churn-predictor")
result = pipeline.execute(
    data_path="s3://data-lake/churn/latest.parquet",
    hyperparameters={"n_estimators": 100, "max_depth": 10},
    previous_model_path="models/churn-predictor-current.pkl"
)
print(f"Improvement: {result['improvement']:.4f}")

Follow-Up Questions

  1. How do you implement safe retraining with automatic rollback?
  2. What data freshness requirements apply to different retraining frequencies?
  3. How would you handle catastrophic forgetting during retraining?
  4. What monitoring is needed for retraining pipeline health?

Advertisement