πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Topic: User Defined Functions (UDFs)

PySpark AdvancedUDFs and Pandas UDFs⭐ Premium

Advertisement

PySpark Advanced Interview Series

Module 08: UDFs β€” Extending Spark's Capabilities

MicrosoftMetaDifficulty: Hard

Interview Question

"At Microsoft, we process billions of records using UDFs for complex business logic. Walk us through the difference between regular UDFs and Pandas UDFs, how serialization works, and what performance improvements you can expect from vectorized UDFs." β€” Microsoft Data Engineer Interview

"At Meta, we use UDFs for custom ML inference at scale. Explain how PySpark UDFs serialize data between JVM and Python, the limitations of regular UDFs, and how Pandas UDFs achieve 3-100x performance improvements." β€” Meta Senior Data Engineer Interview


UDF Architecture in PySpark

PySpark UDFs work through a complex serialization pipeline:

Architecture Diagram
JVM (Spark) β†’ Python Worker β†’ UDF Execution β†’ Python Worker β†’ JVM (Spark)
  1. Spark serializes data from JVM to Python using pickle
  2. Python worker executes the UDF function
  3. Results are serialized back to JVM

This JVM-Python bridge introduces significant overhead compared to JVM-based UDFs in Scala/Java.


Regular UDFs

Basic UDF

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType, IntegerType, DoubleType

spark = SparkSession.builder.appName("UDFInterview").getOrCreate()

# Define a UDF
@udf(returnType=StringType())
def categorize_age(age):
    if age is None:
        return "Unknown"
    elif age < 25:
        return "Junior"
    elif age < 35:
        return "Mid-Level"
    else:
        return "Senior"

# Register for DataFrame API
df = spark.createDataFrame([(1, 30), (2, 25), (3, 40)], ["id", "age"])
df = df.withColumn("category", categorize_age(col("age")))

# Register for SQL
spark.udf.register("categorize_age", categorize_age)

# Use in SQL
df.createOrReplaceTempView("employees")
spark.sql("SELECT id, age, categorize_age(age) as category FROM employees").show()

UDF with Multiple Parameters

@udf(returnType=DoubleType())
def calculate_discount(price, quantity, membership_level):
    if price is None or quantity is None:
        return 0.0
    
    base_discount = price * quantity * 0.1
    
    if membership_level == "gold":
        return base_discount * 1.5
    elif membership_level == "silver":
        return base_discount * 1.2
    else:
        return base_discount

# Apply UDF
df = df.withColumn(
    "discount",
    calculate_discount(col("price"), col("quantity"), col("membership_level"))
)

UDF with State

# UDFs should be stateless β€” don't rely on external state
# But you can use closures for configuration

def create_threshold_extractor(threshold):
    @udf(returnType=StringType())
    def extractor(value):
        if value is None:
            return "below"
        return "above" if value > threshold else "below"
    return extractor

above_100 = create_threshold_extractor(100)
df = df.withColumn("status", above_100(col("value")))

Pandas UDFs (Vectorized UDFs)

Pandas UDFs use Apache Arrow to transfer data between JVM and Python, achieving dramatically better performance.

Series to Series UDF

from pyspark.sql.functions import pandas_udf
import pandas as pd

# Pandas UDF that operates on Series
@pandas_udf(DoubleType())
def multiply_by_two(s: pd.Series) -> pd.Series:
    return s * 2

df = df.withColumn("doubled", multiply_by_two(col("value")))

Grouped Map UDF

from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, StringType, DoubleType

# Define output schema
schema = StructType([
    StructField("category", StringType(), False),
    StructField("avg_value", DoubleType(), False),
    StructField("std_value", DoubleType(), False)
])

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def compute_stats(pdf: pd.DataFrame) -> pd.DataFrame:
    return pd.DataFrame({
        "category": [pdf["category"].iloc[0]],
        "avg_value": [pdf["value"].mean()],
        "std_value": [pdf["value"].std()]
    })

# Apply grouped map
result = df.groupBy("category").apply(compute_stats)
result.show()

Iterator UDF (Spark 3.0+)

from pyspark.sql.functions import pandas_udf
import pandas as pd

# Iterator-based UDF for batch processing
@pandas_udf(DoubleType())
def batch_process(s: pd.Series) -> pd.Series:
    # Process entire batch at once
    return s.fillna(0).rolling(window=3).mean()

df = df.withColumn("rolling_avg", batch_process(col("value")))

Performance Comparison

import time
from pyspark.sql.functions import udf, pandas_udf, col
import pandas as pd

# Create test data
df = spark.range(10000000).withColumn("value", (col("id") * 1.5).cast("double"))

# Regular UDF
@udf(returnType=DoubleType())
def regular_udf(x):
    return x * 2 if x else 0

start = time.time()
df.withColumn("result", regular_udf(col("value"))).count()
regular_time = time.time() - start
print(f"Regular UDF: {regular_time:.2f}s")

# Pandas UDF
@pandas_udf(DoubleType())
def pandas_udf_func(s: pd.Series) -> pd.Series:
    return s * 2

start = time.time()
df.withColumn("result", pandas_udf_func(col("value"))).count()
pandas_time = time.time() - start
print(f"Pandas UDF: {pandas_time:.2f}s")

# Built-in function (for comparison)
from pyspark.sql.functions import lit
start = time.time()
df.withColumn("result", col("value") * 2).count()
builtin_time = time.time() - start
print(f"Built-in: {builtin_time:.2f}s")
Method10M RowsRelative Speed
Built-in function2.1s1x (fastest)
Pandas UDF4.8s2.3x slower
Regular UDF45.2s21.5x slower

πŸ’‘Meta Pro Tip

At Meta, Pandas UDFs are preferred over regular UDFs for any custom logic. The Arrow-based data transfer reduces serialization overhead by 10-100x. However, built-in functions are always faster than any UDF because they're optimized by Catalyst.


Serialization Deep Dive

Regular UDF Serialization

# Regular UDFs use pickle for serialization
# Data is serialized row-by-row from JVM to Python

# This is inefficient because:
# 1. Each row requires serialization/deserialization
# 2. Python objects are larger than JVM objects
# 3. No vectorized operations

# Example: Sending 1M rows to Python
# JVM sends: [row1, row2, ..., row1000000] one by one
# Python processes: one row at a time
# Python returns: [result1, result2, ..., result1000000]

Pandas UDF Serialization (Arrow)

# Pandas UDFs use Apache Arrow for serialization
# Data is sent as columnar batches, not row-by-row

# This is efficient because:
# 1. Columnar format reduces serialization overhead
# 2. Arrow is language-agnostic (JVM and Python)
# 3. Vectorized operations on entire columns

# Example: Sending 1M rows to Python
# JVM sends: Arrow column [value1, value2, ..., value1000000]
# Python receives: pd.Series with all values
# Python processes: entire Series at once
# Python returns: pd.Series with results

Arrow Configuration

# Enable Arrow optimization
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# Configure Arrow batch size
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10000")

# For large datasets, increase batch size
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "100000")

Real-World Scenario: Microsoft Analytics Pipeline

Problem Statement

Build a complex analytics pipeline with custom business logic that requires UDFs. Optimize with Pandas UDFs and handle edge cases.

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import pandas as pd
import numpy as np

spark = SparkSession.builder \
    .appName("MicrosoftAnalyticsUDF") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .config("spark.sql.shuffle.partitions", "200") \
    .getOrCreate()

# Read data
sales_df = spark.read.parquet("s3a://microsoft-data/sales/")
customer_df = spark.read.parquet("s3a://microsoft-data/customers/")

# === Regular UDFs for complex logic ===

@udf(returnType=StringType())
def classify_revenue(revenue):
    """Classify revenue into tiers"""
    if revenue is None:
        return "Unknown"
    elif revenue < 1000:
        return "Low"
    elif revenue < 10000:
        return "Medium"
    elif revenue < 100000:
        return "High"
    else:
        return "Enterprise"

@udf(returnType=DoubleType())
def calculate_ltv(monthly_revenue, tenure_months, churn_probability):
    """Calculate customer lifetime value"""
    if monthly_revenue is None or tenure_months is None:
        return 0.0
    
    expected_lifetime = tenure_months * (1 - (churn_probability or 0))
    return monthly_revenue * expected_lifetime

@udf(returnType=ArrayType(StringType()))
def parse_product_bundle(bundle_string):
    """Parse comma-separated product bundle"""
    if bundle_string is None:
        return []
    return [p.strip() for p in bundle_string.split(",")]

# Apply regular UDFs
enriched_sales = sales_df \
    .withColumn("revenue_tier", classify_revenue(col("revenue"))) \
    .withColumn("ltv", calculate_ltv(
        col("monthly_revenue"), 
        col("tenure_months"), 
        col("churn_probability")
    )) \
    .withColumn("products", parse_product_bundle(col("product_bundle")))

# === Pandas UDFs for vectorized processing ===

# Window function using Pandas UDF
@pandas_udf(DoubleType())
def rolling_average(values: pd.Series) -> pd.Series:
    """Compute 7-day rolling average"""
    return values.rolling(window=7, min_periods=1).mean()

# Apply rolling average
enriched_sales = enriched_sales \
    .withColumn("rolling_avg_revenue", 
                rolling_average(col("revenue").over(
                    Window.partitionBy("customer_id").orderBy("date")
                )))

