Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8bb8b76
[Experiment] ROCM backend initial push
NripeshN Jun 16, 2025
ac5adfa
increment 1: few ops and jit update
NripeshN Jun 18, 2025
cc4de6a
Increment 2: Implement major ops and add structure similar to cuda
NripeshN Jun 18, 2025
1163da1
Merge remote-tracking branch 'upstream/main' into rocm-support
NripeshN Jan 24, 2026
667cd9b
rocm yaay
NripeshN Jan 24, 2026
8780ad9
Implement ROCm support for various operations including arg reduce, g…
NripeshN Jan 24, 2026
63d6b6a
chore fix cmake
NripeshN Jan 24, 2026
7c1b29d
Merge branch 'main' into rocm-support
NripeshN Jan 24, 2026
ee8b705
compile fix
NripeshN Jan 25, 2026
9aa0f5c
Refactor error handling in ROCm backend to use std::ostringstream for…
NripeshN Jan 25, 2026
cadf18c
lint
NripeshN Jan 25, 2026
6fa7c7c
add more features
NripeshN Jan 26, 2026
57941f9
Enhance ROCm backend with new features including binary operations, L…
NripeshN Jan 26, 2026
1856341
Remove optional MIOpen support from ROCm backend CMake configuration.…
NripeshN Jan 26, 2026
2e27dc9
Add scaled dot product attention kernel and update ROCm convolution i…
NripeshN Jan 26, 2026
da275f7
Fix symbol linking issue
NripeshN Jan 26, 2026
499d2a6
lazy load GPU
NripeshN Jan 26, 2026
c30b211
Add general gather and scatter kernels for arbitrary indexing in ROCm…
NripeshN Jan 26, 2026
86e4f85
Add dynamic copy kernel and gather operation in ROCm backend
NripeshN Jan 26, 2026
7141d8c
Add quantized matrix multiplication and gather QMM kernel in ROCm bac…
NripeshN Jan 26, 2026
1c74fba
Merge remote-tracking branch 'upstream/main' into rocm-support
NripeshN Jan 27, 2026
04efa16
Fix HIP include paths for C++ standard library headers
NripeshN Feb 3, 2026
bf993f8
Rewrite ROCm sort with custom merge sort implementation
NripeshN Feb 3, 2026
b76745e
Fix ROCm sort compilation errors
NripeshN Feb 3, 2026
969fd0b
Remove duplicate is_available() and unavailable header from ROCm eval…
Geramy Feb 3, 2026
b82594d
Add device_info.cpp for ROCm backend
Geramy Feb 3, 2026
231c078
Include memory.h in ROCm allocator for proper symbol visibility
Geramy Feb 3, 2026
8de6a7a
Fix all ROCm backend compiler warnings
NripeshN Feb 3, 2026
04b2e8d
Fix remaining ROCm backend compiler warnings
NripeshN Feb 3, 2026
bf3b69b
Add ROCm Python bindings and test skip list
NripeshN Feb 3, 2026
9af0755
Add MLX_API to rocm::is_available() for proper symbol export
NripeshN Feb 3, 2026
90377cc
Fix ROCm allocator to fall back to hipMalloc when managed memory fails
NripeshN Feb 3, 2026
b330ad1
Fix ROCm allocator to use hipHostMalloc when managed memory unavailable
NripeshN Feb 3, 2026
39b2926
Fix WARP_SIZE to be architecture-dependent for ROCm
NripeshN Feb 3, 2026
467fb00
Fix macro conflicts in WARP_SIZE and MAX_NDIM definitions
NripeshN Feb 3, 2026
4545bac
Fix WARP_SIZE_ROW namespace reference
NripeshN Feb 3, 2026
6e6d837
Fix MAX_NDIM macro reference in compiled.cpp
NripeshN Feb 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,6 @@ uv.lock
.cache/
# vim
*.swp

# keys
*.pem
37 changes: 35 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_BUILD_ROCM "Build rocm backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
Expand Down Expand Up @@ -158,6 +159,36 @@ if(MLX_BUILD_CUDA)
find_package(CUDNN REQUIRED)
endif()

if(MLX_BUILD_ROCM)
# Set HIP architectures - these will be used by the ROCm backend
# CMakeLists.txt
if(DEFINED MLX_ROCM_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES
${MLX_ROCM_ARCHITECTURES}
CACHE STRING "HIP architectures" FORCE)
else()
set(CMAKE_HIP_ARCHITECTURES
"gfx906;gfx908;gfx90a;gfx1030;gfx1100"
CACHE STRING "HIP architectures" FORCE)
endif()
message(
STATUS "Setting CMAKE_HIP_ARCHITECTURES to: ${CMAKE_HIP_ARCHITECTURES}")
# Note: We don't enable_language(HIP) here because it causes CMake to add -x
# hip to all CXX files in targets that link to HIP libraries. Instead, we
# compile HIP files using custom commands in the ROCm backend CMakeLists.txt.
# Find the HIP compiler
find_program(
CMAKE_HIP_COMPILER
NAMES hipcc clang++
PATHS /opt/rocm/bin /opt/rocm-6.0.0/bin /opt/rocm/llvm/bin
PATH_SUFFIXES bin
DOC "HIP compiler")
if(NOT CMAKE_HIP_COMPILER)
message(FATAL_ERROR "Could not find HIP compiler (hipcc or clang++)")
endif()
message(STATUS "Found HIP compiler: ${CMAKE_HIP_COMPILER}")
endif()

if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
Expand Down Expand Up @@ -286,10 +317,12 @@ if(MLX_BUILD_CPU)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/local/opt/openblas/include)
/usr/local/opt/openblas/include /usr/include/openblas)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
if(LAPACK_INCLUDE_DIRS)
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
endif()
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
Expand Down
11 changes: 10 additions & 1 deletion mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,16 @@ else()
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()

if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
if(MLX_BUILD_ROCM)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp)
endif()

if(MLX_BUILD_METAL
OR MLX_BUILD_CUDA
OR MLX_BUILD_ROCM)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
Expand Down
257 changes: 257 additions & 0 deletions mlx/backend/rocm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Filename rules in ROCm backend:
#
# * Use .hip/.hpp if code contains device code, and .cpp/.h if not.
# * Device-only code should be put in device/ subdir.
# * Files in device/ subdir should not include files outside.

# Find ROCm packages
find_package(hip REQUIRED CONFIG)
find_package(rocblas REQUIRED CONFIG)
find_package(rocthrust REQUIRED CONFIG)
find_package(rocprim REQUIRED CONFIG)
find_package(hiprand REQUIRED CONFIG)

# Ensure HIP architectures are set - respect user-provided value
if(NOT DEFINED CMAKE_HIP_ARCHITECTURES OR CMAKE_HIP_ARCHITECTURES STREQUAL "")
set(CMAKE_HIP_ARCHITECTURES
"gfx906;gfx908;gfx90a;gfx1030;gfx1100"
CACHE STRING "HIP architectures" FORCE)
endif()
message(
STATUS "ROCm backend using HIP architectures: ${CMAKE_HIP_ARCHITECTURES}")

# Build architecture flags
set(HIP_ARCH_FLAGS "")
foreach(arch ${CMAKE_HIP_ARCHITECTURES})
list(APPEND HIP_ARCH_FLAGS "--offload-arch=${arch}")
endforeach()

# Get HIP include directories
get_target_property(HIP_DEVICE_INCLUDES hip::device
INTERFACE_INCLUDE_DIRECTORIES)
get_target_property(ROCTHRUST_INCLUDES roc::rocthrust
INTERFACE_INCLUDE_DIRECTORIES)
get_target_property(ROCPRIM_INCLUDES roc::rocprim INTERFACE_INCLUDE_DIRECTORIES)
get_target_property(HIPRAND_INCLUDES hip::hiprand INTERFACE_INCLUDE_DIRECTORIES)

# Find GCC installation for C++ standard library headers
# ROCm's clang needs to know where to find libstdc++ headers
execute_process(
COMMAND ${CMAKE_CXX_COMPILER} -print-file-name=include/c++
OUTPUT_VARIABLE GCC_CXX_INCLUDE_BASE
OUTPUT_STRIP_TRAILING_WHITESPACE)
get_filename_component(GCC_CXX_INCLUDE_BASE "${GCC_CXX_INCLUDE_BASE}" DIRECTORY)

# Get GCC version for the target-specific include directory
execute_process(
COMMAND ${CMAKE_CXX_COMPILER} -dumpversion
OUTPUT_VARIABLE GCC_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE)
string(REGEX MATCH "^[0-9]+" GCC_MAJOR_VERSION "${GCC_VERSION}")

# Build include flags - use PROJECT_SOURCE_DIR for correct path
set(HIP_INCLUDE_FLAGS "-I${PROJECT_SOURCE_DIR}" "-I${HIP_INCLUDE_DIRS}")

# Add C++ standard library include paths for HIP compiler
if(EXISTS "${GCC_CXX_INCLUDE_BASE}/c++/${GCC_MAJOR_VERSION}")
list(APPEND HIP_INCLUDE_FLAGS "-I${GCC_CXX_INCLUDE_BASE}/c++/${GCC_MAJOR_VERSION}")
list(APPEND HIP_INCLUDE_FLAGS "-I${GCC_CXX_INCLUDE_BASE}/c++/${GCC_MAJOR_VERSION}/x86_64-linux-gnu")
list(APPEND HIP_INCLUDE_FLAGS "-I${GCC_CXX_INCLUDE_BASE}/c++/${GCC_MAJOR_VERSION}/backward")
endif()

# Also try to find system include directories
if(EXISTS "/usr/include/c++/${GCC_MAJOR_VERSION}")
list(APPEND HIP_INCLUDE_FLAGS "-I/usr/include/c++/${GCC_MAJOR_VERSION}")
list(APPEND HIP_INCLUDE_FLAGS "-I/usr/include/x86_64-linux-gnu/c++/${GCC_MAJOR_VERSION}")
list(APPEND HIP_INCLUDE_FLAGS "-I/usr/include/c++/${GCC_MAJOR_VERSION}/backward")
endif()

# Add standard system include paths
list(APPEND HIP_INCLUDE_FLAGS "-I/usr/include/x86_64-linux-gnu")
list(APPEND HIP_INCLUDE_FLAGS "-I/usr/include")

foreach(inc ${HIP_DEVICE_INCLUDES})
if(inc)
list(APPEND HIP_INCLUDE_FLAGS "-I${inc}")
endif()
endforeach()
foreach(inc ${ROCTHRUST_INCLUDES})
if(inc)
list(APPEND HIP_INCLUDE_FLAGS "-I${inc}")
endif()
endforeach()
foreach(inc ${ROCPRIM_INCLUDES})
if(inc)
list(APPEND HIP_INCLUDE_FLAGS "-I${inc}")
endif()
endforeach()
foreach(inc ${HIPRAND_INCLUDES})
if(inc)
list(APPEND HIP_INCLUDE_FLAGS "-I${inc}")
endif()
endforeach()

message(STATUS "HIP include flags: ${HIP_INCLUDE_FLAGS}")

# HIP source files
set(HIP_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/event.hip
${CMAKE_CURRENT_SOURCE_DIR}/arange.hip
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/binary.hip
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.hip
${CMAKE_CURRENT_SOURCE_DIR}/copy.hip
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.hip
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.hip
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.hip
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.hip
${CMAKE_CURRENT_SOURCE_DIR}/distributed.hip
${CMAKE_CURRENT_SOURCE_DIR}/indexing.hip
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip
${CMAKE_CURRENT_SOURCE_DIR}/random.hip
${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.hip
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip
${CMAKE_CURRENT_SOURCE_DIR}/rope.hip
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.hip
${CMAKE_CURRENT_SOURCE_DIR}/scan.hip
${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip
${CMAKE_CURRENT_SOURCE_DIR}/sort.hip
${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip
${CMAKE_CURRENT_SOURCE_DIR}/unary.hip
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.hip
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.hip
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.hip
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.hip
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.hip)

# Create output directory for compiled objects
set(HIP_OBJ_DIR "${CMAKE_CURRENT_BINARY_DIR}/hip_objs")
file(MAKE_DIRECTORY ${HIP_OBJ_DIR})

# Compile each HIP file to object file using custom commands Use -fno-gpu-rdc to
# avoid needing device link step
set(HIP_OBJECTS "")
foreach(hip_src ${HIP_SOURCES})
get_filename_component(hip_name ${hip_src} NAME_WE)
get_filename_component(hip_dir ${hip_src} DIRECTORY)
file(RELATIVE_PATH rel_dir ${CMAKE_CURRENT_SOURCE_DIR} ${hip_dir})

# Create subdirectory for object if needed
if(rel_dir)
set(obj_subdir "${HIP_OBJ_DIR}/${rel_dir}")
file(MAKE_DIRECTORY ${obj_subdir})
set(hip_obj "${obj_subdir}/${hip_name}.o")
else()
set(hip_obj "${HIP_OBJ_DIR}/${hip_name}.o")
endif()

add_custom_command(
OUTPUT ${hip_obj}
COMMAND ${CMAKE_HIP_COMPILER} -c ${hip_src} -o ${hip_obj} -fPIC
-DMLX_USE_ROCM ${HIP_ARCH_FLAGS} ${HIP_INCLUDE_FLAGS} -std=c++17
DEPENDS ${hip_src}
COMMENT "Compiling HIP source ${hip_src}"
VERBATIM)

list(APPEND HIP_OBJECTS ${hip_obj})
endforeach()

# Create a custom target for all HIP objects
add_custom_target(mlx_hip_objects DEPENDS ${HIP_OBJECTS})

# Create static library from all objects (no device link needed without
# -fgpu-rdc)
set(HIP_STATIC_LIB "${CMAKE_CURRENT_BINARY_DIR}/libmlx_rocm_kernels.a")
add_custom_command(
OUTPUT ${HIP_STATIC_LIB}
COMMAND ${CMAKE_AR} rcs ${HIP_STATIC_LIB} ${HIP_OBJECTS}
DEPENDS ${HIP_OBJECTS}
COMMENT "Creating static library from HIP objects"
VERBATIM)

add_custom_target(mlx_rocm_kernels_lib DEPENDS ${HIP_STATIC_LIB})

# Add C++ sources directly to mlx target
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/rocblas_gemm.cpp)

target_compile_definitions(mlx PRIVATE MLX_USE_ROCM)

# Make mlx depend on the HIP kernels library
add_dependencies(mlx mlx_rocm_kernels_lib)

# Get the library paths from the imported targets (without propagating compile
# options)
get_target_property(ROCBLAS_LIB roc::rocblas IMPORTED_LOCATION)
if(NOT ROCBLAS_LIB)
get_target_property(ROCBLAS_LIB roc::rocblas IMPORTED_LOCATION_RELEASE)
endif()
if(NOT ROCBLAS_LIB)
# Fallback to finding the library directly
find_library(ROCBLAS_LIB rocblas PATHS ${ROCM_PATH}/lib /opt/rocm/lib
/opt/rocm-6.0.0/lib)
endif()

get_target_property(HIPRAND_LIB hip::hiprand IMPORTED_LOCATION)
if(NOT HIPRAND_LIB)
get_target_property(HIPRAND_LIB hip::hiprand IMPORTED_LOCATION_RELEASE)
endif()
if(NOT HIPRAND_LIB)
find_library(HIPRAND_LIB hiprand PATHS ${ROCM_PATH}/lib /opt/rocm/lib
/opt/rocm-6.0.0/lib)
endif()

# Find amdhip64 library
find_library(AMDHIP64_LIB amdhip64 PATHS ${ROCM_PATH}/lib /opt/rocm/lib
/opt/rocm-6.0.0/lib)

# Find hiprtc library (needed for JIT compilation)
find_library(HIPRTC_LIB hiprtc PATHS ${ROCM_PATH}/lib /opt/rocm/lib
/opt/rocm-6.0.0/lib)

message(
STATUS
"ROCm libraries: rocblas=${ROCBLAS_LIB}, hiprand=${HIPRAND_LIB}, amdhip64=${AMDHIP64_LIB}, hiprtc=${HIPRTC_LIB}"
)

# Link the static library and ROCm libraries to mlx We link directly to the .so
# files instead of using CMake targets to avoid propagating compile options like
# -x hip
target_link_libraries(mlx PRIVATE ${HIP_STATIC_LIB} ${AMDHIP64_LIB}
${ROCBLAS_LIB} ${HIPRAND_LIB} ${HIPRTC_LIB})

# Include ROCm headers for mlx C++ files Get the HIP include directory from the
# hip package
get_target_property(HIP_HOST_INCLUDES hip::host INTERFACE_INCLUDE_DIRECTORIES)
if(HIP_HOST_INCLUDES)
target_include_directories(mlx PRIVATE ${HIP_HOST_INCLUDES})
endif()
target_include_directories(mlx PRIVATE ${HIP_INCLUDE_DIRS})

# Add HIP platform define for C++ files
target_compile_definitions(mlx PRIVATE __HIP_PLATFORM_AMD__=1)
Loading