55import torch
66import torch .distributed as dist
77import torch .nn .functional as F
8- import transformers
8+ from contextlib import contextmanager
99from megatron .core import mpu
1010from packaging import version
1111from peft import PeftModel
1212from peft .utils import ModulesToSaveWrapper
1313from tqdm import tqdm
14- from typing import List , Optional , Union
14+ from transformers import PreTrainedModel
15+ from transformers .utils import ContextManagers
16+ from typing import Callable , List , Optional , Union
1517
1618from mcore_bridge .tuners import LoraParallelLinear
1719from mcore_bridge .utils import (MxFp4Dequantizer , SafetensorLazyLoader , StreamingSafetensorSaver , deep_getattr ,
@@ -66,7 +68,6 @@ def __init__(self, config):
6668 self .pp_group = mpu .get_pipeline_model_parallel_group ()
6769 self .etp_group = mpu .get_expert_tensor_parallel_group ()
6870 self .ep_group = mpu .get_expert_model_parallel_group ()
69- self .is_transformers_5 = version .parse (transformers .__version__ ) >= version .parse ('5.0.0.dev' )
7071 self .tp_rank = mpu .get_tensor_model_parallel_rank ()
7172 self .pp_rank = mpu .get_pipeline_model_parallel_rank ()
7273 self .etp_rank = mpu .get_expert_tensor_parallel_rank ()
@@ -1615,7 +1616,14 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx:
16151616 hf_state_dict .update (origin_hf_state_dict )
16161617 return hf_state_dict
16171618
1618- def load_weights (self , mg_models , hf_model_dir : str , peft_format : bool = False , adapter_name : str = 'default' ):
1619+ def load_weights (
1620+ self ,
1621+ mg_models ,
1622+ hf_model_dir : str ,
1623+ peft_format : bool = False ,
1624+ adapter_name : str = 'default' ,
1625+ converter : Optional [Callable ] = None ,
1626+ ):
16191627 """Load weights from safetensors (HuggingFace) format into Megatron model.
16201628
16211629 Args:
@@ -1624,24 +1632,38 @@ def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False,
16241632 peft_format: Whether the weights are in PEFT (LoRA, etc.) format. Defaults to False.
16251633 If True, loads LoRA delta weights. If False, loads the full model weights.
16261634 adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1635+ converter: Used to perform key-value conversion on the newly loaded state_dict.
16271636 """
16281637 self ._peft_format = peft_format
16291638 self ._adapter_name = adapter_name
16301639 mg_models = unwrap_model (mg_models )
16311640 self ._disable_tqdm = False
16321641 with torch .no_grad (), SafetensorLazyLoader (hf_model_dir , peft_format = peft_format ) as loader :
16331642 state_dict = loader .get_state_dict ()
1643+ if converter :
1644+ new_state_dict = {}
1645+ for k , v in state_dict .items ():
1646+ kv = converter (k , v )
1647+ if kv is None :
1648+ continue
1649+ k , v = kv
1650+ new_state_dict [k ] = v
1651+ state_dict = new_state_dict
16341652 hf_prefix = 'base_model.model.' if peft_format else ''
16351653 for mg_model in mg_models :
16361654 list (self ._convert ([mg_model ], state_dict , hf_prefix , True , 'Loading: ' ))
16371655
1638- def export_weights (self ,
1639- mg_models ,
1640- target_device = None ,
1641- only_master_rank : bool = False ,
1642- peft_format : bool = False ,
1643- tqdm_desc : str = 'Exporting: ' ,
1644- disable_tqdm : bool = True ):
1656+ def export_weights (
1657+ self ,
1658+ mg_models ,
1659+ target_device = None ,
1660+ only_master_rank : bool = False ,
1661+ peft_format : bool = False ,
1662+ adapter_name : str = 'default' ,
1663+ converter : Optional [Callable ] = None ,
1664+ tqdm_desc : str = 'Exporting: ' ,
1665+ disable_tqdm : bool = True ,
1666+ ):
16451667 """Export Megatron model weights to safetensors (HuggingFace) format as a generator.
16461668
16471669 This method yields weight tensors one by one for streaming save operations or RL weight synchronization,
@@ -1654,6 +1676,8 @@ def export_weights(self,
16541676 peft_format: Whether to export in PEFT (LoRA, etc.) format. Defaults to False.
16551677 - If True, exports only LoRA delta weights. If False, exports the complete model weights
16561678 (e.g., after merge-lora or full-parameter fine-tuning).
1679+ adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1680+ converter: Used to perform key-value conversion on the newly exported state_dict.
16571681 tqdm_desc: Description text for the progress bar. Defaults to 'Exporting: '.
16581682 disable_tqdm: Whether to disable the tqdm progress bar. Defaults to True.
16591683
@@ -1663,8 +1687,8 @@ def export_weights(self,
16631687 self ._target_device = target_device
16641688 self ._only_master_rank = only_master_rank
16651689 self ._peft_format = peft_format
1690+ self ._adapter_name = adapter_name
16661691 self ._disable_tqdm = disable_tqdm
1667- self ._adapter_name = 'default'
16681692 self ._peft_target_modules = set ()
16691693 self ._peft_modules_to_save = set ()
16701694 hf_prefix = 'base_model.model.' if peft_format else ''
@@ -1674,13 +1698,21 @@ def export_weights(self,
16741698 mg_models [i ] = mg_model .model
16751699 self .config = mg_models [0 ].config
16761700 with torch .no_grad ():
1677- yield from self ._convert (mg_models , {}, hf_prefix , False , tqdm_desc = tqdm_desc )
1701+ for k , v in self ._convert (mg_models , {}, hf_prefix , False , tqdm_desc = tqdm_desc ):
1702+ if converter :
1703+ kv = converter (k , v )
1704+ if kv is None :
1705+ continue
1706+ k , v = kv
1707+ yield k , v
16781708
16791709 def save_weights (
16801710 self ,
16811711 mg_models ,
16821712 output_dir : str ,
16831713 peft_format : bool = False ,
1714+ adapter_name : str = 'default' ,
1715+ converter : Optional [Callable ] = None ,
16841716 max_shard_size : str = '5GB' ,
16851717 ) -> None :
16861718 """Save Megatron model checkpoint in safetensors (HuggingFace) format.
@@ -1695,6 +1727,8 @@ def save_weights(
16951727 peft_format: Whether to save in PEFT (LoRA, etc.) format. Defaults to False.
16961728 If True, saves LoRA delta weights. If False, saves the complete model weights
16971729 (e.g., after merge-lora or full-parameter fine-tuning).
1730+ adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1731+ converter: Used to perform key-value conversion on the newly exported state_dict.
16981732 max_shard_size: Maximum size of a single storage file, default is '5GB'.
16991733 """
17001734 gc_collect ()
@@ -1705,12 +1739,51 @@ def save_weights(
17051739 target_device = 'cpu' ,
17061740 only_master_rank = True ,
17071741 peft_format = peft_format ,
1742+ adapter_name = adapter_name ,
1743+ converter = converter ,
17081744 tqdm_desc = 'Saving: ' ,
17091745 disable_tqdm = False ):
17101746 saver .add_tensor (k , v )
17111747 saver .finalize ()
17121748 dist .barrier () # Ensure all weights are saved completely
17131749
1750+ @contextmanager
1751+ def _patch_hf_initialize_weight (self ):
1752+
1753+ _origin_initialize_weight = PreTrainedModel ._initialize_weights
1754+
1755+ def _initialize_weight (self , * args , ** kwargs ):
1756+ return
1757+
1758+ PreTrainedModel ._initialize_weights = _initialize_weight
1759+ try :
1760+ yield
1761+ finally :
1762+ PreTrainedModel ._initialize_weights = _origin_initialize_weight
1763+
1764+ @contextmanager
1765+ def _patch_device_meta (self , model_cls ):
1766+ __origin_init__ = model_cls .__init__
1767+
1768+ def __init__ (self , * args , ** kwargs ):
1769+ with torch .device ('meta' ):
1770+ __origin_init__ (self , * args , ** kwargs )
1771+
1772+ model_cls .__init__ = __init__
1773+
1774+ try :
1775+ yield
1776+ finally :
1777+ model_cls .__init__ = __origin_init__
1778+
1779+ def _get_meta_model_context (self , ignore_init_model_cls = None ):
1780+ ignore_init_model_cls = ignore_init_model_cls or []
1781+ if not isinstance (ignore_init_model_cls , list ):
1782+ ignore_init_model_cls = [ignore_init_model_cls ]
1783+ context_list = [self ._patch_device_meta (model_cls ) for model_cls in ignore_init_model_cls ]
1784+ context_list .append (self ._patch_hf_initialize_weight ())
1785+ return ContextManagers (context_list )
1786+
17141787
17151788class MultimodalGPTBridge (GPTBridge ):
17161789 hf_layers_prefix = 'model.language_model.layers'
0 commit comments