Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
6e60ec0
init distance transform
zeakey Oct 27, 2025
875483d
docs: 详细记录并修复项目环境搭建与构建流程
moliflower Oct 28, 2025
79c5acd
实现二维欧式距离变换(EDT)
moliflower Nov 1, 2025
72943e4
N维批处理的欧氏距离变换
moliflower Nov 1, 2025
f9420b2
修复BUG:原先版本中误将0当成背景,1当成前景
moliflower Nov 2, 2025
85bd1e5
returns both distance and index
zeakey Nov 3, 2025
be4eb3d
format
zeakey Nov 3, 2025
8f835e3
benchmark
zeakey Nov 3, 2025
7e10879
benchmark outputs tables
zeakey Nov 3, 2025
7b6b8aa
prettytable
zeakey Nov 3, 2025
15933d9
实现n维批处理同时返回坐标和距离的精确欧式距离变换
moliflower Nov 20, 2025
11f067d
测试文件调整+速度优化
moliflower Nov 20, 2025
4c427c7
采用JFA算法提高并行度
moliflower Dec 2, 2025
9291a49
优化合并内存
moliflower Dec 2, 2025
698ce24
增加3维以上维度的计算处理
moliflower Dec 6, 2025
bc9c03b
实现n维批处理同时返回坐标和距离的精确欧式距离变换
moliflower Dec 8, 2025
a6fafaa
测试文件调整+速度优化
moliflower Dec 8, 2025
54464a8
采用JFA算法提高并行度
moliflower Dec 8, 2025
32a09da
合并优化内存
moliflower Dec 8, 2025
5a9872f
增加三维以上维度的计算处理
moliflower Dec 8, 2025
4493039
forbid non-ascii
zeakey Dec 9, 2025
81e43df
workflow: precommit
zeakey Dec 9, 2025
95f8f34
workflow: run on all files
zeakey Dec 9, 2025
8341085
prevent duplicated ci
zeakey Dec 9, 2025
fd4cd5e
test workflow
zeakey Dec 9, 2025
1b520c2
test workflow
zeakey Dec 9, 2025
157bfac
test workflow
zeakey Dec 9, 2025
6599991
CUDA_HOME
zeakey Dec 9, 2025
d417ba9
CUDA_HOME
zeakey Dec 9, 2025
c93176c
CUDA_HOME
zeakey Dec 9, 2025
384f641
check ascii for .py
zeakey Dec 10, 2025
da7cadd
修改distance_transform_kernel.cu注释为英文
dongliangnie Dec 10, 2025
d622ccd
flake8
zeakey Dec 10, 2025
ee9e779
->
zeakey Dec 10, 2025
c42bb3e
check ascii for py
zeakey Dec 10, 2025
b8fbc8e
check non-latin languages
zeakey Dec 10, 2025
4ed0610
reformat
zeakey Dec 10, 2025
95724d9
reformat
zeakey Dec 10, 2025
86381e4
isort
zeakey Dec 10, 2025
d2df725
移除switch精简代码 INF_VAL 改为 1e20
moliflower Dec 14, 2025
11a9657
add imports for scipy and torchmorph
moliflower Dec 14, 2025
bdaf48f
use JFA for 2D/3D and separable transform for high dimensions
moliflower Dec 20, 2025
326e247
speed up 2D
moliflower Dec 22, 2025
f91761a
speed up 3D
moliflower Dec 23, 2025
879eaa5
add 1D EDT test
moliflower Jan 2, 2026
e4ff013
1D EDT
moliflower Jan 2, 2026
8bf6318
make input tensor to be at least 1D
moliflower Jan 2, 2026
0f4c900
restore code to JFA
moliflower Jan 13, 2026
955e185
Restore code to JFA
moliflower Jan 13, 2026
4ce7da7
test ping
moliflower Jan 16, 2026
fa5c940
Merge branch 'dist' of github.com:intcomp/torchmorph into dist
moliflower Jan 16, 2026
bb2bb7e
Align scipy edt
moliflower Jan 24, 2026
ea02c5b
speed up 2D
moliflower Jan 24, 2026
4fb5ee5
add cdt function
moliflower Jan 28, 2026
fe845fe
Merge branch 'main' into dist
moliflower Jan 30, 2026
1d3546b
Feat: Implement Chamfer Distance (CDT) & Resolve merge conflicts
moliflower Jan 30, 2026
6a16d6b
add edt 3D benchmark
moliflower Feb 2, 2026
2d67351
add algorithm parameter for edt
moliflower Feb 2, 2026
292a8c4
Merge branch 'main' into dist
moliflower Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 0 additions & 77 deletions benchmark/distance_transform.py

This file was deleted.

89 changes: 89 additions & 0 deletions benchmark/distance_transform_cdt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import scipy.ndimage as ndi # noqa: F401
import torch
import torch.utils.benchmark as benchmark
from prettytable import PrettyTable

