diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 4074e7b1e9..4af62f517d 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -48,6 +48,7 @@ if(MLX_METAL_JIT) make_jit_source(softmax) make_jit_source(scan) make_jit_source(sort) + make_jit_source(radix_select) make_jit_source( reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h) diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index a6ef0f14af..f93d1e593c 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -31,6 +31,7 @@ const char* scan(); const char* scatter_axis(); const char* softmax(); const char* sort(); +const char* radix_select(); const char* reduce(); const char* gemm(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index b657457e95..28b3fc953d 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -429,6 +429,44 @@ MTL::ComputePipelineState* get_mb_sort_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_radix_select_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + int bn, + int tn) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name, [&]() { + std::ostringstream kernel_source; + auto in_type = get_type_string(in.dtype()); + auto out_type = get_type_string(out.dtype()); + kernel_source << metal::utils() << metal::radix_select(); + for (bool is_arg_partition : {true, false}) { + std::string bool_string = is_arg_partition ? "true" : "false"; + std::string func_string = is_arg_partition ? "carg_" : "c_"; + kernel_source << get_template_definition( + func_string + lib_name, + "radix_select_partition", + in_type, + out_type, + bool_string, + bn, + tn); + kernel_source << get_template_definition( + "n" + func_string + lib_name, + "radix_select_partition_nc", + in_type, + out_type, + bool_string, + bn, + tn); + } + return kernel_source.str(); + }); + return d.get_kernel(kernel_name, lib); +} + MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 82aa4f976a..d8d1cb19c5 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -89,6 +89,14 @@ MTL::ComputePipelineState* get_mb_sort_kernel( int bn, int tn); +MTL::ComputePipelineState* get_radix_select_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out, + int bn, + int tn); + MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 7d010e6c8c..0169b9268d 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -140,6 +140,7 @@ if(NOT MLX_METAL_JIT) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) build_kernel(sort sort.h) + build_kernel(radix_select radix_select.h) build_kernel(ternary ternary.h ternary_ops.h) build_kernel(unary unary.h unary_ops.h) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) diff --git a/mlx/backend/metal/kernels/radix_select.h b/mlx/backend/metal/kernels/radix_select.h new file mode 100644 index 0000000000..7654550c18 --- /dev/null +++ b/mlx/backend/metal/kernels/radix_select.h @@ -0,0 +1,1506 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include "mlx/backend/metal/kernels/bf16.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Radix Select Implementation for Metal +// +// Multi-pass radix-based selection algorithm for partition operations. +// Uses IEEE 754 bit manipulation for correct floating-point ordering. +/////////////////////////////////////////////////////////////////////////////// + +// Radix configuration +constant constexpr int RADIX_BITS = 8; +constant constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 256 bins +constant constexpr int SIMD_SIZE = 32; + +/////////////////////////////////////////////////////////////////////////////// +// Bit manipulation for radix sorting +/////////////////////////////////////////////////////////////////////////////// + +template +struct RadixTraits; + +template <> +struct RadixTraits { + using UnsignedT = uint32_t; + static constexpr constant int BITS = 32; + + static METAL_FUNC UnsignedT to_radix(float val) { + UnsignedT bits = as_type(val); + UnsignedT mask = -int32_t(bits >> 31) | 0x80000000u; + return bits ^ mask; + } + + static METAL_FUNC float from_radix(UnsignedT bits) { + UnsignedT mask = ((bits >> 31) - 1) | 0x80000000u; + return as_type(bits ^ mask); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint16_t; + static constexpr constant int BITS = 16; + + static METAL_FUNC UnsignedT to_radix(half val) { + UnsignedT bits = as_type(val); + UnsignedT mask = -int16_t(bits >> 15) | 0x8000u; + return bits ^ mask; + } + + static METAL_FUNC half from_radix(UnsignedT bits) { + UnsignedT mask = ((bits >> 15) - 1) | 0x8000u; + UnsignedT result = bits ^ mask; + return as_type(result); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint16_t; + static constexpr constant int BITS = 16; + + static METAL_FUNC UnsignedT to_radix(bfloat16_t val) { + UnsignedT bits = as_type(val); + UnsignedT mask = -int16_t(bits >> 15) | 0x8000u; + return bits ^ mask; + } + + static METAL_FUNC bfloat16_t from_radix(UnsignedT bits) { + UnsignedT mask = ((bits >> 15) - 1) | 0x8000u; + UnsignedT result = bits ^ mask; + return as_type(result); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint8_t; + static constexpr constant int BITS = 8; + static METAL_FUNC UnsignedT to_radix(int8_t val) { + return static_cast(val) ^ 0x80u; + } + static METAL_FUNC int8_t from_radix(UnsignedT bits) { + return static_cast(bits ^ 0x80u); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint16_t; + static constexpr constant int BITS = 16; + static METAL_FUNC UnsignedT to_radix(int16_t val) { + return static_cast(val) ^ 0x8000u; + } + static METAL_FUNC int16_t from_radix(UnsignedT bits) { + return static_cast(bits ^ 0x8000u); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint32_t; + static constexpr constant int BITS = 32; + static METAL_FUNC UnsignedT to_radix(int32_t val) { + return static_cast(val) ^ 0x80000000u; + } + static METAL_FUNC int32_t from_radix(UnsignedT bits) { + return static_cast(bits ^ 0x80000000u); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr constant int BITS = 64; + static METAL_FUNC UnsignedT to_radix(int64_t val) { + return static_cast(val) ^ 0x8000000000000000ull; + } + static METAL_FUNC int64_t from_radix(UnsignedT bits) { + return static_cast(bits ^ 0x8000000000000000ull); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint8_t; + static constexpr constant int BITS = 8; + static METAL_FUNC UnsignedT to_radix(uint8_t val) { + return val; + } + static METAL_FUNC uint8_t from_radix(UnsignedT bits) { + return bits; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint16_t; + static constexpr constant int BITS = 16; + static METAL_FUNC UnsignedT to_radix(uint16_t val) { + return val; + } + static METAL_FUNC uint16_t from_radix(UnsignedT bits) { + return bits; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint32_t; + static constexpr constant int BITS = 32; + static METAL_FUNC UnsignedT to_radix(uint32_t val) { + return val; + } + static METAL_FUNC uint32_t from_radix(UnsignedT bits) { + return bits; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr constant int BITS = 64; + static METAL_FUNC UnsignedT to_radix(uint64_t val) { + return val; + } + static METAL_FUNC uint64_t from_radix(UnsignedT bits) { + return bits; + } +}; + +template +METAL_FUNC int extract_digit(UnsignedT val, int start_bit, int num_bits) { + return (val >> start_bit) & ((1 << num_bits) - 1); +} + +template +METAL_FUNC bool is_nan_value(T val) { + if constexpr (is_floating_point_v) { + return isnan(val); + } else { + return false; + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Multi-pass Radix Select Kernels +/////////////////////////////////////////////////////////////////////////////// + +// Build histogram across all elements +template +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +radix_histogram_kernel( + const device ValT* input [[buffer(0)]], + device atomic_int* histogram [[buffer(1)]], + const constant int& n [[buffer(2)]], + const constant int& stride [[buffer(3)]], + const constant int& start_bit [[buffer(4)]], + const constant int& segment_stride [[buffer(5)]], + const constant uint64_t& prefix_mask [[buffer(6)]], + const constant uint64_t& target_prefix [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 grid_dims [[threadgroups_per_grid]]) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + threadgroup int shared_hist[RADIX_SIZE]; + + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + int row = tid.y; + const device ValT* row_input = input + row * segment_stride; + + // Each threadgroup processes a chunk of the array + int total_threads = grid_dims.x * BLOCK_THREADS; + int global_tid = tid.x * BLOCK_THREADS + lid.x; + + for (int i = global_tid; i < n; i += total_threads) { + ValT val = row_input[i * stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + // Only count if matches current prefix + if ((key & UnsignedT(prefix_mask)) == UnsignedT(target_prefix)) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit], + 1, + memory_order_relaxed); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduce to global histogram + device atomic_int* row_hist = histogram + row * RADIX_SIZE; + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + if (shared_hist[i] > 0) { + atomic_fetch_add_explicit( + &row_hist[i], shared_hist[i], memory_order_relaxed); + } + } +} + +// Find target bin from histogram +template +[[kernel]] void radix_find_bin_kernel( + const device int* histogram [[buffer(0)]], + device int* target_bin [[buffer(1)]], + device int* new_k [[buffer(2)]], + const constant int& k [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]]) { + int row = tid.y; + const device int* row_hist = histogram + row * RADIX_SIZE; + + int cumsum = 0; + int bin = 0; + int remaining_k = k; + + for (int i = 0; i < RADIX_SIZE; i++) { + int count = row_hist[i]; + if (cumsum + count >= k) { + bin = i; + remaining_k = k - cumsum; + break; + } + cumsum += count; + } + + target_bin[row] = bin; + new_k[row] = remaining_k; +} + +// Partition output with known pivot +template +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +radix_partition_output_kernel( + const device ValT* input [[buffer(0)]], + device OutT* output [[buffer(1)]], + device atomic_int* counters [[buffer(2)]], + const constant int& n [[buffer(3)]], + const constant int& in_stride [[buffer(4)]], + const constant int& out_stride [[buffer(5)]], + const constant int& segment_stride [[buffer(6)]], + const constant int& out_segment_stride [[buffer(7)]], + const constant uint64_t& pivot_key [[buffer(8)]], + const constant int& kth [[buffer(9)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 grid_dims [[threadgroups_per_grid]]) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + int row = tid.y; + const device ValT* row_input = input + row * segment_stride; + device OutT* row_output = output + row * out_segment_stride; + + // Counters: [0] = less_count, [1] = equal_count, [2] = greater_count + device atomic_int* row_counters = counters + row * 3; + + int total_threads = grid_dims.x * BLOCK_THREADS; + int global_tid = tid.x * BLOCK_THREADS + lid.x; + + UnsignedT pivot = UnsignedT(pivot_key); + + // Phase 1: Count and output elements less than pivot + for (int i = global_tid; i < n; i += total_threads) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + if (key < pivot) { + int pos = + atomic_fetch_add_explicit(&row_counters[0], 1, memory_order_relaxed); + if (ARG_PARTITION) { + row_output[pos * out_stride] = i; + } else { + row_output[pos * out_stride] = val; + } + } + } +} + +// Output equal elements +template +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +radix_partition_equal_kernel( + const device ValT* input [[buffer(0)]], + device OutT* output [[buffer(1)]], + device atomic_int* counters [[buffer(2)]], + const constant int& n [[buffer(3)]], + const constant int& in_stride [[buffer(4)]], + const constant int& out_stride [[buffer(5)]], + const constant int& segment_stride [[buffer(6)]], + const constant int& out_segment_stride [[buffer(7)]], + const constant uint64_t& pivot_key [[buffer(8)]], + const constant int& less_count [[buffer(9)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 grid_dims [[threadgroups_per_grid]]) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + int row = tid.y; + const device ValT* row_input = input + row * segment_stride; + device OutT* row_output = output + row * out_segment_stride; + device atomic_int* row_counters = counters + row * 3; + + int total_threads = grid_dims.x * BLOCK_THREADS; + int global_tid = tid.x * BLOCK_THREADS + lid.x; + + UnsignedT pivot = UnsignedT(pivot_key); + + for (int i = global_tid; i < n; i += total_threads) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + if (key == pivot) { + int pos = less_count + + atomic_fetch_add_explicit(&row_counters[1], 1, memory_order_relaxed); + if (ARG_PARTITION) { + row_output[pos * out_stride] = i; + } else { + row_output[pos * out_stride] = val; + } + } + } +} + +// Output greater elements +template +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +radix_partition_greater_kernel( + const device ValT* input [[buffer(0)]], + device OutT* output [[buffer(1)]], + device atomic_int* counters [[buffer(2)]], + const constant int& n [[buffer(3)]], + const constant int& in_stride [[buffer(4)]], + const constant int& out_stride [[buffer(5)]], + const constant int& segment_stride [[buffer(6)]], + const constant int& out_segment_stride [[buffer(7)]], + const constant uint64_t& pivot_key [[buffer(8)]], + const constant int& less_equal_count [[buffer(9)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 grid_dims [[threadgroups_per_grid]]) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + int row = tid.y; + const device ValT* row_input = input + row * segment_stride; + device OutT* row_output = output + row * out_segment_stride; + device atomic_int* row_counters = counters + row * 3; + + int total_threads = grid_dims.x * BLOCK_THREADS; + int global_tid = tid.x * BLOCK_THREADS + lid.x; + + UnsignedT pivot = UnsignedT(pivot_key); + + for (int i = global_tid; i < n; i += total_threads) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + if (key > pivot) { + int pos = less_equal_count + + atomic_fetch_add_explicit(&row_counters[2], 1, memory_order_relaxed); + if (ARG_PARTITION) { + row_output[pos * out_stride] = i; + } else { + row_output[pos * out_stride] = val; + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Fused Multi-pass Radix Select for Large Arrays +// +// Performs the complete radix select in a single dispatch: +// 1. Build histograms in parallel across threadgroups +// 2. Reduce histograms and find pivot +// 3. Output partitioned results +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +radix_select_large_fused( + const device ValT* input [[buffer(0)]], + device OutT* output [[buffer(1)]], + device int* global_histogram [[buffer(2)]], + device atomic_int* global_counters [[buffer(3)]], + device int* pivot_info [[buffer(4)]], + const constant int& n [[buffer(5)]], + const constant int& kth [[buffer(6)]], + const constant int& in_stride [[buffer(7)]], + const constant int& out_stride [[buffer(8)]], + const constant int& segment_stride [[buffer(9)]], + const constant int& out_segment_stride [[buffer(10)]], + const constant int& num_blocks [[buffer(11)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_group [[simdgroup_index_in_threadgroup]]) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + constexpr int NUM_SIMD_GROUPS = BLOCK_THREADS / SIMD_SIZE; + + int row = tid.y; + int block_id = tid.x; + const device ValT* row_input = input + row * segment_stride; + device OutT* row_output = output + row * out_segment_stride; + + // Shared memory for histogram and reduction + threadgroup int shared_hist[RADIX_SIZE]; + threadgroup int simd_hist[NUM_SIMD_GROUPS][RADIX_SIZE]; + threadgroup int + shared_pivot[4]; // [target_bin, new_k, less_count, equal_count] + threadgroup UnsignedT shared_pivot_key[1]; + + // Per-row global state + device int* row_histogram = global_histogram + row * RADIX_SIZE; + device atomic_int* row_counters = global_counters + row * 4; + device int* row_pivot = pivot_info + row * 4; + + int k = kth + 1; + UnsignedT target_prefix = 0; + UnsignedT prefix_mask = 0; + + // Multi-pass radix select + for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { + int start_bit = pass * RADIX_BITS; + + // Clear shared histogram + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + // Clear SIMD histograms + if (simd_lane < RADIX_SIZE / SIMD_SIZE) { + for (int s = 0; s < NUM_SIMD_GROUPS; s++) { + simd_hist[s][simd_lane * SIMD_SIZE + simd_group] = 0; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Phase 1: Build local histogram with SIMD optimization + // Each thread maintains private histogram bins + int private_hist[4] = {0, 0, 0, 0}; // Process 4 bins at a time + + int elements_per_block = (n + num_blocks - 1) / num_blocks; + int start_idx = block_id * elements_per_block; + int end_idx = min(start_idx + elements_per_block, n); + + for (int i = start_idx + lid.x; i < end_idx; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + // Only count elements matching current prefix + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + // Use SIMD shuffle to aggregate within SIMD group + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&simd_hist[simd_group][digit], + 1, + memory_order_relaxed); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduce SIMD histograms to shared histogram + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + int sum = 0; + for (int s = 0; s < NUM_SIMD_GROUPS; s++) { + sum += simd_hist[s][i]; + } + shared_hist[i] = sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Phase 2: Reduce to global histogram (only first block does final + // reduction) + if (block_id == 0) { + // Clear global histogram first + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + row_histogram[i] = 0; + } + } + threadgroup_barrier(mem_flags::mem_device); + + // All blocks contribute to global histogram + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + if (shared_hist[i] > 0) { + atomic_fetch_add_explicit( + (device atomic_int*)&row_histogram[i], + shared_hist[i], + memory_order_relaxed); + } + } + threadgroup_barrier(mem_flags::mem_device); + + // Phase 3: Find target bin (only block 0, thread 0) + if (block_id == 0 && lid.x == 0) { + int cumsum = 0; + int target_bin = 0; + int remaining_k = k; + + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = row_histogram[bin]; + if (cumsum + count >= k) { + target_bin = bin; + remaining_k = k - cumsum; + break; + } + cumsum += count; + } + + shared_pivot[0] = target_bin; + shared_pivot[1] = remaining_k; + row_pivot[pass * 2] = target_bin; + row_pivot[pass * 2 + 1] = remaining_k; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // All threads read the pivot info + int target_bin = shared_pivot[0]; + k = shared_pivot[1]; + + // Update prefix for next pass + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Store final pivot key + if (block_id == 0 && lid.x == 0) { + shared_pivot_key[0] = target_prefix; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + UnsignedT pivot_key = shared_pivot_key[0]; + + // Phase 4: Output partitioned array + // Reset counters + if (block_id == 0 && lid.x == 0) { + atomic_store_explicit(&row_counters[0], 0, memory_order_relaxed); // less + atomic_store_explicit(&row_counters[1], 0, memory_order_relaxed); // equal + atomic_store_explicit(&row_counters[2], 0, memory_order_relaxed); // greater + } + threadgroup_barrier(mem_flags::mem_device); + + // Count elements in each partition + int local_less = 0, local_equal = 0, local_greater = 0; + int elements_per_block = (n + num_blocks - 1) / num_blocks; + int start_idx = block_id * elements_per_block; + int end_idx = min(start_idx + elements_per_block, n); + + for (int i = start_idx + lid.x; i < end_idx; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + if (key < pivot_key) + local_less++; + else if (key == pivot_key) + local_equal++; + else + local_greater++; + } + + // Reduce within SIMD group + local_less = simd_sum(local_less); + local_equal = simd_sum(local_equal); + local_greater = simd_sum(local_greater); + + // First lane of each SIMD group contributes to global count + if (simd_lane == 0) { + atomic_fetch_add_explicit( + &row_counters[0], local_less, memory_order_relaxed); + atomic_fetch_add_explicit( + &row_counters[1], local_equal, memory_order_relaxed); + atomic_fetch_add_explicit( + &row_counters[2], local_greater, memory_order_relaxed); + } + threadgroup_barrier(mem_flags::mem_device); + + // Read final counts + if (lid.x == 0) { + shared_pivot[2] = + atomic_load_explicit(&row_counters[0], memory_order_relaxed); + shared_pivot[3] = + atomic_load_explicit(&row_counters[1], memory_order_relaxed); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + int less_count = shared_pivot[2]; + int equal_count = shared_pivot[3]; + + // Reset counters for output phase + if (block_id == 0 && lid.x == 0) { + atomic_store_explicit(&row_counters[0], 0, memory_order_relaxed); + atomic_store_explicit(&row_counters[1], 0, memory_order_relaxed); + atomic_store_explicit(&row_counters[2], 0, memory_order_relaxed); + } + threadgroup_barrier(mem_flags::mem_device); + + // Output elements + for (int i = start_idx + lid.x; i < end_idx; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + int pos; + if (key < pivot_key) { + pos = + atomic_fetch_add_explicit(&row_counters[0], 1, memory_order_relaxed); + } else if (key == pivot_key) { + pos = less_count + + atomic_fetch_add_explicit(&row_counters[1], 1, memory_order_relaxed); + } else { + pos = less_count + equal_count + + atomic_fetch_add_explicit(&row_counters[2], 1, memory_order_relaxed); + } + + if (ARG_PARTITION) { + row_output[pos * out_stride] = i; + } else { + row_output[pos * out_stride] = val; + } + } +} + +// Large array streaming kernel +template +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +radix_select_large_streaming( + const device ValT* input [[buffer(0)]], + device OutT* output [[buffer(1)]], + device atomic_int* counters [[buffer(2)]], + const constant int& n [[buffer(3)]], + const constant int& kth [[buffer(4)]], + const constant int& in_stride [[buffer(5)]], + const constant int& out_stride [[buffer(6)]], + const constant int& segment_stride [[buffer(7)]], + const constant int& out_segment_stride [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_group [[simdgroup_index_in_threadgroup]]) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + + int row = tid.y; + const device ValT* row_input = input + row * segment_stride; + device OutT* row_output = output + row * out_segment_stride; + + // Shared memory + threadgroup int shared_hist[RADIX_SIZE]; + threadgroup int shared_pivot_info[2]; + threadgroup int shared_counts[2]; + threadgroup int shared_output_counters[3]; + + int k = kth + 1; + UnsignedT target_prefix = 0; + UnsignedT prefix_mask = 0; + + // Check if data is contiguous for fast path + const bool is_contiguous = (in_stride == 1); + + // First pass - no prefix filtering needed (prefix_mask == 0) + { + int start_bit = (NUM_PASSES - 1) * RADIX_BITS; + + // Clear histogram + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Build histogram - no prefix check needed on first pass + if (is_contiguous) { + // Process 4 elements at a time for better memory throughput + int n4 = n & ~3; // Round down to multiple of 4 + for (int i = lid.x * 4; i < n4; i += BLOCK_THREADS * 4) { + ValT val0 = row_input[i]; + ValT val1 = row_input[i + 1]; + ValT val2 = row_input[i + 2]; + ValT val3 = row_input[i + 3]; + + UnsignedT key0 = Traits::to_radix(val0); + UnsignedT key1 = Traits::to_radix(val1); + UnsignedT key2 = Traits::to_radix(val2); + UnsignedT key3 = Traits::to_radix(val3); + + if (is_nan_value(val0)) + key0 = ~UnsignedT(0); + if (is_nan_value(val1)) + key1 = ~UnsignedT(0); + if (is_nan_value(val2)) + key2 = ~UnsignedT(0); + if (is_nan_value(val3)) + key3 = ~UnsignedT(0); + + int digit0 = extract_digit(key0, start_bit, RADIX_BITS); + int digit1 = extract_digit(key1, start_bit, RADIX_BITS); + int digit2 = extract_digit(key2, start_bit, RADIX_BITS); + int digit3 = extract_digit(key3, start_bit, RADIX_BITS); + + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit0], + 1, + memory_order_relaxed); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit1], + 1, + memory_order_relaxed); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit2], + 1, + memory_order_relaxed); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit3], + 1, + memory_order_relaxed); + } + // Handle remaining elements + for (int i = n4 + lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) + key = ~UnsignedT(0); + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit], + 1, + memory_order_relaxed); + } + } else { + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit], + 1, + memory_order_relaxed); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Find target bin + if (lid.x == 0) { + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = shared_hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k = k - cumsum; + break; + } + cumsum += count; + } + shared_pivot_info[0] = target_bin; + shared_pivot_info[1] = k; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + int target_bin = shared_pivot_info[0]; + k = shared_pivot_info[1]; + + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Remaining passes - need prefix filtering + for (int pass = NUM_PASSES - 2; pass >= 0; pass--) { + int start_bit = pass * RADIX_BITS; + + // Clear histogram + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Build histogram with prefix filtering + if (is_contiguous) { + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit], + 1, + memory_order_relaxed); + } + } + } else { + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit], + 1, + memory_order_relaxed); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Find target bin + if (lid.x == 0) { + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = shared_hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k = k - cumsum; + break; + } + cumsum += count; + } + shared_pivot_info[0] = target_bin; + shared_pivot_info[1] = k; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + int target_bin = shared_pivot_info[0]; + k = shared_pivot_info[1]; + + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + // Initialize counters for next phase while we have the barrier + if (lid.x == 0) { + shared_counts[0] = 0; + shared_counts[1] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Count partition sizes with SIMD reduction + int local_less = 0, local_equal = 0; + if (is_contiguous) { + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + if (key < target_prefix) + local_less++; + else if (key == target_prefix) + local_equal++; + } + } else { + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + if (key < target_prefix) + local_less++; + else if (key == target_prefix) + local_equal++; + } + } + + // SIMD reduction + local_less = simd_sum(local_less); + local_equal = simd_sum(local_equal); + + // Aggregate across SIMD groups (only first lane of each SIMD group) + if (simd_lane == 0) { + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_counts[0], + local_less, + memory_order_relaxed); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_counts[1], + local_equal, + memory_order_relaxed); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Read final counts - all threads read the same values + int less_count = shared_counts[0]; + int equal_count = shared_counts[1]; + + // Initialize output counters + if (lid.x == 0) { + shared_output_counters[0] = 0; // less output counter + shared_output_counters[1] = 0; // equal output counter + shared_output_counters[2] = 0; // greater output counter + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Output partitioned elements + if (is_contiguous && out_stride == 1) { + // Fast path: both input and output are contiguous + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + int pos; + if (key < target_prefix) { + pos = atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_output_counters[0], + 1, + memory_order_relaxed); + } else if (key == target_prefix) { + pos = less_count + + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_output_counters[1], + 1, + memory_order_relaxed); + } else { + pos = less_count + equal_count + + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_output_counters[2], + 1, + memory_order_relaxed); + } + + if (ARG_PARTITION) { + row_output[pos] = i; + } else { + row_output[pos] = val; + } + } + } else { + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + int pos; + if (key < target_prefix) { + pos = atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_output_counters[0], + 1, + memory_order_relaxed); + } else if (key == target_prefix) { + pos = less_count + + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_output_counters[1], + 1, + memory_order_relaxed); + } else { + pos = less_count + equal_count + + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_output_counters[2], + 1, + memory_order_relaxed); + } + + if (ARG_PARTITION) { + row_output[pos * out_stride] = i; + } else { + row_output[pos * out_stride] = val; + } + } + } +} + +// Large array streaming kernel for non-contiguous arrays +// Uses elem_to_loc for proper multi-dimensional indexing +template +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +radix_select_large_streaming_nc( + const device ValT* input [[buffer(0)]], + device OutT* output [[buffer(1)]], + device atomic_int* counters [[buffer(2)]], + const constant int& n [[buffer(3)]], + const constant int& kth [[buffer(4)]], + const constant int& in_stride [[buffer(5)]], + const constant int& out_stride [[buffer(6)]], + const constant int& nc_dim [[buffer(7)]], + const constant int* nc_shape [[buffer(8)]], + const constant int64_t* in_nc_strides [[buffer(9)]], + const constant int64_t* out_nc_strides [[buffer(10)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_group [[simdgroup_index_in_threadgroup]]) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + + // Compute row offsets using elem_to_loc for non-contiguous arrays + int row = tid.y; + auto in_offset = elem_to_loc(row, nc_shape, in_nc_strides, nc_dim); + auto out_offset = elem_to_loc(row, nc_shape, out_nc_strides, nc_dim); + + const device ValT* row_input = input + in_offset; + device OutT* row_output = output + out_offset; + + // Shared memory + threadgroup int shared_hist[RADIX_SIZE]; + threadgroup int shared_pivot_info[2]; + threadgroup int shared_counts[2]; + threadgroup int shared_output_counters[3]; + + int k = kth + 1; + UnsignedT target_prefix = 0; + UnsignedT prefix_mask = 0; + + // Multi-pass to find pivot + for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { + int start_bit = pass * RADIX_BITS; + + // Clear histogram + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Build histogram + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit], + 1, + memory_order_relaxed); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Find target bin + if (lid.x == 0) { + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = shared_hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k = k - cumsum; + break; + } + cumsum += count; + } + shared_pivot_info[0] = target_bin; + shared_pivot_info[1] = k; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + int target_bin = shared_pivot_info[0]; + k = shared_pivot_info[1]; + + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Initialize counters for partition size counting + if (lid.x == 0) { + shared_counts[0] = 0; // less_count + shared_counts[1] = 0; // equal_count + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Count partition sizes with SIMD reduction + int local_less = 0, local_equal = 0; + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + if (key < target_prefix) + local_less++; + else if (key == target_prefix) + local_equal++; + } + + // SIMD reduction + local_less = simd_sum(local_less); + local_equal = simd_sum(local_equal); + + // Aggregate across SIMD groups (only first lane of each SIMD group) + if (simd_lane == 0) { + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_counts[0], + local_less, + memory_order_relaxed); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_counts[1], + local_equal, + memory_order_relaxed); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Read final counts - all threads read the same values + int less_count = shared_counts[0]; + int equal_count = shared_counts[1]; + + // Initialize output counters + if (lid.x == 0) { + shared_output_counters[0] = 0; // less output counter + shared_output_counters[1] = 0; // equal output counter + shared_output_counters[2] = 0; // greater output counter + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Output partitioned elements + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + int pos; + if (key < target_prefix) { + pos = atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_output_counters[0], + 1, + memory_order_relaxed); + } else if (key == target_prefix) { + pos = less_count + + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_output_counters[1], + 1, + memory_order_relaxed); + } else { + pos = less_count + equal_count + + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_output_counters[2], + 1, + memory_order_relaxed); + } + + if (ARG_PARTITION) { + row_output[pos * out_stride] = i; + } else { + row_output[pos * out_stride] = val; + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Single-pass Radix Select for small arrays (fits in threadgroup memory) +/////////////////////////////////////////////////////////////////////////////// + +template < + typename ValT, + typename OutT, + bool ARG_PARTITION, + short BLOCK_THREADS, + short ITEMS_PER_THREAD> +struct RadixSelectSmall { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + static constexpr constant short TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; + + static METAL_FUNC void partition( + const device ValT* input, + device OutT* output, + int kth, + int size_sorted_axis, + int in_stride_sorted_axis, + int out_stride_sorted_axis, + int in_stride_segment_axis, + int out_stride_segment_axis, + threadgroup UnsignedT* shared_keys, + threadgroup uint32_t* shared_idxs, + threadgroup int* shared_hist, + threadgroup int* shared_count, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + int row = tid.y; + const device ValT* row_input = input + row * in_stride_segment_axis; + device OutT* row_output = output + row * out_stride_segment_axis; + + int n = min(size_sorted_axis, int(TILE_SIZE)); + + // Load data into threadgroup memory + for (int i = lid.x; i < TILE_SIZE; i += BLOCK_THREADS) { + if (i < n) { + ValT val = row_input[i * in_stride_sorted_axis]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + shared_keys[i] = key; + shared_idxs[i] = i; + } else { + shared_keys[i] = ~UnsignedT(0); + shared_idxs[i] = i; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Radix select + int k = kth + 1; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + + UnsignedT target_prefix = 0; + UnsignedT prefix_mask = 0; + + for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { + int start_bit = pass * RADIX_BITS; + + for (int i = lid.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_hist[digit], + 1, + memory_order_relaxed); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (lid.x == 0) { + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = shared_hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k = k - cumsum; + break; + } + cumsum += count; + } + shared_count[0] = target_bin; + shared_count[1] = k; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + int target_bin = shared_count[0]; + k = shared_count[1]; + + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Output partitioned array + if (lid.x == 0) { + shared_count[0] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Phase 1: less than pivot + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key < target_prefix) { + int pos = atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_count[0], 1, memory_order_relaxed); + if (ARG_PARTITION) { + row_output[pos * out_stride_sorted_axis] = shared_idxs[i]; + } else { + row_output[pos * out_stride_sorted_axis] = + row_input[shared_idxs[i] * in_stride_sorted_axis]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Phase 2: equal to pivot + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key == target_prefix) { + int pos = atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_count[0], 1, memory_order_relaxed); + if (ARG_PARTITION) { + row_output[pos * out_stride_sorted_axis] = shared_idxs[i]; + } else { + row_output[pos * out_stride_sorted_axis] = + row_input[shared_idxs[i] * in_stride_sorted_axis]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Phase 3: greater than pivot + for (int i = lid.x; i < n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key > target_prefix) { + int pos = atomic_fetch_add_explicit( + (threadgroup atomic_int*)&shared_count[0], 1, memory_order_relaxed); + if (ARG_PARTITION) { + row_output[pos * out_stride_sorted_axis] = shared_idxs[i]; + } else { + row_output[pos * out_stride_sorted_axis] = + row_input[shared_idxs[i] * in_stride_sorted_axis]; + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Kernel entry points +/////////////////////////////////////////////////////////////////////////////// + +template < + typename ValT, + typename OutT, + bool ARG_PARTITION, + short BLOCK_THREADS, + short ITEMS_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +radix_select_partition( + const device ValT* input [[buffer(0)]], + device OutT* output [[buffer(1)]], + const constant int& kth [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& in_stride_sorted_axis [[buffer(4)]], + const constant int& out_stride_sorted_axis [[buffer(5)]], + const constant int& in_stride_segment_axis [[buffer(6)]], + const constant int& out_stride_segment_axis [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using SelectKernel = RadixSelectSmall< + ValT, + OutT, + ARG_PARTITION, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + using UnsignedT = typename SelectKernel::UnsignedT; + + threadgroup UnsignedT shared_keys[SelectKernel::TILE_SIZE]; + threadgroup uint32_t shared_idxs[SelectKernel::TILE_SIZE]; + threadgroup int shared_hist[RADIX_SIZE]; + threadgroup int shared_count[2]; + + SelectKernel::partition( + input, + output, + kth, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + shared_keys, + shared_idxs, + shared_hist, + shared_count, + tid, + lid); +} + +template < + typename ValT, + typename OutT, + bool ARG_PARTITION, + short BLOCK_THREADS, + short ITEMS_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +radix_select_partition_nc( + const device ValT* input [[buffer(0)]], + device OutT* output [[buffer(1)]], + const constant int& kth [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& in_stride_sorted_axis [[buffer(4)]], + const constant int& out_stride_sorted_axis [[buffer(5)]], + const constant int& nc_dim [[buffer(6)]], + const constant int* nc_shape [[buffer(7)]], + const constant int64_t* in_nc_strides [[buffer(8)]], + const constant int64_t* out_nc_strides [[buffer(9)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using SelectKernel = RadixSelectSmall< + ValT, + OutT, + ARG_PARTITION, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + using UnsignedT = typename SelectKernel::UnsignedT; + + auto in_offset = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); + auto out_offset = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); + + threadgroup UnsignedT shared_keys[SelectKernel::TILE_SIZE]; + threadgroup uint32_t shared_idxs[SelectKernel::TILE_SIZE]; + threadgroup int shared_hist[RADIX_SIZE]; + threadgroup int shared_count[2]; + + SelectKernel::partition( + input + in_offset, + output + out_offset, + kth, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + 0, + 0, + shared_keys, + shared_idxs, + shared_hist, + shared_count, + tid, + lid); +} diff --git a/mlx/backend/metal/kernels/radix_select.metal b/mlx/backend/metal/kernels/radix_select.metal new file mode 100644 index 0000000000..cd657837de --- /dev/null +++ b/mlx/backend/metal/kernels/radix_select.metal @@ -0,0 +1,167 @@ +// Copyright © 2025 Apple Inc. + +#include + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/radix_select.h" + +/////////////////////////////////////////////////////////////////////////////// +// Single-pass Radix Select Kernel Instantiations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_radix_select(name, itname, itype, otname, otype, arg_part, bn, tn) \ + instantiate_kernel( \ + "c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \ + radix_select_partition, \ + itype, otype, arg_part, bn, tn) \ + instantiate_kernel( \ + "nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \ + radix_select_partition_nc, \ + itype, otype, arg_part, bn, tn) + +#define instantiate_radix_select_arg(itname, itype, bn, tn) \ + instantiate_radix_select(arg_radix_select, itname, itype, uint32, uint32_t, true, bn, tn) + +#define instantiate_radix_select_val(itname, itype, bn, tn) \ + instantiate_radix_select(_radix_select, itname, itype, itname, itype, false, bn, tn) + +#define instantiate_radix_select_tn(itname, itype, bn) \ + instantiate_radix_select_arg(itname, itype, bn, 8) \ + instantiate_radix_select_val(itname, itype, bn, 8) + +#define instantiate_radix_select_bn(itname, itype) \ + instantiate_radix_select_tn(itname, itype, 256) + +instantiate_radix_select_bn(uint8, uint8_t) +instantiate_radix_select_bn(uint16, uint16_t) +instantiate_radix_select_bn(uint32, uint32_t) +instantiate_radix_select_bn(int8, int8_t) +instantiate_radix_select_bn(int16, int16_t) +instantiate_radix_select_bn(int32, int32_t) +instantiate_radix_select_bn(float16, half) +instantiate_radix_select_bn(float32, float) +instantiate_radix_select_bn(bfloat16, bfloat16_t) + +#define instantiate_radix_select_long(itname, itype) \ + instantiate_radix_select_tn(itname, itype, 128) + +instantiate_radix_select_long(uint64, uint64_t) +instantiate_radix_select_long(int64, int64_t) + +/////////////////////////////////////////////////////////////////////////////// +// Large Array Streaming Radix Select Kernel Instantiations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_radix_large_streaming(itname, itype, otname, otype, arg_part, bn) \ + instantiate_kernel( \ + "radix_select_large_" #itname "_" #otname "_" #arg_part "_bn" #bn, \ + radix_select_large_streaming, \ + itype, otype, arg_part, bn) + +#define instantiate_radix_large_streaming_all(itname, itype, bn) \ + instantiate_radix_large_streaming(itname, itype, uint32, uint32_t, true, bn) \ + instantiate_radix_large_streaming(itname, itype, itname, itype, false, bn) + +instantiate_radix_large_streaming_all(uint8, uint8_t, 256) +instantiate_radix_large_streaming_all(uint16, uint16_t, 256) +instantiate_radix_large_streaming_all(uint32, uint32_t, 256) +instantiate_radix_large_streaming_all(int8, int8_t, 256) +instantiate_radix_large_streaming_all(int16, int16_t, 256) +instantiate_radix_large_streaming_all(int32, int32_t, 256) +instantiate_radix_large_streaming_all(float16, half, 256) +instantiate_radix_large_streaming_all(float32, float, 256) +instantiate_radix_large_streaming_all(bfloat16, bfloat16_t, 256) +instantiate_radix_large_streaming_all(uint64, uint64_t, 128) +instantiate_radix_large_streaming_all(int64, int64_t, 128) + +/////////////////////////////////////////////////////////////////////////////// +// Large Array Non-Contiguous Streaming Radix Select Kernel Instantiations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_radix_large_streaming_nc(itname, itype, otname, otype, arg_part, bn) \ + instantiate_kernel( \ + "radix_select_large_nc_" #itname "_" #otname "_" #arg_part "_bn" #bn, \ + radix_select_large_streaming_nc, \ + itype, otype, arg_part, bn) + +#define instantiate_radix_large_streaming_nc_all(itname, itype, bn) \ + instantiate_radix_large_streaming_nc(itname, itype, uint32, uint32_t, true, bn) \ + instantiate_radix_large_streaming_nc(itname, itype, itname, itype, false, bn) + +instantiate_radix_large_streaming_nc_all(uint8, uint8_t, 256) +instantiate_radix_large_streaming_nc_all(uint16, uint16_t, 256) +instantiate_radix_large_streaming_nc_all(uint32, uint32_t, 256) +instantiate_radix_large_streaming_nc_all(int8, int8_t, 256) +instantiate_radix_large_streaming_nc_all(int16, int16_t, 256) +instantiate_radix_large_streaming_nc_all(int32, int32_t, 256) +instantiate_radix_large_streaming_nc_all(float16, half, 256) +instantiate_radix_large_streaming_nc_all(float32, float, 256) +instantiate_radix_large_streaming_nc_all(bfloat16, bfloat16_t, 256) +instantiate_radix_large_streaming_nc_all(uint64, uint64_t, 128) +instantiate_radix_large_streaming_nc_all(int64, int64_t, 128) + +/////////////////////////////////////////////////////////////////////////////// +// Multi-pass Radix Select Kernel Instantiations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_radix_histogram(itname, itype, bn) \ + instantiate_kernel("radix_histogram_" #itname "_bn" #bn, radix_histogram_kernel, itype, bn) + +#define instantiate_radix_histogram_all(itname, itype) \ + instantiate_radix_histogram(itname, itype, 256) + +instantiate_radix_histogram_all(uint8, uint8_t) +instantiate_radix_histogram_all(uint16, uint16_t) +instantiate_radix_histogram_all(uint32, uint32_t) +instantiate_radix_histogram_all(int8, int8_t) +instantiate_radix_histogram_all(int16, int16_t) +instantiate_radix_histogram_all(int32, int32_t) +instantiate_radix_histogram_all(float16, half) +instantiate_radix_histogram_all(float32, float) +instantiate_radix_histogram_all(bfloat16, bfloat16_t) +instantiate_radix_histogram_all(uint64, uint64_t) +instantiate_radix_histogram_all(int64, int64_t) + +#define instantiate_radix_find_bin(itname, itype) \ + instantiate_kernel("radix_find_bin_" #itname, radix_find_bin_kernel, itype) + +instantiate_radix_find_bin(uint8, uint8_t) +instantiate_radix_find_bin(uint16, uint16_t) +instantiate_radix_find_bin(uint32, uint32_t) +instantiate_radix_find_bin(int8, int8_t) +instantiate_radix_find_bin(int16, int16_t) +instantiate_radix_find_bin(int32, int32_t) +instantiate_radix_find_bin(float16, half) +instantiate_radix_find_bin(float32, float) +instantiate_radix_find_bin(bfloat16, bfloat16_t) +instantiate_radix_find_bin(uint64, uint64_t) +instantiate_radix_find_bin(int64, int64_t) + +#define instantiate_partition_output(itname, itype, otname, otype, arg_part, bn) \ + instantiate_kernel( \ + "radix_partition_output_" #itname "_" #otname "_" #arg_part "_bn" #bn, \ + radix_partition_output_kernel, itype, otype, arg_part, bn) \ + instantiate_kernel( \ + "radix_partition_equal_" #itname "_" #otname "_" #arg_part "_bn" #bn, \ + radix_partition_equal_kernel, itype, otype, arg_part, bn) \ + instantiate_kernel( \ + "radix_partition_greater_" #itname "_" #otname "_" #arg_part "_bn" #bn, \ + radix_partition_greater_kernel, itype, otype, arg_part, bn) + +#define instantiate_partition_output_all(itname, itype) \ + instantiate_partition_output(itname, itype, uint32, uint32_t, true, 256) \ + instantiate_partition_output(itname, itype, itname, itype, false, 256) + +instantiate_partition_output_all(uint8, uint8_t) +instantiate_partition_output_all(uint16, uint16_t) +instantiate_partition_output_all(uint32, uint32_t) +instantiate_partition_output_all(int8, int8_t) +instantiate_partition_output_all(int16, int16_t) +instantiate_partition_output_all(int32, int32_t) +instantiate_partition_output_all(float16, half) +instantiate_partition_output_all(float32, float) +instantiate_partition_output_all(bfloat16, bfloat16_t) +instantiate_partition_output_all(uint64, uint64_t) +instantiate_partition_output_all(int64, int64_t) + // clang-format on diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 533b1927c2..3f15109ab0 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -110,6 +110,16 @@ MTL::ComputePipelineState* get_mb_sort_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_radix_select_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + const array&, + int, + int) { + return d.get_kernel(kernel_name); +} + MTL::ComputePipelineState* get_reduce_init_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 3c84022f2c..4cfbbd917f 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -52,15 +52,20 @@ void single_block_sort( contiguous &= check_strides(out, out_stride_sorted_axis); // Prepare kernel name - std::ostringstream kname; - kname << (contiguous ? "c" : "nc"); - if (argsort) { - kname << "arg"; - } - - kname << "_block_sort_" << type_to_name(in) << "_" << type_to_name(out) - << "_bn" << bn << "_tn" << tn; - auto kernel = get_sort_kernel(d, kname.str(), in, out, bn, tn); + std::string kname; + concatenate( + kname, + contiguous ? "c" : "nc", + argsort ? "arg" : "", + "_block_sort_", + type_to_name(in), + "_", + type_to_name(out), + "_bn", + bn, + "_tn", + tn); + auto kernel = get_sort_kernel(d, kname, in, out, bn, tn); // Prepare command encoder auto& compute_encoder = d.get_command_encoder(s.index); @@ -164,11 +169,18 @@ void multi_block_sort( // Do blockwise sort { - std::ostringstream kname; - kname << "sort_mbsort_" << type_to_name(dev_vals_0) << "_" - << type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn; - auto kernel = - get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); + std::string kname; + concatenate( + kname, + "sort_mbsort_", + type_to_name(dev_vals_0), + "_", + type_to_name(dev_idxs_0), + "_bn", + std::to_string(bn), + "_tn", + std::to_string(tn)); + auto kernel = get_mb_sort_kernel(d, kname, dev_vals_0, dev_idxs_0, bn, tn); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); @@ -204,12 +216,20 @@ void multi_block_sort( // Do partition { - std::ostringstream kname; - kname << "partition_mbsort_" << type_to_name(dev_vals_in) << "_" - << type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn; + std::string kname; + concatenate( + kname, + "partition_mbsort_", + type_to_name(dev_vals_in), + "_", + type_to_name(dev_idxs_in), + "_bn", + std::to_string(bn), + "_tn", + std::to_string(tn)); auto kernel = - get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); + get_mb_sort_kernel(d, kname, dev_vals_0, dev_idxs_0, bn, tn); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_output_array(block_partitions, 0); @@ -227,12 +247,20 @@ void multi_block_sort( // Do merge { - std::ostringstream kname; - kname << "merge_mbsort_" << type_to_name(dev_vals_in) << "_" - << type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn; + std::string kname; + concatenate( + kname, + "merge_mbsort_", + type_to_name(dev_vals_in), + "_", + type_to_name(dev_idxs_in), + "_bn", + std::to_string(bn), + "_tn", + std::to_string(tn)); auto kernel = - get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn); + get_mb_sort_kernel(d, kname, dev_vals_0, dev_idxs_0, bn, tn); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(block_partitions, 0); @@ -313,6 +341,341 @@ void gpu_merge_sort( } } +void gpu_radix_partition_small( + const Stream& s, + metal::Device& d, + const array& in, + array& out, + int axis, + int kth, + bool arg_partition, + int n_rows, + int size_sorted_axis, + int in_stride_sorted_axis, + int out_stride_sorted_axis, + bool contiguous, + const Shape& nc_shape, + const Strides& in_nc_str, + const Strides& out_nc_str) { + constexpr int bn = 256; + constexpr int tn = 8; + + std::string kname; + concatenate( + kname, + kname, + contiguous ? "c" : "nc", + arg_partition ? "arg_" : "_", + "radix_select_", + type_to_name(in), + "_", + type_to_name(out), + "_bn", + std::to_string(bn), + "_tn", + std::to_string(tn)); + + auto kernel = get_radix_select_kernel(d, kname, in, out, bn, tn); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(kth, 2); + compute_encoder.set_bytes(size_sorted_axis, 3); + compute_encoder.set_bytes(in_stride_sorted_axis, 4); + compute_encoder.set_bytes(out_stride_sorted_axis, 5); + + if (contiguous) { + int in_stride_segment_axis = INT32_MAX; + int out_stride_segment_axis = INT32_MAX; + for (size_t i = 0; i < in_nc_str.size(); i++) { + if (nc_shape[i] == 1) + continue; + in_stride_segment_axis = + std::min(in_stride_segment_axis, static_cast(in_nc_str[i])); + out_stride_segment_axis = + std::min(out_stride_segment_axis, static_cast(out_nc_str[i])); + } + compute_encoder.set_bytes(in_stride_segment_axis, 6); + compute_encoder.set_bytes(out_stride_segment_axis, 7); + } else { + int nc_dim = nc_shape.size(); + compute_encoder.set_bytes(nc_dim, 6); + if (nc_shape.empty()) { + int shape = 0; + int64_t stride = 0; + compute_encoder.set_bytes(shape, 7); + compute_encoder.set_bytes(stride, 8); + compute_encoder.set_bytes(stride, 9); + } else { + compute_encoder.set_vector_bytes(nc_shape, 7); + compute_encoder.set_vector_bytes(in_nc_str, 8); + compute_encoder.set_vector_bytes(out_nc_str, 9); + } + } + + MTL::Size group_dims = MTL::Size(bn, 1, 1); + MTL::Size grid_dims = MTL::Size(1, n_rows, 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gpu_radix_partition_large( + const Stream& s, + metal::Device& d, + const array& in, + array& out, + int axis, + int kth, + bool arg_partition, + int n_rows, + int size_sorted_axis, + int in_stride_sorted_axis, + int out_stride_sorted_axis, + int in_stride_segment_axis, + int out_stride_segment_axis) { + constexpr int bn = 256; + + // Allocate counter buffer for streaming kernel + // Layout: [less, equal, greater, sync] per row + array counters({n_rows, 4}, int32, nullptr, {}); + counters.set_data(allocator::malloc(counters.nbytes())); + + // Zero-initialize counters + std::memset(counters.data(), 0, counters.nbytes()); + + std::vector temps = {counters}; + + auto& compute_encoder = d.get_command_encoder(s.index); + + // Use the streaming kernel that processes all passes in one dispatch + std::string kname; + concatenate( + kname, + "radix_select_large_", + type_to_name(in), + "_", + type_to_name(out), + "_", + arg_partition ? "true" : "false", + "_bn", + std::to_string(bn)); + + auto kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_output_array(counters, 2); + compute_encoder.set_bytes(size_sorted_axis, 3); + compute_encoder.set_bytes(kth, 4); + compute_encoder.set_bytes(in_stride_sorted_axis, 5); + compute_encoder.set_bytes(out_stride_sorted_axis, 6); + compute_encoder.set_bytes(in_stride_segment_axis, 7); + compute_encoder.set_bytes(out_stride_segment_axis, 8); + + // Single threadgroup per row for streaming approach + MTL::Size group_dims = MTL::Size(bn, 1, 1); + MTL::Size grid_dims = MTL::Size(1, n_rows, 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + d.add_temporaries(std::move(temps), s.index); +} + +void gpu_radix_partition_large_nc( + const Stream& s, + metal::Device& d, + const array& in, + array& out, + int axis, + int kth, + bool arg_partition, + int n_rows, + int size_sorted_axis, + int in_stride_sorted_axis, + int out_stride_sorted_axis, + const Shape& nc_shape, + const Strides& in_nc_str, + const Strides& out_nc_str) { + constexpr int bn = 256; + + // Allocate counter buffer for streaming kernel + array counters({n_rows, 4}, int32, nullptr, {}); + counters.set_data(allocator::malloc(counters.nbytes())); + + // Zero-initialize counters + std::memset(counters.data(), 0, counters.nbytes()); + + std::vector temps = {counters}; + + auto& compute_encoder = d.get_command_encoder(s.index); + + // Use the non-contiguous streaming kernel + std::string kname; + concatenate( + kname, + "radix_select_large_nc_", + type_to_name(in), + "_", + type_to_name(out), + "_", + arg_partition ? "true" : "false", + "_bn", + bn); + + auto kernel = d.get_kernel(kname); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(in, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_output_array(counters, 2); + compute_encoder.set_bytes(size_sorted_axis, 3); + compute_encoder.set_bytes(kth, 4); + compute_encoder.set_bytes(in_stride_sorted_axis, 5); + compute_encoder.set_bytes(out_stride_sorted_axis, 6); + + int nc_dim = nc_shape.size(); + compute_encoder.set_bytes(nc_dim, 7); + if (nc_shape.empty()) { + int shape = 0; + int64_t stride = 0; + compute_encoder.set_bytes(shape, 8); + compute_encoder.set_bytes(stride, 9); + compute_encoder.set_bytes(stride, 10); + } else { + compute_encoder.set_vector_bytes(nc_shape, 8); + compute_encoder.set_vector_bytes(in_nc_str, 9); + compute_encoder.set_vector_bytes(out_nc_str, 10); + } + + // Single threadgroup per row for streaming approach + MTL::Size group_dims = MTL::Size(bn, 1, 1); + MTL::Size grid_dims = MTL::Size(1, n_rows, 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + d.add_temporaries(std::move(temps), s.index); +} + +void gpu_radix_partition( + const Stream& s, + metal::Device& d, + const array& in, + array& out, + int axis_, + int kth, + bool arg_partition) { + int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; + int size_sorted_axis = in.shape(axis); + + // Normalize kth + if (kth < 0) { + kth += size_sorted_axis; + } + + auto in_nc_str = in.strides(); + in_nc_str.erase(in_nc_str.begin() + axis); + + auto out_nc_str = out.strides(); + out_nc_str.erase(out_nc_str.begin() + axis); + + auto nc_shape = in.shape(); + nc_shape.erase(nc_shape.begin() + axis); + + int in_stride_sorted_axis = in.strides()[axis]; + int out_stride_sorted_axis = out.strides()[axis]; + + // Check if we can use the contiguous kernel + bool contiguous = in.flags().contiguous; + auto check_strides = [](const array& x, int sort_stride) { + int min_stride = *std::min_element(x.strides().begin(), x.strides().end()); + int max_stride = *std::max_element(x.strides().begin(), x.strides().end()); + return sort_stride == min_stride || sort_stride == max_stride; + }; + contiguous &= check_strides(in, in_stride_sorted_axis); + contiguous &= check_strides(out, out_stride_sorted_axis); + + // Radix select configuration + constexpr int bn = 256; + constexpr int tn = 8; + constexpr int TILE_SIZE = bn * tn; // 2048 + + // Use single-pass kernel for small arrays that fit in threadgroup memory + if (size_sorted_axis <= TILE_SIZE) { + gpu_radix_partition_small( + s, + d, + in, + out, + axis, + kth, + arg_partition, + n_rows, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + contiguous, + nc_shape, + in_nc_str, + out_nc_str); + return; + } + + // For larger arrays, use the streaming radix select kernel + // This performs all radix passes in a single kernel dispatch + if (contiguous) { + int in_stride_segment_axis = size_sorted_axis; + int out_stride_segment_axis = size_sorted_axis; + + // For contiguous arrays, the segment stride is the product of all + // dimensions after the sorted axis (or the sorted axis size for the last + // axis) + if (!in_nc_str.empty()) { + // Find the stride that separates rows + for (size_t i = 0; i < in_nc_str.size(); i++) { + if (nc_shape[i] == 1) + continue; + in_stride_segment_axis = + std::min(in_stride_segment_axis, static_cast(in_nc_str[i])); + out_stride_segment_axis = + std::min(out_stride_segment_axis, static_cast(out_nc_str[i])); + } + } + + gpu_radix_partition_large( + s, + d, + in, + out, + axis, + kth, + arg_partition, + n_rows, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis); + } else { + // Use non-contiguous kernel with elem_to_loc indexing + gpu_radix_partition_large_nc( + s, + d, + in, + out, + axis, + kth, + arg_partition, + n_rows, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + nc_shape, + in_nc_str, + out_nc_str); + } +} + } // namespace void ArgSort::eval_gpu(const std::vector& inputs, array& out) { @@ -340,7 +703,6 @@ void Sort::eval_gpu(const std::vector& inputs, array& out) { } void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { - // We direct arg partition to sort for now assert(inputs.size() == 1); out.set_data(allocator::malloc(out.nbytes())); @@ -349,11 +711,10 @@ void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); auto& in = inputs[0]; - gpu_merge_sort(s, d, in, out, axis_, true); + gpu_radix_partition(s, d, in, out, axis_, kth_, true); } void Partition::eval_gpu(const std::vector& inputs, array& out) { - // We direct partition to sort for now assert(inputs.size() == 1); out.set_data(allocator::malloc(out.nbytes())); @@ -362,7 +723,7 @@ void Partition::eval_gpu(const std::vector& inputs, array& out) { auto& d = metal::device(s.device); auto& in = inputs[0]; - gpu_merge_sort(s, d, in, out, axis_, false); + gpu_radix_partition(s, d, in, out, axis_, kth_, false); } } // namespace mlx::core