CW

Connection Management and Hooks in Apache Airflow

Free Lesson

Advertisement

Connection Management and Hooks

Architecture Diagram

Formal Definitions

DfConnection

A Connection in Airflow is a named configuration object stored in the metadata database that contains all the information needed to connect to an external system. It includes hostname, login, password, schema, port, and extra parameters. Formally, C=(n,h,u,p,s,e)C = (n, h, u, p, s, e) where nn is the conn_id, hh is host, uu is login, pp is password, ss is schema, and ee is extra JSON.

DfHook

A Hook is a high-level interface to external systems that encapsulates connection logic and provides methods for common operations. Hooks abstract the underlying client library and handle authentication, connection pooling, and retry logic. Formally, H:C(M1,M2,,Mk)H: C \rightarrow (M_1, M_2, \ldots, M_k) maps a connection to a set of available methods.

DfConnection Pool

A Connection Pool is a cache of reusable database connections managed by SQLAlchemy. The pool maintains NminNactiveNmaxN_{\text{min}} \leq N_{\text{active}} \leq N_{\text{max}} connections, reducing the overhead of establishing new connections for each task execution.

Detailed Explanation

Creating Connections

Connections can be created through the Airflow UI, CLI, REST API, or environment variables.

# Using the CLI to create connections
# airflow connections add 'my_s3_conn' \
#   --conn-type 'aws' \
#   --conn-login 'AKIAIOSFODNN7EXAMPLE' \
#   --conn-password 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY' \
#   --conn_extra '{"region_name": "us-east-1"}'

# Using environment variables
# AIRFLOW_CONN_MY_S3_CONN='aws://AKIAIOSFODNN7EXAMPLE:wJalrXUtnFEMI%2Fk7MDENG%2FbPxRfiCYEXAMPLEKEY@?region_name=us-east-1'

# Programmatic connection creation
from airflow.models.connection import Connection
from airflow import settings

def create_connection():
    """Create a connection programmatically."""
    session = settings.Session()
    
    # Check if connection exists
    existing = session.query(Connection).filter(
        Connection.conn_id == 'my_custom_conn'
    ).first()
    
    if existing:
        print("Connection already exists")
        return existing
    
    conn = Connection(
        conn_id='my_custom_conn',
        conn_type='postgres',
        host='localhost',
        login='airflow_user',
        password='secure_password',
        schema='analytics',
        port=5432,
        extra='{"sslmode": "require", "connect_timeout": 30}',
    )
    
    session.add(conn)
    session.commit()
    print("Connection created successfully")
    return conn

if __name__ == "__main__":
    create_connection()

Using Hooks

from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.http.hooks.http import HttpHook

# Database hook usage
def database_operations():
    """Use PostgresHook for database operations."""
    hook = PostgresHook(postgres_conn_id='postgres_default')
    
    # Get a connection
    conn = hook.get_conn()
    cursor = conn.cursor()
    
    # Execute queries
    cursor.execute("SELECT COUNT(*) FROM users WHERE active = %s", (True,))
    count = cursor.fetchone()[0]
    print(f"Active users: {count}")
    
    # Use built-in methods
    df = hook.get_pandas_df("SELECT * FROM orders LIMIT 100")
    
    # Insert data
    hook.insert_rows(
        table='staging_orders',
        rows=[(1, 'order_1', 100.00), (2, 'order_2', 250.00)],
        target_fields=['id', 'order_id', 'amount'],
    )
    
    cursor.close()
    conn.close()

# S3 hook usage
def s3_operations():
    """Use S3Hook for S3 operations."""
    hook = S3Hook(aws_conn_id='aws_default')
    
    # List objects
    objects = hook.list_objects(bucket_name='my-bucket', prefix='data/')
    print(f"Found {len(objects)} objects")
    
    # Download file
    local_path = hook.download_file(
        key='data/file.csv',
        bucket_name='my-bucket',
        local_path='/tmp/',
    )
    
    # Upload file
    hook.load_file(
        filename='/tmp/output.csv',
        key='data/output.csv',
        bucket_name='my-bucket',
        replace=True,
    )
    
    # Check if key exists
    exists = hook.check_for_key(
        key='data/file.csv',
        bucket_name='my-bucket',
    )

# HTTP hook usage
def api_operations():
    """Use HttpHook for API calls."""
    hook = HttpHook(http_conn_id='api_default', method='GET')
    
    # Make API call
    response = hook.run(
        endpoint='/api/v1/data',
        headers={'Content-Type': 'application/json'},
        params={'limit': 100},
    )
    
    data = response.json()
    print(f"Received {len(data)} records")
