Skip to content

Conversation

@DSilva27
Copy link
Collaborator

No description provided.

@DSilva27 DSilva27 marked this pull request as draft December 30, 2025 19:53
@DSilva27 DSilva27 marked this pull request as ready for review January 5, 2026 17:25
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call these PhaseShiftFFT and RotateFFT. I know we currently have "fourier_image" as local variable names all over the code, but as we've started to have function / class names operating on FFTs, I'm slowly moving to real-space as image or fourier-space as image_fft (or just fft when appropriate). See the rescale_fft function in cryojax.ndimage.

)
self.is_rfft = is_rfft
if is_rfft:
self.shape = (frequency_grid.shape[0], (frequency_grid.shape[1] - 1) * 2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is that this is not always what the shape was in real-space and can be off by a pixel depending on parity. This is slightly problematic we don't know if the highest frequency in the x dim is self-conjugate or not. Instead you could just always assume image is square, get rid of self.shape, say shape = (y_dim, y_dim) in __call__, and put a warning in the docstring.

This class implements a phase shift of an image in Fourier space.
The shift is specified in pixels or angstroms along each axis.

Attributes:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can get rid of Attributes in docstring. When adding this to the docs, you can just document translation_operator directly. is_rfft is redundant since it's in __init__.

I know this model is slightly ambiguous at what is an attribute; this approach is more user focused (and attributes are immutable anyway in JAX!)

- `is_rfft`: Whether the frequency grid is in full or rfft format.
Right now only full format is supported for image rotation.

**Example:**
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Examples should go in class docstring, not __init__.

"""
**Arguments:**

- `offset`: The offset by which to shift the image, in pixels or angstroms.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if the extra indent here impacts the markdown. Can you give a shot at adding these transforms to the docs and render things locally to make sure things look right? For instructions, see CONTRIBUTING.md if you need.


def __init__(
self,
angle_degrees: float,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments on __init__ and attributes.

  1. angle_degrees and pixel_size should be attributes of type Float[Array, ""]. This allows them to work with JAX transformations, e.g. RotateFourierImage can be vectorized over them or calls under JIT won't recompile if they change. Additionally, cryojax.jax_util.FloatLike should be used for their __init__ type hint and they should be cast with jnp.asarray. Additionally, frequency_grid can be type hinted with NDArrayLike and cast.
  2. Keyword arguments should follow FourierSliceExtraction naming, i.e. order -> interpolation_order, cval -> fill_value, and mode -> out_of_bounds_mode. FWIW I don't mind getting rid of these arguments, or maybe just keeping interpolation_order. Whatever you find useful, we can always add them later.
  3. You could get rid of pixel_size all together and force that frequency_grid is in nyquist units. Up to you.
  4. Let's keep naming consistent between the angle here and the angle in cryojax.ndimage.CylindricalCosineMask. Unfortunately I've named it in_plane_rotation_angle there, which is a little long and not appropriate here. Maybe just rotation_angle and I'll change that at some point to the same?

- `is_rfft`: Whether the frequency grid is in full or rfft format.
Right now only full format is supported for image rotation.
- `pixel_size`: The pixel size in angstroms or the unit of choice.
- `order`: The order of the spline interpolation. Only 0 and 1 are supported.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments:

  • "Spline" specifically refers to cubic interpolation
  • On boolean docstrings:
    • is_rfft: If True, the frequency grid is X. Otherwise, Y.
  • You can add reference to map_coordinates for the interpolation options:
    • order: ... See [cryojax.ndimage.map_coordinates][] for more information.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

damm, copilot got me on this one. I didn't see it wrote spline, I only read interpolation

) -> Complex[Array, "y_dim x_dim"]:
"""Apply the phase shift to the input image in Fourier space.

Args:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like some old docstrings in the __call__ methods

s = jnp.sin(angle)
rotation_matrix = jnp.array([[c, -s], [s, c]])

rotated_grid = frequency_grid @ rotation_matrix
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are we sure we don't need to run rotated_grid = jnp.reshape(frequency_grid.reshape((N, 2)) @ rotation_matrix, frequency_grid.shape[0:2]), where N is math.prod(frequency_grid.shape[0:2])?

logical_frequency_coordinates = (rotated_grid * N) + N // 2
k_y, k_x = jnp.transpose(logical_frequency_coordinates, axes=[2, 0, 1])

image_fft = jnp.fft.fftshift(fourier_image)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your fftshifts look buggy to me. It seems you only shift the image, not the grid, and you don't shift back to the original convention.

  • If your FFT is the full fourier plane, before map_coordinates you need jnp.fft.fftshift(frequency_grid, axes=(0, 1)) as well as your jnp.fft.fftshift(image_fft). This puts zero frequency in the center so that out of bounds indexing in the interpolation is correct. Then, you just need to return jnp.fft.ifftshift(rotated_fft).
  • If your FFT is the half plane (i.e. rfft), instead just run jnp.fft.fftshift(image_fft, axes=(0,)) and jnp.fft.fftshift(frequency_grid, axes=(0,)) before and jnp.fft.ifftshift(rotated_fft, axes=(0,)) after.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants