-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
Overview
MLX currently supports Apple Silicon and NVIDIA GPUs via CUDA. This feature request proposes adding support for AMD GPUs through ROCm (Radeon Open Compute), enabling MLX to run on a broader range of hardware, particularly AMD GPUs like the Radeon RX 7900 XTX (Navi 31) and upcoming RDNA3+ architectures.
Motivation
Growing AMD GPU adoption: With competitive offerings in both consumer (RX 7000 series) and datacenter (MI300 series) markets, AMD GPU support would significantly expand MLX's user base
Cross-platform compatibility: ROCm supports both Linux and Windows 11, aligning with MLX's multi-platform approach
Open-source ecosystem: ROCm's open-source nature complements MLX's philosophy
Cost-effectiveness: AMD GPUs often provide better price-to-performance ratios for ML workloads
Proposed Implementation
Based on MLX's current architecture, the following components would need adaptation:
Backend Extension (mlx/backend/)
Add a new rocm backend alongside existing metal and cuda backends
Implement ROCm-specific memory management and stream handling
Compiler Integration
Integrate HIP (Heterogeneous-compute Interface for Portability) as the CUDA-equivalent layer
Adapt existing CUDA kernels using HIP's translation tools or native HIP implementations
Build System Updates
Extend CMake configuration to detect and link ROCm libraries
Add ROCm-specific compilation flags and paths
Core Operations Mapping
python# Pseudo-code for backend selection
if backend == "rocm":
import mlx.core.rocm as mx_rocm
return mx_rocm.ops
Memory Management
Implement ROCm memory pools using hipMalloc/hipFree
Add unified memory support via ROCm's managed memory APIs
Technical Considerations
HIP Translation: Many CUDA kernels can be automatically translated using hipify-perl or hipify-clang
Performance Parity: ROCm 6.0+ provides comparable performance to CUDA for most ML operations
Testing Infrastructure: Would require AMD GPU CI/CD runners for comprehensive testing
Compatibility Matrix
ROCm 5.7+ (recommended 6.0+)
Linux: Ubuntu 22.04/24.04, RHEL 9
Windows 11 (via ROCm on Windows)
AMD GPUs: RDNA2+ (RX 6000 series and newer), CDNA (MI100+)
Potential Challenges
ROCm's historically less mature ecosystem compared to CUDA
Limited Windows support (improving with recent releases)
Potential need for architecture-specific optimizations
Resources & References
ROCm Documentation
HIP Porting Guide
Similar implementations: PyTorch ROCm, JAX ROCm support
Contribution
I'm willing to contribute to this implementation and have experience with both CUDA and ROCm development. Initial focus could be on core operations with gradual expansion to full feature parity.
Questions for Maintainers
Is ROCm support aligned with MLX's roadmap?
Are there specific design patterns or architectural decisions that should be followed?
Would a phased approach (starting with basic ops) be acceptable?
Implementierungsschritte
Hier sind konkrete technische Ansatzpunkte für die Implementierung:
- Backend-Architektur
cpp// mlx/backend/rocm/device.cpp
class ROCmDevice : public Device {
hipDevice_t device_;
hipStream_t stream_;
// Implementation details
};
- Kernel-Portierung
cpp// Beispiel: CUDA zu HIP Konversion
// CUDA Version
__global__ void add_kernel(float* a, float* b, float* c, int n);
// HIP Version (meist identisch)
__global__ void add_kernel(float* a, float* b, float* c, int n);
- Build-System Integration
cmake# CMakeLists.txt Erweiterung
option(MLX_BUILD_ROCM "Build with ROCm support" OFF)
if(MLX_BUILD_ROCM)
find_package(hip REQUIRED)
find_package(rocblas REQUIRED)
# weitere ROCm libraries
endif()
- Python Bindings```
python# mlx/backend_selector.py
def get_backend():
if has_rocm() and prefer_rocm():
return "rocm"
elif has_cuda() and prefer_cuda():
return "cuda"
return "cpu"