Decision Trees — Theory, Math & Python

Supervised LearningClassification & RegressionFree Lesson

Advertisement

What Is a Decision Tree?

A decision tree partitions the feature space into rectangular regions by learning a sequence of if-then-else rules from data. It handles both classification and regression.

Is MedInc > 3.5?
    ├── YES → Is HouseAge > 20?
    │             ├── YES → Predict: $280k
    │             └── NO  → Predict: $320k
    └── NO  → Is AveRooms > 4?
                  ├── YES → Predict: $190k
                  └── NO  → Predict: $150k

Splitting Criteria

For Classification — Gini Impurity

G(t)=1k=1Kpk2G(t) = 1 - \sum_{k=1}^{K} p_k^2

Pure node: G=0G = 0. Maximally impure (50/50 binary): G=0.5G = 0.5.

Information Gain (split quality):

IG=G(t)nLnG(tL)nRnG(tR)IG = G(t) - \frac{n_L}{n}G(t_L) - \frac{n_R}{n}G(t_R)

For Classification — Entropy

H(t)=k=1Kpklog2(pk)H(t) = -\sum_{k=1}^{K} p_k \log_2(p_k)

IGentropy=H(t)nLnH(tL)nRnH(tR)IG_{entropy} = H(t) - \frac{n_L}{n}H(t_L) - \frac{n_R}{n}H(t_R)

For Regression — MSE / Variance Reduction

Impurity(t)=1ntit(yiyˉt)2\text{Impurity}(t) = \frac{1}{n_t}\sum_{i \in t}(y_i - \bar{y}_t)^2


Worked Example

Data: 10 samples, predict loan default

Feature: Income > $50k?   YES: 7 samples (1 default, 6 no)
                           NO:  3 samples (3 default, 0 no)

Parent Gini = 1 - (4/10)² - (6/10)² = 1 - 0.16 - 0.36 = 0.48

Left Gini  = 1 - (1/7)² - (6/7)² = 1 - 0.020 - 0.735 = 0.245
Right Gini = 1 - (3/3)² - (0/3)² = 1 - 1.0 - 0.0    = 0.000

Weighted Gini = (7/10)(0.245) + (3/10)(0.000) = 0.172

Information Gain = 0.48 - 0.172 = 0.308  ← choose the split with highest IG

Complete Python Implementation

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris, load_breast_cancer
from sklearn.tree import (DecisionTreeClassifier, DecisionTreeRegressor,
                          plot_tree, export_text)
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
import warnings; warnings.filterwarnings("ignore")

# ── Classification ────────────────────────────────────────────────────
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y)

# Fully grown tree (overfits)
tree_full = DecisionTreeClassifier(criterion="gini", random_state=42)
tree_full.fit(X_train, y_train)

# Pruned tree
tree_pruned = DecisionTreeClassifier(
    criterion="gini",
    max_depth=3,           # limit depth
    min_samples_split=10,  # min samples to split a node
    min_samples_leaf=5,    # min samples in a leaf
    random_state=42,
)
tree_pruned.fit(X_train, y_train)

print(f"Full tree depth   : {tree_full.get_depth()}")
print(f"Pruned tree depth : {tree_pruned.get_depth()}")
print(f"Full   Test Acc   : {accuracy_score(y_test, tree_full.predict(X_test)):.4f}")
print(f"Pruned Test Acc   : {accuracy_score(y_test, tree_pruned.predict(X_test)):.4f}")

# CV scores
cv_full   = cross_val_score(tree_full,   X, y, cv=5).mean()
cv_pruned = cross_val_score(tree_pruned, X, y, cv=5).mean()
print(f"Full   CV Acc : {cv_full:.4f}")
print(f"Pruned CV Acc : {cv_pruned:.4f}")

# Visualise tree
plt.figure(figsize=(14, 6))
plot_tree(tree_pruned, feature_names=iris.feature_names,
          class_names=iris.target_names, filled=True, fontsize=10)
plt.title("Pruned Decision Tree (depth=3) — Iris Dataset")
plt.tight_layout(); plt.show()

# Text representation
print(export_text(tree_pruned, feature_names=list(iris.feature_names)))

# ── Cost-Complexity Pruning ───────────────────────────────────────────
cancer = load_breast_cancer()
X2, y2 = cancer.data, cancer.target
X_train2, X_test2, y_train2, y_test2 = train_test_split(
    X2, y2, test_size=0.2, random_state=42)

path = DecisionTreeClassifier(random_state=42).cost_complexity_pruning_path(
    X_train2, y_train2)
alphas = path.ccp_alphas[:-1]   # remove last (trivial tree)

train_scores, test_scores = [], []
for a in alphas:
    dt = DecisionTreeClassifier(ccp_alpha=a, random_state=42)
    dt.fit(X_train2, y_train2)
    train_scores.append(dt.score(X_train2, y_train2))
    test_scores.append(dt.score(X_test2,  y_test2))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(alphas, train_scores, "o-", label="Train", color="#3b82f6")
ax1.plot(alphas, test_scores,  "o-", label="Test",  color="#ef4444")
ax1.set_xlabel("ccp_alpha"); ax1.set_ylabel("Accuracy")
ax1.set_title("Accuracy vs Alpha (Cost-Complexity Pruning)")
ax1.legend(); ax1.grid(True, alpha=0.3)

# Best alpha
best_alpha = alphas[np.argmax(test_scores)]
best_tree = DecisionTreeClassifier(ccp_alpha=best_alpha, random_state=42)
best_tree.fit(X_train2, y_train2)
ax2.bar(["Train", "Test", "CV"],
        [best_tree.score(X_train2, y_train2),
         best_tree.score(X_test2,  y_test2),
         cross_val_score(best_tree, X2, y2, cv=5).mean()],
        color=["#3b82f6","#10b981","#f59e0b"], edgecolor="white")
ax2.set_title(f"Best Tree (alpha={best_alpha:.5f}, depth={best_tree.get_depth()})")
ax2.set_ylim(0.8, 1.0); ax2.grid(True, alpha=0.3, axis="y")

plt.tight_layout(); plt.show()

Hyperparameter Reference

ParameterControlsRecommendation
max_depthTree depthStart at 3–5; tune via CV
min_samples_splitMin samples to attempt a split2–20
min_samples_leafMin samples in terminal node1–10
max_featuresFeatures considered per splitsqrt or log2
ccp_alphaCost-complexity pruningTune via pruning path

Key Takeaways

  1. Trees split on the feature/threshold that maximises information gain (Gini/entropy)
  2. Fully grown trees overfit — prune with max_depth, min_samples_leaf, or ccp_alpha
  3. Cost-complexity pruning finds the optimal depth automatically
  4. Single trees are interpretable but high variance — use Random Forest for production
  5. Decision trees handle mixed types, missing values, and non-linear boundaries naturally

Advertisement

Need Expert Data Science Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement