Support direct export of block-wise FP8 weights and scaling factors#1994
Support direct export of block-wise FP8 weights and scaling factors#1994eternally-z wants to merge 9 commits intoNVIDIA-NeMo:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR adds FP8 weight export support to the Megatron bridge conversion pipeline. It introduces FP8 parameter detection, blockwise scale parameter handling, and new adapter weight conversion task structures. The changes extend the weight loading and export paths to support FP8 quantization with proper scale_inv parameter naming and unquantized state dict capture. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant AutoBridge
participant MegatronModelBridge
participant FP8Detector
participant ParamMapping
User->>AutoBridge: export_hf_weights(export_weight_dtype="fp8")
AutoBridge->>MegatronModelBridge: build_export_fp8_tasks()
MegatronModelBridge->>FP8Detector: _detect_fp8_params()
FP8Detector->>FP8Detector: scan parameters for scale_inv attributes
FP8Detector-->>MegatronModelBridge: fp8_param_map (param_name → is_fp8)
MegatronModelBridge->>ParamMapping: build tasks with _HFNameSuffixMapping
ParamMapping->>ParamMapping: split_qkv_weights (handle scale-domain tensors)
MegatronModelBridge-->>AutoBridge: export_tasks (including FP8 scale_inv variants)
AutoBridge-->>User: weights exported with FP8 parameters
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/megatron/bridge/models/conversion/model_bridge.py (1)
555-617: Update docstring return type to match actual return value.The method now returns
(megatron_model, self.unquantized_state_dict)- a tuple - but the docstring at line 509 still declaresReturns: List[MegatronModel]. This is a breaking change affecting callers registered as pre_wrap_hooks (lines 869, 875 in auto_bridge.py). When the hook returns a tuple, subsequent code that iterates overmodel(line 560) will fail.Update the docstring:
Docstring fix
Returns: - List[MegatronModel]: The input megatron_model as a list with loaded weights. + Tuple[List[MegatronModel], Dict[str, torch.Tensor] | None]: A tuple containing: + - The input megatron_model as a list with loaded weights + - The unquantized state dict if FP8 capture was enabled, otherwise NoneAdditionally, callers using this method as a pre_wrap_hook must wrap the return or adjust the hook signature to return only the model list, not a tuple. The direct caller at auto_bridge.py:324 already handles unpacking correctly, but hook-based registrations (lines 869, 875) will fail.
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/conversion/auto_bridge.py`:
- Around line 377-384: Add an explicit validation at the start of
build_export_fp8_tasks to raise a clear error if the input model list is empty:
check that the provided model parameter (used later with
unwrap_model(megatron_model)[0]) is a non-empty list and raise a ValueError (or
similar) with a descriptive message if empty; this prevents an IndexError
downstream and complements existing _detect_fp8_params behavior when callers
pass an empty list directly to build_export_fp8_tasks.
In `@src/megatron/bridge/models/conversion/model_bridge.py`:
- Around line 160-187: The file defines a local dataclass
AdapterWeightConversionTask that duplicates and shadows the
AdapterWeightConversionTask imported from peft_bridge; remove the local
AdapterWeightConversionTask definition (the dataclass starting at the second
`@dataclass`(frozen=True) block) so the module uses the imported
AdapterWeightConversionTask, and keep the related AdapterWeight dataclass and
references to WeightConversionTask and MegatronWeightTuple unchanged; if the
local docstring contains important details, move that text into the
AdapterWeightConversionTask docstring in peft_bridge instead.
🧹 Nitpick comments (4)
src/megatron/bridge/models/conversion/auto_bridge.py (2)
324-328: Attribute initialization concern:unquantized_state_dictshould be initialized in__init__.The
unquantized_state_dictattribute is assigned here but never initialized in__init__. Per coding guidelines: "Initialize all externally visible members of a class in the constructor." This could lead toAttributeErrorif accessed beforeload_hf_weightsis called.♻️ Proposed fix
Add initialization in
__init__:def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig): if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)): raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance") self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained # Data type for exporting weights self.export_weight_dtype: Literal["bf16", "fp16", "fp8"] = "bf16" + # Unquantized state dict captured during FP8 weight loading (None if not FP8) + self.unquantized_state_dict: dict[str, torch.Tensor] | None = None
998-1001: Property creates new bridge instance on each access - potential inefficiency.The
_model_bridgeproperty instantiates a newMegatronModelBridgeeach time it's accessed viaget_model_bridge(), then setsexport_weight_dtypeon it. If_model_bridgeis accessed multiple times in a single operation, this could cause redundant object creation and any cached state on the bridge would be lost between accesses.Consider caching the bridge instance or documenting that each access returns a fresh instance.
♻️ Optional: Cache the model bridge
+ `@cached_property` def _model_bridge(self) -> "MegatronModelBridge": bridge = model_bridge.get_model_bridge(self._causal_lm_architecture) bridge.export_weight_dtype = self.export_weight_dtype return bridgeNote: If caching, you'd need to handle the case where
export_weight_dtypechanges after the first access.src/megatron/bridge/models/conversion/model_bridge.py (2)
1093-1165:_detect_fp8_params: Catching bareExceptionis overly broad.Lines 1122-1123 and 1147-1148 catch bare
Exception, which can hide unexpected errors. For FP8 tensor type detection, catchingImportErrorandTypeErrorwould be more appropriate.♻️ Proposed fix for exception handling
try: from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor - except Exception: + except ImportError: Float8BlockwiseQTensor = Noneif Float8BlockwiseQTensor is not None: try: is_blockwise_fp8 = isinstance(local_weights, Float8BlockwiseQTensor) - except Exception: + except TypeError: is_blockwise_fp8 = False
1234-1244: Use direct attribute access instead ofsetattr/getattrwith constant values.Static analysis flagged B009/B010: using
setattrandgetattrwith constant attribute names is not safer than direct attribute access.♻️ Proposed fix
if local_module is not None and not hasattr(local_module, "config"): - setattr(local_module, "config", model_config) + local_module.config = model_config # Main (weight/bias) task export_weight_tensor = local_weights if global_fp8_flags.get(global_name, False): - if local_weights is not None and hasattr(local_weights, "_rowwise_data"): - rd = getattr(local_weights, "_rowwise_data") + if local_weights is not None and hasattr(local_weights, "_rowwise_data"): + rd = local_weights._rowwise_data if rd is not None: export_weight_tensor = rd
| # Build conversion tasks based on export_weight_dtype configuration | ||
| if conversion_tasks is None and self.export_weight_dtype == "fp8": | ||
| if not isinstance(model, list): | ||
| model = [model] | ||
| # Use FP8 export tasks for blockwise FP8 weights | ||
| conversion_tasks = self._model_bridge.build_export_fp8_tasks( | ||
| self.hf_pretrained, model | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Examine build_export_fp8_tasks method implementation
rg -n -A 15 'def build_export_fp8_tasks' src/megatron/bridge/models/conversion/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1725
🏁 Script executed:
#!/bin/bash
# Check the broader context around lines 377-384 in auto_bridge.py
sed -n '365,400p' src/megatron/bridge/models/conversion/auto_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1386
🏁 Script executed:
#!/bin/bash
# Check for model parameter validation and empty list handling
rg -n -B 5 'if not isinstance\(model, list\)' src/megatron/bridge/models/conversion/auto_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 349
🏁 Script executed:
#!/bin/bash
# Get the full build_export_fp8_tasks method
rg -n -A 50 'def build_export_fp8_tasks' src/megatron/bridge/models/conversion/model_bridge.py | head -70Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2767
🏁 Script executed:
#!/bin/bash
# Check if there's validation for empty model list or FP8 parameters in the method
rg -n -A 100 'def build_export_fp8_tasks' src/megatron/bridge/models/conversion/model_bridge.py | grep -E '(empty|len\(|if.*model|fp8|raise|return)'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 924
🏁 Script executed:
#!/bin/bash
# Check how the model parameter is used after the conversion
sed -n '377,450p' src/megatron/bridge/models/conversion/auto_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2975
🏁 Script executed:
#!/bin/bash
# Find the export_hf_weights function definition and its parameter validation
rg -n -B 5 'def export_hf_weights' src/megatron/bridge/models/conversion/auto_bridge.py | head -30Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 300
🏁 Script executed:
#!/bin/bash
# Check full method signature and any early validation
sed -n '340,385p' src/megatron/bridge/models/conversion/auto_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2292
🏁 Script executed:
#!/bin/bash
# Check unwrap_model to see if it has error handling for empty lists
rg -n -A 5 'def unwrap_model' src/megatron/bridge/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
#!/bin/bash
# Search more broadly for unwrap_model
rg -n 'def unwrap_model' --type pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
#!/bin/bash
# Check the full export_hf_weights signature to see parameter types
sed -n '331,360p' src/megatron/bridge/models/conversion/auto_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1710
🏁 Script executed:
#!/bin/bash
# See if model parameter can be None or if there's validation
rg -n -B 10 'if not isinstance\(model, list\)' src/megatron/bridge/models/conversion/auto_bridge.py | head -30Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 555
🏁 Script executed:
#!/bin/bash
# Look for unwrap_model implementation more broadly
find src -name "*.py" -type f -exec grep -l "def unwrap_model" {} \;Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
#!/bin/bash
# Check if there's any validation that model list cannot be empty before the call
rg -n -B 15 'conversion_tasks = self._model_bridge.build_export_fp8_tasks' src/megatron/bridge/models/conversion/auto_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 856
🏁 Script executed:
#!/bin/bash
# Look at the complete method context - check if empty model is possible
sed -n '331,390p' src/megatron/bridge/models/conversion/auto_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2829
🏁 Script executed:
#!/bin/bash
# Check more lines of build_export_fp8_tasks to see full implementation and error handling
sed -n '1167,1230p' src/megatron/bridge/models/conversion/model_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 3265
🏁 Script executed:
#!/bin/bash
# Check what unwrap_model does - search in imports or look for torch distributed unwrap patterns
rg -n 'unwrap_model' src/megatron/bridge/models/conversion/model_bridge.py | head -5Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 321
🏁 Script executed:
#!/bin/bash
# Check where unwrap_model is imported from
rg -n 'from.*unwrap_model|import.*unwrap_model' src/megatron/bridge/models/conversion/model_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
#!/bin/bash
# Check the complete build_export_fp8_tasks to see if it validates the model list
sed -n '1167,1280p' src/megatron/bridge/models/conversion/model_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 6001
🏁 Script executed:
#!/bin/bash
# Find the import of unwrap_model to understand what it does
rg -n 'from|import' src/megatron/bridge/models/conversion/model_bridge.py | grep -i unwrapRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 125
🏁 Script executed:
#!/bin/bash
# Check if there are any callers of build_export_fp8_tasks besides the one we saw
rg -n 'build_export_fp8_tasks' src/megatron/bridge/models/conversion/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 292
🏁 Script executed:
#!/bin/bash
# Check end of build_export_fp8_tasks method to see full return statement
sed -n '1220,1290p' src/megatron/bridge/models/conversion/model_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 4038
🏁 Script executed:
#!/bin/bash
# Look at the full end of build_export_fp8_tasks and the return statement
sed -n '1275,1300p' src/megatron/bridge/models/conversion/model_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1330
🏁 Script executed:
#!/bin/bash
# Check if there's validation of megatron_model parameter anywhere in the method
sed -n '1167,1185p' src/megatron/bridge/models/conversion/model_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 936
🏁 Script executed:
#!/bin/bash
# Check what unwrap_model actually does - find its definition
rg -n 'def unwrap_model' src/megatron/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
#!/bin/bash
# Search for unwrap_model import statement at the top of model_bridge.py
head -60 src/megatron/bridge/models/conversion/model_bridge.py | grep -n 'unwrap_model'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 91
🏁 Script executed:
#!/bin/bash
# Look for unwrap_model in megatron core
find . -name "*.py" -type f | xargs grep -l 'def unwrap_model' 2>/dev/null | head -5Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
#!/bin/bash
# Check what happens if megatron_model is empty - look for the return statement
sed -n '1295,1310p' src/megatron/bridge/models/conversion/model_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 479
Add validation for empty model list in build_export_fp8_tasks.
The FP8 export path correctly handles models without FP8 parameters through _detect_fp8_params and graceful flag checking with .get() defaults. However, build_export_fp8_tasks lacks explicit validation for empty model lists. While the calling code in export_hf_weights (line 379-380) converts non-list models to single-element lists, ensuring they're never empty in practice, adding an explicit check at the start of build_export_fp8_tasks would prevent an unhandled IndexError from unwrap_model(megatron_model)[0] if someone calls the method directly with an empty list.
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/conversion/auto_bridge.py` around lines 377 - 384,
Add an explicit validation at the start of build_export_fp8_tasks to raise a
clear error if the input model list is empty: check that the provided model
parameter (used later with unwrap_model(megatron_model)[0]) is a non-empty list
and raise a ValueError (or similar) with a descriptive message if empty; this
prevents an IndexError downstream and complements existing _detect_fp8_params
behavior when callers pass an empty list directly to build_export_fp8_tasks.
| @dataclass(frozen=True) | ||
| class AdapterWeightConversionTask: | ||
| """Task describing an adapter's LoRA weights for conversion or merging. | ||
|
|
||
| The task reuses :class:`WeightConversionTask` to gather the adapter's | ||
| linear_in/linear_out weights (if they are tensor-parallel) and carries the | ||
| adapter metadata required by the merge step. | ||
| """ | ||
|
|
||
| global_base_prefix: str | ||
| adapter_key: Optional[str] # For canonical LoRA only | ||
| alpha: int | ||
| dim: int | ||
| linear_in_task: WeightConversionTask | ||
| linear_out_task: WeightConversionTask | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class AdapterWeight: | ||
| """Materialized adapter weights ready for merge.""" | ||
|
|
||
| global_base_prefix: str | ||
| adapter_key: Optional[str] # For canonical LoRA only | ||
| alpha: int | ||
| dim: int | ||
| linear_in_weight: MegatronWeightTuple | ||
| linear_out_weight: MegatronWeightTuple | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find both AdapterWeightConversionTask definitions
echo "=== Checking imports in model_bridge.py ==="
rg -n "from.*peft_bridge.*import|import.*peft_bridge" src/megatron/bridge/models/conversion/model_bridge.py
echo ""
echo "=== Imported definition from peft_bridge.py ==="
rg -n -A 20 "class AdapterWeightConversionTask" src/megatron/bridge/models/conversion/peft_bridge.py
echo ""
echo "=== Definition in model_bridge.py around line 161 ==="
sed -n '155,190p' src/megatron/bridge/models/conversion/model_bridge.py
echo ""
echo "=== All AdapterWeightConversionTask occurrences in model_bridge.py ==="
rg -n "AdapterWeightConversionTask" src/megatron/bridge/models/conversion/model_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2351
🏁 Script executed:
sed -n '68,82p' src/megatron/bridge/models/conversion/peft_bridge.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 411
Remove redundant AdapterWeightConversionTask definition that shadows the imported class.
AdapterWeightConversionTask is imported from peft_bridge at line 53 but redefined at line 161, making the import inaccessible. Both definitions are identical (same fields, same @dataclass(frozen=True) decorator). Remove the local definition at line 161 and use only the imported one.
Proposed fix
-@dataclass(frozen=True)
-class AdapterWeightConversionTask:
- """Task describing an adapter's LoRA weights for conversion or merging.
-
- The task reuses :class:`WeightConversionTask` to gather the adapter's
- linear_in/linear_out weights (if they are tensor-parallel) and carries the
- adapter metadata required by the merge step.
- """
-
- global_base_prefix: str
- adapter_key: Optional[str] # For canonical LoRA only
- alpha: int
- dim: int
- linear_in_task: WeightConversionTask
- linear_out_task: WeightConversionTask
-
-
`@dataclass`(frozen=True)
class AdapterWeight:
"""Materialized adapter weights ready for merge."""If the more detailed docstring in the local definition is valuable, update the docstring in peft_bridge.py instead.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @dataclass(frozen=True) | |
| class AdapterWeightConversionTask: | |
| """Task describing an adapter's LoRA weights for conversion or merging. | |
| The task reuses :class:`WeightConversionTask` to gather the adapter's | |
| linear_in/linear_out weights (if they are tensor-parallel) and carries the | |
| adapter metadata required by the merge step. | |
| """ | |
| global_base_prefix: str | |
| adapter_key: Optional[str] # For canonical LoRA only | |
| alpha: int | |
| dim: int | |
| linear_in_task: WeightConversionTask | |
| linear_out_task: WeightConversionTask | |
| @dataclass(frozen=True) | |
| class AdapterWeight: | |
| """Materialized adapter weights ready for merge.""" | |
| global_base_prefix: str | |
| adapter_key: Optional[str] # For canonical LoRA only | |
| alpha: int | |
| dim: int | |
| linear_in_weight: MegatronWeightTuple | |
| linear_out_weight: MegatronWeightTuple | |
| `@dataclass`(frozen=True) | |
| class AdapterWeight: | |
| """Materialized adapter weights ready for merge.""" | |
| global_base_prefix: str | |
| adapter_key: Optional[str] # For canonical LoRA only | |
| alpha: int | |
| dim: int | |
| linear_in_weight: MegatronWeightTuple | |
| linear_out_weight: MegatronWeightTuple |
🧰 Tools
🪛 Ruff (0.14.13)
161-161: Redefinition of unused AdapterWeightConversionTask from line 53: AdapterWeightConversionTask redefined here
(F811)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/conversion/model_bridge.py` around lines 160 -
187, The file defines a local dataclass AdapterWeightConversionTask that
duplicates and shadows the AdapterWeightConversionTask imported from
peft_bridge; remove the local AdapterWeightConversionTask definition (the
dataclass starting at the second `@dataclass`(frozen=True) block) so the module
uses the imported AdapterWeightConversionTask, and keep the related
AdapterWeight dataclass and references to WeightConversionTask and
MegatronWeightTuple unchanged; if the local docstring contains important
details, move that text into the AdapterWeightConversionTask docstring in
peft_bridge instead.
Signed-off-by: eternally-z <zzywzj@gmail.com>
d5e268d to
061a1f7
Compare
061a1f7 to
bfe788b
Compare
Signed-off-by: eternally-z <zzywzj@gmail.com>
| else: | ||
| hidden_size = qkv.shape[-1] | ||
| qkv_reshaped = qkv.view(qkv_total_dim, head_size, hidden_size) | ||
| # NOTE: For standard (BF16/FP16) weights, `head_size` is the usual kv_channels/head_dim. |
There was a problem hiding this comment.
will fp8 affect other module layers, other than qkv?
There was a problem hiding this comment.
Yes, besides QKV, we found that gdn_linear_weights and kv_weights are also affected.
The root cause is similar: the original code reshapes tensors using standard model dimensions (e.g., qkvz = qkvz.reshape(-1, hidden_size)). This logic fails for FP8 scale parameters because their shapes are determined by the block size (scaled down), causing dimension mismatches.
We verified this on Qwen3-8B and Qwen3-30B-A3B. For now, we have temporarily only updated the split logic for the QKV module to accommodate the FP8 scale dimensions, and we can easily apply this to gdn_linear_weights and kv_weights.
|
|
||
| return global_fp8_flags | ||
|
|
||
| def build_export_fp8_tasks( |
There was a problem hiding this comment.
is this only for export, import will not work
There was a problem hiding this comment.
Yes, this PR specifically targets the export of block-wise FP8 weights and scales. Importing FP8 weights/scales is not supported at this stage.
|
/ok to test b2c60df |
|
@eternally-z: The lint is failing, plz check Contribute.md, need to run pre-commit |
|
/ok to test fd9ba1d |
Signed-off-by: eternally-z <zzywzj@gmail.com>
Signed-off-by: eternally-z <zzywzj@gmail.com>
Signed-off-by: eternally-z <zzywzj@gmail.com>
fd9ba1d to
890403d
Compare
@yaoyu-33 hi, we updated the access method for unquantized_state_dict. |
What does this PR do ?
This PR adds support for directly exporting block-wise FP8 weights and their corresponding inverse scales (
scale_inv) when callingexport_hf_weights.Motivation
In scenarios such as LLM Reinforcement Learning with FP8 training and rollout, the current workflow typically requires de-quantizing FP8 weights back to BF16 during export, only for them to be re-quantized during the rollout phase.
This "de-quantize -> re-quantize" process introduces:
By allowing direct export of FP8 weights and scales, we enable a seamless and efficient pipeline from training to rollout.
Implementation Details
export_hf_weights:build_export_fp8_tasks.scale_invfactors directly, bypassing the conversion to BF16.load_hf_weights:state_dictwhenfp8_param_gatheris enabled.Scope & Limitations
@AniZpZ
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.