From 98eedd1e63d94ded93a910676f83628bb92319a1 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 16 Jan 2026 00:47:59 +0100 Subject: [PATCH 01/34] adding tensor scale [wip] --- mlx/backend/cuda/quantized/fp_quantize.cu | 50 +++++++++++++++++------ mlx/backend/cuda/quantized/quantized.cpp | 18 +++++++- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 9c96491358..7f693e920d 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -14,6 +14,9 @@ #include #include +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; + namespace mlx::core { namespace cu { @@ -31,7 +34,17 @@ struct Dequantize { namespace cg = cooperative_groups; template -__global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) { +__global__ void fp_quantize( + T* w, + uint8_t* out, + uint8_t* scales, + size_t size, + float* tensor_amax) { + // NVFP4 conversion: + // Global encode scale: (448 × 6) / *tensor_amax + // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 + // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b + const float scale_enc = (F8E4M3_MAX * F4E2M1_MAX) / *tensor_amax; using Tx2 = Vector2_t; using Tx4 = Vector4_t; uint32_t rbits = 0; // reserved bits for future use @@ -50,7 +63,7 @@ __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) { } auto w_tile = load_vector(w, thread_idx); - float scale = 0.0f; + float scale_dec_b = 0.0f; Tx2 amax_2x = Tx2{0.0f, 0.0f}; @@ -60,17 +73,18 @@ __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) { abs_max_x2(amax_2x, amax_2x, pair); } - scale = static_cast( + scale_dec_b = static_cast( max(fabsf(static_cast(amax_2x.x)), fabsf(static_cast(amax_2x.y)))); - scale /= bits == 4 ? 6.0f : 448.0f; + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= use_mx_scale ? 1.0f : scale_enc; // Convert to mx scale or nv scale using ScaleType = std::conditional_t; - auto s = ScaleType(scale); + auto s = ScaleType(scale_dec_b); uint8_t q_scale = s.__x; - scale = float(s); + float scale_enc_b = scale_enc / float(scale_dec_b); scales[thread_idx] = q_scale; constexpr int elem_per_byte = bits == 8 ? 1 : 2; @@ -81,11 +95,11 @@ __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) { Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); if constexpr (bits == 8) { uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized[i * 4]) = quantized_val; } else { uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized[i * 2]) = quantized_val; } } @@ -93,8 +107,12 @@ __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) { } template -__global__ void -fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) { +__global__ void fp_dequantize( + const uint8_t* w, + const uint8_t* scales, + T* out, + size_t size, + float* tensor_amax) { auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); @@ -105,6 +123,8 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) { auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; constexpr int pack_factor = bits == 8 ? 1 : 2; + const float inv_scale_enc = + use_mx_scale ? 1.0f : (*tensor_amax) / (F8E4M3_MAX * F4E2M1_MAX); size_t offset = tidx + grid_dim_x * size_t(tidy); size_t oindex = offset * pack_factor; @@ -115,7 +135,7 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) { size_t gindex = oindex / group_size; using ScaleType = std::conditional_t; - auto scale = float(((ScaleType*)(scales))[gindex]); + auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc; out += oindex; @@ -138,6 +158,7 @@ void fp_quantize( const array& w, array& wq, array& scales, + const array& tensor_amax, int group_size, int bits, cu::CommandEncoder& enc, @@ -166,7 +187,8 @@ void fp_quantize( gpu_ptr(w), gpu_ptr(wq), gpu_ptr(scales), - w.size()); + w.size(), + gpu_ptr(tensor_amax)); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not quantize input with type float64."); @@ -177,6 +199,7 @@ void fp_quantize( void fp_dequantize( const array& wq, const array& scales, + const array& tensor_amax, array& w, int group_size, int bits, @@ -212,7 +235,8 @@ void fp_dequantize( gpu_ptr(wq), gpu_ptr(scales), gpu_ptr(w), - w.size()); + w.size(), + gpu_ptr(tensor_amax)); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not dequantize to output with type float64."); diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index a185523edc..b17e17fdd5 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -65,7 +65,8 @@ void fast::Quantize::eval_gpu( auto biases = ensure_row_contiguous(inputs[2], enc, s); affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); } else { - fp_dequantize(wq, scales, w, group_size_, bits_, enc, s); + auto tensor_amax = inputs[2]; + fp_dequantize(wq, scales, tensor_amax, w, group_size_, bits_, enc, s); } } else { auto w = ensure_row_contiguous(inputs[0], enc, s); @@ -74,12 +75,25 @@ void fast::Quantize::eval_gpu( wq.set_data(cu::malloc_async(wq.nbytes(), enc)); scales.set_data(cu::malloc_async(scales.nbytes(), enc)); + if (mode_ == QuantizationMode::Affine) { auto& biases = outputs[2]; biases.set_data(cu::malloc_async(biases.nbytes(), enc)); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); } else { - fp_quantize(w, wq, scales, group_size_, bits_, enc, s); + auto& tensor_amax = outputs[2]; + tensor_amax.set_data(cu::malloc_async(tensor_amax.nbytes(), enc)); + // here we will write launch amax kernel + all_reduce(enc, s, w, tensor_amax, MAX_OP); // compute amax + fp_quantize( + w, + wq, + scales, + tensor_amax, + group_size_, + bits_, + enc, + s); // pass amax to quantization kernel } } } From 689240451e0f738dbc44ab06eb2804628416a6a3 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 17 Jan 2026 00:14:20 +0100 Subject: [PATCH 02/34] added absmax reduction, changed fp_quanitze api [wip] --- mlx/backend/cuda/quantized/qqmm.cpp | 8 ++++++-- mlx/backend/cuda/quantized/quantized.cpp | 5 +++-- mlx/backend/cuda/reduce/all_reduce.cu | 12 ++++++++++-- mlx/backend/cuda/reduce/reduce.cuh | 2 ++ mlx/backend/cuda/reduce/reduce_ops.cuh | 24 ++++++++++++++++++++++++ mlx/ops.cpp | 9 ++++++--- mlx/primitives.h | 4 +++- 7 files changed, 54 insertions(+), 10 deletions(-) diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 21e41cf0ed..68d08fcf36 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/quantized/cublas_qqmm.h" #include "mlx/backend/cuda/quantized/qqmm_utils.h" #include "mlx/backend/cuda/quantized/quantized.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/primitives.h" @@ -152,11 +153,14 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32); array scales_x( cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8); - - fp_quantize(x, x_q, scales_x, group_size_, bits_, encoder, s); + array tensor_amax_x({}, x.dtype(), nullptr, {}); + all_reduce(encoder, x, tensor_amax_x, Reduce::ReduceType::AbsMax); + fp_quantize( + x, x_q, scales_x, tensor_amax_x, group_size_, bits_, encoder, s); encoder.add_temporary(x_q); encoder.add_temporary(scales_x); + encoder.add_temporary(tensor_amax_x); return {x_q, scales_x}; }; auto [x_q, scale_x_pre] = quantize(inputs[0], encoder, s); diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index b17e17fdd5..1d530f8103 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -2,6 +2,7 @@ #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" @@ -83,8 +84,8 @@ void fast::Quantize::eval_gpu( } else { auto& tensor_amax = outputs[2]; tensor_amax.set_data(cu::malloc_async(tensor_amax.nbytes(), enc)); - // here we will write launch amax kernel - all_reduce(enc, s, w, tensor_amax, MAX_OP); // compute amax + all_reduce( + enc, w, tensor_amax, Reduce::ReduceType::AbsMax); // compute amax fp_quantize( w, wq, diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 1126f4cc76..bafa5817f9 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -37,7 +37,11 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { for (; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { - accs[0] = op(accs[0], cast_to(vals[j])); + if constexpr (cuda::std::is_same_v) { + accs[0] = op(accs[0], abs(cast_to(vals[j]))); + } else { + accs[0] = op(accs[0], cast_to(vals[j])); + } } } @@ -45,7 +49,11 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { cub::LoadDirectBlocked( block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { - accs[0] = op(accs[0], cast_to(vals[i])); + if constexpr (cuda::std::is_same_v) { + accs[0] = op(accs[0], abs(cast_to(vals[i]))); + } else { + accs[0] = op(accs[0], cast_to(vals[i])); + } } } diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 02e495594a..947e8b36dc 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -35,6 +35,8 @@ void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Min) { f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::AbsMax) { + f(type_identity{}); } else { throw std::invalid_argument("Unknown reduce type."); } diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 7f8cad0c4e..91fbe9460a 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -114,6 +114,18 @@ struct Max { } }; +struct AbsMax { + template + __device__ __forceinline__ T operator()(T a, T b) { + return a > b ? a : b; + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } +}; + // Traits to get the result type of reduce op. template struct ReduceResult; @@ -154,6 +166,11 @@ struct ReduceResult { using type = T; }; +template +struct ReduceResult { + using type = T; +}; + // Traits to get the init value of reduce op. template struct ReduceInit; @@ -208,4 +225,11 @@ struct ReduceInit { } }; +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return T(0); // abs values are >= 0 + } +}; + } // namespace mlx::core::cu diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4964322951..3b28ade4ba 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4597,9 +4597,10 @@ std::vector fp_quantize( wq_shape.back() = w.shape(-1) * bits / 32; auto sshape = w.shape(); sshape.back() = w.shape(-1) / group_size; + Shape tashape = {}; return array::make_arrays( - {std::move(wq_shape), std::move(sshape)}, - {uint32, uint8}, + {std::move(wq_shape), std::move(sshape), std::move(tashape)}, + {uint32, uint8, float32}, std::make_shared( s, fallback, group_size, bits, mode, false), {w}); @@ -4743,6 +4744,7 @@ array affine_dequantize( array fp_dequantize( const array& w, const array& scales, + const array& tensor_amax, int group_size, int bits, Dtype out_type, @@ -4848,7 +4850,7 @@ array fp_dequantize( out_type, std::make_shared( s, fallback, group_size, bits, mode, true), - {w, scales}); + {w, scales, tensor_amax}); } return fallback({w, scales})[0]; } @@ -4856,6 +4858,7 @@ array fp_dequantize( array dequantize( const array& w, const array& scales, + // const std::optional& tensor_amax /* = std::nullopt */, const std::optional& biases /* = std::nullopt */, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, diff --git a/mlx/primitives.h b/mlx/primitives.h index c3ce00f92f..1e9b14da66 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1769,7 +1769,7 @@ class Reshape : public UnaryPrimitive { class Reduce : public UnaryPrimitive { public: - enum ReduceType { And, Or, Sum, Prod, Min, Max }; + enum ReduceType { And, Or, Sum, Prod, Min, Max, AbsMax }; explicit Reduce( Stream stream, @@ -1799,6 +1799,8 @@ class Reduce : public UnaryPrimitive { return "Min"; case Max: return "Max"; + case AbsMax: + return "AbsMax"; } return ""; } From 15d684b13bc66753b5636483b01f65e7180e1c1a Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 19 Jan 2026 00:27:32 +0100 Subject: [PATCH 03/34] refactoring --- mlx/backend/cuda/quantized/cublas_qqmm.h | 9 ++ mlx/backend/cuda/quantized/fp_quantize.cu | 25 ++-- mlx/backend/cuda/quantized/qqmm.cpp | 112 +++++++++++------- mlx/backend/cuda/quantized/qqmm_utils.cu | 31 +++++ mlx/backend/cuda/quantized/qqmm_utils.h | 7 ++ mlx/backend/cuda/quantized/quantized.cpp | 22 ++-- mlx/backend/cuda/quantized/quantized.h | 10 ++ mlx/backend/cuda/reduce/all_reduce.cu | 18 ++- mlx/ops.cpp | 138 ++++++++++++++++------ mlx/ops.h | 1 + mlx/primitives.cpp | 2 + tests/ops_tests.cpp | 2 +- 12 files changed, 277 insertions(+), 100 deletions(-) diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h index 0a710f6e10..ecce044afd 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.h +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -45,6 +45,15 @@ class CublasQQMM : public CublasMatmulBase { Dtype out_dtype, std::string quantization_mode); + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + const array& alpha); + void run( cu::CommandEncoder& encoder, array& out, diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 7f693e920d..7f574d7fa4 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -39,12 +39,13 @@ __global__ void fp_quantize( uint8_t* out, uint8_t* scales, size_t size, - float* tensor_amax) { + float* tensor_amax = nullptr) { // NVFP4 conversion: // Global encode scale: (448 × 6) / *tensor_amax // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b - const float scale_enc = (F8E4M3_MAX * F4E2M1_MAX) / *tensor_amax; + const float scale_enc = + !use_mx_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *tensor_amax : 1.0f; using Tx2 = Vector2_t; using Tx4 = Vector4_t; uint32_t rbits = 0; // reserved bits for future use @@ -78,7 +79,7 @@ __global__ void fp_quantize( fabsf(static_cast(amax_2x.y)))); scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; - scale_dec_b *= use_mx_scale ? 1.0f : scale_enc; + scale_dec_b *= scale_enc; // Convert to mx scale or nv scale using ScaleType = std::conditional_t; @@ -112,7 +113,7 @@ __global__ void fp_dequantize( const uint8_t* scales, T* out, size_t size, - float* tensor_amax) { + float* tensor_amax = nullptr) { auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); @@ -158,12 +159,15 @@ void fp_quantize( const array& w, array& wq, array& scales, - const array& tensor_amax, + const std::optional& tensor_amax /* = std::nullopt */, int group_size, int bits, cu::CommandEncoder& enc, const Stream& s) { enc.set_input_array(w); + if (tensor_amax.has_value()) { + enc.set_input_array(tensor_amax.value()); + } enc.set_output_array(wq); enc.set_output_array(scales); dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) { @@ -188,7 +192,8 @@ void fp_quantize( gpu_ptr(wq), gpu_ptr(scales), w.size(), - gpu_ptr(tensor_amax)); + tensor_amax.has_value() ? gpu_ptr(tensor_amax.value()) + : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not quantize input with type float64."); @@ -199,7 +204,7 @@ void fp_quantize( void fp_dequantize( const array& wq, const array& scales, - const array& tensor_amax, + const std::optional& tensor_amax /* = std::nullopt */, array& w, int group_size, int bits, @@ -215,6 +220,9 @@ void fp_dequantize( enc.set_input_array(wq); enc.set_input_array(scales); + if (tensor_amax.has_value()) { + enc.set_input_array(tensor_amax.value()); + } enc.set_output_array(w); dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) { using T = cuda_type_t; @@ -236,7 +244,8 @@ void fp_dequantize( gpu_ptr(scales), gpu_ptr(w), w.size(), - gpu_ptr(tensor_amax)); + tensor_amax.has_value() ? gpu_ptr(tensor_amax.value()) + : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not dequantize to output with type float64."); diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 68d08fcf36..0bc8bd0826 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -4,7 +4,6 @@ #include "mlx/backend/cuda/quantized/cublas_qqmm.h" #include "mlx/backend/cuda/quantized/qqmm_utils.h" #include "mlx/backend/cuda/quantized/quantized.h" -#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/primitives.h" @@ -12,6 +11,8 @@ namespace mlx::core { +using QuantizedResult = std::tuple>; + namespace { inline array ensure_row_contiguous( @@ -71,6 +72,68 @@ array pad_and_swizzle_scales( return scale_tiled; } +QuantizedResult quantize_input( + const array& input, + cu::CommandEncoder& encoder, + const Stream& s, + QuantizationMode mode, + int bits, + int group_size) { + const array x = ensure_row_contiguous(input, encoder, s); + + auto build_shapes = [&](const array& x_in) { + auto xq_shape = x_in.shape(); + xq_shape.back() = x_in.shape(-1) * bits / 32; + + auto sshape = x_in.shape(); + const int64_t scales_inner = x_in.shape(-1) / group_size; + auto [pad_outer, pad_inner] = + get_padded_scale_dims(x_in.shape(-2), scales_inner); + sshape[x_in.ndim() - 2] = pad_outer; + sshape[x_in.ndim() - 1] = pad_inner; + sshape.back() = scales_inner; + + return std::tuple{ + std::move(xq_shape), + std::move(sshape), + pad_outer, + pad_inner, + }; + }; + + auto allocate_outputs = [&](const array& x_in) { + auto [xq_shape, sshape, pad_outer, pad_inner] = build_shapes(x_in); + + const int64_t xq_bytes = x_in.size() * bits / 8; + const int64_t batch = x_in.size() / (x_in.shape(-2) * x_in.shape(-1)); + const int64_t scales_bytes = batch * (pad_outer * pad_inner); + + array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32); + array scales_x( + cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8); + encoder.add_temporary(x_q); + encoder.add_temporary(scales_x); + + return std::pair{std::move(x_q), std::move(scales_x)}; + }; + + auto run_quant = [&](const array& x_in, std::optional tensor_amax) { + auto [x_q, scales_x] = allocate_outputs(x_in); + fp_quantize(x_in, x_q, scales_x, tensor_amax, group_size, bits, encoder, s); + return QuantizedResult{ + std::move(x_q), std::move(scales_x), std::move(tensor_amax)}; + }; + + if (mode == QuantizationMode::Nvfp4) { + array tensor_amax(cu::malloc_async(sizeof(float), encoder), {1}, float32); + encoder.add_temporary(tensor_amax); + all_reduce(encoder, x, tensor_amax, Reduce::ReduceType::AbsMax); + return run_quant(x, tensor_amax); + } + + return run_quant(x, std::nullopt); +} + void qqmm_impl( cu::CommandEncoder& encoder, int M, @@ -125,49 +188,16 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher."); } + auto quant_input_size = (mode_ == QuantizationMode::Nvfp4) ? 4 : 3; assert( - (inputs.size() == 3 && inputs[1].dtype() == uint32) || + (inputs.size() == quant_input_size && inputs[1].dtype() == uint32) || (inputs.size() == 2)); - auto quantize = [&](const array& input, - cu::CommandEncoder& encoder, - const Stream& s) -> std::pair { - const array x = ensure_row_contiguous(input, encoder, s); - - auto xq_shape = x.shape(); - xq_shape.back() = x.shape(-1) * bits_ / 32; - - auto sshape = x.shape(); - const int64_t scales_inner = x.shape(-1) / group_size_; - auto [pad_outer, pad_inner] = - get_padded_scale_dims(x.shape(-2), scales_inner); - sshape[x.ndim() - 2] = pad_outer; - sshape[x.ndim() - 1] = pad_inner; - sshape.back() = scales_inner; - - // Allocate outputs - const int64_t xq_bytes = x.size() * bits_ / 8; - const int64_t batch = x.size() / (x.shape(-2) * x.shape(-1)); - const int64_t scales_bytes = batch * (pad_outer * pad_inner); - - array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32); - array scales_x( - cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8); - array tensor_amax_x({}, x.dtype(), nullptr, {}); - all_reduce(encoder, x, tensor_amax_x, Reduce::ReduceType::AbsMax); - fp_quantize( - x, x_q, scales_x, tensor_amax_x, group_size_, bits_, encoder, s); - - encoder.add_temporary(x_q); - encoder.add_temporary(scales_x); - encoder.add_temporary(tensor_amax_x); - return {x_q, scales_x}; - }; - auto [x_q, scale_x_pre] = quantize(inputs[0], encoder, s); - auto [w_q, scale_w_pre] = (inputs[1].dtype() != uint32) - ? quantize(inputs[1], encoder, s) - : std::make_pair(inputs[1], inputs[2]); - + auto [x_q, scale_x_pre, tensor_amax_x] = + quantize_input(inputs[0], encoder, s, mode_, bits_, group_size_); + auto [w_q, scale_w_pre, tensor_amax_w] = (inputs[1].dtype() != uint32) + ? quantize_input(inputs[1], encoder, s, mode_, bits_, group_size_) + : QuantizedResult{inputs[1], inputs[2], std::optional(inputs[3])}; out.set_data(cu::malloc_async(out.nbytes(), encoder)); auto out_dtype = out.dtype(); diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu index c8764709b9..81404783a7 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.cu +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -70,6 +70,19 @@ inline std::tuple get_swizzle_launch_args( namespace cu { +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; + +__global__ void compute_qqmm_alpha( + float* alpha_out, + const float* tensor_amax_x, + const float* tensor_amax_w) { + // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 + constexpr float inv_scale_sq = + 1.0f / (F8E4M3_MAX * F4E2M1_MAX * F8E4M3_MAX * F4E2M1_MAX); + *alpha_out = (*tensor_amax_x) * (*tensor_amax_w) * inv_scale_sq; +} + __global__ void swizzle_scales( const uint8_t* scales_linear, uint8_t* scales_swizzled, @@ -224,4 +237,22 @@ void swizzle_scales( output_cols); } +void compute_qqmm_alpha( + array& alpha_out, + const array& tensor_amax_x, + const array& tensor_amax_w, + cu::CommandEncoder& enc) { + enc.set_input_array(tensor_amax_x); + enc.set_input_array(tensor_amax_w); + enc.set_output_array(alpha_out); + enc.add_kernel_node( + cu::compute_qqmm_alpha, + dim3(1), + dim3(1), + 0, + gpu_ptr(alpha_out), + gpu_ptr(tensor_amax_x), + gpu_ptr(tensor_amax_w)); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.h b/mlx/backend/cuda/quantized/qqmm_utils.h index 0a9a78f70c..4e2d9c4739 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.h +++ b/mlx/backend/cuda/quantized/qqmm_utils.h @@ -27,4 +27,11 @@ void swizzle_scales( cu::CommandEncoder& enc, const Stream& s); +// Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 +void compute_qqmm_alpha( + array& alpha_out, + const array& tensor_amax_x, + const array& tensor_amax_w, + cu::CommandEncoder& enc); + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 1d530f8103..88d7839bf2 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -2,7 +2,6 @@ #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/reduce/reduce.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" @@ -65,9 +64,11 @@ void fast::Quantize::eval_gpu( if (mode_ == QuantizationMode::Affine) { auto biases = ensure_row_contiguous(inputs[2], enc, s); affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); - } else { + } else if (mode_ == QuantizationMode::Nvfp4) { auto tensor_amax = inputs[2]; fp_dequantize(wq, scales, tensor_amax, w, group_size_, bits_, enc, s); + } else { + fp_dequantize(wq, scales, {}, w, group_size_, bits_, enc, s); } } else { auto w = ensure_row_contiguous(inputs[0], enc, s); @@ -81,20 +82,13 @@ void fast::Quantize::eval_gpu( auto& biases = outputs[2]; biases.set_data(cu::malloc_async(biases.nbytes(), enc)); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); - } else { + } else if (mode_ == QuantizationMode::Nvfp4) { auto& tensor_amax = outputs[2]; tensor_amax.set_data(cu::malloc_async(tensor_amax.nbytes(), enc)); - all_reduce( - enc, w, tensor_amax, Reduce::ReduceType::AbsMax); // compute amax - fp_quantize( - w, - wq, - scales, - tensor_amax, - group_size_, - bits_, - enc, - s); // pass amax to quantization kernel + all_reduce(enc, w, tensor_amax, Reduce::ReduceType::AbsMax); + fp_quantize(w, wq, scales, tensor_amax, group_size_, bits_, enc, s); + } else { + fp_quantize(w, wq, scales, {}, group_size_, bits_, enc, s); } } } diff --git a/mlx/backend/cuda/quantized/quantized.h b/mlx/backend/cuda/quantized/quantized.h index 4f1980a9c9..fbeb62d6f1 100644 --- a/mlx/backend/cuda/quantized/quantized.h +++ b/mlx/backend/cuda/quantized/quantized.h @@ -1,6 +1,8 @@ // Copyright © 2025 Apple Inc. +#include #include "mlx/backend/cuda/device.h" +#include "mlx/primitives.h" namespace mlx::core { @@ -28,6 +30,7 @@ void fp_quantize( const array& w, array& wq, array& scales, + const std::optional& tensor_amax, int group_size, int bits, cu::CommandEncoder& enc, @@ -36,10 +39,17 @@ void fp_quantize( void fp_dequantize( const array& wq, const array& scales, + const std::optional& tensor_amax, array& w, int group_size, int bits, cu::CommandEncoder& enc, const Stream& s); +void all_reduce( + cu::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type); + } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index bafa5817f9..100fdb3866 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -13,6 +13,20 @@ namespace cu { namespace cg = cooperative_groups; +template +__device__ __forceinline__ T absmax_val(T x) { + if constexpr (cuda::std::is_same_v) { + return x; + } else if constexpr (cuda::std::is_unsigned_v) { + return x; // unsigned types are non-negative + } else if constexpr (cuda::std::is_floating_point_v) { + return fabs(x); + } else { + // signed integer + return x < T(0) ? -x : x; + } +} + template __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { // TODO: Process multiple "rows" in each thread @@ -38,7 +52,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], abs(cast_to(vals[j]))); + accs[0] = op(accs[0], absmax_val(cast_to(vals[j]))); } else { accs[0] = op(accs[0], cast_to(vals[j])); } @@ -50,7 +64,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], abs(cast_to(vals[i]))); + accs[0] = op(accs[0], absmax_val(cast_to(vals[i]))); } else { accs[0] = op(accs[0], cast_to(vals[i])); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 3b28ade4ba..f6fd451735 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4264,6 +4264,7 @@ void validate_qqmm_inputs( array x, array w, std::optional scales_w, + std::optional tensor_amax_w, int group_size, int bits) { // check 2D (for now) @@ -4338,6 +4339,7 @@ array qqmm( array in_x, array w, std::optional scales_w, + std::optional tensor_amax_w, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "nvfp4" */, @@ -4369,7 +4371,21 @@ array qqmm( } else if (w.ndim() == 2 && x.ndim() > 2) { x = flatten(x, 0, -2, s); } - + if (qmode != QuantizationMode::Nvfp4) { + if (tensor_amax_w.has_value()) { + std::ostringstream msg; + msg << "[qqmm] The 'tensor_amax_w' argument is only supported" + << " with 'nvfp4' quantization mode."; + throw std::invalid_argument(msg.str()); + } + } else { + if (!tensor_amax_w.has_value()) { + std::ostringstream msg; + msg << "[qqmm] The 'tensor_amax_w' argument must be provided" + << " with 'nvfp4' quantization mode."; + throw std::invalid_argument(msg.str()); + } + } // validate inputs validate_qqmm_inputs(x, w, scales_w, group_size, bits); // validate and extract shapes @@ -4382,6 +4398,9 @@ array qqmm( if (scales_w.has_value()) { inputs.push_back(*scales_w); } + if (tensor_amax_w.has_value()) { + inputs.push_back(*tensor_amax_w); + } auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; auto out = array( @@ -4534,32 +4553,18 @@ std::vector fp_quantize( << bits << "."; throw std::invalid_argument(msg.str()); } - auto fallback = [bits = bits, group_size = group_size, s]( + constexpr float F8E4M3_MAX = 448.0f; + constexpr float F4E2M1_MAX = 6.0f; + auto fallback = [bits = bits, group_size = group_size, mode = mode, s]( const std::vector& inputs) -> std::vector { auto& w = inputs[0]; - float maxval = (bits == 4) ? 6.0f : 448.0f; + float maxval = (bits == 4) ? F4E2M1_MAX : F8E4M3_MAX; auto new_shape = w.shape(); new_shape.back() = -1; auto wq = reshape(w, {-1, group_size}, s); - auto scales = - divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s); - if (group_size == 16) { - // convert to e4m3 - scales = to_fp8(scales, s); - wq = divide(wq, from_fp8(scales, w.dtype(), s), s); - } else { - // convert to e8m0 - auto z = array(0, scales.dtype()); - scales = where( - equal(scales, z, s), - z, - astype(round(log2(scales, s), s), int32, s), - s); - - wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); - scales = astype(add(scales, array(127, int32), s), uint8, s); - } - if (bits == 4) { + auto block_amax = max(abs(wq, s), -1, true, s); + + auto quantize_to_fp4 = [&](array& wq_in) { auto lut = array({ +0.0f, +0.5f, @@ -4579,11 +4584,41 @@ std::vector fp_quantize( -6.0f, }); lut = astype(lut, w.dtype(), s); - wq = argmin( - abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s); + wq_in = argmin( + abs(subtract(expand_dims(wq_in, -1, s), lut, s), s), -1, false, s); auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); - wq = reshape(wq, {-1, 4, 8}, s); - wq = sum(multiply(wq, shifts, s), -1, false, s); + wq_in = reshape(wq_in, {-1, 4, 8}, s); + wq_in = sum(multiply(wq_in, shifts, s), -1, false, s); + }; + + if (mode == QuantizationMode::Nvfp4) { + auto tensor_amax = astype(max(abs(w, s), s), float32, s); + // Global encode scale: (448 * 6) / tensor_amax + auto scale_enc = divide(array(F8E4M3_MAX * F4E2M1_MAX), tensor_amax, s); + // Per-block decode scale: (block_amax / 6) * scale_enc + auto scales = multiply( + divide(block_amax, array(F4E2M1_MAX, w.dtype()), s), scale_enc, s); + // Convert to e4m3 + scales = to_fp8(scales, s); + // Per-block encode scale: scale_enc / scale_dec_b + auto scale_enc_b = divide(scale_enc, from_fp8(scales, w.dtype(), s), s); + wq = multiply(wq, scale_enc_b, s); + quantize_to_fp4(wq); + wq = reshape(wq, new_shape, s); + scales = reshape(scales, new_shape, s); + return {std::move(wq), std::move(scales), std::move(tensor_amax)}; + } + + auto scales = divide(block_amax, array(maxval, w.dtype()), s); + auto z = array(0, scales.dtype()); + scales = where( + equal(scales, z, s), z, astype(round(log2(scales, s), s), int32, s), s); + + wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); + scales = astype(add(scales, array(127, int32), s), uint8, s); + + if (bits == 4) { + quantize_to_fp4(wq); } else { wq = view(to_fp8(wq, s), uint32, s); } @@ -4597,10 +4632,18 @@ std::vector fp_quantize( wq_shape.back() = w.shape(-1) * bits / 32; auto sshape = w.shape(); sshape.back() = w.shape(-1) / group_size; - Shape tashape = {}; + // nvfp4 fp tensor scale + // TODO: should we try to have w.dtype() here? + std::vector shapes = {std::move(wq_shape), std::move(sshape)}; + std::vector dtypes = {uint32, uint8}; + + if (mode == QuantizationMode::Nvfp4) { + shapes.push_back({}); + dtypes.push_back(float32); + } return array::make_arrays( - {std::move(wq_shape), std::move(sshape), std::move(tashape)}, - {uint32, uint8, float32}, + std::move(shapes), + std::move(dtypes), std::make_shared( s, fallback, group_size, bits, mode, false), {w}); @@ -4744,7 +4787,7 @@ array affine_dequantize( array fp_dequantize( const array& w, const array& scales, - const array& tensor_amax, + const std::optional& tensor_amax, int group_size, int bits, Dtype out_type, @@ -4793,12 +4836,15 @@ array fp_dequantize( throw std::invalid_argument(msg.str()); } + constexpr float F8E4M3_MAX = 448.0f; + constexpr float F4E2M1_MAX = 6.0f; auto fallback = [wshape = std::move(wshape), sshape = std::move(sshape), group_size, bits, out_type, + mode, s](const std::vector& inputs) mutable -> std::vector { auto out = inputs[0]; auto scales = inputs[1]; @@ -4834,7 +4880,16 @@ array fp_dequantize( } out = reshape(out, {-1, group_size}, s); scales = reshape(scales, {-1, 1}, s); - if (group_size == 16) { + if (mode == QuantizationMode::Nvfp4) { + auto tensor_amax = inputs[2]; + // scale_dec_b stored as FP8 e4m3 + scales = from_fp8(scales, out_type, s); + // inv_scale_enc = tensor_amax / (448 * 6) + auto inv_scale_enc = + divide(tensor_amax, array(F8E4M3_MAX * F4E2M1_MAX, out_type), s); + // final scale = scale_dec_b * inv_scale_enc + scales = multiply(scales, inv_scale_enc, s); + } else if (mode == QuantizationMode::Mxfp4) { scales = from_fp8(scales, out_type, s); } else { scales = subtract(astype(scales, out_type, s), array(127, out_type), s); @@ -4845,12 +4900,20 @@ array fp_dequantize( if (s.device == Device::gpu) { auto out_shape = w.shape(); out_shape.back() = out_size; + auto inputs = std::vector{w, scales}; + if (mode == QuantizationMode::Nvfp4) { + inputs.push_back(*tensor_amax); + } + return array( std::move(out_shape), out_type, std::make_shared( s, fallback, group_size, bits, mode, true), - {w, scales, tensor_amax}); + std::move(inputs)); + } + if (mode == QuantizationMode::Nvfp4) { + return fallback({w, scales, *tensor_amax})[0]; } return fallback({w, scales})[0]; } @@ -4858,7 +4921,7 @@ array fp_dequantize( array dequantize( const array& w, const array& scales, - // const std::optional& tensor_amax /* = std::nullopt */, + const std::optional& tensor_amax /* = std::nullopt */, const std::optional& biases /* = std::nullopt */, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, @@ -4897,7 +4960,14 @@ array dequantize( s); } else { return fp_dequantize( - w, scales, group_size, bits, out_type, qmode, to_stream(s)); + w, + scales, + tensor_amax, + group_size, + bits, + out_type, + qmode, + to_stream(s)); } } diff --git a/mlx/ops.h b/mlx/ops.h index ff92cbe926..a1573097c2 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1413,6 +1413,7 @@ std::vector quantize( array dequantize( const array& w, const array& scales, + const std::optional& tensor_amax = std::nullopt, const std::optional& biases = std::nullopt, std::optional group_size = std::nullopt, std::optional bits = std::nullopt, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 9767cd604c..6e8af544f4 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3421,6 +3421,7 @@ std::vector QuantizedMatmul::vjp( primals[1], ones_like(primals[2], stream()), zeros_like(primals[3], stream()), + {}, group_size_, bits_, quantization_mode_to_string(mode_), @@ -3643,6 +3644,7 @@ std::vector GatherQMM::vjp( w, ones_like(scales, stream()), zeros_like(*biases, stream()), + {}, group_size_, bits_, quantization_mode_to_string(mode_), diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 62fd8c5923..6c8cd8ab42 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3087,7 +3087,7 @@ TEST_CASE("test quantize dequantize") { CHECK_EQ(scales.shape(), Shape{128, 4}); CHECK_EQ(biases.shape(), Shape{128, 4}); - auto x_hat = dequantize(x_q, scales, biases, 128, i); + auto x_hat = dequantize(x_q, scales, {}, biases, 128, i); auto max_diff = max(abs(x - x_hat)).item(); CHECK(max_diff <= 127.0 / (1 << i)); } From a7fab997d6c4787fbe7c18963b1cf925018be587 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 19 Jan 2026 13:06:49 +0100 Subject: [PATCH 04/34] alpha device ptr for qqmm --- mlx/backend/cuda/gemms/cublas_gemm.cpp | 4 +- mlx/backend/cuda/quantized/cublas_qqmm.cpp | 74 +++++++++++++++++++--- mlx/backend/cuda/quantized/cublas_qqmm.h | 29 +++++---- mlx/backend/cuda/quantized/qqmm.cpp | 5 +- mlx/ops.cpp | 38 ++++++----- 5 files changed, 107 insertions(+), 43 deletions(-) diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index ee33c78acf..5046e69681 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -214,8 +214,8 @@ void CublasGemm::execute( const void* a, const void* b, const void* c, - float alpha /* = 1 */, - float beta /* = 0 */) { + const float alpha /* = 1 */, + const float beta /* = 0 */) { const void* alpha_ptr = α const void* beta_ptr = β complex64_t alpha_c, beta_c; diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp index 959ae93520..a67278391f 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.cpp +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -150,6 +150,35 @@ CublasQQMM::CublasQQMM( batch_count, c_batch_stride); } +// Supported overloads: +// alpha float +// alpha device ptr + +void CublasQQMM::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + const array& alpha) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(a_scale); + encoder.set_input_array(b_scale); + encoder.set_input_array(alpha); + encoder.set_output_array(out); + + execute( + encoder, + gpu_ptr(out), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(a_scale), + gpu_ptr(b_scale), + nullptr, + gpu_ptr(alpha)); +} void CublasQQMM::run( cu::CommandEncoder& encoder, @@ -176,16 +205,10 @@ void CublasQQMM::run( alpha); } -void CublasQQMM::execute( +void CublasQQMM::set_scales_decs( cu::CommandEncoder& encoder, - void* out, - const void* a, - const void* b, const void* a_scale, - const void* b_scale, - const void* c, - float alpha /* = 1 */, - float beta /* = 0 */) { + const void* b_scale) { CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, @@ -196,6 +219,41 @@ void CublasQQMM::execute( CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &a_scale, sizeof(a_scale))); +} + +void CublasQQMM::execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + const float* alpha) { + set_scales_decs(encoder, a_scale, b_scale); + // alpha and beta are both should be device pointers + // by default cublas uses host pointers + // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(pointer_mode))); + execute_matmul(encoder, out, a, b, c, alpha, nullptr); +} + +void CublasQQMM::execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + const float alpha /* = 1 */, + const float beta /* = 0 */) { + set_scales_decs(encoder, a_scale, b_scale); const void* alpha_ptr = α const void* beta_ptr = β diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h index ecce044afd..7cb40e54b7 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.h +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -61,20 +61,23 @@ class CublasQQMM : public CublasMatmulBase { const array& b, const array& a_scale, const array& b_scale, - float alpha = 1.0f); + const float alpha = 1.0f); private: - void run_batched( + void set_scales_decs( cu::CommandEncoder& encoder, - array& out, - const array& a, - const array& b, - const array& a_scale, - const array& b_scale, - const Shape& batch_shape, - const Strides& a_batch_strides, - const Strides& b_batch_strides, - float alpha); + const void* a_scale, + const void* b_scale); + + void execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + const float* alpha); void execute( cu::CommandEncoder& encoder, @@ -84,8 +87,8 @@ class CublasQQMM : public CublasMatmulBase { const void* a_scale, const void* b_scale, const void* c, - float alpha = 1, - float beta = 0); + const float alpha = 1.0f, + const float beta = 0.0f); std::string quantization_mode_; cublasLtMatmulMatrixScale_t a_scale_mode_; diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 0bc8bd0826..739d0b0f4c 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -125,12 +125,11 @@ QuantizedResult quantize_input( }; if (mode == QuantizationMode::Nvfp4) { - array tensor_amax(cu::malloc_async(sizeof(float), encoder), {1}, float32); + array tensor_amax(cu::malloc_async(sizeof(float), encoder), {}, float32); encoder.add_temporary(tensor_amax); all_reduce(encoder, x, tensor_amax, Reduce::ReduceType::AbsMax); return run_quant(x, tensor_amax); } - return run_quant(x, std::nullopt); } @@ -150,7 +149,7 @@ void qqmm_impl( const array& b_scale, Dtype out_dtype, QuantizationMode mode, - float alpha = 1.0f) { + const float alpha) { // Invoke CublasQQMM std::string qmode = quantization_mode_to_string(mode); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f6fd451735..0f1a7cfebe 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4266,7 +4266,8 @@ void validate_qqmm_inputs( std::optional scales_w, std::optional tensor_amax_w, int group_size, - int bits) { + int bits, + QuantizationMode qmode) { // check 2D (for now) if (x.ndim() > 2 || w.ndim() > 2) { std::ostringstream msg; @@ -4303,6 +4304,24 @@ void validate_qqmm_inputs( << "first argument dtype == " << x.dtype() << "."; throw std::invalid_argument(msg.str()); } + // TODO: not sure if we want to support nvfp4 without tensor amax, + // maybe by adding a boolean variable for quantization if not amax -> set to + // array(1, foat32) + if (qmode != QuantizationMode::Nvfp4) { + if (tensor_amax_w.has_value()) { + std::ostringstream msg; + msg << "[qqmm] The 'tensor_amax_w' argument is only supported" + << " with 'nvfp4' quantization mode."; + throw std::invalid_argument(msg.str()); + } + } else { + if (!tensor_amax_w.has_value()) { + std::ostringstream msg; + msg << "[qqmm] The 'tensor_amax_w' argument must be provided" + << " with 'nvfp4' quantization mode."; + throw std::invalid_argument(msg.str()); + } + } } std::pair extract_qqmm_dims( @@ -4371,23 +4390,8 @@ array qqmm( } else if (w.ndim() == 2 && x.ndim() > 2) { x = flatten(x, 0, -2, s); } - if (qmode != QuantizationMode::Nvfp4) { - if (tensor_amax_w.has_value()) { - std::ostringstream msg; - msg << "[qqmm] The 'tensor_amax_w' argument is only supported" - << " with 'nvfp4' quantization mode."; - throw std::invalid_argument(msg.str()); - } - } else { - if (!tensor_amax_w.has_value()) { - std::ostringstream msg; - msg << "[qqmm] The 'tensor_amax_w' argument must be provided" - << " with 'nvfp4' quantization mode."; - throw std::invalid_argument(msg.str()); - } - } // validate inputs - validate_qqmm_inputs(x, w, scales_w, group_size, bits); + validate_qqmm_inputs(x, w, scales_w, tensor_amax_w, group_size, bits); // validate and extract shapes auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims(x, w, scales_w, group_size, bits); From 47be9947865d1ce63b1d5addb8295c2eb4b206a7 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 20 Jan 2026 16:35:03 +0100 Subject: [PATCH 05/34] device alpha, beta --- mlx/backend/cuda/cublas_utils.cpp | 7 - mlx/backend/cuda/gemms/cublas_gemm.cpp | 8 + mlx/backend/cuda/quantized/cublas_qqmm.cpp | 92 ++++----- mlx/backend/cuda/quantized/cublas_qqmm.h | 16 +- mlx/backend/cuda/quantized/qqmm.cpp | 229 +++++++++++---------- mlx/backend/cuda/quantized/qqmm_utils.cu | 11 +- mlx/backend/cuda/quantized/qqmm_utils.h | 5 +- mlx/ops.cpp | 37 ++-- mlx/ops.h | 3 + mlx/primitives.cpp | 2 + python/src/ops.cpp | 18 +- 11 files changed, 237 insertions(+), 191 deletions(-) diff --git a/mlx/backend/cuda/cublas_utils.cpp b/mlx/backend/cuda/cublas_utils.cpp index 108f56c8ae..f9876055d5 100644 --- a/mlx/backend/cuda/cublas_utils.cpp +++ b/mlx/backend/cuda/cublas_utils.cpp @@ -105,13 +105,6 @@ void CublasMatmulBase::init_base( CHECK_CUBLAS_ERROR( cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type)); - int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, - sizeof(int32_t))); - // In cublasLt matrices use column-major layout, while it is possible to use // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias // epilogue does not work with the option. So instead we swap A and B to make diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 5046e69681..6c7a48efdd 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -72,6 +72,14 @@ CublasGemm::CublasGemm( batch_count, a_batch_stride, b_batch_stride); + + // alpha and beta are both host pointers + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(pointer_mode))); } CublasGemm::CublasGemm( diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp index a67278391f..41446c6f98 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.cpp +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -13,39 +13,26 @@ namespace mlx::core { namespace { -// Currently cublas supports only mxfp8 and nvfp4 -// quantization modes for block scaled quantization -cudaDataType_t qmode_to_cublas_scale_dtype(std::string mode) { - if (mode == "mxfp8") { - return CUDA_R_8F_UE8M0; - } else if (mode == "nvfp4") { - return CUDA_R_8F_UE4M3; - } else { - throw std::runtime_error( - fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); - } -} - -cudaDataType_t qmode_to_cublas_dtype(std::string mode) { - if (mode == "mxfp8") { - return CUDA_R_8F_E4M3; - } else if (mode == "nvfp4") { - return CUDA_R_4F_E2M1; - } else { - throw std::runtime_error( - fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); - } -} +struct QuantModeConfig { + cudaDataType_t data_type; + cudaDataType_t scale_dtype; + cublasLtMatmulMatrixScale_t scale_mode; +}; -cublasLtMatmulMatrixScale_t qmode_to_cublas_scale_mode(std::string mode) { +QuantModeConfig get_quant_mode_config(const std::string& mode) { if (mode == "mxfp8") { - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + return { + CUDA_R_8F_E4M3, + CUDA_R_8F_UE8M0, + CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0}; } else if (mode == "nvfp4") { - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; - } else { - throw std::runtime_error( - fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); + return { + CUDA_R_4F_E2M1, + CUDA_R_8F_UE4M3, + CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3}; } + throw std::runtime_error( + fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); } } // namespace @@ -64,21 +51,21 @@ CublasQQMM::CublasQQMM( int64_t a_batch_stride, int64_t b_batch_stride, Dtype out_dtype, - std::string qmode) { + const std::string& qmode) { + auto config = get_quant_mode_config(qmode); + // The compute type must be CUBLAS_COMPUTE_32F. // The scale type must be CUDA_R_32F. cudaDataType_t scale_type = CUDA_R_32F; cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F; cudaDataType_t output_type = cublas_utils::dtype_to_cublas_type(out_dtype, "CublasQQMM"); - cudaDataType_t data_type = qmode_to_cublas_dtype(qmode); - quantization_mode_ = std::string(qmode); init_base( device, scale_type, gemm_compute_type, - data_type, + config.data_type, output_type, a_transposed, a_rows, @@ -92,8 +79,8 @@ CublasQQMM::CublasQQMM( a_batch_stride, b_batch_stride); - a_scale_mode_ = qmode_to_cublas_scale_mode(qmode); - b_scale_mode_ = qmode_to_cublas_scale_mode(qmode); + a_scale_mode_ = config.scale_mode; + b_scale_mode_ = config.scale_mode; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, @@ -123,7 +110,7 @@ CublasQQMM::CublasQQMM( int64_t b_batch_stride, int64_t c_batch_stride, Dtype out_dtype, - std::string qmode) + const std::string& qmode) : CublasQQMM( device, a_transposed, @@ -161,12 +148,14 @@ void CublasQQMM::run( const array& b, const array& a_scale, const array& b_scale, - const array& alpha) { + const array& alpha, + const array& beta) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(a_scale); encoder.set_input_array(b_scale); encoder.set_input_array(alpha); + encoder.set_input_array(beta); encoder.set_output_array(out); execute( @@ -177,7 +166,8 @@ void CublasQQMM::run( gpu_ptr(a_scale), gpu_ptr(b_scale), nullptr, - gpu_ptr(alpha)); + gpu_ptr(alpha), + gpu_ptr(beta)); } void CublasQQMM::run( @@ -186,8 +176,7 @@ void CublasQQMM::run( const array& a, const array& b, const array& a_scale, - const array& b_scale, - float alpha) { + const array& b_scale) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(a_scale); @@ -201,11 +190,10 @@ void CublasQQMM::run( gpu_ptr(b), gpu_ptr(a_scale), gpu_ptr(b_scale), - nullptr, - alpha); + nullptr); } -void CublasQQMM::set_scales_decs( +void CublasQQMM::set_scales_ptrs( cu::CommandEncoder& encoder, const void* a_scale, const void* b_scale) { @@ -229,9 +217,10 @@ void CublasQQMM::execute( const void* a_scale, const void* b_scale, const void* c, - const float* alpha) { - set_scales_decs(encoder, a_scale, b_scale); - // alpha and beta are both should be device pointers + const void* alpha, + const void* beta) { + set_scales_ptrs(encoder, a_scale, b_scale); + // alpha and beta are both should be device pointers for nvfp4 // by default cublas uses host pointers // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; @@ -240,7 +229,7 @@ void CublasQQMM::execute( CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); - execute_matmul(encoder, out, a, b, c, alpha, nullptr); + execute_matmul(encoder, out, a, b, c, alpha, beta); } void CublasQQMM::execute( @@ -253,7 +242,14 @@ void CublasQQMM::execute( const void* c, const float alpha /* = 1 */, const float beta /* = 0 */) { - set_scales_decs(encoder, a_scale, b_scale); + set_scales_ptrs(encoder, a_scale, b_scale); + // alpha and beta are both should be host pointers + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(pointer_mode))); const void* alpha_ptr = α const void* beta_ptr = β diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h index 7cb40e54b7..a9095012c5 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.h +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -25,7 +25,7 @@ class CublasQQMM : public CublasMatmulBase { int64_t a_batch_stride, int64_t b_batch_stride, Dtype out_dtype, - std::string quantization_mode); + const std::string& quantization_mode); CublasQQMM( cu::Device& device, @@ -43,7 +43,7 @@ class CublasQQMM : public CublasMatmulBase { int64_t b_batch_stride, int64_t c_batch_stride, Dtype out_dtype, - std::string quantization_mode); + const std::string& quantization_mode); void run( cu::CommandEncoder& encoder, @@ -52,7 +52,8 @@ class CublasQQMM : public CublasMatmulBase { const array& b, const array& a_scale, const array& b_scale, - const array& alpha); + const array& alpha, + const array& beta); void run( cu::CommandEncoder& encoder, @@ -60,11 +61,10 @@ class CublasQQMM : public CublasMatmulBase { const array& a, const array& b, const array& a_scale, - const array& b_scale, - const float alpha = 1.0f); + const array& b_scale); private: - void set_scales_decs( + void set_scales_ptrs( cu::CommandEncoder& encoder, const void* a_scale, const void* b_scale); @@ -77,7 +77,8 @@ class CublasQQMM : public CublasMatmulBase { const void* a_scale, const void* b_scale, const void* c, - const float* alpha); + const void* alpha, + const void* beta); void execute( cu::CommandEncoder& encoder, @@ -90,7 +91,6 @@ class CublasQQMM : public CublasMatmulBase { const float alpha = 1.0f, const float beta = 0.0f); - std::string quantization_mode_; cublasLtMatmulMatrixScale_t a_scale_mode_; cublasLtMatmulMatrixScale_t b_scale_mode_; cublasLtMatmulMatrixScale_t c_scale_mode_; diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 739d0b0f4c..9eb52664a2 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -11,37 +11,25 @@ namespace mlx::core { -using QuantizedResult = std::tuple>; - namespace { -inline array ensure_row_contiguous( - const array& x, - cu::CommandEncoder& enc, - const Stream& s) { - if (!x.flags().row_contiguous) { - array x_copy = contiguous_copy_gpu(x, s); - enc.add_temporary(x_copy); - return x_copy; - } else { - return x; +using QuantizedTensor = std::tuple>; + +struct GemmScalars { + std::optional alpha_device; + std::optional beta_device; + + bool uses_device_pointers() const { + return alpha_device.has_value(); } -} +}; -inline array ensure_row_contiguous_matrix( +inline array ensure_row_contiguous( const array& x, cu::CommandEncoder& enc, const Stream& s) { - if (x.ndim() < 2) { - if (x.strides()[0] == 1) { - return x; - } - } else { - auto stride_0 = x.strides()[x.ndim() - 2]; - auto stride_1 = x.strides()[x.ndim() - 1]; - if (stride_0 == x.shape(-1) && stride_1 == 1) { - return x; - } + if (x.flags().row_contiguous) { + return x; } array x_copy = contiguous_copy_gpu(x, s); enc.add_temporary(x_copy); @@ -55,24 +43,22 @@ array pad_and_swizzle_scales( // Compute padded dimensions for full tiles (128 rows × 4 cols) auto [pad_outer, pad_inner] = get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); + // cuBLAS requirements for scale factor layout: // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) // 2. Out-of-bounds values must be filled with zeros // 3. Starting addresses must be 16-byte aligned - // // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout - // Note: cu::malloc_async already provides 256-byte alignment array scale_tiled( cu::malloc_async(pad_outer * pad_inner, encoder), Shape{pad_outer, pad_inner}, scale.dtype()); swizzle_scales(scale, scale_tiled, encoder, s); - encoder.add_temporary(scale_tiled); return scale_tiled; } -QuantizedResult quantize_input( +QuantizedTensor quantize_input( const array& input, cu::CommandEncoder& encoder, const Stream& s, @@ -81,59 +67,80 @@ QuantizedResult quantize_input( int group_size) { const array x = ensure_row_contiguous(input, encoder, s); - auto build_shapes = [&](const array& x_in) { - auto xq_shape = x_in.shape(); - xq_shape.back() = x_in.shape(-1) * bits / 32; - - auto sshape = x_in.shape(); - const int64_t scales_inner = x_in.shape(-1) / group_size; - auto [pad_outer, pad_inner] = - get_padded_scale_dims(x_in.shape(-2), scales_inner); - sshape[x_in.ndim() - 2] = pad_outer; - sshape[x_in.ndim() - 1] = pad_inner; - sshape.back() = scales_inner; - - return std::tuple{ - std::move(xq_shape), - std::move(sshape), - pad_outer, - pad_inner, - }; - }; - - auto allocate_outputs = [&](const array& x_in) { - auto [xq_shape, sshape, pad_outer, pad_inner] = build_shapes(x_in); - - const int64_t xq_bytes = x_in.size() * bits / 8; - const int64_t batch = x_in.size() / (x_in.shape(-2) * x_in.shape(-1)); - const int64_t scales_bytes = batch * (pad_outer * pad_inner); - - array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32); - array scales_x( - cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8); - encoder.add_temporary(x_q); - encoder.add_temporary(scales_x); - - return std::pair{std::move(x_q), std::move(scales_x)}; - }; - - auto run_quant = [&](const array& x_in, std::optional tensor_amax) { - auto [x_q, scales_x] = allocate_outputs(x_in); - fp_quantize(x_in, x_q, scales_x, tensor_amax, group_size, bits, encoder, s); - return QuantizedResult{ - std::move(x_q), std::move(scales_x), std::move(tensor_amax)}; - }; + // Compute output shapes + auto xq_shape = x.shape(); + xq_shape.back() = x.shape(-1) * bits / 32; + + const int64_t scales_inner = x.shape(-1) / group_size; + auto [pad_outer, pad_inner] = + get_padded_scale_dims(x.shape(-2), scales_inner); + + auto sshape = x.shape(); + sshape[x.ndim() - 2] = pad_outer; + sshape[x.ndim() - 1] = pad_inner; + sshape.back() = scales_inner; + + // Allocate outputs + const int64_t xq_bytes = x.size() * bits / 8; + const int64_t batch = x.size() / (x.shape(-2) * x.shape(-1)); + const int64_t scales_bytes = batch * (pad_outer * pad_inner); + array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32); + array scales_x( + cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8); + encoder.add_temporary(x_q); + encoder.add_temporary(scales_x); + + // For NVFP4: compute tensor-wide amax for global scaling + std::optional tensor_amax = std::nullopt; if (mode == QuantizationMode::Nvfp4) { - array tensor_amax(cu::malloc_async(sizeof(float), encoder), {}, float32); - encoder.add_temporary(tensor_amax); - all_reduce(encoder, x, tensor_amax, Reduce::ReduceType::AbsMax); - return run_quant(x, tensor_amax); + array amax(cu::malloc_async(sizeof(float), encoder), {}, float32); + encoder.add_temporary(amax); + all_reduce(encoder, x, amax, Reduce::ReduceType::AbsMax); + tensor_amax = amax; } - return run_quant(x, std::nullopt); + + fp_quantize(x, x_q, scales_x, tensor_amax, group_size, bits, encoder, s); + return {std::move(x_q), std::move(scales_x), std::move(tensor_amax)}; } -void qqmm_impl( +QuantizedTensor get_weight_tensors( + const std::vector& inputs, + cu::CommandEncoder& encoder, + const Stream& s, + QuantizationMode mode, + int bits, + int group_size) { + // Check if weights need quantization + if (inputs[1].dtype() != uint32) { + return quantize_input(inputs[1], encoder, s, mode, bits, group_size); + } + + // Weights are pre-quantized + if (mode == QuantizationMode::Nvfp4) { + // NVFP4: inputs = [x, w_q, scale_w, tensor_amax_w] + return {inputs[1], inputs[2], inputs[3]}; + } + // MXFP8: inputs = [x, w_q, scale_w] (no tensor_amax) + return {inputs[1], inputs[2], std::nullopt}; +} + +GemmScalars create_nvfp4_scalars( + const array& tensor_amax_x, + const array& tensor_amax_w, + cu::CommandEncoder& encoder) { + // NVFP4 requires alpha/beta as device pointers + // alpha = amax_x * amax_w / (448 * 6)^2 + // beta = 0 + array alpha(cu::malloc_async(sizeof(float), encoder), {}, float32); + array beta(cu::malloc_async(sizeof(float), encoder), {}, float32); + compute_qqmm_pointers(alpha, beta, tensor_amax_x, tensor_amax_w, encoder); + encoder.add_temporary(alpha); + encoder.add_temporary(beta); + return {alpha, beta}; +} + +void run_qqmm( cu::CommandEncoder& encoder, int M, int N, @@ -147,15 +154,10 @@ void qqmm_impl( const array& b, const array& a_scale, const array& b_scale, - Dtype out_dtype, QuantizationMode mode, - const float alpha) { - // Invoke CublasQQMM + const GemmScalars& scalars) { std::string qmode = quantization_mode_to_string(mode); - // Currently only supports non-batched QQMM operations - // that covers all use cases for training, we will just collapse (batch, - // seq_len) into (tokens) CublasQQMM qqmm( encoder.device(), a_transposed, @@ -169,53 +171,72 @@ void qqmm_impl( 1, // batch_count 0, // a_batch_stride 0, // b_batch_stride - out_dtype, + out.dtype(), qmode); - qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha); + if (scalars.uses_device_pointers()) { + qqmm.run( + encoder, + out, + a, + b, + a_scale, + b_scale, + *scalars.alpha_device, + *scalars.beta_device); + } else { + qqmm.run(encoder, out, a, b, a_scale, b_scale); + } } + } // namespace void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("QQMatmul::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); + + // Check compute capability (requires Blackwell or newer) auto& device = encoder.device(); - auto cc = device.compute_capability_major() * 100 + + int cc = device.compute_capability_major() * 100 + device.compute_capability_minor() * 10; if (cc < 1000) { throw std::runtime_error( - "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher."); + "[QQMatmul::eval_gpu] QQMM requires compute capability 10.0+"); } - auto quant_input_size = (mode_ == QuantizationMode::Nvfp4) ? 4 : 3; + + size_t expected_size = (mode_ == QuantizationMode::Nvfp4) ? 4 : 3; assert( - (inputs.size() == quant_input_size && inputs[1].dtype() == uint32) || + (inputs.size() == expected_size && inputs[1].dtype() == uint32) || (inputs.size() == 2)); + // Quantize inputs (or use pre-quantized) auto [x_q, scale_x_pre, tensor_amax_x] = quantize_input(inputs[0], encoder, s, mode_, bits_, group_size_); - auto [w_q, scale_w_pre, tensor_amax_w] = (inputs[1].dtype() != uint32) - ? quantize_input(inputs[1], encoder, s, mode_, bits_, group_size_) - : QuantizedResult{inputs[1], inputs[2], std::optional(inputs[3])}; - out.set_data(cu::malloc_async(out.nbytes(), encoder)); + auto [w_q, scale_w_pre, tensor_amax_w] = + get_weight_tensors(inputs, encoder, s, mode_, bits_, group_size_); - auto out_dtype = out.dtype(); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); int M = x_q.shape(-2); - int N = w_q.shape(-2); // always transposed - int K_packed = x_q.shape(-1); - int K = K_packed * (32 / bits_); - - // Repack scales from linear to tiled layout for tensor cores - array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s); - array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s); + int N = w_q.shape(-2); // transposed + int K = x_q.shape(-1) * (32 / bits_); bool x_transposed = false; - bool w_transposed = true; // always transposed + bool w_transposed = true; // weights are always transposed int64_t lda = K; int64_t ldb = K; - qqmm_impl( + // Repack scales to tiled layout for tensor cores + array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s); + array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s); + + GemmScalars scalars; + if (mode_ == QuantizationMode::Nvfp4) { + scalars = create_nvfp4_scalars(*tensor_amax_x, *tensor_amax_w, encoder); + } + + run_qqmm( encoder, M, N, @@ -229,8 +250,8 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { w_q, scale_x, scale_w, - out_dtype, - mode_); + mode_, + scalars); } } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu index 81404783a7..d429adbbb7 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.cu +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -73,14 +73,16 @@ namespace cu { constexpr float F8E4M3_MAX = 448.0f; constexpr float F4E2M1_MAX = 6.0f; -__global__ void compute_qqmm_alpha( +__global__ void compute_qqmm_pointers( float* alpha_out, + float* beta_out, const float* tensor_amax_x, const float* tensor_amax_w) { // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 constexpr float inv_scale_sq = 1.0f / (F8E4M3_MAX * F4E2M1_MAX * F8E4M3_MAX * F4E2M1_MAX); *alpha_out = (*tensor_amax_x) * (*tensor_amax_w) * inv_scale_sq; + *beta_out = 0.0f; } __global__ void swizzle_scales( @@ -237,20 +239,23 @@ void swizzle_scales( output_cols); } -void compute_qqmm_alpha( +void compute_qqmm_pointers( array& alpha_out, + array& beta_out, const array& tensor_amax_x, const array& tensor_amax_w, cu::CommandEncoder& enc) { enc.set_input_array(tensor_amax_x); enc.set_input_array(tensor_amax_w); enc.set_output_array(alpha_out); + enc.set_output_array(beta_out); enc.add_kernel_node( - cu::compute_qqmm_alpha, + cu::compute_qqmm_pointers, dim3(1), dim3(1), 0, gpu_ptr(alpha_out), + gpu_ptr(beta_out), gpu_ptr(tensor_amax_x), gpu_ptr(tensor_amax_w)); } diff --git a/mlx/backend/cuda/quantized/qqmm_utils.h b/mlx/backend/cuda/quantized/qqmm_utils.h index 4e2d9c4739..e40f09190f 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.h +++ b/mlx/backend/cuda/quantized/qqmm_utils.h @@ -28,8 +28,11 @@ void swizzle_scales( const Stream& s); // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 -void compute_qqmm_alpha( +// Allocate beta zero on device as well + +void compute_qqmm_pointers( array& alpha_out, + array& beta_out, const array& tensor_amax_x, const array& tensor_amax_w, cu::CommandEncoder& enc); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0f1a7cfebe..04a736c191 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4283,11 +4283,28 @@ void validate_qqmm_inputs( throw std::invalid_argument( "[qqmm] Scales must be provided if second argument is quantized."); } + if (qmode == QuantizationMode::Nvfp4) { + // nvfp4 quantization requires tensor amax + if (!tensor_amax_w.has_value()) { + std::ostringstream msg; + msg << "[qqmm] The 'tensor_amax_w' argument must be provided" + << " with 'nvfp4' quantization mode."; + throw std::invalid_argument(msg.str()); + } + } // if scales are provided, check compatibility with quantized w else { validate_quantized_input("qqmm", w, *scales_w, group_size, bits); + // other quantization modes do not support tensor amax + if (tensor_amax_w.has_value()) { + std::ostringstream msg; + msg << "[qqmm] The 'tensor_amax_w' argument is only supported" + << " with 'nvfp4' quantization mode."; + throw std::invalid_argument(msg.str()); + } } } + // if w is not quantized, dtype must be in {f16, bf16, fp32} else { if (!issubdtype(w.dtype(), floating) || w.dtype() == float64) { @@ -4304,24 +4321,6 @@ void validate_qqmm_inputs( << "first argument dtype == " << x.dtype() << "."; throw std::invalid_argument(msg.str()); } - // TODO: not sure if we want to support nvfp4 without tensor amax, - // maybe by adding a boolean variable for quantization if not amax -> set to - // array(1, foat32) - if (qmode != QuantizationMode::Nvfp4) { - if (tensor_amax_w.has_value()) { - std::ostringstream msg; - msg << "[qqmm] The 'tensor_amax_w' argument is only supported" - << " with 'nvfp4' quantization mode."; - throw std::invalid_argument(msg.str()); - } - } else { - if (!tensor_amax_w.has_value()) { - std::ostringstream msg; - msg << "[qqmm] The 'tensor_amax_w' argument must be provided" - << " with 'nvfp4' quantization mode."; - throw std::invalid_argument(msg.str()); - } - } } std::pair extract_qqmm_dims( @@ -4391,7 +4390,7 @@ array qqmm( x = flatten(x, 0, -2, s); } // validate inputs - validate_qqmm_inputs(x, w, scales_w, tensor_amax_w, group_size, bits); + validate_qqmm_inputs(x, w, scales_w, tensor_amax_w, group_size, bits, qmode); // validate and extract shapes auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims(x, w, scales_w, group_size, bits); diff --git a/mlx/ops.h b/mlx/ops.h index a1573097c2..7f2c65993f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1426,6 +1426,9 @@ array qqmm( array w, // maybe quantized weights std::optional w_scales = std::nullopt, // optional scales if w is // quantized + std::optional w_tensor_scale = + std::nullopt, // optional tensor amax if + // w is nvfp4 quantized std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "nvfp4", diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 6e8af544f4..984456eeb9 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3511,6 +3511,7 @@ std::vector QQMatmul::vjp( cotan, // M X N swapaxes(primals[1], -1, -2, s), // assuming that w is 2D {}, + {}, group_size_, bits_, qmode, @@ -3520,6 +3521,7 @@ std::vector QQMatmul::vjp( swapaxes(cotan, -1, -2, s), // (N, M) swapaxes(primals[0], -1, -2, s), // (K, M) {}, + {}, group_size_, bits_, qmode, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 7160468582..f9a7c1585a 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4288,6 +4288,8 @@ void init_ops(nb::module_& m) { * w_q (array): The quantized version of ``w`` * scales (array): The quantization scales + * tensor_scale (array): The per-tensor float32 absolute max + scale (returned for ``mode == "nvfp4"``) * biases (array): The quantization biases (returned for ``mode=="affine"``). Notes: @@ -4344,9 +4346,21 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "dequantize", - &mx::dequantize, + [](const mx::array& w, + const mx::array& scales, + const std::optional& tensor_scale, + const std::optional& biases, + std::optional group_size, + std::optional bits, + const std::string& mode, + std::optional dtype, + mx::StreamOrDevice s) { + return mx::dequantize( + w, scales, tensor_scale, biases, group_size, bits, mode, dtype, s); + }, nb::arg(), "scales"_a, + "tensor_scale"_a = nb::none(), "biases"_a = nb::none(), "group_size"_a = nb::none(), "bits"_a = nb::none(), @@ -4362,6 +4376,8 @@ void init_ops(nb::module_& m) { Args: w (array): Matrix to be dequantized scales (array): The scales to use per ``group_size`` elements of ``w``. + tensor_scale (array, optional): The per-tensor float32 scale used for + ``"nvfp4"`` quantization. Default: ``None``. biases (array, optional): The biases to use per ``group_size`` elements of ``w``. Default: ``None``. group_size (int, optional): The size of the group in ``w`` that shares a From 7e4c6e8fd6da87c74189fb4cf91716a2584f1793 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 20 Jan 2026 20:43:49 +0100 Subject: [PATCH 06/34] harcoded absmax to output float --- mlx/backend/cuda/quantized/qqmm.cpp | 2 +- mlx/backend/cuda/reduce/reduce_ops.cuh | 8 +++++--- python/src/ops.cpp | 3 +++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 9eb52664a2..c008af3e44 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -223,7 +223,7 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { int K = x_q.shape(-1) * (32 / bits_); bool x_transposed = false; - bool w_transposed = true; // weights are always transposed + bool w_transposed = true; // always transposed int64_t lda = K; int64_t ldb = K; diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 91fbe9460a..7b55bdc525 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -166,9 +166,10 @@ struct ReduceResult { using type = T; }; +// TODO: this should not be hardcoded template struct ReduceResult { - using type = T; + using type = float; }; // Traits to get the init value of reduce op. @@ -227,8 +228,9 @@ struct ReduceInit { template struct ReduceInit { - static constexpr __host__ __device__ T value() { - return T(0); // abs values are >= 0 + using result_type = typename ReduceResult::type; + static constexpr __host__ __device__ result_type value() { + return result_type(0); // abs values are >= 0 } }; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f9a7c1585a..31cda08d8c 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5474,6 +5474,7 @@ void init_ops(nb::module_& m) { nb::arg(), // x nb::arg(), // w_q "scales"_a = nb::none(), // scales w + "tensor_scale"_a = nb::none(), "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "nvfp4", @@ -5504,6 +5505,8 @@ void init_ops(nb::module_& m) { w (array): Weight matrix. If quantized, it is packed in unsigned integers. scales (array, optional): The scales to use per ``group_size`` elements of ``w`` if ``w`` is quantized. Default: ``None``. + tensor_scale (array, optional): The tensor-wide scale to use if ``w`` is + quantized and ``mode="nvfp4"``. Default: ``None``. group_size (int, optional): Number of elements in ``x`` and ``w`` that share a scale. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. From 11ff19ad586556d2309668685a397ebbbf680a05 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 20 Jan 2026 20:51:55 +0100 Subject: [PATCH 07/34] fixed ops python dequantize --- python/src/ops.cpp | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 31cda08d8c..c9550a7bca 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4346,18 +4346,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "dequantize", - [](const mx::array& w, - const mx::array& scales, - const std::optional& tensor_scale, - const std::optional& biases, - std::optional group_size, - std::optional bits, - const std::string& mode, - std::optional dtype, - mx::StreamOrDevice s) { - return mx::dequantize( - w, scales, tensor_scale, biases, group_size, bits, mode, dtype, s); - }, + &mx::dequantize, nb::arg(), "scales"_a, "tensor_scale"_a = nb::none(), From 2a86dc16c64d0437d909730a93a7d09b63bcf023 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 21 Jan 2026 00:52:42 +0100 Subject: [PATCH 08/34] input global_scale --- mlx/backend/cuda/quantized/fp_quantize.cu | 32 ++-- mlx/backend/cuda/quantized/qqmm.cpp | 70 +++----- mlx/backend/cuda/quantized/quantized.cpp | 17 +- mlx/ops.cpp | 194 ++++++++++------------ 4 files changed, 137 insertions(+), 176 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 7f574d7fa4..66d6465f16 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -32,20 +32,20 @@ struct Dequantize { }; namespace cg = cooperative_groups; - +// TODO: global_scale type template __global__ void fp_quantize( T* w, uint8_t* out, uint8_t* scales, size_t size, - float* tensor_amax = nullptr) { + float* global_scale = nullptr) { // NVFP4 conversion: - // Global encode scale: (448 × 6) / *tensor_amax + // Global encode scale: (448 × 6) / *global_scale // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b const float scale_enc = - !use_mx_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *tensor_amax : 1.0f; + !use_mx_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; using Tx2 = Vector2_t; using Tx4 = Vector4_t; uint32_t rbits = 0; // reserved bits for future use @@ -113,7 +113,7 @@ __global__ void fp_dequantize( const uint8_t* scales, T* out, size_t size, - float* tensor_amax = nullptr) { + float* global_scale = nullptr) { auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); @@ -125,7 +125,7 @@ __global__ void fp_dequantize( constexpr int pack_factor = bits == 8 ? 1 : 2; const float inv_scale_enc = - use_mx_scale ? 1.0f : (*tensor_amax) / (F8E4M3_MAX * F4E2M1_MAX); + use_mx_scale ? 1.0f : (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX); size_t offset = tidx + grid_dim_x * size_t(tidy); size_t oindex = offset * pack_factor; @@ -159,14 +159,14 @@ void fp_quantize( const array& w, array& wq, array& scales, - const std::optional& tensor_amax /* = std::nullopt */, + const std::optional& global_scale /* = std::nullopt */, int group_size, int bits, cu::CommandEncoder& enc, const Stream& s) { enc.set_input_array(w); - if (tensor_amax.has_value()) { - enc.set_input_array(tensor_amax.value()); + if (global_scale.has_value()) { + enc.set_input_array(global_scale.value()); } enc.set_output_array(wq); enc.set_output_array(scales); @@ -192,8 +192,8 @@ void fp_quantize( gpu_ptr(wq), gpu_ptr(scales), w.size(), - tensor_amax.has_value() ? gpu_ptr(tensor_amax.value()) - : nullptr); + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not quantize input with type float64."); @@ -204,7 +204,7 @@ void fp_quantize( void fp_dequantize( const array& wq, const array& scales, - const std::optional& tensor_amax /* = std::nullopt */, + const std::optional& global_scale /* = std::nullopt */, array& w, int group_size, int bits, @@ -220,8 +220,8 @@ void fp_dequantize( enc.set_input_array(wq); enc.set_input_array(scales); - if (tensor_amax.has_value()) { - enc.set_input_array(tensor_amax.value()); + if (global_scale.has_value()) { + enc.set_input_array(global_scale.value()); } enc.set_output_array(w); dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) { @@ -244,8 +244,8 @@ void fp_dequantize( gpu_ptr(scales), gpu_ptr(w), w.size(), - tensor_amax.has_value() ? gpu_ptr(tensor_amax.value()) - : nullptr); + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not dequantize to output with type float64."); diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index c008af3e44..691731c8ba 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -13,8 +13,6 @@ namespace mlx::core { namespace { -using QuantizedTensor = std::tuple>; - struct GemmScalars { std::optional alpha_device; std::optional beta_device; @@ -58,13 +56,14 @@ array pad_and_swizzle_scales( return scale_tiled; } -QuantizedTensor quantize_input( +std::tuple quantize_input( const array& input, cu::CommandEncoder& encoder, const Stream& s, QuantizationMode mode, int bits, - int group_size) { + int group_size, + std::optional global_scale = std::nullopt) { const array x = ensure_row_contiguous(input, encoder, s); // Compute output shapes @@ -90,44 +89,14 @@ QuantizedTensor quantize_input( cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8); encoder.add_temporary(x_q); encoder.add_temporary(scales_x); - - // For NVFP4: compute tensor-wide amax for global scaling - std::optional tensor_amax = std::nullopt; - if (mode == QuantizationMode::Nvfp4) { - array amax(cu::malloc_async(sizeof(float), encoder), {}, float32); - encoder.add_temporary(amax); - all_reduce(encoder, x, amax, Reduce::ReduceType::AbsMax); - tensor_amax = amax; - } - - fp_quantize(x, x_q, scales_x, tensor_amax, group_size, bits, encoder, s); - return {std::move(x_q), std::move(scales_x), std::move(tensor_amax)}; -} - -QuantizedTensor get_weight_tensors( - const std::vector& inputs, - cu::CommandEncoder& encoder, - const Stream& s, - QuantizationMode mode, - int bits, - int group_size) { - // Check if weights need quantization - if (inputs[1].dtype() != uint32) { - return quantize_input(inputs[1], encoder, s, mode, bits, group_size); - } - - // Weights are pre-quantized - if (mode == QuantizationMode::Nvfp4) { - // NVFP4: inputs = [x, w_q, scale_w, tensor_amax_w] - return {inputs[1], inputs[2], inputs[3]}; - } - // MXFP8: inputs = [x, w_q, scale_w] (no tensor_amax) - return {inputs[1], inputs[2], std::nullopt}; + // global_scale is not nullopt only for NVFP4 + fp_quantize(x, x_q, scales_x, global_scale, group_size, bits, encoder, s); + return {std::move(x_q), std::move(scales_x)}; } GemmScalars create_nvfp4_scalars( - const array& tensor_amax_x, - const array& tensor_amax_w, + const array& global_scale_x, + const array& global_scale_w, cu::CommandEncoder& encoder) { // NVFP4 requires alpha/beta as device pointers // alpha = amax_x * amax_w / (448 * 6)^2 @@ -204,17 +173,24 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[QQMatmul::eval_gpu] QQMM requires compute capability 10.0+"); } + // input size = 2 for non-quantized w for qmode != nvfp4 + // input size = 3 for quantized w for qmode != nvfp4 + // input size = 4 for non-quantized w for qmode == nvfp4 + // input size = 5 for quantized w for qmode == nvfp4 + auto num_amax_inputs = mode_ == QuantizationMode::Nvfp4 ? 2 : 0; + auto size = + inputs[1].dtype() == uint32 ? 3 + num_amax_inputs : 2 + num_amax_inputs; - size_t expected_size = (mode_ == QuantizationMode::Nvfp4) ? 4 : 3; - assert( - (inputs.size() == expected_size && inputs[1].dtype() == uint32) || - (inputs.size() == 2)); + assert(inputs.size() == size); // Quantize inputs (or use pre-quantized) - auto [x_q, scale_x_pre, tensor_amax_x] = + auto [x_q, scale_x_pre] = quantize_input(inputs[0], encoder, s, mode_, bits_, group_size_); - auto [w_q, scale_w_pre, tensor_amax_w] = - get_weight_tensors(inputs, encoder, s, mode_, bits_, group_size_); + auto [w_q, scale_w_pre] = inputs[1].dtype() != uint32 + ? quantize_input(inputs[1], encoder, s, mode_, bits_, group_size_) + : std::make_tuple( + ensure_row_contiguous(inputs[1], encoder, s), + ensure_row_contiguous(inputs[2], encoder, s)); out.set_data(cu::malloc_async(out.nbytes(), encoder)); @@ -233,7 +209,7 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { GemmScalars scalars; if (mode_ == QuantizationMode::Nvfp4) { - scalars = create_nvfp4_scalars(*tensor_amax_x, *tensor_amax_w, encoder); + scalars = create_nvfp4_scalars(inputs[size - 2], inputs[size - 1], encoder); } run_qqmm( diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 88d7839bf2..cf48fc7f99 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -64,11 +64,11 @@ void fast::Quantize::eval_gpu( if (mode_ == QuantizationMode::Affine) { auto biases = ensure_row_contiguous(inputs[2], enc, s); affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); - } else if (mode_ == QuantizationMode::Nvfp4) { - auto tensor_amax = inputs[2]; - fp_dequantize(wq, scales, tensor_amax, w, group_size_, bits_, enc, s); } else { - fp_dequantize(wq, scales, {}, w, group_size_, bits_, enc, s); + // third input is global scale for nvfp4 + auto global_scale = + mode_ == QuantizationMode::Nvfp4 ? inputs[2] : std::nullopt; + fp_dequantize(wq, scales, global_scale, w, group_size_, bits_, enc, s); } } else { auto w = ensure_row_contiguous(inputs[0], enc, s); @@ -82,13 +82,10 @@ void fast::Quantize::eval_gpu( auto& biases = outputs[2]; biases.set_data(cu::malloc_async(biases.nbytes(), enc)); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); - } else if (mode_ == QuantizationMode::Nvfp4) { - auto& tensor_amax = outputs[2]; - tensor_amax.set_data(cu::malloc_async(tensor_amax.nbytes(), enc)); - all_reduce(enc, w, tensor_amax, Reduce::ReduceType::AbsMax); - fp_quantize(w, wq, scales, tensor_amax, group_size_, bits_, enc, s); } else { - fp_quantize(w, wq, scales, {}, group_size_, bits_, enc, s); + auto global_scale = + mode_ == QuantizationMode::Nvfp4 ? inputs[2] : std::nullopt; + fp_quantize(w, wq, scales, global_scale, group_size_, bits_, enc, s); } } } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 04a736c191..896d42e348 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4207,6 +4207,39 @@ std::pair validate_mode_with_type( } } +void validate_global_scale(const array& global_scale, QuantizationMode qmode) { + if (global_scale.has_value()) { + if (qmode != QuantizationMode::Nvfp4) { + std::ostringstream msg; + msg << "[quantize] Global scale can only be provided for 'nvfp4' " + << "quantization mode but mode '" << mode << "' was provided."; + throw std::invalid_argument(msg.str()); + } else { + if (global_scale->size() != 1) { + std::ostringstream msg; + msg << "[quantize] Global scale must be a scalar but got an array " + << "with shape " << global_scale->shape() << "."; + throw std::invalid_argument(msg.str()); + } + // TODO: not sure about the type + if (!issubdtype(global_scale->dtype(), floating)) { + std::ostringstream msg; + msg << "[quantize] Global scale must be a floating type but got type " + << global_scale->dtype() << "."; + throw std::invalid_argument(msg.str()); + } + } + } else { + if (qmode == QuantizationMode::Nvfp4) { + std::ostringstream msg; + msg << "[quantize] Global scale must be provided for 'nvfp4' " + << "quantization mode."; + throw std::invalid_argument(msg.str()); + } + return; + } +} + array quantized_matmul( array x, array w, @@ -4264,9 +4297,10 @@ void validate_qqmm_inputs( array x, array w, std::optional scales_w, - std::optional tensor_amax_w, int group_size, int bits, + std::optional global_scale_x, + std::optional global_scale_w, QuantizationMode qmode) { // check 2D (for now) if (x.ndim() > 2 || w.ndim() > 2) { @@ -4283,28 +4317,11 @@ void validate_qqmm_inputs( throw std::invalid_argument( "[qqmm] Scales must be provided if second argument is quantized."); } - if (qmode == QuantizationMode::Nvfp4) { - // nvfp4 quantization requires tensor amax - if (!tensor_amax_w.has_value()) { - std::ostringstream msg; - msg << "[qqmm] The 'tensor_amax_w' argument must be provided" - << " with 'nvfp4' quantization mode."; - throw std::invalid_argument(msg.str()); - } - } // if scales are provided, check compatibility with quantized w else { validate_quantized_input("qqmm", w, *scales_w, group_size, bits); - // other quantization modes do not support tensor amax - if (tensor_amax_w.has_value()) { - std::ostringstream msg; - msg << "[qqmm] The 'tensor_amax_w' argument is only supported" - << " with 'nvfp4' quantization mode."; - throw std::invalid_argument(msg.str()); - } } } - // if w is not quantized, dtype must be in {f16, bf16, fp32} else { if (!issubdtype(w.dtype(), floating) || w.dtype() == float64) { @@ -4321,6 +4338,9 @@ void validate_qqmm_inputs( << "first argument dtype == " << x.dtype() << "."; throw std::invalid_argument(msg.str()); } + // validate global scales + validate_global_scale(global_scale_x, qmode); + validate_global_scale(global_scale_w, qmode); } std::pair extract_qqmm_dims( @@ -4357,10 +4377,11 @@ array qqmm( array in_x, array w, std::optional scales_w, - std::optional tensor_amax_w, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "nvfp4" */, + const std::optional global_scale_x /* = std::nullopt */, + const std::optional global_scale_w /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto stream = to_stream(s); if (stream.device != Device::gpu || !cu::is_available()) { @@ -4389,8 +4410,10 @@ array qqmm( } else if (w.ndim() == 2 && x.ndim() > 2) { x = flatten(x, 0, -2, s); } + // validate inputs - validate_qqmm_inputs(x, w, scales_w, tensor_amax_w, group_size, bits, qmode); + validate_qqmm_inputs( + x, w, scales_w, group_size, bits, global_scale_x, global_scale_w); // validate and extract shapes auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims(x, w, scales_w, group_size, bits); @@ -4401,8 +4424,9 @@ array qqmm( if (scales_w.has_value()) { inputs.push_back(*scales_w); } - if (tensor_amax_w.has_value()) { - inputs.push_back(*tensor_amax_w); + if (global_scale_x.has_value() && global_scale_w.has_value()) { + inputs.push_back(*global_scale_x); + inputs.push_back(*global_scale_w); } auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; @@ -4539,6 +4563,7 @@ std::vector fp_quantize( int group_size, int bits, QuantizationMode mode, + const std::optional& global_scale /* = std::nullopt */, Stream s) { int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32; int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4; @@ -4556,18 +4581,32 @@ std::vector fp_quantize( << bits << "."; throw std::invalid_argument(msg.str()); } - constexpr float F8E4M3_MAX = 448.0f; - constexpr float F4E2M1_MAX = 6.0f; - auto fallback = [bits = bits, group_size = group_size, mode = mode, s]( + auto fallback = [bits = bits, group_size = group_size, s]( const std::vector& inputs) -> std::vector { auto& w = inputs[0]; - float maxval = (bits == 4) ? F4E2M1_MAX : F8E4M3_MAX; + float maxval = (bits == 4) ? 6.0f : 448.0f; auto new_shape = w.shape(); new_shape.back() = -1; auto wq = reshape(w, {-1, group_size}, s); - auto block_amax = max(abs(wq, s), -1, true, s); - - auto quantize_to_fp4 = [&](array& wq_in) { + auto scales = + divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s); + if (group_size == 16) { + // convert to e4m3 + scales = to_fp8(scales, s); + wq = divide(wq, from_fp8(scales, w.dtype(), s), s); + } else { + // convert to e8m0 + auto z = array(0, scales.dtype()); + scales = where( + equal(scales, z, s), + z, + astype(round(log2(scales, s), s), int32, s), + s); + + wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); + scales = astype(add(scales, array(127, int32), s), uint8, s); + } + if (bits == 4) { auto lut = array({ +0.0f, +0.5f, @@ -4587,41 +4626,11 @@ std::vector fp_quantize( -6.0f, }); lut = astype(lut, w.dtype(), s); - wq_in = argmin( - abs(subtract(expand_dims(wq_in, -1, s), lut, s), s), -1, false, s); + wq = argmin( + abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s); auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); - wq_in = reshape(wq_in, {-1, 4, 8}, s); - wq_in = sum(multiply(wq_in, shifts, s), -1, false, s); - }; - - if (mode == QuantizationMode::Nvfp4) { - auto tensor_amax = astype(max(abs(w, s), s), float32, s); - // Global encode scale: (448 * 6) / tensor_amax - auto scale_enc = divide(array(F8E4M3_MAX * F4E2M1_MAX), tensor_amax, s); - // Per-block decode scale: (block_amax / 6) * scale_enc - auto scales = multiply( - divide(block_amax, array(F4E2M1_MAX, w.dtype()), s), scale_enc, s); - // Convert to e4m3 - scales = to_fp8(scales, s); - // Per-block encode scale: scale_enc / scale_dec_b - auto scale_enc_b = divide(scale_enc, from_fp8(scales, w.dtype(), s), s); - wq = multiply(wq, scale_enc_b, s); - quantize_to_fp4(wq); - wq = reshape(wq, new_shape, s); - scales = reshape(scales, new_shape, s); - return {std::move(wq), std::move(scales), std::move(tensor_amax)}; - } - - auto scales = divide(block_amax, array(maxval, w.dtype()), s); - auto z = array(0, scales.dtype()); - scales = where( - equal(scales, z, s), z, astype(round(log2(scales, s), s), int32, s), s); - - wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); - scales = astype(add(scales, array(127, int32), s), uint8, s); - - if (bits == 4) { - quantize_to_fp4(wq); + wq = reshape(wq, {-1, 4, 8}, s); + wq = sum(multiply(wq, shifts, s), -1, false, s); } else { wq = view(to_fp8(wq, s), uint32, s); } @@ -4629,29 +4638,24 @@ std::vector fp_quantize( scales = reshape(scales, new_shape, s); return {std::move(wq), std::move(scales)}; }; + auto inputs = std::vector{w}; + if (global_scale.has_value()) { + inputs.push_back(global_scale.value()); + } if (s.device == Device::gpu) { auto wq_shape = w.shape(); wq_shape.back() = w.shape(-1) * bits / 32; auto sshape = w.shape(); sshape.back() = w.shape(-1) / group_size; - // nvfp4 fp tensor scale - // TODO: should we try to have w.dtype() here? - std::vector shapes = {std::move(wq_shape), std::move(sshape)}; - std::vector dtypes = {uint32, uint8}; - - if (mode == QuantizationMode::Nvfp4) { - shapes.push_back({}); - dtypes.push_back(float32); - } return array::make_arrays( - std::move(shapes), - std::move(dtypes), + {std::move(wq_shape), std::move(sshape)}, + {uint32, uint8}, std::make_shared( s, fallback, group_size, bits, mode, false), - {w}); + inputs); } - return fallback({w}); + return fallback(inputs); } std::vector quantize( @@ -4659,6 +4663,7 @@ std::vector quantize( std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, + const std::optional& global_scale /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto qmode = string_to_quantization_mode(mode, "quantize"); auto [group_size, bits] = @@ -4686,10 +4691,12 @@ std::vector quantize( throw std::invalid_argument(msg.str()); } + validate_global_scale("quantize", global_scale, qmode); + if (qmode == QuantizationMode::Affine) { return affine_quantize(w, group_size, bits, s); } else { - return fp_quantize(w, group_size, bits, qmode, to_stream(s)); + return fp_quantize(w, group_size, bits, qmode, global_scale, to_stream(s)); } } @@ -4790,11 +4797,11 @@ array affine_dequantize( array fp_dequantize( const array& w, const array& scales, - const std::optional& tensor_amax, int group_size, int bits, Dtype out_type, QuantizationMode mode, + const std::optional& global_scale /* = std::nullopt */, Stream s) { int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32; int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4; @@ -4839,15 +4846,12 @@ array fp_dequantize( throw std::invalid_argument(msg.str()); } - constexpr float F8E4M3_MAX = 448.0f; - constexpr float F4E2M1_MAX = 6.0f; auto fallback = [wshape = std::move(wshape), sshape = std::move(sshape), group_size, bits, out_type, - mode, s](const std::vector& inputs) mutable -> std::vector { auto out = inputs[0]; auto scales = inputs[1]; @@ -4883,16 +4887,7 @@ array fp_dequantize( } out = reshape(out, {-1, group_size}, s); scales = reshape(scales, {-1, 1}, s); - if (mode == QuantizationMode::Nvfp4) { - auto tensor_amax = inputs[2]; - // scale_dec_b stored as FP8 e4m3 - scales = from_fp8(scales, out_type, s); - // inv_scale_enc = tensor_amax / (448 * 6) - auto inv_scale_enc = - divide(tensor_amax, array(F8E4M3_MAX * F4E2M1_MAX, out_type), s); - // final scale = scale_dec_b * inv_scale_enc - scales = multiply(scales, inv_scale_enc, s); - } else if (mode == QuantizationMode::Mxfp4) { + if (group_size == 16) { scales = from_fp8(scales, out_type, s); } else { scales = subtract(astype(scales, out_type, s), array(127, out_type), s); @@ -4903,20 +4898,12 @@ array fp_dequantize( if (s.device == Device::gpu) { auto out_shape = w.shape(); out_shape.back() = out_size; - auto inputs = std::vector{w, scales}; - if (mode == QuantizationMode::Nvfp4) { - inputs.push_back(*tensor_amax); - } - return array( std::move(out_shape), out_type, std::make_shared( s, fallback, group_size, bits, mode, true), - std::move(inputs)); - } - if (mode == QuantizationMode::Nvfp4) { - return fallback({w, scales, *tensor_amax})[0]; + {w, scales}); } return fallback({w, scales})[0]; } @@ -4924,11 +4911,11 @@ array fp_dequantize( array dequantize( const array& w, const array& scales, - const std::optional& tensor_amax /* = std::nullopt */, const std::optional& biases /* = std::nullopt */, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, + const std::optional& global_scale /* = std::nullopt */, std::optional dtype /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto [out_type, qmode] = @@ -4955,6 +4942,7 @@ array dequantize( << "but it has only " << w.ndim() << "."; throw std::invalid_argument(msg.str()); } + validate_global_scale("dequantize", global_scale, qmode); if (qmode == QuantizationMode::Affine) { return astype( @@ -4965,7 +4953,7 @@ array dequantize( return fp_dequantize( w, scales, - tensor_amax, + global_scale, group_size, bits, out_type, @@ -6169,4 +6157,4 @@ array contiguous( {a}); } -} // namespace mlx::core +} // namespace mlx::core \ No newline at end of file From 2c68fb62a030e0b9774b029704fc23aff0d194b0 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 21 Jan 2026 16:22:41 +0100 Subject: [PATCH 09/34] fix global_scale --- mlx/backend/cuda/quantized/fp_quantize.cu | 4 +- mlx/backend/cuda/quantized/qqmm.cpp | 21 ++-- mlx/backend/cuda/quantized/qqmm_utils.cu | 8 +- mlx/backend/cuda/quantized/quantized.cpp | 14 +-- mlx/backend/cuda/quantized/quantized.h | 4 +- mlx/ops.cpp | 111 ++++++++++++++-------- mlx/ops.h | 12 +-- mlx/primitives.cpp | 37 ++++++-- python/src/ops.cpp | 24 +++-- tests/ops_tests.cpp | 2 +- 10 files changed, 154 insertions(+), 83 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 66d6465f16..c2c7aa7ab6 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -159,9 +159,9 @@ void fp_quantize( const array& w, array& wq, array& scales, - const std::optional& global_scale /* = std::nullopt */, int group_size, int bits, + const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { enc.set_input_array(w); @@ -204,10 +204,10 @@ void fp_quantize( void fp_dequantize( const array& wq, const array& scales, - const std::optional& global_scale /* = std::nullopt */, array& w, int group_size, int bits, + const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { constexpr int uint8_per_uint32 = 4; diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 691731c8ba..c9a6ea9523 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -90,7 +90,7 @@ std::tuple quantize_input( encoder.add_temporary(x_q); encoder.add_temporary(scales_x); // global_scale is not nullopt only for NVFP4 - fp_quantize(x, x_q, scales_x, global_scale, group_size, bits, encoder, s); + fp_quantize(x, x_q, scales_x, group_size, bits, global_scale, encoder, s); return {std::move(x_q), std::move(scales_x)}; } @@ -103,7 +103,7 @@ GemmScalars create_nvfp4_scalars( // beta = 0 array alpha(cu::malloc_async(sizeof(float), encoder), {}, float32); array beta(cu::malloc_async(sizeof(float), encoder), {}, float32); - compute_qqmm_pointers(alpha, beta, tensor_amax_x, tensor_amax_w, encoder); + compute_qqmm_pointers(alpha, beta, global_scale_x, global_scale_w, encoder); encoder.add_temporary(alpha); encoder.add_temporary(beta); return {alpha, beta}; @@ -183,11 +183,20 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == size); + // For nvfp4, get global scales from inputs + std::optional global_scale_x = std::nullopt; + std::optional global_scale_w = std::nullopt; + if (mode_ == QuantizationMode::Nvfp4) { + global_scale_x = inputs[size - 2]; + global_scale_w = inputs[size - 1]; + } + // Quantize inputs (or use pre-quantized) - auto [x_q, scale_x_pre] = - quantize_input(inputs[0], encoder, s, mode_, bits_, group_size_); + auto [x_q, scale_x_pre] = quantize_input( + inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x); auto [w_q, scale_w_pre] = inputs[1].dtype() != uint32 - ? quantize_input(inputs[1], encoder, s, mode_, bits_, group_size_) + ? quantize_input( + inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w) : std::make_tuple( ensure_row_contiguous(inputs[1], encoder, s), ensure_row_contiguous(inputs[2], encoder, s)); @@ -209,7 +218,7 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { GemmScalars scalars; if (mode_ == QuantizationMode::Nvfp4) { - scalars = create_nvfp4_scalars(inputs[size - 2], inputs[size - 1], encoder); + scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder); } run_qqmm( diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu index d429adbbb7..d19865a3b3 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.cu +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -254,10 +254,10 @@ void compute_qqmm_pointers( dim3(1), dim3(1), 0, - gpu_ptr(alpha_out), - gpu_ptr(beta_out), - gpu_ptr(tensor_amax_x), - gpu_ptr(tensor_amax_w)); + gpu_ptr(alpha_out), + gpu_ptr(beta_out), + gpu_ptr(tensor_amax_x), + gpu_ptr(tensor_amax_w)); } } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index cf48fc7f99..e101491f6c 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -66,9 +66,10 @@ void fast::Quantize::eval_gpu( affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); } else { // third input is global scale for nvfp4 - auto global_scale = - mode_ == QuantizationMode::Nvfp4 ? inputs[2] : std::nullopt; - fp_dequantize(wq, scales, global_scale, w, group_size_, bits_, enc, s); + std::optional global_scale = mode_ == QuantizationMode::Nvfp4 + ? std::make_optional(inputs[2]) + : std::nullopt; + fp_dequantize(wq, scales, w, group_size_, bits_, global_scale, enc, s); } } else { auto w = ensure_row_contiguous(inputs[0], enc, s); @@ -83,9 +84,10 @@ void fast::Quantize::eval_gpu( biases.set_data(cu::malloc_async(biases.nbytes(), enc)); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); } else { - auto global_scale = - mode_ == QuantizationMode::Nvfp4 ? inputs[2] : std::nullopt; - fp_quantize(w, wq, scales, global_scale, group_size_, bits_, enc, s); + std::optional global_scale = mode_ == QuantizationMode::Nvfp4 + ? std::make_optional(inputs[1]) + : std::nullopt; + fp_quantize(w, wq, scales, group_size_, bits_, global_scale, enc, s); } } } diff --git a/mlx/backend/cuda/quantized/quantized.h b/mlx/backend/cuda/quantized/quantized.h index fbeb62d6f1..93fba563a5 100644 --- a/mlx/backend/cuda/quantized/quantized.h +++ b/mlx/backend/cuda/quantized/quantized.h @@ -30,19 +30,19 @@ void fp_quantize( const array& w, array& wq, array& scales, - const std::optional& tensor_amax, int group_size, int bits, + const std::optional& global_scale, cu::CommandEncoder& enc, const Stream& s); void fp_dequantize( const array& wq, const array& scales, - const std::optional& tensor_amax, array& w, int group_size, int bits, + const std::optional& global_scale, cu::CommandEncoder& enc, const Stream& s); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 896d42e348..b57fc71fee 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4207,24 +4207,28 @@ std::pair validate_mode_with_type( } } -void validate_global_scale(const array& global_scale, QuantizationMode qmode) { +void validate_global_scale( + std::string_view tag, + QuantizationMode qmode, + const std::optional& global_scale) { if (global_scale.has_value()) { if (qmode != QuantizationMode::Nvfp4) { std::ostringstream msg; - msg << "[quantize] Global scale can only be provided for 'nvfp4' " - << "quantization mode but mode '" << mode << "' was provided."; + msg << "[" << tag << "] Global scale is only supported for 'nvfp4' " + << "quantization mode."; throw std::invalid_argument(msg.str()); } else { if (global_scale->size() != 1) { std::ostringstream msg; - msg << "[quantize] Global scale must be a scalar but got an array " - << "with shape " << global_scale->shape() << "."; + msg << "[" << tag << "] Global scale must be a scalar but got shape " + << global_scale->shape() << "."; throw std::invalid_argument(msg.str()); } // TODO: not sure about the type if (!issubdtype(global_scale->dtype(), floating)) { std::ostringstream msg; - msg << "[quantize] Global scale must be a floating type but got type " + msg << "[" << tag + << "] Global scale must be a floating type but got type " << global_scale->dtype() << "."; throw std::invalid_argument(msg.str()); } @@ -4232,7 +4236,7 @@ void validate_global_scale(const array& global_scale, QuantizationMode qmode) { } else { if (qmode == QuantizationMode::Nvfp4) { std::ostringstream msg; - msg << "[quantize] Global scale must be provided for 'nvfp4' " + msg << "[" << tag << "] Global scale must be provided for 'nvfp4' " << "quantization mode."; throw std::invalid_argument(msg.str()); } @@ -4282,7 +4286,7 @@ array quantized_matmul( if (x.ndim() > 2 && w.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } - + // TODO: add global scale auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; return array( @@ -4339,8 +4343,8 @@ void validate_qqmm_inputs( throw std::invalid_argument(msg.str()); } // validate global scales - validate_global_scale(global_scale_x, qmode); - validate_global_scale(global_scale_w, qmode); + validate_global_scale("qqmm", qmode, global_scale_x); + validate_global_scale("qqmm", qmode, global_scale_w); } std::pair extract_qqmm_dims( @@ -4413,7 +4417,7 @@ array qqmm( // validate inputs validate_qqmm_inputs( - x, w, scales_w, group_size, bits, global_scale_x, global_scale_w); + x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode); // validate and extract shapes auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims(x, w, scales_w, group_size, bits); @@ -4581,32 +4585,35 @@ std::vector fp_quantize( << bits << "."; throw std::invalid_argument(msg.str()); } - auto fallback = [bits = bits, group_size = group_size, s]( + auto fallback = [bits = bits, group_size = group_size, mode = mode, s]( const std::vector& inputs) -> std::vector { - auto& w = inputs[0]; + auto w = inputs[0]; float maxval = (bits == 4) ? 6.0f : 448.0f; + auto new_shape = w.shape(); new_shape.back() = -1; auto wq = reshape(w, {-1, group_size}, s); - auto scales = - divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s); + auto group_amax = max(abs(wq, s), -1, true, s); + if (group_size == 16) { + // NVFP4: scale_dec = (group_amax / 6) * (448 * 6) / global_scale + // = group_amax * 448 / global_scale + array scales; + if (mode == QuantizationMode::Nvfp4 && inputs.size() > 1) { + // scale_dec = group_amax * 448 / global_scale + scales = divide( + multiply(group_amax, array(448.0f, w.dtype()), s), inputs[1], s); + } else { + // Without global_scale: scale_dec = group_amax / 6 + scales = divide(group_amax, array(maxval, w.dtype()), s); + } // convert to e4m3 scales = to_fp8(scales, s); - wq = divide(wq, from_fp8(scales, w.dtype(), s), s); - } else { - // convert to e8m0 - auto z = array(0, scales.dtype()); - scales = where( - equal(scales, z, s), - z, - astype(round(log2(scales, s), s), int32, s), - s); + // quantized = w * 6 / group_amax + wq = divide(multiply(wq, array(maxval, w.dtype()), s), group_amax, s); + wq = reshape(wq, new_shape, s); + scales = reshape(scales, new_shape, s); - wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); - scales = astype(add(scales, array(127, int32), s), uint8, s); - } - if (bits == 4) { auto lut = array({ +0.0f, +0.5f, @@ -4631,12 +4638,26 @@ std::vector fp_quantize( auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); wq = reshape(wq, {-1, 4, 8}, s); wq = sum(multiply(wq, shifts, s), -1, false, s); + wq = reshape(wq, new_shape, s); + + return {std::move(wq), std::move(scales)}; } else { + auto scales = divide(group_amax, array(maxval, w.dtype()), s); + auto z = array(0, scales.dtype()); + scales = where( + equal(scales, z, s), + z, + astype(round(log2(scales, s), s), int32, s), + s); + + wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); + scales = astype(add(scales, array(127, int32), s), uint8, s); + wq = view(to_fp8(wq, s), uint32, s); + wq = reshape(wq, new_shape, s); + scales = reshape(scales, new_shape, s); + return {std::move(wq), std::move(scales)}; } - wq = reshape(wq, new_shape, s); - scales = reshape(scales, new_shape, s); - return {std::move(wq), std::move(scales)}; }; auto inputs = std::vector{w}; @@ -4691,8 +4712,7 @@ std::vector quantize( throw std::invalid_argument(msg.str()); } - validate_global_scale("quantize", global_scale, qmode); - + validate_global_scale("quantize", qmode, global_scale); if (qmode == QuantizationMode::Affine) { return affine_quantize(w, group_size, bits, s); } else { @@ -4852,6 +4872,7 @@ array fp_dequantize( group_size, bits, out_type, + mode = mode, s](const std::vector& inputs) mutable -> std::vector { auto out = inputs[0]; auto scales = inputs[1]; @@ -4888,13 +4909,27 @@ array fp_dequantize( out = reshape(out, {-1, group_size}, s); scales = reshape(scales, {-1, 1}, s); if (group_size == 16) { + // NVFP4: decode scale from fp8 e4m3 scales = from_fp8(scales, out_type, s); + // For nvfp4 with global_scale: effective_scale = scale * global_scale / + // (448 * 6) + if (mode == QuantizationMode::Nvfp4 && inputs.size() > 2) { + scales = divide( + multiply(scales, inputs[2], s), array(448.0f * 6.0f, out_type), s); + } } else { + // MXFP8: decode e8m0 scale scales = subtract(astype(scales, out_type, s), array(127, out_type), s); scales = power(array(2.0f, out_type), scales, s); } - return {reshape(multiply(out, scales, s), wshape, s)}; + out = reshape(multiply(out, scales, s), wshape, s); + + return {out}; }; + auto inputs = std::vector{w, scales}; + if (global_scale.has_value()) { + inputs.push_back(global_scale.value()); + } if (s.device == Device::gpu) { auto out_shape = w.shape(); out_shape.back() = out_size; @@ -4903,9 +4938,9 @@ array fp_dequantize( out_type, std::make_shared( s, fallback, group_size, bits, mode, true), - {w, scales}); + inputs); } - return fallback({w, scales})[0]; + return fallback(inputs)[0]; } array dequantize( @@ -4942,7 +4977,7 @@ array dequantize( << "but it has only " << w.ndim() << "."; throw std::invalid_argument(msg.str()); } - validate_global_scale("dequantize", global_scale, qmode); + validate_global_scale("dequantize", qmode, global_scale); if (qmode == QuantizationMode::Affine) { return astype( @@ -4953,11 +4988,11 @@ array dequantize( return fp_dequantize( w, scales, - global_scale, group_size, bits, out_type, qmode, + global_scale, to_stream(s)); } } diff --git a/mlx/ops.h b/mlx/ops.h index 7f2c65993f..eea6b93377 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1407,31 +1407,31 @@ std::vector quantize( std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "affine", + const std::optional& global_scale = std::nullopt, StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ array dequantize( const array& w, const array& scales, - const std::optional& tensor_amax = std::nullopt, const std::optional& biases = std::nullopt, std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "affine", + const std::optional& global_scale = std::nullopt, std::optional dtype = std::nullopt, StreamOrDevice s = {}); array qqmm( array x, // input activations array w, // maybe quantized weights - std::optional w_scales = std::nullopt, // optional scales if w is - // quantized - std::optional w_tensor_scale = - std::nullopt, // optional tensor amax if - // w is nvfp4 quantized + const std::optional w_scales = std::nullopt, // optional scales if w + // is quantized std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "nvfp4", + const std::optional global_scale_x = std::nullopt, + const std::optional global_scale_w = std::nullopt, StreamOrDevice s = {}); /** Convert an E4M3 float8 to the given floating point dtype. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 984456eeb9..8bd1fe68d5 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3421,10 +3421,10 @@ std::vector QuantizedMatmul::vjp( primals[1], ones_like(primals[2], stream()), zeros_like(primals[3], stream()), - {}, group_size_, bits_, quantization_mode_to_string(mode_), + {}, // placeholder for amax std::nullopt, stream()); wq = unflatten(wq, -1, {-1, group_size_}, stream()); @@ -3485,14 +3485,24 @@ std::vector QQMatmul::output_shapes(const std::vector& inputs) { } std::vector QQMatmul::vjp( - const std::vector& primals, // non quantized x, non quantized w + const std::vector& primals, // non quantized x, non quantized w, if + // nvfp4 global_scale_x, global_scale_w const std::vector& cotangents, // non quantized upstream grads const std::vector& argnums, const std::vector&) { - if (primals.size() != 2) { - throw std::runtime_error( - "[QQMatmul::vjp] Expected exactly 2 non-quantized primal inputs (x, w)."); + bool is_nvfp4 = (mode_ == QuantizationMode::Nvfp4); + auto expected_size = is_nvfp4 ? 4 : 2; + if (primals.size() != expected_size) { + auto msg = std::ostringstream(); + msg << "[QQMatmul::vjp] Expected exactly " << expected_size + << " non-quantized primal inputs (x, w"; + if (mode_ == QuantizationMode::Nvfp4) { + msg << ", global_scale_x, global_scale_w"; + } + msg << ")."; + throw std::runtime_error(msg.str()); } + std::vector vjps; auto& cotan = cotangents[0]; std::vector reorder(cotan.ndim()); @@ -3503,6 +3513,13 @@ std::vector QQMatmul::vjp( // primal[0] -- non quantized activations (M, K) // cotan -- non quantized grads (M, N) auto qmode = quantization_mode_to_string(mode_); + std::optional cotan_amax = + is_nvfp4 ? std::make_optional(max(abs(cotan, s), s)) : std::nullopt; + + auto get_primal_scale = [&](int idx) { + return is_nvfp4 ? std::make_optional(primals[idx]) : std::nullopt; + }; + for (auto arg : argnums) { // TODO: we need a kernel that will quantize columnwise + transpose if (arg == 0) { // gradient wrt to x @@ -3510,21 +3527,23 @@ std::vector QQMatmul::vjp( vjps.push_back(qqmm( cotan, // M X N swapaxes(primals[1], -1, -2, s), // assuming that w is 2D - {}, - {}, + std::nullopt, group_size_, bits_, qmode, + cotan_amax, + get_primal_scale(2), // global_scale_x s)); } else if (arg == 1) { // gradient wrt to weights vjps.push_back(qqmm( swapaxes(cotan, -1, -2, s), // (N, M) swapaxes(primals[0], -1, -2, s), // (K, M) {}, - {}, group_size_, bits_, qmode, + cotan_amax, + get_primal_scale(3), // global_scale_w s)); } } @@ -3646,10 +3665,10 @@ std::vector GatherQMM::vjp( w, ones_like(scales, stream()), zeros_like(*biases, stream()), - {}, group_size_, bits_, quantization_mode_to_string(mode_), + {}, // placeholder for amax std::nullopt, stream()), -1, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index c9550a7bca..2b9f69e2cc 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4254,10 +4254,11 @@ void init_ops(nb::module_& m) { "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "affine", + "global_scale"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), + "def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, global_scale: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), R"pbdoc( Quantize the array ``w``. @@ -4282,14 +4283,14 @@ void init_ops(nb::module_& m) { ``w`` in the quantized array. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. + global_scale (array, optional): The per-input float32 scale used for + ``"nvfp4"`` quantization. Default: ``None``. Returns: tuple: A tuple with either two or three elements containing: * w_q (array): The quantized version of ``w`` * scales (array): The quantization scales - * tensor_scale (array): The per-tensor float32 absolute max - scale (returned for ``mode == "nvfp4"``) * biases (array): The quantization biases (returned for ``mode=="affine"``). Notes: @@ -4349,24 +4350,22 @@ void init_ops(nb::module_& m) { &mx::dequantize, nb::arg(), "scales"_a, - "tensor_scale"_a = nb::none(), "biases"_a = nb::none(), "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "affine", + "global_scale"_a = nb::none(), "dtype"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', global_scale: Optional[array] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dequantize the matrix ``w`` using quantization parameters. Args: w (array): Matrix to be dequantized scales (array): The scales to use per ``group_size`` elements of ``w``. - tensor_scale (array, optional): The per-tensor float32 scale used for - ``"nvfp4"`` quantization. Default: ``None``. biases (array, optional): The biases to use per ``group_size`` elements of ``w``. Default: ``None``. group_size (int, optional): The size of the group in ``w`` that shares a @@ -4375,6 +4374,8 @@ void init_ops(nb::module_& m) { bits (int, optional): The number of bits occupied by each element of ``w`` in the quantized array. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. + global_scale (array, optional): The per-input float32 scale used for + ``"nvfp4"`` quantization. Default: ``None``. dtype (Dtype, optional): The data type of the dequantized output. If ``None`` the return type is inferred from the scales and biases when possible and otherwise defaults to ``bfloat16``. @@ -5467,10 +5468,12 @@ void init_ops(nb::module_& m) { "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "nvfp4", + "global_scale_x"_a = nb::none(), + "global_scale_w"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def qqmm(x: array, w: array, scales: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', *, stream: Union[None, Stream, Device] = None) -> array"), + "def qqmm(x: array, w: array, scales: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', global_scale_x: Optional[array] = None, global_scale_w: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform a matrix multiplication using a possibly quantized weight matrix ``w`` and a non-quantized input ``x``. The input ``x`` is quantized on the @@ -5505,7 +5508,10 @@ void init_ops(nb::module_& m) { mode (str, optional): The quantization mode. Default: ``"nvfp4"``. Supported modes are ``nvfp4`` and ``mxfp8``. See the :ref:`table of quantization modes ` for details. - + global_scale (array, optional): The per-input float32 scale used for x + ``"nvfp4"`` quantization. Default: ``None``. + global_scale_w (array, optional): The per-input float32 scale used for w + ``"nvfp4"`` quantization. Default: ``None``. Returns: array: The result of the multiplication of quantized ``x`` with quantized ``w``. needed). diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 6c8cd8ab42..62fd8c5923 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3087,7 +3087,7 @@ TEST_CASE("test quantize dequantize") { CHECK_EQ(scales.shape(), Shape{128, 4}); CHECK_EQ(biases.shape(), Shape{128, 4}); - auto x_hat = dequantize(x_q, scales, {}, biases, 128, i); + auto x_hat = dequantize(x_q, scales, biases, 128, i); auto max_diff = max(abs(x - x_hat)).item(); CHECK(max_diff <= 127.0 / (1 << i)); } From 277ceebbfc4d7f5d04156c7b40eb229c7d3bf202 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 21 Jan 2026 19:04:34 +0100 Subject: [PATCH 10/34] fix scale to be float(fp8e4m3(scale)) --- examples/python/qqmm.py | 52 ++++++++++++++++++++--- mlx/backend/cuda/quantized/fp_quantize.cu | 2 +- mlx/ops.cpp | 13 ++---- python/mlx/nn/layers/quantized.py | 22 +++++++++- python/src/ops.cpp | 1 - 5 files changed, 72 insertions(+), 18 deletions(-) diff --git a/examples/python/qqmm.py b/examples/python/qqmm.py index 5be7eae2f3..f0745f19e1 100644 --- a/examples/python/qqmm.py +++ b/examples/python/qqmm.py @@ -38,7 +38,18 @@ def test_qqmm(): for dtype in dtypes: x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) - w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) + x_amax = ( + mx.abs(x).max().astype(mx.float32) if group_size == 16 else None + ) + w_amax = ( + mx.abs(w).max().astype(mx.float32) if group_size == 16 else None + ) + w_q, scales_w = mx.quantize( + w, group_size, bits, mode=mode, global_scale=w_amax + ) + x_q, scales_x = mx.quantize( + x, group_size, bits, mode=mode, global_scale=x_amax + ) w_dq = mx.dequantize( w_q, scales_w, @@ -46,6 +57,7 @@ def test_qqmm(): bits=bits, mode=mode, dtype=dtype, + global_scale=w_amax, ) y_q = mx.qqmm( x, @@ -54,9 +66,11 @@ def test_qqmm(): group_size=group_size, bits=bits, mode=mode, + global_scale_x=x_amax, + global_scale_w=w_amax, ) x_q, scales_x = mx.quantize( - x, group_size=group_size, bits=bits, mode=mode + x, group_size=group_size, bits=bits, mode=mode, global_scale=x_amax ) x_dq = mx.dequantize( x_q, @@ -64,12 +78,16 @@ def test_qqmm(): group_size=group_size, bits=bits, mode=mode, + global_scale=x_amax, dtype=dtype, ) y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) ulp = ulp_bf16_at(y_hat) error = (y_q - y_hat).abs() if not (mx.logical_or(error < 1e-3, error <= ulp).all()): + import pdb + + pdb.set_trace() raise AssertionError( f"qqmm test failed for shape {(M, N, K)}, " f"group_size={group_size}, bits={bits}, " @@ -89,19 +107,43 @@ def test_qqmm_vjp(): ) x = mx.random.normal(shape=(M, K), key=k1) c = mx.ones(shape=(M, N)) + x_amax = mx.abs(x).max() if tests[0][0] == 16 else None for group_size, mode, bits in tests: w = mx.random.normal(shape=(N, K), key=k2) + x_amax = mx.abs(x).max() if group_size == 16 else None + w_amax = mx.abs(w).max() if group_size == 16 else None + c_amax = mx.abs(c).max() if group_size == 16 else None + def fn(x): - return mx.qqmm(x, w, group_size=group_size, bits=bits, mode=mode) + return mx.qqmm( + x, + w, + group_size=group_size, + bits=bits, + mode=mode, + global_scale_x=x_amax, + global_scale_w=w_amax, + ) _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,)) w_tq, scales_wt = mx.quantize( - mx.transpose(w), group_size=group_size, bits=bits, mode=mode + mx.transpose(w), + group_size=group_size, + bits=bits, + mode=mode, + global_scale=w_amax, ) expected_out = mx.qqmm( - c, w_tq, scales_wt, group_size=group_size, bits=bits, mode=mode + c, + w_tq, + scales_wt, + group_size=group_size, + bits=bits, + mode=mode, + global_scale_x=c_amax, + global_scale_w=w_amax, ) ulp = ulp_bf16_at(expected_out) error = (vjp_out[0] - expected_out).abs() diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index c2c7aa7ab6..7cbc683025 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -85,7 +85,7 @@ __global__ void fp_quantize( std::conditional_t; auto s = ScaleType(scale_dec_b); uint8_t q_scale = s.__x; - float scale_enc_b = scale_enc / float(scale_dec_b); + float scale_enc_b = scale_enc / float(s); scales[thread_idx] = q_scale; constexpr int elem_per_byte = bits == 8 ? 1 : 2; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 17fa5b5ca9..319ff00ac5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4600,15 +4600,10 @@ std::vector fp_quantize( if (group_size == 16) { // NVFP4: scale_dec = (group_amax / 6) * (448 * 6) / global_scale // = group_amax * 448 / global_scale - array scales; - if (mode == QuantizationMode::Nvfp4 && inputs.size() > 1) { - // scale_dec = group_amax * 448 / global_scale - scales = divide( - multiply(group_amax, array(448.0f, w.dtype()), s), inputs[1], s); - } else { - // Without global_scale: scale_dec = group_amax / 6 - scales = divide(group_amax, array(maxval, w.dtype()), s); - } + array scales = (mode == QuantizationMode::Nvfp4 && inputs.size() > 1) + ? divide( + multiply(group_amax, array(448.0f, w.dtype()), s), inputs[1], s) + : divide(group_amax, array(maxval, w.dtype()), s); // convert to e4m3 scales = to_fp8(scales, s); // quantized = w * 6 / group_amax diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 0e520d1517..bb82860600 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -334,6 +334,7 @@ def __init__( shape=(output_dims, input_dims), ) self._quantized = False + self._use_global_scale = self.mode == "nvfp4" def _extra_repr(self): out_dims, in_dims = self.weight.shape @@ -346,11 +347,15 @@ def _extra_repr(self): def quantize(self): if not self._quantized: + self.global_amax_w = ( + (self.weight).abs().max() if self._use_global_scale else None + ) self.weight, self.scales = mx.quantize( self.weight, self.group_size, self.bits, mode=self.mode, + global_amax=self.global_amax_w, ) self._quantized = True @@ -362,8 +367,10 @@ def dequantize(self): group_size=self.group_size, bits=self.bits, mode=self.mode, + global_amax=self.global_scale_w, ) - self.__delattr__("scales") + del self.scales + del self.global_amax_w self._quantized = False def _set_training_mode(self, mode: bool): @@ -375,13 +382,24 @@ def _set_training_mode(self, mode: bool): self.quantize() def __call__(self, x): + # TODO: In the future we can implement different policies for amax update + # for the activations as well as for the weights + # (for example for the weights it can be ema ) + global_scale_w = ( + getattr(self, "global_scale_w", (self.weight).abs().max()) + if self._use_global_scale + else None + ) + global_scale_x = (x).abs().max() if self._use_global_scale else None x = mx.qqmm( x, - self["weight"], + self.weight, scales=self.get("scales"), group_size=self.group_size, bits=self.bits, mode=self.mode, + global_amax_x=global_scale_x, + global_amax_w=global_scale_w, ) return x diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2b9f69e2cc..06e167680b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5464,7 +5464,6 @@ void init_ops(nb::module_& m) { nb::arg(), // x nb::arg(), // w_q "scales"_a = nb::none(), // scales w - "tensor_scale"_a = nb::none(), "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "nvfp4", From dad7e57e0aa143b7e8fe9ca0a04de34257f71f88 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 21 Jan 2026 19:15:34 +0100 Subject: [PATCH 11/34] removed AbsMax reduction (probably add back in the future as a separate PR) --- mlx/backend/cuda/quantized/quantized.h | 7 ------- mlx/backend/cuda/reduce/all_reduce.cu | 26 ++------------------------ mlx/backend/cuda/reduce/reduce.cuh | 2 -- mlx/backend/cuda/reduce/reduce_ops.cuh | 26 -------------------------- mlx/primitives.h | 4 +--- 5 files changed, 3 insertions(+), 62 deletions(-) diff --git a/mlx/backend/cuda/quantized/quantized.h b/mlx/backend/cuda/quantized/quantized.h index 93fba563a5..9b8e030cd2 100644 --- a/mlx/backend/cuda/quantized/quantized.h +++ b/mlx/backend/cuda/quantized/quantized.h @@ -2,7 +2,6 @@ #include #include "mlx/backend/cuda/device.h" -#include "mlx/primitives.h" namespace mlx::core { @@ -46,10 +45,4 @@ void fp_dequantize( cu::CommandEncoder& enc, const Stream& s); -void all_reduce( - cu::CommandEncoder& encoder, - const array& in, - array& out, - Reduce::ReduceType reduce_type); - } // namespace mlx::core diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 100fdb3866..1126f4cc76 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -13,20 +13,6 @@ namespace cu { namespace cg = cooperative_groups; -template -__device__ __forceinline__ T absmax_val(T x) { - if constexpr (cuda::std::is_same_v) { - return x; - } else if constexpr (cuda::std::is_unsigned_v) { - return x; // unsigned types are non-negative - } else if constexpr (cuda::std::is_floating_point_v) { - return fabs(x); - } else { - // signed integer - return x < T(0) ? -x : x; - } -} - template __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { // TODO: Process multiple "rows" in each thread @@ -51,11 +37,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { for (; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { - if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], absmax_val(cast_to(vals[j]))); - } else { - accs[0] = op(accs[0], cast_to(vals[j])); - } + accs[0] = op(accs[0], cast_to(vals[j])); } } @@ -63,11 +45,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { cub::LoadDirectBlocked( block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { - if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], absmax_val(cast_to(vals[i]))); - } else { - accs[0] = op(accs[0], cast_to(vals[i])); - } + accs[0] = op(accs[0], cast_to(vals[i])); } } diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 947e8b36dc..02e495594a 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -35,8 +35,6 @@ void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Min) { f(type_identity{}); - } else if (reduce_type == Reduce::ReduceType::AbsMax) { - f(type_identity{}); } else { throw std::invalid_argument("Unknown reduce type."); } diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 7b55bdc525..7f8cad0c4e 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -114,18 +114,6 @@ struct Max { } }; -struct AbsMax { - template - __device__ __forceinline__ T operator()(T a, T b) { - return a > b ? a : b; - } - - template - __device__ void atomic_update(T* x, T y) { - atomic_reduce(x, y); - } -}; - // Traits to get the result type of reduce op. template struct ReduceResult; @@ -166,12 +154,6 @@ struct ReduceResult { using type = T; }; -// TODO: this should not be hardcoded -template -struct ReduceResult { - using type = float; -}; - // Traits to get the init value of reduce op. template struct ReduceInit; @@ -226,12 +208,4 @@ struct ReduceInit { } }; -template -struct ReduceInit { - using result_type = typename ReduceResult::type; - static constexpr __host__ __device__ result_type value() { - return result_type(0); // abs values are >= 0 - } -}; - } // namespace mlx::core::cu diff --git a/mlx/primitives.h b/mlx/primitives.h index 1e9b14da66..c3ce00f92f 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1769,7 +1769,7 @@ class Reshape : public UnaryPrimitive { class Reduce : public UnaryPrimitive { public: - enum ReduceType { And, Or, Sum, Prod, Min, Max, AbsMax }; + enum ReduceType { And, Or, Sum, Prod, Min, Max }; explicit Reduce( Stream stream, @@ -1799,8 +1799,6 @@ class Reduce : public UnaryPrimitive { return "Min"; case Max: return "Max"; - case AbsMax: - return "AbsMax"; } return ""; } From 0a804a93a71ca47a127f978c63c1e38fff81670e Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Thu, 22 Jan 2026 22:58:10 +0100 Subject: [PATCH 12/34] fix columnwise quantize scale, precommit --- mlx/backend/cuda/quantized/fp_quantize.cu | 50 ++++++++++++----------- mlx/backend/cuda/quantized/qqmm.cpp | 2 - mlx/primitives.cpp | 7 ++-- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 7fc60cacab..d3762240bd 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -115,14 +115,14 @@ __global__ void fp_quantize_columnwise( uint8_t* scales, size_t size, int M, - int K, + int K, float* global_scale = nullptr) { // Input: [M, K] with strides [1, M] (M-major) // Quantized output: [M, K/elem_per_byte] row-major (K-major) // Scales: [M, K/group_size] row-major (K-major) // Quantize along K (last dimension, groups of group_size elements) const float scale_enc = - !use_mx_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + !use_mx_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; using Tx2 = Vector2_t; using Tx4 = Vector4_t; @@ -174,7 +174,7 @@ __global__ void fp_quantize_columnwise( auto pair = Tx2{thread_data[r], thread_data[r + 1]}; abs_max_x2(amax_2x, amax_2x, pair); } - float scale = + float scale_dec_b = max(fabsf(static_cast(amax_2x.x)), fabsf(static_cast(amax_2x.y))); scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; @@ -183,7 +183,7 @@ __global__ void fp_quantize_columnwise( using ScaleType = std::conditional_t; auto s = ScaleType(scale_dec_b); - scale = float(s); + float scale_enc_b = scale_enc / float(s); scales_smem[tidx][tidy] = s.__x; int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group; @@ -193,12 +193,12 @@ __global__ void fp_quantize_columnwise( Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); if constexpr (bits == 8) { uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized_smem[shared_idx + j * 4]) = quantized_val; } else { uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized_smem[shared_idx + j * 2]) = quantized_val; } @@ -345,7 +345,9 @@ void fp_quantize( gpu_ptr(scales), w.size(), M, - K); + K, + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not quantize input with type float64."); @@ -364,24 +366,24 @@ void fp_quantize( bool large = w.size() > UINT_MAX; auto [num_blocks, block_dims] = get_launch_args( w.size(), w.shape(), w.strides(), large, group_size); - - enc.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(w), - gpu_ptr(wq), - gpu_ptr(scales), - w.size(), - global_scale.has_value() ? gpu_ptr(global_scale.value()) - : nullptr); - } else { - throw std::runtime_error( - "[Quantize::eval_gpu] Can not quantize input with type float64."); - } - }); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + w.size(), + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); + } + }); + } } void fp_dequantize( diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 1bc6f558f0..544b5237af 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -13,7 +13,6 @@ namespace mlx::core { namespace { - struct GemmScalars { std::optional alpha_device; std::optional beta_device; @@ -23,7 +22,6 @@ struct GemmScalars { } }; - inline array ensure_contiguous(const array& x, cu::CommandEncoder& enc, const Stream& s) { if (x.flags().row_contiguous || x.flags().col_contiguous) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 0a38a081e4..925030151d 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3510,14 +3510,13 @@ std::vector QQMatmul::vjp( // primal[0] -- non quantized activations (M, K) // cotan -- non quantized grads (M, N) auto qmode = quantization_mode_to_string(mode_); - std::optional cotan_amax = - is_nvfp4 ? std::make_optional(astype(max(abs(cotan, s), s), float32, s)) - : std::nullopt; + std::optional cotan_amax = is_nvfp4 + ? std::make_optional(astype(max(abs(cotan, s), s), float32, s)) + : std::nullopt; auto get_primal_scale = [&](int idx) { return is_nvfp4 ? std::make_optional(primals[idx]) : std::nullopt; }; - for (auto arg : argnums) { if (arg == 0) { // gradient wrt to x From 7ca2642ff1649378889b9da2c725db1311ab60da Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 23 Jan 2026 00:55:50 +0100 Subject: [PATCH 13/34] abs_max --- .gitignore | 1 + mlx/backend/cuda/reduce/all_reduce.cu | 26 ++++++++++++++++++++++++-- mlx/backend/cuda/reduce/reduce.cuh | 2 ++ mlx/backend/cuda/reduce/reduce_ops.cuh | 24 ++++++++++++++++++++++++ mlx/ops.cpp | 14 ++++++++++++++ mlx/ops.h | 3 +++ mlx/primitives.h | 2 +- python/src/ops.cpp | 17 +++++++++++++++++ 8 files changed, 86 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 43629548db..898ec94c65 100644 --- a/.gitignore +++ b/.gitignore @@ -86,3 +86,4 @@ build/ # Jetbrains .cache +tmp \ No newline at end of file diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 1126f4cc76..100fdb3866 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -13,6 +13,20 @@ namespace cu { namespace cg = cooperative_groups; +template +__device__ __forceinline__ T absmax_val(T x) { + if constexpr (cuda::std::is_same_v) { + return x; + } else if constexpr (cuda::std::is_unsigned_v) { + return x; // unsigned types are non-negative + } else if constexpr (cuda::std::is_floating_point_v) { + return fabs(x); + } else { + // signed integer + return x < T(0) ? -x : x; + } +} + template __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { // TODO: Process multiple "rows" in each thread @@ -37,7 +51,11 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { for (; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { - accs[0] = op(accs[0], cast_to(vals[j])); + if constexpr (cuda::std::is_same_v) { + accs[0] = op(accs[0], absmax_val(cast_to(vals[j]))); + } else { + accs[0] = op(accs[0], cast_to(vals[j])); + } } } @@ -45,7 +63,11 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { cub::LoadDirectBlocked( block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { - accs[0] = op(accs[0], cast_to(vals[i])); + if constexpr (cuda::std::is_same_v) { + accs[0] = op(accs[0], absmax_val(cast_to(vals[i]))); + } else { + accs[0] = op(accs[0], cast_to(vals[i])); + } } } diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 02e495594a..947e8b36dc 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -35,6 +35,8 @@ void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Min) { f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::AbsMax) { + f(type_identity{}); } else { throw std::invalid_argument("Unknown reduce type."); } diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 7f8cad0c4e..d7cdc41561 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -114,6 +114,18 @@ struct Max { } }; +struct AbsMax { + template + __device__ __forceinline__ T operator()(T a, T b) { + return a > b ? a : b; + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } +}; + // Traits to get the result type of reduce op. template struct ReduceResult; @@ -154,6 +166,11 @@ struct ReduceResult { using type = T; }; +template +struct ReduceResult { + using type = float; +}; + // Traits to get the init value of reduce op. template struct ReduceInit; @@ -208,4 +225,11 @@ struct ReduceInit { } }; +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return T(0); // abs values are >= 0 + } +}; + } // namespace mlx::core::cu diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ed6c02748b..b409408c66 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2237,6 +2237,20 @@ array min( return min(a, std::vector{axis}, keepdims, s); } +array abs_max(const array& a, StreamOrDevice s /* = {}*/) { + if (a.size() == 0) { + throw std::invalid_argument( + "[abs_max] Cannot abs_max reduce zero size array."); + } + std::vector axes(a.ndim()); + std::iota(axes.begin(), axes.end(), 0); + return array( + {}, + float32, + std::make_shared(to_stream(s), Reduce::AbsMax, axes), + {a}); +} + array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { auto result = argmin(flatten(a, s), 0, true, s); if (keepdims) { diff --git a/mlx/ops.h b/mlx/ops.h index eea6b93377..2ac5a2f17b 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -678,6 +678,9 @@ array min( bool keepdims = false, StreamOrDevice s = {}); +/** The maximum of absolute values of all elements of the array. */ +array abs_max(const array& a, StreamOrDevice s = {}); + /** Returns the index of the minimum value in the array. */ array argmin(const array& a, bool keepdims, StreamOrDevice s = {}); inline array argmin(const array& a, StreamOrDevice s = {}) { diff --git a/mlx/primitives.h b/mlx/primitives.h index c3ce00f92f..bbd91f90a1 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1769,7 +1769,7 @@ class Reshape : public UnaryPrimitive { class Reduce : public UnaryPrimitive { public: - enum ReduceType { And, Or, Sum, Prod, Min, Max }; + enum ReduceType { And, Or, Sum, Prod, Min, Max, AbsMax }; explicit Reduce( Stream stream, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 40231de070..83319dab63 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2402,6 +2402,23 @@ void init_ops(nb::module_& m) { Returns: array: The output array with the corresponding axes reduced. )pbdoc"); + m.def( + "abs_max", + &mx::abs_max, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def abs_max(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + The maximum of absolute values of all elements in the array. + + Args: + a (array): Input array. + + Returns: + array: A scalar array with the maximum absolute value. + )pbdoc"); m.def( "logcumsumexp", [](const mx::array& a, From 934c0c830c3abf8c497c79fd6da8c7fd4d00593d Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 23 Jan 2026 01:44:31 +0100 Subject: [PATCH 14/34] fix --- .gitignore | 3 +-- mlx/backend/cuda/reduce/reduce_ops.cuh | 4 ++-- mlx/primitives.cpp | 5 ++--- python/mlx/nn/layers/quantized.py | 4 ++-- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 898ec94c65..ff4d56f592 100644 --- a/.gitignore +++ b/.gitignore @@ -85,5 +85,4 @@ build/ .DS_Store # Jetbrains -.cache -tmp \ No newline at end of file +.cache \ No newline at end of file diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index d7cdc41561..40be1bae41 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -227,8 +227,8 @@ struct ReduceInit { template struct ReduceInit { - static constexpr __host__ __device__ T value() { - return T(0); // abs values are >= 0 + static constexpr __host__ __device__ auto value() { + return typename ReduceResult::type(0); // abs values are >= 0 } }; diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 925030151d..994ca056d8 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3510,9 +3510,8 @@ std::vector QQMatmul::vjp( // primal[0] -- non quantized activations (M, K) // cotan -- non quantized grads (M, N) auto qmode = quantization_mode_to_string(mode_); - std::optional cotan_amax = is_nvfp4 - ? std::make_optional(astype(max(abs(cotan, s), s), float32, s)) - : std::nullopt; + std::optional cotan_amax = + is_nvfp4 ? std::make_optional(abs_max(cotan, s)) : std::nullopt; auto get_primal_scale = [&](int idx) { return is_nvfp4 ? std::make_optional(primals[idx]) : std::nullopt; diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 7b226d01e1..61a5bc306a 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -386,11 +386,11 @@ def __call__(self, x): # for the activations as well as for the weights # (for example for the weights it can be ema ) global_scale_w = ( - getattr(self, "global_scale_w", (self.weight).abs().max()) + getattr(self, "global_scale_w", mx.abs_max(self.weight)) if self._use_global_scale else None ) - global_scale_x = (x).abs().max() if self._use_global_scale else None + global_scale_x = mx.abs_max(x) if self._use_global_scale else None x = mx.qqmm( x, self["weight"], From 1fea0253b592c1d12f7eb5face8ea4c4bf979df6 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 23 Jan 2026 20:34:48 +0100 Subject: [PATCH 15/34] fixed the fallback, fixed absmax --- mlx/backend/cuda/quantized/fp_quantize.cu | 12 +- mlx/backend/cuda/quantized/qqmm.cpp | 31 ++--- mlx/backend/cuda/quantized/quantized.cpp | 17 +-- mlx/backend/cuda/reduce/all_reduce.cu | 17 +-- mlx/backend/cuda/reduce/reduce_ops.cuh | 15 ++- mlx/ops.cpp | 135 +++++++++++----------- mlx/primitives.h | 2 + python/tests/test_quantized.py | 17 +++ 8 files changed, 141 insertions(+), 105 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index d3762240bd..ce2b319331 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -44,8 +44,9 @@ __global__ void fp_quantize_rowwise( // Global encode scale: (448 × 6) / *global_scale // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b + const bool use_global_scale = global_scale != nullptr; const float scale_enc = - !use_mx_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; using Tx2 = Vector2_t; using Tx4 = Vector4_t; @@ -121,8 +122,9 @@ __global__ void fp_quantize_columnwise( // Quantized output: [M, K/elem_per_byte] row-major (K-major) // Scales: [M, K/group_size] row-major (K-major) // Quantize along K (last dimension, groups of group_size elements) + const bool use_global_scale = global_scale != nullptr; const float scale_enc = - !use_mx_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; using Tx2 = Vector2_t; using Tx4 = Vector4_t; @@ -257,8 +259,10 @@ __global__ void fp_dequantize( auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; constexpr int pack_factor = bits == 8 ? 1 : 2; - const float inv_scale_enc = - use_mx_scale ? 1.0f : (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX); + const bool use_global_scale = global_scale != nullptr; + const float inv_scale_enc = use_mx_scale + ? 1.0f + : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f); size_t offset = tidx + grid_dim_x * size_t(tidy); size_t oindex = offset * pack_factor; diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 544b5237af..b3a9cbab83 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -62,7 +62,7 @@ std::tuple quantize_input( int bits, int group_size, std::optional global_scale = std::nullopt) { - const array x = ensure_row_contiguous(input, encoder, s); + const array x = ensure_contiguous(input, encoder, s); // Compute output shapes auto xq_shape = x.shape(); @@ -171,28 +171,29 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[QQMatmul::eval_gpu] QQMM requires compute capability 10.0+"); } - // input size = 2 for non-quantized w for qmode != nvfp4 - // input size = 3 for quantized w for qmode != nvfp4 - // input size = 4 for non-quantized w for qmode == nvfp4 - // input size = 5 for quantized w for qmode == nvfp4 - auto num_amax_inputs = mode_ == QuantizationMode::Nvfp4 ? 2 : 0; - auto size = - inputs[1].dtype() == uint32 ? 3 + num_amax_inputs : 2 + num_amax_inputs; - assert(inputs.size() == size); + // - 2 inputs: x, w (non-quantized w) + // - 3 inputs: x, w, scales_w (quantized w) + bool w_is_quantized = inputs[1].dtype() == uint32; + int base_size = w_is_quantized ? 3 : 2; - // For nvfp4, get global scales from inputs + // For nvfp4, global scales are optional but must be both present or both + // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) + bool has_global_scales = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; + + // For nvfp4, get global scales from inputs if present std::optional global_scale_x = std::nullopt; std::optional global_scale_w = std::nullopt; - if (mode_ == QuantizationMode::Nvfp4) { - global_scale_x = inputs[size - 2]; - global_scale_w = inputs[size - 1]; + if (has_global_scales) { + global_scale_x = inputs[inputs.size() - 2]; + global_scale_w = inputs[inputs.size() - 1]; } // Quantize inputs (or use pre-quantized) auto [x_q, scale_x_pre] = quantize_input( inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x); - auto [w_q, scale_w_pre] = inputs[1].dtype() != uint32 + auto [w_q, scale_w_pre] = !w_is_quantized ? quantize_input( inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w) : std::make_tuple( @@ -215,7 +216,7 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s); GemmScalars scalars; - if (mode_ == QuantizationMode::Nvfp4) { + if (has_global_scales) { scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder); } diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index c0829b2e7d..0fabdda88b 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -63,7 +63,6 @@ void fast::Quantize::eval_gpu( auto& s = stream(); auto& d = cu::device(s.device); auto& enc = d.get_command_encoder(s); - if (dequantize_) { auto wq = ensure_row_contiguous(inputs[0], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s); @@ -75,10 +74,11 @@ void fast::Quantize::eval_gpu( auto biases = ensure_row_contiguous(inputs[2], enc, s); affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); } else { - // third input is global scale for nvfp4 - std::optional global_scale = mode_ == QuantizationMode::Nvfp4 - ? std::make_optional(inputs[2]) - : std::nullopt; + // 0 -- xq, 1 -- scales, 2 -- could be global scale for nvfp4 + bool use_global_scale = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > 2; + std::optional global_scale = + use_global_scale ? std::make_optional(inputs[2]) : std::nullopt; fp_dequantize(wq, scales, w, group_size_, bits_, global_scale, enc, s); } } else { @@ -94,9 +94,10 @@ void fast::Quantize::eval_gpu( biases.set_data(cu::malloc_async(biases.nbytes(), enc)); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); } else { - std::optional global_scale = mode_ == QuantizationMode::Nvfp4 - ? std::make_optional(inputs[1]) - : std::nullopt; + bool use_global_scale = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > 1; + std::optional global_scale = + use_global_scale ? std::make_optional(inputs[1]) : std::nullopt; fp_quantize(w, wq, scales, group_size_, bits_, global_scale, enc, s); } } diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 100fdb3866..1ab79d1f0a 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -14,16 +14,11 @@ namespace cu { namespace cg = cooperative_groups; template -__device__ __forceinline__ T absmax_val(T x) { - if constexpr (cuda::std::is_same_v) { - return x; - } else if constexpr (cuda::std::is_unsigned_v) { - return x; // unsigned types are non-negative - } else if constexpr (cuda::std::is_floating_point_v) { - return fabs(x); +__device__ __forceinline__ T absmax(T x) { + if constexpr (cuda::std::is_unsigned_v) { + return x; // No-op for unsigned types } else { - // signed integer - return x < T(0) ? -x : x; + return abs(x); // Uses cu::abs for half types, ::abs for others } } @@ -52,7 +47,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], absmax_val(cast_to(vals[j]))); + accs[0] = op(accs[0], absmax(cast_to(vals[j]))); } else { accs[0] = op(accs[0], cast_to(vals[j])); } @@ -64,7 +59,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], absmax_val(cast_to(vals[i]))); + accs[0] = op(accs[0], absmax(cast_to(vals[i]))); } else { accs[0] = op(accs[0], cast_to(vals[i])); } diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 40be1bae41..5fcd6cdfa2 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -115,8 +115,21 @@ struct Max { }; struct AbsMax { + // abs is applied inside all_reduce kernel template __device__ __forceinline__ T operator()(T a, T b) { + if constexpr (is_complex_v) { + if (isnan(a.real()) || isnan(a.imag())) { + return a; + } + if (isnan(b.real()) || isnan(b.imag())) { + return b; + } + } else if constexpr (!cuda::std::is_integral_v) { + if (isnan(a) || isnan(b)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + } return a > b ? a : b; } @@ -168,7 +181,7 @@ struct ReduceResult { template struct ReduceResult { - using type = float; + using type = T; }; // Traits to get the init value of reduce op. diff --git a/mlx/ops.cpp b/mlx/ops.cpp index b409408c66..0d34698a81 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2237,17 +2237,26 @@ array min( return min(a, std::vector{axis}, keepdims, s); } +// TODO: extend to row_reduce and col_reduce? array abs_max(const array& a, StreamOrDevice s /* = {}*/) { if (a.size() == 0) { throw std::invalid_argument( "[abs_max] Cannot abs_max reduce zero size array."); } + if (!issubdtype(a.dtype(), floating)) { + throw std::invalid_argument( + "[abs_max] abs_max supported only for floating point types."); + } + auto stream = to_stream(s); + if (stream.device != Device::gpu || !cu::is_available()) { + return max(abs(a, s), false, s); + } std::vector axes(a.ndim()); std::iota(axes.begin(), axes.end(), 0); return array( {}, - float32, - std::make_shared(to_stream(s), Reduce::AbsMax, axes), + a.dtype(), + std::make_shared(stream, Reduce::AbsMax, axes), {a}); } @@ -4243,20 +4252,11 @@ void validate_global_scale( // TODO: not sure if type should be restricted to float32 if (global_scale->dtype() != float32) { std::ostringstream msg; - msg << "[" << tag - << "] Global scale must be a floating type but got type " + msg << "[" << tag << "] Global scale must have dtype float32 but got " << global_scale->dtype() << "."; throw std::invalid_argument(msg.str()); } } - } else { - if (qmode == QuantizationMode::Nvfp4) { - std::ostringstream msg; - msg << "[" << tag << "] Global scale must be provided for 'nvfp4' " - << "quantization mode."; - throw std::invalid_argument(msg.str()); - } - return; } } @@ -4361,6 +4361,16 @@ void validate_qqmm_inputs( // validate global scales validate_global_scale("qqmm", qmode, global_scale_x); validate_global_scale("qqmm", qmode, global_scale_w); + // For nvfp4 mode, both global scales must be provided together or neither + if (qmode == QuantizationMode::Nvfp4) { + bool has_x = global_scale_x.has_value(); + bool has_w = global_scale_w.has_value(); + if (has_x != has_w) { + throw std::invalid_argument( + "[qqmm] For nvfp4 mode, either both global_scale_x and " + "global_scale_w must be provided, or neither."); + } + } } std::pair extract_qqmm_dims( @@ -4444,9 +4454,13 @@ array qqmm( if (scales_w.has_value()) { inputs.push_back(*scales_w); } - if (global_scale_x.has_value() && global_scale_w.has_value()) { + // if + if (global_scale_x.has_value()) { // Stop gradient through global scales inputs.push_back(stop_gradient(*global_scale_x)); + } + if (global_scale_w.has_value()) { + // Stop gradient through global scales inputs.push_back(stop_gradient(*global_scale_w)); } auto out_shape = inputs[0].shape(); @@ -4602,30 +4616,41 @@ std::vector fp_quantize( << bits << "."; throw std::invalid_argument(msg.str()); } - auto fallback = [bits = bits, group_size = group_size, mode = mode, s]( + + auto inputs = std::vector{w}; + if (global_scale.has_value()) { + inputs.push_back(global_scale.value()); + } + + auto fallback = [bits = bits, group_size = group_size, s]( const std::vector& inputs) -> std::vector { - auto w = inputs[0]; + auto& w = inputs[0]; + auto scale_encode = + inputs.size() > 1 ? 448.0f * 6.0f / inputs[1] : array(1.0f, float32); float maxval = (bits == 4) ? 6.0f : 448.0f; - auto new_shape = w.shape(); new_shape.back() = -1; auto wq = reshape(w, {-1, group_size}, s); - auto group_amax = max(abs(wq, s), -1, true, s); - + auto scales = + divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s) * + scale_encode; if (group_size == 16) { - // NVFP4: scale_dec = (group_amax / 6) * (448 * 6) / global_scale - // = group_amax * 448 / global_scale - array scales = (mode == QuantizationMode::Nvfp4 && inputs.size() > 1) - ? divide( - multiply(group_amax, array(448.0f, w.dtype()), s), inputs[1], s) - : divide(group_amax, array(maxval, w.dtype()), s); // convert to e4m3 scales = to_fp8(scales, s); - // quantized = w * 6 / group_amax - wq = divide(multiply(wq, array(maxval, w.dtype()), s), group_amax, s); - wq = reshape(wq, new_shape, s); - scales = reshape(scales, new_shape, s); + wq = divide(wq, from_fp8(scales, w.dtype(), s), s) * scale_encode; + } else { + // convert to e8m0 + auto z = array(0, scales.dtype()); + scales = where( + equal(scales, z, s), + z, + astype(round(log2(scales, s), s), int32, s), + s); + wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); + scales = astype(add(scales, array(127, int32), s), uint8, s); + } + if (bits == 4) { auto lut = array({ +0.0f, +0.5f, @@ -4650,32 +4675,14 @@ std::vector fp_quantize( auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); wq = reshape(wq, {-1, 4, 8}, s); wq = sum(multiply(wq, shifts, s), -1, false, s); - wq = reshape(wq, new_shape, s); - - return {std::move(wq), std::move(scales)}; } else { - auto scales = divide(group_amax, array(maxval, w.dtype()), s); - auto z = array(0, scales.dtype()); - scales = where( - equal(scales, z, s), - z, - astype(round(log2(scales, s), s), int32, s), - s); - - wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); - scales = astype(add(scales, array(127, int32), s), uint8, s); - wq = view(to_fp8(wq, s), uint32, s); - wq = reshape(wq, new_shape, s); - scales = reshape(scales, new_shape, s); - return {std::move(wq), std::move(scales)}; } + wq = reshape(wq, new_shape, s); + scales = reshape(scales, new_shape, s); + return {std::move(wq), std::move(scales)}; }; - auto inputs = std::vector{w}; - if (global_scale.has_value()) { - inputs.push_back(global_scale.value()); - } if (s.device == Device::gpu) { auto wq_shape = w.shape(); wq_shape.back() = w.shape(-1) * bits / 32; @@ -4878,16 +4885,23 @@ array fp_dequantize( throw std::invalid_argument(msg.str()); } + auto inputs = std::vector{w, scales}; + if (global_scale.has_value()) { + inputs.push_back(global_scale.value()); + } + auto fallback = [wshape = std::move(wshape), sshape = std::move(sshape), group_size, bits, out_type, - mode = mode, s](const std::vector& inputs) mutable -> std::vector { auto out = inputs[0]; auto scales = inputs[1]; + array inv_scale_enc = inputs.size() > 2 + ? divide(inputs[2], array(448.0f * 6.0f, out_type), s) + : array(1.0f, out_type); if (bits == 4) { auto lut = array( { @@ -4921,27 +4935,16 @@ array fp_dequantize( out = reshape(out, {-1, group_size}, s); scales = reshape(scales, {-1, 1}, s); if (group_size == 16) { - // NVFP4: decode scale from fp8 e4m3 - scales = from_fp8(scales, out_type, s); - // For nvfp4 with global_scale: effective_scale = scale * global_scale / - // (448 * 6) - if (mode == QuantizationMode::Nvfp4 && inputs.size() > 2) { - scales = divide( - multiply(scales, inputs[2], s), array(448.0f * 6.0f, out_type), s); - } + // NVFP4: scales are E4M3, apply inv_scale_enc + scales = multiply(from_fp8(scales, out_type, s), inv_scale_enc, s); } else { - // MXFP8: decode e8m0 scale + // MXFP: scales are E8M0 (power of 2) scales = subtract(astype(scales, out_type, s), array(127, out_type), s); scales = power(array(2.0f, out_type), scales, s); } - out = reshape(multiply(out, scales, s), wshape, s); - - return {out}; + return {reshape(multiply(out, scales, s), wshape, s)}; }; - auto inputs = std::vector{w, scales}; - if (global_scale.has_value()) { - inputs.push_back(global_scale.value()); - } + if (s.device == Device::gpu) { auto out_shape = w.shape(); out_shape.back() = out_size; diff --git a/mlx/primitives.h b/mlx/primitives.h index bbd91f90a1..1e9b14da66 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1799,6 +1799,8 @@ class Reduce : public UnaryPrimitive { return "Min"; case Max: return "Max"; + case AbsMax: + return "AbsMax"; } return ""; } diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 9e58197039..cd0a167968 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -160,6 +160,23 @@ def test_nvfp4_quantize_dequantize(self): w_hat = mx.dequantize(w_q, scales, mode="nvfp4") self.assertTrue(mx.all(w_hat == 0)) + # Test nvfp4 quantize/dequantize with tensor-scale global_scale + if not mx.metal.is_available(): + global_scale = mx.abs_max(w).astype(mx.float32) + else: + global_scale = mx.array(1.0, dtype=mx.float32) + + w_q, scales = mx.quantize(w, mode="nvfp4", global_scale=global_scale) + w_hat = mx.dequantize( + w_q, scales, group_size=16, bits=4, mode="nvfp4", global_scale=global_scale + ) + # self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) + if (w - w_hat).abs().max() > 1e-5: + import pdb + + pdb.set_trace() + print("Max error with global scale:", (w - w_hat).abs().max().item()) + def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) From 306acd06ecda51577f2d265c685dabbc7dd077bd Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 23 Jan 2026 22:39:05 +0100 Subject: [PATCH 16/34] fix docs, remove the diff --- .../actions/build-macos-release/action.yml | 2 ++ .github/actions/build-macos/action.yml | 10 +++---- .github/workflows/release.yml | 5 ++++ .gitignore | 2 +- examples/python/qqmm.py | 8 ++---- mlx/backend/cuda/quantized/qqmm.cpp | 5 ++-- mlx/ops.cpp | 2 -- mlx/primitives.cpp | 3 +-- python/mlx/nn/layers/quantized.py | 26 +++++++++++++------ python/src/ops.cpp | 9 ++++--- python/tests/test_quantized.py | 7 +---- 11 files changed, 41 insertions(+), 38 deletions(-) diff --git a/.github/actions/build-macos-release/action.yml b/.github/actions/build-macos-release/action.yml index 5fa98bcf8a..93e6dd943f 100644 --- a/.github/actions/build-macos-release/action.yml +++ b/.github/actions/build-macos-release/action.yml @@ -18,6 +18,7 @@ runs: - name: Build Python package shell: bash -l {0} env: + DEVELOPER_DIR: /Applications/Xcode-latest.app MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }} run: | pip install build @@ -28,6 +29,7 @@ runs: if: ${{ inputs.build-backend }} shell: bash -l {0} env: + DEVELOPER_DIR: /Applications/Xcode-latest.app MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }} run: | python setup.py clean --all diff --git a/.github/actions/build-macos/action.yml b/.github/actions/build-macos/action.yml index 067fb8d5a5..f04ed3d63f 100644 --- a/.github/actions/build-macos/action.yml +++ b/.github/actions/build-macos/action.yml @@ -17,15 +17,15 @@ runs: - name: Install tests dependencies shell: bash -l {0} run: | - pip install numpy torch tensorflow unittest-xml-reporting + pip install numpy torch tensorflow - name: Run Python tests shell: bash -l {0} env: LOW_MEMORY: 1 run: | - DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu - DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu + DEVICE=cpu python -m unittest discover -v python/tests + DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m unittest discover -v python/tests mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi @@ -77,6 +77,4 @@ runs: run: | CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ pip install -e . -v - python -m xmlrunner discover \ - -v python/tests \ - -o test-results/gpu_jit + python -m unittest discover -v python/tests diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3e62633a40..9b9420a50a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -110,6 +110,11 @@ jobs: with: macos-target: 15.0 build-backend: ${{ matrix.python-version == '3.10' }} + - name: Build macOS 26 package + uses: ./.github/actions/build-macos-release + with: + macos-target: 26.0 + build-backend: ${{ matrix.python-version == '3.10' }} - name: Upload MLX artifacts uses: actions/upload-artifact@v6 with: diff --git a/.gitignore b/.gitignore index ff4d56f592..43629548db 100644 --- a/.gitignore +++ b/.gitignore @@ -85,4 +85,4 @@ build/ .DS_Store # Jetbrains -.cache \ No newline at end of file +.cache diff --git a/examples/python/qqmm.py b/examples/python/qqmm.py index f0745f19e1..6f1162a830 100644 --- a/examples/python/qqmm.py +++ b/examples/python/qqmm.py @@ -38,12 +38,8 @@ def test_qqmm(): for dtype in dtypes: x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) - x_amax = ( - mx.abs(x).max().astype(mx.float32) if group_size == 16 else None - ) - w_amax = ( - mx.abs(w).max().astype(mx.float32) if group_size == 16 else None - ) + x_amax = mx.abs_max(x).astype(mx.float32) if group_size == 16 else None + w_amax = mx.abs_max(w).astype(mx.float32) if group_size == 16 else None w_q, scales_w = mx.quantize( w, group_size, bits, mode=mode, global_scale=w_amax ) diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index b3a9cbab83..fae845044c 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -39,17 +39,18 @@ array pad_and_swizzle_scales( // Compute padded dimensions for full tiles (128 rows × 4 cols) auto [pad_outer, pad_inner] = get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); - // cuBLAS requirements for scale factor layout: // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) // 2. Out-of-bounds values must be filled with zeros // 3. Starting addresses must be 16-byte aligned // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + // Note: cu::malloc_async already provides 256-byte alignment array scale_tiled( cu::malloc_async(pad_outer * pad_inner, encoder), Shape{pad_outer, pad_inner}, scale.dtype()); swizzle_scales(scale, scale_tiled, encoder, s); + encoder.add_temporary(scale_tiled); return scale_tiled; } @@ -162,8 +163,6 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("QQMatmul::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - - // Check compute capability (requires Blackwell or newer) auto& device = encoder.device(); int cc = device.compute_capability_major() * 100 + device.compute_capability_minor() * 10; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 0d34698a81..662048943a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4302,7 +4302,6 @@ array quantized_matmul( if (x.ndim() > 2 && w.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } - // TODO: add global scale auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; return array( @@ -6206,5 +6205,4 @@ array contiguous( std::make_shared(to_stream(s), allow_col_major), {a}); } - } // namespace mlx::core \ No newline at end of file diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 994ca056d8..75611344ac 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3523,7 +3523,7 @@ std::vector QQMatmul::vjp( vjps.push_back(qqmm( cotan, // M X N swapaxes(primals[1], -1, -2, s), // assuming that w is 2D - std::nullopt, + {}, group_size_, bits_, qmode, @@ -3664,7 +3664,6 @@ std::vector GatherQMM::vjp( group_size_, bits_, quantization_mode_to_string(mode_), - {}, // placeholder for amax std::nullopt, stream()), -1, diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 61a5bc306a..088bb75205 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -320,6 +320,7 @@ def __init__( group_size: int = None, bits: int = None, mode: str = "nvfp4", + use_global_scale: bool = True, ): super().__init__() @@ -327,6 +328,11 @@ def __init__( self.group_size, self.bits = _defaults_for_mode(mode, group_size, bits) self.mode = mode + if self.mode != "nvfp4" and use_global_scale: + raise ValueError( + "Global scale can only be used with 'nvfp4' quantization mode." + ) + scale = math.sqrt(1 / input_dims) self.weight = mx.random.uniform( low=-scale, @@ -334,7 +340,7 @@ def __init__( shape=(output_dims, input_dims), ) self._quantized = False - self._use_global_scale = self.mode == "nvfp4" + self._use_global_scale = use_global_scale def _extra_repr(self): out_dims, in_dims = self.weight.shape @@ -347,8 +353,11 @@ def _extra_repr(self): def quantize(self): if not self._quantized: + self.global_scale_w = ( - (self.weight).abs().max() if self._use_global_scale else None + mx.abs_max(self.weight).astype(mx.float32) + if self._use_global_scale + else None ) self.weight, self.scales = mx.quantize( self.weight, @@ -370,7 +379,8 @@ def dequantize(self): global_scale=self.global_scale_w, ) del self.scales - del self.global_scale_w + if self._use_global_scale: + del self.global_scale_w self._quantized = False def _set_training_mode(self, mode: bool): @@ -382,15 +392,15 @@ def _set_training_mode(self, mode: bool): self.quantize() def __call__(self, x): - # TODO: In the future we can implement different policies for amax update - # for the activations as well as for the weights - # (for example for the weights it can be ema ) + global_scale_w = ( - getattr(self, "global_scale_w", mx.abs_max(self.weight)) + getattr(self, "global_scale_w", mx.abs_max(self.weight).astype(mx.float32)) if self._use_global_scale else None ) - global_scale_x = mx.abs_max(x) if self._use_global_scale else None + global_scale_x = ( + mx.abs_max(x).astype(mx.float32) if self._use_global_scale else None + ) x = mx.qqmm( x, self["weight"], diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 83319dab63..ee71480332 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4301,7 +4301,7 @@ void init_ops(nb::module_& m) { :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. global_scale (array, optional): The per-input float32 scale used for - ``"nvfp4"`` quantization. Default: ``None``. + ``"nvfp4"`` quantization if provided. Default: ``None``. Returns: tuple: A tuple with either two or three elements containing: @@ -4392,7 +4392,7 @@ void init_ops(nb::module_& m) { ``w`` in the quantized array. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. global_scale (array, optional): The per-input float32 scale used for - ``"nvfp4"`` quantization. Default: ``None``. + ``"nvfp4"`` quantization if provided. Default: ``None``. dtype (Dtype, optional): The data type of the dequantized output. If ``None`` the return type is inferred from the scales and biases when possible and otherwise defaults to ``bfloat16``. @@ -5507,6 +5507,7 @@ void init_ops(nb::module_& m) { If ``x`` and `w`` are not quantized, their data types must be ``float32``, ``float16``, or ``bfloat16``. If ``w`` is quantized, it must be packed in unsigned integers. + ``global_scale_x`` and ``global_scale_w`` are only used for ``nvfp4`` quantization. Args: x (array): Input array. @@ -5523,9 +5524,9 @@ void init_ops(nb::module_& m) { Supported modes are ``nvfp4`` and ``mxfp8``. See the :ref:`table of quantization modes ` for details. global_scale (array, optional): The per-input float32 scale used for x - ``"nvfp4"`` quantization. Default: ``None``. + ``"nvfp4"`` quantization if provided. Default: ``None``. global_scale_w (array, optional): The per-input float32 scale used for w - ``"nvfp4"`` quantization. Default: ``None``. + ``"nvfp4"`` quantization if provided. Default: ``None``. Returns: array: The result of the multiplication of quantized ``x`` with quantized ``w``. needed). diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index cd0a167968..6b3c832d86 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -170,12 +170,7 @@ def test_nvfp4_quantize_dequantize(self): w_hat = mx.dequantize( w_q, scales, group_size=16, bits=4, mode="nvfp4", global_scale=global_scale ) - # self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) - if (w - w_hat).abs().max() > 1e-5: - import pdb - - pdb.set_trace() - print("Max error with global scale:", (w - w_hat).abs().max().item()) + self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) def test_qmm(self): key = mx.random.key(0) From 7492841d43e9b63fe787d4a6b8627d70751bdd55 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 23 Jan 2026 22:45:49 +0100 Subject: [PATCH 17/34] fix docs, delete debuging print --- examples/python/qqmm.py | 3 --- mlx/backend/cuda/quantized/cublas_qqmm.cpp | 3 --- mlx/backend/cuda/quantized/qqmm.cpp | 2 +- mlx/ops.cpp | 1 + 4 files changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/python/qqmm.py b/examples/python/qqmm.py index 6f1162a830..5e5d6bee44 100644 --- a/examples/python/qqmm.py +++ b/examples/python/qqmm.py @@ -81,9 +81,6 @@ def test_qqmm(): ulp = ulp_bf16_at(y_hat) error = (y_q - y_hat).abs() if not (mx.logical_or(error < 1e-3, error <= ulp).all()): - import pdb - - pdb.set_trace() raise AssertionError( f"qqmm test failed for shape {(M, N, K)}, " f"group_size={group_size}, bits={bits}, " diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp index 41446c6f98..f81b6455c5 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.cpp +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -137,9 +137,6 @@ CublasQQMM::CublasQQMM( batch_count, c_batch_stride); } -// Supported overloads: -// alpha float -// alpha device ptr void CublasQQMM::run( cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index fae845044c..14ac2535e3 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -168,7 +168,7 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { device.compute_capability_minor() * 10; if (cc < 1000) { throw std::runtime_error( - "[QQMatmul::eval_gpu] QQMM requires compute capability 10.0+"); + "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher."); } // - 2 inputs: x, w (non-quantized w) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 662048943a..b13b7ba7fa 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -6205,4 +6205,5 @@ array contiguous( std::make_shared(to_stream(s), allow_col_major), {a}); } + } // namespace mlx::core \ No newline at end of file From f49abe5797395352b907c9e877c0a2a339b9eca8 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 23 Jan 2026 22:58:29 +0100 Subject: [PATCH 18/34] reverted the example --- examples/python/qqmm.py | 45 +++++------------------------------------ 1 file changed, 5 insertions(+), 40 deletions(-) diff --git a/examples/python/qqmm.py b/examples/python/qqmm.py index 5e5d6bee44..5be7eae2f3 100644 --- a/examples/python/qqmm.py +++ b/examples/python/qqmm.py @@ -38,14 +38,7 @@ def test_qqmm(): for dtype in dtypes: x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) - x_amax = mx.abs_max(x).astype(mx.float32) if group_size == 16 else None - w_amax = mx.abs_max(w).astype(mx.float32) if group_size == 16 else None - w_q, scales_w = mx.quantize( - w, group_size, bits, mode=mode, global_scale=w_amax - ) - x_q, scales_x = mx.quantize( - x, group_size, bits, mode=mode, global_scale=x_amax - ) + w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) w_dq = mx.dequantize( w_q, scales_w, @@ -53,7 +46,6 @@ def test_qqmm(): bits=bits, mode=mode, dtype=dtype, - global_scale=w_amax, ) y_q = mx.qqmm( x, @@ -62,11 +54,9 @@ def test_qqmm(): group_size=group_size, bits=bits, mode=mode, - global_scale_x=x_amax, - global_scale_w=w_amax, ) x_q, scales_x = mx.quantize( - x, group_size=group_size, bits=bits, mode=mode, global_scale=x_amax + x, group_size=group_size, bits=bits, mode=mode ) x_dq = mx.dequantize( x_q, @@ -74,7 +64,6 @@ def test_qqmm(): group_size=group_size, bits=bits, mode=mode, - global_scale=x_amax, dtype=dtype, ) y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) @@ -100,43 +89,19 @@ def test_qqmm_vjp(): ) x = mx.random.normal(shape=(M, K), key=k1) c = mx.ones(shape=(M, N)) - x_amax = mx.abs(x).max() if tests[0][0] == 16 else None for group_size, mode, bits in tests: w = mx.random.normal(shape=(N, K), key=k2) - x_amax = mx.abs(x).max() if group_size == 16 else None - w_amax = mx.abs(w).max() if group_size == 16 else None - c_amax = mx.abs(c).max() if group_size == 16 else None - def fn(x): - return mx.qqmm( - x, - w, - group_size=group_size, - bits=bits, - mode=mode, - global_scale_x=x_amax, - global_scale_w=w_amax, - ) + return mx.qqmm(x, w, group_size=group_size, bits=bits, mode=mode) _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,)) w_tq, scales_wt = mx.quantize( - mx.transpose(w), - group_size=group_size, - bits=bits, - mode=mode, - global_scale=w_amax, + mx.transpose(w), group_size=group_size, bits=bits, mode=mode ) expected_out = mx.qqmm( - c, - w_tq, - scales_wt, - group_size=group_size, - bits=bits, - mode=mode, - global_scale_x=c_amax, - global_scale_w=w_amax, + c, w_tq, scales_wt, group_size=group_size, bits=bits, mode=mode ) ulp = ulp_bf16_at(expected_out) error = (vjp_out[0] - expected_out).abs() From 37e57896760d5e4520bc0c3cefacb336d7d7854d Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 23 Jan 2026 23:28:33 +0100 Subject: [PATCH 19/34] abs_max -> absmax --- mlx/backend/cuda/quantized/fp_quantize.cu | 4 ++-- mlx/backend/cuda/quantized/quantized_utils.cuh | 2 +- mlx/ops.cpp | 6 +++--- mlx/ops.h | 2 +- mlx/primitives.cpp | 2 +- python/mlx/nn/layers/quantized.py | 6 +++--- python/src/ops.cpp | 6 +++--- python/tests/test_quantized.py | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index ce2b319331..9074e26df5 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -73,7 +73,7 @@ __global__ void fp_quantize_rowwise( #pragma unroll for (int i = 0; i < group_size; i += 2) { auto pair = Tx2{w_tile[i], w_tile[i + 1]}; - abs_max_x2(amax_2x, amax_2x, pair); + absmax_x2(amax_2x, amax_2x, pair); } scale_dec_b = static_cast( @@ -174,7 +174,7 @@ __global__ void fp_quantize_columnwise( #pragma unroll for (int r = 0; r < group_size; r += 2) { auto pair = Tx2{thread_data[r], thread_data[r + 1]}; - abs_max_x2(amax_2x, amax_2x, pair); + absmax_x2(amax_2x, amax_2x, pair); } float scale_dec_b = max(fabsf(static_cast(amax_2x.x)), diff --git a/mlx/backend/cuda/quantized/quantized_utils.cuh b/mlx/backend/cuda/quantized/quantized_utils.cuh index e589c97057..8a6c6f1da2 100644 --- a/mlx/backend/cuda/quantized/quantized_utils.cuh +++ b/mlx/backend/cuda/quantized/quantized_utils.cuh @@ -16,7 +16,7 @@ inline constexpr __device__ short get_bytes_per_pack() { } template -__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) { +__device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) { if constexpr ( (std::is_same::value) || (std::is_same::value)) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9f0ee0e0f2..7d75c1cf9c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2238,14 +2238,14 @@ array min( } // TODO: extend to row_reduce and col_reduce? -array abs_max(const array& a, StreamOrDevice s /* = {}*/) { +array absmax(const array& a, StreamOrDevice s /* = {}*/) { if (a.size() == 0) { throw std::invalid_argument( - "[abs_max] Cannot abs_max reduce zero size array."); + "[absmax] Cannot absmax reduce zero size array."); } if (!issubdtype(a.dtype(), floating)) { throw std::invalid_argument( - "[abs_max] abs_max supported only for floating point types."); + "[absmax] absmax supported only for floating point types."); } auto stream = to_stream(s); if (stream.device != Device::gpu || !cu::is_available()) { diff --git a/mlx/ops.h b/mlx/ops.h index 0176e8a589..caf3e3d834 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -667,7 +667,7 @@ MLX_API array min(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); /** The maximum of absolute values of all elements of the array. */ -array abs_max(const array& a, StreamOrDevice s = {}); +array absmax(const array& a, StreamOrDevice s = {}); /** Returns the index of the minimum value in the array. */ MLX_API array argmin(const array& a, bool keepdims, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 75611344ac..8c48c88055 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3511,7 +3511,7 @@ std::vector QQMatmul::vjp( // cotan -- non quantized grads (M, N) auto qmode = quantization_mode_to_string(mode_); std::optional cotan_amax = - is_nvfp4 ? std::make_optional(abs_max(cotan, s)) : std::nullopt; + is_nvfp4 ? std::make_optional(absmax(cotan, s)) : std::nullopt; auto get_primal_scale = [&](int idx) { return is_nvfp4 ? std::make_optional(primals[idx]) : std::nullopt; diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 088bb75205..04dca73650 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -355,7 +355,7 @@ def quantize(self): if not self._quantized: self.global_scale_w = ( - mx.abs_max(self.weight).astype(mx.float32) + mx.absmax(self.weight).astype(mx.float32) if self._use_global_scale else None ) @@ -394,12 +394,12 @@ def _set_training_mode(self, mode: bool): def __call__(self, x): global_scale_w = ( - getattr(self, "global_scale_w", mx.abs_max(self.weight).astype(mx.float32)) + getattr(self, "global_scale_w", mx.absmax(self.weight).astype(mx.float32)) if self._use_global_scale else None ) global_scale_x = ( - mx.abs_max(x).astype(mx.float32) if self._use_global_scale else None + mx.absmax(x).astype(mx.float32) if self._use_global_scale else None ) x = mx.qqmm( x, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index ee71480332..8fc6cd313d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2403,13 +2403,13 @@ void init_ops(nb::module_& m) { array: The output array with the corresponding axes reduced. )pbdoc"); m.def( - "abs_max", - &mx::abs_max, + "absmax", + &mx::absmax, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def abs_max(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + "def absmax(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( The maximum of absolute values of all elements in the array. diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 6b3c832d86..d014e5a7d6 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -162,7 +162,7 @@ def test_nvfp4_quantize_dequantize(self): # Test nvfp4 quantize/dequantize with tensor-scale global_scale if not mx.metal.is_available(): - global_scale = mx.abs_max(w).astype(mx.float32) + global_scale = mx.absmax(w).astype(mx.float32) else: global_scale = mx.array(1.0, dtype=mx.float32) From 507c94e919def90da4be3a11e47df3952dd195ef Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 23 Jan 2026 23:51:06 +0100 Subject: [PATCH 20/34] fix --- mlx/primitives.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8c48c88055..47841eddef 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3665,6 +3665,7 @@ std::vector GatherQMM::vjp( bits_, quantization_mode_to_string(mode_), std::nullopt, + std::nullopt, // amax placeholder stream()), -1, {-1, group_size_}, From 858fe00f78f152c58d05ae6d0e7881f35107e401 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 24 Jan 2026 00:05:16 +0100 Subject: [PATCH 21/34] fix test, force flobal scale only on cuda --- mlx/ops.cpp | 14 +++++++++++++- python/tests/test_quantized.py | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7d75c1cf9c..cc3d515be7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4729,7 +4729,12 @@ std::vector quantize( << " matrix has shape " << w.shape(); throw std::invalid_argument(msg.str()); } - + if (s.device == Device::gpu && metal::is_available() && + global_scale.has_value()) { + std::ostringstream msg; + msg << "[quantize] Global scale is not supported on the Metal backend."; + throw std::invalid_argument(msg.str()); + } validate_global_scale("quantize", qmode, global_scale); if (qmode == QuantizationMode::Affine) { return affine_quantize(w, group_size, bits, s); @@ -4991,6 +4996,13 @@ array dequantize( << "but it has only " << w.ndim() << "."; throw std::invalid_argument(msg.str()); } + if (global_scale.has_value()) { + if (s.device == Device::gpu && metal::is_available()) { + std::ostringstream msg; + msg << "[dequantize] Global scale is not supported on the Metal backend."; + throw std::invalid_argument(msg.str()); + } + } validate_global_scale("dequantize", qmode, global_scale); if (qmode == QuantizationMode::Affine) { diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index d014e5a7d6..2df23aaed7 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -164,7 +164,7 @@ def test_nvfp4_quantize_dequantize(self): if not mx.metal.is_available(): global_scale = mx.absmax(w).astype(mx.float32) else: - global_scale = mx.array(1.0, dtype=mx.float32) + global_scale = None w_q, scales = mx.quantize(w, mode="nvfp4", global_scale=global_scale) w_hat = mx.dequantize( From 5fdffe49b8c2659097fa4d829d9d65e90764fed9 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 24 Jan 2026 00:25:24 +0100 Subject: [PATCH 22/34] fix stream --- mlx/ops.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index cc3d515be7..beb4e5aff3 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -10,6 +10,7 @@ #include #include "mlx/backend/cuda/cuda.h" +#include "mlx/backend/metal/metal.h" #include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/primitives.h" @@ -4729,7 +4730,7 @@ std::vector quantize( << " matrix has shape " << w.shape(); throw std::invalid_argument(msg.str()); } - if (s.device == Device::gpu && metal::is_available() && + if (to_stream(s).device == Device::gpu && metal::is_available() && global_scale.has_value()) { std::ostringstream msg; msg << "[quantize] Global scale is not supported on the Metal backend."; @@ -4997,7 +4998,7 @@ array dequantize( throw std::invalid_argument(msg.str()); } if (global_scale.has_value()) { - if (s.device == Device::gpu && metal::is_available()) { + if (to_stream(s).device == Device::gpu && metal::is_available()) { std::ostringstream msg; msg << "[dequantize] Global scale is not supported on the Metal backend."; throw std::invalid_argument(msg.str()); From 1cc13ba28bec73010d8ee26947d159291f028637 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 24 Jan 2026 01:00:38 +0100 Subject: [PATCH 23/34] made AbsMax the same structure as Max --- mlx/backend/cuda/reduce/all_reduce.cu | 4 ++-- mlx/backend/cuda/reduce/reduce_ops.cuh | 26 ++------------------------ 2 files changed, 4 insertions(+), 26 deletions(-) diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 1ab79d1f0a..71366c4a24 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -14,11 +14,11 @@ namespace cu { namespace cg = cooperative_groups; template -__device__ __forceinline__ T absmax(T x) { +__device__ __forceinline__ auto absmax(T x) { if constexpr (cuda::std::is_unsigned_v) { return x; // No-op for unsigned types } else { - return abs(x); // Uses cu::abs for half types, ::abs for others + return cuda::std::abs(x); } } diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 953dd8ebf0..53c669cc8f 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -114,30 +114,8 @@ struct Max { } }; -struct AbsMax { - // abs is applied inside all_reduce kernel - template - __device__ __forceinline__ T operator()(T a, T b) { - if constexpr (is_complex_v) { - if (isnan(a.real()) || isnan(a.imag())) { - return a; - } - if (isnan(b.real()) || isnan(b.imag())) { - return b; - } - } else if constexpr (!cuda::std::is_integral_v) { - if (isnan(a) || isnan(b)) { - return cuda::std::numeric_limits::quiet_NaN(); - } - } - return a > b ? a : b; - } - - template - __device__ void atomic_update(T* x, T y) { - atomic_reduce(x, y); - } -}; +// AbsMax reuses Max logic; abs is applied inside all_reduce kernel +struct AbsMax : Max {}; // Traits to get the result type of reduce op. template From 05bd4d0b1bca4799d87cea4bd11c953d334e41fa Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 24 Jan 2026 01:11:38 +0100 Subject: [PATCH 24/34] abs_val rename --- mlx/backend/cuda/reduce/all_reduce.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 71366c4a24..3cd779e8e1 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -14,7 +14,7 @@ namespace cu { namespace cg = cooperative_groups; template -__device__ __forceinline__ auto absmax(T x) { +__device__ __forceinline__ auto abs_val(T x) { if constexpr (cuda::std::is_unsigned_v) { return x; // No-op for unsigned types } else { @@ -47,7 +47,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], absmax(cast_to(vals[j]))); + accs[0] = op(accs[0], abs_val(cast_to(vals[j]))); } else { accs[0] = op(accs[0], cast_to(vals[j])); } @@ -59,7 +59,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], absmax(cast_to(vals[i]))); + accs[0] = op(accs[0], abs_val(cast_to(vals[i]))); } else { accs[0] = op(accs[0], cast_to(vals[i])); } From 20480effed3195fe79619edf149c0d960a10a736 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 24 Jan 2026 14:22:31 +0100 Subject: [PATCH 25/34] fix abs type --- mlx/backend/cuda/reduce/all_reduce.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 3cd779e8e1..1cac7a37d8 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -14,7 +14,7 @@ namespace cu { namespace cg = cooperative_groups; template -__device__ __forceinline__ auto abs_val(T x) { +__device__ __forceinline__ T abs_val(T x) { if constexpr (cuda::std::is_unsigned_v) { return x; // No-op for unsigned types } else { From 9f9aabd4b79208385ad5a4e438a38cb20720f658 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sat, 24 Jan 2026 23:32:32 +0100 Subject: [PATCH 26/34] fix fp type for vjp --- mlx/primitives.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 47841eddef..844f58f252 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3511,7 +3511,7 @@ std::vector QQMatmul::vjp( // cotan -- non quantized grads (M, N) auto qmode = quantization_mode_to_string(mode_); std::optional cotan_amax = - is_nvfp4 ? std::make_optional(absmax(cotan, s)) : std::nullopt; + is_nvfp4 ? std::make_optional(astype(absmax(cotan, s), float32)) : std::nullopt; auto get_primal_scale = [&](int idx) { return is_nvfp4 ? std::make_optional(primals[idx]) : std::nullopt; From d2dc3103c710efa4c7e58fc6b9ec8fac86c3dd90 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sun, 25 Jan 2026 23:06:47 +0100 Subject: [PATCH 27/34] decrease block size because of the register pressure --- mlx/backend/cuda/quantized/fp_quantize.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 9074e26df5..e8d1a39112 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -133,7 +133,7 @@ __global__ void fp_quantize_columnwise( auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); - constexpr int BLOCK_X = 32; + constexpr int BLOCK_X = 16; constexpr int BLOCK_Y = 32; constexpr int elem_per_byte = (bits == 8) ? 1 : 2; constexpr int bytes_per_group = group_size / elem_per_byte; @@ -292,7 +292,7 @@ __global__ void fp_dequantize( inline std::tuple get_columnwise_quantize_launch_args(size_t size, int group_size, int M, int K) { - constexpr int BLOCK_X = 32; + constexpr int BLOCK_X = 16; constexpr int BLOCK_Y = 32; int rows_per_block = BLOCK_X; int cols_per_block = BLOCK_Y * group_size; From 79d93e6a31bae32a72270bb2f57e5f5bedd47a6b Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 3 Feb 2026 00:26:51 +0100 Subject: [PATCH 28/34] drop absmax --- mlx/backend/cuda/reduce/all_reduce.cu | 21 ++------------------- mlx/backend/cuda/reduce/reduce.cuh | 2 -- mlx/backend/cuda/reduce/reduce_ops.cuh | 15 --------------- mlx/ops.cpp | 23 ----------------------- mlx/ops.h | 3 --- mlx/primitives.cpp | 5 +++-- mlx/primitives.h | 4 +--- python/src/ops.cpp | 17 ----------------- python/tests/test_quantized.py | 2 +- 9 files changed, 7 insertions(+), 85 deletions(-) diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 794157be19..962e80d4f2 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -13,15 +13,6 @@ namespace cu { namespace cg = cooperative_groups; -template -__device__ __forceinline__ T abs_val(T x) { - if constexpr (cuda::std::is_unsigned_v) { - return x; // No-op for unsigned types - } else { - return cuda::std::abs(x); - } -} - template __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { // TODO: Process multiple "rows" in each thread @@ -46,11 +37,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { for (; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { - if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], abs_val(cast_to(vals[j]))); - } else { - accs[0] = op(accs[0], cast_to(vals[j])); - } + accs[0] = op(accs[0], cast_to(vals[j])); } } @@ -58,11 +45,7 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { cub::LoadDirectBlocked( block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { - if constexpr (cuda::std::is_same_v) { - accs[0] = op(accs[0], abs_val(cast_to(vals[i]))); - } else { - accs[0] = op(accs[0], cast_to(vals[i])); - } + accs[0] = op(accs[0], cast_to(vals[i])); } } diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index 947e8b36dc..02e495594a 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -35,8 +35,6 @@ void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Min) { f(type_identity{}); - } else if (reduce_type == Reduce::ReduceType::AbsMax) { - f(type_identity{}); } else { throw std::invalid_argument("Unknown reduce type."); } diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index 53c669cc8f..6c6b1827ce 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -114,9 +114,6 @@ struct Max { } }; -// AbsMax reuses Max logic; abs is applied inside all_reduce kernel -struct AbsMax : Max {}; - // Traits to get the result type of reduce op. template struct ReduceResult; @@ -157,11 +154,6 @@ struct ReduceResult { using type = T; }; -template -struct ReduceResult { - using type = T; -}; - // Traits to get the init value of reduce op. template struct ReduceInit; @@ -216,11 +208,4 @@ struct ReduceInit { } }; -template -struct ReduceInit { - static constexpr __host__ __device__ auto value() { - return typename ReduceResult::type(0); // abs values are >= 0 - } -}; - } // namespace mlx::core::cu diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c6c2b894ab..a0e1636e33 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2238,29 +2238,6 @@ array min( return min(a, std::vector{axis}, keepdims, s); } -// TODO: extend to row_reduce and col_reduce? -array absmax(const array& a, StreamOrDevice s /* = {}*/) { - if (a.size() == 0) { - throw std::invalid_argument( - "[absmax] Cannot absmax reduce zero size array."); - } - if (!issubdtype(a.dtype(), floating)) { - throw std::invalid_argument( - "[absmax] absmax supported only for floating point types."); - } - auto stream = to_stream(s); - if (stream.device != Device::gpu || !cu::is_available()) { - return max(abs(a, s), false, s); - } - std::vector axes(a.ndim()); - std::iota(axes.begin(), axes.end(), 0); - return array( - {}, - a.dtype(), - std::make_shared(stream, Reduce::AbsMax, axes), - {a}); -} - array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { auto result = argmin(flatten(a, s), 0, true, s); if (keepdims) { diff --git a/mlx/ops.h b/mlx/ops.h index caf3e3d834..cc9db8aedb 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -666,9 +666,6 @@ min(const array& a, MLX_API array min(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}); -/** The maximum of absolute values of all elements of the array. */ -array absmax(const array& a, StreamOrDevice s = {}); - /** Returns the index of the minimum value in the array. */ MLX_API array argmin(const array& a, bool keepdims, StreamOrDevice s = {}); inline array argmin(const array& a, StreamOrDevice s = {}) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 844f58f252..e3ff52e7c7 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3510,8 +3510,9 @@ std::vector QQMatmul::vjp( // primal[0] -- non quantized activations (M, K) // cotan -- non quantized grads (M, N) auto qmode = quantization_mode_to_string(mode_); - std::optional cotan_amax = - is_nvfp4 ? std::make_optional(astype(absmax(cotan, s), float32)) : std::nullopt; + std::optional cotan_amax = is_nvfp4 + ? std::make_optional(astype(max(abs(cotan, s), s), float32)) + : std::nullopt; auto get_primal_scale = [&](int idx) { return is_nvfp4 ? std::make_optional(primals[idx]) : std::nullopt; diff --git a/mlx/primitives.h b/mlx/primitives.h index 6d41ede0ff..4091aafcfb 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1770,7 +1770,7 @@ class Reshape : public UnaryPrimitive { class MLX_API Reduce : public UnaryPrimitive { public: - enum ReduceType { And, Or, Sum, Prod, Min, Max, AbsMax }; + enum ReduceType { And, Or, Sum, Prod, Min, Max }; explicit Reduce( Stream stream, @@ -1800,8 +1800,6 @@ class MLX_API Reduce : public UnaryPrimitive { return "Min"; case Max: return "Max"; - case AbsMax: - return "AbsMax"; } return ""; } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2a036100ad..131f3d12e9 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2406,23 +2406,6 @@ void init_ops(nb::module_& m) { Returns: array: The output array with the corresponding axes reduced. )pbdoc"); - m.def( - "absmax", - &mx::absmax, - nb::arg(), - nb::kw_only(), - "stream"_a = nb::none(), - nb::sig( - "def absmax(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), - R"pbdoc( - The maximum of absolute values of all elements in the array. - - Args: - a (array): Input array. - - Returns: - array: A scalar array with the maximum absolute value. - )pbdoc"); m.def( "logcumsumexp", [](const mx::array& a, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 214b84c289..5bdefa2f0c 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -162,7 +162,7 @@ def test_nvfp4_quantize_dequantize(self): # Test nvfp4 quantize/dequantize with tensor-scale global_scale if not mx.metal.is_available(): - global_scale = mx.absmax(w).astype(mx.float32) + global_scale = w.abs().max().astype(mx.float32) else: global_scale = None From d91fd8a6fd679ab9129db4c4d5f7b56c6545246b Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 3 Feb 2026 00:31:02 +0100 Subject: [PATCH 29/34] merge conflict fp-quantize --- mlx/backend/cuda/quantized/fp_quantize.cu | 101 +++++++++++++++++++++- python/mlx/nn/layers/quantized.py | 5 +- 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index e8d1a39112..40dc45f50a 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -32,7 +32,70 @@ struct Dequantize { }; namespace cg = cooperative_groups; -// TODO: global_scale type + +template +__global__ void fp_quantize_dequantize(T* w, T* out, size_t size) { + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + uint32_t rbits = 0; // reserved bits for future use + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; + + size_t thread_idx = tidx + grid_dim_x * size_t(tidy); + size_t base_idx = thread_idx * group_size; + + if (base_idx >= size) { + return; + } + + auto w_tile = load_vector(w, thread_idx); + float scale = 0.0f; + + Tx2 amax_2x = Tx2{0.0f, 0.0f}; + +#pragma unroll + for (int i = 0; i < group_size; i += 2) { + auto pair = Tx2{w_tile[i], w_tile[i + 1]}; + abs_max_x2(amax_2x, amax_2x, pair); + } + + scale = static_cast( + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y)))); + + scale /= bits == 4 ? 6.0f : 448.0f; + // Convert to mx scale or nv scale + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale); + scale = float(s); + AlignedVector w_hat; + +#pragma unroll + for (int i = 0; i < group_size / 4; i++) { + Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); + float4 dq; + if constexpr (bits == 8) { + uint32_t quantized_val = + scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + dq = dequant_fp8(quantized_val); + } else { + uint16_t quantized_val = + scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); + dq = dequant_fp4(quantized_val); + } + w_hat[i * 4] = static_cast(dq.x * scale); + w_hat[i * 4 + 1] = static_cast(dq.y * scale); + w_hat[i * 4 + 2] = static_cast(dq.z * scale); + w_hat[i * 4 + 3] = static_cast(dq.w * scale); + } + store_vector(out, thread_idx, w_hat); +} + template __global__ void fp_quantize_rowwise( T* w, @@ -277,7 +340,7 @@ __global__ void fp_dequantize( out += oindex; - uint val = w[offset]; + uint32_t val = w[offset]; #pragma clang loop unroll(full) for (int i = 0; i < pack_factor; i++) { uint8_t d; @@ -310,6 +373,40 @@ get_columnwise_quantize_launch_args(size_t size, int group_size, int M, int K) { } // namespace cu +void fp_quantize_dequantize( + const array& w, + array& what, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(w); + enc.set_output_array(what); + dispatch_float_types(w.dtype(), "fp_quantize_dequantize", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_quantize_dequantize; + if (bits == 8) { + kernel = cu::fp_quantize_dequantize; + } else if (group_size == 16) { + kernel = cu::fp_quantize_dequantize; + } + bool large = w.size() > UINT_MAX; + auto [num_blocks, block_dims] = + get_launch_args(w.size(), w.shape(), w.strides(), large, group_size); + + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(what), + w.size()); + } + }); +} + void fp_quantize( const array& w, array& wq, diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 04dca73650..f72f170aaf 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -295,9 +295,8 @@ class QQLinear(Module): Compared to the :class:`mlx.nn.QuantizedLinear` layer, this layer quantizes the input as well and includes weights in gradient computations. - :obj:`QQLinear` also provides: - - the class method :meth:`from_linear` to convert :class:`mlx.nn.Linear` - layers to :obj:`QQLinear` layers. + :obj:`QQLinear` also provides the class method :meth:`from_linear` to + convert :class:`mlx.nn.Linear` layers to :obj:`QQLinear` layers. Note: This layer does not support a bias term yet. From 5cbf48fcf2d02f25c32d541ced51ffce2bbd27c9 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 3 Feb 2026 02:06:28 +0100 Subject: [PATCH 30/34] add scale to fp_quantiz-dequantize, fix merge conflicts, refactor --- mlx/backend/cuda/quantized/fp_quantize.cu | 42 +++++--- mlx/backend/cuda/quantized/qqmm.cpp | 118 ++++------------------ mlx/backend/cuda/quantized/qqmm_impl.cpp | 25 +++-- mlx/backend/cuda/quantized/qqmm_impl.h | 17 +++- mlx/backend/cuda/quantized/qqmm_utils.h | 23 +++++ mlx/backend/cuda/quantized/quantized.h | 1 + 6 files changed, 100 insertions(+), 126 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 40dc45f50a..e2d1af5ae2 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -34,7 +34,13 @@ struct Dequantize { namespace cg = cooperative_groups; template -__global__ void fp_quantize_dequantize(T* w, T* out, size_t size) { +__global__ void +fp_quantize_dequantize(T* w, T* out, size_t size, float* global_scale = nullptr) { + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f; + using Tx2 = Vector2_t; using Tx4 = Vector4_t; uint32_t rbits = 0; // reserved bits for future use @@ -53,26 +59,28 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) { } auto w_tile = load_vector(w, thread_idx); - float scale = 0.0f; + float scale_dec_b = 0.0f; Tx2 amax_2x = Tx2{0.0f, 0.0f}; #pragma unroll for (int i = 0; i < group_size; i += 2) { auto pair = Tx2{w_tile[i], w_tile[i + 1]}; - abs_max_x2(amax_2x, amax_2x, pair); + absmax_x2(amax_2x, amax_2x, pair); } - scale = static_cast( + scale_dec_b = static_cast( max(fabsf(static_cast(amax_2x.x)), fabsf(static_cast(amax_2x.y)))); - scale /= bits == 4 ? 6.0f : 448.0f; + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; // Convert to mx scale or nv scale using ScaleType = std::conditional_t; - auto s = ScaleType(scale); - scale = float(s); + auto s = ScaleType(scale_dec_b); + float scale_enc_b = scale_enc / float(s); + float scale_dec = float(s) * inv_scale_enc; AlignedVector w_hat; #pragma unroll @@ -81,17 +89,17 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) { float4 dq; if constexpr (bits == 8) { uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); dq = dequant_fp8(quantized_val); } else { uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); dq = dequant_fp4(quantized_val); } - w_hat[i * 4] = static_cast(dq.x * scale); - w_hat[i * 4 + 1] = static_cast(dq.y * scale); - w_hat[i * 4 + 2] = static_cast(dq.z * scale); - w_hat[i * 4 + 3] = static_cast(dq.w * scale); + w_hat[i * 4] = static_cast(dq.x * scale_dec); + w_hat[i * 4 + 1] = static_cast(dq.y * scale_dec); + w_hat[i * 4 + 2] = static_cast(dq.z * scale_dec); + w_hat[i * 4 + 3] = static_cast(dq.w * scale_dec); } store_vector(out, thread_idx, w_hat); } @@ -378,9 +386,13 @@ void fp_quantize_dequantize( array& what, int group_size, int bits, + const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { enc.set_input_array(w); + if (global_scale.has_value()) { + enc.set_input_array(global_scale.value()); + } enc.set_output_array(what); dispatch_float_types(w.dtype(), "fp_quantize_dequantize", [&](auto type_tag) { using T = cuda_type_t; @@ -402,7 +414,9 @@ void fp_quantize_dequantize( 0, gpu_ptr(w), gpu_ptr(what), - w.size()); + w.size(), + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); } }); } diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 18f58bb391..f4cf03fd3a 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -14,48 +14,6 @@ namespace mlx::core { namespace { -struct GemmScalars { - std::optional alpha_device; - std::optional beta_device; - - bool uses_device_pointers() const { - return alpha_device.has_value(); - } -}; - -inline array -ensure_contiguous(const array& x, cu::CommandEncoder& enc, const Stream& s) { - if (x.flags().row_contiguous || x.flags().col_contiguous) { - return x; - } - array x_copy = contiguous_copy_gpu(x, s); - enc.add_temporary(x_copy); - return x_copy; -} - -array pad_and_swizzle_scales( - const array& scale, - cu::CommandEncoder& encoder, - const Stream& s) { - // Compute padded dimensions for full tiles (128 rows × 4 cols) - auto [pad_outer, pad_inner] = - get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); - // cuBLAS requirements for scale factor layout: - // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) - // 2. Out-of-bounds values must be filled with zeros - // 3. Starting addresses must be 16-byte aligned - // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout - // Note: cu::malloc_async already provides 256-byte alignment - array scale_tiled( - cu::malloc_async(pad_outer * pad_inner, encoder), - Shape{pad_outer, pad_inner}, - scale.dtype()); - swizzle_scales(scale, scale_tiled, encoder, s); - - encoder.add_temporary(scale_tiled); - return scale_tiled; -} - std::tuple quantize_input( const array& input, cu::CommandEncoder& encoder, @@ -109,71 +67,32 @@ GemmScalars create_nvfp4_scalars( return {alpha, beta}; } -void run_qqmm( - cu::CommandEncoder& encoder, - int M, - int N, - int K, - bool a_transposed, - int64_t lda, - bool b_transposed, - int64_t ldb, - array& out, - const array& a, - const array& b, - const array& a_scale, - const array& b_scale, - QuantizationMode mode, - const GemmScalars& scalars) { - std::string qmode = quantization_mode_to_string(mode); - - CublasQQMM qqmm( - encoder.device(), - a_transposed, - M, - K, - lda, - b_transposed, - K, - N, - ldb, - 1, // batch_count - 0, // a_batch_stride - 0, // b_batch_stride - out.dtype(), - qmode); - - if (scalars.uses_device_pointers()) { - qqmm.run( - encoder, - out, - a, - b, - a_scale, - b_scale, - *scalars.alpha_device, - *scalars.beta_device); - } else { - qqmm.run(encoder, out, a, b, a_scale, b_scale); - } -} - } // namespace void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { - assert( - (inputs.size() == 3 && inputs[1].dtype() == uint32) || - (inputs.size() == 2)); nvtx3::scoped_range r("QQMatmul::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto& device = encoder.device(); - bool w_quantized = (inputs[1].dtype() == uint32); + int base_size = w_quantized ? 3 : 2; + + assert( + inputs.size() == base_size || + (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2)); + if (w_quantized && inputs[0].shape(-2) == 1) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); + // For nvfp4, get global scale for x from inputs if present + bool has_global_scales_m1 = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; + std::optional global_scale_x_m1 = std::nullopt; + if (has_global_scales_m1) { + global_scale_x_m1 = inputs[inputs.size() - 2]; + } + bool donate_x = inputs[0].is_donatable(); array x = ensure_row_contiguous(inputs[0], encoder, s); // If x is a copy it should be donatable @@ -184,7 +103,8 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { if (!donate_x) { encoder.add_temporary(xhat); } - fp_quantize_dequantize(x, xhat, group_size_, bits_, encoder, s); + fp_quantize_dequantize( + x, xhat, group_size_, bits_, global_scale_x_m1, encoder, s); // Make sure the last two dims of w and s are contiguous array w = ensure_row_contiguous_matrix(inputs[1], encoder, s); @@ -208,8 +128,6 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { // - 2 inputs: x, w (non-quantized w) // - 3 inputs: x, w, scales_w (quantized w) - bool w_is_quantized = inputs[1].dtype() == uint32; - int base_size = w_is_quantized ? 3 : 2; // For nvfp4, global scales are optional but must be both present or both // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) @@ -227,7 +145,7 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { // Quantize inputs (or use pre-quantized) auto [x_q, scale_x_pre] = quantize_input( inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x); - auto [w_q, scale_w_pre] = !w_is_quantized + auto [w_q, scale_w_pre] = !w_quantized ? quantize_input( inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w) : std::make_tuple( @@ -254,7 +172,7 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder); } - run_qqmm( + qqmm_impl( encoder, M, N, diff --git a/mlx/backend/cuda/quantized/qqmm_impl.cpp b/mlx/backend/cuda/quantized/qqmm_impl.cpp index dd9407dcdc..e005d4f2bf 100644 --- a/mlx/backend/cuda/quantized/qqmm_impl.cpp +++ b/mlx/backend/cuda/quantized/qqmm_impl.cpp @@ -1,4 +1,4 @@ -// Copyright © 2026 Apple Inc. +// Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/quantized/qqmm_impl.h" #include "mlx/backend/cuda/quantized/cublas_qqmm.h" @@ -19,15 +19,10 @@ void qqmm_impl( const array& b, const array& a_scale, const array& b_scale, - Dtype out_dtype, QuantizationMode mode, - float alpha) { - // Invoke CublasQQMM + const GemmScalars& scalars) { std::string qmode = quantization_mode_to_string(mode); - // Currently only supports non-batched QQMM operations - // that covers all use cases for training, we will just collapse (batch, - // seq_len) into (tokens) CublasQQMM qqmm( encoder.device(), a_transposed, @@ -41,10 +36,22 @@ void qqmm_impl( 1, // batch_count 0, // a_batch_stride 0, // b_batch_stride - out_dtype, + out.dtype(), qmode); - qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha); + if (scalars.uses_device_pointers()) { + qqmm.run( + encoder, + out, + a, + b, + a_scale, + b_scale, + *scalars.alpha_device, + *scalars.beta_device); + } else { + qqmm.run(encoder, out, a, b, a_scale, b_scale); + } } } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_impl.h b/mlx/backend/cuda/quantized/qqmm_impl.h index 7288e2fd7b..c562bfb186 100644 --- a/mlx/backend/cuda/quantized/qqmm_impl.h +++ b/mlx/backend/cuda/quantized/qqmm_impl.h @@ -1,10 +1,22 @@ -// Copyright © 2026 Apple Inc. +// Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device.h" #include "mlx/primitives.h" +#include + namespace mlx::core { + +struct GemmScalars { + std::optional alpha_device; + std::optional beta_device; + + bool uses_device_pointers() const { + return alpha_device.has_value(); + } +}; + void qqmm_impl( cu::CommandEncoder& encoder, int M, @@ -19,8 +31,7 @@ void qqmm_impl( const array& b, const array& a_scale, const array& b_scale, - Dtype out_dtype, QuantizationMode mode, - float alpha = 1.0f); + const GemmScalars& scalars = {}); } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.h b/mlx/backend/cuda/quantized/qqmm_utils.h index e40f09190f..df79882a79 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.h +++ b/mlx/backend/cuda/quantized/qqmm_utils.h @@ -21,6 +21,29 @@ inline std::pair get_padded_scale_dims(int num_rows, int num_cols) { return {padded_rows, padded_cols}; } +inline array pad_and_swizzle_scales( + const array& scale, + cu::CommandEncoder& encoder, + const Stream& s) { + // Compute padded dimensions for full tiles (128 rows × 4 cols) + auto [pad_outer, pad_inner] = + get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); + // cuBLAS requirements for scale factor layout: + // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) + // 2. Out-of-bounds values must be filled with zeros + // 3. Starting addresses must be 16-byte aligned + // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + // Note: cu::malloc_async already provides 256-byte alignment + array scale_tiled( + cu::malloc_async(pad_outer * pad_inner, encoder), + Shape{pad_outer, pad_inner}, + scale.dtype()); + swizzle_scales(scale, scale_tiled, encoder, s); + + encoder.add_temporary(scale_tiled); + return scale_tiled; +} + void swizzle_scales( const array& scales, array& scales_tiled, diff --git a/mlx/backend/cuda/quantized/quantized.h b/mlx/backend/cuda/quantized/quantized.h index a7f0062524..f15c0f76e1 100644 --- a/mlx/backend/cuda/quantized/quantized.h +++ b/mlx/backend/cuda/quantized/quantized.h @@ -50,6 +50,7 @@ void fp_quantize_dequantize( array& what, int group_size, int bits, + const std::optional& global_scale, cu::CommandEncoder& enc, const Stream& s); From da1cacf7a7f4faa9d8b0e306403ed2eaae31594d Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 4 Feb 2026 02:20:13 +0100 Subject: [PATCH 31/34] pre-commit + update a comment --- mlx/backend/cuda/quantized/fp_quantize.cu | 7 +++++-- mlx/backend/cuda/quantized/qqmm_utils.h | 13 ++++++------- python/tests/test_quantized.py | 1 + 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index e2d1af5ae2..d36ae6581e 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -34,8 +34,11 @@ struct Dequantize { namespace cg = cooperative_groups; template -__global__ void -fp_quantize_dequantize(T* w, T* out, size_t size, float* global_scale = nullptr) { +__global__ void fp_quantize_dequantize( + T* w, + T* out, + size_t size, + float* global_scale = nullptr) { const bool use_global_scale = global_scale != nullptr; const float scale_enc = use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; diff --git a/mlx/backend/cuda/quantized/qqmm_utils.h b/mlx/backend/cuda/quantized/qqmm_utils.h index df79882a79..fba9ac9d9e 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.h +++ b/mlx/backend/cuda/quantized/qqmm_utils.h @@ -21,6 +21,12 @@ inline std::pair get_padded_scale_dims(int num_rows, int num_cols) { return {padded_rows, padded_cols}; } +void swizzle_scales( + const array& scales, + array& scales_tiled, + cu::CommandEncoder& enc, + const Stream& s); + inline array pad_and_swizzle_scales( const array& scale, cu::CommandEncoder& encoder, @@ -44,15 +50,8 @@ inline array pad_and_swizzle_scales( return scale_tiled; } -void swizzle_scales( - const array& scales, - array& scales_tiled, - cu::CommandEncoder& enc, - const Stream& s); - // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 // Allocate beta zero on device as well - void compute_qqmm_pointers( array& alpha_out, array& beta_out, diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 5bdefa2f0c..ba85eaf38a 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -161,6 +161,7 @@ def test_nvfp4_quantize_dequantize(self): self.assertTrue(mx.all(w_hat == 0)) # Test nvfp4 quantize/dequantize with tensor-scale global_scale + # currently supported only on cpu and cuda if not mx.metal.is_available(): global_scale = w.abs().max().astype(mx.float32) else: From 67385c8907758e64ed6ebbf7fab4716d3547f584 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 4 Feb 2026 02:26:18 +0100 Subject: [PATCH 32/34] revert qq_linear global scale [WIP] --- python/mlx/nn/layers/quantized.py | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index f72f170aaf..15bbfa76dd 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -319,7 +319,6 @@ def __init__( group_size: int = None, bits: int = None, mode: str = "nvfp4", - use_global_scale: bool = True, ): super().__init__() @@ -327,11 +326,6 @@ def __init__( self.group_size, self.bits = _defaults_for_mode(mode, group_size, bits) self.mode = mode - if self.mode != "nvfp4" and use_global_scale: - raise ValueError( - "Global scale can only be used with 'nvfp4' quantization mode." - ) - scale = math.sqrt(1 / input_dims) self.weight = mx.random.uniform( low=-scale, @@ -339,7 +333,6 @@ def __init__( shape=(output_dims, input_dims), ) self._quantized = False - self._use_global_scale = use_global_scale def _extra_repr(self): out_dims, in_dims = self.weight.shape @@ -352,18 +345,11 @@ def _extra_repr(self): def quantize(self): if not self._quantized: - - self.global_scale_w = ( - mx.absmax(self.weight).astype(mx.float32) - if self._use_global_scale - else None - ) self.weight, self.scales = mx.quantize( self.weight, self.group_size, self.bits, mode=self.mode, - global_scale=self.global_scale_w, ) self._quantized = True @@ -375,11 +361,8 @@ def dequantize(self): group_size=self.group_size, bits=self.bits, mode=self.mode, - global_scale=self.global_scale_w, ) - del self.scales - if self._use_global_scale: - del self.global_scale_w + self.__delattr__("scales") self._quantized = False def _set_training_mode(self, mode: bool): @@ -391,15 +374,6 @@ def _set_training_mode(self, mode: bool): self.quantize() def __call__(self, x): - - global_scale_w = ( - getattr(self, "global_scale_w", mx.absmax(self.weight).astype(mx.float32)) - if self._use_global_scale - else None - ) - global_scale_x = ( - mx.absmax(x).astype(mx.float32) if self._use_global_scale else None - ) x = mx.qqmm( x, self["weight"], @@ -407,8 +381,6 @@ def __call__(self, x): group_size=self.group_size, bits=self.bits, mode=self.mode, - global_scale_x=global_scale_x, - global_scale_w=global_scale_w, ) return x From ad1fcf1f31d2084ebfcd6bdfeaa862f610275557 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 4 Feb 2026 18:09:54 +0100 Subject: [PATCH 33/34] refactoring, revert block size --- mlx/backend/cuda/quantized/fp_quantize.cu | 4 ++-- mlx/backend/cuda/quantized/qqmm.cpp | 10 +++++----- python/tests/test_quantized.py | 22 ---------------------- 3 files changed, 7 insertions(+), 29 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index d36ae6581e..d1b402f26d 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -207,7 +207,7 @@ __global__ void fp_quantize_columnwise( auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); - constexpr int BLOCK_X = 16; + constexpr int BLOCK_X = 32; constexpr int BLOCK_Y = 32; constexpr int elem_per_byte = (bits == 8) ? 1 : 2; constexpr int bytes_per_group = group_size / elem_per_byte; @@ -366,7 +366,7 @@ __global__ void fp_dequantize( inline std::tuple get_columnwise_quantize_launch_args(size_t size, int group_size, int M, int K) { - constexpr int BLOCK_X = 16; + constexpr int BLOCK_X = 32; constexpr int BLOCK_Y = 32; int rows_per_block = BLOCK_X; int cols_per_block = BLOCK_Y * group_size; diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index f4cf03fd3a..665bdb45b7 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -86,11 +86,11 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); // For nvfp4, get global scale for x from inputs if present - bool has_global_scales_m1 = + bool has_global_scale = mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; - std::optional global_scale_x_m1 = std::nullopt; - if (has_global_scales_m1) { - global_scale_x_m1 = inputs[inputs.size() - 2]; + std::optional global_scale = std::nullopt; + if (has_global_scale) { + global_scale = inputs[inputs.size() - 2]; } bool donate_x = inputs[0].is_donatable(); @@ -104,7 +104,7 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { encoder.add_temporary(xhat); } fp_quantize_dequantize( - x, xhat, group_size_, bits_, global_scale_x_m1, encoder, s); + x, xhat, group_size_, bits_, global_scale, encoder, s); // Make sure the last two dims of w and s are contiguous array w = ensure_row_contiguous_matrix(inputs[1], encoder, s); diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index ba85eaf38a..d53c5e568b 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -405,28 +405,6 @@ def test_fp_qmv(self): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) - # Test multiple of 16 but not 32 - M = 128 - N = 48 - mode = "nvfp4" - with self.subTest(shape=(B, M, N), mode=mode): - x_shape = (1, N) - w_shape = (M, N) - x = mx.random.normal(shape=x_shape, key=k1) - w = mx.random.normal(shape=w_shape, key=k2) - w_q, scales = mx.quantize(w, mode=mode) - w_hat = mx.dequantize(w_q, scales, mode=mode) - y_q = mx.quantized_matmul( - x, - w_q, - scales, - transpose=True, - mode=mode, - ) - y_hat = x @ mx.swapaxes(w_hat, -1, -2) - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) - def test_qvm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) From b1dcd2ff13cc067383bcaa969200476d6b7930cf Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 4 Feb 2026 18:18:47 +0100 Subject: [PATCH 34/34] revert the year change --- mlx/backend/cuda/quantized/qqmm_impl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/quantized/qqmm_impl.cpp b/mlx/backend/cuda/quantized/qqmm_impl.cpp index e005d4f2bf..d5986e05ea 100644 --- a/mlx/backend/cuda/quantized/qqmm_impl.cpp +++ b/mlx/backend/cuda/quantized/qqmm_impl.cpp @@ -1,4 +1,4 @@ -// Copyright © 2025 Apple Inc. +// Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/quantized/qqmm_impl.h" #include "mlx/backend/cuda/quantized/cublas_qqmm.h"