Skip to content

Commit b205a2a

Browse files
authored
Merge pull request #38 from wwhenxuan/master
whenxuan: update the v0.0.8 version for s2generator
2 parents cdd0b0c + 252e3d8 commit b205a2a

13 files changed

Lines changed: 431 additions & 4 deletions

File tree

s2generator/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "0.0.7"
3+
__version__ = "0.0.8"
44

55
__all__ = [
66
"Node",
@@ -17,6 +17,7 @@
1717
"print_hello",
1818
"excitation",
1919
"simulator",
20+
"augmentation",
2021
"utils",
2122
"params",
2223
]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on 2026/03/02 12:15:45
4+
@author: Whenxuan Wang
5+
@email: wwhenxuan@gmail.com
6+
@url: https://github.com/wwhenxuan/S2Generator
7+
"""
8+
9+
__all__ = ["frequency_perturbation"]
10+
11+
from .frequency_perturbation import frequency_perturbation
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on 2026/03/02 12:16:05
4+
@author: Whenxuan Wang
5+
@email: wwhenxuan@gmail.com
6+
@url: https://github.com/wwhenxuan/S2Generator
7+
"""
8+
9+
import numpy as np
10+
from numpy import fft
11+
12+
13+
def sample_random_perturbation(
14+
K: int, min_alpha: float, max_alpha: float, rng: np.random.RandomState = None
15+
) -> np.ndarray:
16+
"""
17+
Randomly sample K numbers in the interval [-alpha_max, -alpha_min] U [alpha_min, alpha_max]
18+
The purpose of this sampling is to construct random perturbations in the frequency domain.
19+
20+
:param K: Number of random numbers to sample
21+
:param min_alpha: Minimum absolute value of the random numbers
22+
:param max_alpha: Maximum absolute value of the random numbers
23+
:param rng: Optional random number generator, if not provided, the global numpy random number generator will be used
24+
25+
:return: A numpy array containing K random numbers, which are uniformly distributed in the interval [-alpha_max, -alpha_min] U [alpha_min, alpha_max]
26+
"""
27+
28+
# First generate random numbers in [alpha_min, alpha_max]
29+
if rng is not None:
30+
positive_rand = rng.uniform(min_alpha, max_alpha, K)
31+
32+
# Randomly generate sign (-1 or 1)
33+
signs = rng.choice([-1, 1], size=K)
34+
35+
else:
36+
# When the random number generator is not passed in, use the global numpy random number generator
37+
positive_rand = np.random.uniform(min_alpha, max_alpha, K)
38+
signs = np.random.choice([-1, 1], size=K)
39+
40+
# Combine to get the final result
41+
final_random_nums = positive_rand * signs
42+
43+
return final_random_nums
44+
45+
46+
def frequency_perturbation(
47+
series: np.ndarray,
48+
min_alpha: float,
49+
max_alpha: float,
50+
r: float = 0.5,
51+
rng: np.random.RandomState = None,
52+
) -> np.ndarray:
53+
"""
54+
Perform frequency domain perturbation on the input time series.
55+
This method adds random perturbations to the frequency components of the time series,
56+
which can help to enhance the diversity of the data and improve the robustness of models trained on it.
57+
58+
:param series: Input time series, a 1D numpy array
59+
:param min_alpha: Minimum absolute value of the random perturbation added to the frequency components
60+
:param max_alpha: Maximum absolute value of the random perturbation added to the frequency components
61+
:param r: Proportion of frequency components to perturb (default is 0.5, meaning 50% of the frequency components will be perturbed)
62+
:param rng: Optional random number generator, if not provided, the global numpy random number generator will be used.
63+
64+
:return: Perturbed time series, a 1D numpy array of the same length as the input series.
65+
"""
66+
f = fft.rfft(series)
67+
f_perturbed = f.copy()
68+
frequencies = fft.fftfreq(len(series))
69+
70+
# Calculate the number of frequency domain components that can be perturbed
71+
K = int(len(frequencies) * r)
72+
73+
# Sample random perturbations for the real and imaginary parts of the frequency components
74+
alpha_real = sample_random_perturbation(
75+
K=K, min_alpha=min_alpha, max_alpha=max_alpha, rng=rng
76+
)
77+
alpha_imag = sample_random_perturbation(
78+
K=K, min_alpha=min_alpha, max_alpha=max_alpha, rng=rng
79+
)
80+
81+
# Randomly select K frequency domain components for perturbation
82+
indices = np.random.choice(len(f_perturbed), size=K, replace=False)
83+
f_perturbed[indices] += alpha_real + 1j * alpha_imag
84+
85+
# Perform inverse Fourier transform to restore the original time-domain signal
86+
perturbed_series = fft.irfft(f_perturbed).real
87+
88+
return perturbed_series

s2generator/simulator/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
@url: https://github.com/wwhenxuan/S2Generator
77
"""
88

9+
__all__ = ["ARIMASimulator", "WienerFilterSimulator"]
10+
911
from .arima import ARIMASimulator
1012

1113
from .wiener_filter import WienerFilterSimulator

s2generator/simulator/wiener_filter.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
self.R = None
7575

7676
# Noise variance σ²
77-
self.sigma_sq = None
77+
self._sigma_sq = None
7878

7979
# Wiener filter coefficients [filter_order, ]
8080
self._coeffs = None
@@ -128,7 +128,7 @@ def fit(self, time_series: np.ndarray) -> None:
128128
self.R = toeplitz(self.acf_vals[: self.filter_order])
129129

130130
# The filter coefficients and noise variance are obtained by solving the Yule-Walker equation.
131-
self._coeffs, self.sigma_sq = yule_walker(A=self.R)
131+
self._coeffs, self._sigma_sq = yule_walker(A=self.R)
132132

133133
# Initialize noise and calculate the fitted residuals.
134134
# Note that increasing the filter order is necessary to avoid edge effects.
@@ -259,6 +259,42 @@ def check_inputs(self, time_series: np.ndarray) -> np.ndarray:
259259

260260
return time_series
261261

262+
def set_coeffs(self, coeffs: np.ndarray) -> None:
263+
"""
264+
Manually set the Wiener filter coefficients (for testing purposes).
265+
266+
:param coeffs: The Wiener filter coefficients to set, with shape 1D [filter_order, ].
267+
268+
:return: None
269+
"""
270+
assert isinstance(coeffs, np.ndarray), "Coefficients must be a NumPy array."
271+
assert (
272+
len(coeffs) == self.filter_order
273+
), f"Length of coefficients must be equal to filter_order ({self.filter_order})."
274+
self._coeffs = coeffs
275+
276+
def set_sigma_sq(self, sigma_sq: float) -> None:
277+
"""
278+
Manually set the noise variance σ² (for testing purposes).
279+
280+
:param sigma_sq: The noise variance σ² to set, a positive float.
281+
282+
:return: None
283+
"""
284+
# Check if sigma_sq is a numeric value
285+
assert isinstance(
286+
sigma_sq, (int, float, np.ndarray)
287+
), "Noise variance σ² must be a numeric value."
288+
289+
# Check if sigma_sq is a positive value
290+
assert sigma_sq > 0, "Noise variance σ² must be a positive float."
291+
292+
# If sigma_sq is a NumPy array, check if it is a scalar (shape should be ()).
293+
if isinstance(sigma_sq, np.ndarray):
294+
assert sigma_sq.shape == (), "Noise variance σ² must be a scalar value."
295+
296+
self._sigma_sq = np.asarray(sigma_sq, dtype=np.float64)
297+
262298
@property
263299
def coeffs(self) -> np.ndarray:
264300
"""Get the Wiener filter coefficients after fitting the model."""
@@ -267,3 +303,12 @@ def coeffs(self) -> np.ndarray:
267303
"The filter coefficients have not been calculated yet; please call the `fit` method first."
268304
)
269305
return self._coeffs
306+
307+
@property
308+
def sigma_sq(self) -> float:
309+
"""Get the noise variance σ² after fitting the model."""
310+
if self._sigma_sq is None:
311+
raise ValueError(
312+
"The noise variance has not been calculated yet; please call the `fit` method first."
313+
)
314+
return self._sigma_sq

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setuptools.setup(
77
name="S2Generator",
88
packages=setuptools.find_packages(),
9-
version="0.0.7",
9+
version="0.0.8",
1010
description="A series-symbol (S2) dual-modality data generation mechanism, enabling the unrestricted creation of high-quality time series data paired with corresponding symbolic representations.", # 包的简短描述
1111
url="https://github.com/wwhenxuan/S2Generator",
1212
author="whenxuan, johnfan12, changewam",

tests/data/data.npy

21 Bytes
Binary file not shown.

tests/data/data.npz

-8 Bytes
Binary file not shown.

tests/data/s2data.npy

21 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)