diff --git a/mlx/backend/cuda/cublas_utils.cpp b/mlx/backend/cuda/cublas_utils.cpp index 1176bd49d2..b214a37eab 100644 --- a/mlx/backend/cuda/cublas_utils.cpp +++ b/mlx/backend/cuda/cublas_utils.cpp @@ -105,13 +105,6 @@ void CublasMatmulBase::init_base( CHECK_CUBLAS_ERROR( cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type)); - int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, - sizeof(int32_t))); - // In cublasLt matrices use column-major layout, while it is possible to use // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias // epilogue does not work with the option. So instead we swap A and B to make diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 78340457a2..a790d946a5 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -73,6 +73,14 @@ CublasGemm::CublasGemm( batch_count, a_batch_stride, b_batch_stride); + + // alpha and beta are both host pointers + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(pointer_mode))); } CublasGemm::CublasGemm( @@ -215,8 +223,8 @@ void CublasGemm::execute( const void* a, const void* b, const void* c, - float alpha /* = 1 */, - float beta /* = 0 */) { + const float alpha /* = 1 */, + const float beta /* = 0 */) { const void* alpha_ptr = α const void* beta_ptr = β complex64_t alpha_c, beta_c; diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.cpp b/mlx/backend/cuda/quantized/cublas_qqmm.cpp index 959ae93520..f81b6455c5 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.cpp +++ b/mlx/backend/cuda/quantized/cublas_qqmm.cpp @@ -13,39 +13,26 @@ namespace mlx::core { namespace { -// Currently cublas supports only mxfp8 and nvfp4 -// quantization modes for block scaled quantization -cudaDataType_t qmode_to_cublas_scale_dtype(std::string mode) { - if (mode == "mxfp8") { - return CUDA_R_8F_UE8M0; - } else if (mode == "nvfp4") { - return CUDA_R_8F_UE4M3; - } else { - throw std::runtime_error( - fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); - } -} - -cudaDataType_t qmode_to_cublas_dtype(std::string mode) { - if (mode == "mxfp8") { - return CUDA_R_8F_E4M3; - } else if (mode == "nvfp4") { - return CUDA_R_4F_E2M1; - } else { - throw std::runtime_error( - fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); - } -} +struct QuantModeConfig { + cudaDataType_t data_type; + cudaDataType_t scale_dtype; + cublasLtMatmulMatrixScale_t scale_mode; +}; -cublasLtMatmulMatrixScale_t qmode_to_cublas_scale_mode(std::string mode) { +QuantModeConfig get_quant_mode_config(const std::string& mode) { if (mode == "mxfp8") { - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + return { + CUDA_R_8F_E4M3, + CUDA_R_8F_UE8M0, + CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0}; } else if (mode == "nvfp4") { - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; - } else { - throw std::runtime_error( - fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); + return { + CUDA_R_4F_E2M1, + CUDA_R_8F_UE4M3, + CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3}; } + throw std::runtime_error( + fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode)); } } // namespace @@ -64,21 +51,21 @@ CublasQQMM::CublasQQMM( int64_t a_batch_stride, int64_t b_batch_stride, Dtype out_dtype, - std::string qmode) { + const std::string& qmode) { + auto config = get_quant_mode_config(qmode); + // The compute type must be CUBLAS_COMPUTE_32F. // The scale type must be CUDA_R_32F. cudaDataType_t scale_type = CUDA_R_32F; cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F; cudaDataType_t output_type = cublas_utils::dtype_to_cublas_type(out_dtype, "CublasQQMM"); - cudaDataType_t data_type = qmode_to_cublas_dtype(qmode); - quantization_mode_ = std::string(qmode); init_base( device, scale_type, gemm_compute_type, - data_type, + config.data_type, output_type, a_transposed, a_rows, @@ -92,8 +79,8 @@ CublasQQMM::CublasQQMM( a_batch_stride, b_batch_stride); - a_scale_mode_ = qmode_to_cublas_scale_mode(qmode); - b_scale_mode_ = qmode_to_cublas_scale_mode(qmode); + a_scale_mode_ = config.scale_mode; + b_scale_mode_ = config.scale_mode; CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, @@ -123,7 +110,7 @@ CublasQQMM::CublasQQMM( int64_t b_batch_stride, int64_t c_batch_stride, Dtype out_dtype, - std::string qmode) + const std::string& qmode) : CublasQQMM( device, a_transposed, @@ -158,11 +145,14 @@ void CublasQQMM::run( const array& b, const array& a_scale, const array& b_scale, - float alpha) { + const array& alpha, + const array& beta) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(a_scale); encoder.set_input_array(b_scale); + encoder.set_input_array(alpha); + encoder.set_input_array(beta); encoder.set_output_array(out); execute( @@ -173,19 +163,37 @@ void CublasQQMM::run( gpu_ptr(a_scale), gpu_ptr(b_scale), nullptr, - alpha); + gpu_ptr(alpha), + gpu_ptr(beta)); } -void CublasQQMM::execute( +void CublasQQMM::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(a_scale); + encoder.set_input_array(b_scale); + encoder.set_output_array(out); + + execute( + encoder, + gpu_ptr(out), + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(a_scale), + gpu_ptr(b_scale), + nullptr); +} + +void CublasQQMM::set_scales_ptrs( cu::CommandEncoder& encoder, - void* out, - const void* a, - const void* b, const void* a_scale, - const void* b_scale, - const void* c, - float alpha /* = 1 */, - float beta /* = 0 */) { + const void* b_scale) { CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( matmul_desc_, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, @@ -196,6 +204,49 @@ void CublasQQMM::execute( CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &a_scale, sizeof(a_scale))); +} + +void CublasQQMM::execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + const void* alpha, + const void* beta) { + set_scales_ptrs(encoder, a_scale, b_scale); + // alpha and beta are both should be device pointers for nvfp4 + // by default cublas uses host pointers + // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(pointer_mode))); + execute_matmul(encoder, out, a, b, c, alpha, beta); +} + +void CublasQQMM::execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + const float alpha /* = 1 */, + const float beta /* = 0 */) { + set_scales_ptrs(encoder, a_scale, b_scale); + // alpha and beta are both should be host pointers + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(pointer_mode))); const void* alpha_ptr = α const void* beta_ptr = β diff --git a/mlx/backend/cuda/quantized/cublas_qqmm.h b/mlx/backend/cuda/quantized/cublas_qqmm.h index 0a710f6e10..a9095012c5 100644 --- a/mlx/backend/cuda/quantized/cublas_qqmm.h +++ b/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -25,7 +25,7 @@ class CublasQQMM : public CublasMatmulBase { int64_t a_batch_stride, int64_t b_batch_stride, Dtype out_dtype, - std::string quantization_mode); + const std::string& quantization_mode); CublasQQMM( cu::Device& device, @@ -43,7 +43,7 @@ class CublasQQMM : public CublasMatmulBase { int64_t b_batch_stride, int64_t c_batch_stride, Dtype out_dtype, - std::string quantization_mode); + const std::string& quantization_mode); void run( cu::CommandEncoder& encoder, @@ -52,20 +52,33 @@ class CublasQQMM : public CublasMatmulBase { const array& b, const array& a_scale, const array& b_scale, - float alpha = 1.0f); + const array& alpha, + const array& beta); - private: - void run_batched( + void run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& a_scale, - const array& b_scale, - const Shape& batch_shape, - const Strides& a_batch_strides, - const Strides& b_batch_strides, - float alpha); + const array& b_scale); + + private: + void set_scales_ptrs( + cu::CommandEncoder& encoder, + const void* a_scale, + const void* b_scale); + + void execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + const void* alpha, + const void* beta); void execute( cu::CommandEncoder& encoder, @@ -75,10 +88,9 @@ class CublasQQMM : public CublasMatmulBase { const void* a_scale, const void* b_scale, const void* c, - float alpha = 1, - float beta = 0); + const float alpha = 1.0f, + const float beta = 0.0f); - std::string quantization_mode_; cublasLtMatmulMatrixScale_t a_scale_mode_; cublasLtMatmulMatrixScale_t b_scale_mode_; cublasLtMatmulMatrixScale_t c_scale_mode_; diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 3b4e96ef23..d1b402f26d 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -11,6 +11,11 @@ #include #include +#include +#include + +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; namespace mlx::core { namespace cu { @@ -29,7 +34,16 @@ struct Dequantize { namespace cg = cooperative_groups; template -__global__ void fp_quantize_dequantize(T* w, T* out, size_t size) { +__global__ void fp_quantize_dequantize( + T* w, + T* out, + size_t size, + float* global_scale = nullptr) { + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f; + using Tx2 = Vector2_t; using Tx4 = Vector4_t; uint32_t rbits = 0; // reserved bits for future use @@ -48,26 +62,28 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) { } auto w_tile = load_vector(w, thread_idx); - float scale = 0.0f; + float scale_dec_b = 0.0f; Tx2 amax_2x = Tx2{0.0f, 0.0f}; #pragma unroll for (int i = 0; i < group_size; i += 2) { auto pair = Tx2{w_tile[i], w_tile[i + 1]}; - abs_max_x2(amax_2x, amax_2x, pair); + absmax_x2(amax_2x, amax_2x, pair); } - scale = static_cast( + scale_dec_b = static_cast( max(fabsf(static_cast(amax_2x.x)), fabsf(static_cast(amax_2x.y)))); - scale /= bits == 4 ? 6.0f : 448.0f; + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; // Convert to mx scale or nv scale using ScaleType = std::conditional_t; - auto s = ScaleType(scale); - scale = float(s); + auto s = ScaleType(scale_dec_b); + float scale_enc_b = scale_enc / float(s); + float scale_dec = float(s) * inv_scale_enc; AlignedVector w_hat; #pragma unroll @@ -76,24 +92,36 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) { float4 dq; if constexpr (bits == 8) { uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); dq = dequant_fp8(quantized_val); } else { uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); dq = dequant_fp4(quantized_val); } - w_hat[i * 4] = static_cast(dq.x * scale); - w_hat[i * 4 + 1] = static_cast(dq.y * scale); - w_hat[i * 4 + 2] = static_cast(dq.z * scale); - w_hat[i * 4 + 3] = static_cast(dq.w * scale); + w_hat[i * 4] = static_cast(dq.x * scale_dec); + w_hat[i * 4 + 1] = static_cast(dq.y * scale_dec); + w_hat[i * 4 + 2] = static_cast(dq.z * scale_dec); + w_hat[i * 4 + 3] = static_cast(dq.w * scale_dec); } store_vector(out, thread_idx, w_hat); } template -__global__ void -fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) { +__global__ void fp_quantize_rowwise( + T* w, + uint8_t* out, + uint8_t* scales, + size_t size, + float* global_scale = nullptr) { + // NVFP4 conversion: + // Global encode scale: (448 × 6) / *global_scale + // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 + // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + using Tx2 = Vector2_t; using Tx4 = Vector4_t; uint32_t rbits = 0; // reserved bits for future use @@ -112,27 +140,28 @@ fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) { } auto w_tile = load_vector(w, thread_idx); - float scale = 0.0f; + float scale_dec_b = 0.0f; Tx2 amax_2x = Tx2{0.0f, 0.0f}; #pragma unroll for (int i = 0; i < group_size; i += 2) { auto pair = Tx2{w_tile[i], w_tile[i + 1]}; - abs_max_x2(amax_2x, amax_2x, pair); + absmax_x2(amax_2x, amax_2x, pair); } - scale = static_cast( + scale_dec_b = static_cast( max(fabsf(static_cast(amax_2x.x)), fabsf(static_cast(amax_2x.y)))); - scale /= bits == 4 ? 6.0f : 448.0f; + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; // Convert to mx scale or nv scale using ScaleType = std::conditional_t; - auto s = ScaleType(scale); + auto s = ScaleType(scale_dec_b); uint8_t q_scale = s.__x; - scale = float(s); + float scale_enc_b = scale_enc / float(s); scales[thread_idx] = q_scale; constexpr int elem_per_byte = bits == 8 ? 1 : 2; @@ -143,11 +172,11 @@ fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) { Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); if constexpr (bits == 8) { uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized[i * 4]) = quantized_val; } else { uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized[i * 2]) = quantized_val; } } @@ -161,11 +190,15 @@ __global__ void fp_quantize_columnwise( uint8_t* scales, size_t size, int M, - int K) { + int K, + float* global_scale = nullptr) { // Input: [M, K] with strides [1, M] (M-major) // Quantized output: [M, K/elem_per_byte] row-major (K-major) // Scales: [M, K/group_size] row-major (K-major) // Quantize along K (last dimension, groups of group_size elements) + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; using Tx2 = Vector2_t; using Tx4 = Vector4_t; @@ -215,16 +248,18 @@ __global__ void fp_quantize_columnwise( #pragma unroll for (int r = 0; r < group_size; r += 2) { auto pair = Tx2{thread_data[r], thread_data[r + 1]}; - abs_max_x2(amax_2x, amax_2x, pair); + absmax_x2(amax_2x, amax_2x, pair); } - float scale = + float scale_dec_b = max(fabsf(static_cast(amax_2x.x)), fabsf(static_cast(amax_2x.y))); - scale /= (bits == 4) ? 6.0f : 448.0f; + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; + // Convert to mx scale or nv scale using ScaleType = std::conditional_t; - auto s = ScaleType(scale); - scale = float(s); + auto s = ScaleType(scale_dec_b); + float scale_enc_b = scale_enc / float(s); scales_smem[tidx][tidy] = s.__x; int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group; @@ -234,12 +269,12 @@ __global__ void fp_quantize_columnwise( Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); if constexpr (bits == 8) { uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized_smem[shared_idx + j * 4]) = quantized_val; } else { uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); *reinterpret_cast(&quantized_smem[shared_idx + j * 2]) = quantized_val; } @@ -282,8 +317,12 @@ __global__ void fp_quantize_columnwise( } template -__global__ void -fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) { +__global__ void fp_dequantize( + const uint8_t* w, + const uint8_t* scales, + T* out, + size_t size, + float* global_scale = nullptr) { auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); auto idx_in_block = cg::this_thread_block().thread_index(); @@ -294,6 +333,10 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) { auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; constexpr int pack_factor = bits == 8 ? 1 : 2; + const bool use_global_scale = global_scale != nullptr; + const float inv_scale_enc = use_mx_scale + ? 1.0f + : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f); size_t offset = tidx + grid_dim_x * size_t(tidy); size_t oindex = offset * pack_factor; @@ -304,7 +347,7 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) { size_t gindex = oindex / group_size; using ScaleType = std::conditional_t; - auto scale = float(((ScaleType*)(scales))[gindex]); + auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc; out += oindex; @@ -346,9 +389,13 @@ void fp_quantize_dequantize( array& what, int group_size, int bits, + const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { enc.set_input_array(w); + if (global_scale.has_value()) { + enc.set_input_array(global_scale.value()); + } enc.set_output_array(what); dispatch_float_types(w.dtype(), "fp_quantize_dequantize", [&](auto type_tag) { using T = cuda_type_t; @@ -370,7 +417,9 @@ void fp_quantize_dequantize( 0, gpu_ptr(w), gpu_ptr(what), - w.size()); + w.size(), + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); } }); } @@ -381,9 +430,13 @@ void fp_quantize( array& scales, int group_size, int bits, + const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { enc.set_input_array(w); + if (global_scale.has_value()) { + enc.set_input_array(global_scale.value()); + } enc.set_output_array(wq); enc.set_output_array(scales); if (w.strides().back() != 1) { @@ -410,7 +463,9 @@ void fp_quantize( gpu_ptr(scales), w.size(), M, - K); + K, + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not quantize input with type float64."); @@ -438,7 +493,9 @@ void fp_quantize( gpu_ptr(w), gpu_ptr(wq), gpu_ptr(scales), - w.size()); + w.size(), + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not quantize input with type float64."); @@ -453,6 +510,7 @@ void fp_dequantize( array& w, int group_size, int bits, + const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { constexpr int uint8_per_uint32 = 4; @@ -465,6 +523,9 @@ void fp_dequantize( enc.set_input_array(wq); enc.set_input_array(scales); + if (global_scale.has_value()) { + enc.set_input_array(global_scale.value()); + } enc.set_output_array(w); dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) { using T = cuda_type_t; @@ -485,7 +546,9 @@ void fp_dequantize( gpu_ptr(wq), gpu_ptr(scales), gpu_ptr(w), - w.size()); + w.size(), + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); } else { throw std::runtime_error( "[Quantize::eval_gpu] Can not dequantize to output with type float64."); diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index 3b5be7718a..665bdb45b7 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -14,46 +14,85 @@ namespace mlx::core { namespace { -array pad_and_swizzle_scales( - const array& scale, +std::tuple quantize_input( + const array& input, cu::CommandEncoder& encoder, - const Stream& s) { - // Compute padded dimensions for full tiles (128 rows × 4 cols) + const Stream& s, + QuantizationMode mode, + int bits, + int group_size, + std::optional global_scale = std::nullopt) { + const array x = ensure_contiguous(input, encoder, s); + + // Compute output shapes + auto xq_shape = x.shape(); + xq_shape.back() = x.shape(-1) * bits / 32; + + const int64_t scales_inner = x.shape(-1) / group_size; auto [pad_outer, pad_inner] = - get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); - // cuBLAS requirements for scale factor layout: - // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) - // 2. Out-of-bounds values must be filled with zeros - // 3. Starting addresses must be 16-byte aligned - // - // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout - // Note: cu::malloc_async already provides 256-byte alignment - array scale_tiled( - cu::malloc_async(pad_outer * pad_inner, encoder), - Shape{pad_outer, pad_inner}, - scale.dtype()); - swizzle_scales(scale, scale_tiled, encoder, s); - - encoder.add_temporary(scale_tiled); - return scale_tiled; + get_padded_scale_dims(x.shape(-2), scales_inner); + + auto sshape = x.shape(); + sshape[x.ndim() - 2] = pad_outer; + sshape[x.ndim() - 1] = pad_inner; + sshape.back() = scales_inner; + + // Allocate outputs + const int64_t xq_bytes = x.size() * bits / 8; + const int64_t batch = x.size() / (x.shape(-2) * x.shape(-1)); + const int64_t scales_bytes = batch * (pad_outer * pad_inner); + + array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32); + array scales_x( + cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8); + encoder.add_temporary(x_q); + encoder.add_temporary(scales_x); + // global_scale is not nullopt only for NVFP4 + fp_quantize(x, x_q, scales_x, group_size, bits, global_scale, encoder, s); + return {std::move(x_q), std::move(scales_x)}; +} + +GemmScalars create_nvfp4_scalars( + const array& global_scale_x, + const array& global_scale_w, + cu::CommandEncoder& encoder) { + // NVFP4 requires alpha/beta as device pointers + // alpha = amax_x * amax_w / (448 * 6)^2 + // beta = 0 + array alpha(cu::malloc_async(sizeof(float), encoder), {}, float32); + array beta(cu::malloc_async(sizeof(float), encoder), {}, float32); + compute_qqmm_pointers(alpha, beta, global_scale_x, global_scale_w, encoder); + encoder.add_temporary(alpha); + encoder.add_temporary(beta); + return {alpha, beta}; } } // namespace void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { - assert( - (inputs.size() == 3 && inputs[1].dtype() == uint32) || - (inputs.size() == 2)); nvtx3::scoped_range r("QQMatmul::eval_gpu"); auto& s = stream(); auto& encoder = cu::get_command_encoder(s); auto& device = encoder.device(); - bool w_quantized = (inputs[1].dtype() == uint32); + int base_size = w_quantized ? 3 : 2; + + assert( + inputs.size() == base_size || + (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2)); + if (w_quantized && inputs[0].shape(-2) == 1) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); + // For nvfp4, get global scale for x from inputs if present + bool has_global_scale = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; + std::optional global_scale = std::nullopt; + if (has_global_scale) { + global_scale = inputs[inputs.size() - 2]; + } + bool donate_x = inputs[0].is_donatable(); array x = ensure_row_contiguous(inputs[0], encoder, s); // If x is a copy it should be donatable @@ -64,7 +103,8 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { if (!donate_x) { encoder.add_temporary(xhat); } - fp_quantize_dequantize(x, xhat, group_size_, bits_, encoder, s); + fp_quantize_dequantize( + x, xhat, group_size_, bits_, global_scale, encoder, s); // Make sure the last two dims of w and s are contiguous array w = ensure_row_contiguous_matrix(inputs[1], encoder, s); @@ -85,58 +125,53 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher."); } - auto quantize = [&](const array& input, - cu::CommandEncoder& encoder, - const Stream& s) -> std::pair { - auto x = ensure_contiguous(input, encoder, s); - auto xq_shape = x.shape(); - xq_shape.back() = x.shape(-1) * bits_ / 32; - - auto sshape = x.shape(); - const int64_t scales_inner = x.shape(-1) / group_size_; - auto [pad_outer, pad_inner] = - get_padded_scale_dims(x.shape(-2), scales_inner); - sshape[x.ndim() - 2] = pad_outer; - sshape[x.ndim() - 1] = pad_inner; - sshape.back() = scales_inner; - - // Allocate outputs - const int64_t xq_bytes = x.size() * bits_ / 8; - const int64_t batch = x.size() / (x.shape(-2) * x.shape(-1)); - const int64_t scales_bytes = batch * (pad_outer * pad_inner); - - array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32); - array scales_x( - cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8); - - fp_quantize(x, x_q, scales_x, group_size_, bits_, encoder, s); - - encoder.add_temporary(x_q); - encoder.add_temporary(scales_x); - return {x_q, scales_x}; - }; - auto [x_q, scale_x_pre] = quantize(inputs[0], encoder, s); - auto [w_q, scale_w_pre] = !w_quantized ? quantize(inputs[1], encoder, s) - : std::make_pair(inputs[1], inputs[2]); - out.set_data(cu::malloc_async(out.nbytes(), encoder)); + // - 2 inputs: x, w (non-quantized w) + // - 3 inputs: x, w, scales_w (quantized w) - auto out_dtype = out.dtype(); + // For nvfp4, global scales are optional but must be both present or both + // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) + bool has_global_scales = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; - int M = x_q.shape(-2); - int N = w_q.shape(-2); // always transposed - int K_packed = x_q.shape(-1); - int K = K_packed * (32 / bits_); + // For nvfp4, get global scales from inputs if present + std::optional global_scale_x = std::nullopt; + std::optional global_scale_w = std::nullopt; + if (has_global_scales) { + global_scale_x = inputs[inputs.size() - 2]; + global_scale_w = inputs[inputs.size() - 1]; + } - // Repack scales from linear to tiled layout for tensor cores - array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s); - array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s); + // Quantize inputs (or use pre-quantized) + auto [x_q, scale_x_pre] = quantize_input( + inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x); + auto [w_q, scale_w_pre] = !w_quantized + ? quantize_input( + inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w) + : std::make_tuple( + ensure_contiguous(inputs[1], encoder, s), + ensure_contiguous(inputs[2], encoder, s)); + + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + int M = x_q.shape(-2); + int N = w_q.shape(-2); // transposed + int K = x_q.shape(-1) * (32 / bits_); bool x_transposed = false; bool w_transposed = true; // always transposed int64_t lda = K; int64_t ldb = K; + // Repack scales to tiled layout for tensor cores + array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s); + array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s); + + GemmScalars scalars; + if (has_global_scales) { + scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder); + } + qqmm_impl( encoder, M, @@ -151,8 +186,8 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { w_q, scale_x, scale_w, - out_dtype, - mode_); + mode_, + scalars); } } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_impl.cpp b/mlx/backend/cuda/quantized/qqmm_impl.cpp index dd9407dcdc..d5986e05ea 100644 --- a/mlx/backend/cuda/quantized/qqmm_impl.cpp +++ b/mlx/backend/cuda/quantized/qqmm_impl.cpp @@ -19,15 +19,10 @@ void qqmm_impl( const array& b, const array& a_scale, const array& b_scale, - Dtype out_dtype, QuantizationMode mode, - float alpha) { - // Invoke CublasQQMM + const GemmScalars& scalars) { std::string qmode = quantization_mode_to_string(mode); - // Currently only supports non-batched QQMM operations - // that covers all use cases for training, we will just collapse (batch, - // seq_len) into (tokens) CublasQQMM qqmm( encoder.device(), a_transposed, @@ -41,10 +36,22 @@ void qqmm_impl( 1, // batch_count 0, // a_batch_stride 0, // b_batch_stride - out_dtype, + out.dtype(), qmode); - qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha); + if (scalars.uses_device_pointers()) { + qqmm.run( + encoder, + out, + a, + b, + a_scale, + b_scale, + *scalars.alpha_device, + *scalars.beta_device); + } else { + qqmm.run(encoder, out, a, b, a_scale, b_scale); + } } } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_impl.h b/mlx/backend/cuda/quantized/qqmm_impl.h index 7288e2fd7b..c562bfb186 100644 --- a/mlx/backend/cuda/quantized/qqmm_impl.h +++ b/mlx/backend/cuda/quantized/qqmm_impl.h @@ -1,10 +1,22 @@ -// Copyright © 2026 Apple Inc. +// Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device.h" #include "mlx/primitives.h" +#include + namespace mlx::core { + +struct GemmScalars { + std::optional alpha_device; + std::optional beta_device; + + bool uses_device_pointers() const { + return alpha_device.has_value(); + } +}; + void qqmm_impl( cu::CommandEncoder& encoder, int M, @@ -19,8 +31,7 @@ void qqmm_impl( const array& b, const array& a_scale, const array& b_scale, - Dtype out_dtype, QuantizationMode mode, - float alpha = 1.0f); + const GemmScalars& scalars = {}); } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu index c8764709b9..d19865a3b3 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.cu +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -70,6 +70,21 @@ inline std::tuple get_swizzle_launch_args( namespace cu { +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; + +__global__ void compute_qqmm_pointers( + float* alpha_out, + float* beta_out, + const float* tensor_amax_x, + const float* tensor_amax_w) { + // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 + constexpr float inv_scale_sq = + 1.0f / (F8E4M3_MAX * F4E2M1_MAX * F8E4M3_MAX * F4E2M1_MAX); + *alpha_out = (*tensor_amax_x) * (*tensor_amax_w) * inv_scale_sq; + *beta_out = 0.0f; +} + __global__ void swizzle_scales( const uint8_t* scales_linear, uint8_t* scales_swizzled, @@ -224,4 +239,25 @@ void swizzle_scales( output_cols); } +void compute_qqmm_pointers( + array& alpha_out, + array& beta_out, + const array& tensor_amax_x, + const array& tensor_amax_w, + cu::CommandEncoder& enc) { + enc.set_input_array(tensor_amax_x); + enc.set_input_array(tensor_amax_w); + enc.set_output_array(alpha_out); + enc.set_output_array(beta_out); + enc.add_kernel_node( + cu::compute_qqmm_pointers, + dim3(1), + dim3(1), + 0, + gpu_ptr(alpha_out), + gpu_ptr(beta_out), + gpu_ptr(tensor_amax_x), + gpu_ptr(tensor_amax_w)); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.h b/mlx/backend/cuda/quantized/qqmm_utils.h index 0a9a78f70c..fba9ac9d9e 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.h +++ b/mlx/backend/cuda/quantized/qqmm_utils.h @@ -27,4 +27,36 @@ void swizzle_scales( cu::CommandEncoder& enc, const Stream& s); +inline array pad_and_swizzle_scales( + const array& scale, + cu::CommandEncoder& encoder, + const Stream& s) { + // Compute padded dimensions for full tiles (128 rows × 4 cols) + auto [pad_outer, pad_inner] = + get_padded_scale_dims(scale.shape(-2), scale.shape(-1)); + // cuBLAS requirements for scale factor layout: + // 1. Dimensions must be padded to full tiles (128 rows × 4 cols) + // 2. Out-of-bounds values must be filled with zeros + // 3. Starting addresses must be 16-byte aligned + // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + // Note: cu::malloc_async already provides 256-byte alignment + array scale_tiled( + cu::malloc_async(pad_outer * pad_inner, encoder), + Shape{pad_outer, pad_inner}, + scale.dtype()); + swizzle_scales(scale, scale_tiled, encoder, s); + + encoder.add_temporary(scale_tiled); + return scale_tiled; +} + +// Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2 +// Allocate beta zero on device as well +void compute_qqmm_pointers( + array& alpha_out, + array& beta_out, + const array& tensor_amax_x, + const array& tensor_amax_w, + cu::CommandEncoder& enc); + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 30662d55a4..81ded94248 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -51,7 +51,6 @@ void fast::Quantize::eval_gpu( auto& s = stream(); auto& d = cu::device(s.device); auto& enc = d.get_command_encoder(s); - if (dequantize_) { auto wq = ensure_row_contiguous(inputs[0], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s); @@ -63,7 +62,12 @@ void fast::Quantize::eval_gpu( auto biases = ensure_row_contiguous(inputs[2], enc, s); affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); } else { - fp_dequantize(wq, scales, w, group_size_, bits_, enc, s); + // 0 -- xq, 1 -- scales, 2 -- could be global scale for nvfp4 + bool use_global_scale = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > 2; + std::optional global_scale = + use_global_scale ? std::make_optional(inputs[2]) : std::nullopt; + fp_dequantize(wq, scales, w, group_size_, bits_, global_scale, enc, s); } } else { auto w = ensure_contiguous(inputs[0], enc, s); @@ -72,12 +76,17 @@ void fast::Quantize::eval_gpu( wq.set_data(cu::malloc_async(wq.nbytes(), enc)); scales.set_data(cu::malloc_async(scales.nbytes(), enc)); + if (mode_ == QuantizationMode::Affine) { auto& biases = outputs[2]; biases.set_data(cu::malloc_async(biases.nbytes(), enc)); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); } else { - fp_quantize(w, wq, scales, group_size_, bits_, enc, s); + bool use_global_scale = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > 1; + std::optional global_scale = + use_global_scale ? std::make_optional(inputs[1]) : std::nullopt; + fp_quantize(w, wq, scales, group_size_, bits_, global_scale, enc, s); } } } diff --git a/mlx/backend/cuda/quantized/quantized.h b/mlx/backend/cuda/quantized/quantized.h index 744a12f5c4..f15c0f76e1 100644 --- a/mlx/backend/cuda/quantized/quantized.h +++ b/mlx/backend/cuda/quantized/quantized.h @@ -1,5 +1,6 @@ // Copyright © 2025 Apple Inc. +#include #include "mlx/backend/cuda/device.h" namespace mlx::core { @@ -30,6 +31,7 @@ void fp_quantize( array& scales, int group_size, int bits, + const std::optional& global_scale, cu::CommandEncoder& enc, const Stream& s); @@ -39,6 +41,7 @@ void fp_dequantize( array& w, int group_size, int bits, + const std::optional& global_scale, cu::CommandEncoder& enc, const Stream& s); @@ -47,6 +50,7 @@ void fp_quantize_dequantize( array& what, int group_size, int bits, + const std::optional& global_scale, cu::CommandEncoder& enc, const Stream& s); diff --git a/mlx/backend/cuda/quantized/quantized_utils.cuh b/mlx/backend/cuda/quantized/quantized_utils.cuh index 93d83292dc..8cbe9b297d 100644 --- a/mlx/backend/cuda/quantized/quantized_utils.cuh +++ b/mlx/backend/cuda/quantized/quantized_utils.cuh @@ -29,7 +29,7 @@ inline constexpr __device__ short get_bytes_per_pack() { } template -__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) { +__device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) { if constexpr ( (std::is_same::value) || (std::is_same::value)) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 076de699b3..a0e1636e33 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -10,6 +10,7 @@ #include #include "mlx/backend/cuda/cuda.h" +#include "mlx/backend/metal/metal.h" #include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/primitives.h" @@ -4209,6 +4210,34 @@ std::pair validate_mode_with_type( } } +void validate_global_scale( + std::string_view tag, + QuantizationMode qmode, + const std::optional& global_scale) { + if (global_scale.has_value()) { + if (qmode != QuantizationMode::Nvfp4) { + std::ostringstream msg; + msg << "[" << tag << "] Global scale is only supported for 'nvfp4' " + << "quantization mode."; + throw std::invalid_argument(msg.str()); + } else { + if (global_scale->size() != 1) { + std::ostringstream msg; + msg << "[" << tag << "] Global scale must be a scalar but got shape " + << global_scale->shape() << "."; + throw std::invalid_argument(msg.str()); + } + // TODO: not sure if type should be restricted to float32 + if (global_scale->dtype() != float32) { + std::ostringstream msg; + msg << "[" << tag << "] Global scale must have dtype float32 but got " + << global_scale->dtype() << "."; + throw std::invalid_argument(msg.str()); + } + } + } +} + array quantized_matmul( array x, array w, @@ -4251,7 +4280,6 @@ array quantized_matmul( if (x.ndim() > 2 && w.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); } - auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; return array( @@ -4267,7 +4295,10 @@ void validate_qqmm_inputs( array w, std::optional scales_w, int group_size, - int bits) { + int bits, + std::optional global_scale_x, + std::optional global_scale_w, + QuantizationMode qmode) { // check 2D (for now) if (x.ndim() > 2 || w.ndim() > 2) { std::ostringstream msg; @@ -4304,6 +4335,19 @@ void validate_qqmm_inputs( << "first argument dtype == " << x.dtype() << "."; throw std::invalid_argument(msg.str()); } + // validate global scales + validate_global_scale("qqmm", qmode, global_scale_x); + validate_global_scale("qqmm", qmode, global_scale_w); + // For nvfp4 mode, both global scales must be provided together or neither + if (qmode == QuantizationMode::Nvfp4) { + bool has_x = global_scale_x.has_value(); + bool has_w = global_scale_w.has_value(); + if (has_x != has_w) { + throw std::invalid_argument( + "[qqmm] For nvfp4 mode, either both global_scale_x and " + "global_scale_w must be provided, or neither."); + } + } } std::pair extract_qqmm_dims( @@ -4343,6 +4387,8 @@ array qqmm( std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "nvfp4" */, + const std::optional global_scale_x /* = std::nullopt */, + const std::optional global_scale_w /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto stream = to_stream(s); auto qmode = string_to_quantization_mode(mode, "qqmm"); @@ -4369,7 +4415,8 @@ array qqmm( } // validate inputs - validate_qqmm_inputs(x, w, scales_w, group_size, bits); + validate_qqmm_inputs( + x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode); // validate and extract shapes auto [w_inner_dims, w_outer_dims] = extract_qqmm_dims(x, w, scales_w, group_size, bits); @@ -4380,6 +4427,15 @@ array qqmm( if (scales_w.has_value()) { inputs.push_back(*scales_w); } + // if + if (global_scale_x.has_value()) { + // Stop gradient through global scales + inputs.push_back(stop_gradient(*global_scale_x)); + } + if (global_scale_w.has_value()) { + // Stop gradient through global scales + inputs.push_back(stop_gradient(*global_scale_w)); + } auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; auto out = array( @@ -4515,6 +4571,7 @@ std::vector fp_quantize( int group_size, int bits, QuantizationMode mode, + const std::optional& global_scale /* = std::nullopt */, Stream s) { int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32; int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4; @@ -4532,19 +4589,28 @@ std::vector fp_quantize( << bits << "."; throw std::invalid_argument(msg.str()); } + + auto inputs = std::vector{w}; + if (global_scale.has_value()) { + inputs.push_back(global_scale.value()); + } + auto fallback = [bits = bits, group_size = group_size, s]( const std::vector& inputs) -> std::vector { auto& w = inputs[0]; + auto scale_encode = + inputs.size() > 1 ? 448.0f * 6.0f / inputs[1] : array(1.0f, float32); float maxval = (bits == 4) ? 6.0f : 448.0f; auto new_shape = w.shape(); new_shape.back() = -1; auto wq = reshape(w, {-1, group_size}, s); auto scales = - divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s); + divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s) * + scale_encode; if (group_size == 16) { // convert to e4m3 scales = to_fp8(scales, s); - wq = divide(wq, from_fp8(scales, w.dtype(), s), s); + wq = divide(wq, from_fp8(scales, w.dtype(), s), s) * scale_encode; } else { // convert to e8m0 auto z = array(0, scales.dtype()); @@ -4600,9 +4666,9 @@ std::vector fp_quantize( {uint32, uint8}, std::make_shared( s, fallback, group_size, bits, mode, false), - {w}); + inputs); } - return fallback({w}); + return fallback(inputs); } std::vector quantize( @@ -4610,6 +4676,7 @@ std::vector quantize( std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, + const std::optional& global_scale /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto qmode = string_to_quantization_mode(mode, "quantize"); auto [group_size, bits] = @@ -4636,11 +4703,17 @@ std::vector quantize( << " matrix has shape " << w.shape(); throw std::invalid_argument(msg.str()); } - + if (to_stream(s).device == Device::gpu && metal::is_available() && + global_scale.has_value()) { + std::ostringstream msg; + msg << "[quantize] Global scale is not supported on the Metal backend."; + throw std::invalid_argument(msg.str()); + } + validate_global_scale("quantize", qmode, global_scale); if (qmode == QuantizationMode::Affine) { return affine_quantize(w, group_size, bits, s); } else { - return fp_quantize(w, group_size, bits, qmode, to_stream(s)); + return fp_quantize(w, group_size, bits, qmode, global_scale, to_stream(s)); } } @@ -4745,6 +4818,7 @@ array fp_dequantize( int bits, Dtype out_type, QuantizationMode mode, + const std::optional& global_scale /* = std::nullopt */, Stream s) { int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32; int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4; @@ -4789,6 +4863,11 @@ array fp_dequantize( throw std::invalid_argument(msg.str()); } + auto inputs = std::vector{w, scales}; + if (global_scale.has_value()) { + inputs.push_back(global_scale.value()); + } + auto fallback = [wshape = std::move(wshape), sshape = std::move(sshape), @@ -4798,6 +4877,9 @@ array fp_dequantize( s](const std::vector& inputs) mutable -> std::vector { auto out = inputs[0]; auto scales = inputs[1]; + array inv_scale_enc = inputs.size() > 2 + ? divide(inputs[2], array(448.0f * 6.0f, out_type), s) + : array(1.0f, out_type); if (bits == 4) { auto lut = array( { @@ -4831,13 +4913,16 @@ array fp_dequantize( out = reshape(out, {-1, group_size}, s); scales = reshape(scales, {-1, 1}, s); if (group_size == 16) { - scales = from_fp8(scales, out_type, s); + // NVFP4: scales are E4M3, apply inv_scale_enc + scales = multiply(from_fp8(scales, out_type, s), inv_scale_enc, s); } else { + // MXFP: scales are E8M0 (power of 2) scales = subtract(astype(scales, out_type, s), array(127, out_type), s); scales = power(array(2.0f, out_type), scales, s); } return {reshape(multiply(out, scales, s), wshape, s)}; }; + if (s.device == Device::gpu) { auto out_shape = w.shape(); out_shape.back() = out_size; @@ -4846,9 +4931,9 @@ array fp_dequantize( out_type, std::make_shared( s, fallback, group_size, bits, mode, true), - {w, scales}); + inputs); } - return fallback({w, scales})[0]; + return fallback(inputs)[0]; } array dequantize( @@ -4858,6 +4943,7 @@ array dequantize( std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, + const std::optional& global_scale /* = std::nullopt */, std::optional dtype /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto [out_type, qmode] = @@ -4884,6 +4970,14 @@ array dequantize( << "but it has only " << w.ndim() << "."; throw std::invalid_argument(msg.str()); } + if (global_scale.has_value()) { + if (to_stream(s).device == Device::gpu && metal::is_available()) { + std::ostringstream msg; + msg << "[dequantize] Global scale is not supported on the Metal backend."; + throw std::invalid_argument(msg.str()); + } + } + validate_global_scale("dequantize", qmode, global_scale); if (qmode == QuantizationMode::Affine) { return astype( @@ -4892,7 +4986,14 @@ array dequantize( s); } else { return fp_dequantize( - w, scales, group_size, bits, out_type, qmode, to_stream(s)); + w, + scales, + group_size, + bits, + out_type, + qmode, + global_scale, + to_stream(s)); } } @@ -6091,4 +6192,4 @@ array contiguous( {a}); } -} // namespace mlx::core +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/ops.h b/mlx/ops.h index 1ff3bbfaa5..cc9db8aedb 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1391,6 +1391,7 @@ MLX_API std::vector quantize( std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "affine", + const std::optional& global_scale = std::nullopt, StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ @@ -1401,17 +1402,20 @@ MLX_API array dequantize( std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "affine", + const std::optional& global_scale = std::nullopt, std::optional dtype = std::nullopt, StreamOrDevice s = {}); MLX_API array qqmm( array x, // input activations array w, // maybe quantized weights - std::optional w_scales = std::nullopt, // optional scales if w is - // quantized + const std::optional w_scales = std::nullopt, // optional scales if w + // is quantized std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "nvfp4", + const std::optional global_scale_x = std::nullopt, + const std::optional global_scale_w = std::nullopt, StreamOrDevice s = {}); /** Convert an E4M3 float8 to the given floating point dtype. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b07862514c..e3ff52e7c7 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3424,6 +3424,7 @@ std::vector QuantizedMatmul::vjp( group_size_, bits_, quantization_mode_to_string(mode_), + {}, // placeholder for amax std::nullopt, stream()); wq = unflatten(wq, -1, {-1, group_size_}, stream()); @@ -3484,14 +3485,24 @@ std::vector QQMatmul::output_shapes(const std::vector& inputs) { } std::vector QQMatmul::vjp( - const std::vector& primals, // non quantized x, non quantized w + const std::vector& primals, // non quantized x, non quantized w, if + // nvfp4 global_scale_x, global_scale_w const std::vector& cotangents, // non quantized upstream grads const std::vector& argnums, const std::vector&) { - if (primals.size() != 2) { - throw std::runtime_error( - "[QQMatmul::vjp] Expected exactly 2 non-quantized primal inputs (x, w)."); + bool is_nvfp4 = (mode_ == QuantizationMode::Nvfp4); + auto expected_size = is_nvfp4 ? 4 : 2; + if (primals.size() != expected_size) { + auto msg = std::ostringstream(); + msg << "[QQMatmul::vjp] Expected exactly " << expected_size + << " non-quantized primal inputs (x, w"; + if (mode_ == QuantizationMode::Nvfp4) { + msg << ", global_scale_x, global_scale_w"; + } + msg << ")."; + throw std::runtime_error(msg.str()); } + std::vector vjps; auto& cotan = cotangents[0]; auto& s = stream(); @@ -3499,6 +3510,14 @@ std::vector QQMatmul::vjp( // primal[0] -- non quantized activations (M, K) // cotan -- non quantized grads (M, N) auto qmode = quantization_mode_to_string(mode_); + std::optional cotan_amax = is_nvfp4 + ? std::make_optional(astype(max(abs(cotan, s), s), float32)) + : std::nullopt; + + auto get_primal_scale = [&](int idx) { + return is_nvfp4 ? std::make_optional(primals[idx]) : std::nullopt; + }; + for (auto arg : argnums) { if (arg == 0) { // gradient wrt to x // We transpose weights -> quantize along N @@ -3509,6 +3528,8 @@ std::vector QQMatmul::vjp( group_size_, bits_, qmode, + cotan_amax, + get_primal_scale(3), // global_scale_w (for w.T) s)); } else if (arg == 1) { // gradient wrt to weights vjps.push_back(qqmm( @@ -3518,6 +3539,8 @@ std::vector QQMatmul::vjp( group_size_, bits_, qmode, + cotan_amax, + get_primal_scale(2), // global_scale_x (for x.T) s)); } } @@ -3643,6 +3666,7 @@ std::vector GatherQMM::vjp( bits_, quantization_mode_to_string(mode_), std::nullopt, + std::nullopt, // amax placeholder stream()), -1, {-1, group_size_}, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index de51ad502c..131f3d12e9 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4258,10 +4258,11 @@ void init_ops(nb::module_& m) { "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "affine", + "global_scale"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), + "def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, global_scale: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), R"pbdoc( Quantize the array ``w``. @@ -4286,6 +4287,8 @@ void init_ops(nb::module_& m) { ``w`` in the quantized array. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. + global_scale (array, optional): The per-input float32 scale used for + ``"nvfp4"`` quantization if provided. Default: ``None``. Returns: tuple: A tuple with either two or three elements containing: @@ -4355,11 +4358,12 @@ void init_ops(nb::module_& m) { "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "affine", + "global_scale"_a = nb::none(), "dtype"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', global_scale: Optional[array] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dequantize the matrix ``w`` using quantization parameters. @@ -4374,6 +4378,8 @@ void init_ops(nb::module_& m) { bits (int, optional): The number of bits occupied by each element of ``w`` in the quantized array. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. + global_scale (array, optional): The per-input float32 scale used for + ``"nvfp4"`` quantization if provided. Default: ``None``. dtype (Dtype, optional): The data type of the dequantized output. If ``None`` the return type is inferred from the scales and biases when possible and otherwise defaults to ``bfloat16``. @@ -5465,10 +5471,12 @@ void init_ops(nb::module_& m) { "group_size"_a = nb::none(), "bits"_a = nb::none(), "mode"_a = "nvfp4", + "global_scale_x"_a = nb::none(), + "global_scale_w"_a = nb::none(), nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def qqmm(x: array, w: array, scales: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', *, stream: Union[None, Stream, Device] = None) -> array"), + "def qqmm(x: array, w: array, scales: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'nvfp4', global_scale_x: Optional[array] = None, global_scale_w: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform a matrix multiplication using a possibly quantized weight matrix ``w`` and a non-quantized input ``x``. The input ``x`` is quantized on the @@ -5486,6 +5494,7 @@ void init_ops(nb::module_& m) { If ``x`` and `w`` are not quantized, their data types must be ``float32``, ``float16``, or ``bfloat16``. If ``w`` is quantized, it must be packed in unsigned integers. + ``global_scale_x`` and ``global_scale_w`` are only used for ``nvfp4`` quantization. Args: x (array): Input array. @@ -5501,7 +5510,10 @@ void init_ops(nb::module_& m) { mode (str, optional): The quantization mode. Default: ``"nvfp4"``. Supported modes are ``nvfp4`` and ``mxfp8``. See the :ref:`table of quantization modes ` for details. - + global_scale (array, optional): The per-input float32 scale used for x + ``"nvfp4"`` quantization if provided. Default: ``None``. + global_scale_w (array, optional): The per-input float32 scale used for w + ``"nvfp4"`` quantization if provided. Default: ``None``. Returns: array: The result of the multiplication of quantized ``x`` with quantized ``w``. needed). diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 6e4e9eee8f..d53c5e568b 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -160,6 +160,19 @@ def test_nvfp4_quantize_dequantize(self): w_hat = mx.dequantize(w_q, scales, mode="nvfp4") self.assertTrue(mx.all(w_hat == 0)) + # Test nvfp4 quantize/dequantize with tensor-scale global_scale + # currently supported only on cpu and cuda + if not mx.metal.is_available(): + global_scale = w.abs().max().astype(mx.float32) + else: + global_scale = None + + w_q, scales = mx.quantize(w, mode="nvfp4", global_scale=global_scale) + w_hat = mx.dequantize( + w_q, scales, group_size=16, bits=4, mode="nvfp4", global_scale=global_scale + ) + self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) + def test_qqmv(self): key = mx.random.key(0) k1, k2 = mx.random.split(key)