diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 3e0ac4e02a..d095e1dd61 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -15,7 +15,7 @@ import dataclasses from functools import cached_property, partial from pathlib import Path -from typing import Any, Generic, Iterable, List, Optional, Type, TypeVar, Union +from typing import Any, Generic, Iterable, List, Literal, Optional, Type, TypeVar, Union import torch.distributed as dist import transformers @@ -101,6 +101,8 @@ def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig): if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)): raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance") self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained + # Data type for exporting weights + self.export_weight_dtype: Literal["bf16", "fp16", "fp8"] = "bf16" @classmethod def list_supported_models(cls) -> list[str]: @@ -319,10 +321,10 @@ def load_hf_weights( # Preserve trust_remote_code setting from the original bridge instance trust_remote_code = getattr(self.hf_pretrained, "trust_remote_code", False) pre_trained = PreTrainedCausalLM.from_pretrained(hf_path, trust_remote_code=trust_remote_code) - self._model_bridge.load_weights_hf_to_megatron( - pre_trained, model, allowed_mismatched_params=allowed_mismatched_params - ) - + bridge = self._model_bridge + bridge.load_weights_hf_to_megatron(pre_trained, model, allowed_mismatched_params=allowed_mismatched_params) + # Get unquantized_state_dict from the bridge instance that was used for optimizer reload + self.unquantized_state_dict = getattr(bridge, "unquantized_state_dict", None) return model def export_hf_weights( @@ -371,6 +373,13 @@ def export_hf_weights( ... cpu=True ... )) """ + # Build conversion tasks based on export_weight_dtype configuration + if conversion_tasks is None and self.export_weight_dtype == "fp8": + if not isinstance(model, list): + model = [model] + # Use FP8 export tasks for blockwise FP8 weights + conversion_tasks = self._model_bridge.build_export_fp8_tasks(self.hf_pretrained, model) + dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) return model_bridge.stream_weights_megatron_to_hf( dispatch_instance, @@ -1051,7 +1060,9 @@ def _model_bridge(self) -> "MegatronModelBridge": else: hf_config = self.hf_pretrained - return model_bridge.get_model_bridge(self._causal_lm_architecture, hf_config=hf_config) + bridge = model_bridge.get_model_bridge(self._causal_lm_architecture, hf_config=hf_config) + bridge.export_weight_dtype = self.export_weight_dtype + return bridge @property def _provider_bridge_input(self) -> PreTrainedCausalLM | _ConfigOnlyPretrainedShim: diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index c5a6155b12..e279189f52 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -17,9 +17,11 @@ import fnmatch import itertools import logging +import math import re from dataclasses import dataclass from typing import ( + Any, Callable, Dict, Generic, @@ -28,6 +30,7 @@ Mapping, NamedTuple, Optional, + Tuple, Type, TypeVar, Union, @@ -122,6 +125,44 @@ class WeightConversionTask(Generic[MappingT]): param_weight: Optional[torch.Tensor] = None +class _HFNameSuffixMapping: + """A lightweight wrapper mapping that appends a suffix to all HF param names on export. + + This is used by FP8 scale tasks: the task's `param_name/global_param_name` may carry + `*_scale_inv`, but the exported output names come from `mapping.megatron_to_hf()`. + Reusing the base mapping would otherwise emit the original HF weight names (no suffix). + """ + + def __init__(self, base_mapping: Any, suffix: str): + self._base_mapping = base_mapping + self._suffix = suffix + + def __getattr__(self, name: str) -> Any: + # Delegate everything else (e.g., broadcast/gather helpers, megatron_param, etc.) + return getattr(self._base_mapping, name) + + def resolve(self, captures: Tuple[str, ...]) -> "_HFNameSuffixMapping": + # Preserve wildcard resolution behavior if the base mapping supports it. + if hasattr(self._base_mapping, "resolve"): + return _HFNameSuffixMapping(self._base_mapping.resolve(captures), self._suffix) + return _HFNameSuffixMapping(self._base_mapping, self._suffix) + + def hf_to_megatron(self, hf_weights: Any, megatron_module: torch.nn.Module) -> torch.Tensor: + # Pass-through (not used by our export path, but keeps the wrapper mapping "complete"). + return self._base_mapping.hf_to_megatron(hf_weights, megatron_module) + + def megatron_to_hf( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module: Optional[torch.nn.Module], + ) -> Dict[str, torch.Tensor]: + out = self._base_mapping.megatron_to_hf(megatron_weights, megatron_module) + if not out: + return out + # Append suffix to every exported HF parameter name. + return {f"{k}{self._suffix}": v for k, v in out.items()} + + def _megatron_local_name_to_global( models: MegatronModule | List[MegatronModule], config: TransformerConfig, @@ -777,6 +818,16 @@ def load_weights_hf_to_megatron( hf_state_dict: Mapping[str, torch.Tensor] = hf_pretrained.state if hasattr(hf_pretrained, "state") else {} description = f"Loading from {hf_pretrained.model_name_or_path}" + + # if export_weight_dtype is FP8, we need save the unquantized state dict for initializing optimizer main params + export_weight_dtype = getattr(self, "export_weight_dtype", "bf16") + capture_unquantized_state_dict = export_weight_dtype == "fp8" + # Optional: capture per-rank, per-parameter unquantized shards for initializing + # for loading original weights when fp8_param=True. + captured_state_dicts: dict[str, dict[str, torch.Tensor]] | None = None + if capture_unquantized_state_dict: + captured_state_dicts = {f"model{i}": {} for i in range(len(megatron_model))} + self.unquantized_state_dict = None for task in self._with_progress_tracking(hf_to_megatron_tasks, description): # None means megatron module not on current rank, skip if this task is not going to happen if task.megatron_module is None: @@ -817,9 +868,24 @@ def load_weights_hf_to_megatron( f" Bridge type: {type(task.mapping).__name__}\n" f" HF mapping: {task.mapping.hf_param}" ) - task.param_weight.data.copy_(converted_weights) - + if capture_unquantized_state_dict: + vp_stage = task.vp_stage if task.vp_stage is not None else 0 + chunk_key = f"model{vp_stage}" + captured_state_dicts[chunk_key][task.param_name] = converted_weights.detach() + # NOTE: + # For fp8_param (blockwise), `task.param_weight` can be a TransformerEngine + # Float8BlockwiseQTensor) that is a leaf with requires_grad=True. + # In-place updates under grad mode will raise: + # "a leaf Variable that requires grad is being used in an in-place operation." + with torch.no_grad(): + task.param_weight.copy_(converted_weights) self._broadcast_shared_embeddings(megatron_model) + if capture_unquantized_state_dict: + # Keep Megatron's single-chunk convention: "model" instead of "model0". + if len(megatron_model) == 1: + self.unquantized_state_dict = {"model": captured_state_dicts["model0"]} + else: + self.unquantized_state_dict = captured_state_dicts return megatron_model def stream_weights_hf_to_megatron( @@ -1296,6 +1362,255 @@ def build_conversion_tasks( return tasks + def _detect_fp8_params( + self, + megatron_model: List[MegatronModel], + model_config: TransformerConfig, + sorted_global_param_names_all_pp_ranks: List[str], + pp_group: Any, + fp8_scale_inv_attr: str, + ) -> Dict[str, bool]: + """Detect which global parameters are blockwise FP8 and gather flags across pipeline parallel ranks. + + This method scans all parameters in the megatron model to determine which ones are + blockwise FP8 tensors with valid scale_inv attributes. It then gathers these flags + across all pipeline parallel ranks to ensure consistent decisions. + + Args: + megatron_model: List of Megatron model instances + model_config: Transformer configuration + sorted_global_param_names_all_pp_ranks: Sorted list of global parameter names + pp_group: Pipeline parallel group for distributed communication + fp8_scale_inv_attr: Attribute name for the FP8 scale_inv tensor + + Returns: + Dictionary mapping global parameter names to boolean flags indicating + whether they are blockwise FP8 parameters with valid scale_inv attributes. + """ + local_fp8_flags: Dict[str, bool] = {} + global_name_set = set(sorted_global_param_names_all_pp_ranks) + try: + from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor + except Exception: + Float8BlockwiseQTensor = None + + for vp_stage, model in enumerate(megatron_model): + for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)): + if "_extra_state" in local_name or self._is_adapter_param_name(local_name): + continue + + local_name = self._unwrap_name(local_name) + global_name = _megatron_local_name_to_global(megatron_model, model_config, local_name, vp_stage) + if global_name not in global_name_set: + continue + + _, local_weights = get_module_and_param_from_name(megatron_model, local_name, vp_stage) + if local_weights is None: + continue + + # Determine if this is a blockwise FP8 tensor and has the requested scale_inv attr. + # We intentionally require the scale_inv attribute to be non-None: + # - Some initialization paths may leave scale tensors unset; we should not emit + # a scale task in that case (would break deterministic export/consumer assumptions). + is_blockwise_fp8 = False + if Float8BlockwiseQTensor is not None: + try: + is_blockwise_fp8 = isinstance(local_weights, Float8BlockwiseQTensor) + except Exception: + is_blockwise_fp8 = False + + # Float8BlockwiseQTensor should have the scale_inv attribute, but check if it's set (not None) + if is_blockwise_fp8 and getattr(local_weights, fp8_scale_inv_attr, None) is not None: + local_fp8_flags[global_name] = True + + # Gather across PP ranks to ensure consistent insertion decisions + fp8_flags_list: list[Dict[str, bool]] = [None] * get_pg_size(pp_group) + torch.distributed.all_gather_object(fp8_flags_list, local_fp8_flags, group=pp_group) + global_fp8_flags: Dict[str, bool] = {} + for d in fp8_flags_list: + if not d: + continue + for k, v in d.items(): + if v: + global_fp8_flags[k] = True + + return global_fp8_flags + + def build_export_fp8_tasks( + self, + hf_pretrained: HFPreTrained, + megatron_model: List[MegatronModel], + *, + scale_inv_suffix: str = "_scale_inv", + fp8_scale_inv_attr: str = "_rowwise_scale_inv", + ) -> List[None | WeightConversionTask]: + """ + Build Megatron→(export) conversion tasks, inserting extra *scale_inv* tasks for blockwise FP8 params. + """ + + # Ensure hf_pretrained has the required state structure (reuse existing ordering assumptions) + if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): + raise ValueError("hf_pretrained.state.source is required for weight ordering") + + mapping_registry = self.mapping_registry() + unwrapped_model = unwrap_model(megatron_model)[0] + model_config = unwrapped_model.config + embeddings_are_tied = self._share_embeddings_and_output_weights(model_config, unwrapped_model) + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_group = parallel_state.get_pipeline_model_parallel_group() + sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks(megatron_model) + + # Filter out output_layer related parameters if embeddings are tied + if embeddings_are_tied: + sorted_global_param_names_all_pp_ranks = [ + name for name in sorted_global_param_names_all_pp_ranks if "output_layer" not in name + ] + + # 1) Determine which global params are blockwise FP8 and gather flags across PP ranks + global_fp8_flags = self._detect_fp8_params( + megatron_model, + model_config, + sorted_global_param_names_all_pp_ranks, + pp_group, + fp8_scale_inv_attr, + ) + + from transformer_engine.pytorch.constants import TE_DType_To_Torch + + # 2) Expand the global name list with `*.scale_inv` entries. + # This defines the final deterministic task ordering. + expanded_global_names: list[str] = [] + for global_name in sorted_global_param_names_all_pp_ranks: + expanded_global_names.append(global_name) + if global_fp8_flags.get(global_name, False): + expanded_global_names.append(f"{global_name}{scale_inv_suffix}") + + global_names_index_dict = {name: idx for idx, name in enumerate(expanded_global_names)} + + tasks: list[None | WeightConversionTask] = [None] * len(expanded_global_names) + + # 3) Fill tasks for params that are local to this rank. + for vp_stage, model in enumerate(megatron_model): + for local_name, _ in itertools.chain(model.named_parameters(), persistent_buffers(model)): + if "_extra_state" in local_name or self._is_adapter_param_name(local_name): + continue + + local_name = self._unwrap_name(local_name) + global_name = _megatron_local_name_to_global(megatron_model, model_config, local_name, vp_stage) + if global_name not in global_names_index_dict: + continue + + mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) + if not mapping: + logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") + continue + + local_module, local_weights = get_module_and_param_from_name(megatron_model, local_name, vp_stage) + if local_module is not None and not hasattr(local_module, "config"): + setattr(local_module, "config", model_config) + + # Main (weight/bias) task + export_weight_tensor = local_weights + if global_fp8_flags.get(global_name, False): + if local_weights is not None and hasattr(local_weights, "_rowwise_data"): + rd = getattr(local_weights, "_rowwise_data") + if rd is not None: + export_weight_tensor = rd + # TE blockwise _rowwise_data is stored as uint8; view to correct FP8 type. + # Read _fp8_dtype from tensor when available (robust for future formats). + # Megatron fp8_param weights are always e4m3 (forward pass) in both + # fp8_format=e4m3 and fp8_format=hybrid; e5m2 is only for backward gradients. + fp8_dtype = getattr(local_weights, "_fp8_dtype", None) + torch_fp8_dtype = ( + TE_DType_To_Torch.get(fp8_dtype, torch.float8_e4m3fn) + if fp8_dtype is not None + else torch.float8_e4m3fn + ) + export_weight_tensor = export_weight_tensor.contiguous().view(torch_fp8_dtype) + tasks[global_names_index_dict[global_name]] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=vp_stage, + param_name=local_name, + global_param_name=global_name, + megatron_module=local_module, + param_weight=export_weight_tensor, + mapping=mapping, + ) + + # Optional scale_inv task (only for globally-detected FP8 params) + if global_fp8_flags.get(global_name, False): + scale_global_name = f"{global_name}{scale_inv_suffix}" + scale_local_name = f"{local_name}{scale_inv_suffix}" + scale_tensor = None + if local_weights is not None and hasattr(local_weights, fp8_scale_inv_attr): + scale_tensor = getattr(local_weights, fp8_scale_inv_attr) + scale_tensor = self._trim_blockwise_fp8_scale_inv_padding(local_weights, scale_tensor) + # Note: + # Do NOT reuse the same mapping instance as the base weight task. + # We clone via `resolve(())` which returns a new mapping instance + base_mapping_for_scale = mapping.resolve(()) + tasks[global_names_index_dict[scale_global_name]] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=vp_stage, + param_name=scale_local_name, + global_param_name=scale_global_name, + megatron_module=local_module, + param_weight=scale_tensor, + mapping=_HFNameSuffixMapping(base_mapping_for_scale, scale_inv_suffix), + ) + + # 4) Fill remaining placeholders (for PP ranks that don't own certain params). + for idx, global_name in enumerate(expanded_global_names): + if tasks[idx] is not None: + continue + + # For scale_inv entries, reuse the base param's mapping type. + if global_name.endswith(scale_inv_suffix): + base_global_name = global_name[: -len(scale_inv_suffix)] + base_mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(base_global_name)) + if base_mapping is not None: + # clone mapping instance to avoid sharing state across tasks. + mapping = _HFNameSuffixMapping(base_mapping.resolve(()), scale_inv_suffix) + else: + mapping = None + else: + mapping = mapping_registry.megatron_to_hf_lookup(self._get_lora_unwrapped_name(global_name)) + if mapping is None: + logger.warning(f"No mapping found for global_name: {global_name}") + continue + + tasks[idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=None, + param_name=global_name, + global_param_name=global_name, + megatron_module=None, + param_weight=None, + mapping=mapping, + ) + + return tasks + + def _trim_blockwise_fp8_scale_inv_padding( + self, + local_weights: Optional[torch.Tensor], + scale_tensor: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + # This function is used to trim the padding in the scales for blockwise FP8 parameters. + # The GEMM for 2D blocks required padding in the scales. + quantizer = getattr(local_weights, "_quantizer", None) + block_len = getattr(quantizer, "block_len", None) + is_2d_scaled = getattr(local_weights, "_is_2D_scaled", None) + if block_len is None or not is_2d_scaled: + logger.warning("WARNING: block_len or not is_2d_scaled") + return scale_tensor + + q_k = local_weights.shape[-1] + expected_k_tiles = math.ceil(q_k / block_len) + if scale_tensor.shape[1] == expected_k_tiles: + return scale_tensor + return scale_tensor[:, :expected_k_tiles].contiguous() + @classmethod def register_bridge( cls, diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index d0c0d31853..9375a98a3a 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -2395,8 +2395,37 @@ def split_qkv_weights( hidden_size = 1 qkv_reshaped = qkv.view(qkv_total_dim, head_size) else: - hidden_size = qkv.shape[-1] - qkv_reshaped = qkv.view(qkv_total_dim, head_size, hidden_size) + # NOTE: For standard (BF16/FP16) weights, `head_size` is the usual kv_channels/head_dim. + # For blockwise FP8 scale tensors (e.g. Float8BlockwiseQTensor._rowwise_scale_inv), + # the last dim is typically compressed by a block-size factor (e.g. 4096 -> 32). + # In that case we infer a divisor and scale down `head_size` accordingly so that the + # same QKV slicing logic works for both weight tensors and their scale tensors. + orig_hidden_size = provider.hidden_size + current_last_dim = qkv.shape[-1] + + # If last dim matches the model hidden size, it's a normal weight. + # Otherwise, treat it as a "scale-domain" tensor with compressed dims. + if current_last_dim == orig_hidden_size: + hidden_size = current_last_dim + scaled_head_size = head_size + else: + # Infer block divisor (e.g., 4096 / 32 = 128). + if orig_hidden_size % current_last_dim != 0: + raise ValueError( + f"Cannot infer block divisor for qkv tensor: " + f"provider.hidden_size={orig_hidden_size} is not divisible by qkv.shape[-1]={current_last_dim}" + ) + divisor = orig_hidden_size // current_last_dim + if head_size % divisor != 0: + raise ValueError( + f"Cannot scale head_size for qkv tensor: " + f"head_size={head_size} is not divisible by divisor={divisor} " + f"(provider.hidden_size={orig_hidden_size}, qkv.shape[-1]={current_last_dim})" + ) + hidden_size = current_last_dim + scaled_head_size = head_size // divisor + + qkv_reshaped = qkv.view(qkv_total_dim, scaled_head_size, hidden_size) # Extract Q, K, V from interleaved pattern q_slice = torch.cat(