From e1a75c0a49e319b70d2447b459daec2c6e062197 Mon Sep 17 00:00:00 2001 From: eternally-z Date: Wed, 14 Jan 2026 16:49:38 +0800 Subject: [PATCH 1/7] feat:support transfer blockwise fp8 weights directly Signed-off-by: eternally-z --- .../bridge/models/conversion/auto_bridge.py | 5 +- .../bridge/models/conversion/model_bridge.py | 301 +++++++++++++++++- .../bridge/models/conversion/param_mapping.py | 33 +- 3 files changed, 331 insertions(+), 8 deletions(-) diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 014ea1fde1..e73874d307 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -319,10 +319,11 @@ def load_hf_weights( # Preserve trust_remote_code setting from the original bridge instance trust_remote_code = getattr(self.hf_pretrained, "trust_remote_code", False) pre_trained = PreTrainedCausalLM.from_pretrained(hf_path, trust_remote_code=trust_remote_code) - self._model_bridge.load_weights_hf_to_megatron( + _, unquantized_state_dict = self._model_bridge.load_weights_hf_to_megatron( pre_trained, model, allowed_mismatched_params=allowed_mismatched_params ) - + # Get unquantized_state_dict from the same instance that was used for loading + self.unquantized_state_dict = unquantized_state_dict return model def export_hf_weights( diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index fa3ccafaf5..eee8fb1186 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -20,6 +20,7 @@ import re from dataclasses import dataclass from typing import ( + Any, Callable, Dict, Generic, @@ -28,6 +29,7 @@ Mapping, NamedTuple, Optional, + Tuple, Type, TypeVar, Union, @@ -117,6 +119,73 @@ class WeightConversionTask(Generic[MappingT]): param_weight: Optional[torch.Tensor] = None +class _HFNameSuffixMapping: + """A lightweight wrapper mapping that appends a suffix to all HF param names on export. + + This is used by FP8 scale tasks: the task's `param_name/global_param_name` may carry + `*_scale_inv`, but the exported output names come from `mapping.megatron_to_hf()`. + Reusing the base mapping would otherwise emit the original HF weight names (no suffix). + """ + + def __init__(self, base_mapping: Any, suffix: str): + self._base_mapping = base_mapping + self._suffix = suffix + + def __getattr__(self, name: str) -> Any: + # Delegate everything else (e.g., broadcast/gather helpers, megatron_param, etc.) + return getattr(self._base_mapping, name) + + def resolve(self, captures: Tuple[str, ...]) -> "_HFNameSuffixMapping": + # Preserve wildcard resolution behavior if the base mapping supports it. + if hasattr(self._base_mapping, "resolve"): + return _HFNameSuffixMapping(self._base_mapping.resolve(captures), self._suffix) + return _HFNameSuffixMapping(self._base_mapping, self._suffix) + + def hf_to_megatron(self, hf_weights: Any, megatron_module: torch.nn.Module) -> torch.Tensor: + # Pass-through (not used by our export path, but keeps the wrapper mapping "complete"). + return self._base_mapping.hf_to_megatron(hf_weights, megatron_module) + + def megatron_to_hf( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module: Optional[torch.nn.Module], + ) -> Dict[str, torch.Tensor]: + out = self._base_mapping.megatron_to_hf(megatron_weights, megatron_module) + if not out: + return out + # Append suffix to every exported HF parameter name. + return {f"{k}{self._suffix}": v for k, v in out.items()} + + +@dataclass(frozen=True) +class AdapterWeightConversionTask: + """Task describing an adapter's LoRA weights for conversion or merging. + + The task reuses :class:`WeightConversionTask` to gather the adapter's + linear_in/linear_out weights (if they are tensor-parallel) and carries the + adapter metadata required by the merge step. + """ + + global_base_prefix: str + adapter_key: Optional[str] # For canonical LoRA only + alpha: int + dim: int + linear_in_task: WeightConversionTask + linear_out_task: WeightConversionTask + + +@dataclass(frozen=True) +class AdapterWeight: + """Materialized adapter weights ready for merge.""" + + global_base_prefix: str + adapter_key: Optional[str] # For canonical LoRA only + alpha: int + dim: int + linear_in_weight: MegatronWeightTuple + linear_out_weight: MegatronWeightTuple + + def _megatron_local_name_to_global( models: MegatronModule | List[MegatronModule], config: TransformerConfig, @@ -483,6 +552,14 @@ def load_weights_hf_to_megatron( hf_state_dict: Mapping[str, torch.Tensor] = hf_pretrained.state if hasattr(hf_pretrained, "state") else {} description = f"Loading from {hf_pretrained.model_name_or_path}" + + capture_unquantized_state_dict = True + # Optional: capture per-rank, per-parameter unquantized shards for initializing + # for loading original weights when fp8_param=True. + captured_state_dicts: dict[str, torch.Tensor] | None = ( + {} if capture_unquantized_state_dict else None + ) + self.unquantized_state_dict = None for task in self._with_progress_tracking(hf_to_megatron_tasks, description): # None means megatron module not on current rank, skip if this task is not going to happen if task.megatron_module is None: @@ -523,10 +600,19 @@ def load_weights_hf_to_megatron( f" Bridge type: {type(task.mapping).__name__}\n" f" HF mapping: {task.mapping.hf_param}" ) - task.param_weight.data.copy_(converted_weights) - + if capture_unquantized_state_dict: + captured_state_dicts[task.param_name] = converted_weights.detach() + # NOTE: + # For fp8_param (blockwise), `task.param_weight` can be a TransformerEngine + # Float8BlockwiseQTensor) that is a leaf with requires_grad=True. + # In-place updates under grad mode will raise: + # "a leaf Variable that requires grad is being used in an in-place operation." + with torch.no_grad(): + task.param_weight.copy_(converted_weights) self._broadcast_shared_embeddings(megatron_model) - return megatron_model + if capture_unquantized_state_dict: + self.unquantized_state_dict = captured_state_dicts + return megatron_model, self.unquantized_state_dict def stream_weights_hf_to_megatron( self, @@ -661,7 +747,10 @@ def stream_weights_megatron_to_hf( # Use provided conversion tasks or build them if conversion_tasks is None: - conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + # NOTE: temporarily use our fp8 export tasks + # TODO: need a param from caller to control which tasks to use + # conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + conversion_tasks = self.build_export_fp8_tasks(hf_pretrained, megatron_model) # Collect adapter conversion tasks when merge is requested adapter_tasks_by_base: Dict[str, List[AdapterWeightConversionTask]] = {} @@ -1002,6 +1091,210 @@ def build_conversion_tasks( return tasks + def build_export_fp8_tasks( + self, + hf_pretrained: HFPreTrained, + megatron_model: List[MegatronModel], + *, + # NOTE: follow the export naming convention used by downstream (e.g., verl/rollout engine): + # weight name: model.layers.0.self_attn.q_proj.weight + # scale name: model.layers.0.self_attn.q_proj.weight_scale_inv + scale_inv_suffix: str = "_scale_inv", + fp8_scale_inv_attr: str = "_rowwise_scale_inv", + ) -> List[None | WeightConversionTask]: + """Build Megatron→(export) conversion tasks, inserting extra *scale_inv* tasks for blockwise FP8 params. + + Design goals: + - Deterministic: every rank gets the exact same task list length and ordering. + - Non-invasive: does not modify existing mappings; scale_inv tasks reuse the original mapping type. + + Notes: + - This method only *builds* tasks. The caller is responsible for interpreting + the `scale_inv_suffix` naming convention and exporting the actual fp8 payload. + - For now we only attach the rowwise scale_inv tensor (default: `_rowwise_scale_inv`) + as requested. + """ + + # Ensure hf_pretrained has the required state structure (reuse existing ordering assumptions) + if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): + raise ValueError("hf_pretrained.state.source is required for weight ordering") + + mapping_registry = self.mapping_registry() + unwrapped_model = unwrap_model(megatron_model)[0] + model_config = unwrapped_model.config + embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model) + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_group = parallel_state.get_pipeline_model_parallel_group() + sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks(megatron_model) + + # Filter out output_layer related parameters if embeddings are tied + if embeddings_are_tied: + sorted_global_param_names_all_pp_ranks = [ + name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name + ] + + # ------------------------------------------------------------ + # 1) Determine which global params are blockwise FP8 and thus + # should have an additional `*.scale_inv` export task. + # We need a *global* decision, so we gather flags across PP ranks. + # ------------------------------------------------------------ + local_fp8_flags: Dict[str, bool] = {} + global_name_set = set(sorted_global_param_names_all_pp_ranks) + try: + from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor + except Exception: + Float8BlockwiseQTensor = None + + for vp_stage, model in enumerate(megatron_model): + for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)): + if "_extra_state" in local_name or self._is_adapter_param_name(local_name): + continue + + local_name = self._unwrap_name(local_name) + global_name = _megatron_local_name_to_global(megatron_model, model_config, local_name, vp_stage) + if global_name not in global_name_set: + continue + + _, local_weights = get_module_and_param_from_name(megatron_model, local_name, vp_stage) + if local_weights is None: + continue + + # Determine if this is a blockwise FP8 tensor and has the requested scale_inv attr. + # We intentionally require the scale_inv attribute to be non-None: + # - Some initialization paths may leave scale tensors unset; we should not emit + # a scale task in that case (would break deterministic export/consumer assumptions). + is_blockwise_fp8 = False + if Float8BlockwiseQTensor is not None: + try: + is_blockwise_fp8 = isinstance(local_weights, Float8BlockwiseQTensor) + except Exception: + is_blockwise_fp8 = False + + # Float8BlockwiseQTensor should have the scale_inv attribute, but check if it's set (not None) + if is_blockwise_fp8 and getattr(local_weights, fp8_scale_inv_attr, None) is not None: + local_fp8_flags[global_name] = True + + # Gather across PP ranks to ensure consistent insertion decisions + fp8_flags_list: list[Dict[str, bool]] = [None] * get_pg_size(pp_group) + torch.distributed.all_gather_object(fp8_flags_list, local_fp8_flags, group=pp_group) + global_fp8_flags: Dict[str, bool] = {} + for d in fp8_flags_list: + if not d: + continue + for k, v in d.items(): + if v: + global_fp8_flags[k] = True + + # ------------------------------------------------------------ + # 2) Expand the global name list with `*.scale_inv` entries. + # This defines the final deterministic task ordering. + # ------------------------------------------------------------ + expanded_global_names: list[str] = [] + for global_name in sorted_global_param_names_all_pp_ranks: + expanded_global_names.append(global_name) + if global_fp8_flags.get(global_name, False): + expanded_global_names.append(f"{global_name}{scale_inv_suffix}") + + global_names_index_dict = {name: idx for idx, name in enumerate(expanded_global_names)} + + tasks: list[None | WeightConversionTask] = [None] * len(expanded_global_names) + + # ------------------------------------------------------------ + # 3) Fill tasks for params that are local to this rank. + # ------------------------------------------------------------ + for vp_stage, model in enumerate(megatron_model): + for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)): + if "_extra_state" in local_name or self._is_adapter_param_name(local_name): + continue + + local_name = self._unwrap_name(local_name) + global_name = _megatron_local_name_to_global(megatron_model, model_config, local_name, vp_stage) + if global_name not in global_names_index_dict: + continue + + mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) + if not mapping: + logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") + continue + + local_module, local_weights = get_module_and_param_from_name(megatron_model, local_name, vp_stage) + if local_module is not None and not hasattr(local_module, "config"): + setattr(local_module, "config", model_config) + + # Main (weight/bias) task + export_weight_tensor = local_weights + if global_fp8_flags.get(global_name, False): + if local_weights is not None and hasattr(local_weights, "_rowwise_data"): + rd = getattr(local_weights, "_rowwise_data") + if rd is not None: + export_weight_tensor = rd + tasks[global_names_index_dict[global_name]] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=vp_stage, + param_name=local_name, + global_param_name=global_name, + megatron_module=local_module, + param_weight=export_weight_tensor, + mapping=mapping, + ) + + # Optional scale_inv task (only for globally-detected FP8 params) + if global_fp8_flags.get(global_name, False): + scale_global_name = f"{global_name}{scale_inv_suffix}" + scale_local_name = f"{local_name}{scale_inv_suffix}" + scale_tensor = None + if local_weights is not None and hasattr(local_weights, fp8_scale_inv_attr): + scale_tensor = getattr(local_weights, fp8_scale_inv_attr) + + # Note: + # Do NOT reuse the same mapping instance as the base weight task. + # We clone via `resolve(())` which returns a new mapping instance + base_mapping_for_scale = mapping.resolve(()) + tasks[global_names_index_dict[scale_global_name]] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=vp_stage, + param_name=scale_local_name, + global_param_name=scale_global_name, + megatron_module=local_module, + param_weight=scale_tensor, + mapping=_HFNameSuffixMapping(base_mapping_for_scale, scale_inv_suffix), + ) + + # ------------------------------------------------------------ + # 4) Fill remaining placeholders (for PP ranks that don't own certain params). + # This keeps the tasks list aligned across all ranks. + # ------------------------------------------------------------ + for idx, global_name in enumerate(expanded_global_names): + if tasks[idx] is not None: + continue + + # For scale_inv entries, reuse the base param's mapping type. + if global_name.endswith(scale_inv_suffix): + base_global_name = global_name[: -len(scale_inv_suffix)] + base_mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(base_global_name)) + if base_mapping is not None: + # See note above: clone mapping instance to avoid sharing state across tasks. + mapping = _HFNameSuffixMapping(base_mapping.resolve(()), scale_inv_suffix) + else: + mapping = None + else: + mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) + if mapping is None: + logger.warning(f"No mapping found for global_name: {global_name}") + continue + + tasks[idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=None, + param_name=global_name, + global_param_name=global_name, + megatron_module=None, + param_weight=None, + mapping=mapping, + ) + + return tasks + @classmethod def register_bridge( cls, *, source: Type[PreTrainedModel] | str, target: Type[MegatronModel] diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index 1684828dcc..9c0c0666be 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -2292,8 +2292,37 @@ def split_qkv_weights( hidden_size = 1 qkv_reshaped = qkv.view(qkv_total_dim, head_size) else: - hidden_size = qkv.shape[-1] - qkv_reshaped = qkv.view(qkv_total_dim, head_size, hidden_size) + # NOTE: For standard (BF16/FP16) weights, `head_size` is the usual kv_channels/head_dim. + # For blockwise FP8 scale tensors (e.g. Float8BlockwiseQTensor._rowwise_scale_inv), + # the last dim is typically compressed by a block-size factor (e.g. 4096 -> 32). + # In that case we infer a divisor and scale down `head_size` accordingly so that the + # same QKV slicing logic works for both weight tensors and their scale tensors. + orig_hidden_size = provider.hidden_size + current_last_dim = qkv.shape[-1] + + # If last dim matches the model hidden size, it's a normal weight. + # Otherwise, treat it as a "scale-domain" tensor with compressed dims. + if current_last_dim == orig_hidden_size: + hidden_size = current_last_dim + scaled_head_size = head_size + else: + # Infer block divisor (e.g., 4096 / 32 = 128). + if orig_hidden_size % current_last_dim != 0: + raise ValueError( + f"Cannot infer block divisor for qkv tensor: " + f"provider.hidden_size={orig_hidden_size} is not divisible by qkv.shape[-1]={current_last_dim}" + ) + divisor = orig_hidden_size // current_last_dim + if head_size % divisor != 0: + raise ValueError( + f"Cannot scale head_size for qkv tensor: " + f"head_size={head_size} is not divisible by divisor={divisor} " + f"(provider.hidden_size={orig_hidden_size}, qkv.shape[-1]={current_last_dim})" + ) + hidden_size = current_last_dim + scaled_head_size = head_size // divisor + + qkv_reshaped = qkv.view(qkv_total_dim, scaled_head_size, hidden_size) # Extract Q, K, V from interleaved pattern q_slice = torch.cat( From 56f45d04571494e96041c9c465a117aa09594491 Mon Sep 17 00:00:00 2001 From: eternally-z Date: Thu, 15 Jan 2026 17:15:00 +0800 Subject: [PATCH 2/7] minor fix Signed-off-by: eternally-z --- .../bridge/models/conversion/auto_bridge.py | 15 ++++++++++++++- .../bridge/models/conversion/model_bridge.py | 8 +++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index e73874d307..d0abca846b 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -101,6 +101,8 @@ def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig): if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)): raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance") self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained + # When True, directly export FP8 blockwise weights instead of dequantizing to BF16 + self.export_fp8_weights: bool = False @classmethod def list_supported_models(cls) -> list[str]: @@ -372,6 +374,15 @@ def export_hf_weights( ... cpu=True ... )) """ + # Build conversion tasks based on export_fp8_weights configuration + if conversion_tasks is None and self.export_fp8_weights: + if not isinstance(model, list): + model = [model] + # Use FP8 export tasks for blockwise FP8 weights + conversion_tasks = self._model_bridge.build_export_fp8_tasks( + self.hf_pretrained, model + ) + dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) return model_bridge.stream_weights_megatron_to_hf( dispatch_instance, @@ -985,7 +996,9 @@ def mla_transformer_config(self) -> MLATransformerConfig: @property def _model_bridge(self) -> "MegatronModelBridge": - return model_bridge.get_model_bridge(self._causal_lm_architecture) + bridge = model_bridge.get_model_bridge(self._causal_lm_architecture) + bridge.export_fp8_weights = self.export_fp8_weights + return bridge @property def _provider_bridge_input(self) -> PreTrainedCausalLM | _ConfigOnlyPretrainedShim: diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index eee8fb1186..21d9820c66 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -553,7 +553,8 @@ def load_weights_hf_to_megatron( description = f"Loading from {hf_pretrained.model_name_or_path}" - capture_unquantized_state_dict = True + # if export_fp8_weights is true,we need save the unquantized state dict for initializing optimizer main params + capture_unquantized_state_dict = getattr(self, "export_fp8_weights", False) # Optional: capture per-rank, per-parameter unquantized shards for initializing # for loading original weights when fp8_param=True. captured_state_dicts: dict[str, torch.Tensor] | None = ( @@ -747,10 +748,7 @@ def stream_weights_megatron_to_hf( # Use provided conversion tasks or build them if conversion_tasks is None: - # NOTE: temporarily use our fp8 export tasks - # TODO: need a param from caller to control which tasks to use - # conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) - conversion_tasks = self.build_export_fp8_tasks(hf_pretrained, megatron_model) + conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) # Collect adapter conversion tasks when merge is requested adapter_tasks_by_base: Dict[str, List[AdapterWeightConversionTask]] = {} From bfe788b508361044fbbc2925ac76673c6523a0e7 Mon Sep 17 00:00:00 2001 From: eternally-z Date: Tue, 20 Jan 2026 16:21:38 +0800 Subject: [PATCH 3/7] minor fix Signed-off-by: eternally-z --- .../bridge/models/conversion/auto_bridge.py | 12 +- .../bridge/models/conversion/model_bridge.py | 149 ++++++++---------- 2 files changed, 72 insertions(+), 89 deletions(-) diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index d0abca846b..725d5189f7 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -15,7 +15,7 @@ import dataclasses from functools import cached_property, partial from pathlib import Path -from typing import Any, Generic, Iterable, List, Optional, Type, TypeVar, Union +from typing import Any, Generic, Iterable, List, Literal, Optional, Type, TypeVar, Union import torch.distributed as dist import transformers @@ -101,8 +101,8 @@ def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig): if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)): raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance") self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained - # When True, directly export FP8 blockwise weights instead of dequantizing to BF16 - self.export_fp8_weights: bool = False + # Data type for exporting weights + self.export_weight_dtype: Literal["bf16", "fp16", "fp8"] = "bf16" @classmethod def list_supported_models(cls) -> list[str]: @@ -374,8 +374,8 @@ def export_hf_weights( ... cpu=True ... )) """ - # Build conversion tasks based on export_fp8_weights configuration - if conversion_tasks is None and self.export_fp8_weights: + # Build conversion tasks based on export_weight_dtype configuration + if conversion_tasks is None and self.export_weight_dtype == "fp8": if not isinstance(model, list): model = [model] # Use FP8 export tasks for blockwise FP8 weights @@ -997,7 +997,7 @@ def mla_transformer_config(self) -> MLATransformerConfig: @property def _model_bridge(self) -> "MegatronModelBridge": bridge = model_bridge.get_model_bridge(self._causal_lm_architecture) - bridge.export_fp8_weights = self.export_fp8_weights + bridge.export_weight_dtype = self.export_weight_dtype return bridge @property diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 21d9820c66..c77dda1d62 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -156,36 +156,6 @@ def megatron_to_hf( # Append suffix to every exported HF parameter name. return {f"{k}{self._suffix}": v for k, v in out.items()} - -@dataclass(frozen=True) -class AdapterWeightConversionTask: - """Task describing an adapter's LoRA weights for conversion or merging. - - The task reuses :class:`WeightConversionTask` to gather the adapter's - linear_in/linear_out weights (if they are tensor-parallel) and carries the - adapter metadata required by the merge step. - """ - - global_base_prefix: str - adapter_key: Optional[str] # For canonical LoRA only - alpha: int - dim: int - linear_in_task: WeightConversionTask - linear_out_task: WeightConversionTask - - -@dataclass(frozen=True) -class AdapterWeight: - """Materialized adapter weights ready for merge.""" - - global_base_prefix: str - adapter_key: Optional[str] # For canonical LoRA only - alpha: int - dim: int - linear_in_weight: MegatronWeightTuple - linear_out_weight: MegatronWeightTuple - - def _megatron_local_name_to_global( models: MegatronModule | List[MegatronModule], config: TransformerConfig, @@ -553,8 +523,9 @@ def load_weights_hf_to_megatron( description = f"Loading from {hf_pretrained.model_name_or_path}" - # if export_fp8_weights is true,we need save the unquantized state dict for initializing optimizer main params - capture_unquantized_state_dict = getattr(self, "export_fp8_weights", False) + # if export_weight_dtype is FP8, we need save the unquantized state dict for initializing optimizer main params + export_weight_dtype = getattr(self, "export_weight_dtype", "bf16") + capture_unquantized_state_dict = export_weight_dtype == "fp8" # Optional: capture per-rank, per-parameter unquantized shards for initializing # for loading original weights when fp8_param=True. captured_state_dicts: dict[str, torch.Tensor] | None = ( @@ -1089,53 +1060,31 @@ def build_conversion_tasks( return tasks - def build_export_fp8_tasks( + def _detect_fp8_params( self, - hf_pretrained: HFPreTrained, megatron_model: List[MegatronModel], - *, - # NOTE: follow the export naming convention used by downstream (e.g., verl/rollout engine): - # weight name: model.layers.0.self_attn.q_proj.weight - # scale name: model.layers.0.self_attn.q_proj.weight_scale_inv - scale_inv_suffix: str = "_scale_inv", - fp8_scale_inv_attr: str = "_rowwise_scale_inv", - ) -> List[None | WeightConversionTask]: - """Build Megatron→(export) conversion tasks, inserting extra *scale_inv* tasks for blockwise FP8 params. + model_config: TransformerConfig, + sorted_global_param_names_all_pp_ranks: List[str], + pp_group: Any, + fp8_scale_inv_attr: str, + ) -> Dict[str, bool]: + """Detect which global parameters are blockwise FP8 and gather flags across pipeline parallel ranks. - Design goals: - - Deterministic: every rank gets the exact same task list length and ordering. - - Non-invasive: does not modify existing mappings; scale_inv tasks reuse the original mapping type. - - Notes: - - This method only *builds* tasks. The caller is responsible for interpreting - the `scale_inv_suffix` naming convention and exporting the actual fp8 payload. - - For now we only attach the rowwise scale_inv tensor (default: `_rowwise_scale_inv`) - as requested. - """ + This method scans all parameters in the megatron model to determine which ones are + blockwise FP8 tensors with valid scale_inv attributes. It then gathers these flags + across all pipeline parallel ranks to ensure consistent decisions. - # Ensure hf_pretrained has the required state structure (reuse existing ordering assumptions) - if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): - raise ValueError("hf_pretrained.state.source is required for weight ordering") - - mapping_registry = self.mapping_registry() - unwrapped_model = unwrap_model(megatron_model)[0] - model_config = unwrapped_model.config - embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model) - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - pp_group = parallel_state.get_pipeline_model_parallel_group() - sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks(megatron_model) - - # Filter out output_layer related parameters if embeddings are tied - if embeddings_are_tied: - sorted_global_param_names_all_pp_ranks = [ - name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name - ] + Args: + megatron_model: List of Megatron model instances + model_config: Transformer configuration + sorted_global_param_names_all_pp_ranks: Sorted list of global parameter names + pp_group: Pipeline parallel group for distributed communication + fp8_scale_inv_attr: Attribute name for the FP8 scale_inv tensor - # ------------------------------------------------------------ - # 1) Determine which global params are blockwise FP8 and thus - # should have an additional `*.scale_inv` export task. - # We need a *global* decision, so we gather flags across PP ranks. - # ------------------------------------------------------------ + Returns: + Dictionary mapping global parameter names to boolean flags indicating + whether they are blockwise FP8 parameters with valid scale_inv attributes. + """ local_fp8_flags: Dict[str, bool] = {} global_name_set = set(sorted_global_param_names_all_pp_ranks) try: @@ -1167,7 +1116,7 @@ def build_export_fp8_tasks( is_blockwise_fp8 = isinstance(local_weights, Float8BlockwiseQTensor) except Exception: is_blockwise_fp8 = False - + # Float8BlockwiseQTensor should have the scale_inv attribute, but check if it's set (not None) if is_blockwise_fp8 and getattr(local_weights, fp8_scale_inv_attr, None) is not None: local_fp8_flags[global_name] = True @@ -1183,10 +1132,49 @@ def build_export_fp8_tasks( if v: global_fp8_flags[k] = True - # ------------------------------------------------------------ + return global_fp8_flags + + def build_export_fp8_tasks( + self, + hf_pretrained: HFPreTrained, + megatron_model: List[MegatronModel], + *, + scale_inv_suffix: str = "_scale_inv", + fp8_scale_inv_attr: str = "_rowwise_scale_inv", + ) -> List[None | WeightConversionTask]: + """ + Build Megatron→(export) conversion tasks, inserting extra *scale_inv* tasks for blockwise FP8 params. + """ + + # Ensure hf_pretrained has the required state structure (reuse existing ordering assumptions) + if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): + raise ValueError("hf_pretrained.state.source is required for weight ordering") + + mapping_registry = self.mapping_registry() + unwrapped_model = unwrap_model(megatron_model)[0] + model_config = unwrapped_model.config + embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model) + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_group = parallel_state.get_pipeline_model_parallel_group() + sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks(megatron_model) + + # Filter out output_layer related parameters if embeddings are tied + if embeddings_are_tied: + sorted_global_param_names_all_pp_ranks = [ + name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name + ] + + # 1) Determine which global params are blockwise FP8 and gather flags across PP ranks + global_fp8_flags = self._detect_fp8_params( + megatron_model, + model_config, + sorted_global_param_names_all_pp_ranks, + pp_group, + fp8_scale_inv_attr, + ) + # 2) Expand the global name list with `*.scale_inv` entries. # This defines the final deterministic task ordering. - # ------------------------------------------------------------ expanded_global_names: list[str] = [] for global_name in sorted_global_param_names_all_pp_ranks: expanded_global_names.append(global_name) @@ -1197,9 +1185,7 @@ def build_export_fp8_tasks( tasks: list[None | WeightConversionTask] = [None] * len(expanded_global_names) - # ------------------------------------------------------------ # 3) Fill tasks for params that are local to this rank. - # ------------------------------------------------------------ for vp_stage, model in enumerate(megatron_model): for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)): if "_extra_state" in local_name or self._is_adapter_param_name(local_name): @@ -1258,10 +1244,7 @@ def build_export_fp8_tasks( mapping=_HFNameSuffixMapping(base_mapping_for_scale, scale_inv_suffix), ) - # ------------------------------------------------------------ # 4) Fill remaining placeholders (for PP ranks that don't own certain params). - # This keeps the tasks list aligned across all ranks. - # ------------------------------------------------------------ for idx, global_name in enumerate(expanded_global_names): if tasks[idx] is not None: continue @@ -1271,7 +1254,7 @@ def build_export_fp8_tasks( base_global_name = global_name[: -len(scale_inv_suffix)] base_mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(base_global_name)) if base_mapping is not None: - # See note above: clone mapping instance to avoid sharing state across tasks. + # clone mapping instance to avoid sharing state across tasks. mapping = _HFNameSuffixMapping(base_mapping.resolve(()), scale_inv_suffix) else: mapping = None From d6257ed8c1ccfe4b64d9f64bb1442b1b4daeec1f Mon Sep 17 00:00:00 2001 From: eternally-z Date: Mon, 2 Feb 2026 15:07:21 +0800 Subject: [PATCH 4/7] fix:remove scale padding Signed-off-by: eternally-z --- .../bridge/models/conversion/model_bridge.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index c77dda1d62..98acb78606 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -18,6 +18,7 @@ import itertools import logging import re +import math from dataclasses import dataclass from typing import ( Any, @@ -1229,7 +1230,7 @@ def build_export_fp8_tasks( scale_tensor = None if local_weights is not None and hasattr(local_weights, fp8_scale_inv_attr): scale_tensor = getattr(local_weights, fp8_scale_inv_attr) - + scale_tensor = self._trim_blockwise_fp8_scale_inv_padding(local_weights, scale_tensor) # Note: # Do NOT reuse the same mapping instance as the base weight task. # We clone via `resolve(())` which returns a new mapping instance @@ -1276,6 +1277,26 @@ def build_export_fp8_tasks( return tasks + def _trim_blockwise_fp8_scale_inv_padding( + self, + local_weights: Optional[torch.Tensor], + scale_tensor: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + # This function is used to trim the padding in the scales for blockwise FP8 parameters. + # The GEMM for 2D blocks required padding in the scales. + quantizer = getattr(local_weights, "_quantizer", None) + block_len = getattr(quantizer, "block_len", None) + is_2d_scaled = getattr(local_weights, "_is_2D_scaled", None) + if block_len is None or not is_2d_scaled: + logger.warning(f"WARNING: block_len or not is_2d_scaled") + return scale_tensor + + q_k = local_weights.shape[-1] + expected_k_tiles = math.ceil(q_k / block_len) + if scale_tensor.shape[1] == expected_k_tiles: + return scale_tensor + return scale_tensor[:, :expected_k_tiles].contiguous() + @classmethod def register_bridge( cls, *, source: Type[PreTrainedModel] | str, target: Type[MegatronModel] From 1b320e79fbeaaf9952d7a61a6018aa3a48e09e82 Mon Sep 17 00:00:00 2001 From: eternally-z Date: Tue, 3 Mar 2026 14:32:02 +0800 Subject: [PATCH 5/7] fix type conversion & vpp format Signed-off-by: eternally-z --- .../bridge/models/conversion/model_bridge.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 98acb78606..9e008a3abb 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -529,9 +529,9 @@ def load_weights_hf_to_megatron( capture_unquantized_state_dict = export_weight_dtype == "fp8" # Optional: capture per-rank, per-parameter unquantized shards for initializing # for loading original weights when fp8_param=True. - captured_state_dicts: dict[str, torch.Tensor] | None = ( - {} if capture_unquantized_state_dict else None - ) + captured_state_dicts: dict[str, dict[str, torch.Tensor]] | None = None + if capture_unquantized_state_dict: + captured_state_dicts = {f"model{i}": {} for i in range(len(megatron_model))} self.unquantized_state_dict = None for task in self._with_progress_tracking(hf_to_megatron_tasks, description): # None means megatron module not on current rank, skip if this task is not going to happen @@ -574,7 +574,9 @@ def load_weights_hf_to_megatron( f" HF mapping: {task.mapping.hf_param}" ) if capture_unquantized_state_dict: - captured_state_dicts[task.param_name] = converted_weights.detach() + vp_stage = task.vp_stage if task.vp_stage is not None else 0 + chunk_key = f"model{vp_stage}" + captured_state_dicts[chunk_key][task.param_name] = converted_weights.detach() # NOTE: # For fp8_param (blockwise), `task.param_weight` can be a TransformerEngine # Float8BlockwiseQTensor) that is a leaf with requires_grad=True. @@ -584,7 +586,11 @@ def load_weights_hf_to_megatron( task.param_weight.copy_(converted_weights) self._broadcast_shared_embeddings(megatron_model) if capture_unquantized_state_dict: - self.unquantized_state_dict = captured_state_dicts + # Keep Megatron's single-chunk convention: "model" instead of "model0". + if len(megatron_model) == 1: + self.unquantized_state_dict = {"model": captured_state_dicts["model0"]} + else: + self.unquantized_state_dict = captured_state_dicts return megatron_model, self.unquantized_state_dict def stream_weights_hf_to_megatron( @@ -1174,6 +1180,7 @@ def build_export_fp8_tasks( fp8_scale_inv_attr, ) + from transformer_engine.pytorch.constants import TE_DType_To_Torch # 2) Expand the global name list with `*.scale_inv` entries. # This defines the final deterministic task ordering. expanded_global_names: list[str] = [] @@ -1213,6 +1220,19 @@ def build_export_fp8_tasks( rd = getattr(local_weights, "_rowwise_data") if rd is not None: export_weight_tensor = rd + # TE blockwise _rowwise_data is stored as uint8; view to correct FP8 type. + # Read _fp8_dtype from tensor when available (robust for future formats). + # Megatron fp8_param weights are always e4m3 (forward pass) in both + # fp8_format=e4m3 and fp8_format=hybrid; e5m2 is only for backward gradients. + fp8_dtype = getattr(local_weights, "_fp8_dtype", None) + torch_fp8_dtype = ( + TE_DType_To_Torch.get(fp8_dtype, torch.float8_e4m3fn) + if fp8_dtype is not None + else torch.float8_e4m3fn + ) + export_weight_tensor = export_weight_tensor.contiguous().view( + torch_fp8_dtype + ) tasks[global_names_index_dict[global_name]] = WeightConversionTask( pp_rank=pp_rank, vp_stage=vp_stage, From 836e63afc0a7c59f536f2973274d0377308b579a Mon Sep 17 00:00:00 2001 From: eternally-z Date: Wed, 4 Mar 2026 11:51:07 +0800 Subject: [PATCH 6/7] fix lint issues Signed-off-by: eternally-z --- src/megatron/bridge/models/conversion/auto_bridge.py | 4 +--- .../bridge/models/conversion/model_bridge.py | 12 ++++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index a29f02e43e..0f1e1a8fa8 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -379,9 +379,7 @@ def export_hf_weights( if not isinstance(model, list): model = [model] # Use FP8 export tasks for blockwise FP8 weights - conversion_tasks = self._model_bridge.build_export_fp8_tasks( - self.hf_pretrained, model - ) + conversion_tasks = self._model_bridge.build_export_fp8_tasks(self.hf_pretrained, model) dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) return model_bridge.stream_weights_megatron_to_hf( diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 9e008a3abb..71bf13701b 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -17,8 +17,8 @@ import fnmatch import itertools import logging -import re import math +import re from dataclasses import dataclass from typing import ( Any, @@ -157,6 +157,7 @@ def megatron_to_hf( # Append suffix to every exported HF parameter name. return {f"{k}{self._suffix}": v for k, v in out.items()} + def _megatron_local_name_to_global( models: MegatronModule | List[MegatronModule], config: TransformerConfig, @@ -1150,7 +1151,7 @@ def build_export_fp8_tasks( fp8_scale_inv_attr: str = "_rowwise_scale_inv", ) -> List[None | WeightConversionTask]: """ - Build Megatron→(export) conversion tasks, inserting extra *scale_inv* tasks for blockwise FP8 params. + Build Megatron→(export) conversion tasks, inserting extra *scale_inv* tasks for blockwise FP8 params. """ # Ensure hf_pretrained has the required state structure (reuse existing ordering assumptions) @@ -1181,6 +1182,7 @@ def build_export_fp8_tasks( ) from transformer_engine.pytorch.constants import TE_DType_To_Torch + # 2) Expand the global name list with `*.scale_inv` entries. # This defines the final deterministic task ordering. expanded_global_names: list[str] = [] @@ -1230,9 +1232,7 @@ def build_export_fp8_tasks( if fp8_dtype is not None else torch.float8_e4m3fn ) - export_weight_tensor = export_weight_tensor.contiguous().view( - torch_fp8_dtype - ) + export_weight_tensor = export_weight_tensor.contiguous().view(torch_fp8_dtype) tasks[global_names_index_dict[global_name]] = WeightConversionTask( pp_rank=pp_rank, vp_stage=vp_stage, @@ -1308,7 +1308,7 @@ def _trim_blockwise_fp8_scale_inv_padding( block_len = getattr(quantizer, "block_len", None) is_2d_scaled = getattr(local_weights, "_is_2D_scaled", None) if block_len is None or not is_2d_scaled: - logger.warning(f"WARNING: block_len or not is_2d_scaled") + logger.warning("WARNING: block_len or not is_2d_scaled") return scale_tensor q_k = local_weights.shape[-1] From 2a4ec9b6a628156f9eb2102abcf1b395824c3cde Mon Sep 17 00:00:00 2001 From: eternally-z Date: Thu, 5 Mar 2026 22:02:30 +0800 Subject: [PATCH 7/7] refactor: expose unquantized_state_dict as an instance attribute Signed-off-by: eternally-z --- src/megatron/bridge/models/conversion/auto_bridge.py | 9 ++++----- src/megatron/bridge/models/conversion/model_bridge.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 0f1e1a8fa8..bcbaeac6d1 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -321,11 +321,10 @@ def load_hf_weights( # Preserve trust_remote_code setting from the original bridge instance trust_remote_code = getattr(self.hf_pretrained, "trust_remote_code", False) pre_trained = PreTrainedCausalLM.from_pretrained(hf_path, trust_remote_code=trust_remote_code) - _, unquantized_state_dict = self._model_bridge.load_weights_hf_to_megatron( - pre_trained, model, allowed_mismatched_params=allowed_mismatched_params - ) - # Get unquantized_state_dict from the same instance that was used for loading - self.unquantized_state_dict = unquantized_state_dict + bridge = self._model_bridge + bridge.load_weights_hf_to_megatron(pre_trained, model, allowed_mismatched_params=allowed_mismatched_params) + # Get unquantized_state_dict from the bridge instance that was used for optimizer reload + self.unquantized_state_dict = getattr(bridge, "unquantized_state_dict", None) return model def export_hf_weights( diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index 71bf13701b..a292ca7aee 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -592,7 +592,7 @@ def load_weights_hf_to_megatron( self.unquantized_state_dict = {"model": captured_state_dicts["model0"]} else: self.unquantized_state_dict = captured_state_dicts - return megatron_model, self.unquantized_state_dict + return megatron_model def stream_weights_hf_to_megatron( self,