From 8d73211abcabac0956404929c71cb01936bfc4a5 Mon Sep 17 00:00:00 2001 From: Anihilatorgunn Date: Wed, 3 Dec 2025 14:39:22 +0300 Subject: [PATCH 1/9] argument mode has been removed from dpipe.layers.interpolate_to_left --- dpipe/__version__.py | 2 +- dpipe/layers/fpn.py | 11 ++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/dpipe/__version__.py b/dpipe/__version__.py index 908c0bb..ed7d50e 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.4.3' +__version__ = '0.5.3' diff --git a/dpipe/layers/fpn.py b/dpipe/layers/fpn.py index 767c07e..e2a350e 100644 --- a/dpipe/layers/fpn.py +++ b/dpipe/layers/fpn.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from torch.nn import functional +from torch.nn.functional import interpolate import numpy as np from dpipe.itertools import zip_equal, lmap @@ -131,18 +131,11 @@ def interpolate_merge(merge: Callable, order: int = 0): return lambda left, down: merge(*interpolate_to_left(left, down, order)) -def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, mode: str = None): - if mode is not None: - msg = 'Argument `mode` is deprecated. Use `order` instead.' - warn(msg, UserWarning) - warn(msg, DeprecationWarning) - order = mode - +def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0): if isinstance(order, int): order = order_to_mode(order, len(down.shape) - 2) if np.not_equal(left.shape, down.shape).any(): - interpolate = functional.interpolate if order in ['linear', 'bilinear', ' bicubic', 'trilinear']: interpolate = partial(interpolate, align_corners=False) From 7a147b52bed35eb601b1b241498faee8ab14a295 Mon Sep 17 00:00:00 2001 From: Anihilatorgunn Date: Wed, 3 Dec 2025 14:47:20 +0300 Subject: [PATCH 2/9] naming and conciseness --- dpipe/layers/fpn.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/dpipe/layers/fpn.py b/dpipe/layers/fpn.py index e2a350e..2524cda 100644 --- a/dpipe/layers/fpn.py +++ b/dpipe/layers/fpn.py @@ -132,13 +132,10 @@ def interpolate_merge(merge: Callable, order: int = 0): def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0): - if isinstance(order, int): - order = order_to_mode(order, len(down.shape) - 2) + mode = order_to_mode(order, len(down.shape) - 2) if isinstance(order, int) else order + align_corners = False if mode in ['linear', 'bilinear', ' bicubic', 'trilinear'] else None if np.not_equal(left.shape, down.shape).any(): - if order in ['linear', 'bilinear', ' bicubic', 'trilinear']: - interpolate = partial(interpolate, align_corners=False) - - down = interpolate(down, size=left.shape[2:], mode=order) + down = interpolate(down, size=left.shape[2:], mode=mode, align_corners=align_corners) return left, down From 2b4fb60b9184a78af56359a11d9a1bb7f7e3e8e4 Mon Sep 17 00:00:00 2001 From: Anihilatorgunn Date: Wed, 3 Dec 2025 14:58:04 +0300 Subject: [PATCH 3/9] Now interpolation is always performed --- dpipe/layers/fpn.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dpipe/layers/fpn.py b/dpipe/layers/fpn.py index 2524cda..8cb6650 100644 --- a/dpipe/layers/fpn.py +++ b/dpipe/layers/fpn.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Callable, Sequence, Union from warnings import warn @@ -131,11 +130,14 @@ def interpolate_merge(merge: Callable, order: int = 0): return lambda left, down: merge(*interpolate_to_left(left, down, order)) -def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0): +def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, check_shape_equal: bool = False): + if check_shape_equal and np.equal(left.shape, down.shape).all(): + message = 'interpolate_to_left is called with =True. This may lead to branching.' + warn(message, UserWarning) + return left, down + mode = order_to_mode(order, len(down.shape) - 2) if isinstance(order, int) else order align_corners = False if mode in ['linear', 'bilinear', ' bicubic', 'trilinear'] else None - - if np.not_equal(left.shape, down.shape).any(): - down = interpolate(down, size=left.shape[2:], mode=mode, align_corners=align_corners) + down = interpolate(down, size=left.shape[2:], mode=mode, align_corners=align_corners) return left, down From 3aba379902bb4bf8e6c82b1d5ac518b576a97ab1 Mon Sep 17 00:00:00 2001 From: Anihilatorgunn Date: Wed, 3 Dec 2025 15:14:45 +0300 Subject: [PATCH 4/9] check_shape_equal is kwonly --- dpipe/layers/fpn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpipe/layers/fpn.py b/dpipe/layers/fpn.py index 8cb6650..d3b3fb2 100644 --- a/dpipe/layers/fpn.py +++ b/dpipe/layers/fpn.py @@ -130,7 +130,7 @@ def interpolate_merge(merge: Callable, order: int = 0): return lambda left, down: merge(*interpolate_to_left(left, down, order)) -def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, check_shape_equal: bool = False): +def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, check_shape_equal: bool = False): if check_shape_equal and np.equal(left.shape, down.shape).all(): message = 'interpolate_to_left is called with =True. This may lead to branching.' warn(message, UserWarning) From 78a8bf0b3a637b7e752c721ac26d5d4471f4fb8f Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Wed, 3 Dec 2025 15:48:51 +0300 Subject: [PATCH 5/9] version, interpolation fix for torch>=2.4 --- dpipe/__version__.py | 2 +- dpipe/layers/fpn.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/dpipe/__version__.py b/dpipe/__version__.py index ed7d50e..2b8877c 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.5.3' +__version__ = '0.5.0' diff --git a/dpipe/layers/fpn.py b/dpipe/layers/fpn.py index d3b3fb2..519734f 100644 --- a/dpipe/layers/fpn.py +++ b/dpipe/layers/fpn.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from typing import Callable, Sequence, Union from warnings import warn @@ -130,7 +131,7 @@ def interpolate_merge(merge: Callable, order: int = 0): return lambda left, down: merge(*interpolate_to_left(left, down, order)) -def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, check_shape_equal: bool = False): +def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, check_shape_equal: bool = False, amp_fix: bool = True): if check_shape_equal and np.equal(left.shape, down.shape).all(): message = 'interpolate_to_left is called with =True. This may lead to branching.' warn(message, UserWarning) @@ -138,6 +139,18 @@ def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, mode = order_to_mode(order, len(down.shape) - 2) if isinstance(order, int) else order align_corners = False if mode in ['linear', 'bilinear', ' bicubic', 'trilinear'] else None - down = interpolate(down, size=left.shape[2:], mode=mode, align_corners=align_corners) + down = interpolate_amp(down, size=left.shape[2:], mode=mode, align_corners=align_corners, amp_fix=amp_fix) return left, down + + +def interpolate_amp(*args, amp_fix=True, **kwargs): + """ + interpolate behaves strangely in torch >=2.4 - always returns fp32 regardless of AMP + Disabling autocast leads to dtype inheritance from interpolation input + """ + if amp_fix and torch.is_autocast_enabled('cuda'): + with torch.amp.autocast('cuda', enabled=False, cache_enabled=True): + return interpolate(*args, **kwargs) + + return interpolate(*args, **kwargs) From bd95e19913449aa4602f32b530f3fae1497235cc Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Wed, 3 Dec 2025 15:50:15 +0300 Subject: [PATCH 6/9] message fix --- dpipe/layers/fpn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpipe/layers/fpn.py b/dpipe/layers/fpn.py index 519734f..2a848f2 100644 --- a/dpipe/layers/fpn.py +++ b/dpipe/layers/fpn.py @@ -133,7 +133,7 @@ def interpolate_merge(merge: Callable, order: int = 0): def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, check_shape_equal: bool = False, amp_fix: bool = True): if check_shape_equal and np.equal(left.shape, down.shape).all(): - message = 'interpolate_to_left is called with =True. This may lead to branching.' + message = 'interpolate_to_left is called with check_shape_equal=True. This may lead to branching.' warn(message, UserWarning) return left, down From 8876c8e51b953a162ca311e41c0003e1707f2bd6 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Wed, 3 Dec 2025 17:46:41 +0300 Subject: [PATCH 7/9] fixes --- dpipe/layers/fpn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dpipe/layers/fpn.py b/dpipe/layers/fpn.py index 2a848f2..b834d00 100644 --- a/dpipe/layers/fpn.py +++ b/dpipe/layers/fpn.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext from typing import Callable, Sequence, Union from warnings import warn @@ -138,7 +137,7 @@ def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, return left, down mode = order_to_mode(order, len(down.shape) - 2) if isinstance(order, int) else order - align_corners = False if mode in ['linear', 'bilinear', ' bicubic', 'trilinear'] else None + align_corners = False if mode in ['linear', 'bilinear', 'bicubic', 'trilinear'] else None down = interpolate_amp(down, size=left.shape[2:], mode=mode, align_corners=align_corners, amp_fix=amp_fix) return left, down @@ -149,7 +148,7 @@ def interpolate_amp(*args, amp_fix=True, **kwargs): interpolate behaves strangely in torch >=2.4 - always returns fp32 regardless of AMP Disabling autocast leads to dtype inheritance from interpolation input """ - if amp_fix and torch.is_autocast_enabled('cuda'): + if amp_fix and torch.is_autocast_enabled(): with torch.amp.autocast('cuda', enabled=False, cache_enabled=True): return interpolate(*args, **kwargs) From a9d87c107138e7d79a1c5b5099a9cbc00822ffa5 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Thu, 4 Dec 2025 02:15:10 +0300 Subject: [PATCH 8/9] remove verbose function, remove amp_fix arg --- dpipe/layers/fpn.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/dpipe/layers/fpn.py b/dpipe/layers/fpn.py index b834d00..ec3ea33 100644 --- a/dpipe/layers/fpn.py +++ b/dpipe/layers/fpn.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from typing import Callable, Sequence, Union from warnings import warn @@ -130,7 +131,7 @@ def interpolate_merge(merge: Callable, order: int = 0): return lambda left, down: merge(*interpolate_to_left(left, down, order)) -def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, check_shape_equal: bool = False, amp_fix: bool = True): +def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, check_shape_equal: bool = False): if check_shape_equal and np.equal(left.shape, down.shape).all(): message = 'interpolate_to_left is called with check_shape_equal=True. This may lead to branching.' warn(message, UserWarning) @@ -138,18 +139,15 @@ def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, mode = order_to_mode(order, len(down.shape) - 2) if isinstance(order, int) else order align_corners = False if mode in ['linear', 'bilinear', 'bicubic', 'trilinear'] else None - down = interpolate_amp(down, size=left.shape[2:], mode=mode, align_corners=align_corners, amp_fix=amp_fix) - return left, down + # interpolate behaves strangely in torch >=2.4 - always returns fp32 regardless of AMP + # disabling autocast leads to dtype inheritance from interpolation input + amp_manager = ( + torch.amp.autocast('cuda', enabled=False, cache_enabled=True) if torch.is_autocast_enabled() + else nullcontext() + ) + with amp_manager: + down = interpolate(down, size=left.shape[2:], mode=mode, align_corners=align_corners) -def interpolate_amp(*args, amp_fix=True, **kwargs): - """ - interpolate behaves strangely in torch >=2.4 - always returns fp32 regardless of AMP - Disabling autocast leads to dtype inheritance from interpolation input - """ - if amp_fix and torch.is_autocast_enabled(): - with torch.amp.autocast('cuda', enabled=False, cache_enabled=True): - return interpolate(*args, **kwargs) - - return interpolate(*args, **kwargs) + return left, down From ab0456334201e7d695a29af640ba3477cb540d67 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Thu, 4 Dec 2025 02:15:49 +0300 Subject: [PATCH 9/9] at least ubuntu-22.04 --- .github/workflows/tests.yml | 2 +- .github/workflows/version.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2b8d15a..3ec32de 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,7 +7,7 @@ env: jobs: test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: matrix: python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13' ] diff --git a/.github/workflows/version.yml b/.github/workflows/version.yml index bd8a13b..a1d12d6 100644 --- a/.github/workflows/version.yml +++ b/.github/workflows/version.yml @@ -7,7 +7,7 @@ env: jobs: check: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4