From 9a061423fe67144c4614447bc551ec8d0bb089f9 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Thu, 5 Mar 2026 20:36:11 +0800 Subject: [PATCH 01/11] whenxuan: add the vscode to ignore --- .gitignore | 3 +++ s2generator/base.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 8725249..f508511 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,6 @@ venv/ /docs/auto_examples/*.md5 /docs/_build/ + +.vscode/ +.vscode diff --git a/s2generator/base.py b/s2generator/base.py index 69586b8..8c78568 100644 --- a/s2generator/base.py +++ b/s2generator/base.py @@ -376,9 +376,9 @@ def val_diff(self, xs: ndarray, deterministic: Optional[bool] = True) -> ndarray if xs.ndim > 1: # For multivariate case, keep other dimensions constant x_uniform_input = np.tile(np.mean(xs, axis=0), (n_integration_points, 1)) - x_uniform_input[:, 0] = ( - x_uniform # Replace first dimension with uniform grid - ) + x_uniform_input[ + :, 0 + ] = x_uniform # Replace first dimension with uniform grid else: x_uniform_input = x_uniform.reshape(-1, 1) # Ensure 2D array for val method From 4c8ff4178e66cfe6c1a1f95a118e6124744c87d6 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Fri, 6 Mar 2026 23:05:05 +0800 Subject: [PATCH 02/11] whenxuan: black init --- s2generator/base.py | 6 +++--- s2generator/simulator/kalman_filtering.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/s2generator/base.py b/s2generator/base.py index 8c78568..69586b8 100644 --- a/s2generator/base.py +++ b/s2generator/base.py @@ -376,9 +376,9 @@ def val_diff(self, xs: ndarray, deterministic: Optional[bool] = True) -> ndarray if xs.ndim > 1: # For multivariate case, keep other dimensions constant x_uniform_input = np.tile(np.mean(xs, axis=0), (n_integration_points, 1)) - x_uniform_input[ - :, 0 - ] = x_uniform # Replace first dimension with uniform grid + x_uniform_input[:, 0] = ( + x_uniform # Replace first dimension with uniform grid + ) else: x_uniform_input = x_uniform.reshape(-1, 1) # Ensure 2D array for val method diff --git a/s2generator/simulator/kalman_filtering.py b/s2generator/simulator/kalman_filtering.py index e69de29..904782a 100644 --- a/s2generator/simulator/kalman_filtering.py +++ b/s2generator/simulator/kalman_filtering.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +""" +Created on 2026/03/06 23:04:59 +@author: Whenxuan Wang +@email: wwhenxuan@gmail.com +@url: https://github.com/wwhenxuan/S2Generator +""" + +from typing import Optional + +import numpy as np From 2026453f27dada58f3f230169b0c921630e922a2 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Tue, 10 Mar 2026 12:05:33 +0800 Subject: [PATCH 03/11] whenxuan: add the fit and transform for gaussia mixture model --- s2generator/simulator/gaussia_mixture.py | 123 ++++++++++++++++++++++- 1 file changed, 120 insertions(+), 3 deletions(-) diff --git a/s2generator/simulator/gaussia_mixture.py b/s2generator/simulator/gaussia_mixture.py index 14812b4..d2c522f 100644 --- a/s2generator/simulator/gaussia_mixture.py +++ b/s2generator/simulator/gaussia_mixture.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import matplotlib.pyplot as plt from sklearn.mixture import GaussianMixture @@ -15,7 +17,7 @@ def __init__( n_init: int = 1, init_params: str = "kmeans", random_state=42, - ): + ) -> None: """ :param n_components: int, default=3. The number of mixture components. :param covariance_type: str, default='full'. The type of covariance parameters to use. Must be one of 'full', 'tied', 'diag', 'spherical'. @@ -54,9 +56,124 @@ def __init__( random_state=random_state, ) - def fit(self, time_series: np.ndarray): + def fit(self, time_series: np.ndarray) -> None: """ Fit the Gaussian Mixture Model to the provided time series data. - :param time_series: np.ndarray, shape (n_samples, n_features). The input time series data. + :param time_series: 1D np.ndarray. The input time series data. """ + + # Check if the input time series is 1D numpy array + time_series = self._check_inputs(time_series) + + # Reshape the time series data to fit the GMM input requirements + time_series = time_series.reshape(-1, 1) + self.model.fit(time_series) + + def transform( + self, num_samples: int, seq_len: int, random_state: Optional[int] = None + ) -> np.ndarray: + """ """ + + # 根据GMM的组件权重,随机选择每个样本属于哪个高斯组件 + component_indices = np.random.choice( + self.n_components, size=(num_samples, seq_len), p=self.model.weights_ + ) + + # 从选中的高斯组件中采样生成数据 + generated_series = np.zeros(shape=(num_samples, seq_len)) + for i in range(num_samples): + for j in range(seq_len): + comp = component_indices[i, j] + mean = self.model.means_[comp, 0] + cov = self.covariance(comp) + generated_series[i, j] = np.random.normal(loc=mean, scale=np.sqrt(cov)) + + return generated_series + + def _check_inputs(self, time_series: np.ndarray) -> np.ndarray: + """ + Check if the input time series is a 1D numpy array and reshape it to fit the GMM input requirements. + :param time_series: 1D np.ndarray. The input time series data. + :return: Reshaped time series data. + """ + if not isinstance(time_series, np.ndarray): + raise ValueError("Input time_series must be a numpy array.") + if time_series.ndim != 1: + raise ValueError("Input time_series must be a 1D array.") + + return time_series.reshape(-1, 1) + + def weight(self, component_index: int) -> float: + """ + Get the weight of a specific component in the GMM. + :param component_index: int. The index of the component. + :return: The weight of the specified component. + """ + if component_index < 0 or component_index >= self.n_components: + raise ValueError( + f"Component index must be between 0 and {self.n_components - 1}." + ) + + return self.model.weights_[component_index] + + def weights(self) -> np.ndarray: + """ + Get the weights of all components in the GMM. + :return: The weights of all components. + """ + return self.model.weights_ + + def mean(self, component_index: int) -> float: + """ + Get the mean of a specific component in the GMM. + :param component_index: int. The index of the component. + :return: The mean of the specified component. + """ + if component_index < 0 or component_index >= self.n_components: + raise ValueError( + f"Component index must be between 0 and {self.n_components - 1}." + ) + + return self.model.means_[component_index][0] + + def means(self) -> np.ndarray: + """ + Get the means of all components in the GMM. + :return: The means of all components. + """ + return self.model.means_.flatten() + + def covariance(self, component_index: int) -> float: + """ + Get the covariance of a specific component in the GMM. + :param component_index: int. The index of the component. + :return: The covariance of the specified component. + """ + if component_index < 0 or component_index >= self.n_components: + raise ValueError( + f"Component index must be between 0 and {self.n_components - 1}." + ) + + if self.covariance_type == "full": + return self.model.covariances_[component_index][0][0] + elif self.covariance_type == "tied": + return self.model.covariances_[0][0][0] + elif self.covariance_type == "diag": + return self.model.covariances_[component_index][0] + elif self.covariance_type == "spherical": + return self.model.covariances_[component_index] + + def covariances(self) -> np.ndarray: + """ + Get the covariances of all components in the GMM. + :return: The covariances of all components. + """ + if self.covariance_type == "full": + return np.array([cov[0][0] for cov in self.model.covariances_]) + elif self.covariance_type == "tied": + return np.array([self.model.covariances_[0][0][0]] * self.n_components) + elif self.covariance_type == "diag": + return self.model.covariances_.flatten() + elif self.covariance_type == "spherical": + return self.model.covariances_ From 5de704b3e3649d91cfeaea653564e78646db4a78 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Tue, 10 Mar 2026 19:33:22 +0800 Subject: [PATCH 04/11] whenxuan: update the fit transform --- s2generator/simulator/gaussia_mixture.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/s2generator/simulator/gaussia_mixture.py b/s2generator/simulator/gaussia_mixture.py index d2c522f..aacd270 100644 --- a/s2generator/simulator/gaussia_mixture.py +++ b/s2generator/simulator/gaussia_mixture.py @@ -73,7 +73,13 @@ def fit(self, time_series: np.ndarray) -> None: def transform( self, num_samples: int, seq_len: int, random_state: Optional[int] = None ) -> np.ndarray: - """ """ + """ + Transform the model to generate new samples. + :param num_samples: int. The number of samples to generate. + :param seq_len: int. The length of each generated sequence. + :param random_state: Optional[int]. The random state for reproducibility. + :return: Generated time series data. + """ # 根据GMM的组件权重,随机选择每个样本属于哪个高斯组件 component_indices = np.random.choice( From 1614ab30056b8817e1d1ad1185cdb60afd9e3004 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Tue, 10 Mar 2026 19:44:23 +0800 Subject: [PATCH 05/11] whenxuan: update the unit test for mixup --- tests/test_augmentation.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 55d0048..3c18da9 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -143,6 +143,38 @@ def test_amplitude_modulation(self) -> None: msg="Modulated series is identical to original series in `test_amplitude_modulation` method", ) + def test_time_series_mixup(self) -> None: + """Test the function for performing time series mixup augmentation.""" + # Generate two simple time series for testing + t = np.linspace(0, 1, 100) + series_a = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100) + series_b = np.cos(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100) + + alpha = 0.7 + + # Apply time series mixup augmentation + mixed_series = time_series_mixup( + a=series_a.copy(), b=series_b.copy(), alpha=alpha + ) + + # Check that the output has the same length as the input + self.assertEqual( + len(mixed_series), + len(series_a), + msg="Output length does not match input length in `test_time_series_mixup` method", + ) + + # Check that the output is different from both inputs (since we applied mixup) + self.assertFalse( + np.array_equal(mixed_series, series_a), + msg="Mixed series is identical to first input series in `test_time_series_mixup` method", + ) + self.assertFalse( + np.array_equal(mixed_series, series_b), + msg="Mixed series is identical to second input series in `test_time_series_mixup` method", + ) + + if __name__ == "__main__": unittest.main() From e51176d4ecfebdf45b9517370a5a6557822b6342 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Tue, 10 Mar 2026 19:52:14 +0800 Subject: [PATCH 06/11] whenxuan: fix the trend energy for add linear trend --- .../augmentation/_time_transformation.py | 25 +++++++++++++++--- s2generator/base.py | 6 ++--- tests/test_augmentation.py | 26 +++++++++++++++++++ 3 files changed, 50 insertions(+), 7 deletions(-) diff --git a/s2generator/augmentation/_time_transformation.py b/s2generator/augmentation/_time_transformation.py index 9fed6d2..e8bb00d 100644 --- a/s2generator/augmentation/_time_transformation.py +++ b/s2generator/augmentation/_time_transformation.py @@ -9,7 +9,10 @@ def add_linear_trend( - time_series: np.ndarray, trend_strength: float = 1.0, direction: str = "upward" + time_series: np.ndarray, + trend_strength: float = 1.0, + direction: str = "upward", + normalize: bool = True, ) -> np.ndarray: """ Perform linear trend augmentation on the input time series. @@ -19,6 +22,7 @@ def add_linear_trend( :param time_series: Input time series, a 1D numpy array :param trend_strength: The strength of the linear trend to be added, default is 1.0. :param direction: The direction of the linear trend, either "upward" or "downward", default is "upward". + :param normalize: Whether to normalize the output time series to maintain the same scale as the input, default is True. :return: Augmented time series with a linear trend, a 1D numpy array of the same length as the input series. """ @@ -31,9 +35,9 @@ def add_linear_trend( # Create a linear trend if direction == "upward": - trend = np.linspace(0, trend_strength * seq_length, seq_length) + trend = np.linspace(0, 1, seq_length) elif direction == "downward": - trend = np.linspace(0, -trend_strength * seq_length, seq_length) + trend = np.linspace(0, -1, seq_length) else: raise ValueError("direction must be either 'upward' or 'downward'") @@ -41,7 +45,20 @@ def add_linear_trend( trend_energy = np.mean(trend**2) if trend_energy > 0: - trend = trend * np.sqrt(original_energy / trend_energy) + # Scale the trend to have the same energy as the original time series, and then apply the trend strength factor + trend = trend * np.sqrt(original_energy / trend_energy) * trend_strength + else: + # If the trend energy is zero (which can happen if the trend is constant), + # we set the trend to zero to avoid division by zero + trend = np.zeros_like(trend) + + if normalize: + augmented_series = time_series + trend + # Normalize the augmented series to maintain the same energy as the original time series + augmented_series = (augmented_series - np.mean(augmented_series)) / np.std( + augmented_series + ) * np.std(time_series) + np.mean(time_series) + return augmented_series # Average the original signal and the trend to maintain the overall scale return (time_series + trend) / 2 diff --git a/s2generator/base.py b/s2generator/base.py index 69586b8..8c78568 100644 --- a/s2generator/base.py +++ b/s2generator/base.py @@ -376,9 +376,9 @@ def val_diff(self, xs: ndarray, deterministic: Optional[bool] = True) -> ndarray if xs.ndim > 1: # For multivariate case, keep other dimensions constant x_uniform_input = np.tile(np.mean(xs, axis=0), (n_integration_points, 1)) - x_uniform_input[:, 0] = ( - x_uniform # Replace first dimension with uniform grid - ) + x_uniform_input[ + :, 0 + ] = x_uniform # Replace first dimension with uniform grid else: x_uniform_input = x_uniform.reshape(-1, 1) # Ensure 2D array for val method diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 3c18da9..7b95493 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -174,6 +174,32 @@ def test_time_series_mixup(self) -> None: msg="Mixed series is identical to second input series in `test_time_series_mixup` method", ) + def test_add_linear_trend(self) -> None: + """Test the function for adding a linear trend to time series data.""" + # Generate a simple time series for testing + t = np.linspace(0, 1, 100) + series = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100) + + intercept = 0.5 + + trended_series = add_linear_trend( + time_series=series.copy(), + intercept=intercept, + rng=self.rng, + ) + + # Check that the output has the same length as the input + self.assertEqual( + len(trended_series), + len(series), + msg="Output length does not match input length in `test_add_linear_trend` method", + ) + + # Check that the output is different from the input (since we applied a linear trend) + self.assertFalse( + np.array_equal(trended_series, series), + msg="Trended series is identical to original series in `test_add_linear_trend` method", + ) if __name__ == "__main__": From 8a5b1bfa91a8c3a148d980181c518bbda7485445 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Tue, 10 Mar 2026 19:57:48 +0800 Subject: [PATCH 07/11] whenxuan: update the unit test for add linear trend --- tests/test_augmentation.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 7b95493..e29396c 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -180,20 +180,31 @@ def test_add_linear_trend(self) -> None: t = np.linspace(0, 1, 100) series = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100) - intercept = 0.5 - - trended_series = add_linear_trend( - time_series=series.copy(), - intercept=intercept, - rng=self.rng, - ) - - # Check that the output has the same length as the input + # Test the upward trend + trended_series = add_linear_trend(time_series=series.copy(), direction="upward") + # Check the outputs size self.assertEqual( len(trended_series), len(series), msg="Output length does not match input length in `test_add_linear_trend` method", ) + # Check the upward trend + trend_upward = trended_series - series + self.assertTrue( + trend_upward[-1] > trend_upward[0], + msg="Upward trend is not correctly applied in `test_add_linear_trend` method", + ) + + # Test the downward trend + trended_series = add_linear_trend( + time_series=series.copy(), direction="downward" + ) + # Check the downward trend + trend_downward = trended_series - series + self.assertTrue( + trend_downward[-1] < trend_downward[0], + msg="Downward trend is not correctly applied in `test_add_linear_trend` method", + ) # Check that the output is different from the input (since we applied a linear trend) self.assertFalse( From 55d8bde16306be885a7e4097c001011c86fca1b8 Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Tue, 10 Mar 2026 19:59:57 +0800 Subject: [PATCH 08/11] whenxuan: remove the unit test for io --- tests/test_tools.py | 212 ++++++++++++++++++------------------ tests/tests/data/s2data.npy | Bin 0 -> 345 bytes tests/tests/data/s2data.npz | Bin 0 -> 794 bytes 3 files changed, 106 insertions(+), 106 deletions(-) create mode 100644 tests/tests/data/s2data.npy create mode 100644 tests/tests/data/s2data.npz diff --git a/tests/test_tools.py b/tests/test_tools.py index f34ae26..d886e9b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -254,112 +254,112 @@ def test_ensure_directory_exists(self) -> None: msg="Should return path with default filename for directory input!", ) - def test_save_npy(self) -> None: - """ - Tests the function that saves data to NPY format. - - Validates that the save operation returns success status. - """ - status = save_npy(data=self.data, save_path=self.npy_path) - self.assertTrue( - expr=status, msg="NPY save operation should return True on success!" - ) - - def test_load_npy(self) -> None: - """ - Tests the function that loads data from NPY format. - - Validates that loaded data matches the original saved data. - """ - loaded_data = load_npy(data_path=self.npy_path) - - # Validate all key-value pairs match - for key in self.data.keys(): - self.assertEqual( - first=loaded_data[key], - second=self.data[key], - msg="Loaded data should match original data!", - ) - - def test_save_npz(self) -> None: - """ - Tests the function that saves data to NPZ format. - - Validates that the save operation returns success status. - """ - status = save_npz(data=self.data, save_path=self.npz_path) - self.assertTrue( - expr=status, msg="NPZ save operation should return True on success!" - ) - - def test_load_npz(self) -> None: - """ - Tests the function that loads data from NPZ format. - - Validates that loaded data matches the original saved data. - """ - loaded_data = load_npz(data_path=self.npz_path) - - # Validate all key-value pairs match - for key in self.data.keys(): - self.assertEqual( - first=loaded_data[key], - second=self.data[key], - msg="Loaded data should match original data!", - ) - - def test_save_s2data(self) -> None: - """ - Tests the function that saves S2 data in multiple formats. - - Validates that save operations return success status for both NPY and NPZ formats. - """ - # Test NPY format - status = save_s2data( - save_path=self.s2_npy_path, - symbol=self.data["symbol"], - excitation=self.data["excitation"], - response=self.data["response"], - ) - self.assertTrue( - expr=status, msg="S2 data NPY save should return True on success!" - ) - - # Test NPZ format - status = save_s2data( - save_path=self.s2_npz_path, - symbol=self.data["symbol"], - excitation=self.data["excitation"], - response=self.data["response"], - ) - self.assertTrue( - expr=status, msg="S2 data NPZ save should return True on success!" - ) - - def test_load_s2data(self) -> None: - """ - Tests the function that loads S2 data from file. - - Validates that loaded S2 data matches the original saved data. - """ - symbol, excitation, response = load_s2data(data_path=self.s2_npy_path) - - # Validate all components match - self.assertEqual( - first=symbol, - second=self.data["symbol"], - msg="Loaded symbol should match original!", - ) - self.assertEqual( - first=excitation, - second=self.data["excitation"], - msg="Loaded excitation should match original!", - ) - self.assertEqual( - first=response, - second=self.data["response"], - msg="Loaded response should match original!", - ) + # def test_save_npy(self) -> None: + # """ + # Tests the function that saves data to NPY format. + # + # Validates that the save operation returns success status. + # """ + # status = save_npy(data=self.data, save_path=self.npy_path) + # self.assertTrue( + # expr=status, msg="NPY save operation should return True on success!" + # ) + + # def test_load_npy(self) -> None: + # """ + # Tests the function that loads data from NPY format. + # + # Validates that loaded data matches the original saved data. + # """ + # loaded_data = load_npy(data_path=self.npy_path) + # + # # Validate all key-value pairs match + # for key in self.data.keys(): + # self.assertEqual( + # first=loaded_data[key], + # second=self.data[key], + # msg="Loaded data should match original data!", + # ) + + # def test_save_npz(self) -> None: + # """ + # Tests the function that saves data to NPZ format. + # + # Validates that the save operation returns success status. + # """ + # status = save_npz(data=self.data, save_path=self.npz_path) + # self.assertTrue( + # expr=status, msg="NPZ save operation should return True on success!" + # ) + + # def test_load_npz(self) -> None: + # """ + # Tests the function that loads data from NPZ format. + # + # Validates that loaded data matches the original saved data. + # """ + # loaded_data = load_npz(data_path=self.npz_path) + # + # # Validate all key-value pairs match + # for key in self.data.keys(): + # self.assertEqual( + # first=loaded_data[key], + # second=self.data[key], + # msg="Loaded data should match original data!", + # ) + + # def test_save_s2data(self) -> None: + # """ + # Tests the function that saves S2 data in multiple formats. + + # Validates that save operations return success status for both NPY and NPZ formats. + # """ + # # Test NPY format + # status = save_s2data( + # save_path=self.s2_npy_path, + # symbol=self.data["symbol"], + # excitation=self.data["excitation"], + # response=self.data["response"], + # ) + # self.assertTrue( + # expr=status, msg="S2 data NPY save should return True on success!" + # ) + + # # Test NPZ format + # status = save_s2data( + # save_path=self.s2_npz_path, + # symbol=self.data["symbol"], + # excitation=self.data["excitation"], + # response=self.data["response"], + # ) + # self.assertTrue( + # expr=status, msg="S2 data NPZ save should return True on success!" + # ) + + # def test_load_s2data(self) -> None: + # """ + # Tests the function that loads S2 data from file. + + # Validates that loaded S2 data matches the original saved data. + # """ + # symbol, excitation, response = load_s2data(data_path=self.s2_npy_path) + + # # Validate all components match + # self.assertEqual( + # first=symbol, + # second=self.data["symbol"], + # msg="Loaded symbol should match original!", + # ) + # self.assertEqual( + # first=excitation, + # second=self.data["excitation"], + # msg="Loaded excitation should match original!", + # ) + # self.assertEqual( + # first=response, + # second=self.data["response"], + # msg="Loaded response should match original!", + # ) if __name__ == "__main__": diff --git a/tests/tests/data/s2data.npy b/tests/tests/data/s2data.npy new file mode 100644 index 0000000000000000000000000000000000000000..6068906d838c6b8aa027f23d08260bfdd289b2f9 GIT binary patch literal 345 zcmbu3O-sW-5QcXfTkHDmFUW0EC^?H44}u3nQz;1RLCR(`(m;~U>~6%+7Q9I9*&nQ{ z^&j{ihIyZP4)d8lOzs~kc_B-lgH?uKg?u&TWcSV|zckJEJ41*2s&9EOO zbo3|Xu4{ndK=cS9(@Wgner#|!!-FUZ1odo%EKSq9FQ4C$Tsu6>woiDp#^Xo~ykfg5 s(~FbrvN!1%*7zWP+Xatl%J!QS5V2wAOIExsRYS_z+}Hr-+)L) zhBAg~^_0}&15o_aRKG-{1t Date: Tue, 10 Mar 2026 20:03:01 +0800 Subject: [PATCH 09/11] whenxuan: delete the print in function --- s2generator/augmentation/_empirical_mode_modulation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/s2generator/augmentation/_empirical_mode_modulation.py b/s2generator/augmentation/_empirical_mode_modulation.py index 233962d..6e00318 100644 --- a/s2generator/augmentation/_empirical_mode_modulation.py +++ b/s2generator/augmentation/_empirical_mode_modulation.py @@ -86,7 +86,6 @@ def empirical_mode_modulation( emd = EMD( max_imfs=max_imfs, spline_kind=spline_kind, extrema_detection=extrema_detection ) - print(type(time_series)) imfs = emd.fit_transform(signal=time_series) # Get the number of IMFs extracted From 4528567bc5c84881747d0bf47a87e76b6a07c5dc Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Tue, 10 Mar 2026 20:03:25 +0800 Subject: [PATCH 10/11] whenxuan: update the unit test for emd augmentation --- tests/test_augmentation.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index e29396c..7633e34 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -212,6 +212,30 @@ def test_add_linear_trend(self) -> None: msg="Trended series is identical to original series in `test_add_linear_trend` method", ) + def test_empirical_mode_modulation(self) -> None: + """Test the function for performing empirical mode modulation on time series data.""" + # Generate a simple time series for testing + t = np.linspace(0, 1, 100) + series = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100) + + # Apply empirical mode modulation augmentation + modulated_series = empirical_mode_modulation( + time_series=series.copy(), rng=self.rng + ) + + # Check that the output has the same length as the input + self.assertEqual( + len(modulated_series), + len(series), + msg="Output length does not match input length in `test_empirical_mode_modulation` method", + ) + + # Check that the output is different from the input (since we applied empirical mode modulation) + self.assertFalse( + np.array_equal(modulated_series, series), + msg="Modulated series is identical to original series in `test_empirical_mode_modulation` method", + ) + if __name__ == "__main__": unittest.main() From 3c8359b4effed34ac9ab8e6fb7b492d1290c427e Mon Sep 17 00:00:00 2001 From: wwhenxuan Date: Tue, 10 Mar 2026 20:16:23 +0800 Subject: [PATCH 11/11] whenxuan: update the unit test for wiener filter --- tests/test_augmentation.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 7633e34..675494f 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -236,6 +236,27 @@ def test_empirical_mode_modulation(self) -> None: msg="Modulated series is identical to original series in `test_empirical_mode_modulation` method", ) + def test_wiener_filter(self) -> None: + """Test the function for performing Wiener filtering on time series data.""" + # Generate a simple time series for testing + t = np.linspace(0, 1, 100) + series = np.sin(2 * np.pi * 5 * t) + 0.5 * np.random.normal(size=100) + + # Apply Wiener filter augmentation + filtered_series = wiener_filter(time_series=series.copy()) + + # Check that the output has the same length as the input + self.assertEqual( + len(filtered_series), + len(series), + msg="Output length does not match input length in `test_wiener_filter` method", + ) + + # Check that the output is different from the input (since we applied Wiener filtering) + self.assertFalse( + np.array_equal(filtered_series, series), + msg="Filtered series is identical to original series in `test_wiener_filter` method", + ) if __name__ == "__main__": unittest.main()