From 23c4fc37176e6af8e4051f863896ee919bf0472e Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 31 Jan 2026 02:11:10 +0100 Subject: [PATCH 1/9] [WIP] fuse fusable into reduction --- mlx/compile.cpp | 73 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index ca5f069937..3c6e186cc5 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -801,6 +801,79 @@ void compile_fuse( if (global_cache.find(arr.id()) != global_cache.end()) { continue; } + // If current op is a reduction, we may want to fuse prefix ops + if (arr.has_primitive() && is_reduction(arr.primitive())) { + auto& reduction_input = arr.inputs()[0]; + Stream reduction_stream = arr.primitive().stream(); + const int max_prefix_depth = max_compile_depth - 1; // 1 for reduction + + std::vector prefix_tape; // + std::vector prefix_inputs; // + std::unordered_set visited; + + std::function collect_prefix; + collect_prefix = [&](const array& a, int depth) { + // Skip if already processed + if (visited.count(a.id())) { + return; + } + // Stop fusing if: + // depth limit exceeded + // non fusable primitive + // does not have primitive + // stream mismatch + // is a constant input + if (depth >= max_prefix_depth || !a.has_primitive() || + !is_fusable(a.primitive()) || + a.primitive().stream() != reduction_stream || + input_ids.count(a.id())) { + prefix_inputs.push_back(a); + visited.insert(a.id()); + return; + } + // Check if the input is used multiple times + auto pit = parents_map.find(a.id()); + if (pit != parents_map.end() && pit->second.size() > 1) { + prefix_inputs.push_back(a); + visited.insert(a.id()); + return; + } + visited.insert(a.id()); + for (auto& in : a.inputs()) { + collect_prefix(in, depth + 1); + } + prefix_tape.push_back(a); + }; + + collect_prefix(reduction_input, 0); + + // If there are operations that we can fuse + if (!prefix_tape.empty()) { + std::unordered_set constant_ids; + for (auto& in : prefix_inputs) { + if (in.size() == 1 && !in.has_primitive() && + input_ids.find(in.id()) == input_ids.end()) { + constant_ids.insert(in.id()); + } + } + + // Attach prefix to the Reduce primitive + auto& reduce = static_cast(arr.primitive()); + reduce.set_fused_prefix( + std::move(prefix_tape), + std::move(prefix_inputs), + std::move(constant_ids)); + + for (auto& p : reduce.prefix_tape()) { + global_cache.insert(p.id()); + } + } + + // Add the reduction to the new tape (with or without fused prefix) + new_tape.push_back(arr); + global_cache.insert(arr.id()); + continue; + } // Two pass recursion: // First pass: From 54a4043fbc55b41b73b8a44fc754a1eb0f6f6737 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 31 Jan 2026 22:08:56 +0100 Subject: [PATCH 2/9] move kernel code to .cuh --- mlx/backend/cuda/reduce/all_reduce.cu | 57 +----- mlx/backend/cuda/reduce/all_reduce.cuh | 66 +++++++ mlx/backend/cuda/reduce/col_reduce.cu | 243 +---------------------- mlx/backend/cuda/reduce/col_reduce.cuh | 254 +++++++++++++++++++++++++ mlx/backend/cuda/reduce/reduce_ops.cuh | 7 + mlx/backend/cuda/reduce/row_reduce.cu | 227 +--------------------- mlx/backend/cuda/reduce/row_reduce.cuh | 244 ++++++++++++++++++++++++ mlx/primitives.h | 30 +++ 8 files changed, 604 insertions(+), 524 deletions(-) create mode 100644 mlx/backend/cuda/reduce/all_reduce.cuh create mode 100644 mlx/backend/cuda/reduce/col_reduce.cuh create mode 100644 mlx/backend/cuda/reduce/row_reduce.cuh diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 962e80d4f2..81932f4d6d 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -1,64 +1,9 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/reduce/reduce.cuh" - -#include -#include -#include +#include "mlx/backend/cuda/reduce/all_reduce.cuh" namespace mlx::core { -namespace cu { - -namespace cg = cooperative_groups; - -template -__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { - // TODO: Process multiple "rows" in each thread - constexpr int M = 1; - - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - const U init = cu::ReduceInit::value(); - ReduceOp op; - - T vals[N]; - U accs[M]; - accs[0] = init; - - size_t start = grid.block_rank() * block_step; - size_t end = start + block_step; - size_t check = min(end, size); - - size_t i = start; - for (; i + block.size() * N <= check; i += block.size() * N) { - cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); - for (int j = 0; j < N; j++) { - accs[0] = op(accs[0], cast_to(vals[j])); - } - } - - if (i < check) { - cub::LoadDirectBlocked( - block.thread_rank(), in + i, vals, check - i, cast_to(init)); - for (int i = 0; i < N; i++) { - accs[0] = op(accs[0], cast_to(vals[i])); - } - } - - __shared__ U shared_accumulators[32]; - block_reduce(block, warp, accs, shared_accumulators, op, init); - - if (block.thread_rank() == 0) { - out[grid.block_rank()] = accs[0]; - } -} - -} // namespace cu - void all_reduce( cu::CommandEncoder& encoder, const array& in, diff --git a/mlx/backend/cuda/reduce/all_reduce.cuh b/mlx/backend/cuda/reduce/all_reduce.cuh new file mode 100644 index 0000000000..f627622625 --- /dev/null +++ b/mlx/backend/cuda/reduce/all_reduce.cuh @@ -0,0 +1,66 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" + +#include +#include +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename U, + typename ReduceOp, + int N = 4, + typename PrefixOp = Identity> +__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { + // TODO: Process multiple "rows" in each thread + constexpr int M = 1; + + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = ReduceInit::value(); + ReduceOp op; + PrefixOp prefix; + + T vals[N]; + U accs[M]; + accs[0] = init; + + size_t start = grid.block_rank() * block_step; + size_t end = start + block_step; + size_t check = min(end, size); + + size_t i = start; + for (; i + block.size() * N <= check; i += block.size() * N) { + cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); + for (int j = 0; j < N; j++) { + accs[0] = op(accs[0], cast_to(prefix(vals[j]))); + } + } + + if (i < check) { + cub::LoadDirectBlocked( + block.thread_rank(), in + i, vals, check - i, cast_to(init)); + for (int i = 0; i < N; i++) { + accs[0] = op(accs[0], cast_to(prefix(vals[i]))); + } + } + + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, accs, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + out[grid.block_rank()] = accs[0]; + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index e33551d86e..bbeb226bf1 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -1,250 +1,9 @@ // Copyright © 2025 Apple Inc. -#include - -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/reduce/reduce.cuh" - -#include -#include -#include -#include +#include "mlx/backend/cuda/reduce/col_reduce.cuh" namespace mlx::core { -namespace cu { - -namespace cg = cooperative_groups; - -struct ColReduceArgs { - // The size of the contiguous column reduction. - size_t reduction_size; - int64_t reduction_stride; - - // Input shape and strides excluding the reduction axes. - Shape shape; - Strides strides; - int ndim; - - // Input shape and strides of the reduction axes (including last dimension). - Shape reduce_shape; - Strides reduce_strides; - int reduce_ndim; - - // The number of column we are reducing. Namely prod(reduce_shape). - size_t non_col_reductions; - - ColReduceArgs( - const array& in, - const ReductionPlan& plan, - const std::vector& axes) { - using ShapeVector = decltype(plan.shape); - using StridesVector = decltype(plan.strides); - - ShapeVector shape_vec; - StridesVector strides_vec; - - assert(!plan.shape.empty()); - reduction_size = plan.shape.back(); - reduction_stride = plan.strides.back(); - - int64_t stride_back = 1; - std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); - while (!shape_vec.empty() && stride_back < reduction_stride) { - stride_back *= shape_vec.back(); - shape_vec.pop_back(); - strides_vec.pop_back(); - } - std::vector indices(shape_vec.size()); - std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), indices.end(), [&](int left, int right) { - return strides_vec[left] > strides_vec[right]; - }); - ShapeVector sorted_shape; - StridesVector sorted_strides; - for (auto idx : indices) { - sorted_shape.push_back(shape_vec[idx]); - sorted_strides.push_back(strides_vec[idx]); - } - std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(sorted_shape, sorted_strides); - shape = const_param(shape_vec); - strides = const_param(strides_vec); - ndim = shape_vec.size(); - - reduce_shape = const_param(plan.shape); - reduce_strides = const_param(plan.strides); - reduce_ndim = plan.shape.size(); - - non_col_reductions = 1; - for (int i = 0; i < reduce_ndim - 1; i++) { - non_col_reductions *= reduce_shape[i]; - } - } -}; - -template < - typename T, - typename U, - typename Op, - int NDIM, - int BM, - int BN, - int N_READS = 4, - int BLOCKS = 1> -__global__ void col_reduce_looped( - T* in, - U* out, - const __grid_constant__ ColReduceArgs args, - int64_t out_size) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - constexpr int threads_per_row = BN / N_READS; - - // Compute the indices for the tile - size_t tile_idx = grid.block_rank(); - size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); - size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); - size_t tile_out = tile_y / out_size; - tile_y = tile_y % out_size; - - // Compute the indices for the thread within the tile - short thread_x = block.thread_rank() % threads_per_row; - short thread_y = block.thread_rank() / threads_per_row; - - // Move the input pointer - in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + - tile_x * BN; - - // Initialize the running totals - Op op; - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = ReduceInit::value(); - } - - size_t total = args.non_col_reductions * args.reduction_size; - size_t per_block, start, end; - if constexpr (BLOCKS > 1) { - per_block = (total + BLOCKS - 1) / BLOCKS; - start = tile_out * per_block + thread_y; - end = min((tile_out + 1) * per_block, total); - } else { - per_block = total; - start = thread_y; - end = total; - } - - LoopedElemToLoc 2)> loop(args.reduce_ndim); - loop.next(start, args.reduce_shape.data(), args.reduce_strides.data()); - if (tile_x * BN + BN <= args.reduction_stride) { - if (args.reduction_stride % N_READS == 0) { - for (size_t r = start; r < end; r += BM) { - T vals[N_READS]; - cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], cast_to(vals[i])); - } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); - } - } else { - for (size_t r = start; r < end; r += BM) { - T vals[N_READS]; - cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], cast_to(vals[i])); - } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); - } - } - } else { - for (size_t r = start; r < end; r += BM) { - T vals[N_READS]; - cub::LoadDirectBlocked( - thread_x, - in + loop.location(), - vals, - args.reduction_stride - tile_x * BN, - cast_to(ReduceInit::value())); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], cast_to(vals[i])); - } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); - } - } - - // Do warp reduce for each output. - constexpr int n_outputs = BN / threads_per_row; - static_assert(BM == 32 && n_outputs == N_READS); - __shared__ U shared_vals[BM * BN]; - short s_idx = thread_y * BN + thread_x * N_READS; - for (int i = 0; i < N_READS; i++) { - shared_vals[s_idx + i] = totals[i]; - } - block.sync(); - s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; - for (int i = 0; i < n_outputs; i++) { - totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); - } - - // Write result. - if (warp.thread_rank() == 0) { - if (BLOCKS > 1) { - out += tile_out * out_size * args.reduction_stride; - } - cub::StoreDirectBlocked( - warp.meta_group_rank(), - out + tile_y * args.reduction_stride + tile_x * BN, - totals, - args.reduction_stride - tile_x * BN); - } -} - -template -__global__ void col_reduce_small( - const T* in, - U* out, - const __grid_constant__ ColReduceArgs args, - size_t total) { - Op op; - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - - const auto idx = grid.thread_rank() * N_READS; - const auto before_axis = idx / args.reduction_stride; - const auto after_axis = idx % args.reduction_stride; - const auto offset = - before_axis * args.reduction_stride * args.reduction_size + after_axis; - - if (idx >= total) { - return; - } - - in += offset; - out += idx; - - AlignedVector accumulator; - for (int i = 0; i < N_READS; i++) { - accumulator[i] = ReduceInit::value(); - } - - for (int i = 0; i < args.reduction_size; i++) { - auto values = load_vector(in, 0); - - for (int j = 0; j < N_READS; j++) { - accumulator[j] = op(accumulator[j], cast_to(values[j])); - } - - in += args.reduction_stride; - } - - store_vector(out, 0, accumulator); -} - -} // namespace cu - inline auto output_grid_for_col_reduce( const array& out, const cu::ColReduceArgs& args, diff --git a/mlx/backend/cuda/reduce/col_reduce.cuh b/mlx/backend/cuda/reduce/col_reduce.cuh new file mode 100644 index 0000000000..ba04e3df15 --- /dev/null +++ b/mlx/backend/cuda/reduce/col_reduce.cuh @@ -0,0 +1,254 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" + +#include +#include +#include +#include + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +struct ColReduceArgs { + // The size of the contiguous column reduction. + size_t reduction_size; + int64_t reduction_stride; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes (including last dimension). + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of column we are reducing. Namely prod(reduce_shape). + size_t non_col_reductions; + + ColReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + using ShapeVector = decltype(plan.shape); + using StridesVector = decltype(plan.strides); + + ShapeVector shape_vec; + StridesVector strides_vec; + + assert(!plan.shape.empty()); + reduction_size = plan.shape.back(); + reduction_stride = plan.strides.back(); + + int64_t stride_back = 1; + std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); + while (!shape_vec.empty() && stride_back < reduction_stride) { + stride_back *= shape_vec.back(); + shape_vec.pop_back(); + strides_vec.pop_back(); + } + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + ShapeVector sorted_shape; + StridesVector sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(sorted_shape, sorted_strides); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size(); + + non_col_reductions = 1; + for (int i = 0; i < reduce_ndim - 1; i++) { + non_col_reductions *= reduce_shape[i]; + } + } +}; + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BM, + int BN, + int N_READS = 4, + int BLOCKS = 1, + typename PrefixOp = Identity> +__global__ void col_reduce_looped( + T* in, + U* out, + const __grid_constant__ ColReduceArgs args, + int64_t out_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int threads_per_row = BN / N_READS; + + // Compute the indices for the tile + size_t tile_idx = grid.block_rank(); + size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); + size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); + size_t tile_out = tile_y / out_size; + tile_y = tile_y % out_size; + + // Compute the indices for the thread within the tile + short thread_x = block.thread_rank() % threads_per_row; + short thread_y = block.thread_rank() / threads_per_row; + + // Move the input pointer + in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + + tile_x * BN; + + // Initialize the running totals + Op op; + PrefixOp prefix; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + size_t total = args.non_col_reductions * args.reduction_size; + size_t per_block, start, end; + if constexpr (BLOCKS > 1) { + per_block = (total + BLOCKS - 1) / BLOCKS; + start = tile_out * per_block + thread_y; + end = min((tile_out + 1) * per_block, total); + } else { + per_block = total; + start = thread_y; + end = total; + } + + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next(start, args.reduce_shape.data(), args.reduce_strides.data()); + if (tile_x * BN + BN <= args.reduction_stride) { + if (args.reduction_stride % N_READS == 0) { + for (size_t r = start; r < end; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], cast_to(prefix(vals[i]))); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } else { + for (size_t r = start; r < end; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], cast_to(prefix(vals[i]))); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } + } else { + for (size_t r = start; r < end; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked( + thread_x, + in + loop.location(), + vals, + args.reduction_stride - tile_x * BN, + cast_to(ReduceInit::value())); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], cast_to(prefix(vals[i]))); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } + + // Do warp reduce for each output. + constexpr int n_outputs = BN / threads_per_row; + static_assert(BM == 32 && n_outputs == N_READS); + __shared__ U shared_vals[BM * BN]; + short s_idx = thread_y * BN + thread_x * N_READS; + for (int i = 0; i < N_READS; i++) { + shared_vals[s_idx + i] = totals[i]; + } + block.sync(); + s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + for (int i = 0; i < n_outputs; i++) { + totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); + } + + // Write result. + if (warp.thread_rank() == 0) { + if (BLOCKS > 1) { + out += tile_out * out_size * args.reduction_stride; + } + cub::StoreDirectBlocked( + warp.meta_group_rank(), + out + tile_y * args.reduction_stride + tile_x * BN, + totals, + args.reduction_stride - tile_x * BN); + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS = 4, + typename PrefixOp = Identity> +__global__ void col_reduce_small( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args, + size_t total) { + Op op; + PrefixOp prefix; + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + const auto idx = grid.thread_rank() * N_READS; + const auto before_axis = idx / args.reduction_stride; + const auto after_axis = idx % args.reduction_stride; + const auto offset = + before_axis * args.reduction_stride * args.reduction_size + after_axis; + + if (idx >= total) { + return; + } + + in += offset; + out += idx; + + AlignedVector accumulator; + for (int i = 0; i < N_READS; i++) { + accumulator[i] = ReduceInit::value(); + } + + for (int i = 0; i < args.reduction_size; i++) { + auto values = load_vector(in, 0); + + for (int j = 0; j < N_READS; j++) { + accumulator[j] = op(accumulator[j], cast_to(prefix(values[j]))); + } + + in += args.reduction_stride; + } + + store_vector(out, 0, accumulator); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 6c6b1827ce..a8302dded4 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -9,6 +9,13 @@ namespace mlx::core::cu { +struct Identity { + template + __device__ __forceinline__ T operator()(T x) const { + return x; + } +}; + // Reduce ops. struct And { __device__ __forceinline__ bool operator()(bool a, bool b) { diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index ea99e11325..55441e8e23 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -1,234 +1,9 @@ // Copyright © 2025 Apple Inc. -#include - -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/reduce/reduce.cuh" - -#include -#include +#include "mlx/backend/cuda/reduce/row_reduce.cuh" namespace mlx::core { -namespace cu { - -namespace cg = cooperative_groups; - -struct RowReduceArgs { - // The size of the row being reduced, i.e. the size of last dimension. - int row_size; - - // Input shape and strides excluding the reduction axes. - Shape shape; - Strides strides; - int ndim; - - // Input shape and strides of the reduction axes excluding last dimension. - Shape reduce_shape; - Strides reduce_strides; - int reduce_ndim; - - // The number of rows we are reducing. Namely prod(reduce_shape). - size_t non_row_reductions; - - RowReduceArgs( - const array& in, - const ReductionPlan& plan, - const std::vector& axes) { - assert(!plan.shape.empty()); - row_size = plan.shape.back(); - - auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); - std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(shape_vec, strides_vec); - shape = const_param(shape_vec); - strides = const_param(strides_vec); - ndim = shape_vec.size(); - - reduce_shape = const_param(plan.shape); - reduce_strides = const_param(plan.strides); - reduce_ndim = plan.shape.size() - 1; - - non_row_reductions = 1; - for (int i = 0; i < reduce_ndim; i++) { - non_row_reductions *= reduce_shape[i]; - } - } - - // Convert shape and strides as if in was contiguous - void sort_access_pattern(const array& in, const std::vector& axes) { - auto shape_vec = in.shape(); - auto strides_vec = in.strides(); - std::tie(shape_vec, strides_vec) = - shapes_without_reduction_axes(shape_vec, strides_vec, axes); - std::vector indices(shape_vec.size()); - std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), indices.end(), [&](int left, int right) { - return strides_vec[left] > strides_vec[right]; - }); - decltype(shape_vec) sorted_shape; - decltype(strides_vec) sorted_strides; - for (auto idx : indices) { - sorted_shape.push_back(shape_vec[idx]); - sorted_strides.push_back(strides_vec[idx]); - } - std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(sorted_shape, sorted_strides); - shape = const_param(shape_vec); - strides = const_param(strides_vec); - ndim = shape_vec.size(); - } -}; - -template -__global__ void -row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - const U init = cu::ReduceInit::value(); - ReduceOp op; - - AlignedVector vals[M]; - AlignedVector accs; - for (int i = 0; i < M; i++) { - accs[i] = init; - } - - const size_t start_row = - min(n_rows - M, static_cast(grid.block_rank() * M)); - const size_t full_blocks = size / (block.size() * N); - const size_t final_offset = full_blocks * (block.size() * N); - in += start_row * size + block.thread_rank() * N; - out += start_row; - - for (size_t r = 0; r < full_blocks; r++) { - for (int k = 0; k < M; k++) { - vals[k] = load_vector(in + k * size, 0); - } - for (int k = 0; k < M; k++) { - for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], cast_to(vals[k][j])); - } - } - - in += block.size() * N; - } - - if (final_offset < size) { - for (int k = 0; k < M; k++) { - for (int i = 0; i < N; i++) { - vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size) - ? in[k * size + i] - : cast_to(init); - } - } - for (int k = 0; k < M; k++) { - for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], cast_to(vals[k][j])); - } - } - } - - __shared__ U shared_accumulators[32 * M]; - block_reduce(block, warp, accs.val, shared_accumulators, op, init); - - if (block.thread_rank() == 0) { - if (grid.block_rank() * M + M <= n_rows) { - store_vector(out, 0, accs); - } else { - short offset = grid.block_rank() * M + M - n_rows; - for (int i = offset; i < M; i++) { - out[i] = accs[i]; - } - } - } -} - -template -__global__ void row_reduce_looped( - const T* in, - U* out, - const __grid_constant__ RowReduceArgs args) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - size_t out_idx = grid.block_rank(); - - Op op; - - U total[1]; - U init = ReduceInit::value(); - total[0] = init; - LoopedElemToLoc 2)> loop(args.reduce_ndim); - const size_t full_blocks = args.row_size / (block.size() * N_READS); - const size_t final_offset = full_blocks * (block.size() * N_READS); - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - in += block.thread_rank() * N_READS; - - // Unaligned reduce - if (final_offset < args.row_size) { - bool mask[N_READS]; - for (int i = 0; i < N_READS; i++) { - mask[i] = - (final_offset + block.thread_rank() * N_READS + i) < args.row_size; - } - - for (size_t n = 0; n < args.non_row_reductions; n++) { - const T* inlocal = in + loop.location(); - - for (size_t r = 0; r < full_blocks; r++) { - auto vals = load_vector(inlocal, 0); - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(vals[i])); - } - inlocal += block.size() * N_READS; - } - - { - T vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - vals[i] = mask[i] ? inlocal[i] : cast_to(init); - } - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(vals[i])); - } - } - - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); - } - } - - // Aligned case - else { - for (size_t n = 0; n < args.non_row_reductions; n++) { - const T* inlocal = in + loop.location(); - - for (size_t r = 0; r < full_blocks; r++) { - auto vals = load_vector(inlocal, 0); - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(vals[i])); - } - inlocal += block.size() * N_READS; - } - - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); - } - } - - __shared__ U shared_accumulators[32]; - block_reduce(block, warp, total, shared_accumulators, op, init); - - if (block.thread_rank() == 0) { - out[out_idx] = total[0]; - } -} - -} // namespace cu - void row_reduce_simple( cu::CommandEncoder& encoder, const array& in, diff --git a/mlx/backend/cuda/reduce/row_reduce.cuh b/mlx/backend/cuda/reduce/row_reduce.cuh new file mode 100644 index 0000000000..e77cba44ba --- /dev/null +++ b/mlx/backend/cuda/reduce/row_reduce.cuh @@ -0,0 +1,244 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" + +#include +#include + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +struct RowReduceArgs { + // The size of the row being reduced, i.e. the size of last dimension. + int row_size; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes excluding last dimension. + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of rows we are reducing. Namely prod(reduce_shape). + size_t non_row_reductions; + + RowReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + row_size = plan.shape.back(); + + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size() - 1; + + non_row_reductions = 1; + for (int i = 0; i < reduce_ndim; i++) { + non_row_reductions *= reduce_shape[i]; + } + } + + // Convert shape and strides as if in was contiguous + void sort_access_pattern(const array& in, const std::vector& axes) { + auto shape_vec = in.shape(); + auto strides_vec = in.strides(); + std::tie(shape_vec, strides_vec) = + shapes_without_reduction_axes(shape_vec, strides_vec, axes); + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + decltype(shape_vec) sorted_shape; + decltype(strides_vec) sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(sorted_shape, sorted_strides); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + } +}; + +template < + typename T, + typename U, + typename ReduceOp, + int N = 4, + int M = 1, + typename PrefixOp = Identity> +__global__ void +row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = ReduceInit::value(); + ReduceOp op; + PrefixOp prefix; + + AlignedVector vals[M]; + AlignedVector accs; + for (int i = 0; i < M; i++) { + accs[i] = init; + } + + const size_t start_row = + min(n_rows - M, static_cast(grid.block_rank() * M)); + const size_t full_blocks = size / (block.size() * N); + const size_t final_offset = full_blocks * (block.size() * N); + in += start_row * size + block.thread_rank() * N; + out += start_row; + + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + vals[k] = load_vector(in + k * size, 0); + } + for (int k = 0; k < M; k++) { + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], cast_to(prefix(vals[k][j]))); + } + } + + in += block.size() * N; + } + + if (final_offset < size) { + for (int k = 0; k < M; k++) { + for (int i = 0; i < N; i++) { + vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size) + ? in[k * size + i] + : cast_to(init); + } + } + for (int k = 0; k < M; k++) { + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], cast_to(prefix(vals[k][j]))); + } + } + } + + __shared__ U shared_accumulators[32 * M]; + block_reduce(block, warp, accs.val, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + if (grid.block_rank() * M + M <= n_rows) { + store_vector(out, 0, accs); + } else { + short offset = grid.block_rank() * M + M - n_rows; + for (int i = offset; i < M; i++) { + out[i] = accs[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int N_READS = 4, + typename PrefixOp = Identity> +__global__ void row_reduce_looped( + const T* in, + U* out, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + size_t out_idx = grid.block_rank(); + + Op op; + PrefixOp prefix; + + U total[1]; + U init = ReduceInit::value(); + total[0] = init; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + const size_t full_blocks = args.row_size / (block.size() * N_READS); + const size_t final_offset = full_blocks * (block.size() * N_READS); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + in += block.thread_rank() * N_READS; + + // Unaligned reduce + if (final_offset < args.row_size) { + bool mask[N_READS]; + for (int i = 0; i < N_READS; i++) { + mask[i] = + (final_offset + block.thread_rank() * N_READS + i) < args.row_size; + } + + for (size_t n = 0; n < args.non_row_reductions; n++) { + const T* inlocal = in + loop.location(); + + for (size_t r = 0; r < full_blocks; r++) { + auto vals = load_vector(inlocal, 0); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(prefix(vals[i]))); + } + inlocal += block.size() * N_READS; + } + + { + T vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + vals[i] = mask[i] ? inlocal[i] : cast_to(init); + } + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(prefix(vals[i]))); + } + } + + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + } + + // Aligned case + else { + for (size_t n = 0; n < args.non_row_reductions; n++) { + const T* inlocal = in + loop.location(); + + for (size_t r = 0; r < full_blocks; r++) { + auto vals = load_vector(inlocal, 0); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(prefix(vals[i]))); + } + inlocal += block.size() * N_READS; + } + + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + } + + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, total, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + out[out_idx] = total[0]; + } +} + +} // namespace mlx::core::cu diff --git a/mlx/primitives.h b/mlx/primitives.h index 4091aafcfb..cfb96173dc 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1809,9 +1809,39 @@ class MLX_API Reduce : public UnaryPrimitive { return {reduce_type_, axes_}; }; + void set_fused_prefix( + std::vector tape, + std::vector inputs, + std::unordered_set constant_ids) { + prefix_tape_ = std::move(tape); + prefix_inputs_ = std::move(inputs); + prefix_constant_ids_ = std::move(constant_ids); + } + + bool has_fused_prefix() const { + return !prefix_tape_.empty(); + } + + const std::vector& prefix_tape() const { + return prefix_tape_; + } + + const std::vector& prefix_inputs() const { + return prefix_inputs_; + } + + const std::unordered_set& prefix_constant_ids() const { + return prefix_constant_ids_; + } + private: ReduceType reduce_type_; std::vector axes_; + + // Fused prefix storage + std::vector prefix_tape_; + std::vector prefix_inputs_; + std::unordered_set prefix_constant_ids_; }; class Round : public UnaryPrimitive { From c10ef0ef9b4dc9f4b2e7eac720cc3553a4a42aa3 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sun, 1 Feb 2026 01:11:36 +0100 Subject: [PATCH 3/9] fix --- mlx/backend/cuda/CMakeLists.txt | 4 +- mlx/backend/cuda/reduce.cu | 5 ++ mlx/backend/cuda/reduce/all_reduce.cuh | 37 ++++++++--- mlx/backend/cuda/reduce/col_reduce.cuh | 87 ++++++++++++++++++++---- mlx/backend/cuda/reduce/reduce.cuh | 8 +++ mlx/backend/cuda/reduce/row_reduce.cuh | 91 ++++++++++++++++++++++---- 6 files changed, 199 insertions(+), 33 deletions(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 013b24b2f4..9411885c22 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -42,6 +42,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/fused_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu @@ -90,7 +91,8 @@ file( GLOB MLX_JIT_SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "${CMAKE_CURRENT_SOURCE_DIR}/device/*.h" - "${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh") + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh" + "${CMAKE_CURRENT_SOURCE_DIR}/reduce/*.cuh") string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) add_custom_command( OUTPUT gen/cuda_jit_sources.h diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index 269efc034b..d8cc4cf5d7 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -51,6 +51,11 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { plan = get_reduction_plan(in, axes_); } + if (has_fused_prefix()) { + fused_reduce(encoder, *this, in, out, axes_, plan); + return; + } + if (plan.type == ContiguousAllReduce) { all_reduce(encoder, in, out, reduce_type_); return; diff --git a/mlx/backend/cuda/reduce/all_reduce.cuh b/mlx/backend/cuda/reduce/all_reduce.cuh index f627622625..a0dba31c9b 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cuh +++ b/mlx/backend/cuda/reduce/all_reduce.cuh @@ -13,13 +13,13 @@ namespace mlx::core::cu { namespace cg = cooperative_groups; -template < - typename T, - typename U, - typename ReduceOp, - int N = 4, - typename PrefixOp = Identity> -__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { +template +__device__ void all_reduce_impl( + T* in, + U* out, + size_t block_step, + size_t size, + PrefixOp prefix) { // TODO: Process multiple "rows" in each thread constexpr int M = 1; @@ -29,7 +29,6 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { const U init = ReduceInit::value(); ReduceOp op; - PrefixOp prefix; T vals[N]; U accs[M]; @@ -50,8 +49,8 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { if (i < check) { cub::LoadDirectBlocked( block.thread_rank(), in + i, vals, check - i, cast_to(init)); - for (int i = 0; i < N; i++) { - accs[0] = op(accs[0], cast_to(prefix(vals[i]))); + for (int j = 0; j < N; j++) { + accs[0] = op(accs[0], cast_to(prefix(vals[j]))); } } @@ -63,4 +62,22 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { } } +template +__global__ void +all_reduce(T* in, U* out, size_t block_step, size_t size, PrefixOp prefix) { + all_reduce_impl( + in, out, block_step, size, prefix); +} + +template < + typename T, + typename U, + typename ReduceOp, + int N = 4, + typename PrefixOp = Identity> +__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { + all_reduce_impl( + in, out, block_step, size, PrefixOp{}); +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/col_reduce.cuh b/mlx/backend/cuda/reduce/col_reduce.cuh index ba04e3df15..24ab20b0e2 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cuh +++ b/mlx/backend/cuda/reduce/col_reduce.cuh @@ -90,14 +90,15 @@ template < int NDIM, int BM, int BN, - int N_READS = 4, - int BLOCKS = 1, - typename PrefixOp = Identity> -__global__ void col_reduce_looped( + int N_READS, + int BLOCKS, + typename PrefixOp> +__device__ void col_reduce_looped_impl( T* in, U* out, - const __grid_constant__ ColReduceArgs args, - int64_t out_size) { + const ColReduceArgs& args, + int64_t out_size, + PrefixOp prefix) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -121,7 +122,6 @@ __global__ void col_reduce_looped( // Initialize the running totals Op op; - PrefixOp prefix; U totals[N_READS]; for (int i = 0; i < N_READS; i++) { totals[i] = ReduceInit::value(); @@ -204,19 +204,56 @@ __global__ void col_reduce_looped( } } +// Kernel with prefix parameter +template < + typename T, + typename U, + typename Op, + int NDIM, + int BM, + int BN, + int N_READS, + int BLOCKS, + typename PrefixOp> +__global__ void col_reduce_looped( + T* in, + U* out, + const __grid_constant__ ColReduceArgs args, + int64_t out_size, + PrefixOp prefix) { + col_reduce_looped_impl( + in, out, args, out_size, prefix); +} + +// Kernel without prefix parameter (default Identity) template < typename T, typename U, typename Op, + int NDIM, + int BM, + int BN, int N_READS = 4, + int BLOCKS = 1, typename PrefixOp = Identity> -__global__ void col_reduce_small( - const T* in, +__global__ void col_reduce_looped( + T* in, U* out, const __grid_constant__ ColReduceArgs args, - size_t total) { + int64_t out_size) { + col_reduce_looped_impl( + in, out, args, out_size, PrefixOp{}); +} + +// Device function for col_reduce_small +template +__device__ void col_reduce_small_impl( + const T* in, + U* out, + const ColReduceArgs& args, + size_t total, + PrefixOp prefix) { Op op; - PrefixOp prefix; auto grid = cg::this_grid(); auto block = cg::this_thread_block(); @@ -251,4 +288,32 @@ __global__ void col_reduce_small( store_vector(out, 0, accumulator); } +// Kernel with prefix parameter +template +__global__ void col_reduce_small( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args, + size_t total, + PrefixOp prefix) { + col_reduce_small_impl( + in, out, args, total, prefix); +} + +// Kernel without prefix parameter (default Identity) +template < + typename T, + typename U, + typename Op, + int N_READS = 4, + typename PrefixOp = Identity> +__global__ void col_reduce_small( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args, + size_t total) { + col_reduce_small_impl( + in, out, args, total, PrefixOp{}); +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 02e495594a..788e4b97b0 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -68,4 +68,12 @@ void init_reduce( array& out, Reduce::ReduceType reduce_type); +void fused_reduce( + cu::CommandEncoder& encoder, + const Reduce& reduce, + const array& in, + array& out, + const std::vector& axes, + const ReductionPlan& plan); + } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/row_reduce.cuh b/mlx/backend/cuda/reduce/row_reduce.cuh index e77cba44ba..e29a2ce51d 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cuh +++ b/mlx/backend/cuda/reduce/row_reduce.cuh @@ -84,18 +84,21 @@ template < typename T, typename U, typename ReduceOp, - int N = 4, - int M = 1, - typename PrefixOp = Identity> -__global__ void -row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { + int N, + int M, + typename PrefixOp> +__device__ void row_reduce_simple_impl( + const T* in, + U* out, + size_t n_rows, + int size, + PrefixOp prefix) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); const U init = ReduceInit::value(); ReduceOp op; - PrefixOp prefix; AlignedVector vals[M]; AlignedVector accs; @@ -153,17 +156,51 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { } } +// Kernel with prefix parameter +template < + typename T, + typename U, + typename ReduceOp, + int N, + int M, + typename PrefixOp> +__global__ void row_reduce_simple( + const T* in, + U* out, + size_t n_rows, + int size, + PrefixOp prefix) { + row_reduce_simple_impl( + in, out, n_rows, size, prefix); +} + +// Kernel without prefix parameter (default Identity) +template < + typename T, + typename U, + typename ReduceOp, + int N = 4, + int M = 1, + typename PrefixOp = Identity> +__global__ void +row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { + row_reduce_simple_impl( + in, out, n_rows, size, PrefixOp{}); +} + +// Device function for row_reduce_looped template < typename T, typename U, typename Op, int NDIM, - int N_READS = 4, - typename PrefixOp = Identity> -__global__ void row_reduce_looped( + int N_READS, + typename PrefixOp> +__device__ void row_reduce_looped_impl( const T* in, U* out, - const __grid_constant__ RowReduceArgs args) { + const RowReduceArgs& args, + PrefixOp prefix) { auto grid = cg::this_grid(); auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); @@ -171,7 +208,6 @@ __global__ void row_reduce_looped( size_t out_idx = grid.block_rank(); Op op; - PrefixOp prefix; U total[1]; U init = ReduceInit::value(); @@ -241,4 +277,37 @@ __global__ void row_reduce_looped( } } +// Kernel with prefix parameter +template < + typename T, + typename U, + typename Op, + int NDIM, + int N_READS, + typename PrefixOp> +__global__ void row_reduce_looped( + const T* in, + U* out, + const __grid_constant__ RowReduceArgs args, + PrefixOp prefix) { + row_reduce_looped_impl( + in, out, args, prefix); +} + +// Kernel without prefix parameter (default Identity) +template < + typename T, + typename U, + typename Op, + int NDIM, + int N_READS = 4, + typename PrefixOp = Identity> +__global__ void row_reduce_looped( + const T* in, + U* out, + const __grid_constant__ RowReduceArgs args) { + row_reduce_looped_impl( + in, out, args, PrefixOp{}); +} + } // namespace mlx::core::cu From b813382bcfac894286b679e7426b95ed5ec64546 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sun, 1 Feb 2026 15:20:07 +0100 Subject: [PATCH 4/9] fix --- mlx/backend/cuda/reduce.cu | 38 ++++++++++++++++++++---- mlx/backend/cuda/reduce/all_reduce.cu | 1 + mlx/backend/cuda/reduce/all_reduce.cuh | 2 +- mlx/backend/cuda/reduce/reduce.cuh | 5 ++-- mlx/backend/cuda/reduce/reduce_utils.cuh | 14 ++++++--- mlx/compile.cpp | 8 ++--- 6 files changed, 52 insertions(+), 16 deletions(-) diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index d8cc4cf5d7..9c90b4c1ad 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -12,6 +12,39 @@ namespace mlx::core { void Reduce::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Reduce::eval_gpu"); + + if (has_fused_prefix()) { + array in = inputs[0]; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + if (in.size() == 0) { + init_reduce(encoder, in, out, reduce_type_); + return; + } + + ReductionPlan plan = get_reduction_plan(in, axes_); + + bool broadcasted = false; + for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + broadcasted = in.strides(i) == 0; + } + } + if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { + array in_copy = contiguous_copy_gpu(in, s); + encoder.add_temporary(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + fused_reduce(encoder, *this, inputs, out, axes_, plan, s); + return; + } + assert(inputs.size() == 1); array in = inputs[0]; @@ -51,11 +84,6 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { plan = get_reduction_plan(in, axes_); } - if (has_fused_prefix()) { - fused_reduce(encoder, *this, in, out, axes_, plan); - return; - } - if (plan.type == ContiguousAllReduce) { all_reduce(encoder, in, out, reduce_type_); return; diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 81932f4d6d..7ed7336701 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/reduce/all_reduce.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" namespace mlx::core { diff --git a/mlx/backend/cuda/reduce/all_reduce.cuh b/mlx/backend/cuda/reduce/all_reduce.cuh index a0dba31c9b..6ef70afe5d 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cuh +++ b/mlx/backend/cuda/reduce/all_reduce.cuh @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 788e4b97b0..078685e345 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -71,9 +71,10 @@ void init_reduce( void fused_reduce( cu::CommandEncoder& encoder, const Reduce& reduce, - const array& in, + const std::vector& inputs, array& out, const std::vector& axes, - const ReductionPlan& plan); + const ReductionPlan& plan, + const Stream& stream); } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index fdbe723378..ad28652bc5 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -2,15 +2,18 @@ #pragma once -#include - -#include "mlx/backend/common/utils.h" -#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/utils.cuh" #include #include +// Host-only includes +#ifndef __CUDACC_RTC__ +#include +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#endif + namespace mlx::core { namespace cu { @@ -90,6 +93,8 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { } // namespace cu +// Host-only function +#ifndef __CUDACC_RTC__ inline void allocate_same_layout( array& out, const array& in, @@ -141,5 +146,6 @@ inline void allocate_same_layout( fl, allocator::free); } +#endif } // namespace mlx::core diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 3c6e186cc5..41192f74a4 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -860,13 +860,13 @@ void compile_fuse( // Attach prefix to the Reduce primitive auto& reduce = static_cast(arr.primitive()); reduce.set_fused_prefix( - std::move(prefix_tape), - std::move(prefix_inputs), - std::move(constant_ids)); + prefix_tape, prefix_inputs, std::move(constant_ids)); - for (auto& p : reduce.prefix_tape()) { + for (auto& p : prefix_tape) { global_cache.insert(p.id()); + parents_map.erase(p.id()); } + arr.inputs() = std::move(prefix_inputs); } // Add the reduction to the new tape (with or without fused prefix) From 8c4a3b2e5950dff8edb8513395f155d9721b53c9 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 2 Feb 2026 17:11:13 +0100 Subject: [PATCH 5/9] revert to fuse only unary --- mlx/backend/cuda/CMakeLists.txt | 15 +- mlx/backend/cuda/jit_module.cpp | 12 + mlx/backend/cuda/reduce/all_reduce.cuh | 7 - mlx/backend/cuda/reduce/col_reduce.cuh | 7 +- mlx/backend/cuda/reduce/fused_reduce.cu | 351 +++++++++++++++++++++++ mlx/backend/cuda/reduce/reduce_utils.cuh | 1 - mlx/backend/cuda/reduce/row_reduce.cuh | 7 +- mlx/compile.cpp | 4 +- tests/compile_tests.cpp | 48 +++- 9 files changed, 427 insertions(+), 25 deletions(-) create mode 100644 mlx/backend/cuda/reduce/fused_reduce.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 9411885c22..eb8785fa0f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -212,13 +212,16 @@ install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) install(DIRECTORY ${cccl_SOURCE_DIR}/include/nv DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) +# CUB and Thrust are needed to JIT compile unary -> reduction kernels. +install(DIRECTORY ${cccl_SOURCE_DIR}/include/cub + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) +install(DIRECTORY ${cccl_SOURCE_DIR}/include/thrust + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) -# The binary of C++ tests will not be installed so it can not find the CCCL -# headers, and we have to hard-code the path. -if(MLX_BUILD_TESTS) - target_compile_definitions(mlx - PRIVATE MLX_CCCL_DIR="${cccl_SOURCE_DIR}/include") -endif() +# JIT cannot find the CCCL headers when compiling unary -> reduction kernels, so +# we hard-code the path to the downloaded CCCL. +target_compile_definitions(mlx + PRIVATE MLX_CCCL_DIR="${cccl_SOURCE_DIR}/include") # Use fixed version of NVTX. FetchContent_Declare( diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index b0ebb40195..34efa20c44 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -223,6 +223,7 @@ bool compiler_supports_device_sass(Device& device) { } #define INCLUDE_PREFIX "mlx/backend/cuda/device/" +#define REDUCE_PREFIX "mlx/backend/cuda/reduce/" constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", @@ -236,9 +237,15 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "unary_ops.cuh", INCLUDE_PREFIX "ternary_ops.cuh", INCLUDE_PREFIX "utils.cuh", + REDUCE_PREFIX "all_reduce.cuh", + REDUCE_PREFIX "col_reduce.cuh", + REDUCE_PREFIX "reduce_ops.cuh", + REDUCE_PREFIX "reduce_utils.cuh", + REDUCE_PREFIX "row_reduce.cuh", }; #undef INCLUDE_PREFIX +#undef REDUCE_PREFIX constexpr const char* g_headers[] = { jit_source_atomic_ops, @@ -252,6 +259,11 @@ constexpr const char* g_headers[] = { jit_source_unary_ops, jit_source_ternary_ops, jit_source_utils, + jit_source_all_reduce, + jit_source_col_reduce, + jit_source_reduce_ops, + jit_source_reduce_utils, + jit_source_row_reduce, }; void compile( diff --git a/mlx/backend/cuda/reduce/all_reduce.cuh b/mlx/backend/cuda/reduce/all_reduce.cuh index 6ef70afe5d..e661af56e1 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cuh +++ b/mlx/backend/cuda/reduce/all_reduce.cuh @@ -62,13 +62,6 @@ __device__ void all_reduce_impl( } } -template -__global__ void -all_reduce(T* in, U* out, size_t block_step, size_t size, PrefixOp prefix) { - all_reduce_impl( - in, out, block_step, size, prefix); -} - template < typename T, typename U, diff --git a/mlx/backend/cuda/reduce/col_reduce.cuh b/mlx/backend/cuda/reduce/col_reduce.cuh index 24ab20b0e2..73f52880d9 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cuh +++ b/mlx/backend/cuda/reduce/col_reduce.cuh @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include @@ -10,7 +10,10 @@ #include #include +#ifndef __CUDACC_RTC__ #include +#include "mlx/backend/cuda/reduce/reduce.cuh" +#endif namespace mlx::core::cu { @@ -34,6 +37,7 @@ struct ColReduceArgs { // The number of column we are reducing. Namely prod(reduce_shape). size_t non_col_reductions; +#ifndef __CUDACC_RTC__ ColReduceArgs( const array& in, const ReductionPlan& plan, @@ -81,6 +85,7 @@ struct ColReduceArgs { non_col_reductions *= reduce_shape[i]; } } +#endif }; template < diff --git a/mlx/backend/cuda/reduce/fused_reduce.cu b/mlx/backend/cuda/reduce/fused_reduce.cu new file mode 100644 index 0000000000..74c4a99082 --- /dev/null +++ b/mlx/backend/cuda/reduce/fused_reduce.cu @@ -0,0 +1,351 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/reduce/all_reduce.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/cuda/utils.h" +#include "mlx/graph_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +// Builder to generate prefix functor code that will be +// applied in reduction kernel +struct FusedReducePrefixBuilder { + std::string os; + const std::string& kernel_name; + const std::vector& inputs; + const std::vector& tape; + const std::function& is_constant; + + void build_prefix_struct() { + NodeNamer namer; + + std::unordered_set constant_set; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant(i)) { + constant_set.insert(inputs[i].id()); + } + } + + os += "struct " + kernel_name + "_Prefix {\n"; + + os += "\n template \n"; + os += " __device__ __forceinline__ T operator()(T val) const {\n"; + + // Read the first input (reduction input) from the passed value + // Find the non-constant input + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + if (is_constant(i)) { + std::ostringstream ss; + print_constant(ss, x); + os += fmt::format( + " {} tmp_{} = static_cast<{}>({}); // constant\n", + type, + xname, + type, + ss.str()); + } else { + os += fmt::format( + " {} tmp_{} = static_cast<{}>(val);\n", type, xname, type); + } + } + + for (const auto& x : tape) { + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + std::string value; + if (is_static_cast(x.primitive())) { + value = fmt::format( + "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); + } else { + value = x.primitive().name(); + value += "{}("; + for (size_t i = 0; i < x.inputs().size() - 1; ++i) { + value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); + } + value += fmt::format("tmp_{})", namer.get_name(x.inputs().back())); + } + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + } + + // Return the result (last tape item or first input if tape is empty) + if (!tape.empty()) { + os += fmt::format( + " return static_cast(tmp_{});\n", namer.get_name(tape.back())); + } else { + // Find the non-constant input to return + for (size_t i = 0; i < inputs.size(); ++i) { + if (!is_constant(i)) { + os += fmt::format( + " return static_cast(tmp_{});\n", + namer.get_name(inputs[i])); + break; + } + } + } + + os += " }\n"; + + os += "};\n\n"; + } +}; + +} // namespace cu + +constexpr const char* g_fused_reduce_includes = R"( +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/ternary_ops.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/reduce/all_reduce.cuh" + +#include + +#define inf cuda::std::numeric_limits::infinity() +)"; + +std::string get_reduce_op_name(Reduce::ReduceType reduce_type) { + switch (reduce_type) { + case Reduce::ReduceType::And: + return "And"; + case Reduce::ReduceType::Or: + return "Or"; + case Reduce::ReduceType::Sum: + return "Sum"; + case Reduce::ReduceType::Prod: + return "Prod"; + case Reduce::ReduceType::Max: + return "Max"; + case Reduce::ReduceType::Min: + return "Min"; + default: + throw std::runtime_error("Unknown reduce type"); + } +} + +void fused_all_reduce( + cu::CommandEncoder& encoder, + const Reduce& reduce, + const std::vector& inputs, + array& out, + const Stream& stream) { + nvtx3::scoped_range r("fused_all_reduce"); + + // Copied from all_reduce.cu + constexpr int N_READS = 8; + + const auto& prefix_inputs = reduce.prefix_inputs(); + const auto& prefix_tape = reduce.prefix_tape(); + const auto& prefix_constant_ids = reduce.prefix_constant_ids(); + + auto is_constant = [&](size_t i) -> bool { + return prefix_constant_ids.count(prefix_inputs[i].id()) > 0; + }; + + NodeNamer namer; + std::ostringstream os; + std::ostringstream constant_hasher; + + for (const auto& x : prefix_inputs) { + namer.get_name(x); + } + + // Build string from tape operations + for (const auto& a : prefix_tape) { + os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); + os << a.primitive().name(); + for (const auto& inp : a.inputs()) { + os << namer.get_name(inp); + } + } + // Name the kernel: similar to Compiled::Compiled kernel naming + for (size_t i = 0; i < prefix_inputs.size(); ++i) { + const auto& x = prefix_inputs[i]; + if (is_constant(i)) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << (is_scalar(x) ? "S" : "V"); + } + } + + os << get_reduce_op_name(reduce.state().first); + os << dtype_to_cuda_type(prefix_inputs[0].dtype()); + os << std::hash{}(constant_hasher.str()); + + std::string kernel_name = os.str(); + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + // Copied from all_reduce.cu + auto get_args = [](int size, int N) { + int threads = std::min(512, (size + N - 1) / N); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int reductions_per_step = threads * N; + size_t steps_needed = + (size + reductions_per_step - 1) / reductions_per_step; + + int blocks; + if (steps_needed < 32) { + blocks = 1; + } else if (steps_needed < 128) { + blocks = 32; + } else if (steps_needed < 512) { + blocks = 128; + } else if (steps_needed < 1024) { + blocks = 512; + } else { + blocks = 1024; + } + + size_t steps_per_block = (steps_needed + blocks - 1) / blocks; + size_t block_step = steps_per_block * reductions_per_step; + + return std::make_tuple(blocks, threads, block_step); + }; + + std::string reduce_op = get_reduce_op_name(reduce.state().first); + std::string in_type = dtype_to_cuda_type(prefix_inputs[0].dtype()); + std::string prefix_type = kernel_name + "_Prefix"; + + std::string full_kernel_name = + fmt::format("mlx::core::cu::{}_all_reduce", kernel_name); + + cu::JitModule& mod = cu::get_jit_module(stream.device, kernel_name, [&]() { + cu::FusedReducePrefixBuilder builder{ + g_fused_reduce_includes, + kernel_name, + prefix_inputs, + prefix_tape, + is_constant}; + builder.os += + "namespace mlx::core::cu {\n\n" + "namespace cg = cooperative_groups;\n\n"; + + // Generate the prefix struct + builder.build_prefix_struct(); + + builder.os += fmt::format( + "__global__ void {}_all_reduce({}* in, {}* out, size_t block_step, size_t size) {{\n" + " {} prefix{{}};\n" + " all_reduce_impl<{}, {}, {}, {}, {}>(in, out, block_step, size, prefix);\n" + "}}\n", + kernel_name, + in_type, + in_type, + prefix_type, + in_type, + in_type, + reduce_op, + N_READS, + prefix_type); + + builder.os += "\n} // namespace mlx::core::cu\n"; + + std::vector kernel_names; + kernel_names.push_back(full_kernel_name); + + return std::make_tuple( + false, std::move(builder.os), std::move(kernel_names)); + }); + + int blocks, threads; + size_t block_step; + + size_t insize = inputs[0].size(); + Reduce::ReduceType reduce_type = reduce.state().first; + + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + + encoder.set_input_array(inputs[0]); + + // If the reduction needs more than 1 block -- use fused kernel with + // fused prefix, then reduce intermediate to final output using build compiled + // all_reduce + if (blocks > 1) { + array intermediate({blocks}, out.dtype(), nullptr, {}); + intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder)); + encoder.add_temporary(intermediate); + encoder.set_output_array(intermediate); + + // First pass: apply fused prefix and reduce to intermediate + cu::KernelArgs args; + args.append(inputs[0]); + args.append(intermediate); + args.append(block_step); + args.append(insize); + + auto kernel = mod.get_kernel(full_kernel_name); + encoder.add_kernel_node(kernel, blocks, threads, 0, args.args()); + + // Second pass: reduce intermediate to final output + size_t intermediate_size = intermediate.size(); + std::tie(blocks, threads, block_step) = + get_args(intermediate_size, N_READS); + encoder.set_input_array(intermediate); + encoder.set_output_array(out); + + dispatch_all_types(out.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel2 = cu::all_reduce; + encoder.add_kernel_node( + kernel2, + blocks, + threads, + 0, + gpu_ptr(intermediate), + gpu_ptr(out), + block_step, + intermediate_size); + }); + }); + } else { + // Single block: direct reduction with fused prefix + encoder.set_output_array(out); + + cu::KernelArgs args; + args.append(inputs[0]); + args.append(out); + args.append(block_step); + args.append(insize); + + auto kernel = mod.get_kernel(full_kernel_name); + encoder.add_kernel_node(kernel, blocks, threads, 0, args.args()); + } +} + +void fused_reduce( + cu::CommandEncoder& encoder, + const Reduce& reduce, + const std::vector& inputs, + array& out, + const std::vector& axes, + const ReductionPlan& plan, + const Stream& stream) { + if (plan.type == ContiguousAllReduce) { + fused_all_reduce(encoder, reduce, inputs, out, stream); + return; + } + + // TODO: Implement fused row_reduce and col_reduce + throw std::runtime_error( + "Fused reduce not yet implemented for this reduction type"); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index ad28652bc5..4122e32b42 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -93,7 +93,6 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { } // namespace cu -// Host-only function #ifndef __CUDACC_RTC__ inline void allocate_same_layout( array& out, diff --git a/mlx/backend/cuda/reduce/row_reduce.cuh b/mlx/backend/cuda/reduce/row_reduce.cuh index e29a2ce51d..301fd15e47 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cuh +++ b/mlx/backend/cuda/reduce/row_reduce.cuh @@ -2,13 +2,16 @@ #pragma once -#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include #include +#ifndef __CUDACC_RTC__ #include +#include "mlx/backend/cuda/reduce/reduce.cuh" +#endif namespace mlx::core::cu { @@ -31,6 +34,7 @@ struct RowReduceArgs { // The number of rows we are reducing. Namely prod(reduce_shape). size_t non_row_reductions; +#ifndef __CUDACC_RTC__ RowReduceArgs( const array& in, const ReductionPlan& plan, @@ -78,6 +82,7 @@ struct RowReduceArgs { strides = const_param(strides_vec); ndim = shape_vec.size(); } +#endif }; template < diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 41192f74a4..ef6aa3a3f4 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -819,12 +819,12 @@ void compile_fuse( } // Stop fusing if: // depth limit exceeded - // non fusable primitive + // non unary primitive // does not have primitive // stream mismatch // is a constant input if (depth >= max_prefix_depth || !a.has_primitive() || - !is_fusable(a.primitive()) || + !is_unary(a.primitive()) || a.primitive().stream() != reduction_stream || input_ids.count(a.id())) { prefix_inputs.push_back(a); diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index d2146f3bad..901419519c 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -296,13 +296,13 @@ TEST_CASE("test compile unary fused") { CHECK_EQ(out.size(), 1); auto& p = out[0].primitive(); - // NB: this test is brittle, will need to update - // it if we change compile conditions + // With fused-into-reduction, unary ops are fused into the Reduce primitive CHECK_EQ(typeid(p), typeid(Reduce)); auto cout = out[0].inputs()[0]; - auto& cp = cout.primitive(); - CHECK_EQ(typeid(cp), typeid(Compiled)); - CHECK_EQ(cout.inputs()[0].id(), x.id()); + auto& reduce = static_cast(p); + CHECK(reduce.has_fused_prefix()); + // The input to Reduce is directly x (not a Compiled primitive) + CHECK_EQ(out[0].inputs()[0].id(), x.id()); } { @@ -397,10 +397,11 @@ TEST_CASE("test compile binary fused") { auto& p = out.primitive(); CHECK_EQ(typeid(p), typeid(Reduce)); - + // With fused-into-reduction, only unary ops (abs) are fused into Reduce + // The binary op (Add) remains as the input to fused reduce + abs auto cout = out.inputs()[0]; auto& cp = cout.primitive(); - CHECK_EQ(typeid(cp), typeid(Compiled)); + CHECK_EQ(typeid(cp), typeid(Add)); CHECK_EQ(cout.inputs()[0].id(), x.id()); CHECK_EQ(cout.inputs()[1].id(), y.id()); } @@ -816,3 +817,36 @@ TEST_CASE("test compile random bits") { auto out = compile(fun)({in})[0]; CHECK(array_equal(out, expected).item()); } + +TEST_CASE("test compile unary reduction one pass") { + auto fun = [](const std::vector& inputs) { + auto x = inputs[0]; + return std::vector{sum(abs(x))}; + }; + auto in = ones({128, 128}); + auto expected = fun({in})[0]; + auto out = compile(fun)({ones({128, 128})})[0]; + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test compile unary reduction two passes") { + auto fun = [](const std::vector& inputs) { + auto x = inputs[0]; + return std::vector{sum(abs(x))}; + }; + auto in = ones({1024, 1024}); + auto expected = fun({in})[0]; + auto out = compile(fun)({ones({1024, 1024})})[0]; + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test compile unary+constant reduction two passes") { + auto fun = [](const std::vector& inputs) { + auto x = inputs[0]; + return std::vector{sum(abs(x) + 1.0f)}; + }; + auto in = ones({1024, 1024}); + auto expected = fun({in})[0]; + auto out = compile(fun)({ones({1024, 1024})})[0]; + CHECK(array_equal(out, expected).item()); +} \ No newline at end of file From 88f46e83ebafc7fcb0c80bd5523a1cd0272ecc2a Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 2 Feb 2026 17:58:23 +0100 Subject: [PATCH 6/9] revert row/col files --- mlx/backend/cuda/reduce/col_reduce.cu | 244 ++++++++++++++++++- mlx/backend/cuda/reduce/col_reduce.cuh | 324 ------------------------- mlx/backend/cuda/reduce/row_reduce.cu | 228 ++++++++++++++++- mlx/backend/cuda/reduce/row_reduce.cuh | 318 ------------------------ mlx/compile.cpp | 4 +- 5 files changed, 473 insertions(+), 645 deletions(-) delete mode 100644 mlx/backend/cuda/reduce/col_reduce.cuh delete mode 100644 mlx/backend/cuda/reduce/row_reduce.cuh diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index bbeb226bf1..94fa498873 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -1,9 +1,251 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/reduce/col_reduce.cuh" +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" + +#include +#include +#include +#include + +#include +#include "mlx/backend/cuda/reduce/reduce.cuh" namespace mlx::core { +namespace cg = cooperative_groups; + +namespace cu { + +struct ColReduceArgs { + // The size of the contiguous column reduction. + size_t reduction_size; + int64_t reduction_stride; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes (including last dimension). + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of column we are reducing. Namely prod(reduce_shape). + size_t non_col_reductions; + + ColReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + using ShapeVector = decltype(plan.shape); + using StridesVector = decltype(plan.strides); + + ShapeVector shape_vec; + StridesVector strides_vec; + + assert(!plan.shape.empty()); + reduction_size = plan.shape.back(); + reduction_stride = plan.strides.back(); + + int64_t stride_back = 1; + std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); + while (!shape_vec.empty() && stride_back < reduction_stride) { + stride_back *= shape_vec.back(); + shape_vec.pop_back(); + strides_vec.pop_back(); + } + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + ShapeVector sorted_shape; + StridesVector sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(sorted_shape, sorted_strides); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size(); + + non_col_reductions = 1; + for (int i = 0; i < reduce_ndim - 1; i++) { + non_col_reductions *= reduce_shape[i]; + } + } +}; + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BM, + int BN, + int N_READS = 4, + int BLOCKS = 1> +__global__ void col_reduce_looped( + T* in, + U* out, + const __grid_constant__ ColReduceArgs args, + int64_t out_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int threads_per_row = BN / N_READS; + + // Compute the indices for the tile + size_t tile_idx = grid.block_rank(); + size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); + size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); + size_t tile_out = tile_y / out_size; + tile_y = tile_y % out_size; + + // Compute the indices for the thread within the tile + short thread_x = block.thread_rank() % threads_per_row; + short thread_y = block.thread_rank() / threads_per_row; + + // Move the input pointer + in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + + tile_x * BN; + + // Initialize the running totals + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + size_t total = args.non_col_reductions * args.reduction_size; + size_t per_block, start, end; + if constexpr (BLOCKS > 1) { + per_block = (total + BLOCKS - 1) / BLOCKS; + start = tile_out * per_block + thread_y; + end = min((tile_out + 1) * per_block, total); + } else { + per_block = total; + start = thread_y; + end = total; + } + + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next(start, args.reduce_shape.data(), args.reduce_strides.data()); + if (tile_x * BN + BN <= args.reduction_stride) { + if (args.reduction_stride % N_READS == 0) { + for (size_t r = start; r < end; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], cast_to(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } else { + for (size_t r = start; r < end; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], cast_to(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } + } else { + for (size_t r = start; r < end; r += BM) { + T vals[N_READS]; + cub::LoadDirectBlocked( + thread_x, + in + loop.location(), + vals, + args.reduction_stride - tile_x * BN, + cast_to(ReduceInit::value())); + for (int i = 0; i < N_READS; i++) { + totals[i] = op(totals[i], cast_to(vals[i])); + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + } + + // Do warp reduce for each output. + constexpr int n_outputs = BN / threads_per_row; + static_assert(BM == 32 && n_outputs == N_READS); + __shared__ U shared_vals[BM * BN]; + short s_idx = thread_y * BN + thread_x * N_READS; + for (int i = 0; i < N_READS; i++) { + shared_vals[s_idx + i] = totals[i]; + } + block.sync(); + s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; + for (int i = 0; i < n_outputs; i++) { + totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); + } + + // Write result. + if (warp.thread_rank() == 0) { + if (BLOCKS > 1) { + out += tile_out * out_size * args.reduction_stride; + } + cub::StoreDirectBlocked( + warp.meta_group_rank(), + out + tile_y * args.reduction_stride + tile_x * BN, + totals, + args.reduction_stride - tile_x * BN); + } +} + +template +__global__ void col_reduce_small( + const T* in, + U* out, + const __grid_constant__ ColReduceArgs args, + size_t total) { + Op op; + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + + const auto idx = grid.thread_rank() * N_READS; + const auto before_axis = idx / args.reduction_stride; + const auto after_axis = idx % args.reduction_stride; + const auto offset = + before_axis * args.reduction_stride * args.reduction_size + after_axis; + + if (idx >= total) { + return; + } + + in += offset; + out += idx; + + AlignedVector accumulator; + for (int i = 0; i < N_READS; i++) { + accumulator[i] = ReduceInit::value(); + } + + for (int i = 0; i < args.reduction_size; i++) { + auto values = load_vector(in, 0); + + for (int j = 0; j < N_READS; j++) { + accumulator[j] = op(accumulator[j], cast_to(values[j])); + } + + in += args.reduction_stride; + } + + store_vector(out, 0, accumulator); +} + +} // namespace cu + inline auto output_grid_for_col_reduce( const array& out, const cu::ColReduceArgs& args, diff --git a/mlx/backend/cuda/reduce/col_reduce.cuh b/mlx/backend/cuda/reduce/col_reduce.cuh deleted file mode 100644 index 73f52880d9..0000000000 --- a/mlx/backend/cuda/reduce/col_reduce.cuh +++ /dev/null @@ -1,324 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include "mlx/backend/cuda/device/config.h" -#include "mlx/backend/cuda/reduce/reduce_ops.cuh" - -#include -#include -#include -#include - -#ifndef __CUDACC_RTC__ -#include -#include "mlx/backend/cuda/reduce/reduce.cuh" -#endif - -namespace mlx::core::cu { - -namespace cg = cooperative_groups; - -struct ColReduceArgs { - // The size of the contiguous column reduction. - size_t reduction_size; - int64_t reduction_stride; - - // Input shape and strides excluding the reduction axes. - Shape shape; - Strides strides; - int ndim; - - // Input shape and strides of the reduction axes (including last dimension). - Shape reduce_shape; - Strides reduce_strides; - int reduce_ndim; - - // The number of column we are reducing. Namely prod(reduce_shape). - size_t non_col_reductions; - -#ifndef __CUDACC_RTC__ - ColReduceArgs( - const array& in, - const ReductionPlan& plan, - const std::vector& axes) { - using ShapeVector = decltype(plan.shape); - using StridesVector = decltype(plan.strides); - - ShapeVector shape_vec; - StridesVector strides_vec; - - assert(!plan.shape.empty()); - reduction_size = plan.shape.back(); - reduction_stride = plan.strides.back(); - - int64_t stride_back = 1; - std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); - while (!shape_vec.empty() && stride_back < reduction_stride) { - stride_back *= shape_vec.back(); - shape_vec.pop_back(); - strides_vec.pop_back(); - } - std::vector indices(shape_vec.size()); - std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), indices.end(), [&](int left, int right) { - return strides_vec[left] > strides_vec[right]; - }); - ShapeVector sorted_shape; - StridesVector sorted_strides; - for (auto idx : indices) { - sorted_shape.push_back(shape_vec[idx]); - sorted_strides.push_back(strides_vec[idx]); - } - std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(sorted_shape, sorted_strides); - shape = const_param(shape_vec); - strides = const_param(strides_vec); - ndim = shape_vec.size(); - - reduce_shape = const_param(plan.shape); - reduce_strides = const_param(plan.strides); - reduce_ndim = plan.shape.size(); - - non_col_reductions = 1; - for (int i = 0; i < reduce_ndim - 1; i++) { - non_col_reductions *= reduce_shape[i]; - } - } -#endif -}; - -template < - typename T, - typename U, - typename Op, - int NDIM, - int BM, - int BN, - int N_READS, - int BLOCKS, - typename PrefixOp> -__device__ void col_reduce_looped_impl( - T* in, - U* out, - const ColReduceArgs& args, - int64_t out_size, - PrefixOp prefix) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - constexpr int threads_per_row = BN / N_READS; - - // Compute the indices for the tile - size_t tile_idx = grid.block_rank(); - size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); - size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); - size_t tile_out = tile_y / out_size; - tile_y = tile_y % out_size; - - // Compute the indices for the thread within the tile - short thread_x = block.thread_rank() % threads_per_row; - short thread_y = block.thread_rank() / threads_per_row; - - // Move the input pointer - in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + - tile_x * BN; - - // Initialize the running totals - Op op; - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = ReduceInit::value(); - } - - size_t total = args.non_col_reductions * args.reduction_size; - size_t per_block, start, end; - if constexpr (BLOCKS > 1) { - per_block = (total + BLOCKS - 1) / BLOCKS; - start = tile_out * per_block + thread_y; - end = min((tile_out + 1) * per_block, total); - } else { - per_block = total; - start = thread_y; - end = total; - } - - LoopedElemToLoc 2)> loop(args.reduce_ndim); - loop.next(start, args.reduce_shape.data(), args.reduce_strides.data()); - if (tile_x * BN + BN <= args.reduction_stride) { - if (args.reduction_stride % N_READS == 0) { - for (size_t r = start; r < end; r += BM) { - T vals[N_READS]; - cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], cast_to(prefix(vals[i]))); - } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); - } - } else { - for (size_t r = start; r < end; r += BM) { - T vals[N_READS]; - cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], cast_to(prefix(vals[i]))); - } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); - } - } - } else { - for (size_t r = start; r < end; r += BM) { - T vals[N_READS]; - cub::LoadDirectBlocked( - thread_x, - in + loop.location(), - vals, - args.reduction_stride - tile_x * BN, - cast_to(ReduceInit::value())); - for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], cast_to(prefix(vals[i]))); - } - loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); - } - } - - // Do warp reduce for each output. - constexpr int n_outputs = BN / threads_per_row; - static_assert(BM == 32 && n_outputs == N_READS); - __shared__ U shared_vals[BM * BN]; - short s_idx = thread_y * BN + thread_x * N_READS; - for (int i = 0; i < N_READS; i++) { - shared_vals[s_idx + i] = totals[i]; - } - block.sync(); - s_idx = warp.thread_rank() * BN + warp.meta_group_rank() * n_outputs; - for (int i = 0; i < n_outputs; i++) { - totals[i] = cg::reduce(warp, shared_vals[s_idx + i], op); - } - - // Write result. - if (warp.thread_rank() == 0) { - if (BLOCKS > 1) { - out += tile_out * out_size * args.reduction_stride; - } - cub::StoreDirectBlocked( - warp.meta_group_rank(), - out + tile_y * args.reduction_stride + tile_x * BN, - totals, - args.reduction_stride - tile_x * BN); - } -} - -// Kernel with prefix parameter -template < - typename T, - typename U, - typename Op, - int NDIM, - int BM, - int BN, - int N_READS, - int BLOCKS, - typename PrefixOp> -__global__ void col_reduce_looped( - T* in, - U* out, - const __grid_constant__ ColReduceArgs args, - int64_t out_size, - PrefixOp prefix) { - col_reduce_looped_impl( - in, out, args, out_size, prefix); -} - -// Kernel without prefix parameter (default Identity) -template < - typename T, - typename U, - typename Op, - int NDIM, - int BM, - int BN, - int N_READS = 4, - int BLOCKS = 1, - typename PrefixOp = Identity> -__global__ void col_reduce_looped( - T* in, - U* out, - const __grid_constant__ ColReduceArgs args, - int64_t out_size) { - col_reduce_looped_impl( - in, out, args, out_size, PrefixOp{}); -} - -// Device function for col_reduce_small -template -__device__ void col_reduce_small_impl( - const T* in, - U* out, - const ColReduceArgs& args, - size_t total, - PrefixOp prefix) { - Op op; - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - - const auto idx = grid.thread_rank() * N_READS; - const auto before_axis = idx / args.reduction_stride; - const auto after_axis = idx % args.reduction_stride; - const auto offset = - before_axis * args.reduction_stride * args.reduction_size + after_axis; - - if (idx >= total) { - return; - } - - in += offset; - out += idx; - - AlignedVector accumulator; - for (int i = 0; i < N_READS; i++) { - accumulator[i] = ReduceInit::value(); - } - - for (int i = 0; i < args.reduction_size; i++) { - auto values = load_vector(in, 0); - - for (int j = 0; j < N_READS; j++) { - accumulator[j] = op(accumulator[j], cast_to(prefix(values[j]))); - } - - in += args.reduction_stride; - } - - store_vector(out, 0, accumulator); -} - -// Kernel with prefix parameter -template -__global__ void col_reduce_small( - const T* in, - U* out, - const __grid_constant__ ColReduceArgs args, - size_t total, - PrefixOp prefix) { - col_reduce_small_impl( - in, out, args, total, prefix); -} - -// Kernel without prefix parameter (default Identity) -template < - typename T, - typename U, - typename Op, - int N_READS = 4, - typename PrefixOp = Identity> -__global__ void col_reduce_small( - const T* in, - U* out, - const __grid_constant__ ColReduceArgs args, - size_t total) { - col_reduce_small_impl( - in, out, args, total, PrefixOp{}); -} - -} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 55441e8e23..1bdccd1eae 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -1,9 +1,235 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/reduce/row_reduce.cuh" +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" + +#include +#include + +#include +#include "mlx/backend/cuda/reduce/reduce.cuh" namespace mlx::core { +namespace cg = cooperative_groups; + +namespace cu { + +struct RowReduceArgs { + // The size of the row being reduced, i.e. the size of last dimension. + int row_size; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes excluding last dimension. + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of rows we are reducing. Namely prod(reduce_shape). + size_t non_row_reductions; + + RowReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + assert(!plan.shape.empty()); + row_size = plan.shape.back(); + + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(shape_vec, strides_vec); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + + reduce_shape = const_param(plan.shape); + reduce_strides = const_param(plan.strides); + reduce_ndim = plan.shape.size() - 1; + + non_row_reductions = 1; + for (int i = 0; i < reduce_ndim; i++) { + non_row_reductions *= reduce_shape[i]; + } + } + + // Convert shape and strides as if in was contiguous + void sort_access_pattern(const array& in, const std::vector& axes) { + auto shape_vec = in.shape(); + auto strides_vec = in.strides(); + std::tie(shape_vec, strides_vec) = + shapes_without_reduction_axes(shape_vec, strides_vec, axes); + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + decltype(shape_vec) sorted_shape; + decltype(strides_vec) sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(sorted_shape, sorted_strides); + shape = const_param(shape_vec); + strides = const_param(strides_vec); + ndim = shape_vec.size(); + } +}; + +template +__global__ void +row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = ReduceInit::value(); + ReduceOp op; + + AlignedVector vals[M]; + AlignedVector accs; + for (int i = 0; i < M; i++) { + accs[i] = init; + } + + const size_t start_row = + min(n_rows - M, static_cast(grid.block_rank() * M)); + const size_t full_blocks = size / (block.size() * N); + const size_t final_offset = full_blocks * (block.size() * N); + in += start_row * size + block.thread_rank() * N; + out += start_row; + + for (size_t r = 0; r < full_blocks; r++) { + for (int k = 0; k < M; k++) { + vals[k] = load_vector(in + k * size, 0); + } + for (int k = 0; k < M; k++) { + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], cast_to(vals[k][j])); + } + } + + in += block.size() * N; + } + + if (final_offset < size) { + for (int k = 0; k < M; k++) { + for (int i = 0; i < N; i++) { + vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size) + ? in[k * size + i] + : cast_to(init); + } + } + for (int k = 0; k < M; k++) { + for (int j = 0; j < N; j++) { + accs[k] = op(accs[k], cast_to(vals[k][j])); + } + } + } + + __shared__ U shared_accumulators[32 * M]; + block_reduce(block, warp, accs.val, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + if (grid.block_rank() * M + M <= n_rows) { + store_vector(out, 0, accs); + } else { + short offset = grid.block_rank() * M + M - n_rows; + for (int i = offset; i < M; i++) { + out[i] = accs[i]; + } + } + } +} + +template +__global__ void row_reduce_looped( + const T* in, + U* out, + const __grid_constant__ RowReduceArgs args) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + size_t out_idx = grid.block_rank(); + + Op op; + + U total[1]; + U init = ReduceInit::value(); + total[0] = init; + LoopedElemToLoc 2)> loop(args.reduce_ndim); + const size_t full_blocks = args.row_size / (block.size() * N_READS); + const size_t final_offset = full_blocks * (block.size() * N_READS); + + in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); + in += block.thread_rank() * N_READS; + + // Unaligned reduce + if (final_offset < args.row_size) { + bool mask[N_READS]; + for (int i = 0; i < N_READS; i++) { + mask[i] = + (final_offset + block.thread_rank() * N_READS + i) < args.row_size; + } + + for (size_t n = 0; n < args.non_row_reductions; n++) { + const T* inlocal = in + loop.location(); + + for (size_t r = 0; r < full_blocks; r++) { + auto vals = load_vector(inlocal, 0); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + inlocal += block.size() * N_READS; + } + + { + T vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + vals[i] = mask[i] ? inlocal[i] : cast_to(init); + } + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + } + + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + } + + // Aligned case + else { + for (size_t n = 0; n < args.non_row_reductions; n++) { + const T* inlocal = in + loop.location(); + + for (size_t r = 0; r < full_blocks; r++) { + auto vals = load_vector(inlocal, 0); + for (int i = 0; i < N_READS; i++) { + total[0] = op(total[0], cast_to(vals[i])); + } + inlocal += block.size() * N_READS; + } + + loop.next(args.reduce_shape.data(), args.reduce_strides.data()); + } + } + + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, total, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + out[out_idx] = total[0]; + } +} + +} // namespace cu + void row_reduce_simple( cu::CommandEncoder& encoder, const array& in, diff --git a/mlx/backend/cuda/reduce/row_reduce.cuh b/mlx/backend/cuda/reduce/row_reduce.cuh deleted file mode 100644 index 301fd15e47..0000000000 --- a/mlx/backend/cuda/reduce/row_reduce.cuh +++ /dev/null @@ -1,318 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include "mlx/backend/cuda/device/config.h" -#include "mlx/backend/cuda/reduce/reduce_ops.cuh" - -#include -#include - -#ifndef __CUDACC_RTC__ -#include -#include "mlx/backend/cuda/reduce/reduce.cuh" -#endif - -namespace mlx::core::cu { - -namespace cg = cooperative_groups; - -struct RowReduceArgs { - // The size of the row being reduced, i.e. the size of last dimension. - int row_size; - - // Input shape and strides excluding the reduction axes. - Shape shape; - Strides strides; - int ndim; - - // Input shape and strides of the reduction axes excluding last dimension. - Shape reduce_shape; - Strides reduce_strides; - int reduce_ndim; - - // The number of rows we are reducing. Namely prod(reduce_shape). - size_t non_row_reductions; - -#ifndef __CUDACC_RTC__ - RowReduceArgs( - const array& in, - const ReductionPlan& plan, - const std::vector& axes) { - assert(!plan.shape.empty()); - row_size = plan.shape.back(); - - auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); - std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(shape_vec, strides_vec); - shape = const_param(shape_vec); - strides = const_param(strides_vec); - ndim = shape_vec.size(); - - reduce_shape = const_param(plan.shape); - reduce_strides = const_param(plan.strides); - reduce_ndim = plan.shape.size() - 1; - - non_row_reductions = 1; - for (int i = 0; i < reduce_ndim; i++) { - non_row_reductions *= reduce_shape[i]; - } - } - - // Convert shape and strides as if in was contiguous - void sort_access_pattern(const array& in, const std::vector& axes) { - auto shape_vec = in.shape(); - auto strides_vec = in.strides(); - std::tie(shape_vec, strides_vec) = - shapes_without_reduction_axes(shape_vec, strides_vec, axes); - std::vector indices(shape_vec.size()); - std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), indices.end(), [&](int left, int right) { - return strides_vec[left] > strides_vec[right]; - }); - decltype(shape_vec) sorted_shape; - decltype(strides_vec) sorted_strides; - for (auto idx : indices) { - sorted_shape.push_back(shape_vec[idx]); - sorted_strides.push_back(strides_vec[idx]); - } - std::tie(shape_vec, strides_vec) = - collapse_contiguous_dims(sorted_shape, sorted_strides); - shape = const_param(shape_vec); - strides = const_param(strides_vec); - ndim = shape_vec.size(); - } -#endif -}; - -template < - typename T, - typename U, - typename ReduceOp, - int N, - int M, - typename PrefixOp> -__device__ void row_reduce_simple_impl( - const T* in, - U* out, - size_t n_rows, - int size, - PrefixOp prefix) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - const U init = ReduceInit::value(); - ReduceOp op; - - AlignedVector vals[M]; - AlignedVector accs; - for (int i = 0; i < M; i++) { - accs[i] = init; - } - - const size_t start_row = - min(n_rows - M, static_cast(grid.block_rank() * M)); - const size_t full_blocks = size / (block.size() * N); - const size_t final_offset = full_blocks * (block.size() * N); - in += start_row * size + block.thread_rank() * N; - out += start_row; - - for (size_t r = 0; r < full_blocks; r++) { - for (int k = 0; k < M; k++) { - vals[k] = load_vector(in + k * size, 0); - } - for (int k = 0; k < M; k++) { - for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], cast_to(prefix(vals[k][j]))); - } - } - - in += block.size() * N; - } - - if (final_offset < size) { - for (int k = 0; k < M; k++) { - for (int i = 0; i < N; i++) { - vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size) - ? in[k * size + i] - : cast_to(init); - } - } - for (int k = 0; k < M; k++) { - for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], cast_to(prefix(vals[k][j]))); - } - } - } - - __shared__ U shared_accumulators[32 * M]; - block_reduce(block, warp, accs.val, shared_accumulators, op, init); - - if (block.thread_rank() == 0) { - if (grid.block_rank() * M + M <= n_rows) { - store_vector(out, 0, accs); - } else { - short offset = grid.block_rank() * M + M - n_rows; - for (int i = offset; i < M; i++) { - out[i] = accs[i]; - } - } - } -} - -// Kernel with prefix parameter -template < - typename T, - typename U, - typename ReduceOp, - int N, - int M, - typename PrefixOp> -__global__ void row_reduce_simple( - const T* in, - U* out, - size_t n_rows, - int size, - PrefixOp prefix) { - row_reduce_simple_impl( - in, out, n_rows, size, prefix); -} - -// Kernel without prefix parameter (default Identity) -template < - typename T, - typename U, - typename ReduceOp, - int N = 4, - int M = 1, - typename PrefixOp = Identity> -__global__ void -row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { - row_reduce_simple_impl( - in, out, n_rows, size, PrefixOp{}); -} - -// Device function for row_reduce_looped -template < - typename T, - typename U, - typename Op, - int NDIM, - int N_READS, - typename PrefixOp> -__device__ void row_reduce_looped_impl( - const T* in, - U* out, - const RowReduceArgs& args, - PrefixOp prefix) { - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - size_t out_idx = grid.block_rank(); - - Op op; - - U total[1]; - U init = ReduceInit::value(); - total[0] = init; - LoopedElemToLoc 2)> loop(args.reduce_ndim); - const size_t full_blocks = args.row_size / (block.size() * N_READS); - const size_t final_offset = full_blocks * (block.size() * N_READS); - - in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim); - in += block.thread_rank() * N_READS; - - // Unaligned reduce - if (final_offset < args.row_size) { - bool mask[N_READS]; - for (int i = 0; i < N_READS; i++) { - mask[i] = - (final_offset + block.thread_rank() * N_READS + i) < args.row_size; - } - - for (size_t n = 0; n < args.non_row_reductions; n++) { - const T* inlocal = in + loop.location(); - - for (size_t r = 0; r < full_blocks; r++) { - auto vals = load_vector(inlocal, 0); - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(prefix(vals[i]))); - } - inlocal += block.size() * N_READS; - } - - { - T vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - vals[i] = mask[i] ? inlocal[i] : cast_to(init); - } - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(prefix(vals[i]))); - } - } - - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); - } - } - - // Aligned case - else { - for (size_t n = 0; n < args.non_row_reductions; n++) { - const T* inlocal = in + loop.location(); - - for (size_t r = 0; r < full_blocks; r++) { - auto vals = load_vector(inlocal, 0); - for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], cast_to(prefix(vals[i]))); - } - inlocal += block.size() * N_READS; - } - - loop.next(args.reduce_shape.data(), args.reduce_strides.data()); - } - } - - __shared__ U shared_accumulators[32]; - block_reduce(block, warp, total, shared_accumulators, op, init); - - if (block.thread_rank() == 0) { - out[out_idx] = total[0]; - } -} - -// Kernel with prefix parameter -template < - typename T, - typename U, - typename Op, - int NDIM, - int N_READS, - typename PrefixOp> -__global__ void row_reduce_looped( - const T* in, - U* out, - const __grid_constant__ RowReduceArgs args, - PrefixOp prefix) { - row_reduce_looped_impl( - in, out, args, prefix); -} - -// Kernel without prefix parameter (default Identity) -template < - typename T, - typename U, - typename Op, - int NDIM, - int N_READS = 4, - typename PrefixOp = Identity> -__global__ void row_reduce_looped( - const T* in, - U* out, - const __grid_constant__ RowReduceArgs args) { - row_reduce_looped_impl( - in, out, args, PrefixOp{}); -} - -} // namespace mlx::core::cu diff --git a/mlx/compile.cpp b/mlx/compile.cpp index ef6aa3a3f4..d90275f555 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -802,7 +802,9 @@ void compile_fuse( continue; } // If current op is a reduction, we may want to fuse prefix ops - if (arr.has_primitive() && is_reduction(arr.primitive())) { + // Only fuse for all_reduce + if (arr.has_primitive() && is_reduction(arr.primitive()) && + arr.size() == 1) { auto& reduction_input = arr.inputs()[0]; Stream reduction_stream = arr.primitive().stream(); const int max_prefix_depth = max_compile_depth - 1; // 1 for reduction From c9e47fc8f0c263b5cc049be593915f4c409fcc78 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 2 Feb 2026 18:03:56 +0100 Subject: [PATCH 7/9] revert jit files --- mlx/backend/cuda/jit_module.cpp | 4 ---- mlx/backend/cuda/reduce/col_reduce.cu | 13 ++++++------- mlx/backend/cuda/reduce/row_reduce.cu | 15 +++++++-------- python/tests/test_compile.py | 12 ++++++++++++ 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 34efa20c44..024bf549e5 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -238,10 +238,8 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "ternary_ops.cuh", INCLUDE_PREFIX "utils.cuh", REDUCE_PREFIX "all_reduce.cuh", - REDUCE_PREFIX "col_reduce.cuh", REDUCE_PREFIX "reduce_ops.cuh", REDUCE_PREFIX "reduce_utils.cuh", - REDUCE_PREFIX "row_reduce.cuh", }; #undef INCLUDE_PREFIX @@ -260,10 +258,8 @@ constexpr const char* g_headers[] = { jit_source_ternary_ops, jit_source_utils, jit_source_all_reduce, - jit_source_col_reduce, jit_source_reduce_ops, jit_source_reduce_utils, - jit_source_row_reduce, }; void compile( diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 94fa498873..e33551d86e 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -1,22 +1,21 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device/config.h" -#include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" #include #include #include #include -#include -#include "mlx/backend/cuda/reduce/reduce.cuh" - namespace mlx::core { -namespace cg = cooperative_groups; - namespace cu { +namespace cg = cooperative_groups; + struct ColReduceArgs { // The size of the contiguous column reduction. size_t reduction_size; diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 1bdccd1eae..ea99e11325 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -1,20 +1,19 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device/config.h" -#include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" #include #include -#include -#include "mlx/backend/cuda/reduce/reduce.cuh" - namespace mlx::core { -namespace cg = cooperative_groups; - namespace cu { +namespace cg = cooperative_groups; + struct RowReduceArgs { // The size of the row being reduced, i.e. the size of last dimension. int row_size; @@ -88,7 +87,7 @@ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) { auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - const U init = ReduceInit::value(); + const U init = cu::ReduceInit::value(); ReduceOp op; AlignedVector vals[M]; diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index d64c057fd1..10cd74b487 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1272,6 +1272,18 @@ def fun(x): np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr ) + def test_compile_reduction(self): + x = mx.random.uniform(shape=(4, 4)) + mx.eval(x) + + @mx.compile + def fun(x): + return mx.sum(x, axis=1) + + out = fun(x) + expected = mx.sum(x, axis=1) + self.assertTrue(mx.allclose(out, expected)) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From dc3d220b5bfdc9e158a11b2a3d8ca4ba8458c6fd Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 2 Feb 2026 18:25:36 +0100 Subject: [PATCH 8/9] python tests --- python/tests/test_compile.py | 23 ++++++++++++++++------- tests/compile_tests.cpp | 2 +- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 10cd74b487..8ea10591d8 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1272,16 +1272,25 @@ def fun(x): np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr ) - def test_compile_reduction(self): - x = mx.random.uniform(shape=(4, 4)) - mx.eval(x) + def test_compile_unary_reduction(self): + + x = mx.ones(shape=(4096, 4096)) + y = mx.ones(shape=(4096, 4096)) @mx.compile - def fun(x): - return mx.sum(x, axis=1) + def abs_max(x): + return mx.max(mx.abs(x)) + + out = abs_max(x) + expected = y.abs().max() + self.assertTrue(mx.allclose(out, expected)) + + @mx.compile + def square_sum(x): + return x.square().sum() - out = fun(x) - expected = mx.sum(x, axis=1) + out = square_sum(x) + expected = y.square().sum() self.assertTrue(mx.allclose(out, expected)) diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 901419519c..be3a6d36a9 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -296,7 +296,7 @@ TEST_CASE("test compile unary fused") { CHECK_EQ(out.size(), 1); auto& p = out[0].primitive(); - // With fused-into-reduction, unary ops are fused into the Reduce primitive + // Unary ops are fused into the Reduce primitive CHECK_EQ(typeid(p), typeid(Reduce)); auto cout = out[0].inputs()[0]; auto& reduce = static_cast(p); From ce5684897ae0e6f6fab30fc5f39086b4491290ca Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 2 Feb 2026 19:19:31 +0100 Subject: [PATCH 9/9] python tests --- python/tests/test_compile.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 8ea10591d8..a3dd948c59 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1285,14 +1285,5 @@ def abs_max(x): expected = y.abs().max() self.assertTrue(mx.allclose(out, expected)) - @mx.compile - def square_sum(x): - return x.square().sum() - - out = square_sum(x) - expected = y.square().sum() - self.assertTrue(mx.allclose(out, expected)) - - if __name__ == "__main__": mlx_tests.MLXTestRunner()