Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
98eedd1
adding tensor scale [wip]
nastya236 Jan 15, 2026
438830a
Merge branch 'main' into tensor-scale-nvfp4
nastya236 Jan 15, 2026
6892404
added absmax reduction, changed fp_quanitze api [wip]
nastya236 Jan 16, 2026
15d684b
refactoring
nastya236 Jan 18, 2026
8c67953
Merge branch 'ml-explore:main' into tensor-scale-nvfp4
nastya236 Jan 19, 2026
a7fab99
alpha device ptr for qqmm
nastya236 Jan 19, 2026
9fdfce6
Merge branch 'tensor-scale-nvfp4' of https://github.com/nastya236/mlx…
nastya236 Jan 19, 2026
47be994
device alpha, beta
nastya236 Jan 20, 2026
7e4c6e8
harcoded absmax to output float
nastya236 Jan 20, 2026
11ff19a
fixed ops python dequantize
nastya236 Jan 20, 2026
2a86dc1
input global_scale
nastya236 Jan 20, 2026
2c68fb6
fix global_scale
nastya236 Jan 21, 2026
abe37c2
Merge branch 'main' into tensor-scale-nvfp4
nastya236 Jan 21, 2026
277ceeb
fix scale to be float(fp8e4m3(scale))
nastya236 Jan 21, 2026
dad7e57
removed AbsMax reduction (probably add back in the future as a separa…
nastya236 Jan 21, 2026
3d7ebd9
Merge branch 'main' into tensor-scale-nvfp4
nastya236 Jan 22, 2026
0a804a9
fix columnwise quantize scale, precommit
nastya236 Jan 22, 2026
7ca2642
abs_max
nastya236 Jan 22, 2026
934c0c8
fix
nastya236 Jan 23, 2026
1fea025
fixed the fallback, fixed absmax
nastya236 Jan 23, 2026
306acd0
fix docs, remove the diff
nastya236 Jan 23, 2026
7492841
fix docs, delete debuging print
nastya236 Jan 23, 2026
5503802
Merge branch 'main' into tensor-scale-nvfp4
nastya236 Jan 23, 2026
f49abe5
reverted the example
nastya236 Jan 23, 2026
37e5789
abs_max -> absmax
nastya236 Jan 23, 2026
507c94e
fix
nastya236 Jan 23, 2026
858fe00
fix test, force flobal scale only on cuda
nastya236 Jan 23, 2026
5fdffe4
fix stream
nastya236 Jan 23, 2026
1cc13ba
made AbsMax the same structure as Max
nastya236 Jan 24, 2026
05bd4d0
abs_val rename
nastya236 Jan 24, 2026
20480ef
fix abs type
nastya236 Jan 24, 2026
9f9aabd
fix fp type for vjp
nastya236 Jan 24, 2026
d2dc310
decrease block size because of the register pressure
nastya236 Jan 25, 2026
76cd3b4
Merge branch 'main' into tensor-scale-nvfp4
nastya236 Feb 2, 2026
79d93e6
drop absmax
nastya236 Feb 2, 2026
d91fd8a
merge conflict fp-quantize
nastya236 Feb 2, 2026
5cbf48f
add scale to fp_quantiz-dequantize, fix merge conflicts, refactor
nastya236 Feb 3, 2026
da1cacf
pre-commit + update a comment
nastya236 Feb 4, 2026
67385c8
revert qq_linear global scale [WIP]
nastya236 Feb 4, 2026
ad1fcf1
refactoring, revert block size
nastya236 Feb 4, 2026
019a31d
Merge remote-tracking branch 'upstream/main' into tensor-scale-nvfp4
nastya236 Feb 4, 2026
b1dcd2f
revert the year change
nastya236 Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions mlx/backend/cuda/cublas_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,6 @@ void CublasMatmulBase::init_base(
CHECK_CUBLAS_ERROR(
cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type));

int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(int32_t)));

// In cublasLt matrices use column-major layout, while it is possible to use
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
// epilogue does not work with the option. So instead we swap A and B to make
Expand Down
12 changes: 10 additions & 2 deletions mlx/backend/cuda/gemms/cublas_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ CublasGemm::CublasGemm(
batch_count,
a_batch_stride,
b_batch_stride);

// alpha and beta are both host pointers
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(pointer_mode)));
}

CublasGemm::CublasGemm(
Expand Down Expand Up @@ -215,8 +223,8 @@ void CublasGemm::execute(
const void* a,
const void* b,
const void* c,
float alpha /* = 1 */,
float beta /* = 0 */) {
const float alpha /* = 1 */,
const float beta /* = 0 */) {
const void* alpha_ptr = α
const void* beta_ptr = β
complex64_t alpha_c, beta_c;
Expand Down
143 changes: 97 additions & 46 deletions mlx/backend/cuda/quantized/cublas_qqmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,26 @@ namespace mlx::core {

namespace {

// Currently cublas supports only mxfp8 and nvfp4
// quantization modes for block scaled quantization
cudaDataType_t qmode_to_cublas_scale_dtype(std::string mode) {
if (mode == "mxfp8") {
return CUDA_R_8F_UE8M0;
} else if (mode == "nvfp4") {
return CUDA_R_8F_UE4M3;
} else {
throw std::runtime_error(
fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
}
}

cudaDataType_t qmode_to_cublas_dtype(std::string mode) {
if (mode == "mxfp8") {
return CUDA_R_8F_E4M3;
} else if (mode == "nvfp4") {
return CUDA_R_4F_E2M1;
} else {
throw std::runtime_error(
fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
}
}
struct QuantModeConfig {
cudaDataType_t data_type;
cudaDataType_t scale_dtype;
cublasLtMatmulMatrixScale_t scale_mode;
};

cublasLtMatmulMatrixScale_t qmode_to_cublas_scale_mode(std::string mode) {
QuantModeConfig get_quant_mode_config(const std::string& mode) {
if (mode == "mxfp8") {
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
return {
CUDA_R_8F_E4M3,
CUDA_R_8F_UE8M0,
CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0};
} else if (mode == "nvfp4") {
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
} else {
throw std::runtime_error(
fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
return {
CUDA_R_4F_E2M1,
CUDA_R_8F_UE4M3,
CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3};
}
throw std::runtime_error(
fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
}

} // namespace
Expand All @@ -64,21 +51,21 @@ CublasQQMM::CublasQQMM(
int64_t a_batch_stride,
int64_t b_batch_stride,
Dtype out_dtype,
std::string qmode) {
const std::string& qmode) {
auto config = get_quant_mode_config(qmode);

// The compute type must be CUBLAS_COMPUTE_32F.
// The scale type must be CUDA_R_32F.
cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
cudaDataType_t output_type =
cublas_utils::dtype_to_cublas_type(out_dtype, "CublasQQMM");
cudaDataType_t data_type = qmode_to_cublas_dtype(qmode);
quantization_mode_ = std::string(qmode);

init_base(
device,
scale_type,
gemm_compute_type,
data_type,
config.data_type,
output_type,
a_transposed,
a_rows,
Expand All @@ -92,8 +79,8 @@ CublasQQMM::CublasQQMM(
a_batch_stride,
b_batch_stride);

a_scale_mode_ = qmode_to_cublas_scale_mode(qmode);
b_scale_mode_ = qmode_to_cublas_scale_mode(qmode);
a_scale_mode_ = config.scale_mode;
b_scale_mode_ = config.scale_mode;

CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
Expand Down Expand Up @@ -123,7 +110,7 @@ CublasQQMM::CublasQQMM(
int64_t b_batch_stride,
int64_t c_batch_stride,
Dtype out_dtype,
std::string qmode)
const std::string& qmode)
: CublasQQMM(
device,
a_transposed,
Expand Down Expand Up @@ -158,11 +145,14 @@ void CublasQQMM::run(
const array& b,
const array& a_scale,
const array& b_scale,
float alpha) {
const array& alpha,
const array& beta) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(a_scale);
encoder.set_input_array(b_scale);
encoder.set_input_array(alpha);
encoder.set_input_array(beta);
encoder.set_output_array(out);

execute(
Expand All @@ -173,19 +163,37 @@ void CublasQQMM::run(
gpu_ptr<void>(a_scale),
gpu_ptr<void>(b_scale),
nullptr,
alpha);
gpu_ptr<void>(alpha),
gpu_ptr<void>(beta));
}

void CublasQQMM::execute(
void CublasQQMM::run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& a_scale,
const array& b_scale) {
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(a_scale);
encoder.set_input_array(b_scale);
encoder.set_output_array(out);

execute(
encoder,
gpu_ptr<void>(out),
gpu_ptr<void>(a),
gpu_ptr<void>(b),
gpu_ptr<void>(a_scale),
gpu_ptr<void>(b_scale),
nullptr);
}

void CublasQQMM::set_scales_ptrs(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* a_scale,
const void* b_scale,
const void* c,
float alpha /* = 1 */,
float beta /* = 0 */) {
const void* b_scale) {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
Expand All @@ -196,6 +204,49 @@ void CublasQQMM::execute(
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&a_scale,
sizeof(a_scale)));
}

void CublasQQMM::execute(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* a_scale,
const void* b_scale,
const void* c,
const void* alpha,
const void* beta) {
set_scales_ptrs(encoder, a_scale, b_scale);
// alpha and beta are both should be device pointers for nvfp4
// by default cublas uses host pointers
// https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(pointer_mode)));
execute_matmul(encoder, out, a, b, c, alpha, beta);
}

void CublasQQMM::execute(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* a_scale,
const void* b_scale,
const void* c,
const float alpha /* = 1 */,
const float beta /* = 0 */) {
set_scales_ptrs(encoder, a_scale, b_scale);
// alpha and beta are both should be host pointers
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode,
sizeof(pointer_mode)));

const void* alpha_ptr = &alpha;
const void* beta_ptr = &beta;
Expand Down
38 changes: 25 additions & 13 deletions mlx/backend/cuda/quantized/cublas_qqmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class CublasQQMM : public CublasMatmulBase {
int64_t a_batch_stride,
int64_t b_batch_stride,
Dtype out_dtype,
std::string quantization_mode);
const std::string& quantization_mode);

CublasQQMM(
cu::Device& device,
Expand All @@ -43,7 +43,7 @@ class CublasQQMM : public CublasMatmulBase {
int64_t b_batch_stride,
int64_t c_batch_stride,
Dtype out_dtype,
std::string quantization_mode);
const std::string& quantization_mode);

void run(
cu::CommandEncoder& encoder,
Expand All @@ -52,20 +52,33 @@ class CublasQQMM : public CublasMatmulBase {
const array& b,
const array& a_scale,
const array& b_scale,
float alpha = 1.0f);
const array& alpha,
const array& beta);

private:
void run_batched(
void run(
cu::CommandEncoder& encoder,
array& out,
const array& a,
const array& b,
const array& a_scale,
const array& b_scale,
const Shape& batch_shape,
const Strides& a_batch_strides,
const Strides& b_batch_strides,
float alpha);
const array& b_scale);

private:
void set_scales_ptrs(
cu::CommandEncoder& encoder,
const void* a_scale,
const void* b_scale);

void execute(
cu::CommandEncoder& encoder,
void* out,
const void* a,
const void* b,
const void* a_scale,
const void* b_scale,
const void* c,
const void* alpha,
const void* beta);

void execute(
cu::CommandEncoder& encoder,
Expand All @@ -75,10 +88,9 @@ class CublasQQMM : public CublasMatmulBase {
const void* a_scale,
const void* b_scale,
const void* c,
float alpha = 1,
float beta = 0);
const float alpha = 1.0f,
const float beta = 0.0f);

std::string quantization_mode_;
cublasLtMatmulMatrixScale_t a_scale_mode_;
cublasLtMatmulMatrixScale_t b_scale_mode_;
cublasLtMatmulMatrixScale_t c_scale_mode_;
Expand Down
Loading