Skip to content

Support direct export of block-wise FP8 weights and scaling factors#1994

Open
eternally-z wants to merge 9 commits intoNVIDIA-NeMo:mainfrom
eternally-z:transfer_fp8_weights
Open

Support direct export of block-wise FP8 weights and scaling factors#1994
eternally-z wants to merge 9 commits intoNVIDIA-NeMo:mainfrom
eternally-z:transfer_fp8_weights

Conversation

@eternally-z
Copy link

@eternally-z eternally-z commented Jan 20, 2026

What does this PR do ?

This PR adds support for directly exporting block-wise FP8 weights and their corresponding inverse scales (scale_inv) when calling export_hf_weights.

Motivation

In scenarios such as LLM Reinforcement Learning with FP8 training and rollout, the current workflow typically requires de-quantizing FP8 weights back to BF16 during export, only for them to be re-quantized during the rollout phase.

This "de-quantize -> re-quantize" process introduces:

  1. Redundant computational overhead.
  2. Potential precision loss due to repeated conversion.
  3. Increased transfer size: Exporting as BF16 consumes 2x the storage compared to the native FP8 format.

By allowing direct export of FP8 weights and scales, we enable a seamless and efficient pipeline from training to rollout.

Implementation Details

  1. export_hf_weights:

    • Introduced a new function build_export_fp8_tasks.
    • this function allows the exporter to dump raw FP8 integer weights and scale_inv factors directly, bypassing the conversion to BF16.
  2. load_hf_weights:

    • Modified the loading process to retain the original state_dict when fp8_param_gather is enabled.
    • This ensures that Megatron initializes the model using the original BF16 weights, rather than using the de-quantized results from FP8 parameters .

Scope & Limitations

  • Supported: Dense models & MoE models.
  • Verification: Verified on Qwen3-8B (Dense) and Qwen3-30B-A3B (MoE).

@AniZpZ

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

  • New Features

    • Introduced a new weight export format option providing enhanced control over model precision during conversion.
    • Added support for adapter weight conversion during model merging operations.
  • Improvements

    • Enhanced weight conversion with improved handling of quantized parameters and scale data.
    • Strengthened QKV weight tensor handling to support diverse weight configurations during model transformation.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 20, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 26, 2026

📝 Walkthrough

Walkthrough

This PR adds FP8 weight export support to the Megatron bridge conversion pipeline. It introduces FP8 parameter detection, blockwise scale parameter handling, and new adapter weight conversion task structures. The changes extend the weight loading and export paths to support FP8 quantization with proper scale_inv parameter naming and unquantized state dict capture.

Changes

Cohort / File(s) Summary
FP8 Export Support
src/megatron/bridge/models/conversion/auto_bridge.py
Added export_weight_dtype attribute with "bf16"/"fp16"/"fp8" options; modified export_hf_weights to trigger FP8 task building when appropriate; extended _model_bridge to propagate export settings; updated load_hf_weights to capture and store unquantized_state_dict from bridge.
FP8 Export Support
src/megatron/bridge/models/conversion/model_bridge.py
Introduced _HFNameSuffixMapping wrapper for FP8 scale parameter naming; added AdapterWeightConversionTask and AdapterWeight dataclasses; modified load_weights_hf_to_megatron return signature to tuple (model, unquantized_state_dict); implemented _detect_fp8_params method for identifying blockwise FP8 parameters; implemented build_export_fp8_tasks method for constructing FP8-specific export tasks with appropriate scale_inv handling.
FP8 Export Support
src/megatron/bridge/models/conversion/param_mapping.py
Extended split_qkv_weights to handle scale-domain QKV tensors by inferring dimension divisors from provider.hidden_size when last dimension doesn't match model hidden size; added validation and resizing logic for FP8 scale tensors while preserving existing behavior for standard weights.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant AutoBridge
    participant MegatronModelBridge
    participant FP8Detector
    participant ParamMapping

    User->>AutoBridge: export_hf_weights(export_weight_dtype="fp8")
    AutoBridge->>MegatronModelBridge: build_export_fp8_tasks()
    MegatronModelBridge->>FP8Detector: _detect_fp8_params()
    FP8Detector->>FP8Detector: scan parameters for scale_inv attributes
    FP8Detector-->>MegatronModelBridge: fp8_param_map (param_name → is_fp8)
    MegatronModelBridge->>ParamMapping: build tasks with _HFNameSuffixMapping
    ParamMapping->>ParamMapping: split_qkv_weights (handle scale-domain tensors)
    MegatronModelBridge-->>AutoBridge: export_tasks (including FP8 scale_inv variants)
    AutoBridge-->>User: weights exported with FP8 parameters
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.89% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR introduces major FP8 weight export feature affecting numerics/convergence but lacks concrete test results, regression data, or performance benchmarks despite mentioning verification. Include explicit test results, regression testing documentation, and performance benchmarks demonstrating numerical correctness and absence of convergence regressions.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely describes the main feature addition: support for direct export of block-wise FP8 weights and their scaling factors, which is the primary objective of the PR.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/megatron/bridge/models/conversion/model_bridge.py (1)

555-617: Update docstring return type to match actual return value.

The method now returns (megatron_model, self.unquantized_state_dict) - a tuple - but the docstring at line 509 still declares Returns: List[MegatronModel]. This is a breaking change affecting callers registered as pre_wrap_hooks (lines 869, 875 in auto_bridge.py). When the hook returns a tuple, subsequent code that iterates over model (line 560) will fail.

Update the docstring:

Docstring fix
         Returns:
-            List[MegatronModel]: The input megatron_model as a list with loaded weights.
+            Tuple[List[MegatronModel], Dict[str, torch.Tensor] | None]: A tuple containing:
+                - The input megatron_model as a list with loaded weights
+                - The unquantized state dict if FP8 capture was enabled, otherwise None

Additionally, callers using this method as a pre_wrap_hook must wrap the return or adjust the hook signature to return only the model list, not a tuple. The direct caller at auto_bridge.py:324 already handles unpacking correctly, but hook-based registrations (lines 869, 875) will fail.

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/conversion/auto_bridge.py`:
- Around line 377-384: Add an explicit validation at the start of
build_export_fp8_tasks to raise a clear error if the input model list is empty:
check that the provided model parameter (used later with
unwrap_model(megatron_model)[0]) is a non-empty list and raise a ValueError (or
similar) with a descriptive message if empty; this prevents an IndexError
downstream and complements existing _detect_fp8_params behavior when callers
pass an empty list directly to build_export_fp8_tasks.

In `@src/megatron/bridge/models/conversion/model_bridge.py`:
- Around line 160-187: The file defines a local dataclass
AdapterWeightConversionTask that duplicates and shadows the
AdapterWeightConversionTask imported from peft_bridge; remove the local
AdapterWeightConversionTask definition (the dataclass starting at the second
`@dataclass`(frozen=True) block) so the module uses the imported
AdapterWeightConversionTask, and keep the related AdapterWeight dataclass and
references to WeightConversionTask and MegatronWeightTuple unchanged; if the
local docstring contains important details, move that text into the
AdapterWeightConversionTask docstring in peft_bridge instead.
🧹 Nitpick comments (4)
src/megatron/bridge/models/conversion/auto_bridge.py (2)

324-328: Attribute initialization concern: unquantized_state_dict should be initialized in __init__.

The unquantized_state_dict attribute is assigned here but never initialized in __init__. Per coding guidelines: "Initialize all externally visible members of a class in the constructor." This could lead to AttributeError if accessed before load_hf_weights is called.

♻️ Proposed fix

Add initialization in __init__:

     def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig):
         if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)):
             raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance")
         self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained
         # Data type for exporting weights
         self.export_weight_dtype: Literal["bf16", "fp16", "fp8"] = "bf16"
+        # Unquantized state dict captured during FP8 weight loading (None if not FP8)
+        self.unquantized_state_dict: dict[str, torch.Tensor] | None = None

998-1001: Property creates new bridge instance on each access - potential inefficiency.

The _model_bridge property instantiates a new MegatronModelBridge each time it's accessed via get_model_bridge(), then sets export_weight_dtype on it. If _model_bridge is accessed multiple times in a single operation, this could cause redundant object creation and any cached state on the bridge would be lost between accesses.

Consider caching the bridge instance or documenting that each access returns a fresh instance.

♻️ Optional: Cache the model bridge
+    `@cached_property`
     def _model_bridge(self) -> "MegatronModelBridge":
         bridge = model_bridge.get_model_bridge(self._causal_lm_architecture)
         bridge.export_weight_dtype = self.export_weight_dtype
         return bridge

Note: If caching, you'd need to handle the case where export_weight_dtype changes after the first access.

src/megatron/bridge/models/conversion/model_bridge.py (2)

1093-1165: _detect_fp8_params: Catching bare Exception is overly broad.

Lines 1122-1123 and 1147-1148 catch bare Exception, which can hide unexpected errors. For FP8 tensor type detection, catching ImportError and TypeError would be more appropriate.

♻️ Proposed fix for exception handling
         try:
             from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
-        except Exception:
+        except ImportError:
             Float8BlockwiseQTensor = None
                 if Float8BlockwiseQTensor is not None:
                     try:
                         is_blockwise_fp8 = isinstance(local_weights, Float8BlockwiseQTensor)
-                    except Exception:
+                    except TypeError:
                         is_blockwise_fp8 = False

1234-1244: Use direct attribute access instead of setattr/getattr with constant values.

Static analysis flagged B009/B010: using setattr and getattr with constant attribute names is not safer than direct attribute access.

♻️ Proposed fix
                 if local_module is not None and not hasattr(local_module, "config"):
-                    setattr(local_module, "config", model_config)
+                    local_module.config = model_config

                 # Main (weight/bias) task
                 export_weight_tensor = local_weights
                 if global_fp8_flags.get(global_name, False):
-                    if local_weights is not None and hasattr(local_weights, "_rowwise_data"):
-                        rd = getattr(local_weights, "_rowwise_data")
+                    if local_weights is not None and hasattr(local_weights, "_rowwise_data"):
+                        rd = local_weights._rowwise_data
                         if rd is not None:
                             export_weight_tensor = rd

Comment on lines +377 to +384
# Build conversion tasks based on export_weight_dtype configuration
if conversion_tasks is None and self.export_weight_dtype == "fp8":
if not isinstance(model, list):
model = [model]
# Use FP8 export tasks for blockwise FP8 weights
conversion_tasks = self._model_bridge.build_export_fp8_tasks(
self.hf_pretrained, model
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Examine build_export_fp8_tasks method implementation
rg -n -A 15 'def build_export_fp8_tasks' src/megatron/bridge/models/conversion/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1725


🏁 Script executed:

#!/bin/bash
# Check the broader context around lines 377-384 in auto_bridge.py
sed -n '365,400p' src/megatron/bridge/models/conversion/auto_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1386


🏁 Script executed:

#!/bin/bash
# Check for model parameter validation and empty list handling
rg -n -B 5 'if not isinstance\(model, list\)' src/megatron/bridge/models/conversion/auto_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 349


🏁 Script executed:

#!/bin/bash
# Get the full build_export_fp8_tasks method
rg -n -A 50 'def build_export_fp8_tasks' src/megatron/bridge/models/conversion/model_bridge.py | head -70

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2767


🏁 Script executed:

#!/bin/bash
# Check if there's validation for empty model list or FP8 parameters in the method
rg -n -A 100 'def build_export_fp8_tasks' src/megatron/bridge/models/conversion/model_bridge.py | grep -E '(empty|len\(|if.*model|fp8|raise|return)'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 924


🏁 Script executed:

#!/bin/bash
# Check how the model parameter is used after the conversion
sed -n '377,450p' src/megatron/bridge/models/conversion/auto_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2975


🏁 Script executed:

#!/bin/bash
# Find the export_hf_weights function definition and its parameter validation
rg -n -B 5 'def export_hf_weights' src/megatron/bridge/models/conversion/auto_bridge.py | head -30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 300


🏁 Script executed:

#!/bin/bash
# Check full method signature and any early validation
sed -n '340,385p' src/megatron/bridge/models/conversion/auto_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2292


🏁 Script executed:

#!/bin/bash
# Check unwrap_model to see if it has error handling for empty lists
rg -n -A 5 'def unwrap_model' src/megatron/bridge/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Search more broadly for unwrap_model
rg -n 'def unwrap_model' --type py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Check the full export_hf_weights signature to see parameter types
sed -n '331,360p' src/megatron/bridge/models/conversion/auto_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1710


🏁 Script executed:

#!/bin/bash
# See if model parameter can be None or if there's validation
rg -n -B 10 'if not isinstance\(model, list\)' src/megatron/bridge/models/conversion/auto_bridge.py | head -30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 555


🏁 Script executed:

#!/bin/bash
# Look for unwrap_model implementation more broadly
find src -name "*.py" -type f -exec grep -l "def unwrap_model" {} \;

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Check if there's any validation that model list cannot be empty before the call
rg -n -B 15 'conversion_tasks = self._model_bridge.build_export_fp8_tasks' src/megatron/bridge/models/conversion/auto_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 856


🏁 Script executed:

#!/bin/bash
# Look at the complete method context - check if empty model is possible
sed -n '331,390p' src/megatron/bridge/models/conversion/auto_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2829


🏁 Script executed:

#!/bin/bash
# Check more lines of build_export_fp8_tasks to see full implementation and error handling
sed -n '1167,1230p' src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 3265


🏁 Script executed:

#!/bin/bash
# Check what unwrap_model does - search in imports or look for torch distributed unwrap patterns
rg -n 'unwrap_model' src/megatron/bridge/models/conversion/model_bridge.py | head -5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 321


🏁 Script executed:

#!/bin/bash
# Check where unwrap_model is imported from
rg -n 'from.*unwrap_model|import.*unwrap_model' src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Check the complete build_export_fp8_tasks to see if it validates the model list
sed -n '1167,1280p' src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 6001


🏁 Script executed:

#!/bin/bash
# Find the import of unwrap_model to understand what it does
rg -n 'from|import' src/megatron/bridge/models/conversion/model_bridge.py | grep -i unwrap

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 125


🏁 Script executed:

#!/bin/bash
# Check if there are any callers of build_export_fp8_tasks besides the one we saw
rg -n 'build_export_fp8_tasks' src/megatron/bridge/models/conversion/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 292


🏁 Script executed:

#!/bin/bash
# Check end of build_export_fp8_tasks method to see full return statement
sed -n '1220,1290p' src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 4038


🏁 Script executed:

#!/bin/bash
# Look at the full end of build_export_fp8_tasks and the return statement
sed -n '1275,1300p' src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1330


🏁 Script executed:

#!/bin/bash
# Check if there's validation of megatron_model parameter anywhere in the method
sed -n '1167,1185p' src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 936


🏁 Script executed:

#!/bin/bash
# Check what unwrap_model actually does - find its definition
rg -n 'def unwrap_model' src/megatron/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Search for unwrap_model import statement at the top of model_bridge.py
head -60 src/megatron/bridge/models/conversion/model_bridge.py | grep -n 'unwrap_model'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 91


🏁 Script executed:

#!/bin/bash
# Look for unwrap_model in megatron core
find . -name "*.py" -type f | xargs grep -l 'def unwrap_model' 2>/dev/null | head -5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Check what happens if megatron_model is empty - look for the return statement
sed -n '1295,1310p' src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 479


Add validation for empty model list in build_export_fp8_tasks.

The FP8 export path correctly handles models without FP8 parameters through _detect_fp8_params and graceful flag checking with .get() defaults. However, build_export_fp8_tasks lacks explicit validation for empty model lists. While the calling code in export_hf_weights (line 379-380) converts non-list models to single-element lists, ensuring they're never empty in practice, adding an explicit check at the start of build_export_fp8_tasks would prevent an unhandled IndexError from unwrap_model(megatron_model)[0] if someone calls the method directly with an empty list.

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/conversion/auto_bridge.py` around lines 377 - 384,
Add an explicit validation at the start of build_export_fp8_tasks to raise a
clear error if the input model list is empty: check that the provided model
parameter (used later with unwrap_model(megatron_model)[0]) is a non-empty list
and raise a ValueError (or similar) with a descriptive message if empty; this
prevents an IndexError downstream and complements existing _detect_fp8_params
behavior when callers pass an empty list directly to build_export_fp8_tasks.

Comment on lines +160 to +187
@dataclass(frozen=True)
class AdapterWeightConversionTask:
"""Task describing an adapter's LoRA weights for conversion or merging.

The task reuses :class:`WeightConversionTask` to gather the adapter's
linear_in/linear_out weights (if they are tensor-parallel) and carries the
adapter metadata required by the merge step.
"""

global_base_prefix: str
adapter_key: Optional[str] # For canonical LoRA only
alpha: int
dim: int
linear_in_task: WeightConversionTask
linear_out_task: WeightConversionTask


@dataclass(frozen=True)
class AdapterWeight:
"""Materialized adapter weights ready for merge."""

global_base_prefix: str
adapter_key: Optional[str] # For canonical LoRA only
alpha: int
dim: int
linear_in_weight: MegatronWeightTuple
linear_out_weight: MegatronWeightTuple

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find both AdapterWeightConversionTask definitions

echo "=== Checking imports in model_bridge.py ==="
rg -n "from.*peft_bridge.*import|import.*peft_bridge" src/megatron/bridge/models/conversion/model_bridge.py

echo ""
echo "=== Imported definition from peft_bridge.py ==="
rg -n -A 20 "class AdapterWeightConversionTask" src/megatron/bridge/models/conversion/peft_bridge.py

echo ""
echo "=== Definition in model_bridge.py around line 161 ==="
sed -n '155,190p' src/megatron/bridge/models/conversion/model_bridge.py

