Imbalanced Data: SMOTE, Class Weights and Sampling
Imbalanced datasets are common in fraud detection, medical diagnosis, and anomaly detection. This lesson covers techniques to handle class imbalance.
The Imbalanced Data Problem
<svg width="600" height="350" viewBox="0 0 600 350" xmlns="http://www.w3.org/2000/svg">
<rect width="600" height="350" fill="#f8f9fa" rx="10"/>
<text x="300" y="30" text-anchor="middle" font-size="18" font-weight="bold" fill="#2c3e50">Class Imbalance Problem</text>
<!-- Balanced Dataset -->
<rect x="50" y="60" width="200" height="120" fill="white" stroke="#2ecc71" stroke-width="2" rx="5"/>
<text x="150" y="80" text-anchor="middle" font-size="12" font-weight="bold" fill="#2c3e50">Balanced Dataset</text>
<rect x="70" y="95" width="80" height="70" fill="#3498db" rx="3"/>
<text x="110" y="135" text-anchor="middle" font-size="11" fill="white">Class 0</text>
<text x="110" y="150" text-anchor="middle" font-size="10" fill="white">50%</text>
<rect x="160" y="95" width="80" height="70" fill="#2ecc71" rx="3"/>
<text x="200" y="135" text-anchor="middle" font-size="11" fill="white">Class 1</text>
<text x="200" y="150" text-anchor="middle" font-size="10" fill="white">50%</text>
<!-- Imbalanced Dataset -->
<rect x="350" y="60" width="200" height="120" fill="white" stroke="#e74c3c" stroke-width="2" rx="5"/>
<text x="450" y="80" text-anchor="middle" font-size="12" font-weight="bold" fill="#2c3e50">Imbalanced Dataset</text>
<rect x="370" y="95" width="160" height="70" fill="#3498db" rx="3"/>
<text x="450" y="135" text-anchor="middle" font-size="11" fill="white">Class 0</text>
<text x="450" y="150" text-anchor="middle" font-size="10" fill="white">95%</text>
<rect x="370" y="95" width="10" height="70" fill="#e74c3c" rx="3"/>
<text x="380" y="88" text-anchor="middle" font-size="9" fill="#e74c3c">5% Class 1</text>
<!-- Consequences -->
<text x="300" y="210" text-anchor="middle" font-size="14" font-weight="bold" fill="#2c3e50">Why It Matters:</text>
<text x="300" y="240" text-anchor="middle" font-size="11" fill="#e74c3c">• Model predicts majority class only</text>
<text x="300" y="260" text-anchor="middle" font-size="11" fill="#e74c3c">• Misleading accuracy (95% accuracy = 0% recall on minority)</text>
<text x="300" y="280" text-anchor="middle" font-size="11" fill="#e74c3c">• Poor generalization on rare events</text>
<!-- Solutions -->
<text x="300" y="310" text-anchor="middle" font-size="14" font-weight="bold" fill="#27ae60">Solutions:</text>
<text x="300" y="330" text-anchor="middle" font-size="11" fill="#27ae60">Resampling | SMOTE | Class Weights | Ensemble Methods | Metrics</text>
</svg>
Resampling Techniques
import pandas as pd
import numpy as np
from sklearn.utils import resample
# Split data by class
df_majority = df[df['target'] == 0]
df_minority = df[df['target'] == 1]
# 1. Random Oversampling
df_minority_upsampled = resample(
df_minority,
replace=True,
n_samples=len(df_majority),
random_state=42
)
df_oversampled = pd.concat([df_majority, df_minority_upsampled])
# 2. Random Undersampling
df_majority_downsampled = resample(
df_majority,
replace=False,
n_samples=len(df_minority),
random_state=42
)
df_undersampled = pd.concat([df_majority_downsampled, df_minority])
# 3. Tomek Links (undersampling with cleaning)
from imblearn.under_sampling import TomekLinks
tl = TomekLinks()
X_resampled, y_resampled = tl.fit_resample(X, y)
# 4. Edited Nearest Neighbors
from imblearn.under_sampling import EditedNearestNeighbors
enn = EditedNearestNeighbors()
X_resampled, y_resampled = enn.fit_resample(X, y)
SMOTE Algorithm
from imblearn.over_sampling import SMOTE, ADASYN, BorderlineSMOTE
# Basic SMOTE
smote = SMOTE(random_state=42, k_neighbors=5)
X_smote, y_smote = smote.fit_resample(X, y)
# Borderline SMOTE (focuses on decision boundary)
bsmote = BorderlineSMOTE(random_state=42)
X_bsmote, y_bsmote = bsmote.fit_resample(X, y)
# ADASYN (adaptive synthetic sampling)
adasyn = ADASYN(random_state=42)
X_adasyn, y_adasyn = adasyn.fit_resample(X, y)
# SMOTE with Tomek Links (combination)
from imblearn.combine import SMOTETomek
smt = SMOTETomek(random_state=42)
X_smt, y_smt = smt.fit_resample(X, y)
Class Weights
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
# Method 1: 'balanced' parameter
model = LogisticRegression(class_weight='balanced', random_state=42)
model.fit(X_train, y_train)
# Method 2: Custom class weights
# Inversely proportional to class frequency
class_counts = np.bincount(y_train)
class_weights = len(y_train) / (2 * class_counts)
weight_dict = {0: class_weights[0], 1: class_weights[1]}
model = LogisticRegression(class_weight=weight_dict, random_state=42)
# Method 3: Sample weights in fit
sample_weights = np.where(y_train == 1, 10, 1) # Higher weight for minority
model = LogisticRegression(random_state=42)
model.fit(X_train, y_train, sample_weight=sample_weights)
# For Random Forest
rf = RandomForestClassifier(
class_weight='balanced_subsample', # Adjusts weights per bootstrap
random_state=42
)
Evaluation Metrics
from sklearn.metrics import (
classification_report, confusion_matrix,
precision_recall_curve, roc_curve, auc,
f1_score, roc_auc_score, average_precision_score,
matthews_corrcoef, balanced_accuracy_score
)
# Avoid accuracy for imbalanced data
y_pred = model.predict(X_test)
y_prob = model.predict_proba(X_test)[:, 1]
# Key metrics
print(classification_report(y_test, y_pred))
# F1 Score (harmonic mean of precision and recall)
f1 = f1_score(y_test, y_pred)
# AUROC (Area Under ROC Curve)
auroc = roc_auc_score(y_test, y_prob)
# AUPRC (Area Under Precision-Recall Curve)
auprc = average_precision_score(y_test, y_prob)
# Matthews Correlation Coefficient
mcc = matthews_corrcoef(y_test, y_pred)
# Balanced Accuracy
bal_acc = balanced_accuracy_score(y_test, y_pred)
# Visualize metrics
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# ROC Curve
fpr, tpr, _ = roc_curve(y_test, y_prob)
axes[0].plot(fpr, tpr, label=f'AUROC = {auroc:.3f}')
axes[0].plot([0, 1], [0, 1], 'k--')
axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate')
axes[0].set_title('ROC Curve')
axes[0].legend()
# Precision-Recall Curve
precision, recall, _ = precision_recall_curve(y_test, y_prob)
axes[1].plot(recall, precision, label=f'AUPRC = {auprc:.3f}')
axes[1].set_xlabel('Recall')
axes[1].set_ylabel('Precision')
axes[1].set_title('Precision-Recall Curve')
axes[1].legend()
plt.tight_layout()
plt.show()
Threshold Optimization
from sklearn.metrics import precision_recall_curve
precision, recall, thresholds = precision_recall_curve(y_test, y_prob)
# Find threshold that maximizes F1
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
optimal_threshold = thresholds[np.argmax(f1_scores)]
# Apply custom threshold
y_pred_custom = (y_prob >= optimal_threshold).astype(int)
print(f"Optimal threshold: {optimal_threshold:.3f}")
print(f"F1 at optimal threshold: {f1_score(y_test, y_pred_custom):.3f}")
# Cost-sensitive threshold (if misclassification costs differ)
cost_fp = 1 # Cost of false positive
cost_fn = 10 # Cost of false negative (higher for rare class)
total_costs = (precision * cost_fp + (1 - recall) * cost_fn)
optimal_cost_threshold = thresholds[np.argmin(total_costs)]
Ensemble Methods for Imbalanced Data
from imblearn.ensemble import (
BalancedRandomForestClassifier,
BalancedBaggingClassifier,
EasyEnsembleClassifier,
RUSBoostClassifier
)
# Balanced Random Forest
brf = BalancedRandomForestClassifier(
n_estimators=100,
random_state=42
)
brf.fit(X_train, y_train)
# Balanced Bagging
bbc = BalancedBaggingClassifier(
n_estimators=10,
random_state=42
)
bbc.fit(X_train, y_train)
# Easy Ensemble
ee = EasyEnsembleClassifier(
n_estimators=10,
random_state=42
)
ee.fit(X_train, y_train)
# RUSBoost
rusboost = RUSBoostClassifier(
n_estimators=100,
random_state=42
)
rusboost.fit(X_train, y_train)
Key Takeaways
- Never use accuracy alone for imbalanced problems
- Use SMOTE or class weights as first approaches
- Optimize decision threshold for your cost function
- Consider ensemble methods like BalancedRandomForest
- Always evaluate with AUPRC for severely imbalanced data