What Is a Decision Tree?
A decision tree partitions the feature space into rectangular regions by learning a sequence of if-then-else rules from data. It handles both classification and regression.
Is MedInc > 3.5?
├── YES → Is HouseAge > 20?
│ ├── YES → Predict: $280k
│ └── NO → Predict: $320k
└── NO → Is AveRooms > 4?
├── YES → Predict: $190k
└── NO → Predict: $150k
Splitting Criteria
For Classification — Gini Impurity
Pure node: . Maximally impure (50/50 binary): .
Information Gain (split quality):
For Classification — Entropy
For Regression — MSE / Variance Reduction
Worked Example
Data: 10 samples, predict loan default
Feature: Income > $50k? YES: 7 samples (1 default, 6 no)
NO: 3 samples (3 default, 0 no)
Parent Gini = 1 - (4/10)² - (6/10)² = 1 - 0.16 - 0.36 = 0.48
Left Gini = 1 - (1/7)² - (6/7)² = 1 - 0.020 - 0.735 = 0.245
Right Gini = 1 - (3/3)² - (0/3)² = 1 - 1.0 - 0.0 = 0.000
Weighted Gini = (7/10)(0.245) + (3/10)(0.000) = 0.172
Information Gain = 0.48 - 0.172 = 0.308 ← choose the split with highest IG
Complete Python Implementation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris, load_breast_cancer
from sklearn.tree import (DecisionTreeClassifier, DecisionTreeRegressor,
plot_tree, export_text)
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
import warnings; warnings.filterwarnings("ignore")
# ── Classification ────────────────────────────────────────────────────
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y)
# Fully grown tree (overfits)
tree_full = DecisionTreeClassifier(criterion="gini", random_state=42)
tree_full.fit(X_train, y_train)
# Pruned tree
tree_pruned = DecisionTreeClassifier(
criterion="gini",
max_depth=3, # limit depth
min_samples_split=10, # min samples to split a node
min_samples_leaf=5, # min samples in a leaf
random_state=42,
)
tree_pruned.fit(X_train, y_train)
print(f"Full tree depth : {tree_full.get_depth()}")
print(f"Pruned tree depth : {tree_pruned.get_depth()}")
print(f"Full Test Acc : {accuracy_score(y_test, tree_full.predict(X_test)):.4f}")
print(f"Pruned Test Acc : {accuracy_score(y_test, tree_pruned.predict(X_test)):.4f}")
# CV scores
cv_full = cross_val_score(tree_full, X, y, cv=5).mean()
cv_pruned = cross_val_score(tree_pruned, X, y, cv=5).mean()
print(f"Full CV Acc : {cv_full:.4f}")
print(f"Pruned CV Acc : {cv_pruned:.4f}")
# Visualise tree
plt.figure(figsize=(14, 6))
plot_tree(tree_pruned, feature_names=iris.feature_names,
class_names=iris.target_names, filled=True, fontsize=10)
plt.title("Pruned Decision Tree (depth=3) — Iris Dataset")
plt.tight_layout(); plt.show()
# Text representation
print(export_text(tree_pruned, feature_names=list(iris.feature_names)))
# ── Cost-Complexity Pruning ───────────────────────────────────────────
cancer = load_breast_cancer()
X2, y2 = cancer.data, cancer.target
X_train2, X_test2, y_train2, y_test2 = train_test_split(
X2, y2, test_size=0.2, random_state=42)
path = DecisionTreeClassifier(random_state=42).cost_complexity_pruning_path(
X_train2, y_train2)
alphas = path.ccp_alphas[:-1] # remove last (trivial tree)
train_scores, test_scores = [], []
for a in alphas:
dt = DecisionTreeClassifier(ccp_alpha=a, random_state=42)
dt.fit(X_train2, y_train2)
train_scores.append(dt.score(X_train2, y_train2))
test_scores.append(dt.score(X_test2, y_test2))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.plot(alphas, train_scores, "o-", label="Train", color="#3b82f6")
ax1.plot(alphas, test_scores, "o-", label="Test", color="#ef4444")
ax1.set_xlabel("ccp_alpha"); ax1.set_ylabel("Accuracy")
ax1.set_title("Accuracy vs Alpha (Cost-Complexity Pruning)")
ax1.legend(); ax1.grid(True, alpha=0.3)
# Best alpha
best_alpha = alphas[np.argmax(test_scores)]
best_tree = DecisionTreeClassifier(ccp_alpha=best_alpha, random_state=42)
best_tree.fit(X_train2, y_train2)
ax2.bar(["Train", "Test", "CV"],
[best_tree.score(X_train2, y_train2),
best_tree.score(X_test2, y_test2),
cross_val_score(best_tree, X2, y2, cv=5).mean()],
color=["#3b82f6","#10b981","#f59e0b"], edgecolor="white")
ax2.set_title(f"Best Tree (alpha={best_alpha:.5f}, depth={best_tree.get_depth()})")
ax2.set_ylim(0.8, 1.0); ax2.grid(True, alpha=0.3, axis="y")
plt.tight_layout(); plt.show()
Hyperparameter Reference
| Parameter | Controls | Recommendation |
|---|---|---|
max_depth | Tree depth | Start at 3–5; tune via CV |
min_samples_split | Min samples to attempt a split | 2–20 |
min_samples_leaf | Min samples in terminal node | 1–10 |
max_features | Features considered per split | sqrt or log2 |
ccp_alpha | Cost-complexity pruning | Tune via pruning path |
Key Takeaways
- Trees split on the feature/threshold that maximises information gain (Gini/entropy)
- Fully grown trees overfit — prune with
max_depth,min_samples_leaf, orccp_alpha - Cost-complexity pruning finds the optimal depth automatically
- Single trees are interpretable but high variance — use Random Forest for production
- Decision trees handle mixed types, missing values, and non-linear boundaries naturally