PyTorch's Metal Performance Shaders (MPS) backend enables deep learning workloads to leverage Apple's GPU hardware on M1, M2, M3, and M4 devices. This report provides a comprehensive analysis of optimizations, limitations, and best practices for maximizing performance with PyTorch 2.7 on Apple Silicon.
Memory management is critical for MPS performance due to the unified memory architecture of Apple Silicon. Several techniques can help optimize memory usage:
PyTorch offers environment variables and functions to control memory allocation on MPS devices:
# Option 1: Set environment variables before importing torch
import os
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "1.5" # Default: 1.7
os.environ["PYTORCH_MPS_LOW_WATERMARK_RATIO"] = "1.2" # Default: 1.4
# Option 2: Use the API
import torch
torch.mps.set_per_process_memory_fraction(0.9) # Use up to 90% of available memory# Clear unused cached memory
torch.mps.empty_cache()Use this especially after large operations or between epochs to prevent memory fragmentation.
Use this approach for gradient reset (faster than optimizer.zero_grad()):
for param in model.parameters():
param.grad = NoneFor memory-constrained training:
from torch.utils.checkpoint import checkpoint
def forward(self, x):
x = checkpoint(self.expensive_layer, x)
return xTo simulate large batches via gradient accumulation:
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(loader):
outputs = model(inputs.to("mps"))
loss = loss_fn(outputs, targets.to("mps"))
(loss / accumulation_steps).backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()MPS lacks full AMP support. This is the manual alternative:
inputs = inputs.to(torch.float16).to("mps")
with torch.autocast(device_type="mps", dtype=torch.float16, enabled=False):
outputs = model(inputs)
loss = loss_fn(outputs, targets.to("mps"))
loss.backward()Convert to float32 before moving models to MPS:
model_data = torch.load("model.pt", map_location="cpu")
for k in model_data:
if isinstance(model_data[k], torch.Tensor) and model_data[k].dtype == torch.float64:
model_data[k] = model_data[k].float()
model_data = model_data.to("mps")Avoid pin_memory=True for MPS:
# Recommended:
dataloader = DataLoader(dataset, batch_size=32)
inputs = inputs.to("mps", non_blocking=True)To allow CPU fallback:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"Many operations are still not supported on MPS (e.g., grid_sampler_3d, some linear algebra ops). Use:
os.environ["PYTORCH_MPS_LOG_LEVEL"] = "1"To log fallbacks when running a model.
with torch.mps.profiler.profile(mode="interval", wait_until_completed=True) as p:
output = model(input_tensor)if torch.mps.is_metal_capture_enabled():
with torch.mps.metal_capture():
output = model(input_tensor)Still experimental:
model = torch.compile(model, backend="mps")Known issues: operator fallbacks, shader generation bugs. Stable support expected in PyTorch 2.8+.
To optimize PyTorch on MPS:
- Set memory limits via environment variables or API
- Use
torch.mps.empty_cache()strategically - Avoid
pin_memory - Use gradient accumulation and checkpointing for large models
- Profile fallbacks and memory pressure via logging and Metal tools
- Use float64 on cpu and convert to float32 before sending to mps MPS support is improving steadily — but still requires targeted optimization on Apple Silicon.
not supported, keep on CPU or split real z, imag z, phase recovery
Testing PyTorch Nightly - 2.8 : Complex Tensors appear supported - FFT, Advanced Matrix Ops, and Gradient Calculations are working
>>> print("MPS time:", timeit(lambda: x @ x.conj().T, number=1000))
print("CPU time:", timeit(lambda: x.cpu() @ x.cpu().conj().T, number=1000))MPS time: 0.0301269160117954
>>> print("CPU time:", timeit(lambda: x.cpu() @ x.cpu().conj().T, number=1000))
CPU time: 0.2310853329836391