π€ ML Pipeline Integration in PySpark
Architecture Diagram
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β PYSPARK ML PIPELINE ARCHITECTURE β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β ββββββββββββββββ ββββββββββββββββ ββββββββββββββββ β
β β Raw Data ββββββΆβ Feature ββββββΆβ Model β β
β β (DataFrame) β β Engineering β β Training β β
β ββββββββββββββββ ββββββββ¬ββββββββ ββββββββ¬ββββββββ β
β β β β
β βΌ βΌ β
β ββββββββββββββββββββ ββββββββββββββββββββ β
β β Transformers β β Estimators β β
β β ββββββββββββ β β ββββββββββββ β β
β β StringIndexer β β LogisticReg β β
β β OneHotEncoder β β RandomForest β β
β β VectorAssembler β β GBTClassifier β β
β β StandardScaler β β CrossValidator β β
β ββββββββββ¬ββββββββββ ββββββββββ¬ββββββββββ β
β β β β
β βΌ βΌ β
β ββββββββββββββββββββ ββββββββββββββββββββ β
β β Feature β β Model β β
β β Pipeline β β Persistence β β
β β (Chained β β (Save/Load) β β
β β stages) β β β β
β ββββββββββ¬ββββββββββ ββββββββββ¬ββββββββββ β
β β β β
β βΌ βΌ β
β ββββββββββββββββββββ ββββββββββββββββββββ β
β β Evaluation β β Deployment β β
β β βββββββββββββ β β βββββββββββββ β β
β β BinaryClassif β β MLflow β β
β β Multiclassif β β Airflow β β
β β Regression β β Kubernetes β β
β ββββββββββββββββββββ ββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β TRANSFORMER vs ESTIMATOR PATTERN β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β TRANSFORMER (DataFrame βββΆ DataFrame) β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Input DataFrame β β
β β βββββββ¬ββββββ¬ββββββ¬ββββββ β β
β β β A β B β C β D β β β
β β βββββββΌββββββΌββββββΌββββββ€ β β
β β β 1.0 β 2.0 β 3.0 β 4.0 β β β
β β βββββββ΄ββββββ΄ββββββ΄ββββββ β β
β β β β β
β β βΌ StandardScaler.transform() β β
β β β β β
β β βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ β β
β β β A β B β C β D β features_scaled β β β
β β βββββββΌββββββΌββββββΌββββββΌββββββ€ β β
β β β 1.0 β 2.0 β 3.0 β 4.0 β [0.27,0.53, β β β
β β β β β β β 0.80,1.07] β β β
β β βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β ESTIMATOR (DataFrame βββΆ Model βββΆ DataFrame) β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Training Data β β
β β βββββββ¬ββββββββββββββββββ¬βββββββ β β
β β β id β features βlabel β β β
β β βββββββΌββββββββββββββββββΌβββββββ€ β β
β β β 1 β [1.0, 2.0, 3.0] β 0 β β β
β β β 2 β [4.0, 5.0, 6.0] β 1 β β β
β β βββββββ΄ββββββββββββββββββ΄βββββββ β β
β β β β β
β β βΌ LogisticRegression.fit() β β
β β β β β
β β βββββββββββββββββββββββββββββββββββββββ β β
β β β Model (LogisticRegressionModel) β β β
β β β ββββββββββββββββββββββββββββββββββ β β β
β β β coefficients: [0.1, 0.2, 0.3] β β β
β β β intercept: -0.5 β β β
β β βββββββββββββββββββββββββββββββββββββββ β β
β β β β β
β β βΌ model.transform(testData) β β
β β β β β
β β βββββββ¬βββββββββββββ¬βββββββ¬βββββββββββ β β
β β β id β features βlabel β predictionβ β β
β β βββββββΌβββββββββββββΌβββββββΌβββββββββββ€ β β
β β β 1 β [1,2,3] β 0 β 0 β β β
β β β 2 β [4,5,6] β 1 β 1 β β β
β β βββββββ΄βββββββββββββ΄βββββββ΄βββββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β CROSS-VALIDATION & HYPERPARAMETER TUNING β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Dataset β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Fold 1 β Fold 2 β Fold 3 β Fold 4 β Fold 5 β β
β β βββββββββ β βββββββββ β βββββββββ β βββββββββ β βββββββββ β β
β β β Train β β β Train β β β Train β β β Train β β β Train β β β
β β β 80% β β β 80% β β β 80% β β β 80% β β β 80% β β β
β β βββββββββ€ β βββββββββ€ β βββββββββ€ β βββββββββ€ β βββββββββ€ β β
β β β Test β β β Test β β β Test β β β Test β β β Test β β β
β β β 20% β β β 20% β β β 20% β β β 20% β β β 20% β β β
β β βββββββββ β βββββββββ β βββββββββ β βββββββββ β βββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β For each fold: β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Hyperparameter Grid: β β
β β βββββββββββββββ¬ββββββββββββββ¬ββββββββββββββ β β
β β β regParam β elasticNet β maxIter β β β
β β βββββββββββββββΌββββββββββββββΌββββββββββββββ€ β β
β β β 0.01 β 0.0 β 100 β β Config 1 β β
β β β 0.1 β 0.5 β 100 β β Config 2 β β
β β β 1.0 β 1.0 β 200 β β Config 3 β β
β β βββββββββββββββ΄ββββββββββββββ΄ββββββββββββββ β β
β β β β
β β For each config Γ fold: β β
β β Train on (K-1) folds β Evaluate on held-out fold β β
β β β β
β β Select config with best average metric across all folds β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Results: β β
β β Config 1: Avg AUC = 0.892 β β
β β Config 2: Avg AUC = 0.934 β Best β β
β β Config 3: Avg AUC = 0.921 β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Detailed Explanation
PySpark MLlib provides a high-level API built on DataFrames that enables the construction of machine learning pipelines. The fundamental design pattern distinguishes between Transformers (which transform DataFrames) and Estimators (which learn from DataFrames to produce Transformers). This abstraction enables the composition of complex workflows through the Pipeline class, which chains stages together in a linear sequence where the output of one stage feeds into the input of the next.
Feature engineering in PySpark is primarily accomplished through Transformers. The StringIndexer converts categorical string columns into numerical indices, and the OneHotEncoder transforms those indices into binary vectors. The VectorAssembler combines multiple columns into a single feature vector, which is the required input format for all MLlib algorithms. The StandardScaler normalizes features to zero mean and unit variance, which is essential for algorithms sensitive to feature scales like logistic regression and SVM.
The Estimator pattern implements the fit-transform paradigm. An Estimator (like LogisticRegression) accepts a training DataFrame and produces a Model (like LogisticRegressionModel), which is itself a Transformer. This model can then be applied to new data via the transform method. The model persists the learned parameters (coefficients, intercepts, etc.) and encapsulates the complete inference logic.
Pipeline composition enables reproducible workflows. When you call fit() on a Pipeline, it executes each stage sequentially: first fitting all Estimators to produce Models, then applying all Transformers (including the produced Models) to the data. The entire Pipeline is serializable, meaning the fitted Pipeline (including all intermediate Transformers and Models) can be persisted to disk and loaded for inference without retraining.
Hyperparameter tuning is accomplished through CrossValidator and TrainValidationSplit. CrossValidator performs K-fold cross-validation, training K models on different data splits and evaluating on the held-out folds. The ParamGridBuilder defines the search space over hyperparameters. The evaluator (BinaryClassificationEvaluator, RegressionEvaluator, etc.) measures model quality. The best model is selected based on average metric across folds, and this model is then refit on the entire training dataset.
The evaluation metrics are crucial for model selection. For binary classification, PySpark provides BinaryClassificationEvaluator with metrics like AUC-ROC and AUC-PR. For multiclass classification, MulticlassClassificationEvaluator provides accuracy, precision, recall, F1 score, and log-loss. For regression, RegressionEvaluator provides MSE, RMSE, MAE, and R-squared. Understanding which metric aligns with your business objective is essentialβimbalanced datasets require precision/recall over accuracy, while cost-sensitive applications need custom weight configurations.
Mathematical Foundations
Definition: ML Pipeline
An ML pipeline is a directed acyclic graph of stages where vertices are transformation stages and edges define data dependencies. Each stage implements a function .
Cross-Validation Error
For -fold cross-validation with dataset split into equal partitions :
where is the loss function and is the model trained on all folds except .
Bias-Variance Decomposition
The expected prediction error of model decomposes as:
Regularization Strength
For L2-regularized regression with parameter :
The effective degrees of freedom reduce as increases:
Pipeline Complexity
Total pipeline computation cost with stages, each stage processing records:
where is the per-record cost of stage .
Key Insight
Spark ML pipelines distribute each stage independently, but data shuffle between stages creates synchronization barriers. Optimal pipeline design minimizes shuffles by combining compatible transformations within stages.
Summary
ML pipelines formalize ML workflows as DAGs of transformations. Cross-validation estimates generalization error, bias-variance decomposition guides model selection, and regularization trades off fit complexity against generalization. Pipeline cost scales linearly with stages and data volume.
Key Concepts Table
| Concept | Description | Usage Pattern |
|---|---|---|
| Transformer | Applies transformation to DataFrame | transformer.transform(df) |
| Estimator | Learns from DataFrame, produces Transformer | estimator.fit(df) β Model |
| Pipeline | Chains stages (Transformers + Estimators) | Pipeline(stages=[...]).fit(df) |
| Parameter | Hyperparameter for algorithm configuration | LogisticRegression(regParam=0.1) |
| Evaluator | Measures model quality on test data | BinaryClassificationEvaluator() |
| CrossValidator | K-fold cross-validation for tuning | CrossValidator(estimator, paramGrid) |
| Feature Vector | Combined feature column (required input) | VectorAssembler().transform(df) |
| Model Persistence | Save/load fitted pipelines | pipeline.save(path) / Pipeline.load(path) |
| StringIndexer | Converts strings to category indices | StringIndexer(inputCol="cat") |
| OneHotEncoder | Converts indices to binary vectors | OneHotEncoder(inputCol="idx") |
Code Examples
Complete Classification Pipeline
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import (
StringIndexer, OneHotEncoder, VectorAssembler,
StandardScaler, Imputer
)
from pyspark.ml.classification import (
LogisticRegression, RandomForestClassifier,
GBTClassifier, LinearSVC
)
from pyspark.ml.evaluation import (
BinaryClassificationEvaluator,
MulticlassClassificationEvaluator
)
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql.functions import *
spark = SparkSession.builder \
.appName("MLPipeline") \
.config("spark.jars.packages", "org.apache.spark:spark-mllib_2.12:3.5.0") \
.getOrCreate()
# Load and prepare data
df = spark.read.parquet("/data/customer_churn")
# Define feature columns
categorical_cols = ["gender", "contract_type", "payment_method"]
numerical_cols = ["tenure_months", "monthly_charges", "total_charges"]
label_col = "churn"
# Handle missing values
imputer = Imputer(
inputCols=numerical_cols,
outputCols=[f"{col}_imputed" for col in numerical_cols],
strategy="median"
)
# Index categorical columns
indexers = [
StringIndexer(
inputCol=col,
outputCol=f"{col}_index",
handleInvalid="keep"
)
for col in categorical_cols
]
# One-hot encode indexed columns
encoders = [
OneHotEncoder(
inputCol=f"{col}_index",
outputCol=f"{col}_encoded",
dropLast=True
)
for col in categorical_cols
]
# Assemble all features into a single vector
assembler = VectorAssembler(
inputCols=[f"{col}_imputed" for col in numerical_cols] +
[f"{col}_encoded" for col in categorical_cols],
outputCol="features_raw"
)
# Scale features
scaler = StandardScaler(
inputCol="features_raw",
outputCol="features",
withStd=True,
withMean=True
)
# Define classifiers
lr = LogisticRegression(
featuresCol="features",
labelCol=label_col,
maxIter=100,
regParam=0.01,
elasticNetParam=0.5
)
rf = RandomForestClassifier(
featuresCol="features",
labelCol=label_col,
numTrees=100,
maxDepth=10,
featureSubsetStrategy="sqrt"
)
gbt = GBTClassifier(
featuresCol="features",
labelCol=label_col,
maxIter=100,
maxDepth=5,
stepSize=0.1
)
# Build pipeline
pipeline = Pipeline(stages=[
imputer,
*indexers,
*encoders,
assembler,
scaler,
lr
])
# Hyperparameter tuning
paramGrid = ParamGridBuilder() \
.addGrid(lr.regParam, [0.01, 0.1, 1.0]) \
.addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0]) \
.addGrid(lr.maxIter, [50, 100, 200]) \
.build()
evaluator = BinaryClassificationEvaluator(
labelCol=label_col,
metricName="areaUnderROC"
)
crossval = CrossValidator(
estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
numFolds=5,
parallelism=4,
seed=42
)
# Train with cross-validation
cv_model = crossval.fit(df)
# Evaluate on test set
predictions = cv_model.transform(test_df)
auc = evaluator.evaluate(predictions)
print(f"AUC-ROC: {auc:.4f}")
# Detailed metrics
accuracy_evaluator = MulticlassClassificationEvaluator(
labelCol=label_col, metricName="accuracy"
)
f1_evaluator = MulticlassClassificationEvaluator(
labelCol=label_col, metricName="f1"
)
print(f"Accuracy: {accuracy_evaluator.evaluate(predictions):.4f}")
print(f"F1 Score: {f1_evaluator.evaluate(predictions):.4f}")
Feature Engineering Pipeline
from pyspark.ml.feature import (
Bucketizer, QuantileDiscretizer, SQLTransformer,
Interaction, ChiSqSelector, PCA
)
# Bucket continuous features
bucketizer = Bucketizer(
splits=[0, 18, 30, 45, 60, 100],
inputCol="age",
outputCol="age_bucket"
)
# SQL-based feature engineering
sql_transformer = SQLTransformer(
statement="""
SELECT
*,
monthly_charges / NULLIF(tenure_months, 0) AS avg_monthly_spend,
CASE
WHEN total_charges > 1000 THEN 'high_value'
WHEN total_charges > 500 THEN 'medium_value'
ELSE 'low_value'
END AS value_segment
FROM __THIS__
"""
)
# Create interaction features
interaction = Interaction(
inputCols=["gender_encoded", "contract_type_encoded"],
outputCol="gender_contract_interaction"
)
# Feature selection using Chi-Squared
selector = ChiSqSelector(
featuresCol="features",
labelCol=label_col,
numTopFeatures=20,
outputCol="selected_features"
)
# Dimensionality reduction with PCA
pca = PCA(
k=10,
inputCol="features",
outputCol="pca_features"
)
# Complete feature pipeline
feature_pipeline = Pipeline(stages=[
sql_transformer,
imputer,
*indexers,
*encoders,
interaction,
assembler,
scaler,
selector,
pca
])
feature_model = feature_pipeline.fit(df)
feature_df = feature_model.transform(df)
feature_df.select("features", "pca_features", "selected_features").show(5, truncate=False)
Performance Metrics
| Metric | LogisticRegression | RandomForest | GBTClassifier | LinearSVC |
|---|---|---|---|---|
| Training Time (1M rows, 50 features) | 2-5 seconds | 10-30 seconds | 30-90 seconds | 3-8 seconds |
| Inference Time (100K rows) | 50-100 ms | 100-300 ms | 150-400 ms | 40-90 ms |
| AUC-ROC (typical) | 0.85-0.92 | 0.88-0.95 | 0.90-0.97 | 0.84-0.91 |
| Memory Usage | Low | High | Very High | Low |
| Interpretability | High | Medium | Low | High |
| Overfitting Risk | Low | Low-Medium | High | Low |
| Feature Importance | Coefficient magnitude | Built-in | Built-in | Coefficient magnitude |
| Parallelism | Data parallel | Data + Tree parallel | Data + Tree parallel | Data parallel |
| Hyperparameter Sensitivity | Medium | Low | High | Medium |
| Handles Imbalanced Data | With weightCol | With weightCol | With weightCol | With weightCol |
Best Practices
- Always use Pipeline to ensure reproducibility and prevent data leakage from test set into training transformations
- Fit all preprocessors on training data only and apply the fitted preprocessors to test data using the same fitted model
- Use
handleInvalid="keep"in StringIndexer to gracefully handle unseen categories at inference time - Set a random seed in all stochastic algorithms and cross-validation to ensure reproducible results
- Use CrossValidator with parallelism > 1 to leverage Spark's distributed computing for faster tuning
- Scale features before training models sensitive to feature magnitudes (logistic regression, SVM, neural networks)
- Monitor overfitting by comparing training and validation metricsβif gap exceeds 5%, reduce model complexity
- Use
featureSubsetStrategy="sqrt"in RandomForest to decorrelate trees and reduce overfitting - Persist intermediate DataFrames when building complex pipelines to avoid recomputation
- Log all experiments with MLflow or similar tracking system, including data versions, hyperparameters, and metrics
- Use
VectorAssemblerbeforeStandardScalerto ensure proper feature normalization - Implement stratified sampling for imbalanced datasets using
df.sampleBy()to ensure balanced class representation in training
See also: Snowflake Time Travel (snowflake/02), Kafka CDC (kafka/04), Airflow DAGs (airflow/02)