Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6e60ec0
init distance transform
zeakey Oct 27, 2025
875483d
docs: 详细记录并修复项目环境搭建与构建流程
moliflower Oct 28, 2025
79c5acd
实现二维欧式距离变换(EDT)
moliflower Nov 1, 2025
72943e4
N维批处理的欧氏距离变换
moliflower Nov 1, 2025
f9420b2
修复BUG:原先版本中误将0当成背景,1当成前景
moliflower Nov 2, 2025
85bd1e5
returns both distance and index
zeakey Nov 3, 2025
be4eb3d
format
zeakey Nov 3, 2025
8f835e3
benchmark
zeakey Nov 3, 2025
7e10879
benchmark outputs tables
zeakey Nov 3, 2025
7b6b8aa
prettytable
zeakey Nov 3, 2025
15933d9
实现n维批处理同时返回坐标和距离的精确欧式距离变换
moliflower Nov 20, 2025
11f067d
测试文件调整+速度优化
moliflower Nov 20, 2025
4c427c7
采用JFA算法提高并行度
moliflower Dec 2, 2025
9291a49
优化合并内存
moliflower Dec 2, 2025
698ce24
增加3维以上维度的计算处理
moliflower Dec 6, 2025
bc9c03b
实现n维批处理同时返回坐标和距离的精确欧式距离变换
moliflower Dec 8, 2025
a6fafaa
测试文件调整+速度优化
moliflower Dec 8, 2025
54464a8
采用JFA算法提高并行度
moliflower Dec 8, 2025
32a09da
合并优化内存
moliflower Dec 8, 2025
5a9872f
增加三维以上维度的计算处理
moliflower Dec 8, 2025
4493039
forbid non-ascii
zeakey Dec 9, 2025
81e43df
workflow: precommit
zeakey Dec 9, 2025
95f8f34
workflow: run on all files
zeakey Dec 9, 2025
8341085
prevent duplicated ci
zeakey Dec 9, 2025
fd4cd5e
test workflow
zeakey Dec 9, 2025
1b520c2
test workflow
zeakey Dec 9, 2025
157bfac
test workflow
zeakey Dec 9, 2025
6599991
CUDA_HOME
zeakey Dec 9, 2025
d417ba9
CUDA_HOME
zeakey Dec 9, 2025
c93176c
CUDA_HOME
zeakey Dec 9, 2025
384f641
check ascii for .py
zeakey Dec 10, 2025
da7cadd
修改distance_transform_kernel.cu注释为英文
dongliangnie Dec 10, 2025
d622ccd
flake8
zeakey Dec 10, 2025
ee9e779
->
zeakey Dec 10, 2025
c42bb3e
check ascii for py
zeakey Dec 10, 2025
b8fbc8e
check non-latin languages
zeakey Dec 10, 2025
4ed0610
reformat
zeakey Dec 10, 2025
95724d9
reformat
zeakey Dec 10, 2025
86381e4
isort
zeakey Dec 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 100
extend-ignore = E203, W503 # Compatibility with Black
extend-ignore = E203, W503
exclude =
__pycache__,
build,
Expand Down
32 changes: 32 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: pre-commit

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "*" ]

jobs:
pre-commit:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.x"

- name: Install pre-commit
run: |
pip install pre-commit

- name: Install pre-commit hooks
run: |
pre-commit install --install-hooks

- name: Run pre-commit on all files
run: |
pre-commit run --all-files
27 changes: 27 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: test

on:
pull_request:
branches: ["*"]
push:
branches: [main]

jobs:

test:
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- run: |
# Make CUDA visible to this shell and all child processes
export CUDA_HOME=/usr/local/cuda
export PATH="$CUDA_HOME/bin:$PATH"
export LD_LIBRARY_PATH="$CUDA_HOME/lib64:${LD_LIBRARY_PATH:-}"
echo "CUDA_HOME=$CUDA_HOME"
pip install -r requirements-dev.txt --break-system-packages --user
pip uninstall torchmorph --yes --break-system-packages
python setup.py install --user
- run: |
ORIGINAL=$(pwd)
cd /tmp
pytest $ORIGINAL/test
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,15 @@ repos:
rev: 6.1.0
hooks:
- id: flake8

# -------------------------
# ⭐ Local Hook: forbid non-ASCII in C/C++/CUDA
# -------------------------
- repo: local
hooks:
- id: forbid-non-ascii
name: "Forbid non-ASCII characters in C/C++/CUDA"
entry: python3 scripts/check_ascii.py
language: system
types: [file]
files: '\.(c|cc|cpp|cxx|cu|cuh|h|hpp|py)$'
77 changes: 77 additions & 0 deletions benchmark/distance_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import torch.utils.benchmark as benchmark
from prettytable import PrettyTable

sizes = [64, 128, 256, 512, 1024]
batches = [1, 4, 8, 16]
dtype = torch.float32
device = "cuda"
MIN_RUN = 1.0 # seconds per measurement

torch.set_num_threads(torch.get_num_threads())

for B in batches:
table = PrettyTable()
table.field_names = [
"Size",
"SciPy (ms/img)",
"Torch 1× (ms/img)",
"Torch batch (ms/img)",
"Speedup 1×",
"Speedup batch",
]
for c in table.field_names:
table.align[c] = "r"

