Skip to content
78 changes: 73 additions & 5 deletions built_in_tasks/target_tracking_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ class TargetTracking(Sequence):
wait_retry = dict(start_trial="trajectory", start_pause="pause"),
trajectory = dict(enter_target="hold", timeout="timeout_penalty", start_pause="pause"),
hold = dict(hold_complete="tracking_in", leave_target="hold_penalty", start_pause="pause"),
tracking_in = dict(trial_complete="reward", leave_target="tracking_out", start_pause="pause"),
tracking_out = dict(trial_complete="reward", enter_target="tracking_in", tracking_out_timeout="tracking_out_penalty", start_pause="pause"),
tracking_in = dict(trial_complete="reward", leave_target="tracking_out", start_pause="pause", cursor_is_still="inactive_tracking_penalty"),
tracking_out = dict(trial_complete="reward", enter_target="tracking_in", tracking_out_timeout="tracking_out_penalty", start_pause="pause", cursor_is_still="inactive_tracking_penalty"),
timeout_penalty = dict(timeout_penalty_end="wait", start_pause="pause", end_state=True),
hold_penalty = dict(hold_penalty_end="wait", hold_penalty_end_retry="wait_retry", start_pause="pause", end_state=True),
tracking_out_penalty = dict(tracking_out_penalty_end="wait", start_pause="pause", end_state=True),
# TODO
inactive_tracking_penalty = dict(inactive_tracking_penalty_end="wait", start_pause="pause", end_state=True),
reward = dict(reward_end="wait", start_pause="pause", stoppable=False, end_state=True),
pause = dict(end_pause="wait", end_state=True)
# all end_states will result in trial counter +1, so if you start pause during a penalty state,
Expand All @@ -64,6 +66,8 @@ class TargetTracking(Sequence):
reward_time = traits.Float(.5, desc="Length of reward dispensation")
timeout_time = traits.Float(10, desc="Time allowed to go between trajectories")
timeout_penalty_time = traits.Float(1, desc="Length of penalty time for initiation timeout error")
inactive_tracking = traits.Float(2, desc = "Length of time allowed for inactive tracking error")
inactive_tracking_penalty_time = traits.Float(0, desc="Length of penalty time for inactive tracking error")
hold_time = traits.Float(.1, desc="Time of hold required at target before trajectory begins")
hold_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error")
tracking_out_time = traits.Float(2.5, desc="Time allowed to be tracking outside the target") # AKA tolerance time
Expand Down Expand Up @@ -248,6 +252,18 @@ def _end_tracking_out_penalty(self):
'''Nothing generic to do.'''
pass

def _start_inactive_tracking_penalty(self):
'''Nothing generic to do.'''
pass

def _while_inactive_tracking_penalty(self):
self.pos_offset = [0,0,0]
self.vel_offset = [0,0,0]

def _end_inactive_tracking_penalty(self):
'''Nothing generic to do.'''
pass

def _start_reward(self):
'''Nothing generic to do.'''
pass
Expand Down Expand Up @@ -303,17 +319,23 @@ def _test_trial_complete(self, time_in_state):

def _test_tracking_out_timeout(self, time_in_state):
return time_in_state > self.tracking_out_time

# def _test_inactive_tracking(self, time_in_state):
# return time_in_state > self.inactive_tracking # change this

def _test_timeout_penalty_end(self, time_in_state):
return time_in_state > self.timeout_penalty_time #or self.pause

def _test_inactive_tracking_penalty_end(self, time_in_state):
return time_in_state > self.inactive_tracking_penalty_time #or self.pause

def _test_hold_penalty_end(self, time_in_state):
return (time_in_state > self.hold_penalty_time) and (self.tries==self.max_hold_attempts) #or self.pause

def _test_hold_penalty_end_retry(self, time_in_state):
return (time_in_state > self.hold_penalty_time) and (self.tries<self.max_hold_attempts) #or self.pause

def _test_tracking_out_penalty_end(self, time_in_state):
def _test_tracking_out_penalty_end(self, time_in_state): # test when penalty state has ended
return time_in_state > self.tracking_out_penalty_time #or self.pause

def _test_reward_end(self, time_in_state):
Expand All @@ -326,6 +348,10 @@ def _test_enter_target(self, time_in_state):
def _test_leave_target(self, time_in_state):
'''This function is task-specific and not much can be done generically'''
return False

def _test_cursor_is_still(self, time_in_state):
'''This function is task-specific and not much can be done generically'''
return False

def _test_start_pause(self, time_in_state):
return self.pause
Expand Down Expand Up @@ -388,6 +414,8 @@ def __init__(self, *args, **kwargs):
self.cursor_vis_prev = True
self.lookahead = 30 # number of frames to create a "lookahead" window of 0.5 seconds (half the screen)
self.original_limit1d = self.limit1d # keep track of original settable trait

self.count = 0

if not self.always_1d:
self.limit1d = False # allow 2d movement before center-hold initiation
Expand Down Expand Up @@ -428,6 +456,8 @@ def _cycle(self):
'''
Calls any update functions necessary and redraws screen
'''
# print(self.count)
self.prev_cursor = self.last_pt
self.move_effector(pos_offset=np.asarray(self.pos_offset), vel_offset=np.asarray(self.vel_offset))

# Run graphics commands to show/hide the plant if the visibility has changed
Expand Down Expand Up @@ -503,6 +533,21 @@ def _test_leave_target(self, time_in_state):
cursor_pos = self.plant.get_endpoint_pos()
d = np.linalg.norm(cursor_pos - self.target.get_position())
return d > (self.target_radius - self.cursor_radius) or super()._test_leave_target(time_in_state)

# TODO
def _test_cursor_is_still(self, time_in_state):
'''
Test if the cursor has been still
'''
# print(self.count, self.frame_index)
cursor_pos = self.plant.get_endpoint_pos()
if (cursor_pos == self.prev_cursor).all():
# print('same pos')
self.count += 1/self.fps
else:
self.count = 0

return (self.count >= self.inactive_tracking) or super()._test_cursor_is_still(time_in_state)

#### STATE FUNCTIONS ####
def setup_start_wait(self):
Expand All @@ -527,6 +572,9 @@ def setup_start_wait(self):
# Set up for progress bar
self.bar_width = 12
self.tracking_frame_index = 0

# Set up for inactive tracking
self.count = 0

# Set up the next trajectory
next_trajectory = np.array(np.squeeze(self.targs)[:,2])
Expand Down Expand Up @@ -652,7 +700,7 @@ def _while_tracking_out(self):

# Move target and trajectory to next frame so it appears to be moving
self.update_frame()

# Check if the trial is over and there are no more target frames to display
if self.frame_index+self.lookahead >= np.shape(self.targs)[0]:
self.trial_timed_out = True
Expand All @@ -675,6 +723,26 @@ def _while_timeout_penalty(self):
def _end_timeout_penalty(self):
super()._end_timeout_penalty()
self.sync_event('TRIAL_END')

# TODO
def _start_inactive_tracking_penalty(self):
super()._start_inactive_tracking_penalty()
print('START INACTIVE PENALTY')
self.sync_event('TIMEOUT_PENALTY')

# self.in_end_state = True
self.setup_screen_reset()

# skip to next generated trial using same freq set
self.repeat_freq_set = True

def _while_inactive_tracking_penalty(self):
super()._while_inactive_tracking_penalty()
# # Add disturbance

def _end_inactive_tracking_penalty(self):
super()._end_inactive_tracking_penalty()
self.sync_event('TRIAL_END')

def _start_hold_penalty(self):
super()._start_hold_penalty()
Expand Down Expand Up @@ -880,7 +948,7 @@ def generate_trajectories(num_trials=2, time_length=20, seed=40, sample_rate=120
trials['ref'][trial_id] = ref_trajectory/ref_A # previously, denominator was np.sum(a_ref)
trials['dis'][trial_id] = dis_trajectory/dis_A # previously, denominator was np.sum(a_dis)
# print(trial_order, ref_A, dis_A)

return trials, trial_order

@staticmethod
Expand Down
10 changes: 7 additions & 3 deletions features/reward_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''
import time
import os
import socket
import subprocess
from riglib.experiment import traits
from riglib.experiment.experiment import control_decorator
Expand Down Expand Up @@ -69,16 +70,19 @@ def __init__(self, *args, **kwargs):
super(RewardSystem, self).__init__(*args, **kwargs)
self.reward = RemoteReward()
self.reportstats['Reward #'] = 0
hostname = socket.gethostname()
self.device_ip = socket.gethostbyname(hostname)

def _start_reward(self):
if hasattr(super(RewardSystem, self), '_start_reward'):
super(RewardSystem, self)._start_reward()
self.reportstats['Reward #'] += 1

if self.reportstats['Reward #'] % self.trials_per_reward == 0:
for _ in range(self.pellets_per_reward): # call trigger num of pellets_per_reward time
self.reward.trigger()
time.sleep(0.5) # wait for 0.5 seconds
for _ in range(self.pellets_per_reward): # call trigger num of pellets_per_reward time
print(self.device_ip)
self.reward.trigger(self.device_ip)
time.sleep(0.5) # wait for 0.5 seconds

def _end_reward(self):
if hasattr(super(RewardSystem, self), '_end_reward'):
Expand Down
2 changes: 1 addition & 1 deletion riglib/tablet_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def trigger(self):
try:
requests.post(url, timeout=3)
except:
traceback.print_exc()
traceback.print_exc()