diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 570f629b..00000000 Binary files a/.DS_Store and /dev/null differ diff --git a/tensorflow_mri/.DS_Store b/tensorflow_mri/.DS_Store deleted file mode 100644 index c8ff8f76..00000000 Binary files a/tensorflow_mri/.DS_Store and /dev/null differ diff --git a/tensorflow_mri/python/losses/iqa_losses.py b/tensorflow_mri/python/losses/iqa_losses.py index bde0c74d..d7c5013c 100644 --- a/tensorflow_mri/python/losses/iqa_losses.py +++ b/tensorflow_mri/python/losses/iqa_losses.py @@ -414,6 +414,87 @@ def ssim_multiscale_loss(y_true, y_pred, max_val=None, rank=rank) + +@tf.keras.utils.register_keras_serializable(package="MRI") +class MeanAbsoluteGradientError(LossFunctionWrapperIQA): + def __init__(self, + method='sobel', + norm=False, + batch_dims=None, + image_dims=None, + multichannel=True, + complex_part=None, + reduction=tf.keras.losses.Reduction.AUTO, + name='mean_absolute_gradient_error'): + super().__init__(mean_absolute_gradient_error, + reduction=reduction, name=name, method=method, + norm=norm, batch_dims=batch_dims, image_dims=image_dims, + multichannel=multichannel, complex_part=complex_part) + + +@tf.keras.utils.register_keras_serializable(package="MRI") +class MeanSquaredGradientError(LossFunctionWrapperIQA): + def __init__(self, + method='sobel', + norm=False, + batch_dims=None, + image_dims=None, + multichannel=True, + complex_part=None, + reduction=tf.keras.losses.Reduction.AUTO, + name='mean_squared_gradient_error'): + super().__init__(mean_squared_gradient_error, + reduction=reduction, name=name, method=method, + norm=norm, batch_dims=batch_dims, image_dims=image_dims, + multichannel=multichannel, complex_part=complex_part) + +@tf.keras.utils.register_keras_serializable(package="MRI") +def mean_absolute_error(y_true, y_pred): + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + return tf.math.reduce_mean(tf.math.abs(y_pred - y_true), axis=-1) + +@tf.keras.utils.register_keras_serializable(package="MRI") +def mean_squared_error(y_true, y_pred): + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + return tf.math.reduce_mean( + tf.math.real(tf.math.squared_difference(y_pred, y_true)), axis=-1) + + +@tf.keras.utils.register_keras_serializable(package="MRI") +def mean_absolute_gradient_error(y_true, y_pred, method='sobel', + norm=False, batch_dims=None, image_dims=None): + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + + grad_true = image_ops.image_gradients( + y_true, method=method, norm=norm, + batch_dims=batch_dims, image_dims=image_dims) + grad_pred = image_ops.image_gradients( + y_pred, method=method, norm=norm, + batch_dims=batch_dims, image_dims=image_dims) + + return mean_absolute_error(grad_true, grad_pred) + + +@tf.keras.utils.register_keras_serializable(package="MRI") +def mean_squared_gradient_error(y_true, y_pred, method='sobel', + norm=False, batch_dims=None, image_dims=None): + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + + grad_true = image_ops.image_gradients( + y_true, method=method, norm=norm, + batch_dims=batch_dims, image_dims=image_dims) + grad_pred = image_ops. image_gradients( + y_pred, method=method, norm=norm, + batch_dims=batch_dims, image_dims=image_dims) + + return mean_squared_error(grad_true, grad_pred) + + + # For backward compatibility. @tf.keras.utils.register_keras_serializable(package="MRI") class StructuralSimilarityLoss(SSIMLoss): diff --git a/tensorflow_mri/python/metrics/iqa_metrics.py b/tensorflow_mri/python/metrics/iqa_metrics.py index c23c5090..82620687 100755 --- a/tensorflow_mri/python/metrics/iqa_metrics.py +++ b/tensorflow_mri/python/metrics/iqa_metrics.py @@ -330,6 +330,61 @@ def __init__(self, multichannel=multichannel, complex_part=complex_part) +# We register this object with the Keras serialization framework under a +# different name, in order to avoid clashing with the loss of the same name. +@api_util.export("metrics.MeanAbsGrad", + "metrics.MeanAbsoluteGradientErrorMetric") +@tf.keras.utils.register_keras_serializable( + package='MRI', name='MeanAbsoluteGradientErrorMetric') +class MeanAbsoluteGradientError(MeanMetricWrapperIQA): + + def __init__(self, + method='sobel', + norm=False, + batch_dims=None, + image_dims=None, + multichannel=True, + complex_part=None, + name='mage', + dtype=None): + super().__init__(image_ops.mean_absolute_gradient_error, + method=method, + norm=norm, + batch_dims=batch_dims, + image_dims=image_dims, + multichannel=multichannel, + complex_part=complex_part, + name=name, + dtype=dtype) + + +# We register this object with the Keras serialization framework under a +# different name, in order to avoid clashing with the loss of the same name. +@api_util.export("metrics.MeanSqGrad", + "metrics.MeanSquaredGradientErrorMetric") +@tf.keras.utils.register_keras_serializable( + package='MRI', name='MeanSquaredGradientErrorMetric') +class MeanSquaredGradientError(MeanMetricWrapperIQA): + + def __init__(self, + method='sobel', + norm=False, + batch_dims=None, + image_dims=None, + multichannel=True, + complex_part=None, + name='msge', + dtype=None): + super().__init__(image_ops.mean_squared_gradient_error, + method=method, + norm=norm, + batch_dims=batch_dims, + image_dims=image_dims, + multichannel=multichannel, + complex_part=complex_part, + name=name, + dtype=dtype) + # For backward compatibility. @tf.keras.utils.register_keras_serializable(package="MRI") diff --git a/tensorflow_mri/python/ops/image_ops.py b/tensorflow_mri/python/ops/image_ops.py index 755871bd..9a995d22 100644 --- a/tensorflow_mri/python/ops/image_ops.py +++ b/tensorflow_mri/python/ops/image_ops.py @@ -31,6 +31,7 @@ from tensorflow_mri.python.util import api_util from tensorflow_mri.python.util import check_util from tensorflow_mri.python.util import deprecation +from tensorflow_mri.python.losses import iqa_losses @api_util.export("image.psnr") @@ -1025,6 +1026,85 @@ def _gradient_operators(method, norm=False, rank=2, dtype=tf.float32): kernels[d] *= operator_1d return tf.stack(kernels, axis=0) +@tf.keras.utils.register_keras_serializable(package="MRI") +class MeanAbsoluteGradientError(iqa_losses.LossFunctionWrapperIQA): + def __init__(self, + method='sobel', + norm=False, + batch_dims=None, + image_dims=None, + multichannel=True, + complex_part=None, + reduction=tf.keras.losses.Reduction.AUTO, + name='mean_absolute_gradient_error'): + super().__init__(mean_absolute_gradient_error, + reduction=reduction, name=name, method=method, + norm=norm, batch_dims=batch_dims, image_dims=image_dims, + multichannel=multichannel, complex_part=complex_part) + + +@tf.keras.utils.register_keras_serializable(package="MRI") +class MeanSquaredGradientError(iqa_losses.LossFunctionWrapperIQA): + def __init__(self, + method='sobel', + norm=False, + batch_dims=None, + image_dims=None, + multichannel=True, + complex_part=None, + reduction=tf.keras.losses.Reduction.AUTO, + name='mean_squared_gradient_error'): + super().__init__(mean_squared_gradient_error, + reduction=reduction, name=name, method=method, + norm=norm, batch_dims=batch_dims, image_dims=image_dims, + multichannel=multichannel, complex_part=complex_part) + +@tf.keras.utils.register_keras_serializable(package="MRI") +def mean_absolute_error(y_true, y_pred): + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + return tf.math.reduce_mean(tf.math.abs(y_pred - y_true), axis=-1) + +@tf.keras.utils.register_keras_serializable(package="MRI") +def mean_squared_error(y_true, y_pred): + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + return tf.math.reduce_mean( + tf.math.real(tf.math.squared_difference(y_pred, y_true)), axis=-1) + + +@tf.keras.utils.register_keras_serializable(package="MRI") +def mean_absolute_gradient_error(y_true, y_pred, method='sobel', + norm=False, batch_dims=None, image_dims=None): + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + + grad_true = image_gradients( + y_true, method=method, norm=norm, + batch_dims=batch_dims, image_dims=image_dims) + grad_pred = image_gradients( + y_pred, method=method, norm=norm, + batch_dims=batch_dims, image_dims=image_dims) + + return mean_absolute_error(grad_true, grad_pred) + + +@tf.keras.utils.register_keras_serializable(package="MRI") +def mean_squared_gradient_error(y_true, y_pred, method='sobel', + norm=False, batch_dims=None, image_dims=None): + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + + grad_true = image_gradients( + y_true, method=method, norm=norm, + batch_dims=batch_dims, image_dims=image_dims) + grad_pred = image_gradients( + y_pred, method=method, norm=norm, + batch_dims=batch_dims, image_dims=image_dims) + + return mean_squared_error(grad_true, grad_pred) + + def _filter_image(image, kernels): """Filters an image using the specified kernels. diff --git a/tests/.DS_Store b/tests/.DS_Store deleted file mode 100644 index 97ba5efa..00000000 Binary files a/tests/.DS_Store and /dev/null differ diff --git a/tools/.DS_Store b/tools/.DS_Store deleted file mode 100644 index 5a583841..00000000 Binary files a/tools/.DS_Store and /dev/null differ diff --git a/tools/docs/guide/contribute.ipynb b/tools/docs/guide/contribute.ipynb deleted file mode 100644 index 98b41652..00000000 --- a/tools/docs/guide/contribute.ipynb +++ /dev/null @@ -1,32 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Contributing\n", - "\n", - "Coming soon..." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.8.10 64-bit", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.8.10" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tools/docs/guide/fft.ipynb b/tools/docs/guide/fft.ipynb new file mode 100644 index 00000000..72099ca6 --- /dev/null +++ b/tools/docs/guide/fft.ipynb @@ -0,0 +1,101 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fast Fourier transform (FFT)\n", + "\n", + "TensorFlow MRI uses the built-in FFT ops in core TensorFlow. These are [`tf.signal.fft`](https://www.tensorflow.org/api_docs/python/tf/signal/fft), [`tf.signal.fft2d`](https://www.tensorflow.org/api_docs/python/tf/signal/fft2d) and [`tf.signal.fft3d`](https://www.tensorflow.org/api_docs/python/tf/signal/fft3d).\n", + "\n", + "## N-dimensional FFT\n", + "\n", + "For convenience, TensorFlow MRI also provides [`tfmri.signal.fft`](https://mrphys.github.io/tensorflow-mri/api_docs/tfmri/signal/fft/), which can be used for N-dimensional FFT calculations and provides convenient access to commonly used functionality such as padding/cropping, normalization and shifting of the zero-frequency component within the same function call.\n", + "\n", + "## Custom FFT kernels for CPU\n", + "\n", + "Unfortunately, TensorFlow's FFT ops are [known to be slow](https://github.com/tensorflow/tensorflow/issues/6541) on CPU. As a result, the FFT can become a significant bottleneck on MRI processing pipelines, especially on iterative reconstructions where the FFT is called repeatedly.\n", + "\n", + "To address this issue, TensorFlow MRI provides a set of custom FFT kernels based on the FFTW library. These offer a significant boost in performance compared to the kernels in core TensorFlow.\n", + "\n", + "The custom FFT kernels are automatically registered to the TensorFlow framework when importing TensorFlow MRI. If you have imported TensorFlow MRI, then the standard FFT ops will use the optimized kernels automatically.\n", + "\n", + "```{tip}\n", + "You only need to `import tensorflow_mri` in order to use the custom FFT kernels. You can then access them as usual through `tf.signal.fft`, `tf.signal.fft2d` and `tf.signal.fft3d`.\n", + "```\n", + "\n", + "The only caveat is that the [FFTW license](https://www.fftw.org/doc/License-and-Copyright.html) is more restrictive than the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0) used by TensorFlow MRI. In particular, GNU GPL requires you to distribute any derivative software under equivalent terms.\n", + "\n", + "```{warning}\n", + "If you intend to use custom FFT kernels for commercial purposes, you will need to purchase a commercial FFTW license.\n", + "```\n", + "\n", + "### Disable the use of custom FFT kernels\n", + "\n", + "You can control whether custom FFT kernels are used via the `TFMRI_USE_CUSTOM_FFT` environment variable. When set to false, TensorFlow MRI will not register its custom FFT kernels, falling back to the standard FFT kernels in core TensorFlow. If the variable is unset, its value defaults to true.\n", + "\n", + "````{tip}\n", + "Set `TFMRI_USE_CUSTOM_FFT=0` to disable the custom FFT kernels.\n", + "\n", + "```python\n", + "os.environ[\"TFMRI_USE_CUSTOM_FFT\"] = \"0\"\n", + "import tensorflow_mri as tfmri\n", + "```\n", + "\n", + "```{attention}\n", + "`TFMRI_USE_CUSTOM_FFT` must be set **before** importing TensorFlow MRI. Setting or changing its value after importing the package will have no effect.\n", + "```\n", + "````\n", + "\n", + "### Customize the behavior of custom FFT kernels\n", + "\n", + "FFTW allows you to control the rigor of the planning process. The more rigorously a plan is created, the more efficient the actual FFT execution is likely to be, at the expense of a longer planning time. TensorFlow MRI lets you control the FFTW planning rigor through the `TFMRI_FFTW_PLANNING_RIGOR` environment variable. Valid values for this variable are:\n", + "\n", + "- `\"estimate\"` specifies that, instead of actual measurements of different algorithms, a simple heuristic is used to pick a (probably sub-optimal) plan quickly.\n", + "- `\"measure\"` tells FFTW to find an optimized plan by actually computing several FFTs and measuring their execution time. Depending on your machine, this can take some time (often a few seconds). This is the default planning option.\n", + "- `\"patient\"` is like `\"measure\"`, but considers a wider range of algorithms and often produces a “more optimal” plan (especially for large transforms), but at the expense of several times longer planning time (especially for large transforms).\n", + "- `\"exhaustive\"` is like `\"patient\"`, but considers an even wider range of algorithms, including many that we think are unlikely to be fast, to produce the most optimal plan but with a substantially increased planning time.\n", + "\n", + "````{tip}\n", + "Set the environment variable `TFMRI_FFTW_PLANNING_RIGOR` to control the planning rigor.\n", + "\n", + "```python\n", + "os.environ[\"TFMRI_FFTW_PLANNING_RIGOR\"] = \"estimate\"\n", + "import tensorflow_mri as tfmri\n", + "```\n", + "\n", + "```{attention}\n", + "`TFMRI_FFTW_PLANNING_RIGOR` must be set **before** importing TensorFlow MRI. Setting or changing its value after importing the package will have no effect.\n", + "```\n", + "````\n", + "\n", + "```{note}\n", + "FFTW accumulates \"wisdom\" each time the planner is called, and this wisdom is persisted across invocations of the FFT kernels (during the same process). Therefore, more rigorous planning options will result in long planning times during the first FFT invocation, but may result in faster execution during subsequent invocations. When performing a large amount of similar FFT invocations (e.g., while training a model or performing iterative reconstructions), you are more likely to benefit from more rigorous planning.\n", + "```\n", + "\n", + "```{seealso}\n", + "The FFTW [planner flags](https://www.fftw.org/doc/Planner-Flags.html) documentation page.\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.2 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.8.2" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "0adcc2737ebf6a4a119f135174df96668767fca1ef1112612db5ecadf2b6d608" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tools/docs/guide/linalg.ipynb b/tools/docs/guide/linalg.ipynb deleted file mode 100644 index f45442d9..00000000 --- a/tools/docs/guide/linalg.ipynb +++ /dev/null @@ -1,32 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Linear algebra\n", - "\n", - "Coming soon..." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.8.10 64-bit", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.8.10" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tools/docs/guide/optim.ipynb b/tools/docs/guide/optim.ipynb deleted file mode 100644 index 21363722..00000000 --- a/tools/docs/guide/optim.ipynb +++ /dev/null @@ -1,32 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Optimization\n", - "\n", - "Coming soon..." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.8.2 64-bit", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.8.2" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "0adcc2737ebf6a4a119f135174df96668767fca1ef1112612db5ecadf2b6d608" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tools/docs/guide/recon.ipynb b/tools/docs/guide/recon.ipynb deleted file mode 100644 index 5291a70b..00000000 --- a/tools/docs/guide/recon.ipynb +++ /dev/null @@ -1,32 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# MR image reconstruction\n", - "\n", - "Coming soon..." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.8.2 64-bit", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.8.2" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "0adcc2737ebf6a4a119f135174df96668767fca1ef1112612db5ecadf2b6d608" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tools/docs/templates/index.rst b/tools/docs/templates/index.rst index 13b8384c..e75901fc 100644 --- a/tools/docs/templates/index.rst +++ b/tools/docs/templates/index.rst @@ -16,11 +16,8 @@ TensorFlow MRI |release| Guide Installation + Uniform FFT Non-uniform FFT - Linear algebra - Optimization - MRI reconstruction - Contributing FAQ diff --git a/tools/docs/tutorials/recon.rst b/tools/docs/tutorials/recon.rst index dcc8f28e..cf9dff50 100644 --- a/tools/docs/tutorials/recon.rst +++ b/tools/docs/tutorials/recon.rst @@ -9,5 +9,5 @@ Image reconstruction CARTESIAN SENSE (2D+t Cartesian k-space) GRIDDING (Radials and Spirals) PRE-PROCESSING TRIGGERED CINE DATASET (with GRAPPA and PF) - CG-SENSE (Radial, 2D and 2D+t) + CG-SENSE COMPRESSED SENSING (Radial, 2D and 2D+t) \ No newline at end of file diff --git a/tools/docs/tutorials/recon/cg_sense.ipynb b/tools/docs/tutorials/recon/cg_sense.ipynb index ad1ccf84..d1ab7fa4 100644 --- a/tools/docs/tutorials/recon/cg_sense.ipynb +++ b/tools/docs/tutorials/recon/cg_sense.ipynb @@ -7,6 +7,16 @@ "# Image reconstruction with CG-SENSE" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![View on website](https://img.shields.io/badge/-View%20on%20website-128091?labelColor=grey&logo=)](https://mrphys.github.io/tensorflow-mri/tutorials/recon/cg_sense)\n", + "[![Run in Colab](https://img.shields.io/badge/-Run%20in%20Colab-128091?labelColor=grey&logo=googlecolab)](https://colab.research.google.com/github/mrphys/tensorflow-mri/blob/master/tools/docs/tutorials/recon/cg_sense.ipynb)\n", + "[![View on GitHub](https://img.shields.io/badge/-View%20on%20GitHub-128091?labelColor=grey&logo=github)](https://github.com/mrphys/tensorflow-mri/blob/master/tools/docs/tutorials/recon/cg_sense.ipynb)\n", + "[![Download notebook](https://img.shields.io/badge/-Download%20notebook-128091?labelColor=grey&logo=)](https://raw.githubusercontent.com/mrphys/tensorflow-mri/master/tools/docs/tutorials/recon/cg_sense.ipynb)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -970,214 +980,6 @@ "_ = plt.gcf().suptitle('Reconstructed image', color='w', fontsize=14)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# We will also try a 2D+t non-Cartesian SENSE example\n", - "\n", - "Firstly get the dataset from google drive\n", - "This is a prospective radial undersampled dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import gdown\n", - "\n", - "url = 'https://drive.google.com/uc?id=1nxJgqxOwFLIlO0Cz4NfhvYrB7_3C5Rhy'\n", - "output = '/workspaces/Tutorials/UPLOADED_radialCGsense2D/radiallyUndersampledProspectiveData_fromG.npy'\n", - "gdown.download(url, output, quiet=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now read the data, and calculate the trajectory and density weights for this prospective data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "raw_data = np.load(f'/workspaces/Tutorials/UPLOADED_radialCGsense2D/radiallyUndersampledProspectiveData_fromG.npy')\n", - "kspace = tf.cast(raw_data, dtype = tf.complex64)\n", - "\n", - "print('raw data shape:', raw_data.shape)\n", - "# (512, 30, 13, 27)\n", - "# nPtsPerSpoke, nCh, nSpokes, nTimePoints\n", - "\n", - "nSpokes = raw_data.shape[2]\n", - "nTimePts = raw_data.shape[3]\n", - "\n", - "kspace = np.transpose(kspace, [3,1,2,0]) \n", - "#(time, coils, spokes, readout)\n", - "sh = kspace.shape\n", - "kspace = tf.reshape(kspace,(sh[0],sh[1],sh[2]*sh[3]))\n", - "print('kspace shape: ', kspace.shape)\n", - "#(time, coils, spokes*readout)\n", - "# (27, 30, 6656)\n", - "\n", - "im_size = 256\n", - "image_shape = [im_size, im_size]\n", - "\n", - "# Compute trajectory.\n", - "traj = tfmri.sampling.radial_trajectory(base_resolution=im_size,\n", - " views=nSpokes,\n", - " phases=nTimePts,\n", - " ordering='sorted_half',\n", - " angle_range = 'full')\n", - "\n", - "print('traj shape: ', traj.shape)\n", - "#(time, spokes, readout, 2)\n", - "# (27, 13, 512, 2)\n", - "\n", - "# Compute density.\n", - "dens = tfmri.sampling.estimate_density(traj, image_shape, method=\"pipe\")\n", - "print('density.shape: ' + str(dens.shape))\n", - "# #(time, spokes, readout)\n", - "#density.shape: (27, 13, 512)\n", - "\n", - "# Flatten trajectory and density.\n", - "traj = tfmri.sampling.flatten_trajectory(traj)\n", - "# This should be size: [nTimePts, nPtsPerSpoke*nSpokes, 2]\n", - "#trajectory.shape: (27, 6656, 2)\n", - "\n", - "dens = tfmri.sampling.flatten_density(dens)\n", - "# This should be size: [nTimePts, nPtsPerSpoke*nSpokes]\n", - "#trajectory.shape: (27, 6656)\n", - "\n", - "# And compress to 12 coil elements\n", - "kspace = tfmri.coils.compress_coils(kspace, coil_axis=-2, out_coils=12)\n", - "print('kspace:', kspace.shape)\n", - "#(time, coils, spokes*readout)\n", - "#kspace: (27, 12, 6656)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And calculate the coil sensitivity info for this dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Now calcualte coil sensitivities by collapsing through time and gridding\n", - "\n", - "kSpaceCS = np.transpose(kspace, [1,0,2])\n", - "#(coils, time,spokes*readout)\n", - "# (12, 27, 6656)\n", - "\n", - "kSpaceCS = tf.reshape(kSpaceCS, [kSpaceCS.shape[0], kSpaceCS.shape[1]*kSpaceCS.shape[2]])\n", - "#kSpaceCS: (27, 199680)\n", - "trajCS = tf.reshape(traj, [traj.shape[0]*traj.shape[1], traj.shape[2]])\n", - "#trajCS: (179712, 2)\n", - "densCS = tf.reshape(dens, [dens.shape[0]*dens.shape[1]])\n", - "\n", - "# First let's filter the *k*-space data with a Hann window. We will apply the\n", - "# window to the central 20% of k-space (determined by the factor 5 below), the\n", - "# remaining 80% is filtered out completely.\n", - "filter_fn = lambda x: tfmri.signal.hann(5 * x)\n", - "\n", - "# Low-pass filtering of the k-space data.\n", - "filtered_kspace = tfmri.signal.filter_kspace(kSpaceCS,\n", - " trajectory=trajCS,\n", - " filter_fn=filter_fn)\n", - "\n", - "# Reconstruct low resolution estimates.\n", - "low_res_images = tfmri.recon.adjoint(filtered_kspace,\n", - " image_shape,\n", - " trajectory=trajCS,\n", - " density=densCS)\n", - "\n", - "_ = plot_tiled_images(tf.math.abs(low_res_images))\n", - "_ = plt.gcf().suptitle('Low-resolution images', color='w', fontsize=14)\n", - "\n", - "# Estimate the coil sensitivities.\n", - "coil_sens = tfmri.coils.estimate_sensitivities(\n", - " low_res_images, coil_axis=0, method='walsh')\n", - "\n", - "print('sensitivities.shape: ' + str(coil_sens.shape))\n", - "# This should be size: [nCoils, matrix_size, matrix_size]\n", - "#sensitivities.shape: (12, 256, 256)\n", - "\n", - "_ = plot_tiled_images(tf.math.abs(coil_sens))\n", - "_ = plt.gcf().suptitle('Coil Sensitivities', color='w', fontsize=14)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Lastly do iterative SENSE" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "domain_shape =[nTimePts, im_size, im_size] #, dtype=tf.int32)\n", - "\n", - "#Create regularizer.\n", - "tikhonov_parameter = 0.2\n", - "regularizer = tfmri.convex.ConvexFunctionTikhonov( scale=tikhonov_parameter, dtype=tf.complex64)\n", - "\n", - " \n", - "# this should have the shape [t*x*y,]\n", - "print('regularizer.shape: ' + str(regularizer.shape)) \n", - "# regularizer.shape: ((1769472,)\n", - "\n", - "senserecon = tfmri.recon.least_squares(kspace, # correct\n", - " image_shape, # correct\n", - " extra_shape=nTimePts, # correct\n", - " trajectory=traj, # correct\n", - " density=dens, # correct\n", - " sensitivities=coil_sens, # correct\n", - " regularizer=regularizer, # correct\n", - " optimizer='cg',\n", - " optimizer_kwargs={\n", - " 'max_iterations': 20\n", - " },\n", - " filter_corners=True)\n", - "\n", - "print(np.shape(senserecon))\n", - "\n", - "\n", - "# And lets visualise\n", - "plt.rcParams[\"animation.html\"] = \"jshtml\"\n", - "plt.rcParams['figure.dpi'] = 150 \n", - "plt.ioff()\n", - "fig, ax = plt.subplots()\n", - "\n", - "t= np.linspace(0,nTimePts)\n", - "def animate(t):\n", - " plt.imshow(tf.squeeze(tf.math.abs(senserecon[t,:,:]), axis=1), cmap = 'gray')\n", - " plt.title('iterative SENSE Recon')\n", - "\n", - "import matplotlib.animation\n", - "matplotlib.animation.FuncAnimation(fig, animate, frames=nTimePts)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Ths data is 24x undersampled so its not a great suprise that SENSE didnt resolve all of the artefacts!" - ] - }, { "cell_type": "markdown", "metadata": {},