Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:

jobs:
test:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
strategy:
matrix:
python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13' ]
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/version.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:

jobs:
check:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion dpipe/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.3'
__version__ = '0.5.0'
32 changes: 17 additions & 15 deletions dpipe/layers/fpn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from functools import partial
from contextlib import nullcontext
from typing import Callable, Sequence, Union
from warnings import warn

import torch
import torch.nn as nn
from torch.nn import functional
from torch.nn.functional import interpolate
import numpy as np

from dpipe.itertools import zip_equal, lmap
Expand Down Expand Up @@ -131,21 +131,23 @@ def interpolate_merge(merge: Callable, order: int = 0):
return lambda left, down: merge(*interpolate_to_left(left, down, order))


def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, mode: str = None):
if mode is not None:
msg = 'Argument `mode` is deprecated. Use `order` instead.'
warn(msg, UserWarning)
warn(msg, DeprecationWarning)
order = mode
def interpolate_to_left(left: torch.Tensor, down: torch.Tensor, order: int = 0, *, check_shape_equal: bool = False):
if check_shape_equal and np.equal(left.shape, down.shape).all():
message = 'interpolate_to_left is called with check_shape_equal=True. This may lead to branching.'
warn(message, UserWarning)
return left, down

if isinstance(order, int):
order = order_to_mode(order, len(down.shape) - 2)
mode = order_to_mode(order, len(down.shape) - 2) if isinstance(order, int) else order
align_corners = False if mode in ['linear', 'bilinear', 'bicubic', 'trilinear'] else None

if np.not_equal(left.shape, down.shape).any():
interpolate = functional.interpolate
if order in ['linear', 'bilinear', ' bicubic', 'trilinear']:
interpolate = partial(interpolate, align_corners=False)
# interpolate behaves strangely in torch >=2.4 - always returns fp32 regardless of AMP
# disabling autocast leads to dtype inheritance from interpolation input
amp_manager = (
torch.amp.autocast('cuda', enabled=False, cache_enabled=True) if torch.is_autocast_enabled()
else nullcontext()
)

down = interpolate(down, size=left.shape[2:], mode=order)
with amp_manager:
down = interpolate(down, size=left.shape[2:], mode=mode, align_corners=align_corners)

return left, down