diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b2f61a..cfe85b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ noise ([#58](https://github.com/ctrltz/meegsim/pull/58)) - A method for setting phase-phase coupling by adding noise to the shifted copy of input waveform ([#71](https://github.com/ctrltz/meegsim/pull/71)) - Function to convert the sources to mne.Label ([#73](https://github.com/ctrltz/meegsim/pull/73)) - Quick list-like access to the simulated sources ([#82](https://github.com/ctrltz/meegsim/pull/82)) +- Control over the amplitude envelope of the coupled waveform ([#87](https://github.com/ctrltz/meegsim/pull/87)) ### Changed diff --git a/docs/api/coupling.rst b/docs/api/coupling.rst index e6cd55b..d111368 100644 --- a/docs/api/coupling.rst +++ b/docs/api/coupling.rst @@ -6,6 +6,6 @@ Coupling methods .. autosummary:: :toctree: ../generated/ - ppc_shifted_copy_with_noise + ppc_constant_phase_shift ppc_von_mises - constant_phase_shift + ppc_shifted_copy_with_noise diff --git a/src/meegsim/coupling.py b/src/meegsim/coupling.py index 1f786dc..003d1c9 100644 --- a/src/meegsim/coupling.py +++ b/src/meegsim/coupling.py @@ -7,13 +7,48 @@ from scipy.stats import vonmises from scipy.signal import butter, filtfilt, hilbert -from meegsim._check import check_numeric +from meegsim._check import check_numeric, check_option from meegsim.snr import get_variance, amplitude_adjustment_factor from meegsim.utils import normalize_variance from meegsim.waveform import narrowband_oscillation, white_noise -def constant_phase_shift(waveform, sfreq, phase_lag, m=1, n=1, random_state=None): +def _get_envelope(waveform, envelope, sfreq, fmin=None, fmax=None, random_state=None): + check_option( + "the amplitude envelope of the coupled waveform", envelope, ["same", "random"] + ) + if not np.iscomplexobj(waveform): + waveform = hilbert(waveform) + + if envelope == "same": + return np.abs(waveform) + + if fmin is None or fmax is None: + raise ValueError( + "Frequency limits are required for generating the envelope of the coupled waveform" + ) + times = np.arange(waveform.size) / sfreq + random_waveform = narrowband_oscillation( + 1, times, fmin=fmin, fmax=fmax, random_state=random_state + ) + random_waveform = hilbert(random_waveform) + + # TODO: here we could also mix original and random envelope with different + # values of SNR to achieve smooth control over the resulting envelope correlation + return np.abs(random_waveform) + + +def ppc_constant_phase_shift( + waveform, + sfreq, + phase_lag, + fmin=None, + fmax=None, + envelope="random", + m=1, + n=1, + random_state=None, +): """ Generate a time series that is phase coupled to the input time series with a constant phase lag. @@ -32,21 +67,32 @@ def constant_phase_shift(waveform, sfreq, phase_lag, m=1, n=1, random_state=None The input signal to be processed. It can be a real or complex time series. sfreq : float - Sampling frequency of the signal, in Hz. This argument is not used in this - function but is accepted for consistency with other coupling methods. + Sampling frequency of the signal, in Hz. phase_lag : float Constant phase lag to apply to the waveform in radians. - m : int, optional + envelope : str, {"same", "random"} + Controls the amplitude envelope of the coupled waveform to be either randomly + generated (default) or to be the same as the envelope of the input waveform. + + fmin : float, optional + Lower cutoff frequency for the oscillation that gives rise to the random + amplitude envelope (only if the ``envelope`` is set to ``"random"``). + + fmax : float, optional + Upper cutoff frequency for the oscillation that gives rise to the random + amplitude envelope (only if the ``envelope`` is set to ``"random"``). + + m : float, optional Multiplier for the base frequency of the output oscillation, default is 1. - n : int, optional + n : float, optional Multiplier for the base frequency of the input oscillation, default is 1. random_state : None, optional - This parameter is accepted for consistency with other coupling functions - but not used since no randomness is involved. + Random state can be fixed to provide reproducible results if the envelope + is generated randomly. If not set, the results may differ between function calls. Returns ------- @@ -56,16 +102,36 @@ def constant_phase_shift(waveform, sfreq, phase_lag, m=1, n=1, random_state=None if not np.iscomplexobj(waveform): waveform = hilbert(waveform) - waveform_amp = np.abs(waveform) + waveform_amp = _get_envelope(waveform, envelope, sfreq, fmin, fmax, random_state) waveform_angle = np.angle(waveform) - waveform_coupled = waveform_amp * np.exp( - 1j * m / n * waveform_angle + 1j * phase_lag + waveform_coupled = np.real( + waveform_amp * np.exp(1j * m / n * waveform_angle + 1j * phase_lag) + ) + if envelope == "same": + return normalize_variance(waveform_coupled) + + # NOTE: if the envelope was modified, we filter the result again in the target + # frequency range to suppress possible distortions due to merging amplitude + # envelope and phase from different time series + b, a = butter( + N=2, Wn=np.array([m / n * fmin, m / n * fmax]) / sfreq * 2, btype="bandpass" ) - return normalize_variance(np.real(waveform_coupled)) + waveform_coupled = filtfilt(b, a, waveform_coupled) + + return normalize_variance(waveform_coupled) def ppc_von_mises( - waveform, sfreq, phase_lag, kappa, fmin, fmax, m=1, n=1, random_state=None + waveform, + sfreq, + phase_lag, + kappa, + fmin, + fmax, + envelope="random", + m=1, + n=1, + random_state=None, ): """ Generate a time series that is phase coupled to the input time series with @@ -102,10 +168,14 @@ def ppc_von_mises( fmax: float Upper cutoff frequency of the base frequency harmonic (in Hz). - m : int, optional + envelope : str, {"same", "random"} + Controls the amplitude envelope of the coupled waveform to be either randomly + generated (default) or to be the same as the envelope of the input waveform. + + m : float, optional Multiplier for the base frequency of the output oscillation, default is 1. - n : int, optional + n : float, optional Multiplier for the base frequency of the input oscillation, default is 1. random_state : None (default) or int @@ -121,23 +191,26 @@ def ppc_von_mises( if not np.iscomplexobj(waveform): waveform = hilbert(waveform) - waveform_amp = np.abs(waveform) + waveform_amp = _get_envelope(waveform, envelope, sfreq, fmin, fmax, random_state) waveform_angle = np.angle(waveform) - n_samples = len(waveform) + n_samples = waveform.size ph_distr = vonmises.rvs( kappa, loc=phase_lag, size=n_samples, random_state=random_state ) - tmp_waveform = np.real( + waveform_coupled = np.real( waveform_amp * np.exp(1j * m / n * waveform_angle + 1j * ph_distr) ) + + # NOTE: we filter the result again in the target frequency range to suppress + # possible distortions due to separate adjustment of the phase and amplitude + # of the coupled time series b, a = butter( N=2, Wn=np.array([m / n * fmin, m / n * fmax]) / sfreq * 2, btype="bandpass" ) - tmp_waveform = filtfilt(b, a, tmp_waveform) - waveform_coupled = waveform_amp * np.exp(1j * np.angle(hilbert(tmp_waveform))) + waveform_coupled = filtfilt(b, a, waveform_coupled) - return normalize_variance(np.real(waveform_coupled)) + return normalize_variance(waveform_coupled) def _shifted_copy_with_noise( @@ -148,7 +221,9 @@ def _shifted_copy_with_noise( waveform and (2) mixing it with noise to achieve a desired level of signal-to-noise ratio, which determines the resulting phase-phase and amplitude-amplitude coupling. """ - shifted_waveform = constant_phase_shift(waveform, sfreq, phase_lag) + shifted_waveform = ppc_constant_phase_shift( + waveform, sfreq, phase_lag, envelope="same" + ) signal_var = get_variance(shifted_waveform, sfreq, fmin, fmax, filter=True) # NOTE: to make coupling band-limited (substantial only in the band of interest), diff --git a/src/meegsim/utils.py b/src/meegsim/utils.py index f178456..a46d305 100644 --- a/src/meegsim/utils.py +++ b/src/meegsim/utils.py @@ -3,7 +3,6 @@ import warnings from mne.io.constants import FIFF -from scipy.special import i1, i0 logger = logging.getLogger("meegsim") @@ -79,15 +78,16 @@ def normalize_variance(data): Returns ------- - data: array + data_norm: array Normalized time series. The variance of each row is equal to 1. """ + # NOTE: make a copy to keep the original waveform intact + data_norm = data.copy() + if data_norm.ndim == 1: + return data_norm / np.std(data_norm) - if data.ndim == 1: - return data / np.std(data) - - data /= np.std(data, axis=-1)[:, np.newaxis] - return data + data_norm /= np.std(data_norm, axis=-1)[:, np.newaxis] + return data_norm def _extract_hemi(src): @@ -195,10 +195,6 @@ def unpack_vertices(vertices_lists): return unpacked_vertices -def theoretical_plv(kappa): - return i1(kappa) / i0(kappa) - - def vertices_to_mne(vertices, src): """ Convert the vertices to the MNE format (list of lists). diff --git a/tests/test_coupling.py b/tests/test_coupling.py index 5cd1d07..ee7451a 100644 --- a/tests/test_coupling.py +++ b/tests/test_coupling.py @@ -6,33 +6,66 @@ from scipy.signal import hilbert from meegsim.coupling import ( - constant_phase_shift, + ppc_constant_phase_shift, ppc_von_mises, ppc_shifted_copy_with_noise, + _get_envelope, _get_required_snr, _shifted_copy_with_noise, ) -from meegsim.utils import get_sfreq, theoretical_plv +from meegsim.utils import get_sfreq from utils.prepare import prepare_sinusoid -def prepare_inputs(): +def prepare_inputs(sfreq=250, duration=60): n_series = 2 - fs = 1000 - times = np.arange(0, 1, 1 / fs) + times = np.arange(0, duration, 1 / sfreq) return n_series, len(times), times @pytest.mark.parametrize( "phase_lag", [np.pi / 4, np.pi / 3, np.pi / 2, np.pi, 2 * np.pi] ) -def test_constant_phase_shift(phase_lag): +def test_ppc_constant_phase_shift_same_envelope(phase_lag): # Test with a simple sinusoidal waveform _, _, times = prepare_inputs() waveform = np.sin(2 * np.pi * 10 * times) - result = constant_phase_shift(waveform, get_sfreq(times), phase_lag) + result = ppc_constant_phase_shift( + waveform, get_sfreq(times), phase_lag, envelope="same", random_state=1234 + ) + + waveform = hilbert(waveform) + result = hilbert(result) + + cplv = compute_plv(waveform, result, m=1, n=1, plv_type="complex") + plv = np.abs(cplv) + test_angle = np.angle(cplv) + + assert plv >= 0.9, f"Expected PLV to be at least 0.9, got {plv}" + assert ( + (np.abs(test_angle) - phase_lag) <= 0.01 + ), f"Test failed: angle is different from phase_lag. difference = {np.round((np.abs(test_angle) - phase_lag),2)}" + + +@pytest.mark.parametrize( + "phase_lag", [np.pi / 4, np.pi / 3, np.pi / 2, np.pi, 2 * np.pi] +) +def test_ppc_constant_phase_shift_random_envelope(phase_lag): + # Test with a simple sinusoidal waveform + _, _, times = prepare_inputs() + waveform = np.sin(2 * np.pi * 10 * times) + + result = ppc_constant_phase_shift( + waveform, + get_sfreq(times), + phase_lag, + envelope="random", + fmin=9.5, + fmax=10.5, + random_state=1234, + ) waveform = hilbert(waveform) result = hilbert(result) @@ -41,20 +74,28 @@ def test_constant_phase_shift(phase_lag): plv = np.abs(cplv) test_angle = np.angle(cplv) - assert plv >= 0.99, f"Test failed: plv is smaller than 0.99. plv = {plv}" + assert plv >= 0.9, f"Expected PLV to be at least 0.9, got {plv}" assert ( (np.abs(test_angle) - phase_lag) <= 0.01 ), f"Test failed: angle is different from phase_lag. difference = {np.round((np.abs(test_angle) - phase_lag),2)}" @pytest.mark.parametrize("m, n", [(2, 1), (3, 1), (5 / 2, 1)]) -def test_constant_phase_shift_harmonics(m, n): +def test_ppc_constant_phase_shift_harmonics_same_envelope(m, n): # Test with different m and n harmonics _, _, times = prepare_inputs() waveform = np.sin(2 * np.pi * 10 * times) phase_lag = np.pi / 3 - result = constant_phase_shift(waveform, get_sfreq(times), phase_lag, m=m, n=n) + result = ppc_constant_phase_shift( + waveform, + get_sfreq(times), + phase_lag, + envelope="same", + m=m, + n=n, + random_state=1234, + ) waveform = hilbert(waveform) result = hilbert(result) @@ -63,21 +104,37 @@ def test_constant_phase_shift_harmonics(m, n): plv = np.abs(cplv) test_angle = np.angle(cplv) - assert plv >= 0.9, f"Test failed: plv is smaller than 0.9. plv = {plv}" + assert plv >= 0.9, f"Expected PLV to be at least 0.9, got {plv}" assert ( (np.abs(test_angle) - phase_lag) <= 0.1 ), f"Test failed: angle is different from phase_lag. difference = {np.round((np.abs(test_angle) - phase_lag),2)}" -@pytest.mark.parametrize("kappa", [0.001, 0.1, 0.5, 1, 5, 10, 50]) -def test_ppc_von_mises(kappa): - # Test kappas that are reliable (more than 0.5) - _, _, times = prepare_inputs() +@pytest.mark.parametrize( + "kappa,lo,hi", + [ + (0.001, 0, 0.2), + (0.01, 0, 0.2), + (0.1, 0.1, 0.4), + (0.3, 0.3, 0.7), + (0.5, 0.5, 0.9), + (1, 0.7, 1), + (10, 0.9, 1), + ], +) +def test_ppc_von_mises_same_envelope_kappa(kappa, lo, hi): + _, _, times = prepare_inputs(duration=120) waveform = np.sin(2 * np.pi * 10 * times) - phase_lag = 0 + phase_lag = np.pi / 4 result = ppc_von_mises( - waveform, get_sfreq(times), phase_lag, kappa=kappa, fmin=8, fmax=12 + waveform, + get_sfreq(times), + phase_lag, + kappa=kappa, + fmin=8, + fmax=12, + random_state=1234, ) waveform = hilbert(waveform) @@ -85,11 +142,52 @@ def test_ppc_von_mises(kappa): cplv = compute_plv(waveform, result, m=1, n=1, plv_type="complex") plv = np.abs(cplv) - plv_theoretical = theoretical_plv(kappa) + test_angle = np.abs(np.angle(cplv)) + + # NOTE: lower and upper bounds were selected by simulating multiple time series, + # this test should prevent large deviations from the expected result due to + # errors in the processing + assert lo <= plv <= hi, f"Expected PLV to be between {lo} and {hi}, got {plv}" + if kappa >= 0.5: + assert np.allclose(test_angle, phase_lag, atol=0.1), ( + f"Expected the actual phase lag ({test_angle:.2f}) to be within " + f"0.1 from the desired one ({phase_lag:.2f})" + ) + + +@pytest.mark.parametrize("kappa,lo,hi", [(0.01, 0, 0.2), (10, 0.8, 1)]) +def test_ppc_von_mises_random_envelope_kappa(kappa, lo, hi): + _, _, times = prepare_inputs(duration=120) + waveform = np.sin(2 * np.pi * 10 * times) + phase_lag = np.pi / 4 - assert ( - plv >= plv_theoretical - ), f"Test failed: plv is smaller than theoretical. plv = {plv}, plv_theoretical = {plv_theoretical}" + result = ppc_von_mises( + waveform, + get_sfreq(times), + phase_lag, + kappa=kappa, + envelope="random", + fmin=8, + fmax=12, + random_state=1234, + ) + + waveform = hilbert(waveform) + result = hilbert(result) + + cplv = compute_plv(waveform, result, m=1, n=1, plv_type="complex") + plv = np.abs(cplv) + test_angle = np.abs(np.angle(cplv)) + + # NOTE: we check only extreme cases here to make sure the random envelope does + # not break the result completely. Still, random envelope might decrease PLV a bit, + # so the bounds are a bit wider + assert lo <= plv <= hi, f"Expected PLV to be between {lo} and {hi}, got {plv}" + if kappa >= 0.5: + assert np.allclose(np.abs(test_angle), phase_lag, atol=0.1), ( + f"Expected the actual phase lag ({test_angle:.2f}) to be within " + f"0.1 from the desired one ({phase_lag:.2f})" + ) @pytest.mark.parametrize("m, n", [(2, 1), (3, 1), (5 / 2, 1)]) @@ -97,11 +195,20 @@ def test_ppc_von_mises_harmonics(m, n): # Test with different m and n harmonics _, _, times = prepare_inputs() waveform = np.sin(2 * np.pi * 10 * times) - phase_lag = 0 + phase_lag = np.pi / 4 kappa = 10 result = ppc_von_mises( - waveform, get_sfreq(times), phase_lag, m=m, n=n, kappa=kappa, fmin=8, fmax=12 + waveform, + get_sfreq(times), + phase_lag, + m=m, + n=n, + envelope="same", + kappa=kappa, + fmin=8, + fmax=12, + random_state=1234, ) waveform = hilbert(waveform) @@ -111,15 +218,20 @@ def test_ppc_von_mises_harmonics(m, n): plv = np.abs(cplv) test_angle = np.angle(cplv) - assert plv >= 0.8, f"Test failed: plv is smaller than 0.8. plv = {plv}" - assert ( - (np.abs(test_angle) - phase_lag) <= 0.1 - ), f"Test failed: angle is different from phase_lag. difference = {np.round((np.abs(test_angle) - phase_lag),2)}" + assert 0.7 <= plv, f"Expected PLV to be at least 0.7, got {plv}" + assert np.allclose(np.abs(test_angle), phase_lag, atol=0.1), ( + f"Expected the actual phase lag ({test_angle:.2f}) to be within " + f"0.1 from the desired one ({phase_lag:.2f})" + ) @pytest.mark.parametrize( "coupling_fun,params", [ + ( + ppc_constant_phase_shift, + dict(phase_lag=np.pi / 4, envelope="random", fmin=8, fmax=12), + ), (ppc_von_mises, dict(kappa=1, phase_lag=np.pi / 4, fmin=8, fmax=12)), ( ppc_shifted_copy_with_noise, @@ -165,6 +277,33 @@ def test_get_required_snr(coh, band_limited, expected_snr): assert np.isclose(_get_required_snr(coh, band_limited), expected_snr) +def test_get_envelope_same(): + sfreq = 250 + waveform = prepare_sinusoid(f=10, sfreq=sfreq, duration=60) + waveform_amp = np.abs(hilbert(waveform)) + envelope = _get_envelope(waveform, envelope="same", sfreq=sfreq) + assert np.allclose(waveform_amp, envelope), "Expected envelope to match input" + + +def test_get_envelope_random(): + sfreq = 250 + fmin, fmax = 8, 12 + seed = 1234 + waveform = prepare_sinusoid(f=10, sfreq=sfreq, duration=60) + waveform_amp = np.abs(hilbert(waveform)) + envelope = _get_envelope( + waveform, + envelope="random", + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + random_state=seed, + ) + assert not np.allclose( + waveform_amp, envelope + ), "Expected envelope not to match input" + + @pytest.mark.parametrize( "target_coh,tol", [ diff --git a/tests/test_integration.py b/tests/test_integration.py index afca711..c9f831e 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -4,7 +4,7 @@ import numpy as np -from meegsim.coupling import constant_phase_shift, ppc_von_mises +from meegsim.coupling import ppc_constant_phase_shift, ppc_von_mises from meegsim.location import select_random from meegsim.simulate import SourceSimulator from meegsim.waveform import narrowband_oscillation, white_noise @@ -62,7 +62,13 @@ def test_builtin_methods(): ) # Define several edges to test graph traversal and built-in coupling methods - sim.set_coupling(("point1", "point2"), method=constant_phase_shift, phase_lag=0) + sim.set_coupling( + ("point1", "point2"), + method=ppc_constant_phase_shift, + phase_lag=0, + fmin=8, + fmax=12, + ) sim.set_coupling( coupling={ ("point2", "patch3"): dict(kappa=0.1, phase_lag=-np.pi / 6), diff --git a/tests/test_utils.py b/tests/test_utils.py index 35c206b..5714f8f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -96,11 +96,12 @@ def test_combine_stcs_overlap(): def test_normalize_variance(): data = np.random.randn(10, 1000) + data_orig = data.copy() normalized = normalize_variance(data) - # Should not change the shape but should change the norm - assert data.shape == normalized.shape - assert np.allclose(np.var(normalized, axis=1), 1) + assert data.shape == normalized.shape, "Array shape should not be changed" + assert np.allclose(np.var(normalized, axis=1), 1), "Expected variance to be 1" + assert np.allclose(data_orig, data), "The input waveform should not be modified" def test_get_sfreq():