CW

Building Custom Airflow Operators

Free Lesson

Advertisement

Building Custom Airflow Operators

Architecture Diagram

Formal Definitions

DfCustom Operator

A custom operator is a user-defined class OcustomO_{\text{custom}} that extends BaseOperator to encapsulate reusable workflow logic. It must implement an execute(context) method and optionally define template_fields, template_ext, and serialization hooks for the Airflow scheduler.

DfOperator Template Fields

Template fields are a tuple of attribute names that Airflow renders through Jinja2 before passing them to execute(). If T=(t1,t2,,tn)T = (t_1, t_2, \ldots, t_n) is the set of template fields, then for each tiTt_i \in T, the value is substituted from the DAG context: vti=render(ti,context)v_{t_i} = \text{render}(t_i, \text{context}).

DfHook Abstraction Layer

A hook HH provides a connection-aware client for external systems. It encapsulates connection retrieval, authentication, retry logic, and session management. Hooks decouple operators from low-level networking concerns.

Detailed Explanation

Anatomy of a Custom Operator

Every custom operator inherits from BaseOperator. The class must define template_fields for any parameter that should support Jinja templating. The execute() method receives a context dictionary and performs the operator's work.

from typing import Any, Dict, Optional
from datetime import timedelta
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException


class DataQualityCheckOperator(BaseOperator):
    """
    Custom operator for validating data quality against defined rules.

    :param table_name: Name of the table to validate
    :param validation_rules: Dictionary of validation rules
    :param connection_id: Airflow connection ID for database access
    :param fail_threshold: Maximum allowed failure percentage
    """

    template_fields = ('table_name', 'validation_rules')
    ui_color = '#4CAF50'
    ui_fgcolor = '#FFFFFF'
    template_ext = ('.sql',)

    @apply_defaults
    def __init__(
        self,
        table_name: str,
        validation_rules: Dict[str, Any],
        connection_id: str = 'postgres_default',
        fail_threshold: float = 0.0,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.table_name = table_name
        self.validation_rules = validation_rules
        self.connection_id = connection_id
        self.fail_threshold = fail_threshold

    def execute(self, context: Dict[str, Any]) -> Any:
        self.log.info(f"Running quality checks on {self.table_name}")
        results = self._run_checks()
        self._evaluate_results(results)
        context['ti'].xcom_push(key='quality_results', value=results)
        return results

    def _run_checks(self) -> Dict[str, Any]:
        from airflow.providers.postgres.hooks.postgres import PostgresHook

        hook = PostgresHook(postgres_conn_id=self.connection_id)
        results = {}

        for rule_name, rule_config in self.validation_rules.items():
            query = rule_config.get('query')
            expected = rule_config.get('expected_value')

            records = hook.get_first(query)
            actual = records[0] if records else None

            results[rule_name] = {
                'passed': actual == expected,
                'actual': actual,
                'expected': expected,
            }

        return results

    def _evaluate_results(self, results: Dict[str, Any]) -> None:
        failed = [r for r, v in results.items() if not v['passed']]
        total = len(results)
        fail_rate = len(failed) / total if total > 0 else 0

        if fail_rate > self.fail_threshold:
            raise AirflowException(
                f"Quality check failed: {len(failed)}/{total} rules failed. "
                f"Failures: {failed}"
            )

    def on_kill(self) -> None:
        self.log.info("Operator terminated")

Custom Hook Implementation

Hooks manage connection details and client lifecycle. They retrieve credentials from Airflow's connection store and provide a clean API for operators.

from typing import Any, Dict, Optional
from airflow.hooks.base import BaseHook
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry


class InternalAPIHook(BaseHook):
    """
    Hook for connecting to internal REST APIs.

    :param api_conn_id: Airflow connection ID
    :param timeout: Request timeout in seconds
    :param max_retries: Maximum retry attempts
    """

    conn_name_attr = 'api_conn_id'
    default_conn_name = 'internal_api_default'
    conn_type = 'http'
    hook_name = 'Internal API'

    def __init__(
        self,
        api_conn_id: str = default_conn_name,
        timeout: int = 30,
        max_retries: int = 3,
    ):
        super().__init__()
        self.api_conn_id = api_conn_id
        self.timeout = timeout
        self.max_retries = max_retries
        self._session: Optional[requests.Session] = None

    def get_conn(self) -> requests.Session:
        if self._session is None:
            self._session = self._create_session()
        return self._session

    def _create_session(self) -> requests.Session:
        conn = self.get_connection(self.api_conn_id)
        session = requests.Session()

        retry = Retry(
            total=self.max_retries,
            backoff_factor=0.5,
            status_forcelist=[429, 500, 502, 503, 504],
        )
        adapter = HTTPAdapter(max_retries=retry)
        session.mount('https://', adapter)
        session.mount('http://', adapter)

        token = conn.get_password()
        if token:
            session.headers['Authorization'] = f'Bearer {token}'
        session.headers['Content-Type'] = 'application/json'

        self._base_url = f'{conn.schema}://{conn.host}:{conn.port}'
        return session

    def request(
        self,
        method: str,
        endpoint: str,
        json: Optional[Dict[str, Any]] = None,
        params: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        session = self.get_conn()
        url = f'{self._base_url}{endpoint}'

        response = session.request(
            method=method,
            url=url,
            json=json,
            params=params,
            timeout=self.timeout,
        )
        response.raise_for_status()
        return response.json()

    def get(self, endpoint: str, **kwargs) -> Dict[str, Any]:
        return self.request('GET', endpoint, **kwargs)

    def post(self, endpoint: str, data: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        return self.request('POST', endpoint, json=data, **kwargs)

    def put(self, endpoint: str, data: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        return self.request('PUT', endpoint, json=data, **kwargs)

    def delete(self, endpoint: str, **kwargs) -> Dict[str, Any]:
        return self.request('DELETE', endpoint, **kwargs)

Operator with Multiple Hooks

from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from typing import Any, Dict, List


class MultiSourceSyncOperator(BaseOperator):
    """
    Synchronize data between multiple source and target systems.
    """

    template_fields = ('source_query', 'target_table')

    @apply_defaults
    def __init__(
        self,
        source_conn_id: str,
        target_conn_id: str,
        source_query: str,
        target_table: str,
        batch_size: int = 1000,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.source_conn_id = source_conn_id
        self.target_conn_id = target_conn_id
        self.source_query = source_query
        self.target_table = target_table
        self.batch_size = batch_size

    def execute(self, context: Dict[str, Any]) -> Any:
        source_hook = self._get_hook(self.source_conn_id)
        target_hook = self._get_hook(self.target_conn_id)

        data = self._extract(source_hook)
        loaded = self._load(target_hook, data)

        self.log.info(f'Synced {loaded} records to {self.target_table}')
        return {'records_loaded': loaded}

    def _get_hook(self, conn_id: str) -> Any:
        from airflow.providers.postgres.hooks.postgres import PostgresHook
        return PostgresHook(postgres_conn_id=conn_id)

    def _extract(self, hook: Any) -> List[tuple]:
        return hook.get_records(self.source_query)

    def _load(self, hook: Any, data: List[tuple]) -> int:
        count = 0
        for i in range(0, len(data), self.batch_size):
            batch = data[i:i + self.batch_size]
            hook.insert_rows(table=self.target_table, rows=batch)
            count += len(batch)
        return count

Serialization for Dynamic DAGs

from airflow.models import BaseOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator
from typing import Any, Dict


class SerializableOperator(BaseOperator):
    """
    Custom operator that supports serialization for dynamic DAGs.
    """

    template_fields = ('message',)

    def __init__(self, message: str = 'hello', *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.message = message

    def execute(self, context: Dict[str, Any]) -> None:
        self.log.info(self.message)

    def serialize(self) -> Dict[str, Any]:
        return {
            'class': self.__class__.__name__,
            'module': self.__class__.__module__,
            'kwargs': {
                'message': self.message,
                'task_id': self.task_id,
                'retries': self.retries,
            },
        }

    @classmethod
    def deserialize(cls, data: Dict[str, Any], **kwargs) -> 'SerializableOperator':
        return cls(**data['kwargs'], **kwargs)

Key Concepts Table

ComponentPurposeRequired?Example
BaseOperatorParent class for all operatorsYesclass MyOp(BaseOperator)
execute()Core logic entry pointYesdef execute(self, context)
template_fieldsJinja-rendered attributesRecommendedtemplate_fields = ('query',)
template_extExternal file templatesOptionaltemplate_ext = ('.sql',)
ui_colorDAG visualization colorOptionalui_color = '#4CAF50'
on_kill()Cleanup on terminationOptionaldef on_kill(self)
HookExternal connection managementWhen neededclass MyHook(BaseHook)
Provider PackageDistribution packagingOptionalapache-airflow-providers-x

Code Examples

Testing Custom Operators

import pytest
from unittest.mock import MagicMock, patch
from airflow.models import DagBag
from datetime import datetime


class TestDataQualityCheckOperator:
    """Unit tests for the custom DataQualityCheckOperator."""

    def test_execute_passes_all_rules(self):
        from airflow.providers.postgres.hooks.postgres import PostgresHook

        with patch.object(PostgresHook, 'get_first', return_value=(42,)):
            op = DataQualityCheckOperator(
                task_id='test_quality',
                table_name='orders',
                validation_rules={
                    'row_count': {
                        'query': 'SELECT COUNT(*) FROM orders',
                        'expected_value': 42,
                    }
                },
                connection_id='test_postgres',
            )
            context = {'ti': MagicMock(), 'ds': '2024-01-01'}
            result = op.execute(context)

            assert result['row_count']['passed'] is True

    def test_execute_fails_on_threshold(self):
        from airflow.providers.postgres.hooks.postgres import PostgresHook
        from airflow.exceptions import AirflowException

        with patch.object(PostgresHook, 'get_first', return_value=(0,)):
            op = DataQualityCheckOperator(
                task_id='test_quality',
                table_name='orders',
                validation_rules={
                    'check1': {'query': 'SELECT 1', 'expected_value': 1},
                    'check2': {'query': 'SELECT 1', 'expected_value': 1},
                },
                connection_id='test_postgres',
                fail_threshold=0.0,
            )
            context = {'ti': MagicMock(), 'ds': '2024-01-01'}

            with pytest.raises(AirflowException):
                op.execute(context)

    def test_template_fields_rendered(self):
        op = DataQualityCheckOperator(
            task_id='test',
            table_name='{{ ds_nodash }}',
            validation_rules={},
        )
        assert op.task_id == 'test'
        assert '{{ ds_nodash }}' in op.template_fields

Testing Custom Hooks

import pytest
from unittest.mock import MagicMock, patch


class TestInternalAPIHook:
    """Unit tests for the custom hook."""

    def test_get_conn_returns_session(self):
        from airflow.models import Connection

        mock_conn = Connection(
            conn_id='test_api',
            conn_type='http',
            host='api.example.com',
            port=443,
            schema='https',
            password='test-token',
        )

        with patch.object(InternalAPIHook, 'get_connection', return_value=mock_conn):
            hook = InternalAPIHook(api_conn_id='test_api')
            session = hook.get_conn()

            assert session is not None
            assert 'Authorization' in session.headers

    def test_request_sends_correct_headers(self):
        hook = InternalAPIHook(api_conn_id='test')
        hook._session = MagicMock()
        hook._base_url = 'https://api.example.com'

        hook.get('/health')

        hook._session.request.assert_called_once()
        call_kwargs = hook._session.request.call_args
        assert call_kwargs[1]['method'] == 'GET'

Provider Package Structure

# setup.py for custom provider package
from setuptools import setup, find_packages

setup(
    name='apache-airflow-providers-custom',
    version='1.0.0',
    description='Custom Airflow provider for internal APIs',
    packages=find_packages(),
    install_requires=[
        'apache-airflow>=2.6.0',
        'requests>=2.28.0',
    ],
    entry_points={
        'apache_airflow_provider': [
            'provider_info = custom_provider.__init__:get_provider_info',
        ],
    },
)
# custom_provider/__init__.py
def get_provider_info():
    return {
        'package-name': 'apache-airflow-providers-custom',
        'name': 'Custom Provider',
        'description': 'Custom operators and hooks',
        'version': '1.0.0',
        'hook-class-names': [
            'custom_provider.hooks.InternalAPIHook',
        ],
        'operator-modules': [
            'custom_provider.operators',
        ],
    }

Performance Metrics

MetricDescriptionOptimization Strategy
Hook connection timeTime to establish connectionConnection pooling, session reuse
Operator parse timeDAG file parse durationMinimize imports at module level
Serialization overheadTime for dynamic DAG serializationCache serialized operators
Template render timeJinja2 rendering durationAvoid expensive expressions in templates
Memory footprintRAM per operator instanceUse __slots__ for large operators

Best Practices

  1. Idempotency: Ensure operators are safe to retry. Use unique identifiers, upsert operations, and atomic transactions.
  2. Single Responsibility: Each operator should do one thing well. Decompose complex workflows into multiple tasks.
  3. Template Fields: Always declare template_fields for parameters that should accept Jinja expressions.
  4. Error Handling: Use AirflowException for retryable errors, AirflowFailException for permanent failures, and AirflowSkipException for no-data scenarios.
  5. Logging: Use self.log.info() and self.log.error() instead of print() for structured output.
  6. Type Hints: Use Python type hints for better IDE support and documentation.
  7. Documentation: Write comprehensive docstrings with parameter descriptions and usage examples.
  8. Testing: Mock external dependencies and test both success and failure paths.
  9. Connection Management: Always use Airflow's connection system. Never hardcode credentials.
  10. Package as Provider: Distribute custom operators as Airflow provider packages for reuse across teams.

When designing custom operators, follow the "operator → hook → connection" pattern. The operator handles business logic, the hook manages connection details, and the connection stores credentials in the metadata database.

Custom operators can be deployed as provider packages (pip-installable) or placed directly in the dags/ folder. Provider packages are recommended for production as they support versioning, testing, and cross-DAG reuse.

Key Takeaways:

  • Custom operators extend BaseOperator and implement execute(context)
  • Template fields enable Jinja2 rendering for dynamic parameters
  • Hooks abstract connection management and should be used for external service access
  • Always use Airflow's connection system for credentials — never hardcode
  • Distribute custom operators as provider packages for team-wide reuse
  • Test operators with mocked hooks and validate both success and failure paths

See Also

Advertisement

Need Expert Airflow Help?

Get personalized DAG design, scheduling optimization, or production Airflow consulting.

Advertisement