diff --git a/docs/changes/newsfragments/298.bugfix b/docs/changes/newsfragments/298.bugfix new file mode 100644 index 000000000..916eb1957 --- /dev/null +++ b/docs/changes/newsfragments/298.bugfix @@ -0,0 +1 @@ +Enable and use metadata routing for hyperparameter tuning estimators by `Fede Raimondo`_ \ No newline at end of file diff --git a/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py b/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py index 05cf4cdad..5ad168266 100644 --- a/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py +++ b/examples/03_complex_models/run_hyperparameter_tuning_bayessearch.py @@ -19,6 +19,7 @@ import numpy as np from seaborn import load_dataset +import sklearn from julearn import run_cross_validation from julearn.utils import configure_logging, logger @@ -29,6 +30,10 @@ # Set the logging level to info to see extra information. configure_logging(level="INFO") +############################################################################### +# Disable metadata routing to avoid errors due to BayesSearchCV being used. +sklearn.set_config(enable_metadata_routing=False) + ############################################################################### # Set the random seed to always have the same example. np.random.seed(42) diff --git a/julearn/api.py b/julearn/api.py index b1437a254..39d74907c 100644 --- a/julearn/api.py +++ b/julearn/api.py @@ -4,6 +4,7 @@ # Sami Hamdan # License: AGPL +import inspect from typing import Optional, Union import numpy as np @@ -26,7 +27,10 @@ from .utils.typing import CVLike -def _validata_api_params( # noqa: C901 +sklearn.set_config(enable_metadata_routing=True) + + +def _validate_api_params( # noqa: C901 X: list[str], # noqa: N803 y: str, model: Union[str, PipelineCreator, BaseEstimator, list[PipelineCreator]], @@ -533,7 +537,7 @@ def run_cross_validation( return_inspector, wrap_score, problem_type, - ) = _validata_api_params( + ) = _validate_api_params( X=X, y=y, model=model, @@ -561,7 +565,7 @@ def run_cross_validation( ) logger.info(f"Using outer CV scheme {cv_outer}") - check_consistency(df_y, cv, groups, problem_type) # type: ignore + groups_needed = check_consistency(df_y, cv, groups, problem_type) # type: ignore scoring = check_scoring( pipeline, # type: ignore @@ -570,16 +574,28 @@ def run_cross_validation( ) cv_mdsum = _compute_cvmdsum(cv_outer) - fit_params = {} - if df_groups is not None: - if isinstance(pipeline, BaseSearchCV): - fit_params["groups"] = df_groups.values + fit_params = {} _sklearn_deprec_fit_params = {} - if sklearn.__version__ >= "1.4.0": - _sklearn_deprec_fit_params["params"] = fit_params - else: - _sklearn_deprec_fit_params["fit_params"] = fit_params + + if df_groups is not None: + if groups_needed: + # If we need groups, we have to pass them to the fit method of + # the last step of the pipeline + if not isinstance(pipeline, BaseSearchCV): + last_step = pipeline.steps[-1][1] + argspec = inspect.getfullargspec(last_step.fit) + if "groups" in argspec.args: + logger.debug( + "Pipeline's last step uses groups," \ + "calling `set_fit_request(groups=True)`") + last_step.set_fit_request(groups=True) + fit_params["groups"] = df_groups.values + + if sklearn.__version__ >= "1.4.0": + _sklearn_deprec_fit_params["params"] = fit_params + else: + _sklearn_deprec_fit_params["fit_params"] = fit_params scores = cross_validate( pipeline, @@ -587,7 +603,7 @@ def run_cross_validation( df_y, cv=cv_outer, scoring=scoring, - groups=df_groups, + # groups=df_groups, return_estimator=cv_return_estimator, n_jobs=n_jobs, return_train_score=return_train_score, @@ -776,7 +792,7 @@ def run_fit( _, _, problem_type, - ) = _validata_api_params( + ) = _validate_api_params( X=X, y=y, model=model, diff --git a/julearn/conftest.py b/julearn/conftest.py index 926209bf2..08ea021e0 100644 --- a/julearn/conftest.py +++ b/julearn/conftest.py @@ -19,6 +19,14 @@ } +@pytest.fixture(autouse=True) +def enable_metadata_routing() -> None: + """Enable metadata routing in sklearn for all tests.""" + import sklearn + + sklearn.set_config(enable_metadata_routing=True) + + def pytest_configure(config: pytest.Config) -> None: """Add a new marker to pytest. diff --git a/julearn/model_selection/final_model_cv.py b/julearn/model_selection/final_model_cv.py index 0852abaed..dcd3bae2a 100644 --- a/julearn/model_selection/final_model_cv.py +++ b/julearn/model_selection/final_model_cv.py @@ -4,16 +4,13 @@ # License: AGPL from collections.abc import Generator -from typing import TYPE_CHECKING, Optional +from typing import Optional import numpy as np +from sklearn.model_selection import BaseCrossValidator -if TYPE_CHECKING: - from sklearn.model_selection import BaseCrossValidator - - -class _JulearnFinalModelCV: +class _JulearnFinalModelCV(BaseCrossValidator): """Final model cross-validation iterator. Wraps any CV iterator to provide an extra iteration with the full dataset. @@ -30,6 +27,20 @@ def __init__(self, cv: "BaseCrossValidator") -> None: if hasattr(cv, "n_repeats"): self.n_repeats = cv.n_repeats + def get_metadata_routing(self) -> dict: + """Get metadata routing information from the underlying CV. + + Returns + ------- + dict + The metadata routing information. + + """ + if hasattr(self.cv, "get_metadata_routing"): + return self.cv.get_metadata_routing() + else: + return {} + def split( self, X: np.ndarray, # noqa: N803 diff --git a/julearn/pipeline/pipeline_creator.py b/julearn/pipeline/pipeline_creator.py index 06ce6a5c0..f7b4a0aad 100644 --- a/julearn/pipeline/pipeline_creator.py +++ b/julearn/pipeline/pipeline_creator.py @@ -9,7 +9,9 @@ from typing import Any, Optional, Union import numpy as np +import sklearn from scipy import stats +from sklearn.ensemble import AdaBoostClassifier, AdaBoostRegressor from sklearn.model_selection import RandomizedSearchCV, check_cv from sklearn.pipeline import Pipeline @@ -417,6 +419,16 @@ def add( # noqa: C901 f"{name}__{param}": val for param, val in params_to_tune.items() } + # Disable metadata routing for AdaBoost-based estimators until + # scikit-learn implements them. + + if isinstance(step, (AdaBoostClassifier, AdaBoostRegressor)): + warn_with_log( + "Disabling metadata routing for AdaBoost-based " + "estimators until scikit-learn implements them." + ) + sklearn.set_config(enable_metadata_routing=False) + self._steps.append( Step( name=name, diff --git a/julearn/prepare.py b/julearn/prepare.py index b575c33bf..b3e917536 100644 --- a/julearn/prepare.py +++ b/julearn/prepare.py @@ -365,7 +365,7 @@ def check_consistency( cv: CVLike, groups: Optional[pd.Series], problem_type: str, -) -> None: +) -> bool: """Check the consistency of the parameters/input. Parameters @@ -379,6 +379,12 @@ def check_consistency( problem_type : str The problem type. Can be "classification" or "regression". + Returns + ------- + groups_needed : bool + True if the groups variable is needed for the CV scheme, + False otherwise. + Raises ------ ValueError @@ -435,21 +441,30 @@ def check_consistency( "The problem type must be either 'classification' or 'regression'." ) # Check groups and CV scheme + groups_needed = False + valid_group_cv_instances = ( + GroupKFold, + GroupShuffleSplit, + LeaveOneGroupOut, + LeavePGroupsOut, + StratifiedGroupKFold, + ContinuousStratifiedGroupKFold, + RepeatedContinuousStratifiedGroupKFold, + ) if groups is not None: - valid_instances = ( - GroupKFold, - GroupShuffleSplit, - LeaveOneGroupOut, - LeavePGroupsOut, - StratifiedGroupKFold, - ContinuousStratifiedGroupKFold, - RepeatedContinuousStratifiedGroupKFold, - ) - if not isinstance(cv, valid_instances): + if not isinstance(cv, valid_group_cv_instances): warn_with_log( "The parameter groups was specified but the CV strategy " "will not consider them." ) + else: + groups_needed = True + elif isinstance(cv, valid_group_cv_instances): + raise_error( + "The CV strategy requires groups but the parameter groups was " + "not specified." + ) + return groups_needed def _check_x_types( diff --git a/julearn/tests/test_api.py b/julearn/tests/test_api.py index cfd6a2110..41273dc97 100644 --- a/julearn/tests/test_api.py +++ b/julearn/tests/test_api.py @@ -549,7 +549,6 @@ def test_tune_hyperparam_gridsearch_groups(df_iris: pd.DataFrame) -> None: sk_y, # type: ignore cv=cv_outer, scoring=[scoring], - groups=sk_groups, params={"groups": sk_groups}, ) diff --git a/julearn/utils/testing.py b/julearn/utils/testing.py index f34422294..8c37db705 100644 --- a/julearn/utils/testing.py +++ b/julearn/utils/testing.py @@ -245,13 +245,17 @@ def do_scoring_test( ) np.random.seed(42) + if sk_groups is not None: + params = {"groups": sk_groups} + else: + params = {} expected = cross_validate( sklearn_model, # type: ignore sk_X, sk_y, cv=sk_cv, scoring=scorers, - groups=sk_groups, # type: ignore + params=params, # type: ignore ) # Compare the models