Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/visiomode/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,15 @@ def request_listener(self):
)
conf.Config().save()

task = tasks.get_task(request["data"].pop("task"))
task_cls = tasks.get_task(request["data"].pop("task"))
task = task_cls(screen=self.screen, **request["data"])
self.session = models.Session(
animal_id=request["data"].pop("animal_id"),
experimenter_name=request["data"].pop("experimenter_name"),
experiment=request["data"].pop("experiment"),
duration=float(request["data"].pop("duration")),
timestamp=datetime.datetime.now().isoformat(),
task=task(screen=self.screen, **request["data"]),
spec=request["data"],
task=task,
)
self.session.task.start()
elif request["type"] == "status":
Expand Down
3 changes: 2 additions & 1 deletion src/visiomode/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __post_init__(self):
self.device = socket.gethostname() if not self.device else self.device
self.animal_meta = {} if not self.animal_meta else self.animal_meta
self.experimenter_meta = {} if not self.experimenter_meta else self.experimenter_meta
self.spec = self.task.get_spec if self.task else {}

def to_dict(self):
"""Get class instance attributes as a dictionary.
Expand Down Expand Up @@ -349,7 +350,7 @@ def delete_experimenter(cls, experimenter_name: str) -> None:
list(map(operator.itemgetter("experimenter_name"), experimenters)).index(experimenter_name)
)
except ValueError:
logging.error(f"Tried removing '{experimenter_name}' from database " f"but it was not in it.")
logging.error(f"Tried removing '{experimenter_name}' from database but it was not in it.")

with open(database_path, "w") as handle:
json.dump(experimenters, handle)
9 changes: 8 additions & 1 deletion src/visiomode/stimuli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ def set_centerx(self, centerx):

def get_details(self):
"""Returns a dictionary of stimulus attributes."""
return {"id": self.get_identifier(), "common_name": self.get_common_name()}
return {
"id": self.get_identifier(),
"common_name": self.get_common_name(),
"width": self.width,
"height": self.height,
"center_x": self.rect.centerx,
"center_y": self.rect.centery,
}

def generate_new_trial(self):
"""Regenerate stimuli for a fresh trial"""
13 changes: 13 additions & 0 deletions src/visiomode/stimuli/grating.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ def __init__(self, background, period=30, contrast=1.0, **kwargs):
self.rect = self.image.get_rect()
self.area = self.screen.get_rect()

def get_details(self):
"""Returns a dictionary of stimulus attributes."""
return {
"id": self.get_identifier(),
"common_name": self.get_common_name(),
"width": self.width,
"height": self.height,
"center_x": self.rect.centerx,
"center_y": self.rect.centery,
"contrast": self.contrast,
"period": self.period,
}

@classmethod
def sinusoid(cls, width: int, height: int, period: int, contrast: float = 1.0):
sinusoid = Grating._sinusoid(width, height, period)
Expand Down
16 changes: 16 additions & 0 deletions src/visiomode/stimuli/moving_grating.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,19 @@ def update(self):
# self.rect.move_ip(0, self.px_per_cycle)

self.draw()

def get_details(self):
"""Returns a dictionary of stimulus attributes."""
return {
"id": self.get_identifier(),
"common_name": self.get_common_name(),
"width": self.width,
"height": self.height,
"center_x": self.rect.centerx,
"center_y": self.rect.centery,
"contrast": self.contrast,
"period": self.period,
"frequency": self.frequency,
"drift_direction": "upwards" if self.direction > 0 else "downwards",
"velocity_px": self.px_per_cycle,
}
13 changes: 13 additions & 0 deletions src/visiomode/stimuli/solid_colour.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,22 @@ class SolidColour(stimuli.Stimulus):

def __init__(self, background, colour, **kwargs):
super().__init__(background, **kwargs)
self.colour = colour
rgb = pg.Color(colour)

self.image = pg.Surface((self.width, self.height))
self.image.fill(rgb)
self.rect = self.image.get_rect()
self.area = self.screen.get_rect()

def get_details(self):
"""Returns a dictionary of stimulus attributes."""
return {
"id": self.get_identifier(),
"common_name": self.get_common_name(),
"width": self.width,
"height": self.height,
"center_x": self.rect.centerx,
"center_y": self.rect.centery,
"colour": self.colour,
}
6 changes: 6 additions & 0 deletions src/visiomode/stimuli/variable_contrast_grating.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ def generate_new_trial(self):
self.area = self.screen.get_rect()

def get_details(self):
"""Returns a dictionary of stimulus attributes."""
return {
"id": self.get_identifier(),
"common_name": self.get_common_name(),
"width": self.width,
"height": self.height,
"center_x": self.rect.centerx,
"center_y": self.rect.centery,
"trial_contrast": self.trial_contrast,
"period": self.period,
}
9 changes: 9 additions & 0 deletions src/visiomode/stimuli/variable_contrast_moving_grating.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,17 @@ def generate_new_trial(self):
self.area = self.screen.get_rect()

def get_details(self):
"""Returns a dictionary of stimulus attributes."""
return {
"id": self.get_identifier(),
"common_name": self.get_common_name(),
"width": self.width,
"height": self.height,
"center_x": self.rect.centerx,
"center_y": self.rect.centery,
"trial_contrast": self.trial_contrast,
"period": self.period,
"frequency": self.frequency,
"drift_direction": "upwards" if self.direction > 0 else "downwards",
"velocity_px": self.px_per_cycle,
}
14 changes: 14 additions & 0 deletions src/visiomode/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ def __init__(
self.distractor = None
self.separator = None

self.response_address = response_address
self.response_profile = response_device
self.response_device = devices.get_input_device(response_device, response_address)

self.reward_address = reward_address
self.reward_profile = reward_profile
self.reward_device = devices.get_output_profile(reward_profile, reward_address)

self._response_q = queue.Queue()
Expand Down Expand Up @@ -265,6 +269,16 @@ def on_precued(self):
self.response_device.on_precued()
self.reward_device.on_precued()

def get_spec(self):
"""Return task specification as a dictionary."""
return {
"iti": self.iti,
"stimulus_duration": self.stimulus_duration,
"response_profile": self.response_profile,
"response_address": self.response_address,
"corrections_enabled": self.corrections_enabled,
}

def _session_runner(self):
self.on_task_start()
while self.is_running:
Expand Down
12 changes: 12 additions & 0 deletions src/visiomode/tasks/gonogo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def update_stimulus(self):
def get_random_stimulus(self):
return random.choice([self.target, self.distractor]) # noqa: S311

def get_spec(self):
"""Return task specification as a dictionary."""
return {
"iti": self.iti,
"stimulus_duration": self.stimulus_duration,
"response_profile": self.response_profile,
"response_address": self.response_address,
"corrections_enabled": self.corrections_enabled,
"target_stimulus": self.target.get_identifier(),
"distractor_stimulus": self.distractor.get_identifier(),
}

@classmethod
def get_common_name(cls):
return "Go / NoGo"
12 changes: 12 additions & 0 deletions src/visiomode/tasks/tafc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,15 @@ def shuffle_centerx(self):
self.screen.get_width() + (self.separator_size / 2),
]
return random.sample(centers, 2)

def get_spec(self):
"""Return task specification as a dictionary."""
return {
"iti": self.iti,
"stimulus_duration": self.stimulus_duration,
"response_profile": self.response_profile,
"response_address": self.response_address,
"corrections_enabled": self.corrections_enabled,
"target_stimulus": self.target.get_identifier(),
"distractor_stimulus": self.distractor.get_identifier(),
}
11 changes: 11 additions & 0 deletions src/visiomode/tasks/target_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,14 @@ def show_stimulus(self):

def hide_stimulus(self):
self.target.hide()

def get_spec(self):
"""Return task specification as a dictionary."""
return {
"iti": self.iti,
"stimulus_duration": self.stimulus_duration,
"response_profile": self.response_profile,
"response_address": self.response_address,
"corrections_enabled": self.corrections_enabled,
"target_stimulus": self.target.get_identifier(),
}
2 changes: 1 addition & 1 deletion src/visiomode/webpanel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get(self):
"fname": session_file.split(os.sep)[-1],
"animal_id": session["animal_id"],
"date": session["timestamp"],
"task": session["task"],
"task": session.get("task"),
"experiment": session["experiment"],
"session_id": pathlib.Path(session_file).stem,
}
Expand Down
2 changes: 1 addition & 1 deletion src/visiomode/webpanel/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def to_nwb(session_path):
# Delayed import to save on startup time
import pynwb

Check failure on line 28 in src/visiomode/webpanel/export.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLC0415)

src/visiomode/webpanel/export.py:28:5: PLC0415 `import` should be at the top-level of a file

Check failure on line 28 in src/visiomode/webpanel/export.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLC0415)

src/visiomode/webpanel/export.py:28:5: PLC0415 `import` should be at the top-level of a file

with open(session_path) as f:
session = json.load(f)
Expand Down Expand Up @@ -103,7 +103,7 @@
for trial in session.get("trials"):
start_time = (datetime.fromisoformat(trial["timestamp"]) - session_start_time).total_seconds()

stop_time = start_time + trial["iti"] + float(session["spec"]["stimulus_duration"]) / 1000
stop_time = start_time + trial["iti"] + float(session["spec"].get("stimulus_duration", 10000)) / 1000
if trial["response"].get("timestamp"):
stop_time = (datetime.fromisoformat(trial["response"]["timestamp"]) - session_start_time).total_seconds()

Expand Down
Loading