-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Quantized SDPA #3026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Quantized SDPA #3026
Conversation
b64b7dc to
11b24f5
Compare
|
The numbers seem quite good.. a little too good to be true 😅 What's the difference between SDPA and Attention in the benchmark? Also what's the query sequence length used for the benchmark? |
11b24f5 to
640ec94
Compare
Totally agree, must be missing something 🤔
Attention is a simple reference implementation built from The query sequence length here is 1 (q.shape = (1, 32, 1, 128)), so this benchmark is measuring the single-token decode case, where one new token attends to a long KV cache (L = 32768). |
|
@awni |
|
So if I’m understanding correctly the fused implementation is slower in the quantized case than the unfused ops-based one? |
|
Fused SDPA is faster: |
|
Very nice!! |
mlx/fast.cpp
Outdated
| if (qmode == QuantizationMode::Nvfp4) { | ||
| throw std::invalid_argument( | ||
| "[quantized_scaled_dot_product_attention] Mode 'nvfp4' is not supported for fast attention."); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not nvfp4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s on the way! I just wanted to make sure the PR structure was okay first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added support
mlx/fast.cpp
Outdated
| if (qmode == QuantizationMode::Affine) { | ||
| throw std::invalid_argument( | ||
| "[quantized_scaled_dot_product_attention] Only fp quantization modes are supported."); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not affine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw not suggesting we necessarily do it. Maybe it's better to be more limited in the quants we support here. Maybe fp8, fp4 are fine to start?
For example I don't think it's necessary to support every bit width because in practice no-one will ever use 2, 3 for KV cache quantization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added initial support, still has more room for tuning bit 2/3/5/6
|
@CC-Yeh I'm interested in this PR moving forward. Let me know if you have questions. Also no need to support everything on a first pass. I think doing one 8-bit (fp8 / int8) quant well for Metal / CUDA is already probably good enough to start. |
f3dc49d to
5af4060
Compare
What group sizes did you do for that? I"m not convinced we need broad support for bitwidth X group size. I expect bits < 4 to be used rarely if ever. |
3bc3e28 to
c72fad9
Compare
What group sizes do you think we should support for affine? Currently it's templated so it can handle various template <typename T, int D, QuantMode mode, int group_size, int bits>
[[kernel]] void quant_sdpa_vector_2pass_1( |
|
Yes totally. I think it's good to keep it generic. But probably better to limit initial support and grow than vice versa. I would maybe start with bits = {4, 6, 8} and just group_size = 32. I think 32 is most flexible for the head dimension right? |
Limited the affine support. Yeah, 32 is most flexible for head dim. |

Proposed changes
Add Metal quantized SDPA vector kernels based on #1515
Speedup vs fp16 (Quant SDPA)
Config: H=32, H_k=8, D=128, GQA=4x
TODO:
AffineandNVFP4What improve performance:
k/vclangloop optimizerChecklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes