From 644c016eed9abbc6ba75b28f516765efde63192e Mon Sep 17 00:00:00 2001 From: Ming Du Date: Thu, 13 Nov 2025 09:16:39 -0600 Subject: [PATCH 1/3] FIX: fix device assignment in `vignette` --- src/ptychi/image_proc.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/ptychi/image_proc.py b/src/ptychi/image_proc.py index f7a3b02..a735905 100644 --- a/src/ptychi/image_proc.py +++ b/src/ptychi/image_proc.py @@ -1413,12 +1413,14 @@ def generate_vignette_mask( shape: tuple[int, int], margin: int = 20, sigma: float = 1.0, - method: Literal["gaussian", "linear"] = "gaussian" + method: Literal["gaussian", "linear"] = "gaussian", + device: Optional[torch.device] = None, ): """ Generate a vignette mask for an image of shape `shape`. """ - mask = torch.ones(shape, device=torch.get_default_device()) + device = device or torch.get_default_device() + mask = torch.ones(shape, device=device) mask = vignette(mask, margin, sigma, method=method) return mask @@ -1469,7 +1471,7 @@ def vignette( mask = torch.zeros(mask_shape, device=img.device) mask_slicer = [slice(None)] * i_dim + [slice(margin, None)] mask[*mask_slicer] = 1.0 - gauss_win = torch.signal.windows.gaussian(margin // 2, std=sigma) + gauss_win = torch.signal.windows.gaussian(margin // 2, std=sigma, device=img.device) gauss_win = gauss_win / torch.sum(gauss_win) mask = convolve1d(mask, gauss_win, dim=i_dim, padding="same") mask_final_slicer = [slice(None)] * i_dim + [slice(len(gauss_win), len(gauss_win) + margin)] From bb5632dd34d3369897af5e8564020ea3b3b63d93 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Thu, 13 Nov 2025 09:38:06 -0600 Subject: [PATCH 2/3] FIX: fix device assignment in `fourier_gradient` --- src/ptychi/image_proc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ptychi/image_proc.py b/src/ptychi/image_proc.py index a735905..72b15ad 100644 --- a/src/ptychi/image_proc.py +++ b/src/ptychi/image_proc.py @@ -803,6 +803,8 @@ def fourier_gradient(image: Tensor) -> Tuple[Tensor, Tensor]: The y and x gradients. """ u, v = torch.fft.fftfreq(image.shape[-2]), torch.fft.fftfreq(image.shape[-1]) + u = u.to(image.device) + v = v.to(image.device) u, v = torch.meshgrid(u, v, indexing="ij") grad_y = torch.fft.ifft(torch.fft.fft(image, dim=-2) * (2j * torch.pi) * u, dim=-2) grad_x = torch.fft.ifft(torch.fft.fft(image, dim=-1) * (2j * torch.pi) * v, dim=-1) From b4e000baa66f3c223b0422d157d3633d30da5167 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Thu, 13 Nov 2025 10:03:03 -0600 Subject: [PATCH 3/3] FIX: fix device assignment in `integrate_image_2d_fourier` --- src/ptychi/image_proc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ptychi/image_proc.py b/src/ptychi/image_proc.py index 72b15ad..a14f637 100644 --- a/src/ptychi/image_proc.py +++ b/src/ptychi/image_proc.py @@ -952,6 +952,8 @@ def integrate_image_2d_fourier(grad_y: Tensor, grad_x: Tensor) -> Tensor: shape = grad_y.shape f = pmath.fft2_precise(grad_x + 1j * grad_y) y, x = torch.fft.fftfreq(shape[0]), torch.fft.fftfreq(shape[1]) + y = y.to(grad_y.device) + x = x.to(grad_y.device) # In PtychoShelves' get_img_int_2D.m, they set the numerator of r to be # exp(2j * pi * (x + y[:, None])) to shift it by 1 pixel. We should NOT