Data Skew Solutions
Difficulty: Senior Level | Companies: Databricks, Netflix, Uber, Apple, Airbnb
Understanding Data Skew
Data skew occurs when some partitions have significantly more data than others, causing straggler tasks that bottleneck the entire job.
Detecting Skew
from pyspark.sql import SparkSession, functions as F
spark = SparkSession.builder \
.appName("DataSkew") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.enabled", "true") \
.getOrCreate()
# Detect skew by analyzing partition distribution
df = spark.read.parquet("hdfs://data/events")
# Check partition sizes
partition_stats = df \
.withColumn("partition_id", F.spark_partition_id()) \
.groupBy("partition_id") \
.agg(
F.count("*").alias("row_count"),
F.sum(F.col("amount").isNotNull().cast("long")).alias("non_null_count")
)
# Identify skewed partitions
mean_rows = partition_stats.agg(F.avg("row_count")).collect()[0][0]
stddev_rows = partition_stats.agg(F.stddev("row_count")).collect()[0][0]
skewed_partitions = partition_stats.filter(
F.col("row_count") > mean_rows + 3 * stddev_rows
)
print(f"Skewed partitions: {skewed_partitions.count()}")
skewed_partitions.show()
βΉοΈ
Interview Insight: Data skew is one of the most common performance issues in Spark. The key is to detect it early and apply the right mitigation strategy based on the skew pattern.
Skew Detection with Spark UI
# Monitor task execution times in Spark UI
# Look for:
# 1. Tasks with significantly longer execution times
# 2. Shuffle Read/Write imbalance across partitions
# 3. GC time variance across tasks
# Enable detailed task metrics
spark.conf.set("spark.eventLog.enabled", "true")
spark.conf.set("spark.eventLog.dir", "hdfs://logs/spark-events")
# Create a skewed dataset for demonstration
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
schema = StructType([
StructField("key", StringType(), True),
StructField("value", IntegerType(), True)
])
# Create skewed data: one key has 10x more data
skewed_data = [(f"key_{i}", i) for i in range(1000)] # Normal keys
skewed_data += [("skewed_key", i) for i in range(10000)] # Skewed key
df = spark.createDataFrame(skewed_data, schema)
# This will show skew in task metrics
result = df.groupBy("key").agg(F.sum("value"))
result.explain(mode="formatted")
Solution 1: Salting
Salting adds a random prefix to the skewed key to distribute data evenly.
import random
# Salt the skewed key
def add_salt(df, salt_range=10):
"""Add random salt to handle skew"""
salt_udf = F.udf(lambda: str(random.randint(0, salt_range - 1)))
# Add salt column
salted_df = df.withColumn("salt", salt_udf())
# Combine original key with salt
salted_df = salted_df.withColumn(
"salted_key",
F.concat(F.col("key"), F.lit("_"), F.col("salt"))
)
return salted_df
# Apply salting to skewed dataset
df = spark.read.parquet("hdfs://data/skewed-events")
salted_df = add_salt(df, salt_range=10)
# Now group by salted key (evenly distributed)
result = salted_df.groupBy("salted_key").agg(F.sum("value").alias("partial_sum"))
# Remove salt and aggregate again
result = result \
.withColumn("original_key", F.split(F.col("salted_key"), "_")[0]) \
.groupBy("original_key") \
.agg(F.sum("partial_sum").alias("total"))
result.show(20)
Advanced Salting for Joins
# Salt both sides of a join when one side is skewed
left = spark.read.parquet("hdfs://data/skewed-orders")
right = spark.read.parquet("hdfs://data/customers")
# Identify skewed keys in left table
skewed_keys = left \
.groupBy("customer_id") \
.agg(F.count("*").alias("count")) \
.filter(F.col("count") > 10000) \
.select("customer_id") \
.collect()
skewed_keys = [row[0] for row in skewed_keys]
# Salt skewed keys in left table
salt_range = 10
salt_udf = F.udf(lambda: str(random.randint(0, salt_range - 1)))
left_salted = left.withColumn("salt", salt_udf())
left_salted = left_salted.withColumn(
"join_key",
F.when(
F.col("customer_id").isin(skewed_keys),
F.concat(F.col("customer_id"), F.lit("_"), F.col("salt"))
).otherwise(F.col("customer_id"))
)
# Expand right table for salted keys
from pyspark.sql.functions import explode, array
# Create multiple copies for skewed keys
right_expanded = right \
.withColumn("salt_array",
F.when(F.col("customer_id").isin(skewed_keys),
F.array([F.lit(str(i)) for i in range(salt_range)])
).otherwise(F.array(F.lit(None)))
) \
.withColumn("salt", explode("salt_array")) \
.withColumn("join_key",
F.when(F.col("salt").isNotNull(),
F.concat(F.col("customer_id"), F.lit("_"), F.col("salt"))
).otherwise(F.col("customer_id"))
)
# Now join is balanced
result = left_salted.join(right_expanded, "join_key")
β οΈ
Warning: Salting increases data volume by the salt factor. Use it only when skew is severe (>10x difference) and other solutions don't work.
Solution 2: Adaptive Query Execution (AQE)
Spark 3.x's AQE automatically handles skew.
# Enable AQE skew handling
spark = SparkSession.builder \
.appName("AQESkewHandling") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB") \
.config("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5") \
.getOrCreate()
# AQE automatically detects and handles skew
df = spark.read.parquet("hdfs://data/skewed-dataset")
result = df.groupBy("key").agg(F.sum("value"))
# AQE will:
# 1. Detect skewed partitions after shuffle
# 2. Split skewed partitions into smaller sub-partitions
# 3. Optimize join strategy based on runtime statistics
result.explain(mode="formatted")
# Look for "AdaptiveSparkPlan" in the plan
Solution 3: Broadcast Join
from pyspark.sql.functions import broadcast
# If the skewed side is the larger table, broadcast the other side
large_skewed = spark.read.parquet("hdfs://data/skewed-fact") # Skewed
small_dimension = spark.read.parquet("hdfs://data/dimensions") # Small
# Broadcast the small table to avoid shuffle
result = large_skewed.join(broadcast(small_dimension), "key")
# Check that broadcast was used
result.explain(mode="formatted")
# Look for "BroadcastHashJoin"
Solution 4: Repartitioning
# Pre-repartition data to avoid skew
df = spark.read.parquet("hdfs://data/events")
# Repartition by a different key to distribute evenly
df_repartitioned = df.repartition(200, "event_type") # Instead of skewed "user_id"
# Or use random repartitioning
df_random = df.repartition(200) # Random distribution
# For joins, repartition both sides
left = df.repartition(200, "join_key")
right = spark.table("other_table").repartition(200, "join_key")
result = left.join(right, "join_key")
Solution 5: Two-Phase Aggregation
# Split aggregation into two phases to handle skew
df = spark.read.parquet("hdfs://data/skewed-events")
# Phase 1: Partial aggregation with salt
salt_range = 10
salt_udf = F.udf(lambda: str(random.randint(0, salt_range - 1)))
salted = df.withColumn("salt", salt_udf())
partial = salted \
.withColumn("salted_key", F.concat(F.col("key"), F.lit("_"), F.col("salt"))) \
.groupBy("salted_key") \
.agg(F.sum("value").alias("partial_sum"))
# Phase 2: Final aggregation without salt
result = partial \
.withColumn("original_key", F.split(F.col("salted_key"), "_")[0]) \
.groupBy("original_key") \
.agg(F.sum("partial_sum").alias("total"))
result.show(20)
βΉοΈ
Pro Tip: Two-phase aggregation is most effective for count/sum operations. For distinct operations, you may need a different approach.
Monitoring Skew in Production
# Monitor skew metrics continuously
spark = SparkSession.builder \
.appName("SkewMonitoring") \
.config("spark.eventLog.enabled", "true") \
.config("spark.eventLog.dir", "hdfs://logs/spark-events") \
.getOrCreate()
# Log skew metrics
def log_skew_metrics(df, operation_name):
stats = df \
.withColumn("partition_id", F.spark_partition_id()) \
.groupBy("partition_id") \
.agg(F.count("*").alias("row_count")) \
.collect()
counts = [row[1] for row in stats]
if counts:
max_count = max(counts)
min_count = min(counts)
avg_count = sum(counts) / len(counts)
skew_ratio = max_count / avg_count if avg_count > 0 else 0
print(f"{operation_name}:")
print(f" Max: {max_count}, Min: {min_count}, Avg: {avg_count:.0f}")
print(f" Skew ratio: {skew_ratio:.2f}")
if skew_ratio > 3:
print(f" WARNING: Significant skew detected!")
# Use in pipeline
df = spark.read.parquet("hdfs://data/events")
log_skew_metrics(df, "Input")
result = df.groupBy("key").agg(F.sum("value"))
log_skew_metrics(result, "After aggregation")
βΉοΈ
Key Takeaway: Data skew requires a multi-faceted approach: detect early, use AQE for automatic handling, apply salting for severe cases, and broadcast when possible. Always monitor skew metrics in production.
Follow-Up Questions
- How does Spark's AQE skew detection work under the hood?
- Explain the trade-offs between salting and broadcast joins for skew handling.
- How would you handle skew in a streaming application?
- Describe strategies for preventing skew in data ingestion pipelines.
- How does skew affect dynamic allocation and resource utilization?