diff --git a/built_in_tasks/manualcontrolmultitasks.py b/built_in_tasks/manualcontrolmultitasks.py index 92ed268e..d2478745 100644 --- a/built_in_tasks/manualcontrolmultitasks.py +++ b/built_in_tasks/manualcontrolmultitasks.py @@ -45,6 +45,7 @@ def __init__(self, *args, **kwargs): def init(self): self.add_dtype('manual_input', 'f8', (3,)) + self.add_dtype('user_screen', 'f8', (3,)) super().init() self.no_data_counter = np.zeros((self._quality_window_size,), dtype='?') @@ -121,6 +122,7 @@ def move_effector(self, pos_offset=[0,0,0], vel_offset=[0,0,0]): self.no_data_counter[self.cycle_count % self._quality_window_size] = 1 self.update_report_stats() self.task_data['manual_input'] = np.ones((3,))*np.nan + self.task_data['user_screen'] = np.ones((3,))*np.nan return self.task_data['manual_input'] = raw_coords.copy() @@ -128,6 +130,7 @@ def move_effector(self, pos_offset=[0,0,0], vel_offset=[0,0,0]): # Transform coordinates coords = self._transform_coords(raw_coords) + self.task_data['user_screen'] = coords try: if self.limit2d: diff --git a/built_in_tasks/target_graphics.py b/built_in_tasks/target_graphics.py index 3a78e496..1fca55da 100644 --- a/built_in_tasks/target_graphics.py +++ b/built_in_tasks/target_graphics.py @@ -2,7 +2,7 @@ Base tasks for generic point-to-point reaching ''' import numpy as np -from riglib.stereo_opengl.primitives import Cable, Sphere, Cube, Torus, Text +from riglib.stereo_opengl.primitives import Cable, Snake, Sphere, Cube, Torus, Text from riglib.stereo_opengl.primitives import Cylinder, Plane, Sphere, Cube from riglib.stereo_opengl.models import FlatMesh, Group from riglib.stereo_opengl.textures import Texture, TexModel @@ -36,6 +36,8 @@ "gold": (0.941,0.637,0.25,0.75), "elephant":(0.5,0.5,0.5,0.5), "white": (1, 1, 1, 0.75), + "black": (0, 0, 0, 0.75), + "invisible": (0, 0, 0, 0.0), } class CircularTarget(object): @@ -186,7 +188,7 @@ def __init__(self, target_radius=1, target_color=(1, 0, 0, .5), starting_pos=np. self._pickle_init() def _pickle_init(self): - self.cable = Cable(radius=self.target_radius,trajectory = self.trajectory, color=self.target_color) + self.cable = Cable(radius=self.target_radius, xyz=self.trajectory, color=self.target_color) self.graphics_models = [self.cable] self.cable.translate(*self.position) @@ -238,6 +240,20 @@ def reset(self): def get_position(self): return self.cable.xfm.move +class VirtualSnakeTarget(VirtualCableTarget): + + def _pickle_init(self): + self.trajectory = np.array(self.trajectory) + self.cable = Snake(radius=self.target_radius, trajectory=self.trajectory, color=self.target_color) + self.graphics_models = [self.cable] + self.cable.translate(*self.position) + + def update_mask(self, start_frame, end_frame, inverse=False): + ''' + Update the texture mask of the snake target. + ''' + self.cable.update_texture(start_frame, end_frame, inverse=inverse) + class VirtualTorusTarget(VirtualCircularTarget): def __init__(self, inner_radius=2, outer_radius=3, target_color=(1, 0, 0, .5), starting_pos=np.zeros(3)): diff --git a/built_in_tasks/target_tracking_task.py b/built_in_tasks/target_tracking_task.py index a67f11cb..8994f61e 100644 --- a/built_in_tasks/target_tracking_task.py +++ b/built_in_tasks/target_tracking_task.py @@ -69,7 +69,7 @@ def init(self): self.trial_dtype = np.dtype([('trial', 'u4'), ('index', 'u4'), ('target', 'f8',(3,)), ('disturbance', 'f8',(3,)), ('is_disturbance', '?')]) super().init() - self.frame_index = 0 # index in the frame in a trial + self.frame_index = -1 # index in the frame in a trial self.total_distance_error = 0 # Euclidian distance between cursor and target during each trial self.trial_timed_out = True # check if the trial is finished self.plant_position = [] @@ -92,9 +92,6 @@ def _parse_next_trial(self): self.targs = np.squeeze(self.targs,axis=0) self.disturbance_path = np.squeeze(self.disturbance_path) - WIDTH, HEIGHT = self.window_size[0], self.window_size[1] - lookahead = np.zeros((self.lookahead,np.shape(self.targs)[1])) # (30,3) - self.targs = self.trajectory_amplitude*self.targs self.disturbance_path = self.disturbance_amplitude*self.disturbance_path # print(np.amax(self.targs), np.amax(self.disturbance_path)) @@ -104,18 +101,16 @@ def _parse_next_trial(self): self.ramp_counter[:int(self.ramp_up_time*self.sample_rate)] = 1 if self.ramp_down_time > 0: self.ramp_counter[-int(self.ramp_down_time*self.sample_rate):] = 2 - - self.targs = np.concatenate((lookahead, self.targs),axis=0) # (time_length*sample_rate+30,3) # targs and disturbance are no longer same length def tracking_task_start_wait(self): - print(self.gen_index) + # print(self.gen_index) self.trial_record['trial'] = self.calc_trial_num() self.trial_record['index'] = self.gen_index self.trial_record['is_disturbance'] = self.disturbance_trial for i in range(len(self.disturbance_path)): # Update the data sinks with trial information --> bmi3d_trials - self.trial_record['target'] = self.targs[i+self.lookahead] + self.trial_record['target'] = self.targs[i] self.trial_record['disturbance'] = self.disturbance_path[i] self.sinks.send("trials", self.trial_record) @@ -330,19 +325,19 @@ def _test_hold_complete_no_ramp(self, time_in_state): def _test_ramp_complete(self, time_in_state): '''Test whether the ramp up is finished''' - return self.frame_index == self.ramp_up_time*self.sample_rate + return self.frame_index-1 == self.ramp_up_time*self.sample_rate def _test_traj_complete(self, time_in_state): '''Test whether the trajectory is finished and whether there is a ramp down before the trial ends''' - return (self.frame_index + self.lookahead == self.trajectory_length - self.ramp_down_time*self.sample_rate) and (self.ramp_down_time > 0) + return (self.frame_index-1 == self.trajectory_length - self.ramp_down_time*self.sample_rate) and (self.ramp_down_time > 0) def _test_ramp_and_trial_complete(self, time_in_state): '''Test whether the ramp down is finished, ending the trial''' - return (self.frame_index + self.lookahead == self.trajectory_length) and (self.ramp_down_time > 0) + return (self.frame_index == self.trajectory_length) and (self.ramp_down_time > 0) def _test_trial_complete(self, time_in_state): '''Test whether the trajectory is finished, ending the trial''' - return (self.frame_index + self.lookahead == self.trajectory_length) and (self.ramp_down_time == 0) + return (self.frame_index == self.trajectory_length) and (self.ramp_down_time == 0) def _test_tracking_out_timeout(self, time_in_state): return time_in_state > self.tracking_out_time @@ -405,6 +400,8 @@ class ScreenTargetTracking(TargetTracking, Window): target_radius = traits.Float(2, desc="Radius of targets in cm") #2,0.75 trajectory_radius = traits.Float(.5, desc="Radius of targets in cm") trajectory_color = traits.OptionsList("gold", *target_colors, desc="Color of the trajectory", bmi3d_input_options=list(target_colors.keys())) + trajectory_type = traits.OptionsList("1d", ["1d", "2d", "none"], desc="Type of trajectory to use", bmi3d_input_options=["1d", "2d", "none"]) + lookahead_time = traits.Float(0.5, desc="Time in seconds to display the future trajectory") target_color = traits.OptionsList("yellow", *target_colors, desc="Color of the target", bmi3d_input_options=list(target_colors.keys())) plant_hide_rate = traits.Float(0.0, desc='If the plant is visible, specifies a percentage of trials where it will be hidden') plant_type = traits.OptionsList(*plantlist, bmi3d_input_options=list(plantlist.keys())) @@ -429,7 +426,8 @@ def __init__(self, *args, **kwargs): self.plant.set_cursor_radius(self.cursor_radius) self.plant_vis_prev = True self.cursor_vis_prev = True - self.lookahead = 30 # number of frames to create a "lookahead" window of 0.5 seconds (half the screen) + self.lookahead = int(self.fps * self.lookahead_time) # convert to frames + self.lookahead_scale = (0.5 * self.screen_cm[0]) / self.lookahead # cm per frame self.original_limit1d = self.limit1d # keep track of original settable trait if not self.always_1d: @@ -445,17 +443,14 @@ def __init__(self, *args, **kwargs): if instantiate_targets: # This is the center target being followed by the user self.target = VirtualCircularTarget(target_radius=self.target_radius, target_color=target_colors[self.target_color]) - - # This is the trajectory that spans the screen - self.trajectory = VirtualCableTarget(target_radius=self.trajectory_radius, target_color=target_colors[self.trajectory_color]) + for model in self.target.graphics_models: + self.add_model(model) # This is the optional progress bar (off by default) self.bar = VirtualRectangularTarget(target_width=1, target_height=0, target_color=(0., 1., 0., 0.75), starting_pos=[0,-15,9]) # print('INIT TRAJ') - - # This is a black cube that optionally hides the "lookbehind" of trajectory (off by default) - self.box = VirtualRectangularTarget(target_width=20, target_height=10, target_color=(0, 0, 0, 1), starting_pos=[-10,-1,0]) - # target_width of RectangularTarget is total height, target_height is 1/2 of total width (from center to edge) + for model in self.bar.graphics_models: + self.add_model(model) # Declare any plant attributes which must be saved to the HDF file at the _cycle rate for attr in self.plant.hdf_attrs: @@ -465,9 +460,8 @@ def init(self): self.add_dtype('trial', 'u4', (1,)) self.add_dtype('gen_idx', 'int', (1,)) # dtype needs to be able to represent -1 self.add_dtype('plant_visible', '?', (1,)) - self.add_dtype('current_target', 'f8', (3,)) - self.add_dtype('current_disturbance', 'f8', (3,)) # see task_data['manual_input'] for cursor position without added disturbance - self.add_dtype('current_target_validate', 'f8', (3,)) + self.add_dtype('target', 'f8', (3,)) + self.add_dtype('disturbance', 'f8', (3,)) super().init() self.plant.set_endpoint_pos(np.array(self.starting_pos)) @@ -475,6 +469,10 @@ def _cycle(self): ''' Calls any update functions necessary and redraws screen ''' + # if self.frame_index >= 0: + # print('FRAME ', self.frame_index, self.get_state(), self.trial_timed_out) + # print(self.target.get_position()[2], self.pos_offset[2]) + self.move_effector(pos_offset=np.asarray(self.pos_offset), vel_offset=np.asarray(self.vel_offset)) # Run graphics commands to show/hide the plant if the visibility has changed @@ -489,17 +487,10 @@ def _cycle(self): # Update the trial index self.task_data['trial'] = self.calc_trial_num() self.task_data['gen_idx'] = self.gen_index - # print(self.task_data['gen_idx']) - # Save the target position at each cycle. - if self.trial_timed_out: - self.task_data['current_target'] = [np.nan,np.nan,np.nan] - self.task_data['current_disturbance'] = [np.nan,np.nan,np.nan] - self.task_data['current_target_validate'] = self.target.get_position() # default VirtualCircularTarget position is [0,0,0] - else: - self.task_data['current_target'] = self.targs[self.frame_index+self.lookahead] - self.task_data['current_disturbance'] = self.disturbance_path[self.frame_index] - self.task_data['current_target_validate'] = self.target.get_position() + # Save the target position and disturbance value at each cycle + self.task_data['target'] = self.target.get_position() + self.task_data['disturbance'] = self.pos_offset super()._cycle() @@ -528,30 +519,19 @@ def update_plant_visibility(self): self.plant.set_visibility(self.plant_visible) def update_frame(self): - self.target.move_to_position(np.array([0,0,self.targs[self.frame_index+self.lookahead][2]])) # xzy - self.trajectory.move_to_position(np.array([-self.frame_index/3,10,0])) # same update constant works for 60 and 120 hz - self.target.show() - self.trajectory.show() - self.frame_index +=1 - - def setup_start_wait(self): - if self.calc_trial_num() == 0: - # Instantiate the targets here so they don't show up in any states that might come before "wait" - for model in self.target.graphics_models: - self.add_model(model) - self.target.hide() - - for model in self.trajectory.graphics_models: - self.add_model(model) - self.trajectory.hide() + if self.trial_timed_out: + use_frame_index = self.frame_index - 1 + else: + use_frame_index = self.frame_index - for model in self.bar.graphics_models: - self.add_model(model) - self.bar.hide() + self.target.move_to_position(self.targs[use_frame_index]) + if self.trajectory_type == '1d': + self.trajectory.move_to_position(np.array([-use_frame_index*self.lookahead_scale - self.lookahead*self.lookahead_scale,0,0])) + elif self.trajectory_type == '2d': + self.trajectory.update_mask(use_frame_index, use_frame_index+self.lookahead) + self.frame_index += 1 # increment the frame_index for the following cycle - for model in self.box.graphics_models: - self.add_model(model) - self.box.hide() + def setup_start_wait(self): # Allow 2d movement if not self.always_1d: @@ -562,18 +542,30 @@ def setup_start_wait(self): self.tracking_frame_index = 0 # Set up the next trajectory - next_trajectory = np.array(np.squeeze(self.targs)[:,2]) - next_trajectory[:self.lookahead] = next_trajectory[self.lookahead] - if hasattr(self, 'trajectory'): for model in self.trajectory.graphics_models: self.remove_model(model) del self.trajectory - - self.trajectory = VirtualCableTarget(target_radius=self.trajectory_radius, target_color=target_colors[self.trajectory_color], trajectory=next_trajectory) + if self.trajectory_type == '1d': + + lookbehind = self.targs[1,:]*np.ones((self.lookahead, np.shape(self.targs)[1])) + next_trajectory = np.concatenate((lookbehind, self.targs), axis=0) + next_trajectory = np.array(np.squeeze(next_trajectory)[:,2]) + next_trajectory = np.vstack([ + self.lookahead_scale * np.arange(len(next_trajectory)), # set the lookahead by scaling the trajectory to fit in the screen + np.zeros(len(next_trajectory)), + next_trajectory + ]).T + self.trajectory = VirtualSnakeTarget(target_radius=self.trajectory_radius, target_color=target_colors[self.trajectory_color], trajectory=next_trajectory) + elif self.trajectory_type == '2d': + self.trajectory = VirtualSnakeTarget(target_radius=self.trajectory_radius, target_color=target_colors[self.trajectory_color], trajectory=self.targs) + self.trajectory.update_mask(self.frame_index, self.frame_index+self.lookahead) + else: # 'none' + next_trajectory = np.zeros((self.lookahead, 3)) + self.trajectory = VirtualCircularTarget() for model in self.trajectory.graphics_models: - self.add_model(model) + self.add_model(model) self.target.hide() self.trajectory.hide() @@ -586,8 +578,6 @@ def setup_screen_reset(self): self.trajectory.reset() self.bar.hide() self.bar.reset() - self.box.hide() - self.box.reset() def setup_start_tracking_in(self): # Revert to settable trait @@ -595,28 +585,45 @@ def setup_start_tracking_in(self): # Cue successful tracking self.target.cue_trial_end_success() + if self.frame_index == 0: + # Add disturbance + if self.disturbance_trial == True: + cursor_pos = self.plant.get_endpoint_pos() + if self.velocity_control: + # TODO check manualcontrolmixin for how to implement velocity control + self.vel_offset = (cursor_pos + self.disturbance_path[self.frame_index])*1/self.fps + else: + # position control + self.pos_offset = self.disturbance_path[self.frame_index] + + # Move target and trajectory to next frame so it appears to be moving + self.update_frame() + def setup_start_tracking_out(self): # Reset target color self.target.reset() def setup_while_tracking(self): - # Add disturbance - cursor_pos = self.plant.get_endpoint_pos() - if self.disturbance_trial == True: - if self.velocity_control: - # TODO check manualcontrolmixin for how to implement velocity control - self.vel_offset = (cursor_pos + self.disturbance_path[self.frame_index])*1/self.fps - else: - # position control - self.pos_offset = self.disturbance_path[self.frame_index] + # Check whether there are no more target frames to display + if self.frame_index + self.lookahead == self.trajectory_length: + self.trial_timed_out = True + self.pos_offset = [0,0,0] + self.vel_offset = [0,0,0] + + else: + # Add disturbance + if self.disturbance_trial == True: + cursor_pos = self.plant.get_endpoint_pos() + if self.velocity_control: + # TODO check manualcontrolmixin for how to implement velocity control + self.vel_offset = (cursor_pos + self.disturbance_path[self.frame_index])*1/self.fps + else: + # position control + self.pos_offset = self.disturbance_path[self.frame_index] # Move target and trajectory to next frame so it appears to be moving self.update_frame() - # Check if the trial is over and there are no more target frames to display - if self.frame_index+self.lookahead >= np.shape(self.targs)[0]: - self.trial_timed_out = True - #### TEST FUNCTIONS #### def _test_enter_target(self, time_in_state): ''' @@ -654,15 +661,17 @@ def _while_wait_retry(self): def _start_trajectory(self): super()._start_trajectory() if self.frame_index == 0: - self.target.move_to_position(np.array([0,0,self.targs[self.frame_index+self.lookahead][2]])) # tablet screen x-axis ranges -19,19, center 0 - self.trajectory.move_to_position(np.array([0,10,0])) # tablet screen x-axis ranges 0,41.33333, center 22ish - # print(self.target.get_position()) - # print(self.trajectory.get_position()) + self.target.move_to_position(self.targs[self.frame_index]) + if self.trajectory_type == '1d': + self.trajectory.move_to_position(np.array([-self.lookahead*self.lookahead_scale,0,0])) self.target.show() - self.trajectory.show() + if self.trajectory_type != "none": + self.trajectory.show() # print('SHOW TRAJ') self.sync_event('TARGET_ON') + else: + print('WARNING: trajectory state started with frame_index != 0', self.frame_index) def _while_trajectory(self): super()._while_trajectory() @@ -680,8 +689,8 @@ def _while_hold(self): def _start_tracking_in_ramp(self): super()._start_tracking_in_ramp() self.setup_start_tracking_in() - # print('START TRACKING RAMP', self.ramp_counter[self.frame_index]) - self.sync_event('CURSOR_ENTER_TARGET', self.ramp_counter[self.frame_index]) + # print('START TRACKING IN RAMP', self.ramp_counter[self.frame_index-1]) + self.sync_event('CURSOR_ENTER_TARGET', self.ramp_counter[self.frame_index-1]) def _while_tracking_in_ramp(self): super()._while_tracking_in_ramp() @@ -700,8 +709,8 @@ def _while_tracking_in(self): def _start_tracking_out_ramp(self): super()._start_tracking_out_ramp() self.setup_start_tracking_out() - # print('STOP TRACKING RAMP', self.ramp_counter[self.frame_index]) - self.sync_event('CURSOR_LEAVE_TARGET', self.ramp_counter[self.frame_index]) + # print('START TRACKING OUT RAMP', self.ramp_counter[self.frame_index-1]) + self.sync_event('CURSOR_LEAVE_TARGET', self.ramp_counter[self.frame_index-1]) def _while_tracking_out_ramp(self): super()._while_tracking_out_ramp() @@ -936,7 +945,118 @@ def generate_trajectories(num_trials=2, time_length=20, seed=40, sample_rate=120 return trials, trial_order @staticmethod - def generate_trajectory(primes, base_period, ramp = .0): + def generate_2D_trajectories(num_trials=2, time_length=20, seed=40, sample_rate=120, base_period=20, ramp=0, ramp_down=0, num_primes=8, use_disturb = True, decay_rate = None): + ''' + Sets up variables and uses prime numbers to call the above functions and generate trajectories in both x & y + ramp is time length for preparatory lines + ''' + np.random.seed(seed) + hz = sample_rate # Hz -- sampling rate + dt = 1/hz # sec -- sampling period + + T0 = base_period # sec -- base period + w0 = 1./T0 # Hz -- base frequency + + r = ramp # "ramp up" duration (see sum_of_sines_ramp) + rd = ramp_down # "ramp down" duration (see sum_of_sines_ramp) + P = time_length/T0 # number of periods in signal + T = P*T0+r+rd # sec -- signal duration + dw = 1./T # Hz -- frequency resolution + W = 1./dt/2 # Hz -- signal bandwidth + + full_primes = np.asarray([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, + 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199]) + + xprimes = full_primes[:num_primes:2] #even elements of full primes + yprimes = full_primes[1:num_primes:2] #odd elements of full primes + + f_x = xprimes*w0 # stimulated frequencies + f_y = yprimes*w0 + + if decay_rate is None: #use linear decay rate to generate amplitudes if no decay rate is specified + a_x = 1/(1+np.arange(f_x.size)) # amplitude linear/reciprocal + a_y = 1/(1+np.arange(f_y.size)) + + if decay_rate is not None: #if the decay rate is a value then use exponential decay to generate amplitiudes + k = decay_rate + a_x = np.exp(-k*np.arange(f_x.size)) # amplitude exponential decay + a_y = np.exp(-k*np.arange(f_y.size)) + + # phase offset + o_x = np.random.rand(num_trials, f_x.size) + o_xdis = o_x*0.8 + + o_y = np.random.rand(num_trials, f_y.size) + o_ydis = o_y*0.8 + + t = np.arange(0,T,dt) # time samplesseed + w = np.arange(0,W,dw) # frequency samples + + N = t.size # = T/dt -- number of samples + + # create trials dictionary with x & y + trials = dict( + id=np.arange(num_trials), + times=np.tile(t,(num_trials,1)), + ref_x=np.zeros((num_trials,N)), + dis_x=np.zeros((num_trials,N)), + ref_y=np.zeros((num_trials,N)), + dis_y=np.zeros((num_trials,N)) + ) + + # randomize order of first two trials to provide random starting point + order = np.random.choice([0,1]) + if order == 0: + trial_order = [(1,'E','O'),(1,'O','E')] + elif order == 1: + trial_order = [(1,'O','E'),(1,'E','O')] + + # generate reference and disturbance trajectories for all trials + for trial_id, (num_reps,ref_ind,dis_ind) in enumerate(trial_order*int(num_trials/2)): + if ref_ind == 'E': + sines_r = np.arange(len(xprimes))[0::2] # use even indices + + elif ref_ind == 'O': + sines_r = np.arange(len(xprimes))[1::2] # use odd indices + + else: + sines_r = np.arange(len(xprimes)) + if dis_ind == 'E': + sines_d = np.arange(len(xprimes))[0::2] + + elif dis_ind == 'O': + sines_d = np.arange(len(xprimes))[1::2] + + else: + sines_d = np.arange(len(xprimes)) #every element in vector + + if use_disturb == False: #use both odd and even indices for reference trajectory if distrubance is turned off. + # generate X-dimension + refx_traj, ref_Ax = ScreenTargetTracking.calc_sum_of_sines_ramp(t, r, rd, f_x[np.arange(len(xprimes))], a_x[np.arange(len(xprimes))], o_x[trial_id][np.arange(len(xprimes))]) + disx_traj, dis_Ax = ScreenTargetTracking.calc_sum_of_sines_ramp(t,r,rd, f_x[sines_d], a_x[sines_d], o_xdis[trial_id][sines_d]) + # generate Y-dimension + refy_traj, ref_Ay = ScreenTargetTracking.calc_sum_of_sines_ramp(t, r, rd, f_y[np.arange(len(xprimes))], a_y[np.arange(len(xprimes))], o_y[trial_id][np.arange(len(xprimes))]) + disy_traj, dis_Ay = ScreenTargetTracking.calc_sum_of_sines_ramp(t, r, rd, f_y[sines_d], a_y[sines_d], o_ydis[trial_id][sines_d]) + + else: + # generate X-dimension + refx_traj, ref_Ax = ScreenTargetTracking.calc_sum_of_sines_ramp(t, r, rd, f_x[sines_r], a_x[sines_r], o_x[trial_id][sines_r]) + disx_traj, dis_Ax = ScreenTargetTracking.calc_sum_of_sines_ramp(t,r,rd, f_x[sines_d], a_x[sines_d], o_xdis[trial_id][sines_d]) + # generate Y-dimension + refy_traj, ref_Ay = ScreenTargetTracking.calc_sum_of_sines_ramp(t, r, rd, f_y[sines_r], a_y[sines_r], o_y[trial_id][sines_r]) + disy_traj, dis_Ay = ScreenTargetTracking.calc_sum_of_sines_ramp(t, r, rd, f_y[sines_d], a_y[sines_d], o_ydis[trial_id][sines_d]) + + + # normalized trajectories + trials['ref_x'][trial_id] = refx_traj/ref_Ax # previously, denominator was np.sum(a_ref) + trials['dis_x'][trial_id] = disx_traj/dis_Ax # previously, denominator was np.sum(a_dis) + trials['ref_y'][trial_id] = refy_traj/ref_Ay # previously, denominator was np.sum(a_ref) + trials['dis_y'][trial_id] = disy_traj/dis_Ay # previously, denominator was np.sum(a_dis) + + return trials, trial_order + + @staticmethod + def generate_trajectory(primes, base_period, ramp = 0.0): ''' Sets up variables and uses prime numbers to call the above functions and generate then trajectories ramp is time length for preparatory lines @@ -969,7 +1089,7 @@ def generate_trajectory(primes, base_period, ramp = .0): ### Generator functions #### @staticmethod - def tracking_target_chain(nblocks=1, ntrials=500, time_length=20, ramp=1.5, ramp_down=1.5, num_primes=8, seed=40, sample_rate=120, disturbance=True, boundaries=(-10,10,-10,10)): + def tracking_target_chain(nblocks=1, ntrials=500, time_length=20, ramp=1.5, ramp_down=1.5, num_primes=8, seed=40, sample_rate=120, dimensions = 1, disturbance=True, boundaries=(-10,10,-10,10), decay_rate = None): ''' Generates a sequence of 1D (z axis) target trajectories @@ -987,10 +1107,14 @@ def tracking_target_chain(nblocks=1, ntrials=500, time_length=20, ramp=1.5, ramp The sample rate of the generated trajectories ramp : float The length of ramp up into a trial in seconds + dimensions: int + Number of dimensions to generate trajectories for (1 or 2) disturbance : boolean Whether to add disturbance to the cursor (disturbance is generated regardless) boundaries: 4 element tuple The limits of the allowed target locations (-x, x, -z, z) + decay_rate: None or float + This generates amplitudes using a decay_rate. Used for 2d trajectories. If set to None (default), amplitudes are generated using a linear decay. Returns ------- @@ -1001,19 +1125,38 @@ def tracking_target_chain(nblocks=1, ntrials=500, time_length=20, ramp=1.5, ramp ''' idx = 0 base_period = 20 - for block_id in range(nblocks): - trials, trial_order = ScreenTargetTracking.generate_trajectories( - num_trials=ntrials, time_length=time_length, seed=seed, sample_rate=sample_rate, base_period=base_period, ramp=ramp, ramp_down=ramp_down, num_primes=num_primes - ) - for trial_id in range(ntrials): - targs = [] - ref_trajectory = np.zeros((int((time_length+ramp+ramp_down)*sample_rate),3)) - dis_trajectory = ref_trajectory.copy() - ref_trajectory[:,2] = trials['ref'][trial_id] - dis_trajectory[:,2] = trials['dis'][trial_id] # scale will determine lower limit of target size for perfect tracking - targs.append(ref_trajectory) - yield idx, targs, disturbance, dis_trajectory, sample_rate, ramp, ramp_down - idx += 1 + for block_id in range(nblocks): + if dimensions == 1: + trials, trial_order = ScreenTargetTracking.generate_trajectories( + num_trials=ntrials, time_length=time_length, seed=seed, sample_rate=sample_rate, base_period=base_period, ramp=ramp, ramp_down=ramp_down, num_primes=num_primes + ) + for trial_id in range(ntrials): + targs = [] + ref_trajectory = np.zeros((int((time_length+ramp+ramp_down)*sample_rate),3)) + dis_trajectory = ref_trajectory.copy() + ref_trajectory[:,2] = trials['ref'][trial_id] + dis_trajectory[:,2] = trials['dis'][trial_id] # scale will determine lower limit of target size for perfect tracking + targs.append(ref_trajectory) + yield idx, targs, disturbance, dis_trajectory, sample_rate, ramp, ramp_down + idx += 1 + + if dimensions == 2: + trials, trial_order = ScreenTargetTracking.generate_2D_trajectories( + num_trials=ntrials, time_length=time_length, seed=seed, sample_rate=sample_rate, base_period=base_period, ramp=ramp, ramp_down=ramp_down, num_primes=num_primes, use_disturb = disturbance, + decay_rate = decay_rate) + for trial_id in range(ntrials): + targs = [] + ref_trajectory = np.zeros((int((time_length+ramp+ramp_down)*sample_rate),3)) + dis_trajectory = ref_trajectory.copy() + + ref_trajectory[:,2] = trials['ref_y'][trial_id] #y is out of the screen, x is left and right and z is up and down + ref_trajectory[:,0] = trials['ref_x'][trial_id] + + dis_trajectory[:,2] = trials['dis_y'][trial_id] + dis_trajectory[:,0] = trials['dis_x'][trial_id] + targs.append(ref_trajectory) + yield idx, targs, disturbance, dis_trajectory, sample_rate, ramp, ramp_down + idx += 1 @staticmethod def tracking_target_debug(nblocks=1, ntrials=2, time_length=20, seed=40, sample_rate=60, ramp=0, disturbance=True, boundaries=(-10,10,-10,10)): diff --git a/features/generator_features.py b/features/generator_features.py index 811b30dc..187d3286 100644 --- a/features/generator_features.py +++ b/features/generator_features.py @@ -328,15 +328,19 @@ def _start_reward(self): class HideLeftTrajectory(traits.HasTraits): ''' - Cover left side of tracking task screen with a black box. + Hide the left side of the tracking trajectory. This will cover the 'lookbehind' of the target trajectory. Useful for task with bumpers. ''' - def _start_trajectory(self): - super()._start_trajectory() - if self.frame_index == 0: - self.box.show() + def setup_start_wait(self): + super().setup_start_wait() + print(self.frame_index) + self.trajectory.update_mask(self.lookahead+2, self.lookahead*2) + + def update_frame(self): + super().update_frame() + self.trajectory.update_mask(self.frame_index+self.lookahead+1, self.frame_index+2*self.lookahead) class ReadysetMedley(traits.HasTraits): diff --git a/riglib/stereo_opengl/models.py b/riglib/stereo_opengl/models.py index 595da561..dba54bae 100644 --- a/riglib/stereo_opengl/models.py +++ b/riglib/stereo_opengl/models.py @@ -37,6 +37,9 @@ def __init__(self, shader="default", color=(0.5, 0.5, 0.5, 1), # The orientation of the object, in the world frame self._xfm = self.xfm self.allocated = False + + # Keep track of the model's size for rendering + self.bounding_radius = 0.0 def __setattr__(self, attr, xfm): '''Checks if the xfm was changed, and recaches the _xfm which is sent to the shader''' @@ -185,13 +188,18 @@ def init(self): model.init() def render_queue(self, xfm=np.eye(4), **kwargs): - for model in self.models: + def sort_key(obj): + pos = obj.xfm.move[1] + radius = obj.bounding_radius + dist = pos - radius + return -dist # Negative for far-to-near sorting + sorted_models = sorted(self.models, key=sort_key) + for model in sorted_models: for out in model.render_queue(**kwargs): yield out def draw(self, ctx, **kwargs): - sorted_models = sorted(self.models, key=lambda obj: -obj.xfm.move[1]) - for model in sorted_models: + for model in self.models: model.draw(ctx, **kwargs) def __getitem__(self, idx): @@ -232,6 +240,8 @@ def __init__(self, verts, polys, normals=None, tcoords=None, **kwargs): self.polys = polys self.tcoords = tcoords self.normals = normals + + self.bounding_radius = abs(np.min(self.verts[:, 1])) def init(self): allocated = super(TriMesh, self).init() diff --git a/riglib/stereo_opengl/primitives.py b/riglib/stereo_opengl/primitives.py index 99dfdc6a..6414425b 100644 --- a/riglib/stereo_opengl/primitives.py +++ b/riglib/stereo_opengl/primitives.py @@ -15,7 +15,7 @@ from .models import TriMesh from .textures import Texture, TexModel -from OpenGL.GL import GL_NEAREST, GL_REPEAT +from OpenGL.GL import * from PIL import Image, ImageDraw, ImageFont import matplotlib.font_manager as fm @@ -154,33 +154,68 @@ def __init__(self, height=1, radius=1, segments=36, **kwargs): super().__init__(total_pts, total_polys, tcoords=total_tcoords, normals=total_normals, **kwargs) class Cable(TriMesh): - def __init__(self,radius=.5, trajectory = np.array([np.sin(x) for x in range(60)]), segments=12,**kwargs): - self.trial_trajectory = trajectory + def __init__(self,radius=.5, xyz = np.array([np.sin(x) for x in range(60)]), segments=12,**kwargs): + self.xyz = xyz + if np.ndim(xyz) == 1: + self.xyz = np.array([[x,0,xyz[x]] for x in range(len(xyz))]) self.center_value = [0,0,0] self.radius = radius self.segments = segments self.update(**kwargs) def update(self, **kwargs): - theta = np.linspace(0, 2*np.pi, self.segments, endpoint=False) - unit = np.array([np.ones(self.segments),np.cos(theta) ,np.sin(theta)]).T - intial = np.array([[0,0,self.trial_trajectory[x]] for x in range(len(self.trial_trajectory))]) - self.pts = (unit*[-30/1.36,self.radius,self.radius])+intial[0] - for i in range(1,len(intial)): - self.pts = np.vstack([self.pts, (unit*[(i-30)/3,self.radius,self.radius])+intial[i]]) - - self.normals = np.vstack([unit*[1,1,0], unit*[1,1,0]]) - self.polys = [] - for i in range(self.segments-1): - for j in range(len(intial)-1): - self.polys.append((i+j*self.segments, i+1+j*self.segments, i+self.segments+j*self.segments)) - self.polys.append((i+self.segments+j*self.segments, i+1+j*self.segments, i+1+self.segments+j*self.segments)) - - tcoord = np.array([np.arange(self.segments), np.ones(self.segments)]).T - n = 1./self.segments - self.tcoord = np.vstack([tcoord*[n,1], tcoord*[n,0]]) - super(Cable, self).__init__(self.pts, np.array(self.polys), - tcoords=self.tcoord, normals=self.normals, **kwargs) + theta = np.linspace(0, 2 * np.pi, self.segments, endpoint=False) + circle = np.stack([np.cos(theta), np.sin(theta)], axis=1) # (segments, 2) + + pts = [] + normals = [] + tcoords = [] + n_path = len(self.xyz) + + a = np.array([0, 1, 0]) # fixed up direction + + # Compute tangents along path + tangents = np.gradient(self.xyz, axis=0) + tangents = tangents / np.linalg.norm(tangents, axis=1, keepdims=True) + + for i in range(n_path): + p = self.xyz[i] + t = tangents[i] + + # Ring orientation + b = np.cross(t, a) + if np.linalg.norm(b) < 1e-6: + b = np.array([0, 0, 1]) # fallback + else: + b = b / np.linalg.norm(b) + + for j in range(self.segments): + cx, cy = circle[j] + offset = self.radius * (cx * b + cy * a) + pts.append(p + offset) + normals.append(offset / np.linalg.norm(offset)) + tcoords.append([j / self.segments, i / (n_path - 1)]) + + self.pts = np.array(pts) + self.normals = np.array(normals) + self.tcoord = np.array(tcoords) + + # Create triangle strips between rings + polys = [] + for i in range(n_path - 1): + for j in range(self.segments): + i0 = i * self.segments + j + i1 = i * self.segments + (j + 1) % self.segments + i2 = (i + 1) * self.segments + j + i3 = (i + 1) * self.segments + (j + 1) % self.segments + + polys.append((i2, i1, i0)) + polys.append((i3, i1, i2)) + + self.polys = np.array(polys) + + super().__init__(self.pts, self.polys, tcoords=self.tcoord, + normals=self.normals, **kwargs) class Torus(TriMesh): ''' @@ -473,6 +508,46 @@ def __init__(self, radius, alpha=1, stop=False, **kwargs): texture_mapping='planar', **kwargs) self.rotate_x(90) # Make the target face the camera +class Snake(Cable, TexModel): + ''' + A Cable with a gradient texture applied along its length. + ''' + def __init__(self, radius=.5, trajectory=np.array([np.sin(x) for x in range(100)]), segments=12, **kwargs): + self.trajectory = trajectory + color = kwargs.pop('color', [1, 1, 1, 1]) # Default color if not provided + self.color = color + tex = self.get_texture(0, len(trajectory)) + super().__init__(radius, trajectory, segments, tex=tex, color=[0, 0, 0, 1], **kwargs) + self.color = color # Store the color for later use + + def get_texture(self, start_frame, end_frame, inverse=False): + mask = np.zeros((len(self.trajectory))) + if start_frame >= len(self.trajectory): + start_frame = len(self.trajectory) + if end_frame >= len(self.trajectory): + end_frame = len(self.trajectory) + mask[start_frame:end_frame] = 1 + if inverse: + mask = 1 - mask + mask = np.tile(mask, (4, 1)).T # Repeat for RGBA + mask = self.color * mask # Apply color + tex = Texture(mask.reshape((1, len(mask), 4))) # Reshape to (1, n_colors, 4) + return tex + + def update_texture(self, start_frame, end_frame, inverse=False): + ''' + Update the texture of the snake based on the new trajectory. + ''' + self.tex.delete() # Delete the old texture + tex = self.get_texture(start_frame, end_frame, inverse=inverse) + self.tex = tex + self.tex.init() + + def draw(self, ctx): + glDisable(GL_DEPTH_TEST) + super().draw(ctx) + glEnable(GL_DEPTH_TEST) + ##### 2-D primitives ##### class Shape2D(object): diff --git a/riglib/stereo_opengl/shaders/none.f.glsl b/riglib/stereo_opengl/shaders/none.f.glsl index 675df621..ebc74cff 100644 --- a/riglib/stereo_opengl/shaders/none.f.glsl +++ b/riglib/stereo_opengl/shaders/none.f.glsl @@ -13,5 +13,7 @@ void main() { texcolor.rgb + basecolor.rgb, texcolor.a * basecolor.a ); + if (frag_diffuse.a < 0.01) + discard; FragColor = frag_diffuse; } diff --git a/riglib/stereo_opengl/shaders/phong.f.glsl b/riglib/stereo_opengl/shaders/phong.f.glsl index b6ae4f84..d7c872e6 100644 --- a/riglib/stereo_opengl/shaders/phong.f.glsl +++ b/riglib/stereo_opengl/shaders/phong.f.glsl @@ -58,6 +58,8 @@ vec4 phong() { texcolor.rgb + basecolor.rgb, texcolor.a * basecolor.a ); + if (frag_diffuse.a < 0.01) + discard; vec4 diffuse_factor = max(-dot(normal, mv_light_direction), 0.0) * light_diffuse; diff --git a/riglib/stereo_opengl/textures.py b/riglib/stereo_opengl/textures.py index 8ac0944b..839a7830 100644 --- a/riglib/stereo_opengl/textures.py +++ b/riglib/stereo_opengl/textures.py @@ -26,11 +26,13 @@ def __init__(self, tex, size=None, if isinstance(tex, np.ndarray): if tex.max() <= 1: - tex *= 255 - if len(tex.shape) < 3: - tex = np.tile(tex, [3, 1, 1]).T - if tex.shape[-1] == 3: - tex = np.dstack([tex, np.ones(tex.shape[:-1])]) + tex = (tex * 255).astype(np.uint8) + else: + tex = tex.astype(np.uint8) + if tex.ndim == 2: + tex = np.stack([tex]*3, axis=-1) # grayscale → RGB + elif tex.shape[-1] == 1: + tex = np.repeat(tex, 3, axis=-1) size = tex.shape[:2] tex = tex.astype(np.uint8).tobytes() elif isinstance(tex, str): diff --git a/riglib/stereo_opengl/window.py b/riglib/stereo_opengl/window.py index e877855c..9da27513 100644 --- a/riglib/stereo_opengl/window.py +++ b/riglib/stereo_opengl/window.py @@ -92,7 +92,7 @@ def screen_init(self): glDisable(GL_FRAMEBUFFER_SRGB) # disable gamma correction glEnable(GL_BLEND) - glDepthFunc(GL_LESS) + glDepthFunc(GL_LEQUAL) glEnable(GL_DEPTH_TEST) glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) glClearColor(*self.background) diff --git a/setup.py b/setup.py index e1a4154b..10bbec2f 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="aolab-bmi3d", - version="1.2.4", + version="1.2.5", author="Lots of people", description="electrophysiology experimental rig library", packages=setuptools.find_packages(), diff --git a/tests/test_graphics.py b/tests/test_graphics.py index 84d7bee3..06f1f9fb 100644 --- a/tests/test_graphics.py +++ b/tests/test_graphics.py @@ -11,7 +11,7 @@ from riglib.stereo_opengl.environment import Grid from riglib.stereo_opengl.window import Window, Window2D, FPScontrol -from riglib.stereo_opengl.primitives import AprilTag, Cylinder, Cube, Plane, Sphere, Cone, Text, TexSphere, TexCube, TexPlane +from riglib.stereo_opengl.primitives import AprilTag, Cylinder, Cube, Plane, Snake, Sphere, Cone, Text, Cable, TexSphere, TexCube, TexPlane from features.optitrack_features import SpheresToCylinders from riglib.stereo_opengl.window import Window, Window2D, FPScontrol, WindowSSAO from riglib.stereo_opengl.openxr import WindowVR @@ -39,13 +39,14 @@ planet = Sphere(3, color=[0.75,0.25,0.25,0.75]) orbit_radius = 4 orbit_speed = 1 -wobble_radius = 0 +wobble_radius = 5 wobble_speed = 0.5 #TexSphere = type("TexSphere", (Sphere, TexModel), {}) #TexPlane = type("TexPlane", (Plane, TexModel), {}) -#reward_text = Text(7.5, "123", justify='right', color=[1,0,1,1]) +reward_text = Text(7.5, "123", justify='right', color=[1,0,1,1]) # center_out_gen = ScreenTargetCapture.centerout_2D(1) # center_out_positions = [pos[1] for _, pos in center_out_gen] +cable = Snake(0.5, 2*np.sin(np.arange(200)/2), color=(1,0,1,0.75)).translate(-15, 0, -10) center_out_gen = ScreenTargetCapture.centerout_tabletop(1) center_out_positions = [(pos[1][0], pos[1][1], -10) for _, pos in center_out_gen] center_out_targets = [ @@ -70,8 +71,9 @@ def _start_draw(self): #arm4j.set_joint_pos([0,0,np.pi/2,np.pi/2]) #arm4j.get_endpoint_pos() self.add_model(Grid(50)) - self.add_model(apriltag) self.add_model(moon) + self.add_model(planet) + self.add_model(apriltag) # self.add_model(moon) # self.add_model(planet) # self.add_model(arm4j) @@ -80,16 +82,20 @@ def _start_draw(self): # self.add_model(TexPlane(5,5, tex=cloudy_tex(), specular_color=(0.,0,0,1)).rotate_x(90)) # self.add_model(TexPlane(5,5, specular_color=(0.,0,0,1), tex=cloudy_tex()).rotate_x(90)) # reward_text = Text(7.5, "123", justify='right', color=[1,0,1,0.75]) - # self.add_model(reward_text) + self.add_model(reward_text) # self.add_model(TexPlane(4,4,color=[0,0,0,0.9], tex=cloudy_tex()).rotate_x(90).translate(0,0,-5)) #self.screen_init() #self.draw_world() for model in center_out_targets: self.add_model(model) - self.add_model(Sphere(radius=1, color=target_colors['purple']).translate(3,3,-10)) + for model in center_out_targets: + self.add_model(model) + self.add_model(Sphere(radius=1, color=target_colors['purple']).translate(3,0,-10)) + self.add_model(cable) def _while_draw(self): ts = time.time() - self.start_time + # cable.update_texture(int(ts*10), len(cable.trajectory)) x = travel_radius * np.cos(ts * travel_speed) y = travel_radius * np.sin(ts * travel_speed) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 41a83857..9f60c60b 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -20,6 +20,7 @@ from riglib.stereo_opengl.window import Window, Window2D import unittest import numpy as np +import matplotlib.pyplot as plt import os import socket @@ -77,9 +78,28 @@ def test_example_task(self): @unittest.skip("") def test_tracking(self): print("Running tracking task test") - seq = TrackingTask.tracking_target_debug(nblocks=1, ntrials=6, time_length=5, seed=40, sample_rate=60, ramp=1) # sample_rate needs to match fps in ScreenTargetTracking - exp = init_exp(TrackingTask, [MouseEmulateTouch, Window2D, MouseControl], seq) # , window_size=(1000,800) + seq = TrackingTask.tracking_target_chain(nblocks=1, ntrials=2, time_length=5, ramp=1, ramp_down=1, + num_primes=8, seed=42, sample_rate=60, + disturbance=False, boundaries=(-10,10,-10,10)) + exp = init_exp(TrackingTask, [HideLeftTrajectory, MouseControl, Window2D], seq, window_size=(1000,800), fullscreen=False, + lookahead_time=1, screen_half_height=10) exp.rotation = 'xzy' + # exp.trajectory_type = 'space' + exp.trajectory_amplitude = 5 + exp.trajectory_radius = 0.2 + exp.run() + + @unittest.skip("") + def test_tracking_2d(self): + print("Running tracking task test") + seq = TrackingTask.tracking_target_chain(nblocks=1, ntrials=2, time_length=20, ramp=1, ramp_down=1, + num_primes=10, seed=42, sample_rate=60, dimensions=2, + disturbance=True, boundaries=(-10,10,-10,10), decay_rate = 0.1) + exp = init_exp(TrackingTask, [Window2D, MouseControl], seq, window_size=(1000,800), fullscreen=False, + limit1d=False, trajectory_amplitude=5, lookahead_time=1) + exp.stereo_mode = 'projection' + exp.rotation = 'xzy' + exp.trajectory_type = 'space' exp.run() @unittest.skip("") @@ -102,7 +122,7 @@ def test_force_task(self): exp.end_task() @unittest.skip("only to test progress bar") - def test_tracking_progress(self): + def test_progress_bar(self): seq = TrackingTask.tracking_target_debug(nblocks=1, ntrials=6, time_length=5, seed=40, sample_rate=60, ramp=1) # sample_rate needs to match fps in ScreenTargetTracking exp = init_exp(TrackingTask, [MouseControl, Window2D, ProgressBar], seq) exp.rotation = 'xzy' @@ -163,6 +183,36 @@ def test_corners(self): print(loc) print("---------------corners") + @unittest.skip("") + def test_tracking_2d(self): + seq = TrackingTask.tracking_target_chain(nblocks=1, ntrials=2, time_length=20, ramp=0, ramp_down=0, + num_primes=12, seed=42, sample_rate=60, dimensions=2, + disturbance=False, boundaries=(-10,10,-10,10), decay_rate = None) + trajectories = [t[1][0] for t in seq] # pulls out trajectory. Can use t[3] to get disturbance array + print("2D Test-------") + print(np.shape(trajectories)) + print("2D Test-------") + fig, axs = plt.subplots(2,1, figsize=(10,8)) + for idx, trial in enumerate(trajectories): + ax = axs[idx] + trialx = np.fft.fft(trial[:,0]) + trial_length = np.shape(trialx)[0] + freq = np.fft.fftfreq(trial_length, d=1./60) + non_neg_freq = freq[freq >= 0] #get positive frequencies + non_neg_x = trialx[freq >= 0] / complex(trial_length, 0) #normalize + non_neg_x[1:] = 2*non_neg_x[1:] #account for negative frequencies + trialy = np.fft.fft(trial[:,2]) + non_neg_y = trialy[freq >= 0] / complex(trial_length, 0) #normalize + non_neg_y[1:] = 2*non_neg_y[1:] #account for negative frequencies + ax.plot(non_neg_freq, np.abs(non_neg_x), 'o-', label = 'X') + ax.plot(non_neg_freq, np.abs(non_neg_y), 'o-', label = 'Y') + ax.set_title(f'Trial {idx}') + ax.set_xlim(0, 3) + ax.set_xlabel('Frequency (Hz)') + plt.legend() + plt.tight_layout() + plt.show() + class TestYouTube(unittest.TestCase): @unittest.skip("")