Testing Spark: LocalCluster, SharedSparkSession, Assertions
Difficulty: Expert | Companies: Databricks, Netflix, Uber, Airbnb, LinkedIn
βΉοΈInterview Context
Testing Spark applications is often overlooked but critical for production reliability. Interviewers expect knowledge of testing strategies, frameworks, and how to test Spark code efficiently.
Question
How do you test Spark applications effectively? Compare different testing approaches: local mode, LocalCluster, and SharedSparkSession. What are the best practices for testing DataFrame transformations, Spark SQL queries, and streaming queries? How do you handle flaky tests and slow test suites?
Detailed Answer
1. Testing Approaches Overview
# Testing approaches comparison:
<div className="my-6 flex justify-center">
<svg viewBox="0 0 800 140" width="100%" style={{ maxWidth: 700 }} xmlns="http://www.w3.org/2000/svg">
<defs>
<linearGradient id="test-hdr" x1="0" y1="0" x2="1" y2="1">
<stop offset="0%" stopColor="#6366f1"/>
<stop offset="100%" stopColor="#4f46e5"/>
</linearGradient>
<filter id="test-shadow">
<feDropShadow dx="0" dy="2" stdDeviation="3" floodOpacity="0.12"/>
</filter>
</defs>
<rect x="10" y="10" width="780" height="120" rx="14" fill="#fff" filter="url(#test-shadow)" stroke="#e2e8f0" strokeWidth="1"/>
<rect x="10" y="10" width="780" height="30" rx="14" fill="url(#test-hdr)"/>
<rect x="10" y="24" width="780" height="16" fill="url(#test-hdr)"/>
<text x="140" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Approach</text>
<text x="340" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Speed</text>
<text x="500" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Realism</text>
<text x="680" y="30" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="11" fontWeight="700">Complexity</text>
<text x="140" y="56" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Local Mode</text>
<text x="340" y="56" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Fast</text>
<text x="500" y="56" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Low</text>
<text x="680" y="56" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Simple</text>
<line x1="30" y1="66" x2="770" y2="66" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="140" y="80" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">LocalCluster</text>
<text x="340" y="80" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
<text x="500" y="80" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
<text x="680" y="80" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
<line x1="30" y1="90" x2="770" y2="90" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="140" y="104" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">SharedSparkSession</text>
<text x="340" y="104" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
<text x="500" y="104" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">High</text>
<text x="680" y="104" textAnchor="middle" fill="#f59e0b" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Medium</text>
<line x1="30" y1="114" x2="770" y2="114" stroke="#e2e8f0" strokeWidth="0.5"/>
<text x="140" y="128" textAnchor="middle" fill="#334155" fontFamily="Inter,system-ui,sans-serif" fontSize="10">Integration Test</text>
<text x="340" y="128" textAnchor="middle" fill="#ef4444" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Slow</text>
<text x="500" y="128" textAnchor="middle" fill="#10b981" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">High</text>
<text x="680" y="128" textAnchor="middle" fill="#ef4444" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">Complex</text>
</svg>
</div>
# Recommendation: Use SharedSparkSession for unit tests
# Use integration tests for end-to-end validation
2. SharedSparkSession (Recommended)
import pytest
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
@pytest.fixture(scope="session")
def spark():
"""Create shared SparkSession for all tests."""
spark = SparkSession.builder \
.master("local[*]") \
.appName("test") \
.config("spark.ui.enabled", "false") \
.config("spark.sql.shuffle.partitions", "2") \
.config("spark.driver.bindAddress", "127.0.0.1") \
.getOrCreate()
yield spark
spark.stop()
@pytest.fixture(scope="function")
def sample_data(spark):
"""Create sample data for tests."""
return spark.createDataFrame(
[(1, "Alice", 100.0),
(2, "Bob", 200.0),
(3, "Charlie", 150.0)],
["id", "name", "amount"]
)
def test_basic_transformation(spark, sample_data):
"""Test basic DataFrame transformation."""
result = sample_data.withColumn(
"doubled", F.col("amount") * 2
)
# Assert schema
assert result.schema["doubled"].dataType == DoubleType()
# Assert values
rows = result.collect()
assert len(rows) == 3
assert rows[0]["doubled"] == 200.0
assert rows[1]["doubled"] == 400.0
assert rows[2]["doubled"] == 300.0
def test_filter(spark, sample_data):
"""Test filter operation."""
result = sample_data.filter(F.col("amount") > 120)
assert result.count() == 2
assert set(result.select("name").collect()) == {("Bob",), ("Charlie",)}
def test_aggregation(spark, sample_data):
"""Test aggregation."""
result = sample_data.agg(
F.sum("amount").alias("total"),
F.avg("amount").alias("avg")
).collect()[0]
assert result["total"] == 450.0
assert result["avg"] == 150.0
3. LocalCluster Mode
# LocalCluster mode: Runs executors in separate JVMs
# More realistic than local mode (separate processes)
@pytest.fixture(scope="session")
def spark_cluster():
"""Create SparkSession with LocalCluster."""
spark = SparkSession.builder \
.master("local-cluster[2, 4, 1024]") # 2 executors, 4 cores, 1GB each
.appName("test-cluster") \
.config("spark.ui.enabled", "false") \
.config("spark.executor.memory", "1g") \
.getOrCreate()
yield spark
spark.stop()
def test_with_cluster(spark_cluster):
"""Test with LocalCluster (more realistic)."""
df = spark_cluster.range(1000000)
# This test runs with actual executor processes
result = df.repartition(10).count()
assert result == 1000000
4. Test Assertions
# Comprehensive assertion strategies:
def assert_dataframe_equal(actual, expected, check_schema=True,
check_values=True, check_row_order=True):
"""Assert two DataFrames are equal."""
# Check schema
if check_schema:
assert actual.schema == expected.schema, \
f"Schema mismatch: {actual.schema} != {expected.schema}"
# Check row count
actual_count = actual.count()
expected_count = expected.count()
assert actual_count == expected_count, \
f"Row count mismatch: {actual_count} != {expected_count}"
# Check values
if check_values:
if check_row_order:
actual_rows = actual.collect()
expected_rows = expected.collect()
assert actual_rows == expected_rows, \
f"Values mismatch:\nActual: {actual_rows}\nExpected: {expected_rows}"
else:
actual_set = set(actual.collect())
expected_set = set(expected.collect())
assert actual_set == expected_set, \
f"Values mismatch (order ignored)"
def test_with_approx(spark):
"""Test with approximate floating point comparison."""
df = spark.createDataFrame([(1.0,)], ["value"])
result = df.withColumn("sqrt", F.sqrt(F.col("value"))).collect()[0]
assert result["sqrt"] == pytest.approx(1.0, rel=1e-6)
def test_schema(spark):
"""Test DataFrame schema."""
df = spark.createDataFrame([], "id: int, name: string, amount: double")
expected_schema = StructType([
StructField("id", IntegerType(), True),
StructField("name", StringType(), True),
StructField("amount", DoubleType(), True)
])
assert df.schema == expected_schema
def test_exceptions(spark):
"""Test that exceptions are raised correctly."""
df = spark.createDataFrame([], "id: int")
with pytest.raises(AnalysisException):
df.select("nonexistent_column").collect()
5. Testing Transformations
# Test complex transformations:
def test_complex_transformation(spark):
"""Test a complex transformation pipeline."""
# Input data
input_df = spark.createDataFrame(
[(1, "2024-01-01", 100.0),
(2, "2024-01-02", 200.0),
(3, "2024-01-03", None)],
["id", "date", "amount"]
)
# Transformation under test
def transform(df):
return df \
.withColumn("date", F.to_date("date")) \
.withColumn("amount", F.coalesce(F.col("amount"), F.lit(0.0))) \
.withColumn("year", F.year("date")) \
.withColumn("month", F.month("date")) \
.filter(F.col("amount") > 0)
# Apply transformation
result = transform(input_df)
# Assert results
expected = spark.createDataFrame(
[(1, 1, 1, 100.0),
(2, 1, 2, 200.0)],
["id", "year", "month", "amount"]
)
assert_dataframe_equal(result, expected, check_schema=False)
def test_window_functions(spark):
"""Test window functions."""
df = spark.createDataFrame(
[("A", 1), ("A", 2), ("B", 3), ("B", 4)],
["group", "value"]
)
window = Window.partitionBy("group").orderBy("value")
result = df.withColumn("rank", F.row_number().over(window))
expected = spark.createDataFrame(
[("A", 1, 1), ("A", 2, 2), ("B", 3, 1), ("B", 4, 2)],
["group", "value", "rank"]
)
assert_dataframe_equal(result, expected)
6. Testing Spark SQL
# Test Spark SQL queries:
def test_spark_sql(spark):
"""Test Spark SQL queries."""
# Create temp view
df = spark.createDataFrame(
[(1, "Alice"), (2, "Bob")],
["id", "name"]
)
df.createOrReplaceTempView("users")
# SQL query under test
result = spark.sql("""
SELECT id, name,
CASE WHEN id = 1 THEN 'First' ELSE 'Other' END as position
FROM users
WHERE id > 0
""")
expected = spark.createDataFrame(
[(1, "Alice", "First"), (2, "Bob", "Other")],
["id", "name", "position"]
)
assert_dataframe_equal(result, expected)
def test_udf_in_sql(spark):
"""Test UDF used in SQL."""
@F.udf(returnType=StringType())
def upper_udf(s):
return s.upper() if s else None
spark.udf.register("upper_udf", upper_udf)
df = spark.createDataFrame([("alice",)], ["name"])
df.createOrReplaceTempView("names")
result = spark.sql("SELECT upper_udf(name) as upper_name FROM names")
expected = spark.createDataFrame([("ALICE",)], ["upper_name"])
assert_dataframe_equal(result, expected)
7. Testing Streaming
# Test Structured Streaming:
def test_streaming_query(spark):
"""Test streaming query with test data."""
from pyspark.sql.streaming import StreamingQuery
# Create streaming source from rate
stream_df = spark.readStream \
.format("rate") \
.option("rowsPerSecond", 10) \
.load()
# Apply transformation
result = stream_df \
.withColumn("double_value", F.col("value") * 2) \
.select("timestamp", "double_value")
# Write to memory sink
query = result.writeStream \
.format("memory") \
.queryName("test_output") \
.outputMode("append") \
.start()
# Wait for data
query.processAllAvailable()
# Read results
output_df = spark.sql("SELECT * FROM test_output")
# Assert
assert output_df.count() > 0
assert "double_value" in output_df.columns
query.stop()
def test_streaming_with_watermark(spark):
"""Test streaming with watermark."""
stream_df = spark.readStream \
.format("rate") \
.option("rowsPerSecond", 10) \
.load() \
.withColumn("event_time", F.current_timestamp())
windowed = stream_df \
.withWatermark("event_time", "10 seconds") \
.groupBy(
F.window("event_time", "5 seconds"),
).count()
query = windowed.writeStream \
.format("memory") \
.queryName("windowed_output") \
.outputMode("update") \
.start()
query.processAllAvailable()
output_df = spark.sql("SELECT * FROM windowed_output")
assert output_df.count() > 0
query.stop()
8. Performance Testing
# Test performance characteristics:
def test_performance(spark):
"""Test that transformation meets performance requirements."""
import time
# Create large dataset
df = spark.range(10000000).withColumn(
"value", F.randn()
)
# Measure execution time
start_time = time.time()
result = df.groupBy(F.floor(F.col("id") / 1000)).agg(
F.sum("value").alias("total")
).count()
execution_time = time.time() - start_time
# Assert performance requirement
assert execution_time < 10.0, \
f"Transformation too slow: {execution_time:.2f}s > 10s"
# Assert result correctness
assert result == 10000
def test_memory_usage(spark):
"""Test memory usage doesn't exceed threshold."""
import psutil
import os
process = psutil.Process(os.getpid())
memory_before = process.memory_info().rss / 1024 / 1024 # MB
# Perform operation
df = spark.range(1000000)
result = df.cache()
result.count()
memory_after = process.memory_info().rss / 1024 / 1024
memory_delta = memory_after - memory_before
# Assert memory usage
assert memory_delta < 500, \
f"Memory usage too high: {memory_delta:.1f}MB > 500MB"
result.unpersist()
9. Test Organization
# Project structure:
<div className="my-6 flex justify-center">
<svg viewBox="0 0 800 240" width="100%" style={{ maxWidth: 500 }} xmlns="http://www.w3.org/2000/svg">
<defs>
<linearGradient id="proj-root" x1="0" y1="0" x2="0" y2="1">
<stop offset="0%" stopColor="#6366f1"/>
<stop offset="100%" stopColor="#4f46e5"/>
</linearGradient>
<linearGradient id="proj-folder" x1="0" y1="0" x2="0" y2="1">
<stop offset="0%" stopColor="#3b82f6"/>
<stop offset="100%" stopColor="#2563eb"/>
</linearGradient>
<linearGradient id="proj-file" x1="0" y1="0" x2="0" y2="1">
<stop offset="0%" stopColor="#10b981"/>
<stop offset="100%" stopColor="#059669"/>
</linearGradient>
<filter id="proj-shadow">
<feDropShadow dx="0" dy="2" stdDeviation="3" floodOpacity="0.12"/>
</filter>
</defs>
<rect x="20" y="10" width="120" height="28" rx="8" fill="url(#proj-root)" filter="url(#proj-shadow)"/>
<text x="80" y="29" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="10" fontWeight="600">tests/</text>
<line x1="80" y1="38" x2="80" y2="55" stroke="#94a3b8" strokeWidth="1.5"/>
<line x1="80" y1="55" x2="30" y2="55" stroke="#94a3b8" strokeWidth="1.5"/>
<line x1="80" y1="75" x2="30" y2="75" stroke="#94a3b8" strokeWidth="1.5"/>
<line x1="80" y1="95" x2="30" y2="95" stroke="#94a3b8" strokeWidth="1.5"/>
<line x1="80" y1="175" x2="30" y2="175" stroke="#94a3b8" strokeWidth="1.5"/>
<line x1="30" y1="55" x2="30" y2="175" stroke="#94a3b8" strokeWidth="1.5"/>
<rect x="40" y="48" width="180" height="20" rx="5" fill="url(#proj-file)" filter="url(#proj-shadow)"/>
<text x="130" y="62" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="9" fontWeight="500">conftest.py</text>
<text x="240" y="62" fill="#6b7280" fontFamily="Inter,system-ui,sans-serif" fontSize="9">Shared fixtures</text>
<rect x="40" y="70" width="100" height="20" rx="5" fill="url(#proj-folder)" filter="url(#proj-shadow)"/>
<text x="90" y="84" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="9" fontWeight="500">unit/</text>
<line x1="90" y1="90" x2="90" y2="100" stroke="#cbd5e1" strokeWidth="1"/>
<line x1="90" y1="100" x2="60" y2="100" stroke="#cbd5e1" strokeWidth="1"/>
<line x1="90" y1="115" x2="60" y2="115" stroke="#cbd5e1" strokeWidth="1"/>
<line x1="90" y1="130" x2="60" y2="130" stroke="#cbd5e1" strokeWidth="1"/>
<line x1="60" y1="100" x2="60" y2="130" stroke="#cbd5e1" strokeWidth="1"/>
<rect x="70" y="96" width="180" height="16" rx="4" fill="#d1fae5"/>
<text x="160" y="108" textAnchor="middle" fill="#065f46" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_transformations.py</text>
<rect x="70" y="112" width="180" height="16" rx="4" fill="#d1fae5"/>
<text x="160" y="124" textAnchor="middle" fill="#065f46" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_sql_queries.py</text>
<rect x="70" y="128" width="180" height="16" rx="4" fill="#d1fae5"/>
<text x="160" y="140" textAnchor="middle" fill="#065f46" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_udfs.py</text>
<rect x="40" y="95" width="130" height="20" rx="5" fill="url(#proj-folder)" filter="url(#proj-shadow)"/>
<text x="105" y="109" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="9" fontWeight="500">integration/</text>
<line x1="105" y1="115" x2="105" y2="150" stroke="#cbd5e1" strokeWidth="1"/>
<line x1="105" y1="150" x2="70" y2="150" stroke="#cbd5e1" strokeWidth="1"/>
<line x1="105" y1="163" x2="70" y2="163" stroke="#cbd5e1" strokeWidth="1"/>
<line x1="70" y1="150" x2="70" y2="163" stroke="#cbd5e1" strokeWidth="1"/>
<rect x="80" y="146" width="190" height="16" rx="4" fill="#dbeafe"/>
<text x="175" y="158" textAnchor="middle" fill="#1e40af" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_data_pipeline.py</text>
<rect x="80" y="160" width="190" height="16" rx="4" fill="#dbeafe"/>
<text x="175" y="172" textAnchor="middle" fill="#1e40af" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_streaming.py</text>
<rect x="40" y="170" width="140" height="20" rx="5" fill="url(#proj-folder)" filter="url(#proj-shadow)"/>
<text x="110" y="184" textAnchor="middle" fill="#fff" fontFamily="Inter,system-ui,sans-serif" fontSize="9" fontWeight="500">performance/</text>
<line x1="110" y1="190" x2="110" y2="205" stroke="#cbd5e1" strokeWidth="1"/>
<line x1="110" y1="205" x2="70" y2="205" stroke="#cbd5e1" strokeWidth="1"/>
<rect x="80" y="200" width="180" height="16" rx="4" fill="#fef3c7"/>
<text x="170" y="212" textAnchor="middle" fill="#92400e" fontFamily="Inter,system-ui,sans-serif" fontSize="8">test_benchmarks.py</text>
</svg>
</div>
# conftest.py:
import pytest
from pyspark.sql import SparkSession
@pytest.fixture(scope="session")
def spark():
spark = SparkSession.builder \
.master("local[*]") \
.appName("test") \
.config("spark.ui.enabled", "false") \
.config("spark.sql.shuffle.partitions", "2") \
.getOrCreate()
yield spark
spark.stop()
# Running tests:
# pytest tests/ -v # Run all tests
# pytest tests/unit/ -v # Run unit tests only
# pytest tests/ -k "test_filter" # Run specific test
# pytest tests/ --cov=src # With coverage
# pytest tests/ -x # Stop on first failure
# pytest tests/ --timeout=60 # Timeout per test
β οΈCommon Pitfall
Not stopping SparkSession between tests causes resource leaks and flaky tests. Always use fixtures with proper cleanup, or use spark.stop() in teardown.
π‘Interview Tip
When discussing testing, mention that test data should be deterministic. Avoid using F.rand() in test data without setting a seed, as non-deterministic tests are flaky and hard to debug.
Summary
| Approach | Speed | Use Case | Key Benefit |
|---|---|---|---|
| SharedSparkSession | Fast | Unit tests | Quick feedback |
| LocalCluster | Medium | Integration tests | Realistic execution |
| Memory Sink | Fast | Streaming tests | No external dependencies |
| Performance Tests | Slow | Benchmarks | Catch regressions |
The key to Spark testing is: use SharedSparkSession for speed, create deterministic test data, and test both correctness and performance.