diff --git a/FastGeodis/__init__.py b/FastGeodis/__init__.py index b677288..2aa2176 100644 --- a/FastGeodis/__init__.py +++ b/FastGeodis/__init__.py @@ -35,11 +35,7 @@ def generalised_geodesic2d( - image: torch.Tensor, - softmask: torch.Tensor, - v: float, - lamb: float, - iter: int = 2 + image: torch.Tensor, softmask: torch.Tensor, v: float, lamb: float, iter: int = 2 ): r"""Computes Generalised Geodesic Distance using FastGeodis raster scanning. For more details on generalised geodesic distance, check the following reference: @@ -99,11 +95,7 @@ def generalised_geodesic3d( def signed_generalised_geodesic2d( - image: torch.Tensor, - softmask: torch.Tensor, - v: float, - lamb: float, - iter: int = 2 + image: torch.Tensor, softmask: torch.Tensor, v: float, lamb: float, iter: int = 2 ): r"""Computes Signed Generalised Geodesic Distance using FastGeodis raster scanning. For more details on generalised geodesic distance, check the following reference: @@ -163,18 +155,14 @@ def signed_generalised_geodesic3d( def generalised_geodesic2d_toivanen( - image: torch.Tensor, - softmask: torch.Tensor, - v: float, - lamb: float, - iter: int = 2 + image: torch.Tensor, softmask: torch.Tensor, v: float, lamb: float, iter: int = 2 ): r"""Computes Generalised Geodesic Distance using Toivanen's raster scanning method from: - Toivanen, Pekka J. - "New geodosic distance transforms for gray-scale images." + Toivanen, Pekka J. + "New geodosic distance transforms for gray-scale images." Pattern Recognition Letters 17.5 (1996): 437-450. - + For more details on generalised geodesic distance, check the following reference: Criminisi, Antonio, Toby Sharp, and Andrew Blake. @@ -208,10 +196,10 @@ def generalised_geodesic3d_toivanen( ): r"""Computes Generalised Geodesic Distance using Toivanen's raster scanning method from: - Toivanen, Pekka J. - "New geodosic distance transforms for gray-scale images." + Toivanen, Pekka J. + "New geodosic distance transforms for gray-scale images." Pattern Recognition Letters 17.5 (1996): 437-450. - + For more details on generalised geodesic distance, check the following reference: @@ -236,17 +224,14 @@ def generalised_geodesic3d_toivanen( image, softmask, spacing, v, lamb, 1 - lamb, iter ) + def signed_generalised_geodesic2d_toivanen( - image: torch.Tensor, - softmask: torch.Tensor, - v: float, - lamb: float, - iter: int = 2 + image: torch.Tensor, softmask: torch.Tensor, v: float, lamb: float, iter: int = 2 ): r"""Computes Signed Generalised Geodesic Distance using Toivanen's raster scanning method from: - Toivanen, Pekka J. - "New geodosic distance transforms for gray-scale images." + Toivanen, Pekka J. + "New geodosic distance transforms for gray-scale images." Pattern Recognition Letters 17.5 (1996): 437-450. For more details on generalised geodesic distance, check the following reference: @@ -282,8 +267,8 @@ def signed_generalised_geodesic3d_toivanen( ): r"""Computes Signed Generalised Geodesic Distance using Toivanen's raster scanning method from: - Toivanen, Pekka J. - "New geodosic distance transforms for gray-scale images." + Toivanen, Pekka J. + "New geodosic distance transforms for gray-scale images." Pattern Recognition Letters 17.5 (1996): 437-450. For more details on generalised geodesic distance, check the following reference: @@ -309,17 +294,14 @@ def signed_generalised_geodesic3d_toivanen( image, softmask, spacing, v, lamb, 1 - lamb, iter ) -def geodesic2d_pixelqueue( - image: torch.Tensor, - seed: torch.Tensor, - lamb: float -): + +def geodesic2d_pixelqueue(image: torch.Tensor, seed: torch.Tensor, lamb: float): r"""Computes Geodesic Distance using Pixel Queue method from: - - Ikonen, L., & Toivanen, P. (2007). + + Ikonen, L., & Toivanen, P. (2007). "Distance and nearest neighbor transforms on gray-level surfaces." Pattern Recognition Letters, 28(5), 604-612. - + The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location Args: @@ -330,23 +312,18 @@ def geodesic2d_pixelqueue( Returns: torch.Tensor with distance transform """ - return FastGeodisCpp.geodesic2d_pixelqueue( - image, seed, lamb, 1 - lamb - ) + return FastGeodisCpp.geodesic2d_pixelqueue(image, seed, lamb, 1 - lamb) def geodesic3d_pixelqueue( - image: torch.Tensor, - seed: torch.Tensor, - spacing: List, - lamb: float + image: torch.Tensor, seed: torch.Tensor, spacing: List, lamb: float ): r"""Computes Geodesic Distance using Pixel Queue method from: - - Ikonen, L., & Toivanen, P. (2007). + + Ikonen, L., & Toivanen, P. (2007). "Distance and nearest neighbor transforms on gray-level surfaces." Pattern Recognition Letters, 28(5), 604-612. - + The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location Args: @@ -358,21 +335,16 @@ def geodesic3d_pixelqueue( Returns: torch.Tensor with distance transform """ - return FastGeodisCpp.geodesic3d_pixelqueue( - image, seed, spacing, lamb, 1 - lamb - ) + return FastGeodisCpp.geodesic3d_pixelqueue(image, seed, spacing, lamb, 1 - lamb) -def signed_geodesic2d_pixelqueue( - image: torch.Tensor, - seed: torch.Tensor, - lamb: float -): + +def signed_geodesic2d_pixelqueue(image: torch.Tensor, seed: torch.Tensor, lamb: float): r"""Computes Signed Generalised Geodesic Distance using Pixel Queue method from: - - Ikonen, L., & Toivanen, P. (2007). + + Ikonen, L., & Toivanen, P. (2007). "Distance and nearest neighbor transforms on gray-level surfaces." Pattern Recognition Letters, 28(5), 604-612. - + The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location Args: @@ -383,23 +355,18 @@ def signed_geodesic2d_pixelqueue( Returns: torch.Tensor with distance transform """ - return FastGeodisCpp.signed_geodesic2d_pixelqueue( - image, seed, lamb, 1 - lamb - ) + return FastGeodisCpp.signed_geodesic2d_pixelqueue(image, seed, lamb, 1 - lamb) def signed_geodesic3d_pixelqueue( - image: torch.Tensor, - seed: torch.Tensor, - spacing: List, - lamb: float + image: torch.Tensor, seed: torch.Tensor, spacing: List, lamb: float ): r"""Computes Signed Geodesic Distance using Pixel Queue method from: - - Ikonen, L., & Toivanen, P. (2007). + + Ikonen, L., & Toivanen, P. (2007). "Distance and nearest neighbor transforms on gray-level surfaces." Pattern Recognition Letters, 28(5), 604-612. - + The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location Args: @@ -416,17 +383,14 @@ def signed_geodesic3d_pixelqueue( image, seed, spacing, lamb, 1 - lamb ) -def geodesic2d_fastmarch( - image: torch.Tensor, - seed: torch.Tensor, - lamb: float -): + +def geodesic2d_fastmarch(image: torch.Tensor, seed: torch.Tensor, lamb: float): r"""Computes Geodesic Distance using Fast Marching method from: - Sethian, James A. - "Fast marching methods." + Sethian, James A. + "Fast marching methods." SIAM review 41.2 (1999): 199-235. - + The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location Args: @@ -437,23 +401,18 @@ def geodesic2d_fastmarch( Returns: torch.Tensor with distance transform """ - return FastGeodisCpp.geodesic2d_fastmarch( - image, seed, lamb, 1 - lamb - ) + return FastGeodisCpp.geodesic2d_fastmarch(image, seed, lamb, 1 - lamb) def geodesic3d_fastmarch( - image: torch.Tensor, - seed: torch.Tensor, - spacing: List, - lamb: float + image: torch.Tensor, seed: torch.Tensor, spacing: List, lamb: float ): r"""Computes Geodesic Distance using Fast Marching method from: - TSethian, James A. - "Fast marching methods." + TSethian, James A. + "Fast marching methods." SIAM review 41.2 (1999): 199-235. - + The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location Args: @@ -465,21 +424,16 @@ def geodesic3d_fastmarch( Returns: torch.Tensor with distance transform """ - return FastGeodisCpp.geodesic3d_fastmarch( - image, seed, spacing, lamb, 1 - lamb - ) + return FastGeodisCpp.geodesic3d_fastmarch(image, seed, spacing, lamb, 1 - lamb) -def signed_geodesic2d_fastmarch( - image: torch.Tensor, - seed: torch.Tensor, - lamb: float -): + +def signed_geodesic2d_fastmarch(image: torch.Tensor, seed: torch.Tensor, lamb: float): r"""Computes Signed Geodesic Distance using Fast Marching method from: - Sethian, James A. - "Fast marching methods." + Sethian, James A. + "Fast marching methods." SIAM review 41.2 (1999): 199-235. - + The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location Args: @@ -490,23 +444,18 @@ def signed_geodesic2d_fastmarch( Returns: torch.Tensor with distance transform """ - return FastGeodisCpp.signed_geodesic2d_fastmarch( - image, seed, lamb, 1 - lamb - ) + return FastGeodisCpp.signed_geodesic2d_fastmarch(image, seed, lamb, 1 - lamb) def signed_geodesic3d_fastmarch( - image: torch.Tensor, - seed: torch.Tensor, - spacing: List, - lamb: float + image: torch.Tensor, seed: torch.Tensor, spacing: List, lamb: float ): r"""Computes Signed Geodesic Distance using Fast Marching method from: - Sethian, James A. - "Fast marching methods." + Sethian, James A. + "Fast marching methods." SIAM review 41.2 (1999): 199-235. - + The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location Args: @@ -523,6 +472,7 @@ def signed_geodesic3d_fastmarch( image, seed, spacing, lamb, 1 - lamb ) + def GSF2d( image: torch.Tensor, softmask: torch.Tensor, @@ -584,6 +534,7 @@ def GSF3d( """ return FastGeodisCpp.GSF3d(image, softmask, theta, spacing, v, lamb, iter) + def GSF2d_toivanen( image: torch.Tensor, softmask: torch.Tensor, @@ -594,8 +545,8 @@ def GSF2d_toivanen( ): r"""Computes Geodesic Symmetric Filtering (GSF) using Toivanen's raster scanning method from: - Toivanen, Pekka J. - "New geodosic distance transforms for gray-scale images." + Toivanen, Pekka J. + "New geodosic distance transforms for gray-scale images." Pattern Recognition Letters 17.5 (1996): 437-450. For more details on GSF, check the following reference: @@ -630,10 +581,10 @@ def GSF3d_toivanen( ): r"""Computes Geodesic Symmetric Filtering (GSF) using Toivanen's raster scanning method from: - Toivanen, Pekka J. - "New geodosic distance transforms for gray-scale images." + Toivanen, Pekka J. + "New geodosic distance transforms for gray-scale images." Pattern Recognition Letters 17.5 (1996): 437-450. - + For more details on GSF, check the following reference: Criminisi, Antonio, Toby Sharp, and Andrew Blake. @@ -655,15 +606,13 @@ def GSF3d_toivanen( """ return FastGeodisCpp.GSF3d_toivanen(image, softmask, theta, spacing, v, lamb, iter) + def GSF2d_pixelqueue( - image: torch.Tensor, - seed: torch.Tensor, - theta: float, - lamb: float + image: torch.Tensor, seed: torch.Tensor, theta: float, lamb: float ): r"""Computes Geodesic Symmetric Filtering (GSF) using Pixel Queue method from: - Ikonen, L., & Toivanen, P. (2007). + Ikonen, L., & Toivanen, P. (2007). "Distance and nearest neighbor transforms on gray-level surfaces." Pattern Recognition Letters, 28(5), 604-612. @@ -685,6 +634,7 @@ def GSF2d_pixelqueue( """ return FastGeodisCpp.GSF2d_pixelqueue(image, seed, theta, lamb) + def GSF3d_pixelqueue( image: torch.Tensor, seed: torch.Tensor, @@ -694,7 +644,7 @@ def GSF3d_pixelqueue( ): r"""Computes Geodesic Symmetric Filtering (GSF) using Pixel Queue method from: - Ikonen, L., & Toivanen, P. (2007). + Ikonen, L., & Toivanen, P. (2007). "Distance and nearest neighbor transforms on gray-level surfaces." Pattern Recognition Letters, 28(5), 604-612. @@ -717,16 +667,12 @@ def GSF3d_pixelqueue( """ return FastGeodisCpp.GSF3d_pixelqueue(image, seed, theta, spacing, lamb) -def GSF2d_fastmarch( - image: torch.Tensor, - seed: torch.Tensor, - theta: float, - lamb: float -): + +def GSF2d_fastmarch(image: torch.Tensor, seed: torch.Tensor, theta: float, lamb: float): r"""Computes Geodesic Symmetric Filtering (GSF) using Fast Marching method from: - Sethian, James A. - "Fast marching methods." + Sethian, James A. + "Fast marching methods." SIAM review 41.2 (1999): 199-235. For more details on GSF, check the following reference: @@ -757,8 +703,8 @@ def GSF3d_fastmarch( ): r"""Computes Geodesic Symmetric Filtering (GSF) using Fast Marching method from: - Sethian, James A. - "Fast marching methods." + Sethian, James A. + "Fast marching methods." SIAM review 41.2 (1999): 199-235. For more details on GSF, check the following reference: @@ -779,3 +725,147 @@ def GSF3d_fastmarch( torch.Tensor with distance transform """ return FastGeodisCpp.GSF3d_fastmarch(image, seed, theta, spacing, lamb) + + +def exact_euclidean2d(mask: torch.Tensor, spacing: List = [1.0, 1.0]): + r"""Computes Exact Euclidean Distance Transform using the PBA+ (Parallel Banding Algorithm Plus) + algorithm from: + + Cao, Thanh-Tung, Ke Tang, Anis Mohamed, and Tiow-Seng Tan. + "Parallel banding algorithm to compute exact distance transform with the GPU." + In Proceedings of the 2010 ACM SIGGRAPH symposium on Interactive 3D Graphics and Games, pp. 83-90. 2010. + + This function computes the EXACT Euclidean distance transform, unlike the approximate methods + provided by generalised_geodesic2d with lamb=0.0. + + The function expects input as torch.Tensor on CUDA device. + + Args: + mask: binary mask where 0 indicates seed points (distance=0) and 1 indicates background. + Should be a 4D tensor with shape (B, C, H, W). Supports arbitrary batch and channel sizes. + spacing: pixel spacing [spacing_height, spacing_width] to match [H, W] tensor convention. + Default is [1.0, 1.0]. + + Returns: + torch.Tensor with exact Euclidean distance transform + + Note: + - GPU only: Requires CUDA. No CPU fallback is available. + - Memory: Uses ~21 bytes per pixel (input + int32 Voronoi buffer + float32 output). + + Example: + >>> import torch + >>> import FastGeodis + >>> mask = torch.ones(1, 1, 512, 512, device='cuda') + >>> mask[0, 0, 256, 256] = 0 # Single seed point + >>> distance = FastGeodis.exact_euclidean2d(mask, spacing=[1.0, 1.0]) + """ + return FastGeodisCpp.exact_euclidean2d(mask, spacing) + + +def exact_euclidean3d(mask: torch.Tensor, spacing: List = [1.0, 1.0, 1.0]): + r"""Computes Exact Euclidean Distance Transform using the PBA+ (Parallel Banding Algorithm Plus) + algorithm from: + + Cao, Thanh-Tung, Ke Tang, Anis Mohamed, and Tiow-Seng Tan. + "Parallel banding algorithm to compute exact distance transform with the GPU." + In Proceedings of the 2010 ACM SIGGRAPH symposium on Interactive 3D Graphics and Games, pp. 83-90. 2010. + + This function computes the EXACT Euclidean distance transform for 3D volumetric data, unlike + the approximate methods provided by generalised_geodesic3d with lamb=0.0. + + The function expects input as torch.Tensor on CUDA device. + + Args: + mask: binary mask where 0 indicates seed points (distance=0) and 1 indicates background. + Should be a 5D tensor with shape (B, C, D, H, W). Supports arbitrary batch and channel sizes. + spacing: voxel spacing [spacing_depth, spacing_height, spacing_width] to match [D, H, W] + tensor convention. Default is [1.0, 1.0, 1.0]. + + Returns: + torch.Tensor with exact Euclidean distance transform + + Note: + - GPU only: Requires CUDA. No CPU fallback is available. + - Memory: Uses ~29 bytes per voxel (input + int32 Voronoi buffer + float32 output). + + Example: + >>> import torch + >>> import FastGeodis + >>> mask = torch.ones(1, 1, 128, 128, 128, device='cuda') + >>> mask[0, 0, 64, 64, 64] = 0 # Single seed point + >>> distance = FastGeodis.exact_euclidean3d(mask, spacing=[1.0, 1.0, 1.0]) + """ + return FastGeodisCpp.exact_euclidean3d(mask, spacing) + + +def signed_exact_euclidean2d(mask: torch.Tensor, spacing: List = [1.0, 1.0]): + r"""Computes Signed Exact Euclidean Distance Transform using the PBA+ algorithm. + + This function computes the signed distance where: + - Negative values: distance inside the foreground region (mask=1) to nearest boundary + - Positive values: distance outside the foreground region (mask=0) to nearest boundary + + This follows the convention where distance is negative inside the object, matching + common signed distance field (SDF) conventions. + + The function expects input as torch.Tensor on CUDA device. + + Args: + mask: binary mask where 0 indicates background and 1 indicates foreground. + Should be a 4D tensor with shape (B, C, H, W). Supports arbitrary batch and channel sizes. + spacing: pixel spacing [spacing_height, spacing_width] to match [H, W] tensor convention. + Default is [1.0, 1.0]. + + Returns: + torch.Tensor with signed exact Euclidean distance transform + + Note: + - GPU only: Requires CUDA. No CPU fallback is available. + - Memory: Uses ~42 bytes per pixel (2x the unsigned version for inside/outside computation). + + Example: + >>> import torch + >>> import FastGeodis + >>> mask = torch.zeros(1, 1, 512, 512, device='cuda') + >>> mask[0, 0, 200:300, 200:300] = 1 # Square region + >>> signed_distance = FastGeodis.signed_exact_euclidean2d(mask, spacing=[1.0, 1.0]) + >>> # signed_distance is negative inside the square, positive outside + """ + return FastGeodisCpp.signed_exact_euclidean2d(mask, spacing) + + +def signed_exact_euclidean3d(mask: torch.Tensor, spacing: List = [1.0, 1.0, 1.0]): + r"""Computes Signed Exact Euclidean Distance Transform for 3D volumetric data using the PBA+ algorithm. + + This function computes the signed distance where: + - Negative values: distance inside the foreground region (mask=1) to nearest boundary + - Positive values: distance outside the foreground region (mask=0) to nearest boundary + + This follows the convention where distance is negative inside the object, matching + common signed distance field (SDF) conventions. + + The function expects input as torch.Tensor on CUDA device. + + Args: + mask: binary mask where 0 indicates background and 1 indicates foreground. + Should be a 5D tensor with shape (B, C, D, H, W). Supports arbitrary batch and channel sizes. + spacing: voxel spacing [spacing_depth, spacing_height, spacing_width] to match [D, H, W] + tensor convention. Default is [1.0, 1.0, 1.0]. + + Returns: + torch.Tensor with signed exact Euclidean distance transform + + Note: + - GPU only: Requires CUDA. No CPU fallback is available. + - Memory: Uses ~58 bytes per voxel (2x the unsigned version for inside/outside computation). + + Example: + >>> import torch + >>> import FastGeodis + >>> mask = torch.zeros(1, 1, 128, 128, 128, device='cuda') + >>> mask[0, 0, 40:80, 40:80, 40:80] = 1 # Cube region + >>> signed_distance = FastGeodis.signed_exact_euclidean3d(mask, spacing=[1.0, 1.0, 1.0]) + >>> # signed_distance is negative inside the cube, positive outside + """ + return FastGeodisCpp.signed_exact_euclidean3d(mask, spacing) diff --git a/FastGeodis/fastgeodis.cpp b/FastGeodis/fastgeodis.cpp index 1367acc..a7b23d2 100755 --- a/FastGeodis/fastgeodis.cpp +++ b/FastGeodis/fastgeodis.cpp @@ -34,6 +34,7 @@ #include #include "fastgeodis.h" #include "common.h" +#include "geodis_pba.h" #ifdef _OPENMP #include @@ -420,6 +421,154 @@ torch::Tensor GSF3d_fastmarch(const torch::Tensor &image, const torch::Tensor &m return Dd_Md + De_Me; } +torch::Tensor exact_euclidean2d(const torch::Tensor &mask, const std::vector &spacing) +{ + // Check input dimensions - expect BCHW format + const int num_dims = mask.dim(); + if (num_dims != 4) + { + throw std::invalid_argument( + "exact_euclidean2d only supports 4D inputs (BCHW), received " + std::to_string(num_dims) + "D"); + } + + // Note: batch and channel dimensions are now supported + + if (spacing.size() != 2) + { + throw std::invalid_argument( + "exact_euclidean2d requires 2D spacing, received " + std::to_string(spacing.size())); + } + + if (mask.is_cuda()) + { + #ifdef WITH_CUDA + if (!torch::cuda::is_available()) + { + throw std::runtime_error( + "cuda.is_available() returned false, please check if the library was compiled successfully with CUDA support"); + } + return exact_euclidean2d_cuda(mask, spacing); + #else + AT_ERROR("exact_euclidean2d is only available with CUDA support. Not compiled with CUDA."); + #endif + } + else + { + AT_ERROR("exact_euclidean2d is only available on CUDA devices. Please move tensor to GPU."); + } +} + +torch::Tensor exact_euclidean3d(const torch::Tensor &mask, const std::vector &spacing) +{ + // Check input dimensions - expect BCDHW format + const int num_dims = mask.dim(); + if (num_dims != 5) + { + throw std::invalid_argument( + "exact_euclidean3d only supports 5D inputs (BCDHW), received " + std::to_string(num_dims) + "D"); + } + + // Note: batch and channel dimensions are now supported + + if (spacing.size() != 3) + { + throw std::invalid_argument( + "exact_euclidean3d requires 3D spacing, received " + std::to_string(spacing.size())); + } + + if (mask.is_cuda()) + { + #ifdef WITH_CUDA + if (!torch::cuda::is_available()) + { + throw std::runtime_error( + "cuda.is_available() returned false, please check if the library was compiled successfully with CUDA support"); + } + return exact_euclidean3d_cuda(mask, spacing); + #else + AT_ERROR("exact_euclidean3d is only available with CUDA support. Not compiled with CUDA."); + #endif + } + else + { + AT_ERROR("exact_euclidean3d is only available on CUDA devices. Please move tensor to GPU."); + } +} + +torch::Tensor signed_exact_euclidean2d(const torch::Tensor &mask, const std::vector &spacing) +{ + // Check input dimensions - expect BCHW format + const int num_dims = mask.dim(); + if (num_dims != 4) + { + throw std::invalid_argument( + "signed_exact_euclidean2d only supports 4D inputs (BCHW), received " + std::to_string(num_dims) + "D"); + } + + // Note: batch and channel dimensions are now supported + + if (spacing.size() != 2) + { + throw std::invalid_argument( + "signed_exact_euclidean2d requires 2D spacing, received " + std::to_string(spacing.size())); + } + + if (mask.is_cuda()) + { + #ifdef WITH_CUDA + if (!torch::cuda::is_available()) + { + throw std::runtime_error( + "cuda.is_available() returned false, please check if the library was compiled successfully with CUDA support"); + } + return signed_exact_euclidean2d_cuda(mask, spacing); + #else + AT_ERROR("signed_exact_euclidean2d is only available with CUDA support. Not compiled with CUDA."); + #endif + } + else + { + AT_ERROR("signed_exact_euclidean2d is only available on CUDA devices. Please move tensor to GPU."); + } +} + +torch::Tensor signed_exact_euclidean3d(const torch::Tensor &mask, const std::vector &spacing) +{ + // Check input dimensions - expect BCDHW format + const int num_dims = mask.dim(); + if (num_dims != 5) + { + throw std::invalid_argument( + "signed_exact_euclidean3d only supports 5D inputs (BCDHW), received " + std::to_string(num_dims) + "D"); + } + + // Note: batch and channel dimensions are now supported + + if (spacing.size() != 3) + { + throw std::invalid_argument( + "signed_exact_euclidean3d requires 3D spacing, received " + std::to_string(spacing.size())); + } + + if (mask.is_cuda()) + { + #ifdef WITH_CUDA + if (!torch::cuda::is_available()) + { + throw std::runtime_error( + "cuda.is_available() returned false, please check if the library was compiled successfully with CUDA support"); + } + return signed_exact_euclidean3d_cuda(mask, spacing); + #else + AT_ERROR("signed_exact_euclidean3d is only available with CUDA support. Not compiled with CUDA."); + #endif + } + else + { + AT_ERROR("signed_exact_euclidean3d is only available on CUDA devices. Please move tensor to GPU."); + } +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("generalised_geodesic2d", &generalised_geodesic2d, "Generalised Geodesic distance 2d"); @@ -449,4 +598,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("GSF2d_fastmarch", &GSF2d_fastmarch, "Geodesic Symmetric Filtering 2d using Fast Marching method"); m.def("GSF3d_fastmarch", &GSF3d_fastmarch, "Geodesic Symmetric Filtering 3d using Fast Marching method"); + // Exact Euclidean Distance Transform using PBA+ algorithm + m.def("exact_euclidean2d", &exact_euclidean2d, "Exact Euclidean Distance Transform 2D using PBA+ algorithm"); + m.def("exact_euclidean3d", &exact_euclidean3d, "Exact Euclidean Distance Transform 3D using PBA+ algorithm"); + m.def("signed_exact_euclidean2d", &signed_exact_euclidean2d, "Signed Exact Euclidean Distance Transform 2D using PBA+ algorithm"); + m.def("signed_exact_euclidean3d", &signed_exact_euclidean3d, "Signed Exact Euclidean Distance Transform 3D using PBA+ algorithm"); + } \ No newline at end of file diff --git a/FastGeodis/fastgeodis.h b/FastGeodis/fastgeodis.h index 35b837f..49a9742 100644 --- a/FastGeodis/fastgeodis.h +++ b/FastGeodis/fastgeodis.h @@ -33,6 +33,7 @@ #include #include #include "common.h" +#include "geodis_pba.h" #ifdef WITH_CUDA torch::Tensor generalised_geodesic2d_cuda( @@ -293,4 +294,21 @@ torch::Tensor GSF3d_fastmarch( const torch::Tensor &mask, const float &theta, const std::vector &spacing, - const float &lambda); \ No newline at end of file + const float &lambda); + +// Exact Euclidean Distance Transform using PBA+ algorithm +torch::Tensor exact_euclidean2d( + const torch::Tensor &mask, + const std::vector &spacing); + +torch::Tensor exact_euclidean3d( + const torch::Tensor &mask, + const std::vector &spacing); + +torch::Tensor signed_exact_euclidean2d( + const torch::Tensor &mask, + const std::vector &spacing); + +torch::Tensor signed_exact_euclidean3d( + const torch::Tensor &mask, + const std::vector &spacing); \ No newline at end of file diff --git a/FastGeodis/geodis_pba.cu b/FastGeodis/geodis_pba.cu new file mode 100644 index 0000000..cb9e9d0 --- /dev/null +++ b/FastGeodis/geodis_pba.cu @@ -0,0 +1,1249 @@ +// BSD 3-Clause License + +// Copyright (c) 2021, Muhammad Asad (masadcv@gmail.com) +// All rights reserved. + +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. + +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. + +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// PBA+ (Parallel Banding Algorithm Plus) for Exact Euclidean Distance Transform +// Based on: Cao et al. (2010) "Parallel Banding Algorithm to compute exact distance transform with the GPU" +// +// This implementation is adapted from the CUCIM project (MIT License): +// https://github.com/rapidsai/cucim +// Original PBA+ implementation: https://github.com/orzzzjq/Parallel-Banding-Algorithm-plus +// +// The parallel banding algorithm was originally described in: +// Thanh-Tung Cao, Ke Tang, Anis Mohamed, and Tiow-Seng Tan. 2010. +// Parallel Banding Algorithm to compute exact distance transform with the GPU. +// In Proceedings of the 2010 ACM SIGGRAPH symposium on Interactive 3D Graphics and Games. + +#include +#include +#include +#include + +#include +#include +#include +#include + +// ============================================================================ +// PBA+ Constants and Macros +// ============================================================================ +#define MARKER -32768 +#define BLOCKSIZE 32 + +#define TOID(x, y, size) ((y) * (size) + (x)) + +// Use short2 for 2D coordinate encoding (supports images up to 32767x32767) +typedef short2 pixel_int2_t; +#define make_pixel(x, y) make_short2(x, y) + +// ============================================================================ +// Domination test for Voronoi diagram construction +// Returns true if site2 dominates site3 at column x0 (site3 should be removed) +// ============================================================================ +#define LL long long +__device__ __forceinline__ bool dominate(LL x1, LL y1, LL x2, LL y2, LL x3, LL y3, LL x0) +{ + LL k1 = y2 - y1, k2 = y3 - y2; + return (k1 * (y1 + y2) + (x2 - x1) * ((x1 + x2) - (x0 << 1))) * k2 > \ + (k2 * (y2 + y3) + (x3 - x2) * ((x2 + x3) - (x0 << 1))) * k1; +} +#undef LL + +// Spacing-aware domination test +__device__ __forceinline__ bool dominate_sp(int _x1, int _y1, int _x2, int _y2, int _x3, int _y3, int _x0, float sx, float sy) +{ + float x1 = static_cast(_x1) * sx; + float x2 = static_cast(_x2) * sx; + float x3 = static_cast(_x3) * sx; + float y1 = static_cast(_y1) * sy; + float y2 = static_cast(_y2) * sy; + float y3 = static_cast(_y3) * sy; + float x0_2 = static_cast(_x0 << 1) * sx; + float k1 = (y2 - y1); + float k2 = (y3 - y2); + return (k1 * (y1 + y2) + (x2 - x1) * ((x1 + x2) - x0_2)) * k2 > \ + (k2 * (y2 + y3) + (x3 - x2) * ((x2 + x3) - x0_2)) * k1; +} + +// ============================================================================ +// Phase 1 Kernels: Vertical (Y-axis) Processing +// ============================================================================ + +// Flood downward along columns +__global__ void kernelFloodDown(pixel_int2_t *input, pixel_int2_t *output, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * bandSize; + int id = TOID(tx, ty, size); + + pixel_int2_t pixel1, pixel2; + + pixel1 = make_pixel(MARKER, MARKER); + + for (int i = 0; i < bandSize; i++, id += size) { + pixel2 = input[id]; + + if (pixel2.x != MARKER) + pixel1 = pixel2; + + output[id] = pixel1; + } +} + +// Flood upward along columns +__global__ void kernelFloodUp(pixel_int2_t *input, pixel_int2_t *output, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = (blockIdx.y + 1) * bandSize - 1; + int id = TOID(tx, ty, size); + + pixel_int2_t pixel1, pixel2; + int dist1, dist2; + + pixel1 = make_pixel(MARKER, MARKER); + + for (int i = 0; i < bandSize; i++, id -= size) { + dist1 = abs(pixel1.y - ty + i); + + pixel2 = input[id]; + dist2 = abs(pixel2.y - ty + i); + + if (dist2 < dist1) + pixel1 = pixel2; + + output[id] = pixel1; + } +} + +// Propagate information between bands +__global__ void kernelPropagateInterband(pixel_int2_t *input, pixel_int2_t *margin_out, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int inc = bandSize * size; + int ny, nid, nDist; + pixel_int2_t pixel; + + int ty = blockIdx.y * bandSize; + int topId = TOID(tx, ty, size); + int bottomId = TOID(tx, ty + bandSize - 1, size); + int tid = blockIdx.y * size + tx; + int bid = tid + (size * size / bandSize); + + pixel = input[topId]; + int myDist = abs(pixel.y - ty); + margin_out[tid] = pixel; + + for (nid = bottomId - inc; nid >= 0; nid -= inc) { + pixel = input[nid]; + + if (pixel.x != MARKER) { + nDist = abs(pixel.y - ty); + + if (nDist < myDist) + margin_out[tid] = pixel; + + break; + } + } + + ty = ty + bandSize - 1; + pixel = input[bottomId]; + myDist = abs(pixel.y - ty); + margin_out[bid] = pixel; + + for (ny = ty + 1, nid = topId + inc; ny < size; ny += bandSize, nid += inc) { + pixel = input[nid]; + + if (pixel.x != MARKER) { + nDist = abs(pixel.y - ty); + + if (nDist < myDist) + margin_out[bid] = pixel; + + break; + } + } +} + +// Update vertical distances and transpose +__global__ void kernelUpdateVertical(pixel_int2_t *color, pixel_int2_t *margin, pixel_int2_t *output, int size, int bandSize) +{ + __shared__ pixel_int2_t block[BLOCKSIZE][BLOCKSIZE]; + + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * bandSize; + + pixel_int2_t top = margin[blockIdx.y * size + tx]; + pixel_int2_t bottom = margin[(blockIdx.y + size / bandSize) * size + tx]; + pixel_int2_t pixel; + + int dist, myDist; + + int id = TOID(tx, ty, size); + + int n_step = bandSize / blockDim.x; + for(int step = 0; step < n_step; ++step) { + int y_start = blockIdx.y * bandSize + step * blockDim.x; + int y_end = y_start + blockDim.x; + + for (ty = y_start; ty < y_end; ++ty, id += size) { + pixel = color[id]; + myDist = abs(pixel.y - ty); + + dist = abs(top.y - ty); + if (dist < myDist) { myDist = dist; pixel = top; } + + dist = abs(bottom.y - ty); + if (dist < myDist) pixel = bottom; + + block[threadIdx.x][ty - y_start] = make_pixel(pixel.y, pixel.x); + } + + __syncthreads(); + + int tid = TOID(blockIdx.y * bandSize + step * blockDim.x + threadIdx.x, \ + blockIdx.x * blockDim.x, size); + + for(int i = 0; i < blockDim.x; ++i, tid += size) { + output[tid] = block[i][threadIdx.x]; + } + + __syncthreads(); + } +} + +// ============================================================================ +// Phase 2 Kernels: Horizontal (X-axis) Processing with Stack +// ============================================================================ + +// Build stack of proximate points using domination test (isotropic) +__global__ void kernelProximatePoints(pixel_int2_t *input, pixel_int2_t *stack, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * bandSize; + int id = TOID(tx, ty, size); + int lasty = -1; + pixel_int2_t last1, last2, current; + + last1.y = -1; last2.y = -1; + + for (int i = 0; i < bandSize; i++, id += size) { + current = input[id]; + + if (current.x != MARKER) { + while (last2.y >= 0) { + if (!dominate(last1.x, last2.y, last2.x, lasty, current.x, current.y, tx)) + break; + + lasty = last2.y; last2 = last1; + + if (last1.y >= 0) + last1 = stack[TOID(tx, last1.y, size)]; + } + + last1 = last2; last2 = make_pixel(current.x, lasty); lasty = current.y; + + stack[id] = last2; + } + } + + if (lasty != ty + bandSize - 1) + stack[TOID(tx, ty + bandSize - 1, size)] = make_pixel(MARKER, lasty); +} + +// Build stack with spacing support +__global__ void kernelProximatePointsWithSpacing(pixel_int2_t *input, pixel_int2_t *stack, int size, int bandSize, float sx, float sy) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * bandSize; + int id = TOID(tx, ty, size); + int lasty = -1; + pixel_int2_t last1, last2, current; + + last1.y = -1; last2.y = -1; + + for (int i = 0; i < bandSize; i++, id += size) { + current = input[id]; + + if (current.x != MARKER) { + while (last2.y >= 0) { + if (!dominate_sp(last1.x, last2.y, last2.x, lasty, current.x, current.y, tx, sx, sy)) + break; + + lasty = last2.y; last2 = last1; + + if (last1.y >= 0) + last1 = stack[TOID(tx, last1.y, size)]; + } + + last1 = last2; last2 = make_pixel(current.x, lasty); lasty = current.y; + + stack[id] = last2; + } + } + + if (lasty != ty + bandSize - 1) + stack[TOID(tx, ty + bandSize - 1, size)] = make_pixel(MARKER, lasty); +} + +// Create forward pointers from backward-linked structure +__global__ void kernelCreateForwardPointers(pixel_int2_t *input, pixel_int2_t *output, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = (blockIdx.y + 1) * bandSize - 1; + int id = TOID(tx, ty, size); + int lasty = -1, nexty; + pixel_int2_t current; + + current = input[id]; + + if (current.x == MARKER) + nexty = current.y; + else + nexty = ty; + + for (int i = 0; i < bandSize; i++, id -= size) + if (ty - i == nexty) { + current = make_pixel(lasty, input[id].y); + output[id] = current; + + lasty = nexty; + nexty = current.y; + } + + if (lasty != ty - bandSize + 1) + output[id + size] = make_pixel(lasty, MARKER); +} + +// Merge adjacent bands (isotropic) +__global__ void kernelMergeBands(pixel_int2_t *color, pixel_int2_t *link, pixel_int2_t *output, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int band1 = blockIdx.y * 2; + int band2 = band1 + 1; + int firsty, lasty; + pixel_int2_t last1, last2, current; + + lasty = band2 * bandSize - 1; + last2 = make_pixel(color[TOID(tx, lasty, size)].x, link[TOID(tx, lasty, size)].y); + + if (last2.x == MARKER) { + lasty = last2.y; + + if (lasty >= 0) + last2 = make_pixel(color[TOID(tx, lasty, size)].x, link[TOID(tx, lasty, size)].y); + else + last2 = make_pixel(MARKER, MARKER); + } + + if (last2.y >= 0) { + last1 = make_pixel(color[TOID(tx, last2.y, size)].x, link[TOID(tx, last2.y, size)].y); + } + + firsty = band2 * bandSize; + current = make_pixel(link[TOID(tx, firsty, size)].x, color[TOID(tx, firsty, size)].x); + + if (current.y == MARKER) { + firsty = current.x; + + if (firsty >= 0) + current = make_pixel(link[TOID(tx, firsty, size)].x, color[TOID(tx, firsty, size)].x); + else + current = make_pixel(MARKER, MARKER); + } + + int top = 0; + + while (top < 2 && current.y >= 0) { + while (last2.y >= 0) { + if (!dominate(last1.x, last2.y, last2.x, lasty, current.y, firsty, tx)) + break; + + lasty = last2.y; last2 = last1; + top--; + + if (last1.y >= 0) + last1 = make_pixel(color[TOID(tx, last1.y, size)].x, link[TOID(tx, last1.y, size)].y); + } + + output[TOID(tx, firsty, size)] = make_pixel(current.x, lasty); + + if (lasty >= 0) + output[TOID(tx, lasty, size)] = make_pixel(firsty, last2.y); + + last1 = last2; last2 = make_pixel(current.y, lasty); lasty = firsty; + firsty = current.x; + + top = max(1, top + 1); + + if (firsty >= 0) + current = make_pixel(link[TOID(tx, firsty, size)].x, color[TOID(tx, firsty, size)].x); + else + current = make_pixel(MARKER, MARKER); + } + + firsty = band1 * bandSize; + lasty = band2 * bandSize; + current = link[TOID(tx, firsty, size)]; + + if (current.y == MARKER && current.x < 0) { + last1 = link[TOID(tx, lasty, size)]; + + if (last1.y == MARKER) + current.x = last1.x; + else + current.x = lasty; + + output[TOID(tx, firsty, size)] = current; + } + + firsty = band1 * bandSize + bandSize - 1; + lasty = band2 * bandSize + bandSize - 1; + current = link[TOID(tx, lasty, size)]; + + if (current.x == MARKER && current.y < 0) { + last1 = link[TOID(tx, firsty, size)]; + + if (last1.x == MARKER) + current.y = last1.y; + else + current.y = firsty; + + output[TOID(tx, lasty, size)] = current; + } +} + +// Merge bands with spacing support +__global__ void kernelMergeBandsWithSpacing(pixel_int2_t *color, pixel_int2_t *link, pixel_int2_t *output, int size, int bandSize, float sx, float sy) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int band1 = blockIdx.y * 2; + int band2 = band1 + 1; + int firsty, lasty; + pixel_int2_t last1, last2, current; + + lasty = band2 * bandSize - 1; + last2 = make_pixel(color[TOID(tx, lasty, size)].x, link[TOID(tx, lasty, size)].y); + + if (last2.x == MARKER) { + lasty = last2.y; + + if (lasty >= 0) + last2 = make_pixel(color[TOID(tx, lasty, size)].x, link[TOID(tx, lasty, size)].y); + else + last2 = make_pixel(MARKER, MARKER); + } + + if (last2.y >= 0) { + last1 = make_pixel(color[TOID(tx, last2.y, size)].x, link[TOID(tx, last2.y, size)].y); + } + + firsty = band2 * bandSize; + current = make_pixel(link[TOID(tx, firsty, size)].x, color[TOID(tx, firsty, size)].x); + + if (current.y == MARKER) { + firsty = current.x; + + if (firsty >= 0) + current = make_pixel(link[TOID(tx, firsty, size)].x, color[TOID(tx, firsty, size)].x); + else + current = make_pixel(MARKER, MARKER); + } + + int top = 0; + + while (top < 2 && current.y >= 0) { + while (last2.y >= 0) { + if (!dominate_sp(last1.x, last2.y, last2.x, lasty, current.y, firsty, tx, sx, sy)) + break; + + lasty = last2.y; last2 = last1; + top--; + + if (last1.y >= 0) + last1 = make_pixel(color[TOID(tx, last1.y, size)].x, link[TOID(tx, last1.y, size)].y); + } + + output[TOID(tx, firsty, size)] = make_pixel(current.x, lasty); + + if (lasty >= 0) + output[TOID(tx, lasty, size)] = make_pixel(firsty, last2.y); + + last1 = last2; last2 = make_pixel(current.y, lasty); lasty = firsty; + firsty = current.x; + + top = max(1, top + 1); + + if (firsty >= 0) + current = make_pixel(link[TOID(tx, firsty, size)].x, color[TOID(tx, firsty, size)].x); + else + current = make_pixel(MARKER, MARKER); + } + + firsty = band1 * bandSize; + lasty = band2 * bandSize; + current = link[TOID(tx, firsty, size)]; + + if (current.y == MARKER && current.x < 0) { + last1 = link[TOID(tx, lasty, size)]; + + if (last1.y == MARKER) + current.x = last1.x; + else + current.x = lasty; + + output[TOID(tx, firsty, size)] = current; + } + + firsty = band1 * bandSize + bandSize - 1; + lasty = band2 * bandSize + bandSize - 1; + current = link[TOID(tx, lasty, size)]; + + if (current.x == MARKER && current.y < 0) { + last1 = link[TOID(tx, firsty, size)]; + + if (last1.x == MARKER) + current.y = last1.y; + else + current.y = firsty; + + output[TOID(tx, lasty, size)] = current; + } +} + +// Convert double-linked list to single list +__global__ void kernelDoubleToSingleList(pixel_int2_t *color, pixel_int2_t *link, pixel_int2_t *output, int size) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y; + int id = TOID(tx, ty, size); + + output[id] = make_pixel(color[id].x, link[id].y); +} + +// ============================================================================ +// Phase 3 Kernels: Final Distance Computation +// ============================================================================ + +// Compute final Voronoi coloring (isotropic) +__global__ void kernelColor(pixel_int2_t *input, pixel_int2_t *output, int size) +{ + __shared__ pixel_int2_t block[BLOCKSIZE][BLOCKSIZE]; + + int col = threadIdx.x; + int tid = threadIdx.y; + int tx = blockIdx.x * blockDim.x + col; + int dx, dy, lasty; + unsigned int best, dist; + pixel_int2_t last1, last2; + + lasty = size - 1; + + last2 = input[TOID(tx, lasty, size)]; + + if (last2.x == MARKER) { + lasty = max(last2.y, 0); + last2 = input[TOID(tx, lasty, size)]; + } + + if (last2.y >= 0) + last1 = input[TOID(tx, last2.y, size)]; + + int y_start, y_end, n_step = size / blockDim.x; + for(int step = 0; step < n_step; ++step) { + y_start = size - step * blockDim.x - 1; + y_end = size - (step + 1) * blockDim.x; + + for (int ty = y_start - tid; ty >= y_end; ty -= blockDim.y) { + dx = last2.x - tx; dy = lasty - ty; + best = dist = dx * dx + dy * dy; + + while (last2.y >= 0) { + dx = last1.x - tx; dy = last2.y - ty; + dist = dx * dx + dy * dy; + + if (dist > best) + break; + + best = dist; lasty = last2.y; last2 = last1; + + if (last2.y >= 0) + last1 = input[TOID(tx, last2.y, size)]; + } + + block[threadIdx.x][ty - y_end] = make_pixel(lasty, last2.x); + } + + __syncthreads(); + + if(!threadIdx.y) { + int id = TOID(y_end + threadIdx.x, blockIdx.x * blockDim.x, size); + for(int i = 0; i < blockDim.x; ++i, id += size) { + output[id] = block[i][threadIdx.x]; + } + } + + __syncthreads(); + } +} + +// Compute final Voronoi coloring with spacing +__global__ void kernelColorWithSpacing(pixel_int2_t *input, pixel_int2_t *output, int size, float sx, float sy) +{ + __shared__ pixel_int2_t block[BLOCKSIZE][BLOCKSIZE]; + + int col = threadIdx.x; + int tid = threadIdx.y; + int tx = blockIdx.x * blockDim.x + col; + int lasty; + float dx, dy, best, dist; + pixel_int2_t last1, last2; + + lasty = size - 1; + + last2 = input[TOID(tx, lasty, size)]; + + if (last2.x == MARKER) { + lasty = max(last2.y, 0); + last2 = input[TOID(tx, lasty, size)]; + } + + if (last2.y >= 0) + last1 = input[TOID(tx, last2.y, size)]; + + int y_start, y_end, n_step = size / blockDim.x; + for(int step = 0; step < n_step; ++step) { + y_start = size - step * blockDim.x - 1; + y_end = size - (step + 1) * blockDim.x; + + for (int ty = y_start - tid; ty >= y_end; ty -= blockDim.y) { + dx = static_cast(last2.x - tx) * sx; + dy = static_cast(lasty - ty) * sy; + best = dist = dx * dx + dy * dy; + + while (last2.y >= 0) { + dx = static_cast(last1.x - tx) * sx; + dy = static_cast(last2.y - ty) * sy; + dist = dx * dx + dy * dy; + + if (dist > best) + break; + + best = dist; lasty = last2.y; last2 = last1; + + if (last2.y >= 0) + last1 = input[TOID(tx, last2.y, size)]; + } + + block[threadIdx.x][ty - y_end] = make_pixel(lasty, last2.x); + } + + __syncthreads(); + + if(!threadIdx.y) { + int id = TOID(y_end + threadIdx.x, blockIdx.x * blockDim.x, size); + for(int i = 0; i < blockDim.x; ++i, id += size) { + output[id] = block[i][threadIdx.x]; + } + } + + __syncthreads(); + } +} + +// ============================================================================ +// Kernel to initialize buffer to MARKER (for proper padding initialization) +// Note: cudaMemset cannot be used because MARKER (-32768 = 0x8000) cannot be +// represented as a single byte pattern. This kernel ensures the padded region +// is properly initialized to MARKER values. +// ============================================================================ +__global__ void kernelInitToMarker(pixel_int2_t* __restrict__ buffer, int size) +{ + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= size || y >= size) return; + + int id = y * size + x; + buffer[id] = make_pixel(MARKER, MARKER); +} + +// ============================================================================ +// Kernel to initialize sites from mask +// ============================================================================ +__global__ void kernelInitSites(const float* __restrict__ mask, pixel_int2_t* __restrict__ sites, + int width, int height, int size) +{ + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // mask has stride 'width', sites buffer has stride 'size' (padded) + int mask_id = y * width + x; + int site_id = y * size + x; + + // mask == 0 means seed point (foreground in EDT terms) + if (mask[mask_id] < 0.5f) { + sites[site_id] = make_pixel(x, y); + } else { + sites[site_id] = make_pixel(MARKER, MARKER); + } +} + +// ============================================================================ +// Kernel to compute final distance from Voronoi sites +// ============================================================================ +__global__ void kernelComputeDistance(const pixel_int2_t* __restrict__ sites, float* __restrict__ distance, + int width, int height, int size, float sx, float sy) +{ + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= width || y >= height) return; + + // sites buffer has stride 'size' (padded), output has stride 'width' + int site_id = y * size + x; + int out_id = y * width + x; + pixel_int2_t site = sites[site_id]; + + if (site.x == MARKER) { + distance[out_id] = FLT_MAX; + } else { + float dx = static_cast(x - site.x) * sx; + float dy = static_cast(y - site.y) * sy; + distance[out_id] = sqrtf(dx * dx + dy * dy); + } +} + +// ============================================================================ +// Helper: Compute band sizes for PBA+ +// ============================================================================ +inline void computeBandSizes(int size, int& m1, int& m2) +{ + // Use heuristics from PBA+ paper + // m1: band size for phase 1 (should divide size) + // m2: band size for phase 2 (should be power of 2 and divide size) + + // Default values that work well + m1 = BLOCKSIZE; + m2 = BLOCKSIZE; + + // Adjust if size is not divisible + while (size % m1 != 0 && m1 > 1) m1 /= 2; + while (size % m2 != 0 && m2 > 1) m2 /= 2; + + // Ensure minimum band size + if (m1 < 1) m1 = 1; + if (m2 < 1) m2 = 1; +} + +// ============================================================================ +// Main 2D PBA+ EDT function +// ============================================================================ +torch::Tensor exact_euclidean2d_cuda( + const torch::Tensor &mask, + const std::vector &spacing +) { + int device = mask.get_device(); + c10::cuda::CUDAGuard device_guard(device); + + // Get dimensions (assume BCHW format) + const int batch = mask.size(0); + const int channels = mask.size(1); + const int height = mask.size(2); + const int width = mask.size(3); + + // spacing is [height_spacing, width_spacing] to match [H, W] tensor convention + float sy = spacing.size() > 0 ? spacing[0] : 1.0f; // height (row) spacing + float sx = spacing.size() > 1 ? spacing[1] : 1.0f; // width (column) spacing + bool use_spacing = (sx != 1.0f || sy != 1.0f); + + // Early return if all pixels are seeds (mask=0 means seed, so sum=0 means all seeds) + // In this case, all distances should be 0 + if (mask.sum().item() < 0.5f) { + return torch::zeros({batch, channels, height, width}, + torch::TensorOptions().dtype(torch::kFloat32).device(mask.device())); + } + + // PBA+ requires square images with size being power of 2 + // Round up to next power of 2 + int size = max(width, height); + + // Next power of 2 using bit manipulation + size--; + size |= size >> 1; + size |= size >> 2; + size |= size >> 4; + size |= size >> 8; + size |= size >> 16; + size++; + + // Ensure size is at least 2*BLOCKSIZE for the algorithm to work + if (size < 2 * BLOCKSIZE) size = 2 * BLOCKSIZE; + + // Compute band sizes + int bandSize1, bandSize2; + computeBandSizes(size, bandSize1, bandSize2); + + // Allocate output + auto options_float = torch::TensorOptions().dtype(torch::kFloat32).device(mask.device()); + torch::Tensor distance = torch::zeros({batch, channels, height, width}, options_float); + + // Allocate intermediate buffers (padded size) + auto options_short2 = torch::TensorOptions().dtype(torch::kInt32).device(mask.device()); + // We use int32 but interpret as short2 (same size) + torch::Tensor buffer1 = torch::full({size, size}, 0, options_short2); + torch::Tensor buffer2 = torch::full({size, size}, 0, options_short2); + torch::Tensor margin = torch::full({2 * size * size / bandSize1}, 0, options_short2); + + // Process each batch and channel + for (int b = 0; b < batch; b++) { + for (int c = 0; c < channels; c++) { + // Get pointers + const float* mask_ptr = mask.data_ptr() + (b * channels + c) * height * width; + float* dist_ptr = distance.data_ptr() + (b * channels + c) * height * width; + pixel_int2_t* buf1_ptr = reinterpret_cast(buffer1.data_ptr()); + pixel_int2_t* buf2_ptr = reinterpret_cast(buffer2.data_ptr()); + pixel_int2_t* margin_ptr = reinterpret_cast(margin.data_ptr()); + + // Initialize buffers to MARKER using kernel (cudaMemset cannot set MARKER correctly) + { + dim3 block(16, 16); + dim3 grid((size + 15) / 16, (size + 15) / 16); + kernelInitToMarker<<>>(buf1_ptr, size); + kernelInitToMarker<<>>(buf2_ptr, size); + } + + // Initialize sites from mask (only for valid region) + { + dim3 block(16, 16); + dim3 grid((width + 15) / 16, (height + 15) / 16); + kernelInitSites<<>>(mask_ptr, buf1_ptr, width, height, size); + } + + // Phase 1: Vertical processing + { + dim3 block(BLOCKSIZE); + dim3 grid(size / BLOCKSIZE, size / bandSize1); + + kernelFloodDown<<>>(buf1_ptr, buf1_ptr, size, bandSize1); + kernelFloodUp<<>>(buf1_ptr, buf1_ptr, size, bandSize1); + kernelPropagateInterband<<>>(buf1_ptr, margin_ptr, size, bandSize1); + kernelUpdateVertical<<>>(buf1_ptr, margin_ptr, buf2_ptr, size, bandSize1); + } + + // Phase 2: Horizontal processing with band merging + { + dim3 block(BLOCKSIZE); + dim3 grid(size / BLOCKSIZE, size / bandSize2); + + if (use_spacing) { + kernelProximatePointsWithSpacing<<>>(buf2_ptr, buf1_ptr, size, bandSize2, sx, sy); + } else { + kernelProximatePoints<<>>(buf2_ptr, buf1_ptr, size, bandSize2); + } + + kernelCreateForwardPointers<<>>(buf1_ptr, buf1_ptr, size, bandSize2); + + // Iteratively merge bands + int noBand = size / bandSize2; + dim3 grid2(size / BLOCKSIZE, noBand / 2); + + while (noBand > 1) { + if (use_spacing) { + kernelMergeBandsWithSpacing<<>>(buf2_ptr, buf1_ptr, buf1_ptr, size, size / noBand, sx, sy); + } else { + kernelMergeBands<<>>(buf2_ptr, buf1_ptr, buf1_ptr, size, size / noBand); + } + noBand /= 2; + grid2.y = noBand / 2; + if (grid2.y == 0) break; + } + + dim3 grid3(size / BLOCKSIZE, size); + kernelDoubleToSingleList<<>>(buf2_ptr, buf1_ptr, buf1_ptr, size); + } + + // Phase 3: Final coloring + { + dim3 block(BLOCKSIZE, BLOCKSIZE); + dim3 grid(size / BLOCKSIZE); + + if (use_spacing) { + kernelColorWithSpacing<<>>(buf1_ptr, buf2_ptr, size, sx, sy); + } else { + kernelColor<<>>(buf1_ptr, buf2_ptr, size); + } + } + + // Compute final distances (only for valid region) + { + dim3 block(16, 16); + dim3 grid((width + 15) / 16, (height + 15) / 16); + kernelComputeDistance<<>>(buf2_ptr, dist_ptr, width, height, size, sx, sy); + } + } + } + + return distance; +} + +// ============================================================================ +// 3D PBA+ Implementation +// ============================================================================ + +// For 3D, we use a similar approach but process along Z, Y, X axes +// Using int for 3D coordinate encoding (10 bits per coordinate, supports up to 1024^3) + +#define MARKER3D -1 +#define ENCODE3D(x, y, z) (((x) << 20) | ((y) << 10) | (z)) +#define DECODE3D_X(v) (((v) >> 20) & 0x3FF) +#define DECODE3D_Y(v) (((v) >> 10) & 0x3FF) +#define DECODE3D_Z(v) ((v) & 0x3FF) + +// 3D domination test with spacing +__device__ __forceinline__ bool dominate3d_sp( + int x1, int y1, int z1, + int x2, int y2, int z2, + int x3, int y3, int z3, + int px, int py, + float sx, float sy, float sz +) { + // Check if (x2,y2,z2) dominates (x3,y3,z3) at position (px,py,*) + float fx1 = x1 * sx, fy1 = y1 * sy, fz1 = z1 * sz; + float fx2 = x2 * sx, fy2 = y2 * sy, fz2 = z2 * sz; + float fx3 = x3 * sx, fy3 = y3 * sy, fz3 = z3 * sz; + float fpx = px * sx, fpy = py * sy; + + float d12_xy = (fx1 - fpx) * (fx1 - fpx) + (fy1 - fpy) * (fy1 - fpy) - + (fx2 - fpx) * (fx2 - fpx) - (fy2 - fpy) * (fy2 - fpy); + float d23_xy = (fx2 - fpx) * (fx2 - fpx) + (fy2 - fpy) * (fy2 - fpy) - + (fx3 - fpx) * (fx3 - fpx) - (fy3 - fpy) * (fy3 - fpy); + + float dz12 = fz2 - fz1; + float dz23 = fz3 - fz2; + + if (dz12 == 0 || dz23 == 0) return false; + + float t1 = (d12_xy + fz1 * fz1 - fz2 * fz2) / (2 * dz12); + float t2 = (d23_xy + fz2 * fz2 - fz3 * fz3) / (2 * dz23); + + return t1 >= t2; +} + +// Flood along Z-axis for 3D +__global__ void kernelFloodZ3D(int* input, int* output, int sizeX, int sizeY, int sizeZ) +{ + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + + if (x >= sizeX || y >= sizeY) return; + + int lastSite = MARKER3D; + + // Forward pass + for (int z = 0; z < sizeZ; z++) { + int id = z * sizeY * sizeX + y * sizeX + x; + int site = input[id]; + + if (site != MARKER3D) + lastSite = site; + + output[id] = lastSite; + } + + lastSite = MARKER3D; + + // Backward pass + for (int z = sizeZ - 1; z >= 0; z--) { + int id = z * sizeY * sizeX + y * sizeX + x; + int site = input[id]; + + if (site != MARKER3D) + lastSite = site; + + int siteForward = output[id]; + + if (lastSite == MARKER3D) + continue; + + if (siteForward == MARKER3D) { + output[id] = lastSite; + } else { + int zForward = DECODE3D_Z(siteForward); + int zBack = DECODE3D_Z(lastSite); + + if (abs(z - zBack) < abs(z - zForward)) + output[id] = lastSite; + } + } +} + +// Process Y-axis for 3D (Maurer's algorithm) +__global__ void kernelMaurerY3D(int* input, int* output, int sizeX, int sizeY, int sizeZ, + float sx, float sy, float sz) +{ + int x = blockIdx.x * blockDim.x + threadIdx.x; + int z = blockIdx.z; + + if (x >= sizeX || z >= sizeZ) return; + + int lastSite = MARKER3D; + + // Forward pass + for (int y = 0; y < sizeY; y++) { + int id = z * sizeY * sizeX + y * sizeX + x; + int site = input[id]; + + if (site != MARKER3D) + lastSite = site; + + output[id] = lastSite; + } + + lastSite = MARKER3D; + + // Backward pass with distance comparison + for (int y = sizeY - 1; y >= 0; y--) { + int id = z * sizeY * sizeX + y * sizeX + x; + int site = input[id]; + + if (site != MARKER3D) + lastSite = site; + + int siteForward = output[id]; + + if (lastSite == MARKER3D) + continue; + + if (siteForward == MARKER3D) { + output[id] = lastSite; + } else { + // Compare distances + int xf = DECODE3D_X(siteForward), yf = DECODE3D_Y(siteForward), zf = DECODE3D_Z(siteForward); + int xb = DECODE3D_X(lastSite), yb = DECODE3D_Y(lastSite), zb = DECODE3D_Z(lastSite); + + float dxf = (x - xf) * sx, dyf = (y - yf) * sy, dzf = (z - zf) * sz; + float dxb = (x - xb) * sx, dyb = (y - yb) * sy, dzb = (z - zb) * sz; + + float distF = dxf*dxf + dyf*dyf + dzf*dzf; + float distB = dxb*dxb + dyb*dyb + dzb*dzb; + + if (distB < distF) + output[id] = lastSite; + } + } +} + +// Process X-axis for 3D and compute final distance +__global__ void kernelColorX3D(int* input, float* distance, int sizeX, int sizeY, int sizeZ, + float sx, float sy, float sz) +{ + int y = blockIdx.y * blockDim.y + threadIdx.y; + int z = blockIdx.z; + int x = blockIdx.x * blockDim.x + threadIdx.x; + + if (x >= sizeX || y >= sizeY || z >= sizeZ) return; + + int id = z * sizeY * sizeX + y * sizeX + x; + + int bestSite = input[id]; + float bestDist = FLT_MAX; + + if (bestSite != MARKER3D) { + int bx = DECODE3D_X(bestSite), by = DECODE3D_Y(bestSite), bz = DECODE3D_Z(bestSite); + float dx = (x - bx) * sx, dy = (y - by) * sy, dz = (z - bz) * sz; + bestDist = dx*dx + dy*dy + dz*dz; + } + + // Search left + for (int lx = x - 1; lx >= 0; lx--) { + float horizDist = (x - lx) * sx; + if (horizDist * horizDist >= bestDist) break; + + int leftSite = input[z * sizeY * sizeX + y * sizeX + lx]; + if (leftSite != MARKER3D) { + int lsx = DECODE3D_X(leftSite), lsy = DECODE3D_Y(leftSite), lsz = DECODE3D_Z(leftSite); + float dx = (x - lsx) * sx, dy = (y - lsy) * sy, dz = (z - lsz) * sz; + float dist = dx*dx + dy*dy + dz*dz; + if (dist < bestDist) { + bestDist = dist; + } + } + } + + // Search right + for (int rx = x + 1; rx < sizeX; rx++) { + float horizDist = (rx - x) * sx; + if (horizDist * horizDist >= bestDist) break; + + int rightSite = input[z * sizeY * sizeX + y * sizeX + rx]; + if (rightSite != MARKER3D) { + int rsx = DECODE3D_X(rightSite), rsy = DECODE3D_Y(rightSite), rsz = DECODE3D_Z(rightSite); + float dx = (x - rsx) * sx, dy = (y - rsy) * sy, dz = (z - rsz) * sz; + float dist = dx*dx + dy*dy + dz*dz; + if (dist < bestDist) { + bestDist = dist; + } + } + } + + distance[id] = (bestDist < FLT_MAX) ? sqrtf(bestDist) : FLT_MAX; +} + +// Initialize 3D sites from mask +__global__ void kernelInitSites3D(const float* mask, int* sites, int sizeX, int sizeY, int sizeZ) +{ + int x = blockIdx.x * blockDim.x + threadIdx.x; + int y = blockIdx.y * blockDim.y + threadIdx.y; + int z = blockIdx.z; + + if (x >= sizeX || y >= sizeY || z >= sizeZ) return; + + int id = z * sizeY * sizeX + y * sizeX + x; + + if (mask[id] < 0.5f) { + sites[id] = ENCODE3D(x, y, z); + } else { + sites[id] = MARKER3D; + } +} + +// ============================================================================ +// Main 3D PBA+ EDT function +// ============================================================================ +torch::Tensor exact_euclidean3d_cuda( + const torch::Tensor &mask, + const std::vector &spacing +) { + int device = mask.get_device(); + c10::cuda::CUDAGuard device_guard(device); + + // Get dimensions (assume BCDHW format) + const int batch = mask.size(0); + const int channels = mask.size(1); + const int depth = mask.size(2); + const int height = mask.size(3); + const int width = mask.size(4); + + // spacing is [depth_spacing, height_spacing, width_spacing] to match [D, H, W] tensor convention + float sz = spacing.size() > 0 ? spacing[0] : 1.0f; // depth (z) spacing + float sy = spacing.size() > 1 ? spacing[1] : 1.0f; // height (y) spacing + float sx = spacing.size() > 2 ? spacing[2] : 1.0f; // width (x) spacing + + // Early return if all pixels are seeds (mask=0 means seed, so sum=0 means all seeds) + if (mask.sum().item() < 0.5f) { + return torch::zeros({batch, channels, depth, height, width}, + torch::TensorOptions().dtype(torch::kFloat32).device(mask.device())); + } + + // Allocate output + auto options_float = torch::TensorOptions().dtype(torch::kFloat32).device(mask.device()); + auto options_int = torch::TensorOptions().dtype(torch::kInt32).device(mask.device()); + + torch::Tensor distance = torch::zeros({batch, channels, depth, height, width}, options_float); + + // Allocate intermediate buffers + torch::Tensor buffer1 = torch::full({depth, height, width}, MARKER3D, options_int); + torch::Tensor buffer2 = torch::full({depth, height, width}, MARKER3D, options_int); + + // Process each batch and channel + for (int b = 0; b < batch; b++) { + for (int c = 0; c < channels; c++) { + const float* mask_ptr = mask.data_ptr() + (b * channels + c) * depth * height * width; + float* dist_ptr = distance.data_ptr() + (b * channels + c) * depth * height * width; + int* buf1_ptr = buffer1.data_ptr(); + int* buf2_ptr = buffer2.data_ptr(); + + // Initialize sites + { + dim3 block(16, 16, 1); + dim3 grid((width + 15) / 16, (height + 15) / 16, depth); + kernelInitSites3D<<>>(mask_ptr, buf1_ptr, width, height, depth); + } + + // Phase 1: Z-axis flooding + { + dim3 block(16, 16); + dim3 grid((width + 15) / 16, (height + 15) / 16); + kernelFloodZ3D<<>>(buf1_ptr, buf2_ptr, width, height, depth); + } + + // Phase 2: Y-axis processing + { + dim3 block(256); + dim3 grid((width + 255) / 256, 1, depth); + kernelMaurerY3D<<>>(buf2_ptr, buf1_ptr, width, height, depth, sx, sy, sz); + } + + // Phase 3: X-axis processing and distance computation + { + dim3 block(16, 16); + dim3 grid((width + 15) / 16, (height + 15) / 16, depth); + kernelColorX3D<<>>(buf1_ptr, dist_ptr, width, height, depth, sx, sy, sz); + } + } + } + + return distance; +} + +// ============================================================================ +// Signed EDT functions +// ============================================================================ +torch::Tensor signed_exact_euclidean2d_cuda( + const torch::Tensor &mask, + const std::vector &spacing +) { + // Distance from foreground (mask > 0.5) + torch::Tensor D_fg = exact_euclidean2d_cuda(mask, spacing); + + // Distance from background - invert mask + torch::Tensor inv_mask = 1.0f - mask; + torch::Tensor D_bg = exact_euclidean2d_cuda(inv_mask, spacing); + + // Signed distance: negative inside (mask=1), positive outside (mask=0) + return D_bg - D_fg; +} + +torch::Tensor signed_exact_euclidean3d_cuda( + const torch::Tensor &mask, + const std::vector &spacing +) { + // Distance from foreground + torch::Tensor D_fg = exact_euclidean3d_cuda(mask, spacing); + + // Distance from background - invert mask + torch::Tensor inv_mask = 1.0f - mask; + torch::Tensor D_bg = exact_euclidean3d_cuda(inv_mask, spacing); + + // Signed distance: negative inside (mask=1), positive outside (mask=0) + return D_bg - D_fg; +} diff --git a/FastGeodis/geodis_pba.h b/FastGeodis/geodis_pba.h new file mode 100644 index 0000000..1c8422b --- /dev/null +++ b/FastGeodis/geodis_pba.h @@ -0,0 +1,58 @@ +// BSD 3-Clause License + +// Copyright (c) 2021, Muhammad Asad (masadcv@gmail.com) +// All rights reserved. + +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. + +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. + +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include + +#ifdef WITH_CUDA + +// 2D Exact Euclidean Distance Transform using PBA+ algorithm +torch::Tensor exact_euclidean2d_cuda( + const torch::Tensor &mask, + const std::vector &spacing); + +// 3D Exact Euclidean Distance Transform using PBA+ algorithm +torch::Tensor exact_euclidean3d_cuda( + const torch::Tensor &mask, + const std::vector &spacing); + +// Signed 2D Exact Euclidean Distance Transform +torch::Tensor signed_exact_euclidean2d_cuda( + const torch::Tensor &mask, + const std::vector &spacing); + +// Signed 3D Exact Euclidean Distance Transform +torch::Tensor signed_exact_euclidean3d_cuda( + const torch::Tensor &mask, + const std::vector &spacing); + +#endif diff --git a/README.md b/README.md index 66d700a..ff12220 100755 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ The above raster scan method can be parallelised for each row/plane on an availa In addition, implementation of generalised version of Geodesic distance transforms along with Geodesic Symmetric Filtering (GSF) is provided for use in interactive segmentation methods, that were originally proposed in [1, 2, 5]. -> The raster scan based implementation provides a balance towards speed rather than accuracy of Geodesic distance transform and hence results in efficient hardware utilisation. On the other hand, in case of Euclidean distance transform, exact results can be achieved with other packages (albeit not on necessarilly on GPU) [6, 7, 8] +> The raster scan based implementation provides a balance towards speed rather than accuracy of Geodesic distance transform and hence results in efficient hardware utilisation. For exact Euclidean distance transform on GPU, FastGeodis now includes the PBA+ algorithm [12], which provides exact results matching scipy [7]. For CPU-only exact EDT, see also [6, 7, 8]. # Citation If you use this code in your research, then please consider citing: @@ -101,6 +101,15 @@ or (on conda environments with existing installation of PyTorch with CUDA) | Fast Marching Geodesic Symmetric Filtering 2D | Fast Marching geodesic symmetric filtering for CPU [2, 9] | [FastGeodis.GSF2d_fastmarch](https://fastgeodis.readthedocs.io/en/latest/api_docs.html#FastGeodis.GSF2d_fastmarch) | | Fast Marching Geodesic Symmetric Filtering 3D | Fast Marching geodesic symmetric filtering for CPU [2, 9] | [FastGeodis.GSF3d_fastmarch](https://fastgeodis.readthedocs.io/en/latest/api_docs.html#FastGeodis.GSF3d_fastmarch) | +## Exact Euclidean Distance Transform for GPU based on [12] + +| Method | Description | Documentation | +|--------|-------------|---------------| +| Exact Euclidean Distance 2D | PBA+ exact Euclidean distance transform for GPU [12] | [FastGeodis.exact_euclidean2d](https://fastgeodis.readthedocs.io/en/latest/api_docs.html#FastGeodis.exact_euclidean2d) | +| Exact Euclidean Distance 3D | PBA+ exact Euclidean distance transform for GPU [12] | [FastGeodis.exact_euclidean3d](https://fastgeodis.readthedocs.io/en/latest/api_docs.html#FastGeodis.exact_euclidean3d) | +| Signed Exact Euclidean Distance 2D | PBA+ signed exact Euclidean distance transform for GPU [12] | [FastGeodis.signed_exact_euclidean2d](https://fastgeodis.readthedocs.io/en/latest/api_docs.html#FastGeodis.signed_exact_euclidean2d) | +| Signed Exact Euclidean Distance 3D | PBA+ signed exact Euclidean distance transform for GPU [12] | [FastGeodis.signed_exact_euclidean3d](https://fastgeodis.readthedocs.io/en/latest/api_docs.html#FastGeodis.signed_exact_euclidean3d) | + # Example usage ### Fast Geodesic Distance Transform @@ -219,3 +228,5 @@ FastGeodis (CPU/GPU) is compared with existing GeodisTK ([https://github.com/tai - [10] Sethian, James A. "Fast marching methods." SIAM review 41.2 (1999): 199-235. - [11] Ikonen, L., & Toivanen, P. (2007). Distance and nearest neighbor transforms on gray-level surfaces. Pattern Recognition Letters, 28(5), 604-612. [[doi](https://doi.org/10.1016/j.patrec.2006.10.010)] + +- [12] Cao, Thanh-Tung, Ke Tang, Anis Mohamed, and Tiow-Seng Tan. "Parallel banding algorithm to compute exact distance transform with the GPU." In Proceedings of the 2010 ACM SIGGRAPH symposium on Interactive 3D Graphics and Games, pp. 83-90. 2010. [[doi](https://doi.org/10.1145/1730804.1730818)] diff --git a/figures/experiment_2d_pba.json b/figures/experiment_2d_pba.json new file mode 100644 index 0000000..5224c62 --- /dev/null +++ b/figures/experiment_2d_pba.json @@ -0,0 +1,50 @@ +{ + "scipy_edt_2d": [ + 0.00011696815490722657, + 0.0002858161926269531, + 0.0019108772277832032, + 0.010666751861572265, + 0.036969566345214845, + 0.1944645404815674 + ], + "generalised_geodesic2d_raster_gpu": [ + 0.014127683639526368, + 0.0009024143218994141, + 0.014866113662719727, + 0.01650266647338867, + 0.028803634643554687, + 0.05953817367553711 + ], + "exact_euclidean2d_pba_gpu": [ + 0.006118297576904297, + 0.004769802093505859, + 0.0014224052429199219, + 0.00026378631591796877, + 0.0009524822235107422, + 0.00973048210144043 + ], + "spatial_dim": [ + 64, + 128, + 256, + 512, + 1024, + 2048 + ], + "raster_error": [ + 2.8720017066366665, + 5.748403090096247, + 11.49685195655968, + 22.993745347295487, + 45.986171455232125, + 91.96245521515175 + ], + "pba_error": [ + 1.8856557275626074e-06, + 3.7977307556502637e-06, + 7.6272718843029e-06, + 1.525863092410873e-05, + 3.0517357686221658e-05, + 6.1034715372443316e-05 + ] +} \ No newline at end of file diff --git a/figures/experiment_2d_pba.png b/figures/experiment_2d_pba.png new file mode 100644 index 0000000..1c316d6 Binary files /dev/null and b/figures/experiment_2d_pba.png differ diff --git a/figures/experiment_2d_pba_error.png b/figures/experiment_2d_pba_error.png new file mode 100644 index 0000000..b91e00c Binary files /dev/null and b/figures/experiment_2d_pba_error.png differ diff --git a/figures/experiment_3d_pba.json b/figures/experiment_3d_pba.json new file mode 100644 index 0000000..2299e94 --- /dev/null +++ b/figures/experiment_3d_pba.json @@ -0,0 +1,32 @@ +{ + "scipy_edt_3d": [ + 0.0012058417002360027, + 0.008033990859985352, + 0.14787872632344565 + ], + "generalised_geodesic3d_raster_gpu": [ + 0.015847047170003254, + 0.0033197402954101562, + 0.05194632212320963 + ], + "exact_euclidean3d_pba_gpu": [ + 0.004782517751057942, + 0.005153099695841472, + 0.010609149932861328 + ], + "spatial_dim": [ + 32, + 64, + 128 + ], + "raster_error": [ + 2.3527656564190202, + 4.705527498140775, + 9.415359322211515 + ], + "pba_error": [ + 9.437702850334517e-07, + 1.900751037453574e-06, + 3.8146837226804564e-06 + ] +} \ No newline at end of file diff --git a/figures/experiment_3d_pba.png b/figures/experiment_3d_pba.png new file mode 100644 index 0000000..cac98ad Binary files /dev/null and b/figures/experiment_3d_pba.png differ diff --git a/figures/experiment_3d_pba_error.png b/figures/experiment_3d_pba_error.png new file mode 100644 index 0000000..8757ecc Binary files /dev/null and b/figures/experiment_3d_pba_error.png differ diff --git a/samples/test_speed_benchmark_pba.py b/samples/test_speed_benchmark_pba.py new file mode 100644 index 0000000..3503aff --- /dev/null +++ b/samples/test_speed_benchmark_pba.py @@ -0,0 +1,319 @@ +""" +Benchmark script for PBA+ (Parallel Banding Algorithm Plus) Exact EDT. + +Compares: +- scipy.ndimage.distance_transform_edt (CPU reference) +- FastGeodis raster scanning GPU (approximate) +- FastGeodis PBA+ GPU (exact) + +Also computes accuracy errors vs scipy reference. +""" + +import json +import os +import time +from functools import wraps + +import FastGeodis +import matplotlib.pyplot as plt +import numpy as np +import torch + +try: + from scipy import ndimage + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + print("scipy not available, some benchmarks will be skipped") + + +def timing(f): + @wraps(f) + def wrap(*args, **kw): + ts = time.time() + result = f(*args, **kw) + te = time.time() + print("func:%r took: %2.4f sec" % (f.__name__, te - ts)) + return result + return wrap + + +@timing +def scipy_edt_2d(mask_np): + """Scipy reference EDT (CPU)""" + return ndimage.distance_transform_edt(mask_np) + + +@timing +def scipy_edt_3d(mask_np): + """Scipy reference EDT (CPU)""" + return ndimage.distance_transform_edt(mask_np) + + +@timing +def generalised_geodesic2d_raster_gpu(I, S, v, lamb, iter): + """FastGeodis raster scanning (GPU, approximate)""" + result = FastGeodis.generalised_geodesic2d(I, S, v, lamb, iter) + torch.cuda.synchronize() + return result + + +@timing +def generalised_geodesic3d_raster_gpu(I, S, spacing, v, lamb, iter): + """FastGeodis raster scanning 3D (GPU, approximate)""" + result = FastGeodis.generalised_geodesic3d(I, S, spacing, v, lamb, iter) + torch.cuda.synchronize() + return result + + +@timing +def exact_euclidean2d_pba_gpu(mask, spacing): + """FastGeodis PBA+ exact EDT (GPU)""" + result = FastGeodis.exact_euclidean2d(mask, spacing) + torch.cuda.synchronize() + return result + + +@timing +def exact_euclidean3d_pba_gpu(mask, spacing): + """FastGeodis PBA+ exact EDT 3D (GPU)""" + result = FastGeodis.exact_euclidean3d(mask, spacing) + torch.cuda.synchronize() + return result + + +def test2d(): + """Benchmark 2D EDT methods""" + num_runs = 5 + + sizes_to_test = [64, 128, 256, 512, 1024, 2048] + print(sizes_to_test) + + time_taken_dict = {} + error_dict = {'raster_error': [], 'pba_error': []} + + # Test scipy + if SCIPY_AVAILABLE: + time_taken_dict['scipy_edt_2d'] = [] + for size in sizes_to_test: + mask_np = np.ones((size, size), dtype=np.float32) + mask_np[size // 2, size // 2] = 0 # Single seed point + + tic = time.time() + for _ in range(num_runs): + scipy_edt_2d(mask_np) + time_taken_dict['scipy_edt_2d'].append((time.time() - tic) / num_runs) + print() + + # Test raster GPU + time_taken_dict['generalised_geodesic2d_raster_gpu'] = [] + for size in sizes_to_test: + image = torch.ones((1, 1, size, size), device='cuda') + seed = torch.ones((1, 1, size, size), device='cuda') + seed[:, :, size // 2, size // 2] = 0.0 + + tic = time.time() + for _ in range(num_runs): + generalised_geodesic2d_raster_gpu(image, seed, 1e10, 0.0, 4) + time_taken_dict['generalised_geodesic2d_raster_gpu'].append((time.time() - tic) / num_runs) + print() + + # Test PBA+ GPU + time_taken_dict['exact_euclidean2d_pba_gpu'] = [] + for size in sizes_to_test: + mask = torch.ones((1, 1, size, size), device='cuda') + mask[:, :, size // 2, size // 2] = 0.0 + + tic = time.time() + for _ in range(num_runs): + exact_euclidean2d_pba_gpu(mask, [1.0, 1.0]) + time_taken_dict['exact_euclidean2d_pba_gpu'].append((time.time() - tic) / num_runs) + print() + + # Compute errors vs scipy + if SCIPY_AVAILABLE: + print("Computing errors vs scipy reference...") + for size in sizes_to_test: + mask_np = np.ones((size, size), dtype=np.float32) + mask_np[size // 2, size // 2] = 0 + + # Scipy reference + scipy_dist = ndimage.distance_transform_edt(mask_np) + + # Raster GPU + image = torch.ones((1, 1, size, size), device='cuda') + seed = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to('cuda') + raster_dist = FastGeodis.generalised_geodesic2d(image, seed, 1e10, 0.0, 4) + raster_dist_np = raster_dist[0, 0].cpu().numpy() + + # PBA+ GPU + mask_torch = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to('cuda') + pba_dist = FastGeodis.exact_euclidean2d(mask_torch, [1.0, 1.0]) + pba_dist_np = pba_dist[0, 0].cpu().numpy() + + raster_err = np.abs(scipy_dist - raster_dist_np).max() + pba_err = np.abs(scipy_dist - pba_dist_np).max() + + error_dict['raster_error'].append(float(raster_err)) + error_dict['pba_error'].append(float(pba_err)) + print(f" Size {size}x{size}: Raster err={raster_err:.2e}, PBA+ err={pba_err:.2e}") + + time_taken_dict['spatial_dim'] = sizes_to_test + time_taken_dict.update(error_dict) + + return sizes_to_test, time_taken_dict + + +def test3d(): + """Benchmark 3D EDT methods""" + num_runs = 3 + spacing = [1.0, 1.0, 1.0] + + sizes_to_test = [32, 64, 128] + print(sizes_to_test) + + time_taken_dict = {} + error_dict = {'raster_error': [], 'pba_error': []} + + # Test scipy + if SCIPY_AVAILABLE: + time_taken_dict['scipy_edt_3d'] = [] + for size in sizes_to_test: + mask_np = np.ones((size, size, size), dtype=np.float32) + mask_np[size // 2, size // 2, size // 2] = 0 + + tic = time.time() + for _ in range(num_runs): + scipy_edt_3d(mask_np) + time_taken_dict['scipy_edt_3d'].append((time.time() - tic) / num_runs) + print() + + # Test raster GPU + time_taken_dict['generalised_geodesic3d_raster_gpu'] = [] + for size in sizes_to_test: + image = torch.ones((1, 1, size, size, size), device='cuda') + seed = torch.ones((1, 1, size, size, size), device='cuda') + seed[:, :, size // 2, size // 2, size // 2] = 0.0 + + tic = time.time() + for _ in range(num_runs): + generalised_geodesic3d_raster_gpu(image, seed, spacing, 1e10, 0.0, 4) + time_taken_dict['generalised_geodesic3d_raster_gpu'].append((time.time() - tic) / num_runs) + print() + + # Test PBA+ GPU + time_taken_dict['exact_euclidean3d_pba_gpu'] = [] + for size in sizes_to_test: + mask = torch.ones((1, 1, size, size, size), device='cuda') + mask[:, :, size // 2, size // 2, size // 2] = 0.0 + + tic = time.time() + for _ in range(num_runs): + exact_euclidean3d_pba_gpu(mask, spacing) + time_taken_dict['exact_euclidean3d_pba_gpu'].append((time.time() - tic) / num_runs) + print() + + # Compute errors vs scipy + if SCIPY_AVAILABLE: + print("Computing errors vs scipy reference...") + for size in sizes_to_test: + mask_np = np.ones((size, size, size), dtype=np.float32) + mask_np[size // 2, size // 2, size // 2] = 0 + + # Scipy reference + scipy_dist = ndimage.distance_transform_edt(mask_np) + + # Raster GPU + image = torch.ones((1, 1, size, size, size), device='cuda') + seed = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to('cuda') + raster_dist = FastGeodis.generalised_geodesic3d(image, seed, spacing, 1e10, 0.0, 4) + raster_dist_np = raster_dist[0, 0].cpu().numpy() + + # PBA+ GPU + mask_torch = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to('cuda') + pba_dist = FastGeodis.exact_euclidean3d(mask_torch, spacing) + pba_dist_np = pba_dist[0, 0].cpu().numpy() + + raster_err = np.abs(scipy_dist - raster_dist_np).max() + pba_err = np.abs(scipy_dist - pba_dist_np).max() + + error_dict['raster_error'].append(float(raster_err)) + error_dict['pba_error'].append(float(pba_err)) + print(f" Size {size}x{size}x{size}: Raster err={raster_err:.2e}, PBA+ err={pba_err:.2e}") + + time_taken_dict['spatial_dim'] = sizes_to_test + time_taken_dict.update(error_dict) + + return sizes_to_test, time_taken_dict + + +def save_timing_plot(sizes, time_taken_dict, figname, dim='2D'): + """Save timing comparison plot""" + plt.figure(figsize=(10, 6)) + plt.grid(True, alpha=0.3) + + for key in time_taken_dict.keys(): + if key in ['spatial_dim', 'raster_error', 'pba_error']: + continue + if 'scipy' in key: + plt.plot(sizes, time_taken_dict[key], 'b-s', + label="scipy EDT (CPU)", linewidth=2, markersize=8) + elif 'raster' in key: + plt.plot(sizes, time_taken_dict[key], 'r-^', + label="FastGeodis Raster (GPU)", linewidth=2, markersize=8) + elif 'pba' in key or 'exact' in key: + plt.plot(sizes, time_taken_dict[key], 'g-o', + label="FastGeodis PBA+ (GPU)", linewidth=2, markersize=8) + + plt.legend(fontsize=12) + plt.xticks(sizes, [str(s) for s in sizes], rotation=45) + plt.title(f"{dim} EDT Performance Comparison", fontsize=14) + plt.xlabel("Spatial size", fontsize=12) + plt.ylabel("Execution time (seconds)", fontsize=12) + plt.yscale('log') + plt.tight_layout() + plt.savefig(os.path.join("figures", figname + ".png"), dpi=150) + plt.close() + + # Save JSON + with open(os.path.join('figures', figname + '.json'), 'w') as fp: + json.dump(time_taken_dict, fp, indent=4) + + +def save_error_plot(sizes, time_taken_dict, figname, dim='2D'): + """Save accuracy comparison plot""" + if 'raster_error' not in time_taken_dict or 'pba_error' not in time_taken_dict: + return + + plt.figure(figsize=(10, 6)) + plt.grid(True, alpha=0.3) + + plt.plot(sizes, time_taken_dict['raster_error'], 'r-^', + label='FastGeodis Raster (GPU)', linewidth=2, markersize=8) + plt.plot(sizes, time_taken_dict['pba_error'], 'g-o', + label='FastGeodis PBA+ (GPU)', linewidth=2, markersize=8) + + plt.legend(fontsize=12) + plt.xticks(sizes, [str(s) for s in sizes], rotation=45) + plt.title(f"{dim} EDT Max Error vs scipy Reference", fontsize=14) + plt.xlabel("Spatial size", fontsize=12) + plt.ylabel("Max absolute error", fontsize=12) + plt.yscale('log') + plt.tight_layout() + plt.savefig(os.path.join("figures", figname + "_error.png"), dpi=150) + plt.close() + + +if __name__ == "__main__": + # Ensure figures directory exists + os.makedirs("figures", exist_ok=True) + + # 2D benchmarks + sizes, ttdict = test2d() + save_timing_plot(sizes, ttdict, "experiment_2d_pba", "2D") + save_error_plot(sizes, ttdict, "experiment_2d_pba", "2D") + + # 3D benchmarks + sizes, ttdict = test3d() + save_timing_plot(sizes, ttdict, "experiment_3d_pba", "3D") + save_error_plot(sizes, ttdict, "experiment_3d_pba", "3D") diff --git a/setup.py b/setup.py index 5e0baa3..bc6ac7c 100755 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ from setuptools import find_packages, setup FORCE_CUDA = os.getenv("FORCE_CUDA", "0") == "1" +DISABLE_CUDA = os.getenv("DISABLE_CUDA", "0") == "1" BUILD_CPP = BUILD_CUDA = False TORCH_VERSION = 0 @@ -21,7 +22,12 @@ BUILD_CPP = True from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension - BUILD_CUDA = (CUDA_HOME is not None) if torch.cuda.is_available() else FORCE_CUDA + if DISABLE_CUDA: + BUILD_CUDA = False + else: + BUILD_CUDA = ( + (CUDA_HOME is not None) if torch.cuda.is_available() else FORCE_CUDA + ) _pt_version = pkg_resources.parse_version(torch.__version__)._version.release if _pt_version is None or len(_pt_version) < 3: @@ -97,10 +103,10 @@ def get_extensions(): extra_compile_args["cxx"] += omp_flags() if extension is None or not sources: return [] # compile nothing - + # compile release extra_compile_args["cxx"] += ["-g0"] - + ext_modules = [ extension( name="FastGeodisCpp", @@ -113,9 +119,10 @@ def get_extensions(): ] return ext_modules + def get_version(): # following guidance from: https://stackoverflow.com/a/7071358 - VERSIONFILE="FastGeodis/_version.py" + VERSIONFILE = "FastGeodis/_version.py" verstrline = open(VERSIONFILE, "rt").read() VSRE = r"^__version__ = ['\"]([^'\"]*)['\"]" mo = re.search(VSRE, verstrline, re.M) @@ -125,6 +132,7 @@ def get_version(): raise RuntimeError("Unable to find version string in %s." % (VERSIONFILE,)) return verstr + # get current version version = get_version() print(f"FastGeodis building version: {version}") diff --git a/tests/test_pba.py b/tests/test_pba.py new file mode 100644 index 0000000..e990332 --- /dev/null +++ b/tests/test_pba.py @@ -0,0 +1,421 @@ +# BSD 3-Clause License + +# Copyright (c) 2021, Muhammad Asad (masadcv@gmail.com) +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Tests for PBA+ (Parallel Banding Algorithm Plus) Exact Euclidean Distance Transform. + +This module tests the exact EDT implementation against scipy's reference implementation. +""" + +import math +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from .utils import skip_if_no_cuda + +# set deterministic seed +torch.manual_seed(15) +np.random.seed(15) + +try: + import FastGeodis +except: + print( + "Unable to load FastGeodis for unittests\nMake sure to install using: python setup.py install" + ) + exit() + +try: + from scipy import ndimage + + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + print("scipy not available, some tests will be skipped") + + +# Test configurations for 2D +CONF_2D_CUDA = [("cuda", bas) for bas in [32, 64, 128, 256]] + +# Test configurations for 3D (smaller sizes due to memory) +CONF_3D_CUDA = [("cuda", bas) for bas in [16, 32, 64]] + + +class TestExactEuclidean2D(unittest.TestCase): + """Tests for 2D Exact Euclidean Distance Transform using PBA+""" + + @skip_if_no_cuda + @parameterized.expand(CONF_2D_CUDA) + def test_single_seed_point(self, device, base_dim): + """Test EDT with a single seed point at center""" + height, width = base_dim, base_dim + + # Create mask with single seed point + mask = torch.ones(1, 1, height, width, dtype=torch.float32, device=device) + mask[0, 0, height // 2, width // 2] = 0 # Seed point + + # Compute exact EDT + distance = FastGeodis.exact_euclidean2d(mask, spacing=[1.0, 1.0]) + + # Check output shape + self.assertEqual(distance.shape, mask.shape) + + # Check seed point has distance 0 + self.assertAlmostEqual( + distance[0, 0, height // 2, width // 2].item(), 0.0, places=5 + ) + + # Check corner has expected distance (Euclidean from center to corner) + expected_corner_dist = math.sqrt((height // 2) ** 2 + (width // 2) ** 2) + actual_corner_dist = distance[0, 0, 0, 0].item() + + # Allow small tolerance for floating point + self.assertAlmostEqual(actual_corner_dist, expected_corner_dist, places=3) + + @unittest.skipUnless(SCIPY_AVAILABLE, "scipy required for exact comparison") + @skip_if_no_cuda + @parameterized.expand(CONF_2D_CUDA) + def test_exact_match_scipy(self, device, base_dim): + """Test that PBA+ matches scipy's exact EDT""" + height, width = base_dim, base_dim + + # Create random binary mask with multiple seed points + np.random.seed(42) + mask_np = np.ones((height, width), dtype=np.float32) + + # Add some random seed points + num_seeds = max(1, base_dim // 8) + for _ in range(num_seeds): + y, x = np.random.randint(0, height), np.random.randint(0, width) + mask_np[y, x] = 0 + + # Scipy reference (exact) + scipy_dist = ndimage.distance_transform_edt(mask_np) + + # FastGeodis PBA+ (should be exact) + mask_torch = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to(device) + pba_dist = FastGeodis.exact_euclidean2d(mask_torch, spacing=[1.0, 1.0]) + pba_dist_np = pba_dist[0, 0].cpu().numpy() + + # Compare - should be exact (or very close due to floating point) + max_diff = np.abs(scipy_dist - pba_dist_np).max() + mean_diff = np.abs(scipy_dist - pba_dist_np).mean() + + # Allow small tolerance for floating point precision + self.assertLess( + max_diff, + 1e-3, + f"Max difference {max_diff} exceeds threshold for size {base_dim}", + ) + self.assertLess( + mean_diff, + 1e-4, + f"Mean difference {mean_diff} exceeds threshold for size {base_dim}", + ) + + @skip_if_no_cuda + def test_anisotropic_spacing(self): + """Test EDT with anisotropic pixel spacing""" + device = "cuda" + height, width = 64, 64 + spacing = [2.0, 1.0] # Pixels are 2x taller than wide + + mask = torch.ones(1, 1, height, width, dtype=torch.float32, device=device) + mask[0, 0, height // 2, width // 2] = 0 + + distance = FastGeodis.exact_euclidean2d(mask, spacing=spacing) + + # Distance to point 1 pixel right should be 1.0 + dist_right = distance[0, 0, height // 2, width // 2 + 1].item() + self.assertAlmostEqual(dist_right, 1.0, places=3) + + # Distance to point 1 pixel down should be 2.0 (due to spacing) + dist_down = distance[0, 0, height // 2 + 1, width // 2].item() + self.assertAlmostEqual(dist_down, 2.0, places=3) + + @skip_if_no_cuda + def test_all_seeds(self): + """Test when all pixels are seeds (distance should be 0)""" + device = "cuda" + height, width = 32, 32 + + mask = torch.zeros(1, 1, height, width, dtype=torch.float32, device=device) + distance = FastGeodis.exact_euclidean2d(mask, spacing=[1.0, 1.0]) + + # All distances should be 0 + self.assertAlmostEqual(distance.max().item(), 0.0, places=5) + + @skip_if_no_cuda + def test_ill_shape(self): + """Test that wrong input shapes raise errors""" + device = "cuda" + + # 3D input (should fail - expects 4D) + mask_3d = torch.ones(1, 32, 32, dtype=torch.float32, device=device) + with self.assertRaises(Exception): + FastGeodis.exact_euclidean2d(mask_3d, spacing=[1.0, 1.0]) + + # Wrong spacing dimension + mask = torch.ones(1, 1, 32, 32, dtype=torch.float32, device=device) + with self.assertRaises(Exception): + FastGeodis.exact_euclidean2d(mask, spacing=[1.0, 1.0, 1.0]) + + @unittest.skipUnless(SCIPY_AVAILABLE, "scipy required for exact comparison") + @skip_if_no_cuda + def test_non_aligned_sizes(self): + """Test EDT with non-32-aligned sizes (tests padding handling)""" + device = "cuda" + + # Test various non-aligned sizes that require padding + # Include sizes that would pad to non-power-of-2 (e.g., 65->96, 129->160) + # to ensure power-of-2 padding is working correctly + for size in [(100, 100), (50, 75), (33, 47), (65, 65), (32, 65), (32, 96), (17, 93), (32, 129)]: + height, width = size + + # Create mask with seed near corner to test padding boundary + mask_np = np.ones((height, width), dtype=np.float32) + mask_np[0, 0] = 0 # Seed at corner + + # Scipy reference + scipy_dist = ndimage.distance_transform_edt(mask_np) + + # PBA+ result + mask_torch = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to(device) + pba_dist = FastGeodis.exact_euclidean2d(mask_torch, spacing=[1.0, 1.0]) + pba_dist_np = pba_dist[0, 0].cpu().numpy() + + # Compare + max_diff = np.abs(scipy_dist - pba_dist_np).max() + self.assertLess( + max_diff, + 1e-3, + f"Non-aligned size {size} failed with max_diff={max_diff}", + ) + + +class TestExactEuclidean3D(unittest.TestCase): + """Tests for 3D Exact Euclidean Distance Transform using PBA+""" + + @skip_if_no_cuda + @parameterized.expand(CONF_3D_CUDA) + def test_single_seed_point(self, device, base_dim): + """Test 3D EDT with a single seed point at center""" + depth, height, width = base_dim, base_dim, base_dim + + mask = torch.ones( + 1, 1, depth, height, width, dtype=torch.float32, device=device + ) + mask[0, 0, depth // 2, height // 2, width // 2] = 0 + + distance = FastGeodis.exact_euclidean3d(mask, spacing=[1.0, 1.0, 1.0]) + + # Check output shape + self.assertEqual(distance.shape, mask.shape) + + # Check seed point has distance 0 + self.assertAlmostEqual( + distance[0, 0, depth // 2, height // 2, width // 2].item(), 0.0, places=5 + ) + + # Check corner distance + expected_corner_dist = math.sqrt( + (depth // 2) ** 2 + (height // 2) ** 2 + (width // 2) ** 2 + ) + actual_corner_dist = distance[0, 0, 0, 0, 0].item() + + self.assertAlmostEqual(actual_corner_dist, expected_corner_dist, places=2) + + @unittest.skipUnless(SCIPY_AVAILABLE, "scipy required for exact comparison") + @skip_if_no_cuda + def test_exact_match_scipy_3d(self): + """Test that 3D PBA+ matches scipy's exact EDT""" + device = "cuda" + depth, height, width = 32, 32, 32 + + # Create mask with some seed points + mask_np = np.ones((depth, height, width), dtype=np.float32) + mask_np[depth // 2, height // 2, width // 2] = 0 + mask_np[5, 5, 5] = 0 + mask_np[depth - 5, height - 5, width - 5] = 0 + + # Scipy reference + scipy_dist = ndimage.distance_transform_edt(mask_np) + + # FastGeodis PBA+ + mask_torch = torch.from_numpy(mask_np).unsqueeze(0).unsqueeze(0).to(device) + pba_dist = FastGeodis.exact_euclidean3d(mask_torch, spacing=[1.0, 1.0, 1.0]) + pba_dist_np = pba_dist[0, 0].cpu().numpy() + + # Compare + max_diff = np.abs(scipy_dist - pba_dist_np).max() + mean_diff = np.abs(scipy_dist - pba_dist_np).mean() + + self.assertLess(max_diff, 1.0, f"Max difference {max_diff} exceeds threshold") + self.assertLess( + mean_diff, 0.1, f"Mean difference {mean_diff} exceeds threshold" + ) + + @skip_if_no_cuda + def test_anisotropic_spacing_3d(self): + """Test 3D EDT with anisotropic voxel spacing""" + device = "cuda" + depth, height, width = 32, 32, 32 + spacing = [1.0, 2.0, 3.0] # z, y, x spacing + + mask = torch.ones( + 1, 1, depth, height, width, dtype=torch.float32, device=device + ) + mask[0, 0, depth // 2, height // 2, width // 2] = 0 + + distance = FastGeodis.exact_euclidean3d(mask, spacing=spacing) + + # Distance along each axis should reflect spacing + # +1 in x direction: distance = 3.0 + dist_x = distance[0, 0, depth // 2, height // 2, width // 2 + 1].item() + self.assertAlmostEqual(dist_x, 3.0, places=2) + + # +1 in y direction: distance = 2.0 + dist_y = distance[0, 0, depth // 2, height // 2 + 1, width // 2].item() + self.assertAlmostEqual(dist_y, 2.0, places=2) + + # +1 in z direction: distance = 1.0 + dist_z = distance[0, 0, depth // 2 + 1, height // 2, width // 2].item() + self.assertAlmostEqual(dist_z, 1.0, places=2) + + +class TestSignedExactEuclidean2D(unittest.TestCase): + """Tests for Signed 2D Exact EDT""" + + @skip_if_no_cuda + def test_signed_distance(self): + """Test signed distance has correct signs""" + device = "cuda" + height, width = 64, 64 + + # Create a square region of foreground + mask = torch.zeros(1, 1, height, width, dtype=torch.float32, device=device) + mask[0, 0, 20:44, 20:44] = 1 # Square from (20,20) to (43,43) + + signed_dist = FastGeodis.signed_exact_euclidean2d(mask, spacing=[1.0, 1.0]) + + # Inside the square (e.g., center) should be negative + center_dist = signed_dist[0, 0, 32, 32].item() + self.assertLess(center_dist, 0, "Center should have negative signed distance") + + # Outside the square (e.g., corner) should be positive + corner_dist = signed_dist[0, 0, 0, 0].item() + self.assertGreater( + corner_dist, 0, "Corner should have positive signed distance" + ) + + @skip_if_no_cuda + def test_signed_boundary(self): + """Test that boundary has distance close to 0""" + device = "cuda" + height, width = 64, 64 + + mask = torch.zeros(1, 1, height, width, dtype=torch.float32, device=device) + mask[0, 0, 20:44, 20:44] = 1 + + signed_dist = FastGeodis.signed_exact_euclidean2d(mask, spacing=[1.0, 1.0]) + + # On the boundary, signed distance should be close to 0 + boundary_dist = signed_dist[0, 0, 20, 32].item() # Edge of square + self.assertAlmostEqual(abs(boundary_dist), 0.5, places=0) + + +class TestSignedExactEuclidean3D(unittest.TestCase): + """Tests for Signed 3D Exact EDT""" + + @skip_if_no_cuda + def test_signed_distance_3d(self): + """Test 3D signed distance has correct signs""" + device = "cuda" + depth, height, width = 32, 32, 32 + + # Create a cube region of foreground + mask = torch.zeros( + 1, 1, depth, height, width, dtype=torch.float32, device=device + ) + mask[0, 0, 10:22, 10:22, 10:22] = 1 # Cube + + signed_dist = FastGeodis.signed_exact_euclidean3d(mask, spacing=[1.0, 1.0, 1.0]) + + # Inside the cube should be negative + center_dist = signed_dist[0, 0, 16, 16, 16].item() + self.assertLess(center_dist, 0, "Center should have negative signed distance") + + # Outside the cube should be positive + corner_dist = signed_dist[0, 0, 0, 0, 0].item() + self.assertGreater( + corner_dist, 0, "Corner should have positive signed distance" + ) + + +class TestCompareWithApproximate(unittest.TestCase): + """Compare PBA+ exact EDT with approximate raster scanning EDT""" + + @skip_if_no_cuda + def test_pba_more_accurate(self): + """Test that PBA+ is more accurate than raster scanning for diagonal distances""" + device = "cuda" + height, width = 128, 128 + + # Single seed at corner + mask = torch.ones(1, 1, height, width, dtype=torch.float32, device=device) + mask[0, 0, 0, 0] = 0 + + # PBA+ exact + pba_dist = FastGeodis.exact_euclidean2d(mask, spacing=[1.0, 1.0]) + + # Approximate (raster scanning with lamb=0) + image = torch.ones_like(mask) + approx_dist = FastGeodis.generalised_geodesic2d(image, mask, 1e10, 0.0, 4) + + # Expected distance to opposite corner + expected_dist = math.sqrt((height - 1) ** 2 + (width - 1) ** 2) + + pba_corner = pba_dist[0, 0, height - 1, width - 1].item() + approx_corner = approx_dist[0, 0, height - 1, width - 1].item() + + pba_error = abs(pba_corner - expected_dist) + approx_error = abs(approx_corner - expected_dist) + + # PBA should be more accurate (or at least as accurate) + self.assertLessEqual( + pba_error, 1.0, f"PBA+ error {pba_error} is too high for diagonal distance" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 4f301f9..7fa47f7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,7 +28,6 @@ CONF_ALL_CPU_FM = CONF_2D_CPU_FM + CONF_3D_CPU_FM - def skip_if_no_cuda(obj): return unittest.skipUnless(torch.cuda.is_available(), "Skipping CUDA-based tests")( obj @@ -92,9 +91,8 @@ def pixelqueue_signed_generalised_geodesic_distance_2d(image, softmask, lamb, it def pixelqueue_signed_generalised_geodesic_distance_3d( image, softmask, lamb, iter, spacing ): - return FastGeodis.signed_geodesic3d_pixelqueue( - image, softmask, spacing, lamb - ) + return FastGeodis.signed_geodesic3d_pixelqueue(image, softmask, spacing, lamb) + def fastmarch_signed_generalised_geodesic_distance_2d(image, softmask, lamb, iter): return FastGeodis.signed_geodesic2d_fastmarch(image, softmask, lamb) @@ -103,9 +101,8 @@ def fastmarch_signed_generalised_geodesic_distance_2d(image, softmask, lamb, ite def fastmarch_signed_generalised_geodesic_distance_3d( image, softmask, lamb, iter, spacing ): - return FastGeodis.signed_geodesic3d_fastmarch( - image, softmask, spacing, lamb - ) + return FastGeodis.signed_geodesic3d_fastmarch(image, softmask, spacing, lamb) + def toivanen_generalised_geodesic_distance_2d(image, softmask, v, lamb, iter): return FastGeodis.generalised_geodesic2d_toivanen(image, softmask, v, lamb, iter) @@ -122,18 +119,15 @@ def pixelqueue_geodesic_distance_2d(image, softmask, lamb, iter): def pixelqueue_geodesic_distance_3d(image, softmask, lamb, iter, spacing): - return FastGeodis.geodesic3d_pixelqueue( - image, softmask, spacing, lamb - ) + return FastGeodis.geodesic3d_pixelqueue(image, softmask, spacing, lamb) + def fastmarch_geodesic_distance_2d(image, softmask, lamb, iter): return FastGeodis.geodesic2d_fastmarch(image, softmask, lamb) def fastmarch_geodesic_distance_3d(image, softmask, lamb, iter, spacing): - return FastGeodis.geodesic3d_fastmarch( - image, softmask, spacing, lamb - ) + return FastGeodis.geodesic3d_fastmarch(image, softmask, spacing, lamb) def fastgeodis_GSF_2d(image, softmask, theta, v, lamb, iter): @@ -159,6 +153,7 @@ def pixelqueue_GSF_2d(image, softmask, theta, lamb, iter): def pixelqueue_GSF_3d(image, softmask, theta, lamb, iter, spacing): return FastGeodis.GSF3d_pixelqueue(image, softmask, theta, spacing, lamb) + def fastmarch_GSF_2d(image, softmask, theta, lamb, iter): return FastGeodis.GSF2d_fastmarch(image, softmask, theta, lamb) @@ -232,6 +227,7 @@ def get_signed_pixelqueue_func(num_dims, spacing=[1.0, 1.0, 1.0]): else: raise ValueError("Unsupported num_dims received: {}".format(num_dims)) + def get_fastmarch_func(num_dims, spacing=[1.0, 1.0, 1.0]): if num_dims == 2: return fastmarch_geodesic_distance_2d @@ -279,7 +275,6 @@ def get_GSF_pixelqueue_func(num_dims, spacing=[1.0, 1.0, 1.0]): raise ValueError("Unsupported num_dims received: {}".format(num_dims)) - def get_GSF_fastmarch_func(num_dims, spacing=[1.0, 1.0, 1.0]): if num_dims == 2: return fastmarch_GSF_2d @@ -288,3 +283,37 @@ def get_GSF_fastmarch_func(num_dims, spacing=[1.0, 1.0, 1.0]): else: raise ValueError("Unsupported num_dims received: {}".format(num_dims)) + +# PBA+ Exact Euclidean Distance Transform functions +def exact_euclidean_2d(mask, spacing): + return FastGeodis.exact_euclidean2d(mask, spacing) + + +def exact_euclidean_3d(mask, spacing): + return FastGeodis.exact_euclidean3d(mask, spacing) + + +def signed_exact_euclidean_2d(mask, spacing): + return FastGeodis.signed_exact_euclidean2d(mask, spacing) + + +def signed_exact_euclidean_3d(mask, spacing): + return FastGeodis.signed_exact_euclidean3d(mask, spacing) + + +def get_exact_euclidean_func(num_dims, spacing=[1.0, 1.0, 1.0]): + if num_dims == 2: + return partial(exact_euclidean_2d, spacing=spacing[:2]) + elif num_dims == 3: + return partial(exact_euclidean_3d, spacing=spacing) + else: + raise ValueError("Unsupported num_dims received: {}".format(num_dims)) + + +def get_signed_exact_euclidean_func(num_dims, spacing=[1.0, 1.0, 1.0]): + if num_dims == 2: + return partial(signed_exact_euclidean_2d, spacing=spacing[:2]) + elif num_dims == 3: + return partial(signed_exact_euclidean_3d, spacing=spacing) + else: + raise ValueError("Unsupported num_dims received: {}".format(num_dims))