diff --git a/.github/actions/conda/action.yaml b/.github/actions/conda/action.yaml index b7240e6..266cbad 100755 --- a/.github/actions/conda/action.yaml +++ b/.github/actions/conda/action.yaml @@ -21,7 +21,7 @@ inputs: description: 'Override' required: false default: true - password: + password: required: true runs: using: "composite" @@ -40,7 +40,8 @@ runs: - uses: conda-incubator/setup-miniconda@v2 with: mamba-version: "*" - channels: balbasty,pytorch,conda-forge,defaults + miniforge-version: latest + channels: balbasty,pytorch,conda-forge channel-priority: true activate-environment: build - name: Install boa / anaconda @@ -73,7 +74,7 @@ runs: - name: "Publish (dry run: ${{ inputs.dry-run }})" if: inputs.dry-run == 'false' shell: bash -el {0} - env: + env: OVERRIDE: ${{ inputs.override }} PLATFORMS: ${{ inputs.platforms }} ANACONDA_API_TOKEN: ${{ inputs.password }} diff --git a/.github/actions/test/action.yaml b/.github/actions/test/action.yaml index be7af89..c2f9b78 100755 --- a/.github/actions/test/action.yaml +++ b/.github/actions/test/action.yaml @@ -22,12 +22,13 @@ runs: - uses: conda-incubator/setup-miniconda@v2 with: mamba-version: "*" + miniforge-version: latest python-version: ${{ inputs.python-version }} - channels: balbasty,pytorch,conda-forge,defaults + channels: balbasty,pytorch,conda-forge activate-environment: test-env - name: Install dependencies shell: bash -el {0} - env: + env: PYTORCH_VERSION: ${{ inputs.pytorch-version }} run: | mamba install pytorch=${PYTORCH_VERSION} pytest @@ -35,4 +36,4 @@ runs: shell: bash -el {0} run: | pip install . - pytest --pyargs cornucopia \ No newline at end of file + pytest --pyargs cornucopia diff --git a/cornucopia/intensity.py b/cornucopia/intensity.py index c3a620b..2b23dfb 100755 --- a/cornucopia/intensity.py +++ b/cornucopia/intensity.py @@ -30,6 +30,7 @@ from .random import Sampler, Uniform, RandInt, Fixed, make_range from .utils.py import ensure_list, positive_index from .utils.smart_inplace import add_, mul_, div_, pow_ +from .utils.compat import clamp class OpConstTransform(FinalTransform): @@ -175,7 +176,7 @@ def xform(self, x): vmin = vmin.to(x) if torch.is_tensor(vmax): vmax = vmax.to(x) - y = x.clamp(vmin, vmax) + y = clamp(x, vmin, vmax) return prepare_output( {'input': x, 'output': y, 'vmin': vmin, 'vmax': vmax}, self.returns diff --git a/cornucopia/qmri.py b/cornucopia/qmri.py index f3df192..0108430 100755 --- a/cornucopia/qmri.py +++ b/cornucopia/qmri.py @@ -425,34 +425,28 @@ def __init__(self, tr=25e-3, te=7e-3, alpha=20, self.mt = mt def get_parameters(self, x): - x = x.unbind(0) if self.pd is None: - pd, *x = x - pd = pd[None] + pd, x = x[:1], x[1:] else: pd = self.pd if self.t1 is None: - t1, *x = x - t1 = t1[None] + t1, x = x[:1], x[1:] else: t1 = self.t1 if self.t2 is None: - t2, *x = x - t2 = t2[None] + t2, x = x[:1], x[1:] else: t2 = self.t2 - if self.b1 is None: - b1, *x = x - b1 = b1[None] - else: - b1 = self.b1 if self.mt is None: - mt, *x = x - mt = mt[None] + mt, x = x[:1], x[1:] else: mt = self.mt + if self.b1 is None: + b1 = x + else: + b1 = self.b1 - return pd, t1, t2, b1, mt + return pd, t1, t2, mt, b1 def xform(self, x): """ @@ -468,7 +462,7 @@ def xform(self, x): - b1: Transmit efficiency (B1+). `1` means 100% efficiency. - mt: Magnetization transfer saturation (MTsat). """ - pd, t1, t2, b1, mt = self.get_parameters(x) + pd, t1, t2, mt, b1 = self.get_parameters(x) alpha = (math.pi * self.alpha / 180) * b1 if torch.is_tensor(alpha): sinalpha, cosalpha = alpha.sin(), alpha.cos() @@ -597,8 +591,11 @@ def sigmoid_(x): logmt = logit(mt) logsigma = sigma.log() + nb1 = len(b1) if (torch.is_tensor(b1) and b1.ndim > 0) else 1 + dtype = x.dtype if x.is_floating_point() else torch.get_default_dtype() - y = x.new_zeros([5, *x.shape[1:]], dtype=dtype) + y = x.new_zeros([4 + nb1, *x.shape[1:]], dtype=dtype) + print(y.shape) # PD y[0] = exp_(GaussianMixtureTransform( mu=logpd, sigma=logsigma[:n], fwhm=fwhm[:n], @@ -614,13 +611,13 @@ def sigmoid_(x): mu=logt2, sigma=logsigma[2*n:3*n], fwhm=fwhm[2*n:3*n], background=0, dtype=dtype )(x)).squeeze(0) - # B1 - y[4:] = b1 # MT y[3] = sigmoid_(GaussianMixtureTransform( mu=logmt, sigma=logsigma[3*n:4*n], fwhm=fwhm[3*n:4*n], background=0, dtype=dtype )(x)).squeeze(0) + # B1 + y[4:].copy_(b1) # NOTE: y[4:] = b1 breaks the graph in torch 1.11 # GRE forward mode mask = (1 - x[0]) if x.dtype.is_floating_point else (x != 0) diff --git a/cornucopia/utils/compat.py b/cornucopia/utils/compat.py new file mode 100644 index 0000000..8b604f5 --- /dev/null +++ b/cornucopia/utils/compat.py @@ -0,0 +1,30 @@ +from typing import Optional +import torch + +Tensor = torch.Tensor + + +torch_version = torch.__version__ +torch_version = torch_version.split("+")[0] # remove local modifier +torch_version = torch_version.split(".")[:2] # major + minor +torch_version = tuple(map(int, torch_version)) # integer + + +if torch_version < (1, 9): + def clamp( + x: Tensor, + min: Optional[Tensor] = None, + max: Optional[Tensor] = None, + **kwargs + ): + if torch.is_tensor(min): + x = torch.maximum(x, min, **kwargs) + min = None + if torch.is_tensor(max): + x = torch.minimum(x, max, **kwargs) + max = None + if min is not None or max is not None: + x = x.clamp(min, max, **kwargs) + return x +else: + clamp = torch.clamp