diff --git a/src/ptychi/image_proc.py b/src/ptychi/image_proc.py index f7a3b02..a14f637 100644 --- a/src/ptychi/image_proc.py +++ b/src/ptychi/image_proc.py @@ -803,6 +803,8 @@ def fourier_gradient(image: Tensor) -> Tuple[Tensor, Tensor]: The y and x gradients. """ u, v = torch.fft.fftfreq(image.shape[-2]), torch.fft.fftfreq(image.shape[-1]) + u = u.to(image.device) + v = v.to(image.device) u, v = torch.meshgrid(u, v, indexing="ij") grad_y = torch.fft.ifft(torch.fft.fft(image, dim=-2) * (2j * torch.pi) * u, dim=-2) grad_x = torch.fft.ifft(torch.fft.fft(image, dim=-1) * (2j * torch.pi) * v, dim=-1) @@ -950,6 +952,8 @@ def integrate_image_2d_fourier(grad_y: Tensor, grad_x: Tensor) -> Tensor: shape = grad_y.shape f = pmath.fft2_precise(grad_x + 1j * grad_y) y, x = torch.fft.fftfreq(shape[0]), torch.fft.fftfreq(shape[1]) + y = y.to(grad_y.device) + x = x.to(grad_y.device) # In PtychoShelves' get_img_int_2D.m, they set the numerator of r to be # exp(2j * pi * (x + y[:, None])) to shift it by 1 pixel. We should NOT @@ -1413,12 +1417,14 @@ def generate_vignette_mask( shape: tuple[int, int], margin: int = 20, sigma: float = 1.0, - method: Literal["gaussian", "linear"] = "gaussian" + method: Literal["gaussian", "linear"] = "gaussian", + device: Optional[torch.device] = None, ): """ Generate a vignette mask for an image of shape `shape`. """ - mask = torch.ones(shape, device=torch.get_default_device()) + device = device or torch.get_default_device() + mask = torch.ones(shape, device=device) mask = vignette(mask, margin, sigma, method=method) return mask @@ -1469,7 +1475,7 @@ def vignette( mask = torch.zeros(mask_shape, device=img.device) mask_slicer = [slice(None)] * i_dim + [slice(margin, None)] mask[*mask_slicer] = 1.0 - gauss_win = torch.signal.windows.gaussian(margin // 2, std=sigma) + gauss_win = torch.signal.windows.gaussian(margin // 2, std=sigma, device=img.device) gauss_win = gauss_win / torch.sum(gauss_win) mask = convolve1d(mask, gauss_win, dim=i_dim, padding="same") mask_final_slicer = [slice(None)] * i_dim + [slice(len(gauss_win), len(gauss_win) + margin)]