Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 27 additions & 37 deletions openml/setups/setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# License: BSD 3-Clause
from __future__ import annotations

from dataclasses import asdict, dataclass
from typing import Any

import openml.config
import openml.flows


@dataclass
class OpenMLSetup:
"""Setup object (a.k.a. Configuration).

Expand All @@ -20,20 +22,20 @@ class OpenMLSetup:
The setting of the parameters
"""

def __init__(self, setup_id: int, flow_id: int, parameters: dict[int, Any] | None):
if not isinstance(setup_id, int):
setup_id: int
flow_id: int
parameters: dict[int, Any] | None

def __post_init__(self) -> None:
if not isinstance(self.setup_id, int):
raise ValueError("setup id should be int")

if not isinstance(flow_id, int):
if not isinstance(self.flow_id, int):
raise ValueError("flow id should be int")

if parameters is not None and not isinstance(parameters, dict):
if self.parameters is not None and not isinstance(self.parameters, dict):
raise ValueError("parameters should be dict")

self.setup_id = setup_id
self.flow_id = flow_id
self.parameters = parameters

def _to_dict(self) -> dict[str, Any]:
return {
"setup_id": self.setup_id,
Expand Down Expand Up @@ -66,6 +68,7 @@ def __repr__(self) -> str:
return header + body


@dataclass
class OpenMLParameter:
"""Parameter object (used in setup).

Expand All @@ -91,37 +94,24 @@ class OpenMLParameter:
If the parameter was set, the value that it was set to.
"""

def __init__( # noqa: PLR0913
self,
input_id: int,
flow_id: int,
flow_name: str,
full_name: str,
parameter_name: str,
data_type: str,
default_value: str,
value: str,
):
self.id = input_id
self.flow_id = flow_id
self.flow_name = flow_name
self.full_name = full_name
self.parameter_name = parameter_name
self.data_type = data_type
self.default_value = default_value
self.value = value
input_id: int
flow_id: int
flow_name: str
full_name: str
parameter_name: str
data_type: str
default_value: str
value: str

def __post_init__(self) -> None:
# Map input_id to id for backward compatibility
self.id = self.input_id

def _to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"flow_id": self.flow_id,
"flow_name": self.flow_name,
"full_name": self.full_name,
"parameter_name": self.parameter_name,
"data_type": self.data_type,
"default_value": self.default_value,
"value": self.value,
}
result = asdict(self)
# Replaces input_id with id for backward compatibility
result["id"] = result.pop("input_id")
return result

def __repr__(self) -> str:
header = "OpenML Parameter"
Expand Down
17 changes: 17 additions & 0 deletions tests/test_setups/test_setup_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _existing_setup_exists(self, classif):
assert setup_id == run.setup_id

@pytest.mark.sklearn()
@pytest.mark.xfail(
reason="Dataset 20 has processing errors on test server - see issue #1544",
raises=openml.exceptions.OpenMLServerException,
)
def test_existing_setup_exists_1(self):
def side_effect(self):
self.var_smoothing = 1e-9
Expand All @@ -97,11 +101,19 @@ def side_effect(self):
self._existing_setup_exists(nb)

@pytest.mark.sklearn()
@pytest.mark.xfail(
reason="Dataset 20 has processing errors on test server - see issue #1544",
raises=openml.exceptions.OpenMLServerException,
)
def test_exisiting_setup_exists_2(self):
# Check a flow with one hyperparameter
self._existing_setup_exists(sklearn.naive_bayes.GaussianNB())

@pytest.mark.sklearn()
@pytest.mark.xfail(
reason="Dataset 20 has processing errors on test server - see issue #1544",
raises=openml.exceptions.OpenMLServerException,
)
def test_existing_setup_exists_3(self):
# Check a flow with many hyperparameters
self._existing_setup_exists(
Expand Down Expand Up @@ -166,6 +178,11 @@ def test_list_setups_output_format(self):
def test_setuplist_offset(self):
size = 10
setups = openml.setups.list_setups(offset=0, size=size)

# Skip if test server has no setup data - see issue #1544
if len(setups) == 0:
pytest.skip("Test server has no setup data available")

assert len(setups) == size
setups2 = openml.setups.list_setups(offset=size, size=size)
assert len(setups2) == size
Expand Down