for s in sizes:
# Inputs
x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype)
x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)]
x_imgs = [x[i : i + 1] for i in range(B)]

# SciPy (CPU, one-by-one)
stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]"
t_scipy = benchmark.Timer(
stmt=stmt_scipy,
setup="from __main__ import x_np_list, ndi",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
scipy_per_img_ms = (t_scipy.median * 1e3) / B

# Torch (CUDA, one-by-one)
stmt_torch1 = """
for xi in x_imgs:
tm.distance_transform(xi)
"""
t_torch1 = benchmark.Timer(
stmt=stmt_torch1,
setup="from __main__ import x_imgs, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
torch1_per_img_ms = (t_torch1.median * 1e3) / B

# Torch (CUDA, batched)
t_batch = benchmark.Timer(
stmt="tm.distance_transform(x)",
setup="from __main__ import x, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
torchB_per_img_ms = (t_batch.median * 1e3) / B

# Speedups
speed1 = scipy_per_img_ms / torch1_per_img_ms
speedB = scipy_per_img_ms / torchB_per_img_ms

table.add_row(
[
s,
f"{scipy_per_img_ms:.3f}",
f"{torch1_per_img_ms:.3f}",
f"{torchB_per_img_ms:.3f}",
f"{speed1:.1f}×",
f"{speedB:.1f}×",
]
)

print(f"\n=== Batch Size: {B} ===")
print(table)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ max-line-length = 100
extend-ignore = ["E203", "W503"]

[tool.pytest.ini_options]
addopts = "-v"
addopts = "-v --import-mode=importlib"
testpaths = ["test"]

[build-system]
requires = ["setuptools>=61.0", "wheel", "torch", "numpy"]
build-backend = "setuptools.build_meta"
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ flake8>=6.0
setuptools>=65.0
wheel>=0.40
ninja>=1.11 # optional, speeds up torch extension builds

prettytable>=3.16.0
111 changes: 111 additions & 0 deletions scripts/check_ascii.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/usr/bin/env python3
import sys
import unicodedata
from pathlib import Path

TARGET_SUFFIXES = {".c", ".cc", ".cpp", ".cxx", ".cu", ".cuh", ".h", ".hpp", ".py"}


# --- Helpers --------------------------------------------------------


# Latin ranges we still consider "English-ish" and therefore allowed.
# (You can shrink this if you want to ban accented letters too.)
LATIN_RANGES = [
(0x0000, 0x007F), # Basic Latin (ASCII)
(0x00C0, 0x024F), # Latin-1 Supplement + Latin Extended-A/B
(0x1E00, 0x1EFF), # Latin Extended Additional
]


def in_ranges(ch: str, ranges) -> bool:
cp = ord(ch)
for start, end in ranges:
if start <= cp <= end:
return True
return False


def is_forbidden_char(ch: str) -> bool:
"""
Return True if ch should be *forbidden*.

Policy:
- ASCII (<= 0x7F): always OK
- Non-ASCII letters (Unicode category starting with 'L')
that are NOT in Latin ranges: forbidden
- Everything else (emoji, arrows, symbols, etc.): allowed
"""
cp = ord(ch)
if cp <= 0x7F:
return False # pure ASCII

cat = unicodedata.category(ch)

# Forbid letters that are not Latin.
if cat.startswith("L"): # Letter
if in_ranges(ch, LATIN_RANGES):
return False # Latin letters allowed
return True # Non-Latin letters forbidden

# All non-letter stuff (emoji, arrows, symbols, punctuation) is allowed.
return False


def find_forbidden_chars(line: str):
"""Return list of (index, char) for all forbidden chars in a line."""
result = []
for i, ch in enumerate(line):
if is_forbidden_char(ch):
result.append((i, ch))
return result


# --- Core logic -----------------------------------------------------


def check_file(path: Path) -> bool:
ok = True
with path.open("r", encoding="utf-8", errors="ignore") as f:
for lineno, line in enumerate(f, start=1):
forbidden = find_forbidden_chars(line)
if forbidden:
ok = False
print(f"\n❌ {path}:{lineno}: non-English letters detected")

# Print the full line
print(" Line content:")
print(f" {line.rstrip()}")

# Underline the forbidden characters
underline = [" " for _ in line.rstrip("\n")]
for idx, ch in forbidden:
if idx < len(underline):
underline[idx] = "^"
print(f" {''.join(underline)}")

# Print what characters exactly
chars = ", ".join(
f"'{ch}' (U+{ord(ch):04X}) [{unicodedata.name(ch, 'UNKNOWN')}]"
for _, ch in forbidden
)
print(f" Offending chars: {chars}")

return ok


def main(files):
ok = True
for f in files:
p = Path(f)
if p.suffix.lower() in TARGET_SUFFIXES and p.exists():
if not check_file(p):
ok = False
sys.exit(0 if ok else 1)


if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: check_ascii.py <files...>")
sys.exit(1)
main(sys.argv[1:])
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import glob
from setuptools import setup, find_packages
import os

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


Expand Down
3 changes: 2 additions & 1 deletion test/test_add.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import pytest
import torch

import torchmorph as tm


Expand Down
Loading