Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/actions/conda/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ inputs:
description: 'Override'
required: false
default: true
password:
password:
required: true
runs:
using: "composite"
Expand All @@ -40,7 +40,8 @@ runs:
- uses: conda-incubator/setup-miniconda@v2
with:
mamba-version: "*"
channels: balbasty,pytorch,conda-forge,defaults
miniforge-version: latest
channels: balbasty,pytorch,conda-forge
channel-priority: true
activate-environment: build
- name: Install boa / anaconda
Expand Down Expand Up @@ -73,7 +74,7 @@ runs:
- name: "Publish (dry run: ${{ inputs.dry-run }})"
if: inputs.dry-run == 'false'
shell: bash -el {0}
env:
env:
OVERRIDE: ${{ inputs.override }}
PLATFORMS: ${{ inputs.platforms }}
ANACONDA_API_TOKEN: ${{ inputs.password }}
Expand Down
7 changes: 4 additions & 3 deletions .github/actions/test/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@ runs:
- uses: conda-incubator/setup-miniconda@v2
with:
mamba-version: "*"
miniforge-version: latest
python-version: ${{ inputs.python-version }}
channels: balbasty,pytorch,conda-forge,defaults
channels: balbasty,pytorch,conda-forge
activate-environment: test-env
- name: Install dependencies
shell: bash -el {0}
env:
env:
PYTORCH_VERSION: ${{ inputs.pytorch-version }}
run: |
mamba install pytorch=${PYTORCH_VERSION} pytest
- name: Test with pytest
shell: bash -el {0}
run: |
pip install .
pytest --pyargs cornucopia
pytest --pyargs cornucopia
3 changes: 2 additions & 1 deletion cornucopia/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .random import Sampler, Uniform, RandInt, Fixed, make_range
from .utils.py import ensure_list, positive_index
from .utils.smart_inplace import add_, mul_, div_, pow_
from .utils.compat import clamp


class OpConstTransform(FinalTransform):
Expand Down Expand Up @@ -175,7 +176,7 @@ def xform(self, x):
vmin = vmin.to(x)
if torch.is_tensor(vmax):
vmax = vmax.to(x)
y = x.clamp(vmin, vmax)
y = clamp(x, vmin, vmax)
return prepare_output(
{'input': x, 'output': y, 'vmin': vmin, 'vmax': vmax},
self.returns
Expand Down
35 changes: 16 additions & 19 deletions cornucopia/qmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,34 +425,28 @@ def __init__(self, tr=25e-3, te=7e-3, alpha=20,
self.mt = mt

def get_parameters(self, x):
x = x.unbind(0)
if self.pd is None:
pd, *x = x
pd = pd[None]
pd, x = x[:1], x[1:]
else:
pd = self.pd
if self.t1 is None:
t1, *x = x
t1 = t1[None]
t1, x = x[:1], x[1:]
else:
t1 = self.t1
if self.t2 is None:
t2, *x = x
t2 = t2[None]
t2, x = x[:1], x[1:]
else:
t2 = self.t2
if self.b1 is None:
b1, *x = x
b1 = b1[None]
else:
b1 = self.b1
if self.mt is None:
mt, *x = x
mt = mt[None]
mt, x = x[:1], x[1:]
else:
mt = self.mt
if self.b1 is None:
b1 = x
else:
b1 = self.b1

return pd, t1, t2, b1, mt
return pd, t1, t2, mt, b1

def xform(self, x):
"""
Expand All @@ -468,7 +462,7 @@ def xform(self, x):
- b1: Transmit efficiency (B1+). `1` means 100% efficiency.
- mt: Magnetization transfer saturation (MTsat).
"""
pd, t1, t2, b1, mt = self.get_parameters(x)
pd, t1, t2, mt, b1 = self.get_parameters(x)
alpha = (math.pi * self.alpha / 180) * b1
if torch.is_tensor(alpha):
sinalpha, cosalpha = alpha.sin(), alpha.cos()
Expand Down Expand Up @@ -597,8 +591,11 @@ def sigmoid_(x):
logmt = logit(mt)
logsigma = sigma.log()

nb1 = len(b1) if (torch.is_tensor(b1) and b1.ndim > 0) else 1

dtype = x.dtype if x.is_floating_point() else torch.get_default_dtype()
y = x.new_zeros([5, *x.shape[1:]], dtype=dtype)
y = x.new_zeros([4 + nb1, *x.shape[1:]], dtype=dtype)
print(y.shape)
# PD
y[0] = exp_(GaussianMixtureTransform(
mu=logpd, sigma=logsigma[:n], fwhm=fwhm[:n],
Expand All @@ -614,13 +611,13 @@ def sigmoid_(x):
mu=logt2, sigma=logsigma[2*n:3*n], fwhm=fwhm[2*n:3*n],
background=0, dtype=dtype
)(x)).squeeze(0)
# B1
y[4:] = b1
# MT
y[3] = sigmoid_(GaussianMixtureTransform(
mu=logmt, sigma=logsigma[3*n:4*n], fwhm=fwhm[3*n:4*n],
background=0, dtype=dtype
)(x)).squeeze(0)
# B1
y[4:].copy_(b1) # NOTE: y[4:] = b1 breaks the graph in torch 1.11

# GRE forward mode
mask = (1 - x[0]) if x.dtype.is_floating_point else (x != 0)
Expand Down
30 changes: 30 additions & 0 deletions cornucopia/utils/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional
import torch

Tensor = torch.Tensor


torch_version = torch.__version__
torch_version = torch_version.split("+")[0] # remove local modifier
torch_version = torch_version.split(".")[:2] # major + minor
torch_version = tuple(map(int, torch_version)) # integer


if torch_version < (1, 9):
def clamp(
x: Tensor,
min: Optional[Tensor] = None,
max: Optional[Tensor] = None,
**kwargs
):
if torch.is_tensor(min):
x = torch.maximum(x, min, **kwargs)
min = None
if torch.is_tensor(max):
x = torch.minimum(x, max, **kwargs)
max = None
if min is not None or max is not None:
x = x.clamp(min, max, **kwargs)
return x
else:
clamp = torch.clamp
Loading