# Complex transformation using grouped map
@pandas_udf(
    StructType([
        StructField("customer_id", StringType(), False),
        StructField("region", StringType(), False),
        StructField("total_revenue", DoubleType(), False),
        StructField("avg_order_value", DoubleType(), False),
        StructField("order_frequency", DoubleType(), False),
        StructField("customer_segment", StringType(), False)
    ]),
    PandasUDFType.GROUPED_MAP
)
def compute_customer_metrics(pdf: pd.DataFrame) -> pd.DataFrame:
    """Compute comprehensive customer metrics"""
    total_revenue = pdf["revenue"].sum()
    order_count = len(pdf)
    avg_order = pdf["revenue"].mean()
    
    # Determine segment
    if total_revenue > 100000:
        segment = "Enterprise"
    elif total_revenue > 10000:
        segment = "Mid-Market"
    else:
        segment = "SMB"
    
    return pd.DataFrame({
        "customer_id": [pdf["customer_id"].iloc[0]],
        "region": [pdf["region"].iloc[0]],
        "total_revenue": [total_revenue],
        "avg_order_value": [avg_order],
        "order_frequency": [order_count / max(pdf["days_since_first_order"].max(), 1)],
        "customer_segment": [segment]
    })

# Apply grouped map
customer_metrics = sales_df.groupBy("customer_id").apply(compute_customer_metrics)

# Show results
enriched_sales.show(10, truncate=False)
customer_metrics.show(10, truncate=False)

spark.stop()

UDF Best Practices

1. Prefer Built-in Functions

# BAD: UDF for simple operations
@udf(returnType=DoubleType())
def square(x):
    return x ** 2 if x else None

df.withColumn("squared", square(col("value")))

# GOOD: Built-in function
from pyspark.sql.functions import pow
df.withColumn("squared", pow(col("value"), 2))

2. Handle NULLs

# BAD: Ignores NULLs
@udf(returnType=DoubleType())
def process(x):
    return x * 2  # Returns None for None input

# GOOD: Explicitly handles NULLs
@udf(returnType=DoubleType())
def process(x):
    if x is None:
        return None
    return x * 2

3. Use Pandas UDFs for Vectorized Operations

# BAD: Row-by-row processing
@udf(returnType=DoubleType())
def normalize(x, mean, std):
    return (x - mean) / std if x else None

# GOOD: Vectorized processing
@pandas_udf(DoubleType())
def normalize(s: pd.Series, mean: float, std: float) -> pd.Series:
    return (s - mean) / std

4. Register UDFs for SQL

# Always register UDFs if you need them in SQL
spark.udf.register("my_udf", my_udf_function, returnType=StringType())

# Now usable in SQL
spark.sql("SELECT my_udf(column) FROM table")

ℹ️Microsoft Interview Insight

At Microsoft, interviewers expect you to know that UDFs bypass Catalyst optimization. When you use a UDF, Spark can't push predicates through it, optimize column pruning, or apply code generation. Always prefer built-in functions for performance-critical paths.


Edge Cases

1. UDF with External Dependencies

# BAD: UDF depends on external library not available on executors
@udf(returnType=StringType())
def call_api(value):
    import requests  # May not be installed on executors!
    response = requests.get(f"https://api.example.com/{value}")
    return response.json()["result"]

# GOOD: Broadcast the dependency or use pandas_udf
# Or ensure the library is installed on all executors

2. UDF State Management

# BAD: UDF maintains state between calls (unreliable!)
counter = 0

@udf(returnType=IntegerType())
def increment(x):
    global counter
    counter += 1
    return counter  # Not guaranteed to work correctly

# GOOD: Stateless UDF
@udf(returnType=IntegerType())
def process(x):
    return x * 2 if x else None

3. UDF Error Handling

# BAD: UDF can crash the entire job
@udf(returnType=DoubleType())
def risky_function(x):
    return 1 / x  # Division by zero crashes the executor

# GOOD: UDF handles errors gracefully
@udf(returnType=DoubleType())
def safe_function(x):
    try:
        if x is None or x == 0:
            return None
        return 1 / x
    except Exception:
        return None

Performance Optimization

TechniqueSpeedupUse Case
Pandas UDF3-100xVectorized operations
Built-in functions10-1000xWhen available
Arrow optimization2-5xEnable Arrow
Batch processing2-10xProcess batches, not rows
Caching input2-5xReuse input data
# Enable Arrow for all Pandas UDFs
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# Optimize batch size
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "50000")

# Use built-in functions when possible
from pyspark.sql.functions import *
df.withColumn("result", col("value") * 2 + 1)  # Better than UDF

Summary

UDFs are essential for implementing complex business logic in PySpark, but they come with performance costs. Regular UDFs have significant serialization overhead; Pandas UDFs using Arrow are much faster; built-in functions are fastest. At Microsoft and Meta, the rule is: never use a UDF when a built-in function can do the job.

Advertisement