Add MPS (Metal Performance Shaders) support for Apple Silicon GPUs #429
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add MPS (Metal Performance Shaders) support for Apple Silicon GPUs
This PR brings native PyTorch MPS support to CellBender 0.3.2 so users on Apple Silicon (M1/M2/M3/M4) Macs can run GPU‑accelerated inference on macOS. It adapts the original work from commit
8a70ea3on the legacysf_pytorch_mps_backendbranch (v0.2.0) to the 0.3.2 codebase.8a70ea306efbfd63f2a219b93b1b2c2749f641dfadd-mps-supportSummary
--mpsto enable the PyTorch MPS backend.use_cuda, a stringdeviceis passed end‑to‑end and can be one ofcuda,mps, orcpu.Why this change?
Changes (by file)
cellbender/remove_background/argparser.py--mpsflag with help text and link to PyTorch MPS docs.cellbender/remove_background/cli.pyargs.devicetocuda(if--cudaand available), elsemps(if--mpsand available), elsecpu.cellbender/remove_background/model.pydevice: strinstead ofuse_cuda: bool..to(device)(model and submodules) instead of.cuda().use_cudastate; storesself.deviceonly.pyro.plate(..., use_cuda=..., device=...)withpyro.plate(..., device=...).cellbender/remove_background/data/dataprep.pyDataLoaderacceptsdevice: strand pushes tensors to that device.prep_sparse_data_for_training(...)acceptsdeviceand propagates it to loaders.cellbender/remove_background/data/dataset.pyget_dataloader(...)now takesdevice: str(instead ofuse_cuda: bool) and forwards it toDataLoader.cellbender/remove_background/run.pyargs.deviceconsistently.args.device == 'cpu'.force_deviceconsistent with selected backend.deviceinto posterior computations and estimators.cellbender/remove_background/posterior.pydevicefromvi_model.device(or sensible fallback).deviceexplicitly.cellbender/remove_background/train.pymodel.device == 'cuda'(instead ofmodel.use_cuda).torch.cuda.empty_cache()if CUDA.torch.mps.empty_cache()if MPS and available (wrapped in try/except).Device selection behavior
--cudais provided andtorch.cuda.is_available(), usecuda.--mpsis provided andtorch.backends.mps.is_available()andtorch.backends.mps.is_built(), usemps.cpu.CUDA takes precedence over MPS when both are requested/available.
How to use
To verify the flag is visible:
cellbender remove-background --help | grep -A 4 -- --mpsTesting performed
--mpsappears in the CLI help.torch.backends.mps.is_available()andis_built()on Apple Silicon test machine.--cudaand CPU remain intact.Note: Full end-to-end test suite (including GPU tests) should be run in CI or by maintainers; this PR aims to be minimally invasive while restoring the MPS feature.
Limitations and follow-ups
cellbender/monitor.py) prints GPU utilization vianvidia-smi(CUDA only). There’s no analogous standard CLI for MPS; for now, logs omit MPS GPU utilization. A future improvement could add optional macOS/MPS metrics if a stable API becomes available.Backward compatibility
--mpsis additive.force_devicewhere appropriate.Related work
8a70ea3(Stephen Fleming) onsf_pytorch_mps_backend(0.2.0).Reviewer notes
use_cudaassumptions in the active code paths.Checklist
--mpsflag and help textThanks for reviewing! This should unlock fast, native GPU acceleration for a large portion of the community using Apple Silicon machines, while preserving the familiar CUDA and CPU paths.