Skip to content

Document : Markdown - currently what works for optimizing torch on MPS backend, and what doesn't

Notifications You must be signed in to change notification settings

MichaelLevinson/torch_mps_optimize_doc

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 

Repository files navigation

Optimizing PyTorch 2.7 for MPS (Metal Performance Shaders) on Apple Silicon

PyTorch's Metal Performance Shaders (MPS) backend enables deep learning workloads to leverage Apple's GPU hardware on M1, M2, M3, and M4 devices. This report provides a comprehensive analysis of optimizations, limitations, and best practices for maximizing performance with PyTorch 2.7 on Apple Silicon.

Memory Management Optimizations

Memory management is critical for MPS performance due to the unified memory architecture of Apple Silicon. Several techniques can help optimize memory usage:

Memory Limits Configuration

PyTorch offers environment variables and functions to control memory allocation on MPS devices:

# Option 1: Set environment variables before importing torch
import os
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "1.5"  # Default: 1.7
os.environ["PYTORCH_MPS_LOW_WATERMARK_RATIO"] = "1.2"   # Default: 1.4

# Option 2: Use the API
import torch
torch.mps.set_per_process_memory_fraction(0.9)  # Use up to 90% of available memory

Memory Release Mechanisms

# Clear unused cached memory
torch.mps.empty_cache()

Use this especially after large operations or between epochs to prevent memory fragmentation.

Performance Optimization Techniques

Gradient Management

Use this approach for gradient reset (faster than optimizer.zero_grad()):

for param in model.parameters():
    param.grad = None

Gradient Checkpointing

For memory-constrained training:

from torch.utils.checkpoint import checkpoint

def forward(self, x):
    x = checkpoint(self.expensive_layer, x)
    return x

Batch Size Optimization

To simulate large batches via gradient accumulation:

optimizer.zero_grad()
for i, (inputs, targets) in enumerate(loader):
    outputs = model(inputs.to("mps"))
    loss = loss_fn(outputs, targets.to("mps"))
    (loss / accumulation_steps).backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Manual Mixed Precision

MPS lacks full AMP support. This is the manual alternative:

inputs = inputs.to(torch.float16).to("mps")
with torch.autocast(device_type="mps", dtype=torch.float16, enabled=False):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets.to("mps"))
loss.backward()

Known Limitations and Workarounds

Float64 Precision Limitation

Convert to float32 before moving models to MPS:

model_data = torch.load("model.pt", map_location="cpu")
for k in model_data:
    if isinstance(model_data[k], torch.Tensor) and model_data[k].dtype == torch.float64:
        model_data[k] = model_data[k].float()
model_data = model_data.to("mps")

Data Loading

Avoid pin_memory=True for MPS:

# Recommended:
dataloader = DataLoader(dataset, batch_size=32)
inputs = inputs.to("mps", non_blocking=True)

MPS Fallback

To allow CPU fallback:

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

Operator Support Status

Many operations are still not supported on MPS (e.g., grid_sampler_3d, some linear algebra ops). Use:

os.environ["PYTORCH_MPS_LOG_LEVEL"] = "1"

To log fallbacks when running a model.

Profiling and Analysis

MPS Profiler

with torch.mps.profiler.profile(mode="interval", wait_until_completed=True) as p:
    output = model(input_tensor)

Metal Capture (advanced)

if torch.mps.is_metal_capture_enabled():
    with torch.mps.metal_capture():
        output = model(input_tensor)

torch.compile on MPS (PyTorch 2.7)

Still experimental:

model = torch.compile(model, backend="mps")

Known issues: operator fallbacks, shader generation bugs. Stable support expected in PyTorch 2.8+.

Conclusion

To optimize PyTorch on MPS:

  • Set memory limits via environment variables or API
  • Use torch.mps.empty_cache() strategically
  • Avoid pin_memory
  • Use gradient accumulation and checkpointing for large models
  • Profile fallbacks and memory pressure via logging and Metal tools
  • Use float64 on cpu and convert to float32 before sending to mps MPS support is improving steadily — but still requires targeted optimization on Apple Silicon.

Complex Tensors

not supported, keep on CPU or split real z, imag z, phase recovery

Testing PyTorch Nightly - 2.8 : Complex Tensors appear supported - FFT, Advanced Matrix Ops, and Gradient Calculations are working

>>> print("MPS time:", timeit(lambda: x @ x.conj().T, number=1000))
print("CPU time:", timeit(lambda: x.cpu() @ x.cpu().conj().T, number=1000))MPS time: 0.0301269160117954
>>> print("CPU time:", timeit(lambda: x.cpu() @ x.cpu().conj().T, number=1000))
CPU time: 0.2310853329836391

About

Document : Markdown - currently what works for optimizing torch on MPS backend, and what doesn't

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published