Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 121 additions & 4 deletions analysis/online_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,8 @@ def cleanup(self):

class SaccadeAnalysisWorker(BehaviorAnalysisWorker):
'''
Plots eye, cursor, and target data from experiments that have them. Performs automatic
calibration of eye data to target locations when the cursor enters the target if no
calibration coefficients are available.
Plots calibrated_eye, cursor, and target data from experiments that have them.
This is for eye-related task that requires calibrated eye position
'''

def __init__(self, task_params, data_queue, calibration_dir='/var/tmp', buffer_time=1, ylim=1, px_per_cm=51.67, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you remove the unnecessary variables calibration_dir='/var/tmp', buffer_time=1, ylim=1, px_per_cm=51.67, since they're not used

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you can completely remove the init function here since it doesn't do anything anyway

Expand All @@ -324,6 +323,7 @@ def get_current_pos(self):
targets = [(self.target_pos[k], radius, color if v == 1 else 'green') for k, v in self.targets.items() if v]
except:
targets = []

return self.cursor_pos, self.calibrated_eye_pos, targets

def draw(self):
Expand All @@ -333,7 +333,7 @@ def draw(self):
buffer = self.task_params['fixation_radius_buffer']
elif 'fixation_dist' in self.task_params:
buffer = self.task_params['fixation_dist'] - self.task_params['target_radius']
eye_radius = 0.2
eye_radius = 0.1

patches1 = [plt.Circle(pos, radius+buffer) for pos, radius, _ in targets]
patches2 = [plt.Circle(cursor_pos, cursor_radius), plt.Circle(calibrated_eye_pos, eye_radius)]
Expand All @@ -352,6 +352,117 @@ def draw(self):
self.diam_plot.set_data(np.arange(len(self.eye_diam)) * 1/(int(self.task_params['fps'])) - self.buffer_time,
self.eye_diam[:, 2]/self.px_per_cm)

class EyeHandAnalysisWorker(SaccadeAnalysisWorker):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to inherit SaccadeAnalysisWorker or can you inherit from BehaviorAnalysisWorker instead to simplify the inheritance?

'''
Plots calibrated_eye, cursor, and target data from experiments that have them.
This is for eye-hand task
'''

def init(self):
super().init()
self.hand_targets = {}
self.eye_targets = {}
self.target_pos = []
self.target_idx_trial = []

def handle_data(self, key, values):
#super().handle_data(key, values)
if key == 'sync_event':
event_name, event_data = values
if event_name == 'TARGET_ON':
self.hand_targets[event_data] = 1 # event data represents target index in bmi3d
#self.eye_targets[self.target_idx_trial[0]] = 1 # bacause eye initial target and hand target appear at the same time
elif event_name == 'TARGET_OFF':
self.hand_targets[event_data] = 0
elif event_name == 'EYE_TARGET_ON':
self.eye_targets[event_data] = 1
elif event_name == 'EYE_TARGET_OFF':
self.eye_targets[event_data] = 0

if self.task_params['experiment_name'] == 'EyeConstrainedReachingTask':
self.hand_targets[self.target_idx_trial[-1]] = 0 # In this task, the hand target also disappear

elif event_name in ['PAUSE', 'TRIAL_END', 'HOLD_PENALTY', 'DELAY_PENALTY', 'TIMEOUT_PENALTY','FIXATION_PENALTY','OTHER_PENALTY']:
# Clear targets at the end of the trial
self.hand_targets = {}
self.eye_targets = {}
self.target_pos = []
self.target_idx_trial = []

elif event_name == 'REWARD':
# Set all active targets to reward
for target_idx in self.hand_targets.keys():
self.hand_targets[target_idx] = 2 if self.hand_targets[target_idx] else 0
for target_idx in self.eye_targets.keys():
self.eye_targets[target_idx] = 2 if self.eye_targets[target_idx] else 0

elif key == 'cursor':
self.cursor_pos = np.array(values[0])[[0,2]]
elif key == 'calibrated_eye_pos':
self.calibrated_eye_pos = np.array(values[0])[:2]

# Update eye diameter
if self.calibrated_eye_pos.size > 2:
self.temp = np.array(values[0])[[0,1,4]]
self.eye_diam = np.roll(self.eye_diam, -1, axis=0)
self.eye_diam[-1] = self.temp
Comment on lines +404 to +408
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calibrated_eye_pos is never more than 2 elements, but we still get eye pos which contains diameter

Suggested change
# Update eye diameter
if self.calibrated_eye_pos.size > 2:
self.temp = np.array(values[0])[[0,1,4]]
self.eye_diam = np.roll(self.eye_diam, -1, axis=0)
self.eye_diam[-1] = self.temp
elif key == 'eye_pos':
# Update eye diameter
if self.eye_pos.size > 2:
self.temp = np.array(values[0])[[0,1,4]]
self.eye_diam = np.roll(self.eye_diam, -1, axis=0)
self.eye_diam[-1] = self.temp


elif key == 'target_location':
target_idx, target_location = values
self.target_pos.append(np.array(target_location)[[0,2]])
self.target_idx_trial.append(target_idx)


def get_current_pos(self):
'''
Get the current cursor, eye, and target positions

Returns:
cursor_pos ((2,) tuple): Current cursor position
eye_pos ((2,) tuple): Current eye position and diameters
targets (list): List of active targets in (position, radius, color) format
'''
try:
radius = self.task_params['target_radius']
eye_radius = self.task_params['fixation_radius']
color = 'orange'
eye_color = 'lightskyblue'
eye_targets = [(self.target_pos[0], eye_radius, eye_color if v == 1 else 'green') for k, v in self.eye_targets.items() if v and k < 3]
eye_targets.extend([(self.target_pos[1], eye_radius, eye_color if v == 1 else 'green') for k, v in self.eye_targets.items() if v and k >= 3])
hand_targets = [(self.target_pos[2], radius, color if v == 1 else 'green') for k, v in self.hand_targets.items() if v]
except:
eye_targets = []
hand_targets = []

return self.cursor_pos, self.calibrated_eye_pos, eye_targets, hand_targets

def draw(self):
cursor_pos, calibrated_eye_pos, eye_targets, hand_targets = self.get_current_pos()
cursor_radius = self.task_params.get('cursor_radius', 0.25)
if 'fixation_radius_buffer' in self.task_params:
buffer = self.task_params['fixation_radius_buffer']
elif 'fixation_dist' in self.task_params:
buffer = self.task_params['fixation_dist'] - self.task_params['target_radius']
eye_radius = 0.2

patches1 = [plt.Circle(pos, radius+buffer) for pos, radius, _ in eye_targets]
patches2 = [plt.Circle(cursor_pos, cursor_radius), plt.Circle(calibrated_eye_pos, eye_radius)]
patches3 = [plt.Circle(pos, radius) for pos, radius, _ in eye_targets]
patches4 = [plt.Circle(pos, radius) for pos, radius, _ in hand_targets]
patches = patches1 + patches2 + patches3 + patches4
self.circles.set_paths(patches)
colors = [[0.8,0.8,0.8] for _, _, c in eye_targets] + ['darkblue', 'darkgreen'] + [c for _, _, c in eye_targets] + [c for _, _, c in hand_targets]
self.circles.set_facecolor(colors)
self.circles.set_alpha(0.5)

# Update eye diameter plot
self.x_plot.set_data(np.arange(len(self.eye_diam)) * 1/(int(self.task_params['fps'])) - self.buffer_time,
self.eye_diam[:, 0])
self.y_plot.set_data(np.arange(len(self.eye_diam)) * 1/(int(self.task_params['fps'])) - self.buffer_time,
self.eye_diam[:, 1])
self.diam_plot.set_data(np.arange(len(self.eye_diam)) * 1/(int(self.task_params['fps'])) - self.buffer_time,
self.eye_diam[:, 2]/self.px_per_cm)

class ERPAnalysisWorker(AnalysisWorker):
'''
Plots ERP data from experiments with an ECoG244 array. Automatically calculates
Expand Down Expand Up @@ -646,6 +757,12 @@ def init(self):
elif self.task_params['experiment_name'] == 'SaccadeTask':
self.analysis_workers.append((SaccadeAnalysisWorker(self.task_params, data_queue), data_queue))

elif self.task_params['experiment_name'] == 'HandConstrainedSaccadeTask':
self.analysis_workers.append((EyeHandAnalysisWorker(self.task_params, data_queue), data_queue))

elif self.task_params['experiment_name'] == 'EyeConstrainedReachingTask':
self.analysis_workers.append((EyeHandAnalysisWorker(self.task_params, data_queue), data_queue))

# Is there ecube neural data?
if 'record_headstage' in self.task_params and self.task_params['record_headstage']:
data_queue = mp.Queue()
Expand Down
14 changes: 13 additions & 1 deletion built_in_tasks/manualcontrolmultitasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .target_graphics import *
from .target_capture_task import ScreenTargetCapture
from .target_capture_task_xt import ScreenReachAngle, ScreenReachLine, SequenceCapture, ScreenTargetCapture_ReadySet
from .target_capture_task_eye import EyeConstrainedTargetCapture, HandConstrainedEyeCapture, ScreenTargetCapture_Saccade
from .target_capture_task_eye import EyeConstrainedTargetCapture, HandConstrainedEyeCapture, EyeConstrainedHandCapture, EyeHandSequenceCapture, ScreenTargetCapture_Saccade
from .target_tracking_task import ScreenTargetTracking
from .rotation_matrices import *

Expand Down Expand Up @@ -214,6 +214,18 @@ class HandConstrainedSaccadeTask(ManualControlMixin, HandConstrainedEyeCapture):
'''
pass

class EyeHandSequenceTask(ManualControlMixin, EyeHandSequenceCapture):
'''
Saccade task while holding different targets by hand
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't the correct description i don't think.

'''
pass

class EyeConstrainedReachingTask(ManualControlMixin, EyeConstrainedHandCapture):
'''
Saccade and reaching task while holding different targets by eye and hand
'''
pass

class SaccadeTask(ManualControlMixin, ScreenTargetCapture_Saccade):
'''
Center out saccade task. The controller for the cursor is eye positions. The target color changes when subjects fixate the target.
Expand Down
56 changes: 54 additions & 2 deletions built_in_tasks/target_capture_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ class ScreenTargetCapture(TargetCapture, Window):
limit2d = traits.Bool(True, desc="Limit cursor movement to 2D")

sequence_generators = [
'out_2D', 'out_2D_select','centerout_2D', 'centeroutback_2D', 'centerout_2D_select', 'rand_target_chain_2D', 'rand_same_target_chain_2D',
'rand_target_chain_3D', 'corners_2D', 'centerout_tabletop',
'out_2D', 'out_2D_select', 'centerout_2D', 'centeroutback_2D', 'centerout_2D_select', 'rand_target_chain_2D', 'rand_same_target_chain_2D',
'rand_target_chain_3D', 'corners_2D', 'centerout_tabletop', 'out_2D_square', 'centerout_2D_square'
]

hidden_traits = ['cursor_color', 'target_color', 'cursor_bounds', 'cursor_radius', 'plant_hide_rate', 'starting_pos']
Expand Down Expand Up @@ -586,6 +586,39 @@ def out_2D(nblocks=100, ntargets=8, distance=10, origin=(0,0,0)):
]).T
yield [idx], [pos + origin]

@staticmethod
def out_2D_square(nblocks=100, width=10, height=10, origin=(0,0,0)):
'''
Generates a sequence of 2D (x and z) targets at a point on the side of the square
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you be a bit more descriptive? How many positions will there be? What exactly does it mean to be a point on the side of the square?

'''
ntargets = 8
rng = np.random.default_rng()
for _ in range(nblocks):
order = np.arange(ntargets) + 1 # target indices, starting from 1
rng.shuffle(order)

for t in range(ntargets):
idx = order[t]

if idx == 1:
pos = np.array([0,0,height/2]).T
elif idx == 2:
pos = np.array([width/2,0,height/2]).T
elif idx == 3:
pos = np.array([width/2,0,0]).T
elif idx == 4:
pos = np.array([width/2,0,-height/2]).T
elif idx == 5:
pos = np.array([0,0,-height/2]).T
elif idx == 6:
pos = np.array([-width/2,0,-height/2]).T
elif idx == 7:
pos = np.array([-width/2,0,0]).T
elif idx == 8:
pos = np.array([-width/2,0,height/2]).T

yield [idx], [pos + origin]

@staticmethod
def centerout_2D(nblocks=100, ntargets=8, distance=10, origin=(0,0,0)):
'''
Expand Down Expand Up @@ -638,6 +671,25 @@ def out_2D_select(nblocks=100, ntargets=8, distance=10, origin=(0,0,0), target_i
except StopIteration:
break

@staticmethod
def centerout_2D_square(nblocks=100, width=10, height=10, origin=(0,0,0)):
'''
Pairs of central targets at the origin and peripheral targets centered around the origin

Returns
-------
[nblocks*ntargets x 1] array of tuples containing trial indices and [2 x 3] target coordinates
'''
ntargets = 8
gen = ScreenTargetCapture.out_2D_square(nblocks, width, height, origin)
for _ in range(nblocks*ntargets):
idx, pos = next(gen)
targs = np.zeros([2, 3]) + origin
targs[1,:] = pos[0]
indices = np.zeros([2,1])
indices[1] = idx
yield indices, targs

@staticmethod
def centeroutback_2D(nblocks=100, ntargets=8, distance=10, origin=(0,0,0)):
'''
Expand Down
Loading