Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed .DS_Store
Binary file not shown.
Binary file removed tensorflow_mri/.DS_Store
Binary file not shown.
81 changes: 81 additions & 0 deletions tensorflow_mri/python/losses/iqa_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,87 @@ def ssim_multiscale_loss(y_true, y_pred, max_val=None,
rank=rank)



@tf.keras.utils.register_keras_serializable(package="MRI")
class MeanAbsoluteGradientError(LossFunctionWrapperIQA):
def __init__(self,
method='sobel',
norm=False,
batch_dims=None,
image_dims=None,
multichannel=True,
complex_part=None,
reduction=tf.keras.losses.Reduction.AUTO,
name='mean_absolute_gradient_error'):
super().__init__(mean_absolute_gradient_error,
reduction=reduction, name=name, method=method,
norm=norm, batch_dims=batch_dims, image_dims=image_dims,
multichannel=multichannel, complex_part=complex_part)


@tf.keras.utils.register_keras_serializable(package="MRI")
class MeanSquaredGradientError(LossFunctionWrapperIQA):
def __init__(self,
method='sobel',
norm=False,
batch_dims=None,
image_dims=None,
multichannel=True,
complex_part=None,
reduction=tf.keras.losses.Reduction.AUTO,
name='mean_squared_gradient_error'):
super().__init__(mean_squared_gradient_error,
reduction=reduction, name=name, method=method,
norm=norm, batch_dims=batch_dims, image_dims=image_dims,
multichannel=multichannel, complex_part=complex_part)

@tf.keras.utils.register_keras_serializable(package="MRI")
def mean_absolute_error(y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
return tf.math.reduce_mean(tf.math.abs(y_pred - y_true), axis=-1)

@tf.keras.utils.register_keras_serializable(package="MRI")
def mean_squared_error(y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
return tf.math.reduce_mean(
tf.math.real(tf.math.squared_difference(y_pred, y_true)), axis=-1)


@tf.keras.utils.register_keras_serializable(package="MRI")
def mean_absolute_gradient_error(y_true, y_pred, method='sobel',
norm=False, batch_dims=None, image_dims=None):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)

grad_true = image_ops.image_gradients(
y_true, method=method, norm=norm,
batch_dims=batch_dims, image_dims=image_dims)
grad_pred = image_ops.image_gradients(
y_pred, method=method, norm=norm,
batch_dims=batch_dims, image_dims=image_dims)

return mean_absolute_error(grad_true, grad_pred)


@tf.keras.utils.register_keras_serializable(package="MRI")
def mean_squared_gradient_error(y_true, y_pred, method='sobel',
norm=False, batch_dims=None, image_dims=None):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)

grad_true = image_ops.image_gradients(
y_true, method=method, norm=norm,
batch_dims=batch_dims, image_dims=image_dims)
grad_pred = image_ops. image_gradients(
y_pred, method=method, norm=norm,
batch_dims=batch_dims, image_dims=image_dims)

return mean_squared_error(grad_true, grad_pred)



# For backward compatibility.
@tf.keras.utils.register_keras_serializable(package="MRI")
class StructuralSimilarityLoss(SSIMLoss):
Expand Down
55 changes: 55 additions & 0 deletions tensorflow_mri/python/metrics/iqa_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,61 @@ def __init__(self,
multichannel=multichannel,
complex_part=complex_part)

# We register this object with the Keras serialization framework under a
# different name, in order to avoid clashing with the loss of the same name.
@api_util.export("metrics.MeanAbsGrad",
"metrics.MeanAbsoluteGradientErrorMetric")
@tf.keras.utils.register_keras_serializable(
package='MRI', name='MeanAbsoluteGradientErrorMetric')
class MeanAbsoluteGradientError(MeanMetricWrapperIQA):

def __init__(self,
method='sobel',
norm=False,
batch_dims=None,
image_dims=None,
multichannel=True,
complex_part=None,
name='mage',
dtype=None):
super().__init__(image_ops.mean_absolute_gradient_error,
method=method,
norm=norm,
batch_dims=batch_dims,
image_dims=image_dims,
multichannel=multichannel,
complex_part=complex_part,
name=name,
dtype=dtype)


# We register this object with the Keras serialization framework under a
# different name, in order to avoid clashing with the loss of the same name.
@api_util.export("metrics.MeanSqGrad",
"metrics.MeanSquaredGradientErrorMetric")
@tf.keras.utils.register_keras_serializable(
package='MRI', name='MeanSquaredGradientErrorMetric')
class MeanSquaredGradientError(MeanMetricWrapperIQA):

def __init__(self,
method='sobel',
norm=False,
batch_dims=None,
image_dims=None,
multichannel=True,
complex_part=None,
name='msge',
dtype=None):
super().__init__(image_ops.mean_squared_gradient_error,
method=method,
norm=norm,
batch_dims=batch_dims,
image_dims=image_dims,
multichannel=multichannel,
complex_part=complex_part,
name=name,
dtype=dtype)


