Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
3 changes: 2 additions & 1 deletion s2generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "0.0.7"
__version__ = "0.0.8"

__all__ = [
"Node",
Expand All @@ -17,6 +17,7 @@
"print_hello",
"excitation",
"simulator",
"augmentation",
"utils",
"params",
]
Expand Down
11 changes: 11 additions & 0 deletions s2generator/augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
"""
Created on 2026/03/02 12:15:45
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
@url: https://github.com/wwhenxuan/S2Generator
"""

__all__ = ["frequency_perturbation"]

from .frequency_perturbation import frequency_perturbation
88 changes: 88 additions & 0 deletions s2generator/augmentation/frequency_perturbation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
"""
Created on 2026/03/02 12:16:05
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
@url: https://github.com/wwhenxuan/S2Generator
"""

import numpy as np
from numpy import fft


def sample_random_perturbation(
K: int, min_alpha: float, max_alpha: float, rng: np.random.RandomState = None
) -> np.ndarray:
"""
Randomly sample K numbers in the interval [-alpha_max, -alpha_min] U [alpha_min, alpha_max]
The purpose of this sampling is to construct random perturbations in the frequency domain.

:param K: Number of random numbers to sample
:param min_alpha: Minimum absolute value of the random numbers
:param max_alpha: Maximum absolute value of the random numbers
:param rng: Optional random number generator, if not provided, the global numpy random number generator will be used

:return: A numpy array containing K random numbers, which are uniformly distributed in the interval [-alpha_max, -alpha_min] U [alpha_min, alpha_max]
"""

# First generate random numbers in [alpha_min, alpha_max]
if rng is not None:
positive_rand = rng.uniform(min_alpha, max_alpha, K)

# Randomly generate sign (-1 or 1)
signs = rng.choice([-1, 1], size=K)

else:
# When the random number generator is not passed in, use the global numpy random number generator
positive_rand = np.random.uniform(min_alpha, max_alpha, K)
signs = np.random.choice([-1, 1], size=K)

# Combine to get the final result
final_random_nums = positive_rand * signs

return final_random_nums


def frequency_perturbation(
series: np.ndarray,
min_alpha: float,
max_alpha: float,
r: float = 0.5,
rng: np.random.RandomState = None,
) -> np.ndarray:
"""
Perform frequency domain perturbation on the input time series.
This method adds random perturbations to the frequency components of the time series,
which can help to enhance the diversity of the data and improve the robustness of models trained on it.

:param series: Input time series, a 1D numpy array
:param min_alpha: Minimum absolute value of the random perturbation added to the frequency components
:param max_alpha: Maximum absolute value of the random perturbation added to the frequency components
:param r: Proportion of frequency components to perturb (default is 0.5, meaning 50% of the frequency components will be perturbed)
:param rng: Optional random number generator, if not provided, the global numpy random number generator will be used.

:return: Perturbed time series, a 1D numpy array of the same length as the input series.
"""
f = fft.rfft(series)
f_perturbed = f.copy()
frequencies = fft.fftfreq(len(series))

# Calculate the number of frequency domain components that can be perturbed
K = int(len(frequencies) * r)

# Sample random perturbations for the real and imaginary parts of the frequency components
alpha_real = sample_random_perturbation(
K=K, min_alpha=min_alpha, max_alpha=max_alpha, rng=rng
)
alpha_imag = sample_random_perturbation(
K=K, min_alpha=min_alpha, max_alpha=max_alpha, rng=rng
)

# Randomly select K frequency domain components for perturbation
indices = np.random.choice(len(f_perturbed), size=K, replace=False)
f_perturbed[indices] += alpha_real + 1j * alpha_imag

# Perform inverse Fourier transform to restore the original time-domain signal
perturbed_series = fft.irfft(f_perturbed).real

return perturbed_series
2 changes: 2 additions & 0 deletions s2generator/simulator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
@url: https://github.com/wwhenxuan/S2Generator
"""

__all__ = ["ARIMASimulator", "WienerFilterSimulator"]

from .arima import ARIMASimulator

from .wiener_filter import WienerFilterSimulator
49 changes: 47 additions & 2 deletions s2generator/simulator/wiener_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
self.R = None

# Noise variance σ²
self.sigma_sq = None
self._sigma_sq = None

# Wiener filter coefficients [filter_order, ]
self._coeffs = None
Expand Down Expand Up @@ -128,7 +128,7 @@ def fit(self, time_series: np.ndarray) -> None:
self.R = toeplitz(self.acf_vals[: self.filter_order])

# The filter coefficients and noise variance are obtained by solving the Yule-Walker equation.
self._coeffs, self.sigma_sq = yule_walker(A=self.R)
self._coeffs, self._sigma_sq = yule_walker(A=self.R)

# Initialize noise and calculate the fitted residuals.
# Note that increasing the filter order is necessary to avoid edge effects.
Expand Down Expand Up @@ -259,6 +259,42 @@ def check_inputs(self, time_series: np.ndarray) -> np.ndarray:

return time_series

def set_coeffs(self, coeffs: np.ndarray) -> None:
"""
Manually set the Wiener filter coefficients (for testing purposes).

