Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/source/using_pty_chi/devices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,54 @@ and `ptychi`::

torch.set_default_device("xpu")
ptychi.device.set_torch_accelerator_module(torch.xpu)

Multi-GPU and multi-processing
------------------------------

Some reconstruction engines support multi-GPU and/or multi-processing. The biggest benefit
of using multi-GPU or multi-processing is to split the computation of update vectors across
different devices, reducing the VRAM usage on each device. Note that multi-processing does
not always make the computation faster unless the data size is very large because it incurs
communication overhead.

Multi-GPU engines
+++++++++++++++++

The automatic differentiation (Autodiff) engine supports multi-GPU through PyTorch's
``DataParallel`` wrapper. The reconstructor uses all available GPUs by default without
additional settings. To limit it to a single GPU, set ``CUDA_VISIBLE_DEVICES`` before
launching the reconstruction job::

export CUDA_VISIBLE_DEVICES=0

Multi-processing engines
++++++++++++++++++++++++

Multi-GPU support is enabled for some analytical engine(s) (currently only LSQML)
through the multi-processing feature of PyTorch in ``torch.distributed``. To
enable multi-processing, you must launch the reconstruction script using ``torchrun``::

torchrun --nnodes=1 --nproc_per_node=2 reconstruction_script.py

The ``--nnodes`` and ``--nproc_per_node`` arguments specify the number of nodes and
the number of processes per node, respectively. For single-node machines, keep it to 1.
When a job is launched in this way, Pty-Chi will sign a rank to the GPU indexed
``rank % n_gpus`` where ``n_gpus`` is the number of GPUs available, so as to max
out the number of GPUs while minimizing the number of ranks on each GPU. It is
not recommended, and in some cases not allowed to use launch more processes than
the number of GPUs.

``torchrun`` spawns all processes at the beginning, so the reconstruction script
will also be executed in all processes. If you have post-analysis or data saving
routines in that script, make sure they don't produce unexpected results when executed
in multiple processes. It is generally advised to execute such routines only on rank 0::

import torch.distributed as dist

# Set up and run task

if dist.get_rank() == 0:
# Do post-analysis or data saving

``dist.get_rank()`` is only callable after the task object is instantiated
where it initializes the process group.
2 changes: 1 addition & 1 deletion docs/source/using_pty_chi/engines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The table below compares their merits and limitations.
- No
- Yes
* - GPU support
- Single
- Multiple (multi-process)
- Single
- Single
- Multiple
Expand Down
39 changes: 34 additions & 5 deletions src/ptychi/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@
from ptychi.timing import timer_utils
import ptychi.movies as movies
from ptychi.device import AcceleratorModuleWrapper
from ptychi.parallel import MultiprocessMixin

logger = logging.getLogger(__name__)


class Task:
class Task(MultiprocessMixin):
def __init__(self, options: api.options.base.TaskOptions, *args, **kwargs) -> None:
pass

Expand Down Expand Up @@ -86,6 +87,7 @@ def build(self):
self.build_random_seed()
self.build_default_device()
self.build_default_dtype()
self.build_logger()
self.build_data()
self.build_object()
self.build_probe()
Expand All @@ -103,7 +105,25 @@ def build_random_seed(self):
def build_default_device(self):
accelerator_module = AcceleratorModuleWrapper.get_module()

torch.set_default_device(maps.get_device_by_enum(self.reconstructor_options.default_device))
if self.detect_launcher() is None:
torch.set_default_device(maps.get_device_by_enum(self.reconstructor_options.default_device))
else:
self.init_process_group()

if self.backend == "nccl" and self.n_ranks > accelerator_module.device_count():
raise ValueError(
f"Number of ranks ({self.n_ranks}) is greater than the number of devices "
f"({accelerator_module.device_count()}). This is not allowed with NCCL backend."
)

if self.n_ranks == 1:
torch.set_default_device(maps.get_device_by_enum(self.reconstructor_options.default_device))
else:
logging.info(f"Multi-processing mode detected with {self.n_ranks} ranks.")
torch.set_default_device(
f"{AcceleratorModuleWrapper.get_to_device_string()}:{self.rank % accelerator_module.device_count()}"
)

if accelerator_module.device_count() > 0:
cuda_visible_devices_str = "(unset)"
if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
Expand All @@ -117,6 +137,10 @@ def build_default_device(self):
else:
logger.info("Using device: {}".format(torch.get_default_device()))

def build_logger(self):
if self.rank != 0:
logger.setLevel(level=logging.ERROR)

def build_default_dtype(self):
torch.set_default_dtype(maps.get_dtype_by_enum(self.reconstructor_options.default_dtype))
utils.set_default_complex_dtype(
Expand Down Expand Up @@ -205,9 +229,14 @@ def build_reconstructor(self):
opr_mode_weights=self.opr_mode_weights,
)

reconstructor_class = maps.get_reconstructor_by_enum(
self.reconstructor_options.get_reconstructor_type()
)
if self.n_ranks == 1:
reconstructor_class = maps.get_reconstructor_by_enum(
self.reconstructor_options.get_reconstructor_type()
)
else:
reconstructor_class = maps.get_multiprocess_reconstructor_by_enum(
self.reconstructor_options.get_reconstructor_type()
)

reconstructor_kwargs = {
"parameter_group": par_group,
Expand Down
2 changes: 2 additions & 0 deletions src/ptychi/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch


Numeric: TypeAlias = int | float | complex

BooleanArray: TypeAlias = numpy.typing.NDArray[numpy.bool_]
IntegerArray: TypeAlias = numpy.typing.NDArray[numpy.integer[Any]]
RealArray: TypeAlias = numpy.typing.NDArray[numpy.floating[Any]]
Expand Down
35 changes: 27 additions & 8 deletions src/ptychi/data_structures/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright © 2025 UChicago Argonne, LLC All right reserved
# Full license accessible at https://github.com//AdvancedPhotonSource/pty-chi/blob/main/LICENSE

from typing import Optional, Union, Tuple, Sequence, TYPE_CHECKING
from typing import Optional, Union, Tuple, Sequence, Literal, TYPE_CHECKING
import logging

import torch
Expand Down Expand Up @@ -67,15 +67,18 @@ def complex(self) -> Tensor:
def shape(self) -> Tuple[int, ...]:
return self.data.shape[:-1]

def set_data(self, data: Union[Tensor, ndarray], slicer=None):
def set_data(self, data: Union[Tensor, ndarray], slicer=None, op: Literal["add", "set"] = "set"):
if slicer is None:
slicer = (slice(None),)
elif not isinstance(slicer, Sequence):
slicer = (slicer,)
data = to_tensor(data)
data = torch.stack([data.real, data.imag], dim=-1)
data = data.type(torch.get_default_dtype())
self.data[*slicer].copy_(to_tensor(data))
if op == "add":
self.data[*slicer].copy_(self.data[*slicer] + to_tensor(data))
else:
self.data[*slicer].copy_(to_tensor(data))


class ReconstructParameter(Module):
Expand Down Expand Up @@ -234,16 +237,22 @@ def get_tensor(self, name):
return var

def set_data(
self, data, slicer: Optional[Union[slice, int] | tuple[Union[slice, int], ...]] = None
self,
data,
slicer: Optional[Union[slice, int] | tuple[Union[slice, int], ...]] = None,
op: Literal["add", "set"] = "set",
):
if slicer is None:
slicer = (slice(None),)
elif not isinstance(slicer, Sequence):
slicer = (slicer,)
if isinstance(self.tensor, ComplexTensor):
self.tensor.set_data(data, slicer=slicer)
self.tensor.set_data(data, slicer=slicer, op=op)
else:
self.tensor[*slicer].copy_(to_tensor(data))
if op == "add":
self.tensor[*slicer].copy_(self.data + to_tensor(data))
else:
self.tensor[*slicer].copy_(to_tensor(data))

def get_grad(self):
if isinstance(self.tensor, ComplexTensor):
Expand All @@ -255,6 +264,7 @@ def set_grad(
self,
grad: Tensor,
slicer: Optional[Union[slice, int] | tuple[Union[slice, int], ...]] = None,
op: Literal["add", "set"] = "set",
):
"""
Populate the `grad` field of the contained tensor, so that it can optimized
Expand All @@ -272,6 +282,9 @@ def set_grad(
A tuple of, or a single slice object or integer, that defines the region of
the region of the gradient to update. The shape of `grad` should match
the region given by `slicer`, if given. If None, the whole gradient is updated.
op : Literal["add", "set"]
The operation to perform on the gradient. If "add", the gradient is added to the existing gradient.
If "set", the gradient is set to the given value.
"""
if self.tensor.data.grad is None and slicer is not None:
raise ValueError("Setting gradient with slicing is not allowed when gradient is None.")
Expand All @@ -286,12 +299,18 @@ def set_grad(
if self.tensor.data.grad is None:
self.tensor.data.grad = grad
else:
self.tensor.data.grad[*slicer, ..., :] = grad
if op == "add":
self.tensor.data.grad[*slicer, ..., :] += grad
else:
self.tensor.data.grad[*slicer, ..., :] = grad
else:
if self.tensor.grad is None:
self.tensor.grad = grad
else:
self.tensor.grad[*slicer] = grad
if op == "add":
self.tensor.grad[*slicer] += grad
else:
self.tensor.grad[*slicer] = grad

def initialize_grad(self):
"""
Expand Down
54 changes: 41 additions & 13 deletions src/ptychi/data_structures/opr_mode_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,17 @@ def update_variable_probe(
obj_patches: Tensor,
current_epoch: int,
probe_mode_index: Optional[int] = None,
apply_updates: bool = True,
):
"""Update the OPR mode weights and eigenmodes of the probe.

Parameters
----------
apply_updates : bool
If True, the data of the probe and OPR weights are modified with the
update vectors. Otherwise, the update vectors will be saved in the
``grad`` attribute of the probe and OPR weight objects.
"""
# TODO: OPR updates are calculated using chi with uniquely shifted probes,
# probe without shift, and probe updates that are adjoint-shifted. This is
# not accurate, but PtychoShelves does the same. We might need to revisit this.
Expand All @@ -118,7 +128,14 @@ def update_variable_probe(
or (self.eigenmode_weight_optimization_enabled(current_epoch))
):
self.update_opr_probe_modes_and_weights(
probe, indices, chi, delta_p_i, delta_p_hat, obj_patches, current_epoch
probe,
indices,
chi,
delta_p_i,
delta_p_hat,
obj_patches,
current_epoch,
apply_updates=apply_updates,
)

if self.intensity_variation_optimization_enabled(current_epoch):
Expand All @@ -128,7 +145,10 @@ def update_variable_probe(
chi,
obj_patches,
)
self._apply_variable_intensity_updates(delta_weights_int)
if apply_updates:
self.set_data(self.data + 0.1 * delta_weights_int)
else:
self.set_grad(-0.1 * delta_weights_int, op="add")

@timer()
def update_opr_probe_modes_and_weights(
Expand All @@ -140,12 +160,20 @@ def update_opr_probe_modes_and_weights(
delta_p_hat: Tensor,
obj_patches: Tensor,
current_epoch: int,
apply_updates: bool = True,
):
"""
Update the eigenmodes of the first incoherent mode of the probe, and update the OPR mode weights.

This implementation is adapted from PtychoShelves code (update_variable_probe.m) and has some
differences from Eq. 31 of Odstrcil (2018).

Parameters
----------
apply_updates : bool
If True, the data of the probe and OPR weights are modified with the
update vectors. Otherwise, the update vectors will be saved in the
``grad`` attribute of the probe and OPR weight objects.
"""
probe_data = probe.data
weights_data = self.data
Expand All @@ -172,7 +200,7 @@ def update_opr_probe_modes_and_weights(
# Just take the first incoherent mode.
eigenmode_i = probe.get_mode_and_opr_mode(mode=0, opr_mode=i_opr_mode)
weights_i = self.get_weights(indices)[:, i_opr_mode]
eigenmode_i, weights_i = self._update_first_eigenmode_and_weight(
eigenmode_i, weights_i = self._calculate_updated_first_eigenmode_and_weight(
residue_update,
eigenmode_i,
weights_i,
Expand All @@ -193,13 +221,18 @@ def update_opr_probe_modes_and_weights(
probe_data[i_opr_mode, 0, :, :] = eigenmode_i
weights_data[indices, i_opr_mode] = weights_i

if probe.optimization_enabled(current_epoch):
probe.set_data(probe_data)
if self.eigenmode_weight_optimization_enabled(current_epoch):
self.set_data(weights_data)
if apply_updates:
if probe.optimization_enabled(current_epoch):
probe.set_data(probe_data)
if self.eigenmode_weight_optimization_enabled(current_epoch):
self.set_data(weights_data)
else:
# Gradient is the negative of the update vector.
probe.set_grad(probe.data - probe_data, op="add")
self.set_grad(self.data - weights_data, op="add")

@timer()
def _update_first_eigenmode_and_weight(
def _calculate_updated_first_eigenmode_and_weight(
self,
residue_update: Tensor,
eigenmode_i: Tensor,
Expand All @@ -208,7 +241,6 @@ def _update_first_eigenmode_and_weight(
relax_v: Tensor,
obj_patches: Tensor,
chi: Tensor,
eps=1e-5,
update_eigenmode=True,
update_weights=True,
):
Expand Down Expand Up @@ -276,10 +308,6 @@ def _calculate_intensity_variation_update_direction(
delta_weights_int[indices, 0] = delta_weights_int_i
return delta_weights_int

@timer()
def _apply_variable_intensity_updates(self, delta_weights_int: Tensor):
self.set_data(self.data + 0.1 * delta_weights_int)

@timer()
def smooth_weights(self):
"""
Expand Down
Loading