Connection Retrieval
Cconn=BaseHook.get_connection(connid)C_{\text{conn}} = \text{BaseHook.get\_connection}(conn_{\text{id}})

Here,

  • CextconnC_{ ext{conn}}=Connection object with all credentials
  • connextidconn_{ ext{id}}=Unique connection identifier

Connection Pool Efficiency

ηpool=NreusedNtotal×100%\eta_{\text{pool}} = \frac{N_{\text{reused}}}{N_{\text{total}}} \times 100\%

Here,

  • NextreusedN_{ ext{reused}}=Number of connections reused from pool
  • NexttotalN_{ ext{total}}=Total connection attempts

Airflow stores connections in the metadata database by default. For production, consider using Environment Variables or a Secrets Backend (Vault, AWS Secrets Manager) to avoid storing credentials in the database.

Use connection pools and the Connection object's get_sqlalchemy_engine() method to reuse database connections across tasks. This reduces connection overhead by 60-80%.

Key Concepts Table

Connection TypeConn TypeHook ClassKey Parameters
PostgreSQLpostgresPostgresHookhost, login, password, schema, port
MySQLmysqlMySqlHookhost, login, password, schema, port
S3awsS3Hooklogin (key), password (secret), extra (region)
GCSgoogle_cloud_platformGCSHookextra (keyfile, project)
HTTPhttpHttpHookhost, schema, extra (headers)
SparksparkSparkHookhost, extra (master, deploy-mode)
RedisredisRedisHookhost, login, password, port, db

Code Examples

Custom Hook Implementation

# custom_hook.py
from airflow.hooks.base import BaseHook
from airflow.exceptions import AirflowException
import requests
import json

class CustomAPIHook(BaseHook):
    """Custom hook for internal API services."""
    
    conn_name_attr = 'custom_api_conn_id'
    default_conn_name = 'custom_api_default'
    conn_type = 'custom_api'
    hook_name = 'Custom API'
    
    def __init__(self, custom_api_conn_id=None):
        self.conn_id = custom_api_conn_id or self.default_conn_name
        self.base_url = None
        self.headers = {}
        self._get_connection()
    
    def _get_connection(self):
        """Parse connection details."""
        conn = self.get_connection(self.conn_id)
        schema = conn.schema or 'https'
        self.base_url = f"{schema}://{conn.host}"
        if conn.port:
            self.base_url += f":{conn.port}"
        
        if conn.login:
            self.headers['Authorization'] = f"Bearer {conn.password}"
        
        if conn.extra:
            extra = json.loads(conn.extra) if isinstance(conn.extra, str) else conn.extra
            self.headers.update(extra.get('headers', {}))
    
    def get(self, endpoint, params=None):
        """Make GET request."""
        url = f"{self.base_url}{endpoint}"
        try:
            response = requests.get(
                url,
                headers=self.headers,
                params=params,
                timeout=30,
            )
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            raise AirflowException(f"API request failed: {e}")
    
    def post(self, endpoint, data=None):
        """Make POST request."""
        url = f"{self.base_url}{endpoint}"
        try:
            response = requests.post(
                url,
                headers=self.headers,
                json=data,
                timeout=30,
            )
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            raise AirflowException(f"API request failed: {e}")

# Usage in DAG
from airflow.decorators import task, dag
from datetime import datetime

@dag(schedule_interval="@daily", start_date=datetime(2024, 1, 1))
def custom_api_dag():
    
    @task
    def fetch_data():
        hook = CustomAPIHook(custom_api_conn_id='my_api')
        return hook.get('/api/v1/data', params={'limit': 100})
    
    @task
    def process_data(data: dict):
        print(f"Processing {len(data.get('results', []))} records")
    
    data = fetch_data()
    process_data(data)

custom_api_dag()

Connection Pool Configuration

# connection_pool_config.py
from airflow import settings
from sqlalchemy.pool import QueuePool
from sqlalchemy import create_engine

def configure_connection_pool():
    """Configure optimal connection pool settings."""
    engine = settings.engine
    
    # Check current pool status
    pool = engine.pool
    print(f"Pool type: {type(pool).__name__}")
    print(f"Pool size: {pool.size()}")
    print(f"Checked in: {pool.checkedin()}")
    print(f"Checked out: {pool.checkedout()}")
    print(f"Overflow: {pool.overflow()}")
    
    return {
        'pool_size': pool.size(),
        'checked_in': pool.checkedin(),
        'checked_out': pool.checkedout(),
        'overflow': pool.overflow(),
    }

def optimize_pool_for_workload(
    avg_concurrent_tasks: int,
    peak_concurrent_tasks: int,
    task_duration_seconds: int,
):
    """Calculate optimal pool settings based on workload."""
    
    # Base pool size should handle average concurrent tasks
    pool_size = max(5, avg_concurrent_tasks)
    
    # Overflow handles peaks
    max_overflow = max(10, peak_concurrent_tasks - pool_size)
    
    # Timeout should be longer than typical task duration
    pool_timeout = task_duration_seconds * 2
    
    # Recycle connections to avoid stale connections
    pool_recycle = 1800  # 30 minutes
    
    return {
        'pool_size': pool_size,
        'max_overflow': max_overflow,
        'pool_timeout': pool_timeout,
        'pool_recycle': pool_recycle,
        'pool_pre_ping': True,
    }

# Example: 50 avg tasks, 100 peak, 60s duration
config = optimize_pool_for_workload(
    avg_concurrent_tasks=50,
    peak_concurrent_tasks=100,
    task_duration_seconds=60,
)
print(f"Recommended pool config: {config}")

# Apply to airflow.cfg
# [database]
# sql_alchemy_pool_size = 50
# sql_alchemy_max_overflow = 50
# sql_alchemy_pool_timeout = 120
# sql_alchemy_pool_recycle = 1800
# sql_alchemy_pool_pre_ping = True

Secrets Backend Integration

# secrets_backend.py
from airflow.providers.amazon.aws.secrets.secrets_manager import SecretsManagerBackend
from airflow.providers.hashicorp.secrets.vault import VaultBackend
from airflow.providers.google.cloud.secrets.secret_manager import CloudSecretManagerBackend

# AWS Secrets Manager configuration
aws_secrets_backend = SecretsManagerBackend(
    conn_id='aws_default',
    region_name='us-east-1',
    sep='/',
)

# HashiCorp Vault configuration
vault_backend = VaultBackend(
    conn_id='vault_default',
    secret_path='airflow',
    mount_point='secret',
)

# Google Secret Manager configuration
gcp_backend = CloudSecretManagerBackend(
    conn_id='google_cloud_default',
    project_id='my-project',
    secret_path='airflow',
    sep='-',
)

# Configure in airflow.cfg
# [secrets]
# backend = airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend
# backend_kwargs = {"connections_prefix": "airflow/connections/", "variables_prefix": "airflow/variables/"}

# Usage in DAG - hooks automatically resolve from secrets backend
from airflow.decorators import task, dag
from datetime import datetime

@dag(schedule_interval="@daily", start_date=datetime(2024, 1, 1))
def secrets_backend_dag():
    
    @task
    def use_secret_connection():
        """Connection is automatically retrieved from secrets backend."""
        from airflow.providers.postgres.hooks.postgres import PostgresHook
        
        # This will look up 'production_db' in the configured secrets backend
        hook = PostgresHook(postgres_conn_id='production_db')
        result = hook.get_first("SELECT NOW()")
        return result[0]
    
    @task
    def use_secret_variable():
        """Variable is automatically retrieved from secrets backend."""
        from airflow.models import Variable
        
        api_key = Variable.get("api_key")
        print(f"Using API key: {api_key[:4]}...")
    
    use_secret_connection() >> use_secret_variable()

secrets_backend_dag()

Performance Metrics

Connection Types Comparison

AspectMetadata DBEnvironment VariablesSecrets Manager
SecurityMediumHighVery High
PerformanceFast (DB lookup)Fast (OS lookup)Medium (API call)
Audit TrailYesNoYes
RotationManualManualAutomated
CostFreeFreeLow-Medium
ComplexityLowLowMedium
Multi-environmentDifficultEasyEasy

Connection Pool Metrics

MetricRecommendedWarningCritical
Pool Utilization< 70%70-90%> 90%
Connection Wait Time< 1s1-5s> 5s
Connection Errors< 1%1-5%> 5%
Pool Overflow< 50%50-80%> 80%

Key Takeaways:

  • Connections store credentials in the metadata database; use Secrets Backends for production
  • Hooks abstract client libraries and provide high-level interfaces for external systems
  • Connection pools reduce overhead by 60-80% through connection reuse
  • Use conn_id consistently across DAGs to enable credential rotation without code changes
  • Environment variables (AIRFLOW_CONN_*) provide the simplest secrets management
  • Custom hooks extend Airflow's capabilities for proprietary systems

See Also

Advertisement

Need Expert Airflow Help?

Get personalized DAG design, scheduling optimization, or production Airflow consulting.

Advertisement