diff --git a/pyproject.toml b/pyproject.toml index 57b37ef6..8df75d3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "Werkzeug>=2.2,<3.0", "pyserial>=3.5", "numpy>=1.22", + "pydantic>=2.7", "pygame>=2.5.1", "pynwb>=2.2.0", ] diff --git a/src/visiomode/core.py b/src/visiomode/core.py index 975fa16c..c357c22e 100644 --- a/src/visiomode/core.py +++ b/src/visiomode/core.py @@ -2,17 +2,19 @@ # This file is part of visiomode. # Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree # Distributed under the terms of the MIT Licence. import os import logging import time -import datetime import threading import queue +import json import pkg_resources import pygame as pg import visiomode.config as conf +import visiomode.database as database import visiomode.models as models import visiomode.webpanel as webpanel import visiomode.protocols as protocols @@ -142,6 +144,7 @@ def run_main(self): protocol is stopped. If the application receives a quit event, the session is saved and the application exits. """ + session_database = database.get_database(models.Session) while True: if self.session: self.session.protocol.update() @@ -155,7 +158,7 @@ def run_main(self): self.session.protocol.stop() self.session.complete = True self.session.trials = self.session.protocol.trials - self.session.save(self.config.data_dir) + session_database.save_entry(self.session) self.session = None pg.event.clear() # Clear unused events so queue doesn't fill up @@ -163,7 +166,7 @@ def run_main(self): if pg.event.get(eventtype=pg.QUIT): if self.session: self.session.trials = self.session.protocol.trials - self.session.save(self.config.data_dir) + session_database.save_entry(self.session) return pg.display.flip() @@ -187,10 +190,7 @@ def request_listener(self): if request["type"] == "start": protocol = protocols.get_protocol(request["data"].pop("protocol")) self.session = models.Session( - animal_id=request["data"].pop("animal_id"), - experiment=request["data"].pop("experiment"), - duration=float(request["data"].pop("duration")), - timestamp=datetime.datetime.now().isoformat(), + **request["data"], protocol=protocol(screen=self.screen, **request["data"]), spec=request["data"], ) @@ -199,7 +199,11 @@ def request_listener(self): self.log_q.put( { "status": "active" if self.session else "inactive", - "data": self.session.to_json() if self.session else [], + "data": ( + json.loads(self.session.model_dump_json()) + if self.session + else [] + ), } ) elif request["type"] == "stop": diff --git a/src/visiomode/database/CorruptedDatabaseError.py b/src/visiomode/database/CorruptedDatabaseError.py new file mode 100755 index 00000000..53abad2f --- /dev/null +++ b/src/visiomode/database/CorruptedDatabaseError.py @@ -0,0 +1,11 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + + +class CorruptedDatabaseError(Exception): + def __init__(self, database_type: str): + super().__init__( + f"Parsing of '{database_type}' database failed. " + f"It is most likely corrupted." + ) diff --git a/src/visiomode/database/Database.py b/src/visiomode/database/Database.py new file mode 100755 index 00000000..d0275c21 --- /dev/null +++ b/src/visiomode/database/Database.py @@ -0,0 +1,185 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import json +import os +import typing + +import pydantic + +from visiomode.config import Config +from visiomode.database.CorruptedDatabaseError import CorruptedDatabaseError +from visiomode.database.WrongDataTypeError import WrongDataTypeError + + +class Database: + _database_entry_type: type[pydantic.BaseModel] + _database_file: str + _database_path: str + + _in_memory_contents: typing.Optional[ + dict[str, typing.Union[str, list[dict]]] + ] = None + + def __init__( + self, + database_entry_type: type[pydantic.BaseModel], + database_file: str, + default_uniqueness_key: str, + ) -> None: + self._database_entry_type = database_entry_type + self._database_file = database_file + + self._initialise(default_uniqueness_key) + + @property + def uniqueness_key(self) -> typing.Optional[str]: + if self._in_memory_contents is not None: + return self._in_memory_contents["uniqueness_key"] + return None + + @property + def entries(self) -> typing.Optional[list[dict]]: + if self._in_memory_contents is not None: + return self._in_memory_contents["entries"] + return None + + def save_entry(self, entry: pydantic.BaseModel) -> None: + # Check entry data type is compatible with database (i.e. exact type) + if type(entry) is not self._database_entry_type: + raise WrongDataTypeError(str(type(entry)), str(self._database_entry_type)) + + # Check the entry does have the uniqueness key as an attribute + entry_unique_value = getattr(entry, self.uniqueness_key, None) + if entry_unique_value is None: + raise ValueError( + f"Invalid entry provided. It does not have a value for uniqueness key" + f"'{self.uniqueness_key}'." + ) + serialised_entry = self.serialise_model(entry) + + # Append entry to database or replace matching entry based on uniqueness key + replaced_entry = False + for database_index, database_entry in enumerate(self.entries): + if entry_unique_value == database_entry[self.uniqueness_key]: + self.entries[database_index] = serialised_entry + replaced_entry = True + break + + if not replaced_entry: + self.entries.append(serialised_entry) + + self.dump_database() + + def get_entry(self, entry_id: str) -> typing.Optional[dict]: + for database_entry in self.entries: + if database_entry[self.uniqueness_key] == entry_id: + print(database_entry) + return database_entry + + return None + + def get_entries(self) -> list[dict]: + return self.entries + + def get_filtered_entries(self, filter_key: str, filter_value: str) -> list[dict]: + filtered_entries = [] + + for entry in self.entries: + if entry.get(filter_key) == filter_value: + filtered_entries.append(entry) + + return filtered_entries + + def delete_entry(self, entry_id: str) -> None: + deleted_an_entry = False + for database_index, database_entry in enumerate(self.entries): + if database_entry[self.uniqueness_key] == entry_id: + self.entries.pop(database_index) + deleted_an_entry = True + break + + if deleted_an_entry: + self.dump_database() + + def dump_database(self) -> None: + with open(self._database_path, "w") as database_handle: + json.dump(self._in_memory_contents, database_handle) + + def update_uniqueness_key(self, new_key: str) -> None: + self.validate_database(new_key) + self._in_memory_contents["uniqueness_key"] = new_key + + self.dump_database() + + def reload_database(self): + self.dump_database() + + self._initialise(self.uniqueness_key) + + def validate_database( + self, validation_uniqueness_key: typing.Optional[str] = None + ) -> None: + if validation_uniqueness_key is None: + if self._in_memory_contents.get("uniqueness_key") is None or not isinstance( + self.uniqueness_key, str + ): + raise CorruptedDatabaseError(str(self._database_entry_type)) + validation_uniqueness_key = self.uniqueness_key + + values = [] + + database_entries = self._in_memory_contents.get("entries") + if database_entries is None or not isinstance(database_entries, list): + raise CorruptedDatabaseError(str(self._database_entry_type)) + + for database_entry in database_entries: + database_entry_unique_value = database_entry.get(validation_uniqueness_key) + if database_entry_unique_value is None: + raise KeyError( + f"Validation of the database failed with key '{validation_uniqueness_key}'. " + f"At least one entry does not have a value for the given key." + ) + + if database_entry_unique_value in values: + raise KeyError( + f"Validation of the database failed with key '{validation_uniqueness_key}'. " + f"At least one duplicated entry was found." + ) + + values.append(database_entry_unique_value) + + def _initialise(self, default_uniqueness_key: typing.Optional[str] = None): + # Retrieve up-to-date path based on current config + self._database_path = f"{Config().db_dir}{os.sep}{self._database_file}" + + if os.path.exists(self._database_path): + try: + self._load_database() + except json.JSONDecodeError: + raise CorruptedDatabaseError(str(self._database_entry_type)) + except KeyError as key_error: + raise Exception( + "Initialisation of the database from disk failed." + ) from key_error + else: + if default_uniqueness_key is None: + raise KeyError( + "Cannot initialise database without either a default uniqueness key" + " or an already initialised database on disk." + ) + self._in_memory_contents = { + "uniqueness_key": default_uniqueness_key, + "entries": [], + } + self.dump_database() + + def _load_database(self) -> None: + with open(self._database_path, "r") as database_handle: + self._in_memory_contents = json.load(database_handle) + self.validate_database() + + @staticmethod + def serialise_model(model: pydantic.BaseModel) -> dict: + return json.loads(model.model_dump_json()) diff --git a/src/visiomode/database/NoMatchingDatabaseError.py b/src/visiomode/database/NoMatchingDatabaseError.py new file mode 100755 index 00000000..8400a372 --- /dev/null +++ b/src/visiomode/database/NoMatchingDatabaseError.py @@ -0,0 +1,8 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + + +class NoMatchingDatabaseError(Exception): + def __init__(self, provided_type: str) -> None: + super().__init__(f"Cannot find a database for data type '{provided_type}'.") diff --git a/src/visiomode/database/WrongDataTypeError.py b/src/visiomode/database/WrongDataTypeError.py new file mode 100755 index 00000000..6f518fe8 --- /dev/null +++ b/src/visiomode/database/WrongDataTypeError.py @@ -0,0 +1,11 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + + +class WrongDataTypeError(Exception): + def __init__(self, provided_type: str, database_type: str) -> None: + super().__init__( + f"Provided data type is not compatible with database type " + f"('{provided_type}' vs '{database_type}')." + ) diff --git a/src/visiomode/database/__init__.py b/src/visiomode/database/__init__.py new file mode 100755 index 00000000..d38b4fc2 --- /dev/null +++ b/src/visiomode/database/__init__.py @@ -0,0 +1,41 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import typing + +import visiomode.models as models +from visiomode.database.CorruptedDatabaseError import CorruptedDatabaseError +from visiomode.database.Database import Database +from visiomode.database.NoMatchingDatabaseError import NoMatchingDatabaseError +from visiomode.database.WrongDataTypeError import WrongDataTypeError + + +# Defaults are presented in tuples of: +# - Database file +# - Database default uniqueness key +DATABASE_DEFAULTS = { + models.Animal: ("animals.json", "animal_id"), + models.Experimenter: ("experimenters.json", "experimenter_name"), + models.Session: ("sessions.json", "session_id"), +} +DatabaseSupported = typing.Union[ + type[models.Animal], type[models.Experimenter], type[models.Session] +] + + +_managed_databases = dict() + + +def get_database(database_item_type: DatabaseSupported) -> Database: + if database_item_type not in _managed_databases.keys(): + if database_item_type not in DATABASE_DEFAULTS.keys(): + raise NoMatchingDatabaseError(str(database_item_type)) + + database_defaults = DATABASE_DEFAULTS[database_item_type] + database = Database(database_item_type, *database_defaults) + _managed_databases[database_item_type] = database + else: + database = _managed_databases[database_item_type] + + return database diff --git a/src/visiomode/devices/lever_push.py b/src/visiomode/devices/lever_push.py index 2c8ff367..7ee034f7 100644 --- a/src/visiomode/devices/lever_push.py +++ b/src/visiomode/devices/lever_push.py @@ -32,7 +32,7 @@ def get_response(self): if not self._response_q.empty(): self._response_q.get() # Remove response from queue return models.Response( - timestamp=datetime.datetime.now().isoformat(), + timestamp=datetime.datetime.now(), name="leverpush", pos_x=self.config.width / 2, pos_y=self.config.height / 2, diff --git a/src/visiomode/devices/touchscreen.py b/src/visiomode/devices/touchscreen.py index 299d1953..42893880 100644 --- a/src/visiomode/devices/touchscreen.py +++ b/src/visiomode/devices/touchscreen.py @@ -26,7 +26,7 @@ def get_response(self): dist_y = touch_event.dy * self.config.height name = "left" if pos_x >= (self.config.width / 2) else "right" return models.Response( - timestamp=datetime.datetime.now().isoformat(), + timestamp=datetime.datetime.now(), name=name, pos_x=pos_x, pos_y=pos_y, diff --git a/src/visiomode/models.py b/src/visiomode/models.py deleted file mode 100644 index ccb557c0..00000000 --- a/src/visiomode/models.py +++ /dev/null @@ -1,241 +0,0 @@ -"""Application data model classes.""" - -# This file is part of visiomode. -# Copyright (c) 2020 Constantinos Eleftheriou -# Distributed under the terms of the MIT Licence. -import os -import dataclasses -import datetime -import socket -import json -import typing -import copy - -from visiomode import __about__, config - - -cfg = config.Config() - - -@dataclasses.dataclass -class Base: - """Base model class.""" - - def to_dict(self): - """Get class instance attributes as a dictionary. - - Returns: - Dictionary with class instance attributes. - """ - return dataclasses.asdict(self) - - def to_json(self): - """Get class instance attributes as JSON. - - Returns: - JSON string with class instance attributes. - """ - return json.dumps(self.to_dict()) - - -@dataclasses.dataclass -class Response(Base): - """ - Attributes: - timestamp: String trial date and time (ISO format). Defaults to current date and time. - name: Response type identifier (e.g. left, right or lever). - pos_x: Float representing the touch position in the x-axis. - pos_y: Float representing the touch position in the y-axis. - dist_x: Float representing the distance travelled while touching the screen in the x-axis. - dist_y: Float representing the distance travelled while touching the screen in the y-axis. - """ - - timestamp: str - name: str - pos_x: float - pos_y: float - dist_x: float - dist_y: float - - -@dataclasses.dataclass -class Trial(Base): - """Trial model class. - - Attributes: - outcome: String descriptive of trial outcome, e.g. "correct", "incorrect", "no_response", "precued". - iti: Float representing the "silent" time before the stimulus is presented in milliseconds. - duration: Integer representing the duration of the touch in milliseconds. - pos_x: Float representing the touch position in the x-axis. - pos_y: Float representing the touch position in the y-axis. - dist_x: Float representing the distance travelled while touching the screen in the x-axis. - dist_y: Float representing the distance travelled while touching the screen in the y-axis. - timestamp: String trial date and time (ISO format). Defaults to current date and time. - correction: Boolean indicating whether or not trial is a correction trial. Defaults to False. - response_time: Integer representing the time between stimulus presentation and response in seconds. - sdt_type: Signal detection theory outcome classification (i.e. hit/miss/false_alarm/correct_rejection) - """ - - outcome: str - iti: float - response: Response - timestamp: str = datetime.datetime.now().isoformat() - correction: bool = False - response_time: int = 0 - stimulus: dict = dataclasses.field(default_factory=dict) - sdt_type: str = "NA" - - def __repr__(self): - return "".format(str(self.timestamp)) - - -@dataclasses.dataclass -class Session(Base): - """Session model class. - - Attributes: - animal_id: String representing the animal identifier. - experiment: A string holding the experiment identifier. - protocol: An instance of the Protocol class. - duration: Integer representing the session duration in minutes. - complete: Boolean value indicating whether or not a session was completed - timestamp: A string with the session start date and time (ISO format). Defaults to current date and time. - notes: String with additional session notes. Defaults to empty string - device: String hostname of the device running the session. Defaults to the hostname provided by the socket lib. - trials: A mutable list of session trials; each trial is an instance of the Trial dataclass. Automatically populated using protocol.trials after class instantiation. - animal_meta: A dictionary with animal metadata (see Animal class). Automatically populated using animal_id after class instantiation. - version: Visiomode version this was generated with. - """ - - animal_id: str - experiment: str - duration: float - protocol: None = None - spec: dict = None - complete: bool = False - timestamp: str = datetime.datetime.now().isoformat() - notes: str = "" - device: str = socket.gethostname() - trials: typing.List[Trial] = dataclasses.field(default_factory=list) - animal_meta: dict = None - version: str = __about__.__version__ - - def __post_init__(self): - self.animal_meta = Animal.get_animal(self.animal_id) - self.trials = self.protocol.trials - - def to_dict(self): - """Get class instance attributes as a dictionary. - - This method overrides the Base class to cast nested Trial objects under self.trials as dictionaries. - - Returns: - Dictionary with class instance attributes. - """ - instance = copy.copy(self) - instance.trials = [trial.to_dict() for trial in self.trials if self.trials] - instance.protocol = self.protocol.get_identifier() - return dataclasses.asdict(instance) - - def save(self, path): - """Save session to json file.""" - session_id = ( - "sub-" - + self.animal_id - + "_exp-" - + self.experiment - + "_date-" - + self.timestamp.replace(":", "").replace("-", "").replace(".", "") - ) - f_path = path + os.sep + session_id + ".json" - with open(f_path, "w", encoding="utf-8") as f: - json.dump(self.to_dict(), f) - - def __repr__(self): - return "".format(str(self.timestamp)) - - -@dataclasses.dataclass -class Animal(Base): - """Animal model class. - - Attributes: - animal_id: String representing the animal identifier. - date_of_birth: String representing the animal date of birth (ISO format). - sex: Character representing the animal's sex (M/F/U/O). - species: String representing the animal's species. Use the latin name, eg. Mus musculus. - genotype: String representing the animal's genotype. Defaults to empty string. - description: String with additional animal notes. Defaults to empty string. - rfid: String representing the animal's RFID tag. Defaults to empty string. - """ - - animal_id: str - date_of_birth: str - sex: str - species: str - genotype: str = "" - description: str = "" - rfid: str = "" - - def save(self): - """Append animal to json database file.""" - path = cfg.db_dir + os.sep + "animals.json" - - if os.path.exists(path): - with open(path, "r") as f: - animals = json.load(f) - # If the animal already exists, remove it and append the new one - animals = [ - animal - for animal in animals - if not animal["animal_id"] == self.animal_id - ] - animals.append(self.to_dict()) - with open(path, "w") as f: - json.dump(animals, f) - else: - with open(path, "w") as f: - json.dump([self.to_dict()], f) - - @classmethod - def get_animal(cls, animal_id): - """Get an animal from the database based on its ID.""" - path = cfg.db_dir + os.sep + "animals.json" - - if os.path.exists(path): - with open(path, "r") as f: - animals = json.load(f) - for animal in animals: - if animal["animal_id"] == animal_id: - return animal - return None - - @classmethod - def get_animals(cls): - """Get all animals stored in the database. - - Returns: - List of dictionaries with animal attributes. - """ - path = cfg.db_dir + os.sep + "animals.json" - - if os.path.exists(path): - with open(path, "r") as f: - animals = json.load(f) - return animals - return [] - - @classmethod - def delete_animal(cls, animal_id): - """Delete animal from database.""" - path = cfg.db_dir + os.sep + "animals.json" - - if os.path.exists(path): - with open(path, "r") as f: - animals = json.load(f) - # If the animal exists, remove it - animals = [ - animal for animal in animals if not animal["animal_id"] == animal_id - ] - with open(path, "w") as f: - json.dump(animals, f) diff --git a/src/visiomode/models/Animal.py b/src/visiomode/models/Animal.py new file mode 100755 index 00000000..5da41546 --- /dev/null +++ b/src/visiomode/models/Animal.py @@ -0,0 +1,20 @@ +# This file is part of visiomode. +# Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import typing + +import pydantic + + +class Animal(pydantic.BaseModel): + animal_id: str = pydantic.Field(repr=True) + date_of_birth: pydantic.PastDatetime = pydantic.Field(repr=False) + sex: typing.Literal["U", "M", "F", "O"] = pydantic.Field(repr=False) + species: typing.Literal[ + "Mus musculus", "Rattus norvegicus", "Other" + ] = pydantic.Field(repr=False) + description: str = pydantic.Field(default="", repr=False) + genotype: str = pydantic.Field(default="", repr=False) + rfid: str = pydantic.Field(default="", repr=False) diff --git a/src/visiomode/models/Experimenter.py b/src/visiomode/models/Experimenter.py new file mode 100755 index 00000000..752cd5cd --- /dev/null +++ b/src/visiomode/models/Experimenter.py @@ -0,0 +1,11 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import pydantic + + +class Experimenter(pydantic.BaseModel): + experimenter_name: str = pydantic.Field(repr=True) + laboratory_name: str = pydantic.Field(repr=False) + institution_name: str = pydantic.Field(repr=False) diff --git a/src/visiomode/models/Response.py b/src/visiomode/models/Response.py new file mode 100755 index 00000000..03d15915 --- /dev/null +++ b/src/visiomode/models/Response.py @@ -0,0 +1,21 @@ +# This file is part of visiomode. +# Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import typing + +import pydantic + + +class Response(pydantic.BaseModel, validate_assignment=True): + timestamp: pydantic.PastDatetime + name: typing.Literal["left", "right", "leverpush", "none"] + pos_x: pydantic.FiniteFloat + pos_y: pydantic.FiniteFloat + dist_x: pydantic.FiniteFloat + dist_y: pydantic.FiniteFloat + + @pydantic.field_serializer("timestamp") + def serialise_timestamp(self, timestamp: pydantic.PastDatetime) -> str: + return timestamp.isoformat() diff --git a/src/visiomode/models/Session.py b/src/visiomode/models/Session.py new file mode 100755 index 00000000..5a9007da --- /dev/null +++ b/src/visiomode/models/Session.py @@ -0,0 +1,57 @@ +# This file is part of visiomode. +# Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import datetime +import socket +import typing + +import pydantic + +import visiomode.__about__ as __about__ +import visiomode.database as database +import visiomode.protocols as protocols +import visiomode._models.Animal as Animal +import visiomode._models.Trial as Trial + + +class Session(pydantic.BaseModel, arbitrary_types_allowed=True): + session_id: str = pydantic.Field(default=None, repr=True) + animal_id: str = pydantic.Field(repr=False) + duration: pydantic.FiniteFloat = pydantic.Field(repr=False) + experiment: str = pydantic.Field(repr=False) + protocol: protocols.Protocol = pydantic.Field(repr=False) + animal_metadata: typing.Optional[dict] = pydantic.Field(default={}, repr=False) + device: str = pydantic.Field(default_factory=socket.gethostname, repr=False) + complete: pydantic.StrictBool = pydantic.Field(default=False, repr=False) + notes: str = pydantic.Field(default="", repr=False) + spec: typing.Optional[dict] = pydantic.Field(default=None, repr=False) + timestamp: pydantic.NaiveDatetime = pydantic.Field( + default_factory=datetime.datetime.now, repr=True + ) + trials: list[Trial.Trial] = pydantic.Field(default=[], repr=False) + version: str = pydantic.Field(default=__about__.__version__, repr=False) + + def model_post_init(self, __context: typing.Any) -> None: + self.session_id = self.generate_session_id() + # Don't really know why but somehow by this point, Animal has turned from the + # module into the class (at runtime, according to Pydantic)? Even Pycharm is + # confused. + self.animal_metadata = database.get_database(Animal).get_entry(self.animal_id) + self.trials = self.protocol.trials + + @pydantic.field_serializer("protocol") + def serialise_protocol(self, protocol: protocols.Protocol) -> str: + return protocol.get_identifier() + + @pydantic.field_serializer("timestamp") + def serialise_timestamp(self, timestamp: pydantic.NaiveDatetime) -> str: + return timestamp.isoformat() + + def generate_session_id(self) -> str: + timestamp = self.serialise_timestamp(self.timestamp) + return ( + f"sub-{self.animal_id}_exp-{self.experiment}_date-" + f"{timestamp.replace(':', '').replace('-', '').replace('.', '')}" + ) diff --git a/src/visiomode/models/Trial.py b/src/visiomode/models/Trial.py new file mode 100755 index 00000000..d63e74e2 --- /dev/null +++ b/src/visiomode/models/Trial.py @@ -0,0 +1,32 @@ +# This file is part of visiomode. +# Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import datetime +import typing + +import pydantic + +from visiomode._models.Response import Response + + +class Trial(pydantic.BaseModel): + timestamp: pydantic.NaiveDatetime = pydantic.Field( + default_factory=datetime.datetime.now, repr=True + ) + iti: pydantic.FiniteFloat = pydantic.Field(repr=False) + outcome: typing.Literal[ + "correct", "incorrect", "precued", "no_response" + ] = pydantic.Field(repr=False) + response: typing.Optional[Response] = pydantic.Field(repr=False) + correction: pydantic.StrictBool = pydantic.Field(default=False, repr=False) + response_time: pydantic.FiniteFloat = pydantic.Field(default=0.0, repr=False) + sdt_type: typing.Literal[ + "hit", "miss", "false_alarm", "correct_rejection", "NA" + ] = pydantic.Field(default="NA", repr=False) + stimulus: dict = pydantic.Field(default={}, repr=False) + + @pydantic.field_serializer("timestamp") + def serialise_timestamp(self, timestamp: pydantic.NaiveDatetime) -> str: + return timestamp.isoformat() diff --git a/src/visiomode/models/__init__.py b/src/visiomode/models/__init__.py new file mode 100755 index 00000000..de34f94e --- /dev/null +++ b/src/visiomode/models/__init__.py @@ -0,0 +1,9 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +from visiomode._models.Animal import Animal +from visiomode._models.Experimenter import Experimenter +from visiomode._models.Response import Response +from visiomode._models.Session import Session +from visiomode._models.Trial import Trial diff --git a/src/visiomode/protocols/__init__.py b/src/visiomode/protocols/__init__.py index b7622210..3efbfa40 100644 --- a/src/visiomode/protocols/__init__.py +++ b/src/visiomode/protocols/__init__.py @@ -1,5 +1,6 @@ # This file is part of visiomode. # Copyright (c) 2021 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree # Distributed under the terms of the MIT Licence. """Module that defines the available task and stimulation protocols in a stimulus agnostic manner.""" @@ -198,7 +199,7 @@ def trial_block(self): if not outcome: return - stimulus = "None" + stimulus = {"id": "None"} if outcome != PRECUED: if ( self.distractor @@ -252,12 +253,10 @@ def parse_trial( trial_start, outcome, response=None, - response_time=0, + response_time=0.0, sdt_type="NA", stimulus=None, ): - if not response: - response = {"name": "none"} trial = models.Trial( outcome=outcome, iti=self.iti, diff --git a/src/visiomode/webpanel/api.py b/src/visiomode/webpanel/api.py index 19399293..2a74e932 100644 --- a/src/visiomode/webpanel/api.py +++ b/src/visiomode/webpanel/api.py @@ -2,6 +2,7 @@ # This file is part of visiomode. # Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree # Distributed under the terms of the MIT Licence. import os import json @@ -11,6 +12,8 @@ import glob import flask import flask.views +import pydantic + import visiomode.config as cfg import visiomode.devices as devices import visiomode.protocols as protocols @@ -18,6 +21,7 @@ import visiomode.webpanel.export as export from visiomode.models import Animal +import visiomode.database as db class DeviceAPI(flask.views.MethodView): @@ -158,9 +162,11 @@ def post(self): class AnimalsAPI(flask.views.MethodView): """API for managing animal profiles.""" + database = db.get_database(Animal) + def get(self): """Get animal profiles.""" - return {"animals": Animal.get_animals()} + return {"animals": self.database.get_entries()} def post(self): request_type = flask.request.json.get("type") # add, delete, update @@ -169,20 +175,22 @@ def post(self): if request_type == "delete": animal_id = request.get("id") if animal_id: - Animal.delete_animal(animal_id) + self.database.delete_entry(animal_id) else: - animals = Animal.get_animals() + animals = self.database.get_entries() for animal in animals: - Animal.delete_animal(animal["animal_id"]) + self.database.delete_entry(animal["animal_id"]) elif (request_type == "update") or (request_type == "add"): - animal = Animal( - animal_id=request.get("id"), - date_of_birth=request.get("dob"), - sex=request.get("sex"), - species=request.get("species"), - genotype=request.get("genotype"), - description=request.get("description"), - rfid=request.get("rfid"), - ) - animal.save() + try: + animal = Animal(**request) + self.database.save_entry(animal) + except pydantic.ValidationError: + logging.error( + f"Could not validate animal model from request data:\n{request}" + ) + return ( + json.dumps({"success": False}), + 400, + {"ContentType": "application/json"}, + ) return json.dumps({"success": True}), 200, {"ContentType": "application/json"} diff --git a/src/visiomode/webpanel/static/js/session.js b/src/visiomode/webpanel/static/js/session.js index b3e522f2..96405384 100644 --- a/src/visiomode/webpanel/static/js/session.js +++ b/src/visiomode/webpanel/static/js/session.js @@ -297,8 +297,8 @@ function addAnimal() { data: JSON.stringify({ type: "add", data: { - id: animalId, - dob: animalDob, + animal_id: animalId, + date_of_birth: animalDob, sex: animalSex, species: animalSpecies, genotype: animalGenotype, diff --git a/src/visiomode/webpanel/static/js/settings-animals.js b/src/visiomode/webpanel/static/js/settings-animals.js index 848cfff9..84d9de6a 100644 --- a/src/visiomode/webpanel/static/js/settings-animals.js +++ b/src/visiomode/webpanel/static/js/settings-animals.js @@ -64,8 +64,8 @@ function updateAnimal() { data: JSON.stringify({ type: "update", data: { - id: animalId, - dob: animalDob, + animal_id: animalId, + date_of_birth: animalDob, sex: animalSex, species: animalSpecies, genotype: animalGenotype, diff --git a/src/visiomode/webpanel/static/js/settings.js b/src/visiomode/webpanel/static/js/settings.js index dee29622..baccbffe 100644 --- a/src/visiomode/webpanel/static/js/settings.js +++ b/src/visiomode/webpanel/static/js/settings.js @@ -123,8 +123,8 @@ function addAnimal() { data: JSON.stringify({ type: "add", data: { - id: animalId, - dob: animalDob, + animal_id: animalId, + date_of_birth: animalDob, sex: animalSex, species: animalSpecies, genotype: animalGenotype,