Training Pipeline Automation
Difficulty: Senior Level | Companies: Google, Meta, Netflix, Uber, Stripe
Pipeline Orchestration
Automated pipelines ensure reproducibility, scalability, and maintainability of ML workflows.
βΉοΈ
Google's TFX pipelines process billions of predictions daily with full pipeline automation.
Kubeflow Pipelines
# kubeflow_pipeline.py
import kfp
from kfp import dsl
from kfp.components import load_component_from_file
from kfp.dsl import (
Input, Output, Artifact, Dataset, Model,
ContainerOp, component
)
from typing import Dict, Any
@component(
base_image="python:3.9",
packages_to_install=["pandas==1.5.0", "scikit-learn==1.1.0"]
)
def data_validation_component(
data_path: str,
min_samples: int = 1000,
max_null_ratio: float = 0.1
) -> str:
import pandas as pd
import json
df = pd.read_parquet(data_path)
validation_results = {
"total_samples": len(df),
"total_features": len(df.columns),
"null_ratios": df.isnull().mean().to_dict(),
"passed": True,
"errors": []
}
if len(df) < min_samples:
validation_results["passed"] = False
validation_results["errors"].append(f"Insufficient samples: {len(df)} < {min_samples}")
for col, ratio in validation_results["null_ratios"].items():
if ratio > max_null_ratio:
validation_results["passed"] = False
validation_results["errors"].append(f"High null ratio in {col}: {ratio}")
if not validation_results["passed"]:
raise ValueError(f"Data validation failed: {validation_results['errors']}")
return json.dumps(validation_results)
@component(
base_image="python:3.9",
packages_to_install=["pandas==1.5.0", "scikit-learn==1.1.0", "joblib==1.2.0"]
)
def feature_engineering_component(
data_path: str,
output_path: str,
feature_config: Dict[str, Any]
) -> str:
import pandas as pd
from sklearn.preprocessing import StandardScaler, LabelEncoder
import joblib
import json
df = pd.read_parquet(data_path)
numerical_features = feature_config.get("numerical", [])
categorical_features = feature_config.get("categorical", [])
target_column = feature_config.get("target", "target")
scaler = StandardScaler()
if numerical_features:
df[numerical_features] = scaler.fit_transform(df[numerical_features])
joblib.dump(scaler, f"{output_path}/scaler.pkl")
label_encoders = {}
for col in categorical_features:
le = LabelEncoder()
df[col] = le.fit_transform(df[col].astype(str))
label_encoders[col] = le
joblib.dump(label_encoders, f"{output_path}/label_encoders.pkl")
df.to_parquet(f"{output_path}/features.parquet", index=False)
return json.dumps({
"num_features": len(numerical_features),
"cat_features": len(categorical_features),
"output_samples": len(df)
})
@component(
base_image="python:3.9",
packages_to_install=["scikit-learn==1.1.0", "mlflow==2.0.0"]
)
def training_component(
features_path: str,
model_output: str,
hyperparameters: Dict[str, Any]
) -> str:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import mlflow
import json
df = pd.read_parquet(features_path)
X = df.drop(columns=["target"])
y = df["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
mlflow.set_experiment("training_pipeline")
with mlflow.start_run():
mlflow.log_params(hyperparameters)
model = RandomForestClassifier(**hyperparameters)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
metrics = {
"accuracy": accuracy_score(y_test, y_pred),
"f1_score": f1_score(y_test, y_pred, average="weighted")
}
mlflow.log_metrics(metrics)
mlflow.sklearn.log_model(model, "model")
return json.dumps(metrics)
@dsl.pipeline(
name="ML Training Pipeline",
description="End-to-end ML training pipeline"
)
def training_pipeline(
data_path: str = "gs://data-bucket/raw/data.parquet",
output_path: str = "gs://data-bucket/processed",
hyperparameters: Dict[str, Any] = {"n_estimators": 100, "max_depth": 10}
):
validation_task = data_validation_component(data_path=data_path)
validation_task.after("data_ingestion")
feature_task = feature_engineering_component(
data_path=data_path,
output_path=output_path,
feature_config={
"numerical": ["feature_1", "feature_2", "feature_3"],
"categorical": ["category_1", "category_2"],
"target": "target"
}
)
feature_task.after(validation_task)
training_task = training_component(
features_path=f"{output_path}/features.parquet",
model_output=f"{output_path}/model",
hyperparameters=hyperparameters
)
training_task.after(feature_task)
# Compile pipeline
compiler = kfp.compiler.Compiler()
compiler.compile(training_pipeline, "training_pipeline.yaml")
Apache Airflow Pipeline
# airflow_pipeline.py
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
from airflow.utils.task_group import TaskGroup
from datetime import datetime, timedelta
import pandas as pd
import numpy as np
from typing import Dict
default_args = {
"owner": "ml-team",
"depends_on_past": False,
"email_on_failure": True,
"email": ["ml-alerts@company.com"],
"retries": 3,
"retry_delay": timedelta(minutes=5),
}
def validate_data(**context):
ti = context["ti"]
df = pd.read_parquet("s3://data-bucket/raw/data.parquet")
validation_results = {
"samples": len(df),
"features": len(df.columns),
"null_percentage": df.isnull().mean().mean() * 100,
"passed": True
}
if len(df) < 1000:
validation_results["passed"] = False
if validation_results["null_percentage"] > 10:
validation_results["passed"] = False
ti.xcom_push(key="validation_results", value=validation_results)
return validation_results
def compute_features(**context):
ti = context["ti"]
df = pd.read_parquet("s3://data-bucket/raw/data.parquet")
df["feature_rolling_mean"] = df["value"].rolling(window=7).mean()
df["feature_rolling_std"] = df["value"].rolling(window=7).std()
df["feature_lag_1"] = df["value"].shift(1)
df["feature_lag_7"] = df["value"].shift(7)
df["feature_diff"] = df["value"].diff()
df["feature_pct_change"] = df["value"].pct_change()
df.to_parquet("s3://data-bucket/processed/features.parquet", index=False)
ti.xcom_push(key="feature_count", value=len(df.columns))
def train_model(**context):
ti = context["ti"]
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import mlflow
df = pd.read_parquet("s3://data-bucket/processed/features.parquet")
X = df.drop(columns=["target"])
y = df["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
mlflow.set_experiment("airflow_pipeline")
with mlflow.start_run():
model = GradientBoostingClassifier(n_estimators=100, max_depth=5)
model.fit(X_train, y_train)
accuracy = accuracy_score(y_test, model.predict(X_test))
mlflow.log_metric("accuracy", accuracy)
mlflow.sklearn.log_model(model, "model")
ti.xcom_push(key="model_accuracy", value=accuracy)
def evaluate_model(**context):
ti = context["ti"]
accuracy = ti.xcom_pull(key="model_accuracy", task_ids="training.train_model")
if accuracy < 0.8:
raise ValueError(f"Model accuracy {accuracy} below threshold")
with DAG(
"ml_training_pipeline",
default_args=default_args,
description="ML training pipeline with Airflow",
schedule_interval="@daily",
start_date=datetime(2024, 1, 1),
catchup=False,
) as dag:
with TaskGroup("data_processing") as data_processing:
check_data = S3KeySensor(
task_id="check_data_available",
bucket_name="data-bucket",
bucket_key="raw/data.parquet",
timeout=3600,
poke_interval=60,
)
validate = PythonOperator(
task_id="validate_data",
python_callable=validate_data,
)
features = PythonOperator(
task_id="compute_features",
python_callable=compute_features,
)
check_data >> validate >> features
train = PythonOperator(
task_id="train_model",
python_callable=train_model,
)
evaluate = PythonOperator(
task_id="evaluate_model",
python_callable=evaluate_model,
)
data_processing >> train >> evaluate
Follow-Up Questions
- How would you implement pipeline versioning and rollback?
- What are the trade-offs between Kubeflow and Airflow for ML pipelines?
- How do you handle long-running training jobs with spot instances?
- What monitoring is needed for pipeline health and performance?