Decision Trees: Gini, Entropy and Pruning
Decision trees are non-parametric supervised learning models that partition feature space into regions, making predictions based on the majority class (classification) or mean value (regression) within each region. Their intuitive, tree-like structure makes them one of the most interpretable models in machine learning.
How Decision Trees Work
A decision tree recursively partitions the input space by selecting feature thresholds that maximize predictive purity. Each internal node represents a feature test, each branch represents the outcome of that test, and each leaf node represents a prediction.
Recursive Partitioning Process:
- Select the best split — evaluate all features and thresholds
- Partition the data — split into child nodes
- Repeat — recursively split child nodes until stopping criteria are met
- Assign predictions — leaves hold class labels or regression values
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
# Train a decision tree
tree = DecisionTreeClassifier(max_depth=3, random_state=42)
tree.fit(X_train, y_train)
# Visualize
plt.figure(figsize=(12, 8))
plot_tree(tree, feature_names=feature_names, class_names=class_names, filled=True)
plt.show()
Decision Tree Structure
Impurity Measures
The quality of a split is measured by the reduction in impurity it produces. Three primary impurity measures are used in practice.
Gini Impurity
The Gini impurity measures the probability of incorrectly classifying a randomly chosen element if it were labeled according to the class distribution of the dataset.
where is the proportion of class in dataset , and is the number of classes.
Properties:
- Ranges from 0 (pure node) to (maximum impurity)
- For binary classification: maximum Gini = 0.5
- Computationally efficient (no logarithms required)
Entropy
Entropy measures the expected amount of information (surprise) in the class distribution.
Properties:
- Ranges from 0 (pure node) to (maximum impurity)
- For binary classification: maximum Entropy = 1.0
- Rooted in Shannon's information theory
Information Gain
Information Gain quantifies the reduction in entropy (or impurity) achieved by splitting on a feature .
where is the subset of where feature takes value , and is the number of distinct values of .
Gini vs Entropy Comparison
Key Insight: Both measures peak at for binary classification, but Entropy has a sharper peak. In practice, Gini and Entropy produce very similar trees; the difference rarely exceeds 1–2% in accuracy.
Splitting Criteria: Mathematical Framework
For a candidate split on feature with threshold :
Goal: Minimize (or maximize Information Gain)
Gain Ratio (C4.5)
To avoid bias toward multi-valued features, C4.5 uses the Gain Ratio:
where:
For Regression Trees
The splitting criterion minimizes the mean squared error (MSE):
Tree Building Algorithms
ID3 (Iterative Dichotomiser 3)
- Uses Information Gain as splitting criterion
- Handles only categorical features
- No pruning — grows until pure leaves or no more splits
- Prone to overfitting
C4.5 (Successor to ID3)
- Uses Gain Ratio to reduce multi-valued feature bias
- Handles continuous features via thresholding
- Handles missing values through fractional instance weighting
- Includes post-pruning via error-based pruning
CART (Classification and Regression Trees)
- Uses Gini impurity (classification) or MSE (regression)
- Produces binary trees only (2-way splits)
- Supports both classification and regression
- Uses cost-complexity pruning for generalization
from sklearn.tree import DecisionTreeClassifier
# CART with different criteria
cart_gini = DecisionTreeClassifier(criterion='gini', max_depth=5)
cart_entropy = DecisionTreeClassifier(criterion='entropy', max_depth=5)
cart_gini.fit(X_train, y_train)
cart_entropy.fit(X_train, y_train)
print(f"Gini accuracy: {cart_gini.score(X_test, y_test):.4f}")
print(f"Entropy accuracy: {cart_entropy.score(X_test, y_test):.4f}")
Hyperparameters
| Parameter | Description | Default |
|---|---|---|
criterion | 'gini' or 'entropy' (classification), 'mse' (regression) | 'gini' |
max_depth | Maximum tree depth | None (unlimited) |
min_samples_split | Minimum samples to split a node | 2 |
min_samples_leaf | Minimum samples in a leaf node | 1 |
max_features | Number of features to consider for best split | None (all) |
max_leaf_nodes | Maximum number of leaf nodes | None (unlimited) |
min_impurity_decrease | Minimum impurity decrease for a split | 0.0 |
Pruning Strategies
Pruning removes branches that provide little predictive power, reducing overfitting and improving generalization.
Pre-Pruning (Early Stopping)
Stop growing the tree before it becomes too complex.
- max_depth — limit tree height
- min_samples_split — require minimum samples to split
- min_samples_leaf — require minimum samples in leaves
- min_impurity_decrease — require minimum improvement
Advantage: Computationally efficient — avoids growing unnecessary branches.
Disadvantage: May stop too early (horizon effect) — a seemingly poor split now might lead to a good split later.
Post-Pruning
Grow the full tree first, then remove branches that don't improve generalization.
Reduced Error Pruning
- Grow tree to maximum depth
- For each non-leaf node (bottom-up):
- Evaluate validation accuracy if subtree is replaced by a leaf
- Prune if accuracy doesn't decrease
Cost-Complexity Pruning (CART)
Minimize the cost-complexity function:
where:
- is the misclassification rate (resubstitution error)
- is the number of leaf nodes
- is the complexity parameter
Optimal is found via cross-validation:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score
import numpy as np
# Find optimal ccp_alpha
path = DecisionTreeClassifier(random_state=42).cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas
cv_scores = []
for alpha in ccp_alphas:
tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
scores = cross_val_score(tree, X_train, y_train, cv=5, scoring='accuracy')
cv_scores.append(scores.mean())
optimal_alpha = ccp_alphas[np.argmax(cv_scores)]
print(f"Optimal ccp_alpha: {optimal_alpha:.6f}")
# Train final model
pruned_tree = DecisionTreeClassifier(ccp_alpha=optimal_alpha, random_state=42)
pruned_tree.fit(X_train, y_train)
Overfitting vs Proper Fit
Advantages and Disadvantages
Advantages
| Advantage | Description |
|---|---|
| Interpretability | Visual tree structure is easy to explain to stakeholders |
| No feature scaling | Invariant to monotonic feature transformations |
| Handles mixed types | Works with numerical and categorical features |
| Feature importance | Built-in feature ranking via impurity decrease |
| Non-parametric | No assumptions about data distribution |
| Fast inference | prediction time |
Disadvantages
| Disadvantage | Description |
|---|---|
| Overfitting | Prone to memorizing noise without regularization |
| Instability | Small data changes can produce different trees |
| Greedy optimization | Locally optimal splits ≠globally optimal tree |
| Imbalanced data | Bias toward majority classes |
| Axis-aligned splits | Cannot efficiently represent diagonal boundaries |
| Extrapolation | Cannot predict beyond training range (regression) |
Implementation in Python
Complete Pipeline
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree, export_text
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
# Load dataset
from sklearn.datasets import load_breast_cancer
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target)
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# --- 1. Basic Decision Tree ---
basic_tree = DecisionTreeClassifier(random_state=42)
basic_tree.fit(X_train, y_train)
print(f"Depth: {basic_tree.get_depth()}")
print(f"Leaves: {basic_tree.get_n_leaves()}")
print(f"Train accuracy: {basic_tree.score(X_train, y_train):.4f}")
print(f"Test accuracy: {basic_tree.score(X_test, y_test):.4f}")
# --- 2. Hyperparameter Tuning ---
param_grid = {
'max_depth': [3, 5, 7, 10, None],
'min_samples_split': [2, 5, 10, 20],
'min_samples_leaf': [1, 2, 5, 10],
'criterion': ['gini', 'entropy'],
'ccp_alpha': [0.0, 0.001, 0.01, 0.02, 0.05]
}
grid_search = GridSearchCV(
DecisionTreeClassifier(random_state=42),
param_grid,
cv=5,
scoring='accuracy',
n_jobs=-1,
verbose=0
)
grid_search.fit(X_train, y_train)
print(f"\nBest params: {grid_search.best_params_}")
print(f"Best CV accuracy: {grid_search.best_score_:.4f}")
best_tree = grid_search.best_estimator_
print(f"Test accuracy: {best_tree.score(X_test, y_test):.4f}")
# --- 3. Cost-Complexity Pruning Path ---
clf = DecisionTreeClassifier(random_state=42)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
fig, ax = plt.subplots(1, 2, figsize=(14, 5))
ax[0].plot(ccp_alphas, impurities, marker='o', drawstyle='steps-post')
ax[0].set_xlabel('ccp_alpha')
ax[0].set_ylabel('Total impurity of leaves')
ax[0].set_title('Effective Alphas vs Impurities')
trees = []
for ccp_alpha in ccp_alphas:
tree = DecisionTreeClassifier(ccp_alpha=ccp_alpha, random_state=42)
tree.fit(X_train, y_train)
trees.append(tree)
train_scores = [t.score(X_train, y_train) for t in trees]
test_scores = [t.score(X_test, y_test) for t in trees]
ax[1].plot(ccp_alphas, train_scores, marker='o', label='train', drawstyle='steps-post')
ax[1].plot(ccp_alphas, test_scores, marker='o', label='test', drawstyle='steps-post')
ax[1].set_xlabel('ccp_alpha')
ax[1].set_ylabel('Accuracy')
ax[1].set_title('Accuracy vs Alpha')
ax[1].legend()
plt.tight_layout()
plt.show()
# --- 4. Feature Importance ---
importances = best_tree.feature_importances_
indices = np.argsort(importances)[::-1]
plt.figure(figsize=(10, 6))
plt.title('Feature Importances')
plt.bar(range(X.shape[1]), importances[indices], align='center')
plt.xticks(range(X.shape[1]), X.columns[indices], rotation=45, ha='right')
plt.tight_layout()
plt.show()
# --- 5. Text Representation ---
print(export_text(best_tree, feature_names=list(X.columns), max_depth=3))
Regression Trees
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error, r2_score
# Regression example
reg_tree = DecisionTreeRegressor(
max_depth=5,
min_samples_leaf=10,
random_state=42
)
reg_tree.fit(X_train_reg, y_train_reg)
y_pred = reg_tree.predict(X_test_reg)
print(f"RMSE: {np.sqrt(mean_squared_error(y_test_reg, y_pred)):.4f}")
print(f"R²: {r2_score(y_test_reg, y_pred):.4f}")
Key Takeaways
- Gini vs Entropy — both produce similar trees; Gini is slightly faster, Entropy tends to produce slightly more balanced trees
- Pre-pruning is computationally efficient but risks the horizon effect; post-pruning is more robust but requires a validation set
- Cost-complexity pruning with cross-validation is the standard approach for finding optimal tree size
- Decision trees are the foundation for ensemble methods (Random Forests, Gradient Boosting) which address individual tree limitations
- Feature importance from trees is based on impurity decrease and can be biased toward high-cardinality features — consider permutation importance as an alternative
Further Reading
- Breiman, L. et al. (1984). Classification and Regression Trees. Wadsworth.
- Quinlan, J.R. (1986). Induction of Decision Trees. Machine Learning, 1(1), 81–106.
- Quinlan, J.R. (1993). C4.5: Programs for Machine Learning. Morgan Kaufmann.
- Hastie, T., Tibshirani, R., and Friedman, J. (2009). The Elements of Statistical Learning. Springer. Chapter 9