Model Monitoring & Drift Detection
Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe
Why Model Monitoring?
Models degrade over time due to data drift, concept drift, and changing business conditions.
βΉοΈ
Amazon's monitoring systems detect model degradation within 15 minutes, triggering automated rollback.
Drift Detection
# drift_detection.py
import numpy as np
from scipy import stats
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
import warnings
warnings.filterwarnings('ignore')
class DriftType(Enum):
DATA_DRIFT = "data_drift"
CONCEPT_DRIFT = "concept_drift"
PREDICTION_DRIFT = "prediction_drift"
FEATURE_DRIFT = "feature_drift"
class DriftSeverity(Enum):
NONE = "none"
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
@dataclass
class DriftAlert:
drift_type: DriftType
feature_name: Optional[str]
severity: DriftSeverity
score: float
threshold: float
details: str
timestamp: str
class DriftDetector:
def __init__(self, reference_data: np.ndarray, significance_level: float = 0.05):
self.reference_data = reference_data
self.significance_level = significance_level
self.reference_stats = self._compute_stats(reference_data)
def _compute_stats(self, data: np.ndarray) -> Dict:
return {
"mean": np.mean(data, axis=0),
"std": np.std(data, axis=0),
"median": np.median(data, axis=0),
"q25": np.percentile(data, 25, axis=0),
"q75": np.percentile(data, 75, axis=0),
}
def detect_ks_drift(self, current_data: np.ndarray) -> Tuple[bool, float, float]:
n_features = self.reference_data.shape[1]
drift_detected = False
min_p_value = 1.0
for i in range(n_features):
statistic, p_value = stats.ks_2samp(
self.reference_data[:, i],
current_data[:, i]
)
if p_value < self.significance_level:
drift_detected = True
min_p_value = min(min_p_value, p_value)
return drift_detected, min_p_value, statistic
def detect_psi(self, current_data: np.ndarray, n_bins: int = 10) -> Tuple[bool, float]:
psi_values = []
n_features = self.reference_data.shape[1]
for i in range(n_features):
ref_hist, bin_edges = np.histogram(self.reference_data[:, i], bins=n_bins, density=True)
curr_hist, _ = np.histogram(current_data[:, i], bins=bin_edges, density=True)
ref_hist = np.clip(ref_hist, 1e-6, None)
curr_hist = np.clip(curr_hist, 1e-6, None)
psi = np.sum((curr_hist - ref_hist) * np.log(curr_hist / ref_hist))
psi_values.append(psi)
avg_psi = np.mean(psi_values)
drift_detected = avg_psi > 0.2
return drift_detected, avg_psi
def detect_mmd(self, current_data: np.ndarray, gamma: float = 1.0) -> Tuple[bool, float]:
from sklearn.metrics.pairwise import rbf_kernel
n_ref = min(500, len(self.reference_data))
n_curr = min(500, len(current_data))
ref_sample = self.reference_data[np.random.choice(len(self.reference_data), n_ref)]
curr_sample = current_data[np.random.choice(len(current_data), n_curr)]
K_rr = rbf_kernel(ref_sample, gamma=gamma)
K_cc = rbf_kernel(curr_sample, gamma=gamma)
K_rc = rbf_kernel(ref_sample, curr_sample, gamma=gamma)
mmd2 = np.mean(K_rr) + np.mean(K_cc) - 2 * np.mean(K_rc)
mmd = np.sqrt(max(0, mmd2))
threshold = 0.1
drift_detected = mmd > threshold
return drift_detected, mmd
def detect_concept_drift(
self,
reference_predictions: np.ndarray,
reference_labels: np.ndarray,
current_predictions: np.ndarray,
current_labels: np.ndarray
) -> Tuple[bool, float]:
ref_errors = (reference_predictions > 0.5).astype(int) != reference_labels
curr_errors = (current_predictions > 0.5).astype(int) != current_labels
ref_error_rate = np.mean(ref_errors)
curr_error_rate = np.mean(curr_errors)
error_ratio = curr_error_rate / max(ref_error_rate, 1e-6)
drift_detected = error_ratio > 1.5
return drift_detected, error_ratio
class ModelMonitor:
def __init__(self):
self.drift_detectors: Dict[str, DriftDetector] = {}
self.alerts: List[DriftAlert] = []
self.metrics_history: List[Dict] = []
def register_detector(self, name: str, detector: DriftDetector):
self.drift_detectors[name] = detector
def check_drift(self, feature_name: str, current_data: np.ndarray) -> List[DriftAlert]:
alerts = []
if feature_name not in self.drift_detectors:
return alerts
detector = self.drift_detectors[feature_name]
ks_drift, ks_p_value, ks_stat = detector.detect_ks_drift(current_data)
if ks_drift:
severity = DriftSeverity.HIGH if ks_p_value < 0.01 else DriftSeverity.MEDIUM
alerts.append(DriftAlert(
drift_type=DriftType.DATA_DRIFT,
feature_name=feature_name,
severity=severity,
score=ks_stat,
threshold=0.1,
details=f"KS test p-value: {ks_p_value:.4f}",
timestamp=datetime.now().isoformat()
))
psi_drift, psi_score = detector.detect_psi(current_data)
if psi_drift:
severity = DriftSeverity.HIGH if psi_score > 0.5 else DriftSeverity.MEDIUM
alerts.append(DriftAlert(
drift_type=DriftType.DATA_DRIFT,
feature_name=feature_name,
severity=severity,
score=psi_score,
threshold=0.2,
details=f"PSI score: {psi_score:.4f}",
timestamp=datetime.now().isoformat()
))
self.alerts.extend(alerts)
return alerts
def record_prediction(self, prediction: float, actual: Optional[float] = None):
self.metrics_history.append({
"prediction": prediction,
"actual": actual,
"timestamp": datetime.now().isoformat()
})
def get_prediction_drift(self, window_size: int = 100) -> Dict:
if len(self.metrics_history) < window_size * 2:
return {"drift_detected": False}
recent = [m["prediction"] for m in self.metrics_history[-window_size:]]
previous = [m["prediction"] for m in self.metrics_history[-window_size*2:-window_size]]
stat, p_value = stats.ks_2samp(previous, recent)
return {
"drift_detected": p_value < 0.05,
"ks_statistic": stat,
"p_value": p_value,
"recent_mean": np.mean(recent),
"previous_mean": np.mean(previous)
}
def get_alerts_summary(self) -> Dict:
summary = {"total": len(self.alerts), "by_severity": {}, "by_type": {}}
for alert in self.alerts:
summary["by_severity"][alert.severity.value] = summary["by_severity"].get(alert.severity.value, 0) + 1
summary["by_type"][alert.drift_type.value] = summary["by_type"].get(alert.drift_type.value, 0) + 1
return summary
# Usage
reference_data = np.random.randn(1000, 5)
current_data = np.random.randn(1000, 5) * 1.2 + 0.5
detector = DriftDetector(reference_data)
monitor = ModelMonitor()
monitor.register_detector("feature_0", detector)
alerts = monitor.check_drift("feature_0", current_data)
for alert in alerts:
print(f"Alert: {alert.drift_type.value} - {alert.severity.value} - {alert.details}")
Monitoring Dashboard
# monitoring_metrics.py
import time
from typing import Dict, List
from dataclasses import dataclass
from datetime import datetime
import json
@dataclass
class PredictionMetric:
prediction: float
latency_ms: float
timestamp: datetime
model_version: str
features: Dict[str, float]
class MetricsCollector:
def __init__(self):
self.predictions: List[PredictionMetric] = []
self.counters: Dict[str, int] = {}
self.histograms: Dict[str, List[float]] = {}
def record_prediction(self, metric: PredictionMetric):
self.predictions.append(metric)
self._update_counters(metric)
self._update_histograms(metric)
def _update_counters(self, metric: PredictionMetric):
self.counters["total_predictions"] = self.counters.get("total_predictions", 0) + 1
self.counters[f"model_{metric.model_version}"] = self.counters.get(f"model_{metric.model_version}", 0) + 1
def _update_histograms(self, metric: PredictionMetric):
if "latency" not in self.histograms:
self.histograms["latency"] = []
self.histograms["latency"].append(metric.latency_ms)
def get_metrics_summary(self) -> Dict:
if not self.predictions:
return {}
latencies = [p.latency_ms for p in self.predictions]
return {
"total_predictions": len(self.predictions),
"avg_latency": sum(latencies) / len(latencies),
"p50_latency": sorted(latencies)[len(latencies) // 2],
"p99_latency": sorted(latencies)[int(len(latencies) * 0.99)],
"throughput_rps": len(self.predictions) / max(1, (self.predictions[-1].timestamp - self.predictions[0].timestamp).seconds),
}
def export_prometheus_metrics(self) -> str:
lines = []
lines.append(f"# HELP ml_predictions_total Total predictions")
lines.append(f"# TYPE ml_predictions_total counter")
lines.append(f"ml_predictions_total {self.counters.get('total_predictions', 0)}")
latencies = [p.latency_ms for p in self.predictions]
if latencies:
lines.append(f"# HELP ml_prediction_latency_ms Prediction latency")
lines.append(f"# TYPE ml_prediction_latency_ms histogram")
for percentile in [50, 90, 95, 99]:
lines.append(f"ml_prediction_latency_ms{{percentile=\"{percentile}\"}} {sorted(latencies)[int(len(latencies) * percentile / 100)]}")
return "\n".join(lines)
Follow-Up Questions
- How do you distinguish between data drift and concept drift?
- What thresholds should trigger automated retraining?
- How would you implement monitoring for streaming ML models?
- What are the trade-offs between different drift detection algorithms?