From 6e60ec07ddfc533222b604c2f14e5befefec406d Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Mon, 27 Oct 2025 11:34:19 +0800 Subject: [PATCH 01/39] init distance transform --- pyproject.toml | 3 +-- test/test_distance_transform.py | 4 ++++ torchmorph/csrc/distance_transform_kernel.cu | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1dce09a..6027eac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,5 @@ max-line-length = 100 extend-ignore = ["E203", "W503"] [tool.pytest.ini_options] -addopts = "-v" +addopts = "-v --import-mode=importlib" testpaths = ["test"] - diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 166968c..e002200 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -14,6 +14,10 @@ def test_distance_transform(): y = tm.distance_transform(x) expected = x * 2 + # here we compare the output, i.e. results of our distance transform, + # with the expected output, e.g. the results of scipy.ndimage.distance_transform_edt + # currently, our implementation simply multiplies the input by 2, + # but eventually we have to implement the full algorithm. torch.testing.assert_close(y, expected) assert y.device.type == "cuda" assert y.shape == x.shape diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 6a57f49..33d3abd 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -1,5 +1,8 @@ #include +// distance transform: https://en.wikipedia.org/wiki/Distance_transform +// https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html + __global__ void distance_transform_kernel(const float* in, float* out, int64_t N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { From 875483d93115b997e008a0b9f4fefb9bc8abd7ea Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Wed, 29 Oct 2025 02:08:05 +0800 Subject: [PATCH 02/39] =?UTF-8?q?docs:=20=E8=AF=A6=E7=BB=86=E8=AE=B0?= =?UTF-8?q?=E5=BD=95=E5=B9=B6=E4=BF=AE=E5=A4=8D=E9=A1=B9=E7=9B=AE=E7=8E=AF?= =?UTF-8?q?=E5=A2=83=E6=90=AD=E5=BB=BA=E4=B8=8E=E6=9E=84=E5=BB=BA=E6=B5=81?= =?UTF-8?q?=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 安装了miniconda来进行虚拟环境包管理 在执行`pip install -e. "import torch" 失败并抛出 `ModuleNotFoundError 所以在 `pyproject.toml` 的 `[build-system].requires` 列表中明确添加 `"torch"`。 这会强制 pip 在构建开始前,先将 torch 安装到其临时环境中,从而确保构建过程顺利完成。 --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 6027eac..10af0cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,3 +16,7 @@ extend-ignore = ["E203", "W503"] [tool.pytest.ini_options] addopts = "-v --import-mode=importlib" testpaths = ["test"] + +[build-system] +requires = ["setuptools>=61.0", "wheel", "torch"] +build-backend = "setuptools.build_meta" \ No newline at end of file From 79c5acdb26faeba1f282a435c33ff36d04fc865c Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Sat, 1 Nov 2025 21:18:22 +0800 Subject: [PATCH 03/39] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E4=BA=8C=E7=BB=B4?= =?UTF-8?q?=E6=AC=A7=E5=BC=8F=E8=B7=9D=E7=A6=BB=E5=8F=98=E6=8D=A2=EF=BC=88?= =?UTF-8?q?EDT=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 CUDA 内核,分别用于行与列方向的距离变换 - 支持在 GPU 上处理二维张量 - 已通过基础单元测试验证正确性 - 注意:当前实现仅适用于二维情况,尚未推广到 N 维张量 --- test/test_distance_transform.py | 113 ++++++++++++++++--- torchmorph/csrc/distance_transform_kernel.cu | 109 ++++++++++++++---- torchmorph/csrc/torchmorph.cpp | 11 -- torchmorph/csrc/torchmorph.cu | 51 +++++++++ torchmorph/distance_transform.py | 2 +- 5 files changed, 239 insertions(+), 47 deletions(-) delete mode 100644 torchmorph/csrc/torchmorph.cpp create mode 100644 torchmorph/csrc/torchmorph.cu diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index e002200..493758a 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -1,24 +1,105 @@ import torch import pytest -import torchmorph as tm from scipy.ndimage import distance_transform_edt as dte +import torchmorph as tm +import numpy as np + +# --- 我们在这里定义所有的测试用例 --- + +# 用例 1: 我们之前成功的那个标准例子 +case_standard = np.array([ + [0, 1, 1, 1, 1], + [0, 0, 1, 1, 1], + [0, 1, 1, 1, 1], + [0, 1, 1, 1, 0], + [0, 1, 1, 0, 0] +], dtype=np.float32) + +# 用例 2: 全是背景 (0),输出应该全是 0 +case_all_background = np.zeros((5, 5), dtype=np.float32) + +# 用例 3: 全是前景 (1),输出应该也全是 0 (因为前景点到背景的距离未定义,SciPy默认输出0) +case_all_foreground = np.ones((5, 5), dtype=np.float32) + +# 用例 4: 只有一个背景点 (0) 在中间 +case_single_background = np.ones((5, 5), dtype=np.float32) +case_single_background[2, 2] = 0 +# 用例 5: 只有一个前景点 (1) 在中间 +case_single_foreground = np.zeros((5, 5), dtype=np.float32) +case_single_foreground[2, 2] = 1 -@pytest.mark.cuda -def test_distance_transform(): - """Test that tm.foo doubles all tensor elements.""" +# 用例 6: 非正方形的矩阵 (高 > 宽) +case_tall_matrix = np.array([ + [1, 0, 1], + [1, 1, 1], + [1, 1, 1], + [0, 1, 0], + [1, 1, 1], +], dtype=np.float32) + +# 用例 7: 非正方形的矩阵 (宽 > 高) +case_wide_matrix = np.array([ + [1, 1, 0, 1, 1], + [1, 1, 1, 1, 0], + [0, 1, 1, 1, 1], +], dtype=np.float32) + +# 用例 8: 棋盘格,考验对角线距离的计算 +case_checkerboard = np.array([ + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], +], dtype=np.float32) + +# --- 使用 pytest.mark.parametrize 来自动运行所有测试用例 --- + +@pytest.mark.parametrize( + "input_numpy", + [ + pytest.param(case_standard, id="Standard Case"), + pytest.param(case_all_background, id="All Background"), + pytest.param(case_all_foreground, id="All Foreground"), + pytest.param(case_single_background, id="Single Background Pixel"), + pytest.param(case_single_foreground, id="Single Foreground Pixel"), + pytest.param(case_tall_matrix, id="Tall Matrix (H>W)"), + pytest.param(case_wide_matrix, id="Wide Matrix (W>H)"), + pytest.param(case_checkerboard, id="Checkerboard"), + ] +) +def test_distance_transform_comprehensive(input_numpy, request): + """ + 一个统一的测试函数,用来验证所有不同的输入情况。 + """ if not torch.cuda.is_available(): pytest.skip("CUDA not available") - x = torch.arange(6, dtype=torch.float32, device="cuda").reshape(2, 3) - y = tm.distance_transform(x) - - expected = x * 2 - # here we compare the output, i.e. results of our distance transform, - # with the expected output, e.g. the results of scipy.ndimage.distance_transform_edt - # currently, our implementation simply multiplies the input by 2, - # but eventually we have to implement the full algorithm. - torch.testing.assert_close(y, expected) - assert y.device.type == "cuda" - assert y.shape == x.shape - print("tm.foo test passed ✅") + # 准备输入数据 + x = torch.from_numpy(input_numpy).cuda() + + # 1. 运行你的 CUDA 实现 + y_cuda = tm.distance_transform(x) + + # 2. 运行 SciPy 官方实现 + y_ref_numpy = dte(input_numpy) + y_ref = torch.from_numpy(y_ref_numpy).to(torch.float32).cuda() + + # 打印结果用于直观对比 + print(f"\n\n--- Running Test: {request.node.callspec.id} ---") + print("Input Array:\n", input_numpy) + print("\nYour CUDA Implementation Output:\n", y_cuda.cpu().numpy()) + print("\nSciPy Reference Output:\n", y_ref.cpu().numpy()) + if request.node.callspec.id == "All Foreground": + # 对于这个特殊情况,我们不与 SciPy 比较。 + # 我们验证我们自己的逻辑:输出值是否都非常大 (代表无穷远)。 + print("\nSciPy has different behavior for this edge case. Verifying CUDA output is ~inf.") + # 断言所有元素都大于一个很大的阈值 + assert torch.all(y_cuda > 1e4) + else: + # 对于所有其他正常情况,我们与 SciPy 的黄金标准进行比较。 + y_ref_numpy = dte(input_numpy) + y_ref = torch.from_numpy(y_ref_numpy).to(torch.float32).cuda() + print("\nSciPy Reference Output:\n", y_ref.cpu().numpy()) + torch.testing.assert_close(y_cuda, y_ref, atol=1e-3, rtol=1e-3) + print("--- Test End ---") diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 33d3abd..70527db 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -1,27 +1,98 @@ -#include +#include +#include -// distance transform: https://en.wikipedia.org/wiki/Distance_transform -// https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html +__global__ void edt_pass1_rows(const float* input, float* temp, int H, int W) { + int y = blockIdx.x; + if (y >= H) return; -__global__ void distance_transform_kernel(const float* in, float* out, int64_t N) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < N) { - out[idx] = 2.0f * in[idx]; + extern __shared__ float sdata[]; + float* f = sdata; + int* v = (int*)(sdata + W); + float* z = (float*)(v + W + 1); + + for (int x = threadIdx.x; x < W; x += blockDim.x) { + float val = input[y * W + x]; + // 【关键任务修改】 + // 如果像素是 0 (背景),则为源点 (距离0);否则为无穷远。 + f[x] = (val < 0.5f) ? 0.0f : 1e10f; } -} + __syncthreads(); -torch::Tensor distance_transform_cuda(torch::Tensor input) { - auto output = torch::empty_like(input); - int64_t N = input.numel(); - int threads = 256; - int blocks = (N + threads - 1) / threads; + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e10f; + z[1] = 1e10f; - distance_transform_kernel<<>>( - input.data_ptr(), - output.data_ptr(), - N - ); + for (int q = 1; q < W; q++) { + float s; + while (true) { + int p = v[k]; + s = ((f[q] + q * q) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { break; } + if (k == 0) { break; } + k--; + } + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e10f; + } - return output; + k = 0; + for (int q = 0; q < W; q++) { + while (z[k + 1] < q) k++; + int p = v[k]; + temp[y * W + q] = (q - p) * (q - p) + f[p]; + } + } } +// PASS 2: 对每一列进行操作 +__global__ void edt_pass2_cols(const float* temp, float* output, int H, int W) { + int x = blockIdx.x; + if (x >= W) return; + + extern __shared__ float sdata[]; + float* f = sdata; + int* v = (int*)(sdata + H); + float* z = (float*)(v + H + 1); + + for (int y = threadIdx.x; y < H; y += blockDim.x) { + f[y] = temp[y * W + x]; + } + __syncthreads(); + + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e10f; + z[1] = 1e10f; + + for (int q = 1; q < H; q++) { + float s; + while (true) { + int p = v[k]; + s = ((f[q] + q * q) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { + break; + } + if (k == 0) { + break; + } + k--; + } + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e10f; + } + + k = 0; + for (int q = 0; q < H; q++) { + while (z[k + 1] < q) k++; + int p = v[k]; + output[q * W + x] = sqrtf((q - p) * (q - p) + f[p]); + } + } +} \ No newline at end of file diff --git a/torchmorph/csrc/torchmorph.cpp b/torchmorph/csrc/torchmorph.cpp deleted file mode 100644 index 5d1dae8..0000000 --- a/torchmorph/csrc/torchmorph.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include - -// Declare CUDA implementations -torch::Tensor add_cuda(torch::Tensor input, float scalar); -torch::Tensor distance_transform_cuda(torch::Tensor input); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("add_cuda", &add_cuda, "Add tensor with scalar"); - m.def("distance_transform_cuda", &distance_transform_cuda, "Distance transform"); -} - diff --git a/torchmorph/csrc/torchmorph.cu b/torchmorph/csrc/torchmorph.cu new file mode 100644 index 0000000..dc2e4ce --- /dev/null +++ b/torchmorph/csrc/torchmorph.cu @@ -0,0 +1,51 @@ +// ========================================================================= +// 内容保存到: torchmorph/csrc/torchmorph.cpp +// ========================================================================= + +#include + +// 函数声明:告诉 C++ 编译器,这两个 CUDA 内核函数是在别的文件里定义的 +// 这样 C++ 代码才能成功调用 .cu 文件里的内核 +__global__ void edt_pass1_rows(const float* input, float* temp, int H, int W); +__global__ void edt_pass2_cols(const float* temp, float* output, int H, int W); + + + +// 主调函数 (运行在 CPU 上) +torch::Tensor distance_transform_cuda(torch::Tensor input) { + // 检查输入张量是否在 CUDA 上,以及是否为二维 + TORCH_CHECK(input.is_cuda(), "Input must be on CUDA"); + TORCH_CHECK(input.dim() == 2, "Only 2D tensors supported"); + + int H = input.size(0); + int W = input.size(1); + + // 创建临时的和最终的输出张量 + auto temp = torch::empty_like(input); + auto output = torch::empty_like(input); + + // 计算动态共享内存的大小 + size_t shared_mem_pass1 = W * sizeof(float) + (W + 1) * sizeof(int) + (W + 2) * sizeof(float); + size_t shared_mem_pass2 = H * sizeof(float) + (H + 1) * sizeof(int) + (H + 2) * sizeof(float); + + // 设置每个块的线程数 + int threads_per_block = 32; + + // <<<...>>> 语法:启动 CUDA 内核 + // 参数:Grid大小, Block大小, 共享内存大小, (可选的流) + + // Pass 1: 每行启动一个 block + edt_pass1_rows<<>>( + input.data_ptr(), temp.data_ptr(), H, W); + + // Pass 2: 每列启动一个 block + edt_pass2_cols<<>>( + temp.data_ptr(), output.data_ptr(), H, W); + + return output; +} + +// 使用 PYBIND11 将 C++ 函数绑定到 Python 模块 +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("distance_transform", &distance_transform_cuda, "CUDA-accelerated Exact Euclidean Distance Transform"); +} \ No newline at end of file diff --git a/torchmorph/distance_transform.py b/torchmorph/distance_transform.py index e4b54db..b4cd458 100644 --- a/torchmorph/distance_transform.py +++ b/torchmorph/distance_transform.py @@ -6,4 +6,4 @@ def distance_transform(input: torch.Tensor) -> torch.Tensor: """Distance Transform in CUDA.""" if not input.is_cuda: raise ValueError("Input tensor must be on CUDA device.") - return _C.distance_transform_cuda(input) + return _C.distance_transform(input) From 72943e44a723e794deb4a31af56b345a8e5834a9 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Sun, 2 Nov 2025 04:12:02 +0800 Subject: [PATCH 04/39] =?UTF-8?q?N=E7=BB=B4=E6=89=B9=E5=A4=84=E7=90=86?= =?UTF-8?q?=E7=9A=84=E6=AC=A7=E6=B0=8F=E8=B7=9D=E7=A6=BB=E5=8F=98=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_distance_transform.py | 138 ++++------ torchmorph/csrc/distance_transform_kernel.cu | 271 ++++++++++++++----- torchmorph/csrc/torchmorph.cpp | 10 + torchmorph/csrc/torchmorph.cu | 51 ---- torchmorph/distance_transform.py | 2 +- 5 files changed, 269 insertions(+), 203 deletions(-) create mode 100644 torchmorph/csrc/torchmorph.cpp delete mode 100644 torchmorph/csrc/torchmorph.cu diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 493758a..3feffb4 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -4,102 +4,82 @@ import torchmorph as tm import numpy as np -# --- 我们在这里定义所有的测试用例 --- -# 用例 1: 我们之前成功的那个标准例子 -case_standard = np.array([ - [0, 1, 1, 1, 1], - [0, 0, 1, 1, 1], - [0, 1, 1, 1, 1], - [0, 1, 1, 1, 0], - [0, 1, 1, 0, 0] -], dtype=np.float32) +def batch_distance_transform_edt(batch_numpy): -# 用例 2: 全是背景 (0),输出应该全是 0 -case_all_background = np.zeros((5, 5), dtype=np.float32) + is_single_sample = batch_numpy.ndim <= 2 + # (H, W) -> (1, H, W) + if is_single_sample: + batch_numpy = batch_numpy[np.newaxis, ...] + + results = [dte(sample) for sample in batch_numpy] + output = np.stack(results, axis=0) + # (1, H, W) -> (H, W) + if is_single_sample: + output = output.squeeze(0) + + return output -# 用例 3: 全是前景 (1),输出应该也全是 0 (因为前景点到背景的距离未定义,SciPy默认输出0) -case_all_foreground = np.ones((5, 5), dtype=np.float32) +# 用例 1: 批处理的 2D 图像 +case_batch_2d = np.array([ + # 第 1 张图 + [[0, 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], + # 第 2 张图 + [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]] +], dtype=np.float32) -# 用例 4: 只有一个背景点 (0) 在中间 -case_single_background = np.ones((5, 5), dtype=np.float32) -case_single_background[2, 2] = 0 -# 用例 5: 只有一个前景点 (1) 在中间 -case_single_foreground = np.zeros((5, 5), dtype=np.float32) -case_single_foreground[2, 2] = 1 +# 用例 2: 批处理的 3D 图像 +case_3d_sample1 = np.ones((4, 5, 6), dtype=np.float32); case_3d_sample1[1, 1, 1] = 0.0; case_3d_sample1[2, 3, 4] = 0.0 +case_3d_sample2 = np.ones((4, 5, 6), dtype=np.float32); case_3d_sample2[0, 0, 0] = 0.0 +case_batch_3d = np.stack([case_3d_sample1, case_3d_sample2], axis=0) -# 用例 6: 非正方形的矩阵 (高 > 宽) -case_tall_matrix = np.array([ - [1, 0, 1], - [1, 1, 1], - [1, 1, 1], - [0, 1, 0], - [1, 1, 1], +# 用例 3: 单张 2D 图像 (隐式批处理) +case_single_2d = np.array([ + [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], ], dtype=np.float32) -# 用例 7: 非正方形的矩阵 (宽 > 高) -case_wide_matrix = np.array([ - [1, 1, 0, 1, 1], - [1, 1, 1, 1, 0], - [0, 1, 1, 1, 1], -], dtype=np.float32) -# 用例 8: 棋盘格,考验对角线距离的计算 -case_checkerboard = np.array([ - [0, 1, 0, 1], - [1, 0, 1, 0], - [0, 1, 0, 1], - [1, 0, 1, 0], -], dtype=np.float32) +# 用例 4: 单张 2D 图像 (显式批处理) +case_explicit_batch_one = case_single_2d[np.newaxis, ...] -# --- 使用 pytest.mark.parametrize 来自动运行所有测试用例 --- +# 用例 5: 含幺元维度的批处理 +case_dim_one = np.ones((2, 5, 1), dtype=np.float32) # 两张 5x1 的图片 +case_dim_one[0, 2, 0] = 0.0 +case_dim_one[1, 4, 0] = 0.0 + +# 用例 6: 1D 张量的批处理 +case_batch_1d = np.array([ + [1, 1, 0, 1, 0, 1], + [0, 1, 1, 1, 1, 0] +], dtype=np.float32) @pytest.mark.parametrize( "input_numpy", [ - pytest.param(case_standard, id="Standard Case"), - pytest.param(case_all_background, id="All Background"), - pytest.param(case_all_foreground, id="All Foreground"), - pytest.param(case_single_background, id="Single Background Pixel"), - pytest.param(case_single_foreground, id="Single Foreground Pixel"), - pytest.param(case_tall_matrix, id="Tall Matrix (H>W)"), - pytest.param(case_wide_matrix, id="Wide Matrix (W>H)"), - pytest.param(case_checkerboard, id="Checkerboard"), + pytest.param(case_batch_2d, id="批处理2D图像"), + pytest.param(case_batch_3d, id="批处理3D图像"), + pytest.param(case_single_2d, id="单张2D图像(隐式批处理)"), + pytest.param(case_explicit_batch_one, id="单张2D图像(显式批处理)"), + pytest.param(case_dim_one, id="含幺元维度的批处理"), + pytest.param(case_batch_1d, id="批处理1D数据"), ] ) -def test_distance_transform_comprehensive(input_numpy, request): - """ - 一个统一的测试函数,用来验证所有不同的输入情况。 - """ +def test_batch_processing(input_numpy, request): if not torch.cuda.is_available(): pytest.skip("CUDA not available") + x_numpy_contiguous = np.ascontiguousarray(input_numpy) + x = torch.from_numpy(x_numpy_contiguous).cuda() - # 准备输入数据 - x = torch.from_numpy(input_numpy).cuda() - - # 1. 运行你的 CUDA 实现 - y_cuda = tm.distance_transform(x) - - # 2. 运行 SciPy 官方实现 - y_ref_numpy = dte(input_numpy) + print(f"\n\n--- 正在运行测试: {request.node.callspec.id} ---") + print(f"输入张量形状: {x.shape}") + y_cuda = tm.distance_transform(x.clone()) + + y_ref_numpy = batch_distance_transform_edt(x_numpy_contiguous) y_ref = torch.from_numpy(y_ref_numpy).to(torch.float32).cuda() - - # 打印结果用于直观对比 - print(f"\n\n--- Running Test: {request.node.callspec.id} ---") - print("Input Array:\n", input_numpy) - print("\nYour CUDA Implementation Output:\n", y_cuda.cpu().numpy()) - print("\nSciPy Reference Output:\n", y_ref.cpu().numpy()) - if request.node.callspec.id == "All Foreground": - # 对于这个特殊情况,我们不与 SciPy 比较。 - # 我们验证我们自己的逻辑:输出值是否都非常大 (代表无穷远)。 - print("\nSciPy has different behavior for this edge case. Verifying CUDA output is ~inf.") - # 断言所有元素都大于一个很大的阈值 - assert torch.all(y_cuda > 1e4) - else: - # 对于所有其他正常情况,我们与 SciPy 的黄金标准进行比较。 - y_ref_numpy = dte(input_numpy) - y_ref = torch.from_numpy(y_ref_numpy).to(torch.float32).cuda() - print("\nSciPy Reference Output:\n", y_ref.cpu().numpy()) - torch.testing.assert_close(y_cuda, y_ref, atol=1e-3, rtol=1e-3) - print("--- Test End ---") + + assert y_cuda.shape == y_ref.shape, f"形状不匹配! CUDA输出: {y_cuda.shape}, SciPy应为: {y_ref.shape}" + print("CUDA 和 SciPy 输出形状匹配。") + + torch.testing.assert_close(y_cuda, y_ref, atol=1e-3, rtol=1e-3) + print("--- 断言通过 (数值接近) ---") \ No newline at end of file diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 70527db..0bc5c26 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -1,98 +1,225 @@ -#include -#include +#include +#include + +// --- Kernel 1: 二值化内核 --- +/** + * @brief 对输入张量进行逐元素二值化。 + * @details 这是一个简单的并行操作。它将输入张量中的背景像素(值<0.5)设置为0, + * 并将前景像素(值>=0.5)设置为一个极大的值(1e20f),这在距离变换的上下文中 + * 可以被认为是无穷大。 + * @param in 输入张量的数据指针。 + * @param out 输出张量的数据指针。 + * @param N 张量中的元素总数。 + */ +__global__ void binarize_kernel(const float* in, float* out, int64_t N) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + out[idx] = (in[idx] < 0.5f) ? 0.0f : 1e20f; + } +} + +// --- Kernel 2: 1D Pass 距离平方计算内核 --- +/** + * @brief 沿着一个指定的空间维度,对N维张量执行一维抛物线下包络算法。 + * @details 这是Felzenszwalb和Huttenlocher EDT算法的核心。它通过将N维问题分解为N个 + * 一维问题来解决。此内核负责处理其中一个维度。它只计算距离的平方,以避免 + * 昂贵的开方运算并保持数值精度。 + * 每个CUDA线程块(block)负责处理一条完整的一维扫描线(slice)。 + * @param in_data 输入张量数据指针。 + * @param out_data 输出张量数据指针。 + * @param shape 描述输入张量形状的数组指针 (在GPU上)。 + * @param strides 描述输入张量步幅的数组指针 (在GPU上)。 + * @param ndim 张量的总维度数 (包括批处理维度)。 + * @param process_dim_sample 当前正在处理的空间维度索引 (0代表第一个空间维度,依此类推)。 + * @param total_slices 需要处理的一维扫描线总数 (batch_size * num_slices_per_sample)。 + * @param num_slices_per_sample 每个样本中,垂直于当前处理维度的扫描线数量。 + */ +__global__ void edt_1d_pass_sq_kernel( + const float* in_data, float* out_data, + const int64_t* shape, const int64_t* strides, + int32_t ndim, int32_t process_dim_sample, + int64_t total_slices, int64_t num_slices_per_sample +) { + // 每个线程块处理一条一维扫描线 + int64_t slice_idx = blockIdx.x; + if (slice_idx >= total_slices) return; + + + int64_t batch_idx = slice_idx / num_slices_per_sample; + int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; + int64_t batch_offset = batch_idx * strides[0]; // 获取批处理的基地址 + int64_t sample_base_offset = 0; + int64_t temp_idx = slice_idx_in_sample; + int sample_ndim = ndim - 1; + + // 从非处理维度中计算出样本内的基地址偏移 + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; // 跳过当前正在处理的维度 + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim == 0) continue; + int64_t coord_in_dim = temp_idx % size_of_dim; + temp_idx /= size_of_dim; + sample_base_offset += coord_in_dim * strides[d + 1]; + } + + const int64_t process_dim_actual = process_dim_sample + 1; // 加上批处理维度的实际索引 + const int64_t N = shape[process_dim_actual]; // 当前处理维度的长度 + const int64_t stride = strides[process_dim_actual]; // 沿当前维度移动一个元素所需的步幅 + const int64_t base_offset = batch_offset + sample_base_offset; // 最终的起始地址 -__global__ void edt_pass1_rows(const float* input, float* temp, int H, int W) { - int y = blockIdx.x; - if (y >= H) return; extern __shared__ float sdata[]; - float* f = sdata; - int* v = (int*)(sdata + W); - float* z = (float*)(v + W + 1); - - for (int x = threadIdx.x; x < W; x += blockDim.x) { - float val = input[y * W + x]; - // 【关键任务修改】 - // 如果像素是 0 (背景),则为源点 (距离0);否则为无穷远。 - f[x] = (val < 0.5f) ? 0.0f : 1e10f; + float* f = sdata; // 存储函数值 g(p) = f(p) + p^2 + int* v = (int*)(sdata + N); // 存储抛物线顶点的索引 + float* z = (float*)(v + N + 1); // 存储相邻抛物线的交点 + + // 块内的所有线程协同将数据从全局内存加载到共享内存 + for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { + f[i] = in_data[base_offset + i * stride]; } - __syncthreads(); + __syncthreads(); // 等待所有线程完成加载 - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; - z[0] = -1e10f; - z[1] = 1e10f; + //计算抛物线的下包络 + if (threadIdx.x == 0 && N > 0) { + int k = 0; // 下包络中的抛物线数量 + v[0] = 0; // 第一个抛物线的顶点索引为0 + z[0] = -1e20f; z[1] = 1e20f; // 初始化交点为负无穷和正无穷 - for (int q = 1; q < W; q++) { + // 遍历所有点,构建下包络 + for (int q = 1; q < N; q++) { float s; + // 寻找新的抛物线q应该插入的位置 while (true) { - int p = v[k]; + int p = v[k]; if (q == p) break; + // s 是抛物线 p 和 q 的交点的横坐标 s = ((f[q] + q * q) - (f[p] + p * p)) / (2.0f * (q - p)); + // 如果交点在当前区间的右侧,则找到了插入点 if (s > z[k]) { break; } - if (k == 0) { break; } + // 否则,抛物线p被q完全覆盖,需要移除p + if (k == 0) { break; } k--; } - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e10f; + // 插入新的抛物线q + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e20f; } - + // 计算距离平方 k = 0; - for (int q = 0; q < W; q++) { - while (z[k + 1] < q) k++; - int p = v[k]; - temp[y * W + q] = (q - p) * (q - p) + f[p]; + // 遍历所有点,找到其头顶上方的下包络线段,并计算距离 + for (int q = 0; q < N; q++) { + while (z[k + 1] < q) k++; // 找到点q所属的区间 + int p = v[k]; // 获取该区间的抛物线顶点索引 + // 计算距离平方: D(q)^2 = (q - p)^2 + g(p) + out_data[base_offset + q * stride] = (q - p) * (q - p) + f[p]; } } } -// PASS 2: 对每一列进行操作 -__global__ void edt_pass2_cols(const float* temp, float* output, int H, int W) { - int x = blockIdx.x; - if (x >= W) return; - extern __shared__ float sdata[]; - float* f = sdata; - int* v = (int*)(sdata + H); - float* z = (float*)(v + H + 1); - - for (int y = threadIdx.x; y < H; y += blockDim.x) { - f[y] = temp[y * W + x]; +// --- Kernel 3: 开平方根内核 --- +/** + * @brief 对张量中的每个元素计算平方根。 + * @details 这是一个简单的逐元素操作。由于之前的1D pass计算的是距离的平方, + * 此内核在所有维度处理完毕后被调用,以得到最终的欧氏距离。 + * @param data 需要进行开方操作的张量数据指针。 + * @param N 张量中的元素总数。 + */ +__global__ void sqrt_kernel(float* data, int64_t N) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + data[idx] = sqrtf(data[idx]); } - __syncthreads(); +} - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; - z[0] = -1e10f; - z[1] = 1e10f; +// --- 主调函数 (Host) --- +/** + * @brief 执行N维欧氏距离变换。 + * @param input 一个N维的PyTorch张量,第一个维度被视为批处理(batch)维度。 + * @return 一个与输入形状相同的张量,包含每个点到最近前景点(值>=0.5)的欧氏距离。 + */ +torch::Tensor distance_transform_cuda(torch::Tensor input) { + auto original_input = input; + + // --- 预处理: 统一输入格式 --- + // 确保所有输入都至少是3D的 (B, ...),方便后续统一处理。 + // 如果输入是 (H, W) 或 (L),则变为 (1, H, W) 或 (1, L)。 + bool had_no_batch_dim = (input.dim() <= 2); + if (had_no_batch_dim) { input = input.unsqueeze(0); } - for (int q = 1; q < H; q++) { + // 检查输入张量是否在CUDA上并且是内存连续的 + TORCH_CHECK(input.is_cuda(), "Input must be on a CUDA device."); + TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous."); - float s; - while (true) { - int p = v[k]; - s = ((f[q] + q * q) - (f[p] + p * p)) / (2.0f * (q - p)); - if (s > z[k]) { - break; - } - if (k == 0) { - break; - } - k--; - } - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e10f; - } + if (input.numel() == 0) { return torch::empty_like(original_input); } + + // --- 获取张量元数据 --- + const auto ndim = input.dim(); + const auto sample_ndim = ndim - 1; // 空间维度 = 总维度 - 1 (batch) + const auto batch_size = input.size(0); + const int64_t N_total = input.numel(); + + auto shape_vec = input.sizes().vec(); + auto strides_vec = input.strides().vec(); + + // --- 内存分配: 使用Ping-Pong缓冲策略 --- + // 分配两个缓冲区,在处理每个维度时交替作为输入和输出,避免原地读写冲突。 + auto output = torch::empty_like(input); + auto buffer = (sample_ndim > 0) ? torch::empty_like(input) : output; - k = 0; - for (int q = 0; q < H; q++) { - while (z[k + 1] < q) k++; - int p = v[k]; - output[q * W + x] = sqrtf((q - p) * (q - p) + f[p]); + //二值化 + int threads = 256; // 定义每个线程块的线程数 + int blocks = (N_total + threads - 1) / threads; // 计算启动的线程块数 + binarize_kernel<<>>(input.data_ptr(), buffer.data_ptr(), N_total); + + //循环调用 edt_1d_pass_sq_kernel + // 将shape和strides信息从CPU内存拷贝到GPU内存,以便内核可以访问 + int64_t *shape_gpu, *strides_gpu; + cudaMalloc(&shape_gpu, ndim * sizeof(int64_t)); + cudaMalloc(&strides_gpu, ndim * sizeof(int64_t)); + cudaMemcpy(shape_gpu, shape_vec.data(), ndim * sizeof(int64_t), cudaMemcpyHostToDevice); + cudaMemcpy(strides_gpu, strides_vec.data(), ndim * sizeof(int64_t), cudaMemcpyHostToDevice); + + torch::Tensor current_input = buffer; + torch::Tensor current_output = output; + + // 遍历所有空间维度 + for (int32_t d_sample = 0; d_sample < sample_ndim; ++d_sample) { + // 为当前处理的维度计算启动内核所需的参数 + int64_t num_slices_per_sample = 1; + for(int i = 0; i < sample_ndim; ++i) { + if (i != d_sample) num_slices_per_sample *= shape_vec[i + 1]; } + int64_t total_slices = batch_size * num_slices_per_sample; + int64_t slice_len = shape_vec[d_sample + 1]; + + // 动态设置线程数和共享内存大小 + int threads_pass = (slice_len > 0 && slice_len < 256) ? slice_len : 256; + if (threads_pass == 0) threads_pass = 1; + size_t shared_mem_size = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + (slice_len + 2) * sizeof(float); + + edt_1d_pass_sq_kernel<<>>( + current_input.data_ptr(), current_output.data_ptr(), + shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + ); + // 交换输入和输出缓冲区,为下一个维度做准备 + std::swap(current_input, current_output); + } + + cudaFree(shape_gpu); + cudaFree(strides_gpu); + + //计算最终距离 + // 经过循环后,current_input 指向的是包含最终距离平方结果的张量 + sqrt_kernel<<>>(current_input.data_ptr(), N_total); + + // 如果最后一轮的输出不在我们期望的 output 张量里,就做一次拷贝 + if (current_input.data_ptr() != output.data_ptr()){ + output.copy_(current_input); } + + // 如果最初没有批处理维度,则移除我们添加的维度 + if (had_no_batch_dim) { output = output.squeeze(0); } + + return output; } \ No newline at end of file diff --git a/torchmorph/csrc/torchmorph.cpp b/torchmorph/csrc/torchmorph.cpp new file mode 100644 index 0000000..b7f466a --- /dev/null +++ b/torchmorph/csrc/torchmorph.cpp @@ -0,0 +1,10 @@ +#include + +// Declare CUDA implementations +torch::Tensor add_cuda(torch::Tensor input, float scalar); +torch::Tensor distance_transform_cuda(torch::Tensor input); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("add_cuda", &add_cuda, "Add tensor with scalar"); + m.def("distance_transform_cuda", &distance_transform_cuda, "Distance transform"); +} \ No newline at end of file diff --git a/torchmorph/csrc/torchmorph.cu b/torchmorph/csrc/torchmorph.cu deleted file mode 100644 index dc2e4ce..0000000 --- a/torchmorph/csrc/torchmorph.cu +++ /dev/null @@ -1,51 +0,0 @@ -// ========================================================================= -// 内容保存到: torchmorph/csrc/torchmorph.cpp -// ========================================================================= - -#include - -// 函数声明:告诉 C++ 编译器,这两个 CUDA 内核函数是在别的文件里定义的 -// 这样 C++ 代码才能成功调用 .cu 文件里的内核 -__global__ void edt_pass1_rows(const float* input, float* temp, int H, int W); -__global__ void edt_pass2_cols(const float* temp, float* output, int H, int W); - - - -// 主调函数 (运行在 CPU 上) -torch::Tensor distance_transform_cuda(torch::Tensor input) { - // 检查输入张量是否在 CUDA 上,以及是否为二维 - TORCH_CHECK(input.is_cuda(), "Input must be on CUDA"); - TORCH_CHECK(input.dim() == 2, "Only 2D tensors supported"); - - int H = input.size(0); - int W = input.size(1); - - // 创建临时的和最终的输出张量 - auto temp = torch::empty_like(input); - auto output = torch::empty_like(input); - - // 计算动态共享内存的大小 - size_t shared_mem_pass1 = W * sizeof(float) + (W + 1) * sizeof(int) + (W + 2) * sizeof(float); - size_t shared_mem_pass2 = H * sizeof(float) + (H + 1) * sizeof(int) + (H + 2) * sizeof(float); - - // 设置每个块的线程数 - int threads_per_block = 32; - - // <<<...>>> 语法:启动 CUDA 内核 - // 参数:Grid大小, Block大小, 共享内存大小, (可选的流) - - // Pass 1: 每行启动一个 block - edt_pass1_rows<<>>( - input.data_ptr(), temp.data_ptr(), H, W); - - // Pass 2: 每列启动一个 block - edt_pass2_cols<<>>( - temp.data_ptr(), output.data_ptr(), H, W); - - return output; -} - -// 使用 PYBIND11 将 C++ 函数绑定到 Python 模块 -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("distance_transform", &distance_transform_cuda, "CUDA-accelerated Exact Euclidean Distance Transform"); -} \ No newline at end of file diff --git a/torchmorph/distance_transform.py b/torchmorph/distance_transform.py index b4cd458..7840158 100644 --- a/torchmorph/distance_transform.py +++ b/torchmorph/distance_transform.py @@ -6,4 +6,4 @@ def distance_transform(input: torch.Tensor) -> torch.Tensor: """Distance Transform in CUDA.""" if not input.is_cuda: raise ValueError("Input tensor must be on CUDA device.") - return _C.distance_transform(input) + return _C.distance_transform_cuda(input) \ No newline at end of file From f9420b25ffcca6115a49a7ec7c329c2db46583e7 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Sun, 2 Nov 2025 15:07:10 +0800 Subject: [PATCH 05/39] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBUG:=E5=8E=9F=E5=85=88?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E4=B8=AD=E8=AF=AF=E5=B0=860=E5=BD=93?= =?UTF-8?q?=E6=88=90=E8=83=8C=E6=99=AF=EF=BC=8C1=E5=BD=93=E6=88=90?= =?UTF-8?q?=E5=89=8D=E6=99=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torchmorph/csrc/distance_transform_kernel.cu | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 0bc5c26..d8cc8a7 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -1,11 +1,11 @@ #include #include -// --- Kernel 1: 二值化内核 --- -/** - * @brief 对输入张量进行逐元素二值化。 - * @details 这是一个简单的并行操作。它将输入张量中的背景像素(值<0.5)设置为0, - * 并将前景像素(值>=0.5)设置为一个极大的值(1e20f),这在距离变换的上下文中 +// --- Kernel 1: 二值化内核 --- +/* + * @brief 对输入张量进行逐元素二值化,为距离变换做准备。 + * @details 将前景点(in[idx] == 0)的初始距离设为0, + * 背景点的初始距离设为一个极大值(1e20f),这在距离变换的上下文中 * 可以被认为是无穷大。 * @param in 输入张量的数据指针。 * @param out 输出张量的数据指针。 @@ -14,7 +14,9 @@ __global__ void binarize_kernel(const float* in, float* out, int64_t N) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { - out[idx] = (in[idx] < 0.5f) ? 0.0f : 1e20f; + // 如果输入像素为0,则为前景点,其距离为0。 + // 如果输入像素非0,则为背景点,其初始距离为无穷大。 + out[idx] = (in[idx] == 0.0f) ? 0.0f : 1e20f; } } From 85bd1e546b2c53552868ff79b165e78c6a2176e5 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Mon, 3 Nov 2025 16:53:10 +0800 Subject: [PATCH 06/39] returns both distance and index --- test/test_distance_transform.py | 9 ++--- torchmorph/csrc/distance_transform_kernel.cu | 37 +++++++++++--------- torchmorph/csrc/torchmorph.cpp | 4 +-- torchmorph/distance_transform.py | 8 ++++- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 3feffb4..63b8568 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -73,13 +73,14 @@ def test_batch_processing(input_numpy, request): print(f"\n\n--- 正在运行测试: {request.node.callspec.id} ---") print(f"输入张量形状: {x.shape}") - y_cuda = tm.distance_transform(x.clone()) + dist_cuda, idx_cuda = tm.distance_transform(x.clone()) + print(f"Output index shape: {idx_cuda.shape}.") y_ref_numpy = batch_distance_transform_edt(x_numpy_contiguous) y_ref = torch.from_numpy(y_ref_numpy).to(torch.float32).cuda() - assert y_cuda.shape == y_ref.shape, f"形状不匹配! CUDA输出: {y_cuda.shape}, SciPy应为: {y_ref.shape}" + assert dist_cuda.shape == y_ref.shape, f"形状不匹配! CUDA输出: {dist_cuda.shape}, SciPy应为: {y_ref.shape}" print("CUDA 和 SciPy 输出形状匹配。") - torch.testing.assert_close(y_cuda, y_ref, atol=1e-3, rtol=1e-3) - print("--- 断言通过 (数值接近) ---") \ No newline at end of file + torch.testing.assert_close(dist_cuda, y_ref, atol=1e-3, rtol=1e-3) + print("--- 断言通过 (数值接近) ---") diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index d8cc8a7..534618f 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -11,7 +11,7 @@ * @param out 输出张量的数据指针。 * @param N 张量中的元素总数。 */ -__global__ void binarize_kernel(const float* in, float* out, int64_t N) { +__global__ void initialize_distance_kernel(const float* in, float* out, int64_t N) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { // 如果输入像素为0,则为前景点,其距离为0。 @@ -140,7 +140,7 @@ __global__ void sqrt_kernel(float* data, int64_t N) { * @param input 一个N维的PyTorch张量,第一个维度被视为批处理(batch)维度。 * @return 一个与输入形状相同的张量,包含每个点到最近前景点(值>=0.5)的欧氏距离。 */ -torch::Tensor distance_transform_cuda(torch::Tensor input) { +std::tuple distance_transform_cuda(torch::Tensor input) { auto original_input = input; // --- 预处理: 统一输入格式 --- @@ -153,7 +153,6 @@ torch::Tensor distance_transform_cuda(torch::Tensor input) { TORCH_CHECK(input.is_cuda(), "Input must be on a CUDA device."); TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous."); - if (input.numel() == 0) { return torch::empty_like(original_input); } // --- 获取张量元数据 --- const auto ndim = input.dim(); @@ -161,39 +160,45 @@ torch::Tensor distance_transform_cuda(torch::Tensor input) { const auto batch_size = input.size(0); const int64_t N_total = input.numel(); - auto shape_vec = input.sizes().vec(); + auto shape = input.sizes().vec(); + auto index_shape = shape; + index_shape.push_back(sample_ndim); + auto strides_vec = input.strides().vec(); // --- 内存分配: 使用Ping-Pong缓冲策略 --- // 分配两个缓冲区,在处理每个维度时交替作为输入和输出,避免原地读写冲突。 - auto output = torch::empty_like(input); - auto buffer = (sample_ndim > 0) ? torch::empty_like(input) : output; + auto distance = torch::zeros_like(input); + auto index = torch::zeros(index_shape); + auto buffer = (sample_ndim > 0) ? torch::empty_like(input) : distance; + + if (input.numel() == 0) { return std::make_tuple(distance, index); } //二值化 int threads = 256; // 定义每个线程块的线程数 int blocks = (N_total + threads - 1) / threads; // 计算启动的线程块数 - binarize_kernel<<>>(input.data_ptr(), buffer.data_ptr(), N_total); + initialize_distance_kernel<<>>(input.data_ptr(), buffer.data_ptr(), N_total); //循环调用 edt_1d_pass_sq_kernel // 将shape和strides信息从CPU内存拷贝到GPU内存,以便内核可以访问 int64_t *shape_gpu, *strides_gpu; cudaMalloc(&shape_gpu, ndim * sizeof(int64_t)); cudaMalloc(&strides_gpu, ndim * sizeof(int64_t)); - cudaMemcpy(shape_gpu, shape_vec.data(), ndim * sizeof(int64_t), cudaMemcpyHostToDevice); + cudaMemcpy(shape_gpu, shape.data(), ndim * sizeof(int64_t), cudaMemcpyHostToDevice); cudaMemcpy(strides_gpu, strides_vec.data(), ndim * sizeof(int64_t), cudaMemcpyHostToDevice); torch::Tensor current_input = buffer; - torch::Tensor current_output = output; + torch::Tensor current_output = distance; // 遍历所有空间维度 for (int32_t d_sample = 0; d_sample < sample_ndim; ++d_sample) { // 为当前处理的维度计算启动内核所需的参数 int64_t num_slices_per_sample = 1; for(int i = 0; i < sample_ndim; ++i) { - if (i != d_sample) num_slices_per_sample *= shape_vec[i + 1]; + if (i != d_sample) num_slices_per_sample *= shape[i + 1]; } int64_t total_slices = batch_size * num_slices_per_sample; - int64_t slice_len = shape_vec[d_sample + 1]; + int64_t slice_len = shape[d_sample + 1]; // 动态设置线程数和共享内存大小 int threads_pass = (slice_len > 0 && slice_len < 256) ? slice_len : 256; @@ -216,12 +221,12 @@ torch::Tensor distance_transform_cuda(torch::Tensor input) { sqrt_kernel<<>>(current_input.data_ptr(), N_total); // 如果最后一轮的输出不在我们期望的 output 张量里,就做一次拷贝 - if (current_input.data_ptr() != output.data_ptr()){ - output.copy_(current_input); + if (current_input.data_ptr() != distance.data_ptr()){ + distance.copy_(current_input); } // 如果最初没有批处理维度,则移除我们添加的维度 - if (had_no_batch_dim) { output = output.squeeze(0); } + if (had_no_batch_dim) { distance = distance.squeeze(0); } - return output; -} \ No newline at end of file + return std::make_tuple(distance, index); +} diff --git a/torchmorph/csrc/torchmorph.cpp b/torchmorph/csrc/torchmorph.cpp index b7f466a..c79970c 100644 --- a/torchmorph/csrc/torchmorph.cpp +++ b/torchmorph/csrc/torchmorph.cpp @@ -2,9 +2,9 @@ // Declare CUDA implementations torch::Tensor add_cuda(torch::Tensor input, float scalar); -torch::Tensor distance_transform_cuda(torch::Tensor input); +std::tuple distance_transform_cuda(torch::Tensor input); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("add_cuda", &add_cuda, "Add tensor with scalar"); m.def("distance_transform_cuda", &distance_transform_cuda, "Distance transform"); -} \ No newline at end of file +} diff --git a/torchmorph/distance_transform.py b/torchmorph/distance_transform.py index 7840158..0184be5 100644 --- a/torchmorph/distance_transform.py +++ b/torchmorph/distance_transform.py @@ -6,4 +6,10 @@ def distance_transform(input: torch.Tensor) -> torch.Tensor: """Distance Transform in CUDA.""" if not input.is_cuda: raise ValueError("Input tensor must be on CUDA device.") - return _C.distance_transform_cuda(input) \ No newline at end of file + if input.ndim < 2 or input.numel() == 0: + raise ValueError(f"Invalid input dimension: {input.shape}.") + + # binarize input + input[input != 0] = 1 + + return _C.distance_transform_cuda(input) From be4eb3dba55a34b80dc87089cd110d31715d7741 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Mon, 3 Nov 2025 16:53:30 +0800 Subject: [PATCH 07/39] format --- test/test_distance_transform.py | 61 ++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 63b8568..f65a0b4 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -11,48 +11,59 @@ def batch_distance_transform_edt(batch_numpy): # (H, W) -> (1, H, W) if is_single_sample: batch_numpy = batch_numpy[np.newaxis, ...] - + results = [dte(sample) for sample in batch_numpy] - output = np.stack(results, axis=0) + output = np.stack(results, axis=0) # (1, H, W) -> (H, W) if is_single_sample: output = output.squeeze(0) - + return output + # 用例 1: 批处理的 2D 图像 -case_batch_2d = np.array([ - # 第 1 张图 - [[0, 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], - # 第 2 张图 - [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]] -], dtype=np.float32) +case_batch_2d = np.array( + [ + # 第 1 张图 + [[0, 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], + # 第 2 张图 + [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], + ], + dtype=np.float32, +) # 用例 2: 批处理的 3D 图像 -case_3d_sample1 = np.ones((4, 5, 6), dtype=np.float32); case_3d_sample1[1, 1, 1] = 0.0; case_3d_sample1[2, 3, 4] = 0.0 -case_3d_sample2 = np.ones((4, 5, 6), dtype=np.float32); case_3d_sample2[0, 0, 0] = 0.0 +case_3d_sample1 = np.ones((4, 5, 6), dtype=np.float32) +case_3d_sample1[1, 1, 1] = 0.0 +case_3d_sample1[2, 3, 4] = 0.0 +case_3d_sample2 = np.ones((4, 5, 6), dtype=np.float32) +case_3d_sample2[0, 0, 0] = 0.0 case_batch_3d = np.stack([case_3d_sample1, case_3d_sample2], axis=0) # 用例 3: 单张 2D 图像 (隐式批处理) -case_single_2d = np.array([ - [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], -], dtype=np.float32) +case_single_2d = np.array( + [ + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + ], + dtype=np.float32, +) # 用例 4: 单张 2D 图像 (显式批处理) case_explicit_batch_one = case_single_2d[np.newaxis, ...] # 用例 5: 含幺元维度的批处理 -case_dim_one = np.ones((2, 5, 1), dtype=np.float32) # 两张 5x1 的图片 +case_dim_one = np.ones((2, 5, 1), dtype=np.float32) # 两张 5x1 的图片 case_dim_one[0, 2, 0] = 0.0 case_dim_one[1, 4, 0] = 0.0 # 用例 6: 1D 张量的批处理 -case_batch_1d = np.array([ - [1, 1, 0, 1, 0, 1], - [0, 1, 1, 1, 1, 0] -], dtype=np.float32) +case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32) + @pytest.mark.parametrize( "input_numpy", @@ -63,7 +74,7 @@ def batch_distance_transform_edt(batch_numpy): pytest.param(case_explicit_batch_one, id="单张2D图像(显式批处理)"), pytest.param(case_dim_one, id="含幺元维度的批处理"), pytest.param(case_batch_1d, id="批处理1D数据"), - ] + ], ) def test_batch_processing(input_numpy, request): if not torch.cuda.is_available(): @@ -75,12 +86,14 @@ def test_batch_processing(input_numpy, request): print(f"输入张量形状: {x.shape}") dist_cuda, idx_cuda = tm.distance_transform(x.clone()) print(f"Output index shape: {idx_cuda.shape}.") - + y_ref_numpy = batch_distance_transform_edt(x_numpy_contiguous) y_ref = torch.from_numpy(y_ref_numpy).to(torch.float32).cuda() - - assert dist_cuda.shape == y_ref.shape, f"形状不匹配! CUDA输出: {dist_cuda.shape}, SciPy应为: {y_ref.shape}" + + assert ( + dist_cuda.shape == y_ref.shape + ), f"形状不匹配! CUDA输出: {dist_cuda.shape}, SciPy应为: {y_ref.shape}" print("CUDA 和 SciPy 输出形状匹配。") - + torch.testing.assert_close(dist_cuda, y_ref, atol=1e-3, rtol=1e-3) print("--- 断言通过 (数值接近) ---") From 8f835e37e59e2749eb93ef9efb4bd88836a44453 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Mon, 3 Nov 2025 18:00:28 +0800 Subject: [PATCH 08/39] benchmark --- benchmark/distance_transform.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 benchmark/distance_transform.py diff --git a/benchmark/distance_transform.py b/benchmark/distance_transform.py new file mode 100644 index 0000000..91219bb --- /dev/null +++ b/benchmark/distance_transform.py @@ -0,0 +1,24 @@ +import torch +import torch.utils.benchmark as benchmark +import scipy.ndimage as ndi +import torchmorph as tm + +for size in [64, 128, 256, 512, 1024, 2048]: + x = (torch.randn(1, 1, size, size, device="cuda") > 0).to(torch.float32) + + # TorchMorph CUDA + t1 = benchmark.Timer( + stmt="tm.distance_transform(x)", + setup="from __main__ import x, tm", + num_threads=torch.get_num_threads() + ) + # SciPy (CPU) + import numpy as np + x_np = x.cpu().squeeze().numpy() + t2 = benchmark.Timer( + stmt="ndi.distance_transform_edt(x_np)", + setup="from __main__ import x_np, ndi" + ) + + print(f"Size {size}:\n", t1.blocked_autorange()) + print(f"Size {size}:\n", t2.blocked_autorange()) From 7e108791dee292f7fcaa1e035e58384ce752adc8 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Mon, 3 Nov 2025 18:27:09 +0800 Subject: [PATCH 09/39] benchmark outputs tables --- benchmark/distance_transform.py | 93 ++++++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 19 deletions(-) diff --git a/benchmark/distance_transform.py b/benchmark/distance_transform.py index 91219bb..c659c82 100644 --- a/benchmark/distance_transform.py +++ b/benchmark/distance_transform.py @@ -1,24 +1,79 @@ import torch import torch.utils.benchmark as benchmark import scipy.ndimage as ndi +import numpy as np +from prettytable import PrettyTable import torchmorph as tm -for size in [64, 128, 256, 512, 1024, 2048]: - x = (torch.randn(1, 1, size, size, device="cuda") > 0).to(torch.float32) - - # TorchMorph CUDA - t1 = benchmark.Timer( - stmt="tm.distance_transform(x)", - setup="from __main__ import x, tm", - num_threads=torch.get_num_threads() - ) - # SciPy (CPU) - import numpy as np - x_np = x.cpu().squeeze().numpy() - t2 = benchmark.Timer( - stmt="ndi.distance_transform_edt(x_np)", - setup="from __main__ import x_np, ndi" - ) - - print(f"Size {size}:\n", t1.blocked_autorange()) - print(f"Size {size}:\n", t2.blocked_autorange()) +sizes = [64, 128, 256, 512, 1024] +batches = [1, 4, 8, 16] +dtype = torch.float32 +device = "cuda" +MIN_RUN = 1.0 # seconds per measurement + +torch.set_num_threads(torch.get_num_threads()) + +for B in batches: + table = PrettyTable() + table.field_names = [ + "Size", + "SciPy (ms/img)", + "Torch 1× (ms/img)", + "Torch batch (ms/img)", + "Speedup 1×", + "Speedup batch", + ] + for c in table.field_names: + table.align[c] = "r" + + for s in sizes: + # Inputs + x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) + x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] + x_imgs = [x[i:i+1] for i in range(B)] + + # SciPy (CPU, one-by-one) + stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]" + t_scipy = benchmark.Timer( + stmt=stmt_scipy, + setup="from __main__ import x_np_list, ndi", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + scipy_per_img_ms = (t_scipy.median * 1e3) / B + + # Torch (CUDA, one-by-one) + stmt_torch1 = """ +for xi in x_imgs: + tm.distance_transform(xi) +""" + t_torch1 = benchmark.Timer( + stmt=stmt_torch1, + setup="from __main__ import x_imgs, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + torch1_per_img_ms = (t_torch1.median * 1e3) / B + + # Torch (CUDA, batched) + t_batch = benchmark.Timer( + stmt="tm.distance_transform(x)", + setup="from __main__ import x, tm", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=MIN_RUN) + torchB_per_img_ms = (t_batch.median * 1e3) / B + + # Speedups + speed1 = scipy_per_img_ms / torch1_per_img_ms + speedB = scipy_per_img_ms / torchB_per_img_ms + + table.add_row([ + s, + f"{scipy_per_img_ms:.3f}", + f"{torch1_per_img_ms:.3f}", + f"{torchB_per_img_ms:.3f}", + f"{speed1:.1f}×", + f"{speedB:.1f}×", + ]) + + print(f"\n=== Batch Size: {B} ===") + print(table) + From 7b6b8aae176f5dc8a7ac03898d84fb287fcea674 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Mon, 3 Nov 2025 21:31:14 +0800 Subject: [PATCH 10/39] prettytable --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index fc961bb..0df6115 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,4 +12,4 @@ flake8>=6.0 setuptools>=65.0 wheel>=0.40 ninja>=1.11 # optional, speeds up torch extension builds - +prettytable>=3.16.0 From 15933d9df18aefd1b10c38c5ef8e32e4ad967a76 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Thu, 20 Nov 2025 18:20:03 +0800 Subject: [PATCH 11/39] =?UTF-8?q?=E5=AE=9E=E7=8E=B0n=E7=BB=B4=E6=89=B9?= =?UTF-8?q?=E5=A4=84=E7=90=86=E5=90=8C=E6=97=B6=E8=BF=94=E5=9B=9E=E5=9D=90?= =?UTF-8?q?=E6=A0=87=E5=92=8C=E8=B7=9D=E7=A6=BB=E7=9A=84=E7=B2=BE=E7=A1=AE?= =?UTF-8?q?=E6=AC=A7=E5=BC=8F=E8=B7=9D=E7=A6=BB=E5=8F=98=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- test/test_distance_transform.py | 136 ++-- torchmorph/csrc/distance_transform_kernel.cu | 731 ++++++++++++++----- 3 files changed, 619 insertions(+), 250 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 10af0cd..a4d6163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,5 +18,5 @@ addopts = "-v --import-mode=importlib" testpaths = ["test"] [build-system] -requires = ["setuptools>=61.0", "wheel", "torch"] +requires = ["setuptools>=61.0", "wheel", "torch", "numpy"] build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index f65a0b4..3a11166 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -1,70 +1,37 @@ import torch import pytest -from scipy.ndimage import distance_transform_edt as dte -import torchmorph as tm +from scipy.ndimage import distance_transform_edt as scipy_edt import numpy as np +import torchmorph as tm - -def batch_distance_transform_edt(batch_numpy): - +# 辅助函数:调用 SciPy 并处理格式 +def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, np.ndarray]: is_single_sample = batch_numpy.ndim <= 2 - # (H, W) -> (1, H, W) if is_single_sample: batch_numpy = batch_numpy[np.newaxis, ...] - - results = [dte(sample) for sample in batch_numpy] - output = np.stack(results, axis=0) - # (1, H, W) -> (H, W) + dist_results, indices_results = [], [] + for sample in batch_numpy: + dist, indices = scipy_edt(sample, return_indices=True, return_distances=True) + dist_results.append(dist) + indices_results.append(indices) + output_dist = np.stack(dist_results, axis=0) + output_indices = np.stack(indices_results, axis=0) + output_indices = np.moveaxis(output_indices, 1, -1) if is_single_sample: - output = output.squeeze(0) - - return output - - -# 用例 1: 批处理的 2D 图像 -case_batch_2d = np.array( - [ - # 第 1 张图 - [[0, 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], - # 第 2 张图 - [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], - ], - dtype=np.float32, -) - - -# 用例 2: 批处理的 3D 图像 -case_3d_sample1 = np.ones((4, 5, 6), dtype=np.float32) -case_3d_sample1[1, 1, 1] = 0.0 -case_3d_sample1[2, 3, 4] = 0.0 -case_3d_sample2 = np.ones((4, 5, 6), dtype=np.float32) -case_3d_sample2[0, 0, 0] = 0.0 -case_batch_3d = np.stack([case_3d_sample1, case_3d_sample2], axis=0) - -# 用例 3: 单张 2D 图像 (隐式批处理) -case_single_2d = np.array( - [ - [0, 1, 0, 1], - [1, 0, 1, 0], - [0, 1, 0, 1], - [1, 0, 1, 0], - ], - dtype=np.float32, -) - - -# 用例 4: 单张 2D 图像 (显式批处理) + output_dist = output_dist.squeeze(0) + output_indices = output_indices.squeeze(0) + return output_dist, output_indices + +# 用例定义 +case_batch_2d = np.array([[[0., 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]],[[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.float32) +_case_3d_s1 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s1[1, 1, 1] = 0.0; _case_3d_s1[2, 3, 4] = 0.0 +_case_3d_s2 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s2[0, 0, 0] = 0.0 +case_batch_3d = np.stack([_case_3d_s1, _case_3d_s2], axis=0) +case_single_2d = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]], dtype=np.float32) case_explicit_batch_one = case_single_2d[np.newaxis, ...] - -# 用例 5: 含幺元维度的批处理 -case_dim_one = np.ones((2, 5, 1), dtype=np.float32) # 两张 5x1 的图片 -case_dim_one[0, 2, 0] = 0.0 -case_dim_one[1, 4, 0] = 0.0 - -# 用例 6: 1D 张量的批处理 +case_dim_one = np.ones((2, 5, 1), dtype=np.float32); case_dim_one[0, 2, 0] = 0.0; case_dim_one[1, 4, 0] = 0.0 case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32) - @pytest.mark.parametrize( "input_numpy", [ @@ -76,24 +43,49 @@ def batch_distance_transform_edt(batch_numpy): pytest.param(case_batch_1d, id="批处理1D数据"), ], ) -def test_batch_processing(input_numpy, request): +def test_distance_transform_and_indices(input_numpy: np.ndarray, request: pytest.FixtureRequest): if not torch.cuda.is_available(): pytest.skip("CUDA not available") + x_numpy_contiguous = np.ascontiguousarray(input_numpy) - x = torch.from_numpy(x_numpy_contiguous).cuda() + x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() print(f"\n\n--- 正在运行测试: {request.node.callspec.id} ---") - print(f"输入张量形状: {x.shape}") - dist_cuda, idx_cuda = tm.distance_transform(x.clone()) - print(f"Output index shape: {idx_cuda.shape}.") - - y_ref_numpy = batch_distance_transform_edt(x_numpy_contiguous) - y_ref = torch.from_numpy(y_ref_numpy).to(torch.float32).cuda() - - assert ( - dist_cuda.shape == y_ref.shape - ), f"形状不匹配! CUDA输出: {dist_cuda.shape}, SciPy应为: {y_ref.shape}" - print("CUDA 和 SciPy 输出形状匹配。") - - torch.testing.assert_close(dist_cuda, y_ref, atol=1e-3, rtol=1e-3) - print("--- 断言通过 (数值接近) ---") + print(f"输入张量形状: {x_cuda.shape}") + + # 调用您的 Python 包装函数 + dist_cuda, idx_cuda = tm.distance_transform(x_cuda.clone()) + + print(f"CUDA 距离输出形状: {dist_cuda.shape}") + print(f"CUDA 坐标输出形状: {idx_cuda.shape}") + + # 调用 SciPy 作为参考基准 + dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(x_numpy_contiguous) + dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() + + print(f"SciPy 距离输出形状: {dist_ref.shape}") + + # 断言验证 + print("\n--- 正在验证距离... ---") + assert dist_cuda.shape == dist_ref.shape + torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-3, rtol=1e-3) + print("距离断言通过 (形状和数值接近)。") + + print("\n--- 正在验证坐标... ---") + + # 鲁棒的坐标验证逻辑 + had_no_batch_dim = (x_numpy_contiguous.ndim <= idx_cuda.shape[-1]) + spatial_shape = x_cuda.shape if had_no_batch_dim else x_cuda.shape[1:] + coords = [torch.arange(s, device='cuda') for s in spatial_shape] + grid = torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1) + + if not had_no_batch_dim: + grid = grid.unsqueeze(0) + + diff = grid.float() - idx_cuda.float() + dist_sq_from_indices = torch.sum(diff * diff, dim=-1) + + torch.testing.assert_close(dist_sq_from_indices, dist_cuda * dist_cuda, atol=1e-3, rtol=1e-3) + print("坐标正确性断言通过 (计算出的距离与返回距离匹配)。") + + print("--- 测试通过 ---") \ No newline at end of file diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 534618f..90785bd 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -1,232 +1,609 @@ #include #include +#include +#include +#include +#include -// --- Kernel 1: 二值化内核 --- -/* - * @brief 对输入张量进行逐元素二值化,为距离变换做准备。 - * @details 将前景点(in[idx] == 0)的初始距离设为0, - * 背景点的初始距离设为一个极大值(1e20f),这在距离变换的上下文中 - * 可以被认为是无穷大。 - * @param in 输入张量的数据指针。 - * @param out 输出张量的数据指针。 - * @param N 张量中的元素总数。 - */ -__global__ void initialize_distance_kernel(const float* in, float* out, int64_t N) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < N) { - // 如果输入像素为0,则为前景点,其距离为0。 - // 如果输入像素非0,则为背景点,其初始距离为无穷大。 - out[idx] = (in[idx] == 0.0f) ? 0.0f : 1e20f; +// 优化策略:用4个独立的内核函数替代模板,完全消除分支 + +// 内核1: 第一个pass且是唯一pass (1D情况) +__global__ void edt_kernel_first_final( + const float* in_data, + float* out_dist, + int32_t* out_idx, + const int64_t* shape, + const int64_t* strides, + int32_t ndim, + int32_t process_dim_sample, + int64_t total_slices, + int64_t num_slices_per_sample +) { + int64_t slice_idx = blockIdx.x; + if (slice_idx >= total_slices) return; + + int64_t batch_idx = slice_idx / num_slices_per_sample; + int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; + int64_t batch_offset = batch_idx * strides[0]; + int64_t sample_base_offset = 0; + int64_t temp_idx = slice_idx_in_sample; + const int sample_ndim = ndim - 1; + + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; + } + } + + const int64_t process_dim_actual = process_dim_sample + 1; + const int64_t N = shape[process_dim_actual]; + const int64_t stride = strides[process_dim_actual]; + const int64_t base_offset = batch_offset + sample_base_offset; + + if (N == 0) return; + + extern __shared__ char s_buffer[]; + float* f = (float*)s_buffer; + int* v = (int*)(f + N); + float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); + + // 加载数据 - 第一个pass + for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { + int64_t global_offset = base_offset + i * stride; + float val = __ldg(&in_data[global_offset]); + int32_t* shared_idx_ptr = s_idx + i * sample_ndim; + + if (val == 0.0f) { + f[i] = 0.0f; + int64_t temp_coord = slice_idx_in_sample; + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + shared_idx_ptr[d] = temp_coord % size_of_dim; + temp_coord /= size_of_dim; + } else { + shared_idx_ptr[d] = 0; + } + } + shared_idx_ptr[process_dim_sample] = i; + } else { + f[i] = 1e20f; + for (int d = 0; d < sample_ndim; ++d) shared_idx_ptr[d] = -1; + } + } + __syncthreads(); + + // 构建包络 + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e20f; + z[1] = 1e20f; + + for (int q = 1; q < N; q++) { + float fq = f[q]; + int q_sq = q * q; + + while (k >= 0) { + int p = v[k]; + float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e20f; + break; + } + k--; + if (k < 0) { + k = 0; + v[0] = q; + z[0] = -1e20f; + z[1] = 1e20f; + break; + } + } + } + } + __syncthreads(); + + // 计算距离 - 最后一个pass,直接开方 + for (int q = threadIdx.x; q < N; q += blockDim.x) { + int k = 0; + float q_float = (float)q; + while (z[k + 1] < q_float) k++; + + int p = v[k]; + int64_t global_offset = base_offset + q * stride; + float dist_sq = (float)(q - p) * (q - p) + f[p]; + + out_dist[global_offset] = sqrtf(dist_sq); // 直接开方 + + int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; + const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; + } } } -// --- Kernel 2: 1D Pass 距离平方计算内核 --- -/** - * @brief 沿着一个指定的空间维度,对N维张量执行一维抛物线下包络算法。 - * @details 这是Felzenszwalb和Huttenlocher EDT算法的核心。它通过将N维问题分解为N个 - * 一维问题来解决。此内核负责处理其中一个维度。它只计算距离的平方,以避免 - * 昂贵的开方运算并保持数值精度。 - * 每个CUDA线程块(block)负责处理一条完整的一维扫描线(slice)。 - * @param in_data 输入张量数据指针。 - * @param out_data 输出张量数据指针。 - * @param shape 描述输入张量形状的数组指针 (在GPU上)。 - * @param strides 描述输入张量步幅的数组指针 (在GPU上)。 - * @param ndim 张量的总维度数 (包括批处理维度)。 - * @param process_dim_sample 当前正在处理的空间维度索引 (0代表第一个空间维度,依此类推)。 - * @param total_slices 需要处理的一维扫描线总数 (batch_size * num_slices_per_sample)。 - * @param num_slices_per_sample 每个样本中,垂直于当前处理维度的扫描线数量。 - */ -__global__ void edt_1d_pass_sq_kernel( - const float* in_data, float* out_data, - const int64_t* shape, const int64_t* strides, - int32_t ndim, int32_t process_dim_sample, - int64_t total_slices, int64_t num_slices_per_sample +// 内核2: 第一个pass但不是最后 +__global__ void edt_kernel_first_only( + const float* in_data, + float* out_dist, + int32_t* out_idx, + const int64_t* shape, + const int64_t* strides, + int32_t ndim, + int32_t process_dim_sample, + int64_t total_slices, + int64_t num_slices_per_sample ) { - // 每个线程块处理一条一维扫描线 int64_t slice_idx = blockIdx.x; if (slice_idx >= total_slices) return; - int64_t batch_idx = slice_idx / num_slices_per_sample; int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t batch_offset = batch_idx * strides[0]; // 获取批处理的基地址 + int64_t batch_offset = batch_idx * strides[0]; int64_t sample_base_offset = 0; int64_t temp_idx = slice_idx_in_sample; - int sample_ndim = ndim - 1; + const int sample_ndim = ndim - 1; - // 从非处理维度中计算出样本内的基地址偏移 for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; // 跳过当前正在处理的维度 + if (d == process_dim_sample) continue; int64_t size_of_dim = shape[d + 1]; - if (size_of_dim == 0) continue; - int64_t coord_in_dim = temp_idx % size_of_dim; - temp_idx /= size_of_dim; - sample_base_offset += coord_in_dim * strides[d + 1]; + if (size_of_dim > 0) { + sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; + } } - const int64_t process_dim_actual = process_dim_sample + 1; // 加上批处理维度的实际索引 - const int64_t N = shape[process_dim_actual]; // 当前处理维度的长度 - const int64_t stride = strides[process_dim_actual]; // 沿当前维度移动一个元素所需的步幅 - const int64_t base_offset = batch_offset + sample_base_offset; // 最终的起始地址 + const int64_t process_dim_actual = process_dim_sample + 1; + const int64_t N = shape[process_dim_actual]; + const int64_t stride = strides[process_dim_actual]; + const int64_t base_offset = batch_offset + sample_base_offset; + if (N == 0) return; - extern __shared__ float sdata[]; - float* f = sdata; // 存储函数值 g(p) = f(p) + p^2 - int* v = (int*)(sdata + N); // 存储抛物线顶点的索引 - float* z = (float*)(v + N + 1); // 存储相邻抛物线的交点 + extern __shared__ char s_buffer[]; + float* f = (float*)s_buffer; + int* v = (int*)(f + N); + float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); - // 块内的所有线程协同将数据从全局内存加载到共享内存 for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - f[i] = in_data[base_offset + i * stride]; + int64_t global_offset = base_offset + i * stride; + float val = __ldg(&in_data[global_offset]); + int32_t* shared_idx_ptr = s_idx + i * sample_ndim; + + if (val == 0.0f) { + f[i] = 0.0f; + int64_t temp_coord = slice_idx_in_sample; + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + shared_idx_ptr[d] = temp_coord % size_of_dim; + temp_coord /= size_of_dim; + } else { + shared_idx_ptr[d] = 0; + } + } + shared_idx_ptr[process_dim_sample] = i; + } else { + f[i] = 1e20f; + for (int d = 0; d < sample_ndim; ++d) shared_idx_ptr[d] = -1; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e20f; + z[1] = 1e20f; + + for (int q = 1; q < N; q++) { + float fq = f[q]; + int q_sq = q * q; + + while (k >= 0) { + int p = v[k]; + float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e20f; + break; + } + k--; + if (k < 0) { + k = 0; + v[0] = q; + z[0] = -1e20f; + z[1] = 1e20f; + break; + } + } + } + } + __syncthreads(); + + for (int q = threadIdx.x; q < N; q += blockDim.x) { + int k = 0; + float q_float = (float)q; + while (z[k + 1] < q_float) k++; + + int p = v[k]; + int64_t global_offset = base_offset + q * stride; + float dist_sq = (float)(q - p) * (q - p) + f[p]; + + out_dist[global_offset] = dist_sq; // 不开方 + + int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; + const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; + } + } +} + +// 内核3: 中间pass +__global__ void edt_kernel_middle( + const float* in_dist, + const int32_t* in_idx, + float* out_dist, + int32_t* out_idx, + const int64_t* shape, + const int64_t* strides, + int32_t ndim, + int32_t process_dim_sample, + int64_t total_slices, + int64_t num_slices_per_sample +) { + int64_t slice_idx = blockIdx.x; + if (slice_idx >= total_slices) return; + + int64_t batch_idx = slice_idx / num_slices_per_sample; + int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; + int64_t batch_offset = batch_idx * strides[0]; + int64_t sample_base_offset = 0; + int64_t temp_idx = slice_idx_in_sample; + const int sample_ndim = ndim - 1; + + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; + } } - __syncthreads(); // 等待所有线程完成加载 + + const int64_t process_dim_actual = process_dim_sample + 1; + const int64_t N = shape[process_dim_actual]; + const int64_t stride = strides[process_dim_actual]; + const int64_t base_offset = batch_offset + sample_base_offset; + + if (N == 0) return; - //计算抛物线的下包络 - if (threadIdx.x == 0 && N > 0) { - int k = 0; // 下包络中的抛物线数量 - v[0] = 0; // 第一个抛物线的顶点索引为0 - z[0] = -1e20f; z[1] = 1e20f; // 初始化交点为负无穷和正无穷 + extern __shared__ char s_buffer[]; + float* f = (float*)s_buffer; + int* v = (int*)(f + N); + float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); - // 遍历所有点,构建下包络 + for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { + int64_t global_offset = base_offset + i * stride; + f[i] = __ldg(&in_dist[global_offset]); + + const int32_t* global_idx_ptr = in_idx + global_offset * sample_ndim; + int32_t* shared_idx_ptr = s_idx + i * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + shared_idx_ptr[d] = __ldg(&global_idx_ptr[d]); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e20f; + z[1] = 1e20f; + for (int q = 1; q < N; q++) { - float s; - // 寻找新的抛物线q应该插入的位置 - while (true) { - int p = v[k]; if (q == p) break; - // s 是抛物线 p 和 q 的交点的横坐标 - s = ((f[q] + q * q) - (f[p] + p * p)) / (2.0f * (q - p)); - // 如果交点在当前区间的右侧,则找到了插入点 - if (s > z[k]) { break; } - // 否则,抛物线p被q完全覆盖,需要移除p - if (k == 0) { break; } + float fq = f[q]; + int q_sq = q * q; + + while (k >= 0) { + int p = v[k]; + float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e20f; + break; + } k--; + if (k < 0) { + k = 0; + v[0] = q; + z[0] = -1e20f; + z[1] = 1e20f; + break; + } } - // 插入新的抛物线q - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e20f; - } - // 计算距离平方 - k = 0; - // 遍历所有点,找到其头顶上方的下包络线段,并计算距离 - for (int q = 0; q < N; q++) { - while (z[k + 1] < q) k++; // 找到点q所属的区间 - int p = v[k]; // 获取该区间的抛物线顶点索引 - // 计算距离平方: D(q)^2 = (q - p)^2 + g(p) - out_data[base_offset + q * stride] = (q - p) * (q - p) + f[p]; + } + } + __syncthreads(); + + for (int q = threadIdx.x; q < N; q += blockDim.x) { + int k = 0; + float q_float = (float)q; + while (z[k + 1] < q_float) k++; + + int p = v[k]; + int64_t global_offset = base_offset + q * stride; + float dist_sq = (float)(q - p) * (q - p) + f[p]; + + out_dist[global_offset] = dist_sq; // 不开方 + + int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; + const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; } } } -// --- Kernel 3: 开平方根内核 --- -/** - * @brief 对张量中的每个元素计算平方根。 - * @details 这是一个简单的逐元素操作。由于之前的1D pass计算的是距离的平方, - * 此内核在所有维度处理完毕后被调用,以得到最终的欧氏距离。 - * @param data 需要进行开方操作的张量数据指针。 - * @param N 张量中的元素总数。 - */ -__global__ void sqrt_kernel(float* data, int64_t N) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < N) { - data[idx] = sqrtf(data[idx]); +// 内核4: 最后一个pass +__global__ void edt_kernel_final( + const float* in_dist, + const int32_t* in_idx, + float* out_dist, + int32_t* out_idx, + const int64_t* shape, + const int64_t* strides, + int32_t ndim, + int32_t process_dim_sample, + int64_t total_slices, + int64_t num_slices_per_sample +) { + int64_t slice_idx = blockIdx.x; + if (slice_idx >= total_slices) return; + + int64_t batch_idx = slice_idx / num_slices_per_sample; + int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; + int64_t batch_offset = batch_idx * strides[0]; + int64_t sample_base_offset = 0; + int64_t temp_idx = slice_idx_in_sample; + const int sample_ndim = ndim - 1; + + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; + } + } + + const int64_t process_dim_actual = process_dim_sample + 1; + const int64_t N = shape[process_dim_actual]; + const int64_t stride = strides[process_dim_actual]; + const int64_t base_offset = batch_offset + sample_base_offset; + + if (N == 0) return; + + extern __shared__ char s_buffer[]; + float* f = (float*)s_buffer; + int* v = (int*)(f + N); + float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); + + for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { + int64_t global_offset = base_offset + i * stride; + f[i] = __ldg(&in_dist[global_offset]); + + const int32_t* global_idx_ptr = in_idx + global_offset * sample_ndim; + int32_t* shared_idx_ptr = s_idx + i * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + shared_idx_ptr[d] = __ldg(&global_idx_ptr[d]); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e20f; + z[1] = 1e20f; + + for (int q = 1; q < N; q++) { + float fq = f[q]; + int q_sq = q * q; + + while (k >= 0) { + int p = v[k]; + float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e20f; + break; + } + k--; + if (k < 0) { + k = 0; + v[0] = q; + z[0] = -1e20f; + z[1] = 1e20f; + break; + } + } + } + } + __syncthreads(); + + for (int q = threadIdx.x; q < N; q += blockDim.x) { + int k = 0; + float q_float = (float)q; + while (z[k + 1] < q_float) k++; + + int p = v[k]; + int64_t global_offset = base_offset + q * stride; + float dist_sq = (float)(q - p) * (q - p) + f[p]; + + out_dist[global_offset] = sqrtf(dist_sq); // 最后开方 + + int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; + const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; + } } } -// --- 主调函数 (Host) --- -/** - * @brief 执行N维欧氏距离变换。 - * @param input 一个N维的PyTorch张量,第一个维度被视为批处理(batch)维度。 - * @return 一个与输入形状相同的张量,包含每个点到最近前景点(值>=0.5)的欧氏距离。 - */ +// Host函数 std::tuple distance_transform_cuda(torch::Tensor input) { - auto original_input = input; - - // --- 预处理: 统一输入格式 --- - // 确保所有输入都至少是3D的 (B, ...),方便后续统一处理。 - // 如果输入是 (H, W) 或 (L),则变为 (1, H, W) 或 (1, L)。 + TORCH_CHECK(input.is_cuda(), "Input must be on a CUDA device."); + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be a float tensor."); + input = input.contiguous(); + bool had_no_batch_dim = (input.dim() <= 2); if (had_no_batch_dim) { input = input.unsqueeze(0); } - // 检查输入张量是否在CUDA上并且是内存连续的 - TORCH_CHECK(input.is_cuda(), "Input must be on a CUDA device."); - TORCH_CHECK(input.is_contiguous(), "Input tensor must be contiguous."); - - - // --- 获取张量元数据 --- const auto ndim = input.dim(); - const auto sample_ndim = ndim - 1; // 空间维度 = 总维度 - 1 (batch) + const auto sample_ndim = ndim - 1; const auto batch_size = input.size(0); - const int64_t N_total = input.numel(); auto shape = input.sizes().vec(); - auto index_shape = shape; - index_shape.push_back(sample_ndim); - auto strides_vec = input.strides().vec(); - // --- 内存分配: 使用Ping-Pong缓冲策略 --- - // 分配两个缓冲区,在处理每个维度时交替作为输入和输出,避免原地读写冲突。 - auto distance = torch::zeros_like(input); - auto index = torch::zeros(index_shape); - auto buffer = (sample_ndim > 0) ? torch::empty_like(input) : distance; - - if (input.numel() == 0) { return std::make_tuple(distance, index); } - - //二值化 - int threads = 256; // 定义每个线程块的线程数 - int blocks = (N_total + threads - 1) / threads; // 计算启动的线程块数 - initialize_distance_kernel<<>>(input.data_ptr(), buffer.data_ptr(), N_total); - - //循环调用 edt_1d_pass_sq_kernel - // 将shape和strides信息从CPU内存拷贝到GPU内存,以便内核可以访问 - int64_t *shape_gpu, *strides_gpu; - cudaMalloc(&shape_gpu, ndim * sizeof(int64_t)); - cudaMalloc(&strides_gpu, ndim * sizeof(int64_t)); - cudaMemcpy(shape_gpu, shape.data(), ndim * sizeof(int64_t), cudaMemcpyHostToDevice); - cudaMemcpy(strides_gpu, strides_vec.data(), ndim * sizeof(int64_t), cudaMemcpyHostToDevice); - - torch::Tensor current_input = buffer; - torch::Tensor current_output = distance; - - // 遍历所有空间维度 - for (int32_t d_sample = 0; d_sample < sample_ndim; ++d_sample) { - // 为当前处理的维度计算启动内核所需的参数 - int64_t num_slices_per_sample = 1; - for(int i = 0; i < sample_ndim; ++i) { - if (i != d_sample) num_slices_per_sample *= shape[i + 1]; - } - int64_t total_slices = batch_size * num_slices_per_sample; - int64_t slice_len = shape[d_sample + 1]; - - // 动态设置线程数和共享内存大小 - int threads_pass = (slice_len > 0 && slice_len < 256) ? slice_len : 256; - if (threads_pass == 0) threads_pass = 1; - size_t shared_mem_size = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + (slice_len + 2) * sizeof(float); - - edt_1d_pass_sq_kernel<<>>( - current_input.data_ptr(), current_output.data_ptr(), - shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample - ); - // 交换输入和输出缓冲区,为下一个维度做准备 - std::swap(current_input, current_output); + if (input.numel() == 0) { + auto distance = torch::empty_like(input); + auto index_shape = shape; + index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); + auto index = torch::empty(index_shape, input.options().dtype(torch::kInt32)); + if (had_no_batch_dim) return std::make_tuple(distance.squeeze(0), index.squeeze(0)); + return std::make_tuple(distance, index); } - cudaFree(shape_gpu); - cudaFree(strides_gpu); + auto distance = torch::empty_like(input); + auto index_options = input.options().dtype(torch::kInt32); + auto index_shape = shape; + index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); + auto index = torch::empty(index_shape, index_options); - //计算最终距离 - // 经过循环后,current_input 指向的是包含最终距离平方结果的张量 - sqrt_kernel<<>>(current_input.data_ptr(), N_total); + if (torch::all(input != 0).item()) { + distance.fill_(std::numeric_limits::infinity()); + index.fill_(-1); + if (had_no_batch_dim) { + return std::make_tuple(distance.squeeze(0), index.squeeze(0)); + } + return std::make_tuple(distance, index); + } + + auto shape_tensor = torch::tensor(shape, + torch::TensorOptions().dtype(torch::kInt64).device(input.device())); + auto strides_tensor = torch::tensor(strides_vec, + torch::TensorOptions().dtype(torch::kInt64).device(input.device())); - // 如果最后一轮的输出不在我们期望的 output 张量里,就做一次拷贝 - if (current_input.data_ptr() != distance.data_ptr()){ - distance.copy_(current_input); + const int64_t* shape_gpu = shape_tensor.data_ptr(); + const int64_t* strides_gpu = strides_tensor.data_ptr(); + + std::vector> dim_order_pairs; + for (int32_t d_sample = 0; d_sample < sample_ndim; ++d_sample) { + dim_order_pairs.push_back({strides_vec[d_sample + 1], d_sample}); + } + std::sort(dim_order_pairs.rbegin(), dim_order_pairs.rend()); + + if (sample_ndim == 0) { + int64_t total_slices = batch_size; + int64_t slice_len = (shape.size() > 1) ? shape[1] : 0; + int threads = std::min((int64_t)256, slice_len); + size_t smem = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + + (slice_len + 2) * sizeof(float) + slice_len * 1 * sizeof(int32_t); + + edt_kernel_first_final<<>>( + input.data_ptr(), + distance.data_ptr(), index.data_ptr(), + shape_gpu, strides_gpu, ndim, 0, total_slices, 1 + ); + } else { + auto buffer_dist = torch::empty_like(distance); + auto buffer_idx = torch::empty_like(index); + + for (int pass = 0; pass < sample_ndim; ++pass) { + int32_t d_sample = dim_order_pairs[pass].second; + bool is_first = (pass == 0); + bool is_final = (pass == sample_ndim - 1); + + torch::Tensor *in_dist, *in_idx, *out_dist, *out_idx; + + if (pass % 2 == 0) { + in_dist = &distance; in_idx = &index; + out_dist = &buffer_dist; out_idx = &buffer_idx; + } else { + in_dist = &buffer_dist; in_idx = &buffer_idx; + out_dist = &distance; out_idx = &index; + } + + int64_t num_slices_per_sample = 1; + for(int i = 0; i < sample_ndim; ++i) { + if (i != d_sample) num_slices_per_sample *= shape[i + 1]; + } + int64_t total_slices = batch_size * num_slices_per_sample; + int64_t slice_len = shape[d_sample + 1]; + + int threads = std::min((int64_t)256, slice_len); + size_t smem = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + + (slice_len + 2) * sizeof(float) + slice_len * sample_ndim * sizeof(int32_t); + + if (is_first && is_final) { + edt_kernel_first_final<<>>( + input.data_ptr(), + out_dist->data_ptr(), out_idx->data_ptr(), + shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + ); + } else if (is_first) { + edt_kernel_first_only<<>>( + input.data_ptr(), + out_dist->data_ptr(), out_idx->data_ptr(), + shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + ); + } else if (is_final) { + edt_kernel_final<<>>( + in_dist->data_ptr(), in_idx->data_ptr(), + out_dist->data_ptr(), out_idx->data_ptr(), + shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + ); + } else { + edt_kernel_middle<<>>( + in_dist->data_ptr(), in_idx->data_ptr(), + out_dist->data_ptr(), out_idx->data_ptr(), + shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + ); + } + } + + if (sample_ndim % 2 != 0) { + distance.copy_(buffer_dist); + index.copy_(buffer_idx); + } } - // 如果最初没有批处理维度,则移除我们添加的维度 - if (had_no_batch_dim) { distance = distance.squeeze(0); } + if (had_no_batch_dim) { + return std::make_tuple(distance.squeeze(0), index.squeeze(0)); + } return std::make_tuple(distance, index); } From 11f067d24d841d0d0071236f8f221f0d7ca30578 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Thu, 20 Nov 2025 19:22:11 +0800 Subject: [PATCH 12/39] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E8=B0=83=E6=95=B4+=E9=80=9F=E5=BA=A6=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_distance_transform.py | 20 +- torchmorph/csrc/distance_transform_kernel.cu | 638 ++++++------------- 2 files changed, 205 insertions(+), 453 deletions(-) diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 3a11166..476ffca 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -6,20 +6,32 @@ # 辅助函数:调用 SciPy 并处理格式 def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - is_single_sample = batch_numpy.ndim <= 2 - if is_single_sample: - batch_numpy = batch_numpy[np.newaxis, ...] + + + input_is_1d_batch = (batch_numpy.ndim == 2) + input_is_single_sample_no_batch = (batch_numpy.ndim == 1) + + if input_is_single_sample_no_batch: + batch_numpy = batch_numpy[np.newaxis, ...] # (L) -> (1, L) + + dist_results, indices_results = [], [] for sample in batch_numpy: dist, indices = scipy_edt(sample, return_indices=True, return_distances=True) dist_results.append(dist) indices_results.append(indices) + output_dist = np.stack(dist_results, axis=0) output_indices = np.stack(indices_results, axis=0) + + # indices shape fix: (N, ndim_sample, ...) -> (N, ..., ndim_sample) + # 对于 1D: (N, 1, L) -> (N, L, 1) output_indices = np.moveaxis(output_indices, 1, -1) - if is_single_sample: + + if input_is_single_sample_no_batch: output_dist = output_dist.squeeze(0) output_indices = output_indices.squeeze(0) + return output_dist, output_indices # 用例定义 diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 90785bd..c3ded1e 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -5,15 +5,21 @@ #include #include -// 优化策略:用4个独立的内核函数替代模板,完全消除分支 - -// 内核1: 第一个pass且是唯一pass (1D情况) -__global__ void edt_kernel_first_final( - const float* in_data, - float* out_dist, - int32_t* out_idx, - const int64_t* shape, - const int64_t* strides, +#define MAX_DIMS 10 +#define INF_VAL 1e8f // 使用 1e8 保证 float32 精度下的数值稳定性 + +__device__ __forceinline__ float sqr(float x) { return x * x; } + +// ------------------------------------------------------------------ +// 内核 1: 初始 Pass (First Pass) +// ------------------------------------------------------------------ +template +__global__ void edt_kernel_first_pass( + const float* __restrict__ in_data, + float* __restrict__ out_dist, + int32_t* __restrict__ out_idx, + const int64_t* __restrict__ shape, + const int64_t* __restrict__ strides, int32_t ndim, int32_t process_dim_sample, int64_t total_slices, @@ -24,244 +30,117 @@ __global__ void edt_kernel_first_final( int64_t batch_idx = slice_idx / num_slices_per_sample; int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t batch_offset = batch_idx * strides[0]; - int64_t sample_base_offset = 0; - int64_t temp_idx = slice_idx_in_sample; - const int sample_ndim = ndim - 1; - - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; - } - } - - const int64_t process_dim_actual = process_dim_sample + 1; - const int64_t N = shape[process_dim_actual]; - const int64_t stride = strides[process_dim_actual]; - const int64_t base_offset = batch_offset + sample_base_offset; + int64_t current_offset = batch_idx * strides[0]; - if (N == 0) return; - - extern __shared__ char s_buffer[]; - float* f = (float*)s_buffer; - int* v = (int*)(f + N); - float* z = (float*)((char*)v + (N + 1) * sizeof(int)); - int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); - - // 加载数据 - 第一个pass - for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - int64_t global_offset = base_offset + i * stride; - float val = __ldg(&in_data[global_offset]); - int32_t* shared_idx_ptr = s_idx + i * sample_ndim; - - if (val == 0.0f) { - f[i] = 0.0f; - int64_t temp_coord = slice_idx_in_sample; - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - shared_idx_ptr[d] = temp_coord % size_of_dim; - temp_coord /= size_of_dim; - } else { - shared_idx_ptr[d] = 0; - } - } - shared_idx_ptr[process_dim_sample] = i; - } else { - f[i] = 1e20f; - for (int d = 0; d < sample_ndim; ++d) shared_idx_ptr[d] = -1; - } - } - __syncthreads(); - - // 构建包络 - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; - z[0] = -1e20f; - z[1] = 1e20f; - - for (int q = 1; q < N; q++) { - float fq = f[q]; - int q_sq = q * q; - - while (k >= 0) { - int p = v[k]; - float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); - if (s > z[k]) { - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e20f; - break; - } - k--; - if (k < 0) { - k = 0; - v[0] = q; - z[0] = -1e20f; - z[1] = 1e20f; - break; - } - } - } - } - __syncthreads(); - - // 计算距离 - 最后一个pass,直接开方 - for (int q = threadIdx.x; q < N; q += blockDim.x) { - int k = 0; - float q_float = (float)q; - while (z[k + 1] < q_float) k++; - - int p = v[k]; - int64_t global_offset = base_offset + q * stride; - float dist_sq = (float)(q - p) * (q - p) + f[p]; - - out_dist[global_offset] = sqrtf(dist_sq); // 直接开方 - - int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; - const int32_t* src_idx_ptr = s_idx + p * sample_ndim; - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; - } - } -} - -// 内核2: 第一个pass但不是最后 -__global__ void edt_kernel_first_only( - const float* in_data, - float* out_dist, - int32_t* out_idx, - const int64_t* shape, - const int64_t* strides, - int32_t ndim, - int32_t process_dim_sample, - int64_t total_slices, - int64_t num_slices_per_sample -) { - int64_t slice_idx = blockIdx.x; - if (slice_idx >= total_slices) return; - - int64_t batch_idx = slice_idx / num_slices_per_sample; - int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t batch_offset = batch_idx * strides[0]; - int64_t sample_base_offset = 0; + // 预计算基准坐标 (除了 process_dim 以外的维度坐标) + int32_t base_coords[MAX_DIMS]; int64_t temp_idx = slice_idx_in_sample; const int sample_ndim = ndim - 1; + // 根据 slice_idx 反解坐标 for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; + if (d == process_dim_sample) { + base_coords[d] = 0; // 占位 + continue; } + int64_t size_of_dim = shape[d + 1]; + int32_t coord = (int32_t)(temp_idx % size_of_dim); + base_coords[d] = coord; + current_offset += coord * strides[d + 1]; + temp_idx /= size_of_dim; } - + const int64_t process_dim_actual = process_dim_sample + 1; const int64_t N = shape[process_dim_actual]; const int64_t stride = strides[process_dim_actual]; - const int64_t base_offset = batch_offset + sample_base_offset; - + if (N == 0) return; extern __shared__ char s_buffer[]; float* f = (float*)s_buffer; int* v = (int*)(f + N); float* z = (float*)((char*)v + (N + 1) * sizeof(int)); - int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); + // Phase 1: 加载数据 for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - int64_t global_offset = base_offset + i * stride; - float val = __ldg(&in_data[global_offset]); - int32_t* shared_idx_ptr = s_idx + i * sample_ndim; - - if (val == 0.0f) { - f[i] = 0.0f; - int64_t temp_coord = slice_idx_in_sample; - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - shared_idx_ptr[d] = temp_coord % size_of_dim; - temp_coord /= size_of_dim; - } else { - shared_idx_ptr[d] = 0; - } - } - shared_idx_ptr[process_dim_sample] = i; - } else { - f[i] = 1e20f; - for (int d = 0; d < sample_ndim; ++d) shared_idx_ptr[d] = -1; - } + float val = __ldg(&in_data[current_offset + i * stride]); + f[i] = (val == 0.0f) ? 0.0f : INF_VAL; } __syncthreads(); + // Phase 2: 构建包络 if (threadIdx.x == 0) { int k = 0; v[0] = 0; - z[0] = -1e20f; - z[1] = 1e20f; + z[0] = -INF_VAL; + z[1] = INF_VAL; for (int q = 1; q < N; q++) { + // 显式跳过背景点,避免 INF 污染计算 + if (f[q] >= (INF_VAL * 0.9f)) continue; + float fq = f[q]; - int q_sq = q * q; - - while (k >= 0) { - int p = v[k]; - float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); - if (s > z[k]) { - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e20f; - break; - } - k--; - if (k < 0) { - k = 0; - v[0] = q; - z[0] = -1e20f; - z[1] = 1e20f; + int k_curr = k; + while (k_curr >= 0) { + int p = v[k_curr]; + + // --- 核心修复:数值稳定的交点公式 --- + // 先计算差值再相加,防止大数吞小数 + float diff_f = fq - f[p]; + float diff_sq = (float)q*(float)q - (float)p*(float)p; + float s = (diff_f + diff_sq) / (2.0f * (float)(q - p)); + + if (s > z[k_curr]) { + k_curr++; + v[k_curr] = q; + z[k_curr] = s; + z[k_curr + 1] = INF_VAL; + k = k_curr; break; } + k_curr--; + } + if (k_curr < 0) { + k = 0; v[0] = q; z[0] = -INF_VAL; z[1] = INF_VAL; } } } __syncthreads(); + // Phase 3: 计算距离 for (int q = threadIdx.x; q < N; q += blockDim.x) { int k = 0; float q_float = (float)q; while (z[k + 1] < q_float) k++; - int p = v[k]; - int64_t global_offset = base_offset + q * stride; - float dist_sq = (float)(q - p) * (q - p) + f[p]; + int p = v[k]; + + int64_t global_idx = current_offset + q * stride; + float dist_sq = sqr(q_float - (float)p) + f[p]; - out_dist[global_offset] = dist_sq; // 不开方 + out_dist[global_idx] = IsFinal ? sqrtf(dist_sq) : dist_sq; - int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; - const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + // 写入索引 + int32_t* out_idx_ptr = out_idx + global_idx * sample_ndim; + #pragma unroll for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; + // 只有当前处理的维度写入 p,其他维度写入基准坐标 + out_idx_ptr[d] = (d == process_dim_sample) ? p : base_coords[d]; } } } -// 内核3: 中间pass -__global__ void edt_kernel_middle( - const float* in_dist, - const int32_t* in_idx, - float* out_dist, - int32_t* out_idx, - const int64_t* shape, - const int64_t* strides, +// ------------------------------------------------------------------ +// 内核 2: 后续 Pass (Subsequent Pass) +// ------------------------------------------------------------------ +template +__global__ void edt_kernel_subsequent_pass( + const float* __restrict__ in_dist, + const int32_t* __restrict__ in_idx, + float* __restrict__ out_dist, + int32_t* __restrict__ out_idx, + const int64_t* __restrict__ shape, + const int64_t* __restrict__ strides, int32_t ndim, int32_t process_dim_sample, int64_t total_slices, @@ -272,24 +151,21 @@ __global__ void edt_kernel_middle( int64_t batch_idx = slice_idx / num_slices_per_sample; int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t batch_offset = batch_idx * strides[0]; - int64_t sample_base_offset = 0; + int64_t current_offset = batch_idx * strides[0]; + int64_t temp_idx = slice_idx_in_sample; const int sample_ndim = ndim - 1; for (int32_t d = sample_ndim - 1; d >= 0; --d) { if (d == process_dim_sample) continue; int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; - } + current_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; } const int64_t process_dim_actual = process_dim_sample + 1; const int64_t N = shape[process_dim_actual]; const int64_t stride = strides[process_dim_actual]; - const int64_t base_offset = batch_offset + sample_base_offset; if (N == 0) return; @@ -297,49 +173,33 @@ __global__ void edt_kernel_middle( float* f = (float*)s_buffer; int* v = (int*)(f + N); float* z = (float*)((char*)v + (N + 1) * sizeof(int)); - int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - int64_t global_offset = base_offset + i * stride; - f[i] = __ldg(&in_dist[global_offset]); - - const int32_t* global_idx_ptr = in_idx + global_offset * sample_ndim; - int32_t* shared_idx_ptr = s_idx + i * sample_ndim; - for (int d = 0; d < sample_ndim; ++d) { - shared_idx_ptr[d] = __ldg(&global_idx_ptr[d]); - } + f[i] = __ldg(&in_dist[current_offset + i * stride]); } __syncthreads(); if (threadIdx.x == 0) { int k = 0; - v[0] = 0; - z[0] = -1e20f; - z[1] = 1e20f; + v[0] = 0; z[0] = -INF_VAL; z[1] = INF_VAL; for (int q = 1; q < N; q++) { + if (f[q] >= (INF_VAL * 0.9f)) continue; + float fq = f[q]; - int q_sq = q * q; - - while (k >= 0) { - int p = v[k]; - float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); - if (s > z[k]) { - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e20f; - break; - } - k--; - if (k < 0) { - k = 0; - v[0] = q; - z[0] = -1e20f; - z[1] = 1e20f; - break; + int k_curr = k; + while (k_curr >= 0) { + int p = v[k_curr]; + float diff_f = fq - f[p]; + float diff_sq = (float)q*(float)q - (float)p*(float)p; + float s = (diff_f + diff_sq) / (2.0f * (float)(q - p)); + if (s > z[k_curr]) { + k_curr++; v[k_curr] = q; z[k_curr] = s; z[k_curr + 1] = INF_VAL; + k = k_curr; break; } + k_curr--; } + if (k_curr < 0) { k = 0; v[0] = q; z[0] = -INF_VAL; z[1] = INF_VAL; } } } __syncthreads(); @@ -350,260 +210,140 @@ __global__ void edt_kernel_middle( while (z[k + 1] < q_float) k++; int p = v[k]; - int64_t global_offset = base_offset + q * stride; - float dist_sq = (float)(q - p) * (q - p) + f[p]; - - out_dist[global_offset] = dist_sq; // 不开方 - - int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; - const int32_t* src_idx_ptr = s_idx + p * sample_ndim; - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; - } - } -} - -// 内核4: 最后一个pass -__global__ void edt_kernel_final( - const float* in_dist, - const int32_t* in_idx, - float* out_dist, - int32_t* out_idx, - const int64_t* shape, - const int64_t* strides, - int32_t ndim, - int32_t process_dim_sample, - int64_t total_slices, - int64_t num_slices_per_sample -) { - int64_t slice_idx = blockIdx.x; - if (slice_idx >= total_slices) return; - - int64_t batch_idx = slice_idx / num_slices_per_sample; - int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t batch_offset = batch_idx * strides[0]; - int64_t sample_base_offset = 0; - int64_t temp_idx = slice_idx_in_sample; - const int sample_ndim = ndim - 1; - - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; - } - } - - const int64_t process_dim_actual = process_dim_sample + 1; - const int64_t N = shape[process_dim_actual]; - const int64_t stride = strides[process_dim_actual]; - const int64_t base_offset = batch_offset + sample_base_offset; - - if (N == 0) return; - - extern __shared__ char s_buffer[]; - float* f = (float*)s_buffer; - int* v = (int*)(f + N); - float* z = (float*)((char*)v + (N + 1) * sizeof(int)); - int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); - - for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - int64_t global_offset = base_offset + i * stride; - f[i] = __ldg(&in_dist[global_offset]); - const int32_t* global_idx_ptr = in_idx + global_offset * sample_ndim; - int32_t* shared_idx_ptr = s_idx + i * sample_ndim; - for (int d = 0; d < sample_ndim; ++d) { - shared_idx_ptr[d] = __ldg(&global_idx_ptr[d]); - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; - z[0] = -1e20f; - z[1] = 1e20f; + int64_t q_global_offset = current_offset + q * stride; + int64_t p_global_offset = current_offset + p * stride; - for (int q = 1; q < N; q++) { - float fq = f[q]; - int q_sq = q * q; - - while (k >= 0) { - int p = v[k]; - float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); - if (s > z[k]) { - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e20f; - break; - } - k--; - if (k < 0) { - k = 0; - v[0] = q; - z[0] = -1e20f; - z[1] = 1e20f; - break; - } - } - } - } - __syncthreads(); + float dist_sq = sqr(q_float - (float)p) + f[p]; + out_dist[q_global_offset] = IsFinal ? sqrtf(dist_sq) : dist_sq; - for (int q = threadIdx.x; q < N; q += blockDim.x) { - int k = 0; - float q_float = (float)q; - while (z[k + 1] < q_float) k++; - - int p = v[k]; - int64_t global_offset = base_offset + q * stride; - float dist_sq = (float)(q - p) * (q - p) + f[p]; - - out_dist[global_offset] = sqrtf(dist_sq); // 最后开方 + // 索引直接从 Global Memory 拷贝,无需 Shared Memory + const int32_t* src_idx_ptr = in_idx + p_global_offset * sample_ndim; + int32_t* out_idx_ptr = out_idx + q_global_offset * sample_ndim; - int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; - const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + #pragma unroll for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; + out_idx_ptr[d] = src_idx_ptr[d]; } } } -// Host函数 +// ------------------------------------------------------------------ +// Host 函数 +// ------------------------------------------------------------------ std::tuple distance_transform_cuda(torch::Tensor input) { - TORCH_CHECK(input.is_cuda(), "Input must be on a CUDA device."); - TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be a float tensor."); + TORCH_CHECK(input.is_cuda(), "Input must be on CUDA device."); + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32."); + input = input.contiguous(); - bool had_no_batch_dim = (input.dim() <= 2); - if (had_no_batch_dim) { input = input.unsqueeze(0); } + // 自动处理 1D 输入:(L) -> (1, L) + // 自动处理 1D 批处理:(N, L) 保持不变 (视为 N 个 1D 样本) + bool had_no_batch_dim = (input.dim() == 1); + if (had_no_batch_dim) { + input = input.unsqueeze(0); + } const auto ndim = input.dim(); const auto sample_ndim = ndim - 1; const auto batch_size = input.size(0); - auto shape = input.sizes().vec(); auto strides_vec = input.strides().vec(); if (input.numel() == 0) { - auto distance = torch::empty_like(input); auto index_shape = shape; index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); - auto index = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - if (had_no_batch_dim) return std::make_tuple(distance.squeeze(0), index.squeeze(0)); - return std::make_tuple(distance, index); + return std::make_tuple(torch::empty_like(input), + torch::empty(index_shape, input.options().dtype(torch::kInt32))); } - + auto distance = torch::empty_like(input); - auto index_options = input.options().dtype(torch::kInt32); auto index_shape = shape; - index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); - auto index = torch::empty(index_shape, index_options); - - if (torch::all(input != 0).item()) { - distance.fill_(std::numeric_limits::infinity()); - index.fill_(-1); - if (had_no_batch_dim) { - return std::make_tuple(distance.squeeze(0), index.squeeze(0)); - } - return std::make_tuple(distance, index); - } + index_shape.push_back(sample_ndim); + auto index = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - auto shape_tensor = torch::tensor(shape, - torch::TensorOptions().dtype(torch::kInt64).device(input.device())); - auto strides_tensor = torch::tensor(strides_vec, - torch::TensorOptions().dtype(torch::kInt64).device(input.device())); - - const int64_t* shape_gpu = shape_tensor.data_ptr(); - const int64_t* strides_gpu = strides_tensor.data_ptr(); + auto buffer_dist = torch::empty_like(distance); + auto buffer_idx = torch::empty_like(index); + + auto shape_tensor = torch::tensor(shape, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); + auto strides_tensor = torch::tensor(strides_vec, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); std::vector> dim_order_pairs; - for (int32_t d_sample = 0; d_sample < sample_ndim; ++d_sample) { - dim_order_pairs.push_back({strides_vec[d_sample + 1], d_sample}); + for (int32_t d = 0; d < sample_ndim; ++d) { + dim_order_pairs.push_back({strides_vec[d + 1], d}); } std::sort(dim_order_pairs.rbegin(), dim_order_pairs.rend()); - if (sample_ndim == 0) { - int64_t total_slices = batch_size; - int64_t slice_len = (shape.size() > 1) ? shape[1] : 0; - int threads = std::min((int64_t)256, slice_len); - size_t smem = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + - (slice_len + 2) * sizeof(float) + slice_len * 1 * sizeof(int32_t); - - edt_kernel_first_final<<>>( - input.data_ptr(), - distance.data_ptr(), index.data_ptr(), - shape_gpu, strides_gpu, ndim, 0, total_slices, 1 - ); - } else { - auto buffer_dist = torch::empty_like(distance); - auto buffer_idx = torch::empty_like(index); - - for (int pass = 0; pass < sample_ndim; ++pass) { - int32_t d_sample = dim_order_pairs[pass].second; - bool is_first = (pass == 0); - bool is_final = (pass == sample_ndim - 1); - - torch::Tensor *in_dist, *in_idx, *out_dist, *out_idx; - - if (pass % 2 == 0) { - in_dist = &distance; in_idx = &index; - out_dist = &buffer_dist; out_idx = &buffer_idx; + for (int pass = 0; pass < sample_ndim; ++pass) { + int32_t d_sample = dim_order_pairs[pass].second; + bool is_first_pass = (pass == 0); + bool is_final_pass = (pass == sample_ndim - 1); + + torch::Tensor *in_d, *in_i, *out_d, *out_i; + + if (is_first_pass) { + in_d = nullptr; in_i = nullptr; + out_d = is_final_pass ? &distance : &buffer_dist; + out_i = is_final_pass ? &index : &buffer_idx; + } else { + if (pass % 2 != 0) { + in_d = &buffer_dist; in_i = &buffer_idx; + out_d = &distance; out_i = &index; } else { - in_dist = &buffer_dist; in_idx = &buffer_idx; - out_dist = &distance; out_idx = &index; + in_d = &distance; in_i = &index; + out_d = &buffer_dist; out_i = &buffer_idx; } - - int64_t num_slices_per_sample = 1; - for(int i = 0; i < sample_ndim; ++i) { - if (i != d_sample) num_slices_per_sample *= shape[i + 1]; + if (is_final_pass) { + out_d = &distance; out_i = &index; } - int64_t total_slices = batch_size * num_slices_per_sample; - int64_t slice_len = shape[d_sample + 1]; - - int threads = std::min((int64_t)256, slice_len); - size_t smem = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + - (slice_len + 2) * sizeof(float) + slice_len * sample_ndim * sizeof(int32_t); - - if (is_first && is_final) { - edt_kernel_first_final<<>>( - input.data_ptr(), - out_dist->data_ptr(), out_idx->data_ptr(), - shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + } + + int64_t num_slices_per_sample = 1; + for(int i = 0; i < sample_ndim; ++i) { + if (i != d_sample) num_slices_per_sample *= shape[i + 1]; + } + int64_t total_slices = batch_size * num_slices_per_sample; + int64_t slice_len = shape[d_sample + 1]; + + int threads = std::min((int64_t)256, slice_len); + size_t smem = slice_len * (sizeof(float) + sizeof(int)) + (slice_len + 1) * sizeof(float); + + if (is_first_pass) { + const float* in_ptr = input.data_ptr(); + if (is_final_pass) { + edt_kernel_first_pass<<>>( + in_ptr, out_d->data_ptr(), out_i->data_ptr(), + shape_tensor.data_ptr(), strides_tensor.data_ptr(), + ndim, d_sample, total_slices, num_slices_per_sample ); - } else if (is_first) { - edt_kernel_first_only<<>>( - input.data_ptr(), - out_dist->data_ptr(), out_idx->data_ptr(), - shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + } else { + edt_kernel_first_pass<<>>( + in_ptr, out_d->data_ptr(), out_i->data_ptr(), + shape_tensor.data_ptr(), strides_tensor.data_ptr(), + ndim, d_sample, total_slices, num_slices_per_sample ); - } else if (is_final) { - edt_kernel_final<<>>( - in_dist->data_ptr(), in_idx->data_ptr(), - out_dist->data_ptr(), out_idx->data_ptr(), - shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + } + } else { + if (is_final_pass) { + edt_kernel_subsequent_pass<<>>( + in_d->data_ptr(), in_i->data_ptr(), + out_d->data_ptr(), out_i->data_ptr(), + shape_tensor.data_ptr(), strides_tensor.data_ptr(), + ndim, d_sample, total_slices, num_slices_per_sample ); } else { - edt_kernel_middle<<>>( - in_dist->data_ptr(), in_idx->data_ptr(), - out_dist->data_ptr(), out_idx->data_ptr(), - shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + edt_kernel_subsequent_pass<<>>( + in_d->data_ptr(), in_i->data_ptr(), + out_d->data_ptr(), out_i->data_ptr(), + shape_tensor.data_ptr(), strides_tensor.data_ptr(), + ndim, d_sample, total_slices, num_slices_per_sample ); } } - - if (sample_ndim % 2 != 0) { - distance.copy_(buffer_dist); - index.copy_(buffer_idx); - } } if (had_no_batch_dim) { return std::make_tuple(distance.squeeze(0), index.squeeze(0)); } - return std::make_tuple(distance, index); -} +} \ No newline at end of file From 4c427c70f036ca24094ca671c9b44ab560ffc0d8 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Wed, 3 Dec 2025 03:32:21 +0800 Subject: [PATCH 13/39] =?UTF-8?q?=E9=87=87=E7=94=A8JFA=E7=AE=97=E6=B3=95?= =?UTF-8?q?=E6=8F=90=E9=AB=98=E5=B9=B6=E8=A1=8C=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torchmorph/csrc/distance_transform_kernel.cu | 290 ++++++++++++------- 1 file changed, 184 insertions(+), 106 deletions(-) diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index c3ded1e..64c69fb 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -4,14 +4,126 @@ #include #include #include +#include +#include #define MAX_DIMS 10 -#define INF_VAL 1e8f // 使用 1e8 保证 float32 精度下的数值稳定性 +#define INF_VAL 1e8f +// 保证 blockDim 足以覆盖大多数常见维度大小,或者配合 Loop 处理 +#define MAX_THREADS 1024 __device__ __forceinline__ float sqr(float x) { return x * x; } +// 计算从像素 q 到源点 p 的距离代价 (考虑了 p 点本身的数值权重 val_p) +__device__ __forceinline__ float compute_cost(int q, int p, float val_p) { + if (p < 0) return INF_VAL; // 无效点 + // dist = (q - p)^2 + f[p] + return sqr((float)q - (float)p) + val_p; +} + // ------------------------------------------------------------------ -// 内核 1: 初始 Pass (First Pass) +// 核心逻辑: 1D Jump Flooding (JFA) / Doubling Algorithm +// 全并行求解最近点索引,替代串行的抛物线构建 +// ------------------------------------------------------------------ +__device__ void compute_1d_jfa( + int N, + float* __restrict__ s_vals, // 输入数值 (dist^2) + int* __restrict__ s_idx_curr, // ping-pong buffer 1 + int* __restrict__ s_idx_next // ping-pong buffer 2 +) { + int tid = threadIdx.x; + + // --- 1. 初始化 --- + // 每个线程负责一个或多个像素的初始化 + for (int i = tid; i < N; i += blockDim.x) { + // 如果当前位置的值很大,说明它是背景,没有初始源点 (-1) + // 否则源点就是它自己 (i) + if (s_vals[i] >= INF_VAL * 0.9f) { + s_idx_curr[i] = -1; + } else { + s_idx_curr[i] = i; + } + } + __syncthreads(); + + // --- 2. 迭代传播 (Step = 1, 2, 4, 8...) --- + // 类似于双调排序或倍增法 + int* idx_in = s_idx_curr; + int* idx_out = s_idx_next; + + // 只要步长小于 N,就需要传播 + // 对于 N=1024, 只需要 10 次迭代,每次所有线程全并行 + for (int step = 1; step < N; step *= 2) { + + for (int i = tid; i < N; i += blockDim.x) { + int my_best_p = idx_in[i]; + float min_cost = INF_VAL; + + // 获取当前最优点的代价 + if (my_best_p != -1) { + min_cost = compute_cost(i, my_best_p, s_vals[my_best_p]); + } + + // --- 检查左边邻居 (i - step) --- + int left = i - step; + if (left >= 0) { + int left_p = idx_in[left]; // 邻居推荐的源点 + if (left_p != -1) { + float c = compute_cost(i, left_p, s_vals[left_p]); + if (c < min_cost) { + min_cost = c; + my_best_p = left_p; + } + } + } + + // --- 检查右边邻居 (i + step) --- + int right = i + step; + if (right < N) { + int right_p = idx_in[right]; // 邻居推荐的源点 + if (right_p != -1) { + float c = compute_cost(i, right_p, s_vals[right_p]); + if (c < min_cost) { + min_cost = c; + my_best_p = right_p; + } + } + } + + // 写入下一轮 Buffer + idx_out[i] = my_best_p; + } + + // 交换 Buffer 指针 + int* temp = idx_in; + idx_in = idx_out; + idx_out = temp; + + __syncthreads(); + } + + // --- 3. 结果写回 --- + // 如果最后结果在 s_idx_next 里 (循环次数是奇数),需要拷回 s_idx_curr + // 或者直接让调用者知道结果在哪。 + // 为了简单,我们统一把结果放在 s_idx_curr 指向的内存里。 + // 注意:idx_in 现在指向的是包含最新结果的 buffer。 + + // 如果 idx_in 已经指向 s_idx_curr,那不用动。 + // 如果 idx_in 指向 s_idx_next,说明最新结果在 s_idx_next,我们需要把它拷贝回 s_idx_curr + // 或者是调整后续代码读取的指针。 + + // 这里采用简单拷贝回 s_idx_curr 的方式,确保后续逻辑一致 + if (idx_in != s_idx_curr) { + for (int i = tid; i < N; i += blockDim.x) { + s_idx_curr[i] = s_idx_next[i]; + } + __syncthreads(); + } +} + + +// ------------------------------------------------------------------ +// 内核 1: 初始 Pass (JFA Version) // ------------------------------------------------------------------ template __global__ void edt_kernel_first_pass( @@ -32,15 +144,13 @@ __global__ void edt_kernel_first_pass( int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; int64_t current_offset = batch_idx * strides[0]; - // 预计算基准坐标 (除了 process_dim 以外的维度坐标) int32_t base_coords[MAX_DIMS]; int64_t temp_idx = slice_idx_in_sample; const int sample_ndim = ndim - 1; - // 根据 slice_idx 反解坐标 for (int32_t d = sample_ndim - 1; d >= 0; --d) { if (d == process_dim_sample) { - base_coords[d] = 0; // 占位 + base_coords[d] = 0; continue; } int64_t size_of_dim = shape[d + 1]; @@ -56,82 +166,58 @@ __global__ void edt_kernel_first_pass( if (N == 0) return; + // Shared Memory Layout: + // f: float[N] (Values) + // idx1: int[N] (Buffer 1) + // idx2: int[N] (Buffer 2) extern __shared__ char s_buffer[]; float* f = (float*)s_buffer; - int* v = (int*)(f + N); - float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int* idx1 = (int*)(f + N); + int* idx2 = (int*)(idx1 + N); - // Phase 1: 加载数据 + // Phase 1: 并行加载数据 for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { float val = __ldg(&in_data[current_offset + i * stride]); f[i] = (val == 0.0f) ? 0.0f : INF_VAL; } __syncthreads(); - // Phase 2: 构建包络 - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; - z[0] = -INF_VAL; - z[1] = INF_VAL; - - for (int q = 1; q < N; q++) { - // 显式跳过背景点,避免 INF 污染计算 - if (f[q] >= (INF_VAL * 0.9f)) continue; - - float fq = f[q]; - int k_curr = k; - while (k_curr >= 0) { - int p = v[k_curr]; - - // --- 核心修复:数值稳定的交点公式 --- - // 先计算差值再相加,防止大数吞小数 - float diff_f = fq - f[p]; - float diff_sq = (float)q*(float)q - (float)p*(float)p; - float s = (diff_f + diff_sq) / (2.0f * (float)(q - p)); - - if (s > z[k_curr]) { - k_curr++; - v[k_curr] = q; - z[k_curr] = s; - z[k_curr + 1] = INF_VAL; - k = k_curr; - break; - } - k_curr--; - } - if (k_curr < 0) { - k = 0; v[0] = q; z[0] = -INF_VAL; z[1] = INF_VAL; - } - } - } - __syncthreads(); + // Phase 2: 并行 JFA 计算 + compute_1d_jfa(N, f, idx1, idx2); + // 结果现在存储在 idx1 中 - // Phase 3: 计算距离 + // Phase 3: 并行写回 for (int q = threadIdx.x; q < N; q += blockDim.x) { - int k = 0; - float q_float = (float)q; - while (z[k + 1] < q_float) k++; - - int p = v[k]; - + int p = idx1[q]; + float dist_val; + int p_idx; + + if (p != -1) { + // JFA 得到的是最近源点的索引 p + // 距离 = (q-p)^2 + f[p] + // 注意:在 First Pass 中,f[p] 要么是 0 要么是 INF。如果 p != -1,f[p] 必为 0。 + // 但为了通用性,还是加上 f[p] + float dist_sq = sqr((float)q - (float)p) + f[p]; + dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; + p_idx = p; + } else { + dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); + p_idx = 0; + } + int64_t global_idx = current_offset + q * stride; - float dist_sq = sqr(q_float - (float)p) + f[p]; - - out_dist[global_idx] = IsFinal ? sqrtf(dist_sq) : dist_sq; + out_dist[global_idx] = dist_val; - // 写入索引 int32_t* out_idx_ptr = out_idx + global_idx * sample_ndim; #pragma unroll for (int d = 0; d < sample_ndim; ++d) { - // 只有当前处理的维度写入 p,其他维度写入基准坐标 - out_idx_ptr[d] = (d == process_dim_sample) ? p : base_coords[d]; + out_idx_ptr[d] = (d == process_dim_sample) ? p_idx : base_coords[d]; } } } // ------------------------------------------------------------------ -// 内核 2: 后续 Pass (Subsequent Pass) +// 内核 2: 后续 Pass (JFA Version) // ------------------------------------------------------------------ template __global__ void edt_kernel_subsequent_pass( @@ -169,61 +255,48 @@ __global__ void edt_kernel_subsequent_pass( if (N == 0) return; + // Shared Memory Layout 同上 extern __shared__ char s_buffer[]; float* f = (float*)s_buffer; - int* v = (int*)(f + N); - float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int* idx1 = (int*)(f + N); + int* idx2 = (int*)(idx1 + N); + // Phase 1: 加载 for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { f[i] = __ldg(&in_dist[current_offset + i * stride]); } __syncthreads(); - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; z[0] = -INF_VAL; z[1] = INF_VAL; - - for (int q = 1; q < N; q++) { - if (f[q] >= (INF_VAL * 0.9f)) continue; - - float fq = f[q]; - int k_curr = k; - while (k_curr >= 0) { - int p = v[k_curr]; - float diff_f = fq - f[p]; - float diff_sq = (float)q*(float)q - (float)p*(float)p; - float s = (diff_f + diff_sq) / (2.0f * (float)(q - p)); - if (s > z[k_curr]) { - k_curr++; v[k_curr] = q; z[k_curr] = s; z[k_curr + 1] = INF_VAL; - k = k_curr; break; - } - k_curr--; - } - if (k_curr < 0) { k = 0; v[0] = q; z[0] = -INF_VAL; z[1] = INF_VAL; } - } - } - __syncthreads(); + // Phase 2: 并行 JFA 计算 + // 这里的 f[i] 是上一轮计算出的距离平方,作为权重 + compute_1d_jfa(N, f, idx1, idx2); + // Phase 3: 写回 for (int q = threadIdx.x; q < N; q += blockDim.x) { - int k = 0; - float q_float = (float)q; - while (z[k + 1] < q_float) k++; - - int p = v[k]; - + int p = idx1[q]; // 最近源点在当前行的索引 + float dist_val; + + if (p != -1) { + float dist_sq = sqr((float)q - (float)p) + f[p]; + dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; + } else { + dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); + p = 0; // fallback + } + int64_t q_global_offset = current_offset + q * stride; - int64_t p_global_offset = current_offset + p * stride; - - float dist_sq = sqr(q_float - (float)p) + f[p]; - out_dist[q_global_offset] = IsFinal ? sqrtf(dist_sq) : dist_sq; + out_dist[q_global_offset] = dist_val; - // 索引直接从 Global Memory 拷贝,无需 Shared Memory - const int32_t* src_idx_ptr = in_idx + p_global_offset * sample_ndim; - int32_t* out_idx_ptr = out_idx + q_global_offset * sample_ndim; + // 索引处理 + if (p != -1) { + int64_t p_global_offset = current_offset + p * stride; + const int32_t* src_idx_ptr = in_idx + p_global_offset * sample_ndim; + int32_t* out_idx_ptr = out_idx + q_global_offset * sample_ndim; - #pragma unroll - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; + #pragma unroll + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; + } } } } @@ -237,8 +310,6 @@ std::tuple distance_transform_cuda(torch::Tensor i input = input.contiguous(); - // 自动处理 1D 输入:(L) -> (1, L) - // 自动处理 1D 批处理:(N, L) 保持不变 (视为 N 个 1D 样本) bool had_no_batch_dim = (input.dim() == 1); if (had_no_batch_dim) { input = input.unsqueeze(0); @@ -305,8 +376,15 @@ std::tuple distance_transform_cuda(torch::Tensor i int64_t total_slices = batch_size * num_slices_per_sample; int64_t slice_len = shape[d_sample + 1]; - int threads = std::min((int64_t)256, slice_len); - size_t smem = slice_len * (sizeof(float) + sizeof(int)) + (slice_len + 1) * sizeof(float); + int threads = std::min((int64_t)MAX_THREADS, slice_len); + + // JFA 需要的 Shared Memory: + // float f[N] + // int idx1[N] + // int idx2[N] + // 总共 slice_len * (4 + 4 + 4) = 12 * slice_len bytes + size_t smem = slice_len * sizeof(float) + + slice_len * sizeof(int) * 2; if (is_first_pass) { const float* in_ptr = input.data_ptr(); From 9291a49cf655aa938c20fe031c4b6305b9997521 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Wed, 3 Dec 2025 04:07:18 +0800 Subject: [PATCH 14/39] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=90=88=E5=B9=B6?= =?UTF-8?q?=E5=86=85=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torchmorph/csrc/distance_transform_kernel.cu | 588 ++++++++++--------- 1 file changed, 298 insertions(+), 290 deletions(-) diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 64c69fb..24ba467 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -1,75 +1,64 @@ #include #include -#include -#include #include -#include +#include #include #include -#define MAX_DIMS 10 -#define INF_VAL 1e8f -// 保证 blockDim 足以覆盖大多数常见维度大小,或者配合 Loop 处理 -#define MAX_THREADS 1024 +#define INF_VAL 1e8f +#define MAX_THREADS 1024 +#define SMEM_LIMIT_ELEMENTS 4096 // 48KB / 12 bytes (float+int+int) ~= 4096 __device__ __forceinline__ float sqr(float x) { return x * x; } -// 计算从像素 q 到源点 p 的距离代价 (考虑了 p 点本身的数值权重 val_p) +// 计算从像素 q 到源点 p 的距离代价 +// val_p 是源点 p 在上一轮计算后的距离平方值 (weight) __device__ __forceinline__ float compute_cost(int q, int p, float val_p) { - if (p < 0) return INF_VAL; // 无效点 - // dist = (q - p)^2 + f[p] + if (p < 0) return INF_VAL; return sqr((float)q - (float)p) + val_p; } // ------------------------------------------------------------------ -// 核心逻辑: 1D Jump Flooding (JFA) / Doubling Algorithm -// 全并行求解最近点索引,替代串行的抛物线构建 +// JFA 核心逻辑 (Device Function) // ------------------------------------------------------------------ -__device__ void compute_1d_jfa( +// 无论数据是在 Shared Memory 还是 Global Memory,逻辑是一样的 +__device__ void run_jfa_core( int N, - float* __restrict__ s_vals, // 输入数值 (dist^2) - int* __restrict__ s_idx_curr, // ping-pong buffer 1 - int* __restrict__ s_idx_next // ping-pong buffer 2 + int tid, + const float* __restrict__ vals, // 输入权重 (只读) + int* __restrict__ idx_curr, // Ping-Pong Buffer A + int* __restrict__ idx_next // Ping-Pong Buffer B ) { - int tid = threadIdx.x; - - // --- 1. 初始化 --- - // 每个线程负责一个或多个像素的初始化 + // 1. 初始化 for (int i = tid; i < N; i += blockDim.x) { - // 如果当前位置的值很大,说明它是背景,没有初始源点 (-1) - // 否则源点就是它自己 (i) - if (s_vals[i] >= INF_VAL * 0.9f) { - s_idx_curr[i] = -1; + // 如果输入值很大,说明是背景,没有初始源点 + if (vals[i] >= INF_VAL * 0.9f) { + idx_curr[i] = -1; } else { - s_idx_curr[i] = i; + idx_curr[i] = i; } } __syncthreads(); - // --- 2. 迭代传播 (Step = 1, 2, 4, 8...) --- - // 类似于双调排序或倍增法 - int* idx_in = s_idx_curr; - int* idx_out = s_idx_next; + // 2. 迭代传播 (Step = 1, 2, 4, ... < N) + int* idx_in = idx_curr; + int* idx_out = idx_next; - // 只要步长小于 N,就需要传播 - // 对于 N=1024, 只需要 10 次迭代,每次所有线程全并行 for (int step = 1; step < N; step *= 2) { - for (int i = tid; i < N; i += blockDim.x) { int my_best_p = idx_in[i]; float min_cost = INF_VAL; - // 获取当前最优点的代价 if (my_best_p != -1) { - min_cost = compute_cost(i, my_best_p, s_vals[my_best_p]); + min_cost = compute_cost(i, my_best_p, vals[my_best_p]); } - // --- 检查左边邻居 (i - step) --- + // Check Left int left = i - step; if (left >= 0) { - int left_p = idx_in[left]; // 邻居推荐的源点 + int left_p = idx_in[left]; if (left_p != -1) { - float c = compute_cost(i, left_p, s_vals[left_p]); + float c = compute_cost(i, left_p, vals[left_p]); if (c < min_cost) { min_cost = c; my_best_p = left_p; @@ -77,230 +66,206 @@ __device__ void compute_1d_jfa( } } - // --- 检查右边邻居 (i + step) --- + // Check Right int right = i + step; if (right < N) { - int right_p = idx_in[right]; // 邻居推荐的源点 + int right_p = idx_in[right]; if (right_p != -1) { - float c = compute_cost(i, right_p, s_vals[right_p]); + float c = compute_cost(i, right_p, vals[right_p]); if (c < min_cost) { min_cost = c; my_best_p = right_p; } } } - - // 写入下一轮 Buffer idx_out[i] = my_best_p; } - // 交换 Buffer 指针 + // Swap Pointers int* temp = idx_in; idx_in = idx_out; idx_out = temp; - __syncthreads(); } - // --- 3. 结果写回 --- - // 如果最后结果在 s_idx_next 里 (循环次数是奇数),需要拷回 s_idx_curr - // 或者直接让调用者知道结果在哪。 - // 为了简单,我们统一把结果放在 s_idx_curr 指向的内存里。 - // 注意:idx_in 现在指向的是包含最新结果的 buffer。 - - // 如果 idx_in 已经指向 s_idx_curr,那不用动。 - // 如果 idx_in 指向 s_idx_next,说明最新结果在 s_idx_next,我们需要把它拷贝回 s_idx_curr - // 或者是调整后续代码读取的指针。 - - // 这里采用简单拷贝回 s_idx_curr 的方式,确保后续逻辑一致 - if (idx_in != s_idx_curr) { + // 3. 确保最终结果在 idx_curr (如果循环结束时在 next,则拷回) + if (idx_in != idx_curr) { for (int i = tid; i < N; i += blockDim.x) { - s_idx_curr[i] = s_idx_next[i]; + idx_curr[i] = idx_next[i]; } __syncthreads(); } } - // ------------------------------------------------------------------ -// 内核 1: 初始 Pass (JFA Version) +// Kernel 1: Shared Memory JFA (Fast Path) +// 适用于 N <= 4096 // ------------------------------------------------------------------ -template -__global__ void edt_kernel_first_pass( - const float* __restrict__ in_data, - float* __restrict__ out_dist, - int32_t* __restrict__ out_idx, - const int64_t* __restrict__ shape, - const int64_t* __restrict__ strides, - int32_t ndim, - int32_t process_dim_sample, - int64_t total_slices, - int64_t num_slices_per_sample +template +__global__ void edt_kernel_shared( + const float* __restrict__ in_data, // 当前维度的输入 (dist^2) + const int32_t* __restrict__ in_indices, // 上一轮的索引图 (N_slices, L, NDim) + float* __restrict__ out_dist, // 输出距离 + int32_t* __restrict__ out_indices, // 输出索引图 + int64_t L, // 当前维度的长度 (Length) + int64_t total_elements // Batch * ... * L ) { - int64_t slice_idx = blockIdx.x; - if (slice_idx >= total_slices) return; - - int64_t batch_idx = slice_idx / num_slices_per_sample; - int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t current_offset = batch_idx * strides[0]; - - int32_t base_coords[MAX_DIMS]; - int64_t temp_idx = slice_idx_in_sample; - const int sample_ndim = ndim - 1; - - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) { - base_coords[d] = 0; - continue; - } - int64_t size_of_dim = shape[d + 1]; - int32_t coord = (int32_t)(temp_idx % size_of_dim); - base_coords[d] = coord; - current_offset += coord * strides[d + 1]; - temp_idx /= size_of_dim; - } - - const int64_t process_dim_actual = process_dim_sample + 1; - const int64_t N = shape[process_dim_actual]; - const int64_t stride = strides[process_dim_actual]; + // 这里的 total_elements 是展平后的总像素数 + // 由于我们做了 transpose,数据布局是 [Batch_and_other_dims, L] + // 每个 Block 处理一行 (长度 L) + + int64_t row_idx = blockIdx.x; + int64_t offset = row_idx * L; - if (N == 0) return; + if (offset >= total_elements) return; - // Shared Memory Layout: - // f: float[N] (Values) - // idx1: int[N] (Buffer 1) - // idx2: int[N] (Buffer 2) + // Shared Memory 布局: float vals[L], int idx1[L], int idx2[L] extern __shared__ char s_buffer[]; - float* f = (float*)s_buffer; - int* idx1 = (int*)(f + N); - int* idx2 = (int*)(idx1 + N); - - // Phase 1: 并行加载数据 - for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - float val = __ldg(&in_data[current_offset + i * stride]); - f[i] = (val == 0.0f) ? 0.0f : INF_VAL; + float* s_vals = (float*)s_buffer; + int* s_idx1 = (int*)(s_vals + L); + int* s_idx2 = (int*)(s_idx1 + L); + + // 1. 加载数据到 Shared Memory + for (int i = threadIdx.x; i < L; i += blockDim.x) { + float val = __ldg(&in_data[offset + i]); + // 如果是初始 Pass (无输入索引),val 为 0 或 INF + // 如果是后续 Pass,val 为上一步的 dist^2 + s_vals[i] = val; } __syncthreads(); - // Phase 2: 并行 JFA 计算 - compute_1d_jfa(N, f, idx1, idx2); - // 结果现在存储在 idx1 中 - - // Phase 3: 并行写回 - for (int q = threadIdx.x; q < N; q += blockDim.x) { - int p = idx1[q]; + // 2. 运行 JFA + run_jfa_core(L, threadIdx.x, s_vals, s_idx1, s_idx2); + + // 3. 写回结果 + for (int q = threadIdx.x; q < L; q += blockDim.x) { + int p = s_idx1[q]; // 最近点在当前行内的局部索引 (0..L-1) float dist_val; - int p_idx; if (p != -1) { - // JFA 得到的是最近源点的索引 p - // 距离 = (q-p)^2 + f[p] - // 注意:在 First Pass 中,f[p] 要么是 0 要么是 INF。如果 p != -1,f[p] 必为 0。 - // 但为了通用性,还是加上 f[p] - float dist_sq = sqr((float)q - (float)p) + f[p]; + // 计算新距离: (q-p)^2 + val[p] + float dist_sq = sqr((float)q - (float)p) + s_vals[p]; dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; - p_idx = p; } else { dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p_idx = 0; + p = 0; // fallback } - int64_t global_idx = current_offset + q * stride; - out_dist[global_idx] = dist_val; + out_dist[offset + q] = dist_val; - int32_t* out_idx_ptr = out_idx + global_idx * sample_ndim; - #pragma unroll - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = (d == process_dim_sample) ? p_idx : base_coords[d]; + // 4. 索引传播 + // 我们需要从 in_indices 查找完整的高维索引 + // in_indices 形状: [Batch..., L, NDim] + // 这里的 offset 对应 [Batch..., 0] + // p 是当前维度的偏移 + if (p != -1) { + int64_t src_offset = (offset + p) * NDim; + int64_t dst_offset = (offset + q) * NDim; + + // 手动展开拷贝,或者循环 + for (int d = 0; d < NDim; ++d) { + out_indices[dst_offset + d] = in_indices[src_offset + d]; + } + } else { + // 保持原样或填0 (通常保持原样即可,或者为了安全填0) + int64_t dst_offset = (offset + q) * NDim; + for (int d = 0; d < NDim; ++d) out_indices[dst_offset + d] = 0; } } } // ------------------------------------------------------------------ -// 内核 2: 后续 Pass (JFA Version) +// Kernel 2: Global Memory JFA (Fallback Path) +// 适用于 N > 4096,使用 Global Memory 作为 Ping-Pong Buffer // ------------------------------------------------------------------ -template -__global__ void edt_kernel_subsequent_pass( - const float* __restrict__ in_dist, - const int32_t* __restrict__ in_idx, +template +__global__ void edt_kernel_global( + const float* __restrict__ in_data, + const int32_t* __restrict__ in_indices, float* __restrict__ out_dist, - int32_t* __restrict__ out_idx, - const int64_t* __restrict__ shape, - const int64_t* __restrict__ strides, - int32_t ndim, - int32_t process_dim_sample, - int64_t total_slices, - int64_t num_slices_per_sample + int32_t* __restrict__ out_indices, + int* __restrict__ global_buffer_1, // 临时 buffer A [TotalElements] + int* __restrict__ global_buffer_2, // 临时 buffer B [TotalElements] + int64_t L, + int64_t total_elements ) { - int64_t slice_idx = blockIdx.x; - if (slice_idx >= total_slices) return; - - int64_t batch_idx = slice_idx / num_slices_per_sample; - int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t current_offset = batch_idx * strides[0]; + int64_t row_idx = blockIdx.x; + int64_t offset = row_idx * L; - int64_t temp_idx = slice_idx_in_sample; - const int sample_ndim = ndim - 1; - - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - current_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; - } - - const int64_t process_dim_actual = process_dim_sample + 1; - const int64_t N = shape[process_dim_actual]; - const int64_t stride = strides[process_dim_actual]; - - if (N == 0) return; - - // Shared Memory Layout 同上 - extern __shared__ char s_buffer[]; - float* f = (float*)s_buffer; - int* idx1 = (int*)(f + N); - int* idx2 = (int*)(idx1 + N); + if (offset >= total_elements) return; - // Phase 1: 加载 - for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - f[i] = __ldg(&in_dist[current_offset + i * stride]); - } - __syncthreads(); - - // Phase 2: 并行 JFA 计算 - // 这里的 f[i] 是上一轮计算出的距离平方,作为权重 - compute_1d_jfa(N, f, idx1, idx2); + // 指向当前行在 Global Memory 中的位置 + // 注意:in_data 是只读的,我们需要把它当做 weight + // JFA 需要两个 int buffer 来存 index + int* g_idx1 = global_buffer_1 + offset; + int* g_idx2 = global_buffer_2 + offset; + + // 直接在 Global Memory 上运行 JFA + // 注意:这里 vals 指针直接指向 in_data (Global),读取稍慢但无需拷贝 + run_jfa_core(L, threadIdx.x, in_data + offset, g_idx1, g_idx2); - // Phase 3: 写回 - for (int q = threadIdx.x; q < N; q += blockDim.x) { - int p = idx1[q]; // 最近源点在当前行的索引 + // 写回逻辑同上 + for (int q = threadIdx.x; q < L; q += blockDim.x) { + int p = g_idx1[q]; float dist_val; if (p != -1) { - float dist_sq = sqr((float)q - (float)p) + f[p]; + float val_p = in_data[offset + p]; + float dist_sq = sqr((float)q - (float)p) + val_p; dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; } else { dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p = 0; // fallback + p = 0; } - int64_t q_global_offset = current_offset + q * stride; - out_dist[q_global_offset] = dist_val; + out_dist[offset + q] = dist_val; - // 索引处理 if (p != -1) { - int64_t p_global_offset = current_offset + p * stride; - const int32_t* src_idx_ptr = in_idx + p_global_offset * sample_ndim; - int32_t* out_idx_ptr = out_idx + q_global_offset * sample_ndim; - - #pragma unroll - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; + int64_t src_offset = (offset + p) * NDim; + int64_t dst_offset = (offset + q) * NDim; + for (int d = 0; d < NDim; ++d) { + out_indices[dst_offset + d] = in_indices[src_offset + d]; } + } else { + int64_t dst_offset = (offset + q) * NDim; + for (int d = 0; d < NDim; ++d) out_indices[dst_offset + d] = 0; } } } + +// ------------------------------------------------------------------ +// 辅助:初始化索引张量 +// ------------------------------------------------------------------ +// 将 index tensor 初始化为 grid grid coordinates +// shape: (..., D), 最后一个维度存坐标 +__global__ void init_indices_kernel(int32_t* indices, int64_t total_elements, int NDim, + const int64_t* shape, const int64_t* strides) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_elements) return; + + // 反解坐标 + int64_t temp = idx; + int32_t coords[10]; // max dims + + // strides 是针对 elements 展开的,但这里 indices 是 (Total, NDim) + // 我们可以简单地根据 shape 反解 + // 注意:这里的 total_elements 是像素数,不是 indices 数组的大小 + + // 假设 shape 是 [D0, D1, D2] + // idx 对应 flat index + + for (int d = NDim - 1; d >= 0; --d) { + coords[d] = temp % shape[d]; + temp /= shape[d]; + } + + // 写入 + for (int d = 0; d < NDim; ++d) { + indices[idx * NDim + d] = coords[d]; + } +} + // ------------------------------------------------------------------ // Host 函数 // ------------------------------------------------------------------ @@ -309,119 +274,162 @@ std::tuple distance_transform_cuda(torch::Tensor i TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32."); input = input.contiguous(); + bool had_no_batch_dim = (input.dim() == 1); + if (had_no_batch_dim) input = input.unsqueeze(0); - bool had_no_batch_dim = (input.dim() == 1); - if (had_no_batch_dim) { - input = input.unsqueeze(0); - } - - const auto ndim = input.dim(); - const auto sample_ndim = ndim - 1; - const auto batch_size = input.size(0); + const int ndim = input.dim(); // Include batch + const int sample_ndim = ndim - 1; auto shape = input.sizes().vec(); - auto strides_vec = input.strides().vec(); - - if (input.numel() == 0) { + int64_t num_pixels = input.numel(); + + if (num_pixels == 0) { auto index_shape = shape; index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); return std::make_tuple(torch::empty_like(input), - torch::empty(index_shape, input.options().dtype(torch::kInt32))); + torch::empty(index_shape, input.options().dtype(torch::kInt32))); } - auto distance = torch::empty_like(input); + // 1. 初始化输出 Tensor + // current_dist 在迭代过程中存储 dist^2,最后开方 + // 初始状态:Input 里的 0 还是 0,其他非 0 (背景) 设为 INF + auto current_dist = torch::where(input == 0, + torch::tensor(0.0f, input.options()), + torch::tensor(INF_VAL, input.options())); + + // 初始化索引 Map (Batch, ..., NDim) auto index_shape = shape; - index_shape.push_back(sample_ndim); - auto index = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - - auto buffer_dist = torch::empty_like(distance); - auto buffer_idx = torch::empty_like(index); - - auto shape_tensor = torch::tensor(shape, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); - auto strides_tensor = torch::tensor(strides_vec, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); - - std::vector> dim_order_pairs; - for (int32_t d = 0; d < sample_ndim; ++d) { - dim_order_pairs.push_back({strides_vec[d + 1], d}); + index_shape.push_back(sample_ndim); + auto current_idx = torch::empty(index_shape, input.options().dtype(torch::kInt32)); + + // 启动 Kernel 初始化索引 + // 为了反解坐标,我们需要把 shape 传进去 + { + // 排除 batch 维度的 shape 用于坐标计算? + // 需求是:返回的索引是 (batch_idx, z, y, x) 还是只是 (z, y, x)? + // 通常 EDT 返回的是 sample 内的坐标。所以我们忽略 batch 维度。 + std::vector sample_shape_vec(shape.begin() + 1, shape.end()); + auto sample_shape_tensor = torch::tensor(sample_shape_vec, torch::kInt64).to(input.device()); + // 这里的 strides 不需要,直接由 shape 反解 + + int threads = 256; + int blocks = (num_pixels + threads - 1) / threads; + + // 我们需要传递 sample_ndim + init_indices_kernel<<>>( + current_idx.data_ptr(), + num_pixels, + sample_ndim, + sample_shape_tensor.data_ptr(), + nullptr // strides not needed for simple unravel + ); } - std::sort(dim_order_pairs.rbegin(), dim_order_pairs.rend()); + + // 用于 Global Memory Fallback 的临时 buffer + torch::Tensor global_buf1, global_buf2; + + // 2. 逐维处理 (Separable Phases) + // 从最后一个维度倒着处理,或者顺序处理都可以。 + // 为了 Host Transpose 方便,我们遍历 sample 的每一个维度 (1 到 ndim-1) + for (int d = 1; d < ndim; ++d) { + bool is_final_pass = (d == ndim - 1); + + // ----------------------------------------------------------- + // Step A: Permute & Contiguous + // 将当前处理维度 d 移到最后: (0, 1, ..., d, ..., N-1) -> (0, 1, ..., N-1, d) + // 这样最后内存布局就是 [..., L],stride=1 + // ----------------------------------------------------------- + + // 这种 swap 策略比较简单: transpose(d, -1) + // 注意:index tensor 也要变换,但 index tensor 最后一维是 coord_dim,不能乱动。 + // Index tensor 形状是 [..., sample_ndim]。 + // 我们需要变换的是前面的空间维度 [...]。 + + auto dist_in = current_dist.transpose(d, ndim - 1).contiguous(); + auto idx_in = current_idx.transpose(d, ndim - 1).contiguous(); + + // 此时 dist_in shape: [..., L] + // idx_in shape: [..., L, sample_ndim] + + int64_t L = dist_in.size(-1); // 当前维度的长度 + int64_t total_slices = dist_in.numel() / L; // 有多少行 + + auto dist_out = torch::empty_like(dist_in); + auto idx_out = torch::empty_like(idx_in); - for (int pass = 0; pass < sample_ndim; ++pass) { - int32_t d_sample = dim_order_pairs[pass].second; - bool is_first_pass = (pass == 0); - bool is_final_pass = (pass == sample_ndim - 1); + // ----------------------------------------------------------- + // Step B: Kernel Dispatch + // ----------------------------------------------------------- + int threads = std::min((int64_t)MAX_THREADS, L); + + // 检查 Shared Memory 需求 + // Need: float(4) + int(4) + int(4) = 12 bytes per pixel + if (L <= SMEM_LIMIT_ELEMENTS) { + size_t smem_size = L * (sizeof(float) + 2 * sizeof(int)); + + // 模板参数 NDim 需要是编译期常量。 + // 动态分发 sample_ndim (1D, 2D, 3D usually) + // 使用 switch case 覆盖常见维度 (1, 2, 3) + #define DISPATCH_SHARED(IS_FINAL) \ + switch(sample_ndim) { \ + case 1: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel()); break; \ + case 2: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel()); break; \ + case 3: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel()); break; \ + default: /* fallback for >3D */ break; \ + } - torch::Tensor *in_d, *in_i, *out_d, *out_i; + if (is_final_pass) { DISPATCH_SHARED(true); } + else { DISPATCH_SHARED(false); } - if (is_first_pass) { - in_d = nullptr; in_i = nullptr; - out_d = is_final_pass ? &distance : &buffer_dist; - out_i = is_final_pass ? &index : &buffer_idx; } else { - if (pass % 2 != 0) { - in_d = &buffer_dist; in_i = &buffer_idx; - out_d = &distance; out_i = &index; - } else { - in_d = &distance; in_i = &index; - out_d = &buffer_dist; out_i = &buffer_idx; - } - if (is_final_pass) { - out_d = &distance; out_i = &index; + // Fallback: Global Memory + // 需要分配 buffer: [total_slices * L] = [numel] + if (global_buf1.numel() < dist_in.numel()) { + global_buf1 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); + global_buf2 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); } - } - int64_t num_slices_per_sample = 1; - for(int i = 0; i < sample_ndim; ++i) { - if (i != d_sample) num_slices_per_sample *= shape[i + 1]; - } - int64_t total_slices = batch_size * num_slices_per_sample; - int64_t slice_len = shape[d_sample + 1]; - - int threads = std::min((int64_t)MAX_THREADS, slice_len); - - // JFA 需要的 Shared Memory: - // float f[N] - // int idx1[N] - // int idx2[N] - // 总共 slice_len * (4 + 4 + 4) = 12 * slice_len bytes - size_t smem = slice_len * sizeof(float) + - slice_len * sizeof(int) * 2; - - if (is_first_pass) { - const float* in_ptr = input.data_ptr(); - if (is_final_pass) { - edt_kernel_first_pass<<>>( - in_ptr, out_d->data_ptr(), out_i->data_ptr(), - shape_tensor.data_ptr(), strides_tensor.data_ptr(), - ndim, d_sample, total_slices, num_slices_per_sample - ); - } else { - edt_kernel_first_pass<<>>( - in_ptr, out_d->data_ptr(), out_i->data_ptr(), - shape_tensor.data_ptr(), strides_tensor.data_ptr(), - ndim, d_sample, total_slices, num_slices_per_sample - ); - } - } else { - if (is_final_pass) { - edt_kernel_subsequent_pass<<>>( - in_d->data_ptr(), in_i->data_ptr(), - out_d->data_ptr(), out_i->data_ptr(), - shape_tensor.data_ptr(), strides_tensor.data_ptr(), - ndim, d_sample, total_slices, num_slices_per_sample - ); - } else { - edt_kernel_subsequent_pass<<>>( - in_d->data_ptr(), in_i->data_ptr(), - out_d->data_ptr(), out_i->data_ptr(), - shape_tensor.data_ptr(), strides_tensor.data_ptr(), - ndim, d_sample, total_slices, num_slices_per_sample - ); - } + #define DISPATCH_GLOBAL(IS_FINAL) \ + switch(sample_ndim) { \ + case 1: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel()); break; \ + case 2: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel()); break; \ + case 3: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel()); break; \ + default: break; \ + } + + if (is_final_pass) { DISPATCH_GLOBAL(true); } + else { DISPATCH_GLOBAL(false); } } + + // ----------------------------------------------------------- + // Step C: Transpose Back + // ----------------------------------------------------------- + current_dist = dist_out.transpose(d, ndim - 1); // View, non-contiguous is fine here as next step makes it contiguous + current_idx = idx_out.transpose(d, ndim - 1); } - - if (had_no_batch_dim) { - return std::make_tuple(distance.squeeze(0), index.squeeze(0)); + + if (had_no_batch_dim) { + return std::make_tuple(current_dist.squeeze(0), current_idx.squeeze(0)); } - return std::make_tuple(distance, index); + return std::make_tuple(current_dist, current_idx); } \ No newline at end of file From 698ce2493053f21f7e7acc7b55f5ab38736599f6 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Sat, 6 Dec 2025 21:28:04 +0800 Subject: [PATCH 15/39] =?UTF-8?q?=E5=A2=9E=E5=8A=A03=E7=BB=B4=E4=BB=A5?= =?UTF-8?q?=E4=B8=8A=E7=BB=B4=E5=BA=A6=E7=9A=84=E8=AE=A1=E7=AE=97=E5=A4=84?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_distance_transform.py | 146 ++++---- torchmorph/csrc/distance_transform_kernel.cu | 335 ++++++++++--------- 2 files changed, 269 insertions(+), 212 deletions(-) diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 476ffca..6e6265d 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -1,21 +1,20 @@ import torch import pytest -from scipy.ndimage import distance_transform_edt as scipy_edt import numpy as np -import torchmorph as tm +from scipy.ndimage import distance_transform_edt as scipy_edt +import torchmorph as tm -# 辅助函数:调用 SciPy 并处理格式 +# ====================================================================== +# 辅助函数 +# ====================================================================== def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - + dist_results, indices_results = [], [] - input_is_1d_batch = (batch_numpy.ndim == 2) - input_is_single_sample_no_batch = (batch_numpy.ndim == 1) + # 保证 batch_numpy 至少是 (Batch, ...) + # 如果进来的是 (H, W),我们在外面已经处理成 (1, H, W) 了 + if batch_numpy.ndim == 1: + batch_numpy = batch_numpy[np.newaxis, ...] - if input_is_single_sample_no_batch: - batch_numpy = batch_numpy[np.newaxis, ...] # (L) -> (1, L) - - - dist_results, indices_results = [], [] for sample in batch_numpy: dist, indices = scipy_edt(sample, return_indices=True, return_distances=True) dist_results.append(dist) @@ -23,81 +22,106 @@ def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, n output_dist = np.stack(dist_results, axis=0) output_indices = np.stack(indices_results, axis=0) + output_indices = np.moveaxis(output_indices, 1, -1) - # indices shape fix: (N, ndim_sample, ...) -> (N, ..., ndim_sample) - # 对于 1D: (N, 1, L) -> (N, L, 1) - output_indices = np.moveaxis(output_indices, 1, -1) - - if input_is_single_sample_no_batch: - output_dist = output_dist.squeeze(0) - output_indices = output_indices.squeeze(0) - return output_dist, output_indices -# 用例定义 -case_batch_2d = np.array([[[0., 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]],[[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.float32) +# ====================================================================== +# 测试数据 +# ====================================================================== +case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32) + +case_batch_2d = np.array([[[0., 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], + [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.float32) + +# 这里定义为 (4, 4),意图是单张 2D 图 +case_single_2d = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]], dtype=np.float32) +case_explicit_batch_one = case_single_2d[np.newaxis, ...] + _case_3d_s1 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s1[1, 1, 1] = 0.0; _case_3d_s1[2, 3, 4] = 0.0 _case_3d_s2 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s2[0, 0, 0] = 0.0 case_batch_3d = np.stack([_case_3d_s1, _case_3d_s2], axis=0) -case_single_2d = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]], dtype=np.float32) -case_explicit_batch_one = case_single_2d[np.newaxis, ...] + case_dim_one = np.ones((2, 5, 1), dtype=np.float32); case_dim_one[0, 2, 0] = 0.0; case_dim_one[1, 4, 0] = 0.0 -case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32) +# 4D Case +_case_4d_s1 = np.ones((3, 3, 3, 3), dtype=np.float32); _case_4d_s1[0, 0, 0, 0] = 0.0 +_case_4d_s2 = np.ones((3, 3, 3, 3), dtype=np.float32); _case_4d_s2[1, 1, 1, 1] = 0.0 +case_batch_4d_spatial = np.stack([_case_4d_s1, _case_4d_s2], axis=0) + +# 5D Case +case_batch_5d_spatial = np.ones((1, 2, 2, 2, 2, 2), dtype=np.float32) +case_batch_5d_spatial[0, 0, 0, 0, 0, 0] = 0.0; case_batch_5d_spatial[0, 1, 1, 1, 1, 1] = 0.0 + +# ====================================================================== +# 测试逻辑 +# ====================================================================== @pytest.mark.parametrize( - "input_numpy", + "input_numpy, has_batch_dim", [ - pytest.param(case_batch_2d, id="批处理2D图像"), - pytest.param(case_batch_3d, id="批处理3D图像"), - pytest.param(case_single_2d, id="单张2D图像(隐式批处理)"), - pytest.param(case_explicit_batch_one, id="单张2D图像(显式批处理)"), - pytest.param(case_dim_one, id="含幺元维度的批处理"), - pytest.param(case_batch_1d, id="批处理1D数据"), + pytest.param(case_batch_1d, True, id="1D_Batch"), + pytest.param(case_batch_2d, True, id="2D_Batch"), + pytest.param(case_single_2d, False, id="2D_Single_NoBatch"), + pytest.param(case_explicit_batch_one, True, id="2D_Single_ExplicitBatch"), + pytest.param(case_batch_3d, True, id="3D_Batch"), + pytest.param(case_dim_one, True, id="2D_UnitDim_Batch"), + pytest.param(case_batch_4d_spatial, True, id="4D_Spatial_Batch"), + pytest.param(case_batch_5d_spatial, True, id="5D_Spatial_Batch"), ], ) -def test_distance_transform_and_indices(input_numpy: np.ndarray, request: pytest.FixtureRequest): +def test_distance_transform_and_indices(input_numpy: np.ndarray, has_batch_dim: bool, request: pytest.FixtureRequest): if not torch.cuda.is_available(): pytest.skip("CUDA not available") + # 1. 准备 Numpy 数据 x_numpy_contiguous = np.ascontiguousarray(input_numpy) + + # 2. 准备 SciPy 输入 + # 如果意图是单样本 (has_batch_dim=False),我们手动增加 Batch 维, + # 这样 scipy 辅助函数就会把它当做一张图来处理,而不是 N 张 1D 图 + if not has_batch_dim: + scipy_input = x_numpy_contiguous[np.newaxis, ...] + else: + scipy_input = x_numpy_contiguous + + # 3. 准备 CUDA 输入 + # 关键修复: + # 如果 has_batch_dim=False,说明这是单张 (H, W),我们要测 2D EDT。 + # C++ API 默认第一维是 Batch,所以我们必须 unsqueeze(0) 变成 (1, H, W)。 + # 否则 C++ 会把它当做 (Batch=H, Len=W) 做 1D EDT。 x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() + if not has_batch_dim: + x_cuda = x_cuda.unsqueeze(0) - print(f"\n\n--- 正在运行测试: {request.node.callspec.id} ---") - print(f"输入张量形状: {x_cuda.shape}") + print(f"\n\n--- 运行测试: {request.node.callspec.id} ---") + print(f"CUDA 输入形状: {x_cuda.shape}") - # 调用您的 Python 包装函数 + # 4. 运行 CUDA EDT dist_cuda, idx_cuda = tm.distance_transform(x_cuda.clone()) - print(f"CUDA 距离输出形状: {dist_cuda.shape}") - print(f"CUDA 坐标输出形状: {idx_cuda.shape}") - - # 调用 SciPy 作为参考基准 - dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(x_numpy_contiguous) + # 5. 运行 SciPy (Ground Truth) + dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(scipy_input) dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() - - print(f"SciPy 距离输出形状: {dist_ref.shape}") - # 断言验证 - print("\n--- 正在验证距离... ---") - assert dist_cuda.shape == dist_ref.shape + # 6. 验证距离 + # 此时 dist_cuda 是 (1, H, W),dist_ref 也是 (1, H, W) + # 如果原意是 NoBatch,我们可以把 Batch 维 squeeze 掉再比,或者直接比 + print(f"CUDA Out Shape: {dist_cuda.shape}, Ref Shape: {dist_ref.shape}") + assert dist_cuda.shape == dist_ref.shape, f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-3, rtol=1e-3) - print("距离断言通过 (形状和数值接近)。") + print(">> 距离验证通过。") - print("\n--- 正在验证坐标... ---") - - # 鲁棒的坐标验证逻辑 - had_no_batch_dim = (x_numpy_contiguous.ndim <= idx_cuda.shape[-1]) - spatial_shape = x_cuda.shape if had_no_batch_dim else x_cuda.shape[1:] + # 7. 验证索引 + # idx_cuda: (1, H, W, 2) + # 构造 Grid + spatial_shape = x_cuda.shape[1:] # (H, W) coords = [torch.arange(s, device='cuda') for s in spatial_shape] - grid = torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1) - - if not had_no_batch_dim: - grid = grid.unsqueeze(0) - + grid = torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1) # (H, W, 2) + grid = grid.unsqueeze(0) # (1, H, W, 2) + diff = grid.float() - idx_cuda.float() - dist_sq_from_indices = torch.sum(diff * diff, dim=-1) + dist_sq_calculated = torch.sum(diff * diff, dim=-1) + dist_sq_output = dist_cuda * dist_cuda - torch.testing.assert_close(dist_sq_from_indices, dist_cuda * dist_cuda, atol=1e-3, rtol=1e-3) - print("坐标正确性断言通过 (计算出的距离与返回距离匹配)。") - - print("--- 测试通过 ---") \ No newline at end of file + torch.testing.assert_close(dist_sq_calculated, dist_sq_output, atol=1e-3, rtol=1e-3) + print(">> 索引验证通过。") \ No newline at end of file diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 24ba467..3bda330 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -5,23 +5,32 @@ #include #include +// ------------------------------------------------------------------ +// 配置常量 +// ------------------------------------------------------------------ #define INF_VAL 1e8f #define MAX_THREADS 1024 -#define SMEM_LIMIT_ELEMENTS 4096 // 48KB / 12 bytes (float+int+int) ~= 4096 +// Shared Memory 限制: 48KB 一般安全。 +// 每个像素需要: float(val) + int(idx1) + int(idx2) = 12 bytes +// 4096 * 12 = 48KB. +#define SMEM_LIMIT_ELEMENTS 4096 + +// ------------------------------------------------------------------ +// Device Helper Functions +// ------------------------------------------------------------------ __device__ __forceinline__ float sqr(float x) { return x * x; } -// 计算从像素 q 到源点 p 的距离代价 -// val_p 是源点 p 在上一轮计算后的距离平方值 (weight) +// 计算 JFA 代价: (q - p)^2 + weight[p] __device__ __forceinline__ float compute_cost(int q, int p, float val_p) { if (p < 0) return INF_VAL; return sqr((float)q - (float)p) + val_p; } // ------------------------------------------------------------------ -// JFA 核心逻辑 (Device Function) +// JFA Core Logic (Device Only) // ------------------------------------------------------------------ -// 无论数据是在 Shared Memory 还是 Global Memory,逻辑是一样的 +// 核心 JFA 逻辑,与数据位置无关 (Shared 或 Global 均通用) __device__ void run_jfa_core( int N, int tid, @@ -29,13 +38,12 @@ __device__ void run_jfa_core( int* __restrict__ idx_curr, // Ping-Pong Buffer A int* __restrict__ idx_next // Ping-Pong Buffer B ) { - // 1. 初始化 + // 1. 初始化: 根据 vals 决定是否是有效源点 for (int i = tid; i < N; i += blockDim.x) { - // 如果输入值很大,说明是背景,没有初始源点 if (vals[i] >= INF_VAL * 0.9f) { - idx_curr[i] = -1; + idx_curr[i] = -1; // 背景 } else { - idx_curr[i] = i; + idx_curr[i] = i; // 物体/源点,初始索引指向自己 } } __syncthreads(); @@ -49,11 +57,12 @@ __device__ void run_jfa_core( int my_best_p = idx_in[i]; float min_cost = INF_VAL; + // 检查自己当前的最优解 if (my_best_p != -1) { min_cost = compute_cost(i, my_best_p, vals[my_best_p]); } - // Check Left + // Check Left Neighbor (-step) int left = i - step; if (left >= 0) { int left_p = idx_in[left]; @@ -66,7 +75,7 @@ __device__ void run_jfa_core( } } - // Check Right + // Check Right Neighbor (+step) int right = i + step; if (right < N) { int right_p = idx_in[right]; @@ -99,42 +108,41 @@ __device__ void run_jfa_core( // ------------------------------------------------------------------ // Kernel 1: Shared Memory JFA (Fast Path) -// 适用于 N <= 4096 // ------------------------------------------------------------------ +// 模板参数 NDim: 如果 > 0,编译器会展开循环优化。 +// 参数 runtime_ndim: 如果 NDim == 0 (Default case),使用该参数作为维度。 template __global__ void edt_kernel_shared( - const float* __restrict__ in_data, // 当前维度的输入 (dist^2) - const int32_t* __restrict__ in_indices, // 上一轮的索引图 (N_slices, L, NDim) - float* __restrict__ out_dist, // 输出距离 - int32_t* __restrict__ out_indices, // 输出索引图 - int64_t L, // 当前维度的长度 (Length) - int64_t total_elements // Batch * ... * L + const float* __restrict__ in_data, // 输入 Dist^2 + const int32_t* __restrict__ in_indices, // 输入 Indices + float* __restrict__ out_dist, // 输出 Dist (IsFinal ? sqrt : sqr) + int32_t* __restrict__ out_indices, // 输出 Indices + int64_t L, // 当前维度的长度 + int64_t total_elements, // 总像素数 + int runtime_ndim // 运行时维度 (fallback) ) { - // 这里的 total_elements 是展平后的总像素数 - // 由于我们做了 transpose,数据布局是 [Batch_and_other_dims, L] - // 每个 Block 处理一行 (长度 L) - + // 确定实际维度 + const int D = (NDim > 0) ? NDim : runtime_ndim; + + // 计算行偏移 int64_t row_idx = blockIdx.x; int64_t offset = row_idx * L; if (offset >= total_elements) return; - // Shared Memory 布局: float vals[L], int idx1[L], int idx2[L] + // Shared Memory 布局 extern __shared__ char s_buffer[]; float* s_vals = (float*)s_buffer; int* s_idx1 = (int*)(s_vals + L); int* s_idx2 = (int*)(s_idx1 + L); - // 1. 加载数据到 Shared Memory + // 1. 加载 Dist 到 Shared Memory for (int i = threadIdx.x; i < L; i += blockDim.x) { - float val = __ldg(&in_data[offset + i]); - // 如果是初始 Pass (无输入索引),val 为 0 或 INF - // 如果是后续 Pass,val 为上一步的 dist^2 - s_vals[i] = val; + s_vals[i] = __ldg(&in_data[offset + i]); } __syncthreads(); - // 2. 运行 JFA + // 2. 运行 JFA 核心 run_jfa_core(L, threadIdx.x, s_vals, s_idx1, s_idx2); // 3. 写回结果 @@ -142,69 +150,64 @@ __global__ void edt_kernel_shared( int p = s_idx1[q]; // 最近点在当前行内的局部索引 (0..L-1) float dist_val; + // 计算新距离 if (p != -1) { - // 计算新距离: (q-p)^2 + val[p] float dist_sq = sqr((float)q - (float)p) + s_vals[p]; dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; } else { dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p = 0; // fallback + p = 0; // 防止越界,随便指一个 } - out_dist[offset + q] = dist_val; - // 4. 索引传播 - // 我们需要从 in_indices 查找完整的高维索引 - // in_indices 形状: [Batch..., L, NDim] - // 这里的 offset 对应 [Batch..., 0] - // p 是当前维度的偏移 + // 索引传播: Copy Vector [D] if (p != -1) { - int64_t src_offset = (offset + p) * NDim; - int64_t dst_offset = (offset + q) * NDim; + int64_t src_offset = (offset + p) * D; + int64_t dst_offset = (offset + q) * D; - // 手动展开拷贝,或者循环 - for (int d = 0; d < NDim; ++d) { + // 如果 NDim > 0,这里会完全展开,非常快 + for (int d = 0; d < D; ++d) { out_indices[dst_offset + d] = in_indices[src_offset + d]; } } else { - // 保持原样或填0 (通常保持原样即可,或者为了安全填0) - int64_t dst_offset = (offset + q) * NDim; - for (int d = 0; d < NDim; ++d) out_indices[dst_offset + d] = 0; + // 找不到源点(全图都是背景的情况) + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; } } } // ------------------------------------------------------------------ // Kernel 2: Global Memory JFA (Fallback Path) -// 适用于 N > 4096,使用 Global Memory 作为 Ping-Pong Buffer // ------------------------------------------------------------------ +// 逻辑同上,只是用 Global Memory 做 Ping-Pong Buffer template __global__ void edt_kernel_global( const float* __restrict__ in_data, const int32_t* __restrict__ in_indices, float* __restrict__ out_dist, int32_t* __restrict__ out_indices, - int* __restrict__ global_buffer_1, // 临时 buffer A [TotalElements] - int* __restrict__ global_buffer_2, // 临时 buffer B [TotalElements] + int* __restrict__ global_buffer_1, + int* __restrict__ global_buffer_2, int64_t L, - int64_t total_elements + int64_t total_elements, + int runtime_ndim ) { + const int D = (NDim > 0) ? NDim : runtime_ndim; + int64_t row_idx = blockIdx.x; int64_t offset = row_idx * L; if (offset >= total_elements) return; - // 指向当前行在 Global Memory 中的位置 - // 注意:in_data 是只读的,我们需要把它当做 weight - // JFA 需要两个 int buffer 来存 index + // 指向 Global Memory 的指针 int* g_idx1 = global_buffer_1 + offset; int* g_idx2 = global_buffer_2 + offset; - // 直接在 Global Memory 上运行 JFA - // 注意:这里 vals 指针直接指向 in_data (Global),读取稍慢但无需拷贝 + // 1. & 2. 运行 JFA (直接在 Global Mem 上读写) run_jfa_core(L, threadIdx.x, in_data + offset, g_idx1, g_idx2); - // 写回逻辑同上 + // 3. 写回结果 for (int q = threadIdx.x; q < L; q += blockDim.x) { int p = g_idx1[q]; float dist_val; @@ -221,177 +224,191 @@ __global__ void edt_kernel_global( out_dist[offset + q] = dist_val; if (p != -1) { - int64_t src_offset = (offset + p) * NDim; - int64_t dst_offset = (offset + q) * NDim; - for (int d = 0; d < NDim; ++d) { + int64_t src_offset = (offset + p) * D; + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) { out_indices[dst_offset + d] = in_indices[src_offset + d]; } } else { - int64_t dst_offset = (offset + q) * NDim; - for (int d = 0; d < NDim; ++d) out_indices[dst_offset + d] = 0; + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; } } } - // ------------------------------------------------------------------ -// 辅助:初始化索引张量 +// Kernel 3: Initialize Indices // ------------------------------------------------------------------ -// 将 index tensor 初始化为 grid grid coordinates -// shape: (..., D), 最后一个维度存坐标 -__global__ void init_indices_kernel(int32_t* indices, int64_t total_elements, int NDim, - const int64_t* shape, const int64_t* strides) { +// 初始化索引张量为网格坐标 +// indices shape: (..., D) +__global__ void init_indices_kernel( + int32_t* indices, + int64_t total_pixels, + int NDim, + const int64_t* __restrict__ shape_ptr // shape of spatial dimensions +) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) return; + if (idx >= total_pixels) return; - // 反解坐标 + // 反解坐标 (Unravel Index) + // idx 是每个像素的 flat index + // 我们需要计算它在 spatial_shape 中的坐标 + int64_t temp = idx; - int32_t coords[10]; // max dims + // 使用本地寄存器数组避免多次全局内存读取 (假设最大 10 维) + int32_t coords[10]; - // strides 是针对 elements 展开的,但这里 indices 是 (Total, NDim) - // 我们可以简单地根据 shape 反解 - // 注意:这里的 total_elements 是像素数,不是 indices 数组的大小 - - // 假设 shape 是 [D0, D1, D2] - // idx 对应 flat index - + // 假设 spatial_shape 是 [D0, D1, D2] + // 倒序计算除余 for (int d = NDim - 1; d >= 0; --d) { - coords[d] = temp % shape[d]; - temp /= shape[d]; + int64_t dim_size = shape_ptr[d]; + coords[d] = temp % dim_size; + temp /= dim_size; } - // 写入 + // 写入 Global Memory + // Indices tensor 是 (TotalPixels, NDim) 扁平化的 + int64_t out_ptr = idx * NDim; for (int d = 0; d < NDim; ++d) { - indices[idx * NDim + d] = coords[d]; + indices[out_ptr + d] = coords[d]; } } // ------------------------------------------------------------------ -// Host 函数 +// Host Function: C++ Entry Point // ------------------------------------------------------------------ + std::tuple distance_transform_cuda(torch::Tensor input) { TORCH_CHECK(input.is_cuda(), "Input must be on CUDA device."); TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32."); input = input.contiguous(); - bool had_no_batch_dim = (input.dim() == 1); - if (had_no_batch_dim) input = input.unsqueeze(0); - - const int ndim = input.dim(); // Include batch + + // 处理 Batch 维度:如果输入是 1D (L),视为无 Batch,但在处理中统一加一个 Batch 维方便 + // 标准约定:Input shape (Batch, D1, D2, ..., Dn) + // 算法对 Batch 维度和其他维度处理其实是一样的(视为无关维度) + // 但索引初始化需要知道哪些是 "Spatial Dimensions"。 + // 这里假设:输入的所有维度除了 Batch (Dim 0) 外都是空间维度。 + + const int ndim = input.dim(); + // 如果 ndim=1, 假设是 (L),sample_ndim=1 + // 如果 ndim=4 (B, C, H, W),sample_ndim=3 (C,H,W 都算空间? 通常 C 也是独立处理的) + // **修正**: 标准 EDT 通常是在 (H, W) 或 (D, H, W) 上进行的。 + // 如果有 Channel,通常 Channel 也是独立的。 + // 为了最通用,我们将 **除了第0维(Batch)** 以外的所有维度都视为空间维度进行索引记录。 + // 如果用户输入没有 Batch 维,请在 Python 端 unsqueeze(0)。 + + // 假设输入已经是 (Batch, ...Spatial...) const int sample_ndim = ndim - 1; + TORCH_CHECK(sample_ndim > 0, "Input tensor must have at least 2 dimensions (Batch, ...)"); + auto shape = input.sizes().vec(); int64_t num_pixels = input.numel(); if (num_pixels == 0) { auto index_shape = shape; - index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); + index_shape.push_back(sample_ndim); return std::make_tuple(torch::empty_like(input), torch::empty(index_shape, input.options().dtype(torch::kInt32))); } - // 1. 初始化输出 Tensor - // current_dist 在迭代过程中存储 dist^2,最后开方 - // 初始状态:Input 里的 0 还是 0,其他非 0 (背景) 设为 INF + // 1. 初始化 Distance Tensor + // 0 -> 0, 1 -> INF auto current_dist = torch::where(input == 0, torch::tensor(0.0f, input.options()), torch::tensor(INF_VAL, input.options())); - // 初始化索引 Map (Batch, ..., NDim) + // 2. 初始化 Index Tensor + // Shape: (Batch, D1, ..., Dn, sample_ndim) auto index_shape = shape; index_shape.push_back(sample_ndim); auto current_idx = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - // 启动 Kernel 初始化索引 - // 为了反解坐标,我们需要把 shape 传进去 + // 2.1 准备 Shape 数据传给 Kernel + std::vector spatial_shape(shape.begin() + 1, shape.end()); + auto shape_tensor = torch::tensor(spatial_shape, torch::kInt64).to(input.device()); + + // 2.2 运行初始化 Kernel { - // 排除 batch 维度的 shape 用于坐标计算? - // 需求是:返回的索引是 (batch_idx, z, y, x) 还是只是 (z, y, x)? - // 通常 EDT 返回的是 sample 内的坐标。所以我们忽略 batch 维度。 - std::vector sample_shape_vec(shape.begin() + 1, shape.end()); - auto sample_shape_tensor = torch::tensor(sample_shape_vec, torch::kInt64).to(input.device()); - // 这里的 strides 不需要,直接由 shape 反解 - int threads = 256; int blocks = (num_pixels + threads - 1) / threads; - - // 我们需要传递 sample_ndim init_indices_kernel<<>>( - current_idx.data_ptr(), - num_pixels, + current_idx.data_ptr(), + num_pixels, sample_ndim, - sample_shape_tensor.data_ptr(), - nullptr // strides not needed for simple unravel + shape_tensor.data_ptr() ); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("Init Kernel Failed: %s\n", cudaGetErrorString(err)); + } } - - // 用于 Global Memory Fallback 的临时 buffer + + // 预分配 Global Memory Buffer (懒加载) torch::Tensor global_buf1, global_buf2; - // 2. 逐维处理 (Separable Phases) - // 从最后一个维度倒着处理,或者顺序处理都可以。 - // 为了 Host Transpose 方便,我们遍历 sample 的每一个维度 (1 到 ndim-1) + // 3. 逐维处理 (Separable JFA) + // 遍历每一个空间维度 (从 1 到 ndim-1) for (int d = 1; d < ndim; ++d) { bool is_final_pass = (d == ndim - 1); - // ----------------------------------------------------------- - // Step A: Permute & Contiguous - // 将当前处理维度 d 移到最后: (0, 1, ..., d, ..., N-1) -> (0, 1, ..., N-1, d) - // 这样最后内存布局就是 [..., L],stride=1 - // ----------------------------------------------------------- - - // 这种 swap 策略比较简单: transpose(d, -1) - // 注意:index tensor 也要变换,但 index tensor 最后一维是 coord_dim,不能乱动。 - // Index tensor 形状是 [..., sample_ndim]。 - // 我们需要变换的是前面的空间维度 [...]。 - + // --- Step A: Transpose current dim to last --- + // 变换后 Shape: (..., L) auto dist_in = current_dist.transpose(d, ndim - 1).contiguous(); auto idx_in = current_idx.transpose(d, ndim - 1).contiguous(); - // 此时 dist_in shape: [..., L] - // idx_in shape: [..., L, sample_ndim] - int64_t L = dist_in.size(-1); // 当前维度的长度 - int64_t total_slices = dist_in.numel() / L; // 有多少行 + int64_t total_slices = dist_in.numel() / L; auto dist_out = torch::empty_like(dist_in); auto idx_out = torch::empty_like(idx_in); - // ----------------------------------------------------------- - // Step B: Kernel Dispatch - // ----------------------------------------------------------- + // --- Step B: Kernel Dispatch --- int threads = std::min((int64_t)MAX_THREADS, L); - // 检查 Shared Memory 需求 - // Need: float(4) + int(4) + int(4) = 12 bytes per pixel + // 检查是否可以使用 Shared Memory if (L <= SMEM_LIMIT_ELEMENTS) { size_t smem_size = L * (sizeof(float) + 2 * sizeof(int)); - // 模板参数 NDim 需要是编译期常量。 - // 动态分发 sample_ndim (1D, 2D, 3D usually) - // 使用 switch case 覆盖常见维度 (1, 2, 3) + // 使用 Switch 宏来处理常用的维度模板特化 #define DISPATCH_SHARED(IS_FINAL) \ switch(sample_ndim) { \ case 1: edt_kernel_shared<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel()); break; \ + L, dist_in.numel(), sample_ndim); break; \ case 2: edt_kernel_shared<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel()); break; \ + L, dist_in.numel(), sample_ndim); break; \ case 3: edt_kernel_shared<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel()); break; \ - default: /* fallback for >3D */ break; \ + L, dist_in.numel(), sample_ndim); break; \ + case 4: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 5: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 6: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + default: /* Fallback for > 6D */ \ + edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ } if (is_final_pass) { DISPATCH_SHARED(true); } else { DISPATCH_SHARED(false); } } else { - // Fallback: Global Memory - // 需要分配 buffer: [total_slices * L] = [numel] + // Global Memory Fallback (L > 4096) if (global_buf1.numel() < dist_in.numel()) { global_buf1 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); global_buf2 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); @@ -403,33 +420,49 @@ std::tuple distance_transform_cuda(torch::Tensor i dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel()); break; \ + L, dist_in.numel(), sample_ndim); break; \ case 2: edt_kernel_global<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel()); break; \ + L, dist_in.numel(), sample_ndim); break; \ case 3: edt_kernel_global<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel()); break; \ - default: break; \ + L, dist_in.numel(), sample_ndim); break; \ + case 4: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 5: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 6: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + default: /* Fallback */ \ + edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ } - if (is_final_pass) { DISPATCH_GLOBAL(true); } + if (is_final_pass) { DISPATCH_GLOBAL(true); } else { DISPATCH_GLOBAL(false); } } - // ----------------------------------------------------------- - // Step C: Transpose Back - // ----------------------------------------------------------- - current_dist = dist_out.transpose(d, ndim - 1); // View, non-contiguous is fine here as next step makes it contiguous + // --- Step C: Transpose Back --- + current_dist = dist_out.transpose(d, ndim - 1); current_idx = idx_out.transpose(d, ndim - 1); } - if (had_no_batch_dim) { - return std::make_tuple(current_dist.squeeze(0), current_idx.squeeze(0)); - } return std::make_tuple(current_dist, current_idx); -} \ No newline at end of file +} + From bc9c03b3033522c8a98b6a07ee012e751ba25960 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Mon, 8 Dec 2025 19:58:15 +0800 Subject: [PATCH 16/39] =?UTF-8?q?=E5=AE=9E=E7=8E=B0n=E7=BB=B4=E6=89=B9?= =?UTF-8?q?=E5=A4=84=E7=90=86=E5=90=8C=E6=97=B6=E8=BF=94=E5=9B=9E=E5=9D=90?= =?UTF-8?q?=E6=A0=87=E5=92=8C=E8=B7=9D=E7=A6=BB=E7=9A=84=E7=B2=BE=E7=A1=AE?= =?UTF-8?q?=E6=AC=A7=E5=BC=8F=E8=B7=9D=E7=A6=BB=E5=8F=98=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_distance_transform.py | 140 +-- torchmorph/csrc/distance_transform_kernel.cu | 939 +++++++++++-------- 2 files changed, 592 insertions(+), 487 deletions(-) diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 6e6265d..3a11166 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -1,127 +1,91 @@ import torch import pytest -import numpy as np from scipy.ndimage import distance_transform_edt as scipy_edt -import torchmorph as tm +import numpy as np +import torchmorph as tm -# ====================================================================== -# 辅助函数 -# ====================================================================== +# 辅助函数:调用 SciPy 并处理格式 def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - dist_results, indices_results = [], [] - - # 保证 batch_numpy 至少是 (Batch, ...) - # 如果进来的是 (H, W),我们在外面已经处理成 (1, H, W) 了 - if batch_numpy.ndim == 1: + is_single_sample = batch_numpy.ndim <= 2 + if is_single_sample: batch_numpy = batch_numpy[np.newaxis, ...] - + dist_results, indices_results = [], [] for sample in batch_numpy: dist, indices = scipy_edt(sample, return_indices=True, return_distances=True) dist_results.append(dist) indices_results.append(indices) - output_dist = np.stack(dist_results, axis=0) output_indices = np.stack(indices_results, axis=0) - output_indices = np.moveaxis(output_indices, 1, -1) - + output_indices = np.moveaxis(output_indices, 1, -1) + if is_single_sample: + output_dist = output_dist.squeeze(0) + output_indices = output_indices.squeeze(0) return output_dist, output_indices -# ====================================================================== -# 测试数据 -# ====================================================================== -case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32) - -case_batch_2d = np.array([[[0., 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], - [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.float32) - -# 这里定义为 (4, 4),意图是单张 2D 图 -case_single_2d = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]], dtype=np.float32) -case_explicit_batch_one = case_single_2d[np.newaxis, ...] - +# 用例定义 +case_batch_2d = np.array([[[0., 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]],[[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.float32) _case_3d_s1 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s1[1, 1, 1] = 0.0; _case_3d_s1[2, 3, 4] = 0.0 _case_3d_s2 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s2[0, 0, 0] = 0.0 case_batch_3d = np.stack([_case_3d_s1, _case_3d_s2], axis=0) - +case_single_2d = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]], dtype=np.float32) +case_explicit_batch_one = case_single_2d[np.newaxis, ...] case_dim_one = np.ones((2, 5, 1), dtype=np.float32); case_dim_one[0, 2, 0] = 0.0; case_dim_one[1, 4, 0] = 0.0 +case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32) -# 4D Case -_case_4d_s1 = np.ones((3, 3, 3, 3), dtype=np.float32); _case_4d_s1[0, 0, 0, 0] = 0.0 -_case_4d_s2 = np.ones((3, 3, 3, 3), dtype=np.float32); _case_4d_s2[1, 1, 1, 1] = 0.0 -case_batch_4d_spatial = np.stack([_case_4d_s1, _case_4d_s2], axis=0) - -# 5D Case -case_batch_5d_spatial = np.ones((1, 2, 2, 2, 2, 2), dtype=np.float32) -case_batch_5d_spatial[0, 0, 0, 0, 0, 0] = 0.0; case_batch_5d_spatial[0, 1, 1, 1, 1, 1] = 0.0 - -# ====================================================================== -# 测试逻辑 -# ====================================================================== @pytest.mark.parametrize( - "input_numpy, has_batch_dim", + "input_numpy", [ - pytest.param(case_batch_1d, True, id="1D_Batch"), - pytest.param(case_batch_2d, True, id="2D_Batch"), - pytest.param(case_single_2d, False, id="2D_Single_NoBatch"), - pytest.param(case_explicit_batch_one, True, id="2D_Single_ExplicitBatch"), - pytest.param(case_batch_3d, True, id="3D_Batch"), - pytest.param(case_dim_one, True, id="2D_UnitDim_Batch"), - pytest.param(case_batch_4d_spatial, True, id="4D_Spatial_Batch"), - pytest.param(case_batch_5d_spatial, True, id="5D_Spatial_Batch"), + pytest.param(case_batch_2d, id="批处理2D图像"), + pytest.param(case_batch_3d, id="批处理3D图像"), + pytest.param(case_single_2d, id="单张2D图像(隐式批处理)"), + pytest.param(case_explicit_batch_one, id="单张2D图像(显式批处理)"), + pytest.param(case_dim_one, id="含幺元维度的批处理"), + pytest.param(case_batch_1d, id="批处理1D数据"), ], ) -def test_distance_transform_and_indices(input_numpy: np.ndarray, has_batch_dim: bool, request: pytest.FixtureRequest): +def test_distance_transform_and_indices(input_numpy: np.ndarray, request: pytest.FixtureRequest): if not torch.cuda.is_available(): pytest.skip("CUDA not available") - # 1. 准备 Numpy 数据 x_numpy_contiguous = np.ascontiguousarray(input_numpy) - - # 2. 准备 SciPy 输入 - # 如果意图是单样本 (has_batch_dim=False),我们手动增加 Batch 维, - # 这样 scipy 辅助函数就会把它当做一张图来处理,而不是 N 张 1D 图 - if not has_batch_dim: - scipy_input = x_numpy_contiguous[np.newaxis, ...] - else: - scipy_input = x_numpy_contiguous - - # 3. 准备 CUDA 输入 - # 关键修复: - # 如果 has_batch_dim=False,说明这是单张 (H, W),我们要测 2D EDT。 - # C++ API 默认第一维是 Batch,所以我们必须 unsqueeze(0) 变成 (1, H, W)。 - # 否则 C++ 会把它当做 (Batch=H, Len=W) 做 1D EDT。 x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() - if not has_batch_dim: - x_cuda = x_cuda.unsqueeze(0) - print(f"\n\n--- 运行测试: {request.node.callspec.id} ---") - print(f"CUDA 输入形状: {x_cuda.shape}") + print(f"\n\n--- 正在运行测试: {request.node.callspec.id} ---") + print(f"输入张量形状: {x_cuda.shape}") - # 4. 运行 CUDA EDT + # 调用您的 Python 包装函数 dist_cuda, idx_cuda = tm.distance_transform(x_cuda.clone()) - # 5. 运行 SciPy (Ground Truth) - dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(scipy_input) + print(f"CUDA 距离输出形状: {dist_cuda.shape}") + print(f"CUDA 坐标输出形状: {idx_cuda.shape}") + + # 调用 SciPy 作为参考基准 + dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(x_numpy_contiguous) dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() + + print(f"SciPy 距离输出形状: {dist_ref.shape}") - # 6. 验证距离 - # 此时 dist_cuda 是 (1, H, W),dist_ref 也是 (1, H, W) - # 如果原意是 NoBatch,我们可以把 Batch 维 squeeze 掉再比,或者直接比 - print(f"CUDA Out Shape: {dist_cuda.shape}, Ref Shape: {dist_ref.shape}") - assert dist_cuda.shape == dist_ref.shape, f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" + # 断言验证 + print("\n--- 正在验证距离... ---") + assert dist_cuda.shape == dist_ref.shape torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-3, rtol=1e-3) - print(">> 距离验证通过。") + print("距离断言通过 (形状和数值接近)。") - # 7. 验证索引 - # idx_cuda: (1, H, W, 2) - # 构造 Grid - spatial_shape = x_cuda.shape[1:] # (H, W) + print("\n--- 正在验证坐标... ---") + + # 鲁棒的坐标验证逻辑 + had_no_batch_dim = (x_numpy_contiguous.ndim <= idx_cuda.shape[-1]) + spatial_shape = x_cuda.shape if had_no_batch_dim else x_cuda.shape[1:] coords = [torch.arange(s, device='cuda') for s in spatial_shape] - grid = torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1) # (H, W, 2) - grid = grid.unsqueeze(0) # (1, H, W, 2) - + grid = torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1) + + if not had_no_batch_dim: + grid = grid.unsqueeze(0) + diff = grid.float() - idx_cuda.float() - dist_sq_calculated = torch.sum(diff * diff, dim=-1) - dist_sq_output = dist_cuda * dist_cuda + dist_sq_from_indices = torch.sum(diff * diff, dim=-1) - torch.testing.assert_close(dist_sq_calculated, dist_sq_output, atol=1e-3, rtol=1e-3) - print(">> 索引验证通过。") \ No newline at end of file + torch.testing.assert_close(dist_sq_from_indices, dist_cuda * dist_cuda, atol=1e-3, rtol=1e-3) + print("坐标正确性断言通过 (计算出的距离与返回距离匹配)。") + + print("--- 测试通过 ---") \ No newline at end of file diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 3bda330..4ec4647 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -1,468 +1,609 @@ #include #include -#include #include -#include -#include - -// ------------------------------------------------------------------ -// 配置常量 -// ------------------------------------------------------------------ -#define INF_VAL 1e8f -#define MAX_THREADS 1024 -// Shared Memory 限制: 48KB 一般安全。 -// 每个像素需要: float(val) + int(idx1) + int(idx2) = 12 bytes -// 4096 * 12 = 48KB. -#define SMEM_LIMIT_ELEMENTS 4096 - -// ------------------------------------------------------------------ -// Device Helper Functions -// ------------------------------------------------------------------ - -__device__ __forceinline__ float sqr(float x) { return x * x; } - -// 计算 JFA 代价: (q - p)^2 + weight[p] -__device__ __forceinline__ float compute_cost(int q, int p, float val_p) { - if (p < 0) return INF_VAL; - return sqr((float)q - (float)p) + val_p; -} - -// ------------------------------------------------------------------ -// JFA Core Logic (Device Only) -// ------------------------------------------------------------------ -// 核心 JFA 逻辑,与数据位置无关 (Shared 或 Global 均通用) -__device__ void run_jfa_core( - int N, - int tid, - const float* __restrict__ vals, // 输入权重 (只读) - int* __restrict__ idx_curr, // Ping-Pong Buffer A - int* __restrict__ idx_next // Ping-Pong Buffer B +#include +#include +#include + +// 优化策略:用4个独立的内核函数替代模板,完全消除分支 + +// 内核1: 第一个pass且是唯一pass (1D情况) +__global__ void edt_kernel_first_final( + const float* in_data, + float* out_dist, + int32_t* out_idx, + const int64_t* shape, + const int64_t* strides, + int32_t ndim, + int32_t process_dim_sample, + int64_t total_slices, + int64_t num_slices_per_sample ) { - // 1. 初始化: 根据 vals 决定是否是有效源点 - for (int i = tid; i < N; i += blockDim.x) { - if (vals[i] >= INF_VAL * 0.9f) { - idx_curr[i] = -1; // 背景 - } else { - idx_curr[i] = i; // 物体/源点,初始索引指向自己 + int64_t slice_idx = blockIdx.x; + if (slice_idx >= total_slices) return; + + int64_t batch_idx = slice_idx / num_slices_per_sample; + int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; + int64_t batch_offset = batch_idx * strides[0]; + int64_t sample_base_offset = 0; + int64_t temp_idx = slice_idx_in_sample; + const int sample_ndim = ndim - 1; + + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; } } - __syncthreads(); - - // 2. 迭代传播 (Step = 1, 2, 4, ... < N) - int* idx_in = idx_curr; - int* idx_out = idx_next; + + const int64_t process_dim_actual = process_dim_sample + 1; + const int64_t N = shape[process_dim_actual]; + const int64_t stride = strides[process_dim_actual]; + const int64_t base_offset = batch_offset + sample_base_offset; - for (int step = 1; step < N; step *= 2) { - for (int i = tid; i < N; i += blockDim.x) { - int my_best_p = idx_in[i]; - float min_cost = INF_VAL; + if (N == 0) return; - // 检查自己当前的最优解 - if (my_best_p != -1) { - min_cost = compute_cost(i, my_best_p, vals[my_best_p]); - } - - // Check Left Neighbor (-step) - int left = i - step; - if (left >= 0) { - int left_p = idx_in[left]; - if (left_p != -1) { - float c = compute_cost(i, left_p, vals[left_p]); - if (c < min_cost) { - min_cost = c; - my_best_p = left_p; - } + extern __shared__ char s_buffer[]; + float* f = (float*)s_buffer; + int* v = (int*)(f + N); + float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); + + // 加载数据 - 第一个pass + for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { + int64_t global_offset = base_offset + i * stride; + float val = __ldg(&in_data[global_offset]); + int32_t* shared_idx_ptr = s_idx + i * sample_ndim; + + if (val == 0.0f) { + f[i] = 0.0f; + int64_t temp_coord = slice_idx_in_sample; + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + shared_idx_ptr[d] = temp_coord % size_of_dim; + temp_coord /= size_of_dim; + } else { + shared_idx_ptr[d] = 0; } } + shared_idx_ptr[process_dim_sample] = i; + } else { + f[i] = 1e20f; + for (int d = 0; d < sample_ndim; ++d) shared_idx_ptr[d] = -1; + } + } + __syncthreads(); - // Check Right Neighbor (+step) - int right = i + step; - if (right < N) { - int right_p = idx_in[right]; - if (right_p != -1) { - float c = compute_cost(i, right_p, vals[right_p]); - if (c < min_cost) { - min_cost = c; - my_best_p = right_p; - } + // 构建包络 + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e20f; + z[1] = 1e20f; + + for (int q = 1; q < N; q++) { + float fq = f[q]; + int q_sq = q * q; + + while (k >= 0) { + int p = v[k]; + float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e20f; + break; + } + k--; + if (k < 0) { + k = 0; + v[0] = q; + z[0] = -1e20f; + z[1] = 1e20f; + break; } } - idx_out[i] = my_best_p; } - - // Swap Pointers - int* temp = idx_in; - idx_in = idx_out; - idx_out = temp; - __syncthreads(); } + __syncthreads(); - // 3. 确保最终结果在 idx_curr (如果循环结束时在 next,则拷回) - if (idx_in != idx_curr) { - for (int i = tid; i < N; i += blockDim.x) { - idx_curr[i] = idx_next[i]; + // 计算距离 - 最后一个pass,直接开方 + for (int q = threadIdx.x; q < N; q += blockDim.x) { + int k = 0; + float q_float = (float)q; + while (z[k + 1] < q_float) k++; + + int p = v[k]; + int64_t global_offset = base_offset + q * stride; + float dist_sq = (float)(q - p) * (q - p) + f[p]; + + out_dist[global_offset] = sqrtf(dist_sq); // 直接开方 + + int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; + const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; } - __syncthreads(); } } -// ------------------------------------------------------------------ -// Kernel 1: Shared Memory JFA (Fast Path) -// ------------------------------------------------------------------ -// 模板参数 NDim: 如果 > 0,编译器会展开循环优化。 -// 参数 runtime_ndim: 如果 NDim == 0 (Default case),使用该参数作为维度。 -template -__global__ void edt_kernel_shared( - const float* __restrict__ in_data, // 输入 Dist^2 - const int32_t* __restrict__ in_indices, // 输入 Indices - float* __restrict__ out_dist, // 输出 Dist (IsFinal ? sqrt : sqr) - int32_t* __restrict__ out_indices, // 输出 Indices - int64_t L, // 当前维度的长度 - int64_t total_elements, // 总像素数 - int runtime_ndim // 运行时维度 (fallback) +// 内核2: 第一个pass但不是最后 +__global__ void edt_kernel_first_only( + const float* in_data, + float* out_dist, + int32_t* out_idx, + const int64_t* shape, + const int64_t* strides, + int32_t ndim, + int32_t process_dim_sample, + int64_t total_slices, + int64_t num_slices_per_sample ) { - // 确定实际维度 - const int D = (NDim > 0) ? NDim : runtime_ndim; - - // 计算行偏移 - int64_t row_idx = blockIdx.x; - int64_t offset = row_idx * L; + int64_t slice_idx = blockIdx.x; + if (slice_idx >= total_slices) return; + + int64_t batch_idx = slice_idx / num_slices_per_sample; + int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; + int64_t batch_offset = batch_idx * strides[0]; + int64_t sample_base_offset = 0; + int64_t temp_idx = slice_idx_in_sample; + const int sample_ndim = ndim - 1; + + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; + } + } - if (offset >= total_elements) return; - - // Shared Memory 布局 - extern __shared__ char s_buffer[]; - float* s_vals = (float*)s_buffer; - int* s_idx1 = (int*)(s_vals + L); - int* s_idx2 = (int*)(s_idx1 + L); + const int64_t process_dim_actual = process_dim_sample + 1; + const int64_t N = shape[process_dim_actual]; + const int64_t stride = strides[process_dim_actual]; + const int64_t base_offset = batch_offset + sample_base_offset; - // 1. 加载 Dist 到 Shared Memory - for (int i = threadIdx.x; i < L; i += blockDim.x) { - s_vals[i] = __ldg(&in_data[offset + i]); - } - __syncthreads(); + if (N == 0) return; - // 2. 运行 JFA 核心 - run_jfa_core(L, threadIdx.x, s_vals, s_idx1, s_idx2); - - // 3. 写回结果 - for (int q = threadIdx.x; q < L; q += blockDim.x) { - int p = s_idx1[q]; // 最近点在当前行内的局部索引 (0..L-1) - float dist_val; - - // 计算新距离 - if (p != -1) { - float dist_sq = sqr((float)q - (float)p) + s_vals[p]; - dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; + extern __shared__ char s_buffer[]; + float* f = (float*)s_buffer; + int* v = (int*)(f + N); + float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); + + for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { + int64_t global_offset = base_offset + i * stride; + float val = __ldg(&in_data[global_offset]); + int32_t* shared_idx_ptr = s_idx + i * sample_ndim; + + if (val == 0.0f) { + f[i] = 0.0f; + int64_t temp_coord = slice_idx_in_sample; + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + shared_idx_ptr[d] = temp_coord % size_of_dim; + temp_coord /= size_of_dim; + } else { + shared_idx_ptr[d] = 0; + } + } + shared_idx_ptr[process_dim_sample] = i; } else { - dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p = 0; // 防止越界,随便指一个 + f[i] = 1e20f; + for (int d = 0; d < sample_ndim; ++d) shared_idx_ptr[d] = -1; } - out_dist[offset + q] = dist_val; + } + __syncthreads(); - // 索引传播: Copy Vector [D] - if (p != -1) { - int64_t src_offset = (offset + p) * D; - int64_t dst_offset = (offset + q) * D; + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e20f; + z[1] = 1e20f; + + for (int q = 1; q < N; q++) { + float fq = f[q]; + int q_sq = q * q; - // 如果 NDim > 0,这里会完全展开,非常快 - for (int d = 0; d < D; ++d) { - out_indices[dst_offset + d] = in_indices[src_offset + d]; + while (k >= 0) { + int p = v[k]; + float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e20f; + break; + } + k--; + if (k < 0) { + k = 0; + v[0] = q; + z[0] = -1e20f; + z[1] = 1e20f; + break; + } } - } else { - // 找不到源点(全图都是背景的情况) - int64_t dst_offset = (offset + q) * D; - for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; + } + } + __syncthreads(); + + for (int q = threadIdx.x; q < N; q += blockDim.x) { + int k = 0; + float q_float = (float)q; + while (z[k + 1] < q_float) k++; + + int p = v[k]; + int64_t global_offset = base_offset + q * stride; + float dist_sq = (float)(q - p) * (q - p) + f[p]; + + out_dist[global_offset] = dist_sq; // 不开方 + + int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; + const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; } } } -// ------------------------------------------------------------------ -// Kernel 2: Global Memory JFA (Fallback Path) -// ------------------------------------------------------------------ -// 逻辑同上,只是用 Global Memory 做 Ping-Pong Buffer -template -__global__ void edt_kernel_global( - const float* __restrict__ in_data, - const int32_t* __restrict__ in_indices, - float* __restrict__ out_dist, - int32_t* __restrict__ out_indices, - int* __restrict__ global_buffer_1, - int* __restrict__ global_buffer_2, - int64_t L, - int64_t total_elements, - int runtime_ndim +// 内核3: 中间pass +__global__ void edt_kernel_middle( + const float* in_dist, + const int32_t* in_idx, + float* out_dist, + int32_t* out_idx, + const int64_t* shape, + const int64_t* strides, + int32_t ndim, + int32_t process_dim_sample, + int64_t total_slices, + int64_t num_slices_per_sample ) { - const int D = (NDim > 0) ? NDim : runtime_ndim; - - int64_t row_idx = blockIdx.x; - int64_t offset = row_idx * L; + int64_t slice_idx = blockIdx.x; + if (slice_idx >= total_slices) return; + + int64_t batch_idx = slice_idx / num_slices_per_sample; + int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; + int64_t batch_offset = batch_idx * strides[0]; + int64_t sample_base_offset = 0; + int64_t temp_idx = slice_idx_in_sample; + const int sample_ndim = ndim - 1; + + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; + } + } - if (offset >= total_elements) return; + const int64_t process_dim_actual = process_dim_sample + 1; + const int64_t N = shape[process_dim_actual]; + const int64_t stride = strides[process_dim_actual]; + const int64_t base_offset = batch_offset + sample_base_offset; - // 指向 Global Memory 的指针 - int* g_idx1 = global_buffer_1 + offset; - int* g_idx2 = global_buffer_2 + offset; - - // 1. & 2. 运行 JFA (直接在 Global Mem 上读写) - run_jfa_core(L, threadIdx.x, in_data + offset, g_idx1, g_idx2); - - // 3. 写回结果 - for (int q = threadIdx.x; q < L; q += blockDim.x) { - int p = g_idx1[q]; - float dist_val; - - if (p != -1) { - float val_p = in_data[offset + p]; - float dist_sq = sqr((float)q - (float)p) + val_p; - dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; - } else { - dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p = 0; - } + if (N == 0) return; - out_dist[offset + q] = dist_val; + extern __shared__ char s_buffer[]; + float* f = (float*)s_buffer; + int* v = (int*)(f + N); + float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); + + for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { + int64_t global_offset = base_offset + i * stride; + f[i] = __ldg(&in_dist[global_offset]); + + const int32_t* global_idx_ptr = in_idx + global_offset * sample_ndim; + int32_t* shared_idx_ptr = s_idx + i * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + shared_idx_ptr[d] = __ldg(&global_idx_ptr[d]); + } + } + __syncthreads(); - if (p != -1) { - int64_t src_offset = (offset + p) * D; - int64_t dst_offset = (offset + q) * D; - for (int d = 0; d < D; ++d) { - out_indices[dst_offset + d] = in_indices[src_offset + d]; + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e20f; + z[1] = 1e20f; + + for (int q = 1; q < N; q++) { + float fq = f[q]; + int q_sq = q * q; + + while (k >= 0) { + int p = v[k]; + float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e20f; + break; + } + k--; + if (k < 0) { + k = 0; + v[0] = q; + z[0] = -1e20f; + z[1] = 1e20f; + break; + } } - } else { - int64_t dst_offset = (offset + q) * D; - for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; + } + } + __syncthreads(); + + for (int q = threadIdx.x; q < N; q += blockDim.x) { + int k = 0; + float q_float = (float)q; + while (z[k + 1] < q_float) k++; + + int p = v[k]; + int64_t global_offset = base_offset + q * stride; + float dist_sq = (float)(q - p) * (q - p) + f[p]; + + out_dist[global_offset] = dist_sq; // 不开方 + + int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; + const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; } } } -// ------------------------------------------------------------------ -// Kernel 3: Initialize Indices -// ------------------------------------------------------------------ -// 初始化索引张量为网格坐标 -// indices shape: (..., D) -__global__ void init_indices_kernel( - int32_t* indices, - int64_t total_pixels, - int NDim, - const int64_t* __restrict__ shape_ptr // shape of spatial dimensions +// 内核4: 最后一个pass +__global__ void edt_kernel_final( + const float* in_dist, + const int32_t* in_idx, + float* out_dist, + int32_t* out_idx, + const int64_t* shape, + const int64_t* strides, + int32_t ndim, + int32_t process_dim_sample, + int64_t total_slices, + int64_t num_slices_per_sample ) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_pixels) return; - - // 反解坐标 (Unravel Index) - // idx 是每个像素的 flat index - // 我们需要计算它在 spatial_shape 中的坐标 + int64_t slice_idx = blockIdx.x; + if (slice_idx >= total_slices) return; + + int64_t batch_idx = slice_idx / num_slices_per_sample; + int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; + int64_t batch_offset = batch_idx * strides[0]; + int64_t sample_base_offset = 0; + int64_t temp_idx = slice_idx_in_sample; + const int sample_ndim = ndim - 1; + + for (int32_t d = sample_ndim - 1; d >= 0; --d) { + if (d == process_dim_sample) continue; + int64_t size_of_dim = shape[d + 1]; + if (size_of_dim > 0) { + sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; + } + } - int64_t temp = idx; - // 使用本地寄存器数组避免多次全局内存读取 (假设最大 10 维) - int32_t coords[10]; - - // 假设 spatial_shape 是 [D0, D1, D2] - // 倒序计算除余 - for (int d = NDim - 1; d >= 0; --d) { - int64_t dim_size = shape_ptr[d]; - coords[d] = temp % dim_size; - temp /= dim_size; + const int64_t process_dim_actual = process_dim_sample + 1; + const int64_t N = shape[process_dim_actual]; + const int64_t stride = strides[process_dim_actual]; + const int64_t base_offset = batch_offset + sample_base_offset; + + if (N == 0) return; + + extern __shared__ char s_buffer[]; + float* f = (float*)s_buffer; + int* v = (int*)(f + N); + float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); + + for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { + int64_t global_offset = base_offset + i * stride; + f[i] = __ldg(&in_dist[global_offset]); + + const int32_t* global_idx_ptr = in_idx + global_offset * sample_ndim; + int32_t* shared_idx_ptr = s_idx + i * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + shared_idx_ptr[d] = __ldg(&global_idx_ptr[d]); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + int k = 0; + v[0] = 0; + z[0] = -1e20f; + z[1] = 1e20f; + + for (int q = 1; q < N; q++) { + float fq = f[q]; + int q_sq = q * q; + + while (k >= 0) { + int p = v[k]; + float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); + if (s > z[k]) { + k++; + v[k] = q; + z[k] = s; + z[k + 1] = 1e20f; + break; + } + k--; + if (k < 0) { + k = 0; + v[0] = q; + z[0] = -1e20f; + z[1] = 1e20f; + break; + } + } + } } + __syncthreads(); + + for (int q = threadIdx.x; q < N; q += blockDim.x) { + int k = 0; + float q_float = (float)q; + while (z[k + 1] < q_float) k++; + + int p = v[k]; + int64_t global_offset = base_offset + q * stride; + float dist_sq = (float)(q - p) * (q - p) + f[p]; + + out_dist[global_offset] = sqrtf(dist_sq); // 最后开方 - // 写入 Global Memory - // Indices tensor 是 (TotalPixels, NDim) 扁平化的 - int64_t out_ptr = idx * NDim; - for (int d = 0; d < NDim; ++d) { - indices[out_ptr + d] = coords[d]; + int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; + const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; + } } } -// ------------------------------------------------------------------ -// Host Function: C++ Entry Point -// ------------------------------------------------------------------ - +// Host函数 std::tuple distance_transform_cuda(torch::Tensor input) { - TORCH_CHECK(input.is_cuda(), "Input must be on CUDA device."); - TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32."); - + TORCH_CHECK(input.is_cuda(), "Input must be on a CUDA device."); + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be a float tensor."); input = input.contiguous(); - - // 处理 Batch 维度:如果输入是 1D (L),视为无 Batch,但在处理中统一加一个 Batch 维方便 - // 标准约定:Input shape (Batch, D1, D2, ..., Dn) - // 算法对 Batch 维度和其他维度处理其实是一样的(视为无关维度) - // 但索引初始化需要知道哪些是 "Spatial Dimensions"。 - // 这里假设:输入的所有维度除了 Batch (Dim 0) 外都是空间维度。 - - const int ndim = input.dim(); - // 如果 ndim=1, 假设是 (L),sample_ndim=1 - // 如果 ndim=4 (B, C, H, W),sample_ndim=3 (C,H,W 都算空间? 通常 C 也是独立处理的) - // **修正**: 标准 EDT 通常是在 (H, W) 或 (D, H, W) 上进行的。 - // 如果有 Channel,通常 Channel 也是独立的。 - // 为了最通用,我们将 **除了第0维(Batch)** 以外的所有维度都视为空间维度进行索引记录。 - // 如果用户输入没有 Batch 维,请在 Python 端 unsqueeze(0)。 - - // 假设输入已经是 (Batch, ...Spatial...) - const int sample_ndim = ndim - 1; - TORCH_CHECK(sample_ndim > 0, "Input tensor must have at least 2 dimensions (Batch, ...)"); - auto shape = input.sizes().vec(); - int64_t num_pixels = input.numel(); + bool had_no_batch_dim = (input.dim() <= 2); + if (had_no_batch_dim) { input = input.unsqueeze(0); } - if (num_pixels == 0) { + const auto ndim = input.dim(); + const auto sample_ndim = ndim - 1; + const auto batch_size = input.size(0); + + auto shape = input.sizes().vec(); + auto strides_vec = input.strides().vec(); + + if (input.numel() == 0) { + auto distance = torch::empty_like(input); auto index_shape = shape; - index_shape.push_back(sample_ndim); - return std::make_tuple(torch::empty_like(input), - torch::empty(index_shape, input.options().dtype(torch::kInt32))); + index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); + auto index = torch::empty(index_shape, input.options().dtype(torch::kInt32)); + if (had_no_batch_dim) return std::make_tuple(distance.squeeze(0), index.squeeze(0)); + return std::make_tuple(distance, index); } - - // 1. 初始化 Distance Tensor - // 0 -> 0, 1 -> INF - auto current_dist = torch::where(input == 0, - torch::tensor(0.0f, input.options()), - torch::tensor(INF_VAL, input.options())); - // 2. 初始化 Index Tensor - // Shape: (Batch, D1, ..., Dn, sample_ndim) + auto distance = torch::empty_like(input); + auto index_options = input.options().dtype(torch::kInt32); auto index_shape = shape; - index_shape.push_back(sample_ndim); - auto current_idx = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - - // 2.1 准备 Shape 数据传给 Kernel - std::vector spatial_shape(shape.begin() + 1, shape.end()); - auto shape_tensor = torch::tensor(spatial_shape, torch::kInt64).to(input.device()); - - // 2.2 运行初始化 Kernel - { - int threads = 256; - int blocks = (num_pixels + threads - 1) / threads; - init_indices_kernel<<>>( - current_idx.data_ptr(), - num_pixels, - sample_ndim, - shape_tensor.data_ptr() - ); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - printf("Init Kernel Failed: %s\n", cudaGetErrorString(err)); + index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); + auto index = torch::empty(index_shape, index_options); + + if (torch::all(input != 0).item()) { + distance.fill_(std::numeric_limits::infinity()); + index.fill_(-1); + if (had_no_batch_dim) { + return std::make_tuple(distance.squeeze(0), index.squeeze(0)); } + return std::make_tuple(distance, index); } - // 预分配 Global Memory Buffer (懒加载) - torch::Tensor global_buf1, global_buf2; + auto shape_tensor = torch::tensor(shape, + torch::TensorOptions().dtype(torch::kInt64).device(input.device())); + auto strides_tensor = torch::tensor(strides_vec, + torch::TensorOptions().dtype(torch::kInt64).device(input.device())); + + const int64_t* shape_gpu = shape_tensor.data_ptr(); + const int64_t* strides_gpu = strides_tensor.data_ptr(); - // 3. 逐维处理 (Separable JFA) - // 遍历每一个空间维度 (从 1 到 ndim-1) - for (int d = 1; d < ndim; ++d) { - bool is_final_pass = (d == ndim - 1); - - // --- Step A: Transpose current dim to last --- - // 变换后 Shape: (..., L) - auto dist_in = current_dist.transpose(d, ndim - 1).contiguous(); - auto idx_in = current_idx.transpose(d, ndim - 1).contiguous(); + std::vector> dim_order_pairs; + for (int32_t d_sample = 0; d_sample < sample_ndim; ++d_sample) { + dim_order_pairs.push_back({strides_vec[d_sample + 1], d_sample}); + } + std::sort(dim_order_pairs.rbegin(), dim_order_pairs.rend()); + + if (sample_ndim == 0) { + int64_t total_slices = batch_size; + int64_t slice_len = (shape.size() > 1) ? shape[1] : 0; + int threads = std::min((int64_t)256, slice_len); + size_t smem = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + + (slice_len + 2) * sizeof(float) + slice_len * 1 * sizeof(int32_t); - int64_t L = dist_in.size(-1); // 当前维度的长度 - int64_t total_slices = dist_in.numel() / L; + edt_kernel_first_final<<>>( + input.data_ptr(), + distance.data_ptr(), index.data_ptr(), + shape_gpu, strides_gpu, ndim, 0, total_slices, 1 + ); + } else { + auto buffer_dist = torch::empty_like(distance); + auto buffer_idx = torch::empty_like(index); - auto dist_out = torch::empty_like(dist_in); - auto idx_out = torch::empty_like(idx_in); + for (int pass = 0; pass < sample_ndim; ++pass) { + int32_t d_sample = dim_order_pairs[pass].second; + bool is_first = (pass == 0); + bool is_final = (pass == sample_ndim - 1); - // --- Step B: Kernel Dispatch --- - int threads = std::min((int64_t)MAX_THREADS, L); - - // 检查是否可以使用 Shared Memory - if (L <= SMEM_LIMIT_ELEMENTS) { - size_t smem_size = L * (sizeof(float) + 2 * sizeof(int)); + torch::Tensor *in_dist, *in_idx, *out_dist, *out_idx; - // 使用 Switch 宏来处理常用的维度模板特化 - #define DISPATCH_SHARED(IS_FINAL) \ - switch(sample_ndim) { \ - case 1: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 2: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 3: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 4: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 5: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 6: edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - default: /* Fallback for > 6D */ \ - edt_kernel_shared<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - } - - if (is_final_pass) { DISPATCH_SHARED(true); } - else { DISPATCH_SHARED(false); } - - } else { - // Global Memory Fallback (L > 4096) - if (global_buf1.numel() < dist_in.numel()) { - global_buf1 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); - global_buf2 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); + if (pass % 2 == 0) { + in_dist = &distance; in_idx = &index; + out_dist = &buffer_dist; out_idx = &buffer_idx; + } else { + in_dist = &buffer_dist; in_idx = &buffer_idx; + out_dist = &distance; out_idx = &index; + } + + int64_t num_slices_per_sample = 1; + for(int i = 0; i < sample_ndim; ++i) { + if (i != d_sample) num_slices_per_sample *= shape[i + 1]; + } + int64_t total_slices = batch_size * num_slices_per_sample; + int64_t slice_len = shape[d_sample + 1]; + + int threads = std::min((int64_t)256, slice_len); + size_t smem = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + + (slice_len + 2) * sizeof(float) + slice_len * sample_ndim * sizeof(int32_t); + + if (is_first && is_final) { + edt_kernel_first_final<<>>( + input.data_ptr(), + out_dist->data_ptr(), out_idx->data_ptr(), + shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + ); + } else if (is_first) { + edt_kernel_first_only<<>>( + input.data_ptr(), + out_dist->data_ptr(), out_idx->data_ptr(), + shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + ); + } else if (is_final) { + edt_kernel_final<<>>( + in_dist->data_ptr(), in_idx->data_ptr(), + out_dist->data_ptr(), out_idx->data_ptr(), + shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + ); + } else { + edt_kernel_middle<<>>( + in_dist->data_ptr(), in_idx->data_ptr(), + out_dist->data_ptr(), out_idx->data_ptr(), + shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + ); } - - #define DISPATCH_GLOBAL(IS_FINAL) \ - switch(sample_ndim) { \ - case 1: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 2: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 3: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 4: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 5: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - case 6: edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - default: /* Fallback */ \ - edt_kernel_global<<>>( \ - dist_in.data_ptr(), idx_in.data_ptr(), \ - dist_out.data_ptr(), idx_out.data_ptr(), \ - global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel(), sample_ndim); break; \ - } - - if (is_final_pass) { DISPATCH_GLOBAL(true); } - else { DISPATCH_GLOBAL(false); } } - - // --- Step C: Transpose Back --- - current_dist = dist_out.transpose(d, ndim - 1); - current_idx = idx_out.transpose(d, ndim - 1); + + if (sample_ndim % 2 != 0) { + distance.copy_(buffer_dist); + index.copy_(buffer_idx); + } } - - return std::make_tuple(current_dist, current_idx); -} - + + if (had_no_batch_dim) { + return std::make_tuple(distance.squeeze(0), index.squeeze(0)); + } + + return std::make_tuple(distance, index); +} \ No newline at end of file From a6fafaa41678abcdaad1fceaa523cc20c25d3914 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Mon, 8 Dec 2025 20:23:07 +0800 Subject: [PATCH 17/39] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E8=B0=83=E6=95=B4+=E9=80=9F=E5=BA=A6=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_distance_transform.py | 20 +- torchmorph/csrc/distance_transform_kernel.cu | 636 ++++++------------- 2 files changed, 204 insertions(+), 452 deletions(-) diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 3a11166..476ffca 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -6,20 +6,32 @@ # 辅助函数:调用 SciPy 并处理格式 def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - is_single_sample = batch_numpy.ndim <= 2 - if is_single_sample: - batch_numpy = batch_numpy[np.newaxis, ...] + + + input_is_1d_batch = (batch_numpy.ndim == 2) + input_is_single_sample_no_batch = (batch_numpy.ndim == 1) + + if input_is_single_sample_no_batch: + batch_numpy = batch_numpy[np.newaxis, ...] # (L) -> (1, L) + + dist_results, indices_results = [], [] for sample in batch_numpy: dist, indices = scipy_edt(sample, return_indices=True, return_distances=True) dist_results.append(dist) indices_results.append(indices) + output_dist = np.stack(dist_results, axis=0) output_indices = np.stack(indices_results, axis=0) + + # indices shape fix: (N, ndim_sample, ...) -> (N, ..., ndim_sample) + # 对于 1D: (N, 1, L) -> (N, L, 1) output_indices = np.moveaxis(output_indices, 1, -1) - if is_single_sample: + + if input_is_single_sample_no_batch: output_dist = output_dist.squeeze(0) output_indices = output_indices.squeeze(0) + return output_dist, output_indices # 用例定义 diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 4ec4647..c3ded1e 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -5,15 +5,21 @@ #include #include -// 优化策略:用4个独立的内核函数替代模板,完全消除分支 - -// 内核1: 第一个pass且是唯一pass (1D情况) -__global__ void edt_kernel_first_final( - const float* in_data, - float* out_dist, - int32_t* out_idx, - const int64_t* shape, - const int64_t* strides, +#define MAX_DIMS 10 +#define INF_VAL 1e8f // 使用 1e8 保证 float32 精度下的数值稳定性 + +__device__ __forceinline__ float sqr(float x) { return x * x; } + +// ------------------------------------------------------------------ +// 内核 1: 初始 Pass (First Pass) +// ------------------------------------------------------------------ +template +__global__ void edt_kernel_first_pass( + const float* __restrict__ in_data, + float* __restrict__ out_dist, + int32_t* __restrict__ out_idx, + const int64_t* __restrict__ shape, + const int64_t* __restrict__ strides, int32_t ndim, int32_t process_dim_sample, int64_t total_slices, @@ -24,244 +30,117 @@ __global__ void edt_kernel_first_final( int64_t batch_idx = slice_idx / num_slices_per_sample; int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t batch_offset = batch_idx * strides[0]; - int64_t sample_base_offset = 0; - int64_t temp_idx = slice_idx_in_sample; - const int sample_ndim = ndim - 1; - - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; - } - } - - const int64_t process_dim_actual = process_dim_sample + 1; - const int64_t N = shape[process_dim_actual]; - const int64_t stride = strides[process_dim_actual]; - const int64_t base_offset = batch_offset + sample_base_offset; - - if (N == 0) return; - - extern __shared__ char s_buffer[]; - float* f = (float*)s_buffer; - int* v = (int*)(f + N); - float* z = (float*)((char*)v + (N + 1) * sizeof(int)); - int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); - - // 加载数据 - 第一个pass - for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - int64_t global_offset = base_offset + i * stride; - float val = __ldg(&in_data[global_offset]); - int32_t* shared_idx_ptr = s_idx + i * sample_ndim; - - if (val == 0.0f) { - f[i] = 0.0f; - int64_t temp_coord = slice_idx_in_sample; - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - shared_idx_ptr[d] = temp_coord % size_of_dim; - temp_coord /= size_of_dim; - } else { - shared_idx_ptr[d] = 0; - } - } - shared_idx_ptr[process_dim_sample] = i; - } else { - f[i] = 1e20f; - for (int d = 0; d < sample_ndim; ++d) shared_idx_ptr[d] = -1; - } - } - __syncthreads(); - - // 构建包络 - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; - z[0] = -1e20f; - z[1] = 1e20f; - - for (int q = 1; q < N; q++) { - float fq = f[q]; - int q_sq = q * q; - - while (k >= 0) { - int p = v[k]; - float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); - if (s > z[k]) { - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e20f; - break; - } - k--; - if (k < 0) { - k = 0; - v[0] = q; - z[0] = -1e20f; - z[1] = 1e20f; - break; - } - } - } - } - __syncthreads(); - - // 计算距离 - 最后一个pass,直接开方 - for (int q = threadIdx.x; q < N; q += blockDim.x) { - int k = 0; - float q_float = (float)q; - while (z[k + 1] < q_float) k++; - - int p = v[k]; - int64_t global_offset = base_offset + q * stride; - float dist_sq = (float)(q - p) * (q - p) + f[p]; - - out_dist[global_offset] = sqrtf(dist_sq); // 直接开方 - - int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; - const int32_t* src_idx_ptr = s_idx + p * sample_ndim; - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; - } - } -} + int64_t current_offset = batch_idx * strides[0]; -// 内核2: 第一个pass但不是最后 -__global__ void edt_kernel_first_only( - const float* in_data, - float* out_dist, - int32_t* out_idx, - const int64_t* shape, - const int64_t* strides, - int32_t ndim, - int32_t process_dim_sample, - int64_t total_slices, - int64_t num_slices_per_sample -) { - int64_t slice_idx = blockIdx.x; - if (slice_idx >= total_slices) return; - - int64_t batch_idx = slice_idx / num_slices_per_sample; - int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t batch_offset = batch_idx * strides[0]; - int64_t sample_base_offset = 0; + // 预计算基准坐标 (除了 process_dim 以外的维度坐标) + int32_t base_coords[MAX_DIMS]; int64_t temp_idx = slice_idx_in_sample; const int sample_ndim = ndim - 1; + // 根据 slice_idx 反解坐标 for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; + if (d == process_dim_sample) { + base_coords[d] = 0; // 占位 + continue; } + int64_t size_of_dim = shape[d + 1]; + int32_t coord = (int32_t)(temp_idx % size_of_dim); + base_coords[d] = coord; + current_offset += coord * strides[d + 1]; + temp_idx /= size_of_dim; } - + const int64_t process_dim_actual = process_dim_sample + 1; const int64_t N = shape[process_dim_actual]; const int64_t stride = strides[process_dim_actual]; - const int64_t base_offset = batch_offset + sample_base_offset; - + if (N == 0) return; extern __shared__ char s_buffer[]; float* f = (float*)s_buffer; int* v = (int*)(f + N); float* z = (float*)((char*)v + (N + 1) * sizeof(int)); - int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); + // Phase 1: 加载数据 for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - int64_t global_offset = base_offset + i * stride; - float val = __ldg(&in_data[global_offset]); - int32_t* shared_idx_ptr = s_idx + i * sample_ndim; - - if (val == 0.0f) { - f[i] = 0.0f; - int64_t temp_coord = slice_idx_in_sample; - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - shared_idx_ptr[d] = temp_coord % size_of_dim; - temp_coord /= size_of_dim; - } else { - shared_idx_ptr[d] = 0; - } - } - shared_idx_ptr[process_dim_sample] = i; - } else { - f[i] = 1e20f; - for (int d = 0; d < sample_ndim; ++d) shared_idx_ptr[d] = -1; - } + float val = __ldg(&in_data[current_offset + i * stride]); + f[i] = (val == 0.0f) ? 0.0f : INF_VAL; } __syncthreads(); + // Phase 2: 构建包络 if (threadIdx.x == 0) { int k = 0; v[0] = 0; - z[0] = -1e20f; - z[1] = 1e20f; + z[0] = -INF_VAL; + z[1] = INF_VAL; for (int q = 1; q < N; q++) { + // 显式跳过背景点,避免 INF 污染计算 + if (f[q] >= (INF_VAL * 0.9f)) continue; + float fq = f[q]; - int q_sq = q * q; - - while (k >= 0) { - int p = v[k]; - float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); - if (s > z[k]) { - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e20f; - break; - } - k--; - if (k < 0) { - k = 0; - v[0] = q; - z[0] = -1e20f; - z[1] = 1e20f; + int k_curr = k; + while (k_curr >= 0) { + int p = v[k_curr]; + + // --- 核心修复:数值稳定的交点公式 --- + // 先计算差值再相加,防止大数吞小数 + float diff_f = fq - f[p]; + float diff_sq = (float)q*(float)q - (float)p*(float)p; + float s = (diff_f + diff_sq) / (2.0f * (float)(q - p)); + + if (s > z[k_curr]) { + k_curr++; + v[k_curr] = q; + z[k_curr] = s; + z[k_curr + 1] = INF_VAL; + k = k_curr; break; } + k_curr--; + } + if (k_curr < 0) { + k = 0; v[0] = q; z[0] = -INF_VAL; z[1] = INF_VAL; } } } __syncthreads(); + // Phase 3: 计算距离 for (int q = threadIdx.x; q < N; q += blockDim.x) { int k = 0; float q_float = (float)q; while (z[k + 1] < q_float) k++; - int p = v[k]; - int64_t global_offset = base_offset + q * stride; - float dist_sq = (float)(q - p) * (q - p) + f[p]; + int p = v[k]; + + int64_t global_idx = current_offset + q * stride; + float dist_sq = sqr(q_float - (float)p) + f[p]; - out_dist[global_offset] = dist_sq; // 不开方 + out_dist[global_idx] = IsFinal ? sqrtf(dist_sq) : dist_sq; - int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; - const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + // 写入索引 + int32_t* out_idx_ptr = out_idx + global_idx * sample_ndim; + #pragma unroll for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; + // 只有当前处理的维度写入 p,其他维度写入基准坐标 + out_idx_ptr[d] = (d == process_dim_sample) ? p : base_coords[d]; } } } -// 内核3: 中间pass -__global__ void edt_kernel_middle( - const float* in_dist, - const int32_t* in_idx, - float* out_dist, - int32_t* out_idx, - const int64_t* shape, - const int64_t* strides, +// ------------------------------------------------------------------ +// 内核 2: 后续 Pass (Subsequent Pass) +// ------------------------------------------------------------------ +template +__global__ void edt_kernel_subsequent_pass( + const float* __restrict__ in_dist, + const int32_t* __restrict__ in_idx, + float* __restrict__ out_dist, + int32_t* __restrict__ out_idx, + const int64_t* __restrict__ shape, + const int64_t* __restrict__ strides, int32_t ndim, int32_t process_dim_sample, int64_t total_slices, @@ -272,24 +151,21 @@ __global__ void edt_kernel_middle( int64_t batch_idx = slice_idx / num_slices_per_sample; int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t batch_offset = batch_idx * strides[0]; - int64_t sample_base_offset = 0; + int64_t current_offset = batch_idx * strides[0]; + int64_t temp_idx = slice_idx_in_sample; const int sample_ndim = ndim - 1; for (int32_t d = sample_ndim - 1; d >= 0; --d) { if (d == process_dim_sample) continue; int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; - } + current_offset += (temp_idx % size_of_dim) * strides[d + 1]; + temp_idx /= size_of_dim; } const int64_t process_dim_actual = process_dim_sample + 1; const int64_t N = shape[process_dim_actual]; const int64_t stride = strides[process_dim_actual]; - const int64_t base_offset = batch_offset + sample_base_offset; if (N == 0) return; @@ -297,49 +173,33 @@ __global__ void edt_kernel_middle( float* f = (float*)s_buffer; int* v = (int*)(f + N); float* z = (float*)((char*)v + (N + 1) * sizeof(int)); - int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - int64_t global_offset = base_offset + i * stride; - f[i] = __ldg(&in_dist[global_offset]); - - const int32_t* global_idx_ptr = in_idx + global_offset * sample_ndim; - int32_t* shared_idx_ptr = s_idx + i * sample_ndim; - for (int d = 0; d < sample_ndim; ++d) { - shared_idx_ptr[d] = __ldg(&global_idx_ptr[d]); - } + f[i] = __ldg(&in_dist[current_offset + i * stride]); } __syncthreads(); if (threadIdx.x == 0) { int k = 0; - v[0] = 0; - z[0] = -1e20f; - z[1] = 1e20f; + v[0] = 0; z[0] = -INF_VAL; z[1] = INF_VAL; for (int q = 1; q < N; q++) { + if (f[q] >= (INF_VAL * 0.9f)) continue; + float fq = f[q]; - int q_sq = q * q; - - while (k >= 0) { - int p = v[k]; - float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); - if (s > z[k]) { - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e20f; - break; - } - k--; - if (k < 0) { - k = 0; - v[0] = q; - z[0] = -1e20f; - z[1] = 1e20f; - break; + int k_curr = k; + while (k_curr >= 0) { + int p = v[k_curr]; + float diff_f = fq - f[p]; + float diff_sq = (float)q*(float)q - (float)p*(float)p; + float s = (diff_f + diff_sq) / (2.0f * (float)(q - p)); + if (s > z[k_curr]) { + k_curr++; v[k_curr] = q; z[k_curr] = s; z[k_curr + 1] = INF_VAL; + k = k_curr; break; } + k_curr--; } + if (k_curr < 0) { k = 0; v[0] = q; z[0] = -INF_VAL; z[1] = INF_VAL; } } } __syncthreads(); @@ -350,260 +210,140 @@ __global__ void edt_kernel_middle( while (z[k + 1] < q_float) k++; int p = v[k]; - int64_t global_offset = base_offset + q * stride; - float dist_sq = (float)(q - p) * (q - p) + f[p]; - - out_dist[global_offset] = dist_sq; // 不开方 - - int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; - const int32_t* src_idx_ptr = s_idx + p * sample_ndim; - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; - } - } -} - -// 内核4: 最后一个pass -__global__ void edt_kernel_final( - const float* in_dist, - const int32_t* in_idx, - float* out_dist, - int32_t* out_idx, - const int64_t* shape, - const int64_t* strides, - int32_t ndim, - int32_t process_dim_sample, - int64_t total_slices, - int64_t num_slices_per_sample -) { - int64_t slice_idx = blockIdx.x; - if (slice_idx >= total_slices) return; - - int64_t batch_idx = slice_idx / num_slices_per_sample; - int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t batch_offset = batch_idx * strides[0]; - int64_t sample_base_offset = 0; - int64_t temp_idx = slice_idx_in_sample; - const int sample_ndim = ndim - 1; - - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - if (size_of_dim > 0) { - sample_base_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; - } - } - - const int64_t process_dim_actual = process_dim_sample + 1; - const int64_t N = shape[process_dim_actual]; - const int64_t stride = strides[process_dim_actual]; - const int64_t base_offset = batch_offset + sample_base_offset; - - if (N == 0) return; - - extern __shared__ char s_buffer[]; - float* f = (float*)s_buffer; - int* v = (int*)(f + N); - float* z = (float*)((char*)v + (N + 1) * sizeof(int)); - int32_t* s_idx = (int32_t*)((char*)z + (N + 2) * sizeof(float)); - - for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - int64_t global_offset = base_offset + i * stride; - f[i] = __ldg(&in_dist[global_offset]); - const int32_t* global_idx_ptr = in_idx + global_offset * sample_ndim; - int32_t* shared_idx_ptr = s_idx + i * sample_ndim; - for (int d = 0; d < sample_ndim; ++d) { - shared_idx_ptr[d] = __ldg(&global_idx_ptr[d]); - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; - z[0] = -1e20f; - z[1] = 1e20f; + int64_t q_global_offset = current_offset + q * stride; + int64_t p_global_offset = current_offset + p * stride; - for (int q = 1; q < N; q++) { - float fq = f[q]; - int q_sq = q * q; - - while (k >= 0) { - int p = v[k]; - float s = ((fq + q_sq) - (f[p] + p * p)) / (2.0f * (q - p)); - if (s > z[k]) { - k++; - v[k] = q; - z[k] = s; - z[k + 1] = 1e20f; - break; - } - k--; - if (k < 0) { - k = 0; - v[0] = q; - z[0] = -1e20f; - z[1] = 1e20f; - break; - } - } - } - } - __syncthreads(); + float dist_sq = sqr(q_float - (float)p) + f[p]; + out_dist[q_global_offset] = IsFinal ? sqrtf(dist_sq) : dist_sq; - for (int q = threadIdx.x; q < N; q += blockDim.x) { - int k = 0; - float q_float = (float)q; - while (z[k + 1] < q_float) k++; - - int p = v[k]; - int64_t global_offset = base_offset + q * stride; - float dist_sq = (float)(q - p) * (q - p) + f[p]; - - out_dist[global_offset] = sqrtf(dist_sq); // 最后开方 + // 索引直接从 Global Memory 拷贝,无需 Shared Memory + const int32_t* src_idx_ptr = in_idx + p_global_offset * sample_ndim; + int32_t* out_idx_ptr = out_idx + q_global_offset * sample_ndim; - int32_t* out_idx_ptr = out_idx + global_offset * sample_ndim; - const int32_t* src_idx_ptr = s_idx + p * sample_ndim; + #pragma unroll for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; + out_idx_ptr[d] = src_idx_ptr[d]; } } } -// Host函数 +// ------------------------------------------------------------------ +// Host 函数 +// ------------------------------------------------------------------ std::tuple distance_transform_cuda(torch::Tensor input) { - TORCH_CHECK(input.is_cuda(), "Input must be on a CUDA device."); - TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be a float tensor."); + TORCH_CHECK(input.is_cuda(), "Input must be on CUDA device."); + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32."); + input = input.contiguous(); - bool had_no_batch_dim = (input.dim() <= 2); - if (had_no_batch_dim) { input = input.unsqueeze(0); } + // 自动处理 1D 输入:(L) -> (1, L) + // 自动处理 1D 批处理:(N, L) 保持不变 (视为 N 个 1D 样本) + bool had_no_batch_dim = (input.dim() == 1); + if (had_no_batch_dim) { + input = input.unsqueeze(0); + } const auto ndim = input.dim(); const auto sample_ndim = ndim - 1; const auto batch_size = input.size(0); - auto shape = input.sizes().vec(); auto strides_vec = input.strides().vec(); if (input.numel() == 0) { - auto distance = torch::empty_like(input); auto index_shape = shape; index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); - auto index = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - if (had_no_batch_dim) return std::make_tuple(distance.squeeze(0), index.squeeze(0)); - return std::make_tuple(distance, index); + return std::make_tuple(torch::empty_like(input), + torch::empty(index_shape, input.options().dtype(torch::kInt32))); } - + auto distance = torch::empty_like(input); - auto index_options = input.options().dtype(torch::kInt32); auto index_shape = shape; - index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); - auto index = torch::empty(index_shape, index_options); - - if (torch::all(input != 0).item()) { - distance.fill_(std::numeric_limits::infinity()); - index.fill_(-1); - if (had_no_batch_dim) { - return std::make_tuple(distance.squeeze(0), index.squeeze(0)); - } - return std::make_tuple(distance, index); - } + index_shape.push_back(sample_ndim); + auto index = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - auto shape_tensor = torch::tensor(shape, - torch::TensorOptions().dtype(torch::kInt64).device(input.device())); - auto strides_tensor = torch::tensor(strides_vec, - torch::TensorOptions().dtype(torch::kInt64).device(input.device())); - - const int64_t* shape_gpu = shape_tensor.data_ptr(); - const int64_t* strides_gpu = strides_tensor.data_ptr(); + auto buffer_dist = torch::empty_like(distance); + auto buffer_idx = torch::empty_like(index); + + auto shape_tensor = torch::tensor(shape, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); + auto strides_tensor = torch::tensor(strides_vec, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); std::vector> dim_order_pairs; - for (int32_t d_sample = 0; d_sample < sample_ndim; ++d_sample) { - dim_order_pairs.push_back({strides_vec[d_sample + 1], d_sample}); + for (int32_t d = 0; d < sample_ndim; ++d) { + dim_order_pairs.push_back({strides_vec[d + 1], d}); } std::sort(dim_order_pairs.rbegin(), dim_order_pairs.rend()); - if (sample_ndim == 0) { - int64_t total_slices = batch_size; - int64_t slice_len = (shape.size() > 1) ? shape[1] : 0; - int threads = std::min((int64_t)256, slice_len); - size_t smem = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + - (slice_len + 2) * sizeof(float) + slice_len * 1 * sizeof(int32_t); - - edt_kernel_first_final<<>>( - input.data_ptr(), - distance.data_ptr(), index.data_ptr(), - shape_gpu, strides_gpu, ndim, 0, total_slices, 1 - ); - } else { - auto buffer_dist = torch::empty_like(distance); - auto buffer_idx = torch::empty_like(index); - - for (int pass = 0; pass < sample_ndim; ++pass) { - int32_t d_sample = dim_order_pairs[pass].second; - bool is_first = (pass == 0); - bool is_final = (pass == sample_ndim - 1); - - torch::Tensor *in_dist, *in_idx, *out_dist, *out_idx; - - if (pass % 2 == 0) { - in_dist = &distance; in_idx = &index; - out_dist = &buffer_dist; out_idx = &buffer_idx; + for (int pass = 0; pass < sample_ndim; ++pass) { + int32_t d_sample = dim_order_pairs[pass].second; + bool is_first_pass = (pass == 0); + bool is_final_pass = (pass == sample_ndim - 1); + + torch::Tensor *in_d, *in_i, *out_d, *out_i; + + if (is_first_pass) { + in_d = nullptr; in_i = nullptr; + out_d = is_final_pass ? &distance : &buffer_dist; + out_i = is_final_pass ? &index : &buffer_idx; + } else { + if (pass % 2 != 0) { + in_d = &buffer_dist; in_i = &buffer_idx; + out_d = &distance; out_i = &index; } else { - in_dist = &buffer_dist; in_idx = &buffer_idx; - out_dist = &distance; out_idx = &index; + in_d = &distance; in_i = &index; + out_d = &buffer_dist; out_i = &buffer_idx; } - - int64_t num_slices_per_sample = 1; - for(int i = 0; i < sample_ndim; ++i) { - if (i != d_sample) num_slices_per_sample *= shape[i + 1]; + if (is_final_pass) { + out_d = &distance; out_i = &index; } - int64_t total_slices = batch_size * num_slices_per_sample; - int64_t slice_len = shape[d_sample + 1]; - - int threads = std::min((int64_t)256, slice_len); - size_t smem = slice_len * sizeof(float) + (slice_len + 1) * sizeof(int) + - (slice_len + 2) * sizeof(float) + slice_len * sample_ndim * sizeof(int32_t); - - if (is_first && is_final) { - edt_kernel_first_final<<>>( - input.data_ptr(), - out_dist->data_ptr(), out_idx->data_ptr(), - shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + } + + int64_t num_slices_per_sample = 1; + for(int i = 0; i < sample_ndim; ++i) { + if (i != d_sample) num_slices_per_sample *= shape[i + 1]; + } + int64_t total_slices = batch_size * num_slices_per_sample; + int64_t slice_len = shape[d_sample + 1]; + + int threads = std::min((int64_t)256, slice_len); + size_t smem = slice_len * (sizeof(float) + sizeof(int)) + (slice_len + 1) * sizeof(float); + + if (is_first_pass) { + const float* in_ptr = input.data_ptr(); + if (is_final_pass) { + edt_kernel_first_pass<<>>( + in_ptr, out_d->data_ptr(), out_i->data_ptr(), + shape_tensor.data_ptr(), strides_tensor.data_ptr(), + ndim, d_sample, total_slices, num_slices_per_sample ); - } else if (is_first) { - edt_kernel_first_only<<>>( - input.data_ptr(), - out_dist->data_ptr(), out_idx->data_ptr(), - shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + } else { + edt_kernel_first_pass<<>>( + in_ptr, out_d->data_ptr(), out_i->data_ptr(), + shape_tensor.data_ptr(), strides_tensor.data_ptr(), + ndim, d_sample, total_slices, num_slices_per_sample ); - } else if (is_final) { - edt_kernel_final<<>>( - in_dist->data_ptr(), in_idx->data_ptr(), - out_dist->data_ptr(), out_idx->data_ptr(), - shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + } + } else { + if (is_final_pass) { + edt_kernel_subsequent_pass<<>>( + in_d->data_ptr(), in_i->data_ptr(), + out_d->data_ptr(), out_i->data_ptr(), + shape_tensor.data_ptr(), strides_tensor.data_ptr(), + ndim, d_sample, total_slices, num_slices_per_sample ); } else { - edt_kernel_middle<<>>( - in_dist->data_ptr(), in_idx->data_ptr(), - out_dist->data_ptr(), out_idx->data_ptr(), - shape_gpu, strides_gpu, ndim, d_sample, total_slices, num_slices_per_sample + edt_kernel_subsequent_pass<<>>( + in_d->data_ptr(), in_i->data_ptr(), + out_d->data_ptr(), out_i->data_ptr(), + shape_tensor.data_ptr(), strides_tensor.data_ptr(), + ndim, d_sample, total_slices, num_slices_per_sample ); } } - - if (sample_ndim % 2 != 0) { - distance.copy_(buffer_dist); - index.copy_(buffer_idx); - } } if (had_no_batch_dim) { return std::make_tuple(distance.squeeze(0), index.squeeze(0)); } - return std::make_tuple(distance, index); } \ No newline at end of file From 54464a87aed9492434c125e599f41bea87866aef Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Mon, 8 Dec 2025 20:26:03 +0800 Subject: [PATCH 18/39] =?UTF-8?q?=E9=87=87=E7=94=A8JFA=E7=AE=97=E6=B3=95?= =?UTF-8?q?=E6=8F=90=E9=AB=98=E5=B9=B6=E8=A1=8C=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torchmorph/csrc/distance_transform_kernel.cu | 290 ++++++++++++------- 1 file changed, 184 insertions(+), 106 deletions(-) diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index c3ded1e..64c69fb 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -4,14 +4,126 @@ #include #include #include +#include +#include #define MAX_DIMS 10 -#define INF_VAL 1e8f // 使用 1e8 保证 float32 精度下的数值稳定性 +#define INF_VAL 1e8f +// 保证 blockDim 足以覆盖大多数常见维度大小,或者配合 Loop 处理 +#define MAX_THREADS 1024 __device__ __forceinline__ float sqr(float x) { return x * x; } +// 计算从像素 q 到源点 p 的距离代价 (考虑了 p 点本身的数值权重 val_p) +__device__ __forceinline__ float compute_cost(int q, int p, float val_p) { + if (p < 0) return INF_VAL; // 无效点 + // dist = (q - p)^2 + f[p] + return sqr((float)q - (float)p) + val_p; +} + // ------------------------------------------------------------------ -// 内核 1: 初始 Pass (First Pass) +// 核心逻辑: 1D Jump Flooding (JFA) / Doubling Algorithm +// 全并行求解最近点索引,替代串行的抛物线构建 +// ------------------------------------------------------------------ +__device__ void compute_1d_jfa( + int N, + float* __restrict__ s_vals, // 输入数值 (dist^2) + int* __restrict__ s_idx_curr, // ping-pong buffer 1 + int* __restrict__ s_idx_next // ping-pong buffer 2 +) { + int tid = threadIdx.x; + + // --- 1. 初始化 --- + // 每个线程负责一个或多个像素的初始化 + for (int i = tid; i < N; i += blockDim.x) { + // 如果当前位置的值很大,说明它是背景,没有初始源点 (-1) + // 否则源点就是它自己 (i) + if (s_vals[i] >= INF_VAL * 0.9f) { + s_idx_curr[i] = -1; + } else { + s_idx_curr[i] = i; + } + } + __syncthreads(); + + // --- 2. 迭代传播 (Step = 1, 2, 4, 8...) --- + // 类似于双调排序或倍增法 + int* idx_in = s_idx_curr; + int* idx_out = s_idx_next; + + // 只要步长小于 N,就需要传播 + // 对于 N=1024, 只需要 10 次迭代,每次所有线程全并行 + for (int step = 1; step < N; step *= 2) { + + for (int i = tid; i < N; i += blockDim.x) { + int my_best_p = idx_in[i]; + float min_cost = INF_VAL; + + // 获取当前最优点的代价 + if (my_best_p != -1) { + min_cost = compute_cost(i, my_best_p, s_vals[my_best_p]); + } + + // --- 检查左边邻居 (i - step) --- + int left = i - step; + if (left >= 0) { + int left_p = idx_in[left]; // 邻居推荐的源点 + if (left_p != -1) { + float c = compute_cost(i, left_p, s_vals[left_p]); + if (c < min_cost) { + min_cost = c; + my_best_p = left_p; + } + } + } + + // --- 检查右边邻居 (i + step) --- + int right = i + step; + if (right < N) { + int right_p = idx_in[right]; // 邻居推荐的源点 + if (right_p != -1) { + float c = compute_cost(i, right_p, s_vals[right_p]); + if (c < min_cost) { + min_cost = c; + my_best_p = right_p; + } + } + } + + // 写入下一轮 Buffer + idx_out[i] = my_best_p; + } + + // 交换 Buffer 指针 + int* temp = idx_in; + idx_in = idx_out; + idx_out = temp; + + __syncthreads(); + } + + // --- 3. 结果写回 --- + // 如果最后结果在 s_idx_next 里 (循环次数是奇数),需要拷回 s_idx_curr + // 或者直接让调用者知道结果在哪。 + // 为了简单,我们统一把结果放在 s_idx_curr 指向的内存里。 + // 注意:idx_in 现在指向的是包含最新结果的 buffer。 + + // 如果 idx_in 已经指向 s_idx_curr,那不用动。 + // 如果 idx_in 指向 s_idx_next,说明最新结果在 s_idx_next,我们需要把它拷贝回 s_idx_curr + // 或者是调整后续代码读取的指针。 + + // 这里采用简单拷贝回 s_idx_curr 的方式,确保后续逻辑一致 + if (idx_in != s_idx_curr) { + for (int i = tid; i < N; i += blockDim.x) { + s_idx_curr[i] = s_idx_next[i]; + } + __syncthreads(); + } +} + + +// ------------------------------------------------------------------ +// 内核 1: 初始 Pass (JFA Version) // ------------------------------------------------------------------ template __global__ void edt_kernel_first_pass( @@ -32,15 +144,13 @@ __global__ void edt_kernel_first_pass( int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; int64_t current_offset = batch_idx * strides[0]; - // 预计算基准坐标 (除了 process_dim 以外的维度坐标) int32_t base_coords[MAX_DIMS]; int64_t temp_idx = slice_idx_in_sample; const int sample_ndim = ndim - 1; - // 根据 slice_idx 反解坐标 for (int32_t d = sample_ndim - 1; d >= 0; --d) { if (d == process_dim_sample) { - base_coords[d] = 0; // 占位 + base_coords[d] = 0; continue; } int64_t size_of_dim = shape[d + 1]; @@ -56,82 +166,58 @@ __global__ void edt_kernel_first_pass( if (N == 0) return; + // Shared Memory Layout: + // f: float[N] (Values) + // idx1: int[N] (Buffer 1) + // idx2: int[N] (Buffer 2) extern __shared__ char s_buffer[]; float* f = (float*)s_buffer; - int* v = (int*)(f + N); - float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int* idx1 = (int*)(f + N); + int* idx2 = (int*)(idx1 + N); - // Phase 1: 加载数据 + // Phase 1: 并行加载数据 for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { float val = __ldg(&in_data[current_offset + i * stride]); f[i] = (val == 0.0f) ? 0.0f : INF_VAL; } __syncthreads(); - // Phase 2: 构建包络 - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; - z[0] = -INF_VAL; - z[1] = INF_VAL; - - for (int q = 1; q < N; q++) { - // 显式跳过背景点,避免 INF 污染计算 - if (f[q] >= (INF_VAL * 0.9f)) continue; - - float fq = f[q]; - int k_curr = k; - while (k_curr >= 0) { - int p = v[k_curr]; - - // --- 核心修复:数值稳定的交点公式 --- - // 先计算差值再相加,防止大数吞小数 - float diff_f = fq - f[p]; - float diff_sq = (float)q*(float)q - (float)p*(float)p; - float s = (diff_f + diff_sq) / (2.0f * (float)(q - p)); - - if (s > z[k_curr]) { - k_curr++; - v[k_curr] = q; - z[k_curr] = s; - z[k_curr + 1] = INF_VAL; - k = k_curr; - break; - } - k_curr--; - } - if (k_curr < 0) { - k = 0; v[0] = q; z[0] = -INF_VAL; z[1] = INF_VAL; - } - } - } - __syncthreads(); + // Phase 2: 并行 JFA 计算 + compute_1d_jfa(N, f, idx1, idx2); + // 结果现在存储在 idx1 中 - // Phase 3: 计算距离 + // Phase 3: 并行写回 for (int q = threadIdx.x; q < N; q += blockDim.x) { - int k = 0; - float q_float = (float)q; - while (z[k + 1] < q_float) k++; - - int p = v[k]; - + int p = idx1[q]; + float dist_val; + int p_idx; + + if (p != -1) { + // JFA 得到的是最近源点的索引 p + // 距离 = (q-p)^2 + f[p] + // 注意:在 First Pass 中,f[p] 要么是 0 要么是 INF。如果 p != -1,f[p] 必为 0。 + // 但为了通用性,还是加上 f[p] + float dist_sq = sqr((float)q - (float)p) + f[p]; + dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; + p_idx = p; + } else { + dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); + p_idx = 0; + } + int64_t global_idx = current_offset + q * stride; - float dist_sq = sqr(q_float - (float)p) + f[p]; - - out_dist[global_idx] = IsFinal ? sqrtf(dist_sq) : dist_sq; + out_dist[global_idx] = dist_val; - // 写入索引 int32_t* out_idx_ptr = out_idx + global_idx * sample_ndim; #pragma unroll for (int d = 0; d < sample_ndim; ++d) { - // 只有当前处理的维度写入 p,其他维度写入基准坐标 - out_idx_ptr[d] = (d == process_dim_sample) ? p : base_coords[d]; + out_idx_ptr[d] = (d == process_dim_sample) ? p_idx : base_coords[d]; } } } // ------------------------------------------------------------------ -// 内核 2: 后续 Pass (Subsequent Pass) +// 内核 2: 后续 Pass (JFA Version) // ------------------------------------------------------------------ template __global__ void edt_kernel_subsequent_pass( @@ -169,61 +255,48 @@ __global__ void edt_kernel_subsequent_pass( if (N == 0) return; + // Shared Memory Layout 同上 extern __shared__ char s_buffer[]; float* f = (float*)s_buffer; - int* v = (int*)(f + N); - float* z = (float*)((char*)v + (N + 1) * sizeof(int)); + int* idx1 = (int*)(f + N); + int* idx2 = (int*)(idx1 + N); + // Phase 1: 加载 for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { f[i] = __ldg(&in_dist[current_offset + i * stride]); } __syncthreads(); - if (threadIdx.x == 0) { - int k = 0; - v[0] = 0; z[0] = -INF_VAL; z[1] = INF_VAL; - - for (int q = 1; q < N; q++) { - if (f[q] >= (INF_VAL * 0.9f)) continue; - - float fq = f[q]; - int k_curr = k; - while (k_curr >= 0) { - int p = v[k_curr]; - float diff_f = fq - f[p]; - float diff_sq = (float)q*(float)q - (float)p*(float)p; - float s = (diff_f + diff_sq) / (2.0f * (float)(q - p)); - if (s > z[k_curr]) { - k_curr++; v[k_curr] = q; z[k_curr] = s; z[k_curr + 1] = INF_VAL; - k = k_curr; break; - } - k_curr--; - } - if (k_curr < 0) { k = 0; v[0] = q; z[0] = -INF_VAL; z[1] = INF_VAL; } - } - } - __syncthreads(); + // Phase 2: 并行 JFA 计算 + // 这里的 f[i] 是上一轮计算出的距离平方,作为权重 + compute_1d_jfa(N, f, idx1, idx2); + // Phase 3: 写回 for (int q = threadIdx.x; q < N; q += blockDim.x) { - int k = 0; - float q_float = (float)q; - while (z[k + 1] < q_float) k++; - - int p = v[k]; - + int p = idx1[q]; // 最近源点在当前行的索引 + float dist_val; + + if (p != -1) { + float dist_sq = sqr((float)q - (float)p) + f[p]; + dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; + } else { + dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); + p = 0; // fallback + } + int64_t q_global_offset = current_offset + q * stride; - int64_t p_global_offset = current_offset + p * stride; - - float dist_sq = sqr(q_float - (float)p) + f[p]; - out_dist[q_global_offset] = IsFinal ? sqrtf(dist_sq) : dist_sq; + out_dist[q_global_offset] = dist_val; - // 索引直接从 Global Memory 拷贝,无需 Shared Memory - const int32_t* src_idx_ptr = in_idx + p_global_offset * sample_ndim; - int32_t* out_idx_ptr = out_idx + q_global_offset * sample_ndim; + // 索引处理 + if (p != -1) { + int64_t p_global_offset = current_offset + p * stride; + const int32_t* src_idx_ptr = in_idx + p_global_offset * sample_ndim; + int32_t* out_idx_ptr = out_idx + q_global_offset * sample_ndim; - #pragma unroll - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; + #pragma unroll + for (int d = 0; d < sample_ndim; ++d) { + out_idx_ptr[d] = src_idx_ptr[d]; + } } } } @@ -237,8 +310,6 @@ std::tuple distance_transform_cuda(torch::Tensor i input = input.contiguous(); - // 自动处理 1D 输入:(L) -> (1, L) - // 自动处理 1D 批处理:(N, L) 保持不变 (视为 N 个 1D 样本) bool had_no_batch_dim = (input.dim() == 1); if (had_no_batch_dim) { input = input.unsqueeze(0); @@ -305,8 +376,15 @@ std::tuple distance_transform_cuda(torch::Tensor i int64_t total_slices = batch_size * num_slices_per_sample; int64_t slice_len = shape[d_sample + 1]; - int threads = std::min((int64_t)256, slice_len); - size_t smem = slice_len * (sizeof(float) + sizeof(int)) + (slice_len + 1) * sizeof(float); + int threads = std::min((int64_t)MAX_THREADS, slice_len); + + // JFA 需要的 Shared Memory: + // float f[N] + // int idx1[N] + // int idx2[N] + // 总共 slice_len * (4 + 4 + 4) = 12 * slice_len bytes + size_t smem = slice_len * sizeof(float) + + slice_len * sizeof(int) * 2; if (is_first_pass) { const float* in_ptr = input.data_ptr(); From 32a09da72bc0653e7bfa6b286416f4677396144a Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Mon, 8 Dec 2025 20:27:55 +0800 Subject: [PATCH 19/39] =?UTF-8?q?=E5=90=88=E5=B9=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=86=85=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torchmorph/csrc/distance_transform_kernel.cu | 588 ++++++++++--------- 1 file changed, 298 insertions(+), 290 deletions(-) diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 64c69fb..a8e5a6a 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -1,75 +1,64 @@ #include #include -#include -#include #include -#include +#include #include #include -#define MAX_DIMS 10 -#define INF_VAL 1e8f -// 保证 blockDim 足以覆盖大多数常见维度大小,或者配合 Loop 处理 -#define MAX_THREADS 1024 +#define INF_VAL 1e8f +#define MAX_THREADS 1024 +#define SMEM_LIMIT_ELEMENTS 4096 // 48KB / 12 bytes (float+int+int) ~= 4096 __device__ __forceinline__ float sqr(float x) { return x * x; } -// 计算从像素 q 到源点 p 的距离代价 (考虑了 p 点本身的数值权重 val_p) +// 计算从像素 q 到源点 p 的距离代价 +// val_p 是源点 p 在上一轮计算后的距离平方值 (weight) __device__ __forceinline__ float compute_cost(int q, int p, float val_p) { - if (p < 0) return INF_VAL; // 无效点 - // dist = (q - p)^2 + f[p] + if (p < 0) return INF_VAL; return sqr((float)q - (float)p) + val_p; } // ------------------------------------------------------------------ -// 核心逻辑: 1D Jump Flooding (JFA) / Doubling Algorithm -// 全并行求解最近点索引,替代串行的抛物线构建 +// JFA 核心逻辑 (Device Function) // ------------------------------------------------------------------ -__device__ void compute_1d_jfa( +// 无论数据是在 Shared Memory 还是 Global Memory,逻辑是一样的 +__device__ void run_jfa_core( int N, - float* __restrict__ s_vals, // 输入数值 (dist^2) - int* __restrict__ s_idx_curr, // ping-pong buffer 1 - int* __restrict__ s_idx_next // ping-pong buffer 2 + int tid, + const float* __restrict__ vals, // 输入权重 (只读) + int* __restrict__ idx_curr, // Ping-Pong Buffer A + int* __restrict__ idx_next // Ping-Pong Buffer B ) { - int tid = threadIdx.x; - - // --- 1. 初始化 --- - // 每个线程负责一个或多个像素的初始化 + // 1. 初始化 for (int i = tid; i < N; i += blockDim.x) { - // 如果当前位置的值很大,说明它是背景,没有初始源点 (-1) - // 否则源点就是它自己 (i) - if (s_vals[i] >= INF_VAL * 0.9f) { - s_idx_curr[i] = -1; + // 如果输入值很大,说明是背景,没有初始源点 + if (vals[i] >= INF_VAL * 0.9f) { + idx_curr[i] = -1; } else { - s_idx_curr[i] = i; + idx_curr[i] = i; } } __syncthreads(); - // --- 2. 迭代传播 (Step = 1, 2, 4, 8...) --- - // 类似于双调排序或倍增法 - int* idx_in = s_idx_curr; - int* idx_out = s_idx_next; + // 2. 迭代传播 (Step = 1, 2, 4, ... < N) + int* idx_in = idx_curr; + int* idx_out = idx_next; - // 只要步长小于 N,就需要传播 - // 对于 N=1024, 只需要 10 次迭代,每次所有线程全并行 for (int step = 1; step < N; step *= 2) { - for (int i = tid; i < N; i += blockDim.x) { int my_best_p = idx_in[i]; float min_cost = INF_VAL; - // 获取当前最优点的代价 if (my_best_p != -1) { - min_cost = compute_cost(i, my_best_p, s_vals[my_best_p]); + min_cost = compute_cost(i, my_best_p, vals[my_best_p]); } - // --- 检查左边邻居 (i - step) --- + // Check Left int left = i - step; if (left >= 0) { - int left_p = idx_in[left]; // 邻居推荐的源点 + int left_p = idx_in[left]; if (left_p != -1) { - float c = compute_cost(i, left_p, s_vals[left_p]); + float c = compute_cost(i, left_p, vals[left_p]); if (c < min_cost) { min_cost = c; my_best_p = left_p; @@ -77,230 +66,206 @@ __device__ void compute_1d_jfa( } } - // --- 检查右边邻居 (i + step) --- + // Check Right int right = i + step; if (right < N) { - int right_p = idx_in[right]; // 邻居推荐的源点 + int right_p = idx_in[right]; if (right_p != -1) { - float c = compute_cost(i, right_p, s_vals[right_p]); + float c = compute_cost(i, right_p, vals[right_p]); if (c < min_cost) { min_cost = c; my_best_p = right_p; } } } - - // 写入下一轮 Buffer idx_out[i] = my_best_p; } - // 交换 Buffer 指针 + // Swap Pointers int* temp = idx_in; idx_in = idx_out; idx_out = temp; - __syncthreads(); } - // --- 3. 结果写回 --- - // 如果最后结果在 s_idx_next 里 (循环次数是奇数),需要拷回 s_idx_curr - // 或者直接让调用者知道结果在哪。 - // 为了简单,我们统一把结果放在 s_idx_curr 指向的内存里。 - // 注意:idx_in 现在指向的是包含最新结果的 buffer。 - - // 如果 idx_in 已经指向 s_idx_curr,那不用动。 - // 如果 idx_in 指向 s_idx_next,说明最新结果在 s_idx_next,我们需要把它拷贝回 s_idx_curr - // 或者是调整后续代码读取的指针。 - - // 这里采用简单拷贝回 s_idx_curr 的方式,确保后续逻辑一致 - if (idx_in != s_idx_curr) { + // 3. 确保最终结果在 idx_curr (如果循环结束时在 next,则拷回) + if (idx_in != idx_curr) { for (int i = tid; i < N; i += blockDim.x) { - s_idx_curr[i] = s_idx_next[i]; + idx_curr[i] = idx_next[i]; } __syncthreads(); } } - // ------------------------------------------------------------------ -// 内核 1: 初始 Pass (JFA Version) +// Kernel 1: Shared Memory JFA (Fast Path) +// 适用于 N <= 4096 // ------------------------------------------------------------------ -template -__global__ void edt_kernel_first_pass( - const float* __restrict__ in_data, - float* __restrict__ out_dist, - int32_t* __restrict__ out_idx, - const int64_t* __restrict__ shape, - const int64_t* __restrict__ strides, - int32_t ndim, - int32_t process_dim_sample, - int64_t total_slices, - int64_t num_slices_per_sample +template +__global__ void edt_kernel_shared( + const float* __restrict__ in_data, // 当前维度的输入 (dist^2) + const int32_t* __restrict__ in_indices, // 上一轮的索引图 (N_slices, L, NDim) + float* __restrict__ out_dist, // 输出距离 + int32_t* __restrict__ out_indices, // 输出索引图 + int64_t L, // 当前维度的长度 (Length) + int64_t total_elements // Batch * ... * L ) { - int64_t slice_idx = blockIdx.x; - if (slice_idx >= total_slices) return; - - int64_t batch_idx = slice_idx / num_slices_per_sample; - int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t current_offset = batch_idx * strides[0]; - - int32_t base_coords[MAX_DIMS]; - int64_t temp_idx = slice_idx_in_sample; - const int sample_ndim = ndim - 1; - - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) { - base_coords[d] = 0; - continue; - } - int64_t size_of_dim = shape[d + 1]; - int32_t coord = (int32_t)(temp_idx % size_of_dim); - base_coords[d] = coord; - current_offset += coord * strides[d + 1]; - temp_idx /= size_of_dim; - } - - const int64_t process_dim_actual = process_dim_sample + 1; - const int64_t N = shape[process_dim_actual]; - const int64_t stride = strides[process_dim_actual]; + // 这里的 total_elements 是展平后的总像素数 + // 由于我们做了 transpose,数据布局是 [Batch_and_other_dims, L] + // 每个 Block 处理一行 (长度 L) + + int64_t row_idx = blockIdx.x; + int64_t offset = row_idx * L; - if (N == 0) return; + if (offset >= total_elements) return; - // Shared Memory Layout: - // f: float[N] (Values) - // idx1: int[N] (Buffer 1) - // idx2: int[N] (Buffer 2) + // Shared Memory 布局: float vals[L], int idx1[L], int idx2[L] extern __shared__ char s_buffer[]; - float* f = (float*)s_buffer; - int* idx1 = (int*)(f + N); - int* idx2 = (int*)(idx1 + N); - - // Phase 1: 并行加载数据 - for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - float val = __ldg(&in_data[current_offset + i * stride]); - f[i] = (val == 0.0f) ? 0.0f : INF_VAL; + float* s_vals = (float*)s_buffer; + int* s_idx1 = (int*)(s_vals + L); + int* s_idx2 = (int*)(s_idx1 + L); + + // 1. 加载数据到 Shared Memory + for (int i = threadIdx.x; i < L; i += blockDim.x) { + float val = __ldg(&in_data[offset + i]); + // 如果是初始 Pass (无输入索引),val 为 0 或 INF + // 如果是后续 Pass,val 为上一步的 dist^2 + s_vals[i] = val; } __syncthreads(); - // Phase 2: 并行 JFA 计算 - compute_1d_jfa(N, f, idx1, idx2); - // 结果现在存储在 idx1 中 - - // Phase 3: 并行写回 - for (int q = threadIdx.x; q < N; q += blockDim.x) { - int p = idx1[q]; + // 2. 运行 JFA + run_jfa_core(L, threadIdx.x, s_vals, s_idx1, s_idx2); + + // 3. 写回结果 + for (int q = threadIdx.x; q < L; q += blockDim.x) { + int p = s_idx1[q]; // 最近点在当前行内的局部索引 (0..L-1) float dist_val; - int p_idx; if (p != -1) { - // JFA 得到的是最近源点的索引 p - // 距离 = (q-p)^2 + f[p] - // 注意:在 First Pass 中,f[p] 要么是 0 要么是 INF。如果 p != -1,f[p] 必为 0。 - // 但为了通用性,还是加上 f[p] - float dist_sq = sqr((float)q - (float)p) + f[p]; + // 计算新距离: (q-p)^2 + val[p] + float dist_sq = sqr((float)q - (float)p) + s_vals[p]; dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; - p_idx = p; } else { dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p_idx = 0; + p = 0; // fallback } - int64_t global_idx = current_offset + q * stride; - out_dist[global_idx] = dist_val; + out_dist[offset + q] = dist_val; - int32_t* out_idx_ptr = out_idx + global_idx * sample_ndim; - #pragma unroll - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = (d == process_dim_sample) ? p_idx : base_coords[d]; + // 4. 索引传播 + // 我们需要从 in_indices 查找完整的高维索引 + // in_indices 形状: [Batch..., L, NDim] + // 这里的 offset 对应 [Batch..., 0] + // p 是当前维度的偏移 + if (p != -1) { + int64_t src_offset = (offset + p) * NDim; + int64_t dst_offset = (offset + q) * NDim; + + // 手动展开拷贝,或者循环 + for (int d = 0; d < NDim; ++d) { + out_indices[dst_offset + d] = in_indices[src_offset + d]; + } + } else { + // 保持原样或填0 (通常保持原样即可,或者为了安全填0) + int64_t dst_offset = (offset + q) * NDim; + for (int d = 0; d < NDim; ++d) out_indices[dst_offset + d] = 0; } } } // ------------------------------------------------------------------ -// 内核 2: 后续 Pass (JFA Version) +// Kernel 2: Global Memory JFA (Fallback Path) +// 适用于 N > 4096,使用 Global Memory 作为 Ping-Pong Buffer // ------------------------------------------------------------------ -template -__global__ void edt_kernel_subsequent_pass( - const float* __restrict__ in_dist, - const int32_t* __restrict__ in_idx, +template +__global__ void edt_kernel_global( + const float* __restrict__ in_data, + const int32_t* __restrict__ in_indices, float* __restrict__ out_dist, - int32_t* __restrict__ out_idx, - const int64_t* __restrict__ shape, - const int64_t* __restrict__ strides, - int32_t ndim, - int32_t process_dim_sample, - int64_t total_slices, - int64_t num_slices_per_sample + int32_t* __restrict__ out_indices, + int* __restrict__ global_buffer_1, // 临时 buffer A [TotalElements] + int* __restrict__ global_buffer_2, // 临时 buffer B [TotalElements] + int64_t L, + int64_t total_elements ) { - int64_t slice_idx = blockIdx.x; - if (slice_idx >= total_slices) return; - - int64_t batch_idx = slice_idx / num_slices_per_sample; - int64_t slice_idx_in_sample = slice_idx % num_slices_per_sample; - int64_t current_offset = batch_idx * strides[0]; + int64_t row_idx = blockIdx.x; + int64_t offset = row_idx * L; - int64_t temp_idx = slice_idx_in_sample; - const int sample_ndim = ndim - 1; - - for (int32_t d = sample_ndim - 1; d >= 0; --d) { - if (d == process_dim_sample) continue; - int64_t size_of_dim = shape[d + 1]; - current_offset += (temp_idx % size_of_dim) * strides[d + 1]; - temp_idx /= size_of_dim; - } - - const int64_t process_dim_actual = process_dim_sample + 1; - const int64_t N = shape[process_dim_actual]; - const int64_t stride = strides[process_dim_actual]; - - if (N == 0) return; - - // Shared Memory Layout 同上 - extern __shared__ char s_buffer[]; - float* f = (float*)s_buffer; - int* idx1 = (int*)(f + N); - int* idx2 = (int*)(idx1 + N); + if (offset >= total_elements) return; - // Phase 1: 加载 - for (int64_t i = threadIdx.x; i < N; i += blockDim.x) { - f[i] = __ldg(&in_dist[current_offset + i * stride]); - } - __syncthreads(); - - // Phase 2: 并行 JFA 计算 - // 这里的 f[i] 是上一轮计算出的距离平方,作为权重 - compute_1d_jfa(N, f, idx1, idx2); + // 指向当前行在 Global Memory 中的位置 + // 注意:in_data 是只读的,我们需要把它当做 weight + // JFA 需要两个 int buffer 来存 index + int* g_idx1 = global_buffer_1 + offset; + int* g_idx2 = global_buffer_2 + offset; + + // 直接在 Global Memory 上运行 JFA + // 注意:这里 vals 指针直接指向 in_data (Global),读取稍慢但无需拷贝 + run_jfa_core(L, threadIdx.x, in_data + offset, g_idx1, g_idx2); - // Phase 3: 写回 - for (int q = threadIdx.x; q < N; q += blockDim.x) { - int p = idx1[q]; // 最近源点在当前行的索引 + // 写回逻辑同上 + for (int q = threadIdx.x; q < L; q += blockDim.x) { + int p = g_idx1[q]; float dist_val; if (p != -1) { - float dist_sq = sqr((float)q - (float)p) + f[p]; + float val_p = in_data[offset + p]; + float dist_sq = sqr((float)q - (float)p) + val_p; dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; } else { dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p = 0; // fallback + p = 0; } - int64_t q_global_offset = current_offset + q * stride; - out_dist[q_global_offset] = dist_val; + out_dist[offset + q] = dist_val; - // 索引处理 if (p != -1) { - int64_t p_global_offset = current_offset + p * stride; - const int32_t* src_idx_ptr = in_idx + p_global_offset * sample_ndim; - int32_t* out_idx_ptr = out_idx + q_global_offset * sample_ndim; - - #pragma unroll - for (int d = 0; d < sample_ndim; ++d) { - out_idx_ptr[d] = src_idx_ptr[d]; + int64_t src_offset = (offset + p) * NDim; + int64_t dst_offset = (offset + q) * NDim; + for (int d = 0; d < NDim; ++d) { + out_indices[dst_offset + d] = in_indices[src_offset + d]; } + } else { + int64_t dst_offset = (offset + q) * NDim; + for (int d = 0; d < NDim; ++d) out_indices[dst_offset + d] = 0; } } } + +// ------------------------------------------------------------------ +// 辅助:初始化索引张量 +// ------------------------------------------------------------------ +// 将 index tensor 初始化为 grid grid coordinates +// shape: (..., D), 最后一个维度存坐标 +__global__ void init_indices_kernel(int32_t* indices, int64_t total_elements, int NDim, + const int64_t* shape, const int64_t* strides) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_elements) return; + + // 反解坐标 + int64_t temp = idx; + int32_t coords[10]; // max dims + + // strides 是针对 elements 展开的,但这里 indices 是 (Total, NDim) + // 我们可以简单地根据 shape 反解 + // 注意:这里的 total_elements 是像素数,不是 indices 数组的大小 + + // 假设 shape 是 [D0, D1, D2] + // idx 对应 flat index + + for (int d = NDim - 1; d >= 0; --d) { + coords[d] = temp % shape[d]; + temp /= shape[d]; + } + + // 写入 + for (int d = 0; d < NDim; ++d) { + indices[idx * NDim + d] = coords[d]; + } +} + // ------------------------------------------------------------------ // Host 函数 // ------------------------------------------------------------------ @@ -309,119 +274,162 @@ std::tuple distance_transform_cuda(torch::Tensor i TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32."); input = input.contiguous(); + bool had_no_batch_dim = (input.dim() == 1); + if (had_no_batch_dim) input = input.unsqueeze(0); - bool had_no_batch_dim = (input.dim() == 1); - if (had_no_batch_dim) { - input = input.unsqueeze(0); - } - - const auto ndim = input.dim(); - const auto sample_ndim = ndim - 1; - const auto batch_size = input.size(0); + const int ndim = input.dim(); // Include batch + const int sample_ndim = ndim - 1; auto shape = input.sizes().vec(); - auto strides_vec = input.strides().vec(); - - if (input.numel() == 0) { + int64_t num_pixels = input.numel(); + + if (num_pixels == 0) { auto index_shape = shape; index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); return std::make_tuple(torch::empty_like(input), - torch::empty(index_shape, input.options().dtype(torch::kInt32))); + torch::empty(index_shape, input.options().dtype(torch::kInt32))); } - auto distance = torch::empty_like(input); + // 1. 初始化输出 Tensor + // current_dist 在迭代过程中存储 dist^2,最后开方 + // 初始状态:Input 里的 0 还是 0,其他非 0 (背景) 设为 INF + auto current_dist = torch::where(input == 0, + torch::tensor(0.0f, input.options()), + torch::tensor(INF_VAL, input.options())); + + // 初始化索引 Map (Batch, ..., NDim) auto index_shape = shape; - index_shape.push_back(sample_ndim); - auto index = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - - auto buffer_dist = torch::empty_like(distance); - auto buffer_idx = torch::empty_like(index); - - auto shape_tensor = torch::tensor(shape, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); - auto strides_tensor = torch::tensor(strides_vec, torch::TensorOptions().dtype(torch::kInt64).device(input.device())); - - std::vector> dim_order_pairs; - for (int32_t d = 0; d < sample_ndim; ++d) { - dim_order_pairs.push_back({strides_vec[d + 1], d}); + index_shape.push_back(sample_ndim); + auto current_idx = torch::empty(index_shape, input.options().dtype(torch::kInt32)); + + // 启动 Kernel 初始化索引 + // 为了反解坐标,我们需要把 shape 传进去 + { + // 排除 batch 维度的 shape 用于坐标计算? + // 需求是:返回的索引是 (batch_idx, z, y, x) 还是只是 (z, y, x)? + // 通常 EDT 返回的是 sample 内的坐标。所以我们忽略 batch 维度。 + std::vector sample_shape_vec(shape.begin() + 1, shape.end()); + auto sample_shape_tensor = torch::tensor(sample_shape_vec, torch::kInt64).to(input.device()); + // 这里的 strides 不需要,直接由 shape 反解 + + int threads = 256; + int blocks = (num_pixels + threads - 1) / threads; + + // 我们需要传递 sample_ndim + init_indices_kernel<<>>( + current_idx.data_ptr(), + num_pixels, + sample_ndim, + sample_shape_tensor.data_ptr(), + nullptr // strides not needed for simple unravel + ); } - std::sort(dim_order_pairs.rbegin(), dim_order_pairs.rend()); + + // 用于 Global Memory Fallback 的临时 buffer + torch::Tensor global_buf1, global_buf2; + + // 2. 逐维处理 (Separable Phases) + // 从最后一个维度倒着处理,或者顺序处理都可以。 + // 为了 Host Transpose 方便,我们遍历 sample 的每一个维度 (1 到 ndim-1) + for (int d = 1; d < ndim; ++d) { + bool is_final_pass = (d == ndim - 1); + + // ----------------------------------------------------------- + // Step A: Permute & Contiguous + // 将当前处理维度 d 移到最后: (0, 1, ..., d, ..., N-1) -> (0, 1, ..., N-1, d) + // 这样最后内存布局就是 [..., L],stride=1 + // ----------------------------------------------------------- + + // 这种 swap 策略比较简单: transpose(d, -1) + // 注意:index tensor 也要变换,但 index tensor 最后一维是 coord_dim,不能乱动。 + // Index tensor 形状是 [..., sample_ndim]。 + // 我们需要变换的是前面的空间维度 [...]。 + + auto dist_in = current_dist.transpose(d, ndim - 1).contiguous(); + auto idx_in = current_idx.transpose(d, ndim - 1).contiguous(); + + // 此时 dist_in shape: [..., L] + // idx_in shape: [..., L, sample_ndim] + + int64_t L = dist_in.size(-1); // 当前维度的长度 + int64_t total_slices = dist_in.numel() / L; // 有多少行 + + auto dist_out = torch::empty_like(dist_in); + auto idx_out = torch::empty_like(idx_in); - for (int pass = 0; pass < sample_ndim; ++pass) { - int32_t d_sample = dim_order_pairs[pass].second; - bool is_first_pass = (pass == 0); - bool is_final_pass = (pass == sample_ndim - 1); + // ----------------------------------------------------------- + // Step B: Kernel Dispatch + // ----------------------------------------------------------- + int threads = std::min((int64_t)MAX_THREADS, L); + + // 检查 Shared Memory 需求 + // Need: float(4) + int(4) + int(4) = 12 bytes per pixel + if (L <= SMEM_LIMIT_ELEMENTS) { + size_t smem_size = L * (sizeof(float) + 2 * sizeof(int)); + + // 模板参数 NDim 需要是编译期常量。 + // 动态分发 sample_ndim (1D, 2D, 3D usually) + // 使用 switch case 覆盖常见维度 (1, 2, 3) + #define DISPATCH_SHARED(IS_FINAL) \ + switch(sample_ndim) { \ + case 1: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel()); break; \ + case 2: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel()); break; \ + case 3: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel()); break; \ + default: /* fallback for >3D */ break; \ + } - torch::Tensor *in_d, *in_i, *out_d, *out_i; + if (is_final_pass) { DISPATCH_SHARED(true); } + else { DISPATCH_SHARED(false); } - if (is_first_pass) { - in_d = nullptr; in_i = nullptr; - out_d = is_final_pass ? &distance : &buffer_dist; - out_i = is_final_pass ? &index : &buffer_idx; } else { - if (pass % 2 != 0) { - in_d = &buffer_dist; in_i = &buffer_idx; - out_d = &distance; out_i = &index; - } else { - in_d = &distance; in_i = &index; - out_d = &buffer_dist; out_i = &buffer_idx; - } - if (is_final_pass) { - out_d = &distance; out_i = &index; + // Fallback: Global Memory + // 需要分配 buffer: [total_slices * L] = [numel] + if (global_buf1.numel() < dist_in.numel()) { + global_buf1 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); + global_buf2 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); } - } - int64_t num_slices_per_sample = 1; - for(int i = 0; i < sample_ndim; ++i) { - if (i != d_sample) num_slices_per_sample *= shape[i + 1]; - } - int64_t total_slices = batch_size * num_slices_per_sample; - int64_t slice_len = shape[d_sample + 1]; - - int threads = std::min((int64_t)MAX_THREADS, slice_len); - - // JFA 需要的 Shared Memory: - // float f[N] - // int idx1[N] - // int idx2[N] - // 总共 slice_len * (4 + 4 + 4) = 12 * slice_len bytes - size_t smem = slice_len * sizeof(float) + - slice_len * sizeof(int) * 2; - - if (is_first_pass) { - const float* in_ptr = input.data_ptr(); - if (is_final_pass) { - edt_kernel_first_pass<<>>( - in_ptr, out_d->data_ptr(), out_i->data_ptr(), - shape_tensor.data_ptr(), strides_tensor.data_ptr(), - ndim, d_sample, total_slices, num_slices_per_sample - ); - } else { - edt_kernel_first_pass<<>>( - in_ptr, out_d->data_ptr(), out_i->data_ptr(), - shape_tensor.data_ptr(), strides_tensor.data_ptr(), - ndim, d_sample, total_slices, num_slices_per_sample - ); - } - } else { - if (is_final_pass) { - edt_kernel_subsequent_pass<<>>( - in_d->data_ptr(), in_i->data_ptr(), - out_d->data_ptr(), out_i->data_ptr(), - shape_tensor.data_ptr(), strides_tensor.data_ptr(), - ndim, d_sample, total_slices, num_slices_per_sample - ); - } else { - edt_kernel_subsequent_pass<<>>( - in_d->data_ptr(), in_i->data_ptr(), - out_d->data_ptr(), out_i->data_ptr(), - shape_tensor.data_ptr(), strides_tensor.data_ptr(), - ndim, d_sample, total_slices, num_slices_per_sample - ); - } + #define DISPATCH_GLOBAL(IS_FINAL) \ + switch(sample_ndim) { \ + case 1: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel()); break; \ + case 2: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel()); break; \ + case 3: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel()); break; \ + default: break; \ + } + + if (is_final_pass) { DISPATCH_GLOBAL(true); } + else { DISPATCH_GLOBAL(false); } } + + // ----------------------------------------------------------- + // Step C: Transpose Back + // ----------------------------------------------------------- + current_dist = dist_out.transpose(d, ndim - 1); + current_idx = idx_out.transpose(d, ndim - 1); } - - if (had_no_batch_dim) { - return std::make_tuple(distance.squeeze(0), index.squeeze(0)); + + if (had_no_batch_dim) { + return std::make_tuple(current_dist.squeeze(0), current_idx.squeeze(0)); } - return std::make_tuple(distance, index); + return std::make_tuple(current_dist, current_idx); } \ No newline at end of file From 5a9872f96073b4d4e3c8472d188825209752c274 Mon Sep 17 00:00:00 2001 From: Yuhandeng Date: Mon, 8 Dec 2025 20:30:09 +0800 Subject: [PATCH 20/39] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=B8=89=E7=BB=B4?= =?UTF-8?q?=E4=BB=A5=E4=B8=8A=E7=BB=B4=E5=BA=A6=E7=9A=84=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_distance_transform.py | 146 ++++---- torchmorph/csrc/distance_transform_kernel.cu | 332 ++++++++++--------- 2 files changed, 267 insertions(+), 211 deletions(-) diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 476ffca..6e6265d 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -1,21 +1,20 @@ import torch import pytest -from scipy.ndimage import distance_transform_edt as scipy_edt import numpy as np -import torchmorph as tm +from scipy.ndimage import distance_transform_edt as scipy_edt +import torchmorph as tm -# 辅助函数:调用 SciPy 并处理格式 +# ====================================================================== +# 辅助函数 +# ====================================================================== def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - + dist_results, indices_results = [], [] - input_is_1d_batch = (batch_numpy.ndim == 2) - input_is_single_sample_no_batch = (batch_numpy.ndim == 1) + # 保证 batch_numpy 至少是 (Batch, ...) + # 如果进来的是 (H, W),我们在外面已经处理成 (1, H, W) 了 + if batch_numpy.ndim == 1: + batch_numpy = batch_numpy[np.newaxis, ...] - if input_is_single_sample_no_batch: - batch_numpy = batch_numpy[np.newaxis, ...] # (L) -> (1, L) - - - dist_results, indices_results = [], [] for sample in batch_numpy: dist, indices = scipy_edt(sample, return_indices=True, return_distances=True) dist_results.append(dist) @@ -23,81 +22,106 @@ def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, n output_dist = np.stack(dist_results, axis=0) output_indices = np.stack(indices_results, axis=0) + output_indices = np.moveaxis(output_indices, 1, -1) - # indices shape fix: (N, ndim_sample, ...) -> (N, ..., ndim_sample) - # 对于 1D: (N, 1, L) -> (N, L, 1) - output_indices = np.moveaxis(output_indices, 1, -1) - - if input_is_single_sample_no_batch: - output_dist = output_dist.squeeze(0) - output_indices = output_indices.squeeze(0) - return output_dist, output_indices -# 用例定义 -case_batch_2d = np.array([[[0., 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]],[[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.float32) +# ====================================================================== +# 测试数据 +# ====================================================================== +case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32) + +case_batch_2d = np.array([[[0., 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], + [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.float32) + +# 这里定义为 (4, 4),意图是单张 2D 图 +case_single_2d = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]], dtype=np.float32) +case_explicit_batch_one = case_single_2d[np.newaxis, ...] + _case_3d_s1 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s1[1, 1, 1] = 0.0; _case_3d_s1[2, 3, 4] = 0.0 _case_3d_s2 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s2[0, 0, 0] = 0.0 case_batch_3d = np.stack([_case_3d_s1, _case_3d_s2], axis=0) -case_single_2d = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]], dtype=np.float32) -case_explicit_batch_one = case_single_2d[np.newaxis, ...] + case_dim_one = np.ones((2, 5, 1), dtype=np.float32); case_dim_one[0, 2, 0] = 0.0; case_dim_one[1, 4, 0] = 0.0 -case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32) +# 4D Case +_case_4d_s1 = np.ones((3, 3, 3, 3), dtype=np.float32); _case_4d_s1[0, 0, 0, 0] = 0.0 +_case_4d_s2 = np.ones((3, 3, 3, 3), dtype=np.float32); _case_4d_s2[1, 1, 1, 1] = 0.0 +case_batch_4d_spatial = np.stack([_case_4d_s1, _case_4d_s2], axis=0) + +# 5D Case +case_batch_5d_spatial = np.ones((1, 2, 2, 2, 2, 2), dtype=np.float32) +case_batch_5d_spatial[0, 0, 0, 0, 0, 0] = 0.0; case_batch_5d_spatial[0, 1, 1, 1, 1, 1] = 0.0 + +# ====================================================================== +# 测试逻辑 +# ====================================================================== @pytest.mark.parametrize( - "input_numpy", + "input_numpy, has_batch_dim", [ - pytest.param(case_batch_2d, id="批处理2D图像"), - pytest.param(case_batch_3d, id="批处理3D图像"), - pytest.param(case_single_2d, id="单张2D图像(隐式批处理)"), - pytest.param(case_explicit_batch_one, id="单张2D图像(显式批处理)"), - pytest.param(case_dim_one, id="含幺元维度的批处理"), - pytest.param(case_batch_1d, id="批处理1D数据"), + pytest.param(case_batch_1d, True, id="1D_Batch"), + pytest.param(case_batch_2d, True, id="2D_Batch"), + pytest.param(case_single_2d, False, id="2D_Single_NoBatch"), + pytest.param(case_explicit_batch_one, True, id="2D_Single_ExplicitBatch"), + pytest.param(case_batch_3d, True, id="3D_Batch"), + pytest.param(case_dim_one, True, id="2D_UnitDim_Batch"), + pytest.param(case_batch_4d_spatial, True, id="4D_Spatial_Batch"), + pytest.param(case_batch_5d_spatial, True, id="5D_Spatial_Batch"), ], ) -def test_distance_transform_and_indices(input_numpy: np.ndarray, request: pytest.FixtureRequest): +def test_distance_transform_and_indices(input_numpy: np.ndarray, has_batch_dim: bool, request: pytest.FixtureRequest): if not torch.cuda.is_available(): pytest.skip("CUDA not available") + # 1. 准备 Numpy 数据 x_numpy_contiguous = np.ascontiguousarray(input_numpy) + + # 2. 准备 SciPy 输入 + # 如果意图是单样本 (has_batch_dim=False),我们手动增加 Batch 维, + # 这样 scipy 辅助函数就会把它当做一张图来处理,而不是 N 张 1D 图 + if not has_batch_dim: + scipy_input = x_numpy_contiguous[np.newaxis, ...] + else: + scipy_input = x_numpy_contiguous + + # 3. 准备 CUDA 输入 + # 关键修复: + # 如果 has_batch_dim=False,说明这是单张 (H, W),我们要测 2D EDT。 + # C++ API 默认第一维是 Batch,所以我们必须 unsqueeze(0) 变成 (1, H, W)。 + # 否则 C++ 会把它当做 (Batch=H, Len=W) 做 1D EDT。 x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() + if not has_batch_dim: + x_cuda = x_cuda.unsqueeze(0) - print(f"\n\n--- 正在运行测试: {request.node.callspec.id} ---") - print(f"输入张量形状: {x_cuda.shape}") + print(f"\n\n--- 运行测试: {request.node.callspec.id} ---") + print(f"CUDA 输入形状: {x_cuda.shape}") - # 调用您的 Python 包装函数 + # 4. 运行 CUDA EDT dist_cuda, idx_cuda = tm.distance_transform(x_cuda.clone()) - print(f"CUDA 距离输出形状: {dist_cuda.shape}") - print(f"CUDA 坐标输出形状: {idx_cuda.shape}") - - # 调用 SciPy 作为参考基准 - dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(x_numpy_contiguous) + # 5. 运行 SciPy (Ground Truth) + dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(scipy_input) dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() - - print(f"SciPy 距离输出形状: {dist_ref.shape}") - # 断言验证 - print("\n--- 正在验证距离... ---") - assert dist_cuda.shape == dist_ref.shape + # 6. 验证距离 + # 此时 dist_cuda 是 (1, H, W),dist_ref 也是 (1, H, W) + # 如果原意是 NoBatch,我们可以把 Batch 维 squeeze 掉再比,或者直接比 + print(f"CUDA Out Shape: {dist_cuda.shape}, Ref Shape: {dist_ref.shape}") + assert dist_cuda.shape == dist_ref.shape, f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-3, rtol=1e-3) - print("距离断言通过 (形状和数值接近)。") + print(">> 距离验证通过。") - print("\n--- 正在验证坐标... ---") - - # 鲁棒的坐标验证逻辑 - had_no_batch_dim = (x_numpy_contiguous.ndim <= idx_cuda.shape[-1]) - spatial_shape = x_cuda.shape if had_no_batch_dim else x_cuda.shape[1:] + # 7. 验证索引 + # idx_cuda: (1, H, W, 2) + # 构造 Grid + spatial_shape = x_cuda.shape[1:] # (H, W) coords = [torch.arange(s, device='cuda') for s in spatial_shape] - grid = torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1) - - if not had_no_batch_dim: - grid = grid.unsqueeze(0) - + grid = torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1) # (H, W, 2) + grid = grid.unsqueeze(0) # (1, H, W, 2) + diff = grid.float() - idx_cuda.float() - dist_sq_from_indices = torch.sum(diff * diff, dim=-1) + dist_sq_calculated = torch.sum(diff * diff, dim=-1) + dist_sq_output = dist_cuda * dist_cuda - torch.testing.assert_close(dist_sq_from_indices, dist_cuda * dist_cuda, atol=1e-3, rtol=1e-3) - print("坐标正确性断言通过 (计算出的距离与返回距离匹配)。") - - print("--- 测试通过 ---") \ No newline at end of file + torch.testing.assert_close(dist_sq_calculated, dist_sq_output, atol=1e-3, rtol=1e-3) + print(">> 索引验证通过。") \ No newline at end of file diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index a8e5a6a..6e9c4c1 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -5,23 +5,32 @@ #include #include +// ------------------------------------------------------------------ +// 配置常量 +// ------------------------------------------------------------------ #define INF_VAL 1e8f #define MAX_THREADS 1024 -#define SMEM_LIMIT_ELEMENTS 4096 // 48KB / 12 bytes (float+int+int) ~= 4096 +// Shared Memory 限制: 48KB 一般安全。 +// 每个像素需要: float(val) + int(idx1) + int(idx2) = 12 bytes +// 4096 * 12 = 48KB. +#define SMEM_LIMIT_ELEMENTS 4096 + +// ------------------------------------------------------------------ +// Device Helper Functions +// ------------------------------------------------------------------ __device__ __forceinline__ float sqr(float x) { return x * x; } -// 计算从像素 q 到源点 p 的距离代价 -// val_p 是源点 p 在上一轮计算后的距离平方值 (weight) +// 计算 JFA 代价: (q - p)^2 + weight[p] __device__ __forceinline__ float compute_cost(int q, int p, float val_p) { if (p < 0) return INF_VAL; return sqr((float)q - (float)p) + val_p; } // ------------------------------------------------------------------ -// JFA 核心逻辑 (Device Function) +// JFA Core Logic (Device Only) // ------------------------------------------------------------------ -// 无论数据是在 Shared Memory 还是 Global Memory,逻辑是一样的 +// 核心 JFA 逻辑,与数据位置无关 (Shared 或 Global 均通用) __device__ void run_jfa_core( int N, int tid, @@ -29,13 +38,12 @@ __device__ void run_jfa_core( int* __restrict__ idx_curr, // Ping-Pong Buffer A int* __restrict__ idx_next // Ping-Pong Buffer B ) { - // 1. 初始化 + // 1. 初始化: 根据 vals 决定是否是有效源点 for (int i = tid; i < N; i += blockDim.x) { - // 如果输入值很大,说明是背景,没有初始源点 if (vals[i] >= INF_VAL * 0.9f) { - idx_curr[i] = -1; + idx_curr[i] = -1; // 背景 } else { - idx_curr[i] = i; + idx_curr[i] = i; // 物体/源点,初始索引指向自己 } } __syncthreads(); @@ -49,11 +57,12 @@ __device__ void run_jfa_core( int my_best_p = idx_in[i]; float min_cost = INF_VAL; + // 检查自己当前的最优解 if (my_best_p != -1) { min_cost = compute_cost(i, my_best_p, vals[my_best_p]); } - // Check Left + // Check Left Neighbor (-step) int left = i - step; if (left >= 0) { int left_p = idx_in[left]; @@ -66,7 +75,7 @@ __device__ void run_jfa_core( } } - // Check Right + // Check Right Neighbor (+step) int right = i + step; if (right < N) { int right_p = idx_in[right]; @@ -99,42 +108,41 @@ __device__ void run_jfa_core( // ------------------------------------------------------------------ // Kernel 1: Shared Memory JFA (Fast Path) -// 适用于 N <= 4096 // ------------------------------------------------------------------ +// 模板参数 NDim: 如果 > 0,编译器会展开循环优化。 +// 参数 runtime_ndim: 如果 NDim == 0 (Default case),使用该参数作为维度。 template __global__ void edt_kernel_shared( - const float* __restrict__ in_data, // 当前维度的输入 (dist^2) - const int32_t* __restrict__ in_indices, // 上一轮的索引图 (N_slices, L, NDim) - float* __restrict__ out_dist, // 输出距离 - int32_t* __restrict__ out_indices, // 输出索引图 - int64_t L, // 当前维度的长度 (Length) - int64_t total_elements // Batch * ... * L + const float* __restrict__ in_data, // 输入 Dist^2 + const int32_t* __restrict__ in_indices, // 输入 Indices + float* __restrict__ out_dist, // 输出 Dist (IsFinal ? sqrt : sqr) + int32_t* __restrict__ out_indices, // 输出 Indices + int64_t L, // 当前维度的长度 + int64_t total_elements, // 总像素数 + int runtime_ndim // 运行时维度 (fallback) ) { - // 这里的 total_elements 是展平后的总像素数 - // 由于我们做了 transpose,数据布局是 [Batch_and_other_dims, L] - // 每个 Block 处理一行 (长度 L) - + // 确定实际维度 + const int D = (NDim > 0) ? NDim : runtime_ndim; + + // 计算行偏移 int64_t row_idx = blockIdx.x; int64_t offset = row_idx * L; if (offset >= total_elements) return; - // Shared Memory 布局: float vals[L], int idx1[L], int idx2[L] + // Shared Memory 布局 extern __shared__ char s_buffer[]; float* s_vals = (float*)s_buffer; int* s_idx1 = (int*)(s_vals + L); int* s_idx2 = (int*)(s_idx1 + L); - // 1. 加载数据到 Shared Memory + // 1. 加载 Dist 到 Shared Memory for (int i = threadIdx.x; i < L; i += blockDim.x) { - float val = __ldg(&in_data[offset + i]); - // 如果是初始 Pass (无输入索引),val 为 0 或 INF - // 如果是后续 Pass,val 为上一步的 dist^2 - s_vals[i] = val; + s_vals[i] = __ldg(&in_data[offset + i]); } __syncthreads(); - // 2. 运行 JFA + // 2. 运行 JFA 核心 run_jfa_core(L, threadIdx.x, s_vals, s_idx1, s_idx2); // 3. 写回结果 @@ -142,69 +150,64 @@ __global__ void edt_kernel_shared( int p = s_idx1[q]; // 最近点在当前行内的局部索引 (0..L-1) float dist_val; + // 计算新距离 if (p != -1) { - // 计算新距离: (q-p)^2 + val[p] float dist_sq = sqr((float)q - (float)p) + s_vals[p]; dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; } else { dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p = 0; // fallback + p = 0; // 防止越界,随便指一个 } - out_dist[offset + q] = dist_val; - // 4. 索引传播 - // 我们需要从 in_indices 查找完整的高维索引 - // in_indices 形状: [Batch..., L, NDim] - // 这里的 offset 对应 [Batch..., 0] - // p 是当前维度的偏移 + // 索引传播: Copy Vector [D] if (p != -1) { - int64_t src_offset = (offset + p) * NDim; - int64_t dst_offset = (offset + q) * NDim; + int64_t src_offset = (offset + p) * D; + int64_t dst_offset = (offset + q) * D; - // 手动展开拷贝,或者循环 - for (int d = 0; d < NDim; ++d) { + // 如果 NDim > 0,这里会完全展开,非常快 + for (int d = 0; d < D; ++d) { out_indices[dst_offset + d] = in_indices[src_offset + d]; } } else { - // 保持原样或填0 (通常保持原样即可,或者为了安全填0) - int64_t dst_offset = (offset + q) * NDim; - for (int d = 0; d < NDim; ++d) out_indices[dst_offset + d] = 0; + // 找不到源点(全图都是背景的情况) + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; } } } // ------------------------------------------------------------------ // Kernel 2: Global Memory JFA (Fallback Path) -// 适用于 N > 4096,使用 Global Memory 作为 Ping-Pong Buffer // ------------------------------------------------------------------ +// 逻辑同上,只是用 Global Memory 做 Ping-Pong Buffer template __global__ void edt_kernel_global( const float* __restrict__ in_data, const int32_t* __restrict__ in_indices, float* __restrict__ out_dist, int32_t* __restrict__ out_indices, - int* __restrict__ global_buffer_1, // 临时 buffer A [TotalElements] - int* __restrict__ global_buffer_2, // 临时 buffer B [TotalElements] + int* __restrict__ global_buffer_1, + int* __restrict__ global_buffer_2, int64_t L, - int64_t total_elements + int64_t total_elements, + int runtime_ndim ) { + const int D = (NDim > 0) ? NDim : runtime_ndim; + int64_t row_idx = blockIdx.x; int64_t offset = row_idx * L; if (offset >= total_elements) return; - // 指向当前行在 Global Memory 中的位置 - // 注意:in_data 是只读的,我们需要把它当做 weight - // JFA 需要两个 int buffer 来存 index + // 指向 Global Memory 的指针 int* g_idx1 = global_buffer_1 + offset; int* g_idx2 = global_buffer_2 + offset; - // 直接在 Global Memory 上运行 JFA - // 注意:这里 vals 指针直接指向 in_data (Global),读取稍慢但无需拷贝 + // 1. & 2. 运行 JFA (直接在 Global Mem 上读写) run_jfa_core(L, threadIdx.x, in_data + offset, g_idx1, g_idx2); - // 写回逻辑同上 + // 3. 写回结果 for (int q = threadIdx.x; q < L; q += blockDim.x) { int p = g_idx1[q]; float dist_val; @@ -221,177 +224,191 @@ __global__ void edt_kernel_global( out_dist[offset + q] = dist_val; if (p != -1) { - int64_t src_offset = (offset + p) * NDim; - int64_t dst_offset = (offset + q) * NDim; - for (int d = 0; d < NDim; ++d) { + int64_t src_offset = (offset + p) * D; + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) { out_indices[dst_offset + d] = in_indices[src_offset + d]; } } else { - int64_t dst_offset = (offset + q) * NDim; - for (int d = 0; d < NDim; ++d) out_indices[dst_offset + d] = 0; + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; } } } - // ------------------------------------------------------------------ -// 辅助:初始化索引张量 +// Kernel 3: Initialize Indices // ------------------------------------------------------------------ -// 将 index tensor 初始化为 grid grid coordinates -// shape: (..., D), 最后一个维度存坐标 -__global__ void init_indices_kernel(int32_t* indices, int64_t total_elements, int NDim, - const int64_t* shape, const int64_t* strides) { +// 初始化索引张量为网格坐标 +// indices shape: (..., D) +__global__ void init_indices_kernel( + int32_t* indices, + int64_t total_pixels, + int NDim, + const int64_t* __restrict__ shape_ptr // shape of spatial dimensions +) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) return; + if (idx >= total_pixels) return; - // 反解坐标 + // 反解坐标 (Unravel Index) + // idx 是每个像素的 flat index + // 我们需要计算它在 spatial_shape 中的坐标 + int64_t temp = idx; - int32_t coords[10]; // max dims + // 使用本地寄存器数组避免多次全局内存读取 (假设最大 10 维) + int32_t coords[10]; - // strides 是针对 elements 展开的,但这里 indices 是 (Total, NDim) - // 我们可以简单地根据 shape 反解 - // 注意:这里的 total_elements 是像素数,不是 indices 数组的大小 - - // 假设 shape 是 [D0, D1, D2] - // idx 对应 flat index - + // 假设 spatial_shape 是 [D0, D1, D2] + // 倒序计算除余 for (int d = NDim - 1; d >= 0; --d) { - coords[d] = temp % shape[d]; - temp /= shape[d]; + int64_t dim_size = shape_ptr[d]; + coords[d] = temp % dim_size; + temp /= dim_size; } - // 写入 + // 写入 Global Memory + // Indices tensor 是 (TotalPixels, NDim) 扁平化的 + int64_t out_ptr = idx * NDim; for (int d = 0; d < NDim; ++d) { - indices[idx * NDim + d] = coords[d]; + indices[out_ptr + d] = coords[d]; } } // ------------------------------------------------------------------ -// Host 函数 +// Host Function: C++ Entry Point // ------------------------------------------------------------------ + std::tuple distance_transform_cuda(torch::Tensor input) { TORCH_CHECK(input.is_cuda(), "Input must be on CUDA device."); TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32."); input = input.contiguous(); - bool had_no_batch_dim = (input.dim() == 1); - if (had_no_batch_dim) input = input.unsqueeze(0); - - const int ndim = input.dim(); // Include batch + + // 处理 Batch 维度:如果输入是 1D (L),视为无 Batch,但在处理中统一加一个 Batch 维方便 + // 标准约定:Input shape (Batch, D1, D2, ..., Dn) + // 算法对 Batch 维度和其他维度处理其实是一样的(视为无关维度) + // 但索引初始化需要知道哪些是 "Spatial Dimensions"。 + // 这里假设:输入的所有维度除了 Batch (Dim 0) 外都是空间维度。 + + const int ndim = input.dim(); + // 如果 ndim=1, 假设是 (L),sample_ndim=1 + // 如果 ndim=4 (B, C, H, W),sample_ndim=3 (C,H,W 都算空间? 通常 C 也是独立处理的) + // **修正**: 标准 EDT 通常是在 (H, W) 或 (D, H, W) 上进行的。 + // 如果有 Channel,通常 Channel 也是独立的。 + // 为了最通用,我们将 **除了第0维(Batch)** 以外的所有维度都视为空间维度进行索引记录。 + // 如果用户输入没有 Batch 维,请在 Python 端 unsqueeze(0)。 + + // 假设输入已经是 (Batch, ...Spatial...) const int sample_ndim = ndim - 1; + TORCH_CHECK(sample_ndim > 0, "Input tensor must have at least 2 dimensions (Batch, ...)"); + auto shape = input.sizes().vec(); int64_t num_pixels = input.numel(); if (num_pixels == 0) { auto index_shape = shape; - index_shape.push_back(sample_ndim > 0 ? sample_ndim : 1); + index_shape.push_back(sample_ndim); return std::make_tuple(torch::empty_like(input), torch::empty(index_shape, input.options().dtype(torch::kInt32))); } - // 1. 初始化输出 Tensor - // current_dist 在迭代过程中存储 dist^2,最后开方 - // 初始状态:Input 里的 0 还是 0,其他非 0 (背景) 设为 INF + // 1. 初始化 Distance Tensor + // 0 -> 0, 1 -> INF auto current_dist = torch::where(input == 0, torch::tensor(0.0f, input.options()), torch::tensor(INF_VAL, input.options())); - // 初始化索引 Map (Batch, ..., NDim) + // 2. 初始化 Index Tensor + // Shape: (Batch, D1, ..., Dn, sample_ndim) auto index_shape = shape; index_shape.push_back(sample_ndim); auto current_idx = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - // 启动 Kernel 初始化索引 - // 为了反解坐标,我们需要把 shape 传进去 + // 2.1 准备 Shape 数据传给 Kernel + std::vector spatial_shape(shape.begin() + 1, shape.end()); + auto shape_tensor = torch::tensor(spatial_shape, torch::kInt64).to(input.device()); + + // 2.2 运行初始化 Kernel { - // 排除 batch 维度的 shape 用于坐标计算? - // 需求是:返回的索引是 (batch_idx, z, y, x) 还是只是 (z, y, x)? - // 通常 EDT 返回的是 sample 内的坐标。所以我们忽略 batch 维度。 - std::vector sample_shape_vec(shape.begin() + 1, shape.end()); - auto sample_shape_tensor = torch::tensor(sample_shape_vec, torch::kInt64).to(input.device()); - // 这里的 strides 不需要,直接由 shape 反解 - int threads = 256; int blocks = (num_pixels + threads - 1) / threads; - - // 我们需要传递 sample_ndim init_indices_kernel<<>>( - current_idx.data_ptr(), - num_pixels, + current_idx.data_ptr(), + num_pixels, sample_ndim, - sample_shape_tensor.data_ptr(), - nullptr // strides not needed for simple unravel + shape_tensor.data_ptr() ); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("Init Kernel Failed: %s\n", cudaGetErrorString(err)); + } } - - // 用于 Global Memory Fallback 的临时 buffer + + // 预分配 Global Memory Buffer (懒加载) torch::Tensor global_buf1, global_buf2; - // 2. 逐维处理 (Separable Phases) - // 从最后一个维度倒着处理,或者顺序处理都可以。 - // 为了 Host Transpose 方便,我们遍历 sample 的每一个维度 (1 到 ndim-1) + // 3. 逐维处理 (Separable JFA) + // 遍历每一个空间维度 (从 1 到 ndim-1) for (int d = 1; d < ndim; ++d) { bool is_final_pass = (d == ndim - 1); - // ----------------------------------------------------------- - // Step A: Permute & Contiguous - // 将当前处理维度 d 移到最后: (0, 1, ..., d, ..., N-1) -> (0, 1, ..., N-1, d) - // 这样最后内存布局就是 [..., L],stride=1 - // ----------------------------------------------------------- - - // 这种 swap 策略比较简单: transpose(d, -1) - // 注意:index tensor 也要变换,但 index tensor 最后一维是 coord_dim,不能乱动。 - // Index tensor 形状是 [..., sample_ndim]。 - // 我们需要变换的是前面的空间维度 [...]。 - + // --- Step A: Transpose current dim to last --- + // 变换后 Shape: (..., L) auto dist_in = current_dist.transpose(d, ndim - 1).contiguous(); auto idx_in = current_idx.transpose(d, ndim - 1).contiguous(); - // 此时 dist_in shape: [..., L] - // idx_in shape: [..., L, sample_ndim] - int64_t L = dist_in.size(-1); // 当前维度的长度 - int64_t total_slices = dist_in.numel() / L; // 有多少行 + int64_t total_slices = dist_in.numel() / L; auto dist_out = torch::empty_like(dist_in); auto idx_out = torch::empty_like(idx_in); - // ----------------------------------------------------------- - // Step B: Kernel Dispatch - // ----------------------------------------------------------- + // --- Step B: Kernel Dispatch --- int threads = std::min((int64_t)MAX_THREADS, L); - // 检查 Shared Memory 需求 - // Need: float(4) + int(4) + int(4) = 12 bytes per pixel + // 检查是否可以使用 Shared Memory if (L <= SMEM_LIMIT_ELEMENTS) { size_t smem_size = L * (sizeof(float) + 2 * sizeof(int)); - // 模板参数 NDim 需要是编译期常量。 - // 动态分发 sample_ndim (1D, 2D, 3D usually) - // 使用 switch case 覆盖常见维度 (1, 2, 3) + // 使用 Switch 宏来处理常用的维度模板特化 #define DISPATCH_SHARED(IS_FINAL) \ switch(sample_ndim) { \ case 1: edt_kernel_shared<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel()); break; \ + L, dist_in.numel(), sample_ndim); break; \ case 2: edt_kernel_shared<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel()); break; \ + L, dist_in.numel(), sample_ndim); break; \ case 3: edt_kernel_shared<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ - L, dist_in.numel()); break; \ - default: /* fallback for >3D */ break; \ + L, dist_in.numel(), sample_ndim); break; \ + case 4: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 5: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 6: edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + default: /* Fallback for > 6D */ \ + edt_kernel_shared<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ } if (is_final_pass) { DISPATCH_SHARED(true); } else { DISPATCH_SHARED(false); } } else { - // Fallback: Global Memory - // 需要分配 buffer: [total_slices * L] = [numel] + // Global Memory Fallback (L > 4096) if (global_buf1.numel() < dist_in.numel()) { global_buf1 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); global_buf2 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); @@ -403,33 +420,48 @@ std::tuple distance_transform_cuda(torch::Tensor i dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel()); break; \ + L, dist_in.numel(), sample_ndim); break; \ case 2: edt_kernel_global<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel()); break; \ + L, dist_in.numel(), sample_ndim); break; \ case 3: edt_kernel_global<<>>( \ dist_in.data_ptr(), idx_in.data_ptr(), \ dist_out.data_ptr(), idx_out.data_ptr(), \ global_buf1.data_ptr(), global_buf2.data_ptr(), \ - L, dist_in.numel()); break; \ - default: break; \ + L, dist_in.numel(), sample_ndim); break; \ + case 4: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 5: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + case 6: edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ + default: /* Fallback */ \ + edt_kernel_global<<>>( \ + dist_in.data_ptr(), idx_in.data_ptr(), \ + dist_out.data_ptr(), idx_out.data_ptr(), \ + global_buf1.data_ptr(), global_buf2.data_ptr(), \ + L, dist_in.numel(), sample_ndim); break; \ } - if (is_final_pass) { DISPATCH_GLOBAL(true); } + if (is_final_pass) { DISPATCH_GLOBAL(true); } else { DISPATCH_GLOBAL(false); } } - // ----------------------------------------------------------- - // Step C: Transpose Back - // ----------------------------------------------------------- - current_dist = dist_out.transpose(d, ndim - 1); + // --- Step C: Transpose Back --- + current_dist = dist_out.transpose(d, ndim - 1); current_idx = idx_out.transpose(d, ndim - 1); } - if (had_no_batch_dim) { - return std::make_tuple(current_dist.squeeze(0), current_idx.squeeze(0)); - } return std::make_tuple(current_dist, current_idx); } \ No newline at end of file From 44930399b03615df52ec95713890e5bdbbffb3ef Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 08:12:40 +0800 Subject: [PATCH 21/39] forbid non-ascii --- .pre-commit-config.yaml | 12 +++++++++ scripts/check_ascii.py | 58 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 scripts/check_ascii.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6df40f7..a79f4db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,3 +11,15 @@ repos: rev: 6.1.0 hooks: - id: flake8 + + # ------------------------- + # ⭐ Local Hook: forbid non-ASCII in C/C++/CUDA + # ------------------------- + - repo: local + hooks: + - id: forbid-non-ascii + name: "Forbid non-ASCII characters in C/C++/CUDA" + entry: python3 scripts/check_ascii.py + language: system + types: [file] + files: '\.(c|cc|cpp|cxx|cu|cuh|h|hpp)$' diff --git a/scripts/check_ascii.py b/scripts/check_ascii.py new file mode 100644 index 0000000..eb1cf91 --- /dev/null +++ b/scripts/check_ascii.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +import sys +from pathlib import Path + +TARGET_SUFFIXES = {".c", ".cc", ".cpp", ".cxx", ".cu", ".cuh", ".h", ".hpp"} + +def find_non_ascii(line): + """Return list of (index, char) for all non-ASCII chars in a line.""" + result = [] + for i, ch in enumerate(line): + if ord(ch) > 127: + result.append((i, ch)) + return result + + +def check_file(path: Path) -> bool: + ok = True + with path.open("r", encoding="utf-8", errors="ignore") as f: + for lineno, line in enumerate(f, start=1): + non_ascii = find_non_ascii(line) + if non_ascii: + ok = False + print(f"\n❌ {path}:{lineno}: non-ASCII characters detected") + + # Print the full line + print(f" Line content:") + print(f" {line.rstrip()}") + + # Underline the exact non-ASCII characters + underline = [" " for _ in line.rstrip("\n")] + for idx, ch in non_ascii: + if idx < len(underline): + underline[idx] = "^" + print(f" {' '.join(underline)}") + + # Print what characters exactly + chars = ", ".join(f"'{ch}' (U+{ord(ch):04X})" for _, ch in non_ascii) + print(f" Offending chars: {chars}") + + return ok + + +def main(files): + ok = True + for f in files: + p = Path(f) + if p.suffix.lower() in TARGET_SUFFIXES and p.exists(): + if not check_file(p): + ok = False + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: check_ascii.py ") + sys.exit(1) + main(sys.argv[1:]) + From 81e43df98f9efff001b6b5c7d24655ff4a59f3aa Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 08:23:01 +0800 Subject: [PATCH 22/39] workflow: precommit --- .github/workflows/pre-commit.yml | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/pre-commit.yml diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..ca7885d --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,32 @@ +name: pre-commit + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +jobs: + pre-commit: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Install pre-commit + run: | + pip install pre-commit + + - name: Install pre-commit hooks + run: | + pre-commit install --install-hooks + + - name: Run pre-commit on all files + run: | + pre-commit run --from-ref origin/main --to-ref HEAD From 95f8f34a9884d8c3ccebf120b8f41f89dd544dba Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 08:25:21 +0800 Subject: [PATCH 23/39] workflow: run on all files --- .github/workflows/pre-commit.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index ca7885d..2c1bd6c 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -29,4 +29,4 @@ jobs: - name: Run pre-commit on all files run: | - pre-commit run --from-ref origin/main --to-ref HEAD + pre-commit run --all-files From 83410854d0bc49658b45221bd75554450eb58d1e Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 09:17:51 +0800 Subject: [PATCH 24/39] prevent duplicated ci --- .github/workflows/pre-commit.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 2c1bd6c..bfebf61 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -2,7 +2,7 @@ name: pre-commit on: push: - branches: [ "*" ] + branches: [ "main" ] pull_request: branches: [ "*" ] From fd4cd5e191327dc3a9c871b756d8c450e7bd35de Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 22:10:13 +0800 Subject: [PATCH 25/39] test workflow --- .github/workflows/test.yml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..f602494 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,24 @@ +name: test + +on: + pull_request: + branches: ["*"] + push: + branches: [main] + +permissions: + contents: read + pull-requests: read + +test: + runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + - run: | + pip install -r requirements-dev.txt --break-system-packages --user + pip uninstall torchmorph --yes + python setup.py install --user + - run: | + ORIGINAL=$(pwd) + cd /tmp + pytest $ORIGINAL/test \ No newline at end of file From 1b520c25c738e56a37068572e4f58d67759dea10 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 22:12:46 +0800 Subject: [PATCH 26/39] test workflow --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f602494..d8e1647 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ test: - uses: actions/checkout@v4 - run: | pip install -r requirements-dev.txt --break-system-packages --user - pip uninstall torchmorph --yes + pip uninstall torchmorph --yes --break-system-packages python setup.py install --user - run: | ORIGINAL=$(pwd) From 157bfac630449175a87778209487d5b8aa763862 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 22:13:55 +0800 Subject: [PATCH 27/39] test workflow --- .github/workflows/test.yml | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d8e1647..bc4ddf6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,19 +6,17 @@ on: push: branches: [main] -permissions: - contents: read - pull-requests: read +jobs: -test: - runs-on: self-hosted - steps: - - uses: actions/checkout@v4 - - run: | - pip install -r requirements-dev.txt --break-system-packages --user - pip uninstall torchmorph --yes --break-system-packages - python setup.py install --user - - run: | - ORIGINAL=$(pwd) - cd /tmp - pytest $ORIGINAL/test \ No newline at end of file + test: + runs-on: self-hosted + steps: + - uses: actions/checkout@v4 + - run: | + pip install -r requirements-dev.txt --break-system-packages --user + pip uninstall torchmorph --yes --break-system-packages + python setup.py install --user + - run: | + ORIGINAL=$(pwd) + cd /tmp + pytest $ORIGINAL/test \ No newline at end of file From 659999175ea83da5c74359526fabfcbc833ab51a Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 22:22:20 +0800 Subject: [PATCH 28/39] CUDA_HOME --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bc4ddf6..5ef06c7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,4 +19,4 @@ jobs: - run: | ORIGINAL=$(pwd) cd /tmp - pytest $ORIGINAL/test \ No newline at end of file + CUDA_HOME=/usr/local/cuda pytest $ORIGINAL/test \ No newline at end of file From d417ba9339c7a9e26d634373bd54e6106c25485b Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 22:24:28 +0800 Subject: [PATCH 29/39] CUDA_HOME --- .github/workflows/test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5ef06c7..0f7cb49 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,10 +13,11 @@ jobs: steps: - uses: actions/checkout@v4 - run: | + CUDA_HOME=/usr/local/cuda pip install -r requirements-dev.txt --break-system-packages --user pip uninstall torchmorph --yes --break-system-packages python setup.py install --user - run: | ORIGINAL=$(pwd) cd /tmp - CUDA_HOME=/usr/local/cuda pytest $ORIGINAL/test \ No newline at end of file + pytest $ORIGINAL/test \ No newline at end of file From c93176c8adaccf7f3a5be9fa00be6a3feec4f917 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Tue, 9 Dec 2025 23:27:11 +0800 Subject: [PATCH 30/39] CUDA_HOME --- .github/workflows/test.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0f7cb49..ca3e9b3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,11 @@ jobs: steps: - uses: actions/checkout@v4 - run: | - CUDA_HOME=/usr/local/cuda + # Make CUDA visible to this shell and all child processes + export CUDA_HOME=/usr/local/cuda + export PATH="$CUDA_HOME/bin:$PATH" + export LD_LIBRARY_PATH="$CUDA_HOME/lib64:${LD_LIBRARY_PATH:-}" + echo "CUDA_HOME=$CUDA_HOME" pip install -r requirements-dev.txt --break-system-packages --user pip uninstall torchmorph --yes --break-system-packages python setup.py install --user From 384f6413a152106b93d02168465b6cb005545f6a Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Wed, 10 Dec 2025 12:04:53 +0800 Subject: [PATCH 31/39] check ascii for .py --- scripts/check_ascii.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/check_ascii.py b/scripts/check_ascii.py index eb1cf91..75d9151 100644 --- a/scripts/check_ascii.py +++ b/scripts/check_ascii.py @@ -2,7 +2,7 @@ import sys from pathlib import Path -TARGET_SUFFIXES = {".c", ".cc", ".cpp", ".cxx", ".cu", ".cuh", ".h", ".hpp"} +TARGET_SUFFIXES = {".c", ".cc", ".cpp", ".cxx", ".cu", ".cuh", ".h", ".hpp", ".py"} def find_non_ascii(line): """Return list of (index, char) for all non-ASCII chars in a line.""" From da7cadd6da2409f2896d0d8081402f32b9ee0ada Mon Sep 17 00:00:00 2001 From: dongliangnie Date: Wed, 10 Dec 2025 21:41:40 +0800 Subject: [PATCH 32/39] =?UTF-8?q?=E4=BF=AE=E6=94=B9distance=5Ftransform=5F?= =?UTF-8?q?kernel.cu=E6=B3=A8=E9=87=8A=E4=B8=BA=E8=8B=B1=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torchmorph/csrc/distance_transform_kernel.cu | 158 ++++++++++--------- 1 file changed, 80 insertions(+), 78 deletions(-) diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 6e9c4c1..41defa6 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -6,13 +6,13 @@ #include // ------------------------------------------------------------------ -// 配置常量 +// Configuration Constants // ------------------------------------------------------------------ #define INF_VAL 1e8f #define MAX_THREADS 1024 -// Shared Memory 限制: 48KB 一般安全。 -// 每个像素需要: float(val) + int(idx1) + int(idx2) = 12 bytes -// 4096 * 12 = 48KB. +// Shared memory limit: typically 48 KB. +// Each pixel requires: float(value) + int(idx1) + int(idx2) = 12 bytes. +// 4096 * 12 = 48 KB. #define SMEM_LIMIT_ELEMENTS 4096 // ------------------------------------------------------------------ @@ -21,7 +21,7 @@ __device__ __forceinline__ float sqr(float x) { return x * x; } -// 计算 JFA 代价: (q - p)^2 + weight[p] +// Compute the JFA cost: (q - p)^2 + weight[p] __device__ __forceinline__ float compute_cost(int q, int p, float val_p) { if (p < 0) return INF_VAL; return sqr((float)q - (float)p) + val_p; @@ -30,25 +30,25 @@ __device__ __forceinline__ float compute_cost(int q, int p, float val_p) { // ------------------------------------------------------------------ // JFA Core Logic (Device Only) // ------------------------------------------------------------------ -// 核心 JFA 逻辑,与数据位置无关 (Shared 或 Global 均通用) +// Core JFA logic, independent of data location (works with both Shared and Global memory). __device__ void run_jfa_core( int N, int tid, - const float* __restrict__ vals, // 输入权重 (只读) + const float* __restrict__ vals, // input weight (read-only) int* __restrict__ idx_curr, // Ping-Pong Buffer A int* __restrict__ idx_next // Ping-Pong Buffer B ) { - // 1. 初始化: 根据 vals 决定是否是有效源点 + // 1. Initialization: determine whether each pixel is a valid source based on vals. for (int i = tid; i < N; i += blockDim.x) { if (vals[i] >= INF_VAL * 0.9f) { - idx_curr[i] = -1; // 背景 + idx_curr[i] = -1; // background } else { - idx_curr[i] = i; // 物体/源点,初始索引指向自己 + idx_curr[i] = i; // For each object/source point, the initial index points to itself. } } __syncthreads(); - // 2. 迭代传播 (Step = 1, 2, 4, ... < N) + // 2. Iterative Propagation (Step = 1, 2, 4, ... < N) int* idx_in = idx_curr; int* idx_out = idx_next; @@ -57,7 +57,7 @@ __device__ void run_jfa_core( int my_best_p = idx_in[i]; float min_cost = INF_VAL; - // 检查自己当前的最优解 + // Check its current best solution if (my_best_p != -1) { min_cost = compute_cost(i, my_best_p, vals[my_best_p]); } @@ -97,7 +97,7 @@ __device__ void run_jfa_core( __syncthreads(); } - // 3. 确保最终结果在 idx_curr (如果循环结束时在 next,则拷回) + // 3. Ensure the final result is stored in idx_curr (if the loop ends with idx_next, copy it back). if (idx_in != idx_curr) { for (int i = tid; i < N; i += blockDim.x) { idx_curr[i] = idx_next[i]; @@ -109,70 +109,71 @@ __device__ void run_jfa_core( // ------------------------------------------------------------------ // Kernel 1: Shared Memory JFA (Fast Path) // ------------------------------------------------------------------ -// 模板参数 NDim: 如果 > 0,编译器会展开循环优化。 -// 参数 runtime_ndim: 如果 NDim == 0 (Default case),使用该参数作为维度。 +// Template parameter NDim: when NDim > 0, the compiler performs loop unrolling optimizations. +// Runtime parameter runtime_ndim: when NDim == 0 (default behavior), this parameter specifies the dimension. template __global__ void edt_kernel_shared( - const float* __restrict__ in_data, // 输入 Dist^2 - const int32_t* __restrict__ in_indices, // 输入 Indices - float* __restrict__ out_dist, // 输出 Dist (IsFinal ? sqrt : sqr) - int32_t* __restrict__ out_indices, // 输出 Indices - int64_t L, // 当前维度的长度 - int64_t total_elements, // 总像素数 - int runtime_ndim // 运行时维度 (fallback) + const float* __restrict__ in_data, // input Dist^2 + const int32_t* __restrict__ in_indices, // output Indices + float* __restrict__ out_dist, // output Dist (IsFinal ? sqrt : sqr) + int32_t* __restrict__ out_indices, // output Indices + int64_t L, // Size of the current dimension + int64_t total_elements, // Total number of elements + int runtime_ndim // Runtime dimension (used as fallback) ) { - // 确定实际维度 + // Determine the effective dimension const int D = (NDim > 0) ? NDim : runtime_ndim; - // 计算行偏移 + // Compute row offset int64_t row_idx = blockIdx.x; int64_t offset = row_idx * L; - + if (offset >= total_elements) return; - // Shared Memory 布局 + // Shared memory layout extern __shared__ char s_buffer[]; float* s_vals = (float*)s_buffer; int* s_idx1 = (int*)(s_vals + L); int* s_idx2 = (int*)(s_idx1 + L); - // 1. 加载 Dist 到 Shared Memory + // 1. Load distances into Shared Memory for (int i = threadIdx.x; i < L; i += blockDim.x) { s_vals[i] = __ldg(&in_data[offset + i]); } __syncthreads(); - // 2. 运行 JFA 核心 + // 2. Run the core JFA logic run_jfa_core(L, threadIdx.x, s_vals, s_idx1, s_idx2); - - // 3. 写回结果 + + // 3. Write back the results for (int q = threadIdx.x; q < L; q += blockDim.x) { - int p = s_idx1[q]; // 最近点在当前行内的局部索引 (0..L-1) + int p = s_idx1[q]; // Nearest point (local index within 0..L-1) float dist_val; - // 计算新距离 + // Compute updated distance if (p != -1) { float dist_sq = sqr((float)q - (float)p) + s_vals[p]; dist_val = IsFinal ? sqrtf(dist_sq) : dist_sq; } else { + // No source point found (e.g., entire row is background) dist_val = IsFinal ? INF_VAL : sqr(INF_VAL); - p = 0; // 防止越界,随便指一个 + p = 0; // Prevent out-of-bounds access } out_dist[offset + q] = dist_val; - // 索引传播: Copy Vector [D] + // Propagate indices: copy a vector of size [D] if (p != -1) { int64_t src_offset = (offset + p) * D; int64_t dst_offset = (offset + q) * D; - - // 如果 NDim > 0,这里会完全展开,非常快 + + // When NDim > 0, this loop is fully unrolled by the compiler for (int d = 0; d < D; ++d) { out_indices[dst_offset + d] = in_indices[src_offset + d]; } } else { - // 找不到源点(全图都是背景的情况) - int64_t dst_offset = (offset + q) * D; - for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; + // Fallback: no source available + int64_t dst_offset = (offset + q) * D; + for (int d = 0; d < D; ++d) out_indices[dst_offset + d] = 0; } } } @@ -180,7 +181,7 @@ __global__ void edt_kernel_shared( // ------------------------------------------------------------------ // Kernel 2: Global Memory JFA (Fallback Path) // ------------------------------------------------------------------ -// 逻辑同上,只是用 Global Memory 做 Ping-Pong Buffer +// Same logic as above, but uses Global Memory as the ping-pong buffer template __global__ void edt_kernel_global( const float* __restrict__ in_data, @@ -200,14 +201,14 @@ __global__ void edt_kernel_global( if (offset >= total_elements) return; - // 指向 Global Memory 的指针 + // Pointers to Global Memory int* g_idx1 = global_buffer_1 + offset; int* g_idx2 = global_buffer_2 + offset; - // 1. & 2. 运行 JFA (直接在 Global Mem 上读写) + // 1. & 2. Run the JFA core (operating directly on Global Memory) run_jfa_core(L, threadIdx.x, in_data + offset, g_idx1, g_idx2); - // 3. 写回结果 + // 3. Write back results for (int q = threadIdx.x; q < L; q += blockDim.x) { int p = g_idx1[q]; float dist_val; @@ -236,10 +237,11 @@ __global__ void edt_kernel_global( } } + // ------------------------------------------------------------------ // Kernel 3: Initialize Indices // ------------------------------------------------------------------ -// 初始化索引张量为网格坐标 +// Initialize index tensor as grid coordinates // indices shape: (..., D) __global__ void init_indices_kernel( int32_t* indices, @@ -250,24 +252,24 @@ __global__ void init_indices_kernel( int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_pixels) return; - // 反解坐标 (Unravel Index) - // idx 是每个像素的 flat index - // 我们需要计算它在 spatial_shape 中的坐标 + // Unravel Index + // idx is the flat index of each pixel + // We need to compute its coordinate in spatial_shape int64_t temp = idx; - // 使用本地寄存器数组避免多次全局内存读取 (假设最大 10 维) + // Use local register array to avoid repeated global memory reads (assume max 10 dims) int32_t coords[10]; - // 假设 spatial_shape 是 [D0, D1, D2] - // 倒序计算除余 + // Example: spatial_shape = [D0, D1, D2] + // compute by modulo from last dimension for (int d = NDim - 1; d >= 0; --d) { int64_t dim_size = shape_ptr[d]; coords[d] = temp % dim_size; temp /= dim_size; } - // 写入 Global Memory - // Indices tensor 是 (TotalPixels, NDim) 扁平化的 + // Write to Global Memory + // Indices tensor is flattened as (TotalPixels, NDim) int64_t out_ptr = idx * NDim; for (int d = 0; d < NDim; ++d) { indices[out_ptr + d] = coords[d]; @@ -284,24 +286,23 @@ std::tuple distance_transform_cuda(torch::Tensor i input = input.contiguous(); - // 处理 Batch 维度:如果输入是 1D (L),视为无 Batch,但在处理中统一加一个 Batch 维方便 - // 标准约定:Input shape (Batch, D1, D2, ..., Dn) - // 算法对 Batch 维度和其他维度处理其实是一样的(视为无关维度) - // 但索引初始化需要知道哪些是 "Spatial Dimensions"。 - // 这里假设:输入的所有维度除了 Batch (Dim 0) 外都是空间维度。 + // Handle batch dimension: if input is 1D (L), treat as no batch but internally add a batch dimension. + // Convention: input shape is (Batch, D1, D2, ..., Dn) + // Algorithm treats batch and other dims identically (batch is just another leading dimension) + // But index initialization needs to know which are "spatial dimensions". + // Assumption: all dims except dim 0 (Batch) are spatial. const int ndim = input.dim(); - // 如果 ndim=1, 假设是 (L),sample_ndim=1 - // 如果 ndim=4 (B, C, H, W),sample_ndim=3 (C,H,W 都算空间? 通常 C 也是独立处理的) - // **修正**: 标准 EDT 通常是在 (H, W) 或 (D, H, W) 上进行的。 - // 如果有 Channel,通常 Channel 也是独立的。 - // 为了最通用,我们将 **除了第0维(Batch)** 以外的所有维度都视为空间维度进行索引记录。 - // 如果用户输入没有 Batch 维,请在 Python 端 unsqueeze(0)。 + // If ndim=1, assume (L) → sample_ndim=1 + // If ndim=4 (B, C, H, W), sample_ndim=3 (C,H,W treated as spatial? Channels often processed independently) + // Correction: classical EDT usually runs on (H,W) or (D,H,W). + // If channels exist, typically each channel is processed independently. + // For maximum generality, we treat **all dims except dim 0** as spatial dims. + // If input has no batch dim, user should use unsqueeze(0) in Python. - // 假设输入已经是 (Batch, ...Spatial...) const int sample_ndim = ndim - 1; TORCH_CHECK(sample_ndim > 0, "Input tensor must have at least 2 dimensions (Batch, ...)"); - + auto shape = input.sizes().vec(); int64_t num_pixels = input.numel(); @@ -312,23 +313,23 @@ std::tuple distance_transform_cuda(torch::Tensor i torch::empty(index_shape, input.options().dtype(torch::kInt32))); } - // 1. 初始化 Distance Tensor + // 1. Initialize Distance Tensor // 0 -> 0, 1 -> INF auto current_dist = torch::where(input == 0, torch::tensor(0.0f, input.options()), torch::tensor(INF_VAL, input.options())); - // 2. 初始化 Index Tensor + // 2. Initialize Index Tensor // Shape: (Batch, D1, ..., Dn, sample_ndim) auto index_shape = shape; index_shape.push_back(sample_ndim); auto current_idx = torch::empty(index_shape, input.options().dtype(torch::kInt32)); - // 2.1 准备 Shape 数据传给 Kernel + // 2.1 Prepare shape tensor for kernel std::vector spatial_shape(shape.begin() + 1, shape.end()); auto shape_tensor = torch::tensor(spatial_shape, torch::kInt64).to(input.device()); - // 2.2 运行初始化 Kernel + // 2.2 Launch initialization kernel { int threads = 256; int blocks = (num_pixels + threads - 1) / threads; @@ -344,20 +345,20 @@ std::tuple distance_transform_cuda(torch::Tensor i } } - // 预分配 Global Memory Buffer (懒加载) + // Pre-allocate Global Memory Buffers (lazy) torch::Tensor global_buf1, global_buf2; - // 3. 逐维处理 (Separable JFA) - // 遍历每一个空间维度 (从 1 到 ndim-1) + // 3. Process each spatial dimension (Separable JFA) + // Iterate through each spatial dimension (1 to ndim-1) for (int d = 1; d < ndim; ++d) { bool is_final_pass = (d == ndim - 1); // --- Step A: Transpose current dim to last --- - // 变换后 Shape: (..., L) + // Resulting shape: (..., L) auto dist_in = current_dist.transpose(d, ndim - 1).contiguous(); auto idx_in = current_idx.transpose(d, ndim - 1).contiguous(); - int64_t L = dist_in.size(-1); // 当前维度的长度 + int64_t L = dist_in.size(-1); int64_t total_slices = dist_in.numel() / L; auto dist_out = torch::empty_like(dist_in); @@ -366,11 +367,11 @@ std::tuple distance_transform_cuda(torch::Tensor i // --- Step B: Kernel Dispatch --- int threads = std::min((int64_t)MAX_THREADS, L); - // 检查是否可以使用 Shared Memory + // Check whether Shared Memory can be used if (L <= SMEM_LIMIT_ELEMENTS) { size_t smem_size = L * (sizeof(float) + 2 * sizeof(int)); - // 使用 Switch 宏来处理常用的维度模板特化 + // Switch macro to handle template dimension specialization #define DISPATCH_SHARED(IS_FINAL) \ switch(sample_ndim) { \ case 1: edt_kernel_shared<<>>( \ @@ -408,7 +409,7 @@ std::tuple distance_transform_cuda(torch::Tensor i else { DISPATCH_SHARED(false); } } else { - // Global Memory Fallback (L > 4096) + // Global Memory fallback (L > 4096) if (global_buf1.numel() < dist_in.numel()) { global_buf1 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); global_buf2 = torch::empty({dist_in.numel()}, torch::TensorOptions().dtype(torch::kInt32).device(input.device())); @@ -464,4 +465,5 @@ std::tuple distance_transform_cuda(torch::Tensor i } return std::make_tuple(current_dist, current_idx); -} \ No newline at end of file +} + From d622ccddacb338ddea3cd10384450857db99db63 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Wed, 10 Dec 2025 23:43:03 +0800 Subject: [PATCH 33/39] flake8 --- .flake8 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 5fc408e..6ef9e30 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] max-line-length = 100 -extend-ignore = E203, W503 # Compatibility with Black +extend-ignore = E203, W503 exclude = __pycache__, build, From ee9e779a5606dde6290712a2cd55a29ce293f517 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Wed, 10 Dec 2025 23:45:05 +0800 Subject: [PATCH 34/39] -> --- torchmorph/csrc/distance_transform_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmorph/csrc/distance_transform_kernel.cu b/torchmorph/csrc/distance_transform_kernel.cu index 41defa6..503d13c 100644 --- a/torchmorph/csrc/distance_transform_kernel.cu +++ b/torchmorph/csrc/distance_transform_kernel.cu @@ -293,7 +293,7 @@ std::tuple distance_transform_cuda(torch::Tensor i // Assumption: all dims except dim 0 (Batch) are spatial. const int ndim = input.dim(); - // If ndim=1, assume (L) → sample_ndim=1 + // If ndim=1, assume (L) -> sample_ndim=1 // If ndim=4 (B, C, H, W), sample_ndim=3 (C,H,W treated as spatial? Channels often processed independently) // Correction: classical EDT usually runs on (H,W) or (D,H,W). // If channels exist, typically each channel is processed independently. From c42bb3eb119839f1215a157633648d344b8b413a Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Wed, 10 Dec 2025 23:50:39 +0800 Subject: [PATCH 35/39] check ascii for py --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a79f4db..86088b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,4 +22,4 @@ repos: entry: python3 scripts/check_ascii.py language: system types: [file] - files: '\.(c|cc|cpp|cxx|cu|cuh|h|hpp)$' + files: '\.(c|cc|cpp|cxx|cu|cuh|h|hpp|py)$' From b8fbc8e4a177af96d34fdae14283f1a8ae36ef0c Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Wed, 10 Dec 2025 23:54:25 +0800 Subject: [PATCH 36/39] check non-latin languages --- scripts/check_ascii.py | 77 +++++++++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 12 deletions(-) diff --git a/scripts/check_ascii.py b/scripts/check_ascii.py index 75d9151..bc1dcdd 100644 --- a/scripts/check_ascii.py +++ b/scripts/check_ascii.py @@ -1,40 +1,94 @@ #!/usr/bin/env python3 import sys +import unicodedata from pathlib import Path TARGET_SUFFIXES = {".c", ".cc", ".cpp", ".cxx", ".cu", ".cuh", ".h", ".hpp", ".py"} -def find_non_ascii(line): - """Return list of (index, char) for all non-ASCII chars in a line.""" + +# --- Helpers -------------------------------------------------------- + + +# Latin ranges we still consider "English-ish" and therefore allowed. +# (You can shrink this if you want to ban accented letters too.) +LATIN_RANGES = [ + (0x0000, 0x007F), # Basic Latin (ASCII) + (0x00C0, 0x024F), # Latin-1 Supplement + Latin Extended-A/B + (0x1E00, 0x1EFF), # Latin Extended Additional +] + + +def in_ranges(ch: str, ranges) -> bool: + cp = ord(ch) + for start, end in ranges: + if start <= cp <= end: + return True + return False + + +def is_forbidden_char(ch: str) -> bool: + """ + Return True if ch should be *forbidden*. + + Policy: + - ASCII (<= 0x7F): always OK + - Non-ASCII letters (Unicode category starting with 'L') + that are NOT in Latin ranges: forbidden + - Everything else (emoji, arrows, symbols, etc.): allowed + """ + cp = ord(ch) + if cp <= 0x7F: + return False # pure ASCII + + cat = unicodedata.category(ch) + + # Forbid letters that are not Latin. + if cat.startswith("L"): # Letter + if in_ranges(ch, LATIN_RANGES): + return False # Latin letters allowed + return True # Non-Latin letters forbidden + + # All non-letter stuff (emoji, arrows, symbols, punctuation) is allowed. + return False + + +def find_forbidden_chars(line: str): + """Return list of (index, char) for all forbidden chars in a line.""" result = [] for i, ch in enumerate(line): - if ord(ch) > 127: + if is_forbidden_char(ch): result.append((i, ch)) return result +# --- Core logic ----------------------------------------------------- + + def check_file(path: Path) -> bool: ok = True with path.open("r", encoding="utf-8", errors="ignore") as f: for lineno, line in enumerate(f, start=1): - non_ascii = find_non_ascii(line) - if non_ascii: + forbidden = find_forbidden_chars(line) + if forbidden: ok = False - print(f"\n❌ {path}:{lineno}: non-ASCII characters detected") + print(f"\n❌ {path}:{lineno}: non-English letters detected") # Print the full line - print(f" Line content:") + print(" Line content:") print(f" {line.rstrip()}") - # Underline the exact non-ASCII characters + # Underline the forbidden characters underline = [" " for _ in line.rstrip("\n")] - for idx, ch in non_ascii: + for idx, ch in forbidden: if idx < len(underline): underline[idx] = "^" - print(f" {' '.join(underline)}") + print(f" {''.join(underline)}") # Print what characters exactly - chars = ", ".join(f"'{ch}' (U+{ord(ch):04X})" for _, ch in non_ascii) + chars = ", ".join( + f"'{ch}' (U+{ord(ch):04X}) [{unicodedata.name(ch, 'UNKNOWN')}]" + for _, ch in forbidden + ) print(f" Offending chars: {chars}") return ok @@ -55,4 +109,3 @@ def main(files): print("Usage: check_ascii.py ") sys.exit(1) main(sys.argv[1:]) - From 4ed06101a9e0edc33068a16e8dcfc3ea1be58dac Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Wed, 10 Dec 2025 23:58:19 +0800 Subject: [PATCH 37/39] reformat --- test/test_distance_transform.py | 178 +++++++++++++++++++++----------- 1 file changed, 117 insertions(+), 61 deletions(-) diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 6e6265d..56acb37 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -1,60 +1,100 @@ -import torch -import pytest import numpy as np +import pytest +import torch from scipy.ndimage import distance_transform_edt as scipy_edt -import torchmorph as tm + +import torchmorph as tm + # ====================================================================== -# 辅助函数 +# Helper functions # ====================================================================== -def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - dist_results, indices_results = [], [] - - # 保证 batch_numpy 至少是 (Batch, ...) - # 如果进来的是 (H, W),我们在外面已经处理成 (1, H, W) 了 +def batch_scipy_edt_with_indices( + batch_numpy: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute SciPy EDT and indices for a batch of arrays.""" + dist_results: list[np.ndarray] = [] + indices_results: list[np.ndarray] = [] + + # Ensure batch_numpy has at least shape (Batch, ...) + # If the input is (H, W), it is already converted to (1, H, W) outside. if batch_numpy.ndim == 1: batch_numpy = batch_numpy[np.newaxis, ...] for sample in batch_numpy: - dist, indices = scipy_edt(sample, return_indices=True, return_distances=True) + dist, indices = scipy_edt( + sample, + return_indices=True, + return_distances=True, + ) dist_results.append(dist) indices_results.append(indices) - + output_dist = np.stack(dist_results, axis=0) output_indices = np.stack(indices_results, axis=0) - output_indices = np.moveaxis(output_indices, 1, -1) - + output_indices = np.moveaxis(output_indices, 1, -1) + return output_dist, output_indices + # ====================================================================== -# 测试数据 +# Test data # ====================================================================== -case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32) +case_batch_1d = np.array( + [[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], + dtype=np.float32, +) -case_batch_2d = np.array([[[0., 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], - [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.float32) +case_batch_2d = np.array( + [ + [[0.0, 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]], + [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], + ], + dtype=np.float32, +) -# 这里定义为 (4, 4),意图是单张 2D 图 -case_single_2d = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]], dtype=np.float32) +# This is a single 2D image with shape (4, 4) +case_single_2d = np.array( + [ + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + ], + dtype=np.float32, +) case_explicit_batch_one = case_single_2d[np.newaxis, ...] -_case_3d_s1 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s1[1, 1, 1] = 0.0; _case_3d_s1[2, 3, 4] = 0.0 -_case_3d_s2 = np.ones((4, 5, 6), dtype=np.float32); _case_3d_s2[0, 0, 0] = 0.0 +_case_3d_s1 = np.ones((4, 5, 6), dtype=np.float32) +_case_3d_s1[1, 1, 1] = 0.0 +_case_3d_s1[2, 3, 4] = 0.0 + +_case_3d_s2 = np.ones((4, 5, 6), dtype=np.float32) +_case_3d_s2[0, 0, 0] = 0.0 + case_batch_3d = np.stack([_case_3d_s1, _case_3d_s2], axis=0) -case_dim_one = np.ones((2, 5, 1), dtype=np.float32); case_dim_one[0, 2, 0] = 0.0; case_dim_one[1, 4, 0] = 0.0 +case_dim_one = np.ones((2, 5, 1), dtype=np.float32) +case_dim_one[0, 2, 0] = 0.0 +case_dim_one[1, 4, 0] = 0.0 + +# 4D spatial case +_case_4d_s1 = np.ones((3, 3, 3, 3), dtype=np.float32) +_case_4d_s1[0, 0, 0, 0] = 0.0 + +_case_4d_s2 = np.ones((3, 3, 3, 3), dtype=np.float32) +_case_4d_s2[1, 1, 1, 1] = 0.0 -# 4D Case -_case_4d_s1 = np.ones((3, 3, 3, 3), dtype=np.float32); _case_4d_s1[0, 0, 0, 0] = 0.0 -_case_4d_s2 = np.ones((3, 3, 3, 3), dtype=np.float32); _case_4d_s2[1, 1, 1, 1] = 0.0 case_batch_4d_spatial = np.stack([_case_4d_s1, _case_4d_s2], axis=0) -# 5D Case +# 5D spatial case case_batch_5d_spatial = np.ones((1, 2, 2, 2, 2, 2), dtype=np.float32) -case_batch_5d_spatial[0, 0, 0, 0, 0, 0] = 0.0; case_batch_5d_spatial[0, 1, 1, 1, 1, 1] = 0.0 +case_batch_5d_spatial[0, 0, 0, 0, 0, 0] = 0.0 +case_batch_5d_spatial[0, 1, 1, 1, 1, 1] = 0.0 + # ====================================================================== -# 测试逻辑 +# Test logic # ====================================================================== @pytest.mark.parametrize( "input_numpy, has_batch_dim", @@ -62,66 +102,82 @@ def batch_scipy_edt_with_indices(batch_numpy: np.ndarray) -> tuple[np.ndarray, n pytest.param(case_batch_1d, True, id="1D_Batch"), pytest.param(case_batch_2d, True, id="2D_Batch"), pytest.param(case_single_2d, False, id="2D_Single_NoBatch"), - pytest.param(case_explicit_batch_one, True, id="2D_Single_ExplicitBatch"), + pytest.param( + case_explicit_batch_one, + True, + id="2D_Single_ExplicitBatch", + ), pytest.param(case_batch_3d, True, id="3D_Batch"), pytest.param(case_dim_one, True, id="2D_UnitDim_Batch"), pytest.param(case_batch_4d_spatial, True, id="4D_Spatial_Batch"), pytest.param(case_batch_5d_spatial, True, id="5D_Spatial_Batch"), ], ) -def test_distance_transform_and_indices(input_numpy: np.ndarray, has_batch_dim: bool, request: pytest.FixtureRequest): +def test_distance_transform_and_indices( + input_numpy: np.ndarray, + has_batch_dim: bool, + request: pytest.FixtureRequest, +) -> None: if not torch.cuda.is_available(): pytest.skip("CUDA not available") - - # 1. 准备 Numpy 数据 + + # 1. Prepare NumPy data x_numpy_contiguous = np.ascontiguousarray(input_numpy) - - # 2. 准备 SciPy 输入 - # 如果意图是单样本 (has_batch_dim=False),我们手动增加 Batch 维, - # 这样 scipy 辅助函数就会把它当做一张图来处理,而不是 N 张 1D 图 + + # 2. Prepare SciPy input. + # If this is a single sample (has_batch_dim=False), manually add a + # batch dimension so SciPy treats it as one image instead of N 1D + # signals. if not has_batch_dim: scipy_input = x_numpy_contiguous[np.newaxis, ...] else: scipy_input = x_numpy_contiguous - # 3. 准备 CUDA 输入 - # 关键修复: - # 如果 has_batch_dim=False,说明这是单张 (H, W),我们要测 2D EDT。 - # C++ API 默认第一维是 Batch,所以我们必须 unsqueeze(0) 变成 (1, H, W)。 - # 否则 C++ 会把它当做 (Batch=H, Len=W) 做 1D EDT。 + # 3. Prepare CUDA input. + # If has_batch_dim=False, the input is (H, W) and we want 2D EDT. + # The C++ API assumes the first dimension is batch, so we must + # unsqueeze(0) to get shape (1, H, W). Otherwise, it will be + # interpreted as (Batch=H, Length=W) and run 1D EDT. x_cuda = torch.from_numpy(x_numpy_contiguous).cuda() if not has_batch_dim: x_cuda = x_cuda.unsqueeze(0) - print(f"\n\n--- 运行测试: {request.node.callspec.id} ---") - print(f"CUDA 输入形状: {x_cuda.shape}") + print(f"\n\n--- Running test: {request.node.callspec.id} ---") + print(f"CUDA input shape: {x_cuda.shape}") - # 4. 运行 CUDA EDT + # 4. Run CUDA EDT dist_cuda, idx_cuda = tm.distance_transform(x_cuda.clone()) - # 5. 运行 SciPy (Ground Truth) + # 5. Run SciPy (ground truth) dist_ref_numpy, idx_ref_numpy = batch_scipy_edt_with_indices(scipy_input) dist_ref = torch.from_numpy(dist_ref_numpy).to(torch.float32).cuda() - # 6. 验证距离 - # 此时 dist_cuda 是 (1, H, W),dist_ref 也是 (1, H, W) - # 如果原意是 NoBatch,我们可以把 Batch 维 squeeze 掉再比,或者直接比 - print(f"CUDA Out Shape: {dist_cuda.shape}, Ref Shape: {dist_ref.shape}") - assert dist_cuda.shape == dist_ref.shape, f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" + # 6. Validate distances + print( + f"CUDA distance shape: {dist_cuda.shape}, " + f"reference shape: {dist_ref.shape}", + ) + assert dist_cuda.shape == dist_ref.shape, ( + f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" + ) torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-3, rtol=1e-3) - print(">> 距离验证通过。") + print(">> Distance validation passed.") - # 7. 验证索引 - # idx_cuda: (1, H, W, 2) - # 构造 Grid - spatial_shape = x_cuda.shape[1:] # (H, W) - coords = [torch.arange(s, device='cuda') for s in spatial_shape] - grid = torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1) # (H, W, 2) - grid = grid.unsqueeze(0) # (1, H, W, 2) + # 7. Validate indices + # idx_cuda: (B, H, W, D) + spatial_shape = x_cuda.shape[1:] + coords = [torch.arange(s, device="cuda") for s in spatial_shape] + grid = torch.stack(torch.meshgrid(*coords, indexing="ij"), dim=-1) + grid = grid.unsqueeze(0) # (1, H, W, D) diff = grid.float() - idx_cuda.float() dist_sq_calculated = torch.sum(diff * diff, dim=-1) dist_sq_output = dist_cuda * dist_cuda - - torch.testing.assert_close(dist_sq_calculated, dist_sq_output, atol=1e-3, rtol=1e-3) - print(">> 索引验证通过。") \ No newline at end of file + + torch.testing.assert_close( + dist_sq_calculated, + dist_sq_output, + atol=1e-3, + rtol=1e-3, + ) + print(">> Index validation passed.") From 95724d9e200fde67276cad184ff5bd381f496607 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Thu, 11 Dec 2025 00:04:11 +0800 Subject: [PATCH 38/39] reformat --- benchmark/distance_transform.py | 4 ---- test/test_distance_transform.py | 7 +++---- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/benchmark/distance_transform.py b/benchmark/distance_transform.py index c659c82..c809ca3 100644 --- a/benchmark/distance_transform.py +++ b/benchmark/distance_transform.py @@ -1,9 +1,6 @@ import torch import torch.utils.benchmark as benchmark -import scipy.ndimage as ndi -import numpy as np from prettytable import PrettyTable -import torchmorph as tm sizes = [64, 128, 256, 512, 1024] batches = [1, 4, 8, 16] @@ -76,4 +73,3 @@ print(f"\n=== Batch Size: {B} ===") print(table) - diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index 56acb37..f852d36 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -1,9 +1,8 @@ -import numpy as np +import numpy as np # noqa: F401 import pytest import torch -from scipy.ndimage import distance_transform_edt as scipy_edt - -import torchmorph as tm +from scipy.ndimage import distance_transform_edt as scipy_edt # noqa: F401 +import torchmorph as tm # noqa: F401 # ====================================================================== From 86381e40e1d863ebc21c685956953f94a9553f50 Mon Sep 17 00:00:00 2001 From: Kai ZHAO <694691@qq.com> Date: Thu, 11 Dec 2025 00:07:26 +0800 Subject: [PATCH 39/39] isort --- benchmark/distance_transform.py | 20 +++++++++++--------- scripts/check_ascii.py | 2 +- setup.py | 5 +++-- test/test_add.py | 3 ++- test/test_distance_transform.py | 10 +++++----- torchmorph/add.py | 1 + torchmorph/distance_transform.py | 1 + 7 files changed, 24 insertions(+), 18 deletions(-) diff --git a/benchmark/distance_transform.py b/benchmark/distance_transform.py index c809ca3..3737ced 100644 --- a/benchmark/distance_transform.py +++ b/benchmark/distance_transform.py @@ -27,7 +27,7 @@ # Inputs x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] - x_imgs = [x[i:i+1] for i in range(B)] + x_imgs = [x[i : i + 1] for i in range(B)] # SciPy (CPU, one-by-one) stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]" @@ -62,14 +62,16 @@ speed1 = scipy_per_img_ms / torch1_per_img_ms speedB = scipy_per_img_ms / torchB_per_img_ms - table.add_row([ - s, - f"{scipy_per_img_ms:.3f}", - f"{torch1_per_img_ms:.3f}", - f"{torchB_per_img_ms:.3f}", - f"{speed1:.1f}×", - f"{speedB:.1f}×", - ]) + table.add_row( + [ + s, + f"{scipy_per_img_ms:.3f}", + f"{torch1_per_img_ms:.3f}", + f"{torchB_per_img_ms:.3f}", + f"{speed1:.1f}×", + f"{speedB:.1f}×", + ] + ) print(f"\n=== Batch Size: {B} ===") print(table) diff --git a/scripts/check_ascii.py b/scripts/check_ascii.py index bc1dcdd..d788056 100644 --- a/scripts/check_ascii.py +++ b/scripts/check_ascii.py @@ -46,7 +46,7 @@ def is_forbidden_char(ch: str) -> bool: if cat.startswith("L"): # Letter if in_ranges(ch, LATIN_RANGES): return False # Latin letters allowed - return True # Non-Latin letters forbidden + return True # Non-Latin letters forbidden # All non-letter stuff (emoji, arrows, symbols, punctuation) is allowed. return False diff --git a/setup.py b/setup.py index 9876b88..fd3f95e 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -import os import glob -from setuptools import setup, find_packages +import os + +from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension diff --git a/test/test_add.py b/test/test_add.py index a641494..5f647ff 100644 --- a/test/test_add.py +++ b/test/test_add.py @@ -1,5 +1,6 @@ -import torch import pytest +import torch + import torchmorph as tm diff --git a/test/test_distance_transform.py b/test/test_distance_transform.py index f852d36..5855bf3 100644 --- a/test/test_distance_transform.py +++ b/test/test_distance_transform.py @@ -2,6 +2,7 @@ import pytest import torch from scipy.ndimage import distance_transform_edt as scipy_edt # noqa: F401 + import torchmorph as tm # noqa: F401 @@ -153,12 +154,11 @@ def test_distance_transform_and_indices( # 6. Validate distances print( - f"CUDA distance shape: {dist_cuda.shape}, " - f"reference shape: {dist_ref.shape}", - ) - assert dist_cuda.shape == dist_ref.shape, ( - f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" + f"CUDA distance shape: {dist_cuda.shape}, " f"reference shape: {dist_ref.shape}", ) + assert ( + dist_cuda.shape == dist_ref.shape + ), f"Shape mismatch: {dist_cuda.shape} vs {dist_ref.shape}" torch.testing.assert_close(dist_cuda, dist_ref, atol=1e-3, rtol=1e-3) print(">> Distance validation passed.") diff --git a/torchmorph/add.py b/torchmorph/add.py index 4737073..f7c16b9 100644 --- a/torchmorph/add.py +++ b/torchmorph/add.py @@ -1,4 +1,5 @@ import torch + from torchmorph import _C diff --git a/torchmorph/distance_transform.py b/torchmorph/distance_transform.py index 0184be5..868e84a 100644 --- a/torchmorph/distance_transform.py +++ b/torchmorph/distance_transform.py @@ -1,4 +1,5 @@ import torch + from torchmorph import _C