Skip to content

Add ROCm Support for AMD GPUs #2556

@MrDewitt88

Description

@MrDewitt88

Overview
MLX currently supports Apple Silicon and NVIDIA GPUs via CUDA. This feature request proposes adding support for AMD GPUs through ROCm (Radeon Open Compute), enabling MLX to run on a broader range of hardware, particularly AMD GPUs like the Radeon RX 7900 XTX (Navi 31) and upcoming RDNA3+ architectures.
Motivation

Growing AMD GPU adoption: With competitive offerings in both consumer (RX 7000 series) and datacenter (MI300 series) markets, AMD GPU support would significantly expand MLX's user base
Cross-platform compatibility: ROCm supports both Linux and Windows 11, aligning with MLX's multi-platform approach
Open-source ecosystem: ROCm's open-source nature complements MLX's philosophy
Cost-effectiveness: AMD GPUs often provide better price-to-performance ratios for ML workloads

Proposed Implementation
Based on MLX's current architecture, the following components would need adaptation:

Backend Extension (mlx/backend/)

Add a new rocm backend alongside existing metal and cuda backends
Implement ROCm-specific memory management and stream handling

Compiler Integration

Integrate HIP (Heterogeneous-compute Interface for Portability) as the CUDA-equivalent layer
Adapt existing CUDA kernels using HIP's translation tools or native HIP implementations

Build System Updates

Extend CMake configuration to detect and link ROCm libraries
Add ROCm-specific compilation flags and paths

Core Operations Mapping

python# Pseudo-code for backend selection
if backend == "rocm":
    import mlx.core.rocm as mx_rocm
    return mx_rocm.ops

Memory Management

Implement ROCm memory pools using hipMalloc/hipFree
Add unified memory support via ROCm's managed memory APIs

Technical Considerations

HIP Translation: Many CUDA kernels can be automatically translated using hipify-perl or hipify-clang
Performance Parity: ROCm 6.0+ provides comparable performance to CUDA for most ML operations
Testing Infrastructure: Would require AMD GPU CI/CD runners for comprehensive testing

Compatibility Matrix

ROCm 5.7+ (recommended 6.0+)
Linux: Ubuntu 22.04/24.04, RHEL 9
Windows 11 (via ROCm on Windows)
AMD GPUs: RDNA2+ (RX 6000 series and newer), CDNA (MI100+)

Potential Challenges

ROCm's historically less mature ecosystem compared to CUDA
Limited Windows support (improving with recent releases)
Potential need for architecture-specific optimizations

Resources & References

ROCm Documentation
HIP Porting Guide
Similar implementations: PyTorch ROCm, JAX ROCm support

Contribution
I'm willing to contribute to this implementation and have experience with both CUDA and ROCm development. Initial focus could be on core operations with gradual expansion to full feature parity.
Questions for Maintainers

Is ROCm support aligned with MLX's roadmap?
Are there specific design patterns or architectural decisions that should be followed?
Would a phased approach (starting with basic ops) be acceptable?

Implementierungsschritte
Hier sind konkrete technische Ansatzpunkte für die Implementierung:

  1. Backend-Architektur
cpp// mlx/backend/rocm/device.cpp
class ROCmDevice : public Device {
    hipDevice_t device_;
    hipStream_t stream_;
    // Implementation details
};
  1. Kernel-Portierung
cpp// Beispiel: CUDA zu HIP Konversion
// CUDA Version
__global__ void add_kernel(float* a, float* b, float* c, int n);

// HIP Version (meist identisch)
__global__ void add_kernel(float* a, float* b, float* c, int n);
  1. Build-System Integration
cmake# CMakeLists.txt Erweiterung
option(MLX_BUILD_ROCM "Build with ROCm support" OFF)
if(MLX_BUILD_ROCM)
    find_package(hip REQUIRED)
    find_package(rocblas REQUIRED)
    # weitere ROCm libraries
endif()
  1. Python Bindings```
    python# mlx/backend_selector.py
    def get_backend():
    if has_rocm() and prefer_rocm():
    return "rocm"
    elif has_cuda() and prefer_cuda():
    return "cuda"
    return "cpu"

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions