Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 7 additions & 33 deletions energy_fault_detector/config/quickstart_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import yaml

from .config import Config


def _build_preprocessor_steps(
*,
Expand Down Expand Up @@ -102,35 +104,6 @@ def _build_preprocessor_steps(
return steps


def _dump_yaml_if_requested(
config: Dict[str, Any],
output_path: Optional[Union[str, Path]],
) -> None:
"""
Write the configuration dictionary to a YAML file if a path is provided.

Args:
config (Dict[str, Any]): The configuration dictionary to serialize.
output_path (Optional[Union[str, Path]]): Destination path. If None, nothing is written.

Raises:
RuntimeError: If PyYAML is not installed but output_path is not None.
"""
if output_path is None:
return

if yaml is None: # pragma: no cover - optional dependency
raise RuntimeError(
"PyYAML is not installed; install 'pyyaml' or set output_path=None."
)

path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)

with path.open("w", encoding="utf-8") as f:
yaml.safe_dump(config, f, sort_keys=False)


def generate_quickstart_config(
output_path: Optional[Union[str, Path]] = "base_config.yaml",
*,
Expand All @@ -153,7 +126,7 @@ def generate_quickstart_config(
epochs: int = 10,
layers: Optional[List[int]] = None,
learning_rate: float = 1e-3,
) -> Dict[str, Any]:
) -> Config:
"""
Generate a minimal, valid configuration for EnergyFaultDetector.

Expand Down Expand Up @@ -185,7 +158,7 @@ def generate_quickstart_config(
learning_rate (float): Optimizer learning rate.

Returns:
Dict[str, Any]: Configuration dictionary ready for Config(config_dict=...).
Config: Configuration ready to use: FaultDetector(config).

Raises:
ValueError: If early_stopping is True but validation_split is not in (0, 1).
Expand Down Expand Up @@ -238,9 +211,10 @@ def generate_quickstart_config(
# "data_clipping": {"lower_percentile": 0.001, "upper_percentile": 0.999},
}

config: Dict[str, Any] = {"train": train_config}
config = Config(config_dict={"train": train_config})

# Optionally write YAML
_dump_yaml_if_requested(config=config, output_path=output_path)
if output_path:
config.write_config(output_path)

return config
38 changes: 34 additions & 4 deletions tests/config/test_quickstart_config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from unittest import TestCase
from typing import Any, Dict
from pathlib import Path
import tempfile
import shutil

from energy_fault_detector.config.config import Config
from energy_fault_detector.config.quickstart_config import generate_quickstart_config # adjust
from energy_fault_detector.config.quickstart_config import generate_quickstart_config


class TestQuickstartConfig(TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.test_dir)

def test_generate_quickstart_config_valid_dict(self) -> None:
"""Should return a valid config dict that Config accepts and includes required sections."""
cfg: Dict[str, Any] = generate_quickstart_config(
cfg: Config = generate_quickstart_config(
output_path=None,
angle_columns=["theta_deg"],
counter_columns=["energy_total_kwh"],
Expand Down Expand Up @@ -36,12 +45,33 @@ def test_generate_quickstart_config_valid_dict(self) -> None:
"Expected a scaler step in the pipeline."
)

# Should not raise: validate via Config
Config(config_dict=cfg)
self.assertTrue(
any(n in ("standard_scaler", "minmax_scaler") for n in step_names),
"Expected a scaler step in the pipeline."
)

def test_generate_quickstart_config_validation_split_guard(self) -> None:
"""If validation split not in (0, 1) it should raise ValueError."""
with self.assertRaises(ValueError):
_ = generate_quickstart_config(
validation_split=0.0, # invalid by design
)

def test_save_load_quickstart_config(self) -> None:
"""Test that generated config can be saved and loaded back correctly."""
config_path = Path(self.test_dir) / "config.yaml"
cfg: Config = generate_quickstart_config(
output_path=config_path,
angle_columns=["theta_deg"],
early_stopping=True,
validation_split=0.25
)

self.assertTrue(config_path.exists())

# Load the config back
loaded_cfg = Config(config_path)

# Compare dictionaries
self.assertEqual(cfg.config_dict, loaded_cfg.config_dict)
self.assertEqual(loaded_cfg["train"]["data_splitter"]["validation_split"], 0.25)