PySpark Advanced Interview Series
Module 20: Production Deployment β From Dev to Prod
Interview Question
"At Google, we deploy Spark pipelines through robust CI/CD processes. Walk us through your testing strategy for Spark applications, how you would set up a CI/CD pipeline, and what production safeguards you would implement." β Google Data Engineer Interview
"At Microsoft, we need production-grade Spark applications. Explain how you would structure a Spark project, handle configuration management across environments, implement logging and monitoring, and ensure disaster recovery." β Microsoft Senior Data Engineer Interview
Project Structure
Configuration Management
# src/main/python/config/settings.py
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class SparkConfig:
app_name: str
master: str
executor_memory: str
executor_cores: int
num_executors: int
driver_memory: str
shuffle_partitions: int
enable_aqe: bool
@dataclass
class StorageConfig:
input_path: str
output_path: str
checkpoint_path: str
temp_path: str
@dataclass
class AppConfig:
environment: str
spark: SparkConfig
storage: StorageConfig
enable_monitoring: bool
log_level: str
def load_config(environment: str) -> AppConfig:
"""Load configuration based on environment"""
configs = {
"dev": AppConfig(
environment="dev",
spark=SparkConfig(
app_name="MySparkApp-Dev",
master="local[*]",
executor_memory="4g",
executor_cores=2,
num_executors=2,
driver_memory="2g",
shuffle_partitions=20,
enable_aqe=True
),
storage=StorageConfig(
input_path="s3a://dev-bucket/input/",
output_path="s3a://dev-bucket/output/",
checkpoint_path="s3a://dev-bucket/checkpoints/",
temp_path="s3a://dev-bucket/temp/"
),
enable_monitoring=True,
log_level="DEBUG"
),
"prod": AppConfig(
environment="prod",
spark=SparkConfig(
app_name="MySparkApp-Prod",
master="yarn",
executor_memory="16g",
executor_cores=4,
num_executors=100,
driver_memory="8g",
shuffle_partitions=500,
enable_aqe=True
),
storage=StorageConfig(
input_path="s3a://prod-bucket/input/",
output_path="s3a://prod-bucket/output/",
checkpoint_path="s3a://prod-bucket/checkpoints/",
temp_path="s3a://prod-bucket/temp/"
),
enable_monitoring=True,
log_level="WARN"
)
}
return configs[environment]
Spark Session Factory
# src/main/python/utils/spark.py
from pyspark.sql import SparkSession
from config.settings import AppConfig
class SparkSessionFactory:
_instance = None
@classmethod
def get_instance(cls, config: AppConfig) -> SparkSession:
if cls._instance is None:
builder = SparkSession.builder \
.appName(config.spark.app_name) \
.config("spark.master", config.spark.master) \
.config("spark.executor.memory", config.spark.executor_memory) \
.config("spark.executor.cores", config.spark.executor_cores) \
.config("spark.executor.instances", str(config.spark.num_executors)) \
.config("spark.driver.memory", config.spark.driver_memory) \
.config("spark.sql.shuffle.partitions", str(config.spark.shuffle_partitions)) \
.config("spark.sql.adaptive.enabled", str(config.spark.enable_aqe)) \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.config("spark.sql.execution.arrow.pyspark.enabled", "true")
if config.environment == "prod":
builder = builder \
.config("spark.eventLog.enabled", "true") \
.config("spark.eventLog.dir", "s3a://spark-logs/event-logs/") \
.config("spark.history.fs.logDirectory", "s3a://spark-logs/event-logs/")
cls._instance = builder.getOrCreate()
return cls._instance
@classmethod
def stop(cls):
if cls._instance:
cls._instance.stop()
cls._instance = None
Testing Strategies
Unit Tests
# tests/unit/test_transformers.py
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from src.etl.transformers import clean_data, add_features
@pytest.fixture(scope="session")
def spark():
return SparkSession.builder \
.appName("TestSession") \
.master("local[*]") \
.getOrCreate()
@pytest.fixture
def sample_data(spark):
data = [
(1, "Alice", 100, "2024-01-01"),
(2, None, 200, "2024-01-02"),
(3, "Charlie", -50, "invalid-date")
]
schema = StructType([
StructField("id", IntegerType(), False),
StructField("name", StringType(), True),
StructField("amount", IntegerType(), True),
StructField("date", StringType(), True)
])
return spark.createDataFrame(data, schema)
def test_clean_data_removes_nulls(spark, sample_data):
result = clean_data(sample_data)
assert result.filter("name is null").count() == 0
def test_clean_data_validates_amount(spark, sample_data):
result = clean_data(sample_data)
assert result.filter("amount < 0").count() == 0
def test_add_features(spark, sample_data):
clean = clean_data(sample_data)
result = add_features(clean)
assert "amount_category" in result.columns
def test_transformations_preserve_id(spark, sample_data):
result = clean_data(sample_data)
assert set(result.select("id").collect()) == {(1,), (2,)}
Integration Tests
# tests/integration/test_pipeline.py
import pytest
from pyspark.sql import SparkSession
from src.main import run_pipeline
@pytest.fixture(scope="session")
def spark():
return SparkSession.builder \
.appName("IntegrationTest") \
.master("local[*]") \
.getOrCreate()
def test_full_pipeline(spark, test_data_path, test_output_path):
"""Test complete ETL pipeline"""
result = run_pipeline(spark, test_data_path, test_output_path)
# Verify output exists
output_df = spark.read.parquet(test_output_path)
assert output_df.count() > 0
# Verify schema
expected_columns = ["id", "name", "amount", "category"]
assert set(output_df.columns) == set(expected_columns)
# Verify data quality
assert output_df.filter("amount < 0").count() == 0
assert output_df.filter("name is null").count() == 0
def test_pipeline_idempotency(spark, test_data_path, test_output_path):
"""Test that running pipeline twice produces same result"""
run_pipeline(spark, test_data_path, test_output_path)
first_count = spark.read.parquet(test_output_path).count()
run_pipeline(spark, test_data_path, test_output_path)
second_count = spark.read.parquet(test_output_path).count()
assert first_count == second_count
CI/CD Pipeline
# .github/workflows/spark-ci.yml
name: Spark CI/CD
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'
- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest pyspark
- name: Run unit tests
run: pytest tests/unit/ -v --tb=short
- name: Run integration tests
run: pytest tests/integration/ -v --tb=short
deploy-dev:
needs: test
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/develop'
steps:
- uses: actions/checkout@v3
- name: Deploy to Dev
run: |
./scripts/deploy.sh dev
deploy-prod:
needs: test
runs-on: ubuntu-latest
if: github.ref == 'refs/heads/main'
steps:
- uses: actions/checkout@v3
- name: Deploy to Production
run: |
./scripts/deploy.sh prod
Production Deployment Scripts
#!/bin/bash
# scripts/deploy.sh
ENVIRONMENT=$1
echo "Deploying to $ENVIRONMENT..."
# Validate configuration
python -c "from config.settings import load_config; load_config('$ENVIRONMENT')"
# Run tests
pytest tests/ -v --tb=short
# Package application
python setup.py bdist_wheel
# Deploy to cluster
spark-submit \
--master yarn \
--deploy-mode cluster \
--name "MySparkApp-$ENVIRONMENT" \
--conf spark.dynamicAllocation.enabled=true \
--conf spark.sql.adaptive.enabled=true \
dist/my_spark_app-*.whl \
--environment $ENVIRONMENT
echo "Deployment complete!"
Logging and Monitoring
# src/main/python/utils/logging.py
import logging
from pyspark.sql import SparkSession
def setup_logging(spark: SparkSession, log_level: str = "INFO"):
"""Configure logging for Spark application"""
# Python logging
logging.basicConfig(
level=getattr(logging, log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Spark logging (via log4j)
spark.sparkContext._jsc.log().info(f"Setting log level to {log_level}")
return logging.getLogger(__name__)
class SparkMetrics:
def __init__(self, spark: SparkSession):
self.spark = spark
self.logger = logging.getLogger(__name__)
def log_job_start(self, job_name: str):
self.logger.info(f"Starting job: {job_name}")
self.start_time = time.time()
def log_job_end(self, job_name: str, success: bool):
duration = time.time() - self.start_time
status = "SUCCESS" if success else "FAILED"
self.logger.info(f"Job {job_name} {status} in {duration:.2f}s")
# Send to monitoring system
metrics = {
"job_name": job_name,
"duration": duration,
"success": success,
"timestamp": datetime.now().isoformat()
}
self.send_metrics(metrics)
def send_metrics(self, metrics: dict):
# Send to CloudWatch, Datadog, etc.
pass
Error Handling and Recovery
# src/main/python/utils/error_handling.py
from pyspark.sql import SparkSession
import logging
logger = logging.getLogger(__name__)
class SparkJobError(Exception):
pass
def retry_on_failure(max_retries: int = 3, delay: int = 60):
"""Decorator for retrying failed operations"""
def decorator(func):
def wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise SparkJobError(f"Job failed after {max_retries} attempts: {e}")
return wrapper
return decorator
@retry_on_failure(max_retries=3)
def run_pipeline(spark: SparkSession, config):
"""Run ETL pipeline with retry logic"""
try:
# Extract
raw_data = extract(spark, config.storage.input_path)
# Transform
transformed = transform(raw_data)
# Load
load(transformed, config.storage.output_path)
logger.info("Pipeline completed successfully")
except Exception as e:
logger.error(f"Pipeline failed: {e}")
raise
Production Checklist
π‘Production Deployment Checklist
Before Deployment:
- All unit tests pass
- All integration tests pass
- Configuration reviewed for target environment
- Resource requirements documented
- Monitoring and alerting configured
- Rollback plan documented
During Deployment:
- Deploy to staging first
- Run smoke tests
- Verify data quality
- Monitor job execution
After Deployment:
- Verify output data
- Check performance metrics
- Update documentation
- Notify stakeholders
Disaster Recovery
# Backup strategy
def backup_delta_table(spark, table_path, backup_path):
"""Backup Delta table to separate location"""
spark.sql(f"""
CREATE OR REPLACE TEMPORARY VIEW backup_source
USING delta OPTIONS (path '{table_path}')
""")
spark.sql(f"""
CREATE OR REPLACE TEMPORARY VIEW backup_target
USING delta OPTIONS (path '{backup_path}')
""")
spark.sql("""
MERGE INTO backup_target AS target
USING backup_source AS source
ON 1 = 1
WHEN NOT MATCHED THEN INSERT *
""")
# Recovery procedure
def restore_delta_table(spark, backup_path, restore_path, version=None):
"""Restore Delta table from backup"""
if version:
spark.sql(f"""
RESTORE TABLE delta.`{restore_path}` TO VERSION AS OF {version}
""")
else:
spark.sql(f"""
RESTORE TABLE delta.`{restore_path}`
FROM delta.`{backup_path}`
""")
Summary
Production Spark applications require robust testing, configuration management, CI/CD, monitoring, and disaster recovery. At Google and Microsoft, these practices ensure reliable, maintainable, and scalable data pipelines. The investment in production hardening pays dividends in reduced incidents and faster iteration.