diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2b8d15a..3ec32de 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,7 +7,7 @@ env: jobs: test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: matrix: python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13' ] diff --git a/.github/workflows/version.yml b/.github/workflows/version.yml index bd8a13b..a1d12d6 100644 --- a/.github/workflows/version.yml +++ b/.github/workflows/version.yml @@ -7,7 +7,7 @@ env: jobs: check: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 diff --git a/dpipe/__version__.py b/dpipe/__version__.py index 908c0bb..2b8877c 100644 --- a/dpipe/__version__.py +++ b/dpipe/__version__.py @@ -1 +1 @@ -__version__ = '0.4.3' +__version__ = '0.5.0' diff --git a/dpipe/layers/fpn.py b/dpipe/layers/fpn.py index 767c07e..ec3ea33 100644 --- a/dpipe/layers/fpn.py +++ b/dpipe/layers/fpn.py @@ -1,10 +1,10 @@ -from functools import partial +from contextlib import nullcontext from typing import Callable, Sequence, Union from warnings import warn import torch import torch.nn as nn -from torch.nn import functional +from torch.nn.functional import interpolate import numpy as np from dpipe.itertools import zip_equal, lmap @@ -131,21 +131,23 @@ def interpolate_merge(merge: Callable, order: int = 0): return lambda left, down: merge(*interpolate_to_left(left, down, order)) -def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, mode: str = None): - if mode is not None: - msg = 'Argument `mode` is deprecated. Use `order` instead.' - warn(msg, UserWarning) - warn(msg, DeprecationWarning) - order = mode +def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, check_shape_equal: bool = False): + if check_shape_equal and np.equal(left.shape, down.shape).all(): + message = 'interpolate_to_left is called with check_shape_equal=True. This may lead to branching.' + warn(message, UserWarning) + return left, down - if isinstance(order, int): - order = order_to_mode(order, len(down.shape) - 2) + mode = order_to_mode(order, len(down.shape) - 2) if isinstance(order, int) else order + align_corners = False if mode in ['linear', 'bilinear', 'bicubic', 'trilinear'] else None - if np.not_equal(left.shape, down.shape).any(): - interpolate = functional.interpolate - if order in ['linear', 'bilinear', ' bicubic', 'trilinear']: - interpolate = partial(interpolate, align_corners=False) + # interpolate behaves strangely in torch >=2.4 - always returns fp32 regardless of AMP + # disabling autocast leads to dtype inheritance from interpolation input + amp_manager = ( + torch.amp.autocast('cuda', enabled=False, cache_enabled=True) if torch.is_autocast_enabled() + else nullcontext() + ) - down = interpolate(down, size=left.shape[2:], mode=order) + with amp_manager: + down = interpolate(down, size=left.shape[2:], mode=mode, align_corners=align_corners) return left, down