Automated Retraining Pipelines
Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe
Retraining Triggers
Automated retraining ensures models stay current with changing data distributions and business requirements.
βΉοΈ
Netflix retrains 80% of their models automatically based on performance degradation signals.
Retraining Orchestrator
# retraining_orchestrator.py
import time
import json
from typing import Dict, List, Optional, Callable
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from enum import Enum
import schedule
class RetrainingTrigger(Enum):
SCHEDULE = "schedule"
DRIFT = "drift"
PERFORMANCE = "performance"
DATA_UPDATE = "data_update"
MANUAL = "manual"
@dataclass
class RetrainingJob:
job_id: str
model_name: str
trigger: RetrainingTrigger
config: Dict
created_at: str
status: str
priority: int = 1
@dataclass
class RetrainingConfig:
model_name: str
training_function: str
data_source: str
hyperparameters: Dict
min_performance_threshold: float
max_retraining_interval_hours: int
drift_threshold: float = 0.3
auto_deploy: bool = False
notification_email: Optional[str] = None
class RetrainingOrchestrator:
def __init__(self):
self.jobs: List[RetrainingJob] = []
self.completed_jobs: List[RetrainingJob] = []
self.configs: Dict[str, RetrainingConfig] = {}
self.retraining_history: List[Dict] = []
def register_model(self, config: RetrainingConfig):
self.configs[config.model_name] = config
def trigger_retraining(self, model_name: str, trigger: RetrainingTrigger, priority: int = 1) -> str:
config = self.configs[model_name]
job_id = f"retrain-{model_name}-{int(time.time())}"
job = RetrainingJob(
job_id=job_id,
model_name=model_name,
trigger=trigger,
config={
"training_function": config.training_function,
"data_source": config.data_source,
"hyperparameters": config.hyperparameters,
"auto_deploy": config.auto_deploy,
},
created_at=datetime.now().isoformat(),
status="queued",
priority=priority
)
self.jobs.append(job)
return job_id
def check_retraining_needed(self, model_name: str, metrics: Dict) -> bool:
config = self.configs[model_name]
if metrics.get("accuracy", 0) < config.min_performance_threshold:
return True
if metrics.get("drift_score", 0) > config.drift_threshold:
return True
return False
def process_queue(self) -> List[RetrainingJob]:
self.jobs.sort(key=lambda j: j.priority)
processed = []
for job in self.jobs[:5]:
job.status = "running"
self._execute_retraining(job)
job.status = "completed"
self.completed_jobs.append(job)
processed.append(job)
self.jobs = [j for j in self.jobs if j.status != "running"]
return processed
def _execute_retraining(self, job: RetrainingJob):
print(f"Executing retraining job: {job.job_id}")
print(f"Model: {job.model_name}")
print(f"Trigger: {job.trigger.value}")
self.retraining_history.append({
"job_id": job.job_id,
"model_name": job.model_name,
"trigger": job.trigger.value,
"timestamp": datetime.now().isoformat(),
"status": "completed"
})
def get_retraining_stats(self) -> Dict:
return {
"total_jobs": len(self.completed_jobs),
"triggers": {
trigger.value: len([j for j in self.completed_jobs if j.trigger == trigger])
for trigger in RetrainingTrigger
},
"avg_priority": sum(j.priority for j in self.completed_jobs) / max(1, len(self.completed_jobs))
}
# Usage
orchestrator = RetrainingOrchestrator()
config = RetrainingConfig(
model_name="churn-predictor",
training_function="train_churn_model",
data_source="s3://data-lake/churn",
hyperparameters={"n_estimators": 100, "max_depth": 10},
min_performance_threshold=0.85,
max_retraining_interval_hours=24,
auto_deploy=True
)
orchestrator.register_model(config)
job_id = orchestrator.trigger_retraining("churn-predictor", RetrainingTrigger.DRIFT)
print(f"Job ID: {job_id}")
orchestrator.process_queue()
Retraining Pipeline
# retraining_pipeline.py
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import mlflow
import json
from typing import Dict, Tuple
from datetime import datetime
class RetrainingPipeline:
def __init__(self, model_name: str, mlflow_tracking_uri: str = "http://localhost:5000"):
self.model_name = model_name
mlflow.set_tracking_uri(mlflow_tracking_uri)
def execute(self, data_path: str, hyperparameters: Dict, previous_model_path: Optional[str] = None) -> Dict:
df = pd.read_parquet(data_path)
X = df.drop(columns=["target"])
y = df["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
previous_accuracy = None
if previous_model_path:
previous_accuracy = self._evaluate_previous_model(previous_model_path, X_test, y_test)
mlflow.set_experiment(f"{self.model_name}_retraining")
with mlflow.start_run(run_name=f"retrain_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
mlflow.log_params(hyperparameters)
mlflow.log_param("data_path", data_path)
mlflow.log_param("train_samples", len(X_train))
mlflow.log_param("test_samples", len(X_test))
model = RandomForestClassifier(**hyperparameters)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
metrics = {
"accuracy": accuracy_score(y_test, y_pred),
"f1_score": f1_score(y_test, y_pred, average="weighted"),
}
mlflow.log_metrics(metrics)
mlflow.sklearn.log_model(model, "model")
improvement = None
if previous_accuracy is not None:
improvement = metrics["accuracy"] - previous_accuracy
return {
"model_version": mlflow.active_run().info.run_id,
"metrics": metrics,
"previous_accuracy": previous_accuracy,
"improvement": improvement,
"deployed": improvement is not None and improvement > 0
}
def _evaluate_previous_model(self, model_path: str, X_test: pd.DataFrame, y_test: pd.Series) -> float:
import pickle
with open(model_path, "rb") as f:
previous_model = pickle.load(f)
y_pred = previous_model.predict(X_test)
return accuracy_score(y_test, y_pred)
# Usage
pipeline = RetrainingPipeline("churn-predictor")
result = pipeline.execute(
data_path="s3://data-lake/churn/latest.parquet",
hyperparameters={"n_estimators": 100, "max_depth": 10},
previous_model_path="models/churn-predictor-current.pkl"
)
print(f"Improvement: {result['improvement']:.4f}")
Follow-Up Questions
- How do you implement safe retraining with automatic rollback?
- What data freshness requirements apply to different retraining frequencies?
- How would you handle catastrophic forgetting during retraining?
- What monitoring is needed for retraining pipeline health?