diff --git a/dn3/data/dataset.py b/dn3/data/dataset.py index 2a397b6..e5c7dd8 100644 --- a/dn3/data/dataset.py +++ b/dn3/data/dataset.py @@ -381,7 +381,7 @@ def __init__(self, epochs: mne.Epochs, session_id=0, person_id=0, force_label=No self.epoch_codes_to_class_labels = event_mapping else: reverse_mapping = {v: k for k, v in event_mapping.items()} - self.epoch_codes_to_class_labels = {v: i for i, v in enumerate(sorted(reverse_mapping.keys()))} + self.epoch_codes_to_class_labels = {v: i for i, v in enumerate(sorted(reverse_mapping.values()))} skip_epochs = list() if skip_epochs is None else skip_epochs self._skip_map = [i for i in range(len(self.epochs.events)) if i not in skip_epochs] self._skip_map = dict(zip(range(len(self._skip_map)), self._skip_map))