Add MPS support for image inference on macOS (Apple Silicon) #400
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add MPS support for image inference on macOS (Apple Silicon)
AI-Generated Code Disclaimer
Note: This PR contains code changes that were generated with the assistance of AI tools (Cursor AI). All changes have been reviewed, tested, and validated. The implementation follows PyTorch best practices and patterns similar to those used in SAM2 for MPS compatibility.
Summary
This PR enables running SAM3 image inference on macOS using the PyTorch MPS backend (Apple Silicon GPU) or CPU, with automatic device selection (CUDA → MPS → CPU). CUDA behavior is unchanged. Video/tracking remains CUDA-only and raises a clear
NotImplementedErroron non-CUDA devices.Key Changes
Core Device Support
get_device()helper inmodel_builder.pythat auto-detects device (CUDA → MPS → CPU)build_sam3_image_model()to support MPS device selectionMPS Compatibility Fixes
Optional Dependencies
Graceful Error Handling
NotImplementedErrorwith helpful messages when video/tracking is attempted on non-CUDA devicesTesting
scripts/smoke_macos.pyfor lightweight validation on macOSFiles Modified
Core Library (sam3/sam3/)
model_builder.py: Added device selection helpers, MPS support in device setupmodel/position_encoding.py: Added device parameter, fixed cache device placementmodel/decoder.py: Fixed coordinate cache device detection, improved autocast device detectionmodel/geometry_encoders.py: Added MPS-safe grid_sample fallback, fixed _assert_asyncmodel/edt.py: Improved OpenCV fallback for non-CUDA devicesmodel/sam3_tracking_predictor.py: Disabled bfloat16 autocast on MPSmodel/sam3_tracker_base.py: Added device check in device property (raises error for non-CUDA)model/sam3_video_predictor.py: Added CUDA check before model constructionmodel/sam3_image.py: Fixed _assert_async for MPS compatibilitymodel/utils/sam2_utils.py: Added decord import error handlingtrain/data/sam3_image_dataset.py: Added decord availability checktrain/loss/mask_sampling.py: Added MPS-safe grid_sample fallbacktrain/loss/loss_fns.py: Fixed _assert_async for MPS compatibilityTesting
scripts/smoke_macos.py: New smoke test script for macOS validationValidation
Test Environment
Test Results
Performance Notes
Limitations
Video/Tracking: Currently requires CUDA. Attempting to use video/tracking on CPU or MPS will raise a clear
NotImplementedErrorwith guidance to use image inference instead.MPS Operation Coverage: Some operations (like
grid_sample) have incomplete MPS implementation and require CPU fallback. This is handled automatically via CPU round-trips. For additional unsupported operations, users may need to setPYTORCH_ENABLE_MPS_FALLBACK=1environment variable (per PyTorch MPS documentation). For example,aten::grid_sampler_3dis not implemented on MPS and PyTorch suggests usingPYTORCH_ENABLE_MPS_FALLBACK=1as a temporary fix.Autocast: bfloat16 autocast is disabled on MPS (not well supported). This may result in slightly different numerical outputs compared to CUDA, but results are still accurate.
Backward Compatibility
Testing Checklist
Usage Example
References