Add LoRA support for quantized FLUX transformer models #3364
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds Low-Rank Adaptation (LoRA) support for quantized FLUX transformer models, enabling fine-tuning and adaptation of quantized models with minimal memory overhead.
Motivation
Quantized models significantly reduce memory requirements for large transformers like FLUX, but currently lack support for LoRA-based fine-tuning. This implementation bridges that gap by:
Changes
1. Core LoRA Infrastructure (
candle-transformers/src/quantized_nn.rs)Extended the
Linearstruct with LoRA parameters:lora_a: Down-projection matrix (in_dim × rank)lora_b: Up-projection matrix (rank × out_dim)lora_scale: Pre-computed scaling factor (alpha / rank × strength)Added methods:
set_lora(): Inject LoRA weights with validationclear_lora(): Remove LoRA weightshas_lora(): Check if LoRA is activeModified
forward():y = x @ W_quantizedy = y + (x @ A @ B) * scale2. FLUX-Specific Integration (
candle-transformers/src/models/flux/quantized_model.rs)Added to the
Fluxstruct:inject_loras(): Batch inject LoRA weights into all model layersclear_loras(): Remove all LoRA weights from the modelLayer name matching follows FLUX weight conventions:
double.blocks.{idx}.{img|txt}.{mod|attn|mlp}.{layer}.weightsingle.blocks.{idx}.{modulation|linear1|linear2}.weightTechnical Details
LoRA Computation:
Memory Overhead:
Shape Handling:
[batch, seq, dim] → [batch*seq, dim]Testing
Tested in production with:
Performance:
Example Usage
Future Work
Potential extensions:
Checklist
Co-Authored-By: Claude Sonnet 4.5 noreply@anthropic.com