:param coeffs: The Wiener filter coefficients to set, with shape 1D [filter_order, ].

:return: None
"""
assert isinstance(coeffs, np.ndarray), "Coefficients must be a NumPy array."
assert (
len(coeffs) == self.filter_order
), f"Length of coefficients must be equal to filter_order ({self.filter_order})."
self._coeffs = coeffs

def set_sigma_sq(self, sigma_sq: float) -> None:
"""
Manually set the noise variance σ² (for testing purposes).

:param sigma_sq: The noise variance σ² to set, a positive float.

:return: None
"""
# Check if sigma_sq is a numeric value
assert isinstance(
sigma_sq, (int, float, np.ndarray)
), "Noise variance σ² must be a numeric value."

# Check if sigma_sq is a positive value
assert sigma_sq > 0, "Noise variance σ² must be a positive float."

# If sigma_sq is a NumPy array, check if it is a scalar (shape should be ()).
if isinstance(sigma_sq, np.ndarray):
assert sigma_sq.shape == (), "Noise variance σ² must be a scalar value."

self._sigma_sq = np.asarray(sigma_sq, dtype=np.float64)

@property
def coeffs(self) -> np.ndarray:
"""Get the Wiener filter coefficients after fitting the model."""
Expand All @@ -267,3 +303,12 @@ def coeffs(self) -> np.ndarray:
"The filter coefficients have not been calculated yet; please call the `fit` method first."
)
return self._coeffs

@property
def sigma_sq(self) -> float:
"""Get the noise variance σ² after fitting the model."""
if self._sigma_sq is None:
raise ValueError(
"The noise variance has not been calculated yet; please call the `fit` method first."
)
return self._sigma_sq
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setuptools.setup(
name="S2Generator",
packages=setuptools.find_packages(),
version="0.0.7",
version="0.0.8",
description="A series-symbol (S2) dual-modality data generation mechanism, enabling the unrestricted creation of high-quality time series data paired with corresponding symbolic representations.", # 包的简短描述
url="https://github.com/wwhenxuan/S2Generator",
author="whenxuan, johnfan12, changewam",
Expand Down
Binary file modified tests/data/data.npy
Binary file not shown.
Binary file modified tests/data/data.npz
Binary file not shown.
Binary file modified tests/data/s2data.npy
Binary file not shown.
Binary file modified tests/data/s2data.npz
Binary file not shown.
76 changes: 76 additions & 0 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
"""
Created on 2026/03/02 16:02:37
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
@url: https://github.com/wwhenxuan/S2Generator
"""
import unittest

import numpy as np

from s2generator.augmentation import frequency_perturbation
from s2generator.augmentation.frequency_perturbation import sample_random_perturbation


class TestDataAugmentation(unittest.TestCase):
"""Testing the data augmentation module for time series data"""

# Random number generator for testing
rng = np.random.RandomState(42)

def test_sample_random_perturbation(self) -> None:
"""Test the function for sampling random perturbations in the frequency domain"""
K = 10
min_alpha = 0.1
max_alpha = 0.5

random_perturbations = sample_random_perturbation(
K=K, min_alpha=min_alpha, max_alpha=max_alpha, rng=self.rng
)

# Check the length of the output
self.assertEqual(
len(random_perturbations),
K,
msg="Wrong length of random perturbations in `test_sample_random_perturbation` method",
)

# Check the value range of the output
for alpha in random_perturbations:
self.assertTrue(
(alpha >= min_alpha and alpha <= max_alpha)
or (alpha <= -min_alpha and alpha >= -max_alpha),
msg="Random perturbation value out of range in `test_sample_random_perturbation` method",
)

def test_frequency_perturbation(self) -> None:
"""Test the function for performing frequency domain perturbation on time series data"""
# Generate a simple time series for testing
t = np.linspace(0, 1, 100)
series = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100)

min_alpha = 0.1
max_alpha = 0.5
r = 0.3

perturbed_series = frequency_perturbation(
series=series, min_alpha=min_alpha, max_alpha=max_alpha, r=r, rng=self.rng
)

# Check that the output has the same length as the input
self.assertEqual(
len(perturbed_series),
len(series),
msg="Output length does not match input length in `test_frequency_perturbation` method",
)

# Check that the output is different from the input (since we added perturbations)
self.assertFalse(
np.array_equal(perturbed_series, series),
msg="Perturbed series is identical to original series in `test_frequency_perturbation` method",
)


if __name__ == "__main__":
unittest.main()
Loading