πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Topic: Data Quality and Validation

PySpark AdvancedData Quality⭐ Premium

Advertisement

PySpark Advanced Interview Series

Module 13: Data Quality β€” Ensuring Data Integrity at Scale

NetflixAmazonDifficulty: Hard

Interview Question

"At Netflix, data quality is critical for recommendation systems. Walk us through your data validation framework β€” how do you detect and handle duplicates, nulls, schema drift, and data anomalies at petabyte scale?" β€” Netflix Data Engineer Interview

"At Amazon, we process millions of transactions hourly. Explain how you would implement data quality checks that catch issues before they propagate to downstream systems, and how you would design a quarantine mechanism for bad data." β€” Amazon Senior Data Engineer Interview


Data Quality Dimensions

  1. Completeness: Are required fields present?
  2. Uniqueness: Are records unique?
  3. Validity: Do values conform to expected formats?
  4. Consistency: Are values consistent across sources?
  5. Timeliness: Is data arriving on time?
  6. Accuracy: Does data reflect reality?

Null Handling Strategies

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

spark = SparkSession.builder.appName("DataQualityInterview").getOrCreate()

# Create sample data with quality issues
data = [
    (1, "Alice", "alice@email.com", 100, "2024-01-01"),
    (2, "Bob", None, 200, "2024-01-02"),
    (3, None, "charlie@email.com", None, "2024-01-03"),
    (1, "Alice", "alice@email.com", 100, "2024-01-01"),  # Duplicate
    (4, "Diana", "diana@email.com", 300, "invalid-date"),  # Invalid date
    (5, "Eve", "eve@email.com", -50, "2024-01-05"),  # Negative value
]

df = spark.createDataFrame(data, 
    ["id", "name", "email", "amount", "date"])

# === NULL Detection ===
# Count NULLs per column
null_counts = df.select([
    count(when(col(c).isNull(), c)).alias(c) for c in df.columns
])
null_counts.show()

# Percentage of NULLs
total_rows = df.count()
null_percentage = df.select([
    (count(when(col(c).isNull(), c)) / total_rows * 100).alias(f"{c}_null_pct") 
    for c in df.columns
])
null_percentage.show()

# === NULL Handling ===
# 1. Drop rows with NULLs
df_dropped = df.na.drop()  # Drop rows with any NULL
df_dropped_subset = df.na.drop(subset=["id", "name"])  # Drop only if these are NULL

# 2. Fill NULLs with default values
df_filled = df.na.fill({
    "name": "Unknown",
    "email": "no-email@domain.com",
    "amount": 0,
    "date": "1970-01-01"
})

# 3. Fill with column-specific values
df_filled_expr = df.withColumn(
    "email",
    coalesce(col("email"), lit("no-email@domain.com"))
).withColumn(
    "amount",
    coalesce(col("amount"), lit(0))
)

# 4. Forward fill (last observation carried forward)
from pyspark.sql.window import Window
ffill_window = Window.orderBy("id").rowsBetween(Window.unboundedPreceding, 0)
df_ffill = df.withColumn(
    "amount_ffill",
    last("amount", ignorenulls=True).over(ffill_window)
)

# 5. Interpolation
df_interpolated = df.withColumn(
    "amount_interpolated",
    when(col("amount").isNull(),
         (lag("amount", 1).over(Window.orderBy("id")) + 
          lead("amount", 1).over(Window.orderBy("id"))) / 2
    ).otherwise(col("amount"))
)

Deduplication

Exact Deduplication

# Drop exact duplicates
df_dedup = df.dropDuplicates()

# Drop duplicates based on specific columns
df_dedup_key = df.dropDuplicates(["id", "name"])

# Keep first occurrence
df_dedup_first = df.dropDuplicates(["id"], keep="first")

# Keep last occurrence
df_dedup_last = df.dropDuplicates(["id"], keep="last")

Fuzzy Deduplication

from pyspark.sql.functions import *
from pyspark.sql.window import Window

# Deduplicate with time priority (keep most recent)
window_dedup = Window.partitionBy("id").orderBy(col("date").desc())

df_dedup_time = df \
    .withColumn("row_num", row_number().over(window_dedup)) \
    .filter(col("row_num") == 1) \
    .drop("row_num")

# Deduplicate with complex logic
# Keep record with most complete data
completeness_score = df.withColumn(
    "completeness",
    (when(col("name").isNotNull(), 1).otherwise(0) +
     when(col("email").isNotNull(), 1).otherwise(0) +
     when(col("amount").isNotNull(), 1).otherwise(0))
)

window_complete = Window.partitionBy("id").orderBy(col("completeness").desc())
df_best = completeness_score \
    .withColumn("row_num", row_number().over(window_complete)) \
    .filter(col("row_num") == 1) \
    .drop("row_num", "completeness")

Data Validation Patterns

Schema Validation

from pyspark.sql.types import StructType, StructField, StringType, IntegerType

# Define expected schema
expected_schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("name", StringType(), False),
    StructField("email", StringType(), True),
    StructField("amount", IntegerType(), True),
    StructField("date", StringType(), True)
])

# Validate schema
def validate_schema(df, expected):
    actual_fields = set(df.columns)
    expected_fields = set([f.name for f in expected.fields()])
    
    missing = expected_fields - actual_fields
    extra = actual_fields - expected_fields
    
    if missing:
        raise ValueError(f"Missing columns: {missing}")
    if extra:
        print(f"Warning: Extra columns: {extra}")
    
    return True

Data Type Validation

# Validate data types
df_validated = df \
    .withColumn("amount_valid", 
                when(col("amount").cast("int").isNotNull(), True)
                .otherwise(False)) \
    .withColumn("date_valid",
                when(col("date").rlike(r"^\d{4}-\d{2}-\d{2}$"), True)
                .otherwise(False))

# Report validation failures
invalid_rows = df_validated.filter(
    ~col("amount_valid") | ~col("date_valid")
)

Business Rule Validation

# Define business rules
def validate_business_rules(df):
    violations = []
    
    # Rule 1: Amount must be positive
    negative_amount = df.filter(col("amount") < 0).count()
    if negative_amount > 0:
        violations.append(f"Negative amounts: {negative_amount} rows")
    
    # Rule 2: Email must contain @
    invalid_email = df.filter(
        col("email").isNotNull() & ~col("email").contains("@")
    ).count()
    if invalid_email > 0:
        violations.append(f"Invalid emails: {invalid_email} rows")
    
    # Rule 3: Date must be valid
    invalid_date = df.filter(
        ~col("date").rlike(r"^\d{4}-\d{2}-\d{2}$")
    ).count()
    if invalid_date > 0:
        violations.append(f"Invalid dates: {invalid_date} rows")
    
    return violations

violations = validate_business_rules(df)
if violations:
    print("Data quality violations found:")
    for v in violations:
        print(f"  - {v}")

Real-World Scenario: Netflix Data Quality Pipeline

Problem Statement

Build a comprehensive data quality framework that validates, cleanses, and quarantines bad data from a petabyte-scale viewing history dataset.

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
from pyspark.sql.types import *
import json

spark = SparkSession.builder \
    .appName("NetflixDataQuality") \
    .config("spark.sql.shuffle.partitions", "200") \
    .getOrCreate()

# Read raw data
raw_data = spark.read.parquet("s3a://netflix-data/viewing-raw/")

# === DATA QUALITY FRAMEWORK ===

class DataQualityChecker:
    def __init__(self, df, name="dataset"):
        self.df = df
        self.name = name
        self.violations = []
        self.quarantine = None
    
    def check_completeness(self, required_columns, threshold=0.95):
        """Check that required columns have data above threshold"""
        total_rows = self.df.count()
        for col_name in required_columns:
            non_null = self.df.filter(col(col_name).isNotNull()).count()
            completeness = non_null / total_rows
            if completeness < threshold:
                self.violations.append({
                    "rule": "completeness",
                    "column": col_name,
                    "completeness": completeness,
                    "threshold": threshold
                })
        return self
    
    def check_uniqueness(self, key_columns):
        """Check for duplicate records"""
        total_rows = self.df.count()
        unique_rows = self.df.dropDuplicates(key_columns).count()
        duplicates = total_rows - unique_rows
        if duplicates > 0:
            self.violations.append({
                "rule": "uniqueness",
                "columns": key_columns,
                "duplicates": duplicates,
                "percentage": duplicates / total_rows
            })
        return self
    
    def check_range(self, column, min_val=None, max_val=None):
        """Check values are within expected range"""
        if min_val is not None:
            below_min = self.df.filter(col(column) < min_val).count()
            if below_min > 0:
                self.violations.append({
                    "rule": "range_min",
                    "column": column,
                    "min": min_val,
                    "violations": below_min
                })
        if max_val is not None:
            above_max = self.df.filter(col(column) > max_val).count()
            if above_max > 0:
                self.violations.append({
                    "rule": "range_max",
                    "column": column,
                    "max": max_val,
                    "violations": above_max
                })
        return self
    
    def check_format(self, column, pattern):
        """Check values match expected format"""
        invalid = self.df.filter(
            col(column).isNotNull() & ~col(column).rlike(pattern)
        ).count()
        if invalid > 0:
            self.violations.append({
                "rule": "format",
                "column": column,
                "pattern": pattern,
                "violations": invalid
            })
        return self
    
    def check_freshness(self, timestamp_column, max_delay_hours=24):
        """Check data freshness"""
        max_timestamp = self.df.agg(max(col(timestamp_column))).collect()[0][0]
        if max_timestamp:
            delay_hours = (current_timestamp().cast("long") - max_timestamp.cast("long")) / 3600
            if delay_hours > max_delay_hours:
                self.violations.append({
                    "rule": "freshness",
                    "column": timestamp_column,
                    "delay_hours": delay_hours,
                    "max_allowed": max_delay_hours
                })
        return self
    
    def quarantine_bad_records(self, quarantine_path):
        """Move bad records to quarantine"""
        if not self.violations:
            return self
        
        # Create quarantine condition
        quarantine_condition = None
        for violation in self.violations:
            if violation["rule"] == "completeness":
                condition = col(violation["column"]).isNull()
            elif violation["rule"] == "range_min":
                condition = col(violation["column"]) < violation["min"]
            elif violation["rule"] == "range_max":
                condition = col(violation["column"]) > violation["max"]
            elif violation["rule"] == "format":
                condition = ~col(violation["column"]).rlike(violation["pattern"])
            else:
                continue
            
            if quarantine_condition is None:
                quarantine_condition = condition
            else:
                quarantine_condition = quarantine_condition | condition
        
        if quarantine_condition:
            self.quarantine = self.df.filter(quarantine_condition)
            self.quarantine = self.quarantine.withColumn(
                "quarantine_reason", lit(json.dumps(self.violations))
            ).withColumn(
                "quarantine_timestamp", current_timestamp()
            )
            
            # Write quarantine records
            self.quarantine.write \
                .mode("append") \
                .parquet(quarantine_path)
            
            # Remove bad records from main dataset
            self.df = self.df.filter(~quarantine_condition)
        
        return self
    
    def get_report(self):
        """Generate quality report"""
        return {
            "dataset": self.name,
            "total_rows": self.df.count(),
            "violations": self.violations,
            "quarantine_rows": self.quarantine.count() if self.quarantine else 0
        }

# === APPLY DATA QUALITY CHECKS ===

checker = DataQualityChecker(raw_data, "viewing_history")

quality_report = checker \
    .check_completeness(
        required_columns=["user_id", "content_id", "event_time"],
        threshold=0.99
    ) \
    .check_uniqueness(
        key_columns=["user_id", "content_id", "event_time"]
    ) \
    .check_range(
        column="watch_duration_seconds",
        min_val=0,
        max_val=86400  # 24 hours max
    ) \
    .check_format(
        column="user_id",
        pattern=r"^user_\d{8}$"
    ) \
    .check_freshness(
        timestamp_column="event_time",
        max_delay_hours=2
    ) \
    .quarantine_bad_records(
        quarantine_path="s3a://netflix-data/quarantine/"
    ) \
    .get_report()

print(json.dumps(quality_report, indent=2))

# Write clean data
checker.df.write \
    .mode("overwrite") \
    .partitionBy("year", "month", "day") \
    .parquet("s3a://netflix-data/viewing-clean/")

spark.stop()

Data Anomaly Detection

from pyspark.sql.window import Window

# Statistical anomaly detection
def detect_anomalies(df, column, threshold=3):
    """Detect values beyond N standard deviations from mean"""
    stats = df.agg(
        avg(column).alias("mean"),
        stddev(column).alias("stddev")
    ).collect()[0]
    
    mean_val = stats["mean"]
    stddev_val = stats["stddev"]
    
    anomalies = df.withColumn(
        "z_score",
        (col(column) - mean_val) / stddev_val
    ).withColumn(
        "is_anomaly",
        when(abs(col("z_score")) > threshold, True).otherwise(False)
    )
    
    return anomalies

# Time-series anomaly detection
def detect_time_anomalies(df, time_col, value_col, window_size=7):
    """Detect anomalies using rolling statistics"""
    window_spec = Window.orderBy(time_col).rowsBetween(-window_size, -1)
    
    df_with_stats = df \
        .withColumn("rolling_avg", avg(value_col).over(window_spec)) \
        .withColumn("rolling_std", stddev(value_col).over(window_spec)) \
        .withColumn(
            "upper_bound",
            col("rolling_avg") + 3 * col("rolling_std")
        ) \
        .withColumn(
            "lower_bound",
            col("rolling_avg") - 3 * col("rolling_std")
        ) \
        .withColumn(
            "is_anomaly",
            when(
                (col(value_col) > col("upper_bound")) | 
                (col(value_col) < col("lower_bound")),
                True
            ).otherwise(False)
        )
    
    return df_with_stats

Best Practices

πŸ’‘Data Quality Checklist

  • Define quality rules BEFORE building pipelines
  • Implement quarantine for bad data (don't just drop)
  • Monitor quality metrics over time
  • Set up alerts for quality degradation
  • Log all quality violations for audit
  • Test quality rules with known bad data
  • Use schema evolution carefully
  • Validate at source, not just destination

Summary

Data quality is not optional in production pipelines. At Netflix and Amazon, poor data quality can impact recommendations and revenue. Implementing comprehensive validation, deduplication, and quarantine mechanisms ensures downstream systems receive clean, reliable data.

Advertisement