diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 17592e6..9013a04 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -9,6 +9,9 @@ on: pull_request: branches: [ "master" ] +env: + BUILD_ENV: ci + permissions: contents: read diff --git a/src/meegsim/simulate.py b/src/meegsim/simulate.py index 6c2eceb..135baef 100644 --- a/src/meegsim/simulate.py +++ b/src/meegsim/simulate.py @@ -168,6 +168,8 @@ def add_patch_sources( waveform_params=dict(), snr_params=dict(), extents=None, + subject=None, + subjects_dir=None, names=None, ): """ @@ -227,6 +229,14 @@ def add_patch_sources( grown (using :func:`mne.grow_labels`) from vertices specified in location according to the provided values of extent. If a single number is provided, all patch sources have the same extent. + subject : str, optional + Subject name, only used when growing patch sources from the central vertex. + If None (default), it is derived from the ``src`` object provided when + initializing the simulator. + subject_dir : str, optional + Path to the directory with FreeSurfer output, only used when growing patch + sources from the central vertex. If None (default), it is resolved automatically + by MNE-Python. names : list, optional A list of names for each source. If not specified, the names will be autogenerated using the format 'auto-sgN-sM', where N is the index @@ -247,8 +257,10 @@ def add_patch_sources( std=std, location_params=location_params, waveform_params=waveform_params, - extents=extents, snr_params=snr_params, + extents=extents, + subject=subject, + subjects_dir=subjects_dir, names=names, group=f"sg{next_group_idx}", existing=self._sources, diff --git a/src/meegsim/source_groups.py b/src/meegsim/source_groups.py index 8cdad60..c409c23 100644 --- a/src/meegsim/source_groups.py +++ b/src/meegsim/source_groups.py @@ -162,7 +162,17 @@ def create( class PatchSourceGroup(_BaseSourceGroup): def __init__( - self, n_sources, location, waveform, snr, snr_params, std, extents, names + self, + n_sources, + location, + waveform, + snr, + snr_params, + std, + extents, + subject, + subjects_dir, + names, ): super().__init__() @@ -178,6 +188,8 @@ def __init__( self.std = std self.names = names self.extents = extents + self.subject = subject + self.subjects_dir = subjects_dir def __repr__(self): location_desc = "list" @@ -204,6 +216,8 @@ def simulate(self, src, times, random_state=None): self.std, self.names, self.extents, + self.subject, + self.subjects_dir, random_state=random_state, ) @@ -219,6 +233,8 @@ def create( waveform_params, snr_params, extents, + subject, + subjects_dir, names, group, existing, @@ -247,6 +263,10 @@ def create( Additional parameters for the adjustment of SNR. extents: list Extents (radius in mm) of each patch provided by the user. + subject: str, optional + Subject name. + subject_dir: str, optional + Path to the directory with FreeSurfer output. names: The names of sources provided by the user. group: @@ -280,4 +300,15 @@ def create( else: check_names(names, n_sources, existing) - return cls(n_sources, location, waveform, snr, snr_params, std, extents, names) + return cls( + n_sources, + location, + waveform, + snr, + snr_params, + std, + extents, + subject, + subjects_dir, + names, + ) diff --git a/src/meegsim/sources.py b/src/meegsim/sources.py index 2ae104d..3534684 100644 --- a/src/meegsim/sources.py +++ b/src/meegsim/sources.py @@ -299,6 +299,8 @@ def _create( stds, names, extents, + subject, + subjects_dir, random_state=None, ): """ @@ -323,8 +325,11 @@ def _create( if data.shape[1] != len(times): raise ValueError("The number of samples in waveform does not match") - # find patch vertices - subject = src[0].get("subject_his_id", None) + # Pick subject name from src if not provided explicitly + if subject is None: + subject = src[0].get("subject_his_id", None) + + # Find patch vertices patch_vertices = [] patch_stds = [] if isinstance(stds, mne.SourceEstimate) else stds for isource, extent in enumerate(extents): @@ -347,7 +352,11 @@ def _create( # Grow the patch from center otherwise patch = mne.grow_labels( - subject, vertno, extent, src_idx, subjects_dir=None + subject=subject, + seeds=vertno, + extents=extent, + hemis=src_idx, + subjects_dir=subjects_dir, )[0] # Prune vertices diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 6044a8c..c2e506b 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -383,6 +383,8 @@ def test_simulate_std_adjustment(): snr_params=dict(), std=[3], extents=[None], + subject=None, + subjects_dir=None, names=["patch"], ), ] diff --git a/tests/test_snr.py b/tests/test_snr.py index e65549f..5b144a1 100644 --- a/tests/test_snr.py +++ b/tests/test_snr.py @@ -243,6 +243,8 @@ def test_adjust_snr_local_patch(adjust_snr_mock): snr_params=dict(fmin=8, fmax=12), std=1, extents=None, + subject=None, + subjects_dir=None, names=["s1"], ), PatchSourceGroup( @@ -253,6 +255,8 @@ def test_adjust_snr_local_patch(adjust_snr_mock): snr_params=dict(), std=1, extents=None, + subject=None, + subjects_dir=None, names=["s2"], ), ] diff --git a/tests/test_sources.py b/tests/test_sources.py index f0fc3a3..425ec81 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -347,6 +347,8 @@ def test_patchsource_create_with_extent(): stds=stds, names=names, extents=extents, + subject=None, + subjects_dir=None, random_state=None, ) @@ -369,7 +371,7 @@ def test_patchsource_create_with_extent(): # Verify that grow_labels was called once for the source with extent mock_grow_labels.assert_called_once_with( - "meegsim", [2], 3, 0, subjects_dir=None + subject="meegsim", seeds=[2], extents=3, hemis=0, subjects_dir=None ) @@ -391,13 +393,13 @@ def test_patchsource_create_std_sourceestimate(get_param_mock): # Values are passed directly - the mock should not be used sources = PatchSource._create( - src, times, n_sources, location, waveform, stds, names, extents + src, times, n_sources, location, waveform, stds, names, extents, None, None ) get_param_mock.assert_not_called() # Values are passed in stc - the mock should be called once per patch sources = PatchSource._create( - src, times, n_sources, location, waveform, std_stc, names, extents + src, times, n_sources, location, waveform, std_stc, names, extents, None, None ) assert get_param_mock.call_count == n_sources diff --git a/tests/test_subject.py b/tests/test_subject.py new file mode 100644 index 0000000..c372460 --- /dev/null +++ b/tests/test_subject.py @@ -0,0 +1,77 @@ +""" +Tests for ensuring that subject-specific info is processed and set correctly. +""" + +import mne +import pytest + +from mne.datasets import sample + +from meegsim.location import select_random +from meegsim.simulate import SourceSimulator +from meegsim.waveform import narrowband_oscillation + +from utils.misc import running_on_ci + + +def prepare_real_data(): + data_path = sample.data_path() / "MEG" / "sample" + fwd_path = data_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" + raw_path = data_path / "sample_audvis_raw.fif" + subjects_dir = sample.data_path() / "subjects" + + # Load the prerequisites: fwd, src, and info + fwd = mne.read_forward_solution(fwd_path) + fwd = mne.convert_forward_solution(fwd, force_fixed=True) + raw = mne.io.read_raw(raw_path) + info = raw.info + + # Pick EEG channels only + eeg_idx = mne.pick_types(info, eeg=True) + info_eeg = mne.pick_info(info, eeg_idx) + fwd_eeg = fwd.pick_channels(info_eeg.ch_names) + + return fwd_eeg, info_eeg, subjects_dir + + +@pytest.mark.skipif(running_on_ci(), reason="Skip tests with real data on CI") +def test_grow_patch_source(): + fwd, info, subjects_dir = prepare_real_data() + src = fwd["src"] + + sfreq = 250 + duration = 60 + seed = 123 + + sim = SourceSimulator(src, snr_mode="local") + + # Select some vertices randomly + sim.add_patch_sources( + location=select_random, + waveform=narrowband_oscillation, + location_params=dict(n=3), + waveform_params=dict(fmin=8, fmax=12), + extents=5, + subjects_dir=subjects_dir, + names=["s1", "s2", "s3"], + ) + + sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed) + sc.to_raw(fwd, info) + + +@pytest.mark.skipif(running_on_ci(), reason="Skip tests with real data on CI") +def test_sourceconfiguration_plot(): + fwd, _, subjects_dir = prepare_real_data() + src = fwd["src"] + + sfreq = 250 + duration = 60 + seed = 123 + + sim = SourceSimulator(src, snr_mode="local") + sim.add_noise_sources(location=select_random, location_params=dict(n=10)) + + sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed) + brain = sc.plot(subject="sample", subjects_dir=subjects_dir) + brain.close() diff --git a/tests/utils/misc.py b/tests/utils/misc.py new file mode 100644 index 0000000..350235c --- /dev/null +++ b/tests/utils/misc.py @@ -0,0 +1,5 @@ +import os + + +def running_on_ci(): + return os.environ.get("BUILD_ENV", "local") == "ci"