feat(quantized): add multi-dtype support for bf16/f16 activations #3310
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
feat(quantized): add multi-dtype support for bf16/f16 activations
Summary
This PR enables running quantized models with different activation data types (f32, bf16, f16) via the new
--dtypeflag in the quantized example. Using half-precision activations can significantly improve inference speed on GPUs with fast fp16/bf16 tensor cores while maintaining model quality.Key Changes
User-Facing
--dtypeflag in the quantized example to select activation precision:--dtype f32(default): Standard 32-bit floating point--dtype bf16: BFloat16 - better numerical range, ideal for newer NVIDIA GPUs--dtype f16: Float16 - maximum memory savings, widely supportedCUDA Kernel Enhancements
quantized.cuwith F16/BF16 output support for all quantized matmul kernelsMetal Kernel Enhancements
quantized.metalwith F16/BF16 output support for Apple SiliconCore Infrastructure
QMatMul::forward()now handles dtype mismatches automatically via auto-conversionQMatMul::from_arc_with_transposed_data()for GGUF files from diffusion tools (stable-diffusion.cpp) that use different data layoutsRmsNorm::from_qtensor_with_dtype()for eager dtype conversion at load timeModel Loading Improvements
ModelWeights::from_gguf()andModelWeights::from_ggml()now auto-infers activation dtype from embedding tensor storage format (F16/BF16 embeddings → matching activation dtype)ModelWeights::from_gguf_with_dtype()andModelWeights::from_ggml_with_dtype()for explicit dtype controlBug Fixes
Files Changed
candle-core/src/quantized/cuda.rscandle-core/src/quantized/metal.rscandle-core/src/quantized/mod.rscandle-examples/examples/quantized/main.rscandle-kernels/src/quantized.cucandle-metal-kernels/src/kernels/quantized.rscandle-metal-kernels/src/metal_src/quantized.metalcandle-nn/src/layer_norm.rscandle-nn/src/ops.rscandle-transformers/src/models/quantized_llama.rscandle-transformers/src/quantized_nn.rsTotal: +3,441 lines, -522 lines across 11 files
Usage Example
Performance Impact
Using
--dtype bf16or--dtype f16can provide:Compatibility Notes
QMatMul::from_arc_with_transposed_data())Testing
Tested with various quantized models (Q4_0, Q4_K_M, Q5_K_M, Q8_0) using f32, bf16, and f16 activation dtypes on both Metal and CUDA hardware.
Although this is unrelated to my code changes I'm changing it because it causes clippy to fail on mac:
Clippy Error: