Model Deployment with FastAPI
đĄ Deploying ML models to production requires serving predictions via APIs, handling scaling, and ensuring reliability. This lesson covers FastAPI for ML APIs, Docker containerization, async processing, and production deployment patterns.
1. FastAPI Fundamentals
Why FastAPI for ML?
- Fast: Async support, high performance
- Type-safe: Pydantic validation
- Auto-docs: Swagger UI at
/docs - ML-friendly: Easy model serialization, batch predictions
DfREST API for ML
A REST (Representational State Transfer) API provides a standardized interface for serving ML model predictions over HTTP. Clients send feature data in JSON format and receive prediction responses, enabling integration with any programming language or platform.
Basic ML API
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import joblib
from typing import List
app = FastAPI(title="ML Model API", version="1.0.0")
# Load model at startup
model = None
@app.on_event("startup")
async def load_model():
global model
model = joblib.load("models/model.pkl")
print("Model loaded successfully")
# Request/Response schemas
class PredictionRequest(BaseModel):
features: List[float]
class PredictionResponse(BaseModel):
prediction: float
confidence: float
class BatchRequest(BaseModel):
instances: List[List[float]]
class BatchResponse(BaseModel):
predictions: List[float]
@app.get("/health")
def health():
return {"status": "healthy", "model_loaded": model is not None}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: PredictionRequest):
try:
features = np.array(request.features).reshape(1, -1)
prediction = model.predict(features)[0]
confidence = model.predict_proba(features).max() if hasattr(model, "predict_proba") else None
return PredictionResponse(prediction=prediction, confidence=confidence)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/predict/batch", response_model=BatchResponse)
def predict_batch(request: BatchRequest):
try:
instances = np.array(request.instances)
predictions = model.predict(instances).tolist()
return BatchResponse(predictions=predictions)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
âšī¸ Health Check Endpoint
The /health endpoint is critical for production systems. Load balancers and orchestrators (Kubernetes, Docker Compose) use it to determine if a service is ready to receive traffic. Always include model status in the health check.
2. Pydantic Validation
Advanced Input Validation
from pydantic import BaseModel, Field, validator
from typing import List, Optional
from enum import Enum
import numpy as np
class ModelType(str, Enum):
CLASSIFICATION = "classification"
REGRESSION = "regression"
class FeatureInput(BaseModel):
features: List[float] = Field(..., min_items=1, max_items=100)
model_type: ModelType = ModelType.CLASSIFICATION
threshold: Optional[float] = Field(None, ge=0.0, le=1.0)
@validator("features")
def validate_features(cls, v):
if any(np.isnan(v)):
raise ValueError("Features cannot contain NaN values")
if any(np.isinf(v)):
raise ValueError("Features cannot contain infinite values")
return v
class PredictionOutput(BaseModel):
prediction: float
probability: Optional[float] = None
model_version: str = "v1.0"
@app.post("/predict/validated", response_model=PredictionOutput)
def predict_validated(input_data: FeatureInput):
features = np.array(input_data.features).reshape(1, -1)
prediction = model.predict(features)[0]
result = {
"prediction": prediction,
"model_version": "v1.0"
}
if input_data.model_type == ModelType.CLASSIFICATION and hasattr(model, "predict_proba"):
prob = model.predict_proba(features)[0]
if input_data.threshold:
result["probability"] = float(prob[1] if len(prob) > 1 else prob[0])
else:
result["probability"] = float(max(prob))
return PredictionOutput(**result)
đĄ Input Validation Importance
Pydantic validation prevents malformed inputs from crashing the model. Validate: (1) correct number of features, (2) no NaN/Inf values, (3) values within expected ranges, (4) correct data types. This catches errors before they reach the model.
3. Model Versioning
Versioned Model Endpoints
from pathlib import Path
import joblib
from typing import Dict
class ModelManager:
def __init__(self, model_dir: str = "models"):
self.model_dir = Path(model_dir)
self.models: Dict[str, object] = {}
self.active_version: Optional[str] = None
def load_model(self, version: str):
model_path = self.model_dir / f"model_{version}.pkl"
if not model_path.exists():
raise FileNotFoundError(f"Model version {version} not found")
self.models[version] = joblib.load(model_path)
self.active_version = version
return self.models[version]
def get_model(self, version: Optional[str] = None):
version = version or self.active_version
if version not in self.models:
raise ValueError(f"Model version {version} not loaded")
return self.models[version]
manager = ModelManager()
manager.load_model("v1.0")
@app.get("/models/versions")
def list_versions():
return {
"versions": list(manager.models.keys()),
"active": manager.active_version
}
@app.get("/predict/{version}")
def predict_version(version: str, request: PredictionRequest):
model = manager.get_model(version)
features = np.array(request.features).reshape(1, -1)
prediction = model.predict(features)[0]
return {"prediction": prediction, "version": version}
4. Async Processing
Background Tasks
from fastapi import BackgroundTasks
import logging
from datetime import datetime
logger = logging.getLogger(__name__)
def log_prediction(features, prediction, timestamp):
"""Log prediction to database/file"""
logger.info(f"Prediction logged: {prediction} at {timestamp}")
# In production: save to database, send to monitoring system
@app.post("/predict/async")
async def predict_async(
request: PredictionRequest,
background_tasks: BackgroundTasks
):
features = np.array(request.features).reshape(1, -1)
prediction = model.predict(features)[0]
timestamp = datetime.now().isoformat()
# Log in background
background_tasks.add_task(log_prediction, features, prediction, timestamp)
return {"prediction": prediction, "timestamp": timestamp}
Async Batch Processing with Queues
import asyncio
from collections import deque
class PredictionQueue:
def __init__(self, batch_size: int = 32, max_wait_ms: int = 100):
self.batch_size = batch_size
self.max_wait_ms = max_wait_ms
self.queue = asyncio.Queue()
self.results: Dict[str, asyncio.Future] = {}
async def add_prediction(self, request_id: str, features: np.ndarray):
future = asyncio.get_event_loop().create_future()
self.results[request_id] = future
await self.queue.put((request_id, features))
return await future
async def process_batches(self):
while True:
batch = []
try:
# Wait for first item
item = await asyncio.wait_for(
self.queue.get(),
timeout=self.max_wait_ms / 1000
)
batch.append(item)
# Collect more items up to batch size
while len(batch) < self.batch_size and not self.queue.empty():
try:
item = await asyncio.wait_for(
self.queue.get(),
timeout=0.01
)
batch.append(item)
except asyncio.TimeoutError:
break
# Process batch
if batch:
request_ids = [b[0] for b in batch]
features = np.array([b[1] for b in batch])
predictions = model.predict(features)
for req_id, pred in zip(request_ids, predictions):
self.results[req_id].set_result(float(pred))
except asyncio.TimeoutError:
continue
âšī¸ Batching for Throughput
Batch predictions increase throughput by amortizing overhead (model loading, memory allocation) across multiple requests. The tradeoff is added latency for the first request while waiting for the batch to fill. Tune batch_size and max_wait_ms based on your latency/throughput requirements.
5. Docker Deployment
Dockerfile for ML API
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application
COPY . .
# Expose port
EXPOSE 8000
# Run with uvicorn
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
# requirements.txt
fastapi==0.104.1
uvicorn==0.24.0
scikit-learn==1.3.2
numpy==1.24.3
joblib==1.3.2
pydantic==2.5.2
python-multipart==0.0.6
Docker Compose
# docker-compose.yml
version: '3.8'
services:
ml-api:
build: .
ports:
- "8000:8000"
environment:
- MODEL_PATH=/models/model.pkl
- LOG_LEVEL=info
volumes:
- ./models:/models:ro
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
nginx:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
depends_on:
- ml-api
DfContainerization for ML
Docker containers package the model, dependencies, and runtime into a single unit that runs consistently across environments. This eliminates "it works on my machine" issues and enables reproducible deployments.
6. Production Best Practices
Error Handling
from fastapi import Request
from fastapi.responses import JSONResponse
class ModelNotLoadedError(Exception):
pass
class PredictionError(Exception):
pass
@app.exception_handler(ModelNotLoadedError)
async def model_not_loaded_handler(request: Request, exc: ModelNotLoadedError):
return JSONResponse(
status_code=503,
content={"detail": "Model not loaded. Please try again later."}
)
@app.exception_handler(PredictionError)
async def prediction_error_handler(request: Request, exc: PredictionError):
return JSONResponse(
status_code=422,
content={"detail": f"Prediction failed: {str(exc)}"}
)
Rate Limiting
from slowapi import Limiter
from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
@app.get("/predict")
@limiter.limit("100/minute")
def predict_rate_limited(request: Request, input_data: FeatureInput):
features = np.array(input_data.features).reshape(1, -1)
prediction = model.predict(features)[0]
return {"prediction": prediction}
Monitoring with Prometheus
from prometheus_client import Counter, Histogram, generate_metrics
PREDICTION_COUNT = Counter("predictions_total", "Total predictions", ["model_version"])
PREDICTION_LATENCY = Histogram("prediction_latency_seconds", "Prediction latency")
@app.get("/metrics")
def metrics():
return generate_metrics()
@app.post("/predict/monitored")
def predict_monitored(request: PredictionRequest):
with PREDICTION_LATENCY.time():
features = np.array(request.features).reshape(1, -1)
prediction = model.predict(features)[0]
PREDICTION_COUNT.labels(model_version="v1.0").inc()
return {"prediction": prediction}
đĄ Production Monitoring Essentials
Monitor these key metrics: (1) prediction latency (p50, p95, p99), (2) request rate, (3) error rate, (4) model version, (5) input data distribution. Sudden changes in any of these can indicate model drift, infrastructure issues, or data quality problems.
7. Key Takeaways
đSummary: Model Deployment with FastAPI
- FastAPI provides fast, type-safe APIs with automatic documentation
- Pydantic validates inputs and handles serialization/deserialization
- Model versioning allows A/B testing and rollback capabilities
- Async processing handles high-throughput predictions efficiently
- Docker ensures consistent deployment across environments
- Production readiness: error handling, rate limiting, monitoring, health checks
- Always include a /health endpoint for orchestrator integration
- Use background tasks for non-critical operations (logging, notifications)
- Monitor prediction latency and error rates in production
8. Practice Exercises
Exercise 1: Build ML API
# TODO: Create a FastAPI for your trained model
# Endpoints: /health, /predict, /predict/batch, /model/info
# Add: input validation, error handling, logging
# Test: with curl or Python requests
Exercise 2: Dockerize
# TODO: Create Dockerfile and docker-compose.yml
# Add: health checks, volume mounts for models
# Test: build and run in Docker
# Measure: startup time, memory usage
Exercise 3: Async Processing
# TODO: Add background task logging
# Add: prediction queue for batching
# Test: throughput under load
# Compare: sync vs async performance
Exercise 4: Production API
# TODO: Add rate limiting, monitoring, versioning
# Add: Prometheus metrics endpoint
# Implement: model rollback capability
# Document: API with OpenAPI/Swagger