diff --git a/src/models.py b/src/models.py index 350a752f..0f813f0e 100644 --- a/src/models.py +++ b/src/models.py @@ -14,11 +14,7 @@ class BaseModelConfig(BaseModel): - model_config = ConfigDict( - frozen=True, - regex_engine="rust-regex", - extra="forbid", - ) + model_config = ConfigDict(frozen=True, regex_engine="rust-regex", extra="forbid", arbitrary_types_allowed=True) class SupportedScanExtension(StrEnum): diff --git a/src/preprocessors/pipelines.py b/src/preprocessors/pipelines.py index 984c292e..c60e933b 100644 --- a/src/preprocessors/pipelines.py +++ b/src/preprocessors/pipelines.py @@ -2,6 +2,8 @@ from functools import partial from pathlib import Path +import numpy as np +from container_models.base import BinaryMask from container_models.light_source import LightSource from container_models.scan_image import ScanImage from parsers import load_scan_image, parse_to_x3p, save_x3p, subsample_scan_image @@ -41,6 +43,25 @@ def parse_scan_pipeline(scan_file: Path, step_size_x: int, step_size_y: int) -> ) +def parse_mask_pipeline(raw_data: bytes, shape: tuple[int, int], is_bitpacked: bool = False) -> BinaryMask: + """ + Convert incoming binary data to a 2D mask array. + + :param raw_data: The binary data to convert. + :param shape: The shape of the mask array. + :param is_bitpacked: Boolean indicating whether the binary data is bit-packed + and should be decompressed before reshaping. + :returns: The 2D mask array. + """ + if not is_bitpacked: + # TODO: rewrite logic to use `run_pipeline()` + array = np.frombuffer(raw_data, dtype=np.bool).reshape(*shape) + return array + else: + # TODO: implement unpacking of bits + raise NotImplementedError + + def x3p_pipeline(parsed_scan: ScanImage, output_path: Path) -> Path: """ Convert a scan image to X3P format and save it to the specified path. diff --git a/src/preprocessors/router.py b/src/preprocessors/router.py index fef5d43f..5ae58560 100644 --- a/src/preprocessors/router.py +++ b/src/preprocessors/router.py @@ -1,9 +1,11 @@ from functools import partial from http import HTTPStatus +from typing import Annotated -from fastapi import APIRouter +from fastapi import APIRouter, File, Form, HTTPException, UploadFile from fastapi.responses import RedirectResponse from loguru import logger +from pydantic import Json from constants import PreprocessorEndpoint, RoutePrefix from extractors import ProcessedDataAccess @@ -13,6 +15,7 @@ from .pipelines import ( impression_mark_pipeline, + parse_mask_pipeline, parse_scan_pipeline, preview_pipeline, striation_mark_pipeline, @@ -143,15 +146,42 @@ async def prepare_mark_striation(prepare_mark_parameters: PrepareMarkStriation) "description": "processing error", }, }, + openapi_extra={ + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "properties": { + "edit_image": EditImage.model_json_schema(), + "mask_data": {"type": "string", "format": "binary", "example": b"\x01\x00\x00\x01"}, + }, + "required": ["edit_image, mask_data"], + } + } + } + } + }, ) -async def edit_scan(edit_image: EditImage) -> ProcessedDataAccess: +async def edit_scan( + edit_image: Annotated[Json[EditImage], Form(...)], mask_data: Annotated[UploadFile, File(...)] | None = None +) -> ProcessedDataAccess: """ - Validate and parse a scan file with edit parameters. + Validate and parse a scan file with edit parameters and optional mask. Accepts an X3P scan file and edit parameters (mask, zoom, step sizes), validates the file format, parses it according to the parameters, and creates a vault directory for future outputs. Returns access URLs for the vault. """ + if mask_data is not None: + if edit_image.mask_parameters is None: + raise HTTPException(HTTPStatus.UNPROCESSABLE_CONTENT, "Invalid request: missing mask parameters.") + + _ = parse_mask_pipeline( + raw_data=await mask_data.read(), + shape=edit_image.mask_parameters.shape, + is_bitpacked=edit_image.mask_parameters.is_bitpacked, + ) + _ = parse_scan_pipeline(edit_image.scan_file, edit_image.step_size_x, edit_image.step_size_y) vault = create_vault(edit_image.tag) diff --git a/src/preprocessors/schemas.py b/src/preprocessors/schemas.py index 06a801e9..fa0448df 100644 --- a/src/preprocessors/schemas.py +++ b/src/preprocessors/schemas.py @@ -1,19 +1,10 @@ from __future__ import annotations from enum import StrEnum, auto -from functools import cached_property -from typing import Annotated, Self +from typing import Annotated, Any -import numpy as np from container_models.light_source import LightSource -from numpy.typing import NDArray -from pydantic import ( - AfterValidator, - Field, - PositiveFloat, - PositiveInt, - model_validator, -) +from pydantic import AfterValidator, Field, PositiveFloat, PositiveInt, model_validator from scipy.constants import micro from constants import ImpressionMarks, MaskTypes, StriationMarks @@ -144,23 +135,18 @@ class RegressionOrder(StrEnum): R2 = auto() -type Mask = tuple[tuple[bool, ...], ...] +class MaskParameters(BaseModelConfig): + shape: tuple[PositiveInt, PositiveInt] = Field( + ..., examples=[[100, 100], [250, 150]], description="Defines the shape of the 2D mask array." + ) + is_bitpacked: bool = Field( + default=False, examples=[False, True], description="Whether the mask is bit-packed." + ) # TODO: create enum/flags for compression types class EditImage(BaseParameters): """Request model for editing and transforming processed scan images.""" - mask: Mask = Field( - description=( - "Binary mask defining regions to include (True/1) or exclude (False/0) during processing. " - "Accepts both boolean (True/False) and integer (1/0) representations. " - "Must be a 2D tuple structure matching the scan dimensions." - ), - examples=[ - ((1, 0), (0, 1)), # Integer format - ((True, False), (False, True)), # Boolean format - ], - ) cutoff_length: Annotated[PositiveFloat, AfterValidator(lambda x: x * micro)] = Field( description="Cutoff wavelength in micrometers (µm) for Gaussian regression filtering. " "Defines the spatial frequency threshold for surface texture analysis.", @@ -195,30 +181,33 @@ class EditImage(BaseParameters): description="Subsampling step size in y-direction. Values > 1 reduce resolution by skipping pixels.", examples=[1, 2, 4], ) + mask_parameters: MaskParameters | None = Field(default=None, description="Mask parameters.") @model_validator(mode="after") - def validate_mask_is_2d(self) -> Self: - """ - Validate that the mask is a valid 2D array structure. - - Ensures the mask can be converted to a numpy array and has exactly - 2 dimensions, as required for image masking operations. - """ - try: - self.mask_array - except (ValueError, TypeError) as e: - raise ValueError("Bad mask value: unable to capture mask") from e - if not self.mask_array.ndim == 2: # noqa: PLR2004 - raise ValueError(f"Mask is not a 2D image: D{self.mask_array.ndim}") - if self.scan_file.suffix != ".x3p": + def check_file_is_x3p(self): + """Check whether the scan file is an x3p file.""" + if self.scan_file.suffix.lower() != ".x3p": raise ValueError(f"Unsupported extension: {self.scan_file.suffix}") return self - @cached_property - def mask_array(self) -> NDArray: - """ - Convert the mask tuple to a numpy boolean array. - - :return: 2D numpy array of boolean values representing the mask - """ - return np.array(self.mask, np.bool_) + @classmethod + def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]: + """Override the base method.""" + schema = super().model_json_schema(*args, **kwargs) + # Add schema for mask parameters to JSON model + schema["properties"]["mask_parameters"] = MaskParameters.model_json_schema(*args, **kwargs) + # Add schema for BaseParameters and EditImage to JSON model + attr_to_class = ( + ("scan_file", "ScanFile"), + ("regression_order", "RegressionOrder"), + ("terms", "Terms"), + ("project_name", "ProjectTag"), + ) + for attribute, class_name in attr_to_class: + updated = schema["$defs"][class_name] + for key in ("examples", "description"): + if value := schema["properties"][attribute].get(key): + updated[key] = value + schema["properties"][attribute] = updated + + return schema diff --git a/tests/preprocessors/conftest.py b/tests/preprocessors/conftest.py index 35f2bf3f..c46ef2ed 100644 --- a/tests/preprocessors/conftest.py +++ b/tests/preprocessors/conftest.py @@ -2,11 +2,14 @@ from pathlib import Path from typing import Final +import numpy as np import pytest -from preprocessors.schemas import EditImage, Mask, UploadScan +from preprocessors.schemas import EditImage, MaskParameters, UploadScan -MASK: Final[Mask] = ((1, 0, 1), (0, 1, 0)) # type: ignore +MASK = np.array([[True, False, True], [False, True, False]], dtype=np.bool) +MASK_BYTES = MASK.tobytes(order="C") +MASK_SHAPE = MASK.shape CUTOFF_LENGTH: Final[float] = 250 @@ -14,7 +17,12 @@ def edit_image_parameter(scan_directory: Path) -> Callable[..., EditImage]: def wrapper(**kwargs) -> EditImage: return EditImage.model_validate( - {"scan_file": scan_directory / "circle.x3p", "mask": MASK, "cutoff_length": CUTOFF_LENGTH} | kwargs + { + "scan_file": scan_directory / "circle.x3p", + "cutoff_length": CUTOFF_LENGTH, + "mask_parameters": {"shape": MASK_SHAPE}, + } + | kwargs ) return wrapper @@ -24,8 +32,8 @@ def wrapper(**kwargs) -> EditImage: def edit_image(scan_directory: Path) -> EditImage: return EditImage( scan_file=scan_directory / "circle.x3p", - mask=MASK, cutoff_length=CUTOFF_LENGTH, + mask_parameters=MaskParameters(shape=MASK_SHAPE), # type: ignore ) # type: ignore diff --git a/tests/preprocessors/pipelines/test_parse_mask_pipeline.py b/tests/preprocessors/pipelines/test_parse_mask_pipeline.py new file mode 100644 index 00000000..e6bc6c95 --- /dev/null +++ b/tests/preprocessors/pipelines/test_parse_mask_pipeline.py @@ -0,0 +1,38 @@ +import numpy as np +import pytest +from container_models.base import BinaryMask + +from preprocessors.pipelines import parse_mask_pipeline + + +@pytest.mark.integration +class TestParseMaskPipeline: + @pytest.fixture(scope="class") + def mask_array(self) -> BinaryMask: + """Fixture for a 2D mask array.""" + return np.array([[True, False, True], [False, False, True]], dtype=np.bool) + + def test_pipeline_can_parse_mask(self, mask_array: BinaryMask) -> None: + """Test that the pipeline can parse a 2D mask from binary data.""" + # Arrange + raw_data = mask_array.tobytes(order="C") + shape = mask_array.shape + # Act + parsed_mask = parse_mask_pipeline(raw_data, shape) + # Assert + assert np.array_equal(parsed_mask, mask_array) + + def test_pipeline_raises_on_incorrect_shape(self, mask_array: BinaryMask) -> None: + """Test that the pipeline will raise an error if the shape is incorrect.""" + raw_data = mask_array.tobytes(order="C") + incorrect_shape = (100, 150) + with pytest.raises(ValueError, match="cannot reshape array"): + _ = parse_mask_pipeline(raw_data, incorrect_shape) + + def test_pipeline_not_implemented(self, mask_array: BinaryMask) -> None: + """Test that the pipeline will raise an error for not implemented operations.""" + # TODO: Remove test when bit unpacking is implemented. + raw_data = mask_array.tobytes(order="C") + shape = mask_array.shape + with pytest.raises(NotImplementedError): + _ = parse_mask_pipeline(raw_data, shape, is_bitpacked=True) diff --git a/tests/preprocessors/schemas/test_edit_image.py b/tests/preprocessors/schemas/test_edit_image.py index ae9b222e..410d6e26 100644 --- a/tests/preprocessors/schemas/test_edit_image.py +++ b/tests/preprocessors/schemas/test_edit_image.py @@ -9,11 +9,10 @@ from pydantic import ValidationError from scipy.constants import micro -from preprocessors.schemas import EditImage, Mask, RegressionOrder, Terms +from preprocessors.schemas import EditImage, RegressionOrder, Terms DEFAULT_RESAMPLING_FACTOR: Final[int] = 4 DEFAULT_STEP_SIZE: Final[int] = 1 -MASK: Final[Mask] = ((1, 0, 1), (0, 1, 0)) # type: ignore CUTOFF_LENGTH: Final[float] = 250 @@ -173,4 +172,4 @@ def test_should_reject_when_required_fields_not_provided(self) -> None: EditImage() # type: ignore # Assert - assert get_error_fields(exc_info, "missing") == ("scan_file", "mask", "cutoff_length") + assert get_error_fields(exc_info, "missing") == ("scan_file", "cutoff_length") diff --git a/tests/test_contracts.py b/tests/test_contracts.py index 0bec0d91..625cb44c 100644 --- a/tests/test_contracts.py +++ b/tests/test_contracts.py @@ -1,7 +1,9 @@ +import json from enum import StrEnum from http import HTTPStatus from pathlib import Path +import numpy as np import pytest import requests from pydantic import BaseModel @@ -15,6 +17,7 @@ from models import DirectoryAccess from preprocessors.schemas import ( EditImage, + MaskParameters, PrepareMarkImpression, PrepareMarkStriation, PreprocessingImpressionParams, @@ -24,7 +27,9 @@ from settings import get_settings SCANS_DIR = PROJECT_ROOT / "packages/scratch-core/tests/resources/scans" -MASK = ((1, 0), (0, 1)) +MASK = np.array([[True, False], [False, True]], dtype=np.bool) +MASK_BYTES = MASK.tobytes(order="C") +MASK_SHAPE = MASK.shape CUTOFF_LENGTH = 250 # 250 micrometers in meters @@ -101,8 +106,8 @@ def edit_scan(self, scan_directory: Path) -> Interface: """ data = EditImage( # type: ignore scan_file=scan_directory / "Klein_non_replica_mode_X3P_Scratch.x3p", - mask=MASK, cutoff_length=CUTOFF_LENGTH, + mask_parameters=MaskParameters(shape=MASK_SHAPE), ) return data, ProcessedDataAccess @@ -127,7 +132,6 @@ def test_root(self, route: str) -> None: pytest.param("process_scan", "process-scan", id="process_scan"), pytest.param("prepare_mark_impression", "prepare-mark-impression", id="prepare_mark_impression"), pytest.param("prepare_mark_striation", "prepare-mark-striation", id="prepare_mark_striation"), - pytest.param("edit_scan", "edit-scan", marks=pytest.mark.xfail, id="edit_scan"), ], ) def test_pre_processor_post_requests( @@ -145,6 +149,20 @@ def test_pre_processor_post_requests( assert response.status_code == HTTPStatus.OK expected_response.model_validate(response.json()) + def test_pre_processor_edit_image_post_requests(self, edit_scan: tuple[EditImage, ProcessedDataAccess]) -> None: + """Test if preprocessor EditImage POST endpoints return expected models.""" + params, expected_response = edit_scan + # Act + response = requests.post( + f"{get_settings().base_url}/{RoutePrefix.PREPROCESSOR}/edit-scan", + data={"edit_image": json.dumps(params.model_dump(mode="json"))}, + files={"mask_data": ("mask.bin", MASK_BYTES, "application/octet-stream")}, + timeout=5, + ) + # Assert + assert response.status_code == HTTPStatus.OK + expected_response.model_validate(response.json()) + def test_extractor_get_file_endpoint(self, directory_access: DirectoryAccess) -> None: """Test if extractor /files/{token}/{filename} endpoint retrieves processed files.