diff --git a/.flake8 b/.flake8 index 5fc408e..6ef9e30 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] max-line-length = 100 -extend-ignore = E203, W503 # Compatibility with Black +extend-ignore = E203, W503 exclude = __pycache__, build, diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..bfebf61 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,32 @@ +name: pre-commit + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "*" ] + +jobs: + pre-commit: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Install pre-commit + run: | + pip install pre-commit + + - name: Install pre-commit hooks + run: | + pre-commit install --install-hooks + + - name: Run pre-commit on all files + run: | + pre-commit run --all-files diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..ca3e9b3 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,27 @@ +name: test + +on: + pull_request: + branches: ["*"] + push: + branches: [main] + +jobs: + + test: + runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + - run: | + # Make CUDA visible to this shell and all child processes + export CUDA_HOME=/usr/local/cuda + export PATH="$CUDA_HOME/bin:$PATH" + export LD_LIBRARY_PATH="$CUDA_HOME/lib64:${LD_LIBRARY_PATH:-}" + echo "CUDA_HOME=$CUDA_HOME" + pip install -r requirements-dev.txt --break-system-packages --user + pip uninstall torchmorph --yes --break-system-packages + python setup.py install --user + - run: | + ORIGINAL=$(pwd) + cd /tmp + pytest $ORIGINAL/test \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6df40f7..86088b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,3 +11,15 @@ repos: rev: 6.1.0 hooks: - id: flake8 + + # ------------------------- + # ⭐ Local Hook: forbid non-ASCII in C/C++/CUDA + # ------------------------- + - repo: local + hooks: + - id: forbid-non-ascii + name: "Forbid non-ASCII characters in C/C++/CUDA" + entry: python3 scripts/check_ascii.py + language: system + types: [file] + files: '\.(c|cc|cpp|cxx|cu|cuh|h|hpp|py)$' diff --git a/benchmark/distance_transform.py b/benchmark/distance_transform.py new file mode 100644 index 0000000..3737ced --- /dev/null +++ b/benchmark/distance_transform.py @@ -0,0 +1,77 @@ +import torch +import torch.utils.benchmark as benchmark +from prettytable import PrettyTable + +sizes = [64, 128, 256, 512, 1024] +batches = [1, 4, 8, 16] +dtype = torch.float32 +device = "cuda" +MIN_RUN = 1.0 # seconds per measurement + +torch.set_num_threads(torch.get_num_threads()) + +for B in batches: + table = PrettyTable() + table.field_names = [ + "Size", + "SciPy (ms/img)", + "Torch 1× (ms/img)", + "Torch batch (ms/img)", + "Speedup 1×", + "Speedup batch", + ] + for c in table.field_names: + table.align[c] = "r" + + for s in sizes: + # Inputs + x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) + x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] + x_imgs = [x[i : i + 1] for i in range(B)] + + # SciPy (CPU, one-by-one) + stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]" + t_scipy = benchmark.Timer( + stmt=stmt_scipy, + setup="from __main__ import x_np_list, ndi", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + scipy_per_img_ms = (t_scipy.median * 1e3) / B + + # Torch (CUDA, one-by-one) + stmt_torch1 = """ +for xi in x_imgs: + tm.distance_transform(xi) +""" + t_torch1 = benchmark.Timer( + stmt=stmt_torch1, + setup="from __main__ import x_imgs, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + torch1_per_img_ms = (t_torch1.median * 1e3) / B + + # Torch (CUDA, batched) + t_batch = benchmark.Timer( + stmt="tm.distance_transform(x)", + setup="from __main__ import x, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + torchB_per_img_ms = (t_batch.median * 1e3) / B + + # Speedups + speed1 = scipy_per_img_ms / torch1_per_img_ms + speedB = scipy_per_img_ms / torchB_per_img_ms + + table.add_row( + [ + s, + f"{scipy_per_img_ms:.3f}", + f"{torch1_per_img_ms:.3f}", + f"{torchB_per_img_ms:.3f}", + f"{speed1:.1f}×", + f"{speedB:.1f}×", + ] + ) + + print(f"\n=== Batch Size: {B} ===") + print(table) diff --git a/pyproject.toml b/pyproject.toml index 1dce09a..a4d6163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,9 @@ max-line-length = 100 extend-ignore = ["E203", "W503"] [tool.pytest.ini_options] -addopts = "-v" +addopts = "-v --import-mode=importlib" testpaths = ["test"] +[build-system] +requires = ["setuptools>=61.0", "wheel", "torch", "numpy"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index fc961bb..0df6115 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,4 +12,4 @@ flake8>=6.0 setuptools>=65.0 wheel>=0.40 ninja>=1.11 # optional, speeds up torch extension builds - +prettytable>=3.16.0 diff --git a/scripts/check_ascii.py b/scripts/check_ascii.py new file mode 100644 index 0000000..d788056 --- /dev/null +++ b/scripts/check_ascii.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +import sys +import unicodedata +from pathlib import Path + +TARGET_SUFFIXES = {".c", ".cc", ".cpp", ".cxx", ".cu", ".cuh", ".h", ".hpp", ".py"} + + +# --- Helpers -------------------------------------------------------- + + +# Latin ranges we still consider "English-ish" and therefore allowed. +# (You can shrink this if you want to ban accented letters too.) +LATIN_RANGES = [ + (0x0000, 0x007F), # Basic Latin (ASCII) + (0x00C0, 0x024F), # Latin-1 Supplement + Latin Extended-A/B + (0x1E00, 0x1EFF), # Latin Extended Additional +] + + +def in_ranges(ch: str, ranges) -> bool: + cp = ord(ch) + for start, end in ranges: + if start <= cp <= end: + return True + return False + + +def is_forbidden_char(ch: str) -> bool: + """ + Return True if ch should be *forbidden*. + + Policy: + - ASCII (<= 0x7F): always OK + - Non-ASCII letters (Unicode category starting with 'L') + that are NOT in Latin ranges: forbidden + - Everything else (emoji, arrows, symbols, etc.): allowed + """ + cp = ord(ch) + if cp <= 0x7F: + return False # pure ASCII + + cat = unicodedata.category(ch) + + # Forbid letters that are not Latin. + if cat.startswith("L"): # Letter + if in_ranges(ch, LATIN_RANGES): + return False # Latin letters allowed + return True # Non-Latin letters forbidden + + # All non-letter stuff (emoji, arrows, symbols, punctuation) is allowed. + return False + + +def find_forbidden_chars(line: str): + """Return list of (index, char) for all forbidden chars in a line.""" + result = [] + for i, ch in enumerate(line): + if is_forbidden_char(ch): + result.append((i, ch)) + return result + + +# --- Core logic ----------------------------------------------------- + + +def check_file(path: Path) -> bool: + ok = True + with path.open("r", encoding="utf-8", errors="ignore") as f: + for lineno, line in enumerate(f, start=1): + forbidden = find_forbidden_chars(line) + if forbidden: + ok = False + print(f"\n❌ {path}:{lineno}: non-English letters detected") + + # Print the full line + print(" Line content:") + print(f" {line.rstrip()}") + + # Underline the forbidden characters + underline = [" " for _ in line.rstrip("\n")] + for idx, ch in forbidden: + if idx < len(underline): + underline[idx] = "^" + print(f" {''.join(underline)}") + + # Print what characters exactly + chars = ", ".join( + f"'{ch}' (U+{ord(ch):04X}) [{unicodedata.name(ch, 'UNKNOWN')}]" + for _, ch in forbidden + ) + print(f" Offending chars: {chars}") + + return ok + + +def main(files): + ok = True + for f in files: + p = Path(f) + if p.suffix.lower() in TARGET_SUFFIXES and p.exists(): + if not check_file(p): + ok = False + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: check_ascii.py ") + sys.exit(1) + main(sys.argv[1:]) diff --git a/setup.py b/setup.py index 9876b88..fd3f95e 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -import os import glob -from setuptools import setup, find_packages +import os + +from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension diff --git a/test/test_add.py b/test/test_add.py index a641494..5f647ff 100644 --- a/test/test_add.py +++ b/test/test_add.py @@ -1,5 +1,6 @@ -import torch import pytest +import torch + import torchmorph as tm diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 166968c..5855bf3 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -1,20 +1,182 @@ -import torch +import numpy as np # noqa: F401 import pytest -import torchmorph as tm -from scipy.ndimage import distance_transform_edt as dte +import torch +from scipy.ndimage import distance_transform_edt as scipy_edt # noqa: F401 + +import torchmorph as tm # noqa: F401 + + +# ====================================================================== +# Helper functions +# ====================================================================== +def batch_scipy_edt_with_indices( + batch_numpy: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute SciPy EDT and indices for a batch of arrays.""" + dist_results: list[np.ndarray] = [] + indices_results: list[np.ndarray] = [] + + # Ensure batch_numpy has at least shape (Batch, ...) + # If the input is (H, W), it is already converted to (1, H, W) outside. + if batch_numpy.ndim == 1: + batch_numpy = batch_numpy[np.newaxis, ...] + + for sample in batch_numpy: + dist, indices = scipy_edt( + sample, + return_indices=True, + return_distances=True, + ) + dist_results.append(dist) + indices_results.append(indices) + + output_dist = np.stack(dist_results, axis=0) + output_indices = np.stack(indices_results, axis=0) + output_indices = np.moveaxis(output_indices, 1, -1) + + return output_dist, output_indices + + +# ====================================================================== +# Test data +# ====================================================================== +case_batch_1d = np.array( + [[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], + dtype=np.float32, +) + +case_batch_2d = np.array( + [ + [[0.0, 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], + [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], + ], + dtype=np.float32, +) + +# This is a single 2D image with shape (4, 4) +case_single_2d = np.array( + [ + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + ], + dtype=np.float32, +) +case_explicit_batch_one = case_single_2d[np.newaxis, ...] + +_case_3d_s1 = np.ones((4, 5, 6), dtype=np.float32) +_case_3d_s1[1, 1, 1] = 0.0 +_case_3d_s1[2, 3, 4] = 0.0 + +_case_3d_s2 = np.ones((4, 5, 6), dtype=np.float32) +_case_3d_s2[0, 0, 0] = 0.0 +case_batch_3d = np.stack([_case_3d_s1, _case_3d_s2], axis=0) -@pytest.mark.cuda -def test_distance_transform(): - """Test that tm.foo doubles all tensor elements.""" +case_dim_one = np.ones((2, 5, 1), dtype=np.float32) +case_dim_one[0, 2, 0] = 0.0 +case_dim_one[1, 4, 0] = 0.0 + +# 4D spatial case +_case_4d_s1 = np.ones((3, 3, 3, 3), dtype=np.float32) +_case_4d_s1[0, 0, 0, 0] = 0.0 + +_case_4d_s2 = np.ones((3, 3, 3, 3), dtype=np.float32) +_case_4d_s2[1, 1, 1, 1] = 0.0 + +case_batch_4d_spatial = np.stack([_case_4d_s1, _case_4d_s2], axis=0) + +# 5D spatial case +case_batch_5d_spatial = np.ones((1, 2, 2, 2, 2, 2), dtype=np.float32) +case_batch_5d_spatial[0, 0, 0, 0, 0, 0] = 0.0 +case_batch_5d_spatial[0, 1, 1, 1, 1, 1] = 0.0 + + +# ====================================================================== +# Test logic +# ====================================================================== +@pytest.mark.parametrize( + "input_numpy, has_batch_dim", + [ + pytest.param(case_batch_1d, True, id="1D_Batch"), + pytest.param(case_batch_2d, True, id="2D_Batch"), + pytest.param(case_single_2d, False, id="2D_Single_NoBatch"), + pytest.param( + case_explicit_batch_one, + True, + id="2D_Single_ExplicitBatch", + ), + pytest.param(case_batch_3d, True, id="3D_Batch"), + pytest.param(case_dim_one, True, id="2D_UnitDim_Batch"), + pytest.param(case_batch_4d_spatial, True, id="4D_Spatial_Batch"), + pytest.param(case_batch_5d_spatial, True, id="5D_Spatial_Batch"), + ], +) +def test_distance_transform_and_indices( + input_numpy: np.ndarray, + has_batch_dim: bool, + request: pytest.FixtureRequest, +) -> None: if not torch.cuda.is_available(): pytest.skip("CUDA not available") - x = torch.arange(6, dtype=torch.float32, device="cuda").reshape(2, 3) - y = tm.distance_transform(x) + # 1. Prepare NumPy data + x_numpy_contiguous = np.ascontiguousarray(input_numpy) + + # 2. Prepare SciPy input. + # If this is a single sample (has_batch_dim=False), manually add a + # batch dimension so SciPy treats it as one image instead of N 1D + # signals. + if not has_batch_dim: + scipy_input = x_numpy_contiguous[np.newaxis, ...] + else: + scipy_input = x_numpy_contiguous + + # 3. Prepare CUDA input. + # If has_batch_dim=False, the input is (H, W) and we want 2D EDT. + # The C++ API assumes the first dimension is batch, so we must + # unsqueeze(0) to get shape (1, H, W). Otherwise, it will be + # interpreted as (Batch=H, Length=W) and run 1D EDT. + x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() + if not has_batch_dim: + x_cuda = x_cuda.unsqueeze(0) + + print(f"\n\n--- Running test: {request.node.callspec.id} ---") + print(f"CUDA input shape: {x_cuda.shape}") + + # 4. Run CUDA EDT + dist_cuda, idx_cuda = tm.distance_transform(x_cuda.clone()) + + # 5. Run SciPy (ground truth) + dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(scipy_input) + dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() + + # 6. Validate distances + print( + f"CUDA distance shape: {dist_cuda.shape}, " f"reference shape: {dist_ref.shape}", + ) + assert ( + dist_cuda.shape == dist_ref.shape + ), f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" + torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-3, rtol=1e-3) + print(">> Distance validation passed.") + + # 7. Validate indices + # idx_cuda: (B, H, W, D) + spatial_shape = x_cuda.shape[1:] + coords = [torch.arange(s, device="cuda") for s in spatial_shape] + grid = torch.stack(torch.meshgrid(*coords, indexing="ij"), dim=-1) + grid = grid.unsqueeze(0) # (1, H, W, D) + + diff = grid.float() - idx_cuda.float() + dist_sq_calculated = torch.sum(diff * diff, dim=-1) + dist_sq_output = dist_cuda * dist_cuda - expected = x * 2 - torch.testing.assert_close(y, expected) - assert y.device.type == "cuda" - assert y.shape == x.shape - print("tm.foo test passed ✅") + torch.testing.assert_close( + dist_sq_calculated, + dist_sq_output, + atol=1e-3, + rtol=1e-3, + ) + print(">> Index validation passed.") diff --git a/torchmorph/add.py b/torchmorph/add.py index 4737073..f7c16b9 100644 --- a/torchmorph/add.py +++ b/torchmorph/add.py @@ -1,4 +1,5 @@ import torch + from torchmorph import _C diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 6a57f49..503d13c 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -1,24 +1,469 @@ #include +#include +#include +#include +#include +#include -__global__ void distance_transform_kernel(const float* in, float* out, int64_t N) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < N) { - out[idx] = 2.0f * in[idx]; +// ------------------------------------------------------------------ +// Configuration Constants +// ------------------------------------------------------------------ +#define INF_VAL 1e8f +#define MAX_THREADS 1024 +// Shared memory limit: typically 48 KB. +// Each pixel requires: float(value) + int(idx1) + int(idx2) = 12 bytes. +// 4096 * 12 = 48 KB. +#define SMEM_LIMIT_ELEMENTS 4096 + +// ------------------------------------------------------------------ +// Device Helper Functions +// ------------------------------------------------------------------ + +__device__ __forceinline__ float sqr(float x) { return x * x; } + +// Compute the JFA cost: (q - p)^2 + weight[p] +__device__ __forceinline__ float compute_cost(int q, int p, float val_p) { + if (p < 0) return INF_VAL; + return sqr((float)q - (float)p) + val_p; +} + +// ------------------------------------------------------------------ +// JFA Core Logic (Device Only) +// ------------------------------------------------------------------ +// Core JFA logic, independent of data location (works with both Shared and Global memory). +__device__ void run_jfa_core( + int N, + int tid, + const float* __restrict__ vals, // input weight (read-only) + int* __restrict__ idx_curr, // Ping-Pong Buffer A + int* __restrict__ idx_next // Ping-Pong Buffer B +) { + // 1. Initialization: determine whether each pixel is a valid source based on vals. + for (int i = tid; i < N; i += blockDim.x) { + if (vals[i] >= INF_VAL * 0.9f) { + idx_curr[i] = -1; // background + } else { + idx_curr[i] = i; // For each object/source point, the initial index points to itself. + } + } + __syncthreads(); + + // 2. Iterative Propagation (Step = 1, 2, 4, ... < N) + int* idx_in = idx_curr; + int* idx_out = idx_next; + + for (int step = 1; step < N; step *= 2) { + for (int i = tid; i < N; i += blockDim.x) { + int my_best_p = idx_in[i]; + float min_cost = INF_VAL; + + // Check its current best solution + if (my_best_p != -1) { + min_cost = compute_cost(i, my_best_p, vals[my_best_p]); + } + + // Check Left Neighbor (-step) + int left = i - step; + if (left >= 0) { + int left_p = idx_in[left]; + if (left_p != -1) { + float c = compute_cost(i, left_p, vals[left_p]); + if (c < min_cost) { + min_cost = c; + my_best_p = left_p; + } + } + } + + // Check Right Neighbor (+step) + int right = i + step; + if (right < N) { + int right_p = idx_in[right]; + if (right_p != -1) { + float c = compute_cost(i, right_p, vals[right_p]); + if (c < min_cost) { + min_cost = c; + my_best_p = right_p; + } + } + } + idx_out[i] = my_best_p; + } + + // Swap Pointers + int* temp = idx_in; + idx_in = idx_out; + idx_out = temp; + __syncthreads(); + } + + // 3. Ensure the final result is stored in idx_curr (if the loop ends with idx_next, copy it back). + if (idx_in != idx_curr) { + for (int i = tid; i < N; i += blockDim.x) { + idx_curr[i] = idx_next[i]; + } + __syncthreads(); } } -torch::Tensor distance_transform_cuda(torch::Tensor input) { - auto output = torch::empty_like(input); - int64_t N = input.numel(); - int threads = 256; - int blocks = (N + threads - 1) / threads; +// ------------------------------------------------------------------ +// Kernel 1: Shared Memory JFA (Fast Path) +// ------------------------------------------------------------------ +// Template parameter NDim: when NDim > 0, the compiler performs loop unrolling optimizations. +// Runtime parameter runtime_ndim: when NDim == 0 (default behavior), this parameter specifies the dimension. +template +__global__ void edt_kernel_shared( + const float* __restrict__ in_data, // input Dist^2 + const int32_t* __restrict__ in_indices, // output Indices + float* __restrict__ out_dist, // output Dist (IsFinal ? sqrt : sqr) + int32_t* __restrict__ out_indices, // output Indices + int64_t L, // Size of the current dimension + int64_t total_elements, // Total number of elements + int runtime_ndim // Runtime dimension (used as fallback) +) { + // Determine the effective dimension + const int D = (NDim > 0) ? NDim : runtime_ndim; + + // Compute row offset + int64_t row_idx = blockIdx.x; + int64_t offset = row_idx * L; + + if (offset >= total_elements) return; + + // Shared memory layout + extern __shared__ char s_buffer[]; + float* s_vals = (float*)s_buffer; + int* s_idx1 = (int*)(s_vals + L); + int* s_idx2 = (int*)(s_idx1 + L); + + // 1. Load distances into Shared Memory + for (int i = threadIdx.x; i < L; i += blockDim.x) { + s_vals[i] = __ldg(&in_data[offset + i]); + } + __syncthreads(); + + // 2. Run the core JFA logic + run_jfa_core(L, threadIdx.x, s_vals, s_idx1, s_idx2); + + // 3. Write back the results + for (int q = threadIdx.x; q < L; q += blockDim.x) { + int p = s_idx1[q]; // Nearest point (local index within 0..L-1) + float dist_val; + + // Compute updated distance + if (p != -1) { + float dist_sq = sqr((float)q - (float)p) + s_vals[p]; + dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; + } else { + // No source point found (e.g., entire row is background) + dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); + p = 0; // Prevent out-of-bounds access + } + out_dist[offset + q] = dist_val; + + // Propagate indices: copy a vector of size [D] + if (p != -1) { + int64_t src_offset = (offset + p) * D; + int64_t dst_offset = (offset + q) * D; + + // When NDim > 0, this loop is fully unrolled by the compiler + for (int d = 0; d < D; ++d) { + out_indices[dst_offset + d] = in_indices[src_offset + d]; + } + } else { + // Fallback: no source available + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; + } + } +} + +// ------------------------------------------------------------------ +// Kernel 2: Global Memory JFA (Fallback Path) +// ------------------------------------------------------------------ +// Same logic as above, but uses Global Memory as the ping-pong buffer +template +__global__ void edt_kernel_global( + const float* __restrict__ in_data, + const int32_t* __restrict__ in_indices, + float* __restrict__ out_dist, + int32_t* __restrict__ out_indices, + int* __restrict__ global_buffer_1, + int* __restrict__ global_buffer_2, + int64_t L, + int64_t total_elements, + int runtime_ndim +) { + const int D = (NDim > 0) ? NDim : runtime_ndim; + + int64_t row_idx = blockIdx.x; + int64_t offset = row_idx * L; + + if (offset >= total_elements) return; + + // Pointers to Global Memory + int* g_idx1 = global_buffer_1 + offset; + int* g_idx2 = global_buffer_2 + offset; + + // 1. & 2. Run the JFA core (operating directly on Global Memory) + run_jfa_core(L, threadIdx.x, in_data + offset, g_idx1, g_idx2); + + // 3. Write back results + for (int q = threadIdx.x; q < L; q += blockDim.x) { + int p = g_idx1[q]; + float dist_val; - distance_transform_kernel<<>>( - input.data_ptr(), - output.data_ptr(), - N - ); + if (p != -1) { + float val_p = in_data[offset + p]; + float dist_sq = sqr((float)q - (float)p) + val_p; + dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; + } else { + dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); + p = 0; + } + + out_dist[offset + q] = dist_val; + + if (p != -1) { + int64_t src_offset = (offset + p) * D; + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) { + out_indices[dst_offset + d] = in_indices[src_offset + d]; + } + } else { + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; + } + } +} + + +// ------------------------------------------------------------------ +// Kernel 3: Initialize Indices +// ------------------------------------------------------------------ +// Initialize index tensor as grid coordinates +// indices shape: (..., D) +__global__ void init_indices_kernel( + int32_t* indices, + int64_t total_pixels, + int NDim, + const int64_t* __restrict__ shape_ptr // shape of spatial dimensions +) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_pixels) return; + + // Unravel Index + // idx is the flat index of each pixel + // We need to compute its coordinate in spatial_shape + + int64_t temp = idx; + // Use local register array to avoid repeated global memory reads (assume max 10 dims) + int32_t coords[10]; + + // Example: spatial_shape = [D0, D1, D2] + // compute by modulo from last dimension + for (int d = NDim - 1; d >= 0; --d) { + int64_t dim_size = shape_ptr[d]; + coords[d] = temp % dim_size; + temp /= dim_size; + } + + // Write to Global Memory + // Indices tensor is flattened as (TotalPixels, NDim) + int64_t out_ptr = idx * NDim; + for (int d = 0; d < NDim; ++d) { + indices[out_ptr + d] = coords[d]; + } +} + +// ------------------------------------------------------------------ +// Host Function: C++ Entry Point +// ------------------------------------------------------------------ + +std::tuple distance_transform_cuda(torch::Tensor input) { + TORCH_CHECK(input.is_cuda(), "Input must be on CUDA device."); + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32."); + + input = input.contiguous(); + + // Handle batch dimension: if input is 1D (L), treat as no batch but internally add a batch dimension. + // Convention: input shape is (Batch, D1, D2, ..., Dn) + // Algorithm treats batch and other dims identically (batch is just another leading dimension) + // But index initialization needs to know which are "spatial dimensions". + // Assumption: all dims except dim 0 (Batch) are spatial. + + const int ndim = input.dim(); + // If ndim=1, assume (L) -> sample_ndim=1 + // If ndim=4 (B, C, H, W), sample_ndim=3 (C,H,W treated as spatial? Channels often processed independently) + // Correction: classical EDT usually runs on (H,W) or (D,H,W). + // If channels exist, typically each channel is processed independently. + // For maximum generality, we treat **all dims except dim 0** as spatial dims. + // If input has no batch dim, user should use unsqueeze(0) in Python. + + const int sample_ndim = ndim - 1; + TORCH_CHECK(sample_ndim > 0, "Input tensor must have at least 2 dimensions (Batch, ...)"); + + auto shape = input.sizes().vec(); + int64_t num_pixels = input.numel(); + + if (num_pixels == 0) { + auto index_shape = shape; + index_shape.push_back(sample_ndim); + return std::make_tuple(torch::empty_like(input), + torch::empty(index_shape, input.options().dtype(torch::kInt32))); + } + + // 1. Initialize Distance Tensor + // 0 -> 0, 1 -> INF + auto current_dist = torch::where(input == 0, + torch::tensor(0.0f, input.options()), + torch::tensor(INF_VAL, input.options())); + + // 2. Initialize Index Tensor + // Shape: (Batch, D1, ..., Dn, sample_ndim) + auto index_shape = shape; + index_shape.push_back(sample_ndim); + auto current_idx = torch::empty(index_shape, input.options().dtype(torch::kInt32)); + + // 2.1 Prepare shape tensor for kernel + std::vector spatial_shape(shape.begin() + 1, shape.end()); + auto shape_tensor = torch::tensor(spatial_shape, torch::kInt64).to(input.device()); + + // 2.2 Launch initialization kernel + { + int threads = 256; + int blocks = (num_pixels + threads - 1) / threads; + init_indices_kernel<<>>( + current_idx.data_ptr(), + num_pixels, + sample_ndim, + shape_tensor.data_ptr() + ); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("Init Kernel Failed: %s\n", cudaGetErrorString(err)); + } + } + + // Pre-allocate Global Memory Buffers (lazy) + torch::Tensor global_buf1, global_buf2; + + // 3. Process each spatial dimension (Separable JFA) + // Iterate through each spatial dimension (1 to ndim-1) + for (int d = 1; d < ndim; ++d) { + bool is_final_pass = (d == ndim - 1); + + // --- Step A: Transpose current dim to last --- + // Resulting shape: (..., L) + auto dist_in = current_dist.transpose(d, ndim - 1).contiguous(); + auto idx_in = current_idx.transpose(d, ndim - 1).contiguous(); + + int64_t L = dist_in.size(-1); + int64_t total_slices = dist_in.numel() / L; + + auto dist_out = torch::empty_like(dist_in); + auto idx_out = torch::empty_like(idx_in); + + // --- Step B: Kernel Dispatch --- + int threads = std::min((int64_t)MAX_THREADS, L); + + // Check whether Shared Memory can be used + if (L <= SMEM_LIMIT_ELEMENTS) { + size_t smem_size = L * (sizeof(float) + 2 * sizeof(int)); + + // Switch macro to handle template dimension specialization + #define DISPATCH_SHARED(IS_FINAL) \ + switch(sample_ndim) { \ + case 1: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 2: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 3: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 4: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 5: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 6: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + default: /* Fallback for > 6D */ \ + edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + } + + if (is_final_pass) { DISPATCH_SHARED(true); } + else { DISPATCH_SHARED(false); } + + } else { + // Global Memory fallback (L > 4096) + if (global_buf1.numel() < dist_in.numel()) { + global_buf1 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); + global_buf2 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); + } + + #define DISPATCH_GLOBAL(IS_FINAL) \ + switch(sample_ndim) { \ + case 1: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 2: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 3: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 4: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 5: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 6: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + default: /* Fallback */ \ + edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + } + + if (is_final_pass) { DISPATCH_GLOBAL(true); } + else { DISPATCH_GLOBAL(false); } + } + + // --- Step C: Transpose Back --- + current_dist = dist_out.transpose(d, ndim - 1); + current_idx = idx_out.transpose(d, ndim - 1); + } - return output; + return std::make_tuple(current_dist, current_idx); } diff --git a/torchmorph/csrc/torchmorph.cpp b/torchmorph/csrc/torchmorph.cpp index 5d1dae8..c79970c 100644 --- a/torchmorph/csrc/torchmorph.cpp +++ b/torchmorph/csrc/torchmorph.cpp @@ -2,10 +2,9 @@ // Declare CUDA implementations torch::Tensor add_cuda(torch::Tensor input, float scalar); -torch::Tensor distance_transform_cuda(torch::Tensor input); +std::tuple distance_transform_cuda(torch::Tensor input); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("add_cuda", &add_cuda, "Add tensor with scalar"); m.def("distance_transform_cuda", &distance_transform_cuda, "Distance transform"); } - diff --git a/torchmorph/distance_transform.py b/torchmorph/distance_transform.py index e4b54db..868e84a 100644 --- a/torchmorph/distance_transform.py +++ b/torchmorph/distance_transform.py @@ -1,4 +1,5 @@ import torch + from torchmorph import _C @@ -6,4 +7,10 @@ def distance_transform(input: torch.Tensor) -> torch.Tensor: """Distance Transform in CUDA.""" if not input.is_cuda: raise ValueError("Input tensor must be on CUDA device.") + if input.ndim < 2 or input.numel() == 0: + raise ValueError(f"Invalid input dimension: {input.shape}.") + + # binarize input + input[input != 0] = 1 + return _C.distance_transform_cuda(input)