echo ""
echo "=== All AdapterWeightConversionTask occurrences in model_bridge.py ==="
rg -n "AdapterWeightConversionTask" src/megatron/bridge/models/conversion/model_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2351


🏁 Script executed:

sed -n '68,82p' src/megatron/bridge/models/conversion/peft_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 411


Remove redundant AdapterWeightConversionTask definition that shadows the imported class.

AdapterWeightConversionTask is imported from peft_bridge at line 53 but redefined at line 161, making the import inaccessible. Both definitions are identical (same fields, same @dataclass(frozen=True) decorator). Remove the local definition at line 161 and use only the imported one.

Proposed fix
-@dataclass(frozen=True)
-class AdapterWeightConversionTask:
-    """Task describing an adapter's LoRA weights for conversion or merging.
-
-    The task reuses :class:`WeightConversionTask` to gather the adapter's
-    linear_in/linear_out weights (if they are tensor-parallel) and carries the
-    adapter metadata required by the merge step.
-    """
-
-    global_base_prefix: str
-    adapter_key: Optional[str]  # For canonical LoRA only
-    alpha: int
-    dim: int
-    linear_in_task: WeightConversionTask
-    linear_out_task: WeightConversionTask
-
-
 `@dataclass`(frozen=True)
 class AdapterWeight:
     """Materialized adapter weights ready for merge."""

If the more detailed docstring in the local definition is valuable, update the docstring in peft_bridge.py instead.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@dataclass(frozen=True)
class AdapterWeightConversionTask:
"""Task describing an adapter's LoRA weights for conversion or merging.
The task reuses :class:`WeightConversionTask` to gather the adapter's
linear_in/linear_out weights (if they are tensor-parallel) and carries the
adapter metadata required by the merge step.
"""
global_base_prefix: str
adapter_key: Optional[str] # For canonical LoRA only
alpha: int
dim: int
linear_in_task: WeightConversionTask
linear_out_task: WeightConversionTask
@dataclass(frozen=True)
class AdapterWeight:
"""Materialized adapter weights ready for merge."""
global_base_prefix: str
adapter_key: Optional[str] # For canonical LoRA only
alpha: int
dim: int
linear_in_weight: MegatronWeightTuple
linear_out_weight: MegatronWeightTuple
`@dataclass`(frozen=True)
class AdapterWeight:
"""Materialized adapter weights ready for merge."""
global_base_prefix: str
adapter_key: Optional[str] # For canonical LoRA only
alpha: int
dim: int
linear_in_weight: MegatronWeightTuple
linear_out_weight: MegatronWeightTuple
🧰 Tools
🪛 Ruff (0.14.13)

