PySpark Advanced Interview Series
Module 06: Partitioning β Controlling Data Distribution
Interview Question
"At Netflix, we process billions of viewing events daily. Walk us through how partitioning affects query performance, the difference between repartition and coalesce, and how you would design a partitioning strategy for a 10TB viewing history table queried by date and user_id." β Netflix Data Engineer Interview
"At Uber, we need real-time partitioning strategies for ride data. Explain hash vs range partitioning, how partition pruning works, and how you would handle a join between a partitioned fact table and an unpartitioned dimension table." β Uber Senior Data Engineer Interview
Partitioning Fundamentals
Partitioning is the process of dividing data into smaller, manageable chunks. In Spark, partitioning determines how data is distributed across executors and affects shuffle behavior, parallelism, and I/O efficiency.
Why Partitioning Matters
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("PartitioningInterview") \
.config("spark.sql.shuffle.partitions", "200") \
.getOrCreate()
# Create a large DataFrame
df = spark.range(0, 10000000, 1, 100) # 10M rows, 100 partitions
print(f"Number of partitions: {df.rdd.getNumPartitions()}") # 100
# Each partition is processed by one task in parallel
# More partitions = more parallelism (but also more scheduling overhead)
Types of Partitioning
1. Hash Partitioning
Distributes data based on hash of partition key. Ensures equal distribution when keys are uniform.
from pyspark.sql.functions import col, hash, rand
# Hash partition by a column
df_hashed = df.repartition(200, "user_id")
# Custom hash partitioning
df_custom = df.repartition(200, (col("user_id").hash() % 200))
# Verify partition distribution
partition_counts = df_hashed \
.rdd \
.mapPartitionsWithIndex(lambda idx, it: [(idx, sum(1 for _ in it))]) \
.collect()
for idx, count in partition_counts[:10]:
print(f"Partition {idx}: {count} rows")
2. Range Partitioning
Distributes data based on value ranges. Useful for time-series data.
from pyspark.sql.functions import col, floor
# Range partition by date
df_range = df \
.withColumn("partition_id", floor(col("timestamp") / (24 * 3600)).cast("int")) \
.repartition(100, "partition_id")
# Using rangePartitioner (RDD-level)
rdd = df.rdd.map(lambda row: (row.timestamp, row))
partitioner = RangePartitioner(100, rdd)
partitioned_rdd = rdd.partitionBy(partitioner)
3. Round-Robin Partitioning
Distributes data evenly without regard to content. Ensures balanced partitions.
# Round-robin (no column specified)
df_roundrobin = df.repartition(200)
# This distributes rows evenly across 200 partitions
# Useful before joins to avoid skew
4. Coalesce (Reduce Partitions)
Reduces the number of partitions without a full shuffle. Efficient for decreasing partitions.
# Coalesce from 200 to 50 partitions (no full shuffle)
df_coalesced = df.coalesce(50)
# Coalesce is efficient when:
# 1. Reducing partitions (200 β 50)
# 2. Data is already relatively balanced
# 3. You want to avoid full shuffle
Repartition vs Coalesce
| Operation | Shuffle | Use Case | Performance |
|---|---|---|---|
repartition(n) | Full shuffle | Increase or decrease partitions | Slower (full shuffle) |
repartition(col) | Full shuffle | Partition by column | Slower (full shuffle) |
coalesce(n) | No shuffle | Only decrease partitions | Faster (no shuffle) |
# repartition: full shuffle, can increase or decrease
df_repartitioned = df.repartition(500) # Full shuffle to 500 partitions
df_repartitioned = df.repartition(200, "user_id") # Full shuffle by user_id
# coalesce: no shuffle, can only decrease
df_coalesced = df.coalesce(50) # Merge 200 partitions to 50 (no shuffle)
# IMPORTANT: coalesce cannot increase partitions
# This will NOT increase partitions beyond current count
df_wrong = df.coalesce(500) # Still has 200 partitions!
β οΈCommon Pitfall
At Netflix, a common mistake is using coalesce() to increase partitions. Coalesce can only reduce partitions. If you need to increase parallelism, use repartition() which performs a full shuffle.
Partition Pruning
Spark can skip reading entire partitions if the query filters on the partition column.
# Write data partitioned by date
df.write \
.partitionBy("year", "month", "day") \
.parquet("s3a://bucket/partitioned-data/")
# Read with partition pruning
# Spark only reads relevant partitions
df_pruned = spark.read.parquet("s3a://bucket/partitioned-data/") \
.filter(col("year") == 2024) \
.filter(col("month") == 1) \
.filter(col("day") == 15)
# Verify pruning with explain plan
df_pruned.explain(True)
# Shows PartitionFilters: [year = 2024, month = 1, day = 15]
Partition Discovery
# Spark automatically discovers partitions in directory structure
df = spark.read.parquet("s3a://bucket/data/")
print(df.schema) # Includes partition columns
# Disable partition discovery
df = spark.read.option("pathFilter", "false").parquet("s3a://bucket/data/")
# Specify partition columns explicitly
df = spark.read.parquet("s3a://bucket/data/", partitionSchema="year INT, month INT")
Real-World Scenario: Netflix Viewing Analytics
Problem Statement
Design an optimal partitioning strategy for a 10TB viewing history table that supports:
- Daily analytics queries (filter by date)
- User-level queries (filter by user_id)
- Content-level queries (filter by content_id)
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, date_format, dayofweek, hash
spark = SparkSession.builder \
.appName("NetflixPartitioning") \
.config("spark.sql.shuffle.partitions", "500") \
.config("spark.sql.adaptive.enabled", "true") \
.getOrCreate()
# Read raw viewing data
viewing_df = spark.read.parquet("s3a://netflix-data/viewing-raw/")
# Step 1: Add partition columns
partitioned_df = viewing_df \
.withColumn("view_date", to_date(col("view_timestamp"))) \
.withColumn("year", year(col("view_date"))) \
.withColumn("month", month(col("view_date"))) \
.withColumn("day", dayofmonth(col("view_date"))) \
.withColumn("day_of_week", dayofweek(col("view_date")))
# Step 2: Write with multi-level partitioning
# Partition by date first (most common filter)
partitioned_df.write \
.partitionBy("year", "month", "day") \
.mode("overwrite") \
.option("compression", "snappy") \
.parquet("s3a://netflix-data/viewing-partitioned/")
# Step 3: For user-level queries, create a bucketed copy
# Bucket by user_id for fast user lookups
partitioned_df.write \
.bucketBy(1024, "user_id") \
.sortBy("view_timestamp") \
.partitionBy("year", "month") \
.mode("overwrite") \
.saveAsTable("viewing_bucketed")
# Step 4: For content-level queries, create another bucketed table
partitioned_df.write \
.bucketBy(512, "content_id") \
.sortBy("view_timestamp") \
.partitionBy("year", "month") \
.mode("overwrite") \
.saveAsTable("viewing_by_content")
# Query optimization examples
# 1. Daily analytics (benefits from date partitioning)
daily_stats = spark.read.parquet("s3a://netflix-data/viewing-partitioned/") \
.filter((col("year") == 2024) & (col("month") == 1) & (col("day") == 15)) \
.groupBy("content_id") \
.agg(count("*").alias("views"))
# 2. User-level query (benefits from bucketing)
user_history = spark.table("viewing_bucketed") \
.filter(col("user_id") == "user_12345") \
.orderBy(col("view_timestamp").desc())
# 3. Content-level query (benefits from content bucketing)
content_stats = spark.table("viewing_by_content") \
.filter(col("content_id") == "content_67890") \
.groupBy("user_id") \
.agg(count("*").alias("views"))
# Show explain plans
daily_stats.explain(True)
user_history.explain(True)
content_stats.explain(True)
spark.stop()
Partition Count Guidelines
Too Few Partitions
# BAD: 10TB dataset in 10 partitions = 1TB per partition
df = spark.read.parquet("s3a://bucket/huge-data/")
print(f"Partitions: {df.rdd.getNumPartitions()}") # 10
# Each partition is too large for parallel processing
# Causes OOM, long garbage collection, poor parallelism
Too Many Partitions
# BAD: 1GB dataset in 10000 partitions = 100KB per partition
df = spark.range(0, 1000000, 1, 10000) # 1M rows, 10K partitions
print(f"Partitions: {df.rdd.getNumPartitions()}") # 10000
# Each partition is tiny, causing:
# 1. Excessive scheduling overhead
# 2. Many small files
# 3. Wasted memory for task metadata
Sweet Spot
# Rule of thumb: 128MB-256MB per partition
# For 10TB dataset: 10TB / 256MB = ~40,000 partitions
# For 1GB dataset: 1GB / 256MB = ~4 partitions
# Calculate optimal partition count
data_size_gb = 10000 # 10TB in GB
target_partition_size_mb = 256
optimal_partitions = (data_size_gb * 1024) / target_partition_size_mb
print(f"Optimal partitions: {optimal_partitions}") # ~40,000
π‘Uber Pro Tip
At Uber, the default spark.sql.shuffle.partitions=200 is often too low for large joins. Always tune this based on data size. A common formula: partitions = data_size_bytes / (128 * 1024 * 1024).
Handling Partition Skew
Detecting Skew
# Check partition sizes
partition_sizes = df.rdd \
.mapPartitionsWithIndex(lambda idx, it: [(idx, sum(1 for _ in it))]) \
.collect()
sizes = [size for _, size in partition_sizes]
print(f"Min partition size: {min(sizes)}")
print(f"Max partition size: {max(sizes)}")
print(f"Avg partition size: {sum(sizes) / len(sizes)}")
print(f"Skew ratio: {max(sizes) / (sum(sizes) / len(sizes)):.2f}")
Fixing Skew
# Strategy 1: Salt the skewed key
from pyspark.sql.functions import rand, concat, lit
SALT_BUCKETS = 100
# Identify skewed keys
skewed_keys = df \
.groupBy("user_id") \
.count() \
.filter(col("count") > 100000) \
.select("user_id") \
.collect()
skewed_key_list = [row.user_id for row in skewed_keys]
# Salt the skewed keys
salted_df = df.withColumn(
"salt",
when(
col("user_id").isin(skewed_key_list),
(rand() * SALT_BUCKETS).cast("int")
).otherwise(0)
).withColumn(
"salted_key",
when(
col("salt") > 0,
concat(col("user_id"), lit("_"), col("salt"))
).otherwise(col("user_id"))
)
# Repartition by salted key
df_fixed = salted_df.repartition(400, "salted_key")
# Strategy 2: Use AQE (Spark 3.0+)
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
Partition Pruning Optimization
# Write with partition columns
df.write \
.partitionBy("date", "region") \
.parquet("s3a://bucket/data/")
# Query with partition filters (pruned)
df = spark.read.parquet("s3a://bucket/data/") \
.filter(col("date") == "2024-01-15") \
.filter(col("region") == "us-east-1")
# Verify pruning in explain plan
df.explain(True)
# Shows: PartitionFilters: [date = 2024-01-15, region = us-east-1]
# Anti-pattern: filter after read
df_bad = spark.read.parquet("s3a://bucket/data/")
df_bad = df_bad.filter(col("date") == "2024-01-15") # Still prunes, but less efficient
Dynamic Partition Overwrite
# Enable dynamic partition overwrite
spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
# Write with dynamic overwrite (only overwrites affected partitions)
df.write \
.mode("overwrite") \
.partitionBy("date") \
.parquet("s3a://bucket/data/")
# Without dynamic mode, OVERWRITE deletes ALL partitions first
# With dynamic mode, only affected partitions are replaced
βΉοΈNetflix Interview Insight
At Netflix, dynamic partition overwrite is critical for incremental daily loads. Without it, overwriting a single day's partition would delete all other days' data. Always enable spark.sql.sources.partitionOverwriteMode=dynamic in production.
Best Practices
| Practice | When to Use | Example |
|---|---|---|
| Partition by date | Time-series data | .partitionBy("year", "month", "day") |
| Bucket by join key | Repeated joins | .bucketBy(256, "user_id") |
| Coalesce before write | Reducing file count | .coalesce(100).write.parquet() |
| Repartition before join | Balancing data | .repartition(200, "join_key") |
| Dynamic partition overwrite | Incremental loads | spark.sql.sources.partitionOverwriteMode=dynamic |
| Partition pruning | Query optimization | Filter on partition columns early |
Common Anti-Patterns
# ANTI-PATTERN 1: Coalesce to increase partitions
df.coalesce(1000) # WRONG: cannot increase beyond current count
# CORRECT
df.repartition(1000)
# ANTI-PATTERN 2: Partition by high-cardinality column
df.write.partitionBy("user_id") # WRONG: millions of partitions!
# Each partition has very few rows, causing many small files
# CORRECT: Partition by low-cardinality, filter by high-cardinality
df.write.partitionBy("date").bucketBy(256, "user_id")
# ANTI-PATTERN 3: Not checking partition distribution
df.write.partitionBy("date").parquet("s3a://bucket/data/")
# If some dates have 100x more data than others, partitions are skewed
# CORRECT: Check distribution before writing
df.groupBy("date").count().orderBy(col("count").desc()).show()
Summary
Partitioning is fundamental to Spark performance. Understanding when to use hash vs range partitioning, how repartition differs from coalesce, and how to optimize for partition pruning separates senior data engineers from intermediates. At Netflix and Uber, proper partitioning can reduce query times from hours to seconds.