Database & ORM Patterns
Difficulty: Medium-Hard | Companies: Google, Meta, Amazon, Netflix, Stripe
SQLAlchemy ORM Patterns
from sqlalchemy import (
create_engine, Column, Integer, String, DateTime, ForeignKey,
Boolean, Text, Float, Index, event
)
from sqlalchemy.orm import (
declarative_base, relationship, Session, sessionmaker,
validates, scoped_session
)
from sqlalchemy.pool import QueuePool
from datetime import datetime
from typing import Optional, List
from contextlib import contextmanager
import re
# Database Setup
engine = create_engine(
"postgresql://user:pass@localhost/db",
poolclass=QueuePool,
pool_size=5,
max_overflow=10,
pool_timeout=30,
pool_recycle=1800
)
SessionLocal = sessionmaker(bind=engine)
db_session = scoped_session(SessionLocal)
Base = declarative_base()
# Model Definitions
class TimestampMixin:
"""Mixin for timestamp fields."""
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class User(Base, TimestampMixin):
"""User model with relationships."""
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String(255), unique=True, index=True, nullable=False)
username = Column(String(50), unique=True, index=True, nullable=False)
hashed_password = Column(String(255), nullable=False)
is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False)
# Relationships
posts = relationship("Post", back_populates="author", cascade="all, delete-orphan")
profile = relationship("Profile", back_populates="user", uselist=False)
# Table arguments for additional indexes
__table_args__ = (
Index('idx_user_email_active', 'email', 'is_active'),
)
@validates('email')
def validate_email(self, key, email):
if not re.match(r'^[^@]+@[^@]+\.[^@]+$', email):
raise ValueError("Invalid email format")
return email.lower()
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}')>"
class Post(Base, TimestampMixin):
"""Post model with author relationship."""
__tablename__ = "posts"
id = Column(Integer, primary_key=True, index=True)
title = Column(String(200), nullable=False)
content = Column(Text, nullable=False)
published = Column(Boolean, default=False)
author_id = Column(Integer, ForeignKey("users.id"), nullable=False)
# Relationships
author = relationship("User", back_populates="posts")
tags = relationship("Tag", secondary="post_tags", back_populates="posts")
# Full-text search index
__table_args__ = (
Index('idx_post_title_content', 'title', 'content'),
)
class Profile(Base):
"""User profile with one-to-one relationship."""
__tablename__ = "profiles"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"), unique=True, nullable=False)
bio = Column(Text)
avatar_url = Column(String(500))
user = relationship("User", back_populates="profile")
class Tag(Base):
"""Tag model for posts."""
__tablename__ = "tags"
id = Column(Integer, primary_key=True)
name = Column(String(50), unique=True, nullable=False)
posts = relationship("Post", secondary="post_tags", back_populates="tags")
# Association table
post_tags = Base.metadata.tables.get('post_tags') or \
Base.tables['post_tags'] if 'post_tags' in Base.metadata.tables else None
if post_tags is None:
from sqlalchemy import Table, MetaData
post_tags = Table(
'post_tags', Base.metadata,
Column('post_id', Integer, ForeignKey('posts.id'), primary_key=True),
Column('tag_id', Integer, ForeignKey('tags.id'), primary_key=True)
)
βΉοΈ
Use relationship cascade options to automatically handle related objects. The
back_populates parameter keeps both sides of the relationship in sync.Repository Pattern
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any, TypeVar, Generic
from sqlalchemy.orm import Session
from sqlalchemy import desc, asc
T = TypeVar('T')
class Repository(ABC, Generic[T]):
"""Abstract repository interface."""
@abstractmethod
def get_by_id(self, id: int) -> Optional[T]:
pass
@abstractmethod
def get_all(self, skip: int = 0, limit: int = 100) -> List[T]:
pass
@abstractmethod
def create(self, obj: Dict[str, Any]) -> T:
pass
@abstractmethod
def update(self, id: int, obj: Dict[str, Any]) -> Optional[T]:
pass
@abstractmethod
def delete(self, id: int) -> bool:
pass
@abstractmethod
def count(self) -> int:
pass
class SQLAlchemyRepository(Repository[T]):
"""SQLAlchemy repository implementation."""
def __init__(self, session: Session, model_class: type):
self.session = session
self.model_class = model_class
def get_by_id(self, id: int) -> Optional[T]:
return self.session.query(self.model_class).filter(
self.model_class.id == id
).first()
def get_all(self, skip: int = 0, limit: int = 100) -> List[T]:
return self.session.query(self.model_class).offset(skip).limit(limit).all()
def create(self, obj_data: Dict[str, Any]) -> T:
obj = self.model_class(**obj_data)
self.session.add(obj)
self.session.commit()
self.session.refresh(obj)
return obj
def update(self, id: int, obj_data: Dict[str, Any]) -> Optional[T]:
obj = self.get_by_id(id)
if not obj:
return None
for key, value in obj_data.items():
setattr(obj, key, value)
self.session.commit()
self.session.refresh(obj)
return obj
def delete(self, id: int) -> bool:
obj = self.get_by_id(id)
if not obj:
return False
self.session.delete(obj)
self.session.commit()
return True
def count(self) -> int:
return self.session.query(self.model_class).count()
class UserRepository(SQLAlchemyRepository[User]):
"""User-specific repository methods."""
def get_by_email(self, email: str) -> Optional[User]:
return self.session.query(User).filter(
User.email == email.lower()
).first()
def get_active_users(self) -> List[User]:
return self.session.query(User).filter(
User.is_active == True
).all()
def search_by_username(self, query: str) -> List[User]:
return self.session.query(User).filter(
User.username.ilike(f"%{query}%")
).all()
class PostRepository(SQLAlchemyRepository[Post]):
"""Post-specific repository methods."""
def get_by_author(self, author_id: int) -> List[Post]:
return self.session.query(Post).filter(
Post.author_id == author_id
).all()
def get_published_posts(self) -> List[Post]:
return self.session.query(Post).filter(
Post.published == True
).order_by(desc(Post.created_at)).all()
def search_posts(self, query: str) -> List[Post]:
return self.session.query(Post).filter(
Post.title.ilike(f"%{query}%") | Post.content.ilike(f"%{query}%")
).all()
# Unit of Work Pattern
class UnitOfWork:
"""Unit of Work pattern for transactions."""
def __init__(self, session_factory):
self.session_factory = session_factory
self.session = None
def __enter__(self):
self.session = self.session_factory()
self.users = UserRepository(self.session, User)
self.posts = PostgreSQLRepository(self.session, Post)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.commit()
else:
self.rollback()
self.session.close()
def commit(self):
self.session.commit()
def rollback(self):
self.session.rollback()
# Usage
def create_user_with_post(user_data: dict, post_data: dict):
"""Create user and post in single transaction."""
with UnitOfWork(SessionLocal) as uow:
user = uow.users.create(user_data)
post_data['author_id'] = user.id
post = uow.posts.create(post_data)
return user, post
Query Optimization
from sqlalchemy import func, and_, or_, select
from sqlalchemy.orm import joinedload, selectinload, contains_eager
class QueryOptimizer:
"""Query optimization techniques."""
def __init__(self, session: Session):
self.session = session
def eager_loading_example(self):
"""Demonstrate different loading strategies."""
# N+1 Problem: Bad
# users = session.query(User).all()
# for user in users:
# print(user.posts) # Additional query for each user!
# Solution 1: joinedload (single query with JOIN)
users = self.session.query(User).options(
joinedload(User.posts)
).all()
# Solution 2: selectinload (separate query, IN clause)
users = self.session.query(User).options(
selectinload(User.posts)
).all()
return users
def complex_query_example(self):
"""Complex query with filtering and aggregation."""
# Subquery for post counts
post_counts = self.session.query(
Post.author_id,
func.count(Post.id).label('post_count')
).group_by(Post.author_id).subquery()
# Main query with join
result = self.session.query(
User,
post_counts.c.post_count
).outerjoin(
post_counts, User.id == post_counts.c.author_id
).filter(
User.is_active == True
).order_by(
desc(post_counts.c.post_count)
).limit(10).all()
return result
def bulk_operations(self):
"""Demonstrate bulk operations for better performance."""
# Bulk insert
users_to_insert = [
{'email': f'user{i}@example.com', 'username': f'user{i}'}
for i in range(1000)
]
self.session.bulk_insert_mappings(User, users_to_insert)
self.session.commit()
# Bulk update
self.session.query(User).filter(
User.is_active == False
).update({'is_active': True})
self.session.commit()
# Bulk delete
self.session.query(Post).filter(
Post.created_at < datetime(2020, 1, 1)
).delete()
self.session.commit()
def pagination_query(self, page: int, per_page: int):
"""Efficient pagination query."""
# Offset-based pagination (simple but slow for large offsets)
# users = session.query(User).offset((page-1)*per_page).limit(per_page).all()
# Keyset pagination (faster for large datasets)
query = self.session.query(User).order_by(User.id)
if page > 1:
# Get last ID from previous page
last_user = self.session.query(User).order_by(
User.id
).offset((page - 2) * per_page).limit(1).first()
if last_user:
query = query.filter(User.id > last_user.id)
return query.limit(per_page).all()
β οΈ
Always use
joinedload or selectinload to avoid N+1 queries. Profile your queries using SQLAlchemy's echo=True parameter.Connection Pooling
from sqlalchemy.pool import QueuePool, NullPool
from contextlib import contextmanager
from typing import Generator
import logging
logger = logging.getLogger(__name__)
class DatabaseManager:
"""Database connection manager with pooling."""
def __init__(self, database_url: str):
self.engine = create_engine(
database_url,
poolclass=QueuePool,
pool_size=20,
max_overflow=30,
pool_timeout=30,
pool_recycle=1800,
pool_pre_ping=True,
echo=False
)
self.SessionLocal = sessionmaker(bind=self.engine)
# Connection pool events
event.listen(self.engine, 'connect', self._on_connect)
event.listen(self.engine, 'checkout', self._on_checkout)
def _on_connect(self, dbapi_conn, connection_record):
"""Event handler for new connections."""
logger.info("New database connection established")
def _on_checkout(self, dbapi_conn, connection_record, connection_proxy):
"""Event handler when connection is checked out from pool."""
connection_record.info['checkout_time'] = datetime.now()
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""Get database session with automatic cleanup."""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Database error: {e}")
raise
finally:
session.close()
def execute_raw_query(self, query: str, params: dict = None):
"""Execute raw SQL query."""
with self.engine.connect() as conn:
result = conn.execute(query, params or {})
return result.fetchall()
# Health check
class DatabaseHealthCheck:
"""Database health check utilities."""
def __init__(self, db_manager: DatabaseManager):
self.db_manager = db_manager
def check_connection(self) -> bool:
"""Check if database connection is working."""
try:
with self.db_manager.get_session() as session:
session.execute("SELECT 1")
return True
except Exception:
return False
def get_pool_status(self) -> dict:
"""Get connection pool status."""
pool = self.db_manager.engine.pool
return {
'size': pool.size(),
'checked_in': pool.checkedin(),
'checked_out': pool.checkedout(),
'overflow': pool.overflow()
}
# Usage
db_manager = DatabaseManager("postgresql://user:pass@localhost/db")
def get_db_session():
"""FastAPI dependency for database sessions."""
with db_manager.get_session() as session:
yield session
Follow-Up Questions
-
Explain the N+1 query problem and how to solve it.
-
When would you use raw SQL over an ORM?
-
How do you handle database migrations in production?
-
What are the trade-offs between different loading strategies?
-
Explain the Unit of Work pattern and its benefits.