Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/scratch-core/src/mutations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
can be chained together using a pipeline (e.g. `returns.pipeline.pipe`).
"""

from .filter import LevelMap
from .filter import LevelMap, Mask
from .spatial import CropToMask, Resample


__all__ = ["LevelMap", "CropToMask", "Resample"]
__all__ = ["LevelMap", "Mask", "Resample", "CropToMask"]
60 changes: 58 additions & 2 deletions packages/scratch-core/src/mutations/filter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import NamedTuple
from container_models.base import FloatArray1D, FloatArray2D

import numpy as np
from loguru import logger

from container_models.base import BinaryMask, FloatArray1D, FloatArray2D
from container_models.scan_image import ScanImage
from conversion.leveling.data_types import SurfaceTerms
from conversion.leveling.solver.design import build_design_matrix
from conversion.leveling.solver.grid import get_2d_grid
from conversion.leveling.solver.transforms import normalize_coordinates
from exceptions import ImageShapeMismatchError
from mutations.base import ImageMutation
import numpy as np


class PointCloud(NamedTuple):
Expand All @@ -15,6 +19,58 @@ class PointCloud(NamedTuple):
zs: FloatArray1D


class Mask(ImageMutation):
"""
Image mutation that applies a binary mask to a scan image.

All pixels corresponding to `False` (or zero) values in the mask
are set to `np.nan` in the image data. Pixels where the mask is
`True` remain unchanged.
"""

def __init__(self, mask: BinaryMask) -> None:
"""
Initialize the Mask mutation.

:param mask: Binary mask indicating which pixels should be kept (`True`)
or masked (`False`).
"""
self.mask = mask

@property
def skip_predicate(self) -> bool:
"""
Determine whether the masking operation can be skipped.

If the mask contains no masked pixels (i.e. all values are `True`),
applying the mask would have no effect and the mutation is skipped.

:returns: bool `True` if the mutation can be skipped, otherwise `False`.
"""
if self.mask.all():
logger.warning(
"skipping masking, Mask area is not containing any masking fields."
)
return True
return False

def apply_on_image(self, scan_image: ScanImage) -> ScanImage:
"""
Apply the mask to the image.

:params scan_image: Input scan image to which the mask is applied.
:return: The masked scan image.
:raises ImageShapeMismatchError: If the mask shape does not match the image data shape.
"""
if self.mask.shape != scan_image.data.shape:
raise ImageShapeMismatchError(
f"Mask shape: {self.mask.shape} does not match image shape: {scan_image.data.shape}"
)
logger.info("Applying mask to scan_image")
scan_image.data[~self.mask] = np.nan
return scan_image


class LevelMap(ImageMutation):
"""
Image mutation that performs surface leveling by fitting and subtracting
Expand Down
77 changes: 77 additions & 0 deletions packages/scratch-core/tests/mutations/test_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import re

import numpy as np
import pytest

from container_models.scan_image import ScanImage
from exceptions import ImageShapeMismatchError
from mutations.filter import Mask


class TestMask2dArray:
@pytest.fixture
def scan_image(
self,
):
return ScanImage(
data=np.array([[1, 2], [3, 4]], dtype=float), scale_x=1.0, scale_y=1.0
)

def test_mask_sets_background_pixels_to_nan(self, scan_image: ScanImage) -> None:
# Arrange
mask = np.array([[1, 0], [0, 1]], dtype=bool)
masking_mutator = Mask(mask=mask)
# Act
result = masking_mutator.apply_on_image(scan_image=scan_image)
# Assert
assert np.array_equal(
result.data, np.array([[1, np.nan], [np.nan, 4]]), equal_nan=True
)

def test_raises_on_shape_mismatch(self, scan_image: ScanImage) -> None:
# Arrange
mask = np.array([[1, 0, 0], [0, 1, 0]], dtype=bool)
masking_mutator = Mask(mask=mask)
# Act / Assert
with pytest.raises(
ImageShapeMismatchError,
match=re.escape(
f"Mask shape: {mask.shape} does not match image shape: {scan_image.data.shape}"
),
):
masking_mutator.apply_on_image(scan_image=scan_image)

def test_full_mask_preserves_all_values(self, scan_image: ScanImage) -> None:
# Arrange
mask = np.ones((2, 2), dtype=bool)
masking_mutator = Mask(mask=mask)
# Act
result = masking_mutator.apply_on_image(scan_image=scan_image)
# Assert
assert np.array_equal(result.data, scan_image.data, equal_nan=True)

def test_full_mask_skips_calculation(
self, scan_image: ScanImage, caplog: pytest.LogCaptureFixture
) -> None:
# Arrange
mask = np.ones((2, 2), dtype=bool)
masking_mutator = Mask(mask=mask)
# Act
result = masking_mutator(scan_image=scan_image).unwrap()
# Assert
assert np.array_equal(result.data, scan_image.data, equal_nan=True)
assert (
"skipping masking, Mask area is not containing any masking fields."
in caplog.messages
)

def test_empty_mask_sets_all_to_nan(
self, scan_image: ScanImage, caplog: pytest.LogCaptureFixture
) -> None:
# Arrange
mask = np.zeros((2, 2), dtype=bool)
masking_mutator = Mask(mask=mask)
result = masking_mutator(scan_image=scan_image).unwrap()

assert np.all(np.isnan(result.data))
assert "Applying mask to scan_image" in caplog.messages