Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion s2generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "0.0.8"
__version__ = "0.0.9"

__all__ = [
"Node",
Expand Down
6 changes: 3 additions & 3 deletions s2generator/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__all__ = [
"amplitude_modulation",
"censor_augmentation",
"empirical_model_modulation",
"empirical_mode_modulation",
"frequency_perturbation",
"spike_injection",
"wiener_filter",
Expand All @@ -23,8 +23,8 @@
# Import the censoring augmentation function
from ._censor_augmentation import censor_augmentation

# Import the empirical model modulation function
from ._empirical_model_modulation import empirical_model_modulation
# Import the empirical mode modulation function
from ._empirical_mode_modulation import empirical_mode_modulation

# Import the frequency perturbation function
from ._frequency_perturbation import frequency_perturbation
Expand Down
2 changes: 1 addition & 1 deletion s2generator/augmentation/_amplitude_modulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,4 @@ def amplitude_modulation(
# Apply the modulation trend to the original time series
modulated_series = time_series * modulation_trend

return modulated_series, np.array(modulation_trend)
return modulated_series
6 changes: 3 additions & 3 deletions s2generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ def val_diff(self, xs: ndarray, deterministic: Optional[bool] = True) -> ndarray
if xs.ndim > 1:
# For multivariate case, keep other dimensions constant
x_uniform_input = np.tile(np.mean(xs, axis=0), (n_integration_points, 1))
x_uniform_input[
:, 0
] = x_uniform # Replace first dimension with uniform grid
x_uniform_input[:, 0] = (
x_uniform # Replace first dimension with uniform grid
)
else:
x_uniform_input = x_uniform.reshape(-1, 1) # Ensure 2D array for val method

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setuptools.setup(
name="S2Generator",
packages=setuptools.find_packages(),
version="0.0.8",
version="0.0.9",
description="A series-symbol (S2) dual-modality data generation mechanism, enabling the unrestricted creation of high-quality time series data paired with corresponding symbolic representations.", # 包的简短描述
url="https://github.com/wwhenxuan/S2Generator",
author="whenxuan, johnfan12, changewam",
Expand Down
17 changes: 12 additions & 5 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from s2generator.augmentation import (
amplitude_modulation,
censor_augmentation,
empirical_mode_modulation,
frequency_perturbation,
wiener_filter,
add_linear_trend,
time_series_mixup,
)

from s2generator.augmentation._frequency_perturbation import sample_random_perturbation
Expand Down Expand Up @@ -60,7 +64,11 @@ def test_frequency_perturbation(self) -> None:
r = 0.3

perturbed_series = frequency_perturbation(
series=series, min_alpha=min_alpha, max_alpha=max_alpha, r=r, rng=self.rng
time_series=series,
min_alpha=min_alpha,
max_alpha=max_alpha,
r=r,
rng=self.rng,
)

# Check that the output has the same length as the input
Expand Down Expand Up @@ -113,13 +121,12 @@ def test_amplitude_modulation(self) -> None:
t = np.linspace(0, 1, 100)
series = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100)

min_modulation = 0.5
max_modulation = 1.5
amplitude_mean, amplitude_variation = 1.0, 1.0

modulated_series = amplitude_modulation(
time_series=series.copy(),
min_modulation=min_modulation,
max_modulation=max_modulation,
amplitude_mean=amplitude_mean,
amplitude_variation=amplitude_variation,
rng=self.rng,
)

Expand Down