Model Deployment in PySpark
ποΈ Architecture Diagram
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β MODEL DEPLOYMENT ARCHITECTURE β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β ββββββββββββββββ ββββββββββββββββ ββββββββββββββββ β
β β Training βββββΆβ Model βββββΆβ Deployment β β
β β Pipeline β β Registry β β Target β β
β ββββββββββββββββ ββββββββββββββββ ββββββββββββββββ β
β β β β β
β βΌ βΌ βΌ β
β ββββββββββββββββ ββββββββββββββββ ββββββββββββββββ β
β β MLflow β β Model β β Serving β β
β β Tracking β β Versioning β β Endpoints β β
β ββββββββββββββββ ββββββββββββββββ ββββββββββββββββ β
β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β DEPLOYMENT PIPELINE STAGES β β
β β β β
β β βββββββββββ βββββββββββ βββββββββββ βββββββββββ β β
β β βTrain ββββΆβValidate ββββΆβPackage ββββΆβDeploy β β β
β β βEvaluate β βTest β βContainerβ βServe β β β
β β βββββββββββ βββββββββββ βββββββββββ βββββββββββ β β
β β β β β β β β
β β βΌ βΌ βΌ βΌ β β
β β βββββββββββ βββββββββββ βββββββββββ βββββββββββ β β
β β βMetrics β βAccuracy β βDocker β βREST API β β β
β β βLogs β βAUC-ROC β βImage β βgRPC β β β
β β βParams β βF1-Score β βK8s YAML β βBatch β β β
β β βββββββββββ βββββββββββ βββββββββββ βββββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β PRODUCTION SERVING ARCHITECTURE β β
β β β β
β β Load Balancer βββΆ Model Server βββΆ Feature Store βββΆ Cache β β
β β β β β β β β
β β βΌ βΌ βΌ βΌ β β
β β βββββββββββ βββββββββββ βββββββββββ βββββββββββ β β
β β βAPI β βInferenceβ βOnline β βRedis β β β
β β βGateway β βEngine β βFeatures β βMemcachedβ β β
β β βββββββββββ βββββββββββ βββββββββββ βββββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
π Detailed Explanation
Model deployment in PySpark encompasses the entire lifecycle of taking machine learning models from development to production, ensuring reproducibility, scalability, and maintainability. This process involves model serialization, versioning, containerization, and serving infrastructure.
MLflow Integration
MLflow provides a comprehensive platform for managing the ML lifecycle:
MLflow Tracking:
- Logs experiments, parameters, metrics, and artifacts
- Provides UI for comparing runs
- Stores models in various formats
- Integrates with major ML frameworks
MLflow Models:
- Model packaging with conda/pip dependencies
- Multiple deployment flavors (spark, sklearn, pyfunc)
- Automatic input/output schema inference
- Cross-platform compatibility
MLflow Model Registry:
- Centralized model repository
- Version control and lineage tracking
- Stage transitions (Staging, Production, Archived)
- Model aliases and tags
Model Serialization Formats
PySpark Native (PMML/Spark ML):
# Save as Spark ML pipeline
pipeline.write().overwrite().save("model_path")
# Load model
loaded_pipeline = Pipeline.load("model_path")
MLflow Format:
# Log model to MLflow
mlflow.spark.log_model(pipeline, "model")
# Load model from MLflow
model = mlflow.spark.load_model("runs:/run_id/model")
ONNX (Open Neural Network Exchange):
- Cross-framework compatibility
- Optimized inference
- Hardware acceleration support
PMML (Predictive Model Markup Language):
- Standard XML-based format
- Java/JVM compatible
- Supports common ML algorithms
Deployment Strategies
Batch Deployment:
- Scheduled batch predictions
- High throughput, high latency
- Cost-effective for large volumes
- Example: Daily customer scoring
Real-time Serving:
- Low-latency inference (< 100ms)
- REST/gRPC APIs
- Auto-scaling based on load
- Example: Fraud detection, recommendations
Streaming Deployment:
- Event-driven predictions
- Micro-batch processing
- Integration with Kafka/Kinesis
- Example: Real-time personalization
Edge Deployment:
- Mobile/IoT inference
- Model optimization (quantization, pruning)
- On-device processing
- Example: Mobile vision apps
Containerization and Orchestration
Docker:
FROM apache/spark-py:3.4.1
COPY model /opt/spark/model
COPY requirements.txt /opt/spark/requirements.txt
RUN pip install -r /opt/spark/requirements.txt
ENTRYPOINT ["spark-submit", "--master", "local[*]", "serve.py"]
Kubernetes:
- Horizontal pod autoscaling
- Resource limits and requests
- Health checks and readiness probes
- Rolling deployments
Serverless:
- AWS Lambda, Google Cloud Functions
- Azure Functions
- Auto-scaling to zero
- Pay-per-invocation
Model Optimization
Quantization:
- Reduce model size (FP32 β INT8)
- Faster inference on CPUs
- Minimal accuracy loss
Pruning:
- Remove redundant weights
- Reduce model complexity
- Speed up inference
Knowledge Distillation:
- Train smaller "student" model
- Mimic larger "teacher" model
- Balance accuracy vs. speed
Monitoring and Observability
Model Metrics:
- Prediction latency (p50, p95, p99)
- Throughput (requests/second)
- Error rates (4xx, 5xx)
- Model confidence distribution
Data Drift Detection:
- Input feature distribution monitoring
- Statistical tests for drift
- Automatic retraining triggers
A/B Testing:
- Traffic splitting
- Statistical significance testing
- Gradual rollout
- Rollback capabilities
Security Considerations
Authentication:
- API keys
- OAuth 2.0
- JWT tokens
Authorization:
- Role-based access control (RBAC)
- Model-level permissions
- Data-level security
Encryption:
- TLS/HTTPS for data in transit
- Encryption at rest for models
- Secure key management
Audit Logging:
- Request/response logging
- Model access tracking
- Compliance reporting
Cost Optimization
Right-Sizing:
- Match instance types to workload
- Use spot instances for non-critical
- Reserved instances for baseline
Caching:
- Cache frequent predictions
- Cache feature computations
- Reduces redundant computation
Batching:
- Group multiple predictions
- Amortize overhead
- Improve GPU utilization
These deployment strategies enable organizations to operationalize ML models at scale while maintaining reliability, performance, and cost efficiency.
π― Key Concepts Table
Mathematical Foundations
Definition: Model Serving
Model serving maps input features to predictions with latency constraint :
For batch serving, throughput where records processed in time .
A/B Test Sample Size
To detect effect size with power and significance :
For , , and relative effect :
Model Decay Theorem
Model performance degrades over time as data distribution shifts:
where is the decay rate and is noise. Retraining is triggered when .
Rolling Back Latency
Canary deployment success probability with error rate threshold :
Versioning Cost
Total storage cost for model versions with average artifact size :
Optimal: retain most recent versions, archive older ones to cold storage.
Key Insight
MLflow's model registry decouples model training from serving. Register models with stage transitions (Staging β Production) to enable atomic rollbacks. Use model signatures to validate input schemas at serving time.
Summary
Model deployment requires managing latency, versioning, and decay. A/B testing provides statistical rigor for model comparison. Model decay follows exponential degradation, requiring scheduled retraining. Canary deployments reduce risk through gradual rollout with automatic rollback.
| Component | Purpose | Tools | Complexity | Cost |
|---|---|---|---|---|
| MLflow Tracking | Experiment logging | MLflow | Low | Free |
| Model Registry | Version control | MLflow | Medium | Free |
| Batch Inference | Scheduled predictions | Spark, Airflow | Medium | Low |
| Real-time Serving | Low-latency API | Flask, FastAPI | High | Medium |
| Containerization | Packaging | Docker, K8s | High | Medium |
| Model Optimization | Speed improvement | ONNX, TensorRT | High | Free |
| Monitoring | Observability | Prometheus, Grafana | Medium | Low |
| A/B Testing | Experimentation | Custom, LaunchDarkly | High | Medium |
π» Code Examples
Example 1: Complete MLflow Integration
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
import mlflow
import mlflow.spark
from mlflow.tracking import MlflowClient
# Initialize Spark Session
spark = SparkSession.builder \
.appName("MLflow Model Deployment") \
.config("spark.sql.shuffle.partitions", "200") \
.getOrCreate()
# Configure MLflow
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("customer_churn_prediction")
# Create sample data
data = [
(1, 25, 75000, 12, 0.85, 1),
(2, 35, 55000, 8, 0.72, 0),
(3, 45, 95000, 15, 0.92, 1),
(4, 28, 35000, 5, 0.65, 0),
(5, 52, 65000, 10, 0.78, 1),
(6, 38, 85000, 14, 0.88, 1),
(7, 31, 42000, 6, 0.68, 0),
(8, 41, 58000, 9, 0.75, 0),
(9, 33, 78000, 11, 0.86, 1),
(10, 27, 38000, 4, 0.62, 0),
]
columns = ["customer_id", "age", "income", "tenure_months",
"credit_score", "churned"]
df = spark.createDataFrame(data, columns)
# Split data
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)
# Define feature engineering
assembler = VectorAssembler(
inputCols=["age", "income", "tenure_months", "credit_score"],
outputCol="features_raw"
)
scaler = StandardScaler(
inputCol="features_raw",
outputCol="features"
)
# Define model
rf = RandomForestClassifier(
featuresCol="features",
labelCol="churned",
numTrees=100,
maxDepth=10,
seed=42
)
# Create pipeline
pipeline = Pipeline(stages=[assembler, scaler, rf])
# Log experiment with MLflow
with mlflow.start_run(run_name="rf_baseline"):
# Log parameters
mlflow.log_param("num_trees", 100)
mlflow.log_param("max_depth", 10)
mlflow.log_param("algorithm", "RandomForest")
# Fit model
model = pipeline.fit(train_df)
# Make predictions
predictions = model.transform(test_df)
# Evaluate
evaluator = BinaryClassificationEvaluator(
labelCol="churned",
rawPredictionCol="rawPrediction",
metricName="areaUnderROC"
)
auc = evaluator.evaluate(predictions)
# Log metrics
mlflow.log_metric("auc_roc", auc)
mlflow.log_metric("train_samples", train_df.count())
mlflow.log_metric("test_samples", test_df.count())
# Log model
mlflow.spark.log_model(
model,
"model",
registered_model_name="customer_churn_rf"
)
# Log artifacts
mlflow.log_artifact("requirements.txt")
print(f"Run ID: {mlflow.active_run().info.run_id}")
print(f"AUC-ROC: {auc:.4f}")
# Load and serve model
client = MlflowClient()
# Get latest model version
model_versions = client.get_latest_versions(
"customer_churn_rf",
stages=["Production"]
)
if model_versions:
model_version = model_versions[0].version
# Load model
loaded_model = mlflow.spark.load_model(
f"models:/customer_churn_rf/Production"
)
# Make predictions
test_predictions = loaded_model.transform(test_df)
test_predictions.select("customer_id", "prediction", "probability").show()
Example 2: REST API Model Serving with FastAPI
# serve.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import mlflow.spark
import uvicorn
from pyspark.sql import SparkSession
from pyspark.ml import PipelineModel
# Initialize
app = FastAPI(
title="Churn Prediction API",
description="Customer churn prediction model serving",
version="1.0.0"
)
# Initialize Spark Session for model serving
spark = SparkSession.builder \
.appName("ModelServer") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
# Load model at startup
MODEL_URI = "models:/customer_churn_rf/Production"
class PredictionRequest(BaseModel):
customer_id: int
age: int
income: float
tenure_months: int
credit_score: float
class PredictionResponse(BaseModel):
customer_id: int
prediction: float
probability: List[float]
latency_ms: float
class BatchPredictionRequest(BaseModel):
predictions: List[PredictionRequest]
# Global model cache
model_cache = {}
def load_model():
"""Load model from MLflow registry"""
global model_cache
if "model" not in model_cache:
model_cache["model"] = mlflow.spark.load_model(MODEL_URI)
return model_cache["model"]
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
load_model()
print("Model loaded successfully")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": "model" in model_cache}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""Single prediction endpoint"""
import time
start_time = time.time()
try:
# Get model
model = load_model()
# Create DataFrame from request
data = [(
request.customer_id,
request.age,
request.income,
request.tenure_months,
request.credit_score
)]
columns = ["customer_id", "age", "income",
"tenure_months", "credit_score"]
df = spark.createDataFrame(data, columns)
# Make prediction
predictions = model.transform(df)
# Extract results
result = predictions.collect()[0]
latency = (time.time() - start_time) * 1000
return PredictionResponse(
customer_id=request.customer_id,
prediction=float(result["prediction"]),
probability=result["probability"].toArray().tolist(),
latency_ms=round(latency, 2)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict/batch", response_model=List[PredictionResponse])
async def predict_batch(request: BatchPredictionRequest):
"""Batch prediction endpoint"""
import time
start_time = time.time()
try:
model = load_model()
# Create DataFrame from batch requests
data = [(
pred.customer_id,
pred.age,
pred.income,
pred.tenure_months,
pred.credit_score
) for pred in request.predictions]
columns = ["customer_id", "age", "income",
"tenure_months", "credit_score"]
df = spark.createDataFrame(data, columns)
# Make batch predictions
predictions = model.transform(df)
# Extract results
results = []
for row in predictions.collect():
results.append(PredictionResponse(
customer_id=row["customer_id"],
prediction=float(row["prediction"]),
probability=row["probability"].toArray().tolist(),
latency_ms=0 # Calculated for batch
))
total_latency = (time.time() - start_time) * 1000
avg_latency = total_latency / len(results)
for result in results:
result.latency_ms = round(avg_latency, 2)
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/model/reload")
async def reload_model():
"""Reload model from registry"""
global model_cache
model_cache = {}
load_model()
return {"status": "reloaded"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Example 3: Docker Containerization
# Dockerfile
FROM apache/spark-py:3.4.1
# Set working directory
WORKDIR /opt/spark
# Install system dependencies
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements and install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY serve.py .
COPY config/ config/
# Copy model artifacts (if not using remote registry)
COPY models/ models/
# Set environment variables
ENV SPARK_MASTER_URL=local[*]
ENV MODEL_URI=models:/customer_churn_rf/Production
ENV MLFLOW_TRACKING_URI=http://mlflow-server:5000
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run application
CMD ["python", "serve.py"]
# docker-compose.yml
version: '3.8'
services:
model-server:
build: .
ports:
- "8000:8000"
environment:
- MLFLOW_TRACKING_URI=http://mlflow-server:5000
- MODEL_URI=models:/customer_churn_rf/Production
depends_on:
- mlflow-server
deploy:
replicas: 3
resources:
limits:
cpus: '2'
memory: 4G
reservations:
cpus: '1'
memory: 2G
networks:
- app-network
mlflow-server:
image: ghcr.io/mlflow/mlflow:v2.8.0
ports:
- "5000:5000"
volumes:
- mlflow-data:/mlflow
command: mlflow server --host 0.0.0.0 --port 5000
networks:
- app-network
redis:
image: redis:7-alpine
ports:
- "6379:6379"
networks:
- app-network
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
depends_on:
- model-server
networks:
- app-network
volumes:
mlflow-data:
networks:
app-network:
driver: bridge
Example 4: Kubernetes Deployment
# k8s-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: model-server
labels:
app: model-server
spec:
replicas: 3
selector:
matchLabels:
app: model-server
template:
metadata:
labels:
app: model-server
spec:
containers:
- name: model-server
image: model-server:latest
ports:
- containerPort: 8000
env:
- name: MLFLOW_TRACKING_URI
value: "http://mlflow-service:5000"
- name: MODEL_URI
value: "models:/customer_churn_rf/Production"
resources:
requests:
memory: "2Gi"
cpu: "1000m"
limits:
memory: "4Gi"
cpu: "2000m"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
- name: redis-sidecar
image: redis:7-alpine
ports:
- containerPort: 6379
---
apiVersion: v1
kind: Service
metadata:
name: model-server-service
spec:
selector:
app: model-server
ports:
- protocol: TCP
port: 80
targetPort: 8000
type: LoadBalancer
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: model-server-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: model-server
minReplicas: 2
maxReplicas: 10
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80
Example 5: Batch Inference Pipeline
from pyspark.sql import SparkSession
from pyspark.sql.functions import current_timestamp, lit
import mlflow
from datetime import datetime
# Initialize Spark
spark = SparkSession.builder \
.appName("BatchInference") \
.config("spark.sql.shuffle.partitions", "200") \
.getOrCreate()
# Load production model
model = mlflow.spark.load_model("models:/customer_churn_rf/Production")
# Read input data (could be from data lake, warehouse, etc.)
input_data = spark.read.parquet("s3://data-lake/raw/customers/")
# Add metadata
input_with_metadata = input_data \
.withColumn("prediction_timestamp", current_timestamp()) \
.withColumn("model_version", lit("1.0")) \
.withColumn("batch_id", lit(datetime.now().strftime("%Y%m%d_%H%M%S")))
# Make predictions
predictions = model.transform(input_with_metadata)
# Add confidence scores
from pyspark.sql.functions import udf, col
from pyspark.sql.types import DoubleType
def get_confidence(probability):
"""Extract confidence from probability vector"""
return float(max(probability))
confidence_udf = udf(get_confidence, DoubleType())
predictions_with_confidence = predictions \
.withColumn("confidence", confidence_udf(col("probability")))
# Add prediction categories
predictions_final = predictions_with_confidence \
.withColumn("risk_category",
when(col("prediction") == 1,
when(col("confidence") > 0.8, "high_risk")
.otherwise("medium_risk"))
.otherwise("low_risk"))
# Save predictions to data lake
predictions_final.write \
.mode("overwrite") \
.partitionBy("risk_category") \
.parquet("s3://data-lake/predictions/churn/")
# Log batch run to MLflow
with mlflow.start_run(run_name=f"batch_{datetime.now().strftime('%Y%m%d')}"):
mlflow.log_param("input_count", input_data.count())
mlflow.log_param("prediction_count", predictions_final.count())
mlflow.log_metric("high_risk_count",
predictions_final.filter(col("risk_category") == "high_risk").count())
mlflow.log_metric("medium_risk_count",
predictions_final.filter(col("risk_category") == "medium_risk").count())
mlflow.log_metric("low_risk_count",
predictions_final.filter(col("risk_category") == "low_risk").count())
print(f"Batch prediction completed: {predictions_final.count()} predictions")
π Performance Metrics
| Metric | Batch | Real-time | Streaming | Edge |
|---|---|---|---|---|
| Latency | 100-1000ms | 10-100ms | 1-10ms | 1-5ms |
| Throughput | 10K-100K/sec | 1K-10K/sec | 10K-100K/sec | 100-1K/sec |
| Cost/1K predictions | 0.10-0.50 | 0.001-0.01 | ||
| Model Size | Unconstrained | < 100MB | < 50MB | < 10MB |
| Availability | 99.9% | 99.99% | 99.99% | 99.9% |
| Cold Start | N/A | 1-5s | 1-3s | N/A |
| Auto-scaling | Manual | Yes | Yes | No |
π§ Best Practices
1. Use Model Registry for Versioning
# β Bad: Hardcoded model paths
model = mlflow.spark.load_model("s3://models/churn_model_v1")
# β
Good: Use model registry
model = mlflow.spark.load_model("models:/churn_model/Production")
2. Implement Model Validation Before Deployment
# Validate model performance before promotion
def validate_model(model, test_data, threshold=0.85):
predictions = model.transform(test_data)
auc = BinaryClassificationEvaluator().evaluate(predictions)
if auc < threshold:
raise ValueError(f"Model AUC {auc} below threshold {threshold}")
return True
# Validate before promoting to Production
if validate_model(new_model, validation_data):
client.transition_model_version_stage(
name="churn_model",
version=new_version,
stage="Production"
)
3. Use Health Checks and Readiness Probes
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": check_model_loaded(),
"model_version": get_model_version(),
"last_prediction_time": get_last_prediction_time()
}
@app.get("/ready")
async def ready():
# Check if model is ready to serve
if not check_model_loaded():
raise HTTPException(status_code=503, detail="Model not loaded")
return {"ready": True}
4. Implement Circuit Breaker Pattern
import circuitbreaker
@circuitbreaker.circuit(failure_threshold=5, recovery_timeout=30)
def predict_with_circuit_breaker(data):
try:
model = load_model()
return model.transform(data)
except Exception as e:
# Log failure
log_prediction_failure(e)
raise
5. Use Feature Caching for Low Latency
import redis
import json
class FeatureCache:
def __init__(self):
self.redis_client = redis.Redis(host='localhost', port=6379)
def get_features(self, customer_id):
cached = self.redis_client.get(f"features:{customer_id}")
if cached:
return json.loads(cached)
return None
def set_features(self, customer_id, features, ttl=3600):
self.redis_client.setex(
f"features:{customer_id}",
ttl,
json.dumps(features)
)
# Use in prediction pipeline
def predict_with_cache(customer_id, features):
cache = FeatureCache()
# Check cache first
cached_features = cache.get_features(customer_id)
if cached_features:
return cached_features
# Compute features if not cached
new_features = compute_features(customer_id)
cache.set_features(customer_id, new_features)
return new_features
6. Implement A/B Testing Framework
import random
from datetime import datetime
class ABTestManager:
def __init__(self):
self.experiments = {}
def create_experiment(self, name, control_model, treatment_model,
traffic_split=0.5):
self.experiments[name] = {
"control": control_model,
"treatment": treatment_model,
"split": traffic_split,
"start_time": datetime.now()
}
def get_model(self, experiment_name, user_id):
experiment = self.experiments[experiment_name]
# Deterministic assignment based on user_id
hash_value = hash(f"{experiment_name}_{user_id}") % 100
if hash_value < experiment["split"] * 100:
return "treatment", experiment["treatment"]
else:
return "control", experiment["control"]
# Usage
ab_manager = ABTestManager()
ab_manager.create_experiment(
"churn_model_v2_test",
control_model="models:/churn_model/Production",
treatment_model="models:/churn_model_v2/Staging",
traffic_split=0.1 # 10% to treatment
)
# Get model for user
variant, model = ab_manager.get_model("churn_model_v2_test", user_id=12345)
7. Monitor Model Drift
from scipy import stats
import numpy as np
class DriftDetector:
def __init__(self, reference_data):
self.reference_data = reference_data
def detect_drift(self, current_data, threshold=0.05):
drift_report = {}
for column in self.reference_data.columns:
if self.reference_data.schema[column].dataType.simpleString() == 'double':
# KS test for numerical features
stat, p_value = stats.ks_2samp(
self.reference_data.select(column).toPandas().values.flatten(),
current_data.select(column).toPandas().values.flatten()
)
drift_report[column] = {
"statistic": stat,
"p_value": p_value,
"drift_detected": p_value < threshold
}
return drift_report
# Monitor drift
reference_df = spark.read.parquet("reference_data_path")
current_df = spark.read.parquet("current_data_path")
detector = DriftDetector(reference_df)
drift_report = detector.detect_drift(current_df)
# Alert if drift detected
for feature, report in drift_report.items():
if report["drift_detected"]:
send_alert(f"Drift detected in {feature}: p-value={report['p_value']}")
8. Implement Graceful Degradation
class FallbackPredictor:
def __init__(self, primary_model, fallback_model):
self.primary_model = primary_model
self.fallback_model = fallback_model
self.primary_failures = 0
self.failure_threshold = 5
def predict(self, data):
try:
result = self.primary_model.transform(data)
self.primary_failures = 0 # Reset on success
return result
except Exception as e:
self.primary_failures += 1
if self.primary_failures >= self.failure_threshold:
# Switch to fallback
log.warning("Primary model failing, using fallback")
return self.fallback_model.transform(data)
else:
raise
# Usage
predictor = FallbackPredictor(primary_model, fallback_model)
predictions = predictor.predict(test_data)
π Related Topics
- CI/CD for ML: Automated testing and deployment pipelines
- Model Monitoring: Production performance tracking
- Feature Stores: Centralized feature management
- MLOps: End-to-end ML operations
See also: Snowflake Time Travel (snowflake/02), Kafka CDC (kafka/04), Airflow DAGs (airflow/02)