From c6b933689cb12e1bcc8e05c6f4378f97725f6864 Mon Sep 17 00:00:00 2001 From: Kirill Poliakov Date: Wed, 1 Feb 2023 15:59:07 +0300 Subject: [PATCH] Stratify val_ids in stratified_train_val_test_split --- dpipe/split/cv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dpipe/split/cv.py b/dpipe/split/cv.py index 94c1f98..b2f45c1 100644 --- a/dpipe/split/cv.py +++ b/dpipe/split/cv.py @@ -89,7 +89,8 @@ def stratified_train_val_test_split(ids: Sequence, labels: Union[Callable, Seque train_val_ids = extract(ids, train_val_indices) test_ids = extract(ids, test_indices) if val_size: - train_ids, val_ids = train_test_split(train_val_ids, test_size=val_size, random_state=25 + i) + train_val_labels = extract(labels, train_val_indices) + train_ids, val_ids = train_test_split(train_val_ids, test_size=val_size, random_state=25 + i, stratify=train_val_labels) else: train_ids, val_ids = train_val_ids, []