diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 013b24b2f4..eb8785fa0f 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -42,6 +42,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/fused_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu @@ -90,7 +91,8 @@ file( GLOB MLX_JIT_SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "${CMAKE_CURRENT_SOURCE_DIR}/device/*.h" - "${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh") + "${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh" + "${CMAKE_CURRENT_SOURCE_DIR}/reduce/*.cuh") string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES}) add_custom_command( OUTPUT gen/cuda_jit_sources.h @@ -210,13 +212,16 @@ install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) install(DIRECTORY ${cccl_SOURCE_DIR}/include/nv DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) +# CUB and Thrust are needed to JIT compile unary -> reduction kernels. +install(DIRECTORY ${cccl_SOURCE_DIR}/include/cub + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) +install(DIRECTORY ${cccl_SOURCE_DIR}/include/thrust + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) -# The binary of C++ tests will not be installed so it can not find the CCCL -# headers, and we have to hard-code the path. -if(MLX_BUILD_TESTS) - target_compile_definitions(mlx - PRIVATE MLX_CCCL_DIR="${cccl_SOURCE_DIR}/include") -endif() +# JIT cannot find the CCCL headers when compiling unary -> reduction kernels, so +# we hard-code the path to the downloaded CCCL. +target_compile_definitions(mlx + PRIVATE MLX_CCCL_DIR="${cccl_SOURCE_DIR}/include") # Use fixed version of NVTX. FetchContent_Declare( diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index b0ebb40195..024bf549e5 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -223,6 +223,7 @@ bool compiler_supports_device_sass(Device& device) { } #define INCLUDE_PREFIX "mlx/backend/cuda/device/" +#define REDUCE_PREFIX "mlx/backend/cuda/reduce/" constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", @@ -236,9 +237,13 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "unary_ops.cuh", INCLUDE_PREFIX "ternary_ops.cuh", INCLUDE_PREFIX "utils.cuh", + REDUCE_PREFIX "all_reduce.cuh", + REDUCE_PREFIX "reduce_ops.cuh", + REDUCE_PREFIX "reduce_utils.cuh", }; #undef INCLUDE_PREFIX +#undef REDUCE_PREFIX constexpr const char* g_headers[] = { jit_source_atomic_ops, @@ -252,6 +257,9 @@ constexpr const char* g_headers[] = { jit_source_unary_ops, jit_source_ternary_ops, jit_source_utils, + jit_source_all_reduce, + jit_source_reduce_ops, + jit_source_reduce_utils, }; void compile( diff --git a/mlx/backend/cuda/reduce.cu b/mlx/backend/cuda/reduce.cu index 269efc034b..9c90b4c1ad 100644 --- a/mlx/backend/cuda/reduce.cu +++ b/mlx/backend/cuda/reduce.cu @@ -12,6 +12,39 @@ namespace mlx::core { void Reduce::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Reduce::eval_gpu"); + + if (has_fused_prefix()) { + array in = inputs[0]; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + if (in.size() == 0) { + init_reduce(encoder, in, out, reduce_type_); + return; + } + + ReductionPlan plan = get_reduction_plan(in, axes_); + + bool broadcasted = false; + for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + broadcasted = in.strides(i) == 0; + } + } + if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { + array in_copy = contiguous_copy_gpu(in, s); + encoder.add_temporary(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + fused_reduce(encoder, *this, inputs, out, axes_, plan, s); + return; + } + assert(inputs.size() == 1); array in = inputs[0]; diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 962e80d4f2..7ed7336701 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -1,64 +1,10 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/all_reduce.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" -#include -#include -#include - namespace mlx::core { -namespace cu { - -namespace cg = cooperative_groups; - -template -__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { - // TODO: Process multiple "rows" in each thread - constexpr int M = 1; - - auto grid = cg::this_grid(); - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - const U init = cu::ReduceInit::value(); - ReduceOp op; - - T vals[N]; - U accs[M]; - accs[0] = init; - - size_t start = grid.block_rank() * block_step; - size_t end = start + block_step; - size_t check = min(end, size); - - size_t i = start; - for (; i + block.size() * N <= check; i += block.size() * N) { - cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); - for (int j = 0; j < N; j++) { - accs[0] = op(accs[0], cast_to(vals[j])); - } - } - - if (i < check) { - cub::LoadDirectBlocked( - block.thread_rank(), in + i, vals, check - i, cast_to(init)); - for (int i = 0; i < N; i++) { - accs[0] = op(accs[0], cast_to(vals[i])); - } - } - - __shared__ U shared_accumulators[32]; - block_reduce(block, warp, accs, shared_accumulators, op, init); - - if (block.thread_rank() == 0) { - out[grid.block_rank()] = accs[0]; - } -} - -} // namespace cu - void all_reduce( cu::CommandEncoder& encoder, const array& in, diff --git a/mlx/backend/cuda/reduce/all_reduce.cuh b/mlx/backend/cuda/reduce/all_reduce.cuh new file mode 100644 index 0000000000..e661af56e1 --- /dev/null +++ b/mlx/backend/cuda/reduce/all_reduce.cuh @@ -0,0 +1,76 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" + +#include +#include +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template +__device__ void all_reduce_impl( + T* in, + U* out, + size_t block_step, + size_t size, + PrefixOp prefix) { + // TODO: Process multiple "rows" in each thread + constexpr int M = 1; + + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + const U init = ReduceInit::value(); + ReduceOp op; + + T vals[N]; + U accs[M]; + accs[0] = init; + + size_t start = grid.block_rank() * block_step; + size_t end = start + block_step; + size_t check = min(end, size); + + size_t i = start; + for (; i + block.size() * N <= check; i += block.size() * N) { + cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); + for (int j = 0; j < N; j++) { + accs[0] = op(accs[0], cast_to(prefix(vals[j]))); + } + } + + if (i < check) { + cub::LoadDirectBlocked( + block.thread_rank(), in + i, vals, check - i, cast_to(init)); + for (int j = 0; j < N; j++) { + accs[0] = op(accs[0], cast_to(prefix(vals[j]))); + } + } + + __shared__ U shared_accumulators[32]; + block_reduce(block, warp, accs, shared_accumulators, op, init); + + if (block.thread_rank() == 0) { + out[grid.block_rank()] = accs[0]; + } +} + +template < + typename T, + typename U, + typename ReduceOp, + int N = 4, + typename PrefixOp = Identity> +__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { + all_reduce_impl( + in, out, block_step, size, PrefixOp{}); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/reduce/fused_reduce.cu b/mlx/backend/cuda/reduce/fused_reduce.cu new file mode 100644 index 0000000000..74c4a99082 --- /dev/null +++ b/mlx/backend/cuda/reduce/fused_reduce.cu @@ -0,0 +1,351 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/reduce/all_reduce.cuh" +#include "mlx/backend/cuda/reduce/reduce.cuh" +#include "mlx/backend/cuda/utils.h" +#include "mlx/graph_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +// Builder to generate prefix functor code that will be +// applied in reduction kernel +struct FusedReducePrefixBuilder { + std::string os; + const std::string& kernel_name; + const std::vector& inputs; + const std::vector& tape; + const std::function& is_constant; + + void build_prefix_struct() { + NodeNamer namer; + + std::unordered_set constant_set; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant(i)) { + constant_set.insert(inputs[i].id()); + } + } + + os += "struct " + kernel_name + "_Prefix {\n"; + + os += "\n template \n"; + os += " __device__ __forceinline__ T operator()(T val) const {\n"; + + // Read the first input (reduction input) from the passed value + // Find the non-constant input + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + if (is_constant(i)) { + std::ostringstream ss; + print_constant(ss, x); + os += fmt::format( + " {} tmp_{} = static_cast<{}>({}); // constant\n", + type, + xname, + type, + ss.str()); + } else { + os += fmt::format( + " {} tmp_{} = static_cast<{}>(val);\n", type, xname, type); + } + } + + for (const auto& x : tape) { + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_cuda_type(x.dtype()); + std::string value; + if (is_static_cast(x.primitive())) { + value = fmt::format( + "static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0])); + } else { + value = x.primitive().name(); + value += "{}("; + for (size_t i = 0; i < x.inputs().size() - 1; ++i) { + value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i])); + } + value += fmt::format("tmp_{})", namer.get_name(x.inputs().back())); + } + os += fmt::format(" {} tmp_{} = {};\n", type, xname, value); + } + + // Return the result (last tape item or first input if tape is empty) + if (!tape.empty()) { + os += fmt::format( + " return static_cast(tmp_{});\n", namer.get_name(tape.back())); + } else { + // Find the non-constant input to return + for (size_t i = 0; i < inputs.size(); ++i) { + if (!is_constant(i)) { + os += fmt::format( + " return static_cast(tmp_{});\n", + namer.get_name(inputs[i])); + break; + } + } + } + + os += " }\n"; + + os += "};\n\n"; + } +}; + +} // namespace cu + +constexpr const char* g_fused_reduce_includes = R"( +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/device/ternary_ops.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/reduce/all_reduce.cuh" + +#include + +#define inf cuda::std::numeric_limits::infinity() +)"; + +std::string get_reduce_op_name(Reduce::ReduceType reduce_type) { + switch (reduce_type) { + case Reduce::ReduceType::And: + return "And"; + case Reduce::ReduceType::Or: + return "Or"; + case Reduce::ReduceType::Sum: + return "Sum"; + case Reduce::ReduceType::Prod: + return "Prod"; + case Reduce::ReduceType::Max: + return "Max"; + case Reduce::ReduceType::Min: + return "Min"; + default: + throw std::runtime_error("Unknown reduce type"); + } +} + +void fused_all_reduce( + cu::CommandEncoder& encoder, + const Reduce& reduce, + const std::vector& inputs, + array& out, + const Stream& stream) { + nvtx3::scoped_range r("fused_all_reduce"); + + // Copied from all_reduce.cu + constexpr int N_READS = 8; + + const auto& prefix_inputs = reduce.prefix_inputs(); + const auto& prefix_tape = reduce.prefix_tape(); + const auto& prefix_constant_ids = reduce.prefix_constant_ids(); + + auto is_constant = [&](size_t i) -> bool { + return prefix_constant_ids.count(prefix_inputs[i].id()) > 0; + }; + + NodeNamer namer; + std::ostringstream os; + std::ostringstream constant_hasher; + + for (const auto& x : prefix_inputs) { + namer.get_name(x); + } + + // Build string from tape operations + for (const auto& a : prefix_tape) { + os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); + os << a.primitive().name(); + for (const auto& inp : a.inputs()) { + os << namer.get_name(inp); + } + } + // Name the kernel: similar to Compiled::Compiled kernel naming + for (size_t i = 0; i < prefix_inputs.size(); ++i) { + const auto& x = prefix_inputs[i]; + if (is_constant(i)) { + os << "C"; + print_constant(constant_hasher, x); + } else { + os << (is_scalar(x) ? "S" : "V"); + } + } + + os << get_reduce_op_name(reduce.state().first); + os << dtype_to_cuda_type(prefix_inputs[0].dtype()); + os << std::hash{}(constant_hasher.str()); + + std::string kernel_name = os.str(); + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + // Copied from all_reduce.cu + auto get_args = [](int size, int N) { + int threads = std::min(512, (size + N - 1) / N); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int reductions_per_step = threads * N; + size_t steps_needed = + (size + reductions_per_step - 1) / reductions_per_step; + + int blocks; + if (steps_needed < 32) { + blocks = 1; + } else if (steps_needed < 128) { + blocks = 32; + } else if (steps_needed < 512) { + blocks = 128; + } else if (steps_needed < 1024) { + blocks = 512; + } else { + blocks = 1024; + } + + size_t steps_per_block = (steps_needed + blocks - 1) / blocks; + size_t block_step = steps_per_block * reductions_per_step; + + return std::make_tuple(blocks, threads, block_step); + }; + + std::string reduce_op = get_reduce_op_name(reduce.state().first); + std::string in_type = dtype_to_cuda_type(prefix_inputs[0].dtype()); + std::string prefix_type = kernel_name + "_Prefix"; + + std::string full_kernel_name = + fmt::format("mlx::core::cu::{}_all_reduce", kernel_name); + + cu::JitModule& mod = cu::get_jit_module(stream.device, kernel_name, [&]() { + cu::FusedReducePrefixBuilder builder{ + g_fused_reduce_includes, + kernel_name, + prefix_inputs, + prefix_tape, + is_constant}; + builder.os += + "namespace mlx::core::cu {\n\n" + "namespace cg = cooperative_groups;\n\n"; + + // Generate the prefix struct + builder.build_prefix_struct(); + + builder.os += fmt::format( + "__global__ void {}_all_reduce({}* in, {}* out, size_t block_step, size_t size) {{\n" + " {} prefix{{}};\n" + " all_reduce_impl<{}, {}, {}, {}, {}>(in, out, block_step, size, prefix);\n" + "}}\n", + kernel_name, + in_type, + in_type, + prefix_type, + in_type, + in_type, + reduce_op, + N_READS, + prefix_type); + + builder.os += "\n} // namespace mlx::core::cu\n"; + + std::vector kernel_names; + kernel_names.push_back(full_kernel_name); + + return std::make_tuple( + false, std::move(builder.os), std::move(kernel_names)); + }); + + int blocks, threads; + size_t block_step; + + size_t insize = inputs[0].size(); + Reduce::ReduceType reduce_type = reduce.state().first; + + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + + encoder.set_input_array(inputs[0]); + + // If the reduction needs more than 1 block -- use fused kernel with + // fused prefix, then reduce intermediate to final output using build compiled + // all_reduce + if (blocks > 1) { + array intermediate({blocks}, out.dtype(), nullptr, {}); + intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder)); + encoder.add_temporary(intermediate); + encoder.set_output_array(intermediate); + + // First pass: apply fused prefix and reduce to intermediate + cu::KernelArgs args; + args.append(inputs[0]); + args.append(intermediate); + args.append(block_step); + args.append(insize); + + auto kernel = mod.get_kernel(full_kernel_name); + encoder.add_kernel_node(kernel, blocks, threads, 0, args.args()); + + // Second pass: reduce intermediate to final output + size_t intermediate_size = intermediate.size(); + std::tie(blocks, threads, block_step) = + get_args(intermediate_size, N_READS); + encoder.set_input_array(intermediate); + encoder.set_output_array(out); + + dispatch_all_types(out.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; + auto kernel2 = cu::all_reduce; + encoder.add_kernel_node( + kernel2, + blocks, + threads, + 0, + gpu_ptr(intermediate), + gpu_ptr(out), + block_step, + intermediate_size); + }); + }); + } else { + // Single block: direct reduction with fused prefix + encoder.set_output_array(out); + + cu::KernelArgs args; + args.append(inputs[0]); + args.append(out); + args.append(block_step); + args.append(insize); + + auto kernel = mod.get_kernel(full_kernel_name); + encoder.add_kernel_node(kernel, blocks, threads, 0, args.args()); + } +} + +void fused_reduce( + cu::CommandEncoder& encoder, + const Reduce& reduce, + const std::vector& inputs, + array& out, + const std::vector& axes, + const ReductionPlan& plan, + const Stream& stream) { + if (plan.type == ContiguousAllReduce) { + fused_all_reduce(encoder, reduce, inputs, out, stream); + return; + } + + // TODO: Implement fused row_reduce and col_reduce + throw std::runtime_error( + "Fused reduce not yet implemented for this reduction type"); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 02e495594a..078685e345 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -68,4 +68,13 @@ void init_reduce( array& out, Reduce::ReduceType reduce_type); +void fused_reduce( + cu::CommandEncoder& encoder, + const Reduce& reduce, + const std::vector& inputs, + array& out, + const std::vector& axes, + const ReductionPlan& plan, + const Stream& stream); + } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 6c6b1827ce..a8302dded4 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -9,6 +9,13 @@ namespace mlx::core::cu { +struct Identity { + template + __device__ __forceinline__ T operator()(T x) const { + return x; + } +}; + // Reduce ops. struct And { __device__ __forceinline__ bool operator()(bool a, bool b) { diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index fdbe723378..4122e32b42 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -2,15 +2,18 @@ #pragma once -#include - -#include "mlx/backend/common/utils.h" -#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/utils.cuh" #include #include +// Host-only includes +#ifndef __CUDACC_RTC__ +#include +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#endif + namespace mlx::core { namespace cu { @@ -90,6 +93,7 @@ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { } // namespace cu +#ifndef __CUDACC_RTC__ inline void allocate_same_layout( array& out, const array& in, @@ -141,5 +145,6 @@ inline void allocate_same_layout( fl, allocator::free); } +#endif } // namespace mlx::core diff --git a/mlx/compile.cpp b/mlx/compile.cpp index ca5f069937..d90275f555 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -801,6 +801,81 @@ void compile_fuse( if (global_cache.find(arr.id()) != global_cache.end()) { continue; } + // If current op is a reduction, we may want to fuse prefix ops + // Only fuse for all_reduce + if (arr.has_primitive() && is_reduction(arr.primitive()) && + arr.size() == 1) { + auto& reduction_input = arr.inputs()[0]; + Stream reduction_stream = arr.primitive().stream(); + const int max_prefix_depth = max_compile_depth - 1; // 1 for reduction + + std::vector prefix_tape; // + std::vector prefix_inputs; // + std::unordered_set visited; + + std::function collect_prefix; + collect_prefix = [&](const array& a, int depth) { + // Skip if already processed + if (visited.count(a.id())) { + return; + } + // Stop fusing if: + // depth limit exceeded + // non unary primitive + // does not have primitive + // stream mismatch + // is a constant input + if (depth >= max_prefix_depth || !a.has_primitive() || + !is_unary(a.primitive()) || + a.primitive().stream() != reduction_stream || + input_ids.count(a.id())) { + prefix_inputs.push_back(a); + visited.insert(a.id()); + return; + } + // Check if the input is used multiple times + auto pit = parents_map.find(a.id()); + if (pit != parents_map.end() && pit->second.size() > 1) { + prefix_inputs.push_back(a); + visited.insert(a.id()); + return; + } + visited.insert(a.id()); + for (auto& in : a.inputs()) { + collect_prefix(in, depth + 1); + } + prefix_tape.push_back(a); + }; + + collect_prefix(reduction_input, 0); + + // If there are operations that we can fuse + if (!prefix_tape.empty()) { + std::unordered_set constant_ids; + for (auto& in : prefix_inputs) { + if (in.size() == 1 && !in.has_primitive() && + input_ids.find(in.id()) == input_ids.end()) { + constant_ids.insert(in.id()); + } + } + + // Attach prefix to the Reduce primitive + auto& reduce = static_cast(arr.primitive()); + reduce.set_fused_prefix( + prefix_tape, prefix_inputs, std::move(constant_ids)); + + for (auto& p : prefix_tape) { + global_cache.insert(p.id()); + parents_map.erase(p.id()); + } + arr.inputs() = std::move(prefix_inputs); + } + + // Add the reduction to the new tape (with or without fused prefix) + new_tape.push_back(arr); + global_cache.insert(arr.id()); + continue; + } // Two pass recursion: // First pass: diff --git a/mlx/primitives.h b/mlx/primitives.h index 4091aafcfb..cfb96173dc 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1809,9 +1809,39 @@ class MLX_API Reduce : public UnaryPrimitive { return {reduce_type_, axes_}; }; + void set_fused_prefix( + std::vector tape, + std::vector inputs, + std::unordered_set constant_ids) { + prefix_tape_ = std::move(tape); + prefix_inputs_ = std::move(inputs); + prefix_constant_ids_ = std::move(constant_ids); + } + + bool has_fused_prefix() const { + return !prefix_tape_.empty(); + } + + const std::vector& prefix_tape() const { + return prefix_tape_; + } + + const std::vector& prefix_inputs() const { + return prefix_inputs_; + } + + const std::unordered_set& prefix_constant_ids() const { + return prefix_constant_ids_; + } + private: ReduceType reduce_type_; std::vector axes_; + + // Fused prefix storage + std::vector prefix_tape_; + std::vector prefix_inputs_; + std::unordered_set prefix_constant_ids_; }; class Round : public UnaryPrimitive { diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index d64c057fd1..a3dd948c59 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1272,6 +1272,18 @@ def fun(x): np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr ) + def test_compile_unary_reduction(self): + + x = mx.ones(shape=(4096, 4096)) + y = mx.ones(shape=(4096, 4096)) + + @mx.compile + def abs_max(x): + return mx.max(mx.abs(x)) + + out = abs_max(x) + expected = y.abs().max() + self.assertTrue(mx.allclose(out, expected)) if __name__ == "__main__": mlx_tests.MLXTestRunner() diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index d2146f3bad..be3a6d36a9 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -296,13 +296,13 @@ TEST_CASE("test compile unary fused") { CHECK_EQ(out.size(), 1); auto& p = out[0].primitive(); - // NB: this test is brittle, will need to update - // it if we change compile conditions + // Unary ops are fused into the Reduce primitive CHECK_EQ(typeid(p), typeid(Reduce)); auto cout = out[0].inputs()[0]; - auto& cp = cout.primitive(); - CHECK_EQ(typeid(cp), typeid(Compiled)); - CHECK_EQ(cout.inputs()[0].id(), x.id()); + auto& reduce = static_cast(p); + CHECK(reduce.has_fused_prefix()); + // The input to Reduce is directly x (not a Compiled primitive) + CHECK_EQ(out[0].inputs()[0].id(), x.id()); } { @@ -397,10 +397,11 @@ TEST_CASE("test compile binary fused") { auto& p = out.primitive(); CHECK_EQ(typeid(p), typeid(Reduce)); - + // With fused-into-reduction, only unary ops (abs) are fused into Reduce + // The binary op (Add) remains as the input to fused reduce + abs auto cout = out.inputs()[0]; auto& cp = cout.primitive(); - CHECK_EQ(typeid(cp), typeid(Compiled)); + CHECK_EQ(typeid(cp), typeid(Add)); CHECK_EQ(cout.inputs()[0].id(), x.id()); CHECK_EQ(cout.inputs()[1].id(), y.id()); } @@ -816,3 +817,36 @@ TEST_CASE("test compile random bits") { auto out = compile(fun)({in})[0]; CHECK(array_equal(out, expected).item()); } + +TEST_CASE("test compile unary reduction one pass") { + auto fun = [](const std::vector& inputs) { + auto x = inputs[0]; + return std::vector{sum(abs(x))}; + }; + auto in = ones({128, 128}); + auto expected = fun({in})[0]; + auto out = compile(fun)({ones({128, 128})})[0]; + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test compile unary reduction two passes") { + auto fun = [](const std::vector& inputs) { + auto x = inputs[0]; + return std::vector{sum(abs(x))}; + }; + auto in = ones({1024, 1024}); + auto expected = fun({in})[0]; + auto out = compile(fun)({ones({1024, 1024})})[0]; + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test compile unary+constant reduction two passes") { + auto fun = [](const std::vector& inputs) { + auto x = inputs[0]; + return std::vector{sum(abs(x) + 1.0f)}; + }; + auto in = ones({1024, 1024}); + auto expected = fun({in})[0]; + auto out = compile(fun)({ones({1024, 1024})})[0]; + CHECK(array_equal(out, expected).item()); +} \ No newline at end of file