ML Platform on Cloud
Difficulty: Senior Level | Companies: AWS, Google, Microsoft, Netflix, Uber
ML Platform Architecture
An ML platform handles the full lifecycle: data preparation, training, evaluation, deployment, and monitoring.
โน๏ธ
A production ML platform requires infrastructure for experiment tracking, feature management, model serving, and monitoring for data/model drift.
ML Platform Components
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ ML Platform Architecture โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Data Layer โ โ
โ โ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โ โ
โ โ โ Feature โ โ Trainingโ โ Model โ โ โ
โ โ โ Store โ โ Data โ โ Registryโ โ โ
โ โ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Compute Layer โ โ
โ โ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โ โ
โ โ โ Training โ โ Hyper- โ โ Distributedโ โ โ
โ โ โ Cluster โ โ parameterโ โ Training โ โ โ
โ โ โ (GPU) โ โ Tuning โ โ โ โ โ
โ โ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Serving Layer โ โ
โ โ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โ โ
โ โ โ Model โ โ A/B โ โ Model โ โ โ
โ โ โ Serving โ โ Testing โ โ Monitor โ โ โ
โ โ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Pattern 1: Feature Store
Centralized feature management for training and inference.
# Feature store implementation with Feast
from feast import FeatureStore, Entity, Feature, ValueType
from feast import BigQuerySource, FileSource
from datetime import datetime
# Define entities
customer = Entity(
name="customer_id",
value_type=ValueType.STRING,
description="Customer identifier",
)
# Define features
customer_features = [
Feature(name="total_orders", value_type=ValueType.INT32),
Feature(name="total_spent", value_type=ValueType.FLOAT),
Feature(name="avg_order_value", value_type=ValueType.FLOAT),
Feature(name="days_since_last_order", value_type=ValueType.INT32),
Feature(name="customer_segment", value_type=ValueType.STRING),
]
# Define feature view
customer_feature_view = FeatureView(
name="customer_features",
entities=["customer_id"],
ttl=timedelta(days=1),
features=customer_features,
online=True,
source=BigQuerySource(
table="project.dataset.customer_features",
event_timestamp_column="event_timestamp",
),
)
# Materialize features
store = FeatureStore(repo_path=".")
store.materialize(start_date=datetime(2024, 1, 1), end_date=datetime.now())
# Get features for training
training_df = store.get_historical_features(
entity_df=entity_df,
features=[
"customer_features:total_orders",
"customer_features:total_spent",
"customer_features:avg_order_value",
],
).to_df()
# Get features for real-time inference
feature_vector = store.get_online_features(
features=[
"customer_features:total_orders",
"customer_features:total_spent",
],
entity_rows=[{"customer_id": "customer_123"}],
).to_dict()
โน๏ธ
Feature stores ensure consistency between training and serving. Use online stores for low-latency inference and offline stores for training.
Pattern 2: Training Pipeline
Orchestrate ML training with experiment tracking.
# Training pipeline with MLflow and Kubeflow
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score
class MLTrainingPipeline:
def __init__(self, experiment_name: str):
mlflow.set_experiment(experiment_name)
def train(self, data_path: str, params: dict):
"""Train model with experiment tracking."""
with mlflow.start_run(run_name="training_run"):
# Load data
df = self.load_data(data_path)
X_train, X_test, y_train, y_test = train_test_split(
df.drop("target", axis=1),
df["target"],
test_size=0.2,
)
# Log parameters
mlflow.log_params(params)
# Train model
model = RandomForestClassifier(**params)
model.fit(X_train, y_train)
# Evaluate
predictions = model.predict(X_test)
metrics = {
"accuracy": accuracy_score(y_test, predictions),
"precision": precision_score(y_test, predictions),
"recall": recall_score(y_test, predictions),
}
# Log metrics
mlflow.log_metrics(metrics)
# Log model
mlflow.sklearn.log_model(
model,
"model",
registered_model_name="customer_churn_model",
)
# Log artifacts
mlflow.log_artifact("confusion_matrix.png")
return model, metrics
Pattern 3: Model Serving with Real-Time Inference
Deploy models for low-latency serving.
# SageMaker endpoint configuration
import boto3
import json
sagemaker = boto3.client('sagemaker')
# Create model
sagemaker.create_model(
ModelName='customer-churn-model',
PrimaryContainer={
'Image': '246618743249.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-cpu-py3',
'ModelDataUrl': 's3://models/customer-churn/model.tar.gz',
'Environment': {
'SAGEMAKER_PROGRAM': 'inference.py',
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://models/customer-churn/source.tar.gz',
},
},
ExecutionRoleArn='arn:aws:iam::123456789:role/SageMakerRole',
)
# Create endpoint config
sagemaker.create_endpoint_config(
EndpointConfigName='customer-churn-prod',
ProductionVariants=[
{
'VariantName': 'primary',
'ModelName': 'customer-churn-model',
'InstanceType': 'ml.m5.xlarge',
'InitialInstanceCount': 2,
'InitialVariantWeight': 1.0,
},
],
DataCaptureConfig={
'EnableCapture': True,
'InitialSamplingPercentage': 10,
'DestinationS3Uri': 's3://ml-data-capture/',
'CaptureOptions': [
{'CaptureMode': 'Input'},
{'CaptureMode': 'Output'},
],
},
)
# Deploy endpoint
sagemaker.create_endpoint(
EndpointName='customer-churn',
EndpointConfigName='customer-churn-prod',
)
Pattern 4: A/B Testing for Models
Route traffic between model versions.
// Model A/B testing with weighted routing
interface ModelEndpoint {
name: string;
version: string;
weight: number;
metrics: ModelMetrics;
}
export class ModelRouter {
private endpoints: ModelEndpoint[];
constructor(endpoints: ModelEndpoint[]) {
this.endpoints = endpoints;
}
routeRequest(): ModelEndpoint {
// Weighted random selection
const totalWeight = this.endpoints.reduce((sum, e) => sum + e.weight, 0);
let random = Math.random() * totalWeight;
for (const endpoint of this.endpoints) {
random -= endpoint.weight;
if (random <= 0) {
return endpoint;
}
}
return this.endpoints[0];
}
async evaluatePerformance(): Promise<void> {
// Collect metrics over time period
const metrics = await this.collectMetrics();
// Statistical significance test
const control = metrics.find(m => m.version === 'v1');
const treatment = metrics.find(m => m.version === 'v2');
if (this.isStatisticallySignificant(control, treatment)) {
// Adjust weights based on performance
this.adjustWeights(treatment);
}
}
private isStatisticallySignificant(control: any, treatment: any): boolean {
// Implement t-test or Bayesian analysis
const pValue = this.calculatePValue(control, treatment);
return pValue < 0.05;
}
private adjustWeights(winner: ModelEndpoint): void {
// Gradually shift traffic to better model
for (const endpoint of this.endpoints) {
if (endpoint.version === winner.version) {
endpoint.weight = Math.min(endpoint.weight + 10, 100);
} else {
endpoint.weight = Math.max(endpoint.weight - 10, 0);
}
}
}
}
โน๏ธ
Start with 90/10 traffic split. Gradually shift based on statistical significance. Monitor for at least 7 days before making decisions.
Pattern 5: Model Monitoring
Detect data drift and model degradation.
# Model monitoring with Evidently AI
from evidently.metrics import DataDriftTable, ClassificationQualityMetric
from evidently.report import Report
import pandas as pd
import boto3
class ModelMonitor:
def __init__(self):
self.s3 = boto3.client('s3')
self.cloudwatch = boto3.client('cloudwatch')
def monitor_data_drift(self, reference_data: pd.DataFrame, current_data: pd.DataFrame):
"""Detect drift between training and serving data."""
report = Report(metrics=[DataDriftTable()])
report.run(
reference_data=reference_data,
current_data=current_data,
)
# Extract drift results
drift_result = report.as_dict()
# Send metrics to CloudWatch
for feature, result in drift_result['metrics'][0]['result']['drift_by_columns'].items():
self.cloudwatch.put_metric_data(
Namespace='ML/Drift',
MetricData=[
{
'MetricName': f'{feature}_drift_score',
'Value': result['drift_score'],
'Unit': 'None',
'Dimensions': [
{'Name': 'Model', 'Value': 'customer_churn'},
],
},
],
)
# Alert if drift exceeds threshold
overall_drift = drift_result['metrics'][0]['result']['dataset_drift']
if overall_drift:
self.send_alert(f"Data drift detected: {drift_result['metrics'][0]['result']['drift_share']:.2%} features drifted")
def monitor_prediction_drift(self, predictions: list, reference_predictions: list):
"""Monitor shift in prediction distribution."""
current_dist = pd.Series(predictions).value_counts(normalize=True)
reference_dist = pd.Series(reference_predictions).value_counts(normalize=True)
# KL divergence for distribution shift
kl_divergence = self.calculate_kl_divergence(current_dist, reference_dist)
self.cloudwatch.put_metric_data(
Namespace='ML/Drift',
MetricData=[{
'MetricName': 'PredictionDistributionDrift',
'Value': kl_divergence,
'Unit': 'None',
}],
)
ML Platform Checklist
- Experiment Tracking - MLflow, Weights & Biases
- Feature Store - Feast, Tecton, SageMaker Feature Store
- Model Registry - Versioned model storage
- Model Serving - Low-latency inference endpoints
- Monitoring - Data drift, model performance
- Governance - Model lineage, audit trails
Follow-Up Questions
- How do you handle feature engineering pipelines that run daily for real-time serving?
- What strategies would you use to reduce ML model serving costs while maintaining latency?
- How do you implement model rollback when a new model version performs poorly?