πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Topic: Join Strategies and Optimization

PySpark AdvancedJoin Strategies⭐ Premium

Advertisement

PySpark Advanced Interview Series

Module 05: Join Strategies β€” The Art of Data Combination

MetaAppleDifficulty: Hard

Interview Question

"At Meta, we join petabytes of social graph data with user profile dimensions daily. Walk us through every join strategy Spark supports, when each is optimal, and how you would handle a join between a 500GB fact table and a 2GB dimension table with skewed keys." β€” Meta Data Engineer Interview

"At Apple, our analytics platform processes complex joins across multiple data sources. Explain the difference between broadcast hash join and sort-merge join in terms of memory usage, network I/O, and fault tolerance. How does Spark's adaptive query execution optimize joins at runtime?" β€” Apple Senior Data Engineer Interview


Join Types in Spark

Spark supports all standard SQL join types:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, broadcast

spark = SparkSession.builder.appName("JoinInterview").getOrCreate()

# Create sample DataFrames
employees = spark.createDataFrame([
    (1, "Alice", "Engineering"),
    (2, "Bob", "Marketing"),
    (3, "Charlie", "Engineering"),
    (4, "Diana", "Sales")
], ["emp_id", "name", "department"])

departments = spark.createDataFrame([
    ("Engineering", "San Francisco", 50),
    ("Marketing", "New York", 30),
    ("Sales", "Chicago", 20),
    ("HR", "Boston", 10)
], ["dept_name", "location", "budget"])

# Inner join (default)
inner = employees.join(departments, employees.department == departments.dept_name, "inner")

# Left outer join
left = employees.join(departments, employees.department == departments.dept_name, "left")

# Right outer join
right = employees.join(departments, employees.department == departments.dept_name, "right")

# Full outer join
full = employees.join(departments, employees.department == departments.dept_name, "full")

# Left semi join (like EXISTS)
semi = employees.join(departments, employees.department == departments.dept_name, "left_semi")

# Left anti join (like NOT EXISTS)
anti = employees.join(departments, employees.department == departments.dept_name, "left_anti")

# Cross join (Cartesian product)
cross = employees.crossJoin(departments)

Join Strategies

1. Broadcast Hash Join (BHJ)

The most efficient join strategy when one table is small enough to fit in executor memory.

# Automatic broadcast (if table < spark.sql.autoBroadcastJoinThreshold)
result = employees.join(departments, "department")

# Explicit broadcast hint
result = employees.join(broadcast(departments), "department")

# SQL broadcast hint
result = spark.sql("""
    /*+ BROADCAST(d) */
    SELECT e.*, d.location, d.budget
    FROM employees e
    JOIN departments d ON e.department = d.dept_name
""")

# Configure threshold (default: 10MB)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", str(50 * 1024 * 1024))  # 50MB

How it works:

  1. Driver collects the small table
  2. Broadcasts it to all executors
  3. Each executor builds a hash table from the small table
  4. Streams the large table through the hash table
  5. No shuffle required!

When to use:

  • Small table fits in executor memory (< autoBroadcastJoinThreshold)
  • Large table is too big to shuffle efficiently
  • Join is an inner join (broadcast works for outer joins too but less efficiently)

ℹ️Meta Interview Insight

At Meta, the autoBroadcastJoinThreshold is often set to 100MB-500MB for data warehouse workloads. However, broadcasting a 500MB table to 1000 executors means 500GB of memory consumed cluster-wide. Always consider total cluster memory impact.

2. Sort-Merge Join (SMJ)

The default strategy for large-large joins. Both tables are shuffled, sorted, and merged.

# Force sort-merge join
result = employees.join(departments, "department", "shuffle_merge")

# SQL hint
result = spark.sql("""
    /*+ SHUFFLE_MERGE(e, d) */
    SELECT * FROM employees e JOIN departments d ON e.department = d.dept_name
""")

# Optimize with partitioning
employees_partitioned = employees.repartition(200, "department")
departments_partitioned = departments.repartition(200, "dept_name")

result = employees_partitioned.join(
    departments_partitioned,
    employees_partitioned.department == departments_partitioned.dept_name
)

How it works:

  1. Both tables are shuffled by the join key
  2. Each partition is sorted by the join key
  3. Sorted partitions are merged using a merge algorithm
  4. Handles data skew through salting

