From 8ef389805fabdd5ae5802ecb60a0ed35411836d4 Mon Sep 17 00:00:00 2001 From: Olivier Delree Date: Mon, 10 Jun 2024 13:30:09 +0100 Subject: [PATCH 1/3] reimplement models using pydantic Models that we previously defined inside of `models.py` are now implemented using pydantic. This means they are validated at time of instantiation. As a result, it is possible to simply retrieve information form a POST request and feed it directly into the model constructor without needing to do manual validation. Since the models are now purely data containers, they are no longer responsible for storing themselves (e.g., `Animal`). A future commit will implement a `Database` class that is responsible for managing a database of the model type it was instantiated with. Note the web API and the models are not currently connected but will be once the refactor is a bit further along. Also note, the current `_models` directory is there to avoid name clashing with `models.py`. It will be changed in future commits once the refactor is complete. --- pyproject.toml | 1 + src/visiomode/_models/Animal.py | 20 ++++++++++ src/visiomode/_models/Experimenter.py | 11 ++++++ src/visiomode/_models/Response.py | 21 ++++++++++ src/visiomode/_models/Session.py | 57 +++++++++++++++++++++++++++ src/visiomode/_models/Trial.py | 32 +++++++++++++++ src/visiomode/_models/__init__.py | 9 +++++ 7 files changed, 151 insertions(+) create mode 100755 src/visiomode/_models/Animal.py create mode 100755 src/visiomode/_models/Experimenter.py create mode 100755 src/visiomode/_models/Response.py create mode 100755 src/visiomode/_models/Session.py create mode 100755 src/visiomode/_models/Trial.py create mode 100755 src/visiomode/_models/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 57b37ef6..8df75d3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "Werkzeug>=2.2,<3.0", "pyserial>=3.5", "numpy>=1.22", + "pydantic>=2.7", "pygame>=2.5.1", "pynwb>=2.2.0", ] diff --git a/src/visiomode/_models/Animal.py b/src/visiomode/_models/Animal.py new file mode 100755 index 00000000..5da41546 --- /dev/null +++ b/src/visiomode/_models/Animal.py @@ -0,0 +1,20 @@ +# This file is part of visiomode. +# Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import typing + +import pydantic + + +class Animal(pydantic.BaseModel): + animal_id: str = pydantic.Field(repr=True) + date_of_birth: pydantic.PastDatetime = pydantic.Field(repr=False) + sex: typing.Literal["U", "M", "F", "O"] = pydantic.Field(repr=False) + species: typing.Literal[ + "Mus musculus", "Rattus norvegicus", "Other" + ] = pydantic.Field(repr=False) + description: str = pydantic.Field(default="", repr=False) + genotype: str = pydantic.Field(default="", repr=False) + rfid: str = pydantic.Field(default="", repr=False) diff --git a/src/visiomode/_models/Experimenter.py b/src/visiomode/_models/Experimenter.py new file mode 100755 index 00000000..752cd5cd --- /dev/null +++ b/src/visiomode/_models/Experimenter.py @@ -0,0 +1,11 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import pydantic + + +class Experimenter(pydantic.BaseModel): + experimenter_name: str = pydantic.Field(repr=True) + laboratory_name: str = pydantic.Field(repr=False) + institution_name: str = pydantic.Field(repr=False) diff --git a/src/visiomode/_models/Response.py b/src/visiomode/_models/Response.py new file mode 100755 index 00000000..03d15915 --- /dev/null +++ b/src/visiomode/_models/Response.py @@ -0,0 +1,21 @@ +# This file is part of visiomode. +# Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import typing + +import pydantic + + +class Response(pydantic.BaseModel, validate_assignment=True): + timestamp: pydantic.PastDatetime + name: typing.Literal["left", "right", "leverpush", "none"] + pos_x: pydantic.FiniteFloat + pos_y: pydantic.FiniteFloat + dist_x: pydantic.FiniteFloat + dist_y: pydantic.FiniteFloat + + @pydantic.field_serializer("timestamp") + def serialise_timestamp(self, timestamp: pydantic.PastDatetime) -> str: + return timestamp.isoformat() diff --git a/src/visiomode/_models/Session.py b/src/visiomode/_models/Session.py new file mode 100755 index 00000000..5a9007da --- /dev/null +++ b/src/visiomode/_models/Session.py @@ -0,0 +1,57 @@ +# This file is part of visiomode. +# Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import datetime +import socket +import typing + +import pydantic + +import visiomode.__about__ as __about__ +import visiomode.database as database +import visiomode.protocols as protocols +import visiomode._models.Animal as Animal +import visiomode._models.Trial as Trial + + +class Session(pydantic.BaseModel, arbitrary_types_allowed=True): + session_id: str = pydantic.Field(default=None, repr=True) + animal_id: str = pydantic.Field(repr=False) + duration: pydantic.FiniteFloat = pydantic.Field(repr=False) + experiment: str = pydantic.Field(repr=False) + protocol: protocols.Protocol = pydantic.Field(repr=False) + animal_metadata: typing.Optional[dict] = pydantic.Field(default={}, repr=False) + device: str = pydantic.Field(default_factory=socket.gethostname, repr=False) + complete: pydantic.StrictBool = pydantic.Field(default=False, repr=False) + notes: str = pydantic.Field(default="", repr=False) + spec: typing.Optional[dict] = pydantic.Field(default=None, repr=False) + timestamp: pydantic.NaiveDatetime = pydantic.Field( + default_factory=datetime.datetime.now, repr=True + ) + trials: list[Trial.Trial] = pydantic.Field(default=[], repr=False) + version: str = pydantic.Field(default=__about__.__version__, repr=False) + + def model_post_init(self, __context: typing.Any) -> None: + self.session_id = self.generate_session_id() + # Don't really know why but somehow by this point, Animal has turned from the + # module into the class (at runtime, according to Pydantic)? Even Pycharm is + # confused. + self.animal_metadata = database.get_database(Animal).get_entry(self.animal_id) + self.trials = self.protocol.trials + + @pydantic.field_serializer("protocol") + def serialise_protocol(self, protocol: protocols.Protocol) -> str: + return protocol.get_identifier() + + @pydantic.field_serializer("timestamp") + def serialise_timestamp(self, timestamp: pydantic.NaiveDatetime) -> str: + return timestamp.isoformat() + + def generate_session_id(self) -> str: + timestamp = self.serialise_timestamp(self.timestamp) + return ( + f"sub-{self.animal_id}_exp-{self.experiment}_date-" + f"{timestamp.replace(':', '').replace('-', '').replace('.', '')}" + ) diff --git a/src/visiomode/_models/Trial.py b/src/visiomode/_models/Trial.py new file mode 100755 index 00000000..d63e74e2 --- /dev/null +++ b/src/visiomode/_models/Trial.py @@ -0,0 +1,32 @@ +# This file is part of visiomode. +# Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import datetime +import typing + +import pydantic + +from visiomode._models.Response import Response + + +class Trial(pydantic.BaseModel): + timestamp: pydantic.NaiveDatetime = pydantic.Field( + default_factory=datetime.datetime.now, repr=True + ) + iti: pydantic.FiniteFloat = pydantic.Field(repr=False) + outcome: typing.Literal[ + "correct", "incorrect", "precued", "no_response" + ] = pydantic.Field(repr=False) + response: typing.Optional[Response] = pydantic.Field(repr=False) + correction: pydantic.StrictBool = pydantic.Field(default=False, repr=False) + response_time: pydantic.FiniteFloat = pydantic.Field(default=0.0, repr=False) + sdt_type: typing.Literal[ + "hit", "miss", "false_alarm", "correct_rejection", "NA" + ] = pydantic.Field(default="NA", repr=False) + stimulus: dict = pydantic.Field(default={}, repr=False) + + @pydantic.field_serializer("timestamp") + def serialise_timestamp(self, timestamp: pydantic.NaiveDatetime) -> str: + return timestamp.isoformat() diff --git a/src/visiomode/_models/__init__.py b/src/visiomode/_models/__init__.py new file mode 100755 index 00000000..de34f94e --- /dev/null +++ b/src/visiomode/_models/__init__.py @@ -0,0 +1,9 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +from visiomode._models.Animal import Animal +from visiomode._models.Experimenter import Experimenter +from visiomode._models.Response import Response +from visiomode._models.Session import Session +from visiomode._models.Trial import Trial From c16d64b019cd509bc69ed01bed4dd1a6ccfd2958 Mon Sep 17 00:00:00 2001 From: Olivier Delree Date: Mon, 10 Jun 2024 13:44:07 +0100 Subject: [PATCH 2/3] add `database` module for models management The new `database` module introduces a new class `Database` that is used to store a single type of model. For example, when the application requires a database for animals, it can simply call `database.get_database(Animal)`. This will return the appropriate database object for that model, and any subsequent calls will return the same object. These databases use a so-called "uniqueness-key" to ensure there are no duplicates added to the database. The functionality mirrors what was available from the `Animal` class of `models.py` with a few more bells and whistles because why not. --- .../database/CorruptedDatabaseError.py | 11 ++ src/visiomode/database/Database.py | 185 ++++++++++++++++++ .../database/NoMatchingDatabaseError.py | 8 + src/visiomode/database/WrongDataTypeError.py | 11 ++ src/visiomode/database/__init__.py | 41 ++++ 5 files changed, 256 insertions(+) create mode 100755 src/visiomode/database/CorruptedDatabaseError.py create mode 100755 src/visiomode/database/Database.py create mode 100755 src/visiomode/database/NoMatchingDatabaseError.py create mode 100755 src/visiomode/database/WrongDataTypeError.py create mode 100755 src/visiomode/database/__init__.py diff --git a/src/visiomode/database/CorruptedDatabaseError.py b/src/visiomode/database/CorruptedDatabaseError.py new file mode 100755 index 00000000..53abad2f --- /dev/null +++ b/src/visiomode/database/CorruptedDatabaseError.py @@ -0,0 +1,11 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + + +class CorruptedDatabaseError(Exception): + def __init__(self, database_type: str): + super().__init__( + f"Parsing of '{database_type}' database failed. " + f"It is most likely corrupted." + ) diff --git a/src/visiomode/database/Database.py b/src/visiomode/database/Database.py new file mode 100755 index 00000000..d0275c21 --- /dev/null +++ b/src/visiomode/database/Database.py @@ -0,0 +1,185 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import json +import os +import typing + +import pydantic + +from visiomode.config import Config +from visiomode.database.CorruptedDatabaseError import CorruptedDatabaseError +from visiomode.database.WrongDataTypeError import WrongDataTypeError + + +class Database: + _database_entry_type: type[pydantic.BaseModel] + _database_file: str + _database_path: str + + _in_memory_contents: typing.Optional[ + dict[str, typing.Union[str, list[dict]]] + ] = None + + def __init__( + self, + database_entry_type: type[pydantic.BaseModel], + database_file: str, + default_uniqueness_key: str, + ) -> None: + self._database_entry_type = database_entry_type + self._database_file = database_file + + self._initialise(default_uniqueness_key) + + @property + def uniqueness_key(self) -> typing.Optional[str]: + if self._in_memory_contents is not None: + return self._in_memory_contents["uniqueness_key"] + return None + + @property + def entries(self) -> typing.Optional[list[dict]]: + if self._in_memory_contents is not None: + return self._in_memory_contents["entries"] + return None + + def save_entry(self, entry: pydantic.BaseModel) -> None: + # Check entry data type is compatible with database (i.e. exact type) + if type(entry) is not self._database_entry_type: + raise WrongDataTypeError(str(type(entry)), str(self._database_entry_type)) + + # Check the entry does have the uniqueness key as an attribute + entry_unique_value = getattr(entry, self.uniqueness_key, None) + if entry_unique_value is None: + raise ValueError( + f"Invalid entry provided. It does not have a value for uniqueness key" + f"'{self.uniqueness_key}'." + ) + serialised_entry = self.serialise_model(entry) + + # Append entry to database or replace matching entry based on uniqueness key + replaced_entry = False + for database_index, database_entry in enumerate(self.entries): + if entry_unique_value == database_entry[self.uniqueness_key]: + self.entries[database_index] = serialised_entry + replaced_entry = True + break + + if not replaced_entry: + self.entries.append(serialised_entry) + + self.dump_database() + + def get_entry(self, entry_id: str) -> typing.Optional[dict]: + for database_entry in self.entries: + if database_entry[self.uniqueness_key] == entry_id: + print(database_entry) + return database_entry + + return None + + def get_entries(self) -> list[dict]: + return self.entries + + def get_filtered_entries(self, filter_key: str, filter_value: str) -> list[dict]: + filtered_entries = [] + + for entry in self.entries: + if entry.get(filter_key) == filter_value: + filtered_entries.append(entry) + + return filtered_entries + + def delete_entry(self, entry_id: str) -> None: + deleted_an_entry = False + for database_index, database_entry in enumerate(self.entries): + if database_entry[self.uniqueness_key] == entry_id: + self.entries.pop(database_index) + deleted_an_entry = True + break + + if deleted_an_entry: + self.dump_database() + + def dump_database(self) -> None: + with open(self._database_path, "w") as database_handle: + json.dump(self._in_memory_contents, database_handle) + + def update_uniqueness_key(self, new_key: str) -> None: + self.validate_database(new_key) + self._in_memory_contents["uniqueness_key"] = new_key + + self.dump_database() + + def reload_database(self): + self.dump_database() + + self._initialise(self.uniqueness_key) + + def validate_database( + self, validation_uniqueness_key: typing.Optional[str] = None + ) -> None: + if validation_uniqueness_key is None: + if self._in_memory_contents.get("uniqueness_key") is None or not isinstance( + self.uniqueness_key, str + ): + raise CorruptedDatabaseError(str(self._database_entry_type)) + validation_uniqueness_key = self.uniqueness_key + + values = [] + + database_entries = self._in_memory_contents.get("entries") + if database_entries is None or not isinstance(database_entries, list): + raise CorruptedDatabaseError(str(self._database_entry_type)) + + for database_entry in database_entries: + database_entry_unique_value = database_entry.get(validation_uniqueness_key) + if database_entry_unique_value is None: + raise KeyError( + f"Validation of the database failed with key '{validation_uniqueness_key}'. " + f"At least one entry does not have a value for the given key." + ) + + if database_entry_unique_value in values: + raise KeyError( + f"Validation of the database failed with key '{validation_uniqueness_key}'. " + f"At least one duplicated entry was found." + ) + + values.append(database_entry_unique_value) + + def _initialise(self, default_uniqueness_key: typing.Optional[str] = None): + # Retrieve up-to-date path based on current config + self._database_path = f"{Config().db_dir}{os.sep}{self._database_file}" + + if os.path.exists(self._database_path): + try: + self._load_database() + except json.JSONDecodeError: + raise CorruptedDatabaseError(str(self._database_entry_type)) + except KeyError as key_error: + raise Exception( + "Initialisation of the database from disk failed." + ) from key_error + else: + if default_uniqueness_key is None: + raise KeyError( + "Cannot initialise database without either a default uniqueness key" + " or an already initialised database on disk." + ) + self._in_memory_contents = { + "uniqueness_key": default_uniqueness_key, + "entries": [], + } + self.dump_database() + + def _load_database(self) -> None: + with open(self._database_path, "r") as database_handle: + self._in_memory_contents = json.load(database_handle) + self.validate_database() + + @staticmethod + def serialise_model(model: pydantic.BaseModel) -> dict: + return json.loads(model.model_dump_json()) diff --git a/src/visiomode/database/NoMatchingDatabaseError.py b/src/visiomode/database/NoMatchingDatabaseError.py new file mode 100755 index 00000000..8400a372 --- /dev/null +++ b/src/visiomode/database/NoMatchingDatabaseError.py @@ -0,0 +1,8 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + + +class NoMatchingDatabaseError(Exception): + def __init__(self, provided_type: str) -> None: + super().__init__(f"Cannot find a database for data type '{provided_type}'.") diff --git a/src/visiomode/database/WrongDataTypeError.py b/src/visiomode/database/WrongDataTypeError.py new file mode 100755 index 00000000..6f518fe8 --- /dev/null +++ b/src/visiomode/database/WrongDataTypeError.py @@ -0,0 +1,11 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + + +class WrongDataTypeError(Exception): + def __init__(self, provided_type: str, database_type: str) -> None: + super().__init__( + f"Provided data type is not compatible with database type " + f"('{provided_type}' vs '{database_type}')." + ) diff --git a/src/visiomode/database/__init__.py b/src/visiomode/database/__init__.py new file mode 100755 index 00000000..d38b4fc2 --- /dev/null +++ b/src/visiomode/database/__init__.py @@ -0,0 +1,41 @@ +# This file is part of visiomode. +# Copyright (c) 2024 Olivier Delree +# Distributed under the terms of the MIT Licence. + +import typing + +import visiomode.models as models +from visiomode.database.CorruptedDatabaseError import CorruptedDatabaseError +from visiomode.database.Database import Database +from visiomode.database.NoMatchingDatabaseError import NoMatchingDatabaseError +from visiomode.database.WrongDataTypeError import WrongDataTypeError + + +# Defaults are presented in tuples of: +# - Database file +# - Database default uniqueness key +DATABASE_DEFAULTS = { + models.Animal: ("animals.json", "animal_id"), + models.Experimenter: ("experimenters.json", "experimenter_name"), + models.Session: ("sessions.json", "session_id"), +} +DatabaseSupported = typing.Union[ + type[models.Animal], type[models.Experimenter], type[models.Session] +] + + +_managed_databases = dict() + + +def get_database(database_item_type: DatabaseSupported) -> Database: + if database_item_type not in _managed_databases.keys(): + if database_item_type not in DATABASE_DEFAULTS.keys(): + raise NoMatchingDatabaseError(str(database_item_type)) + + database_defaults = DATABASE_DEFAULTS[database_item_type] + database = Database(database_item_type, *database_defaults) + _managed_databases[database_item_type] = database + else: + database = _managed_databases[database_item_type] + + return database From b803f6ed0659600fa2398949ad12e15999aeb5e1 Mon Sep 17 00:00:00 2001 From: Olivier Delree Date: Mon, 10 Jun 2024 13:54:31 +0100 Subject: [PATCH 3/3] link models and database refactor with rest of application This removes the old `models.py` and replaces the old usages with the new API. Note that some of the Javascript is modified here to allow a one-to-one mapping of request attributes to model attribute, hence making validation just a matter of unpacking the request into the model constructor. --- src/visiomode/core.py | 20 +- src/visiomode/devices/lever_push.py | 2 +- src/visiomode/devices/touchscreen.py | 2 +- src/visiomode/models.py | 241 ------------------ src/visiomode/{_models => models}/Animal.py | 0 .../{_models => models}/Experimenter.py | 0 src/visiomode/{_models => models}/Response.py | 0 src/visiomode/{_models => models}/Session.py | 0 src/visiomode/{_models => models}/Trial.py | 0 src/visiomode/{_models => models}/__init__.py | 0 src/visiomode/protocols/__init__.py | 7 +- src/visiomode/webpanel/api.py | 36 ++- src/visiomode/webpanel/static/js/session.js | 4 +- .../webpanel/static/js/settings-animals.js | 4 +- src/visiomode/webpanel/static/js/settings.js | 4 +- 15 files changed, 45 insertions(+), 275 deletions(-) delete mode 100644 src/visiomode/models.py rename src/visiomode/{_models => models}/Animal.py (100%) rename src/visiomode/{_models => models}/Experimenter.py (100%) rename src/visiomode/{_models => models}/Response.py (100%) rename src/visiomode/{_models => models}/Session.py (100%) rename src/visiomode/{_models => models}/Trial.py (100%) rename src/visiomode/{_models => models}/__init__.py (100%) diff --git a/src/visiomode/core.py b/src/visiomode/core.py index 975fa16c..c357c22e 100644 --- a/src/visiomode/core.py +++ b/src/visiomode/core.py @@ -2,17 +2,19 @@ # This file is part of visiomode. # Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree # Distributed under the terms of the MIT Licence. import os import logging import time -import datetime import threading import queue +import json import pkg_resources import pygame as pg import visiomode.config as conf +import visiomode.database as database import visiomode.models as models import visiomode.webpanel as webpanel import visiomode.protocols as protocols @@ -142,6 +144,7 @@ def run_main(self): protocol is stopped. If the application receives a quit event, the session is saved and the application exits. """ + session_database = database.get_database(models.Session) while True: if self.session: self.session.protocol.update() @@ -155,7 +158,7 @@ def run_main(self): self.session.protocol.stop() self.session.complete = True self.session.trials = self.session.protocol.trials - self.session.save(self.config.data_dir) + session_database.save_entry(self.session) self.session = None pg.event.clear() # Clear unused events so queue doesn't fill up @@ -163,7 +166,7 @@ def run_main(self): if pg.event.get(eventtype=pg.QUIT): if self.session: self.session.trials = self.session.protocol.trials - self.session.save(self.config.data_dir) + session_database.save_entry(self.session) return pg.display.flip() @@ -187,10 +190,7 @@ def request_listener(self): if request["type"] == "start": protocol = protocols.get_protocol(request["data"].pop("protocol")) self.session = models.Session( - animal_id=request["data"].pop("animal_id"), - experiment=request["data"].pop("experiment"), - duration=float(request["data"].pop("duration")), - timestamp=datetime.datetime.now().isoformat(), + **request["data"], protocol=protocol(screen=self.screen, **request["data"]), spec=request["data"], ) @@ -199,7 +199,11 @@ def request_listener(self): self.log_q.put( { "status": "active" if self.session else "inactive", - "data": self.session.to_json() if self.session else [], + "data": ( + json.loads(self.session.model_dump_json()) + if self.session + else [] + ), } ) elif request["type"] == "stop": diff --git a/src/visiomode/devices/lever_push.py b/src/visiomode/devices/lever_push.py index 2c8ff367..7ee034f7 100644 --- a/src/visiomode/devices/lever_push.py +++ b/src/visiomode/devices/lever_push.py @@ -32,7 +32,7 @@ def get_response(self): if not self._response_q.empty(): self._response_q.get() # Remove response from queue return models.Response( - timestamp=datetime.datetime.now().isoformat(), + timestamp=datetime.datetime.now(), name="leverpush", pos_x=self.config.width / 2, pos_y=self.config.height / 2, diff --git a/src/visiomode/devices/touchscreen.py b/src/visiomode/devices/touchscreen.py index 299d1953..42893880 100644 --- a/src/visiomode/devices/touchscreen.py +++ b/src/visiomode/devices/touchscreen.py @@ -26,7 +26,7 @@ def get_response(self): dist_y = touch_event.dy * self.config.height name = "left" if pos_x >= (self.config.width / 2) else "right" return models.Response( - timestamp=datetime.datetime.now().isoformat(), + timestamp=datetime.datetime.now(), name=name, pos_x=pos_x, pos_y=pos_y, diff --git a/src/visiomode/models.py b/src/visiomode/models.py deleted file mode 100644 index ccb557c0..00000000 --- a/src/visiomode/models.py +++ /dev/null @@ -1,241 +0,0 @@ -"""Application data model classes.""" - -# This file is part of visiomode. -# Copyright (c) 2020 Constantinos Eleftheriou -# Distributed under the terms of the MIT Licence. -import os -import dataclasses -import datetime -import socket -import json -import typing -import copy - -from visiomode import __about__, config - - -cfg = config.Config() - - -@dataclasses.dataclass -class Base: - """Base model class.""" - - def to_dict(self): - """Get class instance attributes as a dictionary. - - Returns: - Dictionary with class instance attributes. - """ - return dataclasses.asdict(self) - - def to_json(self): - """Get class instance attributes as JSON. - - Returns: - JSON string with class instance attributes. - """ - return json.dumps(self.to_dict()) - - -@dataclasses.dataclass -class Response(Base): - """ - Attributes: - timestamp: String trial date and time (ISO format). Defaults to current date and time. - name: Response type identifier (e.g. left, right or lever). - pos_x: Float representing the touch position in the x-axis. - pos_y: Float representing the touch position in the y-axis. - dist_x: Float representing the distance travelled while touching the screen in the x-axis. - dist_y: Float representing the distance travelled while touching the screen in the y-axis. - """ - - timestamp: str - name: str - pos_x: float - pos_y: float - dist_x: float - dist_y: float - - -@dataclasses.dataclass -class Trial(Base): - """Trial model class. - - Attributes: - outcome: String descriptive of trial outcome, e.g. "correct", "incorrect", "no_response", "precued". - iti: Float representing the "silent" time before the stimulus is presented in milliseconds. - duration: Integer representing the duration of the touch in milliseconds. - pos_x: Float representing the touch position in the x-axis. - pos_y: Float representing the touch position in the y-axis. - dist_x: Float representing the distance travelled while touching the screen in the x-axis. - dist_y: Float representing the distance travelled while touching the screen in the y-axis. - timestamp: String trial date and time (ISO format). Defaults to current date and time. - correction: Boolean indicating whether or not trial is a correction trial. Defaults to False. - response_time: Integer representing the time between stimulus presentation and response in seconds. - sdt_type: Signal detection theory outcome classification (i.e. hit/miss/false_alarm/correct_rejection) - """ - - outcome: str - iti: float - response: Response - timestamp: str = datetime.datetime.now().isoformat() - correction: bool = False - response_time: int = 0 - stimulus: dict = dataclasses.field(default_factory=dict) - sdt_type: str = "NA" - - def __repr__(self): - return "".format(str(self.timestamp)) - - -@dataclasses.dataclass -class Session(Base): - """Session model class. - - Attributes: - animal_id: String representing the animal identifier. - experiment: A string holding the experiment identifier. - protocol: An instance of the Protocol class. - duration: Integer representing the session duration in minutes. - complete: Boolean value indicating whether or not a session was completed - timestamp: A string with the session start date and time (ISO format). Defaults to current date and time. - notes: String with additional session notes. Defaults to empty string - device: String hostname of the device running the session. Defaults to the hostname provided by the socket lib. - trials: A mutable list of session trials; each trial is an instance of the Trial dataclass. Automatically populated using protocol.trials after class instantiation. - animal_meta: A dictionary with animal metadata (see Animal class). Automatically populated using animal_id after class instantiation. - version: Visiomode version this was generated with. - """ - - animal_id: str - experiment: str - duration: float - protocol: None = None - spec: dict = None - complete: bool = False - timestamp: str = datetime.datetime.now().isoformat() - notes: str = "" - device: str = socket.gethostname() - trials: typing.List[Trial] = dataclasses.field(default_factory=list) - animal_meta: dict = None - version: str = __about__.__version__ - - def __post_init__(self): - self.animal_meta = Animal.get_animal(self.animal_id) - self.trials = self.protocol.trials - - def to_dict(self): - """Get class instance attributes as a dictionary. - - This method overrides the Base class to cast nested Trial objects under self.trials as dictionaries. - - Returns: - Dictionary with class instance attributes. - """ - instance = copy.copy(self) - instance.trials = [trial.to_dict() for trial in self.trials if self.trials] - instance.protocol = self.protocol.get_identifier() - return dataclasses.asdict(instance) - - def save(self, path): - """Save session to json file.""" - session_id = ( - "sub-" - + self.animal_id - + "_exp-" - + self.experiment - + "_date-" - + self.timestamp.replace(":", "").replace("-", "").replace(".", "") - ) - f_path = path + os.sep + session_id + ".json" - with open(f_path, "w", encoding="utf-8") as f: - json.dump(self.to_dict(), f) - - def __repr__(self): - return "".format(str(self.timestamp)) - - -@dataclasses.dataclass -class Animal(Base): - """Animal model class. - - Attributes: - animal_id: String representing the animal identifier. - date_of_birth: String representing the animal date of birth (ISO format). - sex: Character representing the animal's sex (M/F/U/O). - species: String representing the animal's species. Use the latin name, eg. Mus musculus. - genotype: String representing the animal's genotype. Defaults to empty string. - description: String with additional animal notes. Defaults to empty string. - rfid: String representing the animal's RFID tag. Defaults to empty string. - """ - - animal_id: str - date_of_birth: str - sex: str - species: str - genotype: str = "" - description: str = "" - rfid: str = "" - - def save(self): - """Append animal to json database file.""" - path = cfg.db_dir + os.sep + "animals.json" - - if os.path.exists(path): - with open(path, "r") as f: - animals = json.load(f) - # If the animal already exists, remove it and append the new one - animals = [ - animal - for animal in animals - if not animal["animal_id"] == self.animal_id - ] - animals.append(self.to_dict()) - with open(path, "w") as f: - json.dump(animals, f) - else: - with open(path, "w") as f: - json.dump([self.to_dict()], f) - - @classmethod - def get_animal(cls, animal_id): - """Get an animal from the database based on its ID.""" - path = cfg.db_dir + os.sep + "animals.json" - - if os.path.exists(path): - with open(path, "r") as f: - animals = json.load(f) - for animal in animals: - if animal["animal_id"] == animal_id: - return animal - return None - - @classmethod - def get_animals(cls): - """Get all animals stored in the database. - - Returns: - List of dictionaries with animal attributes. - """ - path = cfg.db_dir + os.sep + "animals.json" - - if os.path.exists(path): - with open(path, "r") as f: - animals = json.load(f) - return animals - return [] - - @classmethod - def delete_animal(cls, animal_id): - """Delete animal from database.""" - path = cfg.db_dir + os.sep + "animals.json" - - if os.path.exists(path): - with open(path, "r") as f: - animals = json.load(f) - # If the animal exists, remove it - animals = [ - animal for animal in animals if not animal["animal_id"] == animal_id - ] - with open(path, "w") as f: - json.dump(animals, f) diff --git a/src/visiomode/_models/Animal.py b/src/visiomode/models/Animal.py similarity index 100% rename from src/visiomode/_models/Animal.py rename to src/visiomode/models/Animal.py diff --git a/src/visiomode/_models/Experimenter.py b/src/visiomode/models/Experimenter.py similarity index 100% rename from src/visiomode/_models/Experimenter.py rename to src/visiomode/models/Experimenter.py diff --git a/src/visiomode/_models/Response.py b/src/visiomode/models/Response.py similarity index 100% rename from src/visiomode/_models/Response.py rename to src/visiomode/models/Response.py diff --git a/src/visiomode/_models/Session.py b/src/visiomode/models/Session.py similarity index 100% rename from src/visiomode/_models/Session.py rename to src/visiomode/models/Session.py diff --git a/src/visiomode/_models/Trial.py b/src/visiomode/models/Trial.py similarity index 100% rename from src/visiomode/_models/Trial.py rename to src/visiomode/models/Trial.py diff --git a/src/visiomode/_models/__init__.py b/src/visiomode/models/__init__.py similarity index 100% rename from src/visiomode/_models/__init__.py rename to src/visiomode/models/__init__.py diff --git a/src/visiomode/protocols/__init__.py b/src/visiomode/protocols/__init__.py index b7622210..3efbfa40 100644 --- a/src/visiomode/protocols/__init__.py +++ b/src/visiomode/protocols/__init__.py @@ -1,5 +1,6 @@ # This file is part of visiomode. # Copyright (c) 2021 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree # Distributed under the terms of the MIT Licence. """Module that defines the available task and stimulation protocols in a stimulus agnostic manner.""" @@ -198,7 +199,7 @@ def trial_block(self): if not outcome: return - stimulus = "None" + stimulus = {"id": "None"} if outcome != PRECUED: if ( self.distractor @@ -252,12 +253,10 @@ def parse_trial( trial_start, outcome, response=None, - response_time=0, + response_time=0.0, sdt_type="NA", stimulus=None, ): - if not response: - response = {"name": "none"} trial = models.Trial( outcome=outcome, iti=self.iti, diff --git a/src/visiomode/webpanel/api.py b/src/visiomode/webpanel/api.py index 19399293..2a74e932 100644 --- a/src/visiomode/webpanel/api.py +++ b/src/visiomode/webpanel/api.py @@ -2,6 +2,7 @@ # This file is part of visiomode. # Copyright (c) 2020 Constantinos Eleftheriou +# Copyright (c) 2024 Olivier Delree # Distributed under the terms of the MIT Licence. import os import json @@ -11,6 +12,8 @@ import glob import flask import flask.views +import pydantic + import visiomode.config as cfg import visiomode.devices as devices import visiomode.protocols as protocols @@ -18,6 +21,7 @@ import visiomode.webpanel.export as export from visiomode.models import Animal +import visiomode.database as db class DeviceAPI(flask.views.MethodView): @@ -158,9 +162,11 @@ def post(self): class AnimalsAPI(flask.views.MethodView): """API for managing animal profiles.""" + database = db.get_database(Animal) + def get(self): """Get animal profiles.""" - return {"animals": Animal.get_animals()} + return {"animals": self.database.get_entries()} def post(self): request_type = flask.request.json.get("type") # add, delete, update @@ -169,20 +175,22 @@ def post(self): if request_type == "delete": animal_id = request.get("id") if animal_id: - Animal.delete_animal(animal_id) + self.database.delete_entry(animal_id) else: - animals = Animal.get_animals() + animals = self.database.get_entries() for animal in animals: - Animal.delete_animal(animal["animal_id"]) + self.database.delete_entry(animal["animal_id"]) elif (request_type == "update") or (request_type == "add"): - animal = Animal( - animal_id=request.get("id"), - date_of_birth=request.get("dob"), - sex=request.get("sex"), - species=request.get("species"), - genotype=request.get("genotype"), - description=request.get("description"), - rfid=request.get("rfid"), - ) - animal.save() + try: + animal = Animal(**request) + self.database.save_entry(animal) + except pydantic.ValidationError: + logging.error( + f"Could not validate animal model from request data:\n{request}" + ) + return ( + json.dumps({"success": False}), + 400, + {"ContentType": "application/json"}, + ) return json.dumps({"success": True}), 200, {"ContentType": "application/json"} diff --git a/src/visiomode/webpanel/static/js/session.js b/src/visiomode/webpanel/static/js/session.js index b3e522f2..96405384 100644 --- a/src/visiomode/webpanel/static/js/session.js +++ b/src/visiomode/webpanel/static/js/session.js @@ -297,8 +297,8 @@ function addAnimal() { data: JSON.stringify({ type: "add", data: { - id: animalId, - dob: animalDob, + animal_id: animalId, + date_of_birth: animalDob, sex: animalSex, species: animalSpecies, genotype: animalGenotype, diff --git a/src/visiomode/webpanel/static/js/settings-animals.js b/src/visiomode/webpanel/static/js/settings-animals.js index 848cfff9..84d9de6a 100644 --- a/src/visiomode/webpanel/static/js/settings-animals.js +++ b/src/visiomode/webpanel/static/js/settings-animals.js @@ -64,8 +64,8 @@ function updateAnimal() { data: JSON.stringify({ type: "update", data: { - id: animalId, - dob: animalDob, + animal_id: animalId, + date_of_birth: animalDob, sex: animalSex, species: animalSpecies, genotype: animalGenotype, diff --git a/src/visiomode/webpanel/static/js/settings.js b/src/visiomode/webpanel/static/js/settings.js index dee29622..baccbffe 100644 --- a/src/visiomode/webpanel/static/js/settings.js +++ b/src/visiomode/webpanel/static/js/settings.js @@ -123,8 +123,8 @@ function addAnimal() { data: JSON.stringify({ type: "add", data: { - id: animalId, - dob: animalDob, + animal_id: animalId, + date_of_birth: animalDob, sex: animalSex, species: animalSpecies, genotype: animalGenotype,