Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@


class BaseModelConfig(BaseModel):
model_config = ConfigDict(
frozen=True,
regex_engine="rust-regex",
extra="forbid",
)
model_config = ConfigDict(frozen=True, regex_engine="rust-regex", extra="forbid", arbitrary_types_allowed=True)


class SupportedScanExtension(StrEnum):
Expand Down
21 changes: 21 additions & 0 deletions src/preprocessors/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import partial
from pathlib import Path

import numpy as np
from container_models.base import BinaryMask
from container_models.light_source import LightSource
from container_models.scan_image import ScanImage
from parsers import load_scan_image, parse_to_x3p, save_x3p, subsample_scan_image
Expand Down Expand Up @@ -41,6 +43,25 @@ def parse_scan_pipeline(scan_file: Path, step_size_x: int, step_size_y: int) ->
)


def parse_mask_pipeline(raw_data: bytes, shape: tuple[int, int], is_bitpacked: bool = False) -> BinaryMask:
"""
Convert incoming binary data to a 2D mask array.

:param raw_data: The binary data to convert.
:param shape: The shape of the mask array.
:param is_bitpacked: Boolean indicating whether the binary data is bit-packed
and should be decompressed before reshaping.
:returns: The 2D mask array.
"""
if not is_bitpacked:
# TODO: rewrite logic to use `run_pipeline()`
array = np.frombuffer(raw_data, dtype=np.bool).reshape(*shape)
return array
else:
# TODO: implement unpacking of bits
raise NotImplementedError


