From 00eac850a20f51c6db73baa944498a06cbe8f5b7 Mon Sep 17 00:00:00 2001 From: Nikolai Kapralov <4dvlup@gmail.com> Date: Tue, 10 Jun 2025 08:58:27 +0200 Subject: [PATCH 1/5] FIX: provide subject and subjects_dir for patch generation --- examples/test_sample.py | 81 ++++++++++++++++++++++++++++++++++++ src/meegsim/simulate.py | 14 ++++++- src/meegsim/source_groups.py | 35 +++++++++++++++- src/meegsim/sources.py | 13 +++--- 4 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 examples/test_sample.py diff --git a/examples/test_sample.py b/examples/test_sample.py new file mode 100644 index 0000000..28a86ed --- /dev/null +++ b/examples/test_sample.py @@ -0,0 +1,81 @@ +""" +Adjustment of local SNR +======================== + +This example shows how the local SNR can be adjusted. +""" + +import mne +import matplotlib.pyplot as plt + +from mne.datasets import sample + +from meegsim.location import select_random +from meegsim.simulate import SourceSimulator +from meegsim.waveform import narrowband_oscillation + +# %% +# First, we load the head model and associated source space: + +# Paths +data_path = sample.data_path() / "MEG" / "sample" +fwd_path = data_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" +raw_path = data_path / "sample_audvis_raw.fif" +subjects_dir = sample.data_path() / "subjects" + +# Load the prerequisites: fwd, src, and info +fwd = mne.read_forward_solution(fwd_path) +fwd = mne.convert_forward_solution(fwd, force_fixed=True) +raw = mne.io.read_raw(raw_path) +src = fwd["src"] +info = raw.info + +# Pick EEG channels only +eeg_idx = mne.pick_types(info, eeg=True) +info_eeg = mne.pick_info(info, eeg_idx) +fwd_eeg = fwd.pick_channels(info_eeg.ch_names) + +# %% +# We simulate the same configuration (100 noise sources and 3 point sources) +# several times with different levels of SNR. As shown in the picture below, +# the average alpha power increases relative to the 1/f level with higher SNR: + +# Simulation parameters +sfreq = 250 +duration = 60 +seed = 123 + +fig, axes = plt.subplots(ncols=3, figsize=(8, 3)) +snr_values = [1, 5, 10] + +for i_snr, target_snr in enumerate(snr_values): + sim = SourceSimulator(src) + + # Select some vertices randomly + sim.add_patch_sources( + location=select_random, + waveform=narrowband_oscillation, + location_params=dict(n=3), + waveform_params=dict(fmin=8, fmax=12), + snr=target_snr, + snr_params=dict(fmin=8, fmax=12), + extents=5, + subjects_dir=subjects_dir, + names=["s1", "s2", "s3"], + ) + + sim.add_noise_sources(location=select_random, location_params=dict(n=100)) + + sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed) + raw = sc.to_raw(fwd, info) + + spec = raw.compute_psd(fmax=40, n_fft=sfreq, n_overlap=sfreq // 2, n_per_seg=sfreq) + spec.plot(average=True, dB=False, axes=axes[i_snr], amplitude=False) + + axes[i_snr].set_title(f"SNR={target_snr}") + axes[i_snr].set_xlabel("Frequency (Hz)") + axes[i_snr].set_ylabel("PSD (uV^2/Hz)") + axes[i_snr].set_ylim([0, 0.125]) + +fig.tight_layout() +plt.show() diff --git a/src/meegsim/simulate.py b/src/meegsim/simulate.py index 6c2eceb..135baef 100644 --- a/src/meegsim/simulate.py +++ b/src/meegsim/simulate.py @@ -168,6 +168,8 @@ def add_patch_sources( waveform_params=dict(), snr_params=dict(), extents=None, + subject=None, + subjects_dir=None, names=None, ): """ @@ -227,6 +229,14 @@ def add_patch_sources( grown (using :func:`mne.grow_labels`) from vertices specified in location according to the provided values of extent. If a single number is provided, all patch sources have the same extent. + subject : str, optional + Subject name, only used when growing patch sources from the central vertex. + If None (default), it is derived from the ``src`` object provided when + initializing the simulator. + subject_dir : str, optional + Path to the directory with FreeSurfer output, only used when growing patch + sources from the central vertex. If None (default), it is resolved automatically + by MNE-Python. names : list, optional A list of names for each source. If not specified, the names will be autogenerated using the format 'auto-sgN-sM', where N is the index @@ -247,8 +257,10 @@ def add_patch_sources( std=std, location_params=location_params, waveform_params=waveform_params, - extents=extents, snr_params=snr_params, + extents=extents, + subject=subject, + subjects_dir=subjects_dir, names=names, group=f"sg{next_group_idx}", existing=self._sources, diff --git a/src/meegsim/source_groups.py b/src/meegsim/source_groups.py index 8cdad60..c409c23 100644 --- a/src/meegsim/source_groups.py +++ b/src/meegsim/source_groups.py @@ -162,7 +162,17 @@ def create( class PatchSourceGroup(_BaseSourceGroup): def __init__( - self, n_sources, location, waveform, snr, snr_params, std, extents, names + self, + n_sources, + location, + waveform, + snr, + snr_params, + std, + extents, + subject, + subjects_dir, + names, ): super().__init__() @@ -178,6 +188,8 @@ def __init__( self.std = std self.names = names self.extents = extents + self.subject = subject + self.subjects_dir = subjects_dir def __repr__(self): location_desc = "list" @@ -204,6 +216,8 @@ def simulate(self, src, times, random_state=None): self.std, self.names, self.extents, + self.subject, + self.subjects_dir, random_state=random_state, ) @@ -219,6 +233,8 @@ def create( waveform_params, snr_params, extents, + subject, + subjects_dir, names, group, existing, @@ -247,6 +263,10 @@ def create( Additional parameters for the adjustment of SNR. extents: list Extents (radius in mm) of each patch provided by the user. + subject: str, optional + Subject name. + subject_dir: str, optional + Path to the directory with FreeSurfer output. names: The names of sources provided by the user. group: @@ -280,4 +300,15 @@ def create( else: check_names(names, n_sources, existing) - return cls(n_sources, location, waveform, snr, snr_params, std, extents, names) + return cls( + n_sources, + location, + waveform, + snr, + snr_params, + std, + extents, + subject, + subjects_dir, + names, + ) diff --git a/src/meegsim/sources.py b/src/meegsim/sources.py index 2ae104d..2e6a2fc 100644 --- a/src/meegsim/sources.py +++ b/src/meegsim/sources.py @@ -299,6 +299,8 @@ def _create( stds, names, extents, + subject, + subjects_dir, random_state=None, ): """ @@ -323,8 +325,11 @@ def _create( if data.shape[1] != len(times): raise ValueError("The number of samples in waveform does not match") - # find patch vertices - subject = src[0].get("subject_his_id", None) + # Pick subject name from src if not provided explicitly + if subject is None: + subject = src[0].get("subject_his_id", None) + + # Find patch vertices patch_vertices = [] patch_stds = [] if isinstance(stds, mne.SourceEstimate) else stds for isource, extent in enumerate(extents): @@ -346,9 +351,7 @@ def _create( continue # Grow the patch from center otherwise - patch = mne.grow_labels( - subject, vertno, extent, src_idx, subjects_dir=None - )[0] + patch = mne.grow_labels(subject, vertno, extent, src_idx, subjects_dir)[0] # Prune vertices patch_vertno = [ From caf283c23a091f984671e4c01848172627890455 Mon Sep 17 00:00:00 2001 From: Nikolai Kapralov <4dvlup@gmail.com> Date: Tue, 22 Jul 2025 16:32:56 +0200 Subject: [PATCH 2/5] TEST: fix tests --- examples/test_sample.py | 2 +- src/meegsim/sources.py | 8 +++++++- tests/test_simulate.py | 2 ++ tests/test_snr.py | 4 ++++ tests/test_sources.py | 8 +++++--- 5 files changed, 19 insertions(+), 5 deletions(-) diff --git a/examples/test_sample.py b/examples/test_sample.py index 28a86ed..362925e 100644 --- a/examples/test_sample.py +++ b/examples/test_sample.py @@ -49,7 +49,7 @@ snr_values = [1, 5, 10] for i_snr, target_snr in enumerate(snr_values): - sim = SourceSimulator(src) + sim = SourceSimulator(src, snr_mode="local") # Select some vertices randomly sim.add_patch_sources( diff --git a/src/meegsim/sources.py b/src/meegsim/sources.py index 2e6a2fc..3534684 100644 --- a/src/meegsim/sources.py +++ b/src/meegsim/sources.py @@ -351,7 +351,13 @@ def _create( continue # Grow the patch from center otherwise - patch = mne.grow_labels(subject, vertno, extent, src_idx, subjects_dir)[0] + patch = mne.grow_labels( + subject=subject, + seeds=vertno, + extents=extent, + hemis=src_idx, + subjects_dir=subjects_dir, + )[0] # Prune vertices patch_vertno = [ diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 6044a8c..c2e506b 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -383,6 +383,8 @@ def test_simulate_std_adjustment(): snr_params=dict(), std=[3], extents=[None], + subject=None, + subjects_dir=None, names=["patch"], ), ] diff --git a/tests/test_snr.py b/tests/test_snr.py index e65549f..5b144a1 100644 --- a/tests/test_snr.py +++ b/tests/test_snr.py @@ -243,6 +243,8 @@ def test_adjust_snr_local_patch(adjust_snr_mock): snr_params=dict(fmin=8, fmax=12), std=1, extents=None, + subject=None, + subjects_dir=None, names=["s1"], ), PatchSourceGroup( @@ -253,6 +255,8 @@ def test_adjust_snr_local_patch(adjust_snr_mock): snr_params=dict(), std=1, extents=None, + subject=None, + subjects_dir=None, names=["s2"], ), ] diff --git a/tests/test_sources.py b/tests/test_sources.py index f0fc3a3..425ec81 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -347,6 +347,8 @@ def test_patchsource_create_with_extent(): stds=stds, names=names, extents=extents, + subject=None, + subjects_dir=None, random_state=None, ) @@ -369,7 +371,7 @@ def test_patchsource_create_with_extent(): # Verify that grow_labels was called once for the source with extent mock_grow_labels.assert_called_once_with( - "meegsim", [2], 3, 0, subjects_dir=None + subject="meegsim", seeds=[2], extents=3, hemis=0, subjects_dir=None ) @@ -391,13 +393,13 @@ def test_patchsource_create_std_sourceestimate(get_param_mock): # Values are passed directly - the mock should not be used sources = PatchSource._create( - src, times, n_sources, location, waveform, stds, names, extents + src, times, n_sources, location, waveform, stds, names, extents, None, None ) get_param_mock.assert_not_called() # Values are passed in stc - the mock should be called once per patch sources = PatchSource._create( - src, times, n_sources, location, waveform, std_stc, names, extents + src, times, n_sources, location, waveform, std_stc, names, extents, None, None ) assert get_param_mock.call_count == n_sources From 92f230c046b7f385c27c0d53194e318ee84a987a Mon Sep 17 00:00:00 2001 From: Nikolai Kapralov <4dvlup@gmail.com> Date: Tue, 22 Jul 2025 16:46:55 +0200 Subject: [PATCH 3/5] Add local test for patch growing, skip on CI --- .github/workflows/python-tests.yml | 3 ++ tests/test_subject.py | 60 ++++++++++++++++++++++++++++++ tests/utils/misc.py | 5 +++ 3 files changed, 68 insertions(+) create mode 100644 tests/test_subject.py create mode 100644 tests/utils/misc.py diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 17592e6..9013a04 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -9,6 +9,9 @@ on: pull_request: branches: [ "master" ] +env: + BUILD_ENV: ci + permissions: contents: read diff --git a/tests/test_subject.py b/tests/test_subject.py new file mode 100644 index 0000000..15ae7c0 --- /dev/null +++ b/tests/test_subject.py @@ -0,0 +1,60 @@ +""" +Tests for ensuring that subject-specific info is processed and set correctly. +""" + +import mne +import pytest + +from mne.datasets import sample + +from meegsim.location import select_random +from meegsim.simulate import SourceSimulator +from meegsim.waveform import narrowband_oscillation + +from utils.misc import running_on_ci + + +def prepare_real_data(): + data_path = sample.data_path() / "MEG" / "sample" + fwd_path = data_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" + raw_path = data_path / "sample_audvis_raw.fif" + subjects_dir = sample.data_path() / "subjects" + + # Load the prerequisites: fwd, src, and info + fwd = mne.read_forward_solution(fwd_path) + fwd = mne.convert_forward_solution(fwd, force_fixed=True) + raw = mne.io.read_raw(raw_path) + info = raw.info + + # Pick EEG channels only + eeg_idx = mne.pick_types(info, eeg=True) + info_eeg = mne.pick_info(info, eeg_idx) + fwd_eeg = fwd.pick_channels(info_eeg.ch_names) + + return fwd_eeg, info_eeg, subjects_dir + + +@pytest.mark.skipif(running_on_ci(), reason="Skip tests with real data on CI") +def test_grow_patch_source(): + fwd, info, subjects_dir = prepare_real_data() + src = fwd["src"] + + sfreq = 250 + duration = 60 + seed = 123 + + sim = SourceSimulator(src, snr_mode="local") + + # Select some vertices randomly + sim.add_patch_sources( + location=select_random, + waveform=narrowband_oscillation, + location_params=dict(n=3), + waveform_params=dict(fmin=8, fmax=12), + extents=5, + subjects_dir=subjects_dir, + names=["s1", "s2", "s3"], + ) + + sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed) + sc.to_raw(fwd, info) diff --git a/tests/utils/misc.py b/tests/utils/misc.py new file mode 100644 index 0000000..350235c --- /dev/null +++ b/tests/utils/misc.py @@ -0,0 +1,5 @@ +import os + + +def running_on_ci(): + return os.environ.get("BUILD_ENV", "local") == "ci" From a042b48fd4da1b8210457a97e168d5f734039162 Mon Sep 17 00:00:00 2001 From: Nikolai Kapralov <4dvlup@gmail.com> Date: Tue, 22 Jul 2025 16:54:01 +0200 Subject: [PATCH 4/5] TEST: add test for plotting the configuration --- examples/test_sample.py | 3 +++ tests/test_subject.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/examples/test_sample.py b/examples/test_sample.py index 362925e..241fd08 100644 --- a/examples/test_sample.py +++ b/examples/test_sample.py @@ -79,3 +79,6 @@ fig.tight_layout() plt.show() + +brain = sc.plot("sample", subjects_dir=subjects_dir) +brain.close() diff --git a/tests/test_subject.py b/tests/test_subject.py index 15ae7c0..c372460 100644 --- a/tests/test_subject.py +++ b/tests/test_subject.py @@ -58,3 +58,20 @@ def test_grow_patch_source(): sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed) sc.to_raw(fwd, info) + + +@pytest.mark.skipif(running_on_ci(), reason="Skip tests with real data on CI") +def test_sourceconfiguration_plot(): + fwd, _, subjects_dir = prepare_real_data() + src = fwd["src"] + + sfreq = 250 + duration = 60 + seed = 123 + + sim = SourceSimulator(src, snr_mode="local") + sim.add_noise_sources(location=select_random, location_params=dict(n=10)) + + sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed) + brain = sc.plot(subject="sample", subjects_dir=subjects_dir) + brain.close() From 877d473faf8723dcc324d8ee62eca7388920ed69 Mon Sep 17 00:00:00 2001 From: Nikolai Kapralov <4dvlup@gmail.com> Date: Tue, 22 Jul 2025 16:54:15 +0200 Subject: [PATCH 5/5] Remove the example --- examples/test_sample.py | 84 ----------------------------------------- 1 file changed, 84 deletions(-) delete mode 100644 examples/test_sample.py diff --git a/examples/test_sample.py b/examples/test_sample.py deleted file mode 100644 index 241fd08..0000000 --- a/examples/test_sample.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Adjustment of local SNR -======================== - -This example shows how the local SNR can be adjusted. -""" - -import mne -import matplotlib.pyplot as plt - -from mne.datasets import sample - -from meegsim.location import select_random -from meegsim.simulate import SourceSimulator -from meegsim.waveform import narrowband_oscillation - -# %% -# First, we load the head model and associated source space: - -# Paths -data_path = sample.data_path() / "MEG" / "sample" -fwd_path = data_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" -raw_path = data_path / "sample_audvis_raw.fif" -subjects_dir = sample.data_path() / "subjects" - -# Load the prerequisites: fwd, src, and info -fwd = mne.read_forward_solution(fwd_path) -fwd = mne.convert_forward_solution(fwd, force_fixed=True) -raw = mne.io.read_raw(raw_path) -src = fwd["src"] -info = raw.info - -# Pick EEG channels only -eeg_idx = mne.pick_types(info, eeg=True) -info_eeg = mne.pick_info(info, eeg_idx) -fwd_eeg = fwd.pick_channels(info_eeg.ch_names) - -# %% -# We simulate the same configuration (100 noise sources and 3 point sources) -# several times with different levels of SNR. As shown in the picture below, -# the average alpha power increases relative to the 1/f level with higher SNR: - -# Simulation parameters -sfreq = 250 -duration = 60 -seed = 123 - -fig, axes = plt.subplots(ncols=3, figsize=(8, 3)) -snr_values = [1, 5, 10] - -for i_snr, target_snr in enumerate(snr_values): - sim = SourceSimulator(src, snr_mode="local") - - # Select some vertices randomly - sim.add_patch_sources( - location=select_random, - waveform=narrowband_oscillation, - location_params=dict(n=3), - waveform_params=dict(fmin=8, fmax=12), - snr=target_snr, - snr_params=dict(fmin=8, fmax=12), - extents=5, - subjects_dir=subjects_dir, - names=["s1", "s2", "s3"], - ) - - sim.add_noise_sources(location=select_random, location_params=dict(n=100)) - - sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed) - raw = sc.to_raw(fwd, info) - - spec = raw.compute_psd(fmax=40, n_fft=sfreq, n_overlap=sfreq // 2, n_per_seg=sfreq) - spec.plot(average=True, dB=False, axes=axes[i_snr], amplitude=False) - - axes[i_snr].set_title(f"SNR={target_snr}") - axes[i_snr].set_xlabel("Frequency (Hz)") - axes[i_snr].set_ylabel("PSD (uV^2/Hz)") - axes[i_snr].set_ylim([0, 0.125]) - -fig.tight_layout() -plt.show() - -brain = sc.plot("sample", subjects_dir=subjects_dir) -brain.close()