import torchmorph as tm # noqa: F401

sizes = [64, 128, 256, 512, 1024]
batches = [1, 4, 8, 16]
dtype = torch.float32
device = "cuda"
MIN_RUN = 1.0 # seconds per measurement

torch.set_num_threads(torch.get_num_threads())

for metric in ["chessboard", "taxicab"]:
print(f"\n{'='*60}")
print(f" CDT Benchmark - Metric: {metric}")
print(f"{'='*60}")

for B in batches:
table = PrettyTable()
table.field_names = [
"Size",
"SciPy (ms/img)",
"Torch 1× (ms/img)",
"Torch batch (ms/img)",
"Speedup 1×",
"Speedup batch",
]
for c in table.field_names:
table.align[c] = "r"

for s in sizes:
# Inputs: (B, C, H, W) format - C=1 for single channel
x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype)
# For scipy, we need (H, W) arrays
x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)]
# For torch single image processing: each is (1, 1, H, W)
x_imgs = [x[i : i + 1] for i in range(B)]

# SciPy (CPU, one-by-one)
stmt_scipy = (
f"out = [ndi.distance_transform_cdt(arr, metric='{metric}') for arr in x_np_list]"
)
t_scipy = benchmark.Timer(
stmt=stmt_scipy,
setup="from __main__ import x_np_list, ndi",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
scipy_per_img_ms = (t_scipy.median * 1e3) / B

# Torch (CUDA, one-by-one)
stmt_torch1 = f"""
for xi in x_imgs:
tm.distance_transform_cdt(xi, metric='{metric}')
"""
t_torch1 = benchmark.Timer(
stmt=stmt_torch1,
setup="from __main__ import x_imgs, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
torch1_per_img_ms = (t_torch1.median * 1e3) / B

# Torch (CUDA, batched)
t_batch = benchmark.Timer(
stmt=f"tm.distance_transform_cdt(x, metric='{metric}')",
setup="from __main__ import x, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
torchB_per_img_ms = (t_batch.median * 1e3) / B

# Speedups
speed1 = scipy_per_img_ms / torch1_per_img_ms
speedB = scipy_per_img_ms / torchB_per_img_ms

table.add_row(
[
s,
f"{scipy_per_img_ms:.3f}",
f"{torch1_per_img_ms:.3f}",
f"{torchB_per_img_ms:.3f}",
f"{speed1:.1f}×",
f"{speedB:.1f}×",
]
)

print(f"\n=== Metric: {metric}, Batch Size: {B} ===")
print(table)
106 changes: 106 additions & 0 deletions benchmark/distance_transform_edt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import scipy.ndimage as ndi # noqa: F401
import torch
import torch.utils.benchmark as benchmark
from prettytable import PrettyTable

import torchmorph as tm # noqa: F401

sizes = [64, 128, 256, 512, 1024]
batches = [1, 4, 8, 16]
dtype = torch.float32
device = "cuda"
MIN_RUN = 1.0 # seconds per measurement

torch.set_num_threads(torch.get_num_threads())

for B in batches:
table = PrettyTable()
table.field_names = [
"Size",
"SciPy (ms/img)",
"Exact 1× (ms/img)",
"Exact batch (ms/img)",
"JFA 1× (ms/img)",
"JFA batch (ms/img)",
"Speedup Exact",
"Speedup JFA",
]
for c in table.field_names:
table.align[c] = "r"

for s in sizes:
# Inputs: (B, C, H, W) format - C=1 for single channel
x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype)
# For scipy, we need (H, W) arrays
x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)]
# For torch single image processing: each is (1, 1, H, W)
x_imgs = [x[i : i + 1] for i in range(B)]

# SciPy (CPU, one-by-one)
stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]"
t_scipy = benchmark.Timer(
stmt=stmt_scipy,
setup="from __main__ import x_np_list, ndi",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
scipy_per_img_ms = (t_scipy.median * 1e3) / B

# Torch Exact (CUDA, one-by-one)
stmt_exact1 = """
for xi in x_imgs:
tm.distance_transform_edt(xi, algorithm="exact")
"""
t_exact1 = benchmark.Timer(
stmt=stmt_exact1,
setup="from __main__ import x_imgs, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
exact1_per_img_ms = (t_exact1.median * 1e3) / B

# Torch Exact (CUDA, batched)
t_exact_batch = benchmark.Timer(
stmt='tm.distance_transform_edt(x, algorithm="exact")',
setup="from __main__ import x, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
exactB_per_img_ms = (t_exact_batch.median * 1e3) / B

# Torch JFA (CUDA, one-by-one)
stmt_jfa1 = """
for xi in x_imgs:
tm.distance_transform_edt(xi, algorithm="jfa")
"""
t_jfa1 = benchmark.Timer(
stmt=stmt_jfa1,
setup="from __main__ import x_imgs, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
jfa1_per_img_ms = (t_jfa1.median * 1e3) / B

# Torch JFA (CUDA, batched)
t_jfa_batch = benchmark.Timer(
stmt='tm.distance_transform_edt(x, algorithm="jfa")',
setup="from __main__ import x, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
jfaB_per_img_ms = (t_jfa_batch.median * 1e3) / B

# Speedups (batch mode vs scipy)
speed_exact = scipy_per_img_ms / exactB_per_img_ms
speed_jfa = scipy_per_img_ms / jfaB_per_img_ms

table.add_row(
[
s,
f"{scipy_per_img_ms:.3f}",
f"{exact1_per_img_ms:.3f}",
f"{exactB_per_img_ms:.3f}",
f"{jfa1_per_img_ms:.3f}",
f"{jfaB_per_img_ms:.3f}",
f"{speed_exact:.1f}×",
f"{speed_jfa:.1f}×",
]
)

print(f"\n=== Batch Size: {B} ===")
print(table)
112 changes: 112 additions & 0 deletions benchmark/distance_transform_edt_3D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import scipy.ndimage as ndi # noqa: F401
import torch
import torch.utils.benchmark as benchmark
from prettytable import PrettyTable

import torchmorph as tm # noqa: F401

# 3D benchmark configurations
sizes = [32, 64, 128, 256] # D=H=W
batches = [1, 2, 4, 8]
dtype = torch.float32
device = "cuda"
MIN_RUN = 1.0 # seconds per measurement

torch.set_num_threads(torch.get_num_threads())

for B in batches:
table = PrettyTable()
table.field_names = [
"Size (D×H×W)",
"SciPy (ms/vol)",
"Exact 1× (ms/vol)",
"Exact batch (ms/vol)",
"JFA 1× (ms/vol)",
"JFA batch (ms/vol)",
"Speedup Exact",
"Speedup JFA",
]
for c in table.field_names:
table.align[c] = "r"

for s in sizes:
# Skip large sizes with large batches to avoid OOM
if s >= 256 and B >= 4:
table.add_row([f"{s}³", "OOM", "OOM", "OOM", "OOM", "OOM", "-", "-"])
continue

# Inputs: (B, D, H, W) format for 3D - no channel dimension for JFA 3D
x = (torch.randn(B, s, s, s, device=device) > 0).to(dtype)
# For scipy, we need (D, H, W) arrays
x_np_list = [x[i].detach().cpu().numpy() for i in range(B)]
# For torch single volume processing: each is (1, D, H, W)
x_vols = [x[i : i + 1] for i in range(B)]

# SciPy (CPU, one-by-one)
stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]"
t_scipy = benchmark.Timer(
stmt=stmt_scipy,
setup="from __main__ import x_np_list, ndi",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
scipy_per_vol_ms = (t_scipy.median * 1e3) / B

# Torch Exact (CUDA, one-by-one)
stmt_exact1 = """
for xi in x_vols:
tm.distance_transform_edt(xi, algorithm="exact")
"""
t_exact1 = benchmark.Timer(
stmt=stmt_exact1,
setup="from __main__ import x_vols, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
exact1_per_vol_ms = (t_exact1.median * 1e3) / B

# Torch Exact (CUDA, batched)
t_exact_batch = benchmark.Timer(
stmt='tm.distance_transform_edt(x, algorithm="exact")',
setup="from __main__ import x, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
exactB_per_vol_ms = (t_exact_batch.median * 1e3) / B

# Torch JFA (CUDA, one-by-one)
stmt_jfa1 = """
for xi in x_vols:
tm.distance_transform_edt(xi, algorithm="jfa")
"""
t_jfa1 = benchmark.Timer(
stmt=stmt_jfa1,
setup="from __main__ import x_vols, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
jfa1_per_vol_ms = (t_jfa1.median * 1e3) / B

# Torch JFA (CUDA, batched)
t_jfa_batch = benchmark.Timer(
stmt='tm.distance_transform_edt(x, algorithm="jfa")',
setup="from __main__ import x, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
jfaB_per_vol_ms = (t_jfa_batch.median * 1e3) / B

# Speedups (batch mode vs scipy)
speed_exact = scipy_per_vol_ms / exactB_per_vol_ms
speed_jfa = scipy_per_vol_ms / jfaB_per_vol_ms

table.add_row(
[
f"{s}³",
f"{scipy_per_vol_ms:.3f}",
f"{exact1_per_vol_ms:.3f}",
f"{exactB_per_vol_ms:.3f}",
f"{jfa1_per_vol_ms:.3f}",
f"{jfaB_per_vol_ms:.3f}",
f"{speed_exact:.1f}×",
f"{speed_jfa:.1f}×",
]
)

print(f"\n=== 3D EDT Benchmark | Batch Size: {B} ===")
print(table)
Loading
Loading