-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[Experiment] ROCm backend #2300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
What an unexpected and amazing surprise! I'm absolutely thrilled. |
|
@awni |
|
I think this is good to stay as an experiment branch for some time while we work on core and CUDA. I don't think we have the bandwidth to merge this for a few months at least. Sorry if this is disappointing @NripeshN I don't mean to discourage you working on it. |
|
I would love to see the ROCm backend get more traction. The new AI series of processors by AMD have a similar advantage to Apple Silicon with unified memory and getting MLX to run on those processors would be neat. |
|
Stole my idea :( |
|
How is this even possible for such an awesome PR to be left like this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds experimental ROCm backend support to MLX, enabling execution on AMD GPUs. The implementation mirrors the CUDA backend structure, providing HIP-based implementations of core operations, memory management, and device handling.
Changes:
- Added ROCm backend infrastructure with device management, memory allocation, and stream handling
- Implemented HIP kernels for unary, binary, ternary operations, reductions, normalization (softmax, layer_norm, rms_norm), RoPE, and sorting
- Updated build system (CMake) to support ROCm compilation with configurable GPU architectures
Reviewed changes
Copilot reviewed 59 out of 59 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
| CMakeLists.txt | Added MLX_BUILD_ROCM option and ROCm library detection |
| mlx/CMakeLists.txt | Integrated ROCm backend build configuration |
| mlx/device.cpp | Added ROCm device availability checks |
| mlx/backend/rocm/*.hip | HIP kernel implementations for various operations |
| mlx/backend/rocm/device.* | ROCm device and stream management |
| mlx/backend/rocm/allocator.* | ROCm-specific memory allocator using HIP unified memory |
| mlx/backend/rocm/worker.* | Async task execution worker for stream synchronization |
| mlx/backend/rocm/utils.* | HIP utility functions and error handling |
| mlx/backend/rocm/jit_module.* | JIT compilation support using HIPRTC |
| mlx/backend/rocm/device/*.hpp | Device-side utility functions and type definitions |
| mlx/backend/rocm/CMakeLists.txt | ROCm backend build configuration |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ather, scatter, logsumexp, random bits generation, and sorting. Introduce new kernels for efficient computation and integrate with existing ROCm utilities. Update CMake configuration to include new source files and dependencies. Enhance error handling and ensure compatibility with different data types. This commit significantly expands the functionality of the ROCm backend.
|
👑👑👑 |
|
Can anyone run CMAKE_ARGS="-DMLX_BUILD_ROCM=ON" pip install -e .
CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES={based on your GPU}" pip install -e .Replace {based on your GPU} with your GPU architecture You can run rocm-smito get your GPU information |
|
I'm getting this CMake error: Running on Strix Halo (gfx1151) |
Could you retry with the latest push please (p.s. keep your fingers crossed while it compiles, worked for me 138th time)😅 |
… string formatting, replacing fmt library usage. Remove unused event.cpp file. Update kernel name generation and parameter formatting for consistency.
Now what can I test? 😍 |
|
I'm getting this: |
I forgot to test the Python build my bad, can you try it now? Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works |
|
Could you maybe try: mlx-community/Meta-Llama-3.1-8B-Instruct-bf16 |
|
|
|
I will get back to this in a bit😁 |
|
I also have a Halo Strix, I have it setup in a C++ project, i'm unable to compile with your branch either. |
|
The problem looks like its stemming from the CMakeLists.txt |
|
I have submitted a PR to the ROCm-support branch that fixes these compile errors. |
|
Just got my hands on Radeon Pro V520, should be able to test things out now😏 |
- Use PROJECT_SOURCE_DIR instead of CMAKE_SOURCE_DIR for correct path resolution - Add GCC C++ standard library include paths for HIP compiler - ROCm's clang needs explicit paths to libstdc++ headers
Awesome there is a is_available in eval.cpp that doesn’t need to be there :) |
- Replace rocPRIM-based sort with custom block merge sort - Avoids rocPRIM uninitialized_array compatibility issues with ROCm 7.x - Mirrors CUDA sort implementation approach
|
Here is some profiling information. The command I used. |
- Add Limits struct to device/utils.hpp for sort operations - Add missing numeric_limits specializations for int8, uint8, int16, uint16, bool - Fix C++20 lambda syntax to be C++17 compatible
….cpp - Remove mlx/backend/gpu/available.h include (doesn't exist) - Remove is_available() function (already defined elsewhere) Co-authored-by: Geramy Loveless <geramy@users.noreply.github.com>
- Implement gpu::device_info(), gpu::device_count(), gpu::is_available() - Provides device name, architecture, UUID, PCI bus ID, memory info - Uses hipGetDeviceProperties and hipMemGetInfo for AMD GPU info - Mirrors CUDA device_info.cpp implementation Co-authored-by: Geramy Loveless <geramy@users.noreply.github.com>
- Add mlx/memory.h include to ensure MLX_API visibility attributes are applied to memory function definitions - Fixes undefined symbol errors for reset_peak_memory and other memory management functions Co-authored-by: Geramy Loveless <geramy@users.noreply.github.com>
- Add (void) casts to suppress nodiscard warnings for HIP API calls (hipMalloc, hipMemcpy, hipFree, hipStreamSynchronize, etc.) - Fix implicit float-to-bool conversion warnings in unary_ops.hpp (Erf, ErfInv, Expm1) and binary_ops.hpp (ArcTan2) - Add explicit type checks for bool/integral types before float operations
- Add (void) casts for hipMemsetAsync and hipMemcpyAsync calls in: - conv/gemm_conv.cpp - random.hip - reduce/init_reduce.hip - scaled_dot_product_attention.hip
- Add python/src/rocm.cpp with mx.rocm.is_available() function - Add python/tests/rocm_skip.py with tests to skip for ROCm backend - Update mlx_tests.py to detect ROCm backend and use appropriate skip list - Update CMakeLists.txt to include rocm.cpp and rocm.pyi stub The ROCm skip list includes: - Same tests as CUDA (FFT, linalg, hadamard, etc.) - ROCm-specific: grouped convolution, 1D/3D convolution, input dilation - Quantization tests (different support level than CUDA)
|
I am running the Phi3 Kernel I had made, which works fine on MacOS with the ROCm experimental build. signal SIGSEGV: address not mapped to object (fault address: 0x0) |
The function needs the MLX_API attribute to be exported from the shared library so it can be called from Python bindings.
Some AMD GPUs (like the Radeon Pro V520) report managed memory support but hipMallocManaged fails with "out of memory" even for small allocations. This change adds a runtime check that tests if managed memory actually works, and falls back to regular hipMalloc if it doesn't.
Yup a lot of errors on my end to. Earlier I had just tried eyeballing the implementation trying to copy the structure from cuda and check for compilation errors through docker. I did not have AMD GPUs before this, now that I have AMD GPU I will incrementally patch all the errors. |
When hipMallocManaged fails (which happens on some AMD GPUs like the Radeon Pro V520), fall back to hipHostMalloc instead of hipMalloc. hipHostMalloc allocates pinned host memory that is accessible from both CPU and GPU, which is required because MLX's array initialization code uses std::copy to write data directly to the allocated buffer from CPU. Regular hipMalloc allocates device-only memory that cannot be accessed from CPU code, causing segfaults when std::copy tries to write to it.
AMD GPUs have different wavefront (warp) sizes depending on architecture: - CDNA/GCN (gfx9xx and earlier): 64 - RDNA (gfx10xx, gfx11xx): 32 The previous code hardcoded WARP_SIZE=64 everywhere, which caused incorrect results on RDNA GPUs like the Radeon Pro V520 (gfx1011). This change: 1. Updates device/config.h to detect the target architecture and set WARP_SIZE appropriately using __AMDGCN_WAVEFRONT_SIZE__ or architecture detection macros 2. Updates all kernel files to use the centralized WARP_SIZE definition instead of local hardcoded values
Experiment with ROCm backend.
install MLX with ROCm backend using:
closes #2556
Inspired by @zcbenz