πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Testing Spark Applications

Apache SparkEngineering⭐ Premium

Advertisement

Testing Spark Applications

Difficulty: Senior Level | Companies: Databricks, Netflix, Uber, Apple, Airbnb

Testing Challenges in Spark

Spark applications face unique testing challenges: distributed execution, lazy evaluation, and side effects. A systematic approach is essential.

Local Testing Setup

import pytest
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

@pytest.fixture(scope="session")
def spark():
    """Create a local Spark session for testing"""
    spark = SparkSession.builder \
        .master("local[2]") \
        .appName("unit-tests") \
        .config("spark.ui.enabled", "false") \
        .config("spark.sql.shuffle.partitions", "2") \
        .config("spark.driver.bindAddress", "127.0.0.1") \
        .getOrCreate()
    
    yield spark
    
    spark.stop()

@pytest.fixture
def sample_data(spark):
    """Create sample test data"""
    data = [
        ("user1", "click", 100.0, "2024-06-15"),
        ("user2", "view", 0.0, "2024-06-15"),
        ("user1", "purchase", 250.0, "2024-06-16"),
        ("user3", "click", 100.0, "2024-06-16"),
    ]
    
    return spark.createDataFrame(data, ["user_id", "event_type", "amount", "date"])

def test_filter_events(spark, sample_data):
    """Test filtering logic"""
    from my_module import filter_events
    
    result = filter_events(sample_data, event_type="click")
    
    assert result.count() == 2
    assert result.filter(F.col("event_type") != "click").count() == 0

ℹ️

Interview Insight: Always test Spark code locally with small datasets. Use local[*] mode for parallel testing and disable UI to reduce resource usage.

Unit Testing Transformations

def test_add_tax_column(spark, sample_data):
    """Test adding calculated column"""
    from my_module import add_tax_column
    
    result = add_tax_column(sample_data, tax_rate=0.08)
    
    # Verify column exists
    assert "tax" in result.columns
    
    # Verify calculation
    expected = sample_data.withColumn(
        "tax", 
        F.when(F.col("event_type") == "purchase", F.col("amount") * 0.08)
         .otherwise(0.0)
    )
    
    assert result.select("tax").collect() == expected.select("tax").collect()

def test_user_aggregation(spark, sample_data):
    """Test aggregation logic"""
    from my_module import aggregate_by_user
    
    result = aggregate_by_user(sample_data)
    
    # Verify aggregation
    user1_total = result.filter(F.col("user_id") == "user1").collect()[0]["total_amount"]
    assert user1_total == 350.0  # 100 + 250
    
    # Verify all users present
    assert result.count() == 3

def test_handle_nulls(spark):
    """Test null handling"""
    from my_module import process_data
    
    data = [(1, None), (2, "value"), (3, None)]
    df = spark.createDataFrame(data, ["id", "value"])
    
    result = process_data(df)
    
    # Verify nulls are handled
    assert result.filter(F.col("value").isNull()).count() == 0
    assert result.count() == 3

Testing with PySpark

from pyspark.testing import assertDataFrameEqual

def test_dataframe_equality(spark):
    """Test DataFrame content and schema"""
    from my_module import transform_data
    
    input_data = [(1, "a"), (2, "b")]
    input_df = spark.createDataFrame(input_data, ["id", "value"])
    
    result = transform_data(input_df)
    
    expected_data = [(1, "A"), (2, "B")]
    expected_df = spark.createDataFrame(expected_data, ["id", "value_upper"])
    
    # Compare DataFrames (schema + data)
    assertDataFrameEqual(result, expected_df)

def test_schema_validation(spark):
    """Validate DataFrame schema"""
    from my_module import process_data
    
    df = spark.createDataFrame([(1, "test")], ["id", "name"])
    result = process_data(df)
    
    # Check schema
    expected_schema = "id: int, name: string, processed: boolean"
    assert str(result.schema) == expected_schema

def test_partition_count(spark):
    """Verify partition count after operations"""
    from my_module import repartition_data
    
    df = spark.range(1000)
    result = repartition_data(df, num_partitions=10)
    
    assert result.rdd.getNumPartitions() == 10

Integration Testing

import tempfile
import os

def test_end_to_end_pipeline(spark):
    """Test complete data pipeline"""
    from my_module import run_pipeline
    
    # Create temp directories for test
    with tempfile.TemporaryDirectory() as tmpdir:
        input_path = os.path.join(tmpdir, "input")
        output_path = os.path.join(tmpdir, "output")
        
        # Create test input
        test_data = spark.createDataFrame(
            [(1, "a", 100.0), (2, "b", 200.0)],
            ["id", "category", "amount"]
        )
        test_data.write.parquet(input_path)
        
        # Run pipeline
        run_pipeline(spark, input_path, output_path)
        
        # Verify output
        result = spark.read.parquet(output_path)
        assert result.count() == 2
        assert "processed_amount" in result.columns