# For backward compatibility.
@tf.keras.utils.register_keras_serializable(package="MRI")
Expand Down
80 changes: 80 additions & 0 deletions tensorflow_mri/python/ops/image_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensorflow_mri.python.util import api_util
from tensorflow_mri.python.util import check_util
from tensorflow_mri.python.util import deprecation
from tensorflow_mri.python.losses import iqa_losses


@api_util.export("image.psnr")
Expand Down Expand Up @@ -1025,6 +1026,85 @@ def _gradient_operators(method, norm=False, rank=2, dtype=tf.float32):
kernels[d] *= operator_1d
return tf.stack(kernels, axis=0)

@tf.keras.utils.register_keras_serializable(package="MRI")
class MeanAbsoluteGradientError(iqa_losses.LossFunctionWrapperIQA):
def __init__(self,
method='sobel',
norm=False,
batch_dims=None,
image_dims=None,
multichannel=True,
complex_part=None,
reduction=tf.keras.losses.Reduction.AUTO,
name='mean_absolute_gradient_error'):
super().__init__(mean_absolute_gradient_error,
reduction=reduction, name=name, method=method,
norm=norm, batch_dims=batch_dims, image_dims=image_dims,
multichannel=multichannel, complex_part=complex_part)


@tf.keras.utils.register_keras_serializable(package="MRI")
class MeanSquaredGradientError(iqa_losses.LossFunctionWrapperIQA):
def __init__(self,
method='sobel',
norm=False,
batch_dims=None,
image_dims=None,
multichannel=True,
complex_part=None,
reduction=tf.keras.losses.Reduction.AUTO,
name='mean_squared_gradient_error'):
super().__init__(mean_squared_gradient_error,
reduction=reduction, name=name, method=method,
norm=norm, batch_dims=batch_dims, image_dims=image_dims,
multichannel=multichannel, complex_part=complex_part)

@tf.keras.utils.register_keras_serializable(package="MRI")
def mean_absolute_error(y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
return tf.math.reduce_mean(tf.math.abs(y_pred - y_true), axis=-1)

@tf.keras.utils.register_keras_serializable(package="MRI")
def mean_squared_error(y_true, y_pred):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
return tf.math.reduce_mean(
tf.math.real(tf.math.squared_difference(y_pred, y_true)), axis=-1)


@tf.keras.utils.register_keras_serializable(package="MRI")
def mean_absolute_gradient_error(y_true, y_pred, method='sobel',
norm=False, batch_dims=None, image_dims=None):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)

grad_true = image_gradients(
y_true, method=method, norm=norm,
batch_dims=batch_dims, image_dims=image_dims)
grad_pred = image_gradients(
y_pred, method=method, norm=norm,
batch_dims=batch_dims, image_dims=image_dims)

return mean_absolute_error(grad_true, grad_pred)


@tf.keras.utils.register_keras_serializable(package="MRI")
def mean_squared_gradient_error(y_true, y_pred, method='sobel',
norm=False, batch_dims=None, image_dims=None):
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)

grad_true = image_gradients(
y_true, method=method, norm=norm,
batch_dims=batch_dims, image_dims=image_dims)
grad_pred = image_gradients(
y_pred, method=method, norm=norm,
batch_dims=batch_dims, image_dims=image_dims)

return mean_squared_error(grad_true, grad_pred)



def _filter_image(image, kernels):
"""Filters an image using the specified kernels.
Expand Down
Binary file removed tests/.DS_Store
Binary file not shown.
Binary file removed tools/.DS_Store
Binary file not shown.
32 changes: 0 additions & 32 deletions tools/docs/guide/contribute.ipynb

This file was deleted.

101 changes: 101 additions & 0 deletions tools/docs/guide/fft.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fast Fourier transform (FFT)\n",
"\n",
"TensorFlow MRI uses the built-in FFT ops in core TensorFlow. These are [`tf.signal.fft`](https://www.tensorflow.org/api_docs/python/tf/signal/fft), [`tf.signal.fft2d`](https://www.tensorflow.org/api_docs/python/tf/signal/fft2d) and [`tf.signal.fft3d`](https://www.tensorflow.org/api_docs/python/tf/signal/fft3d).\n",
"\n",
"## N-dimensional FFT\n",
"\n",
"For convenience, TensorFlow MRI also provides [`tfmri.signal.fft`](https://mrphys.github.io/tensorflow-mri/api_docs/tfmri/signal/fft/), which can be used for N-dimensional FFT calculations and provides convenient access to commonly used functionality such as padding/cropping, normalization and shifting of the zero-frequency component within the same function call.\n",
"\n",
"## Custom FFT kernels for CPU\n",
"\n",
"Unfortunately, TensorFlow's FFT ops are [known to be slow](https://github.com/tensorflow/tensorflow/issues/6541) on CPU. As a result, the FFT can become a significant bottleneck on MRI processing pipelines, especially on iterative reconstructions where the FFT is called repeatedly.\n",
"\n",
"To address this issue, TensorFlow MRI provides a set of custom FFT kernels based on the FFTW library. These offer a significant boost in performance compared to the kernels in core TensorFlow.\n",
"\n",
"The custom FFT kernels are automatically registered to the TensorFlow framework when importing TensorFlow MRI. If you have imported TensorFlow MRI, then the standard FFT ops will use the optimized kernels automatically.\n",
"\n",
"```{tip}\n",
"You only need to `import tensorflow_mri` in order to use the custom FFT kernels. You can then access them as usual through `tf.signal.fft`, `tf.signal.fft2d` and `tf.signal.fft3d`.\n",
"```\n",
"\n",
"The only caveat is that the [FFTW license](https://www.fftw.org/doc/License-and-Copyright.html) is more restrictive than the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0) used by TensorFlow MRI. In particular, GNU GPL requires you to distribute any derivative software under equivalent terms.\n",
"\n",
"```{warning}\n",
"If you intend to use custom FFT kernels for commercial purposes, you will need to purchase a commercial FFTW license.\n",
"```\n",
"\n",
"### Disable the use of custom FFT kernels\n",
"\n",
"You can control whether custom FFT kernels are used via the `TFMRI_USE_CUSTOM_FFT` environment variable. When set to false, TensorFlow MRI will not register its custom FFT kernels, falling back to the standard FFT kernels in core TensorFlow. If the variable is unset, its value defaults to true.\n",
"\n",
"````{tip}\n",
"Set `TFMRI_USE_CUSTOM_FFT=0` to disable the custom FFT kernels.\n",
"\n",
"```python\n",
"os.environ[\"TFMRI_USE_CUSTOM_FFT\"] = \"0\"\n",
"import tensorflow_mri as tfmri\n",
"```\n",
"\n",
"```{attention}\n",
"`TFMRI_USE_CUSTOM_FFT` must be set **before** importing TensorFlow MRI. Setting or changing its value after importing the package will have no effect.\n",
"```\n",
"````\n",
"\n",
"### Customize the behavior of custom FFT kernels\n",
"\n",
"FFTW allows you to control the rigor of the planning process. The more rigorously a plan is created, the more efficient the actual FFT execution is likely to be, at the expense of a longer planning time. TensorFlow MRI lets you control the FFTW planning rigor through the `TFMRI_FFTW_PLANNING_RIGOR` environment variable. Valid values for this variable are:\n",
"\n",
"- `\"estimate\"` specifies that, instead of actual measurements of different algorithms, a simple heuristic is used to pick a (probably sub-optimal) plan quickly.\n",
"- `\"measure\"` tells FFTW to find an optimized plan by actually computing several FFTs and measuring their execution time. Depending on your machine, this can take some time (often a few seconds). This is the default planning option.\n",
"- `\"patient\"` is like `\"measure\"`, but considers a wider range of algorithms and often produces a “more optimal” plan (especially for large transforms), but at the expense of several times longer planning time (especially for large transforms).\n",
"- `\"exhaustive\"` is like `\"patient\"`, but considers an even wider range of algorithms, including many that we think are unlikely to be fast, to produce the most optimal plan but with a substantially increased planning time.\n",
"\n",
"````{tip}\n",
"Set the environment variable `TFMRI_FFTW_PLANNING_RIGOR` to control the planning rigor.\n",
"\n",
"```python\n",
"os.environ[\"TFMRI_FFTW_PLANNING_RIGOR\"] = \"estimate\"\n",
"import tensorflow_mri as tfmri\n",
"```\n",
"\n",
"```{attention}\n",
"`TFMRI_FFTW_PLANNING_RIGOR` must be set **before** importing TensorFlow MRI. Setting or changing its value after importing the package will have no effect.\n",
"```\n",
"````\n",
"\n",
"```{note}\n",
"FFTW accumulates \"wisdom\" each time the planner is called, and this wisdom is persisted across invocations of the FFT kernels (during the same process). Therefore, more rigorous planning options will result in long planning times during the first FFT invocation, but may result in faster execution during subsequent invocations. When performing a large amount of similar FFT invocations (e.g., while training a model or performing iterative reconstructions), you are more likely to benefit from more rigorous planning.\n",
"```\n",
"\n",
"```{seealso}\n",
"The FFTW [planner flags](https://www.fftw.org/doc/Planner-Flags.html) documentation page.\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.2 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.2"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "0adcc2737ebf6a4a119f135174df96668767fca1ef1112612db5ecadf2b6d608"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading