diff --git a/src/visiomode/core.py b/src/visiomode/core.py index 4a3008a..dda7b42 100644 --- a/src/visiomode/core.py +++ b/src/visiomode/core.py @@ -205,15 +205,15 @@ def request_listener(self): ) conf.Config().save() - task = tasks.get_task(request["data"].pop("task")) + task_cls = tasks.get_task(request["data"].pop("task")) + task = task_cls(screen=self.screen, **request["data"]) self.session = models.Session( animal_id=request["data"].pop("animal_id"), experimenter_name=request["data"].pop("experimenter_name"), experiment=request["data"].pop("experiment"), duration=float(request["data"].pop("duration")), timestamp=datetime.datetime.now().isoformat(), - task=task(screen=self.screen, **request["data"]), - spec=request["data"], + task=task, ) self.session.task.start() elif request["type"] == "status": diff --git a/src/visiomode/models.py b/src/visiomode/models.py index c0dd70f..7fe9c4c 100644 --- a/src/visiomode/models.py +++ b/src/visiomode/models.py @@ -140,6 +140,7 @@ def __post_init__(self): self.device = socket.gethostname() if not self.device else self.device self.animal_meta = {} if not self.animal_meta else self.animal_meta self.experimenter_meta = {} if not self.experimenter_meta else self.experimenter_meta + self.spec = self.task.get_spec if self.task else {} def to_dict(self): """Get class instance attributes as a dictionary. @@ -349,7 +350,7 @@ def delete_experimenter(cls, experimenter_name: str) -> None: list(map(operator.itemgetter("experimenter_name"), experimenters)).index(experimenter_name) ) except ValueError: - logging.error(f"Tried removing '{experimenter_name}' from database " f"but it was not in it.") + logging.error(f"Tried removing '{experimenter_name}' from database but it was not in it.") with open(database_path, "w") as handle: json.dump(experimenters, handle) diff --git a/src/visiomode/stimuli/__init__.py b/src/visiomode/stimuli/__init__.py index 473a50e..d3a933a 100644 --- a/src/visiomode/stimuli/__init__.py +++ b/src/visiomode/stimuli/__init__.py @@ -89,7 +89,14 @@ def set_centerx(self, centerx): def get_details(self): """Returns a dictionary of stimulus attributes.""" - return {"id": self.get_identifier(), "common_name": self.get_common_name()} + return { + "id": self.get_identifier(), + "common_name": self.get_common_name(), + "width": self.width, + "height": self.height, + "center_x": self.rect.centerx, + "center_y": self.rect.centery, + } def generate_new_trial(self): """Regenerate stimuli for a fresh trial""" diff --git a/src/visiomode/stimuli/grating.py b/src/visiomode/stimuli/grating.py index 347fb62..0d0a9f0 100644 --- a/src/visiomode/stimuli/grating.py +++ b/src/visiomode/stimuli/grating.py @@ -21,6 +21,19 @@ def __init__(self, background, period=30, contrast=1.0, **kwargs): self.rect = self.image.get_rect() self.area = self.screen.get_rect() + def get_details(self): + """Returns a dictionary of stimulus attributes.""" + return { + "id": self.get_identifier(), + "common_name": self.get_common_name(), + "width": self.width, + "height": self.height, + "center_x": self.rect.centerx, + "center_y": self.rect.centery, + "contrast": self.contrast, + "period": self.period, + } + @classmethod def sinusoid(cls, width: int, height: int, period: int, contrast: float = 1.0): sinusoid = Grating._sinusoid(width, height, period) diff --git a/src/visiomode/stimuli/moving_grating.py b/src/visiomode/stimuli/moving_grating.py index d88c105..b808096 100644 --- a/src/visiomode/stimuli/moving_grating.py +++ b/src/visiomode/stimuli/moving_grating.py @@ -44,3 +44,19 @@ def update(self): # self.rect.move_ip(0, self.px_per_cycle) self.draw() + + def get_details(self): + """Returns a dictionary of stimulus attributes.""" + return { + "id": self.get_identifier(), + "common_name": self.get_common_name(), + "width": self.width, + "height": self.height, + "center_x": self.rect.centerx, + "center_y": self.rect.centery, + "contrast": self.contrast, + "period": self.period, + "frequency": self.frequency, + "drift_direction": "upwards" if self.direction > 0 else "downwards", + "velocity_px": self.px_per_cycle, + } diff --git a/src/visiomode/stimuli/solid_colour.py b/src/visiomode/stimuli/solid_colour.py index 41f666d..0967f88 100644 --- a/src/visiomode/stimuli/solid_colour.py +++ b/src/visiomode/stimuli/solid_colour.py @@ -12,9 +12,22 @@ class SolidColour(stimuli.Stimulus): def __init__(self, background, colour, **kwargs): super().__init__(background, **kwargs) + self.colour = colour rgb = pg.Color(colour) self.image = pg.Surface((self.width, self.height)) self.image.fill(rgb) self.rect = self.image.get_rect() self.area = self.screen.get_rect() + + def get_details(self): + """Returns a dictionary of stimulus attributes.""" + return { + "id": self.get_identifier(), + "common_name": self.get_common_name(), + "width": self.width, + "height": self.height, + "center_x": self.rect.centerx, + "center_y": self.rect.centery, + "colour": self.colour, + } diff --git a/src/visiomode/stimuli/variable_contrast_grating.py b/src/visiomode/stimuli/variable_contrast_grating.py index 908f228..588a3c4 100644 --- a/src/visiomode/stimuli/variable_contrast_grating.py +++ b/src/visiomode/stimuli/variable_contrast_grating.py @@ -30,8 +30,14 @@ def generate_new_trial(self): self.area = self.screen.get_rect() def get_details(self): + """Returns a dictionary of stimulus attributes.""" return { "id": self.get_identifier(), "common_name": self.get_common_name(), + "width": self.width, + "height": self.height, + "center_x": self.rect.centerx, + "center_y": self.rect.centery, "trial_contrast": self.trial_contrast, + "period": self.period, } diff --git a/src/visiomode/stimuli/variable_contrast_moving_grating.py b/src/visiomode/stimuli/variable_contrast_moving_grating.py index 1cf920f..70062bf 100644 --- a/src/visiomode/stimuli/variable_contrast_moving_grating.py +++ b/src/visiomode/stimuli/variable_contrast_moving_grating.py @@ -30,8 +30,17 @@ def generate_new_trial(self): self.area = self.screen.get_rect() def get_details(self): + """Returns a dictionary of stimulus attributes.""" return { "id": self.get_identifier(), "common_name": self.get_common_name(), + "width": self.width, + "height": self.height, + "center_x": self.rect.centerx, + "center_y": self.rect.centery, "trial_contrast": self.trial_contrast, + "period": self.period, + "frequency": self.frequency, + "drift_direction": "upwards" if self.direction > 0 else "downwards", + "velocity_px": self.px_per_cycle, } diff --git a/src/visiomode/tasks/__init__.py b/src/visiomode/tasks/__init__.py index c33c836..0e013e5 100644 --- a/src/visiomode/tasks/__init__.py +++ b/src/visiomode/tasks/__init__.py @@ -68,8 +68,12 @@ def __init__( self.distractor = None self.separator = None + self.response_address = response_address + self.response_profile = response_device self.response_device = devices.get_input_device(response_device, response_address) + self.reward_address = reward_address + self.reward_profile = reward_profile self.reward_device = devices.get_output_profile(reward_profile, reward_address) self._response_q = queue.Queue() @@ -265,6 +269,16 @@ def on_precued(self): self.response_device.on_precued() self.reward_device.on_precued() + def get_spec(self): + """Return task specification as a dictionary.""" + return { + "iti": self.iti, + "stimulus_duration": self.stimulus_duration, + "response_profile": self.response_profile, + "response_address": self.response_address, + "corrections_enabled": self.corrections_enabled, + } + def _session_runner(self): self.on_task_start() while self.is_running: diff --git a/src/visiomode/tasks/gonogo.py b/src/visiomode/tasks/gonogo.py index 0122041..d3313cd 100644 --- a/src/visiomode/tasks/gonogo.py +++ b/src/visiomode/tasks/gonogo.py @@ -53,6 +53,18 @@ def update_stimulus(self): def get_random_stimulus(self): return random.choice([self.target, self.distractor]) # noqa: S311 + def get_spec(self): + """Return task specification as a dictionary.""" + return { + "iti": self.iti, + "stimulus_duration": self.stimulus_duration, + "response_profile": self.response_profile, + "response_address": self.response_address, + "corrections_enabled": self.corrections_enabled, + "target_stimulus": self.target.get_identifier(), + "distractor_stimulus": self.distractor.get_identifier(), + } + @classmethod def get_common_name(cls): return "Go / NoGo" diff --git a/src/visiomode/tasks/tafc.py b/src/visiomode/tasks/tafc.py index 2e197c4..df59471 100644 --- a/src/visiomode/tasks/tafc.py +++ b/src/visiomode/tasks/tafc.py @@ -61,3 +61,15 @@ def shuffle_centerx(self): self.screen.get_width() + (self.separator_size / 2), ] return random.sample(centers, 2) + + def get_spec(self): + """Return task specification as a dictionary.""" + return { + "iti": self.iti, + "stimulus_duration": self.stimulus_duration, + "response_profile": self.response_profile, + "response_address": self.response_address, + "corrections_enabled": self.corrections_enabled, + "target_stimulus": self.target.get_identifier(), + "distractor_stimulus": self.distractor.get_identifier(), + } diff --git a/src/visiomode/tasks/target_only.py b/src/visiomode/tasks/target_only.py index 846a8ed..725ef20 100644 --- a/src/visiomode/tasks/target_only.py +++ b/src/visiomode/tasks/target_only.py @@ -30,3 +30,14 @@ def show_stimulus(self): def hide_stimulus(self): self.target.hide() + + def get_spec(self): + """Return task specification as a dictionary.""" + return { + "iti": self.iti, + "stimulus_duration": self.stimulus_duration, + "response_profile": self.response_profile, + "response_address": self.response_address, + "corrections_enabled": self.corrections_enabled, + "target_stimulus": self.target.get_identifier(), + } diff --git a/src/visiomode/webpanel/api.py b/src/visiomode/webpanel/api.py index ebaefa0..d82ea2c 100644 --- a/src/visiomode/webpanel/api.py +++ b/src/visiomode/webpanel/api.py @@ -110,7 +110,7 @@ def get(self): "fname": session_file.split(os.sep)[-1], "animal_id": session["animal_id"], "date": session["timestamp"], - "task": session["task"], + "task": session.get("task"), "experiment": session["experiment"], "session_id": pathlib.Path(session_file).stem, } diff --git a/src/visiomode/webpanel/export.py b/src/visiomode/webpanel/export.py index a29840e..37da211 100644 --- a/src/visiomode/webpanel/export.py +++ b/src/visiomode/webpanel/export.py @@ -103,7 +103,7 @@ def _flatten_trials(session): for trial in session.get("trials"): start_time = (datetime.fromisoformat(trial["timestamp"]) - session_start_time).total_seconds() - stop_time = start_time + trial["iti"] + float(session["spec"]["stimulus_duration"]) / 1000 + stop_time = start_time + trial["iti"] + float(session["spec"].get("stimulus_duration", 10000)) / 1000 if trial["response"].get("timestamp"): stop_time = (datetime.fromisoformat(trial["response"]["timestamp"]) - session_start_time).total_seconds()