Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions spikeinterface/sorters/herdingspikes/herdingspikes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
import copy
from packaging import version

from ..basesorter import BaseSorter
from ..utils import RecordingExtractorOldAPI
Expand Down Expand Up @@ -147,6 +147,13 @@ def _run_from_folder(cls, output_folder, params, verbose):
import herdingspikes as hs
import spikeinterface.toolkit as st

hs_version = version.parse(hs.__version__)

if hs_version >= version.parse("0.4.0"):
new_api = True
else:
new_api = False

recording = load_extractor(output_folder / 'spikeinterface_recording.json')

p = params
Expand All @@ -162,12 +169,16 @@ def _run_from_folder(cls, output_folder, params, verbose):
median=0.0, q1=0.05, q2=0.95
)

print('Herdingspikes use the OLD spikeextractors with RecordingExtractorOldAPI')
old_api_recording = RecordingExtractorOldAPI(recording)
if new_api:
recording_to_hs = recording
else:
print('herdingspikes version < 0.4 uses the OLD spikeextractors with RecordingExtractorOldAPI.\n'
'Consider updating herdingspikes (pip install herdingspikes>=0.4')
recording_to_hs = RecordingExtractorOldAPI(recording)

# this should have its name changed
Probe = hs.probe.RecordingExtractor(
old_api_recording,
recording_to_hs,
masked_channels=p['probe_masked_channels'],
inner_radius=p['probe_inner_radius'],
neighbor_radius=p['probe_neighbor_radius'],
Expand Down
29 changes: 20 additions & 9 deletions spikeinterface/sorters/mountainsort4/mountainsort4.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import copy
from packaging import version
from pathlib import Path

from spikeinterface.toolkit import bandpass_filter, whiten
Expand Down Expand Up @@ -92,6 +92,13 @@ def _setup_recording(cls, recording, output_folder, params, verbose):
def _run_from_folder(cls, output_folder, params, verbose):
import mountainsort4

ms4_version = version.parse(mountainsort4.__version__)

if ms4_version >= version.parse("1.1.0"):
new_api = True
else:
new_api = False

recording = load_extractor(output_folder / 'spikeinterface_recording.json')

# alias to params
Expand All @@ -111,12 +118,16 @@ def _run_from_folder(cls, output_folder, params, verbose):
print('whitenning')
recording = whiten(recording=recording)

print('Mountainsort4 use the OLD spikeextractors mapped with RecordingExtractorOldAPI')
old_api_recording = RecordingExtractorOldAPI(recording)
if new_api:
recording_to_ms4 = recording
else:
print('mountainsort4 version < 1.1 uses the OLD spikeextractors with RecordingExtractorOldAPI.\n'
'Consider updating mountainsort4 (pip install mountainsort4>=1.1')
recording_to_ms4 = RecordingExtractorOldAPI(recording)

# Check location no more needed done in basesorter
old_api_sorting = mountainsort4.mountainsort4(
recording=old_api_recording,
sorting = mountainsort4.mountainsort4(
recording=recording_to_ms4,
detect_sign=p['detect_sign'],
adjacency_radius=p['adjacency_radius'],
clip_size=p['clip_size'],
Expand All @@ -137,10 +148,10 @@ def _run_from_folder(cls, output_folder, params, verbose):
# )

# convert sorting to new API and save it
unit_ids = old_api_sorting.get_unit_ids()
units_dict_list = [{u: old_api_sorting.get_unit_spike_train(u) for u in unit_ids}]
new_api_sorting = NumpySorting.from_dict(units_dict_list, samplerate)
NpzSortingExtractor.write_sorting(new_api_sorting, str(output_folder / 'firings.npz'))
# unit_ids = old_api_sorting.get_unit_ids()
# units_dict_list = [{u: old_api_sorting.get_unit_spike_train(u) for u in unit_ids}]
# new_api_sorting = NumpySorting.from_dict(units_dict_list, samplerate)
NpzSortingExtractor.write_sorting(sorting, str(output_folder / 'firings.npz'))

@classmethod
def _get_result_from_folder(cls, output_folder):
Expand Down