πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Topic: Catalyst Optimizer and Query Optimization

PySpark AdvancedQuery Optimization⭐ Premium

Advertisement

PySpark Advanced Interview Series

Module 09: Catalyst & Tungsten β€” The Optimization Engine

GoogleNetflixDifficulty: Hard

Interview Question

"At Google, we optimize query performance by understanding Catalyst's optimization phases. Walk us through the complete query optimization pipeline from logical plan to physical execution. How does predicate pushdown work, and what role does the cost-based optimizer play?" β€” Google Data Engineer Interview

"At Netflix, we tune Spark queries for petabyte-scale workloads. Explain how Tungsten's whole-stage code generation works, how you would read and interpret an EXPLAIN plan, and what techniques you use to identify optimization opportunities." β€” Netflix Senior Data Engineer Interview


Catalyst Optimizer Pipeline

Catalyst is Spark's extensible query optimization framework. It transforms queries through four phases:

Architecture Diagram
SQL/DataFrame API
      ↓
Parsed Logical Plan
      ↓
Analyzed Logical Plan
      ↓
Optimized Logical Plan
      ↓
Physical Plans (multiple)
      ↓
Selected Physical Plan
      ↓
RDD Code (via Tungsten)

Phase 1: Analysis

Catalyst resolves references and validates the query:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as _sum

spark = SparkSession.builder.appName("CatalystInterview").getOrCreate()

# Read data
sales = spark.read.parquet("s3a://bucket/sales/")
products = spark.read.parquet("s3a://bucket/products/")

# This query will be analyzed
result = sales \
    .join(products, "product_id") \
    .filter(col("category") == "Electronics") \
    .groupBy("product_name") \
    .agg(_sum("revenue").alias("total_revenue"))

# View the analyzed logical plan
result.explain("extended")

The analyzer:

  1. Resolves column references (e.g., category β†’ products.category)
  2. Validates data types
  3. Checks for ambiguous references
  4. Fills in implicit defaults (e.g., join type)

Phase 2: Optimization (Logical)

Catalyst applies rule-based and cost-based optimizations:

Predicate Pushdown

# Catalyst pushes filters closer to data source
result = sales \
    .filter(col("date") == "2024-01-01") \
    .join(products, "product_id") \
    .filter(col("category") == "Electronics")

# Catalyst rewrites to:
# 1. Filter sales by date (pushed to scan)
# 2. Filter products by category (pushed to scan)
# 3. Join filtered results

# Verify with explain
result.explain(True)

Column Pruning

# Catalyst removes unused columns from scans
result = sales \
    .select("product_id", "revenue") \
    .join(products, "product_id") \
    .select("product_name", "revenue")

# Catalyst prunes unused columns from both tables
# Only reads product_id, revenue from sales
# Only reads product_id, product_name from products

Constant Folding

# Catalyst evaluates constant expressions at compile time
result = sales.withColumn(
    "adjusted_revenue",
    col("revenue") * (1 + 0.1)  # Constant: 1.1
)

# Catalyst rewrites to:
# result.withColumn("adjusted_revenue", col("revenue") * 1.1)

Join Reordering

# Catalyst reorders joins based on table sizes and selectivity
result = large_table \
    .join(medium_table, "id") \
    .join(small_table, "id")

# Catalyst may reorder to:
# small_table.join(medium_table, "id").join(large_table, "id")
# Processing small table first reduces intermediate result size

Phase 3: Physical Planning

Catalyst generates multiple physical plans and selects the best one:

# View all physical plans
result.explain("formatted")

# Output includes:
# == Physical Plan ==
# *(2) Project [product_name, revenue]
# +- *(2) BroadcastHashJoin [product_id], [product_id], BuildLeft
#    :- BroadcastExchange HashedRelationBroadcastMode(List(input[0, int, false]))
#    |  +- *(1) Filter (category = Electronics)
#    |     +- *(1) ColumnarToRow
#    |        +- ParquetScan [product_id, product_name, category]
#    +- *(2) Filter (date = 2024-01-01)
#       +- *(2) ColumnarToRow
#          +- ParquetScan [product_id, revenue, date]

Physical Plan Types

PlanDescriptionWhen Used
FileScanRead from storageBase scan
FilterApply predicateAlways
ProjectSelect columnsAlways
BroadcastHashJoinBroadcast small tableSmall table < threshold
SortMergeJoinSort and mergeLarge-large join
ShuffleHashJoinHash join after shuffleMedium-large join
SortSort dataORDER BY, GROUP BY
AggregateCompute aggregationsGROUP BY, aggregations

