Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 49 additions & 33 deletions src/osekit/core_api/frequency_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,67 @@

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

import numpy as np

from osekit.utils.core_utils import get_closest_value_index


@dataclass(frozen=True)
class ScalePart:
"""Represent a part of the frequency scale of a spectrogram.

The given part goes from:
p_min (in % of the axis), representing f_min
to:
p_max (in % of the axis), representing f_max
"""

def __init__(
self,
p_min: float,
p_max: float,
f_min: float,
f_max: float,
scale_type: Literal["lin", "log"] = "lin",
) -> None:
"""Initialize a ScalePart.
p_min: float
Relative position of the bottom of the scale part on the full scale.
Must be in the interval [0.0, 1.0], where 0.0 is the bottom of the scale
and 1.0 is the top.
p_max: float
Relative position of the top of the scale part on the full scale.
Must be in the interval [0.0, 1.0], where 0.0 is the bottom of the scale
and 1.0 is the top.
f_min: float
Frequency corresponding to the bottom of the scale part.
f_max: float
Frequency corresponding to the top of the scale part.
scale_type: Literal["lin", "log"]
Type of the scale, either linear or logarithmic.

Parameters
----------
p_min: float
Position (in percent) of the bottom of the scale part on the full scale.
p_max: float
Position (in percent) of the top of the scale part on the full scale.
f_min: float
Frequency corresponding to the bottom of the scale part.
f_max: float
Frequency corresponding to the top of the scale part.
scale_type: Literal["lin", "log"]
Type of the scale, either linear or logarithmic.
"""

"""
self.p_min = p_min
self.p_max = p_max
self.f_min = f_min
self.f_max = f_max
self.scale_type: Literal["lin", "log"] = scale_type
p_min: float
p_max: float
f_min: float
f_max: float
scale_type: Literal["lin", "log"] = "lin"

def __post_init__(self) -> None:
"""Check if ScalePart values are correct."""
err = []
if not 0.0 <= self.p_min <= 1.0:
err.append(f"p_min must be between 0 and 1, got {self.p_min}")
if not 0.0 <= self.p_max <= 1.0:
err.append(f"p_max must be between 0 and 1, got {self.p_max}")
if self.p_min >= self.p_max:
err.append(
f"p_min must be strictly inferior than p_max, got ({self.p_min},{self.p_max})",
)
if self.f_min < 0:
err.append(
f"f_min must be positive, got {self.f_min}",
)
if self.f_max < 0:
err.append(
f"f_max must be positive, got {self.f_max}",
)
if self.f_min >= self.f_max:
err.append(
f"f_min must be strictly inferior than f_max, got ({self.f_min},{self.f_max})",
)
if err:
msg = "\n".join(err)
raise ValueError(msg)

def get_frequencies(self, nb_points: int) -> list[int]:
"""Return the frequency points of the present scale part."""
Expand Down
111 changes: 111 additions & 0 deletions tests/test_frequency_scales.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import contextlib

import numpy as np
import pytest

Expand Down Expand Up @@ -721,3 +723,112 @@ def test_frequency_scale_equality(scale1: Scale, scale2: Scale, expected: bool)
)
def test_frequency_scale_serialization(scale: Scale) -> None:
assert Scale.from_dict_value(scale.to_dict_value()) == scale


@pytest.mark.parametrize(
("p_min", "p_max", "f_min", "f_max", "expected"),
[
pytest.param(
-0.5,
1.0,
1.0,
100.0,
pytest.raises(
ValueError,
match="p_min must be between 0 and 1, got -0\\.5",
),
id="negative_min",
),
pytest.param(
5.0,
1.0,
1.0,
100.0,
pytest.raises(
ValueError,
match="p_min must be between 0 and 1, got 5\\.0\n"
"p_min must be strictly inferior than p_max, got \\(5\\.0,1\\.0\\)",
),
id="min_too_big",
),
pytest.param(
0.0,
-1.0,
1.0,
100.0,
pytest.raises(
ValueError,
match="p_max must be between 0 and 1, got -1\\.0\n"
"p_min must be strictly inferior than p_max, got \\(0\\.0,-1\\.0\\)",
),
id="negative_max",
),
pytest.param(
0.0,
2.0,
1.0,
100.0,
pytest.raises(
ValueError,
match=r"p_max must be between 0 and 1, got 2.0",
),
id="max_too_big",
),
pytest.param(
0.5,
0.5,
1.0,
100.0,
pytest.raises(
ValueError,
match="p_min must be strictly inferior than p_max,"
" got \\(0\\.5,0\\.5\\)",
),
id="p_min_equals_p_max",
),
pytest.param(
0.0,
1.0,
-1.0,
100.0,
pytest.raises(
ValueError,
match=r"f_min must be positive, got -1.0",
),
id="negative_f_min",
),
pytest.param(
0.0,
1.0,
0.0,
-100.0,
pytest.raises(
ValueError,
match="f_max must be positive, got -100\\.0\n"
"f_min must be strictly inferior than f_max, got \\(0\\.0,-100\\.0\\)",
),
id="negative_f_max",
),
pytest.param(
0.0,
1.0,
500.0,
500.0,
pytest.raises(
ValueError,
match="f_min must be strictly inferior than f_max,"
" got \\(500\\.0,500\\.0\\)",
),
id="f_min_equals_f_max",
),
],
)
def test_scale_part_errors(
p_min: float,
p_max: float,
f_min: float,
f_max: float,
expected: contextlib.AbstractContextManager,
) -> None:
with expected as e:
assert ScalePart(p_min=p_min, p_max=p_max, f_min=f_min, f_max=f_max) == e