🔄 StratifiedKFold Explained with Example (Scikit-Learn)
Cross-validation is a powerful technique for evaluating machine learning models. But when working with imbalanced datasets, regular KFold may produce test sets where one class is missing.
That’s where StratifiedKFold comes in: it ensures class proportions are preserved across folds. Let’s walk through a small example and analyze a multiple-choice question about possible training sets.
📌 The Code
import numpy as np
from sklearn.model_selection import StratifiedKFold
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
y = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0])
kf = StratifiedKFold(n_splits=3)
for train, test in kf.split(X, y):
print(X[train])
🧮 Step 1: Dataset
-
Features (
X) =[1, 2, 3, 4, 5, 6, 7, 8, 9] -
Labels (
y) =[0, 1, 0, 1, 0, 1, 0, 1, 0]
So:
-
Class
0→[1, 3, 5, 7, 9](5 samples) -
Class
1→[2, 4, 6, 8](4 samples)
Class ratio = 5 : 4
🧩 Step 2: How StratifiedKFold Splits
-
We set
n_splits=3. -
That means:
-
Each test fold will have
9 ÷ 3 = 3 samples. -
Each training fold will have
9 − 3 = 6 samples.
-
-
Stratification ensures both classes appear in every fold (ratios preserved as closely as possible).
🔍 Step 3: Analyze Each Option
Let’s check which training sets are possible.
❌ Option [1, 2, 3, 6, 7]
-
Train size = 5 (too small, should be 6).
-
Invalid.
✅ Option [1, 2, 3, 4, 5, 7]
-
Train labels =
[0, 1, 0, 1, 0, 0] → 4 zeros, 2 ones -
Test labels =
[6, 8, 9] → [1, 1, 0] → 2 ones, 1 zero -
Both train and test contain both classes, sizes are correct.
-
Valid.
❌ Option [2, 4, 6, 8]
-
Train labels =
[1, 1, 1, 1] → only class 1 -
Test labels =
[1, 3, 5, 7, 9] → only class 0 -
Stratification broken.
-
Invalid.
✅ Option [1, 2, 3, 6, 8, 9]
-
Train labels =
[0, 1, 0, 1, 1, 0] → 3 zeros, 3 ones -
Test labels =
[4, 5, 7] → [1, 0, 0] → 2 zeros, 1 one -
Balanced distribution, correct sizes.
-
Valid.
❌ Option [2, 5, 8]
-
Train size = 3 (too small, should be 6).
-
Invalid.
✅ Option [4, 5, 6, 7, 8, 9]
-
Train labels =
[1, 0, 1, 0, 1, 0] → 3 zeros, 3 ones -
Test labels =
[1, 2, 3] → [0, 1, 0] → 2 zeros, 1 one -
Balanced distribution, correct sizes.
-
Valid.
🎯 Final Answer
The valid training sets that can be printed are:
-
[1, 2, 3, 4, 5, 7]
-
[1, 2, 3, 6, 8, 9]
-
[4, 5, 6, 7, 8, 9]
✨ Key Takeaways
-
Training set size =
n_samples - (n_samples / n_splits)
→ here, 6 training and 3 test samples each fold. -
StratifiedKFold preserves class balance, but doesn’t require the exact same ratio, just ensures both classes are represented.
-
Options with wrong sizes or missing a class are automatically invalid.
👉 Would you like me to also run this code and show the exact 3 folds that scikit-learn produces? That way, you’ll see the real train/test splits to match against the options.
Comments
Post a Comment