Phase 4: Code Generation (Tungsten)

Tungsten generates optimized JVM bytecode at runtime:

# Tungsten optimizations:
# 1. Whole-stage code generation (fuses multiple operations)
# 2. Cache-aware computation (uses CPU cache efficiently)
# 3. Off-heap memory management (reduces GC pressure)

# Verify code generation in explain plan
result.explain("codegen")

# Output shows generated Java code
# == Generated Code ==
# class GeneratedClass {
#   public Object generate(Object[] references) {
#     return new CodeIterator(references);
#   }
# }

ℹ️Google Interview Insight

At Google, understanding Tungsten's code generation is crucial. Whole-stage code generation fuses multiple operators (filter + project + aggregate) into a single generated function, eliminating virtual function calls and improving CPU efficiency by 2-10x.


Reading EXPLAIN Plans

# Simple explain
df.explain()

# Extended explain (shows all phases)
df.explain("extended")

# Formatted explain (most readable)
df.explain("formatted")

# Codegen explain (shows generated code)
df.explain("codegen")

# Cost-based explain (shows statistics)
df.explain("cost")

Plan Reading Example

from pyspark.sql.functions import col, sum as _sum, broadcast

result = sales \
    .filter(col("date") >= "2024-01-01") \
    .select("product_id", "revenue") \
    .join(broadcast(products), "product_id") \
    .groupBy("category") \
    .agg(_sum("revenue").alias("total_revenue"))

result.explain("formatted")

Output:

Architecture Diagram
== Physical Plan ==
*(3) HashAggregate(keys=[category], functions=[sum(revenue)])
+- *(3) BroadcastHashJoin [product_id], [product_id], BuildRight
   :- *(3) HashAggregate(keys=[product_id], functions=[sum(revenue)])
   |  +- *(3) Project [product_id, revenue]
   |     +- *(3) Filter (date >= 2024-01-01)
   |        +- *(3) ColumnarToRow
   |           +- ParquetScan [product_id, revenue, date]
   +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, int, false]))
      +- *(1) Filter (category IS NOT NULL)
         +- *(1) ColumnarToRow
            +- ParquetScan [product_id, category]

Reading this plan:

  1. Innermost: Parquet scans read data
  2. Filter: applies date predicate
  3. Project: selects columns
  4. Aggregate: groups by product_id
  5. Broadcast: sends products to all executors
  6. Join: joins aggregated sales with products
  7. Final aggregate: groups by category

Real-World Scenario: Google Query Optimization

Problem Statement

Optimize a complex analytics query that was running in 45 minutes. The query joins 3 tables, applies multiple filters, and computes complex aggregations.

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark import StorageLevel

spark = SparkSession.builder \
    .appName("GoogleQueryOptimization") \
    .config("spark.sql.shuffle.partitions", "500") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.autoBroadcastJoinThreshold", str(100 * 1024 * 1024)) \
    .getOrCreate()

# === ORIGINAL SLOW QUERY (45 minutes) ===
def slow_query():
    # Read all data without pruning
    orders = spark.read.parquet("s3a://google-data/orders/")
    customers = spark.read.parquet("s3a://google-data/customers/")
    products = spark.read.parquet("s3a://google-data/products/")
    
    # Complex query without optimization hints
    result = orders \
        .join(customers, "customer_id") \
        .join(products, "product_id") \
        .filter(col("order_date") >= "2024-01-01") \
        .filter(col("status") == "completed") \
        .groupBy("category", "region") \
        .agg(
            sum("amount").alias("total_revenue"),
            count("*").alias("order_count"),
            avg("amount").alias("avg_order_value")
        )
    
    return result

# === OPTIMIZED QUERY (3 minutes) ===
def optimized_query():
    # 1. Read only needed columns (column pruning)
    orders = spark.read.parquet("s3a://google-data/orders/") \
        .select("order_id", "customer_id", "product_id", "amount", "status", "order_date")
    
    customers = spark.read.parquet("s3a://google-data/customers/") \
        .select("customer_id", "region")
    
    products = spark.read.parquet("s3a://google-data/products/") \
        .select("product_id", "category")
    
    # 2. Filter early (predicate pushdown)
    orders = orders \
        .filter(col("order_date") >= "2024-01-01") \
        .filter(col("status") == "completed")
    
    # 3. Cache filtered orders (used multiple times)
    orders.cache()
    orders.count()  # Force materialization
    
    # 4. Broadcast small tables
    # 5. Repartition for balanced joins
    orders = orders.repartition(500, "product_id")
    
    # 6. Optimized join order
    result = orders \
        .join(broadcast(products), "product_id") \
        .join(broadcast(customers), "customer_id") \
        .groupBy("category", "region") \
        .agg(
            sum("amount").alias("total_revenue"),
            count("*").alias("order_count"),
            avg("amount").alias("avg_order_value")
        )
    
    # 7. Cache result for downstream use
    result.cache()
    
    # Clean up
    orders.unpersist()
    
    return result

# Compare performance
import time

start = time.time()
slow_result = slow_query()
slow_result.count()
slow_time = time.time() - start
print(f"Slow query: {slow_time:.0f}s")

start = time.time()
fast_result = optimized_query()
fast_result.count()
fast_time = time.time() - start
print(f"Optimized query: {fast_time:.0f}s")

# View optimized plan
fast_result.explain("formatted")

spark.stop()

AQE (Adaptive Query Execution)

Spark 3.0+ introduces runtime query optimization:

# Enable AQE
spark.conf.set("spark.sql.adaptive.enabled", "true")

# AQE features:
# 1. Auto coalesce shuffle partitions
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

# 2. Optimize sort-merge join to broadcast join
spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", "true")

# 3. Handle skewed joins
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

# 4. Optimize skew join
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256m")

How AQE Works

# AQE re-optimizes at runtime based on actual data statistics:
# 1. After shuffle, it knows actual partition sizes
# 2. It can merge small partitions
# 3. It can split large partitions
# 4. It can change join strategies based on actual table sizes

# Example: Sort-merge join converted to broadcast join
# Before AQE: Both sides shuffled (expensive)
# After AQE: If one side is small, converted to broadcast join (fast)

# Verify AQE decisions
result.explain("formatted")
# Look for: "AdaptiveSparkPlan" in the plan

Performance Analysis Checklist

πŸ’‘Query Optimization Checklist

  1. Check EXPLAIN plan for unnecessary shuffles
  2. Verify predicate pushdown (filters near scan)
  3. Check column pruning (only needed columns read)
  4. Verify broadcast joins for small tables
  5. Check partition count (not too many/too few)
  6. Enable AQE for automatic optimization
  7. Cache DataFrames used multiple times
  8. Use bucketing for repeated joins
  9. Monitor Spark UI for stragglers
  10. Profile with Spark UI DAG visualization

Common Optimization Techniques

TechniqueWhen to UseExpected Improvement
Predicate pushdownAlways2-10x
Column pruningAlways2-5x
Broadcast joinSmall table < threshold2-100x
AQEAlways (Spark 3.0+)1.5-3x
CachingDataFrame used > once2-10x
BucketingRepeated joins on same key2-5x
RepartitionData skew, pre-join2-10x
Kryo serializationAlways2-10x

Edge Cases

1. UDFs Prevent Optimization

# UDFs are black boxes β€” Catalyst can't optimize through them
@udf(returnType=DoubleType())
def my_udf(x):
    return x * 2

# Catalyst can't push filter through this UDF
df.withColumn("result", my_udf(col("value"))) \
  .filter(col("result") > 100)  # Filter applied AFTER UDF

# Better: use built-in functions
df.withColumn("result", col("value") * 2) \
  .filter(col("result") > 100)  # Catalyst can optimize this

2. Skewed Data Prevents Optimization

# Catalyst assumes uniform distribution β€” skew breaks this
# Enable AQE skew handling
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

# Or manually salt skewed keys

3. Dynamic Partition Pruning

# Catalyst prunes partitions based on WHERE clauses
# Ensure partition columns are used directly in filters
spark.sql("""
    SELECT * FROM logs 
    WHERE date = '2024-01-01'  -- Prunes to single partition
""")

# Avoid functions on partition columns
spark.sql("""
    SELECT * FROM logs 
    WHERE to_date(timestamp) = '2024-01-01'  -- No pruning!
""")

Summary

Understanding Catalyst and Tungsten is essential for optimizing Spark queries at Google and Netflix scale. The optimizer applies rule-based and cost-based transformations to convert your query into efficient physical plans. By writing query-friendly code (filtering early, using broadcast hints, enabling AQE), you give Catalyst more opportunities to optimize.

Advertisement