Decision Trees & max_depth — a simple blog + runnable code
What you’ll learn
-
What
max_depthdoes (intuitively). -
How training vs test accuracy usually changes as the tree gets deeper.
-
A small runnable script that plots the accuracies so you can see underfitting vs overfitting.
Intuition (plain language)
-
max_depthcontrols how many splits the tree can make from root → leaf. -
Small
max_depth→ shallow tree → simple model → may underfit (low train & test accuracy). -
Large
max_depth→ deep tree → complex model → may overfit (high train accuracy, lower test accuracy). -
There’s usually a sweet spot where test accuracy is highest.
Run this code (copy & paste)
# demo_decision_tree_max_depth.py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
# 1) Load data and split
X, y = load_wine(return_X_y=True, as_frame=False)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.10, random_state=12, stratify=y
)
# 2) Settings (match the problem's parameters where useful)
min_samples_split = 2
min_samples_leaf = 3
random_state = 81
# 3) Evaluate for a range of max_depth values
max_depths = list(range(1, 16)) # try depths 1..15
train_scores = []
test_scores = []
for depth in max_depths:
clf = DecisionTreeClassifier(
max_depth=depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
random_state=random_state
)
clf.fit(X_train, y_train)
y_train_pred = clf.predict(X_train)
y_test_pred = clf.predict(X_test)
train_scores.append(accuracy_score(y_train, y_train_pred))
test_scores.append(accuracy_score(y_test, y_test_pred))
# 4) Plot the results
plt.figure(figsize=(8,5))
plt.plot(max_depths, train_scores, marker='o', label='Train accuracy')
plt.plot(max_depths, test_scores, marker='o', label='Test accuracy')
plt.xlabel('max_depth')
plt.ylabel('Accuracy')
plt.title('DecisionTree: train vs test accuracy by max_depth')
plt.xticks(max_depths)
plt.grid(axis='y', linestyle='--', alpha=0.4)
plt.legend()
plt.tight_layout()
plt.show()
# 5) Print best depth by test score
best_idx = int(np.argmax(test_scores))
print(f"Best max_depth (by test accuracy): {max_depths[best_idx]}")
print(f"Train acc at best: {train_scores[best_idx]:.4f}, Test acc at best: {test_scores[best_idx]:.4f}")
What to expect when you run it
-
Training accuracy will start low (small depth) and increase as
max_depthincreases, often reaching 1.0 for very deep trees (perfect fit to training set). -
Test accuracy will usually increase until a certain
max_depth, then plateau or decrease as the model overfits. -
If you decrease
max_depthfrom 4 (as in the earlier example where training score was ~0.9875), the training score will most likely decrease (or stay the same), because the model is less able to memorize training examples.
Short summary for interviews / MCQs
-
max_depth ↓→ complexity ↓ → training accuracy non-increasing (usually decreases). -
For training score printed in the snippet you saw: decreasing
max_depthwill generally decrease the printed value. -
For test score: decreasing
max_depthcan increase or decrease depending on whether you remove overfitting or cause underfitting.
Next steps (try these)
-
Run the script and observe the plot — visual intuition helps a lot.
-
Change
min_samples_leafortest_sizeand see how the curves move. -
Use cross-validation (e.g.,
GridSearchCV) to pick the bestmax_depthreliably.
If you want, I can:
-
Run the code here and show the plot and numeric results (I can run it and display output), or
-
Add an even simpler explanation for absolute beginners with pictures and a one-page cheat-sheet.
Which would you like next?
Here’s a quick comparison table of commonly used metrics:
| Problem Type | Metric (scoring in sklearn) | What it measures |
|---|---|---|
| Classification | "accuracy" | Fraction of correctly classified samples |
"precision" | How many predicted positives are correct | |
"recall" | How many actual positives are correctly identified | |
"f1" | Harmonic mean of precision & recall (good for imbalanced classes) | |
"roc_auc" | Area under ROC curve (how well the model separates classes) | |
| Regression | "r2" | Proportion of variance explained by the model |
"neg_mean_absolute_error" | Average absolute difference between predicted & true values | |
"neg_mean_squared_error" | Average squared difference (penalizes larger errors more) | |
"explained_variance" | Variance explained by the model predictions |
⚠️ Note: For regression errors, scikit-learn uses negative values (neg_mean_squared_error) because cross_val_score() always tries to maximize the metric.
Comments
Post a Comment