Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions db/tracker/json_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def norm_trait(trait, value):
record = Model.objects.get(pk=value)
value = record.get()
# Otherwise, let's hope it's already an instance
elif ttype == 'Array':
value = np.array(value)
elif ttype == 'DataFile':
# Similar to Instance traits, except we always know to use models.DataFile as the database table to look up the primary key
from . import models
Expand Down
3 changes: 2 additions & 1 deletion features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .plexon_features import PlexonBMI, RelayPlexon, RelayPlexByte
from .hdf_features import SaveHDF
from .video_recording_features import SingleChannelVideo, E3Video
from .bmi_task_features import NormFiringRates
from .bmi_task_features import NormFiringRates, RandomUnitDropout
from .arduino_features import PlexonSerialDIORowByte
from .blackrock_features import BlackrockBMI
from .blackrock_features import RelayBlackrockByte
Expand Down Expand Up @@ -79,6 +79,7 @@
eye_calibration=EyeCalibration,
force_sensor=ForceControl,
show_fixation_progress=Progressbar_fixation,
random_unit_dropout=RandomUnitDropout,
clda_kfrml=CLDA_KFRML_IntendedVelocity
)

Expand Down
94 changes: 94 additions & 0 deletions features/bmi_task_features.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,110 @@
'''
BMI task features
'''
import ast
import time
import numpy as np
from riglib.experiment import traits, experiment
import copy

###### CONSTANTS
sec_per_min = 60

########################################################################################################
# Decoder/BMISystem add-ons
########################################################################################################
class RandomUnitDropout(traits.HasTraits):
'''
Randomly removes units from the decoder. Units are removed at the end of the delay period on each
trial and replaced when the trial ends (either in reward or penalty).The same units will be dropped
on repeated trials. The units to drop are specified in the `unit_drop_groups` attribute by a list of
lists of unit indices. The `unit_drop_targets` attribute specifies the target indices on which to
drop each group of units. Does not work with CLDA turned on.
'''

unit_drop_prob = traits.Float(0, desc="Probability of dropping a group of units from the decoder")
unit_drop_group1_channels = traits.List(value=[], desc="Channels to drop from the decoder one at a time")
unit_drop_group1_targets = traits.List(value=[], desc="Target indices on which to drop groups of units from the decoder")
unit_drop_group2_channels = traits.List(value=[], desc="Channels to drop from the decoder one at a time")
unit_drop_group2_targets = traits.List(value=[], desc="Target indices on which to drop groups of units from the decoder")
unit_drop_group3_channels = traits.List(value=[], desc="Channels to drop from the decoder one at a time")
unit_drop_group3_targets = traits.List(value=[], desc="Target indices on which to drop groups of units from the decoder")
unit_drop_group4_channels = traits.List(value=[], desc="Channels to drop from the decoder one at a time")
unit_drop_group4_targets = traits.List(value=[], desc="Target indices on which to drop groups of units from the decoder")

def init(self):
self.decoder_units_dropped = np.ones((len(self.decoder.units),), dtype='bool')
self.add_dtype('decoder_units_dropped', '?', self.decoder_units_dropped.shape)
self.unit_drop_groups = [self.unit_drop_group1_channels, self.unit_drop_group2_channels,
self.unit_drop_group3_channels, self.unit_drop_group4_channels]
self.unit_drop_groups = np.array([g for g in self.unit_drop_groups if not g == ['']])
self.unit_drop_targets = [self.unit_drop_group1_targets, self.unit_drop_group2_targets,
self.unit_drop_group3_targets, self.unit_drop_group4_targets]
self.unit_drop_targets = np.array([t for t in self.unit_drop_targets if not t == ['']])
self.unit_drop_group_idx = 0
super().init()

# Save a copy of the decoder
self.decoder_orig = copy.deepcopy(self.decoder)
self.reportstats['Units dropped'] = '' # Runtime stat displayed on the UI

print('--------------------\nUnit dropping settings:')
if len(self.unit_drop_groups) == len(self.unit_drop_targets):
for i in range(len(self.unit_drop_groups)):
print(f'group {i} channels: {self.unit_drop_groups[i]}, targets: {self.unit_drop_targets[i]}')
else:
print('Unit dropping settings invalid! Please reformat into groups [[ch1, ch2], ...] and targets [[t1, t2, ...], ...]')
print('Current groups:', self.unit_drop_groups)
print('Current targets:', self.unit_drop_targets)
print('--------------------')

def _start_wait(self):
super()._start_wait()

# Decide which units to drop in this trial but don't actually drop them yet
if (self.gen_indices[self.target_index] in np.array(self.unit_drop_targets[self.unit_drop_group_idx]) and
np.random.rand() < self.unit_drop_prob):
self.decoder_units_dropped = np.isin(self.decoder.channels, self.unit_drop_groups[self.unit_drop_group_idx])

# Update the group for next trial
self.unit_drop_group_idx = (self.unit_drop_group_idx + 1) % len(self.unit_drop_groups)
else:
self.decoder_units_dropped = np.zeros((len(self.decoder.units),), dtype='bool')

def _start_targ_transition(self):
'''
Override the decoder to drop random units. Keep a record of what's going on in the `trial` data.
'''
super()._start_targ_transition()
if self.target_index == -1:

# Came from a penalty state
pass
elif self.target_index + 1 < self.chain_length and np.any(self.decoder_units_dropped):
if hasattr(self.decoder.filt, 'C'):
self.decoder.filt.C[self.decoder_units_dropped, :] = 0
elif hasattr(self.decoder.filt, 'unit_to_state'):
self.decoder.filt.unit_to_state[:, self.decoder_units_dropped] = 0
self.task_data['decoder_units_dropped'] = self.decoder_units_dropped
self.reportstats['Units dropped'] = str(self.decoder.channels[self.decoder_units_dropped])

def _reset_decoder_units(self):
if hasattr(self.decoder.filt, 'C'):
self.decoder.filt.C = self.decoder_orig.filt.C
elif hasattr(self.decoder.filt, 'unit_to_state'):
self.decoder.filt.unit_to_state = self.decoder_orig.filt.unit_to_state
self.task_data['decoder_units_dropped'] = np.zeros((len(self.decoder.units),), dtype='bool')
self.reportstats['Units dropped'] = '[]'

def _increment_tries(self):
super()._increment_tries()
self._reset_decoder_units()

def _start_reward(self):
super()._start_reward()
self._reset_decoder_units()


class NormFiringRates(traits.HasTraits):
''' Docstring '''

Expand Down
2 changes: 1 addition & 1 deletion features/clda_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def create_updater(self):
alter the decoder parameters to better match the intention estimates.
'''
self.updater = clda.KFRML(self.clda_batch_time, self.clda_update_half_life)
self.updater.init(self.decoder)
self.updater.init(self.decoder)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setuptools.setup(
name="aolab-bmi3d",
version="1.0.5",
version="1.0.6",
author="Lots of people",
description="electrophysiology experimental rig library",
packages=setuptools.find_packages(),
Expand Down