From 0c45842a571780d921b9f439cf93f812ae6181c3 Mon Sep 17 00:00:00 2001 From: balbasty Date: Wed, 16 Apr 2025 14:22:47 +0100 Subject: [PATCH 1/6] FIX(qmri): do not lose b1 channels --- cornucopia/qmri.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/cornucopia/qmri.py b/cornucopia/qmri.py index f3df192..cbdd036 100755 --- a/cornucopia/qmri.py +++ b/cornucopia/qmri.py @@ -425,34 +425,28 @@ def __init__(self, tr=25e-3, te=7e-3, alpha=20, self.mt = mt def get_parameters(self, x): - x = x.unbind(0) if self.pd is None: - pd, *x = x - pd = pd[None] + pd, x = x[:1], x[1:] else: pd = self.pd if self.t1 is None: - t1, *x = x - t1 = t1[None] + t1, x = x[:1], x[1:] else: t1 = self.t1 if self.t2 is None: - t2, *x = x - t2 = t2[None] + t2, x = x[:1], x[1:] else: t2 = self.t2 - if self.b1 is None: - b1, *x = x - b1 = b1[None] - else: - b1 = self.b1 if self.mt is None: - mt, *x = x - mt = mt[None] + mt, x = x[:1], x[1:] else: mt = self.mt + if self.b1 is None: + b1 = x + else: + b1 = self.b1 - return pd, t1, t2, b1, mt + return pd, t1, t2, mt, b1 def xform(self, x): """ @@ -468,7 +462,7 @@ def xform(self, x): - b1: Transmit efficiency (B1+). `1` means 100% efficiency. - mt: Magnetization transfer saturation (MTsat). """ - pd, t1, t2, b1, mt = self.get_parameters(x) + pd, t1, t2, mt, b1 = self.get_parameters(x) alpha = (math.pi * self.alpha / 180) * b1 if torch.is_tensor(alpha): sinalpha, cosalpha = alpha.sin(), alpha.cos() @@ -614,13 +608,13 @@ def sigmoid_(x): mu=logt2, sigma=logsigma[2*n:3*n], fwhm=fwhm[2*n:3*n], background=0, dtype=dtype )(x)).squeeze(0) - # B1 - y[4:] = b1 # MT y[3] = sigmoid_(GaussianMixtureTransform( mu=logmt, sigma=logsigma[3*n:4*n], fwhm=fwhm[3*n:4*n], background=0, dtype=dtype )(x)).squeeze(0) + # B1 + y[4:] = b1 # GRE forward mode mask = (1 - x[0]) if x.dtype.is_floating_point else (x != 0) From 9410cbb7df53757e775cb2459113103003fb71fc Mon Sep 17 00:00:00 2001 From: balbasty Date: Wed, 16 Apr 2025 16:03:02 +0100 Subject: [PATCH 2/6] FIX(qmri): b1 backprop --- cornucopia/qmri.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cornucopia/qmri.py b/cornucopia/qmri.py index cbdd036..0108430 100755 --- a/cornucopia/qmri.py +++ b/cornucopia/qmri.py @@ -591,8 +591,11 @@ def sigmoid_(x): logmt = logit(mt) logsigma = sigma.log() + nb1 = len(b1) if (torch.is_tensor(b1) and b1.ndim > 0) else 1 + dtype = x.dtype if x.is_floating_point() else torch.get_default_dtype() - y = x.new_zeros([5, *x.shape[1:]], dtype=dtype) + y = x.new_zeros([4 + nb1, *x.shape[1:]], dtype=dtype) + print(y.shape) # PD y[0] = exp_(GaussianMixtureTransform( mu=logpd, sigma=logsigma[:n], fwhm=fwhm[:n], @@ -614,7 +617,7 @@ def sigmoid_(x): background=0, dtype=dtype )(x)).squeeze(0) # B1 - y[4:] = b1 + y[4:].copy_(b1) # NOTE: y[4:] = b1 breaks the graph in torch 1.11 # GRE forward mode mask = (1 - x[0]) if x.dtype.is_floating_point else (x != 0) From 95b7ac8bb1351a0722df7f5abdbe12c94a76ebdb Mon Sep 17 00:00:00 2001 From: balbasty Date: Wed, 16 Apr 2025 16:18:01 +0100 Subject: [PATCH 3/6] FIX(intensity): clamp does not take tensors before 1.8 --- cornucopia/intensity.py | 3 ++- cornucopia/utils/compat.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 cornucopia/utils/compat.py diff --git a/cornucopia/intensity.py b/cornucopia/intensity.py index c3a620b..2b23dfb 100755 --- a/cornucopia/intensity.py +++ b/cornucopia/intensity.py @@ -30,6 +30,7 @@ from .random import Sampler, Uniform, RandInt, Fixed, make_range from .utils.py import ensure_list, positive_index from .utils.smart_inplace import add_, mul_, div_, pow_ +from .utils.compat import clamp class OpConstTransform(FinalTransform): @@ -175,7 +176,7 @@ def xform(self, x): vmin = vmin.to(x) if torch.is_tensor(vmax): vmax = vmax.to(x) - y = x.clamp(vmin, vmax) + y = clamp(x, vmin, vmax) return prepare_output( {'input': x, 'output': y, 'vmin': vmin, 'vmax': vmax}, self.returns diff --git a/cornucopia/utils/compat.py b/cornucopia/utils/compat.py new file mode 100644 index 0000000..5f46b68 --- /dev/null +++ b/cornucopia/utils/compat.py @@ -0,0 +1,30 @@ +from typing import Optional +import torch + +Tensor = torch.Tensor + + +torch_version = torch.__version__ +torch_version = torch_version.split("+")[0] # remove local modifier +torch_version = torch_version.split(".")[:2] # major + minor +torch_version = tuple(map(int, torch_version)) # integer + + +if torch_version < (1, 8): + def clamp( + x: Tensor, + min: Optional[Tensor] = None, + max: Optional[Tensor] = None, + **kwargs + ): + if torch.is_tensor(min): + x = torch.maximum(x, min, **kwargs) + min = None + if torch.is_tensor(max): + x = torch.minimum(x, max, **kwargs) + max = None + if min is not None or max is not None: + x = x.clamp(min, max, **kwargs) + return x +else: + clamp = torch.clamp From f7fa1a173d5237791cfb20bb3fab099f6ddd99d8 Mon Sep 17 00:00:00 2001 From: balbasty Date: Wed, 16 Apr 2025 16:43:09 +0100 Subject: [PATCH 4/6] FIX(intensity): clamp does not take tensors before 1.9 --- cornucopia/utils/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cornucopia/utils/compat.py b/cornucopia/utils/compat.py index 5f46b68..8b604f5 100644 --- a/cornucopia/utils/compat.py +++ b/cornucopia/utils/compat.py @@ -10,7 +10,7 @@ torch_version = tuple(map(int, torch_version)) # integer -if torch_version < (1, 8): +if torch_version < (1, 9): def clamp( x: Tensor, min: Optional[Tensor] = None, From 5fce83b724ebabc25db3ca8ffa46134c1b8e35c5 Mon Sep 17 00:00:00 2001 From: balbasty Date: Wed, 16 Apr 2025 17:04:45 +0100 Subject: [PATCH 5/6] MNT(actions): remove defaults conda channel --- .github/actions/test/action.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/actions/test/action.yaml b/.github/actions/test/action.yaml index be7af89..f4eda28 100755 --- a/.github/actions/test/action.yaml +++ b/.github/actions/test/action.yaml @@ -23,11 +23,11 @@ runs: with: mamba-version: "*" python-version: ${{ inputs.python-version }} - channels: balbasty,pytorch,conda-forge,defaults + channels: balbasty,pytorch,conda-forge activate-environment: test-env - name: Install dependencies shell: bash -el {0} - env: + env: PYTORCH_VERSION: ${{ inputs.pytorch-version }} run: | mamba install pytorch=${PYTORCH_VERSION} pytest @@ -35,4 +35,4 @@ runs: shell: bash -el {0} run: | pip install . - pytest --pyargs cornucopia \ No newline at end of file + pytest --pyargs cornucopia From 039ddb221d40d601abbfa4825d93928013b77658 Mon Sep 17 00:00:00 2001 From: balbasty Date: Wed, 16 Apr 2025 17:12:29 +0100 Subject: [PATCH 6/6] MNT(actions): use minforge to avoid defaults --- .github/actions/conda/action.yaml | 7 ++++--- .github/actions/test/action.yaml | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/actions/conda/action.yaml b/.github/actions/conda/action.yaml index b7240e6..266cbad 100755 --- a/.github/actions/conda/action.yaml +++ b/.github/actions/conda/action.yaml @@ -21,7 +21,7 @@ inputs: description: 'Override' required: false default: true - password: + password: required: true runs: using: "composite" @@ -40,7 +40,8 @@ runs: - uses: conda-incubator/setup-miniconda@v2 with: mamba-version: "*" - channels: balbasty,pytorch,conda-forge,defaults + miniforge-version: latest + channels: balbasty,pytorch,conda-forge channel-priority: true activate-environment: build - name: Install boa / anaconda @@ -73,7 +74,7 @@ runs: - name: "Publish (dry run: ${{ inputs.dry-run }})" if: inputs.dry-run == 'false' shell: bash -el {0} - env: + env: OVERRIDE: ${{ inputs.override }} PLATFORMS: ${{ inputs.platforms }} ANACONDA_API_TOKEN: ${{ inputs.password }} diff --git a/.github/actions/test/action.yaml b/.github/actions/test/action.yaml index f4eda28..c2f9b78 100755 --- a/.github/actions/test/action.yaml +++ b/.github/actions/test/action.yaml @@ -22,6 +22,7 @@ runs: - uses: conda-incubator/setup-miniconda@v2 with: mamba-version: "*" + miniforge-version: latest python-version: ${{ inputs.python-version }} channels: balbasty,pytorch,conda-forge activate-environment: test-env