CW

TaskFlow API and Decorators in Apache Airflow

Free Lesson

Advertisement

TaskFlow API and Decorators

Architecture Diagram

Formal Definitions

DfTaskFlow API

The TaskFlow API is a modern DAG authoring paradigm in Airflow that uses Python decorators (@task, @dag) to reduce boilerplate code. It automatically handles XCom pushing/pulling of return values, enabling function-based task definitions where return values are pushed and upstream return values are injected as function parameters.

DfDynamic Task Mapping

Dynamic Task Mapping allows creating tasks at runtime based on data. Given a function f:D2Tf: D \rightarrow 2^T where DD is input data and TT is the set of possible tasks, the .map() method expands a single task into f(d)|f(d)| parallel task instances at schedule time.

DfXCom Auto-Serialization

XCom Auto-Serialization is the mechanism by which TaskFlow automatically serializes return values to XCom and deserializes them when passed as function arguments. The serialization follows s:VBs: V \rightarrow B where VV is the return value and BB is the serialized bytes stored in the metadata database.

Detailed Explanation

Basic TaskFlow Usage

The TaskFlow API transforms how you write Airflow DAGs by using Python decorators instead of explicit operator instantiation.

from airflow import DAG
from airflow.decorators import task, dag
from datetime import datetime, timedelta

# Using @dag decorator
@dag(
    schedule_interval=timedelta(hours=1),
    start_date=datetime(2024, 1, 1),
    catchup=False,
    tags=['taskflow', 'example'],
)
def taskflow_example():
    
    @task
    def extract_data():
        """Extract data - return value auto-pushed to XCom."""
        data = [
            {"id": 1, "name": "Alice", "score": 85},
            {"id": 2, "name": "Bob", "score": 92},
            {"id": 3, "name": "Charlie", "score": 78},
        ]
        return data
    
    @task
    def transform_data(raw_data: list) -> list:
        """Transform data - raw_data auto-pulled from upstream XCom."""
        transformed = []
        for record in raw_data:
            transformed.append({
                "id": record["id"],
                "name": record["name"].upper(),
                "score": record["score"],
                "grade": "A" if record["score"] >= 90 else "B" if record["score"] >= 80 else "C",
            })
        return transformed
    
    @task
    def load_data(transformed_data: list) -> int:
        """Load data and return count."""
        print(f"Loading {len(transformed_data)} records")
        for record in transformed_data:
            print(f"  {record['name']}: {record['grade']}")
        return len(transformed_data)
    
    @task
    def report(count: int):
        """Generate report."""
        print(f"Processed {count} records successfully")
    
    # Define the pipeline
    raw = extract_data()
    transformed = transform_data(raw)
    loaded = load_data(transformed)
    report(loaded)

# Instantiate the DAG
taskflow_example()

TaskFlow with Multiple Returns

from airflow.decorators import task, dag
from datetime import datetime, timedelta
from typing import Tuple, Dict, List

@dag(
    schedule_interval="@daily",
    start_date=datetime(2024, 1, 1),
    catchup=False,
    tags=['taskflow', 'multiple-returns'],
)
def multi_return_dag():
    
    @task
    def extract_metrics() -> Tuple[List[Dict], Dict]:
        """Return multiple values - each pushed as separate XCom key."""
        records = [
            {"user_id": 1, "action": "login", "timestamp": "2024-01-01 10:00"},
            {"user_id": 2, "action": "purchase", "timestamp": "2024-01-01 11:00"},
        ]
        metadata = {"total_records": len(records), "source": "analytics"}
        return records, metadata
    
    @task
    def process_records(records: list, metadata: dict):
        """Pull specific return values by index."""
        print(f"Processing {metadata['total_records']} records from {metadata['source']}")
        for record in records:
            print(f"  User {record['user_id']}: {record['action']}")
    
    @task
    def compute_statistics(records: list) -> Dict:
        """Compute aggregate statistics."""
        actions = [r["action"] for r in records]
        return {
            "unique_users": len(set(r["user_id"] for r in records)),
            "action_counts": {a: actions.count(a) for a in set(actions)},
        }
    
    records, metadata = extract_metrics()
    process_records(records, metadata)
    compute_statistics(records)

multi_return_dag()

Dynamic Task Mapping

from airflow.decorators import task, dag
from airflow.decorators import task
from datetime import datetime, timedelta

@dag(
    schedule_interval="@daily",
    start_date=datetime(2024, 1, 1),
    catchup=False,
    tags=['taskflow', 'dynamic-mapping'],
)
def dynamic_mapping_dag():
    
    @task
    def get_regions() -> list:
        """Return list of regions to process."""
        return ["us-east-1", "us-west-2", "eu-west-1", "ap-southeast-1"]
    
    @task
    def process_region(region: str) -> dict:
        """Process a single region - will be mapped dynamically."""
        import random
        record_count = random.randint(100, 1000)
        return {
            "region": region,
            "records_processed": record_count,
            "status": "success",
        }
    
    @task
    def aggregate_results(region_results: list) -> dict:
        """Aggregate results from all mapped tasks."""
        total_records = sum(r["records_processed"] for r in region_results)
        return {
            "total_regions": len(region_results),
            "total_records": total_records,
            "regions": [r["region"] for r in region_results],
        }
    
    # Dynamic mapping - creates parallel tasks at runtime
    regions = get_regions()
    region_results = process_region.map(regions)
    aggregate_results(region_results)

dynamic_mapping_dag()
Dynamic Task Count
Ntasks=fmap(D)N_{\text{tasks}} = |f_{\text{map}}(D)|

Here,

  • NtasksN_{\text{tasks}}=Number of mapped task instances created
  • fmapf_{\text{map}}=Mapping function applied to input data
  • DD=Input dataset (list, dict, or XCom)

TaskFlow XCom Throughput

Txcom=i=1n(Spush,i+Spull,i)LlatencyT_{\text{xcom}} = \sum_{i=1}^{n} (S_{\text{push},i} + S_{\text{pull},i}) \cdot L_{\text{latency}}

Here,

  • nn=Number of tasks with XCom operations
  • Spush,iS_{\text{push},i}=Serialization size for task i push
  • Spull,iS_{\text{pull},i}=Deserialization size for task i pull
  • LlatencyL_{\text{latency}}=XCom backend latency per operation

TaskFlow automatically pushes return values to XCom with key return_value. To push multiple values, return a tuple. Each element is stored as a separate XCom with keys return_value, return_value__1, return_value__2, etc.

For large data (>48KB), use xcom_push and xcom_pull with custom keys, or configure an alternative XCom backend like S3, GCS, or a custom backend. The default database backend has performance limitations for large payloads.

Key Concepts Table

FeatureTraditional OperatorsTaskFlow API
Task DefinitionPythonOperator(python_callable=func)@task def func():
XCom Pushti.xcom_push(key, value)return value (automatic)
XCom Pullti.xcom_pull(task_ids, key)Function parameter injection
Multiple ReturnsMultiple push callsReturn tuple
DAG Definitionwith DAG(...): block@dag decorator
Dynamic Mappingexpand() method.map() on TaskFlow tasks
Code BoilerplateHighLow
Type HintsOptionalEncouraged for clarity

Code Examples

Advanced TaskFlow Patterns

from airflow.decorators import task, dag
from datetime import datetime, timedelta
from typing import Optional, List, Dict
import json

@dag(
    schedule_interval="0 6 * * *",
    start_date=datetime(2024, 1, 1),
    catchup=False,
    tags=['taskflow', 'advanced'],
    doc_md="""
    ## Advanced TaskFlow Patterns
    Demonstrates TaskFlow API with:
    - Multiple return values
    - Dynamic task mapping
    - Error handling
    - Custom XCom keys
    """,
)
def advanced_taskflow_dag():
    
    @task(retries=3, retry_delay=timedelta(minutes=1))
    def extract_from_source(source: str) -> List[Dict]:
        """Extract data with retry logic."""
        import random
        if random.random() < 0.1:
            raise ConnectionError(f"Failed to connect to {source}")
        
        return [
            {"id": i, "source": source, "value": random.randint(1, 100)}
            for i in range(10)
        ]
    
    @task
    def validate_data(records: List[Dict]) -> tuple:
        """Validate and separate valid/invalid records."""
        valid = [r for r in records if 0 <= r["value"] <= 100]
        invalid = [r for r in records if r["value"] < 0 or r["value"] > 100]
        return valid, invalid
    
    @task
    def transform_record(record: Dict) -> Dict:
        """Transform a single record (used with .map())."""
        return {
            **record,
            "value_normalized": record["value"] / 100.0,
            "transformed": True,
        }
    
    @task
    def load_batch(records: List[Dict], destination: str) -> int:
        """Load a batch of records."""
        print(f"Loading {len(records)} records to {destination}")
        return len(records)
    
    @task
    def generate_report(loaded_counts: list, invalid_records: list) -> str:
        """Generate summary report."""
        total_loaded = sum(loaded_counts)
        total_invalid = len(invalid_records)
        
        report = {
            "total_loaded": total_loaded,
            "total_invalid": total_invalid,
            "success_rate": total_loaded / (total_loaded + total_invalid) * 100,
        }
        return json.dumps(report, indent=2)
    
    # Define sources
    sources = ["postgres", "mysql", "mongodb"]
    
    # Extract from multiple sources (dynamic mapping)
    raw_data = extract_from_source.map(sources)
    
    # Validate each source's data
    valid_data, invalid_data = validate_data.expand(raw_data)
    
    # Transform valid records (dynamic mapping)
    transformed = transform_record.map(valid_data)
    
    # Load to destination
    loaded_count = load_batch(transformed, "data_warehouse")
    
    # Generate report
    generate_report(loaded_count, invalid_data)

advanced_taskflow_dag()

TaskFlow with XCom Backend Configuration

from airflow.decorators import task, dag
from airflow.models.xcom import BaseXCom
from datetime import datetime, timedelta
import json
import pickle

class CustomXComBackend(BaseXCom):
    """Custom XCom backend for large payloads."""
    
    @staticmethod
    def serialize_value(value):
        """Serialize value for storage."""
        if isinstance(value, (dict, list)):
            return json.dumps(value)
        return str(value)
    
    @staticmethod
    def deserialize_value(result):
        """Deserialize value from storage."""
        try:
            return json.loads(result)
        except (json.JSONDecodeError, TypeError):
            return result

@dag(
    schedule_interval="@hourly",
    start_date=datetime(2024, 1, 1),
    catchup=False,
    tags=['taskflow', 'xcom-backend'],
    render_template_as_native_obj=True,
)
def xcom_backend_dag():
    
    @task(
        xcom_backend='path.to.CustomXComBackend',
        pool='data_processing',
    )
    def process_large_dataset(records: list) -> dict:
        """Process dataset using custom XCom backend."""
        import pandas as pd
        
        df = pd.DataFrame(records)
        
        result = {
            "mean": float(df["value"].mean()),
            "std": float(df["value"].std()),
            "count": int(len(df)),
            "percentiles": {
                "25": float(df["value"].quantile(0.25)),
                "50": float(df["value"].quantile(0.50)),
                "75": float(df["value"].quantile(0.75)),
            },
        }
        return result
    
    @task
    def summarize(stats: dict) -> str:
        """Create summary string."""
        return f"Processed {stats['count']} records, mean={stats['mean']:.2f}"
    
    sample_data = [{"value": i * 1.5} for i in range(100)]
    stats = process_large_dataset(sample_data)
    summarize(stats)

xcom_backend_dag()

TaskMap with Cross-Task Dependencies

from airflow.decorators import task, dag
from datetime import datetime, timedelta

@dag(
    schedule_interval="@daily",
    start_date=datetime(2024, 1, 1),
    catchup=False,
    tags=['taskflow', 'taskmap', 'cross-dependencies'],
)
def cross_task_dependency_dag():
    
    @task
    def get_config() -> dict:
        """Return configuration for downstream tasks."""
        return {
            "environments": ["dev", "staging", "prod"],
            "parallel_workers": 4,
            "timeout_seconds": 300,
        }
    
    @task
    def setup_environment(env: str, config: dict) -> dict:
        """Setup each environment using config."""
        return {
            "environment": env,
            "status": "ready",
            "workers": config["parallel_workers"],
        }
    
    @task
    def run_tests(env_setup: dict) -> dict:
        """Run tests in each environment."""
        return {
            "environment": env_setup["environment"],
            "tests_passed": 42,
            "tests_failed": 0,
        }
    
    @task
    def cleanup_environments(results: list) -> str:
        """Cleanup all environments after tests."""
        environments = [r["environment"] for r in results]
        return f"Cleaned up: {', '.join(environments)}"
    
    # Get configuration
    config = get_config()
    
    # Setup environments dynamically using config
    env_setups = setup_environment.map(
        env=config["environments"],
        config=config,  # Same config passed to all mapped tasks
    )
    
    # Run tests in each environment
    test_results = run_tests.expand(env_setups)
    
    # Cleanup
    cleanup_environments(test_results)

cross_task_dependency_dag()

Performance Metrics

TaskFlow vs Traditional Operators

MetricTraditionalTaskFlowImprovement
Lines of Code~50 per task~15 per task70% reduction
XCom OperationsExplicit callsAutomaticSimplified
Error HandlingManualBuilt-in retriesEnhanced
Dynamic MappingComplexSimple .map()80% simpler
Type SafetyOptionalEncouragedBetter
TestabilityModerateHighImproved

XCom Performance

BackendMax PayloadLatencyUse Case
Database (default)48KB recommended~5msSmall metadata
S35GB~50msLarge datasets
GCS5TB~100msCloud-native
Redis512MB~1msHigh-throughput
CustomConfigurableVariesSpecialized needs

Key Takeaways:

  • TaskFlow API reduces boilerplate by ~70% compared to traditional operators
  • Return values are automatically pushed to XCom; parameters auto-pull from upstream XCom
  • Dynamic task mapping with .map() creates parallel tasks at runtime
  • Use custom XCom backends for payloads > 48KB
  • @task(retries=N) adds retry logic directly in the decorator
  • Tuple returns enable multiple XCom values with indexed access

See Also

Advertisement

Need Expert Airflow Help?

Get personalized DAG design, scheduling optimization, or production Airflow consulting.

Advertisement