-
Notifications
You must be signed in to change notification settings - Fork 14
Implement image translation and rotation operators #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call these PhaseShiftFFT and RotateFFT. I know we currently have "fourier_image" as local variable names all over the code, but as we've started to have function / class names operating on FFTs, I'm slowly moving to real-space as image or fourier-space as image_fft (or just fft when appropriate). See the rescale_fft function in cryojax.ndimage.
| ) | ||
| self.is_rfft = is_rfft | ||
| if is_rfft: | ||
| self.shape = (frequency_grid.shape[0], (frequency_grid.shape[1] - 1) * 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue here is that this is not always what the shape was in real-space and can be off by a pixel depending on parity. This is slightly problematic we don't know if the highest frequency in the x dim is self-conjugate or not. Instead you could just always assume image is square, get rid of self.shape, say shape = (y_dim, y_dim) in __call__, and put a warning in the docstring.
| This class implements a phase shift of an image in Fourier space. | ||
| The shift is specified in pixels or angstroms along each axis. | ||
|
|
||
| Attributes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can get rid of Attributes in docstring. When adding this to the docs, you can just document translation_operator directly. is_rfft is redundant since it's in __init__.
I know this model is slightly ambiguous at what is an attribute; this approach is more user focused (and attributes are immutable anyway in JAX!)
| - `is_rfft`: Whether the frequency grid is in full or rfft format. | ||
| Right now only full format is supported for image rotation. | ||
|
|
||
| **Example:** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Examples should go in class docstring, not __init__.
| """ | ||
| **Arguments:** | ||
|
|
||
| - `offset`: The offset by which to shift the image, in pixels or angstroms. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if the extra indent here impacts the markdown. Can you give a shot at adding these transforms to the docs and render things locally to make sure things look right? For instructions, see CONTRIBUTING.md if you need.
|
|
||
| def __init__( | ||
| self, | ||
| angle_degrees: float, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments on __init__ and attributes.
angle_degreesandpixel_sizeshould be attributes of type Float[Array, ""]. This allows them to work with JAX transformations, e.g.RotateFourierImagecan be vectorized over them or calls under JIT won't recompile if they change. Additionally,cryojax.jax_util.FloatLikeshould be used for their__init__type hint and they should be cast withjnp.asarray. Additionally,frequency_gridcan be type hinted withNDArrayLikeand cast.- Keyword arguments should follow
FourierSliceExtractionnaming, i.e.order->interpolation_order,cval->fill_value, andmode->out_of_bounds_mode. FWIW I don't mind getting rid of these arguments, or maybe just keepinginterpolation_order. Whatever you find useful, we can always add them later. - You could get rid of
pixel_sizeall together and force thatfrequency_gridis in nyquist units. Up to you. - Let's keep naming consistent between the angle here and the angle in
cryojax.ndimage.CylindricalCosineMask. Unfortunately I've named itin_plane_rotation_anglethere, which is a little long and not appropriate here. Maybe justrotation_angleand I'll change that at some point to the same?
| - `is_rfft`: Whether the frequency grid is in full or rfft format. | ||
| Right now only full format is supported for image rotation. | ||
| - `pixel_size`: The pixel size in angstroms or the unit of choice. | ||
| - `order`: The order of the spline interpolation. Only 0 and 1 are supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments:
- "Spline" specifically refers to cubic interpolation
- On boolean docstrings:
is_rfft: IfTrue, the frequency grid is X. Otherwise, Y.
- You can add reference to
map_coordinatesfor the interpolation options:order: ... See [cryojax.ndimage.map_coordinates][] for more information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
damm, copilot got me on this one. I didn't see it wrote spline, I only read interpolation
| ) -> Complex[Array, "y_dim x_dim"]: | ||
| """Apply the phase shift to the input image in Fourier space. | ||
|
|
||
| Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like some old docstrings in the __call__ methods
| s = jnp.sin(angle) | ||
| rotation_matrix = jnp.array([[c, -s], [s, c]]) | ||
|
|
||
| rotated_grid = frequency_grid @ rotation_matrix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are we sure we don't need to run rotated_grid = jnp.reshape(frequency_grid.reshape((N, 2)) @ rotation_matrix, frequency_grid.shape[0:2]), where N is math.prod(frequency_grid.shape[0:2])?
| logical_frequency_coordinates = (rotated_grid * N) + N // 2 | ||
| k_y, k_x = jnp.transpose(logical_frequency_coordinates, axes=[2, 0, 1]) | ||
|
|
||
| image_fft = jnp.fft.fftshift(fourier_image) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your fftshifts look buggy to me. It seems you only shift the image, not the grid, and you don't shift back to the original convention.
- If your FFT is the full fourier plane, before
map_coordinatesyou needjnp.fft.fftshift(frequency_grid, axes=(0, 1))as well as yourjnp.fft.fftshift(image_fft). This puts zero frequency in the center so that out of bounds indexing in the interpolation is correct. Then, you just need to returnjnp.fft.ifftshift(rotated_fft). - If your FFT is the half plane (i.e. rfft), instead just run
jnp.fft.fftshift(image_fft, axes=(0,))andjnp.fft.fftshift(frequency_grid, axes=(0,))before andjnp.fft.ifftshift(rotated_fft, axes=(0,))after.
No description provided.