diff --git a/config/.env b/config/.env index 3bcc0594..1cfc943e 100644 --- a/config/.env +++ b/config/.env @@ -10,3 +10,5 @@ TOKEN_EXP_TIME=300 CORS=true AUTH_CONFIG=auth.yml + +PATH_MAP=${ISPYB_DATA_PATH} \ No newline at end of file diff --git a/config/ci.env b/config/ci.env index 55335196..c602ec8a 100644 --- a/config/ci.env +++ b/config/ci.env @@ -14,3 +14,5 @@ SECRET_KEY=ci_secret SQLALCHEMY_DATABASE_URI=mysql+mysqlconnector://test:test@127.0.0.1/test AUTH_CONFIG=tests/config/auth.yml + +PATH_MAP=${ISPYB_DATA_PATH} \ No newline at end of file diff --git a/config/dev.env b/config/dev.env index 5850ae88..b772cbf1 100644 --- a/config/dev.env +++ b/config/dev.env @@ -14,3 +14,5 @@ SECRET_KEY=dev_secret SQLALCHEMY_DATABASE_URI=mysql+mysqlconnector://test:test@127.0.0.1/test AUTH_CONFIG=auth.yml + +PATH_MAP=${ISPYB_DATA_PATH} \ No newline at end of file diff --git a/config/docker.env b/config/docker.env index f3a71780..793314b7 100644 --- a/config/docker.env +++ b/config/docker.env @@ -10,3 +10,5 @@ TOKEN_EXP_TIME=300 CORS=true AUTH_CONFIG=/config/auth.yml + +PATH_MAP=${ISPYB_DATA_PATH} \ No newline at end of file diff --git a/config/test.env b/config/test.env index d59b8135..5d290459 100644 --- a/config/test.env +++ b/config/test.env @@ -14,3 +14,5 @@ SECRET_KEY=test_secret SQLALCHEMY_DATABASE_URI=mysql+mysqlconnector://test:test@127.0.0.1/test AUTH_CONFIG=tests/config/auth.yml + +PATH_MAP=${ISPYB_DATA_PATH} \ No newline at end of file diff --git a/docs/conf.md b/docs/conf.md index cf403cd4..7ad5e361 100644 --- a/docs/conf.md +++ b/docs/conf.md @@ -34,3 +34,7 @@ SQLALCHEMY_DATABASE_URI=mysql+mysqlconnector://test:test@127.0.0.1/test AUTH_CONFIG=auth.yml ``` + +## ISPYB_DATA_PATH + +The env variable `ISPYB_DATA_PATH` allows you to define a path prefix that ISPyB should add to all data files path in the database. diff --git a/pyispyb/app/base.py b/pyispyb/app/base.py index 27c27f45..f4c0a2e8 100644 --- a/pyispyb/app/base.py +++ b/pyispyb/app/base.py @@ -14,9 +14,12 @@ def custom_generate_unique_id(route: APIRoute): class AuthenticatedAPIRouter(BaseRouter): def __init__(self, *args, **kwargs): + + deps = kwargs.pop("dependencies", []) + super().__init__( *args, - dependencies=[Depends(JWTBearer)], + dependencies=[Depends(JWTBearer), *deps], **kwargs, generate_unique_id_function=custom_generate_unique_id, ) diff --git a/pyispyb/app/utils/__init__.py b/pyispyb/app/utils/__init__.py index f6f75a54..c322c7cb 100644 --- a/pyispyb/app/utils/__init__.py +++ b/pyispyb/app/utils/__init__.py @@ -25,6 +25,8 @@ from sqlalchemy import text from functools import wraps from pyispyb.config import settings +from sqlalchemy.orm import class_mapper +from sqlalchemy import inspect logger = logging.getLogger("ispyb") @@ -87,3 +89,11 @@ def wrapper(self, *args, **kwargs): return result return wrapper + + +def model_from_json(model, data): + mapper = class_mapper(model) + keys = mapper.attrs.keys() + relationships = inspect(mapper).relationships + args = {k: v for k, v in data.items() if k in keys and k not in relationships} + return model(**args) diff --git a/pyispyb/core/modules/eventchains.py b/pyispyb/core/modules/eventchains.py new file mode 100644 index 00000000..1aab9853 --- /dev/null +++ b/pyispyb/core/modules/eventchains.py @@ -0,0 +1,43 @@ +from sqlalchemy.orm import joinedload + +from ispyb import models + +from pyispyb.app.extensions.database.middleware import db +from pyispyb.app.extensions.database.definitions import with_authorization +from pyispyb.app.extensions.database.utils import Paged, page + + +def get_datacollection_eventchains( + dataCollectionId: int, + skip: int, + limit: int, +) -> list[models.EventChain]: + query = ( + db.session.query(models.EventChain) + .filter(models.EventChain.dataCollectionId == dataCollectionId) + .options(joinedload(models.EventChain.events)) + .join( + models.DataCollection, + models.EventChain.dataCollectionId + == models.DataCollection.dataCollectionId, + ) + .join( + models.DataCollectionGroup, + models.DataCollection.dataCollectionGroupId + == models.DataCollectionGroup.dataCollectionGroupId, + ) + .join( + models.BLSession, + models.DataCollectionGroup.sessionId == models.BLSession.sessionId, + ) + .join( + models.Proposal, models.BLSession.proposalId == models.Proposal.proposalId + ) + ) + + query = with_authorization(query, joinBLSession=False) + + total = query.count() + query = page(query, skip=skip, limit=limit) + + return Paged(total=total, results=query.all(), skip=skip, limit=limit) diff --git a/pyispyb/core/modules/ssx.py b/pyispyb/core/modules/ssx.py new file mode 100644 index 00000000..36f2558b --- /dev/null +++ b/pyispyb/core/modules/ssx.py @@ -0,0 +1,472 @@ +import json +import logging +import os +import traceback +from typing import Optional, Type, TypeVar +from fastapi import HTTPException + +from ispyb import models +import pydantic +from pyispyb.app.extensions.database.definitions import with_authorization +from sqlalchemy.orm import joinedload + +from pyispyb.app.extensions.database.middleware import db +from pyispyb.app.utils import model_from_json +from pyispyb.config import settings +from pyispyb.core.modules.samples import get_samples +from pyispyb.core.modules.sessions import get_sessions +from pyispyb.core.schemas import events, ssx as schema +from fastapi.concurrency import run_in_threadpool +import numpy as np + + +def find_or_create_event_type(name: str): + type = ( + db.session.query(models.EventType).filter(models.EventType.name == name).first() + ) + if type is None: + type = models.EventType(name=name) + db.session.add(type) + db.session.flush() + return type + + +def create_ssx_datacollection( + ssx_datacollection_create: schema.SSXDataCollectionCreate, +) -> Optional[events.Event]: + data_collection_dict = ssx_datacollection_create.dict() + event_chains_list = data_collection_dict.pop("event_chains") + + try: + + # Check that DCG exists + count_dcg = ( + db.session.query(models.DataCollectionGroup) + .filter( + models.DataCollectionGroup.dataCollectionGroupId + == ssx_datacollection_create.dataCollectionGroupId + ) + .count() + ) + if count_dcg != 1: + raise HTTPException( + status_code=422, + detail=f"Could not find DataCollectionGroup with id {ssx_datacollection_create.dataCollectionGroupId}", + ) + + # Check that Detector exists: + if ssx_datacollection_create.detectorId is not None: + count_detector = ( + db.session.query(models.Detector) + .filter( + models.Detector.detectorId == ssx_datacollection_create.detectorId + ) + .count() + ) + if count_detector != 1: + raise HTTPException( + status_code=422, + detail=f"Could not find Detector with id {ssx_datacollection_create.detectorId}", + ) + + # DATA COLLECTION + + data_collection = model_from_json( + models.DataCollection, + {**data_collection_dict}, + ) + db.session.add(data_collection) + db.session.flush() + + ssx_data_collection = model_from_json( + models.SSXDataCollection, + { + **data_collection_dict, + "dataCollectionId": data_collection.dataCollectionId, + }, + ) + db.session.add(ssx_data_collection) + db.session.flush() + + # EVENT CHAINS + + for event_chain_dict in event_chains_list: + events_list = event_chain_dict.pop("events") + event_chain = model_from_json( + models.EventChain, + { + **event_chain_dict, + "dataCollectionId": data_collection.dataCollectionId, + }, + ) + db.session.add(event_chain) + db.session.flush() + for event_dict in events_list: + type = find_or_create_event_type(event_dict["type"]) + event = model_from_json( + models.Event, + { + **event_dict, + "eventChainId": event_chain.eventChainId, + "eventTypeId": type.eventTypeId, + }, + ) + db.session.add(event) + db.session.flush() + + db.session.commit() + return data_collection.dataCollectionId + + except Exception as e: + logging.error(traceback.format_exc()) + db.session.rollback() + raise e + + +def find_or_create_component_type(name: str): + type = ( + db.session.query(models.ComponentType) + .filter(models.ComponentType.name == name) + .first() + ) + if type is None: + type = models.ComponentType(name=name) + db.session.add(type) + db.session.flush() + return type + + +def create_ssx_datacollectiongroup( + ssx_datacollectiongroup_create: schema.SSXDataCollectionGroupCreate, +) -> Optional[int]: + datacollectiongroup_dict = ssx_datacollectiongroup_create.dict() + sample_dict = datacollectiongroup_dict.pop("sample") + sampleId = datacollectiongroup_dict.pop("sampleId") + + try: + + sessionId = datacollectiongroup_dict["sessionId"] + try: + session = get_sessions(sessionId=sessionId, skip=0, limit=1).first + except IndexError: + raise HTTPException( + status_code=422, detail=f"Could not find session with id {sessionId}" + ) + + ## SAMPLE + + if sample_dict is None and sampleId is None: + raise HTTPException( + status_code=422, + detail="You have to provide sampleId or sample create object", + ) + if sample_dict is not None and sampleId is not None: + raise HTTPException( + status_code=422, + detail="You have to provide only one of sampleId or sample create object", + ) + elif sampleId is not None: + try: + sample = get_samples(skip=0, limit=1, blSampleId=sampleId).first + except IndexError: + raise HTTPException( + status_code=422, detail=f"Could not find sample with id {sampleId}" + ) + else: + crystal_dict = sample_dict.pop("crystal") + protein_dict = crystal_dict.pop("protein") + crystal_components_list = crystal_dict.pop("components") + sample_components_list = sample_dict.pop("components") + protein = model_from_json( + models.Protein, + { + **protein_dict, + "proposalId": session.proposalId, + }, + ) + db.session.add(protein) + db.session.flush() + + crystal = model_from_json( + models.Crystal, + { + **crystal_dict, + "proteinId": protein.proteinId, + }, + ) + db.session.add(crystal) + db.session.flush() + + sample = model_from_json( + models.BLSample, + { + **sample_dict, + "crystalId": crystal.crystalId, + }, + ) + db.session.add(sample) + db.session.flush() + + for component_dict in crystal_components_list: + type = find_or_create_component_type(component_dict["componentType"]) + component = model_from_json( + models.Component, + { + **component_dict, + "componentTypeId": type.componentTypeId, + }, + ) + db.session.add(component) + db.session.flush() + composition = model_from_json( + models.CrystalComposition, + { + **component_dict, + "componentId": component.componentId, + "crystalId": crystal.crystalId, + }, + ) + db.session.add(composition) + db.session.flush() + + for component_dict in sample_components_list: + type = find_or_create_component_type(component_dict["componentType"]) + component = model_from_json( + models.Component, + { + **component_dict, + "componentTypeId": type.componentTypeId, + }, + ) + db.session.add(component) + db.session.flush() + composition = model_from_json( + models.SampleComposition, + { + **component_dict, + "componentId": component.componentId, + "blSampleId": sample.blSampleId, + }, + ) + db.session.add(composition) + db.session.flush() + + # DATA COLLECTION GROUP + + data_collection_group = model_from_json( + models.DataCollectionGroup, + { + **datacollectiongroup_dict, + "blSampleId": sample.blSampleId, + }, + ) + db.session.add(data_collection_group) + db.session.flush() + + db.session.commit() + return data_collection_group.dataCollectionGroupId + + except Exception as e: + logging.error(traceback.format_exc()) + db.session.rollback() + raise e + + +def create_ssx_datacollection_processing( + dataCollectionId: int, data: schema.SSXDataCollectionProcessingCreate +) -> int: + + program = models.AutoProcProgram( + dataCollectionId=dataCollectionId, + processingCommandLine=data.processingCommandLine, + processingPrograms=data.processingPrograms, + processingStatus="SUCCESS", + processingMessage=data.processingMessage, + processingStartTime=data.processingStartTime, + processingEndTime=data.processingEndTime, + processingEnvironment=data.processingEnvironment, + ) + db.session.add(program) + db.session.flush() + + autoProcProgramId = program.autoProcProgramId + + for resultPath in data.results: + [filePath, fileName] = os.path.split(resultPath) + attachment = models.AutoProcProgramAttachment( + filePath=filePath, + fileName=fileName, + fileType="Result", + autoProcProgramId=autoProcProgramId, + ) + db.session.add(attachment) + db.session.flush() + + db.session.commit() + return autoProcProgramId + + +def get_ssx_datacollection_processing_attachments_results( + dataCollectionIds: list[int], +) -> list[models.AutoProcProgramAttachment]: + query = ( + db.session.query(models.AutoProcProgramAttachment) + .filter(models.AutoProcProgramAttachment.fileType == "Result") + .options(joinedload(models.AutoProcProgramAttachment.AutoProcProgram)) + .options( + joinedload( + models.AutoProcProgramAttachment.AutoProcProgram, + models.AutoProcProgram.DataCollection, + ) + ) + .join( + models.AutoProcProgram, + models.AutoProcProgramAttachment.autoProcProgramId + == models.AutoProcProgram.autoProcProgramId, + ) + .join( + models.DataCollection, + models.AutoProcProgram.dataCollectionId + == models.DataCollection.dataCollectionId, + ) + .filter(models.DataCollection.dataCollectionId.in_(dataCollectionIds)) + .join( + models.DataCollectionGroup, + models.DataCollection.dataCollectionGroupId + == models.DataCollectionGroup.dataCollectionGroupId, + ) + .join( + models.BLSession, + models.DataCollectionGroup.sessionId == models.BLSession.sessionId, + ) + .join( + models.Proposal, models.BLSession.proposalId == models.Proposal.proposalId + ) + ) + + query = with_authorization(query, joinBLSession=False) + + return query.all() + + +T = TypeVar("T") + + +def parse_file_as_sync(type_: Type[T], path: str, validate: bool = True) -> T | None: + try: + with open(path, mode="r") as f: + contents = f.read() + if validate: + parsed = pydantic.parse_raw_as(type_, contents) + else: + parsed = json.loads(contents) + return parsed + except pydantic.error_wrappers.ValidationError: + return None + except FileNotFoundError: + return None + + +async def get_ssx_datacollection_processing_stats( + dataCollectionIds: list[int], +) -> list[schema.SSXDataCollectionProcessingStats]: + attachments: list[ + models.AutoProcProgramAttachment + ] = get_ssx_datacollection_processing_attachments_results(dataCollectionIds) + + res: list[schema.SSXDataCollectionProcessingStats] = [] + + for attachment in attachments: + if attachment.fileName == "ssx_stats.json": + path = os.path.join(attachment.filePath, attachment.fileName) + if settings.path_map: + path = os.path.join(settings.path_map, path) + parsed = await run_in_threadpool( + parse_file_as_sync, schema.SSXDataCollectionProcessingStatsBase, path + ) + if parsed is not None: + res.append( + { + "dataCollectionId": attachment.AutoProcProgram.DataCollection.dataCollectionId, + **parsed.dict(), + } + ) + return res + + +async def get_ssx_datacollection_processing_cells( + dataCollectionId: int, +) -> schema.SSXDataCollectionProcessingCells | None: + + attachments: list[ + models.AutoProcProgramAttachment + ] = get_ssx_datacollection_processing_attachments_results([dataCollectionId]) + + for attachment in attachments: + if attachment.fileName == "ssx_cells.json": + path = os.path.join(attachment.filePath, attachment.fileName) + if settings.path_map: + path = os.path.join(settings.path_map, path) + parsed = await run_in_threadpool( + parse_file_as_sync, + schema.SSXDataCollectionProcessingCells, + path, + validate=False, + ) + if parsed is not None: + return parsed + return None + + +async def get_ssx_datacollection_processing_cells_histogram( + dataCollectionIds: list[int], +) -> schema.SSXDataCollectionProcessingCellsHistogram: + cells = [] + for dataCollectionId in dataCollectionIds: + cells_json = await get_ssx_datacollection_processing_cells(dataCollectionId) + if cells_json is not None: + cells = cells + cells_json["unit_cells"] + if len(cells) == 0: + return { + "a": None, + "b": None, + "c": None, + "alpha": None, + "beta": None, + "gamma": None, + "dataCollectionIds": dataCollectionIds, + } + bins = to_bins(cells) + return { + "a": bins[0], + "b": bins[1], + "c": bins[2], + "alpha": bins[3], + "beta": bins[4], + "gamma": bins[5], + "dataCollectionIds": dataCollectionIds, + } + + +def to_bins(data: list[list[float]], nb_bins: int = 50): + unzipped = list(zip(*data)) + res = [] + for cell in unzipped: + hist, bin_edges = np.histogram(filter_outliers(cell), nb_bins) + median = np.median(cell) + res = res + [{"y": list(hist), "x": list(bin_edges), "median": median}] + return res + + +def filter_outliers(data: list[float]): + # FROM https://gist.github.com/vishalkuo/f4aec300cf6252ed28d3 + a = np.array(data) + upper_quartile = np.percentile(a, 75) + lower_quartile = np.percentile(a, 25) + IQR = (upper_quartile - lower_quartile) * 1.5 + quartileSet = (lower_quartile - IQR, upper_quartile + IQR) + resultList = [] + for y in a.tolist(): + if y >= quartileSet[0] and y <= quartileSet[1]: + resultList.append(y) + return resultList diff --git a/pyispyb/core/routes/eventchains.py b/pyispyb/core/routes/eventchains.py new file mode 100644 index 00000000..e3c6d79e --- /dev/null +++ b/pyispyb/core/routes/eventchains.py @@ -0,0 +1,19 @@ +from fastapi import Depends +from pyispyb.app.base import AuthenticatedAPIRouter +from pyispyb.app.extensions.database.utils import Paged +import pyispyb.core.modules.eventchains as crud +import pyispyb.core.schemas.eventchains as schema +from pyispyb.dependencies import pagination + +router = AuthenticatedAPIRouter(prefix="/eventchains", tags=["Event chains"]) + + +@router.get( + "", + response_model=Paged[schema.EventChainResponse], +) +def get_datacollection_eventchains( + dataCollectionId: int, + page: dict[str, int] = Depends(pagination), +) -> list[schema.EventChainResponse]: + return crud.get_datacollection_eventchains(dataCollectionId, **page) diff --git a/pyispyb/core/routes/ssx.py b/pyispyb/core/routes/ssx.py new file mode 100644 index 00000000..d571dfd1 --- /dev/null +++ b/pyispyb/core/routes/ssx.py @@ -0,0 +1,49 @@ +from fastapi import HTTPException +from pydantic import constr +from pyispyb.app.base import AuthenticatedAPIRouter +import pyispyb.core.modules.ssx as crud +import pyispyb.core.schemas.ssx as schema + +router = AuthenticatedAPIRouter(prefix="/ssx", tags=["Serial crystallography"]) + + +IdList = constr(regex=r"^\d+(,\d+)*$") + + +@router.get( + "/datacollection/processing/stats", + response_model=list[schema.SSXDataCollectionProcessingStats], +) +async def get_ssx_datacollection_processing_stats( + dataCollectionIds: IdList, +) -> list[schema.SSXDataCollectionProcessingStats]: + result = await crud.get_ssx_datacollection_processing_stats( + dataCollectionIds.split(",") + ) + return result + + +@router.get( + "/datacollection/processing/cells", + response_model=schema.SSXDataCollectionProcessingCells, +) +async def get_ssx_datacollection_processing_cells( + dataCollectionId: int, +): + result = await crud.get_ssx_datacollection_processing_cells(dataCollectionId) + if result is not None: + return result + raise HTTPException(status_code=404, detail="Item not found") + + +@router.get( + "/datacollection/processing/cells/histogram", + response_model=schema.SSXDataCollectionProcessingCellsHistogram, +) +async def get_ssx_datacollection_processing_cells_histogram( + dataCollectionIds: IdList, +): + result = await crud.get_ssx_datacollection_processing_cells_histogram( + dataCollectionIds.split(",") + ) + return result diff --git a/pyispyb/core/routes/webservices/__init__.py b/pyispyb/core/routes/webservices/__init__.py index 5e38b5a7..a49bfc00 100644 --- a/pyispyb/core/routes/webservices/__init__.py +++ b/pyispyb/core/routes/webservices/__init__.py @@ -3,7 +3,6 @@ from importlib import import_module from fastapi import FastAPI -from .base import router logger = logging.getLogger(__name__) @@ -23,5 +22,3 @@ def init_app(app: FastAPI, prefix: str = None, **kwargs): app.include_router(module.router, prefix=prefix) except Exception: logger.exception(f"Could not import module `{module_name}`") - - app.include_router(router, prefix=prefix) diff --git a/pyispyb/core/routes/webservices/base.py b/pyispyb/core/routes/webservices/base.py deleted file mode 100644 index 9dd58855..00000000 --- a/pyispyb/core/routes/webservices/base.py +++ /dev/null @@ -1,6 +0,0 @@ -from ....app.base import AuthenticatedAPIRouter - - -router = AuthenticatedAPIRouter( - prefix="/webservices", tags=["Webservices - Used by external applications"] -) diff --git a/pyispyb/core/routes/webservices/ssx.py b/pyispyb/core/routes/webservices/ssx.py new file mode 100644 index 00000000..6f9efc69 --- /dev/null +++ b/pyispyb/core/routes/webservices/ssx.py @@ -0,0 +1,47 @@ +import logging +from fastapi import Depends +from ....dependencies import permission +from ....app.base import AuthenticatedAPIRouter +from pyispyb.core.schemas import ssx as schema + +import pyispyb.core.modules.ssx as crud +from ispyb import models + + +router = AuthenticatedAPIRouter( + prefix="/webservices/ssx", + tags=["Webservices - Serial crystallography"], + dependencies=[Depends(permission("ssx_sync"))], +) + +logger = logging.getLogger("ispyb") + + +@router.post( + "/datacollection", + response_model=int, +) +def create_datacollection( + ssx_datacollection_create: schema.SSXDataCollectionCreate, +) -> models.SSXDataCollection: + return crud.create_ssx_datacollection(ssx_datacollection_create) + + +@router.post( + "/datacollectiongroup", + response_model=int, +) +def create_datacollectiongroup( + ssx_datacollectiongroup_create: schema.SSXDataCollectionGroupCreate, +) -> models.DataCollectionGroup: + return crud.create_ssx_datacollectiongroup(ssx_datacollectiongroup_create) + + +@router.post( + "/datacollection/{dataCollectionId:int}/processing", + response_model=int, +) +def create_ssx_datacollection_processing( + data: schema.SSXDataCollectionProcessingCreate, dataCollectionId: int +) -> int: + return crud.create_ssx_datacollection_processing(dataCollectionId, data) diff --git a/pyispyb/core/routes/webservices/userportalsync.py b/pyispyb/core/routes/webservices/userportalsync.py index ba6aa1be..10268f34 100644 --- a/pyispyb/core/routes/webservices/userportalsync.py +++ b/pyispyb/core/routes/webservices/userportalsync.py @@ -4,20 +4,24 @@ from ...schemas import userportalsync as schema from ....dependencies import permission from ..responses import Message -from .base import router +from ....app.base import AuthenticatedAPIRouter +router = AuthenticatedAPIRouter( + prefix="/webservices/userportalsync", + tags=["Webservices - User portal sync"], + dependencies=[Depends(permission("uportal_sync"))], +) logger = logging.getLogger("ispyb") @router.post( - "/userportalsync/sync_proposal", + "/sync_proposal", response_model=Message, responses={400: {"description": "The input data is not following the schema"}}, ) def sync_proposal( proposal: schema.UserPortalProposalSync, - depends: bool = Depends(permission("uportal_sync")), ): """Create/Update a proposal from the User Portal and all its related entities""" try: diff --git a/pyispyb/core/schemas/eventchains.py b/pyispyb/core/schemas/eventchains.py new file mode 100644 index 00000000..def9a106 --- /dev/null +++ b/pyispyb/core/schemas/eventchains.py @@ -0,0 +1,34 @@ +from typing import Literal, Optional, Dict, Any +from pydantic import BaseModel +from pydantic_sqlalchemy import sqlalchemy_to_pydantic +from ispyb import models + + +class EventCreate(BaseModel): + type: Literal["XrayDetection", "XrayExposure", "LaserExcitation", "ReactionTrigger"] + name: Optional[str] + offset: float + duration: Optional[float] + period: Optional[float] + repetition: Optional[float] + + +class EventChainCreate(BaseModel): + name: Optional[str] + events: list[EventCreate] + + +class EventResponse(sqlalchemy_to_pydantic(models.Event)): + EventType: sqlalchemy_to_pydantic(models.EventType) + + def dict(self, *args, **kwargs) -> Dict[str, Any]: + kwargs.pop("exclude_none") + return super().dict(*args, exclude_none=True, **kwargs) + + +class EventChainResponse(sqlalchemy_to_pydantic(models.EventChain)): + events: list[EventResponse] + + def dict(self, *args, **kwargs) -> Dict[str, Any]: + kwargs.pop("exclude_none") + return super().dict(*args, exclude_none=True, **kwargs) diff --git a/pyispyb/core/schemas/ssx.py b/pyispyb/core/schemas/ssx.py new file mode 100644 index 00000000..569075de --- /dev/null +++ b/pyispyb/core/schemas/ssx.py @@ -0,0 +1,131 @@ +from datetime import datetime +from typing import Literal, Optional +from .eventchains import EventChainCreate + +from pydantic import BaseModel + + +class SSXDataCollectionProcessingStatsBase(BaseModel): + nbHits: int + nbIndexed: int + laticeType: str + estimatedResolution: float + + +class SSXDataCollectionProcessingStats(SSXDataCollectionProcessingStatsBase): + dataCollectionId: int + + +class SSXDataCollectionProcessingCells(BaseModel): + unit_cells: list[list[float]] + + +class Histogram(BaseModel): + x: list[float] + y: list[int] + median: float + + +class SSXDataCollectionProcessingCellsHistogram(BaseModel): + a: Histogram | None + b: Histogram | None + c: Histogram | None + alpha: Histogram | None + beta: Histogram | None + gamma: Histogram | None + dataCollectionIds: list[int] + + +class SSXDataCollectionProcessingCreate(BaseModel): + processingCommandLine: Optional[str] + processingPrograms: Optional[str] + processingMessage: Optional[str] + processingStartTime: Optional[datetime] + processingEndTime: Optional[datetime] + processingEnvironment: Optional[str] + results: list[str] + + +class SSXProteinCreate(BaseModel): + name: Optional[str] + acronym: Optional[str] + + +class SSXSampleComponentCreate(BaseModel): + name: Optional[str] + componentType: Literal["Ligand", "Buffer", "JetMaterial"] + composition: Optional[str] + abundance: Optional[float] + + +class SSXCrystalCreate(BaseModel): + size_X: Optional[float] + size_Y: Optional[float] + size_Z: Optional[float] + abundance: Optional[float] + protein: SSXProteinCreate + components: list[SSXSampleComponentCreate] + + +class SSXSampleCreate(BaseModel): + name: Optional[str] + support: Optional[str] + crystal: SSXCrystalCreate + components: list[SSXSampleComponentCreate] + + +class SSXDataCollectionCreate(BaseModel): + dataCollectionGroupId: int + + # Table DataCollection + exposureTime: Optional[float] + transmission: Optional[float] + flux: Optional[float] + xBeam: Optional[float] + yBeam: Optional[float] + wavelength: Optional[float] + detectorDistance: Optional[float] + beamSizeAtSampleX: Optional[float] + beamSizeAtSampleY: Optional[float] + averageTemperature: Optional[float] + xtalSnapshotFullPath1: Optional[str] + xtalSnapshotFullPath2: Optional[str] + xtalSnapshotFullPath3: Optional[str] + xtalSnapshotFullPath4: Optional[str] + imagePrefix: Optional[str] + numberOfPasses: Optional[int] + numberOfImages: Optional[int] + resolution: Optional[float] + resolutionAtCorner: Optional[float] + flux_end: Optional[float] + detectorId: Optional[int] + startTime: datetime + endTime: Optional[datetime] + beamShape: Optional[str] + polarisation: Optional[float] + undulatorGap1: Optional[float] + + # Table SSXDataCollection + repetitionRate: Optional[float] + energyBandwidth: Optional[float] + monoStripe: Optional[str] + experimentName: Optional[str] + jetSize: Optional[float] + jetSpeed: Optional[float] + laserEnergy: Optional[float] + chipModel: Optional[str] + chipPattern: Optional[str] + + event_chains: list[EventChainCreate] + + +class SSXDataCollectionGroupCreate(BaseModel): + # Table DataCollectionGroup + sessionId: int + startTime: datetime + endTime: Optional[datetime] + experimentType: Optional[Literal["SSX-Chip", "SSX-Jet"]] + comments: Optional[str] + + sample: Optional[SSXSampleCreate] + sampleId: Optional[int] diff --git a/requirements.txt b/requirements.txt index 750fe688..2b3fc173 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -ispyb-models==1.0.6 +ispyb-models==1.1.0 fastapi pydantic[dotenv] diff --git a/tests/core/api/data/eventchains.py b/tests/core/api/data/eventchains.py new file mode 100644 index 00000000..e84b0599 --- /dev/null +++ b/tests/core/api/data/eventchains.py @@ -0,0 +1,13 @@ +from tests.core.api.utils.apitest import ApiTestElem, ApiTestExpected, ApiTestInput + + +test_data_event_chains = [ + ApiTestElem( + name="list eventchains empty", + input=ApiTestInput( + permissions=[], + route="/eventchains?dataCollectionId=0", + ), + expected=ApiTestExpected(code=200), + ) +] diff --git a/tests/core/api/data/ssx.py b/tests/core/api/data/ssx.py new file mode 100644 index 00000000..5099cd3e --- /dev/null +++ b/tests/core/api/data/ssx.py @@ -0,0 +1,400 @@ +from tests.core.api.utils.apitest import ApiTestElem, ApiTestExpected, ApiTestInput + + +test_data_ssx_stats = [ + ApiTestElem( + name="list stats empty", + input=ApiTestInput( + permissions=[], + route="/ssx/datacollection/processing/stats?dataCollectionIds=1,2", + ), + expected=ApiTestExpected(code=200, res=[]), + ), + ApiTestElem( + name="list stats wrong ids", + input=ApiTestInput( + permissions=[], + route="/ssx/datacollection/processing/stats?dataCollectionIds=1d2", + ), + expected=ApiTestExpected(code=422), + ), +] + +test_data_ssx_cells = [ + ApiTestElem( + name="cells empty", + input=ApiTestInput( + permissions=[], + route="/ssx/datacollection/processing/cells?dataCollectionId=1", + ), + expected=ApiTestExpected(code=404), + ), + ApiTestElem( + name="list stats wrong ids", + input=ApiTestInput( + permissions=[], + route="/ssx/datacollection/processing/cells?dataCollectionId=1d2", + ), + expected=ApiTestExpected(code=422), + ), +] + + +test_data_ssx_histogram = [ + ApiTestElem( + name="histogram empty", + input=ApiTestInput( + permissions=[], + route="/ssx/datacollection/processing/cells/histogram?dataCollectionIds=1,2", + ), + expected=ApiTestExpected( + code=200, + res={ + "a": None, + "b": None, + "c": None, + "alpha": None, + "beta": None, + "gamma": None, + "dataCollectionIds": [1, 2], + }, + ), + ), + ApiTestElem( + name="histogram wrong ids", + input=ApiTestInput( + permissions=[], + route="/ssx/datacollection/processing/cells/histogram?dataCollectionIds=1d2", + ), + expected=ApiTestExpected(code=422), + ), +] + + +test_data_ssx_create = [ + ApiTestElem( + name="create dc no permission", + input=ApiTestInput( + permissions=[], + route="/webservices/ssx/datacollection", + method="post", + payload={"dummy": 0}, + ), + expected=ApiTestExpected( + code=403, + ), + ), + ApiTestElem( + name="create dc", + input=ApiTestInput( + permissions=["ssx_sync"], + route="/webservices/ssx/datacollection", + method="post", + payload={ + "dataCollectionGroupId": 1, + "detectorId": 4, + "exposureTime": 0, + "transmission": 0, + "flux": 0, + "xBeam": 0, + "yBeam": 0, + "wavelength": 0, + "detectorDistance": 0, + "beamSizeAtSampleX": 0, + "beamSizeAtSampleY": 0, + "averageTemperature": 0, + "xtalSnapshotFullPath1": "string", + "xtalSnapshotFullPath2": "string", + "xtalSnapshotFullPath3": "string", + "xtalSnapshotFullPath4": "string", + "imagePrefix": "string", + "numberOfPasses": 0, + "numberOfImages": 0, + "resolution": 0, + "resolutionAtCorner": 0, + "flux_end": 0, + "startTime": "2023-01-25T09:16:42.053Z", + "endTime": "2023-01-25T09:16:42.053Z", + "beamShape": "string", + "polarisation": 0, + "undulatorGap1": 0, + "repetitionRate": 0, + "energyBandwidth": 0, + "monoStripe": "string", + "experimentName": "string", + "jetSize": 0, + "jetSpeed": 0, + "laserEnergy": 0, + "chipModel": "string", + "chipPattern": "string", + "event_chains": [ + { + "name": "string", + "events": [ + { + "type": "XrayDetection", + "name": "string", + "offset": 0, + "duration": 0, + "period": 0, + "repetition": 0, + } + ], + } + ], + }, + ), + expected=ApiTestExpected( + code=200, + ), + ), + ApiTestElem( + name="create dc unknown group", + input=ApiTestInput( + permissions=["ssx_sync"], + route="/webservices/ssx/datacollection", + method="post", + payload={ + "dataCollectionGroupId": 99999999, + "detectorId": 4, + "exposureTime": 0, + "transmission": 0, + "flux": 0, + "xBeam": 0, + "yBeam": 0, + "wavelength": 0, + "detectorDistance": 0, + "beamSizeAtSampleX": 0, + "beamSizeAtSampleY": 0, + "averageTemperature": 0, + "xtalSnapshotFullPath1": "string", + "xtalSnapshotFullPath2": "string", + "xtalSnapshotFullPath3": "string", + "xtalSnapshotFullPath4": "string", + "imagePrefix": "string", + "numberOfPasses": 0, + "numberOfImages": 0, + "resolution": 0, + "resolutionAtCorner": 0, + "flux_end": 0, + "startTime": "2023-01-25T09:16:42.053Z", + "endTime": "2023-01-25T09:16:42.053Z", + "beamShape": "string", + "polarisation": 0, + "undulatorGap1": 0, + "repetitionRate": 0, + "energyBandwidth": 0, + "monoStripe": "string", + "experimentName": "string", + "jetSize": 0, + "jetSpeed": 0, + "laserEnergy": 0, + "chipModel": "string", + "chipPattern": "string", + "event_chains": [ + { + "name": "string", + "events": [ + { + "type": "XrayDetection", + "name": "string", + "offset": 0, + "duration": 0, + "period": 0, + "repetition": 0, + } + ], + } + ], + }, + ), + expected=ApiTestExpected( + code=422, + ), + ), + ApiTestElem( + name="create dc unknown detector", + input=ApiTestInput( + permissions=["ssx_sync"], + route="/webservices/ssx/datacollection", + method="post", + payload={ + "dataCollectionGroupId": 1, + "detectorId": 99999999, + "exposureTime": 0, + "transmission": 0, + "flux": 0, + "xBeam": 0, + "yBeam": 0, + "wavelength": 0, + "detectorDistance": 0, + "beamSizeAtSampleX": 0, + "beamSizeAtSampleY": 0, + "averageTemperature": 0, + "xtalSnapshotFullPath1": "string", + "xtalSnapshotFullPath2": "string", + "xtalSnapshotFullPath3": "string", + "xtalSnapshotFullPath4": "string", + "imagePrefix": "string", + "numberOfPasses": 0, + "numberOfImages": 0, + "resolution": 0, + "resolutionAtCorner": 0, + "flux_end": 0, + "startTime": "2023-01-25T09:16:42.053Z", + "endTime": "2023-01-25T09:16:42.053Z", + "beamShape": "string", + "polarisation": 0, + "undulatorGap1": 0, + "repetitionRate": 0, + "energyBandwidth": 0, + "monoStripe": "string", + "experimentName": "string", + "jetSize": 0, + "jetSpeed": 0, + "laserEnergy": 0, + "chipModel": "string", + "chipPattern": "string", + "event_chains": [ + { + "name": "string", + "events": [ + { + "type": "XrayDetection", + "name": "string", + "offset": 0, + "duration": 0, + "period": 0, + "repetition": 0, + } + ], + } + ], + }, + ), + expected=ApiTestExpected( + code=422, + ), + ), + ApiTestElem( + name="create dcg no rights", + input=ApiTestInput( + permissions=[], + route="/webservices/ssx/datacollectiongroup", + method="post", + payload={"dummy": 0}, + ), + expected=ApiTestExpected( + code=403, + ), + ), + ApiTestElem( + name="create dcg", + input=ApiTestInput( + permissions=["ssx_sync"], + route="/webservices/ssx/datacollectiongroup", + method="post", + payload={ + "sessionId": 1, + "startTime": "2023-01-25T09:21:50.646Z", + "endTime": "2023-01-25T09:21:50.646Z", + "experimentType": "SSX-Chip", + "comments": "string", + "sample": { + "name": "string", + "support": "string", + "crystal": { + "size_X": 0, + "size_Y": 0, + "size_Z": 0, + "abundance": 0, + "protein": {"name": "string", "acronym": "string"}, + "components": [ + { + "name": "string", + "componentType": "Ligand", + "composition": "string", + "abundance": 0, + } + ], + }, + "components": [ + { + "name": "string", + "componentType": "Ligand", + "composition": "string", + "abundance": 0, + } + ], + }, + }, + ), + expected=ApiTestExpected( + code=200, + ), + ), + ApiTestElem( + name="create dcg wrong session", + input=ApiTestInput( + permissions=["ssx_sync"], + route="/webservices/ssx/datacollectiongroup", + method="post", + payload={ + "sessionId": 9999999, + "startTime": "2023-01-25T09:21:50.646Z", + "endTime": "2023-01-25T09:21:50.646Z", + "experimentType": "SSX-Chip", + "comments": "string", + "sample": { + "name": "string", + "support": "string", + "crystal": { + "size_X": 0, + "size_Y": 0, + "size_Z": 0, + "abundance": 0, + "protein": {"name": "string", "acronym": "string"}, + "components": [ + { + "name": "string", + "componentType": "Ligand", + "composition": "string", + "abundance": 0, + } + ], + }, + "components": [ + { + "name": "string", + "componentType": "Ligand", + "composition": "string", + "abundance": 0, + } + ], + }, + }, + ), + expected=ApiTestExpected( + code=422, + ), + ), + ApiTestElem( + name="create dcg wrong sample", + input=ApiTestInput( + permissions=["ssx_sync"], + route="/webservices/ssx/datacollectiongroup", + method="post", + payload={ + "sessionId": 1, + "startTime": "2023-01-25T09:21:50.646Z", + "endTime": "2023-01-25T09:21:50.646Z", + "experimentType": "SSX-Chip", + "comments": "string", + "sampleId": 99999, + }, + ), + expected=ApiTestExpected( + code=422, + ), + ), +] diff --git a/tests/core/api/test_eventchains.py b/tests/core/api/test_eventchains.py new file mode 100644 index 00000000..f7d0995c --- /dev/null +++ b/tests/core/api/test_eventchains.py @@ -0,0 +1,13 @@ +import pytest + +from starlette.types import ASGIApp + +from tests.conftest import AuthClient +from tests.core.api.utils.apitest import get_elem_name, run_test, ApiTestElem + +from tests.core.api.data.eventchains import test_data_event_chains + + +@pytest.mark.parametrize("test_elem", test_data_event_chains, ids=get_elem_name) +def test_event_chains(auth_client: AuthClient, test_elem: ApiTestElem, app: ASGIApp): + run_test(auth_client, test_elem, app) diff --git a/tests/core/api/test_ssx.py b/tests/core/api/test_ssx.py new file mode 100644 index 00000000..08cf0d26 --- /dev/null +++ b/tests/core/api/test_ssx.py @@ -0,0 +1,33 @@ +import pytest + +from starlette.types import ASGIApp + +from tests.conftest import AuthClient +from tests.core.api.utils.apitest import get_elem_name, run_test, ApiTestElem + +from tests.core.api.data.ssx import ( + test_data_ssx_stats, + test_data_ssx_cells, + test_data_ssx_histogram, + test_data_ssx_create, +) + + +@pytest.mark.parametrize("test_elem", test_data_ssx_stats, ids=get_elem_name) +def test_ssx_stats(auth_client: AuthClient, test_elem: ApiTestElem, app: ASGIApp): + run_test(auth_client, test_elem, app) + + +@pytest.mark.parametrize("test_elem", test_data_ssx_cells, ids=get_elem_name) +def test_ssx_cells(auth_client: AuthClient, test_elem: ApiTestElem, app: ASGIApp): + run_test(auth_client, test_elem, app) + + +@pytest.mark.parametrize("test_elem", test_data_ssx_histogram, ids=get_elem_name) +def test_ssx_histogram(auth_client: AuthClient, test_elem: ApiTestElem, app: ASGIApp): + run_test(auth_client, test_elem, app) + + +@pytest.mark.parametrize("test_elem", test_data_ssx_create, ids=get_elem_name) +def test_ssx_create(auth_client: AuthClient, test_elem: ApiTestElem, app: ASGIApp): + run_test(auth_client, test_elem, app) diff --git a/tests/core/api/utils/apitest.py b/tests/core/api/utils/apitest.py index 49c049c5..ea750360 100644 --- a/tests/core/api/utils/apitest.py +++ b/tests/core/api/utils/apitest.py @@ -12,7 +12,7 @@ class ApiTestInput: def __init__( self, *, - login: str, + login: str = "abcd", route: str, permissions: list[str] = [], method: str = "get",