diff --git a/src/osekit/core_api/frequency_scale.py b/src/osekit/core_api/frequency_scale.py index b2709afd..6b7423b5 100644 --- a/src/osekit/core_api/frequency_scale.py +++ b/src/osekit/core_api/frequency_scale.py @@ -11,6 +11,7 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Literal import numpy as np @@ -18,44 +19,59 @@ from osekit.utils.core_utils import get_closest_value_index +@dataclass(frozen=True) class ScalePart: """Represent a part of the frequency scale of a spectrogram. - The given part goes from: - p_min (in % of the axis), representing f_min - to: - p_max (in % of the axis), representing f_max - """ - - def __init__( - self, - p_min: float, - p_max: float, - f_min: float, - f_max: float, - scale_type: Literal["lin", "log"] = "lin", - ) -> None: - """Initialize a ScalePart. + p_min: float + Relative position of the bottom of the scale part on the full scale. + Must be in the interval [0.0, 1.0], where 0.0 is the bottom of the scale + and 1.0 is the top. + p_max: float + Relative position of the top of the scale part on the full scale. + Must be in the interval [0.0, 1.0], where 0.0 is the bottom of the scale + and 1.0 is the top. + f_min: float + Frequency corresponding to the bottom of the scale part. + f_max: float + Frequency corresponding to the top of the scale part. + scale_type: Literal["lin", "log"] + Type of the scale, either linear or logarithmic. - Parameters - ---------- - p_min: float - Position (in percent) of the bottom of the scale part on the full scale. - p_max: float - Position (in percent) of the top of the scale part on the full scale. - f_min: float - Frequency corresponding to the bottom of the scale part. - f_max: float - Frequency corresponding to the top of the scale part. - scale_type: Literal["lin", "log"] - Type of the scale, either linear or logarithmic. + """ - """ - self.p_min = p_min - self.p_max = p_max - self.f_min = f_min - self.f_max = f_max - self.scale_type: Literal["lin", "log"] = scale_type + p_min: float + p_max: float + f_min: float + f_max: float + scale_type: Literal["lin", "log"] = "lin" + + def __post_init__(self) -> None: + """Check if ScalePart values are correct.""" + err = [] + if not 0.0 <= self.p_min <= 1.0: + err.append(f"p_min must be between 0 and 1, got {self.p_min}") + if not 0.0 <= self.p_max <= 1.0: + err.append(f"p_max must be between 0 and 1, got {self.p_max}") + if self.p_min >= self.p_max: + err.append( + f"p_min must be strictly inferior than p_max, got ({self.p_min},{self.p_max})", + ) + if self.f_min < 0: + err.append( + f"f_min must be positive, got {self.f_min}", + ) + if self.f_max < 0: + err.append( + f"f_max must be positive, got {self.f_max}", + ) + if self.f_min >= self.f_max: + err.append( + f"f_min must be strictly inferior than f_max, got ({self.f_min},{self.f_max})", + ) + if err: + msg = "\n".join(err) + raise ValueError(msg) def get_frequencies(self, nb_points: int) -> list[int]: """Return the frequency points of the present scale part.""" diff --git a/tests/test_frequency_scales.py b/tests/test_frequency_scales.py index 2cb314bb..ccb56a22 100644 --- a/tests/test_frequency_scales.py +++ b/tests/test_frequency_scales.py @@ -1,3 +1,5 @@ +import contextlib + import numpy as np import pytest @@ -721,3 +723,112 @@ def test_frequency_scale_equality(scale1: Scale, scale2: Scale, expected: bool) ) def test_frequency_scale_serialization(scale: Scale) -> None: assert Scale.from_dict_value(scale.to_dict_value()) == scale + + +@pytest.mark.parametrize( + ("p_min", "p_max", "f_min", "f_max", "expected"), + [ + pytest.param( + -0.5, + 1.0, + 1.0, + 100.0, + pytest.raises( + ValueError, + match="p_min must be between 0 and 1, got -0\\.5", + ), + id="negative_min", + ), + pytest.param( + 5.0, + 1.0, + 1.0, + 100.0, + pytest.raises( + ValueError, + match="p_min must be between 0 and 1, got 5\\.0\n" + "p_min must be strictly inferior than p_max, got \\(5\\.0,1\\.0\\)", + ), + id="min_too_big", + ), + pytest.param( + 0.0, + -1.0, + 1.0, + 100.0, + pytest.raises( + ValueError, + match="p_max must be between 0 and 1, got -1\\.0\n" + "p_min must be strictly inferior than p_max, got \\(0\\.0,-1\\.0\\)", + ), + id="negative_max", + ), + pytest.param( + 0.0, + 2.0, + 1.0, + 100.0, + pytest.raises( + ValueError, + match=r"p_max must be between 0 and 1, got 2.0", + ), + id="max_too_big", + ), + pytest.param( + 0.5, + 0.5, + 1.0, + 100.0, + pytest.raises( + ValueError, + match="p_min must be strictly inferior than p_max," + " got \\(0\\.5,0\\.5\\)", + ), + id="p_min_equals_p_max", + ), + pytest.param( + 0.0, + 1.0, + -1.0, + 100.0, + pytest.raises( + ValueError, + match=r"f_min must be positive, got -1.0", + ), + id="negative_f_min", + ), + pytest.param( + 0.0, + 1.0, + 0.0, + -100.0, + pytest.raises( + ValueError, + match="f_max must be positive, got -100\\.0\n" + "f_min must be strictly inferior than f_max, got \\(0\\.0,-100\\.0\\)", + ), + id="negative_f_max", + ), + pytest.param( + 0.0, + 1.0, + 500.0, + 500.0, + pytest.raises( + ValueError, + match="f_min must be strictly inferior than f_max," + " got \\(500\\.0,500\\.0\\)", + ), + id="f_min_equals_f_max", + ), + ], +) +def test_scale_part_errors( + p_min: float, + p_max: float, + f_min: float, + f_max: float, + expected: contextlib.AbstractContextManager, +) -> None: + with expected as e: + assert ScalePart(p_min=p_min, p_max=p_max, f_min=f_min, f_max=f_max) == e