FastAPI Middleware - CORS, Custom Middleware

FastAPIFastAPI MiddlewareFree Lesson

Advertisement

Introduction

Middleware in FastAPI allows you to process requests and responses globally. This tutorial covers built-in CORS middleware and creating custom middleware for various purposes.

CORS Configuration

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://example.com", "https://www.example.com"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Different origins for different environments
@app.on_event("startup")
async def add_cors_middleware():
    import os
    if os.getenv("ENV") == "production":
        app.add_middleware(
            CORSMiddleware,
            allow_origins=["https://app.example.com"],
            allow_credentials=True,
            allow_methods=["GET", "POST"],
            allow_headers=["Authorization", "Content-Type"],
        )
    else:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=["*"],
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )

Custom Middleware

from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request, Response
import time
import logging

class RequestTimingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        start_time = time.time()
        response = await call_next(request)
        process_time = time.time() - start_time
        response.headers["X-Process-Time"] = str(process_time)
        return response

class LoggingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        logging.info(f"Request: {request.method} {request.url}")
        response = await call_next(request)
        logging.info(f"Response: {response.status_code}")
        return response

class RateLimitMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, max_requests: int = 100, window: int = 60):
        super().__init__(app)
        self.max_requests = max_requests
        self.window = window
        self.requests = {}
    
    async def dispatch(self, request: Request, call_next):
        client_ip = request.client.host
        current_time = time.time()
        
        # Clean old requests
        self.requests = {
            ip: times 
            for ip, times in self.requests.items()
            if current_time - times[-1] < self.window
        }
        
        # Check rate limit
        if client_ip in self.requests:
            if len(self.requests[client_ip]) >= self.max_requests:
                return Response(
                    content="Rate limit exceeded",
                    status_code=429
                )
            self.requests[client_ip].append(current_time)
        else:
            self.requests[client_ip] = [current_time]
        
        return await call_next(request)

app.add_middleware(RequestTimingMiddleware)
app.add_middleware(LoggingMiddleware)

Middleware with Error Handling

from fastapi import Request
from starlette.responses import JSONResponse

class ErrorHandlingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        try:
            response = await call_next(request)
            return response
        except Exception as e:
            logging.exception("Unhandled exception")
            return JSONResponse(
                status_code=500,
                content={"error": "Internal server error"}
            )

class SecurityHeadersMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        response = await call_next(request)
        response.headers["X-Content-Type-Options"] = "nosniff"
        response.headers["X-Frame-Options"] = "DENY"
        response.headers["X-XSS-Protection"] = "1; mode=block"
        response.headers["Strict-Transport-Security"] = "max-age=31536000"
        return response

Practice Problems

  1. Create middleware that adds authentication token validation
  2. Implement a middleware that compresses responses for large payloads
  3. Add a middleware that tracks request/response sizes
  4. Build middleware that implements request ID tracing
  5. Create middleware that handles session management

Advertisement

Need Expert Python Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement