Model Serving Patterns
Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe
Serving Architecture
Model serving must balance latency, throughput, cost, and reliability requirements.
βΉοΈ
Google's TFServing handles 100+ billion predictions per day across all serving patterns.
FastAPI Model Server
# model_server.py
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import numpy as np
import pickle
import time
import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="ML Model Server",
description="Production model serving API",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class PredictionRequest(BaseModel):
features: List[float]
request_id: Optional[str] = None
class PredictionResponse(BaseModel):
prediction: float
probability: Optional[float] = None
model_version: str
latency_ms: float
request_id: Optional[str] = None
class BatchPredictionRequest(BaseModel):
instances: List[List[float]]
request_id: Optional[str] = None
class HealthResponse(BaseModel):
status: str
model_loaded: bool
uptime_seconds: float
total_predictions: int
class ModelManager:
def __init__(self):
self.model = None
self.model_version = "1.0.0"
self.load_time = None
self.total_predictions = 0
self.executor = ThreadPoolExecutor(max_workers=4)
def load_model(self, model_path: str):
with open(model_path, "rb") as f:
self.model = pickle.load(f)
self.load_time = time.time()
logger.info(f"Model loaded from {model_path}")
def predict(self, features: np.ndarray) -> Dict[str, Any]:
start_time = time.time()
prediction = self.model.predict(features.reshape(1, -1))[0]
probability = None
if hasattr(self.model, "predict_proba"):
probability = float(self.model.predict_proba(features.reshape(1, -1)).max())
latency = (time.time() - start_time) * 1000
self.total_predictions += 1
return {
"prediction": float(prediction),
"probability": probability,
"latency_ms": latency
}
def predict_batch(self, instances: List[np.ndarray]) -> List[Dict[str, Any]]:
start_time = time.time()
predictions = self.model.predict(np.array(instances))
probabilities = None
if hasattr(self.model, "predict_proba"):
probabilities = self.model.predict_proba(np.array(instances))
results = []
for i, pred in enumerate(predictions):
result = {
"prediction": float(pred),
"probability": float(probabilities[i].max()) if probabilities is not None else None,
}
results.append(result)
latency = (time.time() - start_time) * 1000
self.total_predictions += len(instances)
return results
model_manager = ModelManager()
@app.on_event("startup")
async def startup_event():
model_manager.load_model("model.pkl")
@app.get("/health", response_model=HealthResponse)
async def health_check():
return HealthResponse(
status="healthy",
model_loaded=model_manager.model is not None,
uptime_seconds=time.time() - model_manager.load_time,
total_predictions=model_manager.total_predictions
)
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
if model_manager.model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
features = np.array(request.features)
result = model_manager.predict(features)
return PredictionResponse(
prediction=result["prediction"],
probability=result["probability"],
model_version=model_manager.model_version,
latency_ms=result["latency_ms"],
request_id=request.request_id
)
@app.post("/predict/batch")
async def predict_batch(request: BatchPredictionRequest):
if model_manager.model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
instances = [np.array(inst) for inst in request.instances]
results = model_manager.predict_batch(instances)
return {
"predictions": results,
"model_version": model_manager.model_version,
"batch_size": len(results),
"request_id": request.request_id
}
@app.post("/model/reload")
async def reload_model(model_path: str = "model.pkl"):
model_manager.load_model(model_path)
return {"status": "reloaded", "version": model_manager.model_version}
TensorFlow Serving
# tf_serving_client.py
import requests
import numpy as np
import json
from typing import Dict, List, Any
from dataclasses import dataclass
@dataclass
class TFServerConfig:
host: str
port: int
model_name: str
model_version: int
class TFServingClient:
def __init__(self, config: TFServerConfig):
self.config = config
self.base_url = f"http://{config.host}:{config.port}/v1/models/{config.model_name}"
def predict(self, input_data: np.ndarray) -> Dict[str, Any]:
payload = {
"instances": input_data.tolist()
}
response = requests.post(
f"{self.base_url}/versions/{self.config.model_version}:predict",
json=payload,
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
return response.json()
def predict_with_signature(self, input_data: np.ndarray, signature_name: str) -> Dict[str, Any]:
payload = {
"instances": input_data.tolist(),
"signature_name": signature_name
}
response = requests.post(
f"{self.base_url}/versions/{self.config.model_version}:predict",
json=payload
)
response.raise_for_status()
return response.json()
def get_model_metadata(self) -> Dict[str, Any]:
response = requests.get(f"{self.base_url}/metadata")
response.raise_for_status()
return response.json()
def health_check(self) -> bool:
try:
response = requests.get(f"{self.base_url}/versions/{self.config.model_version}")
return response.status_code == 200
except Exception:
return False
# Usage
config = TFServerConfig(
host="localhost",
port=8501,
model_name="image_classifier",
model_version=1
)
client = TFServingClient(config)
input_data = np.random.randn(1, 224, 224, 3).astype(np.float32)
predictions = client.predict(input_data)
Load Balancer Configuration
# nginx_load_balancer.conf
upstream ml_backend {
least_conn;
server ml-server-1:8000 weight=3;
server ml-server-2:8000 weight=3;
server ml-server-3:8000 weight=2;
keepalive 32;
}
server {
listen 80;
server_name ml-api.example.com;
location /v1/models/ {
proxy_pass http://ml_backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_connect_timeout 5s;
proxy_read_timeout 30s;
proxy_send_timeout 30s;
proxy_next_upstream error timeout http_502 http_503;
proxy_next_upstream_tries 3;
}
location /health {
proxy_pass http://ml_backend;
}
location /metrics {
proxy_pass http://prometheus:9090;
}
}
Follow-Up Questions
- How would you implement model canary deployments?
- What are the trade-offs between batching and real-time inference?
- How do you handle model fallback when the primary model fails?
- What caching strategies work best for ML predictions?