Skip to content

Conversation

@Shreesh-Coder
Copy link

Add MPS support for image inference on macOS (Apple Silicon)

AI-Generated Code Disclaimer

Note: This PR contains code changes that were generated with the assistance of AI tools (Cursor AI). All changes have been reviewed, tested, and validated. The implementation follows PyTorch best practices and patterns similar to those used in SAM2 for MPS compatibility.

Summary

This PR enables running SAM3 image inference on macOS using the PyTorch MPS backend (Apple Silicon GPU) or CPU, with automatic device selection (CUDA → MPS → CPU). CUDA behavior is unchanged. Video/tracking remains CUDA-only and raises a clear NotImplementedError on non-CUDA devices.

Key Changes

Core Device Support

  • Device Selection: Added centralized get_device() helper in model_builder.py that auto-detects device (CUDA → MPS → CPU)
  • Model Builder: Updated build_sam3_image_model() to support MPS device selection
  • Device Threading: Ensured device parameter is passed through all model construction functions (position encoders, geometry encoders, decoder caches)

MPS Compatibility Fixes

  • Position Encoding Cache: Fixed cache creation to respect device parameter, preventing device mismatches
  • Decoder Coordinate Cache: Updated to detect device from model parameters instead of auto-detecting
  • Autocast: Disabled bfloat16 autocast on MPS (not well supported), kept CUDA behavior unchanged
  • grid_sample: Added CPU round-trip fallback for MPS. Some grid_sample operations have incomplete MPS implementation and fall back to CPU execution when needed (see PyTorch MPS limitations)
  • EDT (Euclidean Distance Transform): Added OpenCV fallback for non-CUDA devices (CPU and MPS)
  • _assert_async: Replaced with regular assertions on MPS (MPS doesn't support async asserts)

Optional Dependencies

  • decord: Made optional with clear error messages when video loading is attempted without it
  • triton: Already optional; only imported for CUDA paths

Graceful Error Handling

  • Video/Tracking: Added explicit checks that raise NotImplementedError with helpful messages when video/tracking is attempted on non-CUDA devices
  • Image Inference: Works on CPU and MPS; video/tracking remains CUDA-only

Testing

  • Smoke Test: Added scripts/smoke_macos.py for lightweight validation on macOS

Files Modified

Core Library (sam3/sam3/)

  • model_builder.py: Added device selection helpers, MPS support in device setup
  • model/position_encoding.py: Added device parameter, fixed cache device placement
  • model/decoder.py: Fixed coordinate cache device detection, improved autocast device detection
  • model/geometry_encoders.py: Added MPS-safe grid_sample fallback, fixed _assert_async
  • model/edt.py: Improved OpenCV fallback for non-CUDA devices
  • model/sam3_tracking_predictor.py: Disabled bfloat16 autocast on MPS
  • model/sam3_tracker_base.py: Added device check in device property (raises error for non-CUDA)
  • model/sam3_video_predictor.py: Added CUDA check before model construction
  • model/sam3_image.py: Fixed _assert_async for MPS compatibility
  • model/utils/sam2_utils.py: Added decord import error handling
  • train/data/sam3_image_dataset.py: Added decord availability check
  • train/loss/mask_sampling.py: Added MPS-safe grid_sample fallback
  • train/loss/loss_fns.py: Fixed _assert_async for MPS compatibility

Testing

  • scripts/smoke_macos.py: New smoke test script for macOS validation

Validation

Test Environment

  • macOS: 26.1 (Build 25B78) - Apple Silicon (arm64)
  • PyTorch: 2.9.1
  • MPS available: Yes
  • Hardware: Apple M2

Test Results

# CPU test
python scripts/smoke_macos.py --device cpu --skip-forward
# Result: ✓ PASSED

# MPS test
PYTORCH_ENABLE_MPS_FALLBACK=1 python scripts/smoke_macos.py --device mps --skip-forward
# Result: ✓ PASSED

# Auto-detect test
python scripts/smoke_macos.py --device auto --skip-forward
# Result: ✓ PASSED (selects MPS when available)

Performance Notes

  • On test machine (M2): MPS inference is ~1.7x faster than CPU (~6.5s vs ~11s per inference)
  • Outputs are qualitatively consistent; small numeric differences are expected due to backend/dtype differences
  • Note: Performance numbers are specific to this test machine and not guaranteed

Limitations

  1. Video/Tracking: Currently requires CUDA. Attempting to use video/tracking on CPU or MPS will raise a clear NotImplementedError with guidance to use image inference instead.

  2. MPS Operation Coverage: Some operations (like grid_sample) have incomplete MPS implementation and require CPU fallback. This is handled automatically via CPU round-trips. For additional unsupported operations, users may need to set PYTORCH_ENABLE_MPS_FALLBACK=1 environment variable (per PyTorch MPS documentation). For example, aten::grid_sampler_3d is not implemented on MPS and PyTorch suggests using PYTORCH_ENABLE_MPS_FALLBACK=1 as a temporary fix.

  3. Autocast: bfloat16 autocast is disabled on MPS (not well supported). This may result in slightly different numerical outputs compared to CUDA, but results are still accurate.

Backward Compatibility

  • ✅ All changes are backward compatible
  • ✅ CUDA behavior is unchanged
  • ✅ CPU behavior is unchanged (now more robust)
  • ✅ Existing code continues to work without modification
  • ✅ Device auto-detection maintains CUDA → MPS → CPU priority

Testing Checklist

  • CPU model construction works
  • MPS model construction works (when MPS available)
  • Device auto-detection works correctly
  • Position encoding cache created on correct device
  • Decoder coordinate cache created on correct device
  • grid_sample CPU fallback works on MPS
  • EDT OpenCV fallback works on non-CUDA
  • Video/tracking raises clear error on non-CUDA
  • Smoke tests pass on macOS

Usage Example

from sam3.model_builder import build_sam3_image_model

# Auto-detect device (will choose MPS on macOS if available)
model = build_sam3_image_model(device=None)

# Explicitly use MPS
model = build_sam3_image_model(device="mps")

# Explicitly use CPU
model = build_sam3_image_model(device="cpu")

References

Shreesh Gupta added 3 commits January 7, 2026 18:57
This commit adds CPU/MPS compatibility fixes to enable SAM3 to run on macOS
without CUDA. These changes address device mismatch errors and hardcoded CUDA
references that prevent the model from working on non-GPU systems.

Key Changes:
- geometry_encoders.py: Conditional pin_memory() to avoid MPS device errors
- position_encoding.py: CPU-aware device selection for position encoding
- decoder.py: CPU-aware coordinate caching and autocast device handling
- edt.py: OpenCV fallback for distance transform when Triton unavailable
- sam3_tracking_predictor.py: CPU-compatible autocast configuration
- model_builder.py: Improved device setup with better error handling

Fixes:
- Resolves 'pin_memory()' errors on MPS (similar to SAM2 PR #495)
- Fixes device mismatch: 'Attempted to set storage on cpu to mps:0'
- Enables CPU fallback for GPU-specific operations (Triton → OpenCV)
- Prevents segmentation faults from hardcoded CUDA operations

Testing:
- Successfully tested on macOS with PyTorch 2.9.1 (CPU mode)
- Model loads and runs inference without CUDA
- Compatible with existing CUDA workflows (backward compatible)

---

AI-GENERATED CODE DISCLAIMER:
This code was developed with assistance from AI (Claude/Cursor). The modifications
are compatibility patches for macOS/CPU usage and have been tested, but users should:
1. Review all changes before deploying to production
2. Test thoroughly in their specific environment
3. Be aware that CPU performance is slower than GPU
4. Understand that some operations use fallback implementations
5. Note that these are compatibility patches and do not alter core model architecture

The original SAM3 codebase is from Facebook Research:
https://github.com/facebookresearch/sam3

These modifications maintain backward compatibility with CUDA while enabling
CPU/MPS support for development and testing on macOS systems.
- Add centralized device selection helper (get_device) with auto-detection (CUDA → MPS → CPU)
- Thread device parameter through model construction (position encoders, geometry encoders, decoder caches)
- Add MPS compatibility fixes:
  - Disable bfloat16 autocast on MPS (not well supported)
  - Add CPU round-trip fallback for grid_sample on MPS
  - Fix position encoding and decoder coordinate cache device placement
  - Replace _assert_async with regular assertions on MPS
  - Add OpenCV fallback for EDT on non-CUDA devices
- Make decord optional with clear error messages
- Add graceful error handling for video/tracking on non-CUDA (raises NotImplementedError)
- Add smoke test script for macOS validation

Video/tracking remains CUDA-only. Image inference works on CPU and MPS.

AI-Generated Code Disclaimer: This PR contains code changes that were generated with the assistance of AI tools (Cursor AI). All changes have been reviewed, tested, and validated.

Tested on:
- macOS 26.1 (Build 25B78) - Apple Silicon (arm64)
- PyTorch 2.9.1
- Apple M2

All smoke tests pass on CPU and MPS.
@meta-cla
Copy link

meta-cla bot commented Jan 7, 2026

Hi @Shreesh-Coder!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 7, 2026
@meta-cla
Copy link

meta-cla bot commented Jan 7, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant