From 1edf1f04e368a33f0bb37c7fc52b739bc4d6f1b4 Mon Sep 17 00:00:00 2001 From: Fede Date: Fri, 16 Jan 2026 11:54:51 +0200 Subject: [PATCH 1/8] Fix metadata routing --- julearn/api.py | 26 +++++++++++++++++------ julearn/model_selection/final_model_cv.py | 22 ++++++++++++++----- julearn/pipeline/pipeline_creator.py | 12 +++++++++++ julearn/prepare.py | 11 +++++++++- julearn/tests/test_api.py | 1 - julearn/utils/testing.py | 6 +++++- 6 files changed, 63 insertions(+), 15 deletions(-) diff --git a/julearn/api.py b/julearn/api.py index b1437a254..d3ff19848 100644 --- a/julearn/api.py +++ b/julearn/api.py @@ -4,6 +4,7 @@ # Sami Hamdan # License: AGPL +import inspect from typing import Optional, Union import numpy as np @@ -26,7 +27,10 @@ from .utils.typing import CVLike -def _validata_api_params( # noqa: C901 +sklearn.set_config(enable_metadata_routing=True) + + +def _validate_api_params( # noqa: C901 X: list[str], # noqa: N803 y: str, model: Union[str, PipelineCreator, BaseEstimator, list[PipelineCreator]], @@ -533,7 +537,7 @@ def run_cross_validation( return_inspector, wrap_score, problem_type, - ) = _validata_api_params( + ) = _validate_api_params( X=X, y=y, model=model, @@ -561,7 +565,7 @@ def run_cross_validation( ) logger.info(f"Using outer CV scheme {cv_outer}") - check_consistency(df_y, cv, groups, problem_type) # type: ignore + groups_needed = check_consistency(df_y, cv, groups, problem_type) # type: ignore scoring = check_scoring( pipeline, # type: ignore @@ -570,10 +574,18 @@ def run_cross_validation( ) cv_mdsum = _compute_cvmdsum(cv_outer) + fit_params = {} if df_groups is not None: - if isinstance(pipeline, BaseSearchCV): - fit_params["groups"] = df_groups.values + if groups_needed: + # If we need groups, we have to pass them to the fit method of + # the last step of the pipeline + if not isinstance(pipeline, BaseSearchCV): + last_step = pipeline.steps[-1][1] + argspec = inspect.getfullargspec(last_step.fit) + if "groups" in argspec.args: + last_step.set_fit_request(groups=True) + fit_params["groups"] = df_groups.values _sklearn_deprec_fit_params = {} if sklearn.__version__ >= "1.4.0": @@ -587,7 +599,7 @@ def run_cross_validation( df_y, cv=cv_outer, scoring=scoring, - groups=df_groups, + # groups=df_groups, return_estimator=cv_return_estimator, n_jobs=n_jobs, return_train_score=return_train_score, @@ -776,7 +788,7 @@ def run_fit( _, _, problem_type, - ) = _validata_api_params( + ) = _validate_api_params( X=X, y=y, model=model, diff --git a/julearn/model_selection/final_model_cv.py b/julearn/model_selection/final_model_cv.py index 0852abaed..739382d3b 100644 --- a/julearn/model_selection/final_model_cv.py +++ b/julearn/model_selection/final_model_cv.py @@ -7,13 +7,11 @@ from typing import TYPE_CHECKING, Optional import numpy as np +from sklearn.model_selection import BaseCrossValidator +from sklearn.model_selection._split import GroupsConsumerMixin -if TYPE_CHECKING: - from sklearn.model_selection import BaseCrossValidator - - -class _JulearnFinalModelCV: +class _JulearnFinalModelCV(BaseCrossValidator): """Final model cross-validation iterator. Wraps any CV iterator to provide an extra iteration with the full dataset. @@ -30,6 +28,20 @@ def __init__(self, cv: "BaseCrossValidator") -> None: if hasattr(cv, "n_repeats"): self.n_repeats = cv.n_repeats + def get_metadata_routing(self) -> dict: + """Get metadata routing information from the underlying CV. + + Returns + ------- + dict + The metadata routing information. + + """ + if hasattr(self.cv, "get_metadata_routing"): + return self.cv.get_metadata_routing() + else: + return {} + def split( self, X: np.ndarray, # noqa: N803 diff --git a/julearn/pipeline/pipeline_creator.py b/julearn/pipeline/pipeline_creator.py index 06ce6a5c0..f7b4a0aad 100644 --- a/julearn/pipeline/pipeline_creator.py +++ b/julearn/pipeline/pipeline_creator.py @@ -9,7 +9,9 @@ from typing import Any, Optional, Union import numpy as np +import sklearn from scipy import stats +from sklearn.ensemble import AdaBoostClassifier, AdaBoostRegressor from sklearn.model_selection import RandomizedSearchCV, check_cv from sklearn.pipeline import Pipeline @@ -417,6 +419,16 @@ def add( # noqa: C901 f"{name}__{param}": val for param, val in params_to_tune.items() } + # Disable metadata routing for AdaBoost-based estimators until + # scikit-learn implements them. + + if isinstance(step, (AdaBoostClassifier, AdaBoostRegressor)): + warn_with_log( + "Disabling metadata routing for AdaBoost-based " + "estimators until scikit-learn implements them." + ) + sklearn.set_config(enable_metadata_routing=False) + self._steps.append( Step( name=name, diff --git a/julearn/prepare.py b/julearn/prepare.py index b575c33bf..c8b75b18c 100644 --- a/julearn/prepare.py +++ b/julearn/prepare.py @@ -365,7 +365,7 @@ def check_consistency( cv: CVLike, groups: Optional[pd.Series], problem_type: str, -) -> None: +) -> bool: """Check the consistency of the parameters/input. Parameters @@ -379,6 +379,11 @@ def check_consistency( problem_type : str The problem type. Can be "classification" or "regression". + Returns + ------- + groups_needed : bool + True if the groups variable is needed for the CV scheme, False otherwise. + Raises ------ ValueError @@ -435,6 +440,7 @@ def check_consistency( "The problem type must be either 'classification' or 'regression'." ) # Check groups and CV scheme + groups_needed = False if groups is not None: valid_instances = ( GroupKFold, @@ -450,6 +456,9 @@ def check_consistency( "The parameter groups was specified but the CV strategy " "will not consider them." ) + else: + groups_needed = True + return groups_needed def _check_x_types( diff --git a/julearn/tests/test_api.py b/julearn/tests/test_api.py index cfd6a2110..41273dc97 100644 --- a/julearn/tests/test_api.py +++ b/julearn/tests/test_api.py @@ -549,7 +549,6 @@ def test_tune_hyperparam_gridsearch_groups(df_iris: pd.DataFrame) -> None: sk_y, # type: ignore cv=cv_outer, scoring=[scoring], - groups=sk_groups, params={"groups": sk_groups}, ) diff --git a/julearn/utils/testing.py b/julearn/utils/testing.py index f34422294..0eb8eea48 100644 --- a/julearn/utils/testing.py +++ b/julearn/utils/testing.py @@ -245,13 +245,17 @@ def do_scoring_test( ) np.random.seed(42) + if sk_groups is not None: + params = dict(groups=sk_groups) + else: + params = {} expected = cross_validate( sklearn_model, # type: ignore sk_X, sk_y, cv=sk_cv, scoring=scorers, - groups=sk_groups, # type: ignore + params=params, # type: ignore ) # Compare the models From e595de3c6ed9516669702dd204524107074a7761 Mon Sep 17 00:00:00 2001 From: Fede Date: Fri, 16 Jan 2026 11:57:55 +0200 Subject: [PATCH 2/8] fix linter --- julearn/model_selection/final_model_cv.py | 3 +-- julearn/prepare.py | 3 ++- julearn/utils/testing.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/julearn/model_selection/final_model_cv.py b/julearn/model_selection/final_model_cv.py index 739382d3b..05c7f1b31 100644 --- a/julearn/model_selection/final_model_cv.py +++ b/julearn/model_selection/final_model_cv.py @@ -4,11 +4,10 @@ # License: AGPL from collections.abc import Generator -from typing import TYPE_CHECKING, Optional +from typing import, Optional import numpy as np from sklearn.model_selection import BaseCrossValidator -from sklearn.model_selection._split import GroupsConsumerMixin class _JulearnFinalModelCV(BaseCrossValidator): diff --git a/julearn/prepare.py b/julearn/prepare.py index c8b75b18c..f53725836 100644 --- a/julearn/prepare.py +++ b/julearn/prepare.py @@ -382,7 +382,8 @@ def check_consistency( Returns ------- groups_needed : bool - True if the groups variable is needed for the CV scheme, False otherwise. + True if the groups variable is needed for the CV scheme, + False otherwise. Raises ------ diff --git a/julearn/utils/testing.py b/julearn/utils/testing.py index 0eb8eea48..8c37db705 100644 --- a/julearn/utils/testing.py +++ b/julearn/utils/testing.py @@ -246,7 +246,7 @@ def do_scoring_test( np.random.seed(42) if sk_groups is not None: - params = dict(groups=sk_groups) + params = {"groups": sk_groups} else: params = {} expected = cross_validate( From 7ebd3bf973bd13d30d80e842749dc87ad33293e3 Mon Sep 17 00:00:00 2001 From: Fede Date: Fri, 16 Jan 2026 11:59:41 +0200 Subject: [PATCH 3/8] fix linter --- julearn/model_selection/final_model_cv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/julearn/model_selection/final_model_cv.py b/julearn/model_selection/final_model_cv.py index 05c7f1b31..dcd3bae2a 100644 --- a/julearn/model_selection/final_model_cv.py +++ b/julearn/model_selection/final_model_cv.py @@ -4,7 +4,7 @@ # License: AGPL from collections.abc import Generator -from typing import, Optional +from typing import Optional import numpy as np from sklearn.model_selection import BaseCrossValidator From 4a18a3f653cdbc07cd0f5d3d54d1a314834dce12 Mon Sep 17 00:00:00 2001 From: Fede Date: Fri, 16 Jan 2026 12:13:34 +0200 Subject: [PATCH 4/8] Add pytest fixture for metadata routing enabling --- julearn/conftest.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/julearn/conftest.py b/julearn/conftest.py index 926209bf2..08ea021e0 100644 --- a/julearn/conftest.py +++ b/julearn/conftest.py @@ -19,6 +19,14 @@ } +@pytest.fixture(autouse=True) +def enable_metadata_routing() -> None: + """Enable metadata routing in sklearn for all tests.""" + import sklearn + + sklearn.set_config(enable_metadata_routing=True) + + def pytest_configure(config: pytest.Config) -> None: """Add a new marker to pytest. From 311336079849ae18142a9e9e0f3d622b8f5bf2fb Mon Sep 17 00:00:00 2001 From: Fede Date: Fri, 16 Jan 2026 12:49:18 +0200 Subject: [PATCH 5/8] Fix metadata routing for bayes search CV --- .../run_hyperparameter_tuning_bayessearch.py | 5 +++++ julearn/api.py | 11 ++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py b/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py index 05cf4cdad..5ad168266 100644 --- a/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py +++ b/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py @@ -19,6 +19,7 @@ import numpy as np from seaborn import load_dataset +import sklearn from julearn import run_cross_validation from julearn.utils import configure_logging, logger @@ -29,6 +30,10 @@ # Set the logging level to info to see extra information. configure_logging(level="INFO") +############################################################################### +# Disable metadata routing to avoid errors due to BayesSearchCV being used. +sklearn.set_config(enable_metadata_routing=False) + ############################################################################### # Set the random seed to always have the same example. np.random.seed(42) diff --git a/julearn/api.py b/julearn/api.py index d3ff19848..fa79d6eff 100644 --- a/julearn/api.py +++ b/julearn/api.py @@ -576,6 +576,8 @@ def run_cross_validation( cv_mdsum = _compute_cvmdsum(cv_outer) fit_params = {} + _sklearn_deprec_fit_params = {} + if df_groups is not None: if groups_needed: # If we need groups, we have to pass them to the fit method of @@ -587,11 +589,10 @@ def run_cross_validation( last_step.set_fit_request(groups=True) fit_params["groups"] = df_groups.values - _sklearn_deprec_fit_params = {} - if sklearn.__version__ >= "1.4.0": - _sklearn_deprec_fit_params["params"] = fit_params - else: - _sklearn_deprec_fit_params["fit_params"] = fit_params + if sklearn.__version__ >= "1.4.0": + _sklearn_deprec_fit_params["params"] = fit_params + else: + _sklearn_deprec_fit_params["fit_params"] = fit_params scores = cross_validate( pipeline, From 0d235c830e9057d45cd7889807ede8020fc371ed Mon Sep 17 00:00:00 2001 From: Fede Date: Wed, 21 Jan 2026 16:26:23 +0200 Subject: [PATCH 6/8] Add debug message --- julearn/api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/julearn/api.py b/julearn/api.py index fa79d6eff..39d74907c 100644 --- a/julearn/api.py +++ b/julearn/api.py @@ -586,6 +586,9 @@ def run_cross_validation( last_step = pipeline.steps[-1][1] argspec = inspect.getfullargspec(last_step.fit) if "groups" in argspec.args: + logger.debug( + "Pipeline's last step uses groups," \ + "calling `set_fit_request(groups=True)`") last_step.set_fit_request(groups=True) fit_params["groups"] = df_groups.values From 86d4204cd79bc1356a11b2daf3067335f36eb040 Mon Sep 17 00:00:00 2001 From: Fede Date: Wed, 21 Jan 2026 16:29:46 +0200 Subject: [PATCH 7/8] Be more strict with group CV and group parameters --- julearn/prepare.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/julearn/prepare.py b/julearn/prepare.py index f53725836..b3e917536 100644 --- a/julearn/prepare.py +++ b/julearn/prepare.py @@ -442,23 +442,28 @@ def check_consistency( ) # Check groups and CV scheme groups_needed = False + valid_group_cv_instances = ( + GroupKFold, + GroupShuffleSplit, + LeaveOneGroupOut, + LeavePGroupsOut, + StratifiedGroupKFold, + ContinuousStratifiedGroupKFold, + RepeatedContinuousStratifiedGroupKFold, + ) if groups is not None: - valid_instances = ( - GroupKFold, - GroupShuffleSplit, - LeaveOneGroupOut, - LeavePGroupsOut, - StratifiedGroupKFold, - ContinuousStratifiedGroupKFold, - RepeatedContinuousStratifiedGroupKFold, - ) - if not isinstance(cv, valid_instances): + if not isinstance(cv, valid_group_cv_instances): warn_with_log( "The parameter groups was specified but the CV strategy " "will not consider them." ) else: groups_needed = True + elif isinstance(cv, valid_group_cv_instances): + raise_error( + "The CV strategy requires groups but the parameter groups was " + "not specified." + ) return groups_needed From d02ef73d219b1f48865d9bc131ef5e7d5e907c8a Mon Sep 17 00:00:00 2001 From: Fede Date: Tue, 10 Feb 2026 16:33:10 +0200 Subject: [PATCH 8/8] Add changes --- docs/changes/newsfragments/298.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/changes/newsfragments/298.bugfix diff --git a/docs/changes/newsfragments/298.bugfix b/docs/changes/newsfragments/298.bugfix new file mode 100644 index 000000000..916eb1957 --- /dev/null +++ b/docs/changes/newsfragments/298.bugfix @@ -0,0 +1 @@ +Enable and use metadata routing for hyperparameter tuning estimators by `Fede Raimondo`_ \ No newline at end of file