Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ How to run test scripts
repository: ``export PTYCHO_CI_DATA_DIR="path_to_data_repo/ci_data"``.
4. Run any test scripts in ``tests`` with Python.

======================
To use non-Nvidia GPUs
======================

Pty-Chi works on GPUs from different vendors than NVidia. For example, Intel.
To run Pty-Chi with Intel GPUs, add these lines right after importing `torch`
and `ptychi`::

torch.set_default_device("xpu")
ptychi.device.set_torch_accelerator_module(torch.xpu)


======================
Reading documentations
Expand Down
10 changes: 10 additions & 0 deletions docs/source/using_pty_chi/devices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,13 @@ Note that it is always recommended to set the variable in terminal before runnin
If you have to set the variable in the Python code, make sure to set it before importing PyTorch
using ``os.environ["CUDA_VISIBLE_DEVICES"] = "<GPU index>"``. Setting the variable in Python
will not take effect if it is done after PyTorch is imported.

Non-Nvidia GPUs
---------------

Pty-Chi works on GPUs from different vendors than NVidia. For example, Intel.
To run Pty-Chi with Intel GPUs, add these lines right after importing `torch`
and `ptychi`::

torch.set_default_device("xpu")
ptychi.device.set_torch_accelerator_module(torch.xpu)
9 changes: 6 additions & 3 deletions src/ptychi/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ptychi.maths as pmath
from ptychi.timing import timer_utils
import ptychi.movies as movies
from ptychi.device import AcceleratorModuleWrapper

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,7 +55,7 @@ def __exit__(
exception_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
torch.cuda.empty_cache()
AcceleratorModuleWrapper.get_module().empty_cache()


class PtychographyTask(Task):
Expand Down Expand Up @@ -100,14 +101,16 @@ def build_random_seed(self):
pmath.set_allow_nondeterministic_algorithms(self.reconstructor_options.allow_nondeterministic_algorithms)

def build_default_device(self):
accelerator_module = AcceleratorModuleWrapper.get_module()

torch.set_default_device(maps.get_device_by_enum(self.reconstructor_options.default_device))
if torch.cuda.device_count() > 0:
if accelerator_module.device_count() > 0:
cuda_visible_devices_str = "(unset)"
if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
cuda_visible_devices_str = os.environ["CUDA_VISIBLE_DEVICES"]
logger.info(
"Using device: {} (CUDA_VISIBLE_DEVICES=\"{}\")".format(
[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
[accelerator_module.get_device_name(i) for i in range(accelerator_module.device_count())],
cuda_visible_devices_str,
)
)
Expand Down
72 changes: 61 additions & 11 deletions src/ptychi/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,36 @@

from collections.abc import Sequence
from dataclasses import dataclass
from types import ModuleType

import torch


_torch_accelerator_module = torch.cuda


class AcceleratorModuleWrapper:
"""A wrapper class for the accelerator device module of PyTorch.
"""

@classmethod
def get_module(cls) -> ModuleType:
return get_torch_accelerator_module()

@classmethod
def set_module(cls, module: ModuleType):
set_torch_accelerator_module(module)

@classmethod
def get_to_device_string(cls) -> str:
if cls.get_module() == torch.cuda:
return "cuda"
elif cls.get_module() == torch.xpu:
return "xpu"
else:
raise ValueError(f"Unsupported accelerator module: {cls.get_module()}")


@dataclass(frozen=True)
class Device:
backend: str
Expand All @@ -17,19 +43,43 @@ class Device:
def torch_device(self) -> str:
return f"{self.backend.lower()}:{self.ordinal}"


def list_available_devices() -> Sequence[Device]:
available_devices = list()

if torch.cuda.is_available():
for ordinal in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(ordinal)
device = Device("cuda", ordinal, name)
available_devices.append(device)

if torch.xpu.is_available():
for ordinal in range(torch.xpu.device_count()):
name = torch.xpu.get_device_name(ordinal)
device = Device("xpu", ordinal, name)
accelerator_module_wrapper = AcceleratorModuleWrapper()
accelerator_module = accelerator_module_wrapper.get_module()

if accelerator_module.is_available():
for ordinal in range(accelerator_module.device_count()):
name = accelerator_module.get_device_name(ordinal)
device = Device(accelerator_module_wrapper.get_to_device_string(), ordinal, name)
available_devices.append(device)

return available_devices


def set_torch_accelerator_module(module: ModuleType):
"""Set the global variable of the torch accelerator module.
By default, it is `torch.cuda`.

For Intel GPUs, use `torch.xpu`.

Parameters
----------
module: ModuleType
The torch accelerator module.
"""
global _torch_accelerator_module
_torch_accelerator_module = module


def get_torch_accelerator_module() -> ModuleType:
"""Get the global variable of the torch accelerator module.
By default, it is `torch.cuda`.

Returns
-------
ModuleType
The torch accelerator module.
"""
return _torch_accelerator_module
3 changes: 2 additions & 1 deletion src/ptychi/io_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import scipy.spatial

from ptychi.device import AcceleratorModuleWrapper
from ptychi.utils import to_tensor, to_numpy
import ptychi.maths as pmath

Expand All @@ -33,7 +34,7 @@ def __init__(
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.patterns = to_tensor(patterns, device="cpu" if not save_data_on_device else "cuda")
self.patterns = to_tensor(patterns, device="cpu" if not save_data_on_device else AcceleratorModuleWrapper.get_to_device_string())
if fft_shift:
self.patterns = torch.fft.fftshift(self.patterns, dim=(-2, -1))
logger.info("Diffraction data have been FFT-shifted.")
Expand Down
6 changes: 3 additions & 3 deletions src/ptychi/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Full license accessible at https://github.com//AdvancedPhotonSource/pty-chi/blob/main/LICENSE

from typing import Type
from functools import partial

import torch

Expand All @@ -12,8 +13,7 @@
from ptychi.reconstructors.base import Reconstructor
import ptychi.image_proc as ip
import ptychi.reconstructors.nn as pnn
from functools import partial

from ptychi.device import AcceleratorModuleWrapper

def get_complex_dtype_by_enum(key: enums.Dtypes) -> torch.dtype:
return {enums.Dtypes.FLOAT32: torch.complex64, enums.Dtypes.FLOAT64: torch.complex128}[key]
Expand Down Expand Up @@ -67,7 +67,7 @@ def get_noise_model_by_enum(key: enums.NoiseModels) -> str:


def get_device_by_enum(key: enums.Devices) -> str:
return {enums.Devices.CPU: "cpu", enums.Devices.GPU: "cuda"}[key]
return {enums.Devices.CPU: "cpu", enums.Devices.GPU: AcceleratorModuleWrapper.get_to_device_string()}[key]


def get_dtype_by_enum(key: enums.Dtypes) -> torch.dtype:
Expand Down
11 changes: 6 additions & 5 deletions src/ptychi/timing/timer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import numpy as np
import matplotlib.pyplot as plt
import copy
import torch
from collections import defaultdict

from ptychi.device import AcceleratorModuleWrapper


# Type variables to retain function signatures
T = TypeVar("T", bound=Callable)
Expand Down Expand Up @@ -95,10 +96,10 @@ def wrapper(*args, **kwargs):
overhead_time_1 = time.time() - measure_overhead_start_1

# Measure function execution time
torch.cuda.synchronize()
AcceleratorModuleWrapper.get_module().synchronize()
start_time = time.time()
result = func(*args, **kwargs)
torch.cuda.synchronize()
AcceleratorModuleWrapper.get_module().synchronize()
elapsed_time = time.time() - start_time

# Measure the overhead from running the timer function
Expand Down Expand Up @@ -151,15 +152,15 @@ def start(self):
saved_dict_reference = update_current_dict_reference(self.name)
self.saved_dict_reference = saved_dict_reference
self.overhead_time = time.time() - measure_overhead_start
torch.cuda.synchronize()
AcceleratorModuleWrapper.get_module().synchronize()
self.start_time = time.time()

def end(self):
"""
Stops the timer and records the elapsed time if timing is enabled.
"""
if self.enabled and globals().get("ENABLE_TIMING", False):
torch.cuda.synchronize()
AcceleratorModuleWrapper.get_module().synchronize()
elapsed_time = time.time() - self.start_time
measure_overhead_start = time.time()
update_elapsed_time_dict(self.name, elapsed_time)
Expand Down
34 changes: 26 additions & 8 deletions src/ptychi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ptychi.maths as pmath
import ptychi.propagate as propagate
from ptychi.timing.timer_utils import timer
from ptychi.device import AcceleratorModuleWrapper

if TYPE_CHECKING:
from ptychi.api.task import PtychographyTask
Expand Down Expand Up @@ -337,11 +338,25 @@ def to_numpy(data: Union[ndarray, Tensor]) -> ndarray:


def set_default_complex_dtype(dtype):
"""Set the default complex dtype.

Parameters
----------
dtype : torch.dtype
The default complex dtype.
"""
global _default_complex_dtype
_default_complex_dtype = dtype


def get_default_complex_dtype():
"""Get the default complex dtype.

Returns
-------
torch.dtype
The default complex dtype.
"""
return _default_complex_dtype


Expand Down Expand Up @@ -468,7 +483,7 @@ def get_max_batch_size(
else:
data_size_gb = 0.0

mem_avail = torch.cuda.mem_get_info()[0] * (1 - margin_factor) / 1024 ** 3
mem_avail = AcceleratorModuleWrapper.get_module().mem_get_info()[0] * (1 - margin_factor) / 1024 ** 3
mem_compute = mem_avail - data_size_gb
batch_size = (mem_compute - x1 * n_p - x2 * n_o) / (x0 * n_p)
batch_size = batch_size * (8 / dtype.itemsize)
Expand All @@ -488,13 +503,15 @@ def auto_transfer_to_device(data: Tensor) -> Tensor:
2.2. If `torch.cuda.device_count()` is not 0, we assume it is the latter case, and
we transfer the data to `cuda`.
"""
if torch.get_default_device().type == "cuda":
return data.cuda()
accelerator_module_wrapper = AcceleratorModuleWrapper()

if torch.get_default_device().type == accelerator_module_wrapper.get_to_device_string():
return data.to(accelerator_module_wrapper.get_to_device_string())
else:
if torch.cuda.device_count() == 0:
if accelerator_module_wrapper.get_module().device_count() == 0:
return data
else:
return data.cuda()
return data.to(accelerator_module_wrapper.get_to_device_string())


def clear_memory(task: Optional["PtychographyTask"] = None):
Expand All @@ -506,11 +523,12 @@ def clear_memory(task: Optional["PtychographyTask"] = None):
task : PtychographyTask, optional
The `Task` object to be deleted.
"""
accelerator_module_wrapper = AcceleratorModuleWrapper()
if task is not None:
del task
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
accelerator_module_wrapper.get_module().empty_cache()
accelerator_module_wrapper.get_module().ipc_collect()


def jsonize(val):
Expand All @@ -524,4 +542,4 @@ def jsonize(val):
elif isinstance(val, (list, tuple, dict, str, int, float, bool, type(None))):
return val
else:
raise TypeError(f"Object of type {type(val).__name__} is not JSON serializable")
raise TypeError(f"Object of type {type(val).__name__} is not JSON serializable")