diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index 75bd00836..a0d08cdb9 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -4,8 +4,10 @@ import numpy as np import time, random -from riglib.experiment import traits, experiment +from riglib.experiment import traits, experiment, Sequence from features.bmi_task_features import LinearlyDecreasingAssist, LinearlyDecreasingHalfLife +from target_capture_task import plantlist + import os from riglib.bmi import clda, assist, extractor, train, goal_calculators, ppfdecoder @@ -25,9 +27,29 @@ from riglib.stereo_opengl.window import WindowDispl2D from riglib.stereo_opengl.primitives import Line + from riglib.bmi.state_space_models import StateSpaceEndptVel2D, StateSpaceNLinkPlanarChain -from .target_capture_task import ScreenTargetCapture + +from built_in_tasks.manualcontrolmultitasks import ManualControlMulti + + +target_colors = {"blue":(0,0,1,0.5), +"yellow": (1,1,0,0.5), +"hibiscus":(0.859,0.439,0.576,0.5), +"magenta": (1,0,1,0.5), +"purple":(0.608,0.188,1,0.5), +"lightsteelblue":(0.690,0.769,0.901,0.5), +"dodgerblue": (0.118,0.565,1,0.5), +"teal":(0,0.502,0.502,0.5), +"aquamarine":(0.498,1,0.831,0.5), +"olive":(0.420,0.557,0.137,0.5), +"chiffonlemon": (0.933,0.914,0.749,0.5), +"juicyorange": (1,0.502,0,0.5), +"salmon":(1,0.549,0.384,0.5), +"wood": (0.259,0.149,0.071,0.5), +"elephant":(0.409,0.409,0.409,0.5)} + np.set_printoptions(suppress=False) @@ -71,7 +93,7 @@ def get_F(self, assist_level): np.mat ''' assist_level_idx = min(int(assist_level * self.n_assist_levels), self.n_assist_levels-1) - F = np.mat(self.fb_ctrl.F_dict[assist_level_idx]) + F = np.mat(self.fb_ctrl.F_dict[assist_level_idx]) return F class SimpleEndpointAssister(Assister): @@ -98,12 +120,12 @@ def calc_assisted_BMI_state(self, current_state, target_state, assist_level, mod speed = self.assist_speed * decoder_binlen target_radius = self.target_radius Bu = self.endpoint_assist_simple(cursor_pos, target_pos, decoder_binlen, speed, target_radius, assist_level) - assist_weight = assist_level + assist_weight = assist_level # return Bu, assist_weight return dict(x_assist=Bu, assist_level=assist_weight) - @staticmethod + @staticmethod def endpoint_assist_simple(cursor_pos, target_pos, decoder_binlen=0.1, speed=0.5, target_radius=2., assist_level=0.): ''' Estimate the next state using a constant velocity estimate moving toward the specified target @@ -128,10 +150,10 @@ def endpoint_assist_simple(cursor_pos, target_pos, decoder_binlen=0.1, speed=0.5 x_assist : np.ndarray of shape (7, 1) Control vector to add onto the state vector to assist control. ''' - diff_vec = target_pos - cursor_pos + diff_vec = target_pos - cursor_pos dist_to_target = np.linalg.norm(diff_vec) dir_to_target = diff_vec / (np.spacing(1) + dist_to_target) - + if dist_to_target > target_radius: assist_cursor_pos = cursor_pos + speed*dir_to_target else: @@ -184,7 +206,7 @@ def __init__(self, *args, **kwargs): ------- ''' dt = 0.1 - A = np.mat([[1., 0, 0, dt, 0, 0, 0], + A = np.mat([[1., 0, 0, dt, 0, 0, 0], [0., 1, 0, 0, dt, 0, 0], [0., 0, 1, 0, 0, dt, 0], [0., 0, 0, 0, 0, 0, 0], @@ -203,7 +225,7 @@ def __init__(self, *args, **kwargs): ################# ##### Tasks ##### ################# -class BMIControlMulti(BMILoop, LinearlyDecreasingAssist, ScreenTargetCapture): +class BMIControlMulti(BMILoop, LinearlyDecreasingAssist, ManualControlMulti): ''' Target capture task with cursor position controlled by BMI output. Cursor movement can be assisted toward target by setting assist_level > 0. @@ -220,9 +242,23 @@ class BMIControlMulti(BMILoop, LinearlyDecreasingAssist, ScreenTargetCapture): is_bmi_seed = False - def __init__(self, *args, **kwargs): + cursor_color_adjust = traits.OptionsList(*list(target_colors.keys()), bmi3d_input_options=list(target_colors.keys())) + + def __init__(self, *args, **kwargs): super(BMIControlMulti, self).__init__(*args, **kwargs) - + + def init(self, *args, **kwargs): + sph = self.plant.graphics_models[0] + sph.color = target_colors[self.cursor_color_adjust] + sph.radius = self.cursor_radius + self.plant.cursor_radius = self.cursor_radius + self.plant.cursor.radius = self.cursor_radius + super(BMIControlMulti, self).init(*args, **kwargs) + + + def move_effector(self, *args, **kwargs): + pass + def create_assister(self): # Create the appropriate type of assister object start_level, end_level = self.assist_level @@ -244,8 +280,8 @@ def create_assister(self): ## # self.assister = FeedbackControllerAssist(fb_ctrl, style='additive') ## self.assister = TentacleAssist(ssm=self.decoder.ssm, kin_chain=self.plant.kin_chain, update_rate=self.decoder.binlen) else: - raise NotImplementedError("Cannot assist for this type of statespace: %r" % self.decoder.ssm) - + raise NotImplementedError("Cannot assist for this type of statespace: %r" % self.decoder.ssm) + print(self.assister) def create_goal_calculator(self): @@ -269,7 +305,7 @@ def create_goal_calculator(self): goal_calc_class = goal_calculators.PlanarMultiLinkJointGoal multiproc = True - self.goal_calculator = goal_calc_class(namelist.tentacle_2D_state_space, shoulder_anchor, + self.goal_calculator = goal_calc_class(namelist.tentacle_2D_state_space, shoulder_anchor, chain, multiproc=multiproc, init_resp=x_init) else: raise ValueError("Unrecognized decoder state space!") @@ -296,24 +332,837 @@ def _end_timeout_penalty(self): self.decoder.filt.state.mean = self.init_decoder_mean self.hdf.sendMsg("reset") + def move_effector(self): + pass + + # def _test_enter_target(self, ts): + # ''' + # return true if the distance between center of cursor and target is smaller than the cursor radius + # ''' + # cursor_pos = self.plant.get_endpoint_pos() + # d = np.linalg.norm(cursor_pos - self.target_location) + # return d <= self.target_radius class BMIControlMulti2DWindow(BMIControlMulti, WindowDispl2D): fps = 20. def __init__(self,*args, **kwargs): super(BMIControlMulti2DWindow, self).__init__(*args, **kwargs) + + def create_goal_calculator(self): + self.goal_calculator = goal_calculators.ZeroVelocityGoal(self.decoder.ssm) + + def _start_wait(self): + self.wait_time = 0. + super(BMIControlMulti2DWindow, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + + +class BMIResetting(BMIControlMulti): + ''' + Task where the virtual plant starts in configuration sampled from a discrete set and resets every trial + ''' + status = dict( + wait = dict(start_trial="premove", stop=None), + premove=dict(premove_complete="target"), + target = dict(enter_target="hold", timeout="timeout_penalty", stop=None), + hold = dict(leave_early="hold_penalty", hold_complete="targ_transition"), + targ_transition = dict(trial_complete="reward",trial_abort="wait", trial_incomplete="target", trial_restart="premove"), + timeout_penalty = dict(timeout_penalty_end="targ_transition"), + hold_penalty = dict(hold_penalty_end="targ_transition"), + reward = dict(reward_end="wait") + ) + + plant_visible = 1 + plant_hide_rate = -1 + premove_time = traits.Float(.1, desc='Time before subject must start doing BMI control') + # static_states = ['premove'] # states in which the decoder is not run + add_noise = 0.35 + sequence_generators = BMIControlMulti.sequence_generators + ['outcenter_half_hidden', 'short_long_centerout'] + + # def __init__(self, *args, **kwargs): + # super(BMIResetting, self).__init__(*args, **kwargs) + + def init(self, *args, **kwargs): + #self.add_dtype('bmi_P', 'f8', (self.decoder.ssm.n_states, self.decoder.ssm.n_states)) + super(BMIResetting, self).init(*args, **kwargs) + + # def move_plant(self, *args, **kwargs): + # super(BMIResetting, self).move_plant(*args, **kwargs) + # c = self.plant.get_endpoint_pos() + # self.plant.set_endpoint_pos(c + self.add_noise*np.array([np.random.rand()-0.5, 0., np.random.rand()-0.5])) + + def _cycle(self, *args, **kwargs): + #self.task_data['bmi_P'] = self.decoder.filt.state.cov + super(BMIResetting, self)._cycle(*args, **kwargs) + + def _while_premove(self): + self.plant.set_endpoint_pos(self.targs[0]) + self.decoder['q'] = self.plant.get_intrinsic_coordinates() + # self.decoder.filt.state.mean = self.calc_perturbed_ik(self.targs[0]) + + def _start_premove(self): + + #move a target to current location (target1 and target2 alternate moving) and set location attribute + target = self.targets[(self.target_index+1) % 2] + target.move_to_position(self.targs[self.target_index+1]) + target.cue_trial_start() + + def _end_timeout_penalty(self): + pass + + def _test_premove_complete(self, ts): + return ts>=self.premove_time + + def _parse_next_trial(self): + try: + self.targs, self.plant_visible = self.next_trial + except: + self.targs = self.next_trial + + def _test_hold_complete(self,ts): + ## Disable origin holds for this task + if self.target_index == 0: + return True + else: + return ts>=self.hold_time + + def _test_trial_incomplete(self, ts): + return (self.target_index self.wait_time and not self.pause + + @staticmethod + def sim_target_seq_generator_multi(n_targs=8, n_trials=8): + ''' + Simulated generator for simulations of the BMIControlMulti and CLDAControlMulti tasks + ''' + center = np.zeros(2) + pi = np.pi + targets = 8*np.vstack([[np.cos(pi/4*k), np.sin(pi/4*k)] for k in range(8)]) + + target_inds = np.random.randint(0, n_targs, n_trials) + target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + for k in range(n_trials): + targ = targets[target_inds[k], :] + yield np.array([[center[0], 0, center[1]], + [targ[0], 0, targ[1]]]) + + @staticmethod + def sim_target_no_center(n_targs=8, n_trials=8): + ''' + Simulated generator for simulations of the BMIControlMulti and CLDAControlMulti tasks + ''' + pi = np.pi + targets = 8*np.vstack([[np.cos(pi/4*k), np.sin(pi/4*k)] for k in range(8)]) + + target_inds = np.random.randint(0, n_targs, n_trials) + target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + for k in range(n_trials): + targ = targets[target_inds[k], :] + yield np.array([[targ[0], 0, targ[1]]]) + +class SimBMICosEncKFDec(SimCosineTunedEnc, SimKFDecoderSup, SimBMIControlMulti): + def __init__(self, *args, **kwargs): + N_NEURONS = 4 + N_STATES = 7 # 3 positions and 3 velocities and an offset + + # build the observation matrix + sim_C = np.zeros((N_NEURONS, N_STATES)) + # control x positive directions + sim_C[0, :] = np.array([0, 0, 0, 1, 0, 0, 0]) + sim_C[1, :] = np.array([0, 0, 0, -1, 0, 0, 0]) + # control z positive directions + sim_C[2, :] = np.array([0, 0, 0, 0, 0, 1, 0]) + sim_C[3, :] = np.array([0, 0, 0, 0, 0, -1, 0]) + + kwargs['sim_C'] = sim_C + + ssm = StateSpaceEndptVel2D() + A, B, W = ssm.get_ssm_matrices() + Q = np.mat(np.diag([1., 1, 1, 0, 0, 0, 0])) + R = 10000*np.mat(np.diag([1., 1., 1.])) + self.fb_ctrl = LQRController(A, B, Q, R) + self.ssm = ssm + + super(SimBMICosEncKFDec, self).__init__(*args, **kwargs) + +from features.simulation_features import SimLFPCosineTunedEnc, SimNormCosineTunedEnc +from riglib.bmi.lindecoder import PosVelScaleFilter +from riglib.bmi.bmi import Decoder +class SimBMICosEncLinDec(SimLFPCosineTunedEnc, SimBMIControlMulti): + def __init__(self, *args, **kwargs): + + ssm = StateSpaceEndptVel2D() + + # build the observation matrix + sim_C = np.zeros((2, 7)) + + # control x and z position + sim_C[0, :] = np.array([1, 0, 0, 0, 0, 0, 0]) + sim_C[1, :] = np.array([0, 0, 1, 0, 0, 0, 0]) + self.vel_control = False + self.fb_ctrl = PosFeedbackController() + + # map neurons (2) to states (7) using C + self.decoder_map = sim_C.T + self.ssm = ssm + kwargs['sim_C'] = sim_C + kwargs['assist_level'] = (0, 0) + + super(SimBMICosEncLinDec, self).__init__(*args, **kwargs) + + def init(self, *args, **kwargs): + self.max_attempts = 1 + self.timeout_time = 1 + super(SimBMICosEncLinDec, self).init(*args, **kwargs) + + def load_decoder(self): + units = self.encoder.get_units() + filt_counts = 10000 # number of observations to calculate range + filt_window = 1 # number of observations to average for each tick + filt_map = self.decoder_map # map from states to units + filt = PosVelScaleFilter(self.vel_control, filt_counts, self.ssm.n_states, \ + len(units), unit_to_state=filt_map, smoothing_window=filt_window, call_rate=self.fps, + decoder_to_plant=2*np.max(self.plant.endpt_bounds)) + + # supply some known good attributes + neural_gain = self.fov + scaling_gain = 1 + filt.update_norm_attr(neural_mean=[neural_gain/2, neural_gain/2], neural_std=[neural_gain,neural_gain], \ + offset=[0,0], scale=[scaling_gain,scaling_gain]) + #filt.fix_norm_attr() + + # or allow decoder to figure it out + # neural_gain = self.fov * 1.1 + # filt.update_norm_attr(neural_mean=[neural_gain/2, neural_gain/2], neural_std=[neural_gain,neural_gain]) + + self.decoder = Decoder(filt, units, self.ssm, binlen=0.1, subbins=1, call_rate=self.fps) + self.decoder.n_features = len(units) + +class SimBMIVelocityLinDec(SimBMICosEncLinDec): + def __init__(self, *args, **kwargs): + + ssm = StateSpaceEndptVel2D() + + # control x and z velocity + sim_C = np.zeros((2, 7)) + sim_C[0, :] = np.array([0, 0, 0, 1, 0, 0, 0]) + sim_C[1, :] = np.array([0, 0, 0, 0, 0, 1, 0]) + self.vel_control = True + A, B, W = ssm.get_ssm_matrices() + Q = np.mat(np.diag([1., 1, 1, 0, 0, 0, 0])) + R = 10000*np.mat(np.diag([1., 1., 1.])) + self.fb_ctrl = LQRController(A, B, Q, R) + + # map neurons (2) to states (7) using C + self.decoder_map = sim_C.T + self.ssm = ssm + kwargs['sim_C'] = sim_C + kwargs['assist_level'] = (0, 0) + + super(SimBMICosEncLinDec, self).__init__(*args, **kwargs) + + +from built_in_tasks.target_graphics import VirtualCircularTarget +class ControlMultiNoWindow(Sequence): + ''' + models after manualControlMultiTasks + perserves the trial structure + but gets rid of the display completely + + simulation happens in the mind, who needs a screen. + learning happens in our brain not on the screen + Author Si Jia, Jan 2021 + ''' + + ''' + for the class attributs + copied and pastes most of the stuff. + but the following is not included from manualCursorControl + + background = (0,0,0,1) + cursor_color = (.5,0,.5,1) + + target_color = (1,0,0,.5) + cursor_visible = False # Determines when to hide the cursor. + + ''' + + starting_pos = (5, 0, 5) + + status = dict( + wait = dict(start_trial="target"), + target = dict(enter_target="hold", timeout="timeout_penalty"), + hold = dict(leave_target="hold_penalty", hold_complete="delay"), + delay = dict(leave_target="delay_penalty", delay_complete="targ_transition"), + targ_transition = dict(trial_complete="reward", trial_abort="wait", trial_incomplete="target"), + timeout_penalty = dict(timeout_penalty_end="targ_transition", end_state=True), + hold_penalty = dict(hold_penalty_end="targ_transition", end_state=True), + delay_penalty = dict(delay_penalty_end="targ_transition", end_state=True), + reward = dict(reward_end="wait", stoppable=False, end_state=True) + ) + + trial_end_states = ['reward', 'timeout_penalty', 'hold_penalty'] + + RED = (1,0,0,.5) + _target_color = RED + + #initial state + state = "wait" + target_index = -1 # Helper variable to keep track of which target to display within a trial + tries = 0 # Helper variable to keep track of the number of failed attempts at a given trial. + + no_data_count = 0 # Counter for number of missing data frames in a row + scale_factor = 3.0 #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement) + + sequence_generators = ['centerout_2D_discrete', 'centerout_2D_discrete_offset', 'point_to_point_3D', 'centerout_3D', 'centerout_3D_cube', 'centerout_2D_discrete_upper','centerout_2D_discrete_rot', 'centerout_2D_discrete_multiring', + 'centerout_2D_discrete_randorder', 'centeroutback_2D', 'centeroutback_2D_farcatch', 'centeroutback_2D_farcatch_discrete', + 'outcenterout_2D_discrete', 'outcenter_2D_discrete', 'rand_target_sequence_3d', 'rand_target_sequence_2d', 'rand_target_sequence_2d_centerout', + 'rand_target_sequence_2d_partial_centerout', 'rand_multi_sequence_2d_centerout2step', 'rand_pt_to_pt', + 'centerout_2D_discrete_far', 'centeroutback_2D_v2','centerout_2D_discrete_eyetracker_calibration'] + is_bmi_seed = True + + + # Runtime settable traits + reward_time = traits.Float(.2, desc="Length of juice reward") + target_radius = traits.Float(2, desc="Radius of targets in cm") + + hold_time = traits.Float(.2, desc="Length of hold required at targets") + hold_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error") + timeout_time = traits.Float(10, desc="Time allowed to go between targets") + timeout_penalty_time = traits.Float(1, desc="Length of penalty time for timeout error") + max_attempts = traits.Int(10, desc='The number of attempts at a target before\ + skipping to the next one') + # session_length = traits.Float(0, desc="Time until task automatically stops. Length of 0 means no auto stop.") + marker_num = traits.Int(14, desc="The index of the motiontracker marker to use for cursor position") + + + plant_hide_rate = traits.Float(0.0, desc='If the plant is visible, specifies a percentage of trials where it will be hidden') + plant_type_options = list(plantlist.keys()) + plant_type = traits.OptionsList(*plantlist, bmi3d_input_options=list(plantlist.keys())) + plant_visible = traits.Bool(True, desc='Specifies whether entire plant is displayed or just endpoint') + cursor_radius = traits.Float(.5, desc="Radius of cursor") + + + + def __init__(self, *args, **kwargs): + super(ControlMultiNoWindow, self).__init__(*args, **kwargs) + #self.cursor_visible = True + + # Initialize the plant + if not hasattr(self, 'plant'): + self.plant = plantlist[self.plant_type] + #self.plant_vis_prev = True + + + # Instantiate the targets + instantiate_targets = kwargs.pop('instantiate_targets', True) + if instantiate_targets: + target1 = VirtualCircularTarget(target_radius=self.target_radius, target_color=self._target_color) + target2 = VirtualCircularTarget(target_radius=self.target_radius, target_color=self._target_color) + + self.targets = [target1, target2] + + # Initialize target location variable + self.target_location = np.array([0, 0, 0]) + + # Declare any plant attributes which must be saved to the HDF file at the _cycle rate + for attr in self.plant.hdf_attrs: + self.add_dtype(*attr) + + def init(self): + self.add_dtype('target', 'f8', (3,)) + self.add_dtype('target_index', 'i', (1,)) + super(ControlMultiNoWindow, self).init() + + + def _cycle(self): + ''' + Calls any update functions necessary + ''' + self.task_data['target'] = self.target_location.copy() + self.task_data['target_index'] = self.target_index + + + self.move_effector() + + ## Save plant status to HDF file + plant_data = self.plant.get_data_to_save() + for key in plant_data: + self.task_data[key] = plant_data[key] + + super(ControlMultiNoWindow, self)._cycle() + + + def move_effector(self): + ''' + mainly for sim, default do don't do anything + ''' + pass + + + def run(self): + ''' + See experiment.Experiment.run for documentation. + ''' + # Fire up the plant. For virtual/simulation plants, this does little/nothing. + self.plant.start() + try: + super(ControlMultiNoWindow, self).run() + finally: + self.plant.stop() + + def update_report_stats(self): + ''' + see experiment.Experiment.update_report_stats for docs + ''' + super(ControlMultiNoWindow, self).update_report_stats() + self.reportstats['Trial #'] = self.calc_trial_num() + self.reportstats['Reward/min'] = np.round(self.calc_events_per_min('reward', 120.), decimals=2) + + + + + #### TEST FUNCTIONS #### + def _test_enter_target(self, ts): + ''' + return true if the distance between center of cursor and target is smaller than the cursor radius + ''' + cursor_pos = self.plant.get_endpoint_pos() + d = np.linalg.norm(cursor_pos - self.target_location) + return d <= (self.target_radius - self.cursor_radius) + + def _test_leave_early(self, ts): + ''' + return true if cursor moves outside the exit radius + ''' + cursor_pos = self.plant.get_endpoint_pos() + d = np.linalg.norm(cursor_pos - self.target_location) + rad = self.target_radius - self.cursor_radius + return d > rad + + def _test_hold_complete(self, ts): + return ts>=self.hold_time + + def _test_timeout(self, ts): + return ts>self.timeout_time + + def _test_timeout_penalty_end(self, ts): + return ts>self.timeout_penalty_time + + def _test_hold_penalty_end(self, ts): + return ts>self.hold_penalty_time + + def _test_trial_complete(self, ts): + return self.target_index==self.chain_length-1 + + def _test_trial_incomplete(self, ts): + return (not self._test_trial_complete(ts)) and (self.triesself.reward_time + + + #### STATE FUNCTIONS #### + def _parse_next_trial(self): + self.targs = self.next_trial + + + def _start_wait(self): + super(ControlMultiNoWindow, self)._start_wait() + self.tries = 0 + self.target_index = -1 + + self.chain_length = self.targs.shape[0] #Number of sequential targets in a single trial + + + def _start_target(self): + self.target_index += 1 + + #move a target to current location (target1 and target2 alternate moving) and set location attribute + target = self.targets[self.target_index % 2] + self.target_location = self.targs[self.target_index] + target.move_to_position(self.target_location) + #target.cue_trial_start() + + if self.target_index == 0: + target.move_to_position(self.targs[self.target_index]) + target.show() + self.sync_event('TARGET_ON', self.gen_indices[self.target_index]) + + + def _start_hold(self): + #make next target visible unless this is the final target in the trial + idx = (self.target_index + 1) + if idx < self.chain_length: + target = self.targets[idx % 2] + target.move_to_position(self.targs[idx]) + + self.sync_event('CURSOR_ENTER_TARGET', self.gen_indices[self.target_index]) + + + def _end_hold(self): + # change current target color to green + #self.targets[self.target_index % 2].cue_trial_end_success() + pass + + + + def _start_hold_penalty(self): + + self.tries += 1 + self.target_index = -1 + + def _start_timeout_penalty(self): + + self.tries += 1 + self.target_index = -1 + + + def _start_targ_transition(self): + pass + + def _start_reward(self): + pass + + +from target_capture_task import ConcreteTargetCapture + +class BMIControlMultiNoWindow(BMILoop, LinearlyDecreasingAssist, ConcreteTargetCapture): + ''' + Target capture task with cursor position controlled by BMI output. + Cursor movement can be assisted toward target by setting assist_level > 0. + ''' + + #background = (.5,.5,.5,1) # Set the screen background color to grey + reset = traits.Int(0, desc='reset the decoder state to the starting configuration') + + ordered_traits = ['session_length', 'assist_level', 'assist_level_time', 'reward_time','timeout_time','timeout_penalty_time'] + exclude_parent_traits = ['marker_count', 'marker_num', 'goal_cache_block'] + + static_states = [] # states in which the decoder is not run + hidden_traits = ['arm_hide_rate', 'arm_visible', 'hold_penalty_time', 'rand_start', 'reset', 'target_radius', 'window_size'] + + is_bmi_seed = False + + #cursor_color_adjust = traits.OptionsList(*list(target_colors.keys()), bmi3d_input_options=list(target_colors.keys())) + + def __init__(self, *args, **kwargs): + super(BMIControlMultiNoWindow, self).__init__(*args, **kwargs) + + def init(self, *args, **kwargs): + sph = self.plant.graphics_models[0] + #sph.color = target_colors[self.cursor_color_adjust] no window, no color + sph.radius = self.cursor_radius + self.plant.cursor_radius = self.cursor_radius + self.plant.cursor.radius = self.cursor_radius + super(BMIControlMultiNoWindow, self).init(*args, **kwargs) + + + #the following two functions from BMIControlMulti + def _start_wait(self): + self.wait_time = 0. + super(BMIControlMultiNoWindow, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + + + def move_effector(self, *args, **kwargs): + pass def create_assister(self): + # Create the appropriate type of assister object + start_level, end_level = self.assist_level kwargs = dict(decoder_binlen=self.decoder.binlen, target_radius=self.target_radius) if hasattr(self, 'assist_speed'): kwargs['assist_speed'] = self.assist_speed - self.assister = SimpleEndpointAssister(**kwargs) + + if isinstance(self.decoder.ssm, StateSpaceEndptVel2D) and isinstance(self.decoder, ppfdecoder.PPFDecoder): + self.assister = OFCEndpointAssister() + elif isinstance(self.decoder.ssm, StateSpaceEndptVel2D): + self.assister = SimpleEndpointAssister(**kwargs) + else: + raise NotImplementedError("Cannot assist for this type of statespace: %r" % self.decoder.ssm) + + print(self.assister) def create_goal_calculator(self): - self.goal_calculator = goal_calculators.ZeroVelocityGoal(self.decoder.ssm) + if isinstance(self.decoder.ssm, StateSpaceEndptVel2D): + self.goal_calculator = goal_calculators.ZeroVelocityGoal(self.decoder.ssm) + elif isinstance(self.decoder.ssm, StateSpaceNLinkPlanarChain) and self.decoder.ssm.n_links == 2: + self.goal_calculator = goal_calculators.PlanarMultiLinkJointGoal(self.decoder.ssm, self.plant.base_loc, self.plant.kin_chain, multiproc=False, init_resp=None) + elif isinstance(self.decoder.ssm, StateSpaceNLinkPlanarChain) and self.decoder.ssm.n_links == 4: + shoulder_anchor = self.plant.base_loc + chain = self.plant.kin_chain + q_start = self.plant.get_intrinsic_coordinates() + x_init = np.hstack([q_start, np.zeros_like(q_start), 1]) + x_init = np.mat(x_init).reshape(-1, 1) + + cached = True + + if cached: + goal_calc_class = goal_calculators.PlanarMultiLinkJointGoalCached + multiproc = False + else: + goal_calc_class = goal_calculators.PlanarMultiLinkJointGoal + multiproc = True + + self.goal_calculator = goal_calc_class(namelist.tentacle_2D_state_space, shoulder_anchor, + chain, multiproc=multiproc, init_resp=x_init) + else: + raise ValueError("Unrecognized decoder state space!") + + def get_target_BMI_state(self, *args): + ''' + Run the goal calculator to determine the target state of the task + ''' + if isinstance(self.goal_calculator, goal_calculators.PlanarMultiLinkJointGoalCached): + task_eps = np.inf + else: + task_eps = 0.5 + ik_eps = task_eps/10 + data, solution_updated = self.goal_calculator(self.target_location, verbose=False, n_particles=500, eps=ik_eps, n_iter=10, q_start=self.plant.get_intrinsic_coordinates()) + target_state, error = data + if isinstance(self.goal_calculator, goal_calculators.PlanarMultiLinkJointGoal) and error > task_eps and solution_updated: + self.goal_calculator.reset() + + return np.array(target_state).reshape(-1,1) + + def _end_timeout_penalty(self): + if self.reset: + self.decoder.filt.state.mean = self.init_decoder_mean + self.hdf.sendMsg("reset") + + +class SimpleTargetCapture(BMIControlMultiNoWindow): + status = dict( + wait = dict(start_trial="target"), + target = dict(enter_target="wait", timeout="wait") + ) + + + #the following two functions from BMIControlMulti def _start_wait(self): self.wait_time = 0. - super(BMIControlMulti2DWindow, self)._start_wait() + super(SimpleTargetCapture, self)._start_wait() + + + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + + def _test_enter_target(self, ts): + ''' + return true if the distance between center of cursor and target is smaller than the cursor radius + ''' + cursor_pos = self.plant.get_endpoint_pos() + d = np.linalg.norm(cursor_pos - self.targs[self.target_index]) + + entered_target = (d <= (self.target_radius - self.cursor_radius)) + + if entered_target: self.sync_event('REWARD') + return entered_target + + def _test_timeout(self, time_in_state): + + timed_out = time_in_state > self.timeout_time + if timed_out: self.sync_event('HOLD_PENALTY') + return timed_out + + + def _start_target(self): + super()._start_target() + + + # Show target if it is hidden (this is the first target, or previous state was a penalty) + target = self.targets[self.target_index % 2] + self.target_location = self.targs[self.target_index] + + target.move_to_position(self.targs[self.target_index]) + #target.show() + self.sync_event('TARGET_ON', self.gen_indices[self.target_index]) + + self.decoder.filt.state.mean = np.matrix((0,0,0,0,0,0,1)).T + +class SimpleTargetCaptureWithHold(BMIControlMultiNoWindow): + + """ + this is mainly a simulation task where the user is instructed to move and cursor into the target and hold there. + + """ + + status = dict( + wait = dict(start_trial="target"), + target = dict(enter_target="hold", timeout="wait"), + hold = dict(leave_early="wait", hold_complete="wait"), + ) + + #the following two functions from BMIControlMulti + def _start_wait(self): + self.wait_time = 0. + super(SimpleTargetCaptureWithHold, self)._start_wait() + def _test_start_trial(self, ts): return ts > self.wait_time and not self.pause + + def _test_enter_target(self, ts): + ''' + return true if the distance between center of cursor and target is smaller than the cursor radius + ''' + cursor_pos = self.plant.get_endpoint_pos() + d = np.linalg.norm(cursor_pos - self.targs[self.target_index]) + + entered_target = (d <= (self.target_radius - self.cursor_radius)) + + #if entered_target: self.sync_event('CURSOR_ENTER_TARGET') + return entered_target + + + + def _test_timeout(self, time_in_state): + + timed_out = time_in_state > self.timeout_time + if timed_out: self.sync_event('TIMEOUT_PENALTY') + return timed_out + + + def _test_leave_early(self, ts): + ''' + return true if cursor moves outside the exit radius + ''' + cursor_pos = self.plant.get_endpoint_pos() + d = np.linalg.norm(cursor_pos - self.target_location) + rad = self.target_radius - self.cursor_radius + + leave_early = d > rad + + if leave_early: self.sync_event('HOLD_PENALTY') + return leave_early + + + def _test_hold_complete(self, ts): + hold_complete = ts>=self.hold_time + + if hold_complete: self.sync_event('REWARD') + + return hold_complete + + + def _start_target(self): + super()._start_target() + + + # Show target if it is hidden (this is the first target, or previous state was a penalty) + target = self.targets[self.target_index % 2] + self.target_location = self.targs[self.target_index] + + target.move_to_position(self.targs[self.target_index]) + #target.show() + #self.sync_event('TARGET_ON', self.gen_indices[self.target_index]) + + self.decoder.filt.state.mean = np.matrix((0,0,0,0,0,0,1)).T + diff --git a/built_in_tasks/cursorControlTasks_saveHDF.py b/built_in_tasks/cursorControlTasks_saveHDF.py new file mode 100644 index 000000000..917d7053e --- /dev/null +++ b/built_in_tasks/cursorControlTasks_saveHDF.py @@ -0,0 +1,128 @@ + +from manualcontrolmultitasks import ManualControlMulti +from riglib.stereo_opengl.window import WindowDispl2D +#from bmimultitasks import BMIControlMulti +import pygame +import numpy as np +import copy + +#from riglib.bmi.extractor import DummyExtractor +#from riglib.bmi.state_space_models import StateSpaceEndptVel2D +#from riglib.bmi.bmi import Decoder, BMISystem, GaussianStateHMM, BMILoop, GaussianState, MachineOnlyFilter +from riglib import experiment +from features.hdf_features import SaveHDF +from features.task_code_features import TaskCodeStreamer + +class CursorControl(ManualControlMulti, WindowDispl2D): + ''' + this class implements a python cursor control task for human + ''' + + def __init__(self, *args, **kwargs): + # just run the parent ManualControlMulti's initialization + self.move_step = 1 + + # Initialize target location variable + #target location and index have been initializd + + super(CursorControl, self).__init__(*args, **kwargs) + + def init(self): + pygame.init() + + + + self.assist_level = (0, 0) + super(CursorControl, self).init() + + # override the _cycle function + def _cycle(self): + #print(self.state) + + #target and plant data have been saved in + #the parent manualcontrolmultitasks + + self.move_effector_cursor() + super(CursorControl, self)._cycle() + + # do nothing + def move_effector(self): + pass + + def move_plant(self, **kwargs): + pass + + # use keyboard to control the task + def move_effector_cursor(self): + np.array([0., 0., 0.]) + curr_pos = copy.deepcopy(self.plant.get_endpoint_pos()) + + for event in pygame.event.get(): + if event.type == pygame.KEYUP: + if event.type == pygame.K_q: + pygame.quit() + quit() + if event.key == pygame.K_LEFT: + curr_pos[0] -= self.move_step + if event.key == pygame.K_RIGHT: + curr_pos[0] += self.move_step + if event.key == pygame.K_UP: + curr_pos[2] += self.move_step + if event.key == pygame.K_DOWN: + curr_pos[2] -= self.move_step + #print('Current position: ') + #print(curr_pos) + + # set the current position + self.plant.set_endpoint_pos(curr_pos) + + def _start_wait(self): + self.wait_time = 0. + super(CursorControl, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + +#this task can be run on its +#we will not involve database at this time +target_pos_radius = 10 + +def target_seq_generator(n_targs, n_trials): + #generate targets + angles = np.transpose(np.arange(0,2*np.pi,2*np.pi / n_targs)) + unit_targets = targets = np.stack((np.cos(angles), np.sin(angles)),1) + targets = unit_targets * target_pos_radius + + center = np.array((0,0)) + + target_inds = np.random.randint(0, n_targs, n_trials) + target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + + k = 0 + while k < n_trials: + targ = targets[target_inds[k], :] + yield np.array([[center[0], 0, center[1]], + [targ[0], 0, targ[1]]]) + k += 1 + + +if __name__ == "__main__": + print('Remember to set window size in stereoOpenGL class') + gen = target_seq_generator(8, 1000) + + #incorporate the saveHDF feature by blending code + #see tests\start_From_cmd_line_sim + + base_class = CursorControl + + feats = [SaveHDF, TaskCodeStreamer] + Exp = experiment.make(base_class, feats=feats) + print(Exp) + + exp = Exp(gen) + exp.init() + exp.run() #start the task + + + + \ No newline at end of file diff --git a/built_in_tasks/manualcontrolmultitasks.py b/built_in_tasks/manualcontrolmultitasks.py index c3147570b..2d8c695f5 100644 --- a/built_in_tasks/manualcontrolmultitasks.py +++ b/built_in_tasks/manualcontrolmultitasks.py @@ -1,187 +1,1331 @@ ''' -Virtual target capture tasks where cursors are controlled by physical -motion interfaces such as joysticks +Base tasks for generic point-to-point reaching ''' + import numpy as np from collections import OrderedDict import time + +from riglib import reward # This import file corresponds to the Orsborn lab reward system now +from riglib.experiment import traits, Sequence + +from riglib.stereo_opengl.window import Window, FPScontrol, WindowDispl2D +from riglib.stereo_opengl.primitives import Cylinder, Plane, Sphere, Cube +from riglib.stereo_opengl.models import FlatMesh, Group +from riglib.stereo_opengl.textures import Texture, TexModel +from riglib.stereo_opengl.render import stereo, Renderer +from riglib.stereo_opengl.utils import cloudy_tex + +from target_capture_task import plantlist + +from riglib.stereo_opengl import ik import os + import math import traceback -from riglib.experiment import traits - -from .target_graphics import * -from .target_capture_task import ScreenTargetCapture, ScreenReachAngle -from riglib.stereo_opengl.window import WindowDispl2D - - -rotations = dict( - yzx = np.array( - [[0, 1, 0, 0], - [0, 0, 1, 0], - [1, 0, 0, 0], - [0, 0, 0, 1]] - ), - zyx = np.array( - [[0, 0, 1, 0], - [0, 1, 0, 0], - [1, 0, 0, 0], - [0, 0, 0, 1]] - ), - xzy = np.array( - [[1, 0, 0, 0], - [0, 0, 1, 0], - [0, 1, 0, 0], - [0, 0, 0, 1]] - ), - xyz = np.identity(4), -) - -class ManualControlMixin(traits.HasTraits): - '''Target capture task where the subject operates a joystick - to control a cursor. Targets are captured by having the cursor - dwell in the screen target for the allotted time''' - - # Settable Traits - wait_time = traits.Float(2., desc="Time between successful trials") - velocity_control = traits.Bool(False, desc="Position or velocity control") - random_rewards = traits.Bool(False, desc="Add randomness to reward") - rotation = traits.OptionsList(*rotations, desc="Control rotation matrix", bmi3d_input_options=list(rotations.keys())) - scale = traits.Float(1.0, desc="Control scale factor") - offset = traits.Array(value=[0,0,0], desc="Control offset") +####### CONSTANTS +sec_per_min = 60.0 +RED = (1,0,0,.5) +GREEN = (0,1,0,0.5) +GOLD = (1., 0.843, 0., 0.5) +mm_per_cm = 1./10 + +from built_in_tasks.target_graphics import * + +target_colors = { +"yellow": (1,1,0,0.75), +"magenta": (1,0,1,0.75), +"purple":(0.608,0.188,1,0.75), +"dodgerblue": (0.118,0.565,1,0.75), +"teal":(0,0.502,0.502,0.75), +"olive":(0.420,0.557,0.137,.75), +"juicyorange": (1,0.502,0.,0.75), +"hotpink":(1,0.0,0.606,.75), +"lightwood": (0.627,0.322,0.176,0.75), +"elephant":(0.409,0.409,0.409,0.5), +"green":(0., 1., 0., 0.5)} + + +class ManualControlMulti(Sequence, Window): + ''' + This is an improved version of the original manual control tasks that includes the functionality + of ManualControl, ManualControl2, and TargetCapture all in a single task. This task doesn't + assume anything about the trial structure of the task and allows a trial to consist of a sequence + of any number of sequential targets that must be captured before the reward is triggered. The number + of targets per trial is determined by the structure of the target sequence used. + ''' + + background = (0,0,0,1) + cursor_color = (.5,0,.5,1) + + plant_type = traits.OptionsList(*plantlist, desc='', bmi3d_input_options=list(plantlist.keys())) + + starting_pos = (5, 0, 5) + + status = dict( + wait = dict(start_trial="target", stop=None), + target = dict(enter_target="hold", timeout="timeout_penalty", stop=None), + hold = dict(leave_early="hold_penalty", hold_complete="targ_transition", stop=None), + targ_transition = dict(trial_complete="reward",trial_abort="wait", trial_incomplete="target", stop=None), + timeout_penalty = dict(timeout_penalty_end="targ_transition", stop=None), + hold_penalty = dict(hold_penalty_end="targ_transition", stop=None), + reward = dict(reward_end="wait") + ) + trial_end_states = ['reward', 'timeout_penalty', 'hold_penalty'] + + #initial state + state = "wait" + + target_color = (1,0,0,.5) + target_index = -1 # Helper variable to keep track of which target to display within a trial + tries = 0 # Helper variable to keep track of the number of failed attempts at a given trial. + + cursor_visible = False # Determines when to hide the cursor. + no_data_count = 0 # Counter for number of missing data frames in a row + scale_factor = 3.0 #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement) + + limit2d = 1 + + sequence_generators = ['centerout_2D_discrete', 'centerout_2D_discrete_offset', 'point_to_point_3D', 'centerout_3D', 'centerout_3D_cube', 'centerout_2D_discrete_upper','centerout_2D_discrete_rot', 'centerout_2D_discrete_multiring', + 'centerout_2D_discrete_randorder', 'centeroutback_2D', 'centeroutback_2D_farcatch', 'centeroutback_2D_farcatch_discrete', + 'outcenterout_2D_discrete', 'outcenter_2D_discrete', 'rand_target_sequence_3d', 'rand_target_sequence_2d', 'rand_target_sequence_2d_centerout', + 'rand_target_sequence_2d_partial_centerout', 'rand_multi_sequence_2d_centerout2step', 'rand_pt_to_pt', + 'centerout_2D_discrete_far', 'centeroutback_2D_v2','centerout_2D_discrete_eyetracker_calibration'] is_bmi_seed = True + _target_color = RED + + + # Runtime settable traits + reward_time = traits.Float(.2, desc="Length of juice reward") + target_radius = traits.Float(2, desc="Radius of targets in cm") + + hold_time = traits.Float(.2, desc="Length of hold required at targets") + hold_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error") + timeout_time = traits.Float(10, desc="Time allowed to go between targets") + timeout_penalty_time = traits.Float(1, desc="Length of penalty time for timeout error") + max_attempts = traits.Int(10, desc='The number of attempts at a target before\ + skipping to the next one') + # session_length = traits.Float(0, desc="Time until task automatically stops. Length of 0 means no auto stop.") + marker_num = traits.Int(14, desc="The index of the motiontracker marker to use for cursor position") + # NOTE!!! The marker on the hand was changed from #0 to #14 on + # 5/19/13 after LED #0 broke. All data files saved before this date + # have LED #0 controlling the cursor. + plant_hide_rate = traits.Float(0.0, desc='If the plant is visible, specifies a percentage of trials where it will be hidden') + plant_type_options = list(plantlist.keys()) + plant_type = traits.OptionsList(*plantlist, bmi3d_input_options=list(plantlist.keys())) + plant_visible = traits.Bool(True, desc='Specifies whether entire plant is displayed or just endpoint') + cursor_radius = traits.Float(.5, desc="Radius of cursor") + def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.current_pt=np.zeros([3]) #keep track of current pt - self.last_pt=np.zeros([3]) #keep track of last pt to calc. velocity - self._quality_window_size = 500 # how many cycles to accumulate quality statistics - self.reportstats['Input quality'] = "100 %" - if self.random_rewards: - self.reward_time_base = self.reward_time + super(ManualControlMulti, self).__init__(*args, **kwargs) + self.cursor_visible = True + + # Initialize the plant + if not hasattr(self, 'plant'): + self.plant = plantlist[self.plant_type] + self.plant_vis_prev = True + + # Add graphics models for the plant and targets to the window + if hasattr(self.plant, 'graphics_models'): + for model in self.plant.graphics_models: + self.add_model(model) + + # Instantiate the targets + instantiate_targets = kwargs.pop('instantiate_targets', True) + if instantiate_targets: + target1 = VirtualCircularTarget(target_radius=self.target_radius, target_color=self._target_color) + target2 = VirtualCircularTarget(target_radius=self.target_radius, target_color=self._target_color) + + self.targets = [target1, target2] + for target in self.targets: + for model in target.graphics_models: + self.add_model(model) + + # Initialize target location variable + self.target_location = np.array([0, 0, 0]) + + # Declare any plant attributes which must be saved to the HDF file at the _cycle rate + for attr in self.plant.hdf_attrs: + self.add_dtype(*attr) def init(self): - self.add_dtype('manual_input', 'f8', (3,)) - super().init() - self.no_data_counter = np.zeros((self._quality_window_size,), dtype='?') + self.add_dtype('target', 'f8', (3,)) + self.add_dtype('target_index', 'i', (1,)) + super(ManualControlMulti, self).init() - def _test_start_trial(self, ts): - return ts > self.wait_time and not self.pause + def _cycle(self): + ''' + Calls any update functions necessary and redraws screen. Runs 60x per second. + ''' + self.task_data['target'] = self.target_location.copy() + self.task_data['target_index'] = self.target_index + + ## Run graphics commands to show/hide the plant if the visibility has changed + if self.plant_type != 'CursorPlant': + if self.plant_visible != self.plant_vis_prev: + self.plant_vis_prev = self.plant_visible + self.plant.set_visibility(self.plant_visible) + # self.show_object(self.plant, show=self.plant_visible) + + self.move_effector() + + ## Save plant status to HDF file + plant_data = self.plant.get_data_to_save() + for key in plant_data: + self.task_data[key] = plant_data[key] + + super(ManualControlMulti, self)._cycle() + + def move_effector(self): + ''' Sets the plant configuration based on motiontracker data. For manual control, uses + motiontracker data. If no motiontracker data available, returns None''' + + #get data from motion tracker- take average of all data points since last poll + pt = self.motiondata.get() + if len(pt) > 0: + pt = pt[:, self.marker_num, :] + conds = pt[:, 3] + inds = np.nonzero((conds>=0) & (conds!=4))[0] + if len(inds) > 0: + pt = pt[inds,:3] + #scale actual movement to desired amount of screen movement + pt = pt.mean(0) * self.scale_factor + #Set y coordinate to 0 for 2D tasks + if self.limit2d: pt[1] = 0 + pt[1] = pt[1]*2 + # Return cursor location + self.no_data_count = 0 + pt = pt * mm_per_cm #self.convert_to_cm(pt) + else: #if no usable data + self.no_data_count += 1 + pt = None + else: #if no new data + self.no_data_count +=1 + pt = None + + # Set the plant's endpoint to the position determined by the motiontracker, unless there is no data available + if pt is not None: + self.plant.set_endpoint_pos(pt) + + def run(self): + ''' + See experiment.Experiment.run for documentation. + ''' + # Fire up the plant. For virtual/simulation plants, this does little/nothing. + self.plant.start() + try: + super(ManualControlMulti, self).run() + finally: + self.plant.stop() + + ##### HELPER AND UPDATE FUNCTIONS #### + def update_cursor_visibility(self): + ''' Update cursor visible flag to hide cursor if there has been no good data for more than 3 frames in a row''' + prev = self.cursor_visible + if self.no_data_count < 3: + self.cursor_visible = True + if prev != self.cursor_visible: + self.show_object(self.cursor, show=True) + else: + self.cursor_visible = False + if prev != self.cursor_visible: + self.show_object(self.cursor, show=False) + + def update_report_stats(self): + ''' + see experiment.Experiment.update_report_stats for docs + ''' + super(ManualControlMulti, self).update_report_stats() + self.reportstats['Trial #'] = self.calc_trial_num() + self.reportstats['Reward/min'] = np.round(self.calc_events_per_min('reward', 120.), decimals=2) + + #### TEST FUNCTIONS #### + def _test_enter_target(self, ts): + ''' + return true if the distance between center of cursor and target is smaller than the cursor radius + ''' + cursor_pos = self.plant.get_endpoint_pos() + d = np.linalg.norm(cursor_pos - self.target_location) + return d <= (self.target_radius - self.cursor_radius) + + def _test_leave_early(self, ts): + ''' + return true if cursor moves outside the exit radius + ''' + cursor_pos = self.plant.get_endpoint_pos() + d = np.linalg.norm(cursor_pos - self.target_location) + rad = self.target_radius - self.cursor_radius + return d > rad + + def _test_hold_complete(self, ts): + return ts>=self.hold_time + + def _test_timeout(self, ts): + return ts>self.timeout_time + + def _test_timeout_penalty_end(self, ts): + return ts>self.timeout_penalty_time + + def _test_hold_penalty_end(self, ts): + return ts>self.hold_penalty_time + def _test_trial_complete(self, ts): + return self.target_index==self.chain_length-1 + + def _test_trial_incomplete(self, ts): + return (not self._test_trial_complete(ts)) and (self.triesself.reward_time + + #### STATE FUNCTIONS #### + def _parse_next_trial(self): + self.targs = self.next_trial + + def _start_wait(self): + super(ManualControlMulti, self)._start_wait() + self.tries = 0 + self.target_index = -1 + #hide targets + for target in self.targets: + target.hide() + + self.chain_length = self.targs.shape[0] #Number of sequential targets in a single trial + + def _start_target(self): + self.target_index += 1 + + #move a target to current location (target1 and target2 alternate moving) and set location attribute + target = self.targets[self.target_index % 2] + self.target_location = self.targs[self.target_index] + target.move_to_position(self.target_location) + target.cue_trial_start() + + def _start_hold(self): + #make next target visible unless this is the final target in the trial + idx = (self.target_index + 1) + if idx < self.chain_length: + target = self.targets[idx % 2] + target.move_to_position(self.targs[idx]) + + def _end_hold(self): + # change current target color to green + self.targets[self.target_index % 2].cue_trial_end_success() + + def _start_hold_penalty(self): + #hide targets + for target in self.targets: + target.hide() + + self.tries += 1 + self.target_index = -1 + + def _start_timeout_penalty(self): + #hide targets + for target in self.targets: + target.hide() + + self.tries += 1 + self.target_index = -1 + + def _start_targ_transition(self): + #hide targets + for target in self.targets: + target.hide() + + def _start_reward(self): + #super(ManualControlMulti, self)._start_reward() + self.targets[self.target_index % 2].show() + + #### Generator functions #### + @staticmethod + def point_to_point_3D(length=2000, boundaries=(-18,18,-10,10,-15,15), distance=10, chain_length=2):1 + + @staticmethod + def centerout_3D(length=1000, boundaries=(-18,18,-10,10,-15,15),distance=8): + # Choose a random sequence of points on the surface of a sphere of radius + # "distance" + theta = np.random.rand(length)*2*np.pi + phi = np.arccos(2*np.random.rand(length) - 1) + x = distance*np.cos(theta)*np.sin(phi) + y = distance*np.sin(theta)*np.sin(phi) + z = distance*np.cos(theta) + + pairs = np.zeros([length,2,3]) + pairs[:,1,0] = x + pairs[:,1,1] = y + pairs[:,1,2] = z + + return pairs + + @staticmethod + def centerout_3D_cube(length=1000, edge_length=8): + ''' + Choose a random sequence of points on the surface of a sphere of radius + "distance" + ''' + coord = [-float(edge_length)/2, float(edge_length)/2] + from itertools import product + target_locs = [(x, y, z) for x, y, z in product(coord, coord, coord)] + + n_corners_in_cube = 8 + pairs = np.zeros([length, 2, 3]) + + for k in range(length): + pairs[k, 0, :] = np.zeros(3) + pairs[k, 1, :] = target_locs[np.random.randint(0, n_corners_in_cube)] + + print(pairs.shape) + return pairs + + @staticmethod + def centerout_2D_discrete(nblocks=100, ntargets=8, boundaries=(-18,18,-12,12), + distance=10): + ''' + + Generates a sequence of 2D (x and z) target pairs with the first target + always at the origin. + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between the targets in a pair. + + Returns + ------- + pairs : [nblocks*ntargets x 2 x 3] array of pairs of target locations + + + ''' + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + + theta = [] + for i in range(nblocks): + temp = np.arange(0, 2*np.pi, 2*np.pi/ntargets) + np.random.shuffle(temp) + theta = theta + [temp] + theta = np.hstack(theta) + + + x = distance*np.cos(theta) + y = np.zeros(len(theta)) + z = distance*np.sin(theta) + + pairs = np.zeros([len(theta), 2, 3]) + pairs[:,1,:] = np.vstack([x, y, z]).T + + return pairs + + @staticmethod + def centerout_2D_discrete_offset(nblocks=100, ntargets=8, boundaries=(-18,18,-12,12), + distance=5,xoffset = -8, zoffset = 0, centeroffset = -8): + ''' + + Generates a sequence of 2D (x and z) target pairs with the first target + always at the origin (offset from center of screen). + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between the targets in a pair. + + Returns + ------- + pairs : [nblocks*ntargets x 2 x 3] array of pairs of target locations + + + ''' + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + + theta = [] + for i in range(nblocks): + temp = np.arange(0, 2*np.pi, 2*np.pi/ntargets) + np.random.shuffle(temp) + theta = theta + [temp] + theta = np.hstack(theta) + + + x = distance*np.cos(theta)+xoffset + y = np.zeros(len(theta)) + z = distance*np.sin(theta)+zoffset + + pairs = np.zeros([len(theta), 2, 3]) + pairs[:,1,:] = np.vstack([x, y, z]).T + pairs[:,0,:] = np.array([centeroffset, 0, 0]) + + return pairs + + @staticmethod + def centerout_2D_discrete_eyetracker_calibration(nblocks=100, ntargets=4, boundaries=(-18,18,-12,12), + distance=10): + ''' + + Generates a sequence of 2D (x and z) target pairs with the first target + always at the origin. The order required by the eye-tracker calibration + is Center, Left, Right, Up, Down. The sequence generator therefore + displays 4- trial sequences in the following order: + Center-Left (C-L), C-R, C-U, C-D. + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between the targets in a pair. + + Returns + ------- + pairs : [nblocks*ntargets x 2 x 3] array of pairs of target locations + + + ''' + + # Create a LRUD (Left Right Up Down) sequence of points on the edge of + # a circle of radius "distance" + + theta = [] + for i in range(nblocks): + temp = np.arange(0, 2*np.pi, 2*np.pi/ntargets) # ntargets = 4 --> shape of a + + temp2 = np.array([temp[2],temp[0],temp[1],temp[3]]) # Left Right Up Down + theta = theta + [temp2] + theta = np.hstack(theta) + + + x = distance*np.cos(theta) + y = np.zeros(len(theta)) + z = distance*np.sin(theta) + + pairs = np.zeros([len(theta), 2, 3]) + pairs[:,1,:] = np.vstack([x, y, z]).T + + return pairs + + @staticmethod + def centerout_2D_discrete_upper(nblocks=100, ntargets=8, boundaries=(-18,18,-12,12), + distance=10): + '''Same as centerout_2D_discrete, but rotates position of targets by 'rotate_deg'. + For example, if you wanted only 1 target, but sometimes wanted it at pi and sometiems at 3pi/2, + you could rotate it by 90 degrees''' + + theta = [] + for i in range(nblocks): + temp = np.arange(0, np.pi, np.pi/ntargets) + np.random.shuffle(temp) + theta = theta + [temp] + theta = np.hstack(theta) + + + x = distance*np.cos(theta) + y = np.zeros(len(theta)) + z = distance*np.sin(theta) + + pairs = np.zeros([len(theta), 2, 3]) + pairs[:,1,:] = np.vstack([x, y, z]).T + + return pairs + + @staticmethod + def centerout_2D_discrete_rot(nblocks=100, ntargets=8, boundaries=(-18,18,-12,12), + distance=10,rotate_deg=0): + '''Same as centerout_2D_discrete, but rotates position of targets by 'rotate_deg'. + For example, if you wanted only 1 target, but sometimes wanted it at pi and sometiems at 3pi/2, + you could rotate it by 90 degrees''' + + theta = [] + for i in range(nblocks): + temp = np.arange(0, 2*np.pi, 2*np.pi/ntargets) + temp = temp + (rotate_deg)*(2*np.pi/360) + np.random.shuffle(temp) + theta = theta + [temp] + theta = np.hstack(theta) + + + x = distance*np.cos(theta) + y = np.zeros(len(theta)) + z = distance*np.sin(theta) + + pairs = np.zeros([len(theta), 2, 3]) + pairs[:,1,:] = np.vstack([x, y, z]).T + + return pairs + + @staticmethod + def centerout_2D_discrete_multiring(n_blocks=100, n_angles=8, boundaries=(-18,18,-12,12), + distance=10, n_rings=2): + ''' + + Generates a sequence of 2D (x and z) target pairs with the first target + always at the origin. + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between the targets in a pair. + + Returns + ------- + pairs : [nblocks*ntargets x 2 x 3] array of pairs of target locations + + ''' + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + target_set = [] + angles = np.arange(0, 2*np.pi, 2*np.pi/n_angles) + distances = np.arange(0, distance + 1, float(distance)/n_rings)[1:] + for angle in angles: + for dist in distances: + targ = np.array([np.cos(angle), 0, np.sin(angle)]) * dist + target_set.append(targ) + + target_set = np.vstack(target_set) + n_targets = len(target_set) + + + periph_target_list = [] + for k in range(n_blocks): + target_inds = np.arange(n_targets) + np.random.shuffle(target_inds) + periph_target_list.append(target_set[target_inds]) + + periph_target_list = np.vstack(periph_target_list) + + + pairs = np.zeros([len(periph_target_list), 2, 3]) + pairs[:,1,:] = periph_target_list#np.vstack([x, y, z]).T + + return pairs + + @staticmethod + def centerout_2D_discrete_far(nblocks=100, ntargets=8, xmax=25, xmin=-25, zmin=-14, zmax=14, distance=10): + target_angles = np.arange(0, 2*np.pi, 2*np.pi/ntargets) + target_pos = np.vstack([np.cos(target_angles), np.zeros_like(target_angles), np.sin(target_angles)]).T*distance + target_pos = np.vstack([targ for targ in target_pos if targ[0] < xmax and targ[0] > xmin and targ[2] > zmin and targ[2] < zmax]) + + from riglib.experiment.generate import block_random + periph_targets_per_trial = block_random(target_pos, nblocks=nblocks) + target_seqs = [] + for targ in periph_targets_per_trial: + target_seqs.append(np.vstack([np.zeros(3), targ])) + return target_seqs + + @staticmethod + def centerout_2D_discrete_randorder(nblocks=100, ntargets=8, boundaries=(-18,18,-12,12), + distance=10): + ''' + + Generates a sequence of 2D (x and z) target pairs with the first target + always at the origin, totally randomized instead of block randomized. + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between the targets in a pair. + + Returns + ------- + pairs : [nblocks*ntargets x 2 x 3] array of pairs of target locations + + + ''' + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + + theta = [] + for i in range(nblocks): + temp = np.arange(0, 2*np.pi, 2*np.pi/ntargets) + np.random.shuffle(temp) + theta = theta + [temp] + theta = np.hstack(theta) + np.random.shuffle(theta) + + + x = distance*np.cos(theta) + y = np.zeros(len(theta)) + z = distance*np.sin(theta) + + pairs = np.zeros([len(theta), 2, 3]) + pairs[:,1,:] = np.vstack([x, y, z]).T + + return pairs + + @staticmethod + def centeroutback_2D(length, boundaries=(-18,18,-12,12), distance=8): + ''' + Generates a sequence of 2D (x and z) center-edge-center target triplets. + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between consecutive targets. + + Returns + ------- + targs : [length x 3 x 3] array of pairs of target locations + + + ''' + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + theta = np.random.rand(length)*2*np.pi + x = distance*np.cos(theta) + z = distance*np.sin(theta) + + # Join start and end points together in a [trial x coordinate x start/end] + # array (fill in zeros for endpoint y values) + targs = np.zeros([length, 3, 3]) + targs[:,1,0] = x + targs[:,1,2] = z + + return targs + + @staticmethod + def centeroutback_2D_v2(length, boundaries=(-18,18,-12,12), distance=8): + ''' + This fn exists purely for compatibility reasons? + ''' + return centeroutback_2D(length, boundaries=boundaries, distance=distance) + + @staticmethod + def centeroutback_2D_farcatch(length, boundaries=(-18,18,-12,12), distance=8, catchrate=.1): + ''' + + Generates a sequence of 2D (x and z) center-edge-center target triplets, with occasional + center-edge-far edge catch trials thrown in. + + Parameters + ---------- + length : int + The number of target sets in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between consecutive targets. + catchrate: float + The percent of trials that are far target catch trials. If distance*2 + + Returns + ------- + targs : [length x 3 x 3] array of pairs of target locations + + + ''' + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" and a corresponding set of points on circle of raidus 2*distance + theta = np.random.rand(length)*2*np.pi + x = distance*np.cos(theta) + z = distance*np.sin(theta) + + x2 = 2*x + x2[np.nonzero((x2boundaries[1]))] = np.nan + + z2 = 2*z + z2[np.nonzero((z2boundaries[3]))] = np.nan + + outertargs = np.zeros([length, 3]) + outertargs[:,0] = x2 + outertargs[:,2] = z2 + + # Join start and end points together in a [trial x coordinate x start/end] + # array (fill in zeros for endpoint y values) + targs = np.zeros([length, 3, 3]) + targs[:,1,0] = x + targs[:,1,2] = z + + # shuffle order of indices and select specified percent to use for catch trials + numcatch = int(length*catchrate) + shuffinds = np.array(list(range(length))) + np.random.shuffle(shuffinds) + replace = shuffinds[:numcatch] + count = numcatch + + while np.any(np.isnan(outertargs[list(replace)])): + replace = replace[~np.isnan(np.sum(outertargs[list(replace)],axis=1))] + diff = numcatch - len(replace) + new = shuffinds[count:count+diff] + replace = np.concatenate((replace, new)) + count += diff + + targs[list(replace),2,:] = outertargs[list(replace)] + + return targs + + @staticmethod + def centeroutback_2D_farcatch_discrete(nblocks=100, ntargets=8, boundaries=(-18,18,-12,12), + distance=8, catchrate=0): + ''' + + Generates a sequence of 2D (x and z) center-edge-center target triplets, with occasional + center-edge-far edge catch trials thrown in. + + Parameters + ---------- + length : int + The number of target sets in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between consecutive targets. + catchrate: float + The percent of trials that are far target catch trials. If distance*2 + + Returns + ------- + targs : [length x 3 x 3] array of pairs of target locations + + + ''' + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + + theta = [] + for i in range(nblocks): + temp = np.arange(0, 2*np.pi, 2*np.pi/ntargets) + np.random.shuffle(temp) + theta = theta + [temp] + theta = np.hstack(theta) + + x = distance*np.cos(theta) + z = distance*np.sin(theta) + + # Choose a corresponding set of points on circle of radius 2*distance. Mark any points that + # are outside specified boundaries with nans + + x2 = 2*x + x2[np.nonzero((x2boundaries[1]))] = np.nan + + z2 = 2*z + z2[np.nonzero((z2boundaries[3]))] = np.nan + + outertargs = np.zeros([nblocks*ntargets, 3]) + outertargs[:,0] = x2 + outertargs[:,2] = z2 + + # Join start and end points together in a [trial x coordinate x start/end] + # array (fill in zeros for endpoint y values) + targs = np.zeros([nblocks*ntargets, 3, 3]) + targs[:,1,0] = x + targs[:,1,2] = z + + # shuffle order of indices and select specified percent to use for catch trials + numcatch = int(nblocks*ntargets*catchrate) + shuffinds = np.array(list(range(nblocks*ntargets))) + np.random.shuffle(shuffinds) + replace = shuffinds[:numcatch] + count = numcatch + + while np.any(np.isnan(outertargs[list(replace)])): + replace = replace[~np.isnan(np.sum(outertargs[list(replace)],axis=1))] + diff = numcatch - len(replace) + new = shuffinds[count:count+diff] + replace = np.concatenate((replace, new)) + count += diff + + targs[list(replace),2,:] = outertargs[list(replace)] + + return targs + + @staticmethod + def outcenterout_2D_discrete(nblocks=100, ntargets=8, boundaries=(-18,18,-12,12), + distance=8): + ''' + + Generates a sequence of 2D (x and z) center-edge-center target triplets, with occasional + center-edge-far edge catch trials thrown in. + + Parameters + ---------- + length : int + The number of target sets in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between consecutive targets. + catchrate: float + The percent of trials that are far target catch trials. If distance*2 + + Returns + ------- + targs : [length x 3 x 3] array of pairs of target locations + + + ''' + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + + theta = [] + for i in range(nblocks): + temp = np.arange(0, 2*np.pi, 2*np.pi/ntargets) + np.random.shuffle(temp) + theta = theta + [temp] + theta = np.hstack(theta) + + x = distance*np.cos(theta) + z = distance*np.sin(theta) + + # Join start and end points together in a [trial x coordinate x start/end] + # array (fill in zeros for endpoint y values) + targs = np.zeros([nblocks*ntargets, 3, 3]) + targs[:,0,0] = x + targs[:,2,0] = x + targs[:,0,2] = z + targs[:,2,2] = z + + return targs + + @staticmethod + def outcenter_2D_discrete(nblocks=100, ntargets=4, boundaries=(-18,18,-12,12), + distance=8, startangle=np.pi/4): + ''' + + Generates a sequence of 2D (x and z) center-edge-center target triplets, with occasional + center-edge-far edge catch trials thrown in. + + Parameters + ---------- + length : int + The number of target sets in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between consecutive targets. + catchrate: float + The percent of trials that are far target catch trials. If distance*2 + + Returns + ------- + targs : [length x 3 x 3] array of pairs of target locations + ''' + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + + theta = [] + for i in range(nblocks): + temp = np.arange(startangle, startangle+(2*np.pi), 2*np.pi/ntargets) + np.random.shuffle(temp) + theta = theta + [temp] + theta = np.hstack(theta) + + x = distance*np.cos(theta) + z = distance*np.sin(theta) + + # Join start and end points together in a [trial x coordinate x start/end] + # array (fill in zeros for endpoint y values) + targs = np.zeros([nblocks*ntargets, 2, 3]) + targs[:,0,0] = x + targs[:,0,2] = z + + return targs + + @staticmethod + def rand_target_sequence_3d(length, boundaries=(-18,18,-10,10,-15,15), distance=10): + ''' + + Generates a sequence of 3D target pairs. + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -y, y, -z, z) + distance : float + The distance in cm between the targets in a pair. + + Returns + ------- + pairs : [n x 3 x 2] array of pairs of target locations + + + ''' + + # Choose a random sequence of points at least "distance" from the edge of + # the allowed area + pts = np.random.rand(length, 3)*((boundaries[1]-boundaries[0]-2*distance), + (boundaries[3]-boundaries[2]-2*distance), + (boundaries[5]-boundaries[4]-2*distance)) + pts = pts+(boundaries[0]+distance,boundaries[2]+distance, + boundaries[4]+distance) + + # Choose a random sequence of points on the surface of a sphere of radius + # "distance" + theta = np.random.rand(length)*2*np.pi + phi = np.arccos(2*np.random.rand(length) - 1) + x = distance*np.cos(theta)*np.sin(phi) + y = distance*np.sin(theta)*np.sin(phi) + z = distance*np.cos(theta) + + # Shift points to correct position relative to first sequence and join two + # sequences together in a trial x coordinate x start/end point array + pts2 = np.array([x,y,z]).transpose([1,0]) + pts + pairs = np.array([pts, pts2]).transpose([1,2,0]) + copy = pairs[0:length//2,:,:].copy() + + # Swap start and endpoint for first half of the pairs + pairs[0:length//2,:,0] = copy[:,:,1] + pairs[0:length//2,:,1] = copy[:,:,0] + + # Shuffle list of pairs + np.random.shuffle(pairs) + + return pairs + + @staticmethod + def rand_pt_to_pt(length=100, boundaries=(-18,18,-12,12), buf=2, seq_len=2): + ''' + Generates sequences of random postiions in the XZ plane + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between the targets in a pair. + + Returns + ------- + list + Each element of the list is an array of shape (seq_len, 3) indicating the target + positions to be acquired for the trial. + ''' + xmin, xmax, zmin, zmax = boundaries + L = length*seq_len + pts = np.vstack([np.random.uniform(xmin+buf, xmax-buf, L), + np.zeros(L), np.random.uniform(zmin+buf, zmax-buf, L)]).T + targ_seqs = [] + for k in range(length): + targ_seqs.append(pts[k*seq_len:(k+1)*seq_len]) + return targ_seqs + + @staticmethod + def rand_target_sequence_2d(length, boundaries=(-18,18,-12,12), distance=10): + ''' + + Generates a sequence of 2D (x and z) target pairs. + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between the targets in a pair. + + Returns + ------- + pairs : [n x 2 x 2] array of pairs of target locations + ''' + + # Choose a random sequence of points at least "distance" from the boundaries + pts = np.random.rand(length, 3)*((boundaries[1]-boundaries[0]-2*distance), + 0, (boundaries[3]-boundaries[2]-2*distance)) + pts = pts+(boundaries[0]+distance, 0, boundaries[2]+distance) + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + theta = np.random.rand(length)*2*np.pi + x = distance*np.cos(theta) + z = distance*np.sin(theta) + + # Shift points to correct position relative to first sequence and join two + # sequences together in a trial x coordinate x start/end point array + pts2 = np.array([x,np.zeros(length),z]).transpose([1,0]) + pts + pairs = np.array([pts, pts2]).transpose([1,2,0]) + copy = pairs[0:length//2,:,:].copy() + + # Swap start and endpoint for first half of the pairs + pairs[0:length//2,:,0] = copy[:,:,1] + pairs[0:length//2,:,1] = copy[:,:,0] + + # Shuffle list of pairs + np.random.shuffle(pairs) + + return pairs + + @staticmethod + def rand_target_sequence_2d_centerout(length, boundaries=(-18,18,-12,12), + distance=10): + ''' + + Generates a sequence of 2D (x and z) target pairs with the first target + always at the origin. + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between the targets in a pair. + + Returns + ------- + pairs : [n x 2 x 2] array of pairs of target locations + + + ''' + + # Create list of origin targets + pts1 = np.zeros([length,3]) + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + theta = np.random.rand(length)*2*np.pi + x = distance*np.cos(theta) + z = distance*np.sin(theta) + + # Join start and end points together in a [trial x coordinate x start/end] + # array (fill in zeros for endpoint y values) + pts2 = np.array([x,np.zeros(length),z]).transpose([1,0]) + pairs = np.array([pts1, pts2]).transpose([1,2,0]) + + return pairs + + @staticmethod + def rand_target_sequence_2d_partial_centerout(length, boundaries=(-18,18,-12,12),distance=10,perc_z=20): + ''' + PK + Generates a sequence of 2D (x and z) target pairs with the first target + always at the origin. + + Parameters + ---------- + perc_z : float + The percentage of the z axis to be used for targets + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between the targets in a pair. + + + Returns + ------- + pairs : [n x 2 x 2] array of pairs of target locations + + + ''' + # Need to get perc_z from settable traits + # perc_z = traits.Float(0.1, desc="Percent of Y axis that targets move along") + perc_z=float(10) + # Create list of origin targets + pts1 = np.zeros([length,3]) + + # Choose a random sequence of points on the edge of a circle of radius + # "distance" + + #Added PK -- to confine z value according to entered boundaries: + theta_max = math.asin(boundaries[3]/distance*(perc_z)/float(100)) + theta = (np.random.rand(length)-0.5)*2*theta_max + + #theta = np.random.rand(length)*2*np.pi + + x = distance*np.cos(theta)*(np.ones(length)*-1)**np.random.randint(1,3,length) + z = distance*np.sin(theta) + + # Join start and end points together in a [trial x coordinate x start/end] + # array (fill in zeros for endpoint y values) + pts2 = np.array([x,np.zeros(length),z]).transpose([1,0]) + + pairs = np.array([pts1, pts2]).transpose([1,2,0]) + + return pairs + + @staticmethod + def rand_multi_sequence_2d_centerout2step(length, boundaries=(-20,20,-12,12), distance=10): + ''' + + Generates a sequence of 2D (x and z) center-edge-far edge target triplets. + + Parameters + ---------- + length : int + The number of target pairs in the sequence. + boundaries: 6 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + distance : float + The distance in cm between consecutive targets. + + Returns + ------- + targs : [length x 3 x 3] array of pairs of target locations + + + ''' + + # Create list of origin targets + pts1 = np.zeros([length,3]) + + # Choose a random sequence of points on the edge of a circle of radius + # "distance", and matching set with radius distance*2 + theta = np.random.rand(length*10)*2*np.pi + x1 = distance*np.cos(theta) + z1 = distance*np.sin(theta) + x2 = distance*2*np.cos(theta) + z2 = distance*2*np.sin(theta) + + mask = np.logical_and(np.logical_and(x2>=boundaries[0],x2<=boundaries[1]),np.logical_and(z2>=boundaries[2],z2<=boundaries[3])) + + # Join start and end points together in a [trial x coordinate x start/end] + # array (fill in zeros for endpoint y values) + pts2 = np.array([x1[mask][:length], np.zeros(length), z1[mask][:length]]).transpose([1,0]) + pts3 = np.array([x2[mask][:length], np.zeros(length), z2[mask][:length]]).transpose([1,0]) + targs = np.array([pts1, pts2, pts3]).transpose([1,2,0]) + + return targs + +class JoystickMulti(ManualControlMulti): + + # #Settable Traits + joystick_method = traits.Float(1,desc="1: Normal velocity, 0: Position control") + random_rewards = traits.Float(0,desc="Add randomness to reward, 1: yes, 0: no") + joystick_speed = traits.Float(20, desc="Radius of cursor") + + is_bmi_seed = True + + def __init__(self, *args, **kwargs): + super(JoystickMulti, self).__init__(*args, **kwargs) + self.current_pt=np.zeros([3]) #keep track of current pt + self.last_pt=np.zeros([3]) #keep track of last pt to calc. velocity + #self.plant_visible = False + #self.plant.cursor_color = (0., 1., 0., 0.5) + + def update_report_stats(self): + super(JoystickMulti, self).update_report_stats() + start_time = self.state_log[0][1] + rewardtimes=np.array([state[1] for state in self.state_log if state[0]=='reward']) + if len(rewardtimes): + rt = rewardtimes[-1]-start_time + else: + rt= np.float64("0.0") + + sec = str(np.int(np.mod(rt,60))) + if len(sec) < 2: + sec = '0'+sec + self.reportstats['Time Of Last Reward'] = str(np.int(np.floor(rt/60))) + ':' + sec + def _test_trial_complete(self, ts): if self.target_index==self.chain_length-1 : if self.random_rewards: if not self.rand_reward_set_flag: #reward time has not been set for this iteration self.reward_time = np.max([2*(np.random.rand()-0.5) + self.reward_time_base, self.reward_time_base/2]) #set randomly with min of base / 2 - self.rand_reward_set_flag =1 + self.rand_reward_set_flag =1; #print self.reward_time, self.rand_reward_set_flag return self.target_index==self.chain_length-1 - + def _test_reward_end(self, ts): - #When finished reward, reset flag. + #When finished reward, reset flag. if self.random_rewards: if ts > self.reward_time: - self.rand_reward_set_flag = 0 + self.rand_reward_set_flag = 0; #print self.reward_time, self.rand_reward_set_flag, ts return ts > self.reward_time - def _transform_coords(self, coords): - ''' - Returns transformed coordinates based on rotation, offset, and scale traits - ''' - offset = np.array( - [[1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [self.offset[0], self.offset[1], self.offset[2], 1]] - ) - scale = np.array( - [[self.scale, 0, 0, 0], - [0, self.scale, 0, 0], - [0, 0, self.scale, 0], - [0, 0, 0, 1]] - ) - old = np.concatenate((np.reshape(coords, -1), [1])) - new = np.linalg.multi_dot((old, offset, scale, rotations[self.rotation])) - return new[0:3] - - def _get_manual_position(self): - ''' - Fetches joystick position - ''' - if not hasattr(self, 'joystick'): - return + def move_effector(self): + ''' Returns the 3D coordinates of the cursor. For manual control, uses + motiontracker data. If no motiontracker data available, returns None''' + + #get data from phidget pt = self.joystick.get() - if len(pt) == 0: - return + #print pt - pt = pt[-1] # Use only the latest coordinate + if len(pt) > 0: - if len(pt) == 2: - pt = np.concatenate((np.reshape(pt, -1), [0])) + pt = pt[-1][0] + x = pt[1] + y = 1-pt[0] - return [pt] - def move_effector(self): - ''' - Sets the 3D coordinates of the cursor. For manual control, uses - motiontracker / joystick / mouse data. If no data available, returns None - ''' - - # Get raw input and save it as task data - raw_coords = self._get_manual_position() # array of [3x1] arrays - if raw_coords is None or len(raw_coords) < 1: - self.no_data_counter[self.cycle_count % self._quality_window_size] = 1 - self.update_report_stats() - self.task_data['manual_input'] = np.empty((3,)) - return - - self.task_data['manual_input'] = raw_coords.copy() - self.no_data_counter[self.cycle_count % self._quality_window_size] = 0 - - # Transform coordinates - coords = self._transform_coords(raw_coords) - if self.limit2d: - coords[1] = 0 - - # Set cursor position - if not self.velocity_control: - self.current_pt = coords - else: - epsilon = 2*(10**-2) # Define epsilon to stabilize cursor movement - if sum((coords)**2) > epsilon: + pt[0]=1-pt[0]; #Switch L / R axes + calib = [0.5,0.5] #Sometimes zero point is subject to drift this is the value of the incoming joystick when at 'rest' + # calib = [ 0.487, 0. ] - # Add the velocity (units/s) to the position (units) - self.current_pt = coords / self.fps + self.last_pt - else: - self.current_pt = self.last_pt + #if self.joystick_method==0: + if self.joystick_method==0: + pos = np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]]) + pos[0] = pos[0]*36 + pos[2] = pos[2]*24 + self.current_pt = pos - self.plant.set_endpoint_pos(self.current_pt) - self.last_pt = self.current_pt.copy() + elif self.joystick_method==1: + #vel=np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]]) + vel = np.array([x-calib[0], 0., y-calib[1]]) + epsilon = 2*(10**-2) #Define epsilon to stabilize cursor movement + if sum((vel)**2) > epsilon: + self.current_pt=self.last_pt+20*vel*(1/60) #60 Hz update rate, dt = 1/60 + else: + self.current_pt = self.last_pt - def update_report_stats(self): - super().update_report_stats() - window_size = min(max(1, self.cycle_count), self._quality_window_size) - num_missing = np.sum(self.no_data_counter[:window_size]) - quality = 1 - num_missing / window_size - self.reportstats['Input quality'] = "{} %".format(int(100*quality)) + #self.current_pt = self.current_pt + (np.array([np.random.rand()-0.5, 0., np.random.rand()-0.5])*self.joystick_speed) + + if self.current_pt[0] < -25: self.current_pt[0] = -25 + if self.current_pt[0] > 25: self.current_pt[0] = 25 + if self.current_pt[-1] < -14: self.current_pt[-1] = -14 + if self.current_pt[-1] > 14: self.current_pt[-1] = 14 + + self.plant.set_endpoint_pos(self.current_pt) + self.last_pt = self.current_pt.copy() @classmethod - def get_desc(cls, params, log_summary): - duration = round(log_summary['runtime'] / 60, 1) - return "{}/{} succesful trials in {} min".format( - log_summary['n_success_trials'], log_summary['n_trials'], duration) + def get_desc(cls, params, report): + duration = report[-1][-1] - report[0][-1] + reward_count = 0 + for item in report: + if item[0] == "reward": + reward_count += 1 + return "{} rewarded trials in {} min".format(reward_count, duration) +class JoystickMulti2DWindow(JoystickMulti, WindowDispl2D): + fps = 20. + def __init__(self,*args, **kwargs): + super(JoystickMulti2DWindow, self).__init__(*args, **kwargs) -class ManualControl(ManualControlMixin, ScreenTargetCapture): - ''' - Slightly refactored original manual control task - ''' - pass + def _start_wait(self): + self.wait_time = 0. + super(JoystickMulti2DWindow, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause -class ManualControlDirectionConstraint(ManualControlMixin, ScreenReachAngle): - ''' - Adds an additional constraint that the direction of travel must be within a certain angle - ''' - pass \ No newline at end of file diff --git a/built_in_tasks/sim_task_KF.py b/built_in_tasks/sim_task_KF.py new file mode 100644 index 000000000..ff542842a --- /dev/null +++ b/built_in_tasks/sim_task_KF.py @@ -0,0 +1,174 @@ +from bmimultitasks import SimBMIControlMulti, SimBMICosEncKFDec +from features import SaveHDF +<<<<<<< HEAD +from features.task_code_features import TaskCodeStreamer +from features.simulation_features import get_enc_setup, SimKFDecoderRandom, SimCosineTunedEnc,SimIntentionLQRController +======= +from features.simulation_features import get_enc_setup, SimKFDecoderRandom, SimCosineTunedEnc,SimIntentionLQRController, SimHDF +>>>>>>> cd76cda088fabeed8a08d30b2a3f1eb478d76105 +from riglib import experiment + +import time +import numpy as np + + +""" +this task uses +""" + + +# build a sequence generator +if __name__ == "__main__": + + #generate task params + N_TARGETS = 1 + N_TRIALS = 3 +<<<<<<< HEAD +======= + tuning_level = 70 + clda_batch_time = 60 + clda_half_life = 60 +>>>>>>> cd76cda088fabeed8a08d30b2a3f1eb478d76105 + + + #clda on random matrix + ''' + DECODER_MODE = 'random' # in this case we load simulation_features.SimKFDecoderRandom + ENCODER_TYPE = 'cosine_tuned_encoder' + LEARNER_TYPE = 'feedback' # to dumb or not dumb it is a question 'feedback' + UPDATER_TYPE = 'smooth_batch' #none or "smooth_batch" + + + + + + #clda on trained decoder + #expect to get worse + ''' + +<<<<<<< HEAD + """ +======= + +<<<<<<< HEAD + DECODER_MODE = 'random' # in this case we load simulation_features.SimKFDecoderRandom +======= +>>>>>>> cd76cda088fabeed8a08d30b2a3f1eb478d76105 + DECODER_MODE = 'trainedKF' # in this case we load simulation_features.SimKFDecoderRandom +>>>>>>> 8728e3967b64bb04584a8a00023bc458f9ddf38f + ENCODER_TYPE = 'cosine_tuned_encoder' + LEARNER_TYPE = 'feedback' # to dumb or not dumb it is a question 'feedback' + UPDATER_TYPE = 'smooth_batch' #none or "smooth_batch" + + + + + +<<<<<<< HEAD + """ + #no clda +======= + + #on random + ''' +>>>>>>> cd76cda088fabeed8a08d30b2a3f1eb478d76105 + DECODER_MODE = 'trainedKF' # in this case we load simulation_features.SimKFDecoderRandom + ENCODER_TYPE = 'cosine_tuned_encoder' + LEARNER_TYPE = 'dumb' # to dumb or not dumb it is a question 'feedback' + UPDATER_TYPE = 'none' #none or "smooth_batch" +<<<<<<< HEAD + +======= + ''' +>>>>>>> cd76cda088fabeed8a08d30b2a3f1eb478d76105 + + SAVE_HDF = True + DEBUG_FEATURE = False + TASK_CODE_STREAMER = True + + seq = SimBMIControlMulti.sim_target_seq_generator_multi( + N_TARGETS, N_TRIALS) + + #neuron set up : 'std (20 neurons)' or 'toy (4 neurons)' + N_NEURONS, N_STATES, sim_C = get_enc_setup(sim_mode = 'toy', tuning_level=tuning_level) + + # set up assist level + assist_level = (1, 0) + + base_class = SimBMIControlMulti + feats = [] + + #set up intention feedbackcontroller + #this ideally set before the encoder + feats.append(SimIntentionLQRController) + + + #set up the encoder + if ENCODER_TYPE == 'cosine_tuned_encoder' : + feats.append(SimCosineTunedEnc) + print(f'{__name__}: selected SimCosineTunedEnc\n') + + + #now, we can set up a dumb/or not-dumb learner + if LEARNER_TYPE == 'feedback': + from features.simulation_features import SimFeedbackLearner + feats.append(SimFeedbackLearner) + else: + from features.simulation_features import SimDumbLearner + feats.append(SimDumbLearner) + + #take care the decoder setup + if DECODER_MODE == 'random': + feats.append(SimKFDecoderRandom) + print(f'{__name__}: set base class ') + print(f'{__name__}: selected SimKFDecoderRandom \n') + else: #defaul to a cosEnc and a pre-traind KF DEC + from features.simulation_features import SimKFDecoderSup + feats.append(SimKFDecoderSup) + print(f'{__name__}: set decoder to SimKFDecoderSup\n') + + + + #you know what? + #learner only collects firing rates labeled with estimated estimates + #we would also need to use the labeled data + #to update the decoder. + if UPDATER_TYPE == 'smooth_batch': + from features.simulation_features import SimSmoothBatch + feats.append(SimSmoothBatch) + else: #defaut to none + print(f'{__name__}: need to specify an updater') + + + if DEBUG_FEATURE: + from features.simulation_features import DebugFeature + feats.append(DebugFeature) + + if TASK_CODE_STREAMER: feats.append(TaskCodeStreamer) + + if SAVE_HDF: feats.append(SimHDF) + + #sav everthing in a kw + kwargs = dict() + kwargs['sim_C'] = sim_C + + kwargs['batch_time'] = clda_batch_time + kwargs['half_life'] = clda_half_life + kwargs['assist_level'] = assist_level + + #spawn the task + Exp = experiment.make(base_class, feats=feats) + #print(Exp) + exp = Exp(seq, **kwargs) + exp.init() + exp.run() # start the task + exp.decoder.plot_K() + + + #we clearn up and get the saved hdf file + N_SLEEP_TIME = 5 + print(f'{__name__} :sleep {N_SLEEP_TIME} seconds for hdf file to finish') + time.sleep(N_SLEEP_TIME) + + exp.cleanup_hdf() + diff --git a/built_in_tasks/target_capture_task.py b/built_in_tasks/target_capture_task.py index f309a583a..75e2c0e2b 100644 --- a/built_in_tasks/target_capture_task.py +++ b/built_in_tasks/target_capture_task.py @@ -13,7 +13,7 @@ from riglib import plants from riglib.stereo_opengl.window import Window -from .target_graphics import * +from target_graphics import * ## Plants # List of possible "plants" that a subject could control either during manual or brain control @@ -263,7 +263,10 @@ def update_report_stats(self): self.reportstats['Trial #'] = self.calc_trial_num() self.reportstats['Reward/min'] = np.round(self.calc_events_per_min('reward', 120.), decimals=2) -class ScreenTargetCapture(TargetCapture, Window): + + + +class ConcreteTargetCapture(TargetCapture): """Concrete implementation of TargetCapture task where targets are acquired by "holding" a cursor in an on-screen target""" @@ -301,13 +304,12 @@ def __init__(self, *args, **kwargs): self.plant_vis_prev = True self.cursor_vis_prev = True - # Add graphics models for the plant and targets to the window - if hasattr(self.plant, 'graphics_models'): - for model in self.plant.graphics_models: - self.add_model(model) + # Instantiate the targets instantiate_targets = kwargs.pop('instantiate_targets', True) + + self.target_location = np.array([0,0,0]) if instantiate_targets: # Need two targets to have the ability for delayed holds @@ -323,6 +325,9 @@ def __init__(self, *args, **kwargs): def init(self): self.add_dtype('trial', 'u4', (1,)) self.add_dtype('plant_visible', '?', (1,)) + + self.add_dtype('target', 'f8',(3,)) + self.add_dtype('target_index', 'i', (1,)) super().init() def _cycle(self): @@ -396,17 +401,19 @@ def _start_wait(self): # Instantiate the targets here so they don't show up in any states that might come before "wait" for target in self.targets: for model in target.graphics_models: - self.add_model(model) - target.hide() + #self.add_model(model) + #target.hide() + pass def _start_target(self): super()._start_target() # Show target if it is hidden (this is the first target, or previous state was a penalty) target = self.targets[self.target_index % 2] + self.target_location = self.targs[self.target_index] if self.target_index == 0: target.move_to_position(self.targs[self.target_index]) - target.show() + #target.show() self.sync_event('TARGET_ON', self.gen_indices[self.target_index]) def _start_hold(self): @@ -421,7 +428,7 @@ def _start_delay(self): if next_idx < self.chain_length: target = self.targets[next_idx % 2] target.move_to_position(self.targs[next_idx]) - target.show() + #target.show() self.sync_event('TARGET_ON', self.gen_indices[next_idx]) else: # This delay state should only last 1 cycle, don't sync anything @@ -436,7 +443,7 @@ def _start_targ_transition(self): elif self.target_index + 1 < self.chain_length: # Hide the current target if there are more - self.targets[self.target_index % 2].hide() + #self.targets[self.target_index % 2].hide() self.sync_event('TARGET_OFF', self.gen_indices[self.target_index]) def _start_hold_penalty(self): @@ -444,7 +451,7 @@ def _start_hold_penalty(self): super()._start_hold_penalty() # Hide targets for target in self.targets: - target.hide() + #target.hide() target.reset() def _end_hold_penalty(self): @@ -456,7 +463,7 @@ def _start_delay_penalty(self): super()._start_delay_penalty() # Hide targets for target in self.targets: - target.hide() + #target.hide() target.reset() def _end_delay_penalty(self): @@ -468,7 +475,7 @@ def _start_timeout_penalty(self): super()._start_timeout_penalty() # Hide targets for target in self.targets: - target.hide() + #target.hide() target.reset() def _end_timeout_penalty(self): @@ -485,7 +492,7 @@ def _end_reward(self): # Hide targets for target in self.targets: - target.hide() + #target.hide() target.reset() #### Generator functions #### @@ -508,7 +515,7 @@ def static(pos=(0,0,0), ntrials=0): yield [0], np.array(pos) @staticmethod - def out_2D(nblocks=100, ntargets=8, distance=10, origin=(0,0,0)): + def out_2D(nblocks=500, ntargets=8, distance=10, origin=(0,0,0)): ''' Generates a sequence of 2D (x and z) targets at a given distance from the origin @@ -635,6 +642,31 @@ def rand_target_chain_3D(ntrials=100, chain_length=1, boundaries=(-12,12,-10,10, yield idx+np.arange(chain_length), pts idx += chain_length +class ScreenTargetCapture(ConcreteTargetCapture, Window): + """ + a task that is mixed with Window display + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Add graphics models for the plant and targets to the window + if hasattr(self.plant, 'graphics_models'): + for model in self.plant.graphics_models: + self.add_model(model) + + + #### STATE FUNCTIONS #### + def _start_wait(self): + super()._start_wait() + + if self.calc_trial_num() == 0: + + # Instantiate the targets here so they don't show up in any states that might come before "wait" + for target in self.targets: + for model in target.graphics_models: + self.add_model(model) + target.hide() + class ScreenReachAngle(ScreenTargetCapture): ''' A modified task that requires the cursor to move in the right direction towards the target, diff --git a/features/simulation_features.py b/features/simulation_features.py index f24ed4848..272a6f0df 100644 --- a/features/simulation_features.py +++ b/features/simulation_features.py @@ -75,6 +75,23 @@ def sendMsg(self, msg): None ''' self.msgs.append((msg, -1)) + + def set_state(self, condition, **kwargs): + ''' + mimicks hdf_features + Save task state transitions to HDF + + Parameters + ---------- + condition: string + Name of new state to transition into. The state name must be a key in the 'status' dictionary attribute of the task + + Returns + ------- + None + ''' + self.sendMsg(condition) + super(SimHDF, self).set_state(condition, **kwargs) def _cycle(self): super(SimHDF, self)._cycle() @@ -88,6 +105,25 @@ def __init__(self, *args, **kwargs): def tick(self, *args, **kwargs): pass + def get_time(self): + ''' + Simulates time based on Delta*cycle_count, where the update_rate is specified as an instance attribute + ''' + try: + return self.cycle_count * self.update_rate + + except: + # loop_counter has not been initialized yet, return 0 + return 0 + + @property + def update_rate(self): + ''' + Attribute for update rate of task. Using @property in case any future modifications + decide to change fps on initialization + ''' + return 1./60 + class SimClockTick(object): ''' Summary: A simulation pygame.clock to use in simulations that inherit from experiment.Experiment, to overwrite @@ -131,8 +167,6 @@ def get_time(self): Simulates time based on Delta*cycle_count, where the update_rate is specified as an instance attribute ''' try: - if not (self.cycle_count % (60*10)): - print(self.cycle_count/(60*10.)) return self.cycle_count * self.update_rate except: @@ -147,6 +181,31 @@ def update_rate(self): ''' return 1./60 +############################# +##### Simulation Feedback controllers +##### the stuff actually mimicks higher level of the brain +############################# +class SimIntentionLQRController(object): + """ + this class uses feedback contro + """ + def __init__(self, *args, **kwargs): + + ssm = state_space_models.StateSpaceEndptVel2D() + A, B, W = ssm.get_ssm_matrices() + Q = np.mat(np.diag([1., 1, 1, 0, 0, 0, 0])) + R = 10000*np.mat(np.diag([1., 1., 1.])) + self.fb_ctrl = feedback_controllers.LQRController(A, B, Q, R) + + print() + print(f'{__name__}.{__class__.__name__}: LQRController used \n') + + super().__init__(*args, **kwargs) + + +############################# +##### Simulation Encoders +############################# class SimNeuralEnc(object): def __init__(self, *args, **kwargs): @@ -212,7 +271,11 @@ class SimCosineTunedEnc(SimNeuralEnc): def _init_neural_encoder(self): ## Simulation neural encoder from riglib.bmi.sim_neurons import GenericCosEnc#CLDASimCosEnc - print('SimCosineTunedEnc SSM:', self.ssm) + from riglib.bmi.state_space_models import StateSpaceEndptVel2D + + self.ssm = StateSpaceEndptVel2D() + + print('\nSimCosineTunedEnc SSM:', self.ssm, '\n') self.encoder = GenericCosEnc(self.sim_C, self.ssm, return_ts=True, DT=0.1, call_ds_rate=6) def create_feature_extractor(self): @@ -223,6 +286,31 @@ def create_feature_extractor(self): n_subbins=self.decoder.n_subbins, units=self.decoder.units, task=self) self._add_feature_extractor_dtype() +class SimCosineTunedEncWithNoise(SimCosineTunedEnc): + + def __init__(self, *args, **kwargs): + super().__init__( *args, **kwargs) + + self.noise_mode = kwargs['noise_mode'] + self.percent_poisson_noise = kwargs['percent_noise'] + if 'fixed_noise_level' in kwargs.keys(): self.fixed_noise_level = kwargs['fixed_noise_level'] + + print(f'{__class__}: added CosineTunedEncWithNoise ') + + + def _init_neural_encoder(self): + ## Simulation neural encoder + from riglib.bmi.sim_neurons import GenericCosEncWithNoise + from riglib.bmi.state_space_models import StateSpaceEndptVel2D + + self.ssm = StateSpaceEndptVel2D() + + print('\n using encoder with additional noises ', self.ssm, '\n') + self.encoder = GenericCosEncWithNoise(self.sim_C, self.ssm, + self.noise_mode, noise_profile= self.percent_poisson_noise, fixed_noise_rate = self.fixed_noise_level, + return_ts=True, DT=0.1, call_ds_rate=6) + + class SimNormCosineTunedEnc(SimNeuralEnc): def _init_neural_encoder(self): @@ -326,6 +414,30 @@ class SimKFDecoder(object): def __init__(self, *args, **kwargs): super(SimKFDecoder, self).__init__(*args, **kwargs) + def init(self, *args, **kwargs): + units = self.encoder.get_units() + n_units = len(units) + NUM_STATES = 7 + self.add_dtype('obs_t', 'f8', (n_units,1)) + self.add_dtype('KC', 'f8', (NUM_STATES,NUM_STATES)) + + self.add_dtype('pred_state_mean', 'f8', (NUM_STATES,1)) + self.add_dtype('post_state_mean', 'f8', (NUM_STATES,1)) + + self.add_dtype('pred_state_P', 'f8', (NUM_STATES,NUM_STATES)) + self.add_dtype('post_state_P', 'f8', (NUM_STATES,NUM_STATES)) + + self.k_mat_params = list() + print(f'{__name__}: added tracking of K matrix') + + super().init(*args, **kwargs) + + def _cycle(self, *args, **kwargs): + super()._cycle(*args, **kwargs) + + self.k_mat_params.append(self.decoder.filt.K) + + def change_dec_ssm(self): decoder_old = self.decoder_old ssm_old = decoder_old.ssm @@ -373,42 +485,97 @@ class SimKFDecoderSup(SimKFDecoder): ''' Construct a KFDecoder based on encoder output in response to states simulated according to the state space model's process noise ''' - def load_decoder(self): + + + + def load_decoder(self, supplied_encoder = None, supplied_SSM = None, n_samples = 2000): ''' Instantiate the neural encoder and "train" the decoder + + update 2020 Dec.: + allow training with supplied SSM by adding supplied_enc and supplied_ssm + + + Parameters: + supplied_encoder: SimNeuralEnc and its children classses + supplied_SSM: ssm used to esablish the decoder. + + Output: + None + however: rele decoder information can be accessed via + self.decoder: KFDecoder object + self.init_neural_features : init training data + self.init_kin_features: kinematic states. ''' if hasattr(self, 'decoder'): print('Already have a decoder!') else: print("Creating simulation decoder..") - print(self.encoder, type(self.encoder)) - n_samples = 2000 - units = self.encoder.get_units() + + #select the encoder + #prioritize loading self. encoder + if hasattr(self, 'encoder'): + encoder = self.encoder + print('SimKFDecoderSup:loaded self.encoder') + elif supplied_encoder: + encoder = supplied_encoder + print('SimKFDecoderSup:loaded suppled_encoder from function input') + else: + print('SimKFDecoderSup: Neither self or supplied enc is supplied') + print('Decoder not traiined') + return + #if succussful, print out the type of decoder, eh + print(encoder, type(encoder)) + + #also need to select which ssm to use + #select the encoder + #prioritize loading self. encoder + if hasattr(self, 'ssm'): + ssm = self.ssm + print('SimKFDecoderSup:loaded self.ssm') + elif supplied_SSM: + ssm = supplied_SSM + print('SimKFDecoderSup:loaded suppled_ssm from function input') + else: + print('SimKFDecoderSup: Neither self or supplied ssm is suppleid') + print('Decoder not traiined') + return + #if succussful, print out the type of decoder, eh + print(encoder, type(encoder)) + + + + + units = encoder.get_units() n_units = len(units) - print('units: ', n_units) + print('SimKFDecoderSup: units: ', n_units) # draw samples from the W distribution - ssm = self.ssm A, _, W = ssm.get_ssm_matrices() mean = np.zeros(A.shape[0]) mean[-1] = 1 state_samples = np.random.multivariate_normal(mean, W, n_samples) + kin = state_samples.T + #produce spike samples spike_counts = np.zeros([n_units, n_samples]) - self.encoder.call_ds_rate = 1 + encoder.call_ds_rate = 1 for k in range(n_samples): - spike_counts[:,k] = np.array(self.encoder(state_samples[k], mode='counts')).ravel() + spike_counts[:,k] = np.array(encoder(state_samples[k,:].reshape(-1,1), mode='counts')).ravel() - kin = state_samples.T + #deal with clda zscore = False if hasattr(self, 'clda_adapt_mFR_stats'): if self.clda_adapt_mFR_stats: zscore = True - print(' zscore decoder ? : ', zscore) - self.decoder = train.train_KFDecoder_abstract(ssm, kin, spike_counts, units, 0.1, zscore=zscore) - self.encoder.call_ds_rate = 6 + print(' SimKFDecoderSup: zscore decoder ? : ', zscore) + #now we can train the decoder. + self.decoder = train.train_KFDecoder_abstract(ssm, kin, spike_counts, units, 0.1, zscore=zscore) + encoder.call_ds_rate = 6 + + #save the initial decoder parameters self.init_neural_features = spike_counts self.init_kin_features = kin @@ -455,6 +622,8 @@ def load_decoder(self): self.encoder.call_ds_rate = 6 super(SimKFDecoderShuffled, self).load_decoder() + + class SimKFDecoderRandom(SimKFDecoder): def load_decoder(self): ''' @@ -500,3 +669,278 @@ def load_decoder(self): # self._init_neural_encoder() self.decoder = train._train_PPFDecoder_sim_known_beta(self.beta_full[inds], self.encoder.units, dt=1./180) + + +############################# +##### Simulation learners +############################# +from riglib.bmi import clda +class SimDumbLearner(object): + """ + a feature wrapper to set up learner + this is essentially a replica of bmi.create_learner + but copy and paste here for + 1. unnessary redundency + 2. trivial understanding + """ + def create_learner(self): + ''' + The "learner" uses knowledge of the task goals to determine the "intended" + action of the BMI subject and pairs this intention estimation with actual observations. + ''' + + self.learn_flag = False + self.learner = clda.DumbLearner() + + +class SimFeedbackLearner(object): + """ + a trivial class to set up learner in the BMIloop + that uses feedback controller. + + We need + 1. batch_size(int): to tell how many data points we are accumulating. + + CAVEAT: + this uses the same feedback control as the encoder does. + + with this caveat, + we commencer the setup mainly using the setup info from riglib.bmi.clda.FeedbackControllerLearner + """ + + def __init__(self,*args, **kwargs): + """ + set up function + main job is to determine the batch_size + default batch size to 16 + + """ + #this is another boring python liner. + #highly UNrecommended for its simplicity and readability + self.batch_size = kwargs['batch_size'] if 'batch_size' in kwargs else 16 + print(f'\n {__name__}.{__class__.__name__}: start to create a sim leaner with a batchsize of {self.batch_size}') + super().__init__(*args, **kwargs) + + def create_learner(self): + """ + just set up the feedback controller learner + """ + + if hasattr(self, 'fb_ctrl'): + self.learn_flag = True + print(f'{__name__}: batch size is {self.batch_size}') + self.learner = clda.FeedbackControllerLearner(self.batch_size, self.fb_ctrl) + + print('') + print(f'{__name__}.{__class__.__name__}: flip the self.learn_flag to true') + print(f'{__name__}.{__class__.__name__}: succussfully created a feedback controller learner\n') + else: + print() + print(f'\n {__name__}.{__class__.__name__}: does not have fb_ctrl, sorrry, raise an error\n') + raise(ValueError) + +############################# +##### Simulation updaters +############################# +class SimSmoothBatch(object): + """ + this loads a simulation wrapper + for riglib.clda.KFSmoothbatch + + the below is commment on comment, which is very very important + it is labeled as a depreciated updater. + how can a classical updater be depreciated? + it should not be depreciated for R&D purposes! + """ + + def __init__(self, *arg, **kwargs): + """ + guess what, this function simplify sets up the key params for + the eh smoothbatch clda + + if you have any doubts, please go to riglib.bmi.KFSmoothBatch for a detailed discussion. + for those leaning towards mathematics, can waste your time on Orsborn 2012 or Dangi 2014. + """ + # honestly don't know what these params + #guess we do need to waste time on the foundational papers + + DEFAULT_BATCH_TIME = 1 + DEFAULT_HALF_LIFE = 60 + #again, no idea, seems related to how long it takes the weight of the past value drops to 1/2 + + self.batch_time = kwargs['batch_time'] if 'batch_time' in kwargs.keys() else DEFAULT_BATCH_TIME + self.half_life = kwargs['half_life'] if 'half_life' in kwargs.keys() else DEFAULT_HALF_LIFE + + #we are going to print out the rhos, sort of thing. + rho = np.exp(np.log(0.5)/(self.half_life/self.batch_time)) + print(f'{__name__}.{__class__.__name__}: rho in this simulation is {rho}\n') + + super().__init__(*arg, **kwargs) + + + def create_updater(self): + self.updater = clda.KFSmoothbatch(self.batch_time, self.half_life) + print() + print(f'{__class__.__name__}: created an updater with a batch time of {self.batch_time} and a half_life of {self.half_life} \n') + + +class SimSmoothBatchFullFeature(SimSmoothBatch): + """ + This is the full feature version of the SimSmoothBatch + """ + def __init__(self, *arg, **kwargs): + + super().__init__(*arg, **kwargs) + + self._num_full_features = kwargs['n_starting_feats'] # keep track of how many features start from begining + + def create_updater(self): + self.updater = clda.KFSmoothBatchFullFeature(self.batch_time, self.half_life, self._num_full_features) + print() + print(f'{__class__.__name__}: created a FULL FEATURE updater with a batch time of {self.batch_time} and a half_life of {self.half_life} \n') + + + + +############################# +##### Simulation helper classes/ features +############################# + + +class TimeCountDown(object): + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + + self.TOTAL_RUNNNING_TIME = kwargs.pop('total_exp_time', 10) + fps = 60 + self.total_frames = self.TOTAL_RUNNNING_TIME * fps + self.left_frames = self.total_frames + + print(f'TimeCountDown: assume fps to be {fps}') + + def _cycle(self, *args, **kwargs): + #basically just counts down + self.left_frames -= 1 + + #TODO: determine how to end the experiment: + #for now, we just + if self.left_frames == 0: + self.state = None + #and the experiemnt would be able to safely exit. + + super()._cycle(*args, **kwargs) + + + +class DebugFeature(object): + """ + the purpose of this feature is just to set self.debug_flag to true + + to plant this flag everywhere, use the following format + if hasattr(self, 'debug_flag'): + if self.debug_flag: + #do your thing + + + """ + def __init__(self, *args, **kwargs): + self.debug_flag = True + print(f'{__class__.__name__}:set debug flag to {self.debug_flag}') + super().__init__(*args, **kwargs) + +def get_enc_setup(sim_mode = 'two_gaussian_peaks', tuning_level = 1, n_neurons = 4): + ''' + sim_mode:str + std: mn 20 neurons + 'toy' # mn 4 neurons + + tuning_level: float + the tuning level at which a particular direction the firng rate is tuned + the higher the better + ''' + + print(f'{__name__}: get_enc_setup has a tuning_level of {tuning_level} \n') + + if sim_mode == 'toy': + #by toy, w mn 4 neurons: + #first 2 ctrl x velo + #lst 2 ctrl y vel + # build a observer matrix + N_NEURONS = 4 + N_STATES = 7 # 3 positions and 3 velocities and an offset + # build the observation matrix + sim_C = np.zeros((N_NEURONS, N_STATES)) + + + # control x positive directions + sim_C[0, :] = np.array([0, 0, 0, tuning_level, 0, 0, 0]) + sim_C[1, :] = np.array([0, 0, 0, -tuning_level, 0, 0, 0]) + # control z positive directions + sim_C[2, :] = np.array([0, 0, 0, 0, 0, tuning_level, 0]) + sim_C[3, :] = np.array([0, 0, 0, 0, 0, -tuning_level, 0]) + + + elif sim_mode == 'std': + # build a observer matrix + N_NEURONS = 25 + N_STATES = 7 # 3 positions and 3 velocities and an offset + # build the observation matrix + sim_C = np.zeros((N_NEURONS, N_STATES)) + # control x positive directions + sim_C[0, :] = np.array([0, 0, 0, tuning_level, 0, 0, 0]) + sim_C[1, :] = np.array([0, 0, 0, -tuning_level, 0, 0, 0]) + # control z positive directions + sim_C[2, :] = np.array([0, 0, 0, 0, 0, tuning_level, 0]) + sim_C[3, :] = np.array([0, 0, 0, 0, 0, -tuning_level, 0]) + + elif sim_mode == 'rot_90': + #the directions are along the four axes + N_NEURONS = n_neurons + N_STATES = 7 + sim_C = _get_alternate_encoder_setup_matrix(N_NEURONS, N_STATES, tuning_level) + + elif sim_mode == 'rand': + N_STATES = 7 + sim_C = _get_rand_encoder_matrix(n_neurons, N_STATES, tuning_level) + else: + raise Exception(f'not recognized mode {sim_mode}') + + return (n_neurons, N_STATES, sim_C) + +def _get_alternate_encoder_setup_matrix(N_NEURONS, N_STATES, tuning_level): + from itertools import cycle + axial_angle_iterator = cycle([0, np.pi / 2, np.pi, np.pi * 3 / 2]) + + X_VEL_IND = 3 + Y_VEL_IND = 5 + + sim_C = np.zeros((N_NEURONS, N_STATES)) + x_weights = np.zeros(N_NEURONS) + y_weights = np.zeros(N_NEURONS) + + for i in range(N_NEURONS): + current_angle = next(axial_angle_iterator) + x_weights[i] = np.cos(current_angle) * tuning_level + y_weights[i] = np.sin(current_angle) * tuning_level + + sim_C[:,X_VEL_IND] = x_weights + sim_C[:,Y_VEL_IND] = y_weights + + return sim_C + +def _get_rand_encoder_matrix(n_neurons, N_STATES, tuning_level): + #sample 2 pi: + prefered_angles_in_rad = np.random.uniform(low = 0, high = 2 * np.pi, size = n_neurons) + + sim_C = np.zeros((n_neurons, N_STATES)) + + X_VEL_IND = 3 + Y_VEL_IND = 5 + + #calculate the matrices + sim_C[:,X_VEL_IND] = np.cos(prefered_angles_in_rad) * tuning_level + sim_C[:,Y_VEL_IND] = np.sin(prefered_angles_in_rad) * tuning_level + + return sim_C \ No newline at end of file diff --git a/features/sync_features.py b/features/sync_features.py index 429280df3..e5f846a71 100644 --- a/features/sync_features.py +++ b/features/sync_features.py @@ -1,5 +1,5 @@ from riglib.experiment import traits -from riglib.gpio import NIGPIO, DigitalWave +from riglib.gpio import NIGPIO, DigitalWave, TestGPIO import numpy as np import tables import time @@ -42,7 +42,7 @@ def decode_event(dictionary, value): return ordered_list[-1][0], 0 return None -class NIDAQSync(traits.HasTraits): +class HDFSync(traits.HasTraits): sync_params = dict( sync_protocol = 'rig1', @@ -61,7 +61,7 @@ class NIDAQSync(traits.HasTraits): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.sync_gpio = NIGPIO() + self.sync_gpio = TestGPIO() self.sync_every_cycle = True def init(self, *args, **kwargs): @@ -153,7 +153,18 @@ def cleanup_hdf(self): for param in self.sync_params.keys(): h5file.root.sync_events.attrs[param] = self.sync_params[param] h5file.close() - + +class NIDAQSync(HDFSync): + """ + this class replaces the NIDAQ output with a dummy hardware output class (GPIO) + so all sync information (clock etc, ) is saved to the hdf file. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sync_gpio = NIGPIO() + + + import copy import pygame from built_in_tasks.target_graphics import VirtualRectangularTarget diff --git a/features/task_code_features.py b/features/task_code_features.py new file mode 100644 index 000000000..35076ab6c --- /dev/null +++ b/features/task_code_features.py @@ -0,0 +1,77 @@ +from riglib.dio.NIUSB6501 import write_to_comedi +import time + +''' +TO-DO +this needs abstracton and encapsulation +''' + +class TaskCodeStreamer(object): + + ''' + TaskCodeDict = { + 'wait': 1, + 'target':2, #target appears + 'hold': 15, + 'targ_transition': 6, + 'reward': 0 + } + ''' + #binary state, reward or not + TaskCodeDict = { + 'wait': 0, + 'target':1, #target appears + 'hold': 2, + 'targ_transition': 3, + 'reward': 4, + 'None':255 + } + NONE_CODE = 255 + + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + #clear the output + write_to_comedi(0) + + def set_state(self, condition, **kwargs): + ''' + Extension of riglib.experiment.Experiment.set_state. + + Parameters + ---------- + condition : string + Name of new state. + **kwargs : dict + Passed to 'super' set_state function + + Returns + ------- + None + ''' + if condition in self.TaskCodeDict.keys(): + print(f'transition to {condition} with task code {self.TaskCodeDict[condition]}') + write_to_comedi(self.TaskCodeDict[condition]) + elif condition is None: + print(f'transition to {condition}') + write_to_comedi(self.NONE_CODE) + else: + print(f'transition to {condition}') + + + + super().set_state(condition, **kwargs) + + + +if __name__ == "__main__": + #testing script + #that flashes every 0.4 second + + while 1: + write_to_comedi(0) + time.sleep(0.2) + + write_to_comedi(255) + time.sleep(0.2) + \ No newline at end of file diff --git a/riglib/bmi/bmi.py b/riglib/bmi/bmi.py index 48ccc404c..29e23df10 100644 --- a/riglib/bmi/bmi.py +++ b/riglib/bmi/bmi.py @@ -10,6 +10,7 @@ import time import re import os +from numpy.core.fromnumeric import trace import tables import datetime import copy @@ -559,7 +560,6 @@ def plot_pds(self, C, ax=None, plot_states=['hand_vx', 'hand_vz'], invert=False, if ax == None: plt.figure() ax = plt.subplot(111) - ax.hold(True) if C.shape[1] > 2: state_inds = [self.states.index(x) for x in plot_states] @@ -572,7 +572,12 @@ def plot_pds(self, C, ax=None, plot_states=['hand_vx', 'hand_vz'], invert=False, C = C*-1 for k in range(n_neurons): unit_str = '%d%s' % (self.units[k,0], chr(96 + self.units[k,1])) - ax.plot([0, C[k, x]], [0, C[k, z]], label=unit_str, linestyle=linestyles[k/7 % len(linestyles)], **kwargs) + ax.plot([0, C[k, x]], [0, C[k, z]], + label=unit_str, + linestyle=linestyles[int(k/7) % len(linestyles)], + **kwargs) + + ax.legend(bbox_to_anchor=(1.1, 1.05), prop=dict(size=8)) try: ax.set_xlabel(plot_states[0]) @@ -1325,7 +1330,7 @@ def cleanup_hdf(self): Re-open the HDF file and save any extra task data kept in RAM ''' super(BMILoop, self).cleanup_hdf() - log_file = open(os.path.join(os.getenv("HOME"), 'code/bmi3d/log/clda_log'), 'w') + log_file = open('clda_log.log', 'w') log_file.write(str(self.state) + '\n') try: from . import clda @@ -1339,6 +1344,7 @@ def cleanup_hdf(self): ignore_none=ignore_none) except: import traceback + traceback.print_exc() traceback.print_exc(file=log_file) log_file.close() @@ -1358,7 +1364,7 @@ def write_clda_data_to_hdf_table(hdf_fname, data, ignore_none=False): ------- None ''' - log_file = open(os.path.expandvars('$HOME/code/bmi3d/log/clda_hdf_log'), 'w') + log_file = open(os.path.expandvars('$HOME/code/bmi3d/log/clda_hdf_log'), 'a') compfilt = tables.Filters(complevel=5, complib="zlib", shuffle=True) if len(data) > 0: @@ -1370,6 +1376,7 @@ def write_clda_data_to_hdf_table(hdf_fname, data, ignore_none=False): first_update = data[k] table_col_names = list(first_update.keys()) + print(f'{__name__}: clda table names') print(table_col_names) dtype = [] shapes = [] @@ -1385,8 +1392,8 @@ def write_clda_data_to_hdf_table(hdf_fname, data, ignore_none=False): # Create the HDF table with the datatype above dtype = np.dtype(dtype) - h5file = tables.openFile(hdf_fname, mode='a') - arr = h5file.createTable("/", 'clda', dtype, filters=compfilt) + h5file = tables.open_file(hdf_fname, mode='a') + arr = h5file.create_table("/", 'clda', dtype, filters=compfilt) null_update = np.zeros((1,), dtype=dtype) for col_name in table_col_names: diff --git a/riglib/bmi/clda.py b/riglib/bmi/clda.py index 2b8726a24..19954e58b 100644 --- a/riglib/bmi/clda.py +++ b/riglib/bmi/clda.py @@ -1017,7 +1017,6 @@ def calc(self, intended_kin=None, spike_counts=None, decoder=None, half_life=Non determine the C_hat and Q_hat of new batch. Then combine with old parameters using step-size rho """ - print("calculating new SB parameters") C_old = decoder.kf.C Q_old = decoder.kf.Q drives_neurons = decoder.drives_neurons @@ -1044,7 +1043,105 @@ def calc(self, intended_kin=None, spike_counts=None, decoder=None, half_life=Non 'mFR':mFR, 'sdFR':sdFR, 'rho':rho } return new_params +class KFSmoothBatchFullFeature(KFSmoothbatch): + update_kwargs = dict(steady_state=True) + def __init__(self, batch_time, half_life, + number_of_features, number_of_states = 7): + ''' + Constructor for KFSmoothbatch + + Parameters + ---------- + batch_time : float + Time over which to collect sample data + half_life : float + Time over which parameters are half-overwritten + + Return + ------ + KFSmoothbatch instance + ''' + super(KFSmoothbatch, self).__init__(self.calc, multiproc=False) + self.half_life = half_life + self.batch_time = batch_time + self.rho = np.exp(np.log(0.5) / (self.half_life/batch_time)) + + self._full_C = np.zeros((number_of_features, number_of_states)) + self._full_Q = np.zeros((number_of_features, number_of_features)) + self._full_mFR = np.zeros(number_of_features) + self._full_sdFR = np.zeros(number_of_features) + + def calc(self, intended_kin=None, spike_counts=None, decoder=None, half_life=None, **kwargs): + """ + Smoothbatch calculations + + Run least-squares on (intended_kinematics, spike_counts) to + determine the C_hat and Q_hat of new batch. Then combine with + old parameters using step-size rho + """ + C_old = self._full_C + Q_old = self._full_Q + drives_neurons = decoder.drives_neurons # TODO: check if this has any real infuence + mFR_old = self._full_mFR + sdFR_old = self._full_sdFR + + spike_counts = kwargs['unselected_spike_counts'] + C_hat, Q_hat = kfdecoder.KalmanFilter.MLE_obs_model( + intended_kin, spike_counts, include_offset=False, drives_obs=drives_neurons) + + if not (half_life is None): + rho = np.exp(np.log(0.5)/(half_life/self.batch_time)) + else: + rho = self.rho + + # apply CLDA + C = (1-rho)*C_hat + rho*C_old + Q = (1-rho)*Q_hat + rho*Q_old + + mFR = (1-rho)*np.mean(spike_counts.T, axis=0) + rho*mFR_old + sdFR = (1-rho)*np.std(spike_counts.T, axis=0) + rho*sdFR_old + + self._update_the_matrices(C,Q, mFR, sdFR, **kwargs) + + # select the active neurons + # not all neurons are selected to drive state updates at the same time + if "active_feat_set" in kwargs: + active_feat_set = kwargs['active_feat_set'] + C, Q, mFR, sdFR = self.select_decoder_matrices(active_feat_set, C, Q, mFR, sdFR) + + print(__class__.__name__ + ": " + str(len(active_feat_set)) + " neurons selected") + + D = C.T @ np.linalg.pinv(Q) @ C + + new_params = {'kf.C':C, 'kf.Q':Q, + 'kf.C_xpose_Q_inv_C':D, 'kf.C_xpose_Q_inv':C.T @ np.linalg.pinv(Q), + 'mFR':mFR, 'sdFR':sdFR, 'rho':rho, + 'selected_decoder_features_flag': True } + + return new_params + + def _update_the_matrices(self, C, Q, mFR, sdFR, **kwargs): + + self._full_C = C + self._full_Q = Q + self._full_mFR = mFR + self._full_sdFR = sdFR + + + + def select_decoder_matrices(self, active_feat_set, C, Q, mFR, sdFR): + ''' + Select the decoder matrices for the active features + ''' + # select the active neurons + C_selected = C[active_feat_set, :] + Q_selected = Q[active_feat_set, :][:, active_feat_set] + mFR_selected = mFR[active_feat_set] + sdFR_selected = sdFR[active_feat_set] + + return C_selected, Q_selected, mFR_selected, sdFR_selected + class KFOrthogonalPlantSmoothbatch(KFSmoothbatch): '''This module is deprecated. See KFRML_IVC''' def __init__(self, *args, **kwargs): diff --git a/riglib/bmi/kfdecoder.py b/riglib/bmi/kfdecoder.py index 483b27e51..b8f779b82 100644 --- a/riglib/bmi/kfdecoder.py +++ b/riglib/bmi/kfdecoder.py @@ -8,6 +8,7 @@ from . import bmi import pickle import re +import copy class KalmanFilter(bmi.GaussianStateHMM): """ @@ -50,6 +51,17 @@ def __init__(self, A=None, W=None, C=None, Q=None, is_stochastic=None): self.W = np.mat(W) self.C = np.mat(C) self.Q = np.mat(Q) + self.K = np.nan + self.KC = np.nan + + self.obs_t = np.nan + self.pred_state_mean = np.nan + self.post_state_mean = np.nan + + self.pred_state_P = np.nan + self.post_state_P = np.nan + + if is_stochastic is None: n_states = self.A.shape[0] @@ -129,6 +141,9 @@ def _forward_infer(self, st, obs_t, Bu=None, u=None, x_target=None, F=None, obs_ using_control_input = (Bu is not None) or (u is not None) or (x_target is not None) pred_state = self._ssm_pred(st, target_state=x_target, Bu=Bu, u=u, F=F) + self.pred_state_mean = copy.deepcopy(pred_state.mean) + self.pred_state_P = copy.deepcopy(pred_state.cov) + C, Q = self.C, self.Q P = pred_state.cov @@ -147,6 +162,15 @@ def _forward_infer(self, st, obs_t, Bu=None, u=None, x_target=None, F=None, obs_ post_state.cov = (I - KC) * P + + #save everything + self.obs_t = np.mat(obs_t) + self.KC = np.mat(KC) + self.K = np.mat(K) + + self.post_state_mean = copy.deepcopy(post_state.mean) + self.post_state_P = copy.deepcopy(post_state.cov) + return post_state def set_state_cov(self, n_steps): @@ -345,7 +369,7 @@ def MLE_obs_model(self, hidden_state, obs, include_offset=True, drives_obs=None, # ML estimate of C and Q if regularizer is None: - C = np.mat(np.linalg.lstsq(X.T, Y.T)[0].T) + C = np.mat(np.linalg.lstsq(X.T, Y.T, rcond=None)[0].T) else: x = X.T y = Y.T diff --git a/riglib/bmi/sim_neurons.py b/riglib/bmi/sim_neurons.py index 6d63a6d18..61860c235 100644 --- a/riglib/bmi/sim_neurons.py +++ b/riglib/bmi/sim_neurons.py @@ -170,7 +170,116 @@ def __call__(self, next_state, mode=None): self.call_count += 1 return ts_data +class GenericCosEncWithNoise(GenericCosEnc): + + def __init__(self, C, ssm, noise_mode, noise_profile = None, fixed_noise_rate = 100, + return_ts=False, DT=0.1, call_ds_rate=6): + ''' + Constructor for CosEncWithVariableNoises + + Parameters + ---------- + + + Returns + ------- + + ''' + #exactly the same function signature. + super().__init__(C,ssm, return_ts = return_ts, DT =DT, call_ds_rate=call_ds_rate) + + self._initialize_noise_profile(noise_profile, noise_mode, FIXED_NOISE_RATE = fixed_noise_rate) + + def _initialize_noise_profile(self, noise_profile, noise_mode , FIXED_NOISE_RATE = 100): + + #get the neuron number + self.n_neurons = self.C.shape[0] + + self.noise_mode = noise_mode + self.FIXED_NOISE_RATE = FIXED_NOISE_RATE + + self._select_gen_noise_function() + + if noise_profile is None: + self.noise_profile = np.zeros(n_neurons) + else: + #make sure the same number of neurons + assert (self.n_neurons,1) == noise_profile.shape + self.noise_profile = noise_profile + + def _select_gen_noise_function(self): + + self.noise_models_dict = { + 'fixed_poisson': self._gen_fixed_poisson_noise, + 'relative_poisson':self._gen_relative_poisson_noise, + 'fixed_gaussian':self._gen_fixed_gaussian_noise, + 'relative_gaussian':self._gen_relative_gaussian_noise + } + + if self.noise_mode not in self.noise_models_dict.keys(): + raise Exception(f'unsupported noise model {self.noise_mode}, and available modes are {self.noise_models_dict.keys()}') + + self._gen_noise = self.noise_models_dict[self.noise_mode] + + print(self._gen_noise) + + + def return_spikes(self, rates, mode=None): + + counts = self._generate_poisson_counts(rates) + counts = counts + self._gen_noise(counts) + counts[counts < 0] = 0 #floor the counts to zero + + if np.logical_or(mode=='ts', np.logical_and(mode is None, self.return_ts)): + return self.gen_time_stamped_spikes(counts, mode = mode) + + elif np.logical_or(mode=='counts', np.logical_and(mode is None, self.return_ts is False)): + return counts + + def _generate_poisson_counts(self, rates): + # Floor firing rates at 0 Hz + rates[rates < 0] = 0 + + return np.random.poisson(rates * self.DT) + + def _gen_fixed_poisson_noise(self, counts): + ''' + counts not used, for compatibility + ''' + return np.random.poisson(self.noise_profile * self.FIXED_NOISE_RATE) + + def _gen_relative_poisson_noise(self, counts): + ''' + relative to the counts + ''' + return np.random.poisson(self.noise_profile * counts) + + def _gen_fixed_gaussian_noise(self, counts): + return np.random.standard_normal(counts.shape) * self.noise_profile * self.FIXED_NOISE_RATE + + def _gen_relative_gaussian_noise(self, counts): + return np.random.standard_normal(counts.shape) * self.noise_profile * counts + + def gen_time_stamped_spikes(self, counts, mode = None): + + ts = [] + n_neurons = self.n_neurons + + for k, ind in enumerate(self.unit_inds): + + # separate spike counts into individual time-stamps + n_spikes = int(counts[k]) + fake_time = (self.call_count + 0.5)* 1./60 + if n_spikes > 0: + + spike_data = [(fake_time, ind, 1) for m in range(n_spikes)] + ts += (spike_data) + + ts = np.array(ts, dtype=ts_dtype) + return ts + + class FACosEnc(GenericCosEnc): ''' Simulate neurons where rate is linear function of underlying factor modulation, rate param through Poisson diff --git a/riglib/dio/NIUSB6501/control_comedi b/riglib/dio/NIUSB6501/control_comedi index 2f58df990..6345e644f 100755 Binary files a/riglib/dio/NIUSB6501/control_comedi and b/riglib/dio/NIUSB6501/control_comedi differ diff --git a/riglib/dio/NIUSB6501/control_comedi_swig.i b/riglib/dio/NIUSB6501/control_comedi_swig.i index d6f66b3c2..36e488ed5 100644 --- a/riglib/dio/NIUSB6501/control_comedi_swig.i +++ b/riglib/dio/NIUSB6501/control_comedi_swig.i @@ -5,4 +5,4 @@ extern int set_bits_in_nidaq_using_mask_and_data(int mask, int data, int base_ch %} extern unsigned char comedi_init(char* dev); -extern int set_bits_in_nidaq_using_mask_and_data(int mask, int data, int base_channel); \ No newline at end of file +extern int set_bits_in_nidaq_using_mask_and_data(int mask, int data, int base_channel); diff --git a/riglib/dio/nidaq/pcidio.py b/riglib/dio/nidaq/pcidio.py new file mode 100644 index 000000000..4260f55d2 --- /dev/null +++ b/riglib/dio/nidaq/pcidio.py @@ -0,0 +1,127 @@ +# This file was automatically generated by SWIG (http://www.swig.org). +# Version 3.0.8 +# +# Do not make changes to this file unless you know what you are doing--modify +# the SWIG interface file instead. + + + + + +from sys import version_info +if version_info >= (2, 6, 0): + def swig_import_helper(): + from os.path import dirname + import imp + fp = None + try: + fp, pathname, description = imp.find_module('_pcidio', [dirname(__file__)]) + except ImportError: + import _pcidio + return _pcidio + if fp is not None: + try: + _mod = imp.load_module('_pcidio', fp, pathname, description) + finally: + fp.close() + return _mod + _pcidio = swig_import_helper() + del swig_import_helper +else: + import _pcidio +del version_info +try: + _swig_property = property +except NameError: + pass # Python < 2.2 doesn't have 'property'. + + +def _swig_setattr_nondynamic(self, class_type, name, value, static=1): + if (name == "thisown"): + return self.this.own(value) + if (name == "this"): + if type(value).__name__ == 'SwigPyObject': + self.__dict__[name] = value + return + method = class_type.__swig_setmethods__.get(name, None) + if method: + return method(self, value) + if (not static): + if _newclass: + object.__setattr__(self, name, value) + else: + self.__dict__[name] = value + else: + raise AttributeError("You cannot add attributes to %s" % self) + + +def _swig_setattr(self, class_type, name, value): + return _swig_setattr_nondynamic(self, class_type, name, value, 0) + + +def _swig_getattr_nondynamic(self, class_type, name, static=1): + if (name == "thisown"): + return self.this.own() + method = class_type.__swig_getmethods__.get(name, None) + if method: + return method(self) + if (not static): + return object.__getattr__(self, name) + else: + raise AttributeError(name) + +def _swig_getattr(self, class_type, name): + return _swig_getattr_nondynamic(self, class_type, name, 0) + + +def _swig_repr(self): + try: + strthis = "proxy of " + self.this.__repr__() + except Exception: + strthis = "" + return "<%s.%s; %s >" % (self.__class__.__module__, self.__class__.__name__, strthis,) + +try: + _object = object + _newclass = 1 +except AttributeError: + class _object: + pass + _newclass = 0 + + + +def init(dev): + return _pcidio.init(dev) +init = _pcidio.init + +def closeall(): + return _pcidio.closeall() +closeall = _pcidio.closeall + +def sendMsg(msg): + return _pcidio.sendMsg(msg) +sendMsg = _pcidio.sendMsg + +def register_sys(name, dtype): + return _pcidio.register_sys(name, dtype) +register_sys = _pcidio.register_sys + +def sendData(idx, data): + return _pcidio.sendData(idx, data) +sendData = _pcidio.sendData + +def sendRowCount(idx): + return _pcidio.sendRowCount(idx) +sendRowCount = _pcidio.sendRowCount + +def sendRowByte(idx): + return _pcidio.sendRowByte(idx) +sendRowByte = _pcidio.sendRowByte + +def rstart(start): + return _pcidio.rstart(start) +rstart = _pcidio.rstart +# This file is compatible with both classic and new-style classes. + + diff --git a/riglib/ecube/digital_stream_methods.py b/riglib/ecube/digital_stream_methods.py new file mode 100644 index 000000000..26b70cf6a --- /dev/null +++ b/riglib/ecube/digital_stream_methods.py @@ -0,0 +1,76 @@ +import numpy as np +import pyeCubeStream + +def ts_values(data, srate): + """Finds the timestamp and corresponding value + of all the bit flips in data + author Leo + """ + logical_idx = np.insert(np.diff(data) != 0, 0, True) + time = np.arange(np.size(data))/srate + return time[logical_idx], data[logical_idx] + + +def ffs(x): + """Returns the index, counting from 0, of the + least significant set bit in `x`. + author leo + """ + return (x & -x).bit_length() - 1 + +def mask_and_shift(data, bit_mask): + """Apply bit mask and shift to the least + significant set bit + author leo: + """ + return np.bitwise_and(data, np.uint64(bit_mask)) >> np.uint64(ffs(bit_mask)) + +def test_mask_and_shift(): + return + +if __name__ == "__main__": + + ecubeDigital = pyeCubeStream.eCubeStream('DigitalPanel') + + ecubeDigital.start() + sample = ecubeDigital.get() + last_sample = sample[1][-1] + print(f'last digital value of all channels: {last_sample}') + + long_data = np.squeeze(np.uint64(last_sample)) + print(f'after some conversion: {long_data}') + + binary_rep = np.binary_repr(long_data) + print(f'binary representation:') + print(binary_rep + '\n') + + bits = list(binary_rep) + print(''.join(bits[-8:])) + print(''.join(bits[-16:-8])) + print(''.join(bits[-24:-16])) + print(''.join(bits[-32:-24])) + print(''.join(bits[-40:-32])) + print(''.join(bits[-48:-40])) + print(''.join(bits[-56:-48])) + print(''.join(bits[-64:-56])) + + masks = [0xff, + 0xff00, + 0xff0000, + 0xff000000, + 0xff00000000, + 0xff0000000000, + 0xff000000000000, + 0xff00000000000000] + + #code_num_1 = mask_and_shift(long_data,0xff) + + + print(f'\nprint the channels of in groups of 8') + print('MSB <- LSB') + for m in masks: + code_num_1 = mask_and_shift(long_data,m) + print(f'{np.binary_repr(code_num_1)}') + + #ts, values = ts_values(long_data, dat.samplerate) + diff --git a/riglib/experiment/experiment.py b/riglib/experiment/experiment.py index a0ef078f9..f030d335c 100644 --- a/riglib/experiment/experiment.py +++ b/riglib/experiment/experiment.py @@ -797,7 +797,8 @@ def _start_wait(self): new information needed to start the trial. If the generator runs out, the task stops. ''' if self.debug: - print("_start_wait") + # print("_start_wait") + pass try: self.next_trial = next(self.gen) diff --git a/riglib/hdfwriter/hdfwriter.py b/riglib/hdfwriter/hdfwriter.py index 01e6dd4f7..0af9c4ee6 100644 --- a/riglib/hdfwriter/hdfwriter.py +++ b/riglib/hdfwriter/hdfwriter.py @@ -37,6 +37,7 @@ def __init__(self, filename): self.data = {} self.msgs = {} self.f = [] + self.h5_file_name = filename def register(self, name, dtype, include_msgs=True): ''' @@ -134,5 +135,5 @@ def close(self): Close the HDF file so that it saves properly after the process terminates ''' self.h5.close() - print("Closed hdf") + print(f"Closed hdf with filename: {self.h5_file_name}") diff --git a/riglib/mp_proxy.py b/riglib/mp_proxy.py index 406e8efd0..a2580b529 100644 --- a/riglib/mp_proxy.py +++ b/riglib/mp_proxy.py @@ -10,9 +10,9 @@ class PipeWrapper(object): - def __init__(self, pipe=None, log_filename='', cmd_event=None, **kwargs): + def __init__(self, pipe=None, log_filename='hdf_sink.log', cmd_event=None, **kwargs): self.pipe = pipe - self.log_filename = log_filename + self.log_filename = "hdf_sink.log" self.cmd_event = cmd_event def log_error(self, err, mode='a'): @@ -165,7 +165,7 @@ def __init__(self, target_class=object, target_kwargs=dict(), log_filename='', * super().__init__() self.cmd_pipe = None self.data_pipe = None - self.log_filename = log_filename + self.log_filename = "hdf_log.log" self.target = None self.target_class = target_class diff --git a/riglib/optitrack_client/optitrack_interface_sijia.py b/riglib/optitrack_client/optitrack_interface_sijia.py index b00934cad..d82515247 100644 --- a/riglib/optitrack_client/optitrack_interface_sijia.py +++ b/riglib/optitrack_client/optitrack_interface_sijia.py @@ -20,6 +20,7 @@ class System(object): #a list of supported commands SUPPORTED_COMMANDS = [ "start_rec", + "stop_rec", "send_markers", 'send_rigid_bodies', 'stop', @@ -42,7 +43,7 @@ def __init__(self): self.rigid_body_count = 1 #for now,only one rigid body - def start(self, stream_type = "rb"): + def start(self, stream_type = "rigid_bodies"): """ stream_type """ @@ -70,20 +71,21 @@ def start(self, stream_type = "rb"): print(f"Connection to c# client \ {self.optitrack_ip_addr} has been established.") - if stream_type in self.SUPPORTED_STREAM_TYPES: + self.stream_type = stream_type self.send_command('send_'+stream_type) + print(f'stream_type set to {self.stream_type}') else: raise Exception(f'{stream_type} is not supported \n\n suported stream types are\n{self.SUPPORTED_STREAM_TYPES}') + - - + def sample_stream_data(self, n_grab_frames = 10): #automatically pull 10 frames # and cal the mean round trip time t1 = time.perf_counter() - for i in range(N_TEST_FRAMES): self.get() + for i in range(n_grab_frames): self.get() t2 = time.perf_counter() - print(f'time to grab {N_TEST_FRAMES} frames : \ + print(f'time to grab {n_grab_frames} frames : \ {(t2 - t1)} s ') @@ -154,9 +156,18 @@ def get(self): if __name__ == "__main__": s = System() - s.start() - s.send_command("start_rec") - #s.send_command("send_markers") - time.sleep(5) + s.start(stream_type = "rigid_bodies") + + s.sample_stream_data(n_grab_frames = 3) + time.sleep(2) + s.stop() + time.sleep(2) + + #start a new ssystem that streams markerset + s = System() + s.start(stream_type = "markers") + s.sample_stream_data(n_grab_frames = 3) + time.sleep(2) s.stop() + print("finished")