def test_with_external_services(spark):
    """Test with mocked external services"""
    from unittest.mock import Mock, patch
    from my_module import process_with_api
    
    # Mock external API
    mock_api = Mock()
    mock_api.get.return_value = {"status": "success"}
    
    df = spark.createDataFrame([(1, "test")], ["id", "name"])
    
    with patch('my_module.api_client', mock_api):
        result = process_with_api(df)
    
    # Verify API was called
    assert mock_api.get.call_count == 2
    assert result.filter(F.col("status") == "success").count() == 2

⚠️

Warning: Integration tests with external services should use mocks to avoid flaky tests and reduce costs. Never test against production databases.

Data Quality Testing

def test_data_quality(spark, sample_data):
    """Validate data quality constraints"""
    from my_module import validate_data
    
    # Check for nulls in required columns
    required_columns = ["user_id", "event_type", "amount"]
    for col in required_columns:
        null_count = sample_data.filter(F.col(col).isNull()).count()
        assert null_count == 0, f"Nulls found in {col}: {null_count}"
    
    # Check value ranges
    assert sample_data.filter(F.col("amount") < 0).count() == 0
    
    # Check distinct values
    valid_event_types = {"click", "view", "purchase"}
    actual_types = set(sample_data.select("event_type").distinct().toPandas()["event_type"].tolist())
    assert actual_types.issubset(valid_event_types)

def test_ Referential_integrity(spark):
    """Validate foreign key relationships"""
    from my_module import validate_referential_integrity
    
    orders = spark.read.parquet("hdfs://test/orders")
    customers = spark.read.parquet("hdfs://test/customers")
    
    # Check all orders have valid customer_id
    orphan_orders = orders.join(
        customers, 
        "customer_id", 
        "left_anti"
    )
    
    assert orphan_orders.count() == 0, f"Orphan orders found: {orphan_orders.count()}"

Performance Testing

def test_query_performance(spark):
    """Ensure query completes within time threshold"""
    import time
    from my_module import optimized_query
    
    df = spark.range(1_000_000).withColumn("value", F.rand())
    
    start = time.time()
    result = optimized_query(df)
    result.count()  # Materialize
    elapsed = time.time() - start
    
    assert elapsed < 30.0, f"Query took {elapsed:.1f}s, exceeds 30s threshold"

def test_shuffle_metrics(spark):
    """Verify shuffle operations are minimized"""
    from my_module import join_data
    
    left = spark.range(100000).withColumn("key", F.col("id") % 100)
    right = spark.range(100).withColumn("key", F.col("id"))
    
    # Enable shuffle tracking
    spark.sparkContext.setLocalProperty("spark.scheduler.mode", "FAIR")
    
    result = join_data(left, right)
    result.count()
    
    # Check Spark UI for shuffle metrics
    # In real tests, you'd programmatically access metrics

Testing Best Practices

# conftest.py - Shared test fixtures
import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark():
    return SparkSession.builder \
        .master("local[2]") \
        .config("spark.ui.enabled", "false") \
        .getOrCreate()

@pytest.fixture
def test_data_path(tmp_path):
    """Create temporary test data"""
    return str(tmp_path / "test_data")

# Parameterized tests
@pytest.mark.parametrize("tax_rate,expected", [
    (0.08, 8.0),
    (0.10, 10.0),
    (0.00, 0.0),
])
def test_tax_calculation(spark, tax_rate, expected):
    from my_module import calculate_tax
    
    df = spark.createDataFrame([(100.0,)], ["amount"])
    result = calculate_tax(df, tax_rate)
    
    assert result.collect()[0]["tax"] == expected

# Skip tests based on conditions
@pytest.mark.skipif(
    not spark.sparkContext._jsc.sc().isLocal(),
    reason="Only run locally"
)
def test_local_only():
    pass

ℹ️

Key Takeaway: Test Spark code systematically: unit test transformations locally, integration test complete pipelines, validate data quality, and ensure performance meets thresholds. Use small datasets and local mode for fast feedback.

Follow-Up Questions

  • How would you test a Spark Streaming application?
  • Explain strategies for testing UDFs and custom transformations.
  • How do you handle flaky tests in Spark test suites?
  • Describe testing approaches for Delta Lake operations.
  • How would you test Spark applications that interact with external systems?

Advertisement