161-161: Redefinition of unused AdapterWeightConversionTask from line 53: AdapterWeightConversionTask redefined here

(F811)

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/conversion/model_bridge.py` around lines 160 -
187, The file defines a local dataclass AdapterWeightConversionTask that
duplicates and shadows the AdapterWeightConversionTask imported from
peft_bridge; remove the local AdapterWeightConversionTask definition (the
dataclass starting at the second `@dataclass`(frozen=True) block) so the module
uses the imported AdapterWeightConversionTask, and keep the related
AdapterWeight dataclass and references to WeightConversionTask and
MegatronWeightTuple unchanged; if the local docstring contains important
details, move that text into the AdapterWeightConversionTask docstring in
peft_bridge instead.

Signed-off-by: eternally-z <zzywzj@gmail.com>
Signed-off-by: eternally-z <zzywzj@gmail.com>
Signed-off-by: eternally-z <zzywzj@gmail.com>
@ISEEKYAN ISEEKYAN self-assigned this Feb 4, 2026
else:
hidden_size = qkv.shape[-1]
qkv_reshaped = qkv.view(qkv_total_dim, head_size, hidden_size)
# NOTE: For standard (BF16/FP16) weights, `head_size` is the usual kv_channels/head_dim.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fp8 affect other module layers, other than qkv?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, besides QKV, we found that gdn_linear_weights and kv_weights are also affected.
The root cause is similar: the original code reshapes tensors using standard model dimensions (e.g., qkvz = qkvz.reshape(-1, hidden_size)). This logic fails for FP8 scale parameters because their shapes are determined by the block size (scaled down), causing dimension mismatches.
We verified this on Qwen3-8B and Qwen3-30B-A3B. For now, we have temporarily only updated the split logic for the QKV module to accommodate the FP8 scale dimensions, and we can easily apply this to gdn_linear_weights and kv_weights.


return global_fp8_flags

def build_export_fp8_tasks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this only for export, import will not work

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this PR specifically targets the export of block-wise FP8 weights and scales. Importing FP8 weights/scales is not supported at this stage.

@yaoyu-33
Copy link
Contributor

yaoyu-33 commented Mar 4, 2026

/ok to test b2c60df

@yaoyu-33
Copy link
Contributor

yaoyu-33 commented Mar 4, 2026

@eternally-z: The lint is failing, plz check Contribute.md, need to run pre-commit

@yaoyu-33
Copy link
Contributor

yaoyu-33 commented Mar 4, 2026

/ok to test fd9ba1d

Signed-off-by: eternally-z <zzywzj@gmail.com>
Signed-off-by: eternally-z <zzywzj@gmail.com>
Signed-off-by: eternally-z <zzywzj@gmail.com>
@eternally-z eternally-z force-pushed the transfer_fp8_weights branch from fd9ba1d to 890403d Compare March 5, 2026 14:11
@eternally-z
Copy link
Author

eternally-z commented Mar 5, 2026

/ok to test fd9ba1d

@yaoyu-33 hi, we updated the access method for unquantized_state_dict.
Previously, we added it to the return values of load_weights_hf_to_megatron. However, this caused tests to fail. This made us realize that modifying the function's return signature breaks the tests and risks breaking other downstream usages.
To fix this, it is now stored and accessed as an attribute of the model_bridge instance instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants