Decision Trees — Complete Guide with Visualizations

ML FoundationsClassificationFree Lesson

Advertisement

Decision Trees — Complete Guide

Decision trees make predictions by learning simple rules from data — like a flowchart of if-then-else decisions.


How Decision Trees Work

Example: Should I Play Tennis?

Outlook    Temp   Humidity   Wind    Play?
─────────────────────────────────────────────
Sunny      Hot    High       Weak    No
Sunny      Hot    High       Strong  No
Overcast   Hot    High       Weak    Yes
Rain       Mild   High       Weak    Yes
Rain       Cool   Normal     Weak    Yes
Rain       Cool   Normal     Strong  No
Overcast   Cool   Normal     Strong  Yes
Sunny      Mild   High       Weak    No
Sunny      Cool   Normal     Weak    Yes
Rain       Mild   Normal     Weak    Yes
Sunny      Mild   Normal     Strong  Yes
Overcast   Mild   High       Strong  Yes
Overcast   Hot    Normal     Weak    Yes
Rain       Mild   High       Strong  No

Tree:
Outlook?
├─ Sunny → Humidity?
│         ├─ High → No
│         └─ Normal → Yes
├─ Overcast → Yes
└─ Rain → Wind?
          ├─ Weak → Yes
          └─ Strong → No

Splitting Criteria

Gini Impurity

Gini = 1 - Σ(pᵢ²)

Where pᵢ = proportion of class i

Pure node (all same class):
Gini = 1 - (1²) = 0

Impure node (50/50):
Gini = 1 - (0.5² + 0.5²) = 0.5

Most impure (uniform):
Gini = 1 - K × (1/K)² = 1 - 1/K

Information Gain (Entropy)

Entropy = -Σ pᵢ log₂(pᵢ)

Pure node: Entropy = 0
Impure node: Entropy = 1 (binary)

Information Gain = Entropy(parent) - Weighted Avg Entropy(children)

Split on feature with HIGHEST information gain

Python Implementation

from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
    iris.data, iris.target, test_size=0.2)

# Fit tree
tree = DecisionTreeClassifier(max_depth=3, random_state=42)
tree.fit(X_train, y_train)

# Evaluate
print(f"Accuracy: {tree.score(X_test, y_test):.3f}")

# Visualize
print(export_text(tree, feature_names=iris.feature_names))

# Feature importance
for name, imp in zip(iris.feature_names, tree.feature_importances_):
    print(f"{name}: {imp:.3f}")

Pruning

Problem: Deep trees overfit

Pre-pruning (stop early):
├─ max_depth: Limit tree depth
├─ min_samples_split: Minimum samples to split
├─ min_samples_leaf: Minimum samples in leaf
└─ max_features: Consider only subset of features

Post-pruning (grow then cut):
├─ Cost-complexity pruning (α parameter)
├─ Reduced error pruning
└─ Use validation set to decide when to stop

Unpruned tree:        Pruned tree:
Training: 100%        Training: 95%
Test: 70%             Test: 93%
(overfitting)         (better generalization)

Key Takeaways

  1. Decision trees are easy to understand and visualize
  2. Gini impurity or Information Gain for splitting
  3. Pruning prevents overfitting
  4. Decision trees are the building blocks for Random Forests and Gradient Boosting
  5. Feature importance shows which features matter most
  6. Decision trees handle mixed data types (numerical + categorical)
  7. Non-parametric — no assumptions about data distribution
  8. Unstable — small data changes can create very different trees

Advertisement

Need Expert Machine Learning Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement