Building Custom Airflow Operators
Architecture Diagram
Formal Definitions
DfCustom Operator
A custom operator is a user-defined class that extends BaseOperator to encapsulate reusable workflow logic. It must implement an execute(context) method and optionally define template_fields, template_ext, and serialization hooks for the Airflow scheduler.
DfOperator Template Fields
Template fields are a tuple of attribute names that Airflow renders through Jinja2 before passing them to execute(). If is the set of template fields, then for each , the value is substituted from the DAG context: .
DfHook Abstraction Layer
A hook provides a connection-aware client for external systems. It encapsulates connection retrieval, authentication, retry logic, and session management. Hooks decouple operators from low-level networking concerns.
Detailed Explanation
Anatomy of a Custom Operator
Every custom operator inherits from BaseOperator. The class must define template_fields for any parameter that should support Jinja templating. The execute() method receives a context dictionary and performs the operator's work.
from typing import Any, Dict, Optional
from datetime import timedelta
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException
class DataQualityCheckOperator(BaseOperator):
"""
Custom operator for validating data quality against defined rules.
:param table_name: Name of the table to validate
:param validation_rules: Dictionary of validation rules
:param connection_id: Airflow connection ID for database access
:param fail_threshold: Maximum allowed failure percentage
"""
template_fields = ('table_name', 'validation_rules')
ui_color = '#4CAF50'
ui_fgcolor = '#FFFFFF'
template_ext = ('.sql',)
@apply_defaults
def __init__(
self,
table_name: str,
validation_rules: Dict[str, Any],
connection_id: str = 'postgres_default',
fail_threshold: float = 0.0,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.table_name = table_name
self.validation_rules = validation_rules
self.connection_id = connection_id
self.fail_threshold = fail_threshold
def execute(self, context: Dict[str, Any]) -> Any:
self.log.info(f"Running quality checks on {self.table_name}")
results = self._run_checks()
self._evaluate_results(results)
context['ti'].xcom_push(key='quality_results', value=results)
return results
def _run_checks(self) -> Dict[str, Any]:
from airflow.providers.postgres.hooks.postgres import PostgresHook
hook = PostgresHook(postgres_conn_id=self.connection_id)
results = {}
for rule_name, rule_config in self.validation_rules.items():
query = rule_config.get('query')
expected = rule_config.get('expected_value')
records = hook.get_first(query)
actual = records[0] if records else None
results[rule_name] = {
'passed': actual == expected,
'actual': actual,
'expected': expected,
}
return results
def _evaluate_results(self, results: Dict[str, Any]) -> None:
failed = [r for r, v in results.items() if not v['passed']]
total = len(results)
fail_rate = len(failed) / total if total > 0 else 0
if fail_rate > self.fail_threshold:
raise AirflowException(
f"Quality check failed: {len(failed)}/{total} rules failed. "
f"Failures: {failed}"
)
def on_kill(self) -> None:
self.log.info("Operator terminated")
Custom Hook Implementation
Hooks manage connection details and client lifecycle. They retrieve credentials from Airflow's connection store and provide a clean API for operators.
from typing import Any, Dict, Optional
from airflow.hooks.base import BaseHook
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
class InternalAPIHook(BaseHook):
"""
Hook for connecting to internal REST APIs.
:param api_conn_id: Airflow connection ID
:param timeout: Request timeout in seconds
:param max_retries: Maximum retry attempts
"""
conn_name_attr = 'api_conn_id'
default_conn_name = 'internal_api_default'
conn_type = 'http'
hook_name = 'Internal API'
def __init__(
self,
api_conn_id: str = default_conn_name,
timeout: int = 30,
max_retries: int = 3,
):
super().__init__()
self.api_conn_id = api_conn_id
self.timeout = timeout
self.max_retries = max_retries
self._session: Optional[requests.Session] = None
def get_conn(self) -> requests.Session:
if self._session is None:
self._session = self._create_session()
return self._session
def _create_session(self) -> requests.Session:
conn = self.get_connection(self.api_conn_id)
session = requests.Session()
retry = Retry(
total=self.max_retries,
backoff_factor=0.5,
status_forcelist=[429, 500, 502, 503, 504],
)
adapter = HTTPAdapter(max_retries=retry)
session.mount('https://', adapter)
session.mount('http://', adapter)
token = conn.get_password()
if token:
session.headers['Authorization'] = f'Bearer {token}'
session.headers['Content-Type'] = 'application/json'
self._base_url = f'{conn.schema}://{conn.host}:{conn.port}'
return session
def request(
self,
method: str,
endpoint: str,
json: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
session = self.get_conn()
url = f'{self._base_url}{endpoint}'
response = session.request(
method=method,
url=url,
json=json,
params=params,
timeout=self.timeout,
)
response.raise_for_status()
return response.json()
def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
return self.request('GET', endpoint, **kwargs)
def post(self, endpoint: str, data: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self.request('POST', endpoint, json=data, **kwargs)
def put(self, endpoint: str, data: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self.request('PUT', endpoint, json=data, **kwargs)
def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
return self.request('DELETE', endpoint, **kwargs)
Operator with Multiple Hooks
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from typing import Any, Dict, List
class MultiSourceSyncOperator(BaseOperator):
"""
Synchronize data between multiple source and target systems.
"""
template_fields = ('source_query', 'target_table')
@apply_defaults
def __init__(
self,
source_conn_id: str,
target_conn_id: str,
source_query: str,
target_table: str,
batch_size: int = 1000,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.source_conn_id = source_conn_id
self.target_conn_id = target_conn_id
self.source_query = source_query
self.target_table = target_table
self.batch_size = batch_size
def execute(self, context: Dict[str, Any]) -> Any:
source_hook = self._get_hook(self.source_conn_id)
target_hook = self._get_hook(self.target_conn_id)
data = self._extract(source_hook)
loaded = self._load(target_hook, data)
self.log.info(f'Synced {loaded} records to {self.target_table}')
return {'records_loaded': loaded}
def _get_hook(self, conn_id: str) -> Any:
from airflow.providers.postgres.hooks.postgres import PostgresHook
return PostgresHook(postgres_conn_id=conn_id)
def _extract(self, hook: Any) -> List[tuple]:
return hook.get_records(self.source_query)
def _load(self, hook: Any, data: List[tuple]) -> int:
count = 0
for i in range(0, len(data), self.batch_size):
batch = data[i:i + self.batch_size]
hook.insert_rows(table=self.target_table, rows=batch)
count += len(batch)
return count
Serialization for Dynamic DAGs
from airflow.models import BaseOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator
from typing import Any, Dict
class SerializableOperator(BaseOperator):
"""
Custom operator that supports serialization for dynamic DAGs.
"""
template_fields = ('message',)
def __init__(self, message: str = 'hello', *args, **kwargs):
super().__init__(*args, **kwargs)
self.message = message
def execute(self, context: Dict[str, Any]) -> None:
self.log.info(self.message)
def serialize(self) -> Dict[str, Any]:
return {
'class': self.__class__.__name__,
'module': self.__class__.__module__,
'kwargs': {
'message': self.message,
'task_id': self.task_id,
'retries': self.retries,
},
}
@classmethod
def deserialize(cls, data: Dict[str, Any], **kwargs) -> 'SerializableOperator':
return cls(**data['kwargs'], **kwargs)
Key Concepts Table
| Component | Purpose | Required? | Example |
|---|---|---|---|
| BaseOperator | Parent class for all operators | Yes | class MyOp(BaseOperator) |
execute() | Core logic entry point | Yes | def execute(self, context) |
template_fields | Jinja-rendered attributes | Recommended | template_fields = ('query',) |
template_ext | External file templates | Optional | template_ext = ('.sql',) |
ui_color | DAG visualization color | Optional | ui_color = '#4CAF50' |
on_kill() | Cleanup on termination | Optional | def on_kill(self) |
| Hook | External connection management | When needed | class MyHook(BaseHook) |
| Provider Package | Distribution packaging | Optional | apache-airflow-providers-x |
Code Examples
Testing Custom Operators
import pytest
from unittest.mock import MagicMock, patch
from airflow.models import DagBag
from datetime import datetime
class TestDataQualityCheckOperator:
"""Unit tests for the custom DataQualityCheckOperator."""
def test_execute_passes_all_rules(self):
from airflow.providers.postgres.hooks.postgres import PostgresHook
with patch.object(PostgresHook, 'get_first', return_value=(42,)):
op = DataQualityCheckOperator(
task_id='test_quality',
table_name='orders',
validation_rules={
'row_count': {
'query': 'SELECT COUNT(*) FROM orders',
'expected_value': 42,
}
},
connection_id='test_postgres',
)
context = {'ti': MagicMock(), 'ds': '2024-01-01'}
result = op.execute(context)
assert result['row_count']['passed'] is True
def test_execute_fails_on_threshold(self):
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.exceptions import AirflowException
with patch.object(PostgresHook, 'get_first', return_value=(0,)):
op = DataQualityCheckOperator(
task_id='test_quality',
table_name='orders',
validation_rules={
'check1': {'query': 'SELECT 1', 'expected_value': 1},
'check2': {'query': 'SELECT 1', 'expected_value': 1},
},
connection_id='test_postgres',
fail_threshold=0.0,
)
context = {'ti': MagicMock(), 'ds': '2024-01-01'}
with pytest.raises(AirflowException):
op.execute(context)
def test_template_fields_rendered(self):
op = DataQualityCheckOperator(
task_id='test',
table_name='{{ ds_nodash }}',
validation_rules={},
)
assert op.task_id == 'test'
assert '{{ ds_nodash }}' in op.template_fields
Testing Custom Hooks
import pytest
from unittest.mock import MagicMock, patch
class TestInternalAPIHook:
"""Unit tests for the custom hook."""
def test_get_conn_returns_session(self):
from airflow.models import Connection
mock_conn = Connection(
conn_id='test_api',
conn_type='http',
host='api.example.com',
port=443,
schema='https',
password='test-token',
)
with patch.object(InternalAPIHook, 'get_connection', return_value=mock_conn):
hook = InternalAPIHook(api_conn_id='test_api')
session = hook.get_conn()
assert session is not None
assert 'Authorization' in session.headers
def test_request_sends_correct_headers(self):
hook = InternalAPIHook(api_conn_id='test')
hook._session = MagicMock()
hook._base_url = 'https://api.example.com'
hook.get('/health')
hook._session.request.assert_called_once()
call_kwargs = hook._session.request.call_args
assert call_kwargs[1]['method'] == 'GET'
Provider Package Structure
# setup.py for custom provider package
from setuptools import setup, find_packages
setup(
name='apache-airflow-providers-custom',
version='1.0.0',
description='Custom Airflow provider for internal APIs',
packages=find_packages(),
install_requires=[
'apache-airflow>=2.6.0',
'requests>=2.28.0',
],
entry_points={
'apache_airflow_provider': [
'provider_info = custom_provider.__init__:get_provider_info',
],
},
)
# custom_provider/__init__.py
def get_provider_info():
return {
'package-name': 'apache-airflow-providers-custom',
'name': 'Custom Provider',
'description': 'Custom operators and hooks',
'version': '1.0.0',
'hook-class-names': [
'custom_provider.hooks.InternalAPIHook',
],
'operator-modules': [
'custom_provider.operators',
],
}
Performance Metrics
| Metric | Description | Optimization Strategy |
|---|---|---|
| Hook connection time | Time to establish connection | Connection pooling, session reuse |
| Operator parse time | DAG file parse duration | Minimize imports at module level |
| Serialization overhead | Time for dynamic DAG serialization | Cache serialized operators |
| Template render time | Jinja2 rendering duration | Avoid expensive expressions in templates |
| Memory footprint | RAM per operator instance | Use __slots__ for large operators |
Best Practices
- Idempotency: Ensure operators are safe to retry. Use unique identifiers, upsert operations, and atomic transactions.
- Single Responsibility: Each operator should do one thing well. Decompose complex workflows into multiple tasks.
- Template Fields: Always declare
template_fieldsfor parameters that should accept Jinja expressions. - Error Handling: Use
AirflowExceptionfor retryable errors,AirflowFailExceptionfor permanent failures, andAirflowSkipExceptionfor no-data scenarios. - Logging: Use
self.log.info()andself.log.error()instead ofprint()for structured output. - Type Hints: Use Python type hints for better IDE support and documentation.
- Documentation: Write comprehensive docstrings with parameter descriptions and usage examples.
- Testing: Mock external dependencies and test both success and failure paths.
- Connection Management: Always use Airflow's connection system. Never hardcode credentials.
- Package as Provider: Distribute custom operators as Airflow provider packages for reuse across teams.
When designing custom operators, follow the "operator → hook → connection" pattern. The operator handles business logic, the hook manages connection details, and the connection stores credentials in the metadata database.
Custom operators can be deployed as provider packages (pip-installable) or placed directly in the dags/ folder. Provider packages are recommended for production as they support versioning, testing, and cross-DAG reuse.
Key Takeaways:
- Custom operators extend
BaseOperatorand implementexecute(context) - Template fields enable Jinja2 rendering for dynamic parameters
- Hooks abstract connection management and should be used for external service access
- Always use Airflow's connection system for credentials — never hardcode
- Distribute custom operators as provider packages for team-wide reuse
- Test operators with mocked hooks and validate both success and failure paths
See Also
- Airflow Architecture — Core architecture and component overview
- Operators and Hooks — Built-in operators and hook patterns
- Sensors and Operators — Sensor-based operators and poke modes
- Branching Logic — BranchPythonOperator and conditional workflows
- XCom Communications — Task communication and data passing