Skip to content

Conversation

@CC-Yeh
Copy link
Contributor

@CC-Yeh CC-Yeh commented Jan 20, 2026

Proposed changes

Add Metal quantized SDPA vector kernels based on #1515

Speedup vs fp16 (Quant SDPA)

Config: H=32, H_k=8, D=128, GQA=4x

SeqLen mxfp4 (4b) mxfp8 (8b)
2048 0.91x 1.21x
4096 1.59x 1.35x
8192 1.79x 1.52x
16384 2.31x 1.94x
32768 2.49x 2.04x
65536 2.62x 2.17x
131072 2.58x 2.05x

TODO:

What improve performance:

  • Removed thread storage k, v to reduce register pressure (was waiting on synchronization).
  • Fused computation with dequantization
  • Tuned reading size ('uint16_t'/'uin32_t') for quantized k/v
  • Manual unroll better than clang loop optimizer

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@awni
Copy link
Member

awni commented Jan 21, 2026

The numbers seem quite good.. a little too good to be true 😅

What's the difference between SDPA and Attention in the benchmark? Also what's the query sequence length used for the benchmark?

@CC-Yeh
Copy link
Contributor Author

CC-Yeh commented Jan 21, 2026

The numbers seem quite good.. a little too good to be true 😅

Totally agree, must be missing something 🤔

What's the difference between SDPA and Attention in the benchmark? Also what's the query sequence length used for the benchmark?

Attention is a simple reference implementation built from matmul + softmax + matmul (Maybe too naive?).
SDPA uses mx.fast.scaled_dot_product_attention, which hits the sdpa_vector_2pass kernels when Lq ≤ 8 (this case).

The query sequence length here is 1 (q.shape = (1, 32, 1, 128)), so this benchmark is measuring the single-token decode case, where one new token attends to a long KV cache (L = 32768).

@CC-Yeh
Copy link
Contributor Author

CC-Yeh commented Jan 21, 2026

@awni
Fixed some bugs in dequantizing 8bit and benchmark(unneccessary dequantization steps).
Now the numbers make more sense 😃

@awni
Copy link
Member

awni commented Jan 21, 2026

So if I’m understanding correctly the fused implementation is slower in the quantized case than the unfused ops-based one?

@CC-Yeh
Copy link
Contributor Author

CC-Yeh commented Jan 21, 2026

Fused SDPA is faster: MXFP4 15.33 ms vs 24.71 ms, and MXFP8 26.09 ms vs 46.48 ms to decode a single query.

@awni
Copy link
Member

awni commented Jan 21, 2026

Very nice!!

mlx/fast.cpp Outdated
Comment on lines 875 to 878
if (qmode == QuantizationMode::Nvfp4) {
throw std::invalid_argument(
"[quantized_scaled_dot_product_attention] Mode 'nvfp4' is not supported for fast attention.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not nvfp4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s on the way! I just wanted to make sure the PR structure was okay first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added support

mlx/fast.cpp Outdated
Comment on lines 871 to 874
if (qmode == QuantizationMode::Affine) {
throw std::invalid_argument(
"[quantized_scaled_dot_product_attention] Only fp quantization modes are supported.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not affine?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw not suggesting we necessarily do it. Maybe it's better to be more limited in the quants we support here. Maybe fp8, fp4 are fine to start?

For example I don't think it's necessary to support every bit width because in practice no-one will ever use 2, 3 for KV cache quantization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added initial support, still has more room for tuning bit 2/3/5/6

@awni
Copy link
Member

awni commented Jan 27, 2026

@CC-Yeh I'm interested in this PR moving forward. Let me know if you have questions. Also no need to support everything on a first pass. I think doing one 8-bit (fp8 / int8) quant well for Metal / CUDA is already probably good enough to start.

@CC-Yeh CC-Yeh changed the title [WIP] Quantized SDPA Quantized SDPA Jan 29, 2026
@CC-Yeh CC-Yeh marked this pull request as ready for review January 29, 2026 22:04
@CC-Yeh
Copy link
Contributor Author

CC-Yeh commented Jan 29, 2026

@awni

I’ve added the Metal paths for mxfp4/8, nvfp4, and affine(2/3/4/5/6/8) (affine is not optimized).
Further tuning likely needs validation on other machines.

For the CUDA path (maybe next PR), Colab doesn’t support NVFP4, so would need help for that.

quant_sdpa_speedup_vs_seqlen

@CC-Yeh CC-Yeh requested a review from awni January 29, 2026 22:22
@awni
Copy link
Member

awni commented Jan 29, 2026

affine(2/3/4/5/6/8)

What group sizes did you do for that? I"m not convinced we need broad support for bitwidth X group size. I expect bits < 4 to be used rarely if ever.

@CC-Yeh
Copy link
Contributor Author

CC-Yeh commented Jan 29, 2026

affine(2/3/4/5/6/8)

What group sizes did you do for that? I"m not convinced we need broad support there. I expect < 4 to be used rarely.

What group sizes do you think we should support for affine? Currently it's templated so it can handle various
sizes, but I can limit the instantiations if there's a specific set that's practical.

template <typename T, int D, QuantMode mode, int group_size, int bits>
[[kernel]] void quant_sdpa_vector_2pass_1(

@awni
Copy link
Member

awni commented Jan 29, 2026

Yes totally. I think it's good to keep it generic. But probably better to limit initial support and grow than vice versa.

I would maybe start with bits = {4, 6, 8} and just group_size = 32. I think 32 is most flexible for the head dimension right?

@CC-Yeh
Copy link
Contributor Author

CC-Yeh commented Jan 30, 2026

Yes totally. I think it's good to keep it generic. But probably better to limit initial support and grow than vice versa.

I would maybe start with bits = {4, 6, 8} and just group_size = 32. I think 32 is most flexible for the head dimension right?

Limited the affine support.

Yeah, 32 is most flexible for head dim.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants