Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"Werkzeug>=2.2,<3.0",
"pyserial>=3.5",
"numpy>=1.22",
"pydantic>=2.7",
"pygame>=2.5.1",
"pynwb>=2.2.0",
]
Expand Down
20 changes: 12 additions & 8 deletions src/visiomode/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

# This file is part of visiomode.
# Copyright (c) 2020 Constantinos Eleftheriou <Constantinos.Eleftheriou@ed.ac.uk>
# Copyright (c) 2024 Olivier Delree <odelree@ed.ac.uk>
# Distributed under the terms of the MIT Licence.

import os
import logging
import time
import datetime
import threading
import queue
import json
import pkg_resources
import pygame as pg
import visiomode.config as conf
import visiomode.database as database
import visiomode.models as models
import visiomode.webpanel as webpanel
import visiomode.protocols as protocols
Expand Down Expand Up @@ -142,6 +144,7 @@ def run_main(self):
protocol is stopped. If the application receives a quit event, the
session is saved and the application exits.
"""
session_database = database.get_database(models.Session)
while True:
if self.session:
self.session.protocol.update()
Expand All @@ -155,15 +158,15 @@ def run_main(self):
self.session.protocol.stop()
self.session.complete = True
self.session.trials = self.session.protocol.trials
self.session.save(self.config.data_dir)
session_database.save_entry(self.session)

self.session = None
pg.event.clear() # Clear unused events so queue doesn't fill up

if pg.event.get(eventtype=pg.QUIT):
if self.session:
self.session.trials = self.session.protocol.trials
self.session.save(self.config.data_dir)
session_database.save_entry(self.session)
return

pg.display.flip()
Expand All @@ -187,10 +190,7 @@ def request_listener(self):
if request["type"] == "start":
protocol = protocols.get_protocol(request["data"].pop("protocol"))
self.session = models.Session(
animal_id=request["data"].pop("animal_id"),
experiment=request["data"].pop("experiment"),
duration=float(request["data"].pop("duration")),
timestamp=datetime.datetime.now().isoformat(),
**request["data"],
protocol=protocol(screen=self.screen, **request["data"]),
spec=request["data"],
)
Expand All @@ -199,7 +199,11 @@ def request_listener(self):
self.log_q.put(
{
"status": "active" if self.session else "inactive",
"data": self.session.to_json() if self.session else [],
"data": (
json.loads(self.session.model_dump_json())
if self.session
else []
),
}
)
elif request["type"] == "stop":
Expand Down
11 changes: 11 additions & 0 deletions src/visiomode/database/CorruptedDatabaseError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# This file is part of visiomode.
# Copyright (c) 2024 Olivier Delree <odelree@ed.ac.uk>
# Distributed under the terms of the MIT Licence.


class CorruptedDatabaseError(Exception):
def __init__(self, database_type: str):
super().__init__(
f"Parsing of '{database_type}' database failed. "
f"It is most likely corrupted."
)
185 changes: 185 additions & 0 deletions src/visiomode/database/Database.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly we're moving towards a single JSON object acting as a database for all visiomode data?

I think this module could use a little bit more documentation, I'm struggling to wrap my head around this a little

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example what constitutes an entry in this database? An animal? A session? Are those separate things stored in the same file or different files?

Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# This file is part of visiomode.
# Copyright (c) 2024 Olivier Delree <odelree@ed.ac.uk>
# Distributed under the terms of the MIT Licence.

import json
import os
import typing

import pydantic

from visiomode.config import Config
from visiomode.database.CorruptedDatabaseError import CorruptedDatabaseError
from visiomode.database.WrongDataTypeError import WrongDataTypeError


class Database:
_database_entry_type: type[pydantic.BaseModel]
_database_file: str
_database_path: str

_in_memory_contents: typing.Optional[
dict[str, typing.Union[str, list[dict]]]
] = None

def __init__(
self,
database_entry_type: type[pydantic.BaseModel],
database_file: str,
default_uniqueness_key: str,
) -> None:
self._database_entry_type = database_entry_type
self._database_file = database_file

self._initialise(default_uniqueness_key)

@property
def uniqueness_key(self) -> typing.Optional[str]:
if self._in_memory_contents is not None:
return self._in_memory_contents["uniqueness_key"]
return None

@property
def entries(self) -> typing.Optional[list[dict]]:
if self._in_memory_contents is not None:
return self._in_memory_contents["entries"]
return None

def save_entry(self, entry: pydantic.BaseModel) -> None:
# Check entry data type is compatible with database (i.e. exact type)
if type(entry) is not self._database_entry_type:
raise WrongDataTypeError(str(type(entry)), str(self._database_entry_type))

# Check the entry does have the uniqueness key as an attribute
entry_unique_value = getattr(entry, self.uniqueness_key, None)
if entry_unique_value is None:
raise ValueError(
f"Invalid entry provided. It does not have a value for uniqueness key"
f"'{self.uniqueness_key}'."
)
serialised_entry = self.serialise_model(entry)

# Append entry to database or replace matching entry based on uniqueness key
replaced_entry = False
for database_index, database_entry in enumerate(self.entries):
if entry_unique_value == database_entry[self.uniqueness_key]:
self.entries[database_index] = serialised_entry
replaced_entry = True
break

if not replaced_entry:
self.entries.append(serialised_entry)

self.dump_database()

def get_entry(self, entry_id: str) -> typing.Optional[dict]:
for database_entry in self.entries:
if database_entry[self.uniqueness_key] == entry_id:
print(database_entry)
return database_entry

return None

def get_entries(self) -> list[dict]:
return self.entries

def get_filtered_entries(self, filter_key: str, filter_value: str) -> list[dict]:
filtered_entries = []

for entry in self.entries:
if entry.get(filter_key) == filter_value:
filtered_entries.append(entry)

return filtered_entries

def delete_entry(self, entry_id: str) -> None:
deleted_an_entry = False
for database_index, database_entry in enumerate(self.entries):
if database_entry[self.uniqueness_key] == entry_id:
self.entries.pop(database_index)
deleted_an_entry = True
break

if deleted_an_entry:
self.dump_database()

def dump_database(self) -> None:
with open(self._database_path, "w") as database_handle:
json.dump(self._in_memory_contents, database_handle)

def update_uniqueness_key(self, new_key: str) -> None:
self.validate_database(new_key)
self._in_memory_contents["uniqueness_key"] = new_key

self.dump_database()

def reload_database(self):
self.dump_database()

self._initialise(self.uniqueness_key)

def validate_database(
self, validation_uniqueness_key: typing.Optional[str] = None
) -> None:
if validation_uniqueness_key is None:
if self._in_memory_contents.get("uniqueness_key") is None or not isinstance(
self.uniqueness_key, str
):
raise CorruptedDatabaseError(str(self._database_entry_type))
validation_uniqueness_key = self.uniqueness_key

values = []

database_entries = self._in_memory_contents.get("entries")
if database_entries is None or not isinstance(database_entries, list):
raise CorruptedDatabaseError(str(self._database_entry_type))

for database_entry in database_entries:
database_entry_unique_value = database_entry.get(validation_uniqueness_key)
if database_entry_unique_value is None:
raise KeyError(
f"Validation of the database failed with key '{validation_uniqueness_key}'. "
f"At least one entry does not have a value for the given key."
)

if database_entry_unique_value in values:
raise KeyError(
f"Validation of the database failed with key '{validation_uniqueness_key}'. "
f"At least one duplicated entry was found."
)

values.append(database_entry_unique_value)

def _initialise(self, default_uniqueness_key: typing.Optional[str] = None):
# Retrieve up-to-date path based on current config
self._database_path = f"{Config().db_dir}{os.sep}{self._database_file}"

if os.path.exists(self._database_path):
try:
self._load_database()
except json.JSONDecodeError:
raise CorruptedDatabaseError(str(self._database_entry_type))
except KeyError as key_error:
raise Exception(
"Initialisation of the database from disk failed."
) from key_error
else:
if default_uniqueness_key is None:
raise KeyError(
"Cannot initialise database without either a default uniqueness key"
" or an already initialised database on disk."
)
self._in_memory_contents = {
"uniqueness_key": default_uniqueness_key,
"entries": [],
}
self.dump_database()

def _load_database(self) -> None:
with open(self._database_path, "r") as database_handle:
self._in_memory_contents = json.load(database_handle)
self.validate_database()

@staticmethod
def serialise_model(model: pydantic.BaseModel) -> dict:
return json.loads(model.model_dump_json())
8 changes: 8 additions & 0 deletions src/visiomode/database/NoMatchingDatabaseError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# This file is part of visiomode.
# Copyright (c) 2024 Olivier Delree <odelree@ed.ac.uk>
# Distributed under the terms of the MIT Licence.


class NoMatchingDatabaseError(Exception):
def __init__(self, provided_type: str) -> None:
super().__init__(f"Cannot find a database for data type '{provided_type}'.")
11 changes: 11 additions & 0 deletions src/visiomode/database/WrongDataTypeError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# This file is part of visiomode.
# Copyright (c) 2024 Olivier Delree <odelree@ed.ac.uk>
# Distributed under the terms of the MIT Licence.


class WrongDataTypeError(Exception):
def __init__(self, provided_type: str, database_type: str) -> None:
super().__init__(
f"Provided data type is not compatible with database type "
f"('{provided_type}' vs '{database_type}')."
)
41 changes: 41 additions & 0 deletions src/visiomode/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# This file is part of visiomode.
# Copyright (c) 2024 Olivier Delree <odelree@ed.ac.uk>
# Distributed under the terms of the MIT Licence.

import typing

import visiomode.models as models
from visiomode.database.CorruptedDatabaseError import CorruptedDatabaseError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of this style of importing, i.e. from thing.Stuff import Stuff. Ideally the modules should be structured so that you can either do from thing import Stuff, or for example errors go in their own module such that the imports become from visiomode.database.errors import CorruptedDatabaseError, NoMathingDatabaseError etc

from visiomode.database.Database import Database
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, I find the structuring of the imports a little awkward. You could also use an __all__ export at the module _init__.py (https://docs.python.org/3/tutorial/modules.html#importing-from-a-package), but my personal preference would be to group e.g. errors in a single file and define module-level classes at the module __init__.py file. f

from visiomode.database.NoMatchingDatabaseError import NoMatchingDatabaseError
from visiomode.database.WrongDataTypeError import WrongDataTypeError


# Defaults are presented in tuples of:
# - Database file
# - Database default uniqueness key
DATABASE_DEFAULTS = {
models.Animal: ("animals.json", "animal_id"),
models.Experimenter: ("experimenters.json", "experimenter_name"),
models.Session: ("sessions.json", "session_id"),
}
Comment on lines +17 to +21
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh is each one of these a different "database" ?

DatabaseSupported = typing.Union[
type[models.Animal], type[models.Experimenter], type[models.Session]
]


_managed_databases = dict()


def get_database(database_item_type: DatabaseSupported) -> Database:
if database_item_type not in _managed_databases.keys():
if database_item_type not in DATABASE_DEFAULTS.keys():
raise NoMatchingDatabaseError(str(database_item_type))

database_defaults = DATABASE_DEFAULTS[database_item_type]
database = Database(database_item_type, *database_defaults)
_managed_databases[database_item_type] = database
else:
database = _managed_databases[database_item_type]

return database
2 changes: 1 addition & 1 deletion src/visiomode/devices/lever_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_response(self):
if not self._response_q.empty():
self._response_q.get() # Remove response from queue
return models.Response(
timestamp=datetime.datetime.now().isoformat(),
timestamp=datetime.datetime.now(),
name="leverpush",
pos_x=self.config.width / 2,
pos_y=self.config.height / 2,
Expand Down
2 changes: 1 addition & 1 deletion src/visiomode/devices/touchscreen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_response(self):
dist_y = touch_event.dy * self.config.height
name = "left" if pos_x >= (self.config.width / 2) else "right"
return models.Response(
timestamp=datetime.datetime.now().isoformat(),
timestamp=datetime.datetime.now(),
name=name,
pos_x=pos_x,
pos_y=pos_y,
Expand Down
Loading