diff --git a/energy_fault_detector/config/quickstart_config.py b/energy_fault_detector/config/quickstart_config.py index 208187b..ebc3885 100644 --- a/energy_fault_detector/config/quickstart_config.py +++ b/energy_fault_detector/config/quickstart_config.py @@ -3,6 +3,8 @@ import yaml +from .config import Config + def _build_preprocessor_steps( *, @@ -102,35 +104,6 @@ def _build_preprocessor_steps( return steps -def _dump_yaml_if_requested( - config: Dict[str, Any], - output_path: Optional[Union[str, Path]], -) -> None: - """ - Write the configuration dictionary to a YAML file if a path is provided. - - Args: - config (Dict[str, Any]): The configuration dictionary to serialize. - output_path (Optional[Union[str, Path]]): Destination path. If None, nothing is written. - - Raises: - RuntimeError: If PyYAML is not installed but output_path is not None. - """ - if output_path is None: - return - - if yaml is None: # pragma: no cover - optional dependency - raise RuntimeError( - "PyYAML is not installed; install 'pyyaml' or set output_path=None." - ) - - path = Path(output_path) - path.parent.mkdir(parents=True, exist_ok=True) - - with path.open("w", encoding="utf-8") as f: - yaml.safe_dump(config, f, sort_keys=False) - - def generate_quickstart_config( output_path: Optional[Union[str, Path]] = "base_config.yaml", *, @@ -153,7 +126,7 @@ def generate_quickstart_config( epochs: int = 10, layers: Optional[List[int]] = None, learning_rate: float = 1e-3, -) -> Dict[str, Any]: +) -> Config: """ Generate a minimal, valid configuration for EnergyFaultDetector. @@ -185,7 +158,7 @@ def generate_quickstart_config( learning_rate (float): Optimizer learning rate. Returns: - Dict[str, Any]: Configuration dictionary ready for Config(config_dict=...). + Config: Configuration ready to use: FaultDetector(config). Raises: ValueError: If early_stopping is True but validation_split is not in (0, 1). @@ -238,9 +211,10 @@ def generate_quickstart_config( # "data_clipping": {"lower_percentile": 0.001, "upper_percentile": 0.999}, } - config: Dict[str, Any] = {"train": train_config} + config = Config(config_dict={"train": train_config}) # Optionally write YAML - _dump_yaml_if_requested(config=config, output_path=output_path) + if output_path: + config.write_config(output_path) return config diff --git a/tests/config/test_quickstart_config.py b/tests/config/test_quickstart_config.py index ae55b30..4e93506 100644 --- a/tests/config/test_quickstart_config.py +++ b/tests/config/test_quickstart_config.py @@ -1,14 +1,23 @@ from unittest import TestCase from typing import Any, Dict +from pathlib import Path +import tempfile +import shutil from energy_fault_detector.config.config import Config -from energy_fault_detector.config.quickstart_config import generate_quickstart_config # adjust +from energy_fault_detector.config.quickstart_config import generate_quickstart_config class TestQuickstartConfig(TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir) + def test_generate_quickstart_config_valid_dict(self) -> None: """Should return a valid config dict that Config accepts and includes required sections.""" - cfg: Dict[str, Any] = generate_quickstart_config( + cfg: Config = generate_quickstart_config( output_path=None, angle_columns=["theta_deg"], counter_columns=["energy_total_kwh"], @@ -36,8 +45,10 @@ def test_generate_quickstart_config_valid_dict(self) -> None: "Expected a scaler step in the pipeline." ) - # Should not raise: validate via Config - Config(config_dict=cfg) + self.assertTrue( + any(n in ("standard_scaler", "minmax_scaler") for n in step_names), + "Expected a scaler step in the pipeline." + ) def test_generate_quickstart_config_validation_split_guard(self) -> None: """If validation split not in (0, 1) it should raise ValueError.""" @@ -45,3 +56,22 @@ def test_generate_quickstart_config_validation_split_guard(self) -> None: _ = generate_quickstart_config( validation_split=0.0, # invalid by design ) + + def test_save_load_quickstart_config(self) -> None: + """Test that generated config can be saved and loaded back correctly.""" + config_path = Path(self.test_dir) / "config.yaml" + cfg: Config = generate_quickstart_config( + output_path=config_path, + angle_columns=["theta_deg"], + early_stopping=True, + validation_split=0.25 + ) + + self.assertTrue(config_path.exists()) + + # Load the config back + loaded_cfg = Config(config_path) + + # Compare dictionaries + self.assertEqual(cfg.config_dict, loaded_cfg.config_dict) + self.assertEqual(loaded_cfg["train"]["data_splitter"]["validation_split"], 0.25)