diff --git a/.github/workflows/notebook_tests.yml b/.github/workflows/notebook_tests.yml index cac06e76..3c1a6634 100644 --- a/.github/workflows/notebook_tests.yml +++ b/.github/workflows/notebook_tests.yml @@ -9,7 +9,8 @@ jobs: max-parallel: 4 matrix: python-version: [3.7] - + env: + TEST_ENV: TRUE steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} diff --git a/notebooks/transformations.ipynb b/notebooks/transformations.ipynb new file mode 100644 index 00000000..d5e313bd --- /dev/null +++ b/notebooks/transformations.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-15T09:57:50.791332Z", + "start_time": "2020-02-15T09:57:46.068701Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "!pip install napari\n", + "!pip install SimpleITK" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:54:54.920459Z", + "start_time": "2020-02-16T16:54:54.669509Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running test environment: False\n" + ] + } + ], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%gui qt\n", + "import os\n", + "if 'TEST_ENV' in os.environ:\n", + " TEST_ENV = os.environ['TEST_ENV'].lower() == \"true\"\n", + "else:\n", + " TEST_ENV = 0\n", + "print(f\"Running test environment: {bool(TEST_ENV)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-15T11:55:06.130717Z", + "start_time": "2020-02-15T11:54:22.119553Z" + } + }, + "outputs": [], + "source": [ + "from io import BytesIO\n", + "from zipfile import ZipFile\n", + "from urllib.request import urlopen\n", + "\n", + "resp = urlopen(\"http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip\")\n", + "zipfile = ZipFile(BytesIO(resp.read()))\n", + "\n", + "img_file = zipfile.extract(\"ExBox3/T1_brain.nii.gz\")\n", + "mask_file = zipfile.extract(\"ExBox3/T1_brain_seg.nii.gz\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:01.394975Z", + "start_time": "2020-02-16T16:55:00.893340Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image shape (192, 192, 174)\n", + "Image shape (192, 192, 174)\n" + ] + } + ], + "source": [ + "import SimpleITK as sitk\n", + "import numpy as np\n", + "\n", + "# load image and mask\n", + "img_file = \"./ExBox3/T1_brain.nii.gz\"\n", + "mask_file = \"./ExBox3/T1_brain_seg.nii.gz\"\n", + "img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))\n", + "img = img.astype(np.float32)\n", + "mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))\n", + "mask = mask.astype(np.float32)\n", + "\n", + "assert mask.shape == img.shape\n", + "print(f\"Image shape {img.shape}\")\n", + "print(f\"Image shape {mask.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:04.255613Z", + "start_time": "2020-02-16T16:55:03.213336Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/micha/miniconda3/envs/phoenix/lib/python3.7/site-packages/qtpy/__init__.py:216: RuntimeWarning: Selected binding \"pyqt5\" could not be found, using \"pyside2\"\n", + " 'using \"{}\"'.format(initial_api, API), RuntimeWarning)\n" + ] + } + ], + "source": [ + "if TEST_ENV:\n", + " def view_batch(batch):\n", + " pass\n", + "else:\n", + " %gui qt\n", + " import napari\n", + " def view_batch(batch):\n", + " viewer = napari.view_image(batch[\"data\"].cpu().numpy(), name=\"data\")\n", + " viewer.add_image(batch[\"mask\"].cpu().numpy(), name=\"mask\", opacity=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:54.082493Z", + "start_time": "2020-02-16T16:55:54.019599Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from rising.transforms import *\n", + "\n", + "batch = {\n", + " \"data\": torch.from_numpy(img).float()[None, None],\n", + " \"mask\": torch.from_numpy(mask).long()[None, None],\n", + "}\n", + "\n", + "def apply_transform(trafo, batch):\n", + " transformed = trafo(**batch)\n", + " print(f\"Transformed data shape: {transformed['data'].shape}\")\n", + " print(f\"Transformed mask shape: {transformed['mask'].shape}\")\n", + " print(f\"Transformed data min: {transformed['data'].min()}\")\n", + " print(f\"Transformed data max: {transformed['data'].max()}\")\n", + " print(f\"Transformed data mean: {transformed['data'].mean()}\")\n", + " return transformed" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:06.109008Z", + "start_time": "2020-02-16T16:55:06.069336Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transformed data shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed mask shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed data min: 0.0\n", + "Transformed data max: 502.0\n", + "Transformed data mean: 37.62009048461914\n" + ] + } + ], + "source": [ + "print(f\"Transformed data shape: {batch['data'].shape}\")\n", + "print(f\"Transformed mask shape: {batch['mask'].shape}\")\n", + "print(f\"Transformed data min: {batch['data'].min()}\")\n", + "print(f\"Transformed data max: {batch['data'].max()}\")\n", + "print(f\"Transformed data mean: {batch['data'].mean()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:55:57.391117Z", + "start_time": "2020-02-16T16:55:55.675294Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transformed data shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed mask shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed data min: 0.0\n", + "Transformed data max: 445.4209289550781\n", + "Transformed data mean: 110.88246154785156\n" + ] + } + ], + "source": [ + "trafo = Scale(1.5, adjust_size=False)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T17:03:58.535489Z", + "start_time": "2020-02-16T17:03:57.964843Z" + } + }, + "outputs": [], + "source": [ + "trafo = Rotate([0, 0, 45], degree=True, adjust_size=False)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-02-16T16:00:26.032367Z", + "start_time": "2020-02-16T16:00:25.466391Z" + } + }, + "outputs": [], + "source": [ + "trafo = Translate([0.1, 0, 0], adjust_size=False)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trafo = CenterCropGrid(size=100)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trafo = RandomCropGrid(size=100)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 1, 192, 192, 174])\n", + "torch.Size([1, 192, 192, 174, 3])\n", + "tensor(0.0100)\n", + "tensor(-0.0100)\n", + "Transformed data shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed mask shape: torch.Size([1, 1, 192, 192, 174])\n", + "Transformed data min: 0.0\n", + "Transformed data max: 474.02880859375\n", + "Transformed data mean: 37.61506652832031\n" + ] + } + ], + "source": [ + "trafo = ElasticDistortion(alpha=0.01, std=[0.1, 0.1, 0.000001], dim=3)\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trafo = RadialDistortion(scale=[0.1, 1., 0.1])\n", + "transformed = apply_transform(trafo, batch)\n", + "view_batch(transformed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/rising/transforms/__init__.py b/rising/transforms/__init__.py index 99b95a19..27b5f141 100644 --- a/rising/transforms/__init__.py +++ b/rising/transforms/__init__.py @@ -8,3 +8,5 @@ from rising.transforms.spatial import * from rising.transforms.utility import * from rising.transforms.tensor import * +from rising.transforms.affine import * +from rising.transforms.grid import * diff --git a/rising/transforms/affine.py b/rising/transforms/affine.py index 7beb6dfd..be002381 100644 --- a/rising/transforms/affine.py +++ b/rising/transforms/affine.py @@ -1,10 +1,13 @@ -from rising.transforms.abstract import BaseTransform -from rising.transforms.functional.affine import affine_image_transform -from rising.utils.affine import AffineParamType, \ - assemble_matrix_if_necessary, matrix_to_homogeneous, matrix_to_cartesian -from rising.utils.checktype import check_scalar import torch -from typing import Sequence, Union, Iterable +from torch import Tensor +from typing import Sequence, Union, Iterable, Dict, Tuple + +from rising.transforms.grid import GridTransform +from rising.transforms.functional.affine import create_affine_grid, \ + AffineParamType, parametrize_matrix +from rising.utils.affine import matrix_to_homogeneous, matrix_to_cartesian +from rising.utils.checktype import check_scalar + __all__ = [ 'Affine', @@ -16,14 +19,11 @@ ] -class Affine(BaseTransform): - def __init__(self, scale: AffineParamType = None, - rotation: AffineParamType = None, - translation: AffineParamType = None, - matrix: torch.Tensor = None, +class Affine(GridTransform): + def __init__(self, + matrix: Union[Tensor, Sequence[Sequence[float]]] = None, keys: Sequence = ('data',), grad: bool = False, - degree: bool = False, output_size: tuple = None, adjust_size: bool = False, interpolation_mode: str = 'bilinear', @@ -37,55 +37,15 @@ def __init__(self, scale: AffineParamType = None, Parameters ---------- - scale : torch.Tensor, int, float, optional - the scale factor(s). Supported are: - * a full transformation matrix of shape - (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples - None will be treated as a scaling factor of 1 - rotation : torch.Tensor, int, float, optional - the rotation factor(s). Supported are: - * a full transformation matrix of shape - (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples - None will be treated as a rotation factor of 1 - translation : torch.Tensor, int, float - the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples - None will be treated as a translation offset of 0 - matrix : torch.Tensor, optional + matrix : Tensor, optional if given, overwrites the parameters for :param:`scale`, :param:rotation` and :param:`translation`. - Should be a matrix o shape (BATCHSIZE,) NDIM, NDIM+1. - This matrix represents the whole homogeneous transformation matrix + Should be a matrix of shape [(BATCHSIZE,) NDIM, NDIM(+1)] + This matrix represents the whole transformation matrix keys: Sequence keys which should be augmented grad: bool enable gradient computation inside transformation - degree : bool - whether the given rotation(s) are in degrees. - Only valid for rotation parameters, which aren't passed as full - transformation matrix. output_size : Iterable if given, this will be the resulting image size. Defaults to ``None`` @@ -105,87 +65,76 @@ def __init__(self, scale: AffineParamType = None, referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. **kwargs : - additional keyword arguments passed to the affine transform - - Notes - ----- - If a :param:`matrix` is specified, it overwrites all arguments given - for :param:`scale`, :param:rotation` and :param:`translation` + additional keyword arguments passed to grid sample """ - super().__init__(augment_fn=affine_image_transform, - keys=keys, - grad=grad, + super().__init__(keys=keys, grad=grad, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, **kwargs) - self.scale = scale - self.rotation = rotation - self.translation = translation self.matrix = matrix - self.degree = degree self.output_size = output_size self.adjust_size = adjust_size - self.interpolation_mode = interpolation_mode - self.padding_mode = padding_mode - self.align_corners = align_corners - def assemble_matrix(self, **data) -> torch.Tensor: + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: """ Assembles the matrix (and takes care of batching and having it on the right device and in the correct dtype and dimensionality). Parameters ---------- - **data : - the data to be transformed. Will be used to determine batchsize, - dimensionality, dtype and device + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid Returns ------- - torch.Tensor + Tensor the (batched) transformation matrix - - """ - - batchsize = data[self.keys[0]].shape[0] - ndim = len(data[self.keys[0]].shape) - 2 # channel and batch dim - device = data[self.keys[0]].device - dtype = data[self.keys[0]].dtype - - matrix = assemble_matrix_if_necessary( - batchsize, ndim, scale=self.scale, rotation=self.rotation, - translation=self.translation, matrix=self.matrix, - degree=self.degree, device=device, dtype=dtype) - - return matrix - - def forward(self, **data) -> dict: - """ - Assembles the matrix and applies it to the specified sample-entities. - - Parameters - ---------- - **data : - the data to transform - - Returns - ------- - dict - dictionary containing the transformed data - """ - matrix = self.assemble_matrix(**data) - - for key in self.keys: - data[key] = self.augment_fn( - data[key], matrix_batch=matrix, - output_size=self.output_size, - adjust_size=self.adjust_size, - interpolation_mode=self.interpolation_mode, - padding_mode=self.padding_mode, - align_corners=self.align_corners, - **self.kwargs - ) - - return data + if self.matrix is None: + raise ValueError("Matrix needs to be initialized or overwritten.") + if not torch.is_tensor(self.matrix): + self.matrix = Tensor(self.matrix) + self.matrix = self.matrix.to(device=device, dtype=dtype) + + batchsize = batch_shape[0] + ndim = len(batch_shape) - 2 # channel and batch dim + + # batch dimension missing -> Replicate for each sample in batch + if len(self.matrix.shape) == 2: + self.matrix = self.matrix[None].expand(batchsize, -1, -1).clone() + if self.matrix.shape == (batchsize, ndim, ndim + 1): + return self.matrix + elif self.matrix.shape == (batchsize, ndim, ndim): + return matrix_to_homogeneous(self.matrix)[:, :-1] + elif self.matrix.shape == (batchsize, ndim + 1, ndim + 1): + return matrix_to_cartesian(self.matrix) + + raise ValueError( + "Invalid Shape for affine transformation matrix. " + "Got %s but expected %s" % ( + str(tuple(self.matrix.shape)), + str((batchsize, ndim, ndim + 1)))) + + def create_grid(self, input_size: Sequence[Sequence[int]], + matrix: Tensor = None) -> Dict[Tuple, Tensor]: + grid = {} + for size in input_size: + if tuple(size) not in grid: + grid[tuple(size)] = create_affine_grid( + size, self.assemble_matrix(size), output_size=self.output_size, + adjust_size=self.adjust_size, align_corners=self.align_corners, + ) + return grid + + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + return grid def __add__(self, other): """ @@ -194,16 +143,15 @@ def __add__(self, other): Parameters ---------- - other : torch.Tensor, Affine + other : Tensor, Affine the other transformation Returns ------- StackedAffine a stacked affine transformation - """ - if not isinstance(other, Affine): + if not isinstance(other, GridTransform): other = Affine(matrix=other, keys=self.keys, grad=self.grad, output_size=self.output_size, adjust_size=self.adjust_size, @@ -212,12 +160,15 @@ def __add__(self, other): align_corners=self.align_corners, **self.kwargs) - return StackedAffine(self, other, keys=self.keys, grad=self.grad, - output_size=self.output_size, - adjust_size=self.adjust_size, - interpolation_mode=self.interpolation_mode, - padding_mode=self.padding_mode, - align_corners=self.align_corners, **self.kwargs) + if isinstance(other, Affine): + return StackedAffine(self, other, keys=self.keys, grad=self.grad, + output_size=self.output_size, + adjust_size=self.adjust_size, + interpolation_mode=self.interpolation_mode, + padding_mode=self.padding_mode, + align_corners=self.align_corners, **self.kwargs) + else: + return super().__add__(other) def __radd__(self, other): """ @@ -226,16 +177,15 @@ def __radd__(self, other): Parameters ---------- - other : torch.Tensor, Affine + other : Tensor, Affine the other transformation Returns ------- StackedAffine a stacked affine transformation - """ - if not isinstance(other, Affine): + if not isinstance(other, GridTransform): other = Affine(matrix=other, keys=self.keys, grad=self.grad, output_size=self.output_size, adjust_size=self.adjust_size, @@ -243,16 +193,250 @@ def __radd__(self, other): padding_mode=self.padding_mode, align_corners=self.align_corners, **self.kwargs) - return StackedAffine(other, self, grad=other.grad, - output_size=other.output_size, - adjust_size=other.adjust_size, - interpolation_mode=other.interpolation_mode, - padding_mode=other.padding_mode, - align_corners=other.align_corners, - **other.kwargs) + if isinstance(other, Affine): + return StackedAffine(other, self, grad=other.grad, + output_size=other.output_size, + adjust_size=other.adjust_size, + interpolation_mode=other.interpolation_mode, + padding_mode=other.padding_mode, + align_corners=other.align_corners, + **other.kwargs) + else: + return super().__add__(other) + + +class StackedAffine(Affine): + def __init__( + self, + *transforms: Union[Affine, Sequence[ + Union[Sequence[Affine], Affine]]], + keys: Sequence = ('data',), + grad: bool = False, + output_size: tuple = None, + adjust_size: bool = False, + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + **kwargs): + """ + Class Performing an Affine Transformation on a given sample dict. + The transformation will be applied to all the dict-entries specified + in :attr:`keys`. + + Parameters + ---------- + transforms : sequence of Affines + the transforms to stack. Each transform must have a function + called ``assemble_matrix``, which is called to dynamically + assemble stacked matrices. Afterwards these transformations are + stacked by matrix-multiplication to only perform a single + interpolation + keys: Sequence + keys which should be augmented + grad: bool + enable gradient computation inside transformation + output_size : Iterable + if given, this will be the resulting image size. + Defaults to ``None`` + adjust_size : bool + if True, the resulting image size will be calculated dynamically + to ensure that the whole image fits. + interpolation_mode : str + interpolation mode to calculate output values + 'bilinear' | 'nearest'. Default: 'bilinear' + padding_mode : + padding mode for outside grid values + 'zeros' | 'border' | 'reflection'. Default: 'zeros' + align_corners : Geometrically, we consider the pixels of the input as + squares rather than points. If set to True, the extrema (-1 and 1) + are considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + **kwargs : + additional keyword arguments passed to the affine transform + """ + if isinstance(transforms, (tuple, list)): + if isinstance(transforms[0], (tuple, list)): + transforms = transforms[0] + + # ensure trafos are Affines and not raw matrices + transforms = tuple( + [trafo if isinstance(trafo, Affine) else Affine(matrix=trafo) + for trafo in transforms]) + + super().__init__(keys=keys, grad=grad, + output_size=output_size, adjust_size=adjust_size, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + **kwargs) + + self.transforms = transforms + + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: + """ + Handles the matrix assembly and stacking + + Parameters + ---------- + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid + + Returns + ------- + Tensor + the (batched) transformation matrix + + """ + whole_trafo = None + for trafo in self.transforms: + matrix = matrix_to_homogeneous(trafo.assemble_matrix( + batch_shape=batch_shape, device=device, dtype=dtype + )) + if whole_trafo is None: + whole_trafo = matrix + else: + whole_trafo = torch.bmm(whole_trafo, matrix) + return matrix_to_cartesian(whole_trafo) -class Rotate(Affine): +class BaseAffine(Affine): + def __init__(self, + scale: AffineParamType = None, + rotation: AffineParamType = None, + translation: AffineParamType = None, + degree: bool = False, + image_transform: bool = True, + keys: Sequence = ('data',), + grad: bool = False, + output_size: tuple = None, + adjust_size: bool = False, + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + **kwargs, + ): + """ + Class performing a basic Affine Transformation on a given sample dict. + The transformation will be applied to all the dict-entries specified + in :attr:`keys`. + + Parameters + ---------- + scale : Tensor, int, float, optional + the scale factor(s). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a scaling factor of 1 + rotation : Tensor, int, float, optional + the rotation factor(s). The rotation is performed in + consecutive order axis0 -> axis1 (-> axis 2). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a rotation factor of 1 + translation : Tensor, int, float + the translation offset(s) relative to image (should be in the + range [0, 1]). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a translation offset of 0 + keys: Sequence + keys which should be augmented + grad: bool + enable gradient computation inside transformation + degree : bool + whether the given rotation(s) are in degrees. + Only valid for rotation parameters, which aren't passed as full + transformation matrix. + output_size : Iterable + if given, this will be the resulting image size. + Defaults to ``None`` + adjust_size : bool + if True, the resulting image size will be calculated dynamically + to ensure that the whole image fits. + interpolation_mode : str + interpolation mode to calculate output values + 'bilinear' | 'nearest'. Default: 'bilinear' + padding_mode : + padding mode for outside grid values + 'zeros' | 'border' | 'reflection'. Default: 'zeros' + align_corners : Geometrically, we consider the pixels of the input as + squares rather than points. If set to True, the extrema (-1 and 1) + are considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + **kwargs : + additional keyword arguments passed to the affine transform + """ + super().__init__(keys=keys, grad=grad, output_size=output_size, + adjust_size=adjust_size, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + **kwargs) + self.scale = scale + self.rotation = rotation + self.translation = translation + self.degree = degree + self.image_transform = image_transform + + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: + """ + Assembles the matrix (and takes care of batching and having it on the + right device and in the correct dtype and dimensionality). + + Parameters + ---------- + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid + + Returns + ------- + Tensor + the (batched) transformation matrix + """ + batchsize = batch_shape[0] + ndim = len(batch_shape) - 2 # channel and batch dim + + self.matrix = parametrize_matrix( + scale=self.scale, rotation=self.rotation, translation=self.translation, + batchsize=batchsize, ndim=ndim, degree=self.degree, + device=device, dtype=dtype, image_transform=self.image_transform) + return self.matrix + + +class Rotate(BaseAffine): def __init__(self, rotation: AffineParamType, keys: Sequence = ('data',), @@ -266,23 +450,22 @@ def __init__(self, **kwargs): """ Class Performing a Rotation-OnlyAffine Transformation on a given - sample dict. + sample dict. The rotation is applied in consecutive order: + rot axis 0 -> rot axis 1 -> rot axis 2 The transformation will be applied to all the dict-entries specified in :attr:`keys`. Parameters ---------- - rotation : torch.Tensor, int, float, optional + rotation : Tensor, int, float, optional the rotation factor(s). Supported are: - * a full transformation matrix of shape - (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension None will be treated as a rotation factor of 1 keys: Sequence keys which should be augmented @@ -312,11 +495,6 @@ def __init__(self, making the sampling more resolution agnostic. **kwargs : additional keyword arguments passed to the affine transform - - Warnings - -------- - This transform is not applied around the image center - """ super().__init__(scale=None, rotation=rotation, @@ -333,7 +511,7 @@ def __init__(self, **kwargs) -class Translate(Affine): +class Translate(BaseAffine): def __init__(self, translation: AffineParamType, keys: Sequence = ('data',), @@ -343,6 +521,7 @@ def __init__(self, interpolation_mode: str = 'bilinear', padding_mode: str = 'zeros', align_corners: bool = False, + unit: str = 'relative', **kwargs): """ Class Performing an Translation-Only @@ -352,17 +531,16 @@ def __init__(self, Parameters ---------- - translation : torch.Tensor, int, float - the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples + translation : Tensor, int, float + the translation offset(s). The translation unit can be specified. + Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension None will be treated as a translation offset of 0 keys: Sequence keys which should be augmented @@ -377,7 +555,7 @@ def __init__(self, interpolation_mode : str interpolation mode to calculate output values 'bilinear' | 'nearest'. Default: 'bilinear' - padding_mode : + padding_mode : str padding mode for outside grid values 'zeros' | 'border' | 'reflection'. Default: 'zeros' align_corners : Geometrically, we consider the pixels of the input as @@ -386,9 +564,13 @@ def __init__(self, corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. + unit: str + defines the unit of the translation parameter. + 'pixel': define number of pixels to translate | 'relative': + translation should be in the range [0, 1] and is scaled + with the image size **kwargs : additional keyword arguments passed to the affine transform - """ super().__init__(scale=None, rotation=None, @@ -403,9 +585,40 @@ def __init__(self, padding_mode=padding_mode, align_corners=align_corners, **kwargs) + self.unit = unit + + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: + """ + Assembles the matrix (and takes care of batching and having it on the + right device and in the correct dtype and dimensionality). + Parameters + ---------- + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid -class Scale(Affine): + Returns + ------- + Tensor + the (batched) transformation matrix [N, NDIM, NDIM] + """ + matrix = super().assemble_matrix(batch_shape=batch_shape, + device=device, dtype=dtype) + if self.unit.lower() == 'pixel': + img_size = torch.tensor(batch_shape[2:]).to(matrix) + matrix[..., -1] = matrix[..., -1] / img_size + return matrix + + +class Scale(BaseAffine): def __init__(self, scale: AffineParamType, keys: Sequence = ('data',), @@ -424,17 +637,15 @@ def __init__(self, Parameters ---------- - scale : torch.Tensor, int, float, optional + scale : Tensor, int, float, optional the scale factor(s). Supported are: - * a full transformation matrix of shape - (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be - replicated for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will - be replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for - all batch samples + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension None will be treated as a scaling factor of 1 keys: Sequence keys which should be augmented @@ -484,108 +695,6 @@ def __init__(self, **kwargs) -class StackedAffine(Affine): - def __init__( - self, - *transforms: Union[Affine, Sequence[Union[Sequence[Affine], - Affine]]], - keys: Sequence = ('data',), - grad: bool = False, - output_size: tuple = None, - adjust_size: bool = False, - interpolation_mode: str = 'bilinear', - padding_mode: str = 'zeros', - align_corners: bool = False, - **kwargs): - """ - Class Performing an Affine Transformation on a given sample dict. - The transformation will be applied to all the dict-entries specified - in :attr:`keys`. - - Parameters - ---------- - transforms : sequence of Affines - the transforms to stack. Each transform must have a function - called ``assemble_matrix``, which is called to dynamically - assemble stacked matrices. Afterwards these transformations are - stacked by matrix-multiplication to only perform a single - interpolation - keys: Sequence - keys which should be augmented - grad: bool - enable gradient computation inside transformation - output_size : Iterable - if given, this will be the resulting image size. - Defaults to ``None`` - adjust_size : bool - if True, the resulting image size will be calculated dynamically - to ensure that the whole image fits. - interpolation_mode : str - interpolation mode to calculate output values - 'bilinear' | 'nearest'. Default: 'bilinear' - padding_mode : - padding mode for outside grid values - 'zeros' | 'border' | 'reflection'. Default: 'zeros' - align_corners : Geometrically, we consider the pixels of the input as - squares rather than points. If set to True, the extrema (-1 and 1) - are considered as referring to the center points of the input’s - corner pixels. If set to False, they are instead considered as - referring to the corner points of the input’s corner pixels, - making the sampling more resolution agnostic. - **kwargs : - additional keyword arguments passed to the affine transform - - """ - - if isinstance(transforms, (tuple, list)): - if isinstance(transforms[0], (tuple, list)): - transforms = transforms[0] - - # ensure trafos are Affines and not raw matrices - transforms = tuple( - [trafo if isinstance(trafo, Affine) else Affine(matrix=trafo) - for trafo in transforms]) - - super().__init__(matrix=None, - scale=None, rotation=None, translation=None, - keys=keys, grad=grad, degree=False, - output_size=output_size, adjust_size=adjust_size, - interpolation_mode=interpolation_mode, - padding_mode=padding_mode, - align_corners=align_corners, - **kwargs) - - self.transforms = transforms - - def assemble_matrix(self, **data) -> torch.Tensor: - """ - Handles the matrix assembly and stacking - - Parameters - ---------- - **data : - the data to be transformed. Will be used to determine batchsize, - dimensionality, dtype and device - - Returns - ------- - torch.Tensor - the (batched) transformation matrix - - """ - whole_trafo = None - - for trafo in self.transforms: - matrix = matrix_to_homogeneous(trafo.assemble_matrix(**data)) - - if whole_trafo is None: - whole_trafo = matrix - else: - whole_trafo = torch.bmm(whole_trafo, matrix) - - return matrix_to_cartesian(whole_trafo) - - class Resize(Scale): def __init__(self, size: Union[int, Iterable], @@ -646,24 +755,31 @@ def __init__(self, align_corners=align_corners, **kwargs) - def assemble_matrix(self, **data) -> torch.Tensor: + def assemble_matrix(self, + batch_shape: Sequence[int], + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> Tensor: """ Handles the matrix assembly and calculates the scale factors for resizing Parameters ---------- - **data : - the data to be transformed. Will be used to determine batchsize, - dimensionality, dtype and device + batch_shape : Sequence[int] + shape of batch + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid Returns ------- - torch.Tensor + Tensor the (batched) transformation matrix """ - curr_img_size = data[self.keys[0]].shape[2:] + curr_img_size = batch_shape[2:] was_scalar = check_scalar(self.output_size) @@ -673,7 +789,8 @@ def assemble_matrix(self, **data) -> torch.Tensor: self.scale = [self.output_size[i] / curr_img_size[-i] for i in range(len(curr_img_size))] - matrix = super().assemble_matrix(**data) + matrix = super().assemble_matrix(batch_shape=batch_shape, + device=device, dtype=dtype) if was_scalar: self.output_size = self.output_size[0] diff --git a/rising/transforms/functional/__init__.py b/rising/transforms/functional/__init__.py index b1a5aabe..1117a081 100644 --- a/rising/transforms/functional/__init__.py +++ b/rising/transforms/functional/__init__.py @@ -5,3 +5,4 @@ from rising.transforms.functional.tensor import * from rising.transforms.functional.utility import * from rising.transforms.functional.channel import * +from rising.transforms.functional.kernel import * diff --git a/rising/transforms/functional/affine.py b/rising/transforms/functional/affine.py index db98401a..8f98d211 100644 --- a/rising/transforms/functional/affine.py +++ b/rising/transforms/functional/affine.py @@ -1,15 +1,411 @@ import torch +import warnings +from torch import Tensor +from typing import Union, Sequence + from rising.utils.affine import points_to_cartesian, matrix_to_homogeneous, \ - points_to_homogeneous, matrix_revert_coordinate_order + points_to_homogeneous, unit_box, get_batched_eye, deg_to_rad from rising.utils.checktype import check_scalar -import warnings + __all__ = [ - 'affine_image_transform', - 'affine_point_transform' + 'create_affine_grid', + 'affine_point_transform', + "create_rotation", + "create_scale", + "create_translation", ] +AffineParamType = Union[int, float, Sequence, torch.Tensor] + + +def expand_scalar_param(param: AffineParamType, batchsize: int, ndim: int) -> Tensor: + """ + Bring affine params to shape (batchsize, ndim) + + Parameters + ---------- + param: AffineParamType + affine parameter + batchsize: int + size of batch + ndim: int + number of spatial dimensions + + Returns + ------- + Tensor: + affine params in correct shape + """ + if check_scalar(param): + return torch.tensor([[param] * ndim] * batchsize).float() + + if not torch.is_tensor(param): + param = torch.tensor(param) + else: + param = param.clone() + + if not param.ndimension() == 2: + if param.shape[0] == ndim: # scalar per dim + param = param.reshape(1, -1).expand(batchsize, ndim) + elif param.shape[0] == batchsize: # scalar per batch + param = param.reshape(-1, 1).expand(batchsize, ndim) + else: + raise ValueError("Unknown param for expanding. " + f"Found {param} for batchsize {batchsize} and ndim {ndim}") + assert all([i == j for i, j in zip(param.shape, (batchsize, ndim))]), \ + (f"Affine param need to have shape (batchsize, ndim)" + f"({(batchsize, ndim)}) but found {param.shape}") + return param.float() + + +def create_scale(scale: AffineParamType, + batchsize: int, ndim: int, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + image_transform: bool = True) -> torch.Tensor: + """ + Formats the given scale parameters to a homogeneous transformation matrix + + Parameters + ---------- + scale : torch.Tensor, int, float + the scale factor(s). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a scaling factor of 1 + batchsize : int + the number of samples per batch + ndim : int + the dimensionality of the transform + device : torch.device, str, optional + the device to put the resulting tensor to. Defaults to the default + device + dtype : torch.dtype, str, optional + the dtype of the resulting trensor. Defaults to the default dtype + image_transform: bool + inverts the scale matrix to match expected behavior when applied to + an image, e.g. scale>1 increases the size of an image but decrease + the size of an grid + + Returns + ------- + torch.Tensor + the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1], N is + the batch size and NDIM is the number of spatial dimensions + """ + if scale is None: + scale = 1 + + scale = expand_scalar_param(scale, batchsize, ndim).to( + device=device, dtype=dtype) + if image_transform: + scale = 1 / scale + scale_matrix = torch.stack( + [eye * s for eye, s in zip(get_batched_eye( + batchsize=batchsize, ndim=ndim, device=device, dtype=dtype), scale)]) + return matrix_to_homogeneous(scale_matrix) + + +def create_translation(offset: AffineParamType, + batchsize: int, ndim: int, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + image_transform: bool = True) -> torch.Tensor: + """ + Formats the given translation parameters to a homogeneous transformation + matrix + + Parameters + ---------- + offset : torch.Tensor, int, float + the translation offset(s). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a translation offset of 0 + batchsize : int + the number of samples per batch + ndim : int + the dimensionality of the transform + device : torch.device, str, optional + the device to put the resulting tensor to. Defaults to the default + device + dtype : torch.dtype, str, optional + the dtype of the resulting trensor. Defaults to the default dtype + image_transform: bool + inverts the translation matrix to match expected behavior when applied + to an image, e.g. translation > 0 should move the image in the + positive direction of an axis but the grid in the negative direction + + Returns + ------- + torch.Tensor + the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1], N is + the batch size and NDIM is the number of spatial dimensions + """ + if offset is None: + offset = 0 + offset = expand_scalar_param(offset, batchsize, ndim).to( + device=device, dtype=dtype) + eye_batch = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, dtype=dtype) + translation_matrix = torch.stack([torch.cat([eye, o.view(-1, 1)], dim=1) + for eye, o in zip(eye_batch, offset)]) + if image_transform: + translation_matrix[..., -1] = -translation_matrix[..., -1] + return matrix_to_homogeneous(translation_matrix) + + +def create_rotation(rotation: AffineParamType, + batchsize: int, ndim: int, + degree: bool = False, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None) -> torch.Tensor: + """ + Formats the given scale parameters to a homogeneous transformation matrix + + Parameters + ---------- + rotation : torch.Tensor, int, float + the rotation factor(s). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a rotation factor of 0 + batchsize : int + the number of samples per batch + ndim : int + the dimensionality of the transform + degree : bool + whether the given rotation(s) are in degrees. + Only valid for rotation parameters, which aren't passed as full + transformation matrix. + device : torch.device, str, optional + the device to put the resulting tensor to. Defaults to the default + device + dtype : torch.dtype, str, optional + the dtype of the resulting trensor. Defaults to the default dtype + + Returns + ------- + torch.Tensor + the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1], N is + the batch size and NDIM is the number of spatial dimensions + + """ + if rotation is None: + rotation = 0 + num_rot_params = 1 if ndim == 2 else ndim + + rotation = expand_scalar_param(rotation, batchsize, num_rot_params) + if degree: + rotation = deg_to_rad(rotation) + + matrix_fn = create_rotation_2d if ndim == 2 else create_rotation_3d + sin, cos = torch.sin(rotation), torch.cos(rotation) + rotation_matrix = torch.stack([matrix_fn(s, c) for s, c in zip(sin, cos)]).to( + device=device, dtype=dtype) + return matrix_to_homogeneous(rotation_matrix) + + +def create_rotation_2d(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a 2d rotation matrix + + Parameters + ---------- + sin: Tensor + sin value to use for rotation matrix, [1] + cos: Tensor + cos value to use for rotation matrix, [1] + + Returns + ------- + Tensor + rotation matrix, [2, 2] + """ + return torch.tensor([[cos.clone(), -sin.clone()], [sin.clone(), cos.clone()]]) + + +def create_rotation_3d(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a 3d rotation matrix which sequentially applies the rotation + around axis (rot axis 0 -> rot axis 1 -> rot axis 2) + + Parameters + ---------- + sin: Tensor + sin values to use for the rotation, (axis 0, axis 1, axis 2)[3] + cos: Tensor + cos values to use for the rotation, (axis 0, axis 1, axis 2)[3] + + Returns + ------- + Tensor + rotation matrix, [3, 3] + """ + rot_0 = create_rotation_3d_0(sin[0], cos[0]) + rot_1 = create_rotation_3d_1(sin[1], cos[1]) + rot_2 = create_rotation_3d_2(sin[2], cos[2]) + return rot_2 @ (rot_1 @ rot_0) + + +def create_rotation_3d_0(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a rotation matrix around the zero-th axis + + Parameters + ---------- + sin: Tensor + sin value to use for rotation matrix, [1] + cos: Tensor + cos value to use for rotation matrix, [1] + + Returns + ------- + Tensor: + rotation matrix, [3, 3] + """ + return torch.tensor([[1., 0., 0.], + [0., cos.clone(), -sin.clone()], + [0., sin.clone(), cos.clone()]]) + + +def create_rotation_3d_1(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a rotation matrix around the first axis + + Parameters + ---------- + sin: Tensor + sin value to use for rotation matrix, [1] + cos: Tensor + cos value to use for rotation matrix, [1] + + Returns + ------- + Tensor: + rotation matrix, [3, 3] + """ + return torch.tensor([[cos.clone(), 0., sin.clone()], + [0., 1., 0.], + [-sin.clone(), 0., cos.clone()]]) + + +def create_rotation_3d_2(sin: Tensor, cos: Tensor) -> Tensor: + """ + Create a rotation matrix around the second axis + + Parameters + ---------- + sin: Tensor + sin value to use for rotation matrix, [1] + cos: Tensor + cos value to use for rotation matrix, [1] + + Returns + ------- + Tensor: + rotation matrix, [3, 3] + """ + return torch.tensor([[cos.clone(), -sin.clone(), 0.], + [sin.clone(), cos.clone(), 0.], + [0., 0., 1.]]) + + +def parametrize_matrix(scale: AffineParamType, + rotation: AffineParamType, + translation: AffineParamType, + batchsize: int, ndim: int, + degree: bool = False, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + image_transform: bool = True, + ) -> torch.Tensor: + """ + Formats the given scale parameters to a homogeneous transformation matrix + + Parameters + ---------- + scale : torch.Tensor, int, float + the scale factor(s). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a scaling factor of 1 + rotation : torch.Tensor, int, float + the rotation factor(s). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a rotation factor of 1 + translation : torch.Tensor, int, float + the translation offset(s). Supported are: + * a single parameter (as float or int), which will be replicated + for all dimensions and batch samples + * a parameter per sample, which will be + replicated for all dimensions + * a parameter per dimension, which will be replicated for all + batch samples + * a parameter per sampler per dimension + None will be treated as a translation offset of 0 + batchsize : int + the number of samples per batch + ndim : int + the dimensionality of the transform + degree : bool + whether the given rotation(s) are in degrees. + Only valid for rotation parameters, which aren't passed as full + transformation matrix. + device : torch.device, str, optional + the device to put the resulting tensor to. Defaults to the default + device + dtype : torch.dtype, str, optional + the dtype of the resulting trensor. Defaults to the default dtype + image_transform: bool + adjusts transformation matrices such that they match the expected + behavior on images (see :func:`create_scale` and + :func:`create_translation` for more info) + + Returns + ------- + torch.Tensor + the transformation matrix [N, NDIM, NDIM+1], N is + the batch size and NDIM is the number of spatial dimensions + """ + scale = create_scale(scale, batchsize=batchsize, ndim=ndim, + device=device, dtype=dtype, + image_transform=image_transform) + rotation = create_rotation(rotation, batchsize=batchsize, ndim=ndim, + degree=degree, device=device, dtype=dtype) + translation = create_translation(translation, batchsize=batchsize, + ndim=ndim, device=device, dtype=dtype, + image_transform=image_transform) + return torch.bmm(torch.bmm(scale, rotation), translation)[:, :-1] + + def affine_point_transform(point_batch: torch.Tensor, matrix_batch: torch.Tensor) -> torch.Tensor: """ @@ -18,60 +414,58 @@ def affine_point_transform(point_batch: torch.Tensor, Parameters ---------- point_batch : torch.Tensor - a point batch of shape BATCHSIZE x NUM_POINTS x NDIM + a point batch of shape [N, NP, NDIM] NP is the number of points, + N is the batch size, NDIM is the number of spatial dimensions matrix_batch : torch.Tensor - a batch of affine matrices with shape N x NDIM-1 x NDIM + a batch of affine matrices with shape [N, NDIM, NDIM + 1], + N is the batch size and NDIM is the number of spatial dimensions Returns ------- torch.Tensor the batch of transformed points in cartesian coordinates) - + [N, NP, NDIM] NP is the number of points, N is the batch size, + NDIM is the number of spatial dimensions """ point_batch = points_to_homogeneous(point_batch) matrix_batch = matrix_to_homogeneous(matrix_batch) - - matrix_batch = matrix_revert_coordinate_order(matrix_batch) - transformed_points = torch.bmm(point_batch, matrix_batch.permute(0, 2, 1)) - return points_to_cartesian(transformed_points) -def affine_image_transform(image_batch: torch.Tensor, - matrix_batch: torch.Tensor, - output_size: tuple = None, - adjust_size: bool = False, - interpolation_mode: str = 'bilinear', - padding_mode: str = 'zeros', - align_corners: bool = False) -> torch.Tensor: +def create_affine_grid(batch_shape: Sequence[int], + matrix_batch: torch.Tensor, + output_size: tuple = None, + adjust_size: bool = False, + align_corners: bool = False, + device: Union[torch.device, str] = None, + dtype: Union[torch.dtype, str] = None, + ) -> torch.Tensor: """ Performs an affine transformation on a batch of images Parameters ---------- - image_batch : torch.Tensor - the batch to transform. Should have shape of N x C x (D x) H x W + batch_shape : Sequence[int] + shape of batch matrix_batch : torch.Tensor - a batch of affine matrices with shape N x NDIM-1 x NDIM + a batch of affine matrices with shape [N, NDIM, NDIM+1] output_size : Iterable if given, this will be the resulting image size. Defaults to ``None`` adjust_size : bool if True, the resulting image size will be calculated dynamically to ensure that the whole image fits. - interpolation_mode : str - interpolation mode to calculate output values 'bilinear' | 'nearest'. - Default: 'bilinear' - padding_mode : - padding mode for outside grid values - 'zeros' | 'border' | 'reflection'. Default: 'zeros' align_corners : Geometrically, we consider the pixels of the input as squares rather than points. If set to True, the extrema (-1 and 1) are considered as referring to the center points of the input’s corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. + device: Union[torch.device, str] + device where grid will be cached + dtype: Union[torch.dtype, str] + data type of grid Returns ------- @@ -91,13 +485,12 @@ def affine_image_transform(image_batch: torch.Tensor, If None of them is set, the resulting image will have the same size as the input image """ - # add batch dimension if necessary if len(matrix_batch.shape) < 3: - matrix_batch = matrix_batch[None, ...].expand(image_batch.size(0), - -1, -1).clone() + matrix_batch = matrix_batch[None, ...].expand( + batch_shape[0], -1, -1).clone() - image_size = image_batch.shape[2:] + image_size = batch_shape[2:] if output_size is not None: if check_scalar(output_size): @@ -106,9 +499,7 @@ def affine_image_transform(image_batch: torch.Tensor, if adjust_size: warnings.warn("Adjust size is mutually exclusive with a " "given output size.", UserWarning) - new_size = output_size - elif adjust_size: new_size = tuple([int(tmp.item()) for tmp in _check_new_img_size(image_size, @@ -116,23 +507,19 @@ def affine_image_transform(image_batch: torch.Tensor, else: new_size = image_size - if len(image_size) < len(image_batch.shape): - missing_dims = len(image_batch.shape) - len(image_size) - new_size = (*image_batch.shape[:missing_dims], *new_size) + if len(image_size) < len(batch_shape): + missing_dims = len(batch_shape) - len(image_size) + new_size = (*batch_shape[:missing_dims], *new_size) - matrix_batch = matrix_batch.to(device=image_batch.device, - dtype=image_batch.dtype) + matrix_batch = matrix_batch.to(device=device, dtype=dtype) - grid = torch.nn.functional.affine_grid(matrix_batch, size=new_size, - align_corners=align_corners) + grid = torch.nn.functional.affine_grid( + matrix_batch, size=new_size, align_corners=align_corners) + return grid - return torch.nn.functional.grid_sample(image_batch, grid, - mode=interpolation_mode, - padding_mode=padding_mode, - align_corners=align_corners) - -def _check_new_img_size(curr_img_size, matrix: torch.Tensor) -> torch.Tensor: +def _check_new_img_size(curr_img_size, matrix: torch.Tensor, + zero_border: bool = False) -> torch.Tensor: """ Calculates the image size so that the whole image content fits the image. The resulting size will be the maximum size of the batch, so that the @@ -144,50 +531,28 @@ def _check_new_img_size(curr_img_size, matrix: torch.Tensor) -> torch.Tensor: the size of the current image. If int, it will be used as size for all image dimensions matrix : torch.Tensor - a batch of affine matrices with shape N x NDIM x NDIM + 1 + a batch of affine matrices with shape [N, NDIM, NDIM+1] + zero_border : bool + whether or not to have a fixed image border at zero Returns ------- torch.Tensor the new image size - """ - n_dim = matrix.size(-1) - 1 - if check_scalar(curr_img_size): curr_img_size = [curr_img_size] * n_dim - - curr_img_size = [tmp - 1 for tmp in curr_img_size] - - if n_dim == 2: - possible_points = torch.tensor([[0., 0.], [0., curr_img_size[1]], - [curr_img_size[0], 0], curr_img_size], - dtype=matrix.dtype, - device=matrix.device) - elif n_dim == 3: - possible_points = torch.tensor( - [ - [0., 0., 0.], - [0., 0., curr_img_size[2]], - [0., curr_img_size[1], 0], - [0., curr_img_size[1], curr_img_size[2]], - [curr_img_size[0], 0., 0.], - [curr_img_size[0], 0., curr_img_size[2]], - [curr_img_size[0], curr_img_size[1], 0.], - curr_img_size - ], device=matrix.device, dtype=matrix.dtype - ) - - else: - raise ValueError('Invalid number of dimensions! Expected One of ' - '{2, 3}, but got %s' % str(n_dim)) + possible_points = unit_box(n_dim, torch.tensor(curr_img_size)).to(matrix) transformed_edges = affine_point_transform( possible_points[None].expand( - matrix.size(0), - *[-1 for _ in possible_points.shape]).clone(), + matrix.size(0), *[-1 for _ in possible_points.shape]).clone(), matrix) - return (transformed_edges.max(1)[0] - - transformed_edges.min(1)[0]).max(0)[0] + 1 + if zero_border: + substr = 0 + else: + substr = transformed_edges.min(1)[0] + + return (transformed_edges.max(1)[0] - substr).max(0)[0] diff --git a/rising/transforms/functional/crop.py b/rising/transforms/functional/crop.py index cc358c4c..2c47cbc2 100644 --- a/rising/transforms/functional/crop.py +++ b/rising/transforms/functional/crop.py @@ -7,18 +7,24 @@ __all__ = ["crop", "center_crop", "random_crop"] -def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int]): +def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int], + grid_crop: bool = False): """ Extract crop from last dimensions of data Parameters ---------- data: torch.Tensor - input tensor + input tensor [... , spatial dims] spatial dims can be arbitrary + spatial dimensions. Leading dimensions will be preserved. corner: Sequence[int] top left corner point size: Sequence[int] size of patch + grid_crop: bool + crop from grid of shape [N, spatial dims, NDIM], where N is the batch + size, spatial dims can be arbitrary spatial dimensions and NDIM + is the number of spatial dimensions Returns ------- @@ -26,24 +32,33 @@ def crop(data: torch.Tensor, corner: Sequence[int], size: Sequence[int]): cropped data """ _slices = [] - if len(corner) < data.ndim: - for i in range(data.ndim - len(corner)): + ndim = data.ndimension() - int(bool(grid_crop)) + if len(corner) < ndim: + for i in range(ndim - len(corner)): _slices.append(slice(0, data.shape[i])) _slices = _slices + [slice(c, c + s) for c, s in zip(corner, size)] + if grid_crop: + _slices.append(slice(0, data.shape[-1])) return data[_slices] -def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]]) -> torch.Tensor: +def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]], + grid_crop: bool = False) -> torch.Tensor: """ Crop patch from center Parameters ---------- data: torch.Tensor - input tensor + input tensor [... , spatial dims] spatial dims can be arbitrary + spatial dimensions. Leading dimensions will be preserved. size: Union[int, Sequence[int]] size of patch + grid_crop: bool + crop from grid of shape [N, spatial dims, NDIM], where N is the batch + size, spatial dims can be arbitrary spatial dimensions and NDIM + is the number of spatial dimensions Returns ------- @@ -55,23 +70,34 @@ def center_crop(data: torch.Tensor, size: Union[int, Sequence[int]]) -> torch.Te if not isinstance(size[0], int): size = [int(s) for s in size] - corner = [int(round((img_dim - crop_dim) / 2.)) for img_dim, crop_dim in zip(data.shape[2:], size)] - return crop(data, corner, size) + if grid_crop: + data_shape = data.shape[1:-1] + else: + data_shape = data.shape[2:] + + corner = [int(round((img_dim - crop_dim) / 2.)) for img_dim, crop_dim in zip(data_shape, size)] + return crop(data, corner, size, grid_crop=grid_crop) def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]], - dist: Union[int, Sequence[int]] = 0) -> torch.Tensor: + dist: Union[int, Sequence[int]] = 0, + grid_crop: bool = False) -> torch.Tensor: """ Crop random patch/volume from input tensor Parameters ---------- data: torch.Tensor - input tensor + input tensor [... , spatial dims] spatial dims can be arbitrary + spatial dimensions. Leading dimensions will be preserved. size: Union[int, Sequence[int]] size of patch/volume dist: Union[int, Sequence[int]] minimum distance to border. By default zero + grid_crop: bool + crop from grid of shape [N, spatial dims, NDIM], where N is the batch + size, spatial dims can be arbitrary spatial dimensions and NDIM + is the number of spatial dimensions Returns ------- @@ -85,9 +111,14 @@ def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]], if not isinstance(size[0], int): size = [int(s) for s in size] - if any([crop_dim + dist_dim >= img_dim for img_dim, crop_dim, dist_dim in zip(data.shape[2:], size, dist)]): + if grid_crop: + data_shape = data.shape[1:-1] + else: + data_shape = data.shape[2:] + + if any([crop_dim + dist_dim >= img_dim for img_dim, crop_dim, dist_dim in zip(data_shape, size, dist)]): raise TypeError(f"Crop can not be realized with given size {size} and dist {dist}.") corner = [random.randrange(0, img_dim - crop_dim - dist_dim) for - img_dim, crop_dim, dist_dim in zip(data.shape[2:], size, dist)] - return crop(data, corner, size) + img_dim, crop_dim, dist_dim in zip(data_shape, size, dist)] + return crop(data, corner, size, grid_crop=grid_crop) diff --git a/rising/transforms/functional/kernel.py b/rising/transforms/functional/kernel.py new file mode 100644 index 00000000..e287bd41 --- /dev/null +++ b/rising/transforms/functional/kernel.py @@ -0,0 +1,33 @@ +import math +import torch + +from typing import Sequence, Union + +from rising.utils import check_scalar + +__all__ = ["gaussian_kernel"] + + +def gaussian_kernel(dim: int, kernel_size: Union[int, Sequence[int]], + std: Union[float, Sequence[float]], in_channels: int = 1) -> torch.Tensor: + if check_scalar(kernel_size): + kernel_size = [kernel_size] * dim + if check_scalar(std): + std = [std] * dim + # The gaussian kernel is the product of the gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) + for size in kernel_size]) + + for size, std, mgrid in zip(kernel_size, std, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / std) ** 2 / 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / kernel.sum() + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(in_channels, *[1] * (kernel.dim() - 1)) + kernel.requires_grad = False + return kernel.contiguous() diff --git a/rising/transforms/grid.py b/rising/transforms/grid.py new file mode 100644 index 00000000..961fd4d6 --- /dev/null +++ b/rising/transforms/grid.py @@ -0,0 +1,204 @@ +from typing import Sequence, Union, Dict, Tuple + +import torch + +from abc import abstractmethod +from torch import Tensor + +from rising.transforms import AbstractTransform, GaussianSmoothing +from rising.utils.affine import get_batched_eye, matrix_to_homogeneous +from rising.transforms.functional import center_crop, random_crop + + +__all__ = ["GridTransform", "StackedGridTransform", + "CenterCropGrid", "RandomCropGrid", "ElasticDistortion", "RadialDistortion"] + + +class GridTransform(AbstractTransform): + def __init__(self, + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False, + **kwargs, + ): + super().__init__(grad=grad) + self.keys = keys + self.interpolation_mode = interpolation_mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.kwargs = kwargs + + self.grid: Dict[Tuple, Tensor] = None + + def forward(self, **data) -> dict: + if self.grid is None: + self.grid = self.create_grid([data[key].shape for key in self.keys]) + + self.grid = self.augment_grid(self.grid) + + for key in self.keys: + _grid = self.grid[tuple(data[key].shape)] + _grid = _grid.to(data[key]) + + data[key] = torch.nn.functional.grid_sample( + data[key], _grid, mode=self.interpolation_mode, + padding_mode=self.padding_mode, align_corners=self.align_corners) + self.grid = None + return data + + @abstractmethod + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + raise NotImplementedError + + def create_grid(self, input_size: Sequence[Sequence[int]], + matrix: Tensor = None) -> Dict[Tuple, Tensor]: + if matrix is None: + matrix = get_batched_eye(batchsize=input_size[0][0], ndim=len(input_size[0]) - 2) + matrix = matrix_to_homogeneous(matrix)[:, :-1] + + grid = {} + for size in input_size: + if tuple(size) not in grid: + grid[tuple(size)] = torch.nn.functional.affine_grid( + matrix, size=size, align_corners=self.align_corners) + return grid + + def __add__(self, other): + if not isinstance(other, GridTransform): + raise ValueError("Concatenation is only supported for grid transforms.") + return StackedGridTransform(self, other) + + def __radd__(self, other): + if not isinstance(other, GridTransform): + raise ValueError("Concatenation is only supported for grid transforms.") + return StackedGridTransform(other, self) + + +class StackedGridTransform(GridTransform): + def __init__(self, *transforms: Union[GridTransform, Sequence[GridTransform]]): + super().__init__(keys=None, interpolation_mode=None, padding_mode=None, + align_corners=None) + if isinstance(transforms, (tuple, list)): + if isinstance(transforms[0], (tuple, list)): + transforms = transforms[0] + self.transforms = transforms + + def create_grid(self, input_size: Sequence[Sequence[int]], matrix: Tensor = None) -> \ + Dict[Tuple, Tensor]: + return self.transforms[0].create_grid(input_size=input_size, matrix=matrix) + + def augment_grid(self, grid: Tensor) -> Tensor: + for transform in self.transforms: + grid = transform.augment_grid(grid) + return grid + + +class CenterCropGrid(GridTransform): + def __init__(self, + size: Union[int, Sequence[int]], + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False, + **kwargs,): + super().__init__(keys=keys, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + grad=grad, **kwargs) + self.size = size + + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + return {key: center_crop(item, size=self.size, grid_crop=True) + for key, item in grid.items()} + + +class RandomCropGrid(GridTransform): + def __init__(self, + size: Union[int, Sequence[int]], + dist: Union[int, Sequence[int]] = 0, + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False, + **kwargs,): + super().__init__(keys=keys, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + grad=grad, **kwargs) + self.size = size + self.dist = dist + + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + return {key: random_crop(item, size=self.size, dist=self.dist, grid_crop=True) + for key, item in grid.items()} + + +class ElasticDistortion(GridTransform): + def __init__(self, + std: Union[float, Sequence[float]], + alpha: float, + dim: int = 2, + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False, + **kwargs,): + super().__init__(keys=keys, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + grad=grad, **kwargs) + self.std = std + self.alpha = alpha + self.gaussian = GaussianSmoothing(in_channels=1, kernel_size=7, std=self.std, + dim=dim, stride=1, padding=3) + + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + for key in grid.keys(): + random_offsets = torch.rand(1, 1, *grid[key].shape[1:-1]) * 2 - 1 + random_offsets = self.gaussian(**{"data": random_offsets})["data"] * self.alpha + print(random_offsets.shape) + print(grid[key].shape) + print(random_offsets.max()) + print(random_offsets.min()) + grid[key] += random_offsets[:, 0, ..., None] + return grid + + +class RadialDistortion(GridTransform): + def __init__(self, + scale: float, + keys: Sequence[str] = ('data',), + interpolation_mode: str = 'bilinear', + padding_mode: str = 'zeros', + align_corners: bool = False, + grad: bool = False, + **kwargs,): + super().__init__(keys=keys, interpolation_mode=interpolation_mode, + padding_mode=padding_mode, align_corners=align_corners, + grad=grad, **kwargs) + self.scale = scale + + def augment_grid(self, grid: Dict[Tuple, Tensor]) -> Dict[Tuple, Tensor]: + + new_grid = {key: radial_distortion_grid(item, scale=self.scale) + for key, item in grid.items()} + print(new_grid) + return new_grid + + +def radial_distortion_grid(grid: Tensor, scale: float) -> Tensor: + # spatial_shape = grid.shape[1:-1] + # new_grid = torch.stack([torch.meshgrid( + # *[torch.linspace(-1, 1, i) for i in spatial_shape])], dim=-1).to(grid) + # print(new_grid.shape) + # + # distortion = + + dist = torch.norm(grid, 2, dim=-1, keepdim=True) + dist = dist / dist.max() + distortion = (scale[0] * dist.pow(3) + scale[1] * dist.pow(2) + scale[2] * dist) / 3 + print(distortion.max()) + print(distortion.min()) + return grid * (1 - distortion) diff --git a/rising/transforms/kernel.py b/rising/transforms/kernel.py index 5937e5d5..f6730aa1 100644 --- a/rising/transforms/kernel.py +++ b/rising/transforms/kernel.py @@ -4,6 +4,7 @@ from .abstract import AbstractTransform from rising.utils import check_scalar +from rising.transforms.functional import gaussian_kernel __all__ = ["KernelTransform", "GaussianSmoothing"] @@ -111,10 +112,12 @@ def forward(self, **data) -> dict: class GaussianSmoothing(KernelTransform): - def __init__(self, in_channels: int, kernel_size: Union[int, Sequence], + def __init__(self, + in_channels: int, + kernel_size: Union[int, Sequence], std: Union[int, Sequence], dim: int = 2, stride: Union[int, Sequence] = 1, padding: Union[int, Sequence] = 0, - padding_mode: str = 'reflect', keys: Sequence = ('data',), grad: bool = False, + padding_mode: str = 'constant', keys: Sequence = ('data',), grad: bool = False, **kwargs): """ Perform Gaussian Smoothing. @@ -150,9 +153,8 @@ def __init__(self, in_channels: int, kernel_size: Union[int, Sequence], -------- :func:`torch.functional.pad` """ - if check_scalar(std): - std = [std] * dim self.std = std + self.spatial_dim = dim super().__init__(in_channels=in_channels, kernel_size=kernel_size, dim=dim, stride=stride, padding=padding, padding_mode=padding_mode, keys=keys, grad=grad, **kwargs) @@ -160,22 +162,6 @@ def create_kernel(self) -> torch.Tensor: """ Create gaussian blur kernel """ - # The gaussian kernel is the product of the gaussian function of each dimension. - kernel = 1 - meshgrids = torch.meshgrid([ - torch.arange(size, dtype=torch.float32) - for size in self.kernel_size - ]) - - for size, std, mgrid in zip(self.kernel_size, self.std, meshgrids): - mean = (size - 1) / 2 - kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / std) ** 2 / 2) - - # Make sure sum of values in gaussian kernel equals 1. - kernel = kernel / kernel.sum() - - # Reshape to depthwise convolutional weight - kernel = kernel.view(1, 1, *kernel.size()) - kernel = kernel.repeat(self.in_channels, *[1] * (kernel.dim() - 1)) - kernel.requires_grad = False - return kernel.contiguous() + return gaussian_kernel(kernel_size=self.kernel_size, + std=self.std, in_channels=self.in_channels, + dim=self.spatial_dim) diff --git a/rising/utils/affine.py b/rising/utils/affine.py index ee21a9ab..3b916fe4 100644 --- a/rising/utils/affine.py +++ b/rising/utils/affine.py @@ -1,10 +1,9 @@ import torch -from rising.utils.checktype import check_scalar +import itertools + from math import pi from typing import Union, Sequence -AffineParamType = Union[int, float, Sequence, torch.Tensor] - def points_to_homogeneous(batch: torch.Tensor) -> torch.Tensor: """ @@ -35,7 +34,7 @@ def matrix_to_homogeneous(batch: torch.Tensor) -> torch.Tensor: Parameters ---------- batch : torch.Tensor - the batch of matrices to convert + the batch of matrices to convert [N, dim, dim] Returns ------- @@ -47,10 +46,9 @@ def matrix_to_homogeneous(batch: torch.Tensor) -> torch.Tensor: missing = batch.new_zeros(size=(*batch.shape[:-1], 1)) batch = torch.cat([batch, missing], dim=-1) - missing = torch.zeros((batch.size(0), - *[1 for tmp in batch.shape[1:-1]], - batch.size(-1)), - device=batch.device, dtype=batch.dtype) + missing = torch.zeros( + (batch.size(0), *[1 for tmp in batch.shape[1:-1]], batch.size(-1)), + device=batch.device, dtype=batch.dtype) missing[..., -1] = 1 @@ -106,28 +104,6 @@ def points_to_cartesian(batch: torch.Tensor) -> torch.Tensor: return batch[..., :-1] / batch[..., -1, None] -def matrix_revert_coordinate_order(batch: torch.Tensor) -> torch.Tensor: - """ - Reverts the coordinate order of a matrix (e.g. from xyz to zyx). - - Parameters - ---------- - batch : torch.Tensor - the batched transformation matrices; Should be of shape - BATCHSIZE x NDIM x NDIM - - Returns - ------- - torch.Tensor - the matrix performing the same transformation on vectors with a - reversed coordinate order - - """ - batch[:, :-1, :] = batch[:, :-1, :].flip(1).clone() - batch[:, :-1, :-1] = batch[:, :-1, :-1].flip(2).clone() - return batch - - def get_batched_eye(batchsize: int, ndim: int, device: Union[torch.device, str] = None, dtype: Union[torch.dtype, str] = None) -> torch.Tensor: @@ -156,161 +132,6 @@ def get_batched_eye(batchsize: int, ndim: int, 1, ndim, ndim).expand(batchsize, -1, -1).clone() -def _format_scale(scale: AffineParamType, - batchsize: int, ndim: int, - device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None) -> torch.Tensor: - """ - Formats the given scale parameters to a homogeneous transformation matrix - - Parameters - ---------- - scale : torch.Tensor, int, float - the scale factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a scaling factor of 1 - batchsize : int - the number of samples per batch - ndim : int - the dimensionality of the transform - device : torch.device, str, optional - the device to put the resulting tensor to. Defaults to the default - device - dtype : torch.dtype, str, optional - the dtype of the resulting trensor. Defaults to the default dtype - - Returns - ------- - torch.Tensor - the homogeneous transformation matrix - - """ - - if scale is None: - scale = 1 - - if check_scalar(scale): - - scale = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, - dtype=dtype) * scale - - elif not torch.is_tensor(scale): - scale = torch.tensor(scale, dtype=dtype, device=device) - - # scale must be tensor by now - scale = scale.to(device=device, dtype=dtype) - - # scale is already batched matrix - if scale.size() == (batchsize, ndim, ndim) or scale.size() == (batchsize, ndim, ndim + 1): - return matrix_to_homogeneous(scale) - - # scale is batched matrix with same element for each dimension or just - # not diagonalized - if scale.size() == (batchsize, ndim) or scale.size() == (batchsize,): - new_scale = get_batched_eye(batchsize=batchsize, ndim=ndim, - device=device, dtype=dtype) - - return matrix_to_homogeneous(new_scale * scale.view(batchsize, -1, 1)) - - # scale contains a non-diagonalized form (will be repeated for each batch - # item) - elif scale.size() == (ndim,): - return matrix_to_homogeneous( - torch.diag(scale).view(1, ndim, ndim).expand(batchsize, - -1, -1).clone()) - - # scale contains a diagonalized but not batched matrix - # (will be repeated for each batch item) - elif scale.size() == (ndim, ndim): - return matrix_to_homogeneous( - scale.view(1, ndim, ndim).expand(batchsize, -1, -1).clone()) - - raise ValueError("Unknown shape for scale matrix: %s" - % str(tuple(scale.size()))) - - -def _format_translation(offset: AffineParamType, - batchsize: int, ndim: int, - device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None - ) -> torch.Tensor: - """ - Formats the given translation parameters to a homogeneous transformation - matrix - - Parameters - ---------- - offset : torch.Tensor, int, float - the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a translation offset of 0 - batchsize : int - the number of samples per batch - ndim : int - the dimensionality of the transform - device : torch.device, str, optional - the device to put the resulting tensor to. Defaults to the default - device - dtype : torch.dtype, str, optional - the dtype of the resulting trensor. Defaults to the default dtype - - Returns - ------- - torch.Tensor - the homogeneous transformation matrix - - """ - if offset is None: - offset = 0 - - if check_scalar(offset): - offset = torch.tensor([offset] * ndim, device=device, dtype=dtype) - - elif not torch.is_tensor(offset): - offset = torch.tensor(offset, device=device, dtype=dtype) - - # assumes offset to be tensor from now on - offset = offset.to(device=device, dtype=dtype) - - # translation matrix already built - if offset.size() == (batchsize, ndim + 1, ndim + 1): - return offset - elif offset.size() == (batchsize, ndim, ndim + 1): - return matrix_to_homogeneous(offset) - - # not completely built so far -> bring in shape (batchsize, ndim) - if offset.size() == (batchsize,): - offset = offset.view(-1, 1).expand(-1, ndim).clone() - elif offset.size() == (ndim,): - offset = offset.view(1, -1).expand(batchsize, -1).clone() - elif not offset.size() == (batchsize, ndim): - raise ValueError("Unknown shape for offsets: %s" - % str(tuple(offset.shape))) - - # directly build homogeneous form -> use dim+1 - whole_translation_matrix = get_batched_eye(batchsize=batchsize, - ndim=ndim + 1, device=device, - dtype=dtype) - - whole_translation_matrix[:, :-1, -1] = offset.clone() - return whole_translation_matrix - - def deg_to_rad(angles: Union[torch.Tensor, float, int] ) -> Union[torch.Tensor, float, int]: """ @@ -330,291 +151,24 @@ def deg_to_rad(angles: Union[torch.Tensor, float, int] return angles * pi / 180 -def _format_rotation(rotation: AffineParamType, - batchsize: int, ndim: int, - degree: bool = False, - device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None) -> torch.Tensor: - """ - Formats the given scale parameters to a homogeneous transformation matrix - - Parameters - ---------- - rotation : torch.Tensor, int, float - the rotation factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM(+1) x NDIM(+1)) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a rotation factor of 0 - batchsize : int - the number of samples per batch - ndim : int - the dimensionality of the transform - degree : bool - whether the given rotation(s) are in degrees. - Only valid for rotation parameters, which aren't passed as full - transformation matrix. - device : torch.device, str, optional - the device to put the resulting tensor to. Defaults to the default - device - dtype : torch.dtype, str, optional - the dtype of the resulting trensor. Defaults to the default dtype - - Returns - ------- - torch.Tensor - the homogeneous transformation matrix - - """ - if rotation is None: - rotation = 0 - - num_rot_params = 1 if ndim == 2 else ndim - - if check_scalar(rotation): - rotation = torch.ones(batchsize, num_rot_params, - device=device, dtype=dtype) * rotation - - elif not torch.is_tensor(rotation): - rotation = torch.tensor(rotation, device=device, dtype=dtype) - - # assumes rotation to be tensor by now - rotation = rotation.to(device=device, dtype=dtype) - - # already complete - if rotation.size() == (batchsize, ndim, ndim) or rotation.size() == (batchsize, ndim, ndim + 1): - return matrix_to_homogeneous(rotation) - elif rotation.size() == (batchsize, ndim + 1, ndim + 1): - return rotation - - if degree: - rotation = deg_to_rad(rotation) - - # repeat along batch dimension - if rotation.size() == (ndim, ndim) or rotation.size() == (ndim + 1, ndim + 1): - rotation = rotation[None].expand(batchsize, -1, -1).clone() - if rotation.size(-1) == ndim: - rotation = matrix_to_homogeneous(rotation) - - return rotation - # bring it to default size of (batchsize, num_rot_params) - elif rotation.size() == (batchsize,): - rotation = rotation.view(batchsize, 1).expand(-1, - num_rot_params).clone() - elif rotation.size() == (num_rot_params,): - rotation = rotation.view(1, num_rot_params).expand(batchsize, - -1).clone() - elif rotation.size() != (batchsize, num_rot_params): - raise ValueError("Invalid shape for rotation parameters: %s" - % (str(tuple(rotation.size())))) - - sin, cos = rotation.sin(), rotation.cos() - - whole_rot_matrix = get_batched_eye(batchsize=batchsize, ndim=ndim, - device=device, dtype=dtype) - - # assemble the actual matrix - if num_rot_params == 1: - whole_rot_matrix[:, 0, 0] = cos[0].clone() - whole_rot_matrix[:, 1, 1] = cos[0].clone() - whole_rot_matrix[:, 0, 1] = (-sin[0]).clone() - whole_rot_matrix[:, 1, 0] = sin[0].clone() - - else: - whole_rot_matrix[:, 0, 0] = (cos[:, 0] * cos[:, 1] * cos[:, 2] - - sin[:, 0] * sin[:, 2]).clone() - whole_rot_matrix[:, 0, 1] = (-cos[:, 0] * cos[:, 1] * sin[:, 2] - - sin[:, 0] * cos[:, 2]).clone() - whole_rot_matrix[:, 0, 2] = (cos[:, 0] * sin[:, 1]).clone() - whole_rot_matrix[:, 1, 0] = (sin[:, 0] * cos[:, 1] * cos[:, 2] - + cos[:, 0] * sin[:, 2]).clone() - whole_rot_matrix[:, 1, 1] = (-sin[:, 0] * cos[:, 1] * sin[:, 2] - + cos[:, 0] * cos[:, 2]).clone() - whole_rot_matrix[:, 2, 0] = (-sin[:, 1] * cos[:, 2]).clone() - whole_rot_matrix[:, 2, 1] = (-sin[:, 1] * sin[:, 2]).clone() - whole_rot_matrix[:, 2, 2] = (cos[:, 1]).clone() - - return matrix_to_homogeneous(whole_rot_matrix) - - -def parametrize_matrix(scale: AffineParamType, - rotation: AffineParamType, - translation: AffineParamType, - batchsize: int, ndim: int, - degree: bool = False, - device: Union[torch.device, str] = None, - dtype: Union[torch.dtype, str] = None) -> torch.Tensor: - """ - Formats the given scale parameters to a homogeneous transformation matrix - - Parameters - ---------- - scale : torch.Tensor, int, float - the scale factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a scaling factor of 1 - rotation : torch.Tensor, int, float - the rotation factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a rotation factor of 1 - translation : torch.Tensor, int, float - the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a translation offset of 0 - batchsize : int - the number of samples per batch - ndim : int - the dimensionality of the transform - degree : bool - whether the given rotation(s) are in degrees. - Only valid for rotation parameters, which aren't passed as full - transformation matrix. - device : torch.device, str, optional - the device to put the resulting tensor to. Defaults to the default - device - dtype : torch.dtype, str, optional - the dtype of the resulting trensor. Defaults to the default dtype - - Returns - ------- - torch.Tensor - the transformation matrix (of shape (BATCHSIZE x NDIM x NDIM+1) - - """ - scale = _format_scale(scale, batchsize=batchsize, ndim=ndim, - device=device, dtype=dtype) - rotation = _format_rotation(rotation, batchsize=batchsize, ndim=ndim, - degree=degree, device=device, dtype=dtype) - - translation = _format_translation(translation, batchsize=batchsize, - ndim=ndim, device=device, dtype=dtype) - - return torch.bmm(torch.bmm(scale, rotation), translation)[:, :-1] - - -def assemble_matrix_if_necessary(batchsize: int, ndim: int, - scale: AffineParamType, - rotation: AffineParamType, - translation: AffineParamType, - matrix: torch.Tensor, - degree: bool, - device: Union[torch.device, str], - dtype: Union[torch.dtype, str] - ) -> torch.Tensor: +def unit_box(n: int, scale: torch.Tensor = None) -> torch.Tensor: """ - Assembles a matrix, if the matrix is not already given + Create a sclaed version of a unit box Parameters ---------- - batchsize : int - number of samples per batch - ndim : int - the image dimensionality - scale : torch.Tensor, int, float - the scale factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a scaling factor of 1 - rotation : torch.Tensor, int, float - the rotation factor(s). Supported are: - * a full transformation matrix of shape (BATCHSIZE x NDIM x NDIM) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a rotation factor of 1 - translation : torch.Tensor, int, float - the translation offset(s). Supported are: - * a full homogeneous transformation matrix of shape - (BATCHSIZE x NDIM+1 x NDIM+1) - * a single parameter (as float or int), which will be replicated - for all dimensions and batch samples - * a single parameter per sample (as a 1d tensor), which will be - replicated for all dimensions - * a single parameter per dimension (either as 1d tensor or as - 2d transformation matrix), which will be replicated for all - batch samples - None will be treated as a translation offset of 0 - matrix : torch.Tensor - the transformation matrix. If other than None: overwrites separate - parameters for :param:`scale`, :param:`rotation` and - :param:`translation` - degree : bool - whether the given rotation is in degrees. Only valid for explicit - rotation parameters - device : str, torch.device - the device, the matrix should be put on - dtype : str, torch.dtype - the datatype, the matrix should have + n: int + number of dimensions + scale: Tensor + scaling of each dimension Returns ------- - torch.Tensor - the assembled transformation matrix - - """ - if matrix is None: - matrix = parametrize_matrix(scale=scale, rotation=rotation, - translation=translation, - batchsize=batchsize, - ndim=ndim, - degree=degree, - device=device, - dtype=dtype) - - else: - if not torch.is_tensor(matrix): - matrix = torch.tensor(matrix) - - matrix = matrix.to(dtype=dtype, device=device) - - # batch dimension missing -> Replicate for each sample in batch - if len(matrix.shape) == 2: - matrix = matrix[None].expand(batchsize, -1, -1).clone() - - if matrix.shape == (batchsize, ndim, ndim + 1): - return matrix - elif matrix.shape == (batchsize, ndim + 1, ndim + 1): - return matrix_to_cartesian(matrix) - - raise ValueError( - "Invalid Shape for affine transformation matrix. " - "Got %s but expected %s" % ( - str(tuple(matrix.shape)), - str((batchsize, ndim, ndim + 1)))) + Tensor + scaled unit box + """ + box = torch.tensor( + [list(i) for i in itertools.product([0, 1], repeat=n)]) + if scale is not None: + box = box.to(scale) * scale[None] + return box diff --git a/tests/transforms/functional/test_affine.py b/tests/transforms/functional/test_affine.py index b5877678..e920cee0 100644 --- a/tests/transforms/functional/test_affine.py +++ b/tests/transforms/functional/test_affine.py @@ -1,30 +1,26 @@ import unittest import torch from rising.transforms.functional.affine import _check_new_img_size, \ - affine_point_transform, affine_image_transform -from rising.utils.affine import parametrize_matrix, matrix_to_homogeneous, matrix_to_cartesian, \ - matrix_revert_coordinate_order -from rising.utils.checktype import check_scalar + affine_point_transform, create_affine_grid, parametrize_matrix, \ + create_rotation, create_translation, create_scale +from rising.utils.affine import matrix_to_homogeneous, matrix_to_cartesian class AffineTestCase(unittest.TestCase): - def test_check_image_size(self): - images = [torch.rand(11, 2, 3, 4, 5), torch.rand(11, 2, 3, 4), torch.rand(11, 2, 3, 3)] - - img_sizes = [ - [3, 4, 5], [3, 4], 3 + images = [ + torch.rand(11, 2, 3, 4, 5), + torch.rand(11, 2, 3, 4), + torch.rand(11, 2, 3, 3), ] + img_sizes = [[3, 4, 5], [3, 4], 3] + scales = [ - torch.tensor([[2., 0., 0.], - [0., 3., 0.], - [0., 0., 4.]]), - torch.tensor([[2., 0.], [0., 3.]]), - torch.tensor([[2., 0.], [0., 3.]]) + torch.tensor([2., 3., 4.]), + torch.tensor([2., 3.]), + torch.tensor([2., 3.]) ] - rots = [[45., 90., 135.], [45.], [45.]] - trans = [[0., 10., 20.], [10., 20.], [10., 20.]] edges = [ @@ -40,39 +36,41 @@ def test_check_image_size(self): ] ] - for img, size, scale, rot, tran, edge_pts in zip(images, img_sizes, - scales, rots, trans, - edges): + for img, size, scale, rot, tran, edge_pts in zip( + images, img_sizes, scales, rots, trans, edges): ndim = scale.size(-1) with self.subTest(ndim=ndim): affine = matrix_to_homogeneous( parametrize_matrix(scale=scale, rotation=rot, translation=tran, degree=True, - batchsize=1, ndim=ndim, dtype=torch.float)) + batchsize=1, ndim=ndim, dtype=torch.float, + image_transform=False)) edge_pts = torch.tensor(edge_pts, dtype=torch.float) - edge_pts[edge_pts > 1] = edge_pts[edge_pts > 1] - 1 img = img.to(torch.float) + new_edges = torch.bmm(edge_pts.unsqueeze(0), affine.clone().permute(0, 2, 1)) - new_edges = torch.bmm(edge_pts.unsqueeze(0), - matrix_revert_coordinate_order(affine.clone()).permute(0, 2, 1)) + img_size_zero_border = new_edges.max(dim=1)[0][0] + img_size_non_zero_border = (new_edges.max(dim=1)[0] - new_edges.min(dim=1)[0])[0] - img_size = (new_edges.max(dim=1)[0] - new_edges.min(dim=1)[0])[0] + fn_result_zero_border = _check_new_img_size( + size, matrix_to_cartesian(affine.expand(img.size(0), -1, -1).clone()), + zero_border=True, + ) + fn_result_non_zero_border = _check_new_img_size( + size, matrix_to_cartesian(affine.expand(img.size(0), -1, -1).clone()), + zero_border=False, + ) - fn_result = _check_new_img_size(size, - matrix_to_cartesian( - affine.expand(img.size(0), -1, -1).clone())) - - self.assertTrue(torch.allclose(img_size[:-1] + 1, - fn_result)) - - with self.assertRaises(ValueError): - _check_new_img_size([2, 3, 4, 5], torch.rand(11, 2, 2, 3, 4, 5)) + self.assertTrue(torch.allclose( + img_size_zero_border[:-1], fn_result_zero_border)) + self.assertTrue(torch.allclose( + img_size_non_zero_border[:-1], fn_result_non_zero_border)) def test_affine_point_transform(self): points = [ [[[0, 1], [1, 0]]], - [[[0, 0, 1]]] + [[[1, 1, 1]]], ] matrices = [ torch.tensor([[[1., 0.], [0., 5.]]]), @@ -81,11 +79,12 @@ def test_affine_point_transform(self): rotation=[0, 0, 90], degree=True, batchsize=1, ndim=3, dtype=torch.float, - device='cpu') + device='cpu', + image_transform=False) ] expected = [ - [[0, 1], [5, 0]], - [[0, 1, 0]] + [[0, 5], [1, 0]], + [[-1, 1, 1]] ] for input_pt, matrix, expected_pt in zip(points, matrices, expected): @@ -106,12 +105,11 @@ def test_affine_point_transform(self): atol=1e-7)) def test_affine_image_trafo(self): - matrix = torch.tensor([[4., 0., 0.], [0., 5., 0.]]) image_batch = torch.zeros(10, 3, 25, 25, dtype=torch.float, device='cpu') - target_sizes = [(121, 97), image_batch.shape[2:], (50, 50), (50, 50), + target_sizes = [(100, 125), image_batch.shape[2:], (50, 50), (50, 50), (45, 50), (45, 50)] for output_size in [None, 50, (45, 50)]: @@ -123,20 +121,135 @@ def test_affine_image_trafo(self): output_size=output_size): if output_size is not None and adjust_size: with self.assertWarns(UserWarning): - result = affine_image_transform( - image_batch=image_batch, + grid = create_affine_grid( + batch_shape=image_batch.shape, matrix_batch=matrix, output_size=output_size, - adjust_size=adjust_size) + adjust_size=adjust_size, + device=image_batch.device, + dtype=image_batch.dtype, + ) else: - result = affine_image_transform( - image_batch=image_batch, + grid = create_affine_grid( + batch_shape=image_batch.shape, matrix_batch=matrix, output_size=output_size, - adjust_size=adjust_size) - + adjust_size=adjust_size, + device=image_batch.device, + dtype=image_batch.dtype, + ) + result = torch.nn.functional.grid_sample(image_batch, grid) self.assertTupleEqual(result.shape[2:], target_size) + def test_create_scale(self): + inputs = [ + {'scale': None, 'batchsize': 2, 'ndim': 2}, + {'scale': 2, 'batchsize': 2, 'ndim': 2}, + {'scale': [2, 3], 'batchsize': 3, 'ndim': 2}, + {'scale': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, + ] + + expectations = [ + torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], + [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), + torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]]]), + torch.tensor([[[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]]]), + torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], + [[3., 0., 0.], [0., 3., 0.], [0., 0., 1.]], + [[4., 0., 0.], [0., 4., 0.], [0., 0., 1.]]]), + ] + + for inp, exp in zip(inputs, expectations): + with self.subTest(input=inp, expected=exp): + res = create_scale(**inp, image_transform=False).to(exp.dtype) + self.assertTrue(torch.allclose(res, exp, atol=1e-6)) + + with self.assertRaises(ValueError): + create_scale([4, 5, 6, 7], batchsize=3, ndim=2) + + def test_create_translation(self): + inputs = [ + {'offset': None, 'batchsize': 2, 'ndim': 2}, + {'offset': 2, 'batchsize': 2, 'ndim': 2}, + {'offset': [2, 3], 'batchsize': 3, 'ndim': 2}, + {'offset': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, + ] + + expectations = [ + torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[1, 0, 0], [0, 1, 0], [0, 0, 1]]]), + torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], + [[1, 0, 2], [0, 1, 2], [0, 0, 1]]]), + torch.tensor([[[1, 0, 2], [0, 1, 3], [0, 0, 1]], + [[1, 0, 2], [0, 1, 3], [0, 0, 1]], + [[1, 0, 2], [0, 1, 3], [0, 0, 1]]]), + torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], + [[1, 0, 3], [0, 1, 3], [0, 0, 1]], + [[1, 0, 4], [0, 1, 4], [0, 0, 1]]]), + ] + + for inp, exp in zip(inputs, expectations): + with self.subTest(input=inp, expected=exp): + res = create_translation(**inp, image_transform=False).to(exp.dtype) + self.assertTrue(torch.allclose(res, exp, atol=1e-6)) + + with self.assertRaises(ValueError): + create_translation([4, 5, 6, 7], batchsize=3, ndim=2) + + def test_format_rotation(self): + inputs = [ + {'rotation': None, 'batchsize': 2, 'ndim': 3}, + {'rotation': 0, 'degree': True, 'batchsize': 2, 'ndim': 2}, + ] + expectations = [ + torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], + [0., 0., 1., 0.], [0., 0., 0., 1.]], + [[1., 0., 0., 0.], [0., 1., 0., 0.], + [0., 0., 1., 0.], [0., 0., 0., 1.]]]), + torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], + [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), + ] + + for inp, exp in zip(inputs, expectations): + with self.subTest(input=inp, expected=exp): + res = create_rotation(**inp).to(exp.dtype) + self.assertTrue(torch.allclose(res, exp, atol=1e-6)) + + with self.assertRaises(ValueError): + create_rotation([4, 5, 6, 7], batchsize=1, ndim=2) + + def test_matrix_parametrization(self): + inputs = [ + {'scale': None, 'translation': None, 'rotation': None, 'batchsize': 2, 'ndim': 2, + 'dtype': torch.float}, + {'scale': [2, 5], 'translation': [9, 18, 27], + 'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, + 'ndim': 2, 'dtype':torch.float} + ] + + expectations = [ + torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], + [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), + + torch.bmm(torch.bmm(torch.tensor([[[2., 0., 0], [0., 5., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 5., 0.], [0., 0., 1.]], + [[2., 0., 0.], [0., 5., 0.], [0., 0., 1.]]]), + torch.tensor([[[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]], + [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], + [[-1., 0., 0.], [0., -1., 0.], [0., 0., -1.]]])), + torch.tensor([[[1., 0., 9.], [0., 1., 9.], [0., 0., 1.]], + [[1., 0., 18.], [0., 1., 18.], [0., 0., 1.]], + [[1., 0., 27.], [0., 1., 27.], [0., 0., 1.]]])) + ] + + for inp, exp in zip(inputs, expectations): + with self.subTest(input=inp, expected=exp): + res = parametrize_matrix(**inp, image_transform=False).to(exp.dtype) + self.assertTrue(torch.allclose(res, matrix_to_cartesian(exp), atol=1e-6)) + if __name__ == '__main__': unittest.main() diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index fb0405da..c133d35c 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -13,7 +13,7 @@ def test_affine(self): device='cpu') matrix = matrix.expand(image_batch.size(0), -1, -1).clone() - target_sizes = [(121, 97), image_batch.shape[2:], (50, 50), (50, 50), + target_sizes = [(100, 125), image_batch.shape[2:], (50, 50), (50, 50), (45, 50), (45, 50)] for output_size in [None, 50, (45, 50)]: @@ -36,6 +36,36 @@ def test_affine(self): self.assertEqual(sample['label'], result['label']) + def test_affine_assemble_matrix(self): + matrices = [ + [[1., 0.], [0., 1.]], + [[1., 0., 1.], [0., 1., 1.]], + [[1., 0., 1.], [0., 1., 1.], [0., 0., 1.]], + None, + [0., 1., 1., 0.] + ] + expected_matrices = [ + torch.tensor([[1., 0., 0.], [0., 1., 0.]])[None], + torch.tensor([[1., 0., 1.], [0., 1., 1.]])[None], + torch.tensor([[1., 0., 1.], [0., 1., 1.]])[None], + None, + None, + ] + value_error = [False, False, False, True, True] + batch = torch.zeros(1, 1, 10, 10) + + for matrix, expected, ve in zip(matrices, expected_matrices, value_error): + with self.subTest(matrix=matrix, expected=expected): + trafo = Affine(matrix=matrix) + if ve: + with self.assertRaises(ValueError): + assembled = trafo.assemble_matrix( + batch.shape, device=batch.device, dtype=batch.dtype) + else: + assembled = trafo.assemble_matrix( + batch.shape, device=batch.device, dtype=batch.dtype) + self.assertTrue(expected.allclose(assembled)) + def test_affine_stacking(self): affines = [ Affine(scale=1), @@ -64,9 +94,10 @@ def test_stacked_transformation_assembly(self): second_matrix = torch.tensor([[[4., 0., 3.], [0., 5., 4.]]]) trafo = StackedAffine([first_matrix, second_matrix]) - sample = {'data': torch.rand(1, 3, 25, 25)} + sample = torch.rand(1, 3, 25, 25) - matrix = trafo.assemble_matrix(**sample) + matrix = trafo.assemble_matrix(sample.shape, + dtype=sample.dtype, device=sample.device) target_matrix = matrix_to_cartesian( torch.bmm( @@ -78,20 +109,36 @@ def test_stacked_transformation_assembly(self): self.assertTrue(torch.allclose(matrix, target_matrix)) def test_affine_subtypes(self): + sample = {'data': torch.rand(1, 3, 25, 30)} - sample = {'data': torch.rand(10, 3, 25, 25)} trafos = [ - Scale(5), - Rotate(45), - Translate(10), - Resize((5, 4)) + Scale([5, 3], adjust_size=True), + Resize([50, 90]), + Rotate([90], adjust_size=True, degree=True), ] - for trafo in trafos: - with self.subTest(trafo=trafo): - self.assertIsInstance(trafo(**sample)['data'], torch.Tensor) + expected_sizes = [ + (5, 10), + (50, 90), + (30, 25), + ] - self.assertTupleEqual((5, 4), trafos[-1](**sample)['data'].shape[2:]) + for trafo, expected_size in zip(trafos, expected_sizes): + with self.subTest(trafo=trafo, exp_size=expected_size): + result = trafo(**sample)['data'] + self.assertIsInstance(result, torch.Tensor) + self.assertTupleEqual(expected_size, result.shape[-2:]) + + def test_translation_assemble_matrix_with_pixel(self): + trafo = Translate([1, 10, 100], unit='pixel') + sample = torch.rand(3, 3, 100, 100) + expected = torch.tensor([[[1., 0., -0.01], [0., 1., -0.01]], + [[1., 0., -0.1], [0., 1., -0.1]], + [[1., 0., -1.], [0., 1., -1.]]]) + + out = trafo.assemble_matrix( + sample.shape, device=sample.device, dtype=sample.dtype) + self.assertTrue(expected.allclose(out)) if __name__ == '__main__': diff --git a/tests/transforms/test_compose.py b/tests/transforms/test_compose.py index e88af523..196b7ecd 100644 --- a/tests/transforms/test_compose.py +++ b/tests/transforms/test_compose.py @@ -1,3 +1,4 @@ + import unittest import torch @@ -61,12 +62,12 @@ def __call__(self, *args, **kwargs): trafo_a = trafo_a.to(torch.float32) trafo_b = DummyTrafo(torch.tensor([2.], dtype=torch.float32)) trafo_b = trafo_b.to(torch.float32) - self.assertEquals(trafo_a.tmp.dtype, torch.float32) - self.assertEquals(trafo_b.tmp.dtype, torch.float32) + self.assertEqual(trafo_a.tmp.dtype, torch.float32) + self.assertEqual(trafo_b.tmp.dtype, torch.float32) compose = Compose(trafo_a, trafo_b) compose = compose.to(torch.float64) - self.assertEquals(compose.transforms[0].tmp.dtype, torch.float64) + self.assertEqual(compose.transforms[0].tmp.dtype, torch.float64) def test_wrapping_non_module_trafos(self): class DummyTrafo: diff --git a/tests/utils/test_affine.py b/tests/utils/test_affine.py index 8867524c..c6563b72 100644 --- a/tests/utils/test_affine.py +++ b/tests/utils/test_affine.py @@ -1,8 +1,7 @@ import unittest from rising.utils.affine import points_to_homogeneous, matrix_to_homogeneous, \ - matrix_to_cartesian, points_to_cartesian, matrix_revert_coordinate_order, \ - get_batched_eye, _format_scale, _format_translation, deg_to_rad, \ - _format_rotation, parametrize_matrix, assemble_matrix_if_necessary + matrix_to_cartesian, points_to_cartesian, \ + get_batched_eye, deg_to_rad, unit_box import torch import math @@ -128,24 +127,6 @@ def test_matrix_to_cartesian(self): self.assertTrue(torch.allclose(matrix_to_cartesian(inp, keep_square=keep_square), exp)) keep_square = not keep_square - def test_matrix_coordinate_order(self): - inputs = [ - torch.tensor([[[1, 2, 3], - [4, 5, 6], - [7, 8, 9]]]) - ] - - expectations = [ - torch.tensor([[[5, 4, 6], - [2, 1, 3], - [7, 8, 9]]]) - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - self.assertTrue(torch.allclose(matrix_revert_coordinate_order(inp), exp)) - # self.assertTrue(torch.allclose(inp, matrix_revert_coordinate_order(exp))) - def test_batched_eye(self): for dtype in [torch.float, torch.long]: for ndim in range(10): @@ -159,85 +140,6 @@ def test_batched_eye(self): for _eye in batched_eye: self.assertTrue(torch.allclose(_eye, non_batched_eye, atol=1e-6)) - def test_format_scale(self): - inputs = [ - {'scale': None, 'batchsize': 2, 'ndim': 2}, - {'scale': 2, 'batchsize': 2, 'ndim': 2}, - {'scale': [2, 3], 'batchsize': 3, 'ndim': 2}, - {'scale': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, - {'scale': [[2, 3], [4, 5]], 'batchsize': 3, 'ndim': 2}, - ] - - expectations = [ - torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], - [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), - torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], - [[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]]]), - torch.tensor([[[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]], - [[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]], - [[2., 0., 0.], [0., 3., 0.], [0., 0., 1.]]]), - torch.tensor([[[2., 0., 0.], [0., 2., 0.], [0., 0., 1.]], - [[3., 0., 0.], [0., 3., 0.], [0., 0., 1.]], - [[4., 0., 0.], [0., 4., 0.], [0., 0., 1.]]]), - torch.tensor([[[2, 3, 0], [4, 5, 0], [0, 0, 1]], - [[2, 3, 0], [4, 5, 0], [0, 0, 1]], - [[2, 3, 0], [4, 5, 0], [0, 0, 1]]]) - - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = _format_scale(**inp).to(exp.dtype) - self.assertTrue(torch.allclose(res, exp, atol=1e-6)) - - with self.assertRaises(ValueError): - _format_scale([4, 5, 6, 7], batchsize=3, ndim=2) - - def test_format_translation(self): - inputs = [ - {'offset': None, 'batchsize': 2, 'ndim': 2}, - {'offset': 2, 'batchsize': 2, 'ndim': 2}, - {'offset': [2, 3], 'batchsize': 3, 'ndim': 2}, - {'offset': [2, 3, 4], 'batchsize': 3, 'ndim': 2}, - {'offset': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], - 'batchsize': 3, 'ndim': 2}, - {'offset': [[[1, 2, 3], [4, 5, 6]], - [[10, 11, 12], [13, 14, 15]], - [[19, 20, 21], [22, 23, 24]]], - 'batchsize': 3, 'ndim': 2} - - ] - - expectations = [ - torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[1, 0, 0], [0, 1, 0], [0, 0, 1]]]), - torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], - [[1, 0, 2], [0, 1, 2], [0, 0, 1]]]), - torch.tensor([[[1, 0, 2], [0, 1, 3], [0, 0, 1]], - [[1, 0, 2], [0, 1, 3], [0, 0, 1]], - [[1, 0, 2], [0, 1, 3], [0, 0, 1]]]), - torch.tensor([[[1, 0, 2], [0, 1, 2], [0, 0, 1]], - [[1, 0, 3], [0, 1, 3], [0, 0, 1]], - [[1, 0, 4], [0, 1, 4], [0, 0, 1]]]), - torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]), - torch.tensor([[[1, 2, 3], [4, 5, 6], [0, 0, 1]], - [[10, 11, 12], [13, 14, 15], [0, 0, 1]], - [[19, 20, 21], [22, 23, 24], [0, 0, 1]]]) - - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = _format_translation(**inp).to(exp.dtype) - self.assertTrue(torch.allclose(res, exp, atol=1e-6)) - - with self.assertRaises(ValueError): - _format_translation([4, 5, 6, 7], batchsize=3, ndim=2) - def test_deg_to_rad(self): inputs = [ torch.tensor([tmp * 45. for tmp in range(9)]), @@ -251,117 +153,35 @@ def test_deg_to_rad(self): with self.subTest(input=inp, expected=exp): self.assertTrue(torch.allclose(deg_to_rad(inp), exp, atol=1e-6)) - def test_format_rotation(self): - inputs = [ - {'rotation': None, 'batchsize': 2, 'ndim': 3}, - {'rotation': 0, 'degree': True, 'batchsize': 2, 'ndim': 2}, - {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 2, 'ndim': 3}, - {'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, 'ndim': 2}, - {'rotation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]]], - 'batchsize': 2, 'ndim': 2}, - {'rotation': [[[1, 2, 3], [4, 5, 6]], - [[10, 11, 12], [13, 14, 15]]], - 'batchsize': 2, 'ndim': 2}, - {'rotation': [[1, 2], [3, 4]], 'batchsize': 3, 'ndim': 2, 'degree': False} - - ] - expectations = [ - torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]], - [[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]]]), - torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], - [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), - torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]], - [[1., 0., 0., 0.], [0., 1., 0., 0.], - [0., 0., 1., 0.], [0., 0., 0., 1.]]]), - torch.tensor([[[-1, 0, 0], [0, -1, 0], [0, 0, 1]], - [[-1, 0, 0], [0, -1, 0], [0, 0, 1]], - [[-1, 0, 0], [0, -1, 0], [0, 0, 1]]]), - torch.tensor([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], - [[10., 11., 12.], [13., 14., 15.], [16., 17., 18.]]]), - torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], - [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]]]), - torch.tensor([[[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], - [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]], - [[1., 2., 0.], [3., 4., 0.], [0., 0., 1.]]]) - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = _format_rotation(**inp).to(exp.dtype) - self.assertTrue(torch.allclose(res, exp, atol=1e-6)) - - with self.assertRaises(ValueError): - _format_rotation([4, 5, 6, 7], batchsize=1, ndim=2) - - def test_matrix_parametrization(self): - inputs = [ - {'scale': None, 'translation': None, 'rotation': None, 'batchsize': 2, 'ndim': 2, - 'dtype': torch.float}, - {'scale': [[2, 3], [4, 5]], 'translation': [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[10, 11, 12], [13, 14, 15], [16, 17, 18]], - [[19, 20, 21], [22, 23, 24], [25, 26, 27]]], - 'rotation': [180, 0, 180], 'degree': True, 'batchsize': 3, - 'ndim': 2, 'dtype':torch.float} - ] - - expectations = [ - torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], - [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]), - - torch.bmm(torch.bmm(torch.tensor([[[2., 3., 0], [4., 5., 0.], [0., 0., 1.]], - [[2., 3., 0.], [4., 5., 0.], [0., 0., 1.]], - [[2., 3., 0.], [4., 5., 0.], [0., 0., 1.]]]), - torch.tensor([[[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]], - [[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]], - [[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]]])), - torch.tensor([[[1., 2., 3.], [4., 5., 6.], [0., 0., 1.]], - [[10., 11., 12.], [13., 14., 15.], [0., 0., 1.]], - [[19., 20., 21.], [22., 23., 24.], [0., 0., 1.]]])) - - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = parametrize_matrix(**inp).to(exp.dtype) - self.assertTrue(torch.allclose(res, matrix_to_cartesian(exp), atol=1e-6)) - - def test_necessary_assembly(self): - inputs = [ - {'matrix': None, 'translation': [2, 3], 'ndim':2, 'batchsize': 3, - 'dtype': torch.float}, - {'matrix': [[1., 0., 4.], [0., 1., 5.], [0., 0., 1.]], 'translation': [2, 3], 'ndim': 2, 'batchsize': 3, - 'dtype': torch.float}, - {'matrix': [[1., 0., 4.], [0., 1., 5.]], 'translation': [2, 3], 'ndim': 2, 'batchsize': 3, - 'dtype': torch.float} - - ] - expectations = [ - torch.tensor([[[1., 0., 2.], [0., 1., 3.]], - [[1., 0., 2.], [0., 1., 3.]], - [[1., 0., 2.], [0., 1., 3.]]]), - torch.tensor([[[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]]]), - torch.tensor([[[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]], - [[1., 0., 4.], [0., 1., 5.]]]) - ] - - for inp, exp in zip(inputs, expectations): - with self.subTest(input=inp, expected=exp): - res = assemble_matrix_if_necessary(**inp, degree=False, - device='cpu', scale=None, rotation=None).to(exp.dtype) - self.assertTrue(torch.allclose(res, exp, atol=1e-6)) - - with self.assertRaises(ValueError): - assemble_matrix_if_necessary(matrix=[1, 2, 3, 4, 5], scale=None, - rotation=None, translation=None, - degree=False, dtype=torch.float, - device='cpu', batchsize=1, ndim=2) + def test_unit_box_2d(self): + curr_img_size = torch.tensor([2, 3]) + box = torch.tensor([[0., 0.], [0., curr_img_size[1]], + [curr_img_size[0], 0], curr_img_size]) + created_box = unit_box(2, curr_img_size).to(box) + self.compare_points_unordered(box, created_box) + + def compare_points_unordered(self, points0: torch.Tensor, points1: torch.Tensor): + self.assertEqual(tuple(points0.shape), tuple(points1.shape)) + for point in points0: + comp = point[None] == points1 + comp = comp.sum(dim=1) == comp.shape[1] + self.assertTrue(comp.any()) + + def test_unit_box_3d(self): + curr_img_size = torch.tensor([2, 3, 4]) + box = torch.tensor( + [ + [0., 0., 0.], + [0., 0., curr_img_size[2]], + [0., curr_img_size[1], 0], + [0., curr_img_size[1], curr_img_size[2]], + [curr_img_size[0], 0., 0.], + [curr_img_size[0], 0., curr_img_size[2]], + [curr_img_size[0], curr_img_size[1], 0.], + curr_img_size + ]) + created_box = unit_box(3, curr_img_size).to(box) + self.compare_points_unordered(box, created_box) if __name__ == '__main__':