CW

Model Deployment with FastAPI

Module 15: Data Engineering & MLOpsFree Lesson

Advertisement

Model Deployment with FastAPI

Deploy ML models as production REST APIs using FastAPI, Docker, and best practices for monitoring and scaling.

Deployment Architecture

Model Deployment ArchitectureClientLoad BalancerFastAPI+ Model ServerUvicorn WorkersModel StoreMonitoring (Prometheus)Logging (ELK/Sentry)Tracing (OpenTelemetry)Health ChecksDocker Container (Ubuntu + Python + Model + FastAPI)Stateless API containers behind a load balancer

1. FastAPI Model Server

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import joblib
import numpy as np
import time
from contextlib import asynccontextmanager

# Load model at startup
model = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    global model
    model = joblib.load("model.pkl")
    yield
    # Cleanup
    model = None

app = FastAPI(title="ML Prediction API", version="1.0.0", lifespan=lifespan)

class PredictionRequest(BaseModel):
    features: list[float] = Field(..., min_items=1, max_items=100)
    request_id: str | None = None

class PredictionResponse(BaseModel):
    prediction: int
    probability: float
    latency_ms: float

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    start = time.time()
    try:
        X = np.array(request.features).reshape(1, -1)
        pred = int(model.predict(X)[0])
        prob = float(model.predict_proba(X).max())
        latency = (time.time() - start) * 1000
        return PredictionResponse(
            prediction=pred,
            probability=prob,
            latency_ms=round(latency, 2)
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health():
    return {"status": "healthy", "model_loaded": model is not None}

@app.get("/metadata")
async def metadata():
    return {
        "model_type": type(model).__name__,
        "features_expected": model.n_features_in_ if hasattr(model, "n_features_in_") else None
    }

2. Request Validation

from pydantic import BaseModel, Field, field_validator
from typing import Literal

class AdvancedRequest(BaseModel):
    features: list[float]
    model_version: str = "latest"
    output_type: Literal["class", "probabilities", "both"] = "class"

    @field_validator("features")
    @classmethod
    def validate_features(cls, v):
        if len(v) < 1:
            raise ValueError("At least one feature required")
        if any(np.isnan(x) for x in v):
            raise ValueError("NaN values not allowed")
        return v

3. Docker Containerization

# Dockerfile
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY model.pkl .
COPY app.py .

EXPOSE 8000

CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
# requirements.txt
fastapi==0.109.0
uvicorn[standard]==0.27.0
joblib==1.3.2
numpy==1.26.3
scikit-learn==1.4.0
pydantic==2.5.3
prometheus-fastapi-instrumentator==6.1.0
# docker-compose.yml
services:
  api:
    build: .
    ports:
      - "8000:8000"
    environment:
      - MODEL_PATH=/app/model.pkl
    volumes:
      - ./models:/app/models
    deploy:
      replicas: 3
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
      interval: 30s
      timeout: 10s
      retries: 3

  prometheus:
    image: prom/prometheus
    ports:
      - "9090:9090"
    volumes:
      - ./prometheus.yml:/etc/prometheus/prometheus.yml

  grafana:
    image: grafana/grafana
    ports:
      - "3000:3000"

4. Model Serialization Formats

import joblib
import pickle
import json

# Joblib (sklearn recommended)
joblib.dump(model, "model.joblib")

# Pickle (universal but security concerns)
with open("model.pkl", "wb") as f:
    pickle.dump(model, f)

# ONNX (cross-framework)
import onnx
from skl2onnx import convert_sklearn

onnx_model = convert_sklearn(model, initial_types=[...])
onnx.save(onnx_model, "model.onnx")

# TorchScript (PyTorch)
scripted = torch.jit.script(model)
scripted.save("model.pt")

5. Async and Batch Inference

from fastapi import BackgroundTasks
import asyncio

@app.post("/predict/batch")
async def predict_batch(requests: list[PredictionRequest]):
    tasks = [predict(req) for req in requests]
    results = await asyncio.gather(*tasks)
    return results

@app.post("/predict/async")
async def predict_async(request: PredictionRequest, background_tasks: BackgroundTasks):
    task_id = str(uuid4())
    background_tasks.add_task(process_prediction, task_id, request)
    return {"task_id": task_id, "status": "processing"}

@app.get("/predict/status/{task_id}")
async def get_status(task_id: str):
    result = cache.get(task_id)
    if result is None:
        return {"status": "pending"}
    return {"status": "complete", "result": result}

6. Production Checklist

  • Input validation (Pydantic models)
  • Error handling (try/except, proper HTTP codes)
  • Health check endpoint (/health)
  • Logging (structured JSON logs)
  • Monitoring (latency, throughput, error rates)
  • Rate limiting (prevent abuse)
  • Authentication (API keys, JWT)
  • CORS configuration
  • Model versioning in responses
  • Graceful shutdown handling

Key Takeaways

  • FastAPI provides async support, automatic docs, and type safety
  • Docker ensures reproducible deployments across environments
  • Health checks enable orchestrators to manage container lifecycle
  • Monitoring is essential — track latency, errors, and data drift

Advertisement

Need Expert Data Science Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement