Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions mlx/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/fused_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
Expand Down Expand Up @@ -90,7 +91,8 @@ file(
GLOB MLX_JIT_SOURCES
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh"
"${CMAKE_CURRENT_SOURCE_DIR}/reduce/*.cuh")
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
add_custom_command(
OUTPUT gen/cuda_jit_sources.h
Expand Down Expand Up @@ -210,13 +212,16 @@ install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
install(DIRECTORY ${cccl_SOURCE_DIR}/include/nv
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
# CUB and Thrust are needed to JIT compile unary -> reduction kernels.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cub
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
install(DIRECTORY ${cccl_SOURCE_DIR}/include/thrust
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

# The binary of C++ tests will not be installed so it can not find the CCCL
# headers, and we have to hard-code the path.
if(MLX_BUILD_TESTS)
target_compile_definitions(mlx
PRIVATE MLX_CCCL_DIR="${cccl_SOURCE_DIR}/include")
endif()
# JIT cannot find the CCCL headers when compiling unary -> reduction kernels, so
# we hard-code the path to the downloaded CCCL.
target_compile_definitions(mlx
PRIVATE MLX_CCCL_DIR="${cccl_SOURCE_DIR}/include")

# Use fixed version of NVTX.
FetchContent_Declare(
Expand Down
8 changes: 8 additions & 0 deletions mlx/backend/cuda/jit_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ bool compiler_supports_device_sass(Device& device) {
}

#define INCLUDE_PREFIX "mlx/backend/cuda/device/"
#define REDUCE_PREFIX "mlx/backend/cuda/reduce/"

constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "atomic_ops.cuh",
Expand All @@ -236,9 +237,13 @@ constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "unary_ops.cuh",
INCLUDE_PREFIX "ternary_ops.cuh",
INCLUDE_PREFIX "utils.cuh",
REDUCE_PREFIX "all_reduce.cuh",
REDUCE_PREFIX "reduce_ops.cuh",
REDUCE_PREFIX "reduce_utils.cuh",
};

#undef INCLUDE_PREFIX
#undef REDUCE_PREFIX

constexpr const char* g_headers[] = {
jit_source_atomic_ops,
Expand All @@ -252,6 +257,9 @@ constexpr const char* g_headers[] = {
jit_source_unary_ops,
jit_source_ternary_ops,
jit_source_utils,
jit_source_all_reduce,
jit_source_reduce_ops,
jit_source_reduce_utils,
};

void compile(
Expand Down
33 changes: 33 additions & 0 deletions mlx/backend/cuda/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,39 @@ namespace mlx::core {

void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Reduce::eval_gpu");

if (has_fused_prefix()) {
array in = inputs[0];

auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

if (in.size() == 0) {
init_reduce(encoder, in, out, reduce_type_);
return;
}

ReductionPlan plan = get_reduction_plan(in, axes_);

bool broadcasted = false;
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
broadcasted = in.strides(i) == 0;
}
}
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy = contiguous_copy_gpu(in, s);
encoder.add_temporary(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);
}

fused_reduce(encoder, *this, inputs, out, axes_, plan, s);
return;
}

assert(inputs.size() == 1);
array in = inputs[0];

Expand Down
56 changes: 1 addition & 55 deletions mlx/backend/cuda/reduce/all_reduce.cu
Original file line number Diff line number Diff line change
@@ -1,64 +1,10 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/all_reduce.cuh"
#include "mlx/backend/cuda/reduce/reduce.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>

namespace mlx::core {

namespace cu {

namespace cg = cooperative_groups;

template <typename T, typename U, typename ReduceOp, int N = 4>
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
// TODO: Process multiple "rows" in each thread
constexpr int M = 1;

auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;

T vals[N];
U accs[M];
accs[0] = init;

size_t start = grid.block_rank() * block_step;
size_t end = start + block_step;
size_t check = min(end, size);

size_t i = start;
for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], cast_to<U>(vals[j]));
}
}

if (i < check) {
cub::LoadDirectBlocked(
block.thread_rank(), in + i, vals, check - i, cast_to<T>(init));
for (int i = 0; i < N; i++) {
accs[0] = op(accs[0], cast_to<U>(vals[i]));
}
}

__shared__ U shared_accumulators[32];
block_reduce(block, warp, accs, shared_accumulators, op, init);

if (block.thread_rank() == 0) {
out[grid.block_rank()] = accs[0];
}
}

} // namespace cu

void all_reduce(
cu::CommandEncoder& encoder,
const array& in,
Expand Down
76 changes: 76 additions & 0 deletions mlx/backend/cuda/reduce/all_reduce.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright © 2025 Apple Inc.

#pragma once

#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>

namespace mlx::core::cu {

namespace cg = cooperative_groups;

template <typename T, typename U, typename ReduceOp, int N, typename PrefixOp>
__device__ void all_reduce_impl(
T* in,
U* out,
size_t block_step,
size_t size,
PrefixOp prefix) {
// TODO: Process multiple "rows" in each thread
constexpr int M = 1;

auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

const U init = ReduceInit<ReduceOp, T>::value();
ReduceOp op;

T vals[N];
U accs[M];
accs[0] = init;

size_t start = grid.block_rank() * block_step;
size_t end = start + block_step;
size_t check = min(end, size);

size_t i = start;
for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], cast_to<U>(prefix(vals[j])));
}
}

if (i < check) {
cub::LoadDirectBlocked(
block.thread_rank(), in + i, vals, check - i, cast_to<T>(init));
for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], cast_to<U>(prefix(vals[j])));
}
}

__shared__ U shared_accumulators[32];
block_reduce(block, warp, accs, shared_accumulators, op, init);

if (block.thread_rank() == 0) {
out[grid.block_rank()] = accs[0];
}
}

template <
typename T,
typename U,
typename ReduceOp,
int N = 4,
typename PrefixOp = Identity>
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
all_reduce_impl<T, U, ReduceOp, N, PrefixOp>(
in, out, block_step, size, PrefixOp{});
}

} // namespace mlx::core::cu
Loading