Context Managers in Python
Difficulty: Medium | Companies: Google, Meta, Amazon, Netflix, Stripe
Context Manager Protocol
from typing import Any, Optional
import time
class Timer:
"""Class-based context manager for timing code blocks."""
def __init__(self, label: str = "Code block"):
self.label = label
self.start_time: Optional[float] = None
self.end_time: Optional[float] = None
self.elapsed: Optional[float] = None
def __enter__(self):
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end_time = time.perf_counter()
self.elapsed = self.end_time - self.start_time
if exc_type is not None:
print(f"{self.label} failed after {self.elapsed:.4f}s")
print(f"Exception: {exc_val}")
else:
print(f"{self.label} completed in {self.elapsed:.4f}s")
return False # Don't suppress exceptions
# Usage
with Timer("API call") as timer:
time.sleep(0.1)
print("Performing API call")
print(f"Elapsed: {timer.elapsed:.4f}s")
βΉοΈ
The
__exit__ method receives exception information. Return True to suppress exceptions, False to propagate them.Advanced Class-Based Context Managers
import threading
from contextlib import contextmanager
from typing import Generator
import tempfile
import shutil
from pathlib import Path
class ManagedResource:
"""Context manager with resource lifecycle management."""
def __init__(self, name: str):
self.name = name
self.resource = None
self.is_active = False
def __enter__(self):
print(f"Acquiring resource: {self.name}")
self.resource = {"name": self.name, "status": "active"}
self.is_active = True
return self.resource
def __exit__(self, exc_type, exc_val, exc_tb):
print(f"Releasing resource: {self.name}")
self.resource = None
self.is_active = False
return False
class ThreadSafeResource:
"""Thread-safe context manager."""
def __init__(self, shared_resource: dict):
self.shared_resource = shared_resource
self.lock = threading.RLock()
def __enter__(self):
self.lock.acquire()
print(f"Thread {threading.current_thread().name} acquired lock")
return self.shared_resource
def __exit__(self, exc_type, exc_val, exc_tb):
print(f"Thread {threading.current_thread().name} releasing lock")
self.lock.release()
return False
class TempDirectory:
"""Context manager for temporary directory with cleanup."""
def __init__(self, suffix: str = "", prefix: str = "tmp_"):
self.suffix = suffix
self.prefix = prefix
self.path: Optional[Path] = None
def __enter__(self) -> Path:
self.path = Path(tempfile.mkdtemp(suffix=self.suffix, prefix=self.prefix))
print(f"Created temporary directory: {self.path}")
return self.path
def __exit__(self, exc_type, exc_val, exc_tb):
if self.path and self.path.exists():
shutil.rmtree(self.path)
print(f"Cleaned up temporary directory: {self.path}")
return False
Function-Based Context Managers
from contextlib import contextmanager
from typing import Generator, Any
import time
import os
@contextmanager
def timer(label: str = "Code block") -> Generator[dict, None, None]:
"""Function-based timer context manager."""
start = time.perf_counter()
result = {"label": label, "elapsed": None, "success": False}
try:
yield result
result["success"] = True
except Exception as e:
result["error"] = str(e)
raise
finally:
result["elapsed"] = time.perf_counter() - start
status = "completed" if result["success"] else "failed"
print(f"{label} {status} in {result['elapsed']:.4f}s")
@contextmanager
def change_directory(path: str) -> Generator[None, None, None]:
"""Context manager for temporarily changing directories."""
old_cwd = os.getcwd()
try:
os.chdir(path)
yield
finally:
os.chdir(old_cwd)
@contextmanager
def suppress_exception(*exception_types):
"""Context manager that suppresses specified exceptions."""
try:
yield
except exception_types as e:
print(f"Suppressed exception: {e}")
# Usage
with timer("Data processing") as t:
time.sleep(0.1)
data = [i ** 2 for i in range(1000)]
print(f"Result: {t}")
β οΈ
Always use try/finally blocks in function-based context managers to ensure cleanup code runs even if exceptions occur.
Contextlib Utilities
from contextlib import (
contextmanager, redirect_stdout, redirect_stderr,
ExitStack, closing, suppress, asynccontextmanager
)
import io
import sys
from typing import Generator
import asyncio
class Resource:
"""Example resource that needs cleanup."""
def __init__(self, name: str):
self.name = name
self.is_open = False
def open(self):
self.is_open = True
print(f"Opened {self.name}")
return self
def close(self):
self.is_open = False
print(f"Closed {self.name}")
def __del__(self):
if self.is_open:
self.close()
def example_with_closing():
"""Using closing() for objects with close() method."""
resource = Resource("database")
with closing(resource.open()) as res:
# Use resource
print(f"Using {res.name}")
# Automatically calls close()
def example_with_exit_stack():
"""Using ExitStack for dynamic context managers."""
with ExitStack() as stack:
# Dynamically enter contexts
files = []
for i in range(3):
f = stack.enter_context(open(f"test_{i}.txt", "w"))
files.append(f)
# All files will be closed when ExitStack exits
for i, f in enumerate(files):
f.write(f"Content {i}")
def example_with_suppress():
"""Using suppress() to ignore specific exceptions."""
with suppress(FileNotFoundError):
os.remove("nonexistent_file.txt")
print("Continued despite exception")
@contextmanager
def redirect_output(stdout_target=None, stderr_target=None):
"""Redirect stdout and/or stderr."""
old_stdout = sys.stdout
old_stderr = sys.stderr
if stdout_target:
sys.stdout = stdout_target
if stderr_target:
sys.stderr = stderr_target
try:
yield
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
# Usage
output = io.StringIO()
with redirect_stdout(output):
print("This goes to StringIO")
print(f"Captured: {output.getvalue()}")
Async Context Managers
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncGenerator
class AsyncDatabase:
"""Async context manager for database connections."""
def __init__(self, connection_string: str):
self.connection_string = connection_string
self.connection = None
async def __aenter__(self):
print(f"Async connecting to {self.connection_string}")
await asyncio.sleep(0.1)
self.connection = {"status": "connected"}
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
print("Async disconnecting")
await asyncio.sleep(0.05)
self.connection = None
return False
async def execute(self, query: str):
await asyncio.sleep(0.01)
return f"Result: {query}"
@asynccontextmanager
async def async_timer(label: str) -> AsyncGenerator[dict, None]:
"""Async context manager for timing."""
start = asyncio.get_event_loop().time()
result = {"label": label, "elapsed": None}
try:
yield result
finally:
result["elapsed"] = asyncio.get_event_loop().time() - start
print(f"{label} took {result['elapsed']:.4f}s")
async def async_usage_example():
# Class-based
async with AsyncDatabase("postgresql://localhost/db") as db:
result = await db.execute("SELECT * FROM users")
print(result)
# Function-based
async with async_timer("Async operation") as t:
await asyncio.sleep(0.1)
# Using ExitStack for multiple async contexts
async with AsyncDatabase("db1") as db1, AsyncDatabase("db2") as db2:
await db1.execute("Query 1")
await db2.execute("Query 2")
asyncio.run(async_usage_example())
Nested and Composed Context Managers
from contextlib import contextmanager, ExitStack
from typing import Generator, List
import time
@contextmanager
def database_connection(db_name: str) -> Generator:
"""Database connection context."""
print(f"Connecting to {db_name}")
connection = {"name": db_name, "active": True}
try:
yield connection
finally:
connection["active"] = False
print(f"Disconnected from {db_name}")
@contextmanager
def transaction(connection: dict) -> Generator:
"""Transaction context within a database connection."""
print(f"Starting transaction on {connection['name']}")
try:
yield connection
print("Committing transaction")
except Exception:
print("Rolling back transaction")
raise
def nested_context_demo():
"""Demonstrate nested context managers."""
with database_connection("postgres") as db:
with transaction(db) as tx:
print(f"Executing query on {tx['name']}")
time.sleep(0.01)
print("Transaction completed")
print("Database connection closed")
@contextmanager
def managed_resources(resources: List[str]) -> Generator[dict, None, None]:
"""Manage multiple resources with automatic cleanup."""
managed = {}
with ExitStack() as stack:
for resource_name in resources:
# Simulate acquiring resource
managed[resource_name] = {"name": resource_name, "acquired": True}
print(f"Acquired: {resource_name}")
try:
yield managed
finally:
# Cleanup all resources in reverse order
for resource_name in reversed(resources):
managed[resource_name]["acquired"] = False
print(f"Released: {resource_name}")
# Usage
with managed_resources(["file1.txt", "file2.txt", "database"]) as resources:
print(f"Using resources: {list(resources.keys())}")
βΉοΈ
Use ExitStack when you need to dynamically manage multiple context managers or when the number of contexts is determined at runtime.
Real-World Patterns
import functools
from typing import Callable, Any
import time
class RateLimiter:
"""Rate limiter as a context manager."""
def __init__(self, max_calls: int, period: float):
self.max_calls = max_calls
self.period = period
self.calls = []
def __enter__(self):
now = time.time()
self.calls = [call for call in self.calls if now - call < self.period]
if len(self.calls) >= self.max_calls:
raise RuntimeError(f"Rate limit exceeded")
self.calls.append(now)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
class CircuitBreaker:
"""Circuit breaker pattern as context manager."""
def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 60):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.failures = 0
self.last_failure_time = None
self.state = "closed" # closed, open, half-open
def __enter__(self):
if self.state == "open":
if time.time() - self.last_failure_time > self.recovery_timeout:
self.state = "half-open"
print("Circuit breaker: half-open state")
else:
raise RuntimeError("Circuit breaker is open")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
self.failures += 1
self.last_failure_time = time.time()
if self.failures >= self.failure_threshold:
self.state = "open"
print("Circuit breaker: opened")
else:
self.failures = 0
if self.state == "half-open":
self.state = "closed"
print("Circuit breaker: closed")
return False
def rate_limited(max_calls: int, period: float) -> Callable:
"""Decorator that applies rate limiting using context manager."""
limiter = RateLimiter(max_calls, period)
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
with limiter:
return func(*args, **kwargs)
return wrapper
return decorator
# Usage
@rate_limited(max_calls=10, period=1.0)
def api_call(endpoint: str):
return f"Response from {endpoint}"
try:
for i in range(15):
result = api_call(f"/api/endpoint/{i}")
print(result)
except RuntimeError as e:
print(f"Rate limited: {e}")
Follow-Up Questions
-
Explain the difference between
__enter__/__exit__andcontextmanagerdecorator. -
How do you handle exceptions in context managers?
-
What is the purpose of the
ExitStackclass? -
When would you use async context managers over regular ones?
-
How do you create a context manager that can be used both with
withstatement and manually?