def x3p_pipeline(parsed_scan: ScanImage, output_path: Path) -> Path:
"""
Convert a scan image to X3P format and save it to the specified path.
Expand Down
36 changes: 33 additions & 3 deletions src/preprocessors/router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from functools import partial
from http import HTTPStatus
from typing import Annotated

from fastapi import APIRouter
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi.responses import RedirectResponse
from loguru import logger
from pydantic import Json

from constants import PreprocessorEndpoint, RoutePrefix
from extractors import ProcessedDataAccess
Expand All @@ -13,6 +15,7 @@

from .pipelines import (
impression_mark_pipeline,
parse_mask_pipeline,
parse_scan_pipeline,
preview_pipeline,
striation_mark_pipeline,
Expand Down Expand Up @@ -143,15 +146,42 @@ async def prepare_mark_striation(prepare_mark_parameters: PrepareMarkStriation)
"description": "processing error",
},
},
openapi_extra={
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {
"properties": {
"edit_image": EditImage.model_json_schema(),
"mask_data": {"type": "string", "format": "binary", "example": b"\x01\x00\x00\x01"},
},
"required": ["edit_image, mask_data"],
}
}
}
}
},
)
async def edit_scan(edit_image: EditImage) -> ProcessedDataAccess:
async def edit_scan(
edit_image: Annotated[Json[EditImage], Form(...)], mask_data: Annotated[UploadFile, File(...)] | None = None
) -> ProcessedDataAccess:
"""
Validate and parse a scan file with edit parameters.
Validate and parse a scan file with edit parameters and optional mask.
Accepts an X3P scan file and edit parameters (mask, zoom, step sizes),
validates the file format, parses it according to the parameters, and
creates a vault directory for future outputs. Returns access URLs for the vault.
"""
if mask_data is not None:
if edit_image.mask_parameters is None:
raise HTTPException(HTTPStatus.UNPROCESSABLE_CONTENT, "Invalid request: missing mask parameters.")

_ = parse_mask_pipeline(
raw_data=await mask_data.read(),
shape=edit_image.mask_parameters.shape,
is_bitpacked=edit_image.mask_parameters.is_bitpacked,
)

_ = parse_scan_pipeline(edit_image.scan_file, edit_image.step_size_x, edit_image.step_size_y)
vault = create_vault(edit_image.tag)

Expand Down
79 changes: 34 additions & 45 deletions src/preprocessors/schemas.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
from __future__ import annotations

from enum import StrEnum, auto
from functools import cached_property
from typing import Annotated, Self
from typing import Annotated, Any

import numpy as np
from container_models.light_source import LightSource
from numpy.typing import NDArray
from pydantic import (
AfterValidator,
Field,
PositiveFloat,
PositiveInt,
model_validator,
)
from pydantic import AfterValidator, Field, PositiveFloat, PositiveInt, model_validator
from scipy.constants import micro

from constants import ImpressionMarks, MaskTypes, StriationMarks
Expand Down Expand Up @@ -144,23 +135,18 @@ class RegressionOrder(StrEnum):
R2 = auto()


type Mask = tuple[tuple[bool, ...], ...]
class MaskParameters(BaseModelConfig):
shape: tuple[PositiveInt, PositiveInt] = Field(
..., examples=[[100, 100], [250, 150]], description="Defines the shape of the 2D mask array."
)
is_bitpacked: bool = Field(
default=False, examples=[False, True], description="Whether the mask is bit-packed."
) # TODO: create enum/flags for compression types


class EditImage(BaseParameters):
"""Request model for editing and transforming processed scan images."""

mask: Mask = Field(
description=(
"Binary mask defining regions to include (True/1) or exclude (False/0) during processing. "
"Accepts both boolean (True/False) and integer (1/0) representations. "
"Must be a 2D tuple structure matching the scan dimensions."
),
examples=[
((1, 0), (0, 1)), # Integer format
((True, False), (False, True)), # Boolean format
],
)
cutoff_length: Annotated[PositiveFloat, AfterValidator(lambda x: x * micro)] = Field(
description="Cutoff wavelength in micrometers (µm) for Gaussian regression filtering. "
"Defines the spatial frequency threshold for surface texture analysis.",
Expand Down Expand Up @@ -195,30 +181,33 @@ class EditImage(BaseParameters):
description="Subsampling step size in y-direction. Values > 1 reduce resolution by skipping pixels.",
examples=[1, 2, 4],
)
mask_parameters: MaskParameters | None = Field(default=None, description="Mask parameters.")

@model_validator(mode="after")
def validate_mask_is_2d(self) -> Self:
"""
Validate that the mask is a valid 2D array structure.
Ensures the mask can be converted to a numpy array and has exactly
2 dimensions, as required for image masking operations.
"""
try:
self.mask_array
except (ValueError, TypeError) as e:
raise ValueError("Bad mask value: unable to capture mask") from e
if not self.mask_array.ndim == 2: # noqa: PLR2004
raise ValueError(f"Mask is not a 2D image: D{self.mask_array.ndim}")
if self.scan_file.suffix != ".x3p":
def check_file_is_x3p(self):
"""Check whether the scan file is an x3p file."""
if self.scan_file.suffix.lower() != ".x3p":
raise ValueError(f"Unsupported extension: {self.scan_file.suffix}")
return self

@cached_property
def mask_array(self) -> NDArray:
"""
Convert the mask tuple to a numpy boolean array.
:return: 2D numpy array of boolean values representing the mask
"""
return np.array(self.mask, np.bool_)
@classmethod
def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]:
"""Override the base method."""
schema = super().model_json_schema(*args, **kwargs)
# Add schema for mask parameters to JSON model
schema["properties"]["mask_parameters"] = MaskParameters.model_json_schema(*args, **kwargs)
# Add schema for BaseParameters and EditImage to JSON model
attr_to_class = (
("scan_file", "ScanFile"),
("regression_order", "RegressionOrder"),
("terms", "Terms"),
("project_name", "ProjectTag"),
)
for attribute, class_name in attr_to_class:
updated = schema["$defs"][class_name]
for key in ("examples", "description"):
if value := schema["properties"][attribute].get(key):
updated[key] = value
schema["properties"][attribute] = updated

return schema
16 changes: 12 additions & 4 deletions tests/preprocessors/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,27 @@
from pathlib import Path
from typing import Final

import numpy as np
import pytest

from preprocessors.schemas import EditImage, Mask, UploadScan
from preprocessors.schemas import EditImage, MaskParameters, UploadScan

MASK: Final[Mask] = ((1, 0, 1), (0, 1, 0)) # type: ignore
MASK = np.array([[True, False, True], [False, True, False]], dtype=np.bool)
MASK_BYTES = MASK.tobytes(order="C")
MASK_SHAPE = MASK.shape
CUTOFF_LENGTH: Final[float] = 250


@pytest.fixture(scope="module")
def edit_image_parameter(scan_directory: Path) -> Callable[..., EditImage]:
def wrapper(**kwargs) -> EditImage:
return EditImage.model_validate(
{"scan_file": scan_directory / "circle.x3p", "mask": MASK, "cutoff_length": CUTOFF_LENGTH} | kwargs
{
"scan_file": scan_directory / "circle.x3p",
"cutoff_length": CUTOFF_LENGTH,
"mask_parameters": {"shape": MASK_SHAPE},
}
| kwargs
)

return wrapper
Expand All @@ -24,8 +32,8 @@ def wrapper(**kwargs) -> EditImage:
def edit_image(scan_directory: Path) -> EditImage:
return EditImage(
scan_file=scan_directory / "circle.x3p",
mask=MASK,
cutoff_length=CUTOFF_LENGTH,
mask_parameters=MaskParameters(shape=MASK_SHAPE), # type: ignore
) # type: ignore


Expand Down
38 changes: 38 additions & 0 deletions tests/preprocessors/pipelines/test_parse_mask_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import pytest
from container_models.base import BinaryMask

from preprocessors.pipelines import parse_mask_pipeline


@pytest.mark.integration
class TestParseMaskPipeline:
@pytest.fixture(scope="class")
def mask_array(self) -> BinaryMask:
"""Fixture for a 2D mask array."""
return np.array([[True, False, True], [False, False, True]], dtype=np.bool)

def test_pipeline_can_parse_mask(self, mask_array: BinaryMask) -> None:
"""Test that the pipeline can parse a 2D mask from binary data."""
# Arrange
raw_data = mask_array.tobytes(order="C")
shape = mask_array.shape
# Act
parsed_mask = parse_mask_pipeline(raw_data, shape)
# Assert
assert np.array_equal(parsed_mask, mask_array)

def test_pipeline_raises_on_incorrect_shape(self, mask_array: BinaryMask) -> None:
"""Test that the pipeline will raise an error if the shape is incorrect."""
raw_data = mask_array.tobytes(order="C")
incorrect_shape = (100, 150)
with pytest.raises(ValueError, match="cannot reshape array"):
_ = parse_mask_pipeline(raw_data, incorrect_shape)

def test_pipeline_not_implemented(self, mask_array: BinaryMask) -> None:
"""Test that the pipeline will raise an error for not implemented operations."""
# TODO: Remove test when bit unpacking is implemented.
raw_data = mask_array.tobytes(order="C")
shape = mask_array.shape
with pytest.raises(NotImplementedError):
_ = parse_mask_pipeline(raw_data, shape, is_bitpacked=True)
5 changes: 2 additions & 3 deletions tests/preprocessors/schemas/test_edit_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from pydantic import ValidationError
from scipy.constants import micro

from preprocessors.schemas import EditImage, Mask, RegressionOrder, Terms
from preprocessors.schemas import EditImage, RegressionOrder, Terms

DEFAULT_RESAMPLING_FACTOR: Final[int] = 4
DEFAULT_STEP_SIZE: Final[int] = 1
MASK: Final[Mask] = ((1, 0, 1), (0, 1, 0)) # type: ignore
CUTOFF_LENGTH: Final[float] = 250


Expand Down Expand Up @@ -173,4 +172,4 @@ def test_should_reject_when_required_fields_not_provided(self) -> None:
EditImage() # type: ignore

# Assert
assert get_error_fields(exc_info, "missing") == ("scan_file", "mask", "cutoff_length")
assert get_error_fields(exc_info, "missing") == ("scan_file", "cutoff_length")
24 changes: 21 additions & 3 deletions tests/test_contracts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from enum import StrEnum
from http import HTTPStatus
from pathlib import Path

import numpy as np
import pytest
import requests
from pydantic import BaseModel
Expand All @@ -15,6 +17,7 @@
from models import DirectoryAccess
from preprocessors.schemas import (
EditImage,
MaskParameters,
PrepareMarkImpression,
PrepareMarkStriation,
PreprocessingImpressionParams,
Expand All @@ -24,7 +27,9 @@
from settings import get_settings

SCANS_DIR = PROJECT_ROOT / "packages/scratch-core/tests/resources/scans"
MASK = ((1, 0), (0, 1))
MASK = np.array([[True, False], [False, True]], dtype=np.bool)
MASK_BYTES = MASK.tobytes(order="C")
MASK_SHAPE = MASK.shape
CUTOFF_LENGTH = 250 # 250 micrometers in meters


Expand Down Expand Up @@ -101,8 +106,8 @@ def edit_scan(self, scan_directory: Path) -> Interface:
"""
data = EditImage( # type: ignore
scan_file=scan_directory / "Klein_non_replica_mode_X3P_Scratch.x3p",
mask=MASK,
cutoff_length=CUTOFF_LENGTH,
mask_parameters=MaskParameters(shape=MASK_SHAPE),
)
return data, ProcessedDataAccess

