diff --git a/packages/scratch-core/src/mutations/__init__.py b/packages/scratch-core/src/mutations/__init__.py index 68e249f2..0db9179d 100644 --- a/packages/scratch-core/src/mutations/__init__.py +++ b/packages/scratch-core/src/mutations/__init__.py @@ -9,8 +9,8 @@ can be chained together using a pipeline (e.g. `returns.pipeline.pipe`). """ -from .filter import LevelMap +from .filter import LevelMap, Mask from .spatial import CropToMask, Resample -__all__ = ["LevelMap", "CropToMask", "Resample"] +__all__ = ["LevelMap", "Mask", "Resample", "CropToMask"] diff --git a/packages/scratch-core/src/mutations/filter.py b/packages/scratch-core/src/mutations/filter.py index c4e92235..2f6df754 100644 --- a/packages/scratch-core/src/mutations/filter.py +++ b/packages/scratch-core/src/mutations/filter.py @@ -1,12 +1,16 @@ from typing import NamedTuple -from container_models.base import FloatArray1D, FloatArray2D + +import numpy as np +from loguru import logger + +from container_models.base import BinaryMask, FloatArray1D, FloatArray2D from container_models.scan_image import ScanImage from conversion.leveling.data_types import SurfaceTerms from conversion.leveling.solver.design import build_design_matrix from conversion.leveling.solver.grid import get_2d_grid from conversion.leveling.solver.transforms import normalize_coordinates +from exceptions import ImageShapeMismatchError from mutations.base import ImageMutation -import numpy as np class PointCloud(NamedTuple): @@ -15,6 +19,58 @@ class PointCloud(NamedTuple): zs: FloatArray1D +class Mask(ImageMutation): + """ + Image mutation that applies a binary mask to a scan image. + + All pixels corresponding to `False` (or zero) values in the mask + are set to `np.nan` in the image data. Pixels where the mask is + `True` remain unchanged. + """ + + def __init__(self, mask: BinaryMask) -> None: + """ + Initialize the Mask mutation. + + :param mask: Binary mask indicating which pixels should be kept (`True`) + or masked (`False`). + """ + self.mask = mask + + @property + def skip_predicate(self) -> bool: + """ + Determine whether the masking operation can be skipped. + + If the mask contains no masked pixels (i.e. all values are `True`), + applying the mask would have no effect and the mutation is skipped. + + :returns: bool `True` if the mutation can be skipped, otherwise `False`. + """ + if self.mask.all(): + logger.warning( + "skipping masking, Mask area is not containing any masking fields." + ) + return True + return False + + def apply_on_image(self, scan_image: ScanImage) -> ScanImage: + """ + Apply the mask to the image. + + :params scan_image: Input scan image to which the mask is applied. + :return: The masked scan image. + :raises ImageShapeMismatchError: If the mask shape does not match the image data shape. + """ + if self.mask.shape != scan_image.data.shape: + raise ImageShapeMismatchError( + f"Mask shape: {self.mask.shape} does not match image shape: {scan_image.data.shape}" + ) + logger.info("Applying mask to scan_image") + scan_image.data[~self.mask] = np.nan + return scan_image + + class LevelMap(ImageMutation): """ Image mutation that performs surface leveling by fitting and subtracting diff --git a/packages/scratch-core/tests/mutations/test_mask.py b/packages/scratch-core/tests/mutations/test_mask.py new file mode 100644 index 00000000..d253e255 --- /dev/null +++ b/packages/scratch-core/tests/mutations/test_mask.py @@ -0,0 +1,77 @@ +import re + +import numpy as np +import pytest + +from container_models.scan_image import ScanImage +from exceptions import ImageShapeMismatchError +from mutations.filter import Mask + + +class TestMask2dArray: + @pytest.fixture + def scan_image( + self, + ): + return ScanImage( + data=np.array([[1, 2], [3, 4]], dtype=float), scale_x=1.0, scale_y=1.0 + ) + + def test_mask_sets_background_pixels_to_nan(self, scan_image: ScanImage) -> None: + # Arrange + mask = np.array([[1, 0], [0, 1]], dtype=bool) + masking_mutator = Mask(mask=mask) + # Act + result = masking_mutator.apply_on_image(scan_image=scan_image) + # Assert + assert np.array_equal( + result.data, np.array([[1, np.nan], [np.nan, 4]]), equal_nan=True + ) + + def test_raises_on_shape_mismatch(self, scan_image: ScanImage) -> None: + # Arrange + mask = np.array([[1, 0, 0], [0, 1, 0]], dtype=bool) + masking_mutator = Mask(mask=mask) + # Act / Assert + with pytest.raises( + ImageShapeMismatchError, + match=re.escape( + f"Mask shape: {mask.shape} does not match image shape: {scan_image.data.shape}" + ), + ): + masking_mutator.apply_on_image(scan_image=scan_image) + + def test_full_mask_preserves_all_values(self, scan_image: ScanImage) -> None: + # Arrange + mask = np.ones((2, 2), dtype=bool) + masking_mutator = Mask(mask=mask) + # Act + result = masking_mutator.apply_on_image(scan_image=scan_image) + # Assert + assert np.array_equal(result.data, scan_image.data, equal_nan=True) + + def test_full_mask_skips_calculation( + self, scan_image: ScanImage, caplog: pytest.LogCaptureFixture + ) -> None: + # Arrange + mask = np.ones((2, 2), dtype=bool) + masking_mutator = Mask(mask=mask) + # Act + result = masking_mutator(scan_image=scan_image).unwrap() + # Assert + assert np.array_equal(result.data, scan_image.data, equal_nan=True) + assert ( + "skipping masking, Mask area is not containing any masking fields." + in caplog.messages + ) + + def test_empty_mask_sets_all_to_nan( + self, scan_image: ScanImage, caplog: pytest.LogCaptureFixture + ) -> None: + # Arrange + mask = np.zeros((2, 2), dtype=bool) + masking_mutator = Mask(mask=mask) + result = masking_mutator(scan_image=scan_image).unwrap() + + assert np.all(np.isnan(result.data)) + assert "Applying mask to scan_image" in caplog.messages