Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/newsfragments/298.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable and use metadata routing for hyperparameter tuning estimators by `Fede Raimondo`_
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
from seaborn import load_dataset
import sklearn

from julearn import run_cross_validation
from julearn.utils import configure_logging, logger
Expand All @@ -29,6 +30,10 @@
# Set the logging level to info to see extra information.
configure_logging(level="INFO")

###############################################################################
# Disable metadata routing to avoid errors due to BayesSearchCV being used.
sklearn.set_config(enable_metadata_routing=False)

###############################################################################
# Set the random seed to always have the same example.
np.random.seed(42)
Expand Down
42 changes: 29 additions & 13 deletions julearn/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Sami Hamdan <s.hamdan@fz-juelich.de>
# License: AGPL

import inspect
from typing import Optional, Union

import numpy as np
Expand All @@ -26,7 +27,10 @@
from .utils.typing import CVLike


def _validata_api_params( # noqa: C901
sklearn.set_config(enable_metadata_routing=True)


def _validate_api_params( # noqa: C901
X: list[str], # noqa: N803
y: str,
model: Union[str, PipelineCreator, BaseEstimator, list[PipelineCreator]],
Expand Down Expand Up @@ -533,7 +537,7 @@ def run_cross_validation(
return_inspector,
wrap_score,
problem_type,
) = _validata_api_params(
) = _validate_api_params(
X=X,
y=y,
model=model,
Expand Down Expand Up @@ -561,7 +565,7 @@ def run_cross_validation(
)
logger.info(f"Using outer CV scheme {cv_outer}")

check_consistency(df_y, cv, groups, problem_type) # type: ignore
groups_needed = check_consistency(df_y, cv, groups, problem_type) # type: ignore

scoring = check_scoring(
pipeline, # type: ignore
Expand All @@ -570,24 +574,36 @@ def run_cross_validation(
)

cv_mdsum = _compute_cvmdsum(cv_outer)
fit_params = {}
if df_groups is not None:
if isinstance(pipeline, BaseSearchCV):
fit_params["groups"] = df_groups.values

fit_params = {}
_sklearn_deprec_fit_params = {}
if sklearn.__version__ >= "1.4.0":
_sklearn_deprec_fit_params["params"] = fit_params
else:
_sklearn_deprec_fit_params["fit_params"] = fit_params

if df_groups is not None:
if groups_needed:
# If we need groups, we have to pass them to the fit method of
# the last step of the pipeline
if not isinstance(pipeline, BaseSearchCV):
last_step = pipeline.steps[-1][1]
argspec = inspect.getfullargspec(last_step.fit)
if "groups" in argspec.args:
logger.debug(
"Pipeline's last step uses groups," \
"calling `set_fit_request(groups=True)`")
last_step.set_fit_request(groups=True)
fit_params["groups"] = df_groups.values

if sklearn.__version__ >= "1.4.0":
_sklearn_deprec_fit_params["params"] = fit_params
else:
_sklearn_deprec_fit_params["fit_params"] = fit_params

scores = cross_validate(
pipeline,
df_X,
df_y,
cv=cv_outer,
scoring=scoring,
groups=df_groups,
# groups=df_groups,
return_estimator=cv_return_estimator,
n_jobs=n_jobs,
return_train_score=return_train_score,
Expand Down Expand Up @@ -776,7 +792,7 @@ def run_fit(
_,
_,
problem_type,
) = _validata_api_params(
) = _validate_api_params(
X=X,
y=y,
model=model,
Expand Down
8 changes: 8 additions & 0 deletions julearn/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
}


@pytest.fixture(autouse=True)
def enable_metadata_routing() -> None:
"""Enable metadata routing in sklearn for all tests."""
import sklearn

sklearn.set_config(enable_metadata_routing=True)


def pytest_configure(config: pytest.Config) -> None:
"""Add a new marker to pytest.

Expand Down
23 changes: 17 additions & 6 deletions julearn/model_selection/final_model_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
# License: AGPL

from collections.abc import Generator
from typing import TYPE_CHECKING, Optional
from typing import Optional

import numpy as np
from sklearn.model_selection import BaseCrossValidator


if TYPE_CHECKING:
from sklearn.model_selection import BaseCrossValidator


class _JulearnFinalModelCV:
class _JulearnFinalModelCV(BaseCrossValidator):
"""Final model cross-validation iterator.

Wraps any CV iterator to provide an extra iteration with the full dataset.
Expand All @@ -30,6 +27,20 @@ def __init__(self, cv: "BaseCrossValidator") -> None:
if hasattr(cv, "n_repeats"):
self.n_repeats = cv.n_repeats

def get_metadata_routing(self) -> dict:
"""Get metadata routing information from the underlying CV.

Returns
-------
dict
The metadata routing information.

"""
if hasattr(self.cv, "get_metadata_routing"):
return self.cv.get_metadata_routing()
else:
return {}

def split(
self,
X: np.ndarray, # noqa: N803
Expand Down
12 changes: 12 additions & 0 deletions julearn/pipeline/pipeline_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from typing import Any, Optional, Union

import numpy as np
import sklearn
from scipy import stats
from sklearn.ensemble import AdaBoostClassifier, AdaBoostRegressor
from sklearn.model_selection import RandomizedSearchCV, check_cv
from sklearn.pipeline import Pipeline

Expand Down Expand Up @@ -417,6 +419,16 @@ def add( # noqa: C901
f"{name}__{param}": val for param, val in params_to_tune.items()
}

# Disable metadata routing for AdaBoost-based estimators until
# scikit-learn implements them.

if isinstance(step, (AdaBoostClassifier, AdaBoostRegressor)):
warn_with_log(
"Disabling metadata routing for AdaBoost-based "
"estimators until scikit-learn implements them."
)
sklearn.set_config(enable_metadata_routing=False)

self._steps.append(
Step(
name=name,
Expand Down
37 changes: 26 additions & 11 deletions julearn/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def check_consistency(
cv: CVLike,
groups: Optional[pd.Series],
problem_type: str,
) -> None:
) -> bool:
"""Check the consistency of the parameters/input.

Parameters
Expand All @@ -379,6 +379,12 @@ def check_consistency(
problem_type : str
The problem type. Can be "classification" or "regression".

Returns
-------
groups_needed : bool
True if the groups variable is needed for the CV scheme,
False otherwise.

Raises
------
ValueError
Expand Down Expand Up @@ -435,21 +441,30 @@ def check_consistency(
"The problem type must be either 'classification' or 'regression'."
)
# Check groups and CV scheme
groups_needed = False
valid_group_cv_instances = (
GroupKFold,
GroupShuffleSplit,
LeaveOneGroupOut,
LeavePGroupsOut,
StratifiedGroupKFold,
ContinuousStratifiedGroupKFold,
RepeatedContinuousStratifiedGroupKFold,
)
if groups is not None:
valid_instances = (
GroupKFold,
GroupShuffleSplit,
LeaveOneGroupOut,
LeavePGroupsOut,
StratifiedGroupKFold,
ContinuousStratifiedGroupKFold,
RepeatedContinuousStratifiedGroupKFold,
)
if not isinstance(cv, valid_instances):
if not isinstance(cv, valid_group_cv_instances):
warn_with_log(
"The parameter groups was specified but the CV strategy "
"will not consider them."
)
else:
groups_needed = True
elif isinstance(cv, valid_group_cv_instances):
raise_error(
"The CV strategy requires groups but the parameter groups was "
"not specified."
)
return groups_needed


def _check_x_types(
Expand Down
1 change: 0 additions & 1 deletion julearn/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,6 @@ def test_tune_hyperparam_gridsearch_groups(df_iris: pd.DataFrame) -> None:
sk_y, # type: ignore
cv=cv_outer,
scoring=[scoring],
groups=sk_groups,
params={"groups": sk_groups},
)

Expand Down
6 changes: 5 additions & 1 deletion julearn/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,17 @@ def do_scoring_test(
)

np.random.seed(42)
if sk_groups is not None:
params = {"groups": sk_groups}
else:
params = {}
expected = cross_validate(
sklearn_model, # type: ignore
sk_X,
sk_y,
cv=sk_cv,
scoring=scorers,
groups=sk_groups, # type: ignore
params=params, # type: ignore
)

# Compare the models
Expand Down