Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions nion/data/Core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,30 +1801,49 @@ def calculate_data() -> _ImageDataType:


def function_warp(data_and_metadata_in: _DataAndMetadataLike, coordinates_in: typing.Sequence[_DataAndMetadataLike], order: int = 1) -> DataAndMetadata.DataAndMetadata:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix this during the merge, but this is the wrong format for a Python function comment.

https://peps.python.org/pep-0257/#multi-line-docstrings

Warp or unwarp input data using an N-dimensional warp map.

The warp map is applied along N axes and broadcast over any additional
dimensions in the input, allowing a single warp map to be used for
higher-dimensional data (e.g., image sequences). For channelled data
such as RGB/RGBA, the warp is applied uniformly to all channels.

scipy map_coordinates does not broadcast by default, so need to loop
"""
data_and_metadata = DataAndMetadata.promote_ndarray(data_and_metadata_in)
coordinates = [DataAndMetadata.promote_ndarray(c) for c in coordinates_in]
coords = numpy.moveaxis(numpy.dstack([coordinate.data for coordinate in coordinates]), -1, 0)
coords = numpy.stack([c.data.astype(float) for c in coordinates], axis=0)
data = data_and_metadata._data_ex
if data_and_metadata.is_data_rgb:
rgb: numpy.typing.NDArray[numpy.uint8] = numpy.zeros(tuple(data_and_metadata.dimensional_shape) + (3,), numpy.uint8)
rgb[..., 0] = scipy.ndimage.map_coordinates(data[..., 0], coords, order=order)
rgb[..., 1] = scipy.ndimage.map_coordinates(data[..., 1], coords, order=order)
rgb[..., 2] = scipy.ndimage.map_coordinates(data[..., 2], coords, order=order)
return DataAndMetadata.new_data_and_metadata(data=rgb,
dimensional_calibrations=data_and_metadata.dimensional_calibrations,
intensity_calibration=data_and_metadata.intensity_calibration)
elif data_and_metadata.is_data_rgba:
rgba: numpy.typing.NDArray[numpy.uint8] = numpy.zeros(tuple(data_and_metadata.dimensional_shape) + (4,), numpy.uint8)
rgba[..., 0] = scipy.ndimage.map_coordinates(data[..., 0], coords, order=order)
rgba[..., 1] = scipy.ndimage.map_coordinates(data[..., 1], coords, order=order)
rgba[..., 2] = scipy.ndimage.map_coordinates(data[..., 2], coords, order=order)
rgba[..., 3] = scipy.ndimage.map_coordinates(data[..., 3], coords, order=order)
return DataAndMetadata.new_data_and_metadata(data=rgba,
num_frame_dims = coords.shape[0]

if data_and_metadata.is_data_rgb_type:
# Last dimension is channels
leading_shape = data.shape[:-num_frame_dims - 1]
output_shape = leading_shape + coords.shape[1:]
channels = 3 if data_and_metadata.is_data_rgb else 4
output = numpy.zeros(tuple(output_shape) + (channels,), numpy.uint8)

for idx in numpy.ndindex(leading_shape):
for chan in range(channels):
output[idx + (..., chan)] = scipy.ndimage.map_coordinates(
data[idx + (..., chan)],
coords,
order=order)

return DataAndMetadata.new_data_and_metadata(data=output,
dimensional_calibrations=data_and_metadata.dimensional_calibrations,
intensity_calibration=data_and_metadata.intensity_calibration)
else:
leading_shape = data.shape[:-num_frame_dims]
output_shape = leading_shape + coords.shape[1:]
output = numpy.zeros(output_shape, dtype=data.dtype)

for idx in numpy.ndindex(leading_shape):
output[idx] = scipy.ndimage.map_coordinates(data[idx], coords, order=order)

return DataAndMetadata.new_data_and_metadata(
data=scipy.ndimage.map_coordinates(data, coords, order=order),
data=output,
dimensional_calibrations=data_and_metadata.dimensional_calibrations,
intensity_calibration=data_and_metadata.intensity_calibration)

Expand Down
130 changes: 130 additions & 0 deletions nion/data/test/Core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,136 @@ def test_fft_zero_component_calibration(self) -> None:
result4 = Core.function_fft(xdata4)
self.assertAlmostEqual(0.0, result4.dimensional_calibrations[0].convert_to_calibrated_value(7.5))

## WARP TESTS
# Helper func
def _create_warp_test_data(self,
input_shape: tuple[int,...],
output_shape: tuple[int, ...] | None = None,
identity: bool = False,
mode: str = "greyscale") -> tuple[DataAndMetadata.DataAndMetadata, list[numpy.ndarray]]:
# Determine data type and channels based on mode
dtype: numpy.typing.DTypeLike
if mode == "greyscale":
dtype = float
channels = None
elif mode == "rgb":
dtype = numpy.uint8
channels = 3
elif mode == "rgba":
dtype = numpy.uint8
channels = 4
else:
raise ValueError(f"Invalid mode: {mode}. Choose 'greyscale', 'rgb', or 'rgba'.")

# Prepare input shape for data array
if channels is None:
full_shape = input_shape
else:
full_shape = input_shape + (channels,)

# Input data: sequential numbers for easy validation
data = numpy.arange(numpy.prod(full_shape), dtype=dtype).reshape(full_shape)
src = DataAndMetadata.new_data_and_metadata(data=data)

# Determine output grid shape
if output_shape is None:
H, W = input_shape[-2:]
else:
H, W = output_shape[-2:]

# Create warp coordinates
if identity:
# Identity warp: map output coordinates to same as input indices
warp_y, warp_x = numpy.meshgrid(
numpy.arange(input_shape[-2]),
numpy.arange(input_shape[-1]),
indexing="ij"
)
else:
# Resampling / scaling: map output grid into input index space
in_H, in_W = input_shape[-2:]
y = numpy.arange(0, in_H, in_H / H)
x = numpy.arange(0, in_W, in_W / W)
warp_y, warp_x = numpy.meshgrid(y, x, indexing="ij")

return src, [warp_y, warp_x]

def _validate_warp(self, src: DataAndMetadata.DataAndMetadata, dst: DataAndMetadata.DataAndMetadata, coords: list[numpy.ndarray], is_channel_data: bool = False) -> None:

# ---- shape validation ----
n_dims = len(coords) # number of warped dimensions
output_shape = coords[0].shape # shape of warp grid

expected_shape = dst.data_shape[:-n_dims] + output_shape

if is_channel_data:
expected_shape = dst.data_shape[:-n_dims-1] + output_shape + (dst.data_shape[-1],)

assert dst.data_shape == expected_shape, (
f"Output shape mismatch: {dst.data_shape} != {expected_shape}"
)

# ---- extract warped subspace ----
# Take the first element of all leading dimensions
warped = dst._data_ex
for _ in range(warped.ndim - n_dims):
warped = warped[0]

# warped now has shape == output_shape

# ---- monotonicity checks for each warped axis ----
for axis in range(n_dims):
# Build a slice that varies only along this axis
slicer: list[typing.Union[int, slice]] = [0] * n_dims
slicer[axis] = slice(None)

axis_values = warped[tuple(slicer)]

# Remove out-of-range zeros (leading or trailing)
nonzero = axis_values != 0
if numpy.count_nonzero(nonzero) < 2:
# Not enough valid data to validate this axis
continue

valid_values = axis_values[nonzero]

diffs = numpy.diff(valid_values)

assert numpy.all(diffs > 0), (
f"Warped axis {axis} is not strictly increasing: {axis_values}"
)

def test_warp_identity(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(4, 4), identity=True)
dst = Core.function_warp(src, coords)
self._validate_warp(src, dst, coords)

def test_warp_sequence(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4))
dst = Core.function_warp(src, coords)
self._validate_warp(src, dst, coords)

def test_warp_upscale(self) -> None:
# Input 4x4, warp to 8x8
src, coords = self._create_warp_test_data(input_shape=(4, 4), output_shape=(8, 8))
dst = Core.function_warp(src, coords)
self._validate_warp(src, dst, coords)

def test_warp_sequence_upscale(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(6, 8, 8))
dst = Core.function_warp(src, coords)
self._validate_warp(src, dst, coords)

def test_warp_rgb(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4), mode="rgb")
dst = Core.function_warp(src, coords)
self._validate_warp(src, dst, coords, is_channel_data=True)

def test_warp_rgba(self) -> None:
src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4), mode="rgba")
dst = Core.function_warp(src, coords)
self._validate_warp(src, dst, coords, is_channel_data=True)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
Expand Down