Expand All @@ -127,7 +132,6 @@ def test_root(self, route: str) -> None:
pytest.param("process_scan", "process-scan", id="process_scan"),
pytest.param("prepare_mark_impression", "prepare-mark-impression", id="prepare_mark_impression"),
pytest.param("prepare_mark_striation", "prepare-mark-striation", id="prepare_mark_striation"),
pytest.param("edit_scan", "edit-scan", marks=pytest.mark.xfail, id="edit_scan"),
],
)
def test_pre_processor_post_requests(
Expand All @@ -145,6 +149,20 @@ def test_pre_processor_post_requests(
assert response.status_code == HTTPStatus.OK
expected_response.model_validate(response.json())

def test_pre_processor_edit_image_post_requests(self, edit_scan: tuple[EditImage, ProcessedDataAccess]) -> None:
"""Test if preprocessor EditImage POST endpoints return expected models."""
params, expected_response = edit_scan
# Act
response = requests.post(
f"{get_settings().base_url}/{RoutePrefix.PREPROCESSOR}/edit-scan",
data={"edit_image": json.dumps(params.model_dump(mode="json"))},
files={"mask_data": ("mask.bin", MASK_BYTES, "application/octet-stream")},
timeout=5,
)
# Assert
assert response.status_code == HTTPStatus.OK
expected_response.model_validate(response.json())

def test_extractor_get_file_endpoint(self, directory_access: DirectoryAccess) -> None:
"""Test if extractor /files/{token}/{filename} endpoint retrieves processed files.

Expand Down