-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
Dear code maintainers,
Would it be possible to include mps as device for running parrot as fast as using CUDA device on Macbook laptops?
Here is a snippet where mps can be included:
# Device configuration
if forceCPU:
device = 'cpu'
elif gpu_id:
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else 'cpu')
print(f"You've specified to run this network on cuda:{gpu_id}. Running on {device=}")
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')Available on https://github.com/idptools/parrot/blob/6e09567afdc3a59d0c03f0802cf4d2fe9c973feb/scripts/parrot-train#L135C1-L142C5
This code could help to integrate MPS:
has_mps = torch.backends.mps.is_built()
device = "mps" if has_mps else "cuda" if torch.cuda.is_available() else "cpu"Thanks in advance
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels