diff --git a/benchmark/distance_transform.py b/benchmark/distance_transform.py deleted file mode 100644 index 3737ced..0000000 --- a/benchmark/distance_transform.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -import torch.utils.benchmark as benchmark -from prettytable import PrettyTable - -sizes = [64, 128, 256, 512, 1024] -batches = [1, 4, 8, 16] -dtype = torch.float32 -device = "cuda" -MIN_RUN = 1.0 # seconds per measurement - -torch.set_num_threads(torch.get_num_threads()) - -for B in batches: - table = PrettyTable() - table.field_names = [ - "Size", - "SciPy (ms/img)", - "Torch 1× (ms/img)", - "Torch batch (ms/img)", - "Speedup 1×", - "Speedup batch", - ] - for c in table.field_names: - table.align[c] = "r" - - for s in sizes: - # Inputs - x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) - x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] - x_imgs = [x[i : i + 1] for i in range(B)] - - # SciPy (CPU, one-by-one) - stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]" - t_scipy = benchmark.Timer( - stmt=stmt_scipy, - setup="from __main__ import x_np_list, ndi", - num_threads=torch.get_num_threads(), - ).blocked_autorange(min_run_time=MIN_RUN) - scipy_per_img_ms = (t_scipy.median * 1e3) / B - - # Torch (CUDA, one-by-one) - stmt_torch1 = """ -for xi in x_imgs: - tm.distance_transform(xi) -""" - t_torch1 = benchmark.Timer( - stmt=stmt_torch1, - setup="from __main__ import x_imgs, tm", - num_threads=torch.get_num_threads(), - ).blocked_autorange(min_run_time=MIN_RUN) - torch1_per_img_ms = (t_torch1.median * 1e3) / B - - # Torch (CUDA, batched) - t_batch = benchmark.Timer( - stmt="tm.distance_transform(x)", - setup="from __main__ import x, tm", - num_threads=torch.get_num_threads(), - ).blocked_autorange(min_run_time=MIN_RUN) - torchB_per_img_ms = (t_batch.median * 1e3) / B - - # Speedups - speed1 = scipy_per_img_ms / torch1_per_img_ms - speedB = scipy_per_img_ms / torchB_per_img_ms - - table.add_row( - [ - s, - f"{scipy_per_img_ms:.3f}", - f"{torch1_per_img_ms:.3f}", - f"{torchB_per_img_ms:.3f}", - f"{speed1:.1f}×", - f"{speedB:.1f}×", - ] - ) - - print(f"\n=== Batch Size: {B} ===") - print(table) diff --git a/benchmark/distance_transform_cdt.py b/benchmark/distance_transform_cdt.py new file mode 100644 index 0000000..4e56135 --- /dev/null +++ b/benchmark/distance_transform_cdt.py @@ -0,0 +1,89 @@ +import scipy.ndimage as ndi # noqa: F401 +import torch +import torch.utils.benchmark as benchmark +from prettytable import PrettyTable + +import torchmorph as tm # noqa: F401 + +sizes = [64, 128, 256, 512, 1024] +batches = [1, 4, 8, 16] +dtype = torch.float32 +device = "cuda" +MIN_RUN = 1.0 # seconds per measurement + +torch.set_num_threads(torch.get_num_threads()) + +for metric in ["chessboard", "taxicab"]: + print(f"\n{'='*60}") + print(f" CDT Benchmark - Metric: {metric}") + print(f"{'='*60}") + + for B in batches: + table = PrettyTable() + table.field_names = [ + "Size", + "SciPy (ms/img)", + "Torch 1× (ms/img)", + "Torch batch (ms/img)", + "Speedup 1×", + "Speedup batch", + ] + for c in table.field_names: + table.align[c] = "r" + + for s in sizes: + # Inputs: (B, C, H, W) format - C=1 for single channel + x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) + # For scipy, we need (H, W) arrays + x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] + # For torch single image processing: each is (1, 1, H, W) + x_imgs = [x[i : i + 1] for i in range(B)] + + # SciPy (CPU, one-by-one) + stmt_scipy = ( + f"out = [ndi.distance_transform_cdt(arr, metric='{metric}') for arr in x_np_list]" + ) + t_scipy = benchmark.Timer( + stmt=stmt_scipy, + setup="from __main__ import x_np_list, ndi", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + scipy_per_img_ms = (t_scipy.median * 1e3) / B + + # Torch (CUDA, one-by-one) + stmt_torch1 = f""" +for xi in x_imgs: + tm.distance_transform_cdt(xi, metric='{metric}') +""" + t_torch1 = benchmark.Timer( + stmt=stmt_torch1, + setup="from __main__ import x_imgs, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + torch1_per_img_ms = (t_torch1.median * 1e3) / B + + # Torch (CUDA, batched) + t_batch = benchmark.Timer( + stmt=f"tm.distance_transform_cdt(x, metric='{metric}')", + setup="from __main__ import x, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + torchB_per_img_ms = (t_batch.median * 1e3) / B + + # Speedups + speed1 = scipy_per_img_ms / torch1_per_img_ms + speedB = scipy_per_img_ms / torchB_per_img_ms + + table.add_row( + [ + s, + f"{scipy_per_img_ms:.3f}", + f"{torch1_per_img_ms:.3f}", + f"{torchB_per_img_ms:.3f}", + f"{speed1:.1f}×", + f"{speedB:.1f}×", + ] + ) + + print(f"\n=== Metric: {metric}, Batch Size: {B} ===") + print(table) diff --git a/benchmark/distance_transform_edt.py b/benchmark/distance_transform_edt.py new file mode 100644 index 0000000..f1f7318 --- /dev/null +++ b/benchmark/distance_transform_edt.py @@ -0,0 +1,106 @@ +import scipy.ndimage as ndi # noqa: F401 +import torch +import torch.utils.benchmark as benchmark +from prettytable import PrettyTable + +import torchmorph as tm # noqa: F401 + +sizes = [64, 128, 256, 512, 1024] +batches = [1, 4, 8, 16] +dtype = torch.float32 +device = "cuda" +MIN_RUN = 1.0 # seconds per measurement + +torch.set_num_threads(torch.get_num_threads()) + +for B in batches: + table = PrettyTable() + table.field_names = [ + "Size", + "SciPy (ms/img)", + "Exact 1× (ms/img)", + "Exact batch (ms/img)", + "JFA 1× (ms/img)", + "JFA batch (ms/img)", + "Speedup Exact", + "Speedup JFA", + ] + for c in table.field_names: + table.align[c] = "r" + + for s in sizes: + # Inputs: (B, C, H, W) format - C=1 for single channel + x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) + # For scipy, we need (H, W) arrays + x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] + # For torch single image processing: each is (1, 1, H, W) + x_imgs = [x[i : i + 1] for i in range(B)] + + # SciPy (CPU, one-by-one) + stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]" + t_scipy = benchmark.Timer( + stmt=stmt_scipy, + setup="from __main__ import x_np_list, ndi", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + scipy_per_img_ms = (t_scipy.median * 1e3) / B + + # Torch Exact (CUDA, one-by-one) + stmt_exact1 = """ +for xi in x_imgs: + tm.distance_transform_edt(xi, algorithm="exact") +""" + t_exact1 = benchmark.Timer( + stmt=stmt_exact1, + setup="from __main__ import x_imgs, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + exact1_per_img_ms = (t_exact1.median * 1e3) / B + + # Torch Exact (CUDA, batched) + t_exact_batch = benchmark.Timer( + stmt='tm.distance_transform_edt(x, algorithm="exact")', + setup="from __main__ import x, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + exactB_per_img_ms = (t_exact_batch.median * 1e3) / B + + # Torch JFA (CUDA, one-by-one) + stmt_jfa1 = """ +for xi in x_imgs: + tm.distance_transform_edt(xi, algorithm="jfa") +""" + t_jfa1 = benchmark.Timer( + stmt=stmt_jfa1, + setup="from __main__ import x_imgs, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + jfa1_per_img_ms = (t_jfa1.median * 1e3) / B + + # Torch JFA (CUDA, batched) + t_jfa_batch = benchmark.Timer( + stmt='tm.distance_transform_edt(x, algorithm="jfa")', + setup="from __main__ import x, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + jfaB_per_img_ms = (t_jfa_batch.median * 1e3) / B + + # Speedups (batch mode vs scipy) + speed_exact = scipy_per_img_ms / exactB_per_img_ms + speed_jfa = scipy_per_img_ms / jfaB_per_img_ms + + table.add_row( + [ + s, + f"{scipy_per_img_ms:.3f}", + f"{exact1_per_img_ms:.3f}", + f"{exactB_per_img_ms:.3f}", + f"{jfa1_per_img_ms:.3f}", + f"{jfaB_per_img_ms:.3f}", + f"{speed_exact:.1f}×", + f"{speed_jfa:.1f}×", + ] + ) + + print(f"\n=== Batch Size: {B} ===") + print(table) diff --git a/benchmark/distance_transform_edt_3D.py b/benchmark/distance_transform_edt_3D.py new file mode 100644 index 0000000..49a5c31 --- /dev/null +++ b/benchmark/distance_transform_edt_3D.py @@ -0,0 +1,112 @@ +import scipy.ndimage as ndi # noqa: F401 +import torch +import torch.utils.benchmark as benchmark +from prettytable import PrettyTable + +import torchmorph as tm # noqa: F401 + +# 3D benchmark configurations +sizes = [32, 64, 128, 256] # D=H=W +batches = [1, 2, 4, 8] +dtype = torch.float32 +device = "cuda" +MIN_RUN = 1.0 # seconds per measurement + +torch.set_num_threads(torch.get_num_threads()) + +for B in batches: + table = PrettyTable() + table.field_names = [ + "Size (D×H×W)", + "SciPy (ms/vol)", + "Exact 1× (ms/vol)", + "Exact batch (ms/vol)", + "JFA 1× (ms/vol)", + "JFA batch (ms/vol)", + "Speedup Exact", + "Speedup JFA", + ] + for c in table.field_names: + table.align[c] = "r" + + for s in sizes: + # Skip large sizes with large batches to avoid OOM + if s >= 256 and B >= 4: + table.add_row([f"{s}³", "OOM", "OOM", "OOM", "OOM", "OOM", "-", "-"]) + continue + + # Inputs: (B, D, H, W) format for 3D - no channel dimension for JFA 3D + x = (torch.randn(B, s, s, s, device=device) > 0).to(dtype) + # For scipy, we need (D, H, W) arrays + x_np_list = [x[i].detach().cpu().numpy() for i in range(B)] + # For torch single volume processing: each is (1, D, H, W) + x_vols = [x[i : i + 1] for i in range(B)] + + # SciPy (CPU, one-by-one) + stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]" + t_scipy = benchmark.Timer( + stmt=stmt_scipy, + setup="from __main__ import x_np_list, ndi", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + scipy_per_vol_ms = (t_scipy.median * 1e3) / B + + # Torch Exact (CUDA, one-by-one) + stmt_exact1 = """ +for xi in x_vols: + tm.distance_transform_edt(xi, algorithm="exact") +""" + t_exact1 = benchmark.Timer( + stmt=stmt_exact1, + setup="from __main__ import x_vols, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + exact1_per_vol_ms = (t_exact1.median * 1e3) / B + + # Torch Exact (CUDA, batched) + t_exact_batch = benchmark.Timer( + stmt='tm.distance_transform_edt(x, algorithm="exact")', + setup="from __main__ import x, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + exactB_per_vol_ms = (t_exact_batch.median * 1e3) / B + + # Torch JFA (CUDA, one-by-one) + stmt_jfa1 = """ +for xi in x_vols: + tm.distance_transform_edt(xi, algorithm="jfa") +""" + t_jfa1 = benchmark.Timer( + stmt=stmt_jfa1, + setup="from __main__ import x_vols, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + jfa1_per_vol_ms = (t_jfa1.median * 1e3) / B + + # Torch JFA (CUDA, batched) + t_jfa_batch = benchmark.Timer( + stmt='tm.distance_transform_edt(x, algorithm="jfa")', + setup="from __main__ import x, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + jfaB_per_vol_ms = (t_jfa_batch.median * 1e3) / B + + # Speedups (batch mode vs scipy) + speed_exact = scipy_per_vol_ms / exactB_per_vol_ms + speed_jfa = scipy_per_vol_ms / jfaB_per_vol_ms + + table.add_row( + [ + f"{s}³", + f"{scipy_per_vol_ms:.3f}", + f"{exact1_per_vol_ms:.3f}", + f"{exactB_per_vol_ms:.3f}", + f"{jfa1_per_vol_ms:.3f}", + f"{jfaB_per_vol_ms:.3f}", + f"{speed_exact:.1f}×", + f"{speed_jfa:.1f}×", + ] + ) + + print(f"\n=== 3D EDT Benchmark | Batch Size: {B} ===") + print(table) diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py deleted file mode 100644 index 5855bf3..0000000 --- a/test/test_distance_transform.py +++ /dev/null @@ -1,182 +0,0 @@ -import numpy as np # noqa: F401 -import pytest -import torch -from scipy.ndimage import distance_transform_edt as scipy_edt # noqa: F401 - -import torchmorph as tm # noqa: F401 - - -# ====================================================================== -# Helper functions -# ====================================================================== -def batch_scipy_edt_with_indices( - batch_numpy: np.ndarray, -) -> tuple[np.ndarray, np.ndarray]: - """Compute SciPy EDT and indices for a batch of arrays.""" - dist_results: list[np.ndarray] = [] - indices_results: list[np.ndarray] = [] - - # Ensure batch_numpy has at least shape (Batch, ...) - # If the input is (H, W), it is already converted to (1, H, W) outside. - if batch_numpy.ndim == 1: - batch_numpy = batch_numpy[np.newaxis, ...] - - for sample in batch_numpy: - dist, indices = scipy_edt( - sample, - return_indices=True, - return_distances=True, - ) - dist_results.append(dist) - indices_results.append(indices) - - output_dist = np.stack(dist_results, axis=0) - output_indices = np.stack(indices_results, axis=0) - output_indices = np.moveaxis(output_indices, 1, -1) - - return output_dist, output_indices - - -# ====================================================================== -# Test data -# ====================================================================== -case_batch_1d = np.array( - [[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], - dtype=np.float32, -) - -case_batch_2d = np.array( - [ - [[0.0, 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], - [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], - ], - dtype=np.float32, -) - -# This is a single 2D image with shape (4, 4) -case_single_2d = np.array( - [ - [0, 1, 0, 1], - [1, 0, 1, 0], - [0, 1, 0, 1], - [1, 0, 1, 0], - ], - dtype=np.float32, -) -case_explicit_batch_one = case_single_2d[np.newaxis, ...] - -_case_3d_s1 = np.ones((4, 5, 6), dtype=np.float32) -_case_3d_s1[1, 1, 1] = 0.0 -_case_3d_s1[2, 3, 4] = 0.0 - -_case_3d_s2 = np.ones((4, 5, 6), dtype=np.float32) -_case_3d_s2[0, 0, 0] = 0.0 - -case_batch_3d = np.stack([_case_3d_s1, _case_3d_s2], axis=0) - -case_dim_one = np.ones((2, 5, 1), dtype=np.float32) -case_dim_one[0, 2, 0] = 0.0 -case_dim_one[1, 4, 0] = 0.0 - -# 4D spatial case -_case_4d_s1 = np.ones((3, 3, 3, 3), dtype=np.float32) -_case_4d_s1[0, 0, 0, 0] = 0.0 - -_case_4d_s2 = np.ones((3, 3, 3, 3), dtype=np.float32) -_case_4d_s2[1, 1, 1, 1] = 0.0 - -case_batch_4d_spatial = np.stack([_case_4d_s1, _case_4d_s2], axis=0) - -# 5D spatial case -case_batch_5d_spatial = np.ones((1, 2, 2, 2, 2, 2), dtype=np.float32) -case_batch_5d_spatial[0, 0, 0, 0, 0, 0] = 0.0 -case_batch_5d_spatial[0, 1, 1, 1, 1, 1] = 0.0 - - -# ====================================================================== -# Test logic -# ====================================================================== -@pytest.mark.parametrize( - "input_numpy, has_batch_dim", - [ - pytest.param(case_batch_1d, True, id="1D_Batch"), - pytest.param(case_batch_2d, True, id="2D_Batch"), - pytest.param(case_single_2d, False, id="2D_Single_NoBatch"), - pytest.param( - case_explicit_batch_one, - True, - id="2D_Single_ExplicitBatch", - ), - pytest.param(case_batch_3d, True, id="3D_Batch"), - pytest.param(case_dim_one, True, id="2D_UnitDim_Batch"), - pytest.param(case_batch_4d_spatial, True, id="4D_Spatial_Batch"), - pytest.param(case_batch_5d_spatial, True, id="5D_Spatial_Batch"), - ], -) -def test_distance_transform_and_indices( - input_numpy: np.ndarray, - has_batch_dim: bool, - request: pytest.FixtureRequest, -) -> None: - if not torch.cuda.is_available(): - pytest.skip("CUDA not available") - - # 1. Prepare NumPy data - x_numpy_contiguous = np.ascontiguousarray(input_numpy) - - # 2. Prepare SciPy input. - # If this is a single sample (has_batch_dim=False), manually add a - # batch dimension so SciPy treats it as one image instead of N 1D - # signals. - if not has_batch_dim: - scipy_input = x_numpy_contiguous[np.newaxis, ...] - else: - scipy_input = x_numpy_contiguous - - # 3. Prepare CUDA input. - # If has_batch_dim=False, the input is (H, W) and we want 2D EDT. - # The C++ API assumes the first dimension is batch, so we must - # unsqueeze(0) to get shape (1, H, W). Otherwise, it will be - # interpreted as (Batch=H, Length=W) and run 1D EDT. - x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() - if not has_batch_dim: - x_cuda = x_cuda.unsqueeze(0) - - print(f"\n\n--- Running test: {request.node.callspec.id} ---") - print(f"CUDA input shape: {x_cuda.shape}") - - # 4. Run CUDA EDT - dist_cuda, idx_cuda = tm.distance_transform(x_cuda.clone()) - - # 5. Run SciPy (ground truth) - dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(scipy_input) - dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() - - # 6. Validate distances - print( - f"CUDA distance shape: {dist_cuda.shape}, " f"reference shape: {dist_ref.shape}", - ) - assert ( - dist_cuda.shape == dist_ref.shape - ), f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" - torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-3, rtol=1e-3) - print(">> Distance validation passed.") - - # 7. Validate indices - # idx_cuda: (B, H, W, D) - spatial_shape = x_cuda.shape[1:] - coords = [torch.arange(s, device="cuda") for s in spatial_shape] - grid = torch.stack(torch.meshgrid(*coords, indexing="ij"), dim=-1) - grid = grid.unsqueeze(0) # (1, H, W, D) - - diff = grid.float() - idx_cuda.float() - dist_sq_calculated = torch.sum(diff * diff, dim=-1) - dist_sq_output = dist_cuda * dist_cuda - - torch.testing.assert_close( - dist_sq_calculated, - dist_sq_output, - atol=1e-3, - rtol=1e-3, - ) - print(">> Index validation passed.") diff --git a/test/test_distance_transform_cdt.py b/test/test_distance_transform_cdt.py new file mode 100644 index 0000000..b35248f --- /dev/null +++ b/test/test_distance_transform_cdt.py @@ -0,0 +1,500 @@ +import numpy as np +import pytest +import torch +from scipy.ndimage import distance_transform_cdt as scipy_cdt + +import torchmorph as tm + + +# ====================================================================== +# Helper functions +# ====================================================================== +def batch_scipy_cdt( + batch_numpy: np.ndarray, + metric: str = "chessboard", + return_indices: bool = False, + spatial_ndim: int = 2, +) -> tuple[np.ndarray, np.ndarray | None]: + """Compute SciPy CDT for a batch of arrays. + + Args: + batch_numpy: Input array with shape (batch..., spatial...) + metric: 'chessboard' or 'taxicab' + return_indices: Whether to return indices + spatial_ndim: Number of spatial dimensions (1, 2 or 3) + """ + original_shape = batch_numpy.shape + spatial_shape = original_shape[-spatial_ndim:] + batch_shape = original_shape[:-spatial_ndim] + + if len(batch_shape) > 0: + batch_size = int(np.prod(batch_shape)) + flat_input = batch_numpy.reshape(batch_size, *spatial_shape) + else: + batch_size = 1 + flat_input = batch_numpy[np.newaxis, ...] + + dist_results: list[np.ndarray] = [] + indices_results: list[np.ndarray] = [] + + for sample in flat_input: + if return_indices: + dist, indices = scipy_cdt( + sample, + metric=metric, + return_distances=True, + return_indices=True, + ) + dist_results.append(dist) + indices_results.append(indices) + else: + dist = scipy_cdt(sample, metric=metric) + dist_results.append(dist) + + output_dist = np.stack(dist_results, axis=0) + + # Reshape back + if len(batch_shape) > 0: + output_dist = output_dist.reshape(*batch_shape, *spatial_shape) + else: + output_dist = output_dist[0] + + if return_indices: + output_indices = np.stack(indices_results, axis=0) + if len(batch_shape) > 0: + output_indices = output_indices.reshape(*batch_shape, spatial_ndim, *spatial_shape) + else: + output_indices = output_indices[0] + return output_dist, output_indices + + return output_dist, None + + +# ====================================================================== +# Test data: (B, C, Spatial...) format +# ====================================================================== +# 1D spatial: (B=2, C=1, W=9) +case_1d = np.array( + [[[0, 1, 1, 1, 1, 0, 1, 1, 0]], [[1, 1, 0, 1, 1, 1, 1, 0, 1]]], + dtype=np.float32, +) + +# 2D spatial: (B=1, C=1, H=5, W=6) +case_2d_simple = np.array( + [ + [ + [ + [0, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 1, 1, 1, 1, 0], + ] + ] + ], + dtype=np.float32, +) + +# 2D spatial batch: (B=2, C=1, H=4, W=5) +case_2d_batch = np.array( + [ + [[[0, 1, 1, 1, 0], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [0, 1, 1, 1, 0]]], + [[[1, 1, 0, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 0, 1, 1]]], + ], + dtype=np.float32, +) + +# 2D checkerboard: (B=1, C=1, H=4, W=4) +case_checkerboard = np.array( + [ + [ + [ + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + ] + ] + ], + dtype=np.float32, +) + +# 3D spatial: (B=1, C=1, D=5, H=5, W=5) +_case_3d_simple = np.zeros((1, 1, 5, 5, 5), dtype=np.float32) +_case_3d_simple[0, 0, 1:4, 1:4, 1:4] = 1 # 3x3x3 cube of foreground +case_3d_simple = _case_3d_simple + +# 3D sphere: (B=1, C=1, D=7, H=7, W=7) +_case_3d_sphere = np.zeros((1, 1, 7, 7, 7), dtype=np.float32) +for z in range(7): + for y in range(7): + for x in range(7): + if (z - 3) ** 2 + (y - 3) ** 2 + (x - 3) ** 2 <= 4: + _case_3d_sphere[0, 0, z, y, x] = 1 +case_3d_sphere = _case_3d_sphere + +# 3D batch: (B=2, C=1, D=4, H=5, W=6) +_case_3d_batch_s1 = np.ones((1, 4, 5, 6), dtype=np.float32) +_case_3d_batch_s1[0, 1, 1, 1] = 0.0 +_case_3d_batch_s1[0, 2, 3, 4] = 0.0 + +_case_3d_batch_s2 = np.ones((1, 4, 5, 6), dtype=np.float32) +_case_3d_batch_s2[0, 0, 0, 0] = 0.0 + +case_3d_batch = np.stack( + [_case_3d_batch_s1, _case_3d_batch_s2], axis=0 +) # (B=2, C=1, D=4, H=5, W=6) + + +# ====================================================================== +# Test basic CDT functionality with BCHW format +# ====================================================================== +@pytest.mark.parametrize( + "input_numpy, spatial_ndim, metric", + [ + pytest.param(case_1d, 1, "chessboard", id="1D_B2C1_chessboard"), + pytest.param(case_1d, 1, "taxicab", id="1D_B2C1_taxicab"), + pytest.param(case_2d_simple, 2, "chessboard", id="2D_B1C1_chessboard"), + pytest.param(case_2d_simple, 2, "taxicab", id="2D_B1C1_taxicab"), + pytest.param(case_2d_batch, 2, "chessboard", id="2D_B2C1_chessboard"), + pytest.param(case_2d_batch, 2, "taxicab", id="2D_B2C1_taxicab"), + pytest.param(case_checkerboard, 2, "chessboard", id="2D_checkerboard_chessboard"), + pytest.param(case_checkerboard, 2, "taxicab", id="2D_checkerboard_taxicab"), + pytest.param(case_3d_simple, 3, "chessboard", id="3D_B1C1_simple_chessboard"), + pytest.param(case_3d_simple, 3, "taxicab", id="3D_B1C1_simple_taxicab"), + pytest.param(case_3d_sphere, 3, "chessboard", id="3D_B1C1_sphere_chessboard"), + pytest.param(case_3d_sphere, 3, "taxicab", id="3D_B1C1_sphere_taxicab"), + pytest.param(case_3d_batch, 3, "chessboard", id="3D_B2C1_batch_chessboard"), + pytest.param(case_3d_batch, 3, "taxicab", id="3D_B2C1_batch_taxicab"), + ], +) +def test_cdt_basic( + input_numpy: np.ndarray, + spatial_ndim: int, + metric: str, + request: pytest.FixtureRequest, +) -> None: + """Test CDT distance computation against scipy with BCHW format.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + x_numpy_contiguous = np.ascontiguousarray(input_numpy) + x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() + + print(f"\n\n--- Running test: {request.node.callspec.id} ---") + print(f"CUDA input shape: {x_cuda.shape}, spatial_ndim: {spatial_ndim}, metric: {metric}") + + # Run torchmorph CDT + dist_cuda = tm.distance_transform_cdt(x_cuda, metric=metric) + + # Run scipy CDT (ground truth) + dist_scipy, _ = batch_scipy_cdt(x_numpy_contiguous, metric=metric, spatial_ndim=spatial_ndim) + dist_ref = torch.from_numpy(dist_scipy).to(torch.float32).cuda() + + print(f"CUDA distance shape: {dist_cuda.shape}, reference shape: {dist_ref.shape}") + assert ( + dist_cuda.shape == dist_ref.shape + ), f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" + + # Debug: print actual values for small tensors + if dist_cuda.numel() <= 50: + print(f"Input:\n{x_cuda.cpu().numpy()}") + print(f"CUDA result:\n{dist_cuda.cpu().numpy()}") + print(f"SciPy reference:\n{dist_ref.cpu().numpy()}") + + torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-5, rtol=1e-5) + print(">> Distance validation passed.") + + +# ====================================================================== +# Test metric aliases +# ====================================================================== +@pytest.mark.parametrize( + "alias, canonical", + [ + pytest.param("cityblock", "taxicab", id="cityblock"), + pytest.param("manhattan", "taxicab", id="manhattan"), + ], +) +def test_metric_aliases(alias: str, canonical: str) -> None: + """Test that metric aliases produce same results.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + x_cuda = torch.from_numpy(case_2d_simple).cuda() + + dist_alias = tm.distance_transform_cdt(x_cuda, metric=alias) + dist_canonical = tm.distance_transform_cdt(x_cuda, metric=canonical) + + torch.testing.assert_close(dist_alias, dist_canonical, atol=1e-5, rtol=1e-5) + print(f">> Alias '{alias}' == '{canonical}' validation passed.") + + +# ====================================================================== +# Test return flags with BCHW format +# ====================================================================== +def test_return_flags() -> None: + """Test return_distances and return_indices flags with BCHW format.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # (B=1, C=1, H=5, W=6) + x = torch.from_numpy(case_2d_simple).cuda() + + # Only distances (default) + result = tm.distance_transform_cdt(x, return_distances=True, return_indices=False) + assert isinstance(result, torch.Tensor), "Should return single tensor" + assert result.shape == x.shape + + # Only indices - spatial_ndim=2 for BCHW + result = tm.distance_transform_cdt(x, return_distances=False, return_indices=True) + assert isinstance(result, torch.Tensor), "Should return single tensor" + assert result.shape == (2, *x.shape) # (spatial_ndim, B, C, H, W) + + # Both + dist, idx = tm.distance_transform_cdt(x, return_distances=True, return_indices=True) + assert dist.shape == x.shape + assert idx.shape == (2, *x.shape) + + print(">> Return flags test passed.") + + +# ====================================================================== +# Test pre-allocated output tensors (scipy convention) with BCHW format +# ====================================================================== +def test_preallocated_output() -> None: + """Test pre-allocated output tensors with scipy-style return convention.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # (B=1, C=1, H=5, W=6) + x = torch.from_numpy(case_2d_simple).cuda() + + # Pre-allocate distances tensor + dist_out = torch.empty_like(x) + result = tm.distance_transform_cdt(x, distances=dist_out) + + # Should return None (scipy convention) + assert result is None, "Should return None when distances tensor is provided" + + # But dist_out should be filled + dist_ref, _ = batch_scipy_cdt(case_2d_simple, metric="chessboard", spatial_ndim=2) + dist_ref_tensor = torch.from_numpy(dist_ref).to(torch.float32).cuda() + torch.testing.assert_close(dist_out, dist_ref_tensor, atol=1e-5, rtol=1e-5) + + print(">> Pre-allocated output test passed.") + + +# ====================================================================== +# Test indices correctness with BCHW format +# ====================================================================== +def test_indices_correctness() -> None: + """Test that indices point to correct nearest background pixel with BCHW format.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # (B=1, C=1, H=5, W=6) + x = torch.from_numpy(case_2d_simple).cuda() + + dist, idx = tm.distance_transform_cdt(x, metric="chessboard", return_indices=True) + + # For each foreground pixel, verify the index points to a background pixel + # idx shape: (spatial_ndim=2, B=1, C=1, H=5, W=6) + B, C, H, W = x.shape + x_np = x.cpu().numpy() + idx_np = idx.cpu().numpy() # (2, B, C, H, W) + dist_np = dist.cpu().numpy() + + for b in range(B): + for c in range(C): + for y in range(H): + for x_coord in range(W): + if x_np[b, c, y, x_coord] != 0: # Foreground + idx_y = idx_np[0, b, c, y, x_coord] + idx_x = idx_np[1, b, c, y, x_coord] + # The pointed pixel should be background + assert ( + x_np[b, c, idx_y, idx_x] == 0 + ), f"Index ({idx_y}, {idx_x}) should point to background" + # Chessboard distance should match + expected_dist = max(abs(y - idx_y), abs(x_coord - idx_x)) + assert ( + dist_np[b, c, y, x_coord] == expected_dist + ), f"Distance mismatch at ({b}, {c}, {y}, {x_coord})" + + print(">> Indices correctness test passed.") + + +# ====================================================================== +# Test indices correctness - 3D with BCHW format +# ====================================================================== +def test_indices_correctness_3d() -> None: + """Test that 3D indices point to correct nearest background pixel with BCDHW format.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # (B=1, C=1, D=5, H=5, W=5) + x = torch.from_numpy(case_3d_simple).cuda() + + dist, idx = tm.distance_transform_cdt(x, metric="chessboard", return_indices=True) + + # For each foreground pixel, verify the index points to a background pixel + # idx shape: (spatial_ndim=3, B=1, C=1, D=5, H=5, W=5) + B, C, D, H, W = x.shape + x_np = x.cpu().numpy() + idx_np = idx.cpu().numpy() # (3, B, C, D, H, W) + dist_np = dist.cpu().numpy() + + for b in range(B): + for c in range(C): + for z in range(D): + for y in range(H): + for x_coord in range(W): + if x_np[b, c, z, y, x_coord] != 0: # Foreground + idx_z = idx_np[0, b, c, z, y, x_coord] + idx_y = idx_np[1, b, c, z, y, x_coord] + idx_x = idx_np[2, b, c, z, y, x_coord] + # The pointed pixel should be background + assert ( + x_np[b, c, idx_z, idx_y, idx_x] == 0 + ), f"Index ({idx_z}, {idx_y}, {idx_x}) should point to background" + # Chessboard distance should match + expected_dist = max( + abs(z - idx_z), abs(y - idx_y), abs(x_coord - idx_x) + ) + assert ( + dist_np[b, c, z, y, x_coord] == expected_dist + ), f"Distance mismatch at ({b}, {c}, {z}, {y}, {x_coord})" + + print(">> 3D Indices correctness test passed.") + + +# ====================================================================== +# Test invalid inputs +# ====================================================================== +def test_invalid_metric() -> None: + """Test that invalid metric raises error.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + x = torch.from_numpy(case_2d_simple).cuda() + + with pytest.raises(ValueError, match="metric must be"): + tm.distance_transform_cdt(x, metric="invalid") + + +def test_cpu_input_error() -> None: + """Test that CPU input raises error.""" + x = torch.from_numpy(case_2d_simple) # CPU tensor + + with pytest.raises(ValueError, match="CUDA"): + tm.distance_transform_cdt(x) + + +# ====================================================================== +# Test with random data - BCHW format +# ====================================================================== +@pytest.mark.parametrize( + "shape, spatial_ndim, metric", + [ + # 1D spatial: (B, C, W) + pytest.param((2, 1, 32), 1, "chessboard", id="1D_B2C1_32_chessboard"), + pytest.param((2, 1, 32), 1, "taxicab", id="1D_B2C1_32_taxicab"), + pytest.param((4, 2, 64), 1, "chessboard", id="1D_B4C2_64_chessboard"), + # 2D spatial: (B, C, H, W) + pytest.param((1, 1, 32, 32), 2, "chessboard", id="2D_B1C1_32x32_chessboard"), + pytest.param((1, 1, 32, 32), 2, "taxicab", id="2D_B1C1_32x32_taxicab"), + pytest.param((2, 1, 32, 32), 2, "chessboard", id="2D_B2C1_32x32_chessboard"), + pytest.param((2, 1, 32, 32), 2, "taxicab", id="2D_B2C1_32x32_taxicab"), + pytest.param((2, 3, 64, 48), 2, "chessboard", id="2D_B2C3_64x48_chessboard"), + # 3D spatial: (B, C, D, H, W) + pytest.param((1, 1, 8, 8, 8), 3, "chessboard", id="3D_B1C1_8x8x8_chessboard"), + pytest.param((1, 1, 8, 8, 8), 3, "taxicab", id="3D_B1C1_8x8x8_taxicab"), + pytest.param((2, 1, 16, 16, 16), 3, "chessboard", id="3D_B2C1_16x16x16_chessboard"), + pytest.param((2, 2, 12, 10, 8), 3, "taxicab", id="3D_B2C2_12x10x8_taxicab"), + ], +) +def test_random_data(shape: tuple, spatial_ndim: int, metric: str) -> None: + """Test CDT with random binary data in BCHW format.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + np.random.seed(42) + input_numpy = (np.random.rand(*shape) > 0.3).astype(np.float32) + + x_cuda = torch.from_numpy(input_numpy).cuda() + + # Run torchmorph CDT + dist_cuda = tm.distance_transform_cdt(x_cuda, metric=metric) + + # Run scipy CDT + dist_scipy, _ = batch_scipy_cdt(input_numpy, metric=metric, spatial_ndim=spatial_ndim) + dist_ref = torch.from_numpy(dist_scipy).to(torch.float32).cuda() + + torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-5, rtol=1e-5) + print(f">> Random data test ({shape}, spatial_ndim={spatial_ndim}, {metric}) passed.") + + +# ====================================================================== +# Test indices validation with BCHW format +# ====================================================================== +@pytest.mark.parametrize( + "input_numpy, spatial_ndim, metric", + [ + pytest.param(case_1d, 1, "chessboard", id="1D_indices_chessboard"), + pytest.param(case_1d, 1, "taxicab", id="1D_indices_taxicab"), + pytest.param(case_2d_batch, 2, "chessboard", id="2D_batch_indices_chessboard"), + pytest.param(case_3d_batch, 3, "chessboard", id="3D_batch_indices_chessboard"), + ], +) +def test_indices_validation( + input_numpy: np.ndarray, + spatial_ndim: int, + metric: str, + request: pytest.FixtureRequest, +) -> None: + """Test that indices correctly point to nearest background pixels in BCHW format.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + x_numpy_contiguous = np.ascontiguousarray(input_numpy) + x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() + + print(f"\n\n--- Running test: {request.node.callspec.id} ---") + print(f"CUDA input shape: {x_cuda.shape}, spatial_ndim: {spatial_ndim}") + + # Run torchmorph CDT with indices + dist_cuda, idx_cuda = tm.distance_transform_cdt(x_cuda, metric=metric, return_indices=True) + + # Validate indices shape: (spatial_ndim, *input_shape) + expected_idx_shape = (spatial_ndim, *x_cuda.shape) + assert ( + idx_cuda.shape == expected_idx_shape + ), f"Index shape mismatch: {idx_cuda.shape} vs {expected_idx_shape}" + + # Validate that indices point to background pixels and distance matches + spatial_shape = x_cuda.shape[-spatial_ndim:] + batch_shape = x_cuda.shape[:-spatial_ndim] + + # Create coordinate grid for spatial dimensions + coords = [torch.arange(s, device="cuda") for s in spatial_shape] + grid = torch.stack( + torch.meshgrid(*coords, indexing="ij"), dim=0 + ) # (spatial_ndim, *spatial_shape) + + # Expand grid for batch dimensions + for _ in batch_shape: + grid = grid.unsqueeze(1) + grid = grid.expand(spatial_ndim, *batch_shape, *spatial_shape) + + # Calculate distance from indices based on metric + diff = grid.float() - idx_cuda.float() + if metric in ("chessboard",): + # Chessboard: max of absolute differences + dist_calculated = torch.max(torch.abs(diff), dim=0).values + else: + # Taxicab: sum of absolute differences + dist_calculated = torch.sum(torch.abs(diff), dim=0) + + torch.testing.assert_close(dist_calculated, dist_cuda, atol=1e-5, rtol=1e-5) + print(">> Index validation passed.") diff --git a/test/test_distance_transform_edt.py b/test/test_distance_transform_edt.py new file mode 100644 index 0000000..20b993a --- /dev/null +++ b/test/test_distance_transform_edt.py @@ -0,0 +1,498 @@ +import numpy as np # noqa: F401 +import pytest +import torch +from scipy.ndimage import distance_transform_edt as scipy_edt # noqa: F401 + +import torchmorph as tm # noqa: F401 + + +# ====================================================================== +# Helper functions +# ====================================================================== +def batch_scipy_edt_with_indices( + batch_numpy: np.ndarray, + spatial_ndim: int, +) -> tuple[np.ndarray, np.ndarray]: + """Compute SciPy EDT and indices for a batch of arrays. + + Args: + batch_numpy: Input array with shape (batch..., *spatial_shape) + spatial_ndim: Number of spatial dimensions + """ + dist_results: list[np.ndarray] = [] + indices_results: list[np.ndarray] = [] + + # Compute batch shape + batch_shape = batch_numpy.shape[:-spatial_ndim] if spatial_ndim > 0 else () + spatial_shape = batch_numpy.shape[-spatial_ndim:] if spatial_ndim > 0 else batch_numpy.shape + + # Flatten batch dimensions + if len(batch_shape) > 0: + batch_size = int(np.prod(batch_shape)) + flat_input = batch_numpy.reshape(batch_size, *spatial_shape) + else: + batch_size = 1 + flat_input = batch_numpy[np.newaxis, ...] + + for sample in flat_input: + dist, indices = scipy_edt( + sample, + return_indices=True, + return_distances=True, + ) + dist_results.append(dist) + indices_results.append(indices) + + output_dist = np.stack(dist_results, axis=0) + output_indices = np.stack(indices_results, axis=0) + + # Reshape back to original batch shape + if len(batch_shape) > 0: + output_dist = output_dist.reshape(*batch_shape, *spatial_shape) + output_indices = output_indices.reshape(*batch_shape, spatial_ndim, *spatial_shape) + else: + output_dist = output_dist[0] + output_indices = output_indices[0] + + return output_dist, output_indices + + +# ====================================================================== +# Test data: (B, C, Spatial...) format +# ====================================================================== +# 1D spatial: (B=2, C=1, W=6) +case_1d = np.array( + [[[1, 1, 0, 1, 0, 1]], [[0, 1, 1, 1, 1, 0]]], + dtype=np.float32, +) + +# 2D spatial: (B=2, C=1, H=3, W=4) +case_2d = np.array( + [ + [[[0.0, 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]]], + [[[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]], + ], + dtype=np.float32, +) + +# 2D spatial single batch: (B=1, C=1, H=4, W=4) +case_2d_single = np.array( + [ + [ + [ + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + ] + ] + ], + dtype=np.float32, +) + +# 3D spatial: (B=2, C=1, D=4, H=5, W=6) +_case_3d_s1 = np.ones((1, 4, 5, 6), dtype=np.float32) +_case_3d_s1[0, 1, 1, 1] = 0.0 +_case_3d_s1[0, 2, 3, 4] = 0.0 + +_case_3d_s2 = np.ones((1, 4, 5, 6), dtype=np.float32) +_case_3d_s2[0, 0, 0, 0] = 0.0 + +case_3d = np.stack([_case_3d_s1, _case_3d_s2], axis=0) # (B=2, C=1, D=4, H=5, W=6) + +# 2D with unit dimension: (B=2, C=1, H=5, W=1) +case_2d_unit = np.ones((2, 1, 5, 1), dtype=np.float32) +case_2d_unit[0, 0, 2, 0] = 0.0 +case_2d_unit[1, 0, 4, 0] = 0.0 + + +# ====================================================================== +# Test logic +# ====================================================================== +@pytest.mark.parametrize( + "input_numpy, spatial_ndim", + [ + pytest.param(case_1d, 1, id="1D_B2C1"), + pytest.param(case_2d, 2, id="2D_B2C1"), + pytest.param(case_2d_single, 2, id="2D_B1C1"), + pytest.param(case_3d, 3, id="3D_B2C1"), + pytest.param(case_2d_unit, 2, id="2D_UnitDim_B2C1"), + ], +) +def test_distance_transform_and_indices( + input_numpy: np.ndarray, + spatial_ndim: int, + request: pytest.FixtureRequest, +) -> None: + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # 1. Prepare data + x_numpy_contiguous = np.ascontiguousarray(input_numpy) + x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() + + print(f"\n\n--- Running test: {request.node.callspec.id} ---") + print(f"CUDA input shape: {x_cuda.shape}, spatial_ndim: {spatial_ndim}") + + # 2. Create sampling list to specify spatial dimensions + sampling = [1.0] * spatial_ndim + + # 3. Run CUDA EDT + dist_cuda, idx_cuda = tm.distance_transform( + x_cuda.clone(), sampling=sampling, return_indices=True + ) + + # 4. Run SciPy (ground truth) + dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(x_numpy_contiguous, spatial_ndim) + dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() + + # 5. Validate distances + print( + f"CUDA distance shape: {dist_cuda.shape}, reference shape: {dist_ref.shape}", + ) + assert ( + dist_cuda.shape == dist_ref.shape + ), f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" + + # Debug: print actual values for small tensors + if dist_cuda.numel() <= 30: + print(f"Input:\n{x_cuda.cpu().numpy()}") + print(f"CUDA result:\n{dist_cuda.cpu().numpy()}") + print(f"SciPy reference:\n{dist_ref.cpu().numpy()}") + + torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-5, rtol=1e-5) + print(">> Distance validation passed.") + + # 6. Validate indices + # idx_cuda shape: (spatial_ndim, *input_shape) + # We need to verify that the indices point to the correct nearest background pixel + spatial_shape = x_cuda.shape[-spatial_ndim:] + batch_shape = x_cuda.shape[:-spatial_ndim] + + # Create coordinate grid for spatial dimensions + coords = [torch.arange(s, device="cuda") for s in spatial_shape] + grid = torch.stack( + torch.meshgrid(*coords, indexing="ij"), dim=0 + ) # (spatial_ndim, *spatial_shape) + + # Expand grid for batch dimensions + for _ in batch_shape: + grid = grid.unsqueeze(1) # (spatial_ndim, 1, ..., *spatial_shape) + grid = grid.expand( + spatial_ndim, *batch_shape, *spatial_shape + ) # (spatial_ndim, *batch_shape, *spatial_shape) + + # Calculate distance from indices + diff = grid.float() - idx_cuda.float() + dist_sq_calculated = torch.sum(diff * diff, dim=0) + dist_sq_output = dist_cuda * dist_cuda + + torch.testing.assert_close( + dist_sq_calculated, + dist_sq_output, + atol=1e-5, + rtol=1e-5, + ) + print(">> Index validation passed.") + + +# ====================================================================== +# Helper for sampling tests +# ====================================================================== +def batch_scipy_edt_with_sampling( + batch_numpy: np.ndarray, + spatial_ndim: int, + sampling: list[float], +) -> tuple[np.ndarray, np.ndarray]: + """Compute SciPy EDT with sampling for a batch of arrays. + + Args: + batch_numpy: Input array with shape (batch..., *spatial_shape) + spatial_ndim: Number of spatial dimensions + sampling: Spacing for each spatial dimension + """ + dist_results: list[np.ndarray] = [] + indices_results: list[np.ndarray] = [] + + batch_shape = batch_numpy.shape[:-spatial_ndim] if spatial_ndim > 0 else () + spatial_shape = batch_numpy.shape[-spatial_ndim:] if spatial_ndim > 0 else batch_numpy.shape + + if len(batch_shape) > 0: + batch_size = int(np.prod(batch_shape)) + flat_input = batch_numpy.reshape(batch_size, *spatial_shape) + else: + batch_size = 1 + flat_input = batch_numpy[np.newaxis, ...] + + for sample in flat_input: + dist, indices = scipy_edt( + sample, + sampling=sampling, + return_indices=True, + return_distances=True, + ) + dist_results.append(dist) + indices_results.append(indices) + + output_dist = np.stack(dist_results, axis=0) + output_indices = np.stack(indices_results, axis=0) + + if len(batch_shape) > 0: + output_dist = output_dist.reshape(*batch_shape, *spatial_shape) + output_indices = output_indices.reshape(*batch_shape, spatial_ndim, *spatial_shape) + else: + output_dist = output_dist[0] + output_indices = output_indices[0] + + return output_dist, output_indices + + +# ====================================================================== +# Test sampling functionality +# ====================================================================== +@pytest.mark.parametrize( + "input_numpy, spatial_ndim, sampling", + [ + # 2D with non-uniform sampling + pytest.param(case_2d_single, 2, [0.5, 1.0], id="2D_Sampling_0.5_1.0"), + pytest.param(case_2d_single, 2, [2.0, 0.5], id="2D_Sampling_2.0_0.5"), + pytest.param(case_2d_single, 2, [0.25, 0.25], id="2D_Sampling_0.25_0.25"), + # 2D batch with sampling + pytest.param(case_2d, 2, [1.5, 0.75], id="2D_Batch_Sampling"), + # 3D with sampling + pytest.param(case_3d, 3, [1.0, 2.0, 0.5], id="3D_Batch_Sampling"), + # 1D with sampling + pytest.param(case_1d, 1, [0.5], id="1D_Batch_Sampling"), + # Test single-element list broadcast + pytest.param(case_2d_single, 2, [0.5], id="2D_SingleElementList_Broadcast"), + pytest.param(case_3d, 3, [2.0], id="3D_SingleElementList_Broadcast"), + ], +) +def test_distance_transform_with_sampling( + input_numpy: np.ndarray, + spatial_ndim: int, + sampling: list[float], + request: pytest.FixtureRequest, +) -> None: + """Test EDT with non-unit sampling (pixel spacing).""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + x_numpy_contiguous = np.ascontiguousarray(input_numpy) + x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() + + print(f"\n\n--- Running test: {request.node.callspec.id} ---") + print(f"CUDA input shape: {x_cuda.shape}, spatial_ndim: {spatial_ndim}, sampling: {sampling}") + + # Run CUDA EDT with sampling + dist_cuda, idx_cuda = tm.distance_transform_edt( + x_cuda.clone(), sampling=sampling, return_indices=True + ) + + # Expand single-element list for SciPy (it doesn't support broadcast) + scipy_sampling = sampling if len(sampling) == spatial_ndim else sampling * spatial_ndim + + # Run SciPy with sampling (ground truth) + dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_sampling( + x_numpy_contiguous, spatial_ndim, scipy_sampling + ) + dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() + + # Validate distances + print(f"CUDA distance shape: {dist_cuda.shape}, reference shape: {dist_ref.shape}") + assert ( + dist_cuda.shape == dist_ref.shape + ), f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" + torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-5, rtol=1e-5) + print(">> Distance validation with sampling passed.") + + # Validate indices shape + expected_idx_shape = (spatial_ndim, *x_cuda.shape) + assert ( + idx_cuda.shape == expected_idx_shape + ), f"Index shape mismatch: {idx_cuda.shape} vs {expected_idx_shape}" + print(">> Index shape validation passed.") + + # Validate indices correctness using sampling + spatial_shape = x_cuda.shape[-spatial_ndim:] + batch_shape = x_cuda.shape[:-spatial_ndim] + + coords = [torch.arange(s, device="cuda") for s in spatial_shape] + grid = torch.stack(torch.meshgrid(*coords, indexing="ij"), dim=0) + + for _ in batch_shape: + grid = grid.unsqueeze(1) + grid = grid.expand(spatial_ndim, *batch_shape, *spatial_shape) + + # Calculate distance with sampling (use expanded sampling for validation) + sampling_expanded = sampling if len(sampling) == spatial_ndim else sampling * spatial_ndim + sampling_tensor = torch.tensor(sampling_expanded, device="cuda", dtype=torch.float32) + for _ in range(len(batch_shape) + len(spatial_shape)): + sampling_tensor = sampling_tensor.unsqueeze(-1) + sampling_tensor = sampling_tensor.expand(spatial_ndim, *batch_shape, *spatial_shape) + + diff = (grid.float() - idx_cuda.float()) * sampling_tensor + dist_sq_calculated = torch.sum(diff * diff, dim=0) + dist_sq_output = dist_cuda * dist_cuda + + torch.testing.assert_close(dist_sq_calculated, dist_sq_output, atol=1e-5, rtol=1e-5) + print(">> Index validation with sampling passed.") + + +# ====================================================================== +# Test return_distances and return_indices flags +# ====================================================================== +def test_return_flags() -> None: + """Test return_distances and return_indices flags.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # (B=1, C=1, H=2, W=3) + x = torch.tensor([[[[1, 1, 0], [1, 0, 0]]]], dtype=torch.float32).cuda() + + # Only distances + result = tm.distance_transform_edt(x, return_distances=True, return_indices=False) + assert isinstance( + result, torch.Tensor + ), "Should return single tensor when only distances requested" + assert result.shape == x.shape + + # Only indices + result = tm.distance_transform_edt(x, return_distances=False, return_indices=True) + assert isinstance( + result, torch.Tensor + ), "Should return single tensor when only indices requested" + assert result.shape == (2, *x.shape) # (spatial_ndim, B, C, H, W) + + # Both + dist, idx = tm.distance_transform_edt(x, return_distances=True, return_indices=True) + assert dist.shape == x.shape + assert idx.shape == (2, *x.shape) + + print(">> Return flags test passed.") + + +# ====================================================================== +# Test single float sampling +# ====================================================================== +def test_single_float_sampling() -> None: + """Test that a single float sampling value applies to all dimensions.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Use case_2d_single which is (B=1, C=1, H=4, W=4) format + x_numpy = case_2d_single + x_cuda = torch.from_numpy(x_numpy).cuda() + + # Single float should apply to all spatial dimensions + dist_cuda = tm.distance_transform_edt(x_cuda, sampling=0.5) + + # Compare with scipy using [0.5, 0.5] - use batch helper for BCHW format + spatial_ndim = 2 + dist_ref_numpy, _ = batch_scipy_edt_with_sampling(x_numpy, spatial_ndim, [0.5, 0.5]) + dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() + + torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-5, rtol=1e-5) + print(">> Single float sampling test passed.") + + +# ====================================================================== +# Test algorithm parameter (JFA vs Exact) +# ====================================================================== +@pytest.mark.parametrize( + "input_numpy, spatial_ndim, algorithm", + [ + # 2D tests with different algorithms + pytest.param(case_2d, 2, "exact", id="2D_exact"), + pytest.param(case_2d, 2, "jfa", id="2D_jfa"), + pytest.param(case_2d, 2, "auto", id="2D_auto"), + pytest.param(case_2d_single, 2, "exact", id="2D_single_exact"), + pytest.param(case_2d_single, 2, "jfa", id="2D_single_jfa"), + pytest.param(case_2d_single, 2, "auto", id="2D_single_auto"), + # 3D tests with different algorithms + pytest.param(case_3d, 3, "exact", id="3D_exact"), + pytest.param(case_3d, 3, "jfa", id="3D_jfa"), + pytest.param(case_3d, 3, "auto", id="3D_auto"), + ], +) +def test_distance_transform_algorithm( + input_numpy: np.ndarray, + spatial_ndim: int, + algorithm: str, + request: pytest.FixtureRequest, +) -> None: + """Test EDT with different algorithm options (exact, jfa, auto).""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + x_numpy_contiguous = np.ascontiguousarray(input_numpy) + x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() + + print(f"\n\n--- Running test: {request.node.callspec.id} ---") + print(f"CUDA input shape: {x_cuda.shape}, spatial_ndim: {spatial_ndim}, algorithm: {algorithm}") + + # Run CUDA EDT with specified algorithm + dist_cuda = tm.distance_transform_edt(x_cuda.clone(), algorithm=algorithm) + + # Run SciPy (ground truth) + dist_ref_numpy, _ = batch_scipy_edt_with_indices(x_numpy_contiguous, spatial_ndim) + dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() + + # Validate distances + print(f"CUDA distance shape: {dist_cuda.shape}, reference shape: {dist_ref.shape}") + assert ( + dist_cuda.shape == dist_ref.shape + ), f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" + + torch.testing.assert_close(dist_cuda, dist_ref, rtol=1e-5, atol=1e-5) + + print(f">> Algorithm '{algorithm}' validation passed.") + + +def test_algorithm_fallback_with_sampling() -> None: + """Test that JFA falls back to exact when sampling is provided.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + x_numpy = case_2d_single + x_cuda = torch.from_numpy(x_numpy).cuda() + + # With non-unit sampling, JFA should fall back to exact algorithm + # Both should give same result + dist_jfa = tm.distance_transform_edt(x_cuda.clone(), sampling=[0.5, 1.0], algorithm="jfa") + dist_exact = tm.distance_transform_edt(x_cuda.clone(), sampling=[0.5, 1.0], algorithm="exact") + + # Compare with scipy + spatial_ndim = 2 + dist_ref_numpy, _ = batch_scipy_edt_with_sampling(x_numpy, spatial_ndim, [0.5, 1.0]) + dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() + + torch.testing.assert_close(dist_jfa, dist_ref, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(dist_exact, dist_ref, atol=1e-5, rtol=1e-5) + + print(">> Algorithm fallback with sampling test passed.") + + +def test_jfa_vs_exact_consistency() -> None: + """Test that JFA and exact produce similar results for unit sampling.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Create a larger random test case + torch.manual_seed(42) + x = (torch.randn(2, 1, 64, 64, device="cuda") > 0).float() + + dist_exact = tm.distance_transform_edt(x, algorithm="exact") + dist_jfa = tm.distance_transform_edt(x, algorithm="jfa") + + # JFA should be very close to exact for most pixels + # Allow for small differences due to JFA's approximate nature + diff = torch.abs(dist_exact - dist_jfa) + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + print(f"JFA vs Exact - Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}") + + # Most pixels should be exact or very close + assert mean_diff < 0.1, f"Mean difference too large: {mean_diff}" + print(">> JFA vs Exact consistency test passed.") diff --git a/torchmorph/__init__.py b/torchmorph/__init__.py index f35a5c9..3310c8c 100644 --- a/torchmorph/__init__.py +++ b/torchmorph/__init__.py @@ -1,10 +1,12 @@ from .add import add from .dilation_erosion import binary_dilation, binary_erosion -from .distance_transform import distance_transform +from .distance_transform import distance_transform, distance_transform_cdt, distance_transform_edt __all__ = [ "add", "distance_transform", + "distance_transform_edt", + "distance_transform_cdt", "binary_dilation", "binary_erosion", ] diff --git a/torchmorph/csrc/distance_transform_cdt.cu b/torchmorph/csrc/distance_transform_cdt.cu new file mode 100644 index 0000000..6de71dd --- /dev/null +++ b/torchmorph/csrc/distance_transform_cdt.cu @@ -0,0 +1,446 @@ +#include +#include +#include +#include +#include + +#define CDT_BLOCK_SIZE 256 +#define CDT_INF_VAL 1000000000 +#define MAX_NDIM 16 + +// ============================================================================ +// High-performance N-dimensional CDT using dimension-separable parallel scans +// For each dimension, we do forward and backward sweeps that can be parallelized +// across all other dimensions and batch elements. +// ============================================================================ + +// Initialize distance and indices +__global__ void cdt_init_kernel( + const float* __restrict__ input, + int32_t* __restrict__ dist, + int32_t* __restrict__ indices, // [spatial_ndim, total_elements] + int64_t total_elements, + int spatial_ndim, + int64_t spatial_elements, + const int64_t* __restrict__ spatial_strides, + bool compute_indices +) { + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total_elements) return; + + if (input[tid] == 0.0f) { + dist[tid] = 0; + if (compute_indices) { + int64_t spatial_idx = tid % spatial_elements; + int64_t rem = spatial_idx; + for (int d = 0; d < spatial_ndim; d++) { + int64_t coord = rem / spatial_strides[d]; + rem = rem % spatial_strides[d]; + indices[d * total_elements + tid] = (int32_t)coord; + } + } + } else { + dist[tid] = CDT_INF_VAL; + if (compute_indices) { + for (int d = 0; d < spatial_ndim; d++) { + indices[d * total_elements + tid] = -1; + } + } + } +} + +// ============================================================================ +// Dimension-wise sweep kernels for chessboard metric +// Each thread handles one "line" along the scan dimension +// ============================================================================ + +// Forward sweep along dimension d (from 0 to size-1) +// For chessboard: check neighbor at offset -1 in dimension d, and diagonal neighbors +__global__ void cdt_sweep_forward_chessboard_kernel( + int32_t* __restrict__ dist, + int32_t* __restrict__ indices, + int64_t total_elements, + int64_t num_lines, // number of parallel lines + int64_t line_stride, // stride between elements in the same line + int64_t line_length, // number of elements in one line + int64_t batch_stride, // stride between batches + int64_t spatial_elements, + int spatial_ndim, + int scan_dim, + const int64_t* __restrict__ spatial_strides, + const int64_t* __restrict__ spatial_shape, + bool compute_indices +) { + int64_t line_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (line_idx >= num_lines) return; + + // Compute starting position for this line + int64_t batch_idx = line_idx / (spatial_elements / line_length); + int64_t within_batch = line_idx % (spatial_elements / line_length); + + // Convert within_batch to actual spatial offset (excluding scan dimension) + int64_t spatial_offset = 0; + int64_t rem = within_batch; + for (int d = 0; d < spatial_ndim; d++) { + if (d == scan_dim) continue; + int64_t dim_size = spatial_shape[d]; + int64_t coord = rem % dim_size; + rem /= dim_size; + spatial_offset += coord * spatial_strides[d]; + } + + int64_t base = batch_idx * spatial_elements + spatial_offset; + + // Forward sweep: i = 0 to line_length-1 + for (int64_t i = 1; i < line_length; i++) { + int64_t curr_idx = base + i * line_stride; + int32_t curr_dist = dist[curr_idx]; + + if (curr_dist == 0) continue; + + // Check previous element in this dimension + int64_t prev_idx = base + (i - 1) * line_stride; + int32_t prev_dist = dist[prev_idx]; + + if (prev_dist < CDT_INF_VAL) { + int32_t new_dist = prev_dist + 1; + if (new_dist < curr_dist) { + dist[curr_idx] = new_dist; + if (compute_indices) { + for (int d = 0; d < spatial_ndim; d++) { + indices[d * total_elements + curr_idx] = indices[d * total_elements + prev_idx]; + } + } + } + } + } +} + +// Backward sweep along dimension d (from size-1 to 0) +__global__ void cdt_sweep_backward_chessboard_kernel( + int32_t* __restrict__ dist, + int32_t* __restrict__ indices, + int64_t total_elements, + int64_t num_lines, + int64_t line_stride, + int64_t line_length, + int64_t batch_stride, + int64_t spatial_elements, + int spatial_ndim, + int scan_dim, + const int64_t* __restrict__ spatial_strides, + const int64_t* __restrict__ spatial_shape, + bool compute_indices +) { + int64_t line_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (line_idx >= num_lines) return; + + int64_t batch_idx = line_idx / (spatial_elements / line_length); + int64_t within_batch = line_idx % (spatial_elements / line_length); + + int64_t spatial_offset = 0; + int64_t rem = within_batch; + for (int d = 0; d < spatial_ndim; d++) { + if (d == scan_dim) continue; + int64_t dim_size = spatial_shape[d]; + int64_t coord = rem % dim_size; + rem /= dim_size; + spatial_offset += coord * spatial_strides[d]; + } + + int64_t base = batch_idx * spatial_elements + spatial_offset; + + // Backward sweep: i = line_length-2 down to 0 + for (int64_t i = line_length - 2; i >= 0; i--) { + int64_t curr_idx = base + i * line_stride; + int32_t curr_dist = dist[curr_idx]; + + if (curr_dist == 0) continue; + + int64_t next_idx = base + (i + 1) * line_stride; + int32_t next_dist = dist[next_idx]; + + if (next_dist < CDT_INF_VAL) { + int32_t new_dist = next_dist + 1; + if (new_dist < curr_dist) { + dist[curr_idx] = new_dist; + if (compute_indices) { + for (int d = 0; d < spatial_ndim; d++) { + indices[d * total_elements + curr_idx] = indices[d * total_elements + next_idx]; + } + } + } + } + } +} + +// ============================================================================ +// Diagonal sweep kernels for chessboard metric (handles corner neighbors) +// ============================================================================ + +// Check all neighbors at distance 1 in chessboard metric +__global__ void cdt_diagonal_pass_kernel( + int32_t* __restrict__ dist, + int32_t* __restrict__ indices, + int64_t total_elements, + int64_t batch_size, + int64_t spatial_elements, + int spatial_ndim, + const int64_t* __restrict__ spatial_strides, + const int64_t* __restrict__ spatial_shape, + const int32_t* __restrict__ offsets, + int num_offsets, + bool compute_indices, + bool forward // true for forward pass, false for backward +) { + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total_elements) return; + + int32_t curr_dist = dist[tid]; + if (curr_dist == 0) return; + + int64_t batch_idx = tid / spatial_elements; + int64_t spatial_idx = tid % spatial_elements; + int64_t base = batch_idx * spatial_elements; + + // Compute current coordinates + int32_t coords[MAX_NDIM]; + int64_t rem = spatial_idx; + for (int d = 0; d < spatial_ndim; d++) { + coords[d] = (int32_t)(rem / spatial_strides[d]); + rem = rem % spatial_strides[d]; + } + + int32_t min_dist = curr_dist; + int best_neighbor = -1; + + for (int n = 0; n < num_offsets; n++) { + int64_t neighbor_spatial = spatial_idx + offsets[n]; + + // Check bounds and no wrap-around + if (neighbor_spatial < 0 || neighbor_spatial >= spatial_elements) continue; + + // Verify no wrap-around by checking coordinate differences + int64_t n_rem = neighbor_spatial; + bool valid = true; + for (int d = 0; d < spatial_ndim; d++) { + int32_t n_coord = (int32_t)(n_rem / spatial_strides[d]); + n_rem = n_rem % spatial_strides[d]; + int32_t diff = coords[d] - n_coord; + if (diff < -1 || diff > 1) { + valid = false; + break; + } + } + if (!valid) continue; + + int64_t neighbor_idx = base + neighbor_spatial; + int32_t neighbor_dist = dist[neighbor_idx]; + + if (neighbor_dist < CDT_INF_VAL) { + int32_t new_dist = neighbor_dist + 1; + if (new_dist < min_dist) { + min_dist = new_dist; + best_neighbor = n; + } + } + } + + if (min_dist < curr_dist) { + dist[tid] = min_dist; + if (compute_indices && best_neighbor >= 0) { + int64_t src_idx = base + spatial_idx + offsets[best_neighbor]; + for (int d = 0; d < spatial_ndim; d++) { + indices[d * total_elements + tid] = indices[d * total_elements + src_idx]; + } + } + } +} + +// ============================================================================ +// Taxicab metric uses simpler dimension-separable sweeps (no diagonals) +// ============================================================================ + +std::tuple distance_transform_cdt_cuda( + torch::Tensor input, + const std::string& metric, + bool return_distances, + bool return_indices +) { + TORCH_CHECK(input.is_cuda(), "Input must be CUDA tensor"); + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32"); + TORCH_CHECK(metric == "chessboard" || metric == "taxicab", + "metric must be 'chessboard' or 'taxicab'"); + TORCH_CHECK(return_distances || return_indices, + "At least one of return_distances or return_indices must be True"); + + input = input.contiguous(); + + bool is_taxicab = (metric == "taxicab"); + int total_ndim = input.dim(); + + TORCH_CHECK(total_ndim >= 3, "Input must be (B, C, Spatial...) format with at least 3 dimensions"); + + auto shape_vec = input.sizes().vec(); + int64_t batch_size = shape_vec[0] * shape_vec[1]; + int spatial_ndim = total_ndim - 2; + + TORCH_CHECK(spatial_ndim >= 1 && spatial_ndim <= MAX_NDIM, + "CDT supports 1D-" + std::to_string(MAX_NDIM) + "D spatial dimensions"); + + std::vector spatial_shape(spatial_ndim); + std::vector spatial_strides(spatial_ndim); + + int64_t spatial_elements = 1; + for (int d = 0; d < spatial_ndim; d++) { + spatial_shape[d] = shape_vec[d + 2]; + spatial_elements *= spatial_shape[d]; + } + + spatial_strides[spatial_ndim - 1] = 1; + for (int d = spatial_ndim - 2; d >= 0; d--) { + spatial_strides[d] = spatial_strides[d + 1] * spatial_shape[d + 1]; + } + + int64_t total_elements = input.numel(); + + // Allocate output tensors + auto dist = torch::empty({total_elements}, input.options().dtype(torch::kInt32)); + torch::Tensor indices; + if (return_indices) { + indices = torch::empty({spatial_ndim, total_elements}, input.options().dtype(torch::kInt32)); + } + + // Copy shape/strides to device + auto spatial_shape_tensor = torch::tensor(spatial_shape, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); + auto spatial_strides_tensor = torch::tensor(spatial_strides, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); + + // Initialize + int block = CDT_BLOCK_SIZE; + int grid = (total_elements + block - 1) / block; + + cdt_init_kernel<<>>( + input.data_ptr(), + dist.data_ptr(), + return_indices ? indices.data_ptr() : nullptr, + total_elements, + spatial_ndim, + spatial_elements, + spatial_strides_tensor.data_ptr(), + return_indices + ); + + // For each dimension, do forward and backward sweeps + for (int d = 0; d < spatial_ndim; d++) { + int64_t line_length = spatial_shape[d]; + int64_t line_stride = spatial_strides[d]; + int64_t num_lines = batch_size * (spatial_elements / line_length); + + int sweep_block = CDT_BLOCK_SIZE; + int sweep_grid = (num_lines + sweep_block - 1) / sweep_block; + + // Forward sweep + cdt_sweep_forward_chessboard_kernel<<>>( + dist.data_ptr(), + return_indices ? indices.data_ptr() : nullptr, + total_elements, + num_lines, + line_stride, + line_length, + spatial_elements, + spatial_elements, + spatial_ndim, + d, + spatial_strides_tensor.data_ptr(), + spatial_shape_tensor.data_ptr(), + return_indices + ); + + // Backward sweep + cdt_sweep_backward_chessboard_kernel<<>>( + dist.data_ptr(), + return_indices ? indices.data_ptr() : nullptr, + total_elements, + num_lines, + line_stride, + line_length, + spatial_elements, + spatial_elements, + spatial_ndim, + d, + spatial_strides_tensor.data_ptr(), + spatial_shape_tensor.data_ptr(), + return_indices + ); + } + + // For chessboard metric, we need additional diagonal passes + if (!is_taxicab && spatial_ndim >= 2) { + // Generate diagonal offsets + std::vector diagonal_offsets; + + // All neighbors in 3^ndim hypercube except axis-aligned ones + int total_combos = 1; + for (int d = 0; d < spatial_ndim; d++) total_combos *= 3; + + for (int i = 0; i < total_combos; i++) { + int temp = i; + int64_t offset = 0; + int non_zero_count = 0; + + for (int d = 0; d < spatial_ndim; d++) { + int dir = (temp % 3) - 1; + temp /= 3; + if (dir != 0) non_zero_count++; + offset += dir * spatial_strides[d]; + } + + // Only include diagonal neighbors (more than one non-zero direction) + if (non_zero_count >= 2) { + diagonal_offsets.push_back((int32_t)offset); + } + } + + if (!diagonal_offsets.empty()) { + auto offsets_tensor = torch::tensor(diagonal_offsets, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); + + // Multiple passes to propagate diagonal distances + // Need more passes for higher dimensions to ensure full propagation + int num_passes = spatial_ndim * 2; // Scale with dimensions + for (int pass = 0; pass < num_passes; pass++) { + cdt_diagonal_pass_kernel<<>>( + dist.data_ptr(), + return_indices ? indices.data_ptr() : nullptr, + total_elements, + batch_size, + spatial_elements, + spatial_ndim, + spatial_strides_tensor.data_ptr(), + spatial_shape_tensor.data_ptr(), + offsets_tensor.data_ptr(), + diagonal_offsets.size(), + return_indices, + pass % 2 == 0 + ); + } + } + } + + // Prepare output + torch::Tensor result_dist; + torch::Tensor result_indices; + + if (return_distances) { + result_dist = dist.to(torch::kFloat32).view(input.sizes()); + } + + if (return_indices) { + std::vector idx_shape = {spatial_ndim}; + for (int d = 0; d < total_ndim; d++) { + idx_shape.push_back(shape_vec[d]); + } + result_indices = indices.view(idx_shape); + } + + return std::make_tuple(result_dist, result_indices); +} diff --git a/torchmorph/csrc/distance_transform_edt.cu b/torchmorph/csrc/distance_transform_edt.cu new file mode 100644 index 0000000..d021c62 --- /dev/null +++ b/torchmorph/csrc/distance_transform_edt.cu @@ -0,0 +1,1598 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================== +// Configuration +// ============================================================================== +#define INF_VAL 1e20f +#define MAX_THREADS 256 +#define SHARED_MEM_LIMIT 2048 // Max dimension size for shared memory path (48KB limit) + +// JFA Configuration +#define BLOCK_SIZE 256 +#define SMEM_LIMIT_ELEMENTS 4096 +#define JFA_BLOCK_DIM 32 +#define JFA_FUSED_STEPS 4 +#define JFA_MAX_OFFSET 8 +#define JFA_SMEM_DIM (JFA_BLOCK_DIM + 2 * JFA_MAX_OFFSET) +#define JFA_3D_BLOCK 8 +#define JFA_3D_HALO 1 + +// ============================================================================== +// JFA Device Helpers +// ============================================================================== +__device__ __forceinline__ float sqr(float x) { return x * x; } + +__device__ __forceinline__ float dist_sq_2d(int y1, int x1, int y2, int x2) { + return sqr((float)(y1 - y2)) + sqr((float)(x1 - x2)); +} + +__device__ __forceinline__ float dist_sq_3d_soa(int z1, int y1, int x1, int z2, int y2, int x2) { + if (z2 == -1) return INF_VAL; + float dz = (float)(z1 - z2); + float dy = (float)(y1 - y2); + float dx = (float)(x1 - x2); + return dz*dz + dy*dy + dx*dx; +} + +__device__ __forceinline__ float compute_cost(int q, int p, float val_p) { + if (p < 0 || val_p >= INF_VAL) return INF_VAL; + return sqr((float)q - (float)p) + val_p; +} + +__device__ __forceinline__ float dist_sq_int2(int y, int x, int2 seed) { + if (seed.x == -1) return INF_VAL; + float dy = (float)(y - seed.x); + float dx = (float)(x - seed.y); + return dy*dy + dx*dx; +} + +// ============================================================================== +// JFA 2D Kernels (Vectorized int2 + Block Shared) +// ============================================================================== +__global__ void init_jfa_kernel_2d_opt( + const float* __restrict__ input, + int2* __restrict__ output, + int64_t total_elements, + int H, int W +) { + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total_elements) return; + + if (input[tid] == 0.0f) { + int64_t spatial_size = (int64_t)H * W; + int64_t rem = tid % spatial_size; + int w = (int)(rem % W); + int h = (int)(rem / W); + output[tid] = make_int2(h, w); + } else { + output[tid] = make_int2(-1, -1); + } +} + +__global__ void jfa_block_fused_kernel_2d( + const int2* __restrict__ in_idx, + int2* __restrict__ out_idx, + int H, int W, + int64_t num_images +) { + __shared__ int2 smem[JFA_SMEM_DIM][JFA_SMEM_DIM]; + + int tx = threadIdx.x; + int ty = threadIdx.y; + + int bx = blockIdx.x * blockDim.x; + int by = blockIdx.y * blockDim.y; + int img_idx = blockIdx.z; + int64_t batch_offset = (int64_t)img_idx * (H * W); + + int gx = bx + tx; + int gy = by + ty; + + // Phase 1: load data to Shared Memory + int smem_linear_size = JFA_SMEM_DIM * JFA_SMEM_DIM; + int total_threads = blockDim.x * blockDim.y; + int thread_linear_idx = ty * blockDim.x + tx; + + int base_x = bx - JFA_MAX_OFFSET; + int base_y = by - JFA_MAX_OFFSET; + + for (int i = thread_linear_idx; i < smem_linear_size; i += total_threads) { + int s_y = i / JFA_SMEM_DIM; + int s_x = i % JFA_SMEM_DIM; + int global_y = base_y + s_y; + int global_x = base_x + s_x; + int2 val = make_int2(-1, -1); + if (global_y >= 0 && global_y < H && global_x >= 0 && global_x < W) { + val = in_idx[batch_offset + global_y * W + global_x]; + } + smem[s_y][s_x] = val; + } + __syncthreads(); + + // Phase 2: Iterate in Shared Memory + if (gx < W && gy < H) { + int center_sy = ty + JFA_MAX_OFFSET; + int center_sx = tx + JFA_MAX_OFFSET; + + int2 best_seed = smem[center_sy][center_sx]; + float best_dist = dist_sq_int2(gy, gx, best_seed); + + int step = 1; + #pragma unroll + for (int k = 0; k < JFA_FUSED_STEPS; ++k) { + #pragma unroll + for (int dy = -1; dy <= 1; ++dy) { + #pragma unroll + for (int dx = -1; dx <= 1; ++dx) { + if (dy == 0 && dx == 0) continue; + int2 neighbor_seed = smem[center_sy + dy * step][center_sx + dx * step]; + if (neighbor_seed.x != -1) { + float d = dist_sq_int2(gy, gx, neighbor_seed); + if (d < best_dist) { + best_dist = d; + best_seed = neighbor_seed; + } + } + } + } + __syncthreads(); + smem[center_sy][center_sx] = best_seed; + __syncthreads(); + step *= 2; + } + out_idx[batch_offset + gy * W + gx] = best_seed; + } +} + +__global__ void jfa_step_global_2d_opt( + const int2* __restrict__ in_idx, + int2* __restrict__ out_idx, + int step, + int H, int W, + int64_t total_pixels +) { + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total_pixels) return; + + int64_t spatial_size = (int64_t)H * W; + int64_t rem = tid % spatial_size; + int64_t batch_offset = tid - rem; + int w = (int)(rem % W); + int h = (int)(rem / W); + + int2 best_seed = in_idx[tid]; + float best_dist = dist_sq_int2(h, w, best_seed); + + #pragma unroll + for (int dy = -1; dy <= 1; ++dy) { + #pragma unroll + for (int dx = -1; dx <= 1; ++dx) { + if (dx == 0 && dy == 0) continue; + + int ny = h + dy * step; + int nx = w + dx * step; + + if (ny >= 0 && ny < H && nx >= 0 && nx < W) { + int2 neighbor_seed = in_idx[batch_offset + ny * W + nx]; + if (neighbor_seed.x != -1) { + float d = dist_sq_int2(h, w, neighbor_seed); + if (d < best_dist) { + best_dist = d; + best_seed = neighbor_seed; + } + } + } + } + } + out_idx[tid] = best_seed; +} + +__global__ void calc_dist_kernel_2d_opt( + const int2* __restrict__ indices, + float* __restrict__ dist_out, + int64_t total_elements, + int H, int W +) { + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total_elements) return; + + int2 s = indices[tid]; + if (s.x == -1) { + dist_out[tid] = INF_VAL; + } else { + int64_t spatial_size = (int64_t)H * W; + int64_t rem = tid % spatial_size; + int cur_w = (int)(rem % W); + int cur_h = (int)(rem / W); + dist_out[tid] = sqrtf(dist_sq_int2(cur_h, cur_w, s)); + } +} + +// ============================================================================== +// JFA 3D Kernels (Optimized SoA Layout) +// ============================================================================== +template +__global__ void init_jfa_kernel_3d_soa( + const float* __restrict__ input, + IndexType* __restrict__ indices_z, + IndexType* __restrict__ indices_y, + IndexType* __restrict__ indices_x, + int64_t total_elements, + int D, int H, int W +) { + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total_elements) return; + + if (input[tid] == 0.0f) { + int64_t spatial_size = (int64_t)D * H * W; + int64_t rem = tid % spatial_size; + int w = (int)(rem % W); + int h = (int)((rem / W) % H); + int d = (int)(rem / (W * H)); + + indices_z[tid] = (IndexType)d; + indices_y[tid] = (IndexType)h; + indices_x[tid] = (IndexType)w; + } else { + indices_z[tid] = (IndexType)-1; + indices_y[tid] = (IndexType)-1; + indices_x[tid] = (IndexType)-1; + } +} + +template +__global__ void jfa_block_fused_kernel_3d_soa( + const IndexType* __restrict__ in_z, + const IndexType* __restrict__ in_y, + const IndexType* __restrict__ in_x, + IndexType* __restrict__ out_z, + IndexType* __restrict__ out_y, + IndexType* __restrict__ out_x, + int D, int H, int W, + int blocks_per_d +) { + const int BLOCK_DIM = 8; + const int HALO = 3; + const int SMEM_DIM = BLOCK_DIM + 2 * HALO; // 14 + const int SMEM_SIZE = SMEM_DIM * SMEM_DIM * SMEM_DIM; + + extern __shared__ char smem_raw[]; + IndexType* smem_z = (IndexType*)smem_raw; + IndexType* smem_y = smem_z + SMEM_SIZE; + IndexType* smem_x = smem_y + SMEM_SIZE; + + int tx = threadIdx.x; int ty = threadIdx.y; int tz = threadIdx.z; + + int b_z_total = blockIdx.z; + int batch_id = b_z_total / blocks_per_d; + int b_z_local = b_z_total % blocks_per_d; + + int bx = blockIdx.x * BLOCK_DIM; + int by = blockIdx.y * BLOCK_DIM; + int bz = b_z_local * BLOCK_DIM; + + int64_t spatial_offset = (int64_t)batch_id * (D * H * W); + + // Phase 1: Load to SoA Shared Memory + int tid = tz * 64 + ty * 8 + tx; + int base_x = bx - HALO; + int base_y = by - HALO; + int base_z = bz - HALO; + + for (int i = tid; i < SMEM_SIZE; i += 512) { + int temp = i; + int sx = temp % SMEM_DIM; temp /= SMEM_DIM; + int sy = temp % SMEM_DIM; + int sz = temp / SMEM_DIM; + + int gx = base_x + sx; + int gy = base_y + sy; + int gz = base_z + sz; + + IndexType val_z = -1, val_y = -1, val_x = -1; + if (gz >= 0 && gz < D && gy >= 0 && gy < H && gx >= 0 && gx < W) { + int64_t idx = spatial_offset + (int64_t)gz * (H * W) + gy * W + gx; + val_z = in_z[idx]; + val_y = in_y[idx]; + val_x = in_x[idx]; + } + smem_z[i] = val_z; + smem_y[i] = val_y; + smem_x[i] = val_x; + } + __syncthreads(); + + // Phase 2: Compute + int center_sz = tz + HALO; + int center_sy = ty + HALO; + int center_sx = tx + HALO; + int my_s_idx = (center_sz * SMEM_DIM + center_sy) * SMEM_DIM + center_sx; + + int best_z = (int)smem_z[my_s_idx]; + int best_y = (int)smem_y[my_s_idx]; + int best_x = (int)smem_x[my_s_idx]; + + int g_cz = bz + tz; + int g_cy = by + ty; + int g_cx = bx + tx; + + float best_dist = dist_sq_3d_soa(g_cz, g_cy, g_cx, best_z, best_y, best_x); + + int step = 1; + #pragma unroll + for (int k = 0; k < 2; ++k) { + #pragma unroll + for (int dz = -1; dz <= 1; ++dz) { + #pragma unroll + for (int dy = -1; dy <= 1; ++dy) { + #pragma unroll + for (int dx = -1; dx <= 1; ++dx) { + if (dz == 0 && dy == 0 && dx == 0) continue; + + int nz = center_sz + dz * step; + int ny = center_sy + dy * step; + int nx = center_sx + dx * step; + int n_idx = (nz * SMEM_DIM + ny) * SMEM_DIM + nx; + + int sz_in = (int)smem_z[n_idx]; + if (sz_in != -1) { + int sy_in = (int)smem_y[n_idx]; + int sx_in = (int)smem_x[n_idx]; + float d = dist_sq_3d_soa(g_cz, g_cy, g_cx, sz_in, sy_in, sx_in); + if (d < best_dist) { + best_dist = d; + best_z = sz_in; + best_y = sy_in; + best_x = sx_in; + } + } + } + } + } + __syncthreads(); + smem_z[my_s_idx] = (IndexType)best_z; + smem_y[my_s_idx] = (IndexType)best_y; + smem_x[my_s_idx] = (IndexType)best_x; + __syncthreads(); + step *= 2; + } + + if (g_cz < D && g_cy < H && g_cx < W) { + int64_t out_idx_g = spatial_offset + (int64_t)g_cz * (H * W) + g_cy * W + g_cx; + out_z[out_idx_g] = (IndexType)best_z; + out_y[out_idx_g] = (IndexType)best_y; + out_x[out_idx_g] = (IndexType)best_x; + } +} + +template +__global__ void jfa_step_3d_soa( + const IndexType* __restrict__ in_z, + const IndexType* __restrict__ in_y, + const IndexType* __restrict__ in_x, + IndexType* __restrict__ out_z, + IndexType* __restrict__ out_y, + IndexType* __restrict__ out_x, + int step, + int D, int H, int W, + int64_t total_pixels +) { + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total_pixels) return; + + int64_t spatial_size = (int64_t)D * H * W; + int64_t rem = tid % spatial_size; + int64_t batch_offset = tid - rem; + int cur_w = (int)(rem % W); + int cur_h = (int)((rem / W) % H); + int cur_d = (int)(rem / (W * H)); + + int best_z = (int)in_z[tid]; + int best_y = (int)in_y[tid]; + int best_x = (int)in_x[tid]; + + float best_dist = dist_sq_3d_soa(cur_d, cur_h, cur_w, best_z, best_y, best_x); + + #pragma unroll + for (int dz = -1; dz <= 1; ++dz) { + #pragma unroll + for (int dy = -1; dy <= 1; ++dy) { + #pragma unroll + for (int dx = -1; dx <= 1; ++dx) { + if (dz == 0 && dy == 0 && dx == 0) continue; + + int nz = cur_d + dz * step; + int ny = cur_h + dy * step; + int nx = cur_w + dx * step; + + if (nz >= 0 && nz < D && ny >= 0 && ny < H && nx >= 0 && nx < W) { + int64_t n_idx = batch_offset + (int64_t)nz * (H * W) + ny * W + nx; + + int seed_z = (int)in_z[n_idx]; + if (seed_z != -1) { + float dz_val = (float)(cur_d - seed_z); + float dz_sq = dz_val * dz_val; + + if (dz_sq < best_dist) { + int seed_y = (int)in_y[n_idx]; + int seed_x = (int)in_x[n_idx]; + float dist = dz_sq + sqr((float)(cur_h - seed_y)) + sqr((float)(cur_w - seed_x)); + + if (dist < best_dist) { + best_dist = dist; + best_z = seed_z; + best_y = seed_y; + best_x = seed_x; + } + } + } + } + } + } + } + out_z[tid] = (IndexType)best_z; + out_y[tid] = (IndexType)best_y; + out_x[tid] = (IndexType)best_x; +} + +template +__global__ void calc_dist_kernel_3d_soa( + const IndexType* __restrict__ in_z, + const IndexType* __restrict__ in_y, + const IndexType* __restrict__ in_x, + float* __restrict__ dist_out, + int64_t total_elements, + int D, int H, int W +) { + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total_elements) return; + + int seed_d = (int)in_z[tid]; + if (seed_d == -1) { + dist_out[tid] = INF_VAL; + } else { + int seed_h = (int)in_y[tid]; + int seed_w = (int)in_x[tid]; + + int64_t spatial_size = (int64_t)D * H * W; + int64_t rem = tid % spatial_size; + int cur_w = (int)(rem % W); + int cur_h = (int)((rem / W) % H); + int cur_d = (int)(rem / (W * H)); + + dist_out[tid] = sqrtf(dist_sq_3d_soa(cur_d, cur_h, cur_w, seed_d, seed_h, seed_w)); + } +} + +// ============================================================================== +// 2D Optimized: Initialization kernel +// ============================================================================== +__global__ void init_distance_2d_kernel( + const float* __restrict__ input, + float* __restrict__ distance, + int* __restrict__ indices_y, + int* __restrict__ indices_x, + int height, + int width, + int64_t batch_stride, + bool compute_indices +) { + int64_t batch_idx = blockIdx.z; + int y = blockIdx.y * blockDim.y + threadIdx.y; + int x = blockIdx.x * blockDim.x + threadIdx.x; + + if (y >= height || x >= width) return; + + int64_t idx = batch_idx * batch_stride + y * width + x; + + float val = input[idx]; + distance[idx] = (val != 0.0f) ? INF_VAL : 0.0f; + + if (compute_indices) { + indices_y[idx] = y; + indices_x[idx] = x; + } +} + +// ============================================================================== +// 2D Optimized: Row-wise EDT (X direction) - contiguous access +// Each block processes one row +// ============================================================================== +__global__ void edt_2d_rows_kernel( + const float* __restrict__ input, + float* __restrict__ output, + const int* __restrict__ input_idx_y, + const int* __restrict__ input_idx_x, + int* __restrict__ output_idx_y, + int* __restrict__ output_idx_x, + int height, + int width, + int64_t batch_stride, + float spacing, + bool compute_indices +) { + // blockIdx.x = batch_idx * height + row_idx + int64_t linear_idx = blockIdx.x; + int row_idx = linear_idx % height; + int64_t batch_idx = linear_idx / height; + + int64_t row_base = batch_idx * batch_stride + row_idx * width; + + extern __shared__ char shared_mem[]; + float* v_val = (float*)shared_mem; + int* v_idx = (int*)(v_val + width); + float* z = (float*)(v_idx + width); + + int tid = threadIdx.x; + int num_threads = blockDim.x; + + // Load row into shared memory (contiguous access - optimal) + for (int i = tid; i < width; i += num_threads) { + v_val[i] = input[row_base + i]; + } + __syncthreads(); + + // Build lower envelope (thread 0 only) + __shared__ int k_shared; + + if (tid == 0) { + int k = -1; + + for (int q = 0; q < width; q++) { + float fq = v_val[q]; + if (fq >= INF_VAL * 0.5f) continue; + + float q_pos = (float)q * spacing; + float q_pos_sq = q_pos * q_pos; + + while (k >= 0) { + int vk = v_idx[k]; + float vk_pos = (float)vk * spacing; + float fvk = v_val[vk]; + float s = ((fq + q_pos_sq) - (fvk + vk_pos * vk_pos)) / (2.0f * (q_pos - vk_pos)); + + if (s > z[k]) break; + k--; + } + + k++; + v_idx[k] = q; + + if (k == 0) { + z[0] = -INF_VAL; + } else { + int vk_prev = v_idx[k - 1]; + float vk_prev_pos = (float)vk_prev * spacing; + float fvk_prev = v_val[vk_prev]; + z[k] = ((fq + q_pos_sq) - (fvk_prev + vk_prev_pos * vk_prev_pos)) / + (2.0f * (q_pos - vk_prev_pos)); + } + z[k + 1] = INF_VAL; + } + k_shared = k; + } + __syncthreads(); + + int k = k_shared; + + // Parallel fill with binary search + for (int q = tid; q < width; q += num_threads) { + int64_t out_idx = row_base + q; + + if (k < 0) { + output[out_idx] = INF_VAL; + if (compute_indices) { + output_idx_y[out_idx] = row_idx; + output_idx_x[out_idx] = 0; + } + } else { + float q_pos = (float)q * spacing; + + // Binary search + int lo = 0, hi = k; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (z[mid] <= q_pos) lo = mid; + else hi = mid - 1; + } + + int nearest = v_idx[lo]; + float nearest_pos = (float)nearest * spacing; + float diff = q_pos - nearest_pos; + float dist_sq = diff * diff + v_val[nearest]; + + output[out_idx] = dist_sq; // Keep squared for next pass + + if (compute_indices) { + output_idx_y[out_idx] = row_idx; // Y unchanged in X pass + output_idx_x[out_idx] = nearest; + } + } + } +} + +// ============================================================================== +// 2D Optimized: Column-wise EDT (Y direction) - strided access with shared memory +// Each block processes one column +// ============================================================================== +__global__ void edt_2d_cols_kernel( + const float* __restrict__ input, + float* __restrict__ output, + const int* __restrict__ input_idx_y, + const int* __restrict__ input_idx_x, + int* __restrict__ output_idx_y, + int* __restrict__ output_idx_x, + int height, + int width, + int64_t batch_stride, + float spacing, + bool is_final, + bool compute_indices +) { + // blockIdx.x = batch_idx * width + col_idx + int64_t linear_idx = blockIdx.x; + int col_idx = linear_idx % width; + int64_t batch_idx = linear_idx / width; + + int64_t col_base = batch_idx * batch_stride + col_idx; + int stride = width; // Stride to next row + + extern __shared__ char shared_mem[]; + float* v_val = (float*)shared_mem; + int* v_idx = (int*)(v_val + height); + float* z = (float*)(v_idx + height); + int* src_x = (int*)(z + height + 1); // Store source X indices for index propagation + + int tid = threadIdx.x; + int num_threads = blockDim.x; + + // Load column into shared memory (strided access - but only once) + for (int i = tid; i < height; i += num_threads) { + v_val[i] = input[col_base + i * stride]; + if (compute_indices) { + src_x[i] = input_idx_x[col_base + i * stride]; + } + } + __syncthreads(); + + // Build lower envelope (thread 0 only) + __shared__ int k_shared; + + if (tid == 0) { + int k = -1; + + for (int q = 0; q < height; q++) { + float fq = v_val[q]; + if (fq >= INF_VAL * 0.5f) continue; + + float q_pos = (float)q * spacing; + float q_pos_sq = q_pos * q_pos; + + while (k >= 0) { + int vk = v_idx[k]; + float vk_pos = (float)vk * spacing; + float fvk = v_val[vk]; + float s = ((fq + q_pos_sq) - (fvk + vk_pos * vk_pos)) / (2.0f * (q_pos - vk_pos)); + + if (s > z[k]) break; + k--; + } + + k++; + v_idx[k] = q; + + if (k == 0) { + z[0] = -INF_VAL; + } else { + int vk_prev = v_idx[k - 1]; + float vk_prev_pos = (float)vk_prev * spacing; + float fvk_prev = v_val[vk_prev]; + z[k] = ((fq + q_pos_sq) - (fvk_prev + vk_prev_pos * vk_prev_pos)) / + (2.0f * (q_pos - vk_prev_pos)); + } + z[k + 1] = INF_VAL; + } + k_shared = k; + } + __syncthreads(); + + int k = k_shared; + + // Parallel fill with binary search + for (int q = tid; q < height; q += num_threads) { + int64_t out_idx = col_base + q * stride; + + if (k < 0) { + output[out_idx] = INF_VAL; + if (compute_indices) { + output_idx_y[out_idx] = 0; + output_idx_x[out_idx] = col_idx; + } + } else { + float q_pos = (float)q * spacing; + + // Binary search + int lo = 0, hi = k; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (z[mid] <= q_pos) lo = mid; + else hi = mid - 1; + } + + int nearest = v_idx[lo]; + float nearest_pos = (float)nearest * spacing; + float diff = q_pos - nearest_pos; + float dist_sq = diff * diff + v_val[nearest]; + + output[out_idx] = is_final ? sqrtf(dist_sq) : dist_sq; + + if (compute_indices) { + output_idx_y[out_idx] = nearest; + output_idx_x[out_idx] = src_x[nearest]; // Propagate X from source + } + } + } +} + +// ============================================================================== +// 2D Optimized: Host function (shared memory only, for dimensions <= 2048) +// ============================================================================== +std::tuple run_edt_2d_optimized( + torch::Tensor input, + float spacing_y, + float spacing_x, + bool return_indices +) { + TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor"); + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32"); + + input = input.contiguous(); + + int total_ndim = input.dim(); + TORCH_CHECK(total_ndim >= 2, "Input must have at least 2 dimensions"); + + auto shape = input.sizes().vec(); + int height = shape[total_ndim - 2]; + int width = shape[total_ndim - 1]; + int64_t batch_stride = (int64_t)height * width; + int64_t batch_size = input.numel() / batch_stride; + + // This function should only be called when both dimensions fit in shared memory + TORCH_CHECK(height <= SHARED_MEM_LIMIT && width <= SHARED_MEM_LIMIT, + "Dimensions too large for 2D optimized path, use general N-D version"); + + // Create output tensors + auto distance = torch::empty_like(input); + auto temp = torch::empty_like(input); + + torch::Tensor indices_y, indices_x, temp_idx_y, temp_idx_x; + if (return_indices) { + indices_y = torch::empty_like(input, input.options().dtype(torch::kInt32)); + indices_x = torch::empty_like(input, input.options().dtype(torch::kInt32)); + temp_idx_y = torch::empty_like(indices_y); + temp_idx_x = torch::empty_like(indices_x); + } + + // Step 1: Initialize + { + dim3 block(16, 16); + dim3 grid((width + 15) / 16, (height + 15) / 16, batch_size); + + init_distance_2d_kernel<<>>( + input.data_ptr(), + distance.data_ptr(), + return_indices ? indices_y.data_ptr() : nullptr, + return_indices ? indices_x.data_ptr() : nullptr, + height, width, batch_stride, + return_indices + ); + } + + // Step 2: Row-wise EDT (X direction) - shared memory + { + int64_t num_rows = batch_size * height; + int threads = min(width, MAX_THREADS); + size_t shared_mem_size = width * sizeof(float) + // v_val + width * sizeof(int) + // v_idx + (width + 1) * sizeof(float); // z + + edt_2d_rows_kernel<<>>( + distance.data_ptr(), + temp.data_ptr(), + return_indices ? indices_y.data_ptr() : nullptr, + return_indices ? indices_x.data_ptr() : nullptr, + return_indices ? temp_idx_y.data_ptr() : nullptr, + return_indices ? temp_idx_x.data_ptr() : nullptr, + height, width, batch_stride, + spacing_x, + return_indices + ); + } + + // Step 3: Column-wise EDT (Y direction) - shared memory + { + int64_t num_cols = batch_size * width; + int threads = min(height, MAX_THREADS); + size_t shared_mem_size = height * sizeof(float) + // v_val + height * sizeof(int) + // v_idx + (height + 1) * sizeof(float); // z + if (return_indices) { + shared_mem_size += height * sizeof(int); // src_x + } + + edt_2d_cols_kernel<<>>( + temp.data_ptr(), + distance.data_ptr(), + return_indices ? temp_idx_y.data_ptr() : nullptr, + return_indices ? temp_idx_x.data_ptr() : nullptr, + return_indices ? indices_y.data_ptr() : nullptr, + return_indices ? indices_x.data_ptr() : nullptr, + height, width, batch_stride, + spacing_y, + true, // is_final + return_indices + ); + } + + // Combine indices into single tensor with shape [2, ...] + torch::Tensor indices; + if (return_indices) { + std::vector idx_shape = {2}; + for (auto s : shape) idx_shape.push_back(s); + indices = torch::empty(idx_shape, input.options().dtype(torch::kInt32)); + + // Copy Y and X indices + indices.select(0, 0).copy_(indices_y); + indices.select(0, 1).copy_(indices_x); + } + + return std::make_tuple(distance, indices); +} + +// ============================================================================== +// 1D EDT kernel using GLOBAL memory (for large dimensions) +// ============================================================================== +__global__ void edt_1d_kernel_global( + const float* __restrict__ input, + float* __restrict__ output, + const int* __restrict__ input_idx, + int* __restrict__ output_idx, + float* __restrict__ g_v_val, + int* __restrict__ g_v_idx, + float* __restrict__ g_z, + int* __restrict__ g_k, + int64_t num_slices, + int64_t slice_len, + int64_t num_pixels, + int spatial_ndim, + int current_dim, + float spacing, + bool is_final, + bool compute_indices +) { + int64_t slice_idx = blockIdx.x; + if (slice_idx >= num_slices) return; + + int64_t base_offset = slice_idx * slice_len; + + float* v_val = g_v_val + base_offset; + int* v_idx = g_v_idx + base_offset; + float* z = g_z + slice_idx * (slice_len + 1); + int* k_ptr = g_k + slice_idx; + + int tid = threadIdx.x; + int num_threads = blockDim.x; + + // Load input values + for (int i = tid; i < slice_len; i += num_threads) { + v_val[i] = input[base_offset + i]; + } + __syncthreads(); + + // Build lower envelope (thread 0 only) + if (tid == 0) { + int k = -1; + + for (int q = 0; q < slice_len; q++) { + float fq = v_val[q]; + if (fq >= INF_VAL * 0.5f) continue; + + float q_pos = (float)q * spacing; + float q_pos_sq = q_pos * q_pos; + + while (k >= 0) { + int vk = v_idx[k]; + float vk_pos = (float)vk * spacing; + float fvk = v_val[vk]; + float s = ((fq + q_pos_sq) - (fvk + vk_pos * vk_pos)) / (2.0f * (q_pos - vk_pos)); + + if (s > z[k]) break; + k--; + } + + k++; + v_idx[k] = q; + + if (k == 0) { + z[0] = -INF_VAL; + } else { + int vk_prev = v_idx[k - 1]; + float vk_prev_pos = (float)vk_prev * spacing; + float fvk_prev = v_val[vk_prev]; + z[k] = ((fq + q_pos_sq) - (fvk_prev + vk_prev_pos * vk_prev_pos)) / + (2.0f * (q_pos - vk_prev_pos)); + } + z[k + 1] = INF_VAL; + } + *k_ptr = k; + } + __syncthreads(); + + int k = *k_ptr; + + // Parallel fill with binary search + for (int q = tid; q < slice_len; q += num_threads) { + int64_t out_idx = base_offset + q; + + if (k < 0) { + output[out_idx] = INF_VAL; + if (compute_indices) { + for (int d = 0; d < spatial_ndim; d++) { + output_idx[d * num_pixels + out_idx] = 0; + } + } + } else { + float q_pos = (float)q * spacing; + + // Binary search + int lo = 0, hi = k; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (z[mid] <= q_pos) lo = mid; + else hi = mid - 1; + } + + int nearest = v_idx[lo]; + float nearest_pos = (float)nearest * spacing; + float diff = q_pos - nearest_pos; + float dist_sq = diff * diff + v_val[nearest]; + + output[out_idx] = is_final ? sqrtf(dist_sq) : dist_sq; + + if (compute_indices) { + int64_t src_idx = base_offset + nearest; + for (int d = 0; d < spatial_ndim; d++) { + if (d == current_dim) { + output_idx[d * num_pixels + out_idx] = nearest; + } else { + output_idx[d * num_pixels + out_idx] = input_idx[d * num_pixels + src_idx]; + } + } + } + } + } +} + +// ============================================================================== +// 1D Euclidean Distance Transform (Felzenszwalb & Huttenlocher) +// ============================================================================== +__global__ void edt_1d_kernel( + const float* __restrict__ input, + float* __restrict__ output, + const int* __restrict__ input_idx, + int* __restrict__ output_idx, + int64_t num_slices, + int64_t slice_len, + int64_t num_pixels, + int spatial_ndim, + int current_dim, + float spacing, + bool is_final, + bool compute_indices +) { + int64_t slice_idx = blockIdx.x; + if (slice_idx >= num_slices) return; + + int64_t base_offset = slice_idx * slice_len; + + extern __shared__ char shared_mem[]; + float* v_val = (float*)shared_mem; + int* v_idx = (int*)(v_val + slice_len); + float* z = (float*)(v_idx + slice_len); + + int tid = threadIdx.x; + int num_threads = blockDim.x; + + // Load input values into shared memory + for (int i = tid; i < slice_len; i += num_threads) { + v_val[i] = input[base_offset + i]; + } + __syncthreads(); + + // Build lower envelope (thread 0 only) + __shared__ int k_shared; + + if (tid == 0) { + int k = -1; + + for (int q = 0; q < slice_len; q++) { + float fq = v_val[q]; + if (fq >= INF_VAL * 0.5f) continue; + + float q_pos = (float)q * spacing; + float q_pos_sq = q_pos * q_pos; + + while (k >= 0) { + int vk = v_idx[k]; + float vk_pos = (float)vk * spacing; + float fvk = v_val[vk]; + float s = ((fq + q_pos_sq) - (fvk + vk_pos * vk_pos)) / (2.0f * (q_pos - vk_pos)); + + if (s > z[k]) break; + k--; + } + + k++; + v_idx[k] = q; + + if (k == 0) { + z[0] = -INF_VAL; + } else { + int vk_prev = v_idx[k - 1]; + float vk_prev_pos = (float)vk_prev * spacing; + float fvk_prev = v_val[vk_prev]; + z[k] = ((fq + q_pos_sq) - (fvk_prev + vk_prev_pos * vk_prev_pos)) / + (2.0f * (q_pos - vk_prev_pos)); + } + z[k + 1] = INF_VAL; + } + k_shared = k; + } + __syncthreads(); + + int k = k_shared; + + // Parallel fill + for (int q = tid; q < slice_len; q += num_threads) { + int64_t out_idx = base_offset + q; + + if (k < 0) { + output[out_idx] = INF_VAL; + if (compute_indices) { + for (int d = 0; d < spatial_ndim; d++) { + output_idx[d * num_pixels + out_idx] = 0; + } + } + } else { + float q_pos = (float)q * spacing; + + // Binary search + int lo = 0, hi = k; + while (lo < hi) { + int mid = (lo + hi + 1) / 2; + if (z[mid] <= q_pos) lo = mid; + else hi = mid - 1; + } + + int nearest = v_idx[lo]; + float nearest_pos = (float)nearest * spacing; + float diff = q_pos - nearest_pos; + float dist_sq = diff * diff + v_val[nearest]; + + output[out_idx] = is_final ? sqrtf(dist_sq) : dist_sq; + + if (compute_indices) { + int64_t src_idx = base_offset + nearest; + for (int d = 0; d < spatial_ndim; d++) { + if (d == current_dim) { + output_idx[d * num_pixels + out_idx] = nearest; + } else { + output_idx[d * num_pixels + out_idx] = input_idx[d * num_pixels + src_idx]; + } + } + } + } + } +} + +// ============================================================================== +// Initialization kernel: set up initial distances and indices +// ============================================================================== +__global__ void init_distance_kernel( + const float* __restrict__ input, + float* __restrict__ distance, + int* __restrict__ indices, + int64_t total_pixels, + int total_ndim, + int spatial_ndim, + const int64_t* __restrict__ shape, + bool compute_indices +) { + int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_pixels) return; + + // Set distance: 0 for background (input == 0), INF for foreground (input != 0) + float val = input[idx]; + distance[idx] = (val != 0.0f) ? INF_VAL : 0.0f; + + // Initialize indices to current coordinates + if (compute_indices) { + int64_t temp = idx; + int coords[16]; // Support up to 16D + + // Compute coordinates from linear index + for (int d = total_ndim - 1; d >= 0; d--) { + int64_t dim_size = shape[d]; + coords[d] = temp % dim_size; + temp /= dim_size; + } + + // Store spatial coordinates + int start_dim = total_ndim - spatial_ndim; + for (int s = 0; s < spatial_ndim; s++) { + indices[s * total_pixels + idx] = coords[start_dim + s]; + } + } +} + +// ============================================================================== +// Host function to run separable EDT +// ============================================================================== +std::tuple run_edt_separable( + torch::Tensor input, + const std::vector& sampling, + bool return_indices +) { + TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor"); + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32"); + + input = input.contiguous(); + + int total_ndim = input.dim(); + int spatial_ndim = sampling.size(); + int start_dim = total_ndim - spatial_ndim; + + auto shape = input.sizes().vec(); + int64_t total_pixels = input.numel(); + + // Create output tensors + auto distance = torch::empty_like(input); + torch::Tensor indices; + if (return_indices) { + std::vector idx_shape = {spatial_ndim}; + for (auto s : shape) idx_shape.push_back(s); + indices = torch::empty(idx_shape, input.options().dtype(torch::kInt32)); + } + + // Copy shape to device + auto shape_tensor = torch::tensor(std::vector(shape.begin(), shape.end()), + torch::TensorOptions().dtype(torch::kInt64).device(input.device())); + + // Initialize distances and indices + int threads = 256; + int blocks = (total_pixels + threads - 1) / threads; + + init_distance_kernel<<>>( + input.data_ptr(), + distance.data_ptr(), + return_indices ? indices.data_ptr() : nullptr, + total_pixels, + total_ndim, + spatial_ndim, + shape_tensor.data_ptr(), + return_indices + ); + + // Global memory buffers (allocated lazily for large dimensions) + torch::Tensor g_v_val, g_v_idx, g_z, g_k; + + // Process each spatial dimension + for (int dim_idx = 0; dim_idx < spatial_ndim; dim_idx++) { + int actual_dim = start_dim + dim_idx; + bool is_final = (dim_idx == spatial_ndim - 1); + float spacing = sampling[dim_idx]; + + // Transpose to make current dimension last + auto dist_transposed = distance.transpose(actual_dim, total_ndim - 1).contiguous(); + auto dist_out = torch::empty_like(dist_transposed); + + torch::Tensor idx_transposed, idx_out; + if (return_indices) { + // Indices have an extra leading dimension + idx_transposed = indices.transpose(actual_dim + 1, total_ndim).contiguous(); + idx_out = torch::empty_like(idx_transposed); + } + + // Get dimensions after transpose + int64_t slice_len = dist_transposed.size(-1); + int64_t num_slices = dist_transposed.numel() / slice_len; + + int kernel_threads = min((int)slice_len, MAX_THREADS); + + // Choose between shared memory and global memory kernel + bool use_shared = (slice_len <= SHARED_MEM_LIMIT); + + if (use_shared) { + // Calculate shared memory size + size_t shared_mem_size = slice_len * sizeof(float) + // v_val + slice_len * sizeof(int) + // v_idx + (slice_len + 1) * sizeof(float); // z + + edt_1d_kernel<<>>( + dist_transposed.data_ptr(), + dist_out.data_ptr(), + return_indices ? idx_transposed.data_ptr() : nullptr, + return_indices ? idx_out.data_ptr() : nullptr, + num_slices, + slice_len, + dist_transposed.numel(), + spatial_ndim, + dim_idx, + spacing, + is_final, + return_indices + ); + } else { + // Allocate global memory buffers if needed + int64_t total_elements = dist_transposed.numel(); + if (!g_v_val.defined() || g_v_val.numel() < total_elements) { + g_v_val = torch::empty({total_elements}, dist_transposed.options()); + g_v_idx = torch::empty({total_elements}, dist_transposed.options().dtype(torch::kInt32)); + } + if (!g_z.defined() || g_z.numel() < num_slices * (slice_len + 1)) { + g_z = torch::empty({num_slices * (slice_len + 1)}, dist_transposed.options()); + } + if (!g_k.defined() || g_k.numel() < num_slices) { + g_k = torch::empty({num_slices}, dist_transposed.options().dtype(torch::kInt32)); + } + + edt_1d_kernel_global<<>>( + dist_transposed.data_ptr(), + dist_out.data_ptr(), + return_indices ? idx_transposed.data_ptr() : nullptr, + return_indices ? idx_out.data_ptr() : nullptr, + g_v_val.data_ptr(), + g_v_idx.data_ptr(), + g_z.data_ptr(), + g_k.data_ptr(), + num_slices, + slice_len, + dist_transposed.numel(), + spatial_ndim, + dim_idx, + spacing, + is_final, + return_indices + ); + } + + // Transpose back + distance = dist_out.transpose(actual_dim, total_ndim - 1); + if (return_indices) { + indices = idx_out.transpose(actual_dim + 1, total_ndim); + } + } + + return std::make_tuple(distance.contiguous(), return_indices ? indices.contiguous() : torch::Tensor()); +} + +// ============================================================================== +// JFA Dispatch Helpers +// ============================================================================== +std::tuple run_jfa_2d( + torch::Tensor input, int64_t H, int64_t W, int grid, int block, int64_t numel +) { + auto index_opts = input.options().dtype(torch::kInt32); + auto idx_shape = input.sizes().vec(); + idx_shape.push_back(2); + auto curr_idx = torch::empty(idx_shape, index_opts); + auto next_idx = torch::empty(idx_shape, index_opts); + + int2* d_curr = (int2*)curr_idx.data_ptr(); + int2* d_next = (int2*)next_idx.data_ptr(); + + init_jfa_kernel_2d_opt<<>>( + input.data_ptr(), d_curr, numel, H, W + ); + + { + dim3 dimBlock(JFA_BLOCK_DIM, JFA_BLOCK_DIM); + int64_t batch_size = numel / (H * W); + dim3 dimGrid((W + JFA_BLOCK_DIM - 1) / JFA_BLOCK_DIM, + (H + JFA_BLOCK_DIM - 1) / JFA_BLOCK_DIM, + batch_size); + + jfa_block_fused_kernel_2d<<>>(d_curr, d_next, H, W, batch_size); + std::swap(d_curr, d_next); + std::swap(curr_idx, next_idx); + } + + int max_dim = std::max((int)H, (int)W); + int step = 16; + + while (step < max_dim) { + jfa_step_global_2d_opt<<>>(d_curr, d_next, step, H, W, numel); + std::swap(d_curr, d_next); + std::swap(curr_idx, next_idx); + step *= 2; + } + + auto final_dist = torch::empty_like(input); + calc_dist_kernel_2d_opt<<>>(d_curr, final_dist.data_ptr(), numel, H, W); + + return std::make_tuple(final_dist, curr_idx); +} + +std::tuple run_jfa_3d( + torch::Tensor input, int64_t D, int64_t H, int64_t W, int grid, int block, int64_t numel +) { + bool use_int16 = (D < 32767 && H < 32767 && W < 32767); + auto index_opts = input.options().dtype(use_int16 ? torch::kInt16 : torch::kInt32); + + int64_t batch = numel / (D * H * W); + + // (3, Batch, D, H, W) + auto curr_idx_soa = torch::empty({3, batch, D, H, W}, index_opts); + auto next_idx_soa = torch::empty({3, batch, D, H, W}, index_opts); + + void* d_curr = curr_idx_soa.data_ptr(); + void* d_next = next_idx_soa.data_ptr(); + int64_t plane_stride = numel; // B*D*H*W + + // 1. Init + if (use_int16) { + int16_t* ptr = (int16_t*)d_curr; + init_jfa_kernel_3d_soa<<>>( + input.data_ptr(), ptr, ptr + plane_stride, ptr + 2 * plane_stride, numel, D, H, W + ); + } else { + int32_t* ptr = (int32_t*)d_curr; + init_jfa_kernel_3d_soa<<>>( + input.data_ptr(), ptr, ptr + plane_stride, ptr + 2 * plane_stride, numel, D, H, W + ); + } + + // 2. Fused Steps + int block_dim = 8; + int blocks_per_d = (D + block_dim - 1) / block_dim; + dim3 fused_block(block_dim, block_dim, block_dim); + dim3 fused_grid((W + block_dim - 1) / block_dim, (H + block_dim - 1) / block_dim, blocks_per_d * batch); + size_t smem_bytes = (14*14*14) * 3 * (use_int16 ? 2 : 4); + + if (use_int16) { + int16_t* c = (int16_t*)d_curr; + int16_t* n = (int16_t*)d_next; + jfa_block_fused_kernel_3d_soa<<>>( + c, c + plane_stride, c + 2 * plane_stride, + n, n + plane_stride, n + 2 * plane_stride, + D, H, W, blocks_per_d + ); + } else { + int32_t* c = (int32_t*)d_curr; + int32_t* n = (int32_t*)d_next; + jfa_block_fused_kernel_3d_soa<<>>( + c, c + plane_stride, c + 2 * plane_stride, + n, n + plane_stride, n + 2 * plane_stride, + D, H, W, blocks_per_d + ); + } + std::swap(d_curr, d_next); + + // 3. Global Steps + int max_dim = std::max({(int)D, (int)H, (int)W}); + int step = 4; + while (step < max_dim) { + if (use_int16) { + int16_t* c = (int16_t*)d_curr; + int16_t* n = (int16_t*)d_next; + jfa_step_3d_soa<<>>( + c, c + plane_stride, c + 2 * plane_stride, + n, n + plane_stride, n + 2 * plane_stride, + step, D, H, W, numel + ); + } else { + int32_t* c = (int32_t*)d_curr; + int32_t* n = (int32_t*)d_next; + jfa_step_3d_soa<<>>( + c, c + plane_stride, c + 2 * plane_stride, + n, n + plane_stride, n + 2 * plane_stride, + step, D, H, W, numel + ); + } + std::swap(d_curr, d_next); + step *= 2; + } + + // 4. Final Dist + auto final_dist = torch::empty_like(input); + if (use_int16) { + int16_t* c = (int16_t*)d_curr; + calc_dist_kernel_3d_soa<<>>( + c, c + plane_stride, c + 2 * plane_stride, + final_dist.data_ptr(), numel, D, H, W + ); + } else { + int32_t* c = (int32_t*)d_curr; + calc_dist_kernel_3d_soa<<>>( + c, c + plane_stride, c + 2 * plane_stride, + final_dist.data_ptr(), numel, D, H, W + ); + } + + // Permute result indices back to (Batch, D, H, W, 3) + torch::Tensor result_indices; + if (d_curr == curr_idx_soa.data_ptr()) result_indices = curr_idx_soa; + else result_indices = next_idx_soa; + + result_indices = result_indices.permute({1, 2, 3, 4, 0}).contiguous(); + + return std::make_tuple(final_dist, result_indices); +} + +// ============================================================================== +// JFA Main Entry Point +// ============================================================================== +std::tuple distance_transform_cuda(torch::Tensor input) { + TORCH_CHECK(input.is_cuda(), "Input must be CUDA tensor"); + input = input.contiguous(); + + int64_t dims = input.dim(); + int64_t numel = input.numel(); + int block = BLOCK_SIZE; + int grid = (numel + block - 1) / block; + + if (dims >= 5) { + // For 4D+ spatial, fall back to separable algorithm + int spatial_ndim = dims - 1; + std::vector sampling(spatial_ndim, 1.0f); + return run_edt_separable(input, sampling, true); + } + else if (dims == 4) { + int64_t dim1 = input.size(1); + if (dim1 == 1) { + int64_t H = input.size(-2); + int64_t W = input.size(-1); + return run_jfa_2d(input, H, W, grid, block, numel); + } + else { + int64_t D = dim1; + int64_t H = input.size(-2); + int64_t W = input.size(-1); + return run_jfa_3d(input, D, H, W, grid, block, numel); + } + } + else if (dims == 3) { + int64_t H = input.size(-2); + int64_t W = input.size(-1); + return run_jfa_2d(input, H, W, grid, block, numel); + } + else if (dims == 2) { + int64_t H = 1; + int64_t W = input.size(-1); + auto result = run_jfa_2d(input, H, W, grid, block, numel); + torch::Tensor dist = std::get<0>(result); + torch::Tensor idx_2d = std::get<1>(result); + auto idx_1d = idx_2d.slice(/*dim=*/-1, /*start=*/1, /*end=*/2).contiguous(); + return std::make_tuple(dist, idx_1d); + } + else { + TORCH_CHECK(false, "Unsupported dimensions."); + return std::make_tuple(torch::Tensor(), torch::Tensor()); + } +} + +// ============================================================================== +// Python binding entry point +// ============================================================================== + +std::tuple distance_transform_edt_cuda( + torch::Tensor input, + std::vector sampling, + bool return_distances, + bool return_indices, + const std::string& algorithm +) { + TORCH_CHECK(input.is_cuda(), "Input must be a CUDA tensor"); + + int total_ndim = input.dim(); + + // Handle empty sampling (default to unit spacing for all spatial dimensions) + if (sampling.empty()) { + // Assume all dimensions are spatial if no sampling provided + // But typically input is (B, C, spatial...) so use total_ndim - 2 + int spatial_ndim = total_ndim >= 3 ? total_ndim - 2 : total_ndim; + sampling.resize(spatial_ndim, 1.0f); + } + + int spatial_ndim = sampling.size(); + + // Check if we can use JFA algorithm + bool can_use_jfa = true; + + // JFA doesn't support non-unit sampling + for (float s : sampling) { + if (std::abs(s - 1.0f) > 1e-6f) { + can_use_jfa = false; + break; + } + } + + // JFA only supports 2D and 3D (spatial dimensions) + if (spatial_ndim > 3) { + can_use_jfa = false; + } + + // Determine which algorithm to use + bool use_jfa = false; + if (algorithm == "jfa") { + if (can_use_jfa) { + use_jfa = true; + } else { + // Fall back to exact with warning (or we can throw) + // For now, silently fall back to exact + use_jfa = false; + } + } else if (algorithm == "exact") { + use_jfa = false; + } else if (algorithm == "auto") { + // Auto mode: use JFA only for 2D with unit sampling + // For 3D, exact algorithm performs better in practice + use_jfa = can_use_jfa && (spatial_ndim == 2); + } else { + TORCH_CHECK(false, "algorithm must be 'exact', 'jfa', or 'auto', got: ", algorithm); + } + + if (use_jfa) { + // Use JFA algorithm + auto [distances, indices_result] = distance_transform_cuda(input); + + if (!return_indices) { + indices_result = torch::Tensor(); + } + + return std::make_tuple(distances, indices_result); + } + + // Use exact (Felzenszwalb) algorithm + // Use 2D optimized path only when both dimensions fit in shared memory + // For larger dimensions, the N-D general version with transpose is faster + if (spatial_ndim == 2) { + auto shape = input.sizes().vec(); + int height = shape[total_ndim - 2]; + int width = shape[total_ndim - 1]; + + // Only use 2D optimized path when shared memory can be used for both directions + if (height <= SHARED_MEM_LIMIT && width <= SHARED_MEM_LIMIT) { + float spacing_y = sampling[0]; + float spacing_x = sampling[1]; + + auto [distances, indices_result] = run_edt_2d_optimized( + input, spacing_y, spacing_x, return_indices + ); + + if (!return_indices) { + indices_result = torch::Tensor(); + } + + return std::make_tuple(distances, indices_result); + } + } + + // Fall back to general N-D implementation + auto [distances, indices_result] = run_edt_separable(input, sampling, return_indices); + + if (!return_indices) { + indices_result = torch::Tensor(); + } + + return std::make_tuple(distances, indices_result); +} diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu deleted file mode 100644 index 503d13c..0000000 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ /dev/null @@ -1,469 +0,0 @@ -#include -#include -#include -#include -#include -#include - -// ------------------------------------------------------------------ -// Configuration Constants -// ------------------------------------------------------------------ -#define INF_VAL 1e8f -#define MAX_THREADS 1024 -// Shared memory limit: typically 48 KB. -// Each pixel requires: float(value) + int(idx1) + int(idx2) = 12 bytes. -// 4096 * 12 = 48 KB. -#define SMEM_LIMIT_ELEMENTS 4096 - -// ------------------------------------------------------------------ -// Device Helper Functions -// ------------------------------------------------------------------ - -__device__ __forceinline__ float sqr(float x) { return x * x; } - -// Compute the JFA cost: (q - p)^2 + weight[p] -__device__ __forceinline__ float compute_cost(int q, int p, float val_p) { - if (p < 0) return INF_VAL; - return sqr((float)q - (float)p) + val_p; -} - -// ------------------------------------------------------------------ -// JFA Core Logic (Device Only) -// ------------------------------------------------------------------ -// Core JFA logic, independent of data location (works with both Shared and Global memory). -__device__ void run_jfa_core( - int N, - int tid, - const float* __restrict__ vals, // input weight (read-only) - int* __restrict__ idx_curr, // Ping-Pong Buffer A - int* __restrict__ idx_next // Ping-Pong Buffer B -) { - // 1. Initialization: determine whether each pixel is a valid source based on vals. - for (int i = tid; i < N; i += blockDim.x) { - if (vals[i] >= INF_VAL * 0.9f) { - idx_curr[i] = -1; // background - } else { - idx_curr[i] = i; // For each object/source point, the initial index points to itself. - } - } - __syncthreads(); - - // 2. Iterative Propagation (Step = 1, 2, 4, ... < N) - int* idx_in = idx_curr; - int* idx_out = idx_next; - - for (int step = 1; step < N; step *= 2) { - for (int i = tid; i < N; i += blockDim.x) { - int my_best_p = idx_in[i]; - float min_cost = INF_VAL; - - // Check its current best solution - if (my_best_p != -1) { - min_cost = compute_cost(i, my_best_p, vals[my_best_p]); - } - - // Check Left Neighbor (-step) - int left = i - step; - if (left >= 0) { - int left_p = idx_in[left]; - if (left_p != -1) { - float c = compute_cost(i, left_p, vals[left_p]); - if (c < min_cost) { - min_cost = c; - my_best_p = left_p; - } - } - } - - // Check Right Neighbor (+step) - int right = i + step; - if (right < N) { - int right_p = idx_in[right]; - if (right_p != -1) { - float c = compute_cost(i, right_p, vals[right_p]); - if (c < min_cost) { - min_cost = c; - my_best_p = right_p; - } - } - } - idx_out[i] = my_best_p; - } - - // Swap Pointers - int* temp = idx_in; - idx_in = idx_out; - idx_out = temp; - __syncthreads(); - } - - // 3. Ensure the final result is stored in idx_curr (if the loop ends with idx_next, copy it back). - if (idx_in != idx_curr) { - for (int i = tid; i < N; i += blockDim.x) { - idx_curr[i] = idx_next[i]; - } - __syncthreads(); - } -} - -// ------------------------------------------------------------------ -// Kernel 1: Shared Memory JFA (Fast Path) -// ------------------------------------------------------------------ -// Template parameter NDim: when NDim > 0, the compiler performs loop unrolling optimizations. -// Runtime parameter runtime_ndim: when NDim == 0 (default behavior), this parameter specifies the dimension. -template -__global__ void edt_kernel_shared( - const float* __restrict__ in_data, // input Dist^2 - const int32_t* __restrict__ in_indices, // output Indices - float* __restrict__ out_dist, // output Dist (IsFinal ? sqrt : sqr) - int32_t* __restrict__ out_indices, // output Indices - int64_t L, // Size of the current dimension - int64_t total_elements, // Total number of elements - int runtime_ndim // Runtime dimension (used as fallback) -) { - // Determine the effective dimension - const int D = (NDim > 0) ? NDim : runtime_ndim; - - // Compute row offset - int64_t row_idx = blockIdx.x; - int64_t offset = row_idx * L; - - if (offset >= total_elements) return; - - // Shared memory layout - extern __shared__ char s_buffer[]; - float* s_vals = (float*)s_buffer; - int* s_idx1 = (int*)(s_vals + L); - int* s_idx2 = (int*)(s_idx1 + L); - - // 1. Load distances into Shared Memory - for (int i = threadIdx.x; i < L; i += blockDim.x) { - s_vals[i] = __ldg(&in_data[offset + i]); - } - __syncthreads(); - - // 2. Run the core JFA logic - run_jfa_core(L, threadIdx.x, s_vals, s_idx1, s_idx2); - - // 3. Write back the results - for (int q = threadIdx.x; q < L; q += blockDim.x) { - int p = s_idx1[q]; // Nearest point (local index within 0..L-1) - float dist_val; - - // Compute updated distance - if (p != -1) { - float dist_sq = sqr((float)q - (float)p) + s_vals[p]; - dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; - } else { - // No source point found (e.g., entire row is background) - dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p = 0; // Prevent out-of-bounds access - } - out_dist[offset + q] = dist_val; - - // Propagate indices: copy a vector of size [D] - if (p != -1) { - int64_t src_offset = (offset + p) * D; - int64_t dst_offset = (offset + q) * D; - - // When NDim > 0, this loop is fully unrolled by the compiler - for (int d = 0; d < D; ++d) { - out_indices[dst_offset + d] = in_indices[src_offset + d]; - } - } else { - // Fallback: no source available - int64_t dst_offset = (offset + q) * D; - for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; - } - } -} - -// ------------------------------------------------------------------ -// Kernel 2: Global Memory JFA (Fallback Path) -// ------------------------------------------------------------------ -// Same logic as above, but uses Global Memory as the ping-pong buffer -template -__global__ void edt_kernel_global( - const float* __restrict__ in_data, - const int32_t* __restrict__ in_indices, - float* __restrict__ out_dist, - int32_t* __restrict__ out_indices, - int* __restrict__ global_buffer_1, - int* __restrict__ global_buffer_2, - int64_t L, - int64_t total_elements, - int runtime_ndim -) { - const int D = (NDim > 0) ? NDim : runtime_ndim; - - int64_t row_idx = blockIdx.x; - int64_t offset = row_idx * L; - - if (offset >= total_elements) return; - - // Pointers to Global Memory - int* g_idx1 = global_buffer_1 + offset; - int* g_idx2 = global_buffer_2 + offset; - - // 1. & 2. Run the JFA core (operating directly on Global Memory) - run_jfa_core(L, threadIdx.x, in_data + offset, g_idx1, g_idx2); - - // 3. Write back results - for (int q = threadIdx.x; q < L; q += blockDim.x) { - int p = g_idx1[q]; - float dist_val; - - if (p != -1) { - float val_p = in_data[offset + p]; - float dist_sq = sqr((float)q - (float)p) + val_p; - dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; - } else { - dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p = 0; - } - - out_dist[offset + q] = dist_val; - - if (p != -1) { - int64_t src_offset = (offset + p) * D; - int64_t dst_offset = (offset + q) * D; - for (int d = 0; d < D; ++d) { - out_indices[dst_offset + d] = in_indices[src_offset + d]; - } - } else { - int64_t dst_offset = (offset + q) * D; - for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; - } - } -} - - -// ------------------------------------------------------------------ -// Kernel 3: Initialize Indices -// ------------------------------------------------------------------ -// Initialize index tensor as grid coordinates -// indices shape: (..., D) -__global__ void init_indices_kernel( - int32_t* indices, - int64_t total_pixels, - int NDim, - const int64_t* __restrict__ shape_ptr // shape of spatial dimensions -) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_pixels) return; - - // Unravel Index - // idx is the flat index of each pixel - // We need to compute its coordinate in spatial_shape - - int64_t temp = idx; - // Use local register array to avoid repeated global memory reads (assume max 10 dims) - int32_t coords[10]; - - // Example: spatial_shape = [D0, D1, D2] - // compute by modulo from last dimension - for (int d = NDim - 1; d >= 0; --d) { - int64_t dim_size = shape_ptr[d]; - coords[d] = temp % dim_size; - temp /= dim_size; - } - - // Write to Global Memory - // Indices tensor is flattened as (TotalPixels, NDim) - int64_t out_ptr = idx * NDim; - for (int d = 0; d < NDim; ++d) { - indices[out_ptr + d] = coords[d]; - } -} - -// ------------------------------------------------------------------ -// Host Function: C++ Entry Point -// ------------------------------------------------------------------ - -std::tuple distance_transform_cuda(torch::Tensor input) { - TORCH_CHECK(input.is_cuda(), "Input must be on CUDA device."); - TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32."); - - input = input.contiguous(); - - // Handle batch dimension: if input is 1D (L), treat as no batch but internally add a batch dimension. - // Convention: input shape is (Batch, D1, D2, ..., Dn) - // Algorithm treats batch and other dims identically (batch is just another leading dimension) - // But index initialization needs to know which are "spatial dimensions". - // Assumption: all dims except dim 0 (Batch) are spatial. - - const int ndim = input.dim(); - // If ndim=1, assume (L) -> sample_ndim=1 - // If ndim=4 (B, C, H, W), sample_ndim=3 (C,H,W treated as spatial? Channels often processed independently) - // Correction: classical EDT usually runs on (H,W) or (D,H,W). - // If channels exist, typically each channel is processed independently. - // For maximum generality, we treat **all dims except dim 0** as spatial dims. - // If input has no batch dim, user should use unsqueeze(0) in Python. - - const int sample_ndim = ndim - 1; - TORCH_CHECK(sample_ndim > 0, "Input tensor must have at least 2 dimensions (Batch, ...)"); - - auto shape = input.sizes().vec(); - int64_t num_pixels = input.numel(); - - if (num_pixels == 0) { - auto index_shape = shape; - index_shape.push_back(sample_ndim); - return std::make_tuple(torch::empty_like(input), - torch::empty(index_shape, input.options().dtype(torch::kInt32))); - } - - // 1. Initialize Distance Tensor - // 0 -> 0, 1 -> INF - auto current_dist = torch::where(input == 0, - torch::tensor(0.0f, input.options()), - torch::tensor(INF_VAL, input.options())); - - // 2. Initialize Index Tensor - // Shape: (Batch, D1, ..., Dn, sample_ndim) - auto index_shape = shape; - index_shape.push_back(sample_ndim); - auto current_idx = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - - // 2.1 Prepare shape tensor for kernel - std::vector spatial_shape(shape.begin() + 1, shape.end()); - auto shape_tensor = torch::tensor(spatial_shape, torch::kInt64).to(input.device()); - - // 2.2 Launch initialization kernel - { - int threads = 256; - int blocks = (num_pixels + threads - 1) / threads; - init_indices_kernel<<>>( - current_idx.data_ptr(), - num_pixels, - sample_ndim, - shape_tensor.data_ptr() - ); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - printf("Init Kernel Failed: %s\n", cudaGetErrorString(err)); - } - } - - // Pre-allocate Global Memory Buffers (lazy) - torch::Tensor global_buf1, global_buf2; - - // 3. Process each spatial dimension (Separable JFA) - // Iterate through each spatial dimension (1 to ndim-1) - for (int d = 1; d < ndim; ++d) { - bool is_final_pass = (d == ndim - 1); - - // --- Step A: Transpose current dim to last --- - // Resulting shape: (..., L) - auto dist_in = current_dist.transpose(d, ndim - 1).contiguous(); - auto idx_in = current_idx.transpose(d, ndim - 1).contiguous(); - - int64_t L = dist_in.size(-1); - int64_t total_slices = dist_in.numel() / L; - - auto dist_out = torch::empty_like(dist_in); - auto idx_out = torch::empty_like(idx_in); - - // --- Step B: Kernel Dispatch --- - int threads = std::min((int64_t)MAX_THREADS, L); - - // Check whether Shared Memory can be used - if (L <= SMEM_LIMIT_ELEMENTS) { - size_t smem_size = L * (sizeof(float) + 2 * sizeof(int)); - - // Switch macro to handle template dimension specialization - #define DISPATCH_SHARED(IS_FINAL) \ - switch(sample_ndim) { \ - case 1: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 2: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 3: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 4: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 5: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 6: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - default: /* Fallback for > 6D */ \ - edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - } - - if (is_final_pass) { DISPATCH_SHARED(true); } - else { DISPATCH_SHARED(false); } - - } else { - // Global Memory fallback (L > 4096) - if (global_buf1.numel() < dist_in.numel()) { - global_buf1 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); - global_buf2 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); - } - - #define DISPATCH_GLOBAL(IS_FINAL) \ - switch(sample_ndim) { \ - case 1: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 2: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 3: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 4: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 5: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 6: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - default: /* Fallback */ \ - edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - } - - if (is_final_pass) { DISPATCH_GLOBAL(true); } - else { DISPATCH_GLOBAL(false); } - } - - // --- Step C: Transpose Back --- - current_dist = dist_out.transpose(d, ndim - 1); - current_idx = idx_out.transpose(d, ndim - 1); - } - - return std::make_tuple(current_dist, current_idx); -} - diff --git a/torchmorph/csrc/torchmorph.cpp b/torchmorph/csrc/torchmorph.cpp index c79970c..58cf6ed 100644 --- a/torchmorph/csrc/torchmorph.cpp +++ b/torchmorph/csrc/torchmorph.cpp @@ -2,9 +2,39 @@ // Declare CUDA implementations torch::Tensor add_cuda(torch::Tensor input, float scalar); -std::tuple distance_transform_cuda(torch::Tensor input); + +// Distance Transform functions +std::tuple distance_transform_edt_cuda( + torch::Tensor input, + std::vector sampling, + bool return_distances, + bool return_indices, + const std::string& algorithm +); + +std::tuple distance_transform_cdt_cuda( + torch::Tensor input, + const std::string& metric, + bool return_distances, + bool return_indices +); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("add_cuda", &add_cuda, "Add tensor with scalar"); - m.def("distance_transform_cuda", &distance_transform_cuda, "Distance transform"); + + // Distance Transform + m.def("distance_transform_edt_cuda", &distance_transform_edt_cuda, + "Exact Euclidean Distance Transform (Felzenszwalb algorithm)", + py::arg("input"), + py::arg("sampling"), + py::arg("return_distances") = true, + py::arg("return_indices") = false, + py::arg("algorithm") = "exact"); + m.def("distance_transform_cdt_cuda", &distance_transform_cdt_cuda, + "Chamfer Distance Transform", + py::arg("input"), + py::arg("metric") = "chessboard", + py::arg("return_distances") = true, + py::arg("return_indices") = false); } diff --git a/torchmorph/distance_transform.py b/torchmorph/distance_transform.py index 868e84a..c3356e2 100644 --- a/torchmorph/distance_transform.py +++ b/torchmorph/distance_transform.py @@ -1,16 +1,282 @@ +from typing import Optional, Sequence, Tuple, Union + import torch from torchmorph import _C -def distance_transform(input: torch.Tensor) -> torch.Tensor: - """Distance Transform in CUDA.""" +def distance_transform_edt( + input: torch.Tensor, + sampling: Optional[Union[float, Sequence[float]]] = None, + return_distances: bool = True, + return_indices: bool = False, + distances: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + algorithm: str = "exact", +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], None]: + """Exact Euclidean Distance Transform (EDT) using Felzenszwalb algorithm. + + Args: + input: Binary input tensor (0 = background, non-zero = foreground). + Must be in (B, C, Spatial...) format where Spatial can be 1D, 2D, or 3D. + For single images, use unsqueeze to add batch and channel dims. + sampling: Spacing of elements along each spatial dimension. If a single + number, the spacing is uniform in all spatial dimensions. If a + sequence, it must match the number of spatial dimensions. + Default is None (unit spacing for all spatial dimensions). + Note: When sampling is not unit spacing, only "exact" algorithm is used. + return_distances: Whether to calculate the distance transform. + Default is True. + return_indices: Whether to calculate the feature transform (indices + of closest background element). Default is False. + distances: Optional output tensor for distances. If provided, must have + the same shape as input. If None and return_distances is True, + a new tensor will be created and returned. + indices: Optional output tensor for indices. If provided, must have shape + (spatial_ndim, ...) where ... matches input shape. If None and + return_indices is True, a new tensor will be created and returned. + algorithm: Algorithm to use for distance transform. Options: + - "exact": Use Felzenszwalb's exact algorithm (default). + - "jfa": Use Jump Flooding Algorithm (fast but approximate). + Only available for 2D/3D with unit sampling. + - "auto": Automatically choose based on input (uses JFA when + applicable, otherwise exact). + + Returns: + Depending on return_distances, return_indices, and whether output tensors + are provided: + - Returns distance tensor only when return_distances=True and distances=None + - Returns indices tensor only when return_indices=True and indices=None + - Returns tuple of (distances, indices) when both conditions above are met + - Returns None if output tensors are provided for all requested outputs + + Example: + >>> import torchmorph as tm + >>> # 2D image: (B, C, H, W) + >>> x = torch.zeros(1, 1, 64, 64, device='cuda') + >>> x[0, 0, 10:20, 10:20] = 1 + >>> dist = tm.distance_transform_edt(x) + >>> dist, indices = tm.distance_transform_edt(x, return_indices=True) + >>> dist = tm.distance_transform_edt(x, sampling=[0.5, 1.0]) + >>> # Using JFA algorithm (faster for large images) + >>> dist = tm.distance_transform_edt(x, algorithm="jfa") + >>> # Using pre-allocated output tensors + >>> dist_out = torch.empty_like(x) + >>> tm.distance_transform_edt(x, distances=dist_out) # Returns None, fills dist_out + >>> # 3D volume: (B, C, D, H, W) + >>> x_3d = torch.zeros(2, 1, 32, 64, 64, device='cuda') + >>> dist_3d = tm.distance_transform_edt(x_3d, sampling=[2.0, 1.0, 1.0]) + """ + if not input.is_cuda: + raise ValueError("Input tensor must be on CUDA device.") + if input.ndim < 3: + raise ValueError( + f"Input must be (B, C, ) format with at least 3 dimensions, got {input.shape}. " + "For single images, use unsqueeze to add batch and channel dims." + ) + if input.numel() == 0: + raise ValueError(f"Invalid input: empty tensor with shape {input.shape}.") + + # Validate pre-allocated output tensors + if distances is not None: + if distances.shape != input.shape: + raise ValueError( + f"distances shape {distances.shape} must match input shape {input.shape}" + ) + if not distances.is_cuda: + raise ValueError("distances tensor must be on CUDA device.") + return_distances = True + + if indices is not None: + if not indices.is_cuda: + raise ValueError("indices tensor must be on CUDA device.") + return_indices = True + + if not return_distances and not return_indices: + raise ValueError( + "At least one of return_distances or return_indices must be True, " + "or output tensors must be provided." + ) + + input = input.float().contiguous() + total_ndim = input.ndim + spatial_ndim = total_ndim - 2 # Exclude B and C dimensions + + # Process sampling parameter for spatial dimensions only + if sampling is None: + # Unit spacing for all spatial dimensions + sampling_list = [1.0] * spatial_ndim + elif isinstance(sampling, (int, float)): + # Single value: same spacing for all spatial dimensions + sampling_list = [float(sampling)] * spatial_ndim + else: + # Sequence: convert to list + sampling_list = [float(s) for s in sampling] + if len(sampling_list) == 1: + # Single element list: broadcast to all spatial dimensions + sampling_list = sampling_list * spatial_ndim + elif len(sampling_list) != spatial_ndim: + raise ValueError( + f"sampling has {len(sampling_list)} but input {spatial_ndim} dimensions " + f"(input shape: {input.shape}, format: (B, C, Spatial...))" + ) + + # Call CUDA kernel - it handles batch dimensions based on sampling size + raw_distances, raw_indices = _C.distance_transform_edt_cuda( + input, sampling_list, return_distances, return_indices, algorithm + ) + + # Copy to pre-allocated tensors if provided + if distances is not None and raw_distances is not None: + distances.copy_(raw_distances) + + if indices is not None and raw_indices is not None: + indices.copy_(raw_indices) + + # Return based on scipy convention: + # Only return tensors that were NOT provided by the user + return_dist_tensor = return_distances and distances is None + return_idx_tensor = return_indices and indices is None + + if return_dist_tensor and return_idx_tensor: + return raw_distances, raw_indices + elif return_dist_tensor: + return raw_distances + elif return_idx_tensor: + return raw_indices + else: + return None + + +def distance_transform_cdt( + input: torch.Tensor, + metric: str = "chessboard", + return_distances: bool = True, + return_indices: bool = False, + distances: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], None]: + """Chamfer Distance Transform (CDT). + + Calculates the distance transform of the input using a chamfer metric. + The input is treated as a binary image where non-zero values are foreground + and zero values are background. Distances are computed from each foreground + pixel to the nearest background pixel. + + Args: + input: Binary input tensor (0 = background, non-zero = foreground). + Must be in (B, C, H, W) or (B, C, D, H, W) format for batch processing, + or (H, W) / (D, H, W) for single images. + metric: Distance metric to use: + - "chessboard": L-infinity norm (default). Also known as Chebyshev distance. + - "taxicab": L1 norm. Also known as Manhattan or city-block distance. + - "cityblock": Alias for "taxicab". + - "manhattan": Alias for "taxicab". + return_distances: Whether to calculate the distance transform. Default is True. + return_indices: Whether to calculate the feature transform (indices of closest + background element). Default is False. + distances: Optional output tensor for distances. If provided, must have + the same shape as input. If None and return_distances is True, + a new tensor will be created. + indices: Optional output tensor for indices. If provided, must have shape + (..., ndim) where ... matches input shape. If None and return_indices + is True, a new tensor will be created. + + Returns: + Depending on return_distances, return_indices, and whether output tensors + are provided: + - Returns distance tensor only when return_distances=True and distances=None + - Returns indices tensor only when return_indices=True and indices=None + - Returns tuple of (distances, indices) when both conditions above are met + - Returns None if output tensors are provided for all requested outputs + + Example: + >>> import torchmorph as tm + >>> # 2D image with batch: (B, C, H, W) + >>> x = torch.zeros(1, 1, 64, 64, device='cuda') + >>> x[0, 0, 10:20, 10:20] = 1 + >>> dist = tm.distance_transform_cdt(x) # chessboard by default + >>> dist = tm.distance_transform_cdt(x, metric='taxicab') + >>> dist, indices = tm.distance_transform_cdt(x, return_indices=True) + >>> # Using pre-allocated output tensors + >>> dist_out = torch.empty_like(x) + >>> tm.distance_transform_cdt(x, distances=dist_out) # Returns None, fills dist_out + """ if not input.is_cuda: raise ValueError("Input tensor must be on CUDA device.") if input.ndim < 2 or input.numel() == 0: raise ValueError(f"Invalid input dimension: {input.shape}.") - # binarize input - input[input != 0] = 1 + # Normalize metric aliases + if metric in ("cityblock", "manhattan"): + metric = "taxicab" + + if metric not in ("chessboard", "taxicab"): + raise ValueError("metric must be 'chessboard', 'taxicab', 'cityblock', or 'manhattan'.") + if not return_distances and not return_indices: + if distances is None and indices is None: + raise ValueError( + "At least one of return_distances or return_indices must be True, " + "or output tensors must be provided." + ) + + input = input.float().contiguous() + + # Validate pre-allocated output tensors + if distances is not None: + if distances.shape != input.shape: + raise ValueError( + f"distances shape {distances.shape} must match input shape {input.shape}" + ) + if not distances.is_cuda: + raise ValueError("distances tensor must be on CUDA device.") + return_distances = True + + if indices is not None: + if not indices.is_cuda: + raise ValueError("indices tensor must be on CUDA device.") + return_indices = True + + # Call CUDA kernel + raw_distances, raw_indices = _C.distance_transform_cdt_cuda( + input, metric, return_distances, return_indices + ) + + # Copy to pre-allocated tensors if provided + if distances is not None and raw_distances is not None: + distances.copy_(raw_distances) + + if indices is not None and raw_indices is not None: + indices.copy_(raw_indices) + + # Return based on scipy convention: + # Only return tensors that were NOT provided by the user + return_dist_tensor = return_distances and distances is None + return_idx_tensor = return_indices and indices is None + + if return_dist_tensor and return_idx_tensor: + return raw_distances, raw_indices + elif return_dist_tensor: + return raw_distances + elif return_idx_tensor: + return raw_indices + else: + return None + + +# Backward compatibility alias +def distance_transform( + input: torch.Tensor, + sampling: Optional[Union[float, Sequence[float]]] = None, + return_distances: bool = True, + return_indices: bool = False, + distances: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], None]: + """Distance Transform (alias for distance_transform_edt). - return _C.distance_transform_cuda(input) + See distance_transform_edt for full documentation. + """ + return distance_transform_edt( + input, sampling, return_distances, return_indices, distances, indices + )