Interview Question (Hard) β Asked at: Google, Netflix, Uber, Stripe, Square
"Design a data drift detection system that monitors input feature distributions, detects concept drift, and triggers model retraining. How do you balance sensitivity with false alarms?"
Data Drift Overview
Data drift occurs when the statistical properties of incoming data change over time, causing model performance degradation. Detecting drift early is critical for maintaining model accuracy in production.
Types of Drift
| Type | Description | Example |
|---|---|---|
| Data Drift | Change in input feature distribution | User demographics shift |
| Concept Drift | Change in relationship between features and target | Fraud patterns evolve |
| Label Drift | Change in target variable distribution | Class imbalance changes |
| Prediction Drift | Change in model output distribution | Prediction confidence drops |
Drift Detection Architecture
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Data Drift Detection System β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
β β Data βββββΆβ Drift βββββΆβ Alert βββββΆβ Retrain β β
β β Stream β β Detector β β Manager β β Trigger β β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
β β β β β β
β βΌ βΌ βΌ βΌ β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
β β Feature β β Drift β βDashboard β β Pipeline β β
β β Store β β History β β& Reports β β Orchestr.β β
β ββββββββββββ ββββββββββββ ββββββββββββ ββββββββββββ β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Statistical Drift Tests
Kolmogorov-Smirnov Test
import numpy as np
from scipy import stats
from typing import Tuple, Dict
import pandas as pd
class KolmogorovSmirnovDriftDetector:
"""Detect drift using the Kolmogorov-Smirnov test."""
def __init__(self, reference_data: np.ndarray,
significance_level: float = 0.05):
"""
Args:
reference_data: Baseline/reference distribution
significance_level: p-value threshold for drift detection
"""
self.reference_data = reference_data
self.significance_level = significance_level
self.reference_cdf = self._compute_cdf(reference_data)
def _compute_cdf(self, data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Compute empirical CDF."""
sorted_data = np.sort(data)
n = len(sorted_data)
cdf = np.arange(1, n + 1) / n
return sorted_data, cdf
def test(self, current_data: np.ndarray) -> Dict:
"""Perform KS test between reference and current data."""
# Perform KS test
ks_statistic, p_value = stats.ks_2samp(
self.reference_data,
current_data
)
# Determine if drift detected
drift_detected = p_value < self.significance_level
# Calculate effect size
effect_size = ks_statistic
# Calculate confidence interval for KS statistic
n1, n2 = len(self.reference_data), len(current_data)
se = np.sqrt((n1 + n2) / (n1 * n2))
ci_lower = max(0, ks_statistic - 1.96 * se)
ci_upper = min(1, ks_statistic + 1.96 * se)
return {
'test': 'kolmogorov_smirnov',
'ks_statistic': float(ks_statistic),
'p_value': float(p_value),
'drift_detected': drift_detected,
'significance_level': self.significance_level,
'effect_size': float(effect_size),
'confidence_interval': [float(ci_lower), float(ci_upper)],
'reference_size': n1,
'current_size': n2
}
def visualize(self, current_data: np.ndarray):
"""Visualize CDF comparison."""
import matplotlib.pyplot as plt
ref_sorted, ref_cdf = self._compute_cdf(self.reference_data)
cur_sorted, cur_cdf = self._compute_cdf(current_data)
plt.figure(figsize=(10, 6))
plt.plot(ref_sorted, ref_cdf, label='Reference', linewidth=2)
plt.plot(cur_sorted, cur_cdf, label='Current', linewidth=2)
# Mark maximum difference
ks_result = self.test(current_data)
ks_stat = ks_result['ks_statistic']
plt.fill_between(
[ref_sorted.min(), ref_sorted.max()],
[0, 0], [1, 1],
alpha=0.1,
color='red',
label=f'KS Statistic: {ks_stat:.4f}'
)
plt.xlabel('Value')
plt.ylabel('Cumulative Probability')
plt.title('Kolmogorov-Smirnov Test')
plt.legend()
plt.grid(True, alpha=0.3)
return plt.gcf()
Population Stability Index (PSI)
import numpy as np
from typing import Dict, List
import pandas as pd
class PSIDriftDetector:
"""Population Stability Index for drift detection."""
def __init__(self, reference_data: np.ndarray,
n_bins: int = 10,
psi_threshold: float = 0.2):
"""
Args:
reference_data: Baseline distribution
n_bins: Number of bins for discretization
psi_threshold: Threshold for drift detection
"""
self.reference_data = reference_data
self.n_bins = n_bins
self.psi_threshold = psi_threshold
# Create bins from reference data
self.bin_edges = np.percentile(
reference_data,
np.linspace(0, 100, n_bins + 1)
)
self.bin_edges[0] = -np.inf
self.bin_edges[-1] = np.inf
# Compute reference proportions
self.ref_proportions = self._compute_proportions(reference_data)
def _compute_proportions(self, data: np.ndarray) -> np.ndarray:
"""Compute proportion of data in each bin."""
proportions = np.histogram(data, bins=self.bin_edges)[0]
proportions = proportions / len(data)
# Avoid division by zero
proportions = np.maximum(proportions, 1e-6)
return proportions
def test(self, current_data: np.ndarray) -> Dict:
"""Calculate PSI between reference and current distributions."""
current_proportions = self._compute_proportions(current_data)
# Calculate PSI
psi_values = (current_proportions - self.ref_proportions) * \
np.log(current_proportions / self.ref_proportions)
psi_total = np.sum(psi_values)
# Determine drift level
if psi_total < 0.1:
drift_level = "no_drift"
drift_detected = False
elif psi_total < 0.2:
drift_level = "moderate_drift"
drift_detected = True
else:
drift_level = "significant_drift"
drift_detected = True
return {
'test': 'population_stability_index',
'psi_total': float(psi_total),
'psi_per_bin': psi_values.tolist(),
'drift_detected': drift_detected,
'drift_level': drift_level,
'psi_threshold': self.psi_threshold,
'reference_proportions': self.ref_proportions.tolist(),
'current_proportions': current_proportions.tolist(),
'reference_size': len(self.reference_data),
'current_size': len(current_data)
}
def visualize(self, current_data: np.ndarray):
"""Visualize PSI distribution comparison."""
import matplotlib.pyplot as plt
current_proportions = self._compute_proportions(current_data)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Distribution comparison
axes[0].bar(
range(self.n_bins),
self.ref_proportions,
alpha=0.5,
label='Reference',
color='blue'
)
axes[0].bar(
range(self.n_bins),
current_proportions,
alpha=0.5,
label='Current',
color='red'
)
axes[0].set_xlabel('Bin')
axes[0].set_ylabel('Proportion')
axes[0].set_title('Distribution Comparison')
axes[0].legend()
# PSI per bin
psi_result = self.test(current_data)
axes[1].bar(
range(self.n_bins),
psi_result['psi_per_bin'],
color='green'
)
axes[1].set_xlabel('Bin')
axes[1].set_ylabel('PSI Contribution')
axes[1].set_title(f'PSI per Bin (Total: {psi_result["psi_total"]:.4f})')
plt.tight_layout()
return fig
Wasserstein Distance (Earth Mover's Distance)
from scipy.stats import wasserstein_distance
import numpy as np
from typing import Dict
class WassersteinDriftDetector:
"""Drift detection using Wasserstein distance."""
def __init__(self, reference_data: np.ndarray,
threshold_percentile: float = 95):
self.reference_data = reference_data
self.threshold_percentile = threshold_percentile
# Bootstrap threshold calculation
self.threshold = self._calculate_threshold()
def _calculate_threshold(self, n_bootstrap: int = 1000) -> float:
"""Calculate threshold via bootstrapping."""
bootstrap_distances = []
for _ in range(n_bootstrap):
sample1 = np.random.choice(
self.reference_data,
size=len(self.reference_data),
replace=True
)
sample2 = np.random.choice(
self.reference_data,
size=len(self.reference_data),
replace=True
)
distance = wasserstein_distance(sample1, sample2)
bootstrap_distances.append(distance)
return np.percentile(bootstrap_distances, self.threshold_percentile)
def test(self, current_data: np.ndarray) -> Dict:
"""Calculate Wasserstein distance."""
distance = wasserstein_distance(self.reference_data, current_data)
drift_detected = distance > self.threshold
# Normalize distance
normalized_distance = distance / np.std(self.reference_data)
return {
'test': 'wasserstein_distance',
'distance': float(distance),
'normalized_distance': float(normalized_distance),
'threshold': float(self.threshold),
'drift_detected': drift_detected,
'reference_size': len(self.reference_data),
'current_size': len(current_data)
}
Evidently AI Integration
Evidently Dashboard
from evidently.report import Report
from evidently.metric_preset import (
DataDriftPreset,
TargetDriftPreset,
DataQualityPreset
)
from evidently.test_suite import TestSuite
from evidently.tests import (
TestShareOfDriftedColumns,
TestColumnsType,
TestShareOfMissingValues
)
import pandas as pd
import json
from datetime import datetime
class EvidentlyDriftMonitor:
def __init__(self, reference_data: pd.DataFrame):
self.reference_data = reference_data
def create_drift_report(self, current_data: pd.DataFrame) -> dict:
"""Generate comprehensive drift report."""
report = Report(metrics=[
DataDriftPreset(stattest='ks', stattest_threshold=0.05),
TargetDriftPreset(),
DataQualityPreset(),
])
report.run(
reference_data=self.reference_data,
current_data=current_data
)
# Save HTML report
report.save_html("drift_report.html")
# Get JSON results
result = report.as_dict()
return result
def create_drift_tests(self, current_data: pd.DataFrame) -> dict:
"""Create and run drift test suite."""
test_suite = TestSuite(tests=[
TestShareOfDriftedColumns(lt=0.3),
TestColumnsType(),
TestShareOfMissingValues(lt=0.1),
])
test_suite.run(
reference_data=self.reference_data,
current_data=current_data
)
# Get test results
results = test_suite.as_dict()
return results
def calculate_feature_drift(self, current_data: pd.DataFrame) -> pd.DataFrame:
"""Calculate drift for each feature individually."""
drift_results = []
for column in self.reference_data.columns:
if self.reference_data[column].dtype in ['int64', 'float64']:
from scipy.stats import ks_2samp
stat, p_value = ks_2samp(
self.reference_data[column].dropna(),
current_data[column].dropna()
)
drift_results.append({
'feature': column,
'ks_statistic': stat,
'p_value': p_value,
'drift_detected': p_value < 0.05,
'reference_mean': self.reference_data[column].mean(),
'current_mean': current_data[column].mean(),
'mean_change': (
(current_data[column].mean() - self.reference_data[column].mean()) /
self.reference_data[column].mean() * 100
)
})
return pd.DataFrame(drift_results)
def generate_alert(self, drift_results: dict) -> dict:
"""Generate alert based on drift results."""
drifted_features = [
f for f in drift_results.get('drifted_columns', [])
]
severity = "low"
if len(drifted_features) > 5:
severity = "high"
elif len(drifted_features) > 2:
severity = "medium"
return {
'alert_type': 'data_drift',
'severity': severity,
'timestamp': datetime.now().isoformat(),
'drifted_features': drifted_features,
'total_drifted': len(drifted_features),
'recommendation': self._get_recommendation(severity)
}
def _get_recommendation(self, severity: str) -> str:
if severity == 'high':
return "Immediate retraining recommended. Multiple features show significant drift."
elif severity == 'medium':
return "Schedule retraining within 24 hours. Monitor affected features closely."
else:
return "Continue monitoring. Drift detected but within acceptable range."
Evidently with Prometheus Export
from prometheus_client import CollectorRegistry, Gauge, push_to_gateway
import time
class EvidentlyPrometheusExporter:
def __init__(self, pushgateway_url: str):
self.pushgateway_url = pushgateway_url
self.registry = CollectorRegistry()
# Define metrics
self.drift_score = Gauge(
'model_feature_drift_score',
'Feature drift score',
['feature_name'],
registry=self.registry
)
self.data_quality_score = Gauge(
'model_data_quality_score',
'Data quality score',
['quality_type'],
registry=self.registry
)
self.drift_detected = Gauge(
'model_drift_detected',
'Whether drift is detected (1) or not (0)',
['feature_name'],
registry=self.registry
)
def export_metrics(self, drift_results: pd.DataFrame):
"""Export drift metrics to Prometheus."""
for _, row in drift_results.iterrows():
self.drift_score.labels(
feature_name=row['feature']
).set(row['ks_statistic'])
self.drift_detected.labels(
feature_name=row['feature']
).set(1 if row['drift_detected'] else 0)
# Push to gateway
push_to_gateway(
self.pushgateway_url,
job='ml_drift_monitor',
registry=self.registry
)
WhyLabs Integration
WhyLabs Profile Upload
import whylogs as why
from whylogs.api.writer.whylabs import WhyLabsWriter
from whylogs.core.schema import DatasetSchema
import pandas as pd
from datetime import datetime
class WhyLabsDriftMonitor:
def __init__(self, org_id: str, api_key: str, dataset_id: str):
self.org_id = org_id
self.api_key = api_key
self.dataset_id = dataset_id
# Configure WhyLabs
import os
os.environ['WHYLABS_API_KEY'] = api_key
os.environ['WHYLABS_ORG_ID'] = org_id
os.environ['WHYLABS_DEFAULT_DATASET_ID'] = dataset_id
def log_reference_profile(self, data: pd.DataFrame):
"""Log reference profile to WhyLabs."""
result = why.log(
pandas=data,
schema=DatasetSchema(),
writer=WhyLabsWriter()
)
print(f"Reference profile logged: {result}")
return result
def log_current_profile(self, data: pd.DataFrame,
timestamp: datetime = None):
"""Log current data profile for comparison."""
if timestamp is None:
timestamp = datetime.now()
result = why.log(
pandas=data,
schema=DatasetSchema(),
writer=WhyLabsWriter()
)
print(f"Current profile logged at {timestamp}: {result}")
return result
def get_drift_results(self) -> dict:
"""Retrieve drift analysis from WhyLabs."""
import requests
headers = {
'Authorization': f'Basic {self.api_key}',
'Content-Type': 'application/json'
}
# Get analysis results
url = (
f"https://api.whylabsapp.com/v0/datasets/{self.dataset_id}/"
f"analysis"
)
response = requests.get(url, headers=headers)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"WhyLabs API error: {response.status_code}")
def setup_monitor_suite(self):
"""Set up monitoring suite in WhyLabs."""
from whylogs.api.writer.whylabs import WhyLabsWriter
# Configure monitoring
writer = WhyLabsWriter()
# Set up alert conditions
monitoring_config = {
'constraints': [
{
'metric': 'distribution',
'condition': 'drift',
'threshold': 0.1,
'alert': True
},
{
'metric': 'missing_values',
'condition': 'greater_than',
'threshold': 0.05,
'alert': True
}
]
}
return monitoring_config
βΉοΈ
WhyLabs provides end-to-end ML monitoring with automatic drift detection, data quality checks, and model performance tracking. It integrates with major data platforms and provides automated alerting.
Production Drift Monitoring System
Complete Drift Monitoring Pipeline
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from dataclasses import dataclass
import json
import logging
from kafka import KafkaConsumer, KafkaProducer
import redis
from prometheus_client import Counter, Histogram, Gauge
logger = logging.getLogger(__name__)
@dataclass
class DriftAlert:
timestamp: datetime
feature_name: str
drift_score: float
p_value: float
severity: str
message: str
class ProductionDriftMonitor:
def __init__(self, config: dict):
self.config = config
# Initialize components
self.redis = redis.Redis(
host=config['redis_host'],
port=config['redis_port']
)
self.reference_profiles = {}
self.drift_history = []
# Prometheus metrics
self.drift_counter = Counter(
'drift_detection_total',
'Total drift detections',
['feature', 'severity']
)
self.drift_score_gauge = Gauge(
'drift_score',
'Current drift score',
['feature']
)
self.latency_histogram = Histogram(
'drift_check_duration_seconds',
'Time spent on drift checks'
)
def load_reference_profile(self, feature_name: str,
reference_data: np.ndarray):
"""Load reference profile for a feature."""
self.reference_profiles[feature_name] = {
'data': reference_data,
'mean': np.mean(reference_data),
'std': np.std(reference_data),
'percentiles': np.percentile(reference_data, [25, 50, 75]),
'loaded_at': datetime.now()
}
# Cache in Redis
self.redis.setex(
f"reference:{feature_name}",
timedelta(days=30),
json.dumps({
'mean': float(np.mean(reference_data)),
'std': float(np.std(reference_data)),
'percentiles': np.percentile(reference_data, [25, 50, 75]).tolist(),
'n_samples': len(reference_data)
})
)
def check_drift(self, feature_name: str,
current_data: np.ndarray) -> DriftAlert:
"""Check for drift in a single feature."""
with self.latency_histogram.time():
reference_data = self.reference_profiles[feature_name]['data']
# Perform KS test
from scipy.stats import ks_2samp
ks_stat, p_value = ks_2samp(reference_data, current_data)
# Calculate PSI
psi = self._calculate_psi(reference_data, current_data)
# Determine severity
severity = self._determine_severity(ks_stat, p_value, psi)
# Update metrics
self.drift_score_gauge.labels(
feature=feature_name
).set(ks_stat)
if p_value < 0.05:
self.drift_counter.labels(
feature=feature_name,
severity=severity
).inc()
# Create alert if needed
alert = None
if p_value < 0.05:
alert = DriftAlert(
timestamp=datetime.now(),
feature_name=feature_name,
drift_score=ks_stat,
p_value=p_value,
severity=severity,
message=self._generate_message(
feature_name, ks_stat, p_value, psi
)
)
self.drift_history.append(alert)
self._send_alert(alert)
return alert
def _calculate_psi(self, reference: np.ndarray,
current: np.ndarray, n_bins: int = 10) -> float:
"""Calculate Population Stability Index."""
# Create bins from reference data
bin_edges = np.percentile(reference, np.linspace(0, 100, n_bins + 1))
bin_edges[0] = -np.inf
bin_edges[-1] = np.inf
# Calculate proportions
ref_props = np.histogram(reference, bins=bin_edges)[0] / len(reference)
cur_props = np.histogram(current, bins=bin_edges)[0] / len(current)
# Avoid zeros
ref_props = np.maximum(ref_props, 1e-6)
cur_props = np.maximum(cur_props, 1e-6)
# Calculate PSI
psi = np.sum((cur_props - ref_props) * np.log(cur_props / ref_props))
return psi
def _determine_severity(self, ks_stat: float, p_value: float,
psi: float) -> str:
"""Determine drift severity level."""
if p_value < 0.001 or psi > 0.5:
return "critical"
elif p_value < 0.01 or psi > 0.25:
return "high"
elif p_value < 0.05 or psi > 0.1:
return "medium"
else:
return "low"
def _generate_message(self, feature_name: str, ks_stat: float,
p_value: float, psi: float) -> str:
"""Generate human-readable drift message."""
return (
f"Drift detected in feature '{feature_name}': "
f"KS statistic={ks_stat:.4f}, p-value={p_value:.4f}, "
f"PSI={psi:.4f}"
)
def _send_alert(self, alert: DriftAlert):
"""Send drift alert through configured channels."""
# Log
logger.warning(f"DRIFT ALERT: {alert.message}")
# Send to Kafka for downstream processing
producer = KafkaProducer(
bootstrap_servers=self.config['kafka_servers'],
value_serializer=lambda v: json.dumps(v, default=str).encode('utf-8')
)
producer.send(
'drift-alerts',
value={
'timestamp': alert.timestamp.isoformat(),
'feature': alert.feature_name,
'drift_score': alert.drift_score,
'p_value': alert.p_value,
'severity': alert.severity,
'message': alert.message
}
)
producer.flush()
def generate_daily_report(self) -> dict:
"""Generate daily drift monitoring report."""
today_alerts = [
a for a in self.drift_history
if a.timestamp.date() == datetime.now().date()
]
report = {
'date': datetime.now().isoformat(),
'total_alerts': len(today_alerts),
'critical_alerts': len([a for a in today_alerts if a.severity == 'critical']),
'high_alerts': len([a for a in today_alerts if a.severity == 'high']),
'medium_alerts': len([a for a in today_alerts if a.severity == 'medium']),
'low_alerts': len([a for a in today_alerts if a.severity == 'low']),
'features_affected': list(set([a.feature_name for a in today_alerts])),
'recommendations': self._generate_recommendations(today_alerts)
}
return report
def _generate_recommendations(self, alerts: List[DriftAlert]) -> List[str]:
"""Generate recommendations based on drift alerts."""
recommendations = []
critical_count = len([a for a in alerts if a.severity == 'critical'])
if critical_count > 0:
recommendations.append(
"URGENT: Critical drift detected. Immediate model retraining recommended."
)
high_count = len([a for a in alerts if a.severity == 'high'])
if high_count > 3:
recommendations.append(
"Multiple high-severity drift events. Consider scheduled retraining."
)
if len(alerts) > 10:
recommendations.append(
"Elevated drift activity. Review data pipeline for potential issues."
)
return recommendations
β οΈ
Drift detection thresholds should be tuned based on your specific use case. Too sensitive thresholds cause alert fatigue, while too lenient thresholds miss real drift. Start with statistical significance (p < 0.05) and adjust based on false positive rates.
Summary
Data drift detection is essential for maintaining model performance:
- Statistical Tests: KS test, PSI, Wasserstein distance
- Evidently AI: Comprehensive reporting and test suites
- WhyLabs: Cloud-native ML monitoring platform
- Production Monitoring: Real-time drift detection with alerting
Implement automated drift detection to maintain model accuracy in production.