From 6e592a0c7b6939eb6e8068f84b816e885012b96f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Oct 2021 09:57:22 +0200 Subject: [PATCH 1/2] Update MS4 and HS to new API --- .../sorters/herdingspikes/herdingspikes.py | 19 +++++++++--- .../sorters/mountainsort4/mountainsort4.py | 29 +++++++++++++------ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/spikeinterface/sorters/herdingspikes/herdingspikes.py b/spikeinterface/sorters/herdingspikes/herdingspikes.py index df897ec24a..76986b2e9c 100644 --- a/spikeinterface/sorters/herdingspikes/herdingspikes.py +++ b/spikeinterface/sorters/herdingspikes/herdingspikes.py @@ -1,5 +1,5 @@ from pathlib import Path -import copy +from packaging import version from ..basesorter import BaseSorter from ..utils import RecordingExtractorOldAPI @@ -147,6 +147,13 @@ def _run_from_folder(cls, output_folder, params, verbose): import herdingspikes as hs import spikeinterface.toolkit as st + hs_version = version.parse(hs.__version) + + if hs_version >= version.parse("0.4.0"): + new_api = True + else: + new_api = False + recording = load_extractor(output_folder / 'spikeinterface_recording.json') p = params @@ -162,12 +169,16 @@ def _run_from_folder(cls, output_folder, params, verbose): median=0.0, q1=0.05, q2=0.95 ) - print('Herdingspikes use the OLD spikeextractors with RecordingExtractorOldAPI') - old_api_recording = RecordingExtractorOldAPI(recording) + if new_api: + recording_to_hs = recording + else: + print('herdingspikes version < 0.4 uses the OLD spikeextractors with RecordingExtractorOldAPI.\n' + 'Consider updating herdingspikes (pip install herdingspikes>=0.4') + recording_to_hs = RecordingExtractorOldAPI(recording) # this should have its name changed Probe = hs.probe.RecordingExtractor( - old_api_recording, + recording_to_hs, masked_channels=p['probe_masked_channels'], inner_radius=p['probe_inner_radius'], neighbor_radius=p['probe_neighbor_radius'], diff --git a/spikeinterface/sorters/mountainsort4/mountainsort4.py b/spikeinterface/sorters/mountainsort4/mountainsort4.py index 0745c3f45e..a5cd96bf12 100644 --- a/spikeinterface/sorters/mountainsort4/mountainsort4.py +++ b/spikeinterface/sorters/mountainsort4/mountainsort4.py @@ -1,4 +1,4 @@ -import copy +from packaging import version from pathlib import Path from spikeinterface.toolkit import bandpass_filter, whiten @@ -92,6 +92,13 @@ def _setup_recording(cls, recording, output_folder, params, verbose): def _run_from_folder(cls, output_folder, params, verbose): import mountainsort4 + ms4_version = version.parse(mountainsort4.__version) + + if ms4_version >= version.parse("1.1.0"): + new_api = True + else: + new_api = False + recording = load_extractor(output_folder / 'spikeinterface_recording.json') # alias to params @@ -111,12 +118,16 @@ def _run_from_folder(cls, output_folder, params, verbose): print('whitenning') recording = whiten(recording=recording) - print('Mountainsort4 use the OLD spikeextractors mapped with RecordingExtractorOldAPI') - old_api_recording = RecordingExtractorOldAPI(recording) + if new_api: + recording_to_ms4 = recording + else: + print('mountainsort4 version < 1.1 uses the OLD spikeextractors with RecordingExtractorOldAPI.\n' + 'Consider updating mountainsort4 (pip install mountainsort4>=1.1') + recording_to_ms4 = RecordingExtractorOldAPI(recording) # Check location no more needed done in basesorter - old_api_sorting = mountainsort4.mountainsort4( - recording=old_api_recording, + sorting = mountainsort4.mountainsort4( + recording=recording_to_ms4, detect_sign=p['detect_sign'], adjacency_radius=p['adjacency_radius'], clip_size=p['clip_size'], @@ -137,10 +148,10 @@ def _run_from_folder(cls, output_folder, params, verbose): # ) # convert sorting to new API and save it - unit_ids = old_api_sorting.get_unit_ids() - units_dict_list = [{u: old_api_sorting.get_unit_spike_train(u) for u in unit_ids}] - new_api_sorting = NumpySorting.from_dict(units_dict_list, samplerate) - NpzSortingExtractor.write_sorting(new_api_sorting, str(output_folder / 'firings.npz')) + # unit_ids = old_api_sorting.get_unit_ids() + # units_dict_list = [{u: old_api_sorting.get_unit_spike_train(u) for u in unit_ids}] + # new_api_sorting = NumpySorting.from_dict(units_dict_list, samplerate) + NpzSortingExtractor.write_sorting(sorting, str(output_folder / 'firings.npz')) @classmethod def _get_result_from_folder(cls, output_folder): From 760e4ba9dad493c0b226e8878ff92b4b7a581678 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Oct 2021 10:20:26 +0200 Subject: [PATCH 2/2] oups --- spikeinterface/sorters/herdingspikes/herdingspikes.py | 2 +- spikeinterface/sorters/mountainsort4/mountainsort4.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/spikeinterface/sorters/herdingspikes/herdingspikes.py b/spikeinterface/sorters/herdingspikes/herdingspikes.py index 76986b2e9c..a92403d1b5 100644 --- a/spikeinterface/sorters/herdingspikes/herdingspikes.py +++ b/spikeinterface/sorters/herdingspikes/herdingspikes.py @@ -147,7 +147,7 @@ def _run_from_folder(cls, output_folder, params, verbose): import herdingspikes as hs import spikeinterface.toolkit as st - hs_version = version.parse(hs.__version) + hs_version = version.parse(hs.__version__) if hs_version >= version.parse("0.4.0"): new_api = True diff --git a/spikeinterface/sorters/mountainsort4/mountainsort4.py b/spikeinterface/sorters/mountainsort4/mountainsort4.py index a5cd96bf12..dc1a03e292 100644 --- a/spikeinterface/sorters/mountainsort4/mountainsort4.py +++ b/spikeinterface/sorters/mountainsort4/mountainsort4.py @@ -92,7 +92,7 @@ def _setup_recording(cls, recording, output_folder, params, verbose): def _run_from_folder(cls, output_folder, params, verbose): import mountainsort4 - ms4_version = version.parse(mountainsort4.__version) + ms4_version = version.parse(mountainsort4.__version__) if ms4_version >= version.parse("1.1.0"): new_api = True