diff --git a/nion/data/Core.py b/nion/data/Core.py index 41fa9c3..015f046 100755 --- a/nion/data/Core.py +++ b/nion/data/Core.py @@ -1801,30 +1801,49 @@ def calculate_data() -> _ImageDataType: def function_warp(data_and_metadata_in: _DataAndMetadataLike, coordinates_in: typing.Sequence[_DataAndMetadataLike], order: int = 1) -> DataAndMetadata.DataAndMetadata: + """ + Warp or unwarp input data using an N-dimensional warp map. + + The warp map is applied along N axes and broadcast over any additional + dimensions in the input, allowing a single warp map to be used for + higher-dimensional data (e.g., image sequences). For channelled data + such as RGB/RGBA, the warp is applied uniformly to all channels. + + scipy map_coordinates does not broadcast by default, so need to loop + """ data_and_metadata = DataAndMetadata.promote_ndarray(data_and_metadata_in) coordinates = [DataAndMetadata.promote_ndarray(c) for c in coordinates_in] - coords = numpy.moveaxis(numpy.dstack([coordinate.data for coordinate in coordinates]), -1, 0) + coords = numpy.stack([c.data.astype(float) for c in coordinates], axis=0) data = data_and_metadata._data_ex - if data_and_metadata.is_data_rgb: - rgb: numpy.typing.NDArray[numpy.uint8] = numpy.zeros(tuple(data_and_metadata.dimensional_shape) + (3,), numpy.uint8) - rgb[..., 0] = scipy.ndimage.map_coordinates(data[..., 0], coords, order=order) - rgb[..., 1] = scipy.ndimage.map_coordinates(data[..., 1], coords, order=order) - rgb[..., 2] = scipy.ndimage.map_coordinates(data[..., 2], coords, order=order) - return DataAndMetadata.new_data_and_metadata(data=rgb, - dimensional_calibrations=data_and_metadata.dimensional_calibrations, - intensity_calibration=data_and_metadata.intensity_calibration) - elif data_and_metadata.is_data_rgba: - rgba: numpy.typing.NDArray[numpy.uint8] = numpy.zeros(tuple(data_and_metadata.dimensional_shape) + (4,), numpy.uint8) - rgba[..., 0] = scipy.ndimage.map_coordinates(data[..., 0], coords, order=order) - rgba[..., 1] = scipy.ndimage.map_coordinates(data[..., 1], coords, order=order) - rgba[..., 2] = scipy.ndimage.map_coordinates(data[..., 2], coords, order=order) - rgba[..., 3] = scipy.ndimage.map_coordinates(data[..., 3], coords, order=order) - return DataAndMetadata.new_data_and_metadata(data=rgba, + num_frame_dims = coords.shape[0] + + if data_and_metadata.is_data_rgb_type: + # Last dimension is channels + leading_shape = data.shape[:-num_frame_dims - 1] + output_shape = leading_shape + coords.shape[1:] + channels = 3 if data_and_metadata.is_data_rgb else 4 + output = numpy.zeros(tuple(output_shape) + (channels,), numpy.uint8) + + for idx in numpy.ndindex(leading_shape): + for chan in range(channels): + output[idx + (..., chan)] = scipy.ndimage.map_coordinates( + data[idx + (..., chan)], + coords, + order=order) + + return DataAndMetadata.new_data_and_metadata(data=output, dimensional_calibrations=data_and_metadata.dimensional_calibrations, intensity_calibration=data_and_metadata.intensity_calibration) else: + leading_shape = data.shape[:-num_frame_dims] + output_shape = leading_shape + coords.shape[1:] + output = numpy.zeros(output_shape, dtype=data.dtype) + + for idx in numpy.ndindex(leading_shape): + output[idx] = scipy.ndimage.map_coordinates(data[idx], coords, order=order) + return DataAndMetadata.new_data_and_metadata( - data=scipy.ndimage.map_coordinates(data, coords, order=order), + data=output, dimensional_calibrations=data_and_metadata.dimensional_calibrations, intensity_calibration=data_and_metadata.intensity_calibration) diff --git a/nion/data/test/Core_test.py b/nion/data/test/Core_test.py index 5106973..3f9e236 100755 --- a/nion/data/test/Core_test.py +++ b/nion/data/test/Core_test.py @@ -1336,6 +1336,136 @@ def test_fft_zero_component_calibration(self) -> None: result4 = Core.function_fft(xdata4) self.assertAlmostEqual(0.0, result4.dimensional_calibrations[0].convert_to_calibrated_value(7.5)) + ## WARP TESTS + # Helper func + def _create_warp_test_data(self, + input_shape: tuple[int,...], + output_shape: tuple[int, ...] | None = None, + identity: bool = False, + mode: str = "greyscale") -> tuple[DataAndMetadata.DataAndMetadata, list[numpy.ndarray]]: + # Determine data type and channels based on mode + dtype: numpy.typing.DTypeLike + if mode == "greyscale": + dtype = float + channels = None + elif mode == "rgb": + dtype = numpy.uint8 + channels = 3 + elif mode == "rgba": + dtype = numpy.uint8 + channels = 4 + else: + raise ValueError(f"Invalid mode: {mode}. Choose 'greyscale', 'rgb', or 'rgba'.") + + # Prepare input shape for data array + if channels is None: + full_shape = input_shape + else: + full_shape = input_shape + (channels,) + + # Input data: sequential numbers for easy validation + data = numpy.arange(numpy.prod(full_shape), dtype=dtype).reshape(full_shape) + src = DataAndMetadata.new_data_and_metadata(data=data) + + # Determine output grid shape + if output_shape is None: + H, W = input_shape[-2:] + else: + H, W = output_shape[-2:] + + # Create warp coordinates + if identity: + # Identity warp: map output coordinates to same as input indices + warp_y, warp_x = numpy.meshgrid( + numpy.arange(input_shape[-2]), + numpy.arange(input_shape[-1]), + indexing="ij" + ) + else: + # Resampling / scaling: map output grid into input index space + in_H, in_W = input_shape[-2:] + y = numpy.arange(0, in_H, in_H / H) + x = numpy.arange(0, in_W, in_W / W) + warp_y, warp_x = numpy.meshgrid(y, x, indexing="ij") + + return src, [warp_y, warp_x] + + def _validate_warp(self, src: DataAndMetadata.DataAndMetadata, dst: DataAndMetadata.DataAndMetadata, coords: list[numpy.ndarray], is_channel_data: bool = False) -> None: + + # ---- shape validation ---- + n_dims = len(coords) # number of warped dimensions + output_shape = coords[0].shape # shape of warp grid + + expected_shape = dst.data_shape[:-n_dims] + output_shape + + if is_channel_data: + expected_shape = dst.data_shape[:-n_dims-1] + output_shape + (dst.data_shape[-1],) + + assert dst.data_shape == expected_shape, ( + f"Output shape mismatch: {dst.data_shape} != {expected_shape}" + ) + + # ---- extract warped subspace ---- + # Take the first element of all leading dimensions + warped = dst._data_ex + for _ in range(warped.ndim - n_dims): + warped = warped[0] + + # warped now has shape == output_shape + + # ---- monotonicity checks for each warped axis ---- + for axis in range(n_dims): + # Build a slice that varies only along this axis + slicer: list[typing.Union[int, slice]] = [0] * n_dims + slicer[axis] = slice(None) + + axis_values = warped[tuple(slicer)] + + # Remove out-of-range zeros (leading or trailing) + nonzero = axis_values != 0 + if numpy.count_nonzero(nonzero) < 2: + # Not enough valid data to validate this axis + continue + + valid_values = axis_values[nonzero] + + diffs = numpy.diff(valid_values) + + assert numpy.all(diffs > 0), ( + f"Warped axis {axis} is not strictly increasing: {axis_values}" + ) + + def test_warp_identity(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(4, 4), identity=True) + dst = Core.function_warp(src, coords) + self._validate_warp(src, dst, coords) + + def test_warp_sequence(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4)) + dst = Core.function_warp(src, coords) + self._validate_warp(src, dst, coords) + + def test_warp_upscale(self) -> None: + # Input 4x4, warp to 8x8 + src, coords = self._create_warp_test_data(input_shape=(4, 4), output_shape=(8, 8)) + dst = Core.function_warp(src, coords) + self._validate_warp(src, dst, coords) + + def test_warp_sequence_upscale(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(6, 8, 8)) + dst = Core.function_warp(src, coords) + self._validate_warp(src, dst, coords) + + def test_warp_rgb(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4), mode="rgb") + dst = Core.function_warp(src, coords) + self._validate_warp(src, dst, coords, is_channel_data=True) + + def test_warp_rgba(self) -> None: + src, coords = self._create_warp_test_data(input_shape=(6, 4, 4), output_shape=(4, 4), mode="rgba") + dst = Core.function_warp(src, coords) + self._validate_warp(src, dst, coords, is_channel_data=True) + if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG)