Testing Spark Applications
Difficulty: Senior Level | Companies: Databricks, Netflix, Uber, Apple, Airbnb
Testing Challenges in Spark
Spark applications face unique testing challenges: distributed execution, lazy evaluation, and side effects. A systematic approach is essential.
Local Testing Setup
import pytest
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
@pytest.fixture(scope="session")
def spark():
"""Create a local Spark session for testing"""
spark = SparkSession.builder \
.master("local[2]") \
.appName("unit-tests") \
.config("spark.ui.enabled", "false") \
.config("spark.sql.shuffle.partitions", "2") \
.config("spark.driver.bindAddress", "127.0.0.1") \
.getOrCreate()
yield spark
spark.stop()
@pytest.fixture
def sample_data(spark):
"""Create sample test data"""
data = [
("user1", "click", 100.0, "2024-06-15"),
("user2", "view", 0.0, "2024-06-15"),
("user1", "purchase", 250.0, "2024-06-16"),
("user3", "click", 100.0, "2024-06-16"),
]
return spark.createDataFrame(data, ["user_id", "event_type", "amount", "date"])
def test_filter_events(spark, sample_data):
"""Test filtering logic"""
from my_module import filter_events
result = filter_events(sample_data, event_type="click")
assert result.count() == 2
assert result.filter(F.col("event_type") != "click").count() == 0
βΉοΈ
Interview Insight: Always test Spark code locally with small datasets. Use local[*] mode for parallel testing and disable UI to reduce resource usage.
Unit Testing Transformations
def test_add_tax_column(spark, sample_data):
"""Test adding calculated column"""
from my_module import add_tax_column
result = add_tax_column(sample_data, tax_rate=0.08)
# Verify column exists
assert "tax" in result.columns
# Verify calculation
expected = sample_data.withColumn(
"tax",
F.when(F.col("event_type") == "purchase", F.col("amount") * 0.08)
.otherwise(0.0)
)
assert result.select("tax").collect() == expected.select("tax").collect()
def test_user_aggregation(spark, sample_data):
"""Test aggregation logic"""
from my_module import aggregate_by_user
result = aggregate_by_user(sample_data)
# Verify aggregation
user1_total = result.filter(F.col("user_id") == "user1").collect()[0]["total_amount"]
assert user1_total == 350.0 # 100 + 250
# Verify all users present
assert result.count() == 3
def test_handle_nulls(spark):
"""Test null handling"""
from my_module import process_data
data = [(1, None), (2, "value"), (3, None)]
df = spark.createDataFrame(data, ["id", "value"])
result = process_data(df)
# Verify nulls are handled
assert result.filter(F.col("value").isNull()).count() == 0
assert result.count() == 3
Testing with PySpark
from pyspark.testing import assertDataFrameEqual
def test_dataframe_equality(spark):
"""Test DataFrame content and schema"""
from my_module import transform_data
input_data = [(1, "a"), (2, "b")]
input_df = spark.createDataFrame(input_data, ["id", "value"])
result = transform_data(input_df)
expected_data = [(1, "A"), (2, "B")]
expected_df = spark.createDataFrame(expected_data, ["id", "value_upper"])
# Compare DataFrames (schema + data)
assertDataFrameEqual(result, expected_df)
def test_schema_validation(spark):
"""Validate DataFrame schema"""
from my_module import process_data
df = spark.createDataFrame([(1, "test")], ["id", "name"])
result = process_data(df)
# Check schema
expected_schema = "id: int, name: string, processed: boolean"
assert str(result.schema) == expected_schema
def test_partition_count(spark):
"""Verify partition count after operations"""
from my_module import repartition_data
df = spark.range(1000)
result = repartition_data(df, num_partitions=10)
assert result.rdd.getNumPartitions() == 10
Integration Testing
import tempfile
import os
def test_end_to_end_pipeline(spark):
"""Test complete data pipeline"""
from my_module import run_pipeline
# Create temp directories for test
with tempfile.TemporaryDirectory() as tmpdir:
input_path = os.path.join(tmpdir, "input")
output_path = os.path.join(tmpdir, "output")
# Create test input
test_data = spark.createDataFrame(
[(1, "a", 100.0), (2, "b", 200.0)],
["id", "category", "amount"]
)
test_data.write.parquet(input_path)
# Run pipeline
run_pipeline(spark, input_path, output_path)
# Verify output
result = spark.read.parquet(output_path)
assert result.count() == 2
assert "processed_amount" in result.columns
def test_with_external_services(spark):
"""Test with mocked external services"""
from unittest.mock import Mock, patch
from my_module import process_with_api
# Mock external API
mock_api = Mock()
mock_api.get.return_value = {"status": "success"}
df = spark.createDataFrame([(1, "test")], ["id", "name"])
with patch('my_module.api_client', mock_api):
result = process_with_api(df)
# Verify API was called
assert mock_api.get.call_count == 2
assert result.filter(F.col("status") == "success").count() == 2
β οΈ
Warning: Integration tests with external services should use mocks to avoid flaky tests and reduce costs. Never test against production databases.
Data Quality Testing
def test_data_quality(spark, sample_data):
"""Validate data quality constraints"""
from my_module import validate_data
# Check for nulls in required columns
required_columns = ["user_id", "event_type", "amount"]
for col in required_columns:
null_count = sample_data.filter(F.col(col).isNull()).count()
assert null_count == 0, f"Nulls found in {col}: {null_count}"
# Check value ranges
assert sample_data.filter(F.col("amount") < 0).count() == 0
# Check distinct values
valid_event_types = {"click", "view", "purchase"}
actual_types = set(sample_data.select("event_type").distinct().toPandas()["event_type"].tolist())
assert actual_types.issubset(valid_event_types)
def test_ Referential_integrity(spark):
"""Validate foreign key relationships"""
from my_module import validate_referential_integrity
orders = spark.read.parquet("hdfs://test/orders")
customers = spark.read.parquet("hdfs://test/customers")
# Check all orders have valid customer_id
orphan_orders = orders.join(
customers,
"customer_id",
"left_anti"
)
assert orphan_orders.count() == 0, f"Orphan orders found: {orphan_orders.count()}"
Performance Testing
def test_query_performance(spark):
"""Ensure query completes within time threshold"""
import time
from my_module import optimized_query
df = spark.range(1_000_000).withColumn("value", F.rand())
start = time.time()
result = optimized_query(df)
result.count() # Materialize
elapsed = time.time() - start
assert elapsed < 30.0, f"Query took {elapsed:.1f}s, exceeds 30s threshold"
def test_shuffle_metrics(spark):
"""Verify shuffle operations are minimized"""
from my_module import join_data
left = spark.range(100000).withColumn("key", F.col("id") % 100)
right = spark.range(100).withColumn("key", F.col("id"))
# Enable shuffle tracking
spark.sparkContext.setLocalProperty("spark.scheduler.mode", "FAIR")
result = join_data(left, right)
result.count()
# Check Spark UI for shuffle metrics
# In real tests, you'd programmatically access metrics
Testing Best Practices
# conftest.py - Shared test fixtures
import pytest
from pyspark.sql import SparkSession
@pytest.fixture(scope="session")
def spark():
return SparkSession.builder \
.master("local[2]") \
.config("spark.ui.enabled", "false") \
.getOrCreate()
@pytest.fixture
def test_data_path(tmp_path):
"""Create temporary test data"""
return str(tmp_path / "test_data")
# Parameterized tests
@pytest.mark.parametrize("tax_rate,expected", [
(0.08, 8.0),
(0.10, 10.0),
(0.00, 0.0),
])
def test_tax_calculation(spark, tax_rate, expected):
from my_module import calculate_tax
df = spark.createDataFrame([(100.0,)], ["amount"])
result = calculate_tax(df, tax_rate)
assert result.collect()[0]["tax"] == expected
# Skip tests based on conditions
@pytest.mark.skipif(
not spark.sparkContext._jsc.sc().isLocal(),
reason="Only run locally"
)
def test_local_only():
pass
βΉοΈ
Key Takeaway: Test Spark code systematically: unit test transformations locally, integration test complete pipelines, validate data quality, and ensure performance meets thresholds. Use small datasets and local mode for fast feedback.
Follow-Up Questions
- How would you test a Spark Streaming application?
- Explain strategies for testing UDFs and custom transformations.
- How do you handle flaky tests in Spark test suites?
- Describe testing approaches for Delta Lake operations.
- How would you test Spark applications that interact with external systems?