diff --git a/db/tracker/json_param.py b/db/tracker/json_param.py index a048d442..b8befe36 100644 --- a/db/tracker/json_param.py +++ b/db/tracker/json_param.py @@ -83,6 +83,8 @@ def norm_trait(trait, value): record = Model.objects.get(pk=value) value = record.get() # Otherwise, let's hope it's already an instance + elif ttype == 'Array': + value = np.array(value) elif ttype == 'DataFile': # Similar to Instance traits, except we always know to use models.DataFile as the database table to look up the primary key from . import models diff --git a/features/__init__.py b/features/__init__.py index 0e8f7280..973007c4 100644 --- a/features/__init__.py +++ b/features/__init__.py @@ -15,7 +15,7 @@ from .plexon_features import PlexonBMI, RelayPlexon, RelayPlexByte from .hdf_features import SaveHDF from .video_recording_features import SingleChannelVideo, E3Video -from .bmi_task_features import NormFiringRates +from .bmi_task_features import NormFiringRates, RandomUnitDropout from .arduino_features import PlexonSerialDIORowByte from .blackrock_features import BlackrockBMI from .blackrock_features import RelayBlackrockByte @@ -79,6 +79,7 @@ eye_calibration=EyeCalibration, force_sensor=ForceControl, show_fixation_progress=Progressbar_fixation, + random_unit_dropout=RandomUnitDropout, clda_kfrml=CLDA_KFRML_IntendedVelocity ) diff --git a/features/bmi_task_features.py b/features/bmi_task_features.py index f3557c51..1c59b08d 100644 --- a/features/bmi_task_features.py +++ b/features/bmi_task_features.py @@ -1,9 +1,11 @@ ''' BMI task features ''' +import ast import time import numpy as np from riglib.experiment import traits, experiment +import copy ###### CONSTANTS sec_per_min = 60 @@ -11,6 +13,98 @@ ######################################################################################################## # Decoder/BMISystem add-ons ######################################################################################################## +class RandomUnitDropout(traits.HasTraits): + ''' + Randomly removes units from the decoder. Units are removed at the end of the delay period on each + trial and replaced when the trial ends (either in reward or penalty).The same units will be dropped + on repeated trials. The units to drop are specified in the `unit_drop_groups` attribute by a list of + lists of unit indices. The `unit_drop_targets` attribute specifies the target indices on which to + drop each group of units. Does not work with CLDA turned on. + ''' + + unit_drop_prob = traits.Float(0, desc="Probability of dropping a group of units from the decoder") + unit_drop_group1_channels = traits.List(value=[], desc="Channels to drop from the decoder one at a time") + unit_drop_group1_targets = traits.List(value=[], desc="Target indices on which to drop groups of units from the decoder") + unit_drop_group2_channels = traits.List(value=[], desc="Channels to drop from the decoder one at a time") + unit_drop_group2_targets = traits.List(value=[], desc="Target indices on which to drop groups of units from the decoder") + unit_drop_group3_channels = traits.List(value=[], desc="Channels to drop from the decoder one at a time") + unit_drop_group3_targets = traits.List(value=[], desc="Target indices on which to drop groups of units from the decoder") + unit_drop_group4_channels = traits.List(value=[], desc="Channels to drop from the decoder one at a time") + unit_drop_group4_targets = traits.List(value=[], desc="Target indices on which to drop groups of units from the decoder") + + def init(self): + self.decoder_units_dropped = np.ones((len(self.decoder.units),), dtype='bool') + self.add_dtype('decoder_units_dropped', '?', self.decoder_units_dropped.shape) + self.unit_drop_groups = [self.unit_drop_group1_channels, self.unit_drop_group2_channels, + self.unit_drop_group3_channels, self.unit_drop_group4_channels] + self.unit_drop_groups = np.array([g for g in self.unit_drop_groups if not g == ['']]) + self.unit_drop_targets = [self.unit_drop_group1_targets, self.unit_drop_group2_targets, + self.unit_drop_group3_targets, self.unit_drop_group4_targets] + self.unit_drop_targets = np.array([t for t in self.unit_drop_targets if not t == ['']]) + self.unit_drop_group_idx = 0 + super().init() + + # Save a copy of the decoder + self.decoder_orig = copy.deepcopy(self.decoder) + self.reportstats['Units dropped'] = '' # Runtime stat displayed on the UI + + print('--------------------\nUnit dropping settings:') + if len(self.unit_drop_groups) == len(self.unit_drop_targets): + for i in range(len(self.unit_drop_groups)): + print(f'group {i} channels: {self.unit_drop_groups[i]}, targets: {self.unit_drop_targets[i]}') + else: + print('Unit dropping settings invalid! Please reformat into groups [[ch1, ch2], ...] and targets [[t1, t2, ...], ...]') + print('Current groups:', self.unit_drop_groups) + print('Current targets:', self.unit_drop_targets) + print('--------------------') + + def _start_wait(self): + super()._start_wait() + + # Decide which units to drop in this trial but don't actually drop them yet + if (self.gen_indices[self.target_index] in np.array(self.unit_drop_targets[self.unit_drop_group_idx]) and + np.random.rand() < self.unit_drop_prob): + self.decoder_units_dropped = np.isin(self.decoder.channels, self.unit_drop_groups[self.unit_drop_group_idx]) + + # Update the group for next trial + self.unit_drop_group_idx = (self.unit_drop_group_idx + 1) % len(self.unit_drop_groups) + else: + self.decoder_units_dropped = np.zeros((len(self.decoder.units),), dtype='bool') + + def _start_targ_transition(self): + ''' + Override the decoder to drop random units. Keep a record of what's going on in the `trial` data. + ''' + super()._start_targ_transition() + if self.target_index == -1: + + # Came from a penalty state + pass + elif self.target_index + 1 < self.chain_length and np.any(self.decoder_units_dropped): + if hasattr(self.decoder.filt, 'C'): + self.decoder.filt.C[self.decoder_units_dropped, :] = 0 + elif hasattr(self.decoder.filt, 'unit_to_state'): + self.decoder.filt.unit_to_state[:, self.decoder_units_dropped] = 0 + self.task_data['decoder_units_dropped'] = self.decoder_units_dropped + self.reportstats['Units dropped'] = str(self.decoder.channels[self.decoder_units_dropped]) + + def _reset_decoder_units(self): + if hasattr(self.decoder.filt, 'C'): + self.decoder.filt.C = self.decoder_orig.filt.C + elif hasattr(self.decoder.filt, 'unit_to_state'): + self.decoder.filt.unit_to_state = self.decoder_orig.filt.unit_to_state + self.task_data['decoder_units_dropped'] = np.zeros((len(self.decoder.units),), dtype='bool') + self.reportstats['Units dropped'] = '[]' + + def _increment_tries(self): + super()._increment_tries() + self._reset_decoder_units() + + def _start_reward(self): + super()._start_reward() + self._reset_decoder_units() + + class NormFiringRates(traits.HasTraits): ''' Docstring ''' diff --git a/features/clda_features.py b/features/clda_features.py index 64e85175..e006e842 100644 --- a/features/clda_features.py +++ b/features/clda_features.py @@ -41,4 +41,4 @@ def create_updater(self): alter the decoder parameters to better match the intention estimates. ''' self.updater = clda.KFRML(self.clda_batch_time, self.clda_update_half_life) - self.updater.init(self.decoder) \ No newline at end of file + self.updater.init(self.decoder) diff --git a/setup.py b/setup.py index 0a3272a3..09aa4f97 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="aolab-bmi3d", - version="1.0.5", + version="1.0.6", author="Lots of people", description="electrophysiology experimental rig library", packages=setuptools.find_packages(),