From 8e85f18270cbbf37562213425d61638d70b7904e Mon Sep 17 00:00:00 2001 From: dongliangnie Date: Sat, 13 Dec 2025 18:25:36 +0800 Subject: [PATCH 1/4] Add binary dilation and erosion with N-D support --- benchmark/dilation_erosion.py | 93 ++++++++++++++++++ torchmorph/__init__.py | 4 +- torchmorph/dilation_erosion.py | 166 +++++++++++++++++++++++++++++++++ 3 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 benchmark/dilation_erosion.py create mode 100644 torchmorph/dilation_erosion.py diff --git a/benchmark/dilation_erosion.py b/benchmark/dilation_erosion.py new file mode 100644 index 0000000..ed49464 --- /dev/null +++ b/benchmark/dilation_erosion.py @@ -0,0 +1,93 @@ +import torch +import torch.utils.benchmark as benchmark +import scipy.ndimage as ndi +import numpy as np +from prettytable import PrettyTable +import torchmorph as tm + +sizes = [64, 128, 256, 512] +batches = [1, 4, 8, 16] +dtype = torch.float32 +device = "cuda" +MIN_RUN = 1.0 # seconds per measurement + +torch.set_num_threads(torch.get_num_threads()) + + +def bench_single_op(op_name): + """ + op_name: "dilation" or "erosion" + """ + + scipy_op = ndi.binary_dilation if op_name == "dilation" else ndi.binary_erosion + torch_op = tm.binary_dilation if op_name == "dilation" else tm.binary_erosion + + print("\n==============================") + print(f" Benchmark: Binary {op_name}") + print("==============================") + + for B in batches: + table = PrettyTable() + table.field_names = [ + "Size", + "SciPy (ms/img)", + "Torch 1× (ms/img)", + "Torch batch (ms/img)", + "Speedup 1×", + "Speedup batch", + ] + for c in table.field_names: + table.align[c] = "r" + + for s in sizes: + # Generate binary input + x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) + x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] + x_imgs = [x[i:i+1] for i in range(B)] + + # SciPy (CPU, one-by-one) + stmt_scipy = "out = [scipy_op(arr) for arr in x_np_list]" + t_scipy = benchmark.Timer( + stmt=stmt_scipy, + globals={"x_np_list": x_np_list, "scipy_op": scipy_op}, + ).blocked_autorange(min_run_time=MIN_RUN) + scipy_ms = (t_scipy.median * 1e3) / B + + # Torch CUDA (one-by-one) + stmt_torch1 = """ +for xi in x_imgs: + torch_op(xi) +""" + t_torch1 = benchmark.Timer( + stmt=stmt_torch1, + globals={"x_imgs": x_imgs, "torch_op": torch_op}, + ).blocked_autorange(min_run_time=MIN_RUN) + torch1_ms = (t_torch1.median * 1e3) / B + + # Torch CUDA (batched) + t_batch = benchmark.Timer( + stmt="torch_op(x)", + globals={"x": x, "torch_op": torch_op}, + ).blocked_autorange(min_run_time=MIN_RUN) + torchB_ms = (t_batch.median * 1e3) / B + + # Speedups + speed1 = scipy_ms / torch1_ms + speedB = scipy_ms / torchB_ms + + table.add_row([ + s, + f"{scipy_ms:.3f}", + f"{torch1_ms:.3f}", + f"{torchB_ms:.3f}", + f"{speed1:.1f}×", + f"{speedB:.1f}×", + ]) + + print(f"\n=== Batch Size: {B} ===") + print(table) + + +print("Loaded from:", tm.__file__) +bench_single_op("dilation") +bench_single_op("erosion") diff --git a/torchmorph/__init__.py b/torchmorph/__init__.py index 1f31247..a8e3d67 100644 --- a/torchmorph/__init__.py +++ b/torchmorph/__init__.py @@ -1,4 +1,6 @@ from .add import add from .distance_transform import distance_transform +from .dilation import binary_dilation,binary_erosion -__all__ = ["add", "distance_transform"] +__all__ = ["add", "distance_transform","binary_dilation", + "binary_erosion",] diff --git a/torchmorph/dilation_erosion.py b/torchmorph/dilation_erosion.py new file mode 100644 index 0000000..7484500 --- /dev/null +++ b/torchmorph/dilation_erosion.py @@ -0,0 +1,166 @@ +import torch +import torch.nn.functional as F +from typing import Optional, Union, Sequence, Tuple + +def _to_bool_tensor(x: torch.Tensor) -> torch.Tensor: + """ + Convert an input value into a boolean PyTorch tensor. + + This helper function ensures that the input is represented as a + `torch.bool` tensor, which is the internal format required by + binary morphological operations (e.g., dilation/erosion). + + Behavior: + - If `x` is not already a tensor, it is converted using `torch.tensor(x)`. + - Non-zero values become `True`; zero values become `False`. + + Args: + x (torch.Tensor or array-like): + Input data. May be a Python list, scalar, NumPy array, or torch.Tensor. + + Returns: + torch.Tensor (dtype=torch.bool): + Boolean tensor where each element is `True` if corresponding input value + is non-zero, otherwise `False`. + + Examples: + >>> _to_bool_tensor([0, 1, 2]) + tensor([False, True, True]) + + >>> _to_bool_tensor(torch.tensor([3.0, 0.0])) + tensor([True, False]) + """ + # If x is not a tensor yet (e.g., list, numpy array, int, float), convert to tensor. + if not torch.is_tensor(x): + x = torch.tensor(x) + + # Convert input tensor into boolean by checking non-zero status. + # Non-zero -> True, zero -> False. + return (x != 0) + + +def _normalize_structure(structure: Optional[torch.Tensor], ndim: int) -> torch.Tensor: + if structure is None: + shape = (3,) * ndim + return torch.ones(shape, dtype=torch.bool) + st = _to_bool_tensor(structure) + if st.ndim != ndim: + raise ValueError(f"structure must be {ndim}-D (got {st.ndim}-D)") + return st + +def _origin_to_tuple(origin: Union[int, Sequence[int], Tuple[int,...]], ndim: int) -> Tuple[int,...]: + if isinstance(origin, int): + return tuple([origin] * ndim) + origin = tuple(origin) + if len(origin) != ndim: + raise ValueError("origin must match spatial ndim") + return origin + +def _pad_for_kernel(kernel_shape: Sequence[int], origin: Sequence[int]) -> Tuple[Tuple[int,int], ...]: + pads = [] + for k, o in zip(kernel_shape, origin): + pad_before = k//2 - o + pad_after = k - 1 - pad_before + pad_before = max(pad_before, 0) + pad_after = max(pad_after, 0) + pads.append((pad_before, pad_after)) + return tuple(pads) + +def _make_padding_tuple_for_Fpad(pads: Tuple[Tuple[int,int], ...]) -> Tuple[int,...]: + flat = [] + for pb, pa in reversed(pads): + flat.append(pb) + flat.append(pa) + return tuple(flat) + +def _conv_nd(x: torch.Tensor, kernel: torch.Tensor, ndim: int) -> torch.Tensor: + weight = ( + kernel.to(dtype=x.dtype, device=x.device) + .unsqueeze(0).unsqueeze(0) + ) + if ndim == 1: + return F.conv1d(x, weight) + elif ndim == 2: + return F.conv2d(x, weight) + elif ndim == 3: + return F.conv3d(x, weight) + else: + raise NotImplementedError("Only supports 1D/2D/3D") + +def _morph_op( + input_tensor: torch.Tensor, + structure: Optional[torch.Tensor], + iterations: int, + origin: Union[int, Sequence[int]], + border_value: int, + mode: str +) -> torch.Tensor: + + if mode not in ('dilation', 'erosion'): + raise ValueError("mode must be 'dilation' or 'erosion'") + + x = input_tensor + if not torch.is_tensor(x): + x = torch.tensor(x) + + x_bool = (x != 0) + + # Support: (H,W), (C,H,W), (B,C,H,W), (B,C,D,H,W) + full_ndim = x_bool.ndim + + if full_ndim < 2: + raise NotImplementedError("Need at least 2D (H,W)") + if full_ndim > 5: + raise NotImplementedError("Only supports up to 5D (B,C,D,H,W)") + + # Spatial dims = last 1~3 dims + spatial_ndim = full_ndim - 2 # remove (B,C) + if not (1 <= spatial_ndim <= 3): + raise NotImplementedError("Supports 1D/2D/3D spatial dims") + + B, C = x_bool.shape[0], x_bool.shape[1] + spatial_shape = x_bool.shape[2:] + + # structure must match spatial dims + st = _normalize_structure(structure, spatial_ndim) + origin_t = _origin_to_tuple(origin, spatial_ndim) + + k_sum = st.sum().item() + kernel = st.to(torch.float32) + + # apply origin shift + for axis, o in enumerate(origin_t): + if o != 0: + kernel = torch.roll(kernel, shifts=-o, dims=axis) + + pads = _pad_for_kernel(kernel.shape, origin_t) + pad_tuple = _make_padding_tuple_for_Fpad(pads) + + # cast to float + cur = x_bool.to(torch.float32) + + # Now do B*C loops, because conv2d can't do dilation per-channel independently + cur = cur.view(B*C, 1, *spatial_shape) + + for _ in range(max(1, iterations)): + x_pad = F.pad(cur, pad_tuple, value=float(border_value)) + conv_res = _conv_nd(x_pad, kernel, spatial_ndim) + conv_res = conv_res # shape unchanged: (BC,1,H,W) or (BC,1,D,H,W) + + if mode == 'dilation': + cur = (conv_res > 0).to(torch.float32) + else: # erosion + if k_sum == 0: + cur = torch.ones_like(cur) + else: + cur = (conv_res >= (k_sum - 1e-6)).to(torch.float32) + + # reshape back + out = cur.view(B, C, *spatial_shape) + return out.to(torch.bool) + +def binary_dilation(input_tensor, structure=None, iterations=1, origin=0, border_value=0): + return _morph_op(input_tensor, structure, iterations, origin, border_value, mode="dilation") + +def binary_erosion(input_tensor, structure=None, iterations=1, origin=0, border_value=0): + return _morph_op(input_tensor, structure, iterations, origin, border_value, mode="erosion") From 0061fc13ab689cef1cf47073a186a868dbc4c36d Mon Sep 17 00:00:00 2001 From: dongliangnie Date: Sat, 13 Dec 2025 19:54:44 +0800 Subject: [PATCH 2/4] Add more detailed comments to explain the purpose and behavior of this code. --- benchmark/dilation_erosion.py | 7 +- torchmorph/dilation_erosion.py | 210 +++++++++++++++++++++++++++++---- 2 files changed, 187 insertions(+), 30 deletions(-) diff --git a/benchmark/dilation_erosion.py b/benchmark/dilation_erosion.py index ed49464..0ee94fb 100644 --- a/benchmark/dilation_erosion.py +++ b/benchmark/dilation_erosion.py @@ -43,8 +43,7 @@ def bench_single_op(op_name): # Generate binary input x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] - x_imgs = [x[i:i+1] for i in range(B)] - + x_imgs = [x[i:i+1] for i in range(B)] # (1, 1, H, W) # SciPy (CPU, one-by-one) stmt_scipy = "out = [scipy_op(arr) for arr in x_np_list]" t_scipy = benchmark.Timer( @@ -54,10 +53,6 @@ def bench_single_op(op_name): scipy_ms = (t_scipy.median * 1e3) / B # Torch CUDA (one-by-one) - stmt_torch1 = """ -for xi in x_imgs: - torch_op(xi) -""" t_torch1 = benchmark.Timer( stmt=stmt_torch1, globals={"x_imgs": x_imgs, "torch_op": torch_op}, diff --git a/torchmorph/dilation_erosion.py b/torchmorph/dilation_erosion.py index 7484500..85c7153 100644 --- a/torchmorph/dilation_erosion.py +++ b/torchmorph/dilation_erosion.py @@ -40,44 +40,187 @@ def _to_bool_tensor(x: torch.Tensor) -> torch.Tensor: def _normalize_structure(structure: Optional[torch.Tensor], ndim: int) -> torch.Tensor: + """ + Normalize a structuring element into a boolean tensor with the correct + number of spatial dimensions. + + This utility function standardizes user-provided structuring elements + for binary morphological operations (e.g., dilation and erosion). + + Behavior: + 1. If `structure` is None, a default full-connectivity structuring + element of shape (3, 3, ..., 3) with `ndim` dimensions is created. + This matches the default behavior of scipy.ndimage morphology. + 2. If `structure` is provided, it is converted into a boolean tensor, + where non-zero values are treated as True. + 3. The dimensionality of the structuring element is strictly checked + to ensure it matches the spatial dimensionality of the input. + A mismatch indicates an invalid morphological definition and + raises a ValueError. + + Args: + structure (Optional[torch.Tensor]): + Structuring element defining the neighborhood for morphology. + If None, a full (3,) * ndim boolean structure is used. + ndim (int): + Number of spatial dimensions of the input (e.g., 2 for H×W, + 3 for D×H×W). Batch and channel dimensions are excluded. + + Returns: + torch.Tensor (dtype=torch.bool): + An `ndim`-dimensional boolean tensor representing the normalized + structuring element. + + Raises: + ValueError: + If the provided structuring element does not have exactly + `ndim` dimensions. + + Notes: + - This function does not enforce any particular kernel size other + than dimensionality; arbitrary shapes are allowed. + - Channel and batch dimensions are intentionally not supported + for structuring elements, as morphology is defined purely in + spatial dimensions. + + Examples: + >>> _normalize_structure(None, ndim=2) + tensor([[True, True, True], + [True, True, True], + [True, True, True]]) + + >>> _normalize_structure([[0, 1, 0], + ... [1, 1, 1], + ... [0, 1, 0]], ndim=2) + tensor([[False, True, False], + [ True, True, True], + [False, True, False]]) + """ + # Case 1: No structuring element provided by the user. + # Use a default full-connectivity neighborhood of size 3 in each + # spatial dimension (e.g., 3×3 for 2D, 3×3×3 for 3D). if structure is None: shape = (3,) * ndim return torch.ones(shape, dtype=torch.bool) + + # Case 2: A structuring element is provided. + # Convert it to a boolean tensor so that non-zero values indicate + # active neighbors and zero values are ignored. st = _to_bool_tensor(structure) + + # Validate dimensionality: the structuring element must have the same + # number of dimensions as the spatial dimensions of the input tensor. if st.ndim != ndim: - raise ValueError(f"structure must be {ndim}-D (got {st.ndim}-D)") + raise ValueError( + f"structure must be {ndim}-D (got {st.ndim}-D)" + ) + + # Return the normalized boolean structuring element. return st -def _origin_to_tuple(origin: Union[int, Sequence[int], Tuple[int,...]], ndim: int) -> Tuple[int,...]: + +def _origin_to_tuple( + origin: Union[int, Sequence[int], Tuple[int, ...]], + ndim: int +) -> Tuple[int, ...]: + """ + Normalize the `origin` argument into an ndim-length tuple. + + The origin defines the anchor point of the structuring element, + consistent with SciPy's definition. + + Args: + origin (int or sequence of int): + If an int is given, it is broadcast to all spatial dimensions. + If a sequence is given, its length must match `ndim`. + ndim (int): + Number of spatial dimensions. + + Returns: + Tuple[int, ...]: + Origin offset per spatial dimension. + """ + # If a scalar is given, replicate it across all dimensions. if isinstance(origin, int): return tuple([origin] * ndim) + + # Otherwise, ensure it is a tuple with correct dimensionality. origin = tuple(origin) if len(origin) != ndim: raise ValueError("origin must match spatial ndim") + return origin -def _pad_for_kernel(kernel_shape: Sequence[int], origin: Sequence[int]) -> Tuple[Tuple[int,int], ...]: + +def _pad_for_kernel( + kernel_shape: Sequence[int], + origin: Sequence[int] +) -> Tuple[Tuple[int, int], ...]: + """ + Compute per-dimension padding sizes required to keep output shape + identical to input shape after convolution. + + This takes into account the kernel size and the origin offset. + + Returns: + Tuple of (pad_before, pad_after) for each spatial dimension. + """ pads = [] for k, o in zip(kernel_shape, origin): - pad_before = k//2 - o + # Default symmetric padding would be k//2, + # but origin shifts the effective center. + pad_before = k // 2 - o pad_after = k - 1 - pad_before + + # Padding must be non-negative. pad_before = max(pad_before, 0) - pad_after = max(pad_after, 0) + pad_after = max(pad_after, 0) + pads.append((pad_before, pad_after)) return tuple(pads) -def _make_padding_tuple_for_Fpad(pads: Tuple[Tuple[int,int], ...]) -> Tuple[int,...]: + +def _make_padding_tuple_for_Fpad( + pads: Tuple[Tuple[int, int], ...] +) -> Tuple[int, ...]: + """ + Convert per-dimension padding into the flattened format required + by torch.nn.functional.pad. + + PyTorch expects padding in reverse order: + (pad_last_dim_left, pad_last_dim_right, ..., pad_first_dim_left, pad_first_dim_right) + """ flat = [] for pb, pa in reversed(pads): flat.append(pb) flat.append(pa) return tuple(flat) + def _conv_nd(x: torch.Tensor, kernel: torch.Tensor, ndim: int) -> torch.Tensor: + """ + Dispatch N-dimensional convolution based on spatial dimensionality. + + Args: + x (torch.Tensor): + Input tensor of shape (B*C, 1, *spatial_dims) + kernel (torch.Tensor): + Structuring element kernel. + ndim (int): + Number of spatial dimensions (1, 2, or 3). + + Returns: + torch.Tensor: + Convolution result. + """ + # Convert kernel into convolution weight: + # shape -> (out_channels=1, in_channels=1, *kernel_shape) weight = ( kernel.to(dtype=x.dtype, device=x.device) - .unsqueeze(0).unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0) ) + if ndim == 1: return F.conv1d(x, weight) elif ndim == 2: @@ -87,6 +230,7 @@ def _conv_nd(x: torch.Tensor, kernel: torch.Tensor, ndim: int) -> torch.Tensor: else: raise NotImplementedError("Only supports 1D/2D/3D") + def _morph_op( input_tensor: torch.Tensor, structure: Optional[torch.Tensor], @@ -95,7 +239,31 @@ def _morph_op( border_value: int, mode: str ) -> torch.Tensor: + """ + Core implementation of binary dilation and erosion using convolution. + + This function supports batch and channel dimensions by flattening + (B, C) into a single dimension and applying morphology independently + per channel. + + Args: + input_tensor (torch.Tensor): + Input binary tensor. + structure (Optional[torch.Tensor]): + Structuring element. + iterations (int): + Number of times to apply the operation. + origin: + Origin offset of the structuring element. + border_value (int): + Value used for padding outside image boundaries. + mode (str): + Either 'dilation' or 'erosion'. + Returns: + torch.Tensor (dtype=torch.bool): + Output binary tensor. + """ if mode not in ('dilation', 'erosion'): raise ValueError("mode must be 'dilation' or 'erosion'") @@ -103,62 +271,56 @@ def _morph_op( if not torch.is_tensor(x): x = torch.tensor(x) + # Convert input to boolean (binary morphology). x_bool = (x != 0) - - # Support: (H,W), (C,H,W), (B,C,H,W), (B,C,D,H,W) + # Supported input shapes: + # (H,W), (C,H,W), (B,C,H,W), (B,C,D,H,W) full_ndim = x_bool.ndim if full_ndim < 2: raise NotImplementedError("Need at least 2D (H,W)") if full_ndim > 5: raise NotImplementedError("Only supports up to 5D (B,C,D,H,W)") - - # Spatial dims = last 1~3 dims - spatial_ndim = full_ndim - 2 # remove (B,C) + spatial_ndim = full_ndim - 2 # remove (B,C) if not (1 <= spatial_ndim <= 3): raise NotImplementedError("Supports 1D/2D/3D spatial dims") B, C = x_bool.shape[0], x_bool.shape[1] spatial_shape = x_bool.shape[2:] - - # structure must match spatial dims st = _normalize_structure(structure, spatial_ndim) origin_t = _origin_to_tuple(origin, spatial_ndim) k_sum = st.sum().item() kernel = st.to(torch.float32) - # apply origin shift + # Apply origin shift by rolling kernel. for axis, o in enumerate(origin_t): if o != 0: kernel = torch.roll(kernel, shifts=-o, dims=axis) - pads = _pad_for_kernel(kernel.shape, origin_t) pad_tuple = _make_padding_tuple_for_Fpad(pads) - # cast to float cur = x_bool.to(torch.float32) - # Now do B*C loops, because conv2d can't do dilation per-channel independently - cur = cur.view(B*C, 1, *spatial_shape) - + # Flatten (B,C) -> (B*C,1) + cur = cur.view(B * C, 1, *spatial_shape) for _ in range(max(1, iterations)): x_pad = F.pad(cur, pad_tuple, value=float(border_value)) conv_res = _conv_nd(x_pad, kernel, spatial_ndim) - conv_res = conv_res # shape unchanged: (BC,1,H,W) or (BC,1,D,H,W) if mode == 'dilation': + # Any overlap -> True cur = (conv_res > 0).to(torch.float32) - else: # erosion + else: + # Full overlap -> True if k_sum == 0: cur = torch.ones_like(cur) else: cur = (conv_res >= (k_sum - 1e-6)).to(torch.float32) - - # reshape back out = cur.view(B, C, *spatial_shape) return out.to(torch.bool) + def binary_dilation(input_tensor, structure=None, iterations=1, origin=0, border_value=0): return _morph_op(input_tensor, structure, iterations, origin, border_value, mode="dilation") From 41d99a54b45257cea053236097d81938fbb8f94c Mon Sep 17 00:00:00 2001 From: dongliangnie Date: Tue, 16 Dec 2025 11:28:03 +0800 Subject: [PATCH 3/4] Fix flake8 lint issues --- benchmark/dilation_erosion.py | 4 ++-- torchmorph/__init__.py | 5 ++--- torchmorph/dilation_erosion.py | 6 ++++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/benchmark/dilation_erosion.py b/benchmark/dilation_erosion.py index 0ee94fb..5a4e7ca 100644 --- a/benchmark/dilation_erosion.py +++ b/benchmark/dilation_erosion.py @@ -1,7 +1,6 @@ import torch import torch.utils.benchmark as benchmark import scipy.ndimage as ndi -import numpy as np from prettytable import PrettyTable import torchmorph as tm @@ -43,7 +42,7 @@ def bench_single_op(op_name): # Generate binary input x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] - x_imgs = [x[i:i+1] for i in range(B)] # (1, 1, H, W) + x_imgs = [x[i:i+1] for i in range(B)] # (1, 1, H, W) # SciPy (CPU, one-by-one) stmt_scipy = "out = [scipy_op(arr) for arr in x_np_list]" t_scipy = benchmark.Timer( @@ -53,6 +52,7 @@ def bench_single_op(op_name): scipy_ms = (t_scipy.median * 1e3) / B # Torch CUDA (one-by-one) + stmt_torch1 = "out = [torch_op(img) for img in x_imgs]" t_torch1 = benchmark.Timer( stmt=stmt_torch1, globals={"x_imgs": x_imgs, "torch_op": torch_op}, diff --git a/torchmorph/__init__.py b/torchmorph/__init__.py index a8e3d67..821b203 100644 --- a/torchmorph/__init__.py +++ b/torchmorph/__init__.py @@ -1,6 +1,5 @@ from .add import add from .distance_transform import distance_transform -from .dilation import binary_dilation,binary_erosion +from .dilation_erosion import binary_dilation, binary_erosion -__all__ = ["add", "distance_transform","binary_dilation", - "binary_erosion",] +__all__ = ["add", "distance_transform", "binary_dilation", "binary_erosion",] diff --git a/torchmorph/dilation_erosion.py b/torchmorph/dilation_erosion.py index 85c7153..977f23f 100644 --- a/torchmorph/dilation_erosion.py +++ b/torchmorph/dilation_erosion.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from typing import Optional, Union, Sequence, Tuple + def _to_bool_tensor(x: torch.Tensor) -> torch.Tensor: """ Convert an input value into a boolean PyTorch tensor. @@ -170,11 +171,11 @@ def _pad_for_kernel( # Default symmetric padding would be k//2, # but origin shifts the effective center. pad_before = k // 2 - o - pad_after = k - 1 - pad_before + pad_after = k - 1 - pad_before # Padding must be non-negative. pad_before = max(pad_before, 0) - pad_after = max(pad_after, 0) + pad_after = max(pad_after, 0) pads.append((pad_before, pad_after)) return tuple(pads) @@ -324,5 +325,6 @@ def _morph_op( def binary_dilation(input_tensor, structure=None, iterations=1, origin=0, border_value=0): return _morph_op(input_tensor, structure, iterations, origin, border_value, mode="dilation") + def binary_erosion(input_tensor, structure=None, iterations=1, origin=0, border_value=0): return _morph_op(input_tensor, structure, iterations, origin, border_value, mode="erosion") From 3ec5587cf62dcd44d293d34d1657f97eb7b94d09 Mon Sep 17 00:00:00 2001 From: dongliangnie Date: Tue, 16 Dec 2025 13:25:03 +0800 Subject: [PATCH 4/4] style: apply black and isort formatting --- benchmark/dilation_erosion.py | 23 +++++++++++++---------- torchmorph/__init__.py | 9 +++++++-- torchmorph/dilation_erosion.py | 29 ++++++++++------------------- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/benchmark/dilation_erosion.py b/benchmark/dilation_erosion.py index 5a4e7ca..afe7e8e 100644 --- a/benchmark/dilation_erosion.py +++ b/benchmark/dilation_erosion.py @@ -1,7 +1,8 @@ +import scipy.ndimage as ndi import torch import torch.utils.benchmark as benchmark -import scipy.ndimage as ndi from prettytable import PrettyTable + import torchmorph as tm sizes = [64, 128, 256, 512] @@ -42,7 +43,7 @@ def bench_single_op(op_name): # Generate binary input x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype) x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)] - x_imgs = [x[i:i+1] for i in range(B)] # (1, 1, H, W) + x_imgs = [x[i : i + 1] for i in range(B)] # (1, 1, H, W) # SciPy (CPU, one-by-one) stmt_scipy = "out = [scipy_op(arr) for arr in x_np_list]" t_scipy = benchmark.Timer( @@ -70,14 +71,16 @@ def bench_single_op(op_name): speed1 = scipy_ms / torch1_ms speedB = scipy_ms / torchB_ms - table.add_row([ - s, - f"{scipy_ms:.3f}", - f"{torch1_ms:.3f}", - f"{torchB_ms:.3f}", - f"{speed1:.1f}×", - f"{speedB:.1f}×", - ]) + table.add_row( + [ + s, + f"{scipy_ms:.3f}", + f"{torch1_ms:.3f}", + f"{torchB_ms:.3f}", + f"{speed1:.1f}×", + f"{speedB:.1f}×", + ] + ) print(f"\n=== Batch Size: {B} ===") print(table) diff --git a/torchmorph/__init__.py b/torchmorph/__init__.py index 821b203..f35a5c9 100644 --- a/torchmorph/__init__.py +++ b/torchmorph/__init__.py @@ -1,5 +1,10 @@ from .add import add -from .distance_transform import distance_transform from .dilation_erosion import binary_dilation, binary_erosion +from .distance_transform import distance_transform -__all__ = ["add", "distance_transform", "binary_dilation", "binary_erosion",] +__all__ = [ + "add", + "distance_transform", + "binary_dilation", + "binary_erosion", +] diff --git a/torchmorph/dilation_erosion.py b/torchmorph/dilation_erosion.py index 977f23f..4e35d1b 100644 --- a/torchmorph/dilation_erosion.py +++ b/torchmorph/dilation_erosion.py @@ -1,6 +1,7 @@ +from typing import Optional, Sequence, Tuple, Union + import torch import torch.nn.functional as F -from typing import Optional, Union, Sequence, Tuple def _to_bool_tensor(x: torch.Tensor) -> torch.Tensor: @@ -37,7 +38,7 @@ def _to_bool_tensor(x: torch.Tensor) -> torch.Tensor: # Convert input tensor into boolean by checking non-zero status. # Non-zero -> True, zero -> False. - return (x != 0) + return x != 0 def _normalize_structure(structure: Optional[torch.Tensor], ndim: int) -> torch.Tensor: @@ -112,17 +113,14 @@ def _normalize_structure(structure: Optional[torch.Tensor], ndim: int) -> torch. # Validate dimensionality: the structuring element must have the same # number of dimensions as the spatial dimensions of the input tensor. if st.ndim != ndim: - raise ValueError( - f"structure must be {ndim}-D (got {st.ndim}-D)" - ) + raise ValueError(f"structure must be {ndim}-D (got {st.ndim}-D)") # Return the normalized boolean structuring element. return st def _origin_to_tuple( - origin: Union[int, Sequence[int], Tuple[int, ...]], - ndim: int + origin: Union[int, Sequence[int], Tuple[int, ...]], ndim: int ) -> Tuple[int, ...]: """ Normalize the `origin` argument into an ndim-length tuple. @@ -154,8 +152,7 @@ def _origin_to_tuple( def _pad_for_kernel( - kernel_shape: Sequence[int], - origin: Sequence[int] + kernel_shape: Sequence[int], origin: Sequence[int] ) -> Tuple[Tuple[int, int], ...]: """ Compute per-dimension padding sizes required to keep output shape @@ -181,9 +178,7 @@ def _pad_for_kernel( return tuple(pads) -def _make_padding_tuple_for_Fpad( - pads: Tuple[Tuple[int, int], ...] -) -> Tuple[int, ...]: +def _make_padding_tuple_for_Fpad(pads: Tuple[Tuple[int, int], ...]) -> Tuple[int, ...]: """ Convert per-dimension padding into the flattened format required by torch.nn.functional.pad. @@ -216,11 +211,7 @@ def _conv_nd(x: torch.Tensor, kernel: torch.Tensor, ndim: int) -> torch.Tensor: """ # Convert kernel into convolution weight: # shape -> (out_channels=1, in_channels=1, *kernel_shape) - weight = ( - kernel.to(dtype=x.dtype, device=x.device) - .unsqueeze(0) - .unsqueeze(0) - ) + weight = kernel.to(dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(0) if ndim == 1: return F.conv1d(x, weight) @@ -238,7 +229,7 @@ def _morph_op( iterations: int, origin: Union[int, Sequence[int]], border_value: int, - mode: str + mode: str, ) -> torch.Tensor: """ Core implementation of binary dilation and erosion using convolution. @@ -273,7 +264,7 @@ def _morph_op( x = torch.tensor(x) # Convert input to boolean (binary morphology). - x_bool = (x != 0) + x_bool = x != 0 # Supported input shapes: # (H,W), (C,H,W), (B,C,H,W), (B,C,D,H,W) full_ndim = x_bool.ndim