Cross-validation using StratifiedKFold with an exogeneous group feature

214 views Asked by At

Good morning/afternoon, I would like to use cross-validation in sklearn for the prediction of a continuous variable.

I have refered to the "Visualizing cross-validation behavior in scikit-learn" page to select the cross-validation method suited to my problem. https://scikit-learn.org/stable/auto_examples/model_selection/plot_cv_indices.html#sphx-glr-auto-examples-model-selection-plot-cv-indices-py

I want to use StratifiedKFold but it does not provide a way to use a "stratifying" variable that is not the target variable ("class") as in the example below.

enter image description here

What I would like is to use the "group" variable to stratify instead.

Currently, what I do is this:

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score

skf = StratifiedKFold(n_splits=5, 
                      shuffle = True,
                      random_state=57)
cross_val_score(regr, X, y, cv=skf.split(training,groups))

where regr is my regressor, X my features, y my target and groups a panda Series of my prefered "stratifying" variable. I have checked that skf.split(training,groups) provides splits suited to my needs, i.e., train and test sets where the original distribution of my groups is maintained.

However, I have no mean to check that the cross-validation have the behavior I am expecting. Am I correct? Can I check?

1

There are 1 answers

2
DataJanitor On BEST ANSWER

Your approach looks correct to me, even if it is rather uncommon.

You could check if the stratification worked with this code:

# Setup StratifiedKFold, just as you did
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=57)

# Set in which the seen test indiced are put
seen_test_indices = set()

# Iterating over each fold
for split_id, (train_index, test_index) in enumerate(skf.split(X, groups), start=1):
    
    # Check if any of the test indices have been seen before
    overlapping_indices = seen_test_indices.intersection(test_index)
    if overlapping_indices:
        print(f"Overlap detected in Split ID {split_id} with indices {overlapping_indices}")
        break
    seen_test_indices.update(test_index)
    
    # Distribution of 'groups' in train and test split
    train_groups_distribution = np.bincount(groups[train_index])
    test_groups_distribution = np.bincount(groups[test_index])
    
    print(f"Split ID: {split_id}")
    print("Train Groups Distribution:", train_groups_distribution)
    print("Test Groups Distribution:", test_groups_distribution)
    print("-----")

I wouldn't use it if the variable groups has too many distinct/unique values. If each group has only a small number of samples, StratifiedKFold might throw an error due to not having enough samples to create stratified folds.