diff --git a/built_in_tasks/target_tracking_task.py b/built_in_tasks/target_tracking_task.py index 847e18e0..81128b55 100644 --- a/built_in_tasks/target_tracking_task.py +++ b/built_in_tasks/target_tracking_task.py @@ -44,11 +44,13 @@ class TargetTracking(Sequence): wait_retry = dict(start_trial="trajectory", start_pause="pause"), trajectory = dict(enter_target="hold", timeout="timeout_penalty", start_pause="pause"), hold = dict(hold_complete="tracking_in", leave_target="hold_penalty", start_pause="pause"), - tracking_in = dict(trial_complete="reward", leave_target="tracking_out", start_pause="pause"), - tracking_out = dict(trial_complete="reward", enter_target="tracking_in", tracking_out_timeout="tracking_out_penalty", start_pause="pause"), + tracking_in = dict(trial_complete="reward", leave_target="tracking_out", start_pause="pause", cursor_is_still="inactive_tracking_penalty"), + tracking_out = dict(trial_complete="reward", enter_target="tracking_in", tracking_out_timeout="tracking_out_penalty", start_pause="pause", cursor_is_still="inactive_tracking_penalty"), timeout_penalty = dict(timeout_penalty_end="wait", start_pause="pause", end_state=True), hold_penalty = dict(hold_penalty_end="wait", hold_penalty_end_retry="wait_retry", start_pause="pause", end_state=True), tracking_out_penalty = dict(tracking_out_penalty_end="wait", start_pause="pause", end_state=True), + # TODO + inactive_tracking_penalty = dict(inactive_tracking_penalty_end="wait", start_pause="pause", end_state=True), reward = dict(reward_end="wait", start_pause="pause", stoppable=False, end_state=True), pause = dict(end_pause="wait", end_state=True) # all end_states will result in trial counter +1, so if you start pause during a penalty state, @@ -64,6 +66,8 @@ class TargetTracking(Sequence): reward_time = traits.Float(.5, desc="Length of reward dispensation") timeout_time = traits.Float(10, desc="Time allowed to go between trajectories") timeout_penalty_time = traits.Float(1, desc="Length of penalty time for initiation timeout error") + inactive_tracking = traits.Float(2, desc = "Length of time allowed for inactive tracking error") + inactive_tracking_penalty_time = traits.Float(0, desc="Length of penalty time for inactive tracking error") hold_time = traits.Float(.1, desc="Time of hold required at target before trajectory begins") hold_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error") tracking_out_time = traits.Float(2.5, desc="Time allowed to be tracking outside the target") # AKA tolerance time @@ -248,6 +252,18 @@ def _end_tracking_out_penalty(self): '''Nothing generic to do.''' pass + def _start_inactive_tracking_penalty(self): + '''Nothing generic to do.''' + pass + + def _while_inactive_tracking_penalty(self): + self.pos_offset = [0,0,0] + self.vel_offset = [0,0,0] + + def _end_inactive_tracking_penalty(self): + '''Nothing generic to do.''' + pass + def _start_reward(self): '''Nothing generic to do.''' pass @@ -303,9 +319,15 @@ def _test_trial_complete(self, time_in_state): def _test_tracking_out_timeout(self, time_in_state): return time_in_state > self.tracking_out_time + + # def _test_inactive_tracking(self, time_in_state): + # return time_in_state > self.inactive_tracking # change this def _test_timeout_penalty_end(self, time_in_state): return time_in_state > self.timeout_penalty_time #or self.pause + + def _test_inactive_tracking_penalty_end(self, time_in_state): + return time_in_state > self.inactive_tracking_penalty_time #or self.pause def _test_hold_penalty_end(self, time_in_state): return (time_in_state > self.hold_penalty_time) and (self.tries==self.max_hold_attempts) #or self.pause @@ -313,7 +335,7 @@ def _test_hold_penalty_end(self, time_in_state): def _test_hold_penalty_end_retry(self, time_in_state): return (time_in_state > self.hold_penalty_time) and (self.tries self.tracking_out_penalty_time #or self.pause def _test_reward_end(self, time_in_state): @@ -326,6 +348,10 @@ def _test_enter_target(self, time_in_state): def _test_leave_target(self, time_in_state): '''This function is task-specific and not much can be done generically''' return False + + def _test_cursor_is_still(self, time_in_state): + '''This function is task-specific and not much can be done generically''' + return False def _test_start_pause(self, time_in_state): return self.pause @@ -388,6 +414,8 @@ def __init__(self, *args, **kwargs): self.cursor_vis_prev = True self.lookahead = 30 # number of frames to create a "lookahead" window of 0.5 seconds (half the screen) self.original_limit1d = self.limit1d # keep track of original settable trait + + self.count = 0 if not self.always_1d: self.limit1d = False # allow 2d movement before center-hold initiation @@ -428,6 +456,8 @@ def _cycle(self): ''' Calls any update functions necessary and redraws screen ''' + # print(self.count) + self.prev_cursor = self.last_pt self.move_effector(pos_offset=np.asarray(self.pos_offset), vel_offset=np.asarray(self.vel_offset)) # Run graphics commands to show/hide the plant if the visibility has changed @@ -503,6 +533,21 @@ def _test_leave_target(self, time_in_state): cursor_pos = self.plant.get_endpoint_pos() d = np.linalg.norm(cursor_pos - self.target.get_position()) return d > (self.target_radius - self.cursor_radius) or super()._test_leave_target(time_in_state) + + # TODO + def _test_cursor_is_still(self, time_in_state): + ''' + Test if the cursor has been still + ''' + # print(self.count, self.frame_index) + cursor_pos = self.plant.get_endpoint_pos() + if (cursor_pos == self.prev_cursor).all(): + # print('same pos') + self.count += 1/self.fps + else: + self.count = 0 + + return (self.count >= self.inactive_tracking) or super()._test_cursor_is_still(time_in_state) #### STATE FUNCTIONS #### def setup_start_wait(self): @@ -527,6 +572,9 @@ def setup_start_wait(self): # Set up for progress bar self.bar_width = 12 self.tracking_frame_index = 0 + + # Set up for inactive tracking + self.count = 0 # Set up the next trajectory next_trajectory = np.array(np.squeeze(self.targs)[:,2]) @@ -652,7 +700,7 @@ def _while_tracking_out(self): # Move target and trajectory to next frame so it appears to be moving self.update_frame() - + # Check if the trial is over and there are no more target frames to display if self.frame_index+self.lookahead >= np.shape(self.targs)[0]: self.trial_timed_out = True @@ -675,6 +723,26 @@ def _while_timeout_penalty(self): def _end_timeout_penalty(self): super()._end_timeout_penalty() self.sync_event('TRIAL_END') + + # TODO + def _start_inactive_tracking_penalty(self): + super()._start_inactive_tracking_penalty() + print('START INACTIVE PENALTY') + self.sync_event('TIMEOUT_PENALTY') + + # self.in_end_state = True + self.setup_screen_reset() + + # skip to next generated trial using same freq set + self.repeat_freq_set = True + + def _while_inactive_tracking_penalty(self): + super()._while_inactive_tracking_penalty() + # # Add disturbance + + def _end_inactive_tracking_penalty(self): + super()._end_inactive_tracking_penalty() + self.sync_event('TRIAL_END') def _start_hold_penalty(self): super()._start_hold_penalty() @@ -880,7 +948,7 @@ def generate_trajectories(num_trials=2, time_length=20, seed=40, sample_rate=120 trials['ref'][trial_id] = ref_trajectory/ref_A # previously, denominator was np.sum(a_ref) trials['dis'][trial_id] = dis_trajectory/dis_A # previously, denominator was np.sum(a_dis) # print(trial_order, ref_A, dis_A) - + return trials, trial_order @staticmethod diff --git a/features/reward_features.py b/features/reward_features.py index dd53bd81..6b07ba82 100644 --- a/features/reward_features.py +++ b/features/reward_features.py @@ -3,6 +3,7 @@ ''' import time import os +import socket import subprocess from riglib.experiment import traits from riglib.experiment.experiment import control_decorator @@ -69,6 +70,8 @@ def __init__(self, *args, **kwargs): super(RewardSystem, self).__init__(*args, **kwargs) self.reward = RemoteReward() self.reportstats['Reward #'] = 0 + hostname = socket.gethostname() + self.device_ip = socket.gethostbyname(hostname) def _start_reward(self): if hasattr(super(RewardSystem, self), '_start_reward'): @@ -76,9 +79,10 @@ def _start_reward(self): self.reportstats['Reward #'] += 1 if self.reportstats['Reward #'] % self.trials_per_reward == 0: - for _ in range(self.pellets_per_reward): # call trigger num of pellets_per_reward time - self.reward.trigger() - time.sleep(0.5) # wait for 0.5 seconds + for _ in range(self.pellets_per_reward): # call trigger num of pellets_per_reward time + print(self.device_ip) + self.reward.trigger(self.device_ip) + time.sleep(0.5) # wait for 0.5 seconds def _end_reward(self): if hasattr(super(RewardSystem, self), '_end_reward'): diff --git a/riglib/tablet_reward.py b/riglib/tablet_reward.py index 58ecffc5..10b93af1 100644 --- a/riglib/tablet_reward.py +++ b/riglib/tablet_reward.py @@ -49,4 +49,4 @@ def trigger(self): try: requests.post(url, timeout=3) except: - traceback.print_exc() + traceback.print_exc() \ No newline at end of file