diff --git a/.codespellrc b/.codespellrc new file mode 100644 index 0000000..7e4ba7d --- /dev/null +++ b/.codespellrc @@ -0,0 +1,3 @@ +[codespell] +skip = [setup.cfg] +ignore-words-list = FWE, fwe \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..cae26d6 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,25 @@ +name: Tests + +on: [push, pull_request] + +jobs: + build-linux: + runs-on: ubuntu-latest + strategy: + max-parallel: 5 + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@v3 + with: + python-version: '3.11' + - name: Install + run: | + python -m pip install --upgrade pip + python -m pip install .[dev] + - name: Install and run pre-commit hooks + uses: pre-commit/action@v3.0.1 + - name: Test with pytest + run: | + pytest \ No newline at end of file diff --git a/.gitignore b/.gitignore index 94f5823..12ebc39 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ dist/** -fwe.egg-info \ No newline at end of file +fwe.egg-info +.DS_store +__pycache__ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f053dbd --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +default_language_version: + python: python3 + +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.3 + hooks: + # Run the linter + - id: ruff + args: [ --fix, --config, pyproject.toml ] + # Run the formatter + - id: ruff-format + - repo: https://github.com/codespell-project/codespell + rev: v2.3.0 + hooks: + - id: codespell + additional_dependencies: + - tomli \ No newline at end of file diff --git a/README.md b/README.md index 4e471d2..e75e468 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ Implements free water elimination (FWE) models for preprocessing diffusion MRI data. +This code was used in the paper: ["Free water elimination tractometry for aging brains"](https://direct.mit.edu/imag/article/doi/10.1162/IMAG.a.991/133658/Free-water-elimination-tractometry-for-aging) + --- ## To install @@ -22,7 +24,7 @@ automatically as part of the installation process specified above. ## Available Models -* Free water DTI as implmented in `dipy` ([Hoy et al., 2014](https://doi.org/10.1016/j.neuroimage.2014.09.053)): Use `fwe_model="dipy_fwdti"` +* Free water DTI as implemented in `dipy` ([Hoy et al., 2014](https://doi.org/10.1016/j.neuroimage.2014.09.053)): Use `fwe_model="dipy_fwdti"` * Beltrami regularized gradient descent free water DTI ([Golub et al., 2020](https://doi.org/10.1002/mrm.28599)): Use `fwe_model="golub_beltrami"` --- diff --git a/fwe/fwe.py b/fwe/fwe.py deleted file mode 100644 index d7d9fa0..0000000 --- a/fwe/fwe.py +++ /dev/null @@ -1,134 +0,0 @@ -import re -import logging -import argparse -import numpy as np -import nibabel as nib -from .beltrami import BeltramiModel -import dipy.reconst.fwdti as fwdti -from dipy.core.gradients import gradient_table - -# set up logging configuration -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("FWE") - - -def free_water_elimination( - dwi_fname, bval_fname, bvec_fname, mask_fname, fwe_model, output_fname, - Diso = 3.0e-3, save_params = False, - golub_kwargs = { - "init_method": "hybrid", "Stissue": None, "Swater": None, - "n_iterations": 100, "learning_rate": 0.0005 - } - ): - - # load diffusion, mask image and bvec/bval gradient table - dwi = nib.load(dwi_fname) - mask = nib.load(mask_fname).get_fdata() - gtab = gradient_table(bval_fname, bvec_fname) - - # perform free water elimination - match fwe_model.lower(): - case "dipy_fwdti": - logger.info("Performing free water elimination using DIPY's FWDTI model") - fwe_image, model_params = dipy_fwdti(dwi, gtab, mask, Diso, save_params) - case "golub_beltrami": - logger.info("Performing free water elimination using Golub's Beltrami model") - fwe_image, model_params = golub_beltrami( - dwi, gtab, mask, Diso, save_params, **golub_kwargs) - case _: - logger.error("Unrecognized free water elimination model") - - # save free water model parameters - if save_params: - output_params = re.sub("_(\\w+).nii.gz$", "_params.nii.gz", output_fname) - nib.save(model_params, output_params) - logger.info(f"Saving free water model parameters to: {output_fname}") - - # save free water eliminated image - nib.save(fwe_image, output_fname) - logger.info(f"Saving free water eliminated image to: {output_fname}") - - -def golub_beltrami(dwi, gtab, mask = None, Diso = 3.0e-3, save_params = False, - init_method = "hybrid", Stissue = None, Swater = None, - n_iterations = 100, learning_rate = 0.0005): - - # adjust b-values for Beltrami model - gtab.bvals = gtab.bvals * 1e-3 - - # fit Beltrami regularized gradient descent free-water diffusion tensor model - model = BeltramiModel(gtab, init_method = init_method, Diso = Diso * 1e3, - Stissue = Stissue, Swater = Swater, - iterations = n_iterations, - learning_rate = learning_rate) - model_fit = model.fit(dwi.get_fdata(), mask = mask) - - # save free water dti model parameters - model_params = nib.Nifti1Image(model_fit.model_params, affine = dwi.affine) \ - if save_params else None - - # extract free-water dti model parameters - fwf = model_fit.fw # free water fraction - - # undo b-value adjustment from Beltrami model - gtab.bvals = gtab.bvals * 1e3 - - # return free-water eliminated signal (and model parameters) - return (remove_free_water(dwi, gtab, fwf, Diso), model_params) - - -def dipy_fwdti(dwi, gtab, mask = None, Diso = 3.0e-3, - save_params = False): - # fit free-water diffusion tensor imaging model - model = fwdti.FreeWaterTensorModel(gtab) - model_fit = model.fit(dwi.get_fdata(), mask = mask) - - # save free water dti model parameters - model_params = nib.Nifti1Image(model_fit.model_params, affine = dwi.affine) \ - if save_params else None - - # extract free-water dti model parameters - fwf = model_fit.model_params[..., -1] - - # return free-water eliminated signal (and model parameters) - return (remove_free_water(dwi, gtab, fwf, Diso), model_params) - - -def remove_free_water(dwi, gtab, fwf, Diso): - # extract b0 signal from dwi image - S0 = dwi.get_fdata()[..., gtab.bvals == 0].mean(axis = -1) - - # compute free-water signal - fw_decay = np.exp(-gtab.bvals * Diso) # free-water exponential decay - fw_signal = (S0 * fwf).reshape(-1, 1) * fw_decay - fw_signal = fw_signal.reshape(fwf.shape + (dwi.shape[-1], )) - - # compute free-water eliminated signal - fwe_signal = dwi.get_fdata() - fw_signal - - # return free-water eliminated signal - return nib.Nifti1Image(fwe_signal, affine = dwi.affine) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("dwi_fname", type = str) - parser.add_argument("bval_fname", type = str) - parser.add_argument("bvec_fname", type = str) - parser.add_argument("mask_fname", type = str) - parser.add_argument("fwe_model", type = str) - parser.add_argument("output_fname", type = str) - parser.add_argument("--Diso", type = float, default = 3.0e-3) - parser.add_argument("--save_params", type = bool, default = False) - args = parser.parse_args() - - free_water_elimination( - dwi_fname = args.dwi_fname, - bval_fname = args.bval_fname, - bvec_fname = args.bvec_fname, - mask_fname = args.mask_fname, - fwe_model = args.fwe_model, - output_fname = args.output_fname, - Diso = args.Diso, - save_params = args.save_params - ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a1ee589..69c273c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,8 @@ build-backend = "setuptools.build_meta" name = "fwe" version = "0.0.1" authors = [ - { name = "Kelly Chang", email = "kchang4@uw.edu" } + { name = "Kelly Chang", email = "kchang4@uw.edu" }, + { name = "Ariel Rokem", email = "arokem@gmail.com"} ] description = "Free water elimination in diffusion MRI" readme = "README.md" @@ -21,6 +22,36 @@ dependencies = [ "dipy" ] +[project.optional-dependencies] + dev = ["pre-commit", + "pytest"] + +[tool.setuptools_scm] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["fwe*"] + +[tool.ruff] +target-version = "py311" +exclude = ["src/fwe/__init__.py"] + +[tool.ruff.lint] +select = [ + "F", + "E", + "C", + "W", + "B", + "I", +] +ignore = [ + "B905", + "C901", + "E203", + "F821" +] + [project.urls] Homepage = "https://github.com/nrdg/fwe" diff --git a/fwe/__init__.py b/src/fwe/__init__.py similarity index 100% rename from fwe/__init__.py rename to src/fwe/__init__.py diff --git a/fwe/beltrami.py b/src/fwe/beltrami.py similarity index 57% rename from fwe/beltrami.py rename to src/fwe/beltrami.py index f11ca55..7fd99d9 100644 --- a/fwe/beltrami.py +++ b/src/fwe/beltrami.py @@ -24,7 +24,7 @@ # * Neither the name of Marc Golub nor the names of any # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. - +# # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -39,26 +39,58 @@ import numpy as np -from dipy.reconst.vec_val_sum import vec_val_vect -from dipy.reconst.dki import _positive_evals from dipy.core.gradients import gradient_table -from dipy.reconst.base import ReconstModel -from dipy.reconst.dti import (TensorFit, design_matrix, lower_triangular, - eig_from_lo_tri, MIN_POSITIVE_SIGNAL, - ols_fit_tensor, fractional_anisotropy, - mean_diffusivity) from dipy.core.onetime import auto_attr +from dipy.reconst.base import ReconstModel +from dipy.reconst.dki import _positive_evals +from dipy.reconst.dti import ( + MIN_POSITIVE_SIGNAL, + TensorFit, + design_matrix, + eig_from_lo_tri, + fractional_anisotropy, + lower_triangular, + mean_diffusivity, + ols_fit_tensor, +) +from dipy.reconst.vec_val_sum import vec_val_vect + +MAX_DIFFUSIVITY = 0.005 +MIN_DIFFUSIVITY = 0.00001 +MIN_FRAC = 0.0001 +WATER_DIFFUSIVITY = 0.003 # s/mm^2 + +__all__ = ["BeltramiModel", "BeltramiFit"] + + +def model_prediction(model_params, gtab, S0, Diso=WATER_DIFFUSIVITY): + """ + Predict dMRI signal based on fwdti model parameters. + + Parameters + ---------- + model_params: array + Shape (i, j, k, 13). -MAX_DIFFFUSIVITY = 5 -MIN_DIFFUSIVITY = 0.01 + gtab : GradientTable class instance + The gradient table for which predictions will be computed. + S0 : array + Shape (i, j, k), Non-diffusion-weighted signal -def model_prediction(model_params, gtab, S0, Diso): + Diso : float + The diffusivity of an isotropic compartment. Default: diffusivity of + water at body temperature, which is 3 s/mm^2. + Returns + ------- + prediction: array + Shape (i, j, k, len(gtab.bvals)). + """ evals = model_params[..., :3] evecs = model_params[..., 3:12].reshape(model_params.shape[:-1] + (3, 3)) fraction = model_params[..., 12][..., None] qform = vec_val_vect(evecs, evals) - lower_tissue = lower_triangular(qform, S0) + lower_tissue = lower_triangular(qform, b0=S0) lower_water = np.copy(lower_tissue) lower_water[..., 0] = Diso lower_water[..., 1] = 0 @@ -67,17 +99,54 @@ def model_prediction(model_params, gtab, S0, Diso): lower_water[..., 4] = 0 lower_water[..., 5] = Diso H = design_matrix(gtab) - Stissue = fraction * np.exp(np.einsum('...j,ij->...i', lower_tissue, H)) - Swater = (1 - fraction) * np.exp(np.einsum('...j,ij->...i', lower_water, H)) + Stissue = fraction * np.exp(np.einsum("...j,ij->...i", lower_tissue, H)) + Swater = (1 - fraction) * np.exp(np.einsum("...j,ij->...i", lower_water, H)) mask = _positive_evals(evals[..., 0], evals[..., 1], evals[..., 2]) return (Stissue + Swater) * mask[..., None] -class Manifold(): - - def __init__(self, design_matrix, model_params, attenuations, fmin, fmax, - Diso=3, beta=1, mask=None, zooms=None): - +class _BeltramiOptimizer: + """ + The Manifold class + """ + + def __init__( + self, + design_matrix, + model_params, + attenuations, + fmin, + fmax, + Diso=WATER_DIFFUSIVITY, + beta=1, + mask=None, + zooms=None, + ): + """ + Manifold initialization with given parameters and design matrix. + + Parameters + ---------- + design_matrix : array + The design matrix for the model. Shape (n_directions, 6) + model_params : array + Parameters of a FWDTI model. Shape (x, y, z, 13). + attenuations : array + The attenuation of the signal by diffusion-weighting in each + measurement. Shape (x, y, z, n_directions) + fmin: array + Tissue fraction minimum. Shape (x, y, z) + fmax : array + Tissue fraction maximum. Shape (x, y, z) + Diso : float + The diffusivity of an isotropic compartment. Default: diffusivity + of water at body temperature, which is 0.003 s/mm^2. + beta : float + A metric ratio parameter. + mask : array + Binary mask into the model. Used to avoid unstable derivatives at + the mask boundary Shape (x, y, z). + """ # Manifold shape self.shape = model_params.shape[:-1] @@ -105,16 +174,16 @@ def __init__(self, design_matrix, model_params, attenuations, fmin, fmax, self.mask = np.ones(model_params.shape[:-1]).astype(bool) else: self.mask = mask.astype(bool) - + # Masks for derivatives, # to avoid unstable derivatives at the mask boundary nx, ny, nz = self.mask.shape - shift_fx = np.append(np.arange(1, nx), nx-1) - shift_fy = np.append(np.arange(1, ny), ny-1) - shift_fz = np.append(np.arange(1, nz), nz-1) - shift_bx = np.append(0, np.arange(nx-1)) - shift_by = np.append(0, np.arange(ny-1)) - shift_bz = np.append(0, np.arange(nz-1)) + shift_fx = np.append(np.arange(1, nx), nx - 1) + shift_fy = np.append(np.arange(1, ny), ny - 1) + shift_fz = np.append(np.arange(1, nz), nz - 1) + shift_bx = np.append(0, np.arange(nx - 1)) + shift_by = np.append(0, np.arange(ny - 1)) + shift_bz = np.append(0, np.arange(nz - 1)) self.mask_forward_x = self.mask[shift_fx, ...] * self.mask self.mask_forward_y = self.mask[:, shift_fy, :] * self.mask self.mask_forward_z = self.mask[..., shift_fz] * self.mask @@ -124,7 +193,7 @@ def __init__(self, design_matrix, model_params, attenuations, fmin, fmax, # Voxel resolution if zooms is None: - self.zooms = np.array([1., 1., 1.]) + self.zooms = np.array([1.0, 1.0, 1.0]) else: self.zooms = zooms / np.min(zooms) @@ -143,19 +212,18 @@ def __init__(self, design_matrix, model_params, attenuations, fmin, fmax, self.flat_fmax = fmax[self.mask][..., None] # Increment matrices - self.flat_beltrami = np.zeros(self.flat_fraction.shape[:-1] + (6, )) - self.flat_fidelity = np.zeros(self.flat_fraction.shape[:-1] + (6, )) + self.flat_beltrami = np.zeros(self.flat_fraction.shape[:-1] + (6,)) + self.flat_fidelity = np.zeros(self.flat_fraction.shape[:-1] + (6,)) self.flat_df = np.zeros(self.flat_fraction.shape) # cost self.flat_cost = np.zeros(self.flat_fraction.shape) self.flat_g = np.zeros(self.flat_fraction.shape) - @staticmethod def forward_difference(array, d, axis): n = array.shape[axis] - shift = np.append(np.arange(1, n), n-1) + shift = np.append(np.arange(1, n), n - 1) if axis == 0: return (array[shift, ...] - array) / d elif axis == 1: @@ -163,18 +231,16 @@ def forward_difference(array, d, axis): elif axis == 2: return (array[..., shift, :] - array) / d - @staticmethod def backward_difference(array, d, axis): n = array.shape[axis] - shift = np.append(np.arange(1, n), n-1) + shift = np.append(np.arange(1, n), n - 1) if axis == 0: return (array - array[shift, ...]) / d elif axis == 1: return (array - array[:, shift, ...]) / d elif axis == 2: return (array - array[..., shift, :]) / d - @property def flat_lowtri(self): @@ -182,31 +248,29 @@ def flat_lowtri(self): out[..., [1, 3, 4]] *= 1 / np.sqrt(2) return out - def compute_beltrami(self): - # Computing derivatives dx, dy, dz = self.zooms - X_dx = (Manifold.forward_difference(self.X, dx, 0) - * self.mask_forward_x[..., None]) - X_dy = (Manifold.forward_difference(self.X, dy, 1) - * self.mask_forward_y[..., None]) - X_dz = (Manifold.forward_difference(self.X, dz, 2) - * self.mask_forward_z[..., None]) - - # Computing the Manifold metric (Euclidean) - g11 = np.sum(X_dx * X_dx, axis=-1) * self.beta + 1. + X_dx = self.forward_difference(self.X, dx, 0) * self.mask_forward_x[..., None] + X_dy = self.forward_difference(self.X, dy, 1) * self.mask_forward_y[..., None] + X_dz = self.forward_difference(self.X, dz, 2) * self.mask_forward_z[..., None] + + # Computing the Manifold metric (Euclidean) + g11 = np.sum(X_dx * X_dx, axis=-1) * self.beta + 1.0 g12 = np.sum(X_dx * X_dy, axis=-1) * self.beta - g22 = np.sum(X_dy * X_dy, axis=-1) * self.beta + 1. + g22 = np.sum(X_dy * X_dy, axis=-1) * self.beta + 1.0 g13 = np.sum(X_dx * X_dz, axis=-1) * self.beta g23 = np.sum(X_dy * X_dz, axis=-1) * self.beta - g33 = np.sum(X_dz * X_dz, axis=-1) * self.beta + 1. + g33 = np.sum(X_dz * X_dz, axis=-1) * self.beta + 1.0 # Computing inverse metric - gdet = (g12 * g13 * g23 * 2 + g11 * g22 * g33 - - g22 * g13**2 - - g33 * g12**2 - - g11 * g23**2) + gdet = ( + g12 * g13 * g23 * 2 + + g11 * g22 * g33 + - g22 * g13**2 + - g33 * g12**2 + - g11 * g23**2 + ) # # unstable values unstable_g = np.logical_or(gdet <= 0, gdet >= 1000) * self.mask gdet[unstable_g] = 1 @@ -236,15 +300,18 @@ def compute_beltrami(self): Ax = g11 * X_dx + g12 * X_dy + g13 * X_dz Ay = g12 * X_dx + g22 * X_dy + g23 * X_dz Az = g13 * X_dx + g23 * X_dy + g33 * X_dz - - beltrami = (Manifold.backward_difference(g * Ax, dx, 0) - * self.mask_backward_x[..., None]) - beltrami += (Manifold.backward_difference(g * Ay, dy, 1) - * self.mask_backward_y[..., None]) - beltrami += (Manifold.backward_difference(g * Az, dz, 2) - * self.mask_backward_z[..., None]) - beltrami *= 1 / g - + + beltrami = ( + self.backward_difference(g * Ax, dx, 0) * self.mask_backward_x[..., None] + ) + beltrami += ( + self.backward_difference(g * Ay, dy, 1) * self.mask_backward_y[..., None] + ) + beltrami += ( + self.backward_difference(g * Az, dz, 2) * self.mask_backward_z[..., None] + ) + beltrami *= 1 / g + self.flat_beltrami[...] = beltrami[self.mask] # Save the unstable voxels masks @@ -253,35 +320,33 @@ def compute_beltrami(self): # Save srt(det(g)) self.flat_g[..., 0] = g[self.mask, 0] - def compute_fidelity(self): - Awater = np.exp(np.einsum('...j,ij->...i', self.flat_Diso, - self.design_matrix)) - Atissue = np.exp(np.einsum('...j,ij->...i', self.flat_lowtri, - self.design_matrix)) + Awater = np.exp(np.einsum("...j,ij->...i", self.flat_Diso, self.design_matrix)) + Atissue = np.exp( + np.einsum("...j,ij->...i", self.flat_lowtri, self.design_matrix) + ) Cwater = (1 - self.flat_fraction) * Awater Ctissue = self.flat_fraction * Atissue Amodel = Ctissue + Cwater Adiff = Amodel - self.flat_attenuations - np.einsum('...i,ij->...j', -1 * Adiff * Ctissue, - self.dH, out=self.flat_fidelity) - np.sum(-1 * (Atissue - Awater) * Adiff, axis=-1, - out=self.flat_df[..., 0]) - + np.einsum( + "...i,ij->...j", -1 * Adiff * Ctissue, self.dH, out=self.flat_fidelity + ) + np.sum(-1 * (Atissue - Awater) * Adiff, axis=-1, out=self.flat_df[..., 0]) def compute_cost(self, alpha): - Awater = np.exp(np.einsum('...j,ij->...i', self.flat_Diso, - self.design_matrix)) - Atissue = np.exp(np.einsum('...j,ij->...i', self.flat_lowtri, - self.design_matrix)) + Awater = np.exp(np.einsum("...j,ij->...i", self.flat_Diso, self.design_matrix)) + Atissue = np.exp( + np.einsum("...j,ij->...i", self.flat_lowtri, self.design_matrix) + ) Cwater = (1 - self.flat_fraction) * Awater Ctissue = self.flat_fraction * Atissue Amodel = Ctissue + Cwater k = Amodel.shape[-1] - self.flat_cost[..., 0] = np.sum((Amodel - self.flat_attenuations)**2, axis=-1) / k - self.flat_cost *= 1/2 - # self.flat_cost += self.flat_g - + self.flat_cost[..., 0] = ( + np.sum((Amodel - self.flat_attenuations) ** 2, axis=-1) / k + ) + self.flat_cost *= 1 / 2 @property def update_mask(self): @@ -300,58 +365,63 @@ def update(self, dt, alpha): self.flat_df *= self.update_mask # Update parameters - self.X[self.mask, :] += dt * (self.flat_fidelity + - self.flat_beltrami * alpha) + self.X[self.mask, :] += dt * (self.flat_fidelity + self.flat_beltrami * alpha) self.flat_fraction += dt * self.flat_df # constrain the tissue fraction to its lower and upper bounds - np.clip(self.flat_fraction, self.flat_fmin, self.flat_fmax, - out=self.flat_fraction) - + np.clip( + self.flat_fraction, self.flat_fmin, self.flat_fmax, out=self.flat_fraction + ) + # update cost self.compute_cost(alpha) - @auto_attr def parameters(self): - dti_params = eig_from_lo_tri(self.flat_lowtri) - out = np.zeros(self.shape + (13, )) + out = np.zeros(self.shape + (13,)) out[self.mask, 0:12] = dti_params out[self.mask, 12] = self.flat_fraction[..., 0] return out class BeltramiModel(ReconstModel): - - def __init__(self, gtab, init_method='MD', **kwargs): + def __init__(self, gtab, init_method="MD", **kwargs): ReconstModel.__init__(self, gtab) if not callable(init_method): try: init_method = init_methods[init_method] - except KeyError: + except KeyError as e: e_s = '"' + str(init_method) + '" is not a known init ' - e_s += 'method, the init method should either be a ' - e_s += 'function or one of the available init methods' - raise ValueError(e_s) + e_s += "method, the init method should either be a " + e_s += "function or one of the available init methods" + raise ValueError(e_s) from e self.init_method = init_method self.kwargs = kwargs self.design_matrix = design_matrix(self.gtab) - init_keys = ('Diso', 'Stissue', 'Swater', 'min_tissue_diff', - 'max_tissue_diff', 'tissue_MD') - self.init_kwargs = {k:kwargs[k] for k in init_keys if k in kwargs} - fit_keys = ('iterations', 'learning_rate', 'zooms', 'metric_ratio' - 'reg_weight', 'Diso') - self.fit_kwargs = {k:kwargs[k] for k in fit_keys if k in kwargs} - + init_keys = ( + "Diso", + "Stissue", + "Swater", + "min_tissue_diff", + "max_tissue_diff", + "tissue_MD", + ) + self.init_kwargs = {k: kwargs[k] for k in init_keys if k in kwargs} + fit_keys = ( + "iterations", + "learning_rate", + "zooms", + "metric_ratio" "reg_weight", + "Diso", + ) + self.fit_kwargs = {k: kwargs[k] for k in fit_keys if k in kwargs} def predict(self, model_params, S0=1): - Diso = self.init_kwargs.get('Diso', 3) + Diso = self.init_kwargs.get("Diso", WATER_DIFFUSIVITY) return model_prediction(model_params, self.gtab, S0, Diso) - def fit(self, data, mask=None): - if mask is not None: if mask.shape != data.shape[:-1]: raise ValueError("Mask is not the same shape as data.") @@ -365,20 +435,24 @@ def fit(self, data, mask=None): f0 = np.zeros(data.shape[:-1]) fmin = np.zeros(data.shape[:-1]) fmax = np.ones(data.shape[:-1]) - f0[mask], fmin[mask], fmax[mask] = self.init_method(masked_data, - self.gtab, - **self.init_kwargs) - np.clip(f0, fmin, fmax, out=f0) + f0[mask], fmin[mask], fmax[mask] = self.init_method( + masked_data, self.gtab, **self.init_kwargs + ) + np.clip(f0, fmin, fmax, out=f0) # Initializing tissue tensor - init_params = np.zeros(data.shape[:-1] + (13, )) - Diso = self.init_kwargs.get('Diso', 3) - min_tissue_diff = self.init_kwargs.get('min_tissue_diff', 0.001) - max_tissue_diff = self.init_kwargs.get('max_tissue_diff', 2.5) - init_params[mask, 0:12] = tensor_init(masked_data, self.gtab, f0[mask], - min_tissue_diff=min_tissue_diff, - max_tissue_diff=max_tissue_diff, - Diso=Diso) + init_params = np.zeros(data.shape[:-1] + (13,)) + Diso = self.init_kwargs.get("Diso", WATER_DIFFUSIVITY) + min_tissue_diff = self.init_kwargs.get("min_tissue_diff", MIN_DIFFUSIVITY) + max_tissue_diff = self.init_kwargs.get("max_tissue_diff", MAX_DIFFUSIVITY) + init_params[mask, 0:12] = tensor_init( + masked_data, + self.gtab, + f0[mask], + min_tissue_diff=min_tissue_diff, + max_tissue_diff=max_tissue_diff, + Diso=Diso, + ) init_params[mask, 12] = f0[mask] md_tissue = np.mean(init_params[..., :3], axis=-1) @@ -389,12 +463,12 @@ def fit(self, data, mask=None): # Run gradient descent atten, gtab = get_attenuations(data, self.gtab) D = design_matrix(gtab) - beltrami_params = gradient_descent(D, init_params, - atten, fmin, fmax, mask, - **self.fit_kwargs) - + beltrami_params = gradient_descent( + D, init_params, atten, fmin, fmax, mask, **self.fit_kwargs + ) + fit = BeltramiFit(self, beltrami_params) - + # Add the initialization parameters to Class instance (for debugging) fit.initial_guess = init_params fit.finterval = np.stack((fmin, fmax), axis=-1) @@ -403,52 +477,43 @@ def fit(self, data, mask=None): class BeltramiFit(TensorFit): - def __init__(self, model, model_params): TensorFit.__init__(self, model, model_params, model_S0=None) - + @property def f(self): return self.model_params[..., 12] - @property def fw(self): - return (1 - self.model_params[..., 12]) - + return 1 - self.model_params[..., 12] @property def fwmin(self): - return (1 - self.finterval[..., 1]) - + return 1 - self.finterval[..., 1] @property def fwmax(self): - return (1 - self.finterval[..., 0]) - + return 1 - self.finterval[..., 0] @property def fw0(self): return 1 - self.initial_guess[..., 12] - @property def fa0(self): return fractional_anisotropy(self.initial_guess[..., 0:3]) - @property def md0(self): return mean_diffusivity(self.initial_guess[..., 0:3]) - def predict(self, gtab, S0=1): - Diso = self.model.fit_kwargs.get('Diso', 3) + Diso = self.model.fit_kwargs.get("Diso", WATER_DIFFUSIVITY) return model_prediction(self.model_params, gtab, S0, Diso) def get_attenuations(signal, gtab): - # Averaging S0 and getting normalized attenuations b0_inds = gtab.b0s_mask S0 = np.mean(signal[..., b0_inds], axis=-1) @@ -458,39 +523,45 @@ def get_attenuations(signal, gtab): # Correcting non realistic attenuations bvals = gtab.bvals[~b0_inds] bvecs = gtab.bvecs[~b0_inds] - Amin = np.exp(-bvals * MAX_DIFFFUSIVITY) - Amin = np.tile(Amin, Ak.shape[:-1] + (1, )) + Amin = np.exp(-bvals * MAX_DIFFUSIVITY) + Amin = np.tile(Amin, Ak.shape[:-1] + (1,)) Amax = np.exp(-bvals * MIN_DIFFUSIVITY) - Amax = np.tile(Amax, Ak.shape[:-1] + (1, )) + Amax = np.tile(Amax, Ak.shape[:-1] + (1,)) np.clip(Ak, Amin, Amax, out=Ak) # Adding 'dummy' b0 zero data to attenuations and gtab - bvals = np.insert(bvals, 0 , 0) + bvals = np.insert(bvals, 0, 0) bvecs = np.insert(bvecs, 0, np.array([0, 0, 0]), axis=0) - this_gtab = gradient_table(bvals, bvecs) - this_Ak = np.ones(Ak.shape[:-1] + (Ak.shape[-1] + 1, )) - this_gtab = gradient_table(bvals, bvecs, b0_threshold=0) + this_gtab = gradient_table(bvals, bvecs=bvecs) + this_Ak = np.ones(Ak.shape[:-1] + (Ak.shape[-1] + 1,)) + this_gtab = gradient_table(bvals, bvecs=bvecs, b0_threshold=0) this_Ak[..., 1:] = Ak return (this_Ak, this_gtab) -def fraction_init_s0(signal, gtab, Diso=3, Stissue=None, Swater=None, - min_tissue_diff=0.001, max_tissue_diff=2.5): - +def fraction_init_s0( + signal, + gtab, + Diso=WATER_DIFFUSIVITY, + Stissue=None, + Swater=None, + min_tissue_diff=MIN_DIFFUSIVITY, + max_tissue_diff=MAX_DIFFUSIVITY, +): S0 = np.mean(signal[..., gtab.b0s_mask], axis=-1) if Stissue is None or Swater is None: - Stissue = np.percentile(S0, 75) + Stissue = np.percentile(S0, 75) Swater = np.percentile(S0, 95) - print('Stissue = ' + str(Stissue)) - print('Swater = ' + str(Swater)) + print("Stissue = " + str(Stissue)) + print("Swater = " + str(Swater)) # Normalized attenuations Ak, this_gtab = get_attenuations(signal, gtab) Ak = Ak[..., 1:] bvals = this_gtab.bvals[1:] # non zero bvals Awater = np.exp(-bvals * Diso) - Awater = np.tile(Awater, Ak.shape[:-1] + (1, )) + Awater = np.tile(Awater, Ak.shape[:-1] + (1,)) # Min and Max attenuations expected in tissue Atissue_min = np.exp(-bvals * max_tissue_diff) @@ -502,29 +573,38 @@ def fraction_init_s0(signal, gtab, Diso=3, Stissue=None, Swater=None, # Min and Max volume fraction fmin = np.min(Ak - Awater, axis=-1) / np.max(Atissue_max - Awater, axis=-1) fmax = np.max(Ak - Awater, axis=-1) / np.min(Atissue_min - Awater, axis=-1) - fmin[fmin <= 0] = 0.0001 - fmin[fmin >= 1] = 1 - 0.0001 - fmax[fmax <= 0] = 0.0001 - fmax[fmax >= 1] = 1 - 0.0001 + if isinstance(fmin, np.ndarray): + fmin[fmin <= 0] = MIN_FRAC + fmin[fmin >= 1] = 1 - MIN_FRAC + else: + fmin = np.float64(max(MIN_FRAC, fmin)) + fmin = np.float64(min(1 - MIN_FRAC, fmin)) + if isinstance(fmax, np.ndarray): + fmax[fmax <= 0] = MIN_FRAC + fmax[fmax >= 1] = 1 - MIN_FRAC + else: + fmax = np.float64(max(MIN_FRAC, fmax)) + fmax = np.float64(min(1 - MIN_FRAC, fmax)) return (f0, fmin, fmax) -def fraction_init_md(signal, gtab, Diso=3, tissue_MD=0.6): - +def fraction_init_md( + signal, gtab, Diso=WATER_DIFFUSIVITY, tissue_MD=WATER_DIFFUSIVITY / 20 +): # bvals = gtab.bvals[~gtab.b0s_mask] bvals = gtab.bvals bvecs = gtab.bvecs mean_bval = np.max(bvals) # print(mean_bval) - mbvals = bvals[np.logical_or(bvals==0, bvals==mean_bval)] - mbvecs = bvecs[np.logical_or(bvals==0, bvals==mean_bval), :] - mgtab = gradient_table(mbvals, mbvecs, b0_threshold=0) - msignal = signal[..., np.logical_or(bvals==0, bvals==mean_bval)] + mbvals = bvals[np.logical_or(bvals == 0, bvals == mean_bval)] + mbvecs = bvecs[np.logical_or(bvals == 0, bvals == mean_bval), :] + mgtab = gradient_table(mbvals, bvecs=mbvecs, b0_threshold=0) + msignal = signal[..., np.logical_or(bvals == 0, bvals == mean_bval)] # Conventional DTI - dti_params = ols_fit_tensor(design_matrix(mgtab), msignal) + dti_params = ols_fit_tensor(design_matrix(mgtab), msignal)[0] eigvals = dti_params[..., 0:3] MD = np.mean(eigvals, axis=-1) # mean diffusivity @@ -534,97 +614,128 @@ def fraction_init_md(signal, gtab, Diso=3, tissue_MD=0.6): f0 = (np.exp(-mean_bval * MD) - Awater) / (Atissue - Awater) # Min and Max volume fractions - fmin = np.ones(f0.shape) * 0.0001 - fmax = np.ones(f0.shape) * (1 - 0.0001) + fmin = np.ones(f0.shape) * MIN_FRAC + fmax = np.ones(f0.shape) * (1 - MIN_FRAC) return (f0, fmin, fmax) -def fraction_init_hybrid(signal, gtab, Diso=3, Stissue=None, Swater=None, - min_tissue_diff=0.001, max_tissue_diff=2.5, - tissue_MD=0.6): - - f_S0, fmin, fmax = fraction_init_s0(signal, gtab, Diso=Diso, - Stissue=Stissue, Swater=Swater, - min_tissue_diff=min_tissue_diff, - max_tissue_diff=max_tissue_diff) - f_MD, _, _ = fraction_init_md(signal, gtab, Diso=Diso, - tissue_MD=tissue_MD) +def fraction_init_hybrid( + signal, + gtab, + Diso=WATER_DIFFUSIVITY, + Stissue=None, + Swater=None, + min_tissue_diff=MIN_DIFFUSIVITY, + max_tissue_diff=MAX_DIFFUSIVITY, + tissue_MD=WATER_DIFFUSIVITY / 20, +): + f_S0, fmin, fmax = fraction_init_s0( + signal, + gtab, + Diso=Diso, + Stissue=Stissue, + Swater=Swater, + min_tissue_diff=min_tissue_diff, + max_tissue_diff=max_tissue_diff, + ) + f_MD, _, _ = fraction_init_md(signal, gtab, Diso=Diso, tissue_MD=tissue_MD) # hybrid initialization alpha = np.copy(f_S0) - np.clip(alpha, 0.0001, 0.9999, out=alpha) + np.clip(alpha, MIN_FRAC, 1 - MIN_FRAC, out=alpha) np.clip(f_S0, fmin, fmax, out=f_S0) - np.clip(f_MD, 0.0001, 0.9999, out=f_MD) - f0 = (f_MD**alpha) * (f_S0**(1 - alpha)) + np.clip(f_MD, MIN_FRAC, 1 - MIN_FRAC, out=f_MD) + f0 = (f_MD**alpha) * (f_S0 ** (1 - alpha)) # f0 = (f_S0**(f_MD)) * f_MD**(1 - f_MD) return (f0, fmin, fmax) -def tensor_init(signal, gtab, fraction, Diso=3, min_tissue_diff=0.001, - max_tissue_diff=2.5): - +def tensor_init( + signal, + gtab, + fraction, + Diso=WATER_DIFFUSIVITY, + min_tissue_diff=MIN_DIFFUSIVITY, + max_tissue_diff=MAX_DIFFUSIVITY, +): Ak, this_gtab = get_attenuations(signal, gtab) - # nonzero bvals and bvecs bvals = this_gtab.bvals - bvecs = this_gtab.bvecs - + # Min and Max attenuations expected in tissue Atissue_min = np.exp(-bvals * max_tissue_diff) - Atissue_min = np.tile(Atissue_min, Ak.shape[:-1] + (1, )) + Atissue_min = np.tile(Atissue_min, Ak.shape[:-1] + (1,)) Atissue_max = np.exp(-bvals * min_tissue_diff) - Atissue_max = np.tile(Atissue_max, Ak.shape[:-1] + (1, )) + Atissue_max = np.tile(Atissue_max, Ak.shape[:-1] + (1,)) # correcting the attenuations for free water f = fraction[..., None] Awater = np.exp(-bvals * Diso) - Awater = np.tile(Awater, Ak.shape[:-1] + (1, )) - Atissue = (Ak - (1-f) * Awater) / f + Awater = np.tile(Awater, Ak.shape[:-1] + (1,)) + Atissue = (Ak - (1 - f) * Awater) / f # np.clip(Atissue, Atissue_min, Atissue_max, out=Atissue) - np.clip(Atissue, 0.0001, 0.9999, out=Atissue) + np.clip(Atissue, MIN_FRAC, 1 - MIN_FRAC, out=Atissue) # applying standard DTI to corrected signal - dti_params = ols_fit_tensor(design_matrix(this_gtab), Atissue) + dti_params = ols_fit_tensor(design_matrix(this_gtab), Atissue)[0] return dti_params -def gradient_descent(design_matrix, initial_guess, attenuations, fmin, fmax, - mask, iterations=100, learning_rate=0.01, metric_ratio=1, - reg_weight=1, Diso=3, zooms=None): - +def gradient_descent( + design_matrix, + initial_guess, + attenuations, + fmin, + fmax, + mask, + iterations=100, + learning_rate=0.01, + metric_ratio=1, + reg_weight=1, + Diso=WATER_DIFFUSIVITY, + zooms=None, +): # cropping the non zero information from the data Ak = attenuations[..., 1:] H = design_matrix[1:, :-1] # Initializing manifold - manifold = Manifold(H, initial_guess, Ak, fmin, fmax, Diso=Diso, - beta=metric_ratio, mask=mask, zooms=zooms) + opt = _BeltramiOptimizer( + H, + initial_guess, + Ak, + fmin, + fmax, + Diso=Diso, + beta=metric_ratio, + mask=mask, + zooms=zooms, + ) cost = np.zeros(iterations) for i in range(iterations): - # At half itarations, turn off the regualrization term if i == iterations // 2: reg_weight = 0 - # Update the manifold - manifold.update(learning_rate, reg_weight) - cost[i] = np.mean(manifold.flat_cost) - # print(manifold.flat_cost[1000]) + # Update the optimizer + opt.update(learning_rate, reg_weight) + cost[i] = np.mean(opt.flat_cost) # Return the estimated parameters - return manifold.parameters - - -init_methods = {'S0':fraction_init_s0, - 's0':fraction_init_s0, - 'b0':fraction_init_s0, - 'md':fraction_init_md, - 'MD':fraction_init_md, - 'mean_diffusivity':fraction_init_md, - 'hybrid':fraction_init_hybrid, - 'interp':fraction_init_hybrid, - 'log_linear':fraction_init_hybrid - } \ No newline at end of file + return opt.parameters + + +init_methods = { + "S0": fraction_init_s0, + "s0": fraction_init_s0, + "b0": fraction_init_s0, + "md": fraction_init_md, + "MD": fraction_init_md, + "mean_diffusivity": fraction_init_md, + "hybrid": fraction_init_hybrid, + "interp": fraction_init_hybrid, + "log_linear": fraction_init_hybrid, +} diff --git a/src/fwe/fwe.py b/src/fwe/fwe.py new file mode 100644 index 0000000..12ab4d2 --- /dev/null +++ b/src/fwe/fwe.py @@ -0,0 +1,167 @@ +import argparse +import logging +import re + +import dipy.reconst.fwdti as fwdti +import nibabel as nib +import numpy as np +from dipy.core.gradients import gradient_table + +from .beltrami import BeltramiModel + +# set up logging configuration +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("FWE") + +DEFAULT_OPT_KWARGS = { + "init_method": "hybrid", + "Stissue": None, + "Swater": None, + "n_iterations": 100, + "learning_rate": 0.0005, +} + + +def free_water_elimination( + dwi_fname, + bval_fname, + bvec_fname, + mask_fname, + fwe_model, + output_fname, + Diso=3.0e-3, + save_params=False, + opt_kwargs=DEFAULT_OPT_KWARGS, +): + # load diffusion, mask image and bvec/bval gradient table + dwi = nib.load(dwi_fname) + mask = nib.load(mask_fname).get_fdata() + gtab = gradient_table(bval_fname, bvec_fname) + + # perform free water elimination + match fwe_model.lower(): + case "dipy_fwdti": + logger.info("Performing free water elimination using DIPY's FWDTI model") + fwe_image, model_params = dipy_fwdti(dwi, gtab, mask, Diso, save_params) + case "beltrami": + logger.info( + "Performing free water elimination using Golub's Beltrami model" + ) + fwe_image, model_params = single_shell_beltrami( + dwi, gtab, mask, Diso, save_params, **opt_kwargs + ) + case _: + logger.error("Unrecognized free water elimination model") + + # save free water model parameters + if save_params: + output_params = re.sub("_(\\w+).nii.gz$", "_params.nii.gz", output_fname) + nib.save(model_params, output_params) + logger.info(f"Saving free water model parameters to: {output_fname}") + + # save free water eliminated image + nib.save(fwe_image, output_fname) + logger.info(f"Saving free water eliminated image to: {output_fname}") + + +def single_shell_beltrami( + dwi, + gtab, + mask=None, + Diso=3.0e-3, + save_params=False, + init_method="hybrid", + Stissue=None, + Swater=None, + n_iterations=100, + learning_rate=0.0005, +): + # adjust b-values for Beltrami model + gtab.bvals = gtab.bvals + + # fit Beltrami regularized gradient descent free-water diffusion tensor model + model = BeltramiModel( + gtab, + init_method=init_method, + Diso=Diso, + Stissue=Stissue, + Swater=Swater, + iterations=n_iterations, + learning_rate=learning_rate, + ) + model_fit = model.fit(dwi.get_fdata(), mask=mask) + + # save free water dti model parameters + model_params = ( + nib.Nifti1Image(model_fit.model_params, affine=dwi.affine) + if save_params + else None + ) + + # extract free-water dti model parameters + fwf = model_fit.fw # free water fraction + + # undo b-value adjustment from Beltrami model + gtab.bvals = gtab.bvals * 1e3 + + # return free-water eliminated signal (and model parameters) + return (remove_free_water(dwi, gtab, fwf, Diso), model_params) + + +def dipy_fwdti(dwi, gtab, mask=None, Diso=3.0e-3, save_params=False): + # fit free-water diffusion tensor imaging model + model = fwdti.FreeWaterTensorModel(gtab) + model_fit = model.fit(dwi.get_fdata(), mask=mask) + + # save free water dti model parameters + model_params = ( + nib.Nifti1Image(model_fit.model_params, affine=dwi.affine) + if save_params + else None + ) + + # extract free-water dti model parameters + fwf = model_fit.model_params[..., -1] + + # return free-water eliminated signal (and model parameters) + return (remove_free_water(dwi, gtab, fwf, Diso), model_params) + + +def remove_free_water(dwi, gtab, fwf, Diso): + # extract b0 signal from dwi image + S0 = dwi.get_fdata()[..., gtab.bvals == 0].mean(axis=-1) + + # compute free-water signal + fw_decay = np.exp(-gtab.bvals * Diso) # free-water exponential decay + fw_signal = (S0 * fwf).reshape(-1, 1) * fw_decay + fw_signal = fw_signal.reshape(fwf.shape + (dwi.shape[-1],)) + + # compute free-water eliminated signal + fwe_signal = dwi.get_fdata() - fw_signal + + # return free-water eliminated signal + return nib.Nifti1Image(fwe_signal, affine=dwi.affine) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dwi_fname", type=str) + parser.add_argument("bval_fname", type=str) + parser.add_argument("bvec_fname", type=str) + parser.add_argument("mask_fname", type=str) + parser.add_argument("fwe_model", type=str) + parser.add_argument("output_fname", type=str) + parser.add_argument("--Diso", type=float, default=3.0e-3) + parser.add_argument("--save_params", type=bool, default=False) + args = parser.parse_args() + + free_water_elimination( + dwi_fname=args.dwi_fname, + bval_fname=args.bval_fname, + bvec_fname=args.bvec_fname, + mask_fname=args.mask_fname, + fwe_model=args.fwe_model, + output_fname=args.output_fname, + Diso=args.Diso, + save_params=args.save_params, + ) diff --git a/test/test_beltrami.py b/test/test_beltrami.py new file mode 100644 index 0000000..40a8b0d --- /dev/null +++ b/test/test_beltrami.py @@ -0,0 +1,75 @@ +import numpy as np +from dipy.core.gradients import gradient_table +from dipy.data.fetcher import get_fnames, read_bvals_bvecs +from dipy.reconst.dti import TensorModel, decompose_tensor, from_lower_triangular +from dipy.sims.voxel import all_tensor_evecs, multi_tensor, single_tensor +from fwe.beltrami import BeltramiModel +from numpy.testing import assert_almost_equal + +_, fbvals, fbvecs = get_fnames(name="small_64D") +bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) +gtab = gradient_table(bvals, bvecs=bvecs) + + +def setup_module(): + """Module-level setup""" + global gtab, mevals, model_params_mv + global DWI, FAref, GTF, MDref, FAdti, MDdti + _, fbvals, fbvecs = get_fnames(name="small_64D") + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) + gtab = gradient_table(bvals, bvecs=bvecs) + + # Simulation a typical DT and DW signal for no water contamination + # S0 = np.array(100) + dt = np.array([0.0017, 0, 0.0003, 0, 0, 0.0003]) + evals, evecs = decompose_tensor(from_lower_triangular(dt)) + S_tissue = single_tensor(gtab, S0=100, evals=evals, evecs=evecs, snr=None) + dm = TensorModel(gtab, fit_method="WLS") + dtifit = dm.fit(S_tissue) + FAdti = dtifit.fa + MDdti = dtifit.md + + # Simulation of 8 voxels tested + DWI = np.zeros((2, 2, 2, len(gtab.bvals))) + FAref = np.zeros((2, 2, 2)) + MDref = np.zeros((2, 2, 2)) + # Diffusion of tissue and water compartments are constant for all voxel + mevals = np.array([[0.0017, 0.0003, 0.0003], [0.003, 0.003, 0.003]]) + # volume fractions + GTF = np.array([[[0.06, 0.71], [0.33, 0.91]], [[0.0, 0.0], [0.0, 0.0]]]) + # S0 multivoxel + # S0m = 100 * np.ones((2, 2, 2)) + # model_params ground truth (to be fill) + model_params_mv = np.zeros((2, 2, 2, 13)) + for i in range(2): + for j in range(2): + gtf = GTF[0, i, j] + S, p = multi_tensor( + gtab, + mevals, + S0=100, + angles=[(90, 0), (90, 0)], + fractions=[(1 - gtf) * 100, gtf * 100], + snr=None, + ) + DWI[0, i, j] = S + FAref[0, i, j] = FAdti + MDref[0, i, j] = MDdti + R = all_tensor_evecs(p[0]) + R = R.reshape(9) + model_params_mv[0, i, j] = np.concatenate( + ([0.0017, 0.0003, 0.0003], R, [gtf]), axis=0 + ) + + +def test_beltrami_model(): + global DWI, FAref, GTF, MDdti, FAdti + fwdm = BeltramiModel(gtab) + fwefit = fwdm.fit(DWI) + FA = fwefit.fa + FWF = fwefit.f + MD = fwefit.md + + assert_almost_equal(FWF, GTF, decimal=3) + assert_almost_equal(FA, FAref, decimal=3) + assert_almost_equal(MD, MDref, decimal=3)