Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions src/meegsim/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from meegsim._check import check_numeric
from meegsim.snr import get_variance, amplitude_adjustment_factor
from meegsim.utils import normalize_variance
from meegsim.waveform import narrowband_oscillation
from meegsim.waveform import narrowband_oscillation, white_noise


def constant_phase_shift(waveform, sfreq, phase_lag, m=1, n=1, random_state=None):
Expand Down Expand Up @@ -140,7 +140,9 @@ def ppc_von_mises(
return normalize_variance(np.real(waveform_coupled))


def _shifted_copy_with_noise(waveform, sfreq, phase_lag, snr, fmin, fmax, random_state):
def _shifted_copy_with_noise(
waveform, sfreq, phase_lag, snr, fmin, fmax, band_limited, random_state
):
"""
Generate a coupled time series by (1) applying a constant phase shift to the input
waveform and (2) mixing it with noise to achieve a desired level of signal-to-noise
Expand All @@ -149,14 +151,20 @@ def _shifted_copy_with_noise(waveform, sfreq, phase_lag, snr, fmin, fmax, random
shifted_waveform = constant_phase_shift(waveform, sfreq, phase_lag)
signal_var = get_variance(shifted_waveform, sfreq, fmin, fmax, filter=True)

# NOTE: we use another randomly generated narrowband oscillation as noise here
# so that only the frequency band of interest is affected. If a broadband signal
# is provided as input, this behavior might not be desirable (?). In this case,
# it can be addressed with an additional parameter
# NOTE: to make coupling band-limited (substantial only in the band of interest),
# we need to corrupt the rest of the coherence spectra with white noise,
# affecting other parts of the signal apart from the frequency band of interest.
#
# For oscillations as our main case this is not a big deal but might be important
# for other signals. If we filter the added noise in the frequency band of
# interest, it leads to flat connectivity spectra but only affects target frequencies
times = np.arange(waveform.size) / sfreq
noise_waveform = narrowband_oscillation(
n_series=1, times=times, fmin=fmin, fmax=fmax, random_state=random_state
)
if band_limited:
noise_waveform = white_noise(n_series=1, times=times, random_state=random_state)
else:
noise_waveform = narrowband_oscillation(
n_series=1, times=times, fmin=fmin, fmax=fmax, random_state=random_state
)
noise_var = get_variance(noise_waveform, sfreq, fmin, fmax, filter=True)

# Process the corner cases
Expand All @@ -173,16 +181,21 @@ def _shifted_copy_with_noise(waveform, sfreq, phase_lag, snr, fmin, fmax, random
return normalize_variance(coupled_waveform)


def _get_required_snr(coh):
def _get_required_snr(coh, band_limited):
"""
Calculate the value of SNR that is required to obtain desired coherence
between a waveform and its copy mixed with noise.
"""
# NOTE: prevent infinite SNR to always mix some noise in case we need to make
# the coupling band-limited
if band_limited and np.isclose(coh, 1, atol=1e-3):
coh = 0.999

return np.divide(coh**2, 1 - coh**2)


def ppc_shifted_copy_with_noise(
waveform, sfreq, phase_lag, coh, fmin, fmax, random_state=None
waveform, sfreq, phase_lag, coh, fmin, fmax, band_limited=True, random_state=None
):
"""
Generate a time series with desired level of coherence with the provided waveform
Expand Down Expand Up @@ -212,6 +225,12 @@ def ppc_shifted_copy_with_noise(
fmax : float
Upper cutoff frequency of the frequency band of interest (in Hz).

band_limited : bool
Whether to limit coupling only to the frequency band of interest (True by
default). If set to False, coupling will be the same for all frequencies,
resulting in a flat connectivity spectra. However, the signal outside of the
frequency band of interest will be modified negligibly.

random_state : None (default) or int
Seed for the random number generator. If None (default), results will vary
between function calls. Use a fixed value for reproducibility.
Expand All @@ -230,13 +249,14 @@ def ppc_shifted_copy_with_noise(
</auto_examples/plot_coupling>`.
"""
check_numeric("coherence", coh, bounds=(0, 1), allow_none=False)
snr = _get_required_snr(coh)
snr = _get_required_snr(coh, band_limited)
return _shifted_copy_with_noise(
waveform=waveform,
sfreq=sfreq,
phase_lag=phase_lag,
snr=snr,
fmin=fmin,
fmax=fmax,
band_limited=band_limited,
random_state=random_state,
)
21 changes: 13 additions & 8 deletions tests/test_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,18 @@ def test_reproducibility_with_random_state(coupling_fun, params):


@pytest.mark.parametrize(
"coh,expected_snr",
"coh,band_limited,expected_snr",
[
(1.0, np.inf),
(1.0 / np.sqrt(2), 1.0),
(0.0, 0.0),
(1.0, False, np.inf),
(1.0, True, 499.25), # pre-computed for coh=0.999
(1.0 / np.sqrt(2), False, 1.0),
(1.0 / np.sqrt(2), True, 1.0),
(0.0, False, 0.0),
(0.0, True, 0.0),
],
)
def test_get_required_snr(coh, expected_snr):
assert np.isclose(_get_required_snr(coh), expected_snr)
def test_get_required_snr(coh, band_limited, expected_snr):
assert np.isclose(_get_required_snr(coh, band_limited), expected_snr)


@pytest.mark.parametrize(
Expand All @@ -186,6 +189,7 @@ def test_ppc_shifted_copy_with_noise(target_coh, tol):
coh=target_coh,
fmin=8,
fmax=12,
band_limited=False,
random_state=seed,
)
actual_coh = compute_plv(waveform, coupled, m=1, n=1, coh=True)
Expand All @@ -200,6 +204,7 @@ def test_ppc_shifted_copy_with_noise(target_coh, tol):
coh=target_coh,
fmin=8,
fmax=12,
band_limited=False,
random_state=seed,
)
actual_coh = compute_plv(waveform, coupled, m=1, n=1, coh=True)
Expand All @@ -212,7 +217,7 @@ def test_shifted_copy_with_noise_infinite_snr():
waveform = np.sqrt(2) * prepare_sinusoid(f=10, sfreq=sfreq, duration=30)

# Infinite SNR with no phase lag should return the input
coupled = _shifted_copy_with_noise(waveform, sfreq, 0, np.inf, 8, 12, None)
coupled = _shifted_copy_with_noise(waveform, sfreq, 0, np.inf, 8, 12, False, None)
assert np.allclose(coupled, waveform)


Expand All @@ -222,6 +227,6 @@ def test_shifted_copy_with_noise_zero_snr(osc_mock):
waveform = np.sqrt(2) * prepare_sinusoid(f=10, sfreq=sfreq, duration=30)

# Zero SNR should return noise waveform (mock in our case)
coupled = _shifted_copy_with_noise(waveform, sfreq, 0, 0, 8, 12, None)
coupled = _shifted_copy_with_noise(waveform, sfreq, 0, 0, 8, 12, False, None)
assert np.allclose(coupled, 1.0)
osc_mock.assert_called_once()