When to use:

  • Both tables are large (don't fit in memory)
  • Join key has relatively uniform distribution
  • Memory is constrained

3. Shuffle Hash Join

Used when one table is moderately sized but too large for broadcast.

# Force shuffle hash join
result = employees.join(departments, "department", "shuffle_hash")

# SQL hint
result = spark.sql("""
    /*+ SHUFFLE_HASH(d) */
    SELECT * FROM employees e JOIN departments d ON e.department = d.dept_name
""")

How it works:

  1. Both tables are shuffled by join key
  2. The smaller table's partition is used to build a hash table
  3. The larger table is streamed through the hash table
  4. No sort required (unlike SMJ)

When to use:

  • One table is moderately sized (larger than broadcast threshold but smaller than the other)
  • Memory is sufficient for hash table construction
  • Avoids sorting overhead of SMJ

4. Broadcast Nested Loop Join

Used for non-equi joins or when other strategies aren't applicable.

# Non-equi join (broadcast nested loop)
result = employees.join(
    departments,
    employees.age > departments.min_age,
    "inner"
)

Performance Comparison

StrategyShuffle RequiredMemory UsageNetwork I/OBest For
Broadcast Hash JoinNoHigh (broadcast)High (broadcast)Small-large joins
Sort-Merge JoinYes (both)LowHigh (shuffle)Large-large joins
Shuffle Hash JoinYes (both)MediumHigh (shuffle)Medium-large joins
Broadcast Nested LoopNoHighHighNon-equi joins

Real-World Scenario: Meta Social Graph Analytics

Problem Statement

Join a 500GB user_actions fact table with a 2GB user_profiles dimension table. The join key (user_id) has significant skew β€” 0.1% of users have 100x more actions than average.

from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, broadcast, rand, concat, lit, when, count
)

spark = SparkSession.builder \
    .appName("MetaJoinOptimization") \
    .config("spark.sql.shuffle.partitions", "500") \
    .config("spark.sql.autoBroadcastJoinThreshold", str(100 * 1024 * 1024)) \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.skewJoin.enabled", "true") \
    .getOrCreate()

# Read data
user_actions = spark.read.parquet("s3a://meta-data/user-actions/")
user_profiles = spark.read.parquet("s3a://meta-data/user-profiles/")

# Step 1: Analyze skew
skew_analysis = user_actions \
    .groupBy("user_id") \
    .count() \
    .orderBy(col("count").desc())

skew_analysis.show(20)
# Output: top 20 users have millions of actions

# Strategy 1: Broadcast join (if profile table fits threshold)
# Profile table is 2GB β€” too large for default broadcast threshold
# But we can increase threshold for this specific join
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", str(3 * 1024 * 1024 * 1024))  # 3GB

result_broadcast = user_actions.join(
    broadcast(user_profiles),
    "user_id",
    "inner"
)

# Strategy 2: Salting for skew handling
# Add random salt to skewed keys
SALT_BUCKETS = 100

# Identify skewed users (top 1% by action count)
skewed_users = user_actions \
    .groupBy("user_id") \
    .count() \
    .filter(col("count") > 10000) \
    .select("user_id")

# Salt the skewed side
salted_actions = user_actions.join(
    skewed_users,
    user_actions.user_id == skewed_users.user_id,
    "left"
).withColumn(
    "salt",
    when(
        skewed_users.user_id.isNotNull(),
        (rand() * SALT_BUCKETS).cast("int")
    ).otherwise(0)
).withColumn(
    "salted_key",
    when(
        col("salt") > 0,
        concat(col("user_id"), lit("_"), col("salt"))
    ).otherwise(col("user_id"))
)

# Replicate profiles for salted keys
salted_profiles = user_profiles.join(
    spark.range(0, SALT_BUCKETS).withColumnRenamed("id", "salt"),
    how="cross"
).withColumn(
    "salted_key",
    concat(col("user_id"), lit("_"), col("salt"))
)

# Join on salted key
result_salted = salted_actions.join(
    salted_profiles,
    salted_actions.salted_key == salted_profiles.salted_key,
    "inner"
).select(
    salted_actions["*"],
    salted_profiles.profile_name,
    salted_profiles.email
)

# Strategy 3: AQE skew join (Spark 3.0+)
# Simply enable AQE and let Spark handle skew automatically
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256m")

result_aqe = user_actions.join(
    user_profiles,
    "user_id",
    "inner"
)

# Verify AQE optimizations
result_aqe.explain(True)

spark.stop()

Multi-Join Optimization

Join Ordering

Spark reorders joins based on cardinality and selectivity:

# Spark automatically reorders joins
# But you can influence with broadcast hints
result = orders \
    .join(broadcast(customers), "customer_id") \
    .join(broadcast(products), "product_id") \
    .join(order_items, "order_id")

# Or in SQL
result = spark.sql("""
    SELECT /*+ BROADCAST(c), BROADCAST(p) */
        o.order_id,
        c.customer_name,
        p.product_name,
        o.total_amount
    FROM orders o
    JOIN customers c ON o.customer_id = c.customer_id
    JOIN products p ON o.product_id = p.product_id
""")

