From 9c069f11c838014a0bec625910116b109f4a47e7 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Wed, 4 Mar 2026 22:54:39 +0800 Subject: [PATCH 01/14] whenxuan: black init --- .../augmentation/amplitude_modulation.py | 29 +++++++++++++++++++ .../augmentation/frequency_perturbation.py | 8 ++--- s2generator/base.py | 6 ++-- s2generator/utils/visualization.py | 1 - tests/test_wiener_filter_simulator.py | 1 - 5 files changed, 36 insertions(+), 9 deletions(-) create mode 100644 s2generator/augmentation/amplitude_modulation.py diff --git a/s2generator/augmentation/amplitude_modulation.py b/s2generator/augmentation/amplitude_modulation.py new file mode 100644 index 0000000..e9103ca --- /dev/null +++ b/s2generator/augmentation/amplitude_modulation.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +""" +Created on 2026/03/04 22:52:40 +@author: Whenxuan Wang +@email: wwhenxuan@gmail.com +@url: https://github.com/wwhenxuan/S2Generator +""" + +import numpy as np + + +def amplitude_modulation( + time_series: np.ndarray, +) -> np.ndarray: + """ + Perform amplitude modulation on the input time series. + This method applies a random amplitude modulation to the time series, which can help to enhance the diversity of the data and improve the robustness of models trained on it. + + :param time_series: Input time series, a 1D numpy array + + :return: Amplitude modulated time series, a 1D numpy array of the same length as the input series. + """ + # Generate a random amplitude modulation signal + modulation_signal = np.random.uniform(0.5, 1.5, size=len(time_series)) + + # Apply amplitude modulation to the input time series + modulated_series = time_series * modulation_signal + + return modulated_series diff --git a/s2generator/augmentation/frequency_perturbation.py b/s2generator/augmentation/frequency_perturbation.py index 34fb22a..9347000 100644 --- a/s2generator/augmentation/frequency_perturbation.py +++ b/s2generator/augmentation/frequency_perturbation.py @@ -44,7 +44,7 @@ def sample_random_perturbation( def frequency_perturbation( - series: np.ndarray, + time_series: np.ndarray, min_alpha: float, max_alpha: float, r: float = 0.5, @@ -55,7 +55,7 @@ def frequency_perturbation( This method adds random perturbations to the frequency components of the time series, which can help to enhance the diversity of the data and improve the robustness of models trained on it. - :param series: Input time series, a 1D numpy array + :param time_series: Input time series, a 1D numpy array :param min_alpha: Minimum absolute value of the random perturbation added to the frequency components :param max_alpha: Maximum absolute value of the random perturbation added to the frequency components :param r: Proportion of frequency components to perturb (default is 0.5, meaning 50% of the frequency components will be perturbed) @@ -63,9 +63,9 @@ def frequency_perturbation( :return: Perturbed time series, a 1D numpy array of the same length as the input series. """ - f = fft.rfft(series) + f = fft.rfft(time_series) f_perturbed = f.copy() - frequencies = fft.fftfreq(len(series)) + frequencies = fft.fftfreq(len(time_series)) # Calculate the number of frequency domain components that can be perturbed K = int(len(frequencies) * r) diff --git a/s2generator/base.py b/s2generator/base.py index 69586b8..8c78568 100644 --- a/s2generator/base.py +++ b/s2generator/base.py @@ -376,9 +376,9 @@ def val_diff(self, xs: ndarray, deterministic: Optional[bool] = True) -> ndarray if xs.ndim > 1: # For multivariate case, keep other dimensions constant x_uniform_input = np.tile(np.mean(xs, axis=0), (n_integration_points, 1)) - x_uniform_input[:, 0] = ( - x_uniform # Replace first dimension with uniform grid - ) + x_uniform_input[ + :, 0 + ] = x_uniform # Replace first dimension with uniform grid else: x_uniform_input = x_uniform.reshape(-1, 1) # Ensure 2D array for val method diff --git a/s2generator/utils/visualization.py b/s2generator/utils/visualization.py index a589d20..6e0315e 100644 --- a/s2generator/utils/visualization.py +++ b/s2generator/utils/visualization.py @@ -426,7 +426,6 @@ def plot_simulator_statistics( # Plot the histogram of residuals and the Q-Q plot for the normality test. if residuals is not None: - from statsmodels.graphics.gofplots import qqplot from scipy.stats import shapiro, norm diff --git a/tests/test_wiener_filter_simulator.py b/tests/test_wiener_filter_simulator.py index 5e0c298..3682522 100644 --- a/tests/test_wiener_filter_simulator.py +++ b/tests/test_wiener_filter_simulator.py @@ -31,7 +31,6 @@ def test_create_instance(self) -> None: revin=revin, random_state=random_state, ): - # Create an instance of WienerFilterSimulator with the specified parameters simulator = WienerFilterSimulator( filter_order=filter_order, From d9f10557bc61b4805899bdd62248a44a9aa9ffd8 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Wed, 4 Mar 2026 23:34:08 +0800 Subject: [PATCH 02/14] whenxuan: update the interpolate methods --- .../augmentation/amplitude_modulation.py | 82 +++++++++++++++++-- 1 file changed, 75 insertions(+), 7 deletions(-) diff --git a/s2generator/augmentation/amplitude_modulation.py b/s2generator/augmentation/amplitude_modulation.py index e9103ca..e0bc88e 100644 --- a/s2generator/augmentation/amplitude_modulation.py +++ b/s2generator/augmentation/amplitude_modulation.py @@ -7,10 +7,85 @@ """ import numpy as np +from scipy.interpolate import interp1d, lagrange # 改用lagrange函数 + + +def linear_interpolation( + x_known: np.ndarray, y_known: np.ndarray, x_new: np.ndarray +) -> np.ndarray: + """ + 线性插值函数 + :param x_known: 已知离散点的x坐标(时间点),numpy数组 + :param y_known: 已知离散点的y坐标,numpy数组 + :param x_new: 需要插值的新x坐标,numpy数组或单个数值 + :return: 插值后的y_new值,与x_new形状相同 + """ + # 输入校验 + if len(x_known) != len(y_known): + raise ValueError("x_known和y_known的长度必须相等") + if np.any(np.diff(x_known) <= 0): + raise ValueError("x_known必须是严格递增的序列") + + # 创建线性插值器 + linear_interp = interp1d(x_known, y_known, kind="linear", fill_value="extrapolate") + # 计算插值结果 + y_new = linear_interp(x_new) + return y_new + + +def cubic_spline_interpolation( + x_known: np.ndarray, y_known: np.ndarray, x_new: np.ndarray +) -> np.ndarray: + """ + 三次样条插值函数 + :param x_known: 已知离散点的x坐标(时间点),numpy数组 + :param y_known: 已知离散点的y坐标,numpy数组 + :param x_new: 需要插值的新x坐标,numpy数组或单个数值 + :return: 插值后的y_new值,与x_new形状相同 + """ + # 输入校验 + if len(x_known) != len(y_known): + raise ValueError("x_known和y_known的长度必须相等") + if len(x_known) < 3: + raise ValueError("三次样条插值需要至少3个已知点") + if np.any(np.diff(x_known) <= 0): + raise ValueError("x_known必须是严格递增的序列") + + # 创建三次样条插值器 + spline_interp = interp1d(x_known, y_known, kind="cubic", fill_value="extrapolate") + # 计算插值结果 + y_new = spline_interp(x_new) + return y_new + + +def lagrange_interpolation(x_known: np.ndarray, y_known: np.ndarray, x_new: np.ndarray): + """ + 拉格朗日插值函数(兼容所有scipy版本) + :param x_known: 已知离散点的x坐标(时间点),numpy数组 + :param y_known: 已知离散点的y坐标,numpy数组 + :param x_new: 需要插值的新x坐标,numpy数组或单个数值 + :return: 插值后的y_new值,与x_new形状相同 + """ + # 输入校验 + if len(x_known) != len(y_known): + raise ValueError("x_known和y_known的长度必须相等") + if len(x_known) < 2: + raise ValueError("拉格朗日插值需要至少2个已知点") + if len(np.unique(x_known)) != len(x_known): + raise ValueError("x_known中不能有重复的坐标点") + + # 创建拉格朗日插值多项式(兼容旧版本scipy) + lagrange_poly = lagrange(x_known, y_known) + # 计算插值结果(polyval支持单个值或数组输入) + y_new = np.polyval(lagrange_poly, x_new) + return y_new def amplitude_modulation( time_series: np.ndarray, + num_changepoints: int = 5, + mean_amplitude: float = 1.0, + amplitude_variation: float = 1.0, ) -> np.ndarray: """ Perform amplitude modulation on the input time series. @@ -20,10 +95,3 @@ def amplitude_modulation( :return: Amplitude modulated time series, a 1D numpy array of the same length as the input series. """ - # Generate a random amplitude modulation signal - modulation_signal = np.random.uniform(0.5, 1.5, size=len(time_series)) - - # Apply amplitude modulation to the input time series - modulated_series = time_series * modulation_signal - - return modulated_series From 3975305c95d7cf96c37a74b94bd24b3d0a224327 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Wed, 4 Mar 2026 23:36:58 +0800 Subject: [PATCH 03/14] whenxuan: remove the interpolate to tools --- .../augmentation/amplitude_modulation.py | 72 ------------------ s2generator/utils/__init__.py | 10 +++ s2generator/utils/_tools.py | 76 +++++++++++++++++++ 3 files changed, 86 insertions(+), 72 deletions(-) diff --git a/s2generator/augmentation/amplitude_modulation.py b/s2generator/augmentation/amplitude_modulation.py index e0bc88e..1eafb44 100644 --- a/s2generator/augmentation/amplitude_modulation.py +++ b/s2generator/augmentation/amplitude_modulation.py @@ -7,78 +7,6 @@ """ import numpy as np -from scipy.interpolate import interp1d, lagrange # 改用lagrange函数 - - -def linear_interpolation( - x_known: np.ndarray, y_known: np.ndarray, x_new: np.ndarray -) -> np.ndarray: - """ - 线性插值函数 - :param x_known: 已知离散点的x坐标(时间点),numpy数组 - :param y_known: 已知离散点的y坐标,numpy数组 - :param x_new: 需要插值的新x坐标,numpy数组或单个数值 - :return: 插值后的y_new值,与x_new形状相同 - """ - # 输入校验 - if len(x_known) != len(y_known): - raise ValueError("x_known和y_known的长度必须相等") - if np.any(np.diff(x_known) <= 0): - raise ValueError("x_known必须是严格递增的序列") - - # 创建线性插值器 - linear_interp = interp1d(x_known, y_known, kind="linear", fill_value="extrapolate") - # 计算插值结果 - y_new = linear_interp(x_new) - return y_new - - -def cubic_spline_interpolation( - x_known: np.ndarray, y_known: np.ndarray, x_new: np.ndarray -) -> np.ndarray: - """ - 三次样条插值函数 - :param x_known: 已知离散点的x坐标(时间点),numpy数组 - :param y_known: 已知离散点的y坐标,numpy数组 - :param x_new: 需要插值的新x坐标,numpy数组或单个数值 - :return: 插值后的y_new值,与x_new形状相同 - """ - # 输入校验 - if len(x_known) != len(y_known): - raise ValueError("x_known和y_known的长度必须相等") - if len(x_known) < 3: - raise ValueError("三次样条插值需要至少3个已知点") - if np.any(np.diff(x_known) <= 0): - raise ValueError("x_known必须是严格递增的序列") - - # 创建三次样条插值器 - spline_interp = interp1d(x_known, y_known, kind="cubic", fill_value="extrapolate") - # 计算插值结果 - y_new = spline_interp(x_new) - return y_new - - -def lagrange_interpolation(x_known: np.ndarray, y_known: np.ndarray, x_new: np.ndarray): - """ - 拉格朗日插值函数(兼容所有scipy版本) - :param x_known: 已知离散点的x坐标(时间点),numpy数组 - :param y_known: 已知离散点的y坐标,numpy数组 - :param x_new: 需要插值的新x坐标,numpy数组或单个数值 - :return: 插值后的y_new值,与x_new形状相同 - """ - # 输入校验 - if len(x_known) != len(y_known): - raise ValueError("x_known和y_known的长度必须相等") - if len(x_known) < 2: - raise ValueError("拉格朗日插值需要至少2个已知点") - if len(np.unique(x_known)) != len(x_known): - raise ValueError("x_known中不能有重复的坐标点") - - # 创建拉格朗日插值多项式(兼容旧版本scipy) - lagrange_poly = lagrange(x_known, y_known) - # 计算插值结果(polyval支持单个值或数组输入) - y_new = np.polyval(lagrange_poly, x_new) - return y_new def amplitude_modulation( diff --git a/s2generator/utils/__init__.py b/s2generator/utils/__init__.py index 2338118..70bda06 100644 --- a/s2generator/utils/__init__.py +++ b/s2generator/utils/__init__.py @@ -22,6 +22,9 @@ "generate_nonstationary_sine", "eacf_rlike", "yule_walker", + "linear_interpolation", + "cubic_spline_interpolation", + "lagrange_interpolation", "fft", "fftshift", "ifft", @@ -73,6 +76,13 @@ # The Yule-Walker method to estimate the parameters of AR model from ._tools import yule_walker +# The interpolation methods for time series data, including linear interpolation, cubic spline interpolation, and Lagrange interpolation +from ._tools import ( + linear_interpolation, + cubic_spline_interpolation, + lagrange_interpolation, +) + # Print the Generation Status from ._print_status import PrintStatus diff --git a/s2generator/utils/_tools.py b/s2generator/utils/_tools.py index a3ce6a2..fe54a9c 100644 --- a/s2generator/utils/_tools.py +++ b/s2generator/utils/_tools.py @@ -31,6 +31,9 @@ "generate_nonstationary_sine", "eacf_rlike", "yule_walker", + "linear_interpolation", + "cubic_spline_interpolation", + "lagrange_interpolation", ] import os @@ -40,6 +43,8 @@ from numpy import bool_ from numpy import fft as np_fft +from scipy.interpolate import interp1d, lagrange # 改用lagrange函数 + import pandas as pd from typing import Optional, Dict, Union, Tuple @@ -587,3 +592,74 @@ def yule_walker(A: np.ndarray) -> Tuple[np.ndarray, Union[float, np.ndarray]]: sigma_sq = np.dot(A[0], x) return x, sigma_sq + + +def linear_interpolation( + x_known: np.ndarray, y_known: np.ndarray, x_new: np.ndarray +) -> np.ndarray: + """ + 线性插值函数 + :param x_known: 已知离散点的x坐标(时间点),numpy数组 + :param y_known: 已知离散点的y坐标,numpy数组 + :param x_new: 需要插值的新x坐标,numpy数组或单个数值 + :return: 插值后的y_new值,与x_new形状相同 + """ + # 输入校验 + if len(x_known) != len(y_known): + raise ValueError("x_known和y_known的长度必须相等") + if np.any(np.diff(x_known) <= 0): + raise ValueError("x_known必须是严格递增的序列") + + # 创建线性插值器 + linear_interp = interp1d(x_known, y_known, kind="linear", fill_value="extrapolate") + # 计算插值结果 + y_new = linear_interp(x_new) + return y_new + + +def cubic_spline_interpolation( + x_known: np.ndarray, y_known: np.ndarray, x_new: np.ndarray +) -> np.ndarray: + """ + 三次样条插值函数 + :param x_known: 已知离散点的x坐标(时间点),numpy数组 + :param y_known: 已知离散点的y坐标,numpy数组 + :param x_new: 需要插值的新x坐标,numpy数组或单个数值 + :return: 插值后的y_new值,与x_new形状相同 + """ + # 输入校验 + if len(x_known) != len(y_known): + raise ValueError("x_known和y_known的长度必须相等") + if len(x_known) < 3: + raise ValueError("三次样条插值需要至少3个已知点") + if np.any(np.diff(x_known) <= 0): + raise ValueError("x_known必须是严格递增的序列") + + # 创建三次样条插值器 + spline_interp = interp1d(x_known, y_known, kind="cubic", fill_value="extrapolate") + # 计算插值结果 + y_new = spline_interp(x_new) + return y_new + + +def lagrange_interpolation(x_known: np.ndarray, y_known: np.ndarray, x_new: np.ndarray): + """ + 拉格朗日插值函数(兼容所有scipy版本) + :param x_known: 已知离散点的x坐标(时间点),numpy数组 + :param y_known: 已知离散点的y坐标,numpy数组 + :param x_new: 需要插值的新x坐标,numpy数组或单个数值 + :return: 插值后的y_new值,与x_new形状相同 + """ + # 输入校验 + if len(x_known) != len(y_known): + raise ValueError("x_known和y_known的长度必须相等") + if len(x_known) < 2: + raise ValueError("拉格朗日插值需要至少2个已知点") + if len(np.unique(x_known)) != len(x_known): + raise ValueError("x_known中不能有重复的坐标点") + + # 创建拉格朗日插值多项式(兼容旧版本scipy) + lagrange_poly = lagrange(x_known, y_known) + # 计算插值结果(polyval支持单个值或数组输入) + y_new = np.polyval(lagrange_poly, x_new) + return y_new From dff0f7d35124c2c493e98ee14ef219e0164fd75e Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Wed, 4 Mar 2026 23:40:55 +0800 Subject: [PATCH 04/14] whenxuan: add the input validate for frequency perturbation --- s2generator/augmentation/amplitude_modulation.py | 16 +++++++++++++++- .../augmentation/frequency_perturbation.py | 16 ++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/s2generator/augmentation/amplitude_modulation.py b/s2generator/augmentation/amplitude_modulation.py index 1eafb44..2b5697c 100644 --- a/s2generator/augmentation/amplitude_modulation.py +++ b/s2generator/augmentation/amplitude_modulation.py @@ -8,18 +8,32 @@ import numpy as np +from s2generator.utils._tools import ( + linear_interpolation, + cubic_spline_interpolation, + lagrange_interpolation, +) + def amplitude_modulation( time_series: np.ndarray, num_changepoints: int = 5, mean_amplitude: float = 1.0, amplitude_variation: float = 1.0, + interpolation_method: str = "linear", ) -> np.ndarray: """ Perform amplitude modulation on the input time series. - This method applies a random amplitude modulation to the time series, which can help to enhance the diversity of the data and improve the robustness of models trained on it. + This augmentation introduces scale trends and change points into the time series + by multiplying the signal with a piecewise linear trend. + The modulation trend is generated by sampling change points and interpolating amplitudes between them. :param time_series: Input time series, a 1D numpy array + :param num_changepoints: Number of change points to introduce in the modulation trend. + :param mean_amplitude: The mean amplitude of the modulation trend. + :param amplitude_variation: The variation of the amplitude around the mean. + :param interpolation_method: The method to interpolate the modulation trend. + Options are "linear", "cubic", or "lagrange". :return: Amplitude modulated time series, a 1D numpy array of the same length as the input series. """ diff --git a/s2generator/augmentation/frequency_perturbation.py b/s2generator/augmentation/frequency_perturbation.py index 9347000..92a66d9 100644 --- a/s2generator/augmentation/frequency_perturbation.py +++ b/s2generator/augmentation/frequency_perturbation.py @@ -63,6 +63,22 @@ def frequency_perturbation( :return: Perturbed time series, a 1D numpy array of the same length as the input series. """ + # Validate the input parameters + assert 0 <= r <= 1, "The proportion r must be between 0 and 1." + assert min_alpha >= 0, "min_alpha must be non-negative." + assert ( + max_alpha >= min_alpha + ), "max_alpha must be greater than or equal to min_alpha." + + # Validate that the input time series is ndarray + if isinstance(time_series, list): + time_series = np.array(time_series) + + # Validate that the input time series is 1D + if time_series.ndim != 1: + raise ValueError("Input time series must be a 1D array.") + + # Perform Fast Fourier Transform to convert the time series to the frequency domain f = fft.rfft(time_series) f_perturbed = f.copy() frequencies = fft.fftfreq(len(time_series)) From 4d705ca3f2af866e97aab1036a2090dba8a43212 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 00:40:39 +0800 Subject: [PATCH 05/14] whenxuan: update the amplitude modulation for data augmentation --- .../augmentation/amplitude_modulation.py | 68 ++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/s2generator/augmentation/amplitude_modulation.py b/s2generator/augmentation/amplitude_modulation.py index 2b5697c..e35142d 100644 --- a/s2generator/augmentation/amplitude_modulation.py +++ b/s2generator/augmentation/amplitude_modulation.py @@ -18,9 +18,11 @@ def amplitude_modulation( time_series: np.ndarray, num_changepoints: int = 5, - mean_amplitude: float = 1.0, + amplitude_mean: float = 1.0, amplitude_variation: float = 1.0, interpolation_method: str = "linear", + rng: np.random.RandomState = None, + seed: int = 42, ) -> np.ndarray: """ Perform amplitude modulation on the input time series. @@ -30,10 +32,72 @@ def amplitude_modulation( :param time_series: Input time series, a 1D numpy array :param num_changepoints: Number of change points to introduce in the modulation trend. - :param mean_amplitude: The mean amplitude of the modulation trend. + :param amplitude_mean: The mean amplitude of the modulation trend. :param amplitude_variation: The variation of the amplitude around the mean. :param interpolation_method: The method to interpolate the modulation trend. Options are "linear", "cubic", or "lagrange". + :param rng: Optional random number generator for reproducibility. If None, a new RNG will be created using the provided seed. + :param seed: Random seed for reproducibility if rng is not provided. :return: Amplitude modulated time series, a 1D numpy array of the same length as the input series. """ + # Validate the input time series + time_series = np.asarray(time_series) + if time_series.ndim != 1: + raise ValueError("Input time_series must be a 1D array.") + + # Validate interpolation method + if interpolation_method not in ["linear", "cubic", "lagrange"]: + raise ValueError( + "interpolation_method must be one of 'linear', 'cubic', or 'lagrange'." + ) + + # Get the length of the time series + n = len(time_series) + + # Validate num_changepoints + if num_changepoints < 2: + raise ValueError( + "num_changepoints must be at least 2 to create a modulation trend." + ) + if num_changepoints > n: + raise ValueError( + "num_changepoints cannot exceed the length of the time series." + ) + + # Initialize random number generator + if rng is None: + rng = np.random.RandomState(seed) + + # Sample change points and their corresponding amplitudes + changepoints = np.hstack( + [ + 0, + np.sort(rng.choice(n - 2, size=num_changepoints - 2, replace=False) + 1), + n - 1, + ] + ) + + # Generate random amplitudes for each change point + amplitude = rng.normal( + loc=amplitude_mean, scale=amplitude_variation, size=num_changepoints + ) + + # Interpolate the modulation trend across the entire time series + if interpolation_method == "linear": + modulation_trend = linear_interpolation( + x_known=changepoints, y_known=amplitude, x_new=np.arange(n) + ) + elif interpolation_method == "cubic": + modulation_trend = cubic_spline_interpolation( + x_known=changepoints, y_known=amplitude, x_new=np.arange(n) + ) + elif interpolation_method == "lagrange": + modulation_trend = lagrange_interpolation( + x_known=changepoints, y_known=amplitude, x_new=np.arange(n) + ) + + # Apply the modulation trend to the original time series + modulated_series = time_series * modulation_trend + + return modulated_series, np.array(modulation_trend) From 2749499825edd93d70d3895dd7694b6386832ee9 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 01:10:48 +0800 Subject: [PATCH 06/14] whenxuan: update the censor for data augmentation --- s2generator/augmentation/__init__.py | 6 +- .../augmentation/censor_augmentation.py | 74 +++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 s2generator/augmentation/censor_augmentation.py diff --git a/s2generator/augmentation/__init__.py b/s2generator/augmentation/__init__.py index bfcbd4a..cd1d3a5 100644 --- a/s2generator/augmentation/__init__.py +++ b/s2generator/augmentation/__init__.py @@ -6,6 +6,10 @@ @url: https://github.com/wwhenxuan/S2Generator """ -__all__ = ["frequency_perturbation"] +__all__ = ["frequency_perturbation", "amplitude_modulation"] +# Import the frequency perturbation function from .frequency_perturbation import frequency_perturbation + +# Import the amplitude modulation function +from .amplitude_modulation import amplitude_modulation diff --git a/s2generator/augmentation/censor_augmentation.py b/s2generator/augmentation/censor_augmentation.py new file mode 100644 index 0000000..8d380a5 --- /dev/null +++ b/s2generator/augmentation/censor_augmentation.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +""" +Created on 2026/03/05 00:42:51 +@author: Whenxuan Wang +@email: wwhenxuan@gmail.com +@url: https://github.com/wwhenxuan/S2Generator +""" + +import numpy as np + + +def censor_augmentation( + time_series: np.ndarray, + upper_quantile: float = 0.65, + lower_quantile: float = 0.35, + bernoulli_p: float = 0.8, + rng: np.random.RandomState = None, + seed: int = 42, +) -> np.ndarray: + """ + Perform censoring augmentation on the input time series. + + This augmentation censors (clips) the input signal either from below or + above, depending on a randomly sampled direction. The clipping threshold is determined by drawing + a quantile uniformly from the empirical distribution of the signal. + + :param time_series: Input time series, a 1D numpy array + :param upper_quantile: Upper quantile threshold for censoring, default is 0.65 + :param lower_quantile: Lower quantile threshold for censoring, default is 0.35 + :param bernoulli_p: Probability of censoring direction (0 for lower censoring, 1 for upper censoring), + the default is 0.5, meaning equal probability for both directions. + :param rng: Optional random number generator for reproducibility. + If None, a new RNG will be created using the provided seed. + :param seed: Random seed for reproducibility if rng is not provided. + + :return: Censored time series, a 1D numpy array of the same length as the input series. + """ + + # Validate the input time series + time_series = np.asarray(time_series) + if time_series.ndim != 1: + raise ValueError("Input time_series must be a 1D array.") + + # Validate bernoulli_p + if not (0 <= bernoulli_p <= 1): + raise ValueError("bernoulli_p must be in the range [0, 1].") + + # Get the length of the time series + length = time_series.shape[0] + + # Set random seed for reproducibility + if rng is None: + rng = np.random.RandomState(seed) + + # Randomly sample quantile thresholds for each time step + quantile_threshold = rng.uniform(lower_quantile, upper_quantile, size=length) + + # Compute the threshold value based on the quantile of the time series + threshold_value = np.quantile(time_series, quantile_threshold) + + # Sample the censor direction from bernoulli distribution (0.5) + censor_direction = rng.binomial( + n=1, p=bernoulli_p, size=length + ) # 0 for lower censoring, 1 for upper censoring + + for t in range(length): + if censor_direction[t] == 1: + # Lower censoring + time_series[t] = max(time_series[t], threshold_value[t]) + else: + # Upper censoring + time_series[t] = min(time_series[t], threshold_value[t]) + + return time_series From f0685e938fed1aba7a51986c9822bb58c8d1e976 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 01:20:47 +0800 Subject: [PATCH 07/14] whenxuan: update the unit test for data augmentation --- s2generator/augmentation/__init__.py | 11 +++-- tests/test_augmentation.py | 67 +++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/s2generator/augmentation/__init__.py b/s2generator/augmentation/__init__.py index cd1d3a5..a86c462 100644 --- a/s2generator/augmentation/__init__.py +++ b/s2generator/augmentation/__init__.py @@ -6,10 +6,13 @@ @url: https://github.com/wwhenxuan/S2Generator """ -__all__ = ["frequency_perturbation", "amplitude_modulation"] - -# Import the frequency perturbation function -from .frequency_perturbation import frequency_perturbation +__all__ = ["amplitude_modulation", "censor_augmentation", "frequency_perturbation"] # Import the amplitude modulation function from .amplitude_modulation import amplitude_modulation + +# Import the censoring augmentation function +from .censor_augmentation import censor_augmentation + +# Import the frequency perturbation function +from .frequency_perturbation import frequency_perturbation diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 1bf98bb..53cd2f0 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -9,7 +9,12 @@ import numpy as np -from s2generator.augmentation import frequency_perturbation +from s2generator.augmentation import ( + amplitude_modulation, + censor_augmentation, + frequency_perturbation, +) + from s2generator.augmentation.frequency_perturbation import sample_random_perturbation @@ -71,6 +76,66 @@ def test_frequency_perturbation(self) -> None: msg="Perturbed series is identical to original series in `test_frequency_perturbation` method", ) + def test_censor_augmentation(self) -> None: + """Test the function for performing censoring augmentation on time series data.""" + # Generate a simple time series for testing + t = np.linspace(0, 1, 100) + series = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100) + + upper_quantile = 0.65 + lower_quantile = 0.35 + bernoulli_p = 0.8 + + censored_series = censor_augmentation( + time_series=series.copy(), + upper_quantile=upper_quantile, + lower_quantile=lower_quantile, + bernoulli_p=bernoulli_p, + rng=self.rng, + ) + + # Check that the output has the same length as the input + self.assertEqual( + len(censored_series), + len(series), + msg="Output length does not match input length in `test_censor_augmentation` method", + ) + + # Check that the output is different from the input (since we applied censoring) + self.assertFalse( + np.array_equal(censored_series, series), + msg="Censored series is identical to original series in `test_censor_augmentation` method", + ) + + def test_amplitude_modulation(self) -> None: + """Test the function for performing amplitude modulation on time series data.""" + # Generate a simple time series for testing + t = np.linspace(0, 1, 100) + series = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100) + + min_modulation = 0.5 + max_modulation = 1.5 + + modulated_series = amplitude_modulation( + time_series=series.copy(), + min_modulation=min_modulation, + max_modulation=max_modulation, + rng=self.rng, + ) + + # Check that the output has the same length as the input + self.assertEqual( + len(modulated_series), + len(series), + msg="Output length does not match input length in `test_amplitude_modulation` method", + ) + + # Check that the output is different from the input (since we applied amplitude modulation) + self.assertFalse( + np.array_equal(modulated_series, series), + msg="Modulated series is identical to original series in `test_amplitude_modulation` method", + ) + if __name__ == "__main__": unittest.main() From 328ffba6ad813b26889427f4e72d7dab48662eb2 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 15:40:58 +0800 Subject: [PATCH 08/14] whenxuan: update the empirical mode modulation for data augmentation --- .../augmentation/empirical_mode_modulation.py | 126 ++++++++++++++++++ s2generator/base.py | 6 +- 2 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 s2generator/augmentation/empirical_mode_modulation.py diff --git a/s2generator/augmentation/empirical_mode_modulation.py b/s2generator/augmentation/empirical_mode_modulation.py new file mode 100644 index 0000000..233962d --- /dev/null +++ b/s2generator/augmentation/empirical_mode_modulation.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +""" +Created on 2026/03/05 11:05:45 +@author: Whenxuan Wang +@email: wwhenxuan@gmail.com +@url: https://github.com/wwhenxuan/S2Generator +""" +from typing import Optional + +import numpy as np + +from pysdkit import EMD + + +def empirical_mode_modulation( + time_series: np.ndarray, + min_scale_factor: float = 0.5, + max_scale_factor: float = 2.0, + low_frequency_enhancement: bool = True, + spline_kind: str = "cubic", + extrema_detection: str = "parabol", + max_imfs: Optional[int] = None, + rng: Optional[np.random.RandomState] = None, + seed: int = 42, +) -> np.ndarray: + """ + Perform empirical mode modulation on the input time series. + This augmentation decomposes the time series into intrinsic mode functions (IMFs) + using Empirical Mode Decomposition (EMD) and then reconstructs the signal + by randomly modifying the IMFs to introduce non-linear trends and variations. + + :param time_series: Input time series, a 1D numpy array. + :param min_scale_factor: The minimum scaling factor to apply to each IMF. + :param max_scale_factor: The maximum scaling factor to apply to each IMF. + :param low_frequency_enhancement: Whether to enhance the low-frequency components. + If True, the method will suppress the high-frequency noise in the + empirical mode decomposition results and focus on enhancing + the characterization of the low-frequency part. + :param spline_kind: The kind of spline to use for interpolation, + options are "akima", "cubic", "pchip", "cubic_hermite", "slinear", "quadratic", "linear" + :param extrema_detection: The method for detecting extrema in the EMD process, options are "parabol" or "simple" + :param max_imfs: The maximum number of IMFs to extract, + if None, it will extract all possible IMFs until the residue is a monotonic function. + + :return: Empirical mode modulated time series, a 1D numpy array of the same length as the input series. + """ + # Validate the input time series + time_series = np.asarray(time_series) + if time_series.ndim != 1: + raise ValueError("Input time_series must be a 1D array.") + + # Normalize the input time series to have zero mean and unit variance + mean, std = np.mean(time_series), np.std(time_series) + time_series = (time_series - mean) / ( + std + 1e-8 + ) # Add a small value to avoid division by zero + + # Check if max_imfs is valid + if max_imfs is None: + # 表示会完整的分解所有的IMF分量 + # 直到剩余的分量不再满足IMF的定义为止 + max_imfs = -1 + + # Validate the spline kind + assert spline_kind in [ + "akima", + "cubic", + "pchip", + "cubic_hermite", + "slinear", + "quadratic", + "linear", + ], "spline_kind must be one of 'akima', 'cubic', 'pchip', 'cubic_hermite', 'slinear', 'quadratic', 'linear'." + + # Validate the extrema detection method + assert extrema_detection in [ + "parabol", + "simple", + ], "extrema_detection must be one of 'parabol' or 'simple'." + + # Initialize random number generator + if rng is None: + rng = np.random.RandomState(seed=seed) + + # Perform Empirical Mode Decomposition + emd = EMD( + max_imfs=max_imfs, spline_kind=spline_kind, extrema_detection=extrema_detection + ) + print(type(time_series)) + imfs = emd.fit_transform(signal=time_series) + + # Get the number of IMFs extracted + num_imfs = imfs.shape[0] + + # Randomly select a scaling factor for modulation + scale_factor = rng.uniform( + low=min_scale_factor, high=max_scale_factor, size=num_imfs + ) + + # Validate the low_frequency_enhancement parameter + assert isinstance( + low_frequency_enhancement, bool + ), "low_frequency_enhancement must be a boolean value." + if low_frequency_enhancement is True: + # If low-frequency enhancement is enabled, we can apply a stronger scaling to the lower frequency IMFs + scale_factor = np.sort( + scale_factor + ) # Sort the scale factors to enhance low-frequency components more than high-frequency ones + + # Randomly modify the IMFs to create modulation + modified_imfs = [] + for index, imf in enumerate(imfs): + # Randomly scale each IMF by a factor between min_scale_factor and max_scale_factor + scaled_imf = imf * scale_factor[index] + modified_imfs.append(scaled_imf) + + # Reconstruct the signal from the modified IMFs + modulated_time_series = np.sum(modified_imfs, axis=0) + + # Denormalize the modulated time series to restore the original scale + modulated_time_series = (modulated_time_series - np.mean(modulated_time_series)) / ( + np.std(modulated_time_series) + 1e-8 + ) + modulated_time_series = modulated_time_series * (std + 1e-8) + mean + + return modulated_time_series diff --git a/s2generator/base.py b/s2generator/base.py index 8c78568..69586b8 100644 --- a/s2generator/base.py +++ b/s2generator/base.py @@ -376,9 +376,9 @@ def val_diff(self, xs: ndarray, deterministic: Optional[bool] = True) -> ndarray if xs.ndim > 1: # For multivariate case, keep other dimensions constant x_uniform_input = np.tile(np.mean(xs, axis=0), (n_integration_points, 1)) - x_uniform_input[ - :, 0 - ] = x_uniform # Replace first dimension with uniform grid + x_uniform_input[:, 0] = ( + x_uniform # Replace first dimension with uniform grid + ) else: x_uniform_input = x_uniform.reshape(-1, 1) # Ensure 2D array for val method From 55d2c858bc48e60ff81d03e0eb9c84bfb47886ed Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 15:52:47 +0800 Subject: [PATCH 09/14] whenxuan: rename the augmentation --- s2generator/augmentation/__init__.py | 16 ++++++++++++---- ...de_modulation.py => _amplitude_modulation.py} | 0 ...r_augmentation.py => _censor_augmentation.py} | 0 ...dulation.py => _empirical_mode_modulation.py} | 0 ...erturbation.py => _frequency_perturbation.py} | 0 s2generator/base.py | 6 +++--- tests/test_augmentation.py | 2 +- 7 files changed, 16 insertions(+), 8 deletions(-) rename s2generator/augmentation/{amplitude_modulation.py => _amplitude_modulation.py} (100%) rename s2generator/augmentation/{censor_augmentation.py => _censor_augmentation.py} (100%) rename s2generator/augmentation/{empirical_mode_modulation.py => _empirical_mode_modulation.py} (100%) rename s2generator/augmentation/{frequency_perturbation.py => _frequency_perturbation.py} (100%) diff --git a/s2generator/augmentation/__init__.py b/s2generator/augmentation/__init__.py index a86c462..a2f6db6 100644 --- a/s2generator/augmentation/__init__.py +++ b/s2generator/augmentation/__init__.py @@ -6,13 +6,21 @@ @url: https://github.com/wwhenxuan/S2Generator """ -__all__ = ["amplitude_modulation", "censor_augmentation", "frequency_perturbation"] +__all__ = [ + "amplitude_modulation", + "censor_augmentation", + "empirical_model_modulation", + "frequency_perturbation", +] # Import the amplitude modulation function -from .amplitude_modulation import amplitude_modulation +from ._amplitude_modulation import amplitude_modulation # Import the censoring augmentation function -from .censor_augmentation import censor_augmentation +from ._censor_augmentation import censor_augmentation + +# Import the empirical model modulation function +from ._empirical_model_modulation import empirical_model_modulation # Import the frequency perturbation function -from .frequency_perturbation import frequency_perturbation +from ._frequency_perturbation import frequency_perturbation diff --git a/s2generator/augmentation/amplitude_modulation.py b/s2generator/augmentation/_amplitude_modulation.py similarity index 100% rename from s2generator/augmentation/amplitude_modulation.py rename to s2generator/augmentation/_amplitude_modulation.py diff --git a/s2generator/augmentation/censor_augmentation.py b/s2generator/augmentation/_censor_augmentation.py similarity index 100% rename from s2generator/augmentation/censor_augmentation.py rename to s2generator/augmentation/_censor_augmentation.py diff --git a/s2generator/augmentation/empirical_mode_modulation.py b/s2generator/augmentation/_empirical_mode_modulation.py similarity index 100% rename from s2generator/augmentation/empirical_mode_modulation.py rename to s2generator/augmentation/_empirical_mode_modulation.py diff --git a/s2generator/augmentation/frequency_perturbation.py b/s2generator/augmentation/_frequency_perturbation.py similarity index 100% rename from s2generator/augmentation/frequency_perturbation.py rename to s2generator/augmentation/_frequency_perturbation.py diff --git a/s2generator/base.py b/s2generator/base.py index 69586b8..8c78568 100644 --- a/s2generator/base.py +++ b/s2generator/base.py @@ -376,9 +376,9 @@ def val_diff(self, xs: ndarray, deterministic: Optional[bool] = True) -> ndarray if xs.ndim > 1: # For multivariate case, keep other dimensions constant x_uniform_input = np.tile(np.mean(xs, axis=0), (n_integration_points, 1)) - x_uniform_input[:, 0] = ( - x_uniform # Replace first dimension with uniform grid - ) + x_uniform_input[ + :, 0 + ] = x_uniform # Replace first dimension with uniform grid else: x_uniform_input = x_uniform.reshape(-1, 1) # Ensure 2D array for val method diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 53cd2f0..158f19d 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -15,7 +15,7 @@ frequency_perturbation, ) -from s2generator.augmentation.frequency_perturbation import sample_random_perturbation +from s2generator.augmentation._frequency_perturbation import sample_random_perturbation class TestDataAugmentation(unittest.TestCase): From ca1eb6258cfc4fd69f80bbc95f645615dfe854cc Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 16:13:35 +0800 Subject: [PATCH 10/14] whenxuan: update the wiener filter for de-noise --- s2generator/augmentation/__init__.py | 8 +++ s2generator/augmentation/_spike_injection.py | 16 ++++++ s2generator/augmentation/_wiener_filter.py | 60 ++++++++++++++++++++ 3 files changed, 84 insertions(+) create mode 100644 s2generator/augmentation/_spike_injection.py create mode 100644 s2generator/augmentation/_wiener_filter.py diff --git a/s2generator/augmentation/__init__.py b/s2generator/augmentation/__init__.py index a2f6db6..60163f4 100644 --- a/s2generator/augmentation/__init__.py +++ b/s2generator/augmentation/__init__.py @@ -11,6 +11,8 @@ "censor_augmentation", "empirical_model_modulation", "frequency_perturbation", + "spike_injection", + "wiener_filter" ] # Import the amplitude modulation function @@ -24,3 +26,9 @@ # Import the frequency perturbation function from ._frequency_perturbation import frequency_perturbation + +# Import the spike injection function +from ._spike_injection import spike_injection + +# Import the wiener filter function +from ._wiener_filter import wiener_filter diff --git a/s2generator/augmentation/_spike_injection.py b/s2generator/augmentation/_spike_injection.py new file mode 100644 index 0000000..844d36c --- /dev/null +++ b/s2generator/augmentation/_spike_injection.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +""" +Created on 2026/03/05 15:54:19 +@author: Whenxuan Wang +@email: wwhenxuan@gmail.com +@url: https://github.com/wwhenxuan/S2Generator +""" +import numpy as np + + +def spike_injection(time_series: np.ndarray) -> np.ndarray: + """ + Perform spike injection augmentation on the input time series. + This augmentation randomly injects spikes into the input time series to simulate sudden and extreme events, which can help models learn to handle such anomalies. + """ + pass diff --git a/s2generator/augmentation/_wiener_filter.py b/s2generator/augmentation/_wiener_filter.py new file mode 100644 index 0000000..e9e02eb --- /dev/null +++ b/s2generator/augmentation/_wiener_filter.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +""" +Created on 2026/03/05 16:12:41 +@author: Whenxuan Wang +@email: wwhenxuan@gmail.com +@url: https://github.com/wwhenxuan/S2Generator +""" + +from typing import Optional + +import numpy as np + + +def wiener_filter( + time_series: np.ndarray, + noise_variance: float = 1.0, + window_size: Optional[int] = None, +) -> np.ndarray: + """ + Implement Wiener filter to remove white noise from signals. + + :param time_series: The input time series, a 1D numpy array representing the noisy signal. + :param noise_variance: The known variance of the white noise to be removed, default is 1.0 + :param window_size: The window size for power spectral density estimation, optional + if None, it will use the entire signal length as the window size. + + :return: The filtered time series, a 1D numpy array of the same length as the input series. + """ + # Validate the input time series + time_series = np.asarray(time_series, dtype=np.float64) + + # If window_size is not provided, use the length of the time series + if window_size is None: + window_size = len(time_series) + + # Calculate the Fourier transform of the signal + signal_fft = np.fft.fft(time_series) + + # Calculate the power spectral density (PSD) of the signal. + signal_psd = np.abs(signal_fft) ** 2 / len(time_series) + + # Estimate the power spectrum of the original signal (noisy PSD - noise PSD) + # The power spectral density of noise is constant for white noise: noise variance + noise_psd = noise_variance + original_signal_psd_estimate = np.maximum( + signal_psd - noise_psd, 1e-10 + ) # Avoid negative values + + # Calculate the frequency domain response of the Wiener filter. + wiener_filter_freq = original_signal_psd_estimate / ( + original_signal_psd_estimate + noise_psd + ) + + # Apply a filter and perform an inverse Fourier transform to return to the time domain. + filtered_fft = signal_fft * wiener_filter_freq + + # Remove numerical error by taking the real part + filtered_signal = np.fft.ifft(filtered_fft).real + + return filtered_signal From 0a91fc8720a389cadace67f8a4a6317eac938a05 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 16:31:28 +0800 Subject: [PATCH 11/14] whenxuan: add the linear trend --- s2generator/augmentation/__init__.py | 2 +- .../augmentation/_time_transformation.py | 72 +++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 s2generator/augmentation/_time_transformation.py diff --git a/s2generator/augmentation/__init__.py b/s2generator/augmentation/__init__.py index 60163f4..dc1165b 100644 --- a/s2generator/augmentation/__init__.py +++ b/s2generator/augmentation/__init__.py @@ -12,7 +12,7 @@ "empirical_model_modulation", "frequency_perturbation", "spike_injection", - "wiener_filter" + "wiener_filter", ] # Import the amplitude modulation function diff --git a/s2generator/augmentation/_time_transformation.py b/s2generator/augmentation/_time_transformation.py new file mode 100644 index 0000000..cd5c368 --- /dev/null +++ b/s2generator/augmentation/_time_transformation.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +""" +Created on 2026/03/05 16:19:59 +@author: Whenxuan Wang +@email: wwhenxuan@gmail.com +@url: https://github.com/wwhenxuan/S2Generator +""" +import numpy as np + + +def add_linear_trend( + time_series: np.ndarray, trend_strength: float = 1.0, direction: str = "upward" +) -> np.ndarray: + """ + Perform linear trend augmentation on the input time series. + This augmentation adds a linear trend to the input time series, + which can help models learn to handle non-stationary data and improve their robustness to trends. + + :param time_series: Input time series, a 1D numpy array + :param trend_strength: The strength of the linear trend to be added, default is 1.0. + :param direction: The direction of the linear trend, either "upward" or "downward", default is "upward". + + :return: Augmented time series with a linear trend, a 1D numpy array of the same length as the input series. + """ + + # Get the length of the time series + seq_length = len(time_series) + + # Calculate the the energy of the original time series + original_energy = np.mean(time_series**2) + + # Create a linear trend + if direction == "upward": + trend = np.linspace(0, trend_strength * seq_length, seq_length) + elif direction == "downward": + trend = np.linspace(0, -trend_strength * seq_length, seq_length) + else: + raise ValueError("direction must be either 'upward' or 'downward'") + + # Scale the trend to have the same energy as the original time series + trend_energy = np.mean(trend**2) + + if trend_energy > 0: + trend = trend * np.sqrt(original_energy / trend_energy) + + # Average the original signal and the trend to maintain the overall scale + return (time_series + trend) / 2 + + +if __name__ == "__main__": + # Example usage + import matplotlib.pyplot as plt + + # Create a sample time series (sine wave) + t = np.linspace(0, 10, 500) + original_series = np.sin(t) + + # Add linear trend + augmented_series = add_linear_trend( + original_series, trend_strength=1, direction="downward" + ) + + # Plot the original and augmented time series + plt.figure(figsize=(12, 6)) + plt.plot(t, original_series, label="Original Time Series") + plt.plot(t, augmented_series, label="Augmented Time Series with Linear Trend") + plt.legend() + plt.title("Linear Trend Augmentation") + plt.xlabel("Time") + plt.ylabel("Value") + plt.grid() + plt.show() From d9f0f260f64af1e73673b9bf10456d742ef408f2 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 16:45:22 +0800 Subject: [PATCH 12/14] whenxuan: add the mixup for time series --- .../augmentation/_time_transformation.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/s2generator/augmentation/_time_transformation.py b/s2generator/augmentation/_time_transformation.py index cd5c368..2676305 100644 --- a/s2generator/augmentation/_time_transformation.py +++ b/s2generator/augmentation/_time_transformation.py @@ -47,6 +47,27 @@ def add_linear_trend( return (time_series + trend) / 2 +def time_series_mixup(a: np.ndarray, b: np.ndarray, alpha: float = 0.7) -> np.ndarray: + """ + Mixup Enhancement: Weighted mixing of two time series to create a new augmented signal. + This method combines two time series by taking a weighted average of them, + where the weights are determined by a mixing parameter alpha. + This can help models learn to generalize better by exposing them to a wider variety of signal combinations. + + :param a: First input time series, a 1D numpy array. + :param b: Second input time series, a 1D numpy array of the same length as a. + :param alpha: The mixing parameter that controls the weight of each time series in the mixup, + default is 0.7. A value of alpha close to 1 gives more weight to the first time series (a), + while a value close to 0 gives more weight to the second time series (b). + + :return: Mixed time series, a 1D numpy array of the same length as the input series. + """ + assert a.shape == b.shape, "Input time series must have the same shape" + + # Calculate the mixed signal as a weighted average of the two input signals + return alpha * a + (1 - alpha) * b + + if __name__ == "__main__": # Example usage import matplotlib.pyplot as plt From c813ea9001020e7c75a33b21f913c59a0a3d1ce1 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 17:23:07 +0800 Subject: [PATCH 13/14] whenxuan: update the init --- s2generator/augmentation/__init__.py | 3 +++ .../augmentation/_time_transformation.py | 25 ------------------- 2 files changed, 3 insertions(+), 25 deletions(-) diff --git a/s2generator/augmentation/__init__.py b/s2generator/augmentation/__init__.py index dc1165b..0b7db0b 100644 --- a/s2generator/augmentation/__init__.py +++ b/s2generator/augmentation/__init__.py @@ -32,3 +32,6 @@ # Import the wiener filter function from ._wiener_filter import wiener_filter + +# Import the time transformation functions +from ._time_transformation import add_linear_trend, time_series_mixup diff --git a/s2generator/augmentation/_time_transformation.py b/s2generator/augmentation/_time_transformation.py index 2676305..9fed6d2 100644 --- a/s2generator/augmentation/_time_transformation.py +++ b/s2generator/augmentation/_time_transformation.py @@ -66,28 +66,3 @@ def time_series_mixup(a: np.ndarray, b: np.ndarray, alpha: float = 0.7) -> np.nd # Calculate the mixed signal as a weighted average of the two input signals return alpha * a + (1 - alpha) * b - - -if __name__ == "__main__": - # Example usage - import matplotlib.pyplot as plt - - # Create a sample time series (sine wave) - t = np.linspace(0, 10, 500) - original_series = np.sin(t) - - # Add linear trend - augmented_series = add_linear_trend( - original_series, trend_strength=1, direction="downward" - ) - - # Plot the original and augmented time series - plt.figure(figsize=(12, 6)) - plt.plot(t, original_series, label="Original Time Series") - plt.plot(t, augmented_series, label="Augmented Time Series with Linear Trend") - plt.legend() - plt.title("Linear Trend Augmentation") - plt.xlabel("Time") - plt.ylabel("Value") - plt.grid() - plt.show() From 356a4fc270433a6ee8c5d9b8eb9adcf4242b4399 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 17:25:11 +0800 Subject: [PATCH 14/14] whenxuan: update the __all__ --- s2generator/augmentation/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/s2generator/augmentation/__init__.py b/s2generator/augmentation/__init__.py index 0b7db0b..8538e6c 100644 --- a/s2generator/augmentation/__init__.py +++ b/s2generator/augmentation/__init__.py @@ -13,6 +13,8 @@ "frequency_perturbation", "spike_injection", "wiener_filter", + "add_linear_trend", + "time_series_mixup", ] # Import the amplitude modulation function