Bucketed Joins

For repeated joins on the same key, bucketing eliminates shuffle:

# Bucket the data at write time
orders.write \
    .bucketBy(256, "customer_id") \
    .sortBy("customer_id") \
    .saveAsTable("bucketed_orders")

customers.write \
    .bucketBy(256, "customer_id") \
    .sortBy("customer_id") \
    .saveAsTable("bucketed_customers")

# Join without shuffle (buckets are pre-aligned)
orders_bucketed = spark.table("bucketed_orders")
customers_bucketed = spark.table("bucketed_customers")

result = orders_bucketed.join(customers_bucketed, "customer_id")
# No shuffle! Both sides are already partitioned by customer_id

πŸ’‘Apple Pro Tip

At Apple, bucketed tables are used extensively for star schema joins. When you have a fact table and multiple dimension tables that are always joined together, bucket all tables on the join key at the same number of buckets. This eliminates shuffle for all subsequent joins.


Handling Data Skew in Joins

Symmetric Hash Join for Skew

# Detect skew before join
key_distribution = user_actions \
    .groupBy("user_id") \
    .count() \
    .withColumnRenamed("count", "action_count")

# Statistical analysis
key_distribution.select(
    F.mean("action_count").alias("mean"),
    F.stddev("action_count").alias("stddev"),
    F.expr("percentile_approx(action_count, 0.99)").alias("p99"),
    F.expr("percentile_approx(action_count, 0.999)").alias("p999"),
    F.max("action_count").alias("max")
).show()

# If skew detected, use salting or AQE

Isolated Shuffle for Skewed Keys

# Separate skewed keys and process them differently
skewed_keys = ["user_123", "user_456", "user_789"]

# Process skewed and non-skewed separately
skewed_actions = user_actions.filter(col("user_id").isin(skewed_keys))
normal_actions = user_actions.filter(~col("user_id").isin(skewed_keys))

# Broadcast join for skewed keys (small number of keys)
skewed_result = skewed_actions.join(
    broadcast(user_profiles),
    "user_id"
)

# Regular join for non-skewed keys
normal_result = normal_actions.join(
    user_profiles,
    "user_id"
)

# Union results
final_result = skewed_result.union(normal_result)

Edge Cases

1. Cartesian Products

# DANGEROUS: Creates N x M rows
result = df_a.crossJoin(df_b)
# If df_a has 1M rows and df_b has 1M rows = 1 trillion rows!

# SAFE: Always filter before cross join
result = df_a.crossJoin(df_b).filter(col("a.date") == col("b.date"))

2. NULL Keys in Joins

# NULLs never match in joins
df_a = spark.createDataFrame([(1, "a"), (None, "b")], ["id", "val"])
df_b = spark.createDataFrame([(1, "x"), (None, "y")], ["id", "val"])

# Inner join: only (1, a, x) matches, NULLs excluded
result = df_a.join(df_b, df_a.id == df_b.id, "inner")

# Use coalesce for null-safe join
result = df_a.join(
    df_b,
    coalesce(df_a.id, lit(-1)) == coalesce(df_b.id, lit(-1)),
    "left"
)

3. Duplicate Keys

# Both sides have duplicate keys β†’ Cartesian product per key
df_a = spark.createDataFrame([(1, "a1"), (1, "a2")], ["id", "val"])
df_b = spark.createDataFrame([(1, "b1"), (1, "b2")], ["id", "val"])

# Result: 4 rows (a1-b1, a1-b2, a2-b1, a2-b2)
result = df_a.join(df_b, "id")

# Deduplicate before join
df_a_dedup = df_a.dropDuplicates(["id"])
df_b_dedup = df_b.dropDuplicates(["id"])
result = df_a_dedup.join(df_b_dedup, "id")

Join Performance Checklist

πŸ’‘Production Join Checklist

  • Always check if broadcast join is feasible (small table < threshold)
  • Use broadcast() hint when you know one table is small
  • Enable AQE for automatic join optimization
  • Pre-partition/bucket data for repeated joins
  • Detect and handle data skew before joining
  • Use left_semi/left_anti instead of inner/left when checking existence
  • Filter data before joining to reduce shuffle
  • Consider bucketing for star schema joins

Summary

Join optimization is one of the most impactful performance tuning areas in Spark. Understanding when to use broadcast vs sort-merge vs shuffle hash joins, how to handle data skew, and how to leverage bucketing for repeated joins separates expert data engineers from intermediate ones. At Meta and Apple, mastering join strategies can mean the difference between a 10-minute query and a 2-hour query.

Advertisement